diff --git a/.ci/tritonbench/test-gpu.sh b/.ci/tritonbench/test-gpu.sh index bf762c37..7e35898a 100644 --- a/.ci/tritonbench/test-gpu.sh +++ b/.ci/tritonbench/test-gpu.sh @@ -8,7 +8,8 @@ fi . "${SETUP_SCRIPT}" -# install deps -pip install psutil tabulate +# FIXME: patch hstu +sudo apt-get install -y patch +python install.py --hstu -python -m unittest test.test_gpu.main +python -m unittest test.test_gpu.main -v diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 22d525c9..b587d0d7 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -3,6 +3,7 @@ on: pull_request: paths: - .ci/* + - test/test_gpu/* - tritonbench/* - .github/workflows/pr.yaml push: diff --git a/run.py b/run.py index cf8afc87..9d353c2b 100644 --- a/run.py +++ b/run.py @@ -55,12 +55,6 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorRe Opbench = load_opbench_by_name_from_loader(args) else: Opbench = load_opbench_by_name(args.op) - if args.fwd_bwd: - args.mode = "fwd_bwd" - if args.bwd: - args.mode = "bwd" - if args.fwd_no_grad: - args.mode = "fwd_no_grad" opbench = Opbench( tb_args=args, extra_args=extra_args, diff --git a/test/test_gpu/main.py b/test/test_gpu/main.py index 9e952b2a..2738ba77 100644 --- a/test/test_gpu/main.py +++ b/test/test_gpu/main.py @@ -53,10 +53,9 @@ def check_ci_output(op): ), f"output impls: {output_impls} != ci_enabled impls: {ci_enabled_impls}" -def _run_one_operator( - tb_args: argparse.Namespace, - extra_args: Optional[List[str]] = None, -): +def _run_one_operator(args: List[str]): + parser = get_parser(args) + tb_args, extra_args = parser.parse_known_args(args) if tb_args.op in skip_tests: # If the op itself is in the skip list, skip all tests if not skip_tests[tb_args.op]: @@ -80,6 +79,31 @@ def _run_one_operator( ) +def _run_operator_in_task(op: str, args: List[str]): + from tritonbench.operators.op_task import OpTask + + if op in skip_tests: + # If the op itself is in the skip list, skip all tests + if not skip_tests[op]: + return + skip = ",".join(skip_tests[op]) + args.extend(["--skip", skip]) + task = OpTask(op) + task.make_operator_instance(args=args) + task.run() + task.check_output() + task.del_op_instance() + # Test backward (if applicable) + try: + args.extend(["--bwd"]) + task.make_operator_instance(args=args) + task.run() + task.check_output() + except NotImplementedError: + # Operator does not support backward, skip the test + pass + + def make_test(operator): def test_case(self): # Add `--test-only` to disable Triton autotune in tests @@ -92,12 +116,10 @@ def test_case(self): "1", "--test-only", ] - parser = get_parser(args) - tb_args, extra_args = parser.parse_known_args(args) - _run_one_operator( - tb_args, - extra_args, - ) + if IS_FBCODE: + _run_one_operator(args) + else: + _run_operator_in_task(op=operator, args=args) return test_case diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index 6acc2094..ce263437 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -17,6 +17,9 @@ flash_attention: # - flex_attention fp8_attention: - colfax_fmha + # triton_flash_v2 now requires the main branch of Triton + # pytorch version does not work + - triton_flash_v2 fp8_fused_quant_gemm_rowwise: fp8_gemm: - triton_persistent_fp8_gemm @@ -29,9 +32,6 @@ jagged_layer_norm: jagged_mean: jagged_softmax: jagged_sum: -layer_norm: -low_mem_dropout: -rms_norm: -rope: -template_attention: +ragged_attention: + - hstu_triton_ragged_attention_persistent test_op: diff --git a/tritonbench/operators/addmm/hstu.py b/tritonbench/operators/addmm/hstu.py index 66cfd863..ddd08073 100644 --- a/tritonbench/operators/addmm/hstu.py +++ b/tritonbench/operators/addmm/hstu.py @@ -6,11 +6,8 @@ import triton from tritonbench.utils.path_utils import add_path, SUBMODULE_PATH -with add_path(str(SUBMODULE_PATH)): - triton_addmm = importlib.import_module( - "generative-recommenders.ops.triton.triton_addmm" - ) - _addmm_fwd = triton_addmm._addmm_fwd +with add_path(str(SUBMODULE_PATH.joinpath("generative-recommenders"))): + from generative_recommenders.ops.triton.triton_addmm import _addmm_fwd class _AddMmFunction(torch.autograd.Function): diff --git a/tritonbench/operators/fp8_attention/operator.py b/tritonbench/operators/fp8_attention/operator.py index 63d1e33b..0131ca73 100644 --- a/tritonbench/operators/fp8_attention/operator.py +++ b/tritonbench/operators/fp8_attention/operator.py @@ -10,7 +10,7 @@ import torch -from tritonbench.kernels.triton_fused_attention import attention as triton_attention +from tritonbench.kernels.triton_fused_attention import attention_opt as triton_attention from tritonbench.utils.triton_op import ( BenchmarkOperator, BenchmarkOperatorMetrics, @@ -110,7 +110,7 @@ def triton_flash_v2( triton_q, triton_k, triton_v = self.triton_preprocess(q, k, v) # full fp8 will be enabled if type of q,k,v is fp8 return lambda: triton_attention( - triton_q, triton_k, triton_v, False, self.sm_scale + triton_q, triton_k, triton_v, False, self.sm_scale, "base" ) def get_x_val(self, _example_inputs) -> Tuple[int, int, int, int]: diff --git a/tritonbench/operators/op_task.py b/tritonbench/operators/op_task.py index 9c4d47b8..adc4c1de 100644 --- a/tritonbench/operators/op_task.py +++ b/tritonbench/operators/op_task.py @@ -53,7 +53,6 @@ class OpDetails: name: str exists: bool - metadata: Dict[str, Any] class OpTask(base_task.TaskBase): @@ -111,7 +110,6 @@ def _maybe_import_operator(package: str, op_name: str) -> Dict[str, Any]: return { "name": op_name, "exists": Operator is not None, - "metadata": {}, } # ========================================================================= @@ -121,14 +119,16 @@ def _maybe_import_operator(package: str, op_name: str) -> Dict[str, Any]: @base_task.run_in_worker(scoped=True) @staticmethod def make_operator_instance( - mode: str, - device: str, - extra_args: Optional[List[str]] = None, + args: List[str], ) -> None: + from tritonbench.utils.parser import get_parser + + parser = get_parser() + tb_args, extra_args = parser.parse_known_args(args) Operator = globals()["Operator"] + parser = get_parser() op = Operator( - mode=mode, - device=device, + tb_args=tb_args, extra_args=extra_args, ) @@ -136,8 +136,7 @@ def make_operator_instance( gc.collect() - if device == "cuda": - torch.cuda.empty_cache() + if op.device == "cuda": maybe_sync = torch.cuda.synchronize else: maybe_sync = lambda: None @@ -181,6 +180,25 @@ def get_attribute( else: return None + # ========================================================================= + # == Check output is expected in the child process ======================== + # ========================================================================= + @base_task.run_in_worker(scoped=True) + @staticmethod + def check_output() -> None: + op = globals()["op"] + from tritonbench.utils.triton_op import REGISTERED_BENCHMARKS + + output = op.output + output_impls = output.result[0][1].keys() + ci_enabled_impls = [ + x for x in REGISTERED_BENCHMARKS[output.op_name].keys() if x not in op._skip + ] + # Make sure that all the ci_enabled impls are in the output + assert set(output_impls) == set( + ci_enabled_impls + ), f"output impls: {output_impls} != ci_enabled impls: {ci_enabled_impls}" + def del_op_instance(self): self.worker.run( """ diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 7ae0c458..785583c7 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -522,6 +522,22 @@ def __call__(cls, *args, **kwargs): return obj +def _translate_mode(tb_args): + def _has_and_true(attr): + if hasattr(tb_args, attr) and getattr(tb_args, attr): + return True + return False + + if _has_and_true("fwd"): + tb_args.mode = "fwd" + if _has_and_true("bwd"): + tb_args.mode = "bwd" + if _has_and_true("fwd_bwd"): + tb_args.mode = "fwd_bwd" + if _has_and_true("fwd_no_grad"): + tb_args.mode = "fwd_no_grad" + + class BenchmarkOperator(metaclass=PostInitProcessor): mode: Mode = Mode.FWD test: str = "eval" @@ -555,7 +571,7 @@ def __init__( self.use_cuda_graphs = ( self.tb_args.cudagraph if self.tb_args.cudagraph else self.use_cuda_graphs ) - # we accept both "fwd" and "eval" + _translate_mode(self.tb_args) if self.tb_args.mode == "fwd": self.mode = Mode.FWD elif self.tb_args.mode == "fwd_bwd": @@ -565,7 +581,7 @@ def __init__( else: assert ( self.tb_args.mode == "bwd" - ), f"We only accept 3 test modes: fwd(eval), fwd_bwd(train), or bwd." + ), "We only accept test modes: fwd, bwd, fwd_bwd, or fwd_no_grad." self.mode = Mode.BWD self.device = tb_args.device self.required_metrics = (