Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a safe wrapper for cross entropy. #142

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion luminoth/models/fasterrcnn/rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from luminoth.models.fasterrcnn.rcnn_target import RCNNTarget
from luminoth.models.fasterrcnn.roi_pool import ROIPoolingLayer
from luminoth.utils.losses import smooth_l1_loss
from luminoth.utils.safe_wrappers import (
safe_softmax_cross_entropy_with_logits
)
from luminoth.utils.vars import (
get_initializer, layer_summaries, variable_summaries,
get_activation_function
Expand Down Expand Up @@ -304,7 +307,7 @@ def loss(self, prediction_dict):

# We get cross entropy loss of each proposal.
cross_entropy_per_proposal = (
tf.nn.softmax_cross_entropy_with_logits(
safe_softmax_cross_entropy_with_logits(
labels=cls_target_one_hot, logits=cls_score_labeled
)
)
Expand Down
5 changes: 4 additions & 1 deletion luminoth/models/fasterrcnn/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from .rpn_target import RPNTarget
from .rpn_proposal import RPNProposal
from luminoth.utils.losses import smooth_l1_loss
from luminoth.utils.safe_wrappers import (
safe_softmax_cross_entropy_with_logits
)
from luminoth.utils.vars import (
get_initializer, layer_summaries, variable_summaries,
get_activation_function
Expand Down Expand Up @@ -257,7 +260,7 @@ def loss(self, prediction_dict):
cls_target = tf.one_hot(labels, depth=2)

# Equivalent to log loss
ce_per_anchor = tf.nn.softmax_cross_entropy_with_logits(
ce_per_anchor = safe_softmax_cross_entropy_with_logits(
labels=cls_target, logits=cls_score
)
prediction_dict['cross_entropy_per_anchor'] = ce_per_anchor
Expand Down
16 changes: 16 additions & 0 deletions luminoth/utils/safe_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import tensorflow as tf


def safe_softmax_cross_entropy_with_logits(
labels, logits, name='safe_cross_entropy'):
with tf.name_scope(name):
safety_condition = tf.greater(
tf.shape(logits)[0], 0, name='safety_condition'
)
return tf.cond(
safety_condition,
true_fn=lambda: tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits
),
false_fn=lambda: tf.constant([], dtype=logits.dtype)
)