From 3b42a4ac9cfcf5bcda10db8c08430e0116252536 Mon Sep 17 00:00:00 2001 From: Doctor G Date: Thu, 26 Feb 2026 18:44:34 +0000 Subject: [PATCH] Add early stopping support to LightningModel Add early_stopping, patience, monitor, and mode to default_train_params and wire up an EarlyStopping callback in train_on_dataset when enabled. Co-authored-by: Ona --- src/grelu/lightning/__init__.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/grelu/lightning/__init__.py b/src/grelu/lightning/__init__.py index a1e3473..7159fde 100644 --- a/src/grelu/lightning/__init__.py +++ b/src/grelu/lightning/__init__.py @@ -16,7 +16,7 @@ import pytorch_lightning as pl import torch from einops import rearrange -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import CSVLogger, WandbLogger from torch import Tensor, nn, optim from torch.utils.data import DataLoader @@ -57,6 +57,10 @@ "class_weights": None, "total_weight": None, "accumulate_grad_batches": 1, + "early_stopping": False, + "patience": 5, + "monitor": "val_loss", + "mode": "min", } @@ -545,6 +549,16 @@ def train_on_dataset( else: raise Exception("Checkpoint type must be a bool or dict") + # Early stopping + if self.train_params["early_stopping"]: + checkpoint_callbacks.append( + EarlyStopping( + monitor=self.train_params["monitor"], + patience=self.train_params["patience"], + mode=self.train_params["mode"], + ) + ) + # Get device accelerator, devices = self.parse_devices(self.train_params["devices"])