1+ # config/diffusion_config.yaml
2+ # Configuration for diffusion model training
3+
4+ data :
5+ configuration_path : " dataset_config.notebook_api_pytorch_data.json"
6+
7+ # Model architecture configuration
8+ model :
9+ name : " diffusion_icenet"
10+ filter_size : 3 # Convolution kernel size
11+ n_filters_factor : 0.5 # Scaling factor for number of filters
12+ timesteps : 1000 # Number of diffusion steps (T)
13+
14+ # Training configuration
15+ train :
16+ seed : 45 # Random seed for reproducibility
17+
18+ # Optimizer settings
19+ optimizer :
20+ learning_rate : 5e-4 # Learning rate
21+
22+ # DataLoader settings
23+ dataloader :
24+ batch_size : 8 # Batch size
25+ n_workers : 8 # Number of data loading workers
26+ persistent_workers : true # Keep workers alive between epochs
27+ shuffle : false # Shuffle training data
28+
29+ # PyTorch Lightning Trainer settings
30+ trainer :
31+ accelerator : " auto" # Use auto-detection (GPU if available)
32+ devices : -1 # Use all available devices
33+ log_every_n_steps : 5 # Log frequency
34+ max_epochs : 75 # Maximum number of epochs
35+ num_sanity_val_steps : 1 # Validation steps before training
36+ fast_dev_run : false # Set to true for quick testing
37+ precision : 32 # Training precision (16, 32, or 64)
38+
39+ # Training callbacks
40+ callbacks :
41+ # Model checkpointing
42+ checkpoint :
43+ _target_ : lightning.pytorch.callbacks.ModelCheckpoint
44+ monitor : " val_accuracy" # Metric to monitor
45+ mode : " max" # Maximize or minimize the metric
46+ save_top_k : 3 # Save top 3 models
47+ filename : " diffusion-{epoch:02d}-{val_accuracy:.4f}"
48+ save_last : true # Also save the last checkpoint
49+
50+ # Early stopping
51+ early_stopping :
52+ _target_ : lightning.pytorch.callbacks.EarlyStopping
53+ monitor : " val_accuracy" # Metric to monitor
54+ patience : 25 # Number of epochs to wait
55+ verbose : true # Print early stopping info
56+ mode : " max" # Maximize the metric
57+
58+ # Logging configuration (comment out if not needed)
59+ # loggers:
60+ # # TensorBoard logging
61+ # tensorboard:
62+ # _target_: lightning.pytorch.loggers.TensorBoardLogger
63+ # save_dir: "./logs"
64+ # name: "diffusion_experiment"
65+
66+ # Weights & Biases logging (uncomment if you use wandb)
67+ # wandb:
68+ # _target_: lightning.pytorch.loggers.WandbLogger
69+ # project: "diffusion_icenet"
70+ # log_model: true
71+ # offline: false
72+
73+ # Hydra configuration
74+ hydra :
75+ run :
76+ dir : ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
77+ job :
78+ chdir : false # Don't change working directory
0 commit comments