From 07000eefd2277e5bc7326b15a296811d4c48f21d Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 13:22:44 -0500 Subject: [PATCH 01/11] Enable gemm --- test/test_gpu/skip_tests_h100_pytorch.yaml | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index ce263437..aa6847d6 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -25,9 +25,6 @@ fp8_gemm: - triton_persistent_fp8_gemm - triton_tma_persistent_fp8_gemm fp8_gemm_rowwise: -gemm: -grouped_gemm: -int4_gemm: jagged_layer_norm: jagged_mean: jagged_softmax: From f2963f1aa581e51159105ebec46775263c09ee3c Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 13:35:46 -0500 Subject: [PATCH 02/11] Disable gemm on persistent --- test/test_gpu/skip_tests_h100_pytorch.yaml | 4 ++++ tritonbench/operators/gemm/operator.py | 21 +++++++++++++-------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index aa6847d6..a75eff52 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -25,6 +25,10 @@ fp8_gemm: - triton_persistent_fp8_gemm - triton_tma_persistent_fp8_gemm fp8_gemm_rowwise: +gemm: + - triton_persistent_matmul + - triton_tma_persistent_matmul + - triton_tma_persistent_cached_matmul jagged_layer_norm: jagged_mean: jagged_softmax: diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index 7bbae383..d32f0a7c 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -24,11 +24,16 @@ from .kernels import matmul as kernels from .partition_k import matmul_partition_k -from .persistent_matmul import ( - matmul_persistent, - matmul_tma_persistent, - matmul_tma_persistent_cached, -) +try: + from .persistent_matmul import ( + matmul_persistent, + matmul_tma_persistent, + matmul_tma_persistent_cached, + ) + HAS_PRESISTENT = True +except ModuleNotFoundError: + HAS_PRESISTENT = False + from .triton_matmul import ( matmul as triton_tutorial_matmul, matmul_kernel as triton_tutorial_matmul_kernel, @@ -158,14 +163,14 @@ def matmul_partition_k(self, a, b, bias) -> Callable: else: return lambda: matmul_partition_k(a, bt) - @register_benchmark() + @register_benchmark(enabled=HAS_PRESISTENT) def triton_persistent_matmul(self, a, b, bias) -> Callable: if not bias == None: return lambda: matmul_persistent(a, b) + bias else: return lambda: matmul_persistent(a, b) - @register_benchmark(enabled=not IS_FBCODE) + @register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT) def triton_tma_persistent_matmul(self, a, b, bias) -> Callable: b = b.T.contiguous() if not bias == None: @@ -173,7 +178,7 @@ def triton_tma_persistent_matmul(self, a, b, bias) -> Callable: else: return lambda: matmul_tma_persistent(a, b) - @register_benchmark(enabled=not IS_FBCODE) + @register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT) def triton_tma_persistent_cached_matmul(self, a, b, bias) -> Callable: b = b.T.contiguous() if not bias == None: From da326da22cdabdc68400fd5f4bd80f7cc0180d69 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 13:36:53 -0500 Subject: [PATCH 03/11] Fix lint --- tritonbench/operators/gemm/operator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index d32f0a7c..5ac4b87c 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -24,12 +24,14 @@ from .kernels import matmul as kernels from .partition_k import matmul_partition_k + try: from .persistent_matmul import ( matmul_persistent, matmul_tma_persistent, matmul_tma_persistent_cached, ) + HAS_PRESISTENT = True except ModuleNotFoundError: HAS_PRESISTENT = False From 796538e8e4c5814ba3d3df808d01d25a8131acdb Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 13:45:24 -0500 Subject: [PATCH 04/11] Add hstu kernel --- test/test_gpu/skip_tests_h100_pytorch.yaml | 1 + tritonbench/operators/gemm/operator.py | 12 ++++-------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index a75eff52..8f780d7c 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -29,6 +29,7 @@ gemm: - triton_persistent_matmul - triton_tma_persistent_matmul - triton_tma_persistent_cached_matmul + - hstu_triton_matmul jagged_layer_norm: jagged_mean: jagged_softmax: diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index 5ac4b87c..835b3a44 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -1,10 +1,7 @@ import argparse import csv import os -import statistics from typing import Any, Callable, Generator, List, Optional, Tuple - -import numpy import torch import torch._inductor.config as inductor_config import triton @@ -41,9 +38,8 @@ matmul_kernel as triton_tutorial_matmul_kernel, ) -if inductor_config.is_fbcode(): - from hammer.ops.triton.triton_matmul import triton_matmul as hstu_triton_matmul - +if IS_FBCODE: + from hammer.ops.triton.triton_matmul import triton_matmul as hstu_triton_matmul_kernel HAS_HAMMER = True else: HAS_HAMMER = False @@ -204,9 +200,9 @@ def aten_matmul(self, a, b, bias) -> Callable: @register_benchmark(enabled=HAS_HAMMER) def hstu_triton_matmul(self, a, b, bias) -> Callable: if not bias == None: - return lambda: hstu_triton_matmul(a, b) + bias + return lambda: hstu_triton_matmul_kernel(a, b) + bias else: - return lambda: hstu_triton_matmul(a, b) + return lambda: hstu_triton_matmul_kernel(a, b) @register_benchmark(enabled=bool(colfax_gemm)) def colfax_cutlass_matmul(self, a, b, bias) -> Callable: From dc2d5a83adad87c36b089400329e32ac92909886 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 15:07:43 -0500 Subject: [PATCH 05/11] Skip colfax_gemm --- test/test_gpu/skip_tests_h100_pytorch.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index 8f780d7c..dcfa347d 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -30,6 +30,7 @@ gemm: - triton_tma_persistent_matmul - triton_tma_persistent_cached_matmul - hstu_triton_matmul + - colfax_gemm jagged_layer_norm: jagged_mean: jagged_softmax: From 5809e27d14acd78c63280a4fd6f5de5722b18907 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 15:14:37 -0500 Subject: [PATCH 06/11] Fix linting --- test/test_gpu/skip_tests_h100_pytorch.yaml | 2 +- tritonbench/operators/gemm/operator.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index dcfa347d..17250050 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -30,7 +30,7 @@ gemm: - triton_tma_persistent_matmul - triton_tma_persistent_cached_matmul - hstu_triton_matmul - - colfax_gemm + - colfax_cutlass_matmul jagged_layer_norm: jagged_mean: jagged_softmax: diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index 835b3a44..3293f31a 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -2,6 +2,7 @@ import csv import os from typing import Any, Callable, Generator, List, Optional, Tuple + import torch import torch._inductor.config as inductor_config import triton @@ -39,7 +40,10 @@ ) if IS_FBCODE: - from hammer.ops.triton.triton_matmul import triton_matmul as hstu_triton_matmul_kernel + from hammer.ops.triton.triton_matmul import ( + triton_matmul as hstu_triton_matmul_kernel, + ) + HAS_HAMMER = True else: HAS_HAMMER = False From fe7d67b1e80d83ab7e1a02f25c2e64dfe74fe371 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 15:21:38 -0500 Subject: [PATCH 07/11] Skip the broken pt2 cutlass backend --- test/test_gpu/skip_tests_h100_pytorch.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index 17250050..0d280e88 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -31,6 +31,8 @@ gemm: - triton_tma_persistent_cached_matmul - hstu_triton_matmul - colfax_cutlass_matmul + # FIXME: PT2 CUTLASS backend failed + - pt2_cutlass_matmul jagged_layer_norm: jagged_mean: jagged_softmax: From bb4e82c487a788c2471d162917154c2f8248b3cd Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 15:38:31 -0500 Subject: [PATCH 08/11] Enable flash_attention --- test/test_gpu/skip_tests_h100_pytorch.yaml | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index 0d280e88..df70b555 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -3,23 +3,24 @@ # Use to skip an entire operator # Use to skip an impl bf16xint16_gemm: + # LLVM ERROR: mma16816 data type not supported - bf16xint16 -# TODO: we have many buggy backends for flash_attention -# Need to fix them in the CI flash_attention: -# - triton_tutorial_flash_v2_tma -# - triton_op_flash_v2 -# - xformers_splitk -# - colfax_cutlass -# - tk -# - sdpa -# - cudnn -# - flex_attention + - xformers + - xformers_splitk + - colfax_cutlass + - tk + - triton_tutorial_flash_v2 + - triton_tutorial_flash_v2_opt + - triton_tutorial_flash_v2_tma + - triton_tutorial_flash_v2_ws + - triton_tutorial_flash_v2_tma_ws fp8_attention: - colfax_fmha # triton_flash_v2 now requires the main branch of Triton # pytorch version does not work - triton_flash_v2 +# All requires fb-import kernels fp8_fused_quant_gemm_rowwise: fp8_gemm: - triton_persistent_fp8_gemm From 7e801cb07283d5e134a9942db23384079a72fd5f Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 15:43:11 -0500 Subject: [PATCH 09/11] Enable flash_attention kernels --- tritonbench/operators/flash_attention/operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index de21bfd1..29a47759 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -43,7 +43,7 @@ from tritonbench.utils.triton_op import IS_FBCODE try: - with add_path(SUBMODULE_PATH.joinpath("kernels")): + with add_path(str(SUBMODULE_PATH.joinpath("kernels"))): from kernels.flash_attention import attention as triton_op_FA2 HAS_KERNELS = True except BaseException: From e8a588e48c4cc09399fd0be7a9fd325e7e604e27 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 15:46:30 -0500 Subject: [PATCH 10/11] Add more coverage --- test/test_gpu/skip_tests_h100_pytorch.yaml | 23 +++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index df70b555..5baa8047 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -1,15 +1,19 @@ -# Tests we skip in OSS CI -# This file is regarding to the Triton version bundled with pytorch -# Use to skip an entire operator -# Use to skip an impl +# Tests we skip in triton-pytorch + OSS CI +# triton-pytorch is the triton version bundled with pytorch nightly +# We need to skip kernels that only work on triton-main +# Usage: +# op-name: to skip an entire operator +# op-name:\n\t- impl-name to skip an impl bf16xint16_gemm: # LLVM ERROR: mma16816 data type not supported - bf16xint16 flash_attention: + # FIXME: enable colfax_cutlass and tk - xformers - xformers_splitk - colfax_cutlass - tk + # triton_tutorial_* kernels require triton-main - triton_tutorial_flash_v2 - triton_tutorial_flash_v2_opt - triton_tutorial_flash_v2_tma @@ -17,16 +21,17 @@ flash_attention: - triton_tutorial_flash_v2_tma_ws fp8_attention: - colfax_fmha - # triton_flash_v2 now requires the main branch of Triton - # pytorch version does not work + # triton_flash_v2 requires triton-main - triton_flash_v2 -# All requires fb-import kernels +# fp8_fused_quant_gemm_rowwise requires fb-only kernels fp8_fused_quant_gemm_rowwise: fp8_gemm: + # triton_*_persistent requires triton-main - triton_persistent_fp8_gemm - triton_tma_persistent_fp8_gemm fp8_gemm_rowwise: gemm: + # triton_*_persistent_* requires triton-main - triton_persistent_matmul - triton_tma_persistent_matmul - triton_tma_persistent_cached_matmul @@ -34,10 +39,6 @@ gemm: - colfax_cutlass_matmul # FIXME: PT2 CUTLASS backend failed - pt2_cutlass_matmul -jagged_layer_norm: -jagged_mean: -jagged_softmax: -jagged_sum: ragged_attention: - hstu_triton_ragged_attention_persistent test_op: From 7e177c67a682c84aa9c618358e61e4a0f895e477 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 16:01:42 -0500 Subject: [PATCH 11/11] Disable jagged tests for speed --- test/test_gpu/skip_tests_h100_pytorch.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index 5baa8047..d137d3d9 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -39,6 +39,11 @@ gemm: - colfax_cutlass_matmul # FIXME: PT2 CUTLASS backend failed - pt2_cutlass_matmul +# jagged tests are slow, so disable them in OSS +jagged_layer_norm: +jagged_mean: +jagged_softmax: +jagged_sum: ragged_attention: - hstu_triton_ragged_attention_persistent test_op: