Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable gemm and more operators in the CI #56

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions test/test_gpu/skip_tests_h100_pytorch.yaml
Original file line number Diff line number Diff line change
@@ -1,33 +1,45 @@
# Tests we skip in OSS CI
# This file is regarding to the Triton version bundled with pytorch
# Use <op-name:> to skip an entire operator
# Use <op-name:\n - impl-name> to skip an impl
# Tests we skip in triton-pytorch + OSS CI
# triton-pytorch is the triton version bundled with pytorch nightly
# We need to skip kernels that only work on triton-main
# Usage:
# op-name: to skip an entire operator
# op-name:\n\t- impl-name to skip an impl
bf16xint16_gemm:
# LLVM ERROR: mma16816 data type not supported
- bf16xint16
# TODO: we have many buggy backends for flash_attention
# Need to fix them in the CI
flash_attention:
# - triton_tutorial_flash_v2_tma
# - triton_op_flash_v2
# - xformers_splitk
# - colfax_cutlass
# - tk
# - sdpa
# - cudnn
# - flex_attention
# FIXME: enable colfax_cutlass and tk
- xformers
- xformers_splitk
- colfax_cutlass
- tk
# triton_tutorial_* kernels require triton-main
- triton_tutorial_flash_v2
- triton_tutorial_flash_v2_opt
- triton_tutorial_flash_v2_tma
- triton_tutorial_flash_v2_ws
- triton_tutorial_flash_v2_tma_ws
fp8_attention:
- colfax_fmha
# triton_flash_v2 now requires the main branch of Triton
# pytorch version does not work
# triton_flash_v2 requires triton-main
- triton_flash_v2
# fp8_fused_quant_gemm_rowwise requires fb-only kernels
fp8_fused_quant_gemm_rowwise:
fp8_gemm:
# triton_*_persistent requires triton-main
- triton_persistent_fp8_gemm
- triton_tma_persistent_fp8_gemm
fp8_gemm_rowwise:
gemm:
grouped_gemm:
int4_gemm:
# triton_*_persistent_* requires triton-main
- triton_persistent_matmul
- triton_tma_persistent_matmul
- triton_tma_persistent_cached_matmul
- hstu_triton_matmul
- colfax_cutlass_matmul
# FIXME: PT2 CUTLASS backend failed
- pt2_cutlass_matmul
# jagged tests are slow, so disable them in OSS
jagged_layer_norm:
jagged_mean:
jagged_softmax:
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from tritonbench.utils.triton_op import IS_FBCODE

try:
with add_path(SUBMODULE_PATH.joinpath("kernels")):
with add_path(str(SUBMODULE_PATH.joinpath("kernels"))):
from kernels.flash_attention import attention as triton_op_FA2
HAS_KERNELS = True
except BaseException:
Expand Down
35 changes: 21 additions & 14 deletions tritonbench/operators/gemm/operator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import argparse
import csv
import os
import statistics
from typing import Any, Callable, Generator, List, Optional, Tuple

import numpy
import torch
import torch._inductor.config as inductor_config
import triton
Expand All @@ -24,18 +22,27 @@

from .kernels import matmul as kernels
from .partition_k import matmul_partition_k
from .persistent_matmul import (
matmul_persistent,
matmul_tma_persistent,
matmul_tma_persistent_cached,
)

try:
from .persistent_matmul import (
matmul_persistent,
matmul_tma_persistent,
matmul_tma_persistent_cached,
)

HAS_PRESISTENT = True
except ModuleNotFoundError:
HAS_PRESISTENT = False

from .triton_matmul import (
matmul as triton_tutorial_matmul,
matmul_kernel as triton_tutorial_matmul_kernel,
)

if inductor_config.is_fbcode():
from hammer.ops.triton.triton_matmul import triton_matmul as hstu_triton_matmul
if IS_FBCODE:
from hammer.ops.triton.triton_matmul import (
triton_matmul as hstu_triton_matmul_kernel,
)

HAS_HAMMER = True
else:
Expand Down Expand Up @@ -158,22 +165,22 @@ def matmul_partition_k(self, a, b, bias) -> Callable:
else:
return lambda: matmul_partition_k(a, bt)

@register_benchmark()
@register_benchmark(enabled=HAS_PRESISTENT)
def triton_persistent_matmul(self, a, b, bias) -> Callable:
if not bias == None:
return lambda: matmul_persistent(a, b) + bias
else:
return lambda: matmul_persistent(a, b)

@register_benchmark(enabled=not IS_FBCODE)
@register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT)
def triton_tma_persistent_matmul(self, a, b, bias) -> Callable:
b = b.T.contiguous()
if not bias == None:
return lambda: matmul_tma_persistent(a, b) + bias
else:
return lambda: matmul_tma_persistent(a, b)

@register_benchmark(enabled=not IS_FBCODE)
@register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT)
def triton_tma_persistent_cached_matmul(self, a, b, bias) -> Callable:
b = b.T.contiguous()
if not bias == None:
Expand All @@ -197,9 +204,9 @@ def aten_matmul(self, a, b, bias) -> Callable:
@register_benchmark(enabled=HAS_HAMMER)
def hstu_triton_matmul(self, a, b, bias) -> Callable:
if not bias == None:
return lambda: hstu_triton_matmul(a, b) + bias
return lambda: hstu_triton_matmul_kernel(a, b) + bias
else:
return lambda: hstu_triton_matmul(a, b)
return lambda: hstu_triton_matmul_kernel(a, b)

@register_benchmark(enabled=bool(colfax_gemm))
def colfax_cutlass_matmul(self, a, b, bias) -> Callable:
Expand Down
Loading