diff --git a/vmoe/initialization/rules.py b/vmoe/initialization/rules.py index 4d0ac7c..7c848a5 100644 --- a/vmoe/initialization/rules.py +++ b/vmoe/initialization/rules.py @@ -21,7 +21,6 @@ import flax.struct import jax -from jax.experimental import maps import jax.numpy as jnp import scipy.ndimage from vmoe import partitioning @@ -291,25 +290,12 @@ class ZoomTransformation(Transformation): @classmethod def _zoom(cls, source, callback_shape_dtype, zoom): # Wrap scipy.ndimage.zoom with a _pure_callback call. - def _pure_callback_zoom(x): - return jax.pure_callback( - lambda xx: scipy.ndimage.zoom(xx, zoom, order=1), - callback_shape_dtype, x, vectorized=False) - # When using pjit, we need to wrap the pure callback with xmap. - # TODO(jpuigcerver): Simplify this when pure_callback is well supported with - # pjit. - mesh = maps.thread_resources.env.physical_mesh - if mesh.empty: - return _pure_callback_zoom(source) - else: - source = partitioning.with_sharding_constraint( - source, partitioning.PartitionSpec()) - return maps.xmap( - _pure_callback_zoom, - in_axes=((None,) * source.ndim,), - out_axes=(None,) * source.ndim, - axis_resources={n: n for n in mesh.axis_names}, - axis_sizes=mesh.shape)(source) + return jax.pure_callback( + lambda xx: scipy.ndimage.zoom(xx, zoom, order=1), + callback_shape_dtype, + source, + vectorized=False, + ) def __call__(self) -> Array: source = self.source