Skip to content

Commit deb9183

Browse files
FindHaofacebook-github-bot
authored andcommitted
Add nsys report analyzer (#65)
Summary: This PR add a nsys report analyzer providing metrics ```python nsys_metrics_to_reports = { # the sum of kernel execution time "nsys_gpu_kernel_sum": ["cuda_gpu_kern_sum", "nvtx_sum"], # the overhead of kernel launch "nsys_launch_overhead": ["cuda_gpu_kern_sum", "nvtx_sum"], # the names of kernels "nsys_kernel_names": ["cuda_gpu_kern_sum"], # the durations of kernels "nsys_kernel_durations": ["cuda_gpu_kern_sum"], # the duration of nvtx range "nsys_nvtx_range_duration": ["nvtx_sum"], # the number of kernels "nsys_num_of_kernels": ["cuda_gpu_kern_sum"], } ``` `nsys_gpu_kernel_sum` is the sum of total GPU kernel execution time on GPUs, the `nsys_nvtx_range_duration ` is the total execution time of the operator, and the `nsys_launch_overhead` is their difference which indicates the launch overhead. This is one way to measure execution time mentioned in #50 Fix #67 Pull Request resolved: #65 Test Plan: ``` % python run.py --op rope --num-inputs 1 --metrics nsys_gpu_kernel_sum,nsys_launch_overhead,nsys_kernel_names,nsys_kernel_durations,nsys_nvtx_range_duration,nsys_num_of_kernels --csv --dump-csv 0%| | 0/1 [00:00<?, ?it/s]`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46 0%| | 0/1 [00:00<?, ?it/s]`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46 Capture range started in the application. Capture range ended in the application. Generating '/tmp/nsys-report-531e.qdstrm' [1/1] [0% ] nsys_output.nsys-repProcessing events... [1/1] [========================100%] nsys_output.nsys-rep Generated: /tmp/tritonbench/rope/nsys_traces/apply_rotary_pos_emb_0/nsys_output.nsys-rep 0%| | 0/1 [00:00<?, ?it/s]`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46 Capture range started in the application. Capture range ended in the application. Generating '/tmp/nsys-report-39ea.qdstrm' [1/1] [0% ] nsys_output.nsys-repProcessing events... [1/1] [========================100%] nsys_output.nsys-rep Generated: /tmp/tritonbench/rope/nsys_traces/liger_rotary_pos_emb_0/nsys_output.nsys-rep 0%| | 0/1 [00:00<?, ?it/s]`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46 Capture range started in the application. Capture range ended in the application. Generating '/tmp/nsys-report-e8bf.qdstrm' [1/1] [0% ] nsys_output.nsys-repProcessing events... [1/1] [========================100%] nsys_output.nsys-rep Generated: /tmp/tritonbench/rope/nsys_traces/inductor_rotary_pos_emb_full_op_0/nsys_output.nsys-rep 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:52<00:00, 52.40s/it] (H, T);apply_rotary_pos_emb-nsys_kernel_names;apply_rotary_pos_emb-nsys_kernel_durations;apply_rotary_pos_emb-nsys_gpu_kernel_sum;apply_rotary_pos_emb-nsys_num_of_kernels;apply_rotary_pos_emb-nsys_launch_overhead;apply_rotary_pos_emb-nsys_nvtx_range_duration;liger_rotary_pos_emb-nsys_kernel_names;liger_rotary_pos_emb-nsys_kernel_durations;liger_rotary_pos_emb-nsys_gpu_kernel_sum;liger_rotary_pos_emb-nsys_num_of_kernels;liger_rotary_pos_emb-nsys_launch_overhead;liger_rotary_pos_emb-nsys_nvtx_range_duration;inductor_rotary_pos_emb_full_op-nsys_kernel_names;inductor_rotary_pos_emb_full_op-nsys_kernel_durations;inductor_rotary_pos_emb_full_op-nsys_gpu_kernel_sum;inductor_rotary_pos_emb_full_op-nsys_num_of_kernels;inductor_rotary_pos_emb_full_op-nsys_launch_overhead;inductor_rotary_pos_emb_full_op-nsys_nvtx_range_duration (8192, 1024);['void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::native::BinaryFunctor<float, float, float, at::native::binary_internal::MulFunctor<float>>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)', 'void at::native::<unnamed>::CatArrayBatchedCopy<at::native::<unnamed>::OpaqueType<(unsigned int)4>, unsigned int, (int)4, (int)64, (int)64>(T1 *, at::native::<unnamed>::CatArrInputTensorMetadata<T1, T2, T4, T5>, at::native::<unnamed>::TensorSizeStride<T2, (unsigned int)4>, int, T2)', 'void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::native::CUDAFunctor_add<float>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)', 'void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::native::neg_kernel_cuda(at::TensorIteratorBase &)::[lambda() (instance 2)]::operator ()() const::[lambda() (instance 7)]::operator ()() const::[lambda(float) (instance 1)]>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)'];0.090065;0.351364;4;0.4534;0.804764;['_triton_rope'];0.049281;0.049281;1;0.176437;0.225718;['triton_poi_fused_add_cat_mul_0', 'triton_poi_fused_add_cat_mul_1'];0.0266885;0.053377;2;0.444969;0.498346 [TritonBench] Dumped csv to /tmp/tritonbench/op_rope__z_yqmrz.csv ``` Reviewed By: xuzhao9 Differential Revision: D66311127 Pulled By: FindHao fbshipit-source-id: 085454e34a3e9aadb360309cc69885684a8a1758
1 parent bec6d9f commit deb9183

File tree

3 files changed

+160
-6
lines changed

3 files changed

+160
-6
lines changed

tritonbench/components/ncu/analyzer.py renamed to tritonbench/components/ncu/ncu_analyzer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def get_arithmetic_intensity(kernel):
152152
def read_ncu_report(report_path: str, required_metrics: List[str]):
153153
assert os.path.exists(
154154
report_path
155-
), f"The NCU report at {report_path} does not exist. Ensure you add --metrics ncu_rep to your benchmark run."
155+
), f"The NCU report at {report_path} does not exist."
156156
import_ncu_python_path()
157157
import ncu_report
158158

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import csv
2+
import os
3+
import subprocess
4+
from typing import Dict, List
5+
6+
# The nsys metrics to the reports. The value is the list of reports of nsys.
7+
nsys_metrics_to_reports = {
8+
# the sum of kernel execution time
9+
"nsys_gpu_kernel_sum": ["nvtx_kern_sum", "nvtx_sum"],
10+
# the overhead of kernel launch
11+
"nsys_launch_overhead": ["nvtx_kern_sum", "nvtx_sum"],
12+
# the names of kernels
13+
"nsys_kernel_names": ["nvtx_kern_sum"],
14+
# the durations of kernels
15+
"nsys_kernel_durations": ["nvtx_kern_sum"],
16+
# the duration of nvtx range
17+
"nsys_nvtx_range_duration": ["nvtx_sum"],
18+
# the number of kernels
19+
"nsys_num_of_kernels": ["nvtx_kern_sum"],
20+
}
21+
22+
23+
def read_nsys_report(
24+
report_path: str, required_metrics: List[str]
25+
) -> Dict[str, List[float]]:
26+
assert os.path.exists(
27+
report_path
28+
), f"The nsys report at {report_path} does not exist."
29+
reports_required = []
30+
for metric in required_metrics:
31+
if metric in nsys_metrics_to_reports:
32+
reports_required.extend(nsys_metrics_to_reports[metric])
33+
reports_required = list(set(reports_required))
34+
assert reports_required, "No nsys reports required"
35+
cmd = f"nsys stats --report {','.join(reports_required)} --force-export=true --format csv --output . --force-overwrite=true {report_path}"
36+
try:
37+
subprocess.check_call(
38+
cmd.split(), stdout=subprocess.DEVNULL, stderr=subprocess.PIPE
39+
)
40+
except subprocess.CalledProcessError as e:
41+
print(f"Failed to run nsys command: {cmd}\nError: {e}")
42+
raise e
43+
# Get the base path and filename without extension
44+
base_path = os.path.dirname(report_path)
45+
base_name = os.path.splitext(os.path.basename(report_path))[0]
46+
47+
results = {}
48+
csv_contents = {}
49+
50+
for report in reports_required:
51+
csv_path = os.path.join(base_path, f"{base_name}_{report}.csv")
52+
if not os.path.exists(csv_path):
53+
raise RuntimeError(f"Expected CSV report not found at {csv_path}")
54+
55+
# Read CSV using DictReader
56+
with open(csv_path, "r") as f:
57+
reader = csv.DictReader(f)
58+
csv_contents[report] = list(reader)
59+
kernel_duration = []
60+
kernel_names = []
61+
sum_kernel_duration = 0
62+
nvtx_range_duration = 0
63+
if "nvtx_kern_sum" in csv_contents:
64+
# gpu kernel execution time summary
65+
for row in csv_contents["nvtx_kern_sum"]:
66+
# use ms as the unit
67+
kernel_duration.append(float(row["Total Time (ns)"]) / 1_000_000)
68+
kernel_names.append(row["Kernel Name"])
69+
sum_kernel_duration = sum(kernel_duration)
70+
if "nvtx_sum" in csv_contents:
71+
# It is supposed to be only one row. The nvtx range is `:tritonbench_range`
72+
assert len(csv_contents["nvtx_sum"]) == 1
73+
# @TODO: nsys has a bug that the unit of nvtx range duration is ms sometimes.
74+
# waiting for nvidia replys.
75+
nvtx_range_duration = (
76+
float(csv_contents["nvtx_sum"][0]["Total Time (ns)"]) / 1_000_000
77+
)
78+
79+
# Define mapping of metrics to their values. The keys must be in nsys_bench_metrics.
80+
metrics_map = {
81+
# Because tritonbench takes the median of numerical values, we need to convert
82+
# the list of floats to a list of strings.
83+
"nsys_kernel_durations": [str(duration) for duration in kernel_duration],
84+
"nsys_kernel_names": kernel_names,
85+
"nsys_gpu_kernel_sum": sum_kernel_duration,
86+
"nsys_nvtx_range_duration": nvtx_range_duration,
87+
"nsys_launch_overhead": nvtx_range_duration - sum_kernel_duration,
88+
"nsys_num_of_kernels": len(kernel_names),
89+
}
90+
# Verify that metrics_map keys match nsys_metrics_to_reports keys
91+
assert set(metrics_map.keys()) == set(nsys_metrics_to_reports.keys()), (
92+
f"Mismatch between metrics_map keys and nsys_metrics_to_reports keys.\n"
93+
f"metrics_map keys: {set(metrics_map.keys())}\n"
94+
f"nsys_metrics_to_reports keys: {set(nsys_metrics_to_reports.keys())}"
95+
)
96+
# Add only requested metrics to results
97+
results.update(
98+
{
99+
metric: metrics_map[metric]
100+
for metric in required_metrics
101+
if metric in metrics_map
102+
}
103+
)
104+
105+
return results

tritonbench/utils/triton_op.py

+54-5
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424
import triton
2525

26-
from tritonbench.components.ncu import analyzer as ncu_analyzer
26+
from tritonbench.components.ncu import ncu_analyzer, nsys_analyzer
2727
from tritonbench.utils.env_utils import (
2828
apply_precision,
2929
fresh_triton_cache,
@@ -68,7 +68,12 @@ class BenchmarkOperatorBackend:
6868
REGISTERED_METRICS: Dict[str, List[str]] = {}
6969
REGISTERED_X_VALS: Dict[str, str] = {}
7070
BASELINE_BENCHMARKS: Dict[str, str] = {}
71-
BASELINE_SKIP_METRICS = {"speedup", "accuracy", "mem_footprint_compression_ratio"}
71+
BASELINE_SKIP_METRICS = {
72+
"speedup",
73+
"accuracy",
74+
"mem_footprint_compression_ratio",
75+
"nsys_gpu_speedup",
76+
}
7277
X_ONLY_METRICS = set(["hw_roofline"])
7378
PRECISION_DTYPE_MAPPING = {
7479
"fp32": torch.float32,
@@ -222,6 +227,8 @@ class BenchmarkOperatorMetrics:
222227
mem_footprint_compression_ratio: Optional[float] = None
223228
# gbps
224229
gbps: Optional[float] = None
230+
# speedup for the summary of kernel GPU time only
231+
nsys_gpu_speedup: Optional[float] = None
225232

226233

227234
BUILTIN_METRICS = {x.name for x in fields(BenchmarkOperatorMetrics)} - {"extra_metrics"}
@@ -307,9 +314,25 @@ def select_metric(backend, m):
307314
)
308315
metric_val = _metrics_dict.get(metric, None)
309316
if isinstance(metric_val, list):
310-
row.append(numpy.median(metric_val))
317+
# Check if all elements are numbers before calculating median
318+
if all(isinstance(x, Number) for x in metric_val):
319+
row.append(numpy.median(metric_val))
320+
else:
321+
# For non-numeric lists, convert to string representation
322+
metric_val_str = str(metric_val)
323+
if ";" in metric_val_str:
324+
logger.warning(
325+
f"Metric value '{metric_val_str}' contains semicolon which may cause CSV parsing issues"
326+
)
327+
row.append(metric_val_str)
311328
elif isinstance(metric_val, bool):
312329
row.append(1.0 if metric_val else 0.0)
330+
elif isinstance(metric_val, str):
331+
if ";" in metric_val:
332+
logger.warning(
333+
f"Metric value '{metric_val}' contains semicolon which may cause CSV parsing issues"
334+
)
335+
row.append(metric_val)
313336
else:
314337
row.append(metric_val)
315338
table.append(row)
@@ -1065,8 +1088,34 @@ def _init_extra_metrics() -> Dict[str, Any]:
10651088
metrics.ncu_rep_ir = self.ncu_trace(
10661089
input_id, fn_name, replay=True, profile_ir=True
10671090
)
1068-
if "nsys_rep" in self.required_metrics:
1069-
metrics.nsys_rep = self.nsys_rep(input_id, fn_name)
1091+
nsys_metrics = []
1092+
for metric_name in nsys_analyzer.nsys_metrics_to_reports.keys():
1093+
if metric_name in self.required_metrics:
1094+
nsys_metrics.append(metric_name)
1095+
1096+
if "nsys_rep" in self.required_metrics or nsys_metrics:
1097+
nsys_rep_path = self.nsys_rep(input_id, fn_name)
1098+
metrics.nsys_rep = nsys_rep_path
1099+
if nsys_metrics:
1100+
nsys_analyzer_results = nsys_analyzer.read_nsys_report(
1101+
nsys_rep_path, nsys_metrics
1102+
)
1103+
for metric_name, metric_value in nsys_analyzer_results.items():
1104+
metrics.extra_metrics[metric_name] = metric_value
1105+
if "nsys_gpu_speedup" in self.required_metrics:
1106+
baseline_nsys_gpu_kernel_sum = (
1107+
self.baseline_metrics.extra_metrics.get("nsys_gpu_kernel_sum", None)
1108+
if self.baseline_metrics
1109+
else None
1110+
)
1111+
current_nsys_gpu_kernel_sum = metrics.extra_metrics.get(
1112+
"nsys_gpu_kernel_sum", None
1113+
)
1114+
metrics.nsys_gpu_speedup = (
1115+
baseline_nsys_gpu_kernel_sum / current_nsys_gpu_kernel_sum
1116+
if baseline_nsys_gpu_kernel_sum and current_nsys_gpu_kernel_sum
1117+
else None
1118+
)
10701119
if "kineto_trace" in self.required_metrics:
10711120
metrics.kineto_trace = self.kineto_trace(input_id, fn)
10721121
if "best_config" in self.required_metrics:

0 commit comments

Comments
 (0)