Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

转换权重的问题 #280

Closed
Jayce1kk opened this issue Jul 3, 2024 · 3 comments
Closed

转换权重的问题 #280

Jayce1kk opened this issue Jul 3, 2024 · 3 comments

Comments

@Jayce1kk
Copy link

Jayce1kk commented Jul 3, 2024

问题

转换完权重之后进行评估验证时出现下述问题

> number of parameters on (tensor, pipeline) model parallel rank (0, 0): 630167424
 loading release checkpoint from /raid/LLM_train/Pai-Megatron-Patch/checkpoints/qwen-ckpts/Qwen2-0.5B-hf-mcore-te-tp1-pp1
[rank0]: Traceback (most recent call last):
[rank0]:   File "/raid/LLM_train/Pai-Megatron-Patch/examples/qwen2/evaluate_mcore_qwen.py", line 207, in <module>
[rank0]:     main()
[rank0]:   File "/raid/LLM_train/Pai-Megatron-Patch/examples/qwen2/evaluate_mcore_qwen.py", line 193, in main
[rank0]:     load_checkpoint(model, None, None)
[rank0]:   File "/raid/LLM_train/Pai-Megatron-Patch/Megatron-LM-240612/megatron/training/checkpointing.py", line 807, in load_checkpoint
[rank0]:     model[0].load_state_dict(state_dict['model'], strict=strict)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2189, in load_state_dict
[rank0]:     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
[rank0]: RuntimeError: Error(s) in loading state_dict for GPTModel:
[rank0]:        Missing key(s) in state_dict: "decoder.layers.0.self_attention.linear_proj._extra_state", "decoder.layers.0.self_attention.linear_qkv._extra_state", "decoder.layers.0.mlp.linear_fc1._extra_state", "decoder.layers.0.mlp.linear_fc2._extra_state", "decoder.layers.1.self_attention.linear_proj._extra_state", 

....

"decoder.layers.21.mlp.linear_fc2._extra_state", "decoder.layers.22.self_attention.linear_proj._extra_state", "decoder.layers.22.self_attention.linear_qkv._extra_state", "decoder.layers.22.mlp.linear_fc1._extra_state", "decoder.layers.22.mlp.linear_fc2._extra_state", "decoder.layers.23.self_attention.linear_proj._extra_state", "decoder.layers.23.self_attention.linear_qkv._extra_state", "decoder.layers.23.mlp.linear_fc1._extra_state", "decoder.layers.23.mlp.linear_fc2._extra_state". 
E0703 10:00:55.575000 140162604175808 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 0 (pid: 1450033) of binary: /usr/bin/python

转换权重指令

sh hf2mcore_qwen2_convertor.sh \
0.5B \
/raid/LLM_train/Pai-Megatron-Patch/checkpoints/qwen-ckpts/Qwen2-0.5B \
/raid/LLM_train/Pai-Megatron-Patch/checkpoints/qwen-ckpts/Qwen2-0.5B-hf-mcore-te-tp1-pp1  \
1  \
1  \
1 \
fp32 \
true \
false 

评估指令

sh run_evaluate_mcore_qwen.sh \
0.5B \
1 \
256 \
256 \
bf16 \
1 \
1 \
sel \
true \
false \
false \
true \
/raid/LLM_train/Pai-Megatron-Patch/qwen-datasets/alpaca_zh-qwen-valid.json \
/raid/LLM_train/Pai-Megatron-Patch/checkpoints/qwen-ckpts/Qwen2-0.5B-hf-mcore-te-tp1-pp1

详细报错

