Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
relu,
rms_norm,
rotary_position_embedding,
round,
rsqrt,
scaled_dot_product_attention,
sigmoid,
Expand Down Expand Up @@ -74,6 +75,7 @@
"relu",
"rms_norm",
"rotary_position_embedding",
"round",
"rsqrt",
"scaled_dot_product_attention",
"sigmoid",
Expand Down
35 changes: 35 additions & 0 deletions src/ntops/kernels/round.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor
from ninetoothed.language import libdevice

from ntops.kernels.element_wise import arrangement


def application(input, output):
output = libdevice.nearbyint(ntl.cast(input, ntl.float32)) # noqa: F841


def application_with_decimals(input, factor, inv_factor, output):
scaled = input * ntl.cast(
factor, input.dtype
) # 在 input 的原始精度下乘,匹配 torch 行为
output = libdevice.nearbyint(ntl.cast(scaled, ntl.float32)) * inv_factor # noqa: F841


def premake(ndim, decimals=0, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

if decimals == 0:
tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))
return arrangement_, application, tensors
else:
tensors = (
Tensor(ndim, dtype=dtype),
Tensor(0, dtype=ninetoothed.float64),
Tensor(0, dtype=ninetoothed.float64),
Tensor(ndim, dtype=dtype),
)
return arrangement_, application_with_decimals, tensors
2 changes: 2 additions & 0 deletions src/ntops/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ntops.torch.relu import relu
from ntops.torch.rms_norm import rms_norm
from ntops.torch.rotary_position_embedding import rotary_position_embedding
from ntops.torch.round import round
from ntops.torch.rsqrt import rsqrt
from ntops.torch.scaled_dot_product_attention import scaled_dot_product_attention
from ntops.torch.sigmoid import sigmoid
Expand Down Expand Up @@ -74,6 +75,7 @@
"relu",
"rms_norm",
"rotary_position_embedding",
"round",
"rsqrt",
"scaled_dot_product_attention",
"sigmoid",
Expand Down
20 changes: 20 additions & 0 deletions src/ntops/torch/round.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def round(input, decimals=0, *, out=None):
if out is None:
out = torch.empty_like(input)

if decimals == 0:
kernel = _cached_make(ntops.kernels.round.premake, input.ndim)
kernel(input, out)
else:
factor = 10.0**decimals
inv_factor = 1.0 / factor
kernel = _cached_make(ntops.kernels.round.premake, input.ndim, decimals=True)
kernel(input, factor, inv_factor, out)

return out
17 changes: 17 additions & 0 deletions tests/test_round.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import torch

import ntops
from tests.skippers import skip_if_cuda_not_available
from tests.utils import generate_arguments


@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_round(shape, dtype, device, rtol, atol):
input = torch.randn(shape, dtype=dtype, device=device)

ninetoothed_output = ntops.torch.round(input)
reference_output = torch.round(input)

assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)