Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b5f36ad
tilelang frontend v2
kurisu6912 Oct 24, 2025
9d6659c
syntax sugar: defining a local var by annotation
kurisu6912 Oct 27, 2025
f887fb4
Merge branch 'main' into frontend-v2
kurisu6912 Oct 27, 2025
09d8aec
[Refactor] fix type linting warning like `T.float32`
kurisu6912 Oct 27, 2025
4c75e85
Add tl.local_var_init for new tl.float32
kurisu6912 Oct 27, 2025
8dce258
allow passing default argument as function annotation
kurisu6912 Oct 27, 2025
e3815c6
allow default arguments as annotation
kurisu6912 Oct 27, 2025
66ee7b2
fix lint error
kurisu6912 Oct 27, 2025
7745736
minor fix
kurisu6912 Oct 27, 2025
65bd4fc
[Refactor] refactor tilelang.jit and tilelang.autotune
kurisu6912 Oct 27, 2025
d592fbf
minor fix
kurisu6912 Oct 27, 2025
eff7916
minor fix
kurisu6912 Oct 27, 2025
6f69f02
minor fix
kurisu6912 Oct 27, 2025
20feef2
fix metal get function name
kurisu6912 Oct 28, 2025
015416b
add par_compile impl and tests
kurisu6912 Oct 28, 2025
61bfbdd
Merge branch 'refactor-jit-autotune' into frontend-v2
kurisu6912 Oct 28, 2025
a7e2027
Type consistency on tvm datatype
kurisu6912 Oct 28, 2025
4d0bc85
fix lint error
kurisu6912 Oct 28, 2025
0dfe4e3
add more warning in frontend
kurisu6912 Oct 29, 2025
f8a6f32
Merge branch 'main' of https://github.com/tile-ai/tilelang into front…
LeiWang1999 Oct 31, 2025
b29da36
update tvm version
kurisu6912 Oct 31, 2025
6eef76c
Merge branch 'main' into frontend-v2
kurisu6912 Oct 31, 2025
f1be506
Minor fix on tvm_ffi annotations
kurisu6912 Oct 31, 2025
0206356
Merge branch 'frontend-v2' of https://github.com/kurisu6912/tilelang …
kurisu6912 Oct 31, 2025
570be68
add document and examples
kurisu6912 Nov 3, 2025
5c80575
fix lint error
kurisu6912 Nov 3, 2025
f5a160e
Merge branch 'main' into frontend-v2
kurisu6912 Nov 3, 2025
a507ba4
Simplify index calculations in example_chunk_o_bwd.py
LeiWang1999 Nov 3, 2025
e09c1b7
minor fix
LeiWang1999 Nov 3, 2025
7fe3c2d
lint fix
LeiWang1999 Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions examples/gdn/example_chunk_o_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import tilelang.language as T
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401

print(tilelang.__file__)

# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
Expand Down Expand Up @@ -256,8 +254,9 @@ def kernel(
# for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for i_kv in T.Parallel(block_DK * block_DV):
i_k, i_v = i_kv // block_DV, i_kv % block_DV
dg_last_fragment[i_kv] = h_shared[i_k, i_v] * dh_shared[i_k, i_v]
dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv %
block_DV] * dh_shared[i_kv // block_DV,
i_kv % block_DV]
T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False)
dg_last_local[0] += dg_last_fragment_scalar[0]

Expand Down
74 changes: 74 additions & 0 deletions testing/python/jit/test_tilelang_jit_parcompile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import tilelang.testing
import tilelang
import torch


@tilelang.jit(
out_idx=-1, # create the output tensor during runtime
verbose=True,
)
def matmul_kernel_jit(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A=False,
trans_B=True,
in_dtype='float16',
out_dtype='float32',
accum_dtype='float32',
num_stages=2,
threads=128,
):
Comment on lines +6 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Make test portable across CUDA/MPS and select correct backend for metal.

Currently hard-codes CUDA and default backend; metal requires execution_backend="torch". Choose device/backend at import time and skip if none.

+import pytest
 import tilelang.testing
 import tilelang
 import torch
 
