diff --git a/hexl/ntt/inv-ntt-avx512.cpp b/hexl/ntt/inv-ntt-avx512.cpp index a94d4e91..fd672e39 100644 --- a/hexl/ntt/inv-ntt-avx512.cpp +++ b/hexl/ntt/inv-ntt-avx512.cpp @@ -64,6 +64,7 @@ template void InverseTransformFromBitReverseAVX512( template inline void InvButterfly(__m512i* X, __m512i* Y, __m512i W, __m512i W_precon, __m512i neg_modulus, __m512i twice_modulus) { + // Compute T first to allow in-place update of X __m512i Y_minus_2q = _mm512_sub_epi64(*Y, twice_modulus); __m512i T = _mm512_sub_epi64(*X, Y_minus_2q); @@ -71,6 +72,10 @@ inline void InvButterfly(__m512i* X, __m512i* Y, __m512i W, __m512i W_precon, // No need for modulus reduction, since inputs are in [0, q) *X = _mm512_add_epi64(*X, *Y); } else { + // Algorithm 3 computes (X >= 2q) ? (X - 2q) : X + // We instead compute (X - 2q >= 0) ? (X - 2q) : X + // This allows us to use the faster _mm512_movepi64_mask rather than + // _mm512_cmp_epu64_mask to create the mask. *X = _mm512_add_epi64(*X, Y_minus_2q); __mmask8 sign_bits = _mm512_movepi64_mask(*X); *X = _mm512_mask_add_epi64(*X, sign_bits, *X, twice_modulus);