diff --git a/src/engine/engine_avx2.rs b/src/engine/engine_avx2.rs index 40174e5..e2b7830 100644 --- a/src/engine/engine_avx2.rs +++ b/src/engine/engine_avx2.rs @@ -93,17 +93,64 @@ impl Default for Avx2 { // // +#[derive(Copy, Clone)] +struct LutAvx2 { + t0_lo: __m256i, + t1_lo: __m256i, + t2_lo: __m256i, + t3_lo: __m256i, + t0_hi: __m256i, + t1_hi: __m256i, + t2_hi: __m256i, + t3_hi: __m256i, +} + +impl From<&Multiply128lutT> for LutAvx2 { + #[inline(always)] + fn from(lut: &Multiply128lutT) -> Self { + unsafe { + LutAvx2 { + t0_lo: _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.lo[0] as *const u128 as *const __m128i, + )), + t1_lo: _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.lo[1] as *const u128 as *const __m128i, + )), + t2_lo: _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.lo[2] as *const u128 as *const __m128i, + )), + t3_lo: _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.lo[3] as *const u128 as *const __m128i, + )), + t0_hi: _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.hi[0] as *const u128 as *const __m128i, + )), + t1_hi: _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.hi[1] as *const u128 as *const __m128i, + )), + t2_hi: _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.hi[2] as *const u128 as *const __m128i, + )), + t3_hi: _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.hi[3] as *const u128 as *const __m128i, + )), + } + } + } +} + impl Avx2 { #[target_feature(enable = "avx2")] unsafe fn mul_avx2(&self, x: &mut [[u8; 64]], log_m: GfElement) { let lut = &self.mul128[log_m as usize]; + let lut_avx2 = LutAvx2::from(lut); for chunk in x.iter_mut() { let x_ptr = chunk.as_mut_ptr() as *mut __m256i; unsafe { let x_lo = _mm256_loadu_si256(x_ptr); let x_hi = _mm256_loadu_si256(x_ptr.add(1)); - let (prod_lo, prod_hi) = Self::mul_256(x_lo, x_hi, lut); + let (prod_lo, prod_hi) = Self::mul_256(x_lo, x_hi, lut_avx2); _mm256_storeu_si256(x_ptr, prod_lo); _mm256_storeu_si256(x_ptr.add(1), prod_hi); } @@ -112,54 +159,28 @@ impl Avx2 { // Impelemntation of LEO_MUL_256 #[inline(always)] - fn mul_256(value_lo: __m256i, value_hi: __m256i, lut: &Multiply128lutT) -> (__m256i, __m256i) { + fn mul_256(value_lo: __m256i, value_hi: __m256i, lut_avx2: LutAvx2) -> (__m256i, __m256i) { let mut prod_lo: __m256i; let mut prod_hi: __m256i; unsafe { - let t0_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.lo[0] as *const u128 as *const __m128i, - )); - let t1_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.lo[1] as *const u128 as *const __m128i, - )); - let t2_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.lo[2] as *const u128 as *const __m128i, - )); - let t3_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.lo[3] as *const u128 as *const __m128i, - )); - - let t0_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.hi[0] as *const u128 as *const __m128i, - )); - let t1_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.hi[1] as *const u128 as *const __m128i, - )); - let t2_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.hi[2] as *const u128 as *const __m128i, - )); - let t3_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.hi[3] as *const u128 as *const __m128i, - )); - let clr_mask = _mm256_set1_epi8(0x0f); let data_0 = _mm256_and_si256(value_lo, clr_mask); - prod_lo = _mm256_shuffle_epi8(t0_lo, data_0); - prod_hi = _mm256_shuffle_epi8(t0_hi, data_0); + prod_lo = _mm256_shuffle_epi8(lut_avx2.t0_lo, data_0); + prod_hi = _mm256_shuffle_epi8(lut_avx2.t0_hi, data_0); let data_1 = _mm256_and_si256(_mm256_srli_epi64(value_lo, 4), clr_mask); - prod_lo = _mm256_xor_si256(prod_lo, _mm256_shuffle_epi8(t1_lo, data_1)); - prod_hi = _mm256_xor_si256(prod_hi, _mm256_shuffle_epi8(t1_hi, data_1)); + prod_lo = _mm256_xor_si256(prod_lo, _mm256_shuffle_epi8(lut_avx2.t1_lo, data_1)); + prod_hi = _mm256_xor_si256(prod_hi, _mm256_shuffle_epi8(lut_avx2.t1_hi, data_1)); let data_0 = _mm256_and_si256(value_hi, clr_mask); - prod_lo = _mm256_xor_si256(prod_lo, _mm256_shuffle_epi8(t2_lo, data_0)); - prod_hi = _mm256_xor_si256(prod_hi, _mm256_shuffle_epi8(t2_hi, data_0)); + prod_lo = _mm256_xor_si256(prod_lo, _mm256_shuffle_epi8(lut_avx2.t2_lo, data_0)); + prod_hi = _mm256_xor_si256(prod_hi, _mm256_shuffle_epi8(lut_avx2.t2_hi, data_0)); let data_1 = _mm256_and_si256(_mm256_srli_epi64(value_hi, 4), clr_mask); - prod_lo = _mm256_xor_si256(prod_lo, _mm256_shuffle_epi8(t3_lo, data_1)); - prod_hi = _mm256_xor_si256(prod_hi, _mm256_shuffle_epi8(t3_hi, data_1)); + prod_lo = _mm256_xor_si256(prod_lo, _mm256_shuffle_epi8(lut_avx2.t3_lo, data_1)); + prod_hi = _mm256_xor_si256(prod_hi, _mm256_shuffle_epi8(lut_avx2.t3_hi, data_1)); } (prod_lo, prod_hi) @@ -173,9 +194,9 @@ impl Avx2 { mut x_hi: __m256i, y_lo: __m256i, y_hi: __m256i, - lut: &Multiply128lutT, + lut_avx2: LutAvx2, ) -> (__m256i, __m256i) { - let (prod_lo, prod_hi) = Self::mul_256(y_lo, y_hi, lut); + let (prod_lo, prod_hi) = Self::mul_256(y_lo, y_hi, lut_avx2); unsafe { x_lo = _mm256_xor_si256(x_lo, prod_lo); x_hi = _mm256_xor_si256(x_hi, prod_hi); @@ -190,10 +211,10 @@ impl Avx2 { impl Avx2 { // Implementation of LEO_FFTB_256 #[inline(always)] - fn fftb_256(&self, x: &mut [u8; 64], y: &mut [u8; 64], log_m: GfElement) { - let lut = &self.mul128[log_m as usize]; + fn fftb_256(&self, x: &mut [u8; 64], y: &mut [u8; 64], lut_avx2: LutAvx2) { let x_ptr = x.as_mut_ptr() as *mut __m256i; let y_ptr = y.as_mut_ptr() as *mut __m256i; + unsafe { let mut x_lo = _mm256_loadu_si256(x_ptr); let mut x_hi = _mm256_loadu_si256(x_ptr.add(1)); @@ -201,7 +222,7 @@ impl Avx2 { let mut y_lo = _mm256_loadu_si256(y_ptr); let mut y_hi = _mm256_loadu_si256(y_ptr.add(1)); - (x_lo, x_hi) = Self::muladd_256(x_lo, x_hi, y_lo, y_hi, lut); + (x_lo, x_hi) = Self::muladd_256(x_lo, x_hi, y_lo, y_hi, lut_avx2); _mm256_storeu_si256(x_ptr, x_lo); _mm256_storeu_si256(x_ptr.add(1), x_hi); @@ -217,8 +238,11 @@ impl Avx2 { // Partial butterfly, caller must do `GF_MODULUS` check with `xor`. #[inline(always)] fn fft_butterfly_partial(&self, x: &mut [[u8; 64]], y: &mut [[u8; 64]], log_m: GfElement) { + let lut = &self.mul128[log_m as usize]; + let lut_avx2 = LutAvx2::from(lut); + for (x_chunk, y_chunk) in zip(x.iter_mut(), y.iter_mut()) { - self.fftb_256(x_chunk, y_chunk, log_m); + self.fftb_256(x_chunk, y_chunk, lut_avx2); } } @@ -331,8 +355,7 @@ impl Avx2 { impl Avx2 { // Implementation of LEO_IFFTB_256 #[inline(always)] - fn ifftb_256(&self, x: &mut [u8; 64], y: &mut [u8; 64], log_m: GfElement) { - let lut = &self.mul128[log_m as usize]; + fn ifftb_256(&self, x: &mut [u8; 64], y: &mut [u8; 64], lut_avx2: LutAvx2) { let x_ptr = x.as_mut_ptr() as *mut __m256i; let y_ptr = y.as_mut_ptr() as *mut __m256i; @@ -349,7 +372,7 @@ impl Avx2 { _mm256_storeu_si256(y_ptr, y_lo); _mm256_storeu_si256(y_ptr.add(1), y_hi); - (x_lo, x_hi) = Self::muladd_256(x_lo, x_hi, y_lo, y_hi, lut); + (x_lo, x_hi) = Self::muladd_256(x_lo, x_hi, y_lo, y_hi, lut_avx2); _mm256_storeu_si256(x_ptr, x_lo); _mm256_storeu_si256(x_ptr.add(1), x_hi); @@ -358,8 +381,11 @@ impl Avx2 { #[inline(always)] fn ifft_butterfly_partial(&self, x: &mut [[u8; 64]], y: &mut [[u8; 64]], log_m: GfElement) { + let lut = &self.mul128[log_m as usize]; + let lut_avx2 = LutAvx2::from(lut); + for (x_chunk, y_chunk) in zip(x.iter_mut(), y.iter_mut()) { - self.ifftb_256(x_chunk, y_chunk, log_m); + self.ifftb_256(x_chunk, y_chunk, lut_avx2); } }