Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions docs/source/overview/architecture.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,36 @@ Nsight Python collects `gpu__time_duration.sum` by default. To collect other NVI
...

**Derived Metrics**
Define a Python function that computes metrics like TFLOPs based on runtime and input configuration:
Define a Python function that computes custom metrics like TFLOPS based on runtime and input configuration:

.. code-block:: python

def tflops(t, m, n, k):
def tflops_scalar(t, m, n, k):
return 2 * m * n * k / (t / 1e9) / 1e12

@nsight.analyze.kernel(configs=[(1024, 1024, 64)], derive_metric=tflops)
def benchmark(m, n, k):
@nsight.analyze.kernel(configs=[(1024, 1024, 64)], derive_metric=tflops_scalar)
def benchmark1(m, n, k):
...

# or
def tflops_dict(t, m, n, k):
return {"TFLOPS": 2 * m * n * k / (t / 1e9) / 1e12}

@nsight.analyze.kernel(configs=[(1024, 1024, 64)], derive_metric=tflops_dict)
def benchmark2(m, n, k):
...

# or
def tflops_and_arith_intensity(t, m, n, k):
tflops = 2 * m * n * k / (t / 1e9) / 1e12
memory_bytes = (m * k + k * n + m * n) * 4
return {
"TFLOPS": tflops,
"ArithIntensity": tflops / memory_bytes,
}

@nsight.analyze.kernel(configs=[(1024, 1024, 64)], derive_metric=tflops_and_arith_intensity)
def benchmark3(m, n, k):
...

**Relative Metrics**
Expand Down
159 changes: 137 additions & 22 deletions examples/03_custom_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,170 @@
# SPDX-License-Identifier: Apache-2.0

"""
Example 3: Custom Metrics (TFLOPs)
===================================
Example 3: Custom Metrics (TFLOPs and Arithmetic Intensity)
============================================================

This example shows how to compute custom metrics from timing data.
This example demonstrates two patterns for using `derive_metric` to compute
custom performance metrics from profiling data.

New concepts:
- Using `derive_metric` to compute custom values (e.g., TFLOPs)
- Customizing plot labels with `ylabel`
- The `annotate_points` parameter to show values on the plot

Additional insights on `derive_metric` usage patterns:
1. `derive_metric` can return either:
- A single scalar value (e.g., TFLOPS)
- A dictionary containing one or more derived metrics (e.g., {"TFLOPS": value} or
{"TFLOPS": value, "ArithIntensity": value})
2. When using `derive_metric`, plotting requires explicit metric specification:
- Without `derive_metric`: Only one collected metric exists -> plot(metric=None) works
- With `derive_metric`: Multiple metrics exist -> MUST specify which metric to plot
3. How to specify which metric to plot in different scenarios:
- For scalar returns: plot(metric="function_name")
- For dictionary returns: plot(metric="dictionary_key")
"""

import torch

import nsight

# Matrix sizes to benchmark: 2^11, 2^12, 2^13
sizes = [(2**i,) for i in range(11, 14)]


# ------------------------------------------------------------------------------
# Pattern 1: Returning a single scalar value
# ------------------------------------------------------------------------------


def compute_tflops(time_ns: float, n: int) -> float:
"""
Compute TFLOPs/s for matrix multiplication.
Compute TFLOPS for matrix multiplication.

Custom metric function signature:
- First argument: the measured metric (time in nanoseconds by default)
- Remaining arguments: must match the decorated function's signature
This function demonstrates the first pattern: returning a single scalar value.

In this example:
- time_ns: The measured metric (gpu__time_duration.sum in nanoseconds)
- n: Matches the 'n' parameter from benchmark_tflops(n)
Function signature convention for `derive_metric`:
- First argument: the measured base metric (default: gpu__time_duration.sum in nanoseconds)
- Remaining arguments: must match the decorated function's parameters

If your function was benchmark(size, dtype, batch), your metric function
would be: my_metric(time_ns, size, dtype, batch)
Note: When `derive_metric` returns a single value, the plot decorator's
`metric` parameter must be set to the FUNCTION NAME (as a string,
"compute_tflops" in this case).

Args:
time_ns: Kernel execution time in nanoseconds (automatically passed)
n: Matrix size (n x n) - matches benchmark_tflops parameter

Returns:
TFLOPs/s (higher is better)
TFLOPS (higher is better)
"""
# Matrix multiplication: 2*n^3 FLOPs (n^3 multiplies + n^3 adds)
# Matrix multiplication FLOPs: 2 * n^3 (n^3 multiplications + n^3 additions)
flops = 2 * n * n * n
# Convert ns to seconds and FLOPs to TFLOPs

