@@ -36,7 +36,7 @@ use rand_core::OsRng;
3636use sha3:: digest:: core_api:: XofReaderCoreWrapper ;
3737use sha3:: digest:: { ExtendableOutput , XofReader } ;
3838use sha3:: { Sha3_256 , Shake256 , Shake256ReaderCore } ;
39- use x25519_dalek:: { x25519 , X25519_BASEPOINT_BYTES } ;
39+ use x25519_dalek:: { EphemeralSecret , PublicKey , StaticSecret } ;
4040#[ cfg( feature = "zeroize" ) ]
4141use zeroize:: { Zeroize , ZeroizeOnDrop } ;
4242
@@ -69,7 +69,7 @@ pub type SharedSecret = [u8; 32];
6969#[ derive( Clone , PartialEq ) ]
7070pub struct EncapsulationKey {
7171 pk_m : MlKem768EncapsulationKey ,
72- pk_x : x25519_dalek :: PublicKey ,
72+ pk_x : PublicKey ,
7373}
7474
7575impl Encapsulate < Ciphertext , SharedSecret > for EncapsulationKey {
@@ -82,28 +82,23 @@ impl Encapsulate<Ciphertext, SharedSecret> for EncapsulationKey {
8282 // Swapped order of operations compared to RFC, so that usage of the rng matches the RFC
8383 let ( ct_m, ss_m) = self . pk_m . encapsulate ( rng) ?;
8484
85- let ek_x: SharedSecret = generate ( rng) ;
86- let ct_x = x25519 ( ek_x, X25519_BASEPOINT_BYTES ) ;
87- let ss_x = x25519 ( ek_x, self . pk_x . to_bytes ( ) ) ;
85+ let ek_x = EphemeralSecret :: random_from_rng ( rng) ;
86+ // Equal to ct_x = x25519(ek_x, BASE_POINT)
87+ let ct_x = PublicKey :: from ( & ek_x) ;
88+ // Equal to ss_x = x25519(ek_x, pk_x)
89+ let ss_x = ek_x. diffie_hellman ( & self . pk_x ) ;
8890
8991 let ss = combiner ( & ss_m, & ss_x, & ct_x, & self . pk_x ) ;
90-
91- #[ cfg( feature = "zeroize" ) ]
92- {
93- let mut ss_x = ss_x;
94- ss_x. zeroize ( ) ;
95- }
96-
9792 let ct = Ciphertext { ct_m, ct_x } ;
9893 Ok ( ( ct, ss) )
9994 }
10095}
10196
10297impl EncapsulationKey {
10398 /// Convert the key to the following format:
104- /// ML-KEM-768 public key(1184 bytes) | X25519 public key(32 bytes).
99+ /// ML-KEM-768 public key(1184 bytes) || X25519 public key(32 bytes).
105100 #[ must_use]
106- pub fn as_bytes ( & self ) -> [ u8 ; ENCAPSULATION_KEY_SIZE ] {
101+ pub fn to_bytes ( & self ) -> [ u8 ; ENCAPSULATION_KEY_SIZE ] {
107102 let mut buffer = [ 0u8 ; ENCAPSULATION_KEY_SIZE ] ;
108103 buffer[ 0 ..1184 ] . copy_from_slice ( & self . pk_m . as_bytes ( ) ) ;
109104 buffer[ 1184 ..1216 ] . copy_from_slice ( self . pk_x . as_bytes ( ) ) ;
@@ -119,7 +114,7 @@ impl From<&[u8; ENCAPSULATION_KEY_SIZE]> for EncapsulationKey {
119114
120115 let mut pk_x = [ 0 ; 32 ] ;
121116 pk_x. copy_from_slice ( & value[ 1184 ..] ) ;
122- let pk_x = x25519_dalek :: PublicKey :: from ( pk_x) ;
117+ let pk_x = PublicKey :: from ( pk_x) ;
123118 EncapsulationKey { pk_m, pk_x }
124119 }
125120}
@@ -138,16 +133,13 @@ impl Decapsulate<Ciphertext, SharedSecret> for DecapsulationKey {
138133 #[ allow( clippy:: similar_names) ] // So we can use the names as in the RFC
139134 fn decapsulate ( & self , ct : & Ciphertext ) -> Result < SharedSecret , Self :: Error > {
140135 let ( sk_m, sk_x, _pk_m, pk_x) = self . expand_key ( ) ;
136+
141137 let ss_m = sk_m. decapsulate ( & ct. ct_m ) ?;
142- let ss_x = x25519 ( sk_x. to_bytes ( ) , ct. ct_x ) ;
143- let ss = combiner ( & ss_m, & ss_x, & ct. ct_x , & pk_x) ;
144138
145- #[ cfg( feature = "zeroize" ) ]
146- {
147- let mut ss_x = ss_x;
148- ss_x. zeroize ( ) ;
149- }
139+ // equal to ss_x = x25519(sk_x, ct_x)
140+ let ss_x = sk_x. diffie_hellman ( & ct. ct_x ) ;
150141
142+ let ss = combiner ( & ss_m, & ss_x, & ct. ct_x , & pk_x) ;
151143 Ok ( ss)
152144 }
153145}
@@ -176,9 +168,9 @@ impl DecapsulationKey {
176168 & self ,
177169 ) -> (
178170 MlKem768DecapsulationKey ,
179- x25519_dalek :: StaticSecret ,
171+ StaticSecret ,
180172 MlKem768EncapsulationKey ,
181- x25519_dalek :: PublicKey ,
173+ PublicKey ,
182174 ) {
183175 use sha3:: digest:: Update ;
184176 let mut shaker = Shake256 :: default ( ) ;
@@ -190,8 +182,8 @@ impl DecapsulationKey {
190182 let ( sk_m, pk_m) = MlKem768 :: generate_deterministic ( & d, & z) ;
191183
192184 let sk_x = read_from ( & mut expanded) ;
193- let sk_x = x25519_dalek :: StaticSecret :: from ( sk_x) ;
194- let pk_x = x25519_dalek :: PublicKey :: from ( & sk_x) ;
185+ let sk_x = StaticSecret :: from ( sk_x) ;
186+ let pk_x = PublicKey :: from ( & sk_x) ;
195187
196188 ( sk_m, sk_x, pk_m, pk_x)
197189 }
@@ -214,17 +206,17 @@ impl From<[u8; DECAPSULATION_KEY_SIZE]> for DecapsulationKey {
214206#[ cfg_attr( feature = "zeroize" , derive( Zeroize , ZeroizeOnDrop ) ) ]
215207pub struct Ciphertext {
216208 ct_m : ArrayN < u8 , 1088 > ,
217- ct_x : [ u8 ; 32 ] ,
209+ ct_x : PublicKey ,
218210}
219211
220212impl Ciphertext {
221213 /// Convert the ciphertext to the following format:
222- /// ML-KEM-768 ciphertext(1088 bytes) | X25519 ciphertext(32 bytes).
214+ /// ML-KEM-768 ciphertext(1088 bytes) || X25519 ciphertext(32 bytes).
223215 #[ must_use]
224- pub fn as_bytes ( & self ) -> [ u8 ; CIPHERTEXT_SIZE ] {
216+ pub fn to_bytes ( & self ) -> [ u8 ; CIPHERTEXT_SIZE ] {
225217 let mut buffer = [ 0 ; CIPHERTEXT_SIZE ] ;
226218 buffer[ 0 ..1088 ] . copy_from_slice ( & self . ct_m ) ;
227- buffer[ 1088 ..] . copy_from_slice ( & self . ct_x ) ;
219+ buffer[ 1088 ..] . copy_from_slice ( self . ct_x . as_bytes ( ) ) ;
228220 buffer
229221 }
230222}
@@ -238,7 +230,7 @@ impl From<&[u8; CIPHERTEXT_SIZE]> for Ciphertext {
238230
239231 Ciphertext {
240232 ct_m : ct_m. into ( ) ,
241- ct_x,
233+ ct_x : ct_x . into ( ) ,
242234 }
243235 }
244236}
@@ -258,9 +250,9 @@ pub fn generate_key_pair(rng: &mut impl CryptoRngCore) -> (DecapsulationKey, Enc
258250
259251fn combiner (
260252 ss_m : & B32 ,
261- ss_x : & [ u8 ; 32 ] ,
262- ct_x : & [ u8 ; 32 ] ,
263- pk_x : & x25519_dalek :: PublicKey ,
253+ ss_x : & x25519_dalek :: SharedSecret ,
254+ ct_x : & PublicKey ,
255+ pk_x : & PublicKey ,
264256) -> SharedSecret {
265257 use sha3:: Digest ;
266258
@@ -292,8 +284,8 @@ mod tests {
292284
293285 use super :: * ;
294286
295- struct SeedRng {
296- seed : Vec < u8 > ,
287+ pub ( crate ) struct SeedRng {
288+ pub ( crate ) seed : Vec < u8 > ,
297289 }
298290
299291 impl SeedRng {
@@ -360,14 +352,14 @@ mod tests {
360352 let mut seed = SeedRng :: new ( test_vector. seed ) ;
361353 let ( sk, pk) = generate_key_pair ( & mut seed) ;
362354
363- assert_eq ! ( sk. as_bytes( ) . to_vec ( ) , test_vector. sk) ;
364- assert_eq ! ( pk. as_bytes ( ) . to_vec ( ) , test_vector. pk) ;
355+ assert_eq ! ( sk. as_bytes( ) , & test_vector. sk) ;
356+ assert_eq ! ( & pk. to_bytes ( ) , test_vector. pk. as_slice ( ) ) ;
365357
366358 let mut eseed = SeedRng :: new ( test_vector. eseed ) ;
367359 let ( ct, ss) = pk. encapsulate ( & mut eseed) . unwrap ( ) ;
368360
369361 assert_eq ! ( ss, test_vector. ss) ;
370- assert_eq ! ( ct. as_bytes ( ) . to_vec ( ) , test_vector. ct) ;
362+ assert_eq ! ( & ct. to_bytes ( ) , test_vector. ct. as_slice ( ) ) ;
371363
372364 let ss = sk. decapsulate ( & ct) . unwrap ( ) ;
373365 assert_eq ! ( ss, test_vector. ss) ;
@@ -379,10 +371,10 @@ mod tests {
379371
380372 let ct_a = Ciphertext {
381373 ct_m : generate ( & mut rng) . into ( ) ,
382- ct_x : generate ( & mut rng) ,
374+ ct_x : generate ( & mut rng) . into ( ) ,
383375 } ;
384376
385- let bytes = ct_a. as_bytes ( ) ;
377+ let bytes = ct_a. to_bytes ( ) ;
386378
387379 let ct_b = Ciphertext :: from ( & bytes) ;
388380
@@ -395,10 +387,10 @@ mod tests {
395387 let pk = sk. encapsulation_key ( ) ;
396388
397389 let sk_bytes = sk. as_bytes ( ) ;
398- let pk_bytes = pk. as_bytes ( ) ;
390+ let pk_bytes = pk. to_bytes ( ) ;
399391
400392 let sk_b = DecapsulationKey :: from ( * sk_bytes) ;
401- let pk_b = EncapsulationKey :: from ( & pk_bytes. clone ( ) ) ;
393+ let pk_b = EncapsulationKey :: from ( & pk_bytes) ;
402394
403395 assert ! ( sk == sk_b) ;
404396 assert ! ( pk == pk_b) ;
0 commit comments