From b5f36ad4d4da2e8fb1853ffc8992ca57889267d8 Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Fri, 24 Oct 2025 18:28:08 +0800 Subject: [PATCH 01/24] tilelang frontend v2 --- examples/gdn/example_chunk_o_bwd.py | 4 +- tilelang/language/__init__.py | 6 +- tilelang/language/v2/__init__.py | 1 + tilelang/language/v2/ast.py | 489 ++++++++++++++++++++++++++++ tilelang/language/v2/builder.py | 383 ++++++++++++++++++++++ tilelang/language/v2/dtypes.py | 198 +++++++++++ tilelang/language/v2/utils.py | 106 ++++++ 7 files changed, 1182 insertions(+), 5 deletions(-) create mode 100644 tilelang/language/v2/__init__.py create mode 100644 tilelang/language/v2/ast.py create mode 100644 tilelang/language/v2/builder.py create mode 100644 tilelang/language/v2/dtypes.py create mode 100644 tilelang/language/v2/utils.py diff --git a/examples/gdn/example_chunk_o_bwd.py b/examples/gdn/example_chunk_o_bwd.py index 76b4792df..3f69b6b68 100644 --- a/examples/gdn/example_chunk_o_bwd.py +++ b/examples/gdn/example_chunk_o_bwd.py @@ -256,8 +256,8 @@ def kernel( # for i_kv in T.Parallel(block_DK * block_DV): # dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] for i_kv in T.Parallel(block_DK * block_DV): - i_k, i_v = i_kv // block_DV, i_kv % block_DV - dg_last_fragment[i_kv] = h_shared[i_k, i_v] * dh_shared[i_k, i_v] + i_k, i_v_1 = i_kv // block_DV, i_kv % block_DV + dg_last_fragment[i_kv] = h_shared[i_k, i_v_1] * dh_shared[i_k, i_v_1] T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False) dg_last_local[0] += dg_last_fragment_scalar[0] diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 1a26b53d0..6cac4a90d 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -8,9 +8,9 @@ # upstream tir script is fully compatible from tvm.script.parser.tir import * from . import overrides as _overrides # noqa: F401 -from .tir import ( - prim_func, # noqa: F401 -) + +# from .tir import prim_func, macro, # noqa: F401 +from .v2 import prim_func, macro # noqa: F401 from .tir.ir import * # noqa: F401 from tilelang.layout import Layout, Fragment # noqa: F401 from .proxy import ( diff --git a/tilelang/language/v2/__init__.py b/tilelang/language/v2/__init__.py new file mode 100644 index 000000000..23b907b37 --- /dev/null +++ b/tilelang/language/v2/__init__.py @@ -0,0 +1 @@ +from .builder import prim_func, macro # noqa: F401 diff --git a/tilelang/language/v2/ast.py b/tilelang/language/v2/ast.py new file mode 100644 index 000000000..1ed9c3d0d --- /dev/null +++ b/tilelang/language/v2/ast.py @@ -0,0 +1,489 @@ +from __future__ import annotations +import ast +from typing import Callable, ContextManager, Iterable, Any, Literal, ParamSpec, TypeVar +import inspect +# from .utils import get_ast, get_compiled_object +from . import utils + +_span_attrs = ['lineno', 'col_offset', 'end_lineno', 'end_col_offset'] + + +def ast_has_span(ast: ast.AST) -> bool: + return all(hasattr(ast, attr) for attr in _span_attrs) + + +def ast_get_span(ast: ast.AST) -> tuple[int, int, int, int]: + if not ast_has_span(ast): + return None + return tuple(getattr(ast, attr) for attr in _span_attrs) + + +def ast_set_span(ast: ast.AST, span: tuple[int, int, int, int]): + if not ast_has_span(ast): + return + for attr, value in zip(_span_attrs, span): + setattr(ast, attr, value) + + +class QuoteVisitor(ast.NodeTransformer): + + def __init__(self, names: dict[str, ast.AST], passes: list[Any] | None = None, span=None): + self.names = names + self.passes = passes or [] + self.span = span + + def generic_visit(self, node: ast.AST): + if self.span is not None: + ast_set_span(node, self.span) + return super().generic_visit(node) + + def visit_Name(self, node: ast.Name) -> Any: + if node.id in self.names: + return self.names[node.id] + else: + return node + + def visit_Pass(self, node: ast.Pass) -> Any: + item = self.passes.pop(0) + return item if item else node + + +def quote(expr: str, *, passes: list[Any] | None = None, span=None, **kws) -> list[ast.AST]: + tree = ast.parse(expr) + if isinstance(span, ast.AST): + span = ast_get_span(span) + tree = QuoteVisitor(kws, passes, span).visit(tree) + return tree.body + + +def quote1(expr: str, *, passes: list[Any] | None = None, span=None, **kws) -> ast.AST: + res = quote(expr, passes=passes, span=span, **kws) + assert len(res) == 1 + return res[0] + + +def quote_expr(expr: str, **kws) -> ast.expr: + res = quote1(expr, **kws) + assert isinstance(res, ast.Expr) + return res.value + + +Operator = Literal['Add', 'Sub', 'Mult', 'MatMult', 'Div', 'Mod', 'Pow', 'LShift', 'RShift', + 'BitOr', 'BitXor', 'BitAnd', 'FloorDiv'] +BoolOp = Literal['And', 'Or'] + + +def get_operator_name(operator: ast.operator) -> Operator: + return operator.__class__.__name__ + + +def get_boolop_name(boolop: ast.boolop) -> BoolOp: + return boolop.__class__.__name__ + + +_T = TypeVar('_T') + + +def eval_op(op: Operator, left: Any, right: Any) -> Any: + if op == 'Add': + return left + right + if op == 'Sub': + return left - right + if op == 'Mult': + return left * right + if op == 'MatMult': + return left @ right + if op == 'Div': + return left / right + if op == 'Mod': + return left % right + if op == 'Pow': + return left**right + if op == 'LShift': + return left << right + if op == 'RShift': + return left >> right + if op == 'BitOr': + return left | right + if op == 'BitXor': + return left ^ right + if op == 'BitAnd': + return left & right + if op == 'FloorDiv': + return left // right + raise ValueError(f'Unknown operator: {op}') + + +def eval_aug_assign(op: Operator, left: Any, sl: slice, right: Any) -> Any: + if op == 'Add': + left[sl] += right + return left + if op == 'Sub': + left[sl] -= right + return left + if op == 'Mult': + left[sl] *= right + return left + if op == 'MatMult': + left[sl] @= right + return left + if op == 'Div': + left[sl] /= right + return left + if op == 'Mod': + left[sl] %= right + return left + if op == 'Pow': + left[sl] **= right + return left + if op == 'LShift': + left[sl] <<= right + return left + if op == 'RShift': + left[sl] >>= right + return left + if op == 'BitOr': + left[sl] |= right + return left + if op == 'BitXor': + left[sl] ^= right + return left + if op == 'BitAnd': + left[sl] &= right + return left + if op == 'FloorDiv': + left[sl] //= right + return left + raise ValueError(f'Unknown operator: {op}') + + +class BaseBuilder: + + def get_parent_locals(self): + return inspect.currentframe().f_back.f_back.f_locals + + def ctx_if(self, cond) -> Iterable[_T]: + yield cond + + def ctx_then(self, val: _T) -> Iterable[None]: + if val: + yield + + def ctx_else(self, val: _T) -> Iterable[None]: + if not val: + yield + + def eval(self, val: Any): # noqa: B027 + pass + + def ctx_for(self, range: Iterable[Any]) -> Iterable[Any]: + return range + + def ctx_continue(self) -> bool: + return True + + def ctx_break(self) -> bool: + return True + + def ctx_while(self, cond: Callable[[], Any]) -> Iterable[None]: + while cond(): + yield + + def bind(self, name: str, value: Any) -> Any: + return value + + def assign_slice(self, lval: Any, sl: slice, value: Any): + lval[sl] = value + + def aug_assign(self, op: Operator, target: Any, aug_value: Any) -> Any: + return eval_op(op, target, aug_value) + + def aug_assign_slice(self, op: Operator, target: Any, sl: slice, aug_value: Any): + eval_aug_assign(op, target, sl, aug_value) + + def boolop(self, op: BoolOp, left: Any, right: Callable[[], Any]) -> Any: + if op == 'And': + return left and right() + if op == 'Or': + return left or right() + raise ValueError(f'Unknown boolop: {op}') + + def ifexp(self, cond: Any, then: Callable[[], Any], otherwise: Callable[[], Any]) -> Any: + return then() if cond else otherwise() + + def ret(self, value: Any) -> Any: + return value + + def ctx_with(self, ctx: ContextManager[Any]) -> ContextManager[Any]: + return ctx + + def assert_expr(self, cond: Any, msg: Any): + assert cond, msg + + def rval(self, name: str, value: Any): + return value + + def arg(self, name: str, value: Any): + return value + + def override(self, name: str): + return globals()[name] + + +class DSLMutator(ast.NodeTransformer): + + def __init__(self): + self.tmp_counter = 0 + + def get_tmp(self) -> str: + name = f"__{self.tmp_counter}" + self.tmp_counter += 1 + return name + + def visit_If(self, node: ast.If): + node = self.generic_visit(node) + br = self.get_tmp() + if len(node.orelse) == 0: + return quote( + f"for {br} in __tb.ctx_if(cond):\n" + f" for _ in __tb.ctx_then({br}):\n" + " pass\n", + cond=node.test, + passes=[node.body], + span=node, + ) + return quote( + f"for {br} in __tb.ctx_if(cond):\n" + f" for _ in __tb.ctx_then({br}):\n" + f" pass\n" + f" for _ in __tb.ctx_else({br}):\n" + f" pass\n", + cond=node.test, + passes=[node.body, node.orelse], + span=node, + ) + + def visit_Expr(self, node: ast.Expr): + node = self.generic_visit(node) + return quote("__tb.eval(value)", value=node.value, span=node) + + def _parse_names(self, target: ast.expr): + if isinstance(target, ast.Name): + return f"'{target.id}'" + elif isinstance(target, ast.Tuple): + return ("(" + ",".join([self._parse_names(elt) for elt in target.elts]) + ",)") + else: + raise SyntaxError("Unsupported for target") + + def visit_For(self, node: ast.For): + node = self.generic_visit(node) + tmp = self.get_tmp() + # names = self._parse_names(node.target) + var = ast.Name(tmp, ctx=ast.Load()) + ast_set_span(var, ast_get_span(node.target)) + stmts = self._emit_assign_target(node.target, var) + return quote( + f"for {tmp} in __tb.ctx_for(range):\n" + " pass\n", + target=node.target, + range=node.iter, + passes=[stmts + node.body], + span=node, + ) + + def visit_Continue(self, node: ast.Continue): + node = self.generic_visit(node) + return quote("if __tb.ctx_continue(): continue", span=node) + + def visit_Break(self, node: ast.Break): + node = self.generic_visit(node) + return quote("if __tb.ctx_break(): break", span=node) + + def _emit_assign_target(self, target: ast.expr, rval: ast.expr) -> list[ast.AST]: + if isinstance(target, ast.Name): + return quote( + f"name = __tb.bind('{target.id}', value)", name=target, value=rval, span=target) + elif isinstance(target, ast.Subscript): + return quote( + "__tb.assign_slice(lval, slice, value)", + lval=target.value, + slice=target.slice, + value=rval, + span=target, + ) + else: + unpacked = [] + + def _visit_target(target: ast.expr) -> str: + if isinstance(target, (ast.Name, ast.Subscript)): + tmp = self.get_tmp() + unpacked.append((tmp, target)) + res = ast.Name(id=tmp, ctx=target.ctx) + ast_set_span(res, ast_get_span(target)) + return res + elif isinstance(target, ast.Tuple): + elts = [_visit_target(elt) for elt in target.elts] + res = ast.Tuple(elts=elts, ctx=target.ctx) + ast_set_span(res, ast_get_span(target)) + return res + + unpack_stmt = ast.Assign(targets=[_visit_target(target)], value=rval) + ast_set_span(unpack_stmt, ast_get_span(target)) + stmts = [unpack_stmt] + bind_lvals = [] + bind_rvals = [] + + def flush_binds(): + if bind_lvals: + stmts.append( + quote1(f'{", ".join(bind_lvals)}, = {", ".join(bind_rvals)},', span=target)) + bind_lvals.clear() + bind_rvals.clear() + + for tmp, target in unpacked: + if isinstance(target, ast.Name): + bind_lvals.append(target.id) + bind_rvals.append(f'__tb.bind("{target.id}", {tmp})') + elif isinstance(target, ast.Subscript): + flush_binds() + stmts.append( + quote1( + f'__tb.assign_slice(lval, slice, {tmp})', + lval=target.value, + slice=target.slice, + span=target)) + else: + raise NotImplementedError(f'Unsupported target: {target}') + flush_binds() + return stmts + + def visit_Assign(self, node: ast.Assign) -> list[ast.AST]: + node = self.generic_visit(node) + rval = node.value + stmts = [] + for target in reversed(node.targets): + stmts.extend(self._emit_assign_target(target, rval)) + rval = target + return stmts + + def visit_AugAssign(self, node: ast.AugAssign) -> list[ast.AST]: + node = self.generic_visit(node) + target, rval = node.target, node.value + op = get_operator_name(node.op) + if isinstance(target, ast.Name): + return quote( + f"name = __tb.aug_assign('{op}', {target.id}, value)", + name=target, + value=rval, + span=node) + elif isinstance(target, ast.Subscript): + return quote( + f"__tb.aug_assign_slice('{op}', lval, slice, value)", + lval=target.value, + slice=target.slice, + value=rval, + span=node, + ) + else: + return node + + def visit_While(self, node): + return quote1( + "for _ in __tb.ctx_while(lambda: cond):\n pass", + cond=node.test, + passes=[node.body], + span=node) + + def visit_FunctionDef(self, node: ast.FunctionDef): + node = self.generic_visit(node) + all_args = node.args.posonlyargs + node.args.args + if node.args.vararg is not None: + all_args += node.args.vararg + all_args += node.args.kwonlyargs + stmts = [] + for arg in all_args: + name = arg.arg + if arg.annotation is not None: + arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg) + else: + arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg) + arg.annotation = None + stmts.append(arg_stmt) + node.body = stmts + node.body + node.decorator_list.clear() + return quote1( + "def ir_generator(__tb):\n" + " range = __tb.override('range')\n" + " pass\n" + f" return {node.name}", + passes=[node], + ) + + def visit_BoolOp(self, node: ast.BoolOp): + node = self.generic_visit(node) + op_name = get_boolop_name(node.op) + last = node.values[-1] + for i in reversed(range(len(node.values) - 1)): + last = quote_expr( + expr=f"__tb.boolop('{op_name}', left, lambda: right)", + left=node.values[i], + right=last, + span=node, + ) + return last + + def visit_Compare(self, node: ast.Compare) -> ast.expr: + node = self.generic_visit(node) + left = node.left + split = [] + for op, comp in zip(node.ops, node.comparators): + cmp = ast.Compare(left=left, ops=[op], comparators=[comp]) + ast_set_span(cmp, ast_get_span(node)) + split.append(cmp) + left = comp + last = split[-1] + for i in reversed(range(len(split) - 1)): + last = quote_expr( + "__tb.boolop('And', left, lambda: right)", left=split[i], right=last, span=node) + return last + + def visit_IfExp(self, node: ast.IfExp) -> ast.Expr: + node = self.generic_visit(node) + return quote_expr( + '__tb.ifexp(cond, lambda: then, lambda: otherwise)', + cond=node.test, + then=node.body, + otherwise=node.orelse, + span=node) + + def visit_Return(self, node: ast.Return): + node = self.generic_visit(node) + return quote("return __tb.ret(value)", value=node.value, span=node) + + def visit_With(self, node: ast.With): + node = self.generic_visit(node) + for expr in node.items: + expr.context_expr = quote_expr("__tb.ctx_with(e)", e=expr.context_expr, span=expr) + return node + + def visit_Assert(self, node: ast.Assert): + node = self.generic_visit(node) + return quote("__tb.assert_expr(cond, msg)", cond=node.test, msg=node.msg, span=node) + + def visit_Name(self, node: ast.Name): + if isinstance(node.ctx, ast.Load): + return quote_expr(f"__tb.rval('{node.id}', {node.id})", span=node) + return node + + +_P = ParamSpec('_P') + + +def mutate(func: Callable[_P, _T]) -> Callable[[BaseBuilder], Callable[_P, _T]]: + tree = utils.get_ast(func) + filename = inspect.getsourcefile(func) or inspect.getfile(func) + tree = DSLMutator().visit(tree) + fn = utils.get_compiled_object(tree, "ir_generator", filename, + utils.inspect_function_capture(func)) + fn.__source__ = ast.unparse(tree) + return fn diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py new file mode 100644 index 000000000..a8c1f3155 --- /dev/null +++ b/tilelang/language/v2/builder.py @@ -0,0 +1,383 @@ +from __future__ import annotations + +from contextlib import contextmanager +import functools +from tilelang.language.kernel import KernelLaunchFrame +from tvm.ffi.container import Map +from tvm.ir.base import Span +from .ast import BaseBuilder, eval_op, mutate +import tvm +from tvm.tir import Buffer +from tvm.script.ir_builder import tir, IRBuilder +from tvm.tir.expr import EqualOp, NotEqualOp, PrimExpr +from typing import Callable, ContextManager, Any, Generic, ParamSpec, Self, TypeVar +import threading +import logging + +logger = logging.getLogger(__name__) + + +def unwrap_expr(expr) -> PrimExpr | int | float: + if isinstance(expr, tir.meta_var): + expr = expr.value + elif isinstance(expr, Buffer) and expr.scope() == 'local.var': + expr = tir.BufferLoad(expr, indices=[0]) + elif isinstance(expr, (EqualOp, NotEqualOp)): + expr = expr.asobject() + elif isinstance(expr, tir.IntImm) and expr.dtype == 'int32': + expr = expr.value + return expr + + +def unwrap_cond(expr): + expr = unwrap_expr(expr) + if isinstance(expr, PrimExpr): + return expr + elif isinstance(expr, Buffer): + raise TypeError(f"Buffer `{expr}` cannot be used as condition directly.") + elif isinstance(expr, (int, bool, tuple, list)): + return expr + else: + logger.warning(f"Python expression `{expr}` is used in TileLang. ", stack_info=True) + return expr + + +thread_local_storage = threading.local() + + +class DummyFrame: + + def __enter__(self): + ... + + def __exit__(self, exc_type, exc_value, traceback): + ... + + +class MacroFrame(DummyFrame): + ... + + +class BoolOpFrame(DummyFrame): + ... + + +class ConstIfFrame(DummyFrame): + ... + + +class BlockFrame(DummyFrame): + ... + + +AnyFrame = tir.frame.IRBuilderFrame | DummyFrame + +TIR_CONTROL_FRAME = ( + tir.frame.WhileFrame, + tir.frame.ForFrame, + tir.frame.IfFrame, + tir.frame.PrimFuncFrame, +) + +TIR_VAR_SCOPE_FRAME = ( + tir.frame.WhileFrame, + tir.frame.ForFrame, + tir.frame.IfFrame, + tir.frame.PrimFuncFrame, + MacroFrame, + KernelLaunchFrame, +) + + +class Builder(BaseBuilder): + + def __init__(self, arg_annot: dict[str, Any]): + self.arg_annot = arg_annot + self.frames: list[AnyFrame] = [] + self.ir_builder = IRBuilder() + self.name_inside_frame: dict[str, AnyFrame] = {} + + @classmethod + def current(cls) -> Self: + builder = thread_local_storage.builder + assert builder is not None, "No active Builder found in the current thread." + return builder + + @contextmanager + def prim_func(self, name): + thread_local_storage.builder = self + with self.ir_builder, self.with_frame(tir.prim_func()): + tir.func_name(name) + yield + + @contextmanager + def macro(self, name=None): + if self.find_frame_idx(BoolOpFrame) is not None: + raise RuntimeError( + f"Macro `{name}` is used inside boolean expressions, " + "please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs") + save = self.name_inside_frame + self.name_inside_frame = {} + with self.with_frame(MacroFrame()): + yield + self.name_inside_frame = save + + def get(self): + return self.ir_builder.get() + + def find_frame_idx(self, frame: type | tuple[type, ...], start=0) -> int | None: + for idx in reversed(range(start, len(self.frames))): + f = self.frames[idx] + if isinstance(f, frame): + return idx + + def enter_frame(self, frame: ContextManager): + self.frames.append(frame) + return frame.__enter__() + + @contextmanager + def with_frame(self, frame: ContextManager | None): + pop_idx = len(self.frames) + yield self.enter_frame(frame) + while len(self.frames) > pop_idx: + self.frames.pop().__exit__(None, None, None) + + class _has_if_frame: + ... + + def ctx_if(self, cond): + cond = unwrap_cond(cond) + if isinstance(cond, PrimExpr): + with self.with_frame(tir.If(cond)): + yield self._has_if_frame + else: + with self.with_frame(ConstIfFrame()): + yield cond + + def ctx_then(self, val): + if val is self._has_if_frame: + with self.with_frame(tir.Then()): + yield + else: + with self.with_frame(BlockFrame()): + if val: + yield + + def ctx_else(self, val): + if val is self._has_if_frame: + with self.with_frame(tir.Else()): + yield + else: + with self.with_frame(BlockFrame()): + if not val: + yield + + def eval(self, val: Any): + val = unwrap_expr(val) + if val is None: + pass + elif isinstance(val, tir.frame.IRBuilderFrame): + self.enter_frame(val) + elif isinstance(val, PrimExpr): + tir.evaluate(val) + elif isinstance(val, (int, bool)): + self.enter_frame(tir.evaluate(tvm.tir.const(val))) + elif isinstance(val, str): + pass + elif isinstance(val, tvm.tir.stmt.BufferStore): + self.enter_frame(tir.buffer_store(val.buffer, val.value, val.indices, val.predicate)) + else: + raise TypeError(f"Unsupported eval value: {val} of type {type(val)}") + + def ctx_for(self, it): + it = unwrap_expr(it) + if isinstance(it, range): + assert it.step == 1, "Only step=1 is supported in range for now." + it = tir.serial(it.start, it.stop) + if not isinstance(it, tir.frame.ForFrame): + raise TypeError( + f"Invalid for loop, got {it}({type(it)}), expect one of the following: " + "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding") + with self.with_frame(it) as v: + yield v + + def ctx_continue(self): + raise RuntimeError("continue is not supported in TileLang builder") + + def ctx_break(self): + raise RuntimeError("break is not supported in TileLang builder") + + def ctx_while(self, cond): + raise RuntimeError("while loops are not supported in TileLang builder") + + def bind(self, name, value): + if name == '_': + return value + locals = self.get_parent_locals() + orig_value = locals.get(name, None) + # handle var + if isinstance(orig_value, Buffer) and orig_value.scope() == 'local.var': + tir.buffer_store(orig_value, value, 0) + return orig_value + res = self.bind_immutable(name, value) + frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME) + assert frame is not None, f"Variable `{name}` is not defined inside any control flow." + self.name_inside_frame[name] = self.frames[frame] + return res + + def bind_immutable(self, name, value): + if isinstance(value, tir.meta_var): + return value.value + elif isinstance(value, tir.frame.IRBuilderFrame): + return self.enter_frame(value) + elif isinstance(value, (Buffer, tir.IterVar, tir.Var)): + IRBuilder.name(name, value) + return value + else: + try: + value = tvm.runtime.convert(value) + except TypeError: + return value + frame = tir.LetStmt(value) + var = frame.var + IRBuilder.name(name, var) + return self.enter_frame(frame) + + def assign_slice(self, lval: Any, sl: slice, value: Any): + if isinstance(lval, Buffer): + tir.buffer_store(lval, value, sl) + else: + return super().assign_slice(lval, sl, value) + + def aug_assign(self, op, target, aug_value): + if isinstance(target, Buffer) and target.scope() == 'local.var': + tir.buffer_store(target, eval_op(op, target, aug_value), 0) + if isinstance(target, Buffer): + raise RuntimeError("Augmented assignment is not supported for Buffer") + else: + return super().aug_assign(op, target, aug_value) + + def aug_assign_slice(self, op, target, sl, aug_value): + if isinstance(target, Buffer): + tir.buffer_store(target, eval_op(op, target[sl], aug_value), sl) + else: + return super().aug_assign_slice(op, target, sl, aug_value) + + def boolop(self, op, left, right): + left = unwrap_cond(left) + if isinstance(left, PrimExpr): + with self.with_frame(BoolOpFrame()): + if op == 'And': + return tir.And(left, right()) + if op == 'Or': + return tir.Or(left, right()) + raise RuntimeError(f"Unsupported boolean operator: {op}") + else: + return super().boolop(op, left, right) + + def ifexp(self, cond, then, otherwise): + cond = unwrap_cond(cond) + if isinstance(cond, PrimExpr): + with self.with_frame(BoolOpFrame()): + return tir.if_then_else(cond, then(), otherwise()) + else: + return super().ifexp(cond, then, otherwise) + + def ret(self, value): + last_macro = self.find_frame_idx(MacroFrame) + if last_macro is not None: + frame = self.find_frame_idx(TIR_CONTROL_FRAME, start=last_macro) + if frame is not None: + raise NotImplementedError( + "Return from control flow is not supported yet. " + "You can't return inside `if`, `for`, `while` blocks in a macro. " + "You should allocate a var before the control flow, assign value inside the blocks, " + "and return the var after the control flow. i.e.\n" + "```\n" + "@T.macro\n" \ + "def my_macro(cond):\n" + " a: T.float16 = ...\n" + " if cond:\n" + " a = 1.0\n" + " return a\n" + "```" + ) + return super().ret(value) + + def ctx_with(self, ctx): + if isinstance(ctx, tir.frame.IRBuilderFrame): + return self.with_frame(ctx) + else: + return super().ctx_with(ctx) + + def assert_expr(self, cond, msg): + cond = unwrap_cond(cond) + if isinstance(cond, PrimExpr): + self.enter_frame(tir.Assert(cond, msg)) + else: + super().assert_expr(cond, msg) + + def rval(self, name: str, value: Any) -> Any: + if name in self.name_inside_frame: + frame = self.name_inside_frame[name] + if frame not in self.frames: + raise RuntimeError( + f"Use variable `{name}` outside its defining region, defined in frame: {frame}, current frames: {self.frames}." + ) + if isinstance(value, tir.IntImm): + return value.value + if isinstance(value, Buffer) and value.scope() == 'local.var': + return tir.BufferLoad(value, indices=[0]) + return super().rval(name, value) + + def arg(self, name, value): + if self.find_frame_idx(MacroFrame) is not None: + return value + else: + annot = self.arg_annot[name] + if callable(annot): + annot = annot() + return tir.arg(name, annot) + + def override(self, name: str): + if name == 'range': + return tir.serial + raise ValueError(f'Unknown override: {name}') + + +_P = ParamSpec('_P') +_T = TypeVar('_T') + + +class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc): + params: list[tvm.tir.Var | tvm.tir.Buffer] + body: tvm.tir.Stmt + ret_type: tvm.ir.Type + buffer_map: Map[tvm.tir.Var, tvm.tir.Buffer] + attrs: tvm.Attrs | None + span: Span | None + + +def macro(func: Callable[_P, _T]) -> PrimFunc[_P, _T]: + ir_gen = mutate(func) + + @functools.wraps(func) + def macro_wrapper(*args, **kwargs): + builder = Builder.current() + with builder.macro(func.__name__): + res = ir_gen(builder)(*args, **kwargs) + return res + + return macro_wrapper + + +def prim_func(func: Callable[_P, _T]) -> PrimFunc[_P, _T]: + # hints = get_type_hints(func) + hints = func.__annotations__ + ir_gen = mutate(func) + builder = Builder(hints) + with builder.prim_func(func.__name__): + ir_gen(builder)(*hints) + res = builder.get() + res.ir_gen = ir_gen + return res diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py new file mode 100644 index 000000000..37f9d44d0 --- /dev/null +++ b/tilelang/language/v2/dtypes.py @@ -0,0 +1,198 @@ +from tilelang import tvm +from tvm import ir +import torch +import ctypes +from typing import TYPE_CHECKING +from tvm import tir +import tvm.script.ir_builder.tir._ffi_api as tb_ffi + + +class VoidPtr: + ... + + +AnyDType = ir.Type | str | type | torch.dtype | tvm.DataType + +_dtype_cvt = [ + (None, 'handle', ctypes.c_long, 'long'), # use long to repr void* + (bool, 'bool', ctypes.c_bool, 'bool'), + (int, 'int32', ctypes.c_int32, 'int'), + (float, 'float32', ctypes.c_float, 'float'), + (torch.short, 'int16', ctypes.c_int16, 'short'), + (torch.int, 'int32', ctypes.c_int32, 'int'), + (torch.long, 'int64', ctypes.c_int64, 'long long'), + (torch.half, 'float16', None, None), + (torch.float, 'float32', ctypes.c_float, 'float'), + (torch.double, 'float64', ctypes.c_double, 'double'), + + # (pytype, 'tvm dtype str', 'ctypes dtype', 'cffi dtype') + (torch.bool, 'bool', ctypes.c_bool, 'bool'), + (torch.int8, 'int8', ctypes.c_int8, 'char'), + (torch.int16, 'int16', ctypes.c_int16, 'short'), + (torch.int32, 'int32', ctypes.c_int32, 'int'), + (torch.int64, 'int64', ctypes.c_int64, 'long long'), + (torch.uint8, 'uint8', ctypes.c_uint8, 'unsigned char'), + (torch.uint16, 'uint16', ctypes.c_uint16, 'unsigned short'), + (torch.uint32, 'uint32', ctypes.c_uint32, 'unsigned int'), + (torch.uint64, 'uint64', ctypes.c_uint64, 'unsigned long long'), + (torch.float16, 'float16', None, None), + (torch.float32, 'float32', ctypes.c_float, 'float'), + (torch.float64, 'float64', ctypes.c_double, 'double'), + (torch.float8_e4m3fn, 'float8_e4m3fn', None, None), + (torch.float8_e4m3fnuz, 'float8_e4m3fnuz', None, None), + (torch.float8_e5m2, 'float8_e5m2', None, None), + (torch.float8_e5m2fnuz, 'float8_e5m2fnuz', None, None), + (torch.float8_e8m0fnu, 'float8_e8m0fnu', None, None), + (torch.bfloat16, 'bfloat16', None, None), +] + + +def _create_type_mapper(sidx, didx, smapper=lambda x: x, dmapper=lambda x: x): + return { + smapper(item[sidx]): dmapper(item[didx]) + for item in _dtype_cvt + if item[didx] is not None and item[sidx] is not None + } + + +_dtype_tvm2py = _create_type_mapper(1, 0, lambda x: tvm.DataType(x)) +_dtype_tvm2ctype = _create_type_mapper(1, 2, lambda x: tvm.DataType(x)) +_dtype_tvm2cffi = _create_type_mapper(1, 3, lambda x: tvm.DataType(x)) + + +class dtype: + __cvt = _create_type_mapper(0, 1) + + def __init__(self, value: AnyDType): + if isinstance(value, dtype): + value = value.name + if not isinstance(value, str): + if value not in self.__cvt: + raise TypeError( + f"Unsupported dtype: {value}, expected one of {list(self.__cvt.keys())}") + value = self.__cvt[value] + self.name = value + + def __eq__(self, other: AnyDType): + if isinstance(other, str): + return str.__eq__(self.name, other) + if other in self.__cvt: + return str.__eq__(self.name, self.__cvt[other]) + return NotImplemented + + def __req__(self, other: AnyDType): + if isinstance(other, str): + return str.__eq__(self.name, other) + if other in self.__cvt: + return str.__eq__(self.name, self.__cvt[other]) + return NotImplemented + + def __ne__(self, other: AnyDType): + if isinstance(other, str): + return str.__ne__(self.name, other) + if other in self.__cvt: + return str.__ne__(self.name, self.__cvt[other]) + return NotImplemented + + def __rne__(self, other: AnyDType): + if isinstance(other, str): + return str.__ne__(self.name, other) + if other in self.__cvt: + return str.__ne__(self.name, self.__cvt[other]) + return NotImplemented + + def __repr__(self): + return f"dtype({str.__repr__(self.name)})" + + def __hash__(self): + return str.__hash__(self.name) + + def __call__(self, expr=None, is_size_var: bool = False) -> tir.Var: + return getattr(tb_ffi, self.name.title())(expr, is_size_var) + + def get_tvm_dtype(self) -> tvm.DataType: + return tvm.DataType(self.name) + + +def get_tvm_dtype(value: AnyDType) -> tvm.DataType: + if isinstance(value, (tvm.DataType, ir.Type)): + return value + if isinstance(value, dtype): + return value.get_tvm_dtype() + return dtype(value).get_tvm_dtype() + + +if TYPE_CHECKING: + + class int8(dtype): + ... + + class int16(dtype): + ... + + class int32(dtype): + ... + + class int64(dtype): + ... + + class uint8(dtype): + ... + + class uint16(dtype): + ... + + class uint32(dtype): + ... + + class uint64(dtype): + ... + + class float16(dtype): + ... + + class float32(dtype): + ... + + class float64(dtype): + ... + + class bool(dtype): + ... + + class float8_e4m3fn(dtype): + ... + + class float8_e4m3fnuz(dtype): + ... + + class float8_e5m2(dtype): + ... + + class float8_e5m2fnuz(dtype): + ... + + class float8_e8m0fnu(dtype): + ... + + class bfloat16(dtype): + ... +else: + int8 = dtype('int8') + int16 = dtype('int16') + int32 = dtype('int32') + int64 = dtype('int64') + uint8 = dtype('uint8') + uint16 = dtype('uint16') + uint32 = dtype('uint32') + uint64 = dtype('uint64') + float16 = dtype('float16') + float32 = dtype('float32') + float64 = dtype('float64') + bool = dtype('bool') + float8_e4m3fn = dtype('float8_e4m3fn') + float8_e4m3fnuz = dtype('float8_e4m3fnuz') + float8_e5m2 = dtype('float8_e5m2') + float8_e5m2fnuz = dtype('float8_e5m2fnuz') + float8_e8m0fnu = dtype('float8_e8m0fnu') + bfloat16 = dtype('bfloat16') diff --git a/tilelang/language/v2/utils.py b/tilelang/language/v2/utils.py new file mode 100644 index 000000000..daa9c6d0c --- /dev/null +++ b/tilelang/language/v2/utils.py @@ -0,0 +1,106 @@ +from __future__ import annotations +import ast +import inspect +from typing import Any, Callable, Literal +from tilelang import env +from hashlib import sha256 +import linecache + + +def disk_compile(source, name): + cache_dir = env.TILELANG_CACHE_DIR + if cache_dir is not None: + import os + save_dir = os.path.join(cache_dir, "py-cache") + os.makedirs(save_dir, exist_ok=True) + hash_sfx = sha256(source.encode('utf-8')).hexdigest()[:8] + path = os.path.join(save_dir, f"{name}.{hash_sfx}.py") + with open(path, 'w') as f: + f.write(source) + linecache.cache[path] = (len(source), None, source.splitlines(), path) + return compile(source, path, "exec") + + +def _remove_leading_ident(source: str): + lines = source.splitlines() + if not lines: + return source + ident_size = len(lines[0]) - len(lines[0].lstrip()) + return "\n".join([line[ident_size:] if len(line) >= ident_size else line for line in lines]) + + +def get_func_nonlocals(func): + """A modified version of `inspect.getclosurevars`""" + + if inspect.ismethod(func): + func = func.__func__ + + if not inspect.isfunction(func): + raise TypeError(f"{func!r} is not a Python function") + + code = func.__code__ + # Nonlocal references are named in co_freevars and resolved + # by looking them up in __closure__ by positional index + nonlocal_vars = {} + if func.__closure__ is not None: + for var, cell in zip(code.co_freevars, func.__closure__): + try: + nonlocal_vars[var] = cell.cell_contents + except ValueError as err: + # cell_contents may raise ValueError if the cell is empty. + if "empty" not in str(err): + raise + return nonlocal_vars + + +def inspect_function_capture(func: Callable) -> dict[str, Any]: + """Capture function non-locals and global variables. + + Parameters + ---------- + func : Callable + The function to inspect. + + Returns + ------- + res : Dict[str, Any] + The function variables map with non-local or global variables. + """ + captured = { + **func.__globals__, # type: ignore + **get_func_nonlocals(func), + } + return captured + + +def get_ast(func: Callable): + _, start = inspect.getsourcelines(func) + filename = inspect.getsourcefile(func) or inspect.getfile(func) + source = inspect.getsource(func) + source = _remove_leading_ident(source) + source = '\n' * (start - 1) + source + tree = ast.parse(source, filename=filename) + return tree + + +CompileMethod = Literal['direct', 'disk'] + + +def get_compiled_object(source: str | ast.AST, + name: str, + filename: str = None, + globals: dict[str, Any] = None): + if isinstance(source, ast.AST): + assert filename is not None, "filename must be provided when source is an AST" + try: + if isinstance(source, ast.AST): + ast.fix_missing_locations(source) + compiled = compile(source, filename, 'exec') + else: + compiled = disk_compile(source, name) + except Exception as e: + source_str = source if isinstance(source, str) else ast.unparse(source) + raise RuntimeError(f'Failed to compile source for {name}:\n{source_str}') from e + locs = {} + exec(compiled, globals, locs) + return locs[name] From 9d6659c463183a8bee4ecfce6f1e34e265f2d1f9 Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Mon, 27 Oct 2025 14:31:47 +0800 Subject: [PATCH 02/24] syntax sugar: defining a local var by annotation --- tilelang/language/v2/ast.py | 75 ++++++++++++++----- tilelang/language/v2/builder.py | 125 +++++++++++++++++++++++++------- 2 files changed, 156 insertions(+), 44 deletions(-) diff --git a/tilelang/language/v2/ast.py b/tilelang/language/v2/ast.py index 1ed9c3d0d..7dc7f31fa 100644 --- a/tilelang/language/v2/ast.py +++ b/tilelang/language/v2/ast.py @@ -157,7 +157,12 @@ def eval_aug_assign(op: Operator, left: Any, sl: slice, right: Any) -> Any: raise ValueError(f'Unknown operator: {op}') +class _empty: + ... + + class BaseBuilder: + empty = _empty def get_parent_locals(self): return inspect.currentframe().f_back.f_back.f_locals @@ -189,10 +194,13 @@ def ctx_while(self, cond: Callable[[], Any]) -> Iterable[None]: while cond(): yield - def bind(self, name: str, value: Any) -> Any: + def bind(self, name: str, value: Any, annot: Any = empty) -> Any: + return value + + def unwrap_value(self, value): return value - def assign_slice(self, lval: Any, sl: slice, value: Any): + def assign_slice(self, lval: Any, sl: slice, value: Any, annot: Any = empty): lval[sl] = value def aug_assign(self, op: Operator, target: Any, aug_value: Any) -> Any: @@ -273,7 +281,8 @@ def _parse_names(self, target: ast.expr): elif isinstance(target, ast.Tuple): return ("(" + ",".join([self._parse_names(elt) for elt in target.elts]) + ",)") else: - raise SyntaxError("Unsupported for target") + s = ast.unparse(target) + raise NotImplementedError(f"Unsupported for target `{s}`") def visit_For(self, node: ast.For): node = self.generic_visit(node) @@ -299,18 +308,42 @@ def visit_Break(self, node: ast.Break): node = self.generic_visit(node) return quote("if __tb.ctx_break(): break", span=node) - def _emit_assign_target(self, target: ast.expr, rval: ast.expr) -> list[ast.AST]: + def _emit_assign_target(self, + target: ast.expr, + rval: ast.expr, + annot: ast.expr = None) -> list[ast.AST]: if isinstance(target, ast.Name): - return quote( - f"name = __tb.bind('{target.id}', value)", name=target, value=rval, span=target) + if annot is None: + return quote( + f"name = __tb.bind('{target.id}', value)", name=target, value=rval, span=target) + else: + return quote( + f'name = __tb.bind("{target.id}", value, annot)', + name=target, + value=rval, + annot=annot, + span=target) + elif isinstance(target, ast.Attribute): + s = ast.unparse(target) + raise NotImplementedError(f'Attribute assignment not supported yet, `{s}`') elif isinstance(target, ast.Subscript): - return quote( - "__tb.assign_slice(lval, slice, value)", - lval=target.value, - slice=target.slice, - value=rval, - span=target, - ) + if annot is None: + return quote( + "__tb.assign_slice(lval, slice, value)", + lval=target.value, + slice=target.slice, + value=rval, + span=target, + ) + else: + return quote( + "__tb.assign_slice(lval, slice, value, annot)", + lval=target.value, + slice=target.slice, + value=rval, + annot=annot, + span=target, + ) else: unpacked = [] @@ -327,7 +360,9 @@ def _visit_target(target: ast.expr) -> str: ast_set_span(res, ast_get_span(target)) return res - unpack_stmt = ast.Assign(targets=[_visit_target(target)], value=rval) + unpack_stmt = ast.Assign( + targets=[_visit_target(target)], + value=quote_expr('__tb.unwrap_value(rval)', rval=rval, span=rval)) ast_set_span(unpack_stmt, ast_get_span(target)) stmts = [unpack_stmt] bind_lvals = [] @@ -353,7 +388,8 @@ def flush_binds(): slice=target.slice, span=target)) else: - raise NotImplementedError(f'Unsupported target: {target}') + s = ast.unparse(target) + raise NotImplementedError(f'Unsupported target: {s}') flush_binds() return stmts @@ -387,6 +423,11 @@ def visit_AugAssign(self, node: ast.AugAssign) -> list[ast.AST]: else: return node + def visit_AnnAssign(self, node: ast.AnnAssign): + node = self.generic_visit(node) + rval = node.value or quote_expr('__tb.empty', span=node, annot=node) + return self._emit_assign_target(node.target, rval, annot=node.annotation) + def visit_While(self, node): return quote1( "for _ in __tb.ctx_while(lambda: cond):\n pass", @@ -412,7 +453,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef): node.body = stmts + node.body node.decorator_list.clear() return quote1( - "def ir_generator(__tb):\n" + f"def {node.name}(__tb):\n" " range = __tb.override('range')\n" " pass\n" f" return {node.name}", @@ -483,7 +524,7 @@ def mutate(func: Callable[_P, _T]) -> Callable[[BaseBuilder], Callable[_P, _T]]: tree = utils.get_ast(func) filename = inspect.getsourcefile(func) or inspect.getfile(func) tree = DSLMutator().visit(tree) - fn = utils.get_compiled_object(tree, "ir_generator", filename, + fn = utils.get_compiled_object(tree, func.__name__, filename, utils.inspect_function_capture(func)) fn.__source__ = ast.unparse(tree) return fn diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index a8c1f3155..8df4d8e90 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -1,7 +1,9 @@ from __future__ import annotations from contextlib import contextmanager -import functools +from dataclasses import dataclass + +import torch from tilelang.language.kernel import KernelLaunchFrame from tvm.ffi.container import Map from tvm.ir.base import Span @@ -9,8 +11,10 @@ import tvm from tvm.tir import Buffer from tvm.script.ir_builder import tir, IRBuilder -from tvm.tir.expr import EqualOp, NotEqualOp, PrimExpr -from typing import Callable, ContextManager, Any, Generic, ParamSpec, Self, TypeVar +from tvm.tir.expr import EqualOp, NotEqualOp, PrimExpr, Var +from typing import Callable, ContextManager, Any, Generic, Hashable, ParamSpec, Self, TypeVar +from .dtypes import get_tvm_dtype +from types import EllipsisType import threading import logging @@ -89,9 +93,13 @@ class BlockFrame(DummyFrame): ) +def is_var(v: Any) -> bool: + return isinstance(v, Buffer) and v.scope() == 'local.var' + + class Builder(BaseBuilder): - def __init__(self, arg_annot: dict[str, Any]): + def __init__(self, arg_annot: dict[str, Any] = None): self.arg_annot = arg_annot self.frames: list[AnyFrame] = [] self.ir_builder = IRBuilder() @@ -210,21 +218,46 @@ def ctx_break(self): def ctx_while(self, cond): raise RuntimeError("while loops are not supported in TileLang builder") - def bind(self, name, value): - if name == '_': - return value + def bind(self, name, value, annot=BaseBuilder.empty): locals = self.get_parent_locals() orig_value = locals.get(name, None) - # handle var - if isinstance(orig_value, Buffer) and orig_value.scope() == 'local.var': + # annotation like tl.float32 + if callable(annot): + annot_val = annot() + if isinstance(annot_val, tir.Var): + orig_value = tir.alloc_buffer((1,), dtype=annot_val.dtype, scope='local.var') + IRBuilder.name(name, orig_value) + if isinstance(value, EllipsisType) or value is self.empty: + return orig_value + # if orig_value is a local.var, we use buffer_store to modify it immutably + # however, if rvalue is also a local.var, this is a new binding, + # we should not use buffer_store, and bind it instead + # ```py + # a = tl.alloc_var('float32') # bind var `a` + # a = tl.alloc_var('float32') # bind a new var `a_1` + # b = a # get value of var `b = a_1[0]`` + # c = tl.alloc_var('float32') # bind var `c` + # c = a # get and assign `c[0] = a_1[0]` + # ``` + if is_var(orig_value) and not is_var(value): tir.buffer_store(orig_value, value, 0) return orig_value res = self.bind_immutable(name, value) - frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME) - assert frame is not None, f"Variable `{name}` is not defined inside any control flow." - self.name_inside_frame[name] = self.frames[frame] + if name != '_': + frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME) + assert frame is not None, f"Variable `{name}` is not defined inside any control flow." + self.name_inside_frame[name] = self.frames[frame] return res + def unwrap_value(self, value): + # handle bx, by = tl.Kernel(128, 128), rval is frame + if isinstance(value, tir.meta_var): + return value.value + elif isinstance(value, tir.frame.IRBuilderFrame): + return self.enter_frame(value) + else: + return value + def bind_immutable(self, name, value): if isinstance(value, tir.meta_var): return value.value @@ -243,7 +276,10 @@ def bind_immutable(self, name, value): IRBuilder.name(name, var) return self.enter_frame(frame) - def assign_slice(self, lval: Any, sl: slice, value: Any): + def assign_slice(self, lval: Any, sl: slice, value: Any, annot=BaseBuilder.empty): + if annot is not self.empty: + logger.warning( + "Type annotation in slice assignment has no effect", stack_info=True, stacklevel=2) if isinstance(lval, Buffer): tir.buffer_store(lval, value, sl) else: @@ -333,11 +369,14 @@ def rval(self, name: str, value: Any) -> Any: def arg(self, name, value): if self.find_frame_idx(MacroFrame) is not None: return value + if isinstance(value, (Buffer, Var)): + return tir.arg(name, value) + elif hasattr(value, '__tl_arg__'): + return value.__tl_arg__(name, self) + elif isinstance(value, Hashable): + return value else: - annot = self.arg_annot[name] - if callable(annot): - annot = annot() - return tir.arg(name, annot) + raise TypeError(f"Unsupported argument type: {type(value)} for argument `{name}`.") def override(self, name: str): if name == 'range': @@ -345,10 +384,27 @@ def override(self, name: str): raise ValueError(f'Unknown override: {name}') +def __torch_tensor_tl_arg__(self: torch.Tensor, name: str, builder: Builder): + buffer = tir.buffer( + self.shape, get_tvm_dtype(self.dtype), strides=self.stride(), scope='global') + return tir.arg(name, buffer) + + +torch.Tensor.__tl_arg__ = __torch_tensor_tl_arg__ + _P = ParamSpec('_P') _T = TypeVar('_T') +@dataclass +class IRGenerator(Generic[_P, _T]): + func: Callable[[BaseBuilder], Callable[_P, _T]] + source: str + + def __call__(self, tb: BaseBuilder) -> Callable[_P, _T]: + return self.func(tb) + + class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc): params: list[tvm.tir.Var | tvm.tir.Buffer] body: tvm.tir.Stmt @@ -356,28 +412,43 @@ class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc): buffer_map: Map[tvm.tir.Var, tvm.tir.Buffer] attrs: tvm.Attrs | None span: Span | None + ir_gen: IRGenerator[_P, _T] + source: str -def macro(func: Callable[_P, _T]) -> PrimFunc[_P, _T]: - ir_gen = mutate(func) +@dataclass +class Macro(Generic[_P, _T]): + name: str + ir_gen: IRGenerator[_P, _T] + + @property + def source(self) -> str: + return self.ir_gen.source - @functools.wraps(func) - def macro_wrapper(*args, **kwargs): + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: builder = Builder.current() - with builder.macro(func.__name__): - res = ir_gen(builder)(*args, **kwargs) + with builder.macro(self.name): + res = self.ir_gen(builder)(*args, **kwargs) return res - return macro_wrapper + +def macro(func: Callable[_P, _T]) -> Macro[_P, _T]: + ir_gen = mutate(func) + ir_gen = IRGenerator(func=ir_gen, source=ir_gen.__source__) + return Macro(func.__name__, ir_gen) def prim_func(func: Callable[_P, _T]) -> PrimFunc[_P, _T]: - # hints = get_type_hints(func) hints = func.__annotations__ + for k in hints: + if callable(hints[k]): + hints[k] = hints[k]() ir_gen = mutate(func) - builder = Builder(hints) + ir_gen = IRGenerator(func=ir_gen, source=ir_gen.__source__) + builder = Builder() with builder.prim_func(func.__name__): - ir_gen(builder)(*hints) + ir_gen(builder)(**hints) res = builder.get() res.ir_gen = ir_gen + res.source = ir_gen.source return res From 09d8aec19bc29e53b5a7b000158c7f01b9c1184b Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Mon, 27 Oct 2025 14:47:49 +0800 Subject: [PATCH 03/24] [Refactor] fix type linting warning like `T.float32` --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 5bf17a346..9cda9b611 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 5bf17a34602931e7d7e01cbccf358a21fe972779 +Subproject commit 9cda9b611ba9d91a1d42b561767f40aba0afcd78 From 4c75e85046ae2754ad00154d5125a2b26fca8b21 Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Mon, 27 Oct 2025 15:00:20 +0800 Subject: [PATCH 04/24] Add tl.local_var_init for new tl.float32 --- tilelang/language/v2/builder.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 8df4d8e90..0f7b5c2b9 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -11,7 +11,7 @@ import tvm from tvm.tir import Buffer from tvm.script.ir_builder import tir, IRBuilder -from tvm.tir.expr import EqualOp, NotEqualOp, PrimExpr, Var +from tvm.tir.expr import EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, Var from typing import Callable, ContextManager, Any, Generic, Hashable, ParamSpec, Self, TypeVar from .dtypes import get_tvm_dtype from types import EllipsisType @@ -229,6 +229,12 @@ def bind(self, name, value, annot=BaseBuilder.empty): IRBuilder.name(name, orig_value) if isinstance(value, EllipsisType) or value is self.empty: return orig_value + elif isinstance(value, (int, float, IntImm, FloatImm)): + tir.block_attr( + {'tl.local_var_init': { + orig_value.data: tvm.runtime.convert(value) + }}) + return orig_value # if orig_value is a local.var, we use buffer_store to modify it immutably # however, if rvalue is also a local.var, this is a new binding, # we should not use buffer_store, and bind it instead From 8dce2586ee526843e42244c71363dd3a7dc25f58 Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Mon, 27 Oct 2025 15:20:40 +0800 Subject: [PATCH 05/24] allow passing default argument as function annotation --- tilelang/language/v2/builder.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 0f7b5c2b9..8622c18e8 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -2,6 +2,7 @@ from contextlib import contextmanager from dataclasses import dataclass +import inspect import torch from tilelang.language.kernel import KernelLaunchFrame @@ -420,6 +421,7 @@ class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc): span: Span | None ir_gen: IRGenerator[_P, _T] source: str + orig_func: Callable[_P, _T] @dataclass @@ -445,16 +447,22 @@ def macro(func: Callable[_P, _T]) -> Macro[_P, _T]: def prim_func(func: Callable[_P, _T]) -> PrimFunc[_P, _T]: - hints = func.__annotations__ - for k in hints: - if callable(hints[k]): - hints[k] = hints[k]() + sig = inspect.signature(func) + annot = func.__annotations__ + for param in sig.parameters.values(): + if param.default is param.empty: + if param.annotation is param.empty: + raise TypeError(f"Parameter `{param.name}` in prim_func `{func.__name__}` " + "must have type annotation or default value.") + param.default = param.annotation + args = sig.bind() ir_gen = mutate(func) ir_gen = IRGenerator(func=ir_gen, source=ir_gen.__source__) - builder = Builder() + builder = Builder(annot) with builder.prim_func(func.__name__): - ir_gen(builder)(**hints) + ir_gen(builder)(*args.args, **args.kwargs) res = builder.get() res.ir_gen = ir_gen res.source = ir_gen.source + res.orig_func = func return res From e3815c6a61b5b305f5624314f960ab0546d97efc Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Mon, 27 Oct 2025 15:55:42 +0800 Subject: [PATCH 06/24] allow default arguments as annotation --- tilelang/language/v2/builder.py | 66 +++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 19 deletions(-) diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 8622c18e8..916797c3e 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -3,6 +3,7 @@ from contextlib import contextmanager from dataclasses import dataclass import inspect +import typing import torch from tilelang.language.kernel import KernelLaunchFrame @@ -13,7 +14,7 @@ from tvm.tir import Buffer from tvm.script.ir_builder import tir, IRBuilder from tvm.tir.expr import EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, Var -from typing import Callable, ContextManager, Any, Generic, Hashable, ParamSpec, Self, TypeVar +from typing import Callable, ContextManager, Any, Generic, Hashable, ParamSpec, Self, TypeVar, get_type_hints from .dtypes import get_tvm_dtype from types import EllipsisType import threading @@ -399,18 +400,16 @@ def __torch_tensor_tl_arg__(self: torch.Tensor, name: str, builder: Builder): torch.Tensor.__tl_arg__ = __torch_tensor_tl_arg__ + _P = ParamSpec('_P') _T = TypeVar('_T') @dataclass class IRGenerator(Generic[_P, _T]): - func: Callable[[BaseBuilder], Callable[_P, _T]] + gen: Callable[[BaseBuilder], Callable[_P, _T]] source: str - def __call__(self, tb: BaseBuilder) -> Callable[_P, _T]: - return self.func(tb) - class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc): params: list[tvm.tir.Var | tvm.tir.Buffer] @@ -427,6 +426,7 @@ class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc): @dataclass class Macro(Generic[_P, _T]): name: str + orig_func: Callable[_P, _T] ir_gen: IRGenerator[_P, _T] @property @@ -436,31 +436,59 @@ def source(self) -> str: def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: builder = Builder.current() with builder.macro(self.name): - res = self.ir_gen(builder)(*args, **kwargs) + res = self.ir_gen.gen(builder)(*args, **kwargs) return res -def macro(func: Callable[_P, _T]) -> Macro[_P, _T]: +def build_ir_generator(func: Callable[_P, _T]) -> IRGenerator[_P, _T]: ir_gen = mutate(func) - ir_gen = IRGenerator(func=ir_gen, source=ir_gen.__source__) - return Macro(func.__name__, ir_gen) + ir_gen = IRGenerator(gen=ir_gen, source=ir_gen.__source__) + return ir_gen + + +def macro(func: Callable[_P, _T]) -> Macro[_P, _T]: + return Macro( + name=func.__name__, + orig_func=func, + ir_gen=build_ir_generator(func) + ) def prim_func(func: Callable[_P, _T]) -> PrimFunc[_P, _T]: sig = inspect.signature(func) annot = func.__annotations__ - for param in sig.parameters.values(): - if param.default is param.empty: - if param.annotation is param.empty: - raise TypeError(f"Parameter `{param.name}` in prim_func `{func.__name__}` " - "must have type annotation or default value.") - param.default = param.annotation - args = sig.bind() - ir_gen = mutate(func) - ir_gen = IRGenerator(func=ir_gen, source=ir_gen.__source__) + if any(map(lambda x: isinstance(x, str), annot)): + try: + annot = get_type_hints(func) + except Exception as e: + raise RuntimeError( + f"Failed to get type hints for function `{func.__name__}`. \n" + "Note: if you are using `from __future__ import annotations`, type hints may be missing, \n" + "To fix this, please use default argument instead of type annotations: \n" + "```py\n" + "def foo(a=tl.Tensor((128, 128), 'float32'), b=tl.float32()): ..." + "```" + ) from e + args = [] + kwargs = {} + for name, param in sig.parameters.items(): + if param.annotation is not param.empty: + if callable(param.annotation): + value = param.annotation() + else: + value = param.annotation + elif param.default is not param.empty: + value = param.default + else: + value = Builder.empty + if param.kind == param.POSITIONAL_ONLY: + args.append(value) + else: + kwargs[name] = value + ir_gen = build_ir_generator(func) builder = Builder(annot) with builder.prim_func(func.__name__): - ir_gen(builder)(*args.args, **args.kwargs) + ir_gen.gen(builder)(*args, **kwargs) res = builder.get() res.ir_gen = ir_gen res.source = ir_gen.source From 66ee7b2ab11509eba687eea59aa466d7bf2a939d Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Mon, 27 Oct 2025 16:00:32 +0800 Subject: [PATCH 07/24] fix lint error --- tilelang/language/v2/builder.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 916797c3e..a50b40877 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -3,7 +3,6 @@ from contextlib import contextmanager from dataclasses import dataclass import inspect -import typing import torch from tilelang.language.kernel import KernelLaunchFrame @@ -400,7 +399,6 @@ def __torch_tensor_tl_arg__(self: torch.Tensor, name: str, builder: Builder): torch.Tensor.__tl_arg__ = __torch_tensor_tl_arg__ - _P = ParamSpec('_P') _T = TypeVar('_T') @@ -447,11 +445,7 @@ def build_ir_generator(func: Callable[_P, _T]) -> IRGenerator[_P, _T]: def macro(func: Callable[_P, _T]) -> Macro[_P, _T]: - return Macro( - name=func.__name__, - orig_func=func, - ir_gen=build_ir_generator(func) - ) + return Macro(name=func.__name__, orig_func=func, ir_gen=build_ir_generator(func)) def prim_func(func: Callable[_P, _T]) -> PrimFunc[_P, _T]: @@ -467,8 +461,7 @@ def prim_func(func: Callable[_P, _T]) -> PrimFunc[_P, _T]: "To fix this, please use default argument instead of type annotations: \n" "```py\n" "def foo(a=tl.Tensor((128, 128), 'float32'), b=tl.float32()): ..." - "```" - ) from e + "```") from e args = [] kwargs = {} for name, param in sig.parameters.items(): From 77457369c8a0f5c23c4b3680dd10efc7fd2d890d Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Mon, 27 Oct 2025 16:55:23 +0800 Subject: [PATCH 08/24] minor fix --- tilelang/language/v2/builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index a50b40877..7fdce0371 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -190,11 +190,11 @@ def eval(self, val: Any): elif isinstance(val, PrimExpr): tir.evaluate(val) elif isinstance(val, (int, bool)): - self.enter_frame(tir.evaluate(tvm.tir.const(val))) + tir.evaluate(tvm.tir.const(val)) elif isinstance(val, str): pass elif isinstance(val, tvm.tir.stmt.BufferStore): - self.enter_frame(tir.buffer_store(val.buffer, val.value, val.indices, val.predicate)) + tir.buffer_store(val.buffer, val.value, val.indices, val.predicate) else: raise TypeError(f"Unsupported eval value: {val} of type {type(val)}") From 65bd4fc4ee42db5df559cc9f68c1ef51e1ed08ab Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Mon, 27 Oct 2025 16:58:06 +0800 Subject: [PATCH 09/24] [Refactor] refactor tilelang.jit and tilelang.autotune --- tilelang/autotuner/tuner.py | 215 +++++++++++------------------ tilelang/jit/__init__.py | 260 +++++++++++++++--------------------- 2 files changed, 179 insertions(+), 296 deletions(-) diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index e94ac7466..d096bc5a6 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -4,17 +4,19 @@ and performance optimization through configuration search. """ from __future__ import annotations +from dataclasses import dataclass import tilelang from tilelang import tvm as tvm +from tilelang.jit import JITImpl +from tilelang.jit.kernel import JITKernel from tvm.tir import PrimFunc, Var from tvm.target import Target import inspect from functools import partial -from typing import (Callable, Literal, Any, overload) +from typing import (Callable, Generic, Literal, Any, ParamSpec, TypeVar) from tqdm import tqdm import logging -import functools import concurrent.futures import torch import os @@ -30,7 +32,6 @@ from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult from tilelang.autotuner.capture import get_autotune_inputs from tilelang.utils.target import determine_target -from tilelang.jit.param import _P, _RProg from tilelang import __version__ @@ -585,9 +586,13 @@ def __call__(self) -> Any: return self.run() -class _AutoTunerImplementation: - # Overload __init__ to help type checkers understand the effect of return_program - # The '-> None' is for __init__ itself. The crucial part is Literal for return_program. +_P = ParamSpec('_P') +_T = TypeVar('_T') + + +@dataclass +class AutoTuneImpl(Generic[_P, _T]): + jit_impl: JITImpl warmup: int = 25 rep: int = 100 @@ -603,125 +608,51 @@ class _AutoTunerImplementation: manual_check_prog: Callable = None cache_input_tensors: bool = False - def __init__(self, - configs: dict | Callable, - warmup: int = 25, - rep: int = 100, - timeout: int = 100, - supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, - ref_prog: Callable = None, - supply_prog: Callable = None, - rtol: float = 1e-2, - atol: float = 1e-2, - max_mismatched_ratio: float = 0.01, - skip_check: bool = False, - manual_check_prog: Callable = None, - cache_input_tensors: bool = False) -> None: - """Initialize the AutoTunerImplementation. + def __post_init__(self): + self._tuner_cache = {} + + def get_tunner(self): + autotuner = AutoTuner( + self.jit_impl.func, configs=self.configs).set_profile_args( + supply_type=self.supply_type, + ref_prog=self.ref_prog, + supply_prog=self.supply_prog, + rtol=self.rtol, + atol=self.atol, + max_mismatched_ratio=self.max_mismatched_ratio, + skip_check=self.skip_check, + manual_check_prog=self.manual_check_prog, + cache_input_tensors=self.cache_input_tensors, + ).set_compile_args( + out_idx=self.jit_impl.out_idx, + execution_backend=self.jit_impl.execution_backend, + target=self.jit_impl.target, + target_host=self.jit_impl.target_host, + verbose=self.jit_impl.verbose, + pass_configs=self.jit_impl.pass_configs, + ) + autotuner.run = partial(autotuner.run, self.warmup, self.rep, self.timeout) + return autotuner + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel: + key_args_tuple = args + key_kwargs_tuple = tuple(sorted(kwargs.items())) + key = (key_args_tuple, key_kwargs_tuple) + if key not in self._tuner_cache: - Args: - configs: Configuration space to explore during auto-tuning. - warmup: Number of warmup iterations before timing. - rep: Number of repetitions for timing measurements. - timeout: Maximum time (in seconds) allowed for each configuration. - supply_type: Strategy for generating input tensors (random/zeros/etc) - ref_prog: Reference implementation for validation - supply_prog: Custom function to provide input tensors - rtol: Relative tolerance for numerical validation - atol: Absolute tolerance for numerical validation - max_mismatched_ratio: Allowed percentage of mismatched values - skip_check: Bypass validation against reference implementation - manual_check_prog: Custom validation function - cache_input_tensors: Reuse input tensors across trials - """ - # Configuration and benchmarking parameters - self.configs = configs # Search space of tuning configurations - self.warmup = warmup # Warmup iterations for stable measurements - self.rep = rep # Measurement repetitions for statistics - self.timeout = timeout # Per-configuration timeout threshold - - # Tensor handling and validation setup - self.supply_type = supply_type # Input tensor generation strategy - self.ref_prog = ref_prog # Ground truth implementation - self.supply_prog = supply_prog # Custom input data provider - self.rtol = rtol # Relative error tolerance - self.atol = atol # Absolute error tolerance - self.max_mismatched_ratio = max_mismatched_ratio # Allowed mismatch - - # Validation control flags - self.skip_check = skip_check # Bypass accuracy verification - self.manual_check_prog = manual_check_prog # Custom validation - self.cache_input_tensors = cache_input_tensors # Reuse inputs - - # Cache for storing tuned kernel implementations - self._tuner_cache: dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel - - # This tells the type checker what the *wrapper* function will return. - # this is for linting, please do not remove it. - @overload - def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, AutotuneResult]]: - ... - - @overload - def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, AutotuneResult]: - ... - - # Actual implementation of __call__ - def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Any]: - warmup = self.warmup - rep = self.rep - timeout = self.timeout - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - - key_args_tuple = args - key_kwargs_tuple = tuple(sorted(kwargs.items())) - key = (key_args_tuple, key_kwargs_tuple) - - if key not in self._tuner_cache: - - def jit_compile(**config_arg): - return fn(*args, **kwargs, __tune_params=config_arg) - - compile_arguments = fn(__return_compile_arguments=True) - - autotuner = AutoTuner( - fn, configs=self.configs).set_profile_args( - supply_type=self.supply_type, - ref_prog=self.ref_prog, - supply_prog=self.supply_prog, - rtol=self.rtol, - atol=self.atol, - max_mismatched_ratio=self.max_mismatched_ratio, - skip_check=self.skip_check, - manual_check_prog=self.manual_check_prog, - cache_input_tensors=self.cache_input_tensors, - ).set_compile_args( - out_idx=compile_arguments['out_idx'], - execution_backend=compile_arguments['execution_backend'], - target=compile_arguments['target'], - target_host=compile_arguments['target_host'], - verbose=compile_arguments['verbose'], - pass_configs=compile_arguments['pass_configs'], - ) - - autotuner.jit_compile = jit_compile - autotuner.set_kernel_parameters(key, inspect.signature(fn).parameters) - - autotuner.run = partial(autotuner.run, warmup, rep, timeout) - - artifact = autotuner.run() - - self._tuner_cache[key] = artifact.kernel - - return self._tuner_cache[key] - - return wrapper + def jit_compile(**config_arg): + return self.jit_impl(*args, **kwargs, __tune_params=config_arg) + + autotuner = self.get_tunner() + autotuner.jit_compile = jit_compile + autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters) + artifact = autotuner.run() + self._tuner_cache[key] = artifact.kernel + return self._tuner_cache[key] def autotune( # This is the new public interface - func: Callable[_P, _RProg] | PrimFunc | None = None, + func: Callable[_P, _T] | PrimFunc | None = None, *, # Indicates subsequent arguments are keyword-only configs: dict | Callable, # profile arguments @@ -795,22 +726,26 @@ def autotune( # This is the new public interface elif isinstance(func, PrimFunc): raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.") else: - # Case 2: Used as @autotune(...) to configure, or func_or_out_idx is meant as out_idx. - # Create a _AutoTunerImplementation instance with the provided/defaulted arguments. - # This instance is a decorator that will be applied to the function later. - configured_decorator = _AutoTunerImplementation( - configs=configs, - warmup=warmup, - rep=rep, - timeout=timeout, - supply_type=supply_type, - ref_prog=ref_prog, - supply_prog=supply_prog, - rtol=rtol, - atol=atol, - max_mismatched_ratio=max_mismatched_ratio, - skip_check=skip_check, - manual_check_prog=manual_check_prog, - cache_input_tensors=cache_input_tensors, - ) - return configured_decorator + + def decorator(impl): + assert isinstance( + impl, JITImpl + ), "The @autotune decorator can only be applied to @tilelang.jit decorated instances." + return AutoTuneImpl( + jit_impl=impl, + configs=configs, + warmup=warmup, + rep=rep, + timeout=timeout, + supply_type=supply_type, + ref_prog=ref_prog, + supply_prog=supply_prog, + rtol=rtol, + atol=atol, + max_mismatched_ratio=max_mismatched_ratio, + skip_check=skip_check, + manual_check_prog=manual_check_prog, + cache_input_tensors=cache_input_tensors, + ) + + return decorator diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index 2080a00c6..aef72935a 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -5,9 +5,14 @@ """ from __future__ import annotations +from dataclasses import dataclass +import inspect from typing import ( Any, Callable, + Generic, + ParamSpec, + TypeVar, overload, Literal, ) @@ -21,8 +26,7 @@ from tilelang.cache import cached from os import path, makedirs from logging import getLogger -import functools -from tilelang.jit.param import Kernel, _P, _RProg +from tilelang.jit.param import Kernel logger = getLogger(__name__) @@ -79,8 +83,13 @@ def compile( ) -class _JitImplementation: +_P = ParamSpec('_P') +_T = TypeVar('_T') + +@dataclass +class JITImpl(Generic[_P, _T]): + func: Callable[_P, _T] out_idx: list[int] | int | None target: str | Target target_host: str | Target @@ -89,149 +98,98 @@ class _JitImplementation: pass_configs: dict[str, Any] | None debug_root_path: str | None compile_flags: list[str] | str | None + func_source: str + signature: inspect.Signature - def __init__(self, - out_idx: Any = None, - target: str | Target = "auto", - target_host: str | Target = None, - execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - debug_root_path: str | None = None, - compile_flags: list[str] | str | None = None): - """ - Initializes the JIT compiler decorator. - - Parameters - ---------- - out_idx : Any, optional - Index(es) of the output tensors to return from the compiled kernel - (default: None, meaning all outputs are returned or determined by the kernel itself). - target : Union[str, Target], optional - Compilation target for TVM. Can be a string (e.g., "cuda", "llvm") - or a TVM Target object. If "auto", the target is determined automatically - (default: "auto"). - target_host : Union[str, Target], optional - Target host for cross-compilation, similar to `target` (default: None). - execution_backend : Literal["dlpack", "ctypes", "cython"], optional - The backend used for kernel execution and argument passing. - "dlpack" is generally preferred for zero-copy tensor passing with compatible frameworks. - "ctypes" uses standard C types. "cython" uses Cython for potentially faster execution. - (default: "cython"). - verbose : bool, optional - If True, enables verbose logging during compilation (default: False). - pass_configs : Optional[Dict[str, Any]], optional - A dictionary of configurations for TVM's pass context. These can fine-tune - the compilation process. Examples include "tir.disable_vectorize" - (default: None). - debug_root_path : Optional[str], optional - If provided, the compiled kernel's source code will be saved to a file - in this directory. This is useful for debugging the generated code. - If None, no debug information is saved (default: None). - If a relative path is given, it's made absolute relative to the project root - or current working directory. - compile_flags : Optional[Union[List[str], str]], optional - Additional compilation flags to pass to the compiler. - If None, no additional compilation flags are passed (default: None). - """ - self.out_idx = out_idx - self.execution_backend = execution_backend - self.target = target - self.target_host = target_host - self.verbose = verbose - self.pass_configs = pass_configs - self.compile_flags = compile_flags - - # Corrected debug_root_path handling - self.debug_root_path = debug_root_path + def __post_init__(self): if self.debug_root_path is not None and not path.isabs(self.debug_root_path): try: base_path = path.dirname(path.dirname(path.dirname(__file__))) self.debug_root_path = path.join(base_path, self.debug_root_path) except NameError: self.debug_root_path = path.abspath(self.debug_root_path) - self._kernel_cache: dict[tuple, Kernel] = {} - # This tells the type checker what the *wrapper* function will return. - # this is for linting, please do not remove it. - @overload - def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, Kernel]]: - ... - - @overload - def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Kernel]: - ... - - # Actual implementation of __call__ - def __call__( - self, - func: Callable[_P, _RProg] # func is Union[Callable[_P, _RProg], PrimFunc] in original - ) -> Callable[_P, Any]: - - @functools.wraps(func) - def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: - # Separate out the tuning parameters from the user's kwargs - tune_params = kwargs.pop('__tune_params', {}) - # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache - return_compile_arguments = kwargs.pop('__return_compile_arguments', False) - if return_compile_arguments: - compile_args = { - 'out_idx': self.out_idx, - 'execution_backend': self.execution_backend, - 'target': self.target, - 'target_host': self.target_host, - 'verbose': self.verbose, - 'pass_configs': self.pass_configs, - 'compile_flags': self.compile_flags, - } - return compile_args - - key_args_tuple = args - key_kwargs_tuple = tuple(sorted(kwargs.items())) - tuned_key_kwargs_tuple = tuple(sorted(tune_params.items())) - key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple) - - if key not in self._kernel_cache: - # Ensure 'func' (the original user function) is used correctly - program_result_source = func - if isinstance(program_result_source, PrimFunc): - program_result = program_result_source - elif callable(program_result_source): - program_result = program_result_source(*args, **kwargs, **tune_params) - else: - raise ValueError(f"Invalid function type: {type(program_result_source)}") - - kernel_result = compile( - program_result, - out_idx=self.out_idx, - execution_backend=self.execution_backend, - target=self.target, - target_host=self.target_host, - verbose=self.verbose, - pass_configs=self.pass_configs, - compile_flags=self.compile_flags, - ) - - if self.debug_root_path: - func_name = getattr(func, '__name__', 'jit_kernel') # Use func for name - kernel_file = f'tilelang_jit_kernel_{func_name}.c' - program_file = f'tilelang_jit_program_{func_name}.py' - makedirs(self.debug_root_path, exist_ok=True) - with open(path.join(self.debug_root_path, kernel_file), 'w') as f: - print(kernel_result.get_kernel_source(), file=f) - with open(path.join(self.debug_root_path, program_file), 'w') as f: - print(program_result.script(), file=f) - - self._kernel_cache[key] = kernel_result - - return self._kernel_cache[key] - - return wrapper + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel: + # Separate out the tuning parameters from the user's kwargs + tune_params = kwargs.pop('__tune_params', {}) + # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache + return_compile_arguments = kwargs.pop('__return_compile_arguments', False) + if return_compile_arguments: + compile_args = { + 'out_idx': self.out_idx, + 'execution_backend': self.execution_backend, + 'target': self.target, + 'target_host': self.target_host, + 'verbose': self.verbose, + 'pass_configs': self.pass_configs, + 'compile_flags': self.compile_flags, + } + return compile_args + + key_args_tuple = args + key_kwargs_tuple = tuple(sorted(kwargs.items())) + tuned_key_kwargs_tuple = tuple(sorted(tune_params.items())) + key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple) + + if key not in self._kernel_cache: + # Ensure 'func' (the original user function) is used correctly + program_result_source = self.func + if isinstance(program_result_source, PrimFunc): + program_result = program_result_source + elif callable(program_result_source): + program_result = program_result_source(*args, **kwargs, **tune_params) + else: + raise ValueError(f"Invalid function type: {type(program_result_source)}") + + kernel_result = compile( + program_result, + out_idx=self.out_idx, + execution_backend=self.execution_backend, + target=self.target, + target_host=self.target_host, + verbose=self.verbose, + pass_configs=self.pass_configs, + compile_flags=self.compile_flags, + ) + + if self.debug_root_path: + func_name = getattr(self.func, '__name__', 'jit_kernel') # Use func for name + kernel_file = f'tilelang_jit_kernel_{func_name}.c' + program_file = f'tilelang_jit_program_{func_name}.py' + makedirs(self.debug_root_path, exist_ok=True) + with open(path.join(self.debug_root_path, kernel_file), 'w') as f: + print(kernel_result.get_kernel_source(), file=f) + with open(path.join(self.debug_root_path, program_file), 'w') as f: + print(program_result.script(), file=f) + + self._kernel_cache[key] = kernel_result + + return self._kernel_cache[key] + + +@overload +def jit(func: Callable[_P, _T]) -> JITImpl[_P, _T]: + ... + + +@overload +def jit( + *, # Indicates subsequent arguments are keyword-only + out_idx: Any = None, + target: str | Target = "auto", + target_host: str | Target = None, + execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + debug_root_path: str | None = None, + compile_flags: list[str] | str | None = None +) -> Callable[[Callable[_P, _T]], JITImpl[_P, _T]]: + ... def jit( # This is the new public interface - func: Callable[_P, _RProg] | PrimFunc | None = None, + func: Callable[_P, _T] | PrimFunc | None = None, *, # Indicates subsequent arguments are keyword-only out_idx: Any = None, target: str | Target = "auto", @@ -275,32 +233,22 @@ def jit( # This is the new public interface if isinstance(compile_flags, str): compile_flags = [compile_flags] - if callable(func): - # Case 1: Used as @jit (func_or_out_idx is the function, others are defaults) - # Create a default _JitImplementation instance and apply it to the function. - default_decorator = _JitImplementation( - out_idx=out_idx, # Explicitly None for the default case + def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]: + return JITImpl( + func, + out_idx=out_idx, target=target, target_host=target_host, execution_backend=execution_backend, verbose=verbose, pass_configs=pass_configs, debug_root_path=debug_root_path, - compile_flags=compile_flags) - return default_decorator(func) - elif isinstance(func, PrimFunc): - raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.") + compile_flags=compile_flags, + func_source=inspect.getsource(func), + signature=inspect.signature(func), + ) + + if callable(func): + return decorator(func) else: - # Case 2: Used as @jit(...) to configure, or func_or_out_idx is meant as out_idx. - # Create a _JitImplementation instance with the provided/defaulted arguments. - # This instance is a decorator that will be applied to the function later. - configured_decorator = _JitImplementation( - out_idx=out_idx, # Pass along; could be an actual out_idx or None - target=target, - target_host=target_host, - execution_backend=execution_backend, - verbose=verbose, - pass_configs=pass_configs, - debug_root_path=debug_root_path, - compile_flags=compile_flags) - return configured_decorator + return decorator From d592fbf672fd8f19c716f8c57cb7cbd13cc92661 Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Mon, 27 Oct 2025 17:58:13 +0800 Subject: [PATCH 10/24] minor fix --- .../test_tilelang_transform_layout_inference.py | 14 ++++++-------- ...lelang_transform_legalize_safe_memory_access.py | 2 +- .../test_tilelang_transform_lower_tile_op.py | 14 ++++++-------- tilelang/language/v2/builder.py | 5 ++++- 4 files changed, 17 insertions(+), 18 deletions(-) diff --git a/testing/python/transform/test_tilelang_transform_layout_inference.py b/testing/python/transform/test_tilelang_transform_layout_inference.py index dd7f7e2ce..5f9bc3240 100644 --- a/testing/python/transform/test_tilelang_transform_layout_inference.py +++ b/testing/python/transform/test_tilelang_transform_layout_inference.py @@ -15,9 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): N = tvm.te.var("n") K = tvm.te.var("k") - @tvm.script.ir.ir_module - class Before: - + def before(): @T.prim_func def main(B: T.Tensor((K, N), dtype),): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): @@ -37,10 +35,9 @@ def main(B: T.Tensor((K, N), dtype),): t // (block_N // vec_load_b), bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec], T.float16(0)) + return tvm.IRModule({'main': main}) - @tvm.script.ir.ir_module - class After: - + def after(): @T.prim_func def main(B: T.Tensor((K, N), dtype),): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): @@ -76,12 +73,13 @@ def main(B: T.Tensor((K, N), dtype),): t // (block_N // vec_load_b), bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec], T.float16(0)) + return tvm.IRModule({'main': main}) with tvm.target.Target(auto_target): - mod = tvm.tir.transform.BindTarget(auto_target)(Before) + mod = tvm.tir.transform.BindTarget(auto_target)(before()) mod = tl.transform.LayoutInference()(mod) mod = tvm.tir.transform.Simplify()(mod) - ref_mod = tvm.tir.transform.BindTarget(auto_target)(After) + ref_mod = tvm.tir.transform.BindTarget(auto_target)(after()) ref_mod = tvm.tir.transform.Simplify()(ref_mod) # Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass # This loop is "for vec in T.parallel(1)", diff --git a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py index 5202ab647..e5215db25 100644 --- a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py +++ b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py @@ -41,7 +41,7 @@ def assert_vectorize_access(M: int = 64, N: int = 64): def issue_1013_buggy_kernel(): # NOTE: This kernel is mainly to test some corner cases in boundary check - num_tokens = T.dynamic('num_tokens') + num_tokens = T.Var('num_tokens', 'int32') num_threads = 128 @T.prim_func diff --git a/testing/python/transform/test_tilelang_transform_lower_tile_op.py b/testing/python/transform/test_tilelang_transform_lower_tile_op.py index 1729072d2..27d32a259 100644 --- a/testing/python/transform/test_tilelang_transform_lower_tile_op.py +++ b/testing/python/transform/test_tilelang_transform_lower_tile_op.py @@ -15,19 +15,16 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): N = tvm.te.var("n") K = tvm.te.var("k") - @tvm.script.ir.ir_module - class Before: - + def before(): @T.prim_func def main(B: T.Tensor((K, N), dtype),): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared((block_K, block_N), dtype) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): T.copy(B[k * block_K, bx * block_N], B_shared) + return tvm.IRModule({'main': main}) - @tvm.script.ir.ir_module - class After: - + def after(): @T.prim_func def main(B: T.Tensor((K, N), dtype),): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): @@ -63,12 +60,13 @@ def main(B: T.Tensor((K, N), dtype),): t // (block_N // vec_load_b), bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec], T.float16(0)) + return tvm.IRModule({'main': main}) with tvm.transform.PassContext(): - mod = tvm.tir.transform.BindTarget(auto_target)(Before) + mod = tvm.tir.transform.BindTarget(auto_target)(before()) mod = tl.transform.LowerTileOp()(mod) mod = tvm.tir.transform.Simplify()(mod) - ref_mod = tvm.tir.transform.BindTarget(auto_target)(After) + ref_mod = tvm.tir.transform.BindTarget(auto_target)(after()) ref_mod = tvm.tir.transform.Simplify()(ref_mod) # Note(tzj): The structures are equal except the argument in "T.reads" function. # The difference is just between the first index and the indices range, which is totally equivalent diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 7fdce0371..aeb88dfb1 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -273,6 +273,8 @@ def bind_immutable(self, name, value): elif isinstance(value, (Buffer, tir.IterVar, tir.Var)): IRBuilder.name(name, value) return value + elif isinstance(value, (tuple, list, tvm.ffi.Array)): + return value else: try: value = tvm.runtime.convert(value) @@ -365,7 +367,8 @@ def rval(self, name: str, value: Any) -> Any: frame = self.name_inside_frame[name] if frame not in self.frames: raise RuntimeError( - f"Use variable `{name}` outside its defining region, defined in frame: {frame}, current frames: {self.frames}." + f"Use immutable variable `{name}` outside its defining region, did you forget **alloc_var**?\n" + f"variable `{name}` is defined in frame: {frame}, current frames: {self.frames}." ) if isinstance(value, tir.IntImm): return value.value From eff7916fc17f36b6d434e20344c07fe97618adc4 Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Mon, 27 Oct 2025 17:58:23 +0800 Subject: [PATCH 11/24] minor fix --- tilelang/language/v2/ast.py | 17 ++++++++++++----- tilelang/language/v2/builder.py | 19 +++++++------------ tilelang/language/v2/utils.py | 2 +- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/tilelang/language/v2/ast.py b/tilelang/language/v2/ast.py index 7dc7f31fa..5b23a1d20 100644 --- a/tilelang/language/v2/ast.py +++ b/tilelang/language/v2/ast.py @@ -396,11 +396,18 @@ def flush_binds(): def visit_Assign(self, node: ast.Assign) -> list[ast.AST]: node = self.generic_visit(node) rval = node.value - stmts = [] - for target in reversed(node.targets): - stmts.extend(self._emit_assign_target(target, rval)) - rval = target - return stmts + if len(node.targets) == 1: + return self._emit_assign_target(node.targets[0], rval) + else: + tmp_name = self.get_tmp() + tmp_store = ast.Name(tmp_name, ctx=ast.Store()) + tmp_load = ast.Name(tmp_name, ctx=ast.Load()) + ast_set_span(tmp_store, node.targets[0]) + ast_set_span(tmp_load, node.targets[0]) + stmt = self._emit_assign_target(tmp_store, rval) + for target in node.targets: + stmt.extend(self._emit_assign_target(target, tmp_load)) + return stmt def visit_AugAssign(self, node: ast.AugAssign) -> list[ast.AST]: node = self.generic_visit(node) diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index aeb88dfb1..5068186cd 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -29,7 +29,7 @@ def unwrap_expr(expr) -> PrimExpr | int | float: expr = tir.BufferLoad(expr, indices=[0]) elif isinstance(expr, (EqualOp, NotEqualOp)): expr = expr.asobject() - elif isinstance(expr, tir.IntImm) and expr.dtype == 'int32': + elif isinstance(expr, IntImm) and expr.dtype == 'int32': expr = expr.value return expr @@ -257,10 +257,9 @@ def bind(self, name, value, annot=BaseBuilder.empty): return res def unwrap_value(self, value): + value = unwrap_expr(value) # handle bx, by = tl.Kernel(128, 128), rval is frame - if isinstance(value, tir.meta_var): - return value.value - elif isinstance(value, tir.frame.IRBuilderFrame): + if isinstance(value, tir.frame.IRBuilderFrame): return self.enter_frame(value) else: return value @@ -295,9 +294,9 @@ def assign_slice(self, lval: Any, sl: slice, value: Any, annot=BaseBuilder.empty return super().assign_slice(lval, sl, value) def aug_assign(self, op, target, aug_value): - if isinstance(target, Buffer) and target.scope() == 'local.var': - tir.buffer_store(target, eval_op(op, target, aug_value), 0) - if isinstance(target, Buffer): + if is_var(target): + tir.buffer_store(target, eval_op(op, target[0], aug_value), 0) + elif isinstance(target, Buffer): raise RuntimeError("Augmented assignment is not supported for Buffer") else: return super().aug_assign(op, target, aug_value) @@ -370,11 +369,7 @@ def rval(self, name: str, value: Any) -> Any: f"Use immutable variable `{name}` outside its defining region, did you forget **alloc_var**?\n" f"variable `{name}` is defined in frame: {frame}, current frames: {self.frames}." ) - if isinstance(value, tir.IntImm): - return value.value - if isinstance(value, Buffer) and value.scope() == 'local.var': - return tir.BufferLoad(value, indices=[0]) - return super().rval(name, value) + return unwrap_expr(value) def arg(self, name, value): if self.find_frame_idx(MacroFrame) is not None: diff --git a/tilelang/language/v2/utils.py b/tilelang/language/v2/utils.py index daa9c6d0c..739ecd1eb 100644 --- a/tilelang/language/v2/utils.py +++ b/tilelang/language/v2/utils.py @@ -100,7 +100,7 @@ def get_compiled_object(source: str | ast.AST, compiled = disk_compile(source, name) except Exception as e: source_str = source if isinstance(source, str) else ast.unparse(source) - raise RuntimeError(f'Failed to compile source for {name}:\n{source_str}') from e + raise RuntimeError(f'Failed to compile source for {name}, Error: {e}:\n{source_str}') from e locs = {} exec(compiled, globals, locs) return locs[name] From 6f69f0243ead1e32e4fecc3d30a6f1ca2865df89 Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Mon, 27 Oct 2025 17:58:47 +0800 Subject: [PATCH 12/24] minor fix --- .../transform/test_tilelang_transform_layout_inference.py | 4 ++++ .../python/transform/test_tilelang_transform_lower_tile_op.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/testing/python/transform/test_tilelang_transform_layout_inference.py b/testing/python/transform/test_tilelang_transform_layout_inference.py index 5f9bc3240..66415aacb 100644 --- a/testing/python/transform/test_tilelang_transform_layout_inference.py +++ b/testing/python/transform/test_tilelang_transform_layout_inference.py @@ -16,6 +16,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): K = tvm.te.var("k") def before(): + @T.prim_func def main(B: T.Tensor((K, N), dtype),): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): @@ -35,9 +36,11 @@ def main(B: T.Tensor((K, N), dtype),): t // (block_N // vec_load_b), bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec], T.float16(0)) + return tvm.IRModule({'main': main}) def after(): + @T.prim_func def main(B: T.Tensor((K, N), dtype),): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): @@ -73,6 +76,7 @@ def main(B: T.Tensor((K, N), dtype),): t // (block_N // vec_load_b), bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec], T.float16(0)) + return tvm.IRModule({'main': main}) with tvm.target.Target(auto_target): diff --git a/testing/python/transform/test_tilelang_transform_lower_tile_op.py b/testing/python/transform/test_tilelang_transform_lower_tile_op.py index 27d32a259..07dbd53f1 100644 --- a/testing/python/transform/test_tilelang_transform_lower_tile_op.py +++ b/testing/python/transform/test_tilelang_transform_lower_tile_op.py @@ -16,15 +16,18 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): K = tvm.te.var("k") def before(): + @T.prim_func def main(B: T.Tensor((K, N), dtype),): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared((block_K, block_N), dtype) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): T.copy(B[k * block_K, bx * block_N], B_shared) + return tvm.IRModule({'main': main}) def after(): + @T.prim_func def main(B: T.Tensor((K, N), dtype),): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): @@ -60,6 +63,7 @@ def main(B: T.Tensor((K, N), dtype),): t // (block_N // vec_load_b), bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec], T.float16(0)) + return tvm.IRModule({'main': main}) with tvm.transform.PassContext(): From 20feef28bb70b071eb25db2fb5ea6af3a1495cab Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Tue, 28 Oct 2025 10:52:48 +0800 Subject: [PATCH 13/24] fix metal get function name --- tilelang/jit/adapter/torch/metal.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tilelang/jit/adapter/torch/metal.py b/tilelang/jit/adapter/torch/metal.py index 30e84ad71..0b1bc0098 100644 --- a/tilelang/jit/adapter/torch/metal.py +++ b/tilelang/jit/adapter/torch/metal.py @@ -27,7 +27,11 @@ def __init__( # compile_flags: Optional[List[str]] = None ): self.kernel_global_source = kernel_global_source - self.kernel_name = func_or_mod.__name__ + '_kernel' + if isinstance(func_or_mod, tir.PrimFunc): + func_name = func_or_mod.attrs['global_symbol'] + else: + func_name = func_or_mod.__name__ + self.kernel_name = func_name + '_kernel' self.verbose = verbose self.block_info = [1, 1, 1] @@ -43,7 +47,7 @@ def __init__( self.grid_info["xyz".index(tag[-1])] = extent break else: - raise AssertionError(f'no kernel with name {func_or_mod.__name__}') + raise AssertionError(f'no kernel with name {func_name}') # print(self.block_info, self.grid_info) super().__init__(func_or_mod, result_idx=result_idx, params=params) From 015416b9f9bdbd98c5eb607e8749769ad3c84c29 Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Tue, 28 Oct 2025 12:39:50 +0800 Subject: [PATCH 14/24] add par_compile impl and tests --- .../jit/test_tilelang_jit_parcompile.py | 75 ++++++++ tilelang/__init__.py | 2 +- tilelang/autotuner/tuner.py | 6 +- tilelang/jit/__init__.py | 169 ++++++++++++++---- 4 files changed, 215 insertions(+), 37 deletions(-) create mode 100644 testing/python/jit/test_tilelang_jit_parcompile.py diff --git a/testing/python/jit/test_tilelang_jit_parcompile.py b/testing/python/jit/test_tilelang_jit_parcompile.py new file mode 100644 index 000000000..deaef6487 --- /dev/null +++ b/testing/python/jit/test_tilelang_jit_parcompile.py @@ -0,0 +1,75 @@ +from tilelang import tvm +import tilelang.testing +import tilelang +import torch + +@tilelang.jit( + out_idx=-1, # create the output tensor during runtime + verbose=True, +) +def matmul_kernel_jit( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A=False, + trans_B=True, + in_dtype='float16', + out_dtype='float32', + accum_dtype='float32', + num_stages=2, + threads=128, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def test_par_compile(): + configs = [ + (1024, 1024, 1024, 128, 128, 32), + (2048, 2048, 2048, 256, 256, 64), + (4096, 4096, 4096, 64, 64, 128), + ] + ker = matmul_kernel_jit(1024, 1024, 1024, 128, 128, 32) + kernels = matmul_kernel_jit.par_compile(configs) + for (M, N, K, _, _, _), kernel in zip(configs, kernels): + A = torch.randn(M, K, dtype=torch.float16).cuda() + B = torch.randn(N, K, dtype=torch.float16).cuda() + ref = (A @ B.T).float() + C = kernel(A, B) + tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + tilelang.testing.main() \ No newline at end of file diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 98c2a6b37..bd978e5b1 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -4,7 +4,7 @@ import logging import warnings -from tqdm import tqdm +from tqdm.auto import tqdm from importlib.metadata import PackageNotFoundError, version diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index d096bc5a6..cc474dc45 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -15,7 +15,7 @@ import inspect from functools import partial from typing import (Callable, Generic, Literal, Any, ParamSpec, TypeVar) -from tqdm import tqdm +from tqdm.auto import tqdm import logging import concurrent.futures import torch @@ -525,12 +525,12 @@ def inner(**config_arg): # latency, ref_latency = target_fn(jit_kernel) latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel) except TimeoutException: - logger.info( + logger.warning( f"A timeout occurred while testing config {config}, checkout autotuner.log for more details" ) continue except Exception: - logger.info( + logger.warning( f"An error occurred while testing config {config}, checkout autotuner.log for more details" ) logger.debug(f"Error: {traceback.format_exc()}") diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index aef72935a..18a59dfcc 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -11,6 +11,7 @@ Any, Callable, Generic, + Iterable, ParamSpec, TypeVar, overload, @@ -27,6 +28,9 @@ from os import path, makedirs from logging import getLogger from tilelang.jit.param import Kernel +import concurrent.futures + +from tqdm.auto import tqdm logger = getLogger(__name__) @@ -83,6 +87,72 @@ def compile( ) +def par_compile(funcs: Iterable[PrimFunc], + out_idx: list[int] | int | None = None, + execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", + target: str | Target = "auto", + target_host: str | Target | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | str | None = None, + num_workers: int = None, + ignore_error: bool = False) -> list[JITKernel]: + """ + Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. + Parameters + ---------- + funcs : Iterable[tvm.tir.PrimFunc] + The TileLang TIR functions to compile and wrap. + out_idx : Union[List[int], int], optional + Index(es) of the output tensors to return (default: None). + execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional + Execution backend to use for kernel execution (default: "cython"). + target : Union[str, Target], optional + Compilation target, either as a string or a TVM Target object (default: "auto"). + target_host : Union[str, Target], optional + Target host for cross-compilation (default: None). + verbose : bool, optional + Whether to enable verbose output (default: False). + pass_configs : dict, optional + Additional keyword arguments to pass to the Compiler PassContext. + Refer to `tilelang.transform.PassConfigKey` for supported options. + """ + with concurrent.futures.ThreadPoolExecutor(num_workers, 'tl-par-comp') as executor: + futures = [] + future_map = {} + for i, func in enumerate(funcs): + future = executor.submit( + compile, + func=func, + out_idx=out_idx, + execution_backend=execution_backend, + target=target, + target_host=target_host, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + future_map[future] = i + futures.append(future) + results = [... for _ in futures] + for future in tqdm( + concurrent.futures.as_completed(futures), + total=len(futures), + desc="Parallel Compiling", + ): + idx = future_map[future] + if ignore_error: + try: + results[idx] = future.result() + except Exception as e: + logger.warning(f"Error compiling function at index {idx}: {e}") + results[idx] = None + else: + results[idx] = future.result() + return results + return results + + _P = ParamSpec('_P') _T = TypeVar('_T') @@ -91,9 +161,9 @@ def compile( class JITImpl(Generic[_P, _T]): func: Callable[_P, _T] out_idx: list[int] | int | None + execution_backend: Literal["dlpack", "ctypes", "cython"] target: str | Target target_host: str | Target - execution_backend: Literal["dlpack", "ctypes", "cython"] verbose: bool pass_configs: dict[str, Any] | None debug_root_path: str | None @@ -110,6 +180,69 @@ def __post_init__(self): self.debug_root_path = path.abspath(self.debug_root_path) self._kernel_cache: dict[tuple, Kernel] = {} + def get_tir(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc: + program_result_source = self.func + if isinstance(program_result_source, PrimFunc): + program_result = program_result_source + elif callable(program_result_source): + program_result = program_result_source(*args, **kwargs) + else: + raise ValueError(f"Invalid function type: {type(program_result_source)}") + return program_result + + def par_compile(self, + configs: Iterable[dict[str, Any] | tuple[str, Any]], + num_workers: int = None, + ignore_error: bool = False) -> list[JITKernel]: + configs = list(configs) + funcs = [] + for cfg in tqdm(configs, desc='Elaborating'): + if isinstance(cfg, tuple): + funcs.append(self.get_tir(*cfg)) + elif isinstance(cfg, dict): + funcs.append(self.get_tir(**cfg)) + else: + raise ValueError(f"Invalid config type: {type(cfg)}, expected tuple or dict.") + return par_compile( + funcs, + out_idx=self.out_idx, + execution_backend=self.execution_backend, + target=self.target, + target_host=self.target_host, + verbose=self.verbose, + pass_configs=self.pass_configs, + compile_flags=self.compile_flags, + num_workers=num_workers, + ignore_error=ignore_error) + + def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel: + func = self.get_tir(*args, **kwargs) + kernel_result = compile( + func, + out_idx=self.out_idx, + execution_backend=self.execution_backend, + target=self.target, + target_host=self.target_host, + verbose=self.verbose, + pass_configs=self.pass_configs, + compile_flags=self.compile_flags, + ) + + if self.debug_root_path: + if isinstance(self.func, PrimFunc): + func_name = self.func.attrs['global_symbol'] + else: + func_name = getattr(self.func, '__name__', 'jit_kernel') + kernel_file = f'tilelang_jit_kernel_{func_name}.c' + program_file = f'tilelang_jit_program_{func_name}.py' + makedirs(self.debug_root_path, exist_ok=True) + with open(path.join(self.debug_root_path, kernel_file), 'w') as f: + print(kernel_result.get_kernel_source(), file=f) + with open(path.join(self.debug_root_path, program_file), 'w') as f: + print(func.script(), file=f) + + return kernel_result + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel: # Separate out the tuning parameters from the user's kwargs tune_params = kwargs.pop('__tune_params', {}) @@ -133,37 +266,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel: key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple) if key not in self._kernel_cache: - # Ensure 'func' (the original user function) is used correctly - program_result_source = self.func - if isinstance(program_result_source, PrimFunc): - program_result = program_result_source - elif callable(program_result_source): - program_result = program_result_source(*args, **kwargs, **tune_params) - else: - raise ValueError(f"Invalid function type: {type(program_result_source)}") - - kernel_result = compile( - program_result, - out_idx=self.out_idx, - execution_backend=self.execution_backend, - target=self.target, - target_host=self.target_host, - verbose=self.verbose, - pass_configs=self.pass_configs, - compile_flags=self.compile_flags, - ) - - if self.debug_root_path: - func_name = getattr(self.func, '__name__', 'jit_kernel') # Use func for name - kernel_file = f'tilelang_jit_kernel_{func_name}.c' - program_file = f'tilelang_jit_program_{func_name}.py' - makedirs(self.debug_root_path, exist_ok=True) - with open(path.join(self.debug_root_path, kernel_file), 'w') as f: - print(kernel_result.get_kernel_source(), file=f) - with open(path.join(self.debug_root_path, program_file), 'w') as f: - print(program_result.script(), file=f) - - self._kernel_cache[key] = kernel_result + self._kernel_cache[key] = self.compile(*args, **kwargs, **tune_params) return self._kernel_cache[key] @@ -237,9 +340,9 @@ def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]: return JITImpl( func, out_idx=out_idx, + execution_backend=execution_backend, target=target, target_host=target_host, - execution_backend=execution_backend, verbose=verbose, pass_configs=pass_configs, debug_root_path=debug_root_path, From a7e202775cea21361d705037d369dc66440f85dd Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Tue, 28 Oct 2025 16:17:18 +0800 Subject: [PATCH 15/24] Type consistency on tvm datatype 1. isinstance(tl.float32, tvm.DataType) == True 2. Allow `tl.float32` as function annotations 3. Allow `tl.float32` as argument to be passed to `tl.alloc` or other functions --- .../jit/test_tilelang_jit_parcompile.py | 5 +- ... => test_tilelang_language_chain_equal.py} | 0 .../language/test_tilelang_language_dtype.py | 214 ++++++ ...tilelang_transform_multi_version_buffer.py | 4 +- tilelang/jit/__init__.py | 62 +- tilelang/jit/kernel.py | 9 +- tilelang/language/__init__.py | 2 +- tilelang/language/v2/__init__.py | 3 +- tilelang/language/v2/builder.py | 68 +- tilelang/language/v2/dtypes.py | 713 ++++++++++++++---- 10 files changed, 860 insertions(+), 220 deletions(-) rename testing/python/language/{test_tilelang_laguange_chain_equal.py => test_tilelang_language_chain_equal.py} (100%) create mode 100644 testing/python/language/test_tilelang_language_dtype.py diff --git a/testing/python/jit/test_tilelang_jit_parcompile.py b/testing/python/jit/test_tilelang_jit_parcompile.py index deaef6487..e7bcec412 100644 --- a/testing/python/jit/test_tilelang_jit_parcompile.py +++ b/testing/python/jit/test_tilelang_jit_parcompile.py @@ -1,8 +1,8 @@ -from tilelang import tvm import tilelang.testing import tilelang import torch + @tilelang.jit( out_idx=-1, # create the output tensor during runtime verbose=True, @@ -61,7 +61,6 @@ def test_par_compile(): (2048, 2048, 2048, 256, 256, 64), (4096, 4096, 4096, 64, 64, 128), ] - ker = matmul_kernel_jit(1024, 1024, 1024, 128, 128, 32) kernels = matmul_kernel_jit.par_compile(configs) for (M, N, K, _, _, _), kernel in zip(configs, kernels): A = torch.randn(M, K, dtype=torch.float16).cuda() @@ -72,4 +71,4 @@ def test_par_compile(): if __name__ == "__main__": - tilelang.testing.main() \ No newline at end of file + tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_laguange_chain_equal.py b/testing/python/language/test_tilelang_language_chain_equal.py similarity index 100% rename from testing/python/language/test_tilelang_laguange_chain_equal.py rename to testing/python/language/test_tilelang_language_chain_equal.py diff --git a/testing/python/language/test_tilelang_language_dtype.py b/testing/python/language/test_tilelang_language_dtype.py new file mode 100644 index 000000000..d6f6b0e74 --- /dev/null +++ b/testing/python/language/test_tilelang_language_dtype.py @@ -0,0 +1,214 @@ +import tilelang +import tilelang.language as T +import torch +import tilelang.testing +import tvm + +def test_argument(): + @T.prim_func + def test_argument( + t_1: T.bool, + t_2: T.short, + t_3: T.int, + t_4: T.long, + t_5: T.half, + t_6: T.float, + t_7: T.long, + t_8: T.int8, + t_9: T.int16, + t_10: T.int32, + t_11: T.int64, + t_12: T.uint8, + t_13: T.uint16, + t_14: T.uint32, + t_15: T.uint64, + t_16: T.float8_e4m3fn, + t_17: T.float8_e4m3fnuz, + t_18: T.float8_e5m2, + t_19: T.float8_e5m2fnuz, + t_20: T.float8_e8m0fnu, + t_21: T.float16, + t_22: T.bfloat16, + t_23: T.float32, + t_24: T.float64, + ): + pass + + +def test_expr(): + from tilelang.language.v2.dtypes import _all_dtypes + errors = [] + for name in _all_dtypes: + dtype = getattr(T, name) + assert isinstance(dtype, tvm.DataType), f"{dtype} is not tvm.DataType" + try: + dtype(1.0) + dtype() + except TypeError as e: + pass + except Exception as e: + errors.append(name) + assert not errors + + +def test_var_decl_sugar(): + @T.prim_func + def test_var_decl_sugar(): + with T.Kernel(128, 128) as (bx, by): + var_1: T.bool = 1.0 + var_2: T.short = 1.0 + var_3: T.int = 1.0 + var_4: T.long = 1.0 + var_5: T.half = 1.0 + var_6: T.float = 1.0 + var_7: T.long = 1.0 + var_8: T.int8 = 1.0 + var_9: T.int16 = 1.0 + var_10: T.int32 = 1.0 + var_11: T.int64 = 1.0 + var_12: T.uint8 = 1.0 + var_13: T.uint16 = 1.0 + var_14: T.uint32 = 1.0 + var_15: T.uint64 = 1.0 + var_16: T.float8_e4m3fn = 1.0 + var_17: T.float8_e4m3fnuz = 1.0 + var_18: T.float8_e5m2 = 1.0 + var_19: T.float8_e5m2fnuz = 1.0 + var_20: T.float8_e8m0fnu = 1.0 + var_21: T.float16 = 1.0 + var_22: T.bfloat16 = 1.0 + var_23: T.float32 = 1.0 + var_24: T.float64 = 1.0 + var_1: T.bool = var_1 + var_2: T.short = var_2 + var_3: T.int = var_3 + var_4: T.long = var_4 + var_5: T.half = var_5 + var_6: T.float = var_6 + var_7: T.long = var_7 + var_8: T.int8 = var_8 + var_9: T.int16 = var_9 + var_10: T.int32 = var_10 + var_11: T.int64 = var_11 + var_12: T.uint8 = var_12 + var_13: T.uint16 = var_13 + var_14: T.uint32 = var_14 + var_15: T.uint64 = var_15 + var_16: T.float8_e4m3fn = var_16 + var_17: T.float8_e4m3fnuz = var_17 + var_18: T.float8_e5m2 = var_18 + var_19: T.float8_e5m2fnuz = var_19 + var_20: T.float8_e8m0fnu = var_20 + var_21: T.float16 = var_21 + var_22: T.bfloat16 = var_22 + var_23: T.float32 = var_23 + var_24: T.float64 = var_24 + + s = test_var_decl_sugar.script() + for i in range(1, 25): + assert f'var_{i}_1' in s + assert f'tl.local_var_init' in s + +def test_dtype_str_repr(): + @T.prim_func + def test_str_repr(): + buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared') + buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared') + buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared') + buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared') + buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared') + buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared') + buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared') + buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared') + buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared') + buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared') + buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared') + buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared') + buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared') + buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared') + buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared') + buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared') + buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared') + buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared') + buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared') + buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared') + buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared') + buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared') + buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared') + buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') + +def test_torch_eq(): + dtypes = [ + T.bool, + T.short, + T.int, + T.long, + T.half, + T.float, + T.long, + T.int8, + T.int16, + T.int32, + T.int64, + T.uint8, + T.uint16, + T.uint32, + T.uint64, + T.float8_e4m3fn, + T.float8_e4m3fnuz, + T.float8_e5m2, + T.float8_e5m2fnuz, + T.float8_e8m0fnu, + T.float16, + T.bfloat16, + T.float32, + T.float64, + ] + torch_dtypes = [ + torch.bool, + torch.short, + torch.int, + torch.long, + torch.half, + torch.float, + torch.long, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + torch.float8_e8m0fnu, + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + ] + for a, b in zip(dtypes, torch_dtypes): + assert a == b, f"{a} and {b} are not equal" + + +def test_var_assign(): + @tilelang.jit(out_idx=-1) + @T.prim_func + def test_var_assign(A: T.Tensor((2,), T.int32)): + with T.Kernel(1) as _: + a: T.int32 = 1 + b: T.int32 = a + a = 2 + d: T.int32 = a + A[0] = b + A[1] = d + res = test_var_assign()() + assert res[0] == 1 + assert res[1] == 2 + + +if __name__ == '__main__': + tilelang.testing.main() diff --git a/testing/python/transform/test_tilelang_transform_multi_version_buffer.py b/testing/python/transform/test_tilelang_transform_multi_version_buffer.py index 6c9b5c539..ddb7f6662 100644 --- a/testing/python/transform/test_tilelang_transform_multi_version_buffer.py +++ b/testing/python/transform/test_tilelang_transform_multi_version_buffer.py @@ -113,7 +113,7 @@ def before(scales: T.Tensor((4,), "float32")): shared = T.alloc_buffer((8,), "float32", scope="shared.dyn") accum = T.alloc_buffer((8,), "float32", scope="local") for k in T.serial(4, annotations={"num_stages": T.int32(2)}): - value: T.float32 = scales[k] + value = scales[k] for i in T.serial(8): shared[i] = value for i in T.serial(8): @@ -125,7 +125,7 @@ def after(scales: T.Tensor((4,), "float32")): shared = T.alloc_buffer((2, 8), "float32", scope="shared.dyn") accum = T.alloc_buffer((8,), "float32", scope="local") for k in T.serial(4, annotations={"num_stages": T.int32(2)}): - value: T.float32 = scales[k] + value = scales[k] for i in T.serial(8): shared[k % 2, i] = value for i in T.serial(8): diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index 18a59dfcc..d64ea7967 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -18,8 +18,8 @@ Literal, ) from tilelang import tvm as tvm +from tilelang.language.v2 import PrimFunc from tilelang.jit.adapter.utils import is_metal_target -from tvm.tir import PrimFunc from tvm.target import Target from tilelang.jit.kernel import JITKernel @@ -34,9 +34,13 @@ logger = getLogger(__name__) +_P = ParamSpec('_P') +_KP = ParamSpec('_KP') +_T = TypeVar('_T') + def compile( - func: PrimFunc = None, + func: PrimFunc[_KP, _T] = None, out_idx: list[int] | int | None = None, execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", target: str | Target = "auto", @@ -44,7 +48,7 @@ def compile( verbose: bool = False, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | str | None = None, -) -> JITKernel: +) -> JITKernel[_KP, _T]: """ Compile the given TileLang PrimFunc with TVM and build a JITKernel. Parameters @@ -87,7 +91,7 @@ def compile( ) -def par_compile(funcs: Iterable[PrimFunc], +def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], out_idx: list[int] | int | None = None, execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", target: str | Target = "auto", @@ -96,7 +100,7 @@ def par_compile(funcs: Iterable[PrimFunc], pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | str | None = None, num_workers: int = None, - ignore_error: bool = False) -> list[JITKernel]: + ignore_error: bool = False) -> list[JITKernel[_KP, _T]]: """ Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. Parameters @@ -153,13 +157,9 @@ def par_compile(funcs: Iterable[PrimFunc], return results -_P = ParamSpec('_P') -_T = TypeVar('_T') - - @dataclass -class JITImpl(Generic[_P, _T]): - func: Callable[_P, _T] +class JITImpl(Generic[_P, _KP, _T]): + func: Callable[_P, _T] | PrimFunc[_KP, _T] out_idx: list[int] | int | None execution_backend: Literal["dlpack", "ctypes", "cython"] target: str | Target @@ -180,7 +180,7 @@ def __post_init__(self): self.debug_root_path = path.abspath(self.debug_root_path) self._kernel_cache: dict[tuple, Kernel] = {} - def get_tir(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc: + def get_tir(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc[_KP, _T]: program_result_source = self.func if isinstance(program_result_source, PrimFunc): program_result = program_result_source @@ -193,7 +193,7 @@ def get_tir(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc: def par_compile(self, configs: Iterable[dict[str, Any] | tuple[str, Any]], num_workers: int = None, - ignore_error: bool = False) -> list[JITKernel]: + ignore_error: bool = False) -> list[JITKernel[_KP, _T]]: configs = list(configs) funcs = [] for cfg in tqdm(configs, desc='Elaborating'): @@ -215,7 +215,7 @@ def par_compile(self, num_workers=num_workers, ignore_error=ignore_error) - def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel: + def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel[_KP, _T]: func = self.get_tir(*args, **kwargs) kernel_result = compile( func, @@ -243,7 +243,7 @@ def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel: return kernel_result - def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel: + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel[_KP, _T]: # Separate out the tuning parameters from the user's kwargs tune_params = kwargs.pop('__tune_params', {}) # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache @@ -272,22 +272,22 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel: @overload -def jit(func: Callable[_P, _T]) -> JITImpl[_P, _T]: +def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T]: ... @overload def jit( - *, # Indicates subsequent arguments are keyword-only - out_idx: Any = None, - target: str | Target = "auto", - target_host: str | Target = None, - execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - debug_root_path: str | None = None, - compile_flags: list[str] | str | None = None -) -> Callable[[Callable[_P, _T]], JITImpl[_P, _T]]: + *, # Indicates subsequent arguments are keyword-only + out_idx: Any = None, + target: str | Target = "auto", + target_host: str | Target = None, + execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + debug_root_path: str | None = None, + compile_flags: list[str] | str | None = None +) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T]]: ... @@ -337,6 +337,10 @@ def jit( # This is the new public interface compile_flags = [compile_flags] def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]: + if isinstance(func, PrimFunc): + orig_func = func.orig_func + else: + orig_func = func return JITImpl( func, out_idx=out_idx, @@ -347,11 +351,11 @@ def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]: pass_configs=pass_configs, debug_root_path=debug_root_path, compile_flags=compile_flags, - func_source=inspect.getsource(func), - signature=inspect.signature(func), + func_source=inspect.getsource(orig_func), + signature=inspect.signature(orig_func), ) - if callable(func): + if func is not None: return decorator(func) else: return decorator diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index 7fe307bfd..b560ef8bd 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Callable, Literal +from typing import Any, Callable, Generic, Literal, ParamSpec, TypeVar from tilelang.jit.adapter.utils import is_metal_target from tvm.target import Target @@ -17,8 +17,11 @@ logger = logging.getLogger(__name__) +_P = ParamSpec('_P') +_T = TypeVar('_T') -class JITKernel: + +class JITKernel(Generic[_P, _T]): """ A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions. @@ -170,7 +173,7 @@ def from_database( instance.torch_function = instance.adapter.func return instance - def __call__(self, *args: Any, **kwds: Any) -> Any: + def __call__(self, *args: _P.args, **kwds: _P.kwargs) -> _T: """ Invokes the compiled function with the given arguments. diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 16af88ff6..114d7b715 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -10,7 +10,7 @@ from . import overrides as _overrides # noqa: F401 # from .tir import prim_func, macro, # noqa: F401 -from .v2 import prim_func, macro # noqa: F401 +from .v2 import * # noqa: F401 from .tir.ir import * # noqa: F401 from tilelang.layout import Layout, Fragment # noqa: F401 from .proxy import ( diff --git a/tilelang/language/v2/__init__.py b/tilelang/language/v2/__init__.py index 23b907b37..b86b378ae 100644 --- a/tilelang/language/v2/__init__.py +++ b/tilelang/language/v2/__init__.py @@ -1 +1,2 @@ -from .builder import prim_func, macro # noqa: F401 +from .builder import prim_func, macro, PrimFunc # noqa: F401 +from .dtypes import * diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 5068186cd..0ddef17d9 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -13,7 +13,7 @@ from tvm.tir import Buffer from tvm.script.ir_builder import tir, IRBuilder from tvm.tir.expr import EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, Var -from typing import Callable, ContextManager, Any, Generic, Hashable, ParamSpec, Self, TypeVar, get_type_hints +from typing import TYPE_CHECKING, Callable, ContextManager, Any, Generic, Hashable, ParamSpec, Self, TypeVar, ForwardRef from .dtypes import get_tvm_dtype from types import EllipsisType import threading @@ -333,9 +333,8 @@ def ret(self, value): frame = self.find_frame_idx(TIR_CONTROL_FRAME, start=last_macro) if frame is not None: raise NotImplementedError( - "Return from control flow is not supported yet. " - "You can't return inside `if`, `for`, `while` blocks in a macro. " - "You should allocate a var before the control flow, assign value inside the blocks, " + "Return from control flow is not supported yet. \n" + "You should allocate a var before the control flow, assign value inside the blocks, \n" "and return the var after the control flow. i.e.\n" "```\n" "@T.macro\n" \ @@ -407,16 +406,20 @@ class IRGenerator(Generic[_P, _T]): source: str -class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc): - params: list[tvm.tir.Var | tvm.tir.Buffer] - body: tvm.tir.Stmt - ret_type: tvm.ir.Type - buffer_map: Map[tvm.tir.Var, tvm.tir.Buffer] - attrs: tvm.Attrs | None - span: Span | None - ir_gen: IRGenerator[_P, _T] - source: str - orig_func: Callable[_P, _T] +if TYPE_CHECKING: + + class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc): + params: list[tvm.tir.Var | tvm.tir.Buffer] + body: tvm.tir.Stmt + ret_type: tvm.ir.Type + buffer_map: Map[tvm.tir.Var, tvm.tir.Buffer] + attrs: tvm.Attrs | None + span: Span | None + ir_gen: IRGenerator[_P, _T] | None + source: str | None + orig_func: Callable[_P, _T] | None +else: + PrimFunc = tvm.tir.PrimFunc @dataclass @@ -446,20 +449,33 @@ def macro(func: Callable[_P, _T]) -> Macro[_P, _T]: return Macro(name=func.__name__, orig_func=func, ir_gen=build_ir_generator(func)) +from typing import _eval_type + + +def get_type_hints(func): + annot = getattr(func, '__annotations__', None) + if annot is None: + raise TypeError(f'Failed to get function type hints, {func} is not a function') + hints = {} + type_params = getattr(func, "__type_params__", ()) + globalns = getattr(func, '__globals__', {}) + localns = globalns + for name, value in annot.items(): + if isinstance(value, tvm.DataType): + hints[name] = value + continue + if value is None: + value = type(None) + if isinstance(value, str): + value = ForwardRef(value, is_argument=True, is_class=False) + + hints[name] = _eval_type(value, globalns=globalns, localns=localns, type_params=type_params) + return hints + + def prim_func(func: Callable[_P, _T]) -> PrimFunc[_P, _T]: sig = inspect.signature(func) - annot = func.__annotations__ - if any(map(lambda x: isinstance(x, str), annot)): - try: - annot = get_type_hints(func) - except Exception as e: - raise RuntimeError( - f"Failed to get type hints for function `{func.__name__}`. \n" - "Note: if you are using `from __future__ import annotations`, type hints may be missing, \n" - "To fix this, please use default argument instead of type annotations: \n" - "```py\n" - "def foo(a=tl.Tensor((128, 128), 'float32'), b=tl.float32()): ..." - "```") from e + annot = get_type_hints(func) args = [] kwargs = {} for name, param in sig.parameters.items(): diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py index 37f9d44d0..ff3388dbd 100644 --- a/tilelang/language/v2/dtypes.py +++ b/tilelang/language/v2/dtypes.py @@ -14,36 +14,37 @@ class VoidPtr: AnyDType = ir.Type | str | type | torch.dtype | tvm.DataType _dtype_cvt = [ - (None, 'handle', ctypes.c_long, 'long'), # use long to repr void* - (bool, 'bool', ctypes.c_bool, 'bool'), - (int, 'int32', ctypes.c_int32, 'int'), - (float, 'float32', ctypes.c_float, 'float'), - (torch.short, 'int16', ctypes.c_int16, 'short'), - (torch.int, 'int32', ctypes.c_int32, 'int'), - (torch.long, 'int64', ctypes.c_int64, 'long long'), - (torch.half, 'float16', None, None), - (torch.float, 'float32', ctypes.c_float, 'float'), - (torch.double, 'float64', ctypes.c_double, 'double'), + (None, 'handle', ctypes.c_long, 'long', None), # use long to repr void* + (bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'), + (int, 'int32', ctypes.c_int32, 'int', 'Int32'), + (float, 'float32', ctypes.c_float, 'float', 'Float32'), + (torch.short, 'int16', ctypes.c_int16, 'short', 'Int16'), + (torch.int, 'int32', ctypes.c_int32, 'int', 'Int32'), + (torch.long, 'int64', ctypes.c_int64, 'long long', 'Int64'), + (torch.half, 'float16', None, None, 'Float16'), + (torch.float, 'float32', ctypes.c_float, 'float', 'Float32'), + (torch.double, 'float64', ctypes.c_double, 'double', 'Float64'), # (pytype, 'tvm dtype str', 'ctypes dtype', 'cffi dtype') - (torch.bool, 'bool', ctypes.c_bool, 'bool'), - (torch.int8, 'int8', ctypes.c_int8, 'char'), - (torch.int16, 'int16', ctypes.c_int16, 'short'), - (torch.int32, 'int32', ctypes.c_int32, 'int'), - (torch.int64, 'int64', ctypes.c_int64, 'long long'), - (torch.uint8, 'uint8', ctypes.c_uint8, 'unsigned char'), - (torch.uint16, 'uint16', ctypes.c_uint16, 'unsigned short'), - (torch.uint32, 'uint32', ctypes.c_uint32, 'unsigned int'), - (torch.uint64, 'uint64', ctypes.c_uint64, 'unsigned long long'), - (torch.float16, 'float16', None, None), - (torch.float32, 'float32', ctypes.c_float, 'float'), - (torch.float64, 'float64', ctypes.c_double, 'double'), - (torch.float8_e4m3fn, 'float8_e4m3fn', None, None), - (torch.float8_e4m3fnuz, 'float8_e4m3fnuz', None, None), - (torch.float8_e5m2, 'float8_e5m2', None, None), - (torch.float8_e5m2fnuz, 'float8_e5m2fnuz', None, None), - (torch.float8_e8m0fnu, 'float8_e8m0fnu', None, None), - (torch.bfloat16, 'bfloat16', None, None), + (torch.bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'), + (torch.int8, 'int8', ctypes.c_int8, 'char', 'Int8'), + (torch.int16, 'int16', ctypes.c_int16, 'short', 'Int16'), + (torch.int32, 'int32', ctypes.c_int32, 'int', 'Int32'), + (torch.int64, 'int64', ctypes.c_int64, 'long long', 'Int64'), + (torch.uint8, 'uint8', ctypes.c_uint8, 'unsigned char', 'UInt8'), + (torch.uint16, 'uint16', ctypes.c_uint16, 'unsigned short', 'UInt16'), + (torch.uint32, 'uint32', ctypes.c_uint32, 'unsigned int', 'UInt32'), + (torch.uint64, 'uint64', ctypes.c_uint64, 'unsigned long long', 'UInt64'), + (torch.float16, 'float16', None, None, 'Float16'), + (torch.float32, 'float32', ctypes.c_float, 'float', 'Float32'), + (torch.float64, 'float64', ctypes.c_double, 'double', 'Float64'), + (None, 'float8_e4m3', None, None, 'Float8E4M3'), + (torch.float8_e4m3fn, 'float8_e4m3fn', None, None, 'Float8E4M3FN'), + (torch.float8_e4m3fnuz, 'float8_e4m3fnuz', None, None, 'Float8E4M3FNUZ'), + (torch.float8_e5m2, 'float8_e5m2', None, None, 'Float8E5M2'), + (torch.float8_e5m2fnuz, 'float8_e5m2fnuz', None, None, 'Float8E5M2FNUZ'), + (torch.float8_e8m0fnu, 'float8_e8m0fnu', None, None, 'Float8E8M0FNU'), + (torch.bfloat16, 'bfloat16', None, None, 'BFloat16'), ] @@ -55,144 +56,546 @@ def _create_type_mapper(sidx, didx, smapper=lambda x: x, dmapper=lambda x: x): } +_dtype_py2tvmstr = _create_type_mapper(0, 1) +_dtype_tvmstr2fficall = _create_type_mapper(1, 4, dmapper=lambda x: getattr(tb_ffi, x)) _dtype_tvm2py = _create_type_mapper(1, 0, lambda x: tvm.DataType(x)) _dtype_tvm2ctype = _create_type_mapper(1, 2, lambda x: tvm.DataType(x)) _dtype_tvm2cffi = _create_type_mapper(1, 3, lambda x: tvm.DataType(x)) -class dtype: - __cvt = _create_type_mapper(0, 1) - - def __init__(self, value: AnyDType): - if isinstance(value, dtype): - value = value.name - if not isinstance(value, str): - if value not in self.__cvt: - raise TypeError( - f"Unsupported dtype: {value}, expected one of {list(self.__cvt.keys())}") - value = self.__cvt[value] - self.name = value - - def __eq__(self, other: AnyDType): - if isinstance(other, str): - return str.__eq__(self.name, other) - if other in self.__cvt: - return str.__eq__(self.name, self.__cvt[other]) - return NotImplemented - - def __req__(self, other: AnyDType): - if isinstance(other, str): - return str.__eq__(self.name, other) - if other in self.__cvt: - return str.__eq__(self.name, self.__cvt[other]) - return NotImplemented - - def __ne__(self, other: AnyDType): - if isinstance(other, str): - return str.__ne__(self.name, other) - if other in self.__cvt: - return str.__ne__(self.name, self.__cvt[other]) - return NotImplemented - - def __rne__(self, other: AnyDType): - if isinstance(other, str): - return str.__ne__(self.name, other) - if other in self.__cvt: - return str.__ne__(self.name, self.__cvt[other]) - return NotImplemented - - def __repr__(self): - return f"dtype({str.__repr__(self.name)})" - - def __hash__(self): - return str.__hash__(self.name) - - def __call__(self, expr=None, is_size_var: bool = False) -> tir.Var: - return getattr(tb_ffi, self.name.title())(expr, is_size_var) - - def get_tvm_dtype(self) -> tvm.DataType: - return tvm.DataType(self.name) +def __dtype_eq__(self: tvm.DataType, other: AnyDType): + if isinstance(other, str): + return str.__eq__(self, other) + if other in _dtype_py2tvmstr: + return str.__eq__(self, _dtype_py2tvmstr[other]) + return NotImplemented + + +def __dtype_ne__(self: tvm.DataType, other: AnyDType): + if isinstance(other, str): + return str.__ne__(self, other) + if other in _dtype_py2tvmstr: + return str.__ne__(self, _dtype_py2tvmstr[other]) + return NotImplemented + + +def __dtype_call__(self: tvm.DataType, expr=None, is_size_var: bool = False) -> tir.Var: + if self in _dtype_tvmstr2fficall: + return _dtype_tvmstr2fficall[self](expr, is_size_var) + # try to construct the ffi call + if self.startswith('uint'): + val = 'UInt' + self[4:] + elif self.startswith('int'): + val = 'Int' + self[3:] + elif self.startswith('float'): + val = 'Float' + self[5:] + elif self.startswith('bfloat'): + val = 'BFloat' + self[6:] + else: + raise TypeError(f'Invalid type {self}') + if '_' in val: + first, second = val.split('_', maxsplit=1) + val = first + second.upper() + call = getattr(tb_ffi, val, None) + if call is None: + raise TypeError(f'Convert to datatype `{self}` is not supported by tvm, calling failed on `tvm.script.ir_builder.tir._ffi_api.{val}`') + return call(expr, is_size_var) + + +def __dtype_new__(cls, value: AnyDType) -> tvm.DataType: + if isinstance(value, str): + val = str.__new__(cls, value) + elif value in _dtype_py2tvmstr: + val = str.__new__(cls, _dtype_py2tvmstr[value]) + else: + expected = set(list(_dtype_py2tvmstr.keys()) + list(_dtype_tvmstr2fficall.values())) + raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}") + val.__tvm_ffi_dtype__ = tvm.ffi.core.DataType(val) + return val + + +tvm.DataType.__eq__ = __dtype_eq__ +tvm.DataType.__req__ = __dtype_eq__ +tvm.DataType.__ne__ = __dtype_ne__ +tvm.DataType.__rne__ = __dtype_ne__ +tvm.DataType.__call__ = __dtype_call__ +tvm.DataType.__new__ = __dtype_new__ def get_tvm_dtype(value: AnyDType) -> tvm.DataType: if isinstance(value, (tvm.DataType, ir.Type)): return value - if isinstance(value, dtype): - return value.get_tvm_dtype() - return dtype(value).get_tvm_dtype() + return tvm.DataType(value) if TYPE_CHECKING: - - class int8(dtype): - ... - - class int16(dtype): - ... - - class int32(dtype): - ... - - class int64(dtype): - ... - - class uint8(dtype): - ... - - class uint16(dtype): - ... - - class uint32(dtype): - ... - - class uint64(dtype): - ... - - class float16(dtype): - ... - - class float32(dtype): - ... - - class float64(dtype): - ... - - class bool(dtype): - ... - - class float8_e4m3fn(dtype): - ... - - class float8_e4m3fnuz(dtype): - ... - - class float8_e5m2(dtype): - ... - - class float8_e5m2fnuz(dtype): - ... - - class float8_e8m0fnu(dtype): - ... - - class bfloat16(dtype): - ... + class bool(tvm.DataType): ... + class short(tvm.DataType): ... + class int(tvm.DataType): ... + class long(tvm.DataType): ... + class half(tvm.DataType): ... + class float(tvm.DataType): ... + class double(tvm.DataType): ... + class int8(tvm.DataType): ... + class int16(tvm.DataType): ... + class int32(tvm.DataType): ... + class int64(tvm.DataType): ... + class int8x4(tvm.DataType): ... + class int16x4(tvm.DataType): ... + class int32x4(tvm.DataType): ... + class int64x4(tvm.DataType): ... + class int8x8(tvm.DataType): ... + class int16x8(tvm.DataType): ... + class int32x8(tvm.DataType): ... + class int64x8(tvm.DataType): ... + class int8x16(tvm.DataType): ... + class int16x16(tvm.DataType): ... + class int32x16(tvm.DataType): ... + class int64x16(tvm.DataType): ... + class int8x32(tvm.DataType): ... + class int16x32(tvm.DataType): ... + class int32x32(tvm.DataType): ... + class int64x32(tvm.DataType): ... + class int8x64(tvm.DataType): ... + class int16x64(tvm.DataType): ... + class int32x64(tvm.DataType): ... + class int64x64(tvm.DataType): ... + class uint8(tvm.DataType): ... + class uint16(tvm.DataType): ... + class uint32(tvm.DataType): ... + class uint64(tvm.DataType): ... + class uint8x4(tvm.DataType): ... + class uint16x4(tvm.DataType): ... + class uint32x4(tvm.DataType): ... + class uint64x4(tvm.DataType): ... + class uint8x8(tvm.DataType): ... + class uint16x8(tvm.DataType): ... + class uint32x8(tvm.DataType): ... + class uint64x8(tvm.DataType): ... + class uint8x16(tvm.DataType): ... + class uint16x16(tvm.DataType): ... + class uint32x16(tvm.DataType): ... + class uint64x16(tvm.DataType): ... + class uint8x32(tvm.DataType): ... + class uint16x32(tvm.DataType): ... + class uint32x32(tvm.DataType): ... + class uint64x32(tvm.DataType): ... + class uint8x64(tvm.DataType): ... + class uint16x64(tvm.DataType): ... + class uint32x64(tvm.DataType): ... + class uint64x64(tvm.DataType): ... + class float16(tvm.DataType): ... + class float32(tvm.DataType): ... + class float64(tvm.DataType): ... + class float16x2(tvm.DataType): ... + class float32x2(tvm.DataType): ... + class float64x2(tvm.DataType): ... + class float16x4(tvm.DataType): ... + class float32x4(tvm.DataType): ... + class float64x4(tvm.DataType): ... + class float16x8(tvm.DataType): ... + class float32x8(tvm.DataType): ... + class float64x8(tvm.DataType): ... + class float16x16(tvm.DataType): ... + class float32x16(tvm.DataType): ... + class float64x16(tvm.DataType): ... + class float16x32(tvm.DataType): ... + class float32x32(tvm.DataType): ... + class float64x32(tvm.DataType): ... + class float16x64(tvm.DataType): ... + class float32x64(tvm.DataType): ... + class float64x64(tvm.DataType): ... + class float8_e3m4(tvm.DataType): ... + class float8_e3m4x2(tvm.DataType): ... + class float8_e3m4x4(tvm.DataType): ... + class float8_e3m4x8(tvm.DataType): ... + class float8_e3m4x16(tvm.DataType): ... + class float8_e3m4x32(tvm.DataType): ... + class float8_e3m4x64(tvm.DataType): ... + class float8_e4m3(tvm.DataType): ... + class float8_e4m3x2(tvm.DataType): ... + class float8_e4m3x4(tvm.DataType): ... + class float8_e4m3x8(tvm.DataType): ... + class float8_e4m3x16(tvm.DataType): ... + class float8_e4m3x32(tvm.DataType): ... + class float8_e4m3x64(tvm.DataType): ... + class float8_e4m3b11fnuz(tvm.DataType): ... + class float8_e4m3b11fnuzx2(tvm.DataType): ... + class float8_e4m3b11fnuzx4(tvm.DataType): ... + class float8_e4m3b11fnuzx8(tvm.DataType): ... + class float8_e4m3b11fnuzx16(tvm.DataType): ... + class float8_e4m3b11fnuzx32(tvm.DataType): ... + class float8_e4m3b11fnuzx64(tvm.DataType): ... + class float8_e4m3fn(tvm.DataType): ... + class float8_e4m3fnx2(tvm.DataType): ... + class float8_e4m3fnx4(tvm.DataType): ... + class float8_e4m3fnx8(tvm.DataType): ... + class float8_e4m3fnx16(tvm.DataType): ... + class float8_e4m3fnx32(tvm.DataType): ... + class float8_e4m3fnx64(tvm.DataType): ... + class float8_e4m3fnuz(tvm.DataType): ... + class float8_e4m3fnuzx2(tvm.DataType): ... + class float8_e4m3fnuzx4(tvm.DataType): ... + class float8_e4m3fnuzx8(tvm.DataType): ... + class float8_e4m3fnuzx16(tvm.DataType): ... + class float8_e4m3fnuzx32(tvm.DataType): ... + class float8_e4m3fnuzx64(tvm.DataType): ... + class float8_e5m2(tvm.DataType): ... + class float8_e5m2x2(tvm.DataType): ... + class float8_e5m2x4(tvm.DataType): ... + class float8_e5m2x8(tvm.DataType): ... + class float8_e5m2x16(tvm.DataType): ... + class float8_e5m2x32(tvm.DataType): ... + class float8_e5m2x64(tvm.DataType): ... + class float8_e5m2fnuz(tvm.DataType): ... + class float8_e5m2fnuzx2(tvm.DataType): ... + class float8_e5m2fnuzx4(tvm.DataType): ... + class float8_e5m2fnuzx8(tvm.DataType): ... + class float8_e5m2fnuzx16(tvm.DataType): ... + class float8_e5m2fnuzx32(tvm.DataType): ... + class float8_e5m2fnuzx64(tvm.DataType): ... + class float8_e8m0fnu(tvm.DataType): ... + class float8_e8m0fnux2(tvm.DataType): ... + class float8_e8m0fnux4(tvm.DataType): ... + class float8_e8m0fnux8(tvm.DataType): ... + class float8_e8m0fnux16(tvm.DataType): ... + class float8_e8m0fnux32(tvm.DataType): ... + class float8_e8m0fnux64(tvm.DataType): ... + class float6_e2m3fn(tvm.DataType): ... + class float6_e2m3fnx2(tvm.DataType): ... + class float6_e2m3fnx4(tvm.DataType): ... + class float6_e2m3fnx8(tvm.DataType): ... + class float6_e2m3fnx16(tvm.DataType): ... + class float6_e2m3fnx32(tvm.DataType): ... + class float6_e2m3fnx64(tvm.DataType): ... + class float6_e3m2fn(tvm.DataType): ... + class float6_e3m2fnx2(tvm.DataType): ... + class float6_e3m2fnx4(tvm.DataType): ... + class float6_e3m2fnx8(tvm.DataType): ... + class float6_e3m2fnx16(tvm.DataType): ... + class float6_e3m2fnx32(tvm.DataType): ... + class float6_e3m2fnx64(tvm.DataType): ... + class float4_e2m1fn(tvm.DataType): ... + class float4_e2m1fnx2(tvm.DataType): ... + class float4_e2m1fnx4(tvm.DataType): ... + class float4_e2m1fnx8(tvm.DataType): ... + class float4_e2m1fnx16(tvm.DataType): ... + class float4_e2m1fnx32(tvm.DataType): ... + class float4_e2m1fnx64(tvm.DataType): ... + class bfloat16(tvm.DataType): ... else: - int8 = dtype('int8') - int16 = dtype('int16') - int32 = dtype('int32') - int64 = dtype('int64') - uint8 = dtype('uint8') - uint16 = dtype('uint16') - uint32 = dtype('uint32') - uint64 = dtype('uint64') - float16 = dtype('float16') - float32 = dtype('float32') - float64 = dtype('float64') - bool = dtype('bool') - float8_e4m3fn = dtype('float8_e4m3fn') - float8_e4m3fnuz = dtype('float8_e4m3fnuz') - float8_e5m2 = dtype('float8_e5m2') - float8_e5m2fnuz = dtype('float8_e5m2fnuz') - float8_e8m0fnu = dtype('float8_e8m0fnu') - bfloat16 = dtype('bfloat16') + bool = tvm.DataType('bool') + short = tvm.DataType('int16') + int = tvm.DataType('int32') + long = tvm.DataType('int64') + half = tvm.DataType('float16') + float = tvm.DataType('float32') + double = tvm.DataType('float64') + int8 = tvm.DataType('int8') + int16 = tvm.DataType('int16') + int32 = tvm.DataType('int32') + int64 = tvm.DataType('int64') + int8x4 = tvm.DataType('int8x4') + int16x4 = tvm.DataType('int16x4') + int32x4 = tvm.DataType('int32x4') + int64x4 = tvm.DataType('int64x4') + int8x8 = tvm.DataType('int8x8') + int16x8 = tvm.DataType('int16x8') + int32x8 = tvm.DataType('int32x8') + int64x8 = tvm.DataType('int64x8') + int8x16 = tvm.DataType('int8x16') + int16x16 = tvm.DataType('int16x16') + int32x16 = tvm.DataType('int32x16') + int64x16 = tvm.DataType('int64x16') + int8x32 = tvm.DataType('int8x32') + int16x32 = tvm.DataType('int16x32') + int32x32 = tvm.DataType('int32x32') + int64x32 = tvm.DataType('int64x32') + int8x64 = tvm.DataType('int8x64') + int16x64 = tvm.DataType('int16x64') + int32x64 = tvm.DataType('int32x64') + int64x64 = tvm.DataType('int64x64') + uint8 = tvm.DataType('uint8') + uint16 = tvm.DataType('uint16') + uint32 = tvm.DataType('uint32') + uint64 = tvm.DataType('uint64') + uint8x4 = tvm.DataType('uint8x4') + uint16x4 = tvm.DataType('uint16x4') + uint32x4 = tvm.DataType('uint32x4') + uint64x4 = tvm.DataType('uint64x4') + uint8x8 = tvm.DataType('uint8x8') + uint16x8 = tvm.DataType('uint16x8') + uint32x8 = tvm.DataType('uint32x8') + uint64x8 = tvm.DataType('uint64x8') + uint8x16 = tvm.DataType('uint8x16') + uint16x16 = tvm.DataType('uint16x16') + uint32x16 = tvm.DataType('uint32x16') + uint64x16 = tvm.DataType('uint64x16') + uint8x32 = tvm.DataType('uint8x32') + uint16x32 = tvm.DataType('uint16x32') + uint32x32 = tvm.DataType('uint32x32') + uint64x32 = tvm.DataType('uint64x32') + uint8x64 = tvm.DataType('uint8x64') + uint16x64 = tvm.DataType('uint16x64') + uint32x64 = tvm.DataType('uint32x64') + uint64x64 = tvm.DataType('uint64x64') + float16 = tvm.DataType('float16') + float32 = tvm.DataType('float32') + float64 = tvm.DataType('float64') + float16x2 = tvm.DataType('float16x2') + float32x2 = tvm.DataType('float32x2') + float64x2 = tvm.DataType('float64x2') + float16x4 = tvm.DataType('float16x4') + float32x4 = tvm.DataType('float32x4') + float64x4 = tvm.DataType('float64x4') + float16x8 = tvm.DataType('float16x8') + float32x8 = tvm.DataType('float32x8') + float64x8 = tvm.DataType('float64x8') + float16x16 = tvm.DataType('float16x16') + float32x16 = tvm.DataType('float32x16') + float64x16 = tvm.DataType('float64x16') + float16x32 = tvm.DataType('float16x32') + float32x32 = tvm.DataType('float32x32') + float64x32 = tvm.DataType('float64x32') + float16x64 = tvm.DataType('float16x64') + float32x64 = tvm.DataType('float32x64') + float64x64 = tvm.DataType('float64x64') + float8_e3m4 = tvm.DataType('float8_e3m4') + float8_e3m4x2 = tvm.DataType('float8_e3m4x2') + float8_e3m4x4 = tvm.DataType('float8_e3m4x4') + float8_e3m4x8 = tvm.DataType('float8_e3m4x8') + float8_e3m4x16 = tvm.DataType('float8_e3m4x16') + float8_e3m4x32 = tvm.DataType('float8_e3m4x32') + float8_e3m4x64 = tvm.DataType('float8_e3m4x64') + float8_e4m3 = tvm.DataType('float8_e4m3') + float8_e4m3x2 = tvm.DataType('float8_e4m3x2') + float8_e4m3x4 = tvm.DataType('float8_e4m3x4') + float8_e4m3x8 = tvm.DataType('float8_e4m3x8') + float8_e4m3x16 = tvm.DataType('float8_e4m3x16') + float8_e4m3x32 = tvm.DataType('float8_e4m3x32') + float8_e4m3x64 = tvm.DataType('float8_e4m3x64') + float8_e4m3b11fnuz = tvm.DataType('float8_e4m3b11fnuz') + float8_e4m3b11fnuzx2 = tvm.DataType('float8_e4m3b11fnuzx2') + float8_e4m3b11fnuzx4 = tvm.DataType('float8_e4m3b11fnuzx4') + float8_e4m3b11fnuzx8 = tvm.DataType('float8_e4m3b11fnuzx8') + float8_e4m3b11fnuzx16 = tvm.DataType('float8_e4m3b11fnuzx16') + float8_e4m3b11fnuzx32 = tvm.DataType('float8_e4m3b11fnuzx32') + float8_e4m3b11fnuzx64 = tvm.DataType('float8_e4m3b11fnuzx64') + float8_e4m3fn = tvm.DataType('float8_e4m3fn') + float8_e4m3fnx2 = tvm.DataType('float8_e4m3fnx2') + float8_e4m3fnx4 = tvm.DataType('float8_e4m3fnx4') + float8_e4m3fnx8 = tvm.DataType('float8_e4m3fnx8') + float8_e4m3fnx16 = tvm.DataType('float8_e4m3fnx16') + float8_e4m3fnx32 = tvm.DataType('float8_e4m3fnx32') + float8_e4m3fnx64 = tvm.DataType('float8_e4m3fnx64') + float8_e4m3fnuz = tvm.DataType('float8_e4m3fnuz') + float8_e4m3fnuzx2 = tvm.DataType('float8_e4m3fnuzx2') + float8_e4m3fnuzx4 = tvm.DataType('float8_e4m3fnuzx4') + float8_e4m3fnuzx8 = tvm.DataType('float8_e4m3fnuzx8') + float8_e4m3fnuzx16 = tvm.DataType('float8_e4m3fnuzx16') + float8_e4m3fnuzx32 = tvm.DataType('float8_e4m3fnuzx32') + float8_e4m3fnuzx64 = tvm.DataType('float8_e4m3fnuzx64') + float8_e5m2 = tvm.DataType('float8_e5m2') + float8_e5m2x2 = tvm.DataType('float8_e5m2x2') + float8_e5m2x4 = tvm.DataType('float8_e5m2x4') + float8_e5m2x8 = tvm.DataType('float8_e5m2x8') + float8_e5m2x16 = tvm.DataType('float8_e5m2x16') + float8_e5m2x32 = tvm.DataType('float8_e5m2x32') + float8_e5m2x64 = tvm.DataType('float8_e5m2x64') + float8_e5m2fnuz = tvm.DataType('float8_e5m2fnuz') + float8_e5m2fnuzx2 = tvm.DataType('float8_e5m2fnuzx2') + float8_e5m2fnuzx4 = tvm.DataType('float8_e5m2fnuzx4') + float8_e5m2fnuzx8 = tvm.DataType('float8_e5m2fnuzx8') + float8_e5m2fnuzx16 = tvm.DataType('float8_e5m2fnuzx16') + float8_e5m2fnuzx32 = tvm.DataType('float8_e5m2fnuzx32') + float8_e5m2fnuzx64 = tvm.DataType('float8_e5m2fnuzx64') + float8_e8m0fnu = tvm.DataType('float8_e8m0fnu') + float8_e8m0fnux2 = tvm.DataType('float8_e8m0fnux2') + float8_e8m0fnux4 = tvm.DataType('float8_e8m0fnux4') + float8_e8m0fnux8 = tvm.DataType('float8_e8m0fnux8') + float8_e8m0fnux16 = tvm.DataType('float8_e8m0fnux16') + float8_e8m0fnux32 = tvm.DataType('float8_e8m0fnux32') + float8_e8m0fnux64 = tvm.DataType('float8_e8m0fnux64') + float6_e2m3fn = tvm.DataType('float6_e2m3fn') + float6_e2m3fnx2 = tvm.DataType('float6_e2m3fnx2') + float6_e2m3fnx4 = tvm.DataType('float6_e2m3fnx4') + float6_e2m3fnx8 = tvm.DataType('float6_e2m3fnx8') + float6_e2m3fnx16 = tvm.DataType('float6_e2m3fnx16') + float6_e2m3fnx32 = tvm.DataType('float6_e2m3fnx32') + float6_e2m3fnx64 = tvm.DataType('float6_e2m3fnx64') + float6_e3m2fn = tvm.DataType('float6_e3m2fn') + float6_e3m2fnx2 = tvm.DataType('float6_e3m2fnx2') + float6_e3m2fnx4 = tvm.DataType('float6_e3m2fnx4') + float6_e3m2fnx8 = tvm.DataType('float6_e3m2fnx8') + float6_e3m2fnx16 = tvm.DataType('float6_e3m2fnx16') + float6_e3m2fnx32 = tvm.DataType('float6_e3m2fnx32') + float6_e3m2fnx64 = tvm.DataType('float6_e3m2fnx64') + float4_e2m1fn = tvm.DataType('float4_e2m1fn') + float4_e2m1fnx2 = tvm.DataType('float4_e2m1fnx2') + float4_e2m1fnx4 = tvm.DataType('float4_e2m1fnx4') + float4_e2m1fnx8 = tvm.DataType('float4_e2m1fnx8') + float4_e2m1fnx16 = tvm.DataType('float4_e2m1fnx16') + float4_e2m1fnx32 = tvm.DataType('float4_e2m1fnx32') + float4_e2m1fnx64 = tvm.DataType('float4_e2m1fnx64') + bfloat16 = tvm.DataType('bfloat16') + +_all_dtypes = [ + 'bool', + 'short', + 'int', + 'long', + 'half', + 'float', + 'double', + 'int8', + 'int16', + 'int32', + 'int64', + 'int8x4', + 'int16x4', + 'int32x4', + 'int64x4', + 'int8x8', + 'int16x8', + 'int32x8', + 'int64x8', + 'int8x16', + 'int16x16', + 'int32x16', + 'int64x16', + 'int8x32', + 'int16x32', + 'int32x32', + 'int64x32', + 'int8x64', + 'int16x64', + 'int32x64', + 'int64x64', + 'uint8', + 'uint16', + 'uint32', + 'uint64', + 'uint8x4', + 'uint16x4', + 'uint32x4', + 'uint64x4', + 'uint8x8', + 'uint16x8', + 'uint32x8', + 'uint64x8', + 'uint8x16', + 'uint16x16', + 'uint32x16', + 'uint64x16', + 'uint8x32', + 'uint16x32', + 'uint32x32', + 'uint64x32', + 'uint8x64', + 'uint16x64', + 'uint32x64', + 'uint64x64', + 'float16', + 'float32', + 'float64', + 'float16x2', + 'float32x2', + 'float64x2', + 'float16x4', + 'float32x4', + 'float64x4', + 'float16x8', + 'float32x8', + 'float64x8', + 'float16x16', + 'float32x16', + 'float64x16', + 'float16x32', + 'float32x32', + 'float64x32', + 'float16x64', + 'float32x64', + 'float64x64', + 'float8_e3m4', + 'float8_e3m4x2', + 'float8_e3m4x4', + 'float8_e3m4x8', + 'float8_e3m4x16', + 'float8_e3m4x32', + 'float8_e3m4x64', + 'float8_e4m3', + 'float8_e4m3x2', + 'float8_e4m3x4', + 'float8_e4m3x8', + 'float8_e4m3x16', + 'float8_e4m3x32', + 'float8_e4m3x64', + 'float8_e4m3b11fnuz', + 'float8_e4m3b11fnuzx2', + 'float8_e4m3b11fnuzx4', + 'float8_e4m3b11fnuzx8', + 'float8_e4m3b11fnuzx16', + 'float8_e4m3b11fnuzx32', + 'float8_e4m3b11fnuzx64', + 'float8_e4m3fn', + 'float8_e4m3fnx2', + 'float8_e4m3fnx4', + 'float8_e4m3fnx8', + 'float8_e4m3fnx16', + 'float8_e4m3fnx32', + 'float8_e4m3fnx64', + 'float8_e4m3fnuz', + 'float8_e4m3fnuzx2', + 'float8_e4m3fnuzx4', + 'float8_e4m3fnuzx8', + 'float8_e4m3fnuzx16', + 'float8_e4m3fnuzx32', + 'float8_e4m3fnuzx64', + 'float8_e5m2', + 'float8_e5m2x2', + 'float8_e5m2x4', + 'float8_e5m2x8', + 'float8_e5m2x16', + 'float8_e5m2x32', + 'float8_e5m2x64', + 'float8_e5m2fnuz', + 'float8_e5m2fnuzx2', + 'float8_e5m2fnuzx4', + 'float8_e5m2fnuzx8', + 'float8_e5m2fnuzx16', + 'float8_e5m2fnuzx32', + 'float8_e5m2fnuzx64', + 'float8_e8m0fnu', + 'float8_e8m0fnux2', + 'float8_e8m0fnux4', + 'float8_e8m0fnux8', + 'float8_e8m0fnux16', + 'float8_e8m0fnux32', + 'float8_e8m0fnux64', + 'float6_e2m3fn', + 'float6_e2m3fnx2', + 'float6_e2m3fnx4', + 'float6_e2m3fnx8', + 'float6_e2m3fnx16', + 'float6_e2m3fnx32', + 'float6_e2m3fnx64', + 'float6_e3m2fn', + 'float6_e3m2fnx2', + 'float6_e3m2fnx4', + 'float6_e3m2fnx8', + 'float6_e3m2fnx16', + 'float6_e3m2fnx32', + 'float6_e3m2fnx64', + 'float4_e2m1fn', + 'float4_e2m1fnx2', + 'float4_e2m1fnx4', + 'float4_e2m1fnx8', + 'float4_e2m1fnx16', + 'float4_e2m1fnx32', + 'float4_e2m1fnx64', + 'bfloat16', +] + +__all__ = _all_dtypes + [ + 'AnyDType', 'get_tvm_dtype', +] From 4d0bc855f3e2a262714e4cb83899ac7dbe380227 Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Tue, 28 Oct 2025 16:58:02 +0800 Subject: [PATCH 16/24] fix lint error --- .../language/test_tilelang_language_dtype.py | 110 ++-- tilelang/language/v2/dtypes.py | 622 +++++++++++++----- 2 files changed, 525 insertions(+), 207 deletions(-) diff --git a/testing/python/language/test_tilelang_language_dtype.py b/testing/python/language/test_tilelang_language_dtype.py index d6f6b0e74..45a2f4531 100644 --- a/testing/python/language/test_tilelang_language_dtype.py +++ b/testing/python/language/test_tilelang_language_dtype.py @@ -4,33 +4,35 @@ import tilelang.testing import tvm + def test_argument(): + @T.prim_func def test_argument( - t_1: T.bool, - t_2: T.short, - t_3: T.int, - t_4: T.long, - t_5: T.half, - t_6: T.float, - t_7: T.long, - t_8: T.int8, - t_9: T.int16, - t_10: T.int32, - t_11: T.int64, - t_12: T.uint8, - t_13: T.uint16, - t_14: T.uint32, - t_15: T.uint64, - t_16: T.float8_e4m3fn, - t_17: T.float8_e4m3fnuz, - t_18: T.float8_e5m2, - t_19: T.float8_e5m2fnuz, - t_20: T.float8_e8m0fnu, - t_21: T.float16, - t_22: T.bfloat16, - t_23: T.float32, - t_24: T.float64, + t_1: T.bool, + t_2: T.short, + t_3: T.int, + t_4: T.long, + t_5: T.half, + t_6: T.float, + t_7: T.long, + t_8: T.int8, + t_9: T.int16, + t_10: T.int32, + t_11: T.int64, + t_12: T.uint8, + t_13: T.uint16, + t_14: T.uint32, + t_15: T.uint64, + t_16: T.float8_e4m3fn, + t_17: T.float8_e4m3fnuz, + t_18: T.float8_e5m2, + t_19: T.float8_e5m2fnuz, + t_20: T.float8_e8m0fnu, + t_21: T.float16, + t_22: T.bfloat16, + t_23: T.float32, + t_24: T.float64, ): pass @@ -44,14 +46,15 @@ def test_expr(): try: dtype(1.0) dtype() - except TypeError as e: + except TypeError: pass - except Exception as e: + except Exception: errors.append(name) assert not errors def test_var_decl_sugar(): + @T.prim_func def test_var_decl_sugar(): with T.Kernel(128, 128) as (bx, by): @@ -107,35 +110,38 @@ def test_var_decl_sugar(): s = test_var_decl_sugar.script() for i in range(1, 25): assert f'var_{i}_1' in s - assert f'tl.local_var_init' in s + assert 'tl.local_var_init' in s + def test_dtype_str_repr(): + @T.prim_func def test_str_repr(): - buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared') - buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared') - buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared') - buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared') - buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared') - buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared') - buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared') - buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared') - buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared') - buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared') - buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared') - buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared') - buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared') - buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared') - buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared') - buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared') - buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared') - buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared') - buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared') - buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared') - buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared') - buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared') - buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared') - buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') + buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared') # noqa F841 + buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared') # noqa F841 + buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared') # noqa F841 + buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841 + buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared') # noqa F841 + buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared') # noqa F841 + buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841 + buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared') # noqa F841 + buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared') # noqa F841 + buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared') # noqa F841 + buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared') # noqa F841 + buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared') # noqa F841 + buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared') # noqa F841 + buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared') # noqa F841 + buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared') # noqa F841 + buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared') # noqa F841 + buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared') # noqa F841 + buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared') # noqa F841 + buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared') # noqa F841 + buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared') # noqa F841 + buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared') # noqa F841 + buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared') # noqa F841 + buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared') # noqa F841 + buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841 + def test_torch_eq(): dtypes = [ @@ -195,6 +201,7 @@ def test_torch_eq(): def test_var_assign(): + @tilelang.jit(out_idx=-1) @T.prim_func def test_var_assign(A: T.Tensor((2,), T.int32)): @@ -205,6 +212,7 @@ def test_var_assign(A: T.Tensor((2,), T.int32)): d: T.int32 = a A[0] = b A[1] = d + res = test_var_assign()() assert res[0] == 1 assert res[1] == 2 diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py index ff3388dbd..d7f5af74d 100644 --- a/tilelang/language/v2/dtypes.py +++ b/tilelang/language/v2/dtypes.py @@ -98,7 +98,8 @@ def __dtype_call__(self: tvm.DataType, expr=None, is_size_var: bool = False) -> val = first + second.upper() call = getattr(tb_ffi, val, None) if call is None: - raise TypeError(f'Convert to datatype `{self}` is not supported by tvm, calling failed on `tvm.script.ir_builder.tir._ffi_api.{val}`') + raise TypeError(f"Convert to datatype `{self}` is not supported by tvm\n" + f"calling failed on `tvm.script.ir_builder.tir._ffi_api.{val}`") return call(expr, is_size_var) @@ -129,160 +130,468 @@ def get_tvm_dtype(value: AnyDType) -> tvm.DataType: if TYPE_CHECKING: - class bool(tvm.DataType): ... - class short(tvm.DataType): ... - class int(tvm.DataType): ... - class long(tvm.DataType): ... - class half(tvm.DataType): ... - class float(tvm.DataType): ... - class double(tvm.DataType): ... - class int8(tvm.DataType): ... - class int16(tvm.DataType): ... - class int32(tvm.DataType): ... - class int64(tvm.DataType): ... - class int8x4(tvm.DataType): ... - class int16x4(tvm.DataType): ... - class int32x4(tvm.DataType): ... - class int64x4(tvm.DataType): ... - class int8x8(tvm.DataType): ... - class int16x8(tvm.DataType): ... - class int32x8(tvm.DataType): ... - class int64x8(tvm.DataType): ... - class int8x16(tvm.DataType): ... - class int16x16(tvm.DataType): ... - class int32x16(tvm.DataType): ... - class int64x16(tvm.DataType): ... - class int8x32(tvm.DataType): ... - class int16x32(tvm.DataType): ... - class int32x32(tvm.DataType): ... - class int64x32(tvm.DataType): ... - class int8x64(tvm.DataType): ... - class int16x64(tvm.DataType): ... - class int32x64(tvm.DataType): ... - class int64x64(tvm.DataType): ... - class uint8(tvm.DataType): ... - class uint16(tvm.DataType): ... - class uint32(tvm.DataType): ... - class uint64(tvm.DataType): ... - class uint8x4(tvm.DataType): ... - class uint16x4(tvm.DataType): ... - class uint32x4(tvm.DataType): ... - class uint64x4(tvm.DataType): ... - class uint8x8(tvm.DataType): ... - class uint16x8(tvm.DataType): ... - class uint32x8(tvm.DataType): ... - class uint64x8(tvm.DataType): ... - class uint8x16(tvm.DataType): ... - class uint16x16(tvm.DataType): ... - class uint32x16(tvm.DataType): ... - class uint64x16(tvm.DataType): ... - class uint8x32(tvm.DataType): ... - class uint16x32(tvm.DataType): ... - class uint32x32(tvm.DataType): ... - class uint64x32(tvm.DataType): ... - class uint8x64(tvm.DataType): ... - class uint16x64(tvm.DataType): ... - class uint32x64(tvm.DataType): ... - class uint64x64(tvm.DataType): ... - class float16(tvm.DataType): ... - class float32(tvm.DataType): ... - class float64(tvm.DataType): ... - class float16x2(tvm.DataType): ... - class float32x2(tvm.DataType): ... - class float64x2(tvm.DataType): ... - class float16x4(tvm.DataType): ... - class float32x4(tvm.DataType): ... - class float64x4(tvm.DataType): ... - class float16x8(tvm.DataType): ... - class float32x8(tvm.DataType): ... - class float64x8(tvm.DataType): ... - class float16x16(tvm.DataType): ... - class float32x16(tvm.DataType): ... - class float64x16(tvm.DataType): ... - class float16x32(tvm.DataType): ... - class float32x32(tvm.DataType): ... - class float64x32(tvm.DataType): ... - class float16x64(tvm.DataType): ... - class float32x64(tvm.DataType): ... - class float64x64(tvm.DataType): ... - class float8_e3m4(tvm.DataType): ... - class float8_e3m4x2(tvm.DataType): ... - class float8_e3m4x4(tvm.DataType): ... - class float8_e3m4x8(tvm.DataType): ... - class float8_e3m4x16(tvm.DataType): ... - class float8_e3m4x32(tvm.DataType): ... - class float8_e3m4x64(tvm.DataType): ... - class float8_e4m3(tvm.DataType): ... - class float8_e4m3x2(tvm.DataType): ... - class float8_e4m3x4(tvm.DataType): ... - class float8_e4m3x8(tvm.DataType): ... - class float8_e4m3x16(tvm.DataType): ... - class float8_e4m3x32(tvm.DataType): ... - class float8_e4m3x64(tvm.DataType): ... - class float8_e4m3b11fnuz(tvm.DataType): ... - class float8_e4m3b11fnuzx2(tvm.DataType): ... - class float8_e4m3b11fnuzx4(tvm.DataType): ... - class float8_e4m3b11fnuzx8(tvm.DataType): ... - class float8_e4m3b11fnuzx16(tvm.DataType): ... - class float8_e4m3b11fnuzx32(tvm.DataType): ... - class float8_e4m3b11fnuzx64(tvm.DataType): ... - class float8_e4m3fn(tvm.DataType): ... - class float8_e4m3fnx2(tvm.DataType): ... - class float8_e4m3fnx4(tvm.DataType): ... - class float8_e4m3fnx8(tvm.DataType): ... - class float8_e4m3fnx16(tvm.DataType): ... - class float8_e4m3fnx32(tvm.DataType): ... - class float8_e4m3fnx64(tvm.DataType): ... - class float8_e4m3fnuz(tvm.DataType): ... - class float8_e4m3fnuzx2(tvm.DataType): ... - class float8_e4m3fnuzx4(tvm.DataType): ... - class float8_e4m3fnuzx8(tvm.DataType): ... - class float8_e4m3fnuzx16(tvm.DataType): ... - class float8_e4m3fnuzx32(tvm.DataType): ... - class float8_e4m3fnuzx64(tvm.DataType): ... - class float8_e5m2(tvm.DataType): ... - class float8_e5m2x2(tvm.DataType): ... - class float8_e5m2x4(tvm.DataType): ... - class float8_e5m2x8(tvm.DataType): ... - class float8_e5m2x16(tvm.DataType): ... - class float8_e5m2x32(tvm.DataType): ... - class float8_e5m2x64(tvm.DataType): ... - class float8_e5m2fnuz(tvm.DataType): ... - class float8_e5m2fnuzx2(tvm.DataType): ... - class float8_e5m2fnuzx4(tvm.DataType): ... - class float8_e5m2fnuzx8(tvm.DataType): ... - class float8_e5m2fnuzx16(tvm.DataType): ... - class float8_e5m2fnuzx32(tvm.DataType): ... - class float8_e5m2fnuzx64(tvm.DataType): ... - class float8_e8m0fnu(tvm.DataType): ... - class float8_e8m0fnux2(tvm.DataType): ... - class float8_e8m0fnux4(tvm.DataType): ... - class float8_e8m0fnux8(tvm.DataType): ... - class float8_e8m0fnux16(tvm.DataType): ... - class float8_e8m0fnux32(tvm.DataType): ... - class float8_e8m0fnux64(tvm.DataType): ... - class float6_e2m3fn(tvm.DataType): ... - class float6_e2m3fnx2(tvm.DataType): ... - class float6_e2m3fnx4(tvm.DataType): ... - class float6_e2m3fnx8(tvm.DataType): ... - class float6_e2m3fnx16(tvm.DataType): ... - class float6_e2m3fnx32(tvm.DataType): ... - class float6_e2m3fnx64(tvm.DataType): ... - class float6_e3m2fn(tvm.DataType): ... - class float6_e3m2fnx2(tvm.DataType): ... - class float6_e3m2fnx4(tvm.DataType): ... - class float6_e3m2fnx8(tvm.DataType): ... - class float6_e3m2fnx16(tvm.DataType): ... - class float6_e3m2fnx32(tvm.DataType): ... - class float6_e3m2fnx64(tvm.DataType): ... - class float4_e2m1fn(tvm.DataType): ... - class float4_e2m1fnx2(tvm.DataType): ... - class float4_e2m1fnx4(tvm.DataType): ... - class float4_e2m1fnx8(tvm.DataType): ... - class float4_e2m1fnx16(tvm.DataType): ... - class float4_e2m1fnx32(tvm.DataType): ... - class float4_e2m1fnx64(tvm.DataType): ... - class bfloat16(tvm.DataType): ... + + class bool(tvm.DataType): + ... + + class short(tvm.DataType): + ... + + class int(tvm.DataType): + ... + + class long(tvm.DataType): + ... + + class half(tvm.DataType): + ... + + class float(tvm.DataType): + ... + + class double(tvm.DataType): + ... + + class int8(tvm.DataType): + ... + + class int16(tvm.DataType): + ... + + class int32(tvm.DataType): + ... + + class int64(tvm.DataType): + ... + + class int8x4(tvm.DataType): + ... + + class int16x4(tvm.DataType): + ... + + class int32x4(tvm.DataType): + ... + + class int64x4(tvm.DataType): + ... + + class int8x8(tvm.DataType): + ... + + class int16x8(tvm.DataType): + ... + + class int32x8(tvm.DataType): + ... + + class int64x8(tvm.DataType): + ... + + class int8x16(tvm.DataType): + ... + + class int16x16(tvm.DataType): + ... + + class int32x16(tvm.DataType): + ... + + class int64x16(tvm.DataType): + ... + + class int8x32(tvm.DataType): + ... + + class int16x32(tvm.DataType): + ... + + class int32x32(tvm.DataType): + ... + + class int64x32(tvm.DataType): + ... + + class int8x64(tvm.DataType): + ... + + class int16x64(tvm.DataType): + ... + + class int32x64(tvm.DataType): + ... + + class int64x64(tvm.DataType): + ... + + class uint8(tvm.DataType): + ... + + class uint16(tvm.DataType): + ... + + class uint32(tvm.DataType): + ... + + class uint64(tvm.DataType): + ... + + class uint8x4(tvm.DataType): + ... + + class uint16x4(tvm.DataType): + ... + + class uint32x4(tvm.DataType): + ... + + class uint64x4(tvm.DataType): + ... + + class uint8x8(tvm.DataType): + ... + + class uint16x8(tvm.DataType): + ... + + class uint32x8(tvm.DataType): + ... + + class uint64x8(tvm.DataType): + ... + + class uint8x16(tvm.DataType): + ... + + class uint16x16(tvm.DataType): + ... + + class uint32x16(tvm.DataType): + ... + + class uint64x16(tvm.DataType): + ... + + class uint8x32(tvm.DataType): + ... + + class uint16x32(tvm.DataType): + ... + + class uint32x32(tvm.DataType): + ... + + class uint64x32(tvm.DataType): + ... + + class uint8x64(tvm.DataType): + ... + + class uint16x64(tvm.DataType): + ... + + class uint32x64(tvm.DataType): + ... + + class uint64x64(tvm.DataType): + ... + + class float16(tvm.DataType): + ... + + class float32(tvm.DataType): + ... + + class float64(tvm.DataType): + ... + + class float16x2(tvm.DataType): + ... + + class float32x2(tvm.DataType): + ... + + class float64x2(tvm.DataType): + ... + + class float16x4(tvm.DataType): + ... + + class float32x4(tvm.DataType): + ... + + class float64x4(tvm.DataType): + ... + + class float16x8(tvm.DataType): + ... + + class float32x8(tvm.DataType): + ... + + class float64x8(tvm.DataType): + ... + + class float16x16(tvm.DataType): + ... + + class float32x16(tvm.DataType): + ... + + class float64x16(tvm.DataType): + ... + + class float16x32(tvm.DataType): + ... + + class float32x32(tvm.DataType): + ... + + class float64x32(tvm.DataType): + ... + + class float16x64(tvm.DataType): + ... + + class float32x64(tvm.DataType): + ... + + class float64x64(tvm.DataType): + ... + + class float8_e3m4(tvm.DataType): + ... + + class float8_e3m4x2(tvm.DataType): + ... + + class float8_e3m4x4(tvm.DataType): + ... + + class float8_e3m4x8(tvm.DataType): + ... + + class float8_e3m4x16(tvm.DataType): + ... + + class float8_e3m4x32(tvm.DataType): + ... + + class float8_e3m4x64(tvm.DataType): + ... + + class float8_e4m3(tvm.DataType): + ... + + class float8_e4m3x2(tvm.DataType): + ... + + class float8_e4m3x4(tvm.DataType): + ... + + class float8_e4m3x8(tvm.DataType): + ... + + class float8_e4m3x16(tvm.DataType): + ... + + class float8_e4m3x32(tvm.DataType): + ... + + class float8_e4m3x64(tvm.DataType): + ... + + class float8_e4m3b11fnuz(tvm.DataType): + ... + + class float8_e4m3b11fnuzx2(tvm.DataType): + ... + + class float8_e4m3b11fnuzx4(tvm.DataType): + ... + + class float8_e4m3b11fnuzx8(tvm.DataType): + ... + + class float8_e4m3b11fnuzx16(tvm.DataType): + ... + + class float8_e4m3b11fnuzx32(tvm.DataType): + ... + + class float8_e4m3b11fnuzx64(tvm.DataType): + ... + + class float8_e4m3fn(tvm.DataType): + ... + + class float8_e4m3fnx2(tvm.DataType): + ... + + class float8_e4m3fnx4(tvm.DataType): + ... + + class float8_e4m3fnx8(tvm.DataType): + ... + + class float8_e4m3fnx16(tvm.DataType): + ... + + class float8_e4m3fnx32(tvm.DataType): + ... + + class float8_e4m3fnx64(tvm.DataType): + ... + + class float8_e4m3fnuz(tvm.DataType): + ... + + class float8_e4m3fnuzx2(tvm.DataType): + ... + + class float8_e4m3fnuzx4(tvm.DataType): + ... + + class float8_e4m3fnuzx8(tvm.DataType): + ... + + class float8_e4m3fnuzx16(tvm.DataType): + ... + + class float8_e4m3fnuzx32(tvm.DataType): + ... + + class float8_e4m3fnuzx64(tvm.DataType): + ... + + class float8_e5m2(tvm.DataType): + ... + + class float8_e5m2x2(tvm.DataType): + ... + + class float8_e5m2x4(tvm.DataType): + ... + + class float8_e5m2x8(tvm.DataType): + ... + + class float8_e5m2x16(tvm.DataType): + ... + + class float8_e5m2x32(tvm.DataType): + ... + + class float8_e5m2x64(tvm.DataType): + ... + + class float8_e5m2fnuz(tvm.DataType): + ... + + class float8_e5m2fnuzx2(tvm.DataType): + ... + + class float8_e5m2fnuzx4(tvm.DataType): + ... + + class float8_e5m2fnuzx8(tvm.DataType): + ... + + class float8_e5m2fnuzx16(tvm.DataType): + ... + + class float8_e5m2fnuzx32(tvm.DataType): + ... + + class float8_e5m2fnuzx64(tvm.DataType): + ... + + class float8_e8m0fnu(tvm.DataType): + ... + + class float8_e8m0fnux2(tvm.DataType): + ... + + class float8_e8m0fnux4(tvm.DataType): + ... + + class float8_e8m0fnux8(tvm.DataType): + ... + + class float8_e8m0fnux16(tvm.DataType): + ... + + class float8_e8m0fnux32(tvm.DataType): + ... + + class float8_e8m0fnux64(tvm.DataType): + ... + + class float6_e2m3fn(tvm.DataType): + ... + + class float6_e2m3fnx2(tvm.DataType): + ... + + class float6_e2m3fnx4(tvm.DataType): + ... + + class float6_e2m3fnx8(tvm.DataType): + ... + + class float6_e2m3fnx16(tvm.DataType): + ... + + class float6_e2m3fnx32(tvm.DataType): + ... + + class float6_e2m3fnx64(tvm.DataType): + ... + + class float6_e3m2fn(tvm.DataType): + ... + + class float6_e3m2fnx2(tvm.DataType): + ... + + class float6_e3m2fnx4(tvm.DataType): + ... + + class float6_e3m2fnx8(tvm.DataType): + ... + + class float6_e3m2fnx16(tvm.DataType): + ... + + class float6_e3m2fnx32(tvm.DataType): + ... + + class float6_e3m2fnx64(tvm.DataType): + ... + + class float4_e2m1fn(tvm.DataType): + ... + + class float4_e2m1fnx2(tvm.DataType): + ... + + class float4_e2m1fnx4(tvm.DataType): + ... + + class float4_e2m1fnx8(tvm.DataType): + ... + + class float4_e2m1fnx16(tvm.DataType): + ... + + class float4_e2m1fnx32(tvm.DataType): + ... + + class float4_e2m1fnx64(tvm.DataType): + ... + + class bfloat16(tvm.DataType): + ... else: bool = tvm.DataType('bool') short = tvm.DataType('int16') @@ -597,5 +906,6 @@ class bfloat16(tvm.DataType): ... ] __all__ = _all_dtypes + [ - 'AnyDType', 'get_tvm_dtype', + 'AnyDType', + 'get_tvm_dtype', ] From 0dfe4e37c0f72419d0efe103407b926fb4f1c07f Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Wed, 29 Oct 2025 11:50:40 +0800 Subject: [PATCH 17/24] add more warning in frontend --- tilelang/language/v2/builder.py | 72 +++++++++++++++++++++++++-------- 1 file changed, 56 insertions(+), 16 deletions(-) diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 0ddef17d9..605fe36dd 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -12,7 +12,7 @@ import tvm from tvm.tir import Buffer from tvm.script.ir_builder import tir, IRBuilder -from tvm.tir.expr import EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, Var +from tvm.tir.expr import EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var from typing import TYPE_CHECKING, Callable, ContextManager, Any, Generic, Hashable, ParamSpec, Self, TypeVar, ForwardRef from .dtypes import get_tvm_dtype from types import EllipsisType @@ -23,34 +23,42 @@ def unwrap_expr(expr) -> PrimExpr | int | float: + ''' + unwrap expr and convert it into PrimExpr like + ''' if isinstance(expr, tir.meta_var): expr = expr.value elif isinstance(expr, Buffer) and expr.scope() == 'local.var': expr = tir.BufferLoad(expr, indices=[0]) elif isinstance(expr, (EqualOp, NotEqualOp)): expr = expr.asobject() - elif isinstance(expr, IntImm) and expr.dtype == 'int32': - expr = expr.value return expr def unwrap_cond(expr): + ''' + unwrap expr and convert to bool condition + ''' expr = unwrap_expr(expr) - if isinstance(expr, PrimExpr): + if isinstance(expr, (IntImm, FloatImm, StringImm)): + return bool(expr.value) + elif isinstance(expr, PrimExpr): return expr elif isinstance(expr, Buffer): raise TypeError(f"Buffer `{expr}` cannot be used as condition directly.") - elif isinstance(expr, (int, bool, tuple, list)): - return expr + elif isinstance(expr, (int, bool)) or expr is None: + return bool(expr) else: - logger.warning(f"Python expression `{expr}` is used in TileLang. ", stack_info=True) - return expr + logger.warning( + f"Python expression `{expr}` is used as condition in TileLang, \n" + "this is treated as a constant expression. ", stack_info=True, stacklevel=3) + return bool(expr) thread_local_storage = threading.local() -class DummyFrame: +class Frame: def __enter__(self): ... @@ -59,23 +67,30 @@ def __exit__(self, exc_type, exc_value, traceback): ... -class MacroFrame(DummyFrame): +class MacroFrame(Frame): ... -class BoolOpFrame(DummyFrame): +class BoolOpFrame(Frame): ... -class ConstIfFrame(DummyFrame): +class ConstIfFrame(Frame): ... -class BlockFrame(DummyFrame): +class BlockFrame(Frame): ... -AnyFrame = tir.frame.IRBuilderFrame | DummyFrame +class ContinueFrame(Frame): + ... + +class BreakFrame(Frame): + ... + +ContinueOrBreak = ContinueFrame | BreakFrame +AnyFrame = tir.frame.IRBuilderFrame | Frame TIR_CONTROL_FRAME = ( tir.frame.WhileFrame, @@ -144,6 +159,15 @@ def enter_frame(self, frame: ContextManager): self.frames.append(frame) return frame.__enter__() + def check_continue_break(self): + idx = self.find_frame_idx(ContinueOrBreak) + if idx is not None: + logger.warning( + 'Writing code after continue/break may cause undefined behavior in tilelang.', + stack_info=True, + stacklevel=3 + ) + @contextmanager def with_frame(self, frame: ContextManager | None): pop_idx = len(self.frames) @@ -155,6 +179,7 @@ class _has_if_frame: ... def ctx_if(self, cond): + self.check_continue_break() cond = unwrap_cond(cond) if isinstance(cond, PrimExpr): with self.with_frame(tir.If(cond)): @@ -199,6 +224,7 @@ def eval(self, val: Any): raise TypeError(f"Unsupported eval value: {val} of type {type(val)}") def ctx_for(self, it): + self.check_continue_break() it = unwrap_expr(it) if isinstance(it, range): assert it.step == 1, "Only step=1 is supported in range for now." @@ -211,15 +237,21 @@ def ctx_for(self, it): yield v def ctx_continue(self): + self.check_continue_break() + self.enter_frame(ContinueFrame()) raise RuntimeError("continue is not supported in TileLang builder") def ctx_break(self): + self.check_continue_break() + self.enter_frame(BreakFrame()) raise RuntimeError("break is not supported in TileLang builder") def ctx_while(self, cond): + self.check_continue_break() raise RuntimeError("while loops are not supported in TileLang builder") def bind(self, name, value, annot=BaseBuilder.empty): + self.check_continue_break() locals = self.get_parent_locals() orig_value = locals.get(name, None) # annotation like tl.float32 @@ -285,6 +317,7 @@ def bind_immutable(self, name, value): return self.enter_frame(frame) def assign_slice(self, lval: Any, sl: slice, value: Any, annot=BaseBuilder.empty): + self.check_continue_break() if annot is not self.empty: logger.warning( "Type annotation in slice assignment has no effect", stack_info=True, stacklevel=2) @@ -294,6 +327,7 @@ def assign_slice(self, lval: Any, sl: slice, value: Any, annot=BaseBuilder.empty return super().assign_slice(lval, sl, value) def aug_assign(self, op, target, aug_value): + self.check_continue_break() if is_var(target): tir.buffer_store(target, eval_op(op, target[0], aug_value), 0) elif isinstance(target, Buffer): @@ -302,6 +336,7 @@ def aug_assign(self, op, target, aug_value): return super().aug_assign(op, target, aug_value) def aug_assign_slice(self, op, target, sl, aug_value): + self.check_continue_break() if isinstance(target, Buffer): tir.buffer_store(target, eval_op(op, target[sl], aug_value), sl) else: @@ -328,6 +363,9 @@ def ifexp(self, cond, then, otherwise): return super().ifexp(cond, then, otherwise) def ret(self, value): + self.check_continue_break() + # handle return T.alloc_var() + value = self.unwrap_value(value) last_macro = self.find_frame_idx(MacroFrame) if last_macro is not None: frame = self.find_frame_idx(TIR_CONTROL_FRAME, start=last_macro) @@ -345,15 +383,17 @@ def ret(self, value): " return a\n" "```" ) - return super().ret(value) + return value def ctx_with(self, ctx): + self.check_continue_break() if isinstance(ctx, tir.frame.IRBuilderFrame): return self.with_frame(ctx) else: return super().ctx_with(ctx) def assert_expr(self, cond, msg): + self.check_continue_break() cond = unwrap_cond(cond) if isinstance(cond, PrimExpr): self.enter_frame(tir.Assert(cond, msg)) @@ -368,7 +408,7 @@ def rval(self, name: str, value: Any) -> Any: f"Use immutable variable `{name}` outside its defining region, did you forget **alloc_var**?\n" f"variable `{name}` is defined in frame: {frame}, current frames: {self.frames}." ) - return unwrap_expr(value) + return self.unwrap_value(value) def arg(self, name, value): if self.find_frame_idx(MacroFrame) is not None: From b29da364dcdd011b591397c11e43542e84554c2a Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Fri, 31 Oct 2025 16:01:43 +0800 Subject: [PATCH 18/24] update tvm version --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 9cda9b611..fa576ec5c 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 9cda9b611ba9d91a1d42b561767f40aba0afcd78 +Subproject commit fa576ec5ce704b44226cce6a13d0a2bb525f10c2 From f1be5065d2cdb1890f238a525e18108a4ec634f2 Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Fri, 31 Oct 2025 18:20:34 +0800 Subject: [PATCH 19/24] Minor fix on tvm_ffi annotations --- .../language/test_tilelang_language_dtype.py | 1 + .../language/test_tilelang_language_let.py | 4 +- tilelang/language/v2/builder.py | 123 ++- tilelang/language/v2/dtypes.py | 968 ++++++------------ 4 files changed, 407 insertions(+), 689 deletions(-) diff --git a/testing/python/language/test_tilelang_language_dtype.py b/testing/python/language/test_tilelang_language_dtype.py index 45a2f4531..2303acba9 100644 --- a/testing/python/language/test_tilelang_language_dtype.py +++ b/testing/python/language/test_tilelang_language_dtype.py @@ -198,6 +198,7 @@ def test_torch_eq(): ] for a, b in zip(dtypes, torch_dtypes): assert a == b, f"{a} and {b} are not equal" + assert T.dtype(b) == a, f"dtype convertion error" def test_var_assign(): diff --git a/testing/python/language/test_tilelang_language_let.py b/testing/python/language/test_tilelang_language_let.py index 8cc5b1fa6..a2af09c67 100644 --- a/testing/python/language/test_tilelang_language_let.py +++ b/testing/python/language/test_tilelang_language_let.py @@ -11,12 +11,12 @@ def main(A_ptr: T.handle): for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): - b: T.float32x4 = A[0, 0:4] + b = A[0, 0:4] A[0, 4:8] = b mod = tvm.IRModule({"main": main}) mod = tvm.compile(mod, target="cuda") - assert "float4 b" in mod.mod.imported_modules[0].get_source() + assert "float4 b" in mod.mod.imports[0].inspect_source() if __name__ == "__main__": diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 605fe36dd..9f85dcd34 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -4,9 +4,8 @@ from dataclasses import dataclass import inspect -import torch from tilelang.language.kernel import KernelLaunchFrame -from tvm.ffi.container import Map +from tvm_ffi.container import Map from tvm.ir.base import Span from .ast import BaseBuilder, eval_op, mutate import tvm @@ -14,7 +13,7 @@ from tvm.script.ir_builder import tir, IRBuilder from tvm.tir.expr import EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var from typing import TYPE_CHECKING, Callable, ContextManager, Any, Generic, Hashable, ParamSpec, Self, TypeVar, ForwardRef -from .dtypes import get_tvm_dtype +from . import dtypes as dt from types import EllipsisType import threading import logging @@ -115,8 +114,7 @@ def is_var(v: Any) -> bool: class Builder(BaseBuilder): - def __init__(self, arg_annot: dict[str, Any] = None): - self.arg_annot = arg_annot + def __init__(self): self.frames: list[AnyFrame] = [] self.ir_builder = IRBuilder() self.name_inside_frame: dict[str, AnyFrame] = {} @@ -211,6 +209,12 @@ def eval(self, val: Any): if val is None: pass elif isinstance(val, tir.frame.IRBuilderFrame): + if isinstance(val, tir.frame.ForFrame): + logger.warning( + f'Evaluating a for frame to may cause undefined behavior in tilelang.', + stack_info=True, + stacklevel=1, + ) self.enter_frame(val) elif isinstance(val, PrimExpr): tir.evaluate(val) @@ -226,9 +230,6 @@ def eval(self, val: Any): def ctx_for(self, it): self.check_continue_break() it = unwrap_expr(it) - if isinstance(it, range): - assert it.step == 1, "Only step=1 is supported in range for now." - it = tir.serial(it.start, it.stop) if not isinstance(it, tir.frame.ForFrame): raise TypeError( f"Invalid for loop, got {it}({type(it)}), expect one of the following: " @@ -285,6 +286,12 @@ def bind(self, name, value, annot=BaseBuilder.empty): if name != '_': frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME) assert frame is not None, f"Variable `{name}` is not defined inside any control flow." + if name in self.name_inside_frame and self.name_inside_frame[name] in self.frames: + logger.warning( + f'Variable `{name}` shadows another declared value, Are you forgetting to use alloc it as a var?', + stack_info=True, + stacklevel=2, + ) self.name_inside_frame[name] = self.frames[frame] return res @@ -300,6 +307,12 @@ def bind_immutable(self, name, value): if isinstance(value, tir.meta_var): return value.value elif isinstance(value, tir.frame.IRBuilderFrame): + if isinstance(value, tir.frame.ForFrame): + logger.warning( + f'Binding a for frame to variable may cause undefined behavior in tilelang.', + stack_info=True, + stacklevel=2, + ) return self.enter_frame(value) elif isinstance(value, (Buffer, tir.IterVar, tir.Var)): IRBuilder.name(name, value) @@ -397,8 +410,8 @@ def assert_expr(self, cond, msg): cond = unwrap_cond(cond) if isinstance(cond, PrimExpr): self.enter_frame(tir.Assert(cond, msg)) - else: - super().assert_expr(cond, msg) + elif not cond: + raise AssertionError(msg) def rval(self, name: str, value: Any) -> Any: if name in self.name_inside_frame: @@ -412,15 +425,19 @@ def rval(self, name: str, value: Any) -> Any: def arg(self, name, value): if self.find_frame_idx(MacroFrame) is not None: - return value + if isinstance(value, (str, StringImm)): + # this is a workaround for string argument in macro + return value + else: + return self.bind(name, value) if isinstance(value, (Buffer, Var)): return tir.arg(name, value) - elif hasattr(value, '__tl_arg__'): - return value.__tl_arg__(name, self) - elif isinstance(value, Hashable): - return value + elif value is self.empty: + raise ValueError(f'Argument `{name}` is not annotated') + # elif isinstance(value, Hashable): + # return value else: - raise TypeError(f"Unsupported argument type: {type(value)} for argument `{name}`.") + raise TypeError(f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.") def override(self, name: str): if name == 'range': @@ -428,14 +445,6 @@ def override(self, name: str): raise ValueError(f'Unknown override: {name}') -def __torch_tensor_tl_arg__(self: torch.Tensor, name: str, builder: Builder): - buffer = tir.buffer( - self.shape, get_tvm_dtype(self.dtype), strides=self.stride(), scope='global') - return tir.arg(name, buffer) - - -torch.Tensor.__tl_arg__ = __torch_tensor_tl_arg__ - _P = ParamSpec('_P') _T = TypeVar('_T') @@ -501,43 +510,57 @@ def get_type_hints(func): globalns = getattr(func, '__globals__', {}) localns = globalns for name, value in annot.items(): + if name == 'return': + continue if isinstance(value, tvm.DataType): hints[name] = value continue if value is None: value = type(None) if isinstance(value, str): + _, v = value.split('.', maxsplit=1) + if v in dt._all_dtypes: + try: + hints[name] = eval(value, globalns, localns) + continue + except Exception as e: + pass value = ForwardRef(value, is_argument=True, is_class=False) - hints[name] = _eval_type(value, globalns=globalns, localns=localns, type_params=type_params) return hints -def prim_func(func: Callable[_P, _T]) -> PrimFunc[_P, _T]: +def _is_static_annot(annot: Any) -> bool: + return isinstance(annot, (dt.dtype, Buffer, Var)) + + +def prim_func(func: Callable[_P, _T]) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]]: sig = inspect.signature(func) annot = get_type_hints(func) - args = [] - kwargs = {} - for name, param in sig.parameters.items(): - if param.annotation is not param.empty: - if callable(param.annotation): - value = param.annotation() - else: - value = param.annotation - elif param.default is not param.empty: - value = param.default - else: - value = Builder.empty - if param.kind == param.POSITIONAL_ONLY: - args.append(value) - else: - kwargs[name] = value + + for k in annot: + if callable(annot[k]): + annot[k] = annot[k]() + + all_arg_annotated = all([x in annot for x in sig.parameters]) + all_annot_are_static = all([_is_static_annot(x) for x in annot.values()]) ir_gen = build_ir_generator(func) - builder = Builder(annot) - with builder.prim_func(func.__name__): - ir_gen.gen(builder)(*args, **kwargs) - res = builder.get() - res.ir_gen = ir_gen - res.source = ir_gen.source - res.orig_func = func - return res + + def prim_func_generator(*args, **kwargs): + builder = Builder() + with builder.prim_func(func.__name__): + ir_gen.gen(builder)(*args, **kwargs) + res = builder.get() + res.ir_gen = ir_gen + res.source = ir_gen.source + res.orig_func = func + return res + + prim_func_generator.ir_gen = ir_gen + prim_func_generator.source = ir_gen.source + prim_func_generator.orig_func = func + + if all_arg_annotated and all_annot_are_static: + return prim_func_generator(**annot) + else: + return prim_func_generator diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py index d7f5af74d..f5413874a 100644 --- a/tilelang/language/v2/dtypes.py +++ b/tilelang/language/v2/dtypes.py @@ -1,17 +1,14 @@ from tilelang import tvm from tvm import ir +import tvm_ffi import torch import ctypes from typing import TYPE_CHECKING from tvm import tir import tvm.script.ir_builder.tir._ffi_api as tb_ffi - -class VoidPtr: - ... - - -AnyDType = ir.Type | str | type | torch.dtype | tvm.DataType +dtype = tvm.DataType +AnyDType = ir.Type | str | type | torch.dtype | dtype _dtype_cvt = [ (None, 'handle', ctypes.c_long, 'long', None), # use long to repr void* @@ -58,12 +55,12 @@ def _create_type_mapper(sidx, didx, smapper=lambda x: x, dmapper=lambda x: x): _dtype_py2tvmstr = _create_type_mapper(0, 1) _dtype_tvmstr2fficall = _create_type_mapper(1, 4, dmapper=lambda x: getattr(tb_ffi, x)) -_dtype_tvm2py = _create_type_mapper(1, 0, lambda x: tvm.DataType(x)) -_dtype_tvm2ctype = _create_type_mapper(1, 2, lambda x: tvm.DataType(x)) -_dtype_tvm2cffi = _create_type_mapper(1, 3, lambda x: tvm.DataType(x)) +_dtype_tvm2py = _create_type_mapper(1, 0, lambda x: dtype(x)) +_dtype_tvm2ctype = _create_type_mapper(1, 2, lambda x: dtype(x)) +_dtype_tvm2cffi = _create_type_mapper(1, 3, lambda x: dtype(x)) -def __dtype_eq__(self: tvm.DataType, other: AnyDType): +def __dtype_eq__(self: dtype, other: AnyDType): if isinstance(other, str): return str.__eq__(self, other) if other in _dtype_py2tvmstr: @@ -71,7 +68,7 @@ def __dtype_eq__(self: tvm.DataType, other: AnyDType): return NotImplemented -def __dtype_ne__(self: tvm.DataType, other: AnyDType): +def __dtype_ne__(self: dtype, other: AnyDType): if isinstance(other, str): return str.__ne__(self, other) if other in _dtype_py2tvmstr: @@ -79,7 +76,7 @@ def __dtype_ne__(self: tvm.DataType, other: AnyDType): return NotImplemented -def __dtype_call__(self: tvm.DataType, expr=None, is_size_var: bool = False) -> tir.Var: +def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var: if self in _dtype_tvmstr2fficall: return _dtype_tvmstr2fficall[self](expr, is_size_var) # try to construct the ffi call @@ -103,7 +100,7 @@ def __dtype_call__(self: tvm.DataType, expr=None, is_size_var: bool = False) -> return call(expr, is_size_var) -def __dtype_new__(cls, value: AnyDType) -> tvm.DataType: +def __dtype_new__(cls, value: AnyDType) -> dtype: if isinstance(value, str): val = str.__new__(cls, value) elif value in _dtype_py2tvmstr: @@ -111,642 +108,338 @@ def __dtype_new__(cls, value: AnyDType) -> tvm.DataType: else: expected = set(list(_dtype_py2tvmstr.keys()) + list(_dtype_tvmstr2fficall.values())) raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}") - val.__tvm_ffi_dtype__ = tvm.ffi.core.DataType(val) + val.__tvm_ffi_dtype__ = tvm_ffi.core.DataType(val) return val -tvm.DataType.__eq__ = __dtype_eq__ -tvm.DataType.__req__ = __dtype_eq__ -tvm.DataType.__ne__ = __dtype_ne__ -tvm.DataType.__rne__ = __dtype_ne__ -tvm.DataType.__call__ = __dtype_call__ -tvm.DataType.__new__ = __dtype_new__ +dtype.__eq__ = __dtype_eq__ +dtype.__req__ = __dtype_eq__ +dtype.__ne__ = __dtype_ne__ +dtype.__rne__ = __dtype_ne__ +dtype.__call__ = __dtype_call__ +dtype.__new__ = __dtype_new__ -def get_tvm_dtype(value: AnyDType) -> tvm.DataType: - if isinstance(value, (tvm.DataType, ir.Type)): +def get_tvm_dtype(value: AnyDType) -> dtype: + if isinstance(value, (dtype, ir.Type)): return value - return tvm.DataType(value) + return dtype(value) if TYPE_CHECKING: - class bool(tvm.DataType): - ... - - class short(tvm.DataType): - ... - - class int(tvm.DataType): - ... - - class long(tvm.DataType): - ... - - class half(tvm.DataType): - ... - - class float(tvm.DataType): - ... - - class double(tvm.DataType): - ... - - class int8(tvm.DataType): - ... - - class int16(tvm.DataType): - ... - - class int32(tvm.DataType): - ... - - class int64(tvm.DataType): - ... - - class int8x4(tvm.DataType): - ... - - class int16x4(tvm.DataType): - ... - - class int32x4(tvm.DataType): - ... - - class int64x4(tvm.DataType): - ... - - class int8x8(tvm.DataType): - ... - - class int16x8(tvm.DataType): - ... - - class int32x8(tvm.DataType): - ... - - class int64x8(tvm.DataType): - ... - - class int8x16(tvm.DataType): - ... - - class int16x16(tvm.DataType): - ... - - class int32x16(tvm.DataType): - ... - - class int64x16(tvm.DataType): - ... - - class int8x32(tvm.DataType): - ... - - class int16x32(tvm.DataType): - ... - - class int32x32(tvm.DataType): - ... - - class int64x32(tvm.DataType): - ... - - class int8x64(tvm.DataType): - ... - - class int16x64(tvm.DataType): - ... - - class int32x64(tvm.DataType): - ... - - class int64x64(tvm.DataType): - ... - - class uint8(tvm.DataType): - ... - - class uint16(tvm.DataType): - ... - - class uint32(tvm.DataType): - ... - - class uint64(tvm.DataType): - ... - - class uint8x4(tvm.DataType): - ... - - class uint16x4(tvm.DataType): - ... - - class uint32x4(tvm.DataType): - ... - - class uint64x4(tvm.DataType): - ... - - class uint8x8(tvm.DataType): - ... - - class uint16x8(tvm.DataType): - ... - - class uint32x8(tvm.DataType): - ... - - class uint64x8(tvm.DataType): - ... - - class uint8x16(tvm.DataType): - ... - - class uint16x16(tvm.DataType): - ... - - class uint32x16(tvm.DataType): - ... - - class uint64x16(tvm.DataType): - ... - - class uint8x32(tvm.DataType): - ... - - class uint16x32(tvm.DataType): - ... - - class uint32x32(tvm.DataType): - ... - - class uint64x32(tvm.DataType): - ... - - class uint8x64(tvm.DataType): - ... - - class uint16x64(tvm.DataType): - ... - - class uint32x64(tvm.DataType): - ... - - class uint64x64(tvm.DataType): - ... - - class float16(tvm.DataType): - ... - - class float32(tvm.DataType): - ... - - class float64(tvm.DataType): - ... - - class float16x2(tvm.DataType): - ... - - class float32x2(tvm.DataType): - ... - - class float64x2(tvm.DataType): - ... - - class float16x4(tvm.DataType): - ... - - class float32x4(tvm.DataType): - ... - - class float64x4(tvm.DataType): - ... - - class float16x8(tvm.DataType): - ... - - class float32x8(tvm.DataType): - ... - - class float64x8(tvm.DataType): - ... - - class float16x16(tvm.DataType): - ... - - class float32x16(tvm.DataType): - ... - - class float64x16(tvm.DataType): - ... - - class float16x32(tvm.DataType): - ... - - class float32x32(tvm.DataType): - ... - - class float64x32(tvm.DataType): - ... - - class float16x64(tvm.DataType): - ... - - class float32x64(tvm.DataType): - ... - - class float64x64(tvm.DataType): - ... - - class float8_e3m4(tvm.DataType): - ... - - class float8_e3m4x2(tvm.DataType): - ... - - class float8_e3m4x4(tvm.DataType): - ... - - class float8_e3m4x8(tvm.DataType): - ... - - class float8_e3m4x16(tvm.DataType): - ... - - class float8_e3m4x32(tvm.DataType): - ... - - class float8_e3m4x64(tvm.DataType): - ... - - class float8_e4m3(tvm.DataType): - ... - - class float8_e4m3x2(tvm.DataType): - ... - - class float8_e4m3x4(tvm.DataType): - ... - - class float8_e4m3x8(tvm.DataType): - ... - - class float8_e4m3x16(tvm.DataType): - ... - - class float8_e4m3x32(tvm.DataType): - ... - - class float8_e4m3x64(tvm.DataType): - ... - - class float8_e4m3b11fnuz(tvm.DataType): - ... - - class float8_e4m3b11fnuzx2(tvm.DataType): - ... - - class float8_e4m3b11fnuzx4(tvm.DataType): - ... - - class float8_e4m3b11fnuzx8(tvm.DataType): - ... - - class float8_e4m3b11fnuzx16(tvm.DataType): - ... - - class float8_e4m3b11fnuzx32(tvm.DataType): - ... - - class float8_e4m3b11fnuzx64(tvm.DataType): - ... - - class float8_e4m3fn(tvm.DataType): - ... - - class float8_e4m3fnx2(tvm.DataType): - ... - - class float8_e4m3fnx4(tvm.DataType): - ... - - class float8_e4m3fnx8(tvm.DataType): - ... - - class float8_e4m3fnx16(tvm.DataType): - ... - - class float8_e4m3fnx32(tvm.DataType): - ... - - class float8_e4m3fnx64(tvm.DataType): - ... - - class float8_e4m3fnuz(tvm.DataType): - ... - - class float8_e4m3fnuzx2(tvm.DataType): - ... - - class float8_e4m3fnuzx4(tvm.DataType): - ... - - class float8_e4m3fnuzx8(tvm.DataType): - ... - - class float8_e4m3fnuzx16(tvm.DataType): - ... - - class float8_e4m3fnuzx32(tvm.DataType): - ... - - class float8_e4m3fnuzx64(tvm.DataType): - ... - - class float8_e5m2(tvm.DataType): - ... - - class float8_e5m2x2(tvm.DataType): - ... - - class float8_e5m2x4(tvm.DataType): - ... - - class float8_e5m2x8(tvm.DataType): - ... - - class float8_e5m2x16(tvm.DataType): - ... - - class float8_e5m2x32(tvm.DataType): - ... - - class float8_e5m2x64(tvm.DataType): - ... - - class float8_e5m2fnuz(tvm.DataType): - ... - - class float8_e5m2fnuzx2(tvm.DataType): - ... - - class float8_e5m2fnuzx4(tvm.DataType): - ... - - class float8_e5m2fnuzx8(tvm.DataType): - ... - - class float8_e5m2fnuzx16(tvm.DataType): - ... - - class float8_e5m2fnuzx32(tvm.DataType): - ... - - class float8_e5m2fnuzx64(tvm.DataType): - ... - - class float8_e8m0fnu(tvm.DataType): - ... - - class float8_e8m0fnux2(tvm.DataType): - ... - - class float8_e8m0fnux4(tvm.DataType): - ... - - class float8_e8m0fnux8(tvm.DataType): - ... - - class float8_e8m0fnux16(tvm.DataType): - ... - - class float8_e8m0fnux32(tvm.DataType): - ... - - class float8_e8m0fnux64(tvm.DataType): - ... - - class float6_e2m3fn(tvm.DataType): - ... - - class float6_e2m3fnx2(tvm.DataType): - ... - - class float6_e2m3fnx4(tvm.DataType): - ... - - class float6_e2m3fnx8(tvm.DataType): - ... - - class float6_e2m3fnx16(tvm.DataType): - ... - - class float6_e2m3fnx32(tvm.DataType): - ... - - class float6_e2m3fnx64(tvm.DataType): - ... - - class float6_e3m2fn(tvm.DataType): - ... - - class float6_e3m2fnx2(tvm.DataType): - ... - - class float6_e3m2fnx4(tvm.DataType): - ... - - class float6_e3m2fnx8(tvm.DataType): - ... - - class float6_e3m2fnx16(tvm.DataType): - ... - - class float6_e3m2fnx32(tvm.DataType): - ... - - class float6_e3m2fnx64(tvm.DataType): - ... - - class float4_e2m1fn(tvm.DataType): - ... - - class float4_e2m1fnx2(tvm.DataType): - ... - - class float4_e2m1fnx4(tvm.DataType): - ... - - class float4_e2m1fnx8(tvm.DataType): - ... - - class float4_e2m1fnx16(tvm.DataType): - ... - - class float4_e2m1fnx32(tvm.DataType): - ... - - class float4_e2m1fnx64(tvm.DataType): - ... + # yapf: disable + class bool(dtype): ... + class short(dtype): ... + class int(dtype): ... + class long(dtype): ... + class half(dtype): ... + class float(dtype): ... + class double(dtype): ... + class int8(dtype): ... + class int16(dtype): ... + class int32(dtype): ... + class int64(dtype): ... + class int8x4(dtype): ... + class int16x4(dtype): ... + class int32x4(dtype): ... + class int64x4(dtype): ... + class int8x8(dtype): ... + class int16x8(dtype): ... + class int32x8(dtype): ... + class int64x8(dtype): ... + class int8x16(dtype): ... + class int16x16(dtype): ... + class int32x16(dtype): ... + class int64x16(dtype): ... + class int8x32(dtype): ... + class int16x32(dtype): ... + class int32x32(dtype): ... + class int64x32(dtype): ... + class int8x64(dtype): ... + class int16x64(dtype): ... + class int32x64(dtype): ... + class int64x64(dtype): ... + class uint8(dtype): ... + class uint16(dtype): ... + class uint32(dtype): ... + class uint64(dtype): ... + class uint8x4(dtype): ... + class uint16x4(dtype): ... + class uint32x4(dtype): ... + class uint64x4(dtype): ... + class uint8x8(dtype): ... + class uint16x8(dtype): ... + class uint32x8(dtype): ... + class uint64x8(dtype): ... + class uint8x16(dtype): ... + class uint16x16(dtype): ... + class uint32x16(dtype): ... + class uint64x16(dtype): ... + class uint8x32(dtype): ... + class uint16x32(dtype): ... + class uint32x32(dtype): ... + class uint64x32(dtype): ... + class uint8x64(dtype): ... + class uint16x64(dtype): ... + class uint32x64(dtype): ... + class uint64x64(dtype): ... + class float16(dtype): ... + class float32(dtype): ... + class float64(dtype): ... + class float16x2(dtype): ... + class float32x2(dtype): ... + class float64x2(dtype): ... + class float16x4(dtype): ... + class float32x4(dtype): ... + class float64x4(dtype): ... + class float16x8(dtype): ... + class float32x8(dtype): ... + class float64x8(dtype): ... + class float16x16(dtype): ... + class float32x16(dtype): ... + class float64x16(dtype): ... + class float16x32(dtype): ... + class float32x32(dtype): ... + class float64x32(dtype): ... + class float16x64(dtype): ... + class float32x64(dtype): ... + class float64x64(dtype): ... + class float8_e3m4(dtype): ... + class float8_e3m4x2(dtype): ... + class float8_e3m4x4(dtype): ... + class float8_e3m4x8(dtype): ... + class float8_e3m4x16(dtype): ... + class float8_e3m4x32(dtype): ... + class float8_e3m4x64(dtype): ... + class float8_e4m3(dtype): ... + class float8_e4m3x2(dtype): ... + class float8_e4m3x4(dtype): ... + class float8_e4m3x8(dtype): ... + class float8_e4m3x16(dtype): ... + class float8_e4m3x32(dtype): ... + class float8_e4m3x64(dtype): ... + class float8_e4m3b11fnuz(dtype): ... + class float8_e4m3b11fnuzx2(dtype): ... + class float8_e4m3b11fnuzx4(dtype): ... + class float8_e4m3b11fnuzx8(dtype): ... + class float8_e4m3b11fnuzx16(dtype): ... + class float8_e4m3b11fnuzx32(dtype): ... + class float8_e4m3b11fnuzx64(dtype): ... + class float8_e4m3fn(dtype): ... + class float8_e4m3fnx2(dtype): ... + class float8_e4m3fnx4(dtype): ... + class float8_e4m3fnx8(dtype): ... + class float8_e4m3fnx16(dtype): ... + class float8_e4m3fnx32(dtype): ... + class float8_e4m3fnx64(dtype): ... + class float8_e4m3fnuz(dtype): ... + class float8_e4m3fnuzx2(dtype): ... + class float8_e4m3fnuzx4(dtype): ... + class float8_e4m3fnuzx8(dtype): ... + class float8_e4m3fnuzx16(dtype): ... + class float8_e4m3fnuzx32(dtype): ... + class float8_e4m3fnuzx64(dtype): ... + class float8_e5m2(dtype): ... + class float8_e5m2x2(dtype): ... + class float8_e5m2x4(dtype): ... + class float8_e5m2x8(dtype): ... + class float8_e5m2x16(dtype): ... + class float8_e5m2x32(dtype): ... + class float8_e5m2x64(dtype): ... + class float8_e5m2fnuz(dtype): ... + class float8_e5m2fnuzx2(dtype): ... + class float8_e5m2fnuzx4(dtype): ... + class float8_e5m2fnuzx8(dtype): ... + class float8_e5m2fnuzx16(dtype): ... + class float8_e5m2fnuzx32(dtype): ... + class float8_e5m2fnuzx64(dtype): ... + class float8_e8m0fnu(dtype): ... + class float8_e8m0fnux2(dtype): ... + class float8_e8m0fnux4(dtype): ... + class float8_e8m0fnux8(dtype): ... + class float8_e8m0fnux16(dtype): ... + class float8_e8m0fnux32(dtype): ... + class float8_e8m0fnux64(dtype): ... + class float6_e2m3fn(dtype): ... + class float6_e2m3fnx2(dtype): ... + class float6_e2m3fnx4(dtype): ... + class float6_e2m3fnx8(dtype): ... + class float6_e2m3fnx16(dtype): ... + class float6_e2m3fnx32(dtype): ... + class float6_e2m3fnx64(dtype): ... + class float6_e3m2fn(dtype): ... + class float6_e3m2fnx2(dtype): ... + class float6_e3m2fnx4(dtype): ... + class float6_e3m2fnx8(dtype): ... + class float6_e3m2fnx16(dtype): ... + class float6_e3m2fnx32(dtype): ... + class float6_e3m2fnx64(dtype): ... + class float4_e2m1fn(dtype): ... + class float4_e2m1fnx2(dtype): ... + class float4_e2m1fnx4(dtype): ... + class float4_e2m1fnx8(dtype): ... + class float4_e2m1fnx16(dtype): ... + class float4_e2m1fnx32(dtype): ... + class float4_e2m1fnx64(dtype): ... + class bfloat16(dtype): ... + # yapf: enable - class bfloat16(tvm.DataType): - ... else: - bool = tvm.DataType('bool') - short = tvm.DataType('int16') - int = tvm.DataType('int32') - long = tvm.DataType('int64') - half = tvm.DataType('float16') - float = tvm.DataType('float32') - double = tvm.DataType('float64') - int8 = tvm.DataType('int8') - int16 = tvm.DataType('int16') - int32 = tvm.DataType('int32') - int64 = tvm.DataType('int64') - int8x4 = tvm.DataType('int8x4') - int16x4 = tvm.DataType('int16x4') - int32x4 = tvm.DataType('int32x4') - int64x4 = tvm.DataType('int64x4') - int8x8 = tvm.DataType('int8x8') - int16x8 = tvm.DataType('int16x8') - int32x8 = tvm.DataType('int32x8') - int64x8 = tvm.DataType('int64x8') - int8x16 = tvm.DataType('int8x16') - int16x16 = tvm.DataType('int16x16') - int32x16 = tvm.DataType('int32x16') - int64x16 = tvm.DataType('int64x16') - int8x32 = tvm.DataType('int8x32') - int16x32 = tvm.DataType('int16x32') - int32x32 = tvm.DataType('int32x32') - int64x32 = tvm.DataType('int64x32') - int8x64 = tvm.DataType('int8x64') - int16x64 = tvm.DataType('int16x64') - int32x64 = tvm.DataType('int32x64') - int64x64 = tvm.DataType('int64x64') - uint8 = tvm.DataType('uint8') - uint16 = tvm.DataType('uint16') - uint32 = tvm.DataType('uint32') - uint64 = tvm.DataType('uint64') - uint8x4 = tvm.DataType('uint8x4') - uint16x4 = tvm.DataType('uint16x4') - uint32x4 = tvm.DataType('uint32x4') - uint64x4 = tvm.DataType('uint64x4') - uint8x8 = tvm.DataType('uint8x8') - uint16x8 = tvm.DataType('uint16x8') - uint32x8 = tvm.DataType('uint32x8') - uint64x8 = tvm.DataType('uint64x8') - uint8x16 = tvm.DataType('uint8x16') - uint16x16 = tvm.DataType('uint16x16') - uint32x16 = tvm.DataType('uint32x16') - uint64x16 = tvm.DataType('uint64x16') - uint8x32 = tvm.DataType('uint8x32') - uint16x32 = tvm.DataType('uint16x32') - uint32x32 = tvm.DataType('uint32x32') - uint64x32 = tvm.DataType('uint64x32') - uint8x64 = tvm.DataType('uint8x64') - uint16x64 = tvm.DataType('uint16x64') - uint32x64 = tvm.DataType('uint32x64') - uint64x64 = tvm.DataType('uint64x64') - float16 = tvm.DataType('float16') - float32 = tvm.DataType('float32') - float64 = tvm.DataType('float64') - float16x2 = tvm.DataType('float16x2') - float32x2 = tvm.DataType('float32x2') - float64x2 = tvm.DataType('float64x2') - float16x4 = tvm.DataType('float16x4') - float32x4 = tvm.DataType('float32x4') - float64x4 = tvm.DataType('float64x4') - float16x8 = tvm.DataType('float16x8') - float32x8 = tvm.DataType('float32x8') - float64x8 = tvm.DataType('float64x8') - float16x16 = tvm.DataType('float16x16') - float32x16 = tvm.DataType('float32x16') - float64x16 = tvm.DataType('float64x16') - float16x32 = tvm.DataType('float16x32') - float32x32 = tvm.DataType('float32x32') - float64x32 = tvm.DataType('float64x32') - float16x64 = tvm.DataType('float16x64') - float32x64 = tvm.DataType('float32x64') - float64x64 = tvm.DataType('float64x64') - float8_e3m4 = tvm.DataType('float8_e3m4') - float8_e3m4x2 = tvm.DataType('float8_e3m4x2') - float8_e3m4x4 = tvm.DataType('float8_e3m4x4') - float8_e3m4x8 = tvm.DataType('float8_e3m4x8') - float8_e3m4x16 = tvm.DataType('float8_e3m4x16') - float8_e3m4x32 = tvm.DataType('float8_e3m4x32') - float8_e3m4x64 = tvm.DataType('float8_e3m4x64') - float8_e4m3 = tvm.DataType('float8_e4m3') - float8_e4m3x2 = tvm.DataType('float8_e4m3x2') - float8_e4m3x4 = tvm.DataType('float8_e4m3x4') - float8_e4m3x8 = tvm.DataType('float8_e4m3x8') - float8_e4m3x16 = tvm.DataType('float8_e4m3x16') - float8_e4m3x32 = tvm.DataType('float8_e4m3x32') - float8_e4m3x64 = tvm.DataType('float8_e4m3x64') - float8_e4m3b11fnuz = tvm.DataType('float8_e4m3b11fnuz') - float8_e4m3b11fnuzx2 = tvm.DataType('float8_e4m3b11fnuzx2') - float8_e4m3b11fnuzx4 = tvm.DataType('float8_e4m3b11fnuzx4') - float8_e4m3b11fnuzx8 = tvm.DataType('float8_e4m3b11fnuzx8') - float8_e4m3b11fnuzx16 = tvm.DataType('float8_e4m3b11fnuzx16') - float8_e4m3b11fnuzx32 = tvm.DataType('float8_e4m3b11fnuzx32') - float8_e4m3b11fnuzx64 = tvm.DataType('float8_e4m3b11fnuzx64') - float8_e4m3fn = tvm.DataType('float8_e4m3fn') - float8_e4m3fnx2 = tvm.DataType('float8_e4m3fnx2') - float8_e4m3fnx4 = tvm.DataType('float8_e4m3fnx4') - float8_e4m3fnx8 = tvm.DataType('float8_e4m3fnx8') - float8_e4m3fnx16 = tvm.DataType('float8_e4m3fnx16') - float8_e4m3fnx32 = tvm.DataType('float8_e4m3fnx32') - float8_e4m3fnx64 = tvm.DataType('float8_e4m3fnx64') - float8_e4m3fnuz = tvm.DataType('float8_e4m3fnuz') - float8_e4m3fnuzx2 = tvm.DataType('float8_e4m3fnuzx2') - float8_e4m3fnuzx4 = tvm.DataType('float8_e4m3fnuzx4') - float8_e4m3fnuzx8 = tvm.DataType('float8_e4m3fnuzx8') - float8_e4m3fnuzx16 = tvm.DataType('float8_e4m3fnuzx16') - float8_e4m3fnuzx32 = tvm.DataType('float8_e4m3fnuzx32') - float8_e4m3fnuzx64 = tvm.DataType('float8_e4m3fnuzx64') - float8_e5m2 = tvm.DataType('float8_e5m2') - float8_e5m2x2 = tvm.DataType('float8_e5m2x2') - float8_e5m2x4 = tvm.DataType('float8_e5m2x4') - float8_e5m2x8 = tvm.DataType('float8_e5m2x8') - float8_e5m2x16 = tvm.DataType('float8_e5m2x16') - float8_e5m2x32 = tvm.DataType('float8_e5m2x32') - float8_e5m2x64 = tvm.DataType('float8_e5m2x64') - float8_e5m2fnuz = tvm.DataType('float8_e5m2fnuz') - float8_e5m2fnuzx2 = tvm.DataType('float8_e5m2fnuzx2') - float8_e5m2fnuzx4 = tvm.DataType('float8_e5m2fnuzx4') - float8_e5m2fnuzx8 = tvm.DataType('float8_e5m2fnuzx8') - float8_e5m2fnuzx16 = tvm.DataType('float8_e5m2fnuzx16') - float8_e5m2fnuzx32 = tvm.DataType('float8_e5m2fnuzx32') - float8_e5m2fnuzx64 = tvm.DataType('float8_e5m2fnuzx64') - float8_e8m0fnu = tvm.DataType('float8_e8m0fnu') - float8_e8m0fnux2 = tvm.DataType('float8_e8m0fnux2') - float8_e8m0fnux4 = tvm.DataType('float8_e8m0fnux4') - float8_e8m0fnux8 = tvm.DataType('float8_e8m0fnux8') - float8_e8m0fnux16 = tvm.DataType('float8_e8m0fnux16') - float8_e8m0fnux32 = tvm.DataType('float8_e8m0fnux32') - float8_e8m0fnux64 = tvm.DataType('float8_e8m0fnux64') - float6_e2m3fn = tvm.DataType('float6_e2m3fn') - float6_e2m3fnx2 = tvm.DataType('float6_e2m3fnx2') - float6_e2m3fnx4 = tvm.DataType('float6_e2m3fnx4') - float6_e2m3fnx8 = tvm.DataType('float6_e2m3fnx8') - float6_e2m3fnx16 = tvm.DataType('float6_e2m3fnx16') - float6_e2m3fnx32 = tvm.DataType('float6_e2m3fnx32') - float6_e2m3fnx64 = tvm.DataType('float6_e2m3fnx64') - float6_e3m2fn = tvm.DataType('float6_e3m2fn') - float6_e3m2fnx2 = tvm.DataType('float6_e3m2fnx2') - float6_e3m2fnx4 = tvm.DataType('float6_e3m2fnx4') - float6_e3m2fnx8 = tvm.DataType('float6_e3m2fnx8') - float6_e3m2fnx16 = tvm.DataType('float6_e3m2fnx16') - float6_e3m2fnx32 = tvm.DataType('float6_e3m2fnx32') - float6_e3m2fnx64 = tvm.DataType('float6_e3m2fnx64') - float4_e2m1fn = tvm.DataType('float4_e2m1fn') - float4_e2m1fnx2 = tvm.DataType('float4_e2m1fnx2') - float4_e2m1fnx4 = tvm.DataType('float4_e2m1fnx4') - float4_e2m1fnx8 = tvm.DataType('float4_e2m1fnx8') - float4_e2m1fnx16 = tvm.DataType('float4_e2m1fnx16') - float4_e2m1fnx32 = tvm.DataType('float4_e2m1fnx32') - float4_e2m1fnx64 = tvm.DataType('float4_e2m1fnx64') - bfloat16 = tvm.DataType('bfloat16') + bool = dtype('bool') + short = dtype('int16') + int = dtype('int32') + long = dtype('int64') + half = dtype('float16') + float = dtype('float32') + double = dtype('float64') + int8 = dtype('int8') + int16 = dtype('int16') + int32 = dtype('int32') + int64 = dtype('int64') + int8x4 = dtype('int8x4') + int16x4 = dtype('int16x4') + int32x4 = dtype('int32x4') + int64x4 = dtype('int64x4') + int8x8 = dtype('int8x8') + int16x8 = dtype('int16x8') + int32x8 = dtype('int32x8') + int64x8 = dtype('int64x8') + int8x16 = dtype('int8x16') + int16x16 = dtype('int16x16') + int32x16 = dtype('int32x16') + int64x16 = dtype('int64x16') + int8x32 = dtype('int8x32') + int16x32 = dtype('int16x32') + int32x32 = dtype('int32x32') + int64x32 = dtype('int64x32') + int8x64 = dtype('int8x64') + int16x64 = dtype('int16x64') + int32x64 = dtype('int32x64') + int64x64 = dtype('int64x64') + uint8 = dtype('uint8') + uint16 = dtype('uint16') + uint32 = dtype('uint32') + uint64 = dtype('uint64') + uint8x4 = dtype('uint8x4') + uint16x4 = dtype('uint16x4') + uint32x4 = dtype('uint32x4') + uint64x4 = dtype('uint64x4') + uint8x8 = dtype('uint8x8') + uint16x8 = dtype('uint16x8') + uint32x8 = dtype('uint32x8') + uint64x8 = dtype('uint64x8') + uint8x16 = dtype('uint8x16') + uint16x16 = dtype('uint16x16') + uint32x16 = dtype('uint32x16') + uint64x16 = dtype('uint64x16') + uint8x32 = dtype('uint8x32') + uint16x32 = dtype('uint16x32') + uint32x32 = dtype('uint32x32') + uint64x32 = dtype('uint64x32') + uint8x64 = dtype('uint8x64') + uint16x64 = dtype('uint16x64') + uint32x64 = dtype('uint32x64') + uint64x64 = dtype('uint64x64') + float16 = dtype('float16') + float32 = dtype('float32') + float64 = dtype('float64') + float16x2 = dtype('float16x2') + float32x2 = dtype('float32x2') + float64x2 = dtype('float64x2') + float16x4 = dtype('float16x4') + float32x4 = dtype('float32x4') + float64x4 = dtype('float64x4') + float16x8 = dtype('float16x8') + float32x8 = dtype('float32x8') + float64x8 = dtype('float64x8') + float16x16 = dtype('float16x16') + float32x16 = dtype('float32x16') + float64x16 = dtype('float64x16') + float16x32 = dtype('float16x32') + float32x32 = dtype('float32x32') + float64x32 = dtype('float64x32') + float16x64 = dtype('float16x64') + float32x64 = dtype('float32x64') + float64x64 = dtype('float64x64') + float8_e3m4 = dtype('float8_e3m4') + float8_e3m4x2 = dtype('float8_e3m4x2') + float8_e3m4x4 = dtype('float8_e3m4x4') + float8_e3m4x8 = dtype('float8_e3m4x8') + float8_e3m4x16 = dtype('float8_e3m4x16') + float8_e3m4x32 = dtype('float8_e3m4x32') + float8_e3m4x64 = dtype('float8_e3m4x64') + float8_e4m3 = dtype('float8_e4m3') + float8_e4m3x2 = dtype('float8_e4m3x2') + float8_e4m3x4 = dtype('float8_e4m3x4') + float8_e4m3x8 = dtype('float8_e4m3x8') + float8_e4m3x16 = dtype('float8_e4m3x16') + float8_e4m3x32 = dtype('float8_e4m3x32') + float8_e4m3x64 = dtype('float8_e4m3x64') + float8_e4m3b11fnuz = dtype('float8_e4m3b11fnuz') + float8_e4m3b11fnuzx2 = dtype('float8_e4m3b11fnuzx2') + float8_e4m3b11fnuzx4 = dtype('float8_e4m3b11fnuzx4') + float8_e4m3b11fnuzx8 = dtype('float8_e4m3b11fnuzx8') + float8_e4m3b11fnuzx16 = dtype('float8_e4m3b11fnuzx16') + float8_e4m3b11fnuzx32 = dtype('float8_e4m3b11fnuzx32') + float8_e4m3b11fnuzx64 = dtype('float8_e4m3b11fnuzx64') + float8_e4m3fn = dtype('float8_e4m3fn') + float8_e4m3fnx2 = dtype('float8_e4m3fnx2') + float8_e4m3fnx4 = dtype('float8_e4m3fnx4') + float8_e4m3fnx8 = dtype('float8_e4m3fnx8') + float8_e4m3fnx16 = dtype('float8_e4m3fnx16') + float8_e4m3fnx32 = dtype('float8_e4m3fnx32') + float8_e4m3fnx64 = dtype('float8_e4m3fnx64') + float8_e4m3fnuz = dtype('float8_e4m3fnuz') + float8_e4m3fnuzx2 = dtype('float8_e4m3fnuzx2') + float8_e4m3fnuzx4 = dtype('float8_e4m3fnuzx4') + float8_e4m3fnuzx8 = dtype('float8_e4m3fnuzx8') + float8_e4m3fnuzx16 = dtype('float8_e4m3fnuzx16') + float8_e4m3fnuzx32 = dtype('float8_e4m3fnuzx32') + float8_e4m3fnuzx64 = dtype('float8_e4m3fnuzx64') + float8_e5m2 = dtype('float8_e5m2') + float8_e5m2x2 = dtype('float8_e5m2x2') + float8_e5m2x4 = dtype('float8_e5m2x4') + float8_e5m2x8 = dtype('float8_e5m2x8') + float8_e5m2x16 = dtype('float8_e5m2x16') + float8_e5m2x32 = dtype('float8_e5m2x32') + float8_e5m2x64 = dtype('float8_e5m2x64') + float8_e5m2fnuz = dtype('float8_e5m2fnuz') + float8_e5m2fnuzx2 = dtype('float8_e5m2fnuzx2') + float8_e5m2fnuzx4 = dtype('float8_e5m2fnuzx4') + float8_e5m2fnuzx8 = dtype('float8_e5m2fnuzx8') + float8_e5m2fnuzx16 = dtype('float8_e5m2fnuzx16') + float8_e5m2fnuzx32 = dtype('float8_e5m2fnuzx32') + float8_e5m2fnuzx64 = dtype('float8_e5m2fnuzx64') + float8_e8m0fnu = dtype('float8_e8m0fnu') + float8_e8m0fnux2 = dtype('float8_e8m0fnux2') + float8_e8m0fnux4 = dtype('float8_e8m0fnux4') + float8_e8m0fnux8 = dtype('float8_e8m0fnux8') + float8_e8m0fnux16 = dtype('float8_e8m0fnux16') + float8_e8m0fnux32 = dtype('float8_e8m0fnux32') + float8_e8m0fnux64 = dtype('float8_e8m0fnux64') + float6_e2m3fn = dtype('float6_e2m3fn') + float6_e2m3fnx2 = dtype('float6_e2m3fnx2') + float6_e2m3fnx4 = dtype('float6_e2m3fnx4') + float6_e2m3fnx8 = dtype('float6_e2m3fnx8') + float6_e2m3fnx16 = dtype('float6_e2m3fnx16') + float6_e2m3fnx32 = dtype('float6_e2m3fnx32') + float6_e2m3fnx64 = dtype('float6_e2m3fnx64') + float6_e3m2fn = dtype('float6_e3m2fn') + float6_e3m2fnx2 = dtype('float6_e3m2fnx2') + float6_e3m2fnx4 = dtype('float6_e3m2fnx4') + float6_e3m2fnx8 = dtype('float6_e3m2fnx8') + float6_e3m2fnx16 = dtype('float6_e3m2fnx16') + float6_e3m2fnx32 = dtype('float6_e3m2fnx32') + float6_e3m2fnx64 = dtype('float6_e3m2fnx64') + float4_e2m1fn = dtype('float4_e2m1fn') + float4_e2m1fnx2 = dtype('float4_e2m1fnx2') + float4_e2m1fnx4 = dtype('float4_e2m1fnx4') + float4_e2m1fnx8 = dtype('float4_e2m1fnx8') + float4_e2m1fnx16 = dtype('float4_e2m1fnx16') + float4_e2m1fnx32 = dtype('float4_e2m1fnx32') + float4_e2m1fnx64 = dtype('float4_e2m1fnx64') + bfloat16 = dtype('bfloat16') _all_dtypes = [ 'bool', @@ -906,6 +599,7 @@ class bfloat16(tvm.DataType): ] __all__ = _all_dtypes + [ + 'dtype', 'AnyDType', 'get_tvm_dtype', ] From 570be6851c7c05a62c40c5f9402e4333cab558c1 Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Mon, 3 Nov 2025 12:31:36 +0800 Subject: [PATCH 20/24] add document and examples --- ... => test_tilelang_language_frontend_v2.py} | 167 ++++++++---- tilelang/language/v2/ast.py | 39 ++- tilelang/language/v2/builder.py | 249 ++++++++++++------ tilelang/language/v2/dtypes.py | 6 +- 4 files changed, 320 insertions(+), 141 deletions(-) rename testing/python/language/{test_tilelang_language_dtype.py => test_tilelang_language_frontend_v2.py} (59%) diff --git a/testing/python/language/test_tilelang_language_dtype.py b/testing/python/language/test_tilelang_language_frontend_v2.py similarity index 59% rename from testing/python/language/test_tilelang_language_dtype.py rename to testing/python/language/test_tilelang_language_frontend_v2.py index 2303acba9..c54ab706e 100644 --- a/testing/python/language/test_tilelang_language_dtype.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -53,64 +53,64 @@ def test_expr(): assert not errors -def test_var_decl_sugar(): +# def test_var_decl_sugar(): - @T.prim_func - def test_var_decl_sugar(): - with T.Kernel(128, 128) as (bx, by): - var_1: T.bool = 1.0 - var_2: T.short = 1.0 - var_3: T.int = 1.0 - var_4: T.long = 1.0 - var_5: T.half = 1.0 - var_6: T.float = 1.0 - var_7: T.long = 1.0 - var_8: T.int8 = 1.0 - var_9: T.int16 = 1.0 - var_10: T.int32 = 1.0 - var_11: T.int64 = 1.0 - var_12: T.uint8 = 1.0 - var_13: T.uint16 = 1.0 - var_14: T.uint32 = 1.0 - var_15: T.uint64 = 1.0 - var_16: T.float8_e4m3fn = 1.0 - var_17: T.float8_e4m3fnuz = 1.0 - var_18: T.float8_e5m2 = 1.0 - var_19: T.float8_e5m2fnuz = 1.0 - var_20: T.float8_e8m0fnu = 1.0 - var_21: T.float16 = 1.0 - var_22: T.bfloat16 = 1.0 - var_23: T.float32 = 1.0 - var_24: T.float64 = 1.0 - var_1: T.bool = var_1 - var_2: T.short = var_2 - var_3: T.int = var_3 - var_4: T.long = var_4 - var_5: T.half = var_5 - var_6: T.float = var_6 - var_7: T.long = var_7 - var_8: T.int8 = var_8 - var_9: T.int16 = var_9 - var_10: T.int32 = var_10 - var_11: T.int64 = var_11 - var_12: T.uint8 = var_12 - var_13: T.uint16 = var_13 - var_14: T.uint32 = var_14 - var_15: T.uint64 = var_15 - var_16: T.float8_e4m3fn = var_16 - var_17: T.float8_e4m3fnuz = var_17 - var_18: T.float8_e5m2 = var_18 - var_19: T.float8_e5m2fnuz = var_19 - var_20: T.float8_e8m0fnu = var_20 - var_21: T.float16 = var_21 - var_22: T.bfloat16 = var_22 - var_23: T.float32 = var_23 - var_24: T.float64 = var_24 - - s = test_var_decl_sugar.script() - for i in range(1, 25): - assert f'var_{i}_1' in s - assert 'tl.local_var_init' in s +# @T.prim_func +# def test_var_decl_sugar(): +# with T.Kernel(128, 128) as (bx, by): +# var_1: T.bool = 1.0 +# var_2: T.short = 1.0 +# var_3: T.int = 1.0 +# var_4: T.long = 1.0 +# var_5: T.half = 1.0 +# var_6: T.float = 1.0 +# var_7: T.long = 1.0 +# var_8: T.int8 = 1.0 +# var_9: T.int16 = 1.0 +# var_10: T.int32 = 1.0 +# var_11: T.int64 = 1.0 +# var_12: T.uint8 = 1.0 +# var_13: T.uint16 = 1.0 +# var_14: T.uint32 = 1.0 +# var_15: T.uint64 = 1.0 +# var_16: T.float8_e4m3fn = 1.0 +# var_17: T.float8_e4m3fnuz = 1.0 +# var_18: T.float8_e5m2 = 1.0 +# var_19: T.float8_e5m2fnuz = 1.0 +# var_20: T.float8_e8m0fnu = 1.0 +# var_21: T.float16 = 1.0 +# var_22: T.bfloat16 = 1.0 +# var_23: T.float32 = 1.0 +# var_24: T.float64 = 1.0 +# var_1: T.bool = var_1 +# var_2: T.short = var_2 +# var_3: T.int = var_3 +# var_4: T.long = var_4 +# var_5: T.half = var_5 +# var_6: T.float = var_6 +# var_7: T.long = var_7 +# var_8: T.int8 = var_8 +# var_9: T.int16 = var_9 +# var_10: T.int32 = var_10 +# var_11: T.int64 = var_11 +# var_12: T.uint8 = var_12 +# var_13: T.uint16 = var_13 +# var_14: T.uint32 = var_14 +# var_15: T.uint64 = var_15 +# var_16: T.float8_e4m3fn = var_16 +# var_17: T.float8_e4m3fnuz = var_17 +# var_18: T.float8_e5m2 = var_18 +# var_19: T.float8_e5m2fnuz = var_19 +# var_20: T.float8_e8m0fnu = var_20 +# var_21: T.float16 = var_21 +# var_22: T.bfloat16 = var_22 +# var_23: T.float32 = var_23 +# var_24: T.float64 = var_24 + +# s = test_var_decl_sugar.script() +# for i in range(1, 25): +# assert f'var_{i}_1' in s +# assert 'tl.local_var_init' in s def test_dtype_str_repr(): @@ -198,7 +198,7 @@ def test_torch_eq(): ] for a, b in zip(dtypes, torch_dtypes): assert a == b, f"{a} and {b} are not equal" - assert T.dtype(b) == a, f"dtype convertion error" + assert T.dtype(b) == a, "dtype conversion error" def test_var_assign(): @@ -219,5 +219,56 @@ def test_var_assign(A: T.Tensor((2,), T.int32)): assert res[1] == 2 +def test_marco_return(): + @T.macro + def macro_return_constant(): + return 0 + + @T.macro + def macro_return_frame(x): + return T.alloc_var(T.float32, init=x) + + @T.macro + def macro_return_expr(x): + y = x + 1.0 + return y + + @T.macro + def macro_apply_func(x, fn): + return fn(x) + + def check(x, ty): + assert isinstance(x, ty) + + @T.prim_func + def test_macro_return(): + with T.Kernel(1) as _: + a = macro_return_constant() + b = macro_return_frame(3.0) + c = macro_return_expr(4.0) + d = macro_apply_func(5.0, lambda x: x * 2.0) + check(a, (int, float, T.PrimExpr)) + check(b, T.PrimExpr) + check(c, T.PrimExpr) + check(d, T.PrimExpr) + + +def test_prim_func_generator(): + @T.prim_func(generator=True) + def prim_func_gen( + A=T.Tensor((128,), T.float32), + B=T.Tensor((128,), T.float32), + ): + with T.Kernel(128) as (tx,): + T.copy(A[tx], B[tx]) + pf = prim_func_gen() + + @T.prim_func + def foo() -> T.Tensor((128,), T.float32): + pass + assert isinstance(foo, T.PrimFunc) + + + if __name__ == '__main__': tilelang.testing.main() diff --git a/tilelang/language/v2/ast.py b/tilelang/language/v2/ast.py index 5b23a1d20..34e74d64b 100644 --- a/tilelang/language/v2/ast.py +++ b/tilelang/language/v2/ast.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast -from typing import Callable, ContextManager, Iterable, Any, Literal, ParamSpec, TypeVar +from dataclasses import dataclass +from typing import Callable, ContextManager, Generic, Iterable, Any, Literal, ParamSpec, TypeVar import inspect # from .utils import get_ast, get_compiled_object from . import utils @@ -527,11 +528,41 @@ def visit_Name(self, node: ast.Name): _P = ParamSpec('_P') -def mutate(func: Callable[_P, _T]) -> Callable[[BaseBuilder], Callable[_P, _T]]: +@dataclass +class IRGenerator(Generic[_P, _T]): + gen: Callable[[BaseBuilder], Callable[_P, _T]] + source: str + + +def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]: + """ + Transform a Python function into an IR (Intermediate Representation) generator. + This function takes a regular Python function and performs AST (Abstract Syntax Tree) + transformation to create an IRGenerator that can be used for code generation purposes. + Args: + func (Callable[_P, _T]): The Python function to be transformed. This should be a + callable that will be analyzed and mutated at the AST level. The function's + signature is preserved through generic type parameters _P (parameters) and + _T (return type). + Returns: + IRGenerator[_P, _T]: An IRGenerator instance wrapping the transformed function. + The generator contains: + - gen: The compiled and mutated version of the original function + - source: The unparsed source code of the transformed AST as a string + Example: + >>> @mutate + ... def my_function(x: int) -> int: + ... return x * 2 + >>> # my_function is now an IRGenerator that can be used for code generation + Note: + - The original function's closure variables and captured context are preserved + - The transformation is performed at compile-time through AST manipulation + - The returned IRGenerator maintains type information from the original function + """ + tree = utils.get_ast(func) filename = inspect.getsourcefile(func) or inspect.getfile(func) tree = DSLMutator().visit(tree) fn = utils.get_compiled_object(tree, func.__name__, filename, utils.inspect_function_capture(func)) - fn.__source__ = ast.unparse(tree) - return fn + return IRGenerator(gen=fn, source=ast.unparse(tree)) diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 9f85dcd34..3bae9ecd1 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -7,14 +7,13 @@ from tilelang.language.kernel import KernelLaunchFrame from tvm_ffi.container import Map from tvm.ir.base import Span -from .ast import BaseBuilder, eval_op, mutate +from .ast import BaseBuilder, IRGenerator, eval_op, mutate import tvm from tvm.tir import Buffer from tvm.script.ir_builder import tir, IRBuilder from tvm.tir.expr import EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var -from typing import TYPE_CHECKING, Callable, ContextManager, Any, Generic, Hashable, ParamSpec, Self, TypeVar, ForwardRef +from typing import TYPE_CHECKING, Callable, ContextManager, Any, Generic, ParamSpec, Self, TypeVar, ForwardRef from . import dtypes as dt -from types import EllipsisType import threading import logging @@ -50,7 +49,9 @@ def unwrap_cond(expr): else: logger.warning( f"Python expression `{expr}` is used as condition in TileLang, \n" - "this is treated as a constant expression. ", stack_info=True, stacklevel=3) + "this is treated as a constant expression. ", + stack_info=True, + stacklevel=3) return bool(expr) @@ -58,6 +59,10 @@ def unwrap_cond(expr): class Frame: + ''' + Frame are virtual context managers used in frontend only + They do not have any runtime representation in the generated TIR. + ''' def __enter__(self): ... @@ -85,9 +90,11 @@ class BlockFrame(Frame): class ContinueFrame(Frame): ... + class BreakFrame(Frame): ... + ContinueOrBreak = ContinueFrame | BreakFrame AnyFrame = tir.frame.IRBuilderFrame | Frame @@ -163,8 +170,7 @@ def check_continue_break(self): logger.warning( 'Writing code after continue/break may cause undefined behavior in tilelang.', stack_info=True, - stacklevel=3 - ) + stacklevel=3) @contextmanager def with_frame(self, frame: ContextManager | None): @@ -211,7 +217,7 @@ def eval(self, val: Any): elif isinstance(val, tir.frame.IRBuilderFrame): if isinstance(val, tir.frame.ForFrame): logger.warning( - f'Evaluating a for frame to may cause undefined behavior in tilelang.', + 'Evaluating a for frame may cause undefined behavior in tilelang.', stack_info=True, stacklevel=1, ) @@ -239,13 +245,15 @@ def ctx_for(self, it): def ctx_continue(self): self.check_continue_break() + # add a dummy frame for checking code after continue/break self.enter_frame(ContinueFrame()) - raise RuntimeError("continue is not supported in TileLang builder") + tir.evaluate(tir.continue_loop()) def ctx_break(self): self.check_continue_break() + # add a dummy frame for checking code after continue/break self.enter_frame(BreakFrame()) - raise RuntimeError("break is not supported in TileLang builder") + tir.evaluate(tir.break_loop()) def ctx_while(self, cond): self.check_continue_break() @@ -256,19 +264,20 @@ def bind(self, name, value, annot=BaseBuilder.empty): locals = self.get_parent_locals() orig_value = locals.get(name, None) # annotation like tl.float32 - if callable(annot): - annot_val = annot() - if isinstance(annot_val, tir.Var): - orig_value = tir.alloc_buffer((1,), dtype=annot_val.dtype, scope='local.var') - IRBuilder.name(name, orig_value) - if isinstance(value, EllipsisType) or value is self.empty: - return orig_value - elif isinstance(value, (int, float, IntImm, FloatImm)): - tir.block_attr( - {'tl.local_var_init': { - orig_value.data: tvm.runtime.convert(value) - }}) - return orig_value + # temporarily disable annotation based var declaration, for better pull request separation + # if callable(annot): + # annot_val = annot() + # if isinstance(annot_val, tir.Var): + # orig_value = tir.alloc_buffer((1,), dtype=annot_val.dtype, scope='local.var') + # IRBuilder.name(name, orig_value) + # if isinstance(value, EllipsisType) or value is self.empty: + # return orig_value + # elif isinstance(value, (int, float, IntImm, FloatImm)): + # tir.block_attr( + # {'tl.local_var_init': { + # orig_value.data: tvm.runtime.convert(value) + # }}) + # return orig_value # if orig_value is a local.var, we use buffer_store to modify it immutably # however, if rvalue is also a local.var, this is a new binding, # we should not use buffer_store, and bind it instead @@ -288,7 +297,7 @@ def bind(self, name, value, annot=BaseBuilder.empty): assert frame is not None, f"Variable `{name}` is not defined inside any control flow." if name in self.name_inside_frame and self.name_inside_frame[name] in self.frames: logger.warning( - f'Variable `{name}` shadows another declared value, Are you forgetting to use alloc it as a var?', + f'Variable `{name}` shadows another declared value, Are you forgetting to allocate it as a var?', stack_info=True, stacklevel=2, ) @@ -309,7 +318,7 @@ def bind_immutable(self, name, value): elif isinstance(value, tir.frame.IRBuilderFrame): if isinstance(value, tir.frame.ForFrame): logger.warning( - f'Binding a for frame to variable may cause undefined behavior in tilelang.', + 'Binding a for frame to variable may cause undefined behavior in tilelang.', stack_info=True, stacklevel=2, ) @@ -390,7 +399,7 @@ def ret(self, value): "```\n" "@T.macro\n" \ "def my_macro(cond):\n" - " a: T.float16 = ...\n" + " a = T.alloc_var(T.float16)\n" " if cond:\n" " a = 1.0\n" " return a\n" @@ -425,11 +434,10 @@ def rval(self, name: str, value: Any) -> Any: def arg(self, name, value): if self.find_frame_idx(MacroFrame) is not None: - if isinstance(value, (str, StringImm)): - # this is a workaround for string argument in macro - return value - else: + if isinstance(value, (PrimExpr, int, float)): return self.bind(name, value) + else: + return value if isinstance(value, (Buffer, Var)): return tir.arg(name, value) elif value is self.empty: @@ -437,7 +445,8 @@ def arg(self, name, value): # elif isinstance(value, Hashable): # return value else: - raise TypeError(f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.") + raise TypeError( + f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.") def override(self, name: str): if name == 'range': @@ -448,13 +457,6 @@ def override(self, name: str): _P = ParamSpec('_P') _T = TypeVar('_T') - -@dataclass -class IRGenerator(Generic[_P, _T]): - gen: Callable[[BaseBuilder], Callable[_P, _T]] - source: str - - if TYPE_CHECKING: class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc): @@ -488,14 +490,42 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: return res -def build_ir_generator(func: Callable[_P, _T]) -> IRGenerator[_P, _T]: - ir_gen = mutate(func) - ir_gen = IRGenerator(gen=ir_gen, source=ir_gen.__source__) - return ir_gen - - -def macro(func: Callable[_P, _T]) -> Macro[_P, _T]: - return Macro(name=func.__name__, orig_func=func, ir_gen=build_ir_generator(func)) +def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]: + """ + Decorator that converts a Python function into a TileLang macro. + TileLang macro is very similar to PrimFunc, it can be used in prim_func or another macro. + Parameters + ---------- + func : Callable[_P, _T] + The Python function to be converted into a macro. This function will be analyzed + and transformed into an IR generation function. The function can take any parameters + (_P) and return any type (_T). + Returns + ------- + Macro[_P, _T] + A Macro object that wraps the original function with IR generation capabilities. + The returned Macro preserves the original function's signature (parameters _P and + return type _T) while adding metaprogramming capabilities. + Example: + -------- + >>> @macro + ... def my_macro(x: T.int32) -> T.int32: + ... return x ** 2 + >>> @prim_func + ... def my_func(A: T.Tensor((10,), T.int32), B: T.Tensor((10,), T.int32)): + ... with T.Kernel(1) as _: + ... for i in T.serial(10): + ... B[i] = my_macro(A[i]) + See Also + -------- + Macro : The class that wraps macro functions + mutate : The function that transforms Python code into IR generators + """ + + def impl(func: Callable[_P, _T]) -> Macro[_P, _T]: + return Macro(name=func.__name__, orig_func=func, ir_gen=mutate(func)) + + return impl(func) if func is not None else impl from typing import _eval_type @@ -518,12 +548,15 @@ def get_type_hints(func): if value is None: value = type(None) if isinstance(value, str): + # this branch handles T.float32 style annotation + # since they are string, directly evaluating them usually causes NameError + # so we need to split and evaluate them separately _, v = value.split('.', maxsplit=1) if v in dt._all_dtypes: try: hints[name] = eval(value, globalns, localns) continue - except Exception as e: + except Exception: pass value = ForwardRef(value, is_argument=True, is_class=False) hints[name] = _eval_type(value, globalns=globalns, localns=localns, type_params=type_params) @@ -534,33 +567,97 @@ def _is_static_annot(annot: Any) -> bool: return isinstance(annot, (dt.dtype, Buffer, Var)) -def prim_func(func: Callable[_P, _T]) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]]: - sig = inspect.signature(func) - annot = get_type_hints(func) - - for k in annot: - if callable(annot[k]): - annot[k] = annot[k]() - - all_arg_annotated = all([x in annot for x in sig.parameters]) - all_annot_are_static = all([_is_static_annot(x) for x in annot.values()]) - ir_gen = build_ir_generator(func) - - def prim_func_generator(*args, **kwargs): - builder = Builder() - with builder.prim_func(func.__name__): - ir_gen.gen(builder)(*args, **kwargs) - res = builder.get() - res.ir_gen = ir_gen - res.source = ir_gen.source - res.orig_func = func - return res - - prim_func_generator.ir_gen = ir_gen - prim_func_generator.source = ir_gen.source - prim_func_generator.orig_func = func - - if all_arg_annotated and all_annot_are_static: - return prim_func_generator(**annot) - else: - return prim_func_generator +def prim_func(func: Callable[_P, _T] = None, + *, + generator: bool = False) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]]: + """ + Decorator to create a primitive function (PrimFunc) for TileLang IR generation. + This decorator transforms a Python function into a TileLang primitive function by analyzing + its type annotations and generating intermediate representation (IR) code. It supports both + immediate construction (when all parameters are statically annotated) and generator mode + (for dynamic construction). + Parameters + ---------- + func : Callable[_P, _T], optional + The function to be decorated. Can be None when using decorator with arguments. + generator : bool, default=False + If True, returns a generator function that creates PrimFunc instances on demand. + If False, attempts to create a PrimFunc immediately using type annotations. + Returns + ------- + PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]] + - If `generator=False` and all parameters are statically annotated: returns a PrimFunc instance + - If `generator=True`: returns a callable that generates PrimFunc instances when invoked + - If used without parentheses: returns the decorator implementation function + Examples + -------- + Static annotation mode (immediate construction): + >>> @prim_func + ... def add_kernel(A: T.Buffer((128,), T.float32), + ... B: T.Buffer((128,), T.float32)): + ... for i in T.grid(128): + ... B[i] = A[i] + 1.0 + Generator mode (dynamic construction): + >>> @prim_func(generator=True) + ... def dynamic_kernel(A=T.Tensor((128,), T.float32)): + ... # function body + ... pass + >>> kernel_instance = dynamic_kernel() + With custom parameters: + >>> @prim_func(generator=True) + ... def parameterized_kernel(size: int = 128): + ... # function body using size parameter + ... pass + >>> kernel = parameterized_kernel(size=256) + See Also + -------- + Builder : The IR builder class used for constructing primitive functions + mutate : Function used to generate IR from the decorated function + """ + + def impl(func: Callable[_P, _T]) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]]: + sig = inspect.signature(func) + annot = get_type_hints(func) + + for k in annot: + if callable(annot[k]): + annot[k] = annot[k]() + + # check whether all arguments are annotated + all_arg_annotated = all([x in annot for x in sig.parameters]) + # check whether all annotations are Buffer/Var/dtype + all_annot_are_static = all([_is_static_annot(x) for x in annot.values()]) + ir_gen = mutate(func) + + def prim_func_generator(*args, **kwargs): + builder = Builder() + with builder.prim_func(func.__name__): + ir_gen.gen(builder)(*args, **kwargs) + res = builder.get() + res.ir_gen = ir_gen + res.source = ir_gen.source + res.orig_func = func + return res + + prim_func_generator.ir_gen = ir_gen + prim_func_generator.source = ir_gen.source + prim_func_generator.orig_func = func + + if generator: + return prim_func_generator + + if all_arg_annotated and all_annot_are_static: + return prim_func_generator(**annot) + else: + raise ValueError( + "Some arguments are not supported or statically annotated, \n" + "please check the annotations or set generator=True to get a prim_func generator.\n" + f"Argument Annotations: {annot}\n" + "Example usage of generator:\n" + "```py\n" + "@prim_func(generator=True)\n" + "def my_func(a=T.Tensor((128,), T.float32)): ...\n" + "return my_func()\n" + "```") + + return impl(func) if func is not None else impl diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py index f5413874a..def59845b 100644 --- a/tilelang/language/v2/dtypes.py +++ b/tilelang/language/v2/dtypes.py @@ -441,7 +441,7 @@ class bfloat16(dtype): ... float4_e2m1fnx64 = dtype('float4_e2m1fnx64') bfloat16 = dtype('bfloat16') -_all_dtypes = [ +_all_dtypes = { 'bool', 'short', 'int', @@ -596,9 +596,9 @@ class bfloat16(dtype): ... 'float4_e2m1fnx32', 'float4_e2m1fnx64', 'bfloat16', -] +} -__all__ = _all_dtypes + [ +__all__ = list(_all_dtypes) + [ 'dtype', 'AnyDType', 'get_tvm_dtype', From 5c80575e15841cdfa30debc0cee88bdff17b9a04 Mon Sep 17 00:00:00 2001 From: Freebase6912 Date: Mon, 3 Nov 2025 12:40:01 +0800 Subject: [PATCH 21/24] fix lint error --- .../language/test_tilelang_language_frontend_v2.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index c54ab706e..b4ca94232 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -220,6 +220,7 @@ def test_var_assign(A: T.Tensor((2,), T.int32)): def test_marco_return(): + @T.macro def macro_return_constant(): return 0 @@ -254,20 +255,22 @@ def test_macro_return(): def test_prim_func_generator(): + @T.prim_func(generator=True) def prim_func_gen( - A=T.Tensor((128,), T.float32), - B=T.Tensor((128,), T.float32), + A=T.Tensor((128,), T.float32), # noqa: B008 + B=T.Tensor((128,), T.float32), # noqa: B008 ): with T.Kernel(128) as (tx,): T.copy(A[tx], B[tx]) - pf = prim_func_gen() + + prim_func_gen() @T.prim_func def foo() -> T.Tensor((128,), T.float32): pass - assert isinstance(foo, T.PrimFunc) + assert isinstance(foo, T.PrimFunc) if __name__ == '__main__': From a507ba497e39ec900477f1c7fd3bd8fb1f05857a Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:28:38 +0800 Subject: [PATCH 22/24] Simplify index calculations in example_chunk_o_bwd.py Refactor index calculations for dg_last_fragment assignment. --- examples/gdn/example_chunk_o_bwd.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/gdn/example_chunk_o_bwd.py b/examples/gdn/example_chunk_o_bwd.py index 3f69b6b68..ce2671115 100644 --- a/examples/gdn/example_chunk_o_bwd.py +++ b/examples/gdn/example_chunk_o_bwd.py @@ -7,8 +7,6 @@ import tilelang.language as T from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401 -print(tilelang.__file__) - # Add your fla repository path to sys.path # Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae # sys.path.insert(0, "/home/tzj/flash-linear-attention") @@ -256,8 +254,7 @@ def kernel( # for i_kv in T.Parallel(block_DK * block_DV): # dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] for i_kv in T.Parallel(block_DK * block_DV): - i_k, i_v_1 = i_kv // block_DV, i_kv % block_DV - dg_last_fragment[i_kv] = h_shared[i_k, i_v_1] * dh_shared[i_k, i_v_1] + dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False) dg_last_local[0] += dg_last_fragment_scalar[0] From e09c1b77a725c45cfe2728b0a73c149c9a44096b Mon Sep 17 00:00:00 2001 From: Lei Wang Date: Mon, 3 Nov 2025 16:54:38 +0800 Subject: [PATCH 23/24] minor fix --- .../test_tilelang_transform_legalize_safe_memory_access.py | 2 +- tilelang/language/symbolics.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py index e5215db25..5202ab647 100644 --- a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py +++ b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py @@ -41,7 +41,7 @@ def assert_vectorize_access(M: int = 64, N: int = 64): def issue_1013_buggy_kernel(): # NOTE: This kernel is mainly to test some corner cases in boundary check - num_tokens = T.Var('num_tokens', 'int32') + num_tokens = T.dynamic('num_tokens') num_threads = 128 @T.prim_func diff --git a/tilelang/language/symbolics.py b/tilelang/language/symbolics.py index 92b9d5bab..928edf82c 100644 --- a/tilelang/language/symbolics.py +++ b/tilelang/language/symbolics.py @@ -7,7 +7,6 @@ __all__ = ["dynamic", "symbolic"] -@deprecated("T.dynamic(...)", "tir.Var(...)", "v0.1.9") def dynamic(name: str, dtype: str = "int32"): """ Create a TIR dynamic symbolic variable. @@ -22,7 +21,7 @@ def dynamic(name: str, dtype: str = "int32"): return tir.Var(name, dtype) -@deprecated("T.symbolic(...)", "T.dynamic(...)") +@deprecated("T.symbolic(...)", "T.dynamic(...)", "v0.1.9") def symbolic(name: str, dtype: str = "int32"): """Deprecated alias for `T.dynamic`.""" return tir.Var(name, dtype) From 7fe3c2d324debd63fd4b6ca0ded7364a78a87f1f Mon Sep 17 00:00:00 2001 From: Lei Wang Date: Mon, 3 Nov 2025 17:27:13 +0800 Subject: [PATCH 24/24] lint fix --- examples/gdn/example_chunk_o_bwd.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/gdn/example_chunk_o_bwd.py b/examples/gdn/example_chunk_o_bwd.py index ce2671115..7e87a2c4f 100644 --- a/examples/gdn/example_chunk_o_bwd.py +++ b/examples/gdn/example_chunk_o_bwd.py @@ -254,7 +254,9 @@ def kernel( # for i_kv in T.Parallel(block_DK * block_DV): # dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] for i_kv in T.Parallel(block_DK * block_DV): - dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] + dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % + block_DV] * dh_shared[i_kv // block_DV, + i_kv % block_DV] T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False) dg_last_local[0] += dg_last_fragment_scalar[0]