Skip to content

Kyyle2114/Pytorch-Template-Accelerate

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

13 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

πŸš€ PyTorch Template with Accelerate

Notice: This repository is a modified version of @Kyyle2114/Pytorch-Template, enhanced to support training with Hugging Face's Accelerate and improved with comprehensive type hints, error handling, and best practices.

Python 3.10+ PyTorch Accelerate Code style: ruff

Structured and modular template for deep learning model training with hugging face's accelerate.

✨ Features

πŸ”§ Core Functionality

  • Distributed Training: Full support via Hugging Face Accelerate
  • Type Safety: Comprehensive type hints throughout the codebase
  • Error Handling: Robust exception handling with specific error types
  • Modular Design: Clean separation of concerns with well-defined interfaces
  • Configuration Management: YAML-based configuration with Pydantic validation

πŸ“Š Monitoring & Logging

  • Weights & Biases Integration: Automatic experiment tracking and visualization
  • Early Stopping: Configurable early stopping to prevent overfitting
  • Model Checkpointing: Automatic saving of best and last models
  • Progress Tracking: Detailed logging of training progress and metrics

πŸ—οΈ Code Quality

  • PEP 8 Compliance: Consistent code formatting with Ruff
  • Documentation: Google-style docstrings for all functions and classes
  • Input Validation: Pydantic-based config validation with informative error messages
  • Memory Management: Efficient memory usage with proper cleanup

πŸ”„ Training Pipeline

  • Learning Rate Scheduling: Cosine annealing with warm-up restarts
  • Gradient Accumulation: Support for effective batch size scaling
  • Mixed Precision: FP16 training for improved performance
  • Gradient Clipping: Configurable gradient norm clipping

πŸš€ Quick Start

Prerequisites

  • Python 3.10 or higher
  • CUDA-compatible GPU (optional but recommended)
  • Git

Installation

  1. Clone the repository:
git clone https://github.com/your-username/Pytorch-Template-Accelerate.git
cd Pytorch-Template-Accelerate
  1. Create virtual environment:
# Using conda
conda create -n pytorch-template python=3.10
conda activate pytorch-template
  1. Install dependencies:
pip install -r requirements.txt
  1. Configure Accelerate (first time only):
accelerate config

πŸƒβ€β™‚οΈ Running Training

Using the training script

chmod +x train.sh
bash train.sh

Using Python directly

# Single GPU
python main.py --config config/default.yaml

# Multi-GPU with Accelerate
accelerate launch main.py --config config/default.yaml

# With config overrides
python main.py --config config/default.yaml --set training.lr=0.0001 --set data.batch_size=32

πŸ“ Project Structure

Pytorch-Template-Accelerate/
β”œβ”€β”€ config/
β”‚   β”œβ”€β”€ __init__.py          # Config package initialization
β”‚   β”œβ”€β”€ default.yaml         # Default configuration file
β”‚   └── schemas.py           # Pydantic configuration schemas
β”œβ”€β”€ engines/
β”‚   β”œβ”€β”€ __init__.py          # Engine package initialization
β”‚   └── engine_train.py      # Training and evaluation logic
β”œβ”€β”€ models/
β”‚   β”œβ”€β”€ __init__.py          # Models package initialization
β”‚   └── cnn.py               # CNN model implementation
β”œβ”€β”€ utils/
β”‚   β”œβ”€β”€ __init__.py          # Utils package initialization
β”‚   β”œβ”€β”€ datasets.py          # Dataset handling utilities
β”‚   β”œβ”€β”€ lr_sched.py          # Learning rate schedulers
β”‚   └── misc.py              # Miscellaneous utilities
β”œβ”€β”€ dataset/                 # Dataset storage 
β”œβ”€β”€ output_dir/              # Training outputs (auto-created)
β”œβ”€β”€ main.py                  # Main training script
β”œβ”€β”€ train.sh                 # Training execution script
β”œβ”€β”€ requirements.txt         # Python dependencies
β”œβ”€β”€ .gitignore               # Git ignore rules
└── README.md                # This file

βš™οΈ Configuration

YAML Configuration File

Configuration is managed through YAML files with Pydantic validation. The default configuration file is config/default.yaml:

# --- General Configuration ---
general:
  seed: 42
  output_dir: ./output_dir

# --- Data Configuration ---
data:
  dataset_path: ./dataset
  batch_size: 16
  num_workers: 4

# --- Training Configuration ---
training:
  epoch: 100
  patience: 50
  lr: 1e-3
  weight_decay: 1e-4
  grad_accum_steps: 1
  warmup_epochs: 10
  clip_grad: null   # null for no gradient clipping