# Compute TFLOPS
tflops = flops / (time_ns / 1e9) / 1e12

# This function can also return a directory of one metric, such as
# {"TFLOPS": tflops}, but the "metric" of the plot decorator must be
# set to "TFLOPS" instead of "compute_tflops".
return tflops


@nsight.analyze.plot(
filename="03_custom_metrics.png",
ylabel="Performance (TFLOPs/s)", # Custom y-axis label
annotate_points=True, # Show values on the plot
filename="03_custom_metrics_tflops.png",
metric="compute_tflops", # Must match the function name of `derive_metric` when returning scalar
ylabel="Performance (TFLOPS)",
annotate_points=True,
)
@nsight.analyze.kernel(
configs=sizes, runs=10, derive_metric=compute_tflops # Use custom metric
configs=sizes, runs=10, derive_metric=compute_tflops # Single scalar return
)
def benchmark_tflops(n: int) -> None:
"""
Benchmark matmul and display results in TFLOPs/s.
Benchmark matrix multiplication and display results in TFLOPS.

This example shows:
- When `derive_metric` returns a single value, the plot metric parameter
must be the function name ("compute_tflops")
- Without `derive_metric`, there's only one collected metric (time duration),
we don't need to specify a metric because plot(metric=None) works by default
- With `derive_metric`, we have >1 metrics (time duration + derived), so we must
explicitly specify which metric to plot
"""
a = torch.randn(n, n, device="cuda")
b = torch.randn(n, n, device="cuda")

with nsight.annotate("matmul"):
_ = a @ b


# ------------------------------------------------------------------------------
# Pattern 2: Returning a dictionary of metrics
# ------------------------------------------------------------------------------


def compute_tflops_and_arithmetic_intensity(time_ns: float, n: int) -> dict[str, float]:
"""
Compute both TFLOPS and Arithmetic Intensity for matrix multiplication.

This function demonstrates the second pattern: returning a dictionary
containing multiple derived metrics.

Important: When derive_metric returns a dictionary, the plot decorator's
`metric` parameter must be set to a KEY from the dictionary (as a string).

Note: A single scalar value could also be returned as a dictionary with
one key-value pair for consistency, but returning the scalar directly is
more concise.

Args:
time_ns: Kernel execution time in nanoseconds (automatically passed)
n: Matrix size (n x n) - matches benchmark_tflops parameter

Returns:
Dictionary with two metrics (TFLOPS and ArithIntensity)
"""
# Matrix multiplication FLOPs: 2 * n^3
flops = 2 * n * n * n

# Compute TFLOPS
tflops = flops / (time_ns / 1e9) / 1e12

# Memory access calculation:
# - Input matrices: n * n each, Output matrix: n * n
# - Float32 datatype (4 bytes per element)
memory_bytes = (n * n + n * n + n * n) * 4

# Arithmetic Intensity = FLOPs / Bytes accessed
arithmetic_intensity = flops / memory_bytes

return {
"TFLOPS": tflops,
"ArithIntensity": arithmetic_intensity,
}


@nsight.analyze.plot(
filename="03_custom_metrics_arith_intensity.png",
metric="ArithIntensity", # Must be a key from the returned dictionary
ylabel="Arithmetic Intensity (FLOPs/Byte)",
annotate_points=True,
)
@nsight.analyze.kernel(
configs=sizes,
runs=10,
derive_metric=compute_tflops_and_arithmetic_intensity, # Dictionary return
)
def benchmark_tflops_and_arithmetic_intensity(n: int) -> None:
"""
Benchmark matrix multiplication with multiple derived metrics.

This example shows:
- When `derive_metric` returns a dictionary, the plot metric parameter
must be a key from that dictionary (e.g., "ArithIntensity")
- You can have multiple derived metrics but plot only one at a time
- All derived metrics are available in the ProfileResults object
"""
a = torch.randn(n, n, device="cuda")
b = torch.randn(n, n, device="cuda")
Expand All @@ -69,9 +175,18 @@ def benchmark_tflops(n: int) -> None:


