@@ -243,3 +243,100 @@ def forward(self, x):
243
243
out = x + self .conv_block (x )
244
244
return out
245
245
246
+
247
+ ##################################### Discriminators ####################################################
248
+
249
+ class MultiscaleDiscriminator (nn .Module ):
250
+ def __init__ (self , input_nc , ndf = 64 , n_layers = 3 , norm_layer = nn .BatchNorm2d ,
251
+ use_sigmoid = False , num_D = 3 , getIntermFeat = False ):
252
+ super (MultiscaleDiscriminator , self ).__init__ ()
253
+ self .num_D = num_D
254
+ self .n_layers = n_layers
255
+ self .getIntermFeat = getIntermFeat
256
+
257
+ for i in range (num_D ):
258
+ netD = NLayerDiscriminator (input_nc , ndf , n_layers , norm_layer , use_sigmoid , getIntermFeat )
259
+ if getIntermFeat :
260
+ for j in range (n_layers + 2 ):
261
+ setattr (self , 'scale' + str (i )+ '_layer' + str (j ), getattr (netD , 'model' + str (j )))
262
+ else :
263
+ setattr (self , 'layer' + str (i ), netD .model )
264
+
265
+ self .downsample = nn .AvgPool2d (3 , stride = 2 , padding = [1 , 1 ], count_include_pad = False )
266
+
267
+ def singleD_forward (self , model , input ):
268
+ if self .getIntermFeat :
269
+ result = [input ]
270
+ for i in range (len (model )):
271
+ result .append (model [i ](result [- 1 ]))
272
+ return result [1 :]
273
+ else :
274
+ return [model (input )]
275
+
276
+ def forward (self , input ):
277
+ num_D = self .num_D
278
+ result = []
279
+ input_downsampled = input
280
+ for i in range (num_D ):
281
+ if self .getIntermFeat :
282
+ model = [getattr (self , 'scale' + str (num_D - 1 - i )+ '_layer' + str (j )) for j in range (self .n_layers + 2 )]
283
+ else :
284
+ model = getattr (self , 'layer' + str (num_D - 1 - i ))
285
+ result .append (self .singleD_forward (model , input_downsampled ))
286
+ if i != (num_D - 1 ):
287
+ input_downsampled = self .downsample (input_downsampled )
288
+ return result
289
+
290
+ # Defines the PatchGAN discriminator with the specified arguments.
291
+ class NLayerDiscriminator (nn .Module ):
292
+ def __init__ (self , input_nc , ndf = 64 , n_layers = 3 , norm_layer = nn .BatchNorm2d , use_sigmoid = False , getIntermFeat = False ):
293
+ super (NLayerDiscriminator , self ).__init__ ()
294
+ self .getIntermFeat = getIntermFeat
295
+ self .n_layers = n_layers
296
+
297
+ kw = 4
298
+ padw = int (np .ceil ((kw - 1.0 )/ 2 ))
299
+ sequence = [[nn .Conv2d (input_nc , ndf , kernel_size = kw , stride = 2 , padding = padw ), nn .LeakyReLU (0.2 , True )]]
300
+
301
+ nf = ndf
302
+ for n in range (1 , n_layers ):
303
+ nf_prev = nf
304
+ nf = min (nf * 2 , 512 )
305
+ sequence += [[
306
+ nn .Conv2d (nf_prev , nf , kernel_size = kw , stride = 2 , padding = padw ),
307
+ norm_layer (nf ), nn .LeakyReLU (0.2 , True )
308
+ ]]
309
+
310
+ nf_prev = nf
311
+ nf = min (nf * 2 , 512 )
312
+ sequence += [[
313
+ nn .Conv2d (nf_prev , nf , kernel_size = kw , stride = 1 , padding = padw ),
314
+ norm_layer (nf ),
315
+ nn .LeakyReLU (0.2 , True )
316
+ ]]
317
+
318
+ sequence += [[nn .Conv2d (nf , 1 , kernel_size = kw , stride = 1 , padding = padw )]]
319
+
320
+ if use_sigmoid :
321
+ sequence += [[nn .Sigmoid ()]]
322
+
323
+ if getIntermFeat :
324
+ for n in range (len (sequence )):
325
+ setattr (self , 'model' + str (n ), nn .Sequential (* sequence [n ]))
326
+ else :
327
+ sequence_stream = []
328
+ for n in range (len (sequence )):
329
+ sequence_stream += sequence [n ]
330
+ self .model = nn .Sequential (* sequence_stream )
331
+
332
+ def forward (self , input ):
333
+ if self .getIntermFeat :
334
+ res = [input ]
335
+ for n in range (self .n_layers + 2 ):
336
+ model = getattr (self , 'model' + str (n ))
337
+ res .append (model (res [- 1 ]))
338
+ return res [1 :]
339
+ else :
340
+ return self .model (input )
341
+
342
+
0 commit comments