Skip to content

Commit

Permalink
mock and reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Dec 1, 2024
1 parent e0a87af commit 8910f25
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,31 @@
from tests.common.events import EventCounterCallback


@pytest.fixture(scope='session')
@pytest.fixture
def train_dataset():
return RandomClassificationDataset(size=16)


@pytest.fixture(scope='session')
@pytest.fixture
def eval_dataset():
return RandomClassificationDataset(size=16)


@pytest.fixture(scope='session')
@pytest.fixture
def model():
return SimpleModel()


@pytest.fixture(scope='session')
def optimizer(model):
return torch.optim.Adam(model.parameters())
@pytest.fixture
def optimizer():

def _create_optimizer(model):
return torch.optim.Adam(model.parameters())

return _create_optimizer


@pytest.fixture(scope='session')
@pytest.fixture
def evaluator1(eval_dataset):
return DataLoader(
dataset=eval_dataset,
Expand All @@ -47,7 +51,7 @@ def evaluator1(eval_dataset):
)


@pytest.fixture(scope='session')
@pytest.fixture
def evaluator2(eval_dataset):
return DataLoader(
dataset=eval_dataset,
Expand Down Expand Up @@ -120,9 +124,14 @@ def test_event_calls(
evaluator2,
event_counter_callback,
):

def mock_forward(*args, **kwargs):
input_tensor = args[0]
batch_size = input_tensor.size(0)
return torch.zeros(batch_size, 2)

with patch.object(Trainer, 'save_checkpoint', return_value=None):
# mock forward method
with patch.object(model, 'forward', return_value=torch.tensor(0.0)):
with patch.object(model, 'forward', side_effect=mock_forward):
# initialize the Trainer with the current parameters
deepspeed_config = None
if deepspeed_zero_stage:
Expand Down Expand Up @@ -153,7 +162,7 @@ def test_event_calls(
eval_subset_num_batches=1,
max_duration='1ep',
save_interval=save_interval,
optimizers=optimizer,
optimizers=optimizer(model), # Create optimizer with the wrapped model
callbacks=[event_counter_callback],
device=device,
deepspeed_config=deepspeed_config,
Expand All @@ -166,7 +175,7 @@ def test_event_calls(
state = trainer_instance.state

assert state.dataloader_len is not None
total_steps = 1 * int(state.dataloader_len)
total_steps = 1 * int(state.dataloader_len) # 1 epoch
batch_size = state.train_dataloader.batch_size # type: ignore
assert batch_size is not None
assert state.device_train_microbatch_size is not None
Expand Down

0 comments on commit 8910f25

Please sign in to comment.