Pytorch Early Stopping

Basic Early Stopping Implementation in PyTorch

import torch

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001):
        """
        Args:
            patience (int): How many epochs to wait after last time validation loss improved.
            min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
        """
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0  # Reset patience counter
        else:
            self.counter += 1  # Increase counter if no improvement

        return self.counter >= self.patience  # Return True if early stopping condition met

# Example Usage in a Training Loop
early_stopping = EarlyStopping(patience=5, min_delta=0.001)

for epoch in range(100):  # Example max epochs
    train_loss = train_model()  # Your training function
    val_loss = validate_model()  # Your validation function

    print(f"Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")

    if early_stopping(val_loss):
        print("Early stopping triggered!")
        break

Early Stopping with Model Checkpointing & Best Weights Restoration

Last updated