From 198880b61158a1a17d5f1ca40aeea44bf359a796 Mon Sep 17 00:00:00 2001 From: wlxjhlf Date: Tue, 21 Jan 2025 09:43:42 +0800 Subject: [PATCH 1/4] dot --- benchmark/test_reduction_perf.py | 19 ++++++++ src/flag_gems/__init__.py | 1 + src/flag_gems/ops/__init__.py | 2 + src/flag_gems/ops/dot.py | 78 ++++++++++++++++++++++++++++++++ tests/test_reduction_ops.py | 17 +++++++ 5 files changed, 117 insertions(+) create mode 100644 src/flag_gems/ops/dot.py diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index 6cc50495b..00460412d 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -11,6 +11,7 @@ Benchmark, Config, GenericBenchmark2DOnly, + GenericBenchmark, SkipVersion, generate_tensor_input, unary_input_fn, @@ -186,3 +187,21 @@ def count_nonzero_input_fn(shape, dtype, device): dtypes=FLOAT_DTYPES, ) bench.run() + + +@pytest.mark.dot +def test_perf_dot(): + def dot_input_fn(shape, dtype, device): + inp = generate_tensor_input(shape, dtype=dtype, device=device) + if inp.dim() > 1: + inp = inp.flatten() + yield inp, inp + + bench = GenericBenchmark( + input_fn = dot_input_fn, + op_name = "dot", + torch_op = torch.dot, + dtypes = FLOAT_DTYPES, + ) + + bench.run() \ No newline at end of file diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 0ca58a4a8..4c1988359 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -209,6 +209,7 @@ def enable(lib=aten_lib, unused=None, registrar=registrar): ("logical_and", logical_and, Autograd.disable), ("logical_xor", logical_xor, Autograd.disable), ("logical_not", logical_not, Autograd.disable), + ("dot", dot, Autograd.disable), ), user_unused_ops_list=[] if unused is None else unused, lib=lib, diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 354a14cca..35bb3247a 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -127,6 +127,7 @@ from .where import where_scalar_other, where_scalar_self, where_self, where_self_out from .zeros import zeros from .zeros_like import zeros_like +from .dot import dot __all__ = [ "all", @@ -289,4 +290,5 @@ "logical_xor", "logical_not", "sort", + "dot", ] diff --git a/src/flag_gems/ops/dot.py b/src/flag_gems/ops/dot.py new file mode 100644 index 000000000..d18408ce1 --- /dev/null +++ b/src/flag_gems/ops/dot.py @@ -0,0 +1,78 @@ +import logging +import math + +import torch +import triton +import triton.language as tl + +from .. import runtime +from ..runtime import torch_device_fn +from ..utils import libentry +from ..utils import triton_lang_extension as tle + + + + +@libentry() +@triton.jit +def dot_kernel_1( + x_ptr, + y_ptr, + mid_ptr, + N, + BLOCK_SIZE: tl.constexpr +): + pid = tle.program_id(0) + block_start = pid * BLOCK_SIZE + + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < N + x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + + partial_sum = tl.sum(x * y) + tl.store(mid_ptr + pid, partial_sum) + + +@libentry() +@triton.jit +def dot_kernel_2( + mid_ptr, + out_ptr, + M, + BLOCK_MID: tl.constexpr +): + offset = tl.arange(0, BLOCK_MID) + mid = mid_ptr + offset + mask = offset < M + mid_val = tl.load(mid, mask=mask, other=0.0) + out_val = tl.sum(mid_val) + tl.store(out_ptr, out_val) + + +def dot(x, y): + logging.debug("Triton Dot Product") + + assert x.shape == y.shape, "Input vectors must have the same shape" + assert x.dim() == 1, "Input must be 1D tensors" + + N = x.shape[0] + + block_size = triton.next_power_of_2(math.ceil(math.sqrt(N))) + mid_size = triton.cdiv(N, block_size) + block_mid = triton.next_power_of_2(mid_size) + + grid_1 = (mid_size, 1, 1) + grid_2 = (1, 1, 1) + + mid = torch.empty((mid_size,), dtype=torch.float32, device=x.device) + out = torch.empty([], dtype=x.dtype, device=x.device) + + with torch_device_fn.device(x.device): + dot_kernel_1[grid_1](x, y, mid, N, block_size) + dot_kernel_2[grid_2](mid, out, mid_size, block_mid) + + return out + + diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 37f4ab2c2..438288155 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -13,6 +13,7 @@ REDUCTION_SHAPES, REDUCTION_SMALL_SHAPES, SHAPE_STRIDES, + UT_SHAPES_1D, SkipVersion, gems_assert_close, gems_assert_equal, @@ -874,3 +875,19 @@ def test_accuracy_depthwise2d( inp, weight, kernel, bias=None, stride=stride, padding=padding, dilation=1 ) gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.dot +@pytest.mark.parametrize("shape", UT_SHAPES_1D) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_dot_tensor_tensor(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + ref_inp1 = to_reference(inp1, False) + ref_inp2 = to_reference(inp2, False) + + ref_out = torch.dot(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.dot(inp1, inp2) + + gems_assert_close(res_out, ref_out, dtype, equal_nan=True) From e329cd13d3ebe7323e9786f6b0f836baa248670e Mon Sep 17 00:00:00 2001 From: wlxjhyf Date: Mon, 24 Feb 2025 12:50:54 +0800 Subject: [PATCH 2/4] dot kernel --- src/flag_gems/ops/dot.py | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/src/flag_gems/ops/dot.py b/src/flag_gems/ops/dot.py index d18408ce1..5b0494159 100644 --- a/src/flag_gems/ops/dot.py +++ b/src/flag_gems/ops/dot.py @@ -11,6 +11,27 @@ from ..utils import triton_lang_extension as tle +@libentry() +@triton.jit +def dot_kernel( + x_ptr, + y_ptr, + out_ptr, + N, + BLOCK_SIZE: tl.constexpr +): + pid = tle.program_id(0) + block_start = pid * BLOCK_SIZE + + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < N + x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + + partial_sum = tl.sum(x * y) + + tl.atomic_add(out_ptr, partial_sum) @libentry() @@ -57,9 +78,10 @@ def dot(x, y): assert x.shape == y.shape, "Input vectors must have the same shape" assert x.dim() == 1, "Input must be 1D tensors" - N = x.shape[0] - + N = x.shape[0] block_size = triton.next_power_of_2(math.ceil(math.sqrt(N))) + + # if N <= 2560000: mid_size = triton.cdiv(N, block_size) block_mid = triton.next_power_of_2(mid_size) @@ -73,6 +95,19 @@ def dot(x, y): dot_kernel_1[grid_1](x, y, mid, N, block_size) dot_kernel_2[grid_2](mid, out, mid_size, block_mid) + # else: + # grid_size = triton.cdiv(N, block_size) + # grid = (grid_size,1,1) + + # with torch_device_fn.device(x.device): + # if x.dtype != torch.float32: + # out = torch.zeros([], dtype=torch.float32, device=x.device) + # dot_kernel[grid](x, y, out, N, block_size) + # out = out.to(x.dtype) + # else: + # out = torch.zeros([], dtype=x.dtype, device=x.device) + # dot_kernel[grid](x, y, out, N, block_size) + return out From d452bc395bc82b8273181497bedb5c247991e30b Mon Sep 17 00:00:00 2001 From: wlxjhyf Date: Tue, 25 Feb 2025 13:20:36 +0000 Subject: [PATCH 3/4] fix_format_error --- benchmark/test_reduction_perf.py | 16 ++++++------- src/flag_gems/ops/__init__.py | 2 +- src/flag_gems/ops/dot.py | 40 ++++++++------------------------ tests/test_reduction_ops.py | 8 ++----- 4 files changed, 21 insertions(+), 45 deletions(-) diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index 325e6d192..2b48fffd8 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -12,7 +12,6 @@ Config, GenericBenchmark, GenericBenchmark2DOnly, - SkipVersion, generate_tensor_input, unary_input_fn, ) @@ -227,16 +226,17 @@ def dot_input_fn(shape, dtype, device): if inp.dim() > 1: inp = inp.flatten() yield inp, inp - + bench = GenericBenchmark( - input_fn = dot_input_fn, - op_name = "dot", - torch_op = torch.dot, - dtypes = FLOAT_DTYPES, + input_fn=dot_input_fn, + op_name="dot", + torch_op=torch.dot, + dtypes=FLOAT_DTYPES, ) - + bench.run() - + + class quantileBenchmark(GenericBenchmark): def set_more_shapes(self): more_shapes_1d = [(4,), (1024,), (65535)] diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 0f520750d..26b01521c 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -31,6 +31,7 @@ from .diag_embed import diag_embed from .diagonal import diagonal_backward from .div import div_mode, floor_divide, remainder, true_divide +from .dot import dot from .dropout import native_dropout from .embedding import embedding from .eq import eq, eq_scalar @@ -138,7 +139,6 @@ from .where import where_scalar_other, where_scalar_self, where_self, where_self_out from .zeros import zeros from .zeros_like import zeros_like -from .dot import dot __all__ = [ "log_sigmoid", diff --git a/src/flag_gems/ops/dot.py b/src/flag_gems/ops/dot.py index 5b0494159..ca877e4d7 100644 --- a/src/flag_gems/ops/dot.py +++ b/src/flag_gems/ops/dot.py @@ -5,7 +5,6 @@ import triton import triton.language as tl -from .. import runtime from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -13,41 +12,29 @@ @libentry() @triton.jit -def dot_kernel( - x_ptr, - y_ptr, - out_ptr, - N, - BLOCK_SIZE: tl.constexpr -): +def dot_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr): pid = tle.program_id(0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) - + mask = offsets < N x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) partial_sum = tl.sum(x * y) - + tl.atomic_add(out_ptr, partial_sum) @libentry() @triton.jit -def dot_kernel_1( - x_ptr, - y_ptr, - mid_ptr, - N, - BLOCK_SIZE: tl.constexpr -): +def dot_kernel_1(x_ptr, y_ptr, mid_ptr, N, BLOCK_SIZE: tl.constexpr): pid = tle.program_id(0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) - + mask = offsets < N x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) @@ -58,12 +45,7 @@ def dot_kernel_1( @libentry() @triton.jit -def dot_kernel_2( - mid_ptr, - out_ptr, - M, - BLOCK_MID: tl.constexpr -): +def dot_kernel_2(mid_ptr, out_ptr, M, BLOCK_MID: tl.constexpr): offset = tl.arange(0, BLOCK_MID) mid = mid_ptr + offset mask = offset < M @@ -74,17 +56,17 @@ def dot_kernel_2( def dot(x, y): logging.debug("Triton Dot Product") - + assert x.shape == y.shape, "Input vectors must have the same shape" assert x.dim() == 1, "Input must be 1D tensors" - N = x.shape[0] + N = x.shape[0] block_size = triton.next_power_of_2(math.ceil(math.sqrt(N))) - + # if N <= 2560000: mid_size = triton.cdiv(N, block_size) block_mid = triton.next_power_of_2(mid_size) - + grid_1 = (mid_size, 1, 1) grid_2 = (1, 1, 1) @@ -109,5 +91,3 @@ def dot(x, y): # dot_kernel[grid](x, y, out, N, block_size) return out - - diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index de294c28e..f9a1dda5f 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -952,10 +952,6 @@ def test_accuracy_depthwise2d( gems_assert_close(res_out, ref_out, dtype) - - - - INDEX_PUT_SHAPE_ACC_FALSE = ( ((2**28,), ((2**16,),), (2**16,)), ((32, 32), ((8,), (8,)), (8,)), @@ -1055,8 +1051,8 @@ def test_accuracy_mse_loss(shape, dtype, reduction): with flag_gems.use_gems(): res_out = torch.nn.functional.mse_loss(inp, target, reduction=reduction) gems_assert_close(res_out, ref_out, dtype, equal_nan=True, reduce_dim=shape[dim]) - - + + @pytest.mark.dot @pytest.mark.parametrize("shape", UT_SHAPES_1D) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) From 7d9f6035f10bd352bfcd23447835b3d1f661e616 Mon Sep 17 00:00:00 2001 From: wlxjhyf Date: Wed, 26 Feb 2025 22:31:14 +0800 Subject: [PATCH 4/4] fix with single kernel in small input --- src/flag_gems/ops/dot.py | 62 +++++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/src/flag_gems/ops/dot.py b/src/flag_gems/ops/dot.py index ca877e4d7..77818ad9b 100644 --- a/src/flag_gems/ops/dot.py +++ b/src/flag_gems/ops/dot.py @@ -9,7 +9,6 @@ from ..utils import libentry from ..utils import triton_lang_extension as tle - @libentry() @triton.jit def dot_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr): @@ -22,9 +21,8 @@ def dot_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr): x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - partial_sum = tl.sum(x * y) - - tl.atomic_add(out_ptr, partial_sum) + sum = tl.sum(x * y) + tl.store(out_ptr, sum) @libentry() @@ -61,33 +59,33 @@ def dot(x, y): assert x.dim() == 1, "Input must be 1D tensors" N = x.shape[0] - block_size = triton.next_power_of_2(math.ceil(math.sqrt(N))) - - # if N <= 2560000: - mid_size = triton.cdiv(N, block_size) - block_mid = triton.next_power_of_2(mid_size) - - grid_1 = (mid_size, 1, 1) - grid_2 = (1, 1, 1) - - mid = torch.empty((mid_size,), dtype=torch.float32, device=x.device) - out = torch.empty([], dtype=x.dtype, device=x.device) - - with torch_device_fn.device(x.device): - dot_kernel_1[grid_1](x, y, mid, N, block_size) - dot_kernel_2[grid_2](mid, out, mid_size, block_mid) - - # else: - # grid_size = triton.cdiv(N, block_size) - # grid = (grid_size,1,1) - - # with torch_device_fn.device(x.device): - # if x.dtype != torch.float32: - # out = torch.zeros([], dtype=torch.float32, device=x.device) - # dot_kernel[grid](x, y, out, N, block_size) - # out = out.to(x.dtype) - # else: - # out = torch.zeros([], dtype=x.dtype, device=x.device) - # dot_kernel[grid](x, y, out, N, block_size) + + # Only when N is less than TRITON_MAX_TENSOR_NUMEL can it be processed with a single kernel, and performance is better when N < 4096 + if N >= 4096: + block_size = triton.next_power_of_2(math.ceil(math.sqrt(N))) + + mid_size = triton.cdiv(N, block_size) + block_mid = triton.next_power_of_2(mid_size) + + grid_1 = (mid_size, 1, 1) + grid_2 = (1, 1, 1) + + mid = torch.empty((mid_size,), dtype=torch.float32, device=x.device) + out = torch.empty([], dtype=x.dtype, device=x.device) + + with torch_device_fn.device(x.device): + dot_kernel_1[grid_1](x, y, mid, N, block_size) + dot_kernel_2[grid_2](mid, out, mid_size, block_mid) + + else: + block_size = triton.next_power_of_2(math.ceil(N)) + + grid = (1, 1, 1) + + out = torch.empty([], dtype=torch.float32, device=x.device) + + with torch_device_fn.device(x.device): + dot_kernel[grid](x, y, out, N, block_size) + out = out.to(x.dtype) return out