# --- WandB Configuration ---
wandb:
  project_name: Model-Training
  run_name: Model-Training

Command Line Arguments

Argument Type Default Description
-c, --config str config/default.yaml Path to YAML configuration file
--help_config flag - Show detailed help for configuration parameters
--set KEY=VALUE - Override any config value (can be used multiple times)

Configuration Override Examples

# Override single value
python main.py --config config/default.yaml --set training.lr=0.0001

# Override multiple values
python main.py --config config/default.yaml \
    --set general.seed=42 \
    --set data.batch_size=32 \
    --set training.epoch=50

# Show config help
python main.py --help_config

Configuration Parameters

Section Parameter Type Default Description
general seed int 42 Random seed for reproducibility
output_dir str ./output_dir Output directory for checkpoints and logs
data dataset_path str ./dataset Dataset root path
batch_size int 16 Batch size per GPU
num_workers int 4 Number of data loading workers
training epoch int 100 Total number of training epochs
patience int 50 Early stopping patience
lr float 1e-3 Base learning rate
weight_decay float 1e-4 Weight decay for optimizer
grad_accum_steps int 1 Gradient accumulation steps
warmup_epochs int 10 Number of warmup epochs
clip_grad float/null null Gradient clipping norm (null for no clipping)
wandb project_name str Model-Training WandB project name
run_name str Model-Training WandB run name

πŸ”§ Customization

Adding a New Model

  1. Create your model in models/your_model.py:
import torch.nn as nn
from typing import Optional

class YourModel(nn.Module):
    def __init__(self, num_classes: int = 10) -> None:
        super().__init__()
        # Your model implementation
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Forward pass implementation
        return x
  1. Update models/__init__.py:
from .your_model import YourModel
__all__ = ['SimpleCNNforCIFAR10', 'YourModel']
  1. Modify main.py to use your model:
model = models.YourModel(num_classes=10)

Adding a New Dataset

  1. Implement your dataset in utils/datasets.py:
class YourDataset(Dataset):
    def __init__(self, data_path: str, transform: Optional[torch.nn.Module] = None):
        # Your dataset implementation
        pass
  1. Create a factory function:
def make_your_dataset(dataset_path: str, train: bool = True, transform=None):
    return YourDataset(dataset_path, transform)

Creating a Custom Configuration

  1. Create a new YAML file (e.g., config/my_config.yaml):
general:
  seed: 123
  output_dir: ./my_experiment

data:
  dataset_path: /path/to/your/data
  batch_size: 64
  num_workers: 8

training:
  epoch: 200
  patience: 30
  lr: 5e-4
  weight_decay: 1e-5
  grad_accum_steps: 2
  warmup_epochs: 5
  clip_grad: 1.0

wandb:
  project_name: My-Project
  run_name: experiment-1
  1. Run training with your config:
python main.py --config config/my_config.yaml

πŸ› Troubleshooting

Common Issues

1. CUDA Out of Memory

  • Reduce data.batch_size in your config
  • Increase training.grad_accum_steps to maintain effective batch size

2. WandB Login Issues

  • Set WANDB_API_KEY environment variable
  • Use wandb login command
  • Set WANDB_MODE=offline for offline logging

3. Import Errors

  • Ensure all dependencies are installed: pip install -r requirements.txt
  • Check Python version: requires 3.10+

4. Multi-GPU Training Issues

  • Run accelerate config to set up distributed training
  • Ensure CUDA_VISIBLE_DEVICES is set correctly in train.sh

5. Config Validation Errors

  • Check YAML syntax in your config file
  • Ensure all required fields are present
  • Use --help_config to see available parameters

πŸ“Š Monitoring

WandB Dashboard

Training metrics automatically logged to WandB include:

  • Training/validation loss
  • Learning rate schedule
  • Model accuracy
  • Best model metrics (epoch, loss, accuracy)
  • Hardware utilization
  • Hyperparameter configurations

Local Logs

  • JSON logs: output_dir/log.txt
  • Best model: output_dir/best_model/
  • Last model: output_dir/last_model/
  • Configuration: output_dir/config.yaml

🀝 Contributing

  1. Fork the repository
  2. Create a feature branch: git checkout -b feature/amazing-feature
  3. Make your changes with proper type hints and documentation
  4. Run tests and ensure code quality
  5. Commit your changes: git commit -m 'Add amazing feature'
  6. Push to the branch: git push origin feature/amazing-feature
  7. Open a Pull Request

πŸ“š References


⭐ Star this repository if you find it helpful!

About

Structured and modular template for deep learning model training with Hugging Face's Accelerate.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors