From 9505e61acc43ba309f858704f550f5ad91f7ecd8 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 2 Aug 2023 12:58:22 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 553231237 --- vmoe/partitioning.py | 4 ++-- vmoe/train/trainer.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vmoe/partitioning.py b/vmoe/partitioning.py index a980f7d..a9e82cf 100644 --- a/vmoe/partitioning.py +++ b/vmoe/partitioning.py @@ -68,8 +68,8 @@ from absl import logging import flax.traverse_util import jax +from jax import lax from jax.experimental import maps -from jax.experimental import pjit import numpy as np AxisResourcesRegexes = Sequence[Tuple[str, 'UnparsedPartitionSpec']] @@ -426,4 +426,4 @@ def with_sharding_constraint(x: PyTree, partition_spec: PartitionSpec): if maps.thread_resources.env.physical_mesh.empty or partition_spec is None: return x else: - return pjit.with_sharding_constraint(x, partition_spec) + return lax.with_sharding_constraint(x, partition_spec) diff --git a/vmoe/train/trainer.py b/vmoe/train/trainer.py index 1d61c87..081ccca 100644 --- a/vmoe/train/trainer.py +++ b/vmoe/train/trainer.py @@ -29,6 +29,7 @@ import flax.training.train_state import flax.traverse_util import jax +from jax import lax from jax.experimental import maps from jax.experimental import pjit import jax.numpy as jnp @@ -104,8 +105,8 @@ def new_grad_and_metrics_fn(params, images, labels, rngs): pspec = jax.sharding.PartitionSpec(('expert', 'replica')) images = images.reshape((-1, microsteps) + images.shape[1:]) labels = labels.reshape((-1, microsteps) + labels.shape[1:]) - images = pjit.with_sharding_constraint(images, pspec) - labels = pjit.with_sharding_constraint(labels, pspec) + images = lax.with_sharding_constraint(images, pspec) + labels = lax.with_sharding_constraint(labels, pspec) def accum_fn(i, state): grad, rngs, metrics = state