diff --git a/python/cufinufft/cufinufft/_compat.py b/python/cufinufft/cufinufft/_compat.py index 04e066a1a..3e743dd16 100644 --- a/python/cufinufft/cufinufft/_compat.py +++ b/python/cufinufft/cufinufft/_compat.py @@ -27,6 +27,22 @@ def get_array_module(obj): return "generic" +def get_stream_ptr(obj): + framework = get_array_module(obj) + if isinstance(obj, int): + return obj + if framework == 'numba': + return obj.handle.value + if framework == 'pycuda': + return obj.handle + if framework == 'torch': + return obj.cuda_stream + # Unknown / generic / cupy + if hasattr(obj, 'ptr'): + return obj.ptr + raise TypeError("Unknown cuda stream pointer type") + + def get_array_size(obj): array_module = get_array_module(obj) diff --git a/python/cufinufft/cufinufft/_plan.py b/python/cufinufft/cufinufft/_plan.py index 70973ba00..5e264b70a 100644 --- a/python/cufinufft/cufinufft/_plan.py +++ b/python/cufinufft/cufinufft/_plan.py @@ -125,20 +125,26 @@ def __init__(self, nufft_type, n_modes, n_trans=1, eps=1e-6, isign=None, # Extract list of valid field names. field_names = [name for name, _ in self._opts._fields_] + # Initialize a list for references to objects + # we want to keep around for life of instance. + self._references = [] + # Assign field names from kwargs if they match up, otherwise error. for k, v in kwargs.items(): if k in field_names: - setattr(self._opts, k, v) + if k == 'gpu_stream': + # need reference to stream object to prevent stream cleanup before plan object + self._references.append(v) + stream_ptr = _compat.get_stream_ptr(v) + setattr(self._opts, k, stream_ptr) + else: + setattr(self._opts, k, v) else: raise TypeError(f"Invalid option '{k}'") # Initialize the plan. self._init_plan() - # Initialize a list for references to objects - # we want to keep around for life of instance. - self._references = [] - @staticmethod def _default_opts(): """ diff --git a/python/cufinufft/examples/example3d2many_async_cupy.py b/python/cufinufft/examples/example3d2many_async_cupy.py index 6fbae36d1..64744468e 100644 --- a/python/cufinufft/examples/example3d2many_async_cupy.py +++ b/python/cufinufft/examples/example3d2many_async_cupy.py @@ -44,7 +44,7 @@ def test(): # Initialize the plan and set the points. plan_stream = cupy.cuda.Stream(null=True) - plan = cufinufft.Plan(2, N, n_transf, eps=eps, dtype=complex_dtype, gpu_kerevalmeth=1, gpu_stream=plan_stream.ptr) + plan = cufinufft.Plan(2, N, n_transf, eps=eps, dtype=complex_dtype, gpu_kerevalmeth=1, gpu_stream=plan_stream) plan.setpts(cupy.array(x), cupy.array(y), cupy.array(z)) # Using a simple front/back buffer approach. backbuffer is for DtoH transfers, and front for @@ -54,9 +54,9 @@ def test(): back_fk_gpu = cupy.empty(fk_all[0].shape, fk_all[0].dtype) front_c_gpu = cupy.empty(c_all_async[0].shape, c_all_async[0].dtype) back_c_gpu = cupy.empty(c_all_async[0].shape, c_all_async[0].dtype) - front_plan = cufinufft.Plan(2, N, n_transf, eps=eps, dtype=complex_dtype, gpu_kerevalmeth=1, gpu_stream=front_stream.ptr) + front_plan = cufinufft.Plan(2, N, n_transf, eps=eps, dtype=complex_dtype, gpu_kerevalmeth=1, gpu_stream=front_stream) front_plan.setpts(cupy.array(x), cupy.array(y), cupy.array(z)) - back_plan = cufinufft.Plan(2, N, n_transf, eps=eps, dtype=complex_dtype, gpu_kerevalmeth=1, gpu_stream=back_stream.ptr) + back_plan = cufinufft.Plan(2, N, n_transf, eps=eps, dtype=complex_dtype, gpu_kerevalmeth=1, gpu_stream=back_stream) back_plan.setpts(cupy.array(x), cupy.array(y), cupy.array(z)) # Run with async @@ -101,10 +101,4 @@ def test(): print(f"speedup (naive / sync): {round(naive_time / sync_time, 2)}") print(f"speedup (naive / async): {round(naive_time / async_time, 2)}") - # Since plans carry raw stream pointers which aren't reference counted, we need to make - # sure they're deleted before the stream objects that hold them. Otherwise, the stream - # might be deleted before cufinufft can use it in the deletion routines. Manually clear - # them out here. - del plan, front_plan, back_plan - test()