Snips & Tips
Snips & Tips
  • Snips & Tips
  • 📊Data Science
    • Polars Dataframe Library
    • Loading large data
    • Pandas
      • Pandas Apply Function
    • Apache Spark
      • Custom Transformer
    • Data Visualizations
    • Jupyter Notebooks
      • Jupyter Notebook Structure
    • Probability
    • Statistics
      • Statistical Tests
      • Z - Test
      • Hypothesis Testing
    • SQL
      • SQL Tips
      • Creating new columns
  • ☘️Deep Learning
    • Backpropagation in Deep Learning
    • Pytorch Early Stopping
    • Optimizers
  • Pytorch Tensor Shapes
  • 🔖Machine Learning
    • Handling Imbalanced Dataset
    • Time Series Forecasting
      • Hierarchical Time Series Forecasting
      • Facebook Prophet
      • Misc
    • Handling high dimensionality data
      • Weight of evidence and Information value
    • Debugging ML Models
    • Feature Engineering
      • Time Series
      • Outlier Detection
      • Categorical Encoding
      • Feature Scaling
  • 🐲DSA
    • Arrays
  • 🖥️WEB DEV
    • Typescript
    • React State Management
    • Redux Boilerplate
    • Intercept a HTTP request or response
    • this keyword
    • Array Methods
    • Throttle Debounce
    • Media Queries
    • React Typeahead Search
  • Replace text with React Component
  • 💻Product Analytics
    • Product Sense
    • Customer Segmentation
  • 🖥️Terminal
    • Terminal Commands
    • Jupyter Notebook 2 HTML
  • 🪛Tools and Libraries
    • Web Based
    • Databases
  • 🚟Backend
    • Fast API CRUD
    • Scalable APIs
  • 💸Quant Finance
    • Misc
    • Factor Investing
  • 🎮Game Dev
    • Misc
  • 🛠️Architecture
    • Docker
    • AWS CDK
  • 🦠Artificial Intelligence
    • AI Engg
Powered by GitBook
On this page

Was this helpful?

Edit on GitHub
  1. Deep Learning

Pytorch Early Stopping

Basic Early Stopping Implementation in PyTorch

import torch

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001):
        """
        Args:
            patience (int): How many epochs to wait after last time validation loss improved.
            min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
        """
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0  # Reset patience counter
        else:
            self.counter += 1  # Increase counter if no improvement

        return self.counter >= self.patience  # Return True if early stopping condition met

# Example Usage in a Training Loop
early_stopping = EarlyStopping(patience=5, min_delta=0.001)

for epoch in range(100):  # Example max epochs
    train_loss = train_model()  # Your training function
    val_loss = validate_model()  # Your validation function

    print(f"Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")

    if early_stopping(val_loss):
        print("Early stopping triggered!")
        break

Early Stopping with Model Checkpointing & Best Weights Restoration

import torch
import os

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001, save_path="best_model.pth"):
        """
        Args:
            patience (int): How many epochs to wait after the last improvement.
            min_delta (float): Minimum improvement to reset patience.
            save_path (str): Path to save the best model.
        """
        self.patience = patience
        self.min_delta = min_delta
        self.save_path = save_path
        self.best_loss = float("inf")
        self.counter = 0
        self.best_model = None

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)  # Save best model
        else:
            self.counter += 1

        if self.counter >= self.patience:
            print("Early stopping triggered!")
            self.restore_checkpoint(model)
            return True
        
        return False

    def save_checkpoint(self, model):
        """Save the model's state_dict."""
        torch.save(model.state_dict(), self.save_path)
        print(f"Checkpoint saved at {self.save_path}")

    def restore_checkpoint(self, model):
        """Load the best model's state_dict."""
        if os.path.exists(self.save_path):
            model.load_state_dict(torch.load(self.save_path))
            print("Restored best model weights.")

# Example Training Loop Usage
def train_model():
    return torch.randn(1).item() + 1  # Simulated train loss

def validate_model():
    return torch.randn(1).item()  # Simulated validation loss

model = torch.nn.Linear(10, 1)  # Example model
early_stopping = EarlyStopping(patience=3, min_delta=0.01)

for epoch in range(100):
    train_loss = train_model()
    val_loss = validate_model()

    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    if early_stopping(val_loss, model):
        break
PreviousBackpropagation in Deep LearningNextOptimizers

Last updated 2 months ago

Was this helpful?

☘️