Source code for easy_torch.utils

import torch
from typing import List, Tuple, Dict

# Class added because torch.nn.ModuleDict doesn't allow certain keys to be used if they conflict with existing class attributes
# https://github.com/pytorch/pytorch/issues/71203
SUFFIX = "____"
SUFFIX_LENGTH = len(SUFFIX)
[docs] class RobustModuleDict(torch.nn.ModuleDict): """ Torch ModuleDict wrapper that permits keys with any name, including those that would otherwise conflict with class attributes. Torch's `ModuleDict` does not allow certain keys (e.g., 'type', 'forward') because they clash with existing methods or attributes of `nn.Module`, raising errors like `KeyError`. Example: > torch.nn.ModuleDict({'type': torch.nn.Module()}) # Raises KeyError. > RobustModuleDict({'type': torch.nn.Module()}) # Works fine. This class mitigates possible conflicts by using a key-suffixing protocol. Args: init_dict (Dict[str, torch.nn.Module], optional): Initial dictionary of modules. If provided, it initializes the `RobustModuleDict` with these modules. Defaults to None. Returns: None """ def __init__(self, init_dict: Dict[str, torch.nn.Module] = None) -> None: super().__init__() #self.module_dict = torch.nn.ModuleDict() if init_dict is not None: self.update(init_dict) # Retrieve a module using the original (unsuffixed) key. # Internally appends the suffix to avoid naming conflicts. def __getitem__(self, key) -> torch.nn.Module: return super().__getitem__(key + SUFFIX) # Set a module using the original (unsuffixed) key. # Internally appends the suffix to avoid naming conflicts. def __setitem__(self, key: str, module: torch.nn.Module) -> None: super().__setitem__(key + SUFFIX, module) # Return the number of modules in the dictionary. def __len__(self) -> int: return super().__len__() # Return a list of keys without the suffix. # This allows users to access the original keys.
[docs] def keys(self) -> List[str]: return [key[:-SUFFIX_LENGTH] for key in super().keys()]
# Return a list of modules.
[docs] def values(self) -> List[torch.nn.Module]: return list(super().values())
# def values(self) -> List[torch.nn.Module]: # return [module for _, module in self.items()] # Return a list of (key, module) tuples without the suffix in the keys. # This allows users to access both the original keys and their corresponding modules.
[docs] def items(self) -> List[Tuple[str, torch.nn.Module]]: return [ (key[:-SUFFIX_LENGTH], module) for key, module in super().items() ]
# Update the dictionary with a new set of modules.
[docs] def update(self, modules: Dict[str, torch.nn.Module]) -> None: for key, module in modules.items(): self[key] = module
# def __next__(self) -> None: # return next(iter(self)) # Allow iteration over the keys of the dictionary. def __iter__(self): return iter(self.keys()) # Check if a key exists in the dictionary. def __contains__(self, key: str) -> bool: return super().__contains__(key + SUFFIX)