Pytorch Early Stopping
Basic Early Stopping Implementation in PyTorch
import torch
class EarlyStopping:
def __init__(self, patience=5, min_delta=0.001):
"""
Args:
patience (int): How many epochs to wait after last time validation loss improved.
min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
"""
self.patience = patience
self.min_delta = min_delta
self.best_loss = float('inf')
self.counter = 0
def __call__(self, val_loss):
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0 # Reset patience counter
else:
self.counter += 1 # Increase counter if no improvement
return self.counter >= self.patience # Return True if early stopping condition met
# Example Usage in a Training Loop
early_stopping = EarlyStopping(patience=5, min_delta=0.001)
for epoch in range(100): # Example max epochs
train_loss = train_model() # Your training function
val_loss = validate_model() # Your validation function
print(f"Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")
if early_stopping(val_loss):
print("Early stopping triggered!")
break
Early Stopping with Model Checkpointing & Best Weights Restoration
import torch
import os
class EarlyStopping:
def __init__(self, patience=5, min_delta=0.001, save_path="best_model.pth"):
"""
Args:
patience (int): How many epochs to wait after the last improvement.
min_delta (float): Minimum improvement to reset patience.
save_path (str): Path to save the best model.
"""
self.patience = patience
self.min_delta = min_delta
self.save_path = save_path
self.best_loss = float("inf")
self.counter = 0
self.best_model = None
def __call__(self, val_loss, model):
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
self.save_checkpoint(model) # Save best model
else:
self.counter += 1
if self.counter >= self.patience:
print("Early stopping triggered!")
self.restore_checkpoint(model)
return True
return False
def save_checkpoint(self, model):
"""Save the model's state_dict."""
torch.save(model.state_dict(), self.save_path)
print(f"Checkpoint saved at {self.save_path}")
def restore_checkpoint(self, model):
"""Load the best model's state_dict."""
if os.path.exists(self.save_path):
model.load_state_dict(torch.load(self.save_path))
print("Restored best model weights.")
# Example Training Loop Usage
def train_model():
return torch.randn(1).item() + 1 # Simulated train loss
def validate_model():
return torch.randn(1).item() # Simulated validation loss
model = torch.nn.Linear(10, 1) # Example model
early_stopping = EarlyStopping(patience=3, min_delta=0.01)
for epoch in range(100):
train_loss = train_model()
val_loss = validate_model()
print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
if early_stopping(val_loss, model):
break
Last updated
Was this helpful?