torchrun --nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 19751 evaluate_mcore_qwen.py --valid-data-path /raid/LLM_train/Pai-Megatron-Patch/qwen-datasets/alpaca_zh-qwen-valid.json --micro-batch-size 1 --num-layers 24 --hidden-size 896 --num-attention-heads 14 --seq-length 256 --max-position-embeddings 131072 --ffn-hidden-size 4864 --log-interval 1 --eval-interval 100 --eval-iters 10 --tensor-model-parallel-size 1 --pipeline-model-parallel-size 1 --no-load-optim --no-load-rng --seed 1234 --num-workers 0 --max-padding-length 256 --extra-vocab-size 293 --patch-tokenizer-type LLamaTokenizer --dataset LLama-Pretrain-Raw --swiglu --normalization RMSNorm --norm-epsilon 1e-6 --use-rotary-position-embeddings --no-rope-fusion --position-embedding-type rope --rotary-base 1000000 --untie-embeddings-and-output-weights --disable-bias-linear --add-qkv-bias --group-query-attention --num-query-groups 2 --eod-mask-loss --bf16 --load /raid/LLM_train/Pai-Megatron-Patch/checkpoints/qwen-ckpts/Qwen2-0.5B-hf-mcore-te-tp1-pp1 --transformer-impl transformer_engine --recompute-activations --use-distributed-optimizer
INFO:datasets:PyTorch version 2.4.0a0+07cecf4168.nv24.5 available.
in oss file
/raid/LLM_train/Pai-Megatron-Patch/megatron_patch/model/llava/clip_encoder.py:26: UserWarning: The cvcuda environment does not exist. Install cvcuda and use it
  warnings.warn("The cvcuda environment does not exist. Install cvcuda and use it")
using world size: 1, data-parallel size: 1, context-parallel size: 1 tensor-model-parallel size: 1, pipeline-model-parallel size: 1 
setting global batch size to 1
WARNING: Setting args.overlap_p2p_comm to False since non-interleaved schedule does not support overlapping p2p communication
accumulate and all-reduce gradients in fp32 for bfloat16 data type.
using torch.bfloat16 for parameters ...
------------------------ arguments ------------------------
  accumulate_allreduce_grads_in_fp32 .............. True
  adam_beta1 ...................................... 0.9
  adam_beta2 ...................................... 0.999
  adam_eps ........................................ 1e-08
  adaptive_seq_len ................................ False
  add_bias_attn_fc ................................ True
  add_bias_linear ................................. False
  add_bias_linear_fc .............................. True
  add_position_embedding .......................... True
  add_qkv_bias .................................... True
  adlr_autoresume ................................. False
  adlr_autoresume_interval ........................ 1000
  apply_layernorm_1p .............................. False
  apply_query_key_layer_scaling ................... False
  apply_residual_connection_post_layernorm ........ False
  apply_rope_fusion ............................... False
  async_save ...................................... None
  async_tensor_model_parallel_allreduce ........... True
  attention_dropout ............................... 0.1
  attention_head_type ............................. None
  attention_softmax_in_fp32 ....................... False
  auto_detect_ckpt_format ......................... False
  barrier_with_L1_time ............................ True
  bert_binary_head ................................ True
  bert_embedder_type .............................. megatron
  bert_load ....................................... None
  bf16 ............................................ True
  bias_dropout_fusion ............................. True
  bias_gelu_fusion ................................ False
  bias_swiglu_fusion .............................. True
  biencoder_projection_dim ........................ 0
  biencoder_shared_query_context_model ............ False
  block_data_path ................................. None
  calculate_per_token_loss ........................ False
  check_for_nan_in_loss_and_grad .................. True
  check_weight_hash_across_dp_replicas_interval ... None
  ckpt_assume_constant_structure .................. False
  ckpt_fully_parallel_load ........................ False
  ckpt_fully_parallel_save ........................ False
  ckpt_step ....................................... None
  classes_fraction ................................ 1.0
  clip_grad ....................................... 1.0
  clone_scatter_output_in_embedding ............... True
  consumed_train_samples .......................... 0
  consumed_valid_samples .......................... 0
  context_parallel_size ........................... 1
  convert_checkpoint_from_megatron_to_transformers  False
  create_attention_mask_in_dataloader ............. True
  cvcuda_image_processing ......................... False
  data_cache_path ................................. None
  data_dir ........................................ None
  data_parallel_random_init ....................... False
  data_parallel_size .............................. 1
  data_path ....................................... None
  data_per_class_fraction ......................... 1.0
  data_sharding ................................... True
  dataloader_type ................................. single
  dataset ......................................... LLama-Pretrain-Raw
  ddp_average_in_collective ....................... False
  ddp_bucket_size ................................. None
  decoder_num_layers .............................. None
  decoder_seq_length .............................. None
  decoupled_lr .................................... None
  decoupled_min_lr ................................ None
  delay_grad_reduce ............................... True
  delay_param_gather .............................. False
  deprecated_use_mcore_models ..................... False
  deterministic_mode .............................. False
  dino_bottleneck_size ............................ 256
  dino_freeze_last_layer .......................... 1
  dino_head_hidden_size ........................... 2048
  dino_local_crops_number ......................... 10
  dino_local_img_size ............................. 96
  dino_norm_last_layer ............................ False
  dino_teacher_temp ............................... 0.07
  dino_warmup_teacher_temp ........................ 0.04
  dino_warmup_teacher_temp_epochs ................. 30
  disable_straggler_on_startup .................... False
  dist_ckpt_format ................................ torch_dist
  distribute_saved_activations .................... False
  distributed_backend ............................. nccl
  distributed_timeout_minutes ..................... 10
  embed_layernorm ................................. False
  embedding_path .................................. None
  empty_unused_memory_level ....................... 0
  enable_one_logger ............................... False
  enable_parallel_output .......................... True
  enable_shared_expert ............................ False
  encoder_num_layers .............................. 24
  encoder_seq_length .............................. 256
  end_weight_decay ................................ 0.01
  eod_mask_loss ................................... True
  epochs .......................................... None
  eval_dev ........................................ False
  eval_fp32 ....................................... False
  eval_interval ................................... 100
  eval_iters ...................................... 10
  evidence_data_path .............................. None
  exit_duration_in_mins ........................... None
  exit_interval ................................... None
  exit_on_missing_checkpoint ...................... False
  exit_signal_handler ............................. False
  expert_interval ................................. 2
  expert_model_parallel_size ...................... 1
  expert_tensor_parallelism ....................... False
  extra_vocab_size ................................ 293
  ffn_hidden_size ................................. 4864
  finetune ........................................ False
  fp16 ............................................ False
  fp16_lm_cross_entropy ........................... False
  fp32_residual_connection ........................ False
  fp8 ............................................. None
  fp8_amax_compute_algo ........................... most_recent
  fp8_amax_history_len ............................ 1
  fp8_interval .................................... 1
  fp8_margin ...................................... 0
  fp8_wgrad ....................................... True
  freeze_clip_vision_tower ........................ False
  freeze_llm ...................................... False
  generation_length ............................... None
  global_batch_size ............................... 1
  glu_activation .................................. None
  gradient_accumulation_fusion .................... True
  group_query_attention ........................... True
  head_lr_mult .................................... 1.0
  hidden_dropout .................................. 0.1
  hidden_size ..................................... 896
  hysteresis ...................................... 2
  ict_head_size ................................... None
  ict_load ........................................ None
  image_aspect_ratio .............................. square
  image_folder .................................... 
  image_size ...................................... None
  img_h ........................................... 224
  img_w ........................................... 224
  indexer_batch_size .............................. 128
  indexer_log_interval ............................ 1000
  inference_batch_times_seqlen_threshold .......... 512
  init_method_std ................................. 0.02
  init_method_xavier_uniform ...................... False
  initial_loss_scale .............................. 4294967296
  input_len ....................................... 1
  intermediate_size ............................... None
  iter_per_epoch .................................. 1250
  keep_last ....................................... False
  kv_channels ..................................... 64
  kv_lora_rank .................................... None
  lazy_mpu_init ................................... None
  load ............................................ /raid/LLM_train/Pai-Megatron-Patch/checkpoints/qwen-ckpts/Qwen2-0.5B-hf-mcore-te-tp1-pp1
  local_rank ...................................... None
  log_batch_size_to_tensorboard ................... False
  log_interval .................................... 1
  log_learning_rate_to_tensorboard ................ True
  log_loss_scale_to_tensorboard ................... True
  log_memory_to_tensorboard ....................... False
  log_num_zeros_in_grad ........................... False
  log_params_norm ................................. False
  log_progress .................................... False
  log_straggler ................................... False
  log_throughput .................................. False
  log_timers_to_tensorboard ....................... False
  log_validation_ppl_to_tensorboard ............... False
  log_world_size_to_tensorboard ................... False
  logging_level ................................... None
  loss_scale ...................................... None
  loss_scale_window ............................... 1000
  lr .............................................. None
  lr_decay_iters .................................. None
  lr_decay_samples ................................ None
  lr_decay_style .................................. linear
  lr_warmup_fraction .............................. None
  lr_warmup_init .................................. 0.0
  lr_warmup_iters ................................. 0
  lr_warmup_samples ............................... 0
  make_vocab_size_divisible_by .................... 128
  manual_gc ....................................... False
  manual_gc_eval .................................. True
  manual_gc_interval .............................. 0
  mask_factor ..................................... 1.0
  mask_prob ....................................... 0.15
  mask_type ....................................... random
  masked_softmax_fusion ........................... True
  max_padding_length .............................. 256
  max_position_embeddings ......................... 131072
  max_tokens_to_oom ............................... 12000
  merge_file ...................................... None
  micro_batch_size ................................ 1
  min_loss_scale .................................. 1.0
  min_lr .......................................... 0.0
  mm_projector_type ............................... None
  mm_use_im_patch_token ........................... False
  mm_use_im_start_end ............................. False
  mm_vision_select_layer .......................... None
  mmap_bin_files .................................. True
  mock_data ....................................... False
  moe ............................................. False
  moe_aux_loss_coeff .............................. 0.0
  moe_eval_capacity_factor ........................ 1.0
  moe_expert_capacity_factor ...................... None
  moe_expert_parallel_size ........................ None
  moe_extended_tp ................................. False
  moe_ffn_hidden_size ............................. None
  moe_grouped_gemm ................................ False
  moe_input_feature_slicing ....................... False
  moe_input_jitter_eps ............................ None
  moe_layer_freq .................................. 1
  moe_layer_recompute ............................. False
  moe_loss_coeff .................................. 0.01
  moe_min_capacity ................................ 4
  moe_pad_expert_input_to_capacity ................ False
  moe_per_layer_logging ........................... False
  moe_router_load_balancing_type .................. aux_loss
  moe_router_topk ................................. 2
  moe_token_dispatcher_type ....................... allgather
  moe_token_drop_policy ........................... probs
  moe_topk ........................................ 1
  moe_train_capacity_factor ....................... 1.0
  moe_z_loss_coeff ................................ None
  n_head_kv ....................................... None
  nccl_communicator_config_path ................... None
  no_load_optim ................................... True
  no_load_rng ..................................... True
  no_persist_layer_norm ........................... False
  no_save_optim ................................... None
  no_save_rng ..................................... None
  norm_epsilon .................................... 1e-06
  normalization ................................... RMSNorm
  num_attention_heads ............................. 14
  num_channels .................................... 3
  num_classes ..................................... 1000
  num_dataset_builder_threads ..................... 1
  num_experts ..................................... None
  num_fewshot ..................................... None
  num_layers ...................................... 24
  num_layers_per_virtual_pipeline_stage ........... None
  num_query_groups ................................ 2
  num_shared_experts .............................. None
  num_workers ..................................... 0
  one_logger_entity ............................... hwinf_dcm
  one_logger_project .............................. e2e-tracking
  one_logger_run_name ............................. None
  onnx_safe ....................................... None
  openai_gelu ..................................... False
  optimizer ....................................... adam
  out_seq_length .................................. 1024
  output_bert_embeddings .......................... False
  overlap_grad_reduce ............................. False
  overlap_p2p_comm ................................ False
  overlap_param_gather ............................ False
  override_opt_param_scheduler .................... False
  params_dtype .................................... torch.bfloat16
  patch_dim ....................................... 16
  patch_size ...................................... None
  patch_tokenizer_type ............................ LLamaTokenizer
  perform_initialization .......................... True
  pipeline_model_parallel_size .................... 1
  pipeline_model_parallel_split_rank .............. None
  position_embedding_type ......................... rope
  position_encoding_2d ............................ False
  pretrained_checkpoint ........................... None
  profile ......................................... False
  profile_ranks ................................... [0]
  profile_step_end ................................ 12
  profile_step_start .............................. 10
  q_lora_rank ..................................... None
  qk_layernorm .................................... False
  qk_nope_head_dim ................................ None
  qk_rope_head_dim ................................ None
  query_in_block_prob ............................. 0.1
  rampup_batch_size ............................... None
  rank ............................................ 0
  recompute_granularity ........................... selective
  recompute_method ................................ None
  recompute_num_layers ............................ None
  repetition_penalty .............................. 1.1
  reset_attention_mask ............................ False
  reset_position_ids .............................. False
  retriever_report_topk_accuracies ................ []
  retriever_score_scaling ......................... False
  retriever_seq_length ............................ 256
  retro_add_retriever ............................. False
  retro_attention_gate ............................ 1
  retro_cyclic_train_iters ........................ None
  retro_encoder_attention_dropout ................. 0.1
  retro_encoder_hidden_dropout .................... 0.1
  retro_encoder_layers ............................ 2
  retro_num_neighbors ............................. 2
  retro_num_retrieved_chunks ...................... 2
  retro_project_dir ............................... None
  retro_verify_neighbor_count ..................... True
  rotary_base ..................................... 1000000
  rotary_interleaved .............................. False
  rotary_percent .................................. 1.0
  rotary_scale_factor ............................. 1
  rotary_scaling_factor ........................... 1
  rotary_seq_len_interpolation_factor ............. None
  router_type ..................................... topk
  sample_rate ..................................... 1.0
  save ............................................ None
  save_interval ................................... None
  scatter_gather_tensors_in_pipeline .............. True
  seed ............................................ 1234
  seq_length ...................................... 256
  sequence_parallel ............................... False
  sgd_momentum .................................... 0.9
  shared_moe_ffn_hidden_size ...................... None
  short_seq_prob .................................. 0.1
  skip_train ...................................... False
  sliding_window .................................. None
  source_seq_len .................................. None
  spec ............................................ None
  split ........................................... None
  squared_relu .................................... False
  standalone_embedding_stage ...................... False
  start_weight_decay .............................. 0.01
  straggler_ctrlr_port ............................ 65535
  straggler_minmax_count .......................... 1
  swiglu .......................................... True
  swin_backbone_type .............................. tiny
  target_seq_len .................................. None
  task_list ....................................... all
  temperature ..................................... 1.0
  tensor_model_parallel_size ...................... 1
  tensorboard_dir ................................. None
  tensorboard_log_interval ........................ 1
  tensorboard_queue_size .......................... 1000
  test_data_path .................................. None
  test_mode ....................................... False
  text_generate_gt_file ........................... 
  text_generate_input_file ........................ 
  text_generate_output_file ....................... 
  time ............................................ False
  timing_log_level ................................ 0
  timing_log_option ............................... minmax
  titles_data_path ................................ None
  tokenizer_model ................................. None
  tokenizer_type .................................. NullTokenizer
  top_k ........................................... 0
  top_p ........................................... 0.0
  tp_comm_bulk_dgrad .............................. True
  tp_comm_bulk_wgrad .............................. True
  tp_comm_overlap ................................. False
  tp_comm_overlap_ag .............................. True
  tp_comm_overlap_cfg ............................. None
  tp_comm_overlap_rs .............................. True
  tp_comm_overlap_rs_dgrad ........................ False
  tp_comm_split_ag ................................ True
  tp_comm_split_rs ................................ True
  train_data ...................................... None
  train_data_path ................................. None
  train_iters ..................................... None
  train_samples ................................... None
  transformer_impl ................................ transformer_engine
  transformer_pipeline_model_parallel_size ........ 1
  transformer_timers .............................. False
  transformer_type ................................ megatron
  tune_mm_mlp_adapter ............................. False
  untie_embeddings_and_output_weights ............. True
  use_alibi_mask .................................. False
  use_checkpoint_args ............................. False
  use_checkpoint_opt_param_scheduler .............. False
  use_cpu_initialization .......................... None
  use_dist_ckpt ................................... False
  use_distributed_optimizer ....................... True
  use_flash_attn .................................. False
  use_legacy_models ............................... False
  use_llama2_rotary_position_embeddings ........... False
  use_mistral_rotary_position_embeddings .......... False
  use_normhead .................................... False
  use_one_sent_docs ............................... False
  use_ring_exchange_p2p ........................... False
  use_rotary_position_embeddings .................. True
  use_tp_pp_dp_mapping ............................ False
  use_tutel ....................................... False
  v_head_dim ...................................... None
  valid_data ...................................... None
  valid_data_path ................................. ['/raid/LLM_train/Pai-Megatron-Patch/qwen-datasets/alpaca_zh-qwen-valid.json']
  variable_seq_lengths ............................ False
  verbosity ....................................... INFO
  version ......................................... plain
  virtual_pipeline_model_parallel_size ............ None
  vision_backbone_type ............................ vit
  vision_pretraining .............................. False
  vision_pretraining_type ......................... classify
  vision_tower .................................... 
  vocab_extra_ids ................................. 0
  vocab_file ...................................... None
  vocab_size ...................................... -1
  wandb_exp_name .................................. 
  wandb_project ................................... 
  wandb_save_dir .................................. 
  weight_decay .................................... 0.01
  weight_decay_incr_style ......................... constant
  world_size ...................................... 1
  yaml_cfg ........................................ None
  z_loss_weight ................................... 0.0
-------------------- end of arguments ---------------------
setting number of micro-batches to constant 1
> building NullTokenizer tokenizer ...
 > padded vocab (size: 0) with 0 dummy tokens (new size: 0)
> initializing torch distributed ...
> initialized tensor model parallel with size 1
> initialized pipeline model parallel with size 1
> setting random seeds to 1234 ...
> compiling dataset index builder ...
make: Entering directory '/raid/LLM_train/Pai-Megatron-Patch/Megatron-LM-240612/megatron/core/datasets'
make: Nothing to be done for 'default'.
make: Leaving directory '/raid/LLM_train/Pai-Megatron-Patch/Megatron-LM-240612/megatron/core/datasets'
>>> done with dataset index builder. Compilation time: 0.500 seconds
WARNING: constraints for invoking optimized fused softmax kernel are not met. We default back to unfused kernel invocations.
> compiling and loading fused kernels ...
>>> done with compiling and loading fused kernels. Compilation time: 0.695 seconds
> building LLamaTokenizer tokenizer ...
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Running Encoding (num_proc=16): 100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:08<00:00, 14.60 examples/s]
1000it [00:00, 130834.86it/s]
  >> total number of samples: 997
> building LLamaTokenizer tokenizer ...
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
building Qwen2 model ...
 > number of parameters on (tensor, pipeline) model parallel rank (0, 0): 630167424
 loading release checkpoint from /raid/LLM_train/Pai-Megatron-Patch/checkpoints/qwen-ckpts/Qwen2-0.5B-hf-mcore-te-tp1-pp1
[rank0]: Traceback (most recent call last):
[rank0]:   File "/raid/LLM_train/Pai-Megatron-Patch/examples/qwen2/evaluate_mcore_qwen.py", line 207, in <module>
[rank0]:     main()
[rank0]:   File "/raid/LLM_train/Pai-Megatron-Patch/examples/qwen2/evaluate_mcore_qwen.py", line 193, in main
[rank0]:     load_checkpoint(model, None, None)
[rank0]:   File "/raid/LLM_train/Pai-Megatron-Patch/Megatron-LM-240612/megatron/training/checkpointing.py", line 807, in load_checkpoint
[rank0]:     model[0].load_state_dict(state_dict['model'], strict=strict)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2189, in load_state_dict
[rank0]:     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
[rank0]: RuntimeError: Error(s) in loading state_dict for GPTModel:
[rank0]:        Missing key(s) in state_dict: "decoder.layers.0.self_attention.linear_proj._extra_state", "decoder.layers.0.self_attention.linear_qkv._extra_state", "decoder.layers.0.mlp.linear_fc1._extra_state", "decoder.layers.0.mlp.linear_fc2._extra_state", "decoder.layers.1.self_attention.linear_proj._extra_state", "decoder.layers.1.self_attention.linear_qkv._extra_state", "decoder.layers.1.mlp.linear_fc1._extra_state", "decoder.layers.1.mlp.linear_fc2._extra_state", "decoder.layers.2.self_attention.linear_proj._extra_state", "decoder.layers.2.self_attention.linear_qkv._extra_state", "decoder.layers.2.mlp.linear_fc1._extra_state", "decoder.layers.2.mlp.linear_fc2._extra_state", "decoder.layers.3.self_attention.linear_proj._extra_state", "decoder.layers.3.self_attention.linear_qkv._extra_state", "decoder.layers.3.mlp.linear_fc1._extra_state", "decoder.layers.3.mlp.linear_fc2._extra_state", "decoder.layers.4.self_attention.linear_proj._extra_state", "decoder.layers.4.self_attention.linear_qkv._extra_state", "decoder.layers.4.mlp.linear_fc1._extra_state", "decoder.layers.4.mlp.linear_fc2._extra_state", "decoder.layers.5.self_attention.linear_proj._extra_state", "decoder.layers.5.self_attention.linear_qkv._extra_state", "decoder.layers.5.mlp.linear_fc1._extra_state", "decoder.layers.5.mlp.linear_fc2._extra_state", "decoder.layers.6.self_attention.linear_proj._extra_state", "decoder.layers.6.self_attention.linear_qkv._extra_state", "decoder.layers.6.mlp.linear_fc1._extra_state", "decoder.layers.6.mlp.linear_fc2._extra_state", "decoder.layers.7.self_attention.linear_proj._extra_state", "decoder.layers.7.self_attention.linear_qkv._extra_state", "decoder.layers.7.mlp.linear_fc1._extra_state", "decoder.layers.7.mlp.linear_fc2._extra_state", "decoder.layers.8.self_attention.linear_proj._extra_state", "decoder.layers.8.self_attention.linear_qkv._extra_state", "decoder.layers.8.mlp.linear_fc1._extra_state", "decoder.layers.8.mlp.linear_fc2._extra_state", "decoder.layers.9.self_attention.linear_proj._extra_state", "decoder.layers.9.self_attention.linear_qkv._extra_state", "decoder.layers.9.mlp.linear_fc1._extra_state", "decoder.layers.9.mlp.linear_fc2._extra_state", "decoder.layers.10.self_attention.linear_proj._extra_state", "decoder.layers.10.self_attention.linear_qkv._extra_state", "decoder.layers.10.mlp.linear_fc1._extra_state", "decoder.layers.10.mlp.linear_fc2._extra_state", "decoder.layers.11.self_attention.linear_proj._extra_state", "decoder.layers.11.self_attention.linear_qkv._extra_state", "decoder.layers.11.mlp.linear_fc1._extra_state", "decoder.layers.11.mlp.linear_fc2._extra_state", "decoder.layers.12.self_attention.linear_proj._extra_state", "decoder.layers.12.self_attention.linear_qkv._extra_state", "decoder.layers.12.mlp.linear_fc1._extra_state", "decoder.layers.12.mlp.linear_fc2._extra_state", "decoder.layers.13.self_attention.linear_proj._extra_state", "decoder.layers.13.self_attention.linear_qkv._extra_state", "decoder.layers.13.mlp.linear_fc1._extra_state", "decoder.layers.13.mlp.linear_fc2._extra_state", "decoder.layers.14.self_attention.linear_proj._extra_state", "decoder.layers.14.self_attention.linear_qkv._extra_state", "decoder.layers.14.mlp.linear_fc1._extra_state", "decoder.layers.14.mlp.linear_fc2._extra_state", "decoder.layers.15.self_attention.linear_proj._extra_state", "decoder.layers.15.self_attention.linear_qkv._extra_state", "decoder.layers.15.mlp.linear_fc1._extra_state", "decoder.layers.15.mlp.linear_fc2._extra_state", "decoder.layers.16.self_attention.linear_proj._extra_state", "decoder.layers.16.self_attention.linear_qkv._extra_state", "decoder.layers.16.mlp.linear_fc1._extra_state", "decoder.layers.16.mlp.linear_fc2._extra_state", "decoder.layers.17.self_attention.linear_proj._extra_state", "decoder.layers.17.self_attention.linear_qkv._extra_state", "decoder.layers.17.mlp.linear_fc1._extra_state", "decoder.layers.17.mlp.linear_fc2._extra_state", "decoder.layers.18.self_attention.linear_proj._extra_state", "decoder.layers.18.self_attention.linear_qkv._extra_state", "decoder.layers.18.mlp.linear_fc1._extra_state", "decoder.layers.18.mlp.linear_fc2._extra_state", "decoder.layers.19.self_attention.linear_proj._extra_state", "decoder.layers.19.self_attention.linear_qkv._extra_state", "decoder.layers.19.mlp.linear_fc1._extra_state", "decoder.layers.19.mlp.linear_fc2._extra_state", "decoder.layers.20.self_attention.linear_proj._extra_state", "decoder.layers.20.self_attention.linear_qkv._extra_state", "decoder.layers.20.mlp.linear_fc1._extra_state", "decoder.layers.20.mlp.linear_fc2._extra_state", "decoder.layers.21.self_attention.linear_proj._extra_state", "decoder.layers.21.self_attention.linear_qkv._extra_state", "decoder.layers.21.mlp.linear_fc1._extra_state", "decoder.layers.21.mlp.linear_fc2._extra_state", "decoder.layers.22.self_attention.linear_proj._extra_state", "decoder.layers.22.self_attention.linear_qkv._extra_state", "decoder.layers.22.mlp.linear_fc1._extra_state", "decoder.layers.22.mlp.linear_fc2._extra_state", "decoder.layers.23.self_attention.linear_proj._extra_state", "decoder.layers.23.self_attention.linear_qkv._extra_state", "decoder.layers.23.mlp.linear_fc1._extra_state", "decoder.layers.23.mlp.linear_fc2._extra_state". 
E0703 10:00:55.575000 140162604175808 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 0 (pid: 1450033) of binary: /usr/bin/python
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==2.4.0a0+07cecf4168.nv24.5', 'console_scripts', 'torchrun')())
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 879, in main
    run(args)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 870, in run
    elastic_launch(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 263, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
evaluate_mcore_qwen.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-07-03_10:00:55
  host      : 6a459af124ab
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 1450033)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
@Jayce1kk
Copy link
Author

Jayce1kk commented Jul 3, 2024

在转换权重的代码中将
/Pai-Megatron-Patch/toolkits/model_checkpoints_convertor/qwen/hf2mcore_qwen2_dense_and_moe_gqa.py 中line482行

if full_model[k] is None or "_extra_state" in k:
            full_model.pop(k)

修改为下述代码之后不报错,但是不知道对后续的训练有没有影响

if full_model[k] is None :
            full_model.pop(k)

@jerryli1981
Copy link
Collaborator

您好,报错中仅仅出现extra_state其实不是一个错误,仅需要将strict=False即可

@Jayce1kk
Copy link
Author

好的,非常感谢您的回复

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants