diff --git a/src/main.py b/src/main.py index 180ee5e..09a2064 100644 --- a/src/main.py +++ b/src/main.py @@ -39,7 +39,7 @@ def main(path_to_data: str, Model = LSTMModel( vocab_size=vocab_size, - n_classes=2 + n_classes=n_classes ) if n_classes > 2: