diff --git a/src/model.py b/src/model.py index a199e1d..a384187 100644 --- a/src/model.py +++ b/src/model.py @@ -24,7 +24,8 @@ def __init__(self, ffn: int = 128, n_classes: int = None, do_normalization: bool = True, - pooling_strategy: str = 'avg'): + pooling_strategy: str = 'avg', + **kwargs): super(LSTMModel, self).__init__() self.embedding = torch.nn.Embedding(vocab_size, embed_dim)