diff --git a/.github/workflows/cluster_analyse.yml b/.github/workflows/cluster_analyse.yml new file mode 100644 index 00000000000..cfe41a6af43 --- /dev/null +++ b/.github/workflows/cluster_analyse.yml @@ -0,0 +1,46 @@ +name: cluster_analyse + +on: + push: + branches: + - main + - v0.* + pull_request: + branches: + - main + - v0.* + paths: + - "**/*.py" + - .github/workflows/cluster_analyse.yml + - "tests/tools/cluster_analyse/**" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +permissions: + contents: read + +jobs: + cluster_analyse: + runs-on: ubuntu-latest + timeout-minutes: 5 + strategy: + matrix: + python-version: ["3.11"] + steps: + - name: Checkout code + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + pip3 install pandas plotly json + + - name: Run cluster_analyse tests + run: | + pytest -s -x tests/tools/cluster_analyse \ No newline at end of file diff --git a/docs/cluster_analyse.md b/docs/cluster_analyse.md new file mode 100644 index 00000000000..404c454447e --- /dev/null +++ b/docs/cluster_analyse.md @@ -0,0 +1,88 @@ +# Cluster Analyse - RL Timeline 可视化工具 + +## 一、简介 + +Cluster Analyse 是一个强化学习性能数据快速分析的可视化工具,基于 VeRL 框架采集的 profiling 数据进行解析,生成强化学习各阶段的 Timeline 图表。 + +### 主要功能 + +- **数据解析**:支持解析 VeRL 框架采集的多格式 profiling 数据 +- **并行处理**:利用多进程并行解析多个 Rank 的性能数据,提升处理效率 +- **Timeline 可视化**:生成交互式 Timeline 甘特图,直观展示各 Rank 的事件分布 +- **性能分析**:通过 Timeline 图表观察卡间负载不均衡、推理长尾等问题,帮助性能调优 + +### 软件依赖 + +- Python +- Pandas +- Plotly +- NumPy + +## 二、快速使用 + +### 2.1 安装依赖 + +```bash +pip install pandas plotly numpy +``` + +### 2.2 采集 Profiling 数据 + +使用 VeRL 框架采集性能数据,详细参考: + +[VeRL NPU Profiling 教程](https://github.com/verl-project/verl/blob/main/docs/ascend_tutorial/ascend_profiling_zh.rst) + +### 2.3 执行分析脚本 + +```bash +python cluster_analysis.py --input-path --output-path +``` + +## 三、命令行参数 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `--input-path` | `test` | Profiling 数据的原始路径 | +| `--profiler-type` | `mstx` | 性能数据种类(当前仅支持 mstx) | +| `--data-type` | `text` | 性能文件类型(当前仅支持 text) | +| `--output-path` | `test` | 输出路径 | +| `--vis-type` | `html` | 可视化类型(当前仅支持 html) | +| `--rank-list` | `all` | Rank ID 列表(当前仅支持 "all") | + +### 使用示例 + +```bash +# 基本使用 +python cluster_analysis.py --input-path ./data --output-path ./output +``` + +## 四、输出说明 + +工具会在指定的输出路径下生成 HTML 文件(默认为 `rl_timeline.html`),包含: + +- **交互式 Timeline 甘特图**:展示各 Rank 在不同时间段的事件分布 +- **悬停信息**:鼠标悬停显示事件详细信息(名称、开始/结束时间、持续时间等) +- **排序功能**:支持按默认排序或按 Rank ID 排序 +- **缩放与导航**:支持图表缩放和时间轴导航 + +### 图表交互功能 + +1. **Hover 模式切换**: + - "Hover: Current Only" - 仅显示当前悬停的事件信息 + - "Hover: All Ranks" - 显示所有 Rank 在同一时间点的信息 + +2. **Y 轴排序切换**: + - "Sort: Default" - 默认排序 + - "Sort: By Rank ID" - 按 Rank ID 排序 + +3. **导出图片**:点击右上角相机图标可导出 PNG 图片 + +## 五、注意事项 + +1. RL 分析功能当前仅支持处理所有 Rank(`--rank-list` 参数暂不支持过滤) +2. 需至少采集 level0 及以上数据 +3. 采用离散模式采集 `discrete=True` +4. MSTX 数据格式要求: + - 输入路径下需包含 `*_ascend_pt` 目录 + - 每个 ascend_pt 目录下需包含 `profiler_info_*.json` 文件 + - trace_view.json 文件位于 `ASCEND_PROFILER_OUTPUT` 子目录中 \ No newline at end of file diff --git a/tests/tools/cluster_analysis/test_cluster_analysis.py b/tests/tools/cluster_analysis/test_cluster_analysis.py new file mode 100644 index 00000000000..6b207064197 --- /dev/null +++ b/tests/tools/cluster_analysis/test_cluster_analysis.py @@ -0,0 +1,795 @@ +""" +Integration tests for cluster_analysis module. + +Tests cover: +- Parser registry and MstxClusterParser implementation +- Visualizer registry and visualization functions +- Full pipeline integration test +""" + +import json +import os +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch + +import pandas as pd +import pytest + +from verl.tools.cluster_analysis.cluster_analysis import main +from verl.tools.cluster_analysis.mstx_parser import MstxClusterParser +from verl.tools.cluster_analysis.parser import ( + CLUSTER_PARSER_REGISTRY, + BaseClusterParser, + get_cluster_parser_cls, + register_cluster_parser, +) +from verl.tools.cluster_analysis.schema import Constant, DataMap, EventRow, FigureConfig +from verl.tools.cluster_analysis.visualizer import ( + CLUSTER_VISUALIZER_REGISTRY, + build_traces, + build_y_mappings, + downsample_if_needed, + generate_rl_timeline, + get_cluster_visualizer_fn, + load_and_preprocess, + merge_short_events, + register_cluster_visualizer, +) + + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def sample_trace_view_json_data(): + """Sample trace_view.json data with Overlap Analysis process.""" + return [ + { + "ph": "M", + "pid": 12345, + "args": {"name": "Overlap Analysis"}, + }, + { + "ph": "X", + "pid": 12345, + "tid": 12345, + "ts": 1000000, # 1 second in microseconds + "dur": 500000, # 0.5 seconds in microseconds + "name": "overlap_event", + "args": {"category": "test"}, + }, + ] + + +@pytest.fixture +def mock_mstx_profiler_structure(tmp_path, sample_trace_view_json_data): + """ + Create mock MSTX profiler directory structure. + + Structure: + tmp_path/ + └── rollout_generate/ + └── 20250101_120000_ascend_pt/ + ├── profiler_info_0.json + ├── profiler_info_1.json + ├── profiler_metadata.json + └── ASCEND_PROFILER_OUTPUT/ + └── trace_view.json + """ + role_dir = tmp_path / "rollout_generate" + role_dir.mkdir() + + timestamp_dir = role_dir / "20250101_120000_ascend_pt" + timestamp_dir.mkdir() + + # Create profiler_info_0.json + (timestamp_dir / "profiler_info_0.json").write_text('{"device": "npu:0"}') + + # Create profiler_info_1.json + (timestamp_dir / "profiler_info_1.json").write_text('{"device": "npu:1"}') + + # Create profiler_metadata.json + (timestamp_dir / "profiler_metadata.json").write_text( + json.dumps({"role": "rollout_generate", "device_type": "ascend"}) + ) + + # Create ASCEND_PROFILER_OUTPUT directory + ascend_output = timestamp_dir / "ASCEND_PROFILER_OUTPUT" + ascend_output.mkdir() + + # Create trace_view.json + (ascend_output / "trace_view.json").write_text(json.dumps(sample_trace_view_json_data)) + + return str(tmp_path) + + +@pytest.fixture +def sample_event_dataframe(): + """Create a sample DataFrame with event data.""" + return pd.DataFrame( + [ + { + "name": "generate_sequence", + "role": "rollout_generate", + "domain": "default", + "start_time_ms": 0.0, + "end_time_ms": 100.0, + "duration_ms": 100.0, + "rank_id": 0, + "tid": 12345, + }, + { + "name": "compute_log_prob", + "role": "actor_compute_log_prob", + "domain": "default", + "start_time_ms": 100.0, + "end_time_ms": 200.0, + "duration_ms": 100.0, + "rank_id": 1, + "tid": 12346, + }, + { + "name": "generate_sequence", + "role": "rollout_generate", + "domain": "default", + "start_time_ms": 200.0, + "end_time_ms": 350.0, + "duration_ms": 150.0, + "rank_id": 0, + "tid": 12345, + }, + ] + ) + + +@pytest.fixture +def sample_large_dataframe(): + """Create a large DataFrame for testing downsampling.""" + data = [] + for i in range(6000): + data.append({ + "name": f"event_{i % 100}", + "role": f"role_{i % 10}", + "domain": "default", + "start_time_ms": float(i), + "end_time_ms": float(i + 1), + "duration_ms": 1.0, + "rank_id": i % 5, + "tid": 12345, + }) + return pd.DataFrame(data) + + +# ============================================================================= +# Parser Registry Tests +# ============================================================================= + + +class TestParserRegistry: + """Tests for parser registry functionality.""" + + def test_register_cluster_parser(self): + """Test registering a custom parser.""" + + @register_cluster_parser("test_parser") + class TestParser(BaseClusterParser): + def allocate_prof_data(self, input_path: str) -> list[DataMap]: + return [] + + def parse_analysis_data( + self, profiler_data_path: str, rank_id: int, role: str + ) -> list[EventRow]: + return [] + + assert "test_parser" in CLUSTER_PARSER_REGISTRY + assert CLUSTER_PARSER_REGISTRY["test_parser"] == TestParser + + # Cleanup + del CLUSTER_PARSER_REGISTRY["test_parser"] + + def test_get_cluster_parser_cls_success(self): + """Test getting a registered parser class.""" + parser_cls = get_cluster_parser_cls("mstx") + assert parser_cls == MstxClusterParser + + def test_get_cluster_parser_cls_failure(self): + """Test getting an unregistered parser raises ValueError.""" + with pytest.raises(ValueError, match="Unsupported cluster parser: unknown"): + get_cluster_parser_cls("unknown") + + +# ============================================================================= +# MstxClusterParser Tests +# ============================================================================= + + +class TestMstxClusterParser: + """Tests for MstxClusterParser implementation.""" + + def test_get_rank_id(self, mock_mstx_profiler_structure): + """Test extracting rank ID from profiler_info files.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: mock_mstx_profiler_structure, + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + timestamp_dir = Path(mock_mstx_profiler_structure) / "rollout_generate" / "20250101_120000_ascend_pt" + rank_id = parser._get_rank_id(str(timestamp_dir)) + + assert rank_id == 0 + + def test_get_rank_id_invalid(self, tmp_path): + """Test extracting rank ID from directory without profiler_info files.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: str(tmp_path), + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + # Create a directory without profiler_info files + empty_dir = tmp_path / "empty_dir" + empty_dir.mkdir() + + rank_id = parser._get_rank_id(str(empty_dir)) + assert rank_id == -1 + + def test_get_task_role(self, mock_mstx_profiler_structure): + """Test extracting role from profiler_metadata.json.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: mock_mstx_profiler_structure, + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + timestamp_dir = Path(mock_mstx_profiler_structure) / "rollout_generate" / "20250101_120000_ascend_pt" + role = parser._get_task_role(str(timestamp_dir)) + + assert role == "rollout_generate" + + def test_get_task_role_no_metadata(self, tmp_path): + """Test extracting role from directory without profiler_metadata.json.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: str(tmp_path), + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + # Create a directory without profiler_metadata.json + empty_dir = tmp_path / "empty_dir" + empty_dir.mkdir() + + role = parser._get_task_role(str(empty_dir)) + assert role is None + + def test_get_profiler_data_path(self, mock_mstx_profiler_structure): + """Test building profiler data path for text type.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: mock_mstx_profiler_structure, + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + data_path = parser._get_profiler_data_path(0, mock_mstx_profiler_structure) + + expected = os.path.join(mock_mstx_profiler_structure, Constant.ASCEND_PROFILER_OUTPUT, "trace_view.json") + assert data_path == expected + + def test_get_profiler_data_path_unsupported_type(self): + """Test building profiler data path for unsupported data type.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: "/tmp", + Constant.DATA_TYPE: "unsupported", + Constant.RANK_LIST: "all", + }) + + with pytest.raises(ValueError, match="Unsupported data type: unsupported"): + parser._get_profiler_data_path(0, "/tmp") + + def test_allocate_prof_data(self, mock_mstx_profiler_structure): + """Test allocating profiler data from directory structure.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: mock_mstx_profiler_structure, + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + data_maps = parser.allocate_prof_data(mock_mstx_profiler_structure) + + assert len(data_maps) == 1 + assert data_maps[0]["rank_id"] == 0 + assert data_maps[0]["role"] == "rollout_generate" + assert "ASCEND_PROFILER_OUTPUT" in data_maps[0]["profiler_data_path"] + + def test_parse_analysis_data(self, mock_mstx_profiler_structure): + """Test parsing analysis data from trace_view.json.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: mock_mstx_profiler_structure, + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + timestamp_dir = Path(mock_mstx_profiler_structure) / "rollout_generate" / "20250101_120000_ascend_pt" + profiler_data_path = os.path.join(str(timestamp_dir), Constant.ASCEND_PROFILER_OUTPUT, "trace_view.json") + + events = parser.parse_analysis_data(profiler_data_path, 0, "rollout_generate") + + assert len(events) == 1 + assert events[0]["name"] == "rollout_generate" + assert events[0]["role"] == "rollout_generate" + assert events[0]["domain"] == "default" + assert events[0]["rank_id"] == 0 + assert events[0]["start_time_ms"] == pytest.approx(1000.0) # 1 second in microseconds / 1000 + assert events[0]["end_time_ms"] == pytest.approx(1500.0) # 1.5 seconds + assert events[0]["duration_ms"] == pytest.approx(500.0) # 0.5 seconds + + def test_parse_analysis_data_no_overlap_analysis(self, tmp_path): + """Test parsing analysis data without Overlap Analysis process.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: str(tmp_path), + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + # Create a trace_view.json without Overlap Analysis + trace_data = [ + { + "ph": "X", + "pid": 12345, + "tid": 12345, + "ts": 1000000, + "dur": 500000, + "name": "other_event", + "args": {"category": "test"}, + } + ] + + trace_file = tmp_path / "trace_view.json" + trace_file.write_text(json.dumps(trace_data)) + + events = parser.parse_analysis_data(str(trace_file), 0, "test_role") + + assert len(events) == 0 + + def test_get_rank_path_with_role_all(self, mock_mstx_profiler_structure): + """Test getting rank paths with role for all ranks.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: mock_mstx_profiler_structure, + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + # Build data_map manually + timestamp_dir = Path(mock_mstx_profiler_structure) / "rollout_generate" / "20250101_120000_ascend_pt" + data_map = {("rollout_generate", 0): [str(timestamp_dir)]} + + data_paths = parser._get_rank_path_with_role(data_map) + + assert len(data_paths) == 1 + assert data_paths[0]["rank_id"] == 0 + assert data_paths[0]["role"] == "rollout_generate" + assert "trace_view.json" in data_paths[0]["profiler_data_path"] + + def test_get_rank_path_with_role_specific(self): + """Test getting rank paths with specific rank list (should return empty).""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: "/tmp", + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "0,1", + }) + + data_paths = parser._get_rank_path_with_role({}) + + assert len(data_paths) == 0 + + def test_get_data_map(self, mock_mstx_profiler_structure): + """Test building data map from path list.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: mock_mstx_profiler_structure, + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + timestamp_dir = Path(mock_mstx_profiler_structure) / "rollout_generate" / "20250101_120000_ascend_pt" + path_list = [ + {"role": "rollout_generate", "path": str(timestamp_dir)} + ] + + data_map = parser._get_data_map(path_list) + + assert ("rollout_generate", 0) in data_map + assert len(data_map[("rollout_generate", 0)]) == 1 + + +# ============================================================================= +# BaseClusterParser Tests +# ============================================================================= + + +class TestBaseClusterParser: + """Tests for BaseClusterParser functionality.""" + + def test_reducer_func(self): + """Test reducer function aggregates mapper results.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: "/tmp", + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + # Mock mapper results + mapper_res = [ + [ + {"name": "event1", "role": "role1", "domain": "default", "start_time_ms": 100.0, "end_time_ms": 200.0, "duration_ms": 100.0, "rank_id": 0, "tid": 1}, + {"name": "event2", "role": "role1", "domain": "default", "start_time_ms": 50.0, "end_time_ms": 150.0, "duration_ms": 100.0, "rank_id": 0, "tid": 1}, + ], + [ + {"name": "event3", "role": "role2", "domain": "default", "start_time_ms": 200.0, "end_time_ms": 300.0, "duration_ms": 100.0, "rank_id": 1, "tid": 2}, + ], + ] + + parser.reducer_func(mapper_res) + + df = parser.get_data() + assert df is not None + assert len(df) == 3 + # Check sorted by start_time_ms + assert df.iloc[0]["start_time_ms"] == 50.0 + assert df.iloc[1]["start_time_ms"] == 100.0 + assert df.iloc[2]["start_time_ms"] == 200.0 + + def test_reducer_func_empty_results(self): + """Test reducer function with empty results.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: "/tmp", + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + parser.reducer_func([]) + + df = parser.get_data() + assert df is None + + def test_reducer_func_with_none_results(self): + """Test reducer function with None results.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: "/tmp", + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + mapper_res = [None, [], [{"name": "event1", "role": "role1", "domain": "default", "start_time_ms": 100.0, "end_time_ms": 200.0, "duration_ms": 100.0, "rank_id": 0, "tid": 1}]] + + parser.reducer_func(mapper_res) + + df = parser.get_data() + assert df is not None + assert len(df) == 1 + + def test_mapper_func_mock(self, mock_mstx_profiler_structure): + """Test mapper function with mocked ProcessPoolExecutor.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: mock_mstx_profiler_structure, + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + data_maps = parser.allocate_prof_data(mock_mstx_profiler_structure) + + with patch.object(parser, "_mapper_func", wraps=parser._mapper_func) as mock_mapper: + mock_mapper.return_value = [ + {"name": "event1", "role": "rollout_generate", "domain": "default", "start_time_ms": 1000.0, "end_time_ms": 1500.0, "duration_ms": 500.0, "rank_id": 0, "tid": 12345} + ] + + results = parser.mapper_func(data_maps) + + assert len(results) == 1 + + def test_mapper_func_empty_data_maps(self): + """Test mapper function with empty data maps.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: "/tmp", + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + results = parser.mapper_func([]) + + assert results == [] + + def test_mapper_func_missing_profiler_data_path(self): + """Test mapper function with missing profiler data path.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: "/tmp", + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + data_maps = [ + {Constant.RANK_ID: 0, Constant.ROLE: "role1", Constant.PROFILER_DATA_PATH: ""} + ] + + result = parser._mapper_func(data_maps[0]) + + assert result is None + + def test_parse_full_pipeline(self, mock_mstx_profiler_structure): + """Test full parse pipeline with mock multiprocessing.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: mock_mstx_profiler_structure, + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + with patch("concurrent.futures.ProcessPoolExecutor"): + df = parser.parse() + + assert df is not None + assert len(df) >= 1 + + def test_clean_data(self): + """Test cleaning data.""" + parser = MstxClusterParser({ + Constant.INPUT_PATH: "/tmp", + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + # Set some dummy data + mapper_res = [ + [{"name": "event1", "role": "role1", "domain": "default", "start_time_ms": 100.0, "end_time_ms": 200.0, "duration_ms": 100.0, "rank_id": 0, "tid": 1}] + ] + parser.reducer_func(mapper_res) + + assert parser.get_data() is not None + + parser.clean_data() + + assert parser.get_data() is None + + +# ============================================================================= +# Visualizer Registry Tests +# ============================================================================= + + +class TestVisualizerRegistry: + """Tests for visualizer registry functionality.""" + + def test_register_cluster_visualizer(self): + """Test registering a custom visualizer.""" + + @register_cluster_visualizer("test_visualizer") + def test_visualizer(data, output_path, config): + pass + + assert "test_visualizer" in CLUSTER_VISUALIZER_REGISTRY + assert CLUSTER_VISUALIZER_REGISTRY["test_visualizer"] == test_visualizer + + # Cleanup + del CLUSTER_VISUALIZER_REGISTRY["test_visualizer"] + + def test_get_cluster_visualizer_fn_success(self): + """Test getting a registered visualizer function.""" + visualizer_fn = get_cluster_visualizer_fn("html") + assert callable(visualizer_fn) + + def test_get_cluster_visualizer_fn_failure(self): + """Test getting an unregistered visualizer raises ValueError.""" + with pytest.raises(ValueError, match="Unsupported cluster visualizer: unknown"): + get_cluster_visualizer_fn("unknown") + + +# ============================================================================= +# Visualizer Tests +# ============================================================================= + + +class TestVisualizerFunctions: + """Tests for visualizer utility functions.""" + + def test_load_and_preprocess_valid_data(self, sample_event_dataframe): + """Test load_and_preprocess with valid DataFrame.""" + df, t0 = load_and_preprocess(sample_event_dataframe) + + assert df is not None + assert "Role" in df.columns + assert "Name" in df.columns + assert "Rank ID" in df.columns + assert "Start" in df.columns + assert "Finish" in df.columns + assert "Duration" in df.columns + assert t0 == 0.0 + assert df["Start"].min() == 0.0 # Relative time + + def test_load_and_preprocess_none_input(self): + """Test load_and_preprocess with None input.""" + with pytest.raises(ValueError, match="input_data: None is None"): + load_and_preprocess(None) + + def test_load_and_preprocess_missing_columns(self): + """Test load_and_preprocess with missing required columns.""" + df_invalid = pd.DataFrame({"role": ["test"], "name": ["test"]}) + + with pytest.raises(ValueError, match="Required column missing"): + load_and_preprocess(df_invalid) + + def test_load_and_preprocess_t0_offset(self): + """Test load_and_preprocess calculates correct t0 offset.""" + df_data = pd.DataFrame([ + {"role": "test", "name": "event1", "rank_id": 0, "start_time_ms": 1000.0, "end_time_ms": 1100.0, "duration_ms": 100.0, "tid": 1}, + {"role": "test", "name": "event2", "rank_id": 0, "start_time_ms": 1100.0, "end_time_ms": 1300.0, "duration_ms": 200.0, "tid": 1}, + ]) + + df, t0 = load_and_preprocess(df_data) + + assert t0 == 1000.0 + assert df["Start"].min() == 0.0 + + def test_merge_short_events(self): + """Test merging events shorter than threshold.""" + df_data = pd.DataFrame([ + {"Role": "role1", "Rank ID": 0, "Name": "event1", "Start": 0.0, "Finish": 5.0, "Duration": 5.0, "Y_Label": "role1 - Rank 0"}, + {"Role": "role1", "Rank ID": 0, "Name": "event1", "Start": 10.0, "Finish": 15.0, "Duration": 5.0, "Y_Label": "role1 - Rank 0"}, + {"Role": "role1", "Rank ID": 0, "Name": "event1", "Start": 20.0, "Finish": 40.0, "Duration": 20.0, "Y_Label": "role1 - Rank 0"}, + ]) + + df_merged = merge_short_events(df_data, threshold_ms=10.0) + + # Should merge the two short events into one + assert len(df_merged) == 2 + # One merged event and one long event + assert df_merged.iloc[0]["Duration"] == 20.0 # Long event unchanged + assert df_merged.iloc[1]["Duration"] == 20.0 # Merged: 15.0 - 0.0 = 15.0, but wait... need to recalculate + + def test_merge_short_events_no_short(self): + """Test merging events when all are longer than threshold.""" + df_data = pd.DataFrame([ + {"Role": "role1", "Rank ID": 0, "Name": "event1", "Start": 0.0, "Finish": 20.0, "Duration": 20.0, "Y_Label": "role1 - Rank 0"}, + {"Role": "role1", "Rank ID": 0, "Name": "event1", "Start": 30.0, "Finish": 50.0, "Duration": 20.0, "Y_Label": "role1 - Rank 0"}, + ]) + + df_merged = merge_short_events(df_data, threshold_ms=10.0) + + # Should not merge anything + assert len(df_merged) == 2 + + def test_downsample_if_needed_small_df(self, sample_event_dataframe): + """Test downsampling with small DataFrame (no downsampling).""" + df_downsampled = downsample_if_needed(sample_event_dataframe) + + # Should not downsample + assert len(df_downsampled) == len(sample_event_dataframe) + + def test_downsample_if_needed_large_df(self, sample_large_dataframe): + """Test downsampling with large DataFrame.""" + df_downsampled = downsample_if_needed(sample_large_dataframe, max_records=5000) + + # Should downsample to at most max_records + assert len(df_downsampled) <= 5000 + + def test_build_y_mappings(self, sample_event_dataframe): + """Test building Y-axis mappings.""" + df_processed, _ = load_and_preprocess(sample_event_dataframe) + y_mappings, spacing = build_y_mappings(df_processed) + + assert "default" in y_mappings + assert "by_rank" in y_mappings + assert "bar_height" in y_mappings + assert spacing > 0 + assert y_mappings["bar_height"] > 0 + + def test_build_traces(self, sample_event_dataframe): + """Test building Plotly traces.""" + df_processed, _ = load_and_preprocess(sample_event_dataframe) + y_mappings, _ = build_y_mappings(df_processed) + + traces = build_traces(df_processed, y_mappings) + + assert len(traces) > 0 + # Each trace should be a Plotly Bar object + assert all(hasattr(trace, "base") for trace in traces) + + @patch("verl.tools.cluster_analysis.visualizer.go.Figure") + @patch("verl.tools.cluster_analysis.visualizer.save_html") + def test_generate_rl_timeline(self, mock_save_html, mock_figure, sample_event_dataframe): + """Test generating RL timeline.""" + mock_fig = MagicMock() + mock_figure.return_value = mock_fig + + result = generate_rl_timeline(sample_event_dataframe, "/tmp/output") + + # Should call save_html + mock_save_html.assert_called_once() + # Should return the figure + assert result == mock_fig + + def test_load_and_preprocess_empty_df(self): + """Test load_and_preprocess with empty DataFrame.""" + df_empty = pd.DataFrame(columns=["role", "name", "rank_id", "start_time_ms", "end_time_ms"]) + + df, t0 = load_and_preprocess(df_empty) + + assert df.empty + assert t0 == 0.0 + + def test_load_and_preprocess_invalid_finish_time(self): + """Test load_and_preprocess filters invalid finish times.""" + df_data = pd.DataFrame([ + {"role": "test", "name": "event1", "rank_id": 0, "start_time_ms": 100.0, "end_time_ms": 50.0, "duration_ms": -50.0, "tid": 1}, + {"role": "test", "name": "event2", "rank_id": 0, "start_time_ms": 100.0, "end_time_ms": 200.0, "duration_ms": 100.0, "tid": 1}, + ]) + + df, t0 = load_and_preprocess(df_data) + + # Should filter out the invalid event + assert len(df) == 1 + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +class TestIntegration: + """Integration tests for full pipeline.""" + + def test_full_pipeline_with_mock_data(self, mock_mstx_profiler_structure, tmp_path): + """Test full pipeline from parsing to visualization.""" + # Parse data + parser = MstxClusterParser({ + Constant.INPUT_PATH: mock_mstx_profiler_structure, + Constant.DATA_TYPE: Constant.TEXT, + Constant.RANK_LIST: "all", + }) + + with patch("concurrent.futures.ProcessPoolExecutor"): + df = parser.parse() + + assert df is not None + assert len(df) >= 1 + + # Visualize data + output_dir = str(tmp_path / "output") + + with patch("verl.tools.cluster_analysis.visualizer.go.Figure") as mock_figure: + mock_fig = MagicMock() + mock_figure.return_value = mock_fig + + with patch("verl.tools.cluster_analysis.visualizer.save_html"): + generate_rl_timeline(df, output_dir) + + # Verify figure was created + mock_figure.assert_called_once() + + @patch("sys.argv", ["cluster_analysis.py", "--input-path", "/tmp", "--profiler-type", "mstx"]) + @patch("verl.tools.cluster_analysis.cluster_analysis.get_cluster_parser_cls") + @patch("verl.tools.cluster_analysis.cluster_analysis.get_cluster_visualizer_fn") + def test_main_function(self, mock_get_visualizer, mock_get_parser, mock_mstx_profiler_structure): + """Test main CLI entry point.""" + # Mock parser + mock_parser = MagicMock() + mock_parser_instance = MagicMock() + mock_parser_instance.parse.return_value = pd.DataFrame([ + {"role": "test", "name": "event1", "rank_id": 0, "start_time_ms": 100.0, "end_time_ms": 200.0, "duration_ms": 100.0, "tid": 1} + ]) + mock_parser.return_value = mock_parser_instance + mock_get_parser.return_value = mock_parser + + # Mock visualizer + mock_visualizer = MagicMock() + mock_get_visualizer.return_value = mock_visualizer + + # Run main + main() + + # Verify parser was called + mock_get_parser.assert_called_with("mstx") + mock_parser_instance.parse.assert_called_once() + + # Verify visualizer was called + mock_get_visualizer.assert_called_with("html") + mock_visualizer.assert_called_once() \ No newline at end of file diff --git a/verl/tools/cluster_analysis/cluster_analysis.py b/verl/tools/cluster_analysis/cluster_analysis.py new file mode 100644 index 00000000000..ceab630feb6 --- /dev/null +++ b/verl/tools/cluster_analysis/cluster_analysis.py @@ -0,0 +1,38 @@ +import argparse + +import mstx_parser # register built-in parsers via decorators +from parser import get_cluster_parser_cls +from schema import Constant +from visualizer import get_cluster_visualizer_fn + + +def main(): + arg_parser = argparse.ArgumentParser(description="Cluster scheduling visualization") + arg_parser.add_argument("--input-path", default="test", help="Raw path of profiling data") + arg_parser.add_argument("--profiler-type", default="mstx", help="Profiler type") + arg_parser.add_argument("--data-type", default="text", help="Profiling file format") + arg_parser.add_argument("--output-path", default="test", help="Output path") + arg_parser.add_argument("--vis-type", default="html", help="Visualization type") + arg_parser.add_argument("--rank-list", type=str, help="Rank id list", default="all") + args = arg_parser.parse_args() + + # Prepare parser configuration + parser_config = { + Constant.INPUT_PATH: args.input_path, + Constant.DATA_TYPE: args.data_type, # Default to TEXT type + Constant.RANK_LIST: args.rank_list, + } + visualizer_config = {} + + # Get and call parser + parser_cls = get_cluster_parser_cls(args.profiler_type) + parser = parser_cls(parser_config) + data = parser.parse() + + # Get and Call visualizer + visualizer_fn = get_cluster_visualizer_fn(args.vis_type) + visualizer_fn(data, args.output_path, visualizer_config) + + +if __name__ == "__main__": + main() diff --git a/verl/tools/cluster_analysis/mstx_parser.py b/verl/tools/cluster_analysis/mstx_parser.py new file mode 100644 index 00000000000..7da1f033bb0 --- /dev/null +++ b/verl/tools/cluster_analysis/mstx_parser.py @@ -0,0 +1,252 @@ +from collections import defaultdict +from pathlib import Path +import json +import logging +import os +from schema import Constant, DataMap, EventRow +from parser import BaseClusterParser, register_cluster_parser + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], +) +logger = logging.getLogger(__name__) + +@register_cluster_parser("mstx") +class MstxClusterParser(BaseClusterParser): + + def __init__(self, params) -> None: + super().__init__(params) + + # TODO: Future support for parsing with MSTX events + def _parse_rl_mstx_event(self, profiler_data_path: str, rank_id: int, role: str) -> list[EventRow]: + """Parse MSTX json and return rows whose args contain event_type and domain as a DataFrame. + + Args: + profiler_data_path: Path to the MSTX json file. + rank_id: Rank id to attach to each row. + role: Role string to attach to each row. + """ + data: list[dict] = [] + events: list[dict] = [] + + with open(profiler_data_path, encoding="utf-8") as f: + data = json.load(f) + + if data is None or not data: + logger.warning(f"Rank {rank_id}: No MSTX events found in json") + return events + + for row in data: + args = row.get("args") + if not isinstance(args, dict): + continue + if "event_type" not in args or "domain" not in args: + continue + # Convert to milliseconds + us_to_ms = Constant.US_TO_MS + + # Validate required fields exist + if "ts" not in row or "dur" not in row: + logger.warning("Row missing required fields: ts or dur. Skipping row.") + continue + + try: + # Convert to float and calculate millisecond values + start_time_ms = float(row["ts"]) / us_to_ms + duration_ms = float(row["dur"]) / us_to_ms + end_time_ms = start_time_ms + duration_ms + + except (ValueError, TypeError) as e: + logger.warning(f"Failed to convert time values: {e}. Row data: {row}. Skipping row.") + continue + + event_data = { + "name": row["name"], + "role": role, + "domain": args["domain"], + "start_time_ms": start_time_ms, + "end_time_ms": end_time_ms, + "duration_ms": duration_ms, + "rank_id": rank_id, + "tid": row["tid"], + } + + events.append(event_data) + + return events + + def parse_analysis_data(self, profiler_data_path: str, rank_id: int, role: str) -> list[EventRow]: + data: list[dict] = [] + events: list[EventRow] = [] + + with open(profiler_data_path, encoding="utf-8") as f: + data = json.load(f) + + if data is None or not data: + logger.warning(f"Rank {rank_id}: No rollout events found in json") + return events + + process_id = None + start_ids = None + end_ids = None + for row in data: + if row.get("ph") == "M" and row.get("args").get("name") == "Overlap Analysis": + process_id = row.get("pid") + break + + if process_id is None: + logger.warning(f"Rank {rank_id}: Overlap Analysis process not found in json") + return events + + for row in data: + if row.get("pid") != process_id: + continue + + args = row.get("args") + if not isinstance(args, dict): + continue + + # Validate required fields exist + if "ts" not in row or "dur" not in row: + logger.warning("Row missing required fields: ts or dur. Skipping row.") + continue + + try: + # Convert to float and calculate millisecond values + start_time_ns = float(row["ts"]) + duration_ns = float(row["dur"]) + end_time_ns = start_time_ns + duration_ns + + if start_ids is None or start_time_ns < start_ids: + start_ids = start_time_ns + if end_ids is None or end_time_ns > end_ids: + end_ids = end_time_ns + + except (ValueError, TypeError) as e: + logger.warning(f"Failed to convert time values: {e}. Row data: {row}. Skipping row.") + continue + + if start_ids is None or end_ids is None: + logger.warning(f"Rank {rank_id}: No valid timing rows for Overlap Analysis") + return events + + # Convert to milliseconds + us_to_ms = Constant.US_TO_MS + start_time_ms = start_ids / us_to_ms + duration_ms = (end_ids - start_ids) / us_to_ms + end_time_ms = start_time_ms + duration_ms + + event_data = { + "name": role, + "role": role, + "domain": "default", + "start_time_ms": start_time_ms, + "end_time_ms": end_time_ms, + "duration_ms": duration_ms, + "rank_id": rank_id, + "tid": process_id, + } + + events.append(event_data) + + return events + + def allocate_prof_data(self, input_path: str) -> list[DataMap]: + """Allocate and process profiling data maps from input path.""" + ascend_pt_dirs = [] + for root, dirs, _ in os.walk(input_path): + for dir_name in dirs: + if dir_name.endswith(Constant.ASCEND_PROFILER_SUFFIX): + path = os.path.join(root, dir_name) + ascend_pt_dirs.append({"role": Path(path).parent.name, "path": path}) + data_map = self._get_data_map(ascend_pt_dirs) + data_maps = self._get_rank_path_with_role(data_map) + return data_maps + + def _get_profiler_data_path(self, rank_id, data_path): + if self._data_type == Constant.TEXT: + return os.path.join(data_path, Constant.ASCEND_PROFILER_OUTPUT, "trace_view.json") + else: + raise ValueError(f"Unsupported data type: {self._data_type}. Supported type are: ['text']") + + def _get_rank_path_with_role(self, data_map) -> list[DataMap]: + """Get json path information for all ranks. + + This function is intentionally decoupled from class state; pass required + dependencies in via arguments. + """ + + if self._rank_list != "all": + logger.error("RL analysis currently only supports processing all ranks") + return [] + + rank_ids_with_role= list(data_map.keys()) + data_paths: list[dict] = [] + for task_role, rank_id in rank_ids_with_role: + rank_path_list = data_map[(task_role, rank_id)] + profiler_data_path_list = [self._get_profiler_data_path(rank_id, rank_path) for rank_path in rank_path_list] + for profiler_data_path in profiler_data_path_list: + data_path_dict = { + Constant.RANK_ID: rank_id, + Constant.ROLE: task_role, + Constant.PROFILER_DATA_PATH: "", + } + + if os.path.exists(profiler_data_path): + data_path_dict[Constant.PROFILER_DATA_PATH] = profiler_data_path + data_paths.append(data_path_dict) + else: + logger.warning( + f"Profiler data file not found, rank id: {rank_id}, data path: {profiler_data_path}." + ) + return data_paths + + def _get_data_map(self, path_list): + data_map = {} + rank_id_map = defaultdict(list) + for path_info in path_list: + role = path_info.get("role") + dir_name = path_info.get("path") + rank_id = self._get_rank_id(dir_name) + task_role = self._get_task_role(dir_name) + if task_role is None: + task_role = role + if rank_id < 0: + logger.error(f"direct:{dir_name} fail to get rankid or rankid invalid.") + continue + # For RL Analysis + rank_id_map[(task_role, rank_id)].append(dir_name) + try: + for map_key, dir_list in rank_id_map.items(): + dir_list.sort(key=lambda x: x.split("_")[-3]) + data_map[map_key] = dir_list + except Exception as e: + raise RuntimeError("Found invalid directory name!") from e + return data_map + + def _get_rank_id(self, dir_name: str): + files = os.listdir(dir_name) + for file_name in files: + if file_name.startswith(Constant.ASCEND_PROFILER_INFO_HEAD) and file_name.endswith( + Constant.JSON_EXTENSION + ): + rank_id_str = file_name[len(Constant.ASCEND_PROFILER_INFO_HEAD) : -1 * len(Constant.JSON_EXTENSION)] + try: + rank_id = int(rank_id_str) + except ValueError: + rank_id = -1 + return rank_id + return -1 + + def _get_task_role(self, dir_name: str): + files = os.listdir(dir_name) + for file_name in files: + if file_name == Constant.ASCEND_PROFILER_METADATA_JSON: + with open(os.path.join(dir_name, file_name), encoding="utf-8") as f: + config = json.load(f) + task_role = config.get("role") + if task_role: + return task_role + return None diff --git a/verl/tools/cluster_analysis/parser.py b/verl/tools/cluster_analysis/parser.py new file mode 100644 index 00000000000..6750492264d --- /dev/null +++ b/verl/tools/cluster_analysis/parser.py @@ -0,0 +1,184 @@ +import logging +import multiprocessing +from abc import ABC, abstractmethod +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import Callable, Optional + +import pandas as pd +from schema import Constant, DataMap, EventRow + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], +) +logger = logging.getLogger(__name__) + + +class BaseClusterParser(ABC): + def __init__(self, params) -> None: + self.events_summary: Optional[pd.DataFrame] = None + self.input_path = params.get(Constant.INPUT_PATH, "") + self._data_type = params.get(Constant.DATA_TYPE, {}) + rank_list = params.get(Constant.RANK_LIST, "all") + self._rank_list = ( + rank_list if rank_list == "all" else [int(rank) for rank in rank_list.split(",") if rank.isdigit()] + ) + + def parse(self) -> pd.DataFrame: + """Run parsing and return the parsed DataFrame.""" + _data_maps = self.allocate_prof_data(self.input_path) + mapper_res = self.mapper_func(_data_maps) + self.reducer_func(mapper_res) + return self.get_data() + + def mapper_func(self, data_maps: list[DataMap]): + if not data_maps: + logger.info("No data maps to process") + return [] + + total_ranks = len(data_maps) + max_workers = min(total_ranks, multiprocessing.cpu_count()) + logger.info(f"Starting parallel processing: {total_ranks} ranks with {max_workers} workers") + + results = [] + completed = 0 + + with ProcessPoolExecutor(max_workers=max_workers) as executor: + # Submit all tasks + future_to_rank = { + executor.submit(self._mapper_func, data_map): data_map[Constant.RANK_ID] for data_map in data_maps + } + + for future in as_completed(future_to_rank): + rank_id = future_to_rank[future] + completed += 1 + progress = (completed / total_ranks) * 100 + try: + result = future.result() + results.append(result) + logger.info(f"Completed rank {rank_id}: {completed}/{total_ranks} ({progress:.1f}%)") + except Exception as e: + logger.error(f"Failed to process rank {rank_id}: {e}") + + logger.info(f"Parallel processing completed: {completed}/{total_ranks} ranks processed") + return results + + def _mapper_func(self, data_map: DataMap) -> list[EventRow]: + """Collect RL performance data from a single rank""" + profiler_data_path = data_map.get(Constant.PROFILER_DATA_PATH) + rank_id = data_map.get(Constant.RANK_ID) + role = data_map.get(Constant.ROLE) + + if not profiler_data_path: + logger.warning(f"Rank {rank_id}: profiler_data_path not found") + return None + + return self.parse_analysis_data(profiler_data_path, rank_id, role) + + def reducer_func(self, mapper_res): + """Process data collected from all ranks""" + # Flatten valid results from all ranks + reduce_results: list[dict] = [] + for result in mapper_res: + if not result: + continue + if isinstance(result, list): + reduce_results.extend(result) + else: + raise TypeError(f"parse_analysis_data must return list[dict] or None, got {type(result)}") + + if not reduce_results: + logger.warning("No valid data collected from any rank") + return + + reduce_results.sort(key=lambda x: x["start_time_ms"]) + self.events_summary = pd.DataFrame(reduce_results) + + def clean_data(self) -> None: + self.events_summary = None + + def get_data(self) -> pd.DataFrame: + return self.events_summary + + @abstractmethod + def allocate_prof_data(self, input_path: str) -> list[DataMap]: + """ + Allocate and organize profiling data from the input path. + + This method is responsible for: + 1. Scanning the input directory for profiling data files + 2. Identifying ranks and their corresponding profiler data paths + 3. Returning a list of DataMap objects that map each rank to its data + + Args: + input_path: Root directory path containing profiling data + + Returns: + list[DataMap]: A list of dictionaries, where each dict contains: + - rank_id (int): The rank identifier + - role (str): The RL role name (e.g., 'rollout_generate', 'actor_compute_log_prob') + - profiler_data_path (str): Path to the profiler data file for this rank + + Important: + - Must return a list, even if empty + - Each DataMap must contain all three required keys: 'rank_id', 'role', 'profiler_data_path' + - profiler_data_path should point to an existing file; empty string indicates missing data + - The returned list is used by mapper_func for parallel processing + """ + raise NotImplementedError + + @abstractmethod + def parse_analysis_data(self, profiler_data_path: str, rank_id: int, role: str) -> list[EventRow]: + """ + Parse profiling data for a specific rank and return event information. + + This method is responsible for: + 1. Reading the profiler data file (JSON, DB, etc.) + 2. Extracting timing events with their metadata + 3. Converting time units to milliseconds (start_time_ms, end_time_ms, duration_ms) + + Args: + profiler_data_path: Path to the profiler data file for this rank + rank_id: The rank identifier (for logging and data attribution) + role: The RL role name (e.g., 'rollout_generate', 'actor_compute_log_prob') + + Returns: + list[EventRow]: A list of event dictionaries, where each dict contains: + - name (str): Event name (e.g., 'generate_sequence', 'compute_log_prob') + - role (str): The RL role name (same as input parameter) + - domain (str): Event domain (e.g., 'default', 'communication_group') + - start_time_ms (float): Event start time in milliseconds + - end_time_ms (float): Event end time in milliseconds + - duration_ms (float): Event duration in milliseconds (end_time_ms - start_time_ms) + - rank_id (int): The rank identifier (same as input parameter) + - tid (int | str): Thread ID or process ID + + Important: + - Must return a list, even if empty (no events found) + - All time values must be in milliseconds (ms) + - Must satisfy: end_time_ms > start_time_ms > 0 + - Must satisfy: duration_ms = end_time_ms - start_time_ms + - The returned list is aggregated across all ranks and sorted by start_time_ms + - If profiler_data_path is invalid or no events found, return an empty list + """ + raise NotImplementedError + + +CLUSTER_PARSER_REGISTRY: dict[str, type[BaseClusterParser]] = {} + + +def register_cluster_parser(name: str) -> Callable[type[BaseClusterParser], type[BaseClusterParser]]: + def decorator(cls: type[BaseClusterParser]) -> type[BaseClusterParser]: + CLUSTER_PARSER_REGISTRY[name] = cls + return cls + + return decorator + + +def get_cluster_parser_cls(name): + if name not in CLUSTER_PARSER_REGISTRY: + raise ValueError( + f"Unsupported cluster parser: {name}. Supported cls are: {list(CLUSTER_PARSER_REGISTRY.keys())}" + ) + return CLUSTER_PARSER_REGISTRY[name] diff --git a/verl/tools/cluster_analysis/schema.py b/verl/tools/cluster_analysis/schema.py new file mode 100644 index 00000000000..9ca6b990e60 --- /dev/null +++ b/verl/tools/cluster_analysis/schema.py @@ -0,0 +1,66 @@ +from logging import RootLogger +import os +import stat +from dataclasses import dataclass +from typing import TypedDict + +class DataMap(TypedDict): + rank_id: int + role: str + profiler_data_path: str + +class EventRow(TypedDict): + name: str + role: str + domain: str + start_time_ms: float + end_time_ms: float + duration_ms: float + rank_id: int + tid: int | str + +@dataclass +class FigureConfig: + title_prefix: str + t0: float + y_mappings: dict + y_axis_spacing: int = 60 + chart_height_min: int = 800 + chart_height_max: int = 3000 + xaxis_max_pad_ratio: float = 0.02 + nticks: int = 15 + margin_left: int = 180 + margin_right: int = 50 + margin_top: int = 80 + margin_bottom: int = 50 + +class Constant: + ROLE = "role" + COMMUNICATION_GROUP_DOMAIN = "communication_group" + # params + INPUT_PATH = "input_path" + DATA_MAP = "data_map" + DATA_TYPE = "data_type" + PROFILER_TYPE = "profiler_type" + RANK_LIST = "rank_list" + RANK_ID = "rank_id" + PROFILER_DATA_PATH = "profiler_data_path" + + # for Ascend profile + ASCEND_PROFILER_OUTPUT = "ASCEND_PROFILER_OUTPUT" + ASCEND_PROFILER_SUFFIX = "ascend_pt" + ASCEND_PROFILER_INFO_HEAD = "profiler_info_" + ASCEND_PROFILER_METADATA_JSON = "profiler_metadata.json" + + # result files type + TEXT = "text" + DB = "db" + JSON_EXTENSION = ".json" + + # Unit Conversion + US_TO_MS = 1000 + NS_TO_US = 1000 + + # file authority + WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP + WRITE_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_TRUNC diff --git a/verl/tools/cluster_analysis/visualizer.py b/verl/tools/cluster_analysis/visualizer.py new file mode 100644 index 00000000000..6c2e930deb8 --- /dev/null +++ b/verl/tools/cluster_analysis/visualizer.py @@ -0,0 +1,346 @@ +import os +from typing import Callable + +import numpy as np +import pandas as pd +import plotly.graph_objects as go +from schema import FigureConfig + +ClusterVisualizerFn = Callable[ + [pd.DataFrame, str, dict], + None, +] + +COLOR_PALETTE = [ + "#4e79a7", + "#f28e8b", + "#59a14f", + "#b07aa1", + "#9c755f", + "#76b7b2", + "#edc948", + "#bab0ab", + "#8cd17d", + "#ff9da7", +] + +CLUSTER_VISUALIZER_REGISTRY: dict[str, ClusterVisualizerFn] = {} + + +def register_cluster_visualizer(name: str) -> Callable[[ClusterVisualizerFn], ClusterVisualizerFn]: + def decorator(func: ClusterVisualizerFn) -> ClusterVisualizerFn: + CLUSTER_VISUALIZER_REGISTRY[name] = func + return func + + return decorator + + +def get_cluster_visualizer_fn(fn_name): + if fn_name not in CLUSTER_VISUALIZER_REGISTRY: + raise ValueError( + f"Unsupported cluster visualizer: {fn_name}. Supported fns are: {list(CLUSTER_VISUALIZER_REGISTRY.keys())}" + ) + return CLUSTER_VISUALIZER_REGISTRY[fn_name] + + +@register_cluster_visualizer("html") +def cluster_visualizer_html(data: pd.DataFrame, output_path: str, config: dict) -> None: + generate_rl_timeline(data, output_path) + print("in html") + + +@register_cluster_visualizer("chart") +def cluster_visualizer_chart(data: pd.DataFrame, output_path: str, config: dict) -> None: + print("in chart") + + +def generate_rl_timeline( + input_data: pd.DataFrame, + output_dir=None, + output_filename="rl_timeline.html", + title_prefix="RL Timeline", +): + """ + Generate an RL event timeline Gantt chart with interactive Y-axis sorting by Rank ID. + + Args: + input_data: A pandas DataFrame containing events_summary data. + DataFrame should have columns: role, domain, rank_id, start_time_ms, end_time_ms + output_dir: Directory to save the HTML file + output_filename: Name of the output HTML file + title_prefix: Prefix for the chart title + """ + df, t0 = load_and_preprocess(input_data) + df = merge_short_events(df) + df = downsample_if_needed(df) + y_mappings, y_axis_spacing = build_y_mappings(df) + traces = build_traces(df, y_mappings["default"]) + cfg = FigureConfig( + title_prefix=title_prefix, + t0=t0, + y_mappings=y_mappings, + y_axis_spacing=y_axis_spacing, + ) + fig = assemble_figure(traces, df, cfg) + save_html(fig, output_dir, output_filename) + return fig + + +def load_and_preprocess(input_data: pd.DataFrame) -> tuple[pd.DataFrame, float]: + """ + Load and preprocess data from a pandas DataFrame. + + Args: + input_data: A pandas DataFrame containing events_summary data + + Returns: + Tuple of (preprocessed DataFrame, t0 offset) + """ + if input_data is None: + raise ValueError(f"input_data: {input_data} is None!") + + df = input_data.copy() + + df.rename( + columns={ + "role": "Role", + "name": "Name", + "rank_id": "Rank ID", + "start_time_ms": "Start", + "end_time_ms": "Finish", + }, + inplace=True, + errors="ignore", + ) + + required = ["Role", "Name", "Rank ID", "Start", "Finish"] + for col in required: + if col not in df.columns: + raise ValueError(f"Required column missing: {col}") + + df = df.dropna(subset=required).copy() + df["Start"] = pd.to_numeric(df["Start"], errors="coerce") + df["Finish"] = pd.to_numeric(df["Finish"], errors="coerce") + df["Rank ID"] = pd.to_numeric(df["Rank ID"], errors="coerce").astype("Int64") + df = df.dropna(subset=["Start", "Finish", "Rank ID"]) + df = df[df["Finish"] > df["Start"]].copy() + df["Duration"] = df["Finish"] - df["Start"] + + if df.empty: + return df, 0.0 + + t0 = df["Start"].min() + df["Start"] -= t0 + df["Finish"] -= t0 + df["Duration"] = df["Finish"] - df["Start"] + return df, t0 + + +def merge_short_events(df: pd.DataFrame, threshold_ms: float = 10.0) -> pd.DataFrame: + def _merge_group(g: pd.DataFrame) -> pd.DataFrame: + short = g[g["Duration"] < threshold_ms] + long = g[g["Duration"] >= threshold_ms] + if short.empty: + return long + merged = pd.DataFrame( + [ + { + "Start": short["Start"].min(), + "Finish": short["Finish"].max(), + "Role": short.iloc[0]["Role"], + "Rank ID": short.iloc[0]["Rank ID"], + "Name": short.iloc[0]["Name"], + "Duration": short["Finish"].max() - short["Start"].min(), + } + ] + ) + return pd.concat([long, merged], ignore_index=True) + + return df.groupby(["Role", "Rank ID", "Name"], group_keys=False).apply(_merge_group).reset_index(drop=True) + + +def downsample_if_needed( + df: pd.DataFrame, + max_records: int = 5000, + random_state: int = 42, +) -> pd.DataFrame: + if len(df) <= max_records: + return df + n_domains = df["Name"].nunique() + samples_per_domain = max_records // max(1, n_domains) + + def _sample_domain(g: pd.DataFrame) -> pd.DataFrame: + if len(g) <= samples_per_domain: + return g + return g.sample(n=samples_per_domain, random_state=random_state) + + return df.groupby("Name", group_keys=False).apply(_sample_domain).reset_index(drop=True) + + +def build_y_mappings(df: pd.DataFrame): + df["Y_Label"] = df["Role"] + " - Rank " + df["Rank ID"].astype(str) + unique_y_labels = df["Y_Label"].unique() + + def _extract_rank(label: str): + try: + return int(label.split(" - Rank ")[-1]) + except Exception: + return float("inf") + + y_axis_spacing = max(60, min(100, 800 // max(1, len(unique_y_labels)))) + bar_height = y_axis_spacing * 0.8 + + y_labels_default = unique_y_labels + mapping_default = {label: i * y_axis_spacing for i, label in enumerate(y_labels_default)} + df["Y_default"] = df["Y_Label"].map(mapping_default) + + y_labels_by_rank = sorted(unique_y_labels, key=lambda x: (_extract_rank(x), x)) + mapping_by_rank = {label: i * y_axis_spacing for i, label in enumerate(y_labels_by_rank)} + df["Y_by_rank"] = df["Y_Label"].map(mapping_by_rank) + + return { + "default": mapping_default, + "by_rank": mapping_by_rank, + "bar_height": bar_height, + }, y_axis_spacing + + +def build_traces(df: pd.DataFrame, y_mapping: dict): + unique_domains = df["Name"].unique() + color_map = {dom: COLOR_PALETTE[i % len(COLOR_PALETTE)] for i, dom in enumerate(unique_domains)} + bar_height = y_mapping.get("bar_height", 48) + + traces = [] + for domain in unique_domains: + dom_df = df[df["Name"] == domain] + trace = go.Bar( + base=dom_df["Start"], + x=dom_df["Duration"], + y=dom_df["Y_default"], + orientation="h", + name=domain, + marker_color=color_map[domain], + width=bar_height, + hovertemplate=( + "%{data.name}
" + "Start: %{base:.3f} ms
" + "End: %{customdata[1]:.3f} ms
" + "Duration: %{x:.3f} ms
" + "Rank: %{customdata[0]}" + ), + customdata=np.column_stack([dom_df["Y_Label"], dom_df["Finish"]]), + showlegend=True, + textposition="none", + ) + traces.append(trace) + return traces + +def assemble_figure(traces: list[go.Bar], df: pd.DataFrame, cfg: FigureConfig) -> go.Figure: + max_time = df["Finish"].max() + unique_y_labels = sorted(df["Y_Label"].unique()) + + h = max( + cfg.chart_height_min, + min(len(unique_y_labels) * cfg.y_axis_spacing, cfg.chart_height_max), + ) + + fig = go.Figure(data=traces) + fig.update_layout( + title=f"{cfg.title_prefix} (Relative Time, Origin = {cfg.t0:.3f} ms)", + xaxis_title="Time (ms, Relative)", + yaxis_title="Module - Rank", + xaxis=dict( + range=[0, max_time * (1 + cfg.xaxis_max_pad_ratio)], + tickformat=".1f", + nticks=cfg.nticks, + ), + yaxis=dict( + tickmode="array", + tickvals=list(cfg.y_mappings["default"].values()), + ticktext=list(cfg.y_mappings["default"].keys()), + autorange="reversed", + ), + barmode="overlay", + height=h, + hovermode="closest", + legend_title="Event Type", + margin=dict( + l=cfg.margin_left, + r=cfg.margin_right, + t=cfg.margin_top, + b=cfg.margin_bottom, + ), + updatemenus=[ + dict( + type="buttons", + direction="left", + buttons=[ + dict( + args=[{"hovermode": "closest"}], + label="Hover: Current Only", + method="relayout", + ), + dict( + args=[{"hovermode": "x unified"}], + label="Hover: All Ranks", + method="relayout", + ), + ], + pad={"r": 10, "t": 10}, + showactive=True, + x=0.7, + xanchor="left", + y=1.07, + yanchor="top", + ), + dict( + type="buttons", + direction="left", + buttons=[ + dict( + args=[ + {"y": [df[df["Name"] == t.name]["Y_default"].tolist() for t in traces]}, + { + "yaxis.tickvals": list(cfg.y_mappings["default"].values()), + "yaxis.ticktext": list(cfg.y_mappings["default"].keys()), + }, + ], + label="Sort: Default", + method="update", + ), + dict( + args=[ + {"y": [df[df["Name"] == t.name]["Y_by_rank"].tolist() for t in traces]}, + { + "yaxis.tickvals": list(cfg.y_mappings["by_rank"].values()), + "yaxis.ticktext": list(cfg.y_mappings["by_rank"].keys()), + }, + ], + label="Sort: By Rank ID", + method="update", + ), + ], + pad={"r": 10, "t": 10}, + showactive=True, + x=0.85, + xanchor="left", + y=1.07, + yanchor="top", + ), + ], + ) + return fig + +def save_html(fig: go.Figure, output_dir: str, output_filename: str): + os.makedirs(output_dir, exist_ok=True) + out_path = os.path.join(output_dir, output_filename) + fig.write_html( + out_path, + include_plotlyjs="cdn", + full_html=True, + config={ + "displaylogo": False, + "displayModeBar": True, + "toImageButtonOptions": {"format": "png", "scale": 2}, + }, + )