Skip to content

Commit

Permalink
Verify that dali->dlpack->torch is zero-copy.
Browse files Browse the repository at this point in the history
Signed-off-by: Michał Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Oct 29, 2024
1 parent 3a73c87 commit 370640a
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion dali/test/python/dlpack/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,25 @@ def _test_pipe():


@attr("pytorch")
def test_dlpack():
def test_dlpack_is_zero_copy():
print("Testing dlpack")
# get a DALI pipeline that produces batches of very large tensors
pipe = _test_pipe(batch_size=1, experimental_exec_dynamic=True)
pipe.build()

s = torch.cuda.Stream(0)
with torch.cuda.stream(s):
(out,) = pipe.run(s)
t0 = torch.from_dlpack(out[0])
t1 = torch.from_dlpack(out[0])
assert t0[0, 0] == 54
assert t1[0, 0] == 54
t0[0, 0] = 12345
assert t1[0, 0] == 12345, "t1 and t0 should point to the same memory"


@attr("pytorch")
def test_dlpack_no_corruption():
print("Testing dlpack")
# get a DALI pipeline that produces batches of very large tensors
pipe = _test_pipe(experimental_exec_dynamic=True)
Expand Down

0 comments on commit 370640a

Please sign in to comment.