-
Notifications
You must be signed in to change notification settings - Fork 314
[Language] Initial version of tilelang frontend v2 #1120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 18 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
b5f36ad
tilelang frontend v2
kurisu6912 9d6659c
syntax sugar: defining a local var by annotation
kurisu6912 f887fb4
Merge branch 'main' into frontend-v2
kurisu6912 09d8aec
[Refactor] fix type linting warning like `T.float32`
kurisu6912 4c75e85
Add tl.local_var_init for new tl.float32
kurisu6912 8dce258
allow passing default argument as function annotation
kurisu6912 e3815c6
allow default arguments as annotation
kurisu6912 66ee7b2
fix lint error
kurisu6912 7745736
minor fix
kurisu6912 65bd4fc
[Refactor] refactor tilelang.jit and tilelang.autotune
kurisu6912 d592fbf
minor fix
kurisu6912 eff7916
minor fix
kurisu6912 6f69f02
minor fix
kurisu6912 20feef2
fix metal get function name
kurisu6912 015416b
add par_compile impl and tests
kurisu6912 61bfbdd
Merge branch 'refactor-jit-autotune' into frontend-v2
kurisu6912 a7e2027
Type consistency on tvm datatype
kurisu6912 4d0bc85
fix lint error
kurisu6912 0dfe4e3
add more warning in frontend
kurisu6912 f8a6f32
Merge branch 'main' of https://github.com/tile-ai/tilelang into front…
LeiWang1999 b29da36
update tvm version
kurisu6912 6eef76c
Merge branch 'main' into frontend-v2
kurisu6912 f1be506
Minor fix on tvm_ffi annotations
kurisu6912 0206356
Merge branch 'frontend-v2' of https://github.com/kurisu6912/tilelang …
kurisu6912 570be68
add document and examples
kurisu6912 5c80575
fix lint error
kurisu6912 f5a160e
Merge branch 'main' into frontend-v2
kurisu6912 a507ba4
Simplify index calculations in example_chunk_o_bwd.py
LeiWang1999 e09c1b7
minor fix
LeiWang1999 7fe3c2d
lint fix
LeiWang1999 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Submodule tvm
updated
from 5bf17a to 9cda9b
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| import tilelang.testing | ||
| import tilelang | ||
| import torch | ||
|
|
||
|
|
||
| @tilelang.jit( | ||
| out_idx=-1, # create the output tensor during runtime | ||
| verbose=True, | ||
| ) | ||
| def matmul_kernel_jit( | ||
| M, | ||
| N, | ||
| K, | ||
| block_M, | ||
| block_N, | ||
| block_K, | ||
| trans_A=False, | ||
| trans_B=True, | ||
| in_dtype='float16', | ||
| out_dtype='float32', | ||
| accum_dtype='float32', | ||
| num_stages=2, | ||
| threads=128, | ||
| ): | ||
| A_shape = (K, M) if trans_A else (M, K) | ||
| B_shape = (N, K) if trans_B else (K, N) | ||
| A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) | ||
| B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) | ||
|
|
||
| import tilelang.language as T | ||
|
|
||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor(A_shape, in_dtype), | ||
| B: T.Tensor(B_shape, in_dtype), | ||
| C: T.Tensor((M, N), out_dtype), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): | ||
| A_shared = T.alloc_shared(A_shared_shape, in_dtype) | ||
| B_shared = T.alloc_shared(B_shared_shape, in_dtype) | ||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
| T.clear(C_local) | ||
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): | ||
| if trans_A: | ||
| T.copy(A[k * block_K, by * block_M], A_shared) | ||
| else: | ||
| T.copy(A[by * block_M, k * block_K], A_shared) | ||
| if trans_B: | ||
| T.copy(B[bx * block_N, k * block_K], B_shared) | ||
| else: | ||
| T.copy(B[k * block_K, bx * block_N], B_shared) | ||
| T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) | ||
| T.copy(C_local, C[by * block_M, bx * block_N]) | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| def test_par_compile(): | ||
| configs = [ | ||
| (1024, 1024, 1024, 128, 128, 32), | ||
| (2048, 2048, 2048, 256, 256, 64), | ||
| (4096, 4096, 4096, 64, 64, 128), | ||
| ] | ||
| kernels = matmul_kernel_jit.par_compile(configs) | ||
| for (M, N, K, _, _, _), kernel in zip(configs, kernels): | ||
| A = torch.randn(M, K, dtype=torch.float16).cuda() | ||
| B = torch.randn(N, K, dtype=torch.float16).cuda() | ||
| ref = (A @ B.T).float() | ||
| C = kernel(A, B) | ||
| tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() | ||
File renamed without changes.
222 changes: 222 additions & 0 deletions
222
testing/python/language/test_tilelang_language_dtype.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,222 @@ | ||
| import tilelang | ||
| import tilelang.language as T | ||
| import torch | ||
| import tilelang.testing | ||
| import tvm | ||
|
|
||
|
|
||
| def test_argument(): | ||
|
|
||
| @T.prim_func | ||
| def test_argument( | ||
| t_1: T.bool, | ||
| t_2: T.short, | ||
| t_3: T.int, | ||
| t_4: T.long, | ||
| t_5: T.half, | ||
| t_6: T.float, | ||
| t_7: T.long, | ||
| t_8: T.int8, | ||
| t_9: T.int16, | ||
| t_10: T.int32, | ||
| t_11: T.int64, | ||
| t_12: T.uint8, | ||
| t_13: T.uint16, | ||
| t_14: T.uint32, | ||
| t_15: T.uint64, | ||
| t_16: T.float8_e4m3fn, | ||
| t_17: T.float8_e4m3fnuz, | ||
| t_18: T.float8_e5m2, | ||
| t_19: T.float8_e5m2fnuz, | ||
| t_20: T.float8_e8m0fnu, | ||
| t_21: T.float16, | ||
| t_22: T.bfloat16, | ||
| t_23: T.float32, | ||
| t_24: T.float64, | ||
| ): | ||
| pass | ||
|
|
||
|
|
||
| def test_expr(): | ||
| from tilelang.language.v2.dtypes import _all_dtypes | ||
| errors = [] | ||
| for name in _all_dtypes: | ||
| dtype = getattr(T, name) | ||
| assert isinstance(dtype, tvm.DataType), f"{dtype} is not tvm.DataType" | ||
| try: | ||
| dtype(1.0) | ||
| dtype() | ||
| except TypeError: | ||
| pass | ||
| except Exception: | ||
| errors.append(name) | ||
| assert not errors | ||
|
|
||
|
|
||
| def test_var_decl_sugar(): | ||
|
|
||
| @T.prim_func | ||
| def test_var_decl_sugar(): | ||
| with T.Kernel(128, 128) as (bx, by): | ||
| var_1: T.bool = 1.0 | ||
| var_2: T.short = 1.0 | ||
| var_3: T.int = 1.0 | ||
| var_4: T.long = 1.0 | ||
| var_5: T.half = 1.0 | ||
| var_6: T.float = 1.0 | ||
| var_7: T.long = 1.0 | ||
| var_8: T.int8 = 1.0 | ||
| var_9: T.int16 = 1.0 | ||
| var_10: T.int32 = 1.0 | ||
| var_11: T.int64 = 1.0 | ||
| var_12: T.uint8 = 1.0 | ||
| var_13: T.uint16 = 1.0 | ||
| var_14: T.uint32 = 1.0 | ||
| var_15: T.uint64 = 1.0 | ||
| var_16: T.float8_e4m3fn = 1.0 | ||
| var_17: T.float8_e4m3fnuz = 1.0 | ||
| var_18: T.float8_e5m2 = 1.0 | ||
| var_19: T.float8_e5m2fnuz = 1.0 | ||
| var_20: T.float8_e8m0fnu = 1.0 | ||
| var_21: T.float16 = 1.0 | ||
| var_22: T.bfloat16 = 1.0 | ||
| var_23: T.float32 = 1.0 | ||
| var_24: T.float64 = 1.0 | ||
| var_1: T.bool = var_1 | ||
| var_2: T.short = var_2 | ||
| var_3: T.int = var_3 | ||
| var_4: T.long = var_4 | ||
| var_5: T.half = var_5 | ||
| var_6: T.float = var_6 | ||
| var_7: T.long = var_7 | ||
| var_8: T.int8 = var_8 | ||
| var_9: T.int16 = var_9 | ||
| var_10: T.int32 = var_10 | ||
| var_11: T.int64 = var_11 | ||
| var_12: T.uint8 = var_12 | ||
| var_13: T.uint16 = var_13 | ||
| var_14: T.uint32 = var_14 | ||
| var_15: T.uint64 = var_15 | ||
| var_16: T.float8_e4m3fn = var_16 | ||
| var_17: T.float8_e4m3fnuz = var_17 | ||
| var_18: T.float8_e5m2 = var_18 | ||
| var_19: T.float8_e5m2fnuz = var_19 | ||
| var_20: T.float8_e8m0fnu = var_20 | ||
| var_21: T.float16 = var_21 | ||
| var_22: T.bfloat16 = var_22 | ||
| var_23: T.float32 = var_23 | ||
| var_24: T.float64 = var_24 | ||
|
|
||
| s = test_var_decl_sugar.script() | ||
| for i in range(1, 25): | ||
| assert f'var_{i}_1' in s | ||
| assert 'tl.local_var_init' in s | ||
|
|
||
|
|
||
| def test_dtype_str_repr(): | ||
|
|
||
| @T.prim_func | ||
| def test_str_repr(): | ||
| buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared') # noqa F841 | ||
| buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared') # noqa F841 | ||
| buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared') # noqa F841 | ||
| buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841 | ||
| buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared') # noqa F841 | ||
| buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared') # noqa F841 | ||
| buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841 | ||
| buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared') # noqa F841 | ||
| buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared') # noqa F841 | ||
| buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared') # noqa F841 | ||
| buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared') # noqa F841 | ||
| buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared') # noqa F841 | ||
| buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared') # noqa F841 | ||
| buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared') # noqa F841 | ||
| buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared') # noqa F841 | ||
| buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared') # noqa F841 | ||
| buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared') # noqa F841 | ||
| buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared') # noqa F841 | ||
| buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared') # noqa F841 | ||
| buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared') # noqa F841 | ||
| buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared') # noqa F841 | ||
| buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared') # noqa F841 | ||
| buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared') # noqa F841 | ||
| buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841 | ||
|
|
||
|
|
||
| def test_torch_eq(): | ||
| dtypes = [ | ||
| T.bool, | ||
| T.short, | ||
| T.int, | ||
| T.long, | ||
| T.half, | ||
| T.float, | ||
| T.long, | ||
| T.int8, | ||
| T.int16, | ||
| T.int32, | ||
| T.int64, | ||
| T.uint8, | ||
| T.uint16, | ||
| T.uint32, | ||
| T.uint64, | ||
| T.float8_e4m3fn, | ||
| T.float8_e4m3fnuz, | ||
| T.float8_e5m2, | ||
| T.float8_e5m2fnuz, | ||
| T.float8_e8m0fnu, | ||
| T.float16, | ||
| T.bfloat16, | ||
| T.float32, | ||
| T.float64, | ||
| ] | ||
| torch_dtypes = [ | ||
| torch.bool, | ||
| torch.short, | ||
| torch.int, | ||
| torch.long, | ||
| torch.half, | ||
| torch.float, | ||
| torch.long, | ||
| torch.int8, | ||
| torch.int16, | ||
| torch.int32, | ||
| torch.int64, | ||
| torch.uint8, | ||
| torch.uint16, | ||
| torch.uint32, | ||
| torch.uint64, | ||
| torch.float8_e4m3fn, | ||
| torch.float8_e4m3fnuz, | ||
| torch.float8_e5m2, | ||
| torch.float8_e5m2fnuz, | ||
| torch.float8_e8m0fnu, | ||
| torch.float16, | ||
| torch.bfloat16, | ||
| torch.float32, | ||
| torch.float64, | ||
| ] | ||
| for a, b in zip(dtypes, torch_dtypes): | ||
| assert a == b, f"{a} and {b} are not equal" | ||
|
|
||
|
|
||
| def test_var_assign(): | ||
|
|
||
| @tilelang.jit(out_idx=-1) | ||
| @T.prim_func | ||
| def test_var_assign(A: T.Tensor((2,), T.int32)): | ||
| with T.Kernel(1) as _: | ||
| a: T.int32 = 1 | ||
| b: T.int32 = a | ||
| a = 2 | ||
| d: T.int32 = a | ||
| A[0] = b | ||
| A[1] = d | ||
|
|
||
| res = test_var_assign()() | ||
| assert res[0] == 1 | ||
| assert res[1] == 2 | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| tilelang.testing.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Make test portable across CUDA/MPS and select correct backend for metal.
Currently hard-codes CUDA and default backend; metal requires execution_backend="torch". Choose device/backend at import time and skip if none.
Also applies to: 58-71