Skip to content

Commit f74fd56

Browse files
FindHaofacebook-github-bot
authored andcommitted
Rename mem_footprint to mem_footprint_compression_ratio (#66)
Summary: Fix #44 Pull Request resolved: #66 Reviewed By: xuzhao9 Differential Revision: D66312472 Pulled By: FindHao fbshipit-source-id: 299a37622f5d066461f8779f5d7b15f72ae5d27a
1 parent bde7986 commit f74fd56

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

tritonbench/utils/triton_op.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class BenchmarkOperatorBackend:
6363
REGISTERED_METRICS: Dict[str, List[str]] = {}
6464
REGISTERED_X_VALS: Dict[str, str] = {}
6565
BASELINE_BENCHMARKS: Dict[str, str] = {}
66-
BASELINE_SKIP_METRICS = {"speedup", "accuracy", "mem_footprint"}
66+
BASELINE_SKIP_METRICS = {"speedup", "accuracy", "mem_footprint_compression_ratio"}
6767
X_ONLY_METRICS = set(["hw_roofline"])
6868
PRECISION_DTYPE_MAPPING = {
6969
"fp32": torch.float32,
@@ -227,7 +227,7 @@ class BenchmarkOperatorMetrics:
227227
# extra metrics
228228
extra_metrics: Optional[Dict[str, float]] = None
229229
# mem footprint
230-
mem_footprint: Optional[float] = None
230+
mem_footprint_compression_ratio: Optional[float] = None
231231

232232

233233
BUILTIN_METRICS = {x.name for x in fields(BenchmarkOperatorMetrics)} - {"extra_metrics"}
@@ -953,29 +953,31 @@ def _init_extra_metrics() -> Dict[str, Any]:
953953
if not self.tb_args.bypass_fail:
954954
raise e
955955
metrics.latency = None
956-
if {"gpu_peak_mem", "gpu_mem_footprint", "cpu_peak_mem"} & set(
957-
self.required_metrics
958-
):
956+
if {
957+
"gpu_peak_mem",
958+
"gpu_mem_footprint_compression_ratio",
959+
"cpu_peak_mem",
960+
} & set(self.required_metrics):
959961
metrics.cpu_peak_mem, metrics.gpu_peak_mem = self.get_peak_mem(
960962
fn,
961963
grad_to_none=self.get_grad_to_none(self.example_inputs),
962964
required_metrics=self.required_metrics,
963965
use_cuda_graphs=self.use_cuda_graphs,
964966
)
965967
if (
966-
"mem_footprint" in self.required_metrics
968+
"mem_footprint_compression_ratio" in self.required_metrics
967969
and "gpu_peak_mem" in self.required_metrics
968970
and self.baseline_metrics
969971
):
970972
if (
971973
self.baseline_metrics.gpu_peak_mem is not None
972974
and metrics.gpu_peak_mem is not None
973975
):
974-
metrics.mem_footprint = (
976+
metrics.mem_footprint_compression_ratio = (
975977
self.baseline_metrics.gpu_peak_mem / metrics.gpu_peak_mem
976978
)
977979
else:
978-
metrics.mem_footprint = None
980+
metrics.mem_footprint_compression_ratio = None
979981
if "walltime" in self.required_metrics:
980982
metrics.walltime = do_bench_walltime(
981983
fn,
@@ -1180,7 +1182,7 @@ def get_peak_mem(
11801182
grad_to_none (Optional[List[torch.Tensor]], optional): List of tensors whose gradients
11811183
should be set to None between iterations. Defaults to None.
11821184
required_metrics (Optional[List[str]], optional): List of metrics to measure.
1183-
Supported values: ["gpu_peak_mem", "mem_footprint", "cpu_peak_mem"].
1185+
Supported values: ["gpu_peak_mem", "mem_footprint_compression_ratio", "cpu_peak_mem"].
11841186
Defaults to None.
11851187
use_cuda_graphs (bool, optional): Whether to use CUDA graphs for measurement.
11861188
Defaults to False.
@@ -1206,7 +1208,7 @@ def get_peak_mem(
12061208
fn, n_repeat=2, grad_to_none=grad_to_none, device_type=device_type
12071209
)
12081210
if device_type == "cuda" and (
1209-
{"gpu_peak_mem", "mem_footprint"} & set(required_metrics)
1211+
{"gpu_peak_mem", "mem_footprint_compression_ratio"} & set(required_metrics)
12101212
):
12111213
gpu_peak_mem = torch.cuda.max_memory_allocated() / 10**9
12121214
if "cpu_peak_mem" in required_metrics:

0 commit comments

Comments
 (0)