diff --git a/src/ntops/kernels/rms_norm.py b/src/ntops/kernels/rms_norm.py index 4fb9f1d..2646aea 100644 --- a/src/ntops/kernels/rms_norm.py +++ b/src/ntops/kernels/rms_norm.py @@ -1,6 +1,6 @@ import functools -import math +import ninetoothed import ninetoothed.language as ntl from ninetoothed import Tensor @@ -20,17 +20,24 @@ def application(input, weight, eps, output, num_normalized_elements): output[i] = input[i] / rms * weight[i] -def premake(ndim, normalized_shape, dtype=None, block_size=None): - dims = tuple(-(dim + 1) for dim in range(len(normalized_shape))) +def premake( + ndim, + num_normalized_dims, + input_dtype=None, + weight_dtype=None, + output_dtype=None, + block_size=None, +): + dims = tuple(-(dim + 1) for dim in range(num_normalized_dims)) arrangement_ = functools.partial(arrangement, dim=dims, block_size=block_size) tensors = ( - Tensor(ndim, other=0, dtype=dtype), - Tensor(ndim, dtype=dtype), - Tensor(0, dtype=dtype), - Tensor(ndim, dtype=dtype), - Tensor(0, dtype=dtype, constexpr=True, value=math.prod(normalized_shape)), + Tensor(ndim, other=0, dtype=input_dtype), + Tensor(ndim, dtype=weight_dtype), + Tensor(0, dtype=ninetoothed.float32), + Tensor(ndim, dtype=output_dtype), + Tensor(0, dtype=ninetoothed.uint64), ) return arrangement_, application, tensors diff --git a/src/ntops/torch.py b/src/ntops/torch.py index 562c6d0..e834e9d 100644 --- a/src/ntops/torch.py +++ b/src/ntops/torch.py @@ -394,7 +394,9 @@ def rms_norm(input, normalized_shape, weight=None, eps=None): output = torch.empty_like(input) - kernel = _cached_make(ntops.kernels.rms_norm.premake, input.ndim, normalized_shape) + kernel = _cached_make( + ntops.kernels.rms_norm.premake, input.ndim, len(normalized_shape) + ) kernel(input, weight, eps, output, math.prod(normalized_shape))