Easy_Torch package

Submodules

easy_torch.callbacks module

class easy_torch.callbacks.TimeCallback(log_params={})[source]

Bases: Callback

on_epoch_start()[source]
on_epoch_end(split_name)[source]
on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the pytorch_lightning.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
on_validation_epoch_start(trainer, pl_module)[source]

Called when the val epoch begins.

on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

on_test_epoch_start(trainer, pl_module)[source]

Called when the test epoch begins.

on_test_epoch_end(trainer, pl_module)[source]

Called when the test epoch ends.

class easy_torch.callbacks.TemperatureSlowdownCallback(threshold=80, sleep_time=10, every_n_epochs=5, devices=slice(None, None, None), nvidia_smi_path='C:\\Program Files\\NVIDIA Corporation\\NVSMI\\nvidia-smi.exe')[source]

Bases: Callback

on_validation_epoch_start(trainer, pl_module)[source]

Called when the val epoch begins.

class easy_torch.callbacks.TerminateOnNaNCallback[source]

Bases: Callback

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

easy_torch.losses module

class easy_torch.losses.PatriniLoss(noise_level, num_classes)[source]

Bases: Module

forward(input, target)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
  • input (Tensor)

  • target (Tensor)

Return type:

Tensor

class easy_torch.losses.ForwardNRL(noise_level, num_classes)[source]

Bases: PatriniLoss

forward(input, target)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
  • input (Tensor)

  • target (Tensor)

Return type:

Tensor

class easy_torch.losses.BackwardNRL(noise_level, num_classes)[source]

Bases: PatriniLoss

forward(input, target)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
  • input (Tensor)

  • target (Tensor)

Return type:

Tensor

class easy_torch.losses.GCELoss(q=0.7)[source]

Bases: Module

Computes the Generalized Cross Entropy (GCE) loss, which is especially useful for training deep neural networks with noisy labels. Refer to “Generalized Cross Entropy Loss for Training Deep Neural Networks with Noisy Labels” <https://arxiv.org/abs/1805.07836>

Parameters:

q (float)

q

Box-Cox transformation parameter. Must be in (0,1].

Type:

float

epsilon

A small value to avoid undefined gradient.

Type:

float

softmax

Softmax function to convert raw scores to probabilities.

Type:

nn.Softmax

forward(input, target)[source]

Compute the GCE loss between the predictions and targets.

Parameters:
  • input (param) – Predictions from the model (before softmax) shape: (batch_size, num_classes)

  • target (param) – True labels (one-hot encoded) shape: (batch_size, num_classes)

Returns:

The mean GCE loss.

Return type:

torch.Tensor

class easy_torch.losses.NCODLoss(sample_labels=None, num_examp=50000, num_classes=100, ratio_consistency=0, ratio_balance=0, total_epochs=4000, encoder_features=512)[source]

Bases: Module

init_param(mean=1e-08, std=1e-09)[source]
forward(index, outputs, label, out, flag, epoch)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

consistency_loss(output1, output2)[source]
soft_to_hard(x)[source]

easy_torch.metrics module

class easy_torch.metrics.SoftLabelsAccuracy[source]

Bases: Metric

update(input, target)[source]

Override this method to update the state variables of your metric class.

Parameters:
  • input (Tensor)

  • target (Tensor)

compute()[source]

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.

class easy_torch.metrics.BatchLength(batch_dim=1)[source]

Bases: Metric

A metric to compute the average batch length.

Parameters:

batch_size (int) – The size of the batch.

update(batch, *args, **kwargs)[source]

Updates the metric with the current batch length.

Parameters:
  • batch_length (torch.Tensor) – The length of the current batch.

  • batch (Tensor)

compute()[source]

Computes and returns the average batch length.

class easy_torch.metrics.FakeMetricCollection(metric_class, keys_name='out_keys', *args, **kwargs)[source]

Bases: MetricCollection

easy_torch.metrics.make_fake_class(base_class)[source]
class easy_torch.metrics.FakeMetric(true_metric, key, *args, **kwargs)[source]

Bases: Metric

update(*args, **kwargs)[source]

Override this method to update the state variables of your metric class.

compute(*args, **kwargs)[source]

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.

easy_torch.model module

