diff --git a/vmoe/train/trainer.py b/vmoe/train/trainer.py index 989e5f0..84e8169 100644 --- a/vmoe/train/trainer.py +++ b/vmoe/train/trainer.py @@ -634,7 +634,7 @@ def mix(x): a = alpha.reshape(alpha.shape + (1,) * (x.ndim - batch_ndim)) return sum(a[i] * jnp.roll(x, -i, axis=roll_axis) for i in range(shape[-1])) arrays = list(map(mix, arrays)) - return jax.tree_unflatten(treedef, arrays) + return jax.tree.unflatten(treedef, arrays) def override_base_config(