Skip to content
86 changes: 86 additions & 0 deletions src/ninetoothed/aot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import ast
import ctypes
import itertools
import pathlib
import re
import shutil
import subprocess
import tempfile
import textwrap
Expand Down Expand Up @@ -39,6 +42,8 @@ def aot(
with open(output_path, "w") as f:
f.write(output_content)

return _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)


def _aot(func, caller, kernel_name, num_warps, num_stages):
def _find_tensor_by_source_name(tensors, name):
Expand Down Expand Up @@ -351,6 +356,25 @@ def visit_Lambda(self, node):
return node


class _ArgumentTensor(ctypes.Structure):
_fields_ = [
("data", ctypes.c_void_p),
("shape", ctypes.POINTER(ctypes.c_uint64)),
("strides", ctypes.POINTER(ctypes.c_int64)),
]

@staticmethod
def from_torch_tensor(tensor):
data = ctypes.c_void_p(tensor.data_ptr())
shape = (ctypes.c_uint64 * len(tensor.shape))(*tensor.shape)
strides = (ctypes.c_int64 * len(tensor.stride()))(*tensor.stride())

arg_tensor = _ArgumentTensor(data, shape, strides)
arg_tensor._torch_tensor = tensor

return arg_tensor


def _compile(path, name, signature, grid, num_warps, num_stages):
with tempfile.TemporaryDirectory() as temp_dir:
output_dir = pathlib.Path(temp_dir)
Expand Down Expand Up @@ -389,3 +413,65 @@ def _compile(path, name, signature, grid, num_warps, num_stages):
output_contents[file.name.replace(output_name, name)] = f.read()

return signature_hash, output_contents


def _generate_launch_func(kernel_name, output_dir):
import torch

output_dir = pathlib.Path(output_dir)

_compile_library(kernel_name, output_dir)
library = _load_library(kernel_name, output_dir)
launch_func_name = f"launch_{kernel_name}"
launch_func = getattr(library, launch_func_name)

def _run_launch_func(*args, **kwargs):
arguments = []

for arg in itertools.chain(args, kwargs.values()):
if isinstance(arg, torch.Tensor):
argument = _ArgumentTensor.from_torch_tensor(arg)
elif isinstance(arg, str) and arg in _DTYPE_MAPPING:
argument = tuple(_DTYPE_MAPPING.keys()).index(arg)
else:
argument = arg

arguments.append(argument)

result = launch_func(
ctypes.c_void_p(torch.cuda.current_stream().cuda_stream), *arguments
)

if result != 0:
raise RuntimeError(f"Kernel launch failed with error code: {result}.")

return _run_launch_func


def _compile_library(kernel_name, output_dir):
command = [
"nvcc",
"-shared",
"-Xcompiler",
"-fPIC",
"-lcuda",
"-o",
output_dir / f"{kernel_name}.so",
] + list(output_dir.glob(f"{kernel_name}*.cpp"))

subprocess.run(command, check=True)


def _load_library(kernel_name, kernel_dir):
suffix = ".so"

original_path = kernel_dir / f"{kernel_name}{suffix}"

with tempfile.NamedTemporaryFile(suffix=suffix) as temp_file:
temp_path = temp_file.name

shutil.copy(original_path, temp_path)

library = ctypes.CDLL(temp_path)

return library
9 changes: 8 additions & 1 deletion src/ninetoothed/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import pathlib

import ninetoothed
from ninetoothed.aot import _DTYPE_MAPPING, _HEADER_PATH, _MACRO_MAPPING
from ninetoothed.aot import (
_DTYPE_MAPPING,
_HEADER_PATH,
_MACRO_MAPPING,
_generate_launch_func,
)


def build(premake, configs, *, caller=None, kernel_name=None, output_dir=None):
Expand Down Expand Up @@ -102,6 +107,8 @@ def build(premake, configs, *, caller=None, kernel_name=None, output_dir=None):
(output_dir / source_file_name).write_text(source_content)
(output_dir / header_file_name).write_text(header_content)

return _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)


def _make(premake, config, caller, kernel_name, output_dir):
args, kwargs, compilation_configs = config
Expand Down
99 changes: 12 additions & 87 deletions tests/test_aot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import ctypes
import functools
import itertools
import pathlib
import subprocess

import pytest
import torch
Expand All @@ -15,7 +11,6 @@
import tests.test_conv2d as conv2d
import tests.test_matmul as matmul
from ninetoothed import Tensor
from ninetoothed.aot import _DTYPE_MAPPING
from tests.utils import get_available_devices


Expand All @@ -40,7 +35,7 @@ def _application(input, other, output):
kernel_name = f"add{_generate_kernel_name_suffix()}"
output_dir = ninetoothed.generation.CACHE_DIR

ninetoothed.make(
kernel = ninetoothed.make(
_arrangement,
_application,
tensors,
Expand All @@ -49,8 +44,6 @@ def _application(input, other, output):
output_dir=output_dir,
)

launch_func = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)

shape = (size,)

if test_multi_device:
Expand All @@ -67,7 +60,7 @@ def _application(input, other, output):
other = torch.randn(shape, dtype=dtype, device=device)
output = torch.empty_like(input)

_run_launch_func(launch_func, input, other, output)
kernel(input, other, output)

expected = torch.add(input, other)

Expand All @@ -93,7 +86,7 @@ def test_addmm(m, n, k, dtype, device, ninetoothed_dtype, atol):
kernel_name = f"addmm{_generate_kernel_name_suffix()}"
output_dir = ninetoothed.generation.CACHE_DIR

ninetoothed.make(
kernel = ninetoothed.make(
arrangement,
application,
tensors,
Expand All @@ -102,8 +95,6 @@ def test_addmm(m, n, k, dtype, device, ninetoothed_dtype, atol):
output_dir=output_dir,
)

launch_func = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)

