Skip to content

Commit

Permalink
avx512: Avoid continuously reloading the look up table
Browse files Browse the repository at this point in the history
  • Loading branch information
AndersTrier committed Oct 9, 2024
1 parent 945c03e commit 78224ac
Showing 1 changed file with 60 additions and 28 deletions.
88 changes: 60 additions & 28 deletions src/engine/engine_avx512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,22 @@ impl Default for Avx512 {
//
//

impl Avx512 {
#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
unsafe fn mul_avx512(&self, x: &mut [[u8; 64]], log_m: GfElement) {
let lut = &self.mul128[log_m as usize];

for chunk in x.iter_mut() {
let x_ptr = chunk.as_mut_ptr() as *mut i32;
unsafe {
let x = _mm512_loadu_si512(x_ptr);
let prod = Self::mul_512(x, lut);
_mm512_storeu_si512(x_ptr, prod);
}
}
}
#[derive(Copy, Clone)]
struct LutAvx512 {
t0_t2_lo: __m512i,
t0_t2_hi: __m512i,
t1_t3_lo: __m512i,
t1_t3_hi: __m512i,
}

// Impelemntation of LEO_MUL_256
impl From<&Multiply128lutT> for LutAvx512 {
#[inline(always)]
fn mul_512(value: __m512i, lut: &Multiply128lutT) -> __m512i {
fn from(lut: &Multiply128lutT) -> Self {
let t0_t2_lo: __m512i;
let t0_t2_hi: __m512i;
let t1_t3_lo: __m512i;
let t1_t3_hi: __m512i;

unsafe {
let t0_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.lo[0] as *const u128 as *const __m128i,
Expand Down Expand Up @@ -138,20 +136,53 @@ impl Avx512 {
&lut.hi[3] as *const u128 as *const __m128i,
));

let t0_t2_lo = _mm512_inserti64x4(_mm512_castsi256_si512(t0_lo), t2_lo, 1);
let t0_t2_hi = _mm512_inserti64x4(_mm512_castsi256_si512(t0_hi), t2_hi, 1);
let t1_t3_lo = _mm512_inserti64x4(_mm512_castsi256_si512(t1_lo), t3_lo, 1);
let t1_t3_hi = _mm512_inserti64x4(_mm512_castsi256_si512(t1_hi), t3_hi, 1);
t0_t2_lo = _mm512_inserti64x4(_mm512_castsi256_si512(t0_lo), t2_lo, 1);
t0_t2_hi = _mm512_inserti64x4(_mm512_castsi256_si512(t0_hi), t2_hi, 1);
t1_t3_lo = _mm512_inserti64x4(_mm512_castsi256_si512(t1_lo), t3_lo, 1);
t1_t3_hi = _mm512_inserti64x4(_mm512_castsi256_si512(t1_hi), t3_hi, 1);
}

LutAvx512 {
t0_t2_lo,
t0_t2_hi,
t1_t3_lo,
t1_t3_hi,
}
}
}

impl Avx512 {
#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
unsafe fn mul_avx512(&self, x: &mut [[u8; 64]], log_m: GfElement) {
let lut = &self.mul128[log_m as usize];

let lut_avx512 = LutAvx512::from(lut);

for chunk in x.iter_mut() {
let x_ptr = chunk.as_mut_ptr() as *mut i32;
unsafe {
let x = _mm512_loadu_si512(x_ptr);
let prod = Self::mul_512(x, lut_avx512);
_mm512_storeu_si512(x_ptr, prod);
}
}
}

// Impelemntation of LEO_MUL_256
#[inline(always)]
fn mul_512(value: __m512i, lut_avx512: LutAvx512) -> __m512i {
unsafe {
let clr_mask = _mm512_set1_epi8(0x0f);

let data = _mm512_and_si512(value, clr_mask);
let mut prod_lo_512 = _mm512_shuffle_epi8(t0_t2_lo, data);
let mut prod_hi_512 = _mm512_shuffle_epi8(t0_t2_hi, data);
let mut prod_lo_512 = _mm512_shuffle_epi8(lut_avx512.t0_t2_lo, data);
let mut prod_hi_512 = _mm512_shuffle_epi8(lut_avx512.t0_t2_hi, data);

let data = _mm512_and_si512(_mm512_srli_epi64(value, 4), clr_mask);
prod_lo_512 = _mm512_xor_si512(prod_lo_512, _mm512_shuffle_epi8(t1_t3_lo, data));
prod_hi_512 = _mm512_xor_si512(prod_hi_512, _mm512_shuffle_epi8(t1_t3_hi, data));
prod_lo_512 =
_mm512_xor_si512(prod_lo_512, _mm512_shuffle_epi8(lut_avx512.t1_t3_lo, data));
prod_hi_512 =
_mm512_xor_si512(prod_hi_512, _mm512_shuffle_epi8(lut_avx512.t1_t3_hi, data));

// XOR first half with second half of vector
let prod_lo = _mm256_xor_si256(
Expand All @@ -169,9 +200,8 @@ impl Avx512 {

//// {x_lo, x_hi} ^= {y_lo, y_hi} * log_m
// Implementation of LEO_MULADD_256
#[allow(clippy::too_many_arguments)]
#[inline(always)]
fn muladd_512(x: __m512i, y: __m512i, lut: &Multiply128lutT) -> __m512i {
fn muladd_512(x: __m512i, y: __m512i, lut: LutAvx512) -> __m512i {
unsafe {
let prod = Self::mul_512(y, lut);
_mm512_xor_si512(x, prod)
Expand All @@ -188,6 +218,7 @@ impl Avx512 {
#[inline(always)]
fn fft_butterfly_partial(&self, x: &mut [[u8; 64]], y: &mut [[u8; 64]], log_m: GfElement) {
let lut = &self.mul128[log_m as usize];
let lut_avx512 = LutAvx512::from(lut);

for (x_chunk, y_chunk) in zip(x.iter_mut(), y.iter_mut()) {
let x_ptr = x_chunk.as_mut_ptr() as *mut i32;
Expand All @@ -197,7 +228,7 @@ impl Avx512 {
let mut x = _mm512_loadu_si512(x_ptr);
let mut y = _mm512_loadu_si512(y_ptr);

x = Self::muladd_512(x, y, lut);
x = Self::muladd_512(x, y, lut_avx512);
y = _mm512_xor_si512(y, x);

_mm512_storeu_si512(x_ptr, x);
Expand Down Expand Up @@ -317,6 +348,7 @@ impl Avx512 {
#[inline(always)]
fn ifft_butterfly_partial(&self, x: &mut [[u8; 64]], y: &mut [[u8; 64]], log_m: GfElement) {
let lut = &self.mul128[log_m as usize];
let lut_avx512 = LutAvx512::from(lut);

for (x_chunk, y_chunk) in zip(&mut x.iter_mut(), &mut y.iter_mut()) {
let x_ptr = x_chunk.as_mut_ptr() as *mut i32;
Expand All @@ -327,7 +359,7 @@ impl Avx512 {
let mut y = _mm512_loadu_si512(y_ptr);

y = _mm512_xor_si512(y, x);
x = Self::muladd_512(x, y, lut);
x = Self::muladd_512(x, y, lut_avx512);

_mm512_storeu_si512(x_ptr, x);
_mm512_storeu_si512(y_ptr, y);
Expand Down

0 comments on commit 78224ac

Please sign in to comment.