Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b5f36ad
tilelang frontend v2
kurisu6912 Oct 24, 2025
9d6659c
syntax sugar: defining a local var by annotation
kurisu6912 Oct 27, 2025
f887fb4
Merge branch 'main' into frontend-v2
kurisu6912 Oct 27, 2025
09d8aec
[Refactor] fix type linting warning like `T.float32`
kurisu6912 Oct 27, 2025
4c75e85
Add tl.local_var_init for new tl.float32
kurisu6912 Oct 27, 2025
8dce258
allow passing default argument as function annotation
kurisu6912 Oct 27, 2025
e3815c6
allow default arguments as annotation
kurisu6912 Oct 27, 2025
66ee7b2
fix lint error
kurisu6912 Oct 27, 2025
7745736
minor fix
kurisu6912 Oct 27, 2025
65bd4fc
[Refactor] refactor tilelang.jit and tilelang.autotune
kurisu6912 Oct 27, 2025
d592fbf
minor fix
kurisu6912 Oct 27, 2025
eff7916
minor fix
kurisu6912 Oct 27, 2025
6f69f02
minor fix
kurisu6912 Oct 27, 2025
20feef2
fix metal get function name
kurisu6912 Oct 28, 2025
015416b
add par_compile impl and tests
kurisu6912 Oct 28, 2025
61bfbdd
Merge branch 'refactor-jit-autotune' into frontend-v2
kurisu6912 Oct 28, 2025
a7e2027
Type consistency on tvm datatype
kurisu6912 Oct 28, 2025
4d0bc85
fix lint error
kurisu6912 Oct 28, 2025
0dfe4e3
add more warning in frontend
kurisu6912 Oct 29, 2025
f8a6f32
Merge branch 'main' of https://github.com/tile-ai/tilelang into front…
LeiWang1999 Oct 31, 2025
b29da36
update tvm version
kurisu6912 Oct 31, 2025
6eef76c
Merge branch 'main' into frontend-v2
kurisu6912 Oct 31, 2025
f1be506
Minor fix on tvm_ffi annotations
kurisu6912 Oct 31, 2025
0206356
Merge branch 'frontend-v2' of https://github.com/kurisu6912/tilelang …
kurisu6912 Oct 31, 2025
570be68
add document and examples
kurisu6912 Nov 3, 2025
5c80575
fix lint error
kurisu6912 Nov 3, 2025
f5a160e
Merge branch 'main' into frontend-v2
kurisu6912 Nov 3, 2025
a507ba4
Simplify index calculations in example_chunk_o_bwd.py
LeiWang1999 Nov 3, 2025
e09c1b7
minor fix
LeiWang1999 Nov 3, 2025
7fe3c2d
lint fix
LeiWang1999 Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 5bf17a to 9cda9b
4 changes: 2 additions & 2 deletions examples/gdn/example_chunk_o_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ def kernel(
# for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for i_kv in T.Parallel(block_DK * block_DV):
i_k, i_v = i_kv // block_DV, i_kv % block_DV
dg_last_fragment[i_kv] = h_shared[i_k, i_v] * dh_shared[i_k, i_v]
i_k, i_v_1 = i_kv // block_DV, i_kv % block_DV
dg_last_fragment[i_kv] = h_shared[i_k, i_v_1] * dh_shared[i_k, i_v_1]
T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False)
dg_last_local[0] += dg_last_fragment_scalar[0]

Expand Down
74 changes: 74 additions & 0 deletions testing/python/jit/test_tilelang_jit_parcompile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import tilelang.testing
import tilelang
import torch


@tilelang.jit(
out_idx=-1, # create the output tensor during runtime
verbose=True,
)
def matmul_kernel_jit(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A=False,
trans_B=True,
in_dtype='float16',
out_dtype='float32',
accum_dtype='float32',
num_stages=2,
threads=128,
):
Comment on lines +6 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Make test portable across CUDA/MPS and select correct backend for metal.

Currently hard-codes CUDA and default backend; metal requires execution_backend="torch". Choose device/backend at import time and skip if none.

+import pytest
 import tilelang.testing
 import tilelang
 import torch
 
