Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pretrain Models issues. #1

Open
Xuguowei-hub opened this issue May 6, 2024 · 3 comments
Open

Pretrain Models issues. #1

Xuguowei-hub opened this issue May 6, 2024 · 3 comments

Comments

@Xuguowei-hub
Copy link

Xuguowei-hub commented May 6, 2024

Thank you very much for your contribution to the community.

  1. There seems to be a possible discrepancy between the correct file name on the command line and the actual file name. I found that the actual weight file in the link is called clip_best.pth, while the one I need in the command line is clip_best_fid.pth.
    python eval_t2m.py --resume-pth pretrained/net_best_fid.pth --clip_path pretrained/clip_best_fid.pth

2.I'm getting a KeyError, the checkpoint file net_best_fid.pth doesn't seem to contain the key ['net'] that the code is expecting.
print('loading checkpoint from {}'.format(args.resume_pth)) ckpt = torch.load(args.resume_pth, map_location='cpu') net.load_state_dict(ckpt['net'], strict=True) net.eval() net.cuda()
Therefore, it raises a KeyError.
1714980684012

Looking forward to your reply!

@thevisad
Copy link

thevisad commented May 6, 2024

@Xuguowei-hub beat me to the punch on this. I dug a little deeper and the trained data contains the dict trans, not net as indicated in the code. Switch to use trans instead introduces other errors. As indicated above clip_best_fid.pth does not exist but clip_best.pth does.

