Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/triton/experimental/tle/raw/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .runtime import dialect
from .typing import Input, InOut

__all__ = ["dialect", "Input", "InOut"]
__all__ = ["dialect"]
53 changes: 30 additions & 23 deletions python/triton/experimental/tle/raw/mlir/codegen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
from typing import Any, Dict, Final, List, Optional, Sequence
from functools import cached_property
from typing import Any, Dict, Final, List, Optional, Sequence, Tuple, Union
from typing_extensions import override

from mlir import ir
Expand All @@ -16,10 +17,19 @@ def __init__(self, name: str, *args, **kwargs) -> None:

class EdslMLIRCodeGenerator(ast.NodeVisitor):

def __init__(self, absfilename: str, lscope: Dict[str, Any] = None, gscope: Dict[str, Any] = {},
context: Optional[ir.Context] = None, *args, **kwargs) -> None:
def __init__(
self,
absfilename: str,
signature: Tuple[Sequence[Union[str, ir.Type]], Sequence[Union[str, ir.Type]]],
lscope: Dict[str, Any] = None,
gscope: Dict[str, Any] = {},
context: Optional[ir.Context] = None,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.absfilename: Final[str] = absfilename
self.signature: Final[Tuple[Sequence[Union[str, ir.Type]], Sequence[Union[str, ir.Type]]]] = signature
self.lscope: Final[Dict[str, Any]] = {**lscope}
self.gscope: Final[Dict[str, Any]] = {**gscope}
self.decls: Final[Dict[str, func.FuncOp]] = {}
Expand Down Expand Up @@ -87,32 +97,14 @@ def visit_For(self, node: ast.For) -> None:
@override
def visit_FunctionDef(self, node: ast.FunctionDef) -> func.FuncOp:
with self.context, ir.Location.file(self.absfilename, node.lineno, node.col_offset):
operand_tys: List[ir.Type] = []
output_tys: List[ir.Type] = []
output_indices: List[int] = []
for idx, arg in enumerate(node.args.args):
if arg.annotation.value.id == "InOut":
ty: ir.Type = ir.Type.parse(arg.annotation.slice.value)
operand_tys += [ty]
output_tys += [ty]
output_indices += [idx]
elif arg.annotation.value.id == "Input":
ty: ir.Type = ir.Type.parse(arg.annotation.slice.value)
operand_tys += [ty]
elif arg.annotation.value.id == "Num":
ty: ir.Type = ir.Type.parse(arg.annotation.slice.value)
operand_tys += [ty]
else:
raise NotImplementedError(f"unsupported argument annotation: {ast.dump(arg.annotation)}")
fnty: ir.FunctionType = ir.FunctionType.get(operand_tys, output_tys)
fn: func.FuncOp = func.FuncOp(node.name, fnty, visibility="public")
fn: func.FuncOp = func.FuncOp(node.name, self.funcdef, visibility="public")
block: ir.Block = fn.add_entry_block()
for k, arg in zip(map(lambda arg: arg.arg, node.args.args), block.arguments):
self.lscope[k] = arg
with ir.InsertionPoint(block):
for stmt in node.body:
self.visit(stmt)
func.return_([block.arguments[idx] for idx in output_indices])
func.return_([arg for arg, _ in zip(block.arguments, self.funcdef.results)])
return fn

@override
Expand Down Expand Up @@ -156,3 +148,18 @@ def visit_With(self, node: ast.With) -> None:
with self.visit(item.context_expr):
for stmt in node.body:
self.visit(stmt)

@cached_property
def funcdef(self) -> ir.FunctionType:
outputs, inputs = self.signature
outputs: List[ir.Type] = [self.canonicalize(t) for t in outputs]
inputs: List[ir.Type] = [self.canonicalize(t) for t in inputs]
operands: List[ir.Type] = [*outputs, *inputs]
return ir.FunctionType.get(operands, outputs)

@staticmethod
def canonicalize(type: Union[str, ir.Type]) -> ir.Type:
if isinstance(type, ir.Type):
return type
else:
return ir.Type.parse(type)
22 changes: 18 additions & 4 deletions python/triton/experimental/tle/raw/mlir/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
from functools import cached_property
import inspect
from typing import Any, Dict, Final, List, Optional
from typing import Any, Dict, Final, List, Optional, Sequence, Tuple, Union

from mlir import ir
from mlir.passmanager import PassManager
Expand All @@ -13,15 +13,29 @@

class EdslMLIRJITFunction(object):

def __init__(self, fn: Any, pipeline: List[str], context: Optional[ir.Context] = None, *args, **kwargs) -> None:
def __init__(
self,
fn: Any,
signature: Tuple[Sequence[Union[str, ir.Type]], Sequence[Union[str, ir.Type]]],
pipeline: List[str],
context: Optional[ir.Context] = None,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.fn: Final[Any] = fn
self.signature: Final[Tuple[Sequence[Union[str, ir.Type]], Sequence[Union[str, ir.Type]]]] = signature
self.pipeline: Final[List[str]] = [*pipeline]
self.context: Final[ir.Context] = ir.Context() if context is None else context
self.__triton_builtin__: Final[bool] = True

def __deepcopy__(self, memo: Dict[int, Any]) -> EdslMLIRJITFunction:
return self.__class__(copy.deepcopy(self.fn, memo), copy.deepcopy(self.pipeline, memo), self.context)
return self.__class__(
copy.deepcopy(self.fn, memo),
copy.deepcopy(self.signature, memo),
copy.deepcopy(self.pipeline, memo),
self.context,
)

@cached_property
def ast(self) -> ast.Module:
Expand All @@ -41,7 +55,7 @@ def globals(self) -> Dict[str, Any]:

@cached_property
def codegen(self) -> EdslMLIRCodeGenerator:
return EdslMLIRCodeGenerator(self.absfilename, {}, self.globals, self.context)
return EdslMLIRCodeGenerator(self.absfilename, self.signature, {}, self.globals, self.context)

@property
def ir(self) -> ir.Module:
Expand Down
15 changes: 9 additions & 6 deletions python/triton/experimental/tle/raw/runtime.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from .mlir import EdslMLIRJITFunction
from typing import List
from typing import List, Sequence, Tuple, Union

from mlir import ir

registry = {"mlir": EdslMLIRJITFunction}


def dialect(*, name: str, pipeline: List[str] = [
"convert-scf-to-cf", "finalize-memref-to-llvm", "convert-arith-to-llvm", "convert-cf-to-llvm",
"convert-func-to-llvm", "convert-index-to-llvm", "convert-nvvm-to-llvm", "cse"
]):
def dialect(*, name: str, signature: Tuple[Sequence[Union[str, ir.Type]], Sequence[Union[str, ir.Type]]],
pipeline: List[str] = [
"convert-scf-to-cf", "finalize-memref-to-llvm", "convert-arith-to-llvm", "convert-cf-to-llvm",
"convert-func-to-llvm", "convert-index-to-llvm", "convert-nvvm-to-llvm", "cse"
]):

def decorator(fn):
edsl = registry[name](fn, pipeline=pipeline)
edsl = registry[name](fn, signature=signature, pipeline=pipeline)
return edsl

return decorator
15 changes: 0 additions & 15 deletions python/triton/experimental/tle/raw/typing.py

This file was deleted.

3 changes: 2 additions & 1 deletion python/tutorials/hints/06/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def _attn_bwd_preprocess(O, DO, #
off_n = tl.arange(0, HEAD_DIM)
# load
o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) # @hint: shared_memory
do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) # @hint: shared_memory
do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(
tl.float32) # @hint: shared_memory
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(Delta + off_hz * N_CTX + off_m, delta)
Expand Down
9 changes: 4 additions & 5 deletions python/tutorials/tle/raw/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import torch
import triton
import triton.language as tl
from triton.experimental.tle.raw import dialect, Input
from triton.experimental.tle.raw import dialect
import triton.experimental.tle.language.raw as tle_raw

DEVICE = triton.runtime.driver.active.get_active_torch_device()


@dialect(name="mlir")
def edsl(output: Input["!llvm.ptr<1>"], x: Input["!llvm.ptr<1>"], y: Input["!llvm.ptr<1>"], # noqa: F722,
n_elements: Input["i32"]): # noqa: F821
@dialect(name="mlir", signature=([], ["!llvm.ptr<1>", "!llvm.ptr<1>", "!llvm.ptr<1>", "i32"]))
def edsl(output: ir.Value, x: ir.Value, y: ir.Value, n_elements: ir.Value):
tidx = nvvm.read_ptx_sreg_tid_x(ir.IntegerType.get_signless(32))
bdimx = nvvm.read_ptx_sreg_ntid_x(ir.IntegerType.get_signless(32))
bidx = nvvm.read_ptx_sreg_ctaid_x(ir.IntegerType.get_signless(32))
Expand Down Expand Up @@ -49,7 +48,7 @@ def add(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output

Expand Down
11 changes: 7 additions & 4 deletions python/tutorials/tle/raw/02-fused-softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import triton
import triton.language as tl
from triton.experimental.tle.raw import dialect, InOut, Input
from triton.experimental.tle.raw import dialect
import triton.experimental.tle.language.raw as tle_raw

DEVICE = triton.runtime.driver.active.get_active_torch_device()
Expand All @@ -18,8 +18,11 @@ def naive_softmax(x):
return ret


@dialect(name="mlir")
def edsl(y: InOut["memref<?xf32, 3>"], x: Input["memref<?xf32, 3>"]): # noqa: F722
@dialect(
name="mlir",
signature=(["memref<?xf32, strided<[?], offset: ?>, 3>"], ["memref<?xf32, strided<[?], offset: ?>, 3>"]),
)
def edsl(y: ir.Value, x: ir.Value):
tidx = nvvm.read_ptx_sreg_tid_x(ir.IntegerType.get_signless(32))
bdimx = nvvm.read_ptx_sreg_ntid_x(ir.IntegerType.get_signless(32))
tidx = arith.index_cast(ir.IndexType.get(), tidx)
Expand Down Expand Up @@ -112,7 +115,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
row = tl.load(input_ptrs, mask=mask, other=-float("inf"))
softmax_output = tl.zeros_like(row)
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
Expand Down
Loading