Skip to content

Commit

Permalink
[brief] Add more tests.
Browse files Browse the repository at this point in the history
[detailed]
- Expand the test coverage of the model functions to cover the new
  stopping function.
- Add a case to ensure accumulation works correctly when training on
  iteration.
  • Loading branch information
marovira committed May 3, 2024
1 parent ab35a02 commit 3ba57cd
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self) -> None:
"on_validation_batch_end": False,
"on_validation_end": False,
"have_metrics_improved": False,
"should_training_stop": False,
}

self.called_test_funs: dict[str, bool] = {
Expand Down Expand Up @@ -116,6 +117,10 @@ def have_metrics_improved(self) -> bool:
self.called_train_funs["have_metrics_improved"] = True
return True

def should_training_stop(self) -> bool:
self.called_train_funs["should_training_stop"] = True
return False

def on_testing_start(self) -> None:
self.called_test_funs["on_testing_start"] = True

Expand Down Expand Up @@ -150,6 +155,25 @@ def train_step(self, batch: torch.Tensor, state) -> None:
self.batches.append(as_np)


class AccumulationModel(hlm.Model):
def __init__(self, accumulation_steps: int = 1) -> None:
super().__init__("test-accumulation")
self.accumulation_steps = accumulation_steps
self.global_steps: int = 0
self.accumulated_steps: int = 0

def setup(self, fast_init: bool = False) -> None:
pass

def train_step(self, batch: torch.Tensor, state: hlt.TrainingState) -> None:
self.global_steps += 1
if self.global_steps % self.accumulation_steps == 0:
self.accumulated_steps += 1

assert state.global_iteration == self.global_steps
assert state.current_iteration == self.accumulated_steps


class TestTrainingUnit:
def test_from_str(self) -> None:
assert hlt.TrainingUnit.from_str("epoch") == hlt.TrainingUnit.EPOCH
Expand Down Expand Up @@ -348,3 +372,22 @@ def test_restart_iter(self, tmp_path: pathlib.Path) -> None:

def test_restart_epoch(self, tmp_path: pathlib.Path) -> None:
self.check_restart_trainer(hlt.TrainingUnit.EPOCH, tmp_path)

def check_accumulation(self, num_steps: int) -> None:
datamodule = RandomDatamodule()
model = AccumulationModel(num_steps)
trainer = hlt.Trainer(
train_unit=hlt.TrainingUnit.ITERATION,
total_steps=20,
use_cpu=True,
accumulation_steps=num_steps,
)

trainer.fit(model, datamodule)

def test_accumulation(self) -> None:
self.check_accumulation(1)
self.check_accumulation(2)
self.check_accumulation(4)
self.check_accumulation(5)
self.check_accumulation(10)

0 comments on commit 3ba57cd

Please sign in to comment.