Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Op dot #430

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Benchmark,
Config,
GenericBenchmark2DOnly,
GenericBenchmark,
SkipVersion,
generate_tensor_input,
unary_input_fn,
Expand Down Expand Up @@ -186,3 +187,21 @@ def count_nonzero_input_fn(shape, dtype, device):
dtypes=FLOAT_DTYPES,
)
bench.run()


@pytest.mark.dot
def test_perf_dot():
def dot_input_fn(shape, dtype, device):
inp = generate_tensor_input(shape, dtype=dtype, device=device)
if inp.dim() > 1:
inp = inp.flatten()
yield inp, inp

bench = GenericBenchmark(
input_fn = dot_input_fn,
op_name = "dot",
torch_op = torch.dot,
dtypes = FLOAT_DTYPES,
)

bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
("logical_and", logical_and, Autograd.disable),
("logical_xor", logical_xor, Autograd.disable),
("logical_not", logical_not, Autograd.disable),
("dot", dot, Autograd.disable),
),
user_unused_ops_list=[] if unused is None else unused,
lib=lib,
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
from .where import where_scalar_other, where_scalar_self, where_self, where_self_out
from .zeros import zeros
from .zeros_like import zeros_like
from .dot import dot

__all__ = [
"all",
Expand Down Expand Up @@ -289,4 +290,5 @@
"logical_xor",
"logical_not",
"sort",
"dot",
]
78 changes: 78 additions & 0 deletions src/flag_gems/ops/dot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import logging
import math

import torch
import triton
import triton.language as tl

from .. import runtime
from ..runtime import torch_device_fn
from ..utils import libentry
from ..utils import triton_lang_extension as tle




@libentry()
@triton.jit
def dot_kernel_1(
x_ptr,
y_ptr,
mid_ptr,
N,
BLOCK_SIZE: tl.constexpr
):
pid = tle.program_id(0)
block_start = pid * BLOCK_SIZE

offsets = block_start + tl.arange(0, BLOCK_SIZE)

mask = offsets < N
x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)

partial_sum = tl.sum(x * y)
tl.store(mid_ptr + pid, partial_sum)


@libentry()
@triton.jit
def dot_kernel_2(
mid_ptr,
out_ptr,
M,
BLOCK_MID: tl.constexpr
):
offset = tl.arange(0, BLOCK_MID)
mid = mid_ptr + offset
mask = offset < M
mid_val = tl.load(mid, mask=mask, other=0.0)
out_val = tl.sum(mid_val)
tl.store(out_ptr, out_val)


def dot(x, y):
logging.debug("Triton Dot Product")

assert x.shape == y.shape, "Input vectors must have the same shape"
assert x.dim() == 1, "Input must be 1D tensors"

N = x.shape[0]

block_size = triton.next_power_of_2(math.ceil(math.sqrt(N)))
mid_size = triton.cdiv(N, block_size)
block_mid = triton.next_power_of_2(mid_size)

grid_1 = (mid_size, 1, 1)
grid_2 = (1, 1, 1)

mid = torch.empty((mid_size,), dtype=torch.float32, device=x.device)
out = torch.empty([], dtype=x.dtype, device=x.device)

with torch_device_fn.device(x.device):
dot_kernel_1[grid_1](x, y, mid, N, block_size)
dot_kernel_2[grid_2](mid, out, mid_size, block_mid)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to take tensor stride into consideration. but it's a good implementation!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we resort to a single persistent kernel when the input numel is small enough?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reasonable.

Copy link
Author

@wlxjhyf wlxjhyf Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to implement dot in a single kernel using atomic_add, but even in very small input numel, the performance was not good, but I still kept the code in function dot_kernel.
this shows the performance of kernel1 and kernel2
kernel1+2
this shows the performance of single kernel
kenrel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably didnt make myself clear. What I suggested is adding a one pass branch to handle small input. We don't have to use atomic_add on either branch. The two pass branch still exists.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understand!


return out


17 changes: 17 additions & 0 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
REDUCTION_SHAPES,
REDUCTION_SMALL_SHAPES,
SHAPE_STRIDES,
UT_SHAPES_1D,
SkipVersion,
gems_assert_close,
gems_assert_equal,
Expand Down Expand Up @@ -874,3 +875,19 @@ def test_accuracy_depthwise2d(
inp, weight, kernel, bias=None, stride=stride, padding=padding, dilation=1
)
gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.dot
@pytest.mark.parametrize("shape", UT_SHAPES_1D)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_dot_tensor_tensor(shape, dtype):
inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device)
inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device)
ref_inp1 = to_reference(inp1, False)
ref_inp2 = to_reference(inp2, False)

ref_out = torch.dot(ref_inp1, ref_inp2)
with flag_gems.use_gems():
res_out = torch.dot(inp1, inp2)

gems_assert_close(res_out, ref_out, dtype, equal_nan=True)