input = torch.randn((m, n), dtype=dtype, device=device)
mat1 = torch.randn((m, k), dtype=dtype, device=device)
mat2 = torch.randn((k, n), dtype=dtype, device=device)
Expand All @@ -113,7 +104,7 @@ def test_addmm(m, n, k, dtype, device, ninetoothed_dtype, atol):
(mat1.shape[0], mat2.shape[1]), dtype=mat1.dtype, device=mat1.device
)

_run_launch_func(launch_func, input, mat1, mat2, beta, alpha, output)
kernel(input, mat1, mat2, beta, alpha, output)

expected = torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha)

Expand Down Expand Up @@ -155,7 +146,7 @@ def test_attention(
kernel_name = f"attention{_generate_kernel_name_suffix()}"
output_dir = ninetoothed.generation.CACHE_DIR

ninetoothed.make(
kernel = ninetoothed.make(
arrangement,
application,
tensors,
Expand All @@ -164,8 +155,6 @@ def test_attention(
output_dir=output_dir,
)

launch_func = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)

shape = (batch_size, num_heads, seq_len, emb_dim)

query = torch.randn(shape, dtype=dtype, device=device)
Expand All @@ -174,7 +163,7 @@ def test_attention(
is_causal = torch.tensor(True)
output = torch.empty(shape, dtype=dtype, device=device)

_run_launch_func(launch_func, query, key, value, is_causal, output)
kernel(query, key, value, is_causal, output)

expected = F.scaled_dot_product_attention(
query, key, value, is_causal=True, scale=1
Expand All @@ -200,7 +189,7 @@ def test_matmul(m, n, k, dtype, device, ninetoothed_dtype):
kernel_name = f"matmul{_generate_kernel_name_suffix()}"
output_dir = ninetoothed.generation.CACHE_DIR

ninetoothed.make(
kernel = ninetoothed.make(
arrangement,
application,
tensors,
Expand All @@ -209,13 +198,11 @@ def test_matmul(m, n, k, dtype, device, ninetoothed_dtype):
output_dir=output_dir,
)

launch_func = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)

lhs = torch.randn((m, k), dtype=dtype, device=device)
rhs = torch.randn((k, n), dtype=dtype, device=device)
output = torch.empty((lhs.shape[0], rhs.shape[1]), dtype=dtype, device=device)

_run_launch_func(launch_func, lhs, rhs, output)
kernel(lhs, rhs, output)

expected = torch.matmul(lhs, rhs)

Expand Down Expand Up @@ -266,7 +253,7 @@ def test_conv2d(
((), {"block_size_m": 128, "block_size_n": 32, "block_size_k": 64}, {}),
)

ninetoothed.build(
kernel = ninetoothed.build(
premake,
configs,
caller=caller,
Expand All @@ -276,7 +263,7 @@ def test_conv2d(
else:
arrangement, application, tensors = premake()

ninetoothed.make(
kernel = ninetoothed.make(
arrangement,
application,
tensors,
Expand All @@ -285,8 +272,6 @@ def test_conv2d(
output_dir=output_dir,
)

launch_func = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)

p = h - r + 1
q = w - s + 1

Expand All @@ -295,77 +280,17 @@ def test_conv2d(
output = torch.empty(n, k, p, q, dtype=dtype, device=device)

if test_build:
config = (
tuple(_DTYPE_MAPPING.keys()).index(ninetoothed_dtype),
constexpr_shapes,
) + tuple(configs[0][1].values())
config = (ninetoothed_dtype, constexpr_shapes) + tuple(configs[0][1].values())
else:
config = ()

_run_launch_func(launch_func, input, filter, output, *config)
kernel(input, filter, output, *config)

expected = F.conv2d(input, filter)

assert torch.allclose(output, expected, rtol=rtol, atol=atol)


class _ArgumentTensor(ctypes.Structure):
_fields_ = [
("data", ctypes.c_void_p),
("shape", ctypes.POINTER(ctypes.c_uint64)),
("strides", ctypes.POINTER(ctypes.c_int64)),
]

@staticmethod
def from_torch_tensor(tensor):
data = ctypes.c_void_p(tensor.data_ptr())
shape = (ctypes.c_uint64 * len(tensor.shape))(*tensor.shape)
strides = (ctypes.c_int64 * len(tensor.stride()))(*tensor.stride())

return _ArgumentTensor(data, shape, strides)


def _run_launch_func(launch_func, *args, **kwargs):
stream = torch.cuda.Stream()

arguments = tuple(
_ArgumentTensor.from_torch_tensor(arg) if isinstance(arg, torch.Tensor) else arg
for arg in itertools.chain(args, kwargs.values())
)

with torch.cuda.stream(stream):
launch_func(ctypes.c_void_p(stream.cuda_stream), *arguments)


def _generate_launch_func(kernel_name, output_dir):
output_dir = pathlib.Path(output_dir)

_compile_library(kernel_name, output_dir)
library = _load_library(kernel_name, output_dir)
launch_func_name = f"launch_{kernel_name}"
launch_func = getattr(library, launch_func_name)

return launch_func


def _compile_library(kernel_name, output_dir):
command = [
"nvcc",
"-shared",
"-Xcompiler",
"-fPIC",
"-lcuda",
"-o",
output_dir / f"{kernel_name}.so",
] + list(output_dir.glob(f"{kernel_name}*.cpp"))

subprocess.run(command, check=True)


def _load_library(kernel_name, kernel_dir):
return ctypes.CDLL(kernel_dir / f"{kernel_name}.so")


def _generate_kernel_name_suffix():
count = _generate_kernel_name_suffix._kernel_count
_generate_kernel_name_suffix._kernel_count += 1
Expand Down