Skip to content

Commit 40094eb

Browse files
committed
Split torch.py into torch package with per-op modules
1 parent 39100bd commit 40094eb

40 files changed

Lines changed: 823 additions & 592 deletions

src/ntops/torch.py

Lines changed: 0 additions & 592 deletions
This file was deleted.

src/ntops/torch/__init__.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from ntops.torch.abs import abs
2+
from ntops.torch.add import add
3+
from ntops.torch.addmm import addmm
4+
from ntops.torch.bitwise_and import bitwise_and
5+
from ntops.torch.bitwise_not import bitwise_not
6+
from ntops.torch.bitwise_or import bitwise_or
7+
from ntops.torch.bmm import bmm
8+
from ntops.torch.clamp import clamp
9+
from ntops.torch.cos import cos
10+
from ntops.torch.div import div
11+
from ntops.torch.dropout import dropout
12+
from ntops.torch.eq import eq
13+
from ntops.torch.exp import exp
14+
from ntops.torch.ge import ge
15+
from ntops.torch.gelu import gelu
16+
from ntops.torch.gt import gt
17+
from ntops.torch.isinf import isinf
18+
from ntops.torch.isnan import isnan
19+
from ntops.torch.layer_norm import layer_norm
20+
from ntops.torch.le import le
21+
from ntops.torch.lt import lt
22+
from ntops.torch.mm import mm
23+
from ntops.torch.mul import mul
24+
from ntops.torch.ne import ne
25+
from ntops.torch.neg import neg
26+
from ntops.torch.pow import pow
27+
from ntops.torch.relu import relu
28+
from ntops.torch.rms_norm import rms_norm
29+
from ntops.torch.rotary_position_embedding import rotary_position_embedding
30+
from ntops.torch.rsqrt import rsqrt
31+
from ntops.torch.scaled_dot_product_attention import scaled_dot_product_attention
32+
from ntops.torch.sigmoid import sigmoid
33+
from ntops.torch.silu import silu
34+
from ntops.torch.sin import sin
35+
from ntops.torch.softmax import softmax
36+
from ntops.torch.sub import sub
37+
from ntops.torch.tanh import tanh
38+
39+
__all__ = [
40+
"abs",
41+
"add",
42+
"addmm",
43+
"bitwise_and",
44+
"bitwise_not",
45+
"bitwise_or",
46+
"bmm",
47+
"clamp",
48+
"cos",
49+
"div",
50+
"dropout",
51+
"eq",
52+
"exp",
53+
"ge",
54+
"gelu",
55+
"gt",
56+
"isinf",
57+
"isnan",
58+
"layer_norm",
59+
"le",
60+
"lt",
61+
"mm",
62+
"mul",
63+
"ne",
64+
"neg",
65+
"pow",
66+
"relu",
67+
"rms_norm",
68+
"rotary_position_embedding",
69+
"rsqrt",
70+
"scaled_dot_product_attention",
71+
"sigmoid",
72+
"silu",
73+
"sin",
74+
"softmax",
75+
"sub",
76+
"tanh",
77+
]

src/ntops/torch/abs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
3+
import ntops
4+
from ntops.torch.utils import _cached_make
5+
6+
7+
def abs(input, *, out=None):
8+
if out is None:
9+
out = torch.empty_like(input)
10+
11+
kernel = _cached_make(ntops.kernels.abs.premake, input.ndim)
12+
13+
kernel(input, out)
14+
15+
return out

src/ntops/torch/add.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
3+
import ntops
4+
from ntops.torch.utils import _cached_make
5+
6+
7+
def add(input, other, *, alpha=1, out=None):
8+
if out is None:
9+
out = torch.empty_like(input)
10+
11+
kernel = _cached_make(ntops.kernels.add.premake, input.ndim)
12+
13+
kernel(input, other, alpha, out)
14+
15+
return out

src/ntops/torch/addmm.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
3+
import ntops
4+
from ntops.torch.utils import _cached_make, _get_matmul_input_precision
5+
6+
7+
def addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
8+
m, _ = mat1.shape
9+
_, n = mat2.shape
10+
11+
if out is None:
12+
out = torch.empty((m, n), dtype=input.dtype, device=input.device)
13+
14+
kernel = _cached_make(ntops.kernels.addmm.premake)
15+
16+
kernel(input, mat1, mat2, beta, alpha, out, _get_matmul_input_precision())
17+
18+
return out

src/ntops/torch/bitwise_and.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
3+
import ntops
4+
from ntops.torch.utils import _cached_make
5+
6+
7+
def bitwise_and(input, other, *, out=None):
8+
if out is None:
9+
out = torch.empty_like(input)
10+
11+
kernel = _cached_make(ntops.kernels.bitwise_and.premake, input.ndim)
12+
13+
kernel(input, other, out)
14+
15+
return out

src/ntops/torch/bitwise_not.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
3+
import ntops
4+
from ntops.torch.utils import _cached_make
5+
6+
7+
def bitwise_not(input, *, out=None):
8+
if out is None:
9+
out = torch.empty_like(input)
10+
11+
kernel = _cached_make(
12+
ntops.kernels.bitwise_not.premake, input.ndim, input.dtype == torch.bool
13+
)
14+
15+
kernel(input, out)
16+
17+
return out

src/ntops/torch/bitwise_or.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
3+
import ntops
4+
from ntops.torch.utils import _cached_make
5+
6+
7+
def bitwise_or(input, other, *, out=None):
8+
if out is None:
9+
out = torch.empty_like(input)
10+
11+
kernel = _cached_make(ntops.kernels.bitwise_or.premake, input.ndim)
12+
13+
kernel(input, other, out)
14+
15+
return out

src/ntops/torch/bmm.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
3+
import ntops
4+
from ntops.torch.utils import _cached_make, _get_matmul_input_precision
5+
6+
7+
def bmm(input, mat2, *, out=None):
8+
b, m, _ = input.shape
9+
_, _, n = mat2.shape
10+
11+
if out is None:
12+
out = torch.empty((b, m, n), dtype=input.dtype, device=input.device)
13+
14+
kernel = _cached_make(ntops.kernels.bmm.premake)
15+
16+
kernel(input, mat2, out, _get_matmul_input_precision())
17+
18+
return out

src/ntops/torch/clamp.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
3+
import ntops
4+
from ntops.torch.utils import _cached_make
5+
6+
7+
def clamp(input, min=None, max=None, *, out=None):
8+
if out is None:
9+
out = torch.empty_like(input)
10+
11+
kernel = _cached_make(ntops.kernels.clamp.premake, input.ndim)
12+
13+
kernel(input, min, max, out)
14+
15+
return out

0 commit comments

Comments
 (0)