Skip to content

Commit 3590d8e

Browse files
committed
Add test cases for ninetoothed.build's auto-tuning support
1 parent e7c1863 commit 3590d8e

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed

tests/test_aot_auto_tuning.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import functools
2+
import shutil
3+
4+
import pytest
5+
import torch
6+
7+
import ninetoothed
8+
import ninetoothed.generation
9+
from ninetoothed import Tensor
10+
from tests.utils import get_available_devices
11+
12+
13+
def arrangement(input, other, alpha, output, block_size=None):
14+
if block_size is None:
15+
block_size = ninetoothed.block_size()
16+
17+
input_arranged = input.tile((block_size,))
18+
other_arranged = other.tile((block_size,))
19+
alpha_arranged = alpha
20+
output_arranged = output.tile((block_size,))
21+
22+
return input_arranged, other_arranged, alpha_arranged, output_arranged
23+
24+
25+
def application(input, other, alpha, output):
26+
output = input + alpha * other # noqa: F841
27+
28+
29+
def premake(size=None, dtype=None, block_size=None):
30+
arrangement_ = functools.partial(arrangement, block_size=block_size)
31+
32+
tensors = (
33+
Tensor(shape=(size,), dtype=dtype),
34+
Tensor(shape=(size,), dtype=dtype),
35+
Tensor(0, dtype=ninetoothed.float64),
36+
Tensor(shape=(size,), dtype=dtype),
37+
)
38+
39+
return arrangement_, application, tensors
40+
41+
42+
@pytest.mark.parametrize("device", get_available_devices())
43+
@pytest.mark.parametrize(
44+
"dtype, ninetoothed_dtype, rtol, atol",
45+
(
46+
(torch.float32, ninetoothed.float32, 1e-5, 1e-5),
47+
(torch.float16, ninetoothed.float16, 1e-3, 1e-3),
48+
),
49+
)
50+
@pytest.mark.parametrize("size", (20260128, 1127))
51+
def test_auto_tuning(size, dtype, device, ninetoothed_dtype, rtol, atol):
52+
caller = device
53+
kernel_name = "add"
54+
output_dir = ninetoothed.generation.CACHE_DIR / "test_auto_tuning"
55+
56+
output_dir.mkdir()
57+
58+
configs = (
59+
((), {"size": 20260128, "dtype": ninetoothed.float16, "block_size": 256}, {}),
60+
((), {"size": 20260128, "dtype": ninetoothed.float16, "block_size": 1024}, {}),
61+
((), {"size": 20260128, "dtype": ninetoothed.float32, "block_size": 512}, {}),
62+
((), {"size": 20260128, "dtype": ninetoothed.float32, "block_size": 1024}, {}),
63+
(
64+
(),
65+
{"size": 1127, "dtype": ninetoothed.float16, "block_size": 64},
66+
{"num_warps": 4},
67+
),
68+
(
69+
(),
70+
{"size": 1127, "dtype": ninetoothed.float16, "block_size": 64},
71+
{"num_warps": 8},
72+
),
73+
(
74+
(),
75+
{"size": 1127, "dtype": ninetoothed.float16, "block_size": 256},
76+
{"num_warps": 4, "num_stages": 1},
77+
),
78+
(
79+
(),
80+
{"size": 1127, "dtype": ninetoothed.float16, "block_size": 256},
81+
{"num_warps": 8, "num_stages": 1},
82+
),
83+
((), {"size": 1127, "dtype": ninetoothed.float32, "block_size": 512}, {}),
84+
)
85+
86+
kernel = ninetoothed.build(
87+
premake,
88+
configs,
89+
meta_parameters=("block_size",),
90+
caller=caller,
91+
kernel_name=kernel_name,
92+
output_dir=output_dir,
93+
)
94+
95+
input = torch.randn((size,), dtype=dtype, device=device)
96+
other = torch.randn((size,), dtype=dtype, device=device)
97+
alpha = torch.randn((), dtype=torch.float64)
98+
output = torch.empty_like(input)
99+
100+
kernel(input, other, alpha, output, size, ninetoothed_dtype)
101+
102+
shutil.rmtree(output_dir)
103+
104+
expected = torch.add(input, other, alpha=alpha)
105+
106+
assert torch.allclose(output, expected, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)