-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_pairwise.py
32 lines (27 loc) · 1.05 KB
/
run_pairwise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from network.model import GMARAFT_Denoiser
from train.trainer import Trainer
from loader.loader_cine import CineDatasetPairwise
import json
import torch.utils.data as data
import os
json_file_path = "/path_to/configs/train_cine_group.json"
with open(json_file_path, 'r') as file:
config = json.load(file)
## load model
model = GMARAFT_Denoiser()
model.cuda()
model.train()
## read data
mode = 'debug' if config['debug'] else 'train'
train_dataset = CineDatasetPairwise(config['data_loader'], mode=mode)
train_loader = data.DataLoader(train_dataset,
batch_size=config['data_loader']['batch_size'],
pin_memory=True,
shuffle=True,
num_workers=config['data_loader']['num_workers'],
drop_last=True)
print('Loader has %d cine image pairs' % len(train_dataset))
print('steps per epoch', len(train_loader))
## run training
trainer = Trainer(config, model=model, data_loader=train_loader)
trainer.run()