Skip to content

Commit

Permalink
fix: get rid of reduce mocking (for testing)
Browse files Browse the repository at this point in the history
  • Loading branch information
flxst committed Jan 30, 2024
1 parent f4e3c56 commit dfbefcb
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions tests/test_gym.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from unittest.mock import call, patch
from unittest.mock import call

import torch

from modalities.batch import DatasetBatch
from modalities.gym import Gym
from modalities.running_env.fsdp.reducer import Reducer


def test_run_cpu_only(
Expand Down Expand Up @@ -37,15 +36,13 @@ def test_run_cpu_only(
llm_data_loader_mock.__len__ = lambda _: num_batches

gym = Gym(trainer=trainer, evaluator=evaluator_mock, loss_fun=loss_mock, num_ranks=num_ranks)
with patch.object(Reducer, "reduce", return_value=None) as reduce_mock:
gym.run(
model=nn_model_mock,
optimizer=optimizer_mock,
callback_interval_in_batches=int(num_batches),
train_data_loader=llm_data_loader_mock,
evaluation_data_loaders=[],
checkpointing=checkpointing_mock,
)
nn_model_mock.forward.assert_has_calls([call(b.samples) for b in batches])
optimizer_mock.step.assert_called()
reduce_mock.assert_called()
gym.run(
model=nn_model_mock,
optimizer=optimizer_mock,
callback_interval_in_batches=int(num_batches),
train_data_loader=llm_data_loader_mock,
evaluation_data_loaders=[],
checkpointing=checkpointing_mock,
)
nn_model_mock.forward.assert_has_calls([call(b.samples) for b in batches])
optimizer_mock.step.assert_called()

0 comments on commit dfbefcb

Please sign in to comment.