-
Notifications
You must be signed in to change notification settings - Fork 17
/
train.py
98 lines (85 loc) · 3.15 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from lib.loader import get_loader, InfiniteLoader
from lib.model.recycle_gan import ReCycleGAN
from lib.utils import visualizeSingle
from lib import augmentations as aug
from parse import parse_train_args
from torch.autograd import Variable
from torchvision import transforms
from torch.utils import data
from tqdm import tqdm
import torch
import os
"""
This script define the training procedure of Re-cycle GAN
"""
def eval(args, model, video_a, video_b):
"""
Render for the first valid time step and visualize
Arg: args - The argparse object
model - The nn.Module represent the Re-cycle GAN model
video_a - The video sequence in domain A
video_b - The video sequence in domain B
"""
# BTCHW -> T * BCHW
true_a_seq = [frame.squeeze(1).to(args.device) for frame in torch.chunk(video_a, video_a.size(1), dim = 1)]
true_b_seq = [frame.squeeze(1).to(args.device) for frame in torch.chunk(video_b, video_b.size(1), dim = 1)]
# Form the input frame in original domain
true_a = true_a_seq[args.t]
true_b = true_b_seq[args.t]
# Render single image
model.eval()
with torch.no_grad():
images = model(
true_a = true_a,
true_b = true_b,
true_a_seq = true_a_seq[:args.t],
true_b_seq = true_b_seq[:args.t],
warning = False
)
visualizeSingle(images)
model.train()
def train(args):
"""
This function define the training procedure
Arg: args - The argparse argument
"""
# Create the data loader
loader = InfiniteLoader(
loader = data.DataLoader(
dataset = get_loader(args.dataset)(
root = [args.A, args.B],
transform = aug.Compose([
# aug.RandomRotate(10),
aug.RandomHorizontallyFlip(),
aug.ToTensor(),
aug.ToFloat(),
aug.Transpose(aug.BHWC2BCHW),
aug.Resize(size_tuple = (args.H, args.W)),
aug.Normalize()
]),
T = args.T,
t = args.t,
use_cv = True,
), batch_size = args.batch_size, shuffle = True
), max_iter = args.n_iter
)
# Create the model and initialize
model = ReCycleGAN(A_channel = args.A_channel, B_channel = args.B_channel, T = args.T, t = args.t, r = args.r, device = args.device)
if os.path.exists(args.resume):
model.load_state_dict(torch.load(args.resume))
model.train()
# Work!
bar = tqdm(loader)
for i, (video_a, video_b) in enumerate(bar):
# Update parameters
model.setInput(video_a, video_b)
model.backward()
bar.set_description("G: " + str(model.loss_G.item()) + " D: " + str(model.loss_D.item()))
bar.refresh()
# Record render result
if i % args.record_iter == 0 and i != 0:
torch.save(model.state_dict(), args.det)
eval(args, model, video_a, video_b)
if __name__ == '__main__':
args = parse_train_args()
train(args)