Skip to content

Commit cdf39ee

Browse files
AArch64: Add native implementations of polyvecl_pointwise_acc_montgomery
This commit adds native implementations of polyvecl_pointwise_acc_montgomery written from scratch. Co-authored-by: Matthias J. Kannwischer <[email protected]> Signed-off-by: jammychiou1 <[email protected]>
1 parent d9e74fb commit cdf39ee

File tree

5 files changed

+444
-0
lines changed

5 files changed

+444
-0
lines changed

mldsa/native/aarch64/meta.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
#define MLD_USE_NATIVE_POLYZ_UNPACK_17
2323
#define MLD_USE_NATIVE_POLYZ_UNPACK_19
2424
#define MLD_USE_NATIVE_POINTWISE_MONTGOMERY
25+
#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4
26+
#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5
27+
#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7
2528

2629
/* Identifier for this backend so that source and assembly files
2730
* in the build can be appropriately guarded. */
@@ -155,5 +158,29 @@ static MLD_INLINE void mld_poly_pointwise_montgomery_native(
155158
mld_poly_pointwise_montgomery_asm(out, in0, in1);
156159
}
157160

161+
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
162+
int32_t w[MLDSA_N], const int32_t u[4][MLDSA_N],
163+
const int32_t v[4][MLDSA_N])
164+
{
165+
mld_polyvecl_pointwise_acc_montgomery_l4_asm(w, (const int32_t *)u,
166+
(const int32_t *)v);
167+
}
168+
169+
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
170+
int32_t w[MLDSA_N], const int32_t u[5][MLDSA_N],
171+
const int32_t v[5][MLDSA_N])
172+
{
173+
mld_polyvecl_pointwise_acc_montgomery_l5_asm(w, (const int32_t *)u,
174+
(const int32_t *)v);
175+
}
176+
177+
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native(
178+
int32_t w[MLDSA_N], const int32_t u[7][MLDSA_N],
179+
const int32_t v[7][MLDSA_N])
180+
{
181+
mld_polyvecl_pointwise_acc_montgomery_l7_asm(w, (const int32_t *)u,
182+
(const int32_t *)v);
183+
}
184+
158185
#endif /* !__ASSEMBLER__ */
159186
#endif /* !MLD_NATIVE_AARCH64_META_H */

