From aa67b621dfecef1d0b2fe5fc5207e3af17fa503d Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 14:57:18 -0800 Subject: [PATCH] Enable gemm and more operators in the CI (#56) Summary: As the test isolation is implemented in https://github.com/pytorch-labs/tritonbench/pull/55, we can now enable more operators in the CI. Pull Request resolved: https://github.com/pytorch-labs/tritonbench/pull/56 Reviewed By: FindHao Differential Revision: D66189246 Pulled By: xuzhao9 fbshipit-source-id: 22f01b2e5b64956f6e2985f87be785efc977e46b --- test/test_gpu/skip_tests_h100_pytorch.yaml | 48 ++++++++++++------- .../operators/flash_attention/operator.py | 2 +- tritonbench/operators/gemm/operator.py | 35 ++++++++------ 3 files changed, 52 insertions(+), 33 deletions(-) diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index ce263437..d137d3d9 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -1,33 +1,45 @@ -# Tests we skip in OSS CI -# This file is regarding to the Triton version bundled with pytorch -# Use to skip an entire operator -# Use to skip an impl +# Tests we skip in triton-pytorch + OSS CI +# triton-pytorch is the triton version bundled with pytorch nightly +# We need to skip kernels that only work on triton-main +# Usage: +# op-name: to skip an entire operator +# op-name:\n\t- impl-name to skip an impl bf16xint16_gemm: + # LLVM ERROR: mma16816 data type not supported - bf16xint16 -# TODO: we have many buggy backends for flash_attention -# Need to fix them in the CI flash_attention: -# - triton_tutorial_flash_v2_tma -# - triton_op_flash_v2 -# - xformers_splitk -# - colfax_cutlass -# - tk -# - sdpa -# - cudnn -# - flex_attention + # FIXME: enable colfax_cutlass and tk + - xformers + - xformers_splitk + - colfax_cutlass + - tk + # triton_tutorial_* kernels require triton-main + - triton_tutorial_flash_v2 + - triton_tutorial_flash_v2_opt + - triton_tutorial_flash_v2_tma + - triton_tutorial_flash_v2_ws + - triton_tutorial_flash_v2_tma_ws fp8_attention: - colfax_fmha - # triton_flash_v2 now requires the main branch of Triton - # pytorch version does not work + # triton_flash_v2 requires triton-main - triton_flash_v2 +# fp8_fused_quant_gemm_rowwise requires fb-only kernels fp8_fused_quant_gemm_rowwise: fp8_gemm: + # triton_*_persistent requires triton-main - triton_persistent_fp8_gemm - triton_tma_persistent_fp8_gemm fp8_gemm_rowwise: gemm: -grouped_gemm: -int4_gemm: + # triton_*_persistent_* requires triton-main + - triton_persistent_matmul + - triton_tma_persistent_matmul + - triton_tma_persistent_cached_matmul + - hstu_triton_matmul + - colfax_cutlass_matmul + # FIXME: PT2 CUTLASS backend failed + - pt2_cutlass_matmul +# jagged tests are slow, so disable them in OSS jagged_layer_norm: jagged_mean: jagged_softmax: diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index de21bfd1..29a47759 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -43,7 +43,7 @@ from tritonbench.utils.triton_op import IS_FBCODE try: - with add_path(SUBMODULE_PATH.joinpath("kernels")): + with add_path(str(SUBMODULE_PATH.joinpath("kernels"))): from kernels.flash_attention import attention as triton_op_FA2 HAS_KERNELS = True except BaseException: diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index e6b8fe94..d12a2d73 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -1,10 +1,8 @@ import argparse import csv import os -import statistics from typing import Any, Callable, Generator, List, Optional, Tuple -import numpy import torch import torch._inductor.config as inductor_config import triton @@ -24,18 +22,27 @@ from .kernels import matmul as kernels from .partition_k import matmul_partition_k -from .persistent_matmul import ( - matmul_persistent, - matmul_tma_persistent, - matmul_tma_persistent_cached, -) + +try: + from .persistent_matmul import ( + matmul_persistent, + matmul_tma_persistent, + matmul_tma_persistent_cached, + ) + + HAS_PRESISTENT = True +except ModuleNotFoundError: + HAS_PRESISTENT = False + from .triton_matmul import ( matmul as triton_tutorial_matmul, matmul_kernel as triton_tutorial_matmul_kernel, ) -if inductor_config.is_fbcode(): - from hammer.ops.triton.triton_matmul import triton_matmul as hstu_triton_matmul +if IS_FBCODE: + from hammer.ops.triton.triton_matmul import ( + triton_matmul as hstu_triton_matmul_kernel, + ) HAS_HAMMER = True else: @@ -158,14 +165,14 @@ def matmul_partition_k(self, a, b, bias) -> Callable: else: return lambda: matmul_partition_k(a, bt) - @register_benchmark() + @register_benchmark(enabled=HAS_PRESISTENT) def triton_persistent_matmul(self, a, b, bias) -> Callable: if not bias == None: return lambda: matmul_persistent(a, b) + bias else: return lambda: matmul_persistent(a, b) - @register_benchmark(enabled=not IS_FBCODE) + @register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT) def triton_tma_persistent_matmul(self, a, b, bias) -> Callable: b = b.T.contiguous() if not bias == None: @@ -173,7 +180,7 @@ def triton_tma_persistent_matmul(self, a, b, bias) -> Callable: else: return lambda: matmul_tma_persistent(a, b) - @register_benchmark(enabled=not IS_FBCODE) + @register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT) def triton_tma_persistent_cached_matmul(self, a, b, bias) -> Callable: b = b.T.contiguous() if not bias == None: @@ -216,9 +223,9 @@ def op(): @register_benchmark(enabled=HAS_HAMMER) def hstu_triton_matmul(self, a, b, bias) -> Callable: if not bias == None: - return lambda: hstu_triton_matmul(a, b) + bias + return lambda: hstu_triton_matmul_kernel(a, b) + bias else: - return lambda: hstu_triton_matmul(a, b) + return lambda: hstu_triton_matmul_kernel(a, b) @register_benchmark(enabled=bool(colfax_gemm)) def colfax_cutlass_matmul(self, a, b, bias) -> Callable: