Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 82 additions & 31 deletions sdp/pyfftw_sdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,99 @@

class SLIDING_DOT_PRODUCT:
# https://stackoverflow.com/a/30615425/2955541
def __init__(self):
self.m = 0
def __init__(self, max_n=2**20):
"""
Parameters
----------
max_n : int
Maximum length to preallocate arrays for. This will be the size of the
the real-valued array. A complex-valued array of size `1 + (max_n // 2)`
will also be preallocated.
"""
self.n = 0
self.threads = 1
self.rfft_Q_obj = None
self.rfft_T_obj = None
self.irfft_obj = None

def __call__(self, Q, T):
if Q.shape[0] != self.m or T.shape[0] != self.n:
self.m = Q.shape[0]
self.next_fast_n = 0

# Preallocate arrays
self.real_arr = pyfftw.empty_aligned(max_n, dtype="float64")
self.complex_arr = pyfftw.empty_aligned(1 + (max_n // 2), dtype="complex128")

# Store FFTW objects in a dict keyed by `next_fast_n`, where n=len(T)
self.rfft_objects = {}
self.irfft_objects = {}

def __call__(self, Q, T, threads=1, planning_flag="FFTW_MEASURE"):
m = Q.shape[0]
if self.n != T.shape[0]:
self.n = T.shape[0]
shape = pyfftw.next_fast_len(self.n)
self.rfft_Q_obj = pyfftw.builders.rfft(
np.empty(self.m), overwrite_input=True, n=shape, threads=self.threads
self.next_fast_n = pyfftw.next_fast_len(self.n)

# Update preallocated arrays if needed
if self.next_fast_n > len(self.real_arr):
self.real_arr = pyfftw.empty_aligned(self.next_fast_n, dtype="float64")
self.complex_arr = pyfftw.empty_aligned(
1 + (self.next_fast_n // 2), dtype="complex128"
)
self.rfft_T_obj = pyfftw.builders.rfft(
np.empty(self.n), overwrite_input=True, n=shape, threads=self.threads

real_arr = self.real_arr[: self.next_fast_n]
complex_arr = self.complex_arr[: 1 + (self.next_fast_n // 2)]

# Get or create FFTW objects
rfft_obj = self.rfft_objects.get(self.next_fast_n, None)
if rfft_obj is None:
rfft_obj = pyfftw.FFTW(
input_array=real_arr,
output_array=complex_arr,
direction="FFTW_FORWARD",
flags=(planning_flag,),
threads=threads,
)
self.irfft_obj = pyfftw.builders.irfft(
self.rfft_Q_obj.output_array,
overwrite_input=True,
n=shape,
threads=self.threads,
self.rfft_objects[self.next_fast_n] = rfft_obj
else:
rfft_obj.update_arrays(real_arr, complex_arr)

irfft_obj = self.irfft_objects.get(self.next_fast_n, None)
if irfft_obj is None:
irfft_obj = pyfftw.FFTW(
input_array=complex_arr,
output_array=real_arr,
direction="FFTW_BACKWARD",
flags=(planning_flag, "FFTW_DESTROY_INPUT"),
threads=threads,
)
self.irfft_objects[self.next_fast_n] = irfft_obj
else:
irfft_obj.update_arrays(complex_arr, real_arr)

# RFFT(T)
real_arr[: self.n] = T
real_arr[self.n :] = 0.0
rfft_obj.execute() # output is in complex_arr
complex_arr_T = complex_arr.copy()

# RFFT(Q)
# Scale by 1/next_fast_n to account for
# FFTW's unnormalized inverse FFT via execute()
real_arr[:m] = Q[::-1] / self.next_fast_n
real_arr[m:] = 0.0
rfft_obj.execute() # output is in complex_arr

# RFFT(T) * RFFT(Q)
np.multiply(complex_arr, complex_arr_T, out=complex_arr)

Qr = Q[::-1] # Reverse/flip Q
rfft_padded_Q = self.rfft_Q_obj(Qr)
rfft_padded_T = self.rfft_T_obj(T)
# IRFFT
# input is in complex_arr
irfft_obj.execute() # output is in real_arr

return self.irfft_obj(np.multiply(rfft_padded_Q, rfft_padded_T)).real[
self.m - 1 : self.n
]
return real_arr[m - 1 : self.n]


_sliding_dot_product = SLIDING_DOT_PRODUCT()
_sliding_dot_product = SLIDING_DOT_PRODUCT(max_n=2**20)


def setup(Q, T):
_sliding_dot_product(Q, T)
def setup(Q, T, threads=1, planning_flag="FFTW_MEASURE"):
_sliding_dot_product(Q, T, threads=threads, planning_flag=planning_flag)
return


def sliding_dot_product(Q, T):
return _sliding_dot_product(Q, T)
def sliding_dot_product(Q, T, threads=1, planning_flag="FFTW_MEASURE"):
return _sliding_dot_product(Q, T, threads=threads, planning_flag=planning_flag)
14 changes: 14 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,17 @@ def test_setup():
raise e

return


def test_pyfftw_sdp_max_n():
from sdp.pyfftw_sdp import SLIDING_DOT_PRODUCT

sliding_dot_product = SLIDING_DOT_PRODUCT(max_n=2**10)
T = np.random.rand(2**12) # len(T) is larger than max_n
Q = np.random.rand(2**8)

comp = sliding_dot_product(Q, T)
ref = naive_sliding_dot_product(Q, T)
np.testing.assert_allclose(comp, ref)

return
2 changes: 1 addition & 1 deletion timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="./timing.py -noheader -pmin 6 -pmax 23 -pdiff 3 pyfftw challenger"
description="./timing.py -pmin 6 -pmax 23 -pdiff 3 pyfftw challenger"
)
parser.add_argument("-noheader", default=False, action="store_true")
parser.add_argument(
Expand Down
Loading