-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun-dgn.py
131 lines (98 loc) · 4.32 KB
/
run-dgn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
run-dgn.py
Script to train the DGN on the FAT fat_dataset.
"""
import os
import math
import torch
import torch.nn as nn
from torch.distributions import Normal
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from dgn import DepthGenerativeNetwork
from dataset import custom_save_img, normalize_depth
from dataset import FATDataset, Rescale, RandomCrop, ToTensor, RandomVerticalFlip
import pendulum
cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if cuda else "cpu")
commit = '0.5'
save_dir = './log-{}'.format(commit)
fine_tune = '77000'
batch_size = 8
crop_size = 240
data_parallel = False
gradient_steps = 2*10**5
if __name__ == '__main__':
print(" - Train id: {}\n"
" - fine_tune model: {}\n"
" - batch_size: {}\n"
" - corp_size: {}\n".format(commit, fine_tune, batch_size, crop_size))
print(" - note: min max depth, rgb mul 1/255, no batch norm")
print(pendulum.now())
tuned = 0
if fine_tune != 'none':
tuned += int(fine_tune)
train_set = FATDataset("./dataset/fat", "train",
trans=Compose([RandomCrop(crop_size), Rescale(crop_size), RandomVerticalFlip(), ToTensor()]))
dataloader = DataLoader(train_set, batch_size, shuffle=True, num_workers=2)
# Pixel variance
sigma_f, sigma_i = 0.7, 2.0
# Learning rate
mu_f, mu_i = 5*10**(-5), 5*10**(-4)
mu, sigma = mu_f, sigma_f
if not os.path.exists(save_dir):
os.mkdir(save_dir)
model_path = os.path.join("./log-{}".format(commit), "model-{}.pt".format(fine_tune))
if os.path.exists(model_path):
model = torch.load(model_path, map_location=lambda storage, _: storage).to(device)
if type(model) is nn.DataParallel:
model = model.module
else:
# Create model and optimizer
model = DepthGenerativeNetwork(x_dim=6, y_dim=1, r_dim=256, h_dim=128, z_dim=64).to(device)
# Model optimisations
model = nn.DataParallel(model) if data_parallel else model
optimizer = torch.optim.Adam(model.parameters(), lr=mu)
# Number of gradient steps
s = 0
while True:
if s >= gradient_steps:
torch.save(model, "model-final.pt")
print(pendulum.now())
break
for _, batch in enumerate(dataloader):
img_d = batch['depth'].to(device)
img_l = batch['left'].to(device)
img_r = batch['right'].to(device)
img_cat = torch.cat([img_l, img_r], 1)
img_d_mu, img_d_q, kld = model(img_d, img_cat)
# If more than one GPU we must take new shape into account
batch_size = img_d_q.size(0)
# Negative log likelihood
nll = - Normal(img_d_mu, sigma).log_prob(img_d_q)
reconstruction = torch.mean(nll.view(batch_size, -1), dim=0).sum()
kl_divergence = torch.mean(kld.view(batch_size, -1), dim=0).sum()
# Evidence lower bound
elbo = reconstruction + kl_divergence
elbo.backward()
optimizer.step()
optimizer.zero_grad()
s += 1
# Keep a checkpoint every 1000 steps
if s % 1000 == 0:
torch.save(model, os.path.join(save_dir, "model-{}.pt".format(s+tuned)))
print("model-{}.pt saved.".format(s+tuned))
# Annealing the parameters every 100 steps
if s % 100 == 0:
with torch.no_grad():
print("|Steps: {}\t|NLL: {}\t|KL: {}\t|sigma: {}\t|".format(s+tuned, reconstruction.item(),
kl_divergence.item(), sigma))
if s % 500 == 0:
img_show = torch.cat([img_d_q, normalize_depth(img_d_mu)], 0)
custom_save_img(img_show, os.path.join(save_dir, "result_{}.png".format(s+tuned)),
n_row=batch_size)
# Anneal learning rate
mu = max(mu_f + (mu_i - mu_f) * (1 - (s+tuned) / (1.6 * 10 ** 6)), mu_f)
optimizer.lr = mu * math.sqrt(1 - 0.999 ** (s+tuned)) / (1 - 0.9 ** (s+tuned))
# Anneal pixel variance
sigma = max(sigma_f + (sigma_i - sigma_f) * (1 - (s+tuned) / gradient_steps), sigma_f)