diff --git a/src/ninetoothed/aot.py b/src/ninetoothed/aot.py index 3f9b0ce..1ecfea5 100644 --- a/src/ninetoothed/aot.py +++ b/src/ninetoothed/aot.py @@ -1,6 +1,9 @@ import ast +import ctypes +import itertools import pathlib import re +import shutil import subprocess import tempfile import textwrap @@ -39,6 +42,8 @@ def aot( with open(output_path, "w") as f: f.write(output_content) + return _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir) + def _aot(func, caller, kernel_name, num_warps, num_stages): def _find_tensor_by_source_name(tensors, name): @@ -351,6 +356,25 @@ def visit_Lambda(self, node): return node +class _ArgumentTensor(ctypes.Structure): + _fields_ = [ + ("data", ctypes.c_void_p), + ("shape", ctypes.POINTER(ctypes.c_uint64)), + ("strides", ctypes.POINTER(ctypes.c_int64)), + ] + + @staticmethod + def from_torch_tensor(tensor): + data = ctypes.c_void_p(tensor.data_ptr()) + shape = (ctypes.c_uint64 * len(tensor.shape))(*tensor.shape) + strides = (ctypes.c_int64 * len(tensor.stride()))(*tensor.stride()) + + arg_tensor = _ArgumentTensor(data, shape, strides) + arg_tensor._torch_tensor = tensor + + return arg_tensor + + def _compile(path, name, signature, grid, num_warps, num_stages): with tempfile.TemporaryDirectory() as temp_dir: output_dir = pathlib.Path(temp_dir) @@ -389,3 +413,65 @@ def _compile(path, name, signature, grid, num_warps, num_stages): output_contents[file.name.replace(output_name, name)] = f.read() return signature_hash, output_contents + + +def _generate_launch_func(kernel_name, output_dir): + import torch + + output_dir = pathlib.Path(output_dir) + + _compile_library(kernel_name, output_dir) + library = _load_library(kernel_name, output_dir) + launch_func_name = f"launch_{kernel_name}" + launch_func = getattr(library, launch_func_name) + + def _run_launch_func(*args, **kwargs): + arguments = [] + + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, torch.Tensor): + argument = _ArgumentTensor.from_torch_tensor(arg) + elif isinstance(arg, str) and arg in _DTYPE_MAPPING: + argument = tuple(_DTYPE_MAPPING.keys()).index(arg) + else: + argument = arg + + arguments.append(argument) + + result = launch_func( + ctypes.c_void_p(torch.cuda.current_stream().cuda_stream), *arguments + ) + + if result != 0: + raise RuntimeError(f"Kernel launch failed with error code: {result}.") + + return _run_launch_func + + +def _compile_library(kernel_name, output_dir): + command = [ + "nvcc", + "-shared", + "-Xcompiler", + "-fPIC", + "-lcuda", + "-o", + output_dir / f"{kernel_name}.so", + ] + list(output_dir.glob(f"{kernel_name}*.cpp")) + + subprocess.run(command, check=True) + + +def _load_library(kernel_name, kernel_dir): + suffix = ".so" + + original_path = kernel_dir / f"{kernel_name}{suffix}" + + with tempfile.NamedTemporaryFile(suffix=suffix) as temp_file: + temp_path = temp_file.name + + shutil.copy(original_path, temp_path) + + library = ctypes.CDLL(temp_path) + + return library diff --git a/src/ninetoothed/build.py b/src/ninetoothed/build.py index 116fe73..d55a590 100644 --- a/src/ninetoothed/build.py +++ b/src/ninetoothed/build.py @@ -5,7 +5,12 @@ import pathlib import ninetoothed -from ninetoothed.aot import _DTYPE_MAPPING, _HEADER_PATH, _MACRO_MAPPING +from ninetoothed.aot import ( + _DTYPE_MAPPING, + _HEADER_PATH, + _MACRO_MAPPING, + _generate_launch_func, +) def build(premake, configs, *, caller=None, kernel_name=None, output_dir=None): @@ -102,6 +107,8 @@ def build(premake, configs, *, caller=None, kernel_name=None, output_dir=None): (output_dir / source_file_name).write_text(source_content) (output_dir / header_file_name).write_text(header_content) + return _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir) + def _make(premake, config, caller, kernel_name, output_dir): args, kwargs, compilation_configs = config diff --git a/tests/test_aot.py b/tests/test_aot.py index 6356507..9a217df 100644 --- a/tests/test_aot.py +++ b/tests/test_aot.py @@ -1,8 +1,4 @@ -import ctypes import functools -import itertools -import pathlib -import subprocess import pytest import torch @@ -15,7 +11,6 @@ import tests.test_conv2d as conv2d import tests.test_matmul as matmul from ninetoothed import Tensor -from ninetoothed.aot import _DTYPE_MAPPING from tests.utils import get_available_devices @@ -40,7 +35,7 @@ def _application(input, other, output): kernel_name = f"add{_generate_kernel_name_suffix()}" output_dir = ninetoothed.generation.CACHE_DIR - ninetoothed.make( + kernel = ninetoothed.make( _arrangement, _application, tensors, @@ -49,8 +44,6 @@ def _application(input, other, output): output_dir=output_dir, ) - launch_func = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir) - shape = (size,) if test_multi_device: @@ -67,7 +60,7 @@ def _application(input, other, output): other = torch.randn(shape, dtype=dtype, device=device) output = torch.empty_like(input) - _run_launch_func(launch_func, input, other, output) + kernel(input, other, output) expected = torch.add(input, other) @@ -93,7 +86,7 @@ def test_addmm(m, n, k, dtype, device, ninetoothed_dtype, atol): kernel_name = f"addmm{_generate_kernel_name_suffix()}" output_dir = ninetoothed.generation.CACHE_DIR - ninetoothed.make( + kernel = ninetoothed.make( arrangement, application, tensors, @@ -102,8 +95,6 @@ def test_addmm(m, n, k, dtype, device, ninetoothed_dtype, atol): output_dir=output_dir, ) - launch_func = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir) - input = torch.randn((m, n), dtype=dtype, device=device) mat1 = torch.randn((m, k), dtype=dtype, device=device) mat2 = torch.randn((k, n), dtype=dtype, device=device) @@ -113,7 +104,7 @@ def test_addmm(m, n, k, dtype, device, ninetoothed_dtype, atol): (mat1.shape[0], mat2.shape[1]), dtype=mat1.dtype, device=mat1.device ) - _run_launch_func(launch_func, input, mat1, mat2, beta, alpha, output) + kernel(input, mat1, mat2, beta, alpha, output) expected = torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha) @@ -155,7 +146,7 @@ def test_attention( kernel_name = f"attention{_generate_kernel_name_suffix()}" output_dir = ninetoothed.generation.CACHE_DIR - ninetoothed.make( + kernel = ninetoothed.make( arrangement, application, tensors, @@ -164,8 +155,6 @@ def test_attention( output_dir=output_dir, ) - launch_func = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir) - shape = (batch_size, num_heads, seq_len, emb_dim) query = torch.randn(shape, dtype=dtype, device=device) @@ -174,7 +163,7 @@ def test_attention( is_causal = torch.tensor(True) output = torch.empty(shape, dtype=dtype, device=device) - _run_launch_func(launch_func, query, key, value, is_causal, output) + kernel(query, key, value, is_causal, output) expected = F.scaled_dot_product_attention( query, key, value, is_causal=True, scale=1 @@ -200,7 +189,7 @@ def test_matmul(m, n, k, dtype, device, ninetoothed_dtype): kernel_name = f"matmul{_generate_kernel_name_suffix()}" output_dir = ninetoothed.generation.CACHE_DIR - ninetoothed.make( + kernel = ninetoothed.make( arrangement, application, tensors, @@ -209,13 +198,11 @@ def test_matmul(m, n, k, dtype, device, ninetoothed_dtype): output_dir=output_dir, ) - launch_func = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir) - lhs = torch.randn((m, k), dtype=dtype, device=device) rhs = torch.randn((k, n), dtype=dtype, device=device) output = torch.empty((lhs.shape[0], rhs.shape[1]), dtype=dtype, device=device) - _run_launch_func(launch_func, lhs, rhs, output) + kernel(lhs, rhs, output) expected = torch.matmul(lhs, rhs) @@ -266,7 +253,7 @@ def test_conv2d( ((), {"block_size_m": 128, "block_size_n": 32, "block_size_k": 64}, {}), ) - ninetoothed.build( + kernel = ninetoothed.build( premake, configs, caller=caller, @@ -276,7 +263,7 @@ def test_conv2d( else: arrangement, application, tensors = premake() - ninetoothed.make( + kernel = ninetoothed.make( arrangement, application, tensors, @@ -285,8 +272,6 @@ def test_conv2d( output_dir=output_dir, ) - launch_func = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir) - p = h - r + 1 q = w - s + 1 @@ -295,77 +280,17 @@ def test_conv2d( output = torch.empty(n, k, p, q, dtype=dtype, device=device) if test_build: - config = ( - tuple(_DTYPE_MAPPING.keys()).index(ninetoothed_dtype), - constexpr_shapes, - ) + tuple(configs[0][1].values()) + config = (ninetoothed_dtype, constexpr_shapes) + tuple(configs[0][1].values()) else: config = () - _run_launch_func(launch_func, input, filter, output, *config) + kernel(input, filter, output, *config) expected = F.conv2d(input, filter) assert torch.allclose(output, expected, rtol=rtol, atol=atol) -class _ArgumentTensor(ctypes.Structure): - _fields_ = [ - ("data", ctypes.c_void_p), - ("shape", ctypes.POINTER(ctypes.c_uint64)), - ("strides", ctypes.POINTER(ctypes.c_int64)), - ] - - @staticmethod - def from_torch_tensor(tensor): - data = ctypes.c_void_p(tensor.data_ptr()) - shape = (ctypes.c_uint64 * len(tensor.shape))(*tensor.shape) - strides = (ctypes.c_int64 * len(tensor.stride()))(*tensor.stride()) - - return _ArgumentTensor(data, shape, strides) - - -def _run_launch_func(launch_func, *args, **kwargs): - stream = torch.cuda.Stream() - - arguments = tuple( - _ArgumentTensor.from_torch_tensor(arg) if isinstance(arg, torch.Tensor) else arg - for arg in itertools.chain(args, kwargs.values()) - ) - - with torch.cuda.stream(stream): - launch_func(ctypes.c_void_p(stream.cuda_stream), *arguments) - - -def _generate_launch_func(kernel_name, output_dir): - output_dir = pathlib.Path(output_dir) - - _compile_library(kernel_name, output_dir) - library = _load_library(kernel_name, output_dir) - launch_func_name = f"launch_{kernel_name}" - launch_func = getattr(library, launch_func_name) - - return launch_func - - -def _compile_library(kernel_name, output_dir): - command = [ - "nvcc", - "-shared", - "-Xcompiler", - "-fPIC", - "-lcuda", - "-o", - output_dir / f"{kernel_name}.so", - ] + list(output_dir.glob(f"{kernel_name}*.cpp")) - - subprocess.run(command, check=True) - - -def _load_library(kernel_name, kernel_dir): - return ctypes.CDLL(kernel_dir / f"{kernel_name}.so") - - def _generate_kernel_name_suffix(): count = _generate_kernel_name_suffix._kernel_count _generate_kernel_name_suffix._kernel_count += 1