# Import necessary libraries
import multiprocessing
import torch
import pytorch_lightning as pl
import torchmetrics
from copy import deepcopy
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import LambdaLR, SequentialLR
#import wandb
import os
import math
from ray.train.lightning import prepare_trainer as prepare_ray_trainer
import ray.train.lightning as ray_lightning
# Import modules and functions from local files
from .model import BaseNN
from . import metrics as custom_metrics
from . import losses as custom_losses # Ensure your custom losses are imported
from . import callbacks as custom_callbacks # Ensure your custom losses are imported
from . import utils
from . import process
# Function to prepare data loaders
[docs]
def prepare_data_loaders(data, split_keys={"train": ["train_x", "train_y"], "val": ["val_x", "val_y"], "test": ["test_x", "test_y"]}, dtypes = None, **loader_params):
# Default loader parameters
default_loader_params = {
"num_workers": multiprocessing.cpu_count(),
"pin_memory": True,
"persistent_workers": True,
"drop_last": {"train": False, "val": False, "test": False},
"shuffle": {"train": True, "val": False, "test": False}
}
# Combine default and custom loader parameters
loader_params = dict(list(default_loader_params.items()) + list(loader_params.items()))
if dtypes is None or isinstance(dtypes, str) or isinstance(dtypes, torch.dtype):
if isinstance(dtypes, str):
dtypes = getattr(torch, dtypes)
dtypes = {split_name: {data_key:dtypes for data_key in data_keys} for split_name, data_keys in split_keys.items()}
elif isinstance(dtypes, dict):
new_dtypes = {}
for split_name, data_keys in split_keys.items():
new_dtypes[split_name] = {}
for data_key in data_keys:
new_dtypes[split_name][data_key] = dtypes[data_key] if data_key in dtypes.keys() else None
if isinstance(new_dtypes[split_name][data_key], str):
new_dtypes[split_name][data_key] = getattr(torch, new_dtypes[split_name][data_key])
dtypes = new_dtypes
else:
raise NotImplementedError(f"Unsupported dtype: {dtypes}")
loaders = {}
for split_name, data_keys in split_keys.items():
split_loader_params = deepcopy(loader_params)
# Select specific parameters for this split
for key, value in split_loader_params.items():
if isinstance(value, dict):
if split_name in value.keys():
split_loader_params[key] = value[split_name]
# Get data and create the TensorDataset
td = TensorDataset(*[torch.tensor(data[data_key], dtype=dtypes[split_name][data_key]) for data_key in data_keys])
# Create the DataLoader
loaders[split_name] = DataLoader(td, **split_loader_params)
return loaders
# Function to prepare trainer parameters with experiment ID
[docs]
def prepare_experiment_id(original_trainer_params, experiment_id, cfg=None):
# Create a deep copy of the original trainer parameters
trainer_params = deepcopy(original_trainer_params)
# Check if "callbacks" is in trainer_params
if "callbacks" in trainer_params:
for callback_dict in trainer_params["callbacks"]:
if isinstance(callback_dict, dict):
for callback_name, callback_params in callback_dict.items():
if callback_name == "ModelCheckpoint":
# Update the "dirpath" to include the experiment_id
callback_params["dirpath"] += experiment_id + "/"
else:
# Print a warning message for unrecognized callback names
print(f"Warning: {callback_name} not recognized for adding experiment_id")
pass
# Check if "logger" is in trainer_params
if "logger" in trainer_params:
# Update the "save_dir" in logger parameters to include the experiment_id
trainer_params["logger"]["params"]["save_dir"] += experiment_id + "/"
if trainer_params["logger"]["name"] == "WandbLogger":
trainer_params["logger"]["params"]["id"] = experiment_id
trainer_params["logger"]["params"]["name"] = experiment_id
if cfg is not None:
trainer_params["logger"]["params"]["config"] = cfg
return trainer_params
# Function to prepare callbacks
[docs]
def prepare_callbacks(trainer_params, additional_module=None, seed=42):
pl.seed_everything(seed, verbose=False) # Seed the random number generator
# Initialize an empty list for callbacks
callbacks = []
# Check if "callbacks" is in trainer_params
if "callbacks" in trainer_params:
for callback_dict in trainer_params["callbacks"]:
if isinstance(callback_dict, dict):
for callback_name, callback_params in callback_dict.items():
# Create callback instances based on callback names and parameters
callbacks.append(get_single_callback(callback_name, callback_params, additional_module))
# The following lines are commented out because they seem to be related to a specific issue
# if callback_name == "ModelCheckpoint":
# if os.path.isdir(callbacks[-1].dirpath):
# callbacks[-1].STARTING_VERSION = -1
else:
# If the callback is not a dictionary, add it directly to the callbacks list
callbacks.append(callback_dict)
return callbacks
[docs]
def remove_keys_from_dict(input_dict, keys_to_remove):
"""
Recursively remove keys from a dictionary and all its sub-dictionaries.
"""
if isinstance(input_dict, dict):
for key in keys_to_remove:
if key in input_dict:
del input_dict[key]
for value in input_dict.values():
remove_keys_from_dict(value, keys_to_remove)
return input_dict
# def log_wandb(trainer_params):
# items_to_delete = ['__nosave__', 'emission_tracker', 'metrics',
# 'data_folder', 'log_params', 'step_routing']
# cfg = exp_utils.cfg.load_configuration()
# exp_found, experiment_id = exp_utils.exp.get_set_experiment_id(cfg)
# if not exp_found:
# wandb.login(key=trainer_params["logger"]["key"])
# if trainer_params["logger"]["entity"] is not None:
# wandb.init(entity=trainer_params["logger"]["entity"],
# project=trainer_params["logger"]["project"],
# name = cfg['__exp__.name'] + "_" + experiment_id,
# id = experiment_id,
# config = remove_keys_from_dict(cfg, items_to_delete))
# else:
# wandb.init(project=trainer_params["logger"]["project"],
# name = cfg['__exp__.name'] + "_" + experiment_id,
# id = experiment_id,
# config = remove_keys_from_dict(cfg, items_to_delete))
# Function to prepare a logger based on trainer parameters
[docs]
def prepare_logger(trainer_params, additional_module=None, seed=42):
pl.seed_everything(seed, verbose=False) # Seed the random number generator
logger = None
if "logger" in trainer_params:
# Get the logger class based on its name and initialize it with parameters
if not os.path.exists(trainer_params["logger"]["params"]["save_dir"]):
os.makedirs(trainer_params["logger"]["params"]["save_dir"])
logger = get_function(trainer_params["logger"]["name"], additional_module, pl.loggers)(**trainer_params["logger"]["params"])
#if isinstance(logger, pl.loggers.wandb.WandbLogger):
#This is the case when the logger is wandb so we check for the entity and the the key
#log_wandb(trainer_params)
#TODO: Multiple loggers
return logger
# Function to prepare strategy
[docs]
def prepare_strategy(trainer_params, additional_module=None):
# Have to check if strategy is in pytorch_lightning.strategies or additional_module, otherwise leave it as string (the trainer will handle it)
# Check if "strategy" is in trainer_params
strategy = "auto"
if "strategy" in trainer_params:
strategy_info = trainer_params["strategy"]
if isinstance(strategy_info, str):
strategy_name = strategy_info
strategy_params = {}
elif isinstance(strategy_info, dict):
strategy_name = strategy_info["name"]
strategy_params = strategy_info.get("params", {})
function_module = get_correct_package(strategy_name, additional_module, pl.strategies, ray_lightning)
if function_module is not None:
strategy = getattr(function_module, strategy_name)(**strategy_params)
else:
strategy = strategy_name
return strategy
[docs]
def prepare_plugins(trainer_params, additional_module=None):
# Check if "plugins" is in trainer_params
plugins = [] # Initialize an empty list for plugins
if "plugins" in trainer_params:
for plugin_info in trainer_params["plugins"]:
if isinstance(plugin_info, str):
plugin_name = plugin_info
plugin_params = {}
elif isinstance(plugin_info, dict):
plugin_name = plugin_info["name"]
plugin_params = plugin_info.get("params", {})
plugin = get_function(plugin_name, additional_module, pl.plugins, ray_lightning)(**plugin_params)
plugins.append(plugin)
return plugins
# Function to prepare a PyTorch Lightning Trainer instance
[docs]
def prepare_trainer(seed=42, raytune=False, **trainer_kwargs):
pl.seed_everything(seed, verbose=False) # Seed the random number generator
# Default trainer parameters
default_trainer_params = {"enable_checkpointing": False, "accelerator": "auto", "devices": "auto"}
# Combine default parameters with user-provided kwargs
trainer_params = dict(list(default_trainer_params.items()) + list(trainer_kwargs.items()))
# Create a Trainer instance with the specified parameters
trainer = pl.Trainer(**trainer_params)
if raytune:
trainer = prepare_ray_trainer(trainer)
return trainer
# Function to prepare a loss function
[docs]
def prepare_loss(loss_info, *additional_modules, split_keys={"train":1,"val":2,"test":3}, seed=42):
pl.seed_everything(seed)
losses = {}
# Controlla se losses_info è già suddiviso per split
if isinstance(loss_info, dict) and all([key in loss_info for key in split_keys.keys()]):
losses_info_already_split = True
else:
losses_info_already_split = False
for split_name, num_dataloaders in split_keys.items():
losses[split_name] = [] # Lista di loss per ogni dataloader di questo split
for dataloader_idx in range(num_dataloaders):
# Se losses_info è già suddiviso per split, usa direttamente il valore corrispondente
if losses_info_already_split:
loss_info_to_use = loss_info[split_name][dataloader_idx]
# Altrimenti, usa NCOD per train, CE per val/test
else:
loss_info_to_use = loss_info
# Se il valore è una stringa, significa che è un singolo loss da usare
if isinstance(loss_info_to_use, str):
loss = get_single_loss(loss_info_to_use, {}, *additional_modules)
# Se il valore è un dizionario, significa che è un loss con parametri
elif isinstance(loss_info_to_use, dict):
loss = {}
for loss_name, loss_params in sorted(loss_info_to_use.items()):
if loss_name != "__weight__":
loss[loss_name] = get_single_loss(loss_params["name"], loss_params.get("params", {}), *additional_modules)
loss = torch.nn.ModuleDict(loss)
loss.__weight__ = loss_info_to_use.get("__weight__", torch.ones(len(loss)))
else:
raise NotImplementedError
# Aggiungi loss alla lista per questo split
losses[split_name].append(loss)
losses[split_name] = torch.nn.ModuleList(losses[split_name])
losses = utils.RobustModuleDict(losses)
return losses
[docs]
def get_single_loss(loss_name, loss_params, *additional_modules):
return get_function(loss_name, *additional_modules, custom_losses, torch.nn)(**loss_params)
[docs]
def get_single_callback(callback_name, callback_params, additional_module=None):
return get_function(callback_name, additional_module, custom_callbacks, pl.callbacks, ray_lightning)(**callback_params)
[docs]
def get_function(function_name, *modules):
# Check if the function_name exists in additional_module or torch/torchmetrics
function_module = get_correct_package(function_name, *modules)
# Return the function using the name and parameters
return getattr(function_module, function_name)
[docs]
def prepare_metrics(metrics_info, *additional_modules, split_keys={"train":1,"val":2,"test":3}, seed=42):
# Initialize an empty dictionary to store metrics
metrics = {}
if isinstance(metrics_info, dict) and all([key in metrics_info for key in split_keys.keys()]):
metrics_info_already_split = True
else:
metrics_info_already_split = False
for split_name, num_dataloaders in split_keys.items():
metrics[split_name] = []
for dataloader_idx in range(num_dataloaders):
metrics[split_name].append({})
if metrics_info_already_split:
metrics_info_to_use = metrics_info[split_name][dataloader_idx]
else:
metrics_info_to_use = metrics_info
for metric_name in metrics_info_to_use:
if isinstance(metrics_info_to_use, list):
metric_vals = {} # Initialize an empty dictionary for metric parameters
elif isinstance(metrics_info_to_use, dict):
metric_vals = metrics_info_to_use[metric_name] # Get metric parameters from the provided dictionary
else:
raise NotImplementedError # Raise an error for unsupported input types
pl.seed_everything(seed, verbose=False) # Seed the random number generator
# Check if metric_name is the special FakeMetricCollectionMetric
metric_name, true_metric_name, metric_vals = handle_FakeMetricCollection(metric_name, metric_vals, *additional_modules)
# Create a metric object using getattr and store it in the metrics dictionary
metrics[split_name][-1][true_metric_name] = get_function(metric_name, *additional_modules, custom_metrics, torchmetrics)(**metric_vals)
metrics[split_name][-1] = torch.nn.ModuleDict(metrics[split_name][-1])
metrics[split_name] = torch.nn.ModuleList(metrics[split_name])
# Convert the metrics dictionary to a ModuleDict for easy handling
metrics = utils.RobustModuleDict(metrics)
return metrics
[docs]
def handle_FakeMetricCollection(metric_name, metric_params, *additional_modules):
# Check if the metric name is "FakeMetricCollectionMetric"
true_metric_name = metric_name
if "FakeMetricCollection" in metric_name:
metric_name,true_metric_name = metric_name.split(":")
# Get the actual class from the name
metric_params = {**metric_params, "metric_class": get_function(true_metric_name, *additional_modules, custom_metrics, torchmetrics)} #to avoid overwriting the original metric_params
return metric_name, true_metric_name, metric_params
[docs]
def prepare_optimizer(name, params={}, seed=42):
pl.seed_everything(seed, verbose=False) # Seed the random number generator
# Return a lambda function that creates an optimizer based on the provided name and parameters
return lambda model_params: getattr(torch.optim, name)(model_params, **params)
[docs]
def prepare_scheduler(scheduler_info, seed=42, *additional_modules):
name = scheduler_info["name"]
params = scheduler_info.get("params", {})
# Seed the random number generator
if "warmup_params" not in scheduler_info.keys():
# Return a lambda function that creates a scheduler based on the provided name and parameters
return lambda optimizer: get_function(name, *additional_modules, torch.optim.lr_scheduler)(optimizer, **params)
# when there is a warmup
else:
wcfg = scheduler_info["warmup_params"]
warmup_epochs = wcfg.get("epochs", 0)
warmup_type = wcfg.get("type", "linear") # linear / constant / exponential / cosine / custom
start_factor = wcfg.get("start_factor", 0.0) # where warmup starts
end_factor = wcfg.get("end_factor", 1.0) # where warmup ends
custom_func = wcfg.get("function", None) # user-specified function(epoch, warmup_epochs)
def create_scheduler(optimizer):
def warmup_lambda(epoch):
# Allow fully custom warmup
if custom_func is not None:
return custom_func(epoch, warmup_epochs)
# Linear warmup
elif warmup_type == "linear":
t = epoch / float(max(1, warmup_epochs))
return start_factor + (end_factor - start_factor) * t
# Constant warmup (flat)
elif warmup_type == "constant":
return start_factor
# Exponential: start → end
elif warmup_type == "exponential":
t = epoch / float(max(1, warmup_epochs))
return start_factor * ((end_factor / start_factor) ** t)
# Cosine warmup
elif warmup_type == "cosine":
t = epoch / float(max(1, warmup_epochs))
return start_factor + (end_factor - start_factor) * (
0.5 * (1 - math.cos(math.pi * t))
)
else:
raise NotImplementedError(f"Unsupported warmup type: {warmup_type}. Please select from ['linear', 'constant', 'exponential', 'cosine', 'custom']")
# Warmup scheduler
warmup_sched = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lambda)
# Main scheduler from PyTorch
main_sched = get_function(
name, *additional_modules, torch.optim.lr_scheduler
)(optimizer, **params)
# Chain warmup + main
return torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[warmup_sched, main_sched],
milestones=[warmup_epochs]
)
return create_scheduler
[docs]
def prepare_model(model_cfg):
# Seed the random number generator for weight initialization
pl.seed_everything(model_cfg["seed"], verbose=False) # Seed the random number generator
# Create a model instance based on the provided configuration
model = BaseNN(**model_cfg)
return model
[docs]
def prepare_emission_tracker(experiment_id, **tracker_kwargs):
from codecarbon import EmissionsTracker
# Update the "output_dir" in tracker parameters to include the experiment_id
tracker_kwargs.pop("use", None)
tracker_kwargs["output_dir"] = tracker_kwargs.get("output_dir", "../out/log/") + experiment_id + "/"
print(f"Tracker output directory: {tracker_kwargs['output_dir']}")
tracker = EmissionsTracker(**tracker_kwargs)
return tracker
[docs]
def prepare_flops_profiler(model, experiment_id, **profiler_kwargs):
from deepspeed.profiling.flops_profiler import FlopsProfiler
profiler_kwargs.pop("use", None) # Remove 'use' key if it exists
output_dir = profiler_kwargs.pop("output_dir", "../out/log/")
profiler = FlopsProfiler(model, **profiler_kwargs)
profiler.output_dir = output_dir + experiment_id + "/"
print(f"Profiler output directory: {profiler.output_dir}")
return profiler
"""
# Prototype for logging different configurations for metrics and losses
def prepare_loss(loss_info):
'''
Prepare a loss function or multiple loss functions with different configurations.
Parameters:
- loss_info: Single loss function name or a list of loss function names with configurations.
Returns:
- loss: Dictionary containing loss functions and their respective configurations.
'''
if isinstance(loss_info, str):
iterate_on = {loss_info: {}}
elif isinstance(loss_info, list):
iterate_on = {metric_name: {} for metric_name in loss_info}
elif isinstance(loss_info, dict):
iterate_on = loss_info
else:
raise NotImplementedError
loss = {}
for loss_name, loss_params in iterate_on.items():
# Separate log_params from loss_params
loss_log_params = loss_params.pop("log_params", {})
loss_weight = loss_params.pop("weight", 1.0)
loss[loss_name] = {"loss": getattr(torch.nn, loss_name)(**loss_params), "log_params": loss_log_params, "weight": loss_weight}
return loss
def prepare_metrics(metrics_info):
'''
Prepare evaluation metrics or multiple metrics with different configurations.
Parameters:
- metrics_info: Single metric name or a list of metric names with configurations.
Returns:
- metrics: Dictionary containing metrics and their respective configurations.
'''
if isinstance(metrics_info, str):
iterate_on = {metrics_info: {}}
elif isinstance(metrics_info, list):
iterate_on = {metric_name: {} for metric_name in metrics_info}
elif isinstance(metrics_info, dict):
iterate_on = metrics_info
else:
raise NotImplementedError
metrics = {}
for metric_name, metric_params in iterate_on.items():
# Separate log_params from metric_params
metric_log_params = metric_params.pop("log_params", {})
metrics[metric_name] = {"metric": getattr(torchmetrics, metric_name)(**metric_params), "log_params": metric_log_params}
"""
# To solve OSError: [Errno 24] ---> Too many open files?
# sharing_strategy = "file_system"
# def set_worker_sharing_strategy(worker_id: int) -> None:
# torch.multiprocessing.set_sharing_strategy(sharing_strategy)
# torch.multiprocessing.set_sharing_strategy(sharing_strategy)
# Function to add experiment info to ModelCheckpoint
# def add_exp_info_to_ModelCheckpoint(callbacks_dict, add_to_dirpath):
# new_list = copy.deepcopy(callbacks_dict)
# for MC_index, dc in enumerate(new_list):
# if any([x == "ModelCheckpoint" for x in new_list]):
# break
# new_list[MC_index]["ModelCheckpoint"]["dirpath"] += str(add_to_dirpath)
# return new_list
# Function to express neurons per layers
# def express_neuron_per_layers(cfg_model_cfg, model_cfg):
# # probably not efficient since expressing all possible combinations
# num_neurons = model_cfg["num_neurons"]
# num_layers = model_cfg["num_layers"]
# neurons_per_layer = []
# for layer in num_layers:
# neurons_per_layer += list(it.product(num_neurons, repeat=layer))
# for cfg in [cfg_model_cfg, model_cfg]:
# cfg.pop('num_neurons', None)
# cfg.pop('num_layers', None)
# cfg["neurons_per_layer"] = neurons_per_layer
[docs]
def get_correct_package(name, *modules, raise_error=True):
# Check if name exists in any module, in order
for module in modules:
if hasattr(module, name):
return module
if raise_error:
raise NotImplementedError(f"The function/class {name} is not found in [{', '.join([module.__name__ for module in modules])}]")
else: #raise only a warning
print(f"Warning: The function/class {name} is not found in [{', '.join([module.__name__ for module in modules])}]")
[docs]
def complete_prepare_trainer(cfg, experiment_id, model_params=None, additional_module={}, raytune=False):
if model_params is None:
model_params = deepcopy(cfg["model"])
trainer_params = prepare_experiment_id(model_params["trainer_params"], experiment_id)
# Prepare callbacks and logger using the prepared trainer_params
trainer_params["callbacks"] = prepare_callbacks(trainer_params, getattr(additional_module,"callbacks",None))
trainer_params["logger"] = prepare_logger(trainer_params, getattr(additional_module,"loggers",None))
trainer_params["strategy"] = prepare_strategy(trainer_params, getattr(additional_module,"strategies",None))
trainer_params["plugins"] = prepare_plugins(trainer_params, getattr(additional_module,"plugins",None))
# Prepare the trainer using the prepared trainer_params
trainer = prepare_trainer(**trainer_params, raytune=raytune)
return trainer
[docs]
def complete_prepare_model(cfg, main_module, *additional_modules, model_params=None):
model_params = deepcopy(cfg["model"])
model_params["loss"] = prepare_loss(model_params["loss"], *[getattr(module,"losses",module) for module in additional_modules])
# Prepare the optimizer using configuration from cfg
model_params["optimizer"] = prepare_optimizer(**model_params["optimizer"])
# Prepare the scheduler using configuration from cfg
if model_params["scheduler"] is not None:
model_params["scheduler"] = prepare_scheduler(model_params["scheduler"], *[getattr(module,"schedulers",module) for module in additional_modules])
# Prepare the metrics using configuration from cfg
model_params["metrics"] = prepare_metrics(model_params["metrics"], *[getattr(module,"metrics",module) for module in additional_modules])
# Create the model using main_module, loss, and optimizer
model = process.create_model(main_module, **model_params)
return model
# Deprecated
[docs]
def prepare_profiler(trainer_params, additional_module=None, seed=42):
pl.seed_everything(seed, verbose=False) # Seed the random number generator
# Check if "profiler" is in trainer_params
if "profiler" in trainer_params:
if isinstance(trainer_params["profiler"], dict):
# Create profiler instances based on profiler names and parameters
profiler = get_single_callback(trainer_params["profiler"]["name"], trainer_params["profiler"]["params"], additional_module)
trainer_params["profiler"] = profiler
return trainer_params