-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Op dot #430
base: master
Are you sure you want to change the base?
Add Op dot #430
Changes from 1 commit
198880b
e329cd1
d0f012b
e026813
d452bc3
7d9f603
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import logging | ||
import math | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
from .. import runtime | ||
from ..runtime import torch_device_fn | ||
from ..utils import libentry | ||
from ..utils import triton_lang_extension as tle | ||
|
||
|
||
|
||
|
||
@libentry() | ||
@triton.jit | ||
def dot_kernel_1( | ||
x_ptr, | ||
y_ptr, | ||
mid_ptr, | ||
N, | ||
BLOCK_SIZE: tl.constexpr | ||
): | ||
pid = tle.program_id(0) | ||
block_start = pid * BLOCK_SIZE | ||
|
||
offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
|
||
mask = offsets < N | ||
x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) | ||
y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) | ||
|
||
partial_sum = tl.sum(x * y) | ||
tl.store(mid_ptr + pid, partial_sum) | ||
|
||
|
||
@libentry() | ||
@triton.jit | ||
def dot_kernel_2( | ||
mid_ptr, | ||
out_ptr, | ||
M, | ||
BLOCK_MID: tl.constexpr | ||
): | ||
offset = tl.arange(0, BLOCK_MID) | ||
mid = mid_ptr + offset | ||
mask = offset < M | ||
mid_val = tl.load(mid, mask=mask, other=0.0) | ||
out_val = tl.sum(mid_val) | ||
tl.store(out_ptr, out_val) | ||
|
||
|
||
def dot(x, y): | ||
logging.debug("Triton Dot Product") | ||
|
||
assert x.shape == y.shape, "Input vectors must have the same shape" | ||
assert x.dim() == 1, "Input must be 1D tensors" | ||
|
||
N = x.shape[0] | ||
|
||
block_size = triton.next_power_of_2(math.ceil(math.sqrt(N))) | ||
mid_size = triton.cdiv(N, block_size) | ||
block_mid = triton.next_power_of_2(mid_size) | ||
|
||
grid_1 = (mid_size, 1, 1) | ||
grid_2 = (1, 1, 1) | ||
|
||
mid = torch.empty((mid_size,), dtype=torch.float32, device=x.device) | ||
out = torch.empty([], dtype=x.dtype, device=x.device) | ||
|
||
with torch_device_fn.device(x.device): | ||
dot_kernel_1[grid_1](x, y, mid, N, block_size) | ||
dot_kernel_2[grid_2](mid, out, mid_size, block_mid) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we resort to a single persistent kernel when the input numel is small enough? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reasonable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I probably didnt make myself clear. What I suggested is adding a one pass branch to handle small input. We don't have to use atomic_add on either branch. The two pass branch still exists. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Understand! |
||
|
||
return out | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to take tensor stride into consideration. but it's a good implementation!