class easy_torch.model.BaseNN(main_module, loss, optimizer, scheduler=None, metrics={}, log_params={}, step_routing={'loss_input_from_batch': [1], 'loss_input_from_model_output': None, 'metrics_input_from_batch': [1], 'metrics_input_from_model_output': None, 'model_input_from_batch': [0]}, **kwargs)[source]

Bases: LightningModule

Base class for a neural network model in PyTorch Lightning. This class serves as a base for creating neural network models with customizable components such as the main module, loss function, optimizer, and metrics. It also provides methods for logging, computing model outputs, losses, and metrics. :param main_module: The main neural network module. :type main_module: torch.nn.Module :param loss: The primary loss function or a dictionary of loss functions. :type loss: torch.nn.Module or dict :param optimizer: The optimizer function to be used for training. :type optimizer: callable :param scheduler: Learning rate scheduler function or a dictionary containing scheduler configuration. :type scheduler: callable or dict, optional :param metrics: A dictionary of metrics to be used for evaluation. :type metrics: dict :param log_params: Parameters for logging, such as whether to log on epoch end. :type log_params: dict :param step_routing: A dictionary defining how batch and model output are routed to the model, loss, and metrics. :type step_routing: dict

log(name, value)[source]

Custom logging function that handles logging of metrics and values.

Parameters:
  • name (str) – Base name for the metric or value.

  • value (Any) – Value to log, which can be a scalar, dict, or torchmetrics.MetricCollection.

forward(*args, **kwargs)[source]

Forward pass through the main module.

Returns:

Output of the main model.

Return type:

torch.Tensor or Any

configure_optimizers()[source]

Configure the optimizer(s) and learning rate scheduler(s) for the model.

Returns:

A dictionary containing the optimizer and optionally the learning rate scheduler. The dictionary can contain:

  • ”optimizer”: The optimizer instance.

  • ”lr_scheduler”: A dictionary or callable for the learning rate scheduler.

Return type:

dict

step(batch, batch_idx, dataloader_idx, split_name)[source]

Common step function for processing a batch.

Parameters:
  • batch (Any) – Input batch from the dataloader.

  • batch_idx (int) – Index of the batch.

  • dataloader_idx (int) – Index of the dataloader (used for multi-dataloader scenarios).

  • split_name (str) – One of [“train”, “val”, “test”, “predict”].

Returns:

Dictionary containing model output, loss (if applicable), and metrics (if applicable).

Return type:

dict

compute_model_output(batch, model_input_from_batch)[source]

Compute the model output given a batch and the routing for model input.

Parameters:
  • batch (Any) – Input batch from the dataloader.

  • model_input_from_batch (list or dict) – Routing for model input from the batch.

Returns:

Output of the model.

Return type:

torch.Tensor or Any

get_input_args_kwargs(*args)[source]

Get postional arguments and keyword arguments from the provided args.

Parameters:

*args – A tuple of objects and their corresponding keys.

Returns:

List of positional arguments extracted from the objects. input_kwargs (dict): Dictionary of input keyword arguments extracted from the objects.

Return type:

input_args (list)

compute_loss(loss_object, batch, loss_input_from_batch, model_output, loss_input_from_model_output, split_name, dataloader_idx)[source]

Compute the loss given a batch and the routing for loss input. :param loss_object: The loss function or a dictionary of loss functions. :type loss_object: torch.nn.Module or dict :param batch: Input batch from the dataloader. :type batch: Any :param loss_input_from_batch: Routing for loss input from the batch. :type loss_input_from_batch: list or dict :param model_output: Output of the model. :type model_output: torch.Tensor or Any :param loss_input_from_model_output: Routing for loss input from the model output. :type loss_input_from_model_output: list or dict :param split_name: Data split name. :type split_name: str :param dataloader_idx: Index of the dataloader (used for multi-dataloader scenarios). :type dataloader_idx: int

Returns:

Computed loss value.

Return type:

torch.Tensor

compute_metrics(batch, metrics_input_from_batch, model_output, metrics_input_from_model_output, split_name, dataloader_idx)[source]

Compute metrics using the specified metric functions.

Parameters:
  • batch (Any) – Input batch from the dataloader.

  • metrics_input_from_batch (list or dict) – Routing for metrics input from the batch.

  • model_output (torch.Tensor or Any) – Output of the model.

  • metrics_input_from_model_output (list or dict) – Routing for metrics input from the model output.

  • split_name (str) – Data split name.

  • dataloader_idx (int) – Index of the dataloader (used for multi-dataloader scenarios).

