Skip to content

Commit

Permalink
cuda: allow passing opts in cufinufft simple
Browse files Browse the repository at this point in the history
This way, you can specify more advanced options, like `gpu_method` using
the simple interface in the same way as the plan interface.
  • Loading branch information
janden committed Jul 31, 2023
1 parent c714ff2 commit 1f1d6c0
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions python/cufinufft/cufinufft/_simple.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
from cufinufft import Plan

def nufft1d1(x, data, n_modes=None, out=None, eps=1e-6, isign=1):
return _invoke_plan(1, 1, x, None, None, data, out, isign, eps, n_modes)
def nufft1d1(x, data, n_modes=None, out=None, eps=1e-6, isign=1, **kwargs):
return _invoke_plan(1, 1, x, None, None, data, out, isign, eps, n_modes,
kwargs)

def nufft1d2(x, data, out=None, eps=1e-6, isign=-1):
return _invoke_plan(1, 2, x, None, None, data, out, isign, eps)
def nufft1d2(x, data, out=None, eps=1e-6, isign=-1, **kwargs):
return _invoke_plan(1, 2, x, None, None, data, out, isign, eps, None,
kwargs)

def nufft2d1(x, y, data, n_modes=None, out=None, eps=1e-6, isign=1):
return _invoke_plan(2, 1, x, y, None, data, out, isign, eps, n_modes)
def nufft2d1(x, y, data, n_modes=None, out=None, eps=1e-6, isign=1, **kwargs):
return _invoke_plan(2, 1, x, y, None, data, out, isign, eps, n_modes,
kwargs)

def nufft2d2(x, y, data, out=None, eps=1e-6, isign=-1):
return _invoke_plan(2, 2, x, y, None, data, out, isign, eps)
def nufft2d2(x, y, data, out=None, eps=1e-6, isign=-1, **kwargs):
return _invoke_plan(2, 2, x, y, None, data, out, isign, eps, None, kwargs)

def nufft3d1(x, y, z, data, n_modes=None, out=None, eps=1e-6, isign=1):
return _invoke_plan(3, 1, x, y, z, data, out, isign, eps, n_modes)
def nufft3d1(x, y, z, data, n_modes=None, out=None, eps=1e-6, isign=1,
**kwargs):
return _invoke_plan(3, 1, x, y, z, data, out, isign, eps, n_modes, kwargs)

def nufft3d2(x, y, z, data, out=None, eps=1e-6, isign=-1):
return _invoke_plan(3, 2, x, y, z, data, out, isign, eps)
def nufft3d2(x, y, z, data, out=None, eps=1e-6, isign=-1, **kwargs):
return _invoke_plan(3, 2, x, y, z, data, out, isign, eps, None, kwargs)

def _invoke_plan(dim, nufft_type, x, y, z, data, out, isign, eps, n_modes=None):
def _invoke_plan(dim, nufft_type, x, y, z, data, out, isign, eps,
n_modes=None, kwargs=None):
dtype = data.dtype

n_trans = _get_ntrans(dim, nufft_type, data)
Expand All @@ -28,7 +33,7 @@ def _invoke_plan(dim, nufft_type, x, y, z, data, out, isign, eps, n_modes=None):
if nufft_type == 2:
n_modes = data.shape[-dim:]

plan = Plan(nufft_type, n_modes, n_trans, eps, isign, dtype)
plan = Plan(nufft_type, n_modes, n_trans, eps, isign, dtype, **kwargs)

plan.setpts(x, y, z)

Expand Down

0 comments on commit 1f1d6c0

Please sign in to comment.