diff --git a/docs/source/overview/architecture.rst b/docs/source/overview/architecture.rst index e01dbd5..7a2f5d7 100644 --- a/docs/source/overview/architecture.rst +++ b/docs/source/overview/architecture.rst @@ -41,15 +41,36 @@ Nsight Python collects `gpu__time_duration.sum` by default. To collect other NVI ... **Derived Metrics** -Define a Python function that computes metrics like TFLOPs based on runtime and input configuration: +Define a Python function that computes custom metrics like TFLOPS based on runtime and input configuration: .. code-block:: python - def tflops(t, m, n, k): + def tflops_scalar(t, m, n, k): return 2 * m * n * k / (t / 1e9) / 1e12 - @nsight.analyze.kernel(configs=[(1024, 1024, 64)], derive_metric=tflops) - def benchmark(m, n, k): + @nsight.analyze.kernel(configs=[(1024, 1024, 64)], derive_metric=tflops_scalar) + def benchmark1(m, n, k): + ... + + # or + def tflops_dict(t, m, n, k): + return {"TFLOPS": 2 * m * n * k / (t / 1e9) / 1e12} + + @nsight.analyze.kernel(configs=[(1024, 1024, 64)], derive_metric=tflops_dict) + def benchmark2(m, n, k): + ... + + # or + def tflops_and_arith_intensity(t, m, n, k): + tflops = 2 * m * n * k / (t / 1e9) / 1e12 + memory_bytes = (m * k + k * n + m * n) * 4 + return { + "TFLOPS": tflops, + "ArithIntensity": tflops / memory_bytes, + } + + @nsight.analyze.kernel(configs=[(1024, 1024, 64)], derive_metric=tflops_and_arith_intensity) + def benchmark3(m, n, k): ... **Relative Metrics** diff --git a/examples/03_custom_metrics.py b/examples/03_custom_metrics.py index 723043d..4e19285 100644 --- a/examples/03_custom_metrics.py +++ b/examples/03_custom_metrics.py @@ -2,64 +2,170 @@ # SPDX-License-Identifier: Apache-2.0 """ -Example 3: Custom Metrics (TFLOPs) -=================================== +Example 3: Custom Metrics (TFLOPs and Arithmetic Intensity) +============================================================ -This example shows how to compute custom metrics from timing data. +This example demonstrates two patterns for using `derive_metric` to compute +custom performance metrics from profiling data. New concepts: - Using `derive_metric` to compute custom values (e.g., TFLOPs) - Customizing plot labels with `ylabel` - The `annotate_points` parameter to show values on the plot + +Additional insights on `derive_metric` usage patterns: +1. `derive_metric` can return either: + - A single scalar value (e.g., TFLOPS) + - A dictionary containing one or more derived metrics (e.g., {"TFLOPS": value} or + {"TFLOPS": value, "ArithIntensity": value}) +2. When using `derive_metric`, plotting requires explicit metric specification: + - Without `derive_metric`: Only one collected metric exists -> plot(metric=None) works + - With `derive_metric`: Multiple metrics exist -> MUST specify which metric to plot +3. How to specify which metric to plot in different scenarios: + - For scalar returns: plot(metric="function_name") + - For dictionary returns: plot(metric="dictionary_key") """ import torch import nsight +# Matrix sizes to benchmark: 2^11, 2^12, 2^13 sizes = [(2**i,) for i in range(11, 14)] +# ------------------------------------------------------------------------------ +# Pattern 1: Returning a single scalar value +# ------------------------------------------------------------------------------ + + def compute_tflops(time_ns: float, n: int) -> float: """ - Compute TFLOPs/s for matrix multiplication. + Compute TFLOPS for matrix multiplication. - Custom metric function signature: - - First argument: the measured metric (time in nanoseconds by default) - - Remaining arguments: must match the decorated function's signature + This function demonstrates the first pattern: returning a single scalar value. - In this example: - - time_ns: The measured metric (gpu__time_duration.sum in nanoseconds) - - n: Matches the 'n' parameter from benchmark_tflops(n) + Function signature convention for `derive_metric`: + - First argument: the measured base metric (default: gpu__time_duration.sum in nanoseconds) + - Remaining arguments: must match the decorated function's parameters - If your function was benchmark(size, dtype, batch), your metric function - would be: my_metric(time_ns, size, dtype, batch) + Note: When `derive_metric` returns a single value, the plot decorator's + `metric` parameter must be set to the FUNCTION NAME (as a string, + "compute_tflops" in this case). Args: time_ns: Kernel execution time in nanoseconds (automatically passed) n: Matrix size (n x n) - matches benchmark_tflops parameter Returns: - TFLOPs/s (higher is better) + TFLOPS (higher is better) """ - # Matrix multiplication: 2*n^3 FLOPs (n^3 multiplies + n^3 adds) + # Matrix multiplication FLOPs: 2 * n^3 (n^3 multiplications + n^3 additions) flops = 2 * n * n * n - # Convert ns to seconds and FLOPs to TFLOPs + + # Compute TFLOPS tflops = flops / (time_ns / 1e9) / 1e12 + + # This function can also return a directory of one metric, such as + # {"TFLOPS": tflops}, but the "metric" of the plot decorator must be + # set to "TFLOPS" instead of "compute_tflops". return tflops @nsight.analyze.plot( - filename="03_custom_metrics.png", - ylabel="Performance (TFLOPs/s)", # Custom y-axis label - annotate_points=True, # Show values on the plot + filename="03_custom_metrics_tflops.png", + metric="compute_tflops", # Must match the function name of `derive_metric` when returning scalar + ylabel="Performance (TFLOPS)", + annotate_points=True, ) @nsight.analyze.kernel( - configs=sizes, runs=10, derive_metric=compute_tflops # Use custom metric + configs=sizes, runs=10, derive_metric=compute_tflops # Single scalar return ) def benchmark_tflops(n: int) -> None: """ - Benchmark matmul and display results in TFLOPs/s. + Benchmark matrix multiplication and display results in TFLOPS. + + This example shows: + - When `derive_metric` returns a single value, the plot metric parameter + must be the function name ("compute_tflops") + - Without `derive_metric`, there's only one collected metric (time duration), + we don't need to specify a metric because plot(metric=None) works by default + - With `derive_metric`, we have >1 metrics (time duration + derived), so we must + explicitly specify which metric to plot + """ + a = torch.randn(n, n, device="cuda") + b = torch.randn(n, n, device="cuda") + + with nsight.annotate("matmul"): + _ = a @ b + + +# ------------------------------------------------------------------------------ +# Pattern 2: Returning a dictionary of metrics +# ------------------------------------------------------------------------------ + + +def compute_tflops_and_arithmetic_intensity(time_ns: float, n: int) -> dict[str, float]: + """ + Compute both TFLOPS and Arithmetic Intensity for matrix multiplication. + + This function demonstrates the second pattern: returning a dictionary + containing multiple derived metrics. + + Important: When derive_metric returns a dictionary, the plot decorator's + `metric` parameter must be set to a KEY from the dictionary (as a string). + + Note: A single scalar value could also be returned as a dictionary with + one key-value pair for consistency, but returning the scalar directly is + more concise. + + Args: + time_ns: Kernel execution time in nanoseconds (automatically passed) + n: Matrix size (n x n) - matches benchmark_tflops parameter + + Returns: + Dictionary with two metrics (TFLOPS and ArithIntensity) + """ + # Matrix multiplication FLOPs: 2 * n^3 + flops = 2 * n * n * n + + # Compute TFLOPS + tflops = flops / (time_ns / 1e9) / 1e12 + + # Memory access calculation: + # - Input matrices: n * n each, Output matrix: n * n + # - Float32 datatype (4 bytes per element) + memory_bytes = (n * n + n * n + n * n) * 4 + + # Arithmetic Intensity = FLOPs / Bytes accessed + arithmetic_intensity = flops / memory_bytes + + return { + "TFLOPS": tflops, + "ArithIntensity": arithmetic_intensity, + } + + +@nsight.analyze.plot( + filename="03_custom_metrics_arith_intensity.png", + metric="ArithIntensity", # Must be a key from the returned dictionary + ylabel="Arithmetic Intensity (FLOPs/Byte)", + annotate_points=True, +) +@nsight.analyze.kernel( + configs=sizes, + runs=10, + derive_metric=compute_tflops_and_arithmetic_intensity, # Dictionary return +) +def benchmark_tflops_and_arithmetic_intensity(n: int) -> None: + """ + Benchmark matrix multiplication with multiple derived metrics. + + This example shows: + - When `derive_metric` returns a dictionary, the plot metric parameter + must be a key from that dictionary (e.g., "ArithIntensity") + - You can have multiple derived metrics but plot only one at a time + - All derived metrics are available in the ProfileResults object """ a = torch.randn(n, n, device="cuda") b = torch.randn(n, n, device="cuda") @@ -69,9 +175,18 @@ def benchmark_tflops(n: int) -> None: def main() -> None: - result = benchmark_tflops() + # Run single-metric benchmark + print("Running TFLOPs benchmark (scalar return pattern)...") + benchmark_tflops() + print("āœ“ TFLOPs benchmark complete! Check '03_custom_metrics_tflops.png'\n") + + # Run multi-metric benchmark + print("Running combined benchmark (dictionary return pattern)...") + result = benchmark_tflops_and_arithmetic_intensity() print(result.to_dataframe()) - print("āœ“ TFLOPs benchmark complete! Check '03_custom_metrics.png'") + + print("\nāœ“ TFLOPs and Arithmetic Intensity benchmark complete! ", end="") + print("Check '03_custom_metrics_arith_intensity.png'") if __name__ == "__main__": diff --git a/examples/04_multi_parameter.py b/examples/04_multi_parameter.py index b345073..06d8f27 100644 --- a/examples/04_multi_parameter.py +++ b/examples/04_multi_parameter.py @@ -29,7 +29,7 @@ configs = list(itertools.product(sizes, dtypes)) -def compute_tflops(time_ns: float, *conf: Any) -> float: +def compute_tflops(time_ns: float, *conf: Any) -> dict[str, float]: """ Compute TFLOPs/s. @@ -45,11 +45,12 @@ def compute_tflops(time_ns: float, *conf: Any) -> float: flops = 2 * n * n * n tflops: float = flops / (time_ns / 1e9) / 1e12 - return tflops + return {"TFLOPS": tflops} @nsight.analyze.plot( filename="04_multi_parameter.png", + metric="TFLOPS", ylabel="Performance (TFLOPs/s)", annotate_points=True, ) diff --git a/examples/05_subplots.py b/examples/05_subplots.py index d750e00..51a7a0d 100644 --- a/examples/05_subplots.py +++ b/examples/05_subplots.py @@ -28,16 +28,17 @@ configs = list(itertools.product(sizes, dtypes, transpose)) -def compute_tflops(time_ns: float, *conf: Any) -> float: +def compute_tflops(time_ns: float, *conf: Any) -> dict[str, float]: """Compute TFLOPs/s using *conf to extract only what we need.""" n: int = conf[0] # Extract size (dtype and transpose not needed) flops = 2 * n * n * n tflops: float = flops / (time_ns / 1e9) / 1e12 - return tflops + return {"TFLOPS": tflops} @nsight.analyze.plot( filename="05_subplots.png", + metric="TFLOPS", title="Matrix Multiplication Performance", ylabel="TFLOPs/s", row_panels=["dtype"], # Create row for each dtype @@ -61,7 +62,8 @@ def benchmark_with_subplots(n: int, dtype: torch.dtype, transpose: bool) -> None def main() -> None: - benchmark_with_subplots() + result = benchmark_with_subplots() + print(result.to_dataframe()) print("āœ“ Subplot benchmark complete! Check '05_subplots.png'") diff --git a/examples/06_plot_customization.py b/examples/06_plot_customization.py index 4cb0594..e84bdff 100644 --- a/examples/06_plot_customization.py +++ b/examples/06_plot_customization.py @@ -22,14 +22,15 @@ sizes = [(2**i,) for i in range(11, 14)] -def compute_tflops(time_ns: float, n: int) -> float: +def compute_tflops(time_ns: float, n: int) -> dict[str, float]: flops = 2 * n * n * n - return flops / (time_ns / 1e9) / 1e12 + return {"TFLOPS": flops / (time_ns / 1e9) / 1e12} # Example 1: Bar chart @nsight.analyze.plot( filename="06_bar_chart.png", + metric="TFLOPS", title="Matrix Multiplication Performance", ylabel="TFLOPs/s", plot_type="bar", # Use bar chart instead of line plot @@ -68,6 +69,7 @@ def custom_style(fig: Any) -> None: @nsight.analyze.plot( filename="06_custom_plot.png", + metric="TFLOPS", plot_callback=custom_style, # Apply custom styling ) @nsight.analyze.kernel(configs=sizes, runs=10, derive_metric=compute_tflops) diff --git a/examples/09_advanced_metric_custom.py b/examples/09_advanced_metric_custom.py index 2080609..937ebb9 100644 --- a/examples/09_advanced_metric_custom.py +++ b/examples/09_advanced_metric_custom.py @@ -15,21 +15,21 @@ import nsight -sizes = [(2**i,) for i in range(10, 13)] +sizes = [(2**i,) for i in range(10, 12)] -def compute_avg_insts( +def compute_insts_statistics( ld_insts: int, st_insts: int, launch_sm_count: int, n: int -) -> float: +) -> dict[str, float]: """ - Compute average shared memory load/store instructions per SM. + Compute shared memory instruction statistics per SM. Custom metric function signature: - First several arguments: the measured metrics, must match the order - of metrics in @kernel decorator + of metrics specified in the @kernel decorator - Remaining arguments: must match the decorated function's signature - In this example: + In this example (metrics must be listed in this exact order in @kernel): - ld_insts: Total shared memory load instructions (from smsp__inst_executed_pipe_lsu.shared_op_ld.sum metric) - st_insts: Total shared memory store instructions @@ -45,21 +45,35 @@ def compute_avg_insts( n: Matrix size (n x n) - parameter from the decorated benchmark function Returns: - Average shared memory load/store instructions per SM + Dictionary containing four derived metrics: + - "ld_insts_per_sm": Average load instructions per SM + - "st_insts_per_sm": Average store instructions per SM + - "insts_total": Total shared memory instructions (load + store) + - "insts_per_sm": Average total instructions per SM """ + ld_insts_per_sm = ld_insts / launch_sm_count + st_insts_per_sm = st_insts / launch_sm_count + insts_total = ld_insts + st_insts insts_per_sm = (ld_insts + st_insts) / launch_sm_count - return insts_per_sm + + return { + "ld_insts_per_sm": ld_insts_per_sm, + "st_insts_per_sm": st_insts_per_sm, + "insts_total": insts_total, + "insts_per_sm": insts_per_sm, + } @nsight.analyze.plot( filename="09_advanced_metric_custom.png", + metric="insts_per_sm", ylabel="Average Shared Memory Load/Store Instructions per SM", # Custom y-axis label annotate_points=True, # Show values on the plot ) @nsight.analyze.kernel( configs=sizes, runs=10, - derive_metric=compute_avg_insts, # Use custom metric + derive_metric=compute_insts_statistics, # Use custom metric metrics=[ "smsp__sass_inst_executed_op_shared_ld.sum", "smsp__sass_inst_executed_op_shared_st.sum", @@ -72,10 +86,15 @@ def benchmark_avg_insts(n: int) -> None: """ a = torch.randn(n, n, device="cuda") b = torch.randn(n, n, device="cuda") + c = torch.randn(2 * n, 2 * n, device="cuda") + d = torch.randn(2 * n, 2 * n, device="cuda") - with nsight.annotate("matmul"): + with nsight.annotate("@-operator"): _ = a @ b + with nsight.annotate("torch-matmul"): + _ = torch.matmul(c, d) + def main() -> None: result = benchmark_avg_insts() diff --git a/examples/11_output_csv.py b/examples/11_output_csv.py index 5e983ce..23e1921 100644 --- a/examples/11_output_csv.py +++ b/examples/11_output_csv.py @@ -21,7 +21,7 @@ # Get current directory for output current_dir = os.path.dirname(os.path.abspath(__file__)) -output_prefix = f"{current_dir}/example10_" +output_prefix = f"{current_dir}/example11_" # Matrix sizes to benchmark @@ -101,7 +101,7 @@ def read_and_display_csv_files() -> None: # Find CSV files csv_files = [] for file in os.listdir(current_dir): - if file.startswith("example10_") and file.endswith(".csv"): + if file.startswith("example11_") and file.endswith(".csv"): csv_files.append(os.path.join(current_dir, file)) for file_path in sorted(csv_files): @@ -129,7 +129,7 @@ def read_and_display_csv_files() -> None: def main() -> None: # Clean up any previous output files for old_file in os.listdir(current_dir): - if old_file.startswith("example10_") and old_file.endswith( + if old_file.startswith("example11_") and old_file.endswith( (".csv", ".ncu-rep", ".log") ): os.remove(os.path.join(current_dir, old_file)) diff --git a/nsight/analyze.py b/nsight/analyze.py index e8bbbc4..27fd9b0 100644 --- a/nsight/analyze.py +++ b/nsight/analyze.py @@ -12,6 +12,7 @@ import matplotlib import matplotlib.figure import numpy as np +from numpy.typing import NDArray import nsight.collection as collection import nsight.visualization as visualization @@ -31,7 +32,7 @@ def kernel( *, configs: Iterable[Any] | None = None, runs: int = 1, - derive_metric: Callable[..., float] | None = None, + derive_metric: Callable[..., float | dict[str, float]] | None = None, normalize_against: str | None = None, output: Literal["quiet", "progress", "verbose"] = "progress", metrics: Sequence[str] = ["gpu__time_duration.sum"], @@ -52,7 +53,7 @@ def kernel( *, configs: Iterable[Any] | None = None, runs: int = 1, - derive_metric: Callable[..., float] | None = None, + derive_metric: Callable[..., float | dict[str, float]] | None = None, normalize_against: str | None = None, output: Literal["quiet", "progress", "verbose"] = "progress", metrics: Sequence[str] = ["gpu__time_duration.sum"], @@ -105,7 +106,12 @@ def wrapped_function(*args, configs=None, **kwargs) -> ProfileResults This can be used to compute derived metrics like TFLOPs that cannot be captured by ncu directly. The function takes the metric values and the arguments of the profile-decorated function and returns the new - metric. The parameter order requirements for the custom function: + metrics. Return value can be either: + + - A single value (float/int): Will be added as a new metric with the function name as the metric name. For lambda functions, the metric name will be ``""``. + - A dictionary: Keys will be used as metric names, values as metric values. + + The parameter order requirements for the custom function: - First several arguments: Must exactly match the order of metrics declared in the @kernel decorator. These arguments will receive the actual measured values of those metrics. - Remaining arguments: Must exactly match the signature of the decorated function. In other words, the original function's parameters are passed in order. @@ -175,8 +181,7 @@ def wrapped_function(*args, configs=None, **kwargs) -> ProfileResults - ``Annotation``: Name of the annotated region being profiled - ``Value``: Raw metric values collected by the profiler - - ``Metric``: The metrics being collected (e.g., ``gpu__time_duration.sum``) - - ``Transformed``: Name of the function used to transform the metrics (specified via ``derive_metric``), or ``False`` if no transformation was applied. For lambda functions, this shows ``""`` + - ``Metric``: The metrics being collected (e.g., ``gpu__time_duration.sum``) and the metrics being derived - ``Kernel``: Name of the GPU kernel(s) launched - ``GPU``: GPU device name - ``Host``: Host machine name @@ -197,8 +202,7 @@ def wrapped_function(*args, configs=None, **kwargs) -> ProfileResults - ``CI95_Upper``: Upper bound of the 95% confidence interval - ``RelativeStdDevPct``: Standard deviation as a percentage of the mean - ``StableMeasurement``: Boolean indicating if the measurement is stable (low variance). The measurement is stable if ``RelativeStdDevPct`` < 2 % . - - ``Metric``: The metrics being collected - - ``Transformed``: Name of the function used to transform the metrics (specified via ``derive_metric``), or ``False`` if no transformation was applied. For lambda functions, this shows ``""`` + - ``Metric``: The metrics being collected and the metrics being derived - ``Kernel``: Name of the GPU kernel(s) launched - ``GPU``: GPU device name - ``Host``: Host machine name @@ -259,29 +263,63 @@ def _create_profiler() -> collection.core.NsightProfiler: return profiler(_func) # type: ignore[return-value] -def _validate_metric(result: collection.core.ProfileResults) -> None: +def _legalize_metric(result: collection.core.ProfileResults, metric: str | None) -> str: """ - Check if ProfileResults contains only a single metric. + Legalize the metric parameter for plotting. + + This function ensures that the provided metric parameter is valid for plotting. + If the metric is None, it verifies there is exactly one metric available and + returns that metric's name. If a metric is specified, it validates that the + metric exists in the profile results and returns itself. Args: - result: ProfileResults object + result: ProfileResults object containing profiling data. + metric: The name of the metric to plot, or None to use the single metric + if only one exists. + + Returns: + The validated metric name for plotting. Raises: - ValueError: If multiple metrics are detected + ValueError: Raised in the following cases: + 1. When `metric` is None and multiple metrics are found. + 2. When `metric` is None and no metrics are found. + 3. When `metric` is specified but not present in the results. """ df = result.to_dataframe() - # Check for multiple metrics in "Metric" column - unique_metrics = df["Metric"].unique() - if len(unique_metrics) > 1: - raise ValueError( - f"Cannot visualize {len(unique_metrics)} > 1 metrics with the " - "@nsight.analyze.plot decorator." - ) + # Extract unique metric names from the DataFrame + unique_metrics: NDArray[Any] = df["Metric"].unique() + n_metrics = len(unique_metrics) + + if metric is None: + # Case: No metric specified - ensure there is exactly one metric + if n_metrics > 1: + raise ValueError( + f"Cannot visualize multiple metrics ({n_metrics} > 1) with the " + "@nsight.analyze.plot decorator. Please specify a metric to plot. " + f"The metrics found are: {unique_metrics.tolist()}" + ) + elif n_metrics == 0: + raise ValueError( + "No metrics found in the profile results. Please ensure that the " + "profiler is configured correctly." + ) + else: # n_metrics == 1 + return str(unique_metrics[0]) + else: + # Case: Metric specified - ensure it exists in the results + if metric not in unique_metrics: + raise ValueError( + f"Metric '{metric}' not found in the profile results. Please check " + f"the metric name and try again. The metrics found are: {unique_metrics.tolist()}." + ) + return metric def plot( filename: str = "plot.png", + metric: str | None = None, *, title: str = "Nsight Analyze Kernel Plot Results", ylabel: str | None = None, @@ -320,6 +358,10 @@ def wrapped_function(*args, configs=None, **kwargs) -> ProfileResults def my_func(...): Args: + metric: The specific metric to plot (e.g., ``gpu__time_duration.sum``). + If None (default), the function will plot the single metric found in + the ``ProfileResults`` object containing profiling data, if only one exists, + otherwise it will raise an error if multiple metrics are present. Default: ``None`` filename: Filename to save the plot. Default: ``'plot'`` title: Title for the plot. Default: ``'Nsight Analyze Kernel Plot Results'`` ylabel: Label for the y-axis in the generated plot. @@ -358,11 +400,12 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: result = func(*args, **kwargs) if "NSPY_NCU_PROFILE" not in os.environ: - # Check for multiple metrics or complex data structures - _validate_metric(result) + # Legalize the user-provided metric for plotting + legalized_metric = _legalize_metric(result, metric) visualization.visualize( result.to_dataframe(), + metric=legalized_metric, row_panels=row_panels, col_panels=col_panels, x_keys=x_keys, diff --git a/nsight/collection/core.py b/nsight/collection/core.py index c2eb19a..d0d0b85 100644 --- a/nsight/collection/core.py +++ b/nsight/collection/core.py @@ -271,7 +271,7 @@ class ProfileSettings: Will display a progress bar, detailed output for each config along with the profiler logs """ - derive_metric: Callable[..., float] | None + derive_metric: Callable[..., float | dict[str, float]] | None """ A function to transform the collected metrics. This can be used to compute derived metrics like TFLOPs that cannot @@ -339,8 +339,7 @@ def to_dataframe(self) -> pd.DataFrame: - ``CI95_Upper``: Upper bound of the 95% confidence interval - ``RelativeStdDevPct``: Standard deviation as a percentage of the mean - ``StableMeasurement``: Boolean indicating if the measurement is stable (low variance). The measurement is stable if ``RelativeStdDevPct`` < 2 % . - - ``Metric``: The metrics being collected - - ``Transformed``: Name of the function used to transform the metrics (specified via ``derive_metric``), or ``False`` if no transformation was applied. For lambda functions, this shows ``""`` + - ``Metric``: The metrics being collected and the metrics being derived - ``Kernel``: Name of the GPU kernel(s) launched - ``GPU``: GPU device name - ``Host``: Host machine name diff --git a/nsight/extraction.py b/nsight/extraction.py index 113518f..f715301 100644 --- a/nsight/extraction.py +++ b/nsight/extraction.py @@ -19,7 +19,7 @@ import inspect import socket from collections.abc import Callable, Sequence -from typing import Any, List, Tuple +from typing import Any, List, Tuple, TypeAlias import ncu_report import numpy as np @@ -29,6 +29,9 @@ from nsight import exceptions, utils from nsight.utils import is_scalar +DerivedValue: TypeAlias = float | int | None +DerivedValueDict: TypeAlias = dict[str, DerivedValue] + def extract_ncu_action_data(action: Any, metrics: Sequence[str]) -> utils.NCUActionData: """ @@ -63,45 +66,6 @@ def extract_ncu_action_data(action: Any, metrics: Sequence[str]) -> utils.NCUAct ) -def explode_dataframe(df: pd.DataFrame) -> pd.DataFrame: - """ - Explode columns with list/tuple/np.ndarray values into multiple rows. - Two scenarios: - - 1. No derived metrics (all "Transformed" = False): - - All columns maybe contain multiple values (lists/arrays). - - Use `explode()` to flatten each list element into separate rows. - - 2. With derived metrics: - - Metric columns contain either single-element lists or scalars. - - Only flatten single-element lists to scalars, don't create new rows. - - Args: - df: Dataframe to be exploded. - - Returns: - Exploded dataframe. - """ - df_explode = None - if df["Transformed"].eq(False).all(): - # 1: No derived metrics - explode all columns with sequences into rows. - df_explode = df.apply(pd.Series.explode).reset_index(drop=True) - else: - # 2: With derived metrics - only explode columns with single-value sequences. - df_explode = df.apply( - lambda col: ( - col.apply( - lambda x: ( - x[0] - if isinstance(x, (list, tuple, np.ndarray)) and len(x) == 1 - else x - ) - ) - ) - ) - return df_explode - - def extract_df_from_report( report_path: str, metrics: Sequence[str], @@ -145,13 +109,16 @@ def extract_df_from_report( ) annotations: List[str] = [] - all_values: List[NDArray[Any] | None] = [] + all_values: List[Tuple[Any, ...] | None] = [] + all_transformed_values: List[ + List[DerivedValue] | DerivedValue | NDArray[Any] | None + ] = [] kernel_names: List[str] = [] gpus: List[str] = [] compute_clocks: List[int] = [] memory_clocks: List[int] = [] all_metrics: List[Tuple[str, ...]] = [] - all_transformed_metrics: List[str | bool] = [] + all_transformed_metrics: List[List[str] | str | bool] = [] hostnames: List[str] = [] sig = inspect.signature(func) @@ -245,19 +212,46 @@ def extract_df_from_report( # evaluate the measured metrics values = data.values if derive_metric is not None: - derived_metric: float | int | None = ( + if not callable(derive_metric): + raise TypeError("derive_metric must be a callable function") + + if values is not None: + derive_metric_params = inspect.signature(derive_metric).parameters + has_varargs: bool = any( + p.kind == inspect.Parameter.VAR_POSITIONAL + for p in derive_metric_params.values() + ) + actual_params = None if has_varargs else len(derive_metric_params) + # If there are varargs, skip the check + if actual_params is not None: + expected_params = len(values) + len(conf) + if actual_params != expected_params: + raise ValueError( + f"derive_metric expects {expected_params} parameters " + f"({len(values)} metric values + {len(conf)} configs), " + f"but has {actual_params} parameters" + ) + + derived_metric: DerivedValueDict | DerivedValue = ( None if values is None else derive_metric(*values, *conf) ) - values = derived_metric # type: ignore[assignment] - derive_metric_name = derive_metric.__name__ - all_transformed_metrics.append(derive_metric_name) - else: - all_transformed_metrics.append(False) - - all_values.append(values) + if isinstance(derived_metric, dict): + # If the derived metric is a dict, then we have multiple metrics + # and use the keys of the dict as metric names. + metric_names, metric_values = zip( + *[(k, v) for k, v in derived_metric.items()] + ) + all_transformed_values.append(list(metric_values)) + all_transformed_metrics.append(list(metric_names)) + else: + # If the derived metric is a scalar, then we have a single metric + # and use the name of the function as the metric name. + all_transformed_values.append(derived_metric) + all_transformed_metrics.append(derive_metric.__name__) # gather remaining required data annotations.append(annotation) + all_values.append(tuple(values) if values is not None else None) all_metrics.append(tuple(metrics)) hostnames.append(socket.gethostname()) # Add a field for every config argument @@ -270,7 +264,6 @@ def extract_df_from_report( "Annotation": annotations, "Value": all_values, "Metric": all_metrics, - "Transformed": all_transformed_metrics, "Kernel": kernel_names, "GPU": gpus, "Host": hostnames, @@ -283,6 +276,30 @@ def extract_df_from_report( df_data[arg_name] = arg_values # Explode the dataframe - df = explode_dataframe(pd.DataFrame(df_data)) + df = pd.DataFrame(df_data).apply(pd.Series.explode).reset_index(drop=True) + + if derive_metric is not None: + transformed_df_data = { + "Annotation": annotations, + "Value": all_transformed_values, + "Metric": all_transformed_metrics, + "Kernel": kernel_names, + "GPU": gpus, + "Host": hostnames, + "ComputeClock": compute_clocks, + "MemoryClock": memory_clocks, + } + + for arg_name, arg_values in arg_arrays.items(): + transformed_df_data[arg_name] = arg_values + + transformed_df = ( + pd.DataFrame(transformed_df_data) + .apply(pd.Series.explode) + .reset_index(drop=True) + ) + + # Concat the two dataframes + df = pd.concat([df, transformed_df], ignore_index=True) return df diff --git a/nsight/transformation.py b/nsight/transformation.py index 2c4a9bc..4e61f69 100644 --- a/nsight/transformation.py +++ b/nsight/transformation.py @@ -76,13 +76,13 @@ def convert_non_sortable_columns(dframe: pd.DataFrame) -> pd.DataFrame: } # The columns to aggregate except for the function parameters - groupby_columns = ["Annotation", "Metric", "Transformed"] + groupby_columns = ["Annotation", "Metric"] # Add assertion-based unique selection for remaining fields remaining_fields = [ col for col in df.columns - if col not in [*groupby_columns, "Value", "_original_order"] + func_fields + if col not in [*groupby_columns, "Value", "_original_order", *func_fields] ] for col in remaining_fields: @@ -136,11 +136,11 @@ def convert_non_sortable_columns(dframe: pd.DataFrame) -> pd.DataFrame: ), f"Annotation '{normalize_against}' not found in data." # Columns of normalization dataframe to merge on - merge_on = func_fields + ["Metric", "Transformed"] + merge_on = [*func_fields, "Metric"] # Create a DataFrame to hold the normalization values normalization_df = agg_df[agg_df["Annotation"] == normalize_against][ - merge_on + ["AvgValue"] + [*merge_on, "AvgValue"] ] normalization_df = normalization_df.rename( columns={"AvgValue": "NormalizationValue"} diff --git a/nsight/visualization.py b/nsight/visualization.py index 79857ef..b90dcde 100644 --- a/nsight/visualization.py +++ b/nsight/visualization.py @@ -20,6 +20,7 @@ def visualize( agg_df: str | pd.DataFrame, + metric: str | None, row_panels: Sequence[str] | None, col_panels: Sequence[str] | None, x_keys: Sequence[str] | None = None, @@ -43,6 +44,10 @@ def visualize( Args: agg_df: Aggregated profiling data or path to CSV file. + metric: The specific metric to plot (e.g., ``gpu__time_duration.sum``). + If None (default), the function will plot the single metric found in + the ``ProfileResults`` object containing profiling data, if only one exists, + otherwise it will raise an error if multiple metrics are present. Default: ``None`` row_panels: List of fields for whose unique values to create a new subplot along the vertical axis. col_panels: List of fields for whose unique values @@ -69,6 +74,10 @@ def visualize( agg_df, pd.DataFrame ), f"agg_df must be a pandas DataFrame or a CSV file path, not {type(agg_df)}" + # Filter by metric + if metric is not None: + agg_df = agg_df[agg_df["Metric"] == metric].copy() + row_panels = row_panels or [] col_panels = col_panels or [] @@ -81,7 +90,7 @@ def visualize( # Build Configuration field excluding variant_fields annotation_idx = agg_df.columns.get_loc("AvgValue") - func_fields = list(agg_df.columns[3:annotation_idx]) + func_fields = list(agg_df.columns[2:annotation_idx]) subplot_fields = row_panels + col_panels # type: ignore[operator] non_panel_fields = [ field @@ -202,7 +211,7 @@ def visualize( config_fields = x_keys else: annotation_idx = local_df.columns.get_loc("AvgValue") - func_fields = list(local_df.columns[3:annotation_idx]) + func_fields = list(local_df.columns[2:annotation_idx]) subplot_fields = row_panels + col_panels # type: ignore[operator] config_exclude = set(variant_fields or []) config_fields = [ diff --git a/tests/test_api_params.py b/tests/test_api_params.py index 0873080..260778b 100644 --- a/tests/test_api_params.py +++ b/tests/test_api_params.py @@ -52,6 +52,7 @@ def get_app_args() -> argparse.Namespace: # nsight.analyze.plot() parameters # TBD no command line arguments yet for: row_panels, col_panels, x_keys, annotate_points, show_aggregate parser.add_argument("--plot-title", "-l", default="test", help="Plot title") + parser.add_argument("--plot-metric", "-e", default=None, help="Plot metric") parser.add_argument( "--plot-filename", "-f", default="params_test1.png", help="Plot filename" ) @@ -86,6 +87,7 @@ def einsum(a: torch.Tensor, b: torch.Tensor) -> Any: @nsight.analyze.plot( title=args.plot_title, + metric=args.plot_metric, filename=args.plot_filename, plot_type=args.plot_type, print_data=args.plot_print_data, diff --git a/tests/test_collection.py b/tests/test_collection.py index 4e430e6..4f4d11c 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -3,7 +3,7 @@ import subprocess import sys -from typing import Any, Dict +from typing import Any from unittest.mock import MagicMock, call, patch import pytest @@ -95,7 +95,7 @@ class EnvMatcher(dict[str, str]): def __eq__(self, other: object) -> bool: if not isinstance(other, dict): return False - subset: Dict[str, str] = self + subset: dict[str, str] = self return all(item in other.items() for item in subset.items()) pytest.helpers = type("helpers", (), {})() @@ -103,7 +103,7 @@ def __eq__(self, other: object) -> bool: def mock_any_command_string() -> Matcher: return Matcher("any-ncu-command") - def env_contains(expected_subset: Dict[str, str]) -> EnvMatcher: + def env_contains(expected_subset: dict[str, str]) -> EnvMatcher: return EnvMatcher(expected_subset) pytest.helpers.mock_any_command_string = mock_any_command_string diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 45aa88e..957e09d 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -307,19 +307,26 @@ def no_args_with_transform() -> None: assert result is not None, "ProfileResults should be returned" df = result.to_dataframe() - # Should have exactly 1 row - assert len(df) == 1, f"Expected 1 row in dataframe, got {len(df)}" + # Expected: exactly 2 rows, one for each metric ("gpu__time_duration.sum" and "custom_metric") + assert len(df) == 2, f"Expected 2 row in dataframe, got {len(df)}" + + # Verify the collected metric + assert ( + df["Metric"].iloc[0] == "gpu__time_duration.sum" + ), f"Expected 'gpu__time_duration.sum' in Metric column, got {df['Metric'].iloc[0]}" # Verify the transformation was applied assert ( - df["Transformed"].iloc[0] == "custom_metric" - ), f"Expected 'custom_metric' in Transformed column, got {df['Transformed'].iloc[0]}" + df["Metric"].iloc[1] == "custom_metric" + ), f"Expected 'custom_metric' in Metric column, got {df['Metric'].iloc[1]}" - # Verify the value is positive (transformed metric should still be positive) - assert df["AvgValue"].iloc[0] > 0, "Expected positive transformed metric value" + # Verify the values are positive (transformed metric should still be positive) + assert all(df["AvgValue"] > 0), "Expected positive transformed metric value" # Verify runs parameter was respected - assert df["NumRuns"].iloc[0] == 2, f"Expected 2 runs, got {df['NumRuns'].iloc[0]}" + assert all( + df["NumRuns"] == 2 + ), f"Expected 2 runs, got {list(df['NumRuns'][df['NumRuns'] == 2].values)}" # ---------------------------------------------------------------------------- @@ -930,7 +937,7 @@ def multiple_kernels_replay_test(n: int) -> None: # ============================================================================ -def _compute_custom_metric(time_ns: float, x: int, y: int) -> float: +def _compute_custom_metric_1(time_ns: float, x: int, y: int) -> float: """Transform time in nanoseconds to a custom metric based on matrix size.""" # Custom formula: operations per second (arbitrary for testing) operations = x * y @@ -938,18 +945,35 @@ def _compute_custom_metric(time_ns: float, x: int, y: int) -> float: return operations / time_s if time_s > 0 else 0.0 +def _compute_custom_metric_2(time_ns: float, x: int, y: int) -> dict[str, float]: + """Transform time in nanoseconds to a custom metric based on matrix size.""" + # Custom formula: operations per second (arbitrary for testing) + operations = x * y + time_s = time_ns / 1e9 + return {"Custom Metric": operations / time_s if time_s > 0 else 0.0} + + @pytest.mark.parametrize( "derive_metric_func,expected_name", [ - (_compute_custom_metric, "_compute_custom_metric"), + (_compute_custom_metric_1, "_compute_custom_metric_1"), + (_compute_custom_metric_2, "Custom Metric"), (lambda time_ns, x, y: (x * y) / (time_ns / 1e9) / 1e9, ""), + ( + lambda time_ns, x, y: {"Custom Metric": (x * y) / (time_ns / 1e9) / 1e9}, + "Custom Metric", + ), ], ) # type: ignore[untyped-decorator] def test_parameter_derive_metric(derive_metric_func: Any, expected_name: str) -> None: """Test the derive_metric parameter to transform collected metrics.""" + configs = [(100, 100), (200, 200)] + # Number of raw collected metric rows, excluding any derived metrics + raw_metric_rows = len(configs) # as metrics = 1, annotations = 1 + @nsight.analyze.kernel( - configs=[(100, 100), (200, 200)], + configs=configs, runs=2, output="quiet", derive_metric=derive_metric_func, @@ -963,17 +987,21 @@ def profiled_func(x: int, y: int) -> None: # Verify the transformed metric is present df = profile_output.to_dataframe() - assert "Transformed" in df.columns, "Transformed column should exist" - assert ( - df["Transformed"].iloc[0] == expected_name - ), f"Transformed column should show '{expected_name}'" + + assert all( + df["Metric"][0:raw_metric_rows] == "gpu__time_duration.sum" + ), f"Metric column (row-{0} ~ row-{raw_metric_rows-1}) should show 'gpu__time_duration.sum'" + + assert all( + df["Metric"][raw_metric_rows : len(df)] == expected_name + ), f"Metric column (row-{raw_metric_rows} ~ row-{len(df)-1}) should show '{expected_name}'" # Verify the metric values are transformed (should be positive numbers) assert "AvgValue" in df.columns, "AvgValue column should exist" assert all(df["AvgValue"] > 0), "All derived metric values should be positive" - # Verify we have results for both configs - assert len(df) == 2, "Should have results for 2 configurations" + # Verify we have results for all combinations (2 configs * 2 metrics = 4 rows) + assert len(df) == 2 * 2, "Should have results for 2 configurations * 2 metrics" # ============================================================================ @@ -1016,7 +1044,7 @@ def profiled_func(x: int, y: int) -> None: @pytest.mark.parametrize( # type: ignore[untyped-decorator] - "metrics, expected_result", + "metrics,expected_result", [ pytest.param( [ @@ -1042,12 +1070,12 @@ def profiled_func(x: int, y: int) -> None: ), ], ) -def test_parameter_metric(metrics: Sequence[str], expected_result: str) -> None: +def test_parameter_metrics(metrics: Sequence[str], expected_result: str) -> None: @nsight.analyze.plot(filename="plot.png", ylabel="Instructions") @nsight.analyze.kernel(configs=[(100, 100), (200, 200)], runs=2, metrics=metrics) def profiled_func(x: int, y: int) -> None: - _simple_kernel_impl(x, y, "test_parameter_metric") + _simple_kernel_impl(x, y, "test_parameter_metrics") # Run profiling if expected_result == "invalid_single": @@ -1075,7 +1103,115 @@ def profiled_func(x: int, y: int) -> None: with pytest.raises( ValueError, match=( - f"Cannot visualize {len(metrics)} > 1 metrics with the @nsight.analyze.plot decorator." + rf"Cannot visualize multiple metrics \({len(metrics)} > 1\) with the @nsight\.analyze\.plot decorator\." ), ): profiled_func() + + +# ============================================================================ +# metric parameter test +# ============================================================================ + + +@pytest.mark.parametrize( # type: ignore[untyped-decorator] + "metrics,metric_param,expected_result", + [ + pytest.param( + [ + "gpu__time_duration.sum", + ], + None, + "valid_single_none", + id="valid_single_none", + ), + pytest.param( + [ + "gpu__time_duration.sum", + ], + "gpu__time_duration.sum", + "valid_single_specified", + id="valid_single_specified", + ), + pytest.param( + [ + "smsp__inst_executed.sum", + "smsp__inst_issued.sum", + ], + "smsp__inst_executed.sum", + "valid_multiple_specified", + id="valid_multiple_specified", + ), + pytest.param( + [ + "smsp__inst_executed.sum", + "smsp__inst_issued.sum", + ], + None, + "invalid_multiple_none", + id="invalid_multiple_none", + ), + pytest.param( + [ + "smsp__inst_executed.sum", + "smsp__inst_issued.sum", + ], + "invalid_metric", + "invalid_metric_specified", + id="invalid_metric_specified", + ), + ], +) +def test_legalize_metric_in_plot( + metrics: Sequence[str], metric_param: str | None, expected_result: str +) -> None: + + @nsight.analyze.plot( + filename="plot.png", ylabel="Instructions", metric=metric_param + ) + @nsight.analyze.kernel(configs=[(100, 100), (200, 200)], runs=2, metrics=metrics) + def profiled_func(x: int, y: int) -> None: + _simple_kernel_impl(x, y, "test_legalize_metric_in_plot") + + if expected_result == "invalid_multiple_none": + with pytest.raises( + ValueError, + match=( + rf"Cannot visualize multiple metrics \({len(metrics)} > 1\) with the @nsight\.analyze\.plot decorator\." + ), + ): + profiled_func() + elif expected_result == "invalid_metric_specified": + with pytest.raises( + ValueError, + match=(rf"Metric '{metric_param}' not found in the profile results\."), + ): + profiled_func() + elif expected_result in [ + "valid_single_none", + "valid_single_specified", + "valid_multiple_specified", + ]: + profile_output = profiled_func() + df = profile_output.to_dataframe() + print(df) + + # Check if the dataframe has the right metrics + assert all( + df["Metric"].isin(metrics) + ), f"Invalid metric name {df.loc[df['Metric'] != metrics[0], 'Metric'].iloc[0]} found in output dataframe" + + if metric_param is None: + expected_metric = metrics[0] + else: + expected_metric = metric_param + + metric_rows = df[df["Metric"] == expected_metric] + assert ( + len(metric_rows) == 2 # 2 configs + ), f"Expected metric '{expected_metric}' not found in results" + + # Check if the metric values are valid + assert all( + df["AvgValue"].notna() & (df["AvgValue"] > 0) + ), f"Invalid AvgValue for metric {metrics}"