Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use the full-parameter training checkpoint for sample_t2i.py? #107

Open
a5704607 opened this issue Jun 20, 2024 · 7 comments
Open

Comments

@a5704607
Copy link

PYTHONPATH=./ sh hydit/train.sh --index-file dataset/porcelain/jsons/porcelain.json
I use this code to train full-parameter, then the checkpoint is saved as 005-dit_g2_full_1024p/checkpoints,I want to know how to use the checkpoint for sample_t2i.py?

@MichaelFan01
Copy link

同问

@a5704607
Copy link
Author

同问

I think you can use 005-dit_g2_full_1024p/checkpoints/0010000.pt/mp_rank_00_model_states.pt

@MichaelFan01
Copy link

同问

I think you can use 005-dit_g2_full_1024p/checkpoints/0010000.pt/mp_rank_00_model_states.pt

试了试,不太行啊:
File "/maindata/data/shared/multimodal/michael.fan/aigc-apps/HunyuanDiT-main/hydit/inference.py", line 223, in init
self.model.load_state_dict(state_dict)
File "/maindata/data/shared/multimodal/michael.fan/env/miniconda3/envs/hunyuanDit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for HunYuanDiT:
Missing key(s) in state_dict: "text_embedding_padding", "mlp_t5.0.weight", "mlp_t5.0.bias", "mlp_t5.2.weight", "mlp_t5.2.bias", "pooler.positional_embedding", "pooler.k_proj.weight", "pooler.k_proj.bias", "pooler.q_proj.weight", "pooler.q_proj.bias", "pooler.v_proj.weight", "pooler.v_proj.bias", "pooler.c_proj.weight", "pooler.c_proj.bias", "style_embedder.weight", "x_embedder.proj.weight", "x_embedder.proj.bias", "t_embedder.mlp.0.weight", "t_embedder.mlp.0.bias", "t_embedder.mlp.2.weight", "t_embedder.mlp.2.bias", "extra_embedder.0.weight", "extra_embedder.0.bias", "extra_embedder.2.weight", "extra_embedder.2.bias", "blocks.0.norm1.weight", "blocks.0.norm1.bias", "blocks.0.attn1.Wqkv.weight", "blocks.0.attn1.Wqkv.bias", "blocks.0.attn1.q_norm.weight", "blocks.0.attn1.q_norm.bias", "blocks.0.attn1.k_norm.weight", "blocks.0.attn1.k_norm.bias", "blocks.0.attn1.out_proj.weight", "blocks.0.attn1.out_proj.bias", "blocks.0.norm2.weight", "blocks.0.norm2.bias", "blocks.0.mlp.fc1.weight", "blocks.0.mlp.fc1.bias", "blocks.0.mlp.fc2.weight", "blocks.0.mlp.fc2.bias", "blocks.0.default_modulation.1.weight", "blocks.0.default_modulation.1.bias", "blocks.0.attn2.q_proj.weight", "blocks.0.attn2.q_proj.bias", "blocks.0.attn2.kv_proj.weight", "blocks.0.attn2.kv_proj.bias", "blocks.0.attn2.q_norm.weight", "blocks.0.attn2.q_norm.bias", "blocks.0.attn2.k_norm.weight", "blocks.0.attn2.k_norm.bias", "blocks.0.attn2.out_proj.weight", "blocks.0.attn2.out_proj.bias", "blocks.0.norm3.weight", "blocks.0.norm3.bias", "blocks.1.norm1.weight", "blocks.1.norm1.bias", "blocks.1.attn1.Wqkv.weight", "blocks.1.attn1.Wqkv.bias", "blocks.1.attn1.q_norm.weight", "blocks.1.attn1.q_norm.bias", "blocks.1.attn1.k_norm.weight", "blocks.1.attn1.k_norm.bias", "blocks.1.attn1.out_proj.weight", "blocks.1.attn1.out_proj.bias", "blocks.1.norm2.weight", "blocks.1.norm2.bias", "blocks.1.mlp.fc1.weight", "blocks.1.mlp.fc1.bias", "blocks.1.mlp.fc2.weight", "blocks.1.mlp.fc2.bias", "blocks.1.default_modulation.1.weight", "blocks.1.default_modulation.1.bias", "blocks.1.attn2.q_proj.weight", "blocks.1.attn2.q_proj.bias", "blocks.1.attn2.kv_proj.weight", "blocks.1.attn2.kv_proj.bias", "blocks.1.attn2.q_norm.weight", "blocks.1.attn2.q_norm.bias", "blocks.1.attn2.k_norm.weight", "blocks.1.attn2.k_norm.bias", "blocks.1.attn2.out_proj.weight", "blocks.1.attn2.out_proj.bias", "blocks.1.norm3.weight", "blocks.1.norm3.bias", "blocks.2.norm1.weight", "blocks.2.norm1.bias", "blocks.2.attn1.Wqkv.weight", "blocks.2.attn1.Wqkv.bias", "blocks.2.attn1.q_norm.weight", "blocks.2.attn1.q_norm.bias", "blocks.2.attn1.k_norm.weight", "blocks.2.attn1.k_norm.bias", "blocks.2.attn1.out_proj.weight", "blocks.2.attn1.out_proj.bias"

