Skip to content

Commit 50bf332

Browse files
author
Andrey Ryabtsev
committed
apply changes
1 parent 18ff37d commit 50bf332

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

Diff for: functions.py

+30
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,28 @@ def composite4(fg, bg, a):
99
im = im.astype(np.uint8)
1010
return im
1111

12+
def compose_image_withshift(alpha_pred,fg_pred,bg,seg):
13+
14+
image_sh=torch.zeros(fg_pred.shape).cuda()
15+
16+
for t in range(0,fg_pred.shape[0]):
17+
al_tmp=to_image(seg[t,...]).squeeze(2)
18+
where = np.array(np.where((al_tmp>0.1).astype(np.float32)))
19+
x1, y1 = np.amin(where, axis=1)
20+
x2, y2 = np.amax(where, axis=1)
21+
22+
#select shift
23+
n=np.random.randint(-(y1-10),al_tmp.shape[1]-y2-10)
24+
#n positive indicates shift to right
25+
alpha_pred_sh=torch.cat((alpha_pred[t,:,:,-n:],alpha_pred[t,:,:,:-n]),dim=2)
26+
fg_pred_sh=torch.cat((fg_pred[t,:,:,-n:],fg_pred[t,:,:,:-n]),dim=2)
27+
28+
alpha_pred_sh=(alpha_pred_sh+1)/2
29+
30+
image_sh[t,...]=fg_pred_sh*alpha_pred_sh + (1-alpha_pred_sh)*bg[t,...]
31+
32+
return Variable(image_sh.cuda())
33+
1234
def get_bbox(mask,R,C):
1335
where = np.array(np.where(mask))
1436
x1, y1 = np.amin(where, axis=1)
@@ -76,3 +98,11 @@ def to_image(rec0):
7698
rec0[rec0<0]=0
7799
return rec0
78100

101+
def write_tb_log(image,tag,log_writer,i):
102+
# image1
103+
output_to_show = image.cpu().data[0:4,...]
104+
output_to_show = (output_to_show + 1)/2.0
105+
grid = torchvision.utils.make_grid(output_to_show,nrow=4)
106+
107+
log_writer.add_image(tag, grid, i + 1)
108+

Diff for: networks.py

+97
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,100 @@ def forward(self, x):
243243
out = x + self.conv_block(x)
244244
return out
245245

246+
247+
##################################### Discriminators ####################################################
248+
249+
class MultiscaleDiscriminator(nn.Module):
250+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
251+
use_sigmoid=False, num_D=3, getIntermFeat=False):
252+
super(MultiscaleDiscriminator, self).__init__()
253+
self.num_D = num_D
254+
self.n_layers = n_layers
255+
self.getIntermFeat = getIntermFeat
256+
257+
for i in range(num_D):
258+
netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
259+
if getIntermFeat:
260+
for j in range(n_layers+2):
261+
setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
262+
else:
263+
setattr(self, 'layer'+str(i), netD.model)
264+
265+
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
266+
267+
def singleD_forward(self, model, input):
268+
if self.getIntermFeat:
269+
result = [input]
270+
for i in range(len(model)):
271+
result.append(model[i](result[-1]))
272+
return result[1:]
273+
else:
274+
return [model(input)]
275+
276+
def forward(self, input):
277+
num_D = self.num_D
278+
result = []
279+
input_downsampled = input
280+
for i in range(num_D):
281+
if self.getIntermFeat:
282+
model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
283+
else:
284+
model = getattr(self, 'layer'+str(num_D-1-i))
285+
result.append(self.singleD_forward(model, input_downsampled))
286+
if i != (num_D-1):
287+
input_downsampled = self.downsample(input_downsampled)
288+
return result
289+
290+
# Defines the PatchGAN discriminator with the specified arguments.
291+
class NLayerDiscriminator(nn.Module):
292+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
293+
super(NLayerDiscriminator, self).__init__()
294+
self.getIntermFeat = getIntermFeat
295+
self.n_layers = n_layers
296+
297+
kw = 4
298+
padw = int(np.ceil((kw-1.0)/2))
299+
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
300+
301+
nf = ndf
302+
for n in range(1, n_layers):
303+
nf_prev = nf
304+
nf = min(nf * 2, 512)
305+
sequence += [[
306+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
307+
norm_layer(nf), nn.LeakyReLU(0.2, True)
308+
]]
309+
310+
nf_prev = nf
311+
nf = min(nf * 2, 512)
312+
sequence += [[
313+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
314+
norm_layer(nf),
315+
nn.LeakyReLU(0.2, True)
316+
]]
317+
318+
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
319+
320+
if use_sigmoid:
321+
sequence += [[nn.Sigmoid()]]
322+
323+
if getIntermFeat:
324+
for n in range(len(sequence)):
325+
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
326+
else:
327+
sequence_stream = []
328+
for n in range(len(sequence)):
329+
sequence_stream += sequence[n]
330+
self.model = nn.Sequential(*sequence_stream)
331+
332+
def forward(self, input):
333+
if self.getIntermFeat:
334+
res = [input]
335+
for n in range(self.n_layers+2):
336+
model = getattr(self, 'model'+str(n))
337+
res.append(model(res[-1]))
338+
return res[1:]
339+
else:
340+
return self.model(input)
341+
342+

0 commit comments

Comments
 (0)