(base) thevisad@prettygirl:~/Stable-Text-to-Motion-Framework$ conda activate SATO
(SATO) thevisad@prettygirl:~/Stable-Text-to-Motion-Framework$ python3
Python 3.8.11 (default, Aug  3 2021, 15:09:35)
[GCC 7.5.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>>
>>> # Load the checkpoint
>>> checkpoint_path = 'pretrained/net_best_fid.pth'
>>> ckpt = torch.load(checkpoint_path, map_location='cpu')
>>>
>>> # Print all keys in the checkpoint dictionary
>>> print("Keys in the checkpoint:", ckpt.keys())
Keys in the checkpoint: dict_keys(['trans'])
>>>

>>> checkpoint_path = 'pretrained/clip_best.pth'
>>> ckpt = torch.load(checkpoint_path, map_location='cpu')
>>> print("Keys in the checkpoint:", ckpt.keys())
Keys in the checkpoint: odict_keys(['positional_embedding', 'text_projection', 'logit_scale', 'visual.class_embedding', 'visual.positional_embedding', 'visual.proj', 'visual.conv1.weight', 'visual.ln_pre.weight', 'visual.ln_pre.bias', 'visual.transformer.resblocks.0.attn.in_proj_weight', 'visual.transformer.resblocks.0.attn.in_proj_bias', 'visual.transformer.resblocks.0.attn.out_proj.weight', 'visual.transformer.resblocks.0.attn.out_proj.bias', 'visual.transformer.resblocks.0.ln_1.weight', 'visual.transformer.resblocks.0.ln_1.bias', 'visual.transformer.resblocks.0.mlp.c_fc.weight', 'visual.transformer.resblocks.0.mlp.c_fc.bias', 'visual.transformer.resblocks.0.mlp.c_proj.weight', 'visual.transformer.resblocks.0.mlp.c_proj.bias', 'visual.transformer.resblocks.0.ln_2.weight', 'visual.transformer.resblocks.0.ln_2.bias', 'visual.transformer.resblocks.1.attn.in_proj_weight', 'visual.transformer.resblocks.1.attn.in_proj_bias', 'visual.transformer.resblocks.1.attn.out_proj.weight', 'visual.transformer.resblocks.1.attn.out_proj.bias', 'visual.transformer.resblocks.1.ln_1.weight', 'visual.transformer.resblocks.1.ln_1.bias', 'visual.transformer.resblocks.1.mlp.c_fc.weight', 'visual.transformer.resblocks.1.mlp.c_fc.bias', 'visual.transformer.resblocks.1.mlp.c_proj.weight', 'visual.transformer.resblocks.1.mlp.c_proj.bias', 'visual.transformer.resblocks.1.ln_2.weight', 'visual.transformer.resblocks.1.ln_2.bias', 'visual.transformer.resblocks.2.attn.in_proj_weight', 'visual.transformer.resblocks.2.attn.in_proj_bias', 'visual.transformer.resblocks.2.attn.out_proj.weight', 'visual.transformer.resblocks.2.attn.out_proj.bias', 'visual.transformer.resblocks.2.ln_1.weight', 'visual.transformer.resblocks.2.ln_1.bias', 'visual.transformer.resblocks.2.mlp.c_fc.weight', 'visual.transformer.resblocks.2.mlp.c_fc.bias', 'visual.transformer.resblocks.2.mlp.c_proj.weight', 'visual.transformer.resblocks.2.mlp.c_proj.bias', 'visual.transformer.resblocks.2.ln_2.weight', 'visual.transformer.resblocks.2.ln_2.bias', 'visual.transformer.resblocks.3.attn.in_proj_weight', 'visual.transformer.resblocks.3.attn.in_proj_bias', 'visual.transformer.resblocks.3.attn.out_proj.weight', 'visual.transformer.resblocks.3.attn.out_proj.bias', 'visual.transformer.resblocks.3.ln_1.weight', 'visual.transformer.resblocks.3.ln_1.bias', 'visual.transformer.resblocks.3.mlp.c_fc.weight', 'visual.transformer.resblocks.3.mlp.c_fc.bias', 'visual.transformer.resblocks.3.mlp.c_proj.weight', 'visual.transformer.resblocks.3.mlp.c_proj.bias', 'visual.transformer.resblocks.3.ln_2.weight', 'visual.transformer.resblocks.3.ln_2.bias', 'visual.transformer.resblocks.4.attn.in_proj_weight', 'visual.transformer.resblocks.4.attn.in_proj_bias', 'visual.transformer.resblocks.4.attn.out_proj.weight', 'visual.transformer.resblocks.4.attn.out_proj.bias', 'visual.transformer.resblocks.4.ln_1.weight', 'visual.transformer.resblocks.4.ln_1.bias', 'visual.transformer.resblocks.4.mlp.c_fc.weight', 'visual.transformer.resblocks.4.mlp.c_fc.bias', 'visual.transformer.resblocks.4.mlp.c_proj.weight', 'visual.transformer.resblocks.4.mlp.c_proj.bias', 'visual.transformer.resblocks.4.ln_2.weight', 'visual.transformer.resblocks.4.ln_2.bias', 'visual.transformer.resblocks.5.attn.in_proj_weight', 'visual.transformer.resblocks.5.attn.in_proj_bias', 'visual.transformer.resblocks.5.attn.out_proj.weight', 'visual.transformer.resblocks.5.attn.out_proj.bias', 'visual.transformer.resblocks.5.ln_1.weight', 'visual.transformer.resblocks.5.ln_1.bias', 'visual.transformer.resblocks.5.mlp.c_fc.weight', 'visual.transformer.resblocks.5.mlp.c_fc.bias', 'visual.transformer.resblocks.5.mlp.c_proj.weight', 'visual.transformer.resblocks.5.mlp.c_proj.bias', 'visual.transformer.resblocks.5.ln_2.weight', 'visual.transformer.resblocks.5.ln_2.bias', 'visual.transformer.resblocks.6.attn.in_proj_weight', 'visual.transformer.resblocks.6.attn.in_proj_bias', 'visual.transformer.resblocks.6.attn.out_proj.weight', 'visual.transformer.resblocks.6.attn.out_proj.bias', 'visual.transformer.resblocks.6.ln_1.weight', 'visual.transformer.resblocks.6.ln_1.bias', 'visual.transformer.resblocks.6.mlp.c_fc.weight', 'visual.transformer.resblocks.6.mlp.c_fc.bias', 'visual.transformer.resblocks.6.mlp.c_proj.weight', 'visual.transformer.resblocks.6.mlp.c_proj.bias', 'visual.transformer.resblocks.6.ln_2.weight', 'visual.transformer.resblocks.6.ln_2.bias', 'visual.transformer.resblocks.7.attn.in_proj_weight', 'visual.transformer.resblocks.7.attn.in_proj_bias', 'visual.transformer.resblocks.7.attn.out_proj.weight', 'visual.transformer.resblocks.7.attn.out_proj.bias', 'visual.transformer.resblocks.7.ln_1.weight', 'visual.transformer.resblocks.7.ln_1.bias', 'visual.transformer.resblocks.7.mlp.c_fc.weight', 'visual.transformer.resblocks.7.mlp.c_fc.bias', 'visual.transformer.resblocks.7.mlp.c_proj.weight', 'visual.transformer.resblocks.7.mlp.c_proj.bias', 'visual.transformer.resblocks.7.ln_2.weight', 'visual.transformer.resblocks.7.ln_2.bias', 'visual.transformer.resblocks.8.attn.in_proj_weight', 'visual.transformer.resblocks.8.attn.in_proj_bias', 'visual.transformer.resblocks.8.attn.out_proj.weight', 'visual.transformer.resblocks.8.attn.out_proj.bias', 'visual.transformer.resblocks.8.ln_1.weight', 'visual.transformer.resblocks.8.ln_1.bias', 'visual.transformer.resblocks.8.mlp.c_fc.weight', 'visual.transformer.resblocks.8.mlp.c_fc.bias', 'visual.transformer.resblocks.8.mlp.c_proj.weight', 'visual.transformer.resblocks.8.mlp.c_proj.bias', 'visual.transformer.resblocks.8.ln_2.weight', 'visual.transformer.resblocks.8.ln_2.bias', 'visual.transformer.resblocks.9.attn.in_proj_weight', 'visual.transformer.resblocks.9.attn.in_proj_bias', 'visual.transformer.resblocks.9.attn.out_proj.weight', 'visual.transformer.resblocks.9.attn.out_proj.bias', 'visual.transformer.resblocks.9.ln_1.weight', 'visual.transformer.resblocks.9.ln_1.bias', 'visual.transformer.resblocks.9.mlp.c_fc.weight', 'visual.transformer.resblocks.9.mlp.c_fc.bias', 'visual.transformer.resblocks.9.mlp.c_proj.weight', 'visual.transformer.resblocks.9.mlp.c_proj.bias', 'visual.transformer.resblocks.9.ln_2.weight', 'visual.transformer.resblocks.9.ln_2.bias', 'visual.transformer.resblocks.10.attn.in_proj_weight', 'visual.transformer.resblocks.10.attn.in_proj_bias', 'visual.transformer.resblocks.10.attn.out_proj.weight', 'visual.transformer.resblocks.10.attn.out_proj.bias', 'visual.transformer.resblocks.10.ln_1.weight', 'visual.transformer.resblocks.10.ln_1.bias', 'visual.transformer.resblocks.10.mlp.c_fc.weight', 'visual.transformer.resblocks.10.mlp.c_fc.bias', 'visual.transformer.resblocks.10.mlp.c_proj.weight', 'visual.transformer.resblocks.10.mlp.c_proj.bias', 'visual.transformer.resblocks.10.ln_2.weight', 'visual.transformer.resblocks.10.ln_2.bias', 'visual.transformer.resblocks.11.attn.in_proj_weight', 'visual.transformer.resblocks.11.attn.in_proj_bias', 'visual.transformer.resblocks.11.attn.out_proj.weight', 'visual.transformer.resblocks.11.attn.out_proj.bias', 'visual.transformer.resblocks.11.ln_1.weight', 'visual.transformer.resblocks.11.ln_1.bias', 'visual.transformer.resblocks.11.mlp.c_fc.weight', 'visual.transformer.resblocks.11.mlp.c_fc.bias', 'visual.transformer.resblocks.11.mlp.c_proj.weight', 'visual.transformer.resblocks.11.mlp.c_proj.bias', 'visual.transformer.resblocks.11.ln_2.weight', 'visual.transformer.resblocks.11.ln_2.bias', 'visual.ln_post.weight', 'visual.ln_post.bias', 'transformer.resblocks.0.attn.in_proj_weight', 'transformer.resblocks.0.attn.in_proj_bias', 'transformer.resblocks.0.attn.out_proj.weight', 'transformer.resblocks.0.attn.out_proj.bias', 'transformer.resblocks.0.ln_1.weight', 'transformer.resblocks.0.ln_1.bias', 'transformer.resblocks.0.mlp.c_fc.weight', 'transformer.resblocks.0.mlp.c_fc.bias', 'transformer.resblocks.0.mlp.c_proj.weight', 'transformer.resblocks.0.mlp.c_proj.bias', 'transformer.resblocks.0.ln_2.weight', 'transformer.resblocks.0.ln_2.bias', 'transformer.resblocks.1.attn.in_proj_weight', 'transformer.resblocks.1.attn.in_proj_bias', 'transformer.resblocks.1.attn.out_proj.weight', 'transformer.resblocks.1.attn.out_proj.bias', 'transformer.resblocks.1.ln_1.weight', 'transformer.resblocks.1.ln_1.bias', 'transformer.resblocks.1.mlp.c_fc.weight', 'transformer.resblocks.1.mlp.c_fc.bias', 'transformer.resblocks.1.mlp.c_proj.weight', 'transformer.resblocks.1.mlp.c_proj.bias', 'transformer.resblocks.1.ln_2.weight', 'transformer.resblocks.1.ln_2.bias', 'transformer.resblocks.2.attn.in_proj_weight', 'transformer.resblocks.2.attn.in_proj_bias', 'transformer.resblocks.2.attn.out_proj.weight', 'transformer.resblocks.2.attn.out_proj.bias', 'transformer.resblocks.2.ln_1.weight', 'transformer.resblocks.2.ln_1.bias', 'transformer.resblocks.2.mlp.c_fc.weight', 'transformer.resblocks.2.mlp.c_fc.bias', 'transformer.resblocks.2.mlp.c_proj.weight', 'transformer.resblocks.2.mlp.c_proj.bias', 'transformer.resblocks.2.ln_2.weight', 'transformer.resblocks.2.ln_2.bias', 'transformer.resblocks.3.attn.in_proj_weight', 'transformer.resblocks.3.attn.in_proj_bias', 'transformer.resblocks.3.attn.out_proj.weight', 'transformer.resblocks.3.attn.out_proj.bias', 'transformer.resblocks.3.ln_1.weight', 'transformer.resblocks.3.ln_1.bias', 'transformer.resblocks.3.mlp.c_fc.weight', 'transformer.resblocks.3.mlp.c_fc.bias', 'transformer.resblocks.3.mlp.c_proj.weight', 'transformer.resblocks.3.mlp.c_proj.bias', 'transformer.resblocks.3.ln_2.weight', 'transformer.resblocks.3.ln_2.bias', 'transformer.resblocks.4.attn.in_proj_weight', 'transformer.resblocks.4.attn.in_proj_bias', 'transformer.resblocks.4.attn.out_proj.weight', 'transformer.resblocks.4.attn.out_proj.bias', 'transformer.resblocks.4.ln_1.weight', 'transformer.resblocks.4.ln_1.bias', 'transformer.resblocks.4.mlp.c_fc.weight', 'transformer.resblocks.4.mlp.c_fc.bias', 'transformer.resblocks.4.mlp.c_proj.weight', 'transformer.resblocks.4.mlp.c_proj.bias', 'transformer.resblocks.4.ln_2.weight', 'transformer.resblocks.4.ln_2.bias', 'transformer.resblocks.5.attn.in_proj_weight', 'transformer.resblocks.5.attn.in_proj_bias', 'transformer.resblocks.5.attn.out_proj.weight', 'transformer.resblocks.5.attn.out_proj.bias', 'transformer.resblocks.5.ln_1.weight', 'transformer.resblocks.5.ln_1.bias', 'transformer.resblocks.5.mlp.c_fc.weight', 'transformer.resblocks.5.mlp.c_fc.bias', 'transformer.resblocks.5.mlp.c_proj.weight', 'transformer.resblocks.5.mlp.c_proj.bias', 'transformer.resblocks.5.ln_2.weight', 'transformer.resblocks.5.ln_2.bias', 'transformer.resblocks.6.attn.in_proj_weight', 'transformer.resblocks.6.attn.in_proj_bias', 'transformer.resblocks.6.attn.out_proj.weight', 'transformer.resblocks.6.attn.out_proj.bias', 'transformer.resblocks.6.ln_1.weight', 'transformer.resblocks.6.ln_1.bias', 'transformer.resblocks.6.mlp.c_fc.weight', 'transformer.resblocks.6.mlp.c_fc.bias', 'transformer.resblocks.6.mlp.c_proj.weight', 'transformer.resblocks.6.mlp.c_proj.bias', 'transformer.resblocks.6.ln_2.weight', 'transformer.resblocks.6.ln_2.bias', 'transformer.resblocks.7.attn.in_proj_weight', 'transformer.resblocks.7.attn.in_proj_bias', 'transformer.resblocks.7.attn.out_proj.weight', 'transformer.resblocks.7.attn.out_proj.bias', 'transformer.resblocks.7.ln_1.weight', 'transformer.resblocks.7.ln_1.bias', 'transformer.resblocks.7.mlp.c_fc.weight', 'transformer.resblocks.7.mlp.c_fc.bias', 'transformer.resblocks.7.mlp.c_proj.weight', 'transformer.resblocks.7.mlp.c_proj.bias', 'transformer.resblocks.7.ln_2.weight', 'transformer.resblocks.7.ln_2.bias', 'transformer.resblocks.8.attn.in_proj_weight', 'transformer.resblocks.8.attn.in_proj_bias', 'transformer.resblocks.8.attn.out_proj.weight', 'transformer.resblocks.8.attn.out_proj.bias', 'transformer.resblocks.8.ln_1.weight', 'transformer.resblocks.8.ln_1.bias', 'transformer.resblocks.8.mlp.c_fc.weight', 'transformer.resblocks.8.mlp.c_fc.bias', 'transformer.resblocks.8.mlp.c_proj.weight', 'transformer.resblocks.8.mlp.c_proj.bias', 'transformer.resblocks.8.ln_2.weight', 'transformer.resblocks.8.ln_2.bias', 'transformer.resblocks.9.attn.in_proj_weight', 'transformer.resblocks.9.attn.in_proj_bias', 'transformer.resblocks.9.attn.out_proj.weight', 'transformer.resblocks.9.attn.out_proj.bias', 'transformer.resblocks.9.ln_1.weight', 'transformer.resblocks.9.ln_1.bias', 'transformer.resblocks.9.mlp.c_fc.weight', 'transformer.resblocks.9.mlp.c_fc.bias', 'transformer.resblocks.9.mlp.c_proj.weight', 'transformer.resblocks.9.mlp.c_proj.bias', 'transformer.resblocks.9.ln_2.weight', 'transformer.resblocks.9.ln_2.bias', 'transformer.resblocks.10.attn.in_proj_weight', 'transformer.resblocks.10.attn.in_proj_bias', 'transformer.resblocks.10.attn.out_proj.weight', 'transformer.resblocks.10.attn.out_proj.bias', 'transformer.resblocks.10.ln_1.weight', 'transformer.resblocks.10.ln_1.bias', 'transformer.resblocks.10.mlp.c_fc.weight', 'transformer.resblocks.10.mlp.c_fc.bias', 'transformer.resblocks.10.mlp.c_proj.weight', 'transformer.resblocks.10.mlp.c_proj.bias', 'transformer.resblocks.10.ln_2.weight', 'transformer.resblocks.10.ln_2.bias', 'transformer.resblocks.11.attn.in_proj_weight', 'transformer.resblocks.11.attn.in_proj_bias', 'transformer.resblocks.11.attn.out_proj.weight', 'transformer.resblocks.11.attn.out_proj.bias', 'transformer.resblocks.11.ln_1.weight', 'transformer.resblocks.11.ln_1.bias', 'transformer.resblocks.11.mlp.c_fc.weight', 'transformer.resblocks.11.mlp.c_fc.bias', 'transformer.resblocks.11.mlp.c_proj.weight', 'transformer.resblocks.11.mlp.c_proj.bias', 'transformer.resblocks.11.ln_2.weight', 'transformer.resblocks.11.ln_2.bias', 'token_embedding.weight', 'ln_final.weight', 'ln_final.bias'])
>>>

@Chatonz
Copy link
Contributor

Chatonz commented May 6, 2024

Thank you very much for pointing out the issues. Regarding the first error with the file name, we have made the corresponding correction.

For the second issue, we may not have clarified in the usage instructions. The weights loaded here should be those of the VQ-VAE, and for this weight, we have used the same weight as t2m-gpt without any adjustments. We have also uploaded the weights to the link

@Xuguowei-hub
Copy link
Author

Thank you very much for your reply! Will you provide additional tutorials on demo visualization in the future?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants