From 40094eb5a7728f01bc51f270c0552e52e410c0ee Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 14 Oct 2025 11:09:48 +0800 Subject: [PATCH] Split `torch.py` into `torch` package with per-op modules --- src/ntops/torch.py | 592 ------------------ src/ntops/torch/__init__.py | 77 +++ src/ntops/torch/abs.py | 15 + src/ntops/torch/add.py | 15 + src/ntops/torch/addmm.py | 18 + src/ntops/torch/bitwise_and.py | 15 + src/ntops/torch/bitwise_not.py | 17 + src/ntops/torch/bitwise_or.py | 15 + src/ntops/torch/bmm.py | 18 + src/ntops/torch/clamp.py | 15 + src/ntops/torch/cos.py | 15 + src/ntops/torch/div.py | 15 + src/ntops/torch/dropout.py | 27 + src/ntops/torch/eq.py | 15 + src/ntops/torch/exp.py | 15 + src/ntops/torch/ge.py | 15 + src/ntops/torch/gelu.py | 14 + src/ntops/torch/gt.py | 15 + src/ntops/torch/isinf.py | 14 + src/ntops/torch/isnan.py | 14 + src/ntops/torch/layer_norm.py | 33 + src/ntops/torch/le.py | 15 + src/ntops/torch/lt.py | 15 + src/ntops/torch/mm.py | 18 + src/ntops/torch/mul.py | 15 + src/ntops/torch/ne.py | 15 + src/ntops/torch/neg.py | 15 + src/ntops/torch/pow.py | 15 + src/ntops/torch/relu.py | 17 + src/ntops/torch/rms_norm.py | 31 + src/ntops/torch/rotary_position_embedding.py | 29 + src/ntops/torch/rsqrt.py | 15 + .../torch/scaled_dot_product_attention.py | 108 ++++ src/ntops/torch/sigmoid.py | 15 + src/ntops/torch/silu.py | 17 + src/ntops/torch/sin.py | 15 + src/ntops/torch/softmax.py | 16 + src/ntops/torch/sub.py | 15 + src/ntops/torch/tanh.py | 15 + src/ntops/torch/utils.py | 25 + 40 files changed, 823 insertions(+), 592 deletions(-) delete mode 100644 src/ntops/torch.py create mode 100644 src/ntops/torch/__init__.py create mode 100644 src/ntops/torch/abs.py create mode 100644 src/ntops/torch/add.py create mode 100644 src/ntops/torch/addmm.py create mode 100644 src/ntops/torch/bitwise_and.py create mode 100644 src/ntops/torch/bitwise_not.py create mode 100644 src/ntops/torch/bitwise_or.py create mode 100644 src/ntops/torch/bmm.py create mode 100644 src/ntops/torch/clamp.py create mode 100644 src/ntops/torch/cos.py create mode 100644 src/ntops/torch/div.py create mode 100644 src/ntops/torch/dropout.py create mode 100644 src/ntops/torch/eq.py create mode 100644 src/ntops/torch/exp.py create mode 100644 src/ntops/torch/ge.py create mode 100644 src/ntops/torch/gelu.py create mode 100644 src/ntops/torch/gt.py create mode 100644 src/ntops/torch/isinf.py create mode 100644 src/ntops/torch/isnan.py create mode 100644 src/ntops/torch/layer_norm.py create mode 100644 src/ntops/torch/le.py create mode 100644 src/ntops/torch/lt.py create mode 100644 src/ntops/torch/mm.py create mode 100644 src/ntops/torch/mul.py create mode 100644 src/ntops/torch/ne.py create mode 100644 src/ntops/torch/neg.py create mode 100644 src/ntops/torch/pow.py create mode 100644 src/ntops/torch/relu.py create mode 100644 src/ntops/torch/rms_norm.py create mode 100644 src/ntops/torch/rotary_position_embedding.py create mode 100644 src/ntops/torch/rsqrt.py create mode 100644 src/ntops/torch/scaled_dot_product_attention.py create mode 100644 src/ntops/torch/sigmoid.py create mode 100644 src/ntops/torch/silu.py create mode 100644 src/ntops/torch/sin.py create mode 100644 src/ntops/torch/softmax.py create mode 100644 src/ntops/torch/sub.py create mode 100644 src/ntops/torch/tanh.py create mode 100644 src/ntops/torch/utils.py diff --git a/src/ntops/torch.py b/src/ntops/torch.py deleted file mode 100644 index 2d5d69a..0000000 --- a/src/ntops/torch.py +++ /dev/null @@ -1,592 +0,0 @@ -import functools -import math -import random - -import ninetoothed -import torch - -import ntops -from ntops.kernels.scaled_dot_product_attention import CausalVariant - - -def abs(input, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.abs.premake, input.ndim) - - kernel(input, out) - - return out - - -def add(input, other, *, alpha=1, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.add.premake, input.ndim) - - kernel(input, other, alpha, out) - - return out - - -def addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): - m, _ = mat1.shape - _, n = mat2.shape - - if out is None: - out = torch.empty((m, n), dtype=input.dtype, device=input.device) - - kernel = _cached_make(ntops.kernels.addmm.premake) - - kernel(input, mat1, mat2, beta, alpha, out, _get_matmul_input_precision()) - - return out - - -def bitwise_and(input, other, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.bitwise_and.premake, input.ndim) - - kernel(input, other, out) - - return out - - -def bitwise_not(input, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make( - ntops.kernels.bitwise_not.premake, input.ndim, input.dtype == torch.bool - ) - - kernel(input, out) - - return out - - -def bitwise_or(input, other, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.bitwise_or.premake, input.ndim) - - kernel(input, other, out) - - return out - - -def bmm(input, mat2, *, out=None): - b, m, _ = input.shape - _, _, n = mat2.shape - - if out is None: - out = torch.empty((b, m, n), dtype=input.dtype, device=input.device) - - kernel = _cached_make(ntops.kernels.bmm.premake) - - kernel(input, mat2, out, _get_matmul_input_precision()) - - return out - - -def clamp(input, min=None, max=None, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.clamp.premake, input.ndim) - - kernel(input, min, max, out) - - return out - - -def cos(input, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.cos.premake, input.ndim) - - kernel(input, out) - - return out - - -def div(input, other, *, rounding_mode=None, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.div.premake, input.ndim, rounding_mode) - - kernel(input, other, out) - - return out - - -def dropout(input, p=0.5, training=True, inplace=False): - if not training or p == 0: - if inplace: - return input - else: - return input.clone() - - seed = random.randrange(0, 2**31) - - if inplace: - output = input - else: - output = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.dropout.premake, input.ndim) - - kernel(input, p, seed, output) - - return output - - -def exp(input, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.exp.premake, input.ndim) - - kernel(input, out) - - return out - - -def ge(input, other, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.ge.premake, input.ndim) - - kernel(input, other, out) - - return out - - -def eq(input, other, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.eq.premake, input.ndim) - - kernel(input, other, out) - - return out - - -def gelu(input, approximate="none"): - output = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.gelu.premake, input.ndim, approximate) - - kernel(input, output) - - return output - - -def gt(input, other, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.gt.premake, input.ndim) - - kernel(input, other, out) - - return out - - -def isinf(input): - output = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.isinf.premake, input.ndim) - - kernel(input, output) - - return output - - -def isnan(input): - output = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.isnan.premake, input.ndim) - - kernel(input, output) - - return output - - -def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): - if isinstance(normalized_shape, int): - normalized_shape = (normalized_shape,) - - normalized_shape = tuple(normalized_shape) - - if weight is None: - weight = torch.ones_like(input) - else: - weight = weight.expand_as(input) - - if bias is None: - bias = torch.zeros_like(input) - else: - bias = bias.expand_as(input) - - output = torch.empty_like(input) - - kernel = _cached_make( - ntops.kernels.layer_norm.premake, input.ndim, normalized_shape - ) - - kernel(input, weight, bias, eps, output, math.prod(normalized_shape)) - - return output - - -def mm(input, mat2, *, out=None): - m, _ = input.shape - _, n = mat2.shape - - if out is None: - out = torch.empty((m, n), dtype=input.dtype, device=input.device) - - kernel = _cached_make(ntops.kernels.mm.premake) - - kernel(input, mat2, out, _get_matmul_input_precision()) - - return out - - -def le(input, other, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.le.premake, input.ndim) - - kernel(input, other, out) - - return out - - -def lt(input, other, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.lt.premake, input.ndim) - - kernel(input, other, out) - - return out - - -def mul(input, other, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.mul.premake, input.ndim) - - kernel(input, other, out) - - return out - - -def ne(input, other, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.ne.premake, input.ndim) - - kernel(input, other, out) - - return out - - -def neg(input, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.neg.premake, input.ndim) - - kernel(input, out) - - return out - - -def pow(input, exponent, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.pow.premake, input.ndim) - - kernel(input, exponent, out) - - return out - - -def relu(input, inplace=False): - if inplace: - output = input - else: - output = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.relu.premake, input.ndim) - - kernel(input, output) - - return output - - -def rms_norm(input, normalized_shape, weight=None, eps=None): - if isinstance(normalized_shape, int): - normalized_shape = (normalized_shape,) - - normalized_shape = tuple(normalized_shape) - - if weight is None: - weight = torch.ones_like(input) - else: - weight = weight.expand_as(input) - - if eps is None: - eps = torch.finfo(input.dtype).eps - - output = torch.empty_like(input) - - kernel = _cached_make( - ntops.kernels.rms_norm.premake, input.ndim, len(normalized_shape) - ) - - kernel(input, weight, eps, output, math.prod(normalized_shape)) - - return output - - -def rotary_position_embedding( - input, sin_table, cos_table, interleaved=True, inplace=False -): - if inplace: - output = input - else: - output = torch.empty_like(input) - - batch_size, _, num_heads, _ = input.shape - - sin_table = sin_table[None, :, None, :].expand(batch_size, -1, num_heads, -1) - cos_table = cos_table[None, :, None, :].expand(batch_size, -1, num_heads, -1) - - kernel = _cached_make( - ntops.kernels.rotary_position_embedding.premake, - input.ndim, - interleaved=interleaved, - num_warps=1, - ) - - kernel(input, sin_table, cos_table, output) - - return output - - -def rsqrt(input, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.rsqrt.premake, input.ndim) - - kernel(input, out) - - return out - - -def scaled_dot_product_attention( - query, - key, - value, - attn_mask=None, - dropout_p=0, - is_causal=False, - scale=None, - enable_gqa=False, - causal_variant=None, - present_key=None, - present_value=None, - present_key_slot=None, - present_value_slot=None, -): - # TODO: Support `dropout_p`. - assert dropout_p == 0, "`dropout_p` is not supported yet." - - assert attn_mask is None or not is_causal, ( - "Cannot use `attn_mask` and `is_causal` together." - ) - - num_heads_q = query.shape[-3] - num_heads_kv = key.shape[-3] - - assert num_heads_kv == value.shape[-3], ( - "Number of heads in `key` and `value` must be the same." - ) - - if not enable_gqa: - assert num_heads_q == num_heads_kv, ( - "Number of heads in `query`, `key`, and `value` must be the same when GQA is not enabled." - ) - else: - assert num_heads_q % num_heads_kv == 0, ( - "Number of heads in `query` must be divisible by number of heads in `key` and `value` when GQA is enabled." - ) - - mask_shape = query.shape[:-1] + (key.shape[-2],) - - if attn_mask is not None: - with_attn_mask = True - - if attn_mask.dtype == torch.bool: - attn_mask = torch.where(attn_mask, 0, float("-inf")) - - attn_mask = attn_mask.expand(mask_shape) - else: - with_attn_mask = False - - attn_mask = torch.empty(mask_shape, device="meta") - - if scale is None: - scale = 1 / math.sqrt(query.shape[-1]) - - if causal_variant is None: - causal_variant = CausalVariant.UPPER_LEFT - - if present_key is not None: - with_kv_cache = True - else: - with_kv_cache = False - - output = torch.empty_like(query, dtype=value.dtype) - - kernel = _cached_make( - ntops.kernels.scaled_dot_product_attention.premake, with_kv_cache - ) - - if with_kv_cache: - kernel( - query, - key, - value, - present_key, - present_value, - present_key_slot, - present_value_slot, - attn_mask, - is_causal, - scale, - output, - with_attn_mask, - causal_variant, - ) - else: - kernel( - query, - key, - value, - attn_mask, - is_causal, - scale, - output, - with_attn_mask, - causal_variant, - ) - - return output - - -def sigmoid(input, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.sigmoid.premake, input.ndim) - - kernel(input, out) - - return out - - -def silu(input, inplace=False): - if inplace: - output = input - else: - output = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.silu.premake, input.ndim) - - kernel(input, output) - - return output - - -def sin(input, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.sin.premake, input.ndim) - - kernel(input, out) - - return out - - -def softmax(input, dim, dtype=None): - tensor_dtype = dtype if dtype is not None else input.dtype - - output = torch.empty_like(input, dtype=tensor_dtype) - - kernel = _cached_make(ntops.kernels.softmax.premake, input.ndim, dim) - - kernel(input, output) - - return output - - -def sub(input, other, *, alpha=1, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.sub.premake, input.ndim) - - kernel(input, other, alpha, out) - - return out - - -def tanh(input, *, out=None): - if out is None: - out = torch.empty_like(input) - - kernel = _cached_make(ntops.kernels.tanh.premake, input.ndim) - - kernel(input, out) - - return out - - -@functools.cache -def _cached_make( - premake, *args, num_warps=None, num_stages=None, max_num_configs=None, **keywords -): - return ninetoothed.make( - *premake(*args, **keywords), - num_warps=num_warps, - num_stages=num_stages, - max_num_configs=max_num_configs, - ) - - -def _get_matmul_input_precision(): - if torch.get_float32_matmul_precision() == "highest": - return ntops.kernels.mm.InputPrecisionVariant.IEEE - - return ntops.kernels.mm.InputPrecisionVariant.TF32 diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py new file mode 100644 index 0000000..056c0eb --- /dev/null +++ b/src/ntops/torch/__init__.py @@ -0,0 +1,77 @@ +from ntops.torch.abs import abs +from ntops.torch.add import add +from ntops.torch.addmm import addmm +from ntops.torch.bitwise_and import bitwise_and +from ntops.torch.bitwise_not import bitwise_not +from ntops.torch.bitwise_or import bitwise_or +from ntops.torch.bmm import bmm +from ntops.torch.clamp import clamp +from ntops.torch.cos import cos +from ntops.torch.div import div +from ntops.torch.dropout import dropout +from ntops.torch.eq import eq +from ntops.torch.exp import exp +from ntops.torch.ge import ge +from ntops.torch.gelu import gelu +from ntops.torch.gt import gt +from ntops.torch.isinf import isinf +from ntops.torch.isnan import isnan +from ntops.torch.layer_norm import layer_norm +from ntops.torch.le import le +from ntops.torch.lt import lt +from ntops.torch.mm import mm +from ntops.torch.mul import mul +from ntops.torch.ne import ne +from ntops.torch.neg import neg +from ntops.torch.pow import pow +from ntops.torch.relu import relu +from ntops.torch.rms_norm import rms_norm +from ntops.torch.rotary_position_embedding import rotary_position_embedding +from ntops.torch.rsqrt import rsqrt +from ntops.torch.scaled_dot_product_attention import scaled_dot_product_attention +from ntops.torch.sigmoid import sigmoid +from ntops.torch.silu import silu +from ntops.torch.sin import sin +from ntops.torch.softmax import softmax +from ntops.torch.sub import sub +from ntops.torch.tanh import tanh + +__all__ = [ + "abs", + "add", + "addmm", + "bitwise_and", + "bitwise_not", + "bitwise_or", + "bmm", + "clamp", + "cos", + "div", + "dropout", + "eq", + "exp", + "ge", + "gelu", + "gt", + "isinf", + "isnan", + "layer_norm", + "le", + "lt", + "mm", + "mul", + "ne", + "neg", + "pow", + "relu", + "rms_norm", + "rotary_position_embedding", + "rsqrt", + "scaled_dot_product_attention", + "sigmoid", + "silu", + "sin", + "softmax", + "sub", + "tanh", +] diff --git a/src/ntops/torch/abs.py b/src/ntops/torch/abs.py new file mode 100644 index 0000000..c412a75 --- /dev/null +++ b/src/ntops/torch/abs.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def abs(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.abs.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/add.py b/src/ntops/torch/add.py new file mode 100644 index 0000000..1004c93 --- /dev/null +++ b/src/ntops/torch/add.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def add(input, other, *, alpha=1, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.add.premake, input.ndim) + + kernel(input, other, alpha, out) + + return out diff --git a/src/ntops/torch/addmm.py b/src/ntops/torch/addmm.py new file mode 100644 index 0000000..1845f69 --- /dev/null +++ b/src/ntops/torch/addmm.py @@ -0,0 +1,18 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make, _get_matmul_input_precision + + +def addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): + m, _ = mat1.shape + _, n = mat2.shape + + if out is None: + out = torch.empty((m, n), dtype=input.dtype, device=input.device) + + kernel = _cached_make(ntops.kernels.addmm.premake) + + kernel(input, mat1, mat2, beta, alpha, out, _get_matmul_input_precision()) + + return out diff --git a/src/ntops/torch/bitwise_and.py b/src/ntops/torch/bitwise_and.py new file mode 100644 index 0000000..0530a5d --- /dev/null +++ b/src/ntops/torch/bitwise_and.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def bitwise_and(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.bitwise_and.premake, input.ndim) + + kernel(input, other, out) + + return out diff --git a/src/ntops/torch/bitwise_not.py b/src/ntops/torch/bitwise_not.py new file mode 100644 index 0000000..4318633 --- /dev/null +++ b/src/ntops/torch/bitwise_not.py @@ -0,0 +1,17 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def bitwise_not(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make( + ntops.kernels.bitwise_not.premake, input.ndim, input.dtype == torch.bool + ) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/bitwise_or.py b/src/ntops/torch/bitwise_or.py new file mode 100644 index 0000000..850bf12 --- /dev/null +++ b/src/ntops/torch/bitwise_or.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def bitwise_or(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.bitwise_or.premake, input.ndim) + + kernel(input, other, out) + + return out diff --git a/src/ntops/torch/bmm.py b/src/ntops/torch/bmm.py new file mode 100644 index 0000000..80eb735 --- /dev/null +++ b/src/ntops/torch/bmm.py @@ -0,0 +1,18 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make, _get_matmul_input_precision + + +def bmm(input, mat2, *, out=None): + b, m, _ = input.shape + _, _, n = mat2.shape + + if out is None: + out = torch.empty((b, m, n), dtype=input.dtype, device=input.device) + + kernel = _cached_make(ntops.kernels.bmm.premake) + + kernel(input, mat2, out, _get_matmul_input_precision()) + + return out diff --git a/src/ntops/torch/clamp.py b/src/ntops/torch/clamp.py new file mode 100644 index 0000000..bf5046a --- /dev/null +++ b/src/ntops/torch/clamp.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def clamp(input, min=None, max=None, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.clamp.premake, input.ndim) + + kernel(input, min, max, out) + + return out diff --git a/src/ntops/torch/cos.py b/src/ntops/torch/cos.py new file mode 100644 index 0000000..c7a72ac --- /dev/null +++ b/src/ntops/torch/cos.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def cos(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.cos.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/div.py b/src/ntops/torch/div.py new file mode 100644 index 0000000..8eda876 --- /dev/null +++ b/src/ntops/torch/div.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def div(input, other, *, rounding_mode=None, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.div.premake, input.ndim, rounding_mode) + + kernel(input, other, out) + + return out diff --git a/src/ntops/torch/dropout.py b/src/ntops/torch/dropout.py new file mode 100644 index 0000000..6e95ffe --- /dev/null +++ b/src/ntops/torch/dropout.py @@ -0,0 +1,27 @@ +import random + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def dropout(input, p=0.5, training=True, inplace=False): + if not training or p == 0: + if inplace: + return input + else: + return input.clone() + + seed = random.randrange(0, 2**31) + + if inplace: + output = input + else: + output = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.dropout.premake, input.ndim) + + kernel(input, p, seed, output) + + return output diff --git a/src/ntops/torch/eq.py b/src/ntops/torch/eq.py new file mode 100644 index 0000000..73dbed7 --- /dev/null +++ b/src/ntops/torch/eq.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def eq(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.eq.premake, input.ndim) + + kernel(input, other, out) + + return out diff --git a/src/ntops/torch/exp.py b/src/ntops/torch/exp.py new file mode 100644 index 0000000..64ac6f7 --- /dev/null +++ b/src/ntops/torch/exp.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def exp(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.exp.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/ge.py b/src/ntops/torch/ge.py new file mode 100644 index 0000000..019469c --- /dev/null +++ b/src/ntops/torch/ge.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def ge(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.ge.premake, input.ndim) + + kernel(input, other, out) + + return out diff --git a/src/ntops/torch/gelu.py b/src/ntops/torch/gelu.py new file mode 100644 index 0000000..10b2ef9 --- /dev/null +++ b/src/ntops/torch/gelu.py @@ -0,0 +1,14 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def gelu(input, approximate="none"): + output = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.gelu.premake, input.ndim, approximate) + + kernel(input, output) + + return output diff --git a/src/ntops/torch/gt.py b/src/ntops/torch/gt.py new file mode 100644 index 0000000..5404ffd --- /dev/null +++ b/src/ntops/torch/gt.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def gt(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.gt.premake, input.ndim) + + kernel(input, other, out) + + return out diff --git a/src/ntops/torch/isinf.py b/src/ntops/torch/isinf.py new file mode 100644 index 0000000..60e2756 --- /dev/null +++ b/src/ntops/torch/isinf.py @@ -0,0 +1,14 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def isinf(input): + output = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.isinf.premake, input.ndim) + + kernel(input, output) + + return output diff --git a/src/ntops/torch/isnan.py b/src/ntops/torch/isnan.py new file mode 100644 index 0000000..7baf32c --- /dev/null +++ b/src/ntops/torch/isnan.py @@ -0,0 +1,14 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def isnan(input): + output = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.isnan.premake, input.ndim) + + kernel(input, output) + + return output diff --git a/src/ntops/torch/layer_norm.py b/src/ntops/torch/layer_norm.py new file mode 100644 index 0000000..0bd641d --- /dev/null +++ b/src/ntops/torch/layer_norm.py @@ -0,0 +1,33 @@ +import math + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + + normalized_shape = tuple(normalized_shape) + + if weight is None: + weight = torch.ones_like(input) + else: + weight = weight.expand_as(input) + + if bias is None: + bias = torch.zeros_like(input) + else: + bias = bias.expand_as(input) + + output = torch.empty_like(input) + + kernel = _cached_make( + ntops.kernels.layer_norm.premake, input.ndim, normalized_shape + ) + + kernel(input, weight, bias, eps, output, math.prod(normalized_shape)) + + return output diff --git a/src/ntops/torch/le.py b/src/ntops/torch/le.py new file mode 100644 index 0000000..25811ca --- /dev/null +++ b/src/ntops/torch/le.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def le(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.le.premake, input.ndim) + + kernel(input, other, out) + + return out diff --git a/src/ntops/torch/lt.py b/src/ntops/torch/lt.py new file mode 100644 index 0000000..b58590a --- /dev/null +++ b/src/ntops/torch/lt.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def lt(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.lt.premake, input.ndim) + + kernel(input, other, out) + + return out diff --git a/src/ntops/torch/mm.py b/src/ntops/torch/mm.py new file mode 100644 index 0000000..43c004b --- /dev/null +++ b/src/ntops/torch/mm.py @@ -0,0 +1,18 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make, _get_matmul_input_precision + + +def mm(input, mat2, *, out=None): + m, _ = input.shape + _, n = mat2.shape + + if out is None: + out = torch.empty((m, n), dtype=input.dtype, device=input.device) + + kernel = _cached_make(ntops.kernels.mm.premake) + + kernel(input, mat2, out, _get_matmul_input_precision()) + + return out diff --git a/src/ntops/torch/mul.py b/src/ntops/torch/mul.py new file mode 100644 index 0000000..b93547a --- /dev/null +++ b/src/ntops/torch/mul.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def mul(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.mul.premake, input.ndim) + + kernel(input, other, out) + + return out diff --git a/src/ntops/torch/ne.py b/src/ntops/torch/ne.py new file mode 100644 index 0000000..fc1568e --- /dev/null +++ b/src/ntops/torch/ne.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def ne(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.ne.premake, input.ndim) + + kernel(input, other, out) + + return out diff --git a/src/ntops/torch/neg.py b/src/ntops/torch/neg.py new file mode 100644 index 0000000..9d11d4a --- /dev/null +++ b/src/ntops/torch/neg.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def neg(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.neg.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/pow.py b/src/ntops/torch/pow.py new file mode 100644 index 0000000..6248b1b --- /dev/null +++ b/src/ntops/torch/pow.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def pow(input, exponent, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.pow.premake, input.ndim) + + kernel(input, exponent, out) + + return out diff --git a/src/ntops/torch/relu.py b/src/ntops/torch/relu.py new file mode 100644 index 0000000..8f764bd --- /dev/null +++ b/src/ntops/torch/relu.py @@ -0,0 +1,17 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def relu(input, inplace=False): + if inplace: + output = input + else: + output = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.relu.premake, input.ndim) + + kernel(input, output) + + return output diff --git a/src/ntops/torch/rms_norm.py b/src/ntops/torch/rms_norm.py new file mode 100644 index 0000000..922d751 --- /dev/null +++ b/src/ntops/torch/rms_norm.py @@ -0,0 +1,31 @@ +import math + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def rms_norm(input, normalized_shape, weight=None, eps=None): + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + + normalized_shape = tuple(normalized_shape) + + if weight is None: + weight = torch.ones_like(input) + else: + weight = weight.expand_as(input) + + if eps is None: + eps = torch.finfo(input.dtype).eps + + output = torch.empty_like(input) + + kernel = _cached_make( + ntops.kernels.rms_norm.premake, input.ndim, len(normalized_shape) + ) + + kernel(input, weight, eps, output, math.prod(normalized_shape)) + + return output diff --git a/src/ntops/torch/rotary_position_embedding.py b/src/ntops/torch/rotary_position_embedding.py new file mode 100644 index 0000000..a95cbb6 --- /dev/null +++ b/src/ntops/torch/rotary_position_embedding.py @@ -0,0 +1,29 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def rotary_position_embedding( + input, sin_table, cos_table, interleaved=True, inplace=False +): + if inplace: + output = input + else: + output = torch.empty_like(input) + + batch_size, _, num_heads, _ = input.shape + + sin_table = sin_table[None, :, None, :].expand(batch_size, -1, num_heads, -1) + cos_table = cos_table[None, :, None, :].expand(batch_size, -1, num_heads, -1) + + kernel = _cached_make( + ntops.kernels.rotary_position_embedding.premake, + input.ndim, + interleaved=interleaved, + num_warps=1, + ) + + kernel(input, sin_table, cos_table, output) + + return output diff --git a/src/ntops/torch/rsqrt.py b/src/ntops/torch/rsqrt.py new file mode 100644 index 0000000..12fdd95 --- /dev/null +++ b/src/ntops/torch/rsqrt.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def rsqrt(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.rsqrt.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/scaled_dot_product_attention.py b/src/ntops/torch/scaled_dot_product_attention.py new file mode 100644 index 0000000..277d089 --- /dev/null +++ b/src/ntops/torch/scaled_dot_product_attention.py @@ -0,0 +1,108 @@ +import math + +import torch + +import ntops +from ntops.kernels.scaled_dot_product_attention import CausalVariant +from ntops.torch.utils import _cached_make + + +def scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0, + is_causal=False, + scale=None, + enable_gqa=False, + causal_variant=None, + present_key=None, + present_value=None, + present_key_slot=None, + present_value_slot=None, +): + # TODO: Support `dropout_p`. + assert dropout_p == 0, "`dropout_p` is not supported yet." + + assert attn_mask is None or not is_causal, ( + "Cannot use `attn_mask` and `is_causal` together." + ) + + num_heads_q = query.shape[-3] + num_heads_kv = key.shape[-3] + + assert num_heads_kv == value.shape[-3], ( + "Number of heads in `key` and `value` must be the same." + ) + + if not enable_gqa: + assert num_heads_q == num_heads_kv, ( + "Number of heads in `query`, `key`, and `value` must be the same when GQA is not enabled." + ) + else: + assert num_heads_q % num_heads_kv == 0, ( + "Number of heads in `query` must be divisible by number of heads in `key` and `value` when GQA is enabled." + ) + + mask_shape = query.shape[:-1] + (key.shape[-2],) + + if attn_mask is not None: + with_attn_mask = True + + if attn_mask.dtype == torch.bool: + attn_mask = torch.where(attn_mask, 0, float("-inf")) + + attn_mask = attn_mask.expand(mask_shape) + else: + with_attn_mask = False + + attn_mask = torch.empty(mask_shape, device="meta") + + if scale is None: + scale = 1 / math.sqrt(query.shape[-1]) + + if causal_variant is None: + causal_variant = CausalVariant.UPPER_LEFT + + if present_key is not None: + with_kv_cache = True + else: + with_kv_cache = False + + output = torch.empty_like(query, dtype=value.dtype) + + kernel = _cached_make( + ntops.kernels.scaled_dot_product_attention.premake, with_kv_cache + ) + + if with_kv_cache: + kernel( + query, + key, + value, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + ) + else: + kernel( + query, + key, + value, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + ) + + return output diff --git a/src/ntops/torch/sigmoid.py b/src/ntops/torch/sigmoid.py new file mode 100644 index 0000000..3f7e812 --- /dev/null +++ b/src/ntops/torch/sigmoid.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def sigmoid(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.sigmoid.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/silu.py b/src/ntops/torch/silu.py new file mode 100644 index 0000000..5d8cb35 --- /dev/null +++ b/src/ntops/torch/silu.py @@ -0,0 +1,17 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def silu(input, inplace=False): + if inplace: + output = input + else: + output = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.silu.premake, input.ndim) + + kernel(input, output) + + return output diff --git a/src/ntops/torch/sin.py b/src/ntops/torch/sin.py new file mode 100644 index 0000000..17a9552 --- /dev/null +++ b/src/ntops/torch/sin.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def sin(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.sin.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/softmax.py b/src/ntops/torch/softmax.py new file mode 100644 index 0000000..0fd6e20 --- /dev/null +++ b/src/ntops/torch/softmax.py @@ -0,0 +1,16 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def softmax(input, dim, dtype=None): + tensor_dtype = dtype if dtype is not None else input.dtype + + output = torch.empty_like(input, dtype=tensor_dtype) + + kernel = _cached_make(ntops.kernels.softmax.premake, input.ndim, dim) + + kernel(input, output) + + return output diff --git a/src/ntops/torch/sub.py b/src/ntops/torch/sub.py new file mode 100644 index 0000000..07ced52 --- /dev/null +++ b/src/ntops/torch/sub.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def sub(input, other, *, alpha=1, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.sub.premake, input.ndim) + + kernel(input, other, alpha, out) + + return out diff --git a/src/ntops/torch/tanh.py b/src/ntops/torch/tanh.py new file mode 100644 index 0000000..c25e799 --- /dev/null +++ b/src/ntops/torch/tanh.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def tanh(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.tanh.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/utils.py b/src/ntops/torch/utils.py new file mode 100644 index 0000000..1e0d5ba --- /dev/null +++ b/src/ntops/torch/utils.py @@ -0,0 +1,25 @@ +import functools + +import ninetoothed +import torch + +import ntops + + +@functools.cache +def _cached_make( + premake, *args, num_warps=None, num_stages=None, max_num_configs=None, **keywords +): + return ninetoothed.make( + *premake(*args, **keywords), + num_warps=num_warps, + num_stages=num_stages, + max_num_configs=max_num_configs, + ) + + +def _get_matmul_input_precision(): + if torch.get_float32_matmul_precision() == "highest": + return ntops.kernels.mm.InputPrecisionVariant.IEEE + + return ntops.kernels.mm.InputPrecisionVariant.TF32