Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 210 additions & 16 deletions test/common/envPreCheck/run_env_preCheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, Dict, List, Optional, Tuple

import yaml
from common.config_utils import config_utils

CODE_ROOT = Path(__file__).resolve().parent
Custom_SSH_DIR = (CODE_ROOT / "ssh_keys").resolve()
Expand All @@ -18,23 +19,22 @@
LOCAL_SSH_KEY = Custom_SSH_DIR / "id_rsa"
LOCAL_SSH_KEY_PUB = Custom_SSH_DIR / "id_rsa.pub"

config_file = Path(__file__).parent.parent.parent / "config.yaml"
with open(config_file, "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
MASTER_IP = config.get("Env_preCheck", {}).get("master_ip", "")
WORKER_IP = config.get("Env_preCheck", {}).get("worker_ip", "")
ASCEND_RT_VISIBLE_DEVICES = config.get("Env_preCheck", {}).get(
"ascend_rt_visible_devices", ""
)
NODE_NUM = config.get("Env_preCheck", {}).get("node_num", "")
MODEL_PATH = config.get("Env_preCheck", {}).get("model_path", "")
HF_MODEL_NAME = config.get("Env_preCheck", {}).get("hf_model_name", "")
MIDDLE_PAGE = config.get("Env_preCheck", {}).get("middle_page", "")
MASTER_IP = config_utils.get_nested_config("Env_preCheck.master_ip", "")
WORKER_IP = config_utils.get_nested_config("Env_preCheck.worker_ip", "")
ASCEND_RT_VISIBLE_DEVICES = config_utils.get_nested_config(
"Env_preCheck.ascend_rt_visible_devices", ""
)
NODE_NUM = config_utils.get_nested_config("Env_preCheck.node_num", "")
MODEL_PATH = config_utils.get_nested_config("Env_preCheck.model_path", "")
HF_MODEL_NAME = config_utils.get_nested_config("Env_preCheck.hf_model_name", "")
MIDDLE_PAGE = config_utils.get_nested_config("Env_preCheck.middle_page", "")
KVCACHE_BLOCK_NUMBER = config_utils.get_nested_config(
"Env_preCheck.kvCache_block_number", ""
)
STORAGE_BACKENDS = config_utils.get_nested_config("Env_preCheck.storage_backends", "")

KVCACHE_BLOCK_NUMBER = config.get("Env_preCheck", {}).get(
"kvCache_block_number", ""
)
STORAGE_BACKENDS = config.get("Env_preCheck", {}).get("storage_backends", "")

config_file = Path(__file__).parent.parent.parent / "config.yaml"


def run_command(
Expand Down Expand Up @@ -1398,3 +1398,197 @@ def run_bandwidth_check():
bw_summary[key] = None

return bw_summary


def run_mlnx_qos():
"""
通过 SSH 在宿主机上执行 QoS 配置,支持 NVIDIA / Ascend 环境自动判断
每个阶段都会记录执行状态,并最终给出 global_status
"""
result = {
"env": {"type": None, "status": False},
"mapping": [],
"mapping_status": False,
"interfaces": {},
"rc_local": {"status": None},
"dispatcher": {"status": None},
"global_status": True,
"errors": [],
}

def ssh_run(cmd):
full_cmd = f'ssh -i "{LOCAL_SSH_KEY}" root@{MASTER_IP} "{cmd}"'
p = subprocess.run(
full_cmd,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
return p.returncode == 0, p.stdout.strip(), p.stderr.strip()

# ------------------------------------------------------------------
# 1. 环境判断
# ------------------------------------------------------------------
print("[STEP] Detecting remote environment...")
if ssh_run("command -v nvidia-smi")[0]:
env_type = "Nvidia"
elif ssh_run("command -v npu-smi")[0] or ssh_run("test -d /usr/local/Ascend")[0]:
env_type = "Ascend"
else:
env_type = "unknown"

result["env"]["type"] = env_type
result["env"]["status"] = env_type != "unknown"

if not result["env"]["status"]:
print("[FAIL] Environment detection failed")
result["errors"].append("Unable to detect NVIDIA or Ascend environment")
result["global_status"] = False
return result

print(f"[OK] Detected environment: {env_type}")

# ------------------------------------------------------------------
# 2. 获取接口列表
# ------------------------------------------------------------------
print("[STEP] Collecting IB interfaces...")
ok, out, err = ssh_run("ibdev2netdev")
if not ok:
print("[FAIL] Failed to run ibdev2netdev")
result["errors"].append(err)
result["global_status"] = False
return result

mapping = []
for line in out.splitlines():
if "==>" in line:
ibdev = line.split("==>")[0].split()[0]
netdev = line.split("==>")[1].split()[0]
mapping.append((ibdev, netdev))

result["mapping"] = mapping
result["mapping_status"] = bool(mapping)

if not mapping:
print("[FAIL] No IB interfaces found")
result["global_status"] = False
return result

print(f"[OK] Found interfaces: {mapping}")

# ------------------------------------------------------------------
# 3. QoS / 流控配置
# ------------------------------------------------------------------
print("[STEP] Applying QoS configuration...")
for ibdev, netdev in mapping:
iface_ok = True
iface_result = {}

def run_cmd(name, cmd):
nonlocal iface_ok
ok, _, err = ssh_run(cmd)
iface_result[name] = ok
if not ok:
iface_ok = False
result["errors"].append(f"{netdev}: {name} failed: {err}")
return ok

print(f" -> Configuring {netdev} ({ibdev})")

run_cmd("pfc", f"mlnx_qos -i {netdev} --pfc 1,0,0,0,1,0,0,0 --trust dscp")
run_cmd("prio_tc", f"mlnx_qos -i {netdev} --prio_tc 0,0,0,0,4,0,0,0")

if env_type == "Nvidia":
run_cmd(
"tcbw",
f"mlnx_qos -i {netdev} --tsa ets,ets,ets,ets,ets,ets,ets,ets "
f"--tcbw 10,90,0,0,0,0,0,0",
)
else:
run_cmd(
"tcbw",
f"mlnx_qos -i {netdev} --tsa ets,ets,ets,ets,ets,ets,ets,ets "
f"--tcbw 10,0,0,0,90,0,0,0",
)

run_cmd("cma_roce_tos", f"cma_roce_tos -d {ibdev} -t 128")

iface_result["status"] = iface_ok
result["interfaces"][netdev] = iface_result

if iface_ok:
print(f" [OK] {netdev} configured successfully")
else:
print(f" [FAIL] {netdev} configuration failed")
result["global_status"] = False

# ------------------------------------------------------------------
# 4. rc-local 确认
# ------------------------------------------------------------------
if env_type == "Ascend":
print("[STEP] Configuring rc-local")

rc_ok, _, _ = ssh_run(
'test -f /etc/rc.local || echo -e "#!/bin/bash\\nexit 0" > /etc/rc.local'
)
result["rc_local"]["status"] = rc_ok

if not rc_ok:
result["errors"].append("rc-local setup failed")
result["global_status"] = False
print("[FAIL] rc-local configuration failed")
else:
print("[OK] rc-local configured")

# ------------------------------------------------------------------
# 配置永久生效 脚本写入(存在则覆盖,不存在则创建)
# ------------------------------------------------------------------
DISPATCH_SCRIPT = "/etc/NetworkManager/dispatcher.d/pre-up.d/set-nvmeof-qos"
print(f"[STEP] Writing dispatcher script: {DISPATCH_SCRIPT}")

dispatch_content = """#!/bin/sh
# NetworkManager dispatcher script for NVMe-oF QoS
# Triggered on pre-up

netdev=$1
action=$2

if [ "x${action}" != "xpre-up" ]; then
exit 0
fi

logger "[set-nvmeof-qos] pre-up triggered for ${netdev}"

ibdev=$(ibdev2netdev | awk -v dev="${netdev}" '$2 == dev {print $1}')
[ -z "${ibdev}" ] && exit 0

mlnx_qos -i ${netdev} --pfc 1,0,0,0,1,0,0,0 --trust=dscp
mlnx_qos -i ${netdev} --prio_tc 0,0,0,0,4,0,0,0
mlnx_qos -i ${netdev} --tsa ets,ets,ets,ets,ets,ets,ets,ets --tcbw 10,0,0,0,90,0,0,0
cma_roce_tos -d ${ibdev} -t 128

exit 0
"""

disp_ok, _, err = ssh_run(
f"""
mkdir -p $(dirname {DISPATCH_SCRIPT}) && \
cat > {DISPATCH_SCRIPT} << 'EOF'
{dispatch_content}
EOF
chmod 755 {DISPATCH_SCRIPT}
"""
)

result["dispatcher"]["status"] = disp_ok

if disp_ok:
print("[OK] dispatcher script written successfully")
else:
print(f"[FAIL] dispatcher script write failed: {err}")
result["global_status"] = False
result["errors"].append("dispatcher script setup failed")

print(f"[RESULT] Global QoS status: {result['global_status']}")
return result
12 changes: 12 additions & 0 deletions test/suites/E2E/test_environment_precheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,15 @@ def test_check_bandwidth():
assert (
bandwidth["fetch"] < 0.85 * EXPECTED_FETCH_BANDWIDTH
), f"Fetch bandwidth too high: {bandwidth['fetch']} GB/s"


@pytest.mark.stage(1)
@pytest.mark.platform("npu")
@pytest.mark.feature("test_apply_mlnx_qos")
def test_remote_mlnx_qos():
"""
远程配置宿主机流控 QoS 配置
"""
qos_result = run_mlnx_qos()
print(qos_result)
assert qos_result["global_status"] is True