Skip to content

Commit

Permalink
Enable gemm and more operators in the CI (#56)
Browse files Browse the repository at this point in the history
Summary:
As the test isolation is implemented in #55, we can now enable more operators in the CI.

Pull Request resolved: #56

Reviewed By: FindHao

Differential Revision: D66189246

Pulled By: xuzhao9

fbshipit-source-id: 22f01b2e5b64956f6e2985f87be785efc977e46b
  • Loading branch information
xuzhao9 committed Nov 20, 2024
1 parent de4c76b commit 69b49e0
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 33 deletions.
48 changes: 30 additions & 18 deletions test/test_gpu/skip_tests_h100_pytorch.yaml
Original file line number Diff line number Diff line change
@@ -1,33 +1,45 @@
# Tests we skip in OSS CI
# This file is regarding to the Triton version bundled with pytorch
# Use <op-name:> to skip an entire operator
# Use <op-name:\n - impl-name> to skip an impl
# Tests we skip in triton-pytorch + OSS CI
# triton-pytorch is the triton version bundled with pytorch nightly
# We need to skip kernels that only work on triton-main
# Usage:
# op-name: to skip an entire operator
# op-name:\n\t- impl-name to skip an impl
bf16xint16_gemm:
# LLVM ERROR: mma16816 data type not supported
- bf16xint16
# TODO: we have many buggy backends for flash_attention
# Need to fix them in the CI
flash_attention:
# - triton_tutorial_flash_v2_tma
# - triton_op_flash_v2
# - xformers_splitk
# - colfax_cutlass
# - tk
# - sdpa
# - cudnn
# - flex_attention
# FIXME: enable colfax_cutlass and tk
- xformers
- xformers_splitk
- colfax_cutlass
- tk
# triton_tutorial_* kernels require triton-main
- triton_tutorial_flash_v2
- triton_tutorial_flash_v2_opt
- triton_tutorial_flash_v2_tma
- triton_tutorial_flash_v2_ws
- triton_tutorial_flash_v2_tma_ws
fp8_attention:
- colfax_fmha
# triton_flash_v2 now requires the main branch of Triton
# pytorch version does not work
# triton_flash_v2 requires triton-main
- triton_flash_v2
# fp8_fused_quant_gemm_rowwise requires fb-only kernels
fp8_fused_quant_gemm_rowwise:
fp8_gemm:
# triton_*_persistent requires triton-main
- triton_persistent_fp8_gemm
- triton_tma_persistent_fp8_gemm
fp8_gemm_rowwise:
gemm:
grouped_gemm:
int4_gemm:
# triton_*_persistent_* requires triton-main
- triton_persistent_matmul
- triton_tma_persistent_matmul
- triton_tma_persistent_cached_matmul
- hstu_triton_matmul
- colfax_cutlass_matmul
# FIXME: PT2 CUTLASS backend failed
- pt2_cutlass_matmul
# jagged tests are slow, so disable them in OSS
jagged_layer_norm:
jagged_mean:
jagged_softmax:
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from tritonbench.utils.triton_op import IS_FBCODE

try:
with add_path(SUBMODULE_PATH.joinpath("kernels")):
with add_path(str(SUBMODULE_PATH.joinpath("kernels"))):
from kernels.flash_attention import attention as triton_op_FA2
HAS_KERNELS = True
except BaseException:
Expand Down
35 changes: 21 additions & 14 deletions tritonbench/operators/gemm/operator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import argparse
import csv
import os
import statistics
from typing import Any, Callable, Generator, List, Optional, Tuple

import numpy
import torch
import torch._inductor.config as inductor_config
import triton
Expand All @@ -24,18 +22,27 @@

from .kernels import matmul as kernels
from .partition_k import matmul_partition_k
from .persistent_matmul import (
matmul_persistent,
matmul_tma_persistent,
matmul_tma_persistent_cached,
)

try:
from .persistent_matmul import (
matmul_persistent,
matmul_tma_persistent,
matmul_tma_persistent_cached,
)

HAS_PRESISTENT = True
except ModuleNotFoundError:
HAS_PRESISTENT = False

from .triton_matmul import (
matmul as triton_tutorial_matmul,
matmul_kernel as triton_tutorial_matmul_kernel,
)

if inductor_config.is_fbcode():
from hammer.ops.triton.triton_matmul import triton_matmul as hstu_triton_matmul
if IS_FBCODE:
from hammer.ops.triton.triton_matmul import (
triton_matmul as hstu_triton_matmul_kernel,
)

HAS_HAMMER = True
else:
Expand Down Expand Up @@ -158,22 +165,22 @@ def matmul_partition_k(self, a, b, bias) -> Callable:
else:
return lambda: matmul_partition_k(a, bt)

@register_benchmark()
@register_benchmark(enabled=HAS_PRESISTENT)
def triton_persistent_matmul(self, a, b, bias) -> Callable:
if not bias == None:
return lambda: matmul_persistent(a, b) + bias
else:
return lambda: matmul_persistent(a, b)

@register_benchmark(enabled=not IS_FBCODE)
@register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT)
def triton_tma_persistent_matmul(self, a, b, bias) -> Callable:
b = b.T.contiguous()
if not bias == None:
return lambda: matmul_tma_persistent(a, b) + bias
else:
return lambda: matmul_tma_persistent(a, b)

@register_benchmark(enabled=not IS_FBCODE)
@register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT)
def triton_tma_persistent_cached_matmul(self, a, b, bias) -> Callable:
b = b.T.contiguous()
if not bias == None:
Expand Down Expand Up @@ -216,9 +223,9 @@ def op():
@register_benchmark(enabled=HAS_HAMMER)
def hstu_triton_matmul(self, a, b, bias) -> Callable:
if not bias == None:
return lambda: hstu_triton_matmul(a, b) + bias
return lambda: hstu_triton_matmul_kernel(a, b) + bias
else:
return lambda: hstu_triton_matmul(a, b)
return lambda: hstu_triton_matmul_kernel(a, b)

@register_benchmark(enabled=bool(colfax_gemm))
def colfax_cutlass_matmul(self, a, b, bias) -> Callable:
Expand Down

0 comments on commit 69b49e0

Please sign in to comment.