diff --git a/python/cufinufft/cufinufft/_plan.py b/python/cufinufft/cufinufft/_plan.py index 70973ba00..263c12cd7 100644 --- a/python/cufinufft/cufinufft/_plan.py +++ b/python/cufinufft/cufinufft/_plan.py @@ -5,6 +5,8 @@ """ import atexit +import collections.abc +import numbers import sys import warnings @@ -108,8 +110,12 @@ def __init__(self, nufft_type, n_modes, n_trans=1, eps=1e-6, isign=None, else: raise TypeError("Expected complex64 or complex128.") - if isinstance(n_modes, int): + if isinstance(n_modes, numbers.Integral): n_modes = (n_modes,) + elif isinstance(n_modes, collections.abc.Iterable): + n_modes = tuple(n_modes) + else: + raise ValueError(f"Invalid n_modes '{n_modes}'") self.dim = len(n_modes) self.type = nufft_type