diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py index 9402d7af7..f8a88b1c9 100644 --- a/tests/test_schedulers.py +++ b/tests/test_schedulers.py @@ -37,11 +37,13 @@ def test_schedulers(self): constant_step = int(self.train_steps * self.constant_lr_ratio) remaining_step = self.train_steps - constant_step for _ in range(constant_step): + self.optimizer.step() self.lr_scheduler.step() self.assertEqual( self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio ) for _ in range(remaining_step): + self.optimizer.step() self.lr_scheduler.step() self.assertEqual( self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio