Skip to content

Commit

Permalink
fp8 gemm was using fp32 by default
Browse files Browse the repository at this point in the history
Summary: why was fp8 defaulting to fp32 before?

Reviewed By: xuzhao9

Differential Revision: D66474486

fbshipit-source-id: 560ae17c93ce225e74b6a91bb6147536535089c1
  • Loading branch information
adamomainz authored and facebook-github-bot committed Nov 26, 2024
1 parent ebdb921 commit 66816da
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def parse_args(args: List[str]) -> argparse.Namespace:

class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops"]
DEFAULT_PRECISION = "fp32"
DEFAULT_PRECISION = "fp8"

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/fp8_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def parse_args(args):

class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
DEFAULT_PRECISION = "fp8"

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/fp8_gemm_blockwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def fp8_block_quantize(

class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "speedup", "accuracy"]
DEFAULT_PRECISION = "fp32"
DEFAULT_PRECISION = "fp8"

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/fp8_gemm_rowwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def fp8_row_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "speedup", "accuracy"]
DEFAULT_PRECISION = "fp32"
DEFAULT_PRECISION = "fp8"

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
Expand All @@ -118,7 +118,7 @@ def __init__(
self.use_cuda_graphs = True
addmm_args = parse_args(self.extra_args)
if hasattr(tb_args, "production_shapes") and tb_args.production_shapes:
self.shapes = get_production_shapes(self.name, "fp8_gemm")
self.shapes = get_production_shapes(self.name, "fp32_gemm")
elif addmm_args.m and addmm_args.n and addmm_args.k:
self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)]
elif addmm_args.llama:
Expand Down

0 comments on commit 66816da

Please sign in to comment.