Skip to content

Commit

Permalink
Merge pull request #2125 from sumana-2705/sumana-2705/steps_newsrec
Browse files Browse the repository at this point in the history
Modified 2 files to update newsrec model
  • Loading branch information
miguelgfierro committed Jul 7, 2024
2 parents b775038 + 65797dc commit ae2338c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
4 changes: 4 additions & 0 deletions recommenders/models/newsrec/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def fit(
valid_behaviors_file,
test_news_file=None,
test_behaviors_file=None,
step_limit=None,

):
"""Fit the model with train_file. Evaluate the model on valid_file per epoch to observe the training status.
If test_news_file is not None, evaluate it too.
Expand All @@ -212,6 +214,8 @@ def fit(
)

for batch_data_input in tqdm_util:
if step_limit is not None and step>=step_limit:
break

step_result = self.train(batch_data_input)
step_data_loss = step_result
Expand Down
8 changes: 4 additions & 4 deletions tests/smoke/recommenders/recommender/test_newsrec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_model_nrms(mind_resource_path):
assert model.run_eval(valid_news_file, valid_behaviors_file) is not None
assert isinstance(
model.fit(
train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file
train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file,step_limit=10
),
BaseModel,
)
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_model_naml(mind_resource_path):
assert model.run_eval(valid_news_file, valid_behaviors_file) is not None
assert isinstance(
model.fit(
train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file
train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file,step_limit=10
),
BaseModel,
)
Expand Down Expand Up @@ -166,7 +166,7 @@ def test_model_lstur(mind_resource_path):
assert model.run_eval(valid_news_file, valid_behaviors_file) is not None
assert isinstance(
model.fit(
train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file
train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file,step_limit=10
),
BaseModel,
)
Expand Down Expand Up @@ -217,7 +217,7 @@ def test_model_npa(mind_resource_path):
assert model.run_eval(valid_news_file, valid_behaviors_file) is not None
assert isinstance(
model.fit(
train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file
train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file,step_limit=10
),
BaseModel,
)

0 comments on commit ae2338c

Please sign in to comment.