Skip to content

Commit feb4ca4

Browse files
committed
Refactor kernels to use premake instead of make
1 parent 231c8ea commit feb4ca4

35 files changed

+339
-201
lines changed

src/ntops/kernels/abs.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
import ninetoothed.language as ntl
54
from ninetoothed import Tensor
65

@@ -11,6 +10,9 @@ def application(input, output):
1110
output = ntl.abs(input) # noqa: F841
1211

1312

14-
@functools.cache
15-
def make(ndim):
16-
return ninetoothed.make(arrangement, application, (Tensor(ndim), Tensor(ndim)))
13+
def premake(ndim, dtype=None, block_size=None):
14+
arrangement_ = functools.partial(arrangement, block_size=block_size)
15+
16+
tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))
17+
18+
return arrangement_, application, tensors

src/ntops/kernels/add.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
from ninetoothed import Tensor
54

65
from ntops.kernels.element_wise import arrangement
@@ -10,8 +9,14 @@ def application(input, other, alpha, output):
109
output = input + alpha * other # noqa: F841
1110

1211

13-
@functools.cache
14-
def make(ndim):
15-
tensors = (Tensor(ndim), Tensor(ndim), Tensor(0), Tensor(ndim))
12+
def premake(ndim, dtype=None, block_size=None):
13+
arrangement_ = functools.partial(arrangement, block_size=block_size)
1614

17-
return ninetoothed.make(arrangement, application, tensors)
15+
tensors = (
16+
Tensor(ndim, dtype=dtype),
17+
Tensor(ndim, dtype=dtype),
18+
Tensor(0, dtype=dtype),
19+
Tensor(ndim, dtype=dtype),
20+
)
21+
22+
return arrangement_, application, tensors

src/ntops/kernels/addmm.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
import ninetoothed.language as ntl
54
from ninetoothed import Tensor
65

@@ -47,8 +46,21 @@ def application(input, x, y, beta, alpha, output):
4746
output = beta * input + alpha * mm_output
4847

4948

50-
@functools.cache
51-
def make():
52-
tensors = (Tensor(2), Tensor(2), Tensor(2), Tensor(0), Tensor(0), Tensor(2))
49+
def premake(dtype=None, block_size_m=None, block_size_n=None, block_size_k=None):
50+
arrangement_ = functools.partial(
51+
arrangement,
52+
block_size_m=block_size_m,
53+
block_size_n=block_size_n,
54+
block_size_k=block_size_k,
55+
)
56+
57+
tensors = (
58+
Tensor(2, dtype=dtype),
59+
Tensor(2, dtype=dtype),
60+
Tensor(2, dtype=dtype),
61+
Tensor(0, dtype=dtype),
62+
Tensor(0, dtype=dtype),
63+
Tensor(2, dtype=dtype),
64+
)
5365

54-
return ninetoothed.make(arrangement, application, tensors)
66+
return arrangement_, application, tensors

src/ntops/kernels/bitwise_and.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
from ninetoothed import Tensor
54

65
from ntops.kernels.element_wise import arrangement
@@ -10,8 +9,13 @@ def application(input, other, output):
109
output = input & other # noqa: F841
1110

1211

13-
@functools.cache
14-
def make(ndim):
15-
tensors = (Tensor(ndim), Tensor(ndim), Tensor(ndim))
12+
def premake(ndim, dtype=None, block_size=None):
13+
arrangement_ = functools.partial(arrangement, block_size=block_size)
1614

17-
return ninetoothed.make(arrangement, application, tensors)
15+
tensors = (
16+
Tensor(ndim, dtype=dtype),
17+
Tensor(ndim, dtype=dtype),
18+
Tensor(ndim, dtype=dtype),
19+
)
20+
21+
return arrangement_, application, tensors

src/ntops/kernels/bitwise_not.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
import ninetoothed.language as ntl
54
from ninetoothed import Tensor
65

@@ -15,10 +14,11 @@ def logical_application(input, output):
1514
output = ntl.where(input, False, True) # noqa: F841
1615

1716

18-
@functools.cache
19-
def make(ndim, logical=False):
20-
tensors = (Tensor(ndim), Tensor(ndim))
17+
def premake(ndim, logical=False, dtype=None, block_size=None):
18+
arrangement_ = functools.partial(arrangement, block_size=block_size)
2119

2220
application = logical_application if logical else bitwise_application
2321

24-
return ninetoothed.make(arrangement, application, tensors)
22+
tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))
23+
24+
return arrangement_, application, tensors

src/ntops/kernels/bitwise_or.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
from ninetoothed import Tensor
54

65
from ntops.kernels.element_wise import arrangement
@@ -10,8 +9,13 @@ def application(input, other, output):
109
output = input | other # noqa: F841
1110

1211

13-
@functools.cache
14-
def make(ndim):
15-
tensors = (Tensor(ndim), Tensor(ndim), Tensor(ndim))
12+
def premake(ndim, dtype=None, block_size=None):
13+
arrangement_ = functools.partial(arrangement, block_size=block_size)
1614

17-
return ninetoothed.make(arrangement, application, tensors)
15+
tensors = (
16+
Tensor(ndim, dtype=dtype),
17+
Tensor(ndim, dtype=dtype),
18+
Tensor(ndim, dtype=dtype),
19+
)
20+
21+
return arrangement_, application, tensors

src/ntops/kernels/bmm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
from ninetoothed import Tensor
54

65
from ntops.kernels.mm import BLOCK_SIZE_K, BLOCK_SIZE_M, BLOCK_SIZE_N, application
@@ -36,6 +35,14 @@ def arrangement(
3635
return input_arranged, other_arranged, output_arranged
3736

3837

39-
@functools.cache
40-
def make():
41-
return ninetoothed.make(arrangement, application, (Tensor(3), Tensor(3), Tensor(3)))
38+
def premake(dtype=None, block_size_m=None, block_size_n=None, block_size_k=None):
39+
arrangement_ = functools.partial(
40+
arrangement,
41+
block_size_m=block_size_m,
42+
block_size_n=block_size_n,
43+
block_size_k=block_size_k,
44+
)
45+
46+
tensors = (Tensor(3, dtype=dtype), Tensor(3, dtype=dtype), Tensor(3, dtype=dtype))
47+
48+
return arrangement_, application, tensors

src/ntops/kernels/clamp.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
import ninetoothed.language as ntl
54
from ninetoothed import Tensor
65

@@ -11,8 +10,14 @@ def application(input, min_val, max_val, output):
1110
output = ntl.clamp(input, min_val, max_val) # noqa: F841
1211

1312

14-
@functools.cache
15-
def make(ndim):
16-
tensors = (Tensor(ndim), Tensor(ndim), Tensor(ndim), Tensor(ndim))
13+
def premake(ndim, dtype=None, block_size=None):
14+
arrangement_ = functools.partial(arrangement, block_size=block_size)
1715

18-
return ninetoothed.make(arrangement, application, tensors)
16+
tensors = (
17+
Tensor(ndim, dtype=dtype),
18+
Tensor(ndim, dtype=dtype),
19+
Tensor(ndim, dtype=dtype),
20+
Tensor(ndim, dtype=dtype),
21+
)
22+
23+
return arrangement_, application, tensors

src/ntops/kernels/cos.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
import ninetoothed.language as ntl
54
from ninetoothed import Tensor
65

@@ -11,6 +10,9 @@ def application(input, output):
1110
output = ntl.cos(input) # noqa: F841
1211

1312

14-
@functools.cache
15-
def make(ndim):
16-
return ninetoothed.make(arrangement, application, (Tensor(ndim), Tensor(ndim)))
13+
def premake(ndim, dtype=None, block_size=None):
14+
arrangement_ = functools.partial(arrangement, block_size=block_size)
15+
16+
tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))
17+
18+
return arrangement_, application, tensors

src/ntops/kernels/div.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
import ninetoothed.language as ntl
54
from ninetoothed import Tensor
65

@@ -19,15 +18,20 @@ def floor_application(input, other, output):
1918
output = ntl.floor(input / other) # noqa: F841
2019

2120

22-
@functools.cache
23-
def make(ndim, rounding_mode):
21+
def premake(ndim, rounding_mode, dtype=None, block_size=None):
22+
arrangement_ = functools.partial(arrangement, block_size=block_size)
23+
2424
if rounding_mode == "trunc":
2525
application = trunc_application
2626
elif rounding_mode == "floor":
2727
application = floor_application
2828
else:
2929
application = default_application
3030

31-
tensors = (Tensor(ndim), Tensor(ndim), Tensor(ndim))
31+
tensors = (
32+
Tensor(ndim, dtype=dtype),
33+
Tensor(ndim, dtype=dtype),
34+
Tensor(ndim, dtype=dtype),
35+
)
3236

33-
return ninetoothed.make(arrangement, application, tensors)
37+
return arrangement_, application, tensors

0 commit comments

Comments
 (0)