diff --git a/pytext/torchscript/module.py b/pytext/torchscript/module.py index 255db81ea..dfab60b73 100644 --- a/pytext/torchscript/module.py +++ b/pytext/torchscript/module.py @@ -69,12 +69,12 @@ def __init__( @torch.jit.script_method def forward( self, - dense_feat: List[List[float]], texts: Optional[List[str]] = None, # multi_texts is of shape [batch_size, num_columns] multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, + dense_feat: Optional[List[List[float]]] = None, ): inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, multi_texts), @@ -107,13 +107,13 @@ def __init__( @torch.jit.script_method def forward( self, - right_dense_feat: List[List[float]], - left_dense_feat: List[List[float]], texts: Optional[List[str]] = None, # multi_texts is of shape [batch_size, num_columns] multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, + right_dense_feat: Optional[List[List[float]]] = None, + left_dense_feat: Optional[List[List[float]]] = None, ): inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, multi_texts),