@a5704607
Copy link
Author

同问

I think you can use 005-dit_g2_full_1024p/checkpoints/0010000.pt/mp_rank_00_model_states.pt

试了试,不太行啊: File "/maindata/data/shared/multimodal/michael.fan/aigc-apps/HunyuanDiT-main/hydit/inference.py", line 223, in init self.model.load_state_dict(state_dict) File "/maindata/data/shared/multimodal/michael.fan/env/miniconda3/envs/hunyuanDit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for HunYuanDiT: Missing key(s) in state_dict: "text_embedding_padding", "mlp_t5.0.weight", "mlp_t5.0.bias", "mlp_t5.2.weight", "mlp_t5.2.bias", "pooler.positional_embedding", "pooler.k_proj.weight", "pooler.k_proj.bias", "pooler.q_proj.weight", "pooler.q_proj.bias", "pooler.v_proj.weight", "pooler.v_proj.bias", "pooler.c_proj.weight", "pooler.c_proj.bias", "style_embedder.weight", "x_embedder.proj.weight", "x_embedder.proj.bias", "t_embedder.mlp.0.weight", "t_embedder.mlp.0.bias", "t_embedder.mlp.2.weight", "t_embedder.mlp.2.bias", "extra_embedder.0.weight", "extra_embedder.0.bias", "extra_embedder.2.weight", "extra_embedder.2.bias", "blocks.0.norm1.weight", "blocks.0.norm1.bias", "blocks.0.attn1.Wqkv.weight", "blocks.0.attn1.Wqkv.bias", "blocks.0.attn1.q_norm.weight", "blocks.0.attn1.q_norm.bias", "blocks.0.attn1.k_norm.weight", "blocks.0.attn1.k_norm.bias", "blocks.0.attn1.out_proj.weight", "blocks.0.attn1.out_proj.bias", "blocks.0.norm2.weight", "blocks.0.norm2.bias", "blocks.0.mlp.fc1.weight", "blocks.0.mlp.fc1.bias", "blocks.0.mlp.fc2.weight", "blocks.0.mlp.fc2.bias", "blocks.0.default_modulation.1.weight", "blocks.0.default_modulation.1.bias", "blocks.0.attn2.q_proj.weight", "blocks.0.attn2.q_proj.bias", "blocks.0.attn2.kv_proj.weight", "blocks.0.attn2.kv_proj.bias", "blocks.0.attn2.q_norm.weight", "blocks.0.attn2.q_norm.bias", "blocks.0.attn2.k_norm.weight", "blocks.0.attn2.k_norm.bias", "blocks.0.attn2.out_proj.weight", "blocks.0.attn2.out_proj.bias", "blocks.0.norm3.weight", "blocks.0.norm3.bias", "blocks.1.norm1.weight", "blocks.1.norm1.bias", "blocks.1.attn1.Wqkv.weight", "blocks.1.attn1.Wqkv.bias", "blocks.1.attn1.q_norm.weight", "blocks.1.attn1.q_norm.bias", "blocks.1.attn1.k_norm.weight", "blocks.1.attn1.k_norm.bias", "blocks.1.attn1.out_proj.weight", "blocks.1.attn1.out_proj.bias", "blocks.1.norm2.weight", "blocks.1.norm2.bias", "blocks.1.mlp.fc1.weight", "blocks.1.mlp.fc1.bias", "blocks.1.mlp.fc2.weight", "blocks.1.mlp.fc2.bias", "blocks.1.default_modulation.1.weight", "blocks.1.default_modulation.1.bias", "blocks.1.attn2.q_proj.weight", "blocks.1.attn2.q_proj.bias", "blocks.1.attn2.kv_proj.weight", "blocks.1.attn2.kv_proj.bias", "blocks.1.attn2.q_norm.weight", "blocks.1.attn2.q_norm.bias", "blocks.1.attn2.k_norm.weight", "blocks.1.attn2.k_norm.bias", "blocks.1.attn2.out_proj.weight", "blocks.1.attn2.out_proj.bias", "blocks.1.norm3.weight", "blocks.1.norm3.bias", "blocks.2.norm1.weight", "blocks.2.norm1.bias", "blocks.2.attn1.Wqkv.weight", "blocks.2.attn1.Wqkv.bias", "blocks.2.attn1.q_norm.weight", "blocks.2.attn1.q_norm.bias", "blocks.2.attn1.k_norm.weight", "blocks.2.attn1.k_norm.bias", "blocks.2.attn1.out_proj.weight", "blocks.2.attn1.out_proj.bias"

key: 'ema'

