diff --git a/test/common/envPreCheck/run_env_preCheck.py b/test/common/envPreCheck/run_env_preCheck.py index daae41032..95f08cc55 100644 --- a/test/common/envPreCheck/run_env_preCheck.py +++ b/test/common/envPreCheck/run_env_preCheck.py @@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple import yaml +from common.config_utils import config_utils CODE_ROOT = Path(__file__).resolve().parent Custom_SSH_DIR = (CODE_ROOT / "ssh_keys").resolve() @@ -18,23 +19,22 @@ LOCAL_SSH_KEY = Custom_SSH_DIR / "id_rsa" LOCAL_SSH_KEY_PUB = Custom_SSH_DIR / "id_rsa.pub" -config_file = Path(__file__).parent.parent.parent / "config.yaml" -with open(config_file, "r", encoding="utf-8") as f: - config = yaml.safe_load(f) - MASTER_IP = config.get("Env_preCheck", {}).get("master_ip", "") - WORKER_IP = config.get("Env_preCheck", {}).get("worker_ip", "") - ASCEND_RT_VISIBLE_DEVICES = config.get("Env_preCheck", {}).get( - "ascend_rt_visible_devices", "" - ) - NODE_NUM = config.get("Env_preCheck", {}).get("node_num", "") - MODEL_PATH = config.get("Env_preCheck", {}).get("model_path", "") - HF_MODEL_NAME = config.get("Env_preCheck", {}).get("hf_model_name", "") - MIDDLE_PAGE = config.get("Env_preCheck", {}).get("middle_page", "") +MASTER_IP = config_utils.get_nested_config("Env_preCheck.master_ip", "") +WORKER_IP = config_utils.get_nested_config("Env_preCheck.worker_ip", "") +ASCEND_RT_VISIBLE_DEVICES = config_utils.get_nested_config( + "Env_preCheck.ascend_rt_visible_devices", "" +) +NODE_NUM = config_utils.get_nested_config("Env_preCheck.node_num", "") +MODEL_PATH = config_utils.get_nested_config("Env_preCheck.model_path", "") +HF_MODEL_NAME = config_utils.get_nested_config("Env_preCheck.hf_model_name", "") +MIDDLE_PAGE = config_utils.get_nested_config("Env_preCheck.middle_page", "") +KVCACHE_BLOCK_NUMBER = config_utils.get_nested_config( + "Env_preCheck.kvCache_block_number", "" +) +STORAGE_BACKENDS = config_utils.get_nested_config("Env_preCheck.storage_backends", "") - KVCACHE_BLOCK_NUMBER = config.get("Env_preCheck", {}).get( - "kvCache_block_number", "" - ) - STORAGE_BACKENDS = config.get("Env_preCheck", {}).get("storage_backends", "") + +config_file = Path(__file__).parent.parent.parent / "config.yaml" def run_command( @@ -1398,3 +1398,197 @@ def run_bandwidth_check(): bw_summary[key] = None return bw_summary + + +def run_mlnx_qos(): + """ + 通过 SSH 在宿主机上执行 QoS 配置,支持 NVIDIA / Ascend 环境自动判断 + 每个阶段都会记录执行状态,并最终给出 global_status + """ + result = { + "env": {"type": None, "status": False}, + "mapping": [], + "mapping_status": False, + "interfaces": {}, + "rc_local": {"status": None}, + "dispatcher": {"status": None}, + "global_status": True, + "errors": [], + } + + def ssh_run(cmd): + full_cmd = f'ssh -i "{LOCAL_SSH_KEY}" root@{MASTER_IP} "{cmd}"' + p = subprocess.run( + full_cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + return p.returncode == 0, p.stdout.strip(), p.stderr.strip() + + # ------------------------------------------------------------------ + # 1. 环境判断 + # ------------------------------------------------------------------ + print("[STEP] Detecting remote environment...") + if ssh_run("command -v nvidia-smi")[0]: + env_type = "Nvidia" + elif ssh_run("command -v npu-smi")[0] or ssh_run("test -d /usr/local/Ascend")[0]: + env_type = "Ascend" + else: + env_type = "unknown" + + result["env"]["type"] = env_type + result["env"]["status"] = env_type != "unknown" + + if not result["env"]["status"]: + print("[FAIL] Environment detection failed") + result["errors"].append("Unable to detect NVIDIA or Ascend environment") + result["global_status"] = False + return result + + print(f"[OK] Detected environment: {env_type}") + + # ------------------------------------------------------------------ + # 2. 获取接口列表 + # ------------------------------------------------------------------ + print("[STEP] Collecting IB interfaces...") + ok, out, err = ssh_run("ibdev2netdev") + if not ok: + print("[FAIL] Failed to run ibdev2netdev") + result["errors"].append(err) + result["global_status"] = False + return result + + mapping = [] + for line in out.splitlines(): + if "==>" in line: + ibdev = line.split("==>")[0].split()[0] + netdev = line.split("==>")[1].split()[0] + mapping.append((ibdev, netdev)) + + result["mapping"] = mapping + result["mapping_status"] = bool(mapping) + + if not mapping: + print("[FAIL] No IB interfaces found") + result["global_status"] = False + return result + + print(f"[OK] Found interfaces: {mapping}") + + # ------------------------------------------------------------------ + # 3. QoS / 流控配置 + # ------------------------------------------------------------------ + print("[STEP] Applying QoS configuration...") + for ibdev, netdev in mapping: + iface_ok = True + iface_result = {} + + def run_cmd(name, cmd): + nonlocal iface_ok + ok, _, err = ssh_run(cmd) + iface_result[name] = ok + if not ok: + iface_ok = False + result["errors"].append(f"{netdev}: {name} failed: {err}") + return ok + + print(f" -> Configuring {netdev} ({ibdev})") + + run_cmd("pfc", f"mlnx_qos -i {netdev} --pfc 1,0,0,0,1,0,0,0 --trust dscp") + run_cmd("prio_tc", f"mlnx_qos -i {netdev} --prio_tc 0,0,0,0,4,0,0,0") + + if env_type == "Nvidia": + run_cmd( + "tcbw", + f"mlnx_qos -i {netdev} --tsa ets,ets,ets,ets,ets,ets,ets,ets " + f"--tcbw 10,90,0,0,0,0,0,0", + ) + else: + run_cmd( + "tcbw", + f"mlnx_qos -i {netdev} --tsa ets,ets,ets,ets,ets,ets,ets,ets " + f"--tcbw 10,0,0,0,90,0,0,0", + ) + + run_cmd("cma_roce_tos", f"cma_roce_tos -d {ibdev} -t 128") + + iface_result["status"] = iface_ok + result["interfaces"][netdev] = iface_result + + if iface_ok: + print(f" [OK] {netdev} configured successfully") + else: + print(f" [FAIL] {netdev} configuration failed") + result["global_status"] = False + + # ------------------------------------------------------------------ + # 4. rc-local 确认 + # ------------------------------------------------------------------ + if env_type == "Ascend": + print("[STEP] Configuring rc-local") + + rc_ok, _, _ = ssh_run( + 'test -f /etc/rc.local || echo -e "#!/bin/bash\\nexit 0" > /etc/rc.local' + ) + result["rc_local"]["status"] = rc_ok + + if not rc_ok: + result["errors"].append("rc-local setup failed") + result["global_status"] = False + print("[FAIL] rc-local configuration failed") + else: + print("[OK] rc-local configured") + + # ------------------------------------------------------------------ + # 配置永久生效 脚本写入(存在则覆盖,不存在则创建) + # ------------------------------------------------------------------ + DISPATCH_SCRIPT = "/etc/NetworkManager/dispatcher.d/pre-up.d/set-nvmeof-qos" + print(f"[STEP] Writing dispatcher script: {DISPATCH_SCRIPT}") + + dispatch_content = """#!/bin/sh + # NetworkManager dispatcher script for NVMe-oF QoS + # Triggered on pre-up + + netdev=$1 + action=$2 + + if [ "x${action}" != "xpre-up" ]; then + exit 0 + fi + + logger "[set-nvmeof-qos] pre-up triggered for ${netdev}" + + ibdev=$(ibdev2netdev | awk -v dev="${netdev}" '$2 == dev {print $1}') + [ -z "${ibdev}" ] && exit 0 + + mlnx_qos -i ${netdev} --pfc 1,0,0,0,1,0,0,0 --trust=dscp + mlnx_qos -i ${netdev} --prio_tc 0,0,0,0,4,0,0,0 + mlnx_qos -i ${netdev} --tsa ets,ets,ets,ets,ets,ets,ets,ets --tcbw 10,0,0,0,90,0,0,0 + cma_roce_tos -d ${ibdev} -t 128 + + exit 0 + """ + + disp_ok, _, err = ssh_run( + f""" + mkdir -p $(dirname {DISPATCH_SCRIPT}) && \ + cat > {DISPATCH_SCRIPT} << 'EOF' + {dispatch_content} + EOF + chmod 755 {DISPATCH_SCRIPT} + """ + ) + + result["dispatcher"]["status"] = disp_ok + + if disp_ok: + print("[OK] dispatcher script written successfully") + else: + print(f"[FAIL] dispatcher script write failed: {err}") + result["global_status"] = False + result["errors"].append("dispatcher script setup failed") + + print(f"[RESULT] Global QoS status: {result['global_status']}") + return result diff --git a/test/suites/E2E/test_environment_precheck.py b/test/suites/E2E/test_environment_precheck.py index 355652894..adb7140b3 100644 --- a/test/suites/E2E/test_environment_precheck.py +++ b/test/suites/E2E/test_environment_precheck.py @@ -264,3 +264,15 @@ def test_check_bandwidth(): assert ( bandwidth["fetch"] < 0.85 * EXPECTED_FETCH_BANDWIDTH ), f"Fetch bandwidth too high: {bandwidth['fetch']} GB/s" + + +@pytest.mark.stage(1) +@pytest.mark.platform("npu") +@pytest.mark.feature("test_apply_mlnx_qos") +def test_remote_mlnx_qos(): + """ + 远程配置宿主机流控 QoS 配置 + """ + qos_result = run_mlnx_qos() + print(qos_result) + assert qos_result["global_status"] is True