diff --git a/src/sse/sse_radix4.rs b/src/sse/sse_radix4.rs index 6d8e27e..f344925 100644 --- a/src/sse/sse_radix4.rs +++ b/src/sse/sse_radix4.rs @@ -1,23 +1,25 @@ use num_complex::Complex; +use num_integer::Integer; use std::any::TypeId; +use std::arch::x86_64::{__m128, __m128d}; use std::sync::Arc; -use crate::array_utils::{self, bitreversed_transpose, workaround_transmute_mut}; +use crate::array_utils::{self, reverse_bits, workaround_transmute, workaround_transmute_mut}; use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::sse::sse_utils::transpose_complex_2x2_f32; use crate::{common::FftNum, FftDirection}; use crate::{Direction, Fft, Length}; use super::SseNum; -use super::sse_vector::{Rotation90, SseArray, SseArrayMut, SseVector}; +use super::sse_vector::{Deinterleaved, SseArray, SseArrayMut, SseVector}; /// FFT algorithm optimized for power-of-two sizes, SSE accelerated version. /// This is designed to be used via a Planner, and not created directly. pub struct SseRadix4 { - twiddles: Box<[S::VectorType]>, - rotation: Rotation90, + twiddles: Box<[Deinterleaved]>, base_fft: Arc>, base_len: usize, @@ -52,7 +54,7 @@ impl SseRadix4 { let base_len = base_fft.len(); // note that we can eventually release this restriction - we just need to update the rest of the code in here to handle remainders - assert!(base_len % (2 * S::VectorType::COMPLEX_PER_VECTOR) == 0 && base_len > 0); + assert!(base_len % (S::VectorType::SCALAR_PER_VECTOR) == 0 && base_len > 0); let len = base_len * (1 << (k * 2)); @@ -61,28 +63,38 @@ impl SseRadix4 { // but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down // so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up const ROW_COUNT: usize = 4; - let mut cross_fft_len = base_len * ROW_COUNT; + let mut cross_fft_len = base_len; let mut twiddle_factors = Vec::with_capacity(len * 2); - while cross_fft_len <= len { - let num_scalar_columns = cross_fft_len / ROW_COUNT; - let num_vector_columns = num_scalar_columns / S::VectorType::COMPLEX_PER_VECTOR; + while cross_fft_len < len { + let num_scalar_columns = cross_fft_len; + cross_fft_len *= ROW_COUNT; + + let (quotient, remainder) = + num_scalar_columns.div_rem(&S::VectorType::SCALAR_PER_VECTOR); + let num_vector_columns = quotient + if remainder > 0 { 1 } else { 0 }; for i in 0..num_vector_columns { for k in 1..ROW_COUNT { - twiddle_factors.push(SseVector::make_mixedradix_twiddle_chunk( - i * S::VectorType::COMPLEX_PER_VECTOR, + let twiddle0 = SseVector::make_mixedradix_twiddle_chunk( + i * S::VectorType::SCALAR_PER_VECTOR, + k, + cross_fft_len, + direction, + ); + let twiddle1 = SseVector::make_mixedradix_twiddle_chunk( + i * S::VectorType::SCALAR_PER_VECTOR + S::VectorType::COMPLEX_PER_VECTOR, k, cross_fft_len, direction, - )); + ); + let deinterleaved_twiddles = SseVector::deinterleave(twiddle0, twiddle1); + twiddle_factors.push(deinterleaved_twiddles); } } - cross_fft_len *= ROW_COUNT; } Self { twiddles: twiddle_factors.into_boxed_slice(), - rotation: SseVector::make_rotate90(direction), base_fft, base_len, @@ -103,7 +115,16 @@ impl SseRadix4 { if self.len() == self.base_len { output.copy_from_slice(input); } else { - bitreversed_transpose::, 4>(self.base_len, input, output); + // Hack: Making a f32 vs f64 agaonstic version of this seems hard. Avoid it for now, and hopefully we can make one later + if TypeId::of::() == TypeId::of::() { + let input = workaround_transmute(input); + let output = workaround_transmute_mut(output); + sse_bitreversed_transpose_f32(self.base_len, input, output); + } else { + let input = workaround_transmute(input); + let output = workaround_transmute_mut(output); + sse_bitreversed_transpose_f64(self.base_len, input, output); + } } // Base-level FFTs @@ -111,81 +132,308 @@ impl SseRadix4 { // cross-FFTs const ROW_COUNT: usize = 4; - let mut cross_fft_len = self.base_len * ROW_COUNT; - let mut layer_twiddles: &[S::VectorType] = &self.twiddles; - - while cross_fft_len <= input.len() { - let num_rows = input.len() / cross_fft_len; - let num_scalar_columns = cross_fft_len / ROW_COUNT; - let num_vector_columns = num_scalar_columns / S::VectorType::COMPLEX_PER_VECTOR; - - for i in 0..num_rows { - butterfly_4::( - &mut output[i * cross_fft_len..], - layer_twiddles, - num_scalar_columns, - &self.rotation, - ) + let mut cross_fft_len = self.base_len; + let mut layer_twiddles: &[Deinterleaved] = &self.twiddles; + + while cross_fft_len < input.len() { + let columns = cross_fft_len; + let first = cross_fft_len == self.base_len; + cross_fft_len *= ROW_COUNT; + let last = cross_fft_len == self.len(); + + if first && last { + for data in output.chunks_exact_mut(cross_fft_len) { + butterfly_4::(data, layer_twiddles, columns, self.direction) + } + } else if first { + for data in output.chunks_exact_mut(cross_fft_len) { + butterfly_4::(data, layer_twiddles, columns, self.direction) + } + } else if last { + for data in output.chunks_exact_mut(cross_fft_len) { + butterfly_4::(data, layer_twiddles, columns, self.direction) + } + } else { + for data in output.chunks_exact_mut(cross_fft_len) { + butterfly_4::(data, layer_twiddles, columns, self.direction) + } } // skip past all the twiddle factors used in this layer + let (quotient, remainder) = columns.div_rem(&S::VectorType::SCALAR_PER_VECTOR); + let num_vector_columns = quotient + if remainder > 0 { 1 } else { 0 }; + let twiddle_offset = num_vector_columns * (ROW_COUNT - 1); layer_twiddles = &layer_twiddles[twiddle_offset..]; - - cross_fft_len *= ROW_COUNT; } } } boilerplate_fft_sse_oop!(SseRadix4, |this: &SseRadix4<_, _>| this.len); +#[inline(always)] +fn load_debug_checked(buffer: &[T], idx: usize) -> T { + debug_assert!(idx < buffer.len()); + unsafe { *buffer.get_unchecked(idx) } +} + +#[inline(always)] +unsafe fn load( + buffer: &[Complex], + idx: usize, +) -> Deinterleaved { + if DEINTERLEAVE { + let a = buffer.load_complex(idx); + let b = buffer.load_complex(idx + S::VectorType::COMPLEX_PER_VECTOR); + SseVector::deinterleave(a, b) + } else { + let a = buffer.load_complex(idx); + let b = buffer.load_complex(idx + S::VectorType::COMPLEX_PER_VECTOR); + Deinterleaved { re: a, im: b } + } +} + +#[inline(always)] +unsafe fn store( + mut buffer: &mut [Complex], + vector: Deinterleaved, + idx: usize, +) { + if INTERLEAVE { + let (a, b) = SseVector::interleave(vector); + buffer.store_complex(a, idx); + buffer.store_complex(b, idx + S::VectorType::COMPLEX_PER_VECTOR); + } else { + buffer.store_complex(vector.re, idx); + buffer.store_complex(vector.im, idx + S::VectorType::COMPLEX_PER_VECTOR); + } +} + +#[inline(never)] #[target_feature(enable = "sse4.1")] -unsafe fn butterfly_4( +unsafe fn butterfly_4( data: &mut [Complex], - twiddles: &[S::VectorType], - num_ffts: usize, - rotation: &Rotation90, + twiddles: &[Deinterleaved], + num_scalar_columns: usize, + direction: FftDirection, ) { - let unroll_offset = S::VectorType::COMPLEX_PER_VECTOR; - - let mut idx = 0usize; - let mut buffer: &mut [Complex] = workaround_transmute_mut(data); - for tw in twiddles - .chunks_exact(6) - .take(num_ffts / (S::VectorType::COMPLEX_PER_VECTOR * 2)) - { - let mut scratcha = [ - buffer.load_complex(idx + 0 * num_ffts), - buffer.load_complex(idx + 1 * num_ffts), - buffer.load_complex(idx + 2 * num_ffts), - buffer.load_complex(idx + 3 * num_ffts), + let num_vector_columns = num_scalar_columns / S::VectorType::SCALAR_PER_VECTOR; + let buffer: &mut [Complex] = workaround_transmute_mut(data); + + for i in 0..num_vector_columns { + let idx = i * S::VectorType::SCALAR_PER_VECTOR; + let tw_idx = i * 3; + let mut scratch = [ + load::(buffer, idx + 0 * num_scalar_columns), + load::(buffer, idx + 1 * num_scalar_columns), + load::(buffer, idx + 2 * num_scalar_columns), + load::(buffer, idx + 3 * num_scalar_columns), ]; - let mut scratchb = [ - buffer.load_complex(idx + 0 * num_ffts + unroll_offset), - buffer.load_complex(idx + 1 * num_ffts + unroll_offset), - buffer.load_complex(idx + 2 * num_ffts + unroll_offset), - buffer.load_complex(idx + 3 * num_ffts + unroll_offset), + + let tw1 = load_debug_checked(twiddles, tw_idx + 0); + let tw2 = load_debug_checked(twiddles, tw_idx + 1); + let tw3 = load_debug_checked(twiddles, tw_idx + 2); + + scratch[1] = Deinterleaved::mul_complex(scratch[1], tw1); + scratch[2] = Deinterleaved::mul_complex(scratch[2], tw2); + scratch[3] = Deinterleaved::mul_complex(scratch[3], tw3); + + let scratch = Deinterleaved::butterfly4(scratch, direction); + + store::(buffer, scratch[0], idx + 0 * num_scalar_columns); + store::(buffer, scratch[1], idx + 1 * num_scalar_columns); + store::(buffer, scratch[2], idx + 2 * num_scalar_columns); + store::(buffer, scratch[3], idx + 3 * num_scalar_columns); + } +} + +#[inline(always)] +unsafe fn load4_complex_f32(buffer: &[Complex], idx: usize) -> [__m128; 2] { + [ + buffer.load_complex(idx), + buffer.load_complex(idx + ::VectorType::COMPLEX_PER_VECTOR), + ] +} +#[inline(always)] +unsafe fn transpose_complex_4x4_f32(rows: [[__m128; 2]; 4]) -> [[__m128; 2]; 4] { + let transposed0 = transpose_complex_2x2_f32(rows[0][0], rows[1][0]); + let transposed1 = transpose_complex_2x2_f32(rows[0][1], rows[1][1]); + let transposed2 = transpose_complex_2x2_f32(rows[2][0], rows[3][0]); + let transposed3 = transpose_complex_2x2_f32(rows[2][1], rows[3][1]); + + [ + [transposed0[0], transposed2[0]], + [transposed0[1], transposed2[1]], + [transposed1[0], transposed3[0]], + [transposed1[1], transposed3[1]], + ] +} +#[inline(always)] +unsafe fn store4_complex_f32(mut buffer: &mut [Complex], data: [__m128; 2], idx: usize) { + buffer.store_complex(data[0], idx); + buffer.store_complex( + data[1], + idx + ::VectorType::COMPLEX_PER_VECTOR, + ); +} + +// Utility to help reorder data as a part of computing RadixD FFTs. Conceputally, it works like a transpose, but with the column indexes bit-reversed. +// Use a lookup table to avoid repeating the slow bit reverse operations. +// Unrolling the outer loop by a factor D helps speed things up. +// const parameter D (for Divisor) determines the divisor to use for the "bit reverse", and how much to unroll. `input.len() / height` must be a power of D. +#[inline(never)] +#[target_feature(enable = "sse4.1")] +pub unsafe fn sse_bitreversed_transpose_f32( + height: usize, + input: &[Complex], + output: &mut [Complex], +) { + let width = input.len() / height; + const WIDTH_UNROLL: usize = 4; + const HEIGHT_UNROLL: usize = ::VectorType::SCALAR_PER_VECTOR; + + // Let's make sure the arguments are ok + assert!( + height % ::VectorType::SCALAR_PER_VECTOR == 0 + && width % WIDTH_UNROLL == 0 + && input.len() % height == 0 + && input.len() == output.len() + ); + + let width_bits = width.trailing_zeros(); + let d_bits = WIDTH_UNROLL.trailing_zeros(); + + // verify that width is a power of d + assert!(width_bits % d_bits == 0); + let rev_digits = width_bits / d_bits; + let strided_width = width / WIDTH_UNROLL; + let strided_height = height / HEIGHT_UNROLL; + for x in 0..strided_width { + let x_rev = [ + reverse_bits::(WIDTH_UNROLL * x + 0, rev_digits) * height, + reverse_bits::(WIDTH_UNROLL * x + 1, rev_digits) * height, + reverse_bits::(WIDTH_UNROLL * x + 2, rev_digits) * height, + reverse_bits::(WIDTH_UNROLL * x + 3, rev_digits) * height, + ]; + + // Assert that the the bit reversed indices will not exceed the length of the output. + // we add HEIGHT_UNROLL * y to each x_rev, which goes up to height exclusive + // so verify that r + height isn't more than our length + for r in x_rev { + assert!(r <= input.len() - height); + } + for y in 0..strided_height { + unsafe { + // Load data in HEIGHT_UNROLL rows, with each row containing WIDTH_UNROLL=4 complex elements + // for f32, HEIGHT_UNROLL=4, this translates to 4 rows of 2 SSE vectors each, + // overall storing 4x4=16 complex elements + let base_input_idx = WIDTH_UNROLL * x + 0 + y * HEIGHT_UNROLL * width; + let rows = [ + load4_complex_f32(input, base_input_idx + width * 0), + load4_complex_f32(input, base_input_idx + width * 1), + load4_complex_f32(input, base_input_idx + width * 2), + load4_complex_f32(input, base_input_idx + width * 3), + ]; + let transposed = transpose_complex_4x4_f32(rows); + + store4_complex_f32(output, transposed[0], HEIGHT_UNROLL * y + x_rev[0]); + store4_complex_f32(output, transposed[1], HEIGHT_UNROLL * y + x_rev[1]); + store4_complex_f32(output, transposed[2], HEIGHT_UNROLL * y + x_rev[2]); + store4_complex_f32(output, transposed[3], HEIGHT_UNROLL * y + x_rev[3]); + } + } + } +} + +#[inline(always)] +unsafe fn load4_complex_f64(buffer: &[Complex], idx: usize) -> [__m128d; 4] { + [ + buffer.load_complex(idx + 0), + buffer.load_complex(idx + 1), + buffer.load_complex(idx + 2), + buffer.load_complex(idx + 3), + ] +} +#[inline(always)] +unsafe fn transpose_complex_4x2_f64(rows: [[__m128d; 4]; 2]) -> [[__m128d; 2]; 4] { + [ + [rows[0][0], rows[1][0]], + [rows[0][1], rows[1][1]], + [rows[0][2], rows[1][2]], + [rows[0][3], rows[1][3]], + ] +} +#[inline(always)] +unsafe fn store2_complex_f64(mut buffer: &mut [Complex], data: [__m128d; 2], idx: usize) { + buffer.store_complex(data[0], idx); + buffer.store_complex( + data[1], + idx + ::VectorType::COMPLEX_PER_VECTOR, + ); +} + +// Utility to help reorder data as a part of computing RadixD FFTs. Conceputally, it works like a transpose, but with the column indexes bit-reversed. +// Use a lookup table to avoid repeating the slow bit reverse operations. +// Unrolling the outer loop by a factor D helps speed things up. +// const parameter D (for Divisor) determines the divisor to use for the "bit reverse", and how much to unroll. `input.len() / height` must be a power of D. +#[inline(never)] +#[target_feature(enable = "sse4.1")] +pub unsafe fn sse_bitreversed_transpose_f64( + height: usize, + input: &[Complex], + output: &mut [Complex], +) { + let width = input.len() / height; + const WIDTH_UNROLL: usize = 4; + const HEIGHT_UNROLL: usize = ::VectorType::SCALAR_PER_VECTOR; + + // Let's make sure the arguments are ok + assert!( + height % ::VectorType::SCALAR_PER_VECTOR == 0 + && width % WIDTH_UNROLL == 0 + && input.len() % height == 0 + && input.len() == output.len() + ); + + let width_bits = width.trailing_zeros(); + let d_bits = WIDTH_UNROLL.trailing_zeros(); + + // verify that width is a power of d + assert!(width_bits % d_bits == 0); + let rev_digits = width_bits / d_bits; + let strided_width = width / WIDTH_UNROLL; + let strided_height = height / HEIGHT_UNROLL; + for x in 0..strided_width { + let x_rev = [ + reverse_bits::(WIDTH_UNROLL * x + 0, rev_digits) * height, + reverse_bits::(WIDTH_UNROLL * x + 1, rev_digits) * height, + reverse_bits::(WIDTH_UNROLL * x + 2, rev_digits) * height, + reverse_bits::(WIDTH_UNROLL * x + 3, rev_digits) * height, ]; - scratcha[1] = SseVector::mul_complex(scratcha[1], tw[0]); - scratcha[2] = SseVector::mul_complex(scratcha[2], tw[1]); - scratcha[3] = SseVector::mul_complex(scratcha[3], tw[2]); - scratchb[1] = SseVector::mul_complex(scratchb[1], tw[3]); - scratchb[2] = SseVector::mul_complex(scratchb[2], tw[4]); - scratchb[3] = SseVector::mul_complex(scratchb[3], tw[5]); - - let scratcha = SseVector::column_butterfly4(scratcha, *rotation); - let scratchb = SseVector::column_butterfly4(scratchb, *rotation); - - buffer.store_complex(scratcha[0], idx + 0 * num_ffts); - buffer.store_complex(scratchb[0], idx + 0 * num_ffts + unroll_offset); - buffer.store_complex(scratcha[1], idx + 1 * num_ffts); - buffer.store_complex(scratchb[1], idx + 1 * num_ffts + unroll_offset); - buffer.store_complex(scratcha[2], idx + 2 * num_ffts); - buffer.store_complex(scratchb[2], idx + 2 * num_ffts + unroll_offset); - buffer.store_complex(scratcha[3], idx + 3 * num_ffts); - buffer.store_complex(scratchb[3], idx + 3 * num_ffts + unroll_offset); - - idx += S::VectorType::COMPLEX_PER_VECTOR * 2; + // Assert that the the bit reversed indices will not exceed the length of the output. + // we add HEIGHT_UNROLL * y to each x_rev, which goes up to height exclusive + // so verify that r + height isn't more than our length + for r in x_rev { + assert!(r <= input.len() - height); + } + for y in 0..strided_height { + unsafe { + // Load data in HEIGHT_UNROLL rows, with each row containing WIDTH_UNROLL=4 complex elements + // for f64, HEIGHT_UNROLL=2, this translates to 2 rows of 4 SSE vectors each, + // overall storing 2x4=8 complex elements + let base_input_idx = WIDTH_UNROLL * x + 0 + y * HEIGHT_UNROLL * width; + let rows = [ + load4_complex_f64(input, base_input_idx + width * 0), + load4_complex_f64(input, base_input_idx + width * 1), + ]; + let transposed = transpose_complex_4x2_f64(rows); + + store2_complex_f64(output, transposed[0], HEIGHT_UNROLL * y + x_rev[0]); + store2_complex_f64(output, transposed[1], HEIGHT_UNROLL * y + x_rev[1]); + store2_complex_f64(output, transposed[2], HEIGHT_UNROLL * y + x_rev[2]); + store2_complex_f64(output, transposed[3], HEIGHT_UNROLL * y + x_rev[3]); + } + } } } diff --git a/src/sse/sse_vector.rs b/src/sse/sse_vector.rs index c223d01..cd87812 100644 --- a/src/sse/sse_vector.rs +++ b/src/sse/sse_vector.rs @@ -148,6 +148,12 @@ pub trait SseVector: Copy + Debug + Send + Sync { unsafe fn store_partial_lo_complex(ptr: *mut Complex, data: Self); unsafe fn store_partial_hi_complex(ptr: *mut Complex, data: Self); + // math ops + unsafe fn add(a: Self, b: Self) -> Self; + unsafe fn sub(a: Self, b: Self) -> Self; + unsafe fn mul(a: Self, b: Self) -> Self; + unsafe fn neg(a: Self) -> Self; + /// Generates a chunk of twiddle factors starting at (X,Y) and incrementing X `COMPLEX_PER_VECTOR` times. /// The result will be [twiddle(x*y, len), twiddle((x+1)*y, len), twiddle((x+2)*y, len), ...] for as many complex numbers fit in a vector unsafe fn make_mixedradix_twiddle_chunk( @@ -157,6 +163,12 @@ pub trait SseVector: Copy + Debug + Send + Sync { direction: FftDirection, ) -> Self; + /// De-interleaves the complex numbers in a and b so that all of the reals are in one vector, and all the imaginaries are in the other + unsafe fn deinterleave(a: Self, b: Self) -> Deinterleaved; + + /// Interleaves the provided real and imaginary values into interleaved complex numbers + unsafe fn interleave(values: Deinterleaved) -> (Self, Self); + /// Pairwise multiply the complex numbers in `left` with the complex numbers in `right`. unsafe fn mul_complex(left: Self, right: Self) -> Self; @@ -207,6 +219,24 @@ impl SseVector for __m128 { _mm_storeh_pd(ptr as *mut f64, _mm_castps_pd(data)); } + #[inline(always)] + unsafe fn add(a: Self, b: Self) -> Self { + _mm_add_ps(a, b) + } + #[inline(always)] + unsafe fn sub(a: Self, b: Self) -> Self { + _mm_sub_ps(a, b) + } + #[inline(always)] + unsafe fn mul(a: Self, b: Self) -> Self { + _mm_mul_ps(a, b) + } + #[inline(always)] + unsafe fn neg(a: Self) -> Self { + let neg_vector = _mm_set1_ps(-0.0); + _mm_xor_ps(a, neg_vector) + } + #[inline(always)] unsafe fn make_mixedradix_twiddle_chunk( x: usize, @@ -222,6 +252,21 @@ impl SseVector for __m128 { twiddle_chunk.as_slice().load_complex(0) } + #[inline(always)] + unsafe fn deinterleave(a: Self, b: Self) -> Deinterleaved { + Deinterleaved { + re: _mm_shuffle_ps(a, b, 0x88), + im: _mm_shuffle_ps(a, b, 0xDD), + } + } + #[inline(always)] + unsafe fn interleave(values: Deinterleaved) -> (Self, Self) { + ( + _mm_unpacklo_ps(values.re, values.im), + _mm_unpackhi_ps(values.re, values.im), + ) + } + #[inline(always)] unsafe fn mul_complex(left: Self, right: Self) -> Self { //SSE3, taken from Intel performance manual @@ -308,6 +353,24 @@ impl SseVector for __m128d { unimplemented!("Impossible to do a partial store of complex f64's"); } + #[inline(always)] + unsafe fn add(a: Self, b: Self) -> Self { + _mm_add_pd(a, b) + } + #[inline(always)] + unsafe fn sub(a: Self, b: Self) -> Self { + _mm_sub_pd(a, b) + } + #[inline(always)] + unsafe fn mul(a: Self, b: Self) -> Self { + _mm_mul_pd(a, b) + } + #[inline(always)] + unsafe fn neg(a: Self) -> Self { + let neg_vector = _mm_set1_pd(-0.0); + _mm_xor_pd(a, neg_vector) + } + #[inline(always)] unsafe fn make_mixedradix_twiddle_chunk( x: usize, @@ -323,6 +386,21 @@ impl SseVector for __m128d { twiddle_chunk.as_slice().load_complex(0) } + #[inline(always)] + unsafe fn deinterleave(a: Self, b: Self) -> Deinterleaved { + Deinterleaved { + re: _mm_unpacklo_pd(a, b), + im: _mm_unpackhi_pd(a, b), + } + } + #[inline(always)] + unsafe fn interleave(values: Deinterleaved) -> (Self, Self) { + ( + _mm_unpacklo_pd(values.re, values.im), + _mm_unpackhi_pd(values.re, values.im), + ) + } + #[inline(always)] unsafe fn mul_complex(left: Self, right: Self) -> Self { // SSE3, taken from Intel performance manual @@ -491,3 +569,112 @@ where self.output.store_partial_hi_complex(vector, index); } } + +#[derive(Copy, Clone, Debug)] +pub struct Deinterleaved { + pub re: V, + pub im: V, +} + +impl Deinterleaved { + pub unsafe fn butterfly2(values: [Self; 2]) -> [Self; 2] { + let tmp0 = Self { + re: SseVector::add(values[0].re, values[1].re), + im: SseVector::add(values[0].im, values[1].im), + }; + let tmp1 = Self { + re: SseVector::sub(values[0].re, values[1].re), + im: SseVector::sub(values[0].im, values[1].im), + }; + [tmp0, tmp1] + } + pub unsafe fn butterfly4(values: [Self; 4], direction: FftDirection) -> [Self; 4] { + let tmp0 = Self::butterfly2([values[0], values[2]]); + let mut tmp1 = Self::butterfly2([values[1], values[3]]); + + // butterflies for butterfly4 is just swapping the re and im of tmp1[1], then negating one or the other based on direction + if direction == FftDirection::Forward { + let tmp = tmp1[1].re; + tmp1[1].re = tmp1[1].im; + tmp1[1].im = SseVector::neg(tmp); + } else { + let tmp = tmp1[1].im; + tmp1[1].im = tmp1[1].re; + tmp1[1].re = SseVector::neg(tmp); + } + + let out0 = Self::butterfly2([tmp0[0], tmp1[0]]); + let out1 = Self::butterfly2([tmp0[1], tmp1[1]]); + + [out0[0], out1[0], out0[1], out1[1]] + } + pub unsafe fn mul_complex(a: Self, b: Self) -> Self { + Self { + re: SseVector::sub(SseVector::mul(a.re, b.re), SseVector::mul(a.im, b.im)), + im: SseVector::add(SseVector::mul(a.re, b.im), SseVector::mul(a.im, b.re)), + } + } +} + +#[cfg(test)] +mod unit_tests { + use std::arch::x86_64::{_mm_store_pd, _mm_store_ps}; + + use num_complex::Complex; + use num_traits::Zero; + + use super::{SseArray, SseArrayMut, SseVector}; + + #[test] + fn test_interleave() { + unsafe { + let data_complex: &[Complex] = &[ + Complex { re: 0.0, im: 0.5 }, + Complex { re: 1.0, im: 1.5 }, + Complex { re: 2.0, im: 2.5 }, + Complex { re: 3.0, im: 3.5 }, + ]; + + let deinterleaved = + SseVector::deinterleave(data_complex.load_complex(0), data_complex.load_complex(2)); + + let mut deinterleaved_re = [0.0f32; 4]; + let mut deinterleaved_im = [0.0f32; 4]; + _mm_store_ps(deinterleaved_re.as_mut_ptr(), deinterleaved.re); + _mm_store_ps(deinterleaved_im.as_mut_ptr(), deinterleaved.im); + + assert_eq!(deinterleaved_re, [0.0, 1.0, 2.0, 3.0]); + assert_eq!(deinterleaved_im, [0.5, 1.5, 2.5, 3.5]); + + let (reinterleaved_a, reinterleaved_b) = SseVector::interleave(deinterleaved); + let mut reinterleaved_complex: &mut [Complex] = &mut [Complex::zero(); 4]; + reinterleaved_complex.store_complex(reinterleaved_a, 0); + reinterleaved_complex.store_complex(reinterleaved_b, 2); + + assert_eq!(data_complex, reinterleaved_complex); + } + + unsafe { + let data_complex: &[Complex] = + &[Complex { re: 0.0, im: 0.5 }, Complex { re: 1.0, im: 1.5 }]; + + let deinterleaved = + SseVector::deinterleave(data_complex.load_complex(0), data_complex.load_complex(1)); + + let mut deinterleaved_re = [0.0f64; 2]; + let mut deinterleaved_im = [0.0f64; 2]; + _mm_store_pd(deinterleaved_re.as_mut_ptr(), deinterleaved.re); + _mm_store_pd(deinterleaved_im.as_mut_ptr(), deinterleaved.im); + + assert_eq!(deinterleaved_re, [0.0, 1.0]); + assert_eq!(deinterleaved_im, [0.5, 1.5]); + + let (reinterleaved_a, reinterleaved_b) = SseVector::interleave(deinterleaved); + let mut reinterleaved_complex: &mut [Complex] = &mut [Complex::zero(); 2]; + reinterleaved_complex.store_complex(reinterleaved_a, 0); + reinterleaved_complex.store_complex(reinterleaved_b, 1); + + assert_eq!(data_complex, reinterleaved_complex); + } + } +}