diff --git a/BIBLIOGRAPHY.md b/BIBLIOGRAPHY.md index b8954d05..e8a264d6 100644 --- a/BIBLIOGRAPHY.md +++ b/BIBLIOGRAPHY.md @@ -150,6 +150,10 @@ source code and documentation. - [mldsa/native/x86_64/src/intt.S](mldsa/native/x86_64/src/intt.S) - [mldsa/native/x86_64/src/ntt.S](mldsa/native/x86_64/src/ntt.S) - [mldsa/native/x86_64/src/nttunpack.S](mldsa/native/x86_64/src/nttunpack.S) + - [mldsa/native/x86_64/src/pointwise.S](mldsa/native/x86_64/src/pointwise.S) + - [mldsa/native/x86_64/src/pointwise_acc_l4.S](mldsa/native/x86_64/src/pointwise_acc_l4.S) + - [mldsa/native/x86_64/src/pointwise_acc_l5.S](mldsa/native/x86_64/src/pointwise_acc_l5.S) + - [mldsa/native/x86_64/src/pointwise_acc_l7.S](mldsa/native/x86_64/src/pointwise_acc_l7.S) - [mldsa/native/x86_64/src/poly_caddq_avx2.c](mldsa/native/x86_64/src/poly_caddq_avx2.c) - [mldsa/native/x86_64/src/poly_chknorm_avx2.c](mldsa/native/x86_64/src/poly_chknorm_avx2.c) - [mldsa/native/x86_64/src/poly_decompose_32_avx2.c](mldsa/native/x86_64/src/poly_decompose_32_avx2.c) diff --git a/mldsa/native/aarch64/meta.h b/mldsa/native/aarch64/meta.h index 3c328a03..38b0d513 100644 --- a/mldsa/native/aarch64/meta.h +++ b/mldsa/native/aarch64/meta.h @@ -21,6 +21,10 @@ #define MLD_USE_NATIVE_POLY_CHKNORM #define MLD_USE_NATIVE_POLYZ_UNPACK_17 #define MLD_USE_NATIVE_POLYZ_UNPACK_19 +#define MLD_USE_NATIVE_POINTWISE_MONTGOMERY +#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 +#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5 +#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7 /* Identifier for this backend so that source and assembly files * in the build can be appropriately guarded. */ @@ -147,5 +151,36 @@ static MLD_INLINE void mld_polyz_unpack_19_native(int32_t *r, mld_polyz_unpack_19_asm(r, buf, mld_polyz_unpack_19_indices); } +static MLD_INLINE void mld_poly_pointwise_montgomery_native( + int32_t out[MLDSA_N], const int32_t in0[MLDSA_N], + const int32_t in1[MLDSA_N]) +{ + mld_poly_pointwise_montgomery_asm(out, in0, in1); +} + +static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native( + int32_t w[MLDSA_N], const int32_t u[4][MLDSA_N], + const int32_t v[4][MLDSA_N]) +{ + mld_polyvecl_pointwise_acc_montgomery_l4_asm(w, (const int32_t *)u, + (const int32_t *)v); +} + +static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native( + int32_t w[MLDSA_N], const int32_t u[5][MLDSA_N], + const int32_t v[5][MLDSA_N]) +{ + mld_polyvecl_pointwise_acc_montgomery_l5_asm(w, (const int32_t *)u, + (const int32_t *)v); +} + +static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native( + int32_t w[MLDSA_N], const int32_t u[7][MLDSA_N], + const int32_t v[7][MLDSA_N]) +{ + mld_polyvecl_pointwise_acc_montgomery_l7_asm(w, (const int32_t *)u, + (const int32_t *)v); +} + #endif /* !__ASSEMBLER__ */ #endif /* !MLD_NATIVE_AARCH64_META_H */ diff --git a/mldsa/native/aarch64/src/arith_native_aarch64.h b/mldsa/native/aarch64/src/arith_native_aarch64.h index 92db2a51..c55a205c 100644 --- a/mldsa/native/aarch64/src/arith_native_aarch64.h +++ b/mldsa/native/aarch64/src/arith_native_aarch64.h @@ -93,4 +93,24 @@ void mld_polyz_unpack_17_asm(int32_t *r, const uint8_t *buf, void mld_polyz_unpack_19_asm(int32_t *r, const uint8_t *buf, const uint8_t *indices); +#define mld_poly_pointwise_montgomery_asm \ + MLD_NAMESPACE(poly_pointwise_montgomery_asm) +void mld_poly_pointwise_montgomery_asm(int32_t *, const int32_t *, + const int32_t *); + +#define mld_polyvecl_pointwise_acc_montgomery_l4_asm \ + MLD_NAMESPACE(polyvecl_pointwise_acc_montgomery_l4_asm) +void mld_polyvecl_pointwise_acc_montgomery_l4_asm(int32_t *, const int32_t *, + const int32_t *); + +#define mld_polyvecl_pointwise_acc_montgomery_l5_asm \ + MLD_NAMESPACE(polyvecl_pointwise_acc_montgomery_l5_asm) +void mld_polyvecl_pointwise_acc_montgomery_l5_asm(int32_t *, const int32_t *, + const int32_t *); + +#define mld_polyvecl_pointwise_acc_montgomery_l7_asm \ + MLD_NAMESPACE(polyvecl_pointwise_acc_montgomery_l7_asm) +void mld_polyvecl_pointwise_acc_montgomery_l7_asm(int32_t *, const int32_t *, + const int32_t *); + #endif /* !MLD_NATIVE_AARCH64_SRC_ARITH_NATIVE_AARCH64_H */ diff --git a/mldsa/native/aarch64/src/mld_polyvecl_pointwise_acc_montgomery_l4.S b/mldsa/native/aarch64/src/mld_polyvecl_pointwise_acc_montgomery_l4.S new file mode 100644 index 00000000..32a8a9a5 --- /dev/null +++ b/mldsa/native/aarch64/src/mld_polyvecl_pointwise_acc_montgomery_l4.S @@ -0,0 +1,126 @@ +/* Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_AARCH64) + +.macro montgomery_reduce_long res, inl, inh + uzp1 t0.4s, \inl\().4s, \inh\().4s + mul t0.4s, t0.4s, modulus_twisted.4s + smlal \inl\().2d, t0.2s, modulus.2s + smlal2 \inh\().2d, t0.4s, modulus.4s + uzp2 \res\().4s, \inl\().4s, \inh\().4s +.endm + +.macro load_polys a, b, a_ptr, b_ptr + ldr q_\()\a, [\a_ptr], #16 + ldr q_\()\b, [\b_ptr], #16 +.endm + +.macro pmull dl, dh, a, b + smull \dl\().2d, \a\().2s, \b\().2s + smull2 \dh\().2d, \a\().4s, \b\().4s +.endm + +.macro pmlal dl, dh, a, b + smlal \dl\().2d, \a\().2s, \b\().2s + smlal2 \dh\().2d, \a\().4s, \b\().4s +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + +out_ptr .req x0 +a0_ptr .req x1 +b0_ptr .req x2 +a1_ptr .req x3 +b1_ptr .req x4 +a2_ptr .req x5 +b2_ptr .req x6 +a3_ptr .req x7 +b3_ptr .req x8 +count .req x9 +wtmp .req w9 + +modulus .req v0 +modulus_twisted .req v1 + +aa .req v2 +bb .req v3 +res .req v4 +resl .req v5 +resh .req v6 +t0 .req v7 + +q_aa .req q2 +q_bb .req q3 +q_res .req q4 + +.text +.global MLD_ASM_NAMESPACE(polyvecl_pointwise_acc_montgomery_l4_asm) +.balign 4 +MLD_ASM_FN_SYMBOL(polyvecl_pointwise_acc_montgomery_l4_asm) + push_stack + + // load q = 8380417 + movz wtmp, #57345 + movk wtmp, #127, lsl #16 + dup modulus.4s, wtmp + + // load -q^-1 = 4236238847 + movz wtmp, #57343 + movk wtmp, #64639, lsl #16 + dup modulus_twisted.4s, wtmp + + // Computed bases of vector entries + add a1_ptr, a0_ptr, #(1 * 1024) + add a2_ptr, a0_ptr, #(2 * 1024) + add a3_ptr, a0_ptr, #(3 * 1024) + + add b1_ptr, b0_ptr, #(1 * 1024) + add b2_ptr, b0_ptr, #(2 * 1024) + add b3_ptr, b0_ptr, #(3 * 1024) + + mov count, #(MLDSA_N / 4) +l4_loop_start: + load_polys aa, bb, a0_ptr, b0_ptr + pmull resl, resh, aa, bb + load_polys aa, bb, a1_ptr, b1_ptr + pmlal resl, resh, aa, bb + load_polys aa, bb, a2_ptr, b2_ptr + pmlal resl, resh, aa, bb + load_polys aa, bb, a3_ptr, b3_ptr + pmlal resl, resh, aa, bb + + montgomery_reduce_long res, resl, resh + + str q_res, [out_ptr], #16 + + subs count, count, #1 + cbnz count, l4_loop_start + + pop_stack + ret +#endif /* MLD_ARITH_BACKEND_AARCH64 */ diff --git a/mldsa/native/aarch64/src/mld_polyvecl_pointwise_acc_montgomery_l5.S b/mldsa/native/aarch64/src/mld_polyvecl_pointwise_acc_montgomery_l5.S new file mode 100644 index 00000000..eea407e0 --- /dev/null +++ b/mldsa/native/aarch64/src/mld_polyvecl_pointwise_acc_montgomery_l5.S @@ -0,0 +1,132 @@ +/* Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_AARCH64) + +.macro montgomery_reduce_long res, inl, inh + uzp1 t0.4s, \inl\().4s, \inh\().4s + mul t0.4s, t0.4s, modulus_twisted.4s + smlal \inl\().2d, t0.2s, modulus.2s + smlal2 \inh\().2d, t0.4s, modulus.4s + uzp2 \res\().4s, \inl\().4s, \inh\().4s +.endm + +.macro load_polys a, b, a_ptr, b_ptr + ldr q_\()\a, [\a_ptr], #16 + ldr q_\()\b, [\b_ptr], #16 +.endm + +.macro pmull dl, dh, a, b + smull \dl\().2d, \a\().2s, \b\().2s + smull2 \dh\().2d, \a\().4s, \b\().4s +.endm + +.macro pmlal dl, dh, a, b + smlal \dl\().2d, \a\().2s, \b\().2s + smlal2 \dh\().2d, \a\().4s, \b\().4s +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + +out_ptr .req x0 +a0_ptr .req x1 +b0_ptr .req x2 +a1_ptr .req x3 +b1_ptr .req x4 +a2_ptr .req x5 +b2_ptr .req x6 +a3_ptr .req x7 +b3_ptr .req x8 +a4_ptr .req x9 +b4_ptr .req x10 +count .req x11 +wtmp .req w11 + +modulus .req v0 +modulus_twisted .req v1 + +aa .req v2 +bb .req v3 +res .req v4 +resl .req v5 +resh .req v6 +t0 .req v7 + +q_aa .req q2 +q_bb .req q3 +q_res .req q4 + +.text +.global MLD_ASM_NAMESPACE(polyvecl_pointwise_acc_montgomery_l5_asm) +.balign 4 +MLD_ASM_FN_SYMBOL(polyvecl_pointwise_acc_montgomery_l5_asm) + push_stack + + // load q = 8380417 + movz wtmp, #57345 + movk wtmp, #127, lsl #16 + dup modulus.4s, wtmp + + // load -q^-1 = 4236238847 + movz wtmp, #57343 + movk wtmp, #64639, lsl #16 + dup modulus_twisted.4s, wtmp + + // Computed bases of vector entries + add a1_ptr, a0_ptr, #(1 * 1024) + add a2_ptr, a0_ptr, #(2 * 1024) + add a3_ptr, a0_ptr, #(3 * 1024) + add a4_ptr, a0_ptr, #(4 * 1024) + + add b1_ptr, b0_ptr, #(1 * 1024) + add b2_ptr, b0_ptr, #(2 * 1024) + add b3_ptr, b0_ptr, #(3 * 1024) + add b4_ptr, b0_ptr, #(4 * 1024) + + mov count, #(MLDSA_N / 4) +l5_loop_start: + load_polys aa, bb, a0_ptr, b0_ptr + pmull resl, resh, aa, bb + load_polys aa, bb, a1_ptr, b1_ptr + pmlal resl, resh, aa, bb + load_polys aa, bb, a2_ptr, b2_ptr + pmlal resl, resh, aa, bb + load_polys aa, bb, a3_ptr, b3_ptr + pmlal resl, resh, aa, bb + load_polys aa, bb, a4_ptr, b4_ptr + pmlal resl, resh, aa, bb + + montgomery_reduce_long res, resl, resh + + str q_res, [out_ptr], #16 + + subs count, count, #1 + cbnz count, l5_loop_start + + pop_stack + ret +#endif /* MLD_ARITH_BACKEND_AARCH64 */ diff --git a/mldsa/native/aarch64/src/mld_polyvecl_pointwise_acc_montgomery_l7.S b/mldsa/native/aarch64/src/mld_polyvecl_pointwise_acc_montgomery_l7.S new file mode 100644 index 00000000..153f9a7f --- /dev/null +++ b/mldsa/native/aarch64/src/mld_polyvecl_pointwise_acc_montgomery_l7.S @@ -0,0 +1,144 @@ +/* Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_AARCH64) + +.macro montgomery_reduce_long res, inl, inh + uzp1 t0.4s, \inl\().4s, \inh\().4s + mul t0.4s, t0.4s, modulus_twisted.4s + smlal \inl\().2d, t0.2s, modulus.2s + smlal2 \inh\().2d, t0.4s, modulus.4s + uzp2 \res\().4s, \inl\().4s, \inh\().4s +.endm + +.macro load_polys a, b, a_ptr, b_ptr + ldr q_\()\a, [\a_ptr], #16 + ldr q_\()\b, [\b_ptr], #16 +.endm + +.macro pmull dl, dh, a, b + smull \dl\().2d, \a\().2s, \b\().2s + smull2 \dh\().2d, \a\().4s, \b\().4s +.endm + +.macro pmlal dl, dh, a, b + smlal \dl\().2d, \a\().2s, \b\().2s + smlal2 \dh\().2d, \a\().4s, \b\().4s +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + +out_ptr .req x0 +a0_ptr .req x1 +b0_ptr .req x2 +a1_ptr .req x3 +b1_ptr .req x4 +a2_ptr .req x5 +b2_ptr .req x6 +a3_ptr .req x7 +b3_ptr .req x8 +a4_ptr .req x9 +b4_ptr .req x10 +a5_ptr .req x11 +b5_ptr .req x12 +a6_ptr .req x13 +b6_ptr .req x14 +count .req x15 +wtmp .req w15 + +modulus .req v0 +modulus_twisted .req v1 + +aa .req v2 +bb .req v3 +res .req v4 +resl .req v5 +resh .req v6 +t0 .req v7 + +q_aa .req q2 +q_bb .req q3 +q_res .req q4 + +.text +.global MLD_ASM_NAMESPACE(polyvecl_pointwise_acc_montgomery_l7_asm) +.balign 4 +MLD_ASM_FN_SYMBOL(polyvecl_pointwise_acc_montgomery_l7_asm) + push_stack + + // load q = 8380417 + movz wtmp, #57345 + movk wtmp, #127, lsl #16 + dup modulus.4s, wtmp + + // load -q^-1 = 4236238847 + movz wtmp, #57343 + movk wtmp, #64639, lsl #16 + dup modulus_twisted.4s, wtmp + + // Computed bases of vector entries + add a1_ptr, a0_ptr, #(1 * 1024) + add a2_ptr, a0_ptr, #(2 * 1024) + add a3_ptr, a0_ptr, #(3 * 1024) + add a4_ptr, a0_ptr, #(4 * 1024) + add a5_ptr, a1_ptr, #(4 * 1024) + add a6_ptr, a2_ptr, #(4 * 1024) + + add b1_ptr, b0_ptr, #(1 * 1024) + add b2_ptr, b0_ptr, #(2 * 1024) + add b3_ptr, b0_ptr, #(3 * 1024) + add b4_ptr, b0_ptr, #(4 * 1024) + add b5_ptr, b1_ptr, #(4 * 1024) + add b6_ptr, b2_ptr, #(4 * 1024) + + mov count, #(MLDSA_N / 4) +l7_loop_start: + load_polys aa, bb, a0_ptr, b0_ptr + pmull resl, resh, aa, bb + load_polys aa, bb, a1_ptr, b1_ptr + pmlal resl, resh, aa, bb + load_polys aa, bb, a2_ptr, b2_ptr + pmlal resl, resh, aa, bb + load_polys aa, bb, a3_ptr, b3_ptr + pmlal resl, resh, aa, bb + load_polys aa, bb, a4_ptr, b4_ptr + pmlal resl, resh, aa, bb + load_polys aa, bb, a5_ptr, b5_ptr + pmlal resl, resh, aa, bb + load_polys aa, bb, a6_ptr, b6_ptr + pmlal resl, resh, aa, bb + + montgomery_reduce_long res, resl, resh + + str q_res, [out_ptr], #16 + + subs count, count, #1 + cbnz count, l7_loop_start + + pop_stack + ret +#endif /* MLD_ARITH_BACKEND_AARCH64 */ diff --git a/mldsa/native/aarch64/src/pointwise_montgomery.S b/mldsa/native/aarch64/src/pointwise_montgomery.S new file mode 100644 index 00000000..505783a7 --- /dev/null +++ b/mldsa/native/aarch64/src/pointwise_montgomery.S @@ -0,0 +1,119 @@ +/* Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_AARCH64) + +.macro montgomery_reduce_long res, inl, inh + uzp1 t0.4s, \inl\().4s, \inh\().4s + mul t0.4s, t0.4s, modulus_twisted.4s + smlal \inl\().2d, t0.2s, modulus.2s + smlal2 \inh\().2d, t0.4s, modulus.4s + uzp2 \res\().4s, \inl\().4s, \inh\().4s +.endm + + +.macro pmull dl, dh, a, b + smull \dl\().2d, \a\().2s, \b\().2s + smull2 \dh\().2d, \a\().4s, \b\().4s +.endm + +.macro pmlal dl, dh, a, b + smlal \dl\().2d, \a\().2s, \b\().2s + smlal2 \dh\().2d, \a\().4s, \b\().4s +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + +out_ptr .req x0 +a0_ptr .req x1 +b0_ptr .req x2 +count .req x3 +wtmp .req w3 + +modulus .req v0 +modulus_twisted .req v1 + +aa .req v2 +bb .req v3 +res .req v4 +resl .req v5 +resh .req v6 +t0 .req v7 + +q_aa .req q2 +q_bb .req q3 +q_res .req q4 + +.text +.global MLD_ASM_NAMESPACE(poly_pointwise_montgomery_asm) +.balign 4 +MLD_ASM_FN_SYMBOL(poly_pointwise_montgomery_asm) + push_stack + + // load q = 8380417 + movz wtmp, #57345 + movk wtmp, #127, lsl #16 + dup modulus.4s, wtmp + + // load -q^-1 = 4236238847 + movz wtmp, #57343 + movk wtmp, #64639, lsl #16 + dup modulus_twisted.4s, wtmp + mov count, #(MLDSA_N / 4) +loop_start: + + + ldr q_aa, [a0_ptr], #64 + ldr q_bb, [b0_ptr], #64 + pmull resl, resh, aa, bb + montgomery_reduce_long res, resl, resh + str q_res, [out_ptr], #64 + + ldr q_aa, [a0_ptr, #-48] + ldr q_bb, [b0_ptr, #-48] + pmull resl, resh, aa, bb + montgomery_reduce_long res, resl, resh + str q_res, [out_ptr, #-48] + + ldr q_aa, [a0_ptr, #-32] + ldr q_bb, [b0_ptr, #-32] + pmull resl, resh, aa, bb + montgomery_reduce_long res, resl, resh + str q_res, [out_ptr, #-32] + + ldr q_aa, [a0_ptr, #-16] + ldr q_bb, [b0_ptr, #-16] + pmull resl, resh, aa, bb + montgomery_reduce_long res, resl, resh + str q_res, [out_ptr, #-16] + + subs count, count, #4 + cbnz count, loop_start + + pop_stack + ret +#endif /* MLD_ARITH_BACKEND_AARCH64 */ diff --git a/mldsa/native/api.h b/mldsa/native/api.h index d0119da8..fe1ededc 100644 --- a/mldsa/native/api.h +++ b/mldsa/native/api.h @@ -298,4 +298,82 @@ static MLD_INLINE void mld_polyz_unpack_17_native(int32_t *r, const uint8_t *a); static MLD_INLINE void mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a); #endif /* MLD_USE_NATIVE_POLYZ_UNPACK_19 */ +#if defined(MLD_USE_NATIVE_POINTWISE_MONTGOMERY) +/************************************************* + * Name: mld_poly_pointwise_montgomery_native + * + * Description: Pointwise multiplication of polynomials in NTT domain + * with Montgomery reduction. + * + * Computes c[i] = a[i] * b[i] * R^(-1) mod q for all i, + * where R = 2^32. + * + * Arguments: - int32_t c[MLDSA_N]: pointer to output polynomial + * - const int32_t a[MLDSA_N]: pointer to first input polynomial + * - const int32_t b[MLDSA_N]: pointer to second input polynomial + **************************************************/ +static MLD_INLINE void mld_poly_pointwise_montgomery_native( + int32_t c[MLDSA_N], const int32_t a[MLDSA_N], const int32_t b[MLDSA_N]); +#endif /* MLD_USE_NATIVE_POINTWISE_MONTGOMERY */ + +#if defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4) +/************************************************* + * Name: mld_polyvecl_pointwise_acc_montgomery_l4_native + * + * Description: Native implementation of poly_use_hint for L = 4. + * Pointwise multiply vectors of polynomials of length L, multiply + * resulting vector by 2^{-32} and add (accumulate) polynomials + * in it. Input/output vectors are in NTT domain representation. + * + * Arguments: - int32_t w[MLDSA_N]: pointer to output polynomial + * - const int32_t u[MLDSA_L][MLDSA_N]: pointer to first input + *vector + * - const int32_t v[MLDSA_L][MLDSA_N]: pointer to second input + *vector + **************************************************/ +static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native( + int32_t w[MLDSA_N], const int32_t u[MLDSA_L][MLDSA_N], + const int32_t v[MLDSA_L][MLDSA_N]); +#endif /* MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 */ + +#if defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5) +/************************************************* + * Name: mld_polyvecl_pointwise_acc_montgomery_l5_native + * + * Description: Native implementation of poly_use_hint for L = 5. + * Pointwise multiply vectors of polynomials of length L, multiply + * resulting vector by 2^{-32} and add (accumulate) polynomials + * in it. Input/output vectors are in NTT domain representation. + * + * Arguments: - int32_t w[MLDSA_N]: pointer to output polynomial + * - const int32_t u[MLDSA_L][MLDSA_N]: pointer to first input + *vector + * - const int32_t v[MLDSA_L][MLDSA_N]: pointer to second input + *vector + **************************************************/ +static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native( + int32_t w[MLDSA_N], const int32_t u[MLDSA_L][MLDSA_N], + const int32_t v[MLDSA_L][MLDSA_N]); +#endif /* MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5 */ + +#if defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7) +/************************************************* + * Name: mld_polyvecl_pointwise_acc_montgomery_l7_native + * + * Description: Native implementation of poly_use_hint for L = 7. + * Pointwise multiply vectors of polynomials of length L, multiply + * resulting vector by 2^{-32} and add (accumulate) polynomials + * in it. Input/output vectors are in NTT domain representation. + * + * Arguments: - int32_t w[MLDSA_N]: pointer to output polynomial + * - const int32_t u[MLDSA_L][MLDSA_N]: pointer to first input + *vector + * - const int32_t v[MLDSA_L][MLDSA_N]: pointer to second input + *vector + **************************************************/ +static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native( + int32_t w[MLDSA_N], const int32_t u[MLDSA_L][MLDSA_N], + const int32_t v[MLDSA_L][MLDSA_N]); +#endif /* MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7 */ + #endif /* !MLD_NATIVE_API_H */ diff --git a/mldsa/native/x86_64/meta.h b/mldsa/native/x86_64/meta.h index 3feba6be..1dd3b49c 100644 --- a/mldsa/native/x86_64/meta.h +++ b/mldsa/native/x86_64/meta.h @@ -25,6 +25,10 @@ #define MLD_USE_NATIVE_POLY_CHKNORM #define MLD_USE_NATIVE_POLYZ_UNPACK_17 #define MLD_USE_NATIVE_POLYZ_UNPACK_19 +#define MLD_USE_NATIVE_POINTWISE_MONTGOMERY +#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 +#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5 +#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7 #if !defined(__ASSEMBLER__) #include @@ -151,6 +155,37 @@ static MLD_INLINE void mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a) mld_polyz_unpack_19_avx2((__m256i *)r, a); } +static MLD_INLINE void mld_poly_pointwise_montgomery_native( + int32_t c[MLDSA_N], const int32_t a[MLDSA_N], const int32_t b[MLDSA_N]) +{ + mld_pointwise_avx2((__m256i *)c, (const __m256i *)a, (const __m256i *)b, + mld_qdata.vec); +} + +static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native( + int32_t w[MLDSA_N], const int32_t u[MLDSA_L][MLDSA_N], + const int32_t v[MLDSA_L][MLDSA_N]) +{ + mld_pointwise_acc_l4_avx2((__m256i *)w, (const __m256i *)u, + (const __m256i *)v, mld_qdata.vec); +} + +static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native( + int32_t w[MLDSA_N], const int32_t u[MLDSA_L][MLDSA_N], + const int32_t v[MLDSA_L][MLDSA_N]) +{ + mld_pointwise_acc_l5_avx2((__m256i *)w, (const __m256i *)u, + (const __m256i *)v, mld_qdata.vec); +} + +static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native( + int32_t w[MLDSA_N], const int32_t u[MLDSA_L][MLDSA_N], + const int32_t v[MLDSA_L][MLDSA_N]) +{ + mld_pointwise_acc_l7_avx2((__m256i *)w, (const __m256i *)u, + (const __m256i *)v, mld_qdata.vec); +} + #endif /* !__ASSEMBLER__ */ #endif /* !MLD_NATIVE_X86_64_META_H */ diff --git a/mldsa/native/x86_64/src/arith_native_x86_64.h b/mldsa/native/x86_64/src/arith_native_x86_64.h index bd6dfb55..822b29bc 100644 --- a/mldsa/native/x86_64/src/arith_native_x86_64.h +++ b/mldsa/native/x86_64/src/arith_native_x86_64.h @@ -78,4 +78,20 @@ void mld_polyz_unpack_17_avx2(__m256i *r, const uint8_t *a); #define mld_polyz_unpack_19_avx2 MLD_NAMESPACE(mld_polyz_unpack_19_avx2) void mld_polyz_unpack_19_avx2(__m256i *r, const uint8_t *a); +#define mld_pointwise_avx2 MLD_NAMESPACE(pointwise_avx2) +void mld_pointwise_avx2(__m256i *c, const __m256i *a, const __m256i *b, + const __m256i *qdata); + +#define mld_pointwise_acc_l4_avx2 MLD_NAMESPACE(pointwise_acc_l4_avx2) +void mld_pointwise_acc_l4_avx2(__m256i *c, const __m256i *a, const __m256i *b, + const __m256i *qdata); + +#define mld_pointwise_acc_l5_avx2 MLD_NAMESPACE(pointwise_acc_l5_avx2) +void mld_pointwise_acc_l5_avx2(__m256i *c, const __m256i *a, const __m256i *b, + const __m256i *qdata); + +#define mld_pointwise_acc_l7_avx2 MLD_NAMESPACE(pointwise_acc_l7_avx2) +void mld_pointwise_acc_l7_avx2(__m256i *c, const __m256i *a, const __m256i *b, + const __m256i *qdata); + #endif /* !MLD_NATIVE_X86_64_SRC_ARITH_NATIVE_X86_64_H */ diff --git a/mldsa/native/x86_64/src/pointwise.S b/mldsa/native/x86_64/src/pointwise.S new file mode 100644 index 00000000..36328038 --- /dev/null +++ b/mldsa/native/x86_64/src/pointwise.S @@ -0,0 +1,150 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ + !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) + +#include "consts.h" + + .intel_syntax noprefix + .text + +/* + * void mld_pointwise_avx2(__m256i *c, const __m256i *a, const __m256i *b, const __m256i *qdata) + * + * Pointwise multiplication of polynomials in NTT domain with Montgomery reduction + * + * Arguments: + * rdi: pointer to output polynomial c + * rsi: pointer to input polynomial a + * rdx: pointer to input polynomial b + * rcx: pointer to qdata constants + */ + .balign 4 + .global MLD_ASM_NAMESPACE(pointwise_avx2) +MLD_ASM_FN_SYMBOL(pointwise_avx2) + +// Load constants + vmovdqa ymm0, [rcx + (MLD_AVX2_BACKEND_DATA_OFFSET_8XQINV)*4] + vmovdqa ymm1, [rcx + (MLD_AVX2_BACKEND_DATA_OFFSET_8XQ)*4] + + xor eax, eax +_looptop1: +// Load + vmovdqa ymm2, [rsi] + vmovdqa ymm4, [rsi + 32] + vmovdqa ymm6, [rsi + 64] + vmovdqa ymm10, [rdx] + vmovdqa ymm12, [rdx + 32] + vmovdqa ymm14, [rdx + 64] + vpsrlq ymm3, ymm2, 32 + vpsrlq ymm5, ymm4, 32 + vmovshdup ymm7, ymm6 + vpsrlq ymm11, ymm10, 32 + vpsrlq ymm13, ymm12, 32 + vmovshdup ymm15, ymm14 + +// Multiply + vpmuldq ymm2, ymm2, ymm10 + vpmuldq ymm3, ymm3, ymm11 + vpmuldq ymm4, ymm4, ymm12 + vpmuldq ymm5, ymm5, ymm13 + vpmuldq ymm6, ymm6, ymm14 + vpmuldq ymm7, ymm7, ymm15 + +// Reduce + vpmuldq ymm10, ymm0, ymm2 + vpmuldq ymm11, ymm0, ymm3 + vpmuldq ymm12, ymm0, ymm4 + vpmuldq ymm13, ymm0, ymm5 + vpmuldq ymm14, ymm0, ymm6 + vpmuldq ymm15, ymm0, ymm7 + vpmuldq ymm10, ymm1, ymm10 + vpmuldq ymm11, ymm1, ymm11 + vpmuldq ymm12, ymm1, ymm12 + vpmuldq ymm13, ymm1, ymm13 + vpmuldq ymm14, ymm1, ymm14 + vpmuldq ymm15, ymm1, ymm15 + vpsubq ymm2, ymm2, ymm10 + vpsubq ymm3, ymm3, ymm11 + vpsubq ymm4, ymm4, ymm12 + vpsubq ymm5, ymm5, ymm13 + vpsubq ymm6, ymm6, ymm14 + vpsubq ymm7, ymm7, ymm15 + vpsrlq ymm2, ymm2, 32 + vpsrlq ymm4, ymm4, 32 + vmovshdup ymm6, ymm6 + +// Store + vpblendd ymm2, ymm2, ymm3, 0xAA + vpblendd ymm4, ymm4, ymm5, 0xAA + vpblendd ymm6, ymm6, ymm7, 0xAA + vmovdqa [rdi], ymm2 + vmovdqa [rdi + 32], ymm4 + vmovdqa [rdi + 64], ymm6 + + add rdi, 96 + add rsi, 96 + add rdx, 96 + add eax, 1 + cmp eax, 10 + jb _looptop1 + + vmovdqa ymm2, [rsi] + vmovdqa ymm4, [rsi + 32] + vmovdqa ymm10, [rdx] + vmovdqa ymm12, [rdx + 32] + vpsrlq ymm3, ymm2, 32 + vpsrlq ymm5, ymm4, 32 + vmovshdup ymm11, ymm10 + vmovshdup ymm13, ymm12 + +// Multiply + vpmuldq ymm2, ymm2, ymm10 + vpmuldq ymm3, ymm3, ymm11 + vpmuldq ymm4, ymm4, ymm12 + vpmuldq ymm5, ymm5, ymm13 + +// Reduce + vpmuldq ymm10, ymm0, ymm2 + vpmuldq ymm11, ymm0, ymm3 + vpmuldq ymm12, ymm0, ymm4 + vpmuldq ymm13, ymm0, ymm5 + vpmuldq ymm10, ymm1, ymm10 + vpmuldq ymm11, ymm1, ymm11 + vpmuldq ymm12, ymm1, ymm12 + vpmuldq ymm13, ymm1, ymm13 + vpsubq ymm2, ymm2, ymm10 + vpsubq ymm3, ymm3, ymm11 + vpsubq ymm4, ymm4, ymm12 + vpsubq ymm5, ymm5, ymm13 + vpsrlq ymm2, ymm2, 32 + vmovshdup ymm4, ymm4 + +// Store + vpblendd ymm2, ymm3, ymm2, 0x55 + vpblendd ymm4, ymm5, ymm4, 0x55 + vmovdqa [rdi], ymm2 + vmovdqa [rdi + 32], ymm4 + + ret + +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \ + */ diff --git a/mldsa/native/x86_64/src/pointwise_acc_l4.S b/mldsa/native/x86_64/src/pointwise_acc_l4.S new file mode 100644 index 00000000..3b2c45fd --- /dev/null +++ b/mldsa/native/x86_64/src/pointwise_acc_l4.S @@ -0,0 +1,125 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ + !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) + +#include "consts.h" + + .intel_syntax noprefix + .text + +.macro pointwise off +// Load + vmovdqa ymm6, [rsi + \off] + vmovdqa ymm8, [rsi + \off + 32] + vmovdqa ymm10, [rdx + \off] + vmovdqa ymm12, [rdx + \off + 32] + vpsrlq ymm7, ymm6, 32 + vpsrlq ymm9, ymm8, 32 + vmovshdup ymm11, ymm10 + vmovshdup ymm13, ymm12 + +// Multiply + vpmuldq ymm6, ymm6, ymm10 + vpmuldq ymm7, ymm7, ymm11 + vpmuldq ymm8, ymm8, ymm12 + vpmuldq ymm9, ymm9, ymm13 +.endm + +.macro acc + vpaddq ymm2, ymm6, ymm2 + vpaddq ymm3, ymm7, ymm3 + vpaddq ymm4, ymm8, ymm4 + vpaddq ymm5, ymm9, ymm5 +.endm + +/* + * void mld_pointwise_acc_l4_avx2(__m256i *c, const __m256i *a, const __m256i *b, const __m256i *qdata) + * + * Pointwise multiplication with accumulation across multiple polynomial vectors + * + * Arguments: + * rdi: pointer to output polynomial c + * rsi: pointer to input polynomial a (multiple vectors) + * rdx: pointer to input polynomial b (multiple vectors) + * rcx: pointer to qdata constants + */ + .balign 4 + .global MLD_ASM_NAMESPACE(pointwise_acc_l4_avx2) +MLD_ASM_FN_SYMBOL(pointwise_acc_l4_avx2) + +// Load constants + vmovdqa ymm0, [rcx + (MLD_AVX2_BACKEND_DATA_OFFSET_8XQINV)*4] + vmovdqa ymm1, [rcx + (MLD_AVX2_BACKEND_DATA_OFFSET_8XQ)*4] + + xor eax, eax +_looptop2: + pointwise 0 + +// Move + vmovdqa ymm2, ymm6 + vmovdqa ymm3, ymm7 + vmovdqa ymm4, ymm8 + vmovdqa ymm5, ymm9 + + pointwise 1024 + acc + + pointwise 2048 + acc + + pointwise 3072 + acc + +// Reduce + vpmuldq ymm6, ymm0, ymm2 + vpmuldq ymm7, ymm0, ymm3 + vpmuldq ymm8, ymm0, ymm4 + vpmuldq ymm9, ymm0, ymm5 + vpmuldq ymm6, ymm1, ymm6 + vpmuldq ymm7, ymm1, ymm7 + vpmuldq ymm8, ymm1, ymm8 + vpmuldq ymm9, ymm1, ymm9 + vpsubq ymm2, ymm2, ymm6 + vpsubq ymm3, ymm3, ymm7 + vpsubq ymm4, ymm4, ymm8 + vpsubq ymm5, ymm5, ymm9 + vpsrlq ymm2, ymm2, 32 + vmovshdup ymm4, ymm4 + +// Store + vpblendd ymm2, ymm2, ymm3, 0xAA + vpblendd ymm4, ymm4, ymm5, 0xAA + + vmovdqa [rdi], ymm2 + vmovdqa [rdi + 32], ymm4 + + add rsi, 64 + add rdx, 64 + add rdi, 64 + add eax, 1 + cmp eax, 16 + jb _looptop2 + + ret + +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \ + */ diff --git a/mldsa/native/x86_64/src/pointwise_acc_l5.S b/mldsa/native/x86_64/src/pointwise_acc_l5.S new file mode 100644 index 00000000..46d5e722 --- /dev/null +++ b/mldsa/native/x86_64/src/pointwise_acc_l5.S @@ -0,0 +1,128 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ + !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) + +#include "consts.h" + + .intel_syntax noprefix + .text + +.macro pointwise off +// Load + vmovdqa ymm6, [rsi + \off] + vmovdqa ymm8, [rsi + \off + 32] + vmovdqa ymm10, [rdx + \off] + vmovdqa ymm12, [rdx + \off + 32] + vpsrlq ymm7, ymm6, 32 + vpsrlq ymm9, ymm8, 32 + vmovshdup ymm11, ymm10 + vmovshdup ymm13, ymm12 + +// Multiply + vpmuldq ymm6, ymm6, ymm10 + vpmuldq ymm7, ymm7, ymm11 + vpmuldq ymm8, ymm8, ymm12 + vpmuldq ymm9, ymm9, ymm13 +.endm + +.macro acc + vpaddq ymm2, ymm6, ymm2 + vpaddq ymm3, ymm7, ymm3 + vpaddq ymm4, ymm8, ymm4 + vpaddq ymm5, ymm9, ymm5 +.endm + +/* + * void mld_pointwise_acc_l5_avx2(__m256i *c, const __m256i *a, const __m256i *b, const __m256i *qdata) + * + * Pointwise multiplication with accumulation across multiple polynomial vectors + * + * Arguments: + * rdi: pointer to output polynomial c + * rsi: pointer to input polynomial a (multiple vectors) + * rdx: pointer to input polynomial b (multiple vectors) + * rcx: pointer to qdata constants + */ + .balign 4 + .global MLD_ASM_NAMESPACE(pointwise_acc_l5_avx2) +MLD_ASM_FN_SYMBOL(pointwise_acc_l5_avx2) + +// Load constants + vmovdqa ymm0, [rcx + (MLD_AVX2_BACKEND_DATA_OFFSET_8XQINV)*4] + vmovdqa ymm1, [rcx + (MLD_AVX2_BACKEND_DATA_OFFSET_8XQ)*4] + + xor eax, eax +_looptop2: + pointwise 0 + +// Move + vmovdqa ymm2, ymm6 + vmovdqa ymm3, ymm7 + vmovdqa ymm4, ymm8 + vmovdqa ymm5, ymm9 + + pointwise 1024 + acc + + pointwise 2048 + acc + + pointwise 3072 + acc + + pointwise 4096 + acc + +// Reduce + vpmuldq ymm6, ymm0, ymm2 + vpmuldq ymm7, ymm0, ymm3 + vpmuldq ymm8, ymm0, ymm4 + vpmuldq ymm9, ymm0, ymm5 + vpmuldq ymm6, ymm1, ymm6 + vpmuldq ymm7, ymm1, ymm7 + vpmuldq ymm8, ymm1, ymm8 + vpmuldq ymm9, ymm1, ymm9 + vpsubq ymm2, ymm2, ymm6 + vpsubq ymm3, ymm3, ymm7 + vpsubq ymm4, ymm4, ymm8 + vpsubq ymm5, ymm5, ymm9 + vpsrlq ymm2, ymm2, 32 + vmovshdup ymm4, ymm4 + +// Store + vpblendd ymm2, ymm2, ymm3, 0xAA + vpblendd ymm4, ymm4, ymm5, 0xAA + + vmovdqa [rdi], ymm2 + vmovdqa [rdi + 32], ymm4 + + add rsi, 64 + add rdx, 64 + add rdi, 64 + add eax, 1 + cmp eax, 16 + jb _looptop2 + + ret + +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \ + */ diff --git a/mldsa/native/x86_64/src/pointwise_acc_l7.S b/mldsa/native/x86_64/src/pointwise_acc_l7.S new file mode 100644 index 00000000..0c375efc --- /dev/null +++ b/mldsa/native/x86_64/src/pointwise_acc_l7.S @@ -0,0 +1,134 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ + !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) + +#include "consts.h" + + .intel_syntax noprefix + .text + +.macro pointwise off +// Load + vmovdqa ymm6, [rsi + \off] + vmovdqa ymm8, [rsi + \off + 32] + vmovdqa ymm10, [rdx + \off] + vmovdqa ymm12, [rdx + \off + 32] + vpsrlq ymm7, ymm6, 32 + vpsrlq ymm9, ymm8, 32 + vmovshdup ymm11, ymm10 + vmovshdup ymm13, ymm12 + +// Multiply + vpmuldq ymm6, ymm6, ymm10 + vpmuldq ymm7, ymm7, ymm11 + vpmuldq ymm8, ymm8, ymm12 + vpmuldq ymm9, ymm9, ymm13 +.endm + +.macro acc + vpaddq ymm2, ymm6, ymm2 + vpaddq ymm3, ymm7, ymm3 + vpaddq ymm4, ymm8, ymm4 + vpaddq ymm5, ymm9, ymm5 +.endm + +/* + * void mld_pointwise_acc_l7_avx2(__m256i *c, const __m256i *a, const __m256i *b, const __m256i *qdata) + * + * Pointwise multiplication with accumulation across multiple polynomial vectors + * + * Arguments: + * rdi: pointer to output polynomial c + * rsi: pointer to input polynomial a (multiple vectors) + * rdx: pointer to input polynomial b (multiple vectors) + * rcx: pointer to qdata constants + */ + .balign 4 + .global MLD_ASM_NAMESPACE(pointwise_acc_l7_avx2) +MLD_ASM_FN_SYMBOL(pointwise_acc_l7_avx2) + +// Load constants + vmovdqa ymm0, [rcx + (MLD_AVX2_BACKEND_DATA_OFFSET_8XQINV)*4] + vmovdqa ymm1, [rcx + (MLD_AVX2_BACKEND_DATA_OFFSET_8XQ)*4] + + xor eax, eax +_looptop2: + pointwise 0 + +// Move + vmovdqa ymm2, ymm6 + vmovdqa ymm3, ymm7 + vmovdqa ymm4, ymm8 + vmovdqa ymm5, ymm9 + + pointwise 1024 + acc + + pointwise 2048 + acc + + pointwise 3072 + acc + + pointwise 4096 + acc + + pointwise 5120 + acc + + pointwise 6144 + acc + +// Reduce + vpmuldq ymm6, ymm0, ymm2 + vpmuldq ymm7, ymm0, ymm3 + vpmuldq ymm8, ymm0, ymm4 + vpmuldq ymm9, ymm0, ymm5 + vpmuldq ymm6, ymm1, ymm6 + vpmuldq ymm7, ymm1, ymm7 + vpmuldq ymm8, ymm1, ymm8 + vpmuldq ymm9, ymm1, ymm9 + vpsubq ymm2, ymm2, ymm6 + vpsubq ymm3, ymm3, ymm7 + vpsubq ymm4, ymm4, ymm8 + vpsubq ymm5, ymm5, ymm9 + vpsrlq ymm2, ymm2, 32 + vmovshdup ymm4, ymm4 + +// Store + vpblendd ymm2, ymm2, ymm3, 0xAA + vpblendd ymm4, ymm4, ymm5, 0xAA + + vmovdqa [rdi], ymm2 + vmovdqa [rdi + 32], ymm4 + + add rsi, 64 + add rdx, 64 + add rdi, 64 + add eax, 1 + cmp eax, 16 + jb _looptop2 + + ret + +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \ + */ diff --git a/mldsa/poly.c b/mldsa/poly.c index dc1ee809..b7d99704 100644 --- a/mldsa/poly.c +++ b/mldsa/poly.c @@ -182,6 +182,13 @@ MLD_INTERNAL_API void mld_poly_pointwise_montgomery(mld_poly *c, const mld_poly *a, const mld_poly *b) { +#if defined(MLD_USE_NATIVE_POINTWISE_MONTGOMERY) + /* TODO: proof */ + mld_assert_abs_bound(a->coeffs, MLDSA_N, MLD_NTT_BOUND); + mld_assert_abs_bound(b->coeffs, MLDSA_N, MLD_NTT_BOUND); + mld_poly_pointwise_montgomery_native(c->coeffs, a->coeffs, b->coeffs); + mld_assert_abs_bound(c->coeffs, MLDSA_N, MLDSA_Q); +#else /* MLD_USE_NATIVE_POINTWISE_MONTGOMERY */ unsigned int i; mld_assert_abs_bound(a->coeffs, MLDSA_N, MLD_NTT_BOUND); @@ -197,6 +204,7 @@ void mld_poly_pointwise_montgomery(mld_poly *c, const mld_poly *a, } mld_assert_abs_bound(c->coeffs, MLDSA_N, MLDSA_Q); +#endif /* !MLD_USE_NATIVE_POINTWISE_MONTGOMERY */ } MLD_INTERNAL_API diff --git a/mldsa/polyvec.c b/mldsa/polyvec.c index 01e4492a..d51270b9 100644 --- a/mldsa/polyvec.c +++ b/mldsa/polyvec.c @@ -283,6 +283,39 @@ MLD_INTERNAL_API void mld_polyvecl_pointwise_acc_montgomery(mld_poly *w, const mld_polyvecl *u, const mld_polyvecl *v) { +#if defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4) && \ + MLD_CONFIG_PARAMETER_SET == 44 + /* TODO: proof */ + mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q); + mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND); + mld_polyvecl_pointwise_acc_montgomery_l4_native( + w->coeffs, (const int32_t(*)[MLDSA_N])u->vec, + (const int32_t(*)[MLDSA_N])v->vec); + mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q); +#elif defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5) && \ + MLD_CONFIG_PARAMETER_SET == 65 + /* TODO: proof */ + mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q); + mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND); + mld_polyvecl_pointwise_acc_montgomery_l5_native( + w->coeffs, (const int32_t(*)[MLDSA_N])u->vec, + (const int32_t(*)[MLDSA_N])v->vec); + mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q); +#elif defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7) && \ + MLD_CONFIG_PARAMETER_SET == 87 + /* TODO: proof */ + mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q); + mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND); + mld_polyvecl_pointwise_acc_montgomery_l7_native( + w->coeffs, (const int32_t(*)[MLDSA_N])u->vec, + (const int32_t(*)[MLDSA_N])v->vec); + mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q); +#else /* !(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 && \ + MLD_CONFIG_PARAMETER_SET == 44) && \ + !(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5 && \ + MLD_CONFIG_PARAMETER_SET == 65) && \ + MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7 && \ + MLD_CONFIG_PARAMETER_SET == 87 */ unsigned int i, j; mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q); mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND); @@ -320,6 +353,12 @@ void mld_polyvecl_pointwise_acc_montgomery(mld_poly *w, const mld_polyvecl *u, } mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q); +#endif /* !(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 && \ + MLD_CONFIG_PARAMETER_SET == 44) && \ + !(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5 && \ + MLD_CONFIG_PARAMETER_SET == 65) && \ + !(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7 && \ + MLD_CONFIG_PARAMETER_SET == 87) */ } MLD_INTERNAL_API diff --git a/test/bench_components_mldsa.c b/test/bench_components_mldsa.c index c0a68aaa..3fa0495c 100644 --- a/test/bench_components_mldsa.c +++ b/test/bench_components_mldsa.c @@ -10,6 +10,7 @@ #include #include "../mldsa/ntt.h" #include "../mldsa/poly.h" +#include "../mldsa/polyvec.h" #include "../mldsa/randombytes.h" #include "hal.h" @@ -22,29 +23,36 @@ static int cmp_uint64_t(const void *a, const void *b) return (int)((*((const uint64_t *)a)) - (*((const uint64_t *)b))); } -#define BENCH(txt, code) \ - for (i = 0; i < NTESTS; i++) \ - { \ - mld_randombytes((uint8_t *)data0, sizeof(data0)); \ - for (j = 0; j < NWARMUP; j++) \ - { \ - code; \ - } \ - \ - t0 = get_cyclecounter(); \ - for (j = 0; j < NITERATIONS; j++) \ - { \ - code; \ - } \ - t1 = get_cyclecounter(); \ - (cyc)[i] = t1 - t0; \ - } \ - qsort((cyc), NTESTS, sizeof(uint64_t), cmp_uint64_t); \ +#define BENCH(txt, code) \ + for (i = 0; i < NTESTS; i++) \ + { \ + mld_randombytes((uint8_t *)data0, sizeof(data0)); \ + mld_randombytes((uint8_t *)&polyvecl_a, sizeof(polyvecl_a)); \ + mld_randombytes((uint8_t *)&polyvecl_b, sizeof(polyvecl_b)); \ + mld_randombytes((uint8_t *)polyvecl_mat, sizeof(polyvecl_mat)); \ + for (j = 0; j < NWARMUP; j++) \ + { \ + code; \ + } \ + \ + t0 = get_cyclecounter(); \ + for (j = 0; j < NITERATIONS; j++) \ + { \ + code; \ + } \ + t1 = get_cyclecounter(); \ + (cyc)[i] = t1 - t0; \ + } \ + qsort((cyc), NTESTS, sizeof(uint64_t), cmp_uint64_t); \ printf(txt " cycles=%" PRIu64 "\n", (cyc)[NTESTS >> 1] / NITERATIONS); static int bench(void) { MLD_ALIGN int32_t data0[256]; + MLD_ALIGN mld_poly poly_out; + MLD_ALIGN mld_polyvecl polyvecl_a, polyvecl_b; + MLD_ALIGN mld_polyveck polyveck_out; + MLD_ALIGN mld_polyvecl polyvecl_mat[MLDSA_K]; uint64_t cyc[NTESTS]; unsigned i, j; uint64_t t0, t1; @@ -53,6 +61,14 @@ static int bench(void) BENCH("poly_ntt", mld_poly_ntt((mld_poly *)data0)) BENCH("poly_invntt_tomont", mld_poly_invntt_tomont((mld_poly *)data0)) + /* pointwise */ + BENCH("polyvecl_pointwise_acc_montgomery", + mld_polyvecl_pointwise_acc_montgomery(&poly_out, &polyvecl_a, + &polyvecl_b)) + BENCH("polyvec_matrix_pointwise_montgomery", + mld_polyvec_matrix_pointwise_montgomery(&polyveck_out, polyvecl_mat, + &polyvecl_b)) + return 0; }