diff --git a/ml-kem/Cargo.toml b/ml-kem/Cargo.toml index cff7ccc..a346650 100644 --- a/ml-kem/Cargo.toml +++ b/ml-kem/Cargo.toml @@ -20,6 +20,7 @@ default = ["std"] std = ["sha3/std"] deterministic = [] # Expose deterministic generation and encapsulation functions zeroize = ["dep:zeroize"] +decap_key = [] # Use seed for decapsulation key (default behaviour) or not. If set, will use standard decapsulation key. [dependencies] kem = "0.3.0-pre.0" diff --git a/ml-kem/src/kem.rs b/ml-kem/src/kem.rs index 4abc52a..0408b39 100644 --- a/ml-kem/src/kem.rs +++ b/ml-kem/src/kem.rs @@ -1,6 +1,8 @@ use core::convert::Infallible; use core::marker::PhantomData; use hybrid_array::typenum::U32; +#[cfg(not(feature = "decap_key"))] +use hybrid_array::typenum::U64; use rand_core::CryptoRngCore; use crate::crypto::{rand, G, H, J}; @@ -18,10 +20,19 @@ pub use ::kem::{Decapsulate, Encapsulate}; /// A shared key resulting from an ML-KEM transaction pub(crate) type SharedKey = B32; -/// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an -/// encapsulated shared key. +#[cfg(not(feature = "decap_key"))] #[derive(Clone, Debug, PartialEq)] -pub struct DecapsulationKey

+struct DecapsulationSeed

+where + P: KemParams, +{ + d: B32, + z: B32, + _phantom: PhantomData

, +} + +#[derive(Clone, Debug, PartialEq)] +struct DecapsulationKeyInner

where P: KemParams, { @@ -30,8 +41,29 @@ where z: B32, } +/// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an +/// encapsulated shared key. +#[cfg(feature = "decap_key")] +#[derive(Clone, Debug, PartialEq)] +pub struct DecapsulationKey

+where + P: KemParams, +{ + key: DecapsulationKeyInner

, +} +/// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an +/// encapsulated shared key. +#[cfg(not(feature = "decap_key"))] +#[derive(Clone, Debug, PartialEq)] +pub struct DecapsulationKey

+where + P: KemParams, +{ + key: DecapsulationSeed

, +} + #[cfg(feature = "zeroize")] -impl

Drop for DecapsulationKey

+impl

Drop for DecapsulationKeyInner

where P: KemParams, { @@ -41,10 +73,59 @@ where } } +#[cfg(all(feature = "zeroize", not(feature = "decap_key")))] +impl

Drop for DecapsulationSeed

+where + P: KemParams, +{ + fn drop(&mut self) { + self.d.zeroize(); + self.z.zeroize(); + } +} + +#[cfg(feature = "zeroize")] +impl

Zeroize for DecapsulationKeyInner

+where + P: KemParams, +{ + fn zeroize(&mut self) { + self.dk_pke.zeroize(); + self.z.zeroize(); + } +} + +#[cfg(all(feature = "zeroize", not(feature = "decap_key")))] +impl

Zeroize for DecapsulationSeed

+where + P: KemParams, +{ + fn zeroize(&mut self) { + self.d.zeroize(); + self.z.zeroize(); + } +} + +#[cfg(feature = "zeroize")] +impl

Drop for DecapsulationKey

+where + P: KemParams, +{ + fn drop(&mut self) { + self.key.zeroize(); + } +} + +#[cfg(feature = "zeroize")] +impl

ZeroizeOnDrop for DecapsulationKeyInner

where P: KemParams {} + +#[cfg(all(feature = "zeroize", not(feature = "decap_key")))] +impl

ZeroizeOnDrop for DecapsulationSeed

where P: KemParams {} + #[cfg(feature = "zeroize")] impl

ZeroizeOnDrop for DecapsulationKey

where P: KemParams {} -impl

EncodedSizeUser for DecapsulationKey

+impl

EncodedSizeUser for DecapsulationKeyInner

where P: KemParams, { @@ -75,6 +156,59 @@ where } } +#[cfg(not(feature = "decap_key"))] +impl

EncodedSizeUser for DecapsulationSeed

+where + P: KemParams, +{ + type EncodedSize = U64; + + #[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec + fn from_bytes(enc: &Encoded) -> Self { + let (d, z) = P::split_seed(enc); + + Self { + d: d.clone(), + z: z.clone(), + _phantom: PhantomData, + } + } + + fn as_bytes(&self) -> Encoded { + self.d.clone().concat(self.z.clone()) + } +} + +impl

EncodedSizeUser for DecapsulationKey

+where + P: KemParams, +{ + #[cfg(feature = "decap_key")] + type EncodedSize = DecapsulationKeySize

; + #[cfg(not(feature = "decap_key"))] + type EncodedSize = U64; + + #[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec + fn from_bytes(enc: &Encoded) -> Self { + #[cfg(feature = "decap_key")] + { + Self { + key: DecapsulationKeyInner::

::from_bytes(enc), + } + } + #[cfg(not(feature = "decap_key"))] + { + Self { + key: DecapsulationSeed::

::from_bytes(enc), + } + } + } + + fn as_bytes(&self) -> Encoded { + self.key.as_bytes() + } +} + // 0xff if x == y, 0x00 otherwise fn constant_time_eq(x: u8, y: u8) -> u8 { let diff = x ^ y; @@ -82,7 +216,7 @@ fn constant_time_eq(x: u8, y: u8) -> u8 { 0u8.wrapping_sub(is_zero >> 7) } -impl

