Skip to content

Commit

Permalink
revert
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Dec 1, 2024
1 parent a255e90 commit 91bbd25
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
from tests.common.events import EventCounterCallback


@pytest.mark.parametrize('event', list(Event))
def test_event_values(event: Event):
assert event.name.lower() == event.value


class TestEventCalls:

eval_subset_num_batches = 1
Expand Down Expand Up @@ -103,9 +108,9 @@ def get_trainer(self, precision='fp32', max_duration='1ep', save_interval='1ep',
)
@pytest.mark.parametrize('save_interval', ['1ep', '1ba'])
def test_event_calls(self, world_size, device, deepspeed_zero_stage, use_fsdp, precision, save_interval):
# Handle '1ba' save interval separately to optimize speed
# handle 1ba save interval separately to optimize speed
if save_interval == '1ba':
# Mock the save_checkpoint method to speed up batch saves
# mock the save_checkpoint method to speed up batch saves
with patch('composer.trainer.trainer.Trainer.save_checkpoint') as mock_save:
mock_save.return_value = None
self._run_event_calls_test(
Expand Down Expand Up @@ -160,8 +165,8 @@ def _run_event_calls_test(
save_interval=save_interval,
eval_interval=Time.from_timestring(save_interval),
)

trainer.fit()

self._assert_expected_event_calls(trainer, Time.from_timestring(save_interval), num_epochs=num_epochs)

def _assert_expected_event_calls(self, trainer: Trainer, eval_interval: Time, num_epochs: int):
Expand Down

0 comments on commit 91bbd25

Please sign in to comment.