Skip to content

Commit b810744

Browse files
✨ Integrate Hydra configuration and adapt code to use it properly
1 parent 1ea9b8a commit b810744

File tree

3 files changed

+16306
-29522
lines changed

3 files changed

+16306
-29522
lines changed

configs/ddpm_config.yaml

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# config/diffusion_config.yaml
2+
# Configuration for diffusion model training
3+
4+
data:
5+
configuration_path: "dataset_config.notebook_api_pytorch_data.json"
6+
7+
# Model architecture configuration
8+
model:
9+
name: "diffusion_icenet"
10+
filter_size: 3 # Convolution kernel size
11+
n_filters_factor: 0.5 # Scaling factor for number of filters
12+
timesteps: 1000 # Number of diffusion steps (T)
13+
14+
# Training configuration
15+
train:
16+
seed: 45 # Random seed for reproducibility
17+
18+
# Optimizer settings
19+
optimizer:
20+
learning_rate: 5e-4 # Learning rate
21+
22+
# DataLoader settings
23+
dataloader:
24+
batch_size: 8 # Batch size
25+
n_workers: 8 # Number of data loading workers
26+
persistent_workers: true # Keep workers alive between epochs
27+
shuffle: false # Shuffle training data
28+
29+
# PyTorch Lightning Trainer settings
30+
trainer:
31+
accelerator: "auto" # Use auto-detection (GPU if available)
32+
devices: -1 # Use all available devices
33+
log_every_n_steps: 5 # Log frequency
34+
max_epochs: 75 # Maximum number of epochs
35+
num_sanity_val_steps: 1 # Validation steps before training
36+
fast_dev_run: false # Set to true for quick testing
37+
precision: 32 # Training precision (16, 32, or 64)
38+
39+
# Training callbacks
40+
callbacks:
41+
# Model checkpointing
42+
checkpoint:
43+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
44+
monitor: "val_accuracy" # Metric to monitor
45+
mode: "max" # Maximize or minimize the metric
46+
save_top_k: 3 # Save top 3 models
47+
filename: "diffusion-{epoch:02d}-{val_accuracy:.4f}"
48+
save_last: true # Also save the last checkpoint
49+
50+
# Early stopping
51+
early_stopping:
52+
_target_: lightning.pytorch.callbacks.EarlyStopping
53+
monitor: "val_accuracy" # Metric to monitor
54+
patience: 25 # Number of epochs to wait
55+
verbose: true # Print early stopping info
56+
mode: "max" # Maximize the metric
57+
58+
# Logging configuration (comment out if not needed)
59+
# loggers:
60+
# # TensorBoard logging
61+
# tensorboard:
62+
# _target_: lightning.pytorch.loggers.TensorBoardLogger
63+
# save_dir: "./logs"
64+
# name: "diffusion_experiment"
65+
66+
# Weights & Biases logging (uncomment if you use wandb)
67+
# wandb:
68+
# _target_: lightning.pytorch.loggers.WandbLogger
69+
# project: "diffusion_icenet"
70+
# log_model: true
71+
# offline: false
72+
73+
# Hydra configuration
74+
hydra:
75+
run:
76+
dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
77+
job:
78+
chdir: false # Don't change working directory

0 commit comments

Comments
 (0)