Skip to content
Merged
11 changes: 8 additions & 3 deletions .github/workflows/hopper-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ jobs:
set -x
pip uninstall -y triton
source ~/env-3.4.sh
env | grep -E '^(LLVM_SYSPATH)=' >> $GITHUB_ENV || true
MAX_JOBS=32 python3 -m pip install . --no-build-isolation

- name: FlagTree Build on NVidia (triton_v3.5.x branch)
Expand All @@ -92,6 +93,7 @@ jobs:
set -x
pip uninstall -y triton
source ~/env-3.5.sh
env | grep -E '^(LLVM_SYSPATH)=' >> $GITHUB_ENV || true
MAX_JOBS=32 python3 -m pip install . --no-build-isolation

- name: FlagTree Test on NVidia (triton_v3.4.x branch)
Expand Down Expand Up @@ -153,6 +155,9 @@ jobs:
python3 python/tutorials/hints/08/08-grouped-gemm.py --only_unit_test
python3 python/tutorials/hints/11/11-programmatic-dependent-launch.py --only_unit_test
# flagtree tle raw
# python3 python/tutorials/tle/raw/01-vector-add.py
# python3 python/tutorials/tle/raw/02-fused-softmax.py
# python3 python/tutorials/tle/raw/03-matrix-multiplication.py
python3 python/tutorials/tle/raw/01-vector-add.py
python3 python/tutorials/tle/raw/02-fused-softmax.py
python3 python/tutorials/tle/raw/03-matrix-multiplication.py
python3 python/tutorials/tle/raw/04-hello-world.py
python3 python/tutorials/tle/raw/05-topk.py
python3 python/tutorials/tle/raw/06-test-vassert.py
4 changes: 2 additions & 2 deletions python/triton/experimental/tle/raw/mlir/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .runtime import EdslMLIRJITFunction
from .utils import vprintf
from .utils import vprintf, vassert

__all__ = ["EdslMLIRJITFunction", "vprintf"]
__all__ = ["EdslMLIRJITFunction", "vprintf", "vassert"]
11 changes: 11 additions & 0 deletions python/triton/experimental/tle/raw/mlir/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> func.FuncOp:
output_tys: List[ir.Type] = []
output_indices: List[int] = []
for idx, arg in enumerate(node.args.args):
# issue#328 [bug]edsl InOut&Input anno F722 error
# https://github.com/flagos-ai/FlagTree/issues/328
# use while find method to fix the bug,
# remember replace below arg.annotation.slice.value with type_str
'''
slice_node = arg.annotation.slice
if isinstance(slice_node, ast.Subscript):
type_str = slice_node.slice.value
else:
type_str = slice_node.value
'''
if arg.annotation.value.id == "InOut":
ty: ir.Type = ir.Type.parse(arg.annotation.slice.value)
operand_tys += [ty]
Expand Down
85 changes: 82 additions & 3 deletions python/triton/experimental/tle/raw/mlir/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations
from abc import abstractmethod
import base64
from hashlib import blake2s
import inspect
import os
from typing import TYPE_CHECKING, Any, Final, List
from typing_extensions import override
from hashlib import blake2s

from mlir import ir
from mlir.dialects import arith, func, llvm
import base64
from mlir.dialects import arith, func, llvm, scf

if TYPE_CHECKING:
from .codegen import EdslMLIRCodeGenerator

Expand Down Expand Up @@ -71,3 +75,78 @@ def call(self, codegen: EdslMLIRCodeGenerator) -> func.CallOp:

def vprintf(*args) -> VPrintf:
return VPrintf(args)


class Assert(ExternalCall):

def __init__(self, cond, msg, file_name, func_name, line_no, *args, **kwargs) -> None:
dependencies = [cond] + list(args)
super().__init__("__assertfail", dependencies, **kwargs)
self.cond = cond
self.msg = msg
self.file_name = file_name
self.func_name = func_name
self.line_no = line_no
self.print_args = args

@override
def build(self) -> func.FuncOp:
ptr_type = ir.Type.parse("!llvm.ptr")
i32_type = ir.IntegerType.get_signless(32)
i64_type = ir.IntegerType.get_signless(64)

return func.FuncOp(self.keyword, ir.FunctionType.get([ptr_type, ptr_type, i32_type, ptr_type, i64_type], []),
visibility="private")

@override
def call(self, codegen: EdslMLIRCodeGenerator) -> Any:
func_op = self.decl(codegen)

true_const = arith.constant(ir.IntegerType.get_signless(1), 1)
is_false = arith.xori(self.cond, true_const)

if_op = scf.IfOp(is_false)
with ir.InsertionPoint(if_op.then_block):

debug_args = [self.msg]
if self.print_args:
debug_args.extend(self.print_args)
VPrintf(debug_args).call(codegen)

# 1. Message String
msg_global = self.global_string(self.msg, codegen)
msg_ptr = llvm.AddressOfOp(ir.Type.parse("!llvm.ptr"), msg_global.sym_name.value)