def main() -> None:
result = benchmark_tflops()
# Run single-metric benchmark
print("Running TFLOPs benchmark (scalar return pattern)...")
benchmark_tflops()
print("✓ TFLOPs benchmark complete! Check '03_custom_metrics_tflops.png'\n")

# Run multi-metric benchmark
print("Running combined benchmark (dictionary return pattern)...")
result = benchmark_tflops_and_arithmetic_intensity()
print(result.to_dataframe())
print("✓ TFLOPs benchmark complete! Check '03_custom_metrics.png'")

print("\n✓ TFLOPs and Arithmetic Intensity benchmark complete! ", end="")
print("Check '03_custom_metrics_arith_intensity.png'")


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions examples/04_multi_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
configs = list(itertools.product(sizes, dtypes))


def compute_tflops(time_ns: float, *conf: Any) -> float:
def compute_tflops(time_ns: float, *conf: Any) -> dict[str, float]:
"""
Compute TFLOPs/s.

Expand All @@ -45,11 +45,12 @@ def compute_tflops(time_ns: float, *conf: Any) -> float:

flops = 2 * n * n * n
tflops: float = flops / (time_ns / 1e9) / 1e12
return tflops
return {"TFLOPS": tflops}


@nsight.analyze.plot(
filename="04_multi_parameter.png",
metric="TFLOPS",
ylabel="Performance (TFLOPs/s)",
annotate_points=True,
)
Expand Down
8 changes: 5 additions & 3 deletions examples/05_subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@
configs = list(itertools.product(sizes, dtypes, transpose))


def compute_tflops(time_ns: float, *conf: Any) -> float:
def compute_tflops(time_ns: float, *conf: Any) -> dict[str, float]:
"""Compute TFLOPs/s using *conf to extract only what we need."""
n: int = conf[0] # Extract size (dtype and transpose not needed)
flops = 2 * n * n * n
tflops: float = flops / (time_ns / 1e9) / 1e12
return tflops
return {"TFLOPS": tflops}


@nsight.analyze.plot(
filename="05_subplots.png",
metric="TFLOPS",
title="Matrix Multiplication Performance",
ylabel="TFLOPs/s",
row_panels=["dtype"], # Create row for each dtype
Expand All @@ -61,7 +62,8 @@ def benchmark_with_subplots(n: int, dtype: torch.dtype, transpose: bool) -> None


def main() -> None:
benchmark_with_subplots()
result = benchmark_with_subplots()
print(result.to_dataframe())
print("✓ Subplot benchmark complete! Check '05_subplots.png'")


Expand Down
6 changes: 4 additions & 2 deletions examples/06_plot_customization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
sizes = [(2**i,) for i in range(11, 14)]


def compute_tflops(time_ns: float, n: int) -> float:
def compute_tflops(time_ns: float, n: int) -> dict[str, float]:
flops = 2 * n * n * n
return flops / (time_ns / 1e9) / 1e12
return {"TFLOPS": flops / (time_ns / 1e9) / 1e12}


# Example 1: Bar chart
@nsight.analyze.plot(
filename="06_bar_chart.png",
metric="TFLOPS",
title="Matrix Multiplication Performance",
ylabel="TFLOPs/s",
plot_type="bar", # Use bar chart instead of line plot
Expand Down Expand Up @@ -68,6 +69,7 @@ def custom_style(fig: Any) -> None:

@nsight.analyze.plot(
filename="06_custom_plot.png",
metric="TFLOPS",
plot_callback=custom_style, # Apply custom styling
)
@nsight.analyze.kernel(configs=sizes, runs=10, derive_metric=compute_tflops)
Expand Down
Loading