Source code for easy_torch.process

# Import necessary libraries
import pandas as pd  # Import Pandas library for data manipulation
import os  # Import the os library for working with the file system
import torch  # Import the PyTorch library for deep learning
import pytorch_lightning as pl  # Import PyTorch Lightning for training and logging
from .model import BaseNN  # Import the BaseNN class from the model module


[docs] def create_model(main_module, seed=42, **kwargs): """ Create a PyTorch Lightning model. Args: main_module (nn.Module): The main module of the model. seed (int, optional): Random seed for reproducibility (default: 42). **kwargs: Additional keyword arguments to pass to the BaseNN constructor. Returns: BaseNN: A PyTorch Lightning model wrapping the main_module. """ pl.seed_everything(seed, verbose=False) # Create the model using the BaseNN class model = BaseNN(main_module, **kwargs) return model
[docs] def train_model(trainer, model, loaders, train_key="train", val_key="val", seed=42, tracker=None, profiler=None): """ Trains a PyTorch Lightning model. Args: trainer (pl.Trainer): The PyTorch Lightning Trainer instance used to fit the model. model (pl.LightningModule): The model to be trained. loaders (Dict[str, DataLoader]): Dictionary containing DataLoaders. train_key (str): Key to select the training DataLoader from `loaders` (default: "train"). val_key (Union[str, list[str], None]): Key or list of keys to select validation DataLoaders, or None to skip validation (default: "val"). seed (int): Random seed for deterministic training (default: 42). tracker (Optional[object]): Optional tracker with `start()` and `stop()` methods. profiler (Optional[object]): Optional profiler with `start_profile()`, `stop_profile()`, and `print_model_profile(output_file)` methods. Returns: None """ # Set a random seed for deterministic training pl.seed_everything(seed, verbose=False) # Check if validation data loaders are specified and handle them accordingly # (single validation DataLoader if `val_key` is a string, or multiple if `val_key` is a list) if val_key is not None: if isinstance(val_key, str): val_dataloaders = loaders[val_key] elif isinstance(val_key, list): val_dataloaders = {key: loaders[key] for key in val_key} else: raise NotImplementedError else: val_dataloaders = None # Start the tracker and profiler if they are provided if tracker is not None: tracker.start() if profiler is not None: profiler.start_profile() # Train the model trainer.fit(model, loaders[train_key], val_dataloaders) # Stop the tracker and profiler if they are provided if tracker is not None: tracker.stop() if profiler is not None: profiler.print_model_profile(output_file = f"{profiler.output_dir}/train_flops.txt") profiler.stop_profile()
[docs] def validate_model(trainer, model, loaders, loaders_key="val", seed=42): """ Validates a PyTorch Lightning model. Args: trainer (pl.Trainer): The PyTorch Lightning Trainer instance used to run validation. model (pl.LightningModule): The trained model to be validated. loaders (Dict[str, DataLoader]): Dictionary of DataLoaders, keyed by names (e.g., 'train', 'val'). loaders_key (str): Key used to select the validation DataLoader from `loaders` (default: "val"). seed (int): Random seed for reproducibility during validation (default: 42). Returns: None """ pl.seed_everything(seed, workers=True, verbose=False) # Validate the model using the trainer trainer.validate(model, loaders[loaders_key])
[docs] def test_model(trainer, model, loaders, test_key="test", tracker=None, profiler=None, seed=42): """ Test a PyTorch Lightning model. Args: trainer (pl.Trainer): The PyTorch Lightning Trainer instance used to run validation. model (pl.LightningModule): The trained model to be tested. loaders (Dict[str, DataLoader]): Dictionary of DataLoaders, keyed by names (e.g., 'train', 'val'). test_key (str): Key used to select the test DataLoader from `loaders` (default: "test"). seed (int): Random seed for reproducibility during validation (default: 42). tracker (Optional[object]): Optional tracker with `start()` and `stop()` methods. profiler (Optional[object]): Optional profiler with `start_profile()`, `stop_profile()`, and `print_model_profile(output_file)` methods. Returns: None """ # Set a random seed for reproducibility pl.seed_everything(seed, workers=True, verbose=False) # Start the tracker and profiler if they are provided if tracker is not None: tracker.start() if profiler is not None: profiler.start_profile() # Check if test data loaders are specified and handle them accordingly # (single test DataLoader if `test_key` is a string, or multiple if `test_key` is a list) if isinstance(test_key, str): test_dataloaders = loaders[test_key] elif isinstance(test_key, list): test_dataloaders = {key: loaders[key] for key in test_key} else: raise NotImplementedError # Test the model using the trainer trainer.test(model, test_dataloaders) # Stop the tracker and profiler if they are provided if tracker is not None: tracker.stop() if profiler is not None: profiler.print_model_profile(output_file = f"{profiler.output_dir}/test_flops.txt") profiler.stop_profile()
# # (1) load the best checkpoint automatically (lightning tracks this for you during .fit()) # trainer.test(ckpt_path="best") # # (2) load the last available checkpoint (only works if `ModelCheckpoint(save_last=True)`) # trainer.test(ckpt_path="last") # Function to shutdown data loader workers in a distributed setting
[docs] def shutdown_dataloaders_workers(): """ Shutdown data loader workers in a distributed setting. Args: None Returns: None """ # Check if PyTorch is distributed initialized if torch.distributed.is_initialized(): torch.distributed.barrier() torch.distributed.destroy_process_group()
# Function to load a PyTorch Lightning model from a checkpoint
[docs] def load_model(model_cfg, path, **kwargs): """ Load a PyTorch Lightning model from a checkpoint. Args: model_cfg (dict): Configuration parameters for the model. path (str): Path to the checkpoint file. **kwargs: Additional keyword arguments to pass to the BaseNN constructor. Returns: The loaded PyTorch Lightning model. """ # Load the model from the checkpoint file using the BaseNN class model = BaseNN.load_from_checkpoint(path, **model_cfg, **kwargs) return model
# Function to load log data from a CSV file
[docs] def load_logs(name, exp_id, project_folder="../"): """ Load log data from a CSV file. Args: name(str): Name of the log file. exp_id(str): Experiment ID. project_folder(str): Path to the project folder (default: "../"). Returns: Loaded log data as a Pandas DataFrame. """ # Construct the file path to the log data file_path = os.path.join(project_folder, "out", "log", name, exp_id, "lightning_logs", "version_0", "metrics.csv") # Load CSV data into a Pandas DataFrame logs = pd.read_csv(file_path) return logs