diff --git a/recommenders/models/newsrec/models/base_model.py b/recommenders/models/newsrec/models/base_model.py index e0f4da017..344ea8f7c 100644 --- a/recommenders/models/newsrec/models/base_model.py +++ b/recommenders/models/newsrec/models/base_model.py @@ -186,6 +186,8 @@ def fit( valid_behaviors_file, test_news_file=None, test_behaviors_file=None, + step_limit=None, + ): """Fit the model with train_file. Evaluate the model on valid_file per epoch to observe the training status. If test_news_file is not None, evaluate it too. @@ -212,6 +214,8 @@ def fit( ) for batch_data_input in tqdm_util: + if step_limit is not None and step>=step_limit: + break step_result = self.train(batch_data_input) step_data_loss = step_result diff --git a/tests/smoke/recommenders/recommender/test_newsrec_model.py b/tests/smoke/recommenders/recommender/test_newsrec_model.py index 7cad05ba3..db609f098 100644 --- a/tests/smoke/recommenders/recommender/test_newsrec_model.py +++ b/tests/smoke/recommenders/recommender/test_newsrec_model.py @@ -62,7 +62,7 @@ def test_model_nrms(mind_resource_path): assert model.run_eval(valid_news_file, valid_behaviors_file) is not None assert isinstance( model.fit( - train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file + train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file,step_limit=10 ), BaseModel, ) @@ -115,7 +115,7 @@ def test_model_naml(mind_resource_path): assert model.run_eval(valid_news_file, valid_behaviors_file) is not None assert isinstance( model.fit( - train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file + train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file,step_limit=10 ), BaseModel, ) @@ -166,7 +166,7 @@ def test_model_lstur(mind_resource_path): assert model.run_eval(valid_news_file, valid_behaviors_file) is not None assert isinstance( model.fit( - train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file + train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file,step_limit=10 ), BaseModel, ) @@ -217,7 +217,7 @@ def test_model_npa(mind_resource_path): assert model.run_eval(valid_news_file, valid_behaviors_file) is not None assert isinstance( model.fit( - train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file + train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file,step_limit=10 ), BaseModel, )