-@tilelang.jit(
-    out_idx=-1,  # create the output tensor during runtime
-    verbose=True,
-)
+# Device/backend selection
+USE_CUDA = torch.cuda.is_available()
+USE_MPS = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
+DEVICE = "cuda" if USE_CUDA else ("mps" if USE_MPS else None)
+EXEC_BACKEND = "torch" if (USE_MPS and not USE_CUDA) else "cython"
+
+@tilelang.jit(
+    out_idx=-1,  # create the output tensor during runtime
+    verbose=True,
+    execution_backend=EXEC_BACKEND,
+)
 def matmul_kernel_jit(
@@
 def test_par_compile():
+    if DEVICE is None:
+        pytest.skip("No CUDA or MPS device available for JIT test.")
     configs = [
         (1024, 1024, 1024, 128, 128, 32),
         (2048, 2048, 2048, 256, 256, 64),
         (4096, 4096, 4096, 64, 64, 128),
     ]
     kernels = matmul_kernel_jit.par_compile(configs)
     for (M, N, K, _, _, _), kernel in zip(configs, kernels):
-        A = torch.randn(M, K, dtype=torch.float16).cuda()
-        B = torch.randn(N, K, dtype=torch.float16).cuda()
+        A = torch.randn(M, K, dtype=torch.float16, device=DEVICE)
+        B = torch.randn(N, K, dtype=torch.float16, device=DEVICE)
         ref = (A @ B.T).float()
         C = kernel(A, B)
         tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2)

Also applies to: 58-71

A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

import tilelang.language as T

@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])

return main


def test_par_compile():
configs = [
(1024, 1024, 1024, 128, 128, 32),
(2048, 2048, 2048, 256, 256, 64),
(4096, 4096, 4096, 64, 64, 128),
]
kernels = matmul_kernel_jit.par_compile(configs)
for (M, N, K, _, _, _), kernel in zip(configs, kernels):
A = torch.randn(M, K, dtype=torch.float16).cuda()
B = torch.randn(N, K, dtype=torch.float16).cuda()
ref = (A @ B.T).float()
C = kernel(A, B)
tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2)


if __name__ == "__main__":
tilelang.testing.main()
222 changes: 222 additions & 0 deletions testing/python/language/test_tilelang_language_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
import tvm


def test_argument():

@T.prim_func
def test_argument(
t_1: T.bool,
t_2: T.short,
t_3: T.int,
t_4: T.long,
t_5: T.half,
t_6: T.float,
t_7: T.long,
t_8: T.int8,
t_9: T.int16,
t_10: T.int32,
t_11: T.int64,
t_12: T.uint8,
t_13: T.uint16,
t_14: T.uint32,
t_15: T.uint64,
t_16: T.float8_e4m3fn,
t_17: T.float8_e4m3fnuz,
t_18: T.float8_e5m2,
t_19: T.float8_e5m2fnuz,
t_20: T.float8_e8m0fnu,
t_21: T.float16,
t_22: T.bfloat16,
t_23: T.float32,
t_24: T.float64,
):
pass


def test_expr():
from tilelang.language.v2.dtypes import _all_dtypes
errors = []
for name in _all_dtypes:
dtype = getattr(T, name)
assert isinstance(dtype, tvm.DataType), f"{dtype} is not tvm.DataType"
try:
dtype(1.0)
dtype()
except TypeError:
pass
except Exception:
errors.append(name)
assert not errors


def test_var_decl_sugar():

@T.prim_func
def test_var_decl_sugar():
with T.Kernel(128, 128) as (bx, by):
var_1: T.bool = 1.0
var_2: T.short = 1.0
var_3: T.int = 1.0
var_4: T.long = 1.0
var_5: T.half = 1.0
var_6: T.float = 1.0
var_7: T.long = 1.0
var_8: T.int8 = 1.0
var_9: T.int16 = 1.0
var_10: T.int32 = 1.0
var_11: T.int64 = 1.0
var_12: T.uint8 = 1.0
var_13: T.uint16 = 1.0
var_14: T.uint32 = 1.0
var_15: T.uint64 = 1.0
var_16: T.float8_e4m3fn = 1.0
var_17: T.float8_e4m3fnuz = 1.0
var_18: T.float8_e5m2 = 1.0
var_19: T.float8_e5m2fnuz = 1.0
var_20: T.float8_e8m0fnu = 1.0
var_21: T.float16 = 1.0
var_22: T.bfloat16 = 1.0
var_23: T.float32 = 1.0
var_24: T.float64 = 1.0
var_1: T.bool = var_1
var_2: T.short = var_2
var_3: T.int = var_3
var_4: T.long = var_4
var_5: T.half = var_5
var_6: T.float = var_6
var_7: T.long = var_7
var_8: T.int8 = var_8
var_9: T.int16 = var_9
var_10: T.int32 = var_10
var_11: T.int64 = var_11
var_12: T.uint8 = var_12
var_13: T.uint16 = var_13
var_14: T.uint32 = var_14
var_15: T.uint64 = var_15
var_16: T.float8_e4m3fn = var_16
var_17: T.float8_e4m3fnuz = var_17
var_18: T.float8_e5m2 = var_18
var_19: T.float8_e5m2fnuz = var_19
var_20: T.float8_e8m0fnu = var_20
var_21: T.float16 = var_21
var_22: T.bfloat16 = var_22
var_23: T.float32 = var_23
var_24: T.float64 = var_24

