Skip to content

Commit

Permalink
Use shard_map instead of xmap
Browse files Browse the repository at this point in the history
xmap is soft-deprecated with shard_map being the recommended alternative.

See https://jax.readthedocs.io/en/latest/notebooks/shard_map.html.

PiperOrigin-RevId: 610862919
  • Loading branch information
superbobry authored and copybara-github committed Feb 27, 2024
1 parent c95fb0a commit 90f8f94
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
31 changes: 25 additions & 6 deletions vmoe/initialization/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@

import flax.struct
import jax
from jax.experimental import shard_map
from jax.interpreters import pxla
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import scipy.ndimage
from vmoe import partitioning
from vmoe import utils

Array = jax.Array
UnparsedRules = Sequence[Union['Rule', Tuple[Any, ...]]]

shard_map = shard_map.shard_map

get_array_sharding_or_default = partitioning.get_array_sharding_or_default


Expand Down Expand Up @@ -290,12 +295,26 @@ class ZoomTransformation(Transformation):
@classmethod
def _zoom(cls, source, callback_shape_dtype, zoom):
# Wrap scipy.ndimage.zoom with a _pure_callback call.
return jax.pure_callback(
lambda xx: scipy.ndimage.zoom(xx, zoom, order=1),
callback_shape_dtype,
source,
vectorized=False,
)
def _pure_callback_zoom(x):
return jax.pure_callback(
lambda xx: scipy.ndimage.zoom(xx, zoom, order=1),
callback_shape_dtype,
x,
vectorized=False,
)

mesh = pxla.thread_resources.env.physical_mesh
if mesh.empty:
return _pure_callback_zoom(source)
else:
source = partitioning.with_sharding_constraint(source, P())
return shard_map(
_pure_callback_zoom,
mesh,
in_specs=P(*(None,) * source.ndim),
out_specs=P(*(None,) * source.ndim),
check_rep=False,
)(source)

def __call__(self) -> Array:
source = self.source
Expand Down
3 changes: 0 additions & 3 deletions vmoe/initialization/rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
import numpy as np
from vmoe.initialization import rules as _rules

jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)


class ZoomTest(parameterized.TestCase):

Expand Down

0 comments on commit 90f8f94

Please sign in to comment.