diff --git a/.ci/tritonbench/test-gpu.sh b/.ci/tritonbench/test-gpu.sh index 7e35898a..b41331aa 100644 --- a/.ci/tritonbench/test-gpu.sh +++ b/.ci/tritonbench/test-gpu.sh @@ -8,8 +8,11 @@ fi . "${SETUP_SCRIPT}" -# FIXME: patch hstu +# FIXME: patch and install hstu sudo apt-get install -y patch python install.py --hstu +# FIXME: install colfax +python install.py --colfax + python -m unittest test.test_gpu.main -v diff --git a/.gitmodules b/.gitmodules index cbea5d9e..00209831 100644 --- a/.gitmodules +++ b/.gitmodules @@ -19,3 +19,6 @@ [submodule "submodules/xformers"] path = submodules/xformers url = https://github.com/facebookresearch/xformers.git +[submodule "submodules/cutlass"] + path = submodules/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/install.py b/install.py index c2df7f61..4e2b7ab9 100644 --- a/install.py +++ b/install.py @@ -52,12 +52,6 @@ def test_fbgemm(): print("OK") -def install_cutlass(): - from tools.cutlass_kernels.install import install_colfax_cutlass - - install_colfax_cutlass() - - def install_fa2(compile=False): if compile: # compile from source (slow) @@ -83,12 +77,6 @@ def install_liger(): subprocess.check_call(cmd) -def install_tk(): - from tools.tk.install import install_tk - - install_tk() - - def install_xformers(): os_env = os.environ.copy() os_env["TORCH_CUDA_ARCH_LIST"] = "8.0;9.0;9.0a" @@ -101,7 +89,7 @@ def install_xformers(): parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--fbgemm", action="store_true", help="Install FBGEMM GPU") parser.add_argument( - "--cutlass", action="store_true", help="Install optional CUTLASS kernels" + "--colfax", action="store_true", help="Install optional Colfax CUTLASS kernels" ) parser.add_argument( "--fa2", action="store_true", help="Install optional flash_attention 2 kernels" @@ -139,14 +127,18 @@ def install_xformers(): if args.fa3 or args.all: logger.info("[tritonbench] installing fa3...") install_fa3() - if args.cutlass or args.all: - logger.info("[tritonbench] installing cutlass-kernels...") - install_cutlass() + if args.colfax or args.all: + logger.info("[tritonbench] installing colfax cutlass-kernels...") + from tools.cutlass_kernels.install import install_colfax_cutlass + + install_colfax_cutlass() if args.jax or args.all: logger.info("[tritonbench] installing jax...") install_jax() if args.tk or args.all: logger.info("[tritonbench] installing thunderkittens...") + from tools.tk.install import install_tk + install_tk() if args.liger or args.all: logger.info("[tritonbench] installing liger-kernels...") diff --git a/submodules/cutlass b/submodules/cutlass new file mode 160000 index 00000000..bbe579a9 --- /dev/null +++ b/submodules/cutlass @@ -0,0 +1 @@ +Subproject commit bbe579a9e3beb6ea6626d9227ec32d0dae119a49 diff --git a/test/test_gpu/main.py b/test/test_gpu/main.py index 5e900d76..b9c0025d 100644 --- a/test/test_gpu/main.py +++ b/test/test_gpu/main.py @@ -30,12 +30,10 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Ops that we can skip the unit tests -SKIP_OPS = { - "test_op", -} +# Ops that we run forward only +FWD_ONLY_OPS = skip_tests.get("fwd_only_ops", []) -TEST_OPERATORS = set(list_operators_by_collection(op_collection="default")) - SKIP_OPS +TEST_OPERATORS = set(list_operators_by_collection(op_collection="default")) def check_ci_output(op): @@ -67,6 +65,8 @@ def _run_one_operator(args: List[str]): op.run() check_ci_output(op) # Test backward (if applicable) + if tb_args.op in FWD_ONLY_OPS: + return if op.has_bwd(): del op tb_args.mode = "bwd" @@ -89,6 +89,8 @@ def _run_operator_in_task(op: str, args: List[str]): task.run() task.check_output() # Test backward (if applicable) + if op in FWD_ONLY_OPS: + return if task.get_attribute("has_bwd", method=True): task.del_op_instance() args.extend(["--bwd"]) diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index d137d3d9..d7d1cb8b 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -8,10 +8,7 @@ bf16xint16_gemm: # LLVM ERROR: mma16816 data type not supported - bf16xint16 flash_attention: - # FIXME: enable colfax_cutlass and tk - - xformers - - xformers_splitk - - colfax_cutlass + # thunderkittens cannot handle the default input shapes - tk # triton_tutorial_* kernels require triton-main - triton_tutorial_flash_v2 @@ -37,8 +34,6 @@ gemm: - triton_tma_persistent_cached_matmul - hstu_triton_matmul - colfax_cutlass_matmul - # FIXME: PT2 CUTLASS backend failed - - pt2_cutlass_matmul # jagged tests are slow, so disable them in OSS jagged_layer_norm: jagged_mean: @@ -47,3 +42,5 @@ jagged_sum: ragged_attention: - hstu_triton_ragged_attention_persistent test_op: +fwd_only_ops: + - flash_attention diff --git a/tools/cutlass_kernels/install.py b/tools/cutlass_kernels/install.py index a4c552d9..9ea9d4a7 100644 --- a/tools/cutlass_kernels/install.py +++ b/tools/cutlass_kernels/install.py @@ -8,8 +8,7 @@ "/usr/local/cuda" if "CUDA_HOME" not in os.environ else os.environ["CUDA_HOME"] ) REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent -FBGEMM_PATH = REPO_PATH.joinpath("submodules", "FBGEMM", "fbgemm_gpu") -FBGEMM_CUTLASS_PATH = FBGEMM_PATH.parent.joinpath("external", "cutlass") +TORCH_CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass") COLFAX_CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass-kernels") COLFAX_CUTLASS_TRITONBENCH_PATH = REPO_PATH.joinpath("tools", "cutlass_kernels") @@ -41,9 +40,9 @@ f"-I{str(COLFAX_CUTLASS_PATH.joinpath('lib').resolve())}", f"-I{str(COLFAX_CUTLASS_PATH.joinpath('include').resolve())}", f"-I{str(COLFAX_CUTLASS_PATH.joinpath('src', 'fmha').resolve())}", - f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('include').resolve())}", - f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('examples', 'commmon').resolve())}", - f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('tools', 'util', 'include').resolve())}", + f"-I{str(TORCH_CUTLASS_PATH.joinpath('include').resolve())}", + f"-I{str(TORCH_CUTLASS_PATH.joinpath('examples', 'commmon').resolve())}", + f"-I{str(TORCH_CUTLASS_PATH.joinpath('tools', 'util', 'include').resolve())}", f"-I{CUDA_HOME}/include", f"-I{str(TORCH_BASE_PATH.joinpath('include').resolve())}", f"-I{str(COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath('include').resolve())}", diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 29a47759..fc5ba43e 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -85,7 +85,6 @@ ) except (ImportError, IOError, AttributeError): HAS_FLASH_V3 = False - pass # [Optional] xformers backend try: @@ -93,8 +92,10 @@ import xformers.ops.fmha as xformers_fmha # @manual=//fair/xformers:xformers from .test_fmha_utils import permute_qkv + + HAS_XFORMERS = True except (ImportError, IOError, AttributeError): - pass + HAS_XFORMERS = False # [Optional] colfax cutlass backend try: @@ -122,7 +123,6 @@ tk_fwd = torch.ops.tk except (ImportError, IOError, AttributeError): tk_fwd = None - tk_fwd_causal = None from typing import Any, Generator, List @@ -145,9 +145,6 @@ def parse_op_args(args: List[str]): parser.add_argument("--n-heads", type=int, default=48, help="Number of heads") parser.add_argument("--d-head", type=int, default=64, help="specify head dimension") parser.add_argument("--causal", action="store_true", help="enable causal") - parser.add_argument( - "--xformers-splitk", action="store_true", help="benchmark xformers-split impl" - ) return parser.parse_args(args) @@ -168,7 +165,6 @@ def __init__( self.N_CTX = None self.causal = args.causal self.sm_scale = 1.3 - self.xformers_splitk = args.xformers_splitk @register_benchmark() def aten( @@ -335,7 +331,7 @@ def xformers_preprocess( ) return fhma_input - @register_benchmark(enabled=False) + @register_benchmark(enabled=HAS_XFORMERS) def xformers( self, q: torch.Tensor, @@ -346,7 +342,7 @@ def xformers( xformers_cutlass_fhma = xformers.ops.fmha.cutlass.FwOp return lambda: xformers_cutlass_fhma().apply(fhma_input, needs_gradient=False) - @register_benchmark(enabled=False) + @register_benchmark(enabled=HAS_XFORMERS) def xformers_splitk( self, q: torch.Tensor, @@ -364,7 +360,7 @@ def colfax_cutlass_preprocess(self, q, k, v): torch.transpose(v, 1, 2), ) - @register_benchmark(enabled=False) + @register_benchmark(enabled=bool(colfax_cutlass_fmha is not None)) def colfax_cutlass(self, q, k, v): default_scale = 1.0 / math.sqrt(float(self.D_HEAD)) colfax_q, colfax_k, colfax_v = self.colfax_cutlass_preprocess(q, k, v) @@ -378,15 +374,13 @@ def colfax_cutlass(self, q, k, v): default_scale, ) - @register_benchmark(enabled=False) + @register_benchmark(enabled=bool(tk_fwd is not None)) def tk(self, q, k, v): o = torch.zeros_like(v) + l_tensor = torch.zeros_like(o).to(torch.float32) def tk_dispatcher(): - if self.causal: - tk_fwd_causal.attention_forward_causal(q, k, v, o) - else: - tk_fwd.attention_forward(q, k, v, o) + tk_fwd.attention_forward(q, k, v, o, l_tensor, causal=self.causal) return o return tk_dispatcher diff --git a/tritonbench/utils/env_utils.py b/tritonbench/utils/env_utils.py index fd012e0f..fbfca74c 100644 --- a/tritonbench/utils/env_utils.py +++ b/tritonbench/utils/env_utils.py @@ -5,6 +5,8 @@ from typing import Optional +from tritonbench.utils.path_utils import REPO_PATH + log = logging.getLogger(__name__) MAIN_RANDOM_SEED = 1337 @@ -22,6 +24,18 @@ ] +def set_env(): + # set cutlass dir + # by default we use the cutlass version built with pytorch + import torch + + current_cutlass_dir = torch._inductor.config.cuda.cutlass_dir + if not os.path.exists(current_cutlass_dir): + tb_cutlass_dir = REPO_PATH.joinpath("submodules", "cutlass") + if tb_cutlass_dir.is_dir(): + torch._inductor.config.cuda.cutlass_dir = str(tb_cutlass_dir) + + def set_random_seed(): """Make torch manual seed deterministic. Helps with accuracy testing.""" import random diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index ccaf88cb..a21ebe38 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -26,6 +26,7 @@ from tritonbench.utils.env_utils import ( apply_precision, fresh_triton_cache, + set_env, set_random_seed, ) from tritonbench.utils.input import input_cast @@ -561,6 +562,7 @@ class BenchmarkOperator(metaclass=PostInitProcessor): def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): + set_env() set_random_seed() self.name = _find_op_name_from_module_path(self.__class__.__module__) self._raw_extra_args = copy.deepcopy(extra_args) @@ -619,7 +621,7 @@ def __post__init__(self): def _get_bm_func(self, bm_func_name: str): fwd_fn_lambda = getattr(self, bm_func_name, None) - assert fwd_fn_lambda, ( + assert callable(fwd_fn_lambda), ( f"Could not find benchmark {bm_func_name} registered in {self.name}. " f"Available benchmarks: {REGISTERED_BENCHMARKS[self.name].keys()}. " )