Returns:

Dictionary containing computed metric values.

Return type:

dict

get_key_if_dict_and_exists(obj, key)[source]
training_step(batch, batch_idx, dataloader_idx=0)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

validation_step(batch, batch_idx, dataloader_idx=0)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})

Note

If you don’t need to validate you don’t need to implement this method.

Note

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

test_step(batch, batch_idx, dataloader_idx=0)[source]

Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one test dataloader:
def test_step(self, batch, batch_idx): ...


# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'test_loss': loss, 'test_acc': test_acc})

If you pass in multiple test dataloaders, test_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple test dataloaders
def test_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})

Note

If you don’t need to test you don’t need to implement this method.

Note

When the test_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

Step function called during predict(). By default, it calls forward(). Override to add any processing logic.

The predict_step() is used to scale inference on multi-devices.

To prevent an OOM error, it is possible to use BasePredictionWriter callback to write the predictions to disk or database after each batch or on epoch end.

The BasePredictionWriter should be used while using a spawn based accelerator. This happens for Trainer(strategy="ddp_spawn") or training on 8 TPU cores with Trainer(accelerator="tpu", devices=8) as predictions won’t be returned.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

Predicted output (optional).

Example

class MyModel(LightningModule):

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

dm = ...
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=2)
predictions = trainer.predict(model, dm)
easy_torch.model.get_torchvision_model(*args, seed=42, **kwargs)[source]
easy_torch.model.get_torchvision_model_as_decoder(example_datum, *args, **kwargs)[source]
easy_torch.model.load_torchvision_model(*args, **kwargs)[source]
class easy_torch.model.Identity[source]

Bases: Module

An Identity module that returns the input as is. This module can be used as a placeholder in a neural network architecture. It does not perform any operation on the input and simply returns it. :param None:

Returns:

The input tensor is returned unchanged.

Return type:

torch.Tensor

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class easy_torch.model.LambdaLayer(lambd)[source]

Bases: Module

A LambdaLayer module that applies a custom function to the input. It is useful for applying custom transformations or operations in a neural network. :param lambd: A function that takes a tensor as input and returns a tensor as output. :type lambd: callable

Returns:

The output tensor after applying the custom function.

Return type:

torch.Tensor

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

easy_torch.preparation module

easy_torch.preparation.prepare_data_loaders(data, split_keys={'test': ['test_x', 'test_y'], 'train': ['train_x', 'train_y'], 'val': ['val_x', 'val_y']}, dtypes=None, **loader_params)[source]
easy_torch.preparation.prepare_experiment_id(original_trainer_params, experiment_id, cfg=None)[source]
easy_torch.preparation.prepare_callbacks(trainer_params, additional_module=None, seed=42)[source]
easy_torch.preparation.remove_keys_from_dict(input_dict, keys_to_remove)[source]

Recursively remove keys from a dictionary and all its sub-dictionaries.

easy_torch.preparation.prepare_logger(trainer_params, additional_module=None, seed=42)[source]
easy_torch.preparation.prepare_strategy(trainer_params, additional_module=None)[source]
easy_torch.preparation.prepare_plugins(trainer_params, additional_module=None)[source]
easy_torch.preparation.prepare_trainer(seed=42, raytune=False, **trainer_kwargs)[source]
easy_torch.preparation.prepare_loss(loss_info, *additional_modules, split_keys={'test': 3, 'train': 1, 'val': 2}, seed=42)[source]
easy_torch.preparation.get_single_loss(loss_name, loss_params, *additional_modules)[source]
easy_torch.preparation.get_single_callback(callback_name, callback_params, additional_module=None)[source]
easy_torch.preparation.get_function(function_name, *modules)[source]
easy_torch.preparation.prepare_metrics(metrics_info, *additional_modules, split_keys={'test': 3, 'train': 1, 'val': 2}, seed=42)[source]
easy_torch.preparation.handle_FakeMetricCollection(metric_name, metric_params, *additional_modules)[source]
easy_torch.preparation.prepare_optimizer(name, params={}, seed=42)[source]
easy_torch.preparation.prepare_scheduler(scheduler_info, seed=42, *additional_modules)[source]
easy_torch.preparation.prepare_model(model_cfg)[source]
easy_torch.preparation.prepare_emission_tracker(experiment_id, **tracker_kwargs)[source]
easy_torch.preparation.prepare_flops_profiler(model, experiment_id, **profiler_kwargs)[source]
easy_torch.preparation.get_correct_package(name, *modules, raise_error=True)[source]
easy_torch.preparation.complete_prepare_trainer(cfg, experiment_id, model_params=None, additional_module={}, raytune=False)[source]
easy_torch.preparation.complete_prepare_model(cfg, main_module, *additional_modules, model_params=None)[source]
easy_torch.preparation.prepare_profiler(trainer_params, additional_module=None, seed=42)[source]

easy_torch.process module

easy_torch.process.create_model(main_module, seed=42, **kwargs)[source]

Create a PyTorch Lightning model.

Parameters:
  • main_module (nn.Module) – The main module of the model.

  • seed (int, optional) – Random seed for reproducibility (default: 42).

  • **kwargs – Additional keyword arguments to pass to the BaseNN constructor.

Returns:

A PyTorch Lightning model wrapping the main_module.

Return type:

BaseNN

easy_torch.process.train_model(trainer, model, loaders, train_key='train', val_key='val', seed=42, tracker=None, profiler=None)[source]

Trains a PyTorch Lightning model.

Parameters:
  • trainer (pl.Trainer) – The PyTorch Lightning Trainer instance used to fit the model.

  • model (pl.LightningModule) – The model to be trained.

  • loaders (Dict[str, DataLoader]) – Dictionary containing DataLoaders.

  • train_key (str) – Key to select the training DataLoader from loaders (default: “train”).

  • val_key (Union[str, list[str], None]) – Key or list of keys to select validation DataLoaders, or None to skip validation (default: “val”).

  • seed (int) – Random seed for deterministic training (default: 42).

  • tracker (Optional[object]) – Optional tracker with start() and stop() methods.

  • profiler (Optional[object]) – Optional profiler with start_profile(), stop_profile(), and print_model_profile(output_file) methods.

Returns:

None

easy_torch.process.validate_model(trainer, model, loaders, loaders_key='val', seed=42)[source]

Validates a PyTorch Lightning model.

Parameters:
  • trainer (pl.Trainer) – The PyTorch Lightning Trainer instance used to run validation.

  • model (pl.LightningModule) – The trained model to be validated.

  • loaders (Dict[str, DataLoader]) – Dictionary of DataLoaders, keyed by names (e.g., ‘train’, ‘val’).

  • loaders_key (str) – Key used to select the validation DataLoader from loaders (default: “val”).

  • seed (int) – Random seed for reproducibility during validation (default: 42).

Returns:

None

easy_torch.process.test_model(trainer, model, loaders, test_key='test', tracker=None, profiler=None, seed=42)[source]

Test a PyTorch Lightning model.

Parameters:
  • trainer (pl.Trainer) – The PyTorch Lightning Trainer instance used to run validation.

  • model (pl.LightningModule) – The trained model to be tested.

  • loaders (Dict[str, DataLoader]) – Dictionary of DataLoaders, keyed by names (e.g., ‘train’, ‘val’).

  • test_key (str) – Key used to select the test DataLoader from loaders (default: “test”).

  • seed (int) – Random seed for reproducibility during validation (default: 42).

  • tracker (Optional[object]) – Optional tracker with start() and stop() methods.

  • profiler (Optional[object]) – Optional profiler with start_profile(), stop_profile(), and print_model_profile(output_file) methods.

Returns:

None

easy_torch.process.shutdown_dataloaders_workers()[source]

Shutdown data loader workers in a distributed setting.

Parameters:

None

Returns:

None

easy_torch.process.load_model(model_cfg, path, **kwargs)[source]

Load a PyTorch Lightning model from a checkpoint.

Parameters:
  • model_cfg (dict) – Configuration parameters for the model.

  • path (str) – Path to the checkpoint file.

  • **kwargs – Additional keyword arguments to pass to the BaseNN constructor.

Returns:

The loaded PyTorch Lightning model.

easy_torch.process.load_logs(name, exp_id, project_folder='../')[source]

Load log data from a CSV file.

Parameters:
  • name (str) – Name of the log file.

  • exp_id (str) – Experiment ID.

  • project_folder (str) – Path to the project folder (default: “../”).

Returns:

Loaded log data as a Pandas DataFrame.

easy_torch.torchvision_utils module

