Skip to content

Commit 2e4ece7

Browse files
authored
Merge pull request #36 from InfiniTensor/develop-silu
Add `silu` operator
2 parents ba321c4 + 90fa74d commit 2e4ece7

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

src/ntops/kernels/silu.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import functools
2+
3+
import ninetoothed
4+
import ninetoothed.language as ntl
5+
from ninetoothed import Tensor
6+
7+
from ntops.kernels.element_wise import arrangement
8+
9+
10+
def application(input, output):
11+
output = input / (1 + ntl.exp(-ntl.cast(input, ntl.float32))) # noqa: F841
12+
13+
14+
@functools.cache
15+
def make(ndim):
16+
tensors = (Tensor(ndim), Tensor(ndim))
17+
18+
return ninetoothed.make(arrangement, application, tensors)

src/ntops/torch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import ntops.kernels.relu
3131
import ntops.kernels.rsqrt
3232
import ntops.kernels.sigmoid
33+
import ntops.kernels.silu
3334
import ntops.kernels.sin
3435
import ntops.kernels.softmax
3536
import ntops.kernels.sub
@@ -362,6 +363,19 @@ def sigmoid(input, *, out=None):
362363
return out
363364

364365

366+
def silu(input, inplace=False):
367+
if inplace:
368+
output = input
369+
else:
370+
output = torch.empty_like(input)
371+
372+
kernel = ntops.kernels.silu.make(input.ndim)
373+
374+
kernel(input, output)
375+
376+
return output
377+
378+
365379
def sin(input, *, out=None):
366380
if out is None:
367381
out = torch.empty_like(input)

tests/test_silu.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
import torch
3+
import torch.nn.functional as F
4+
5+
import ntops.torch
6+
from tests.skippers import skip_if_cuda_not_available
7+
from tests.utils import generate_arguments
8+
9+
10+
@skip_if_cuda_not_available
11+
@pytest.mark.parametrize(*generate_arguments())
12+
def test_cuda(shape, dtype, atol, rtol):
13+
device = "cuda"
14+
15+
input = torch.randn(shape, dtype=dtype, device=device)
16+
17+
# TODO: Add `inplace` tests later.
18+
ninetoothed_output = ntops.torch.silu(input)
19+
reference_output = F.silu(input)
20+
21+
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)