From 0d220a55b32541c2234b223b0d8dbf39047d246e Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 10 Jan 2025 23:40:23 +0800 Subject: [PATCH 01/68] Replace auto-tuning with hard-coding --- attention.py | 31 +++-------------- conv2d.py | 93 ++++--------------------------------------------- matmul.py | 97 +++++----------------------------------------------- 3 files changed, 20 insertions(+), 201 deletions(-) diff --git a/attention.py b/attention.py index 545b51c..a115d7f 100644 --- a/attention.py +++ b/attention.py @@ -8,8 +8,8 @@ def arrangement(q, k, v, o): - BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) - BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) + BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", constexpr=True) + BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", constexpr=True) def arrange_q_or_o(input): arranged = input.tile((1, 1, BLOCK_SIZE_M, -1)) @@ -63,34 +63,11 @@ def application(q, k, v, o): def attention(q, k, v): o = torch.empty_like(q, dtype=v.dtype) - attention_kernel(q, k, v, o) + attention_kernel(q, k, v, o, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64) return o -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32}, num_stages=4, num_warps=8 - ), - ], - key=["EMB_DIM"], -) @triton.jit def triton_attention_kernel( q_ptr, @@ -214,6 +191,8 @@ def grid(meta): *o.stride(), SEQ_LEN=seq_len, EMB_DIM=emb_dim, + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=64, ) return o diff --git a/conv2d.py b/conv2d.py index e6740c1..f17062d 100644 --- a/conv2d.py +++ b/conv2d.py @@ -38,96 +38,13 @@ def conv2d(input, filter): output = torch.empty((n, k, p, q), device=input.device, dtype=input.dtype) - conv2d_kernel(input, filter, output) + conv2d_kernel( + input, filter, output, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64 + ) return output -@triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - ], - key=["n", "c", "h", "w", "k", "r", "s"], -) @triton.jit def triton_conv2d_kernel( input_ptr, @@ -270,6 +187,10 @@ def grid(meta): *input.stride(), *filter.stride(), *output.stride(), + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=64, + GROUP_SIZE_M=8, ) return output diff --git a/matmul.py b/matmul.py index d5aa445..74cca86 100644 --- a/matmul.py +++ b/matmul.py @@ -7,9 +7,9 @@ def arrangement(lhs, rhs, output): - BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) - BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) - BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True) + BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", constexpr=True) + BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", constexpr=True) + BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", constexpr=True) output_tiled = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N)) @@ -47,96 +47,11 @@ def matmul(lhs, rhs): (lhs.shape[0], rhs.shape[1]), device=lhs.device, dtype=torch.float16 ) - matmul_kernel(lhs, rhs, output) + matmul_kernel(lhs, rhs, output, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64) return output -@triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - ], - key=["m", "n", "k"], -) @triton.jit def triton_matmul_kernel( lhs_ptr, @@ -220,6 +135,10 @@ def grid(meta): rhs.stride(1), output.stride(0), output.stride(1), + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=64, + GROUP_SIZE_M=8, ) return output From 5ee0bb9ce1321d093ab0191c4f416f8a2cab0a09 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 11 Jan 2025 00:39:30 +0800 Subject: [PATCH 02/68] Use the half-precision floating-point format as the data type for arguments --- add.py | 9 +++++---- conv2d.py | 10 ++++++---- softmax.py | 6 +++--- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/add.py b/add.py index 06e244c..c0829e1 100644 --- a/add.py +++ b/add.py @@ -59,8 +59,9 @@ def grid(meta): torch.manual_seed(0) size = 98432 -lhs = torch.rand(size, device="cuda") -rhs = torch.rand(size, device="cuda") +dtype = torch.float16 +lhs = torch.rand(size, dtype=dtype, device="cuda") +rhs = torch.rand(size, dtype=dtype, device="cuda") ninetoothed_output = add(lhs, rhs) torch_output = lhs + rhs triton_output = triton_add(lhs, rhs) @@ -92,8 +93,8 @@ def grid(meta): ) ) def benchmark(size, provider): - lhs = torch.rand(size, device="cuda", dtype=torch.float32) - rhs = torch.rand(size, device="cuda", dtype=torch.float32) + lhs = torch.rand(size, device="cuda", dtype=torch.float16) + rhs = torch.rand(size, device="cuda", dtype=torch.float16) quantiles = [0.5, 0.2, 0.8] if provider == "ninetoothed": diff --git a/conv2d.py b/conv2d.py index f17062d..434ecea 100644 --- a/conv2d.py +++ b/conv2d.py @@ -200,8 +200,9 @@ def grid(meta): torch.manual_seed(0) n, c, h, w = 4, 3, 224, 224 k, _, r, s = 8, c, 3, 3 - input = torch.randn(n, c, h, w, device="cuda") - filter = torch.randn(k, c, r, s, device="cuda") + dtype = torch.float16 + input = torch.randn(n, c, h, w, dtype=dtype, device="cuda") + filter = torch.randn(k, c, r, s, dtype=dtype, device="cuda") ninetoothed_output = conv2d(input, filter) torch_output = F.conv2d(input, filter) triton_output = triton_conv2d(input, filter) @@ -233,8 +234,9 @@ def grid(meta): def benchmark(h, w, provider): n, c, _, _ = 64, 3, h, w k, _, r, s = 64, c, 3, 3 - input = torch.randn((n, c, h, w), device="cuda") - filter = torch.randn((k, c, r, s), device="cuda") + dtype = torch.float16 + input = torch.randn((n, c, h, w), dtype=dtype, device="cuda") + filter = torch.randn((k, c, r, s), dtype=dtype, device="cuda") if provider == "ninetoothed": ms = triton.testing.do_bench(lambda: conv2d(input, filter)) diff --git a/softmax.py b/softmax.py index d42f316..533a627 100644 --- a/softmax.py +++ b/softmax.py @@ -72,14 +72,14 @@ def triton_softmax(input): torch.manual_seed(0) -input = torch.randn(1823, 781, device="cuda") +input = torch.randn(1823, 781, dtype=torch.float16, device="cuda") ninetoothed_output = softmax(input) torch_output = torch.softmax(input, axis=-1) triton_output = triton_softmax(input) print(ninetoothed_output) print(torch_output) print(triton_output) -if torch.allclose(ninetoothed_output, torch_output): +if torch.allclose(ninetoothed_output, torch_output, atol=1e-5): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") @@ -103,7 +103,7 @@ def triton_softmax(input): ) ) def benchmark(m, n, provider): - input = torch.randn(m, n, device="cuda", dtype=torch.float32) + input = torch.randn(m, n, device="cuda", dtype=torch.float16) stream = torch.cuda.Stream() torch.cuda.set_stream(stream) From 79513295ba463f15d01ca137f0a2d7b2b97de66a Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 11 Jan 2025 12:06:02 +0800 Subject: [PATCH 03/68] Improve the performance of the 2D convolution compute kernel written in Triton --- conv2d.py | 123 ++++++++++++++++++++++++++---------------------------- 1 file changed, 59 insertions(+), 64 deletions(-) diff --git a/conv2d.py b/conv2d.py index 434ecea..7c6ad0c 100644 --- a/conv2d.py +++ b/conv2d.py @@ -23,11 +23,8 @@ def arrangement(input, filter, output): return matmul.arrangement(input_flattened, filter_permuted, output_flattened) -conv2d_kernel = ninetoothed.make( - arrangement, - matmul.application, - (Tensor(4), Tensor(4, constexpr_shape=True), Tensor(4)), -) +tensors = (Tensor(4, constexpr_shape=True) for _ in range(3)) +conv2d_kernel = ninetoothed.make(arrangement, matmul.application, tensors) def conv2d(input, filter): @@ -50,13 +47,13 @@ def triton_conv2d_kernel( input_ptr, filter_ptr, output_ptr, - n, - c, - h, - w, - k, - r, - s, + N: tl.constexpr, + C: tl.constexpr, + H: tl.constexpr, + W: tl.constexpr, + K: tl.constexpr, + R: tl.constexpr, + S: tl.constexpr, input_stride_n, input_stride_c, input_stride_h, @@ -74,16 +71,16 @@ def triton_conv2d_kernel( BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ): - p = h - r + 1 - q = w - s + 1 + P: tl.constexpr = H - R + 1 + Q: tl.constexpr = W - S + 1 - gemm_m = n * p * q - gemm_n = k - gemm_k = c * r * s + GEMM_M: tl.constexpr = N * P * Q + GEMM_N: tl.constexpr = K + GEMM_K: tl.constexpr = C * R * S pid = tl.program_id(0) - num_pid_gemm_m = tl.cdiv(gemm_m, BLOCK_SIZE_M) - num_pid_gemm_n = tl.cdiv(gemm_n, BLOCK_SIZE_N) + num_pid_gemm_m = tl.cdiv(GEMM_M, BLOCK_SIZE_M) + num_pid_gemm_n = tl.cdiv(GEMM_N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_gemm_n group_id = pid // num_pid_in_group first_pid_gemm_m = group_id * GROUP_SIZE_M @@ -94,49 +91,46 @@ def triton_conv2d_kernel( offs_gemm_i = pid_gemm_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_gemm_j = pid_gemm_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_n = offs_gemm_i // (p * q) + offs_n = offs_gemm_i // (P * Q) offs_k = offs_gemm_j - npq_residual = offs_gemm_i % (p * q) - offs_p = npq_residual // q - offs_q = npq_residual % q + npq_residual = offs_gemm_i % (P * Q) + offs_p = npq_residual // Q + offs_q = npq_residual % Q + + input_offs_gemm_m = ( + offs_n * input_stride_n + offs_p * input_stride_h + offs_q * input_stride_w + )[:, None] + filter_offs_gemm_n = (offs_k * filter_stride_k)[None, :] accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for i in range(0, tl.cdiv(gemm_k, BLOCK_SIZE_K)): + for i in range(0, tl.cdiv(GEMM_K, BLOCK_SIZE_K)): offs_gemm_k = i * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - offs_c = offs_gemm_k // (r * s) - crs_residual = offs_gemm_k % (r * s) - offs_r = crs_residual // s - offs_s = crs_residual % s - - offs_h = offs_p[:, None] + offs_r[None, :] - offs_w = offs_q[:, None] + offs_s[None, :] - - input_ptrs = ( - input_ptr - + offs_n[:, None] * input_stride_n - + offs_c[None, :] * input_stride_c - + offs_h * input_stride_h - + offs_w * input_stride_w - ) - input_mask = ( - (offs_n[:, None] < n) & (offs_c[None, :] < c) & (offs_h < h) & (offs_w < w) - ) - - filter_ptrs = ( - filter_ptr - + offs_k[None, :] * filter_stride_k - + offs_c[:, None] * filter_stride_c - + offs_r[:, None] * filter_stride_r - + offs_s[:, None] * filter_stride_s - ) - filter_mask = (offs_k[None, :] < k) & ( - (offs_c < c) & (offs_r < r) & (offs_s < s) + offs_c = offs_gemm_k // (R * S) + crs_residual = offs_gemm_k % (R * S) + offs_r = crs_residual // S + offs_s = crs_residual % S + + input_offs_gemm_n = ( + offs_c * input_stride_c + offs_r * input_stride_h + offs_s * input_stride_w + )[None, :] + input_ptrs = input_ptr + input_offs_gemm_m + input_offs_gemm_n + input_mask = ((offs_n < N) & (offs_p < P) & (offs_q < Q))[:, None] & ( + (offs_c < C) & (offs_r < R) & (offs_s < S) + )[None, :] + input = tl.load(input_ptrs, mask=input_mask) + + filter_offs_gemm_m = ( + offs_c * filter_stride_c + + offs_r * filter_stride_r + + offs_s * filter_stride_s )[:, None] - - input = tl.load(input_ptrs, mask=input_mask, other=0.0) - filter = tl.load(filter_ptrs, mask=filter_mask, other=0.0) + filter_ptrs = filter_ptr + filter_offs_gemm_m + filter_offs_gemm_n + filter_mask = (offs_k[None, :] < K) & ( + (offs_c < C) & (offs_r < R) & (offs_s < S) + )[:, None] + filter = tl.load(filter_ptrs, mask=filter_mask) accumulator = tl.dot(input, filter, accumulator) @@ -144,18 +138,19 @@ def triton_conv2d_kernel( output_ptrs = ( output_ptr - + offs_n[:, None] * output_stride_n - + offs_k[None, :] * output_stride_k - + offs_p[:, None] * output_stride_p - + offs_q[:, None] * output_stride_q + + ( + offs_n * output_stride_n + + offs_p * output_stride_p + + offs_q * output_stride_q + )[:, None] + + (offs_k * output_stride_k)[None, :] ) output_mask = ( - (offs_n[:, None] < n) - & (offs_k[None, :] < k) - & (offs_p[:, None] < p) - & (offs_q[:, None] < q) + (offs_n[:, None] < N) + & (offs_k[None, :] < K) + & (offs_p[:, None] < P) + & (offs_q[:, None] < Q) ) - tl.store(output_ptrs, output, mask=output_mask) From e2e0112584d72d578f12018377c2e121a752f7d3 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 11 Jan 2025 12:09:10 +0800 Subject: [PATCH 04/68] Update the benchmark input dimensions to vary based on batch size instead of height and width --- conv2d.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/conv2d.py b/conv2d.py index 7c6ad0c..6bb2174 100644 --- a/conv2d.py +++ b/conv2d.py @@ -215,8 +215,8 @@ def grid(meta): @triton.testing.perf_report( triton.testing.Benchmark( - x_names=["h", "w"], - x_vals=[8 * i for i in range(2, 33)], + x_names=["n"], + x_vals=[2**i for i in range(11)], line_arg="provider", line_vals=["ninetoothed", "torch", "triton"], line_names=["NineToothed", "PyTorch", "Triton"], @@ -226,9 +226,9 @@ def grid(meta): args={}, ) ) - def benchmark(h, w, provider): - n, c, _, _ = 64, 3, h, w - k, _, r, s = 64, c, 3, 3 + def benchmark(n, provider): + _, c, h, w = n, 512, 14, 14 + k, _, r, s = 512, c, 3, 3 dtype = torch.float16 input = torch.randn((n, c, h, w), dtype=dtype, device="cuda") filter = torch.randn((k, c, r, s), dtype=dtype, device="cuda") From b904705aa7db3b513e79b6ed64dda5e4ab9f8e06 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 11 Jan 2025 12:15:42 +0800 Subject: [PATCH 05/68] Remove the `other=0.0` argument from the `tl.load` calls in the `triton_matmul_kernel` function --- matmul.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/matmul.py b/matmul.py index 74cca86..c226de4 100644 --- a/matmul.py +++ b/matmul.py @@ -93,8 +93,8 @@ def triton_matmul_kernel( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for i in range(0, tl.cdiv(k, BLOCK_SIZE_K)): - lhs = tl.load(lhs_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K, other=0.0) - rhs = tl.load(rhs_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K, other=0.0) + lhs = tl.load(lhs_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K) + rhs = tl.load(rhs_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K) accumulator = tl.dot(lhs, rhs, accumulator) lhs_ptrs += BLOCK_SIZE_K * lhs_stride_k rhs_ptrs += BLOCK_SIZE_K * rhs_stride_k From ec7bc2df27717cf96eac4fd7b1260ec4cb2e44ba Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 11 Jan 2025 17:29:16 +0800 Subject: [PATCH 06/68] Add `if __name__ == "__main__"` to `add.py` --- add.py | 117 ++++++++++++++++++++++++++++----------------------------- 1 file changed, 58 insertions(+), 59 deletions(-) diff --git a/add.py b/add.py index c0829e1..d186f4b 100644 --- a/add.py +++ b/add.py @@ -57,63 +57,62 @@ def grid(meta): return output -torch.manual_seed(0) -size = 98432 -dtype = torch.float16 -lhs = torch.rand(size, dtype=dtype, device="cuda") -rhs = torch.rand(size, dtype=dtype, device="cuda") -ninetoothed_output = add(lhs, rhs) -torch_output = lhs + rhs -triton_output = triton_add(lhs, rhs) -print(ninetoothed_output) -print(torch_output) -print(triton_output) -if torch.allclose(ninetoothed_output, torch_output): - print("✅ NineToothed and PyTorch match.") -else: - print("❌ NineToothed and PyTorch differ.") -if torch.allclose(ninetoothed_output, triton_output): - print("✅ NineToothed and Triton match.") -else: - print("❌ NineToothed and Triton differ.") - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["size"], - x_vals=[2**i for i in range(12, 28, 1)], - x_log=True, - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="GB/s", - plot_name="vector-addition-performance", - args={}, - ) -) -def benchmark(size, provider): - lhs = torch.rand(size, device="cuda", dtype=torch.float16) - rhs = torch.rand(size, device="cuda", dtype=torch.float16) - quantiles = [0.5, 0.2, 0.8] - - if provider == "ninetoothed": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: add(lhs, rhs), quantiles=quantiles - ) - elif provider == "torch": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: lhs + rhs, quantiles=quantiles +if __name__ == "__main__": + torch.manual_seed(0) + size = 98432 + dtype = torch.float16 + lhs = torch.rand(size, dtype=dtype, device="cuda") + rhs = torch.rand(size, dtype=dtype, device="cuda") + ninetoothed_output = add(lhs, rhs) + torch_output = lhs + rhs + triton_output = triton_add(lhs, rhs) + print(ninetoothed_output) + print(torch_output) + print(triton_output) + if torch.allclose(ninetoothed_output, torch_output): + print("✅ NineToothed and PyTorch match.") + else: + print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["size"], + x_vals=[2**i for i in range(12, 28, 1)], + x_log=True, + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="GB/s", + plot_name="vector-addition-performance", + args={}, ) - elif provider == "triton": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: triton_add(lhs, rhs), quantiles=quantiles - ) - - def gbps(ms): - return 3 * lhs.numel() * lhs.element_size() / ms * 1e-6 - - return gbps(ms), gbps(max_ms), gbps(min_ms) - - -benchmark.run(print_data=True, show_plots=True, save_path=".") + ) + def benchmark(size, provider): + lhs = torch.rand(size, device="cuda", dtype=torch.float16) + rhs = torch.rand(size, device="cuda", dtype=torch.float16) + quantiles = [0.5, 0.2, 0.8] + + if provider == "ninetoothed": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: add(lhs, rhs), quantiles=quantiles + ) + elif provider == "torch": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: lhs + rhs, quantiles=quantiles + ) + elif provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: triton_add(lhs, rhs), quantiles=quantiles + ) + + def gbps(ms): + return 3 * lhs.numel() * lhs.element_size() / ms * 1e-6 + + return gbps(ms), gbps(max_ms), gbps(min_ms) + + benchmark.run(print_data=True, show_plots=True, save_path=".") From a8bfbd63a795e5e189e5a1a172ea50355c2839a2 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 11 Jan 2025 17:32:50 +0800 Subject: [PATCH 07/68] Add `if __name__ == "__main__"` to `softmax.py` --- softmax.py | 89 +++++++++++++++++++++++++++--------------------------- 1 file changed, 44 insertions(+), 45 deletions(-) diff --git a/softmax.py b/softmax.py index 533a627..92c03e0 100644 --- a/softmax.py +++ b/softmax.py @@ -71,53 +71,52 @@ def triton_softmax(input): return output -torch.manual_seed(0) -input = torch.randn(1823, 781, dtype=torch.float16, device="cuda") -ninetoothed_output = softmax(input) -torch_output = torch.softmax(input, axis=-1) -triton_output = triton_softmax(input) -print(ninetoothed_output) -print(torch_output) -print(triton_output) -if torch.allclose(ninetoothed_output, torch_output, atol=1e-5): - print("✅ NineToothed and PyTorch match.") -else: - print("❌ NineToothed and PyTorch differ.") -if torch.allclose(ninetoothed_output, triton_output): - print("✅ NineToothed and Triton match.") -else: - print("❌ NineToothed and Triton differ.") - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["n"], - x_vals=[128 * i for i in range(2, 100)], - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="GB/s", - plot_name="softmax-performance", - args={"m": 4096}, +if __name__ == "__main__": + torch.manual_seed(0) + input = torch.randn(1823, 781, dtype=torch.float16, device="cuda") + ninetoothed_output = softmax(input) + torch_output = torch.softmax(input, axis=-1) + triton_output = triton_softmax(input) + print(ninetoothed_output) + print(torch_output) + print(triton_output) + if torch.allclose(ninetoothed_output, torch_output, atol=1e-5): + print("✅ NineToothed and PyTorch match.") + else: + print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["n"], + x_vals=[128 * i for i in range(2, 100)], + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="GB/s", + plot_name="softmax-performance", + args={"m": 4096}, + ) ) -) -def benchmark(m, n, provider): - input = torch.randn(m, n, device="cuda", dtype=torch.float16) - stream = torch.cuda.Stream() - torch.cuda.set_stream(stream) + def benchmark(m, n, provider): + input = torch.randn(m, n, device="cuda", dtype=torch.float16) + stream = torch.cuda.Stream() + torch.cuda.set_stream(stream) - if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: softmax(input)) - elif provider == "torch": - ms = triton.testing.do_bench(lambda: torch.softmax(input, axis=-1)) - elif provider == "triton": - ms = triton.testing.do_bench(lambda: triton_softmax(input)) + if provider == "ninetoothed": + ms = triton.testing.do_bench(lambda: softmax(input)) + elif provider == "torch": + ms = triton.testing.do_bench(lambda: torch.softmax(input, axis=-1)) + elif provider == "triton": + ms = triton.testing.do_bench(lambda: triton_softmax(input)) - def gbps(ms): - return 2 * input.numel() * input.element_size() * 1e-6 / ms + def gbps(ms): + return 2 * input.numel() * input.element_size() * 1e-6 / ms - return gbps(ms) + return gbps(ms) - -benchmark.run(show_plots=True, print_data=True, save_path=".") + benchmark.run(show_plots=True, print_data=True, save_path=".") From f07b7484e8e823ed94c2ea7c853183edf7a3020e Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 11 Jan 2025 18:41:12 +0800 Subject: [PATCH 08/68] Add code size comparison --- code_size_comparison.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 code_size_comparison.py diff --git a/code_size_comparison.py b/code_size_comparison.py new file mode 100644 index 0000000..e3bd845 --- /dev/null +++ b/code_size_comparison.py @@ -0,0 +1,36 @@ +import matplotlib.pyplot as plt +import numpy as np + +plt.rcParams["figure.dpi"] = 600 +plt.rcParams["font.family"] = "JetBrains Mono" +plt.rcParams["font.weight"] = "bold" +plt.rcParams["axes.labelweight"] = "bold" + +kernels = ("add", "softmax", "matmul", "conv2d", "attention") +lines_of_code = {"Triton": (19, 26, 57, 110, 98), "NineToothed": (10, 12, 34, 17, 51)} + +x = np.arange(len(kernels)) +width = 0.4 +multiplier = 0 + +fig, ax = plt.subplots() + +for provider, lines in lines_of_code.items(): + offset = width * multiplier + rects = ax.bar(x + offset, lines, width, label=provider) + ax.bar_label(rects, fontsize=16) + multiplier += 1 + +ax.set_ylabel("Lines of Code", fontsize=16) +ax.tick_params(axis="y", labelsize=12, labelcolor="gray") +ax.set_xticks(x + width / 2, kernels, fontsize=12) +ax.xaxis.set_ticks_position("none") +ax.yaxis.set_ticks_position("none") +ax.legend(fontsize=10) +ax.spines[["top", "left", "right"]].set_visible(False) +ax.spines["bottom"].set_linewidth(1.5) +ax.grid(axis="y", linewidth=1.5) +ax.set_axisbelow(True) + +plt.show() +plt.savefig("code-size-comparison.png") From 33a29fe7fb0668fb42c031b0123f9b32f70e138d Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 11 Jan 2025 23:27:14 +0800 Subject: [PATCH 09/68] Add an example for Root Mean Square Layer Normalization --- rms_norm.py | 119 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 rms_norm.py diff --git a/rms_norm.py b/rms_norm.py new file mode 100644 index 0000000..f4ba67b --- /dev/null +++ b/rms_norm.py @@ -0,0 +1,119 @@ +import ninetoothed +import ninetoothed.language as ntl +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) + + +@ninetoothed.jit +def rms_norm_kernel( + input: Tensor(2).tile((1, BLOCK_SIZE)), + output: Tensor(2).tile((1, BLOCK_SIZE)), + eps: Tensor(0), +): + input_fp32 = ntl.cast(input, ntl.float32) + output = input_fp32 * ntl.rsqrt( # noqa: F841 + ntl.sum(input_fp32 * input_fp32) / input.shape[-1] + eps + ) + + +def rms_norm(input, eps=1e-5): + output = torch.empty_like(input) + + rms_norm_kernel(input, output, eps, BLOCK_SIZE=input.shape[-1]) + + return output + + +@triton.jit +def triton_rms_norm_kernel( + input_ptr, + output_ptr, + num_cols, + input_row_stride, + output_row_stride, + eps: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets + mask = col_offsets < num_cols + input = tl.load(input_ptrs, mask=mask) + + output = input * tl.rsqrt(tl.sum(input * input) / num_cols + eps) + + output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets + tl.store(output_ptrs, output, mask=mask) + + +def triton_rms_norm(input, eps=1e-5): + output = torch.empty_like(input) + + triton_rms_norm_kernel[(input.shape[-2],)]( + input, + output, + input.shape[-1], + input.stride(-2), + output.stride(-2), + eps, + BLOCK_SIZE=triton.next_power_of_2(input.shape[-1]), + ) + + return output + + +if __name__ == "__main__": + torch.manual_seed(0) + input = torch.randn(1151, 8192, dtype=torch.float16, device="cuda") + ninetoothed_output = rms_norm(input) + torch_output = F.rms_norm(input, input.shape[-1:]) + triton_output = triton_rms_norm(input) + print(ninetoothed_output) + print(torch_output) + print(triton_output) + if torch.allclose(ninetoothed_output, torch_output): + print("✅ NineToothed and PyTorch match.") + else: + print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["n"], + x_vals=[512 * i for i in range(2, 32)], + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="GB/s", + plot_name="rms-norm-performance", + args={"m": 4096}, + ) + ) + def benchmark(m, n, provider): + input = torch.randn(m, n, dtype=torch.float16, device="cuda") + + if provider == "ninetoothed": + ms = triton.testing.do_bench(lambda: rms_norm(input)) + elif provider == "torch": + ms = triton.testing.do_bench( + lambda: torch.rms_norm(input, input.shape[-1:]) + ) + elif provider == "triton": + ms = triton.testing.do_bench(lambda: triton_rms_norm(input)) + + def gbps(ms): + return 2 * input.numel() * input.element_size() * 1e-6 / ms + + return gbps(ms) + + benchmark.run(show_plots=True, print_data=True, save_path=".") From 9275994acfc4a9d5e22bf4da77d728f096e9349a Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 11 Jan 2025 23:36:20 +0800 Subject: [PATCH 10/68] Remove the unused `n_rows` parameter in the `triton_softmax_kernel` function --- code_size_comparison.py | 2 +- softmax.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/code_size_comparison.py b/code_size_comparison.py index e3bd845..c2c6b0a 100644 --- a/code_size_comparison.py +++ b/code_size_comparison.py @@ -7,7 +7,7 @@ plt.rcParams["axes.labelweight"] = "bold" kernels = ("add", "softmax", "matmul", "conv2d", "attention") -lines_of_code = {"Triton": (19, 26, 57, 110, 98), "NineToothed": (10, 12, 34, 17, 51)} +lines_of_code = {"Triton": (19, 25, 57, 110, 98), "NineToothed": (10, 12, 34, 17, 51)} x = np.arange(len(kernels)) width = 0.4 diff --git a/softmax.py b/softmax.py index 92c03e0..5306490 100644 --- a/softmax.py +++ b/softmax.py @@ -33,7 +33,6 @@ def triton_softmax_kernel( output_ptr, input_row_stride, output_row_stride, - n_rows, n_cols, BLOCK_SIZE: tl.constexpr, ): @@ -63,7 +62,6 @@ def triton_softmax(input): output, input.stride(0), output.stride(0), - input.shape[0], input.shape[1], BLOCK_SIZE=triton.next_power_of_2(input.shape[-1]), ) From 83f7596e70510a503b1660c32786bb3826cb7c0d Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sun, 12 Jan 2025 00:10:21 +0800 Subject: [PATCH 11/68] Add `rms_norm` data into `code_size_comparison.py` --- code_size_comparison.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/code_size_comparison.py b/code_size_comparison.py index c2c6b0a..7b69cf9 100644 --- a/code_size_comparison.py +++ b/code_size_comparison.py @@ -6,11 +6,14 @@ plt.rcParams["font.weight"] = "bold" plt.rcParams["axes.labelweight"] = "bold" -kernels = ("add", "softmax", "matmul", "conv2d", "attention") -lines_of_code = {"Triton": (19, 25, 57, 110, 98), "NineToothed": (10, 12, 34, 17, 51)} +kernels = ("add", "softmax", "rms_norm", "matmul", "conv2d", "attention") +lines_of_code = { + "Triton": (19, 25, 21, 57, 110, 98), + "NineToothed": (10, 12, 13, 34, 17, 51), +} x = np.arange(len(kernels)) -width = 0.4 +width = 0.25 multiplier = 0 fig, ax = plt.subplots() @@ -18,15 +21,15 @@ for provider, lines in lines_of_code.items(): offset = width * multiplier rects = ax.bar(x + offset, lines, width, label=provider) - ax.bar_label(rects, fontsize=16) + ax.bar_label(rects, fontsize=12) multiplier += 1 -ax.set_ylabel("Lines of Code", fontsize=16) -ax.tick_params(axis="y", labelsize=12, labelcolor="gray") -ax.set_xticks(x + width / 2, kernels, fontsize=12) +ax.set_ylabel("Lines of Code", fontsize=12) +ax.tick_params(axis="y", labelsize=10, labelcolor="gray") +ax.set_xticks(x + width / 2, kernels, fontsize=10) ax.xaxis.set_ticks_position("none") ax.yaxis.set_ticks_position("none") -ax.legend(fontsize=10) +ax.legend(ncols=2, fontsize=10) ax.spines[["top", "left", "right"]].set_visible(False) ax.spines["bottom"].set_linewidth(1.5) ax.grid(axis="y", linewidth=1.5) From 9d006c95eaacfba67d7c7d2f970c2d3bdfedee3c Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sun, 12 Jan 2025 03:10:01 +0800 Subject: [PATCH 12/68] Add performance comparison --- performance_comparison.py | 87 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 performance_comparison.py diff --git a/performance_comparison.py b/performance_comparison.py new file mode 100644 index 0000000..e85d886 --- /dev/null +++ b/performance_comparison.py @@ -0,0 +1,87 @@ +from dataclasses import dataclass + +import matplotlib.pyplot as plt +import pandas as pd + +plt.rcParams["figure.figsize"] = [12, 6] +plt.rcParams["figure.dpi"] = 600 +plt.rcParams["font.family"] = "JetBrains Mono" +plt.rcParams["font.weight"] = "bold" +plt.rcParams["axes.titleweight"] = "bold" +plt.rcParams["axes.labelweight"] = "bold" + + +@dataclass +class KernelInformation: + name: str + memory_bound: bool + compute_bound: bool + perf_report_path: str + independent_variable: str + + +@dataclass +class CategoryInformation: + kernels: tuple + y_label: str + + +kernels = ( + KernelInformation("add", True, False, "vector-addition-performance.csv", "Length"), + KernelInformation( + "softmax", True, False, "softmax-performance.csv", "Number of Columns" + ), + KernelInformation( + "rms_norm", True, False, "rms-norm-performance.csv", "Number of Columns" + ), + KernelInformation( + "matmul", False, True, "matrix-multiplication-performance.csv", "Sizes" + ), + KernelInformation( + "conv2d", False, True, "2d-convolution-performance.csv", "Batch Size" + ), + KernelInformation( + "attention", False, True, "attention-performance.csv", "Sequence Length" + ), +) + +providers = ("Triton", "NineToothed") + +categories = ( + CategoryInformation( + tuple(kernel for kernel in kernels if kernel.memory_bound), "GB/s" + ), + CategoryInformation( + tuple(kernel for kernel in kernels if kernel.compute_bound), "TFLOPS" + ), +) + +num_rows = len(categories) +num_cols = max(len(category.kernels) for category in categories) + +fig, axs = plt.subplots(num_rows, num_cols) + +for row, category in enumerate(categories): + axs[row, 0].set_ylabel(category.y_label) + + for col, kernel in enumerate(category.kernels): + df = pd.read_csv(kernel.perf_report_path) + ax = axs[row, col] + + x = df.iloc[:, 0] + + for provider in providers: + y = df[provider] + + ax.plot(x, y, label=provider) + + ax.set_title(kernel.name) + ax.set_xlabel(kernel.independent_variable) + ax.set_xscale("log", base=2) + +fig.legend(providers, loc="upper center", ncols=len(providers)) +fig.tight_layout() +fig.subplots_adjust(top=0.9) + +plt.show() +plt.savefig("performance-comparison.png") From 0b12eb826805d0dc5add2ed803a8f41ba0dd364a Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sun, 12 Jan 2025 11:27:43 +0800 Subject: [PATCH 13/68] Fix a precision issue by casting the loaded data to `tl.float32` --- rms_norm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rms_norm.py b/rms_norm.py index f4ba67b..b538b4e 100644 --- a/rms_norm.py +++ b/rms_norm.py @@ -44,7 +44,7 @@ def triton_rms_norm_kernel( col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets mask = col_offsets < num_cols - input = tl.load(input_ptrs, mask=mask) + input = tl.load(input_ptrs, mask=mask).to(tl.float32) output = input * tl.rsqrt(tl.sum(input * input) / num_cols + eps) @@ -77,7 +77,7 @@ def triton_rms_norm(input, eps=1e-5): print(ninetoothed_output) print(torch_output) print(triton_output) - if torch.allclose(ninetoothed_output, torch_output): + if torch.allclose(ninetoothed_output, torch_output, atol=0.001, rtol=0.005): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") From be4f8c274833fc80a576e3d7bfb4ff46120662c4 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sun, 12 Jan 2025 11:44:59 +0800 Subject: [PATCH 14/68] Use `atol=0` and `rtol=0` in `when comparing NineToothed and Triton outputs --- add.py | 2 +- attention.py | 2 +- conv2d.py | 2 +- matmul.py | 2 +- rms_norm.py | 2 +- softmax.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/add.py b/add.py index d186f4b..8b591c7 100644 --- a/add.py +++ b/add.py @@ -73,7 +73,7 @@ def grid(meta): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output): + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): print("✅ NineToothed and Triton match.") else: print("❌ NineToothed and Triton differ.") diff --git a/attention.py b/attention.py index a115d7f..c945dcb 100644 --- a/attention.py +++ b/attention.py @@ -216,7 +216,7 @@ def grid(meta): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0.01): + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): print("✅ NineToothed and Triton match.") else: print("❌ NineToothed and Triton differ.") diff --git a/conv2d.py b/conv2d.py index 6bb2174..6190f22 100644 --- a/conv2d.py +++ b/conv2d.py @@ -208,7 +208,7 @@ def grid(meta): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0.01, rtol=0.01): + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): print("✅ NineToothed and Triton match.") else: print("❌ NineToothed and Triton differ.") diff --git a/matmul.py b/matmul.py index c226de4..000c13a 100644 --- a/matmul.py +++ b/matmul.py @@ -159,7 +159,7 @@ def grid(meta): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output): + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): print("✅ NineToothed and Triton match.") else: print("❌ NineToothed and Triton differ.") diff --git a/rms_norm.py b/rms_norm.py index b538b4e..cfc0b03 100644 --- a/rms_norm.py +++ b/rms_norm.py @@ -81,7 +81,7 @@ def triton_rms_norm(input, eps=1e-5): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output): + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): print("✅ NineToothed and Triton match.") else: print("❌ NineToothed and Triton differ.") diff --git a/softmax.py b/softmax.py index 5306490..a46717b 100644 --- a/softmax.py +++ b/softmax.py @@ -82,7 +82,7 @@ def triton_softmax(input): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output): + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): print("✅ NineToothed and Triton match.") else: print("❌ NineToothed and Triton differ.") From ee28ccb166f7bba8e39c52bf135f782d09103013 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sun, 12 Jan 2025 11:45:54 +0800 Subject: [PATCH 15/68] Add `requirements.txt` --- requirements.txt | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1ebe6ae --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +ninetoothed +torch +matplotlib +pandas From 11d09ddc8c4d48498341971cf9f3c4c386c2505c Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sun, 12 Jan 2025 17:15:33 +0800 Subject: [PATCH 16/68] Add correctness verification during performance testing --- add.py | 6 ++++++ attention.py | 6 ++++++ conv2d.py | 6 ++++++ matmul.py | 6 ++++++ rms_norm.py | 6 ++++++ softmax.py | 8 +++++++- 6 files changed, 37 insertions(+), 1 deletion(-) diff --git a/add.py b/add.py index 8b591c7..53a67de 100644 --- a/add.py +++ b/add.py @@ -97,6 +97,12 @@ def benchmark(size, provider): rhs = torch.rand(size, device="cuda", dtype=torch.float16) quantiles = [0.5, 0.2, 0.8] + ninetoothed_output = add(lhs, rhs) + torch_output = lhs + rhs + triton_output = triton_add(lhs, rhs) + assert torch.allclose(ninetoothed_output, torch_output) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) + if provider == "ninetoothed": ms, min_ms, max_ms = triton.testing.do_bench( lambda: add(lhs, rhs), quantiles=quantiles diff --git a/attention.py b/attention.py index c945dcb..1aa138b 100644 --- a/attention.py +++ b/attention.py @@ -242,6 +242,12 @@ def benchmark(seq_len, provider): k = torch.randn(shape, dtype=dtype, device="cuda") v = torch.randn(shape, dtype=dtype, device="cuda") + ninetoothed_output = attention(q, k, v) + torch_output = F.scaled_dot_product_attention(q, k, v, scale=1) + triton_output = triton_attention(q, k, v) + assert torch.allclose(ninetoothed_output, torch_output, atol=0.025, rtol=0.025) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) + if provider == "ninetoothed": ms = triton.testing.do_bench(lambda: attention(q, k, v)) elif provider == "torch": diff --git a/conv2d.py b/conv2d.py index 6190f22..f592171 100644 --- a/conv2d.py +++ b/conv2d.py @@ -233,6 +233,12 @@ def benchmark(n, provider): input = torch.randn((n, c, h, w), dtype=dtype, device="cuda") filter = torch.randn((k, c, r, s), dtype=dtype, device="cuda") + ninetoothed_output = conv2d(input, filter) + torch_output = F.conv2d(input, filter) + triton_output = triton_conv2d(input, filter) + assert torch.allclose(ninetoothed_output, torch_output, atol=0.01, rtol=0.01) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) + if provider == "ninetoothed": ms = triton.testing.do_bench(lambda: conv2d(input, filter)) elif provider == "torch": diff --git a/matmul.py b/matmul.py index 000c13a..7204575 100644 --- a/matmul.py +++ b/matmul.py @@ -182,6 +182,12 @@ def benchmark(m, n, k, provider): rhs = torch.randn((k, n), device="cuda", dtype=torch.float16) quantiles = [0.5, 0.2, 0.8] + ninetoothed_output = matmul(lhs, rhs) + torch_output = torch.matmul(lhs, rhs) + triton_output = triton_matmul(lhs, rhs) + assert torch.allclose(ninetoothed_output, torch_output, atol=0.025, rtol=0.025) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) + if provider == "ninetoothed": ms, min_ms, max_ms = triton.testing.do_bench( lambda: matmul(lhs, rhs), quantiles=quantiles diff --git a/rms_norm.py b/rms_norm.py index cfc0b03..270c959 100644 --- a/rms_norm.py +++ b/rms_norm.py @@ -102,6 +102,12 @@ def triton_rms_norm(input, eps=1e-5): def benchmark(m, n, provider): input = torch.randn(m, n, dtype=torch.float16, device="cuda") + ninetoothed_output = rms_norm(input) + torch_output = F.rms_norm(input, input.shape[-1:]) + triton_output = triton_rms_norm(input) + assert torch.allclose(ninetoothed_output, torch_output, atol=0.001, rtol=0.005) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) + if provider == "ninetoothed": ms = triton.testing.do_bench(lambda: rms_norm(input)) elif provider == "torch": diff --git a/softmax.py b/softmax.py index a46717b..463bcfa 100644 --- a/softmax.py +++ b/softmax.py @@ -78,7 +78,7 @@ def triton_softmax(input): print(ninetoothed_output) print(torch_output) print(triton_output) - if torch.allclose(ninetoothed_output, torch_output, atol=1e-5): + if torch.allclose(ninetoothed_output, torch_output, atol=0.001): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") @@ -105,6 +105,12 @@ def benchmark(m, n, provider): stream = torch.cuda.Stream() torch.cuda.set_stream(stream) + ninetoothed_output = softmax(input) + torch_output = torch.softmax(input, axis=-1) + triton_output = triton_softmax(input) + assert torch.allclose(ninetoothed_output, torch_output, atol=0.001) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) + if provider == "ninetoothed": ms = triton.testing.do_bench(lambda: softmax(input)) elif provider == "torch": From 9d5c7317907acbb9be7570a8a8a76cb26b28b581 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sun, 12 Jan 2025 17:46:33 +0800 Subject: [PATCH 17/68] Update benchmark `x_vals` ranges and use log scaling for performance testing --- add.py | 2 +- attention.py | 3 ++- conv2d.py | 3 ++- matmul.py | 3 ++- rms_norm.py | 3 ++- softmax.py | 3 ++- 6 files changed, 11 insertions(+), 6 deletions(-) diff --git a/add.py b/add.py index 53a67de..74bcab8 100644 --- a/add.py +++ b/add.py @@ -81,7 +81,7 @@ def grid(meta): @triton.testing.perf_report( triton.testing.Benchmark( x_names=["size"], - x_vals=[2**i for i in range(12, 28, 1)], + x_vals=[2**i for i in range(18, 28)], x_log=True, line_arg="provider", line_vals=["ninetoothed", "torch", "triton"], diff --git a/attention.py b/attention.py index 1aa138b..329fc89 100644 --- a/attention.py +++ b/attention.py @@ -224,7 +224,8 @@ def grid(meta): @triton.testing.perf_report( triton.testing.Benchmark( x_names=["seq_len"], - x_vals=[2**i for i in range(10, 15)], + x_vals=[2**i for i in range(6, 16)], + x_log=True, line_arg="provider", line_vals=["ninetoothed", "torch", "triton"], line_names=["NineToothed", "PyTorch", "Triton"], diff --git a/conv2d.py b/conv2d.py index f592171..6e4b719 100644 --- a/conv2d.py +++ b/conv2d.py @@ -216,7 +216,8 @@ def grid(meta): @triton.testing.perf_report( triton.testing.Benchmark( x_names=["n"], - x_vals=[2**i for i in range(11)], + x_vals=[2**i for i in range(1, 11)], + x_log=True, line_arg="provider", line_vals=["ninetoothed", "torch", "triton"], line_names=["NineToothed", "PyTorch", "Triton"], diff --git a/matmul.py b/matmul.py index 7204575..b5b2871 100644 --- a/matmul.py +++ b/matmul.py @@ -167,7 +167,8 @@ def grid(meta): @triton.testing.perf_report( triton.testing.Benchmark( x_names=["m", "n", "k"], - x_vals=[128 * i for i in range(2, 33)], + x_vals=[2**i for i in range(3, 13)], + x_log=True, line_arg="provider", line_vals=["ninetoothed", "torch", "triton"], line_names=["NineToothed", "PyTorch", "Triton"], diff --git a/rms_norm.py b/rms_norm.py index 270c959..abefc8e 100644 --- a/rms_norm.py +++ b/rms_norm.py @@ -89,7 +89,8 @@ def triton_rms_norm(input, eps=1e-5): @triton.testing.perf_report( triton.testing.Benchmark( x_names=["n"], - x_vals=[512 * i for i in range(2, 32)], + x_vals=[2**i for i in range(5, 15)], + x_log=True, line_arg="provider", line_vals=["ninetoothed", "torch", "triton"], line_names=["NineToothed", "PyTorch", "Triton"], diff --git a/softmax.py b/softmax.py index 463bcfa..97d5f53 100644 --- a/softmax.py +++ b/softmax.py @@ -90,7 +90,8 @@ def triton_softmax(input): @triton.testing.perf_report( triton.testing.Benchmark( x_names=["n"], - x_vals=[128 * i for i in range(2, 100)], + x_vals=[2**i for i in range(5, 15)], + x_log=True, line_arg="provider", line_vals=["ninetoothed", "torch", "triton"], line_names=["NineToothed", "PyTorch", "Triton"], From 60c9a6e9dcf5d251eda5eece9acfae0aca572748 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 13 Jan 2025 22:33:32 +0800 Subject: [PATCH 18/68] Add statistics into `performance_comparison.py` --- performance_comparison.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/performance_comparison.py b/performance_comparison.py index e85d886..90c2f56 100644 --- a/performance_comparison.py +++ b/performance_comparison.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import matplotlib.pyplot as plt +import numpy as np import pandas as pd plt.rcParams["figure.figsize"] = [12, 6] @@ -61,6 +62,8 @@ class CategoryInformation: fig, axs = plt.subplots(num_rows, num_cols) +performance_differences = [] + for row, category in enumerate(categories): axs[row, 0].set_ylabel(category.y_label) @@ -70,11 +73,18 @@ class CategoryInformation: x = df.iloc[:, 0] + performance_differences.append((kernel, [])) + for provider in providers: y = df[provider] ax.plot(x, y, label=provider) + if provider == "NineToothed": + y_triton = df["Triton"] + diff = (y - y_triton) / y_triton * 100 + performance_differences[-1][-1].append(diff) + ax.set_title(kernel.name) ax.set_xlabel(kernel.independent_variable) ax.set_xscale("log", base=2) @@ -85,3 +95,28 @@ class CategoryInformation: plt.show() plt.savefig("performance-comparison.png") + +all_differences = [] +stats_data = [] + +for kernel, diffs in performance_differences: + all_differences.extend(diffs) + + kernel_stats = { + "Kernel": kernel.name, + "Mean": np.mean(diffs), + "Median": np.median(diffs), + } + + stats_data.append(kernel_stats) + +overall_stats = { + "Kernel": "Overall", + "Mean": np.mean(all_differences), + "Median": np.median(all_differences), +} + +stats_data.append(overall_stats) + +print("Relative Performance Change (%):") +print(pd.DataFrame(stats_data)) From 69b626dcd40428eaa50c0a50aadabdf02fec1cc9 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 13 Jan 2025 22:51:33 +0800 Subject: [PATCH 19/68] Add statistics into `code_size_comparison.py` --- code_size_comparison.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/code_size_comparison.py b/code_size_comparison.py index 7b69cf9..fa31a1a 100644 --- a/code_size_comparison.py +++ b/code_size_comparison.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt import numpy as np +import pandas as pd plt.rcParams["figure.dpi"] = 600 plt.rcParams["font.family"] = "JetBrains Mono" @@ -37,3 +38,17 @@ plt.show() plt.savefig("code-size-comparison.png") + +print( + pd.DataFrame( + { + "Kernel": kernels, + "Relative Code Size Change (%)": [ + f"{ninetoothed_lines / triton_lines * 100:.2f}%" + for ninetoothed_lines, triton_lines in zip( + lines_of_code["NineToothed"], lines_of_code["Triton"] + ) + ], + } + ) +) From 13dc6be3c1de249c94cb6bed9a544515d4e9855c Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 14 Jan 2025 00:08:27 +0800 Subject: [PATCH 20/68] Add overall comparison into `code_size_comparison.py` --- code_size_comparison.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/code_size_comparison.py b/code_size_comparison.py index fa31a1a..1e733e5 100644 --- a/code_size_comparison.py +++ b/code_size_comparison.py @@ -42,12 +42,15 @@ print( pd.DataFrame( { - "Kernel": kernels, + "Kernel": kernels + ("Overall",), "Relative Code Size Change (%)": [ f"{ninetoothed_lines / triton_lines * 100:.2f}%" for ninetoothed_lines, triton_lines in zip( lines_of_code["NineToothed"], lines_of_code["Triton"] ) + ] + + [ + f"{sum(lines_of_code['NineToothed']) / sum(lines_of_code['Triton']) * 100:.2f}%" ], } ) From b5674eea28563071351cf8f3fad8c121e7026d41 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 14 Jan 2025 14:03:48 +0800 Subject: [PATCH 21/68] Remove the stream setting in `softmax.py` --- softmax.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/softmax.py b/softmax.py index 97d5f53..2641902 100644 --- a/softmax.py +++ b/softmax.py @@ -103,8 +103,6 @@ def triton_softmax(input): ) def benchmark(m, n, provider): input = torch.randn(m, n, device="cuda", dtype=torch.float16) - stream = torch.cuda.Stream() - torch.cuda.set_stream(stream) ninetoothed_output = softmax(input) torch_output = torch.softmax(input, axis=-1) From aafa34aca174385cff9189cbb9fe67d4b0620935 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 14 Jan 2025 14:05:49 +0800 Subject: [PATCH 22/68] Use `torch.randn` instead of `torch.rand` in `add.py` --- add.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/add.py b/add.py index 74bcab8..421955a 100644 --- a/add.py +++ b/add.py @@ -93,8 +93,8 @@ def grid(meta): ) ) def benchmark(size, provider): - lhs = torch.rand(size, device="cuda", dtype=torch.float16) - rhs = torch.rand(size, device="cuda", dtype=torch.float16) + lhs = torch.randn(size, device="cuda", dtype=torch.float16) + rhs = torch.randn(size, device="cuda", dtype=torch.float16) quantiles = [0.5, 0.2, 0.8] ninetoothed_output = add(lhs, rhs) From 0dfaeb0258d6195a168f635924fdfcf29bac8cd9 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 14 Jan 2025 14:08:36 +0800 Subject: [PATCH 23/68] Use time as the metric for measuring performance --- add.py | 20 +++------ attention.py | 10 +---- conv2d.py | 10 +---- matmul.py | 20 +++------ performance_comparison.py | 93 ++++++++++++++------------------------- rms_norm.py | 7 +-- softmax.py | 7 +-- 7 files changed, 50 insertions(+), 117 deletions(-) diff --git a/add.py b/add.py index 421955a..e07c571 100644 --- a/add.py +++ b/add.py @@ -87,7 +87,7 @@ def grid(meta): line_vals=["ninetoothed", "torch", "triton"], line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="GB/s", + ylabel="ms", plot_name="vector-addition-performance", args={}, ) @@ -95,7 +95,6 @@ def grid(meta): def benchmark(size, provider): lhs = torch.randn(size, device="cuda", dtype=torch.float16) rhs = torch.randn(size, device="cuda", dtype=torch.float16) - quantiles = [0.5, 0.2, 0.8] ninetoothed_output = add(lhs, rhs) torch_output = lhs + rhs @@ -104,21 +103,12 @@ def benchmark(size, provider): assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) if provider == "ninetoothed": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: add(lhs, rhs), quantiles=quantiles - ) + ms = triton.testing.do_bench(lambda: add(lhs, rhs)) elif provider == "torch": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: lhs + rhs, quantiles=quantiles - ) + ms = triton.testing.do_bench(lambda: lhs + rhs) elif provider == "triton": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: triton_add(lhs, rhs), quantiles=quantiles - ) + ms = triton.testing.do_bench(lambda: triton_add(lhs, rhs)) - def gbps(ms): - return 3 * lhs.numel() * lhs.element_size() / ms * 1e-6 - - return gbps(ms), gbps(max_ms), gbps(min_ms) + return ms benchmark.run(print_data=True, show_plots=True, save_path=".") diff --git a/attention.py b/attention.py index 329fc89..4c7cdb2 100644 --- a/attention.py +++ b/attention.py @@ -230,7 +230,7 @@ def grid(meta): line_vals=["ninetoothed", "torch", "triton"], line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="TFLOPS", + ylabel="ms", plot_name="attention-performance", args={}, ) @@ -258,12 +258,6 @@ def benchmark(seq_len, provider): elif provider == "triton": ms = triton.testing.do_bench(lambda: triton_attention(q, k, v)) - def perf(ms): - flops_per_matmul = 2 * batch_size * num_heads * seq_len * seq_len * emb_dim - total_flops = 2 * flops_per_matmul - - return total_flops * 1e-12 / (ms * 1e-3) - - return perf(ms) + return ms benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/conv2d.py b/conv2d.py index 6e4b719..17d22b4 100644 --- a/conv2d.py +++ b/conv2d.py @@ -222,7 +222,7 @@ def grid(meta): line_vals=["ninetoothed", "torch", "triton"], line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="TFLOPS", + ylabel="ms", plot_name="2d-convolution-performance", args={}, ) @@ -247,12 +247,6 @@ def benchmark(n, provider): elif provider == "triton": ms = triton.testing.do_bench(lambda: triton_conv2d(input, filter)) - def perf(ms): - p = h - r + 1 - q = w - s + 1 - - return 2 * n * k * p * q * c * r * s * 1e-12 / (ms * 1e-3) - - return perf(ms) + return ms benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/matmul.py b/matmul.py index b5b2871..bbd6a08 100644 --- a/matmul.py +++ b/matmul.py @@ -173,7 +173,7 @@ def grid(meta): line_vals=["ninetoothed", "torch", "triton"], line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="TFLOPS", + ylabel="ms", plot_name="matrix-multiplication-performance", args={}, ) @@ -181,7 +181,6 @@ def grid(meta): def benchmark(m, n, k, provider): lhs = torch.randn((m, k), device="cuda", dtype=torch.float16) rhs = torch.randn((k, n), device="cuda", dtype=torch.float16) - quantiles = [0.5, 0.2, 0.8] ninetoothed_output = matmul(lhs, rhs) torch_output = torch.matmul(lhs, rhs) @@ -190,21 +189,12 @@ def benchmark(m, n, k, provider): assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) if provider == "ninetoothed": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: matmul(lhs, rhs), quantiles=quantiles - ) + ms = triton.testing.do_bench(lambda: matmul(lhs, rhs)) elif provider == "torch": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: torch.matmul(lhs, rhs), quantiles=quantiles - ) + ms = triton.testing.do_bench(lambda: torch.matmul(lhs, rhs)) elif provider == "triton": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: triton_matmul(lhs, rhs), quantiles=quantiles - ) + ms = triton.testing.do_bench(lambda: triton_matmul(lhs, rhs)) - def perf(ms): - return 2 * m * n * k * 1e-12 / (ms * 1e-3) - - return perf(ms), perf(max_ms), perf(min_ms) + return ms benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/performance_comparison.py b/performance_comparison.py index 90c2f56..a6745a6 100644 --- a/performance_comparison.py +++ b/performance_comparison.py @@ -15,79 +15,50 @@ @dataclass class KernelInformation: name: str - memory_bound: bool - compute_bound: bool perf_report_path: str independent_variable: str -@dataclass -class CategoryInformation: - kernels: tuple - y_label: str - - kernels = ( - KernelInformation("add", True, False, "vector-addition-performance.csv", "Length"), - KernelInformation( - "softmax", True, False, "softmax-performance.csv", "Number of Columns" - ), - KernelInformation( - "rms_norm", True, False, "rms-norm-performance.csv", "Number of Columns" - ), - KernelInformation( - "matmul", False, True, "matrix-multiplication-performance.csv", "Sizes" - ), - KernelInformation( - "conv2d", False, True, "2d-convolution-performance.csv", "Batch Size" - ), - KernelInformation( - "attention", False, True, "attention-performance.csv", "Sequence Length" - ), + KernelInformation("add", "vector-addition-performance.csv", "Length"), + KernelInformation("softmax", "softmax-performance.csv", "Number of Columns"), + KernelInformation("rms_norm", "rms-norm-performance.csv", "Number of Columns"), + KernelInformation("matmul", "matrix-multiplication-performance.csv", "Sizes"), + KernelInformation("conv2d", "2d-convolution-performance.csv", "Batch Size"), + KernelInformation("attention", "attention-performance.csv", "Sequence Length"), ) providers = ("Triton", "NineToothed") -categories = ( - CategoryInformation( - tuple(kernel for kernel in kernels if kernel.memory_bound), "GB/s" - ), - CategoryInformation( - tuple(kernel for kernel in kernels if kernel.compute_bound), "TFLOPS" - ), -) - -num_rows = len(categories) -num_cols = max(len(category.kernels) for category in categories) +num_rows = 2 +num_cols = 3 fig, axs = plt.subplots(num_rows, num_cols) -performance_differences = [] - -for row, category in enumerate(categories): - axs[row, 0].set_ylabel(category.y_label) +performance_changes = [] - for col, kernel in enumerate(category.kernels): - df = pd.read_csv(kernel.perf_report_path) - ax = axs[row, col] +for i, kernel in enumerate(kernels): + df = pd.read_csv(kernel.perf_report_path) + ax = axs[i // num_cols, i % num_cols] - x = df.iloc[:, 0] + x = df.iloc[:, 0] - performance_differences.append((kernel, [])) + performance_changes.append((kernel, [])) - for provider in providers: - y = df[provider] + for provider in providers: + y = df[provider] - ax.plot(x, y, label=provider) + ax.plot(x, y, label=provider) - if provider == "NineToothed": - y_triton = df["Triton"] - diff = (y - y_triton) / y_triton * 100 - performance_differences[-1][-1].append(diff) + if provider == "NineToothed": + y_triton = df["Triton"] + change = (y - y_triton) / y_triton * 100 + performance_changes[-1][-1].append(change) - ax.set_title(kernel.name) - ax.set_xlabel(kernel.independent_variable) - ax.set_xscale("log", base=2) + ax.set_title(kernel.name) + ax.set_xlabel(kernel.independent_variable) + ax.set_ylabel("Execution Time (ms)") + ax.set_xscale("log", base=2) fig.legend(providers, loc="upper center", ncols=len(providers)) fig.tight_layout() @@ -96,24 +67,24 @@ class CategoryInformation: plt.show() plt.savefig("performance-comparison.png") -all_differences = [] +all_changes = [] stats_data = [] -for kernel, diffs in performance_differences: - all_differences.extend(diffs) +for kernel, changes in performance_changes: + all_changes.extend(changes) kernel_stats = { "Kernel": kernel.name, - "Mean": np.mean(diffs), - "Median": np.median(diffs), + "Mean": np.mean(changes), + "Median": np.median(changes), } stats_data.append(kernel_stats) overall_stats = { "Kernel": "Overall", - "Mean": np.mean(all_differences), - "Median": np.median(all_differences), + "Mean": np.mean(all_changes), + "Median": np.median(all_changes), } stats_data.append(overall_stats) diff --git a/rms_norm.py b/rms_norm.py index abefc8e..e44723b 100644 --- a/rms_norm.py +++ b/rms_norm.py @@ -95,7 +95,7 @@ def triton_rms_norm(input, eps=1e-5): line_vals=["ninetoothed", "torch", "triton"], line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="GB/s", + ylabel="ms", plot_name="rms-norm-performance", args={"m": 4096}, ) @@ -118,9 +118,6 @@ def benchmark(m, n, provider): elif provider == "triton": ms = triton.testing.do_bench(lambda: triton_rms_norm(input)) - def gbps(ms): - return 2 * input.numel() * input.element_size() * 1e-6 / ms - - return gbps(ms) + return ms benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/softmax.py b/softmax.py index 2641902..209995e 100644 --- a/softmax.py +++ b/softmax.py @@ -96,7 +96,7 @@ def triton_softmax(input): line_vals=["ninetoothed", "torch", "triton"], line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="GB/s", + ylabel="ms", plot_name="softmax-performance", args={"m": 4096}, ) @@ -117,9 +117,6 @@ def benchmark(m, n, provider): elif provider == "triton": ms = triton.testing.do_bench(lambda: triton_softmax(input)) - def gbps(ms): - return 2 * input.numel() * input.element_size() * 1e-6 / ms - - return gbps(ms) + return ms benchmark.run(show_plots=True, print_data=True, save_path=".") From bec933468633788821dafd7a1c1a2e4b8cfdb1e3 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 14 Jan 2025 14:31:00 +0800 Subject: [PATCH 24/68] Add PyTorch data into performance comparison --- performance_comparison.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/performance_comparison.py b/performance_comparison.py index a6745a6..c5a93a7 100644 --- a/performance_comparison.py +++ b/performance_comparison.py @@ -28,7 +28,7 @@ class KernelInformation: KernelInformation("attention", "attention-performance.csv", "Sequence Length"), ) -providers = ("Triton", "NineToothed") +providers = ("Triton", "NineToothed", "PyTorch") num_rows = 2 num_cols = 3 From 0cb93e53c2f5d9e39fa6d62385ea454e11025772 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 13 May 2025 22:08:44 +0800 Subject: [PATCH 25/68] Separate `add` kernels into modular packages --- add.py | 81 +++++++--------------------------- ops/ninetoothed/kernels/add.py | 21 +++++++++ ops/ninetoothed/torch.py | 11 +++++ ops/triton/kernels/add.py | 17 +++++++ ops/triton/torch.py | 19 ++++++++ 5 files changed, 85 insertions(+), 64 deletions(-) create mode 100644 ops/ninetoothed/kernels/add.py create mode 100644 ops/ninetoothed/torch.py create mode 100644 ops/triton/kernels/add.py create mode 100644 ops/triton/torch.py diff --git a/add.py b/add.py index 3d1e2af..782a7c7 100644 --- a/add.py +++ b/add.py @@ -1,57 +1,8 @@ -import ninetoothed import torch import triton -import triton.language as tl -from ninetoothed import Symbol, Tensor - -BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) - - -@ninetoothed.jit -def add_kernel( - lhs: Tensor(1).tile((BLOCK_SIZE,)), - rhs: Tensor(1).tile((BLOCK_SIZE,)), - output: Tensor(1).tile((BLOCK_SIZE,)), -): - output = lhs + rhs # noqa: F841 - - -def add(lhs, rhs): - output = torch.empty_like(lhs) - - add_kernel(lhs, rhs, output, BLOCK_SIZE=1024) - - return output - - -@triton.jit -def triton_add_kernel( - lhs_ptr, rhs_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr -): - pid = tl.program_id(0) - - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - - lhs = tl.load(lhs_ptr + offsets, mask=mask) - rhs = tl.load(rhs_ptr + offsets, mask=mask) - output = lhs + rhs - - tl.store(output_ptr + offsets, output, mask=mask) - - -def triton_add(lhs, rhs): - output = torch.empty_like(lhs) - n_elements = output.numel() - - def grid(meta): - return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - - triton_add_kernel[grid](lhs, rhs, output, n_elements, BLOCK_SIZE=1024) - - return output +import ops.ninetoothed.torch +import ops.triton.torch if __name__ == "__main__": torch.manual_seed(0) @@ -60,12 +11,12 @@ def grid(meta): dtype = torch.float16 device = "cuda" - lhs = torch.rand(size, dtype=dtype, device=device) - rhs = torch.rand(size, dtype=dtype, device=device) + input = torch.rand(size, dtype=dtype, device=device) + other = torch.rand(size, dtype=dtype, device=device) - ninetoothed_output = add(lhs, rhs) - torch_output = lhs + rhs - triton_output = triton_add(lhs, rhs) + ninetoothed_output = ops.ninetoothed.torch.add(input, other) + torch_output = input + other + triton_output = ops.triton.torch.add(input, other) print(ninetoothed_output) print(torch_output) @@ -95,22 +46,24 @@ def grid(meta): ) ) def benchmark(size, provider): - lhs = torch.randn(size, dtype=dtype, device=device) - rhs = torch.randn(size, dtype=dtype, device=device) + input = torch.randn(size, dtype=dtype, device=device) + other = torch.randn(size, dtype=dtype, device=device) - ninetoothed_output = add(lhs, rhs) - torch_output = torch.add(lhs, rhs) - triton_output = triton_add(lhs, rhs) + ninetoothed_output = ops.ninetoothed.torch.add(input, other) + torch_output = torch.add(input, other) + triton_output = ops.triton.torch.add(input, other) assert torch.allclose(ninetoothed_output, torch_output) assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: add(lhs, rhs)) + ms = triton.testing.do_bench( + lambda: ops.ninetoothed.torch.add(input, other) + ) elif provider == "torch": - ms = triton.testing.do_bench(lambda: torch.add(lhs, rhs)) + ms = triton.testing.do_bench(lambda: torch.add(input, other)) elif provider == "triton": - ms = triton.testing.do_bench(lambda: triton_add(lhs, rhs)) + ms = triton.testing.do_bench(lambda: ops.triton.torch.add(input, other)) return ms diff --git a/ops/ninetoothed/kernels/add.py b/ops/ninetoothed/kernels/add.py new file mode 100644 index 0000000..1dfb4be --- /dev/null +++ b/ops/ninetoothed/kernels/add.py @@ -0,0 +1,21 @@ +import ninetoothed +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) + + +def arrangement(input, other, output, BLOCK_SIZE=BLOCK_SIZE): + input_arranged = input.tile((BLOCK_SIZE,)) + other_arranged = other.tile((BLOCK_SIZE,)) + output_arranged = output.tile((BLOCK_SIZE,)) + + return input_arranged, other_arranged, output_arranged + + +def application(input, other, output): + output = input + other # noqa: F841 + + +tensors = tuple(Tensor(1) for _ in range(3)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/torch.py b/ops/ninetoothed/torch.py new file mode 100644 index 0000000..7a6acd2 --- /dev/null +++ b/ops/ninetoothed/torch.py @@ -0,0 +1,11 @@ +import torch + +import ops.ninetoothed.kernels.add + + +def add(input, other): + output = torch.empty_like(input) + + ops.ninetoothed.kernels.add.kernel(input, other, output, BLOCK_SIZE=1024) + + return output diff --git a/ops/triton/kernels/add.py b/ops/triton/kernels/add.py new file mode 100644 index 0000000..e7c1eb3 --- /dev/null +++ b/ops/triton/kernels/add.py @@ -0,0 +1,17 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel(input_ptr, other_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < num_elements + + input = tl.load(input_ptr + offsets, mask=mask) + other = tl.load(other_ptr + offsets, mask=mask) + output = input + other + + tl.store(output_ptr + offsets, output, mask=mask) diff --git a/ops/triton/torch.py b/ops/triton/torch.py new file mode 100644 index 0000000..e88b200 --- /dev/null +++ b/ops/triton/torch.py @@ -0,0 +1,19 @@ +import torch +import triton + +import ops.triton.kernels.add + + +def add(input, other): + num_elements = input.numel() + + output = torch.empty_like(input) + + def grid(meta): + return (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) + + ops.triton.kernels.add.kernel[grid]( + input, other, output, num_elements, BLOCK_SIZE=1024 + ) + + return output From d2784337256243fff223921e11de75f83cdddf4e Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 13 May 2025 22:29:46 +0800 Subject: [PATCH 26/68] Rename `matmul` to `mm` and separate the kernels into modular packages --- addmm.py | 10 +- bmm.py | 4 +- conv2d.py | 6 +- matmul.py | 288 ---------------------------------- mm.py | 68 ++++++++ ops/ninetoothed/kernels/mm.py | 44 ++++++ ops/ninetoothed/torch.py | 10 ++ ops/triton/kernels/mm.py | 150 ++++++++++++++++++ ops/triton/torch.py | 29 ++++ 9 files changed, 310 insertions(+), 299 deletions(-) delete mode 100644 matmul.py create mode 100644 mm.py create mode 100644 ops/ninetoothed/kernels/mm.py create mode 100644 ops/triton/kernels/mm.py diff --git a/addmm.py b/addmm.py index e0c654b..a3de947 100644 --- a/addmm.py +++ b/addmm.py @@ -6,21 +6,19 @@ import triton.language as tl from ninetoothed import Tensor -import matmul +import mm def arrangement(input, mat1, mat2, beta, alpha, output): - _, _, input_arranged = matmul.arrangement(mat1, mat2, input) + _, _, input_arranged = mm.arrangement(mat1, mat2, input) - mat1_arranged, mat2_arranged, output_arranged = matmul.arrangement( - mat1, mat2, output - ) + mat1_arranged, mat2_arranged, output_arranged = mm.arrangement(mat1, mat2, output) return input_arranged, mat1_arranged, mat2_arranged, beta, alpha, output_arranged def application(input, mat1, mat2, beta, alpha, output): - matmul.application(mat1, mat2, output) + mm.application(mat1, mat2, output) output = beta * input + alpha * output diff --git a/bmm.py b/bmm.py index 78a009d..5c694f4 100644 --- a/bmm.py +++ b/bmm.py @@ -2,7 +2,7 @@ import torch from ninetoothed import Symbol, Tensor -import matmul +import mm BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) @@ -36,7 +36,7 @@ def arrangement( tensors = (Tensor(3), Tensor(3), Tensor(3)) -bmm_kernel = ninetoothed.make(arrangement, matmul.application, tensors) +bmm_kernel = ninetoothed.make(arrangement, mm.application, tensors) def bmm(lhs, rhs): diff --git a/conv2d.py b/conv2d.py index b742c11..ab79220 100644 --- a/conv2d.py +++ b/conv2d.py @@ -5,7 +5,7 @@ import triton.language as tl from ninetoothed import Tensor -import matmul +import mm def arrangement(input, filter, output): @@ -20,12 +20,12 @@ def arrangement(input, filter, output): output_flattened = output.permute((0, 2, 3, 1)).flatten(end_dim=3) - return matmul.arrangement(input_flattened, filter_permuted, output_flattened) + return mm.arrangement(input_flattened, filter_permuted, output_flattened) shape_options = {"constexpr": True, "upper_bound": 16} tensors = tuple(Tensor(4, shape_options=shape_options) for _ in range(3)) -conv2d_kernel = ninetoothed.make(arrangement, matmul.application, tensors) +conv2d_kernel = ninetoothed.make(arrangement, mm.application, tensors) def conv2d(input, filter): diff --git a/matmul.py b/matmul.py deleted file mode 100644 index 19ebc18..0000000 --- a/matmul.py +++ /dev/null @@ -1,288 +0,0 @@ -import ninetoothed -import ninetoothed.language as ntl -import torch -import triton -import triton.language as tl -from ninetoothed import Symbol, Tensor - - -def arrangement(lhs, rhs, output): - BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) - BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) - BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True) - - output_tiled = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N)) - - lhs_tiled = ( - lhs.tile((BLOCK_SIZE_M, BLOCK_SIZE_K)) - .tile((1, -1)) - .expand((-1, output_tiled.shape[1])) - ) - lhs_tiled.dtype = lhs_tiled.dtype.squeeze(0) - - rhs_tiled = ( - rhs.tile((BLOCK_SIZE_K, BLOCK_SIZE_N)) - .tile((-1, 1)) - .expand((output_tiled.shape[0], -1)) - ) - rhs_tiled.dtype = rhs_tiled.dtype.squeeze(1) - - return lhs_tiled, rhs_tiled, output_tiled - - -def application(lhs, rhs, output): - accumulator = ntl.zeros(output.shape, dtype=ntl.float32) - - for k in range(lhs.shape[0]): - accumulator += ntl.dot(lhs[k], rhs[k]) - - output = accumulator - - -tensors = (Tensor(2), Tensor(2), Tensor(2)) -matmul_kernel = ninetoothed.make(arrangement, application, tensors) - - -def matmul(lhs, rhs): - output_shape = (lhs.shape[0], rhs.shape[1]) - output = torch.empty(output_shape, dtype=lhs.dtype, device=lhs.device) - - matmul_kernel(lhs, rhs, output) - - return output - - -@triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - ], - key=["m", "n", "k"], -) -@triton.jit -def triton_matmul_kernel( - lhs_ptr, - rhs_ptr, - output_ptr, - m, - n, - k, - lhs_stride_m, - lhs_stride_k, - rhs_stride_k, - rhs_stride_n, - output_stride_m, - output_stride_n, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - pid = tl.program_id(0) - num_pid_m = tl.cdiv(m, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(n, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % m - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n - offs_k = tl.arange(0, BLOCK_SIZE_K) - lhs_ptrs = lhs_ptr + ( - offs_am[:, None] * lhs_stride_m + offs_k[None, :] * lhs_stride_k - ) - rhs_ptrs = rhs_ptr + ( - offs_k[:, None] * rhs_stride_k + offs_bn[None, :] * rhs_stride_n - ) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for i in range(0, tl.cdiv(k, BLOCK_SIZE_K)): - lhs = tl.load(lhs_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K) - rhs = tl.load(rhs_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K) - accumulator = tl.dot(lhs, rhs, accumulator) - lhs_ptrs += BLOCK_SIZE_K * lhs_stride_k - rhs_ptrs += BLOCK_SIZE_K * rhs_stride_k - output = accumulator.to(tl.float16) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - output_ptrs = ( - output_ptr - + output_stride_m * offs_cm[:, None] - + output_stride_n * offs_cn[None, :] - ) - output_mask = (offs_cm[:, None] < m) & (offs_cn[None, :] < n) - tl.store(output_ptrs, output, mask=output_mask) - - -def triton_matmul(lhs, rhs): - output_shape = (lhs.shape[0], rhs.shape[1]) - output = torch.empty(output_shape, dtype=lhs.dtype, device=lhs.device) - - def grid(meta): - return ( - triton.cdiv(lhs.shape[0], meta["BLOCK_SIZE_M"]) - * triton.cdiv(rhs.shape[1], meta["BLOCK_SIZE_N"]), - ) - - triton_matmul_kernel[grid]( - lhs, - rhs, - output, - lhs.shape[0], - rhs.shape[1], - lhs.shape[1], - lhs.stride(0), - lhs.stride(1), - rhs.stride(0), - rhs.stride(1), - output.stride(0), - output.stride(1), - ) - - return output - - -if __name__ == "__main__": - torch.manual_seed(0) - - shape = (512, 512) - dtype = torch.float16 - device = "cuda" - - lhs = torch.randn(shape, dtype=dtype, device=device) - rhs = torch.randn(shape, dtype=dtype, device=device) - - ninetoothed_output = matmul(lhs, rhs) - torch_output = torch.matmul(lhs, rhs) - triton_output = triton_matmul(lhs, rhs) - - print(ninetoothed_output) - print(torch_output) - print(triton_output) - - if torch.allclose(ninetoothed_output, torch_output): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["m", "n", "k"], - x_vals=[2**i for i in range(3, 13)], - x_log=True, - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="ms", - plot_name="matrix-multiplication-performance", - args={}, - ) - ) - def benchmark(m, n, k, provider): - lhs = torch.randn((m, k), dtype=dtype, device=device) - rhs = torch.randn((k, n), dtype=dtype, device=device) - - ninetoothed_output = matmul(lhs, rhs) - torch_output = torch.matmul(lhs, rhs) - triton_output = triton_matmul(lhs, rhs) - - assert torch.allclose(ninetoothed_output, torch_output, atol=0.025, rtol=0.025) - assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) - - if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: matmul(lhs, rhs)) - elif provider == "torch": - ms = triton.testing.do_bench(lambda: torch.matmul(lhs, rhs)) - elif provider == "triton": - ms = triton.testing.do_bench(lambda: triton_matmul(lhs, rhs)) - - return ms - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/mm.py b/mm.py new file mode 100644 index 0000000..33eb17f --- /dev/null +++ b/mm.py @@ -0,0 +1,68 @@ +import torch +import triton + +import ops.ninetoothed.torch +import ops.triton.torch + +if __name__ == "__main__": + torch.manual_seed(0) + + shape = (512, 512) + dtype = torch.float16 + device = "cuda" + + input = torch.randn(shape, dtype=dtype, device=device) + other = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.mm(input, other) + torch_output = torch.mm(input, other) + triton_output = ops.triton.torch.mm(input, other) + + print(ninetoothed_output) + print(torch_output) + print(triton_output) + + if torch.allclose(ninetoothed_output, torch_output): + print("✅ NineToothed and PyTorch match.") + else: + print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["m", "n", "k"], + x_vals=[2**i for i in range(3, 13)], + x_log=True, + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="ms", + plot_name="matrix-multiplication-performance", + args={}, + ) + ) + def benchmark(m, n, k, provider): + input = torch.randn((m, k), dtype=dtype, device=device) + other = torch.randn((k, n), dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.mm(input, other) + torch_output = torch.mm(input, other) + triton_output = ops.triton.torch.mm(input, other) + + assert torch.allclose(ninetoothed_output, torch_output, atol=0.025, rtol=0.025) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) + + if provider == "ninetoothed": + ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.mm(input, other)) + elif provider == "torch": + ms = triton.testing.do_bench(lambda: torch.mm(input, other)) + elif provider == "triton": + ms = triton.testing.do_bench(lambda: ops.triton.torch.mm(input, other)) + + return ms + + benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/ops/ninetoothed/kernels/mm.py b/ops/ninetoothed/kernels/mm.py new file mode 100644 index 0000000..abb054d --- /dev/null +++ b/ops/ninetoothed/kernels/mm.py @@ -0,0 +1,44 @@ +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor, block_size + +BLOCK_SIZE_M = block_size() +BLOCK_SIZE_N = block_size() +BLOCK_SIZE_K = block_size() + + +def arrangement( + input, + other, + output, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, +): + output_arranged = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N)) + + input_arranged = input.tile((BLOCK_SIZE_M, BLOCK_SIZE_K)) + input_arranged = input_arranged.tile((1, -1)) + input_arranged = input_arranged.expand((-1, output_arranged.shape[1])) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + + other_arranged = other.tile((BLOCK_SIZE_K, BLOCK_SIZE_N)) + other_arranged = other_arranged.tile((-1, 1)) + other_arranged = other_arranged.expand((output_arranged.shape[0], -1)) + other_arranged.dtype = other_arranged.dtype.squeeze(1) + + return input_arranged, other_arranged, output_arranged + + +def application(input, other, output): + accumulator = ntl.zeros(output.shape, dtype=ntl.float32) + + for k in range(input.shape[0]): + accumulator += ntl.dot(input[k], other[k]) + + output = accumulator + + +tensors = (Tensor(2), Tensor(2), Tensor(2)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/torch.py b/ops/ninetoothed/torch.py index 7a6acd2..669af16 100644 --- a/ops/ninetoothed/torch.py +++ b/ops/ninetoothed/torch.py @@ -1,6 +1,7 @@ import torch import ops.ninetoothed.kernels.add +import ops.ninetoothed.kernels.mm def add(input, other): @@ -9,3 +10,12 @@ def add(input, other): ops.ninetoothed.kernels.add.kernel(input, other, output, BLOCK_SIZE=1024) return output + + +def mm(input, other): + output_shape = (input.shape[0], other.shape[1]) + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + + ops.ninetoothed.kernels.mm.kernel(input, other, output) + + return output diff --git a/ops/triton/kernels/mm.py b/ops/triton/kernels/mm.py new file mode 100644 index 0000000..eae5fd6 --- /dev/null +++ b/ops/triton/kernels/mm.py @@ -0,0 +1,150 @@ +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + ], + key=["m", "n", "k"], +) +@triton.jit +def kernel( + input_ptr, + other_ptr, + output_ptr, + m, + n, + k, + input_stride_m, + input_stride_k, + other_stride_k, + other_stride_n, + output_stride_m, + output_stride_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(m, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(n, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % m + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n + offs_k = tl.arange(0, BLOCK_SIZE_K) + input_ptrs = input_ptr + ( + offs_am[:, None] * input_stride_m + offs_k[None, :] * input_stride_k + ) + other_ptrs = other_ptr + ( + offs_k[:, None] * other_stride_k + offs_bn[None, :] * other_stride_n + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for i in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + input = tl.load(input_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K) + other = tl.load(other_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K) + + accumulator = tl.dot(input, other, accumulator) + + input_ptrs += BLOCK_SIZE_K * input_stride_k + other_ptrs += BLOCK_SIZE_K * other_stride_k + + output = accumulator + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + output_ptrs = ( + output_ptr + + output_stride_m * offs_cm[:, None] + + output_stride_n * offs_cn[None, :] + ) + + tl.store(output_ptrs, output, mask=(offs_cm[:, None] < m) & (offs_cn[None, :] < n)) diff --git a/ops/triton/torch.py b/ops/triton/torch.py index e88b200..34ec09e 100644 --- a/ops/triton/torch.py +++ b/ops/triton/torch.py @@ -2,6 +2,7 @@ import triton import ops.triton.kernels.add +import ops.triton.kernels.mm def add(input, other): @@ -17,3 +18,31 @@ def grid(meta): ) return output + + +def mm(input, other): + output_shape = (input.shape[0], other.shape[1]) + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + + def grid(meta): + return ( + triton.cdiv(input.shape[0], meta["BLOCK_SIZE_M"]) + * triton.cdiv(other.shape[1], meta["BLOCK_SIZE_N"]), + ) + + ops.triton.kernels.mm.kernel[grid]( + input, + other, + output, + input.shape[0], + other.shape[1], + input.shape[1], + input.stride(0), + input.stride(1), + other.stride(0), + other.stride(1), + output.stride(0), + output.stride(1), + ) + + return output From 932a1d51542d7941ea91566e172d0779004f0b4c Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 13 May 2025 23:14:58 +0800 Subject: [PATCH 27/68] Separate `addmm` kernels into modular packages --- addmm.py | 243 ++----------------------------- ops/ninetoothed/kernels/addmm.py | 22 +++ ops/ninetoothed/torch.py | 10 ++ ops/triton/kernels/addmm.py | 166 +++++++++++++++++++++ ops/triton/torch.py | 34 +++++ 5 files changed, 244 insertions(+), 231 deletions(-) create mode 100644 ops/ninetoothed/kernels/addmm.py create mode 100644 ops/triton/kernels/addmm.py diff --git a/addmm.py b/addmm.py index a3de947..59b21ad 100644 --- a/addmm.py +++ b/addmm.py @@ -1,235 +1,10 @@ import random -import ninetoothed import torch import triton -import triton.language as tl -from ninetoothed import Tensor - -import mm - - -def arrangement(input, mat1, mat2, beta, alpha, output): - _, _, input_arranged = mm.arrangement(mat1, mat2, input) - - mat1_arranged, mat2_arranged, output_arranged = mm.arrangement(mat1, mat2, output) - - return input_arranged, mat1_arranged, mat2_arranged, beta, alpha, output_arranged - - -def application(input, mat1, mat2, beta, alpha, output): - mm.application(mat1, mat2, output) - output = beta * input + alpha * output - - -tensors = (Tensor(2), Tensor(2), Tensor(2), Tensor(0), Tensor(0), Tensor(2)) -addmm_kernel = ninetoothed.make(arrangement, application, tensors) - - -def addmm(input, mat1, mat2, beta=1, alpha=1): - output_shape = (mat1.shape[0], mat2.shape[1]) - output = torch.empty(output_shape, dtype=mat1.dtype, device=mat1.device) - - addmm_kernel(input, mat1, mat2, beta, alpha, output) - - return output - - -@triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - ], - key=["m", "n", "k"], -) -@triton.jit -def triton_addmm_kernel( - input_ptr, - mat1_ptr, - mat2_ptr, - output_ptr, - m, - n, - k, - input_stride_m, - input_stride_n, - mat1_stride_m, - mat1_stride_k, - mat2_stride_k, - mat2_stride_n, - output_stride_m, - output_stride_n, - beta, - alpha, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - pid = tl.program_id(0) - num_pid_m = tl.cdiv(m, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(n, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % m - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n - offs_k = tl.arange(0, BLOCK_SIZE_K) - mat1_ptrs = mat1_ptr + ( - offs_am[:, None] * mat1_stride_m + offs_k[None, :] * mat1_stride_k - ) - mat2_ptrs = mat2_ptr + ( - offs_k[:, None] * mat2_stride_k + offs_bn[None, :] * mat2_stride_n - ) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for i in range(0, tl.cdiv(k, BLOCK_SIZE_K)): - mat1 = tl.load( - mat1_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K, other=0.0 - ) - mat2 = tl.load( - mat2_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K, other=0.0 - ) - accumulator = tl.dot(mat1, mat2, accumulator) - mat1_ptrs += BLOCK_SIZE_K * mat1_stride_k - mat2_ptrs += BLOCK_SIZE_K * mat2_stride_k - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - mask_c = (offs_cm[:, None] < m) & (offs_cn[None, :] < n) - - input_ptrs = ( - input_ptr - + input_stride_m * offs_cm[:, None] - + input_stride_n * offs_cn[None, :] - ) - input = tl.load(input_ptrs, mask=mask_c) - - output = beta * input + alpha * accumulator.to(tl.float16) - - output_ptrs = ( - output_ptr - + output_stride_m * offs_cm[:, None] - + output_stride_n * offs_cn[None, :] - ) - tl.store(output_ptrs, output, mask=mask_c) - - -def triton_addmm(input, mat1, mat2, beta=1, alpha=1): - output_shape = (mat1.shape[0], mat2.shape[1]) - output = torch.empty(output_shape, dtype=mat1.dtype, device=mat1.device) - - def grid(meta): - return ( - triton.cdiv(mat1.shape[0], meta["BLOCK_SIZE_M"]) - * triton.cdiv(mat2.shape[1], meta["BLOCK_SIZE_N"]), - ) - - triton_addmm_kernel[grid]( - input, - mat1, - mat2, - output, - mat1.shape[0], - mat2.shape[1], - mat1.shape[1], - input.stride(0), - input.stride(1), - mat1.stride(0), - mat1.stride(1), - mat2.stride(0), - mat2.stride(1), - output.stride(0), - output.stride(1), - beta, - alpha, - ) - - return output +import ops.ninetoothed.torch +import ops.triton.torch if __name__ == "__main__": random.seed(0) @@ -245,9 +20,11 @@ def grid(meta): beta = random.uniform(0, 1) alpha = random.uniform(0, 1) - ninetoothed_output = addmm(input, mat1, mat2, beta=beta, alpha=alpha) + ninetoothed_output = ops.ninetoothed.torch.addmm( + input, mat1, mat2, beta=beta, alpha=alpha + ) torch_output = torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha) - triton_output = triton_addmm(input, mat1, mat2, beta=beta, alpha=alpha) + triton_output = ops.triton.torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha) print(ninetoothed_output) print(torch_output) @@ -284,7 +61,9 @@ def benchmark(m, n, k, provider): if provider == "ninetoothed": ms = triton.testing.do_bench( - lambda: addmm(input, mat1, mat2, beta=beta, alpha=alpha) + lambda: ops.ninetoothed.torch.addmm( + input, mat1, mat2, beta=beta, alpha=alpha + ) ) elif provider == "torch": ms = triton.testing.do_bench( @@ -292,7 +71,9 @@ def benchmark(m, n, k, provider): ) elif provider == "triton": ms = triton.testing.do_bench( - lambda: triton_addmm(input, mat1, mat2, beta=beta, alpha=alpha) + lambda: ops.triton.torch.addmm( + input, mat1, mat2, beta=beta, alpha=alpha + ) ) return ms diff --git a/ops/ninetoothed/kernels/addmm.py b/ops/ninetoothed/kernels/addmm.py new file mode 100644 index 0000000..495246a --- /dev/null +++ b/ops/ninetoothed/kernels/addmm.py @@ -0,0 +1,22 @@ +import ninetoothed +from ninetoothed import Tensor + +import ops.ninetoothed.kernels.mm as mm + + +def arrangement(input, mat1, mat2, beta, alpha, output): + _, _, input_arranged = mm.arrangement(mat1, mat2, input) + + mat1_arranged, mat2_arranged, output_arranged = mm.arrangement(mat1, mat2, output) + + return input_arranged, mat1_arranged, mat2_arranged, beta, alpha, output_arranged + + +def application(input, mat1, mat2, beta, alpha, output): + mm.application(mat1, mat2, output) + output = beta * input + alpha * output + + +tensors = (Tensor(2), Tensor(2), Tensor(2), Tensor(0), Tensor(0), Tensor(2)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/torch.py b/ops/ninetoothed/torch.py index 669af16..aa0c898 100644 --- a/ops/ninetoothed/torch.py +++ b/ops/ninetoothed/torch.py @@ -1,6 +1,7 @@ import torch import ops.ninetoothed.kernels.add +import ops.ninetoothed.kernels.addmm import ops.ninetoothed.kernels.mm @@ -12,6 +13,15 @@ def add(input, other): return output +def addmm(input, mat1, mat2, beta=1, alpha=1): + output_shape = (mat1.shape[0], mat2.shape[1]) + output = torch.empty(output_shape, dtype=mat1.dtype, device=mat1.device) + + ops.ninetoothed.kernels.addmm.kernel(input, mat1, mat2, beta, alpha, output) + + return output + + def mm(input, other): output_shape = (input.shape[0], other.shape[1]) output = torch.empty(output_shape, dtype=input.dtype, device=input.device) diff --git a/ops/triton/kernels/addmm.py b/ops/triton/kernels/addmm.py new file mode 100644 index 0000000..5ba7dc5 --- /dev/null +++ b/ops/triton/kernels/addmm.py @@ -0,0 +1,166 @@ +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + ], + key=["m", "n", "k"], +) +@triton.jit +def kernel( + input_ptr, + mat1_ptr, + mat2_ptr, + output_ptr, + m, + n, + k, + input_stride_m, + input_stride_n, + mat1_stride_m, + mat1_stride_k, + mat2_stride_k, + mat2_stride_n, + output_stride_m, + output_stride_n, + beta, + alpha, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(m, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(n, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % m + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n + offs_k = tl.arange(0, BLOCK_SIZE_K) + mat1_ptrs = mat1_ptr + ( + offs_am[:, None] * mat1_stride_m + offs_k[None, :] * mat1_stride_k + ) + mat2_ptrs = mat2_ptr + ( + offs_k[:, None] * mat2_stride_k + offs_bn[None, :] * mat2_stride_n + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for i in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + mat1 = tl.load(mat1_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K) + mat2 = tl.load(mat2_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K) + + accumulator = tl.dot(mat1, mat2, accumulator) + + mat1_ptrs += BLOCK_SIZE_K * mat1_stride_k + mat2_ptrs += BLOCK_SIZE_K * mat2_stride_k + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + mask_c = (offs_cm[:, None] < m) & (offs_cn[None, :] < n) + + input_ptrs = ( + input_ptr + + input_stride_m * offs_cm[:, None] + + input_stride_n * offs_cn[None, :] + ) + + input = tl.load(input_ptrs, mask=mask_c) + + output = beta * input + alpha * accumulator.to(tl.float16) + + output_ptrs = ( + output_ptr + + output_stride_m * offs_cm[:, None] + + output_stride_n * offs_cn[None, :] + ) + + tl.store(output_ptrs, output, mask=mask_c) diff --git a/ops/triton/torch.py b/ops/triton/torch.py index 34ec09e..b4d7c5f 100644 --- a/ops/triton/torch.py +++ b/ops/triton/torch.py @@ -2,6 +2,7 @@ import triton import ops.triton.kernels.add +import ops.triton.kernels.addmm import ops.triton.kernels.mm @@ -20,6 +21,39 @@ def grid(meta): return output +def addmm(input, mat1, mat2, beta=1, alpha=1): + output_shape = (mat1.shape[0], mat2.shape[1]) + output = torch.empty(output_shape, dtype=mat1.dtype, device=mat1.device) + + def grid(meta): + return ( + triton.cdiv(mat1.shape[0], meta["BLOCK_SIZE_M"]) + * triton.cdiv(mat2.shape[1], meta["BLOCK_SIZE_N"]), + ) + + ops.triton.kernels.addmm.kernel[grid]( + input, + mat1, + mat2, + output, + mat1.shape[0], + mat2.shape[1], + mat1.shape[1], + input.stride(0), + input.stride(1), + mat1.stride(0), + mat1.stride(1), + mat2.stride(0), + mat2.stride(1), + output.stride(0), + output.stride(1), + beta, + alpha, + ) + + return output + + def mm(input, other): output_shape = (input.shape[0], other.shape[1]) output = torch.empty(output_shape, dtype=input.dtype, device=input.device) From 97b2842c2c32873f111736d574aeeb1f97a9c3a9 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 13 May 2025 23:23:49 +0800 Subject: [PATCH 28/68] Separate `conv2d` kernels into modular packages --- conv2d.py | 286 ++---------------------------- ops/ninetoothed/kernels/conv2d.py | 25 +++ ops/ninetoothed/torch.py | 14 ++ ops/triton/kernels/conv2d.py | 202 +++++++++++++++++++++ ops/triton/torch.py | 34 ++++ 5 files changed, 287 insertions(+), 274 deletions(-) create mode 100644 ops/ninetoothed/kernels/conv2d.py create mode 100644 ops/triton/kernels/conv2d.py diff --git a/conv2d.py b/conv2d.py index ab79220..20dec55 100644 --- a/conv2d.py +++ b/conv2d.py @@ -1,275 +1,9 @@ -import ninetoothed import torch import torch.nn.functional as F import triton -import triton.language as tl -from ninetoothed import Tensor - -import mm - - -def arrangement(input, filter, output): - input_tiled = input.tile((1, *filter.shape[1:]), strides=(-1, -1, 1, 1)) - input_squeezed = input_tiled.squeeze(1) - input_squeezed.dtype = input_squeezed.dtype.squeeze(0) - input_raveled = input_squeezed.ravel() - input_flattened = input_raveled.flatten(end_dim=3).flatten(start_dim=1) - - filter_flattened = filter.flatten(start_dim=1) - filter_permuted = filter_flattened.permute((1, 0)) - - output_flattened = output.permute((0, 2, 3, 1)).flatten(end_dim=3) - - return mm.arrangement(input_flattened, filter_permuted, output_flattened) - - -shape_options = {"constexpr": True, "upper_bound": 16} -tensors = tuple(Tensor(4, shape_options=shape_options) for _ in range(3)) -conv2d_kernel = ninetoothed.make(arrangement, mm.application, tensors) - - -def conv2d(input, filter): - n, _, h, w = input.shape - k, _, r, s = filter.shape - p = h - r + 1 - q = w - s + 1 - - output = torch.empty((n, k, p, q), dtype=input.dtype, device=input.device) - - conv2d_kernel(input, filter, output) - - return output - - -@triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), - ], - key=["N", "C", "H", "W", "C", "R", "S"], -) -@triton.jit -def triton_conv2d_kernel( - input_ptr, - filter_ptr, - output_ptr, - N: tl.constexpr, - C: tl.constexpr, - H: tl.constexpr, - W: tl.constexpr, - K: tl.constexpr, - R: tl.constexpr, - S: tl.constexpr, - input_stride_n, - input_stride_c, - input_stride_h, - input_stride_w, - filter_stride_k, - filter_stride_c, - filter_stride_r, - filter_stride_s, - output_stride_n, - output_stride_k, - output_stride_p, - output_stride_q, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - P: tl.constexpr = H - R + 1 - Q: tl.constexpr = W - S + 1 - - GEMM_M: tl.constexpr = N * P * Q - GEMM_N: tl.constexpr = K - GEMM_K: tl.constexpr = C * R * S - - pid = tl.program_id(0) - num_pid_gemm_m = tl.cdiv(GEMM_M, BLOCK_SIZE_M) - num_pid_gemm_n = tl.cdiv(GEMM_N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_gemm_n - group_id = pid // num_pid_in_group - first_pid_gemm_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_gemm_m - first_pid_gemm_m, GROUP_SIZE_M) - pid_gemm_m = first_pid_gemm_m + ((pid % num_pid_in_group) % group_size_m) - pid_gemm_n = (pid % num_pid_in_group) // group_size_m - - offs_gemm_i = pid_gemm_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_gemm_j = pid_gemm_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - offs_n = offs_gemm_i // (P * Q) - offs_k = offs_gemm_j - npq_residual = offs_gemm_i % (P * Q) - offs_p = npq_residual // Q - offs_q = npq_residual % Q - - input_offs_gemm_m = ( - offs_n * input_stride_n + offs_p * input_stride_h + offs_q * input_stride_w - )[:, None] - filter_offs_gemm_n = (offs_k * filter_stride_k)[None, :] - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for i in range(0, tl.cdiv(GEMM_K, BLOCK_SIZE_K)): - offs_gemm_k = i * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - - offs_c = offs_gemm_k // (R * S) - crs_residual = offs_gemm_k % (R * S) - offs_r = crs_residual // S - offs_s = crs_residual % S - - input_offs_gemm_n = ( - offs_c * input_stride_c + offs_r * input_stride_h + offs_s * input_stride_w - )[None, :] - input_ptrs = input_ptr + input_offs_gemm_m + input_offs_gemm_n - input_mask = ((offs_n < N) & (offs_p < P) & (offs_q < Q))[:, None] & ( - (offs_c < C) & (offs_r < R) & (offs_s < S) - )[None, :] - input = tl.load(input_ptrs, mask=input_mask) - - filter_offs_gemm_m = ( - offs_c * filter_stride_c - + offs_r * filter_stride_r - + offs_s * filter_stride_s - )[:, None] - filter_ptrs = filter_ptr + filter_offs_gemm_m + filter_offs_gemm_n - filter_mask = (offs_k[None, :] < K) & ( - (offs_c < C) & (offs_r < R) & (offs_s < S) - )[:, None] - filter = tl.load(filter_ptrs, mask=filter_mask) - - accumulator = tl.dot(input, filter, accumulator) - - output = accumulator.to(tl.float16) - - output_ptrs = ( - output_ptr - + ( - offs_n * output_stride_n - + offs_p * output_stride_p - + offs_q * output_stride_q - )[:, None] - + (offs_k * output_stride_k)[None, :] - ) - output_mask = ( - (offs_n[:, None] < N) - & (offs_k[None, :] < K) - & (offs_p[:, None] < P) - & (offs_q[:, None] < Q) - ) - tl.store(output_ptrs, output, mask=output_mask) - - -def triton_conv2d(input, filter): - n, c, h, w = input.shape - k, _, r, s = filter.shape - p = h - r + 1 - q = w - s + 1 - - output = torch.empty((n, k, p, q), dtype=input.dtype, device=input.device) - - def grid(meta): - return ( - triton.cdiv(n * p * q, meta["BLOCK_SIZE_M"]) - * triton.cdiv(k, meta["BLOCK_SIZE_N"]), - ) - - triton_conv2d_kernel[grid]( - input, - filter, - output, - n, - c, - h, - w, - k, - r, - s, - *input.stride(), - *filter.stride(), - *output.stride(), - ) - - return output +import ops.ninetoothed.torch +import ops.triton.torch if __name__ == "__main__": torch.manual_seed(0) @@ -282,9 +16,9 @@ def grid(meta): input = torch.randn(n, c, h, w, dtype=dtype, device=device) filter = torch.randn(k, c, r, s, dtype=dtype, device=device) - ninetoothed_output = conv2d(input, filter) + ninetoothed_output = ops.ninetoothed.torch.conv2d(input, filter) torch_output = F.conv2d(input, filter) - triton_output = triton_conv2d(input, filter) + triton_output = ops.triton.torch.triton_conv2d(input, filter) print(ninetoothed_output) print(torch_output) @@ -320,19 +54,23 @@ def benchmark(n, provider): input = torch.randn((n, c, h, w), dtype=dtype, device=device) filter = torch.randn((k, c, r, s), dtype=dtype, device=device) - ninetoothed_output = conv2d(input, filter) + ninetoothed_output = ops.ninetoothed.torch.conv2d(input, filter) torch_output = F.conv2d(input, filter) - triton_output = triton_conv2d(input, filter) + triton_output = ops.triton.torch.triton_conv2d(input, filter) assert torch.allclose(ninetoothed_output, torch_output, atol=0.01, rtol=0.01) assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: conv2d(input, filter)) + ms = triton.testing.do_bench( + lambda: ops.ninetoothed.torch.conv2d(input, filter) + ) elif provider == "torch": ms = triton.testing.do_bench(lambda: F.conv2d(input, filter)) elif provider == "triton": - ms = triton.testing.do_bench(lambda: triton_conv2d(input, filter)) + ms = triton.testing.do_bench( + lambda: ops.triton.torch.triton_conv2d(input, filter) + ) return ms diff --git a/ops/ninetoothed/kernels/conv2d.py b/ops/ninetoothed/kernels/conv2d.py new file mode 100644 index 0000000..514bba9 --- /dev/null +++ b/ops/ninetoothed/kernels/conv2d.py @@ -0,0 +1,25 @@ +import ninetoothed +from ninetoothed import Tensor + +import ops.ninetoothed.kernels.mm as mm + + +def arrangement(input, filter, output): + input_arranged = input.tile((1, *filter.shape[1:]), strides=(-1, -1, 1, 1)) + input_arranged = input_arranged.squeeze(1) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + input_arranged = input_arranged.ravel() + input_arranged = input_arranged.flatten(end_dim=3).flatten(start_dim=1) + + filter_arranged = filter.flatten(start_dim=1) + filter_arranged = filter_arranged.permute((1, 0)) + + output_arranged = output.permute((0, 2, 3, 1)).flatten(end_dim=3) + + return mm.arrangement(input_arranged, filter_arranged, output_arranged) + + +shape_options = {"constexpr": True, "upper_bound": 16} +tensors = tuple(Tensor(4, shape_options=shape_options) for _ in range(3)) + +kernel = ninetoothed.make(arrangement, mm.application, tensors) diff --git a/ops/ninetoothed/torch.py b/ops/ninetoothed/torch.py index aa0c898..665b972 100644 --- a/ops/ninetoothed/torch.py +++ b/ops/ninetoothed/torch.py @@ -2,6 +2,7 @@ import ops.ninetoothed.kernels.add import ops.ninetoothed.kernels.addmm +import ops.ninetoothed.kernels.conv2d import ops.ninetoothed.kernels.mm @@ -22,6 +23,19 @@ def addmm(input, mat1, mat2, beta=1, alpha=1): return output +def conv2d(input, filter): + n, _, h, w = input.shape + k, _, r, s = filter.shape + p = h - r + 1 + q = w - s + 1 + + output = torch.empty((n, k, p, q), dtype=input.dtype, device=input.device) + + ops.ninetoothed.kernels.conv2d.kernel(input, filter, output) + + return output + + def mm(input, other): output_shape = (input.shape[0], other.shape[1]) output = torch.empty(output_shape, dtype=input.dtype, device=input.device) diff --git a/ops/triton/kernels/conv2d.py b/ops/triton/kernels/conv2d.py new file mode 100644 index 0000000..6e28ad7 --- /dev/null +++ b/ops/triton/kernels/conv2d.py @@ -0,0 +1,202 @@ +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + ], + key=["N", "C", "H", "W", "C", "R", "S"], +) +@triton.jit +def kernel( + input_ptr, + filter_ptr, + output_ptr, + N: tl.constexpr, + C: tl.constexpr, + H: tl.constexpr, + W: tl.constexpr, + K: tl.constexpr, + R: tl.constexpr, + S: tl.constexpr, + input_stride_n, + input_stride_c, + input_stride_h, + input_stride_w, + filter_stride_k, + filter_stride_c, + filter_stride_r, + filter_stride_s, + output_stride_n, + output_stride_k, + output_stride_p, + output_stride_q, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + P: tl.constexpr = H - R + 1 + Q: tl.constexpr = W - S + 1 + + GEMM_M: tl.constexpr = N * P * Q + GEMM_N: tl.constexpr = K + GEMM_K: tl.constexpr = C * R * S + + pid = tl.program_id(0) + num_pid_gemm_m = tl.cdiv(GEMM_M, BLOCK_SIZE_M) + num_pid_gemm_n = tl.cdiv(GEMM_N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_gemm_n + group_id = pid // num_pid_in_group + first_pid_gemm_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_gemm_m - first_pid_gemm_m, GROUP_SIZE_M) + pid_gemm_m = first_pid_gemm_m + ((pid % num_pid_in_group) % group_size_m) + pid_gemm_n = (pid % num_pid_in_group) // group_size_m + + offs_gemm_i = pid_gemm_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_gemm_j = pid_gemm_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + offs_n = offs_gemm_i // (P * Q) + offs_k = offs_gemm_j + npq_residual = offs_gemm_i % (P * Q) + offs_p = npq_residual // Q + offs_q = npq_residual % Q + + input_offs_gemm_m = ( + offs_n * input_stride_n + offs_p * input_stride_h + offs_q * input_stride_w + )[:, None] + filter_offs_gemm_n = (offs_k * filter_stride_k)[None, :] + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for i in range(0, tl.cdiv(GEMM_K, BLOCK_SIZE_K)): + offs_gemm_k = i * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + offs_c = offs_gemm_k // (R * S) + crs_residual = offs_gemm_k % (R * S) + offs_r = crs_residual // S + offs_s = crs_residual % S + + input_offs_gemm_n = ( + offs_c * input_stride_c + offs_r * input_stride_h + offs_s * input_stride_w + )[None, :] + input_ptrs = input_ptr + input_offs_gemm_m + input_offs_gemm_n + input_mask = ((offs_n < N) & (offs_p < P) & (offs_q < Q))[:, None] & ( + (offs_c < C) & (offs_r < R) & (offs_s < S) + )[None, :] + + input = tl.load(input_ptrs, mask=input_mask) + + filter_offs_gemm_m = ( + offs_c * filter_stride_c + + offs_r * filter_stride_r + + offs_s * filter_stride_s + )[:, None] + filter_ptrs = filter_ptr + filter_offs_gemm_m + filter_offs_gemm_n + filter_mask = (offs_k[None, :] < K) & ( + (offs_c < C) & (offs_r < R) & (offs_s < S) + )[:, None] + + filter = tl.load(filter_ptrs, mask=filter_mask) + + accumulator = tl.dot(input, filter, accumulator) + + output = accumulator + + output_ptrs = ( + output_ptr + + ( + offs_n * output_stride_n + + offs_p * output_stride_p + + offs_q * output_stride_q + )[:, None] + + (offs_k * output_stride_k)[None, :] + ) + output_mask = ( + (offs_n[:, None] < N) + & (offs_k[None, :] < K) + & (offs_p[:, None] < P) + & (offs_q[:, None] < Q) + ) + + tl.store(output_ptrs, output, mask=output_mask) diff --git a/ops/triton/torch.py b/ops/triton/torch.py index b4d7c5f..cd63f97 100644 --- a/ops/triton/torch.py +++ b/ops/triton/torch.py @@ -3,6 +3,7 @@ import ops.triton.kernels.add import ops.triton.kernels.addmm +import ops.triton.kernels.conv2d import ops.triton.kernels.mm @@ -54,6 +55,39 @@ def grid(meta): return output +def triton_conv2d(input, filter): + n, c, h, w = input.shape + k, _, r, s = filter.shape + p = h - r + 1 + q = w - s + 1 + + output = torch.empty((n, k, p, q), dtype=input.dtype, device=input.device) + + def grid(meta): + return ( + triton.cdiv(n * p * q, meta["BLOCK_SIZE_M"]) + * triton.cdiv(k, meta["BLOCK_SIZE_N"]), + ) + + ops.triton.kernels.conv2d.kernel[grid]( + input, + filter, + output, + n, + c, + h, + w, + k, + r, + s, + *input.stride(), + *filter.stride(), + *output.stride(), + ) + + return output + + def mm(input, other): output_shape = (input.shape[0], other.shape[1]) output = torch.empty(output_shape, dtype=input.dtype, device=input.device) From 4db4e5f0af0c576f17b2296be92facbea923a8ab Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 13 May 2025 23:36:31 +0800 Subject: [PATCH 29/68] Separate `softmax` kernels into modular packages --- ops/ninetoothed/kernels/softmax.py | 24 +++++++++ ops/ninetoothed/torch.py | 9 ++++ ops/triton/kernels/softmax.py | 31 ++++++++++++ ops/triton/torch.py | 16 ++++++ softmax.py | 81 +++--------------------------- 5 files changed, 88 insertions(+), 73 deletions(-) create mode 100644 ops/ninetoothed/kernels/softmax.py create mode 100644 ops/triton/kernels/softmax.py diff --git a/ops/ninetoothed/kernels/softmax.py b/ops/ninetoothed/kernels/softmax.py new file mode 100644 index 0000000..6d24e66 --- /dev/null +++ b/ops/ninetoothed/kernels/softmax.py @@ -0,0 +1,24 @@ +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) + + +def arrangement(input, output, BLOCK_SIZE=BLOCK_SIZE): + return input.tile((1, BLOCK_SIZE)), output.tile((1, BLOCK_SIZE)) + + +def application(input, output): + input_loaded = input + + row_minus_max = input_loaded - ntl.max(input_loaded) + numerator = ntl.exp(row_minus_max) + denominator = ntl.sum(numerator) + + output = numerator / denominator # noqa: F841 + + +tensors = (Tensor(2, other=float("-inf")), Tensor(2)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/torch.py b/ops/ninetoothed/torch.py index 665b972..f51a0e8 100644 --- a/ops/ninetoothed/torch.py +++ b/ops/ninetoothed/torch.py @@ -4,6 +4,7 @@ import ops.ninetoothed.kernels.addmm import ops.ninetoothed.kernels.conv2d import ops.ninetoothed.kernels.mm +import ops.ninetoothed.kernels.softmax def add(input, other): @@ -43,3 +44,11 @@ def mm(input, other): ops.ninetoothed.kernels.mm.kernel(input, other, output) return output + + +def softmax(input): + output = torch.empty_like(input) + + ops.ninetoothed.kernels.softmax.kernel(input, output, BLOCK_SIZE=input.shape[-1]) + + return output diff --git a/ops/triton/kernels/softmax.py b/ops/triton/kernels/softmax.py new file mode 100644 index 0000000..1514907 --- /dev/null +++ b/ops/triton/kernels/softmax.py @@ -0,0 +1,31 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel( + input_ptr, + output_ptr, + input_stride, + output_stride, + num_cols, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + + row_start_ptr = input_ptr + row_idx * input_stride + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + mask = col_offsets < num_cols + + row = tl.load(input_ptrs, mask=mask, other=float("-inf")) + + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + output_row_start_ptr = output_ptr + row_idx * output_stride + output_ptrs = output_row_start_ptr + col_offsets + + tl.store(output_ptrs, softmax_output, mask=mask) diff --git a/ops/triton/torch.py b/ops/triton/torch.py index cd63f97..ef692c1 100644 --- a/ops/triton/torch.py +++ b/ops/triton/torch.py @@ -5,6 +5,7 @@ import ops.triton.kernels.addmm import ops.triton.kernels.conv2d import ops.triton.kernels.mm +import ops.triton.kernels.softmax def add(input, other): @@ -114,3 +115,18 @@ def grid(meta): ) return output + + +def softmax(input): + output = torch.empty_like(input) + + ops.triton.kernels.softmax.kernel[(input.shape[0],)]( + input, + output, + input.stride(0), + output.stride(0), + input.shape[1], + BLOCK_SIZE=triton.next_power_of_2(input.shape[-1]), + ) + + return output diff --git a/softmax.py b/softmax.py index ee417bc..f979928 100644 --- a/softmax.py +++ b/softmax.py @@ -1,73 +1,8 @@ -import ninetoothed -import ninetoothed.language as ntl import torch import triton -import triton.language as tl -from ninetoothed import Symbol, Tensor - -BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) - - -@ninetoothed.jit -def softmax_kernel( - input_row: Tensor(2, other=float("-inf")).tile((1, BLOCK_SIZE)), - output_row: Tensor(2).tile((1, BLOCK_SIZE)), -): - row_minus_max = input_row - ntl.max(input_row) - numerator = ntl.exp(row_minus_max) - denominator = ntl.sum(numerator) - output_row = numerator / denominator # noqa: F841 - - -def softmax(input): - output = torch.empty_like(input) - - softmax_kernel(input, output, BLOCK_SIZE=input.shape[-1]) - - return output - - -@triton.jit -def triton_softmax_kernel( - input_ptr, - output_ptr, - input_row_stride, - output_row_stride, - n_cols, - BLOCK_SIZE: tl.constexpr, -): - row_idx = tl.program_id(0) - - row_start_ptr = input_ptr + row_idx * input_row_stride - col_offsets = tl.arange(0, BLOCK_SIZE) - input_ptrs = row_start_ptr + col_offsets - mask = col_offsets < n_cols - - row = tl.load(input_ptrs, mask=mask, other=float("-inf")) - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - - output_row_start_ptr = output_ptr + row_idx * output_row_stride - output_ptrs = output_row_start_ptr + col_offsets - tl.store(output_ptrs, softmax_output, mask=mask) - - -def triton_softmax(input): - output = torch.empty_like(input) - - triton_softmax_kernel[(input.shape[0],)]( - input, - output, - input.stride(0), - output.stride(0), - input.shape[1], - BLOCK_SIZE=triton.next_power_of_2(input.shape[-1]), - ) - - return output +import ops.ninetoothed.torch +import ops.triton.torch if __name__ == "__main__": torch.manual_seed(0) @@ -77,9 +12,9 @@ def triton_softmax(input): input = torch.randn(1823, 781, dtype=dtype, device=device) - ninetoothed_output = softmax(input) + ninetoothed_output = ops.ninetoothed.torch.softmax(input) torch_output = torch.softmax(input, axis=-1) - triton_output = triton_softmax(input) + triton_output = ops.triton.torch.softmax(input) print(ninetoothed_output) print(torch_output) @@ -111,19 +46,19 @@ def triton_softmax(input): def benchmark(m, n, provider): input = torch.randn(m, n, dtype=dtype, device=device) - ninetoothed_output = softmax(input) + ninetoothed_output = ops.ninetoothed.torch.softmax(input) torch_output = torch.softmax(input, axis=-1) - triton_output = triton_softmax(input) + triton_output = ops.triton.torch.softmax(input) assert torch.allclose(ninetoothed_output, torch_output, atol=0.001) assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: softmax(input)) + ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.softmax(input)) elif provider == "torch": ms = triton.testing.do_bench(lambda: torch.softmax(input, axis=-1)) elif provider == "triton": - ms = triton.testing.do_bench(lambda: triton_softmax(input)) + ms = triton.testing.do_bench(lambda: ops.triton.torch.softmax(input)) return ms From 236add48b358a1e9075edc383fd73b63700009f4 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 14 May 2025 10:03:04 +0800 Subject: [PATCH 30/68] Separate `rms_norm` kernels into modular packages --- ops/ninetoothed/kernels/rms_norm.py | 21 ++++++++ ops/ninetoothed/torch.py | 14 +++++ ops/triton/kernels/rms_norm.py | 27 ++++++++++ ops/triton/torch.py | 20 ++++++++ rms_norm.py | 79 +++-------------------------- 5 files changed, 90 insertions(+), 71 deletions(-) create mode 100644 ops/ninetoothed/kernels/rms_norm.py create mode 100644 ops/triton/kernels/rms_norm.py diff --git a/ops/ninetoothed/kernels/rms_norm.py b/ops/ninetoothed/kernels/rms_norm.py new file mode 100644 index 0000000..f2b926e --- /dev/null +++ b/ops/ninetoothed/kernels/rms_norm.py @@ -0,0 +1,21 @@ +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) + + +def arrangement(input, eps, output, BLOCK_SIZE=BLOCK_SIZE): + return input.tile((1, BLOCK_SIZE)), eps, output.tile((1, BLOCK_SIZE)) + + +def application(input, eps, output): + input_fp32 = ntl.cast(input, ntl.float32) + output = input_fp32 * ntl.rsqrt( # noqa: F841 + ntl.sum(input_fp32 * input_fp32) / input.shape[-1] + eps + ) + + +tensors = (Tensor(2), Tensor(0), Tensor(2)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/torch.py b/ops/ninetoothed/torch.py index f51a0e8..e37891f 100644 --- a/ops/ninetoothed/torch.py +++ b/ops/ninetoothed/torch.py @@ -4,6 +4,7 @@ import ops.ninetoothed.kernels.addmm import ops.ninetoothed.kernels.conv2d import ops.ninetoothed.kernels.mm +import ops.ninetoothed.kernels.rms_norm import ops.ninetoothed.kernels.softmax @@ -46,6 +47,19 @@ def mm(input, other): return output +def rms_norm(input, eps=None): + if eps is None: + eps = torch.finfo(input.dtype).eps + + output = torch.empty_like(input) + + ops.ninetoothed.kernels.rms_norm.kernel( + input, eps, output, BLOCK_SIZE=input.shape[-1] + ) + + return output + + def softmax(input): output = torch.empty_like(input) diff --git a/ops/triton/kernels/rms_norm.py b/ops/triton/kernels/rms_norm.py new file mode 100644 index 0000000..852d06f --- /dev/null +++ b/ops/triton/kernels/rms_norm.py @@ -0,0 +1,27 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel( + input_ptr, + output_ptr, + num_cols, + input_stride, + output_stride, + eps: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = input_ptr + row_idx * input_stride + col_offsets + mask = col_offsets < num_cols + + input = tl.load(input_ptrs, mask=mask).to(tl.float32) + + output = input * tl.rsqrt(tl.sum(input * input) / num_cols + eps) + + output_ptrs = output_ptr + row_idx * output_stride + col_offsets + + tl.store(output_ptrs, output, mask=mask) diff --git a/ops/triton/torch.py b/ops/triton/torch.py index ef692c1..b2481a2 100644 --- a/ops/triton/torch.py +++ b/ops/triton/torch.py @@ -5,6 +5,7 @@ import ops.triton.kernels.addmm import ops.triton.kernels.conv2d import ops.triton.kernels.mm +import ops.triton.kernels.rms_norm import ops.triton.kernels.softmax @@ -117,6 +118,25 @@ def grid(meta): return output +def rms_norm(input, eps=None): + if eps is None: + eps = torch.finfo(input.dtype).eps + + output = torch.empty_like(input) + + ops.triton.kernels.rms_norm.kernel[(input.shape[-2],)]( + input, + output, + input.shape[-1], + input.stride(-2), + output.stride(-2), + eps, + BLOCK_SIZE=triton.next_power_of_2(input.shape[-1]), + ) + + return output + + def softmax(input): output = torch.empty_like(input) diff --git a/rms_norm.py b/rms_norm.py index 18f8429..c06ee7f 100644 --- a/rms_norm.py +++ b/rms_norm.py @@ -1,72 +1,9 @@ -import ninetoothed -import ninetoothed.language as ntl import torch import torch.nn.functional as F import triton -import triton.language as tl -from ninetoothed import Symbol, Tensor - -BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) - - -@ninetoothed.jit -def rms_norm_kernel( - input: Tensor(2).tile((1, BLOCK_SIZE)), - output: Tensor(2).tile((1, BLOCK_SIZE)), - eps: Tensor(0), -): - input_fp32 = ntl.cast(input, ntl.float32) - output = input_fp32 * ntl.rsqrt( # noqa: F841 - ntl.sum(input_fp32 * input_fp32) / input.shape[-1] + eps - ) - - -def rms_norm(input, eps=1e-5): - output = torch.empty_like(input) - - rms_norm_kernel(input, output, eps, BLOCK_SIZE=input.shape[-1]) - - return output - - -@triton.jit -def triton_rms_norm_kernel( - input_ptr, - output_ptr, - num_cols, - input_row_stride, - output_row_stride, - eps: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - row_idx = tl.program_id(0) - - col_offsets = tl.arange(0, BLOCK_SIZE) - input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets - mask = col_offsets < num_cols - input = tl.load(input_ptrs, mask=mask).to(tl.float32) - - output = input * tl.rsqrt(tl.sum(input * input) / num_cols + eps) - - output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets - tl.store(output_ptrs, output, mask=mask) - - -def triton_rms_norm(input, eps=1e-5): - output = torch.empty_like(input) - - triton_rms_norm_kernel[(input.shape[-2],)]( - input, - output, - input.shape[-1], - input.stride(-2), - output.stride(-2), - eps, - BLOCK_SIZE=triton.next_power_of_2(input.shape[-1]), - ) - - return output +import ops.ninetoothed.torch +import ops.triton.torch if __name__ == "__main__": torch.manual_seed(0) @@ -76,9 +13,9 @@ def triton_rms_norm(input, eps=1e-5): input = torch.randn(1151, 8192, dtype=dtype, device=device) - ninetoothed_output = rms_norm(input) + ninetoothed_output = ops.ninetoothed.torch.rms_norm(input) torch_output = F.rms_norm(input, input.shape[-1:]) - triton_output = triton_rms_norm(input) + triton_output = ops.triton.torch.rms_norm(input) print(ninetoothed_output) print(torch_output) @@ -110,21 +47,21 @@ def triton_rms_norm(input, eps=1e-5): def benchmark(m, n, provider): input = torch.randn(m, n, dtype=dtype, device=device) - ninetoothed_output = rms_norm(input) + ninetoothed_output = ops.ninetoothed.torch.rms_norm(input) torch_output = F.rms_norm(input, input.shape[-1:]) - triton_output = triton_rms_norm(input) + triton_output = ops.triton.torch.rms_norm(input) assert torch.allclose(ninetoothed_output, torch_output, atol=0.001, rtol=0.005) assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: rms_norm(input)) + ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.rms_norm(input)) elif provider == "torch": ms = triton.testing.do_bench( lambda: torch.rms_norm(input, input.shape[-1:]) ) elif provider == "triton": - ms = triton.testing.do_bench(lambda: triton_rms_norm(input)) + ms = triton.testing.do_bench(lambda: ops.triton.torch.rms_norm(input)) return ms From a33a14c95ffd1c96161a627f13c013e5adab99c9 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 14 May 2025 10:06:36 +0800 Subject: [PATCH 31/68] Rename `ops.triton.torch.triton_conv2d` to `ops.triton.torch.conv2d` --- conv2d.py | 8 +++----- ops/triton/torch.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/conv2d.py b/conv2d.py index 20dec55..60e48c8 100644 --- a/conv2d.py +++ b/conv2d.py @@ -18,7 +18,7 @@ ninetoothed_output = ops.ninetoothed.torch.conv2d(input, filter) torch_output = F.conv2d(input, filter) - triton_output = ops.triton.torch.triton_conv2d(input, filter) + triton_output = ops.triton.torch.conv2d(input, filter) print(ninetoothed_output) print(torch_output) @@ -56,7 +56,7 @@ def benchmark(n, provider): ninetoothed_output = ops.ninetoothed.torch.conv2d(input, filter) torch_output = F.conv2d(input, filter) - triton_output = ops.triton.torch.triton_conv2d(input, filter) + triton_output = ops.triton.torch.conv2d(input, filter) assert torch.allclose(ninetoothed_output, torch_output, atol=0.01, rtol=0.01) assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) @@ -68,9 +68,7 @@ def benchmark(n, provider): elif provider == "torch": ms = triton.testing.do_bench(lambda: F.conv2d(input, filter)) elif provider == "triton": - ms = triton.testing.do_bench( - lambda: ops.triton.torch.triton_conv2d(input, filter) - ) + ms = triton.testing.do_bench(lambda: ops.triton.torch.conv2d(input, filter)) return ms diff --git a/ops/triton/torch.py b/ops/triton/torch.py index b2481a2..3c61942 100644 --- a/ops/triton/torch.py +++ b/ops/triton/torch.py @@ -57,7 +57,7 @@ def grid(meta): return output -def triton_conv2d(input, filter): +def conv2d(input, filter): n, c, h, w = input.shape k, _, r, s = filter.shape p = h - r + 1 From be1f6946f370b4ba0025782fbe64de1b25e7c35b Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 14 May 2025 10:36:29 +0800 Subject: [PATCH 32/68] Rename `attention` to `scaled_dot_product_attention` and separate the kernels into modular packages --- attention.py | 382 ------------------ infer.py | 2 +- .../kernels/scaled_dot_product_attention.py | 59 +++ ops/ninetoothed/torch.py | 14 + .../kernels/scaled_dot_product_attention.py | 126 ++++++ ops/triton/torch.py | 35 ++ scaled_dot_product_attention.py | 163 ++++++++ 7 files changed, 398 insertions(+), 383 deletions(-) delete mode 100644 attention.py create mode 100644 ops/ninetoothed/kernels/scaled_dot_product_attention.py create mode 100644 ops/triton/kernels/scaled_dot_product_attention.py create mode 100644 scaled_dot_product_attention.py diff --git a/attention.py b/attention.py deleted file mode 100644 index 9b90c28..0000000 --- a/attention.py +++ /dev/null @@ -1,382 +0,0 @@ -import functools -import math - -import ninetoothed -import ninetoothed.language as ntl -import torch -import torch.nn as nn -import torch.nn.functional as F -import triton -import triton.language as tl -from ninetoothed import Symbol, Tensor -from transformers.models.llama.modeling_llama import repeat_kv - -import rope - - -def arrangement(q, k, v, scale, o): - BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) - BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) - - def arrange_q_or_o(input): - arranged = input.tile((1, 1, BLOCK_SIZE_M, -1)) - arranged.dtype = arranged.dtype.squeeze((0, 1)) - - return arranged - - def arrange_k_or_v(input): - arranged = ( - input.tile((1, 1, BLOCK_SIZE_N, -1)) - .tile((1, 1, -1, -1)) - .expand((-1, -1, q_arranged.shape[-2], -1)) - ) - arranged.dtype = arranged.dtype.squeeze((0, 1, 3)) - arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) - - return arranged - - q_arranged = arrange_q_or_o(q) - - return q_arranged, arrange_k_or_v(k), arrange_k_or_v(v), scale, arrange_q_or_o(o) - - -def application(q, k, v, scale, o): - q_loaded = (q * scale * 1.44269504089).to(ntl.float16) - - acc = ntl.zeros((q.shape[-2], q.shape[-1]), dtype=ntl.float32) - l_i = ntl.full((q.shape[-2],), 1, dtype=ntl.float32) - m_i = ntl.full((q.shape[-2],), float("-inf"), dtype=ntl.float32) - - for i in range(k.shape[0]): - qk = ntl.dot(q_loaded, ntl.trans(k[i])) - - m_ij = ntl.maximum(m_i, ntl.max(qk, 1)) - p = ntl.exp2(qk - m_ij[:, None]) - l_ij = ntl.sum(p, 1) - - alpha = ntl.exp2(m_i - m_ij) - acc = acc * alpha[:, None] + ntl.dot(p.to(ntl.float16), v[i]) - m_i = m_ij - l_i = l_i * alpha + l_ij - - acc /= l_i[:, None] - o = acc # noqa: F841 - - -shape_options = (None, None, None, {"constexpr": True, "upper_bound": 128}) -q, k, v, o = (Tensor(4, shape_options=shape_options) for _ in range(4)) -attention_kernel = ninetoothed.make(arrangement, application, (q, k, v, Tensor(0), o)) - - -def attention(q, k, v, scale=None): - if scale is None: - scale = 1 / math.sqrt(q.shape[-1]) - - o = torch.empty_like(q, dtype=v.dtype) - - attention_kernel(q, k, v, scale, o) - - return o - - -class Attention(nn.Module): - def __init__(self, other): - super().__init__() - - self.__dict__ = other.__dict__ - - def forward( - self, - hidden_states, - position_embeddings, - attention_mask, - past_key_value, - cache_position, - **kwargs, - ): - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape) - key_states = self.k_proj(hidden_states).view(hidden_shape) - value_states = self.v_proj(hidden_states).view(hidden_shape) - - cos_table, sin_table = position_embeddings - - _rope(query_states, sin_table, cos_table) - _rope(key_states, sin_table, cos_table) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if past_key_value is not None: - cache_kwargs = { - "sin": sin_table, - "cos": cos_table, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_dtype = torch.float16 - attn_output = attention( - query_states.to(attn_dtype), - key_states.to(attn_dtype), - value_states.to(attn_dtype), - scale=self.scaling, - ).to(query_states.dtype) - attn_output = attn_output.transpose(1, 2) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - - return attn_output, None - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32}, num_stages=4, num_warps=8 - ), - ], - key=["EMB_DIM"], -) -@triton.jit -def triton_attention_kernel( - q_ptr, - k_ptr, - v_ptr, - o_ptr, - q_stride_z, - q_stride_h, - q_stride_m, - q_stride_k, - k_stride_z, - k_stride_h, - k_stride_n, - k_stride_k, - v_stride_z, - v_stride_h, - v_stride_k, - v_stride_n, - o_stride_z, - o_stride_h, - o_stride_m, - o_stride_n, - scale, - SEQ_LEN: tl.constexpr, - EMB_DIM: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, -): - off_m = tl.program_id(0) - off_h = tl.program_id(1) - off_z = tl.program_id(2) - - offs_m_start = off_m * BLOCK_SIZE_M - - q_off = off_z * q_stride_z + off_h * q_stride_h - q_block_ptr = tl.make_block_ptr( - base=q_ptr + q_off, - shape=(SEQ_LEN, EMB_DIM), - strides=(q_stride_m, q_stride_k), - offsets=(offs_m_start, 0), - block_shape=(BLOCK_SIZE_M, EMB_DIM), - order=(1, 0), - ) - k_off = off_z * k_stride_z + off_h * k_stride_h - k_block_ptr = tl.make_block_ptr( - base=k_ptr + k_off, - shape=(EMB_DIM, SEQ_LEN), - strides=(k_stride_k, k_stride_n), - offsets=(0, 0), - block_shape=(EMB_DIM, BLOCK_SIZE_N), - order=(0, 1), - ) - v_off = off_z * v_stride_z + off_h * v_stride_h - v_block_ptr = tl.make_block_ptr( - base=v_ptr + v_off, - shape=(SEQ_LEN, EMB_DIM), - strides=(v_stride_k, v_stride_n), - offsets=(0, 0), - block_shape=(BLOCK_SIZE_N, EMB_DIM), - order=(1, 0), - ) - o_off = off_z * o_stride_z + off_h * o_stride_h - o_block_ptr = tl.make_block_ptr( - base=o_ptr + o_off, - shape=(SEQ_LEN, EMB_DIM), - strides=(o_stride_m, o_stride_n), - offsets=(offs_m_start, 0), - block_shape=(BLOCK_SIZE_M, EMB_DIM), - order=(1, 0), - ) - - q = (tl.load(q_block_ptr) * scale * 1.44269504089).to(q_block_ptr.type.element_ty) - - acc = tl.zeros((BLOCK_SIZE_M, EMB_DIM), dtype=tl.float32) - l_i = tl.full((BLOCK_SIZE_M,), 1, dtype=tl.float32) - m_i = tl.full((BLOCK_SIZE_M,), float("-inf"), dtype=tl.float32) - - for _ in range(0, tl.cdiv(SEQ_LEN, BLOCK_SIZE_N)): - k = tl.load(k_block_ptr) - - qk = tl.dot(q, k) - - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - p = tl.exp2(qk) - l_ij = tl.sum(p, 1) - - v = tl.load(v_block_ptr) - alpha = tl.exp2(m_i - m_ij) - acc = acc * alpha[:, None] + tl.dot(p.to(v_block_ptr.type.element_ty), v) - m_i = m_ij - l_i = l_i * alpha + l_ij - - v_block_ptr = tl.advance(v_block_ptr, (BLOCK_SIZE_N, 0)) - k_block_ptr = tl.advance(k_block_ptr, (0, BLOCK_SIZE_N)) - - acc /= l_i[:, None] - - tl.store(o_block_ptr, acc.to(o_ptr.type.element_ty)) - - -def triton_attention(q, k, v, scale=None): - batch_size, num_heads, seq_len, emb_dim = q.shape - - if scale is None: - scale = 1 / math.sqrt(emb_dim) - - o = torch.empty_like(q) - - def grid(meta): - return ( - triton.cdiv(seq_len, meta["BLOCK_SIZE_M"]), - num_heads, - batch_size, - ) - - triton_attention_kernel[grid]( - q, - k, - v, - o, - *q.stride(), - *k.stride(), - *v.stride(), - *o.stride(), - scale=scale, - SEQ_LEN=seq_len, - EMB_DIM=emb_dim, - ) - - return o - - -_rope_kernel = ninetoothed.make( - functools.partial(rope.arrangement, interleaved=False), - rope.application, - rope.tensors, -) - - -def _rope(x, sin_table, cos_table): - _, _, num_heads, _ = x.shape - sin_table = sin_table.unsqueeze(2).expand(-1, -1, num_heads, -1) - cos_table = cos_table.unsqueeze(2).expand(-1, -1, num_heads, -1) - - _rope_kernel(x, sin_table, cos_table) - - -if __name__ == "__main__": - torch.manual_seed(0) - - shape = (2, 4, 1024, 64) - dtype = torch.float16 - device = "cuda" - - q = torch.randn(shape, dtype=dtype, device=device) - k = torch.randn(shape, dtype=dtype, device=device) - v = torch.randn(shape, dtype=dtype, device=device) - - ninetoothed_output = attention(q, k, v) - torch_output = F.scaled_dot_product_attention(q, k, v) - triton_output = triton_attention(q, k, v) - - print(ninetoothed_output) - print(torch_output) - print(triton_output) - - if torch.allclose(ninetoothed_output, torch_output, atol=0.01): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.001): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["seq_len"], - x_vals=[2**i for i in range(7, 17)], - x_log=True, - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="ms", - plot_name="attention-performance", - args={}, - ) - ) - def benchmark(seq_len, provider): - batch_size, num_heads, emb_dim = 4, 32, 64 - shape = (batch_size, num_heads, seq_len, emb_dim) - dtype = torch.float16 - device = "cuda" - - q = torch.randn(shape, dtype=dtype, device=device) - k = torch.randn(shape, dtype=dtype, device=device) - v = torch.randn(shape, dtype=dtype, device=device) - - ninetoothed_output = attention(q, k, v) - torch_output = F.scaled_dot_product_attention(q, k, v) - triton_output = triton_attention(q, k, v) - - assert torch.allclose(ninetoothed_output, torch_output, atol=0.025, rtol=0.025) - assert torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.001) - - if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: attention(q, k, v)) - elif provider == "torch": - ms = triton.testing.do_bench( - lambda: F.scaled_dot_product_attention(q, k, v) - ) - elif provider == "triton": - ms = triton.testing.do_bench(lambda: triton_attention(q, k, v)) - - return ms - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/infer.py b/infer.py index c24a021..7f1a3c9 100644 --- a/infer.py +++ b/infer.py @@ -2,9 +2,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer -from attention import Attention from fused_rms_norm import RMSNorm from linear import Linear +from scaled_dot_product_attention import Attention from silu import SiLU from utils import replace_module diff --git a/ops/ninetoothed/kernels/scaled_dot_product_attention.py b/ops/ninetoothed/kernels/scaled_dot_product_attention.py new file mode 100644 index 0000000..24feb39 --- /dev/null +++ b/ops/ninetoothed/kernels/scaled_dot_product_attention.py @@ -0,0 +1,59 @@ +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor, block_size + +BLOCK_SIZE_M = block_size() +BLOCK_SIZE_N = block_size() + + +def arrangement( + q, k, v, scale, o, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N +): + def arrange_q_or_o(input): + arranged = input.tile((1, 1, BLOCK_SIZE_M, -1)) + arranged.dtype = arranged.dtype.squeeze((0, 1)) + + return arranged + + def arrange_k_or_v(input): + arranged = input.tile((1, 1, BLOCK_SIZE_N, -1)) + arranged = arranged.tile((1, 1, -1, -1)) + arranged = arranged.expand((-1, -1, q_arranged.shape[-2], -1)) + arranged.dtype = arranged.dtype.squeeze((0, 1, 3)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + q_arranged = arrange_q_or_o(q) + + return q_arranged, arrange_k_or_v(k), arrange_k_or_v(v), scale, arrange_q_or_o(o) + + +def application(q, k, v, scale, o): + q_loaded = (q * scale * 1.44269504089).to(ntl.float16) + + acc = ntl.zeros((q.shape[-2], q.shape[-1]), dtype=ntl.float32) + l_i = ntl.full((q.shape[-2],), 1, dtype=ntl.float32) + m_i = ntl.full((q.shape[-2],), float("-inf"), dtype=ntl.float32) + + for i in range(k.shape[0]): + qk = ntl.dot(q_loaded, ntl.trans(k[i])) + + m_ij = ntl.maximum(m_i, ntl.max(qk, 1)) + p = ntl.exp2(qk - m_ij[:, None]) + l_ij = ntl.sum(p, 1) + + alpha = ntl.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + ntl.dot(p.to(ntl.float16), v[i]) + m_i = m_ij + l_i = l_i * alpha + l_ij + + acc /= l_i[:, None] + o = acc # noqa: F841 + + +shape_options = (None, None, None, {"constexpr": True, "upper_bound": 128}) +q, k, v, o = (Tensor(4, shape_options=shape_options) for _ in range(4)) +tensors = (q, k, v, Tensor(0), o) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/torch.py b/ops/ninetoothed/torch.py index e37891f..4ec08d5 100644 --- a/ops/ninetoothed/torch.py +++ b/ops/ninetoothed/torch.py @@ -1,3 +1,5 @@ +import math + import torch import ops.ninetoothed.kernels.add @@ -5,6 +7,7 @@ import ops.ninetoothed.kernels.conv2d import ops.ninetoothed.kernels.mm import ops.ninetoothed.kernels.rms_norm +import ops.ninetoothed.kernels.scaled_dot_product_attention import ops.ninetoothed.kernels.softmax @@ -60,6 +63,17 @@ def rms_norm(input, eps=None): return output +def scaled_dot_product_attention(q, k, v, scale=None): + if scale is None: + scale = 1 / math.sqrt(q.shape[-1]) + + o = torch.empty_like(q) + + ops.ninetoothed.kernels.scaled_dot_product_attention.kernel(q, k, v, scale, o) + + return o + + def softmax(input): output = torch.empty_like(input) diff --git a/ops/triton/kernels/scaled_dot_product_attention.py b/ops/triton/kernels/scaled_dot_product_attention.py new file mode 100644 index 0000000..32fa849 --- /dev/null +++ b/ops/triton/kernels/scaled_dot_product_attention.py @@ -0,0 +1,126 @@ +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32}, num_stages=4, num_warps=8 + ), + ], + key=["EMB_DIM"], +) +@triton.jit +def kernel( + q_ptr, + k_ptr, + v_ptr, + o_ptr, + q_stride_z, + q_stride_h, + q_stride_m, + q_stride_k, + k_stride_z, + k_stride_h, + k_stride_n, + k_stride_k, + v_stride_z, + v_stride_h, + v_stride_k, + v_stride_n, + o_stride_z, + o_stride_h, + o_stride_m, + o_stride_n, + scale, + seq_len, + EMB_DIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + off_m = tl.program_id(0) + off_h = tl.program_id(1) + off_z = tl.program_id(2) + + offs_m_start = off_m * BLOCK_SIZE_M + + q_off = off_z * q_stride_z + off_h * q_stride_h + q_block_ptr = tl.make_block_ptr( + base=q_ptr + q_off, + shape=(seq_len, EMB_DIM), + strides=(q_stride_m, q_stride_k), + offsets=(offs_m_start, 0), + block_shape=(BLOCK_SIZE_M, EMB_DIM), + order=(1, 0), + ) + k_off = off_z * k_stride_z + off_h * k_stride_h + k_block_ptr = tl.make_block_ptr( + base=k_ptr + k_off, + shape=(EMB_DIM, seq_len), + strides=(k_stride_k, k_stride_n), + offsets=(0, 0), + block_shape=(EMB_DIM, BLOCK_SIZE_N), + order=(0, 1), + ) + v_off = off_z * v_stride_z + off_h * v_stride_h + v_block_ptr = tl.make_block_ptr( + base=v_ptr + v_off, + shape=(seq_len, EMB_DIM), + strides=(v_stride_k, v_stride_n), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_N, EMB_DIM), + order=(1, 0), + ) + o_off = off_z * o_stride_z + off_h * o_stride_h + o_block_ptr = tl.make_block_ptr( + base=o_ptr + o_off, + shape=(seq_len, EMB_DIM), + strides=(o_stride_m, o_stride_n), + offsets=(offs_m_start, 0), + block_shape=(BLOCK_SIZE_M, EMB_DIM), + order=(1, 0), + ) + + q = (tl.load(q_block_ptr) * scale * 1.44269504089).to(q_block_ptr.type.element_ty) + + acc = tl.zeros((BLOCK_SIZE_M, EMB_DIM), dtype=tl.float32) + l_i = tl.full((BLOCK_SIZE_M,), 1, dtype=tl.float32) + m_i = tl.full((BLOCK_SIZE_M,), float("-inf"), dtype=tl.float32) + + for _ in range(0, tl.cdiv(seq_len, BLOCK_SIZE_N)): + k = tl.load(k_block_ptr) + + qk = tl.dot(q, k) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.exp2(qk) + l_ij = tl.sum(p, 1) + + v = tl.load(v_block_ptr) + alpha = tl.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + tl.dot(p.to(v_block_ptr.type.element_ty), v) + m_i = m_ij + l_i = l_i * alpha + l_ij + + v_block_ptr = tl.advance(v_block_ptr, (BLOCK_SIZE_N, 0)) + k_block_ptr = tl.advance(k_block_ptr, (0, BLOCK_SIZE_N)) + + acc /= l_i[:, None] + + tl.store(o_block_ptr, acc.to(o_ptr.type.element_ty)) diff --git a/ops/triton/torch.py b/ops/triton/torch.py index 3c61942..6f3a8cd 100644 --- a/ops/triton/torch.py +++ b/ops/triton/torch.py @@ -1,3 +1,5 @@ +import math + import torch import triton @@ -6,6 +8,7 @@ import ops.triton.kernels.conv2d import ops.triton.kernels.mm import ops.triton.kernels.rms_norm +import ops.triton.kernels.scaled_dot_product_attention import ops.triton.kernels.softmax @@ -137,6 +140,38 @@ def rms_norm(input, eps=None): return output +def scaled_dot_product_attention(q, k, v, scale=None): + batch_size, num_heads, seq_len, emb_dim = q.shape + + if scale is None: + scale = 1 / math.sqrt(emb_dim) + + o = torch.empty_like(q) + + def grid(meta): + return ( + triton.cdiv(seq_len, meta["BLOCK_SIZE_M"]), + num_heads, + batch_size, + ) + + ops.triton.kernels.scaled_dot_product_attention.kernel[grid]( + q, + k, + v, + o, + *q.stride(), + *k.stride(), + *v.stride(), + *o.stride(), + scale=scale, + seq_len=seq_len, + EMB_DIM=emb_dim, + ) + + return o + + def softmax(input): output = torch.empty_like(input) diff --git a/scaled_dot_product_attention.py b/scaled_dot_product_attention.py new file mode 100644 index 0000000..3ac70dd --- /dev/null +++ b/scaled_dot_product_attention.py @@ -0,0 +1,163 @@ +import functools + +import ninetoothed +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +from transformers.models.llama.modeling_llama import repeat_kv + +import ops.ninetoothed.torch +import ops.triton.torch +import rope + + +class Attention(nn.Module): + def __init__(self, other): + super().__init__() + + self.__dict__ = other.__dict__ + + def forward( + self, + hidden_states, + position_embeddings, + attention_mask, + past_key_value, + cache_position, + **kwargs, + ): + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) + + cos_table, sin_table = position_embeddings + + _rope(query_states, sin_table, cos_table) + _rope(key_states, sin_table, cos_table) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin_table, + "cos": cos_table, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_dtype = torch.float16 + attn_output = ops.ninetoothed.torch.scaled_dot_product_attention( + query_states.to(attn_dtype), + key_states.to(attn_dtype), + value_states.to(attn_dtype), + scale=self.scaling, + ).to(query_states.dtype) + attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, None + + +_rope_kernel = ninetoothed.make( + functools.partial(rope.arrangement, interleaved=False), + rope.application, + rope.tensors, +) + + +def _rope(x, sin_table, cos_table): + _, _, num_heads, _ = x.shape + sin_table = sin_table.unsqueeze(2).expand(-1, -1, num_heads, -1) + cos_table = cos_table.unsqueeze(2).expand(-1, -1, num_heads, -1) + + _rope_kernel(x, sin_table, cos_table) + + +if __name__ == "__main__": + torch.manual_seed(0) + + shape = (2, 4, 1024, 64) + dtype = torch.float16 + device = "cuda" + + q = torch.randn(shape, dtype=dtype, device=device) + k = torch.randn(shape, dtype=dtype, device=device) + v = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.scaled_dot_product_attention(q, k, v) + torch_output = F.scaled_dot_product_attention(q, k, v) + triton_output = ops.triton.torch.scaled_dot_product_attention(q, k, v) + + print(ninetoothed_output) + print(torch_output) + print(triton_output) + + if torch.allclose(ninetoothed_output, torch_output, atol=0.01): + print("✅ NineToothed and PyTorch match.") + else: + print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.001): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["seq_len"], + x_vals=[2**i for i in range(7, 17)], + x_log=True, + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="ms", + plot_name="attention-performance", + args={}, + ) + ) + def benchmark(seq_len, provider): + batch_size, num_heads, emb_dim = 4, 32, 64 + shape = (batch_size, num_heads, seq_len, emb_dim) + dtype = torch.float16 + device = "cuda" + + q = torch.randn(shape, dtype=dtype, device=device) + k = torch.randn(shape, dtype=dtype, device=device) + v = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.scaled_dot_product_attention(q, k, v) + torch_output = F.scaled_dot_product_attention(q, k, v) + triton_output = ops.triton.torch.scaled_dot_product_attention(q, k, v) + + assert torch.allclose(ninetoothed_output, torch_output, atol=0.025, rtol=0.025) + assert torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.001) + + if provider == "ninetoothed": + ms = triton.testing.do_bench( + lambda: ops.ninetoothed.torch.scaled_dot_product_attention(q, k, v) + ) + elif provider == "torch": + ms = triton.testing.do_bench( + lambda: F.scaled_dot_product_attention(q, k, v) + ) + elif provider == "triton": + ms = triton.testing.do_bench( + lambda: ops.triton.torch.scaled_dot_product_attention(q, k, v) + ) + + return ms + + benchmark.run(show_plots=True, print_data=True, save_path=".") From 6952672e1f4857269cb8d76bf58ae783487f3738 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 14 May 2025 10:43:46 +0800 Subject: [PATCH 33/68] Use `dtype` access instead of hardcoding --- ops/ninetoothed/kernels/scaled_dot_product_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ops/ninetoothed/kernels/scaled_dot_product_attention.py b/ops/ninetoothed/kernels/scaled_dot_product_attention.py index 24feb39..3e0604b 100644 --- a/ops/ninetoothed/kernels/scaled_dot_product_attention.py +++ b/ops/ninetoothed/kernels/scaled_dot_product_attention.py @@ -30,7 +30,7 @@ def arrange_k_or_v(input): def application(q, k, v, scale, o): - q_loaded = (q * scale * 1.44269504089).to(ntl.float16) + q_loaded = (q * scale * 1.44269504089).to(q.dtype) acc = ntl.zeros((q.shape[-2], q.shape[-1]), dtype=ntl.float32) l_i = ntl.full((q.shape[-2],), 1, dtype=ntl.float32) @@ -44,7 +44,7 @@ def application(q, k, v, scale, o): l_ij = ntl.sum(p, 1) alpha = ntl.exp2(m_i - m_ij) - acc = acc * alpha[:, None] + ntl.dot(p.to(ntl.float16), v[i]) + acc = acc * alpha[:, None] + ntl.dot(p.to(v.dtype.dtype), v[i]) m_i = m_ij l_i = l_i * alpha + l_ij From fc45cea2bbe6188e96a252dd6b202ca473c1b9da Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 14 May 2025 11:28:12 +0800 Subject: [PATCH 34/68] Fix the boundary issues --- .../kernels/scaled_dot_product_attention.py | 3 ++- .../kernels/scaled_dot_product_attention.py | 17 +++++++++-------- scaled_dot_product_attention.py | 4 ++-- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/ops/ninetoothed/kernels/scaled_dot_product_attention.py b/ops/ninetoothed/kernels/scaled_dot_product_attention.py index 3e0604b..2d3f887 100644 --- a/ops/ninetoothed/kernels/scaled_dot_product_attention.py +++ b/ops/ninetoothed/kernels/scaled_dot_product_attention.py @@ -38,6 +38,7 @@ def application(q, k, v, scale, o): for i in range(k.shape[0]): qk = ntl.dot(q_loaded, ntl.trans(k[i])) + qk = ntl.where(k[i].offsets(-2) < k.source.shape[-2], qk, float("-inf")) m_ij = ntl.maximum(m_i, ntl.max(qk, 1)) p = ntl.exp2(qk - m_ij[:, None]) @@ -49,7 +50,7 @@ def application(q, k, v, scale, o): l_i = l_i * alpha + l_ij acc /= l_i[:, None] - o = acc # noqa: F841 + o = acc.to(o.dtype) # noqa: F841 shape_options = (None, None, None, {"constexpr": True, "upper_bound": 128}) diff --git a/ops/triton/kernels/scaled_dot_product_attention.py b/ops/triton/kernels/scaled_dot_product_attention.py index 32fa849..48de10e 100644 --- a/ops/triton/kernels/scaled_dot_product_attention.py +++ b/ops/triton/kernels/scaled_dot_product_attention.py @@ -96,23 +96,24 @@ def kernel( order=(1, 0), ) - q = (tl.load(q_block_ptr) * scale * 1.44269504089).to(q_block_ptr.type.element_ty) + q = tl.load(q_block_ptr, boundary_check=(0, 1)) + q = (q * scale * 1.44269504089).to(q_block_ptr.type.element_ty) acc = tl.zeros((BLOCK_SIZE_M, EMB_DIM), dtype=tl.float32) l_i = tl.full((BLOCK_SIZE_M,), 1, dtype=tl.float32) m_i = tl.full((BLOCK_SIZE_M,), float("-inf"), dtype=tl.float32) - for _ in range(0, tl.cdiv(seq_len, BLOCK_SIZE_N)): - k = tl.load(k_block_ptr) + for i in range(0, tl.cdiv(seq_len, BLOCK_SIZE_N)): + k = tl.load(k_block_ptr, boundary_check=(0, 1)) - qk = tl.dot(q, k) + mask = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) < seq_len + qk = tl.where(mask, tl.dot(q, k), float("-inf")) m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - p = tl.exp2(qk) + p = tl.exp2(qk - m_ij[:, None]) l_ij = tl.sum(p, 1) - v = tl.load(v_block_ptr) + v = tl.load(v_block_ptr, boundary_check=(0, 1)) alpha = tl.exp2(m_i - m_ij) acc = acc * alpha[:, None] + tl.dot(p.to(v_block_ptr.type.element_ty), v) m_i = m_ij @@ -123,4 +124,4 @@ def kernel( acc /= l_i[:, None] - tl.store(o_block_ptr, acc.to(o_ptr.type.element_ty)) + tl.store(o_block_ptr, acc.to(o_ptr.type.element_ty), boundary_check=(0, 1)) diff --git a/scaled_dot_product_attention.py b/scaled_dot_product_attention.py index 3ac70dd..616b64f 100644 --- a/scaled_dot_product_attention.py +++ b/scaled_dot_product_attention.py @@ -109,7 +109,7 @@ def _rope(x, sin_table, cos_table): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.001): + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): print("✅ NineToothed and Triton match.") else: print("❌ NineToothed and Triton differ.") @@ -143,7 +143,7 @@ def benchmark(seq_len, provider): triton_output = ops.triton.torch.scaled_dot_product_attention(q, k, v) assert torch.allclose(ninetoothed_output, torch_output, atol=0.025, rtol=0.025) - assert torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.001) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) if provider == "ninetoothed": ms = triton.testing.do_bench( From db86285b7f197b34580401a1b24c44f3bcdc9d2e Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 14 May 2025 14:20:31 +0800 Subject: [PATCH 35/68] Improve the Triton `conv2d` implementation --- ops/triton/kernels/conv2d.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/ops/triton/kernels/conv2d.py b/ops/triton/kernels/conv2d.py index 6e28ad7..acc8460 100644 --- a/ops/triton/kernels/conv2d.py +++ b/ops/triton/kernels/conv2d.py @@ -9,7 +9,6 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, }, num_stages=3, num_warps=8, @@ -19,7 +18,6 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, }, num_stages=4, num_warps=4, @@ -29,7 +27,6 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, }, num_stages=4, num_warps=4, @@ -39,7 +36,6 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, }, num_stages=4, num_warps=4, @@ -49,7 +45,6 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, }, num_stages=4, num_warps=4, @@ -59,7 +54,6 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, }, num_stages=4, num_warps=4, @@ -69,7 +63,6 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, }, num_stages=5, num_warps=2, @@ -79,7 +72,6 @@ "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, }, num_stages=5, num_warps=2, @@ -114,24 +106,17 @@ def kernel( BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, ): P: tl.constexpr = H - R + 1 Q: tl.constexpr = W - S + 1 - GEMM_M: tl.constexpr = N * P * Q GEMM_N: tl.constexpr = K GEMM_K: tl.constexpr = C * R * S pid = tl.program_id(0) - num_pid_gemm_m = tl.cdiv(GEMM_M, BLOCK_SIZE_M) num_pid_gemm_n = tl.cdiv(GEMM_N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_gemm_n - group_id = pid // num_pid_in_group - first_pid_gemm_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_gemm_m - first_pid_gemm_m, GROUP_SIZE_M) - pid_gemm_m = first_pid_gemm_m + ((pid % num_pid_in_group) % group_size_m) - pid_gemm_n = (pid % num_pid_in_group) // group_size_m + pid_gemm_m = pid // num_pid_gemm_n + pid_gemm_n = pid % num_pid_gemm_n offs_gemm_i = pid_gemm_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_gemm_j = pid_gemm_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) From 29bf2f960c4094eb286f465a1f081f6c631e84fd Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 14 May 2025 14:25:32 +0800 Subject: [PATCH 36/68] Refactor `import mm` to `import ops.ninetoothed.kernels.mm as mm` in `bmm.py` --- bmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmm.py b/bmm.py index 5c694f4..1cd832c 100644 --- a/bmm.py +++ b/bmm.py @@ -2,7 +2,7 @@ import torch from ninetoothed import Symbol, Tensor -import mm +import ops.ninetoothed.kernels.mm as mm BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) From f87582fe5ea19e559e89c89557b008957ecbab63 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 14 May 2025 23:59:07 +0800 Subject: [PATCH 37/68] Add `compare_code_metrics.py` --- code_size_comparison.py | 57 ------------ compare_code_metrics.py | 196 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 196 insertions(+), 57 deletions(-) delete mode 100644 code_size_comparison.py create mode 100644 compare_code_metrics.py diff --git a/code_size_comparison.py b/code_size_comparison.py deleted file mode 100644 index 1e733e5..0000000 --- a/code_size_comparison.py +++ /dev/null @@ -1,57 +0,0 @@ -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -plt.rcParams["figure.dpi"] = 600 -plt.rcParams["font.family"] = "JetBrains Mono" -plt.rcParams["font.weight"] = "bold" -plt.rcParams["axes.labelweight"] = "bold" - -kernels = ("add", "softmax", "rms_norm", "matmul", "conv2d", "attention") -lines_of_code = { - "Triton": (19, 25, 21, 57, 110, 98), - "NineToothed": (10, 12, 13, 34, 17, 51), -} - -x = np.arange(len(kernels)) -width = 0.25 -multiplier = 0 - -fig, ax = plt.subplots() - -for provider, lines in lines_of_code.items(): - offset = width * multiplier - rects = ax.bar(x + offset, lines, width, label=provider) - ax.bar_label(rects, fontsize=12) - multiplier += 1 - -ax.set_ylabel("Lines of Code", fontsize=12) -ax.tick_params(axis="y", labelsize=10, labelcolor="gray") -ax.set_xticks(x + width / 2, kernels, fontsize=10) -ax.xaxis.set_ticks_position("none") -ax.yaxis.set_ticks_position("none") -ax.legend(ncols=2, fontsize=10) -ax.spines[["top", "left", "right"]].set_visible(False) -ax.spines["bottom"].set_linewidth(1.5) -ax.grid(axis="y", linewidth=1.5) -ax.set_axisbelow(True) - -plt.show() -plt.savefig("code-size-comparison.png") - -print( - pd.DataFrame( - { - "Kernel": kernels + ("Overall",), - "Relative Code Size Change (%)": [ - f"{ninetoothed_lines / triton_lines * 100:.2f}%" - for ninetoothed_lines, triton_lines in zip( - lines_of_code["NineToothed"], lines_of_code["Triton"] - ) - ] - + [ - f"{sum(lines_of_code['NineToothed']) / sum(lines_of_code['Triton']) * 100:.2f}%" - ], - } - ) -) diff --git a/compare_code_metrics.py b/compare_code_metrics.py new file mode 100644 index 0000000..a63ce95 --- /dev/null +++ b/compare_code_metrics.py @@ -0,0 +1,196 @@ +import json +import os.path +from pathlib import Path + +import pandas as pd + +_PARENT_PATH = Path(__file__).parent + +_OPS_PATH = _PARENT_PATH / "ops" + +_NINETOOTHED_KERNELS_PATH = _OPS_PATH / "ninetoothed" / "kernels" + +_TRITON_KERNELS_PATH = _OPS_PATH / "triton" / "kernels" + + +def _generate_cc_table(): + path = _PARENT_PATH / "cc.json" + + metric_names = {"complexity": "$G$"} + + data = json.loads(path.read_text()) + + data = { + kernel: { + metric_names["complexity"]: sum(block["complexity"] for block in blocks) + } + for kernel, blocks in data.items() + if "torch" not in kernel + } + + df = _generate_table(data, metric_names.values()) + + styled_df = df.style.apply(_highlight_minimum, axis=None).format(precision=2) + + return styled_df.to_latex(hrules=True, multicol_align="c", convert_css=True) + + +def _generate_mi_table(): + path = _PARENT_PATH / "mi.json" + + metric_names = {"mi": "$MI$"} + + data = json.loads(path.read_text()) + + data = { + kernel: { + latex_name: metrics[raw_name] + for raw_name, latex_name in metric_names.items() + } + for kernel, metrics in data.items() + if "torch" not in kernel + } + + df = _generate_table(data, metric_names.values()) + + styled_df = df.style.apply(_highlight_maximum, axis=None).format(precision=2) + + return styled_df.to_latex(hrules=True, multicol_align="c", convert_css=True) + + +def _generate_raw_table(): + path = _PARENT_PATH / "raw.json" + + metric_names = {"loc": "LOC", "lloc": "LLOC", "sloc": "SLOC"} + + data = json.loads(path.read_text()) + + data = { + kernel: { + latex_name: metrics[raw_name] + for raw_name, latex_name in metric_names.items() + } + for kernel, metrics in data.items() + if "torch" not in kernel + } + + df = _generate_table(data, metric_names.values()) + + styled_df = df.style.apply(_highlight_minimum, axis=None).format(precision=2) + + return styled_df.to_latex(hrules=True, multicol_align="c", convert_css=True) + + +def _generate_hal_table(): + path = _PARENT_PATH / "hal.json" + + metric_names = { + "h1": "$\\eta_1$", + "h2": "$\\eta_2$", + "N1": "$N_1$", + "N2": "$N_2$", + "vocabulary": "$\\eta$", + "length": "$N$", + "calculated_length": "$\\hat{N}$", + "volume": "$V$", + "difficulty": "$D$", + "effort": "$E$", + "time": "$T$", + "bugs": "$B$", + } + + data = json.loads(path.read_text()) + + data = { + kernel: { + latex_name: metrics["total"][raw_name] + for raw_name, latex_name in metric_names.items() + } + for kernel, metrics in data.items() + if "torch" not in kernel + } + + df = _generate_table(data, metric_names.values()) + + styled_df = df.style.apply(_highlight_minimum, axis=None).format(precision=2) + + return styled_df.to_latex(hrules=True, multicol_align="c", convert_css=True) + + +def _generate_table(data, metric_names): + kernel_names = sorted( + set( + os.path.splitext(os.path.basename(kernel_name))[0] + for kernel_name in data.keys() + ) + ) + + def _key_from_kernel_name(path, kernel_name): + return str(path / f"{kernel_name}.py").removeprefix(str(_PARENT_PATH))[1:] + + data = { + f"\\texttt{{{kernel_name.replace('scaled_dot_product_attention', 'sdpa').replace('_', '\\_')}}}": { + "Triton": { + metric_name: data[ + _key_from_kernel_name(_TRITON_KERNELS_PATH, kernel_name) + ][metric_name] + for metric_name in metric_names + }, + "NineToothed": { + metric_name: data[ + _key_from_kernel_name(_NINETOOTHED_KERNELS_PATH, kernel_name) + ][metric_name] + for metric_name in metric_names + }, + } + for kernel_name in kernel_names + } + + df = pd.DataFrame.from_dict( + { + (outer_key, inner_key): value + for outer_key, inner_dict in data.items() + for inner_key, value in inner_dict.items() + }, + orient="index", + ) + + df.index = pd.MultiIndex.from_tuples(df.index) + + return df + + +def _highlight_minimum(df): + styles = pd.DataFrame("", index=df.index, columns=df.columns) + + for kernel, group in df.groupby(level=0): + mask = group == group.min() + + styles.update( + mask.replace(True, "background-color: green!20").replace(False, "") + ) + + return styles + + +def _highlight_maximum(df): + styles = pd.DataFrame("", index=df.index, columns=df.columns) + + for kernel, group in df.groupby(level=0): + mask = group == group.max() + + styles.update( + mask.replace(True, "background-color: green!20").replace(False, "") + ) + + return styles + + +if __name__ == "__main__": + for latex_code in ( + _generate_cc_table(), + _generate_mi_table(), + _generate_raw_table(), + _generate_hal_table(), + ): + print(latex_code) From 406c18ea29f20f6ada4dc02bec4a872318b89503 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Thu, 15 May 2025 15:46:03 +0800 Subject: [PATCH 38/68] Separate `bmm` kernels into modular packages --- bmm.py | 102 ++++++++++----------- ops/ninetoothed/kernels/bmm.py | 39 ++++++++ ops/ninetoothed/torch.py | 10 +++ ops/triton/kernels/bmm.py | 158 +++++++++++++++++++++++++++++++++ ops/triton/torch.py | 28 ++++++ 5 files changed, 286 insertions(+), 51 deletions(-) create mode 100644 ops/ninetoothed/kernels/bmm.py create mode 100644 ops/triton/kernels/bmm.py diff --git a/bmm.py b/bmm.py index 1cd832c..3809f20 100644 --- a/bmm.py +++ b/bmm.py @@ -1,52 +1,8 @@ -import ninetoothed import torch -from ninetoothed import Symbol, Tensor - -import ops.ninetoothed.kernels.mm as mm - -BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) -BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) -BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True) - - -def arrangement( - lhs, - rhs, - output, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, -): - output_arranged = output.tile((1, BLOCK_SIZE_M, BLOCK_SIZE_N)) - output_arranged.dtype = output_arranged.dtype.squeeze(0) - - lhs_arranged = lhs.tile((1, BLOCK_SIZE_M, BLOCK_SIZE_K)) - lhs_arranged = lhs_arranged.tile((1, 1, -1)) - lhs_arranged = lhs_arranged.expand((-1, -1, output_arranged.shape[-1])) - lhs_arranged.dtype = lhs_arranged.dtype.squeeze((0, 1)) - lhs_arranged.dtype.dtype = lhs_arranged.dtype.dtype.squeeze(0) - - rhs_arranged = rhs.tile((1, BLOCK_SIZE_K, BLOCK_SIZE_N)) - rhs_arranged = rhs_arranged.tile((1, -1, 1)) - rhs_arranged = rhs_arranged.expand((-1, output_arranged.shape[-2], -1)) - rhs_arranged.dtype = rhs_arranged.dtype.squeeze((0, 2)) - rhs_arranged.dtype.dtype = rhs_arranged.dtype.dtype.squeeze(0) - - return lhs_arranged, rhs_arranged, output_arranged - - -tensors = (Tensor(3), Tensor(3), Tensor(3)) -bmm_kernel = ninetoothed.make(arrangement, mm.application, tensors) - - -def bmm(lhs, rhs): - output_shape = (lhs.shape[0], lhs.shape[-2], rhs.shape[-1]) - output = torch.empty(output_shape, dtype=lhs.dtype, device=lhs.device) - - bmm_kernel(lhs, rhs, output) - - return output +import triton +import ops.ninetoothed.torch +import ops.triton.torch if __name__ == "__main__": torch.manual_seed(0) @@ -55,16 +11,60 @@ def bmm(lhs, rhs): dtype = torch.float16 device = "cuda" - lhs = torch.randn(batch_size, m, k, dtype=dtype, device=device) - rhs = torch.randn(batch_size, k, n, dtype=dtype, device=device) + input = torch.randn(batch_size, m, k, dtype=dtype, device=device) + other = torch.randn(batch_size, k, n, dtype=dtype, device=device) - ninetoothed_output = bmm(lhs, rhs) - torch_output = torch.bmm(lhs, rhs) + ninetoothed_output = ops.ninetoothed.torch.bmm(input, other) + torch_output = torch.bmm(input, other) + triton_output = ops.triton.torch.bmm(input, other) print(ninetoothed_output) print(torch_output) + print(triton_output) if torch.allclose(ninetoothed_output, torch_output): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["m", "n", "k"], + x_vals=[2**i for i in range(3, 13)], + x_log=True, + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="ms", + plot_name="batched-matrix-multiplication-performance", + args={"b": 4}, + ) + ) + def benchmark(b, m, n, k, provider): + input = torch.randn((b, m, k), dtype=dtype, device=device) + other = torch.randn((b, k, n), dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.bmm(input, other) + torch_output = torch.bmm(input, other) + triton_output = ops.triton.torch.bmm(input, other) + + assert torch.allclose(ninetoothed_output, torch_output) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) + + if provider == "ninetoothed": + ms = triton.testing.do_bench( + lambda: ops.ninetoothed.torch.bmm(input, other) + ) + elif provider == "torch": + ms = triton.testing.do_bench(lambda: torch.bmm(input, other)) + elif provider == "triton": + ms = triton.testing.do_bench(lambda: ops.triton.torch.bmm(input, other)) + + return ms + + benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/ops/ninetoothed/kernels/bmm.py b/ops/ninetoothed/kernels/bmm.py new file mode 100644 index 0000000..6b34635 --- /dev/null +++ b/ops/ninetoothed/kernels/bmm.py @@ -0,0 +1,39 @@ +import ninetoothed +from ninetoothed import Tensor, block_size + +from ops.ninetoothed.kernels.mm import application + +BLOCK_SIZE_M = block_size() +BLOCK_SIZE_N = block_size() +BLOCK_SIZE_K = block_size() + + +def arrangement( + input, + other, + output, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, +): + output_arranged = output.tile((1, BLOCK_SIZE_M, BLOCK_SIZE_N)) + output_arranged.dtype = output_arranged.dtype.squeeze(0) + + input_arranged = input.tile((1, BLOCK_SIZE_M, BLOCK_SIZE_K)) + input_arranged = input_arranged.tile((1, 1, -1)) + input_arranged = input_arranged.expand((-1, -1, output_arranged.shape[-1])) + input_arranged.dtype = input_arranged.dtype.squeeze((0, 1)) + input_arranged.dtype.dtype = input_arranged.dtype.dtype.squeeze(0) + + other_arranged = other.tile((1, BLOCK_SIZE_K, BLOCK_SIZE_N)) + other_arranged = other_arranged.tile((1, -1, 1)) + other_arranged = other_arranged.expand((-1, output_arranged.shape[-2], -1)) + other_arranged.dtype = other_arranged.dtype.squeeze((0, 2)) + other_arranged.dtype.dtype = other_arranged.dtype.dtype.squeeze(0) + + return input_arranged, other_arranged, output_arranged + + +tensors = (Tensor(3), Tensor(3), Tensor(3)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/torch.py b/ops/ninetoothed/torch.py index 4ec08d5..71828cf 100644 --- a/ops/ninetoothed/torch.py +++ b/ops/ninetoothed/torch.py @@ -4,6 +4,7 @@ import ops.ninetoothed.kernels.add import ops.ninetoothed.kernels.addmm +import ops.ninetoothed.kernels.bmm import ops.ninetoothed.kernels.conv2d import ops.ninetoothed.kernels.mm import ops.ninetoothed.kernels.rms_norm @@ -28,6 +29,15 @@ def addmm(input, mat1, mat2, beta=1, alpha=1): return output +def bmm(lhs, rhs): + output_shape = (lhs.shape[0], lhs.shape[-2], rhs.shape[-1]) + output = torch.empty(output_shape, dtype=lhs.dtype, device=lhs.device) + + ops.ninetoothed.kernels.bmm.kernel(lhs, rhs, output) + + return output + + def conv2d(input, filter): n, _, h, w = input.shape k, _, r, s = filter.shape diff --git a/ops/triton/kernels/bmm.py b/ops/triton/kernels/bmm.py new file mode 100644 index 0000000..bc18a74 --- /dev/null +++ b/ops/triton/kernels/bmm.py @@ -0,0 +1,158 @@ +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + ], + key=["m", "n", "k"], +) +@triton.jit +def kernel( + input_ptr, + other_ptr, + output_ptr, + m, + n, + k, + input_stride_b, + input_stride_m, + input_stride_k, + other_stride_b, + other_stride_k, + other_stride_n, + output_stride_b, + output_stride_m, + output_stride_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(m, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(n, BLOCK_SIZE_N) + num_pid_per_batch = num_pid_m * num_pid_n + bid = pid // num_pid_per_batch + pid_in_batch = pid % num_pid_per_batch + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid_in_batch // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid_in_batch % num_pid_in_group) % group_size_m) + pid_n = (pid_in_batch % num_pid_in_group) // group_size_m + + input_ptr_batch = input_ptr + bid * input_stride_b + other_ptr_batch = other_ptr + bid * other_stride_b + output_ptr_batch = output_ptr + bid * output_stride_b + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % m + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n + offs_k = tl.arange(0, BLOCK_SIZE_K) + + input_ptrs = input_ptr_batch + ( + offs_am[:, None] * input_stride_m + offs_k[None, :] * input_stride_k + ) + other_ptrs = other_ptr_batch + ( + offs_k[:, None] * other_stride_k + offs_bn[None, :] * other_stride_n + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for i in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + input = tl.load(input_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K) + other = tl.load(other_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K) + accumulator = tl.dot(input, other, accumulator) + input_ptrs += BLOCK_SIZE_K * input_stride_k + other_ptrs += BLOCK_SIZE_K * other_stride_k + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + output_ptrs = output_ptr_batch + ( + output_stride_m * offs_cm[:, None] + output_stride_n * offs_cn[None, :] + ) + + tl.store( + output_ptrs, accumulator, mask=(offs_cm[:, None] < m) & (offs_cn[None, :] < n) + ) diff --git a/ops/triton/torch.py b/ops/triton/torch.py index 6f3a8cd..150c328 100644 --- a/ops/triton/torch.py +++ b/ops/triton/torch.py @@ -5,6 +5,7 @@ import ops.triton.kernels.add import ops.triton.kernels.addmm +import ops.triton.kernels.bmm import ops.triton.kernels.conv2d import ops.triton.kernels.mm import ops.triton.kernels.rms_norm @@ -60,6 +61,33 @@ def grid(meta): return output +def bmm(input, other): + batch, m, k = input.shape + _, _, n = other.shape + output = torch.empty((batch, m, n), dtype=input.dtype, device=input.device) + + def grid(meta): + return ( + batch + * triton.cdiv(m, meta["BLOCK_SIZE_M"]) + * triton.cdiv(n, meta["BLOCK_SIZE_N"]), + ) + + ops.triton.kernels.bmm.kernel[grid]( + input, + other, + output, + m, + n, + k, + *input.stride(), + *other.stride(), + *output.stride(), + ) + + return output + + def conv2d(input, filter): n, c, h, w = input.shape k, _, r, s = filter.shape From 114f12fade4121366871e24ad36fd5e8d9fa2d2c Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Thu, 15 May 2025 15:46:56 +0800 Subject: [PATCH 39/68] Separate `fused_rms_norm` kernels into modular packages --- fused_rms_norm.py | 101 ++++++++++++++++------ ops/ninetoothed/kernels/fused_rms_norm.py | 22 +++++ ops/ninetoothed/torch.py | 16 ++++ ops/triton/kernels/fused_rms_norm.py | 32 +++++++ ops/triton/torch.py | 24 +++++ 5 files changed, 168 insertions(+), 27 deletions(-) create mode 100644 ops/ninetoothed/kernels/fused_rms_norm.py create mode 100644 ops/triton/kernels/fused_rms_norm.py diff --git a/fused_rms_norm.py b/fused_rms_norm.py index e0afb48..8119674 100644 --- a/fused_rms_norm.py +++ b/fused_rms_norm.py @@ -1,41 +1,88 @@ -import ninetoothed -import ninetoothed.language as ntl import torch import torch.nn as nn -from ninetoothed import Symbol, Tensor +import torch.nn.functional as F +import triton -BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) +import ops.ninetoothed.torch +import ops.triton.torch -@ninetoothed.jit -def fused_rms_norm_kernel( - x: Tensor(2).tile((1, BLOCK_SIZE)), - w: Tensor(2).tile((1, BLOCK_SIZE)), - y: Tensor(2).tile((1, BLOCK_SIZE)), - eps: Tensor(0), -): - x_fp32 = ntl.cast(x, ntl.float32) - y = x_fp32 * ntl.rsqrt(ntl.sum(x_fp32 * x_fp32) / x.shape[-1] + eps) * w # noqa: F841 +class RMSNorm(nn.Module): + def __init__(self, other): + super().__init__() + self.__dict__ = other.__dict__ -def fused_rms_norm(x, w, eps=None): - if eps is None: - eps = torch.finfo(x.dtype).eps() + def forward(self, x): + return ops.ninetoothed.torch.fused_rms_norm( + x, self.weight, self.variance_epsilon + ) - x_2d = x.view(-1, x.shape[-1]) - w_2d = w.expand_as(x_2d) - y_2d = torch.empty_like(x_2d) - fused_rms_norm_kernel(x_2d, w_2d, y_2d, eps, BLOCK_SIZE=x.shape[-1]) +if __name__ == "__main__": + torch.manual_seed(0) - return y_2d.view(x.shape) + dtype = torch.float16 + device = "cuda" + x = torch.randn(1151, 8192, dtype=dtype, device=device) + w = torch.randn(8192, dtype=dtype, device=device) + eps = 1e-5 -class RMSNorm(nn.Module): - def __init__(self, other): - super().__init__() + ninetoothed_output = ops.ninetoothed.torch.fused_rms_norm(x, w, eps) + torch_output = F.rms_norm(x, x.shape[-1:], w, eps) + triton_output = ops.triton.torch.fused_rms_norm(x, w, eps) - self.__dict__ = other.__dict__ + print(ninetoothed_output) + print(torch_output) + print(triton_output) - def forward(self, x): - return fused_rms_norm(x, self.weight, self.variance_epsilon) + if torch.allclose(ninetoothed_output, torch_output, atol=0.001, rtol=0.005): + print("✅ NineToothed and PyTorch match.") + else: + print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.005): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["n"], + x_vals=[2**i for i in range(5, 15)], + x_log=True, + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="ms", + plot_name="fused-rms-norm-performance", + args={"m": 4096}, + ) + ) + def benchmark(m, n, provider): + x = torch.randn(m, n, dtype=dtype, device=device) + w = torch.randn(n, dtype=dtype, device=device) + eps = 1e-5 + + ninetoothed_output = ops.ninetoothed.torch.fused_rms_norm(x, w, eps) + torch_output = F.rms_norm(x, x.shape[-1:], w, eps) + triton_output = ops.triton.torch.fused_rms_norm(x, w, eps) + + assert torch.allclose(ninetoothed_output, torch_output, atol=0.001, rtol=0.005) + assert torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.005) + + if provider == "ninetoothed": + ms = triton.testing.do_bench( + lambda: ops.ninetoothed.torch.fused_rms_norm(x, w, eps) + ) + elif provider == "torch": + ms = triton.testing.do_bench(lambda: F.rms_norm(x, x.shape[-1:], w, eps)) + elif provider == "triton": + ms = triton.testing.do_bench( + lambda: ops.triton.torch.fused_rms_norm(x, w, eps) + ) + + return ms + + benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/ops/ninetoothed/kernels/fused_rms_norm.py b/ops/ninetoothed/kernels/fused_rms_norm.py new file mode 100644 index 0000000..5242179 --- /dev/null +++ b/ops/ninetoothed/kernels/fused_rms_norm.py @@ -0,0 +1,22 @@ +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) + + +def arrangement(x, w, eps, y, BLOCK_SIZE=BLOCK_SIZE): + def arrange(tensor): + return tensor.tile((1, BLOCK_SIZE)) + + return arrange(x), arrange(w), eps, arrange(y) + + +def application(x, w, eps, y): + x_fp32 = ntl.cast(x, ntl.float32) + y = x_fp32 * ntl.rsqrt(ntl.sum(x_fp32 * x_fp32) / x.shape[-1] + eps) * w # noqa: F841 + + +tensors = (Tensor(2), Tensor(2), Tensor(0), Tensor(2)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/torch.py b/ops/ninetoothed/torch.py index 71828cf..94b11e6 100644 --- a/ops/ninetoothed/torch.py +++ b/ops/ninetoothed/torch.py @@ -6,6 +6,7 @@ import ops.ninetoothed.kernels.addmm import ops.ninetoothed.kernels.bmm import ops.ninetoothed.kernels.conv2d +import ops.ninetoothed.kernels.fused_rms_norm import ops.ninetoothed.kernels.mm import ops.ninetoothed.kernels.rms_norm import ops.ninetoothed.kernels.scaled_dot_product_attention @@ -51,6 +52,21 @@ def conv2d(input, filter): return output +def fused_rms_norm(x, w, eps=None): + if eps is None: + eps = torch.finfo(x.dtype).eps() + + x_2d = x.view(-1, x.shape[-1]) + w_2d = w.expand_as(x_2d) + y_2d = torch.empty_like(x_2d) + + ops.ninetoothed.kernels.fused_rms_norm.kernel( + x_2d, w_2d, eps, y_2d, BLOCK_SIZE=x.shape[-1] + ) + + return y_2d.view(x.shape) + + def mm(input, other): output_shape = (input.shape[0], other.shape[1]) output = torch.empty(output_shape, dtype=input.dtype, device=input.device) diff --git a/ops/triton/kernels/fused_rms_norm.py b/ops/triton/kernels/fused_rms_norm.py new file mode 100644 index 0000000..28a9616 --- /dev/null +++ b/ops/triton/kernels/fused_rms_norm.py @@ -0,0 +1,32 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel( + x_ptr, + w_ptr, + y_ptr, + num_cols, + x_stride, + w_stride, + y_stride, + eps: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < num_cols + + x_ptrs = x_ptr + row_idx * x_stride + col_offsets + w_ptrs = w_ptr + row_idx * w_stride + col_offsets + + x = tl.load(x_ptrs, mask=mask).to(tl.float32) + w = tl.load(w_ptrs, mask=mask).to(tl.float32) + + y = x * tl.rsqrt(tl.sum(x * x) / num_cols + eps) * w + + y_ptrs = y_ptr + row_idx * y_stride + col_offsets + + tl.store(y_ptrs, y, mask=mask) diff --git a/ops/triton/torch.py b/ops/triton/torch.py index 150c328..d8129f2 100644 --- a/ops/triton/torch.py +++ b/ops/triton/torch.py @@ -7,6 +7,7 @@ import ops.triton.kernels.addmm import ops.triton.kernels.bmm import ops.triton.kernels.conv2d +import ops.triton.kernels.fused_rms_norm import ops.triton.kernels.mm import ops.triton.kernels.rms_norm import ops.triton.kernels.scaled_dot_product_attention @@ -121,6 +122,29 @@ def grid(meta): return output +def fused_rms_norm(x, w, eps=None): + if eps is None: + eps = torch.finfo(x.dtype).eps + + x_2d = x.view(-1, x.shape[-1]) + w_2d = w.expand_as(x_2d) + y_2d = torch.empty_like(x_2d) + + ops.triton.kernels.fused_rms_norm.kernel[(x_2d.shape[-2],)]( + x_2d, + w_2d, + y_2d, + x_2d.shape[-1], + x_2d.stride(0), + w_2d.stride(0), + y_2d.stride(0), + eps, + BLOCK_SIZE=triton.next_power_of_2(x_2d.shape[-1]), + ) + + return y_2d.view(x.shape) + + def mm(input, other): output_shape = (input.shape[0], other.shape[1]) output = torch.empty(output_shape, dtype=input.dtype, device=input.device) From da3e0b5dfeee89ca0795cfcdf0f86132ec6b59cd Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Thu, 15 May 2025 15:48:13 +0800 Subject: [PATCH 40/68] Update the `bmm` function call in `linear.py` --- linear.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/linear.py b/linear.py index 8234064..3f42d3b 100644 --- a/linear.py +++ b/linear.py @@ -1,6 +1,6 @@ import torch.nn as nn -from bmm import bmm +import ops.ninetoothed.torch class Linear(nn.Module): @@ -10,4 +10,6 @@ def __init__(self, other): self.__dict__ = other.__dict__ def forward(self, input): - return bmm(input, self.weight.T.unsqueeze(0).expand(input.shape[0], -1, -1)) + return ops.ninetoothed.torch.bmm( + input, self.weight.T.unsqueeze(0).expand(input.shape[0], -1, -1) + ) From dbc429a72a4daa8d8da686b0400db3ee7efed4fc Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Thu, 15 May 2025 15:49:46 +0800 Subject: [PATCH 41/68] Separate `silu` kernels into modular packages --- ops/ninetoothed/kernels/silu.py | 19 ++++++++ ops/ninetoothed/torch.py | 10 ++++ ops/triton/kernels/silu.py | 16 +++++++ ops/triton/torch.py | 16 +++++++ silu.py | 85 +++++++++++++++++++-------------- 5 files changed, 109 insertions(+), 37 deletions(-) create mode 100644 ops/ninetoothed/kernels/silu.py create mode 100644 ops/triton/kernels/silu.py diff --git a/ops/ninetoothed/kernels/silu.py b/ops/ninetoothed/kernels/silu.py new file mode 100644 index 0000000..3ead189 --- /dev/null +++ b/ops/ninetoothed/kernels/silu.py @@ -0,0 +1,19 @@ +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) + + +def arrangement(input, output, BLOCK_SIZE=BLOCK_SIZE): + return input.tile((BLOCK_SIZE,)), output.tile((BLOCK_SIZE,)) + + +def application(input, output): + input_loaded = input + output = input_loaded * ntl.sigmoid(ntl.cast(input_loaded, ntl.float32)) # noqa: F841 + + +tensors = (Tensor(1), Tensor(1)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/torch.py b/ops/ninetoothed/torch.py index 94b11e6..49bbe5b 100644 --- a/ops/ninetoothed/torch.py +++ b/ops/ninetoothed/torch.py @@ -10,6 +10,7 @@ import ops.ninetoothed.kernels.mm import ops.ninetoothed.kernels.rms_norm import ops.ninetoothed.kernels.scaled_dot_product_attention +import ops.ninetoothed.kernels.silu import ops.ninetoothed.kernels.softmax @@ -100,6 +101,15 @@ def scaled_dot_product_attention(q, k, v, scale=None): return o +def silu(input): + input_flat = input.flatten() + output_flat = torch.empty_like(input_flat) + + ops.ninetoothed.kernels.silu.kernel(input_flat, output_flat, BLOCK_SIZE=1024) + + return output_flat.view_as(input) + + def softmax(input): output = torch.empty_like(input) diff --git a/ops/triton/kernels/silu.py b/ops/triton/kernels/silu.py new file mode 100644 index 0000000..b852a4c --- /dev/null +++ b/ops/triton/kernels/silu.py @@ -0,0 +1,16 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel(input_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < num_elements + + input = tl.load(input_ptr + offsets, mask=mask) + output = input * tl.sigmoid(tl.cast(input, tl.float32)) + + tl.store(output_ptr + offsets, output, mask=mask) diff --git a/ops/triton/torch.py b/ops/triton/torch.py index d8129f2..d0cfbcb 100644 --- a/ops/triton/torch.py +++ b/ops/triton/torch.py @@ -11,6 +11,7 @@ import ops.triton.kernels.mm import ops.triton.kernels.rms_norm import ops.triton.kernels.scaled_dot_product_attention +import ops.triton.kernels.silu import ops.triton.kernels.softmax @@ -224,6 +225,21 @@ def grid(meta): return o +def silu(input): + input_flat = input.flatten() + output_flat = torch.empty_like(input_flat) + num_elements = input_flat.numel() + + def grid(meta): + return (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) + + ops.triton.kernels.silu.kernel[grid]( + input_flat, output_flat, num_elements, BLOCK_SIZE=1024 + ) + + return output_flat.view_as(input) + + def softmax(input): output = torch.empty_like(input) diff --git a/silu.py b/silu.py index 420cc0a..785f9c6 100644 --- a/silu.py +++ b/silu.py @@ -1,40 +1,10 @@ -import ninetoothed -import ninetoothed.language as ntl import torch import torch.nn as nn import torch.nn.functional as F -from ninetoothed import Symbol, Tensor +import triton -BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) -BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) -BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True) - - -def arrangement( - input, - output, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, -): - tile_shape = (BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K) - - return input.tile(tile_shape), output.tile(tile_shape) - - -def application(input, output): - output = input * ntl.sigmoid(input) # noqa: F841 - - -silu_kernel = ninetoothed.make(arrangement, application, (Tensor(3), Tensor(3))) - - -def silu(input): - output = torch.empty_like(input) - - silu_kernel(input, output) - - return output +import ops.ninetoothed.torch +import ops.triton.torch class SiLU(nn.Module): @@ -44,25 +14,66 @@ def __init__(self, other): self.__dict__ = other.__dict__ def forward(self, input): - return silu(input) + return ops.ninetoothed.torch.silu(input) if __name__ == "__main__": torch.manual_seed(0) shape = (8, 256, 512) - dtype = torch.float32 + dtype = torch.float16 device = "cuda" input = torch.randn(shape, dtype=dtype, device=device) - ninetoothed_output = silu(input) + ninetoothed_output = ops.ninetoothed.torch.silu(input) torch_output = F.silu(input) + triton_output = ops.triton.torch.silu(input) print(ninetoothed_output) print(torch_output) + print(triton_output) - if torch.allclose(ninetoothed_output, torch_output): + if torch.allclose(ninetoothed_output, torch_output, atol=1e-3, rtol=1e-3): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): + print("✅ NineToothed and Triton match.") + else: + print("❌ NineToothed and Triton differ.") + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["m", "n", "k"], + x_vals=[2**i for i in range(3, 10)], + x_log=True, + line_arg="provider", + line_vals=["ninetoothed", "torch", "triton"], + line_names=["NineToothed", "PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], + ylabel="ms", + plot_name="silu-performance", + args={}, + ) + ) + def benchmark(m, n, k, provider): + input = torch.randn(m, n, k, dtype=dtype, device=device) + + ninetoothed_output = ops.ninetoothed.torch.silu(input) + torch_output = F.silu(input) + triton_output = ops.triton.torch.silu(input) + + assert torch.allclose(ninetoothed_output, torch_output, atol=0.001) + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) + + if provider == "ninetoothed": + ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.silu(input)) + elif provider == "torch": + ms = triton.testing.do_bench(lambda: F.silu(input)) + elif provider == "triton": + ms = triton.testing.do_bench(lambda: ops.triton.torch.silu(input)) + + return ms + + benchmark.run(show_plots=True, print_data=True, save_path=".") From ebe16814bfced4cf2676278a27837bd0c9d10359 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Thu, 15 May 2025 15:50:16 +0800 Subject: [PATCH 42/68] Separate `swiglu` kernels into modular packages --- ops/ninetoothed/kernels/swiglu.py | 20 +++++++++ ops/ninetoothed/torch.py | 12 ++++++ ops/triton/kernels/swiglu.py | 18 ++++++++ ops/triton/torch.py | 18 ++++++++ swiglu.py | 69 +++---------------------------- 5 files changed, 74 insertions(+), 63 deletions(-) create mode 100644 ops/ninetoothed/kernels/swiglu.py create mode 100644 ops/triton/kernels/swiglu.py diff --git a/ops/ninetoothed/kernels/swiglu.py b/ops/ninetoothed/kernels/swiglu.py new file mode 100644 index 0000000..63b7cd6 --- /dev/null +++ b/ops/ninetoothed/kernels/swiglu.py @@ -0,0 +1,20 @@ +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) + + +def arrangement(a, b, c, BLOCK_SIZE=BLOCK_SIZE): + return a.tile((BLOCK_SIZE,)), b.tile((BLOCK_SIZE,)), c.tile((BLOCK_SIZE,)) + + +def application(a, b, c): + b_loaded = b + gate = b_loaded * ntl.sigmoid(ntl.cast(b_loaded, ntl.float32)) + c = a * gate # noqa: F841 + + +tensors = (Tensor(1), Tensor(1), Tensor(1)) + +kernel = ninetoothed.make(arrangement, application, tensors) diff --git a/ops/ninetoothed/torch.py b/ops/ninetoothed/torch.py index 49bbe5b..92f4db9 100644 --- a/ops/ninetoothed/torch.py +++ b/ops/ninetoothed/torch.py @@ -12,6 +12,7 @@ import ops.ninetoothed.kernels.scaled_dot_product_attention import ops.ninetoothed.kernels.silu import ops.ninetoothed.kernels.softmax +import ops.ninetoothed.kernels.swiglu def add(input, other): @@ -116,3 +117,14 @@ def softmax(input): ops.ninetoothed.kernels.softmax.kernel(input, output, BLOCK_SIZE=input.shape[-1]) return output + + +def swiglu(a, b): + a_flat = a.flatten() + b_flat = b.flatten() + + c = torch.empty_like(a_flat) + + ops.ninetoothed.kernels.swiglu.kernel(a_flat, b_flat, c, BLOCK_SIZE=1024) + + return c.view_as(a) diff --git a/ops/triton/kernels/swiglu.py b/ops/triton/kernels/swiglu.py new file mode 100644 index 0000000..a154af4 --- /dev/null +++ b/ops/triton/kernels/swiglu.py @@ -0,0 +1,18 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel(a_ptr, b_ptr, c_ptr, num_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < num_elements + + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + + silu_b = b * tl.sigmoid(tl.cast(b, tl.float32)) + c = a * silu_b + + tl.store(c_ptr + offsets, c, mask=mask) diff --git a/ops/triton/torch.py b/ops/triton/torch.py index d0cfbcb..bc9efb1 100644 --- a/ops/triton/torch.py +++ b/ops/triton/torch.py @@ -13,6 +13,7 @@ import ops.triton.kernels.scaled_dot_product_attention import ops.triton.kernels.silu import ops.triton.kernels.softmax +import ops.triton.kernels.swiglu def add(input, other): @@ -253,3 +254,20 @@ def softmax(input): ) return output + + +def swiglu(a, b): + a_flat = a.flatten() + b_flat = b.flatten() + c_flat = torch.empty_like(a_flat) + + num_elements = a_flat.numel() + + def grid(meta): + return (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) + + ops.triton.kernels.swiglu.kernel[grid]( + a_flat, b_flat, c_flat, num_elements, BLOCK_SIZE=1024 + ) + + return c_flat.view_as(a) diff --git a/swiglu.py b/swiglu.py index 27d8504..3aa78b3 100644 --- a/swiglu.py +++ b/swiglu.py @@ -1,66 +1,9 @@ -import ninetoothed -import ninetoothed.language as ntl import torch import torch.nn.functional as F import triton -import triton.language as tl -from ninetoothed import Symbol, Tensor -BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) - - -@ninetoothed.jit -def swiglu_kernel( - a: Tensor(1).tile((BLOCK_SIZE,)), - b: Tensor(1).tile((BLOCK_SIZE,)), - c: Tensor(1).tile((BLOCK_SIZE,)), -): - b_loaded = b - gate = b_loaded * ntl.sigmoid(ntl.cast(b_loaded, ntl.float32)) - c = a * gate # noqa: F841 - - -def swiglu(a, b): - a_1d = a.flatten() - b_1d = b.flatten() - - c = torch.empty_like(a_1d) - - swiglu_kernel(a_1d, b_1d, c, BLOCK_SIZE=1024) - - return c.view_as(a) - - -@triton.jit -def triton_swiglu_kernel( - a_ptr, b_ptr, c_ptr, num_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr -): - pid = tl.program_id(0) - offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < num_elements - - a = tl.load(a_ptr + offsets, mask=mask, other=0.0) - b = tl.load(b_ptr + offsets, mask=mask, other=0.0) - - silu_b = b * tl.sigmoid(tl.cast(b, tl.float32)) - c = a * silu_b - - tl.store(c_ptr + offsets, c, mask=mask) - - -def triton_swiglu(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - # Flatten the inputs so that the kernel always works on 1D tensors. - a_flat = a.flatten() - b_flat = b.flatten() - c_flat = torch.empty_like(a_flat) - num_elements = a_flat.numel() - - def grid(meta): - return (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) - - triton_swiglu_kernel[grid](a_flat, b_flat, c_flat, num_elements, BLOCK_SIZE=1024) - - return c_flat.view_as(a) +import ops.ninetoothed.torch +import ops.triton.torch def torch_swiglu( @@ -81,9 +24,9 @@ def torch_swiglu( b = torch.rand(shape, dtype=dtype, device=device) c = torch.rand(shape, dtype=dtype, device=device) - ninetoothed_output = swiglu(a, b) + ninetoothed_output = ops.ninetoothed.torch.swiglu(a, b) torch_output = torch_swiglu(a, b) - triton_output = triton_swiglu(a, b) + triton_output = ops.triton.torch.swiglu(a, b) print(ninetoothed_output) print(torch_output) @@ -119,11 +62,11 @@ def benchmark(m, n, provider): b = torch.rand(shape, dtype=dtype, device=device) if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: swiglu(a, b)) + ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.swiglu(a, b)) elif provider == "torch": ms = triton.testing.do_bench(lambda: torch_swiglu(a, b)) elif provider == "triton": - ms = triton.testing.do_bench(lambda: triton_swiglu(a, b)) + ms = triton.testing.do_bench(lambda: ops.triton.torch.swiglu(a, b)) return ms From 596455ff5e9b26f48f7d7b37da72db7524524427 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Thu, 15 May 2025 15:51:19 +0800 Subject: [PATCH 43/68] Fix the Triton implementation in `scaled_dot_product_attention.py` --- .../kernels/scaled_dot_product_attention.py | 15 ++++++++------- ops/triton/torch.py | 8 +++++--- scaled_dot_product_attention.py | 11 ++++++----- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/ops/triton/kernels/scaled_dot_product_attention.py b/ops/triton/kernels/scaled_dot_product_attention.py index 48de10e..148700d 100644 --- a/ops/triton/kernels/scaled_dot_product_attention.py +++ b/ops/triton/kernels/scaled_dot_product_attention.py @@ -48,7 +48,8 @@ def kernel( o_stride_m, o_stride_n, scale, - seq_len, + seq_len_q, + seq_len_k_v, EMB_DIM: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -62,7 +63,7 @@ def kernel( q_off = off_z * q_stride_z + off_h * q_stride_h q_block_ptr = tl.make_block_ptr( base=q_ptr + q_off, - shape=(seq_len, EMB_DIM), + shape=(seq_len_q, EMB_DIM), strides=(q_stride_m, q_stride_k), offsets=(offs_m_start, 0), block_shape=(BLOCK_SIZE_M, EMB_DIM), @@ -71,7 +72,7 @@ def kernel( k_off = off_z * k_stride_z + off_h * k_stride_h k_block_ptr = tl.make_block_ptr( base=k_ptr + k_off, - shape=(EMB_DIM, seq_len), + shape=(EMB_DIM, seq_len_k_v), strides=(k_stride_k, k_stride_n), offsets=(0, 0), block_shape=(EMB_DIM, BLOCK_SIZE_N), @@ -80,7 +81,7 @@ def kernel( v_off = off_z * v_stride_z + off_h * v_stride_h v_block_ptr = tl.make_block_ptr( base=v_ptr + v_off, - shape=(seq_len, EMB_DIM), + shape=(seq_len_k_v, EMB_DIM), strides=(v_stride_k, v_stride_n), offsets=(0, 0), block_shape=(BLOCK_SIZE_N, EMB_DIM), @@ -89,7 +90,7 @@ def kernel( o_off = off_z * o_stride_z + off_h * o_stride_h o_block_ptr = tl.make_block_ptr( base=o_ptr + o_off, - shape=(seq_len, EMB_DIM), + shape=(seq_len_q, EMB_DIM), strides=(o_stride_m, o_stride_n), offsets=(offs_m_start, 0), block_shape=(BLOCK_SIZE_M, EMB_DIM), @@ -103,10 +104,10 @@ def kernel( l_i = tl.full((BLOCK_SIZE_M,), 1, dtype=tl.float32) m_i = tl.full((BLOCK_SIZE_M,), float("-inf"), dtype=tl.float32) - for i in range(0, tl.cdiv(seq_len, BLOCK_SIZE_N)): + for i in range(0, tl.cdiv(seq_len_k_v, BLOCK_SIZE_N)): k = tl.load(k_block_ptr, boundary_check=(0, 1)) - mask = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) < seq_len + mask = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) < seq_len_k_v qk = tl.where(mask, tl.dot(q, k), float("-inf")) m_ij = tl.maximum(m_i, tl.max(qk, 1)) diff --git a/ops/triton/torch.py b/ops/triton/torch.py index bc9efb1..63c721e 100644 --- a/ops/triton/torch.py +++ b/ops/triton/torch.py @@ -195,7 +195,8 @@ def rms_norm(input, eps=None): def scaled_dot_product_attention(q, k, v, scale=None): - batch_size, num_heads, seq_len, emb_dim = q.shape + batch_size, num_heads, seq_len_q, emb_dim = q.shape + _, _, seq_len_k_v, _ = k.shape if scale is None: scale = 1 / math.sqrt(emb_dim) @@ -204,7 +205,7 @@ def scaled_dot_product_attention(q, k, v, scale=None): def grid(meta): return ( - triton.cdiv(seq_len, meta["BLOCK_SIZE_M"]), + triton.cdiv(seq_len_q, meta["BLOCK_SIZE_M"]), num_heads, batch_size, ) @@ -219,7 +220,8 @@ def grid(meta): *v.stride(), *o.stride(), scale=scale, - seq_len=seq_len, + seq_len_q=seq_len_q, + seq_len_k_v=seq_len_k_v, EMB_DIM=emb_dim, ) diff --git a/scaled_dot_product_attention.py b/scaled_dot_product_attention.py index 616b64f..04443c8 100644 --- a/scaled_dot_product_attention.py +++ b/scaled_dot_product_attention.py @@ -89,13 +89,14 @@ def _rope(x, sin_table, cos_table): if __name__ == "__main__": torch.manual_seed(0) - shape = (2, 4, 1024, 64) + q_o_shape = (2, 8, 1024, 64) + k_v_shape = (2, 8, 1024, 64) dtype = torch.float16 device = "cuda" - q = torch.randn(shape, dtype=dtype, device=device) - k = torch.randn(shape, dtype=dtype, device=device) - v = torch.randn(shape, dtype=dtype, device=device) + q = torch.randn(q_o_shape, dtype=dtype, device=device) + k = torch.randn(k_v_shape, dtype=dtype, device=device) + v = torch.randn(k_v_shape, dtype=dtype, device=device) ninetoothed_output = ops.ninetoothed.torch.scaled_dot_product_attention(q, k, v) torch_output = F.scaled_dot_product_attention(q, k, v) @@ -109,7 +110,7 @@ def _rope(x, sin_table, cos_table): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): + if torch.allclose(ninetoothed_output, triton_output, atol=1e-3, rtol=0): print("✅ NineToothed and Triton match.") else: print("❌ NineToothed and Triton differ.") From 3585cfa646fca4a12b09b71158f09af41ca072f2 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Thu, 15 May 2025 15:53:50 +0800 Subject: [PATCH 44/68] Add inference profiling support --- infer.py | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/infer.py b/infer.py index 7f1a3c9..16ef327 100644 --- a/infer.py +++ b/infer.py @@ -1,5 +1,7 @@ import argparse +import time +import torch from transformers import AutoModelForCausalLM, AutoTokenizer from fused_rms_norm import RMSNorm @@ -38,6 +40,18 @@ default="cpu", help='Device to use for inference (e.g., "cuda", "cpu").', ) + parser.add_argument( + "--num-warmup-iterations", + type=int, + default=0, + help="For profiling. The number of warmup iterations to run before measuring performance.", + ) + parser.add_argument( + "--num-profiling-iterations", + type=int, + default=1, + help="For profiling. The number of iterations to run for performance measurement.", + ) args = parser.parse_args() @@ -45,6 +59,11 @@ prompts = args.prompts max_new_tokens = args.max_new_tokens device = args.device + num_warmup_iterations = args.num_warmup_iterations + num_profiling_iterations = args.num_profiling_iterations + + assert num_profiling_iterations >= 1 + assert num_warmup_iterations >= 0 tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device) @@ -58,7 +77,25 @@ replace_module(model, SiLU) inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(device) - outputs = model.generate(**inputs, max_new_tokens=max_new_tokens) + + for _ in range(num_warmup_iterations): + model.generate(**inputs, max_new_tokens=max_new_tokens) + + if device == "cuda": + torch.cuda.synchronize() + + start_time = time.time() + + for _ in range(num_profiling_iterations): + outputs = model.generate(**inputs, max_new_tokens=max_new_tokens) + + if device == "cuda": + torch.cuda.synchronize() + + end_time = time.time() + avg_time_ms = (end_time - start_time) * 1000 / num_profiling_iterations + strings = tokenizer.batch_decode(outputs, skip_special_tokens=True) print(strings) + print(f"\nAverage inference time: {avg_time_ms:.4f} ms.") From db05592848e3b9bacce11d47ed4f539c87c08680 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 16 May 2025 00:36:52 +0800 Subject: [PATCH 45/68] Add Triton and PyTorch implementations of non-interleaved RoPE --- rope.py | 69 +++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 24 deletions(-) diff --git a/rope.py b/rope.py index 13109ab..33cb204 100644 --- a/rope.py +++ b/rope.py @@ -1,3 +1,5 @@ +import functools + import ninetoothed import torch import triton @@ -42,10 +44,22 @@ def application(tensor, sin_table, cos_table): tensors = tuple(Tensor(4, shape_options={"constexpr": True}) for _ in range(3)) -rope_kernel = ninetoothed.make(arrangement, application, tensors) + +interleaved_kernel = ninetoothed.make( + functools.partial(arrangement, interleaved=True), application, tensors +) +non_interleaved_kernel = ninetoothed.make( + functools.partial(arrangement, interleaved=False), application, tensors +) + + +def rope_kernel(tensor, sin_table, cos_table, interleaved=True): + return (interleaved_kernel if interleaved else non_interleaved_kernel)( + tensor, sin_table, cos_table + ) -def rope(tensor, sin_table, cos_table): +def rope(tensor, sin_table, cos_table, interleaved=True): batch_size, _, num_heads, _ = tensor.shape sin_table = sin_table.unsqueeze(1).unsqueeze(0) @@ -54,7 +68,7 @@ def rope(tensor, sin_table, cos_table): cos_table = cos_table.expand(batch_size, -1, num_heads, -1) tensor_cloned = tensor.clone() - rope_kernel(tensor_cloned, sin_table, cos_table) + rope_kernel(tensor_cloned, sin_table, cos_table, interleaved) return tensor_cloned @@ -71,6 +85,7 @@ def triton_rope_kernel( sin_table_stride_l, cos_table_stride_l, emb_dim, + INTERLEAVED: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): off_n = tl.program_id(0) @@ -86,18 +101,19 @@ def triton_rope_kernel( cos_table = tl.load(cos_table_ptr + off_l * cos_table_stride_l + offs, mask=mask) even_offs = ( - off_n * tensor_stride_n - + off_l * tensor_stride_l - + off_h * tensor_stride_h - + (2 * offs) * tensor_stride_e + off_n * tensor_stride_n + off_l * tensor_stride_l + off_h * tensor_stride_h ) odd_offs = ( - off_n * tensor_stride_n - + off_l * tensor_stride_l - + off_h * tensor_stride_h - + (2 * offs + 1) * tensor_stride_e + off_n * tensor_stride_n + off_l * tensor_stride_l + off_h * tensor_stride_h ) + if INTERLEAVED: + even_offs += (2 * offs) * tensor_stride_e + odd_offs += (2 * offs + 1) * tensor_stride_e + else: + even_offs += offs * tensor_stride_e + odd_offs += (offs + half_emb_dim) * tensor_stride_e + even_ptrs = tensor_ptr + even_offs odd_ptrs = tensor_ptr + odd_offs @@ -108,9 +124,7 @@ def triton_rope_kernel( tl.store(odd_ptrs, even * sin_table + odd * cos_table, mask=mask) -def triton_rope( - tensor: torch.Tensor, sin_table: torch.Tensor, cos_table: torch.Tensor -) -> torch.Tensor: +def triton_rope(tensor, sin_table, cos_table, interleaved=True): batch_size, seq_len, num_heads, emb_dim = tensor.shape assert emb_dim % 2 == 0, "The embedding dimension must be even." @@ -131,28 +145,35 @@ def triton_rope( sin_table.stride(0), cos_table.stride(0), emb_dim, + INTERLEAVED=interleaved, BLOCK_SIZE=BLOCK_SIZE, ) return tensor_cloned -def torch_rope(input, sin_table, cos_table): +def torch_rope(input, sin_table, cos_table, interleaved=True): batch_size, seq_len, num_heads, emb_dim = input.shape assert emb_dim % 2 == 0, "The embedding dimension must be even." - pair_wise_input = input.view(batch_size, seq_len, num_heads, emb_dim // 2, 2) sin_table = sin_table[None, :, None, :] cos_table = cos_table[None, :, None, :] - pair_0, pair_1 = pair_wise_input[..., 0], pair_wise_input[..., 1] - rotated_pair_0 = pair_0 * cos_table - pair_1 * sin_table - rotated_pair_1 = pair_0 * sin_table + pair_1 * cos_table + if interleaved: + pair_wise_input = input.view(batch_size, seq_len, num_heads, emb_dim // 2, 2) + input_0, input_1 = pair_wise_input[..., 0], pair_wise_input[..., 1] + input_0_rotated = input_0 * cos_table - input_1 * sin_table + input_1_rotated = input_0 * sin_table + input_1 * cos_table - output = torch.stack((rotated_pair_0, rotated_pair_1), dim=-1).view(input.shape) + return torch.stack((input_0_rotated, input_1_rotated), dim=-1).view(input.shape) + else: + input_0 = x[..., : x.shape[-1] // 2] + input_1 = x[..., x.shape[-1] // 2 :] + input_0_rotated = input_0 * cos_table - input_1 * sin_table + input_1_rotated = input_0 * sin_table + input_1 * cos_table - return output + return torch.cat((input_0_rotated, input_1_rotated), dim=-1) def _generate_sin_and_cos_tables( @@ -181,9 +202,9 @@ def _generate_sin_and_cos_tables( sin_table, cos_table = _generate_sin_and_cos_tables(seq_len, emb_dim) x = torch.randn(batch_size, seq_len, num_heads, emb_dim, dtype=dtype, device=device) - ninetoothed_output = rope(x, sin_table, cos_table) - torch_output = torch_rope(x, sin_table, cos_table) - triton_output = triton_rope(x, sin_table, cos_table) + ninetoothed_output = rope(x, sin_table, cos_table, interleaved=False) + torch_output = torch_rope(x, sin_table, cos_table, interleaved=False) + triton_output = triton_rope(x, sin_table, cos_table, interleaved=False) print(ninetoothed_output) print(torch_output) From 9db2eebb677730a988c970dc2a3f7e708172b79e Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 16 May 2025 00:43:13 +0800 Subject: [PATCH 46/68] Separate `rope` kernels into modular packages --- ops/ninetoothed/kernels/rope.py | 56 +++++++++++ ops/ninetoothed/torch.py | 13 +++ ops/triton/kernels/rope.py | 53 ++++++++++ ops/triton/torch.py | 23 +++++ rope.py | 165 +++----------------------------- 5 files changed, 157 insertions(+), 153 deletions(-) create mode 100644 ops/ninetoothed/kernels/rope.py create mode 100644 ops/triton/kernels/rope.py diff --git a/ops/ninetoothed/kernels/rope.py b/ops/ninetoothed/kernels/rope.py new file mode 100644 index 0000000..f55d9cf --- /dev/null +++ b/ops/ninetoothed/kernels/rope.py @@ -0,0 +1,56 @@ +import functools + +import ninetoothed +from ninetoothed import Tensor + + +def arrangement(input, sin_table, cos_table, interleaved=True): + emb_dim = input.shape[-1] + tile_shape = (1, 1, 1, emb_dim // 2) + + if interleaved: + strides = (-1, -1, -1, 1) + dilation = (1, 1, 1, 2) + else: + strides = None + dilation = None + + input_arranged = input.tile(tile_shape, strides=strides, dilation=dilation) + input_arranged = input_arranged.tile((1, 1, 1, 2)) + input_arranged.dtype = input_arranged.dtype.squeeze((0, 1, 2)) + input_arranged.dtype.dtype = input_arranged.dtype.dtype.squeeze((0, 1, 2)) + + sin_table_arranged = sin_table.tile(tile_shape) + sin_table_arranged.dtype = sin_table_arranged.dtype.squeeze((0, 1, 2)) + + cos_table_arranged = cos_table.tile(tile_shape) + cos_table_arranged.dtype = cos_table_arranged.dtype.squeeze((0, 1, 2)) + + return input_arranged, sin_table_arranged, cos_table_arranged + + +def application(input, sin_table, cos_table): + sin_table_loaded = sin_table + cos_table_loaded = cos_table + + input_0 = input[0] + input_1 = input[1] + + input[0] = input_0 * cos_table_loaded - input_1 * sin_table_loaded + input[1] = input_0 * sin_table_loaded + input_1 * cos_table_loaded + + +inputs = tuple(Tensor(4, shape_options={"constexpr": True}) for _ in range(3)) + +interleaved_kernel = ninetoothed.make( + functools.partial(arrangement, interleaved=True), application, inputs +) +non_interleaved_kernel = ninetoothed.make( + functools.partial(arrangement, interleaved=False), application, inputs +) + + +def kernel(input, sin_table, cos_table, interleaved=True): + return (interleaved_kernel if interleaved else non_interleaved_kernel)( + input, sin_table, cos_table + ) diff --git a/ops/ninetoothed/torch.py b/ops/ninetoothed/torch.py index 92f4db9..7f2402d 100644 --- a/ops/ninetoothed/torch.py +++ b/ops/ninetoothed/torch.py @@ -9,6 +9,7 @@ import ops.ninetoothed.kernels.fused_rms_norm import ops.ninetoothed.kernels.mm import ops.ninetoothed.kernels.rms_norm +import ops.ninetoothed.kernels.rope import ops.ninetoothed.kernels.scaled_dot_product_attention import ops.ninetoothed.kernels.silu import ops.ninetoothed.kernels.softmax @@ -91,6 +92,18 @@ def rms_norm(input, eps=None): return output +def rope(input, sin_table, cos_table, interleaved=True): + batch_size, _, num_heads, _ = input.shape + + output = input.clone() + sin_table = sin_table[None, :, None, :].expand(batch_size, -1, num_heads, -1) + cos_table = cos_table[None, :, None, :].expand(batch_size, -1, num_heads, -1) + + ops.ninetoothed.kernels.rope.kernel(output, sin_table, cos_table, interleaved) + + return output + + def scaled_dot_product_attention(q, k, v, scale=None): if scale is None: scale = 1 / math.sqrt(q.shape[-1]) diff --git a/ops/triton/kernels/rope.py b/ops/triton/kernels/rope.py new file mode 100644 index 0000000..2e2aa2d --- /dev/null +++ b/ops/triton/kernels/rope.py @@ -0,0 +1,53 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel( + tensor_ptr, + sin_table_ptr, + cos_table_ptr, + tensor_stride_n, + tensor_stride_l, + tensor_stride_h, + tensor_stride_e, + sin_table_stride_l, + cos_table_stride_l, + emb_dim, + INTERLEAVED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + off_n = tl.program_id(0) + off_l = tl.program_id(1) + off_h = tl.program_id(2) + + offs = tl.arange(0, BLOCK_SIZE) + + half_emb_dim = emb_dim // 2 + mask = offs < half_emb_dim + + sin_table = tl.load(sin_table_ptr + off_l * sin_table_stride_l + offs, mask=mask) + cos_table = tl.load(cos_table_ptr + off_l * cos_table_stride_l + offs, mask=mask) + + even_offs = ( + off_n * tensor_stride_n + off_l * tensor_stride_l + off_h * tensor_stride_h + ) + odd_offs = ( + off_n * tensor_stride_n + off_l * tensor_stride_l + off_h * tensor_stride_h + ) + + if INTERLEAVED: + even_offs += (2 * offs) * tensor_stride_e + odd_offs += (2 * offs + 1) * tensor_stride_e + else: + even_offs += offs * tensor_stride_e + odd_offs += (offs + half_emb_dim) * tensor_stride_e + + even_ptrs = tensor_ptr + even_offs + odd_ptrs = tensor_ptr + odd_offs + + even = tl.load(even_ptrs, mask=mask) + odd = tl.load(odd_ptrs, mask=mask) + + tl.store(even_ptrs, even * cos_table - odd * sin_table, mask=mask) + tl.store(odd_ptrs, even * sin_table + odd * cos_table, mask=mask) diff --git a/ops/triton/torch.py b/ops/triton/torch.py index 63c721e..4eb8777 100644 --- a/ops/triton/torch.py +++ b/ops/triton/torch.py @@ -10,6 +10,7 @@ import ops.triton.kernels.fused_rms_norm import ops.triton.kernels.mm import ops.triton.kernels.rms_norm +import ops.triton.kernels.rope import ops.triton.kernels.scaled_dot_product_attention import ops.triton.kernels.silu import ops.triton.kernels.softmax @@ -194,6 +195,28 @@ def rms_norm(input, eps=None): return output +def rope(input, sin_table, cos_table, interleaved=True): + batch_size, seq_len, num_heads, emb_dim = input.shape + + BLOCK_SIZE = triton.next_power_of_2(emb_dim // 2) + + output = input.clone() + + ops.triton.kernels.rope.kernel[(batch_size, seq_len, num_heads)]( + output, + sin_table, + cos_table, + *input.stride(), + sin_table.stride(0), + cos_table.stride(0), + emb_dim, + INTERLEAVED=interleaved, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return output + + def scaled_dot_product_attention(q, k, v, scale=None): batch_size, num_heads, seq_len_q, emb_dim = q.shape _, _, seq_len_k_v, _ = k.shape diff --git a/rope.py b/rope.py index 33cb204..bac870a 100644 --- a/rope.py +++ b/rope.py @@ -1,155 +1,8 @@ -import functools - -import ninetoothed import torch import triton -import triton.language as tl -from ninetoothed import Tensor - - -def arrangement(tensor, sin_table, cos_table, interleaved=True): - emb_dim = tensor.shape[-1] - tile_shape = (1, 1, 1, emb_dim // 2) - - if interleaved: - strides = (-1, -1, -1, 1) - dilation = (1, 1, 1, 2) - else: - strides = None - dilation = None - - tensor_arranged = tensor.tile(tile_shape, strides=strides, dilation=dilation) - tensor_arranged = tensor_arranged.tile((1, 1, 1, 2)) - tensor_arranged.dtype = tensor_arranged.dtype.squeeze((0, 1, 2)) - tensor_arranged.dtype.dtype = tensor_arranged.dtype.dtype.squeeze((0, 1, 2)) - - sin_table_arranged = sin_table.tile(tile_shape) - sin_table_arranged.dtype = sin_table_arranged.dtype.squeeze((0, 1, 2)) - - cos_table_arranged = cos_table.tile(tile_shape) - cos_table_arranged.dtype = cos_table_arranged.dtype.squeeze((0, 1, 2)) - - return tensor_arranged, sin_table_arranged, cos_table_arranged - - -def application(tensor, sin_table, cos_table): - sin_table_loaded = sin_table - cos_table_loaded = cos_table - - tensor_0 = tensor[0] - tensor_1 = tensor[1] - tensor[0] = tensor_0 * cos_table_loaded - tensor_1 * sin_table_loaded - tensor[1] = tensor_0 * sin_table_loaded + tensor_1 * cos_table_loaded - - -tensors = tuple(Tensor(4, shape_options={"constexpr": True}) for _ in range(3)) - -interleaved_kernel = ninetoothed.make( - functools.partial(arrangement, interleaved=True), application, tensors -) -non_interleaved_kernel = ninetoothed.make( - functools.partial(arrangement, interleaved=False), application, tensors -) - - -def rope_kernel(tensor, sin_table, cos_table, interleaved=True): - return (interleaved_kernel if interleaved else non_interleaved_kernel)( - tensor, sin_table, cos_table - ) - - -def rope(tensor, sin_table, cos_table, interleaved=True): - batch_size, _, num_heads, _ = tensor.shape - - sin_table = sin_table.unsqueeze(1).unsqueeze(0) - sin_table = sin_table.expand(batch_size, -1, num_heads, -1) - cos_table = cos_table.unsqueeze(1).unsqueeze(0) - cos_table = cos_table.expand(batch_size, -1, num_heads, -1) - - tensor_cloned = tensor.clone() - rope_kernel(tensor_cloned, sin_table, cos_table, interleaved) - - return tensor_cloned - - -@triton.jit -def triton_rope_kernel( - tensor_ptr, - sin_table_ptr, - cos_table_ptr, - tensor_stride_n, - tensor_stride_l, - tensor_stride_h, - tensor_stride_e, - sin_table_stride_l, - cos_table_stride_l, - emb_dim, - INTERLEAVED: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - off_n = tl.program_id(0) - off_l = tl.program_id(1) - off_h = tl.program_id(2) - - offs = tl.arange(0, BLOCK_SIZE) - - half_emb_dim = emb_dim // 2 - mask = offs < half_emb_dim - - sin_table = tl.load(sin_table_ptr + off_l * sin_table_stride_l + offs, mask=mask) - cos_table = tl.load(cos_table_ptr + off_l * cos_table_stride_l + offs, mask=mask) - - even_offs = ( - off_n * tensor_stride_n + off_l * tensor_stride_l + off_h * tensor_stride_h - ) - odd_offs = ( - off_n * tensor_stride_n + off_l * tensor_stride_l + off_h * tensor_stride_h - ) - - if INTERLEAVED: - even_offs += (2 * offs) * tensor_stride_e - odd_offs += (2 * offs + 1) * tensor_stride_e - else: - even_offs += offs * tensor_stride_e - odd_offs += (offs + half_emb_dim) * tensor_stride_e - - even_ptrs = tensor_ptr + even_offs - odd_ptrs = tensor_ptr + odd_offs - - even = tl.load(even_ptrs, mask=mask) - odd = tl.load(odd_ptrs, mask=mask) - - tl.store(even_ptrs, even * cos_table - odd * sin_table, mask=mask) - tl.store(odd_ptrs, even * sin_table + odd * cos_table, mask=mask) - - -def triton_rope(tensor, sin_table, cos_table, interleaved=True): - batch_size, seq_len, num_heads, emb_dim = tensor.shape - - assert emb_dim % 2 == 0, "The embedding dimension must be even." - - BLOCK_SIZE = triton.next_power_of_2(emb_dim // 2) - if BLOCK_SIZE > 1024: - BLOCK_SIZE = 1024 - - grid = (batch_size, seq_len, num_heads) - - tensor_cloned = tensor.clone() - - triton_rope_kernel[grid]( - tensor_cloned, - sin_table, - cos_table, - *tensor.stride(), - sin_table.stride(0), - cos_table.stride(0), - emb_dim, - INTERLEAVED=interleaved, - BLOCK_SIZE=BLOCK_SIZE, - ) - - return tensor_cloned +import ops.ninetoothed.torch +import ops.triton.torch def torch_rope(input, sin_table, cos_table, interleaved=True): @@ -202,9 +55,11 @@ def _generate_sin_and_cos_tables( sin_table, cos_table = _generate_sin_and_cos_tables(seq_len, emb_dim) x = torch.randn(batch_size, seq_len, num_heads, emb_dim, dtype=dtype, device=device) - ninetoothed_output = rope(x, sin_table, cos_table, interleaved=False) + ninetoothed_output = ops.ninetoothed.torch.rope( + x, sin_table, cos_table, interleaved=False + ) torch_output = torch_rope(x, sin_table, cos_table, interleaved=False) - triton_output = triton_rope(x, sin_table, cos_table, interleaved=False) + triton_output = ops.triton.torch.rope(x, sin_table, cos_table, interleaved=False) print(ninetoothed_output) print(torch_output) @@ -242,11 +97,15 @@ def benchmark(seq_len, provider): x = torch.randn(shape, dtype=dtype, device=device) if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: rope(x, sin_table, cos_table)) + ms = triton.testing.do_bench( + lambda: ops.ninetoothed.torch.rope(x, sin_table, cos_table) + ) elif provider == "torch": ms = triton.testing.do_bench(lambda: torch_rope(x, sin_table, cos_table)) elif provider == "triton": - ms = triton.testing.do_bench(lambda: triton_rope(x, sin_table, cos_table)) + ms = triton.testing.do_bench( + lambda: ops.triton.torch.rope(x, sin_table, cos_table) + ) return ms From 8086ccc57e9317915b443ab056b073612132c571 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 16 May 2025 01:55:39 +0800 Subject: [PATCH 47/68] Add context managers to select the backend to use for inference --- fused_rms_norm.py | 32 +++++++++++++-- infer.py | 57 +++++++++++++++++--------- linear.py | 28 ++++++++++++- scaled_dot_product_attention.py | 72 +++++++++++++++++++++++---------- silu.py | 27 ++++++++++++- 5 files changed, 170 insertions(+), 46 deletions(-) diff --git a/fused_rms_norm.py b/fused_rms_norm.py index 8119674..a32ce9b 100644 --- a/fused_rms_norm.py +++ b/fused_rms_norm.py @@ -1,3 +1,5 @@ +from contextlib import contextmanager + import torch import torch.nn as nn import torch.nn.functional as F @@ -8,15 +10,39 @@ class RMSNorm(nn.Module): + fused_rms_norm = None + def __init__(self, other): super().__init__() self.__dict__ = other.__dict__ def forward(self, x): - return ops.ninetoothed.torch.fused_rms_norm( - x, self.weight, self.variance_epsilon - ) + return type(self).fused_rms_norm(x, self.weight, self.variance_epsilon) + + +@contextmanager +def rms_norm_backend(backend_name): + def _torch_fused_rms_norm(x, w, eps): + return F.rms_norm(x, x.shape[-1:], w, eps) + + _prev_impl = RMSNorm.fused_rms_norm + + if backend_name == "ninetoothed": + impl = ops.ninetoothed.torch.fused_rms_norm + elif backend_name == "triton": + impl = ops.triton.torch.fused_rms_norm + elif backend_name == "torch": + impl = _torch_fused_rms_norm + else: + raise ValueError(f"unknown backend: `{backend_name}`") + + RMSNorm.fused_rms_norm = impl + + try: + yield + finally: + RMSNorm.fused_rms_norm = _prev_impl if __name__ == "__main__": diff --git a/infer.py b/infer.py index 16ef327..99e862d 100644 --- a/infer.py +++ b/infer.py @@ -4,10 +4,14 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer -from fused_rms_norm import RMSNorm -from linear import Linear -from scaled_dot_product_attention import Attention -from silu import SiLU +from fused_rms_norm import RMSNorm, rms_norm_backend +from linear import Linear, bmm_backend +from scaled_dot_product_attention import ( + Attention, + rope_backend, + scaled_dot_product_attention_backend, +) +from silu import SiLU, silu_backend from utils import replace_module if __name__ == "__main__": @@ -40,6 +44,12 @@ default="cpu", help='Device to use for inference (e.g., "cuda", "cpu").', ) + parser.add_argument( + "--backend", + type=str, + default="ninetoothed", + help='Backend to use for inference (e.g., "ninetoothed", "triton", "torch").', + ) parser.add_argument( "--num-warmup-iterations", type=int, @@ -59,6 +69,7 @@ prompts = args.prompts max_new_tokens = args.max_new_tokens device = args.device + backend = args.backend num_warmup_iterations = args.num_warmup_iterations num_profiling_iterations = args.num_profiling_iterations @@ -71,31 +82,39 @@ tokenizer.pad_token = tokenizer.eos_token model.generation_config.pad_token_id = tokenizer.pad_token_id - replace_module(model, Attention) - replace_module(model, Linear) - replace_module(model, RMSNorm) - replace_module(model, SiLU) + if backend != "torch": + replace_module(model, Attention) + replace_module(model, Linear) + replace_module(model, RMSNorm) + replace_module(model, SiLU) inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(device) - for _ in range(num_warmup_iterations): - model.generate(**inputs, max_new_tokens=max_new_tokens) + with ( + bmm_backend(backend), + rms_norm_backend(backend), + rope_backend(backend), + scaled_dot_product_attention_backend(backend), + silu_backend(backend), + ): + for _ in range(num_warmup_iterations): + model.generate(**inputs, max_new_tokens=max_new_tokens) - if device == "cuda": - torch.cuda.synchronize() + if device == "cuda": + torch.cuda.synchronize() - start_time = time.time() + start_time = time.time() - for _ in range(num_profiling_iterations): - outputs = model.generate(**inputs, max_new_tokens=max_new_tokens) + for _ in range(num_profiling_iterations): + outputs = model.generate(**inputs, max_new_tokens=max_new_tokens) - if device == "cuda": - torch.cuda.synchronize() + if device == "cuda": + torch.cuda.synchronize() - end_time = time.time() - avg_time_ms = (end_time - start_time) * 1000 / num_profiling_iterations + end_time = time.time() strings = tokenizer.batch_decode(outputs, skip_special_tokens=True) + avg_time_ms = (end_time - start_time) * 1000 / num_profiling_iterations print(strings) print(f"\nAverage inference time: {avg_time_ms:.4f} ms.") diff --git a/linear.py b/linear.py index 3f42d3b..50bf179 100644 --- a/linear.py +++ b/linear.py @@ -1,15 +1,41 @@ +from contextlib import contextmanager + +import torch import torch.nn as nn import ops.ninetoothed.torch class Linear(nn.Module): + bmm = None + def __init__(self, other): super().__init__() self.__dict__ = other.__dict__ def forward(self, input): - return ops.ninetoothed.torch.bmm( + return type(self).bmm( input, self.weight.T.unsqueeze(0).expand(input.shape[0], -1, -1) ) + + +@contextmanager +def bmm_backend(backend_name): + _prev_impl = Linear.bmm + + if backend_name == "ninetoothed": + impl = ops.ninetoothed.torch.bmm + elif backend_name == "triton": + impl = ops.triton.torch.bmm + elif backend_name == "torch": + impl = torch.bmm + else: + raise ValueError(f"unknown backend: `{backend_name}`") + + Linear.bmm = impl + + try: + yield + finally: + Linear.bmm = _prev_impl diff --git a/scaled_dot_product_attention.py b/scaled_dot_product_attention.py index 04443c8..a137b13 100644 --- a/scaled_dot_product_attention.py +++ b/scaled_dot_product_attention.py @@ -1,6 +1,5 @@ -import functools +from contextlib import contextmanager -import ninetoothed import torch import torch.nn as nn import torch.nn.functional as F @@ -9,10 +8,14 @@ import ops.ninetoothed.torch import ops.triton.torch -import rope +from rope import torch_rope class Attention(nn.Module): + scaled_dot_product_attention = None + + rope = None + def __init__(self, other): super().__init__() @@ -35,9 +38,11 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape) cos_table, sin_table = position_embeddings + sin_table = sin_table[0] + cos_table = cos_table[0] - _rope(query_states, sin_table, cos_table) - _rope(key_states, sin_table, cos_table) + query_states = type(self).rope(query_states, sin_table, cos_table) + key_states = type(self).rope(key_states, sin_table, cos_table) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -56,13 +61,9 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_dtype = torch.float16 - attn_output = ops.ninetoothed.torch.scaled_dot_product_attention( - query_states.to(attn_dtype), - key_states.to(attn_dtype), - value_states.to(attn_dtype), - scale=self.scaling, - ).to(query_states.dtype) + attn_output = type(self).scaled_dot_product_attention( + query_states, key_states, value_states, scale=self.scaling + ) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -71,19 +72,46 @@ def forward( return attn_output, None -_rope_kernel = ninetoothed.make( - functools.partial(rope.arrangement, interleaved=False), - rope.application, - rope.tensors, -) +@contextmanager +def scaled_dot_product_attention_backend(backend_name): + _prev_impl = Attention.scaled_dot_product_attention + + if backend_name == "ninetoothed": + impl = ops.ninetoothed.torch.scaled_dot_product_attention + elif backend_name == "triton": + impl = ops.triton.torch.scaled_dot_product_attention + elif backend_name == "torch": + impl = F.scaled_dot_product_attention + else: + raise ValueError(f"unknown backend: `{backend_name}`") + + Attention.scaled_dot_product_attention = impl + + try: + yield + finally: + Attention.scaled_dot_product_attention = _prev_impl + +@contextmanager +def rope_backend(backend_name): + _prev_impl = Attention.rope + + if backend_name == "ninetoothed": + impl = ops.ninetoothed.torch.rope + elif backend_name == "triton": + impl = ops.triton.torch.rope + elif backend_name == "torch": + impl = torch_rope + else: + raise ValueError(f"unknown backend: `{backend_name}`") -def _rope(x, sin_table, cos_table): - _, _, num_heads, _ = x.shape - sin_table = sin_table.unsqueeze(2).expand(-1, -1, num_heads, -1) - cos_table = cos_table.unsqueeze(2).expand(-1, -1, num_heads, -1) + Attention.rope = impl - _rope_kernel(x, sin_table, cos_table) + try: + yield + finally: + Attention.rope = _prev_impl if __name__ == "__main__": diff --git a/silu.py b/silu.py index 785f9c6..fafae5d 100644 --- a/silu.py +++ b/silu.py @@ -1,3 +1,5 @@ +from contextlib import contextmanager + import torch import torch.nn as nn import torch.nn.functional as F @@ -8,13 +10,36 @@ class SiLU(nn.Module): + silu = None + def __init__(self, other): super().__init__() self.__dict__ = other.__dict__ def forward(self, input): - return ops.ninetoothed.torch.silu(input) + return type(self).silu(input) + + +@contextmanager +def silu_backend(backend_name): + _prev_impl = SiLU.silu + + if backend_name == "ninetoothed": + impl = ops.ninetoothed.torch.silu + elif backend_name == "triton": + impl = ops.triton.torch.silu + elif backend_name == "torch": + impl = F.silu + else: + raise ValueError(f"unknown backend: `{backend_name}`") + + SiLU.silu = impl + + try: + yield + finally: + SiLU.silu = _prev_impl if __name__ == "__main__": From d711d57c71e52edd2a15742bf28d15037c376499 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 16 May 2025 11:12:44 +0800 Subject: [PATCH 48/68] Relax tolerance in `scaled_dot_product_attention.py` --- scaled_dot_product_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scaled_dot_product_attention.py b/scaled_dot_product_attention.py index a137b13..ece4f1e 100644 --- a/scaled_dot_product_attention.py +++ b/scaled_dot_product_attention.py @@ -172,7 +172,7 @@ def benchmark(seq_len, provider): triton_output = ops.triton.torch.scaled_dot_product_attention(q, k, v) assert torch.allclose(ninetoothed_output, torch_output, atol=0.025, rtol=0.025) - assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) + assert torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.001) if provider == "ninetoothed": ms = triton.testing.do_bench( From 379ca83b07b6879d2fc42f3fb818ba8fff5445dd Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 16 May 2025 11:18:00 +0800 Subject: [PATCH 49/68] Improve inference logging with JSON output and token-level performance metrics --- infer.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/infer.py b/infer.py index 99e862d..96e1b72 100644 --- a/infer.py +++ b/infer.py @@ -1,4 +1,5 @@ import argparse +import json import time import torch @@ -114,7 +115,20 @@ end_time = time.time() strings = tokenizer.batch_decode(outputs, skip_special_tokens=True) - avg_time_ms = (end_time - start_time) * 1000 / num_profiling_iterations - print(strings) - print(f"\nAverage inference time: {avg_time_ms:.4f} ms.") + average_time = (end_time - start_time) / num_profiling_iterations + num_input_tokens = inputs["input_ids"].size(-1) + num_output_tokens = outputs.size(-1) - num_input_tokens + num_tokens_per_second = num_output_tokens / average_time + + print( + json.dumps( + { + "strings": strings, + "average_time": average_time, + "num_input_tokens": num_input_tokens, + "num_output_tokens": num_output_tokens, + "num_tokens_per_second": num_tokens_per_second, + } + ) + ) From 14b858221b6dcc9aa106bd872e436f98e3c7edcc Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 16 May 2025 11:38:13 +0800 Subject: [PATCH 50/68] Extract the backslash character into a constant --- compare_code_metrics.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/compare_code_metrics.py b/compare_code_metrics.py index a63ce95..0f6c951 100644 --- a/compare_code_metrics.py +++ b/compare_code_metrics.py @@ -12,6 +12,8 @@ _TRITON_KERNELS_PATH = _OPS_PATH / "triton" / "kernels" +_BACKSLASH_CHAR = "\\" + def _generate_cc_table(): path = _PARENT_PATH / "cc.json" @@ -129,7 +131,7 @@ def _key_from_kernel_name(path, kernel_name): return str(path / f"{kernel_name}.py").removeprefix(str(_PARENT_PATH))[1:] data = { - f"\\texttt{{{kernel_name.replace('scaled_dot_product_attention', 'sdpa').replace('_', '\\_')}}}": { + f"{_BACKSLASH_CHAR}texttt{{{kernel_name.replace('scaled_dot_product_attention', 'sdpa').replace('_', f'{_BACKSLASH_CHAR}_')}}}": { "Triton": { metric_name: data[ _key_from_kernel_name(_TRITON_KERNELS_PATH, kernel_name) From a5319d2c2394a352a6e7fe82cb3fe2adb6749123 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 19 May 2025 11:34:51 +0800 Subject: [PATCH 51/68] Update `.gitignore` to exclude evaluation result files --- .gitignore | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.gitignore b/.gitignore index 82f9275..fbfdfdd 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,10 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# Evaluation results +*.csv +*.html +*.json +*.png +*.tex From 72d552c4c331f98026bca1dd8f2790f0fa26fb8a Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 19 May 2025 11:35:56 +0800 Subject: [PATCH 52/68] Remove upper bound constraint in `conv2d.py` --- ops/ninetoothed/kernels/conv2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ops/ninetoothed/kernels/conv2d.py b/ops/ninetoothed/kernels/conv2d.py index 514bba9..5f2a49e 100644 --- a/ops/ninetoothed/kernels/conv2d.py +++ b/ops/ninetoothed/kernels/conv2d.py @@ -19,7 +19,7 @@ def arrangement(input, filter, output): return mm.arrangement(input_arranged, filter_arranged, output_arranged) -shape_options = {"constexpr": True, "upper_bound": 16} +shape_options = {"constexpr": True} tensors = tuple(Tensor(4, shape_options=shape_options) for _ in range(3)) kernel = ninetoothed.make(arrangement, mm.application, tensors) From c8ae164cf33dc124663c1f734b237de65f904b1d Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 19 May 2025 11:42:27 +0800 Subject: [PATCH 53/68] Rename `rope` to `rotary_position_embedding` --- compare_code_metrics.py | 2 +- infer.py | 4 +-- .../{rope.py => rotary_position_embedding.py} | 0 ops/ninetoothed/torch.py | 8 +++--- .../{rope.py => rotary_position_embedding.py} | 0 ops/triton/torch.py | 8 +++--- rope.py => rotary_position_embedding.py | 26 +++++++++++++------ scaled_dot_product_attention.py | 26 +++++++++++-------- 8 files changed, 46 insertions(+), 28 deletions(-) rename ops/ninetoothed/kernels/{rope.py => rotary_position_embedding.py} (100%) rename ops/triton/kernels/{rope.py => rotary_position_embedding.py} (100%) rename rope.py => rotary_position_embedding.py (80%) diff --git a/compare_code_metrics.py b/compare_code_metrics.py index 0f6c951..cdd92d5 100644 --- a/compare_code_metrics.py +++ b/compare_code_metrics.py @@ -131,7 +131,7 @@ def _key_from_kernel_name(path, kernel_name): return str(path / f"{kernel_name}.py").removeprefix(str(_PARENT_PATH))[1:] data = { - f"{_BACKSLASH_CHAR}texttt{{{kernel_name.replace('scaled_dot_product_attention', 'sdpa').replace('_', f'{_BACKSLASH_CHAR}_')}}}": { + f"{_BACKSLASH_CHAR}texttt{{{kernel_name.replace('scaled_dot_product_attention', 'sdpa').replace('rotary_position_embedding', 'rope').replace('_', f'{_BACKSLASH_CHAR}_')}}}": { "Triton": { metric_name: data[ _key_from_kernel_name(_TRITON_KERNELS_PATH, kernel_name) diff --git a/infer.py b/infer.py index 96e1b72..c5b2ab3 100644 --- a/infer.py +++ b/infer.py @@ -9,7 +9,7 @@ from linear import Linear, bmm_backend from scaled_dot_product_attention import ( Attention, - rope_backend, + rotary_position_embedding_backend, scaled_dot_product_attention_backend, ) from silu import SiLU, silu_backend @@ -94,7 +94,7 @@ with ( bmm_backend(backend), rms_norm_backend(backend), - rope_backend(backend), + rotary_position_embedding_backend(backend), scaled_dot_product_attention_backend(backend), silu_backend(backend), ): diff --git a/ops/ninetoothed/kernels/rope.py b/ops/ninetoothed/kernels/rotary_position_embedding.py similarity index 100% rename from ops/ninetoothed/kernels/rope.py rename to ops/ninetoothed/kernels/rotary_position_embedding.py diff --git a/ops/ninetoothed/torch.py b/ops/ninetoothed/torch.py index 7f2402d..fe0824d 100644 --- a/ops/ninetoothed/torch.py +++ b/ops/ninetoothed/torch.py @@ -9,7 +9,7 @@ import ops.ninetoothed.kernels.fused_rms_norm import ops.ninetoothed.kernels.mm import ops.ninetoothed.kernels.rms_norm -import ops.ninetoothed.kernels.rope +import ops.ninetoothed.kernels.rotary_position_embedding import ops.ninetoothed.kernels.scaled_dot_product_attention import ops.ninetoothed.kernels.silu import ops.ninetoothed.kernels.softmax @@ -92,14 +92,16 @@ def rms_norm(input, eps=None): return output -def rope(input, sin_table, cos_table, interleaved=True): +def rotary_position_embedding(input, sin_table, cos_table, interleaved=True): batch_size, _, num_heads, _ = input.shape output = input.clone() sin_table = sin_table[None, :, None, :].expand(batch_size, -1, num_heads, -1) cos_table = cos_table[None, :, None, :].expand(batch_size, -1, num_heads, -1) - ops.ninetoothed.kernels.rope.kernel(output, sin_table, cos_table, interleaved) + ops.ninetoothed.kernels.rotary_position_embedding.kernel( + output, sin_table, cos_table, interleaved + ) return output diff --git a/ops/triton/kernels/rope.py b/ops/triton/kernels/rotary_position_embedding.py similarity index 100% rename from ops/triton/kernels/rope.py rename to ops/triton/kernels/rotary_position_embedding.py diff --git a/ops/triton/torch.py b/ops/triton/torch.py index 4eb8777..6d40df3 100644 --- a/ops/triton/torch.py +++ b/ops/triton/torch.py @@ -10,7 +10,7 @@ import ops.triton.kernels.fused_rms_norm import ops.triton.kernels.mm import ops.triton.kernels.rms_norm -import ops.triton.kernels.rope +import ops.triton.kernels.rotary_position_embedding import ops.triton.kernels.scaled_dot_product_attention import ops.triton.kernels.silu import ops.triton.kernels.softmax @@ -195,14 +195,16 @@ def rms_norm(input, eps=None): return output -def rope(input, sin_table, cos_table, interleaved=True): +def rotary_position_embedding(input, sin_table, cos_table, interleaved=True): batch_size, seq_len, num_heads, emb_dim = input.shape BLOCK_SIZE = triton.next_power_of_2(emb_dim // 2) output = input.clone() - ops.triton.kernels.rope.kernel[(batch_size, seq_len, num_heads)]( + ops.triton.kernels.rotary_position_embedding.kernel[ + (batch_size, seq_len, num_heads) + ]( output, sin_table, cos_table, diff --git a/rope.py b/rotary_position_embedding.py similarity index 80% rename from rope.py rename to rotary_position_embedding.py index bac870a..0eb4e8b 100644 --- a/rope.py +++ b/rotary_position_embedding.py @@ -5,7 +5,7 @@ import ops.triton.torch -def torch_rope(input, sin_table, cos_table, interleaved=True): +def torch_rotary_position_embedding(input, sin_table, cos_table, interleaved=True): batch_size, seq_len, num_heads, emb_dim = input.shape assert emb_dim % 2 == 0, "The embedding dimension must be even." @@ -55,11 +55,15 @@ def _generate_sin_and_cos_tables( sin_table, cos_table = _generate_sin_and_cos_tables(seq_len, emb_dim) x = torch.randn(batch_size, seq_len, num_heads, emb_dim, dtype=dtype, device=device) - ninetoothed_output = ops.ninetoothed.torch.rope( + ninetoothed_output = ops.ninetoothed.torch.rotary_position_embedding( + x, sin_table, cos_table, interleaved=False + ) + torch_output = torch_rotary_position_embedding( + x, sin_table, cos_table, interleaved=False + ) + triton_output = ops.triton.torch.rotary_position_embedding( x, sin_table, cos_table, interleaved=False ) - torch_output = torch_rope(x, sin_table, cos_table, interleaved=False) - triton_output = ops.triton.torch.rope(x, sin_table, cos_table, interleaved=False) print(ninetoothed_output) print(torch_output) @@ -83,7 +87,7 @@ def _generate_sin_and_cos_tables( line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], ylabel="ms", - plot_name="rope-performance", + plot_name="rotary_position_embedding-performance", args={}, ) ) @@ -98,13 +102,19 @@ def benchmark(seq_len, provider): if provider == "ninetoothed": ms = triton.testing.do_bench( - lambda: ops.ninetoothed.torch.rope(x, sin_table, cos_table) + lambda: ops.ninetoothed.torch.rotary_position_embedding( + x, sin_table, cos_table + ) ) elif provider == "torch": - ms = triton.testing.do_bench(lambda: torch_rope(x, sin_table, cos_table)) + ms = triton.testing.do_bench( + lambda: torch_rotary_position_embedding(x, sin_table, cos_table) + ) elif provider == "triton": ms = triton.testing.do_bench( - lambda: ops.triton.torch.rope(x, sin_table, cos_table) + lambda: ops.triton.torch.rotary_position_embedding( + x, sin_table, cos_table + ) ) return ms diff --git a/scaled_dot_product_attention.py b/scaled_dot_product_attention.py index ece4f1e..de4c848 100644 --- a/scaled_dot_product_attention.py +++ b/scaled_dot_product_attention.py @@ -8,13 +8,13 @@ import ops.ninetoothed.torch import ops.triton.torch -from rope import torch_rope +from rotary_position_embedding import torch_rotary_position_embedding class Attention(nn.Module): scaled_dot_product_attention = None - rope = None + rotary_position_embedding = None def __init__(self, other): super().__init__() @@ -41,8 +41,12 @@ def forward( sin_table = sin_table[0] cos_table = cos_table[0] - query_states = type(self).rope(query_states, sin_table, cos_table) - key_states = type(self).rope(key_states, sin_table, cos_table) + query_states = type(self).rotary_position_embedding( + query_states, sin_table, cos_table + ) + key_states = type(self).rotary_position_embedding( + key_states, sin_table, cos_table + ) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -94,24 +98,24 @@ def scaled_dot_product_attention_backend(backend_name): @contextmanager -def rope_backend(backend_name): - _prev_impl = Attention.rope +def rotary_position_embedding_backend(backend_name): + _prev_impl = Attention.rotary_position_embedding if backend_name == "ninetoothed": - impl = ops.ninetoothed.torch.rope + impl = ops.ninetoothed.torch.rotary_position_embedding elif backend_name == "triton": - impl = ops.triton.torch.rope + impl = ops.triton.torch.rotary_position_embedding elif backend_name == "torch": - impl = torch_rope + impl = torch_rotary_position_embedding else: raise ValueError(f"unknown backend: `{backend_name}`") - Attention.rope = impl + Attention.rotary_position_embedding = impl try: yield finally: - Attention.rope = _prev_impl + Attention.rotary_position_embedding = _prev_impl if __name__ == "__main__": From 4ccc2bb1849d852d0738da23c5d22056202694f2 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 19 May 2025 14:22:59 +0800 Subject: [PATCH 54/68] Generate a single table instead of multiple tables --- compare_code_metrics.py | 67 +++++++++++++++++------------------------ 1 file changed, 27 insertions(+), 40 deletions(-) diff --git a/compare_code_metrics.py b/compare_code_metrics.py index cdd92d5..a199a55 100644 --- a/compare_code_metrics.py +++ b/compare_code_metrics.py @@ -1,3 +1,4 @@ +import functools import json import os.path from pathlib import Path @@ -32,9 +33,7 @@ def _generate_cc_table(): df = _generate_table(data, metric_names.values()) - styled_df = df.style.apply(_highlight_minimum, axis=None).format(precision=2) - - return styled_df.to_latex(hrules=True, multicol_align="c", convert_css=True) + return df def _generate_mi_table(): @@ -55,9 +54,7 @@ def _generate_mi_table(): df = _generate_table(data, metric_names.values()) - styled_df = df.style.apply(_highlight_maximum, axis=None).format(precision=2) - - return styled_df.to_latex(hrules=True, multicol_align="c", convert_css=True) + return df def _generate_raw_table(): @@ -78,27 +75,17 @@ def _generate_raw_table(): df = _generate_table(data, metric_names.values()) - styled_df = df.style.apply(_highlight_minimum, axis=None).format(precision=2) - - return styled_df.to_latex(hrules=True, multicol_align="c", convert_css=True) + return df def _generate_hal_table(): path = _PARENT_PATH / "hal.json" metric_names = { - "h1": "$\\eta_1$", - "h2": "$\\eta_2$", - "N1": "$N_1$", - "N2": "$N_2$", "vocabulary": "$\\eta$", "length": "$N$", - "calculated_length": "$\\hat{N}$", "volume": "$V$", "difficulty": "$D$", - "effort": "$E$", - "time": "$T$", - "bugs": "$B$", } data = json.loads(path.read_text()) @@ -114,9 +101,7 @@ def _generate_hal_table(): df = _generate_table(data, metric_names.values()) - styled_df = df.style.apply(_highlight_minimum, axis=None).format(precision=2) - - return styled_df.to_latex(hrules=True, multicol_align="c", convert_css=True) + return df def _generate_table(data, metric_names): @@ -162,37 +147,39 @@ def _key_from_kernel_name(path, kernel_name): return df -def _highlight_minimum(df): - styles = pd.DataFrame("", index=df.index, columns=df.columns) +def _highlight(df): + new_df = pd.DataFrame("", index=df.index, columns=df.columns) - for kernel, group in df.groupby(level=0): + for _, group in df[ + ["LOC", "LLOC", "SLOC", "$G$", "$\\eta$", "$N$", "$V$", "$D$"] + ].groupby(level=0): mask = group == group.min() - styles.update( + new_df.update( mask.replace(True, "background-color: green!20").replace(False, "") ) - return styles - - -def _highlight_maximum(df): - styles = pd.DataFrame("", index=df.index, columns=df.columns) - - for kernel, group in df.groupby(level=0): + for _, group in df[["$MI$"]].groupby(level=0): mask = group == group.max() - styles.update( + new_df.update( mask.replace(True, "background-color: green!20").replace(False, "") ) - return styles + return new_df if __name__ == "__main__": - for latex_code in ( - _generate_cc_table(), - _generate_mi_table(), - _generate_raw_table(), - _generate_hal_table(), - ): - print(latex_code) + raw_table = _generate_raw_table() + cc_table = _generate_cc_table() + hal_table = _generate_hal_table() + mi_table = _generate_mi_table() + + df = functools.reduce( + lambda left, right: pd.merge(left, right, left_index=True, right_index=True), + (raw_table, cc_table, hal_table, mi_table), + ) + + styler = df.style.apply(_highlight, axis=None).format(precision=2) + + print(styler.to_latex(hrules=True, multicol_align="c", convert_css=True)) From f3a8db16425f79a07464de8f7e60eb3a329298d3 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 21 May 2025 10:18:51 +0800 Subject: [PATCH 55/68] Replace the manually specified lists of Triton `Config` objects with programmatically generated tuples --- ops/triton/kernels/addmm.py | 91 +++---------------- ops/triton/kernels/bmm.py | 91 +++---------------- ops/triton/kernels/conv2d.py | 84 +++-------------- ops/triton/kernels/mm.py | 91 +++---------------- .../kernels/scaled_dot_product_attention.py | 30 +++--- 5 files changed, 63 insertions(+), 324 deletions(-) diff --git a/ops/triton/kernels/addmm.py b/ops/triton/kernels/addmm.py index 5ba7dc5..18fa693 100644 --- a/ops/triton/kernels/addmm.py +++ b/ops/triton/kernels/addmm.py @@ -1,90 +1,25 @@ +import itertools + import triton import triton.language as tl @triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), + configs=tuple( triton.Config( { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, "GROUP_SIZE_M": 8, }, - num_stages=5, - num_warps=2, - ), - ], + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m, block_size_n, block_size_k, num_stages, num_warps in itertools.product( + (32, 64, 128), (32, 64, 128, 256), (32, 64), (3, 4, 5), (2, 4, 8) + ) + ), key=["m", "n", "k"], ) @triton.jit diff --git a/ops/triton/kernels/bmm.py b/ops/triton/kernels/bmm.py index bc18a74..45cc1dd 100644 --- a/ops/triton/kernels/bmm.py +++ b/ops/triton/kernels/bmm.py @@ -1,90 +1,25 @@ +import itertools + import triton import triton.language as tl @triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), + configs=tuple( triton.Config( { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, "GROUP_SIZE_M": 8, }, - num_stages=5, - num_warps=2, - ), - ], + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m, block_size_n, block_size_k, num_stages, num_warps in itertools.product( + (32, 64, 128), (32, 64, 128, 256), (32, 64), (3, 4, 5), (2, 4, 8) + ) + ), key=["m", "n", "k"], ) @triton.jit diff --git a/ops/triton/kernels/conv2d.py b/ops/triton/kernels/conv2d.py index acc8460..7238276 100644 --- a/ops/triton/kernels/conv2d.py +++ b/ops/triton/kernels/conv2d.py @@ -1,82 +1,24 @@ +import itertools + import triton import triton.language as tl @triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - }, - num_stages=5, - num_warps=2, - ), + configs=tuple( triton.Config( { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, }, - num_stages=5, - num_warps=2, - ), - ], + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m, block_size_n, block_size_k, num_stages, num_warps in itertools.product( + (32, 64, 128), (32, 64, 128, 256), (32, 64), (3, 4, 5), (2, 4, 8) + ) + ), key=["N", "C", "H", "W", "C", "R", "S"], ) @triton.jit diff --git a/ops/triton/kernels/mm.py b/ops/triton/kernels/mm.py index eae5fd6..d2a4d55 100644 --- a/ops/triton/kernels/mm.py +++ b/ops/triton/kernels/mm.py @@ -1,90 +1,25 @@ +import itertools + import triton import triton.language as tl @triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=5, - num_warps=2, - ), + configs=tuple( triton.Config( { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, "GROUP_SIZE_M": 8, }, - num_stages=5, - num_warps=2, - ), - ], + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m, block_size_n, block_size_k, num_stages, num_warps in itertools.product( + (32, 64, 128), (32, 64, 128, 256), (32, 64), (3, 4, 5), (2, 4, 8) + ) + ), key=["m", "n", "k"], ) @triton.jit diff --git a/ops/triton/kernels/scaled_dot_product_attention.py b/ops/triton/kernels/scaled_dot_product_attention.py index 148700d..79e9970 100644 --- a/ops/triton/kernels/scaled_dot_product_attention.py +++ b/ops/triton/kernels/scaled_dot_product_attention.py @@ -1,28 +1,20 @@ +import itertools + import triton import triton.language as tl @triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8 - ), + configs=tuple( triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32}, num_stages=4, num_warps=8 - ), - ], + {"BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n}, + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m, block_size_n, num_stages, num_warps in itertools.product( + (32, 64, 128, 256), (32, 64, 128), (2, 3, 4, 5), (4, 8) + ) + ), key=["EMB_DIM"], ) @triton.jit From a9cefabf5bd2e3942632f2e49a93ec0015049a03 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 21 May 2025 10:32:12 +0800 Subject: [PATCH 56/68] Replace `torch.rand` with `torch.randn` --- add.py | 4 ++-- swiglu.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/add.py b/add.py index 782a7c7..06b6735 100644 --- a/add.py +++ b/add.py @@ -11,8 +11,8 @@ dtype = torch.float16 device = "cuda" - input = torch.rand(size, dtype=dtype, device=device) - other = torch.rand(size, dtype=dtype, device=device) + input = torch.randn(size, dtype=dtype, device=device) + other = torch.randn(size, dtype=dtype, device=device) ninetoothed_output = ops.ninetoothed.torch.add(input, other) torch_output = input + other diff --git a/swiglu.py b/swiglu.py index 3aa78b3..edf70c1 100644 --- a/swiglu.py +++ b/swiglu.py @@ -20,9 +20,9 @@ def torch_swiglu( dtype = torch.float16 device = "cuda" - a = torch.rand(shape, dtype=dtype, device=device) - b = torch.rand(shape, dtype=dtype, device=device) - c = torch.rand(shape, dtype=dtype, device=device) + a = torch.randn(shape, dtype=dtype, device=device) + b = torch.randn(shape, dtype=dtype, device=device) + c = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ops.ninetoothed.torch.swiglu(a, b) torch_output = torch_swiglu(a, b) @@ -58,8 +58,8 @@ def torch_swiglu( def benchmark(m, n, provider): shape = (m, n) - a = torch.rand(shape, dtype=dtype, device=device) - b = torch.rand(shape, dtype=dtype, device=device) + a = torch.randn(shape, dtype=dtype, device=device) + b = torch.randn(shape, dtype=dtype, device=device) if provider == "ninetoothed": ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.swiglu(a, b)) From 276c5abc7edd86443f22ad68274dd61cab8498bb Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 21 May 2025 10:41:43 +0800 Subject: [PATCH 57/68] Standardize `plot_name` values across scripts --- add.py | 2 +- bmm.py | 2 +- conv2d.py | 2 +- max_pool2d.py | 2 +- mm.py | 2 +- rotary_position_embedding.py | 2 +- scaled_dot_product_attention.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/add.py b/add.py index 06b6735..9b34f89 100644 --- a/add.py +++ b/add.py @@ -41,7 +41,7 @@ line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], ylabel="ms", - plot_name="vector-addition-performance", + plot_name="add-performance", args={}, ) ) diff --git a/bmm.py b/bmm.py index 3809f20..3243511 100644 --- a/bmm.py +++ b/bmm.py @@ -41,7 +41,7 @@ line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], ylabel="ms", - plot_name="batched-matrix-multiplication-performance", + plot_name="bmm-performance", args={"b": 4}, ) ) diff --git a/conv2d.py b/conv2d.py index 60e48c8..c7bd8bb 100644 --- a/conv2d.py +++ b/conv2d.py @@ -43,7 +43,7 @@ line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], ylabel="ms", - plot_name="2d-convolution-performance", + plot_name="conv2d-performance", args={}, ) ) diff --git a/max_pool2d.py b/max_pool2d.py index e76f588..584ee53 100644 --- a/max_pool2d.py +++ b/max_pool2d.py @@ -78,7 +78,7 @@ def max_pool2d(input, window_shape): line_names=["NineToothed", "PyTorch"], styles=[("blue", "-"), ("green", "-")], ylabel="ms", - plot_name="2d-max-pooling-performance", + plot_name="max-pool2d-performance", args={}, ) ) diff --git a/mm.py b/mm.py index 33eb17f..f4d0b95 100644 --- a/mm.py +++ b/mm.py @@ -41,7 +41,7 @@ line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], ylabel="ms", - plot_name="matrix-multiplication-performance", + plot_name="mm-performance", args={}, ) ) diff --git a/rotary_position_embedding.py b/rotary_position_embedding.py index 0eb4e8b..02719f6 100644 --- a/rotary_position_embedding.py +++ b/rotary_position_embedding.py @@ -87,7 +87,7 @@ def _generate_sin_and_cos_tables( line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], ylabel="ms", - plot_name="rotary_position_embedding-performance", + plot_name="rotary-position-embedding-performance", args={}, ) ) diff --git a/scaled_dot_product_attention.py b/scaled_dot_product_attention.py index de4c848..11f0ec0 100644 --- a/scaled_dot_product_attention.py +++ b/scaled_dot_product_attention.py @@ -157,7 +157,7 @@ def rotary_position_embedding_backend(backend_name): line_names=["NineToothed", "PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-"), ("orange", "-")], ylabel="ms", - plot_name="attention-performance", + plot_name="scaled-dot-product-attention-performance", args={}, ) ) From f75f8414be562ce76fde6a0f8adac62e434e0fe7 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 21 May 2025 18:15:28 +0800 Subject: [PATCH 58/68] Add `compare_performance_metrics.py` --- compare_performance_metrics.py | 128 +++++++++++++++++++++++++++++++++ performance_comparison.py | 93 ------------------------ 2 files changed, 128 insertions(+), 93 deletions(-) create mode 100644 compare_performance_metrics.py delete mode 100644 performance_comparison.py diff --git a/compare_performance_metrics.py b/compare_performance_metrics.py new file mode 100644 index 0000000..098f007 --- /dev/null +++ b/compare_performance_metrics.py @@ -0,0 +1,128 @@ +import functools +import random + +import matplotlib.pyplot as plt +import pandas as pd +import torch +import torch.nn.functional +import triton + +import ops.ninetoothed.torch +import ops.triton.torch +import rotary_position_embedding +from compare_code_metrics import _BACKSLASH_CHAR + + +def _run_task(op_name, dtype, device, *arg_shapes, **kwarg_shapes): + ninetoothed_op = getattr(ops.ninetoothed.torch, op_name) + triton_op = getattr(ops.triton.torch, op_name) + + if op_name == "rotary_position_embedding": + torch_op = rotary_position_embedding.torch_rotary_position_embedding + else: + torch_op = ( + getattr(torch, op_name) + if hasattr(torch, op_name) + else getattr(torch.nn.functional, op_name) + ) + + if op_name == "rms_norm": + torch_op = functools.partial(torch_op, normalized_shape=arg_shapes[0][-1:]) + elif op_name == "softmax": + torch_op = functools.partial(torch_op, dim=-1) + + args = tuple( + torch.randn(shape, dtype=dtype, device=device) if shape else random.gauss(0, 1) + for shape in arg_shapes + ) + kwargs = { + key: torch.randn(shape, dtype=dtype, device=device) + if shape + else random.gauss(0, 1) + for key, shape in kwarg_shapes.items() + } + + arg_shape_string = ", ".join(str(shape) for shape in arg_shapes) + kwarg_shape_string = ", ".join( + f"{key}={shape}" for key, shape in kwarg_shapes.items() + ) + shape_string = ( + f"{arg_shape_string}, {kwarg_shape_string}" + if kwarg_shape_string + else arg_shape_string + ) + + task_description = f"{op_name}({shape_string})" + + return task_description, _benchmark_ops( + (ninetoothed_op, triton_op, torch_op), *args, **kwargs + ) + + +def _benchmark_ops(ops, *args, **kwargs): + assert all( + torch.allclose( + op(*args, **kwargs), ops[0](*args, **kwargs), rtol=0.01, atol=0.01 + ) + for op in ops[1:] + ) + + return tuple(triton.testing.do_bench(lambda: op(*args, **kwargs)) for op in ops) + + +if __name__ == "__main__": + random.seed(0) + torch.manual_seed(0) + + plt.rcParams["figure.dpi"] = 600 + plt.rcParams["font.family"] = "Linux Biolinum" + + dtype = torch.float16 + device = "cuda" + + tasks = ( + ("add", ((4096 * 4096,), (4096 * 4096,)), {}), + ( + "addmm", + ((4096, 4096), (4096, 4096), (4096, 4096)), + {"beta": (), "alpha": ()}, + ), + ("bmm", ((4, 2048, 2048), (4, 2048, 2048)), {}), + ("conv2d", ((4, 512, 14, 14), (512, 512, 3, 3)), {}), + ("mm", ((4096, 4096), (4096, 4096)), {}), + ("rms_norm", ((4096, 4096),), {}), + ("rotary_position_embedding", ((4, 1024, 48, 64), (1024, 32), (1024, 32)), {}), + ( + "scaled_dot_product_attention", + ((4, 48, 1024, 64), (4, 48, 1024, 64), (4, 48, 1024, 64)), + {}, + ), + ("silu", ((4096 * 4096,),), {}), + ("softmax", ((4096, 4096),), {}), + ) + + data = {"Task": [], "NineToothed": [], "Triton": [], "PyTorch": []} + + for name, args, kwargs in tasks: + description, results = _run_task(name, dtype, device, *args, **kwargs) + + latex_item = f"\item {_BACKSLASH_CHAR}texttt{{{description.replace('scaled_dot_product_attention', 'sdpa').replace('rotary_position_embedding', 'rope').replace('_', f'{_BACKSLASH_CHAR}_')}}}" + + print(latex_item) + + data["Task"].append(description) + + for i, provider in enumerate(("NineToothed", "Triton", "PyTorch")): + data[provider].append(results[i]) + + df = pd.DataFrame(data) + df.index += 1 + + df.set_index("Task").to_csv("performance-metrics.csv") + + df.plot(kind="bar", rot=0) + plt.ylabel("Execution Time (ms)") + plt.xlabel("Task") + plt.grid(False) + plt.tight_layout() + plt.savefig("performance-metrics.png") diff --git a/performance_comparison.py b/performance_comparison.py deleted file mode 100644 index c5a93a7..0000000 --- a/performance_comparison.py +++ /dev/null @@ -1,93 +0,0 @@ -from dataclasses import dataclass - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -plt.rcParams["figure.figsize"] = [12, 6] -plt.rcParams["figure.dpi"] = 600 -plt.rcParams["font.family"] = "JetBrains Mono" -plt.rcParams["font.weight"] = "bold" -plt.rcParams["axes.titleweight"] = "bold" -plt.rcParams["axes.labelweight"] = "bold" - - -@dataclass -class KernelInformation: - name: str - perf_report_path: str - independent_variable: str - - -kernels = ( - KernelInformation("add", "vector-addition-performance.csv", "Length"), - KernelInformation("softmax", "softmax-performance.csv", "Number of Columns"), - KernelInformation("rms_norm", "rms-norm-performance.csv", "Number of Columns"), - KernelInformation("matmul", "matrix-multiplication-performance.csv", "Sizes"), - KernelInformation("conv2d", "2d-convolution-performance.csv", "Batch Size"), - KernelInformation("attention", "attention-performance.csv", "Sequence Length"), -) - -providers = ("Triton", "NineToothed", "PyTorch") - -num_rows = 2 -num_cols = 3 - -fig, axs = plt.subplots(num_rows, num_cols) - -performance_changes = [] - -for i, kernel in enumerate(kernels): - df = pd.read_csv(kernel.perf_report_path) - ax = axs[i // num_cols, i % num_cols] - - x = df.iloc[:, 0] - - performance_changes.append((kernel, [])) - - for provider in providers: - y = df[provider] - - ax.plot(x, y, label=provider) - - if provider == "NineToothed": - y_triton = df["Triton"] - change = (y - y_triton) / y_triton * 100 - performance_changes[-1][-1].append(change) - - ax.set_title(kernel.name) - ax.set_xlabel(kernel.independent_variable) - ax.set_ylabel("Execution Time (ms)") - ax.set_xscale("log", base=2) - -fig.legend(providers, loc="upper center", ncols=len(providers)) -fig.tight_layout() -fig.subplots_adjust(top=0.9) - -plt.show() -plt.savefig("performance-comparison.png") - -all_changes = [] -stats_data = [] - -for kernel, changes in performance_changes: - all_changes.extend(changes) - - kernel_stats = { - "Kernel": kernel.name, - "Mean": np.mean(changes), - "Median": np.median(changes), - } - - stats_data.append(kernel_stats) - -overall_stats = { - "Kernel": "Overall", - "Mean": np.mean(all_changes), - "Median": np.median(all_changes), -} - -stats_data.append(overall_stats) - -print("Relative Performance Change (%):") -print(pd.DataFrame(stats_data)) From 123bd37d2d0928ad072887dbc6efb8f82847a26c Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 21 May 2025 18:34:30 +0800 Subject: [PATCH 59/68] Add `run_experiments.py` --- run_experiments.py | 82 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 run_experiments.py diff --git a/run_experiments.py b/run_experiments.py new file mode 100644 index 0000000..11f6ce8 --- /dev/null +++ b/run_experiments.py @@ -0,0 +1,82 @@ +import argparse +import subprocess + +PROMPTS = ( + "The emergence of deep learning domain-specific languages (DSLs) has substantially reduced the obstacles in developing high-performance, cross-platform compute kernels, but current DSLs", + "Driven by recent advancements in the AI industry, the AI accelerator sector has increasingly diversified, with vendors developing their own hardware architectures and programming models, such as NVIDIA", +) + +NUM_WARMUP_ITERATIONS = 1 + +NUM_PROFILING_ITERATIONS = 3 + +BACKENDS = ("ninetoothed", "triton", "torch") + +ALL_MAX_NEW_TOKENS = (128, 512, 2048) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run experiments.") + + parser.add_argument( + "--model", + type=str, + required=True, + help="Path to the model or model identifier from Hugging Face.", + ) + + args = parser.parse_args() + + model_name_or_path = args.model + + radon_commands = ( + ( + "radon", + "cc", + "--show-complexity", + "--json", + "--output-file", + "cc.json", + "ops/", + ), + ("radon", "mi", "--show", "--json", "--output-file", "mi.json", "ops/"), + ("radon", "raw", "--json", "--output-file", "raw.json", "ops/"), + ("radon", "hal", "--json", "--output-file", "hal.json", "ops/"), + ) + + for command in radon_commands: + subprocess.run(command, check=True) + + with open("code_metrics.tex", "w") as f: + subprocess.run(("python", "compare_code_metrics.py"), stdout=f, check=True) + + for max_new_tokens in ALL_MAX_NEW_TOKENS: + for backend in BACKENDS: + with open(f"infer_{max_new_tokens}_{backend}.json", "w") as f: + subprocess.run( + ( + "python", + "infer.py", + "--model", + model_name_or_path, + "--prompts", + *PROMPTS, + "--max-new-tokens", + str(max_new_tokens), + "--device", + "cuda", + "--backend", + "ninetoothed", + "--num-warmup-iterations", + str(NUM_WARMUP_ITERATIONS), + "--num-profiling-iterations", + str(NUM_PROFILING_ITERATIONS), + ), + stdout=f, + check=True, + ) + + with open("performance_metrics.tex", "w") as f: + subprocess.run( + ("python", "compare_performance_metrics.py"), stdout=f, check=True + ) From 0f8477ce89d3f01a5bf97549689f4182d40fa477 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 21 May 2025 18:43:39 +0800 Subject: [PATCH 60/68] Run tasks in `run_experiments.py` instead of `compare_performance_metrics.py` --- compare_performance_metrics.py | 112 +-------------------------------- run_experiments.py | 109 ++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 109 deletions(-) diff --git a/compare_performance_metrics.py b/compare_performance_metrics.py index 098f007..48ba6f4 100644 --- a/compare_performance_metrics.py +++ b/compare_performance_metrics.py @@ -1,125 +1,19 @@ -import functools -import random - import matplotlib.pyplot as plt import pandas as pd -import torch -import torch.nn.functional -import triton -import ops.ninetoothed.torch -import ops.triton.torch -import rotary_position_embedding from compare_code_metrics import _BACKSLASH_CHAR - -def _run_task(op_name, dtype, device, *arg_shapes, **kwarg_shapes): - ninetoothed_op = getattr(ops.ninetoothed.torch, op_name) - triton_op = getattr(ops.triton.torch, op_name) - - if op_name == "rotary_position_embedding": - torch_op = rotary_position_embedding.torch_rotary_position_embedding - else: - torch_op = ( - getattr(torch, op_name) - if hasattr(torch, op_name) - else getattr(torch.nn.functional, op_name) - ) - - if op_name == "rms_norm": - torch_op = functools.partial(torch_op, normalized_shape=arg_shapes[0][-1:]) - elif op_name == "softmax": - torch_op = functools.partial(torch_op, dim=-1) - - args = tuple( - torch.randn(shape, dtype=dtype, device=device) if shape else random.gauss(0, 1) - for shape in arg_shapes - ) - kwargs = { - key: torch.randn(shape, dtype=dtype, device=device) - if shape - else random.gauss(0, 1) - for key, shape in kwarg_shapes.items() - } - - arg_shape_string = ", ".join(str(shape) for shape in arg_shapes) - kwarg_shape_string = ", ".join( - f"{key}={shape}" for key, shape in kwarg_shapes.items() - ) - shape_string = ( - f"{arg_shape_string}, {kwarg_shape_string}" - if kwarg_shape_string - else arg_shape_string - ) - - task_description = f"{op_name}({shape_string})" - - return task_description, _benchmark_ops( - (ninetoothed_op, triton_op, torch_op), *args, **kwargs - ) - - -def _benchmark_ops(ops, *args, **kwargs): - assert all( - torch.allclose( - op(*args, **kwargs), ops[0](*args, **kwargs), rtol=0.01, atol=0.01 - ) - for op in ops[1:] - ) - - return tuple(triton.testing.do_bench(lambda: op(*args, **kwargs)) for op in ops) - - if __name__ == "__main__": - random.seed(0) - torch.manual_seed(0) - plt.rcParams["figure.dpi"] = 600 plt.rcParams["font.family"] = "Linux Biolinum" - dtype = torch.float16 - device = "cuda" - - tasks = ( - ("add", ((4096 * 4096,), (4096 * 4096,)), {}), - ( - "addmm", - ((4096, 4096), (4096, 4096), (4096, 4096)), - {"beta": (), "alpha": ()}, - ), - ("bmm", ((4, 2048, 2048), (4, 2048, 2048)), {}), - ("conv2d", ((4, 512, 14, 14), (512, 512, 3, 3)), {}), - ("mm", ((4096, 4096), (4096, 4096)), {}), - ("rms_norm", ((4096, 4096),), {}), - ("rotary_position_embedding", ((4, 1024, 48, 64), (1024, 32), (1024, 32)), {}), - ( - "scaled_dot_product_attention", - ((4, 48, 1024, 64), (4, 48, 1024, 64), (4, 48, 1024, 64)), - {}, - ), - ("silu", ((4096 * 4096,),), {}), - ("softmax", ((4096, 4096),), {}), - ) - - data = {"Task": [], "NineToothed": [], "Triton": [], "PyTorch": []} - - for name, args, kwargs in tasks: - description, results = _run_task(name, dtype, device, *args, **kwargs) - - latex_item = f"\item {_BACKSLASH_CHAR}texttt{{{description.replace('scaled_dot_product_attention', 'sdpa').replace('rotary_position_embedding', 'rope').replace('_', f'{_BACKSLASH_CHAR}_')}}}" + df = pd.read_csv("performance-metrics.csv") + for task in df["Task"]: + latex_item = f"\item {_BACKSLASH_CHAR}texttt{{{task.replace('scaled_dot_product_attention', 'sdpa').replace('rotary_position_embedding', 'rope').replace('_', f'{_BACKSLASH_CHAR}_')}}}" print(latex_item) - data["Task"].append(description) - - for i, provider in enumerate(("NineToothed", "Triton", "PyTorch")): - data[provider].append(results[i]) - - df = pd.DataFrame(data) df.index += 1 - - df.set_index("Task").to_csv("performance-metrics.csv") - df.plot(kind="bar", rot=0) plt.ylabel("Execution Time (ms)") plt.xlabel("Task") diff --git a/run_experiments.py b/run_experiments.py index 11f6ce8..9cfc0b3 100644 --- a/run_experiments.py +++ b/run_experiments.py @@ -1,6 +1,17 @@ import argparse +import functools +import random import subprocess +import pandas as pd +import torch +import torch.nn.functional +import triton + +import ops.ninetoothed.torch +import ops.triton.torch +import rotary_position_embedding + PROMPTS = ( "The emergence of deep learning domain-specific languages (DSLs) has substantially reduced the obstacles in developing high-performance, cross-platform compute kernels, but current DSLs", "Driven by recent advancements in the AI industry, the AI accelerator sector has increasingly diversified, with vendors developing their own hardware architectures and programming models, such as NVIDIA", @@ -15,6 +26,63 @@ ALL_MAX_NEW_TOKENS = (128, 512, 2048) +def _run_task(op_name, dtype, device, *arg_shapes, **kwarg_shapes): + ninetoothed_op = getattr(ops.ninetoothed.torch, op_name) + triton_op = getattr(ops.triton.torch, op_name) + + if op_name == "rotary_position_embedding": + torch_op = rotary_position_embedding.torch_rotary_position_embedding + else: + torch_op = ( + getattr(torch, op_name) + if hasattr(torch, op_name) + else getattr(torch.nn.functional, op_name) + ) + + if op_name == "rms_norm": + torch_op = functools.partial(torch_op, normalized_shape=arg_shapes[0][-1:]) + elif op_name == "softmax": + torch_op = functools.partial(torch_op, dim=-1) + + args = tuple( + torch.randn(shape, dtype=dtype, device=device) if shape else random.gauss(0, 1) + for shape in arg_shapes + ) + kwargs = { + key: torch.randn(shape, dtype=dtype, device=device) + if shape + else random.gauss(0, 1) + for key, shape in kwarg_shapes.items() + } + + arg_shape_string = ", ".join(str(shape) for shape in arg_shapes) + kwarg_shape_string = ", ".join( + f"{key}={shape}" for key, shape in kwarg_shapes.items() + ) + shape_string = ( + f"{arg_shape_string}, {kwarg_shape_string}" + if kwarg_shape_string + else arg_shape_string + ) + + task_description = f"{op_name}({shape_string})" + + return task_description, _benchmark_ops( + (ninetoothed_op, triton_op, torch_op), *args, **kwargs + ) + + +def _benchmark_ops(ops, *args, **kwargs): + assert all( + torch.allclose( + op(*args, **kwargs), ops[0](*args, **kwargs), rtol=0.01, atol=0.01 + ) + for op in ops[1:] + ) + + return tuple(triton.testing.do_bench(lambda: op(*args, **kwargs)) for op in ops) + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run experiments.") @@ -29,6 +97,9 @@ model_name_or_path = args.model + random.seed(0) + torch.manual_seed(0) + radon_commands = ( ( "radon", @@ -50,6 +121,44 @@ with open("code_metrics.tex", "w") as f: subprocess.run(("python", "compare_code_metrics.py"), stdout=f, check=True) + dtype = torch.float16 + device = "cuda" + + tasks = ( + ("add", ((4096 * 4096,), (4096 * 4096,)), {}), + ( + "addmm", + ((4096, 4096), (4096, 4096), (4096, 4096)), + {"beta": (), "alpha": ()}, + ), + ("bmm", ((4, 2048, 2048), (4, 2048, 2048)), {}), + ("conv2d", ((4, 512, 14, 14), (512, 512, 3, 3)), {}), + ("mm", ((4096, 4096), (4096, 4096)), {}), + ("rms_norm", ((4096, 4096),), {}), + ("rotary_position_embedding", ((4, 1024, 48, 64), (1024, 32), (1024, 32)), {}), + ( + "scaled_dot_product_attention", + ((4, 48, 1024, 64), (4, 48, 1024, 64), (4, 48, 1024, 64)), + {}, + ), + ("silu", ((4096 * 4096,),), {}), + ("softmax", ((4096, 4096),), {}), + ) + + data = {"Task": [], "NineToothed": [], "Triton": [], "PyTorch": []} + + for name, args, kwargs in tasks: + description, results = _run_task(name, dtype, device, *args, **kwargs) + + data["Task"].append(description) + + for i, provider in enumerate(("NineToothed", "Triton", "PyTorch")): + data[provider].append(results[i]) + + df = pd.DataFrame(data) + + df.set_index("Task").to_csv("performance-metrics.csv") + for max_new_tokens in ALL_MAX_NEW_TOKENS: for backend in BACKENDS: with open(f"infer_{max_new_tokens}_{backend}.json", "w") as f: From 27ae4b9d71adc2d1ad05ddca9a211003e97e3c76 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 21 May 2025 19:12:26 +0800 Subject: [PATCH 61/68] Add end-to-end model inference throughput comparison plot --- compare_performance_metrics.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/compare_performance_metrics.py b/compare_performance_metrics.py index 48ba6f4..f4955ea 100644 --- a/compare_performance_metrics.py +++ b/compare_performance_metrics.py @@ -1,7 +1,10 @@ +import json + import matplotlib.pyplot as plt import pandas as pd from compare_code_metrics import _BACKSLASH_CHAR +from run_experiments import ALL_MAX_NEW_TOKENS, BACKENDS if __name__ == "__main__": plt.rcParams["figure.dpi"] = 600 @@ -20,3 +23,28 @@ plt.grid(False) plt.tight_layout() plt.savefig("performance-metrics.png") + + data = {"Output Length": [], "NineToothed": [], "Triton": [], "PyTorch": []} + + for max_new_tokens in ALL_MAX_NEW_TOKENS: + data["Output Length"].append(max_new_tokens) + + for backend in BACKENDS: + with open(f"infer_{max_new_tokens}_{backend}.json") as f: + num_tokens_per_second = json.load(f)["num_tokens_per_second"] + + if backend == "ninetoothed": + data["NineToothed"].append(num_tokens_per_second) + elif backend == "triton": + data["Triton"].append(num_tokens_per_second) + elif backend == "torch": + data["PyTorch"].append(num_tokens_per_second) + + df = pd.DataFrame(data) + + df.set_index("Output Length").plot(kind="bar", rot=0) + plt.ylabel("Throughput (TPS)") + plt.xlabel("Output Length") + plt.grid(False) + plt.tight_layout() + plt.savefig("end-to-end-performance-metrics.png") From 05171082e06a68528ced8bc32598608631147ac7 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 23 May 2025 00:15:36 +0800 Subject: [PATCH 62/68] Add `transformers` and `radon` to `requirements.txt` --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 1ebe6ae..8844f71 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,5 @@ ninetoothed torch matplotlib pandas +transformers +radon From bd6a73001884cbd9b34056a7ccac283d9b0f2486 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 23 May 2025 08:39:07 +0800 Subject: [PATCH 63/68] Refactor CSV export to occur within task-processing loop --- run_experiments.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/run_experiments.py b/run_experiments.py index 9cfc0b3..f834fce 100644 --- a/run_experiments.py +++ b/run_experiments.py @@ -155,9 +155,7 @@ def _benchmark_ops(ops, *args, **kwargs): for i, provider in enumerate(("NineToothed", "Triton", "PyTorch")): data[provider].append(results[i]) - df = pd.DataFrame(data) - - df.set_index("Task").to_csv("performance-metrics.csv") + pd.DataFrame(data).set_index("Task").to_csv("performance-metrics.csv") for max_new_tokens in ALL_MAX_NEW_TOKENS: for backend in BACKENDS: From 772cae8a82cc01291fe909287076c3caefef4829 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 23 May 2025 08:45:23 +0800 Subject: [PATCH 64/68] Rename the output CSV file from `performance-metrics.csv` to `microbenchmark_data.csv` --- compare_performance_metrics.py | 2 +- run_experiments.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/compare_performance_metrics.py b/compare_performance_metrics.py index f4955ea..ab7f2c6 100644 --- a/compare_performance_metrics.py +++ b/compare_performance_metrics.py @@ -10,7 +10,7 @@ plt.rcParams["figure.dpi"] = 600 plt.rcParams["font.family"] = "Linux Biolinum" - df = pd.read_csv("performance-metrics.csv") + df = pd.read_csv("microbenchmark_data.csv") for task in df["Task"]: latex_item = f"\item {_BACKSLASH_CHAR}texttt{{{task.replace('scaled_dot_product_attention', 'sdpa').replace('rotary_position_embedding', 'rope').replace('_', f'{_BACKSLASH_CHAR}_')}}}" diff --git a/run_experiments.py b/run_experiments.py index f834fce..5e42f62 100644 --- a/run_experiments.py +++ b/run_experiments.py @@ -155,7 +155,7 @@ def _benchmark_ops(ops, *args, **kwargs): for i, provider in enumerate(("NineToothed", "Triton", "PyTorch")): data[provider].append(results[i]) - pd.DataFrame(data).set_index("Task").to_csv("performance-metrics.csv") + pd.DataFrame(data).set_index("Task").to_csv("microbenchmark_data.csv") for max_new_tokens in ALL_MAX_NEW_TOKENS: for backend in BACKENDS: From 287fc78b5d95c00b15b48cc64d8f6f484fa3336a Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 23 May 2025 08:50:55 +0800 Subject: [PATCH 65/68] Rename output image files --- compare_performance_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compare_performance_metrics.py b/compare_performance_metrics.py index ab7f2c6..f1edb4b 100644 --- a/compare_performance_metrics.py +++ b/compare_performance_metrics.py @@ -22,7 +22,7 @@ plt.xlabel("Task") plt.grid(False) plt.tight_layout() - plt.savefig("performance-metrics.png") + plt.savefig("microbenchmark-results.png") data = {"Output Length": [], "NineToothed": [], "Triton": [], "PyTorch": []} @@ -47,4 +47,4 @@ plt.xlabel("Output Length") plt.grid(False) plt.tight_layout() - plt.savefig("end-to-end-performance-metrics.png") + plt.savefig("benchmark-results.png") From b04acf85ac84c3123dad09758d73a35ac5c2a3bc Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 23 May 2025 08:54:19 +0800 Subject: [PATCH 66/68] Rename evaluation scripts and output files for consistency and clarity across code and performance evaluation --- compare_code_metrics.py => evaluate_code.py | 0 ...e_performance_metrics.py => evaluate_performance.py | 2 +- run_experiments.py | 10 ++++------ 3 files changed, 5 insertions(+), 7 deletions(-) rename compare_code_metrics.py => evaluate_code.py (100%) rename compare_performance_metrics.py => evaluate_performance.py (97%) diff --git a/compare_code_metrics.py b/evaluate_code.py similarity index 100% rename from compare_code_metrics.py rename to evaluate_code.py diff --git a/compare_performance_metrics.py b/evaluate_performance.py similarity index 97% rename from compare_performance_metrics.py rename to evaluate_performance.py index f1edb4b..7b9af13 100644 --- a/compare_performance_metrics.py +++ b/evaluate_performance.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt import pandas as pd -from compare_code_metrics import _BACKSLASH_CHAR +from evaluate_code import _BACKSLASH_CHAR from run_experiments import ALL_MAX_NEW_TOKENS, BACKENDS if __name__ == "__main__": diff --git a/run_experiments.py b/run_experiments.py index 5e42f62..d0149aa 100644 --- a/run_experiments.py +++ b/run_experiments.py @@ -118,8 +118,8 @@ def _benchmark_ops(ops, *args, **kwargs): for command in radon_commands: subprocess.run(command, check=True) - with open("code_metrics.tex", "w") as f: - subprocess.run(("python", "compare_code_metrics.py"), stdout=f, check=True) + with open("code_evaluation.tex", "w") as f: + subprocess.run(("python", "evaluate_code.py"), stdout=f, check=True) dtype = torch.float16 device = "cuda" @@ -183,7 +183,5 @@ def _benchmark_ops(ops, *args, **kwargs): check=True, ) - with open("performance_metrics.tex", "w") as f: - subprocess.run( - ("python", "compare_performance_metrics.py"), stdout=f, check=True - ) + with open("performance_evaluation.tex", "w") as f: + subprocess.run(("python", "evaluate_performance.py"), stdout=f, check=True) From e8a0d1a88a36bef26a95175ea42782749351d42a Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 23 May 2025 08:58:30 +0800 Subject: [PATCH 67/68] Use `torch.utils.collect_env` in `run_experiments.py` --- run_experiments.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/run_experiments.py b/run_experiments.py index d0149aa..1b0b68f 100644 --- a/run_experiments.py +++ b/run_experiments.py @@ -100,6 +100,11 @@ def _benchmark_ops(ops, *args, **kwargs): random.seed(0) torch.manual_seed(0) + with open("torch.utils.collect_env.log", "w") as f: + subprocess.run( + ("python", "-m", "torch.utils.collect_env"), stdout=f, stderr=f, check=True + ) + radon_commands = ( ( "radon", From 7798e3949a35b59d9255d72a8caac128a9e89a46 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 2 Jul 2025 11:33:10 +0800 Subject: [PATCH 68/68] Update `README.md` --- README.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 03557df..dcf435c 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,20 @@ # NineToothed Examples -This repository contains examples for [NineToothed](https://github.com/InfiniTensor/ninetoothed), including implementations of several common compute kernels written using NineToothed. +This repository contains examples of [NineToothed](https://github.com/InfiniTensor/ninetoothed), including implementations of several common compute kernels written using NineToothed. ## Usage After cloning this repository, you can run any of the examples using Python. For instance, to run the matrix multiplication example, execute the following command: ```bash -python matmul.py +python mm.py ``` ### Autotuning Behavior -By default, the examples apply autotuning, which may take several minutes or longer to complete for complex kernels. If you wish to disable autotuning, you can replace symbol definitions with concrete values. Consider the following example: +Some examples apply autotuning, which may take several minutes or longer to complete for complex kernels. If you wish to disable autotuning, you can replace symbol definitions with concrete values. + +Consider the following example: ```python BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True) @@ -29,6 +31,8 @@ BLOCK_SIZE = 1024 These approaches allow you to obtain results in seconds. However, selecting optimal values is crucial for good performance. Experiment with different values to determine the best configuration. +Note: Please don't forget to also disable the autotuning of the corresponding Triton compute kernels. + ## Third-Party Code and Licenses This project includes code modified or inspired from the following open-source repositories: