Skip to content

Commit 5381917

Browse files
committed
Add relu.premake
1 parent fe097d3 commit 5381917

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/ntops/kernels/relu.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@ def application(input, output):
1010
output = max(0.0, input) # noqa: F841
1111

1212

13-
@functools.cache
14-
def make(ndim):
15-
tensors = (Tensor(ndim), Tensor(ndim))
13+
def premake(ndim, dtype, block_size):
14+
tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))
15+
16+
return functools.partial(arrangement, block_size=block_size), application, tensors
1617

17-
return ninetoothed.make(arrangement, application, tensors)
18+
19+
@functools.cache
20+
def make(ndim, dtype=None, block_size=None):
21+
return ninetoothed.make(*premake(ndim, dtype, block_size))

0 commit comments

Comments
 (0)