Skip to content

Commit 17b2c76

Browse files
authored
Merge pull request #414 from pq-code-package/caddq-asm
Add native implementation for `poly_caddq`
2 parents 6e4570f + 78cf598 commit 17b2c76

File tree

8 files changed

+151
-0
lines changed

8 files changed

+151
-0
lines changed

mldsa/native/aarch64/meta.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define MLD_USE_NATIVE_REJ_UNIFORM_ETA4
1616
#define MLD_USE_NATIVE_POLY_DECOMPOSE_32
1717
#define MLD_USE_NATIVE_POLY_DECOMPOSE_88
18+
#define MLD_USE_NATIVE_POLY_CADDQ
1819

1920
/* Identifier for this backend so that source and assembly files
2021
* in the build can be appropriately guarded. */
@@ -107,6 +108,11 @@ static MLD_INLINE void mld_poly_decompose_88_native(int32_t *a1, int32_t *a0,
107108
mld_poly_decompose_88_asm(a1, a0, a);
108109
}
109110

111+
static MLD_INLINE void mld_poly_caddq_native(int32_t a[MLDSA_N])
112+
{
113+
mld_poly_caddq_asm(a);
114+
}
115+
110116
#endif /* !__ASSEMBLER__ */
111117

112118
#endif /* !MLD_NATIVE_AARCH64_META_H */

mldsa/native/aarch64/src/arith_native_aarch64.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,7 @@ void mld_poly_decompose_32_asm(int32_t *a1, int32_t *a0, const int32_t *a);
6868
#define mld_poly_decompose_88_asm MLD_NAMESPACE(poly_decompose_88_asm)
6969
void mld_poly_decompose_88_asm(int32_t *a1, int32_t *a0, const int32_t *a);
7070

