diff --git a/pytext/loss/__init__.py b/pytext/loss/__init__.py index 7a4932828..abeda0a7f 100644 --- a/pytext/loss/__init__.py +++ b/pytext/loss/__init__.py @@ -4,6 +4,7 @@ from .loss import ( AUCPRHingeLoss, BinaryCrossEntropyLoss, + BinaryCrossEntropyWithLogitsLoss, CosineEmbeddingLoss, CrossEntropyLoss, KLDivergenceBCELoss, @@ -26,6 +27,7 @@ "CrossEntropyLoss", "CosineEmbeddingLoss", "BinaryCrossEntropyLoss", + "BinaryCrossEntropyWithLogitsLoss", "MultiLabelSoftMarginLoss", "KLDivergenceBCELoss", "KLDivergenceCELoss", diff --git a/pytext/loss/loss.py b/pytext/loss/loss.py index 4fcebabd0..c99956399 100644 --- a/pytext/loss/loss.py +++ b/pytext/loss/loss.py @@ -69,6 +69,29 @@ def __call__(self, log_probs, targets, reduce=True): ) +class BinaryCrossEntropyWithLogitsLoss(Loss): + class Config(ConfigBase): + reduce: bool = True + + def __call__(self, logits, targets, reduce=True): + """ + Computes 1-vs-all binary cross entropy loss for multiclass classification. However, unlike BinaryCrossEntropyLoss, we require targets to be a one-hot vector. + """ + + target_labels = targets[0].float() + + """ + `F.binary_cross_entropy_with_logits` requires the + output of the previous function be already a FloatTensor. + """ + + loss = F.binary_cross_entropy_with_logits( + precision.maybe_float(logits), target_labels, reduction="none" + ) + + return loss.sum(-1).mean() if reduce else loss.sum(-1) + + class BinaryCrossEntropyLoss(Loss): class Config(ConfigBase): reweight_negative: bool = True