Skip to content

Commit

Permalink
Removed unnecessary pjit/xmap hack in ZoomTransformation
Browse files Browse the repository at this point in the history
xmap is soft-deprecated with shard_map being the recommended alternative.

See https://jax.readthedocs.io/en/latest/notebooks/shard_map.html.

PiperOrigin-RevId: 610748229
  • Loading branch information
superbobry authored and copybara-github committed Feb 27, 2024
1 parent a0f68af commit c95fb0a
Showing 1 changed file with 6 additions and 20 deletions.
26 changes: 6 additions & 20 deletions vmoe/initialization/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import flax.struct
import jax
from jax.experimental import maps
import jax.numpy as jnp
import scipy.ndimage
from vmoe import partitioning
Expand Down Expand Up @@ -291,25 +290,12 @@ class ZoomTransformation(Transformation):
@classmethod
def _zoom(cls, source, callback_shape_dtype, zoom):
# Wrap scipy.ndimage.zoom with a _pure_callback call.
def _pure_callback_zoom(x):
return jax.pure_callback(
lambda xx: scipy.ndimage.zoom(xx, zoom, order=1),
callback_shape_dtype, x, vectorized=False)
# When using pjit, we need to wrap the pure callback with xmap.
# TODO(jpuigcerver): Simplify this when pure_callback is well supported with
# pjit.
mesh = maps.thread_resources.env.physical_mesh
if mesh.empty:
return _pure_callback_zoom(source)
else:
source = partitioning.with_sharding_constraint(
source, partitioning.PartitionSpec())
return maps.xmap(
_pure_callback_zoom,
in_axes=((None,) * source.ndim,),
out_axes=(None,) * source.ndim,
axis_resources={n: n for n in mesh.axis_names},
axis_sizes=mesh.shape)(source)
return jax.pure_callback(
lambda xx: scipy.ndimage.zoom(xx, zoom, order=1),
callback_shape_dtype,
source,
vectorized=False,
)

def __call__(self) -> Array:
source = self.source
Expand Down

0 comments on commit c95fb0a

Please sign in to comment.