Custom Metrics
1. How to Define Metrics in model.yaml
You can define metrics in two ways:
Simple format
metrics:
- MetricName
Advanced format (for metrics that return dictionaries)
metrics:
- FakeMetricCollection:
metric_class: MetricName
If your metric returns a dictionary (e.g., multiple values like {"ndcg@5": ..., "ndcg@10": ...}), you need to wrap it using FakeMetricCollection to avoid issues with Lightning and TorchMetrics.
2. How to Load Metrics in Your Code
In your main.py, use the following:
from src import your_module
from easy_torch.preparation import prepare_metrics
metrics = prepare_metrics(cfg["model"]["metrics"], your_module)
The function will try to locate and load each metric class by searching in this order:
your_moduleeasy_lightning.metricstorchmetrics
If a class is not found, an error will be raised.
Loading from multiple modules
metrics = prepare_metrics(cfg["model"]["metrics"], [your_module, another_module])
3. Using the Metrics in the Standard Pipeline
If you’re using the standard EasyLightning pipeline, you can inject the metric list or collection into:
model_params["metrics"]
This ensures that all metrics are correctly used and logged during training and evaluation.
Additional Notes
For additional details about how inputs and outputs are handled internally, see the step_routing section.