@@ -10,8 +10,6 @@ cimport numpy as np
1010from .common cimport pywt_index_t
1111from ._pywt cimport c_wavelet_from_object, cdata_t, Wavelet, _check_dtype
1212
13- include " config.pxi"
14-
1513np.import_array()
1614
1715
@@ -99,21 +97,20 @@ def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level,
9997 & cD[0 ], output_len, i)
10098 if retval < 0 :
10199 raise RuntimeError (" C swt failed." )
102- IF HAVE_C99_CPLX:
103- if cdata_t is np.complex128_t:
104- cD = np.zeros(output_len, dtype = np.complex128)
105- with nogil:
106- retval = c_wt.double_complex_swt_d(& data[0 ], data_size, wavelet.w,
107- & cD[0 ], output_len, i)
108- if retval < 0 :
109- raise RuntimeError (" C swt failed." )
110- elif cdata_t is np.complex64_t:
111- cD = np.zeros(output_len, dtype = np.complex64)
112- with nogil:
113- retval = c_wt.float_complex_swt_d(& data[0 ], data_size, wavelet.w,
114- & cD[0 ], output_len, i)
115- if retval < 0 :
116- raise RuntimeError (" C swt failed." )
100+ elif cdata_t is np.complex128_t:
101+ cD = np.zeros(output_len, dtype = np.complex128)
102+ with nogil:
103+ retval = c_wt.double_complex_swt_d(& data[0 ], data_size, wavelet.w,
104+ & cD[0 ], output_len, i)
105+ if retval < 0 :
106+ raise RuntimeError (" C swt failed." )
107+ elif cdata_t is np.complex64_t:
108+ cD = np.zeros(output_len, dtype = np.complex64)
109+ with nogil:
110+ retval = c_wt.float_complex_swt_d(& data[0 ], data_size, wavelet.w,
111+ & cD[0 ], output_len, i)
112+ if retval < 0 :
113+ raise RuntimeError (" C swt failed." )
117114
118115 # alloc memory, decompose A
119116 if cdata_t is np.float64_t:
@@ -130,21 +127,20 @@ def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level,
130127 & cA[0 ], output_len, i)
131128 if retval < 0 :
132129 raise RuntimeError (" C swt failed." )
133- IF HAVE_C99_CPLX:
134- if cdata_t is np.complex128_t:
135- cA = np.zeros(output_len, dtype = np.complex128)
136- with nogil:
137- retval = c_wt.double_complex_swt_a(& data[0 ], data_size, wavelet.w,
138- & cA[0 ], output_len, i)
139- if retval < 0 :
140- raise RuntimeError (" C swt failed." )
141- elif cdata_t is np.complex64_t:
142- cA = np.zeros(output_len, dtype = np.complex64)
143- with nogil:
144- retval = c_wt.float_complex_swt_a(& data[0 ], data_size, wavelet.w,
145- & cA[0 ], output_len, i)
146- if retval < 0 :
147- raise RuntimeError (" C swt failed." )
130+ elif cdata_t is np.complex128_t:
131+ cA = np.zeros(output_len, dtype = np.complex128)
132+ with nogil:
133+ retval = c_wt.double_complex_swt_a(& data[0 ], data_size, wavelet.w,
134+ & cA[0 ], output_len, i)
135+ if retval < 0 :
136+ raise RuntimeError (" C swt failed." )
137+ elif cdata_t is np.complex64_t:
138+ cA = np.zeros(output_len, dtype = np.complex64)
139+ with nogil:
140+ retval = c_wt.float_complex_swt_a(& data[0 ], data_size, wavelet.w,
141+ & cA[0 ], output_len, i)
142+ if retval < 0 :
143+ raise RuntimeError (" C swt failed." )
148144
149145 data = cA
150146 if not trim_approx:
@@ -253,58 +249,57 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
253249 if retval:
254250 raise RuntimeError (
255251 " C wavelet transform failed with error code %d " % retval)
252+ elif data.dtype == np.complex128:
253+ cA = np.zeros(output_shape, dtype = np.complex128)
254+ with nogil:
255+ retval = c_wt.double_complex_downcoef_axis(
256+ < double complex * > data.data, data_info,
257+ < double complex * > cA.data, output_info,
258+ wavelet.w, axis,
259+ common.COEF_APPROX, common.MODE_PERIODIZATION,
260+ i, common.SWT_TRANSFORM)
261+ if retval:
262+ raise RuntimeError (
263+ " C wavelet transform failed with error code %d " %
264+ retval)
265+ cD = np.zeros(output_shape, dtype = np.complex128)
266+ with nogil:
267+ retval = c_wt.double_complex_downcoef_axis(
268+ < double complex * > data.data, data_info,
269+ < double complex * > cD.data, output_info,
270+ wavelet.w, axis,
271+ common.COEF_DETAIL, common.MODE_PERIODIZATION,
272+ i, common.SWT_TRANSFORM)
273+ if retval:
274+ raise RuntimeError (
275+ " C wavelet transform failed with error code %d " %
276+ retval)
277+ elif data.dtype == np.complex64:
278+ cA = np.zeros(output_shape, dtype = np.complex64)
279+ with nogil:
280+ retval = c_wt.float_complex_downcoef_axis(
281+ < float complex * > data.data, data_info,
282+ < float complex * > cA.data, output_info,
283+ wavelet.w, axis,
284+ common.COEF_APPROX, common.MODE_PERIODIZATION,
285+ i, common.SWT_TRANSFORM)
286+ if retval:
287+ raise RuntimeError (
288+ " C wavelet transform failed with error code %d " %
289+ retval)
290+ cD = np.zeros(output_shape, dtype = np.complex64)
291+ with nogil:
292+ retval = c_wt.float_complex_downcoef_axis(
293+ < float complex * > data.data, data_info,
294+ < float complex * > cD.data, output_info,
295+ wavelet.w, axis,
296+ common.COEF_DETAIL, common.MODE_PERIODIZATION,
297+ i, common.SWT_TRANSFORM)
298+ if retval:
299+ raise RuntimeError (
300+ " C wavelet transform failed with error code %d " %
301+ retval)
256302
257- IF HAVE_C99_CPLX:
258- if data.dtype == np.complex128:
259- cA = np.zeros(output_shape, dtype = np.complex128)
260- with nogil:
261- retval = c_wt.double_complex_downcoef_axis(
262- < double complex * > data.data, data_info,
263- < double complex * > cA.data, output_info,
264- wavelet.w, axis,
265- common.COEF_APPROX, common.MODE_PERIODIZATION,
266- i, common.SWT_TRANSFORM)
267- if retval:
268- raise RuntimeError (
269- " C wavelet transform failed with error code %d " %
270- retval)
271- cD = np.zeros(output_shape, dtype = np.complex128)
272- with nogil:
273- retval = c_wt.double_complex_downcoef_axis(
274- < double complex * > data.data, data_info,
275- < double complex * > cD.data, output_info,
276- wavelet.w, axis,
277- common.COEF_DETAIL, common.MODE_PERIODIZATION,
278- i, common.SWT_TRANSFORM)
279- if retval:
280- raise RuntimeError (
281- " C wavelet transform failed with error code %d " %
282- retval)
283- elif data.dtype == np.complex64:
284- cA = np.zeros(output_shape, dtype = np.complex64)
285- with nogil:
286- retval = c_wt.float_complex_downcoef_axis(
287- < float complex * > data.data, data_info,
288- < float complex * > cA.data, output_info,
289- wavelet.w, axis,
290- common.COEF_APPROX, common.MODE_PERIODIZATION,
291- i, common.SWT_TRANSFORM)
292- if retval:
293- raise RuntimeError (
294- " C wavelet transform failed with error code %d " %
295- retval)
296- cD = np.zeros(output_shape, dtype = np.complex64)
297- with nogil:
298- retval = c_wt.float_complex_downcoef_axis(
299- < float complex * > data.data, data_info,
300- < float complex * > cD.data, output_info,
301- wavelet.w, axis,
302- common.COEF_DETAIL, common.MODE_PERIODIZATION,
303- i, common.SWT_TRANSFORM)
304- if retval:
305- raise RuntimeError (
306- " C wavelet transform failed with error code %d " %
307- retval)
308303 if retval == - 5 :
309304 raise TypeError (" Array must be floating point, not {}"
310305 .format(data.dtype))
0 commit comments