Skip to content

Commit

Permalink
Fix content_loss shape inference for Theano
Browse files Browse the repository at this point in the history
  • Loading branch information
titu1994 committed Aug 9, 2017
1 parent 669aa17 commit 037a8a7
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 4 deletions.
5 changes: 4 additions & 1 deletion INetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,10 @@ def style_loss(style, combination, mask_path=None, nb_channels=None):
def content_loss(base, combination):
channel_dim = 0 if K.image_dim_ordering() == "th" else -1

channels = K.int_shape(base)[channel_dim]
try:
channels = K.int_shape(base)[channel_dim]
except TypeError:
channels = K.shape(base)[channel_dim]
size = img_width * img_height

if args.content_loss_type == 1:
Expand Down
5 changes: 4 additions & 1 deletion Network.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,10 @@ def style_loss(style, combination, mask_path=None, nb_channels=None):
def content_loss(base, combination):
channel_dim = 0 if K.image_dim_ordering() == "th" else -1

channels = K.int_shape(base)[channel_dim]
try:
channels = K.int_shape(base)[channel_dim]
except TypeError:
channels = K.shape(base)[channel_dim]
size = img_width * img_height

if args.content_loss_type == 1:
Expand Down
5 changes: 4 additions & 1 deletion script_helper/Script/INetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,10 @@ def style_loss(style, combination, mask_path=None, nb_channels=None):
def content_loss(base, combination):
channel_dim = 0 if K.image_dim_ordering() == "th" else -1

channels = K.int_shape(base)[channel_dim]
try:
channels = K.int_shape(base)[channel_dim]
except TypeError:
channels = K.shape(base)[channel_dim]
size = img_width * img_height

if args.content_loss_type == 1:
Expand Down
5 changes: 4 additions & 1 deletion script_helper/Script/Network.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,10 @@ def style_loss(style, combination, mask_path=None, nb_channels=None):
def content_loss(base, combination):
channel_dim = 0 if K.image_dim_ordering() == "th" else -1

channels = K.int_shape(base)[channel_dim]
try:
channels = K.int_shape(base)[channel_dim]
except TypeError:
channels = K.shape(base)[channel_dim]
size = img_width * img_height

if args.content_loss_type == 1:
Expand Down

0 comments on commit 037a8a7

Please sign in to comment.