Skip to content

Commit

Permalink
Add peak memory usage and footprint measurement (#34)
Browse files Browse the repository at this point in the history
Summary:
Fix #28
Add `gpu_peak_mem,mem_footprint,cpu_peak_mem` to `--metrics`.

Pull Request resolved: #34

Test Plan:
```
% python run.py --op fused_linear_cross_entropy --num-inputs 4 --metrics latency,gpu_peak_mem,mem_footprint,cpu_peak_mem
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [03:13<00:00, 48.38s/it]
  x_val    LMHeadCE-gpu_peak_mem    LMHeadCE-cpu_peak_mem    LMHeadCE-latency    LigerLMHeadCE-gpu_peak_mem    LigerLMHeadCE-mem_footprint    LigerLMHeadCE-cpu_peak_mem    LigerLMHeadCE-latency    inductor_fused_linear_cross_entropy-gpu_peak_mem    inductor_fused_linear_cross_entropy-mem_footprint    inductor_fused_linear_cross_entropy-cpu_peak_mem    inductor_fused_linear_cross_entropy-latency
-------  -----------------------  -----------------------  ------------------  ----------------------------  -----------------------------  ----------------------------  -----------------------  --------------------------------------------------  ---------------------------------------------------  --------------------------------------------------  ---------------------------------------------
      0              8.506082304                  85.0289             145.928                       6.66886                        1.27549                       93.7701                  537.321                                             6.40477                                              1.32809                                             96.5104                                        166.69
      1             12.775916544                  96.5116             285.919                       7.00013                        1.8251                        96.5116                 1007.9                                               8.57329                                              1.4902                                              96.5722                                        290.965
      2             21.315585024                  96.5722             579.465                       7.66267                        2.78175                       97.0613                 6807.01                                             12.9103                                               1.65105                                             97.1784                                        583.152
      3                 CUDA OOM                                                                    8.98984                                                      97.6245                 3416.81                                             21.5844                                                                                                   97.7367                                       1219.02
```

Reviewed By: xuzhao9

Differential Revision: D65302730

Pulled By: FindHao

fbshipit-source-id: de29e60e03965a7a98e0d5ba8ff5b32d60a8a85f
  • Loading branch information
FindHao authored and facebook-github-bot committed Nov 5, 2024
1 parent 136d272 commit 4cd607b
Showing 1 changed file with 118 additions and 19 deletions.
137 changes: 118 additions & 19 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

import numpy
import psutil
import tabulate
import torch
import triton

from tritonbench.components.ncu import analyzer as ncu_analyzer
from tritonbench.utils.env_utils import (
apply_precision,
Expand Down Expand Up @@ -62,7 +62,7 @@ class BenchmarkOperatorBackend:
REGISTERED_METRICS: Dict[str, List[str]] = {}
REGISTERED_X_VALS: Dict[str, str] = {}
BASELINE_BENCHMARKS: Dict[str, str] = {}
BASELINE_SKIP_METRICS = set(["speedup", "accuracy"])
BASELINE_SKIP_METRICS = {"speedup", "accuracy", "mem_footprint"}
X_ONLY_METRICS = set(["hw_roofline"])
PRECISION_DTYPE_MAPPING = {
"fp32": torch.float32,
Expand Down Expand Up @@ -225,6 +225,8 @@ class BenchmarkOperatorMetrics:
best_config: Optional[Dict[str, Any]] = None
# extra metrics
extra_metrics: Optional[Dict[str, float]] = None
# mem footprint
mem_footprint: Optional[float] = None


BUILTIN_METRICS = {x.name for x in fields(BenchmarkOperatorMetrics)} - {"extra_metrics"}
Expand Down Expand Up @@ -906,7 +908,7 @@ def _init_extra_metrics() -> Dict[str, Any]:
fn = self._get_bm_func(fn_name)
if baseline:
self.baseline_fn = fn
if set(["latency", "tflops", "speedup", "compile_time"]) & set(
if {"latency", "tflops", "speedup", "compile_time"} & set(
self.required_metrics
):
if self.use_cuda_graphs:
Expand All @@ -925,6 +927,29 @@ def _init_extra_metrics() -> Dict[str, Any]:
return_mode="median",
grad_to_none=self.get_grad_to_none(self.example_inputs),
)
if {"gpu_peak_mem", "gpu_mem_footprint", "cpu_peak_mem"} & set(
self.required_metrics
):
metrics.cpu_peak_mem, metrics.gpu_peak_mem = self.get_peak_mem(
fn,
grad_to_none=self.get_grad_to_none(self.example_inputs),
required_metrics=self.required_metrics,
use_cuda_graphs=self.use_cuda_graphs,
)
if (
"mem_footprint" in self.required_metrics
and "gpu_peak_mem" in self.required_metrics
and self.baseline_metrics
):
if (
self.baseline_metrics.gpu_peak_mem is not None
and metrics.gpu_peak_mem is not None
):
metrics.mem_footprint = (
self.baseline_metrics.gpu_peak_mem / metrics.gpu_peak_mem
)
else:
metrics.mem_footprint = None
if "walltime" in self.required_metrics:
metrics.walltime = do_bench_walltime(
fn,
Expand All @@ -942,13 +967,6 @@ def _init_extra_metrics() -> Dict[str, Any]:
if self.baseline_metrics and self.baseline_metrics.error_msg
else None
)
if (
"cpu_peak_mem" in self.required_metrics
or "gpu_peak_mem" in self.required_metrics
):
metrics.cpu_peak_mem, _device_id, metrics.gpu_peak_mem = (
self.get_peak_mem(fn, self.tb_args.metrics_gpu_backend)
)
if not baseline and "accuracy" in self.required_metrics:
metrics.accuracy = (
self._get_accuracy(fn, self.baseline_fn)
Expand Down Expand Up @@ -1068,16 +1086,97 @@ def _init_extra_metrics() -> Dict[str, Any]:
metrics.error_msg = str(e)
return metrics

def do_bench_cudagraph_mem(
self, fn, n_repeat=2, grad_to_none=None, device_type="cuda"
):
if torch.cuda.current_stream() == torch.cuda.default_stream():
raise RuntimeError(
"Cannot capture graph in default stream. Please use side stream in benchmark code."
)
# warmup
fn()
if grad_to_none is not None:
for x in grad_to_none:
x.detach_()
x.requires_grad_(True)
x.grad = None
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
fn()
torch.cuda.synchronize()
g.replay()
torch.cuda.synchronize()
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for _ in range(n_repeat):
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
fn()
torch.cuda.synchronize()

def do_bench_mem(self, fn, n_repeat=2, grad_to_none=None, device_type="cuda"):
di = torch._dynamo.device_interface.get_interface_for_device(device_type)
# warmup
fn()
di.synchronize()
# benchmark
for _ in range(n_repeat):
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
fn()
di.synchronize()

def get_peak_mem(
self, fn: Callable, metrics_memory_usage_backend: str
) -> Tuple[Optional[float], Optional[str], Optional[float]]:
raise NotImplementedError("Peak GPU Memory is not supported yet.")
# return get_peak_memory(
# func=fn,
# device=self.device,
# metrics_needed=["gpu_peak_mem", "cpu_peak_mem"],
# metrics_gpu_backend=metrics_memory_usage_backend,
# )
self,
fn: Callable,
grad_to_none: Optional[List[torch.Tensor]] = None,
required_metrics: Optional[List[str]] = None,
use_cuda_graphs: bool = False,
device_type: str = "cuda",
) -> Tuple[Optional[float], Optional[float]]:
"""Measures peak CPU and GPU memory usage during function execution.
Args:
fn (Callable): The function to measure memory usage for.
grad_to_none (Optional[List[torch.Tensor]], optional): List of tensors whose gradients
should be set to None between iterations. Defaults to None.
required_metrics (Optional[List[str]], optional): List of metrics to measure.
Supported values: ["gpu_peak_mem", "mem_footprint", "cpu_peak_mem"].
Defaults to None.
use_cuda_graphs (bool, optional): Whether to use CUDA graphs for measurement.
Defaults to False.
device_type (str, optional): Device to measure memory for ("cuda" or "cpu").
Defaults to "cuda".
Returns:
Tuple[Optional[float], Optional[float]]: A tuple containing:
- Peak CPU memory usage in GB (None if not requested)
- Peak GPU memory usage in GB (None if not requested or not on CUDA)
"""
gpu_peak_mem = None
cpu_peak_mem = None
if device_type == "cuda":
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
if use_cuda_graphs:
self.do_bench_cudagraph_mem(
fn, n_repeat=2, grad_to_none=grad_to_none, device_type=device_type
)
else:
self.do_bench_mem(
fn, n_repeat=2, grad_to_none=grad_to_none, device_type=device_type
)
if device_type == "cuda" and (
{"gpu_peak_mem", "mem_footprint"} & set(required_metrics)
):
gpu_peak_mem = torch.cuda.max_memory_allocated() / 10**9
if "cpu_peak_mem" in required_metrics:
total = psutil.virtual_memory().total
percentage = psutil.Process(os.getpid()).memory_percent()
cpu_peak_mem = percentage * total / 10**9
return cpu_peak_mem, gpu_peak_mem

def nsys_rep(self, input_id: int, fn_name: str) -> str:
import subprocess
Expand Down

0 comments on commit 4cd607b

Please sign in to comment.