Skip to content

Commit f875ec3

Browse files
committed
Add demo_tilelang_to_ninetoothed.py
1 parent d9b4a82 commit f875ec3

1 file changed

Lines changed: 186 additions & 0 deletions

File tree

demo_tilelang_to_ninetoothed.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import torch
2+
import triton
3+
4+
import ops.ninetoothed.kernels.mm
5+
import ops.tilelang.kernels.mm
6+
import ops.triton.kernels.mm
7+
import tilelang_to_ninetoothed
8+
9+
BLOCK_SIZE_M = 128
10+
BLOCK_SIZE_N = 128
11+
BLOCK_SIZE_K = 32
12+
13+
ninetoothed_mm_kernel = ops.ninetoothed.kernels.mm.kernel
14+
15+
triton_mm_kernel = ops.triton.kernels.mm.kernel
16+
17+
tilelang_mm_kernel = ops.tilelang.kernels.mm.mm(
18+
ops.tilelang.kernels.mm.M,
19+
ops.tilelang.kernels.mm.N,
20+
ops.tilelang.kernels.mm.K,
21+
BLOCK_SIZE_M,
22+
BLOCK_SIZE_N,
23+
BLOCK_SIZE_K,
24+
)
25+
26+
ninetoothed_mm_kernel_from_tilelang = (
27+
tilelang_to_ninetoothed.transform_tilelang_to_ninetoothed(
28+
ops.tilelang.kernels.mm.mm
29+
)
30+
)
31+
32+
33+
def ninetoothed_mm(input, other):
34+
output_shape = (input.shape[0], other.shape[1])
35+
output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
36+
37+
ninetoothed_mm_kernel(input, other, output)
38+
39+
return output
40+
41+
42+
def triton_mm(input, other):
43+
output_shape = (input.shape[0], other.shape[1])
44+
output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
45+
46+
def grid(meta):
47+
return (
48+
triton.cdiv(input.shape[0], meta["BLOCK_SIZE_M"])
49+
* triton.cdiv(other.shape[1], meta["BLOCK_SIZE_N"]),
50+
)
51+
52+
triton_mm_kernel[grid](
53+
input,
54+
other,
55+
output,
56+
input.shape[0],
57+
other.shape[1],
58+
input.shape[1],
59+
input.stride(0),
60+
input.stride(1),
61+
other.stride(0),
62+
other.stride(1),
63+
output.stride(0),
64+
output.stride(1),
65+
)
66+
67+
return output
68+
69+
70+
def tilelang_mm(input, other):
71+
output_shape = (input.shape[0], other.shape[1])
72+
output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
73+
74+
tilelang_mm_kernel(input, other, output)
75+
76+
return output
77+
78+
79+
def ninetoothed_from_tilelang_mm(input, other):
80+
m, k = input.shape
81+
_, n = other.shape
82+
83+
output = torch.empty((m, n), dtype=input.dtype, device=input.device)
84+
85+
ninetoothed_mm_kernel_from_tilelang(
86+
input, other, output, M=m, N=n, K=k, block_M=64, block_N=64, block_K=64
87+
)
88+
89+
return output
90+
91+
92+
def torch_mm(input, other):
93+
return torch.mm(input, other)
94+
95+
96+
if __name__ == "__main__":
97+
torch.manual_seed(0)
98+
99+
shape = (512, 512)
100+
dtype = torch.float16
101+
device = "cuda"
102+
103+
input = torch.randn(shape, dtype=dtype, device=device)
104+
other = torch.randn(shape, dtype=dtype, device=device)
105+
106+
ninetoothed_output = ninetoothed_mm(input, other)
107+
torch_output = torch_mm(input, other)
108+
triton_output = triton_mm(input, other)
109+
tilelang_output = tilelang_mm(input, other)
110+
ninetoothed_from_tilelang_output = ninetoothed_from_tilelang_mm(input, other)
111+
112+
print(ninetoothed_output)
113+
print(torch_output)
114+
print(triton_output)
115+
print(tilelang_output)
116+
print(ninetoothed_from_tilelang_output)
117+
118+
if torch.allclose(ninetoothed_output, torch_output):
119+
print("✅ NineToothed and PyTorch match.")
120+
else:
121+
print("❌ NineToothed and PyTorch differ.")
122+
if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0):
123+
print("✅ NineToothed and Triton match.")
124+
else:
125+
print("❌ NineToothed and Triton differ.")
126+
if torch.allclose(ninetoothed_output, tilelang_output):
127+
print("✅ NineToothed and TileLang match.")
128+
else:
129+
print("❌ NineToothed and TileLang differ.")
130+
if torch.allclose(ninetoothed_output, ninetoothed_from_tilelang_output):
131+
print("✅ NineToothed and NineToothed from TileLang match.")
132+
else:
133+
print("❌ NineToothed and NineToothed from TileLang differ.")
134+
135+
@triton.testing.perf_report(
136+
triton.testing.Benchmark(
137+
x_names=["m", "n", "k"],
138+
x_vals=[2**i for i in range(8, 13)],
139+
x_log=True,
140+
line_arg="provider",
141+
line_vals=[
142+
"ninetoothed",
143+
"torch",
144+
"triton",
145+
"tilelang",
146+
"ninetoothed_from_tilelang",
147+
],
148+
line_names=[
149+
"NineToothed",
150+
"PyTorch",
151+
"Triton",
152+
"TileLang",
153+
"NineToothed from TileLang",
154+
],
155+
styles=[
156+
("blue", "-"),
157+
("green", "-"),
158+
("orange", "-"),
159+
("cyan", "-"),
160+
("purple", "-"),
161+
],
162+
ylabel="ms",
163+
plot_name="mm-performance",
164+
args={},
165+
)
166+
)
167+
def benchmark(m, n, k, provider):
168+
input = torch.randn((m, k), dtype=dtype, device=device)
169+
other = torch.randn((k, n), dtype=dtype, device=device)
170+
171+
if provider == "ninetoothed":
172+
ms = triton.testing.do_bench(lambda: ninetoothed_mm(input, other))
173+
elif provider == "torch":
174+
ms = triton.testing.do_bench(lambda: torch_mm(input, other))
175+
elif provider == "triton":
176+
ms = triton.testing.do_bench(lambda: triton_mm(input, other))
177+
elif provider == "tilelang":
178+
ms = triton.testing.do_bench(lambda: tilelang_mm(input, other))
179+
elif provider == "ninetoothed_from_tilelang":
180+
ms = triton.testing.do_bench(
181+
lambda: ninetoothed_from_tilelang_mm(input, other)
182+
)
183+
184+
return ms
185+
186+
benchmark.run(show_plots=True, print_data=True, save_path=".")

0 commit comments

Comments
 (0)