From f5d84d2001952bb8e14a58e371df0e126e00b64f Mon Sep 17 00:00:00 2001 From: Matteo Hessel Date: Tue, 3 Dec 2024 02:49:12 -0800 Subject: [PATCH] Remove longtime deprecated functions. PiperOrigin-RevId: 702264822 --- t5x/optimizers.py | 41 +++++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/t5x/optimizers.py b/t5x/optimizers.py index 2d663c8db..8b1283adf 100644 --- a/t5x/optimizers.py +++ b/t5x/optimizers.py @@ -174,6 +174,13 @@ def restore_state(self, state): # Optax Elementwise Wrapper +def _scale_by_schedule_ctor(state, params_axes): + del state, params_axes + return optax.ScaleByScheduleState( # pytype: disable=wrong-arg-types # numpy-scalars + count=None + ) + + class OptaxStatePartitionRules: """Collection of rules to partition optax states. @@ -218,16 +225,10 @@ class OptaxStatePartitionRules: mu=OptaxStatePartitionRules.derive_params_axes(state.mu, params_axes), nu=OptaxStatePartitionRules.derive_params_axes(state.nu, params_axes), ), - optax.ScaleByBeliefState: ( - lambda state, params_axes: optax.ScaleByBeliefState( # pytype: disable=wrong-arg-types # numpy-scalars - count=None, - mu=OptaxStatePartitionRules.derive_params_axes( - state.mu, params_axes - ), - nu=OptaxStatePartitionRules.derive_params_axes( - state.nu, params_axes - ), - ) + optax.ScaleByBeliefState: lambda state, params_axes: optax.ScaleByBeliefState( # pytype: disable=wrong-arg-types # numpy-scalars + count=None, + mu=OptaxStatePartitionRules.derive_params_axes(state.mu, params_axes), + nu=OptaxStatePartitionRules.derive_params_axes(state.nu, params_axes), ), optax.ScaleByLionState: lambda state, params_axes: optax.ScaleByLionState( # pytype: disable=wrong-arg-types # numpy-scalars count=None, @@ -258,9 +259,7 @@ class OptaxStatePartitionRules: optax.ScaleByTrustRatioState: ( lambda state, params_axes: optax.ScaleByTrustRatioState() ), - optax.ScaleByScheduleState: ( - lambda state, params_axes: optax.ScaleByScheduleState(count=None) # pytype: disable=wrong-arg-types # numpy-scalars - ), + optax.ScaleByScheduleState: _scale_by_schedule_ctor, optax.ZeroNansState: lambda state, params_axes: optax.ZeroNansState( found_nan=None ), @@ -272,14 +271,12 @@ class OptaxStatePartitionRules: state.inner_state, params_axes ) ), - optax.InjectHyperparamsState: ( - lambda state, params_axes: optax.InjectHyperparamsState( # pytype: disable=wrong-arg-types # jax-ndarray - count=None, - hyperparams=jax.tree.map(lambda x: None, state.hyperparams), - inner_state=OptaxStatePartitionRules.derive_optax_logical_axes( - state.inner_state, params_axes - ), - ) + optax.InjectHyperparamsState: lambda state, params_axes: optax.InjectHyperparamsState( # pytype: disable=wrong-arg-types # jax-ndarray + count=None, + hyperparams=jax.tree.map(lambda x: None, state.hyperparams), + inner_state=OptaxStatePartitionRules.derive_optax_logical_axes( + state.inner_state, params_axes + ), ), optax.MultiStepsState: lambda state, params_axes: optax.MultiStepsState( # pytype: disable=wrong-arg-types # jax-ndarray mini_step=None, @@ -299,7 +296,7 @@ class OptaxStatePartitionRules: ), ) ), - optax.MaybeUpdateState: lambda state, params_axes: optax.MaybeUpdateState( # pytype: disable=wrong-arg-types # jax-ndarray + optax.ConditionallyTransformState: lambda state, params_axes: optax.ConditionallyTransformState( # pytype: disable=wrong-arg-types # jax-ndarray inner_state=OptaxStatePartitionRules.derive_optax_logical_axes( state.inner_state, params_axes ),