Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Multi Label Cross Entropy Loss
Browse files Browse the repository at this point in the history
Summary: Creating a BinaryCrossEntropy loss function with logit based loss

Differential Revision: D25440017

fbshipit-source-id: 84d481e6146e9210422f8d1b3f94691404ee608c
  • Loading branch information
Michael Marlen authored and facebook-github-bot committed Dec 15, 2020
1 parent 80bcfec commit 5068b68
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pytext/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .loss import (
AUCPRHingeLoss,
BinaryCrossEntropyLoss,
BinaryCrossEntropyWithLogitsLoss,
CosineEmbeddingLoss,
CrossEntropyLoss,
KLDivergenceBCELoss,
Expand All @@ -26,6 +27,7 @@
"CrossEntropyLoss",
"CosineEmbeddingLoss",
"BinaryCrossEntropyLoss",
"BinaryCrossEntropyWithLogitsLoss",
"MultiLabelSoftMarginLoss",
"KLDivergenceBCELoss",
"KLDivergenceCELoss",
Expand Down
23 changes: 23 additions & 0 deletions pytext/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,29 @@ def __call__(self, log_probs, targets, reduce=True):
)


class BinaryCrossEntropyWithLogitsLoss(Loss):
class Config(ConfigBase):
reduce: bool = True

def __call__(self, logits, targets, reduce=True):
"""
Computes 1-vs-all binary cross entropy loss for multiclass classification. However, unlike BinaryCrossEntropyLoss, we require targets to be a one-hot vector.
"""

target_labels = targets[0].float()

"""
`F.binary_cross_entropy_with_logits` requires the
output of the previous function be already a FloatTensor.
"""

loss = F.binary_cross_entropy_with_logits(
precision.maybe_float(logits), target_labels, reduction="none"
)

return loss.sum(-1).mean() if reduce else loss.sum(-1)


class BinaryCrossEntropyLoss(Loss):
class Config(ConfigBase):
reweight_negative: bool = True
Expand Down

0 comments on commit 5068b68

Please sign in to comment.