diff --git a/src/ntops/__init__.py b/src/ntops/__init__.py index a4de9bf..2d690be 100644 --- a/src/ntops/__init__.py +++ b/src/ntops/__init__.py @@ -1,3 +1,3 @@ -from ntops import torch +from ntops import kernels, torch -__all__ = ["torch"] +__all__ = ["kernels", "torch"] diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index e69de29..084e52c 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -0,0 +1,79 @@ +from ntops.kernels import ( + abs, + add, + addmm, + bitwise_and, + bitwise_not, + bitwise_or, + bmm, + clamp, + cos, + div, + dropout, + eq, + exp, + ge, + gelu, + gt, + isinf, + isnan, + layer_norm, + le, + lt, + mm, + mul, + ne, + neg, + pow, + relu, + rms_norm, + rotary_position_embedding, + rsqrt, + scaled_dot_product_attention, + sigmoid, + silu, + sin, + softmax, + sub, + tanh, +) + +__all__ = [ + "abs", + "add", + "addmm", + "bitwise_and", + "bitwise_not", + "bitwise_or", + "bmm", + "clamp", + "cos", + "div", + "dropout", + "eq", + "exp", + "ge", + "gelu", + "gt", + "isinf", + "isnan", + "layer_norm", + "le", + "lt", + "mm", + "mul", + "ne", + "neg", + "pow", + "relu", + "rms_norm", + "rotary_position_embedding", + "rsqrt", + "scaled_dot_product_attention", + "sigmoid", + "silu", + "sin", + "softmax", + "sub", + "tanh", +] diff --git a/src/ntops/torch.py b/src/ntops/torch.py index 3312b96..2d5d69a 100644 --- a/src/ntops/torch.py +++ b/src/ntops/torch.py @@ -5,43 +5,7 @@ import ninetoothed import torch -import ntops.kernels.abs -import ntops.kernels.add -import ntops.kernels.addmm -import ntops.kernels.bitwise_and -import ntops.kernels.bitwise_not -import ntops.kernels.bitwise_or -import ntops.kernels.bmm -import ntops.kernels.clamp -import ntops.kernels.cos -import ntops.kernels.div -import ntops.kernels.dropout -import ntops.kernels.eq -import ntops.kernels.exp -import ntops.kernels.ge -import ntops.kernels.gelu -import ntops.kernels.gt -import ntops.kernels.isinf -import ntops.kernels.isnan -import ntops.kernels.layer_norm -import ntops.kernels.le -import ntops.kernels.lt -import ntops.kernels.mm -import ntops.kernels.mul -import ntops.kernels.ne -import ntops.kernels.neg -import ntops.kernels.pow -import ntops.kernels.relu -import ntops.kernels.rms_norm -import ntops.kernels.rotary_position_embedding -import ntops.kernels.rsqrt -import ntops.kernels.scaled_dot_product_attention -import ntops.kernels.sigmoid -import ntops.kernels.silu -import ntops.kernels.sin -import ntops.kernels.softmax -import ntops.kernels.sub -import ntops.kernels.tanh +import ntops from ntops.kernels.scaled_dot_product_attention import CausalVariant