-@tilelang.jit(
-    out_idx=-1,  # create the output tensor during runtime
-    verbose=True,
-)
+# Device/backend selection
+USE_CUDA = torch.cuda.is_available()
+USE_MPS = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
+DEVICE = "cuda" if USE_CUDA else ("mps" if USE_MPS else None)
+EXEC_BACKEND = "torch" if (USE_MPS and not USE_CUDA) else "cython"
+
+@tilelang.jit(
+    out_idx=-1,  # create the output tensor during runtime
+    verbose=True,
+    execution_backend=EXEC_BACKEND,
+)
 def matmul_kernel_jit(
@@
 def test_par_compile():
+    if DEVICE is None:
+        pytest.skip("No CUDA or MPS device available for JIT test.")
     configs = [
         (1024, 1024, 1024, 128, 128, 32),
         (2048, 2048, 2048, 256, 256, 64),
         (4096, 4096, 4096, 64, 64, 128),
     ]
     kernels = matmul_kernel_jit.par_compile(configs)
     for (M, N, K, _, _, _), kernel in zip(configs, kernels):
-        A = torch.randn(M, K, dtype=torch.float16).cuda()
-        B = torch.randn(N, K, dtype=torch.float16).cuda()
+        A = torch.randn(M, K, dtype=torch.float16, device=DEVICE)
+        B = torch.randn(N, K, dtype=torch.float16, device=DEVICE)
         ref = (A @ B.T).float()
         C = kernel(A, B)
         tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2)

Also applies to: 58-71

A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

import tilelang.language as T

@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])

return main


def test_par_compile():
configs = [
(1024, 1024, 1024, 128, 128, 32),
(2048, 2048, 2048, 256, 256, 64),
(4096, 4096, 4096, 64, 64, 128),
]
kernels = matmul_kernel_jit.par_compile(configs)
for (M, N, K, _, _, _), kernel in zip(configs, kernels):
A = torch.randn(M, K, dtype=torch.float16).cuda()
B = torch.randn(N, K, dtype=torch.float16).cuda()
ref = (A @ B.T).float()
C = kernel(A, B)
tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2)


if __name__ == "__main__":
tilelang.testing.main()
277 changes: 277 additions & 0 deletions testing/python/language/test_tilelang_language_frontend_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
import tvm


def test_argument():

@T.prim_func
def test_argument(
t_1: T.bool,
t_2: T.short,
t_3: T.int,
t_4: T.long,
t_5: T.half,
t_6: T.float,
t_7: T.long,
t_8: T.int8,
t_9: T.int16,
t_10: T.int32,
t_11: T.int64,
t_12: T.uint8,
t_13: T.uint16,
t_14: T.uint32,
t_15: T.uint64,
t_16: T.float8_e4m3fn,
t_17: T.float8_e4m3fnuz,
t_18: T.float8_e5m2,
t_19: T.float8_e5m2fnuz,
t_20: T.float8_e8m0fnu,
t_21: T.float16,
t_22: T.bfloat16,
t_23: T.float32,
t_24: T.float64,
):
pass
Comment on lines +8 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Duplicate dtype at parameter positions 4 and 7.

Both t_4 and t_7 are typed as T.long (lines 15 and 18). This appears to be unintentional duplication. Based on the dtype coverage pattern, position 7 should likely use a different dtype (e.g., T.double or T.float64).

Apply this diff to test a distinct dtype at position 7:

     t_4: T.long,
     t_5: T.half,
     t_6: T.float,
-    t_7: T.long,
+    t_7: T.double,
     t_8: T.int8,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def test_argument():
@T.prim_func
def test_argument(
t_1: T.bool,
t_2: T.short,
t_3: T.int,
t_4: T.long,
t_5: T.half,
t_6: T.float,
t_7: T.long,
t_8: T.int8,
t_9: T.int16,
t_10: T.int32,
t_11: T.int64,
t_12: T.uint8,
t_13: T.uint16,
t_14: T.uint32,
t_15: T.uint64,
t_16: T.float8_e4m3fn,
t_17: T.float8_e4m3fnuz,
t_18: T.float8_e5m2,
t_19: T.float8_e5m2fnuz,
t_20: T.float8_e8m0fnu,
t_21: T.float16,
t_22: T.bfloat16,
t_23: T.float32,
t_24: T.float64,
):
pass
def test_argument():
@T.prim_func
def test_argument(
t_1: T.bool,
t_2: T.short,
t_3: T.int,
t_4: T.long,
t_5: T.half,
t_6: T.float,
t_7: T.double,
t_8: T.int8,
t_9: T.int16,
t_10: T.int32,
t_11: T.int64,
t_12: T.uint8,
t_13: T.uint16,
t_14: T.uint32,
t_15: T.uint64,
t_16: T.float8_e4m3fn,
t_17: T.float8_e4m3fnuz,
t_18: T.float8_e5m2,
t_19: T.float8_e5m2fnuz,
t_20: T.float8_e8m0fnu,
t_21: T.float16,
t_22: T.bfloat16,
t_23: T.float32,
t_24: T.float64,
):
pass
🤖 Prompt for AI Agents
testing/python/language/test_tilelang_language_frontend_v2.py around lines 8 to
37: parameter t_7 is duplicated as T.long (same as t_4); replace its dtype with
a distinct floating-point type (e.g., change t_7: T.long to t_7: T.double) so
each parameter uses a unique dtype consistent with the test coverage pattern.



def test_expr():
from tilelang.language.v2.dtypes import _all_dtypes
errors = []
for name in _all_dtypes:
dtype = getattr(T, name)
assert isinstance(dtype, tvm.DataType), f"{dtype} is not tvm.DataType"
try:
dtype(1.0)
dtype()
except TypeError:
pass
except Exception:
errors.append(name)
assert not errors


# def test_var_decl_sugar():

# @T.prim_func
# def test_var_decl_sugar():
# with T.Kernel(128, 128) as (bx, by):
# var_1: T.bool = 1.0
# var_2: T.short = 1.0
# var_3: T.int = 1.0
# var_4: T.long = 1.0
# var_5: T.half = 1.0
# var_6: T.float = 1.0
# var_7: T.long = 1.0
# var_8: T.int8 = 1.0
# var_9: T.int16 = 1.0
# var_10: T.int32 = 1.0
# var_11: T.int64 = 1.0
# var_12: T.uint8 = 1.0
# var_13: T.uint16 = 1.0
# var_14: T.uint32 = 1.0
# var_15: T.uint64 = 1.0
# var_16: T.float8_e4m3fn = 1.0
# var_17: T.float8_e4m3fnuz = 1.0
# var_18: T.float8_e5m2 = 1.0
# var_19: T.float8_e5m2fnuz = 1.0
# var_20: T.float8_e8m0fnu = 1.0
# var_21: T.float16 = 1.0
# var_22: T.bfloat16 = 1.0
# var_23: T.float32 = 1.0
# var_24: T.float64 = 1.0
# var_1: T.bool = var_1
# var_2: T.short = var_2
# var_3: T.int = var_3
# var_4: T.long = var_4
# var_5: T.half = var_5
# var_6: T.float = var_6
# var_7: T.long = var_7
# var_8: T.int8 = var_8
# var_9: T.int16 = var_9
# var_10: T.int32 = var_10
# var_11: T.int64 = var_11
# var_12: T.uint8 = var_12
# var_13: T.uint16 = var_13
# var_14: T.uint32 = var_14
# var_15: T.uint64 = var_15
# var_16: T.float8_e4m3fn = var_16
# var_17: T.float8_e4m3fnuz = var_17
# var_18: T.float8_e5m2 = var_18
# var_19: T.float8_e5m2fnuz = var_19
# var_20: T.float8_e8m0fnu = var_20
# var_21: T.float16 = var_21
# var_22: T.bfloat16 = var_22
# var_23: T.float32 = var_23
# var_24: T.float64 = var_24

# s = test_var_decl_sugar.script()
# for i in range(1, 25):
# assert f'var_{i}_1' in s
# assert 'tl.local_var_init' in s


def test_dtype_str_repr():

