From 3ba57cd58efe228506ec1cfe90884f5735aa4ffd Mon Sep 17 00:00:00 2001 From: "Mauricio A. Rovira Galvez" <8482308+marovira@users.noreply.github.com> Date: Fri, 3 May 2024 13:14:29 -0700 Subject: [PATCH] [brief] Add more tests. [detailed] - Expand the test coverage of the model functions to cover the new stopping function. - Add a case to ensure accumulation works correctly when training on iteration. --- test/test_trainer.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/test/test_trainer.py b/test/test_trainer.py index af5fc0f..8970bf3 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -57,6 +57,7 @@ def __init__(self) -> None: "on_validation_batch_end": False, "on_validation_end": False, "have_metrics_improved": False, + "should_training_stop": False, } self.called_test_funs: dict[str, bool] = { @@ -116,6 +117,10 @@ def have_metrics_improved(self) -> bool: self.called_train_funs["have_metrics_improved"] = True return True + def should_training_stop(self) -> bool: + self.called_train_funs["should_training_stop"] = True + return False + def on_testing_start(self) -> None: self.called_test_funs["on_testing_start"] = True @@ -150,6 +155,25 @@ def train_step(self, batch: torch.Tensor, state) -> None: self.batches.append(as_np) +class AccumulationModel(hlm.Model): + def __init__(self, accumulation_steps: int = 1) -> None: + super().__init__("test-accumulation") + self.accumulation_steps = accumulation_steps + self.global_steps: int = 0 + self.accumulated_steps: int = 0 + + def setup(self, fast_init: bool = False) -> None: + pass + + def train_step(self, batch: torch.Tensor, state: hlt.TrainingState) -> None: + self.global_steps += 1 + if self.global_steps % self.accumulation_steps == 0: + self.accumulated_steps += 1 + + assert state.global_iteration == self.global_steps + assert state.current_iteration == self.accumulated_steps + + class TestTrainingUnit: def test_from_str(self) -> None: assert hlt.TrainingUnit.from_str("epoch") == hlt.TrainingUnit.EPOCH @@ -348,3 +372,22 @@ def test_restart_iter(self, tmp_path: pathlib.Path) -> None: def test_restart_epoch(self, tmp_path: pathlib.Path) -> None: self.check_restart_trainer(hlt.TrainingUnit.EPOCH, tmp_path) + + def check_accumulation(self, num_steps: int) -> None: + datamodule = RandomDatamodule() + model = AccumulationModel(num_steps) + trainer = hlt.Trainer( + train_unit=hlt.TrainingUnit.ITERATION, + total_steps=20, + use_cpu=True, + accumulation_steps=num_steps, + ) + + trainer.fit(model, datamodule) + + def test_accumulation(self) -> None: + self.check_accumulation(1) + self.check_accumulation(2) + self.check_accumulation(4) + self.check_accumulation(5) + self.check_accumulation(10)