Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shape mismatch running vmoe_b16_imagenet21k_randaug_strong_ft_ilsvrc2012 checkpoint #182

Open
seliayeu opened this issue Aug 15, 2024 · 1 comment

Comments

@seliayeu
Copy link

I am unable to run the vmoe_b16_imagenet21k_randaug_strong_ft_ilsvrc2012 checkpoint with the config returned by get_config in vmoe/configs/vmoe_paper/vmoe_b16_imagenet21k_randaug_strong_ft_ilsvrc2012. This code fragment:

import jax
from vmoe.nn import models
from vmoe.data import input_pipeline
from vmoe.checkpoints import partitioned
from vmoe.configs.vmoe_paper.vmoe_b16_imagenet21k_randaug_strong_ft_ilsvrc2012 import get_config

model = models.VisionTransformerMoe(**get_config()["model"])
checkpoint = partitioned.restore_checkpoint("gs://vmoe_checkpoints/vmoe_b16_imagenet21k_randaug_strong_ft_ilsvrc2012", tree=None)

IMAGE_SIZE = 384
BATCH_SIZE = 1

image = jax.random.uniform(key=jax.random.key(1), shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3))
model.apply({'params': checkpoint}, image)

Gives the following error:
TypeError: cannot reshape array of shape (1, 577, 768) (size 443136) into shape (-1, 4616, 768) because the product of specified axis sizes (3545088) does not evenly divide 443136.

Am I using the config wrong? The issue #160 seems to describe the same problem I'm having.

@jason-adriel
Copy link

I ran to this issue as well. Trying to set the BATCH_SIZE with multiples of 8 fixes the problem for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants