Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSeek V3模型训练报错 #387

Open
Nemo-HelloWorld opened this issue Feb 28, 2025 · 4 comments
Open

DeepSeek V3模型训练报错 #387

Nemo-HelloWorld opened this issue Feb 28, 2025 · 4 comments

Comments

@Nemo-HelloWorld
Copy link

使用的是0.6.5版本example提供的yaml文件进行配置,报错信息如下:

[default0]:Traceback (most recent call last):
[default0]: File "/data2/nfs/liyucong/FlagScale-0.6.5/flagscale/train/train_deepseek_v3.py", line 433, in
[default0]: pretrain(
[default0]: File "/data2/nfs/liyucong/FlagScale-0.6.5/flagscale/train/train.py", line 423, in pretrain
[default0]: iteration, num_floating_point_operations_so_far = train(
[default0]: File "/data2/nfs/liyucong/FlagScale-0.6.5/flagscale/train/train.py", line 1659, in train
[default0]: train_step(forward_step_func,
[default0]: File "/data2/nfs/liyucong/FlagScale-0.6.5/flagscale/train/train.py", line 863, in train_step
[default0]: losses_reduced = forward_backward_func(
[default0]: File "/data2/nfs/liyucong/FlagScale-0.6.5/megatron/megatron/core/pipeline_parallel/schedules.py", line 1742, in forward_backward_pipelining_without_interleaving
[default0]: output_tensor, num_tokens = forward_step(
[default0]: File "/data2/nfs/liyucong/FlagScale-0.6.5/megatron/megatron/core/pipeline_parallel/schedules.py", line 275, in forward_step
[default0]: output_tensor, loss_func = forward_step_func(data_iterator, model)
[default0]: File "/data2/nfs/liyucong/FlagScale-0.6.5/flagscale/train/train_deepseek_v3.py", line 322, in forward_step
[default0]: output_tensor = model(tokens, position_ids, attention_mask,
[default0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[default0]: return self._call_impl(*args, **kwargs)
[default0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[default0]: return forward_call(*args, **kwargs)
[default0]: File "/data2/nfs/liyucong/FlagScale-0.6.5/megatron/megatron/core/distributed/data_parallel_base.py", line 22, in forward
[default0]: return self.module(*inputs, **kwargs)
[default0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[default0]: return self._call_impl(*args, **kwargs)
[default0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[default0]: return forward_call(*args, **kwargs)
[default0]: File "/data2/nfs/liyucong/FlagScale-0.6.5/megatron/megatron/legacy/model/module.py", line 189, in forward
[default0]: outputs = self.module(*inputs, **kwargs)
[default0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[default0]: return self._call_impl(*args, **kwargs)
[default0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[default0]: return forward_call(*args, **kwargs)
[default0]: File "/data2/nfs/liyucong/FlagScale-0.6.5/flagscale/train/models/deepseek_v3/deepseek_v3_model.py", line 220, in forward
[default0]: hidden_states = self.decoder(
[default0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[default0]: return self._call_impl(*args, **kwargs)
[default0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[default0]: return forward_call(*args, **kwargs)
[default0]: File "/data2/nfs/liyucong/FlagScale-0.6.5/megatron/megatron/core/transformer/transformer_block.py", line 619, in forward
[default0]: hidden_states, context = layer(
[default0]: File "/data2/nfs/liyucong/FlagScale-0.6.5/megatron/megatron/core/transformer/transformer_layer.py", line 503, in call
[default0]: return super(MegatronModule, self).call(*args, **kwargs)
[default0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[default0]: return self._call_impl(*args, **kwargs)
[default0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[default0]: return forward_call(*args, **kwargs)
[default0]: File "/data2/nfs/liyucong/FlagScale-0.6.5/megatron/megatron/core/transformer/transformer_layer.py", line 391, in forward
[default0]: attention_output_with_bias = self.self_attention(
[default0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[default0]: return self._call_impl(*args, **kwargs)
[default0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[default0]: return forward_call(*args, **kwargs)
[default0]: File "/data2/nfs/liyucong/FlagScale-0.6.5/megatron/megatron/core/transformer/multi_latent_attention.py", line 165, in forward
[default0]: core_attn_out = self.core_attention(
[default0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
[default0]: return self._call_impl(*args, **kwargs)
[default0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[default0]: return forward_call(*args, **kwargs)
[default0]: File "/data2/nfs/liyucong/FlagScale-0.6.5/megatron/megatron/core/extensions/transformer_engine.py", line 804, in forward
[default0]: core_attn_out = super().forward(
[default0]: File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 2488, in forward
[default0]: assert (key_layer.shape == value_layer.shape
[default0]:AssertionError: Keys and values must have the same shape!

@lxd-cumt
Copy link
Collaborator

麻烦提供一下 yaml config

@Nemo-HelloWorld
Copy link
Author

Nemo-HelloWorld commented Feb 28, 2025

system:
  no_shared_fs: ${experiment.runner.no_shared_fs}
  num_workers: 16
  tensor_model_parallel_size: 2
  pipeline_model_parallel_size: 2
  expert_model_parallel_size: 2
  context_parallel_size: 1
  disable_bias_linear: true
  reset_position_ids: True
  reset_attention_mask: True
  add_qkv_bias: true
  qk_layernorm: true
  sequence_parallel: true
  use_distributed_optimizer: true
  overlap_grad_reduce: true
  overlap_param_gather: true
  precision:
    bf16: true
    initial_loss_scale: 522893
    min_loss_scale: 1.0
    attention_softmax_in_fp32: true
    accumulate_allreduce_grads_in_fp32: true
  logging:
    log_interval: 1
    tensorboard_log_interval: 1
    wandb_project: ${experiment.exp_name}
    wandb_exp_name: ${experiment.exp_name}
    log_timers_to_tensorboard: true
    log_validation_ppl_to_tensorboard: true
    log_throughput: true
    log_params_norm: true
    log_num_zeros_in_grad: true
    log_memory_to_tensorboard: true
  checkpoint:
    save_interval: ${experiment.save_steps}
    load: ${experiment.load}
    ckpt_format: ${experiment.ckpt_format}

model:
  transformer_impl: transformer_engine
  num_layers: 24
  hidden_size: 2048
  ffn_hidden_size: 1408
  num_attention_heads: 16
  group_query_attention: true
  num_query_groups: 16 # num_key_value_heads
  seq_length: 4096
  max_position_embeddings: 4096
  norm_epsilon: 1e-6
  use_rotary_position_embeddings: true
  rotary_base: 10000
  norm_init_weight: 0.5
  swiglu: true
  normalization: RMSNorm
  init_method_std: 0.02
  attention_dropout: 0.0
  hidden_dropout: 0.0
  position_embedding_type: rope
  untie_embeddings_and_output_weights: true
  no_position_embedding: true
  no_rope_fusion: true

  # mla args ==================
  multi_latent_attention: true
  q_lora_rank: 1536
  kv_lora_rank: 512
  qk_head_dim: 128
  qk_pos_emb_head_dim: 64
  v_head_dim: 128

  # moe args ===================
  moe_shared_expert_intermediate_size: 1408
  num_experts: 64
  moe_router_load_balancing_type: "aux_loss"
  moe_router_score_function: sigmoid
  moe_router_enable_expert_bias: true
  moe_router_bias_update_rate: 0.001
  moe_aux_loss_coeff: 0.02
  moe_layer_freq: "[0]+[1]*23"
  # node limited routing
  moe_router_num_groups: 1
  moe_router_group_topk: 1
  moe_router_topk: 6

  # training
  seed: ${experiment.seed}
  train_iters: 10000
  micro_batch_size: 1
  global_batch_size: 32
  eval_interval: 1000
  eval_iters: 5

  # optimizer
  no_load_optim: True
  no_load_rng: True
  optimizer: adam
  lr: 0.0005
  min_lr: 0.00005
  weight_decay: 0.1
  adam_beta1: 0.9
  adam_beta2: 0.95
  adam_eps: 1.0e-6
  clip_grad: 1.0
  lr_warmup_fraction: 0.02
  lr_decay_iters: 120000
  lr_decay_style: cosine

data:
  data_path: /data2/nfs/llama-dataset/merged-1t/merged-1t
  split: 1
  tokenizer:
    tokenizer_type: Llama2Tokenizer
    tokenizer_model: /data2/nfs/llama-dataset/tokenizer.model 
    vocab_size: 32000
  # no_create_attention_mask_in_dataloader: true
  # data_path: /data2/nfs/llama-dataset/merged-1t/merged-1t
  # split: 1
  # no_mmap_bin_files: true
  # tokenizer:
  #   tokenizer_type: QwenTokenizerFS
  #   tokenizer_path: /data2/nfs/liyucong/FlagScale-0.6.5/examples/aquila/qwentokenizer
  #   vocab_size: 151851
  #   make_vocab_size_divisible_by: 64

@Nemo-HelloWorld
Copy link
Author

# Demo for testing
# TODO: Release 671B version
defaults:
  - _self_
  - train: train_deepseek_v3.yaml

experiment:
  exp_name: DeepSeek-V3-Test
  seed: 42
  save_steps: 10000
  load: None
  exp_dir: ./outputs_deepseek_v3
  ckpt_format: torch
  task:
    type: train
    backend: megatron
    entrypoint: flagscale/train/train_deepseek_v3.py
  runner:
    per_node_task: false
    no_shared_fs: false
    rdzv_backend: static
    hostfile: null
  cmds:
    before_start: "ulimit -n 1048576 && source /root/miniconda3/bin/activate flagscale"
  envs:
    LOGLEVEL: "INFO"
    CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7"
    CUDA_DEVICE_MAX_CONNECTIONS: 1

action: run

hydra:
  run:
    dir: ${experiment.exp_dir}/hydra

@lxd-cumt
Copy link
Collaborator

lxd-cumt commented Feb 28, 2025

Image 可否提供一下 transformer engine 版本
我们测试过 transformer engine v1.11 是没有问题的

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants