diff --git a/pytext/models/output_layers/doc_classification_output_layer.py b/pytext/models/output_layers/doc_classification_output_layer.py index fb0b71711..f61ea7341 100644 --- a/pytext/models/output_layers/doc_classification_output_layer.py +++ b/pytext/models/output_layers/doc_classification_output_layer.py @@ -17,6 +17,7 @@ KLDivergenceCELoss, LabelSmoothedCrossEntropyLoss, MultiLabelSoftMarginLoss, + BinaryCrossEntropyWithLogitsLoss, ) from pytext.utils.label import get_label_weights from torch import jit @@ -43,6 +44,7 @@ class Config(OutputLayerBase.Config): loss: Union[ CrossEntropyLoss.Config, BinaryCrossEntropyLoss.Config, + BinaryCrossEntropyWithLogitsLoss.Config, MultiLabelSoftMarginLoss.Config, AUCPRHingeLoss.Config, KLDivergenceBCELoss.Config, @@ -83,6 +85,8 @@ def from_config( cls = BinaryClassificationOutputLayer elif isinstance(loss, MultiLabelSoftMarginLoss): cls = MultiLabelOutputLayer + elif isinstance(loss, BinaryCrossEntropyWithLogitsLoss): + cls = MultiLabelOutputLayer else: cls = MulticlassOutputLayer