Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 553231237
  • Loading branch information
hawkinsp authored and copybara-github committed Aug 2, 2023
1 parent e00dbc2 commit 9505e61
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions vmoe/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@
from absl import logging
import flax.traverse_util
import jax
from jax import lax
from jax.experimental import maps
from jax.experimental import pjit
import numpy as np

AxisResourcesRegexes = Sequence[Tuple[str, 'UnparsedPartitionSpec']]
Expand Down Expand Up @@ -426,4 +426,4 @@ def with_sharding_constraint(x: PyTree, partition_spec: PartitionSpec):
if maps.thread_resources.env.physical_mesh.empty or partition_spec is None:
return x
else:
return pjit.with_sharding_constraint(x, partition_spec)
return lax.with_sharding_constraint(x, partition_spec)
5 changes: 3 additions & 2 deletions vmoe/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import flax.training.train_state
import flax.traverse_util
import jax
from jax import lax
from jax.experimental import maps
from jax.experimental import pjit
import jax.numpy as jnp
Expand Down Expand Up @@ -104,8 +105,8 @@ def new_grad_and_metrics_fn(params, images, labels, rngs):
pspec = jax.sharding.PartitionSpec(('expert', 'replica'))
images = images.reshape((-1, microsteps) + images.shape[1:])
labels = labels.reshape((-1, microsteps) + labels.shape[1:])
images = pjit.with_sharding_constraint(images, pspec)
labels = pjit.with_sharding_constraint(labels, pspec)
images = lax.with_sharding_constraint(images, pspec)
labels = lax.with_sharding_constraint(labels, pspec)

def accum_fn(i, state):
grad, rngs, metrics = state
Expand Down

0 comments on commit 9505e61

Please sign in to comment.