diff --git a/sdp/pyfftw_sdp.py b/sdp/pyfftw_sdp.py index bd0a0e9..55e9454 100644 --- a/sdp/pyfftw_sdp.py +++ b/sdp/pyfftw_sdp.py @@ -4,48 +4,120 @@ class SLIDING_DOT_PRODUCT: # https://stackoverflow.com/a/30615425/2955541 - def __init__(self): - self.m = 0 - self.n = 0 - self.threads = 1 - self.rfft_Q_obj = None - self.rfft_T_obj = None - self.irfft_obj = None - - def __call__(self, Q, T): - if Q.shape[0] != self.m or T.shape[0] != self.n: - self.m = Q.shape[0] - self.n = T.shape[0] - shape = pyfftw.next_fast_len(self.n) - self.rfft_Q_obj = pyfftw.builders.rfft( - np.empty(self.m), overwrite_input=True, n=shape, threads=self.threads + def __init__(self, max_n=2**20): + """ + Parameters + ---------- + max_n : int + Maximum length to preallocate arrays for. This will be the size of the + the real-valued array. A complex-valued array of size `1 + (max_n // 2)` + will also be preallocated. + """ + # Preallocate arrays + self.real_arr = pyfftw.empty_aligned(max_n, dtype="float64") + self.complex_arr = pyfftw.empty_aligned(1 + (max_n // 2), dtype="complex128") + + # Store FFTW objects, keyed by (next_fast_n, n_threads, planning_flag) + self.rfft_objects = {} + self.irfft_objects = {} + + def __call__(self, Q, T, n_threads=1, planning_flag="FFTW_MEASURE"): + """ + Compute the sliding dot product between `Q` and `T` using FFTW via pyfftw. + + Parameters + ---------- + Q : numpy.ndarray + Query array or subsequence. + + T : numpy.ndarray + Time series or sequence. + + n_threads : int, default=1 + Number of threads to use for FFTW computations. + + planning_flag : str, default="FFTW_MEASURE" + The planning flag that will be used in FFTW for planning. + See pyfftw documentation for details. Current options include: + "FFTW_ESTIMATE", "FFTW_MEASURE", "FFTW_PATIENT", and "FFTW_EXHAUSTIVE". + + Returns + ------- + out : numpy.ndarray + Sliding dot product between `Q` and `T`. + """ + m = Q.shape[0] + n = T.shape[0] + next_fast_n = pyfftw.next_fast_len(n) + + # Update preallocated arrays if needed + if next_fast_n > len(self.real_arr): + self.real_arr = pyfftw.empty_aligned(next_fast_n, dtype="float64") + self.complex_arr = pyfftw.empty_aligned( + 1 + (next_fast_n // 2), dtype="complex128" ) - self.rfft_T_obj = pyfftw.builders.rfft( - np.empty(self.n), overwrite_input=True, n=shape, threads=self.threads + + real_arr = self.real_arr[:next_fast_n] + complex_arr = self.complex_arr[: 1 + (next_fast_n // 2)] + + # Get or create FFTW objects + key = (next_fast_n, n_threads, planning_flag) + + rfft_obj = self.rfft_objects.get(key, None) + if rfft_obj is None: + rfft_obj = pyfftw.FFTW( + input_array=real_arr, + output_array=complex_arr, + direction="FFTW_FORWARD", + flags=(planning_flag,), + threads=n_threads, ) - self.irfft_obj = pyfftw.builders.irfft( - self.rfft_Q_obj.output_array, - overwrite_input=True, - n=shape, - threads=self.threads, + self.rfft_objects[key] = rfft_obj + else: + rfft_obj.update_arrays(real_arr, complex_arr) + + irfft_obj = self.irfft_objects.get(key, None) + if irfft_obj is None: + irfft_obj = pyfftw.FFTW( + input_array=complex_arr, + output_array=real_arr, + direction="FFTW_BACKWARD", + flags=(planning_flag, "FFTW_DESTROY_INPUT"), + threads=n_threads, ) + self.irfft_objects[key] = irfft_obj + else: + irfft_obj.update_arrays(complex_arr, real_arr) + + # RFFT(T) + real_arr[:n] = T + real_arr[n:] = 0.0 + rfft_obj.execute() # output is in complex_arr + complex_arr_T = complex_arr.copy() + + # RFFT(Q) + # Scale by 1/next_fast_n to account for + # FFTW's unnormalized inverse FFT via execute() + real_arr[:m] = Q[::-1] / next_fast_n + real_arr[m:] = 0.0 + rfft_obj.execute() # output is in complex_arr + + # RFFT(T) * RFFT(Q) + np.multiply(complex_arr, complex_arr_T, out=complex_arr) - Qr = Q[::-1] # Reverse/flip Q - rfft_padded_Q = self.rfft_Q_obj(Qr) - rfft_padded_T = self.rfft_T_obj(T) + # IRFFT (input is in complex_arr) + irfft_obj.execute() # output is in real_arr - return self.irfft_obj(np.multiply(rfft_padded_Q, rfft_padded_T)).real[ - self.m - 1 : self.n - ] + return real_arr[m - 1 : n] _sliding_dot_product = SLIDING_DOT_PRODUCT() -def setup(Q, T): - _sliding_dot_product(Q, T) +def setup(Q, T, n_threads=1, planning_flag="FFTW_MEASURE"): + _sliding_dot_product(Q, T, n_threads=n_threads, planning_flag=planning_flag) return -def sliding_dot_product(Q, T): - return _sliding_dot_product(Q, T) +def sliding_dot_product(Q, T, n_threads=1, planning_flag="FFTW_MEASURE"): + return _sliding_dot_product(Q, T, n_threads=n_threads, planning_flag=planning_flag) diff --git a/test.py b/test.py index c72a6b3..d1b80e8 100644 --- a/test.py +++ b/test.py @@ -192,3 +192,21 @@ def test_setup(): raise e return + + +def test_pyfftw_sdp_max_n(): + # When `len(T)` larger than `max_n` in pyfftw_sdp, + # the internal preallocated arrays should be resized. + # This test checks that functionality. + from sdp.pyfftw_sdp import SLIDING_DOT_PRODUCT + + T = np.random.rand(2**12) + Q = np.random.rand(2**8) + + sliding_dot_product = SLIDING_DOT_PRODUCT(max_n=2**10) + comp = sliding_dot_product(Q, T) + ref = naive_sliding_dot_product(Q, T) + + np.testing.assert_allclose(comp, ref) + + return diff --git a/timing.py b/timing.py index 0bc4af1..7fc2a6c 100755 --- a/timing.py +++ b/timing.py @@ -9,7 +9,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( - description="./timing.py -noheader -pmin 6 -pmax 23 -pdiff 3 pyfftw challenger" + description="./timing.py -pmin 6 -pmax 23 -pdiff 3 pyfftw challenger" ) parser.add_argument("-noheader", default=False, action="store_true") parser.add_argument(