diff --git a/.gitignore b/.gitignore index 82f9275..fbfdfdd 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,10 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# Evaluation results +*.csv +*.html +*.json +*.png +*.tex diff --git a/README.md b/README.md index 03557df..dcf435c 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,20 @@ # NineToothed Examples -This repository contains examples for [NineToothed](https://github.com/InfiniTensor/ninetoothed), including implementations of several common compute kernels written using NineToothed. +This repository contains examples of [NineToothed](https://github.com/InfiniTensor/ninetoothed), including implementations of several common compute kernels written using NineToothed. ## Usage After cloning this repository, you can run any of the examples using Python. For instance, to run the matrix multiplication example, execute the following command: ```bash -python matmul.py +python mm.py ``` ### Autotuning Behavior -By default, the examples apply autotuning, which may take several minutes or longer to complete for complex kernels. If you wish to disable autotuning, you can replace symbol definitions with concrete values. Consider the following example: +Some examples apply autotuning, which may take several minutes or longer to complete for complex kernels. If you wish to disable autotuning, you can replace symbol definitions with concrete values. + +Consider the following example: ```python BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True) @@ -29,6 +31,8 @@ BLOCK_SIZE = 1024 These approaches allow you to obtain results in seconds. However, selecting optimal values is crucial for good performance. Experiment with different values to determine the best configuration. +Note: Please don't forget to also disable the autotuning of the corresponding Triton compute kernels. + ## Third-Party Code and Licenses This project includes code modified or inspired from the following open-source repositories: diff --git a/add.py b/add.py index 06e244c..9b34f89 100644 --- a/add.py +++ b/add.py @@ -1,118 +1,70 @@ -import ninetoothed import torch import triton -import triton.language as tl -from ninetoothed import Symbol, Tensor -BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True) - - -@ninetoothed.jit -def add_kernel( - lhs: Tensor(1).tile((BLOCK_SIZE,)), - rhs: Tensor(1).tile((BLOCK_SIZE,)), - output: Tensor(1).tile((BLOCK_SIZE,)), -): - output = lhs + rhs # noqa: F841 - - -def add(lhs, rhs): - output = torch.empty_like(lhs) - - add_kernel(lhs, rhs, output) - - return output - - -@triton.jit -def triton_add_kernel( - lhs_ptr, - rhs_ptr, - output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - - lhs = tl.load(lhs_ptr + offsets, mask=mask) - rhs = tl.load(rhs_ptr + offsets, mask=mask) - output = lhs + rhs - - tl.store(output_ptr + offsets, output, mask=mask) - - -def triton_add(lhs, rhs): - output = torch.empty_like(lhs) - n_elements = output.numel() - - def grid(meta): - return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - - triton_add_kernel[grid](lhs, rhs, output, n_elements, BLOCK_SIZE=1024) - - return output - - -torch.manual_seed(0) -size = 98432 -lhs = torch.rand(size, device="cuda") -rhs = torch.rand(size, device="cuda") -ninetoothed_output = add(lhs, rhs) -torch_output = lhs + rhs -triton_output = triton_add(lhs, rhs) -print(ninetoothed_output) -print(torch_output) -print(triton_output) -if torch.allclose(ninetoothed_output, torch_output): - print("✅ NineToothed and PyTorch match.") -else: - print("❌ NineToothed and PyTorch differ.") -if torch.allclose(ninetoothed_output, triton_output): - print("✅ NineToothed and Triton match.") -else: - print("❌ NineToothed and Triton differ.") - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["size"], - x_vals=[2**i for i in range(12, 28, 1)], - x_log=True, - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="GB/s", - plot_name="vector-addition-performance", - args={}, +import ops.ninetoothed.torch +import ops.triton.torch + +if __name__ == "__main__": + torch.manual_seed(0) + + size = 98432 + dtype = torch.float16 + device = "cuda" + + input = torch.randn(size, dtype=dtype, device=device) + other = torch.randn(size, dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.add(input, other) + torch_output = input + other + triton_output = ops.triton.torch.add(input, other) + + print(ninetoothed_output) + print(torch_output) + print(triton_output) + + if torch.allclose(ninetoothed_output, torch_output): + print("✅ NineToothed and PyTorch match.") + else: + print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["size"], + x_vals=[2**i for i in range(18, 28)], + x_log=True, + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="ms", + plot_name="add-performance", + args={}, + ) ) -) -def benchmark(size, provider): - lhs = torch.rand(size, device="cuda", dtype=torch.float32) - rhs = torch.rand(size, device="cuda", dtype=torch.float32) - quantiles = [0.5, 0.2, 0.8] + def benchmark(size, provider): + input = torch.randn(size, dtype=dtype, device=device) + other = torch.randn(size, dtype=dtype, device=device) - if provider == "ninetoothed": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: add(lhs, rhs), quantiles=quantiles - ) - elif provider == "torch": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: lhs + rhs, quantiles=quantiles - ) - elif provider == "triton": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: triton_add(lhs, rhs), quantiles=quantiles - ) + ninetoothed_output = ops.ninetoothed.torch.add(input, other) + torch_output = torch.add(input, other) + triton_output = ops.triton.torch.add(input, other) - def gbps(ms): - return 3 * lhs.numel() * lhs.element_size() / ms * 1e-6 + assert torch.allclose(ninetoothed_output, torch_output) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) - return gbps(ms), gbps(max_ms), gbps(min_ms) + if provider == "ninetoothed": + ms = triton.testing.do_bench( + lambda: ops.ninetoothed.torch.add(input, other) + ) + elif provider == "torch": + ms = triton.testing.do_bench(lambda: torch.add(input, other)) + elif provider == "triton": + ms = triton.testing.do_bench(lambda: ops.triton.torch.add(input, other)) + return ms -benchmark.run(print_data=True, show_plots=True, save_path=".") + benchmark.run(print_data=True, show_plots=True, save_path=".") diff --git a/addmm.py b/addmm.py index 13579a8..59b21ad 100644 --- a/addmm.py +++ b/addmm.py @@ -1,258 +1,35 @@ import random -import ninetoothed import torch import triton -import triton.language as tl -from ninetoothed import Tensor - -import matmul - - -def arrangement(input, mat1, mat2, beta, alpha, output): - _, _, input_arranged = matmul.arrangement(mat1, mat2, input) - - mat1_arrange, mat2_arranged, output_arranged = matmul.arrangement( - mat1, mat2, output - ) - - return input_arranged, mat1_arrange, mat2_arranged, beta, alpha, output_arranged - - -def application(input, mat1, mat2, beta, alpha, output): - matmul.application(mat1, mat2, output) - output = beta * input + alpha * output - - -addmm_kernel = ninetoothed.make( - arrangement, - application, - (Tensor(2), Tensor(2), Tensor(2), Tensor(0), Tensor(0), Tensor(2)), -) - - -def addmm(input, mat1, mat2, beta=1, alpha=1): - output = torch.empty( - (mat1.shape[0], mat2.shape[1]), dtype=torch.float16, device=mat1.device - ) - - addmm_kernel(input, mat1, mat2, beta, alpha, output) - - return output - - -@triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - ], - key=["m", "n", "k"], -) -@triton.jit -def triton_addmm_kernel( - input_ptr, - mat1_ptr, - mat2_ptr, - output_ptr, - m, - n, - k, - input_stride_m, - input_stride_n, - mat1_stride_m, - mat1_stride_k, - mat2_stride_k, - mat2_stride_n, - output_stride_m, - output_stride_n, - beta, - alpha, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - pid = tl.program_id(0) - num_pid_m = tl.cdiv(m, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(n, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % m - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n - offs_k = tl.arange(0, BLOCK_SIZE_K) - mat1_ptrs = mat1_ptr + ( - offs_am[:, None] * mat1_stride_m + offs_k[None, :] * mat1_stride_k - ) - mat2_ptrs = mat2_ptr + ( - offs_k[:, None] * mat2_stride_k + offs_bn[None, :] * mat2_stride_n - ) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for i in range(0, tl.cdiv(k, BLOCK_SIZE_K)): - mat1 = tl.load( - mat1_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K, other=0.0 - ) - mat2 = tl.load( - mat2_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K, other=0.0 - ) - accumulator = tl.dot(mat1, mat2, accumulator) - mat1_ptrs += BLOCK_SIZE_K * mat1_stride_k - mat2_ptrs += BLOCK_SIZE_K * mat2_stride_k - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - mask_c = (offs_cm[:, None] < m) & (offs_cn[None, :] < n) - - input_ptrs = ( - input_ptr - + input_stride_m * offs_cm[:, None] - + input_stride_n * offs_cn[None, :] - ) - input = tl.load(input_ptrs, mask=mask_c) - - output = beta * input + alpha * accumulator.to(tl.float16) - - output_ptrs = ( - output_ptr - + output_stride_m * offs_cm[:, None] - + output_stride_n * offs_cn[None, :] - ) - tl.store(output_ptrs, output, mask=mask_c) - - -def triton_addmm(input, mat1, mat2, beta=1, alpha=1): - output = torch.empty( - (mat1.shape[0], mat2.shape[1]), dtype=torch.float16, device=mat1.device - ) - - def grid(meta): - return ( - triton.cdiv(mat1.shape[0], meta["BLOCK_SIZE_M"]) - * triton.cdiv(mat2.shape[1], meta["BLOCK_SIZE_N"]), - ) - - triton_addmm_kernel[grid]( - input, - mat1, - mat2, - output, - mat1.shape[0], - mat2.shape[1], - mat1.shape[1], - input.stride(0), - input.stride(1), - mat1.stride(0), - mat1.stride(1), - mat2.stride(0), - mat2.stride(1), - output.stride(0), - output.stride(1), - beta, - alpha, - ) - - return output +import ops.ninetoothed.torch +import ops.triton.torch if __name__ == "__main__": random.seed(0) torch.manual_seed(0) + shape = (512, 512) - input = torch.randn(shape, dtype=torch.float16, device="cuda") - mat1 = torch.randn(shape, dtype=torch.float16, device="cuda") - mat2 = torch.randn(shape, dtype=torch.float16, device="cuda") + dtype = torch.float16 + device = "cuda" + + input = torch.randn(shape, dtype=dtype, device=device) + mat1 = torch.randn(shape, dtype=dtype, device=device) + mat2 = torch.randn(shape, dtype=dtype, device=device) beta = random.uniform(0, 1) alpha = random.uniform(0, 1) - ninetoothed_output = addmm(input, mat1, mat2, beta=beta, alpha=alpha) + + ninetoothed_output = ops.ninetoothed.torch.addmm( + input, mat1, mat2, beta=beta, alpha=alpha + ) torch_output = torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha) - triton_output = triton_addmm(input, mat1, mat2, beta=beta, alpha=alpha) + triton_output = ops.triton.torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + print(ninetoothed_output) print(torch_output) print(triton_output) + if torch.allclose(ninetoothed_output, torch_output, atol=0.01, rtol=0.01): print("✅ NineToothed and PyTorch match.") else: @@ -270,38 +47,35 @@ def grid(meta): line_vals=["ninetoothed", "torch", "triton"], line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="TFLOPS", + ylabel="ms", plot_name="addmm-performance", args={}, ) ) def benchmark(m, n, k, provider): - input = torch.randn((m, n), dtype=torch.float16, device="cuda") - mat1 = torch.randn((m, k), dtype=torch.float16, device="cuda") - mat2 = torch.randn((k, n), dtype=torch.float16, device="cuda") + input = torch.randn((m, n), dtype=dtype, device=device) + mat1 = torch.randn((m, k), dtype=dtype, device=device) + mat2 = torch.randn((k, n), dtype=dtype, device=device) beta = random.uniform(0, 1) alpha = random.uniform(0, 1) - quantiles = [0.5, 0.2, 0.8] if provider == "ninetoothed": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: addmm(input, mat1, mat2, beta=beta, alpha=alpha), - quantiles=quantiles, + ms = triton.testing.do_bench( + lambda: ops.ninetoothed.torch.addmm( + input, mat1, mat2, beta=beta, alpha=alpha + ) ) elif provider == "torch": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha), - quantiles=quantiles, + ms = triton.testing.do_bench( + lambda: torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha) ) elif provider == "triton": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: triton_addmm(input, mat1, mat2, beta=beta, alpha=alpha), - quantiles=quantiles, + ms = triton.testing.do_bench( + lambda: ops.triton.torch.addmm( + input, mat1, mat2, beta=beta, alpha=alpha + ) ) - def perf(ms): - return (2 * m * n * k + 3 * m * n) * 1e-12 / (ms * 1e-3) - - return perf(ms), perf(max_ms), perf(min_ms) + return ms benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/attention.py b/attention.py deleted file mode 100644 index 48e51ec..0000000 --- a/attention.py +++ /dev/null @@ -1,375 +0,0 @@ -import functools -import math - -import ninetoothed -import ninetoothed.language as ntl -import torch -import torch.nn as nn -import torch.nn.functional as F -import triton -import triton.language as tl -from ninetoothed import Symbol, Tensor -from transformers.models.llama.modeling_llama import repeat_kv - -import rope - - -def arrangement(q, k, v, scale, o): - BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) - BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) - - def arrange_q_or_o(input): - arranged = input.tile((1, 1, BLOCK_SIZE_M, -1)) - arranged.dtype = arranged.dtype.squeeze((0, 1)) - - return arranged - - def arrange_k_or_v(input): - arranged = ( - input.tile((1, 1, BLOCK_SIZE_N, -1)) - .tile((1, 1, -1, -1)) - .expand((-1, -1, q_arranged.shape[-2], -1)) - ) - arranged.dtype = arranged.dtype.squeeze((0, 1, 3)) - arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) - - return arranged - - q_arranged = arrange_q_or_o(q) - - return q_arranged, arrange_k_or_v(k), arrange_k_or_v(v), scale, arrange_q_or_o(o) - - -def application(q, k, v, scale, o): - q_loaded = (q * scale * 1.44269504089).to(ntl.float16) - - acc = ntl.zeros((q.shape[-2], q.shape[-1]), dtype=ntl.float32) - l_i = ntl.full((q.shape[-2],), 1, dtype=ntl.float32) - m_i = ntl.full((q.shape[-2],), float("-inf"), dtype=ntl.float32) - - for i in range(k.shape[0]): - qk = ntl.dot(q_loaded, ntl.trans(k[i])) - - m_ij = ntl.maximum(m_i, ntl.max(qk, 1)) - p = ntl.exp2(qk - m_ij[:, None]) - l_ij = ntl.sum(p, 1) - - alpha = ntl.exp2(m_i - m_ij) - acc = acc * alpha[:, None] + ntl.dot(p.to(ntl.float16), v[i]) - m_i = m_ij - l_i = l_i * alpha + l_ij - - acc /= l_i[:, None] - o = acc # noqa: F841 - - -q, k, v, o = ( - Tensor(4, shape_options=(None, None, None, {"constexpr": True, "upper_bound": 128})) - for _ in range(4) -) -attention_kernel = ninetoothed.make(arrangement, application, (q, k, v, Tensor(0), o)) - - -def attention(q, k, v, scale=None): - if scale is None: - scale = 1 / math.sqrt(q.shape[-1]) - - o = torch.empty_like(q, dtype=v.dtype) - - attention_kernel(q, k, v, scale, o) - - return o - - -class Attention(nn.Module): - def __init__(self, other): - super().__init__() - - self.__dict__ = other.__dict__ - - def forward( - self, - hidden_states, - position_embeddings, - attention_mask, - past_key_value, - cache_position, - **kwargs, - ): - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape) - key_states = self.k_proj(hidden_states).view(hidden_shape) - value_states = self.v_proj(hidden_states).view(hidden_shape) - - cos_table, sin_table = position_embeddings - - _rope(query_states, sin_table, cos_table) - _rope(key_states, sin_table, cos_table) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if past_key_value is not None: - cache_kwargs = { - "sin": sin_table, - "cos": cos_table, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_dtype = torch.float16 - attn_output = attention( - query_states.to(attn_dtype), - key_states.to(attn_dtype), - value_states.to(attn_dtype), - scale=self.scaling, - ).to(query_states.dtype) - attn_output = attn_output.transpose(1, 2) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - - return attn_output, None - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32}, num_stages=4, num_warps=8 - ), - ], - key=["EMB_DIM"], -) -@triton.jit -def triton_attention_kernel( - q_ptr, - k_ptr, - v_ptr, - o_ptr, - q_stride_z, - q_stride_h, - q_stride_m, - q_stride_k, - k_stride_z, - k_stride_h, - k_stride_n, - k_stride_k, - v_stride_z, - v_stride_h, - v_stride_k, - v_stride_n, - o_stride_z, - o_stride_h, - o_stride_m, - o_stride_n, - scale, - SEQ_LEN: tl.constexpr, - EMB_DIM: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, -): - off_m = tl.program_id(0) - off_h = tl.program_id(1) - off_z = tl.program_id(2) - - offs_m_start = off_m * BLOCK_SIZE_M - - q_off = off_z * q_stride_z + off_h * q_stride_h - q_block_ptr = tl.make_block_ptr( - base=q_ptr + q_off, - shape=(SEQ_LEN, EMB_DIM), - strides=(q_stride_m, q_stride_k), - offsets=(offs_m_start, 0), - block_shape=(BLOCK_SIZE_M, EMB_DIM), - order=(1, 0), - ) - k_off = off_z * k_stride_z + off_h * k_stride_h - k_block_ptr = tl.make_block_ptr( - base=k_ptr + k_off, - shape=(EMB_DIM, SEQ_LEN), - strides=(k_stride_k, k_stride_n), - offsets=(0, 0), - block_shape=(EMB_DIM, BLOCK_SIZE_N), - order=(0, 1), - ) - v_off = off_z * v_stride_z + off_h * v_stride_h - v_block_ptr = tl.make_block_ptr( - base=v_ptr + v_off, - shape=(SEQ_LEN, EMB_DIM), - strides=(v_stride_k, v_stride_n), - offsets=(0, 0), - block_shape=(BLOCK_SIZE_N, EMB_DIM), - order=(1, 0), - ) - o_off = off_z * o_stride_z + off_h * o_stride_h - o_block_ptr = tl.make_block_ptr( - base=o_ptr + o_off, - shape=(SEQ_LEN, EMB_DIM), - strides=(o_stride_m, o_stride_n), - offsets=(offs_m_start, 0), - block_shape=(BLOCK_SIZE_M, EMB_DIM), - order=(1, 0), - ) - - q = (tl.load(q_block_ptr) * scale * 1.44269504089).to(q_block_ptr.type.element_ty) - - acc = tl.zeros((BLOCK_SIZE_M, EMB_DIM), dtype=tl.float32) - l_i = tl.full((BLOCK_SIZE_M,), 1, dtype=tl.float32) - m_i = tl.full((BLOCK_SIZE_M,), float("-inf"), dtype=tl.float32) - - for _ in range(0, tl.cdiv(SEQ_LEN, BLOCK_SIZE_N)): - k = tl.load(k_block_ptr) - - qk = tl.dot(q, k) - - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - p = tl.exp2(qk) - l_ij = tl.sum(p, 1) - - v = tl.load(v_block_ptr) - alpha = tl.exp2(m_i - m_ij) - acc = acc * alpha[:, None] + tl.dot(p.to(v_block_ptr.type.element_ty), v) - m_i = m_ij - l_i = l_i * alpha + l_ij - - v_block_ptr = tl.advance(v_block_ptr, (BLOCK_SIZE_N, 0)) - k_block_ptr = tl.advance(k_block_ptr, (0, BLOCK_SIZE_N)) - - acc /= l_i[:, None] - - tl.store(o_block_ptr, acc.to(o_ptr.type.element_ty)) - - -def triton_attention(q, k, v, scale=None): - o = torch.empty_like(q) - - batch_size, num_heads, seq_len, emb_dim = q.shape - - if scale is None: - scale = 1 / math.sqrt(emb_dim) - - def grid(meta): - return ( - triton.cdiv(seq_len, meta["BLOCK_SIZE_M"]), - num_heads, - batch_size, - ) - - triton_attention_kernel[grid]( - q, - k, - v, - o, - *q.stride(), - *k.stride(), - *v.stride(), - *o.stride(), - scale=scale, - SEQ_LEN=seq_len, - EMB_DIM=emb_dim, - ) - - return o - - -_rope_kernel = ninetoothed.make( - functools.partial(rope.arrangement, interleaved=False), - rope.application, - rope.tensors, -) - - -def _rope(x, sin_table, cos_table): - _, _, num_heads, _ = x.shape - sin_table = sin_table.unsqueeze(2).expand(-1, -1, num_heads, -1) - cos_table = cos_table.unsqueeze(2).expand(-1, -1, num_heads, -1) - - _rope_kernel(x, sin_table, cos_table) - - -if __name__ == "__main__": - torch.manual_seed(0) - shape = (2, 4, 1024, 64) - dtype = torch.float16 - q = torch.randn(shape, dtype=dtype, device="cuda") - k = torch.randn(shape, dtype=dtype, device="cuda") - v = torch.randn(shape, dtype=dtype, device="cuda") - - ninetoothed_output = attention(q, k, v) - torch_output = F.scaled_dot_product_attention(q, k, v) - triton_output = triton_attention(q, k, v) - print(ninetoothed_output) - print(torch_output) - print(triton_output) - if torch.allclose(ninetoothed_output, torch_output, atol=0.01): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0.01): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["seq_len"], - x_vals=[2**i for i in range(10, 15)], - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="TFLOPS", - plot_name="attention-performance", - args={}, - ) - ) - def benchmark(seq_len, provider): - batch_size, num_heads, emb_dim = 4, 32, 64 - shape = (batch_size, num_heads, seq_len, emb_dim) - dtype = torch.float16 - q = torch.randn(shape, dtype=dtype, device="cuda") - k = torch.randn(shape, dtype=dtype, device="cuda") - v = torch.randn(shape, dtype=dtype, device="cuda") - - if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: attention(q, k, v)) - elif provider == "torch": - ms = triton.testing.do_bench( - lambda: F.scaled_dot_product_attention(q, k, v, scale=1) - ) - elif provider == "triton": - ms = triton.testing.do_bench(lambda: triton_attention(q, k, v)) - - def perf(ms): - flops_per_matmul = 2 * batch_size * num_heads * seq_len * seq_len * emb_dim - total_flops = 2 * flops_per_matmul - - return total_flops * 1e-12 / (ms * 1e-3) - - return perf(ms) - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/bmm.py b/bmm.py index 7cc1b30..3243511 100644 --- a/bmm.py +++ b/bmm.py @@ -1,54 +1,8 @@ -import ninetoothed import torch -from ninetoothed import Symbol, Tensor - -import matmul - -BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) -BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) -BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True) - - -def arrangement( - lhs, - rhs, - output, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, -): - output_arranged = output.tile((1, BLOCK_SIZE_M, BLOCK_SIZE_N)) - output_arranged.dtype = output_arranged.dtype.squeeze(0) - - lhs_arranged = lhs.tile((1, BLOCK_SIZE_M, BLOCK_SIZE_K)) - lhs_arranged = lhs_arranged.tile((1, 1, -1)) - lhs_arranged = lhs_arranged.expand((-1, -1, output_arranged.shape[-1])) - lhs_arranged.dtype = lhs_arranged.dtype.squeeze((0, 1)) - lhs_arranged.dtype.dtype = lhs_arranged.dtype.dtype.squeeze(0) - - rhs_arranged = rhs.tile((1, BLOCK_SIZE_K, BLOCK_SIZE_N)) - rhs_arranged = rhs_arranged.tile((1, -1, 1)) - rhs_arranged = rhs_arranged.expand((-1, output_arranged.shape[-2], -1)) - rhs_arranged.dtype = rhs_arranged.dtype.squeeze((0, 2)) - rhs_arranged.dtype.dtype = rhs_arranged.dtype.dtype.squeeze(0) - - return lhs_arranged, rhs_arranged, output_arranged - - -bmm_kernel = ninetoothed.make( - arrangement, matmul.application, (Tensor(3), Tensor(3), Tensor(3)) -) - - -def bmm(lhs, rhs): - output = torch.empty( - (lhs.shape[0], lhs.shape[-2], rhs.shape[-1]), dtype=lhs.dtype, device=lhs.device - ) - - bmm_kernel(lhs, rhs, output) - - return output +import triton +import ops.ninetoothed.torch +import ops.triton.torch if __name__ == "__main__": torch.manual_seed(0) @@ -56,16 +10,61 @@ def bmm(lhs, rhs): batch_size, m, n, k = 4, 512, 2028, 1024 dtype = torch.float16 device = "cuda" - lhs = torch.randn(batch_size, m, k, dtype=dtype, device=device) - rhs = torch.randn(batch_size, k, n, dtype=dtype, device=device) - ninetoothed_output = bmm(lhs, rhs) - torch_output = torch.bmm(lhs, rhs) + input = torch.randn(batch_size, m, k, dtype=dtype, device=device) + other = torch.randn(batch_size, k, n, dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.bmm(input, other) + torch_output = torch.bmm(input, other) + triton_output = ops.triton.torch.bmm(input, other) print(ninetoothed_output) print(torch_output) + print(triton_output) if torch.allclose(ninetoothed_output, torch_output): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["m", "n", "k"], + x_vals=[2**i for i in range(3, 13)], + x_log=True, + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="ms", + plot_name="bmm-performance", + args={"b": 4}, + ) + ) + def benchmark(b, m, n, k, provider): + input = torch.randn((b, m, k), dtype=dtype, device=device) + other = torch.randn((b, k, n), dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.bmm(input, other) + torch_output = torch.bmm(input, other) + triton_output = ops.triton.torch.bmm(input, other) + + assert torch.allclose(ninetoothed_output, torch_output) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) + + if provider == "ninetoothed": + ms = triton.testing.do_bench( + lambda: ops.ninetoothed.torch.bmm(input, other) + ) + elif provider == "torch": + ms = triton.testing.do_bench(lambda: torch.bmm(input, other)) + elif provider == "triton": + ms = triton.testing.do_bench(lambda: ops.triton.torch.bmm(input, other)) + + return ms + + benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/conv2d.py b/conv2d.py index d05a146..c7bd8bb 100644 --- a/conv2d.py +++ b/conv2d.py @@ -1,336 +1,75 @@ -import ninetoothed import torch import torch.nn.functional as F import triton -import triton.language as tl -from ninetoothed import Tensor - -import matmul - - -def arrangement(input, filter, output): - input_tiled = input.tile((1, *filter.shape[1:]), strides=(-1, -1, 1, 1)) - input_squeezed = input_tiled.squeeze(1) - input_squeezed.dtype = input_squeezed.dtype.squeeze(0) - input_raveled = input_squeezed.ravel() - input_flattened = input_raveled.flatten(end_dim=3).flatten(start_dim=1) - - filter_flattened = filter.flatten(start_dim=1) - filter_permuted = filter_flattened.permute((1, 0)) - - output_flattened = output.permute((0, 2, 3, 1)).flatten(end_dim=3) - - return matmul.arrangement(input_flattened, filter_permuted, output_flattened) - - -filter_shape_options = ( - None, - None, - {"constexpr": True, "upper_bound": 16}, - {"constexpr": True, "upper_bound": 16}, -) -tensors = (Tensor(4), Tensor(4, shape_options=filter_shape_options), Tensor(4)) -conv2d_kernel = ninetoothed.make(arrangement, matmul.application, tensors) - - -def conv2d(input, filter): - n, _, h, w = input.shape - k, _, r, s = filter.shape - p = h - r + 1 - q = w - s + 1 - - output = torch.empty((n, k, p, q), device=input.device, dtype=input.dtype) - - conv2d_kernel(input, filter, output) - - return output - - -@triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - ], - key=["n", "c", "h", "w", "k", "r", "s"], -) -@triton.jit -def triton_conv2d_kernel( - input_ptr, - filter_ptr, - output_ptr, - n, - c, - h, - w, - k, - r, - s, - input_stride_n, - input_stride_c, - input_stride_h, - input_stride_w, - filter_stride_k, - filter_stride_c, - filter_stride_r, - filter_stride_s, - output_stride_n, - output_stride_k, - output_stride_p, - output_stride_q, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - p = h - r + 1 - q = w - s + 1 - - gemm_m = n * p * q - gemm_n = k - gemm_k = c * r * s - - pid = tl.program_id(0) - num_pid_gemm_m = tl.cdiv(gemm_m, BLOCK_SIZE_M) - num_pid_gemm_n = tl.cdiv(gemm_n, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_gemm_n - group_id = pid // num_pid_in_group - first_pid_gemm_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_gemm_m - first_pid_gemm_m, GROUP_SIZE_M) - pid_gemm_m = first_pid_gemm_m + ((pid % num_pid_in_group) % group_size_m) - pid_gemm_n = (pid % num_pid_in_group) // group_size_m - - offs_gemm_i = pid_gemm_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_gemm_j = pid_gemm_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - offs_n = offs_gemm_i // (p * q) - offs_k = offs_gemm_j - npq_residual = offs_gemm_i % (p * q) - offs_p = npq_residual // q - offs_q = npq_residual % q - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for i in range(0, tl.cdiv(gemm_k, BLOCK_SIZE_K)): - offs_gemm_k = i * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - - offs_c = offs_gemm_k // (r * s) - crs_residual = offs_gemm_k % (r * s) - offs_r = crs_residual // s - offs_s = crs_residual % s - - offs_h = offs_p[:, None] + offs_r[None, :] - offs_w = offs_q[:, None] + offs_s[None, :] - - input_ptrs = ( - input_ptr - + offs_n[:, None] * input_stride_n - + offs_c[None, :] * input_stride_c - + offs_h * input_stride_h - + offs_w * input_stride_w - ) - input_mask = ( - (offs_n[:, None] < n) & (offs_c[None, :] < c) & (offs_h < h) & (offs_w < w) - ) - - filter_ptrs = ( - filter_ptr - + offs_k[None, :] * filter_stride_k - + offs_c[:, None] * filter_stride_c - + offs_r[:, None] * filter_stride_r - + offs_s[:, None] * filter_stride_s - ) - filter_mask = (offs_k[None, :] < k) & ( - (offs_c < c) & (offs_r < r) & (offs_s < s) - )[:, None] - - input = tl.load(input_ptrs, mask=input_mask, other=0.0) - filter = tl.load(filter_ptrs, mask=filter_mask, other=0.0) - - accumulator = tl.dot(input, filter, accumulator) - - output = accumulator.to(tl.float16) - - output_ptrs = ( - output_ptr - + offs_n[:, None] * output_stride_n - + offs_k[None, :] * output_stride_k - + offs_p[:, None] * output_stride_p - + offs_q[:, None] * output_stride_q - ) - output_mask = ( - (offs_n[:, None] < n) - & (offs_k[None, :] < k) - & (offs_p[:, None] < p) - & (offs_q[:, None] < q) - ) - - tl.store(output_ptrs, output, mask=output_mask) - - -def triton_conv2d(input, filter): - n, c, h, w = input.shape - k, _, r, s = filter.shape - p = h - r + 1 - q = w - s + 1 - - output = torch.empty((n, k, p, q), device=input.device, dtype=input.dtype) - - def grid(meta): - return ( - triton.cdiv(n * p * q, meta["BLOCK_SIZE_M"]) - * triton.cdiv(k, meta["BLOCK_SIZE_N"]), - ) - - triton_conv2d_kernel[grid]( - input, - filter, - output, - n, - c, - h, - w, - k, - r, - s, - *input.stride(), - *filter.stride(), - *output.stride(), - ) - - return output +import ops.ninetoothed.torch +import ops.triton.torch if __name__ == "__main__": torch.manual_seed(0) + n, c, h, w = 4, 3, 224, 224 k, _, r, s = 8, c, 3, 3 - input = torch.randn(n, c, h, w, device="cuda") - filter = torch.randn(k, c, r, s, device="cuda") - ninetoothed_output = conv2d(input, filter) + dtype = torch.float16 + device = "cuda" + + input = torch.randn(n, c, h, w, dtype=dtype, device=device) + filter = torch.randn(k, c, r, s, dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.conv2d(input, filter) torch_output = F.conv2d(input, filter) - triton_output = triton_conv2d(input, filter) + triton_output = ops.triton.torch.conv2d(input, filter) + print(ninetoothed_output) print(torch_output) print(triton_output) + if torch.allclose(ninetoothed_output, torch_output, atol=0.01, rtol=0.01): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0.01, rtol=0.01): + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): print("✅ NineToothed and Triton match.") else: print("❌ NineToothed and Triton differ.") @triton.testing.perf_report( triton.testing.Benchmark( - x_names=["h", "w"], - x_vals=[8 * i for i in range(2, 33)], + x_names=["n"], + x_vals=[2**i for i in range(1, 11)], + x_log=True, line_arg="provider", line_vals=["ninetoothed", "torch", "triton"], line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="TFLOPS", - plot_name="2d-convolution-performance", + ylabel="ms", + plot_name="conv2d-performance", args={}, ) ) - def benchmark(h, w, provider): - n, c, _, _ = 64, 3, h, w - k, _, r, s = 64, c, 3, 3 - input = torch.randn((n, c, h, w), device="cuda") - filter = torch.randn((k, c, r, s), device="cuda") + def benchmark(n, provider): + _, c, h, w = n, 512, 14, 14 + k, _, r, s = 512, c, 3, 3 + + input = torch.randn((n, c, h, w), dtype=dtype, device=device) + filter = torch.randn((k, c, r, s), dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.conv2d(input, filter) + torch_output = F.conv2d(input, filter) + triton_output = ops.triton.torch.conv2d(input, filter) + + assert torch.allclose(ninetoothed_output, torch_output, atol=0.01, rtol=0.01) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: conv2d(input, filter)) + ms = triton.testing.do_bench( + lambda: ops.ninetoothed.torch.conv2d(input, filter) + ) elif provider == "torch": ms = triton.testing.do_bench(lambda: F.conv2d(input, filter)) elif provider == "triton": - ms = triton.testing.do_bench(lambda: triton_conv2d(input, filter)) - - def perf(ms): - p = h - r + 1 - q = w - s + 1 - - return 2 * n * k * p * q * c * r * s * 1e-12 / (ms * 1e-3) + ms = triton.testing.do_bench(lambda: ops.triton.torch.conv2d(input, filter)) - return perf(ms) + return ms benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/evaluate_code.py b/evaluate_code.py new file mode 100644 index 0000000..a199a55 --- /dev/null +++ b/evaluate_code.py @@ -0,0 +1,185 @@ +import functools +import json +import os.path +from pathlib import Path + +import pandas as pd + +_PARENT_PATH = Path(__file__).parent + +_OPS_PATH = _PARENT_PATH / "ops" + +_NINETOOTHED_KERNELS_PATH = _OPS_PATH / "ninetoothed" / "kernels" + +_TRITON_KERNELS_PATH = _OPS_PATH / "triton" / "kernels" + +_BACKSLASH_CHAR = "\\" + + +def _generate_cc_table(): + path = _PARENT_PATH / "cc.json" + + metric_names = {"complexity": "$G$"} + + data = json.loads(path.read_text()) + + data = { + kernel: { + metric_names["complexity"]: sum(block["complexity"] for block in blocks) + } + for kernel, blocks in data.items() + if "torch" not in kernel + } + + df = _generate_table(data, metric_names.values()) + + return df + + +def _generate_mi_table(): + path = _PARENT_PATH / "mi.json" + + metric_names = {"mi": "$MI$"} + + data = json.loads(path.read_text()) + + data = { + kernel: { + latex_name: metrics[raw_name] + for raw_name, latex_name in metric_names.items() + } + for kernel, metrics in data.items() + if "torch" not in kernel + } + + df = _generate_table(data, metric_names.values()) + + return df + + +def _generate_raw_table(): + path = _PARENT_PATH / "raw.json" + + metric_names = {"loc": "LOC", "lloc": "LLOC", "sloc": "SLOC"} + + data = json.loads(path.read_text()) + + data = { + kernel: { + latex_name: metrics[raw_name] + for raw_name, latex_name in metric_names.items() + } + for kernel, metrics in data.items() + if "torch" not in kernel + } + + df = _generate_table(data, metric_names.values()) + + return df + + +def _generate_hal_table(): + path = _PARENT_PATH / "hal.json" + + metric_names = { + "vocabulary": "$\\eta$", + "length": "$N$", + "volume": "$V$", + "difficulty": "$D$", + } + + data = json.loads(path.read_text()) + + data = { + kernel: { + latex_name: metrics["total"][raw_name] + for raw_name, latex_name in metric_names.items() + } + for kernel, metrics in data.items() + if "torch" not in kernel + } + + df = _generate_table(data, metric_names.values()) + + return df + + +def _generate_table(data, metric_names): + kernel_names = sorted( + set( + os.path.splitext(os.path.basename(kernel_name))[0] + for kernel_name in data.keys() + ) + ) + + def _key_from_kernel_name(path, kernel_name): + return str(path / f"{kernel_name}.py").removeprefix(str(_PARENT_PATH))[1:] + + data = { + f"{_BACKSLASH_CHAR}texttt{{{kernel_name.replace('scaled_dot_product_attention', 'sdpa').replace('rotary_position_embedding', 'rope').replace('_', f'{_BACKSLASH_CHAR}_')}}}": { + "Triton": { + metric_name: data[ + _key_from_kernel_name(_TRITON_KERNELS_PATH, kernel_name) + ][metric_name] + for metric_name in metric_names + }, + "NineToothed": { + metric_name: data[ + _key_from_kernel_name(_NINETOOTHED_KERNELS_PATH, kernel_name) + ][metric_name] + for metric_name in metric_names + }, + } + for kernel_name in kernel_names + } + + df = pd.DataFrame.from_dict( + { + (outer_key, inner_key): value + for outer_key, inner_dict in data.items() + for inner_key, value in inner_dict.items() + }, + orient="index", + ) + + df.index = pd.MultiIndex.from_tuples(df.index) + + return df + + +def _highlight(df): + new_df = pd.DataFrame("", index=df.index, columns=df.columns) + + for _, group in df[ + ["LOC", "LLOC", "SLOC", "$G$", "$\\eta$", "$N$", "$V$", "$D$"] + ].groupby(level=0): + mask = group == group.min() + + new_df.update( + mask.replace(True, "background-color: green!20").replace(False, "") + ) + + for _, group in df[["$MI$"]].groupby(level=0): + mask = group == group.max() + + new_df.update( + mask.replace(True, "background-color: green!20").replace(False, "") + ) + + return new_df + + +if __name__ == "__main__": + raw_table = _generate_raw_table() + cc_table = _generate_cc_table() + hal_table = _generate_hal_table() + mi_table = _generate_mi_table() + + df = functools.reduce( + lambda left, right: pd.merge(left, right, left_index=True, right_index=True), + (raw_table, cc_table, hal_table, mi_table), + ) + + styler = df.style.apply(_highlight, axis=None).format(precision=2) + + print(styler.to_latex(hrules=True, multicol_align="c", convert_css=True)) diff --git a/evaluate_performance.py b/evaluate_performance.py new file mode 100644 index 0000000..7b9af13 --- /dev/null +++ b/evaluate_performance.py @@ -0,0 +1,50 @@ +import json + +import matplotlib.pyplot as plt +import pandas as pd + +from evaluate_code import _BACKSLASH_CHAR +from run_experiments import ALL_MAX_NEW_TOKENS, BACKENDS + +if __name__ == "__main__": + plt.rcParams["figure.dpi"] = 600 + plt.rcParams["font.family"] = "Linux Biolinum" + + df = pd.read_csv("microbenchmark_data.csv") + + for task in df["Task"]: + latex_item = f"\item {_BACKSLASH_CHAR}texttt{{{task.replace('scaled_dot_product_attention', 'sdpa').replace('rotary_position_embedding', 'rope').replace('_', f'{_BACKSLASH_CHAR}_')}}}" + print(latex_item) + + df.index += 1 + df.plot(kind="bar", rot=0) + plt.ylabel("Execution Time (ms)") + plt.xlabel("Task") + plt.grid(False) + plt.tight_layout() + plt.savefig("microbenchmark-results.png") + + data = {"Output Length": [], "NineToothed": [], "Triton": [], "PyTorch": []} + + for max_new_tokens in ALL_MAX_NEW_TOKENS: + data["Output Length"].append(max_new_tokens) + + for backend in BACKENDS: + with open(f"infer_{max_new_tokens}_{backend}.json") as f: + num_tokens_per_second = json.load(f)["num_tokens_per_second"] + + if backend == "ninetoothed": + data["NineToothed"].append(num_tokens_per_second) + elif backend == "triton": + data["Triton"].append(num_tokens_per_second) + elif backend == "torch": + data["PyTorch"].append(num_tokens_per_second) + + df = pd.DataFrame(data) + + df.set_index("Output Length").plot(kind="bar", rot=0) + plt.ylabel("Throughput (TPS)") + plt.xlabel("Output Length") + plt.grid(False) + plt.tight_layout() + plt.savefig("benchmark-results.png") diff --git a/fused_rms_norm.py b/fused_rms_norm.py new file mode 100644 index 0000000..a32ce9b --- /dev/null +++ b/fused_rms_norm.py @@ -0,0 +1,114 @@ +from contextlib import contextmanager + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton + +import ops.ninetoothed.torch +import ops.triton.torch + + +class RMSNorm(nn.Module): + fused_rms_norm = None + + def __init__(self, other): + super().__init__() + + self.__dict__ = other.__dict__ + + def forward(self, x): + return type(self).fused_rms_norm(x, self.weight, self.variance_epsilon) + + +@contextmanager +def rms_norm_backend(backend_name): + def _torch_fused_rms_norm(x, w, eps): + return F.rms_norm(x, x.shape[-1:], w, eps) + + _prev_impl = RMSNorm.fused_rms_norm + + if backend_name == "ninetoothed": + impl = ops.ninetoothed.torch.fused_rms_norm + elif backend_name == "triton": + impl = ops.triton.torch.fused_rms_norm + elif backend_name == "torch": + impl = _torch_fused_rms_norm + else: + raise ValueError(f"unknown backend: `{backend_name}`") + + RMSNorm.fused_rms_norm = impl + + try: + yield + finally: + RMSNorm.fused_rms_norm = _prev_impl + + +if __name__ == "__main__": + torch.manual_seed(0) + + dtype = torch.float16 + device = "cuda" + + x = torch.randn(1151, 8192, dtype=dtype, device=device) + w = torch.randn(8192, dtype=dtype, device=device) + eps = 1e-5 + + ninetoothed_output = ops.ninetoothed.torch.fused_rms_norm(x, w, eps) + torch_output = F.rms_norm(x, x.shape[-1:], w, eps) + triton_output = ops.triton.torch.fused_rms_norm(x, w, eps) + + print(ninetoothed_output) + print(torch_output) + print(triton_output) + + if torch.allclose(ninetoothed_output, torch_output, atol=0.001, rtol=0.005): + print("✅ NineToothed and PyTorch match.") + else: + print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.005): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["n"], + x_vals=[2**i for i in range(5, 15)], + x_log=True, + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="ms", + plot_name="fused-rms-norm-performance", + args={"m": 4096}, + ) + ) + def benchmark(m, n, provider): + x = torch.randn(m, n, dtype=dtype, device=device) + w = torch.randn(n, dtype=dtype, device=device) + eps = 1e-5 + + ninetoothed_output = ops.ninetoothed.torch.fused_rms_norm(x, w, eps) + torch_output = F.rms_norm(x, x.shape[-1:], w, eps) + triton_output = ops.triton.torch.fused_rms_norm(x, w, eps) + + assert torch.allclose(ninetoothed_output, torch_output, atol=0.001, rtol=0.005) + assert torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.005) + + if provider == "ninetoothed": + ms = triton.testing.do_bench( + lambda: ops.ninetoothed.torch.fused_rms_norm(x, w, eps) + ) + elif provider == "torch": + ms = triton.testing.do_bench(lambda: F.rms_norm(x, x.shape[-1:], w, eps)) + elif provider == "triton": + ms = triton.testing.do_bench( + lambda: ops.triton.torch.fused_rms_norm(x, w, eps) + ) + + return ms + + benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/infer.py b/infer.py index 84580db..c5b2ab3 100644 --- a/infer.py +++ b/infer.py @@ -1,11 +1,18 @@ import argparse +import json +import time +import torch from transformers import AutoModelForCausalLM, AutoTokenizer -from attention import Attention -from linear import Linear -from rms_norm import RMSNorm -from silu import SiLU +from fused_rms_norm import RMSNorm, rms_norm_backend +from linear import Linear, bmm_backend +from scaled_dot_product_attention import ( + Attention, + rotary_position_embedding_backend, + scaled_dot_product_attention_backend, +) +from silu import SiLU, silu_backend from utils import replace_module if __name__ == "__main__": @@ -38,6 +45,24 @@ default="cpu", help='Device to use for inference (e.g., "cuda", "cpu").', ) + parser.add_argument( + "--backend", + type=str, + default="ninetoothed", + help='Backend to use for inference (e.g., "ninetoothed", "triton", "torch").', + ) + parser.add_argument( + "--num-warmup-iterations", + type=int, + default=0, + help="For profiling. The number of warmup iterations to run before measuring performance.", + ) + parser.add_argument( + "--num-profiling-iterations", + type=int, + default=1, + help="For profiling. The number of iterations to run for performance measurement.", + ) args = parser.parse_args() @@ -45,6 +70,12 @@ prompts = args.prompts max_new_tokens = args.max_new_tokens device = args.device + backend = args.backend + num_warmup_iterations = args.num_warmup_iterations + num_profiling_iterations = args.num_profiling_iterations + + assert num_profiling_iterations >= 1 + assert num_warmup_iterations >= 0 tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device) @@ -52,13 +83,52 @@ tokenizer.pad_token = tokenizer.eos_token model.generation_config.pad_token_id = tokenizer.pad_token_id - replace_module(model, Attention) - replace_module(model, Linear) - replace_module(model, RMSNorm) - replace_module(model, SiLU) + if backend != "torch": + replace_module(model, Attention) + replace_module(model, Linear) + replace_module(model, RMSNorm) + replace_module(model, SiLU) inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(device) - outputs = model.generate(**inputs, max_new_tokens=max_new_tokens) + + with ( + bmm_backend(backend), + rms_norm_backend(backend), + rotary_position_embedding_backend(backend), + scaled_dot_product_attention_backend(backend), + silu_backend(backend), + ): + for _ in range(num_warmup_iterations): + model.generate(**inputs, max_new_tokens=max_new_tokens) + + if device == "cuda": + torch.cuda.synchronize() + + start_time = time.time() + + for _ in range(num_profiling_iterations): + outputs = model.generate(**inputs, max_new_tokens=max_new_tokens) + + if device == "cuda": + torch.cuda.synchronize() + + end_time = time.time() + strings = tokenizer.batch_decode(outputs, skip_special_tokens=True) - print(strings) + average_time = (end_time - start_time) / num_profiling_iterations + num_input_tokens = inputs["input_ids"].size(-1) + num_output_tokens = outputs.size(-1) - num_input_tokens + num_tokens_per_second = num_output_tokens / average_time + + print( + json.dumps( + { + "strings": strings, + "average_time": average_time, + "num_input_tokens": num_input_tokens, + "num_output_tokens": num_output_tokens, + "num_tokens_per_second": num_tokens_per_second, + } + ) + ) diff --git a/linear.py b/linear.py index 8234064..50bf179 100644 --- a/linear.py +++ b/linear.py @@ -1,13 +1,41 @@ +from contextlib import contextmanager + +import torch import torch.nn as nn -from bmm import bmm +import ops.ninetoothed.torch class Linear(nn.Module): + bmm = None + def __init__(self, other): super().__init__() self.__dict__ = other.__dict__ def forward(self, input): - return bmm(input, self.weight.T.unsqueeze(0).expand(input.shape[0], -1, -1)) + return type(self).bmm( + input, self.weight.T.unsqueeze(0).expand(input.shape[0], -1, -1) + ) + + +@contextmanager +def bmm_backend(backend_name): + _prev_impl = Linear.bmm + + if backend_name == "ninetoothed": + impl = ops.ninetoothed.torch.bmm + elif backend_name == "triton": + impl = ops.triton.torch.bmm + elif backend_name == "torch": + impl = torch.bmm + else: + raise ValueError(f"unknown backend: `{backend_name}`") + + Linear.bmm = impl + + try: + yield + finally: + Linear.bmm = _prev_impl diff --git a/matmul.py b/matmul.py deleted file mode 100644 index d5aa445..0000000 --- a/matmul.py +++ /dev/null @@ -1,284 +0,0 @@ -import ninetoothed -import ninetoothed.language as ntl -import torch -import triton -import triton.language as tl -from ninetoothed import Symbol, Tensor - - -def arrangement(lhs, rhs, output): - BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) - BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) - BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True) - - output_tiled = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N)) - - lhs_tiled = ( - lhs.tile((BLOCK_SIZE_M, BLOCK_SIZE_K)) - .tile((1, -1)) - .expand((-1, output_tiled.shape[1])) - ) - lhs_tiled.dtype = lhs_tiled.dtype.squeeze(0) - - rhs_tiled = ( - rhs.tile((BLOCK_SIZE_K, BLOCK_SIZE_N)) - .tile((-1, 1)) - .expand((output_tiled.shape[0], -1)) - ) - rhs_tiled.dtype = rhs_tiled.dtype.squeeze(1) - - return lhs_tiled, rhs_tiled, output_tiled - - -def application(lhs, rhs, output): - accumulator = ntl.zeros(output.shape, dtype=ntl.float32) - for k in range(lhs.shape[0]): - accumulator += ntl.dot(lhs[k], rhs[k]) - output = accumulator.to(ntl.float16) - - -matmul_kernel = ninetoothed.make( - arrangement, application, (Tensor(2), Tensor(2), Tensor(2)) -) - - -def matmul(lhs, rhs): - output = torch.empty( - (lhs.shape[0], rhs.shape[1]), device=lhs.device, dtype=torch.float16 - ) - - matmul_kernel(lhs, rhs, output) - - return output - - -@triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - ], - key=["m", "n", "k"], -) -@triton.jit -def triton_matmul_kernel( - lhs_ptr, - rhs_ptr, - output_ptr, - m, - n, - k, - lhs_stride_m, - lhs_stride_k, - rhs_stride_k, - rhs_stride_n, - output_stride_m, - output_stride_n, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - pid = tl.program_id(0) - num_pid_m = tl.cdiv(m, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(n, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % m - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n - offs_k = tl.arange(0, BLOCK_SIZE_K) - lhs_ptrs = lhs_ptr + ( - offs_am[:, None] * lhs_stride_m + offs_k[None, :] * lhs_stride_k - ) - rhs_ptrs = rhs_ptr + ( - offs_k[:, None] * rhs_stride_k + offs_bn[None, :] * rhs_stride_n - ) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for i in range(0, tl.cdiv(k, BLOCK_SIZE_K)): - lhs = tl.load(lhs_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K, other=0.0) - rhs = tl.load(rhs_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(lhs, rhs, accumulator) - lhs_ptrs += BLOCK_SIZE_K * lhs_stride_k - rhs_ptrs += BLOCK_SIZE_K * rhs_stride_k - output = accumulator.to(tl.float16) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - output_ptrs = ( - output_ptr - + output_stride_m * offs_cm[:, None] - + output_stride_n * offs_cn[None, :] - ) - output_mask = (offs_cm[:, None] < m) & (offs_cn[None, :] < n) - tl.store(output_ptrs, output, mask=output_mask) - - -def triton_matmul(lhs, rhs): - output = torch.empty( - (lhs.shape[0], rhs.shape[1]), device=lhs.device, dtype=torch.float16 - ) - - def grid(meta): - return ( - triton.cdiv(lhs.shape[0], meta["BLOCK_SIZE_M"]) - * triton.cdiv(rhs.shape[1], meta["BLOCK_SIZE_N"]), - ) - - triton_matmul_kernel[grid]( - lhs, - rhs, - output, - lhs.shape[0], - rhs.shape[1], - lhs.shape[1], - lhs.stride(0), - lhs.stride(1), - rhs.stride(0), - rhs.stride(1), - output.stride(0), - output.stride(1), - ) - - return output - - -if __name__ == "__main__": - torch.manual_seed(0) - shape = (512, 512) - lhs = torch.randn(shape, device="cuda", dtype=torch.float16) - rhs = torch.randn(shape, device="cuda", dtype=torch.float16) - ninetoothed_output = matmul(lhs, rhs) - torch_output = torch.matmul(lhs, rhs) - triton_output = triton_matmul(lhs, rhs) - print(ninetoothed_output) - print(torch_output) - print(triton_output) - if torch.allclose(ninetoothed_output, torch_output): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["m", "n", "k"], - x_vals=[128 * i for i in range(2, 33)], - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="TFLOPS", - plot_name="matrix-multiplication-performance", - args={}, - ) - ) - def benchmark(m, n, k, provider): - lhs = torch.randn((m, k), device="cuda", dtype=torch.float16) - rhs = torch.randn((k, n), device="cuda", dtype=torch.float16) - quantiles = [0.5, 0.2, 0.8] - - if provider == "ninetoothed": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: matmul(lhs, rhs), quantiles=quantiles - ) - elif provider == "torch": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: torch.matmul(lhs, rhs), quantiles=quantiles - ) - elif provider == "triton": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: triton_matmul(lhs, rhs), quantiles=quantiles - ) - - def perf(ms): - return 2 * m * n * k * 1e-12 / (ms * 1e-3) - - return perf(ms), perf(max_ms), perf(min_ms) - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/max_pool2d.py b/max_pool2d.py index 844c2c4..584ee53 100644 --- a/max_pool2d.py +++ b/max_pool2d.py @@ -52,13 +52,18 @@ def max_pool2d(input, window_shape): if __name__ == "__main__": torch.manual_seed(0) + input_shape = (32, 3, 64, 64) window_shape = (3, 3) + input = torch.randn(input_shape, dtype=torch.float16, device="cuda") + ninetoothed_output = max_pool2d(input, window_shape) torch_output = F.max_pool2d(input, window_shape, ceil_mode=True) + print(ninetoothed_output) print(torch_output) + if torch.allclose(ninetoothed_output, torch_output): print("✅ NineToothed and PyTorch match.") else: @@ -72,8 +77,8 @@ def max_pool2d(input, window_shape): line_vals=["ninetoothed", "torch"], line_names=["NineToothed", "PyTorch"], styles=[("blue", "-"), ("green", "-")], - ylabel="GB/s", - plot_name="2d-max-pooling-performance", + ylabel="ms", + plot_name="max-pool2d-performance", args={}, ) ) @@ -90,9 +95,6 @@ def benchmark(h, w, provider): elif provider == "torch": ms = triton.testing.do_bench(lambda: F.max_pool2d(input, window_shape)) - def gbps(ms): - return 2 * input.numel() * input.element_size() / ms * 1e-6 - - return gbps(ms) + return ms benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/mm.py b/mm.py new file mode 100644 index 0000000..f4d0b95 --- /dev/null +++ b/mm.py @@ -0,0 +1,68 @@ +import torch +import triton + +import ops.ninetoothed.torch +import ops.triton.torch + +if __name__ == "__main__": + torch.manual_seed(0) + + shape = (512, 512) + dtype = torch.float16 + device = "cuda" + + input = torch.randn(shape, dtype=dtype, device=device) + other = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.mm(input, other) + torch_output = torch.mm(input, other) + triton_output = ops.triton.torch.mm(input, other) + + print(ninetoothed_output) + print(torch_output) + print(triton_output) + + if torch.allclose(ninetoothed_output, torch_output): + print("✅ NineToothed and PyTorch match.") + else: + print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["m", "n", "k"], + x_vals=[2**i for i in range(3, 13)], + x_log=True, + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="ms", + plot_name="mm-performance", + args={}, + ) + ) + def benchmark(m, n, k, provider): + input = torch.randn((m, k), dtype=dtype, device=device) + other = torch.randn((k, n), dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.mm(input, other) + torch_output = torch.mm(input, other) + triton_output = ops.triton.torch.mm(input, other) + + assert torch.allclose(ninetoothed_output, torch_output, atol=0.025, rtol=0.025) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) + + if provider == "ninetoothed": + ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.mm(input, other)) + elif provider == "torch": + ms = triton.testing.do_bench(lambda: torch.mm(input, other)) + elif provider == "triton": + ms = triton.testing.do_bench(lambda: ops.triton.torch.mm(input, other)) + + return ms + + benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/ops/ninetoothed/kernels/add.py b/ops/ninetoothed/kernels/add.py new file mode 100644 index 0000000..1dfb4be --- /dev/null +++ b/ops/ninetoothed/kernels/add.py @@ -0,0 +1,21 @@ +import ninetoothed +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) + + +def arrangement(input, other, output, BLOCK_SIZE=BLOCK_SIZE): + input_arranged = input.tile((BLOCK_SIZE,)) + other_arranged = other.tile((BLOCK_SIZE,)) + output_arranged = output.tile((BLOCK_SIZE,)) + + return input_arranged, other_arranged, output_arranged + + +def application(input, other, output): + output = input + other # noqa: F841 + + +tensors = tuple(Tensor(1) for _ in range(3)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/kernels/addmm.py b/ops/ninetoothed/kernels/addmm.py new file mode 100644 index 0000000..495246a --- /dev/null +++ b/ops/ninetoothed/kernels/addmm.py @@ -0,0 +1,22 @@ +import ninetoothed +from ninetoothed import Tensor + +import ops.ninetoothed.kernels.mm as mm + + +def arrangement(input, mat1, mat2, beta, alpha, output): + _, _, input_arranged = mm.arrangement(mat1, mat2, input) + + mat1_arranged, mat2_arranged, output_arranged = mm.arrangement(mat1, mat2, output) + + return input_arranged, mat1_arranged, mat2_arranged, beta, alpha, output_arranged + + +def application(input, mat1, mat2, beta, alpha, output): + mm.application(mat1, mat2, output) + output = beta * input + alpha * output + + +tensors = (Tensor(2), Tensor(2), Tensor(2), Tensor(0), Tensor(0), Tensor(2)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/kernels/bmm.py b/ops/ninetoothed/kernels/bmm.py new file mode 100644 index 0000000..6b34635 --- /dev/null +++ b/ops/ninetoothed/kernels/bmm.py @@ -0,0 +1,39 @@ +import ninetoothed +from ninetoothed import Tensor, block_size + +from ops.ninetoothed.kernels.mm import application + +BLOCK_SIZE_M = block_size() +BLOCK_SIZE_N = block_size() +BLOCK_SIZE_K = block_size() + + +def arrangement( + input, + other, + output, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, +): + output_arranged = output.tile((1, BLOCK_SIZE_M, BLOCK_SIZE_N)) + output_arranged.dtype = output_arranged.dtype.squeeze(0) + + input_arranged = input.tile((1, BLOCK_SIZE_M, BLOCK_SIZE_K)) + input_arranged = input_arranged.tile((1, 1, -1)) + input_arranged = input_arranged.expand((-1, -1, output_arranged.shape[-1])) + input_arranged.dtype = input_arranged.dtype.squeeze((0, 1)) + input_arranged.dtype.dtype = input_arranged.dtype.dtype.squeeze(0) + + other_arranged = other.tile((1, BLOCK_SIZE_K, BLOCK_SIZE_N)) + other_arranged = other_arranged.tile((1, -1, 1)) + other_arranged = other_arranged.expand((-1, output_arranged.shape[-2], -1)) + other_arranged.dtype = other_arranged.dtype.squeeze((0, 2)) + other_arranged.dtype.dtype = other_arranged.dtype.dtype.squeeze(0) + + return input_arranged, other_arranged, output_arranged + + +tensors = (Tensor(3), Tensor(3), Tensor(3)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/kernels/conv2d.py b/ops/ninetoothed/kernels/conv2d.py new file mode 100644 index 0000000..5f2a49e --- /dev/null +++ b/ops/ninetoothed/kernels/conv2d.py @@ -0,0 +1,25 @@ +import ninetoothed +from ninetoothed import Tensor + +import ops.ninetoothed.kernels.mm as mm + + +def arrangement(input, filter, output): + input_arranged = input.tile((1, *filter.shape[1:]), strides=(-1, -1, 1, 1)) + input_arranged = input_arranged.squeeze(1) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + input_arranged = input_arranged.ravel() + input_arranged = input_arranged.flatten(end_dim=3).flatten(start_dim=1) + + filter_arranged = filter.flatten(start_dim=1) + filter_arranged = filter_arranged.permute((1, 0)) + + output_arranged = output.permute((0, 2, 3, 1)).flatten(end_dim=3) + + return mm.arrangement(input_arranged, filter_arranged, output_arranged) + + +shape_options = {"constexpr": True} +tensors = tuple(Tensor(4, shape_options=shape_options) for _ in range(3)) + +kernel = ninetoothed.make(arrangement, mm.application, tensors) diff --git a/ops/ninetoothed/kernels/fused_rms_norm.py b/ops/ninetoothed/kernels/fused_rms_norm.py new file mode 100644 index 0000000..5242179 --- /dev/null +++ b/ops/ninetoothed/kernels/fused_rms_norm.py @@ -0,0 +1,22 @@ +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) + + +def arrangement(x, w, eps, y, BLOCK_SIZE=BLOCK_SIZE): + def arrange(tensor): + return tensor.tile((1, BLOCK_SIZE)) + + return arrange(x), arrange(w), eps, arrange(y) + + +def application(x, w, eps, y): + x_fp32 = ntl.cast(x, ntl.float32) + y = x_fp32 * ntl.rsqrt(ntl.sum(x_fp32 * x_fp32) / x.shape[-1] + eps) * w # noqa: F841 + + +tensors = (Tensor(2), Tensor(2), Tensor(0), Tensor(2)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/kernels/mm.py b/ops/ninetoothed/kernels/mm.py new file mode 100644 index 0000000..abb054d --- /dev/null +++ b/ops/ninetoothed/kernels/mm.py @@ -0,0 +1,44 @@ +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor, block_size + +BLOCK_SIZE_M = block_size() +BLOCK_SIZE_N = block_size() +BLOCK_SIZE_K = block_size() + + +def arrangement( + input, + other, + output, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, +): + output_arranged = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N)) + + input_arranged = input.tile((BLOCK_SIZE_M, BLOCK_SIZE_K)) + input_arranged = input_arranged.tile((1, -1)) + input_arranged = input_arranged.expand((-1, output_arranged.shape[1])) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + + other_arranged = other.tile((BLOCK_SIZE_K, BLOCK_SIZE_N)) + other_arranged = other_arranged.tile((-1, 1)) + other_arranged = other_arranged.expand((output_arranged.shape[0], -1)) + other_arranged.dtype = other_arranged.dtype.squeeze(1) + + return input_arranged, other_arranged, output_arranged + + +def application(input, other, output): + accumulator = ntl.zeros(output.shape, dtype=ntl.float32) + + for k in range(input.shape[0]): + accumulator += ntl.dot(input[k], other[k]) + + output = accumulator + + +tensors = (Tensor(2), Tensor(2), Tensor(2)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/kernels/rms_norm.py b/ops/ninetoothed/kernels/rms_norm.py new file mode 100644 index 0000000..f2b926e --- /dev/null +++ b/ops/ninetoothed/kernels/rms_norm.py @@ -0,0 +1,21 @@ +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) + + +def arrangement(input, eps, output, BLOCK_SIZE=BLOCK_SIZE): + return input.tile((1, BLOCK_SIZE)), eps, output.tile((1, BLOCK_SIZE)) + + +def application(input, eps, output): + input_fp32 = ntl.cast(input, ntl.float32) + output = input_fp32 * ntl.rsqrt( # noqa: F841 + ntl.sum(input_fp32 * input_fp32) / input.shape[-1] + eps + ) + + +tensors = (Tensor(2), Tensor(0), Tensor(2)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/kernels/rotary_position_embedding.py b/ops/ninetoothed/kernels/rotary_position_embedding.py new file mode 100644 index 0000000..f55d9cf --- /dev/null +++ b/ops/ninetoothed/kernels/rotary_position_embedding.py @@ -0,0 +1,56 @@ +import functools + +import ninetoothed +from ninetoothed import Tensor + + +def arrangement(input, sin_table, cos_table, interleaved=True): + emb_dim = input.shape[-1] + tile_shape = (1, 1, 1, emb_dim // 2) + + if interleaved: + strides = (-1, -1, -1, 1) + dilation = (1, 1, 1, 2) + else: + strides = None + dilation = None + + input_arranged = input.tile(tile_shape, strides=strides, dilation=dilation) + input_arranged = input_arranged.tile((1, 1, 1, 2)) + input_arranged.dtype = input_arranged.dtype.squeeze((0, 1, 2)) + input_arranged.dtype.dtype = input_arranged.dtype.dtype.squeeze((0, 1, 2)) + + sin_table_arranged = sin_table.tile(tile_shape) + sin_table_arranged.dtype = sin_table_arranged.dtype.squeeze((0, 1, 2)) + + cos_table_arranged = cos_table.tile(tile_shape) + cos_table_arranged.dtype = cos_table_arranged.dtype.squeeze((0, 1, 2)) + + return input_arranged, sin_table_arranged, cos_table_arranged + + +def application(input, sin_table, cos_table): + sin_table_loaded = sin_table + cos_table_loaded = cos_table + + input_0 = input[0] + input_1 = input[1] + + input[0] = input_0 * cos_table_loaded - input_1 * sin_table_loaded + input[1] = input_0 * sin_table_loaded + input_1 * cos_table_loaded + + +inputs = tuple(Tensor(4, shape_options={"constexpr": True}) for _ in range(3)) + +interleaved_kernel = ninetoothed.make( + functools.partial(arrangement, interleaved=True), application, inputs +) +non_interleaved_kernel = ninetoothed.make( + functools.partial(arrangement, interleaved=False), application, inputs +) + + +def kernel(input, sin_table, cos_table, interleaved=True): + return (interleaved_kernel if interleaved else non_interleaved_kernel)( + input, sin_table, cos_table + ) diff --git a/ops/ninetoothed/kernels/scaled_dot_product_attention.py b/ops/ninetoothed/kernels/scaled_dot_product_attention.py new file mode 100644 index 0000000..2d3f887 --- /dev/null +++ b/ops/ninetoothed/kernels/scaled_dot_product_attention.py @@ -0,0 +1,60 @@ +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor, block_size + +BLOCK_SIZE_M = block_size() +BLOCK_SIZE_N = block_size() + + +def arrangement( + q, k, v, scale, o, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N +): + def arrange_q_or_o(input): + arranged = input.tile((1, 1, BLOCK_SIZE_M, -1)) + arranged.dtype = arranged.dtype.squeeze((0, 1)) + + return arranged + + def arrange_k_or_v(input): + arranged = input.tile((1, 1, BLOCK_SIZE_N, -1)) + arranged = arranged.tile((1, 1, -1, -1)) + arranged = arranged.expand((-1, -1, q_arranged.shape[-2], -1)) + arranged.dtype = arranged.dtype.squeeze((0, 1, 3)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + q_arranged = arrange_q_or_o(q) + + return q_arranged, arrange_k_or_v(k), arrange_k_or_v(v), scale, arrange_q_or_o(o) + + +def application(q, k, v, scale, o): + q_loaded = (q * scale * 1.44269504089).to(q.dtype) + + acc = ntl.zeros((q.shape[-2], q.shape[-1]), dtype=ntl.float32) + l_i = ntl.full((q.shape[-2],), 1, dtype=ntl.float32) + m_i = ntl.full((q.shape[-2],), float("-inf"), dtype=ntl.float32) + + for i in range(k.shape[0]): + qk = ntl.dot(q_loaded, ntl.trans(k[i])) + qk = ntl.where(k[i].offsets(-2) < k.source.shape[-2], qk, float("-inf")) + + m_ij = ntl.maximum(m_i, ntl.max(qk, 1)) + p = ntl.exp2(qk - m_ij[:, None]) + l_ij = ntl.sum(p, 1) + + alpha = ntl.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + ntl.dot(p.to(v.dtype.dtype), v[i]) + m_i = m_ij + l_i = l_i * alpha + l_ij + + acc /= l_i[:, None] + o = acc.to(o.dtype) # noqa: F841 + + +shape_options = (None, None, None, {"constexpr": True, "upper_bound": 128}) +q, k, v, o = (Tensor(4, shape_options=shape_options) for _ in range(4)) +tensors = (q, k, v, Tensor(0), o) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/kernels/silu.py b/ops/ninetoothed/kernels/silu.py new file mode 100644 index 0000000..3ead189 --- /dev/null +++ b/ops/ninetoothed/kernels/silu.py @@ -0,0 +1,19 @@ +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) + + +def arrangement(input, output, BLOCK_SIZE=BLOCK_SIZE): + return input.tile((BLOCK_SIZE,)), output.tile((BLOCK_SIZE,)) + + +def application(input, output): + input_loaded = input + output = input_loaded * ntl.sigmoid(ntl.cast(input_loaded, ntl.float32)) # noqa: F841 + + +tensors = (Tensor(1), Tensor(1)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/kernels/softmax.py b/ops/ninetoothed/kernels/softmax.py new file mode 100644 index 0000000..6d24e66 --- /dev/null +++ b/ops/ninetoothed/kernels/softmax.py @@ -0,0 +1,24 @@ +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) + + +def arrangement(input, output, BLOCK_SIZE=BLOCK_SIZE): + return input.tile((1, BLOCK_SIZE)), output.tile((1, BLOCK_SIZE)) + + +def application(input, output): + input_loaded = input + + row_minus_max = input_loaded - ntl.max(input_loaded) + numerator = ntl.exp(row_minus_max) + denominator = ntl.sum(numerator) + + output = numerator / denominator # noqa: F841 + + +tensors = (Tensor(2, other=float("-inf")), Tensor(2)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/kernels/swiglu.py b/ops/ninetoothed/kernels/swiglu.py new file mode 100644 index 0000000..63b7cd6 --- /dev/null +++ b/ops/ninetoothed/kernels/swiglu.py @@ -0,0 +1,20 @@ +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) + + +def arrangement(a, b, c, BLOCK_SIZE=BLOCK_SIZE): + return a.tile((BLOCK_SIZE,)), b.tile((BLOCK_SIZE,)), c.tile((BLOCK_SIZE,)) + + +def application(a, b, c): + b_loaded = b + gate = b_loaded * ntl.sigmoid(ntl.cast(b_loaded, ntl.float32)) + c = a * gate # noqa: F841 + + +tensors = (Tensor(1), Tensor(1), Tensor(1)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/torch.py b/ops/ninetoothed/torch.py new file mode 100644 index 0000000..fe0824d --- /dev/null +++ b/ops/ninetoothed/torch.py @@ -0,0 +1,145 @@ +import math + +import torch + +import ops.ninetoothed.kernels.add +import ops.ninetoothed.kernels.addmm +import ops.ninetoothed.kernels.bmm +import ops.ninetoothed.kernels.conv2d +import ops.ninetoothed.kernels.fused_rms_norm +import ops.ninetoothed.kernels.mm +import ops.ninetoothed.kernels.rms_norm +import ops.ninetoothed.kernels.rotary_position_embedding +import ops.ninetoothed.kernels.scaled_dot_product_attention +import ops.ninetoothed.kernels.silu +import ops.ninetoothed.kernels.softmax +import ops.ninetoothed.kernels.swiglu + + +def add(input, other): + output = torch.empty_like(input) + + ops.ninetoothed.kernels.add.kernel(input, other, output, BLOCK_SIZE=1024) + + return output + + +def addmm(input, mat1, mat2, beta=1, alpha=1): + output_shape = (mat1.shape[0], mat2.shape[1]) + output = torch.empty(output_shape, dtype=mat1.dtype, device=mat1.device) + + ops.ninetoothed.kernels.addmm.kernel(input, mat1, mat2, beta, alpha, output) + + return output + + +def bmm(lhs, rhs): + output_shape = (lhs.shape[0], lhs.shape[-2], rhs.shape[-1]) + output = torch.empty(output_shape, dtype=lhs.dtype, device=lhs.device) + + ops.ninetoothed.kernels.bmm.kernel(lhs, rhs, output) + + return output + + +def conv2d(input, filter): + n, _, h, w = input.shape + k, _, r, s = filter.shape + p = h - r + 1 + q = w - s + 1 + + output = torch.empty((n, k, p, q), dtype=input.dtype, device=input.device) + + ops.ninetoothed.kernels.conv2d.kernel(input, filter, output) + + return output + + +def fused_rms_norm(x, w, eps=None): + if eps is None: + eps = torch.finfo(x.dtype).eps() + + x_2d = x.view(-1, x.shape[-1]) + w_2d = w.expand_as(x_2d) + y_2d = torch.empty_like(x_2d) + + ops.ninetoothed.kernels.fused_rms_norm.kernel( + x_2d, w_2d, eps, y_2d, BLOCK_SIZE=x.shape[-1] + ) + + return y_2d.view(x.shape) + + +def mm(input, other): + output_shape = (input.shape[0], other.shape[1]) + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + + ops.ninetoothed.kernels.mm.kernel(input, other, output) + + return output + + +def rms_norm(input, eps=None): + if eps is None: + eps = torch.finfo(input.dtype).eps + + output = torch.empty_like(input) + + ops.ninetoothed.kernels.rms_norm.kernel( + input, eps, output, BLOCK_SIZE=input.shape[-1] + ) + + return output + + +def rotary_position_embedding(input, sin_table, cos_table, interleaved=True): + batch_size, _, num_heads, _ = input.shape + + output = input.clone() + sin_table = sin_table[None, :, None, :].expand(batch_size, -1, num_heads, -1) + cos_table = cos_table[None, :, None, :].expand(batch_size, -1, num_heads, -1) + + ops.ninetoothed.kernels.rotary_position_embedding.kernel( + output, sin_table, cos_table, interleaved + ) + + return output + + +def scaled_dot_product_attention(q, k, v, scale=None): + if scale is None: + scale = 1 / math.sqrt(q.shape[-1]) + + o = torch.empty_like(q) + + ops.ninetoothed.kernels.scaled_dot_product_attention.kernel(q, k, v, scale, o) + + return o + + +def silu(input): + input_flat = input.flatten() + output_flat = torch.empty_like(input_flat) + + ops.ninetoothed.kernels.silu.kernel(input_flat, output_flat, BLOCK_SIZE=1024) + + return output_flat.view_as(input) + + +def softmax(input): + output = torch.empty_like(input) + + ops.ninetoothed.kernels.softmax.kernel(input, output, BLOCK_SIZE=input.shape[-1]) + + return output + + +def swiglu(a, b): + a_flat = a.flatten() + b_flat = b.flatten() + + c = torch.empty_like(a_flat) + + ops.ninetoothed.kernels.swiglu.kernel(a_flat, b_flat, c, BLOCK_SIZE=1024) + + return c.view_as(a) diff --git a/ops/triton/kernels/add.py b/ops/triton/kernels/add.py new file mode 100644 index 0000000..e7c1eb3 --- /dev/null +++ b/ops/triton/kernels/add.py @@ -0,0 +1,17 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel(input_ptr, other_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < num_elements + + input = tl.load(input_ptr + offsets, mask=mask) + other = tl.load(other_ptr + offsets, mask=mask) + output = input + other + + tl.store(output_ptr + offsets, output, mask=mask) diff --git a/ops/triton/kernels/addmm.py b/ops/triton/kernels/addmm.py new file mode 100644 index 0000000..18fa693 --- /dev/null +++ b/ops/triton/kernels/addmm.py @@ -0,0 +1,101 @@ +import itertools + +import triton +import triton.language as tl + + +@triton.autotune( + configs=tuple( + triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": 8, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m, block_size_n, block_size_k, num_stages, num_warps in itertools.product( + (32, 64, 128), (32, 64, 128, 256), (32, 64), (3, 4, 5), (2, 4, 8) + ) + ), + key=["m", "n", "k"], +) +@triton.jit +def kernel( + input_ptr, + mat1_ptr, + mat2_ptr, + output_ptr, + m, + n, + k, + input_stride_m, + input_stride_n, + mat1_stride_m, + mat1_stride_k, + mat2_stride_k, + mat2_stride_n, + output_stride_m, + output_stride_n, + beta, + alpha, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(m, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(n, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % m + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n + offs_k = tl.arange(0, BLOCK_SIZE_K) + mat1_ptrs = mat1_ptr + ( + offs_am[:, None] * mat1_stride_m + offs_k[None, :] * mat1_stride_k + ) + mat2_ptrs = mat2_ptr + ( + offs_k[:, None] * mat2_stride_k + offs_bn[None, :] * mat2_stride_n + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for i in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + mat1 = tl.load(mat1_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K) + mat2 = tl.load(mat2_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K) + + accumulator = tl.dot(mat1, mat2, accumulator) + + mat1_ptrs += BLOCK_SIZE_K * mat1_stride_k + mat2_ptrs += BLOCK_SIZE_K * mat2_stride_k + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + mask_c = (offs_cm[:, None] < m) & (offs_cn[None, :] < n) + + input_ptrs = ( + input_ptr + + input_stride_m * offs_cm[:, None] + + input_stride_n * offs_cn[None, :] + ) + + input = tl.load(input_ptrs, mask=mask_c) + + output = beta * input + alpha * accumulator.to(tl.float16) + + output_ptrs = ( + output_ptr + + output_stride_m * offs_cm[:, None] + + output_stride_n * offs_cn[None, :] + ) + + tl.store(output_ptrs, output, mask=mask_c) diff --git a/ops/triton/kernels/bmm.py b/ops/triton/kernels/bmm.py new file mode 100644 index 0000000..45cc1dd --- /dev/null +++ b/ops/triton/kernels/bmm.py @@ -0,0 +1,93 @@ +import itertools + +import triton +import triton.language as tl + + +@triton.autotune( + configs=tuple( + triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": 8, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m, block_size_n, block_size_k, num_stages, num_warps in itertools.product( + (32, 64, 128), (32, 64, 128, 256), (32, 64), (3, 4, 5), (2, 4, 8) + ) + ), + key=["m", "n", "k"], +) +@triton.jit +def kernel( + input_ptr, + other_ptr, + output_ptr, + m, + n, + k, + input_stride_b, + input_stride_m, + input_stride_k, + other_stride_b, + other_stride_k, + other_stride_n, + output_stride_b, + output_stride_m, + output_stride_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(m, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(n, BLOCK_SIZE_N) + num_pid_per_batch = num_pid_m * num_pid_n + bid = pid // num_pid_per_batch + pid_in_batch = pid % num_pid_per_batch + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid_in_batch // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid_in_batch % num_pid_in_group) % group_size_m) + pid_n = (pid_in_batch % num_pid_in_group) // group_size_m + + input_ptr_batch = input_ptr + bid * input_stride_b + other_ptr_batch = other_ptr + bid * other_stride_b + output_ptr_batch = output_ptr + bid * output_stride_b + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % m + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n + offs_k = tl.arange(0, BLOCK_SIZE_K) + + input_ptrs = input_ptr_batch + ( + offs_am[:, None] * input_stride_m + offs_k[None, :] * input_stride_k + ) + other_ptrs = other_ptr_batch + ( + offs_k[:, None] * other_stride_k + offs_bn[None, :] * other_stride_n + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for i in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + input = tl.load(input_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K) + other = tl.load(other_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K) + accumulator = tl.dot(input, other, accumulator) + input_ptrs += BLOCK_SIZE_K * input_stride_k + other_ptrs += BLOCK_SIZE_K * other_stride_k + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + output_ptrs = output_ptr_batch + ( + output_stride_m * offs_cm[:, None] + output_stride_n * offs_cn[None, :] + ) + + tl.store( + output_ptrs, accumulator, mask=(offs_cm[:, None] < m) & (offs_cn[None, :] < n) + ) diff --git a/ops/triton/kernels/conv2d.py b/ops/triton/kernels/conv2d.py new file mode 100644 index 0000000..7238276 --- /dev/null +++ b/ops/triton/kernels/conv2d.py @@ -0,0 +1,129 @@ +import itertools + +import triton +import triton.language as tl + + +@triton.autotune( + configs=tuple( + triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m, block_size_n, block_size_k, num_stages, num_warps in itertools.product( + (32, 64, 128), (32, 64, 128, 256), (32, 64), (3, 4, 5), (2, 4, 8) + ) + ), + key=["N", "C", "H", "W", "C", "R", "S"], +) +@triton.jit +def kernel( + input_ptr, + filter_ptr, + output_ptr, + N: tl.constexpr, + C: tl.constexpr, + H: tl.constexpr, + W: tl.constexpr, + K: tl.constexpr, + R: tl.constexpr, + S: tl.constexpr, + input_stride_n, + input_stride_c, + input_stride_h, + input_stride_w, + filter_stride_k, + filter_stride_c, + filter_stride_r, + filter_stride_s, + output_stride_n, + output_stride_k, + output_stride_p, + output_stride_q, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + P: tl.constexpr = H - R + 1 + Q: tl.constexpr = W - S + 1 + + GEMM_N: tl.constexpr = K + GEMM_K: tl.constexpr = C * R * S + + pid = tl.program_id(0) + num_pid_gemm_n = tl.cdiv(GEMM_N, BLOCK_SIZE_N) + pid_gemm_m = pid // num_pid_gemm_n + pid_gemm_n = pid % num_pid_gemm_n + + offs_gemm_i = pid_gemm_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_gemm_j = pid_gemm_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + offs_n = offs_gemm_i // (P * Q) + offs_k = offs_gemm_j + npq_residual = offs_gemm_i % (P * Q) + offs_p = npq_residual // Q + offs_q = npq_residual % Q + + input_offs_gemm_m = ( + offs_n * input_stride_n + offs_p * input_stride_h + offs_q * input_stride_w + )[:, None] + filter_offs_gemm_n = (offs_k * filter_stride_k)[None, :] + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for i in range(0, tl.cdiv(GEMM_K, BLOCK_SIZE_K)): + offs_gemm_k = i * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + offs_c = offs_gemm_k // (R * S) + crs_residual = offs_gemm_k % (R * S) + offs_r = crs_residual // S + offs_s = crs_residual % S + + input_offs_gemm_n = ( + offs_c * input_stride_c + offs_r * input_stride_h + offs_s * input_stride_w + )[None, :] + input_ptrs = input_ptr + input_offs_gemm_m + input_offs_gemm_n + input_mask = ((offs_n < N) & (offs_p < P) & (offs_q < Q))[:, None] & ( + (offs_c < C) & (offs_r < R) & (offs_s < S) + )[None, :] + + input = tl.load(input_ptrs, mask=input_mask) + + filter_offs_gemm_m = ( + offs_c * filter_stride_c + + offs_r * filter_stride_r + + offs_s * filter_stride_s + )[:, None] + filter_ptrs = filter_ptr + filter_offs_gemm_m + filter_offs_gemm_n + filter_mask = (offs_k[None, :] < K) & ( + (offs_c < C) & (offs_r < R) & (offs_s < S) + )[:, None] + + filter = tl.load(filter_ptrs, mask=filter_mask) + + accumulator = tl.dot(input, filter, accumulator) + + output = accumulator + + output_ptrs = ( + output_ptr + + ( + offs_n * output_stride_n + + offs_p * output_stride_p + + offs_q * output_stride_q + )[:, None] + + (offs_k * output_stride_k)[None, :] + ) + output_mask = ( + (offs_n[:, None] < N) + & (offs_k[None, :] < K) + & (offs_p[:, None] < P) + & (offs_q[:, None] < Q) + ) + + tl.store(output_ptrs, output, mask=output_mask) diff --git a/ops/triton/kernels/fused_rms_norm.py b/ops/triton/kernels/fused_rms_norm.py new file mode 100644 index 0000000..28a9616 --- /dev/null +++ b/ops/triton/kernels/fused_rms_norm.py @@ -0,0 +1,32 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel( + x_ptr, + w_ptr, + y_ptr, + num_cols, + x_stride, + w_stride, + y_stride, + eps: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < num_cols + + x_ptrs = x_ptr + row_idx * x_stride + col_offsets + w_ptrs = w_ptr + row_idx * w_stride + col_offsets + + x = tl.load(x_ptrs, mask=mask).to(tl.float32) + w = tl.load(w_ptrs, mask=mask).to(tl.float32) + + y = x * tl.rsqrt(tl.sum(x * x) / num_cols + eps) * w + + y_ptrs = y_ptr + row_idx * y_stride + col_offsets + + tl.store(y_ptrs, y, mask=mask) diff --git a/ops/triton/kernels/mm.py b/ops/triton/kernels/mm.py new file mode 100644 index 0000000..d2a4d55 --- /dev/null +++ b/ops/triton/kernels/mm.py @@ -0,0 +1,85 @@ +import itertools + +import triton +import triton.language as tl + + +@triton.autotune( + configs=tuple( + triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": 8, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m, block_size_n, block_size_k, num_stages, num_warps in itertools.product( + (32, 64, 128), (32, 64, 128, 256), (32, 64), (3, 4, 5), (2, 4, 8) + ) + ), + key=["m", "n", "k"], +) +@triton.jit +def kernel( + input_ptr, + other_ptr, + output_ptr, + m, + n, + k, + input_stride_m, + input_stride_k, + other_stride_k, + other_stride_n, + output_stride_m, + output_stride_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(m, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(n, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % m + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n + offs_k = tl.arange(0, BLOCK_SIZE_K) + input_ptrs = input_ptr + ( + offs_am[:, None] * input_stride_m + offs_k[None, :] * input_stride_k + ) + other_ptrs = other_ptr + ( + offs_k[:, None] * other_stride_k + offs_bn[None, :] * other_stride_n + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for i in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + input = tl.load(input_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K) + other = tl.load(other_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K) + + accumulator = tl.dot(input, other, accumulator) + + input_ptrs += BLOCK_SIZE_K * input_stride_k + other_ptrs += BLOCK_SIZE_K * other_stride_k + + output = accumulator + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + output_ptrs = ( + output_ptr + + output_stride_m * offs_cm[:, None] + + output_stride_n * offs_cn[None, :] + ) + + tl.store(output_ptrs, output, mask=(offs_cm[:, None] < m) & (offs_cn[None, :] < n)) diff --git a/ops/triton/kernels/rms_norm.py b/ops/triton/kernels/rms_norm.py new file mode 100644 index 0000000..852d06f --- /dev/null +++ b/ops/triton/kernels/rms_norm.py @@ -0,0 +1,27 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel( + input_ptr, + output_ptr, + num_cols, + input_stride, + output_stride, + eps: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = input_ptr + row_idx * input_stride + col_offsets + mask = col_offsets < num_cols + + input = tl.load(input_ptrs, mask=mask).to(tl.float32) + + output = input * tl.rsqrt(tl.sum(input * input) / num_cols + eps) + + output_ptrs = output_ptr + row_idx * output_stride + col_offsets + + tl.store(output_ptrs, output, mask=mask) diff --git a/ops/triton/kernels/rotary_position_embedding.py b/ops/triton/kernels/rotary_position_embedding.py new file mode 100644 index 0000000..2e2aa2d --- /dev/null +++ b/ops/triton/kernels/rotary_position_embedding.py @@ -0,0 +1,53 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel( + tensor_ptr, + sin_table_ptr, + cos_table_ptr, + tensor_stride_n, + tensor_stride_l, + tensor_stride_h, + tensor_stride_e, + sin_table_stride_l, + cos_table_stride_l, + emb_dim, + INTERLEAVED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + off_n = tl.program_id(0) + off_l = tl.program_id(1) + off_h = tl.program_id(2) + + offs = tl.arange(0, BLOCK_SIZE) + + half_emb_dim = emb_dim // 2 + mask = offs < half_emb_dim + + sin_table = tl.load(sin_table_ptr + off_l * sin_table_stride_l + offs, mask=mask) + cos_table = tl.load(cos_table_ptr + off_l * cos_table_stride_l + offs, mask=mask) + + even_offs = ( + off_n * tensor_stride_n + off_l * tensor_stride_l + off_h * tensor_stride_h + ) + odd_offs = ( + off_n * tensor_stride_n + off_l * tensor_stride_l + off_h * tensor_stride_h + ) + + if INTERLEAVED: + even_offs += (2 * offs) * tensor_stride_e + odd_offs += (2 * offs + 1) * tensor_stride_e + else: + even_offs += offs * tensor_stride_e + odd_offs += (offs + half_emb_dim) * tensor_stride_e + + even_ptrs = tensor_ptr + even_offs + odd_ptrs = tensor_ptr + odd_offs + + even = tl.load(even_ptrs, mask=mask) + odd = tl.load(odd_ptrs, mask=mask) + + tl.store(even_ptrs, even * cos_table - odd * sin_table, mask=mask) + tl.store(odd_ptrs, even * sin_table + odd * cos_table, mask=mask) diff --git a/ops/triton/kernels/scaled_dot_product_attention.py b/ops/triton/kernels/scaled_dot_product_attention.py new file mode 100644 index 0000000..79e9970 --- /dev/null +++ b/ops/triton/kernels/scaled_dot_product_attention.py @@ -0,0 +1,120 @@ +import itertools + +import triton +import triton.language as tl + + +@triton.autotune( + configs=tuple( + triton.Config( + {"BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n}, + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m, block_size_n, num_stages, num_warps in itertools.product( + (32, 64, 128, 256), (32, 64, 128), (2, 3, 4, 5), (4, 8) + ) + ), + key=["EMB_DIM"], +) +@triton.jit +def kernel( + q_ptr, + k_ptr, + v_ptr, + o_ptr, + q_stride_z, + q_stride_h, + q_stride_m, + q_stride_k, + k_stride_z, + k_stride_h, + k_stride_n, + k_stride_k, + v_stride_z, + v_stride_h, + v_stride_k, + v_stride_n, + o_stride_z, + o_stride_h, + o_stride_m, + o_stride_n, + scale, + seq_len_q, + seq_len_k_v, + EMB_DIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + off_m = tl.program_id(0) + off_h = tl.program_id(1) + off_z = tl.program_id(2) + + offs_m_start = off_m * BLOCK_SIZE_M + + q_off = off_z * q_stride_z + off_h * q_stride_h + q_block_ptr = tl.make_block_ptr( + base=q_ptr + q_off, + shape=(seq_len_q, EMB_DIM), + strides=(q_stride_m, q_stride_k), + offsets=(offs_m_start, 0), + block_shape=(BLOCK_SIZE_M, EMB_DIM), + order=(1, 0), + ) + k_off = off_z * k_stride_z + off_h * k_stride_h + k_block_ptr = tl.make_block_ptr( + base=k_ptr + k_off, + shape=(EMB_DIM, seq_len_k_v), + strides=(k_stride_k, k_stride_n), + offsets=(0, 0), + block_shape=(EMB_DIM, BLOCK_SIZE_N), + order=(0, 1), + ) + v_off = off_z * v_stride_z + off_h * v_stride_h + v_block_ptr = tl.make_block_ptr( + base=v_ptr + v_off, + shape=(seq_len_k_v, EMB_DIM), + strides=(v_stride_k, v_stride_n), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_N, EMB_DIM), + order=(1, 0), + ) + o_off = off_z * o_stride_z + off_h * o_stride_h + o_block_ptr = tl.make_block_ptr( + base=o_ptr + o_off, + shape=(seq_len_q, EMB_DIM), + strides=(o_stride_m, o_stride_n), + offsets=(offs_m_start, 0), + block_shape=(BLOCK_SIZE_M, EMB_DIM), + order=(1, 0), + ) + + q = tl.load(q_block_ptr, boundary_check=(0, 1)) + q = (q * scale * 1.44269504089).to(q_block_ptr.type.element_ty) + + acc = tl.zeros((BLOCK_SIZE_M, EMB_DIM), dtype=tl.float32) + l_i = tl.full((BLOCK_SIZE_M,), 1, dtype=tl.float32) + m_i = tl.full((BLOCK_SIZE_M,), float("-inf"), dtype=tl.float32) + + for i in range(0, tl.cdiv(seq_len_k_v, BLOCK_SIZE_N)): + k = tl.load(k_block_ptr, boundary_check=(0, 1)) + + mask = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) < seq_len_k_v + qk = tl.where(mask, tl.dot(q, k), float("-inf")) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + v = tl.load(v_block_ptr, boundary_check=(0, 1)) + alpha = tl.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + tl.dot(p.to(v_block_ptr.type.element_ty), v) + m_i = m_ij + l_i = l_i * alpha + l_ij + + v_block_ptr = tl.advance(v_block_ptr, (BLOCK_SIZE_N, 0)) + k_block_ptr = tl.advance(k_block_ptr, (0, BLOCK_SIZE_N)) + + acc /= l_i[:, None] + + tl.store(o_block_ptr, acc.to(o_ptr.type.element_ty), boundary_check=(0, 1)) diff --git a/ops/triton/kernels/silu.py b/ops/triton/kernels/silu.py new file mode 100644 index 0000000..b852a4c --- /dev/null +++ b/ops/triton/kernels/silu.py @@ -0,0 +1,16 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel(input_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < num_elements + + input = tl.load(input_ptr + offsets, mask=mask) + output = input * tl.sigmoid(tl.cast(input, tl.float32)) + + tl.store(output_ptr + offsets, output, mask=mask) diff --git a/ops/triton/kernels/softmax.py b/ops/triton/kernels/softmax.py new file mode 100644 index 0000000..1514907 --- /dev/null +++ b/ops/triton/kernels/softmax.py @@ -0,0 +1,31 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel( + input_ptr, + output_ptr, + input_stride, + output_stride, + num_cols, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + + row_start_ptr = input_ptr + row_idx * input_stride + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + mask = col_offsets < num_cols + + row = tl.load(input_ptrs, mask=mask, other=float("-inf")) + + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + output_row_start_ptr = output_ptr + row_idx * output_stride + output_ptrs = output_row_start_ptr + col_offsets + + tl.store(output_ptrs, softmax_output, mask=mask) diff --git a/ops/triton/kernels/swiglu.py b/ops/triton/kernels/swiglu.py new file mode 100644 index 0000000..a154af4 --- /dev/null +++ b/ops/triton/kernels/swiglu.py @@ -0,0 +1,18 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel(a_ptr, b_ptr, c_ptr, num_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < num_elements + + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + + silu_b = b * tl.sigmoid(tl.cast(b, tl.float32)) + c = a * silu_b + + tl.store(c_ptr + offsets, c, mask=mask) diff --git a/ops/triton/torch.py b/ops/triton/torch.py new file mode 100644 index 0000000..6d40df3 --- /dev/null +++ b/ops/triton/torch.py @@ -0,0 +1,300 @@ +import math + +import torch +import triton + +import ops.triton.kernels.add +import ops.triton.kernels.addmm +import ops.triton.kernels.bmm +import ops.triton.kernels.conv2d +import ops.triton.kernels.fused_rms_norm +import ops.triton.kernels.mm +import ops.triton.kernels.rms_norm +import ops.triton.kernels.rotary_position_embedding +import ops.triton.kernels.scaled_dot_product_attention +import ops.triton.kernels.silu +import ops.triton.kernels.softmax +import ops.triton.kernels.swiglu + + +def add(input, other): + num_elements = input.numel() + + output = torch.empty_like(input) + + def grid(meta): + return (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) + + ops.triton.kernels.add.kernel[grid]( + input, other, output, num_elements, BLOCK_SIZE=1024 + ) + + return output + + +def addmm(input, mat1, mat2, beta=1, alpha=1): + output_shape = (mat1.shape[0], mat2.shape[1]) + output = torch.empty(output_shape, dtype=mat1.dtype, device=mat1.device) + + def grid(meta): + return ( + triton.cdiv(mat1.shape[0], meta["BLOCK_SIZE_M"]) + * triton.cdiv(mat2.shape[1], meta["BLOCK_SIZE_N"]), + ) + + ops.triton.kernels.addmm.kernel[grid]( + input, + mat1, + mat2, + output, + mat1.shape[0], + mat2.shape[1], + mat1.shape[1], + input.stride(0), + input.stride(1), + mat1.stride(0), + mat1.stride(1), + mat2.stride(0), + mat2.stride(1), + output.stride(0), + output.stride(1), + beta, + alpha, + ) + + return output + + +def bmm(input, other): + batch, m, k = input.shape + _, _, n = other.shape + output = torch.empty((batch, m, n), dtype=input.dtype, device=input.device) + + def grid(meta): + return ( + batch + * triton.cdiv(m, meta["BLOCK_SIZE_M"]) + * triton.cdiv(n, meta["BLOCK_SIZE_N"]), + ) + + ops.triton.kernels.bmm.kernel[grid]( + input, + other, + output, + m, + n, + k, + *input.stride(), + *other.stride(), + *output.stride(), + ) + + return output + + +def conv2d(input, filter): + n, c, h, w = input.shape + k, _, r, s = filter.shape + p = h - r + 1 + q = w - s + 1 + + output = torch.empty((n, k, p, q), dtype=input.dtype, device=input.device) + + def grid(meta): + return ( + triton.cdiv(n * p * q, meta["BLOCK_SIZE_M"]) + * triton.cdiv(k, meta["BLOCK_SIZE_N"]), + ) + + ops.triton.kernels.conv2d.kernel[grid]( + input, + filter, + output, + n, + c, + h, + w, + k, + r, + s, + *input.stride(), + *filter.stride(), + *output.stride(), + ) + + return output + + +def fused_rms_norm(x, w, eps=None): + if eps is None: + eps = torch.finfo(x.dtype).eps + + x_2d = x.view(-1, x.shape[-1]) + w_2d = w.expand_as(x_2d) + y_2d = torch.empty_like(x_2d) + + ops.triton.kernels.fused_rms_norm.kernel[(x_2d.shape[-2],)]( + x_2d, + w_2d, + y_2d, + x_2d.shape[-1], + x_2d.stride(0), + w_2d.stride(0), + y_2d.stride(0), + eps, + BLOCK_SIZE=triton.next_power_of_2(x_2d.shape[-1]), + ) + + return y_2d.view(x.shape) + + +def mm(input, other): + output_shape = (input.shape[0], other.shape[1]) + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + + def grid(meta): + return ( + triton.cdiv(input.shape[0], meta["BLOCK_SIZE_M"]) + * triton.cdiv(other.shape[1], meta["BLOCK_SIZE_N"]), + ) + + ops.triton.kernels.mm.kernel[grid]( + input, + other, + output, + input.shape[0], + other.shape[1], + input.shape[1], + input.stride(0), + input.stride(1), + other.stride(0), + other.stride(1), + output.stride(0), + output.stride(1), + ) + + return output + + +def rms_norm(input, eps=None): + if eps is None: + eps = torch.finfo(input.dtype).eps + + output = torch.empty_like(input) + + ops.triton.kernels.rms_norm.kernel[(input.shape[-2],)]( + input, + output, + input.shape[-1], + input.stride(-2), + output.stride(-2), + eps, + BLOCK_SIZE=triton.next_power_of_2(input.shape[-1]), + ) + + return output + + +def rotary_position_embedding(input, sin_table, cos_table, interleaved=True): + batch_size, seq_len, num_heads, emb_dim = input.shape + + BLOCK_SIZE = triton.next_power_of_2(emb_dim // 2) + + output = input.clone() + + ops.triton.kernels.rotary_position_embedding.kernel[ + (batch_size, seq_len, num_heads) + ]( + output, + sin_table, + cos_table, + *input.stride(), + sin_table.stride(0), + cos_table.stride(0), + emb_dim, + INTERLEAVED=interleaved, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return output + + +def scaled_dot_product_attention(q, k, v, scale=None): + batch_size, num_heads, seq_len_q, emb_dim = q.shape + _, _, seq_len_k_v, _ = k.shape + + if scale is None: + scale = 1 / math.sqrt(emb_dim) + + o = torch.empty_like(q) + + def grid(meta): + return ( + triton.cdiv(seq_len_q, meta["BLOCK_SIZE_M"]), + num_heads, + batch_size, + ) + + ops.triton.kernels.scaled_dot_product_attention.kernel[grid]( + q, + k, + v, + o, + *q.stride(), + *k.stride(), + *v.stride(), + *o.stride(), + scale=scale, + seq_len_q=seq_len_q, + seq_len_k_v=seq_len_k_v, + EMB_DIM=emb_dim, + ) + + return o + + +def silu(input): + input_flat = input.flatten() + output_flat = torch.empty_like(input_flat) + num_elements = input_flat.numel() + + def grid(meta): + return (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) + + ops.triton.kernels.silu.kernel[grid]( + input_flat, output_flat, num_elements, BLOCK_SIZE=1024 + ) + + return output_flat.view_as(input) + + +def softmax(input): + output = torch.empty_like(input) + + ops.triton.kernels.softmax.kernel[(input.shape[0],)]( + input, + output, + input.stride(0), + output.stride(0), + input.shape[1], + BLOCK_SIZE=triton.next_power_of_2(input.shape[-1]), + ) + + return output + + +def swiglu(a, b): + a_flat = a.flatten() + b_flat = b.flatten() + c_flat = torch.empty_like(a_flat) + + num_elements = a_flat.numel() + + def grid(meta): + return (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) + + ops.triton.kernels.swiglu.kernel[grid]( + a_flat, b_flat, c_flat, num_elements, BLOCK_SIZE=1024 + ) + + return c_flat.view_as(a) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8844f71 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +ninetoothed +torch +matplotlib +pandas +transformers +radon diff --git a/rms_norm.py b/rms_norm.py index e0afb48..c06ee7f 100644 --- a/rms_norm.py +++ b/rms_norm.py @@ -1,41 +1,68 @@ -import ninetoothed -import ninetoothed.language as ntl import torch -import torch.nn as nn -from ninetoothed import Symbol, Tensor +import torch.nn.functional as F +import triton -BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) +import ops.ninetoothed.torch +import ops.triton.torch +if __name__ == "__main__": + torch.manual_seed(0) -@ninetoothed.jit -def fused_rms_norm_kernel( - x: Tensor(2).tile((1, BLOCK_SIZE)), - w: Tensor(2).tile((1, BLOCK_SIZE)), - y: Tensor(2).tile((1, BLOCK_SIZE)), - eps: Tensor(0), -): - x_fp32 = ntl.cast(x, ntl.float32) - y = x_fp32 * ntl.rsqrt(ntl.sum(x_fp32 * x_fp32) / x.shape[-1] + eps) * w # noqa: F841 + dtype = torch.float16 + device = "cuda" + input = torch.randn(1151, 8192, dtype=dtype, device=device) -def fused_rms_norm(x, w, eps=None): - if eps is None: - eps = torch.finfo(x.dtype).eps() + ninetoothed_output = ops.ninetoothed.torch.rms_norm(input) + torch_output = F.rms_norm(input, input.shape[-1:]) + triton_output = ops.triton.torch.rms_norm(input) - x_2d = x.view(-1, x.shape[-1]) - w_2d = w.expand_as(x_2d) - y_2d = torch.empty_like(x_2d) + print(ninetoothed_output) + print(torch_output) + print(triton_output) - fused_rms_norm_kernel(x_2d, w_2d, y_2d, eps, BLOCK_SIZE=x.shape[-1]) + if torch.allclose(ninetoothed_output, torch_output, atol=0.001, rtol=0.005): + print("✅ NineToothed and PyTorch match.") + else: + print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") - return y_2d.view(x.shape) + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["n"], + x_vals=[2**i for i in range(5, 15)], + x_log=True, + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="ms", + plot_name="rms-norm-performance", + args={"m": 4096}, + ) + ) + def benchmark(m, n, provider): + input = torch.randn(m, n, dtype=dtype, device=device) + ninetoothed_output = ops.ninetoothed.torch.rms_norm(input) + torch_output = F.rms_norm(input, input.shape[-1:]) + triton_output = ops.triton.torch.rms_norm(input) -class RMSNorm(nn.Module): - def __init__(self, other): - super().__init__() + assert torch.allclose(ninetoothed_output, torch_output, atol=0.001, rtol=0.005) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) - self.__dict__ = other.__dict__ + if provider == "ninetoothed": + ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.rms_norm(input)) + elif provider == "torch": + ms = triton.testing.do_bench( + lambda: torch.rms_norm(input, input.shape[-1:]) + ) + elif provider == "triton": + ms = triton.testing.do_bench(lambda: ops.triton.torch.rms_norm(input)) - def forward(self, x): - return fused_rms_norm(x, self.weight, self.variance_epsilon) + return ms + + benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/rope.py b/rope.py deleted file mode 100644 index f660a13..0000000 --- a/rope.py +++ /dev/null @@ -1,234 +0,0 @@ -import ninetoothed -import torch -import triton -import triton.language as tl -from ninetoothed import Tensor - - -def arrangement(tensor, sin_table, cos_table, interleaved=True): - emb_dim = tensor.shape[-1] - tile_shape = (1, 1, 1, emb_dim // 2) - - if interleaved: - strides = (-1, -1, -1, 1) - dilation = (1, 1, 1, 2) - else: - strides = None - dilation = None - - tensor_arranged = tensor.tile(tile_shape, strides=strides, dilation=dilation) - tensor_arranged = tensor_arranged.tile((1, 1, 1, 2)) - tensor_arranged.dtype = tensor_arranged.dtype.squeeze((0, 1, 2)) - tensor_arranged.dtype.dtype = tensor_arranged.dtype.dtype.squeeze((0, 1, 2)) - - sin_table_arranged = sin_table.tile(tile_shape) - sin_table_arranged.dtype = sin_table_arranged.dtype.squeeze((0, 1, 2)) - - cos_table_arranged = cos_table.tile(tile_shape) - cos_table_arranged.dtype = cos_table_arranged.dtype.squeeze((0, 1, 2)) - - return tensor_arranged, sin_table_arranged, cos_table_arranged - - -def application(tensor, sin_table, cos_table): - sin_table_loaded = sin_table - cos_table_loaded = cos_table - - tensor_0 = tensor[0] - tensor_1 = tensor[1] - - tensor[0] = tensor_0 * cos_table_loaded - tensor_1 * sin_table_loaded - tensor[1] = tensor_0 * sin_table_loaded + tensor_1 * cos_table_loaded - - -tensors = tuple(Tensor(4, shape_options={"constexpr": True}) for _ in range(3)) -rope_kernel = ninetoothed.make(arrangement, application, tensors) - - -def rope(tensor, sin_table, cos_table): - batch_size, _, num_heads, _ = tensor.shape - - sin_table = sin_table.unsqueeze(1).unsqueeze(0) - sin_table = sin_table.expand(batch_size, -1, num_heads, -1) - cos_table = cos_table.unsqueeze(1).unsqueeze(0) - cos_table = cos_table.expand(batch_size, -1, num_heads, -1) - - tensor_cloned = tensor.clone() - rope_kernel(tensor_cloned, sin_table, cos_table) - - return tensor_cloned - - -@triton.jit -def triton_rope_kernel( - tensor_ptr, - sin_table_ptr, - cos_table_ptr, - tensor_stride_n, - tensor_stride_l, - tensor_stride_h, - tensor_stride_e, - sin_table_stride_l, - cos_table_stride_l, - emb_dim, - BLOCK_SIZE: tl.constexpr, -): - off_n = tl.program_id(0) - off_l = tl.program_id(1) - off_h = tl.program_id(2) - - offs = tl.arange(0, BLOCK_SIZE) - - half_emb_dim = emb_dim // 2 - mask = offs < half_emb_dim - - sin_table = tl.load(sin_table_ptr + off_l * sin_table_stride_l + offs, mask=mask) - cos_table = tl.load(cos_table_ptr + off_l * cos_table_stride_l + offs, mask=mask) - - even_offs = ( - off_n * tensor_stride_n - + off_l * tensor_stride_l - + off_h * tensor_stride_h - + (2 * offs) * tensor_stride_e - ) - odd_offs = ( - off_n * tensor_stride_n - + off_l * tensor_stride_l - + off_h * tensor_stride_h - + (2 * offs + 1) * tensor_stride_e - ) - - even_ptrs = tensor_ptr + even_offs - odd_ptrs = tensor_ptr + odd_offs - - even = tl.load(even_ptrs, mask=mask) - odd = tl.load(odd_ptrs, mask=mask) - - tl.store(even_ptrs, even * cos_table - odd * sin_table, mask=mask) - tl.store(odd_ptrs, even * sin_table + odd * cos_table, mask=mask) - - -def triton_rope( - tensor: torch.Tensor, sin_table: torch.Tensor, cos_table: torch.Tensor -) -> torch.Tensor: - batch_size, seq_len, num_heads, emb_dim = tensor.shape - - assert emb_dim % 2 == 0, "The embedding dimension must be even." - - BLOCK_SIZE = triton.next_power_of_2(emb_dim // 2) - if BLOCK_SIZE > 1024: - BLOCK_SIZE = 1024 - - grid = (batch_size, seq_len, num_heads) - - tensor_cloned = tensor.clone() - - triton_rope_kernel[grid]( - tensor_cloned, - sin_table, - cos_table, - *tensor.stride(), - sin_table.stride(0), - cos_table.stride(0), - emb_dim, - BLOCK_SIZE=BLOCK_SIZE, - ) - - return tensor_cloned - - -def torch_rope(input, sin_table, cos_table): - batch_size, seq_len, num_heads, emb_dim = input.shape - - assert emb_dim % 2 == 0, "The embedding dimension must be even." - - pair_wise_input = input.view(batch_size, seq_len, num_heads, emb_dim // 2, 2) - sin_table = sin_table[None, :, None, :] - cos_table = cos_table[None, :, None, :] - - pair_0, pair_1 = pair_wise_input[..., 0], pair_wise_input[..., 1] - rotated_pair_0 = pair_0 * cos_table - pair_1 * sin_table - rotated_pair_1 = pair_0 * sin_table + pair_1 * cos_table - - output = torch.stack((rotated_pair_0, rotated_pair_1), dim=-1).view(input.shape) - - return output - - -def _generate_sin_and_cos_tables( - seq_len, emb_dim, base=10000, dtype=torch.float32, device="cuda" -): - assert emb_dim % 2 == 0, "The embedding dimension must be even." - - theta = base ** ( - -2 * (torch.arange(emb_dim // 2, dtype=dtype, device=device) / emb_dim) - ) - - positions = torch.arange(seq_len, dtype=dtype, device=device).unsqueeze(1) - sin_table = torch.sin(positions * theta) - cos_table = torch.cos(positions * theta) - - return sin_table, cos_table - - -if __name__ == "__main__": - torch.manual_seed(0) - batch_size, seq_len, num_heads, emb_dim = 4, 128, 8, 64 - sin_table, cos_table = _generate_sin_and_cos_tables(seq_len, emb_dim) - dtype = torch.float32 - device = "cuda" - x = torch.randn(batch_size, seq_len, num_heads, emb_dim, dtype=dtype, device=device) - ninetoothed_output = rope(x, sin_table, cos_table) - torch_output = torch_rope(x, sin_table, cos_table) - triton_output = triton_rope(x, sin_table, cos_table) - print(ninetoothed_output) - print(torch_output) - print(triton_output) - if torch.allclose(ninetoothed_output, torch_output, atol=0.001): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["seq_len"], - x_vals=[2**i for i in range(5, 15)], - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="GB/s", - plot_name="rope-performance", - args={}, - ) - ) - def benchmark(seq_len, provider): - batch_size, num_heads, emb_dim = 4, 32, 64 - shape = (batch_size, seq_len, num_heads, emb_dim) - dtype = torch.float16 - device = "cuda" - - x = torch.randn(shape, dtype=dtype, device=device) - sin_table, cos_table = _generate_sin_and_cos_tables(seq_len, emb_dim) - - if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: rope(x, sin_table, cos_table)) - elif provider == "torch": - ms = triton.testing.do_bench(lambda: torch_rope(x, sin_table, cos_table)) - elif provider == "triton": - ms = triton.testing.do_bench(lambda: triton_rope(x, sin_table, cos_table)) - - def gbps(ms): - x_bytes = x.numel() * x.element_size() - sin_table_bytes = sin_table.numel() * sin_table.element_size() - cos_table_bytes = cos_table.numel() * cos_table.element_size() - - return (x_bytes + sin_table_bytes + cos_table_bytes) / ms * 1e-6 - - return gbps(ms) - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/rotary_position_embedding.py b/rotary_position_embedding.py new file mode 100644 index 0000000..02719f6 --- /dev/null +++ b/rotary_position_embedding.py @@ -0,0 +1,122 @@ +import torch +import triton + +import ops.ninetoothed.torch +import ops.triton.torch + + +def torch_rotary_position_embedding(input, sin_table, cos_table, interleaved=True): + batch_size, seq_len, num_heads, emb_dim = input.shape + + assert emb_dim % 2 == 0, "The embedding dimension must be even." + + sin_table = sin_table[None, :, None, :] + cos_table = cos_table[None, :, None, :] + + if interleaved: + pair_wise_input = input.view(batch_size, seq_len, num_heads, emb_dim // 2, 2) + input_0, input_1 = pair_wise_input[..., 0], pair_wise_input[..., 1] + input_0_rotated = input_0 * cos_table - input_1 * sin_table + input_1_rotated = input_0 * sin_table + input_1 * cos_table + + return torch.stack((input_0_rotated, input_1_rotated), dim=-1).view(input.shape) + else: + input_0 = x[..., : x.shape[-1] // 2] + input_1 = x[..., x.shape[-1] // 2 :] + input_0_rotated = input_0 * cos_table - input_1 * sin_table + input_1_rotated = input_0 * sin_table + input_1 * cos_table + + return torch.cat((input_0_rotated, input_1_rotated), dim=-1) + + +def _generate_sin_and_cos_tables( + seq_len, emb_dim, base=10000, dtype=torch.float32, device="cuda" +): + assert emb_dim % 2 == 0, "The embedding dimension must be even." + + theta = base ** ( + -2 * (torch.arange(emb_dim // 2, dtype=dtype, device=device) / emb_dim) + ) + + positions = torch.arange(seq_len, dtype=dtype, device=device).unsqueeze(1) + sin_table = torch.sin(positions * theta) + cos_table = torch.cos(positions * theta) + + return sin_table, cos_table + + +if __name__ == "__main__": + torch.manual_seed(0) + + batch_size, seq_len, num_heads, emb_dim = 4, 128, 8, 64 + dtype = torch.float32 + device = "cuda" + + sin_table, cos_table = _generate_sin_and_cos_tables(seq_len, emb_dim) + x = torch.randn(batch_size, seq_len, num_heads, emb_dim, dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.rotary_position_embedding( + x, sin_table, cos_table, interleaved=False + ) + torch_output = torch_rotary_position_embedding( + x, sin_table, cos_table, interleaved=False + ) + triton_output = ops.triton.torch.rotary_position_embedding( + x, sin_table, cos_table, interleaved=False + ) + + print(ninetoothed_output) + print(torch_output) + print(triton_output) + + if torch.allclose(ninetoothed_output, torch_output, atol=0.001): + print("✅ NineToothed and PyTorch match.") + else: + print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["seq_len"], + x_vals=[2**i for i in range(5, 15)], + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="ms", + plot_name="rotary-position-embedding-performance", + args={}, + ) + ) + def benchmark(seq_len, provider): + batch_size, num_heads, emb_dim = 4, 32, 64 + shape = (batch_size, seq_len, num_heads, emb_dim) + dtype = torch.float16 + device = "cuda" + + sin_table, cos_table = _generate_sin_and_cos_tables(seq_len, emb_dim) + x = torch.randn(shape, dtype=dtype, device=device) + + if provider == "ninetoothed": + ms = triton.testing.do_bench( + lambda: ops.ninetoothed.torch.rotary_position_embedding( + x, sin_table, cos_table + ) + ) + elif provider == "torch": + ms = triton.testing.do_bench( + lambda: torch_rotary_position_embedding(x, sin_table, cos_table) + ) + elif provider == "triton": + ms = triton.testing.do_bench( + lambda: ops.triton.torch.rotary_position_embedding( + x, sin_table, cos_table + ) + ) + + return ms + + benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/run_experiments.py b/run_experiments.py new file mode 100644 index 0000000..1b0b68f --- /dev/null +++ b/run_experiments.py @@ -0,0 +1,192 @@ +import argparse +import functools +import random +import subprocess + +import pandas as pd +import torch +import torch.nn.functional +import triton + +import ops.ninetoothed.torch +import ops.triton.torch +import rotary_position_embedding + +PROMPTS = ( + "The emergence of deep learning domain-specific languages (DSLs) has substantially reduced the obstacles in developing high-performance, cross-platform compute kernels, but current DSLs", + "Driven by recent advancements in the AI industry, the AI accelerator sector has increasingly diversified, with vendors developing their own hardware architectures and programming models, such as NVIDIA", +) + +NUM_WARMUP_ITERATIONS = 1 + +NUM_PROFILING_ITERATIONS = 3 + +BACKENDS = ("ninetoothed", "triton", "torch") + +ALL_MAX_NEW_TOKENS = (128, 512, 2048) + + +def _run_task(op_name, dtype, device, *arg_shapes, **kwarg_shapes): + ninetoothed_op = getattr(ops.ninetoothed.torch, op_name) + triton_op = getattr(ops.triton.torch, op_name) + + if op_name == "rotary_position_embedding": + torch_op = rotary_position_embedding.torch_rotary_position_embedding + else: + torch_op = ( + getattr(torch, op_name) + if hasattr(torch, op_name) + else getattr(torch.nn.functional, op_name) + ) + + if op_name == "rms_norm": + torch_op = functools.partial(torch_op, normalized_shape=arg_shapes[0][-1:]) + elif op_name == "softmax": + torch_op = functools.partial(torch_op, dim=-1) + + args = tuple( + torch.randn(shape, dtype=dtype, device=device) if shape else random.gauss(0, 1) + for shape in arg_shapes + ) + kwargs = { + key: torch.randn(shape, dtype=dtype, device=device) + if shape + else random.gauss(0, 1) + for key, shape in kwarg_shapes.items() + } + + arg_shape_string = ", ".join(str(shape) for shape in arg_shapes) + kwarg_shape_string = ", ".join( + f"{key}={shape}" for key, shape in kwarg_shapes.items() + ) + shape_string = ( + f"{arg_shape_string}, {kwarg_shape_string}" + if kwarg_shape_string + else arg_shape_string + ) + + task_description = f"{op_name}({shape_string})" + + return task_description, _benchmark_ops( + (ninetoothed_op, triton_op, torch_op), *args, **kwargs + ) + + +def _benchmark_ops(ops, *args, **kwargs): + assert all( + torch.allclose( + op(*args, **kwargs), ops[0](*args, **kwargs), rtol=0.01, atol=0.01 + ) + for op in ops[1:] + ) + + return tuple(triton.testing.do_bench(lambda: op(*args, **kwargs)) for op in ops) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run experiments.") + + parser.add_argument( + "--model", + type=str, + required=True, + help="Path to the model or model identifier from Hugging Face.", + ) + + args = parser.parse_args() + + model_name_or_path = args.model + + random.seed(0) + torch.manual_seed(0) + + with open("torch.utils.collect_env.log", "w") as f: + subprocess.run( + ("python", "-m", "torch.utils.collect_env"), stdout=f, stderr=f, check=True + ) + + radon_commands = ( + ( + "radon", + "cc", + "--show-complexity", + "--json", + "--output-file", + "cc.json", + "ops/", + ), + ("radon", "mi", "--show", "--json", "--output-file", "mi.json", "ops/"), + ("radon", "raw", "--json", "--output-file", "raw.json", "ops/"), + ("radon", "hal", "--json", "--output-file", "hal.json", "ops/"), + ) + + for command in radon_commands: + subprocess.run(command, check=True) + + with open("code_evaluation.tex", "w") as f: + subprocess.run(("python", "evaluate_code.py"), stdout=f, check=True) + + dtype = torch.float16 + device = "cuda" + + tasks = ( + ("add", ((4096 * 4096,), (4096 * 4096,)), {}), + ( + "addmm", + ((4096, 4096), (4096, 4096), (4096, 4096)), + {"beta": (), "alpha": ()}, + ), + ("bmm", ((4, 2048, 2048), (4, 2048, 2048)), {}), + ("conv2d", ((4, 512, 14, 14), (512, 512, 3, 3)), {}), + ("mm", ((4096, 4096), (4096, 4096)), {}), + ("rms_norm", ((4096, 4096),), {}), + ("rotary_position_embedding", ((4, 1024, 48, 64), (1024, 32), (1024, 32)), {}), + ( + "scaled_dot_product_attention", + ((4, 48, 1024, 64), (4, 48, 1024, 64), (4, 48, 1024, 64)), + {}, + ), + ("silu", ((4096 * 4096,),), {}), + ("softmax", ((4096, 4096),), {}), + ) + + data = {"Task": [], "NineToothed": [], "Triton": [], "PyTorch": []} + + for name, args, kwargs in tasks: + description, results = _run_task(name, dtype, device, *args, **kwargs) + + data["Task"].append(description) + + for i, provider in enumerate(("NineToothed", "Triton", "PyTorch")): + data[provider].append(results[i]) + + pd.DataFrame(data).set_index("Task").to_csv("microbenchmark_data.csv") + + for max_new_tokens in ALL_MAX_NEW_TOKENS: + for backend in BACKENDS: + with open(f"infer_{max_new_tokens}_{backend}.json", "w") as f: + subprocess.run( + ( + "python", + "infer.py", + "--model", + model_name_or_path, + "--prompts", + *PROMPTS, + "--max-new-tokens", + str(max_new_tokens), + "--device", + "cuda", + "--backend", + "ninetoothed", + "--num-warmup-iterations", + str(NUM_WARMUP_ITERATIONS), + "--num-profiling-iterations", + str(NUM_PROFILING_ITERATIONS), + ), + stdout=f, + check=True, + ) + + with open("performance_evaluation.tex", "w") as f: + subprocess.run(("python", "evaluate_performance.py"), stdout=f, check=True) diff --git a/scaled_dot_product_attention.py b/scaled_dot_product_attention.py new file mode 100644 index 0000000..11f0ec0 --- /dev/null +++ b/scaled_dot_product_attention.py @@ -0,0 +1,196 @@ +from contextlib import contextmanager + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +from transformers.models.llama.modeling_llama import repeat_kv + +import ops.ninetoothed.torch +import ops.triton.torch +from rotary_position_embedding import torch_rotary_position_embedding + + +class Attention(nn.Module): + scaled_dot_product_attention = None + + rotary_position_embedding = None + + def __init__(self, other): + super().__init__() + + self.__dict__ = other.__dict__ + + def forward( + self, + hidden_states, + position_embeddings, + attention_mask, + past_key_value, + cache_position, + **kwargs, + ): + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) + + cos_table, sin_table = position_embeddings + sin_table = sin_table[0] + cos_table = cos_table[0] + + query_states = type(self).rotary_position_embedding( + query_states, sin_table, cos_table + ) + key_states = type(self).rotary_position_embedding( + key_states, sin_table, cos_table + ) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin_table, + "cos": cos_table, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = type(self).scaled_dot_product_attention( + query_states, key_states, value_states, scale=self.scaling + ) + attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, None + + +@contextmanager +def scaled_dot_product_attention_backend(backend_name): + _prev_impl = Attention.scaled_dot_product_attention + + if backend_name == "ninetoothed": + impl = ops.ninetoothed.torch.scaled_dot_product_attention + elif backend_name == "triton": + impl = ops.triton.torch.scaled_dot_product_attention + elif backend_name == "torch": + impl = F.scaled_dot_product_attention + else: + raise ValueError(f"unknown backend: `{backend_name}`") + + Attention.scaled_dot_product_attention = impl + + try: + yield + finally: + Attention.scaled_dot_product_attention = _prev_impl + + +@contextmanager +def rotary_position_embedding_backend(backend_name): + _prev_impl = Attention.rotary_position_embedding + + if backend_name == "ninetoothed": + impl = ops.ninetoothed.torch.rotary_position_embedding + elif backend_name == "triton": + impl = ops.triton.torch.rotary_position_embedding + elif backend_name == "torch": + impl = torch_rotary_position_embedding + else: + raise ValueError(f"unknown backend: `{backend_name}`") + + Attention.rotary_position_embedding = impl + + try: + yield + finally: + Attention.rotary_position_embedding = _prev_impl + + +if __name__ == "__main__": + torch.manual_seed(0) + + q_o_shape = (2, 8, 1024, 64) + k_v_shape = (2, 8, 1024, 64) + dtype = torch.float16 + device = "cuda" + + q = torch.randn(q_o_shape, dtype=dtype, device=device) + k = torch.randn(k_v_shape, dtype=dtype, device=device) + v = torch.randn(k_v_shape, dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.scaled_dot_product_attention(q, k, v) + torch_output = F.scaled_dot_product_attention(q, k, v) + triton_output = ops.triton.torch.scaled_dot_product_attention(q, k, v) + + print(ninetoothed_output) + print(torch_output) + print(triton_output) + + if torch.allclose(ninetoothed_output, torch_output, atol=0.01): + print("✅ NineToothed and PyTorch match.") + else: + print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output, atol=1e-3, rtol=0): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["seq_len"], + x_vals=[2**i for i in range(7, 17)], + x_log=True, + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="ms", + plot_name="scaled-dot-product-attention-performance", + args={}, + ) + ) + def benchmark(seq_len, provider): + batch_size, num_heads, emb_dim = 4, 32, 64 + shape = (batch_size, num_heads, seq_len, emb_dim) + dtype = torch.float16 + device = "cuda" + + q = torch.randn(shape, dtype=dtype, device=device) + k = torch.randn(shape, dtype=dtype, device=device) + v = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.scaled_dot_product_attention(q, k, v) + torch_output = F.scaled_dot_product_attention(q, k, v) + triton_output = ops.triton.torch.scaled_dot_product_attention(q, k, v) + + assert torch.allclose(ninetoothed_output, torch_output, atol=0.025, rtol=0.025) + assert torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.001) + + if provider == "ninetoothed": + ms = triton.testing.do_bench( + lambda: ops.ninetoothed.torch.scaled_dot_product_attention(q, k, v) + ) + elif provider == "torch": + ms = triton.testing.do_bench( + lambda: F.scaled_dot_product_attention(q, k, v) + ) + elif provider == "triton": + ms = triton.testing.do_bench( + lambda: ops.triton.torch.scaled_dot_product_attention(q, k, v) + ) + + return ms + + benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/silu.py b/silu.py index 00d6006..fafae5d 100644 --- a/silu.py +++ b/silu.py @@ -1,67 +1,104 @@ -import ninetoothed -import ninetoothed.language as ntl +from contextlib import contextmanager + import torch import torch.nn as nn import torch.nn.functional as F -from ninetoothed import Symbol, Tensor - -BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) -BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) -BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True) - - -def arrangement( - input, - output, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, -): - tile_shape = (BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K) - - return input.tile(tile_shape), output.tile(tile_shape) +import triton +import ops.ninetoothed.torch +import ops.triton.torch -def application(input, output): - output = input * ntl.sigmoid(input) # noqa: F841 +class SiLU(nn.Module): + silu = None -silu_kernel = ninetoothed.make(arrangement, application, (Tensor(3), Tensor(3))) - + def __init__(self, other): + super().__init__() -def silu(input): - output = torch.empty_like(input) + self.__dict__ = other.__dict__ - silu_kernel(input, output) + def forward(self, input): + return type(self).silu(input) - return output +@contextmanager +def silu_backend(backend_name): + _prev_impl = SiLU.silu -class SiLU(nn.Module): - def __init__(self, other): - super().__init__() + if backend_name == "ninetoothed": + impl = ops.ninetoothed.torch.silu + elif backend_name == "triton": + impl = ops.triton.torch.silu + elif backend_name == "torch": + impl = F.silu + else: + raise ValueError(f"unknown backend: `{backend_name}`") - self.__dict__ = other.__dict__ + SiLU.silu = impl - def forward(self, input): - return silu(input) + try: + yield + finally: + SiLU.silu = _prev_impl if __name__ == "__main__": torch.manual_seed(0) shape = (8, 256, 512) - dtype = torch.float32 + dtype = torch.float16 device = "cuda" + input = torch.randn(shape, dtype=dtype, device=device) - ninetoothed_output = silu(input) + ninetoothed_output = ops.ninetoothed.torch.silu(input) torch_output = F.silu(input) + triton_output = ops.triton.torch.silu(input) print(ninetoothed_output) print(torch_output) + print(triton_output) - if torch.allclose(ninetoothed_output, torch_output): + if torch.allclose(ninetoothed_output, torch_output, atol=1e-3, rtol=1e-3): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["m", "n", "k"], + x_vals=[2**i for i in range(3, 10)], + x_log=True, + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="ms", + plot_name="silu-performance", + args={}, + ) + ) + def benchmark(m, n, k, provider): + input = torch.randn(m, n, k, dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.silu(input) + torch_output = F.silu(input) + triton_output = ops.triton.torch.silu(input) + + assert torch.allclose(ninetoothed_output, torch_output, atol=0.001) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) + + if provider == "ninetoothed": + ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.silu(input)) + elif provider == "torch": + ms = triton.testing.do_bench(lambda: F.silu(input)) + elif provider == "triton": + ms = triton.testing.do_bench(lambda: ops.triton.torch.silu(input)) + + return ms + + benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/softmax.py b/softmax.py index d42f316..f979928 100644 --- a/softmax.py +++ b/softmax.py @@ -1,123 +1,65 @@ -import ninetoothed -import ninetoothed.language as ntl import torch import triton -import triton.language as tl -from ninetoothed import Symbol, Tensor -BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) - - -@ninetoothed.jit -def softmax_kernel( - input_row: Tensor(2, other=float("-inf")).tile((1, BLOCK_SIZE)), - output_row: Tensor(2).tile((1, BLOCK_SIZE)), -): - row_minus_max = input_row - ntl.max(input_row) - numerator = ntl.exp(row_minus_max) - denominator = ntl.sum(numerator) - output_row = numerator / denominator # noqa: F841 - - -def softmax(input): - output = torch.empty_like(input) - - softmax_kernel(input, output, BLOCK_SIZE=input.shape[-1]) - - return output - - -@triton.jit -def triton_softmax_kernel( - input_ptr, - output_ptr, - input_row_stride, - output_row_stride, - n_rows, - n_cols, - BLOCK_SIZE: tl.constexpr, -): - row_idx = tl.program_id(0) - - row_start_ptr = input_ptr + row_idx * input_row_stride - col_offsets = tl.arange(0, BLOCK_SIZE) - input_ptrs = row_start_ptr + col_offsets - mask = col_offsets < n_cols - - row = tl.load(input_ptrs, mask=mask, other=float("-inf")) - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - - output_row_start_ptr = output_ptr + row_idx * output_row_stride - output_ptrs = output_row_start_ptr + col_offsets - tl.store(output_ptrs, softmax_output, mask=mask) - - -def triton_softmax(input): - output = torch.empty_like(input) - - triton_softmax_kernel[(input.shape[0],)]( - input, - output, - input.stride(0), - output.stride(0), - input.shape[0], - input.shape[1], - BLOCK_SIZE=triton.next_power_of_2(input.shape[-1]), - ) - - return output - - -torch.manual_seed(0) -input = torch.randn(1823, 781, device="cuda") -ninetoothed_output = softmax(input) -torch_output = torch.softmax(input, axis=-1) -triton_output = triton_softmax(input) -print(ninetoothed_output) -print(torch_output) -print(triton_output) -if torch.allclose(ninetoothed_output, torch_output): - print("✅ NineToothed and PyTorch match.") -else: - print("❌ NineToothed and PyTorch differ.") -if torch.allclose(ninetoothed_output, triton_output): - print("✅ NineToothed and Triton match.") -else: - print("❌ NineToothed and Triton differ.") - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["n"], - x_vals=[128 * i for i in range(2, 100)], - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="GB/s", - plot_name="softmax-performance", - args={"m": 4096}, +import ops.ninetoothed.torch +import ops.triton.torch + +if __name__ == "__main__": + torch.manual_seed(0) + + dtype = torch.float16 + device = "cuda" + + input = torch.randn(1823, 781, dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.softmax(input) + torch_output = torch.softmax(input, axis=-1) + triton_output = ops.triton.torch.softmax(input) + + print(ninetoothed_output) + print(torch_output) + print(triton_output) + + if torch.allclose(ninetoothed_output, torch_output, atol=0.001): + print("✅ NineToothed and PyTorch match.") + else: + print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["n"], + x_vals=[2**i for i in range(5, 15)], + x_log=True, + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="ms", + plot_name="softmax-performance", + args={"m": 4096}, + ) ) -) -def benchmark(m, n, provider): - input = torch.randn(m, n, device="cuda", dtype=torch.float32) - stream = torch.cuda.Stream() - torch.cuda.set_stream(stream) + def benchmark(m, n, provider): + input = torch.randn(m, n, dtype=dtype, device=device) - if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: softmax(input)) - elif provider == "torch": - ms = triton.testing.do_bench(lambda: torch.softmax(input, axis=-1)) - elif provider == "triton": - ms = triton.testing.do_bench(lambda: triton_softmax(input)) + ninetoothed_output = ops.ninetoothed.torch.softmax(input) + torch_output = torch.softmax(input, axis=-1) + triton_output = ops.triton.torch.softmax(input) - def gbps(ms): - return 2 * input.numel() * input.element_size() * 1e-6 / ms + assert torch.allclose(ninetoothed_output, torch_output, atol=0.001) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) - return gbps(ms) + if provider == "ninetoothed": + ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.softmax(input)) + elif provider == "torch": + ms = triton.testing.do_bench(lambda: torch.softmax(input, axis=-1)) + elif provider == "triton": + ms = triton.testing.do_bench(lambda: ops.triton.torch.softmax(input)) + return ms -benchmark.run(show_plots=True, print_data=True, save_path=".") + benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/swiglu.py b/swiglu.py index 9c1454b..edf70c1 100644 --- a/swiglu.py +++ b/swiglu.py @@ -1,66 +1,9 @@ -import ninetoothed -import ninetoothed.language as ntl import torch import torch.nn.functional as F import triton -import triton.language as tl -from ninetoothed import Symbol, Tensor -BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True) - - -@ninetoothed.jit -def swiglu_kernel( - a: Tensor(1).tile((BLOCK_SIZE,)), - b: Tensor(1).tile((BLOCK_SIZE,)), - c: Tensor(1).tile((BLOCK_SIZE,)), -): - b_loaded = b - gate = b_loaded * ntl.sigmoid(ntl.cast(b_loaded, ntl.float32)) - c = a * gate # noqa: F841 - - -def swiglu(a, b): - a_1d = a.flatten() - b_1d = b.flatten() - - c = torch.empty_like(a_1d) - - swiglu_kernel(a_1d, b_1d, c) - - return c.view_as(a) - - -@triton.jit -def triton_swiglu_kernel( - a_ptr, b_ptr, c_ptr, num_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr -): - pid = tl.program_id(0) - offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < num_elements - - a = tl.load(a_ptr + offsets, mask=mask, other=0.0) - b = tl.load(b_ptr + offsets, mask=mask, other=0.0) - - silu_b = b * tl.sigmoid(tl.cast(b, tl.float32)) - c = a * silu_b - - tl.store(c_ptr + offsets, c, mask=mask) - - -def triton_swiglu(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - # Flatten the inputs so that the kernel always works on 1D tensors - a_flat = a.flatten() - b_flat = b.flatten() - c_flat = torch.empty_like(a_flat) - num_elements = a_flat.numel() - - def grid(meta): - return (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) - - triton_swiglu_kernel[grid](a_flat, b_flat, c_flat, num_elements, BLOCK_SIZE=1024) - - return c_flat.view_as(a) +import ops.ninetoothed.torch +import ops.triton.torch def torch_swiglu( @@ -76,13 +19,15 @@ def torch_swiglu( shape = (13, 3) dtype = torch.float16 device = "cuda" - a = torch.rand(shape, dtype=dtype, device=device) - b = torch.rand(shape, dtype=dtype, device=device) - c = torch.rand(shape, dtype=dtype, device=device) - ninetoothed_output = swiglu(a, b) + a = torch.randn(shape, dtype=dtype, device=device) + b = torch.randn(shape, dtype=dtype, device=device) + c = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.swiglu(a, b) torch_output = torch_swiglu(a, b) - triton_output = triton_swiglu(a, b) + triton_output = ops.triton.torch.swiglu(a, b) + print(ninetoothed_output) print(torch_output) print(triton_output) @@ -105,36 +50,24 @@ def torch_swiglu( line_vals=["ninetoothed", "torch", "triton"], line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="GB/s", + ylabel="ms", plot_name="swiglu-performance", args={}, ) ) def benchmark(m, n, provider): shape = (m, n) - dtype = torch.float16 - device = "cuda" - a = torch.rand(shape, dtype=dtype, device=device) - b = torch.rand(shape, dtype=dtype, device=device) - quantiles = [0.5, 0.2, 0.8] + a = torch.randn(shape, dtype=dtype, device=device) + b = torch.randn(shape, dtype=dtype, device=device) if provider == "ninetoothed": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: swiglu(a, b), quantiles=quantiles - ) + ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.swiglu(a, b)) elif provider == "torch": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: torch_swiglu(a, b), quantiles=quantiles - ) + ms = triton.testing.do_bench(lambda: torch_swiglu(a, b)) elif provider == "triton": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: triton_swiglu(a, b), quantiles=quantiles - ) - - def gbps(ms): - return 3 * a.numel() * a.element_size() / ms * 1e-6 + ms = triton.testing.do_bench(lambda: ops.triton.torch.swiglu(a, b)) - return gbps(ms), gbps(max_ms), gbps(min_ms) + return ms benchmark.run(print_data=True, show_plots=True, save_path=".")