-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
79 lines (67 loc) · 2.62 KB
/
main.py
File metadata and controls
79 lines (67 loc) · 2.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import sys
import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from utils.architecture import CNN, fit
from utils.dataset import prepare_subset, prepare_datasets
from utils.config import load_config
config = load_config("config.yaml")
# --- parse optional inputs ---
parser = argparse.ArgumentParser()
parser.add_argument("--beta", type=float, help="Override beta value")
parser.add_argument("--h", type=float, help="Override h value")
parser.add_argument("--save-interval", type=int, default=10, help="Save every N epochs")
args = parser.parse_args()
# --- apply overrides ---
beta = args.beta if args.beta is not None else config.parameters.beta
h = args.h if args.h is not None else config.parameters.h
device = "cuda" if torch.cuda.is_available() else "cpu"
grids, attrs, train_idx, valid_idx, test_idx = prepare_subset(
f"../data/gridstates_training_{beta:.3f}_{h:.3f}.hdf5",
test_size=config.dataset.test_size,
total_samples=1000
)
train_dl, valid_dl, test_dl, train_ds, valid_ds, test_ds = prepare_datasets(
grids, attrs, train_idx, valid_idx, test_idx,
device,
config.dataset.batch_size,
augment=config.dataset.augment
)
model = CNN(
channels=config.model.channels,
num_cnn_layers=config.model.num_cnn_layers,
num_fc_layers=config.model.num_fc_layers,
).to(device)
decay, no_decay = [], []
for name, param in model.named_parameters():
if "bn" in name or "bias" in name:
no_decay.append(param)
else:
decay.append(param)
optimizer = torch.optim.AdamW([
{'params': decay, 'weight_decay': config.training.weight_decay},
{'params': no_decay, 'weight_decay': 0.0}
], lr=config.training.lr)
if config.training.loss.lower() == "smoothl1loss":
loss_func = torch.nn.SmoothL1Loss()
elif config.training.loss.lower() == "mse":
loss_func = torch.nn.MSELoss()
else:
raise ValueError(f"Unsupported loss: {config.training.loss}")
# Create models directory with beta_h subdirectory
save_dir = Path(f"models/{beta:.3f}_{h:.3f}")
save_dir.mkdir(parents=True, exist_ok=True)
save_path = (
f"{config.paths.save_dir}/"
f"{config.model.type}_ch{config.model.channels}_cn{config.model.num_cnn_layers}_fc{config.model.num_fc_layers}_"
f"{beta:.3f}_{h:.3f}.pth"
)
print(f"Periodic checkpoints will be saved to: {save_dir}")
print(f"Model will be saved to: {save_path}")
print(f"Starting training for {config.model.type.upper()}...")
fit(config.training.epochs, model, loss_func, optimizer,
train_dl, valid_dl, device, config, save_path=save_path,
save_dir=str(save_dir), save_interval=args.save_interval)
print("Training complete.")