import torch.nn.functional as F
import torchmetrics
import torch
# Custom Accuracy to compute accuracy with Soft Labels as a torchmetrics.Metric
[docs]
class SoftLabelsAccuracy(torchmetrics.Metric):
def __init__(self):
super().__init__()
# Initialize state variables for correct predictions and total examples
self.add_state("correct", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
[docs]
def update(self, input: torch.Tensor, target: torch.Tensor):
# Update correct predictions and total examples
self.correct += torch.sum(input.argmax(dim=1) == target.argmax(dim=1))
self.total += target.shape[0]
[docs]
def compute(self):
# Compute accuracy as the ratio of correct predictions to total examples
return self.correct.float() / self.total
# Function to compute accuracy for neural network predictions
# def nn_accuracy(y_hat, y):
# # Apply softmax to predictions and get the class with the highest probability
# soft_y_hat = F.softmax(y_hat).argmax(dim=-1)
# soft_y = y.argmax(dim=-1)
# # Calculate accuracy by comparing predicted and actual class labels
# acc = (soft_y_hat.int() == soft_y.int()).float().mean()
# return acc
# Custom Accuracy to compute accuracy with Soft Labels as a torch.Module
# class SoftLabelsAccuracy(torch.nn.Module):
# def __init__(self):
# super().__init__()
# def forward(self, preds: torch.Tensor, target: torch.Tensor):
# # Calculate accuracy by comparing predicted and actual class labels
# return (preds.argmax(dim=1) == target.argmax(dim=1)).float().mean()
[docs]
class BatchLength(torchmetrics.Metric):
"""
A metric to compute the average batch length.
Args:
batch_size (int): The size of the batch.
"""
def __init__(self, batch_dim=1):
super().__init__()
self.out_keys = ["avg", "max", "min"]
self.batch_dim = batch_dim
self.add_state("total", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("count", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("min", default=torch.tensor(float("inf")), dist_reduce_fx="min")
self.add_state("max", default=torch.tensor(float("-inf")), dist_reduce_fx="max")
[docs]
def update(self, batch: torch.Tensor, *args, **kwargs):
"""
Updates the metric with the current batch length.
Args:
batch_length (torch.Tensor): The length of the current batch.
"""
length = torch.tensor(batch.shape[self.batch_dim], dtype=torch.float32)
self.total += length
self.min = torch.minimum(self.min, length)
self.max = torch.maximum(self.max, length)
self.count += 1
[docs]
def compute(self):
"""
Computes and returns the average batch length.
"""
out = {
"min": self.min,
"max": self.max,
"avg": self.total / self.count
}
return out
[docs]
class FakeMetricCollection(torchmetrics.MetricCollection):
# A collection of fake metrics that actually call just once the update / compute part
def __init__(self, metric_class, keys_name="out_keys", *args, **kwargs):
metric = metric_class(*args, **kwargs)
keys = sorted(getattr(metric, keys_name))
primary_key = str(keys[0])
metrics = {primary_key: make_fake_class(metric_class)(primary_key, *args, **kwargs)}
for k in keys[1:]:
metrics[str(k)] = FakeMetric(metrics[primary_key], key=k)
super().__init__(metrics)
[docs]
def make_fake_class(base_class):
class FakeTrueMetric(base_class):
def __init__(self, key, *args, **kwargs):
super().__init__(*args, **kwargs)
self.primary_key = str(key)
def compute(self, *args, **kwargs):
out = super().compute(*args, **kwargs)
for key,value in out.items():
setattr(self, str(key), value)
return out.get(self.primary_key)
return FakeTrueMetric
[docs]
class FakeMetric(torchmetrics.Metric):
def __init__(self, true_metric, key, *args, **kwargs):
super().__init__(*args, **kwargs)
self.true_metric = true_metric
self.key = key
[docs]
def update(self, *args, **kwargs):
pass
[docs]
def compute(self, *args, **kwargs):
return getattr(self.true_metric, self.key, None)