Skip to content

Commit

Permalink
cuda-python-bugfix: allow passage of cuda stream objects or pointers
Browse files Browse the repository at this point in the history
Pointers are dangerous because they aren't reference counted and so can be deleted before
cufinufft cleans up (#417). This patch allows the user to pass either a stream object OR a
pointer in the `gpu_stream` argument of Plan
  • Loading branch information
blackwer committed Jan 22, 2024
1 parent 7633d26 commit 410e9a6
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 14 deletions.
16 changes: 16 additions & 0 deletions python/cufinufft/cufinufft/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ def get_array_module(obj):
return "generic"


def get_stream_ptr(obj):
framework = get_array_module(obj)
if isinstance(obj, int):
return obj
if framework == 'numba':
return obj.handle.value
if framework == 'pycuda':
return obj.handle
if framework == 'torch':
return obj.cuda_stream
# Unknown / generic / cupy
if hasattr(obj, 'ptr'):
return obj.ptr
raise TypeError("Unknown cuda stream pointer type")


def get_array_size(obj):
array_module = get_array_module(obj)

Expand Down
16 changes: 11 additions & 5 deletions python/cufinufft/cufinufft/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,26 @@ def __init__(self, nufft_type, n_modes, n_trans=1, eps=1e-6, isign=None,
# Extract list of valid field names.
field_names = [name for name, _ in self._opts._fields_]

# Initialize a list for references to objects
# we want to keep around for life of instance.
self._references = []

# Assign field names from kwargs if they match up, otherwise error.
for k, v in kwargs.items():
if k in field_names:
setattr(self._opts, k, v)
if k == 'gpu_stream':
# need reference to stream object to prevent stream cleanup before plan object
self._references.append(v)
stream_ptr = _compat.get_stream_ptr(v)
setattr(self._opts, k, stream_ptr)
else:
setattr(self._opts, k, v)
else:
raise TypeError(f"Invalid option '{k}'")

# Initialize the plan.
self._init_plan()

# Initialize a list for references to objects
# we want to keep around for life of instance.
self._references = []

@staticmethod
def _default_opts():
"""
Expand Down
12 changes: 3 additions & 9 deletions python/cufinufft/examples/example3d2many_async_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test():

# Initialize the plan and set the points.
plan_stream = cupy.cuda.Stream(null=True)
plan = cufinufft.Plan(2, N, n_transf, eps=eps, dtype=complex_dtype, gpu_kerevalmeth=1, gpu_stream=plan_stream.ptr)
plan = cufinufft.Plan(2, N, n_transf, eps=eps, dtype=complex_dtype, gpu_kerevalmeth=1, gpu_stream=plan_stream)
plan.setpts(cupy.array(x), cupy.array(y), cupy.array(z))

# Using a simple front/back buffer approach. backbuffer is for DtoH transfers, and front for
Expand All @@ -54,9 +54,9 @@ def test():
back_fk_gpu = cupy.empty(fk_all[0].shape, fk_all[0].dtype)
front_c_gpu = cupy.empty(c_all_async[0].shape, c_all_async[0].dtype)
back_c_gpu = cupy.empty(c_all_async[0].shape, c_all_async[0].dtype)
front_plan = cufinufft.Plan(2, N, n_transf, eps=eps, dtype=complex_dtype, gpu_kerevalmeth=1, gpu_stream=front_stream.ptr)
front_plan = cufinufft.Plan(2, N, n_transf, eps=eps, dtype=complex_dtype, gpu_kerevalmeth=1, gpu_stream=front_stream)
front_plan.setpts(cupy.array(x), cupy.array(y), cupy.array(z))
back_plan = cufinufft.Plan(2, N, n_transf, eps=eps, dtype=complex_dtype, gpu_kerevalmeth=1, gpu_stream=back_stream.ptr)
back_plan = cufinufft.Plan(2, N, n_transf, eps=eps, dtype=complex_dtype, gpu_kerevalmeth=1, gpu_stream=back_stream)
back_plan.setpts(cupy.array(x), cupy.array(y), cupy.array(z))

# Run with async
Expand Down Expand Up @@ -101,10 +101,4 @@ def test():
print(f"speedup (naive / sync): {round(naive_time / sync_time, 2)}")
print(f"speedup (naive / async): {round(naive_time / async_time, 2)}")

# Since plans carry raw stream pointers which aren't reference counted, we need to make
# sure they're deleted before the stream objects that hold them. Otherwise, the stream
# might be deleted before cufinufft can use it in the deletion routines. Manually clear
# them out here.
del plan, front_plan, back_plan

test()

0 comments on commit 410e9a6

Please sign in to comment.