Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/ntops/kernels/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import functools
import math

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.reduction import arrangement


def application(input, weight, eps, output, num_normalized_elements):
_rms = ntl.zeros(input.dtype.shape, dtype=ntl.float32)

for i in range(input.shape[0]):
input_i = ntl.cast(input[i], ntl.float32)
_rms += input_i * input_i

rms = ntl.sqrt(ntl.sum(_rms) / num_normalized_elements + eps)

for i in range(input.shape[0]):
output[i] = input[i] / rms * weight[i]


def premake(ndim, normalized_shape, dtype=None, block_size=None):
dims = tuple(-(dim + 1) for dim in range(len(normalized_shape)))

arrangement_ = functools.partial(arrangement, dim=dims, block_size=block_size)

tensors = (
Tensor(ndim, other=0, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(0, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(0, dtype=dtype, constexpr=True, value=math.prod(normalized_shape)),
)

return arrangement_, application, tensors
24 changes: 24 additions & 0 deletions src/ntops/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import ntops.kernels.neg
import ntops.kernels.pow
import ntops.kernels.relu
import ntops.kernels.rms_norm
import ntops.kernels.rsqrt
import ntops.kernels.scaled_dot_product_attention
import ntops.kernels.sigmoid
Expand Down Expand Up @@ -347,6 +348,29 @@ def relu(input, inplace=False):
return output


def rms_norm(input, normalized_shape, weight=None, eps=None):
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)

normalized_shape = tuple(normalized_shape)

if weight is None:
weight = torch.ones_like(input)
else:
weight = weight.expand_as(input)

if eps is None:
eps = torch.finfo(input.dtype).eps

output = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.rms_norm.premake, input.ndim, normalized_shape)

kernel(input, weight, eps, output, math.prod(normalized_shape))

return output


def rsqrt(input, *, out=None):
if out is None:
out = torch.empty_like(input)
Expand Down
32 changes: 32 additions & 0 deletions tests/test_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import random

import pytest
import torch

import ntops.torch
from tests.skippers import skip_if_cuda_not_available
from tests.utils import generate_arguments


@skip_if_cuda_not_available
@pytest.mark.parametrize("eps", (None, 0, 1e-5, 1e-3))
@pytest.mark.parametrize("weight_is_none", (False, True))
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol, weight_is_none, eps):
device = "cuda"

input = torch.randn(shape, dtype=dtype, device=device)
normalized_shape = shape[-random.randint(1, len(shape)) :]
if weight_is_none:
weight = None
else:
weight = torch.randn(normalized_shape, dtype=dtype, device=device)

ninetoothed_output = ntops.torch.rms_norm(
input, normalized_shape, weight=weight, eps=eps
)
reference_output = torch.nn.functional.rms_norm(
input, normalized_shape, weight=weight, eps=eps
)

assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)