From 5068b6830c8a1d1dd046826ea18f4c481c9d49c1 Mon Sep 17 00:00:00 2001 From: Michael Marlen Date: Tue, 15 Dec 2020 09:19:54 -0800 Subject: [PATCH] Multi Label Cross Entropy Loss Summary: Creating a BinaryCrossEntropy loss function with logit based loss Differential Revision: D25440017 fbshipit-source-id: 84d481e6146e9210422f8d1b3f94691404ee608c --- pytext/loss/__init__.py | 2 ++ pytext/loss/loss.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/pytext/loss/__init__.py b/pytext/loss/__init__.py index 7a4932828..abeda0a7f 100644 --- a/pytext/loss/__init__.py +++ b/pytext/loss/__init__.py @@ -4,6 +4,7 @@ from .loss import ( AUCPRHingeLoss, BinaryCrossEntropyLoss, + BinaryCrossEntropyWithLogitsLoss, CosineEmbeddingLoss, CrossEntropyLoss, KLDivergenceBCELoss, @@ -26,6 +27,7 @@ "CrossEntropyLoss", "CosineEmbeddingLoss", "BinaryCrossEntropyLoss", + "BinaryCrossEntropyWithLogitsLoss", "MultiLabelSoftMarginLoss", "KLDivergenceBCELoss", "KLDivergenceCELoss", diff --git a/pytext/loss/loss.py b/pytext/loss/loss.py index 4fcebabd0..c99956399 100644 --- a/pytext/loss/loss.py +++ b/pytext/loss/loss.py @@ -69,6 +69,29 @@ def __call__(self, log_probs, targets, reduce=True): ) +class BinaryCrossEntropyWithLogitsLoss(Loss): + class Config(ConfigBase): + reduce: bool = True + + def __call__(self, logits, targets, reduce=True): + """ + Computes 1-vs-all binary cross entropy loss for multiclass classification. However, unlike BinaryCrossEntropyLoss, we require targets to be a one-hot vector. + """ + + target_labels = targets[0].float() + + """ + `F.binary_cross_entropy_with_logits` requires the + output of the previous function be already a FloatTensor. + """ + + loss = F.binary_cross_entropy_with_logits( + precision.maybe_float(logits), target_labels, reduction="none" + ) + + return loss.sum(-1).mean() if reduce else loss.sum(-1) + + class BinaryCrossEntropyLoss(Loss): class Config(ConfigBase): reweight_negative: bool = True