Skip to content

Commit

Permalink
production data used in fp32_Gemm, bf16_gemm and softmax
Browse files Browse the repository at this point in the history
Summary:
adding support for more production data usage in tritonBench. Last left is HSTU for first cut of metric changes

weights are working as well here

Reviewed By: xuzhao9

Differential Revision: D65779069

fbshipit-source-id: a81237e39e407c47b3304e0bc9c4a00aebefb73c
  • Loading branch information
adamomainz authored and facebook-github-bot committed Nov 12, 2024
1 parent 3c83e0b commit eba50ab
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 6 deletions.
5 changes: 4 additions & 1 deletion tritonbench/operators/fp8_gemm_rowwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import triton
from tritonbench.utils.data_utils import get_production_shapes

from tritonbench.utils.triton_op import (
BenchmarkOperator,
Expand Down Expand Up @@ -113,7 +114,9 @@ def __init__(
super().__init__(tb_args, extra_args)
self.use_cuda_graphs = True
addmm_args = parse_args(self.extra_args)
if addmm_args.m and addmm_args.n and addmm_args.k:
if tb_args.production_shapes:
self.shapes = get_production_shapes(self.name, "fp8_gemm")
elif addmm_args.m and addmm_args.n and addmm_args.k:
self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)]
elif addmm_args.llama:
self.shapes = gemm_shapes()
Expand Down
13 changes: 11 additions & 2 deletions tritonbench/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch._inductor.config as inductor_config
import triton
from tritonbench.utils.data_utils import get_production_shapes

from tritonbench.utils.path_utils import REPO_PATH

Expand Down Expand Up @@ -129,7 +130,9 @@ def __init__(
self.use_cuda_graphs = False
gemm_args = parse_args(self.extra_args)
self.layout = gemm_args.layout
if gemm_args.input:
if IS_FBCODE and tb_args.production_shapes:
self.shapes = get_production_shapes(self.name, f"{tb_args.precision}_gemm")
elif gemm_args.input:
self.shapes = read_shapes_from_csv(gemm_args.input)
elif gemm_args.splitk:
self.shapes = SPLIT_K_SHAPES
Expand Down Expand Up @@ -286,7 +289,13 @@ def _scaled_randn(*args, scale: float, **kwargs) -> torch.Tensor:

def get_input_iter(self) -> Generator:
for shape in self.shapes:
m, n, k, bias = shape
if len(shape) == 4:
m, n, k, bias = shape
elif len(shape) == 3:
m, n, k = shape
bias = None
else:
raise ValueError(f"Invalid shape {shape}")
a = self._scaled_randn(
(m, k), scale=k, device=self.device, dtype=self.dtype
)
Expand Down
10 changes: 7 additions & 3 deletions tritonbench/operators/softmax/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import torch
import triton
import triton.language as tl
from tritonbench.utils.data_utils import get_production_shapes

from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
IS_FBCODE,
register_benchmark,
register_metric,
)
Expand Down Expand Up @@ -101,13 +103,15 @@ def _inner():

def get_input_iter(self):
M = 4096
for i in range(2, 100):
N = 128 * i
shapes = (tuple(M, 128 * i) for i in range(2, 100))
if IS_FBCODE and self.tb_args.production_shapes:
shapes = get_production_shapes(self.name, "softmax")
for M, N in shapes:
yield (torch.randn([M, N], dtype=self.dtype, device=self.device),)

def get_x_val(self, example_inputs) -> int:
shape = example_inputs[0].size()
return shape[1]
return [shape[0], shape[1]]

@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics) -> float:
Expand Down
14 changes: 14 additions & 0 deletions tritonbench/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .triton_ops import IS_FBCODE


def get_production_shapes(op_name, op_type):
"""Gets a list of Softmax shapes for benchmarking"""
if IS_FBCODE:
from .fb.durin_data import productionDataLoader

return [
shape
for shape in productionDataLoader.get_shapes_from_frozen_durin(
op_name, op_type
)
]

0 comments on commit eba50ab

Please sign in to comment.