Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 69 additions & 31 deletions sdp/pyfftw_sdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,86 @@

class SLIDING_DOT_PRODUCT:
# https://stackoverflow.com/a/30615425/2955541
def __init__(self):
self.m = 0
def __init__(self, max_shape=2**10):
"""
Parameters
----------
max_shape : int
Maximum shape to preallocate arrays for. This will be the size of the
the real-valued array. A complex-valued array of size 1 + max_shape // 2
will also be preallocated.
"""
self.n = 0
self.threads = 1
self.rfft_Q_obj = None
self.rfft_T_obj = None
self.irfft_obj = None

def __call__(self, Q, T):
if Q.shape[0] != self.m or T.shape[0] != self.n:
self.m = Q.shape[0]
self.shape = 0
self.real_arr = pyfftw.empty_aligned(max_shape, dtype="float64")
self.complex_arr = pyfftw.empty_aligned(1 + max_shape // 2, dtype="complex128")
self.rfft_irfft_objects = {}

def __call__(self, Q, T, threads=1, planning_flag="FFTW_MEASURE"):
m = Q.shape[0]
if self.n != T.shape[0]:
self.n = T.shape[0]
shape = pyfftw.next_fast_len(self.n)
self.rfft_Q_obj = pyfftw.builders.rfft(
np.empty(self.m), overwrite_input=True, n=shape, threads=self.threads
self.shape = pyfftw.next_fast_len(self.n)

if self.shape > len(self.real_arr):
self.real_arr = pyfftw.empty_aligned(self.shape, dtype="float64")
self.complex_arr = pyfftw.empty_aligned(
1 + self.shape // 2, dtype="complex128"
)
self.rfft_T_obj = pyfftw.builders.rfft(
np.empty(self.n), overwrite_input=True, n=shape, threads=self.threads
real_arr = self.real_arr[: self.shape]
complex_arr = self.complex_arr[: 1 + self.shape // 2]

rfft_irfft_obj = self.rfft_irfft_objects.get(self.shape, None)
if rfft_irfft_obj is None:
rfft_obj = pyfftw.FFTW(
input_array=real_arr,
output_array=complex_arr,
direction="FFTW_FORWARD",
flags=(planning_flag,),
threads=threads,
)
self.irfft_obj = pyfftw.builders.irfft(
self.rfft_Q_obj.output_array,
overwrite_input=True,
n=shape,
threads=self.threads,
irfft_obj = pyfftw.FFTW(
input_array=complex_arr,
output_array=real_arr,
direction="FFTW_BACKWARD",
flags=(planning_flag, "FFTW_DESTROY_INPUT"),
threads=threads,
)
self.rfft_irfft_objects[self.shape] = (rfft_obj, irfft_obj)
else:
rfft_obj, irfft_obj = rfft_irfft_obj
rfft_obj.update_arrays(real_arr, complex_arr)
irfft_obj.update_arrays(complex_arr, real_arr)

# RFFT(T)
real_arr[: self.n] = T
real_arr[self.n :] = 0.0
rfft_obj.execute() # output is in self.complex_arr
complex_arr_T = complex_arr.copy()

# RFFT(Q)
real_arr[:m] = Q[::-1] / self.shape # reversed Q and scale
real_arr[m:] = 0.0
rfft_obj.execute() # output is in self.complex_arr

# RFFT(T) * RFFT(Q)
np.multiply(complex_arr, complex_arr_T, out=complex_arr)

Qr = Q[::-1] # Reverse/flip Q
rfft_padded_Q = self.rfft_Q_obj(Qr)
rfft_padded_T = self.rfft_T_obj(T)
# IRFFT
# input is in self.complex_arr
# output is in self.real_arr
irfft_obj.execute()

return self.irfft_obj(np.multiply(rfft_padded_Q, rfft_padded_T)).real[
self.m - 1 : self.n
]
return real_arr[m - 1 : self.n]


_sliding_dot_product = SLIDING_DOT_PRODUCT()
_sliding_dot_product = SLIDING_DOT_PRODUCT(max_shape=2**10)


def setup(Q, T):
_sliding_dot_product(Q, T)
def setup(Q, T, threads=1, planning_flag="FFTW_MEASURE"):
_sliding_dot_product(Q, T, threads=threads, planning_flag=planning_flag)
return


def sliding_dot_product(Q, T):
return _sliding_dot_product(Q, T)
def sliding_dot_product(Q, T, threads=1, planning_flag="FFTW_MEASURE"):
return _sliding_dot_product(Q, T, threads=threads, planning_flag=planning_flag)
2 changes: 1 addition & 1 deletion timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="./timing.py -noheader -pmin 6 -pmax 23 -pdiff 3 pyfftw challenger"
description="./timing.py -pmin 6 -pmax 23 -pdiff 3 pyfftw challenger"
)
parser.add_argument("-noheader", default=False, action="store_true")
parser.add_argument(
Expand Down
Loading