From 66816daabd3647f256802100eec0ed0790eae409 Mon Sep 17 00:00:00 2001 From: Adam Mainz Date: Mon, 25 Nov 2024 17:11:32 -0800 Subject: [PATCH] fp8 gemm was using fp32 by default Summary: why was fp8 defaulting to fp32 before? Reviewed By: xuzhao9 Differential Revision: D66474486 fbshipit-source-id: 560ae17c93ce225e74b6a91bb6147536535089c1 --- .../operators/fp8_fused_quant_gemm_rowwise/operator.py | 2 +- tritonbench/operators/fp8_gemm/fp8_gemm.py | 1 + tritonbench/operators/fp8_gemm_blockwise/operator.py | 2 +- tritonbench/operators/fp8_gemm_rowwise/operator.py | 4 ++-- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tritonbench/operators/fp8_fused_quant_gemm_rowwise/operator.py b/tritonbench/operators/fp8_fused_quant_gemm_rowwise/operator.py index 58d210b..0f87a01 100644 --- a/tritonbench/operators/fp8_fused_quant_gemm_rowwise/operator.py +++ b/tritonbench/operators/fp8_fused_quant_gemm_rowwise/operator.py @@ -72,7 +72,7 @@ def parse_args(args: List[str]) -> argparse.Namespace: class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops"] - DEFAULT_PRECISION = "fp32" + DEFAULT_PRECISION = "fp8" def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None diff --git a/tritonbench/operators/fp8_gemm/fp8_gemm.py b/tritonbench/operators/fp8_gemm/fp8_gemm.py index 556c56c..d277528 100644 --- a/tritonbench/operators/fp8_gemm/fp8_gemm.py +++ b/tritonbench/operators/fp8_gemm/fp8_gemm.py @@ -43,6 +43,7 @@ def parse_args(args): class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops", "gbps", "latency"] + DEFAULT_PRECISION = "fp8" def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None diff --git a/tritonbench/operators/fp8_gemm_blockwise/operator.py b/tritonbench/operators/fp8_gemm_blockwise/operator.py index 0bcf507..d1f9bc7 100644 --- a/tritonbench/operators/fp8_gemm_blockwise/operator.py +++ b/tritonbench/operators/fp8_gemm_blockwise/operator.py @@ -115,7 +115,7 @@ def fp8_block_quantize( class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops", "speedup", "accuracy"] - DEFAULT_PRECISION = "fp32" + DEFAULT_PRECISION = "fp8" def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None diff --git a/tritonbench/operators/fp8_gemm_rowwise/operator.py b/tritonbench/operators/fp8_gemm_rowwise/operator.py index f65383e..cae6fd3 100644 --- a/tritonbench/operators/fp8_gemm_rowwise/operator.py +++ b/tritonbench/operators/fp8_gemm_rowwise/operator.py @@ -109,7 +109,7 @@ def fp8_row_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops", "speedup", "accuracy"] - DEFAULT_PRECISION = "fp32" + DEFAULT_PRECISION = "fp8" def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None @@ -118,7 +118,7 @@ def __init__( self.use_cuda_graphs = True addmm_args = parse_args(self.extra_args) if hasattr(tb_args, "production_shapes") and tb_args.production_shapes: - self.shapes = get_production_shapes(self.name, "fp8_gemm") + self.shapes = get_production_shapes(self.name, "fp32_gemm") elif addmm_args.m and addmm_args.n and addmm_args.k: self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)] elif addmm_args.llama: