Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Making seq2seq_model more torchscript-friendly
Browse files Browse the repository at this point in the history
Summary:
seq2seq_model.py changes:
- removes dependence on _ when unpacking the tensor dicts
- explicit None checks on dict feats

Differential Revision: D23673319

fbshipit-source-id: a50efc6418af57cb2fa89273ddb14d51ca13b0e5
  • Loading branch information
shreydesai authored and facebook-github-bot committed Sep 13, 2020
1 parent 3e7b626 commit 4d401d3
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions pytext/models/seq_models/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ class Seq2SeqModel(Model):
Sequence to sequence model using an encoder-decoder architecture.
"""

SRC_TOKENS_TENSORIZER_INDEX = 0
SRC_LENGTHS_TENSORIZER_INDEX = 1
TRG_TOKENS_TENSORIZER_INDEX = 0
TRG_LENGTHS_TENSORIZER_INDEX = 1

class Config(Model.Config):
class ModelInput(Model.Config.ModelInput):
src_seq_tokens: TokenTensorizer.Config = TokenTensorizer.Config()
Expand Down Expand Up @@ -107,8 +112,13 @@ def arrange_model_inputs(
torch.Tensor,
torch.Tensor,
]:
src_tokens, src_lengths, _ = tensor_dict["src_seq_tokens"]
trg_tokens, trg_lengths, _ = tensor_dict["trg_seq_tokens"]
src_seq_tokens = tensor_dict["src_seq_tokens"]
trg_seq_tokens = tensor_dict["trg_seq_tokens"]

src_tokens = src_seq_tokens[self.SRC_TOKENS_TENSORIZER_INDEX]
src_lengths = src_seq_tokens[self.SRC_LENGTHS_TENSORIZER_INDEX]
trg_tokens = trg_seq_tokens[self.TRG_TOKENS_TENSORIZER_INDEX]
trg_lengths = trg_seq_tokens[self.TRG_LENGTHS_TENSORIZER_INDEX]

def _shift_target(in_sequences, seq_lens, eos_idx, pad_idx):
shifted_sequence = GetTensor(
Expand Down Expand Up @@ -136,7 +146,9 @@ def _shift_target(in_sequences, seq_lens, eos_idx, pad_idx):
)

def arrange_targets(self, tensor_dict):
trg_tokens, trg_lengths, _ = tensor_dict["trg_seq_tokens"]
trg_seq_tokens = tensor_dict["trg_seq_tokens"]
trg_tokens = trg_seq_tokens[self.TRG_TOKENS_TENSORIZER_INDEX]
trg_lengths = trg_seq_tokens[self.TRG_LENGTHS_TENSORIZER_INDEX]
return (trg_tokens, trg_lengths)

def __init__(
Expand Down Expand Up @@ -196,7 +208,7 @@ def forward(
):
additional_features: List[List[torch.Tensor]] = []

if dict_feats:
if dict_feats is not None:
additional_features.append(list(dict_feats))

if contextual_token_embedding is not None:
Expand All @@ -206,7 +218,7 @@ def forward(
src_tokens, additional_features, src_lengths, trg_tokens
)

if dict_feats:
if dict_feats is not None:
(
output_dict["dict_tokens"],
output_dict["dict_weights"],
Expand Down

0 comments on commit 4d401d3

Please sign in to comment.