71+
#define mld_poly_caddq_asm MLD_NAMESPACE(poly_caddq_asm)
72+
void mld_poly_caddq_asm(int32_t *a);
73+
7174
#endif /* !MLD_NATIVE_AARCH64_SRC_ARITH_NATIVE_AARCH64_H */
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright (c) The mldsa-native project authors
3+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
4+
*/
5+
#include "../../../common.h"
6+
7+
#if defined(MLD_ARITH_BACKEND_AARCH64) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)
8+
9+
.macro caddq inout
10+
ushr tmp.4s, \inout\().4s, #31
11+
mla \inout\().4s, tmp.4s, q_reg.4s
12+
.endm
13+
14+
.global MLD_ASM_NAMESPACE(poly_caddq_asm)
15+
.balign 16
16+
MLD_ASM_FN_SYMBOL(poly_caddq_asm)
17+
// Function signature: void mld_poly_caddq_asm(int32_t *a)
18+
// x0: pointer to polynomial coefficients
19+
20+
// Register assignments
21+
a_ptr .req x0
22+
count .req x1
23+
q_reg .req v4
24+
tmp .req v5
25+
26+
// Load constants
27+
// MLDSA_Q = 8380417 = 0x7FE001
28+
movz w9, #0xE001
29+
movk w9, #0x7F, lsl #16
30+
dup q_reg.4s, w9 // Load Q values
31+
32+
mov count, #64/4
33+
poly_caddq_loop:
34+
ldr q0, [a_ptr, #0*16]
35+
ldr q1, [a_ptr, #1*16]
36+
ldr q2, [a_ptr, #2*16]
37+
ldr q3, [a_ptr, #3*16]
38+
39+
caddq v0
40+
caddq v1
41+
caddq v2
42+
caddq v3
43+
44+
str q1, [a_ptr, #1*16]
45+
str q2, [a_ptr, #2*16]
46+
str q3, [a_ptr, #3*16]
47+
str q0, [a_ptr], #4*16
48+
49+
subs count, count, #1
50+
bne poly_caddq_loop
51+
52+
ret
53+
54+
.unreq a_ptr
55+
.unreq count
56+
.unreq q_reg
57+
.unreq tmp
58+
59+
#endif /* MLD_ARITH_BACKEND_AARCH64 && !MLD_CONFIG_MULTILEVEL_NO_SHARED */

mldsa/native/api.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,5 +210,16 @@ static MLD_INLINE void mld_poly_decompose_88_native(int32_t *a1, int32_t *a0,
210210
const int32_t *a);
211211
#endif /* MLD_USE_NATIVE_POLY_DECOMPOSE_88 */
212212

213+
#if defined(MLD_USE_NATIVE_POLY_CADDQ)
214+
/*************************************************
215+
* Name: mld_poly_caddq_native
216+
*
217+
* Description: For all coefficients of in/out polynomial add Q if
218+
* coefficient is negative.
219+
*
220+
* Arguments: - int32_t *a: pointer to input/output polynomial
221+
**************************************************/
222+
static MLD_INLINE void mld_poly_caddq_native(int32_t a[MLDSA_N]);
223+
#endif /* MLD_USE_NATIVE_POLY_CADDQ */
213224

214225
#endif /* !MLD_NATIVE_API_H */

mldsa/native/x86_64/meta.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#define MLD_USE_NATIVE_REJ_UNIFORM_ETA4
2020
#define MLD_USE_NATIVE_POLY_DECOMPOSE_32
2121
#define MLD_USE_NATIVE_POLY_DECOMPOSE_88
22+
#define MLD_USE_NATIVE_POLY_CADDQ
2223

2324
#if !defined(__ASSEMBLER__)
2425
#include <string.h>
@@ -112,6 +113,11 @@ static MLD_INLINE void mld_poly_decompose_88_native(int32_t *a1, int32_t *a0,
112113
mld_poly_decompose_88_avx2((__m256i *)a1, (__m256i *)a0, (const __m256i *)a);
113114
}
114115

116+
static MLD_INLINE void mld_poly_caddq_native(int32_t a[MLDSA_N])
117+
{
118+
mld_poly_caddq_avx2(a);
119+
}
120+
115121
#endif /* !__ASSEMBLER__ */
116122

117123
#endif /* !MLD_NATIVE_X86_64_META_H */

mldsa/native/x86_64/src/arith_native_x86_64.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,7 @@ void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a);
6060
#define mld_poly_decompose_88_avx2 MLD_NAMESPACE(mld_poly_decompose_88_avx2)
6161
void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a);
6262

63+
#define mld_poly_caddq_avx2 MLD_NAMESPACE(poly_caddq_avx2)
64+
void mld_poly_caddq_avx2(int32_t *r);
65+
6366
#endif /* !MLD_NATIVE_X86_64_SRC_ARITH_NATIVE_X86_64_H */
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright (c) The mldsa-native project authors
3+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
4+
*/
5+
6+
/*
7+
* This file is derived from the public domain
8+
* AVX2 Dilithium implementation @[REF_AVX2].
9+
*/
10+
11+
#include "../../../common.h"
12+
13+
#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \
14+
!defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)
15+
16+
#include <immintrin.h>
17+
#include "arith_native_x86_64.h"
18+
#include "consts.h"
19+
20+
/*************************************************
21+
* Name: mld_poly_caddq_avx2
22+
*
23+
* Description: For all coefficients of in/out polynomial add Q if
24+
* coefficient is negative.
25+
*
26+
* Arguments: - int32_t *r: pointer to input/output polynomial
27+
**************************************************/
28+
void mld_poly_caddq_avx2(int32_t *r)
29+
{
30+
unsigned int i;
31+
__m256i f, g;
32+
const __m256i q = _mm256_set1_epi32(MLDSA_Q);
33+
const __m256i zero = _mm256_setzero_si256();
34+
__m256i *rr = (__m256i *)r;
35+
36+
for (i = 0; i < MLDSA_N / 8; i++)
37+
{
38+
f = _mm256_load_si256(&rr[i]);
39+
g = _mm256_cmpgt_epi32(zero, f);
40+
g = _mm256_and_si256(g, q);
41+
f = _mm256_add_epi32(f, g);
42+
_mm256_store_si256(&rr[i], f);
43+
}
44+
}
45+
46+
#else /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \
47+
*/
48+
49+
MLD_EMPTY_CU(avx2_reduce)
50+
51+
#endif /* !(MLD_ARITH_BACKEND_X86_64_DEFAULT && \
52+
!MLD_CONFIG_MULTILEVEL_NO_SHARED) */

mldsa/poly.c

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ void mld_poly_reduce(mld_poly *a)
3232
mld_assert_bound(a->coeffs, MLDSA_N, -REDUCE32_RANGE_MAX, REDUCE32_RANGE_MAX);
3333
}
3434

35+
36+
#if !defined(MLD_USE_NATIVE_POLY_CADDQ)
3537
MLD_INTERNAL_API
3638
void mld_poly_caddq(mld_poly *a)
3739
{
@@ -50,6 +52,15 @@ void mld_poly_caddq(mld_poly *a)
5052

5153
mld_assert_bound(a->coeffs, MLDSA_N, 0, MLDSA_Q);
5254
}
55+
#else /* !MLD_USE_NATIVE_POLY_CADDQ */
56+
MLD_INTERNAL_API
57+
void mld_poly_caddq(mld_poly *a)
58+
{
59+
mld_assert_abs_bound(a->coeffs, MLDSA_N, MLDSA_Q);
60+
mld_poly_caddq_native(a->coeffs);
61+
mld_assert_bound(a->coeffs, MLDSA_N, 0, MLDSA_Q);
62+
}
63+
#endif /* MLD_USE_NATIVE_POLY_CADDQ */
5364

5465
/* Reference: We use destructive version (output=first input) to avoid
5566
* reasoning about aliasing in the CBMC specification */

0 commit comments

Comments
 (0)