diff --git a/vmoe/initialization/rules.py b/vmoe/initialization/rules.py index 3e1825c..cdb6373 100644 --- a/vmoe/initialization/rules.py +++ b/vmoe/initialization/rules.py @@ -263,7 +263,7 @@ class ReshapeTransformation(Transformation): shape: Tuple[int, ...] = flax.struct.field(pytree_node=False) def __call__(self) -> Array: - return jnp.reshape(self.array, newshape=self.shape) + return jnp.reshape(self.array, shape=self.shape) class SqueezeTransformation(Transformation):