@MichaelFan01
Copy link

同问

I think you can use 005-dit_g2_full_1024p/checkpoints/0010000.pt/mp_rank_00_model_states.pt

试了试,不太行啊: File "/maindata/data/shared/multimodal/michael.fan/aigc-apps/HunyuanDiT-main/hydit/inference.py", line 223, in init self.model.load_state_dict(state_dict) File "/maindata/data/shared/multimodal/michael.fan/env/miniconda3/envs/hunyuanDit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for HunYuanDiT: Missing key(s) in state_dict: "text_embedding_padding", "mlp_t5.0.weight", "mlp_t5.0.bias", "mlp_t5.2.weight", "mlp_t5.2.bias", "pooler.positional_embedding", "pooler.k_proj.weight", "pooler.k_proj.bias", "pooler.q_proj.weight", "pooler.q_proj.bias", "pooler.v_proj.weight", "pooler.v_proj.bias", "pooler.c_proj.weight", "pooler.c_proj.bias", "style_embedder.weight", "x_embedder.proj.weight", "x_embedder.proj.bias", "t_embedder.mlp.0.weight", "t_embedder.mlp.0.bias", "t_embedder.mlp.2.weight", "t_embedder.mlp.2.bias", "extra_embedder.0.weight", "extra_embedder.0.bias", "extra_embedder.2.weight", "extra_embedder.2.bias", "blocks.0.norm1.weight", "blocks.0.norm1.bias", "blocks.0.attn1.Wqkv.weight", "blocks.0.attn1.Wqkv.bias", "blocks.0.attn1.q_norm.weight", "blocks.0.attn1.q_norm.bias", "blocks.0.attn1.k_norm.weight", "blocks.0.attn1.k_norm.bias", "blocks.0.attn1.out_proj.weight", "blocks.0.attn1.out_proj.bias", "blocks.0.norm2.weight", "blocks.0.norm2.bias", "blocks.0.mlp.fc1.weight", "blocks.0.mlp.fc1.bias", "blocks.0.mlp.fc2.weight", "blocks.0.mlp.fc2.bias", "blocks.0.default_modulation.1.weight", "blocks.0.default_modulation.1.bias", "blocks.0.attn2.q_proj.weight", "blocks.0.attn2.q_proj.bias", "blocks.0.attn2.kv_proj.weight", "blocks.0.attn2.kv_proj.bias", "blocks.0.attn2.q_norm.weight", "blocks.0.attn2.q_norm.bias", "blocks.0.attn2.k_norm.weight", "blocks.0.attn2.k_norm.bias", "blocks.0.attn2.out_proj.weight", "blocks.0.attn2.out_proj.bias", "blocks.0.norm3.weight", "blocks.0.norm3.bias", "blocks.1.norm1.weight", "blocks.1.norm1.bias", "blocks.1.attn1.Wqkv.weight", "blocks.1.attn1.Wqkv.bias", "blocks.1.attn1.q_norm.weight", "blocks.1.attn1.q_norm.bias", "blocks.1.attn1.k_norm.weight", "blocks.1.attn1.k_norm.bias", "blocks.1.attn1.out_proj.weight", "blocks.1.attn1.out_proj.bias", "blocks.1.norm2.weight", "blocks.1.norm2.bias", "blocks.1.mlp.fc1.weight", "blocks.1.mlp.fc1.bias", "blocks.1.mlp.fc2.weight", "blocks.1.mlp.fc2.bias", "blocks.1.default_modulation.1.weight", "blocks.1.default_modulation.1.bias", "blocks.1.attn2.q_proj.weight", "blocks.1.attn2.q_proj.bias", "blocks.1.attn2.kv_proj.weight", "blocks.1.attn2.kv_proj.bias", "blocks.1.attn2.q_norm.weight", "blocks.1.attn2.q_norm.bias", "blocks.1.attn2.k_norm.weight", "blocks.1.attn2.k_norm.bias", "blocks.1.attn2.out_proj.weight", "blocks.1.attn2.out_proj.bias", "blocks.1.norm3.weight", "blocks.1.norm3.bias", "blocks.2.norm1.weight", "blocks.2.norm1.bias", "blocks.2.attn1.Wqkv.weight", "blocks.2.attn1.Wqkv.bias", "blocks.2.attn1.q_norm.weight", "blocks.2.attn1.q_norm.bias", "blocks.2.attn1.k_norm.weight", "blocks.2.attn1.k_norm.bias", "blocks.2.attn1.out_proj.weight", "blocks.2.attn1.out_proj.bias"

key: 'ema'

多谢,搞定了

@jonathanyin12
Copy link

Running into same issue. What do you mean by this?

key: 'ema'

Is there some way to simply point to the checkpoint path using the command line?

@jonathanyin12
Copy link

Figured it out. After loading the weights from 'mp_rank_00_model_states.pt', you have to index it with 'ema' before loading the state dict. Code ref

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants