diff --git a/pytext/torchscript/module.py b/pytext/torchscript/module.py index bcc98a067..dfab60b73 100644 --- a/pytext/torchscript/module.py +++ b/pytext/torchscript/module.py @@ -69,12 +69,12 @@ def __init__( @torch.jit.script_method def forward( self, - dense_feat: List[List[float]], texts: Optional[List[str]] = None, # multi_texts is of shape [batch_size, num_columns] multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, + dense_feat: Optional[List[List[float]]] = None, ): inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, multi_texts), @@ -107,13 +107,13 @@ def __init__( @torch.jit.script_method def forward( self, - right_dense_feat: List[List[float]], - left_dense_feat: List[List[float]], texts: Optional[List[str]] = None, # multi_texts is of shape [batch_size, num_columns] multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, + right_dense_feat: Optional[List[List[float]]] = None, + left_dense_feat: Optional[List[List[float]]] = None, ): inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, multi_texts), @@ -150,7 +150,7 @@ def inference_interface(self, argument_type: str): # LANGUAGES = 3 # DENSE_FEAT = 4 - if (sel.argno != -1): + if self.argno != -1: raise RuntimeError("Cannot change argument type.") if argument_type == "texts": self.argno = TEXTS