Skip to content

Commit

Permalink
Add hstu kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 19, 2024
1 parent da326da commit 796538e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
1 change: 1 addition & 0 deletions test/test_gpu/skip_tests_h100_pytorch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ gemm:
- triton_persistent_matmul
- triton_tma_persistent_matmul
- triton_tma_persistent_cached_matmul
- hstu_triton_matmul
jagged_layer_norm:
jagged_mean:
jagged_softmax:
Expand Down
12 changes: 4 additions & 8 deletions tritonbench/operators/gemm/operator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import argparse
import csv
import os
import statistics
from typing import Any, Callable, Generator, List, Optional, Tuple

import numpy
import torch
import torch._inductor.config as inductor_config
import triton
Expand Down Expand Up @@ -41,9 +38,8 @@
matmul_kernel as triton_tutorial_matmul_kernel,
)

if inductor_config.is_fbcode():
from hammer.ops.triton.triton_matmul import triton_matmul as hstu_triton_matmul

if IS_FBCODE:
from hammer.ops.triton.triton_matmul import triton_matmul as hstu_triton_matmul_kernel
HAS_HAMMER = True
else:
HAS_HAMMER = False
Expand Down Expand Up @@ -204,9 +200,9 @@ def aten_matmul(self, a, b, bias) -> Callable:
@register_benchmark(enabled=HAS_HAMMER)
def hstu_triton_matmul(self, a, b, bias) -> Callable:
if not bias == None:
return lambda: hstu_triton_matmul(a, b) + bias
return lambda: hstu_triton_matmul_kernel(a, b) + bias
else:
return lambda: hstu_triton_matmul(a, b)
return lambda: hstu_triton_matmul_kernel(a, b)

@register_benchmark(enabled=bool(colfax_gemm))
def colfax_cutlass_matmul(self, a, b, bias) -> Callable:
Expand Down

0 comments on commit 796538e

Please sign in to comment.