mldsa/native/aarch64/src/arith_native_aarch64.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,19 @@ void mld_polyz_unpack_19_asm(int32_t *r, const uint8_t *buf,
9898
void mld_poly_pointwise_montgomery_asm(int32_t *, const int32_t *,
9999
const int32_t *);
100100

101+
#define mld_polyvecl_pointwise_acc_montgomery_l4_asm \
102+
MLD_NAMESPACE(polyvecl_pointwise_acc_montgomery_l4_asm)
103+
void mld_polyvecl_pointwise_acc_montgomery_l4_asm(int32_t *, const int32_t *,
104+
const int32_t *);
105+
106+
#define mld_polyvecl_pointwise_acc_montgomery_l5_asm \
107+
MLD_NAMESPACE(polyvecl_pointwise_acc_montgomery_l5_asm)
108+
void mld_polyvecl_pointwise_acc_montgomery_l5_asm(int32_t *, const int32_t *,
109+
const int32_t *);
110+
111+
#define mld_polyvecl_pointwise_acc_montgomery_l7_asm \
112+
MLD_NAMESPACE(polyvecl_pointwise_acc_montgomery_l7_asm)
113+
void mld_polyvecl_pointwise_acc_montgomery_l7_asm(int32_t *, const int32_t *,
114+
const int32_t *);
115+
101116
#endif /* !MLD_NATIVE_AARCH64_SRC_ARITH_NATIVE_AARCH64_H */
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/* Copyright (c) The mldsa-native project authors
2+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
3+
*/
4+
5+
#include "../../../common.h"
6+
#if defined(MLD_ARITH_BACKEND_AARCH64)
7+
8+
.macro montgomery_reduce_long res, inl, inh
9+
uzp1 t0.4s, \inl\().4s, \inh\().4s
10+
mul t0.4s, t0.4s, modulus_twisted.4s
11+
smlal \inl\().2d, t0.2s, modulus.2s
12+
smlal2 \inh\().2d, t0.4s, modulus.4s
13+
uzp2 \res\().4s, \inl\().4s, \inh\().4s
14+
.endm
15+
16+
.macro load_polys a, b, a_ptr, b_ptr
17+
ldr q_\()\a, [\a_ptr], #16
18+
ldr q_\()\b, [\b_ptr], #16
19+
.endm
20+
21+
.macro pmull dl, dh, a, b
22+
smull \dl\().2d, \a\().2s, \b\().2s
23+
smull2 \dh\().2d, \a\().4s, \b\().4s
24+
.endm
25+
26+
.macro pmlal dl, dh, a, b
27+
smlal \dl\().2d, \a\().2s, \b\().2s
28+
smlal2 \dh\().2d, \a\().4s, \b\().4s
29+
.endm
30+
31+
.macro save_vregs
32+
sub sp, sp, #(16*4)
33+
stp d8, d9, [sp, #16*0]
34+
stp d10, d11, [sp, #16*1]
35+
stp d12, d13, [sp, #16*2]
36+
stp d14, d15, [sp, #16*3]
37+
.endm
38+
39+
.macro restore_vregs
40+
ldp d8, d9, [sp, #16*0]
41+
ldp d10, d11, [sp, #16*1]
42+
ldp d12, d13, [sp, #16*2]
43+
ldp d14, d15, [sp, #16*3]
44+
add sp, sp, #(16*4)
45+
.endm
46+
47+
.macro push_stack
48+
save_vregs
49+
.endm
50+
51+
.macro pop_stack
52+
restore_vregs
53+
.endm
54+
55+
out_ptr .req x0
56+
a0_ptr .req x1
57+
b0_ptr .req x2
58+
a1_ptr .req x3
59+
b1_ptr .req x4
60+
a2_ptr .req x5
61+
b2_ptr .req x6
62+
a3_ptr .req x7
63+
b3_ptr .req x8
64+
count .req x9
65+
wtmp .req w9
66+
67+
modulus .req v0
68+
modulus_twisted .req v1
69+
70+
aa .req v2
71+
bb .req v3
72+
res .req v4
73+
resl .req v5
74+
resh .req v6
75+
t0 .req v7
76+
77+
q_aa .req q2
78+
q_bb .req q3
79+
q_res .req q4
80+
81+
.text
82+
.global MLD_ASM_NAMESPACE(polyvecl_pointwise_acc_montgomery_l4_asm)
83+
.balign 4
84+
MLD_ASM_FN_SYMBOL(polyvecl_pointwise_acc_montgomery_l4_asm)
85+
push_stack
86+
87+
// load q = 8380417
88+
movz wtmp, #57345
89+
movk wtmp, #127, lsl #16
90+
dup modulus.4s, wtmp
91+
92+
// load -q^-1 = 4236238847
93+
movz wtmp, #57343
94+
movk wtmp, #64639, lsl #16
95+
dup modulus_twisted.4s, wtmp
96+
97+
// Computed bases of vector entries
98+
add a1_ptr, a0_ptr, #(1 * 1024)
99+
add a2_ptr, a0_ptr, #(2 * 1024)
100+
add a3_ptr, a0_ptr, #(3 * 1024)
101+
102+
add b1_ptr, b0_ptr, #(1 * 1024)
103+
add b2_ptr, b0_ptr, #(2 * 1024)
104+
add b3_ptr, b0_ptr, #(3 * 1024)
105+
106+
mov count, #(MLDSA_N / 4)
107+
l4_loop_start:
108+
load_polys aa, bb, a0_ptr, b0_ptr
109+
pmull resl, resh, aa, bb
110+
load_polys aa, bb, a1_ptr, b1_ptr
111+
pmlal resl, resh, aa, bb
112+
load_polys aa, bb, a2_ptr, b2_ptr
113+
pmlal resl, resh, aa, bb
114+
load_polys aa, bb, a3_ptr, b3_ptr
115+
pmlal resl, resh, aa, bb
116+
117+
montgomery_reduce_long res, resl, resh
118+
119+
str q_res, [out_ptr], #16
120+
121+
subs count, count, #1
122+
cbnz count, l4_loop_start
123+
124+
pop_stack
125+
ret
126+
#endif /* MLD_ARITH_BACKEND_AARCH64 */
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/* Copyright (c) The mldsa-native project authors
2+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
3+
*/
4+
5+
#include "../../../common.h"
6+
#if defined(MLD_ARITH_BACKEND_AARCH64)
7+
8+
.macro montgomery_reduce_long res, inl, inh
9+
uzp1 t0.4s, \inl\().4s, \inh\().4s
10+
mul t0.4s, t0.4s, modulus_twisted.4s
11+
smlal \inl\().2d, t0.2s, modulus.2s
12+
smlal2 \inh\().2d, t0.4s, modulus.4s
13+
uzp2 \res\().4s, \inl\().4s, \inh\().4s
14+
.endm
15+
16+
.macro load_polys a, b, a_ptr, b_ptr
17+
ldr q_\()\a, [\a_ptr], #16
18+
ldr q_\()\b, [\b_ptr], #16
19+
.endm
20+
21+
.macro pmull dl, dh, a, b
22+
smull \dl\().2d, \a\().2s, \b\().2s
23+
smull2 \dh\().2d, \a\().4s, \b\().4s
24+
.endm
25+
26+
.macro pmlal dl, dh, a, b
27+
smlal \dl\().2d, \a\().2s, \b\().2s
28+
smlal2 \dh\().2d, \a\().4s, \b\().4s
29+
.endm
30+
31+
.macro save_vregs
32+
sub sp, sp, #(16*4)
33+
stp d8, d9, [sp, #16*0]
34+
stp d10, d11, [sp, #16*1]
35+
stp d12, d13, [sp, #16*2]
36+
stp d14, d15, [sp, #16*3]
37+
.endm
38+
39+
.macro restore_vregs
40+
ldp d8, d9, [sp, #16*0]
41+
ldp d10, d11, [sp, #16*1]
42+
ldp d12, d13, [sp, #16*2]
43+
ldp d14, d15, [sp, #16*3]
44+
add sp, sp, #(16*4)
45+
.endm
46+
47+
.macro push_stack
48+
save_vregs
49+
.endm
50+
51+
.macro pop_stack
52+
restore_vregs
53+
.endm
54+
55+
out_ptr .req x0
56+
a0_ptr .req x1
57+
b0_ptr .req x2
58+
a1_ptr .req x3
59+
b1_ptr .req x4
60+
a2_ptr .req x5
61+
b2_ptr .req x6
62+
a3_ptr .req x7
63+
b3_ptr .req x8
64+
a4_ptr .req x9
65+
b4_ptr .req x10
66+
count .req x11
67+
wtmp .req w11
68+
69+
modulus .req v0
70+
modulus_twisted .req v1
71+
72+
aa .req v2
73+
bb .req v3
74+
res .req v4
75+
resl .req v5
76+
resh .req v6
77+
t0 .req v7
78+
79+
q_aa .req q2
80+
q_bb .req q3
81+
q_res .req q4
82+
83+
.text
84+
.global MLD_ASM_NAMESPACE(polyvecl_pointwise_acc_montgomery_l5_asm)
85+
.balign 4
86+
MLD_ASM_FN_SYMBOL(polyvecl_pointwise_acc_montgomery_l5_asm)
87+
push_stack
88+
89+
// load q = 8380417
90+
movz wtmp, #57345
91+
movk wtmp, #127, lsl #16
92+
dup modulus.4s, wtmp
93+
94+
// load -q^-1 = 4236238847
95+
movz wtmp, #57343
96+
movk wtmp, #64639, lsl #16
97+
dup modulus_twisted.4s, wtmp
98+
99+
// Computed bases of vector entries
100+
add a1_ptr, a0_ptr, #(1 * 1024)
101+
add a2_ptr, a0_ptr, #(2 * 1024)
102+
add a3_ptr, a0_ptr, #(3 * 1024)
103+
add a4_ptr, a0_ptr, #(4 * 1024)
104+
105+
add b1_ptr, b0_ptr, #(1 * 1024)
106+
add b2_ptr, b0_ptr, #(2 * 1024)
107+
add b3_ptr, b0_ptr, #(3 * 1024)
108+
add b4_ptr, b0_ptr, #(4 * 1024)
109+
110+
mov count, #(MLDSA_N / 4)
111+
l5_loop_start:
112+
load_polys aa, bb, a0_ptr, b0_ptr
113+
pmull resl, resh, aa, bb
114+
load_polys aa, bb, a1_ptr, b1_ptr
115+
pmlal resl, resh, aa, bb
116+
load_polys aa, bb, a2_ptr, b2_ptr
117+
pmlal resl, resh, aa, bb
118+
load_polys aa, bb, a3_ptr, b3_ptr
119+
pmlal resl, resh, aa, bb
120+
load_polys aa, bb, a4_ptr, b4_ptr
121+
pmlal resl, resh, aa, bb
122+
123+
montgomery_reduce_long res, resl, resh
124+
125+
str q_res, [out_ptr], #16
126+
127+
subs count, count, #1
128+
cbnz count, l5_loop_start
129+
130+
pop_stack
131+
ret
132+
#endif /* MLD_ARITH_BACKEND_AARCH64 */

0 commit comments

Comments
 (0)