Skip to content

Commit

Permalink
tests: add explanatory comments
Browse files Browse the repository at this point in the history
  • Loading branch information
janden committed Jul 17, 2023
1 parent 24616a8 commit 32e86b0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python/cufinufft/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def test_type1(dtype, shape, M, tol, output_arg):

plan = Plan(1, shape, eps=tol, dtype=complex_dtype)

# Since k_gpu is an array of shape (dim, M), this will expand to
# plan.setpts(k_gpu[0], ..., k_gpu[dim]), allowing us to handle all
# dimensions with the same call.
plan.setpts(*k_gpu)

if output_arg:
Expand Down
6 changes: 6 additions & 0 deletions python/cufinufft/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_simple_type1(dtype, shape, n_trans, M, tol, output_arg):

dim = len(shape)

# Select which function to call based on dimension.
fun = {1: cufinufft.nufft1d1,
2: cufinufft.nufft2d1,
3: cufinufft.nufft3d1}[dim]
Expand All @@ -38,7 +39,11 @@ def test_simple_type1(dtype, shape, n_trans, M, tol, output_arg):
c_gpu = gpuarray.to_gpu(c)

if output_arg:
# Ensure that output array has proper shape i.e., (N1, ...) for no
# batch, (1, N1, ...) for batch of size one, and (n, N1, ...) for
# batch of size n.
fk_gpu = gpuarray.GPUArray(n_trans + shape, dtype=complex_dtype)

fun(*k_gpu, c_gpu, out=fk_gpu, eps=tol)
else:
fk_gpu = fun(*k_gpu, c_gpu, shape, eps=tol)
Expand Down Expand Up @@ -71,6 +76,7 @@ def test_simple_type2(dtype, shape, n_trans, M, tol, output_arg):

if output_arg:
c_gpu = gpuarray.GPUArray(n_trans + (M,), dtype=complex_dtype)

fun(*k_gpu, fk_gpu, eps=tol, out=c_gpu)
else:
c_gpu = fun(*k_gpu, fk_gpu, eps=tol)
Expand Down

0 comments on commit 32e86b0

Please sign in to comment.