@T.prim_func
def test_str_repr():
buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared') # noqa F841
buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared') # noqa F841
buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared') # noqa F841
buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared') # noqa F841
buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared') # noqa F841
buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared') # noqa F841
buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared') # noqa F841
buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared') # noqa F841
buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared') # noqa F841
buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared') # noqa F841
buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared') # noqa F841
buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared') # noqa F841
buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared') # noqa F841
buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared') # noqa F841
buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared') # noqa F841
buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared') # noqa F841
buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared') # noqa F841
buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared') # noqa F841
buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared') # noqa F841
buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared') # noqa F841
buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared') # noqa F841
buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841


def test_torch_eq():
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.long,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.long,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"
Comment on lines +147 to +200
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Duplicate dtype comparison at indices 3 and 6.

Both T.long and torch.long appear twice in the test lists (lines 151/154 and 177/180). This appears to be unintentional duplication.

Apply this diff to test a different dtype at position 6:

         T.half,
         T.float,
-        T.long,
+        T.double,
         T.int8,
         torch.half,
         torch.float,
-        torch.long,
+        torch.double,
         torch.int8,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.long,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.long,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.double,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.double,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"
🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_dtype.py around lines 147 to
200, the dtype pair at the 7th position is duplicated (T.long / torch.long
appear twice); replace the second occurrence (the entries at lines ~154 and
~180) with the intended dtype so both lists remain identical (update both dtypes
arrays consistently to the correct dtype and re-run tests to confirm).

assert T.dtype(b) == a, "dtype conversion error"


def test_var_assign():

@tilelang.jit(out_idx=-1)
@T.prim_func
def test_var_assign(A: T.Tensor((2,), T.int32)):
with T.Kernel(1) as _:
a: T.int32 = 1
b: T.int32 = a
a = 2
d: T.int32 = a
A[0] = b
A[1] = d

res = test_var_assign()()
assert res[0] == 1
assert res[1] == 2


def test_marco_return():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix typo in function name.

The function name test_marco_return contains a typo; it should be test_macro_return to accurately reflect that it tests macro functionality.

Apply this diff to fix the typo:

-def test_marco_return():
+def test_macro_return():
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def test_marco_return():
def test_macro_return():
🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_frontend_v2.py around line 222
the test function name has a typo: rename the function from test_marco_return to
test_macro_return; update the function definition accordingly and also search &
update any references or markers (e.g., pytest -k patterns or other calls) that
reference the old name so they match the corrected test_macro_return.


@T.macro
def macro_return_constant():
return 0

@T.macro
def macro_return_frame(x):
return T.alloc_var(T.float32, init=x)

@T.macro
def macro_return_expr(x):
y = x + 1.0
return y

@T.macro
def macro_apply_func(x, fn):
return fn(x)

def check(x, ty):
assert isinstance(x, ty)

@T.prim_func
def test_macro_return():
with T.Kernel(1) as _:
a = macro_return_constant()
b = macro_return_frame(3.0)
c = macro_return_expr(4.0)
d = macro_apply_func(5.0, lambda x: x * 2.0)
check(a, (int, float, T.PrimExpr))
check(b, T.PrimExpr)
check(c, T.PrimExpr)
check(d, T.PrimExpr)


def test_prim_func_generator():

@T.prim_func(generator=True)
def prim_func_gen(
A=T.Tensor((128,), T.float32), # noqa: B008
B=T.Tensor((128,), T.float32), # noqa: B008
):
with T.Kernel(128) as (tx,):
T.copy(A[tx], B[tx])

prim_func_gen()

@T.prim_func
def foo() -> T.Tensor((128,), T.float32):
pass

assert isinstance(foo, T.PrimFunc)


if __name__ == '__main__':
tilelang.testing.main()
2 changes: 1 addition & 1 deletion testing/python/language/test_tilelang_language_let.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def main(A_ptr: T.handle):

for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
b: T.float32x4 = A[0, 0:4]
b = A[0, 0:4]
A[0, 4:8] = b

mod = tvm.IRModule({"main": main})
Expand Down
Loading
Loading