easy_torch.torchvision_utils.get_torchvision_model(name, torchvision_params={}, in_channels=None, out_features=None, out_as_image=False, keep_image_size=False, **kwargs)[source]

Get a pre-trained TorchVision model with optional modifications.

Parameters: - name: Name of the TorchVision model. - torchvision_params: Parameters for the TorchVision model. - in_channels: Number of input channels (optional). - out_features: Number of output features (optional). - out_as_image: Modify the model for image output (optional). - keep_image_size: Keep image size during modifications (optional). - kwargs: Additional keyword arguments.

Returns: - module: The modified TorchVision model.

easy_torch.torchvision_utils.get_torchvision_model_split_name(name)[source]

Split and get a TorchVision model by name.

Parameters: - name: Name of the TorchVision model.

Returns: - app: The TorchVision model.

easy_torch.torchvision_utils.change_in_channels(name, module, in_channels)[source]

Change the number of input channels in the model.

Parameters: - name: Name of the model. - module: The model to be modified. - in_channels: Number of input channels.

Returns: - None

easy_torch.torchvision_utils.change_conv_out_features(name, module, out_features=None)[source]

Change the output features of convolutional layers in the model.

Parameters: - name: Name of the model. - module: The model to be modified. - out_features: Number of output features (optional).

Returns: - None

easy_torch.torchvision_utils.change_fc_out_features(name, module, out_features)[source]

Change the output features of fully connected layers in the model.

Parameters: - name: Name of the model. - module: The model to be modified. - out_features: Number of output features.

Returns: - None

easy_torch.torchvision_utils.change_all_paddings(name, module)[source]

Change padding in convolutional layers to “same”.

Parameters: - name: Name of the model. - module: The model to be modified.

Returns: - None

easy_torch.torchvision_utils.resnet_forward_impl(self, x)[source]

Custom forward method for ResNet.

Parameters: - self: The ResNet model. - x: Input tensor.

Returns: - x: Output tensor.

Parameters:

x (Tensor)

Return type:

Tensor

easy_torch.torchvision_utils.video_resnet_forward(self, x)[source]

Custom forward method for VideoResNet.

Parameters: - self: The VideoResNet model. - x: Input tensor.

Returns: - x: Output tensor.

Parameters:

x (Tensor)

Return type:

Tensor

easy_torch.torchvision_utils.load_torchvision_model(model_cfg, path)[source]

Load a TorchVision model from a checkpoint.

Parameters: - model_cfg: Configuration parameters for the model. - path: Path to the checkpoint file.

Returns: - model: The loaded TorchVision model.

easy_torch.torchvision_utils.invert_model(model, example_datum, keep_order=False)[source]

Invert a model.

Parameters: - model: The model to be inverted. - example_datum: An example datum for the model.

Returns: - inverted_model: The inverted model.

easy_torch.torchvision_utils.invert_layer(layer, current_input, inverted_layers=[])[source]

easy_torch.utils module

class easy_torch.utils.RobustModuleDict(init_dict=None)[source]

Bases: ModuleDict

Torch ModuleDict wrapper that permits keys with any name, including those that would otherwise conflict with class attributes.

Torch’s ModuleDict does not allow certain keys (e.g., ‘type’, ‘forward’) because they clash with existing methods or attributes of nn.Module, raising errors like KeyError.

Example

> torch.nn.ModuleDict({‘type’: torch.nn.Module()}) # Raises KeyError. > RobustModuleDict({‘type’: torch.nn.Module()}) # Works fine.

This class mitigates possible conflicts by using a key-suffixing protocol.

Parameters:

init_dict (Dict[str, torch.nn.Module], optional) – Initial dictionary of modules. If provided, it initializes the RobustModuleDict with these modules. Defaults to None.

Returns:

None

keys()[source]

Return an iterable of the ModuleDict keys.

Return type:

List[str]

values()[source]

Return an iterable of the ModuleDict values.

Return type:

List[Module]

items()[source]

Return an iterable of the ModuleDict key/value pairs.

Return type:

List[Tuple[str, Module]]

update(modules)[source]

Update the ModuleDict with key-value pairs from a mapping, overwriting existing keys.

Note

If modules is an OrderedDict, a ModuleDict, or an iterable of key-value pairs, the order of new elements in it is preserved.

Parameters:

modules (iterable) – a mapping (dictionary) from string to Module, or an iterable of key-value pairs of type (string, Module)

Return type:

None