Skip to content
Merged
4 changes: 2 additions & 2 deletions python/triton/experimental/tle/raw/mlir/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .runtime import EdslMLIRJITFunction
from .utils import vprintf
from .utils import vprintf, vassert

__all__ = ["EdslMLIRJITFunction", "vprintf"]
__all__ = ["EdslMLIRJITFunction", "vprintf", "vassert"]
79 changes: 78 additions & 1 deletion python/triton/experimental/tle/raw/mlir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing_extensions import override

from mlir import ir
from mlir.dialects import arith, func, llvm
from mlir.dialects import arith, func, llvm, scf
import inspect
import os

if TYPE_CHECKING:
from .codegen import EdslMLIRCodeGenerator
Expand Down Expand Up @@ -70,3 +72,78 @@ def call(self, codegen: EdslMLIRCodeGenerator) -> func.CallOp:

def vprintf(*args) -> VPrintf:
return VPrintf(args)


class Assert(ExternalCall):

def __init__(self, cond, msg, file_name, func_name, line_no, *args, **kwargs) -> None:
dependencies = [cond] + list(args)
super().__init__("__assertfail", dependencies, **kwargs)
self.cond = cond
self.msg = msg
self.file_name = file_name
self.func_name = func_name
self.line_no = line_no
self.print_args = args

@override
def build(self) -> func.FuncOp:
ptr_type = ir.Type.parse("!llvm.ptr")
i32_type = ir.IntegerType.get_signless(32)
i64_type = ir.IntegerType.get_signless(64)

return func.FuncOp(self.keyword, ir.FunctionType.get([ptr_type, ptr_type, i32_type, ptr_type, i64_type], []),
visibility="private")

@override
def call(self, codegen: EdslMLIRCodeGenerator) -> Any:
func_op = self.decl(codegen)

true_const = arith.constant(ir.IntegerType.get_signless(1), 1)
is_false = arith.xori(self.cond, true_const)

if_op = scf.IfOp(is_false)
with ir.InsertionPoint(if_op.then_block):

debug_args = [self.msg]
if self.print_args:
debug_args.extend(self.print_args)
VPrintf(debug_args).call(codegen)

# 1. Message String
msg_global = self.global_string(self.msg, codegen)
msg_ptr = llvm.AddressOfOp(ir.Type.parse("!llvm.ptr"), msg_global.sym_name.value)

# 2. File Name String
file_global = self.global_string(self.file_name, codegen)
file_ptr = llvm.AddressOfOp(ir.Type.parse("!llvm.ptr"), file_global.sym_name.value)

# 3. Line Number (Integer)
line_val = arith.constant(ir.IntegerType.get_signless(32), self.line_no)

# 4. Function Name String
func_global = self.global_string(self.func_name, codegen)
func_ptr = llvm.AddressOfOp(ir.Type.parse("!llvm.ptr"), func_global.sym_name.value)

# 5. Char Size
char_size_val = arith.constant(ir.IntegerType.get_signless(64), 1)

#__assertfail
func.call([], ir.FlatSymbolRefAttr.get(func_op.name.value),
[msg_ptr, file_ptr, line_val, func_ptr, char_size_val])

scf.yield_([])

return if_op


def vassert(cond, fmt, *args):
frame = inspect.currentframe().f_back
try:
filename = os.path.basename(frame.f_code.co_filename)
funcname = frame.f_code.co_name
lineno = frame.f_lineno
finally:
del frame

return Assert(cond, fmt, filename, funcname, lineno, *args)
63 changes: 63 additions & 0 deletions python/tutorials/tle/raw/06-test-vassert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import triton
from triton.experimental.tle.raw import dialect
from triton.experimental.tle.raw.mlir import vprintf, vassert
import triton.experimental.tle.language.raw as tle_raw
import torch
import sys

from mlir.dialects import nvvm, arith
from mlir import ir

DEVICE = triton.runtime.driver.active.get_active_torch_device()


@dialect(name="mlir")
def edsl_assert_test():
tidx = nvvm.read_ptx_sreg_tid_x(ir.IntegerType.get_signless(32))
bidx = nvvm.read_ptx_sreg_ctaid_x(ir.IntegerType.get_signless(32))

c0 = arith.constant(ir.IntegerType.get_signless(32), 0)
c1 = arith.constant(ir.IntegerType.get_signless(32), 1)
cond_false = arith.cmpi(arith.CmpIPredicate.eq, c0, c1)

vassert(cond_false, "TEST ASSERT: Block %d, Thread %d should fail!\n", bidx, tidx)

vprintf("ERROR: This line should NOT be reached! bidx=%d\n", bidx)


@triton.jit
def assert_kernel():
tle_raw.call(edsl_assert_test, [], [])


def run_test():
print(">>> Starting Assert Test (Expect Crash)...")

try:
assert_kernel[(1, )]()
torch.cuda.synchronize()

except RuntimeError as e:
msg = str(e)
if "device-side assert triggered" in msg or "unspecified launch failure" in msg:
print("\n✅ [SUCCESS] Assert triggered successfully!")
print(f" Captured Error: {msg}")
return True
else:
print(f"\n❌ [FAIL] Caught unexpected RuntimeError: {msg}")
return False

except Exception as e:
print(f"\n❌ [FAIL] Caught unexpected exception: {type(e)}")
print(e)
return False

else:
print("\n❌ [FAIL] Kernel finished without error (Assert did NOT trigger)")
return False


if __name__ == "__main__":
success = run_test()
if not success:
sys.exit(1)