Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Dec 1, 2024
1 parent e595c52 commit b72ad75
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions tests/test_events.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright ...
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import math
Expand Down Expand Up @@ -33,7 +33,7 @@ def get_trainer(self, precision='fp32', max_duration='1ep', save_interval='1ep',
# Minimal dataset size to reduce batches
train_dataset = RandomClassificationDataset(size=4)
eval_dataset = RandomClassificationDataset(size=4)
train_batch_size = 4
train_batch_size = 4

evaluator1 = DataLoader(
dataset=eval_dataset,
Expand Down Expand Up @@ -112,11 +112,36 @@ def test_event_calls(self, world_size, device, deepspeed_zero_stage, use_fsdp, p
# mock the save_checkpoint method to speed up batch saves
with patch('composer.trainer.trainer.Trainer.save_checkpoint') as mock_save:
mock_save.return_value = None
self._run_event_calls_test(world_size, device, deepspeed_zero_stage, use_fsdp, precision, save_interval, num_epochs=1)
self._run_event_calls_test(
world_size,
device,
deepspeed_zero_stage,
use_fsdp,
precision,
save_interval,
num_epochs=1,
)
else:
self._run_event_calls_test(world_size, device, deepspeed_zero_stage, use_fsdp, precision, save_interval, num_epochs=1)

def _run_event_calls_test(self, world_size, device, deepspeed_zero_stage, use_fsdp, precision, save_interval, num_epochs):
self._run_event_calls_test(
world_size,
device,
deepspeed_zero_stage,
use_fsdp,
precision,
save_interval,
num_epochs=1,
)

def _run_event_calls_test(
self,
world_size,
device,
deepspeed_zero_stage,
use_fsdp,
precision,
save_interval,
num_epochs,
):
save_interval = Time.from_timestring(save_interval)

deepspeed_config = None
Expand Down

0 comments on commit b72ad75

Please sign in to comment.