Skip to content

Commit

Permalink
Replace jax.tree_(*) functions with jax.tree_util.tree_\1.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615001563
  • Loading branch information
jpuigcerver authored and copybara-github committed Mar 12, 2024
1 parent 5361422 commit 647a51c
Show file tree
Hide file tree
Showing 19 changed files with 64 additions and 47 deletions.
12 changes: 7 additions & 5 deletions vmoe/checkpoints/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,15 @@ def test_save_and_restore_multiple_checkpoints(self):
# Restore checkpoints without specifying the tree structure.
restored_filepath_tree_map = base.restore_multiple_checkpoints(
{key: None for key in filepath_tree_map})
jax.tree_map(np.testing.assert_array_equal, restored_filepath_tree_map,
filepath_tree_map)
jax.tree_util.tree_map(np.testing.assert_array_equal,
restored_filepath_tree_map,
filepath_tree_map)
# Restore checkpoints specifying the tree structure.
restored_filepath_tree_map = base.restore_multiple_checkpoints(
jax.tree_structure(filepath_tree_map).unflatten([1, 2, 3, 4]))
jax.tree_map(np.testing.assert_array_equal, restored_filepath_tree_map,
filepath_tree_map)
jax.tree_util.tree_structure(filepath_tree_map).unflatten([1, 2, 3, 4]))
jax.tree_util.tree_map(np.testing.assert_array_equal,
restored_filepath_tree_map,
filepath_tree_map)
# Check that the bfloat16 is restored properly.
self.assertEqual(
restored_filepath_tree_map[workdir + '/checkpoint_2'].dtype,
Expand Down
2 changes: 1 addition & 1 deletion vmoe/checkpoints/partitioned.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def _replace_jax_with_numpy_in_lazy_array_chunks(
for lst in lac.chunks.values():
for arr, _, _ in lst:
id_to_array[id(arr)] = arr
id_to_array = jax.tree_map(np.asarray, id_to_array)
id_to_array = jax.tree_util.tree_map(np.asarray, id_to_array)
for lac in ckpt_shard_to_lazy_array_chunks.values():
for i, lst in lac.chunks.items():
lac.chunks[i] = [
Expand Down
7 changes: 5 additions & 2 deletions vmoe/checkpoints/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def msgpack_serialize(pytree: PyTree, in_place: bool = False) -> bytes:
msgpack-encoded bytes of pytree.
"""
if not in_place:
pytree = jax.tree_map(lambda x: x, pytree)
pytree = jax.tree_util.tree_map(lambda x: x, pytree)
pytree = _np_convert_in_place(pytree)
pytree = _chunk_array_leaves_in_place(pytree)
return msgpack.packb(pytree, default=_msgpack_ext_pack, strict_types=True)
Expand Down Expand Up @@ -137,7 +137,8 @@ def to_bytes(target: PyTree) -> bytes:
_ndarray_to_bytes = flax.serialization._ndarray_to_bytes
_ndarray_from_bytes = flax.serialization._ndarray_from_bytes
_np_convert_in_place = flax.serialization._np_convert_in_place
_unchunk_array_leaves_in_place = flax.serialization._unchunk_array_leaves_in_place
_unchunk_array_leaves_in_place = (
flax.serialization._unchunk_array_leaves_in_place)
_MAX_CHUNK_SIZE = flax.serialization.MAX_CHUNK_SIZE
# pylint: enable=protected-access

Expand All @@ -146,6 +147,7 @@ def to_bytes(target: PyTree) -> bytes:


class _MsgpackExtType(enum.IntEnum):
# pylint: disable=invalid-name
ndarray = 1
native_complex = 2
npscalar = 3
Expand All @@ -154,6 +156,7 @@ class _MsgpackExtType(enum.IntEnum):
slice_nd = 6
slice_nd_array = 7
index_info = 8
# pylint: enable=invalid-name


def _shaped_array_to_bytes(x: core.ShapedArray) -> bytes:
Expand Down
4 changes: 2 additions & 2 deletions vmoe/data/pjit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _to_global(x):

def enqueue(n):
for data in itertools.islice(iterator, n):
queue.append(jax.tree_map(_to_global, data))
queue.append(jax.tree_util.tree_map(_to_global, data))

enqueue(size)
while queue:
Expand All @@ -117,7 +117,7 @@ def enqueue(n):
# If size is None, 0 or negative, simply create jax.Arrays without
# prefetching.
for data in iterator:
yield jax.tree_map(_to_global, data)
yield jax.tree_util.tree_map(_to_global, data)


def put_to_devices(host_array: np.ndarray,
Expand Down
2 changes: 1 addition & 1 deletion vmoe/evaluate/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def evaluate_dataset(
for batch in dataset:
eval_state = eval_step_pjit(eval_state, params, batch['image'],
batch['labels'], batch[VALID_KEY])
return jax.tree_map(lambda x: x.block_until_ready(), eval_state)
return jax.tree_util.tree_map(lambda x: x.block_until_ready(), eval_state)


def evaluate_step(
Expand Down
7 changes: 4 additions & 3 deletions vmoe/initialization/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,11 @@ def initialize_from_vmoe(
# each array read from the checkpoint.
index = vmoe_checkpoint.restore_checkpoint(prefix + '.index')
version = index.get('version', vmoe_checkpoint.Version.UNKNOWN)
shapes = jax.tree_map(lambda x: x.global_shape, index['index'])
shapes = jax.tree_util.tree_map(lambda x: x.global_shape, index['index'])
if version == vmoe_checkpoint.Version.UNKNOWN:
if (jax.tree_structure(vmoe_serialization.to_state_dict(target)) !=
jax.tree_structure(shapes)):
target_state_dict = vmoe_serialization.to_state_dict(target)
if (jax.tree_util.tree_structure(target_state_dict) !=
jax.tree_util.tree_structure(shapes)):
raise ValueError(
'Initialization from V-MoE checkpoints created before 2022/06/22 '
'is only possible when the structure of the checkpoint and target '
Expand Down
4 changes: 2 additions & 2 deletions vmoe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def wrapper(expert_fn: Callable[..., Any]):

def transformed(scopes, dispatcher, *inputs):
# Prepare inputs to be processed by each expert.
inputs = jax.tree_map(dispatcher.dispatch, inputs)
inputs = jax.tree_util.tree_map(dispatcher.dispatch, inputs)
# Wrap the target with vmap, to pass different parameters and inputs to
# each expert.
outputs = flax.core.lift.vmap(
Expand All @@ -440,7 +440,7 @@ def transformed(scopes, dispatcher, *inputs):
# Combine outputs.
if has_aux:
outputs, aux = outputs
outputs = jax.tree_map(dispatcher.combine, outputs)
outputs = jax.tree_util.tree_map(dispatcher.combine, outputs)
return (outputs, aux) if has_aux else outputs

return transformed
Expand Down
4 changes: 2 additions & 2 deletions vmoe/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def init(rng):
# All parameters are partitioned across the first axis of the TPU mesh.
# Thus, each group of devices in one "row" will share the same values for
# all parameters.
param_partition_spec = jax.tree_map(
param_partition_spec = jax.tree_util.tree_map(
lambda _: jax.sharding.PartitionSpec(('expert',)),
jax.eval_shape(init, jax.random.PRNGKey(0)))
init_pjit = pjit.pjit(
Expand Down Expand Up @@ -586,7 +586,7 @@ def __call__(self, x, w):
data_axis_resources = PartitionSpec(('X', 'Y'))
variables_axis_resources = flax.core.freeze({
'params':
jax.tree_map(
jax.tree_util.tree_map(
lambda x: nn.partitioning.logical_to_mesh_axes(x, axis_rules),
nn.partitioning.get_axis_names(variables_shape['params_axes'])),
})
Expand Down
9 changes: 6 additions & 3 deletions vmoe/nn/external_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,19 @@ def compute_grad(params):

variables = model.init(jax.random.PRNGKey(0), x)
params = variables['params']
grads = None
for _ in range(2):
grads = compute_grad(params)
params = jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.01 * g, params, grads)

return grads

x = jax.random.normal(jax.random.PRNGKey(0), (8, 64, 64, 3))
grads = fn(x)
grads_norm = jax.tree_map(lambda x: jnp.linalg.norm(x.flatten()), grads)
zeros = jax.tree_map(jnp.zeros_like, grads_norm)
grads_norm = jax.tree_util.tree_map(
lambda x: jnp.linalg.norm(x.flatten()), grads)
zeros = jax.tree_util.tree_map(jnp.zeros_like, grads_norm)
print(grads_norm)
chex.assert_trees_all_equal_comparator(lambda x, y: x > y,
'{} is not greater than {}'.format,
Expand Down
3 changes: 2 additions & 1 deletion vmoe/nn/vit_moe_ensemble_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def init(rngs, x):

rngs = dict(params=jax.random.PRNGKey(0), gating=jax.random.PRNGKey(1))
x = jax.random.normal(jax.random.PRNGKey(0), (16, 4, 4, 3))
shapes = jax.tree_map(lambda x: x.shape, jax.eval_shape(init, rngs, x))
shapes = jax.tree_util.tree_map(
lambda x: x.shape, jax.eval_shape(init, rngs, x))
shapes = flax.core.unfreeze(shapes)
expected_shapes = {
'params': {
Expand Down
6 changes: 3 additions & 3 deletions vmoe/nn/vit_moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def init():
x = jax.random.normal(jax.random.PRNGKey(0), (16, 4, 4, 3))
return model.init(rngs, x)

shapes = jax.tree_map(
shapes = jax.tree_util.tree_map(
lambda x: x.shape, flax.core.unfreeze(jax.eval_shape(init))
)
expected_shapes = {
Expand Down Expand Up @@ -237,8 +237,8 @@ def test_classifier(self, classifier, seq_length, params_subset):
model = vit_moe.VisionTransformerMoe(**config)
rngs = dict(params=jax.random.PRNGKey(0), gating=jax.random.PRNGKey(1))
x = jax.ShapeDtypeStruct((16, 4, 4, 3), jax.numpy.float32)
shapes = jax.tree_map(lambda x: x.shape,
jax.eval_shape(model.init, rngs, x))
shapes = jax.tree_util.tree_map(lambda x: x.shape,
jax.eval_shape(model.init, rngs, x))
shapes = flax.core.unfreeze(shapes)
self.assertDictEqual(shapes['params']['Encoder']['posembed_input'],
{'pos_embedding': (1, seq_length, 8)})
Expand Down
12 changes: 8 additions & 4 deletions vmoe/projects/adversarial_attacks/attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def mask(a):
num_changes = jnp.sum(y_m != y_0, dtype=jnp.int32)
num_correct = jnp.stack([c_0.sum(), c_m.sum()], axis=0)
sum_loss = jnp.stack([l_0.sum(), l_m.sum()], axis=0)
sum_iou_experts = jax.tree_map(sum_intersection_over_union, cw_0, cw_m)
sum_iou_experts = jax.tree_util.tree_map(
sum_intersection_over_union, cw_0, cw_m)
new_attack_state = attack_state.update(
rngs=new_rngs, num_images=num_images, num_changes=num_changes,
num_correct=num_correct, sum_loss=sum_loss,
Expand Down Expand Up @@ -151,7 +152,8 @@ def stateless_attack_pgd(
# If rngs are given, split them in num_updates + 1. The last one will be
# returned at the end of this function, the others are used in each update.
if rngs is not None:
rngs = jax.tree_map(lambda rng: jax.random.split(rng, num_updates +1), rngs)
rngs = jax.tree_util.tree_map(
lambda rng: jax.random.split(rng, num_updates +1), rngs)
# This computes gradients of loss_fn w.r.t. the images.
@jax.grad
def compute_loss_grads(x, rngs):
Expand All @@ -160,13 +162,15 @@ def compute_loss_grads(x, rngs):
return loss
# Performs an adversarial update on the given images.
def update(i, x_c):
rngs_c = jax.tree_map(lambda x: x[i], rngs) if rngs is not None else None
rngs_c = jax.tree_util.tree_map(
lambda x: x[i], rngs) if rngs is not None else None
dx = compute_loss_grads(x_c, rngs_c)
x_n = x_c + delta * jnp.sign(dx)
return x_n
# Performs num_updates on the original images.
images = jax.lax.fori_loop(0, num_updates, update, images)
rngs = jax.tree_map(lambda x: x[-1], rngs) if rngs is not None else None
rngs = jax.tree_util.tree_map(
lambda x: x[-1], rngs) if rngs is not None else None
return images, rngs


Expand Down
9 changes: 5 additions & 4 deletions vmoe/projects/adversarial_attacks/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, filepath: str, transfer_every_steps: int = 1):

def _transfer(self):
for key, values in self._temp.items():
self._data[key].extend(jax.tree_map(np.asarray, values))
self._data[key].extend(jax.tree_util.tree_map(np.asarray, values))
self._temp[key] = []

@property
Expand Down Expand Up @@ -99,6 +99,7 @@ def run_pgd_attack(
config: ml_collections.ConfigDict, workdir: str, mesh: Mesh,
writer: metric_writers.MetricWriter):
"""Run PGD attack on an entire dataset, using a model from an XM experiment."""
del mesh
# Setup dataset and get the global shape of the image array.
dataset = get_dataset(config.dataset)
element_spec: ArraySpecDict = dataset.element_spec # pytype: disable=annotation-type-mismatch
Expand Down Expand Up @@ -141,8 +142,8 @@ def init_state():
rngs = utils.make_rngs(rng_keys, config.get('seed', 0))
return attacks.AttackState.create(
max_updates=config.num_updates, router_keys=router_keys, rngs=rngs)
state_axis_resources = jax.tree_map(lambda _: PartitionSpec(),
jax.eval_shape(init_state))
state_axis_resources = jax.tree_util.tree_map(lambda _: PartitionSpec(),
jax.eval_shape(init_state))
init_state_pjit = pjit.pjit(
init_state, in_shardings=(), out_shardings=state_axis_resources
)
Expand Down Expand Up @@ -203,7 +204,7 @@ def compute_loss_predict_correct_cw_fn(images, labels, rngs):
**{f'cw_0/{k}': v for k, v in cw_0.items()},
**{f'cw_m/{k}': v for k, v in cw_m.items()})
# Copy state from device to CPU and convert to numpy arrays.
state = jax.tree_map(np.asarray, state)
state = jax.tree_util.tree_map(np.asarray, state)
# Process with index=0 saves the PGD state (it's the same for all processes).
if jax.process_index() == 0:
state_filepath = os.path.join(workdir, 'pgd_state.npz')
Expand Down
8 changes: 4 additions & 4 deletions vmoe/projects/adversarial_attacks/lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def __call__(self, x):
flax_module = ModelWithRouting()
image_size = (batch_size, image_size, image_size, 3)
variables = flax_module.init(jax.random.PRNGKey(0), np.zeros(image_size))
variables_axis_resources = jax.tree_map(lambda _: lib.PartitionSpec(),
variables)
variables_axis_resources = jax.tree_util.tree_map(
lambda _: lib.PartitionSpec(), variables)
router_keys = {'router/__call__'}
loss_fn = lambda a, b, _: optax.softmax_cross_entropy(a, b)
return (flax_module, variables, variables_axis_resources, loss_fn,
Expand All @@ -82,8 +82,8 @@ def __call__(self, x):
flax_module = ModelWithoutRouting()
image_size = (batch_size, image_size, image_size, 3)
variables = flax_module.init(jax.random.PRNGKey(0), np.zeros(image_size))
variables_axis_resources = jax.tree_map(lambda _: lib.PartitionSpec(),
variables)
variables_axis_resources = jax.tree_util.tree_map(
lambda _: lib.PartitionSpec(), variables)
loss_fn = lambda a, b, _: optax.softmax_cross_entropy(a, b)
return flax_module, variables, variables_axis_resources, loss_fn, {}, {}

Expand Down
2 changes: 1 addition & 1 deletion vmoe/projects/adversarial_attacks/restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def compute_loss_predict_cw_fn(x, y, rngs, *, apply_fn, loss_fn):
# This is a dict mapping from each MoE layer to a binary array of shape
# (batch_size, num_tokens, num_experts).
combine_weights = get_combine_weights(intermediates)
combine_weights = jax.tree_map(
combine_weights = jax.tree_util.tree_map(
lambda m: m.reshape(batch_size, -1, m.shape[1]), combine_weights)
return loss, pred, correct, combine_weights

Expand Down
2 changes: 1 addition & 1 deletion vmoe/train/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def step(_, state):
new_params = optimizer.optax.apply_updates(params, updates)
return new_params, new_tx_state

params = jax.tree_map(jnp.asarray, {'x': 0., 'y': 0.})
params = jax.tree_util.tree_map(jnp.asarray, {'x': 0., 'y': 0.})
state = init_fn(params)
return jax.lax.fori_loop(0, 200, step, (params, state))[0]

Expand Down
7 changes: 4 additions & 3 deletions vmoe/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,8 @@ def _array_restore_args_fn(x: jax.ShapeDtypeStruct):
dtype=x.dtype, sharding=x.sharding, global_shape=x.shape)
restore_kwargs = {
'state': {
'restore_args': jax.tree_map(_array_restore_args_fn, train_state),
'restore_args': jax.tree_util.tree_map(
_array_restore_args_fn, train_state),
},
}
items = ckpt_manager.restore(
Expand Down Expand Up @@ -572,7 +573,7 @@ def mixup(
Returns:
A tree with the mixed arrays.
"""
arrays, treedef = jax.tree_flatten(tree)
arrays, treedef = jax.tree_util.tree_flatten(tree)
if len(shape) < 2:
raise ValueError(f"Mixup 'shape' has length {len(shape)}, but it must have "
'length >= 2.')
Expand Down Expand Up @@ -648,7 +649,7 @@ def compute_grads_and_metrics(params, images, labels, rngs):
logits, metrics = state.apply_fn({'params': params}, images, rngs=rngs)
metrics = dict(**metrics)
metrics['main_loss'] = jnp.mean(loss_fn(logits, labels))
metrics = jax.tree_map(jnp.mean, metrics)
metrics = jax.tree_util.tree_map(jnp.mean, metrics)
total_loss = metrics['main_loss'] + metrics.get('auxiliary_loss', 0.0)
metrics['total_loss'] = total_loss
return total_loss, (next_rngs, metrics)
Expand Down
6 changes: 3 additions & 3 deletions vmoe/train/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,9 +732,9 @@ def test_granularity(self, granularity):
np.testing.assert_array_less(np.zeros_like(x),
np.abs(metrics['images'] - x))
# Convert PRNGKeys to int64 to subtract them.
rngs = jax.tree_map(lambda x: np.asarray(x, dtype=np.int64), rngs)
new_rngs = jax.tree_map(lambda x: np.asarray(x, dtype=np.int64),
train_state.rngs)
rngs = jax.tree_util.tree_map(lambda x: np.asarray(x, dtype=np.int64), rngs)
new_rngs = jax.tree_util.tree_map(lambda x: np.asarray(x, dtype=np.int64),
train_state.rngs)
# Check that both PRNGKeys have been updated.
np.testing.assert_array_less(
np.zeros_like(rngs['mixup']), np.abs(rngs['mixup'] - new_rngs['mixup']))
Expand Down
5 changes: 3 additions & 2 deletions vmoe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,9 @@ def safe_zip(*iterables) -> Iterator[Tuple[Any, ...]]:

def tree_rngs_split(rngs, num_splits=2):
"""Splits a PyTree of PRNGKeys into num_splits PyTrees."""
rngs = jax.tree_map(lambda rng: jax.random.split(rng, num_splits), rngs)
slice_rngs = lambda rngs, i: jax.tree_map(lambda rng: rng[i], rngs)
rngs = jax.tree_util.tree_map(
lambda rng: jax.random.split(rng, num_splits), rngs)
slice_rngs = lambda rngs, i: jax.tree_util.tree_map(lambda rng: rng[i], rngs)
return tuple(slice_rngs(rngs, i) for i in range(num_splits))


Expand Down

0 comments on commit 647a51c

Please sign in to comment.