From 8a8c1fc229b1c6d10b3904b8760298cc97bf6d2b Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Thu, 9 May 2024 14:56:54 -0700 Subject: [PATCH] Report tflops by default for gemm; fix exception handling (#2259) Summary: TFLOPS is the core metric for gemm. Along the way I hit some bugs and weirdness: - You couldn't Ctrl-C out of tritonbench, because the `finally` clause contained a return, which [suppresses the exception](https://docs.python.org/3/tutorial/errors.html#defining-clean-up-actions) - In generally I don't think the framework should catch RuntimeErrors, it makes it really hard to debug stuff because the desired result just ends up missing - In fact we had a typo (`metric` instead of `metrics` in the framework code that was never caught because it was caught and suppressed Pull Request resolved: https://github.com/pytorch/benchmark/pull/2259 Test Plan: ``` python run_benchmark.py triton --op gemm --splitk ``` Reviewed By: xuzhao9 Differential Revision: D57171806 Pulled By: bertmaher fbshipit-source-id: 74568625ad10907d9def8916abfc2f6292cdc6d6 --- torchbenchmark/operators/gemm/operator.py | 10 ++-------- torchbenchmark/util/triton_op.py | 2 -- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/torchbenchmark/operators/gemm/operator.py b/torchbenchmark/operators/gemm/operator.py index bbe0cf6058..21d537e176 100644 --- a/torchbenchmark/operators/gemm/operator.py +++ b/torchbenchmark/operators/gemm/operator.py @@ -89,7 +89,7 @@ class Operator(BenchmarkOperator): - DEFAULT_METRICS = ["latency", "speedup", "accuracy"] + DEFAULT_METRICS = ["latency", "speedup", "accuracy", "tflops"] DEFAULT_PRECISION = "fp16" def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): @@ -202,13 +202,7 @@ def get_input_iter(self) -> Generator: def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool: output = fn() baseline_output = baseline_fn() - accuracy = True - try: - torch.testing.assert_close(output, baseline_output) - except Exception: - accuracy = False - finally: - return accuracy + return torch.allclose(output, baseline_output) def plot(self): @triton.testing.perf_report( diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py index c5154bb444..c79f60e40e 100644 --- a/torchbenchmark/util/triton_op.py +++ b/torchbenchmark/util/triton_op.py @@ -752,8 +752,6 @@ def _init_extra_metrics() -> Dict[str, Any]: metrics.extra_metrics[metric_name] = func(fn, self.example_inputs, metrics) except torch.cuda.OutOfMemoryError: metrics.error_msg = "CUDA OOM" - except RuntimeError as e: - metrics.error_msg = str(e) return metrics def get_peak_mem(