From a664f4dab76cf0f482f01f1701235acf3a77e33d Mon Sep 17 00:00:00 2001 From: IanTayler Date: Tue, 12 Dec 2017 18:25:52 -0300 Subject: [PATCH] Add a safe wrapper for cross entropy. This fixes an error we were getting when softmax_cross_entropy_with_logits received empty tensors. --- luminoth/models/fasterrcnn/rcnn.py | 5 ++++- luminoth/models/fasterrcnn/rpn.py | 5 ++++- luminoth/utils/safe_wrappers.py | 16 ++++++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) create mode 100644 luminoth/utils/safe_wrappers.py diff --git a/luminoth/models/fasterrcnn/rcnn.py b/luminoth/models/fasterrcnn/rcnn.py index ecf1dd73..50421f5d 100644 --- a/luminoth/models/fasterrcnn/rcnn.py +++ b/luminoth/models/fasterrcnn/rcnn.py @@ -5,6 +5,9 @@ from luminoth.models.fasterrcnn.rcnn_target import RCNNTarget from luminoth.models.fasterrcnn.roi_pool import ROIPoolingLayer from luminoth.utils.losses import smooth_l1_loss +from luminoth.utils.safe_wrappers import ( + safe_softmax_cross_entropy_with_logits +) from luminoth.utils.vars import ( get_initializer, layer_summaries, variable_summaries, get_activation_function @@ -304,7 +307,7 @@ def loss(self, prediction_dict): # We get cross entropy loss of each proposal. cross_entropy_per_proposal = ( - tf.nn.softmax_cross_entropy_with_logits( + safe_softmax_cross_entropy_with_logits( labels=cls_target_one_hot, logits=cls_score_labeled ) ) diff --git a/luminoth/models/fasterrcnn/rpn.py b/luminoth/models/fasterrcnn/rpn.py index bfaa5200..8921578a 100644 --- a/luminoth/models/fasterrcnn/rpn.py +++ b/luminoth/models/fasterrcnn/rpn.py @@ -10,6 +10,9 @@ from .rpn_target import RPNTarget from .rpn_proposal import RPNProposal from luminoth.utils.losses import smooth_l1_loss +from luminoth.utils.safe_wrappers import ( + safe_softmax_cross_entropy_with_logits +) from luminoth.utils.vars import ( get_initializer, layer_summaries, variable_summaries, get_activation_function @@ -257,7 +260,7 @@ def loss(self, prediction_dict): cls_target = tf.one_hot(labels, depth=2) # Equivalent to log loss - ce_per_anchor = tf.nn.softmax_cross_entropy_with_logits( + ce_per_anchor = safe_softmax_cross_entropy_with_logits( labels=cls_target, logits=cls_score ) prediction_dict['cross_entropy_per_anchor'] = ce_per_anchor diff --git a/luminoth/utils/safe_wrappers.py b/luminoth/utils/safe_wrappers.py new file mode 100644 index 00000000..31210f45 --- /dev/null +++ b/luminoth/utils/safe_wrappers.py @@ -0,0 +1,16 @@ +import tensorflow as tf + + +def safe_softmax_cross_entropy_with_logits( + labels, logits, name='safe_cross_entropy'): + with tf.name_scope(name): + safety_condition = tf.greater( + tf.shape(logits)[0], 0, name='safety_condition' + ) + return tf.cond( + safety_condition, + true_fn=lambda: tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits + ), + false_fn=lambda: tf.constant([], dtype=logits.dtype) + )