Source code for easy_torch.torchvision_utils

# Import necessary libraries
import torch  # Import the PyTorch library for deep learning
import torchvision  # Import torchvision for pre-trained models
from types import MethodType  # Import MethodType for method modification
from .model import BaseNN  # Import the BaseNN class from the model module
from copy import deepcopy  # Import deepcopy for model copying

# Function to get a pre-trained TorchVision model
[docs] def get_torchvision_model(name, torchvision_params={}, in_channels=None, out_features=None, out_as_image=False, keep_image_size=False, **kwargs): """ Get a pre-trained TorchVision model with optional modifications. Parameters: - name: Name of the TorchVision model. - torchvision_params: Parameters for the TorchVision model. - in_channels: Number of input channels (optional). - out_features: Number of output features (optional). - out_as_image: Modify the model for image output (optional). - keep_image_size: Keep image size during modifications (optional). - kwargs: Additional keyword arguments. Returns: - module: The modified TorchVision model. """ # Create the base TorchVision model # module = get_torchvision_model_split_name(name)(**torchvision_params) module = torchvision.models.get_model(name, **torchvision_params) # Modify the model if in_channels is specified if in_channels is not None: change_in_channels(name, module, in_channels) # Modify the model for image output if out_as_image is True if out_as_image: change_conv_out_features(name, module, out_features) if keep_image_size: change_all_paddings(name, module) # Modify the model if out_features is specified elif out_features is not None: change_fc_out_features(name, module, out_features) return module
# Function to split and get a TorchVision model by name
[docs] def get_torchvision_model_split_name(name): """ Split and get a TorchVision model by name. Parameters: - name: Name of the TorchVision model. Returns: - app: The TorchVision model. """ name = name.split(".") app = torchvision.models for i in range(len(name)): app = getattr(app, name[i]) return app
# Function to change the number of input channels in the model
[docs] def change_in_channels(name, module, in_channels): """ Change the number of input channels in the model. Parameters: - name: Name of the model. - module: The model to be modified. - in_channels: Number of input channels. Returns: - None """ if "resnet" in name: module_section = module attr_name = "conv1" elif "squeezenet" in name: module_section = module.features attr_name = "0" elif "deeplab" in name: module_section = getattr(module.backbone, "0") attr_name = "0" elif name in ["mc3_18"]: module_section = module.stem attr_name = "0" else: raise NotImplementedError("Model name", name) current_conv = getattr(module_section, attr_name) setattr(module_section, attr_name, type(current_conv)(in_channels=in_channels, out_channels=current_conv.out_channels, kernel_size=current_conv.kernel_size, stride=current_conv.stride, padding=current_conv.padding, bias=[True, False][current_conv.bias is None]))
# Function to change the output features of convolutional layers in the model
[docs] def change_conv_out_features(name, module, out_features=None): """ Change the output features of convolutional layers in the model. Parameters: - name: Name of the model. - module: The model to be modified. - out_features: Number of output features (optional). Returns: - None """ if "resnet" in name: # Drop the last layer and modify convolutional layers module._forward_impl = MethodType(resnet_forward_impl, module) del module.avgpool del module.fc module_section = getattr(module.layer4, "1") attr_name = "conv2" elif "squeezenet" in name: # Drop the last layer and modify convolutional layers module.forward = lambda x: module.features(x) del module.classifier module_section = getattr(module.features, "12") attr_name = "expand3x3" elif "deeplab" in name: module_section = module.classifier attr_name = "4" elif name in ["mc3_18"]: # Drop the last layer and modify convolutional layers module.forward = MethodType(video_resnet_forward, module) del module.avgpool del module.fc module_section = getattr(getattr(module.layer4, "1"),"conv2") attr_name = "0" else: raise NotImplementedError("Model name", name) current_conv = getattr(module_section, attr_name) if name in ["mc3_18"]: pass # setattr(module_section, attr_name, type(current_conv)(in_planes=current_conv.in_channels, # out_planes=[out_features, current_conv.out_channels][out_features is None], # stride=current_conv.stride[-1], #[-1] because internally the layer does stride=(1, stride, stride), # padding=current_conv.padding[-1] #[-1] because internally the layer does padding=(0, padding, padding), # )) # batch_norm_name = str(int(attr_name)+1) # batch_norm_layer = getattr(module_section, batch_norm_name) # setattr(module_section, batch_norm_name, type(batch_norm_layer)(num_features=out_features, # eps=batch_norm_layer.eps, # momentum=batch_norm_layer.momentum, # affine=batch_norm_layer.affine, # track_running_stats=batch_norm_layer.track_running_stats)) # check: now, it gives error because residual has different size else: setattr(module_section, attr_name, type(current_conv)(in_channels=current_conv.in_channels, out_channels=[out_features, current_conv.out_channels][out_features is None], kernel_size=current_conv.kernel_size, stride=current_conv.stride, padding=current_conv.padding, bias=[True, False][current_conv.bias is None]))
# Function to change the output features of fully connected layers in the model
[docs] def change_fc_out_features(name, module, out_features): """ Change the output features of fully connected layers in the model. Parameters: - name: Name of the model. - module: The model to be modified. - out_features: Number of output features. Returns: - None """ if "resnet" in name: module_section = module attr_name = "fc" elif "squeezenet" in name: module_section = module.classifier attr_name = "1" else: raise NotImplementedError("Model name", name) current_fc = getattr(module_section, attr_name) if "resnet" in name: setattr(module_section, attr_name, type(current_fc)(in_features=current_fc.in_features, out_features=out_features, bias=current_fc.bias is not None)) elif "squeezenet" in name: setattr(module_section, attr_name, type(current_fc)(in_channels=current_fc.in_channels, out_channels=out_features, kernel_size=current_fc.kernel_size, stride=current_fc.stride, padding=current_fc.padding, bias=current_fc.bias is not None)) else: raise NotImplementedError("Model name", name)
# Function to change padding in convolutional layers to "same"
[docs] def change_all_paddings(name, module): """ Change padding in convolutional layers to "same". Parameters: - name: Name of the model. - module: The model to be modified. Returns: - None """ if "squeezenet" in name: for i in range(1, 13): current_conv = getattr(module.features, str(i)) if isinstance(current_conv, torch.nn.Conv2d) or isinstance(current_conv, torch.nn.MaxPool2d): current_conv.padding = "same" current_conv.stride = 1 # Cause with padding same stride needs to be 1 else: raise NotImplementedError("Model name", name)
# Custom forward method for ResNet
[docs] def resnet_forward_impl(self, x: torch.Tensor) -> torch.Tensor: """ Custom forward method for ResNet. Parameters: - self: The ResNet model. - x: Input tensor. Returns: - x: Output tensor. """ x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) # x = self.avgpool(x) # x = torch.flatten(x, 1) # x = self.fc(x) return x
# Custom forward method for VideoResNet
[docs] def video_resnet_forward(self, x: torch.Tensor) -> torch.Tensor: """ Custom forward method for VideoResNet. Parameters: - self: The VideoResNet model. - x: Input tensor. Returns: - x: Output tensor. """ x = self.stem(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) # x = self.avgpool(x) # # Flatten the layer to fc # x = x.flatten(1) # x = self.fc(x) return x
# Function to load a TorchVision model from a checkpoint
[docs] def load_torchvision_model(model_cfg, path): """ Load a TorchVision model from a checkpoint. Parameters: - model_cfg: Configuration parameters for the model. - path: Path to the checkpoint file. Returns: - model: The loaded TorchVision model. """ # Create an instance of the specified TorchVision model torchvision_model = getattr(torchvision.models, model_cfg["name"])(**model_cfg["torchvision_params"]) # Replace the fully connected layer (fc) with a new one based on output_size torchvision_model.fc = torch.nn.Linear(torchvision_model.fc.in_features, model_cfg["output_size"]) # Load the model from the checkpoint file using BaseNN.load_from_checkpoint # Note: You'll need to have the BaseNN class defined or import it here. # For this code to work, you should also have the appropriate model_cfg # that specifies the name and torchvision_params of the model. model = BaseNN.load_from_checkpoint(path, model=torchvision_model, **model_cfg) return model
[docs] def invert_model(model, example_datum, keep_order=False): """ Invert a model. Parameters: - model: The model to be inverted. - example_datum: An example datum for the model. Returns: - inverted_model: The inverted model. """ # Get the model's layers layers = model.named_children() #layers = [module for module in model.modules() if not isinstance(module, type(model))] current_input = example_datum.detach().clone() inverted_layers = torch.nn.ModuleList() for layer_name,layer in layers: if isinstance(layer, torch.nn.Sequential): inverted_layer = invert_model(layer, current_input) else: inverted_layer, current_input = invert_layer(layer, current_input, inverted_layers) inverted_layers.append(inverted_layer) current_input = layer(current_input) #could decrease computational cost by separating computations for sequential layers # Create a new model with the reversed layers # if keep_order: # inverted_model = torch.nn.Sequential(*inverted_layers) # else: # inverted_model = torch.nn.Sequential(*inverted_layers[::-1]) return inverted_model
[docs] def invert_layer(layer, current_input, inverted_layers=[]): if isinstance(layer, torch.nn.Linear): if len(current_input.shape) > 2: inverted_layers.append(torch.nn.Unflatten(1, current_input.shape[1:])) current_input = torch.flatten(current_input, 1) inverted_layer = torch.nn.Linear(in_features=layer.out_features, out_features=layer.in_features, bias=layer.bias is not None) elif isinstance(layer, torch.nn.MaxPool2d) or isinstance(layer, torch.nn.AvgPool2d) or isinstance(layer, torch.nn.AdaptiveAvgPool2d): inverted_layer = torch.nn.Upsample(size=current_input.shape[2:]) elif isinstance(layer, torch.nn.Conv2d): rems = [(current_input.shape[2+i] + 2*layer.padding[i] - layer.dilation[i]*(layer.kernel_size[i]-1) - 1)%layer.stride[i] for i in range(2)] if rems[0] != 0 or rems[1] != 0: inverted_layers.append(torch.nn.Upsample(size=(current_input.shape[2], current_input.shape[3]))) inverted_layer = torch.nn.ConvTranspose2d(in_channels=layer.out_channels, out_channels=layer.in_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, output_padding=layer.output_padding, bias=layer.bias is not None, padding_mode=layer.padding_mode) elif isinstance(layer, torch.nn.BatchNorm2d): inverted_layer = torch.nn.BatchNorm2d(num_features=current_input.shape[1], eps=layer.eps, momentum=layer.momentum, affine=layer.affine, track_running_stats=layer.track_running_stats) elif isinstance(layer, torch.nn.ReLU): inverted_layer = torch.nn.ReLU() elif isinstance(layer, torchvision.models.resnet.BasicBlock): inplanes, planes = layer.conv1.in_channels, layer.conv1.out_channels inverted_layer = torchvision.models.resnet.BasicBlock(inplanes=planes, planes=inplanes, stride=layer.stride, downsample=deepcopy(layer.downsample), norm_layer=type(layer.bn1)) inverted_layer.conv1 = torch.nn.ConvTranspose2d(in_channels=inverted_layer.conv1.in_channels, out_channels=inverted_layer.conv1.out_channels, kernel_size=inverted_layer.conv1.kernel_size, stride=inverted_layer.conv1.stride, padding=inverted_layer.conv1.padding, output_padding=inverted_layer.conv1.output_padding, bias=inverted_layer.conv1.bias is not None, padding_mode=inverted_layer.conv1.padding_mode) inverted_layer.conv2 = torch.nn.ConvTranspose2d(in_channels=inverted_layer.conv2.in_channels, out_channels=inverted_layer.conv2.out_channels, kernel_size=inverted_layer.conv2.kernel_size, stride=inverted_layer.conv2.stride, padding=inverted_layer.conv2.padding, output_padding=inverted_layer.conv2.output_padding, bias=inverted_layer.conv2.bias is not None, padding_mode=inverted_layer.conv2.padding_mode) if inverted_layer.downsample is not None: inverted_layer.downsample[0] = torch.nn.ConvTranspose2d(in_channels=inverted_layer.downsample[0].out_channels, out_channels=inverted_layer.downsample[0].in_channels, kernel_size=inverted_layer.downsample[0].kernel_size, stride=inverted_layer.downsample[0].stride, padding=inverted_layer.downsample[0].padding, output_padding=inverted_layer.downsample[0].output_padding, bias=inverted_layer.downsample[0].bias is not None, padding_mode=inverted_layer.downsample[0].padding_mode) inverted_layer.downsample[1], _ = invert_layer(inverted_layer.downsample[1], current_input, inverted_layers) # inverted_layer = deepcopy(layer) # previous_input2 = previous_input.clone() # current_input2 = current_input.clone() # for children_layer_name, children_layer in layer.named_children(): # if children_layer_name == "downsample" and children_layer is not None: # setattr(inverted_layer, children_layer_name, invert_model(layer.downsample, current_input, previous_input, keep_order=True)) # previous_input2 = current_input2 # current_input2 = children_layer(current_input) + current_input2 # else: # print(children_layer_name, current_input2.shape, previous_input2.shape) # new_layer, current_input2, previous_input2 = invert_layer(children_layer, current_input2, previous_input2, inverted_layers) # setattr(inverted_layer, children_layer_name, new_layer) # current_input2 = children_layer(current_input2) else: raise NotImplementedError("Layer type", type(layer)) return inverted_layer, current_input