Skip to content

Commit e90795e

Browse files
committed
Add pow operator
1 parent 71f6897 commit e90795e

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

src/ntops/kernels/pow.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import functools
2+
3+
import ninetoothed
4+
from ninetoothed import Tensor
5+
from ninetoothed.language import libdevice
6+
7+
from ntops.kernels.element_wise import arrangement
8+
9+
10+
def application(input, exponent, output):
11+
output = libdevice.pow(input, exponent) # noqa: F841
12+
13+
14+
@functools.cache
15+
def make(ndim):
16+
tensors = (Tensor(ndim), Tensor(ndim), Tensor(ndim))
17+
18+
return ninetoothed.make(arrangement, application, tensors)

src/ntops/torch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import ntops.kernels.mul
2727
import ntops.kernels.ne
2828
import ntops.kernels.neg
29+
import ntops.kernels.pow
2930
import ntops.kernels.relu
3031
import ntops.kernels.rsqrt
3132
import ntops.kernels.sigmoid
@@ -314,6 +315,17 @@ def neg(input, *, out=None):
314315
return out
315316

316317

318+
def pow(input, exponent, *, out=None):
319+
if out is None:
320+
out = torch.empty_like(input)
321+
322+
kernel = ntops.kernels.pow.make(input.ndim)
323+
324+
kernel(input, exponent, out)
325+
326+
return out
327+
328+
317329
def relu(input, inplace=False):
318330
if inplace:
319331
output = input

tests/test_pow.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import pytest
2+
import torch
3+
4+
import ntops.torch
5+
from tests.skippers import skip_if_cuda_not_available
6+
from tests.utils import generate_arguments
7+
8+
9+
@skip_if_cuda_not_available
10+
@pytest.mark.parametrize(*generate_arguments())
11+
def test_cuda(shape, dtype, atol, rtol):
12+
# TODO: Test for `float16` later.
13+
if dtype is torch.float16:
14+
return
15+
16+
device = "cuda"
17+
18+
input = torch.randn(shape, dtype=dtype, device=device)
19+
exponent = torch.randn(shape, dtype=dtype, device=device)
20+
21+
ninetoothed_output = ntops.torch.pow(input, exponent)
22+
reference_output = torch.pow(input, exponent)
23+
24+
assert torch.allclose(
25+
ninetoothed_output, reference_output, atol=atol, rtol=rtol, equal_nan=True
26+
)

0 commit comments

Comments
 (0)