s = test_var_decl_sugar.script()
for i in range(1, 25):
assert f'var_{i}_1' in s
assert 'tl.local_var_init' in s


def test_dtype_str_repr():

@T.prim_func
def test_str_repr():
buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared') # noqa F841
buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared') # noqa F841
buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared') # noqa F841
buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared') # noqa F841
buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared') # noqa F841
buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared') # noqa F841
buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared') # noqa F841
buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared') # noqa F841
buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared') # noqa F841
buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared') # noqa F841
buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared') # noqa F841
buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared') # noqa F841
buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared') # noqa F841
buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared') # noqa F841
buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared') # noqa F841
buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared') # noqa F841
buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared') # noqa F841
buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared') # noqa F841
buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared') # noqa F841
buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared') # noqa F841
buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared') # noqa F841
buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841


def test_torch_eq():
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.long,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.long,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"


def test_var_assign():

@tilelang.jit(out_idx=-1)
@T.prim_func
def test_var_assign(A: T.Tensor((2,), T.int32)):
with T.Kernel(1) as _:
a: T.int32 = 1
b: T.int32 = a
a = 2
d: T.int32 = a
A[0] = b
A[1] = d

res = test_var_assign()()
assert res[0] == 1
assert res[1] == 2


if __name__ == '__main__':
tilelang.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n")
K = tvm.te.var("k")

@tvm.script.ir.ir_module
class Before:
def before():

@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
Expand All @@ -38,8 +37,9 @@ def main(B: T.Tensor((K, N), dtype),):
(block_N // vec_load_b) * (block_N // vec_load_b) + vec],
T.float16(0))

@tvm.script.ir.ir_module
class After:
return tvm.IRModule({'main': main})

def after():

@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
Expand Down Expand Up @@ -77,11 +77,13 @@ def main(B: T.Tensor((K, N), dtype),):
bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0))

return tvm.IRModule({'main': main})

with tvm.target.Target(auto_target):
mod = tvm.tir.transform.BindTarget(auto_target)(Before)
mod = tvm.tir.transform.BindTarget(auto_target)(before())
mod = tl.transform.LayoutInference()(mod)
mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(after())
ref_mod = tvm.tir.transform.Simplify()(ref_mod)
# Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass
# This loop is "for vec in T.parallel(1)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
def issue_1013_buggy_kernel():
# NOTE: This kernel is mainly to test some corner cases in boundary check

num_tokens = T.dynamic('num_tokens')
num_tokens = T.Var('num_tokens', 'int32')
num_threads = 128

@T.prim_func
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n")
K = tvm.te.var("k")

@tvm.script.ir.ir_module
class Before:
def before():

@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
Expand All @@ -25,8 +24,9 @@ def main(B: T.Tensor((K, N), dtype),):
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(B[k * block_K, bx * block_N], B_shared)

@tvm.script.ir.ir_module
class After:
return tvm.IRModule({'main': main})

def after():

@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
Expand Down Expand Up @@ -64,11 +64,13 @@ def main(B: T.Tensor((K, N), dtype),):
bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0))

return tvm.IRModule({'main': main})

with tvm.transform.PassContext():
mod = tvm.tir.transform.BindTarget(auto_target)(Before)
mod = tvm.tir.transform.BindTarget(auto_target)(before())
mod = tl.transform.LowerTileOp()(mod)
mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(after())
ref_mod = tvm.tir.transform.Simplify()(ref_mod)
# Note(tzj): The structures are equal except the argument in "T.reads" function.
# The difference is just between the first index and the indices range, which is totally equivalent
Expand Down
Loading
Loading