-
Notifications
You must be signed in to change notification settings - Fork 8
Implements copy function #113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
894410f
to
c2ca89c
Compare
940e3e9
to
3853f82
Compare
This PR will have to update examples that use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, there are a few other places where we use iris.put
. But, I think there something about the semantics of the copy is not correct.
@@ -31,7 +31,7 @@ def get_kernel( | |||
# Loop over all ranks, get the stored data. | |||
# load to local register, accumulate. | |||
for target_rank in range(num_ranks): | |||
iris.get(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask) | |||
iris.copy(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be:
iris.copy(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask) | |
iris.copy(data + offsets, results + offsets, target_rank, cur_rank, heap_bases, mask=mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code I suggest fails the test btw but it shouldn't according to the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume here the from_rank is always the current rank. I think if we want to allow for interchanging then perhaps this would more appropriate ? let me know what you think
@triton.jit
def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None):
assert cur_rank == from_rank or cur_rank == to_rank, "Cannot copy between two arbitrary ranks"
cur_base = tl.load(heap_bases + cur_rank)
from_base = tl.load(heap_bases + from_rank)
to_base = tl.load(heap_bases + to_rank)
src_ptr_int = tl.cast(src_ptr, tl.uint64)
src_offset = src_ptr_int - cur_base
dst_ptr_int = tl.cast(dst_ptr, tl.uint64)
dst_offset = dst_ptr_int - cur_base
from_base_byte = tl.cast(from_base, tl.pointer_type(tl.int8))
to_base_byte = tl.cast(to_base , tl.pointer_type(tl.int8))
translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype)
translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype)
data = tl.load(translated_src, mask=mask)
tl.store(translated_dst, data, mask=mask)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The solution you proposed here is good but adds that additional overhead of the two translates. I have been thinking about this and I am not sure if there is away to resolve this cleanly.
I don’t really like the put/get names but maybe we will just stick to them for now. Let’s keep this PR open for now and we can come back to it later if we get better ideas. Thanks for your time looking into this and sorry this feature was not very well thought through.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi no worries at all! Thanks a lot for taking the time to review my solution!
Motivation
Closes #98
Implements copy. I keep the get and put as wrappers to the copy function so that the tests pass
Technical Details
Test Plan
Test Result
Submission Checklist