From 3cbb975594eb47fafe7e7d6491a6df4c63ad0532 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 14 May 2024 18:05:33 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 633764452 --- vmoe/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vmoe/train/trainer.py b/vmoe/train/trainer.py index 989e5f0..84e8169 100644 --- a/vmoe/train/trainer.py +++ b/vmoe/train/trainer.py @@ -634,7 +634,7 @@ def mix(x): a = alpha.reshape(alpha.shape + (1,) * (x.ndim - batch_ndim)) return sum(a[i] * jnp.roll(x, -i, axis=roll_axis) for i in range(shape[-1])) arrays = list(map(mix, arrays)) - return jax.tree_unflatten(treedef, arrays) + return jax.tree.unflatten(treedef, arrays) def override_base_config(