-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
68 lines (54 loc) · 2.15 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import torch
import hydra
from omegaconf import OmegaConf, DictConfig
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from datasets import build_dataset
from models import build_model
from engine import build_module
torch.set_float32_matmul_precision("high")
# sets seeds for numpy, torch and python.random.
seed_everything(42, workers=True)
@hydra.main(config_path="conf", config_name="", version_base="1.3")
def run(cfg: DictConfig):
# This will throw an error if any required fields (marked with ???)
# are missing
OmegaConf.to_container(cfg, throw_on_missing=True)
assert cfg.image_size[0] == cfg.image_size[1], "Image size must be square"
model_name = (
f"{cfg.name}_{cfg.model.name}_{cfg.model.backbone}"
f"_{cfg.dataset.name}_{cfg.image_size[0]}x{cfg.image_size[1]}")
save_path = os.path.join(cfg.path.output_dir, model_name)
train_loader = build_dataset(cfg, "train")
val_loader = build_dataset(cfg, "val")
model = build_model(cfg)
pretrained = cfg.model.get("pretrained", "")
if len(pretrained) > 0:
model.init_weights(pretrained)
print(f"Loaded pretrained weights from {pretrained}")
module = build_module(cfg, model, save_path)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
ckpt_cb = ModelCheckpoint(
dirpath=os.path.join(save_path, "weight"),
filename="best",
monitor='val/total_loss',
mode='min',
save_top_k=1,
save_last=True)
callbacks = [lr_monitor, ckpt_cb]
logger = TensorBoardLogger(
save_dir=cfg.path.log_dir,
name=model_name)
trainer = Trainer(accelerator='gpu',
devices=[cfg.gpu],
precision=32,
max_epochs=cfg.num_epochs,
deterministic=True,
num_sanity_val_steps=1,
logger=logger,
callbacks=callbacks)
trainer.fit(module, train_loader, val_loader)
if __name__ == "__main__":
run()