Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: deprecation of tree_flatten / tree_unflatten #21

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions precondition/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2166,7 +2166,7 @@ def sharded_init_fn(params):
Args:
params: the parameters that should be updated.
"""
params_flat, treedef = jax.tree_flatten(params)
params_flat, treedef = jax.tree_util.tree_flatten(params)
# Find max size to pad to.
max_size = 0
for param in params_flat:
Expand Down Expand Up @@ -2227,7 +2227,7 @@ def sharded_init_fn(params):
index_start,
sizes))

local_stats = jax.tree_unflatten(treedef, local_stats_flat)
local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat)
to_pad = -len(padded_statistics) % num_devices_for_pjit
if max_size == 0:
to_pad = num_devices_for_pjit
Expand Down Expand Up @@ -2286,9 +2286,9 @@ def sharded_init_partition_spec_fn(params, params_partition_spec,
partition_spec_for_statistics: PartitionSpec for the statistics.
"""
# Parallel lists of spec, and params.
param_pspec_flat, _ = jax.tree_flatten(
param_pspec_flat, _ = jax.tree_util.tree_flatten(
params_partition_spec, is_leaf=lambda x: x is None)
params_flat, treedef = jax.tree_flatten(params)
params_flat, treedef = jax.tree_util.tree_flatten(params)
assert param_pspec_flat
assert params_flat
# Step is replicated across cores.
Expand Down Expand Up @@ -2333,7 +2333,7 @@ def sharded_init_partition_spec_fn(params, params_partition_spec,
index_start,
sizes))

local_stats = jax.tree_unflatten(treedef, local_stats_flat)
local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat)
global_stats = GlobalShardedParameterStats(partition_spec_for_statistics, # pytype: disable=wrong-arg-types # numpy-scalars
partition_spec_for_statistics,
jax.sharding.PartitionSpec())
Expand All @@ -2349,7 +2349,7 @@ def sharded_init_shape_and_dtype_fn(params):
params: A pytree with params.
"""
# Parallel lists of spec, and params.
params_flat, treedef = jax.tree_flatten(params)
params_flat, treedef = jax.tree_util.tree_flatten(params)
assert params_flat
# Step is replicated across cores.
# None means cores.
Expand Down Expand Up @@ -2395,7 +2395,7 @@ def sharded_init_shape_and_dtype_fn(params):
sizes,
))

local_stats = jax.tree_unflatten(treedef, local_stats_flat)
local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat)
max_statistics_size = _max_statistics_size_from_params(params_flat)
to_pad = -num_statistics % num_devices_for_pjit
num_statistics += to_pad
Expand Down Expand Up @@ -2427,7 +2427,7 @@ def sharded_update_fn(grads, state, params):
Returns:
A tuple containing the new parameters and the new optimizer state.
"""
params_flat, treedef = jax.tree_flatten(params)
params_flat, treedef = jax.tree_util.tree_flatten(params)
grads_flat = treedef.flatten_up_to(grads)

global_stats = state.stats.global_stats
Expand All @@ -2450,7 +2450,7 @@ def sharded_update_fn(grads, state, params):
new_stats_flat, params_flat)
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())

updates = jax.tree_unflatten(treedef, updates_flat)
updates = jax.tree_util.tree_unflatten(treedef, updates_flat)
new_local_stats_flat = []
for new_stat, local_stat in zip(new_stats_flat, local_stats_flat):
new_local_stats_flat.append(
Expand Down Expand Up @@ -2563,7 +2563,7 @@ def _update_preconditioners():
if generate_training_metrics:
new_local_stats_flat = _add_metrics_into_local_stats(
new_local_stats_flat, metrics, ~perform_step)
new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
new_local_stats = jax.tree_util.tree_unflatten(treedef, new_local_stats_flat)
errors = metrics.inverse_pth_root_errors
errors = errors.reshape((-1, 1, 1))
predicate = jnp.logical_or(
Expand Down Expand Up @@ -3630,7 +3630,7 @@ def update_fn(grads, state, params):
Returns:
A tuple containing the new parameters and the new optimizer state.
"""
params_flat, treedef = jax.tree_flatten(params)
params_flat, treedef = jax.tree_util.tree_flatten(params)
stats_flat = treedef.flatten_up_to(state.stats)
grads_flat = treedef.flatten_up_to(grads)
stats_grads = grads_flat
Expand All @@ -3646,8 +3646,8 @@ def update_fn(grads, state, params):
new_stats_flat, params_flat)
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())

updates = jax.tree_unflatten(treedef, updates_flat)
new_stats = jax.tree_unflatten(treedef, new_stats_flat)
updates = jax.tree_util.tree_unflatten(treedef, updates_flat)
new_stats = jax.tree_util.tree_unflatten(treedef, new_stats_flat)

new_state = ShampooState(count=state.count + 1, stats=new_stats)
return updates, new_state
Expand Down