Skip to content

Commit

Permalink
Report tflops by default for gemm; fix exception handling (pytorch#2259)
Browse files Browse the repository at this point in the history
Summary:
TFLOPS is the core metric for gemm.  Along the way I hit some bugs and weirdness:

- You couldn't Ctrl-C out of tritonbench, because the `finally` clause contained a return, which [suppresses the exception](https://docs.python.org/3/tutorial/errors.html#defining-clean-up-actions)
- In generally I don't think the framework should catch RuntimeErrors, it makes it really hard to debug stuff because the desired result just ends up missing
- In fact we had a typo (`metric` instead of `metrics` in the framework code that was never caught because it was caught and suppressed

Pull Request resolved: pytorch#2259

Test Plan:
```
python run_benchmark.py triton --op gemm --splitk
```

Reviewed By: xuzhao9

Differential Revision: D57171806

Pulled By: bertmaher

fbshipit-source-id: 74568625ad10907d9def8916abfc2f6292cdc6d6
  • Loading branch information
bertmaher authored and facebook-github-bot committed May 9, 2024
1 parent c1f2dc8 commit 8a8c1fc
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 10 deletions.
10 changes: 2 additions & 8 deletions torchbenchmark/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]
DEFAULT_METRICS = ["latency", "speedup", "accuracy", "tflops"]
DEFAULT_PRECISION = "fp16"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
Expand Down Expand Up @@ -202,13 +202,7 @@ def get_input_iter(self) -> Generator:
def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
baseline_output = baseline_fn()
accuracy = True
try:
torch.testing.assert_close(output, baseline_output)
except Exception:
accuracy = False
finally:
return accuracy
return torch.allclose(output, baseline_output)

def plot(self):
@triton.testing.perf_report(
Expand Down
2 changes: 0 additions & 2 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,8 +752,6 @@ def _init_extra_metrics() -> Dict[str, Any]:
metrics.extra_metrics[metric_name] = func(fn, self.example_inputs, metrics)
except torch.cuda.OutOfMemoryError:
metrics.error_msg = "CUDA OOM"
except RuntimeError as e:
metrics.error_msg = str(e)
return metrics

def get_peak_mem(
Expand Down

0 comments on commit 8a8c1fc

Please sign in to comment.