Source code for easy_torch.callbacks

import time
import pytorch_lightning as pl
import platform, subprocess, time
import numpy as np
import torch

[docs] class TimeCallback(pl.callbacks.Callback): def __init__(self, log_params={}): self.custom_log = lambda name, value: self.log(name, value, **log_params)
[docs] def on_epoch_start(self): self.start_time = time.time()
[docs] def on_epoch_end(self, split_name): self.elapsed_time = time.time() - self.start_time self.custom_log(split_name+"_time", self.elapsed_time)
[docs] def on_train_epoch_start(self, trainer, pl_module): self.on_epoch_start()
[docs] def on_train_epoch_end(self, trainer, pl_module): self.on_epoch_end("train")
[docs] def on_validation_epoch_start(self, trainer, pl_module): self.on_epoch_start()
[docs] def on_validation_epoch_end(self, trainer, pl_module): self.on_epoch_end("val")
[docs] def on_test_epoch_start(self, trainer, pl_module): self.on_epoch_start()
[docs] def on_test_epoch_end(self, trainer, pl_module): self.on_epoch_end("test")
[docs] class TemperatureSlowdownCallback(pl.callbacks.Callback): def __init__(self, threshold=80, sleep_time=10, every_n_epochs=5, devices=slice(None), nvidia_smi_path=r"C:\Program Files\NVIDIA Corporation\NVSMI\nvidia-smi.exe"): self.epoch = 0 self.threshold = threshold self.sleep_time = sleep_time self.every_n_epochs = every_n_epochs self.devices = devices if platform.system() == "Windows": self.command = f'"{nvidia_smi_path}" --query-gpu=temperature.gpu --format=csv,noheader,nounits' elif platform.system() == "Linux": self.command = "nvidia-smi --query-gpu=temperature.gpu --format=csv,noheader,nounits" else: raise ValueError(f"Unsupported OS: {platform.system()}. Only Windows and Linux are supported.")
[docs] def on_validation_epoch_start(self, trainer, pl_module): self.epoch += 1 if self.epoch % self.every_n_epochs == 0: try: max_temp = self.threshold + 1 while max_temp > self.threshold: output = subprocess.check_output(self.command, shell=True, text=True, stderr=subprocess.PIPE, timeout=15) temps = np.array([int(t.strip()) for t in output.strip().split('\n') if t.strip().isdigit()]) max_temp = temps[self.devices].max() if max_temp > self.threshold: print(f"GPU temperature {max_temp}°C exceeds threshold {self.threshold}°C. Sleeping for {self.sleep_time} seconds.") time.sleep(self.sleep_time) except Exception as e: print(f"Error checking GPU temperature: {e}")
[docs] class TerminateOnNaNCallback(pl.callbacks.Callback):
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): loss = outputs.get("loss") if isinstance(outputs, dict) else outputs if loss is not None and torch.isnan(loss): print(f"NaN loss detected at batch {batch_idx}. Stopping training.") trainer.should_stop = True