Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 618625497
  • Loading branch information
superbobry authored and copybara-github committed Mar 24, 2024
1 parent 7d2f7e0 commit f0a8702
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions vmoe/data/pjit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from absl import logging
from clu.data import dataset_iterator
import jax
from jax.experimental import maps
from jax.interpreters import pxla
import numpy as np
import tensorflow as tf

Expand All @@ -33,7 +33,7 @@ def get_dataset_shape_dtype_struct(
mesh: Optional[Mesh] = None,
) -> PyTree:
"""Returns the jax.ShapeDtypeStruct."""
mesh = mesh or maps.thread_resources.env.physical_mesh
mesh = mesh or pxla.thread_resources.env.physical_mesh
assert mesh is not None and not mesh.empty, f'No mesh or empty mesh. {mesh=}'

pspec = jax.sharding.PartitionSpec(mesh.axis_names,)
Expand Down Expand Up @@ -73,7 +73,7 @@ def prefetch_to_device(
The original items from the iterator where each ndarray is now sharded as
specified by `axis_resources`.
"""
mesh = mesh or maps.thread_resources.env.physical_mesh
mesh = mesh or pxla.thread_resources.env.physical_mesh
assert mesh is not None and not mesh.empty, f'No mesh or empty mesh. {mesh=}'

pspec = jax.sharding.PartitionSpec(mesh.axis_names,)
Expand Down
4 changes: 2 additions & 2 deletions vmoe/evaluate/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import flax.core
import flax.struct
import jax
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import pxla
import jax.numpy as jnp
from vmoe import utils
from vmoe.data import input_pipeline
Expand Down Expand Up @@ -143,7 +143,7 @@ def _make_callback_fn(cls, *, apply_fn, loss_fn, label_pred_fn, datasets,
sum_correct=jax.ShapeDtypeStruct(shape=(), dtype=jnp.float32),
sum_loss=jax.ShapeDtypeStruct(shape=(), dtype=jnp.float32),
rngs=jax.eval_shape(lambda: utils.make_rngs(rng_keys, 0)))
mesh = maps.thread_resources.env.physical_mesh
mesh = pxla.thread_resources.env.physical_mesh
assert not mesh.empty, 'The physical mesh is empty.'
sharding = jax.sharding.NamedSharding(mesh, PartitionSpec())
eval_state_dtype_struct = tree_map(
Expand Down
4 changes: 2 additions & 2 deletions vmoe/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
import flax.traverse_util
import jax
from jax import lax
from jax.experimental import maps
from jax.interpreters import pxla
import numpy as np

AxisResourcesRegexes = Sequence[Tuple[str, 'UnparsedPartitionSpec']]
Expand Down Expand Up @@ -423,7 +423,7 @@ def search_partition_spec(key: str, value: Any) -> PartitionSpec:

def with_sharding_constraint(x: PyTree, partition_spec: PartitionSpec):
"""Specifies a partition_spec for a given array to help pjit's sharding."""
if maps.thread_resources.env.physical_mesh.empty or partition_spec is None:
if pxla.thread_resources.env.physical_mesh.empty or partition_spec is None:
return x
else:
return lax.with_sharding_constraint(x, partition_spec)
6 changes: 3 additions & 3 deletions vmoe/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import flax.traverse_util
import jax
from jax import lax
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import pxla
import jax.numpy as jnp
import ml_collections
import numpy as np
Expand Down Expand Up @@ -289,7 +289,7 @@ def create_or_reuse_train_state(
Returns:
A TrainState.
"""
mesh = mesh or maps.thread_resources.env.physical_mesh
mesh = mesh or pxla.thread_resources.env.physical_mesh
# Flatten input train state and keep the ShapeDtypeStruct for the arrays that
# must be created from scratch.
train_state_dict = flax.traverse_util.flatten_dict(
Expand Down Expand Up @@ -395,7 +395,7 @@ def restore_or_create_train_state(
Returns:
A TrainState and (optionally) the last_seen_index.
"""
mesh = mesh or maps.thread_resources.env.physical_mesh
mesh = mesh or pxla.thread_resources.env.physical_mesh
train_state_shape_dtype = jax.eval_shape(initialize_fn)
train_state_axis_resources = partitioning.tree_axis_resources_from_regexes(
tree=train_state_shape_dtype,
Expand Down

0 comments on commit f0a8702

Please sign in to comment.