diff --git a/vmoe/initialization/rules.py b/vmoe/initialization/rules.py
index 4d0ac7c..7c848a5 100644
--- a/vmoe/initialization/rules.py
+++ b/vmoe/initialization/rules.py
@@ -21,7 +21,6 @@
 
 import flax.struct
 import jax
-from jax.experimental import maps
 import jax.numpy as jnp
 import scipy.ndimage
 from vmoe import partitioning
@@ -291,25 +290,12 @@ class ZoomTransformation(Transformation):
   @classmethod
   def _zoom(cls, source, callback_shape_dtype, zoom):
     # Wrap scipy.ndimage.zoom with a _pure_callback call.
-    def _pure_callback_zoom(x):
-      return jax.pure_callback(
-          lambda xx: scipy.ndimage.zoom(xx, zoom, order=1),
-          callback_shape_dtype, x, vectorized=False)
-    # When using pjit, we need to wrap the pure callback with xmap.
-    # TODO(jpuigcerver): Simplify this when pure_callback is well supported with
-    # pjit.
-    mesh = maps.thread_resources.env.physical_mesh
-    if mesh.empty:
-      return _pure_callback_zoom(source)
-    else:
-      source = partitioning.with_sharding_constraint(
-          source, partitioning.PartitionSpec())
-      return maps.xmap(
-          _pure_callback_zoom,
-          in_axes=((None,) * source.ndim,),
-          out_axes=(None,) * source.ndim,
-          axis_resources={n: n for n in mesh.axis_names},
-          axis_sizes=mesh.shape)(source)
+    return jax.pure_callback(
+        lambda xx: scipy.ndimage.zoom(xx, zoom, order=1),
+        callback_shape_dtype,
+        source,
+        vectorized=False,
+    )
 
   def __call__(self) -> Array:
     source = self.source