Skip to content

Commit b2623d4

Browse files
authored
Merge pull request #46 from InfiniTensor/develop-layer-norm
Add `layer_norm` operator
2 parents 7a96cb5 + 3da0418 commit b2623d4

File tree

3 files changed

+112
-0
lines changed

3 files changed

+112
-0
lines changed

src/ntops/kernels/layer_norm.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import functools
2+
import math
3+
4+
import ninetoothed.language as ntl
5+
from ninetoothed import Tensor
6+
7+
from ntops.kernels.reduction import arrangement
8+
9+
10+
def application(input, weight, bias, eps, output, num_normalized_elements):
11+
_mean = ntl.zeros(input.dtype.shape, dtype=ntl.float32)
12+
13+
for i in range(input.shape[0]):
14+
_mean += ntl.cast(input[i], ntl.float32)
15+
16+
mean = ntl.sum(_mean, 0) / num_normalized_elements
17+
18+
_var = ntl.zeros(input.dtype.shape, dtype=ntl.float32)
19+
20+
for i in range(input.shape[0]):
21+
diff = ntl.cast(input[i], ntl.float32) - mean
22+
diff = ntl.where(input[i].offsets(-1) < input.source.shape[-1], diff, 0)
23+
_var += diff * diff
24+
25+
var = ntl.sum(_var, 0) / num_normalized_elements
26+
27+
std = ntl.sqrt(var + eps)
28+
29+
for i in range(input.shape[0]):
30+
output[i] = (ntl.cast(input[i], ntl.float32) - mean) / std * weight[i] + bias[i]
31+
32+
33+
def premake(ndim, normalized_shape, dtype=None, block_size=None):
34+
dims = tuple(-(dim + 1) for dim in range(len(normalized_shape)))
35+
36+
arrangement_ = functools.partial(arrangement, dim=dims, block_size=block_size)
37+
38+
tensors = (
39+
Tensor(ndim, other=0, dtype=dtype),
40+
Tensor(ndim, dtype=dtype),
41+
Tensor(ndim, dtype=dtype),
42+
Tensor(0, dtype=dtype),
43+
Tensor(ndim, dtype=dtype),
44+
Tensor(0, dtype=dtype, constexpr=True, value=math.prod(normalized_shape)),
45+
)
46+
47+
return arrangement_, application, tensors

src/ntops/torch.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import ntops.kernels.gt
2424
import ntops.kernels.isinf
2525
import ntops.kernels.isnan
26+
import ntops.kernels.layer_norm
2627
import ntops.kernels.le
2728
import ntops.kernels.lt
2829
import ntops.kernels.mm
@@ -256,6 +257,33 @@ def isnan(input):
256257
return output
257258

258259

260+
def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
261+
if isinstance(normalized_shape, int):
262+
normalized_shape = (normalized_shape,)
263+
264+
normalized_shape = tuple(normalized_shape)
265+
266+
if weight is None:
267+
weight = torch.ones_like(input)
268+
else:
269+
weight = weight.expand_as(input)
270+
271+
if bias is None:
272+
bias = torch.zeros_like(input)
273+
else:
274+
bias = bias.expand_as(input)
275+
276+
output = torch.empty_like(input)
277+
278+
kernel = _cached_make(
279+
ntops.kernels.layer_norm.premake, input.ndim, normalized_shape
280+
)
281+
282+
kernel(input, weight, bias, eps, output, math.prod(normalized_shape))
283+
284+
return output
285+
286+
259287
def mm(input, mat2, *, out=None):
260288
m, _ = input.shape
261289
_, n = mat2.shape

tests/test_layer_norm.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import random
2+
3+
import pytest
4+
import torch
5+
6+
import ntops.torch
7+
from tests.skippers import skip_if_cuda_not_available
8+
from tests.utils import generate_arguments
9+
10+
11+
@skip_if_cuda_not_available
12+
@pytest.mark.parametrize("eps", (1e-8, 1e-5, 1e-3))
13+
@pytest.mark.parametrize("bias_is_none", (False, True))
14+
@pytest.mark.parametrize("weight_is_none", (False, True))
15+
@pytest.mark.parametrize(*generate_arguments())
16+
def test_cuda(shape, dtype, atol, rtol, weight_is_none, bias_is_none, eps):
17+
device = "cuda"
18+
19+
input = torch.randn(shape, dtype=dtype, device=device)
20+
normalized_shape = shape[-random.randint(1, len(shape)) :]
21+
if weight_is_none:
22+
weight = None
23+
else:
24+
weight = torch.randn(normalized_shape, dtype=dtype, device=device)
25+
if bias_is_none:
26+
bias = None
27+
else:
28+
bias = torch.randn(normalized_shape, dtype=dtype, device=device)
29+
30+
ninetoothed_output = ntops.torch.layer_norm(
31+
input, normalized_shape, weight=weight, bias=bias, eps=eps
32+
)
33+
reference_output = torch.nn.functional.layer_norm(
34+
input, normalized_shape, weight=weight, bias=bias, eps=eps
35+
)
36+
37+
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)