diff --git a/.github/workflows/hopper-build-and-test.yml b/.github/workflows/hopper-build-and-test.yml index e3af19318..fb20a4362 100644 --- a/.github/workflows/hopper-build-and-test.yml +++ b/.github/workflows/hopper-build-and-test.yml @@ -83,6 +83,7 @@ jobs: set -x pip uninstall -y triton source ~/env-3.4.sh + env | grep -E '^(LLVM_SYSPATH)=' >> $GITHUB_ENV || true MAX_JOBS=32 python3 -m pip install . --no-build-isolation - name: FlagTree Build on NVidia (triton_v3.5.x branch) @@ -92,6 +93,7 @@ jobs: set -x pip uninstall -y triton source ~/env-3.5.sh + env | grep -E '^(LLVM_SYSPATH)=' >> $GITHUB_ENV || true MAX_JOBS=32 python3 -m pip install . --no-build-isolation - name: FlagTree Test on NVidia (triton_v3.4.x branch) @@ -153,6 +155,9 @@ jobs: python3 python/tutorials/hints/08/08-grouped-gemm.py --only_unit_test python3 python/tutorials/hints/11/11-programmatic-dependent-launch.py --only_unit_test # flagtree tle raw - # python3 python/tutorials/tle/raw/01-vector-add.py - # python3 python/tutorials/tle/raw/02-fused-softmax.py - # python3 python/tutorials/tle/raw/03-matrix-multiplication.py + python3 python/tutorials/tle/raw/01-vector-add.py + python3 python/tutorials/tle/raw/02-fused-softmax.py + python3 python/tutorials/tle/raw/03-matrix-multiplication.py + python3 python/tutorials/tle/raw/04-hello-world.py + python3 python/tutorials/tle/raw/05-topk.py + python3 python/tutorials/tle/raw/06-test-vassert.py diff --git a/python/triton/experimental/tle/raw/mlir/__init__.py b/python/triton/experimental/tle/raw/mlir/__init__.py index ddb6b8bd7..447446151 100644 --- a/python/triton/experimental/tle/raw/mlir/__init__.py +++ b/python/triton/experimental/tle/raw/mlir/__init__.py @@ -1,4 +1,4 @@ from .runtime import EdslMLIRJITFunction -from .utils import vprintf +from .utils import vprintf, vassert -__all__ = ["EdslMLIRJITFunction", "vprintf"] +__all__ = ["EdslMLIRJITFunction", "vprintf", "vassert"] diff --git a/python/triton/experimental/tle/raw/mlir/codegen.py b/python/triton/experimental/tle/raw/mlir/codegen.py index 436a9aeaf..150600dca 100644 --- a/python/triton/experimental/tle/raw/mlir/codegen.py +++ b/python/triton/experimental/tle/raw/mlir/codegen.py @@ -91,6 +91,17 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> func.FuncOp: output_tys: List[ir.Type] = [] output_indices: List[int] = [] for idx, arg in enumerate(node.args.args): + # issue#328 [bug]edsl InOut&Input anno F722 error + # https://github.com/flagos-ai/FlagTree/issues/328 + # use while find method to fix the bug, + # remember replace below arg.annotation.slice.value with type_str + ''' + slice_node = arg.annotation.slice + if isinstance(slice_node, ast.Subscript): + type_str = slice_node.slice.value + else: + type_str = slice_node.value + ''' if arg.annotation.value.id == "InOut": ty: ir.Type = ir.Type.parse(arg.annotation.slice.value) operand_tys += [ty] diff --git a/python/triton/experimental/tle/raw/mlir/utils.py b/python/triton/experimental/tle/raw/mlir/utils.py index 51764a9e8..e77ed7048 100644 --- a/python/triton/experimental/tle/raw/mlir/utils.py +++ b/python/triton/experimental/tle/raw/mlir/utils.py @@ -1,11 +1,15 @@ from __future__ import annotations from abc import abstractmethod +import base64 +from hashlib import blake2s +import inspect +import os from typing import TYPE_CHECKING, Any, Final, List from typing_extensions import override -from hashlib import blake2s + from mlir import ir -from mlir.dialects import arith, func, llvm -import base64 +from mlir.dialects import arith, func, llvm, scf + if TYPE_CHECKING: from .codegen import EdslMLIRCodeGenerator @@ -71,3 +75,78 @@ def call(self, codegen: EdslMLIRCodeGenerator) -> func.CallOp: def vprintf(*args) -> VPrintf: return VPrintf(args) + + +class Assert(ExternalCall): + + def __init__(self, cond, msg, file_name, func_name, line_no, *args, **kwargs) -> None: + dependencies = [cond] + list(args) + super().__init__("__assertfail", dependencies, **kwargs) + self.cond = cond + self.msg = msg + self.file_name = file_name + self.func_name = func_name + self.line_no = line_no + self.print_args = args + + @override + def build(self) -> func.FuncOp: + ptr_type = ir.Type.parse("!llvm.ptr") + i32_type = ir.IntegerType.get_signless(32) + i64_type = ir.IntegerType.get_signless(64) + + return func.FuncOp(self.keyword, ir.FunctionType.get([ptr_type, ptr_type, i32_type, ptr_type, i64_type], []), + visibility="private") + + @override + def call(self, codegen: EdslMLIRCodeGenerator) -> Any: + func_op = self.decl(codegen) + + true_const = arith.constant(ir.IntegerType.get_signless(1), 1) + is_false = arith.xori(self.cond, true_const) + + if_op = scf.IfOp(is_false) + with ir.InsertionPoint(if_op.then_block): + + debug_args = [self.msg] + if self.print_args: + debug_args.extend(self.print_args) + VPrintf(debug_args).call(codegen) + + # 1. Message String + msg_global = self.global_string(self.msg, codegen) + msg_ptr = llvm.AddressOfOp(ir.Type.parse("!llvm.ptr"), msg_global.sym_name.value) + + # 2. File Name String + file_global = self.global_string(self.file_name, codegen) + file_ptr = llvm.AddressOfOp(ir.Type.parse("!llvm.ptr"), file_global.sym_name.value) + + # 3. Line Number (Integer) + line_val = arith.constant(ir.IntegerType.get_signless(32), self.line_no) + + # 4. Function Name String + func_global = self.global_string(self.func_name, codegen) + func_ptr = llvm.AddressOfOp(ir.Type.parse("!llvm.ptr"), func_global.sym_name.value) + + # 5. Char Size + char_size_val = arith.constant(ir.IntegerType.get_signless(64), 1) + + #__assertfail + func.call([], ir.FlatSymbolRefAttr.get(func_op.name.value), + [msg_ptr, file_ptr, line_val, func_ptr, char_size_val]) + + scf.yield_([]) + + return if_op + + +def vassert(cond, fmt, *args): + frame = inspect.currentframe().f_back + try: + filename = os.path.basename(frame.f_code.co_filename) + funcname = frame.f_code.co_name + lineno = frame.f_lineno + finally: + del frame + + return Assert(cond, fmt, filename, funcname, lineno, *args) diff --git a/python/tutorials/tle/raw/05-topk.py b/python/tutorials/tle/raw/05-topk.py index c71ef4d14..559a11b60 100644 --- a/python/tutorials/tle/raw/05-topk.py +++ b/python/tutorials/tle/raw/05-topk.py @@ -7,6 +7,7 @@ import torch import triton from triton.experimental.tle.raw import dialect, InOut, Input +from triton.experimental.tle.raw.mlir import vassert import triton.experimental.tle.language.raw as tle_raw import triton.language as tl @@ -40,6 +41,17 @@ def edsl1(thre_bin_sum_buf: InOut["memref"], l_new_topk_buf: InOut["me tidx = nvvm.read_ptx_sreg_tid_x(ir.IntegerType.get_signless(32)) bidx = nvvm.read_ptx_sreg_ctaid_x(ir.IntegerType.get_signless(32)) bdimx = nvvm.read_ptx_sreg_ntid_x(ir.IntegerType.get_signless(32)) # blockDim.x + + # --- Start: Runtime Assertion for BlockDim.x == 1024 --- + i32_ty = ir.IntegerType.get_signless(32) + c1024 = arith.constant(i32_ty, 1024) + is_valid_dim = arith.cmpi(arith.CmpIPredicate.eq, bdimx, c1024) + c0 = arith.constant(i32_ty, 0) + is_not_thread_0 = arith.cmpi(arith.CmpIPredicate.ne, tidx, c0) + should_pass = arith.ori(is_valid_dim, is_not_thread_0) + vassert(should_pass, "Runtime Error: BlockDim.x is incorrect, expected 1024.\n") + # --- End: Runtime Assertion --- + i32_ty = ir.IntegerType.get_signless(32) i16_ty = ir.IntegerType.get_signless(16) index_ty = ir.IndexType.get() diff --git a/python/tutorials/tle/raw/06-test-vassert.py b/python/tutorials/tle/raw/06-test-vassert.py new file mode 100644 index 000000000..d0608be4c --- /dev/null +++ b/python/tutorials/tle/raw/06-test-vassert.py @@ -0,0 +1,63 @@ +import triton +from triton.experimental.tle.raw import dialect +from triton.experimental.tle.raw.mlir import vprintf, vassert +import triton.experimental.tle.language.raw as tle_raw +import torch +import sys + +from mlir.dialects import nvvm, arith +from mlir import ir + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@dialect(name="mlir") +def edsl_assert_test(): + tidx = nvvm.read_ptx_sreg_tid_x(ir.IntegerType.get_signless(32)) + bidx = nvvm.read_ptx_sreg_ctaid_x(ir.IntegerType.get_signless(32)) + + c0 = arith.constant(ir.IntegerType.get_signless(32), 0) + c1 = arith.constant(ir.IntegerType.get_signless(32), 1) + cond_false = arith.cmpi(arith.CmpIPredicate.eq, c0, c1) + + vassert(cond_false, "TEST ASSERT: Block %d, Thread %d should fail!\n", bidx, tidx) + + vprintf("ERROR: This line should NOT be reached! bidx=%d\n", bidx) + + +@triton.jit +def assert_kernel(): + tle_raw.call(edsl_assert_test, [], []) + + +def run_test(): + print(">>> Starting Assert Test (Expect Crash)...") + + try: + assert_kernel[(1, )]() + torch.cuda.synchronize() + + except RuntimeError as e: + msg = str(e) + if "device-side assert triggered" in msg or "unspecified launch failure" in msg: + print("\n✅ [SUCCESS] Assert triggered successfully!") + print(f" Captured Error: {msg}") + return True + else: + print(f"\n❌ [FAIL] Caught unexpected RuntimeError: {msg}") + return False + + except Exception as e: + print(f"\n❌ [FAIL] Caught unexpected exception: {type(e)}") + print(e) + return False + + else: + print("\n❌ [FAIL] Kernel finished without error (Assert did NOT trigger)") + return False + + +if __name__ == "__main__": + success = run_test() + if not success: + sys.exit(1)