From 89293c755c02e67ba3f7fd58eeae02785f5d2afe Mon Sep 17 00:00:00 2001 From: Jinjie Liu Date: Mon, 2 Feb 2026 11:09:37 +0800 Subject: [PATCH] move signature from annotations to decl Signed-off-by: Jinjie Liu --- .../triton/experimental/tle/raw/__init__.py | 3 +- .../experimental/tle/raw/mlir/codegen.py | 53 +++++---- .../experimental/tle/raw/mlir/runtime.py | 22 +++- python/triton/experimental/tle/raw/runtime.py | 15 ++- python/triton/experimental/tle/raw/typing.py | 15 --- .../tutorials/hints/06/06-fused-attention.py | 3 +- python/tutorials/tle/raw/01-vector-add.py | 9 +- python/tutorials/tle/raw/02-fused-softmax.py | 11 +- .../tle/raw/03-matrix-multiplication.py | 112 ++++++++++++------ python/tutorials/tle/raw/04-hello-world.py | 2 +- python/tutorials/tle/raw/05-topk.py | 82 +++++++++---- 11 files changed, 212 insertions(+), 115 deletions(-) delete mode 100644 python/triton/experimental/tle/raw/typing.py diff --git a/python/triton/experimental/tle/raw/__init__.py b/python/triton/experimental/tle/raw/__init__.py index d5b5ce650..bd8f8e8da 100644 --- a/python/triton/experimental/tle/raw/__init__.py +++ b/python/triton/experimental/tle/raw/__init__.py @@ -1,4 +1,3 @@ from .runtime import dialect -from .typing import Input, InOut -__all__ = ["dialect", "Input", "InOut"] +__all__ = ["dialect"] diff --git a/python/triton/experimental/tle/raw/mlir/codegen.py b/python/triton/experimental/tle/raw/mlir/codegen.py index 436a9aeaf..ca8d471e9 100644 --- a/python/triton/experimental/tle/raw/mlir/codegen.py +++ b/python/triton/experimental/tle/raw/mlir/codegen.py @@ -1,5 +1,6 @@ import ast -from typing import Any, Dict, Final, List, Optional, Sequence +from functools import cached_property +from typing import Any, Dict, Final, List, Optional, Sequence, Tuple, Union from typing_extensions import override from mlir import ir @@ -16,10 +17,19 @@ def __init__(self, name: str, *args, **kwargs) -> None: class EdslMLIRCodeGenerator(ast.NodeVisitor): - def __init__(self, absfilename: str, lscope: Dict[str, Any] = None, gscope: Dict[str, Any] = {}, - context: Optional[ir.Context] = None, *args, **kwargs) -> None: + def __init__( + self, + absfilename: str, + signature: Tuple[Sequence[Union[str, ir.Type]], Sequence[Union[str, ir.Type]]], + lscope: Dict[str, Any] = None, + gscope: Dict[str, Any] = {}, + context: Optional[ir.Context] = None, + *args, + **kwargs, + ) -> None: super().__init__(*args, **kwargs) self.absfilename: Final[str] = absfilename + self.signature: Final[Tuple[Sequence[Union[str, ir.Type]], Sequence[Union[str, ir.Type]]]] = signature self.lscope: Final[Dict[str, Any]] = {**lscope} self.gscope: Final[Dict[str, Any]] = {**gscope} self.decls: Final[Dict[str, func.FuncOp]] = {} @@ -87,32 +97,14 @@ def visit_For(self, node: ast.For) -> None: @override def visit_FunctionDef(self, node: ast.FunctionDef) -> func.FuncOp: with self.context, ir.Location.file(self.absfilename, node.lineno, node.col_offset): - operand_tys: List[ir.Type] = [] - output_tys: List[ir.Type] = [] - output_indices: List[int] = [] - for idx, arg in enumerate(node.args.args): - if arg.annotation.value.id == "InOut": - ty: ir.Type = ir.Type.parse(arg.annotation.slice.value) - operand_tys += [ty] - output_tys += [ty] - output_indices += [idx] - elif arg.annotation.value.id == "Input": - ty: ir.Type = ir.Type.parse(arg.annotation.slice.value) - operand_tys += [ty] - elif arg.annotation.value.id == "Num": - ty: ir.Type = ir.Type.parse(arg.annotation.slice.value) - operand_tys += [ty] - else: - raise NotImplementedError(f"unsupported argument annotation: {ast.dump(arg.annotation)}") - fnty: ir.FunctionType = ir.FunctionType.get(operand_tys, output_tys) - fn: func.FuncOp = func.FuncOp(node.name, fnty, visibility="public") + fn: func.FuncOp = func.FuncOp(node.name, self.funcdef, visibility="public") block: ir.Block = fn.add_entry_block() for k, arg in zip(map(lambda arg: arg.arg, node.args.args), block.arguments): self.lscope[k] = arg with ir.InsertionPoint(block): for stmt in node.body: self.visit(stmt) - func.return_([block.arguments[idx] for idx in output_indices]) + func.return_([arg for arg, _ in zip(block.arguments, self.funcdef.results)]) return fn @override @@ -156,3 +148,18 @@ def visit_With(self, node: ast.With) -> None: with self.visit(item.context_expr): for stmt in node.body: self.visit(stmt) + + @cached_property + def funcdef(self) -> ir.FunctionType: + outputs, inputs = self.signature + outputs: List[ir.Type] = [self.canonicalize(t) for t in outputs] + inputs: List[ir.Type] = [self.canonicalize(t) for t in inputs] + operands: List[ir.Type] = [*outputs, *inputs] + return ir.FunctionType.get(operands, outputs) + + @staticmethod + def canonicalize(type: Union[str, ir.Type]) -> ir.Type: + if isinstance(type, ir.Type): + return type + else: + return ir.Type.parse(type) diff --git a/python/triton/experimental/tle/raw/mlir/runtime.py b/python/triton/experimental/tle/raw/mlir/runtime.py index 190947045..1113471d7 100644 --- a/python/triton/experimental/tle/raw/mlir/runtime.py +++ b/python/triton/experimental/tle/raw/mlir/runtime.py @@ -3,7 +3,7 @@ import copy from functools import cached_property import inspect -from typing import Any, Dict, Final, List, Optional +from typing import Any, Dict, Final, List, Optional, Sequence, Tuple, Union from mlir import ir from mlir.passmanager import PassManager @@ -13,15 +13,29 @@ class EdslMLIRJITFunction(object): - def __init__(self, fn: Any, pipeline: List[str], context: Optional[ir.Context] = None, *args, **kwargs) -> None: + def __init__( + self, + fn: Any, + signature: Tuple[Sequence[Union[str, ir.Type]], Sequence[Union[str, ir.Type]]], + pipeline: List[str], + context: Optional[ir.Context] = None, + *args, + **kwargs, + ) -> None: super().__init__(*args, **kwargs) self.fn: Final[Any] = fn + self.signature: Final[Tuple[Sequence[Union[str, ir.Type]], Sequence[Union[str, ir.Type]]]] = signature self.pipeline: Final[List[str]] = [*pipeline] self.context: Final[ir.Context] = ir.Context() if context is None else context self.__triton_builtin__: Final[bool] = True def __deepcopy__(self, memo: Dict[int, Any]) -> EdslMLIRJITFunction: - return self.__class__(copy.deepcopy(self.fn, memo), copy.deepcopy(self.pipeline, memo), self.context) + return self.__class__( + copy.deepcopy(self.fn, memo), + copy.deepcopy(self.signature, memo), + copy.deepcopy(self.pipeline, memo), + self.context, + ) @cached_property def ast(self) -> ast.Module: @@ -41,7 +55,7 @@ def globals(self) -> Dict[str, Any]: @cached_property def codegen(self) -> EdslMLIRCodeGenerator: - return EdslMLIRCodeGenerator(self.absfilename, {}, self.globals, self.context) + return EdslMLIRCodeGenerator(self.absfilename, self.signature, {}, self.globals, self.context) @property def ir(self) -> ir.Module: diff --git a/python/triton/experimental/tle/raw/runtime.py b/python/triton/experimental/tle/raw/runtime.py index cc95f0159..6b0f2c40b 100644 --- a/python/triton/experimental/tle/raw/runtime.py +++ b/python/triton/experimental/tle/raw/runtime.py @@ -1,16 +1,19 @@ from .mlir import EdslMLIRJITFunction -from typing import List +from typing import List, Sequence, Tuple, Union + +from mlir import ir registry = {"mlir": EdslMLIRJITFunction} -def dialect(*, name: str, pipeline: List[str] = [ - "convert-scf-to-cf", "finalize-memref-to-llvm", "convert-arith-to-llvm", "convert-cf-to-llvm", - "convert-func-to-llvm", "convert-index-to-llvm", "convert-nvvm-to-llvm", "cse" -]): +def dialect(*, name: str, signature: Tuple[Sequence[Union[str, ir.Type]], Sequence[Union[str, ir.Type]]], + pipeline: List[str] = [ + "convert-scf-to-cf", "finalize-memref-to-llvm", "convert-arith-to-llvm", "convert-cf-to-llvm", + "convert-func-to-llvm", "convert-index-to-llvm", "convert-nvvm-to-llvm", "cse" + ]): def decorator(fn): - edsl = registry[name](fn, pipeline=pipeline) + edsl = registry[name](fn, signature=signature, pipeline=pipeline) return edsl return decorator diff --git a/python/triton/experimental/tle/raw/typing.py b/python/triton/experimental/tle/raw/typing.py deleted file mode 100644 index b03fd65f2..000000000 --- a/python/triton/experimental/tle/raw/typing.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Annotated - -from mlir import ir - - -class InOut: - - def __class_getitem__(cls, desc: str) -> Annotated[ir.MemRefType, str]: - return Annotated[ir.MemRefType, desc] - - -class Input: - - def __class_getitem__(cls, desc: str) -> Annotated[ir.MemRefType, str]: - return Annotated[ir.MemRefType, desc] diff --git a/python/tutorials/hints/06/06-fused-attention.py b/python/tutorials/hints/06/06-fused-attention.py index 6f54f627e..c1030f7d1 100644 --- a/python/tutorials/hints/06/06-fused-attention.py +++ b/python/tutorials/hints/06/06-fused-attention.py @@ -255,7 +255,8 @@ def _attn_bwd_preprocess(O, DO, # off_n = tl.arange(0, HEAD_DIM) # load o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) # @hint: shared_memory - do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) # @hint: shared_memory + do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to( + tl.float32) # @hint: shared_memory delta = tl.sum(o * do, axis=1) # write-back tl.store(Delta + off_hz * N_CTX + off_m, delta) diff --git a/python/tutorials/tle/raw/01-vector-add.py b/python/tutorials/tle/raw/01-vector-add.py index f2e2f5ecf..19b6bdbf3 100644 --- a/python/tutorials/tle/raw/01-vector-add.py +++ b/python/tutorials/tle/raw/01-vector-add.py @@ -3,15 +3,14 @@ import torch import triton import triton.language as tl -from triton.experimental.tle.raw import dialect, Input +from triton.experimental.tle.raw import dialect import triton.experimental.tle.language.raw as tle_raw DEVICE = triton.runtime.driver.active.get_active_torch_device() -@dialect(name="mlir") -def edsl(output: Input["!llvm.ptr<1>"], x: Input["!llvm.ptr<1>"], y: Input["!llvm.ptr<1>"], # noqa: F722, - n_elements: Input["i32"]): # noqa: F821 +@dialect(name="mlir", signature=([], ["!llvm.ptr<1>", "!llvm.ptr<1>", "!llvm.ptr<1>", "i32"])) +def edsl(output: ir.Value, x: ir.Value, y: ir.Value, n_elements: ir.Value): tidx = nvvm.read_ptx_sreg_tid_x(ir.IntegerType.get_signless(32)) bdimx = nvvm.read_ptx_sreg_ntid_x(ir.IntegerType.get_signless(32)) bidx = nvvm.read_ptx_sreg_ctaid_x(ir.IntegerType.get_signless(32)) @@ -49,7 +48,7 @@ def add(x: torch.Tensor, y: torch.Tensor): output = torch.empty_like(x) assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) return output diff --git a/python/tutorials/tle/raw/02-fused-softmax.py b/python/tutorials/tle/raw/02-fused-softmax.py index 787e5637b..6947f6ec7 100644 --- a/python/tutorials/tle/raw/02-fused-softmax.py +++ b/python/tutorials/tle/raw/02-fused-softmax.py @@ -3,7 +3,7 @@ import torch import triton import triton.language as tl -from triton.experimental.tle.raw import dialect, InOut, Input +from triton.experimental.tle.raw import dialect import triton.experimental.tle.language.raw as tle_raw DEVICE = triton.runtime.driver.active.get_active_torch_device() @@ -18,8 +18,11 @@ def naive_softmax(x): return ret -@dialect(name="mlir") -def edsl(y: InOut["memref"], x: Input["memref"]): # noqa: F722 +@dialect( + name="mlir", + signature=(["memref, 3>"], ["memref, 3>"]), +) +def edsl(y: ir.Value, x: ir.Value): tidx = nvvm.read_ptx_sreg_tid_x(ir.IntegerType.get_signless(32)) bdimx = nvvm.read_ptx_sreg_ntid_x(ir.IntegerType.get_signless(32)) tidx = arith.index_cast(ir.IndexType.get(), tidx) @@ -112,7 +115,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets mask = col_offsets < n_cols - row = tl.load(input_ptrs, mask=mask, other=-float('inf')) + row = tl.load(input_ptrs, mask=mask, other=-float("inf")) softmax_output = tl.zeros_like(row) output_row_start_ptr = output_ptr + row_idx * output_row_stride output_ptrs = output_row_start_ptr + col_offsets diff --git a/python/tutorials/tle/raw/03-matrix-multiplication.py b/python/tutorials/tle/raw/03-matrix-multiplication.py index cb580f989..637ffc586 100644 --- a/python/tutorials/tle/raw/03-matrix-multiplication.py +++ b/python/tutorials/tle/raw/03-matrix-multiplication.py @@ -3,7 +3,7 @@ import torch import triton import triton.language as tl -from triton.experimental.tle.raw import dialect, InOut, Input +from triton.experimental.tle.raw import dialect import triton.experimental.tle.language.raw as tle_raw DEVICE = triton.runtime.driver.active.get_active_torch_device() @@ -11,45 +11,58 @@ def get_autotune_config(): return [ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, - num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, - num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, + num_stages=3, + num_warps=8, + ), + triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4) ] -@dialect(name="mlir") -def edsl(c: InOut["memref, 3>"], # noqa: F722 - a: Input["memref, 3>"], # noqa: F722 - b: Input["memref, 3>"]): # noqa: F722 +@dialect( + name="mlir", + signature=( + ["memref, 3>"], + ["memref, 3>", "memref, 3>"], + ), +) +def edsl(c: ir.Value, a: ir.Value, b: ir.Value): tidx = nvvm.read_ptx_sreg_tid_x(ir.IntegerType.get_signless(32)) bdimx = nvvm.read_ptx_sreg_ntid_x(ir.IntegerType.get_signless(32)) tidx = arith.index_cast(ir.IndexType.get(), tidx) @@ -61,8 +74,12 @@ def edsl(c: InOut["memref, 3>"], # noqa: F7 for i in scf.for_(tidx, numel, bdimx): row = arith.divsi(i, n) col = arith.remsi(i, n) - for j, arg, result in scf.for_(arith.constant(ir.IndexType.get(), 0), k, arith.constant(ir.IndexType.get(), 1), - [arith.constant(ir.F32Type.get(), 0.0)]): + for j, arg, result in scf.for_( + arith.constant(ir.IndexType.get(), 0), + k, + arith.constant(ir.IndexType.get(), 1), + [arith.constant(ir.F32Type.get(), 0.0)], + ): a_val = memref.load(a, [row, j]) b_val = memref.load(b, [j, col]) c_val = arith.addf(arg, arith.extf(ir.F32Type.get(), arith.mulf(a_val, b_val))) @@ -74,12 +91,28 @@ def edsl(c: InOut["memref, 3>"], # noqa: F7 @triton.autotune( configs=get_autotune_config(), - key=['M', 'N', 'K'], + key=["M", "N", "K"], ) @triton.jit -def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # - GROUP_SIZE_M: tl.constexpr, ACTIVATION: tl.constexpr): +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, +): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -131,9 +164,22 @@ def matmul(a, b, activation=""): K, N = b.shape # Allocates output. c = torch.empty((M, N), device=a.device, dtype=torch.float16) - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) - matmul_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - ACTIVATION=activation) + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + ACTIVATION=activation, + ) return c diff --git a/python/tutorials/tle/raw/04-hello-world.py b/python/tutorials/tle/raw/04-hello-world.py index 58295c0ec..68fcc2759 100644 --- a/python/tutorials/tle/raw/04-hello-world.py +++ b/python/tutorials/tle/raw/04-hello-world.py @@ -10,7 +10,7 @@ DEVICE = triton.runtime.driver.active.get_active_torch_device() -@dialect(name="mlir") +@dialect(name="mlir", signature=([], [])) def edsl(): tidx = nvvm.read_ptx_sreg_tid_x(ir.IntegerType.get_signless(32)) bidx = nvvm.read_ptx_sreg_ctaid_x(ir.IntegerType.get_signless(32)) diff --git a/python/tutorials/tle/raw/05-topk.py b/python/tutorials/tle/raw/05-topk.py index c71ef4d14..405a8f9cc 100644 --- a/python/tutorials/tle/raw/05-topk.py +++ b/python/tutorials/tle/raw/05-topk.py @@ -6,7 +6,7 @@ from mlir import ir import torch import triton -from triton.experimental.tle.raw import dialect, InOut, Input +from triton.experimental.tle.raw import dialect import triton.experimental.tle.language.raw as tle_raw import triton.language as tl @@ -31,12 +31,38 @@ def convert_to_uint32(x): # NOTE: current implementation requires a thread number of 1024 -@dialect(name="mlir") -def edsl1(thre_bin_sum_buf: InOut["memref"], l_new_topk_buf: InOut["memref"], - s_threshold_bin_id: Input["memref"], indices_base: Input["!llvm.ptr<1>"], - s_input_ids_base: Input["!llvm.ptr<1>"], inputs: Input["!llvm.ptr<1>"], - s_histogram: Input["memref"], l_start_idx: Input["i32"], l_end_idx: Input["i32"], S: Input["i32"], - BS: Input["i32"], K_tensor: Input["memref"]): +@dialect( + name="mlir", + signature=( + ["memref, 3>", "memref, 3>"], + [ + "memref, 3>", + "!llvm.ptr<1>", + "!llvm.ptr<1>", + "!llvm.ptr<1>", + "memref, 3>", + "i32", + "i32", + "i32", + "i32", + "memref, 3>", + ], + ), +) +def edsl( + thre_bin_sum_buf: ir.Value, + l_new_topk_buf: ir.Value, + s_threshold_bin_id: ir.Value, + indices_base: ir.Value, + s_input_ids_base: ir.Value, + inputs: ir.Value, + s_histogram: ir.Value, + l_start_idx: ir.Value, + l_end_idx: ir.Value, + S: ir.Value, + BS: ir.Value, + K_tensor: ir.Value, +): tidx = nvvm.read_ptx_sreg_tid_x(ir.IntegerType.get_signless(32)) bidx = nvvm.read_ptx_sreg_ctaid_x(ir.IntegerType.get_signless(32)) bdimx = nvvm.read_ptx_sreg_ntid_x(ir.IntegerType.get_signless(32)) # blockDim.x @@ -67,13 +93,15 @@ def edsl1(thre_bin_sum_buf: InOut["memref"], l_new_topk_buf: InOut["me s_i32 = arith.index_cast(i32_ty, s) input_idx_i32 = arith.addi(arith.muli(s_i32, BS), tidx) cond = arith.andi( - arith.andi(arith.cmpi(arith.CmpIPredicate.slt, input_idx_i32, l_end_idx), - arith.cmpi(arith.CmpIPredicate.sge, input_idx_i32, l_start_idx)), - arith.cmpi(arith.CmpIPredicate.slt, input_idx_i32, S)) + arith.andi( + arith.cmpi(arith.CmpIPredicate.slt, input_idx_i32, l_end_idx), + arith.cmpi(arith.CmpIPredicate.sge, input_idx_i32, l_start_idx), + ), + arith.cmpi(arith.CmpIPredicate.slt, input_idx_i32, S), + ) if_stmt = scf.if_([], cond) thenblock = if_stmt.opview.thenRegion.blocks.append() with ir.InsertionPoint(thenblock): - base_offset = arith.muli(bidx, S) full_offset = arith.addi(base_offset, input_idx_i32) @@ -297,14 +325,14 @@ def kernel_bucket_sort_topk_triton( # grid(B, BS) TS = tl.cdiv(S, BS) for s in range(TS): input_idx = s * BS + tl.arange(0, BS) - input_mask = ((input_idx < l_end_idx) & (input_idx >= l_start_idx) & (input_idx < S)) + input_mask = (input_idx < l_end_idx) & (input_idx >= l_start_idx) & (input_idx < S) input = tl.load(s_base + input_idx, input_mask, other=float("-inf")).to(tl.float32) inval_int16 = convert_to_uint16(input) s_histogram += inval_int16.to(tl.int32).histogram(HISTOGRAM_SIZE) s_histogram = s_histogram.cumsum(0, reverse=True) # Suffix sum - mv_idx = (tl.arange(1, HISTOGRAM_SIZE + 1) % HISTOGRAM_SIZE) # Construct offset index matrix + mv_idx = tl.arange(1, HISTOGRAM_SIZE + 1) % HISTOGRAM_SIZE # Construct offset index matrix cond = (s_histogram > l_new_topk) & ((s_histogram.gather(mv_idx, 0) <= l_new_topk) | (mv_idx == 0)) l_threshold_bin_id = cond.argmax(0) @@ -314,7 +342,7 @@ def kernel_bucket_sort_topk_triton( # grid(B, BS) thre_bin_sum = 0 for s in range(TS): input_idx = s * BS + tl.arange(0, BS) - input_mask = ((input_idx < l_end_idx) & (input_idx >= l_start_idx) & (input_idx < S)) + input_mask = (input_idx < l_end_idx) & (input_idx >= l_start_idx) & (input_idx < S) input = tl.load(s_base + input_idx, input_mask, other=float("-inf")).to(tl.float32) inval_int16 = convert_to_uint16(input) # This method would slow down the speed, so using other=float("-inf") saves time. @@ -376,7 +404,7 @@ def kernel_bucket_sort_topk_triton( # grid(B, BS) (24 - round * 8)) & 0xFF # Ensure all bits except the last eight are zero s_histogram += inval_int32.to(tl.int32).histogram(HISTOGRAM_SIZE) s_histogram = s_histogram.cumsum(0, reverse=True) # Suffix sum - mv_idx = (tl.arange(1, HISTOGRAM_SIZE + 1) % HISTOGRAM_SIZE) # Construct offset index matrix + mv_idx = tl.arange(1, HISTOGRAM_SIZE + 1) % HISTOGRAM_SIZE # Construct offset index matrix cond = (s_histogram > l_new_topk) & ((s_histogram.gather(mv_idx, 0) <= l_new_topk) | (mv_idx == 0)) l_threshold_bin_id = cond.argmax(0) l_new_topk -= tl.where(tl.arange(0, HISTOGRAM_SIZE) == l_threshold_bin_id + 1, s_histogram, 0).max(0) @@ -455,16 +483,29 @@ def kernel_bucket_sort_topk_edsl( # grid(B,) # Kernel1: Compute histogram s_histogram = tl.zeros([HISTOGRAM_SIZE], dtype=tl.int32) - # Kernel2: Call edsl1 for topk selection (threshold calculated in edsl1) + # Kernel2: Call edsl for topk selection (threshold calculated in edsl) thre_bin_sum_buf = tl.zeros([1], dtype=tl.int32) l_new_topk_buf = tl.zeros([1], dtype=tl.int32) s_threshold_bin_id = tl.zeros([1], dtype=tl.int32) s = S bs = BS k_tensor = tl.full([1], K, dtype=tl.int32) # Convert constexpr to tensor - thre_bin_sum_buf, l_new_topk_buf = tle_raw.call(edsl1, [thre_bin_sum_buf, l_new_topk_buf], [ - s_threshold_bin_id, indices_base, s_input_ids_base, inputs, s_histogram, l_start_idx, l_end_idx, s, bs, k_tensor - ]) + thre_bin_sum_buf, l_new_topk_buf = tle_raw.call( + edsl, + [thre_bin_sum_buf, l_new_topk_buf], + [ + s_threshold_bin_id, + indices_base, + s_input_ids_base, + inputs, + s_histogram, + l_start_idx, + l_end_idx, + s, + bs, + k_tensor, + ], + ) thre_bin_sum = thre_bin_sum_buf.max(0) l_new_topk = l_new_topk_buf.max(0) @@ -485,7 +526,7 @@ def kernel_bucket_sort_topk_edsl( # grid(B,) inval_int32 = (convert_to_uint32(s_input) >> (24 - round * 8)) & 0xFF s_histogram += inval_int32.to(tl.int32).histogram(HISTOGRAM_SIZE) s_histogram = s_histogram.cumsum(0, reverse=True) # Suffix sum - mv_idx = (tl.arange(1, HISTOGRAM_SIZE + 1) % HISTOGRAM_SIZE) # Construct offset index matrix + mv_idx = tl.arange(1, HISTOGRAM_SIZE + 1) % HISTOGRAM_SIZE # Construct offset index matrix cond = (s_histogram > l_new_topk) & ((s_histogram.gather(mv_idx, 0) <= l_new_topk) | (mv_idx == 0)) l_threshold_bin_id = cond.argmax(0) l_new_topk -= tl.where(tl.arange(0, HISTOGRAM_SIZE) == l_threshold_bin_id + 1, s_histogram, 0).max(0) @@ -573,7 +614,6 @@ def bucket_sort_topk(inputs, starts, ends, topk, kernel: Literal["triton", "tle" def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048, kernel: Literal["triton", "tle"] = "triton"): - batch = 64 seq_len = 32 * 1024 topk = 2048