diff --git a/docmae/train.py b/docmae/train.py index 3592371..f61b060 100644 --- a/docmae/train.py +++ b/docmae/train.py @@ -123,7 +123,7 @@ def train(args, config: dict): filename="epoch_{epoch:02d}", monitor="val/loss", mode="min", - save_top_k=2, + save_top_k=1, ), ] @@ -138,6 +138,7 @@ def train(args, config: dict): max_steps=config["training"].get("steps", -1), num_sanity_val_steps=1, enable_progress_bar=config["progress_bar"], + limit_train_batches=20 ) hidden_dim = config["model"]["hidden_dim"]