# 2. File Name String
file_global = self.global_string(self.file_name, codegen)
file_ptr = llvm.AddressOfOp(ir.Type.parse("!llvm.ptr"), file_global.sym_name.value)

# 3. Line Number (Integer)
line_val = arith.constant(ir.IntegerType.get_signless(32), self.line_no)

# 4. Function Name String
func_global = self.global_string(self.func_name, codegen)
func_ptr = llvm.AddressOfOp(ir.Type.parse("!llvm.ptr"), func_global.sym_name.value)

# 5. Char Size
char_size_val = arith.constant(ir.IntegerType.get_signless(64), 1)

#__assertfail
func.call([], ir.FlatSymbolRefAttr.get(func_op.name.value),
[msg_ptr, file_ptr, line_val, func_ptr, char_size_val])

scf.yield_([])

return if_op


def vassert(cond, fmt, *args):
frame = inspect.currentframe().f_back
try:
filename = os.path.basename(frame.f_code.co_filename)
funcname = frame.f_code.co_name
lineno = frame.f_lineno
finally:
del frame

return Assert(cond, fmt, filename, funcname, lineno, *args)
12 changes: 12 additions & 0 deletions python/tutorials/tle/raw/05-topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import triton
from triton.experimental.tle.raw import dialect, InOut, Input
from triton.experimental.tle.raw.mlir import vassert
import triton.experimental.tle.language.raw as tle_raw
import triton.language as tl

Expand Down Expand Up @@ -40,6 +41,17 @@ def edsl1(thre_bin_sum_buf: InOut["memref<?xi32, 3>"], l_new_topk_buf: InOut["me
tidx = nvvm.read_ptx_sreg_tid_x(ir.IntegerType.get_signless(32))
bidx = nvvm.read_ptx_sreg_ctaid_x(ir.IntegerType.get_signless(32))
bdimx = nvvm.read_ptx_sreg_ntid_x(ir.IntegerType.get_signless(32)) # blockDim.x

# --- Start: Runtime Assertion for BlockDim.x == 1024 ---
i32_ty = ir.IntegerType.get_signless(32)
c1024 = arith.constant(i32_ty, 1024)
is_valid_dim = arith.cmpi(arith.CmpIPredicate.eq, bdimx, c1024)
c0 = arith.constant(i32_ty, 0)
is_not_thread_0 = arith.cmpi(arith.CmpIPredicate.ne, tidx, c0)
should_pass = arith.ori(is_valid_dim, is_not_thread_0)
vassert(should_pass, "Runtime Error: BlockDim.x is incorrect, expected 1024.\n")
# --- End: Runtime Assertion ---

i32_ty = ir.IntegerType.get_signless(32)
i16_ty = ir.IntegerType.get_signless(16)
index_ty = ir.IndexType.get()
Expand Down
63 changes: 63 additions & 0 deletions python/tutorials/tle/raw/06-test-vassert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import triton
from triton.experimental.tle.raw import dialect
from triton.experimental.tle.raw.mlir import vprintf, vassert
import triton.experimental.tle.language.raw as tle_raw
import torch
import sys

from mlir.dialects import nvvm, arith
from mlir import ir

DEVICE = triton.runtime.driver.active.get_active_torch_device()


@dialect(name="mlir")
def edsl_assert_test():
tidx = nvvm.read_ptx_sreg_tid_x(ir.IntegerType.get_signless(32))
bidx = nvvm.read_ptx_sreg_ctaid_x(ir.IntegerType.get_signless(32))

c0 = arith.constant(ir.IntegerType.get_signless(32), 0)
c1 = arith.constant(ir.IntegerType.get_signless(32), 1)
cond_false = arith.cmpi(arith.CmpIPredicate.eq, c0, c1)

vassert(cond_false, "TEST ASSERT: Block %d, Thread %d should fail!\n", bidx, tidx)

vprintf("ERROR: This line should NOT be reached! bidx=%d\n", bidx)


@triton.jit
def assert_kernel():
tle_raw.call(edsl_assert_test, [], [])


def run_test():
print(">>> Starting Assert Test (Expect Crash)...")

try:
assert_kernel[(1, )]()
torch.cuda.synchronize()

except RuntimeError as e:
msg = str(e)
if "device-side assert triggered" in msg or "unspecified launch failure" in msg:
print("\n✅ [SUCCESS] Assert triggered successfully!")
print(f" Captured Error: {msg}")
return True
else:
print(f"\n❌ [FAIL] Caught unexpected RuntimeError: {msg}")
return False

except Exception as e:
print(f"\n❌ [FAIL] Caught unexpected exception: {type(e)}")
print(e)
return False

else:
print("\n❌ [FAIL] Kernel finished without error (Assert did NOT trigger)")
return False


if __name__ == "__main__":
success = run_test()
if not success:
sys.exit(1)