Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions recipe/qwen3_ascend/README_zh.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Qwen3-235B-A22B RL训练优化实践样例

## 概述
本样例针对Qwen3-235B-A22B模型,基于[veRL开源框架](https://github.com/volcengine/verl),使用veRL原生支持的MindSpeed和vllm-ascend框架,完成RL训练全流程的优化适配。

# 环境准备

## 镜像创建

使用vLLM-Ascend提供的镜像,可以快速配置环境:
```shell
镜像下载命令:docker pull quay.io/ascend/vllm-ascend:v0.10.1rc1-a3
```
镜像使用:
```shell
# 执行以下脚本创建容器,请传入容器名称,如your_docker_name
bash run_container.sh your_docker_name
```

## 软件包安装

1、安装依赖的python库。
```
pip3 install -r requirements.txt
```

2、准备源码,本样准备源码的步骤如下:
```shell
# veRL (commit-id:ac2f7)
git clone https://github.com/volcengine/verl.git
git fetch origin pull/3427/head && git cherry-pick FETCH_HEAD
cd verl
cd ..

# vLLM
git clone https://github.com/vllm-project/vllm.git
cd vllm
git checkout v0.10.1
cp -r vllm ../verl
cd ..

# vLLM-Ascend
git clone https://github.com/vllm-project/vllm-ascend.git
cd vllm-ascend
git checkout af62af
git fetch origin pull/2869/head && git cherry-pick FETCH_HEAD
git fetch origin pull/3005/head && git cherry-pick FETCH_HEAD
cp -r vllm_ascend ../verl
cd ..

# MindSpeed
git clone https://gitee.com/ascend/MindSpeed.git
cd MindSpeed
git checkout 7ff81
cp -r mindspeed ../verl
cd ..

# Megatron-LM.core and others
pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1
pip install mathruler
```

## 准备训练数据集与模型
数据集放入 ./data, 数据集准备参考: [veRL官方文档](https://verl.readthedocs.io/en/latest/preparation/prepare_data.html)
模型放入 ./Qwen3-235B-A22B 模型下载地址:[Qwen3-235B-A22B](https://huggingface.co/Qwen/Qwen3-235B-A22B)

## 执行RL后训练
```shell
# 本sample目录下启动Qwen3-235B-A22B的RL后训练
bash ./ray_start_grpo_npu.sh # 基于真实权重的训练脚本
```

## 性能数据
基于Atlas 900 A3 SuperPoD超节点64卡集群,加载真实权重,Prefill/Decode阶段长度分别为1K与3K,系统吞吐达到89tps/卡。
| 模型 | 机器型号 | GBS | n_samples | max_prompt_length | max_tokens | 端到端 tps |
|---------------------|----------|-----|-----------|-------------------|------------|---------|
| Qwen3-235B-A22B | Atlas 900 A3 SuperPoD | 256 | 16 | 4096 | 3072 | 89 |
86 changes: 86 additions & 0 deletions recipe/qwen3_ascend/ray_start_grpo_npu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
pkill -9 python
ray stop --force

export RAY_DEDUP_LOGS=0 # 0: disable ray's log folding 1: enable ray's log folding
export HYDRA_FULL_ERROR=1 # display the accurate error stack

ulimit -n 32768
mkdir logs

export HCCL_IF_BASE_PORT=24703

NNODES=8 # number of nodes
NPUS_PER_NODE=16 # the number of npus for each node
MASTER_ADDR="IP FOR MASTER NODE" # modify it to correspond to the IP of the master node
SOCKET_IFNAME="SOCKET IFNAME FOR CURRENT NODE" # modify it to the communication network card of the current node
# obtain the current node IP
CURRENT_IP=$(ifconfig $SOCKET_IFNAME | grep -Eo 'inet (addr:)?([0-9]{1,3}\.){3}[0-9]{1,3}' | awk '{print $NF}')

export TP_SOCKET_IFNAME=$SOCKET_IFNAME
export HCCL_SOCKET_IFNAME=$SOCKET_IFNAME
export GLOO_SOCKET_IFNAME=$SOCKET_IFNAME

# configure environment variables
export CUDA_DEVICE_MAX_CONNECTIONS=1
export PYTORCH_NPU_ALLOC_CONF="expandable_segments:True"
export WORLD_SIZE=$(($NNODES*$NPUS_PER_NODE))
export MASTER_PORT=29444

export ASCEND_LAUNCH_BLOCKING=0 # debug usage, which seriously affects performance after use, but the error stack is accurate

export HCCL_CONNECT_TIMEOUT=600
export HCCL_EXEC_TIMEOUT=600
export HCCL_IF_BASE_PORT=64247
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The environment variable HCCL_IF_BASE_PORT is defined here again, overwriting the value set on line 10. This is likely an error and can lead to unpredictable behavior in your network configuration. You should remove one of the two declarations to ensure consistency.


export VLLM_USE_V1=1 # use the V1 engine of vLLM
export VLLM_ENABLE_GRAPH_MODE=1 # enable vLLM graph mode
export HCCL_OP_EXPANSION_MODE=AIV # enable the communication mode of AIV
export VLLM_ENABLE_MC2=1 # enable MC2 communication
export VLLM_DP_SIZE=128 # configure the DP size of vLLM
# under the configuration of the vLLM log level of INFO, enable this configuration, print the time of prefill and decode
export VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE=0

export TASK_QUEUE_ENABLE=2 # enable level2 optimization of the sent queue of the ascend operator
export HCCL_BUFFSIZE=300 # the buffer size of HCCL


if [ "$MASTER_ADDR" = "$CURRENT_IP" ]; then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This comparison is fragile. If the user forgets to change MASTER_ADDR from its placeholder value "IP FOR MASTER NODE", this condition will likely always evaluate to false. This would cause a node intended as a master to incorrectly act as a worker, leading to cluster setup failure. It is critical to add validation at the beginning of the script to ensure that MASTER_ADDR and SOCKET_IFNAME have been set to valid, non-placeholder values.

# the master node starts
ray start --head --port 6766 --dashboard-host=0.0.0.0 --node-ip-address=$CURRENT_IP --dashboard-port=8260 --resources='{"NPU": '$NPUS_PER_NODE'}'

while true; do
ray_status_output=$(ray status)
npu_count=$(echo "$ray_status_output" | grep -oP '(?<=/)\d+\.\d+(?=\s*NPU)' | head -n 1)
npu_count_int=$(echo "$npu_count" | awk '{print int($1)}')
device_count=$((npu_count_int / $NPUS_PER_NODE))

# determine whether device_count is equal to NNODES
if [ "$device_count" -eq "$NNODES" ]; then
echo "Ray cluster is ready with $device_count devices (from $npu_count NPU resources), starting Python script."
ray status
bash ./recipe/qwen3_ascend/run_grpo_qwen3_235b_npu_megatron.sh
break
else
echo "Waiting for Ray to allocate $NNODES devices. Current device count: $device_count"
sleep 5
fi
done
else
# the child node attempts to register ray with the master node until successful
while true; do
# try to connect to the Ray cluster
ray start --address="$MASTER_ADDR:6766" --resources='{"NPU": '$NPUS_PER_NODE'}' --node-ip-address=$CURRENT_IP

# check if the connection is successful
ray status
if [ $? -eq 0 ]; then
echo "Successfully connected to the Ray cluster!"
break
else
echo "Failed to connect to the Ray cluster. Retrying in 5 seconds..."
sleep 5
fi
done
fi

exit 127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The script unconditionally exits with code 127, which is a critical flaw. For worker nodes, this causes the script to terminate immediately after connecting to the Ray cluster, effectively removing the worker. For the master node, the exit code 127 is misleading, as it usually indicates 'command not found'. Worker nodes must remain running to be part of the cluster. This line should be removed to allow worker processes to persist.

48 changes: 48 additions & 0 deletions recipe/qwen3_ascend/run_container.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# coding=utf-8
# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the license.

#!/bin/bash
container_name=$1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The script requires a container name as an argument but doesn't validate its presence. If run without an argument, container_name will be empty, causing the docker run command to fail with an error about the --name flag. You should enforce that this argument is provided.

Suggested change
container_name=$1
container_name=${1:?"Error: Container name not provided. Usage: $0 <container_name>"}


# create
docker run -itd \
--device=/dev/davinci0 \
--device=/dev/davinci1 \
--device=/dev/davinci2 \
--device=/dev/davinci3 \
--device=/dev/davinci4 \
--device=/dev/davinci5 \
--device=/dev/davinci6 \
--device=/dev/davinci7 \
--device=/dev/davinci8 \
--device=/dev/davinci9 \
--device=/dev/davinci10 \
--device=/dev/davinci11 \
--device=/dev/davinci12 \
--device=/dev/davinci13 \
--device=/dev/davinci14 \
--device=/dev/davinci15 \
-v /usr/local/dcmi:/usr/local/dcmi \
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
-v /var/log/npu/slog/slogd:/var/log/npu/slog/slogd \
-v /usr/local/sbin/:/usr/local/sbin/ \
-v /data/:/data/ \
-v /home/:/home/ \
-v /etc/localtime:/etc/localtime \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v /dev/shm:/dev/shm \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm \
--device=/dev/hisi_hdc \
--net=host \
--name ${container_name} \
--privileged quay.io/ascend/vllm-ascend:v0.10.1rc1-a3 /bin/bash
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using the --privileged flag grants the container unrestricted root access to the host machine, which poses a significant security risk. It's strongly recommended to avoid this and instead grant only the specific capabilities required by the application (e.g., using --cap-add). If --privileged is absolutely necessary for hardware access, this should be clearly documented with a security warning.


# execute
docker exec -it -u root ${container_name} bash
119 changes: 119 additions & 0 deletions recipe/qwen3_ascend/run_grpo_qwen3_235b_npu_megatron.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#!/usr/bin/env bash
set -x

project_name='GRPO'
exp_name='GRPO-Qwen3-235B-Megatron-128rank-gbs256'

NNODES=8
NPUS_PER_NODE=16

adv_estimator=grpo

kl_coef=0.001
use_kl_loss=True
kl_loss_coef=0.001

max_prompt_length=$((1024 * 1))
max_response_length=$((1024 * 3))
max_num_batched_tokens=4096
ppo_mini_batch_size=256

train_prompt_bsz=256
n_resp_per_prompt=16

# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
CONFIG_PATH=${CONFIG_PATH:-"${RAY_DATA_HOME}/verl/trainer/config"}
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-235B-A22B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/Qwen3-235B-dist-ckpts"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/gsm8k/train.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/gsm8k/test.parquet"}

# Algorithm
temperature=0.9
top_p=0.9
top_k=50 # 0 for HF rollout, -1 for vLLM rollout

# Performance Related Parameter
offload=True
max_num_seqs=32
gen_tp=1

# Currently, it is necessary to enable `enable_chunked_prefill` in the script.
# However, in vLLM ascend, this configuration is off by default and does not take effect.
python3 -m recipe.r1_ascend.main_ppo \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The module path recipe.r1_ascend.main_ppo appears to be incorrect. This pull request adds files under recipe/qwen3_ascend/, and there is no r1_ascend directory visible. This will likely result in a ModuleNotFoundError at runtime. Please verify and correct the module path. It might need to be verl.trainer.main_ppo or another path corresponding to your project structure.

--config-path="${CONFIG_PATH}" \
--config-name='ppo_megatron_trainer.yaml' \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.truncation='error' \
data.filter_overlong_prompts=True \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.rollout.max_num_seqs=${max_num_seqs} \
actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
actor_rollout_ref.actor.megatron.expert_model_parallel_size=8 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=16 \
actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=1 \
actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \
actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=block \
actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=5 \
actor_rollout_ref.actor.load_weight=True \
actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \
actor_rollout_ref.actor.megatron.dist_checkpointing_path="${CKPTS_DIR}" \
actor_rollout_ref.actor.megatron.param_offload=${offload} \
actor_rollout_ref.actor.megatron.grad_offload=${offload} \
actor_rollout_ref.actor.megatron.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.ref.load_weight=True \
actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \
actor_rollout_ref.ref.megatron.dist_checkpointing_path="${CKPTS_DIR}" \
actor_rollout_ref.ref.megatron.param_offload=${offload} \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=True \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node="${NPUS_PER_NODE}" \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=True \
trainer.test_freq=5 \
trainer.save_freq=-1 \
trainer.total_epochs=1 \
trainer.device="npu" \
+actor_rollout_ref.actor.megatron.override_transformer_config.multi_head_latent_attention=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.pipeline_num_transformer_layers=[[5],[6],[6],[6],[6],[6],[6],[6],[6],[6],[6],[6],[6],[6],[6],[5]] \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type='alltoall' \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_grouped_gemm=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.use_fused_rotary_pos_emb=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.use_fused_swiglu=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.seq_length=2048 \
+actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=5 \
+actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=5 \
+actor_rollout_ref.actor.megatron.override_transformer_config.swap_optimizer=True $@
Loading