Skip to content

Commit 5f10450

Browse files
committed
Some more kron work. Figured out why some tests fail, implemented a deterministic rng state load but too slow so skipping some tests for now.
1 parent cd21e80 commit 5f10450

File tree

3 files changed

+186
-81
lines changed

3 files changed

+186
-81
lines changed

tests/test_optim.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def _build_params_dict_single(weight, bias, **kwargs):
290290
return [dict(params=bias, **kwargs)]
291291

292292

293-
@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*')))
293+
@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*', 'kron*')))
294294
def test_optim_factory(optimizer):
295295
assert issubclass(get_optimizer_class(optimizer, bind_defaults=False), torch.optim.Optimizer)
296296

@@ -386,6 +386,14 @@ def test_adam(optimizer):
386386
_test_model(optimizer, dict(lr=5e-2))
387387

388388

389+
@pytest.mark.parametrize('optimizer', ['kron'])
390+
def test_kron(optimizer):
391+
_test_rosenbrock(
392+
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
393+
)
394+
_test_model(optimizer, dict(lr=1e-3))
395+
396+
389397
@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
390398
def test_adopt(optimizer):
391399
_test_rosenbrock(

timm/optim/_optim_factory.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -697,9 +697,16 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
697697
OptimInfo(
698698
name='kron',
699699
opt_class=Kron,
700-
description='',
700+
description='PSGD optimizer with Kronecker-factored preconditioner',
701701
has_momentum=True,
702702
),
703+
OptimInfo(
704+
name='kronw',
705+
opt_class=Kron,
706+
description='PSGD optimizer with Kronecker-factored preconditioner and decoupled weight decay',
707+
has_momentum=True,
708+
defaults={'decoupled_decay': True}
709+
),
703710
OptimInfo(
704711
name='laprop',
705712
opt_class=LaProp,

0 commit comments

Comments
 (0)