Custom Callbacks
This guide explains how to create your own custom callback class and make it
discoverable by the training library via the additional_module argument in
easy_torch.preparation.complete_prepare_trainer.
1. Folder Structure
Make sure your project is organized as follows:
project_root/
├── ntb/
│ └── your_notebook.ipynb
└── src/
├── __init__.py
└── my_additional_module/
├── __init__.py
└── callbacks.py
The my_additional_module folder will contain all your custom components,
including callbacks.
2. Define Your Callback
Inside src/my_additional_module/callbacks.py, define your custom callback
class. The class must inherit from pytorch_lightning.callbacks.Callback (or the
equivalent base callback class used by your framework).
# src/my_additional_module/callbacks.py
import pytorch_lightning as pl
class CustomCallback(pl.callbacks.Callback):
"""Logs learning rate at the end of each training epoch."""
def on_train_epoch_end(self, trainer, pl_module):
lr = trainer.optimizers[0].param_groups[0]["lr"]
print(f"Current learning rate: {lr:.6f}")
3. Package Initialization
Ensure the following minimal __init__.py files exist:
# src/__init__.py
from . import my_additional_module
# src/my_additional_module/__init__.py
from . import callbacks
from .callbacks import CustomCallback # optional
4. Make the Package Importable
If your notebook or script is located in ntb/, add the project root
directory to Python’s module search path at the top of your notebook:
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
Then import your module:
from src import my_additional_module
5. Reference the Callback in Your Config
In your configuration file or dictionary, specify the callback by name:
cfg = {
"model": {
"trainer_params": {
"callbacks": [
{"CustomCallback": {"param1": "value1"}}
]
}
}
}
6. Initialize the Trainer
When creating your trainer, pass the additional module that includes the callbacks:
from src import my_additional_module
trainer = easy_torch.preparation.complete_prepare_trainer(
cfg,
experiment_id,
additional_module=my_additional_module
)