diff --git a/precondition/distributed_shampoo.py b/precondition/distributed_shampoo.py index 258810c..45f14a4 100644 --- a/precondition/distributed_shampoo.py +++ b/precondition/distributed_shampoo.py @@ -2166,7 +2166,7 @@ def sharded_init_fn(params): Args: params: the parameters that should be updated. """ - params_flat, treedef = jax.tree_flatten(params) + params_flat, treedef = jax.tree_util.tree_flatten(params) # Find max size to pad to. max_size = 0 for param in params_flat: @@ -2227,7 +2227,7 @@ def sharded_init_fn(params): index_start, sizes)) - local_stats = jax.tree_unflatten(treedef, local_stats_flat) + local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat) to_pad = -len(padded_statistics) % num_devices_for_pjit if max_size == 0: to_pad = num_devices_for_pjit @@ -2286,9 +2286,9 @@ def sharded_init_partition_spec_fn(params, params_partition_spec, partition_spec_for_statistics: PartitionSpec for the statistics. """ # Parallel lists of spec, and params. - param_pspec_flat, _ = jax.tree_flatten( + param_pspec_flat, _ = jax.tree_util.tree_flatten( params_partition_spec, is_leaf=lambda x: x is None) - params_flat, treedef = jax.tree_flatten(params) + params_flat, treedef = jax.tree_util.tree_flatten(params) assert param_pspec_flat assert params_flat # Step is replicated across cores. @@ -2333,7 +2333,7 @@ def sharded_init_partition_spec_fn(params, params_partition_spec, index_start, sizes)) - local_stats = jax.tree_unflatten(treedef, local_stats_flat) + local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat) global_stats = GlobalShardedParameterStats(partition_spec_for_statistics, # pytype: disable=wrong-arg-types # numpy-scalars partition_spec_for_statistics, jax.sharding.PartitionSpec()) @@ -2349,7 +2349,7 @@ def sharded_init_shape_and_dtype_fn(params): params: A pytree with params. """ # Parallel lists of spec, and params. - params_flat, treedef = jax.tree_flatten(params) + params_flat, treedef = jax.tree_util.tree_flatten(params) assert params_flat # Step is replicated across cores. # None means cores. @@ -2395,7 +2395,7 @@ def sharded_init_shape_and_dtype_fn(params): sizes, )) - local_stats = jax.tree_unflatten(treedef, local_stats_flat) + local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat) max_statistics_size = _max_statistics_size_from_params(params_flat) to_pad = -num_statistics % num_devices_for_pjit num_statistics += to_pad @@ -2427,7 +2427,7 @@ def sharded_update_fn(grads, state, params): Returns: A tuple containing the new parameters and the new optimizer state. """ - params_flat, treedef = jax.tree_flatten(params) + params_flat, treedef = jax.tree_util.tree_flatten(params) grads_flat = treedef.flatten_up_to(grads) global_stats = state.stats.global_stats @@ -2450,7 +2450,7 @@ def sharded_update_fn(grads, state, params): new_stats_flat, params_flat) updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) - updates = jax.tree_unflatten(treedef, updates_flat) + updates = jax.tree_util.tree_unflatten(treedef, updates_flat) new_local_stats_flat = [] for new_stat, local_stat in zip(new_stats_flat, local_stats_flat): new_local_stats_flat.append( @@ -2563,7 +2563,7 @@ def _update_preconditioners(): if generate_training_metrics: new_local_stats_flat = _add_metrics_into_local_stats( new_local_stats_flat, metrics, ~perform_step) - new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat) + new_local_stats = jax.tree_util.tree_unflatten(treedef, new_local_stats_flat) errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( @@ -3630,7 +3630,7 @@ def update_fn(grads, state, params): Returns: A tuple containing the new parameters and the new optimizer state. """ - params_flat, treedef = jax.tree_flatten(params) + params_flat, treedef = jax.tree_util.tree_flatten(params) stats_flat = treedef.flatten_up_to(state.stats) grads_flat = treedef.flatten_up_to(grads) stats_grads = grads_flat @@ -3646,8 +3646,8 @@ def update_fn(grads, state, params): new_stats_flat, params_flat) updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) - updates = jax.tree_unflatten(treedef, updates_flat) - new_stats = jax.tree_unflatten(treedef, new_stats_flat) + updates = jax.tree_util.tree_unflatten(treedef, updates_flat) + new_stats = jax.tree_util.tree_unflatten(treedef, new_stats_flat) new_state = ShampooState(count=state.count + 1, stats=new_stats) return updates, new_state