diff --git a/vmoe/checkpoints/base_test.py b/vmoe/checkpoints/base_test.py index e88e72d..8e6bf40 100644 --- a/vmoe/checkpoints/base_test.py +++ b/vmoe/checkpoints/base_test.py @@ -88,13 +88,15 @@ def test_save_and_restore_multiple_checkpoints(self): # Restore checkpoints without specifying the tree structure. restored_filepath_tree_map = base.restore_multiple_checkpoints( {key: None for key in filepath_tree_map}) - jax.tree_map(np.testing.assert_array_equal, restored_filepath_tree_map, - filepath_tree_map) + jax.tree_util.tree_map(np.testing.assert_array_equal, + restored_filepath_tree_map, + filepath_tree_map) # Restore checkpoints specifying the tree structure. restored_filepath_tree_map = base.restore_multiple_checkpoints( - jax.tree_structure(filepath_tree_map).unflatten([1, 2, 3, 4])) - jax.tree_map(np.testing.assert_array_equal, restored_filepath_tree_map, - filepath_tree_map) + jax.tree_util.tree_structure(filepath_tree_map).unflatten([1, 2, 3, 4])) + jax.tree_util.tree_map(np.testing.assert_array_equal, + restored_filepath_tree_map, + filepath_tree_map) # Check that the bfloat16 is restored properly. self.assertEqual( restored_filepath_tree_map[workdir + '/checkpoint_2'].dtype, diff --git a/vmoe/checkpoints/partitioned.py b/vmoe/checkpoints/partitioned.py index b873a2b..99cf06d 100644 --- a/vmoe/checkpoints/partitioned.py +++ b/vmoe/checkpoints/partitioned.py @@ -404,7 +404,7 @@ def _replace_jax_with_numpy_in_lazy_array_chunks( for lst in lac.chunks.values(): for arr, _, _ in lst: id_to_array[id(arr)] = arr - id_to_array = jax.tree_map(np.asarray, id_to_array) + id_to_array = jax.tree_util.tree_map(np.asarray, id_to_array) for lac in ckpt_shard_to_lazy_array_chunks.values(): for i, lst in lac.chunks.items(): lac.chunks[i] = [ diff --git a/vmoe/checkpoints/serialization.py b/vmoe/checkpoints/serialization.py index 19bbc37..f5b6723 100644 --- a/vmoe/checkpoints/serialization.py +++ b/vmoe/checkpoints/serialization.py @@ -109,7 +109,7 @@ def msgpack_serialize(pytree: PyTree, in_place: bool = False) -> bytes: msgpack-encoded bytes of pytree. """ if not in_place: - pytree = jax.tree_map(lambda x: x, pytree) + pytree = jax.tree_util.tree_map(lambda x: x, pytree) pytree = _np_convert_in_place(pytree) pytree = _chunk_array_leaves_in_place(pytree) return msgpack.packb(pytree, default=_msgpack_ext_pack, strict_types=True) @@ -137,7 +137,8 @@ def to_bytes(target: PyTree) -> bytes: _ndarray_to_bytes = flax.serialization._ndarray_to_bytes _ndarray_from_bytes = flax.serialization._ndarray_from_bytes _np_convert_in_place = flax.serialization._np_convert_in_place -_unchunk_array_leaves_in_place = flax.serialization._unchunk_array_leaves_in_place +_unchunk_array_leaves_in_place = ( + flax.serialization._unchunk_array_leaves_in_place) _MAX_CHUNK_SIZE = flax.serialization.MAX_CHUNK_SIZE # pylint: enable=protected-access @@ -146,6 +147,7 @@ def to_bytes(target: PyTree) -> bytes: class _MsgpackExtType(enum.IntEnum): + # pylint: disable=invalid-name ndarray = 1 native_complex = 2 npscalar = 3 @@ -154,6 +156,7 @@ class _MsgpackExtType(enum.IntEnum): slice_nd = 6 slice_nd_array = 7 index_info = 8 + # pylint: enable=invalid-name def _shaped_array_to_bytes(x: core.ShapedArray) -> bytes: diff --git a/vmoe/data/pjit_utils.py b/vmoe/data/pjit_utils.py index 6790aa2..0e674b6 100644 --- a/vmoe/data/pjit_utils.py +++ b/vmoe/data/pjit_utils.py @@ -107,7 +107,7 @@ def _to_global(x): def enqueue(n): for data in itertools.islice(iterator, n): - queue.append(jax.tree_map(_to_global, data)) + queue.append(jax.tree_util.tree_map(_to_global, data)) enqueue(size) while queue: @@ -117,7 +117,7 @@ def enqueue(n): # If size is None, 0 or negative, simply create jax.Arrays without # prefetching. for data in iterator: - yield jax.tree_map(_to_global, data) + yield jax.tree_util.tree_map(_to_global, data) def put_to_devices(host_array: np.ndarray, diff --git a/vmoe/evaluate/evaluator.py b/vmoe/evaluate/evaluator.py index 61015f7..2ab4d7c 100644 --- a/vmoe/evaluate/evaluator.py +++ b/vmoe/evaluate/evaluator.py @@ -226,7 +226,7 @@ def evaluate_dataset( for batch in dataset: eval_state = eval_step_pjit(eval_state, params, batch['image'], batch['labels'], batch[VALID_KEY]) - return jax.tree_map(lambda x: x.block_until_ready(), eval_state) + return jax.tree_util.tree_map(lambda x: x.block_until_ready(), eval_state) def evaluate_step( diff --git a/vmoe/initialization/initialization.py b/vmoe/initialization/initialization.py index 591a512..0e4984e 100644 --- a/vmoe/initialization/initialization.py +++ b/vmoe/initialization/initialization.py @@ -233,10 +233,11 @@ def initialize_from_vmoe( # each array read from the checkpoint. index = vmoe_checkpoint.restore_checkpoint(prefix + '.index') version = index.get('version', vmoe_checkpoint.Version.UNKNOWN) - shapes = jax.tree_map(lambda x: x.global_shape, index['index']) + shapes = jax.tree_util.tree_map(lambda x: x.global_shape, index['index']) if version == vmoe_checkpoint.Version.UNKNOWN: - if (jax.tree_structure(vmoe_serialization.to_state_dict(target)) != - jax.tree_structure(shapes)): + target_state_dict = vmoe_serialization.to_state_dict(target) + if (jax.tree_util.tree_structure(target_state_dict) != + jax.tree_util.tree_structure(shapes)): raise ValueError( 'Initialization from V-MoE checkpoints created before 2022/06/22 ' 'is only possible when the structure of the checkpoint and target ' diff --git a/vmoe/moe.py b/vmoe/moe.py index d6e963a..2992002 100644 --- a/vmoe/moe.py +++ b/vmoe/moe.py @@ -428,7 +428,7 @@ def wrapper(expert_fn: Callable[..., Any]): def transformed(scopes, dispatcher, *inputs): # Prepare inputs to be processed by each expert. - inputs = jax.tree_map(dispatcher.dispatch, inputs) + inputs = jax.tree_util.tree_map(dispatcher.dispatch, inputs) # Wrap the target with vmap, to pass different parameters and inputs to # each expert. outputs = flax.core.lift.vmap( @@ -440,7 +440,7 @@ def transformed(scopes, dispatcher, *inputs): # Combine outputs. if has_aux: outputs, aux = outputs - outputs = jax.tree_map(dispatcher.combine, outputs) + outputs = jax.tree_util.tree_map(dispatcher.combine, outputs) return (outputs, aux) if has_aux else outputs return transformed diff --git a/vmoe/moe_test.py b/vmoe/moe_test.py index 3ff6d1a..8912847 100644 --- a/vmoe/moe_test.py +++ b/vmoe/moe_test.py @@ -491,7 +491,7 @@ def init(rng): # All parameters are partitioned across the first axis of the TPU mesh. # Thus, each group of devices in one "row" will share the same values for # all parameters. - param_partition_spec = jax.tree_map( + param_partition_spec = jax.tree_util.tree_map( lambda _: jax.sharding.PartitionSpec(('expert',)), jax.eval_shape(init, jax.random.PRNGKey(0))) init_pjit = pjit.pjit( @@ -586,7 +586,7 @@ def __call__(self, x, w): data_axis_resources = PartitionSpec(('X', 'Y')) variables_axis_resources = flax.core.freeze({ 'params': - jax.tree_map( + jax.tree_util.tree_map( lambda x: nn.partitioning.logical_to_mesh_axes(x, axis_rules), nn.partitioning.get_axis_names(variables_shape['params_axes'])), }) diff --git a/vmoe/nn/external_test.py b/vmoe/nn/external_test.py index 5c3bf65..25dc34f 100644 --- a/vmoe/nn/external_test.py +++ b/vmoe/nn/external_test.py @@ -57,16 +57,19 @@ def compute_grad(params): variables = model.init(jax.random.PRNGKey(0), x) params = variables['params'] + grads = None for _ in range(2): grads = compute_grad(params) - params = jax.tree_map(lambda p, g: p - 0.01 * g, params, grads) + params = jax.tree_util.tree_map( + lambda p, g: p - 0.01 * g, params, grads) return grads x = jax.random.normal(jax.random.PRNGKey(0), (8, 64, 64, 3)) grads = fn(x) - grads_norm = jax.tree_map(lambda x: jnp.linalg.norm(x.flatten()), grads) - zeros = jax.tree_map(jnp.zeros_like, grads_norm) + grads_norm = jax.tree_util.tree_map( + lambda x: jnp.linalg.norm(x.flatten()), grads) + zeros = jax.tree_util.tree_map(jnp.zeros_like, grads_norm) print(grads_norm) chex.assert_trees_all_equal_comparator(lambda x, y: x > y, '{} is not greater than {}'.format, diff --git a/vmoe/nn/vit_moe_ensemble_test.py b/vmoe/nn/vit_moe_ensemble_test.py index f3810c5..73099d8 100644 --- a/vmoe/nn/vit_moe_ensemble_test.py +++ b/vmoe/nn/vit_moe_ensemble_test.py @@ -43,7 +43,8 @@ def init(rngs, x): rngs = dict(params=jax.random.PRNGKey(0), gating=jax.random.PRNGKey(1)) x = jax.random.normal(jax.random.PRNGKey(0), (16, 4, 4, 3)) - shapes = jax.tree_map(lambda x: x.shape, jax.eval_shape(init, rngs, x)) + shapes = jax.tree_util.tree_map( + lambda x: x.shape, jax.eval_shape(init, rngs, x)) shapes = flax.core.unfreeze(shapes) expected_shapes = { 'params': { diff --git a/vmoe/nn/vit_moe_test.py b/vmoe/nn/vit_moe_test.py index d5d0bce..6f52a6e 100644 --- a/vmoe/nn/vit_moe_test.py +++ b/vmoe/nn/vit_moe_test.py @@ -153,7 +153,7 @@ def init(): x = jax.random.normal(jax.random.PRNGKey(0), (16, 4, 4, 3)) return model.init(rngs, x) - shapes = jax.tree_map( + shapes = jax.tree_util.tree_map( lambda x: x.shape, flax.core.unfreeze(jax.eval_shape(init)) ) expected_shapes = { @@ -237,8 +237,8 @@ def test_classifier(self, classifier, seq_length, params_subset): model = vit_moe.VisionTransformerMoe(**config) rngs = dict(params=jax.random.PRNGKey(0), gating=jax.random.PRNGKey(1)) x = jax.ShapeDtypeStruct((16, 4, 4, 3), jax.numpy.float32) - shapes = jax.tree_map(lambda x: x.shape, - jax.eval_shape(model.init, rngs, x)) + shapes = jax.tree_util.tree_map(lambda x: x.shape, + jax.eval_shape(model.init, rngs, x)) shapes = flax.core.unfreeze(shapes) self.assertDictEqual(shapes['params']['Encoder']['posembed_input'], {'pos_embedding': (1, seq_length, 8)}) diff --git a/vmoe/projects/adversarial_attacks/attacks.py b/vmoe/projects/adversarial_attacks/attacks.py index f3177d7..ddacfa1 100644 --- a/vmoe/projects/adversarial_attacks/attacks.py +++ b/vmoe/projects/adversarial_attacks/attacks.py @@ -120,7 +120,8 @@ def mask(a): num_changes = jnp.sum(y_m != y_0, dtype=jnp.int32) num_correct = jnp.stack([c_0.sum(), c_m.sum()], axis=0) sum_loss = jnp.stack([l_0.sum(), l_m.sum()], axis=0) - sum_iou_experts = jax.tree_map(sum_intersection_over_union, cw_0, cw_m) + sum_iou_experts = jax.tree_util.tree_map( + sum_intersection_over_union, cw_0, cw_m) new_attack_state = attack_state.update( rngs=new_rngs, num_images=num_images, num_changes=num_changes, num_correct=num_correct, sum_loss=sum_loss, @@ -151,7 +152,8 @@ def stateless_attack_pgd( # If rngs are given, split them in num_updates + 1. The last one will be # returned at the end of this function, the others are used in each update. if rngs is not None: - rngs = jax.tree_map(lambda rng: jax.random.split(rng, num_updates +1), rngs) + rngs = jax.tree_util.tree_map( + lambda rng: jax.random.split(rng, num_updates +1), rngs) # This computes gradients of loss_fn w.r.t. the images. @jax.grad def compute_loss_grads(x, rngs): @@ -160,13 +162,15 @@ def compute_loss_grads(x, rngs): return loss # Performs an adversarial update on the given images. def update(i, x_c): - rngs_c = jax.tree_map(lambda x: x[i], rngs) if rngs is not None else None + rngs_c = jax.tree_util.tree_map( + lambda x: x[i], rngs) if rngs is not None else None dx = compute_loss_grads(x_c, rngs_c) x_n = x_c + delta * jnp.sign(dx) return x_n # Performs num_updates on the original images. images = jax.lax.fori_loop(0, num_updates, update, images) - rngs = jax.tree_map(lambda x: x[-1], rngs) if rngs is not None else None + rngs = jax.tree_util.tree_map( + lambda x: x[-1], rngs) if rngs is not None else None return images, rngs diff --git a/vmoe/projects/adversarial_attacks/lib.py b/vmoe/projects/adversarial_attacks/lib.py index da5f3f3..9ecc99a 100644 --- a/vmoe/projects/adversarial_attacks/lib.py +++ b/vmoe/projects/adversarial_attacks/lib.py @@ -53,7 +53,7 @@ def __init__(self, filepath: str, transfer_every_steps: int = 1): def _transfer(self): for key, values in self._temp.items(): - self._data[key].extend(jax.tree_map(np.asarray, values)) + self._data[key].extend(jax.tree_util.tree_map(np.asarray, values)) self._temp[key] = [] @property @@ -99,6 +99,7 @@ def run_pgd_attack( config: ml_collections.ConfigDict, workdir: str, mesh: Mesh, writer: metric_writers.MetricWriter): """Run PGD attack on an entire dataset, using a model from an XM experiment.""" + del mesh # Setup dataset and get the global shape of the image array. dataset = get_dataset(config.dataset) element_spec: ArraySpecDict = dataset.element_spec # pytype: disable=annotation-type-mismatch @@ -141,8 +142,8 @@ def init_state(): rngs = utils.make_rngs(rng_keys, config.get('seed', 0)) return attacks.AttackState.create( max_updates=config.num_updates, router_keys=router_keys, rngs=rngs) - state_axis_resources = jax.tree_map(lambda _: PartitionSpec(), - jax.eval_shape(init_state)) + state_axis_resources = jax.tree_util.tree_map(lambda _: PartitionSpec(), + jax.eval_shape(init_state)) init_state_pjit = pjit.pjit( init_state, in_shardings=(), out_shardings=state_axis_resources ) @@ -203,7 +204,7 @@ def compute_loss_predict_correct_cw_fn(images, labels, rngs): **{f'cw_0/{k}': v for k, v in cw_0.items()}, **{f'cw_m/{k}': v for k, v in cw_m.items()}) # Copy state from device to CPU and convert to numpy arrays. - state = jax.tree_map(np.asarray, state) + state = jax.tree_util.tree_map(np.asarray, state) # Process with index=0 saves the PGD state (it's the same for all processes). if jax.process_index() == 0: state_filepath = os.path.join(workdir, 'pgd_state.npz') diff --git a/vmoe/projects/adversarial_attacks/lib_test.py b/vmoe/projects/adversarial_attacks/lib_test.py index 8f14bde..9eacf6c 100644 --- a/vmoe/projects/adversarial_attacks/lib_test.py +++ b/vmoe/projects/adversarial_attacks/lib_test.py @@ -61,8 +61,8 @@ def __call__(self, x): flax_module = ModelWithRouting() image_size = (batch_size, image_size, image_size, 3) variables = flax_module.init(jax.random.PRNGKey(0), np.zeros(image_size)) - variables_axis_resources = jax.tree_map(lambda _: lib.PartitionSpec(), - variables) + variables_axis_resources = jax.tree_util.tree_map( + lambda _: lib.PartitionSpec(), variables) router_keys = {'router/__call__'} loss_fn = lambda a, b, _: optax.softmax_cross_entropy(a, b) return (flax_module, variables, variables_axis_resources, loss_fn, @@ -82,8 +82,8 @@ def __call__(self, x): flax_module = ModelWithoutRouting() image_size = (batch_size, image_size, image_size, 3) variables = flax_module.init(jax.random.PRNGKey(0), np.zeros(image_size)) - variables_axis_resources = jax.tree_map(lambda _: lib.PartitionSpec(), - variables) + variables_axis_resources = jax.tree_util.tree_map( + lambda _: lib.PartitionSpec(), variables) loss_fn = lambda a, b, _: optax.softmax_cross_entropy(a, b) return flax_module, variables, variables_axis_resources, loss_fn, {}, {} diff --git a/vmoe/projects/adversarial_attacks/restore.py b/vmoe/projects/adversarial_attacks/restore.py index 90ff11e..5bf1245 100644 --- a/vmoe/projects/adversarial_attacks/restore.py +++ b/vmoe/projects/adversarial_attacks/restore.py @@ -47,7 +47,7 @@ def compute_loss_predict_cw_fn(x, y, rngs, *, apply_fn, loss_fn): # This is a dict mapping from each MoE layer to a binary array of shape # (batch_size, num_tokens, num_experts). combine_weights = get_combine_weights(intermediates) - combine_weights = jax.tree_map( + combine_weights = jax.tree_util.tree_map( lambda m: m.reshape(batch_size, -1, m.shape[1]), combine_weights) return loss, pred, correct, combine_weights diff --git a/vmoe/train/optimizer_test.py b/vmoe/train/optimizer_test.py index d613481..df9b2da 100644 --- a/vmoe/train/optimizer_test.py +++ b/vmoe/train/optimizer_test.py @@ -113,7 +113,7 @@ def step(_, state): new_params = optimizer.optax.apply_updates(params, updates) return new_params, new_tx_state - params = jax.tree_map(jnp.asarray, {'x': 0., 'y': 0.}) + params = jax.tree_util.tree_map(jnp.asarray, {'x': 0., 'y': 0.}) state = init_fn(params) return jax.lax.fori_loop(0, 200, step, (params, state))[0] diff --git a/vmoe/train/trainer.py b/vmoe/train/trainer.py index b80ffa0..4e07128 100644 --- a/vmoe/train/trainer.py +++ b/vmoe/train/trainer.py @@ -415,7 +415,8 @@ def _array_restore_args_fn(x: jax.ShapeDtypeStruct): dtype=x.dtype, sharding=x.sharding, global_shape=x.shape) restore_kwargs = { 'state': { - 'restore_args': jax.tree_map(_array_restore_args_fn, train_state), + 'restore_args': jax.tree_util.tree_map( + _array_restore_args_fn, train_state), }, } items = ckpt_manager.restore( @@ -572,7 +573,7 @@ def mixup( Returns: A tree with the mixed arrays. """ - arrays, treedef = jax.tree_flatten(tree) + arrays, treedef = jax.tree_util.tree_flatten(tree) if len(shape) < 2: raise ValueError(f"Mixup 'shape' has length {len(shape)}, but it must have " 'length >= 2.') @@ -648,7 +649,7 @@ def compute_grads_and_metrics(params, images, labels, rngs): logits, metrics = state.apply_fn({'params': params}, images, rngs=rngs) metrics = dict(**metrics) metrics['main_loss'] = jnp.mean(loss_fn(logits, labels)) - metrics = jax.tree_map(jnp.mean, metrics) + metrics = jax.tree_util.tree_map(jnp.mean, metrics) total_loss = metrics['main_loss'] + metrics.get('auxiliary_loss', 0.0) metrics['total_loss'] = total_loss return total_loss, (next_rngs, metrics) diff --git a/vmoe/train/trainer_test.py b/vmoe/train/trainer_test.py index 0fb575d..6e51871 100644 --- a/vmoe/train/trainer_test.py +++ b/vmoe/train/trainer_test.py @@ -732,9 +732,9 @@ def test_granularity(self, granularity): np.testing.assert_array_less(np.zeros_like(x), np.abs(metrics['images'] - x)) # Convert PRNGKeys to int64 to subtract them. - rngs = jax.tree_map(lambda x: np.asarray(x, dtype=np.int64), rngs) - new_rngs = jax.tree_map(lambda x: np.asarray(x, dtype=np.int64), - train_state.rngs) + rngs = jax.tree_util.tree_map(lambda x: np.asarray(x, dtype=np.int64), rngs) + new_rngs = jax.tree_util.tree_map(lambda x: np.asarray(x, dtype=np.int64), + train_state.rngs) # Check that both PRNGKeys have been updated. np.testing.assert_array_less( np.zeros_like(rngs['mixup']), np.abs(rngs['mixup'] - new_rngs['mixup'])) diff --git a/vmoe/utils.py b/vmoe/utils.py index ab89971..7f279c8 100644 --- a/vmoe/utils.py +++ b/vmoe/utils.py @@ -151,8 +151,9 @@ def safe_zip(*iterables) -> Iterator[Tuple[Any, ...]]: def tree_rngs_split(rngs, num_splits=2): """Splits a PyTree of PRNGKeys into num_splits PyTrees.""" - rngs = jax.tree_map(lambda rng: jax.random.split(rng, num_splits), rngs) - slice_rngs = lambda rngs, i: jax.tree_map(lambda rng: rng[i], rngs) + rngs = jax.tree_util.tree_map( + lambda rng: jax.random.split(rng, num_splits), rngs) + slice_rngs = lambda rngs, i: jax.tree_util.tree_map(lambda rng: rng[i], rngs) return tuple(slice_rngs(rngs, i) for i in range(num_splits))