Skip to content

Commit

Permalink
Fix bug with --content_loss_type=1
Browse files Browse the repository at this point in the history
  • Loading branch information
titu1994 committed Jun 27, 2017
1 parent 98c688f commit 17c7b92
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
6 changes: 3 additions & 3 deletions INetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,13 +426,13 @@ def style_loss(style, combination, mask_path=None, nb_channels=None):
def content_loss(base, combination):
channel_dim = 0 if K.image_dim_ordering() == "th" else -1

channels = K.shape(base)[channel_dim]
channels = K.int_shape(base)[channel_dim]
size = img_width * img_height

if args.content_loss_type == 1:
multiplier = 1 / (2. * channels ** 0.5 * size ** 0.5)
multiplier = 1. / (2. * (channels ** 0.5) * (size ** 0.5))
elif args.content_loss_type == 2:
multiplier = 1 / (channels * size)
multiplier = 1. / (channels * size)
else:
multiplier = 1.

Expand Down
6 changes: 3 additions & 3 deletions Network.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,13 +422,13 @@ def style_loss(style, combination, mask_path=None, nb_channels=None):
def content_loss(base, combination):
channel_dim = 0 if K.image_dim_ordering() == "th" else -1

channels = K.shape(base)[channel_dim]
channels = K.int_shape(base)[channel_dim]
size = img_width * img_height

if args.content_loss_type == 1:
multiplier = 1 / (2. * channels ** 0.5 * size ** 0.5)
multiplier = 1. / (2. * (channels ** 0.5) * (size ** 0.5))
elif args.content_loss_type == 2:
multiplier = 1 / (channels * size)
multiplier = 1. / (channels * size)
else:
multiplier = 1.

Expand Down
6 changes: 3 additions & 3 deletions script_helper/Script/INetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,13 +426,13 @@ def style_loss(style, combination, mask_path=None, nb_channels=None):
def content_loss(base, combination):
channel_dim = 0 if K.image_dim_ordering() == "th" else -1

channels = K.shape(base)[channel_dim]
channels = K.int_shape(base)[channel_dim]
size = img_width * img_height

if args.content_loss_type == 1:
multiplier = 1 / (2. * channels ** 0.5 * size ** 0.5)
multiplier = 1. / (2. * (channels ** 0.5) * (size ** 0.5))
elif args.content_loss_type == 2:
multiplier = 1 / (channels * size)
multiplier = 1. / (channels * size)
else:
multiplier = 1.

Expand Down
6 changes: 3 additions & 3 deletions script_helper/Script/Network.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,13 +422,13 @@ def style_loss(style, combination, mask_path=None, nb_channels=None):
def content_loss(base, combination):
channel_dim = 0 if K.image_dim_ordering() == "th" else -1

channels = K.shape(base)[channel_dim]
channels = K.int_shape(base)[channel_dim]
size = img_width * img_height

if args.content_loss_type == 1:
multiplier = 1 / (2. * channels ** 0.5 * size ** 0.5)
multiplier = 1. / (2. * (channels ** 0.5) * (size ** 0.5))
elif args.content_loss_type == 2:
multiplier = 1 / (channels * size)
multiplier = 1. / (channels * size)
else:
multiplier = 1.

Expand Down

0 comments on commit 17c7b92

Please sign in to comment.