diff --git a/src/hexfft/__init__.py b/src/hexfft/__init__.py index 0bf3edc..c948369 100644 --- a/src/hexfft/__init__.py +++ b/src/hexfft/__init__.py @@ -1,2 +1,2 @@ -from .hexfft import fftshift, ifftshift, fft, ifft, FFT +from .hexfft import fft, ifft, FFT from .array import HexArray diff --git a/src/hexfft/hexfft.py b/src/hexfft/hexfft.py index 982a8ad..0be76a6 100644 --- a/src/hexfft/hexfft.py +++ b/src/hexfft/hexfft.py @@ -503,79 +503,6 @@ def mersereau_ifft(PX): return px.T / 4 -def fftshift(X): - N = X.shape[0] - n1, n2 = X.indices - m = hsupport(N, X.pattern).astype(bool) - shifted = HexArray(np.zeros_like(X), X.pattern) - if X.pattern == "oblique": - regI = (n1 < N // 2) & (n2 < N // 2) - regII = m & (n1 < n2) & (n2 >= N // 2) - regIII = m & (n2 <= n1) & (n1 >= N // 2) - - _regI = (n1 >= N // 2) & (n2 >= N // 2) - _regII = m & (n1 >= n2) & (n2 < N // 2) - _regIII = m & (n2 > n1) & (n1 < N // 2) - - shifted[_regI] = X[regI] - shifted[_regII] = X[regII] - shifted[_regIII] = X[regIII] - - elif X.pattern == "offset": - m = m.T - n2 = n2 - N // 4 - regI = m & (n1 < N // 2) & (n2 < N // 2) - regII = m & (n1 <= n2) & (n2 >= N // 2) - regIII = m & (n2 < n1) & (n1 >= N // 2) - - _regI = m & (n1 >= N // 2) & (n2 >= N // 2) - _regII = m & (n1 > n2) & (n2 < N // 2) - _regIII = m & (n2 >= n1) & (n1 < N // 2) - - shifted[_regI.T] = X[regI.T] - shifted[_regII.T] = X[regII.T] - shifted[_regIII.T] = X[regIII.T] - - return shifted - - -def ifftshift(X): - N = X.shape[0] - n1, n2 = X.indices - m = hsupport(N, X.pattern).astype(bool) - shifted = HexArray(np.zeros_like(X), X.pattern) - - if X.pattern == "oblique": - _regI = (n1 < N // 2) & (n2 < N // 2) - _regII = m & (n1 < n2) & (n2 >= N // 2) - _regIII = m & (n2 <= n1) & (n1 >= N // 2) - - regI = (n1 >= N // 2) & (n2 >= N // 2) - regII = m & (n1 >= n2) & (n2 < N // 2) - regIII = m & (n2 > n1) & (n1 < N // 2) - - shifted[_regI] = X[regI] - shifted[_regII] = X[regII] - shifted[_regIII] = X[regIII] - - elif X.pattern == "offset": - m = m.T - n2 = n2 - N // 4 - regI = m & (n1 < N // 2) & (n2 < N // 2) - regII = m & (n1 <= n2) & (n2 >= N // 2) - regIII = m & (n2 < n1) & (n1 >= N // 2) - - _regI = m & (n1 >= N // 2) & (n2 >= N // 2) - _regII = m & (n1 > n2) & (n2 < N // 2) - _regIII = m & (n2 >= n1) & (n1 < N // 2) - - shifted[regI.T] = X[_regI.T] - shifted[regII.T] = X[_regII.T] - shifted[regIII.T] = X[_regIII.T] - - return shifted - - def _hexdft_pgram(px): """""" dtype = px.dtype diff --git a/src/hexfft/utils.py b/src/hexfft/utils.py index 11ae8d0..da46ed0 100644 --- a/src/hexfft/utils.py +++ b/src/hexfft/utils.py @@ -111,14 +111,88 @@ def pgram_to_hex(p, N, pattern="oblique"): return HexArray(h, pattern) -def filter_shift(x): +def fftshift(X): + N = X.shape[0] + n1, n2 = X.indices + m = hsupport(N, X.pattern).astype(bool) + shifted = HexArray(np.zeros_like(X), X.pattern) + if X.pattern == "oblique": + regI = (n1 < N // 2) & (n2 < N // 2) + regII = m & (n1 < n2) & (n2 >= N // 2) + regIII = m & (n2 <= n1) & (n1 >= N // 2) + + _regI = (n1 >= N // 2) & (n2 >= N // 2) + _regII = m & (n1 >= n2) & (n2 < N // 2) + _regIII = m & (n2 > n1) & (n1 < N // 2) + + shifted[_regI] = X[regI] + shifted[_regII] = X[regII] + shifted[_regIII] = X[regIII] + + elif X.pattern == "offset": + m = m.T + n2 = n2 - N // 4 + regI = m & (n1 < N // 2) & (n2 < N // 2) + regII = m & (n1 <= n2) & (n2 >= N // 2) + regIII = m & (n2 < n1) & (n1 >= N // 2) + + _regI = m & (n1 >= N // 2) & (n2 >= N // 2) + _regII = m & (n1 > n2) & (n2 < N // 2) + _regIII = m & (n2 >= n1) & (n1 < N // 2) + + shifted[_regI.T] = X[regI.T] + shifted[_regII.T] = X[regII.T] + shifted[_regIII.T] = X[regIII.T] + + return shifted + + +def ifftshift(X): + N = X.shape[0] + n1, n2 = X.indices + m = hsupport(N, X.pattern).astype(bool) + shifted = HexArray(np.zeros_like(X), X.pattern) + + if X.pattern == "oblique": + _regI = (n1 < N // 2) & (n2 < N // 2) + _regII = m & (n1 < n2) & (n2 >= N // 2) + _regIII = m & (n2 <= n1) & (n1 >= N // 2) + + regI = (n1 >= N // 2) & (n2 >= N // 2) + regII = m & (n1 >= n2) & (n2 < N // 2) + regIII = m & (n2 > n1) & (n1 < N // 2) + + shifted[_regI] = X[regI] + shifted[_regII] = X[regII] + shifted[_regIII] = X[regIII] + + elif X.pattern == "offset": + m = m.T + n2 = n2 - N // 4 + regI = m & (n1 < N // 2) & (n2 < N // 2) + regII = m & (n1 <= n2) & (n2 >= N // 2) + regIII = m & (n2 < n1) & (n1 >= N // 2) + + _regI = m & (n1 >= N // 2) & (n2 >= N // 2) + _regII = m & (n1 > n2) & (n2 < N // 2) + _regIII = m & (n2 >= n1) & (n1 < N // 2) + + shifted[regI.T] = X[_regI.T] + shifted[regII.T] = X[_regII.T] + shifted[regIII.T] = X[_regIII.T] + + return shifted + + +def filter_shift(x, periodicity="rect"): """ - Shift the quadrants of a HexArray to move the origin to/from + Shift the quadrants/thirds of a HexArray to move the origin to/from the center of the grid. Useful if a filter kernel is easier to define the origin at the center of the grid. """ - if not isinstance(x, HexArray): - x = HexArray(x) + + if periodicity == "hex": + return ifftshift(x) N1, N2 = x.shape out = HexArray(np.zeros_like(x), x.pattern) diff --git a/tests/test_api.py b/tests/test_api.py index b707741..4b60c80 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,4 +1,4 @@ -from hexfft import HexArray, fft, ifft, FFT, fftshift, ifftshift +from hexfft import HexArray, fft, ifft, FFT from hexfft.array import rect_shift, rect_unshift import numpy as np import pytest diff --git a/tests/test_hexfft.py b/tests/test_hexfft.py index 10fdb43..c2fe878 100644 --- a/tests/test_hexfft.py +++ b/tests/test_hexfft.py @@ -1,4 +1,4 @@ -from hexfft import fftshift, ifftshift, HexArray +from hexfft import fft, ifft, FFT, HexArray from hexfft.hexfft import ( _hexdft_pgram, _hexidft_pgram, @@ -8,15 +8,14 @@ mersereau_ifft, rect_fft, rect_ifft, - FFT, - fft, - ifft, ) from hexfft.utils import ( hsupport, pgram_to_hex, nice_test_function, hex_to_pgram, + fftshift, + ifftshift, ) from hexfft.array import rect_shift, rect_unshift from hexfft.reference import (