Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Op dot #430

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,24 @@ def count_nonzero_input_fn(shape, dtype, device):
bench.run()


@pytest.mark.dot
def test_perf_dot():
def dot_input_fn(shape, dtype, device):
inp = generate_tensor_input(shape, dtype=dtype, device=device)
if inp.dim() > 1:
inp = inp.flatten()
yield inp, inp

bench = GenericBenchmark(
input_fn=dot_input_fn,
op_name="dot",
torch_op=torch.dot,
dtypes=FLOAT_DTYPES,
)

bench.run()


class quantileBenchmark(GenericBenchmark):
def set_more_shapes(self):
more_shapes_1d = [(4,), (1024,), (65535)]
Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
("logical_and", logical_and, Autograd.disable),
("logical_xor", logical_xor, Autograd.disable),
("logical_not", logical_not, Autograd.disable),
("dot", dot, Autograd.disable),
("index_put", index_put, Autograd.disable),
("log_sigmoid", log_sigmoid, Autograd.disable),
("vdot", vdot, Autograd.disable),
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .diag_embed import diag_embed
from .diagonal import diagonal_backward
from .div import div_mode, floor_divide, remainder, true_divide
from .dot import dot
from .dropout import native_dropout
from .embedding import embedding
from .eq import eq, eq_scalar
Expand Down Expand Up @@ -303,6 +304,7 @@
"logical_xor",
"logical_not",
"sort",
"dot",
"nll_loss_forward",
"nll_loss_backward",
"nll_loss2d_forward",
Expand Down
91 changes: 91 additions & 0 deletions src/flag_gems/ops/dot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import logging
import math

import torch
import triton
import triton.language as tl

from ..runtime import torch_device_fn
from ..utils import libentry
from ..utils import triton_lang_extension as tle

@libentry()
@triton.jit
def dot_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
pid = tle.program_id(0)
block_start = pid * BLOCK_SIZE

offsets = block_start + tl.arange(0, BLOCK_SIZE)

mask = offsets < N
x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)

sum = tl.sum(x * y)
tl.store(out_ptr, sum)


@libentry()
@triton.jit
def dot_kernel_1(x_ptr, y_ptr, mid_ptr, N, BLOCK_SIZE: tl.constexpr):
pid = tle.program_id(0)
block_start = pid * BLOCK_SIZE

offsets = block_start + tl.arange(0, BLOCK_SIZE)

mask = offsets < N
x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)

partial_sum = tl.sum(x * y)
tl.store(mid_ptr + pid, partial_sum)


@libentry()
@triton.jit
def dot_kernel_2(mid_ptr, out_ptr, M, BLOCK_MID: tl.constexpr):
offset = tl.arange(0, BLOCK_MID)
mid = mid_ptr + offset
mask = offset < M
mid_val = tl.load(mid, mask=mask, other=0.0)
out_val = tl.sum(mid_val)
tl.store(out_ptr, out_val)


def dot(x, y):
logging.debug("Triton Dot Product")

assert x.shape == y.shape, "Input vectors must have the same shape"
assert x.dim() == 1, "Input must be 1D tensors"

N = x.shape[0]

# Only when N is less than TRITON_MAX_TENSOR_NUMEL can it be processed with a single kernel, and performance is better when N < 4096
if N >= 4096:
block_size = triton.next_power_of_2(math.ceil(math.sqrt(N)))

mid_size = triton.cdiv(N, block_size)
block_mid = triton.next_power_of_2(mid_size)

grid_1 = (mid_size, 1, 1)
grid_2 = (1, 1, 1)

mid = torch.empty((mid_size,), dtype=torch.float32, device=x.device)
out = torch.empty([], dtype=x.dtype, device=x.device)

with torch_device_fn.device(x.device):
dot_kernel_1[grid_1](x, y, mid, N, block_size)
dot_kernel_2[grid_2](mid, out, mid_size, block_mid)

else:
block_size = triton.next_power_of_2(math.ceil(N))

grid = (1, 1, 1)

out = torch.empty([], dtype=torch.float32, device=x.device)

with torch_device_fn.device(x.device):
dot_kernel[grid](x, y, out, N, block_size)
out = out.to(x.dtype)

return out
17 changes: 17 additions & 0 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
REDUCTION_SHAPES,
REDUCTION_SMALL_SHAPES,
SHAPE_STRIDES,
UT_SHAPES_1D,
SkipVersion,
gems_assert_close,
gems_assert_equal,
Expand Down Expand Up @@ -1050,3 +1051,19 @@ def test_accuracy_mse_loss(shape, dtype, reduction):
with flag_gems.use_gems():
res_out = torch.nn.functional.mse_loss(inp, target, reduction=reduction)
gems_assert_close(res_out, ref_out, dtype, equal_nan=True, reduce_dim=shape[dim])


@pytest.mark.dot
@pytest.mark.parametrize("shape", UT_SHAPES_1D)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_dot_tensor_tensor(shape, dtype):
inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device)
inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device)
ref_inp1 = to_reference(inp1, False)
ref_inp2 = to_reference(inp2, False)

ref_out = torch.dot(ref_inp1, ref_inp2)
with flag_gems.use_gems():
res_out = torch.dot(inp1, inp2)

gems_assert_close(res_out, ref_out, dtype, equal_nan=True)