-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils.py
162 lines (133 loc) · 4.26 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import json
import math
import os
import random
import sys
# import visdom
import time
import numpy as np
import scipy.misc
import torch
def get_mask_index(filename, mask_object):
"""Get the index of an object from a metadata file
filename: name of the metadata file
mask_object: name of the object to get the mask index of
"""
with open(filename) as file:
masks_grayscale = json.load(file)["header"]["masks"]
out = {}
for mo in mask_object:
out[mo] = []
for o, idx in masks_grayscale.items():
for mo in mask_object:
if mo in o:
out[mo].append(idx)
return out
def checkpoint(epoch, model, log, opt):
"""Saves a checkpoint
epoch: epoch
model: model
log: a list gathering statistics about optimization through time
opt: the opt variable
"""
if not os.path.isdir(opt.checkpoint):
os.mkdir(opt.checkpoint)
with open(os.path.join(opt.checkpoint, "opt.txt"), "w") as f:
json.dump(vars(opt), f)
with open(os.path.join(opt.checkpoint, "log.txt"), "w") as f:
json.dump(log, f)
model.save(opt.checkpoint, epoch)
def stack(input_, n, seq):
"""Creates a patch of images for visualization
input_: the tensor of images
n: number of sample to include
seq: sequence length of the input
"""
cat1 = []
for i in range(n):
cat2 = []
for j in range(seq):
cat2.append(input_[i][j])
cat1.append(torch.cat(cat2, 2))
return torch.cat(cat1, 1).cpu().numpy()
def Viz(opt):
"""Visualization"""
if opt.visdom:
vis = visdom.Visdom()
visdom_id = "id_" + str(time.time())
time.sleep(1e-3)
if not os.path.isdir(opt.checkpoint):
os.mkdir(opt.checkpoint)
def inner(img, curve, epoch, batch_idx, nbatch, set_):
if opt.image_save and batch_idx % opt.image_save_interval == 0:
scipy.misc.imsave(
"%s/img_%04d_%s_%06d.png" % (opt.checkpoint, epoch, set_, batch_idx),
np.transpose(img, (1, 2, 0)),
)
if opt.visdom and batch_idx % opt.visdom_interval == 0:
options = {
"title": "Epoch %02d - %s - Batch %06d/%06d"
% (epoch, set_, batch_idx, nbatch)
}
vis.image(img=img, win=visdom_id + set_, env=opt.name)
for i, c in curve.items():
if len(c) > 1:
vis.line(
Y=np.array(c),
X=np.arange(len(c)),
win=visdom_id + str(i),
env=opt.name,
opts=dict(title=i),
)
return inner
def to_number(d):
"""Convert values of dict d to numbers"""
out = {}
for key, value in d.items():
try:
out[key] = value.data[0]
except RuntimeError:
try:
out[key] = value[0]
except TypeError:
out[key] = value
return out
def filter(d, patterns):
"""Returns a copy of d, containing only keys with given patterns
d: dict
patterns: list of patterns
"""
out = {}
for p in patterns:
for name, params in d.items():
if name.split(".")[0] == p:
out[name] = params
return out
class to_namespace:
def __init__(self, d):
"""Constructs a namespace from a dict
d: dict
"""
vars(self).update(dict([(key, value) for key, value in d.items()]))
def slice_epoch(l, n):
"""Creates n slices of range(l), with (approx.) same length"""
slice_size = math.ceil(l / n)
result = []
for i in range(n):
result.append([])
for j in range(slice_size):
if i * slice_size + j >= l:
return result
else:
result[i].append(i * slice_size + j)
return result
def disableBNRunningMeanStd(m):
classname = m.__class__.__name__
if classname.find("BatchNorm") != -1:
setattr(m, "_saved_momentum", m.momentum)
setattr(m, "momentum", 0)
def enableBNRunningMeanStd(m):
classname = m.__class__.__name__
if classname.find("BatchNorm") != -1:
setattr(m, "momentum", m._saved_momentum)
delattr(m, "_saved_momentum")