Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633764452
  • Loading branch information
Jake VanderPlas authored and copybara-github committed May 15, 2024
1 parent da5922d commit 3cbb975
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion vmoe/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def mix(x):
a = alpha.reshape(alpha.shape + (1,) * (x.ndim - batch_ndim))
return sum(a[i] * jnp.roll(x, -i, axis=roll_axis) for i in range(shape[-1]))
arrays = list(map(mix, arrays))
return jax.tree_unflatten(treedef, arrays)
return jax.tree.unflatten(treedef, arrays)


def override_base_config(
Expand Down

0 comments on commit 3cbb975

Please sign in to comment.