-
Notifications
You must be signed in to change notification settings - Fork 2
/
04_training.py
55 lines (43 loc) · 2.14 KB
/
04_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Forth step of our approach: train a GCN."""
from torch_geometric.data import DataLoader
from pg_networks.gcn import GCN
from pg_networks.dynamic_edge import DynamicEdge
import src.config as cfg
from src.davis_2016 import DAVIS2016
from src.solver import Solver
from src.vis_utils import save_loss
if __name__ == "__main__":
# Train and val dataset
train = DAVIS2016(cfg.PYTORCH_GEOMETRIC_DAVIS_2016_DATASET_PATH,
cfg.ANNOTATIONS_AUGMENTED_FOLDERS_PATH,
cfg.CONTOURS_FOLDERS_PATH,
cfg.IMAGES_AUGMENTED_FOLDERS_PATH, cfg.TRANSLATIONS_FOLDERS_PATH,
cfg.PARENT_MODEL_PATH,
cfg.LAYER, cfg.K, cfg.AUGMENTATION_COUNT,
cfg.SKIP_SEQUENCES,
cfg.TRAIN_SEQUENCES[:cfg.NUM_TRAIN_SEQUENCES],
cfg.VAL_SEQUENCES[:cfg.NUM_VAL_SEQUENCES],
train=True)
val = DAVIS2016(cfg.PYTORCH_GEOMETRIC_DAVIS_2016_DATASET_PATH,
cfg.ANNOTATIONS_AUGMENTED_FOLDERS_PATH,
cfg.CONTOURS_FOLDERS_PATH,
cfg.IMAGES_AUGMENTED_FOLDERS_PATH, cfg.TRANSLATIONS_FOLDERS_PATH,
cfg.PARENT_MODEL_PATH,
cfg.LAYER, cfg.K, 0,
cfg.SKIP_SEQUENCES,
cfg.TRAIN_SEQUENCES[:cfg.NUM_TRAIN_SEQUENCES],
cfg.VAL_SEQUENCES[:cfg.NUM_VAL_SEQUENCES],
train=False)
print("Train size: %i" % len(train))
print("Val size: %i" % len(val))
# Train and val Dataloader
train_loader = DataLoader(train, batch_size=cfg.BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val, batch_size=cfg.BATCH_SIZE, shuffle=False)
# Load model and run the solver
model = GCN(in_channels=train[0].num_features,
out_channels=train[0].y.shape[1])
solver = Solver(optim_args={"lr": cfg.LEARNING_RATE,
"weight_decay": cfg.WEIGHT_DECAY})
solver.train(model, train_loader, val_loader,
num_epochs=cfg.NUM_EPOCHS, log_nth=1000 , verbose=True)
save_loss(solver)