Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Back out "Open-source non-autoregressive optimization" #1605

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions pytext/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,7 @@
NLLLoss,
PairwiseRankingLoss,
SourceType,
MaxMarginLoss,
)
from .regularized_loss import (
LabelSmoothingLoss,
SamplewiseLabelSmoothingLoss,
NARSequenceLoss,
NARSamplewiseSequenceLoss,
)
from .regularizer import UniformRegularizer, EntropyRegularizer, AdaptiveRegularizer


__all__ = [
Expand All @@ -46,12 +38,4 @@
"PairwiseRankingLoss",
"LabelSmoothedCrossEntropyLoss",
"SourceType",
"LabelSmoothingLoss",
"SamplewiseLabelSmoothingLoss",
"MaxMarginLoss",
"NARSequenceLoss",
"NARSamplewiseSequenceLoss",
"UniformRegularizer",
"EntropyRegularizer",
"AdaptiveRegularizer",
]
109 changes: 0 additions & 109 deletions pytext/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,6 @@
from torch import nn


def maybe_log_normalize(logits, logits_type, dim=-1):
"""Optionally log normalizes logits on the given dimension."""

if logits_type == SourceType.LOGITS:
return F.log_softmax(logits, dim)
elif logits_type == SourceType.PROBS:
return logits.log()
elif logits_type == SourceType.LOG_PROBS:
return logits
else:
raise NotImplementedError


class SourceType(Enum):
LOG_PROBS = "log_probs"
LOGITS = "logits"
Expand Down Expand Up @@ -620,99 +607,3 @@ def __call__(self, logits, targets, reduce=True):
self.label_smoothing_loss = label_smoothing_loss

return (1.0 - self.beta) * cross_entropy_loss + self.beta * label_smoothing_loss


class MaxMarginLoss(Loss):
"""
Computes a max-margin loss for structured prediction:
max(0, m + cost(Y',Y) + S(Y'|X) - S(Y|X))

Here, we require the score of the gold sequence S(Y|X) to be _at least_
m + cost(Y',Y) higher than the score of a hypothesis sequence S(Y'|X),
where m = margin and cost(Y',Y) = Hamming distance between Y' and Y.

To efficiently search for the sequence with the largest margin violation
when cost(Y',Y) is included, we greedily decode a sequence with score
S(Y'|X) + I[Y'!=Y]. Intuitively, this forces our model to score the
gold label above other candidate labels.
"""

class Config(ConfigBase):
# Enables m (when m > 0)
margin: float = 0.0
# Enables cost(Y',Y)
use_cost: bool = False
# Multiplies cost(Y',Y) with this amount
cost_scale: float = 1.0

def __init__(self, config, pad_index=1, *args, **kwargs):
self.margin = config.margin
self.use_cost = config.use_cost
self.cost_scale = config.cost_scale
self.pad_index = pad_index

def get_sequence_scores(self, logits, indices, mask):
"""
Computes the score of the sequence: sum_i S(Y_i|Y_<i, X) (note that
the scores do not have to be normalized)

Implementation-wise, using each index in `indices, we reference its
score in `logits`, then sum over the sequence dimension while ignoring
pad tokens.
"""

token_scores = logits.gather(2, indices.unsqueeze(2)).squeeze(2) * mask # B x T
sequence_score = token_scores.sum(1) # B

return sequence_score

def get_sequence_costs(self, preds, targets, mask):
"""
Computes the Hamming distance for each (pred, target) in the batch, which
is defined as sum_i I[pred_i != target_i]. Pad tokens are ignored.
Optionally, we can use `cost_scale` to increase/decrease the cost.
"""

return self.cost_scale * ((preds != targets) & mask).sum(1) # B

def cost_augment(self, logits, targets):
"""
Loss-augmented inference requires searching for the sequence with the
largest margin violation. Because we use Hamming distance as the
sequence-level cost, we can augment token-level scores with a scalar cost,
which satisfies this global property.

For example, given model scores [-0.5, 0, -1.0] (where i* = 1 is the
gold index), we add Hamming costs I[i != i*], resulting in cost-augmented
scores [0.5, 0, 0]. Here, because i = 0 is now the predicted index,
our model is required to push the gold score higher than other scores.
"""

sequence_cost = self.cost_scale * torch.ones_like(logits) # B x T x V
gold_cost = torch.zeros_like(targets).to(logits.dtype).unsqueeze(2) # B x T x 1
sequence_cost.scatter_(2, targets.unsqueeze(2), gold_cost)
logits.add_(sequence_cost)

return logits

def __call__(self, logits, targets, reduce=True):
if self.use_cost:
logits = self.cost_augment(logits, targets) # B x T x V

preds = logits.argmax(2) # B x T

mask = targets.ne(self.pad_index) # B x T
ref_scores = self.get_sequence_scores(logits, targets, mask) # B
hyp_scores = self.get_sequence_scores(logits, preds, mask) # B

loss = hyp_scores - ref_scores

if self.use_cost:
loss += self.get_sequence_costs(preds, targets, mask) # B

if self.margin > 0.0:
loss += self.margin

loss.clip_(min=0.0) # B

return loss.sum() if reduce else loss
Loading