::kem::Decapsulate, SharedKey> for DecapsulationKey

+impl

::kem::Decapsulate, SharedKey> for DecapsulationKeyInner

where P: KemParams, { @@ -117,15 +251,46 @@ where } } -impl

DecapsulationKey

+#[cfg(not(feature = "decap_key"))] +impl

::kem::Decapsulate, SharedKey> for DecapsulationSeed

where P: KemParams, { - /// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationKey`]. - pub fn encapsulation_key(&self) -> &EncapsulationKey

{ - &self.ek + type Error = Infallible; + + fn decapsulate( + &self, + encapsulated_key: &EncodedCiphertext

, + ) -> Result { + DecapsulationKeyInner::

::generate_deterministic(&self.d, &self.z) + .decapsulate(encapsulated_key) } +} +impl

::kem::Decapsulate, SharedKey> for DecapsulationKey

+where + P: KemParams, +{ + type Error = Infallible; + + fn decapsulate( + &self, + encapsulated_key: &EncodedCiphertext

, + ) -> Result { + self.key.decapsulate(encapsulated_key) + } +} + +impl

DecapsulationKeyInner

+where + P: KemParams, +{ + /// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationKeyInner`]. + pub fn encapsulation_key(&self) -> EncapsulationKey

{ + self.ek.clone() + } + + #[cfg(feature = "decap_key")] pub(crate) fn generate(rng: &mut impl CryptoRngCore) -> Self { let d: B32 = rand(rng); let z: B32 = rand(rng); @@ -142,6 +307,85 @@ where } } +#[cfg(not(feature = "decap_key"))] +impl

DecapsulationSeed

+where + P: KemParams, +{ + /// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationSeed`]. + #[must_use] + pub fn encapsulation_key(&self) -> EncapsulationKey

{ + DecapsulationKeyInner::

::generate_deterministic(&self.d, &self.z) + .encapsulation_key() + .clone() + } + + pub(crate) fn generate(rng: &mut impl CryptoRngCore) -> Self { + let d: B32 = rand(rng); + let z: B32 = rand(rng); + Self { + d, + z, + _phantom: PhantomData, + } + } + + #[must_use] + #[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec + #[cfg(feature = "deterministic")] + pub(crate) fn generate_deterministic(d: &B32, z: &B32) -> Self { + Self { + d: *d, + z: *z, + _phantom: PhantomData, + } + } +} + +impl

DecapsulationKey

+where + P: KemParams, +{ + /// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationKey`]. + #[must_use] + pub fn encapsulation_key(&self) -> EncapsulationKey

{ + self.key.encapsulation_key() + } + + pub(crate) fn generate(rng: &mut impl CryptoRngCore) -> Self { + #[cfg(not(feature = "decap_key"))] + { + DecapsulationKey { + key: DecapsulationSeed::

::generate(rng), + } + } + #[cfg(feature = "decap_key")] + { + DecapsulationKey { + key: DecapsulationKeyInner::

::generate(rng), + } + } + } + + #[must_use] + #[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec + #[cfg(feature = "deterministic")] + pub(crate) fn generate_deterministic(d: &B32, z: &B32) -> Self { + #[cfg(not(feature = "decap_key"))] + { + DecapsulationKey { + key: DecapsulationSeed::

::generate_deterministic(d, z), + } + } + #[cfg(feature = "decap_key")] + { + DecapsulationKey { + key: DecapsulationKeyInner::

::generate_deterministic(d, z), + } + } + } +} + /// An `EncapsulationKey` provides the ability to encapsulate a shared key so that it can only be /// decapsulated by the holder of the corresponding decapsulation key. #[derive(Clone, Debug, PartialEq)] diff --git a/ml-kem/src/param.rs b/ml-kem/src/param.rs index 2a9497c..6fb2a58 100644 --- a/ml-kem/src/param.rs +++ b/ml-kem/src/param.rs @@ -251,12 +251,15 @@ pub trait KemParams: PkeParams { &B32, &B32, ); + + fn split_seed(enc: &EncodedDecapsulationSeed) -> (&B32, &B32); } pub type DecapsulationKeySize

=

::DecapsulationKeySize; pub type EncapsulationKeySize

=

::EncryptionKeySize; pub type EncodedDecapsulationKey

= Array::DecapsulationKeySize>; +pub type EncodedDecapsulationSeed = Array; impl

KemParams for P where @@ -295,4 +298,9 @@ where let (dk_pke, ek_pke) = enc.split_ref(); (dk_pke, ek_pke, h, z) } + + fn split_seed(enc: &EncodedDecapsulationSeed) -> (&B32, &B32) { + let (d, z) = enc.split_ref(); + (d, z) + } } diff --git a/ml-kem/tests/encap-decap.rs b/ml-kem/tests/encap-decap.rs index 64d0247..26cab67 100644 --- a/ml-kem/tests/encap-decap.rs +++ b/ml-kem/tests/encap-decap.rs @@ -1,4 +1,5 @@ #![cfg(feature = "deterministic")] +#![cfg(feature = "decap_key")] use ml_kem::*; diff --git a/ml-kem/tests/key-gen.rs b/ml-kem/tests/key-gen.rs index 3c855fd..fdf7834 100644 --- a/ml-kem/tests/key-gen.rs +++ b/ml-kem/tests/key-gen.rs @@ -1,4 +1,5 @@ #![cfg(feature = "deterministic")] +#![cfg(feature = "decap_key")] use ml_kem::*;