diff --git a/Cargo.lock b/Cargo.lock index 9823a6ca10..437a5ec908 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -853,6 +853,7 @@ dependencies = [ "log", "log-reload", "macro_rules_attribute 0.2.2", + "mls-crypto-provider", "obfuscate", "paste", "proteus-wasm", @@ -2898,6 +2899,7 @@ dependencies = [ "hex", "log", "openmls", + "openmls_basic_credential", "rand", "sha2", ] diff --git a/crypto-ffi/Cargo.toml b/crypto-ffi/Cargo.toml index e16856fd47..3d087cce36 100644 --- a/crypto-ffi/Cargo.toml +++ b/crypto-ffi/Cargo.toml @@ -25,24 +25,25 @@ default = ["proteus"] proteus = ["core-crypto/proteus", "dep:proteus-wasm"] [dependencies] -thiserror.workspace = true -cfg-if.workspace = true -futures-util.workspace = true -async-trait.workspace = true -tls_codec.workspace = true async-lock.workspace = true -log.workspace = true -log-reload.workspace = true -serde_json.workspace = true -derive_more.workspace = true -proteus-wasm = { workspace = true, optional = true } +async-trait.workspace = true +cfg-if.workspace = true core-crypto-keystore.workspace = true core-crypto-macros.workspace = true core-crypto.workspace = true +derive_more.workspace = true +futures-util.workspace = true +hex.workspace = true +log-reload.workspace = true +log.workspace = true +mls-crypto-provider.workspace = true obfuscate.workspace = true -rmp-serde.workspace = true paste = "1.0.15" -hex.workspace = true +proteus-wasm = { workspace = true, optional = true } +rmp-serde.workspace = true +serde_json.workspace = true +thiserror.workspace = true +tls_codec.workspace = true # see https://github.com/RustCrypto/hashes/issues/404 [target.'cfg(not(any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86")))'.dependencies] diff --git a/crypto-ffi/bindings/js/src/CoreCrypto.ts b/crypto-ffi/bindings/js/src/CoreCrypto.ts index 97dd209bfa..801d06846e 100644 --- a/crypto-ffi/bindings/js/src/CoreCrypto.ts +++ b/crypto-ffi/bindings/js/src/CoreCrypto.ts @@ -70,6 +70,7 @@ export type { } from "./CoreCryptoInstance"; export { + Credential, CredentialType, WirePolicy, GroupInfoEncryptionType, diff --git a/crypto-ffi/bindings/js/src/CoreCryptoMLS.ts b/crypto-ffi/bindings/js/src/CoreCryptoMLS.ts index 68605ecfab..bbfe5505d4 100644 --- a/crypto-ffi/bindings/js/src/CoreCryptoMLS.ts +++ b/crypto-ffi/bindings/js/src/CoreCryptoMLS.ts @@ -3,6 +3,7 @@ import { BufferedDecryptedMessage as BufferedDecryptedMessageFfi, CommitBundle as CommitBundleFfi, CredentialType, + Credential, DecryptedMessage as DecryptedMessageFfi, DeviceStatus, MlsGroupInfoEncryptionType as GroupInfoEncryptionType, @@ -20,6 +21,7 @@ import { } from "./autogenerated/core-crypto-ffi"; export { + Credential, CredentialType, DeviceStatus, GroupInfoEncryptionType, diff --git a/crypto-ffi/bindings/js/test/bun/credential.test.ts b/crypto-ffi/bindings/js/test/bun/credential.test.ts new file mode 100644 index 0000000000..9cb9015e42 --- /dev/null +++ b/crypto-ffi/bindings/js/test/bun/credential.test.ts @@ -0,0 +1,27 @@ +import { setup, teardown } from "./utils"; +import { afterEach, test, beforeEach, describe, expect } from "bun:test"; +import { + ciphersuiteDefault, + ClientId, + Credential, + CredentialType, +} from "../../src/CoreCrypto"; + +beforeEach(async () => { + await setup(); +}); + +afterEach(async () => { + await teardown(); +}); + +describe("credentials", () => { + test("basic credential can be created", async () => { + const credential = Credential.basic( + ciphersuiteDefault(), + new ClientId(Buffer.from("any random client id here")) + ); + expect(credential.type()).toEqual(CredentialType.Basic); + expect(credential.earliest_validity()).toEqual(0n); + }); +}); diff --git a/crypto-ffi/bindings/jvm/src/main/kotlin/com/wire/crypto/MlsModel.kt b/crypto-ffi/bindings/jvm/src/main/kotlin/com/wire/crypto/MlsModel.kt index 39d8e41e10..de450c189a 100644 --- a/crypto-ffi/bindings/jvm/src/main/kotlin/com/wire/crypto/MlsModel.kt +++ b/crypto-ffi/bindings/jvm/src/main/kotlin/com/wire/crypto/MlsModel.kt @@ -32,3 +32,10 @@ fun ByteArray.toAvsSecret() = SecretKey(this) /** Construct a GroupInfo from bytes */ fun ByteArray.toGroupInfo() = GroupInfo(this) + +/** Construct a new Credential from ciphersuite and client id */ +@Throws(CoreCryptoException::class) +fun Credential.Companion.basic( + ciphersuite: Ciphersuite, + clientId: ClientId +): Credential = credentialBasic(ciphersuite, clientId) diff --git a/crypto-ffi/bindings/jvm/src/test/kotlin/com/wire/crypto/MLSTest.kt b/crypto-ffi/bindings/jvm/src/test/kotlin/com/wire/crypto/MLSTest.kt index 5121b6d44d..8728590590 100644 --- a/crypto-ffi/bindings/jvm/src/test/kotlin/com/wire/crypto/MLSTest.kt +++ b/crypto-ffi/bindings/jvm/src/test/kotlin/com/wire/crypto/MLSTest.kt @@ -475,4 +475,14 @@ class MLSTest : HasMockDeliveryService() { ) } } + + @Test + fun can_construct_basic_credential(): TestResult { + val scope = TestScope() + return scope.runTest { + val credential = Credential.basic(CIPHERSUITE_DEFAULT, genClientId()) + assertEquals(credential.type(), CredentialType.BASIC) + assertEquals(credential.earliestValidity(), 0u) + } + } } diff --git a/crypto-ffi/src/ciphersuite.rs b/crypto-ffi/src/ciphersuite.rs index ca59a51738..d067dbb338 100644 --- a/crypto-ffi/src/ciphersuite.rs +++ b/crypto-ffi/src/ciphersuite.rs @@ -4,7 +4,7 @@ //! it doesn't work on newtypes around external enums. We therefore redefine the ciphersuites enum //! here with appropriate annotations such that it gets exported to all relevant bindings. -use core_crypto::{CiphersuiteName, MlsCiphersuite}; +use core_crypto::{Ciphersuite as CryptoCiphersuite, MlsCiphersuite}; #[cfg(target_family = "wasm")] use wasm_bindgen::prelude::*; @@ -42,7 +42,7 @@ pub enum Ciphersuite { MLS_256_DHKEMP384_AES256GCM_SHA384_P384 = 0x0007, } -impl From for CiphersuiteName { +impl From for MlsCiphersuite { #[inline] fn from(value: Ciphersuite) -> Self { (value as u16) @@ -51,26 +51,26 @@ impl From for CiphersuiteName { } } -impl From for Ciphersuite { +impl From for Ciphersuite { #[inline] - fn from(value: CiphersuiteName) -> Self { + fn from(value: MlsCiphersuite) -> Self { (value as u16) .try_into() .expect("mls Ciphersuite is a subset of ffi Ciphersuite") } } -impl From for MlsCiphersuite { +impl From for CryptoCiphersuite { #[inline] fn from(value: Ciphersuite) -> Self { - CiphersuiteName::from(value).into() + MlsCiphersuite::from(value).into() } } -impl From for Ciphersuite { +impl From for Ciphersuite { #[inline] - fn from(value: MlsCiphersuite) -> Self { - CiphersuiteName::from(value).into() + fn from(value: CryptoCiphersuite) -> Self { + MlsCiphersuite::from(value).into() } } diff --git a/crypto-ffi/src/core_crypto/conversation.rs b/crypto-ffi/src/core_crypto/conversation.rs index 8cb8567a36..da2a197092 100644 --- a/crypto-ffi/src/core_crypto/conversation.rs +++ b/crypto-ffi/src/core_crypto/conversation.rs @@ -56,7 +56,7 @@ impl CoreCryptoFfi { .map_err(RecursiveError::mls_client("getting raw conversation by id"))? .ciphersuite() .await; - Ok(Ciphersuite::from(core_crypto::CiphersuiteName::from(cs))) + Ok(Ciphersuite::from(core_crypto::MlsCiphersuite::from(cs))) } /// See [core_crypto::Session::conversation_exists] diff --git a/crypto-ffi/src/core_crypto/e2ei/identities.rs b/crypto-ffi/src/core_crypto/e2ei/identities.rs index 81382d6dce..29ff61446a 100644 --- a/crypto-ffi/src/core_crypto/e2ei/identities.rs +++ b/crypto-ffi/src/core_crypto/e2ei/identities.rs @@ -38,8 +38,8 @@ impl CoreCryptoFfi { .get_device_identities(&device_ids) .await? .into_iter() - .map(WireIdentity::from) - .collect::>(); + .map(WireIdentity::try_from) + .collect::>>()?; #[cfg(target_family = "wasm")] let wire_identities = serde_wasm_bindgen::to_value(&wire_identities).expect("device identities can always be serialized"); @@ -64,8 +64,14 @@ impl CoreCryptoFfi { let identities = conversation.get_user_identities(user_ids.as_slice()).await?; let identities = identities .into_iter() - .map(|(k, v)| (k, v.into_iter().map(WireIdentity::from).collect())) - .collect::>>(); + .map(|(k, v)| -> CoreCryptoResult<_> { + let identities = v + .into_iter() + .map(WireIdentity::try_from) + .collect::>>()?; + Ok((k, identities)) + }) + .collect::>>()?; #[cfg(target_family = "wasm")] let identities = serde_wasm_bindgen::to_value(&identities).expect("user identities can always be serialized"); Ok(identities) diff --git a/crypto-ffi/src/core_crypto/e2ei/mod.rs b/crypto-ffi/src/core_crypto/e2ei/mod.rs index e1e78311c6..3d9268f405 100644 --- a/crypto-ffi/src/core_crypto/e2ei/mod.rs +++ b/crypto-ffi/src/core_crypto/e2ei/mod.rs @@ -17,7 +17,7 @@ impl CoreCryptoFfi { /// See [core_crypto::Session::e2ei_is_enabled] pub async fn e2ei_is_enabled(&self, ciphersuite: Ciphersuite) -> CoreCryptoResult { - let signature_scheme = core_crypto::MlsCiphersuite::from(ciphersuite).signature_algorithm(); + let signature_scheme = core_crypto::Ciphersuite::from(ciphersuite).signature_algorithm(); self.inner .e2ei_is_enabled(signature_scheme) .await diff --git a/crypto-ffi/src/core_crypto_context/e2ei.rs b/crypto-ffi/src/core_crypto_context/e2ei.rs index a08190e3a4..750ada583e 100644 --- a/crypto-ffi/src/core_crypto_context/e2ei.rs +++ b/crypto-ffi/src/core_crypto_context/e2ei.rs @@ -198,7 +198,7 @@ impl CoreCryptoContext { /// See [core_crypto::Session::e2ei_is_enabled] pub async fn e2ei_is_enabled(&self, ciphersuite: Ciphersuite) -> CoreCryptoResult { - let sc = core_crypto::MlsCiphersuite::from(ciphersuite).signature_algorithm(); + let sc = core_crypto::Ciphersuite::from(ciphersuite).signature_algorithm(); self.inner .e2ei_is_enabled(sc) .await @@ -216,7 +216,7 @@ impl CoreCryptoContext { let conversation = self.inner.conversation(conversation_id.as_ref()).await?; let wire_ids = conversation.get_device_identities(device_ids.as_slice()).await?; - Ok(wire_ids.into_iter().map(Into::into).collect()) + wire_ids.into_iter().map(TryInto::try_into).collect() } /// See [core_crypto::mls::conversation::Conversation::get_user_identities] @@ -233,8 +233,14 @@ impl CoreCryptoContext { let user_ids = conversation.get_user_identities(user_ids.as_slice()).await?; let user_ids = user_ids .into_iter() - .map(|(k, v)| (k, v.into_iter().map(WireIdentity::from).collect())) - .collect::>>(); + .map(|(k, v)| -> CoreCryptoResult<_> { + let identities = v + .into_iter() + .map(WireIdentity::try_from) + .collect::>>()?; + Ok((k, identities)) + }) + .collect::>>()?; #[cfg(target_family = "wasm")] let user_ids = serde_wasm_bindgen::to_value(&user_ids)?; Ok(user_ids) diff --git a/crypto-ffi/src/core_crypto_context/mls.rs b/crypto-ffi/src/core_crypto_context/mls.rs index c422d47f84..b8af81948f 100644 --- a/crypto-ffi/src/core_crypto_context/mls.rs +++ b/crypto-ffi/src/core_crypto_context/mls.rs @@ -1,6 +1,6 @@ use core_crypto::{ - ClientIdentifier, KeyPackageIn, MlsCiphersuite, MlsConversationConfiguration, RecursiveError, VerifiableGroupInfo, - mls::conversation::Conversation as _, transaction_context::Error as TransactionError, + Ciphersuite as CryptoCiphersuite, ClientIdentifier, KeyPackageIn, MlsConversationConfiguration, RecursiveError, + VerifiableGroupInfo, mls::conversation::Conversation as _, transaction_context::Error as TransactionError, }; use tls_codec::{Deserialize as _, Serialize as _}; #[cfg(target_family = "wasm")] @@ -60,7 +60,10 @@ impl CoreCryptoContext { self.inner .mls_init( ClientIdentifier::Basic(client_id.as_cc()), - &ciphersuites.into_iter().map(MlsCiphersuite::from).collect::>(), + &ciphersuites + .into_iter() + .map(CryptoCiphersuite::from) + .collect::>(), ) .await?; Ok(()) diff --git a/crypto-ffi/src/credential.rs b/crypto-ffi/src/credential.rs new file mode 100644 index 0000000000..39b0c30e25 --- /dev/null +++ b/crypto-ffi/src/credential.rs @@ -0,0 +1,66 @@ +use core_crypto::{Ciphersuite as CryptoCiphersuite, Credential as CryptoCredential}; +use mls_crypto_provider::RustCrypto; +#[cfg(target_family = "wasm")] +use wasm_bindgen::prelude::*; + +use crate::{Ciphersuite, CoreCryptoResult, CredentialType, client_id::ClientIdMaybeArc}; + +/// A cryptographic credential. +/// +/// This is tied to a particular client via either its client id or certificate bundle, +/// depending on its credential type, but is independent of any client instance or storage. +/// +/// To attach to a particular client instance and store, see [`CoreCryptoContext::add_credential`][crate::CoreCryptoContext::add_credential]. +#[derive(Debug, Clone, derive_more::From, derive_more::Into)] +#[cfg_attr(target_family = "wasm", wasm_bindgen, derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(not(target_family = "wasm"), derive(uniffi::Object))] +pub struct Credential(CryptoCredential); + +impl Credential { + fn basic_impl(ciphersuite: Ciphersuite, client_id: ClientIdMaybeArc) -> CoreCryptoResult { + let crypto = RustCrypto::default(); + CryptoCredential::basic( + CryptoCiphersuite::from(ciphersuite).signature_algorithm(), + &client_id.as_cc(), + crypto, + ) + .map(Into::into) + .map_err(Into::into) + } +} + +/// Generate a basic credential. +/// +/// The result is independent of any client instance and the database; it lives in memory only. +#[cfg(not(target_family = "wasm"))] +#[uniffi::export] +pub fn credential_basic(ciphersuite: Ciphersuite, client_id: ClientIdMaybeArc) -> CoreCryptoResult { + Credential::basic_impl(ciphersuite, client_id) +} + +#[cfg(target_family = "wasm")] +#[wasm_bindgen] +impl Credential { + /// Generate a basic credential. + /// + /// The result is independent of any client instance and the database; it lives in memory only. + pub fn basic(ciphersuite: Ciphersuite, client_id: ClientIdMaybeArc) -> CoreCryptoResult { + Credential::basic_impl(ciphersuite, client_id) + } +} + +#[cfg_attr(target_family = "wasm", wasm_bindgen)] +#[cfg_attr(not(target_family = "wasm"), uniffi::export)] +impl Credential { + /// Get the type of this credential. + pub fn r#type(&self) -> CoreCryptoResult { + self.0.mls_credential().credential_type().try_into() + } + + /// Get the earliest possible validity of this credential, expressed as seconds after the unix epoch. + /// + /// Basic credentials have no defined earliest validity and will always return 0. + pub fn earliest_validity(&self) -> u64 { + self.0.earliest_validity() + } +} diff --git a/crypto-ffi/src/credential_type.rs b/crypto-ffi/src/credential_type.rs index 6e8dec9061..d9dfa7b746 100644 --- a/crypto-ffi/src/credential_type.rs +++ b/crypto-ffi/src/credential_type.rs @@ -1,6 +1,8 @@ #[cfg(target_family = "wasm")] use wasm_bindgen::prelude::*; +use crate::CoreCryptoError; + /// Type of Credential #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[cfg_attr(target_family = "wasm", wasm_bindgen, derive(serde::Serialize, serde::Deserialize))] @@ -14,20 +16,22 @@ pub enum CredentialType { X509 = 0x02, } -impl From for CredentialType { - fn from(value: core_crypto::MlsCredentialType) -> Self { +impl TryFrom for CredentialType { + type Error = CoreCryptoError; + fn try_from(value: core_crypto::CredentialType) -> Result { match value { - core_crypto::MlsCredentialType::Basic => Self::Basic, - core_crypto::MlsCredentialType::X509 => Self::X509, + core_crypto::CredentialType::Basic => Ok(Self::Basic), + core_crypto::CredentialType::X509 => Ok(Self::X509), + core_crypto::CredentialType::Unknown(_) => Err(CoreCryptoError::ad_hoc("unknown credential type")), } } } -impl From for core_crypto::MlsCredentialType { - fn from(value: CredentialType) -> core_crypto::MlsCredentialType { +impl From for core_crypto::CredentialType { + fn from(value: CredentialType) -> core_crypto::CredentialType { match value { - CredentialType::Basic => core_crypto::MlsCredentialType::Basic, - CredentialType::X509 => core_crypto::MlsCredentialType::X509, + CredentialType::Basic => core_crypto::CredentialType::Basic, + CredentialType::X509 => core_crypto::CredentialType::X509, } } } diff --git a/crypto-ffi/src/decrypted_message.rs b/crypto-ffi/src/decrypted_message.rs index ed19380a5c..9a47a71469 100644 --- a/crypto-ffi/src/decrypted_message.rs +++ b/crypto-ffi/src/decrypted_message.rs @@ -68,16 +68,17 @@ impl TryFrom for DecryptedMessage { .transpose()?; #[expect(deprecated)] - Ok(Self { + let msg = Self { message: from.app_msg, is_active: from.is_active, commit_delay: from.delay, sender_client_id: from.sender_client_id.map(ClientId::from_cc), has_epoch_changed: from.has_epoch_changed, - identity: from.identity.into(), + identity: from.identity.try_into()?, buffered_messages, crl_new_distribution_points: from.crl_new_distribution_points.into(), - }) + }; + Ok(msg) } } @@ -124,14 +125,15 @@ impl TryFrom for BufferedDecryptedMessage fn try_from(from: MlsBufferedConversationDecryptMessage) -> Result { #[expect(deprecated)] - Ok(Self { + let msg = Self { message: from.app_msg, is_active: from.is_active, commit_delay: from.delay, sender_client_id: from.sender_client_id.map(ClientId::from_cc), has_epoch_changed: from.has_epoch_changed, - identity: from.identity.into(), + identity: from.identity.try_into()?, crl_new_distribution_points: from.crl_new_distribution_points.into(), - }) + }; + Ok(msg) } } diff --git a/crypto-ffi/src/error/core_crypto.rs b/crypto-ffi/src/error/core_crypto.rs index 13f2e9897a..49a7d6827a 100644 --- a/crypto-ffi/src/error/core_crypto.rs +++ b/crypto-ffi/src/error/core_crypto.rs @@ -241,6 +241,7 @@ macro_rules! impl_from_via_recursive_error { impl_from_via_recursive_error!( core_crypto::mls::Error, core_crypto::mls::conversation::Error, + core_crypto::mls::credential::Error, core_crypto::e2e_identity::Error, core_crypto::transaction_context::Error, ); diff --git a/crypto-ffi/src/identity/wire.rs b/crypto-ffi/src/identity/wire.rs index d64458db2a..4caef352ed 100644 --- a/crypto-ffi/src/identity/wire.rs +++ b/crypto-ffi/src/identity/wire.rs @@ -1,7 +1,7 @@ #[cfg(target_family = "wasm")] use wasm_bindgen::prelude::*; -use crate::{CredentialType, X509Identity}; +use crate::{CoreCryptoError, CredentialType, X509Identity}; /// Represents the identity claims identifying a client /// Those claims are verifiable by any member in the group @@ -31,15 +31,17 @@ pub struct WireIdentity { pub x509_identity: Option, } -impl From for WireIdentity { - fn from(i: core_crypto::WireIdentity) -> Self { - Self { +impl TryFrom for WireIdentity { + type Error = CoreCryptoError; + fn try_from(i: core_crypto::WireIdentity) -> Result { + let identity = Self { client_id: i.client_id, status: i.status.into(), thumbprint: i.thumbprint, - credential_type: i.credential_type.into(), + credential_type: i.credential_type.try_into()?, x509_identity: i.x509_identity.map(Into::into), - } + }; + Ok(identity) } } diff --git a/crypto-ffi/src/lib.rs b/crypto-ffi/src/lib.rs index baf4fd5aba..df04cec790 100644 --- a/crypto-ffi/src/lib.rs +++ b/crypto-ffi/src/lib.rs @@ -16,6 +16,7 @@ mod client_id; mod configuration; mod core_crypto; mod core_crypto_context; +mod credential; mod credential_type; mod crl; mod database; @@ -49,6 +50,9 @@ pub(crate) use core_crypto::{ e2ei::identities::UserIdentities, }; pub use core_crypto_context::CoreCryptoContext; +pub use credential::Credential; +#[cfg(not(target_family = "wasm"))] +pub use credential::credential_basic; pub use credential_type::CredentialType; pub use crl::CrlRegistration; pub use database::{ diff --git a/crypto-macros/src/debug.rs b/crypto-macros/src/debug.rs index 449c9d2466..204e3cf10b 100644 --- a/crypto-macros/src/debug.rs +++ b/crypto-macros/src/debug.rs @@ -49,7 +49,7 @@ fn parse_type(ty: &Type) -> BytesType { type_string.retain(|c| !c.is_whitespace()); match type_string.as_str() { "Option>" => BytesType::OptionalBytes, - "Vec" => BytesType::Bytes, + "Vec" | "[u8]" => BytesType::Bytes, _ => BytesType::Other, } } diff --git a/crypto/Cargo.toml b/crypto/Cargo.toml index e37fb1c061..77f32a3784 100644 --- a/crypto/Cargo.toml +++ b/crypto/Cargo.toml @@ -36,7 +36,7 @@ hex.workspace = true futures-util.workspace = true openmls = { workspace = true, features = ["crypto-subtle"] } -openmls_basic_credential.workspace = true +openmls_basic_credential = { workspace = true, features = ["clonable"] } openmls_x509_credential.workspace = true openmls_traits.workspace = true mls-crypto-provider.workspace = true diff --git a/crypto/benches/create_group.rs b/crypto/benches/create_group.rs index 54e5dd5287..fdb57c79f3 100644 --- a/crypto/benches/create_group.rs +++ b/crypto/benches/create_group.rs @@ -1,6 +1,6 @@ use std::hint::black_box; -use core_crypto::{MlsConversationConfiguration, MlsCredentialType, MlsCustomConfiguration}; +use core_crypto::{CredentialType, MlsConversationConfiguration, MlsCustomConfiguration}; use criterion::{ BatchSize, BenchmarkId, Criterion, async_executor::SmolExecutor as FuturesExecutor, criterion_group, criterion_main, }; @@ -32,10 +32,7 @@ fn create_group_bench(c: &mut Criterion) { }, |(central, id, cfg)| async move { let context = central.new_transaction().await.unwrap(); - context - .new_conversation(&id, MlsCredentialType::Basic, cfg) - .await - .unwrap(); + context.new_conversation(&id, CredentialType::Basic, cfg).await.unwrap(); context.finish().await.unwrap(); black_box(()); }, @@ -62,7 +59,7 @@ fn join_from_welcome_bench(c: &mut Criterion) { let (bob_central, ..) = new_central(ciphersuite, credential.as_ref(), in_memory).await; let bob_context = bob_central.new_transaction().await.unwrap(); let bob_kpbs = bob_context - .get_or_create_client_keypackages(ciphersuite, MlsCredentialType::Basic, 1) + .get_or_create_client_keypackages(ciphersuite, CredentialType::Basic, 1) .await .unwrap(); let bob_kp = bob_kpbs.first().unwrap().clone(); @@ -120,7 +117,7 @@ fn join_from_group_info_bench(c: &mut Criterion) { .join_by_external_commit( group_info, MlsCustomConfiguration::default(), - MlsCredentialType::Basic, + CredentialType::Basic, ) .await .unwrap(), diff --git a/crypto/benches/key_package.rs b/crypto/benches/key_package.rs index d09fe076b4..9f111078da 100644 --- a/crypto/benches/key_package.rs +++ b/crypto/benches/key_package.rs @@ -1,6 +1,6 @@ use std::hint::black_box; -use core_crypto::MlsCredentialType; +use core_crypto::CredentialType; use criterion::{ BatchSize, Criterion, async_executor::SmolExecutor as FuturesExecutor, criterion_group, criterion_main, }; @@ -15,8 +15,8 @@ fn generate_key_package_bench(c: &mut Criterion) { for (case, ciphersuite, credential, in_memory) in MlsTestCase::values() { let credential_type = credential .as_ref() - .map(|_| MlsCredentialType::X509) - .unwrap_or(MlsCredentialType::Basic); + .map(|_| CredentialType::X509) + .unwrap_or(CredentialType::Basic); for i in (GROUP_RANGE).step_by(GROUP_STEP) { group.bench_with_input(case.benchmark_id(i + 1, in_memory), &i, |b, i| { b.to_async(FuturesExecutor).iter_batched( @@ -44,8 +44,8 @@ fn count_key_packages_bench(c: &mut Criterion) { for (case, ciphersuite, credential, in_memory) in MlsTestCase::values() { let credential_type = credential .as_ref() - .map(|_| MlsCredentialType::X509) - .unwrap_or(MlsCredentialType::Basic); + .map(|_| CredentialType::X509) + .unwrap_or(CredentialType::Basic); for i in (GROUP_RANGE).step_by(GROUP_STEP) { group.bench_with_input(case.benchmark_id(i + 1, in_memory), &i, |b, i| { b.to_async(FuturesExecutor).iter_batched( diff --git a/crypto/benches/utils/mls.rs b/crypto/benches/utils/mls.rs index eec5c78e4b..7c56a02c53 100644 --- a/crypto/benches/utils/mls.rs +++ b/crypto/benches/utils/mls.rs @@ -5,8 +5,8 @@ use std::{ use async_lock::RwLock; use core_crypto::{ - CertificateBundle, ClientId, ConnectionType, ConversationId, CoreCrypto, Database, DatabaseKey, HistorySecret, - MlsCiphersuite, MlsCommitBundle, MlsConversationConfiguration, MlsCredentialType, MlsCustomConfiguration, + CertificateBundle, Ciphersuite, ClientId, ConnectionType, ConversationId, CoreCrypto, CredentialType, Database, + DatabaseKey, HistorySecret, MlsCommitBundle, MlsConversationConfiguration, MlsCustomConfiguration, MlsGroupInfoBundle, MlsTransport, MlsTransportData, MlsTransportResponse, Session, SessionConfig, }; use criterion::BenchmarkId; @@ -18,7 +18,7 @@ use openmls::{ }, }; use openmls_basic_credential::SignatureKeyPair; -use openmls_traits::{OpenMlsCryptoProvider, random::OpenMlsRand, types::Ciphersuite}; +use openmls_traits::{OpenMlsCryptoProvider, random::OpenMlsRand, types::Ciphersuite as MlsCiphersuite}; use rand::distributions::{Alphanumeric, DistString}; use tls_codec::Deserialize; @@ -35,31 +35,35 @@ pub enum MlsTestCase { } impl MlsTestCase { - pub fn get(&self) -> (Self, MlsCiphersuite, Option) { + pub fn get(&self) -> (Self, Ciphersuite, Option) { match self { MlsTestCase::Basic_Ciphersuite1 => ( *self, - Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519.into(), + MlsCiphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519.into(), None, ), #[cfg(feature = "test-all-cipher")] - MlsTestCase::Basic_Ciphersuite2 => { - (*self, Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256.into(), None) - } + MlsTestCase::Basic_Ciphersuite2 => ( + *self, + MlsCiphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256.into(), + None, + ), #[cfg(feature = "test-all-cipher")] MlsTestCase::Basic_Ciphersuite3 => ( *self, - Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519.into(), + MlsCiphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519.into(), None, ), #[cfg(feature = "test-all-cipher")] - MlsTestCase::Basic_Ciphersuite7 => { - (*self, Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384.into(), None) - } + MlsTestCase::Basic_Ciphersuite7 => ( + *self, + MlsCiphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384.into(), + None, + ), } } - pub fn values() -> impl Iterator, bool)> { + pub fn values() -> impl Iterator, bool)> { [ MlsTestCase::Basic_Ciphersuite1, #[cfg(feature = "test-all-cipher")] @@ -121,7 +125,7 @@ impl Display for MlsTestCase { } pub async fn setup_mls( - ciphersuite: MlsCiphersuite, + ciphersuite: Ciphersuite, credential: Option<&CertificateBundle>, in_memory: bool, ) -> (CoreCrypto, ConversationId, Arc) { @@ -132,7 +136,7 @@ pub async fn setup_mls( context .new_conversation( &id, - MlsCredentialType::Basic, + CredentialType::Basic, MlsConversationConfiguration { ciphersuite, ..Default::default() @@ -146,7 +150,7 @@ pub async fn setup_mls( } pub async fn new_central( - ciphersuite: MlsCiphersuite, + ciphersuite: Ciphersuite, // TODO: always None for the moment. Need to update the benches with some realistic certificates. Tracking issue: WPB-9589 _credential: Option<&CertificateBundle>, in_memory: bool, @@ -161,7 +165,7 @@ pub async fn new_central( let db = Database::open(connection_type, &DatabaseKey::generate()).await.unwrap(); let cfg = SessionConfig::builder() .database(db) - .client_id(client_id.as_bytes().into()) + .client_id(client_id.as_bytes().to_owned().into()) .ciphersuites([ciphersuite]) .build() .validate() @@ -190,7 +194,7 @@ pub fn conversation_id() -> ConversationId { pub async fn add_clients( central: &mut Session, id: &ConversationId, - ciphersuite: MlsCiphersuite, + ciphersuite: Ciphersuite, nb_clients: usize, main_client_delivery_service: Arc, ) -> (Vec, VerifiableGroupInfo) { @@ -199,7 +203,7 @@ pub async fn add_clients( let mut key_packages = vec![]; for _ in 0..nb_clients { let (kp, id) = rand_key_package(ciphersuite).await; - client_ids.push(id.as_slice().into()); + client_ids.push(id); key_packages.push(kp.into()) } @@ -225,7 +229,7 @@ pub async fn add_clients( } pub async fn setup_mls_and_add_clients( - cipher_suite: MlsCiphersuite, + cipher_suite: Ciphersuite, credential: Option<&CertificateBundle>, in_memory: bool, client_count: usize, @@ -248,12 +252,12 @@ pub async fn setup_mls_and_add_clients( (core_crypto, id, client_ids, group_info, delivery_service) } -fn create_signature_keypair(backend: &MlsCryptoProvider, ciphersuite: Ciphersuite) -> SignatureKeyPair { +fn create_signature_keypair(backend: &MlsCryptoProvider, ciphersuite: MlsCiphersuite) -> SignatureKeyPair { let mut rng = backend.rand().borrow_rand().unwrap(); SignatureKeyPair::new(ciphersuite.signature_algorithm(), &mut *rng).unwrap() } -pub async fn rand_key_package(ciphersuite: MlsCiphersuite) -> (KeyPackage, ClientId) { +pub async fn rand_key_package(ciphersuite: Ciphersuite) -> (KeyPackage, ClientId) { let client_id = Alphanumeric .sample_string(&mut rand::thread_rng(), 16) .as_bytes() @@ -261,7 +265,7 @@ pub async fn rand_key_package(ciphersuite: MlsCiphersuite) -> (KeyPackage, Clien let key = DatabaseKey::generate(); let key_store = Database::open(ConnectionType::InMemory, &key).await.unwrap(); let backend = MlsCryptoProvider::new(key_store); - let cs: Ciphersuite = ciphersuite.into(); + let cs: MlsCiphersuite = ciphersuite.into(); let signer = create_signature_keypair(&backend, cs); let cred = Credential::new_basic(client_id.clone()); @@ -286,7 +290,7 @@ pub async fn invite( from: &mut Session, other: &mut Session, id: &ConversationId, - ciphersuite: MlsCiphersuite, + ciphersuite: Ciphersuite, delivery_service: Arc, ) { let core_crypto = CoreCrypto::from(from.clone()); @@ -294,7 +298,7 @@ pub async fn invite( let core_crypto = CoreCrypto::from(other.clone()); let other_context = core_crypto.new_transaction().await.unwrap(); let other_kps = other_context - .get_or_create_client_keypackages(ciphersuite, MlsCredentialType::Basic, 1) + .get_or_create_client_keypackages(ciphersuite, CredentialType::Basic, 1) .await .unwrap(); let other_kp = other_kps.first().unwrap().clone(); diff --git a/crypto/src/e2e_identity/crypto.rs b/crypto/src/e2e_identity/crypto.rs index e933eca188..f325858ab3 100644 --- a/crypto/src/e2e_identity/crypto.rs +++ b/crypto/src/e2e_identity/crypto.rs @@ -1,25 +1,25 @@ use mls_crypto_provider::PkiKeypair; use openmls_basic_credential::SignatureKeyPair as OpenMlsSignatureKeyPair; -use openmls_traits::types::{Ciphersuite, SignatureScheme}; +use openmls_traits::types::{Ciphersuite as MlsCiphersuite, SignatureScheme}; use wire_e2e_identity::prelude::JwsAlgorithm; use zeroize::Zeroize; use super::error::*; -use crate::{MlsCiphersuite, MlsError}; +use crate::{Ciphersuite, MlsError}; -impl TryFrom for JwsAlgorithm { +impl TryFrom for JwsAlgorithm { type Error = Error; - fn try_from(cs: MlsCiphersuite) -> Result { - let cs = openmls_traits::types::Ciphersuite::from(cs); + fn try_from(cs: Ciphersuite) -> Result { + let cs = MlsCiphersuite::from(cs); Ok(match cs { - Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 - | Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 => JwsAlgorithm::Ed25519, - Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256 => JwsAlgorithm::P256, - Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384 => JwsAlgorithm::P384, - Ciphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521 => JwsAlgorithm::P521, - Ciphersuite::MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448 - | Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448 => return Err(Error::NotYetSupported), + MlsCiphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 + | MlsCiphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 => JwsAlgorithm::Ed25519, + MlsCiphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256 => JwsAlgorithm::P256, + MlsCiphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384 => JwsAlgorithm::P384, + MlsCiphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521 => JwsAlgorithm::P521, + MlsCiphersuite::MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448 + | MlsCiphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448 => return Err(Error::NotYetSupported), }) } } diff --git a/crypto/src/e2e_identity/enrollment/crypto.rs b/crypto/src/e2e_identity/enrollment/crypto.rs index 9782b15dca..44641dfd20 100644 --- a/crypto/src/e2e_identity/enrollment/crypto.rs +++ b/crypto/src/e2e_identity/enrollment/crypto.rs @@ -1,17 +1,13 @@ use mls_crypto_provider::{MlsCryptoProvider, RustCrypto}; use openmls::prelude::SignatureScheme; -use openmls_traits::{OpenMlsCryptoProvider as _, crypto::OpenMlsCrypto as _}; +use openmls_traits::crypto::OpenMlsCrypto as _; use super::{Error, Result}; -use crate::{MlsCiphersuite, MlsError, e2e_identity::crypto::E2eiSignatureKeypair}; +use crate::{Ciphersuite, MlsError, e2e_identity::crypto::E2eiSignatureKeypair}; impl super::E2eiEnrollment { - pub(crate) fn new_sign_key( - ciphersuite: MlsCiphersuite, - backend: &MlsCryptoProvider, - ) -> Result { + pub(crate) fn new_sign_key(ciphersuite: Ciphersuite, backend: &MlsCryptoProvider) -> Result { let (sk, _) = backend - .crypto() .signature_key_gen(ciphersuite.signature_algorithm()) .map_err(MlsError::wrap("performing signature keygen"))?; E2eiSignatureKeypair::try_new(ciphersuite.signature_algorithm(), sk) diff --git a/crypto/src/e2e_identity/enrollment/mod.rs b/crypto/src/e2e_identity/enrollment/mod.rs index dc44400b56..ee5089624e 100644 --- a/crypto/src/e2e_identity/enrollment/mod.rs +++ b/crypto/src/e2e_identity/enrollment/mod.rs @@ -9,7 +9,7 @@ use wire_e2e_identity::{RustyE2eIdentity, prelude::E2eiAcmeAuthorization}; use zeroize::Zeroize as _; use super::{EnrollmentHandle, Error, Json, Result, crypto::E2eiSignatureKeypair, id::QualifiedE2eiClientId, types}; -use crate::{ClientId, KeystoreError, MlsCiphersuite, MlsError}; +use crate::{Ciphersuite, ClientId, KeystoreError, MlsError}; /// Wire end to end identity solution for fetching a x509 certificate which identifies a client. #[derive(Debug, serde::Serialize, serde::Deserialize)] @@ -27,7 +27,7 @@ pub struct E2eiEnrollment { device_authz: Option, valid_order: Option, finalize: Option, - pub(super) ciphersuite: MlsCiphersuite, + pub(super) ciphersuite: Ciphersuite, has_called_new_oidc_challenge_request: bool, } @@ -56,7 +56,7 @@ impl E2eiEnrollment { team: Option, expiry_sec: u32, backend: &MlsCryptoProvider, - ciphersuite: MlsCiphersuite, + ciphersuite: Ciphersuite, sign_keypair: Option, has_called_new_oidc_challenge_request: bool, ) -> Result { @@ -88,7 +88,7 @@ impl E2eiEnrollment { }) } - pub(crate) fn ciphersuite(&self) -> &MlsCiphersuite { + pub(crate) fn ciphersuite(&self) -> &Ciphersuite { &self.ciphersuite } diff --git a/crypto/src/e2e_identity/enrollment/test_utils.rs b/crypto/src/e2e_identity/enrollment/test_utils.rs index 08bc156a68..8d0d63aa10 100644 --- a/crypto/src/e2e_identity/enrollment/test_utils.rs +++ b/crypto/src/e2e_identity/enrollment/test_utils.rs @@ -3,7 +3,7 @@ use mls_crypto_provider::PkiKeypair; use serde_json::json; use crate::{ - CertificateBundle, MlsCredentialType, RecursiveError, + CertificateBundle, CredentialType, RecursiveError, e2e_identity::{E2eiEnrollment, Result, id::QualifiedE2eiClientId}, test_utils::{SessionContext, TestContext, context::TEAM, x509::X509TestChain}, transaction_context::TransactionContext, @@ -58,7 +58,7 @@ pub(crate) fn init_activation_or_rotation(wrapper: E2eiInitWrapper<'_>) -> InitF let E2eiInitWrapper { context: cc, case } = wrapper; let cs = case.ciphersuite(); match case.credential_type { - MlsCredentialType::Basic => { + CredentialType::Basic => { cc.e2ei_new_activation_enrollment( NEW_DISPLAY_NAME.to_string(), NEW_HANDLE.to_string(), @@ -68,7 +68,7 @@ pub(crate) fn init_activation_or_rotation(wrapper: E2eiInitWrapper<'_>) -> InitF ) .await } - MlsCredentialType::X509 => { + CredentialType::X509 => { cc.e2ei_new_rotate_enrollment( Some(NEW_DISPLAY_NAME.to_string()), Some(NEW_HANDLE.to_string()), @@ -78,6 +78,7 @@ pub(crate) fn init_activation_or_rotation(wrapper: E2eiInitWrapper<'_>) -> InitF ) .await } + CredentialType::Unknown(_) => panic!("unknown credential types are unsupported"), } .map_err(RecursiveError::transaction("creating new enrollment")) .map_err(Into::into) diff --git a/crypto/src/e2e_identity/id.rs b/crypto/src/e2e_identity/id.rs index f44eb00937..74ca771762 100644 --- a/crypto/src/e2e_identity/id.rs +++ b/crypto/src/e2e_identity/id.rs @@ -104,7 +104,7 @@ impl QualifiedE2eiClientId { } pub fn from_str_unchecked(s: &str) -> Self { - Self(s.as_bytes().into()) + Self(s.as_bytes().to_owned().into()) } } diff --git a/crypto/src/e2e_identity/identity.rs b/crypto/src/e2e_identity/identity.rs index 8a4c3fe334..001a893d15 100644 --- a/crypto/src/e2e_identity/identity.rs +++ b/crypto/src/e2e_identity/identity.rs @@ -4,7 +4,7 @@ use x509_cert::der::pem::LineEnding; use super::{Error, Result}; use crate::{ - MlsCredentialType, + CredentialType, e2e_identity::{device_status::DeviceStatus, id::WireQualifiedClientId}, }; @@ -19,8 +19,8 @@ pub struct WireIdentity { /// Status of the Credential at the moment T when this object is created pub status: DeviceStatus, /// Indicates whether the credential is Basic or X509 - pub credential_type: MlsCredentialType, - /// In case 'credential_type' is [MlsCredentialType::X509] this is populated + pub credential_type: CredentialType, + /// In case 'credential_type' is [CredentialType::X509] this is populated pub x509_identity: Option, } @@ -60,7 +60,7 @@ impl<'a> TryFrom<(wire_e2e_identity::prelude::WireIdentity, &'a [u8])> for WireI client_id: client_id.try_into()?, status: i.status.into(), thumbprint: i.thumbprint, - credential_type: MlsCredentialType::X509, + credential_type: CredentialType::X509, x509_identity: Some(X509Identity { handle: i.handle.to_string(), display_name: i.display_name, diff --git a/crypto/src/ephemeral.rs b/crypto/src/ephemeral.rs index 7a9c35cef1..6d2dcf999a 100644 --- a/crypto/src/ephemeral.rs +++ b/crypto/src/ephemeral.rs @@ -21,14 +21,16 @@ //! Any attempt to encrypt a message will fail because the client cannot retrieve the signature key from //! its keystore. +use std::borrow::Borrow; + use core_crypto_keystore::{ConnectionType, Database}; use mls_crypto_provider::DatabaseKey; use obfuscate::{Obfuscate, Obfuscated}; use openmls::prelude::KeyPackageSecretEncapsulation; use crate::{ - ClientId, ClientIdentifier, CoreCrypto, Error, MlsCiphersuite, MlsCredentialType, MlsError, RecursiveError, Result, - Session, SessionConfig, + Ciphersuite, ClientId, ClientIdRef, ClientIdentifier, CoreCrypto, CredentialType, Error, MlsError, RecursiveError, + Result, Session, SessionConfig, }; /// We always instantiate history clients with this prefix in their client id, so @@ -56,7 +58,7 @@ impl Obfuscate for HistorySecret { /// Create a new [`CoreCrypto`] with an **uninitialized** mls session. /// /// You must initialize the session yourself before using this! -async fn in_memory_cc_with_ciphersuite(ciphersuite: impl Into) -> Result { +async fn in_memory_cc_with_ciphersuite(ciphersuite: impl Into) -> Result { let db = Database::open(ConnectionType::InMemory, &DatabaseKey::generate()) .await .unwrap(); @@ -85,7 +87,7 @@ async fn in_memory_cc_with_ciphersuite(ciphersuite: impl Into) - /// Note that this is a crate-private function; the public interface for this feature is /// [`Conversation::generate_history_secret`][core_crypto::mls::conversation::Conversation::generate_history_secret]. /// This implementation lives here instead of there for organizational reasons. -pub(crate) async fn generate_history_secret(ciphersuite: MlsCiphersuite) -> Result { +pub(crate) async fn generate_history_secret(ciphersuite: Ciphersuite) -> Result { // generate a new completely arbitrary client id let client_id = uuid::Uuid::new_v4(); let client_id = format!("{HISTORY_CLIENT_ID_PREFIX}-{client_id}"); @@ -97,13 +99,13 @@ pub(crate) async fn generate_history_secret(ciphersuite: MlsCiphersuite) -> Resu .new_transaction() .await .map_err(RecursiveError::transaction("creating new transaction"))?; - cc.init(identifier, &[ciphersuite], &cc.crypto_provider) + cc.init(identifier, &[ciphersuite.signature_algorithm()]) .await .map_err(RecursiveError::mls_client("initializing ephemeral cc"))?; // we can generate a key package from the ephemeral cc and ciphersutite let [key_package] = tx - .get_or_create_client_keypackages(ciphersuite, MlsCredentialType::Basic, 1) + .get_or_create_client_keypackages(ciphersuite, CredentialType::Basic, 1) .await .map_err(RecursiveError::transaction("generating keypackages"))? .try_into() @@ -117,8 +119,8 @@ pub(crate) async fn generate_history_secret(ciphersuite: MlsCiphersuite) -> Resu Ok(HistorySecret { client_id, key_package }) } -pub(crate) fn is_history_client(client_id: &ClientId) -> bool { - client_id.starts_with(HISTORY_CLIENT_ID_PREFIX.as_bytes()) +pub(crate) fn is_history_client(client_id: impl Borrow) -> bool { + client_id.borrow().starts_with(HISTORY_CLIENT_ID_PREFIX.as_bytes()) } impl CoreCrypto { diff --git a/crypto/src/error/recursive.rs b/crypto/src/error/recursive.rs index 414f5f9a91..0d62a68131 100644 --- a/crypto/src/error/recursive.rs +++ b/crypto/src/error/recursive.rs @@ -34,6 +34,10 @@ pub enum RecursiveError { context: &'static str, source: Box, }, + MlsCredentialRef { + context: &'static str, + source: Box, + }, #[cfg(test)] Test(Box), } @@ -87,6 +91,15 @@ impl RecursiveError { source: Box::new(into_source.into()), } } + + pub fn mls_credential_ref>( + context: &'static str, + ) -> impl FnOnce(E) -> Self { + move |into_source| Self::MlsCredentialRef { + context, + source: Box::new(into_source.into()), + } + } } impl std::fmt::Display for RecursiveError { @@ -101,6 +114,7 @@ impl std::fmt::Display for RecursiveError { RecursiveError::MlsClient { context, .. } => context, RecursiveError::MlsConversation { context, .. } => context, RecursiveError::MlsCredential { context, .. } => context, + RecursiveError::MlsCredentialRef { context, .. } => context, RecursiveError::TransactionContext { context, .. } => context, #[cfg(test)] RecursiveError::Test(e) => return e.deref().fmt(f), @@ -118,6 +132,7 @@ impl std::error::Error for RecursiveError { RecursiveError::MlsClient { source, .. } => Some(source.as_ref()), RecursiveError::MlsConversation { source, .. } => Some(source.as_ref()), RecursiveError::MlsCredential { source, .. } => Some(source.as_ref()), + RecursiveError::MlsCredentialRef { source, .. } => Some(source.as_ref()), RecursiveError::TransactionContext { source, .. } => Some(source.as_ref()), #[cfg(test)] RecursiveError::Test(source) => Some(source.as_ref()), @@ -127,7 +142,9 @@ impl std::error::Error for RecursiveError { /// Like [`Into`], but different, because we don't actually want to implement `Into` for our subordinate error types. /// -/// By forcing ourselves to map errors everywhere in order for question mark operators to work, we ensure that +/// By forcing ourselves to map errors everywhere in order for question mark operators to work, we ensure that we can +/// take the opportunity to include a little bit of manual context. Pervasively done, this means that our errors have +/// quite a lot of contextual information about the call stack and what precisely has gone wrong. pub trait ToRecursiveError { /// Construct a recursive error given the current context fn construct_recursive(self, context: &'static str) -> RecursiveError; @@ -155,5 +172,6 @@ impl_to_recursive_error_for!( crate::mls::session::Error => MlsClient, crate::mls::conversation::Error => MlsConversation, crate::mls::credential::Error => MlsCredential, + crate::mls::credential::credential_ref::Error => MlsCredentialRef, crate::transaction_context::Error => TransactionContext, ); diff --git a/crypto/src/lib.rs b/crypto/src/lib.rs index ad4b765ac9..c80db277be 100644 --- a/crypto/src/lib.rs +++ b/crypto/src/lib.rs @@ -33,8 +33,8 @@ pub use mls_crypto_provider::{EntropySeed, MlsCryptoProvider, RawEntropySeed}; pub use openmls::{ group::{MlsGroup, MlsGroupConfig}, prelude::{ - Ciphersuite as CiphersuiteName, Credential, GroupEpoch, KeyPackage, KeyPackageIn, KeyPackageRef, MlsMessageIn, - Node, group_info::VerifiableGroupInfo, + Ciphersuite as MlsCiphersuite, CredentialType, GroupEpoch, KeyPackage, KeyPackageIn, KeyPackageRef, + MlsMessageIn, Node, group_info::VerifiableGroupInfo, }, }; #[cfg(feature = "proteus")] @@ -54,7 +54,7 @@ pub use crate::{ RecursiveError, Result, ToRecursiveError, }, mls::{ - ciphersuite::MlsCiphersuite, + ciphersuite::Ciphersuite, conversation::{ ConversationId, MlsConversation, commit::MlsCommitBundle, @@ -64,12 +64,12 @@ pub use crate::{ proposal::MlsProposalBundle, welcome::WelcomeBundle, }, - credential::{typ::MlsCredentialType, x509::CertificateBundle}, + credential::{Credential, CredentialRef, FindFilters as CredentialFindFilters, x509::CertificateBundle}, proposal::{MlsProposal, MlsProposalRef}, session::{ EpochObserver, HistoryObserver, Session, config::{SessionConfig, ValidatedSessionConfig}, - id::ClientId, + id::{ClientId, ClientIdRef}, identifier::ClientIdentifier, key_package::INITIAL_KEYING_MATERIAL_COUNT, user_id::UserId, diff --git a/crypto/src/mls/ciphersuite.rs b/crypto/src/mls/ciphersuite.rs index 3814b92b9d..686cb588d6 100644 --- a/crypto/src/mls/ciphersuite.rs +++ b/crypto/src/mls/ciphersuite.rs @@ -1,16 +1,18 @@ -use openmls_traits::types::{Ciphersuite, HashType}; +use openmls_traits::types::HashType; use wire_e2e_identity::prelude::HashAlgorithm; use super::{Error, Result}; -use crate::CiphersuiteName; +use crate::MlsCiphersuite; -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, derive_more::Deref, serde::Serialize, serde::Deserialize)] +#[derive( + Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash, derive_more::Deref, serde::Serialize, serde::Deserialize, +)] #[serde(transparent)] #[repr(transparent)] /// A wrapper for the OpenMLS Ciphersuite, so that we are able to provide a default value. -pub struct MlsCiphersuite(pub(crate) Ciphersuite); +pub struct Ciphersuite(pub(crate) MlsCiphersuite); -impl MlsCiphersuite { +impl Ciphersuite { pub(crate) fn e2ei_hash_alg(&self) -> HashAlgorithm { match self.0.hash_algorithm() { HashType::Sha2_256 => HashAlgorithm::SHA256, @@ -20,35 +22,35 @@ impl MlsCiphersuite { } } -impl Default for MlsCiphersuite { +impl Default for Ciphersuite { fn default() -> Self { - Self(Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519) + Self(MlsCiphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519) } } -impl From for MlsCiphersuite { - fn from(value: Ciphersuite) -> Self { +impl From for Ciphersuite { + fn from(value: MlsCiphersuite) -> Self { Self(value) } } -impl From for Ciphersuite { - fn from(ciphersuite: MlsCiphersuite) -> Self { +impl From for MlsCiphersuite { + fn from(ciphersuite: Ciphersuite) -> Self { ciphersuite.0 } } -impl From for u16 { - fn from(cs: MlsCiphersuite) -> Self { +impl From for u16 { + fn from(cs: Ciphersuite) -> Self { (&cs.0).into() } } -impl TryFrom for MlsCiphersuite { +impl TryFrom for Ciphersuite { type Error = Error; fn try_from(c: u16) -> Result { - Ok(CiphersuiteName::try_from(c) + Ok(MlsCiphersuite::try_from(c) .map_err(|_| Error::UnknownCiphersuite)? .into()) } diff --git a/crypto/src/mls/conversation/config.rs b/crypto/src/mls/conversation/config.rs index c751717613..27f785b7a2 100644 --- a/crypto/src/mls/conversation/config.rs +++ b/crypto/src/mls/conversation/config.rs @@ -10,15 +10,14 @@ use openmls::prelude::{ RequiredCapabilitiesExtension, SenderRatchetConfiguration, WireFormatPolicy, }; use openmls_traits::{ - OpenMlsCryptoProvider, crypto::OpenMlsCrypto, - types::{Ciphersuite, SignatureScheme}, + types::{Ciphersuite as MlsCiphersuite, SignatureScheme}, }; use serde::{Deserialize, Serialize}; use wire_e2e_identity::prelude::parse_json_jwk; use super::Result; -use crate::{MlsCiphersuite, MlsError, RecursiveError}; +use crate::{Ciphersuite, MlsError, RecursiveError}; /// Sets the config in OpenMls for the oldest possible epoch(past current) that a message can be decrypted pub(crate) const MAX_PAST_EPOCHS: usize = 3; @@ -34,7 +33,7 @@ pub(crate) const MAXIMUM_FORWARD_DISTANCE: u32 = 1000; #[derive(Debug, Clone, Default)] pub struct MlsConversationConfiguration { /// The `OpenMls` Ciphersuite used in the group - pub ciphersuite: MlsCiphersuite, + pub ciphersuite: Ciphersuite, /// Delivery service public signature key and credential pub external_senders: Vec, /// Implementation specific configuration @@ -54,12 +53,12 @@ impl MlsConversationConfiguration { &[CredentialType::Basic, CredentialType::X509]; /// Conservative sensible defaults - pub(crate) const DEFAULT_SUPPORTED_CIPHERSUITES: &'static [Ciphersuite] = &[ - Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, - Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256, - Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519, - Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384, - Ciphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521, + pub(crate) const DEFAULT_SUPPORTED_CIPHERSUITES: &'static [MlsCiphersuite] = &[ + MlsCiphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, + MlsCiphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256, + MlsCiphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519, + MlsCiphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384, + MlsCiphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521, ]; /// Not used at the moment @@ -125,7 +124,6 @@ impl MlsConversationConfiguration { backend: &MlsCryptoProvider, ) -> Result { backend - .crypto() .validate_signature_key(signature_scheme, &key[..]) .map_err(MlsError::wrap("validating signature key"))?; let key = OpenMlsSignaturePublicKey::new(key.into(), signature_scheme) diff --git a/crypto/src/mls/conversation/conversation_guard/commit.rs b/crypto/src/mls/conversation/conversation_guard/commit.rs index 601c335d0d..55265b09ef 100644 --- a/crypto/src/mls/conversation/conversation_guard/commit.rs +++ b/crypto/src/mls/conversation/conversation_guard/commit.rs @@ -1,17 +1,19 @@ //! The methods in this module all produce or handle commits. +use std::borrow::Borrow; + use openmls::prelude::{KeyPackageIn, LeafNode}; use super::history_sharing::HistoryClientUpdateOutcome; use crate::{ - ClientId, LeafError, MlsCredentialType, MlsError, MlsGroupInfoBundle, MlsTransportResponse, RecursiveError, + ClientIdRef, CredentialType, LeafError, MlsError, MlsGroupInfoBundle, MlsTransportResponse, RecursiveError, e2e_identity::NewCrlDistributionPoints, mls::{ conversation::{ Conversation as _, ConversationGuard, ConversationWithMls as _, Error, Result, commit::MlsCommitBundle, }, credential::{ - CredentialBundle, + Credential, crl::{extract_crl_uris_from_credentials, get_new_crl_distribution_points}, }, }, @@ -109,7 +111,7 @@ impl ConversationGuard { ) -> Result<(NewCrlDistributionPoints, MlsCommitBundle)> { self.ensure_no_pending_commit().await?; let backend = self.crypto_provider().await?; - let credential = self.credential_bundle().await?; + let credential = self.credential().await?; let signer = credential.signature_key(); let mut conversation = self.conversation_mut().await; @@ -152,10 +154,10 @@ impl ConversationGuard { /// # Arguments /// * `id` - group/conversation id /// * `clients` - list of client ids to be removed from the group - pub async fn remove_members(&mut self, clients: &[ClientId]) -> Result<()> { + pub async fn remove_members(&mut self, clients: &[impl Borrow]) -> Result<()> { self.ensure_no_pending_commit().await?; let backend = self.crypto_provider().await?; - let credential = self.credential_bundle().await?; + let credential = self.credential().await?; let signer = credential.signature_key(); let mut conversation = self.inner.write().await; @@ -165,7 +167,7 @@ impl ConversationGuard { .filter_map(|kp| { clients .iter() - .any(move |client_id| client_id.as_slice() == kp.credential.identity()) + .any(move |client_id| client_id.borrow() == kp.credential.identity()) .then_some(kp.index) }) .collect::>(); @@ -205,19 +207,16 @@ impl ConversationGuard { /// [crate::transaction_context::TransactionContext::e2ei_new_activation_enrollment] or /// [crate::transaction_context::TransactionContext::e2ei_new_rotate_enrollment] and having saved it with /// [crate::transaction_context::TransactionContext::save_x509_credential]. - pub async fn e2ei_rotate(&mut self, cb: Option<&CredentialBundle>) -> Result<()> { + pub async fn e2ei_rotate(&mut self, cb: Option<&Credential>) -> Result<()> { let client = &self.session().await?; let conversation = self.conversation().await; let cb = match cb { Some(cb) => cb, - None => &client - .find_most_recent_credential_bundle( - conversation.ciphersuite().signature_algorithm(), - MlsCredentialType::X509, - ) + None => &*client + .find_most_recent_credential(conversation.ciphersuite().signature_algorithm(), CredentialType::X509) .await - .map_err(RecursiveError::mls_client("finding most recent x509 credential bundle"))?, + .map_err(RecursiveError::mls_client("finding most recent x509 credential"))?, }; let mut leaf_node = conversation @@ -237,7 +236,7 @@ impl ConversationGuard { pub(crate) async fn update_key_material_inner( &mut self, - cb: Option<&CredentialBundle>, + cb: Option<&Credential>, leaf_node: Option, ) -> Result { self.ensure_no_pending_commit().await?; @@ -245,12 +244,12 @@ impl ConversationGuard { let backend = &self.crypto_provider().await?; let mut conversation = self.conversation_mut().await; let cb = match cb { - None => &conversation.find_most_recent_credential_bundle(session).await?, + None => &conversation.find_most_recent_credential(session).await?, Some(cb) => cb, }; let (commit, welcome, group_info) = conversation .group - .explicit_self_update(backend, &cb.signature_key, leaf_node) + .explicit_self_update(backend, &cb.signature_key_pair, leaf_node) .await .map_err(MlsError::wrap("group self update"))?; @@ -288,7 +287,7 @@ impl ConversationGuard { return Ok(None); } - let signer = &inner.find_most_recent_credential_bundle(session).await?.signature_key; + let signer = &inner.find_most_recent_credential(session).await?.signature_key_pair; let (commit, welcome, gi) = inner .group @@ -317,7 +316,7 @@ impl ConversationGuard { if proposals.is_empty() { return Ok(None); } - let signer = &inner.find_most_recent_credential_bundle(session).await?.signature_key; + let signer = &inner.find_most_recent_credential(session).await?.signature_key_pair; let (commit, welcome, gi) = inner .group diff --git a/crypto/src/mls/conversation/conversation_guard/decrypt/buffer_commit.rs b/crypto/src/mls/conversation/conversation_guard/decrypt/buffer_commit.rs index 2117c7c93e..c0b0e360b1 100644 --- a/crypto/src/mls/conversation/conversation_guard/decrypt/buffer_commit.rs +++ b/crypto/src/mls/conversation/conversation_guard/decrypt/buffer_commit.rs @@ -1,4 +1,4 @@ -use core_crypto_keystore::{connection::FetchFromDatabase as _, entities::MlsBufferedCommit}; +use core_crypto_keystore::{connection::FetchFromDatabase as _, entities::StoredBufferedCommit}; use log::info; use openmls::framing::MlsMessageIn; use openmls_traits::OpenMlsCryptoProvider as _; @@ -16,7 +16,7 @@ impl ConversationGuard { let conversation = self.conversation().await; info!(group_id = conversation.id(); "buffering commit"); - let buffered_commit = MlsBufferedCommit::new(conversation.id().to_bytes(), commit.as_ref().to_owned()); + let buffered_commit = StoredBufferedCommit::new(conversation.id().to_bytes(), commit.as_ref().to_owned()); self.crypto_provider() .await? @@ -34,9 +34,9 @@ impl ConversationGuard { self.crypto_provider() .await? .keystore() - .find::(conversation.id()) + .find::(conversation.id()) .await - .map(|option| option.map(MlsBufferedCommit::into_commit_data)) + .map(|option| option.map(StoredBufferedCommit::into_commit_data)) .map_err(KeystoreError::wrap("attempting to retrieve buffered commit")) .map_err(Into::into) } @@ -69,7 +69,7 @@ impl ConversationGuard { self.crypto_provider() .await? .keystore() - .remove::(conversation.id()) + .remove::(conversation.id()) .await .map_err(KeystoreError::wrap("attempting to clear buffered commit")) .map_err(Into::into) diff --git a/crypto/src/mls/conversation/conversation_guard/decrypt/mod.rs b/crypto/src/mls/conversation/conversation_guard/decrypt/mod.rs index 01b596ea4d..9f8e254ccb 100644 --- a/crypto/src/mls/conversation/conversation_guard/decrypt/mod.rs +++ b/crypto/src/mls/conversation/conversation_guard/decrypt/mod.rs @@ -223,7 +223,7 @@ impl ConversationGuard { ) .map_err(RecursiveError::mls_credential("extracting identity"))?; - let sender_client_id: ClientId = credential.credential.identity().into(); + let sender_client_id: ClientId = credential.credential.identity().to_owned().into(); let decrypted = match message.into_content() { ProcessedMessageContent::ApplicationMessage(app_msg) => { @@ -564,7 +564,7 @@ impl ConversationGuard { let state = Session::compute_conversation_state( self.ciphersuite().await, credentials.iter(), - crate::MlsCredentialType::X509, + crate::CredentialType::X509, backend.authentication_service().borrow().await.as_ref(), ) .await; diff --git a/crypto/src/mls/conversation/conversation_guard/encrypt.rs b/crypto/src/mls/conversation/conversation_guard/encrypt.rs index eddc77816f..32961bee8c 100644 --- a/crypto/src/mls/conversation/conversation_guard/encrypt.rs +++ b/crypto/src/mls/conversation/conversation_guard/encrypt.rs @@ -20,7 +20,7 @@ impl ConversationGuard { /// from OpenMls and the KeyStore pub async fn encrypt_message(&mut self, message: impl AsRef<[u8]>) -> Result> { let backend = self.crypto_provider().await?; - let credential = self.credential_bundle().await?; + let credential = self.credential().await?; let signer = credential.signature_key(); let mut inner = self.conversation_mut().await; let encrypted = inner diff --git a/crypto/src/mls/conversation/conversation_guard/history_sharing.rs b/crypto/src/mls/conversation/conversation_guard/history_sharing.rs index c76805fcfd..84a2931715 100644 --- a/crypto/src/mls/conversation/conversation_guard/history_sharing.rs +++ b/crypto/src/mls/conversation/conversation_guard/history_sharing.rs @@ -4,7 +4,7 @@ use itertools::{Either, Itertools as _}; use super::{ConversationGuard, Error, Result}; use crate::{ - HistorySecret, MlsCommitBundle, RecursiveError, + ClientIdRef, HistorySecret, MlsCommitBundle, RecursiveError, mls::conversation::{Conversation as _, ConversationWithMls, conversation_guard::commit::TransportedCommitPolicy}, }; @@ -84,7 +84,7 @@ impl ConversationGuard { // at most one history client in a conversation? // Then we could use something like `into_iter().find_map()` to lazily evaluate client ids, but this way we're making sure to // remove any history client, and not just the first one we find. - history_client_ids.retain(crate::ephemeral::is_history_client); + history_client_ids.retain(|client_id| crate::ephemeral::is_history_client(client_id)); if history_client_ids.is_empty() { log::warn!("History sharing is already disabled."); @@ -122,7 +122,8 @@ impl ConversationGuard { // Distinguish between history clients and other clients for the following operations let (existing_history_clients, other_clients): (HashSet<_>, HashSet<_>) = conversation.group().members().partition_map(|member| { - let is_history_client = crate::ephemeral::is_history_client(&member.credential.identity().into()); + let is_history_client = + crate::ephemeral::is_history_client(ClientIdRef::new(member.credential.identity())); let member_index = member.index; if is_history_client { Either::Left(member_index) diff --git a/crypto/src/mls/conversation/conversation_guard/mod.rs b/crypto/src/mls/conversation/conversation_guard/mod.rs index a9ab8a266f..07a93a3f72 100644 --- a/crypto/src/mls/conversation/conversation_guard/mod.rs +++ b/crypto/src/mls/conversation/conversation_guard/mod.rs @@ -9,7 +9,7 @@ use super::{ConversationWithMls, Error, MlsConversation, Result}; use crate::{ KeystoreError, LeafError, MlsGroupInfoBundle, MlsTransport, RecursiveError, group_store::GroupStoreValue, - mls::{conversation::ConversationIdRef, credential::CredentialBundle}, + mls::{conversation::ConversationIdRef, credential::Credential}, transaction_context::TransactionContext, }; mod commit; @@ -121,11 +121,11 @@ impl ConversationGuard { } } - async fn credential_bundle(&self) -> Result> { + async fn credential(&self) -> Result> { let client = self.session().await?; let inner = self.conversation().await; inner - .find_current_credential_bundle(&client) + .find_current_credential(&client) .await .map_err(|_| Error::IdentityInitializationError) } diff --git a/crypto/src/mls/conversation/credential.rs b/crypto/src/mls/conversation/credential.rs new file mode 100644 index 0000000000..07e4f6b9f7 --- /dev/null +++ b/crypto/src/mls/conversation/credential.rs @@ -0,0 +1,46 @@ +use std::collections::HashMap; + +use openmls::prelude::{Credential as MlsCredential, CredentialType, CredentialWithKey, SignaturePublicKey}; + +use super::{Error, Result}; +use crate::MlsConversation; + +impl MlsConversation { + /// Returns all members credentials from the group/conversation + pub fn members(&self) -> HashMap, MlsCredential> { + self.group.members().fold(HashMap::new(), |mut acc, kp| { + let credential = kp.credential; + let id = credential.identity().to_vec(); + acc.entry(id).or_insert(credential); + acc + }) + } + + /// Returns all members credentials with their signature public key from the group/conversation + pub fn members_with_key(&self) -> HashMap, CredentialWithKey> { + self.group.members().fold(HashMap::new(), |mut acc, kp| { + let credential = kp.credential; + let id = credential.identity().to_vec(); + let signature_key = SignaturePublicKey::from(kp.signature_key); + let credential = CredentialWithKey { + credential, + signature_key, + }; + acc.entry(id).or_insert(credential); + acc + }) + } + + pub(crate) fn own_mls_credential(&self) -> Result<&MlsCredential> { + let credential = self + .group + .own_leaf_node() + .ok_or(Error::MlsGroupInvalidState("own_leaf_node not present in group"))? + .credential(); + Ok(credential) + } + + pub(crate) fn own_credential_type(&self) -> Result { + self.own_mls_credential().map(|credential| credential.credential_type()) + } +} diff --git a/crypto/src/mls/conversation/id.rs b/crypto/src/mls/conversation/id.rs new file mode 100644 index 0000000000..33d080c457 --- /dev/null +++ b/crypto/src/mls/conversation/id.rs @@ -0,0 +1,95 @@ +use std::{ + borrow::{Borrow, Cow}, + ops::Deref, +}; + +/// A unique identifier for a group/conversation. The identifier must be unique within a client. +#[derive( + core_crypto_macros::Debug, + derive_more::AsRef, + derive_more::From, + derive_more::Into, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + Clone, +)] +#[sensitive] +#[as_ref([u8])] +#[from(&[u8], Vec)] +pub struct ConversationId(Vec); + +impl Borrow for ConversationId { + fn borrow(&self) -> &ConversationIdRef { + ConversationIdRef::new(&self.0) + } +} + +impl Deref for ConversationId { + type Target = ConversationIdRef; + + fn deref(&self) -> &Self::Target { + ConversationIdRef::new(&self.0) + } +} + +impl From for Cow<'_, [u8]> { + fn from(value: ConversationId) -> Self { + Cow::Owned(value.0) + } +} + +impl<'a> From<&'a ConversationId> for Cow<'a, [u8]> { + fn from(value: &'a ConversationId) -> Self { + Cow::Borrowed(value.as_ref()) + } +} + +/// Reference to a ConversationId. +/// +/// This type is `!Sized` and is only ever seen as a reference, like `str` or `[u8]`. +// +// pattern from https://stackoverflow.com/a/64990850 +#[repr(transparent)] +#[derive(PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ConversationIdRef([u8]); + +impl ConversationIdRef { + /// Creates a `ConversationId` Ref, needed to implement `Borrow` for `T` + pub fn new(bytes: &Bytes) -> &ConversationIdRef + where + Bytes: AsRef<[u8]> + ?Sized, + { + // safety: because of `repr(transparent)` we know that `ConversationIdRef` has a memory layout + // identical to `[u8]`, so we can perform this cast + unsafe { &*(bytes.as_ref() as *const [u8] as *const ConversationIdRef) } + } +} + +impl ConversationIdRef { + pub(crate) fn to_bytes(&self) -> Vec { + self.0.to_owned() + } +} + +impl ToOwned for ConversationIdRef { + type Owned = ConversationId; + + fn to_owned(&self) -> Self::Owned { + ConversationId(self.0.to_owned()) + } +} + +impl AsRef<[u8]> for ConversationIdRef { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl<'a> From<&'a ConversationIdRef> for Cow<'a, [u8]> { + fn from(value: &'a ConversationIdRef) -> Self { + Cow::Borrowed(value.as_ref()) + } +} diff --git a/crypto/src/mls/conversation/merge.rs b/crypto/src/mls/conversation/merge.rs index eb75b475b6..36448ffd03 100644 --- a/crypto/src/mls/conversation/merge.rs +++ b/crypto/src/mls/conversation/merge.rs @@ -11,7 +11,7 @@ //! | 1+ pend. Proposal | ❌ | ✅ | //! -use core_crypto_keystore::entities::MlsEncryptionKeyPair; +use core_crypto_keystore::entities::StoredEncryptionKeyPair; use mls_crypto_provider::MlsCryptoProvider; use openmls_traits::OpenMlsCryptoProvider; @@ -35,7 +35,7 @@ impl MlsConversation { // ..so if there's any, we clear them after the commit is merged for oln in &previous_own_leaf_nodes { let ek = oln.encryption_key().as_slice(); - let _ = backend.key_store().remove::(ek).await; + let _ = backend.key_store().remove::(ek).await; } client diff --git a/crypto/src/mls/conversation/mod.rs b/crypto/src/mls/conversation/mod.rs index c3830e877a..88356fb851 100644 --- a/crypto/src/mls/conversation/mod.rs +++ b/crypto/src/mls/conversation/mod.rs @@ -11,56 +11,55 @@ //! | merge | ❌ | ❌ | ✅ | ✅ | //! | decrypt | ✅ | ✅ | ✅ | ✅ | -use std::{ - borrow::{Borrow, Cow}, - collections::{HashMap, HashSet}, - ops::Deref, - sync::Arc, -}; - -use config::MlsConversationConfiguration; -use core_crypto_keystore::CryptoKeystoreMls; -use itertools::Itertools as _; -use log::trace; -use mls_crypto_provider::{Database, MlsCryptoProvider}; -use openmls::{ - group::MlsGroup, - prelude::{Credential, CredentialWithKey, LeafNodeIndex, Proposal, SignaturePublicKey}, -}; -use openmls_traits::{OpenMlsCryptoProvider, types::SignatureScheme}; - -use crate::{ - ClientId, E2eiConversationState, KeystoreError, LeafError, MlsCiphersuite, MlsCredentialType, MlsError, - RecursiveError, WireIdentity, mls::Session, -}; - pub(crate) mod commit; mod commit_delay; pub(crate) mod config; pub(crate) mod conversation_guard; +mod credential; mod duplicate; #[cfg(test)] mod durability; mod error; pub(crate) mod group_info; +mod id; mod immutable_conversation; pub(crate) mod merge; mod orphan_welcome; mod own_commit; pub(crate) mod pending_conversation; +mod persistence; pub(crate) mod proposal; mod renew; pub(crate) mod welcome; mod wipe; -pub use conversation_guard::ConversationGuard; -pub use error::{Error, Result}; -pub use immutable_conversation::ImmutableConversation; +use std::{ + collections::{HashMap, HashSet}, + ops::Deref, + sync::Arc, +}; -use super::credential::CredentialBundle; +use itertools::Itertools as _; +use log::trace; +use mls_crypto_provider::MlsCryptoProvider; +use openmls::{ + group::MlsGroup, + prelude::{LeafNodeIndex, Proposal}, +}; +use openmls_traits::{OpenMlsCryptoProvider, types::SignatureScheme}; + +use self::config::MlsConversationConfiguration; +pub use self::{ + conversation_guard::ConversationGuard, + error::{Error, Result}, + id::{ConversationId, ConversationIdRef}, + immutable_conversation::ImmutableConversation, +}; +use super::credential::Credential; use crate::{ - UserId, - mls::{HasSessionAndCrypto, credential::ext::CredentialExt as _}, + Ciphersuite, ClientId, CredentialType, E2eiConversationState, LeafError, MlsError, RecursiveError, UserId, + WireIdentity, + mls::{HasSessionAndCrypto, Session, credential::ext::CredentialExt as _}, }; /// The base layer for [Conversation]. @@ -110,7 +109,7 @@ pub trait Conversation<'a>: ConversationWithMls<'a> { } /// Returns the ciphersuite of a given conversation - async fn ciphersuite(&'a self) -> MlsCiphersuite { + async fn ciphersuite(&'a self) -> Ciphersuite { self.conversation().await.ciphersuite() } @@ -143,7 +142,7 @@ pub trait Conversation<'a>: ConversationWithMls<'a> { inner .group() .members() - .map(|kp| ClientId::from(kp.credential.identity())) + .map(|kp| ClientId::from(kp.credential.identity().to_owned())) .collect() } @@ -171,7 +170,7 @@ pub trait Conversation<'a>: ConversationWithMls<'a> { let state = Session::compute_conversation_state( inner.ciphersuite(), inner.group.members_credentials(), - MlsCredentialType::X509, + CredentialType::X509, authentication_service.borrow().await.as_ref(), ) .await; @@ -196,7 +195,7 @@ pub trait Conversation<'a>: ConversationWithMls<'a> { conversation .members_with_key() .into_iter() - .filter(|(id, _)| device_ids.contains(&ClientId::from(id.as_slice()))) + .filter(|(id, _)| device_ids.iter().any(|client_id| *client_id == id.as_slice())) .map(|(_, c)| { c.extract_identity(conversation.ciphersuite(), env) .map_err(RecursiveError::mls_credential("extracting identity")) @@ -265,97 +264,6 @@ pub trait Conversation<'a>: ConversationWithMls<'a> { impl<'a, T: ConversationWithMls<'a>> Conversation<'a> for T {} -/// A unique identifier for a group/conversation. The identifier must be unique within a client. -#[derive( - core_crypto_macros::Debug, - derive_more::AsRef, - derive_more::From, - derive_more::Into, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - Clone, -)] -#[sensitive] -#[as_ref([u8])] -#[from(&[u8], Vec)] -pub struct ConversationId(Vec); - -impl From for Cow<'_, [u8]> { - fn from(value: ConversationId) -> Self { - Cow::Owned(value.0) - } -} - -impl<'a> From<&'a ConversationId> for Cow<'a, [u8]> { - fn from(value: &'a ConversationId) -> Self { - Cow::Borrowed(value.as_ref()) - } -} - -/// Reference to a ConversationId. -/// -/// This type is `!Sized` and is only ever seen as a reference, like `str` or `[u8]`. -// -// pattern from https://stackoverflow.com/a/64990850 -#[repr(transparent)] -#[derive(PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct ConversationIdRef([u8]); - -impl ConversationIdRef { - /// Creates a `ConversationId` Ref, needed to implement `Borrow` for `T` - pub fn new(bytes: &Bytes) -> &ConversationIdRef - where - Bytes: AsRef<[u8]> + ?Sized, - { - // safety: because of `repr(transparent)` we know that `ConversationIdRef` has a memory layout - // identical to `[u8]`, so we can perform this cast - unsafe { &*(bytes.as_ref() as *const [u8] as *const ConversationIdRef) } - } -} - -impl ConversationIdRef { - fn to_bytes(&self) -> Vec { - self.as_ref().to_owned() - } -} - -impl Borrow for ConversationId { - fn borrow(&self) -> &ConversationIdRef { - ConversationIdRef::new(&self.0) - } -} - -impl Deref for ConversationId { - type Target = ConversationIdRef; - - fn deref(&self) -> &Self::Target { - ConversationIdRef::new(&self.0) - } -} - -impl ToOwned for ConversationIdRef { - type Owned = ConversationId; - - fn to_owned(&self) -> Self::Owned { - ConversationId(self.0.to_owned()) - } -} - -impl AsRef<[u8]> for ConversationIdRef { - fn as_ref(&self) -> &[u8] { - &self.0 - } -} - -impl<'a> From<&'a ConversationIdRef> for Cow<'a, [u8]> { - fn from(value: &'a ConversationIdRef) -> Self { - Cow::Borrowed(value.as_ref()) - } -} - /// This is a wrapper on top of the OpenMls's [MlsGroup], that provides Core Crypto specific functionality /// /// This type will store the state of a group. With the [MlsGroup] it holds, it provides all @@ -385,19 +293,19 @@ impl MlsConversation { pub async fn create( id: ConversationId, author_client: &Session, - creator_credential_type: MlsCredentialType, + creator_credential_type: CredentialType, configuration: MlsConversationConfiguration, backend: &MlsCryptoProvider, ) -> Result { let (cs, ct) = (configuration.ciphersuite, creator_credential_type); let cb = author_client - .get_most_recent_or_create_credential_bundle(backend, cs.signature_algorithm(), ct) + .get_most_recent_or_create_credential(backend, cs.signature_algorithm(), ct) .await - .map_err(RecursiveError::mls_client("getting or creating credential bundle"))?; + .map_err(RecursiveError::mls_client("getting or creating credential"))?; let group = MlsGroup::new_with_group_id( backend, - &cb.signature_key, + &cb.signature_key_pair, &configuration.as_openmls_default_configuration()?, openmls::prelude::GroupId::from_slice(id.as_ref()), cb.to_mls_credential_with_key(), @@ -441,24 +349,6 @@ impl MlsConversation { Ok(conversation) } - /// Internal API: restore the conversation from a persistence-saved serialized Group State. - pub(crate) fn from_serialized_state(buf: Vec, parent_id: Option) -> Result { - let group: MlsGroup = - core_crypto_keystore::deser(&buf).map_err(KeystoreError::wrap("deserializing group state"))?; - let id = ConversationId::from(group.group_id().as_slice()); - let configuration = MlsConversationConfiguration { - ciphersuite: group.ciphersuite().into(), - ..Default::default() - }; - - Ok(Self { - id, - group, - parent_id, - configuration, - }) - } - /// Group/conversation id pub fn id(&self) -> &ConversationId { &self.id @@ -468,16 +358,6 @@ impl MlsConversation { &self.group } - /// Returns all members credentials from the group/conversation - pub fn members(&self) -> HashMap, Credential> { - self.group.members().fold(HashMap::new(), |mut acc, kp| { - let credential = kp.credential; - let id = credential.identity().to_vec(); - acc.entry(id).or_insert(credential); - acc - }) - } - /// Get actual group members and subtract pending remove proposals pub fn members_in_next_epoch(&self) -> Vec { let pending_removals = self.pending_removals(); @@ -486,7 +366,7 @@ impl MlsConversation { .members() .filter_map(|kp| { if !pending_removals.contains(&kp.index) { - Some(kp.credential.identity().into()) + Some(kp.credential.identity().to_owned().into()) } else { trace!(client_index:% = kp.index; "Client is pending removal"); None @@ -507,49 +387,7 @@ impl MlsConversation { .collect::>() } - /// Returns all members credentials with their signature public key from the group/conversation - pub fn members_with_key(&self) -> HashMap, CredentialWithKey> { - self.group.members().fold(HashMap::new(), |mut acc, kp| { - let credential = kp.credential; - let id = credential.identity().to_vec(); - let signature_key = SignaturePublicKey::from(kp.signature_key); - let credential = CredentialWithKey { - credential, - signature_key, - }; - acc.entry(id).or_insert(credential); - acc - }) - } - - pub(crate) async fn persist_group_when_changed(&mut self, keystore: &Database, force: bool) -> Result<()> { - if force || self.group.state_changed() == openmls::group::InnerState::Changed { - keystore - .mls_group_persist( - &self.id, - &core_crypto_keystore::ser(&self.group).map_err(KeystoreError::wrap("serializing group state"))?, - self.parent_id.as_ref().map(|id| id.as_ref()), - ) - .await - .map_err(KeystoreError::wrap("persisting mls group"))?; - - self.group.set_state(openmls::group::InnerState::Persisted); - } - - Ok(()) - } - - pub(crate) fn own_credential_type(&self) -> Result { - Ok(self - .group - .own_leaf_node() - .ok_or(Error::MlsGroupInvalidState("own_leaf_node not present in group"))? - .credential() - .credential_type() - .into()) - } - - pub(crate) fn ciphersuite(&self) -> MlsCiphersuite { + pub(crate) fn ciphersuite(&self) -> Ciphersuite { self.configuration.ciphersuite } @@ -557,7 +395,7 @@ impl MlsConversation { self.ciphersuite().signature_algorithm() } - pub(crate) async fn find_current_credential_bundle(&self, client: &Session) -> Result> { + pub(crate) async fn find_current_credential(&self, client: &Session) -> Result> { let own_leaf = self.group.own_leaf().ok_or(LeafError::InternalMlsError)?; let sc = self.ciphersuite().signature_algorithm(); let ct = self @@ -565,28 +403,30 @@ impl MlsConversation { .map_err(RecursiveError::mls_conversation("getting own credential type"))?; client - .find_credential_bundle_by_public_key(sc, ct, own_leaf.signature_key()) + .find_credential_by_public_key(sc, ct, own_leaf.signature_key()) .await - .map_err(RecursiveError::mls_client("finding current credential bundle")) + .map_err(RecursiveError::mls_client("finding current credential")) .map_err(Into::into) } - pub(crate) async fn find_most_recent_credential_bundle(&self, client: &Session) -> Result> { + pub(crate) async fn find_most_recent_credential(&self, client: &Session) -> Result> { let sc = self.ciphersuite().signature_algorithm(); let ct = self .own_credential_type() .map_err(RecursiveError::mls_conversation("getting own credential type"))?; client - .find_most_recent_credential_bundle(sc, ct) + .find_most_recent_credential(sc, ct) .await - .map_err(RecursiveError::mls_client("finding most recent credential bundle")) + .map_err(RecursiveError::mls_client("finding most recent credential")) .map_err(Into::into) } } #[cfg(test)] pub mod test_utils { + use openmls::prelude::SignaturePublicKey; + use super::*; impl MlsConversation { @@ -655,7 +495,7 @@ mod tests { mod wire_identity_getters { use super::Error; use crate::{ - ClientId, DeviceStatus, E2eiConversationState, MlsCredentialType, mls::conversation::Conversation, + ClientId, CredentialType, DeviceStatus, E2eiConversationState, mls::conversation::Conversation, test_utils::*, }; @@ -900,7 +740,7 @@ mod tests { assert_eq!(ios_ids, android_ids); assert!(ios_ids.iter().all(|i| { - matches!(i.credential_type, MlsCredentialType::Basic) + matches!(i.credential_type, CredentialType::Basic) && matches!(i.status, DeviceStatus::Valid) && i.x509_identity.is_none() && !i.thumbprint.is_empty() diff --git a/crypto/src/mls/conversation/persistence.rs b/crypto/src/mls/conversation/persistence.rs new file mode 100644 index 0000000000..532fadb629 --- /dev/null +++ b/crypto/src/mls/conversation/persistence.rs @@ -0,0 +1,67 @@ +use std::collections::HashMap; + +use core_crypto_keystore::{ + CryptoKeystoreMls as _, + connection::FetchFromDatabase as _, + entities::{EntityFindParams, PersistedMlsGroup}, +}; +use mls_crypto_provider::Database; +use openmls::group::{InnerState, MlsGroup}; + +use super::Result; +use crate::{ConversationId, KeystoreError, MlsConversation, MlsConversationConfiguration}; + +impl MlsConversation { + pub(crate) async fn persist_group_when_changed(&mut self, keystore: &Database, force: bool) -> Result<()> { + if force || self.group.state_changed() == InnerState::Changed { + keystore + .mls_group_persist( + &self.id, + &core_crypto_keystore::ser(&self.group).map_err(KeystoreError::wrap("serializing group state"))?, + self.parent_id.as_ref().map(|id| id.as_ref()), + ) + .await + .map_err(KeystoreError::wrap("persisting mls group"))?; + + self.group.set_state(InnerState::Persisted); + } + + Ok(()) + } + + /// restore the conversation from a persistence-saved serialized Group State. + pub(crate) fn from_serialized_state(buf: Vec, parent_id: Option) -> Result { + let group: MlsGroup = + core_crypto_keystore::deser(&buf).map_err(KeystoreError::wrap("deserializing group state"))?; + let id = ConversationId::from(group.group_id().as_slice()); + let configuration = MlsConversationConfiguration { + ciphersuite: group.ciphersuite().into(), + ..Default::default() + }; + + Ok(Self { + id, + group, + parent_id, + configuration, + }) + } + + /// Effectively [`Database::mls_groups_restore`] but with better types + pub(crate) async fn load_all(keystore: &Database) -> Result> { + let groups = keystore + .find_all::(EntityFindParams::default()) + .await + .map_err(KeystoreError::wrap("finding all persisted mls groups"))?; + groups + .into_iter() + .map(|group| { + // we can't just destructure the fields straight out of the group, because we derive `Zeroize`, which zeroizes on drop, + // which means we are forced to clone all the group's fields, because otherwise the drop impl couldn't run. + let conversation = + Self::from_serialized_state(group.state.clone(), group.parent_id.clone().map(Into::into))?; + Ok((group.id.clone().into(), conversation)) + }) + .collect() + } +} diff --git a/crypto/src/mls/conversation/proposal.rs b/crypto/src/mls/conversation/proposal.rs index b01f16cbce..004d3c7ffe 100644 --- a/crypto/src/mls/conversation/proposal.rs +++ b/crypto/src/mls/conversation/proposal.rs @@ -26,10 +26,10 @@ impl MlsConversation { key_package: KeyPackageIn, ) -> Result { let signer = &self - .find_current_credential_bundle(client) + .find_current_credential(client) .await .map_err(|_| Error::IdentityInitializationError)? - .signature_key; + .signature_key_pair; let crl_new_distribution_points = get_new_crl_distribution_points( backend, @@ -62,10 +62,10 @@ impl MlsConversation { member: LeafNodeIndex, ) -> Result { let signer = &self - .find_current_credential_bundle(client) + .find_current_credential(client) .await .map_err(|_| Error::IdentityInitializationError)? - .signature_key; + .signature_key_pair; let proposal = self .group .propose_remove_member(backend, signer, member) @@ -94,13 +94,13 @@ impl MlsConversation { leaf_node: Option, ) -> Result { let msg_signer = &self - .find_current_credential_bundle(client) + .find_current_credential(client) .await .map_err(|_| Error::IdentityInitializationError)? - .signature_key; + .signature_key_pair; let proposal = if let Some(leaf_node) = leaf_node { - let leaf_node_signer = &self.find_most_recent_credential_bundle(client).await?.signature_key; + let leaf_node_signer = &self.find_most_recent_credential(client).await?.signature_key_pair; self.group .propose_explicit_self_update(backend, msg_signer, leaf_node, leaf_node_signer) diff --git a/crypto/src/mls/conversation/renew.rs b/crypto/src/mls/conversation/renew.rs index 918cad7612..91444176da 100644 --- a/crypto/src/mls/conversation/renew.rs +++ b/crypto/src/mls/conversation/renew.rs @@ -1,4 +1,4 @@ -use core_crypto_keystore::entities::MlsEncryptionKeyPair; +use core_crypto_keystore::entities::StoredEncryptionKeyPair; use mls_crypto_provider::MlsCryptoProvider; use openmls::prelude::{LeafNode, LeafNodeIndex, Proposal, QueuedProposal, Sender, StagedCommit}; use openmls_traits::OpenMlsCryptoProvider; @@ -131,7 +131,7 @@ impl MlsConversation { // encryption key from the keystore otherwise we would have a leak backend .key_store() - .remove::(leaf_node.encryption_key().as_slice()) + .remove::(leaf_node.encryption_key().as_slice()) .await .map_err(KeystoreError::wrap("removing mls encryption keypair"))?; } @@ -144,9 +144,9 @@ impl MlsConversation { let sc = self.signature_scheme(); let ct = self.own_credential_type()?; let cb = client - .find_most_recent_credential_bundle(sc, ct) + .find_most_recent_credential(sc, ct) .await - .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?; + .map_err(RecursiveError::mls_client("finding most recent credential"))?; leaf_node.set_credential_with_key(cb.to_mls_credential_with_key()); diff --git a/crypto/src/mls/credential/credential_ref/error.rs b/crypto/src/mls/credential/credential_ref/error.rs new file mode 100644 index 0000000000..407964cae7 --- /dev/null +++ b/crypto/src/mls/credential/credential_ref/error.rs @@ -0,0 +1,42 @@ +// We allow missing documentation in the error module because the types are generally self-descriptive. +#![allow(missing_docs)] + +use super::super::error::CredentialValidationError; + +pub(crate) type Result = core::result::Result; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("signature keypair not found")] + KeypairNotFound, + #[error("credential not found")] + CredentialNotFound, + #[error("credential failed to validate")] + ValidationFailed(#[from] CredentialValidationError), + #[error(transparent)] + Keystore(#[from] crate::KeystoreError), + #[error(transparent)] + Recursive(#[from] crate::RecursiveError), + #[error("TLS serializing {item}")] + TlsSerialize { + #[source] + source: tls_codec::Error, + item: &'static str, + }, + #[error("TLS deserializing {item}")] + TlsDeserialize { + #[source] + source: tls_codec::Error, + item: &'static str, + }, +} + +impl Error { + pub fn tls_serialize(item: &'static str) -> impl FnOnce(tls_codec::Error) -> Self { + move |source| Self::TlsSerialize { source, item } + } + + pub fn tls_deserialize(item: &'static str) -> impl FnOnce(tls_codec::Error) -> Self { + move |source| Self::TlsDeserialize { source, item } + } +} diff --git a/crypto/src/mls/credential/credential_ref/find.rs b/crypto/src/mls/credential/credential_ref/find.rs new file mode 100644 index 0000000000..2f71c21b6f --- /dev/null +++ b/crypto/src/mls/credential/credential_ref/find.rs @@ -0,0 +1,140 @@ +use core_crypto_keystore::{ + connection::FetchFromDatabase as _, + entities::{EntityFindParams, StoredCredential}, +}; +use mls_crypto_provider::Database; +use openmls::prelude::{Credential as MlsCredential, CredentialType, SignatureScheme}; +use tls_codec::Deserialize as _; + +use super::{super::keypairs, Error, Result}; +use crate::{ClientId, Credential, CredentialRef, KeystoreError, RecursiveError, mls::session::id::ClientIdRef}; + +/// Filters to narrow down the set of credentials returned from [`Credential::find`][super::Credential::find]. +/// +/// Filters which are unset allow any value. +/// +/// ## Example +/// +/// ```no_run +/// # use crypto::src::mls::{Credential, FindFilters}; +/// # use openmls::prelude::CredentialType; +/// # let database = todo!(); +/// # let client_id = todo!(); +/// // get all basic credentials for a client +/// let credentials = Credential::find( +/// &database, +/// FindFilters::builder() +/// .client_id(&client_id) +/// .credential_type(CredentialType::Basic) +/// .build() +/// )?; +/// # println!("{credentials:?}"); +/// ``` +#[derive(Debug, Default, typed_builder::TypedBuilder)] +pub struct FindFilters<'a> { + /// Client ID to search for + #[builder(default, setter(strip_option))] + pub client_id: Option<&'a ClientIdRef>, + /// Signature scheme / ciphersuite to build for + #[builder(default, setter(strip_option))] + pub signature_scheme: Option, + /// Credential type to build for + #[builder(default, setter(strip_option))] + pub credential_type: Option, +} + +impl CredentialRef { + /// Find all credentials in the database matching the provided filters. + /// + /// If you have all the components of a filter, it is more efficient to use those to directly + /// [construct a `CredentialRef`][CredentialRef::new]. + // + // Our database does not currently support indices or even in-db searching, so this moves all data + // from the DB to the runtime, decodes everything, and then filters. This is obviously suboptimal, + // but that's only going to be improved with WPB-20839. + pub async fn find(database: &Database, find_filters: FindFilters<'_>) -> Result> { + let FindFilters { + client_id, + signature_scheme, + credential_type, + } = find_filters; + + let mut stored_keypairs = keypairs::load_all(database) + .await + .map_err(RecursiveError::mls_credential( + "loading all keypairs while finding credentials", + ))?; + if let Some(signature_scheme) = signature_scheme { + stored_keypairs.retain(|keypair| keypair.signature_scheme == signature_scheme as u16); + } + if stored_keypairs.is_empty() { + return Ok(Vec::new()); + } + let stored_keypairs = stored_keypairs + .iter() + .map(|stored_keypair| { + keypairs::deserialize(stored_keypair) + .map_err(RecursiveError::mls_credential( + "deserializing keypair while finding credentials", + )) + .map_err(Into::into) + }) + .collect::>>()?; + + let partial_credentials = database + .find_all::(EntityFindParams::default()) + .await + .map_err(KeystoreError::wrap("finding all credentials"))? + .into_iter() + .filter(|stored| { + client_id + .map(|client_id| client_id.as_ref() == stored.id) + .unwrap_or(true) + }) + .map(|stored| -> Result<_> { + let mls_credential = MlsCredential::tls_deserialize_exact(&stored.credential) + .map_err(Error::tls_deserialize("Credential"))?; + Ok((mls_credential, stored)) + }); + + let mut out = Vec::new(); + for partial in partial_credentials { + let (ref mls_credential, ref stored_credential) = partial?; + + if !credential_type + .map(|credential_type| credential_type == mls_credential.credential_type()) + .unwrap_or(true) + { + // credential type did not match + continue; + } + + for signature_key_pair in &stored_keypairs { + if Credential::validate_mls_credential( + mls_credential, + <&ClientIdRef>::from(&stored_credential.id), + signature_key_pair, + ) + .is_err() + { + // this probably doesn't happen often, but no point getting weird about it if it does; + // just indicates it's not a match + continue; + } + + out.push(Self { + client_id: ClientId(stored_credential.id.clone()), + r#type: mls_credential.credential_type(), + signature_scheme: signature_key_pair.signature_scheme(), + }) + } + } + + Ok(out) + } + + /// Load all credentials from the database + pub async fn get_all(database: &Database) -> Result> { + Self::find(database, FindFilters::default()).await + } +} diff --git a/crypto/src/mls/credential/credential_ref/mod.rs b/crypto/src/mls/credential/credential_ref/mod.rs new file mode 100644 index 0000000000..838a18070a --- /dev/null +++ b/crypto/src/mls/credential/credential_ref/mod.rs @@ -0,0 +1,69 @@ +//! Definitions and implementations for [`CredentialRef`]. + +mod error; +mod find; +mod persistence; + +use openmls::prelude::{CredentialType, SignatureScheme}; + +pub(crate) use self::error::Result; +pub use self::{ + error::Error, + find::{FindFilters, FindFiltersBuilder}, +}; +use crate::{ClientId, ClientIdRef}; + +/// A reference to a credential which has been stored in the database. +/// +/// This serves two purposes: +/// +/// 1. Credentials can be quite large; we'd really like to avoid passing them +/// back and forth across the FFI boundary more than is strictly required. +/// Therefore, we use this type which is substantially more compact. +/// 2. It serves as proof of persistence. If you have a `CredentialRef`, you know +/// that the credential it refers to has been saved in the database. +/// This gives us a typesafe way to require that credentials are saved before +/// they are added to a [`Session`][crate::Session]. +/// +/// Created with [`Credential::save`][crate::Credential::save]. +/// +/// This reference is _not_ a literal reference in memory. +/// It is instead the key from which a credential can be retrieved. +/// This means that it is stable over time and across the FFI boundary. +#[derive( + core_crypto_macros::Debug, Clone, derive_more::From, derive_more::Into, serde::Serialize, serde::Deserialize, +)] +pub struct CredentialRef { + client_id: ClientId, + r#type: CredentialType, + signature_scheme: SignatureScheme, +} + +impl CredentialRef { + /// Construct an instance from its parts. + /// + /// This _must_ remain crate-private so that we can use this type + /// as proof of persistence! + pub(crate) const fn new(client_id: ClientId, r#type: CredentialType, signature_scheme: SignatureScheme) -> Self { + Self { + client_id, + r#type, + signature_scheme, + } + } + + /// Get the client ID associated with this credential + pub fn client_id(&self) -> &ClientIdRef { + self.client_id.as_ref() + } + + /// Get the credential type associated with this credential + pub fn r#type(&self) -> CredentialType { + self.r#type + } + + /// Get the signature scheme associated with this credential + pub fn signature_scheme(&self) -> SignatureScheme { + self.signature_scheme + } +} diff --git a/crypto/src/mls/credential/credential_ref/persistence.rs b/crypto/src/mls/credential/credential_ref/persistence.rs new file mode 100644 index 0000000000..6a5430e697 --- /dev/null +++ b/crypto/src/mls/credential/credential_ref/persistence.rs @@ -0,0 +1,173 @@ +//! Persistence for [`CredentialRef`], i.e. loading actual credentials from the keystore given a ref. +//! +//! It is not logically required that these methods are crate-private, but they aren't likely to be +//! useful to end users. Clients building on the CC API can't do anything useful with a full [`Credential`], +//! and it's wasteful to transfer one across the FFI boundary. + +use std::cmp::Ordering; + +use core_crypto_keystore::{ + connection::FetchFromDatabase as _, + entities::{EntityFindParams, StoredCredential, StoredSignatureKeypair}, +}; +use mls_crypto_provider::Database; +use openmls::prelude::Credential as MlsCredential; +use tls_codec::Deserialize as _; + +use super::{super::keypairs, Error, Result}; +use crate::{Credential, CredentialRef, KeystoreError, RecursiveError}; + +/// Helper struct caching relevant data from the keystore. +/// +/// This is generated by [`Credential::load_cache`]. It's only relevant when you want to +/// load multiple credentials at once. For single loads, prefer [`Credential::load`]. +// +// We'd very much like it if in the future we could do filtering at the database level, +// obviating the requirement for this cache structure. See WPB-20839 and WPB-20844. +pub(crate) struct Cache { + keypairs: Vec, + credentials: Vec, +} + +impl CredentialRef { + /// Load all credentials which match this ref from the database. + /// + /// Note that this does not attach the credential to any Session; it just does the data manipulation. + /// + /// The database schema currently permits multiple credentials to exist simultaneously which match a given credential ref. + /// Therefore, this function returns all of them, ordered by `earliest_validity`. If you only need the first, + /// consider using [`Self::load_first`] or [`Self::load_first_with_cache`]. + /// + /// Due to database limitations we currently cannot efficiently retrieve only those keypairs of interest; + /// if you are going to be loading several references in a row, it is more efficient to first fetch all + /// stored keypairs with [`Self::load_cache`] and then call [`Self::load_with_cache`]. + pub(crate) async fn load(&self, database: &Database) -> Result> { + let cache = Self::load_cache(database).await?; + let mut credentials = self.load_with_cache(&cache).await?.collect::>>()?; + credentials.sort_by_key(|credential| credential.earliest_validity); + Ok(credentials) + } + + /// Load the first credential which matches this ref from the database. + /// + /// Note that this does not attach the credential to any Session; it just does the data manipulation. + /// + /// The database schema currently permits multiple credentials to exist simultaneously which match a given credential ref. + /// If you need any of them beyond the first when ordered by `earliest_validity`, use [`Self::load`]. + /// + /// Due to database limitations we currently cannot efficiently retrieve only those keypairs of interest; + /// if you are going to be loading several references in a row, it is more efficient to first fetch all + /// stored keypairs with [`Self::load_cache`] and then call [`Self::load_first_with_cache`]. + // + // We should evaluate later if this method is worth retaining, but for now let's keep the impl + // in case we want it in the future. + #[expect(dead_code)] + pub(crate) async fn load_first(&self, database: &Database) -> Result { + let cache = Self::load_cache(database).await?; + self.load_first_with_cache(&cache).await + } + + /// Helper to prefetch relevant keypairs when loading multiple credentials at a time. + /// + /// Only useful when preparing to call [`Self::load_with_cache`] multiple times. + /// For loading a single credential, prefer [`Self::load`]. + pub(crate) async fn load_cache(database: &Database) -> Result { + let keypairs = keypairs::load_all(database) + .await + .map_err(RecursiveError::mls_credential("loading all keypairs for cache"))?; + let mut credentials = database + .find_all::(EntityFindParams::default()) + .await + .map_err(KeystoreError::wrap("finding all mls credentials"))?; + credentials.sort_by_key(|credential| credential.created_at); + Ok(Cache { keypairs, credentials }) + } + + /// Load the first credential which matches this ref from the database. + /// + /// Note that this does not attach the credential to any Session; it just does the data manipulation. + /// + /// The database schema currently permits multiple credentials to exist simultaneously which match a given credential ref. + /// If you need any of them beyond the first when ordered by `earliest_validity`, use [`Self::load_with_cache`]. + /// + /// If you are only loading a single credential ref, it may be simpler to call [`Self::load_first`]. + pub(crate) async fn load_first_with_cache(&self, cache: &Cache) -> Result { + self.load_with_cache(cache) + .await? + .min_by(|a, b| { + // errors are the min value so they are propagated + // otherwise we return the min by `earliest_validity` + match (a, b) { + (Err(_), Err(_)) => Ordering::Equal, + (Err(_), Ok(_)) => Ordering::Less, + (Ok(_), Err(_)) => Ordering::Greater, + (Ok(a), Ok(b)) => a.earliest_validity.cmp(&b.earliest_validity), + } + }) + .ok_or(Error::CredentialNotFound) + .flatten() + } + + /// Load this credential from the database. + /// + /// Note that this does not attach the credential to any Session; it just does the data manipulation. + /// + /// The database schema currently permits multiple credentials to exist simultaneously which match a given credential ref. + /// Therefore, this function returns a possibly-empty iterator over all of them, ordered by `earliest_validity`. + /// + /// If you are only loading a single credential ref, it may be simpler to call [`Self::load`]. + pub(crate) async fn load_with_cache<'cache, 'cref>( + &'cref self, + Cache { keypairs, credentials }: &'cache Cache, + ) -> Result>> + where + 'cref: 'cache, + { + let signature_key_pair = keypairs::find_matching(keypairs, self.client_id(), self.signature_scheme()) + .await + .map_err(RecursiveError::mls_credential( + "finding matching key pairs while loading credential", + ))? + .ok_or(Error::KeypairNotFound)?; + + let iter = credentials + .iter() + // this is the only check we can currently do at the DB level: match the client id + .filter(|stored_credential| stored_credential.id == self.client_id().as_slice()) + // from here we can at least deserialize the credential + .map(move |stored_credential| { + let mls_credential = MlsCredential::tls_deserialize(&mut stored_credential.credential.as_slice()) + .map_err(Error::tls_deserialize("mls credential"))?; + let earliest_validity = stored_credential.created_at; + Ok(Credential { + signature_key_pair: signature_key_pair.clone(), + mls_credential, + earliest_validity, + }) + }) + // after deserialization, we can filter out any results which do not match the conditions in the credential ref + // but pass through any errors + .filter(|credential_result| { + credential_result + .as_ref() + .map(|credential| { + credential.signature_key_pair.signature_scheme() == self.signature_scheme() + && credential.mls_credential.credential_type() == self.r#type() + }) + .unwrap_or(true) + }) + // we also need to ensure that the credential validates + .map(|credential_result| { + credential_result.and_then(|credential| { + Credential::validate_mls_credential( + &credential.mls_credential, + self.client_id(), + &credential.signature_key_pair, + )?; + Ok(credential) + }) + }); + + Ok(iter) + } +} diff --git a/crypto/src/mls/credential/error.rs b/crypto/src/mls/credential/error.rs index 9f54235f7a..2a7c3941f9 100644 --- a/crypto/src/mls/credential/error.rs +++ b/crypto/src/mls/credential/error.rs @@ -35,4 +35,36 @@ pub enum Error { Mls(#[from] crate::MlsError), #[error(transparent)] Recursive(#[from] crate::RecursiveError), + #[error("TLS serializing {item}")] + TlsSerialize { + #[source] + source: tls_codec::Error, + item: &'static str, + }, + #[error("TLS deserializing {item}")] + TlsDeserialize { + #[source] + source: tls_codec::Error, + item: &'static str, + }, +} + +impl Error { + pub fn tls_serialize(item: &'static str) -> impl FnOnce(tls_codec::Error) -> Self { + move |source| Self::TlsSerialize { source, item } + } + + pub fn tls_deserialize(item: &'static str) -> impl FnOnce(tls_codec::Error) -> Self { + move |source| Self::TlsDeserialize { source, item } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum CredentialValidationError { + #[error("identity or public key did not match")] + WrongCredential, + #[error("public key not extractable from certificate")] + NoPublicKey, + #[error(transparent)] + Recursive(#[from] crate::RecursiveError), } diff --git a/crypto/src/mls/credential/ext.rs b/crypto/src/mls/credential/ext.rs index 9e2ff52e2e..f23320a511 100644 --- a/crypto/src/mls/credential/ext.rs +++ b/crypto/src/mls/credential/ext.rs @@ -4,15 +4,15 @@ use wire_e2e_identity::prelude::{HashAlgorithm, JwsAlgorithm, compute_raw_key_th use x509_cert::{Certificate, der::Decode}; use super::{Error, Result}; -use crate::{DeviceStatus, MlsCiphersuite, MlsCredentialType, RecursiveError, WireIdentity}; +use crate::{Ciphersuite, DeviceStatus, RecursiveError, WireIdentity}; #[allow(dead_code)] pub(crate) trait CredentialExt { fn parse_leaf_cert(&self) -> Result>; - fn get_type(&self) -> Result; + fn get_type(&self) -> Result; fn extract_identity( &self, - cs: MlsCiphersuite, + cs: Ciphersuite, env: Option<&wire_e2e_identity::prelude::x509::revocation::PkiEnvironment>, ) -> Result; fn extract_public_key(&self) -> Result>>; @@ -24,13 +24,13 @@ impl CredentialExt for CredentialWithKey { self.credential.parse_leaf_cert() } - fn get_type(&self) -> Result { + fn get_type(&self) -> Result { self.credential.get_type() } fn extract_identity( &self, - cs: MlsCiphersuite, + cs: Ciphersuite, env: Option<&wire_e2e_identity::prelude::x509::revocation::PkiEnvironment>, ) -> Result { match self.credential.mls_credential() { @@ -46,7 +46,7 @@ impl CredentialExt for CredentialWithKey { Ok(WireIdentity { client_id, - credential_type: MlsCredentialType::Basic, + credential_type: CredentialType::Basic, thumbprint, status: DeviceStatus::Valid, x509_identity: None, @@ -72,17 +72,17 @@ impl CredentialExt for Credential { } } - fn get_type(&self) -> Result { + fn get_type(&self) -> Result { match self.credential_type() { - CredentialType::Basic => Ok(MlsCredentialType::Basic), - CredentialType::X509 => Ok(MlsCredentialType::X509), + CredentialType::Basic => Ok(CredentialType::Basic), + CredentialType::X509 => Ok(CredentialType::X509), CredentialType::Unknown(_) => Err(Error::UnsupportedCredentialType), } } fn extract_identity( &self, - _cs: MlsCiphersuite, + _cs: Ciphersuite, _env: Option<&wire_e2e_identity::prelude::x509::revocation::PkiEnvironment>, ) -> Result { // This should not be called directly because one does not have the signature public key and hence @@ -110,13 +110,13 @@ impl CredentialExt for openmls::prelude::Certificate { Ok(Some(leaf)) } - fn get_type(&self) -> Result { - Ok(MlsCredentialType::X509) + fn get_type(&self) -> Result { + Ok(CredentialType::X509) } fn extract_identity( &self, - cs: MlsCiphersuite, + cs: Ciphersuite, env: Option<&wire_e2e_identity::prelude::x509::revocation::PkiEnvironment>, ) -> Result { let leaf = self.certificates.first().ok_or(Error::InvalidIdentity)?; @@ -145,7 +145,7 @@ impl CredentialExt for openmls::prelude::Certificate { } } -fn compute_thumbprint(cs: MlsCiphersuite, raw_key: &[u8]) -> Result { +fn compute_thumbprint(cs: Ciphersuite, raw_key: &[u8]) -> Result { let sign_alg = match cs.signature_algorithm() { SignatureScheme::ED25519 => JwsAlgorithm::Ed25519, SignatureScheme::ECDSA_SECP256R1_SHA256 => JwsAlgorithm::P256, diff --git a/crypto/src/mls/credential/keypairs.rs b/crypto/src/mls/credential/keypairs.rs new file mode 100644 index 0000000000..e262671c20 --- /dev/null +++ b/crypto/src/mls/credential/keypairs.rs @@ -0,0 +1,81 @@ +use core_crypto_keystore::{ + connection::FetchFromDatabase as _, + entities::{EntityFindParams, StoredSignatureKeypair}, +}; +use mls_crypto_provider::Database; +use openmls::prelude::{OpenMlsCrypto, SignatureScheme}; +use openmls_basic_credential::SignatureKeyPair; +use tls_codec::{Deserialize as _, Serialize as _}; + +use super::{Error, Result}; +use crate::{KeystoreError, MlsError, mls::session::id::ClientIdRef}; + +/// Load all stored keypairs from the keystore +/// +/// Ensures the keypairs are sorted in order of their creation date. +pub(super) async fn load_all(database: &Database) -> Result> { + database + .find_all::(EntityFindParams::default()) + .await + .map_err(KeystoreError::wrap("finding all mls signature keypairs")) + .map_err(Into::into) +} + +/// Generate a new keypair in-memory with the specificed signature scheme +pub(super) fn generate(crypto: impl OpenMlsCrypto, signature_scheme: SignatureScheme) -> Result { + let (private_key, public_key) = crypto + .signature_key_gen(signature_scheme) + .map_err(MlsError::wrap("generating signature key"))?; + Ok(SignatureKeyPair::from_raw(signature_scheme, private_key, public_key)) +} + +/// Store a keypair in the keystore, attached to a particular client id +pub(super) async fn store(database: &Database, id: &ClientIdRef, keypair: &SignatureKeyPair) -> Result<()> { + let data = keypair + .tls_serialize_detached() + .map_err(Error::tls_serialize("keypair"))?; + + debug_assert!( + { + let deserialized = + SignatureKeyPair::tls_deserialize_exact(&data).expect("keypair deserializes without error"); + deserialized.signature_scheme() == keypair.signature_scheme() + && deserialized.public() == keypair.public() + && deserialized.private() == keypair.private() + }, + "serialized keypair data must deserialize correctly" + ); + + let stored_keypair = StoredSignatureKeypair::new( + keypair.signature_scheme(), + keypair.public().to_owned(), + data, + id.as_slice().into(), + ); + database + .save(stored_keypair) + .await + .map_err(KeystoreError::wrap("storing keypairs in keystore"))?; + + Ok(()) +} + +/// Deserialize a [`StoredSignatureKeypair`] into a [`SignatureKeyPair`] +pub(super) fn deserialize(stored: &StoredSignatureKeypair) -> Result { + SignatureKeyPair::tls_deserialize_exact(&stored.keypair) + .map_err(KeystoreError::wrap("deserializing keypair from keystore")) + .map_err(Into::into) +} + +/// Retrieve the first keypair from the list which matches the provided signature scheme and client id +pub(super) async fn find_matching( + keypairs: &[StoredSignatureKeypair], + client_id: impl AsRef<[u8]>, + signature_scheme: SignatureScheme, +) -> Result> { + keypairs + .iter() + .find(|stored| stored.credential_id == client_id.as_ref() && stored.signature_scheme == signature_scheme as u16) + .map(deserialize) + .transpose() +} diff --git a/crypto/src/mls/credential/mod.rs b/crypto/src/mls/credential/mod.rs index 9191940f87..15530fb25b 100644 --- a/crypto/src/mls/credential/mod.rs +++ b/crypto/src/mls/credential/mod.rs @@ -1,106 +1,178 @@ -use std::{ - cmp::Ordering, - hash::{Hash, Hasher}, -}; - -use openmls::prelude::{Credential, CredentialWithKey}; -use openmls_basic_credential::SignatureKeyPair; +//! This module focuses on [`Credential`]s: cryptographic assertions of identity. +//! +//! Credentials can be basic, or based on an x509 certificate chain. +pub(crate) mod credential_ref; pub(crate) mod crl; mod error; pub(crate) mod ext; -pub(crate) mod typ; +mod keypairs; +mod persistence; pub(crate) mod x509; -pub(crate) use error::{Error, Result}; +use std::hash::{Hash, Hasher}; + +use openmls::prelude::{Credential as MlsCredential, CredentialWithKey, MlsCredentialType, SignatureScheme}; +use openmls_basic_credential::SignatureKeyPair; +use openmls_traits::crypto::OpenMlsCrypto; + +pub(crate) use self::error::Result; +pub use self::{ + credential_ref::{CredentialRef, FindFilters, FindFiltersBuilder}, + error::Error, +}; +use crate::{ + ClientId, ClientIdRef, RecursiveError, + mls::credential::{error::CredentialValidationError, ext::CredentialExt as _}, +}; -#[derive(Debug, serde::Serialize, serde::Deserialize)] -pub struct CredentialBundle { - pub(crate) credential: Credential, - pub(crate) signature_key: SignatureKeyPair, - pub(crate) created_at: u64, +/// A cryptographic credential. +/// +/// This is tied to a particular client via either its client id or certificate bundle, +/// depending on its credential type, but is independent of any client instance or storage. +/// +/// To attach to a particular client instance and store, see [`Session::add_credential`][crate::Session::add_credential]. +/// +/// Note: the current database design makes some questionable assumptions: +/// +/// - There are always either 0 or 1 `StoredSignatureKeypair` instances in the DB for a particular signature scheme +/// - There may be multiple `StoredCredential` instances in the DB for a particular signature scheme, but they all share +/// the same `ClientId` / signing key. In other words, the same signing keypair is _reused_ between credentials. +/// - Practically, the code ensures that there is a 1:1 correspondence between signing scheme <-> identity/credential, +/// and we need to maintain that property for now. +/// +/// Work is ongoing to fix those limitations; see WPB-20844. Until that is resolved, we enforce those restrictions by +/// raising errors as required to preserve DB integrity. +#[derive(core_crypto_macros::Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct Credential { + /// MLS internal credential. Stores the credential type. + pub(crate) mls_credential: MlsCredential, + /// Public and private keys, and the signature scheme. + #[sensitive] + pub(crate) signature_key_pair: SignatureKeyPair, + /// Earliest valid time of creation for this credential. + /// + /// This is represented as seconds after the unix epoch. + /// + /// Only meaningful for X509, where it is the "valid_from" claim of the leaf credential. + /// For basic credentials, this is always 0. + pub(crate) earliest_validity: u64, } -impl CredentialBundle { - pub fn credential(&self) -> &Credential { - &self.credential +impl Credential { + /// Ensure that the provided `MlsCredential` matches the client id / signature key provided + pub(crate) fn validate_mls_credential( + mls_credential: &MlsCredential, + client_id: &ClientIdRef, + signature_key: &SignatureKeyPair, + ) -> Result<(), CredentialValidationError> { + match mls_credential.mls_credential() { + MlsCredentialType::Basic(_) => { + if client_id.as_slice() != mls_credential.identity() { + return Err(CredentialValidationError::WrongCredential); + } + } + MlsCredentialType::X509(cert) => { + let certificate_public_key = cert + .extract_public_key() + .map_err(RecursiveError::mls_credential( + "extracting public key from certificate in credential validation", + ))? + .ok_or(CredentialValidationError::NoPublicKey)?; + if signature_key.public() != certificate_public_key { + return Err(CredentialValidationError::WrongCredential); + } + } + } + Ok(()) + } + + /// Generate a basic credential. + /// + /// The result is independent of any client instance and the database; it lives in memory only. + pub fn basic(signature_scheme: SignatureScheme, client_id: ClientId, crypto: impl OpenMlsCrypto) -> Result { + let signature_key_pair = keypairs::generate(crypto, signature_scheme)?; + + Ok(Self { + mls_credential: MlsCredential::new_basic(client_id.into_inner()), + signature_key_pair, + earliest_validity: 0, + }) } + /// Get the Openmls Credential type. + /// + /// This stores the credential type (basic/x509). + pub fn mls_credential(&self) -> &MlsCredential { + &self.mls_credential + } + + /// Get a reference to the `SignatureKeyPair`. pub(crate) fn signature_key(&self) -> &SignatureKeyPair { - &self.signature_key + &self.signature_key_pair } + /// Generate a `CredentialWithKey`, which combines the credential type with the public portion of the keypair. pub fn to_mls_credential_with_key(&self) -> CredentialWithKey { CredentialWithKey { - credential: self.credential.clone(), - signature_key: self.signature_key.to_public_vec().into(), + credential: self.mls_credential.clone(), + signature_key: self.signature_key_pair.to_public_vec().into(), } } -} -impl From for CredentialWithKey { - fn from(cb: CredentialBundle) -> Self { - Self { - credential: cb.credential, - signature_key: cb.signature_key.public().into(), - } + /// Earliest valid time of creation for this credential. + /// + /// This is represented as seconds after the unix epoch. + /// + /// Only meaningful for X509, where it is the "valid_from" claim of the leaf credential. + /// For basic credentials, this is always 0 when the credential is first created. + /// It is updated upon being persisted to the database. + pub fn earliest_validity(&self) -> u64 { + self.earliest_validity + } + + /// Get the client ID associated with this credential + pub fn client_id(&self) -> &ClientIdRef { + self.mls_credential.identity().into() } } -impl Clone for CredentialBundle { - fn clone(&self) -> Self { +impl From for CredentialWithKey { + fn from(cb: Credential) -> Self { Self { - credential: self.credential.clone(), - signature_key: SignatureKeyPair::from_raw( - self.signature_key.signature_scheme(), - self.signature_key.private().to_vec(), - self.signature_key.to_public_vec(), - ), - created_at: self.created_at, + credential: cb.mls_credential, + signature_key: cb.signature_key_pair.public().into(), } } } -impl Eq for CredentialBundle {} -impl PartialEq for CredentialBundle { +impl Eq for Credential {} +impl PartialEq for Credential { fn eq(&self, other: &Self) -> bool { - self.credential.eq(&other.credential) - && self.created_at.eq(&other.created_at) - && self - .signature_key - .signature_scheme() - .eq(&other.signature_key.signature_scheme()) - && self.signature_key.public().eq(other.signature_key.public()) + self.mls_credential == other.mls_credential && self.earliest_validity == other.earliest_validity && { + let sk = &self.signature_key_pair; + let ok = &other.signature_key_pair; + sk.signature_scheme() == ok.signature_scheme() && sk.public() == ok.public() && sk.private() == ok.private() + } } } -impl Hash for CredentialBundle { +impl Hash for Credential { fn hash(&self, state: &mut H) { - self.created_at.hash(state); - self.signature_key.signature_scheme().hash(state); - self.signature_key.public().hash(state); - self.credential().identity().hash(state); - match self.credential().mls_credential() { - openmls::prelude::MlsCredentialType::X509(cert) => { + self.earliest_validity.hash(state); + self.signature_key_pair.signature_scheme().hash(state); + self.signature_key_pair.public().hash(state); + // self.mls_credential.credential_type().hash(state); // not implemented for Reasons, idk + self.mls_credential.identity().hash(state); + match self.mls_credential().mls_credential() { + MlsCredentialType::X509(cert) => { cert.certificates.hash(state); } - openmls::prelude::MlsCredentialType::Basic(_) => {} + MlsCredentialType::Basic(_) => {} }; } } -impl Ord for CredentialBundle { - fn cmp(&self, other: &Self) -> Ordering { - self.created_at.cmp(&other.created_at) - } -} - -impl PartialOrd for CredentialBundle { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - // TODO: ensure certificate signature must match the group's ciphersuite ; fails otherwise. Tracking issue: WPB-9632 // Requires more than 1 ciphersuite supported at the moment. #[cfg(test)] @@ -111,7 +183,7 @@ mod tests { use super::{x509::CertificateBundle, *}; use crate::{ - ClientIdentifier, E2eiConversationState, MlsCredentialType, + ClientIdentifier, CredentialType, E2eiConversationState, mls::{conversation::Conversation as _, credential::x509::CertificatePrivateKey}, test_utils::{ x509::{CertificateParams, X509TestChain}, @@ -145,8 +217,9 @@ mod tests { let ([x509_session], [basic_session]) = case.sessions_mixed_credential_types().await; // That way the conversation creator (Alice) will have a different credential type than Bob let (alice, bob, alice_credential_type) = match case.credential_type { - MlsCredentialType::Basic => (x509_session, basic_session, MlsCredentialType::X509), - MlsCredentialType::X509 => (basic_session, x509_session, MlsCredentialType::Basic), + CredentialType::Basic => (x509_session, basic_session, CredentialType::X509), + CredentialType::X509 => (basic_session, x509_session, CredentialType::Basic), + _ => panic!("only basic and x509 credential types supported"), }; let conversation = case @@ -307,7 +380,7 @@ mod tests { // Charlie is a basic client that tries to join (i.e. emulates guest links in Wire) let conversation = conversation - .invite_with_credential_type_notify(MlsCredentialType::Basic, [&charlie]) + .invite_with_credential_type_notify(CredentialType::Basic, [&charlie]) .await; assert_eq!( diff --git a/crypto/src/mls/credential/persistence.rs b/crypto/src/mls/credential/persistence.rs new file mode 100644 index 0000000000..acd89ca208 --- /dev/null +++ b/crypto/src/mls/credential/persistence.rs @@ -0,0 +1,43 @@ +use core_crypto_keystore::entities::StoredCredential; +use mls_crypto_provider::Database; +use tls_codec::Serialize as _; + +use super::{Error, Result}; +use crate::{Credential, CredentialRef, KeystoreError, mls::credential::keypairs}; + +impl Credential { + /// Update all the fields that were updated by the DB during the save. + /// [`::pre_save`][core_crypto_keystore::entities::EntityTransactionExt::pre_save]. + fn update_from(&mut self, stored: StoredCredential) { + self.earliest_validity = stored.created_at; + } + + /// Persist this credential into the database. + /// + /// Returns a reference which is stable over time and across the FFI boundary. + pub async fn save(&mut self, database: &Database) -> Result { + keypairs::store(database, self.client_id(), &self.signature_key_pair).await?; + + let credential_data = self + .mls_credential + .tls_serialize_detached() + .map_err(Error::tls_serialize("credential"))?; + + let stored_credential = database + .save(StoredCredential { + id: self.client_id().to_owned().into_inner(), + credential: credential_data, + created_at: Default::default(), // updated by the `.save` impl + }) + .await + .map_err(KeystoreError::wrap("saving credential"))?; + + self.update_from(stored_credential); + + Ok(CredentialRef::new( + self.client_id().to_owned(), + self.mls_credential.credential_type(), + self.signature_key_pair.signature_scheme(), + )) + } +} diff --git a/crypto/src/mls/credential/typ.rs b/crypto/src/mls/credential/typ.rs deleted file mode 100644 index 700f4ad047..0000000000 --- a/crypto/src/mls/credential/typ.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::unreachable; - -use openmls::prelude::CredentialType; - -/// Lists all the supported Credential types. Could list in the future some types not supported by -/// openmls such as Verifiable Presentation -#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)] -#[repr(u8)] -pub enum MlsCredentialType { - /// Basic credential i.e. a KeyPair - #[default] - Basic = 0x01, - /// A x509 certificate generally obtained through e2e identity enrollment process - X509 = 0x02, -} - -impl From for MlsCredentialType { - fn from(value: CredentialType) -> Self { - match value { - CredentialType::Basic => MlsCredentialType::Basic, - CredentialType::X509 => MlsCredentialType::X509, - _ => unreachable!("Unknown credential type"), - } - } -} - -impl From for CredentialType { - fn from(value: MlsCredentialType) -> Self { - match value { - MlsCredentialType::Basic => CredentialType::Basic, - MlsCredentialType::X509 => CredentialType::X509, - } - } -} diff --git a/crypto/src/mls/credential/x509.rs b/crypto/src/mls/credential/x509.rs index 790fa656e3..38b173a7d8 100644 --- a/crypto/src/mls/credential/x509.rs +++ b/crypto/src/mls/credential/x509.rs @@ -1,6 +1,11 @@ +#[cfg(test)] +use std::collections::HashMap; + #[cfg(test)] use mls_crypto_provider::PkiKeypair; +use openmls::prelude::Credential as MlsCredential; use openmls_traits::types::SignatureScheme; +use openmls_x509_credential::CertificateKeyPair; use wire_e2e_identity::prelude::{HashAlgorithm, WireIdentityReader}; #[cfg(test)] use x509_cert::der::Encode; @@ -9,7 +14,7 @@ use zeroize::Zeroize; use super::{Error, Result}; #[cfg(test)] use crate::test_utils::x509::X509Certificate; -use crate::{ClientId, RecursiveError, e2e_identity::id::WireQualifiedClientId}; +use crate::{ClientId, Credential, MlsError, RecursiveError, e2e_identity::id::WireQualifiedClientId}; #[derive(core_crypto_macros::Debug, Clone, Zeroize)] #[zeroize(drop)] @@ -66,6 +71,28 @@ impl CertificateBundle { } } +impl Credential { + /// Create a new x509 credential from a certificate bundle. + pub fn x509(cert: CertificateBundle) -> Result { + let created_at = cert + .get_created_at() + .map_err(RecursiveError::mls_credential("getting credetntial created at"))?; + let (sk, ..) = cert.private_key.into_parts(); + let chain = cert.certificate_chain; + + let kp = CertificateKeyPair::new(sk, chain.clone()).map_err(MlsError::wrap("creating certificate key pair"))?; + + let credential = MlsCredential::new_x509(chain).map_err(MlsError::wrap("creating x509 credential"))?; + + let cb = Credential { + mls_credential: credential, + signature_key_pair: kp.0, + earliest_validity: created_at, + }; + Ok(cb) + } +} + #[cfg(test)] fn new_rand_client(domain: Option) -> (String, String) { let rand_str = |n: usize| { @@ -163,15 +190,20 @@ impl CertificateBundle { } } + pub fn rand_identifier_certs( + client_id: &ClientId, + signers: &[&crate::test_utils::x509::X509Certificate], + ) -> HashMap { + signers + .iter() + .map(|signer| (signer.signature_scheme, Self::rand(client_id, signer))) + .collect() + } + pub fn rand_identifier( - name: &str, + client_id: &ClientId, signers: &[&crate::test_utils::x509::X509Certificate], ) -> crate::ClientIdentifier { - crate::ClientIdentifier::X509( - signers - .iter() - .map(|signer| (signer.signature_scheme, Self::rand(&name.into(), signer))) - .collect::>(), - ) + crate::ClientIdentifier::X509(Self::rand_identifier_certs(client_id, signers)) } } diff --git a/crypto/src/mls/mod.rs b/crypto/src/mls/mod.rs index f1dbb8d794..95e6751bdc 100644 --- a/crypto/src/mls/mod.rs +++ b/crypto/src/mls/mod.rs @@ -4,7 +4,7 @@ use crate::{ClientId, MlsConversation, Session}; pub(crate) mod ciphersuite; pub mod conversation; -pub(crate) mod credential; +pub mod credential; mod error; pub(crate) mod proposal; pub(crate) mod session; @@ -23,7 +23,7 @@ pub(crate) trait HasSessionAndCrypto: Send { mod tests { use crate::{ - CertificateBundle, ClientIdentifier, CoreCrypto, MlsCredentialType, SessionConfig, + CertificateBundle, ClientIdentifier, CoreCrypto, CredentialType, SessionConfig, mls::Session, test_utils::{x509::X509TestChain, *}, transaction_context::Error as TransactionError, @@ -67,7 +67,7 @@ mod tests { mod invariants { use super::*; - use crate::{MlsCiphersuite, mls}; + use crate::{Ciphersuite, mls}; #[apply(all_cred_cipher)] async fn can_create_from_valid_configuration(mut case: TestContext) { @@ -95,7 +95,7 @@ mod tests { let config_err = SessionConfig::builder() .database(db) .client_id("".into()) - .ciphersuites([MlsCiphersuite::default()]) + .ciphersuites([Ciphersuite::default()]) .build() .validate() .unwrap_err(); @@ -148,6 +148,8 @@ mod tests { async fn can_2_phase_init_central(mut case: TestContext) { let db = case.create_persistent_db().await; Box::pin(async move { + use crate::ClientId; + let x509_test_chain = X509TestChain::init_empty(case.signature_scheme()); let configuration = SessionConfig::builder() .database(db) @@ -164,12 +166,13 @@ mod tests { assert!(!context.session().await.unwrap().is_ready().await); // phase 2: init mls_client - let client_id = "alice"; + let client_id = ClientId::from("alice"); let identifier = match case.credential_type { - MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.into()), - MlsCredentialType::X509 => { - CertificateBundle::rand_identifier(client_id, &[x509_test_chain.find_local_intermediate_ca()]) + CredentialType::Basic => ClientIdentifier::Basic(client_id.into()), + CredentialType::X509 => { + CertificateBundle::rand_identifier(&client_id, &[x509_test_chain.find_local_intermediate_ca()]) } + CredentialType::Unknown(_) => panic!("unknown credential types are unsupported"), }; context.mls_init(identifier, &[case.ciphersuite()]).await.unwrap(); assert!(context.session().await.unwrap().is_ready().await); diff --git a/crypto/src/mls/session/config.rs b/crypto/src/mls/session/config.rs index 24c6c87710..692f630fa3 100644 --- a/crypto/src/mls/session/config.rs +++ b/crypto/src/mls/session/config.rs @@ -4,7 +4,7 @@ use typed_builder::TypedBuilder; use crate::{ ClientId, mls::{ - ciphersuite::MlsCiphersuite, + ciphersuite::Ciphersuite, error::{Error, Result}, }, }; @@ -22,8 +22,8 @@ pub struct SessionConfig { #[builder(default, setter(strip_option(fallback = client_id_opt)))] pub client_id: Option, /// All supported ciphersuites in this session - #[builder(default, setter(transform = |iter: impl IntoIterator| iter.into_iter().collect()))] - pub ciphersuites: Vec, + #[builder(default, setter(transform = |iter: impl IntoIterator| iter.into_iter().collect()))] + pub ciphersuites: Vec, } /// Validated configuration parameters for [Session][crate::mls::session::Session]. @@ -33,7 +33,7 @@ pub struct SessionConfig { pub struct ValidatedSessionConfig { pub(super) database: Database, pub(super) client_id: Option, - pub(super) ciphersuites: Vec, + pub(super) ciphersuites: Vec, } impl SessionConfig { diff --git a/crypto/src/mls/session/credential.rs b/crypto/src/mls/session/credential.rs new file mode 100644 index 0000000000..44cf240ebd --- /dev/null +++ b/crypto/src/mls/session/credential.rs @@ -0,0 +1,142 @@ +use openmls_traits::OpenMlsCryptoProvider as _; + +use super::{Error, Result}; +use crate::{CredentialFindFilters, CredentialRef, MlsConversation, RecursiveError, Session}; + +impl Session { + /// Find all credentials which match the specified conditions. + /// + /// If no filters are set, this is equivalent to [`get_credentials`][Self::get_credentials]. + /// + /// This is a convenience method entirely equivalent to [CredentialRef::find]; + /// the only difference is that it automatically includes the appropriate + /// [`Database`][core_crypto_keystore::Database] reference. + pub async fn find_credentials(&self, find_filters: CredentialFindFilters<'_>) -> Result> { + let database = self.crypto_provider.keystore(); + CredentialRef::find(&database, find_filters) + .await + .map_err(RecursiveError::mls_credential_ref("finding credentials")) + .map_err(Into::into) + } + + /// Get all credentials + /// + /// This is a convenience method entirely equivalent to [CredentialRef::get_all]; + /// the only difference is that it automatically includes the appropriate + /// [`Database`][core_crypto_keystore::Database] reference. + pub async fn get_credentials(&self) -> Result> { + let database = self.crypto_provider.keystore(); + CredentialRef::get_all(&database) + .await + .map_err(RecursiveError::mls_credential_ref("getting all credentials")) + .map_err(Into::into) + } + + /// Add a credential to the identities of this session. + /// + /// Note that this accepts a [`CredentialRef`], _not_ a raw [`Credential`][crate::Credential]. + /// This is because a `CredentialRef` serves as proof of persistence. Only credentials + /// which have been persisted are eligible to be included in a session. + pub async fn add_credential(&self, credential_ref: &CredentialRef) -> Result<()> { + if *credential_ref.client_id() != self.id().await? { + return Err(Error::WrongCredential); + } + + // The primary key situation of `Credential` is a bad joke. + // We have no idea how many credentials might be attached to a particular ref, or even + // how they may be related. + // + // Happily, our identities structure has set semantics, so let's lean (heavily) on that. + + // `.load` allocates, but also sorts by `earliest_validity`, which we want + let credentials = + credential_ref + .load(&self.crypto_provider.keystore()) + .await + .map_err(RecursiveError::mls_credential_ref( + "loading all matching credentials in `add_credential`", + ))?; + + let mut inner = self.inner.write().await; + let inner = inner.as_mut().ok_or(Error::MlsNotInitialized)?; + + for credential in credentials { + inner + .identities + .push_credential(credential.signature_key_pair.signature_scheme(), credential) + .await?; + } + + Ok(()) + } + + /// Remove a credential from the identities of this session. + /// + /// First checks that the credential is not used in any conversation. + /// Removes both the credential itself and also any key packages which were generated from it. + pub async fn remove_credential(&self, credential_ref: &CredentialRef) -> Result<()> { + // setup + if *credential_ref.client_id() != self.id().await? { + return Err(Error::WrongCredential); + } + + let database = self.crypto_provider.keystore(); + + let credentials = credential_ref + .load(&database) + .await + .map_err(RecursiveError::mls_credential_ref( + "loading all credentials from ref to remove from session identities", + ))?; + + // in a perfect world, we'd pre-cache the mls credentials in a set structure of some sort for faster querying. + // unfortunately, `MlsCredential` is `!Hash` and `!Ord`, so both the standard sets are out. + // so whatever, linear scan over the credentials every time will have to do. + + // ensure this credential is not in use by any conversation + for (conversation_id, conversation) in + MlsConversation::load_all(&database) + .await + .map_err(RecursiveError::mls_conversation( + "loading all conversations to check if the credential to be removed is present", + ))? + { + let converation_credential = conversation + .own_mls_credential() + .map_err(RecursiveError::mls_conversation("geting conversation credential"))?; + if credentials + .iter() + .any(|credential| credential.mls_credential() == converation_credential) + { + return Err(Error::CredentialStillInUse(conversation_id)); + } + } + + // remove any key packages generated by this credential + let keypackages = self.find_all_keypackages(&self.crypto_provider.keystore()).await?; + let keypackages_from_this_credential = keypackages.iter().filter_map(|(_stored_key_package, key_package)| { + credentials + .iter() + .any(|credential| key_package.leaf_node().credential() == credential.mls_credential()) + // if computing the hash reference fails, we will just not delete the key package + .then(|| key_package.hash_ref(self.crypto_provider.crypto()).ok()).flatten() + }); + self.prune_keypackages(&self.crypto_provider, keypackages_from_this_credential) + .await?; + + // remove all credentials associated with this ref + // do this last so we only remove the actual credential after the keypackages are all gone, + // and keep the lock open as briefly as possible + let mut inner = self.inner.write().await; + let inner = inner.as_mut().ok_or(Error::MlsNotInitialized)?; + for credential in credentials { + inner + .identities + .remove(credential.mls_credential()) + .await + .map_err(RecursiveError::mls_client("removing credential"))?; + } + + Ok(()) + } +} diff --git a/crypto/src/mls/session/e2e_identity.rs b/crypto/src/mls/session/e2e_identity.rs index 5f367334c9..6e4ba54ffd 100644 --- a/crypto/src/mls/session/e2e_identity.rs +++ b/crypto/src/mls/session/e2e_identity.rs @@ -6,7 +6,7 @@ use openmls_traits::OpenMlsCryptoProvider as _; use wire_e2e_identity::prelude::WireIdentityReader as _; use super::{Error, Result, Session}; -use crate::{E2eiConversationState, MlsCiphersuite, MlsCredentialType, MlsError, mls::session::CredentialExt as _}; +use crate::{Ciphersuite, CredentialType, E2eiConversationState, MlsError, mls::session::CredentialExt as _}; impl Session { /// Returns whether the E2EI PKI environment is setup (i.e. Root CA, Intermediates, CRLs) @@ -17,11 +17,11 @@ impl Session { /// Returns true when end-to-end-identity is enabled for the given SignatureScheme pub async fn e2ei_is_enabled(&self, signature_scheme: SignatureScheme) -> Result { let x509_result = self - .find_most_recent_credential_bundle(signature_scheme, MlsCredentialType::X509) + .find_most_recent_credential(signature_scheme, CredentialType::X509) .await; match x509_result { - Err(Error::CredentialNotFound(MlsCredentialType::X509)) => { - self.find_most_recent_credential_bundle(signature_scheme, MlsCredentialType::Basic) + Err(Error::CredentialNotFound(CredentialType::X509)) => { + self.find_most_recent_credential(signature_scheme, CredentialType::Basic) .await?; Ok(false) } @@ -52,7 +52,7 @@ impl Session { Ok(Self::compute_conversation_state( cs, credentials, - MlsCredentialType::X509, + CredentialType::X509, self.crypto_provider.authentication_service().borrow().await.as_ref(), ) .await) @@ -63,7 +63,7 @@ impl Session { pub async fn get_credential_in_use( &self, group_info: VerifiableGroupInfo, - credential_type: MlsCredentialType, + credential_type: CredentialType, ) -> Result { let cs = group_info.ciphersuite().into(); // Not verifying the supplied the GroupInfo here could let attackers lure the clients about @@ -83,9 +83,9 @@ impl Session { .await } pub(crate) async fn get_credential_in_use_in_ratchet_tree( - ciphersuite: MlsCiphersuite, + ciphersuite: Ciphersuite, ratchet_tree: RatchetTree, - credential_type: MlsCredentialType, + credential_type: CredentialType, env: Option<&wire_e2e_identity::prelude::x509::revocation::PkiEnvironment>, ) -> Result { let credentials = ratchet_tree.iter().filter_map(|n| match n { @@ -98,9 +98,9 @@ impl Session { /// _credential_type will be used in the future to get the usage of VC Credentials, even Basics one. /// Right now though, we do not need anything other than X509 so let's keep things simple. pub(crate) async fn compute_conversation_state<'a>( - ciphersuite: MlsCiphersuite, + ciphersuite: Ciphersuite, credentials: impl Iterator, - _credential_type: MlsCredentialType, + _credential_type: CredentialType, env: Option<&wire_e2e_identity::prelude::x509::revocation::PkiEnvironment>, ) -> E2eiConversationState { let mut is_e2ei = false; diff --git a/crypto/src/mls/session/error.rs b/crypto/src/mls/session/error.rs index 73146c598c..d035b32fab 100644 --- a/crypto/src/mls/session/error.rs +++ b/crypto/src/mls/session/error.rs @@ -3,6 +3,8 @@ // We allow missing documentation in the error module because the types are generally self-descriptive. #![allow(missing_docs)] +use crate::ConversationId; + pub(crate) type Result = core::result::Result; #[derive(Debug, thiserror::Error)] @@ -11,12 +13,12 @@ pub enum Error { InvalidUserId, #[error("X509 certificate bundle set was empty")] NoX509CertificateBundle, - #[error("Tried to insert an already existing CredentialBundle")] - CredentialBundleConflict, + #[error("Tried to insert an already existing Credential")] + CredentialConflict, #[error("A MLS operation was requested but MLS hasn't been initialized on this instance")] MlsNotInitialized, #[error("A Credential of type {0:?} was not found locally which is very likely an implementation error")] - CredentialNotFound(crate::MlsCredentialType), + CredentialNotFound(crate::CredentialType), #[error("supplied signature scheme was not valid")] InvalidSignatureScheme, /// The keystore has no knowledge of such client; this shouldn't happen as Client::init is failsafe (find-else-create) @@ -31,6 +33,10 @@ pub enum Error { IdentityAlreadyPresent, #[error("The supplied credential does not match the id or signature schemes provided")] WrongCredential, + #[error("Credentials of type {0} are unknown")] + UnknownCredential(u16), + #[error("this credential is still in use by the conversation with id \"{}\"", hex::encode(.0))] + CredentialStillInUse(ConversationId), #[error("An EpochObserver has already been registered; reregistration is not possible")] EpochObserverAlreadyExists, #[error("An HistoryHandler has already been registered; reregistration is not possible")] diff --git a/crypto/src/mls/session/id.rs b/crypto/src/mls/session/id.rs index ba2ff49849..303f4586de 100644 --- a/crypto/src/mls/session/id.rs +++ b/crypto/src/mls/session/id.rs @@ -1,3 +1,9 @@ +use std::{ + borrow::{Borrow, Cow}, + fmt, + ops::Deref, +}; + use super::error::Error; /// A Client identifier @@ -6,20 +12,23 @@ use super::error::Error; /// mobile, etc. Users can have multiple clients. /// More information [here](https://messaginglayersecurity.rocks/mls-architecture/draft-ietf-mls-architecture.html#name-group-members-and-clients) #[derive( - core_crypto_macros::Debug, Clone, PartialEq, Eq, Hash, derive_more::Deref, serde::Serialize, serde::Deserialize, + core_crypto_macros::Debug, + Clone, + Eq, + PartialOrd, + Ord, + Hash, + derive_more::From, + derive_more::Into, + serde::Serialize, + serde::Deserialize, )] #[sensitive] pub struct ClientId(pub(crate) Vec); -impl From<&[u8]> for ClientId { - fn from(value: &[u8]) -> Self { - Self(value.into()) - } -} - -impl From> for ClientId { - fn from(value: Vec) -> Self { - Self(value) +impl ClientId { + pub(crate) fn into_inner(self) -> Vec { + self.0 } } @@ -35,17 +44,29 @@ impl From for Box<[u8]> { } } -#[cfg(test)] -impl From<&str> for ClientId { - fn from(value: &str) -> Self { - Self(value.as_bytes().into()) +impl Deref for ClientId { + type Target = ClientIdRef; + + fn deref(&self) -> &Self::Target { + ClientIdRef::new(&self.0) } } -#[allow(clippy::from_over_into)] -impl Into> for ClientId { - fn into(self) -> Vec { - self.0 +impl AsRef<[u8]> for ClientId { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl AsRef for ClientId { + fn as_ref(&self) -> &ClientIdRef { + ClientIdRef::new(&self.0) + } +} + +impl From for Cow<'_, [u8]> { + fn from(value: ClientId) -> Self { + Cow::Owned(value.0) } } @@ -65,6 +86,163 @@ impl std::str::FromStr for ClientId { } } +impl PartialEq for ClientId +where + ClientIdRef: PartialEq, +{ + fn eq(&self, other: &T) -> bool { + (**self).eq(other) + } +} + +#[cfg(test)] +impl From<&str> for ClientId { + fn from(value: &str) -> Self { + Self(value.as_bytes().into()) + } +} + +/// Reference to a [`ClientId`]. +/// +/// This type is `!Sized` and is only ever seen as a reference, like `str` or `[u8]`. +// +// pattern from https://stackoverflow.com/a/64990850 +#[repr(transparent)] +#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, derive_more::Deref)] +pub struct ClientIdRef([u8]); + +impl ClientIdRef { + /// Creates a `ClientId` Ref, needed to implement `Borrow` for `T` + pub fn new(bytes: &Bytes) -> &ClientIdRef + where + Bytes: AsRef<[u8]> + ?Sized, + { + // safety: because of `repr(transparent)` we know that `ClientIdRef` has a memory layout + // identical to `[u8]`, so we can perform this cast + unsafe { &*(bytes.as_ref() as *const [u8] as *const ClientIdRef) } + } + + /// View this reference as a byte slice + pub fn as_slice(&self) -> &[u8] { + self.as_ref() + } +} + +impl<'a> From<&'a [u8]> for &'a ClientIdRef { + fn from(value: &'a [u8]) -> Self { + ClientIdRef::new(value) + } +} + +impl<'a> From<&'a Vec> for &'a ClientIdRef { + fn from(value: &'a Vec) -> Self { + ClientIdRef::new(value.as_slice()) + } +} + +impl Borrow for ClientId { + fn borrow(&self) -> &ClientIdRef { + ClientIdRef::new(&self.0) + } +} + +impl Borrow for &'_ ClientId { + fn borrow(&self) -> &ClientIdRef { + ClientIdRef::new(&*self.0) + } +} + +impl ToOwned for ClientIdRef { + type Owned = ClientId; + + fn to_owned(&self) -> Self::Owned { + ClientId(self.0.to_owned()) + } +} + +impl AsRef<[u8]> for ClientIdRef { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl<'a> From<&'a ClientIdRef> for Cow<'a, [u8]> { + fn from(value: &'a ClientIdRef) -> Self { + Cow::Borrowed(value.as_ref()) + } +} + +impl PartialEq for ClientIdRef { + fn eq(&self, other: &ClientId) -> bool { + &self.0 == other.as_slice() + } +} + +impl PartialEq<[u8]> for ClientIdRef { + fn eq(&self, other: &[u8]) -> bool { + &self.0 == other + } +} + +impl PartialEq<&'_ [u8]> for ClientIdRef { + fn eq(&self, other: &&'_ [u8]) -> bool { + &self.0 == *other + } +} + +macro_rules! impl_eq { + ($( $t:ty => |$self:ident, $other:ident| $impl:expr ; )+) => { + $( + impl PartialEq<$t> for ClientIdRef { + fn eq(&self, other: &$t) -> bool { + let $self = self; + let $other = other; + $impl + } + } + + impl PartialEq for $t { + fn eq(&self, other: &ClientIdRef) -> bool { + other.eq(self) + } + } + + impl PartialEq<$t> for &'_ ClientIdRef { + fn eq(&self, other: &$t) -> bool { + let $self = self; + let $other = other; + $impl + } + } + + impl PartialEq<&'_ ClientIdRef> for $t { + fn eq(&self, other: &&'_ ClientIdRef) -> bool { + other.eq(self) + } + } + )+ + }; +} + +impl_eq!( + Vec => |me, other| me.0.eq(other.as_slice()); + Cow<'_, ClientIdRef> => |me, other| me.eq(&other.as_slice()); +); + +// we can't use `core_crypto_macros::Debug` to generate this because `ClientIdRef: !Sized`, +// and the `log` crate maintainers did not explicitly opt-in to allowing `!Sized` in their +// `Value::from_debug` impl, even though it might make sense to. +// +// this has the consequence that we can't natively log a `ClientIdRef` as a value; +// if we want to, we have to do `id_ref.to_owned()`. Which might be ok. +impl fmt::Debug for ClientIdRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("ClientIdRef") + .field(&obfuscate::Obfuscated::from(&self.0)) + .finish() + } +} + #[cfg(test)] impl ClientId { pub(crate) fn to_user_id(&self) -> String { diff --git a/crypto/src/mls/session/identifier.rs b/crypto/src/mls/session/identifier.rs index 7f23145482..3c7d0d66a7 100644 --- a/crypto/src/mls/session/identifier.rs +++ b/crypto/src/mls/session/identifier.rs @@ -1,17 +1,14 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; -use mls_crypto_provider::MlsCryptoProvider; +use openmls::prelude::CredentialType; use openmls_traits::types::SignatureScheme; -use super::{ - CredentialBundle, - error::{Error, Result}, -}; -use crate::{CertificateBundle, ClientId, RecursiveError, Session}; +use super::error::{Error, Result}; +use crate::{CertificateBundle, ClientId, RecursiveError, mls::session::id::ClientIdRef}; /// Used by consumers to initializes a MLS client. Encompasses all the client types available. /// Could be enriched later with Verifiable Presentations. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, derive_more::From)] pub enum ClientIdentifier { /// Basic keypair Basic(ClientId), @@ -22,7 +19,7 @@ pub enum ClientIdentifier { impl ClientIdentifier { /// Extract the unique [ClientId] from an identifier. Use with parsimony as, in case of a x509 /// certificate this leads to parsing the certificate - pub fn get_id(&self) -> Result> { + pub fn get_id(&self) -> Result> { match self { ClientIdentifier::Basic(id) => Ok(std::borrow::Cow::Borrowed(id)), ClientIdentifier::X509(certs) => { @@ -38,35 +35,11 @@ impl ClientIdentifier { } } - /// Generate a new CredentialBundle (Credential + KeyPair) for each ciphersuite. - /// This method does not persist them in the keystore ! - pub fn generate_credential_bundles( - self, - backend: &MlsCryptoProvider, - signature_schemes: HashSet, - ) -> Result> { + /// The credential type for this identifier + pub fn credential_type(&self) -> CredentialType { match self { - ClientIdentifier::Basic(id) => signature_schemes.iter().try_fold( - Vec::with_capacity(signature_schemes.len()), - |mut acc, &sc| -> Result<_> { - let cb = Session::new_basic_credential_bundle(&id, sc, backend)?; - acc.push((sc, id.clone(), cb)); - Ok(acc) - }, - ), - ClientIdentifier::X509(certs) => { - let cap = certs.len(); - certs - .into_iter() - .try_fold(Vec::with_capacity(cap), |mut acc, (sc, cert)| -> Result<_> { - let id = cert - .get_client_id() - .map_err(RecursiveError::mls_credential("getting client id"))?; - let cb = Session::new_x509_credential_bundle(cert)?; - acc.push((sc, id, cb)); - Ok(acc) - }) - } + ClientIdentifier::Basic(_) => CredentialType::Basic, + ClientIdentifier::X509(_) => CredentialType::X509, } } } diff --git a/crypto/src/mls/session/identities.rs b/crypto/src/mls/session/identities.rs index 0f21b06f0c..89ea9dc88b 100644 --- a/crypto/src/mls/session/identities.rs +++ b/crypto/src/mls/session/identities.rs @@ -1,72 +1,69 @@ use std::{collections::HashMap, ops::Deref, sync::Arc}; -use openmls::prelude::{Credential, SignaturePublicKey}; +use openmls::prelude::{Credential as MlsCredential, SignaturePublicKey}; use openmls_traits::types::SignatureScheme; use crate::{ - Session, - mls::{ - credential::{CredentialBundle, typ::MlsCredentialType}, - session::{ - SessionInner, - error::{Error, Result}, - }, + Credential, CredentialType, Session, + mls::session::{ + SessionInner, + error::{Error, Result}, }, }; /// In memory Map of a Session's identities: one per SignatureScheme. -/// We need `indexmap::IndexSet` because each `CredentialBundle` has to be unique and insertion +/// We need `indexmap::IndexSet` because each `Credential` has to be unique at insertion /// order matters in order to keep values sorted by time `created_at` so that we can identify most recent ones. /// -/// We keep each credential bundle inside an arc to avoid cloning them, as X509 credentials can get quite large. +/// We keep each credential inside an arc to avoid cloning them, as X509 credentials can get quite large. #[derive(Debug, Clone)] -pub(crate) struct Identities(HashMap>>); +pub(crate) struct Identities(HashMap>>); impl Identities { pub(crate) fn new(capacity: usize) -> Self { Self(HashMap::with_capacity(capacity)) } - pub(crate) async fn find_credential_bundle_by_public_key( + pub(crate) async fn find_credential_by_public_key( &self, sc: SignatureScheme, - ct: MlsCredentialType, + ct: CredentialType, pk: &SignaturePublicKey, - ) -> Option> { + ) -> Option> { self.0 .get(&sc)? .iter() .find(|c| { - let ct_match = ct == c.credential.credential_type().into(); - let pk_match = c.signature_key.public() == pk.as_slice(); + let ct_match = ct == c.mls_credential.credential_type(); + let pk_match = c.signature_key_pair.public() == pk.as_slice(); ct_match && pk_match }) .cloned() } - pub(crate) async fn find_most_recent_credential_bundle( + pub(crate) async fn find_most_recent_credential( &self, sc: SignatureScheme, - ct: MlsCredentialType, - ) -> Option> { + ct: CredentialType, + ) -> Option> { self.0 .get(&sc)? .iter() - .rfind(|c| ct == c.credential.credential_type().into()) + .rfind(|c| ct == c.mls_credential.credential_type()) .cloned() } /// Having `cb` requiring ownership kinda forces the caller to first persist it in the keystore and /// only then store it in this in-memory map - pub(crate) async fn push_credential_bundle(&mut self, sc: SignatureScheme, cb: CredentialBundle) -> Result<()> { - // this would mean we have messed something up and that we do no init this CredentialBundle from a keypair just inserted in the keystore - debug_assert_ne!(cb.created_at, 0); + pub(crate) async fn push_credential(&mut self, sc: SignatureScheme, cb: Credential) -> Result<()> { + // this would mean we have messed something up and that we do no init this Credential from a keypair just inserted in the keystore + debug_assert_ne!(cb.earliest_validity, 0); match self.0.get_mut(&sc) { Some(cbs) => { let already_exists = !cbs.insert(Arc::new(cb)); if already_exists { - return Err(Error::CredentialBundleConflict); + return Err(Error::CredentialConflict); } } None => { @@ -76,43 +73,43 @@ impl Identities { Ok(()) } - pub(crate) async fn remove(&mut self, credential: &Credential) -> Result<()> { + pub(crate) async fn remove(&mut self, credential: &MlsCredential) -> Result<()> { self.0.iter_mut().for_each(|(_, cbs)| { - cbs.retain(|c| c.credential() != credential); + cbs.retain(|c| c.mls_credential() != credential); }); Ok(()) } - pub(crate) fn iter(&self) -> impl Iterator)> + '_ { + pub(crate) fn iter(&self) -> impl Iterator)> + '_ { self.0.iter().flat_map(|(sc, cb)| cb.iter().map(|c| (*sc, c.clone()))) } } impl Session { - pub(crate) async fn find_most_recent_credential_bundle( + pub(crate) async fn find_most_recent_credential( &self, sc: SignatureScheme, - ct: MlsCredentialType, - ) -> Result> { + ct: CredentialType, + ) -> Result> { match self.inner.read().await.deref() { None => Err(Error::MlsNotInitialized), Some(SessionInner { identities, .. }) => identities - .find_most_recent_credential_bundle(sc, ct) + .find_most_recent_credential(sc, ct) .await .ok_or(Error::CredentialNotFound(ct)), } } - pub(crate) async fn find_credential_bundle_by_public_key( + pub(crate) async fn find_credential_by_public_key( &self, sc: SignatureScheme, - ct: MlsCredentialType, + ct: CredentialType, pk: &SignaturePublicKey, - ) -> Result> { + ) -> Result> { match self.inner.read().await.deref() { None => Err(Error::MlsNotInitialized), Some(SessionInner { identities, .. }) => identities - .find_credential_bundle_by_public_key(sc, ct, pk) + .find_credential_by_public_key(sc, ct, pk) .await .ok_or(Error::CredentialNotFound(ct)), } @@ -142,16 +139,16 @@ mod tests { let [mut central] = case.sessions().await; Box::pin(async move { let cert = central.get_intermediate_ca().cloned(); - let old = central.new_credential_bundle(&case, cert.as_ref()).await; + let old = central.new_credential(&case, cert.as_ref()).await; // wait to make sure we're not in the same second smol::Timer::after(core::time::Duration::from_secs(1)).await; - let new = central.new_credential_bundle(&case, cert.as_ref()).await; + let new = central.new_credential(&case, cert.as_ref()).await; assert_ne!(old, new); let found = central - .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type) + .find_most_recent_credential(case.signature_scheme(), case.credential_type) .await .unwrap(); assert_eq!(found.as_ref(), &new); @@ -169,16 +166,16 @@ mod tests { let mut to_search = None; for i in 0..N { let cert = central.get_intermediate_ca().cloned(); - let cb = central.new_credential_bundle(&case, cert.as_ref()).await; + let cb = central.new_credential(&case, cert.as_ref()).await; if i == r { to_search = Some(cb.clone()); } } let to_search = to_search.unwrap(); - let pk = SignaturePublicKey::from(to_search.signature_key.public()); + let pk = SignaturePublicKey::from(to_search.signature_key_pair.public()); let client = central.transaction.session().await.unwrap(); let found = client - .find_credential_bundle_by_public_key(case.signature_scheme(), case.credential_type, &pk) + .find_credential_by_public_key(case.signature_scheme(), case.credential_type, &pk) .await .unwrap(); assert_eq!(&to_search, found.as_ref()); @@ -197,8 +194,8 @@ mod tests { let client = central.session().await; let prev_count = client.identities_count().await.unwrap(); let cert = central.get_intermediate_ca().cloned(); - // this calls 'push_credential_bundle' under the hood - central.new_credential_bundle(&case, cert.as_ref()).await; + // this calls 'push_credential' under the hood + central.new_credential(&case, cert.as_ref()).await; let next_count = client.identities_count().await.unwrap(); assert_eq!(next_count, prev_count + 1); }) @@ -210,7 +207,7 @@ mod tests { let [mut central] = case.sessions().await; Box::pin(async move { let cert = central.get_intermediate_ca().cloned(); - let cb = central.new_credential_bundle(&case, cert.as_ref()).await; + let cb = central.new_credential(&case, cert.as_ref()).await; let client = central.transaction.session().await.unwrap(); let push = client .save_identity( @@ -220,10 +217,7 @@ mod tests { cb, ) .await; - assert!(matches!( - push.unwrap_err(), - mls::session::Error::CredentialBundleConflict - )); + assert!(matches!(push.unwrap_err(), mls::session::Error::CredentialConflict)); }) .await } diff --git a/crypto/src/mls/session/key_package.rs b/crypto/src/mls/session/key_package.rs index 281c2e1a2a..5fe741a586 100644 --- a/crypto/src/mls/session/key_package.rs +++ b/crypto/src/mls/session/key_package.rs @@ -2,17 +2,19 @@ use std::collections::{HashMap, HashSet}; use core_crypto_keystore::{ connection::FetchFromDatabase, - entities::{EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage}, + entities::{EntityFindParams, StoredEncryptionKeyPair, StoredHpkePrivateKey, StoredKeypackage}, }; use mls_crypto_provider::{Database, MlsCryptoProvider}; -use openmls::prelude::{Credential, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageRef, Lifetime}; +use openmls::prelude::{ + Credential as MlsCredential, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageRef, Lifetime, +}; use openmls_traits::OpenMlsCryptoProvider; use tls_codec::{Deserialize, Serialize}; use super::{Error, Result}; use crate::{ - KeystoreError, MlsCiphersuite, MlsConversationConfiguration, MlsCredentialType, MlsError, Session, - mls::{credential::CredentialBundle, session::SessionInner}, + Ciphersuite, Credential, CredentialType, KeystoreError, MlsConversationConfiguration, MlsError, Session, + mls::session::SessionInner, }; /// Default number of KeyPackages a client generates the first time it's created @@ -34,11 +36,11 @@ impl Session { /// /// # Errors /// KeyStore and OpenMls errors - pub async fn generate_one_keypackage_from_credential_bundle( + pub async fn generate_one_keypackage_from_credential( &self, backend: &MlsCryptoProvider, - cs: MlsCiphersuite, - cb: &CredentialBundle, + cs: Ciphersuite, + cb: &Credential, ) -> Result { let guard = self.inner.read().await; let SessionInner { @@ -54,10 +56,10 @@ impl Session { version: openmls::versions::ProtocolVersion::default(), }, backend, - &cb.signature_key, + &cb.signature_key_pair, CredentialWithKey { - credential: cb.credential.clone(), - signature_key: cb.signature_key.public().into(), + credential: cb.mls_credential.clone(), + signature_key: cb.signature_key_pair.public().into(), }, ) .await @@ -79,8 +81,8 @@ impl Session { pub async fn request_key_packages( &self, count: usize, - ciphersuite: MlsCiphersuite, - credential_type: MlsCredentialType, + ciphersuite: Ciphersuite, + credential_type: CredentialType, backend: &MlsCryptoProvider, ) -> Result> { // Auto-prune expired keypackages on request @@ -94,14 +96,14 @@ impl Session { .into_iter() // TODO: do this filtering in SQL when the schema is updated. Tracking issue: WPB-9599 .filter(|kp| - kp.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(kp.leaf_node().credential().credential_type()) == credential_type) + kp.ciphersuite() == ciphersuite.0 && kp.leaf_node().credential().credential_type() == credential_type) .collect::>(); let kpb_count = existing_kps.len(); let mut kps = if count > kpb_count { let to_generate = count - kpb_count; let cb = self - .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type) + .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type) .await?; self.generate_new_keypackages(backend, ciphersuite, &cb, to_generate) .await? @@ -118,15 +120,15 @@ impl Session { pub(crate) async fn generate_new_keypackages( &self, backend: &MlsCryptoProvider, - ciphersuite: MlsCiphersuite, - cb: &CredentialBundle, + ciphersuite: Ciphersuite, + cb: &Credential, count: usize, ) -> Result> { let mut kps = Vec::with_capacity(count); for _ in 0..count { let kp = self - .generate_one_keypackage_from_credential_bundle(backend, ciphersuite, cb) + .generate_one_keypackage_from_credential(backend, ciphersuite, cb) .await?; kps.push(kp); } @@ -138,10 +140,10 @@ impl Session { pub async fn valid_keypackages_count( &self, backend: &MlsCryptoProvider, - ciphersuite: MlsCiphersuite, - credential_type: MlsCredentialType, + ciphersuite: Ciphersuite, + credential_type: CredentialType, ) -> Result { - let kps: Vec = backend + let kps: Vec = backend .key_store() .find_all(EntityFindParams::default()) .await @@ -154,7 +156,7 @@ impl Session { // TODO: do this filtering in SQL when the schema is updated. Tracking issue: WPB-9599 .filter(|kp| { kp.as_ref() - .map(|b| b.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(b.leaf_node().credential().credential_type()) == credential_type) + .map(|b| b.ciphersuite() == ciphersuite.0 && b.leaf_node().credential().credential_type() == credential_type) .unwrap_or_default() }) { @@ -216,10 +218,7 @@ impl Session { let kp_ref = kp .hash_ref(backend.crypto()) .map_err(MlsError::wrap("computing keypackage hashref"))?; - grouped_kps - .entry(cred) - .and_modify(|kprfs| kprfs.push(kp_ref.clone())) - .or_insert(vec![kp_ref]); + grouped_kps.entry(cred).or_default().push(kp_ref); } for (credential, kps) in &grouped_kps { @@ -232,7 +231,7 @@ impl Session { .cred_delete_by_credential(credential.clone()) .await .map_err(KeystoreError::wrap("deleting credential"))?; - let credential = Credential::tls_deserialize(&mut credential.as_slice()) + let credential = MlsCredential::tls_deserialize(&mut credential.as_slice()) .map_err(Error::tls_deserialize("credential"))?; identities.remove(&credential).await?; } @@ -247,7 +246,7 @@ impl Session { /// * Signature KeyPairs & Credentials (use [Self::prune_keypackages_and_credential]) async fn _prune_keypackages<'a>( &self, - kps: &'a [(MlsKeyPackage, KeyPackage)], + kps: &'a [(StoredKeypackage, KeyPackage)], keystore: &Database, refs: impl IntoIterator, ) -> Result, Error> { @@ -276,15 +275,15 @@ impl Session { // note: we're cloning the iterator here, not the data for (kp, kp_ref) in kp_to_delete.clone() { keystore - .remove::(kp_ref.as_slice()) + .remove::(kp_ref.as_slice()) .await .map_err(KeystoreError::wrap("removing key package from keystore"))?; keystore - .remove::(kp.hpke_init_key().as_slice()) + .remove::(kp.hpke_init_key().as_slice()) .await .map_err(KeystoreError::wrap("removing private key from keystore"))?; keystore - .remove::(kp.leaf_node().encryption_key().as_slice()) + .remove::(kp.leaf_node().encryption_key().as_slice()) .await .map_err(KeystoreError::wrap("removing encryption keypair from keystore"))?; } @@ -292,8 +291,11 @@ impl Session { Ok(kp_to_delete.map(|(_, kpref)| kpref.as_slice()).collect()) } - async fn find_all_keypackages(&self, keystore: &Database) -> Result> { - let kps: Vec = keystore + pub(super) async fn find_all_keypackages( + &self, + keystore: &Database, + ) -> Result> { + let kps: Vec = keystore .find_all(EntityFindParams::default()) .await .map_err(KeystoreError::wrap("finding all keypackages"))?; @@ -396,7 +398,7 @@ mod tests { // Generate 5 Basic key packages first let _basic_key_packages = session_context .transaction - .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::Basic, 5) + .get_or_create_client_keypackages(cipher_suite, CredentialType::Basic, 5) .await .unwrap(); @@ -433,14 +435,16 @@ mod tests { // Request X509 key packages let x509_key_packages = session_context .transaction - .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::X509, 5) + .get_or_create_client_keypackages(cipher_suite, CredentialType::X509, 5) .await .unwrap(); // Verify that the key packages are X509 assert!( - x509_key_packages.iter().all(|kp| MlsCredentialType::X509 - == MlsCredentialType::from(kp.leaf_node().credential().credential_type())) + x509_key_packages + .iter() + .all(|kp| CredentialType::X509 + == CredentialType::from(kp.leaf_node().credential().credential_type())) ); }) .await diff --git a/crypto/src/mls/session/mod.rs b/crypto/src/mls/session/mod.rs index eea38e8761..308d74e754 100644 --- a/crypto/src/mls/session/mod.rs +++ b/crypto/src/mls/session/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod config; +mod credential; pub(crate) mod e2e_identity; mod epoch_observer; mod error; @@ -9,13 +10,12 @@ pub(crate) mod identities; pub(crate) mod key_package; pub(crate) mod user_id; -use std::{collections::HashSet, ops::Deref, sync::Arc}; +use std::{ops::Deref, sync::Arc}; use async_lock::RwLock; use core_crypto_keystore::{ CryptoKeystoreError, Database, - connection::FetchFromDatabase, - entities::{EntityFindParams, MlsCredential, MlsSignatureKeyPair}, + entities::{StoredCredential, StoredSignatureKeypair}, }; pub use epoch_observer::EpochObserver; pub(crate) use error::{Error, Result}; @@ -24,20 +24,17 @@ use identities::Identities; use key_package::KEYPACKAGE_DEFAULT_LIFETIME; use log::debug; use mls_crypto_provider::{EntropySeed, MlsCryptoProvider}; -use openmls::prelude::{Credential, CredentialType}; -use openmls_basic_credential::SignatureKeyPair; -use openmls_traits::{OpenMlsCryptoProvider, crypto::OpenMlsCrypto, types::SignatureScheme}; -use openmls_x509_credential::CertificateKeyPair; -use tls_codec::{Deserialize, Serialize}; +use openmls_traits::{OpenMlsCryptoProvider, types::SignatureScheme}; +use tls_codec::Serialize; use crate::{ - CertificateBundle, ClientId, ClientIdentifier, CoreCrypto, HistorySecret, KeystoreError, LeafError, MlsCiphersuite, - MlsCredentialType, MlsError, MlsTransport, RecursiveError, ValidatedSessionConfig, + CertificateBundle, Ciphersuite, ClientId, ClientIdRef, ClientIdentifier, CoreCrypto, CredentialRef, CredentialType, + HistorySecret, KeystoreError, LeafError, MlsError, MlsTransport, RecursiveError, ValidatedSessionConfig, group_store::GroupStore, mls::{ self, HasSessionAndCrypto, conversation::{ConversationIdRef, ImmutableConversation}, - credential::{CredentialBundle, ext::CredentialExt}, + credential::{Credential, credential_ref, ext::CredentialExt}, }, }; @@ -109,7 +106,7 @@ impl Session { // doing all subsequent actions inside a single transaction, though it forces us to clone // a few Arcs and locks. let session = Self { - crypto_provider: mls_backend.clone(), + crypto_provider: mls_backend, inner: Default::default(), transport: Arc::new(None.into()), epoch_observer: Arc::new(None.into()), @@ -124,7 +121,13 @@ impl Session { if let Some(id) = client_id { cc.mls - .init(ClientIdentifier::Basic(id), ciphersuites.as_slice(), &mls_backend) + .init( + ClientIdentifier::Basic(id), + &ciphersuites + .iter() + .map(|ciphersuite| ciphersuite.signature_algorithm()) + .collect::>(), + ) .await .map_err(RecursiveError::mls_client("initializing mls client"))? } @@ -148,57 +151,67 @@ impl Session { } /// Initializes the client. - /// If the client's cryptographic material is already stored in the keystore, it loads it - /// Otherwise, it is being created. /// - /// # Arguments - /// * `identifier` - client identifier ; either a [ClientId] or a x509 certificate chain - /// * `ciphersuites` - all ciphersuites this client is supposed to support - /// * `backend` - the KeyStore and crypto provider to read identities from - /// - /// # Errors - /// KeyStore and OpenMls errors can happen - pub async fn init( - &self, - identifier: ClientIdentifier, - ciphersuites: &[MlsCiphersuite], - backend: &MlsCryptoProvider, - ) -> Result<()> { + /// Loads any cryptographic material already present in the keystore, but does not create any. + /// If no credentials are present in the keystore, then one _must_ be created and added to the + /// session before it can be used. + pub async fn init(&self, identifier: ClientIdentifier, signature_schemes: &[SignatureScheme]) -> Result<()> { self.ensure_unready().await?; - let id = identifier.get_id()?; + let client_id = identifier.get_id()?.into_owned(); - let credentials = backend - .key_store() - .find_all::(EntityFindParams::default()) + let mut identities = Identities::new(signature_schemes.len()); + let cache = CredentialRef::load_cache(&self.crypto_provider.keystore()) .await - .map_err(KeystoreError::wrap("finding all mls credentials"))?; - - let credentials = credentials - .into_iter() - .filter(|mls_credential| mls_credential.id.as_slice() == id.as_slice()) - .map(|mls_credential| -> Result<_> { - let credential = Credential::tls_deserialize(&mut mls_credential.credential.as_slice()) - .map_err(Error::tls_deserialize("mls credential"))?; - Ok((credential, mls_credential.created_at)) - }) - .collect::>>()?; - - if credentials.is_empty() { - debug!(ciphersuites:? = ciphersuites; "Generating client"); - self.generate(identifier, backend, ciphersuites).await?; - } else { - let signature_schemes = ciphersuites - .iter() - .map(|cs| cs.signature_algorithm()) - .collect::>(); - let load_result = self.load(backend, id.as_ref(), credentials, signature_schemes).await; - if let Err(Error::ClientSignatureNotFound) = load_result { - debug!(ciphersuites:? = ciphersuites; "Client signature not found. Generating client"); - self.generate(identifier, backend, ciphersuites).await?; - } else { - load_result?; + .map_err(RecursiveError::mls_credential_ref( + "loading credential ref cache while initializing session", + ))?; + for signature_scheme in signature_schemes { + // This is a _speculative_ credential ref. If it doesn't exist in the DB, + // that's not a problem, it just means the user has not created / stored the credential + // prior to initializing this session. + let credential_ref = CredentialRef::new( + identifier + .get_id() + .map_err(RecursiveError::mls_client( + "getting id from identifier to make credential ref", + ))? + .into_owned(), + identifier.credential_type(), + *signature_scheme, + ); + + match credential_ref.load_with_cache(&cache).await { + Err(credential_ref::Error::CredentialNotFound) => { + // no worries, do nothing, it's fine + } + Err(err) => { + return Err(RecursiveError::mls_credential_ref( + "attempting to load credential refs in session init", + )(err) + .into()); + } + Ok(credential_result_iter) => { + for credential_result in credential_result_iter { + // if the credential _exists_ but we couldn't load it, that's worth failing for + let credential = credential_result + .map_err(RecursiveError::mls_credential_ref("loading credential in session init"))?; + identities + .push_credential(credential.signature_key_pair.signature_scheme(), credential) + .await + .map_err(RecursiveError::mls_client( + "pushing credential to identities when initializing session", + ))?; + } + } } - }; + } + + self.replace_inner(SessionInner { + id: client_id, + identities, + keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME, + }) + .await; Ok(()) } @@ -249,54 +262,14 @@ impl Session { /// * `credential_type` - of the credential to look for pub async fn public_key( &self, - ciphersuite: MlsCiphersuite, - credential_type: MlsCredentialType, + ciphersuite: Ciphersuite, + credential_type: CredentialType, ) -> crate::mls::Result> { let cb = self - .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type) + .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type) .await - .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?; - Ok(cb.signature_key.to_public_vec()) - } - - pub(crate) fn new_basic_credential_bundle( - id: &ClientId, - sc: SignatureScheme, - backend: &MlsCryptoProvider, - ) -> Result { - let (sk, pk) = backend - .crypto() - .signature_key_gen(sc) - .map_err(MlsError::wrap("generating a signature key"))?; - - let signature_key = SignatureKeyPair::from_raw(sc, sk, pk); - let credential = Credential::new_basic(id.to_vec()); - let cb = CredentialBundle { - credential, - signature_key, - created_at: 0, - }; - - Ok(cb) - } - - pub(crate) fn new_x509_credential_bundle(cert: CertificateBundle) -> Result { - let created_at = cert - .get_created_at() - .map_err(RecursiveError::mls_credential("getting credetntial created at"))?; - let (sk, ..) = cert.private_key.into_parts(); - let chain = cert.certificate_chain; - - let kp = CertificateKeyPair::new(sk, chain.clone()).map_err(MlsError::wrap("creating certificate key pair"))?; - - let credential = Credential::new_x509(chain).map_err(MlsError::wrap("creating x509 credential"))?; - - let cb = CredentialBundle { - credential, - signature_key: kp.0, - created_at, - }; - Ok(cb) + .map_err(RecursiveError::mls_client("finding most recent credential"))?; + Ok(cb.signature_key_pair.to_public_vec()) } /// Checks if a given conversation id exists locally @@ -339,117 +312,6 @@ impl Session { .map_err(Into::into) } - /// Generates a brand new client from scratch - pub(crate) async fn generate( - &self, - identifier: ClientIdentifier, - backend: &MlsCryptoProvider, - ciphersuites: &[MlsCiphersuite], - ) -> Result<()> { - self.ensure_unready().await?; - let id = identifier.get_id()?; - let signature_schemes = ciphersuites - .iter() - .map(|cs| cs.signature_algorithm()) - .collect::>(); - self.replace_inner(SessionInner { - id: id.into_owned(), - identities: Identities::new(signature_schemes.len()), - keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME, - }) - .await; - - let identities = identifier.generate_credential_bundles(backend, signature_schemes)?; - - for (sc, id, cb) in identities { - self.save_identity(&backend.keystore(), Some(&id), sc, cb).await?; - } - - Ok(()) - } - - /// Loads the client from the keystore. - pub(crate) async fn load( - &self, - backend: &MlsCryptoProvider, - id: &ClientId, - mut credentials: Vec<(Credential, u64)>, - signature_schemes: HashSet, - ) -> Result<()> { - self.ensure_unready().await?; - let mut identities = Identities::new(signature_schemes.len()); - - // ensures we load credentials in chronological order - credentials.sort_by_key(|(_, timestamp)| *timestamp); - - let stored_signature_keypairs = backend - .key_store() - .find_all::(EntityFindParams::default()) - .await - .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?; - - for signature_scheme in signature_schemes { - let signature_keypair = stored_signature_keypairs - .iter() - .find(|skp| skp.signature_scheme == (signature_scheme as u16)); - - let signature_key = if let Some(kp) = signature_keypair { - SignatureKeyPair::tls_deserialize(&mut kp.keypair.as_slice()) - .map_err(Error::tls_deserialize("signature keypair"))? - } else { - let (private_key, public_key) = backend - .crypto() - .signature_key_gen(signature_scheme) - .map_err(MlsError::wrap("generating signature key"))?; - let keypair = SignatureKeyPair::from_raw(signature_scheme, private_key, public_key.clone()); - let raw_keypair = keypair - .tls_serialize_detached() - .map_err(Error::tls_serialize("raw keypair"))?; - let store_keypair = - MlsSignatureKeyPair::new(signature_scheme, public_key, raw_keypair, id.as_slice().into()); - backend - .key_store() - .save(store_keypair.clone()) - .await - .map_err(KeystoreError::wrap("storing keypairs in keystore"))?; - SignatureKeyPair::tls_deserialize(&mut store_keypair.keypair.as_slice()) - .map_err(Error::tls_deserialize("signature keypair"))? - }; - - for (credential, created_at) in &credentials { - match credential.mls_credential() { - openmls::prelude::MlsCredentialType::Basic(_) => { - if id.as_slice() != credential.identity() { - return Err(Error::WrongCredential); - } - } - openmls::prelude::MlsCredentialType::X509(cert) => { - let spk = cert - .extract_public_key() - .map_err(RecursiveError::mls_credential("extracting public key"))? - .ok_or(LeafError::InternalMlsError)?; - if signature_key.public() != spk { - return Err(Error::WrongCredential); - } - } - }; - let cb = CredentialBundle { - credential: credential.clone(), - signature_key: signature_key.clone(), - created_at: *created_at, - }; - identities.push_credential_bundle(signature_scheme, cb).await?; - } - } - self.replace_inner(SessionInner { - id: id.clone(), - identities, - keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME, - }) - .await; - Ok(()) - } - /// Restore from an external [`HistorySecret`]. pub(crate) async fn restore_from_history_secret(&self, history_secret: HistorySecret) -> Result<()> { self.ensure_unready().await?; @@ -475,10 +337,10 @@ impl Session { pub(crate) async fn save_identity( &self, keystore: &Database, - id: Option<&ClientId>, + id: Option<&ClientIdRef>, signature_scheme: SignatureScheme, - mut credential_bundle: CredentialBundle, - ) -> Result { + mut credential: Credential, + ) -> Result { let mut guard = self.inner.write().await; let SessionInner { id: existing_id, @@ -486,45 +348,43 @@ impl Session { .. } = guard.as_mut().ok_or(Error::MlsNotInitialized)?; - let id = id.unwrap_or(existing_id); + let id = id.unwrap_or(existing_id.as_ref()); - let credential = credential_bundle - .credential + let credential_data = credential + .mls_credential .tls_serialize_detached() - .map_err(Error::tls_serialize("credential bundle"))?; - let credential = MlsCredential { - id: id.clone().into(), - credential, + .map_err(Error::tls_serialize("credential"))?; + let stored_credential = StoredCredential { + id: id.as_slice().to_owned(), + credential: credential_data, created_at: 0, }; - let credential = keystore - .save(credential) + let stored_credential = keystore + .save(stored_credential) .await .map_err(KeystoreError::wrap("saving credential"))?; - let sign_kp = MlsSignatureKeyPair::new( + let sign_kp = StoredSignatureKeypair::new( signature_scheme, - credential_bundle.signature_key.to_public_vec(), - credential_bundle - .signature_key + credential.signature_key_pair.to_public_vec(), + credential + .signature_key_pair .tls_serialize_detached() .map_err(Error::tls_serialize("signature keypair"))?, - id.clone().into(), + id.as_slice().to_owned(), ); keystore.save(sign_kp).await.map_err(|e| match e { - CryptoKeystoreError::AlreadyExists(_) => Error::CredentialBundleConflict, + CryptoKeystoreError::AlreadyExists(_) => Error::CredentialConflict, _ => KeystoreError::wrap("saving mls signature key pair")(e).into(), })?; - // set the creation date of the signature keypair which is the same for the CredentialBundle - credential_bundle.created_at = credential.created_at; + // set the creation date of the signature keypair which is the same for the Credential + credential.earliest_validity = stored_credential.created_at; - identities - .push_credential_bundle(signature_scheme, credential_bundle.clone()) - .await?; + identities.push_credential(signature_scheme, credential.clone()).await?; - Ok(credential_bundle) + Ok(credential) } /// Retrieves the client's client id. This is free-form and not inspected. @@ -541,58 +401,57 @@ impl Session { None => false, Some(SessionInner { identities, .. }) => identities .iter() - .any(|(_, cred)| cred.credential().credential_type() == CredentialType::X509), + .any(|(_, cred)| cred.mls_credential().credential_type() == CredentialType::X509), } } - pub(crate) async fn get_most_recent_or_create_credential_bundle( + pub(crate) async fn get_most_recent_or_create_credential( &self, backend: &MlsCryptoProvider, sc: SignatureScheme, - ct: MlsCredentialType, - ) -> Result> { + ct: CredentialType, + ) -> Result> { match ct { - MlsCredentialType::Basic => { - self.init_basic_credential_bundle_if_missing(backend, sc).await?; - self.find_most_recent_credential_bundle(sc, ct).await + CredentialType::Basic => { + self.init_basic_credential_if_missing(backend, sc).await?; + self.find_most_recent_credential(sc, ct).await } - MlsCredentialType::X509 => self - .find_most_recent_credential_bundle(sc, ct) - .await - .map_err(|e| match e { - Error::CredentialNotFound(_) => LeafError::E2eiEnrollmentNotDone.into(), - _ => e, - }), + CredentialType::X509 => self.find_most_recent_credential(sc, ct).await.map_err(|e| match e { + Error::CredentialNotFound(_) => LeafError::E2eiEnrollmentNotDone.into(), + _ => e, + }), + CredentialType::Unknown(n) => Err(Error::UnknownCredential(n)), } } - pub(crate) async fn init_basic_credential_bundle_if_missing( + pub(crate) async fn init_basic_credential_if_missing( &self, backend: &MlsCryptoProvider, sc: SignatureScheme, ) -> Result<()> { - let existing_cb = self - .find_most_recent_credential_bundle(sc, MlsCredentialType::Basic) - .await; + let existing_cb = self.find_most_recent_credential(sc, CredentialType::Basic).await; if matches!(existing_cb, Err(Error::CredentialNotFound(_))) { let id = self.id().await?; - debug!(id:% = &id; "Initializing basic credential bundle"); - let cb = Self::new_basic_credential_bundle(&id, sc, backend)?; + debug!(id:% = &id; "Initializing basic credential"); + let cb = Credential::basic(sc, id, backend).map_err(RecursiveError::mls_credential( + "generating credential to replace missing", + ))?; self.save_identity(&backend.keystore(), None, sc, cb).await?; } Ok(()) } - pub(crate) async fn save_new_x509_credential_bundle( + pub(crate) async fn save_new_x509_credential( &self, keystore: &Database, sc: SignatureScheme, cb: CertificateBundle, - ) -> Result { + ) -> Result { let id = cb .get_client_id() .map_err(RecursiveError::mls_credential("getting client id"))?; - let cb = Self::new_x509_credential_bundle(cb)?; + let cb = + Credential::x509(cb).map_err(RecursiveError::mls_credential("generating new x509 credential to save"))?; self.save_identity(keystore, Some(&id), sc, cb).await } } @@ -618,15 +477,25 @@ mod tests { let user_uuid = uuid::Uuid::new_v4(); let rnd_id = rand::random::(); let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated()); - let identity = match case.credential_type { - MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.as_str().into()), - MlsCredentialType::X509 => { + let client_id = ClientId(client_id.into_bytes()); + + let mut credential; + match case.credential_type { + CredentialType::Basic => { + credential = Credential::basic(case.signature_scheme(), client_id, &self.crypto_provider).unwrap(); + } + CredentialType::X509 => { let signer = signer.expect("Missing intermediate CA"); - CertificateBundle::rand_identifier(&client_id, &[signer]) + let certs = CertificateBundle::rand_identifier_certs(&client_id, &[signer]); + let cert = certs.get(&signer.signature_scheme).unwrap(); + credential = Credential::x509(cert.to_owned()).unwrap(); } + CredentialType::Unknown(_) => panic!("unknown credential types are unsupported"), }; - let backend = self.crypto_provider.clone(); - self.generate(identity, &backend, &[case.ciphersuite()]).await?; + + let credential_ref = credential.save(&self.crypto_provider.keystore()).await.unwrap(); + self.add_credential(&credential_ref).await.unwrap(); + Ok(()) } @@ -643,30 +512,27 @@ mod tests { pub(crate) async fn generate_one_keypackage( &self, backend: &MlsCryptoProvider, - cs: MlsCiphersuite, - ct: MlsCredentialType, + cs: Ciphersuite, + ct: CredentialType, ) -> Result { - let cb = self - .find_most_recent_credential_bundle(cs.signature_algorithm(), ct) - .await?; - self.generate_one_keypackage_from_credential_bundle(backend, cs, &cb) - .await + let cb = self.find_most_recent_credential(cs.signature_algorithm(), ct).await?; + self.generate_one_keypackage_from_credential(backend, cs, &cb).await } /// Count the entities pub async fn count_entities(&self) -> EntitiesCount { let keystore = self.crypto_provider.keystore(); - let credential = keystore.count::().await.unwrap(); - let encryption_keypair = keystore.count::().await.unwrap(); - let epoch_encryption_keypair = keystore.count::().await.unwrap(); - let enrollment = keystore.count::().await.unwrap(); + let credential = keystore.count::().await.unwrap(); + let encryption_keypair = keystore.count::().await.unwrap(); + let epoch_encryption_keypair = keystore.count::().await.unwrap(); + let enrollment = keystore.count::().await.unwrap(); let group = keystore.count::().await.unwrap(); - let hpke_private_key = keystore.count::().await.unwrap(); - let key_package = keystore.count::().await.unwrap(); + let hpke_private_key = keystore.count::().await.unwrap(); + let key_package = keystore.count::().await.unwrap(); let pending_group = keystore.count::().await.unwrap(); let pending_messages = keystore.count::().await.unwrap(); - let psk_bundle = keystore.count::().await.unwrap(); - let signature_keypair = keystore.count::().await.unwrap(); + let psk_bundle = keystore.count::().await.unwrap(); + let signature_keypair = keystore.count::().await.unwrap(); EntitiesCount { credential, encryption_keypair, diff --git a/crypto/src/proteus.rs b/crypto/src/proteus.rs index 1de52dd5dd..073c09660b 100644 --- a/crypto/src/proteus.rs +++ b/crypto/src/proteus.rs @@ -589,7 +589,7 @@ mod tests { use super::*; use crate::{ - CertificateBundle, ClientIdentifier, MlsCredentialType, Session, SessionConfig, + CertificateBundle, ClientIdentifier, CredentialType, Session, SessionConfig, test_utils::{proteus_utils::*, x509::X509TestChain, *}, }; @@ -622,6 +622,8 @@ mod tests { #[apply(all_cred_cipher)] async fn cc_can_2_phase_init(case: TestContext) { + use crate::ClientId; + #[cfg(not(target_family = "wasm"))] let (path, db_file) = tmp_db_file(); #[cfg(target_family = "wasm")] @@ -645,12 +647,13 @@ mod tests { // proteus is initialized, prekeys can be generated assert!(transaction.proteus_new_prekey(1).await.is_ok()); // 👇 and so a unique 'client_id' can be fetched from wire-server - let client_id = "alice"; + let client_id = ClientId::from("alice"); let identifier = match case.credential_type { - MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.into()), - MlsCredentialType::X509 => { - CertificateBundle::rand_identifier(client_id, &[x509_test_chain.find_local_intermediate_ca()]) + CredentialType::Basic => ClientIdentifier::Basic(client_id.into()), + CredentialType::X509 => { + CertificateBundle::rand_identifier(&client_id, &[x509_test_chain.find_local_intermediate_ca()]) } + CredentialType::Unknown(_) => panic!("unknown credential types are unsupported"), }; transaction.mls_init(identifier, &[case.ciphersuite()]).await.unwrap(); // expect MLS to work diff --git a/crypto/src/test_utils/context.rs b/crypto/src/test_utils/context.rs index 840707ca8a..48eddca4be 100644 --- a/crypto/src/test_utils/context.rs +++ b/crypto/src/test_utils/context.rs @@ -3,12 +3,13 @@ use std::sync::Arc; use core_crypto_keystore::{ connection::FetchFromDatabase, entities::{ - EntityFindParams, MlsCredential, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage, MlsSignatureKeyPair, + EntityFindParams, StoredCredential, StoredEncryptionKeyPair, StoredHpkePrivateKey, StoredKeypackage, + StoredSignatureKeypair, }, }; use openmls::prelude::{ - Credential, CredentialWithKey, CryptoConfig, ExternalSender, HpkePublicKey, KeyPackage, KeyPackageIn, Lifetime, - SignaturePublicKey, + Credential as MlsCredential, CredentialWithKey, CryptoConfig, ExternalSender, HpkePublicKey, KeyPackage, + KeyPackageIn, Lifetime, SignaturePublicKey, }; use openmls_traits::{OpenMlsCryptoProvider, crypto::OpenMlsCrypto, types::SignatureScheme}; use tls_codec::Serialize; @@ -20,13 +21,13 @@ use super::{ test_conversation::operation_guard::{Commit, OperationGuard}, }; use crate::{ - CertificateBundle, CoreCrypto, MlsCiphersuite, MlsConversationConfiguration, MlsConversationDecryptMessage, - MlsCredentialType, RecursiveError, Session, WireIdentity, + CertificateBundle, Ciphersuite, CoreCrypto, CredentialType, MlsConversationConfiguration, + MlsConversationDecryptMessage, RecursiveError, WireIdentity, e2e_identity::{ device_status::DeviceStatus, id::{QualifiedE2eiClientId, WireQualifiedClientId}, }, - mls::credential::{CredentialBundle, ext::CredentialExt}, + mls::credential::{Credential, ext::CredentialExt}, test_utils::{SessionContext, TestContext, x509::X509Certificate}, }; @@ -50,7 +51,7 @@ impl SessionContext { pub async fn new_keypackage(&self, case: &TestContext, lifetime: Lifetime) -> KeyPackage { let cb = self - .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type) + .find_most_recent_credential(case.signature_scheme(), case.credential_type) .await .unwrap(); KeyPackage::builder() @@ -62,23 +63,23 @@ impl SessionContext { version: openmls::versions::ProtocolVersion::default(), }, &self.transaction.mls_provider().await.unwrap(), - &cb.signature_key, + &cb.signature_key_pair, CredentialWithKey { - credential: cb.credential.clone(), - signature_key: cb.signature_key.public().into(), + credential: cb.mls_credential.clone(), + signature_key: cb.signature_key_pair.public().into(), }, ) .await .unwrap() } - pub async fn count_key_package(&self, cs: MlsCiphersuite, ct: Option) -> usize { + pub async fn count_key_package(&self, cs: Ciphersuite, ct: Option) -> usize { self.transaction .mls_provider() .await .unwrap() .key_store() - .find_all::(EntityFindParams::default()) + .find_all::(EntityFindParams::default()) .await .unwrap() .into_iter() @@ -95,7 +96,7 @@ impl SessionContext { self.rand_key_package_of_type(case, case.credential_type).await } - pub async fn rand_key_package_of_type(&self, case: &TestContext, ct: MlsCredentialType) -> KeyPackageIn { + pub async fn rand_key_package_of_type(&self, case: &TestContext, ct: CredentialType) -> KeyPackageIn { let client = self.transaction.session().await.unwrap(); client .generate_one_keypackage(&self.transaction.mls_provider().await.unwrap(), case.ciphersuite(), ct) @@ -122,90 +123,87 @@ impl SessionContext { pub async fn client_signature_key(&self, case: &TestContext) -> SignaturePublicKey { let (sc, ct) = (case.signature_scheme(), case.credential_type); let client = self.session().await; - let cb = client.find_most_recent_credential_bundle(sc, ct).await.unwrap(); - SignaturePublicKey::from(cb.signature_key.public()) + let cb = client.find_most_recent_credential(sc, ct).await.unwrap(); + SignaturePublicKey::from(cb.signature_key_pair.public()) } pub async fn get_user_id(&self) -> String { WireQualifiedClientId::from(self.get_client_id().await).get_user_id() } - pub async fn new_credential_bundle( - &mut self, - case: &TestContext, - signer: Option<&X509Certificate>, - ) -> CredentialBundle { + pub async fn new_credential(&mut self, case: &TestContext, signer: Option<&X509Certificate>) -> Credential { let backend = &self.transaction.mls_provider().await.unwrap(); let transaction = &self.transaction.keystore().await.unwrap(); let client = self.session().await; let client_id = client.id().await.unwrap(); match case.credential_type { - MlsCredentialType::Basic => { - let cb = Session::new_basic_credential_bundle(&client_id, case.signature_scheme(), backend).unwrap(); + CredentialType::Basic => { + let cb = Credential::basic(case.signature_scheme(), client_id, backend).unwrap(); client .save_identity(&backend.keystore(), None, case.signature_scheme(), cb) .await .unwrap() } - MlsCredentialType::X509 => { + CredentialType::X509 => { let cert_bundle = CertificateBundle::rand(&client_id, signer.unwrap()); client - .save_new_x509_credential_bundle(transaction, case.signature_scheme(), cert_bundle) + .save_new_x509_credential(transaction, case.signature_scheme(), cert_bundle) .await .unwrap() } + CredentialType::Unknown(_) => panic!("unknown credential types are unsupported"), } } - pub async fn find_most_recent_credential_bundle( + pub async fn find_most_recent_credential( &self, sc: SignatureScheme, - ct: MlsCredentialType, - ) -> Option> { - self.session.find_most_recent_credential_bundle(sc, ct).await.ok() + ct: CredentialType, + ) -> Option> { + self.session.find_most_recent_credential(sc, ct).await.ok() } - pub async fn find_credential_bundle( + pub async fn find_credential( &self, sc: SignatureScheme, - ct: MlsCredentialType, + ct: CredentialType, pk: &SignaturePublicKey, - ) -> Option> { + ) -> Option> { self.session() .await - .find_credential_bundle_by_public_key(sc, ct, pk) + .find_credential_by_public_key(sc, ct, pk) .await .ok() } - pub async fn find_signature_keypair_from_keystore(&self, id: &[u8]) -> Option { + pub async fn find_signature_keypair_from_keystore(&self, id: &[u8]) -> Option { self.transaction .keystore() .await .unwrap() - .find::(id) + .find::(id) .await .unwrap() } - pub async fn find_hpke_private_key_from_keystore(&self, skp: &HpkePublicKey) -> Option { + pub async fn find_hpke_private_key_from_keystore(&self, skp: &HpkePublicKey) -> Option { self.transaction .keystore() .await .unwrap() - .find::(&skp.tls_serialize_detached().unwrap()) + .find::(&skp.tls_serialize_detached().unwrap()) .await .unwrap() } - pub async fn find_credential_from_keystore(&self, cb: &CredentialBundle) -> Option { - let credential = cb.credential.tls_serialize_detached().unwrap(); + pub async fn find_credential_from_keystore(&self, cb: &Credential) -> Option { + let credential = cb.mls_credential.tls_serialize_detached().unwrap(); self.transaction .keystore() .await .unwrap() - .find_all::(EntityFindParams::default()) + .find_all::(EntityFindParams::default()) .await .unwrap() .into_iter() @@ -217,7 +215,7 @@ impl SessionContext { .keystore() .await .unwrap() - .count::() + .count::() .await .unwrap() } @@ -227,7 +225,7 @@ impl SessionContext { .keystore() .await .unwrap() - .count::() + .count::() .await .unwrap() } @@ -237,7 +235,7 @@ impl SessionContext { .keystore() .await .unwrap() - .count::() + .count::() .await .unwrap() } @@ -248,12 +246,12 @@ impl SessionContext { handle: &str, display_name: &str, signer: &X509Certificate, - ) -> CredentialBundle { + ) -> Credential { let cid = QualifiedE2eiClientId::try_from(self.get_client_id().await.as_slice()).unwrap(); let new_cert = CertificateBundle::new(handle, display_name, Some(&cid), None, signer); let client = self.session().await; client - .save_new_x509_credential_bundle( + .save_new_x509_credential( &self.transaction.keystore().await.unwrap(), case.signature_scheme(), new_cert, @@ -265,8 +263,8 @@ impl SessionContext { pub(crate) async fn create_key_packages_and_update_credential_in_all_conversations<'a>( &self, all_conversations: Vec>, - cb: &CredentialBundle, - cipher_suite: MlsCiphersuite, + cb: &Credential, + cipher_suite: Ciphersuite, key_package_count: usize, ) -> Result> { let mut commits = Vec::with_capacity(all_conversations.len()); @@ -303,9 +301,9 @@ impl SessionContext { pub async fn verify_sender_identity(&self, case: &TestContext, decrypted: &MlsConversationDecryptMessage) { let (sc, ct) = (case.signature_scheme(), case.credential_type); let client = self.session().await; - let sender_cb = client.find_most_recent_credential_bundle(sc, ct).await.unwrap(); + let sender_cb = client.find_most_recent_credential(sc, ct).await.unwrap(); - if let openmls::prelude::MlsCredentialType::X509(certificate) = &sender_cb.credential().mls_credential() { + if let openmls::prelude::MlsCredentialType::X509(certificate) = &sender_cb.mls_credential().mls_credential() { let mls_identity = certificate.extract_identity(case.ciphersuite(), None).unwrap(); let mls_client_id = mls_identity.client_id.as_bytes(); @@ -371,7 +369,7 @@ impl SessionContext { let signature_key = SignaturePublicKey::from(pk); - let credential = Credential::new_basic(b"server".to_vec()); + let credential = MlsCredential::new_basic(b"server".to_vec()); ExternalSender::new(signature_key, credential) } diff --git a/crypto/src/test_utils/mod.rs b/crypto/src/test_utils/mod.rs index f25cc9f3d6..254bd3c360 100644 --- a/crypto/src/test_utils/mod.rs +++ b/crypto/src/test_utils/mod.rs @@ -32,7 +32,7 @@ use crate::{ test_utils::x509::{CertificateParams, X509TestChain, X509TestChainActorArg, X509TestChainArgs}, transaction_context::TransactionContext, }; -pub use crate::{ClientIdentifier, INITIAL_KEYING_MATERIAL_COUNT, MlsCredentialType}; +pub use crate::{ClientIdentifier, CredentialType, INITIAL_KEYING_MATERIAL_COUNT}; pub const GROUP_SAMPLE_SIZE: usize = 9; @@ -90,6 +90,7 @@ use crate::{RecursiveError::Test, ephemeral::HistorySecret, test_utils::TestErro pub struct SessionContext { pub transaction: TransactionContext, pub session: Session, + pub identifier: ClientIdentifier, mls_transport: Arc>>, x509_test_chain: Arc>, history_observer: Arc>>>, @@ -139,7 +140,7 @@ impl SessionContext { } transaction - .mls_init(identifier, &[context.cfg.ciphersuite]) + .mls_init(identifier.clone(), &[context.cfg.ciphersuite]) .await .map_err(RecursiveError::transaction("mls init"))?; session.provide_transport(transport.clone()).await; @@ -147,6 +148,7 @@ impl SessionContext { let result = Self { transaction, session, + identifier, mls_transport: Arc::new(RwLock::new(transport)), x509_test_chain: Arc::new(chain.cloned()), history_observer: Default::default(), @@ -170,6 +172,7 @@ impl SessionContext { Self { transaction, session, + identifier: todo!("how can we extract an x509 ClientIdentifier from a CC session?"), mls_transport: Arc::new(RwLock::new(transport)), x509_test_chain: Arc::new(chain.cloned()), history_observer: Default::default(), @@ -197,6 +200,7 @@ impl SessionContext { Self { transaction: context.clone(), session: cc.mls, + identifier: todo!("should client id even be part of validated session config?"), mls_transport: Arc::new(RwLock::new(transport.clone())), x509_test_chain: None.into(), history_observer: Default::default(), diff --git a/crypto/src/test_utils/test_context.rs b/crypto/src/test_utils/test_context.rs index f41606bcad..9332ac87cf 100644 --- a/crypto/src/test_utils/test_context.rs +++ b/crypto/src/test_utils/test_context.rs @@ -9,64 +9,62 @@ use super::{ init_x509_test_chain, tmp_db_file, x509::{CertificateParams, X509TestChain}, }; +pub use crate::{Ciphersuite, CredentialType, MlsConversationConfiguration, MlsCustomConfiguration, MlsWirePolicy}; use crate::{ ClientId, ConnectionType, Database, DatabaseKey, e2e_identity::id::{QualifiedE2eiClientId, WireQualifiedClientId}, test_utils::SessionContext, }; -pub use crate::{ - MlsCiphersuite, MlsConversationConfiguration, MlsCredentialType, MlsCustomConfiguration, MlsWirePolicy, -}; #[template] #[rstest( case, case::basic_cs1(TestContext::new( - crate::MlsCredentialType::Basic, + crate::CredentialType::Basic, openmls::prelude::Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 )), case::cert_cs1(TestContext::new( - crate::MlsCredentialType::X509, + crate::CredentialType::X509, openmls::prelude::Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 )), #[cfg(feature = "test-all-cipher")] case::basic_cs2(TestContext::new( - crate::MlsCredentialType::Basic, + crate::CredentialType::Basic, openmls::prelude::Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256 )), #[cfg(feature = "test-all-cipher")] case::cert_cs2(TestContext::new( - crate::MlsCredentialType::X509, + crate::CredentialType::X509, openmls::prelude::Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256 )), #[cfg(feature = "test-all-cipher")] case::basic_cs3(TestContext::new( - crate::MlsCredentialType::Basic, + crate::CredentialType::Basic, openmls::prelude::Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 )), #[cfg(feature = "test-all-cipher")] case::cert_cs3(TestContext::new( - crate::MlsCredentialType::X509, + crate::CredentialType::X509, openmls::prelude::Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 )), #[cfg(feature = "test-all-cipher")] case::basic_cs5(TestContext::new( - crate::MlsCredentialType::Basic, + crate::CredentialType::Basic, openmls::prelude::Ciphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521 )), #[cfg(feature = "test-all-cipher")] case::cert_cs5(TestContext::new( - crate::MlsCredentialType::X509, + crate::CredentialType::X509, openmls::prelude::Ciphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521 )), #[cfg(feature = "test-all-cipher")] case::basic_cs7(TestContext::new( - crate::MlsCredentialType::Basic, + crate::CredentialType::Basic, openmls::prelude::Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384 )), #[cfg(feature = "test-all-cipher")] case::cert_cs7(TestContext::new( - crate::MlsCredentialType::X509, + crate::CredentialType::X509, openmls::prelude::Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384 )), case::pure_ciphertext(TestContext::default_cipher()), @@ -77,14 +75,14 @@ pub fn all_cred_cipher(case: TestContext) {} #[derive(Debug, Clone)] pub struct TestContext { - pub credential_type: MlsCredentialType, + pub credential_type: CredentialType, pub cfg: MlsConversationConfiguration, pub transport: Arc, pub db: Option<(Database, Option>)>, } impl TestContext { - pub fn new(credential_type: MlsCredentialType, cs: openmls::prelude::Ciphersuite) -> Self { + pub fn new(credential_type: CredentialType, cs: openmls::prelude::Ciphersuite) -> Self { Self { credential_type, cfg: MlsConversationConfiguration { @@ -95,7 +93,7 @@ impl TestContext { } } - pub fn ciphersuite(&self) -> MlsCiphersuite { + pub fn ciphersuite(&self) -> Ciphersuite { self.cfg.ciphersuite } @@ -118,7 +116,7 @@ impl TestContext { pub fn default_x509() -> Self { Self { - credential_type: MlsCredentialType::X509, + credential_type: CredentialType::X509, cfg: MlsConversationConfiguration::default(), transport: Arc::::default(), db: None, @@ -132,11 +130,11 @@ impl TestContext { } pub fn is_x509(&self) -> bool { - matches!(self.credential_type, MlsCredentialType::X509) + matches!(self.credential_type, CredentialType::X509) } pub fn is_basic(&self) -> bool { - matches!(self.credential_type, MlsCredentialType::Basic) + matches!(self.credential_type, CredentialType::Basic) } pub fn is_pure_ciphertext(&self) -> bool { @@ -233,14 +231,14 @@ impl TestContext { pub async fn sessions_basic(&self) -> [SessionContext; N] { let client_ids = self.basic_client_ids::(); - return self.sessions_inner(client_ids, None, MlsCredentialType::Basic).await; + return self.sessions_inner(client_ids, None, CredentialType::Basic).await; } pub async fn sessions_basic_with_pki_env(&self) -> [SessionContext; N] { let client_ids = self.basic_client_ids::(); let test_chain = X509TestChain::init_empty(self.signature_scheme()); return self - .sessions_inner(client_ids, Some(&test_chain), MlsCredentialType::Basic) + .sessions_inner(client_ids, Some(&test_chain), CredentialType::Basic) .await; } @@ -252,9 +250,7 @@ impl TestContext { let x509_sessions = self.sessions_x509().await; let chain = x509_sessions[0].x509_chain_unchecked(); let basic_ids = self.basic_client_ids(); - let basic_sessions = self - .sessions_inner(basic_ids, Some(chain), MlsCredentialType::Basic) - .await; + let basic_sessions = self.sessions_inner(basic_ids, Some(chain), CredentialType::Basic).await; (x509_sessions, basic_sessions) } @@ -263,7 +259,7 @@ impl TestContext { client_ids: [ClientId; N], ) -> [SessionContext; N] { let test_chain = self.test_chain(&client_ids, &[], None).await; - self.sessions_inner(client_ids, Some(&test_chain), MlsCredentialType::X509) + self.sessions_inner(client_ids, Some(&test_chain), CredentialType::X509) .await } @@ -273,7 +269,7 @@ impl TestContext { revoked_display_names: &[String], ) -> [SessionContext; N] { let test_chain = self.test_chain(&client_ids, revoked_display_names, None).await; - self.sessions_inner(client_ids, Some(&test_chain), MlsCredentialType::X509) + self.sessions_inner(client_ids, Some(&test_chain), CredentialType::X509) .await } @@ -286,9 +282,9 @@ impl TestContext { &self, client_ids: [ClientId; N], chain: Option<&X509TestChain>, - credential_type: MlsCredentialType, + credential_type: CredentialType, ) -> [SessionContext; N] { - let identifiers = if credential_type == MlsCredentialType::X509 { + let identifiers = if credential_type == CredentialType::X509 { self.x509_identifiers(client_ids, chain.expect("must instantiate an x509 chain in x509 tests")) .await } else { @@ -341,11 +337,11 @@ impl TestContext { }; let mut chain2 = self.test_chain(&client_ids2, revoked_display_names, Some(params)).await; chain1.cross_sign(&mut chain2); - self.sessions_inner(client_ids2, Some(&chain2), MlsCredentialType::X509) + self.sessions_inner(client_ids2, Some(&chain2), CredentialType::X509) .await }; let sessions1 = self - .sessions_inner(client_ids1, Some(&chain1), MlsCredentialType::X509) + .sessions_inner(client_ids1, Some(&chain1), CredentialType::X509) .await; (sessions1, sessions2) } @@ -386,7 +382,7 @@ impl TestContext { /// The first member is required, and is the conversation's creator. pub async fn create_conversation_with_credential_type<'a>( &'a self, - credential_type: MlsCredentialType, + credential_type: CredentialType, members: impl IntoIterator, ) -> TestConversation<'a> { self.create_heterogeneous_conversation(credential_type, credential_type, members) @@ -398,8 +394,8 @@ impl TestContext { /// The first member is required, and is the conversation's creator. pub async fn create_heterogeneous_conversation<'a>( &'a self, - creator_credential_type: MlsCredentialType, - member_credential_type: MlsCredentialType, + creator_credential_type: CredentialType, + member_credential_type: CredentialType, members: impl IntoIterator, ) -> TestConversation<'a> { let mut members = members.into_iter(); @@ -424,7 +420,7 @@ impl TestContext { impl Default for TestContext { fn default() -> Self { Self { - credential_type: MlsCredentialType::Basic, + credential_type: CredentialType::Basic, cfg: MlsConversationConfiguration::default(), transport: Arc::::default(), db: None, diff --git a/crypto/src/test_utils/test_conversation/commit.rs b/crypto/src/test_utils/test_conversation/commit.rs index 4ac1295d9a..b46f208a01 100644 --- a/crypto/src/test_utils/test_conversation/commit.rs +++ b/crypto/src/test_utils/test_conversation/commit.rs @@ -6,10 +6,10 @@ use super::{ operation_guard::{AddGuard, Commit, OperationGuard, TestOperation}, }; use crate::{ - MlsCredentialType, + CredentialType, mls::{ conversation::{ConversationWithMls as _, pending_conversation::PendingConversation}, - credential::CredentialBundle, + credential::Credential, }, }; @@ -31,7 +31,7 @@ impl<'a> TestConversation<'a> { /// Like [Self::invite_notify], but the key packages of the invited members will be of the provided credential type. pub async fn invite_with_credential_type_notify( self, - credential_type: MlsCredentialType, + credential_type: CredentialType, sessions: impl IntoIterator, ) -> TestConversation<'a> { self.invite_with_credential_type(credential_type, sessions) @@ -43,7 +43,7 @@ impl<'a> TestConversation<'a> { /// Like [Self::invite], but the key packages of the invited members will be of the provided credential type. pub async fn invite_with_credential_type( self, - credential_type: MlsCredentialType, + credential_type: CredentialType, sessions: impl IntoIterator, ) -> OperationGuard<'a, Commit> { let new_members = sessions.into_iter().collect::>(); @@ -96,20 +96,20 @@ impl<'a> TestConversation<'a> { } /// Replace the existing credential with an x509 one and notify all members. - pub async fn e2ei_rotate_notify(self, credential_bundle: Option<&CredentialBundle>) -> TestConversation<'a> { - self.e2ei_rotate(credential_bundle).await.notify_members().await + pub async fn e2ei_rotate_notify(self, credential: Option<&Credential>) -> TestConversation<'a> { + self.e2ei_rotate(credential).await.notify_members().await } /// Create an update commit with a leaf node containing x509 credentials, that hasn't been merged by the actor. /// On [OperationGuard::notify_members], the actor will receive this commit. - pub async fn e2ei_rotate_unmerged(self, credential_bundle: &CredentialBundle) -> OperationGuard<'a, Commit> { + pub async fn e2ei_rotate_unmerged(self, credential: &Credential) -> OperationGuard<'a, Commit> { let mut conversation_guard = self.guard().await; let conversation = conversation_guard.conversation().await; let mut leaf_node = conversation.group.own_leaf().unwrap().clone(); drop(conversation); - leaf_node.set_credential_with_key(credential_bundle.to_mls_credential_with_key()); + leaf_node.set_credential_with_key(credential.to_mls_credential_with_key()); let commit = conversation_guard - .update_key_material_inner(Some(credential_bundle), Some(leaf_node)) + .update_key_material_inner(Some(credential), Some(leaf_node)) .await .unwrap() .commit; @@ -118,19 +118,16 @@ impl<'a> TestConversation<'a> { } /// Like [Self::e2ei_rotate_notify], but also when notifying other members, call [SessionContext::verify_sender_identity]. - pub async fn e2ei_rotate_notify_and_verify_sender( - self, - credential_bundle: Option<&CredentialBundle>, - ) -> TestConversation<'a> { - self.e2ei_rotate(credential_bundle) + pub async fn e2ei_rotate_notify_and_verify_sender(self, credential: Option<&Credential>) -> TestConversation<'a> { + self.e2ei_rotate(credential) .await .notify_members_and_verify_sender() .await } /// Replace the existing credential with an x509 one. - pub async fn e2ei_rotate(self, credential_bundle: Option<&CredentialBundle>) -> OperationGuard<'a, Commit> { - self.guard().await.e2ei_rotate(credential_bundle).await.unwrap(); + pub async fn e2ei_rotate(self, credential: Option<&Credential>) -> OperationGuard<'a, Commit> { + self.guard().await.e2ei_rotate(credential).await.unwrap(); let commit = self.transport().await.latest_commit_bundle().await.commit; let committer_index = self.actor_index(); OperationGuard::new(TestOperation::Update, commit, self, [committer_index]) diff --git a/crypto/src/test_utils/test_conversation/mod.rs b/crypto/src/test_utils/test_conversation/mod.rs index 6d49794260..5b790e28ed 100644 --- a/crypto/src/test_utils/test_conversation/mod.rs +++ b/crypto/src/test_utils/test_conversation/mod.rs @@ -2,12 +2,12 @@ use std::sync::Arc; use openmls::{group::QueuedProposal, prelude::group_info::VerifiableGroupInfo}; -use super::{MessageExt as _, MlsCredentialType, MlsTransportTestExt, SessionContext, TestContext, TestError}; +use super::{CredentialType, MessageExt as _, MlsTransportTestExt, SessionContext, TestContext, TestError}; use crate::{ ConversationId, E2eiConversationState, MlsProposalRef, RecursiveError, mls::{ conversation::{Conversation, ConversationGuard, ConversationWithMls as _}, - credential::{CredentialBundle, ext::CredentialExt as _}, + credential::{Credential, ext::CredentialExt as _}, }, }; @@ -34,11 +34,11 @@ impl<'a> TestConversation<'a> { Self::new_with_credential_type(case, creator, case.credential_type).await } - /// Like [Self::new], but with the provided [MlsCredentialType]. + /// Like [Self::new], but with the provided [CredentialType]. pub async fn new_with_credential_type( case: &'a TestContext, creator: &'a SessionContext, - credential_type: MlsCredentialType, + credential_type: CredentialType, ) -> Self { let id = super::conversation_id(); creator @@ -81,25 +81,29 @@ impl<'a> TestConversation<'a> { } pub(crate) async fn export_group_info(&self) -> VerifiableGroupInfo { - let credential = self.credential_bundle().await; + let credential = self.credential().await; let conversation = self.guard().await; let conversation = conversation.conversation().await; let group = conversation.group(); let gi = group - .export_group_info(&self.actor().session.crypto_provider, &credential.signature_key, true) + .export_group_info( + &self.actor().session.crypto_provider, + &credential.signature_key_pair, + true, + ) .unwrap(); gi.group_info().unwrap() } - /// Find the actor's credential bundle used in this conversation. - pub(crate) async fn credential_bundle(&self) -> Arc { + /// Find the actor's credential used in this conversation. + pub(crate) async fn credential(&self) -> Arc { let conversation = self.guard().await; let conversation = conversation.conversation().await; conversation - .find_current_credential_bundle(&self.actor().session) + .find_current_credential(&self.actor().session) .await - .expect("expecting credential bundle") + .expect("expecting credential") } /// Count the members. Also, assert that the count is the same from the point of view of every member. @@ -267,7 +271,7 @@ impl<'a> TestConversation<'a> { self.actor() .transaction - .get_credential_in_use(gi, MlsCredentialType::X509) + .get_credential_in_use(gi, CredentialType::X509) .await .unwrap() } @@ -373,9 +377,9 @@ impl<'a> TestConversation<'a> { let cb = self .actor() .session - .find_most_recent_credential_bundle(self.case.signature_scheme(), MlsCredentialType::X509) + .find_most_recent_credential(self.case.signature_scheme(), CredentialType::X509) .await - .expect("x509 credential bundle"); + .expect("x509 credential"); let cs = guard.ciphersuite().await; let local_identity = cb.to_mls_credential_with_key().extract_identity(cs, None).unwrap(); assert_eq!(&local_identity.client_id.as_bytes(), &cid.0); @@ -390,7 +394,7 @@ impl<'a> TestConversation<'a> { // the keystore let signature_key = self .actor() - .find_signature_keypair_from_keystore(cb.signature_key.public()) + .find_signature_keypair_from_keystore(cb.signature_key_pair.public()) .await .unwrap(); let signature_key = openmls::prelude::SignaturePublicKey::from(signature_key.pk.as_slice()); diff --git a/crypto/src/test_utils/test_conversation/proposal.rs b/crypto/src/test_utils/test_conversation/proposal.rs index fc5fed6e43..f6d8466b20 100644 --- a/crypto/src/test_utils/test_conversation/proposal.rs +++ b/crypto/src/test_utils/test_conversation/proposal.rs @@ -142,13 +142,18 @@ impl<'a> TestConversation<'a> { let sender_index = openmls::prelude::SenderExtensionIndex::new(sender_index); let (sc, ct) = (self.case.signature_scheme(), self.case.credential_type); - let cb = external_actor.find_most_recent_credential_bundle(sc, ct).await.unwrap(); + let cb = external_actor.find_most_recent_credential(sc, ct).await.unwrap(); let group_id = openmls::group::GroupId::from_slice(self.id().as_ref()); let epoch = self.guard().await.epoch().await; - let proposal = - ExternalProposal::new_remove(to_remove_index, group_id, epoch.into(), &cb.signature_key, sender_index) - .unwrap(); + let proposal = ExternalProposal::new_remove( + to_remove_index, + group_id, + epoch.into(), + &cb.signature_key_pair, + sender_index, + ) + .unwrap(); OperationGuard::new(TestOperation::Remove(to_remove), proposal, self, []) } } diff --git a/crypto/src/transaction_context/conversation/external_commit.rs b/crypto/src/transaction_context/conversation/external_commit.rs index 05ff4dc337..d0cf93a5da 100644 --- a/crypto/src/transaction_context/conversation/external_commit.rs +++ b/crypto/src/transaction_context/conversation/external_commit.rs @@ -4,7 +4,7 @@ use openmls::prelude::{MlsGroup, group_info::VerifiableGroupInfo}; use super::{Error, Result}; use crate::{ - ConversationId, LeafError, MlsCiphersuite, MlsCommitBundle, MlsConversationConfiguration, MlsCredentialType, + Ciphersuite, ConversationId, CredentialType, LeafError, MlsCommitBundle, MlsConversationConfiguration, MlsCustomConfiguration, MlsError, MlsGroupInfoBundle, RecursiveError, WelcomeBundle, mls, mls::{ conversation::{ConversationIdRef, pending_conversation::PendingConversation}, @@ -27,8 +27,8 @@ impl TransactionContext { /// * `group_info` - a GroupInfo wrapped in a MLS message. it can be obtained by deserializing a TLS serialized `GroupInfo` object /// * `custom_cfg` - configuration of the MLS conversation fetched from the Delivery Service /// * `credential_type` - kind of [openmls::prelude::Credential] to use for joining this group. - /// If [MlsCredentialType::Basic] is chosen and no Credential has been created yet for it, - /// a new one will be generated. When [MlsCredentialType::X509] is chosen, it fails when no + /// If [CredentialType::Basic] is chosen and no Credential has been created yet for it, + /// a new one will be generated. When [CredentialType::X509] is chosen, it fails when no /// [openmls::prelude::Credential] has been created for the given Ciphersuite. /// /// # Returns [WelcomeBundle] @@ -39,7 +39,7 @@ impl TransactionContext { &self, group_info: VerifiableGroupInfo, custom_cfg: MlsCustomConfiguration, - credential_type: MlsCredentialType, + credential_type: CredentialType, ) -> Result { let (commit_bundle, welcome_bundle, mut pending_conversation) = self .create_external_join_commit(group_info, custom_cfg, credential_type) @@ -67,16 +67,16 @@ impl TransactionContext { &self, group_info: VerifiableGroupInfo, custom_cfg: MlsCustomConfiguration, - credential_type: MlsCredentialType, + credential_type: CredentialType, ) -> Result<(MlsCommitBundle, WelcomeBundle, PendingConversation)> { let client = &self.session().await?; - let cs: MlsCiphersuite = group_info.ciphersuite().into(); + let cs: Ciphersuite = group_info.ciphersuite().into(); let mls_provider = self.mls_provider().await?; let cb = client - .get_most_recent_or_create_credential_bundle(&mls_provider, cs.signature_algorithm(), credential_type) + .get_most_recent_or_create_credential(&mls_provider, cs.signature_algorithm(), credential_type) .await - .map_err(RecursiveError::mls_client("getting or creating credential bundle"))?; + .map_err(RecursiveError::mls_client("getting or creating credential"))?; let configuration = MlsConversationConfiguration { ciphersuite: cs, @@ -86,7 +86,7 @@ impl TransactionContext { let (group, commit, group_info) = MlsGroup::join_by_external_commit( &mls_provider, - &cb.signature_key, + &cb.signature_key_pair, None, group_info, &configuration @@ -400,15 +400,12 @@ mod tests { let ct = group.credential().unwrap().credential_type(); let cs = group.ciphersuite(); let client = alice.session().await; - let cb = client - .find_most_recent_credential_bundle(cs.into(), ct.into()) - .await - .unwrap(); + let cb = client.find_most_recent_credential(cs.into(), ct.into()).await.unwrap(); let gi = group .export_group_info( &alice.transaction.mls_provider().await.unwrap(), - &cb.signature_key, + &cb.signature_key_pair, // joining by external commit assumes we include a ratchet tree, but this `false` // says to leave it out false, diff --git a/crypto/src/transaction_context/conversation/external_proposal.rs b/crypto/src/transaction_context/conversation/external_proposal.rs index c7053c88a6..096d8b933a 100644 --- a/crypto/src/transaction_context/conversation/external_proposal.rs +++ b/crypto/src/transaction_context/conversation/external_proposal.rs @@ -2,8 +2,7 @@ use openmls::prelude::{GroupEpoch, GroupId, JoinProposal, MlsMessageOut}; use super::Result; use crate::{ - ConversationId, LeafError, MlsCiphersuite, MlsError, RecursiveError, - mls::{self, credential::typ::MlsCredentialType}, + Ciphersuite, ConversationId, CredentialType, LeafError, MlsError, RecursiveError, mls, transaction_context::TransactionContext, }; @@ -22,15 +21,15 @@ impl TransactionContext { /// /// # Errors /// Errors resulting from the creation of the proposal within OpenMls. - /// Fails when `credential_type` is [MlsCredentialType::X509] and no Credential has been created + /// Fails when `credential_type` is [CredentialType::X509] and no Credential has been created /// for it beforehand with [TransactionContext::e2ei_mls_init_only] or variants. #[cfg_attr(test, crate::dispotent)] pub async fn new_external_add_proposal( &self, conversation_id: ConversationId, epoch: GroupEpoch, - ciphersuite: MlsCiphersuite, - credential_type: MlsCredentialType, + ciphersuite: Ciphersuite, + credential_type: CredentialType, ) -> Result { let group_id = GroupId::from_slice(conversation_id.as_ref()); let mls_provider = self @@ -43,39 +42,35 @@ impl TransactionContext { .await .map_err(RecursiveError::transaction("getting mls client"))?; let cb = client - .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type) + .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type) .await; let cb = match (cb, credential_type) { (Ok(cb), _) => cb, - (Err(mls::session::Error::CredentialNotFound(_)), MlsCredentialType::Basic) => { - // If a Basic CredentialBundle does not exist, just create one instead of failing + (Err(mls::session::Error::CredentialNotFound(_)), CredentialType::Basic) => { + // If a Basic Credential does not exist, just create one instead of failing client - .init_basic_credential_bundle_if_missing(&mls_provider, ciphersuite.signature_algorithm()) + .init_basic_credential_if_missing(&mls_provider, ciphersuite.signature_algorithm()) .await - .map_err(RecursiveError::mls_client( - "initializing basic credential bundle if missing", - ))?; + .map_err(RecursiveError::mls_client("initializing basic credential if missing"))?; client - .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type) + .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type) .await .map_err(RecursiveError::mls_client( - "finding most recent credential bundle (which we just created)", + "finding most recent credential (which we just created)", ))? } - (Err(mls::session::Error::CredentialNotFound(_)), MlsCredentialType::X509) => { + (Err(mls::session::Error::CredentialNotFound(_)), CredentialType::X509) => { return Err(LeafError::E2eiEnrollmentNotDone.into()); } - (Err(e), _) => return Err(RecursiveError::mls_client("finding most recent credential bundle")(e).into()), + (Err(e), _) => return Err(RecursiveError::mls_client("finding most recent credential")(e).into()), }; let kp = client - .generate_one_keypackage_from_credential_bundle(&mls_provider, ciphersuite, &cb) + .generate_one_keypackage_from_credential(&mls_provider, ciphersuite, &cb) .await - .map_err(RecursiveError::mls_client( - "generating one keypackage from credential bundle", - ))?; + .map_err(RecursiveError::mls_client("generating one keypackage from credential"))?; - JoinProposal::new(kp, group_id, epoch, &cb.signature_key) + JoinProposal::new(kp, group_id, epoch, &cb.signature_key_pair) .map_err(MlsError::wrap("creating join proposal")) .map_err(Into::into) } diff --git a/crypto/src/transaction_context/conversation/mod.rs b/crypto/src/transaction_context/conversation/mod.rs index 9fd69bd83e..542b2cdf6f 100644 --- a/crypto/src/transaction_context/conversation/mod.rs +++ b/crypto/src/transaction_context/conversation/mod.rs @@ -10,7 +10,7 @@ use core_crypto_keystore::{connection::FetchFromDatabase as _, entities::Persist use super::{Error, Result, TransactionContext}; use crate::{ - KeystoreError, LeafError, MlsConversation, MlsConversationConfiguration, MlsCredentialType, RecursiveError, + CredentialType, KeystoreError, LeafError, MlsConversation, MlsConversationConfiguration, RecursiveError, mls::conversation::{ConversationGuard, ConversationIdRef, pending_conversation::PendingConversation}, }; @@ -63,7 +63,7 @@ impl TransactionContext { pub async fn new_conversation( &self, id: &ConversationIdRef, - creator_credential_type: MlsCredentialType, + creator_credential_type: CredentialType, config: MlsConversationConfiguration, ) -> Result<()> { if self.conversation_exists(id).await? || self.pending_conversation_exists(id).await? { diff --git a/crypto/src/transaction_context/conversation/proposal.rs b/crypto/src/transaction_context/conversation/proposal.rs index 2500372cba..208e04edef 100644 --- a/crypto/src/transaction_context/conversation/proposal.rs +++ b/crypto/src/transaction_context/conversation/proposal.rs @@ -158,10 +158,13 @@ mod tests { let conversation = case.create_conversation([&alice]).await; let id = conversation.id().clone(); - let remove_proposal = alice.transaction.new_remove_proposal(&id, b"unknown"[..].into()).await; + let remove_proposal = alice + .transaction + .new_remove_proposal(&id, b"unknown".as_slice().to_owned().into()) + .await; assert!(matches!( remove_proposal.unwrap_err(), - Error::ClientNotFound(client_id) if client_id == b"unknown"[..].into() + Error::ClientNotFound(client_id) if client_id == b"unknown".as_slice() )); }) .await diff --git a/crypto/src/transaction_context/e2e_identity/conversation_state.rs b/crypto/src/transaction_context/e2e_identity/conversation_state.rs index 3b9c658e32..d5559483c5 100644 --- a/crypto/src/transaction_context/e2e_identity/conversation_state.rs +++ b/crypto/src/transaction_context/e2e_identity/conversation_state.rs @@ -2,7 +2,7 @@ use openmls::{messages::group_info::VerifiableGroupInfo, prelude::Node}; use openmls_traits::OpenMlsCryptoProvider; use super::Result; -use crate::{MlsCredentialType, MlsError, RecursiveError, Session, transaction_context::TransactionContext}; +use crate::{CredentialType, MlsError, RecursiveError, Session, transaction_context::TransactionContext}; /// Indicates the state of a Conversation regarding end-to-end identity. /// @@ -50,14 +50,14 @@ impl TransactionContext { }); let auth_service = auth_service.borrow().await; - Ok(Session::compute_conversation_state(cs, credentials, MlsCredentialType::X509, auth_service.as_ref()).await) + Ok(Session::compute_conversation_state(cs, credentials, CredentialType::X509, auth_service.as_ref()).await) } /// See [crate::mls::session::Session::get_credential_in_use]. pub async fn get_credential_in_use( &self, group_info: VerifiableGroupInfo, - credential_type: MlsCredentialType, + credential_type: CredentialType, ) -> Result { let cs = group_info.ciphersuite().into(); // Not verifying the supplied the GroupInfo here could let attackers lure the clients about @@ -89,7 +89,7 @@ impl TransactionContext { #[cfg(test)] mod tests { use super::*; - use crate::{CertificateBundle, MlsCredentialType, Session, test_utils::*}; + use crate::{CertificateBundle, Credential, CredentialType, test_utils::*}; // testing the case where both Bob & Alice have the same Credential type #[apply(all_cred_cipher)] @@ -99,7 +99,7 @@ mod tests { let conversation = case.create_conversation([&alice, &bob]).await; match case.credential_type { - MlsCredentialType::Basic => { + CredentialType::Basic => { let alice_state = conversation.e2ei_state().await; let bob_state = conversation.e2ei_state_of(&bob).await; assert_eq!(alice_state, E2eiConversationState::NotEnabled); @@ -108,7 +108,7 @@ mod tests { let state = conversation.e2ei_state_via_group_info().await; assert_eq!(state, E2eiConversationState::NotEnabled); } - MlsCredentialType::X509 => { + CredentialType::X509 => { let alice_state = conversation.e2ei_state().await; let bob_state = conversation.e2ei_state_of(&bob).await; assert_eq!(alice_state, E2eiConversationState::Verified); @@ -117,6 +117,8 @@ mod tests { let state = conversation.e2ei_state_via_group_info().await; assert_eq!(state, E2eiConversationState::Verified); } + + CredentialType::Unknown(_) => panic!("unknown credential types are unsupported"), } }) .await @@ -129,8 +131,9 @@ mod tests { Box::pin(async move { // That way the conversation creator (Alice) will have a different credential type than Bob let (alice, bob, alice_credential_type) = match case.credential_type { - MlsCredentialType::Basic => (x509_session, basic_session, MlsCredentialType::X509), - MlsCredentialType::X509 => (basic_session, x509_session, MlsCredentialType::Basic), + CredentialType::Basic => (x509_session, basic_session, CredentialType::X509), + CredentialType::X509 => (basic_session, x509_session, CredentialType::Basic), + CredentialType::Unknown(_) => panic!("unknown credential types are unsupported"), }; let conversation = case @@ -164,14 +167,14 @@ mod tests { let intermediate_ca = alice.x509_chain_unchecked().find_local_intermediate_ca(); let cert = CertificateBundle::new_with_default_values(intermediate_ca, Some(expiration_time)); - let cb = Session::new_x509_credential_bundle(cert.clone()).unwrap(); + let cb = Credential::x509(cert.clone()).unwrap(); let conversation = conversation.e2ei_rotate_notify(Some(&cb)).await; let alice_client = alice.transaction.session().await.unwrap(); let alice_provider = alice.transaction.mls_provider().await.unwrap(); // Needed because 'e2ei_rotate' does not do it directly and it's required for 'get_group_info' alice_client - .save_new_x509_credential_bundle(&alice_provider.keystore(), case.signature_scheme(), cert) + .save_new_x509_credential(&alice_provider.keystore(), case.signature_scheme(), cert) .await .unwrap(); @@ -191,7 +194,7 @@ mod tests { let state = alice .transaction - .get_credential_in_use(gi, MlsCredentialType::X509) + .get_credential_in_use(gi, CredentialType::X509) .await .unwrap(); assert_eq!(state, E2eiConversationState::NotVerified); @@ -225,7 +228,7 @@ mod tests { let cert_bundle = CertificateBundle::from_certificate_and_issuer(&alice_cert.certificate, alice_intermediate_ca); - let cb = Session::new_x509_credential_bundle(cert_bundle.clone()).unwrap(); + let cb = Credential::x509(cert_bundle.clone()).unwrap(); let conversation = conversation.e2ei_rotate_notify(Some(&cb)).await; let alice_client = alice.session().await; @@ -233,7 +236,7 @@ mod tests { // Needed because 'e2ei_rotate' does not do it directly and it's required for 'get_group_info' alice_client - .save_new_x509_credential_bundle(&alice_provider.keystore(), case.signature_scheme(), cert_bundle) + .save_new_x509_credential(&alice_provider.keystore(), case.signature_scheme(), cert_bundle) .await .unwrap(); diff --git a/crypto/src/transaction_context/e2e_identity/enabled.rs b/crypto/src/transaction_context/e2e_identity/enabled.rs index b551325524..c42e341edb 100644 --- a/crypto/src/transaction_context/e2e_identity/enabled.rs +++ b/crypto/src/transaction_context/e2e_identity/enabled.rs @@ -25,7 +25,7 @@ mod tests { use openmls_traits::types::SignatureScheme; use super::super::Error; - use crate::{MlsCredentialType, RecursiveError, mls, test_utils::*}; + use crate::{CredentialType, RecursiveError, mls, test_utils::*}; #[apply(all_cred_cipher)] async fn should_be_false_when_basic_and_true_when_x509(case: TestContext) { @@ -33,8 +33,9 @@ mod tests { Box::pin(async move { let e2ei_is_enabled = cc.transaction.e2ei_is_enabled(case.signature_scheme()).await.unwrap(); match case.credential_type { - MlsCredentialType::Basic => assert!(!e2ei_is_enabled), - MlsCredentialType::X509 => assert!(e2ei_is_enabled), + CredentialType::Basic => assert!(!e2ei_is_enabled), + CredentialType::X509 => assert!(e2ei_is_enabled), + CredentialType::Unknown(_) => panic!("unknown credential types are unsupported"), }; }) .await diff --git a/crypto/src/transaction_context/e2e_identity/error.rs b/crypto/src/transaction_context/e2e_identity/error.rs index 16e0c488ee..fa744ebeb8 100644 --- a/crypto/src/transaction_context/e2e_identity/error.rs +++ b/crypto/src/transaction_context/e2e_identity/error.rs @@ -3,7 +3,7 @@ // We allow missing documentation in the error module because the types are generally self-descriptive. #![allow(missing_docs)] -use crate::MlsCredentialType; +use crate::CredentialType; pub type Result = core::result::Result; @@ -12,7 +12,7 @@ pub enum Error { #[error("Incorrect usage of this API")] ImplementationError, #[error("Expected a MLS client with credential type {0:?} but none found")] - MissingExistingClient(MlsCredentialType), + MissingExistingClient(CredentialType), #[error( "We already have an ACME Root Trust Anchor registered. Cannot proceed but this is usually indicative of double registration and can be ignored" )] diff --git a/crypto/src/transaction_context/e2e_identity/mod.rs b/crypto/src/transaction_context/e2e_identity/mod.rs index b0f287c3cb..da1a673fda 100644 --- a/crypto/src/transaction_context/e2e_identity/mod.rs +++ b/crypto/src/transaction_context/e2e_identity/mod.rs @@ -15,7 +15,7 @@ use wire_e2e_identity::prelude::x509::extract_crl_uris; use super::TransactionContext; use crate::{ - CertificateBundle, ClientId, ClientIdentifier, E2eiEnrollment, MlsCiphersuite, RecursiveError, + CertificateBundle, Ciphersuite, ClientId, ClientIdentifier, E2eiEnrollment, RecursiveError, e2e_identity::NewCrlDistributionPoints, mls::credential::{crl::get_new_crl_distribution_points, x509::CertificatePrivateKey}, }; @@ -36,7 +36,7 @@ impl TransactionContext { handle: String, team: Option, expiry_sec: u32, - ciphersuite: MlsCiphersuite, + ciphersuite: Ciphersuite, ) -> Result { let signature_keypair = None; // fresh install without a Basic client. Supplying None will generate a new keypair E2eiEnrollment::try_new( @@ -182,7 +182,7 @@ mod tests { // verify the created client can create a conversation let conversation = case - .create_conversation_with_credential_type(MlsCredentialType::X509, [&session]) + .create_conversation_with_credential_type(CredentialType::X509, [&session]) .await; conversation .guard() diff --git a/crypto/src/transaction_context/e2e_identity/rotate.rs b/crypto/src/transaction_context/e2e_identity/rotate.rs index 5701c0aca1..c06f9d83cd 100644 --- a/crypto/src/transaction_context/e2e_identity/rotate.rs +++ b/crypto/src/transaction_context/e2e_identity/rotate.rs @@ -1,10 +1,10 @@ -use core_crypto_keystore::{CryptoKeystoreMls, connection::FetchFromDatabase, entities::MlsKeyPackage}; +use core_crypto_keystore::{CryptoKeystoreMls, connection::FetchFromDatabase, entities::StoredKeypackage}; use openmls::prelude::KeyPackage; use openmls_traits::OpenMlsCryptoProvider; use super::error::{Error, Result}; use crate::{ - CertificateBundle, E2eiEnrollment, KeystoreError, MlsCiphersuite, MlsCredentialType, MlsError, RecursiveError, + CertificateBundle, Ciphersuite, CredentialType, E2eiEnrollment, KeystoreError, MlsError, RecursiveError, e2e_identity::NewCrlDistributionPoints, mls::credential::{ext::CredentialExt, x509::CertificatePrivateKey}, transaction_context::TransactionContext, @@ -22,7 +22,7 @@ impl TransactionContext { handle: String, team: Option, expiry_sec: u32, - ciphersuite: MlsCiphersuite, + ciphersuite: Ciphersuite, ) -> Result { let mls_provider = self .mls_provider() @@ -33,10 +33,10 @@ impl TransactionContext { .session() .await .map_err(RecursiveError::transaction("getting mls client"))? - .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), MlsCredentialType::Basic) + .find_most_recent_credential(ciphersuite.signature_algorithm(), CredentialType::Basic) .await - .map_err(|_| Error::MissingExistingClient(MlsCredentialType::Basic))?; - let client_id = cb.credential().identity().into(); + .map_err(|_| Error::MissingExistingClient(CredentialType::Basic))?; + let client_id = cb.mls_credential().identity().to_owned().into(); let sign_keypair = Some( cb.signature_key() @@ -70,7 +70,7 @@ impl TransactionContext { handle: Option, team: Option, expiry_sec: u32, - ciphersuite: MlsCiphersuite, + ciphersuite: Ciphersuite, ) -> Result { let mls_provider = self .mls_provider() @@ -81,10 +81,10 @@ impl TransactionContext { .session() .await .map_err(RecursiveError::transaction("getting mls client"))? - .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), MlsCredentialType::X509) + .find_most_recent_credential(ciphersuite.signature_algorithm(), CredentialType::X509) .await - .map_err(|_| Error::MissingExistingClient(MlsCredentialType::X509))?; - let client_id = cb.credential().identity().into(); + .map_err(|_| Error::MissingExistingClient(CredentialType::X509))?; + let client_id = cb.mls_credential().identity().to_owned().into(); let sign_keypair = Some( cb.signature_key() .try_into() @@ -167,7 +167,7 @@ impl TransactionContext { .map_err(RecursiveError::transaction("getting mls provider"))?; client - .save_new_x509_credential_bundle( + .save_new_x509_credential( &self .mls_provider() .await @@ -177,21 +177,21 @@ impl TransactionContext { cert_bundle, ) .await - .map_err(RecursiveError::mls_client("saving new x509 credential bundle"))?; + .map_err(RecursiveError::mls_client("saving new x509 credential"))?; Ok(crl_new_distribution_points) } /// Deletes all key packages whose leaf node's credential does not match the most recently /// saved x509 credential with the provided signature scheme. - pub async fn delete_stale_key_packages(&self, cipher_suite: MlsCiphersuite) -> Result<()> { + pub async fn delete_stale_key_packages(&self, cipher_suite: Ciphersuite) -> Result<()> { let signature_scheme = cipher_suite.signature_algorithm(); let keystore = self .keystore() .await .map_err(RecursiveError::transaction("getting keystore"))?; let nb_kp = keystore - .count::() + .count::() .await .map_err(KeystoreError::wrap("counting key packages"))?; let kps: Vec = keystore @@ -204,9 +204,9 @@ impl TransactionContext { .map_err(RecursiveError::transaction("getting mls client"))?; let cb = client - .find_most_recent_credential_bundle(signature_scheme, MlsCredentialType::X509) + .find_most_recent_credential(signature_scheme, CredentialType::X509) .await - .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?; + .map_err(RecursiveError::mls_client("finding most recent credential"))?; let mut kp_refs = vec![]; @@ -216,7 +216,7 @@ impl TransactionContext { .map_err(RecursiveError::transaction("getting mls provider"))?; for kp in kps { let kp_cred = kp.leaf_node().credential().mls_credential(); - let local_cred = cb.credential().mls_credential(); + let local_cred = cb.mls_credential().mls_credential(); if kp_cred != local_cred { let kpr = kp .hash_ref(provider.crypto()) @@ -235,7 +235,7 @@ impl TransactionContext { mod tests { use std::collections::HashSet; - use core_crypto_keystore::entities::{EntityFindParams, MlsCredential}; + use core_crypto_keystore::entities::{EntityFindParams, StoredCredential}; use openmls::prelude::SignaturePublicKey; use tls_codec::Deserialize; @@ -288,12 +288,12 @@ mod tests { assert_eq!(before_rotate.credential, 1); let old_credential = alice - .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type) + .find_most_recent_credential(case.signature_scheme(), case.credential_type) .await .unwrap() .clone(); - let is_renewal = case.credential_type == MlsCredentialType::X509; + let is_renewal = case.credential_type == CredentialType::X509; let (mut enrollment, cert) = e2ei_utils::e2ei_enrollment( &alice, @@ -314,7 +314,7 @@ mod tests { .unwrap(); let cb = alice - .find_most_recent_credential_bundle(case.signature_scheme(), MlsCredentialType::X509) + .find_most_recent_credential(case.signature_scheme(), CredentialType::X509) .await .unwrap(); @@ -366,10 +366,10 @@ mod tests { // But first let's verify the previous credential material is present assert!( alice - .find_credential_bundle( + .find_credential( case.signature_scheme(), case.credential_type, - &old_credential.signature_key.public().into() + &old_credential.signature_key_pair.public().into() ) .await .is_some() @@ -388,7 +388,7 @@ mod tests { // and the signature keypair is still present assert!( alice - .find_signature_keypair_from_keystore(old_credential.signature_key.public()) + .find_signature_keypair_from_keystore(old_credential.signature_key_pair.public()) .await .is_some() ); @@ -403,12 +403,12 @@ mod tests { // Alice should just have the number of X509 KeyPackages she requested let nb_x509_kp = alice - .count_key_package(case.ciphersuite(), Some(MlsCredentialType::X509)) + .count_key_package(case.ciphersuite(), Some(CredentialType::X509)) .await; assert_eq!(nb_x509_kp, NB_KEY_PACKAGE); // in both cases, Alice should not anymore have any Basic KeyPackage let nb_basic_kp = alice - .count_key_package(case.ciphersuite(), Some(MlsCredentialType::Basic)) + .count_key_package(case.ciphersuite(), Some(CredentialType::Basic)) .await; assert_eq!(nb_basic_kp, 0); @@ -432,7 +432,7 @@ mod tests { let conversation = case .create_conversation([&charlie]) .await - .invite_with_credential_type_notify(MlsCredentialType::X509, [&alice]) + .invite_with_credential_type_notify(CredentialType::X509, [&alice]) .await; assert!(conversation.is_functional_and_contains([&alice, &charlie]).await); }) @@ -448,7 +448,7 @@ mod tests { case.create_conversation([&alice]).await; let old_cb = alice - .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type) + .find_most_recent_credential(case.signature_scheme(), case.credential_type) .await .unwrap() .clone(); @@ -457,7 +457,7 @@ mod tests { // we only have a precision of 1 second for the `created_at` field of the Credential smol::Timer::after(core::time::Duration::from_secs(1)).await; - let is_renewal = case.credential_type == MlsCredentialType::X509; + let is_renewal = case.credential_type == CredentialType::X509; let (mut enrollment, cert) = e2ei_utils::e2ei_enrollment( &alice, @@ -479,7 +479,7 @@ mod tests { // So alice has a new Credential as expected let cb = alice - .find_most_recent_credential_bundle(case.signature_scheme(), MlsCredentialType::X509) + .find_most_recent_credential(case.signature_scheme(), CredentialType::X509) .await .unwrap(); let identity = cb @@ -496,25 +496,24 @@ mod tests { ); // but keeps her old one since it's referenced from some KeyPackages - let old_spk = SignaturePublicKey::from(old_cb.signature_key.public()); + let old_spk = SignaturePublicKey::from(old_cb.signature_key_pair.public()); let old_cb_found = alice - .find_credential_bundle(case.signature_scheme(), case.credential_type, &old_spk) + .find_credential(case.signature_scheme(), case.credential_type, &old_spk) .await .unwrap(); assert_eq!(old_cb, old_cb_found); - let (cid, all_credentials, scs, old_nb_identities) = { + let (scs, old_nb_identities) = { let alice_client = alice.session().await; let old_nb_identities = alice_client.identities_count().await.unwrap(); // Let's simulate an app crash, client gets deleted and restored from keystore - let cid = alice_client.id().await.unwrap(); let scs = HashSet::from([case.signature_scheme()]); let all_credentials = alice .transaction .keystore() .await .unwrap() - .find_all::(EntityFindParams::default()) + .find_all::(EntityFindParams::default()) .await .unwrap() .into_iter() @@ -525,7 +524,8 @@ mod tests { }) .collect::>(); assert_eq!(all_credentials.len(), 2); - (cid, all_credentials, scs, old_nb_identities) + + (scs, old_nb_identities) }; let backend = &alice.transaction.mls_provider().await.unwrap(); backend.keystore().commit_transaction().await.unwrap(); @@ -534,11 +534,14 @@ mod tests { let new_client = alice.session.clone(); new_client.reset().await; - new_client.load(backend, &cid, all_credentials, scs).await.unwrap(); + new_client + .init(alice.identifier.clone(), &scs.iter().copied().collect::>()) + .await + .unwrap(); // Verify that Alice has the same credentials let cb = new_client - .find_most_recent_credential_bundle(case.signature_scheme(), MlsCredentialType::X509) + .find_most_recent_credential(case.signature_scheme(), CredentialType::X509) .await .unwrap(); let identity = cb @@ -576,7 +579,7 @@ mod tests { let e2ei_utils::E2eiInitWrapper { context: cc, case } = wrapper; let cs = case.ciphersuite(); match case.credential_type { - MlsCredentialType::Basic => { + CredentialType::Basic => { cc.e2ei_new_activation_enrollment( ALICE_NEW_DISPLAY_NAME.to_string(), ALICE_NEW_HANDLE.to_string(), @@ -586,7 +589,7 @@ mod tests { ) .await } - MlsCredentialType::X509 => { + CredentialType::X509 => { cc.e2ei_new_rotate_enrollment( Some(ALICE_NEW_DISPLAY_NAME.to_string()), Some(ALICE_NEW_HANDLE.to_string()), @@ -596,13 +599,14 @@ mod tests { ) .await } + CredentialType::Unknown(_) => panic!("unknown credential types are unsupported"), } .map_err(RecursiveError::transaction("creating new enrollment")) .map_err(Into::into) }) } - let is_renewal = case.credential_type == MlsCredentialType::X509; + let is_renewal = case.credential_type == CredentialType::X509; let (mut enrollment, cert) = e2ei_utils::e2ei_enrollment( &alice, @@ -636,7 +640,7 @@ mod tests { let e2ei_utils::E2eiInitWrapper { context: cc, case } = wrapper; let cs = case.ciphersuite(); match case.credential_type { - MlsCredentialType::Basic => { + CredentialType::Basic => { cc.e2ei_new_activation_enrollment( BOB_NEW_DISPLAY_NAME.to_string(), BOB_NEW_HANDLE.to_string(), @@ -646,7 +650,7 @@ mod tests { ) .await } - MlsCredentialType::X509 => { + CredentialType::X509 => { cc.e2ei_new_rotate_enrollment( Some(BOB_NEW_DISPLAY_NAME.to_string()), Some(BOB_NEW_HANDLE.to_string()), @@ -656,12 +660,13 @@ mod tests { ) .await } + CredentialType::Unknown(_) => panic!("unknown credential types are unsupported"), } .map_err(RecursiveError::transaction("creating new enrollment")) .map_err(Into::into) }) } - let is_renewal = case.credential_type == MlsCredentialType::X509; + let is_renewal = case.credential_type == CredentialType::X509; let (mut enrollment, cert) = e2ei_utils::e2ei_enrollment( &bob, @@ -836,7 +841,7 @@ mod tests { let [alice, bob] = case.sessions_basic_with_pki_env().await; Box::pin(async move { let conversation = case - .create_conversation_with_credential_type(MlsCredentialType::Basic, [&alice, &bob]) + .create_conversation_with_credential_type(CredentialType::Basic, [&alice, &bob]) .await; let id = conversation.id().clone(); @@ -860,7 +865,7 @@ mod tests { .await .unwrap(); let alice_old_identity = alice_old_identities.first().unwrap(); - assert_eq!(alice_old_identity.credential_type, MlsCredentialType::Basic); + assert_eq!(alice_old_identity.credential_type, CredentialType::Basic); assert_eq!(alice_old_identity.x509_identity, None); // Alice issues an Update commit to replace her current identity diff --git a/crypto/src/transaction_context/key_package.rs b/crypto/src/transaction_context/key_package.rs index ae2a342410..62ed12fdf1 100644 --- a/crypto/src/transaction_context/key_package.rs +++ b/crypto/src/transaction_context/key_package.rs @@ -3,7 +3,7 @@ use openmls::prelude::{KeyPackage, KeyPackageRef}; use super::{Result, TransactionContext}; -use crate::{MlsCiphersuite, MlsCredentialType, RecursiveError}; +use crate::{Ciphersuite, CredentialType, RecursiveError}; impl TransactionContext { /// Returns `amount_requested` OpenMLS [openmls::key_packages::KeyPackage]s. @@ -21,8 +21,8 @@ impl TransactionContext { /// Errors can happen when accessing the KeyStore pub async fn get_or_create_client_keypackages( &self, - ciphersuite: MlsCiphersuite, - credential_type: MlsCredentialType, + ciphersuite: Ciphersuite, + credential_type: CredentialType, amount_requested: usize, ) -> Result> { let session = self.session().await?; @@ -38,11 +38,11 @@ impl TransactionContext { .map_err(Into::into) } - /// Returns the count of valid, non-expired, unclaimed keypackages in store for the given [MlsCiphersuite] and [MlsCredentialType] + /// Returns the count of valid, non-expired, unclaimed keypackages in store for the given [MlsCiphersuite] and [CredentialType] pub async fn client_valid_key_packages_count( &self, - ciphersuite: MlsCiphersuite, - credential_type: MlsCredentialType, + ciphersuite: Ciphersuite, + credential_type: CredentialType, ) -> Result { let session = self.session().await?; session diff --git a/crypto/src/transaction_context/mod.rs b/crypto/src/transaction_context/mod.rs index cb26ef6780..d19fc93817 100644 --- a/crypto/src/transaction_context/mod.rs +++ b/crypto/src/transaction_context/mod.rs @@ -14,8 +14,8 @@ use openmls_traits::OpenMlsCryptoProvider as _; #[cfg(feature = "proteus")] use crate::proteus::ProteusCentral; use crate::{ - ClientId, ClientIdentifier, CoreCrypto, KeystoreError, MlsCiphersuite, MlsConversation, MlsCredentialType, - MlsError, MlsTransport, RecursiveError, Session, group_store::GroupStore, mls::HasSessionAndCrypto, + Ciphersuite, ClientId, ClientIdentifier, CoreCrypto, CredentialType, KeystoreError, MlsConversation, MlsError, + MlsTransport, RecursiveError, Session, group_store::GroupStore, mls::HasSessionAndCrypto, }; pub mod conversation; pub mod e2e_identity; @@ -215,10 +215,16 @@ impl TransactionContext { } /// Initializes the MLS client of [super::CoreCrypto]. - pub async fn mls_init(&self, identifier: ClientIdentifier, ciphersuites: &[MlsCiphersuite]) -> Result<()> { + pub async fn mls_init(&self, identifier: ClientIdentifier, ciphersuites: &[Ciphersuite]) -> Result<()> { let mls_client = self.session().await?; mls_client - .init(identifier, ciphersuites, &self.mls_provider().await?) + .init( + identifier, + &ciphersuites + .iter() + .map(|ciphersuite| ciphersuite.signature_algorithm()) + .collect::>(), + ) .await .map_err(RecursiveError::mls_client("initializing mls client"))?; @@ -237,16 +243,16 @@ impl TransactionContext { /// Returns the client's public key. pub async fn client_public_key( &self, - ciphersuite: MlsCiphersuite, - credential_type: MlsCredentialType, + ciphersuite: Ciphersuite, + credential_type: CredentialType, ) -> Result> { let cb = self .session() .await? - .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type) + .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type) .await - .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?; - Ok(cb.signature_key.to_public_vec()) + .map_err(RecursiveError::mls_client("finding most recent credential"))?; + Ok(cb.signature_key_pair.to_public_vec()) } /// see [Session::id] diff --git a/crypto/src/transaction_context/test_utils.rs b/crypto/src/transaction_context/test_utils.rs index f46860fac2..ee1c899100 100644 --- a/crypto/src/transaction_context/test_utils.rs +++ b/crypto/src/transaction_context/test_utils.rs @@ -1,9 +1,9 @@ use core_crypto_keystore::{ connection::FetchFromDatabase as _, entities::{ - E2eiEnrollment, MlsCredential, MlsEncryptionKeyPair, MlsEpochEncryptionKeyPair, MlsHpkePrivateKey, - MlsKeyPackage, MlsPendingMessage, MlsPskBundle, MlsSignatureKeyPair, PersistedMlsGroup, - PersistedMlsPendingGroup, + MlsPendingMessage, PersistedMlsGroup, PersistedMlsPendingGroup, StoredCredential, StoredE2eiEnrollment, + StoredEncryptionKeyPair, StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle, + StoredSignatureKeypair, }, }; @@ -28,17 +28,17 @@ impl TransactionContext { /// Count the entities pub async fn count_entities(&self) -> EntitiesCount { let keystore = self.keystore().await.unwrap(); - let credential = keystore.count::().await.unwrap(); - let encryption_keypair = keystore.count::().await.unwrap(); - let epoch_encryption_keypair = keystore.count::().await.unwrap(); - let enrollment = keystore.count::().await.unwrap(); + let credential = keystore.count::().await.unwrap(); + let encryption_keypair = keystore.count::().await.unwrap(); + let epoch_encryption_keypair = keystore.count::().await.unwrap(); + let enrollment = keystore.count::().await.unwrap(); let group = keystore.count::().await.unwrap(); - let hpke_private_key = keystore.count::().await.unwrap(); - let key_package = keystore.count::().await.unwrap(); + let hpke_private_key = keystore.count::().await.unwrap(); + let key_package = keystore.count::().await.unwrap(); let pending_group = keystore.count::().await.unwrap(); let pending_messages = keystore.count::().await.unwrap(); - let psk_bundle = keystore.count::().await.unwrap(); - let signature_keypair = keystore.count::().await.unwrap(); + let psk_bundle = keystore.count::().await.unwrap(); + let signature_keypair = keystore.count::().await.unwrap(); EntitiesCount { credential, encryption_keypair, diff --git a/interop/src/clients/corecrypto/ffi.rs b/interop/src/clients/corecrypto/ffi.rs index 93e178b767..edeaa2a744 100644 --- a/interop/src/clients/corecrypto/ffi.rs +++ b/interop/src/clients/corecrypto/ffi.rs @@ -31,8 +31,8 @@ pub(crate) struct CoreCryptoFfiClient { impl CoreCryptoFfiClient { pub(crate) async fn new() -> Result { let client_id = uuid::Uuid::new_v4(); - let client_id_bytes: Vec = client_id.as_hyphenated().to_string().as_bytes().into(); - let client_id = Arc::new(ClientId::from(core_crypto::ClientId::from(&client_id_bytes[..]))); + let client_id_bytes: Vec = client_id.as_hyphenated().to_string().into_bytes(); + let client_id = Arc::new(ClientId::from(core_crypto::ClientId::from(client_id_bytes.clone()))); let temp_file = NamedTempFile::with_prefix("interop-ffi-keystore-")?; let key = DatabaseKey::from_cc(core_crypto::DatabaseKey::generate()); let db = Database::open(&temp_file.path().to_string_lossy(), key) @@ -129,7 +129,7 @@ impl EmulatedMlsClient for CoreCryptoFfiClient { } async fn kick_client(&self, conversation_id: &[u8], client_id: &[u8]) -> Result<()> { - let client_id = Arc::new(ClientId::from(core_crypto::ClientId::from(client_id))); + let client_id = Arc::new(ClientId::from(core_crypto::ClientId::from(client_id.to_owned()))); let conversation_id = conversation_id.into(); let extractor = TransactionHelper::new(move |context| async move { context diff --git a/interop/src/clients/corecrypto/native.rs b/interop/src/clients/corecrypto/native.rs index e4ee77040b..7b86da56d7 100644 --- a/interop/src/clients/corecrypto/native.rs +++ b/interop/src/clients/corecrypto/native.rs @@ -27,7 +27,7 @@ impl CoreCryptoNativeClient { async fn internal_new(deferred: bool) -> Result { let client_id = uuid::Uuid::new_v4(); - let cid = (!deferred).then(|| client_id.as_hyphenated().to_string().as_bytes().into()); + let cid = (!deferred).then(|| client_id.as_hyphenated().to_string().into_bytes().into()); let db = Database::open(ConnectionType::InMemory, &DatabaseKey::generate()) .await @@ -83,7 +83,7 @@ impl EmulatedMlsClient for CoreCryptoNativeClient { let transaction = self.cc.new_transaction().await?; let start = std::time::Instant::now(); let kp = transaction - .get_or_create_client_keypackages(CIPHERSUITE_IN_USE.into(), MlsCredentialType::Basic, 1) + .get_or_create_client_keypackages(CIPHERSUITE_IN_USE.into(), CredentialType::Basic, 1) .await? .pop() .unwrap(); @@ -109,7 +109,7 @@ impl EmulatedMlsClient for CoreCryptoNativeClient { ..Default::default() }; transaction - .new_conversation(&conversation_id, MlsCredentialType::Basic, config) + .new_conversation(&conversation_id, CredentialType::Basic, config) .await?; } @@ -132,7 +132,7 @@ impl EmulatedMlsClient for CoreCryptoNativeClient { transaction .conversation(&conversation_id) .await? - .remove_members(&[client_id.to_owned().into()]) + .remove_members(&[ClientIdRef::new(client_id)]) .await?; transaction.finish().await?; diff --git a/interop/src/clients/mod.rs b/interop/src/clients/mod.rs index b5d90f27a2..85bc1d7f49 100644 --- a/interop/src/clients/mod.rs +++ b/interop/src/clients/mod.rs @@ -1,7 +1,7 @@ #![allow(clippy::assign_op_pattern)] use anyhow::Result; -use core_crypto::MlsCiphersuite; +use core_crypto::Ciphersuite; pub(crate) mod corecrypto; @@ -79,5 +79,5 @@ pub(crate) trait EmulatedProteusClient: EmulatedClient { #[async_trait::async_trait(?Send)] #[allow(dead_code)] pub(crate) trait EmulatedE2eIdentityClient: EmulatedClient { - async fn e2ei_new_enrollment(&mut self, ciphersuite: MlsCiphersuite) -> Result<()>; + async fn e2ei_new_enrollment(&mut self, ciphersuite: Ciphersuite) -> Result<()>; } diff --git a/interop/src/main.rs b/interop/src/main.rs index f8bfbe5bd9..357e8ef19e 100644 --- a/interop/src/main.rs +++ b/interop/src/main.rs @@ -3,9 +3,9 @@ use std::sync::Arc; use anyhow::{Result, anyhow}; -use core_crypto::CiphersuiteName; #[cfg(target_family = "wasm")] use core_crypto::DatabaseKey; +use core_crypto::MlsCiphersuite; use tls_codec::Serialize; #[cfg(not(target_family = "wasm"))] @@ -20,7 +20,7 @@ const MLS_MAIN_CLIENTID: &[u8] = b"test_main"; const MLS_CONVERSATION_ID: &[u8] = b"test_conversation"; const ROUNDTRIP_MSG_AMOUNT: usize = 100; -const CIPHERSUITE_IN_USE: CiphersuiteName = CiphersuiteName::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519; +const CIPHERSUITE_IN_USE: MlsCiphersuite = MlsCiphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519; // TODO: Add support for Android emulator. Tracking issue: WPB-9646 // TODO: Add support for iOS emulator when on macOS. Tracking issue: WPB-9646 @@ -151,7 +151,7 @@ async fn run_mls_test(chrome_driver_addr: &std::net::SocketAddr, web_server: &st let mut clients = create_mls_clients(chrome_driver_addr, web_server).await; let configuration = SessionConfig::builder() .database(db) - .client_id(MLS_MAIN_CLIENTID.into()) + .client_id(MLS_MAIN_CLIENTID.to_owned().into()) .ciphersuites([CIPHERSUITE_IN_USE.into()]) .build() .validate()?; @@ -169,7 +169,7 @@ async fn run_mls_test(chrome_driver_addr: &std::net::SocketAddr, web_server: &st cc.provide_transport(success_provider.clone()).await; let transaction = cc.new_transaction().await?; transaction - .new_conversation(&conversation_id, MlsCredentialType::Basic, config) + .new_conversation(&conversation_id, CredentialType::Basic, config) .await?; spinner.success("[MLS] Step 0: Initializing clients [OK]"); @@ -307,7 +307,7 @@ async fn run_proteus_test(chrome_driver_addr: &std::net::SocketAddr, web_server: let configuration = SessionConfig::builder() .database(db) - .client_id(MLS_MAIN_CLIENTID.into()) + .client_id(MLS_MAIN_CLIENTID.to_owned().into()) .ciphersuites([CIPHERSUITE_IN_USE.into()]) .build() .validate()?; diff --git a/keystore-dump/src/main.rs b/keystore-dump/src/main.rs index 706980fe5c..1c23962eb8 100644 --- a/keystore-dump/src/main.rs +++ b/keystore-dump/src/main.rs @@ -44,7 +44,7 @@ async fn main() -> anyhow::Result<()> { let mut credentials: Vec = vec![]; for cred in keystore - .find_all::(Default::default()) + .find_all::(Default::default()) .await? .into_iter() { @@ -64,7 +64,7 @@ async fn main() -> anyhow::Result<()> { let mut signature_keypairs: Vec = vec![]; for kp in keystore - .find_all::(Default::default()) + .find_all::(Default::default()) .await? .into_iter() { @@ -78,7 +78,7 @@ async fn main() -> anyhow::Result<()> { json_map.serialize_entry("mls_signature_keypairs", &signature_keypairs)?; let hpke_sks: Vec = keystore - .find_all::(Default::default()) + .find_all::(Default::default()) .await? .into_iter() .map(|hpke_sk| postcard::from_bytes::(&hpke_sk.sk)) @@ -86,7 +86,7 @@ async fn main() -> anyhow::Result<()> { json_map.serialize_entry("mls_hpke_private_keys", &hpke_sks)?; let hpke_keypairs: Vec = keystore - .find_all::(Default::default()) + .find_all::(Default::default()) .await? .into_iter() .map(|hpke_kp| postcard::from_bytes::(&hpke_kp.sk)) @@ -94,7 +94,11 @@ async fn main() -> anyhow::Result<()> { json_map.serialize_entry("mls_hpke_keypairs", &hpke_keypairs)?; let mut external_psks: std::collections::HashMap = Default::default(); - for psk in keystore.find_all::(Default::default()).await?.into_iter() { + for psk in keystore + .find_all::(Default::default()) + .await? + .into_iter() + { let mls_psk = postcard::from_bytes::(&psk.psk)?; external_psks.insert(hex::encode(&psk.psk_id), mls_psk); } @@ -102,7 +106,7 @@ async fn main() -> anyhow::Result<()> { json_map.serialize_entry("external_psks", &external_psks)?; let keypackages: Vec = keystore - .find_all::(Default::default()) + .find_all::(Default::default()) .await? .into_iter() .map(|kp| postcard::from_bytes::(&kp.keypackage)) @@ -110,7 +114,7 @@ async fn main() -> anyhow::Result<()> { json_map.serialize_entry("mls_keypackages", &keypackages)?; let e2ei_enrollments: Vec = keystore - .find_all::(Default::default()) + .find_all::(Default::default()) .await? .into_iter() .map(|enrollment| serde_json::from_slice::(&enrollment.content)) diff --git a/keystore/src/connection/platform/wasm/migrations/db_key_type_to_bytes.rs b/keystore/src/connection/platform/wasm/migrations/db_key_type_to_bytes.rs index c0b035f195..f014fead94 100644 --- a/keystore/src/connection/platform/wasm/migrations/db_key_type_to_bytes.rs +++ b/keystore/src/connection/platform/wasm/migrations/db_key_type_to_bytes.rs @@ -8,10 +8,10 @@ use crate::{ CryptoKeystoreError, CryptoKeystoreResult, DatabaseKey, connection::platform::wasm::rekey::rekey_entities, entities::{ - E2eiAcmeCA, E2eiCrl, E2eiEnrollment, E2eiIntermediateCert, E2eiRefreshToken, Entity as _, EntityBase as _, - MlsCredential, MlsEncryptionKeyPair, MlsEpochEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage, - MlsPendingMessage, MlsPskBundle, MlsSignatureKeyPair, PersistedMlsGroup, PersistedMlsPendingGroup, - ProteusIdentity, ProteusPrekey, ProteusSession, + E2eiAcmeCA, E2eiCrl, E2eiIntermediateCert, E2eiRefreshToken, Entity as _, EntityBase as _, MlsPendingMessage, + PersistedMlsGroup, PersistedMlsPendingGroup, ProteusIdentity, ProteusPrekey, ProteusSession, StoredCredential, + StoredE2eiEnrollment, StoredEncryptionKeyPair, StoredEpochEncryptionKeypair, StoredHpkePrivateKey, + StoredKeypackage, StoredPskBundle, StoredSignatureKeypair, }, }; @@ -35,17 +35,17 @@ pub(crate) async fn migrate_db_key_type_to_bytes( old_cipher, new_cipher, [ - MlsCredential, - MlsSignatureKeyPair, - MlsHpkePrivateKey, - MlsEncryptionKeyPair, - MlsEpochEncryptionKeyPair, - MlsPskBundle, - MlsKeyPackage, + StoredCredential, + StoredSignatureKeypair, + StoredHpkePrivateKey, + StoredEncryptionKeyPair, + StoredEpochEncryptionKeypair, + StoredPskBundle, + StoredKeypackage, PersistedMlsGroup, PersistedMlsPendingGroup, MlsPendingMessage, - E2eiEnrollment, + StoredE2eiEnrollment, E2eiRefreshToken, E2eiAcmeCA, E2eiIntermediateCert, diff --git a/keystore/src/connection/platform/wasm/migrations/v0.rs b/keystore/src/connection/platform/wasm/migrations/v0.rs index 5fbcd92f74..c888065080 100644 --- a/keystore/src/connection/platform/wasm/migrations/v0.rs +++ b/keystore/src/connection/platform/wasm/migrations/v0.rs @@ -5,22 +5,23 @@ use idb::{ use super::{DB_VERSION_0, Metabuilder}; use crate::entities::{ - E2eiAcmeCA, E2eiCrl, E2eiEnrollment, E2eiIntermediateCert, E2eiRefreshToken, EntityBase as _, MlsCredential, - MlsEncryptionKeyPair, MlsEpochEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage, MlsPendingMessage, MlsPskBundle, - MlsSignatureKeyPair, PersistedMlsGroup, PersistedMlsPendingGroup, ProteusIdentity, ProteusPrekey, ProteusSession, + E2eiAcmeCA, E2eiCrl, E2eiIntermediateCert, E2eiRefreshToken, EntityBase as _, MlsPendingMessage, PersistedMlsGroup, + PersistedMlsPendingGroup, ProteusIdentity, ProteusPrekey, ProteusSession, StoredCredential, StoredE2eiEnrollment, + StoredEncryptionKeyPair, StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle, + StoredSignatureKeypair, }; pub(super) fn get_builder(name: &str) -> Metabuilder { let idb_builder = Metabuilder::new(name) .version(DB_VERSION_0) .add_object_store( - ObjectStoreBuilder::new(MlsCredential::COLLECTION_NAME) + ObjectStoreBuilder::new(StoredCredential::COLLECTION_NAME) .auto_increment(false) .add_index(IndexBuilder::new("id".into(), KeyPath::new_single("id"))) .add_index(IndexBuilder::new("credential".into(), KeyPath::new_single("credential")).unique(true)), ) .add_object_store( - ObjectStoreBuilder::new(MlsSignatureKeyPair::COLLECTION_NAME) + ObjectStoreBuilder::new(StoredSignatureKeypair::COLLECTION_NAME) .auto_increment(false) .add_index(IndexBuilder::new( "signature_scheme".into(), @@ -29,27 +30,27 @@ pub(super) fn get_builder(name: &str) -> Metabuilder { .add_index(IndexBuilder::new("signature_pk".into(), KeyPath::new_single("pk"))), ) .add_object_store( - ObjectStoreBuilder::new(MlsHpkePrivateKey::COLLECTION_NAME) + ObjectStoreBuilder::new(StoredHpkePrivateKey::COLLECTION_NAME) .auto_increment(false) .add_index(IndexBuilder::new("pk".into(), KeyPath::new_single("pk")).unique(true)), ) .add_object_store( - ObjectStoreBuilder::new(MlsEncryptionKeyPair::COLLECTION_NAME) + ObjectStoreBuilder::new(StoredEncryptionKeyPair::COLLECTION_NAME) .auto_increment(false) .add_index(IndexBuilder::new("pk".into(), KeyPath::new_single("pk")).unique(true)), ) .add_object_store( - ObjectStoreBuilder::new(MlsEpochEncryptionKeyPair::COLLECTION_NAME) + ObjectStoreBuilder::new(StoredEpochEncryptionKeypair::COLLECTION_NAME) .auto_increment(false) .add_index(IndexBuilder::new("id".into(), KeyPath::new_single("id")).unique(true)), ) .add_object_store( - ObjectStoreBuilder::new(MlsPskBundle::COLLECTION_NAME) + ObjectStoreBuilder::new(StoredPskBundle::COLLECTION_NAME) .auto_increment(false) .add_index(IndexBuilder::new("psk_id".into(), KeyPath::new_single("psk_id")).unique(true)), ) .add_object_store( - ObjectStoreBuilder::new(MlsKeyPackage::COLLECTION_NAME) + ObjectStoreBuilder::new(StoredKeypackage::COLLECTION_NAME) .auto_increment(false) .add_index( IndexBuilder::new("keypackage_ref".into(), KeyPath::new_single("keypackage_ref")).unique(true), @@ -71,7 +72,7 @@ pub(super) fn get_builder(name: &str) -> Metabuilder { .add_index(IndexBuilder::new("id".into(), KeyPath::new_single("id"))), ) .add_object_store( - ObjectStoreBuilder::new(E2eiEnrollment::COLLECTION_NAME) + ObjectStoreBuilder::new(StoredE2eiEnrollment::COLLECTION_NAME) .auto_increment(false) .add_index(IndexBuilder::new("id".into(), KeyPath::new_single("id")).unique(true)), ) diff --git a/keystore/src/connection/platform/wasm/migrations/v3.rs b/keystore/src/connection/platform/wasm/migrations/v3.rs index ae4ad19cab..8310d4957d 100644 --- a/keystore/src/connection/platform/wasm/migrations/v3.rs +++ b/keystore/src/connection/platform/wasm/migrations/v3.rs @@ -6,7 +6,7 @@ use idb::{ use super::{DB_VERSION_3, Metabuilder}; use crate::{ CryptoKeystoreResult, - entities::{EntityBase as _, MlsBufferedCommit}, + entities::{EntityBase as _, StoredBufferedCommit}, }; /// Open IDB once with the new builder and close it, this will add the new object store. @@ -17,11 +17,11 @@ pub(super) async fn migrate(name: &str) -> CryptoKeystoreResult { Ok(version) } -/// Add a new object store for the MlsBufferedCommit struct. +/// Add a new object store for the StoredBufferedCommit struct. pub(super) fn get_builder(name: &str) -> Metabuilder { let previous_builder = super::v2::get_builder(name); previous_builder.version(DB_VERSION_3).add_object_store( - ObjectStoreBuilder::new(MlsBufferedCommit::COLLECTION_NAME) + ObjectStoreBuilder::new(StoredBufferedCommit::COLLECTION_NAME) .auto_increment(false) .add_index( IndexBuilder::new("conversation_id".into(), KeyPath::new_single("conversation_id")).unique(true), diff --git a/keystore/src/connection/platform/wasm/mod.rs b/keystore/src/connection/platform/wasm/mod.rs index b5f72456ec..08ecfd6f50 100644 --- a/keystore/src/connection/platform/wasm/mod.rs +++ b/keystore/src/connection/platform/wasm/mod.rs @@ -8,10 +8,10 @@ use crate::{ DatabaseConnection, DatabaseConnectionRequirements, DatabaseKey, platform::wasm::migrations::open_and_migrate, }, entities::{ - E2eiAcmeCA, E2eiCrl, E2eiEnrollment, E2eiIntermediateCert, Entity as _, EntityBase as _, MlsCredential, - MlsEncryptionKeyPair, MlsEpochEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage, MlsPendingMessage, - MlsPskBundle, MlsSignatureKeyPair, PersistedMlsGroup, PersistedMlsPendingGroup, ProteusIdentity, ProteusPrekey, - ProteusSession, + E2eiAcmeCA, E2eiCrl, E2eiIntermediateCert, Entity as _, EntityBase as _, MlsPendingMessage, PersistedMlsGroup, + PersistedMlsPendingGroup, ProteusIdentity, ProteusPrekey, ProteusSession, StoredCredential, + StoredE2eiEnrollment, StoredEncryptionKeyPair, StoredEpochEncryptionKeypair, StoredHpkePrivateKey, + StoredKeypackage, StoredPskBundle, StoredSignatureKeypair, }, }; @@ -87,17 +87,17 @@ impl<'a> DatabaseConnection<'a> for WasmConnection { old_cipher, new_cipher, [ - MlsCredential, - MlsSignatureKeyPair, - MlsHpkePrivateKey, - MlsEncryptionKeyPair, - MlsEpochEncryptionKeyPair, - MlsPskBundle, - MlsKeyPackage, + StoredCredential, + StoredSignatureKeypair, + StoredHpkePrivateKey, + StoredEncryptionKeyPair, + StoredEpochEncryptionKeypair, + StoredPskBundle, + StoredKeypackage, PersistedMlsGroup, PersistedMlsPendingGroup, MlsPendingMessage, - E2eiEnrollment, + StoredE2eiEnrollment, E2eiAcmeCA, E2eiIntermediateCert, E2eiCrl, diff --git a/keystore/src/entities/mls.rs b/keystore/src/entities/mls.rs index 49a8b8db1a..3049227d7a 100644 --- a/keystore/src/entities/mls.rs +++ b/keystore/src/entities/mls.rs @@ -95,7 +95,8 @@ pub struct MlsPendingMessage { serde::Serialize, serde::Deserialize, )] -pub struct MlsBufferedCommit { +#[entity(collection_name = "mls_buffered_commits")] +pub struct StoredBufferedCommit { // we'd ideally just call this field `conversation_id`, but as of right now the // Entity macro does not yet support id columns not named `id` #[id(hex, column = "conversation_id_hex")] @@ -104,7 +105,7 @@ pub struct MlsBufferedCommit { commit_data: Vec, } -impl MlsBufferedCommit { +impl StoredBufferedCommit { /// Create a new `Self` from conversation id and the commit data. pub fn new(conversation_id: Vec, commit_data: Vec) -> Self { Self { @@ -129,7 +130,7 @@ impl MlsBufferedCommit { /// Entity representing a persisted `Credential` #[derive(core_crypto_macros::Debug, Clone, PartialEq, Eq, Zeroize, serde::Serialize, serde::Deserialize)] #[zeroize(drop)] -pub struct MlsCredential { +pub struct StoredCredential { #[sensitive] pub id: Vec, #[sensitive] @@ -146,7 +147,7 @@ pub trait MlsCredentialExt: Entity { /// Entity representing a persisted `SignatureKeyPair` #[derive(core_crypto_macros::Debug, Clone, PartialEq, Eq, Zeroize, serde::Serialize, serde::Deserialize)] #[zeroize(drop)] -pub struct MlsSignatureKeyPair { +pub struct StoredSignatureKeypair { pub signature_scheme: u16, #[sensitive] pub pk: Vec, @@ -156,7 +157,7 @@ pub struct MlsSignatureKeyPair { pub credential_id: Vec, } -impl MlsSignatureKeyPair { +impl StoredSignatureKeypair { pub fn new(signature_scheme: SignatureScheme, pk: Vec, keypair: Vec, credential_id: Vec) -> Self { Self { signature_scheme: signature_scheme as u16, @@ -171,7 +172,7 @@ impl MlsSignatureKeyPair { #[derive(core_crypto_macros::Debug, Clone, PartialEq, Eq, Zeroize, serde::Serialize, serde::Deserialize)] #[zeroize(drop)] #[sensitive] -pub struct MlsHpkePrivateKey { +pub struct StoredHpkePrivateKey { pub sk: Vec, pub pk: Vec, } @@ -180,12 +181,12 @@ pub struct MlsHpkePrivateKey { #[derive(core_crypto_macros::Debug, Clone, PartialEq, Eq, Zeroize, serde::Serialize, serde::Deserialize)] #[zeroize(drop)] #[sensitive] -pub struct MlsEncryptionKeyPair { +pub struct StoredEncryptionKeyPair { pub sk: Vec, pub pk: Vec, } -/// Entity representing a list of [MlsEncryptionKeyPair] +/// Entity representing a list of [StoredEncryptionKeyPair] #[derive( core_crypto_macros::Debug, Clone, @@ -198,7 +199,7 @@ pub struct MlsEncryptionKeyPair { )] #[zeroize(drop)] #[entity(collection_name = "mls_epoch_encryption_keypairs")] -pub struct MlsEpochEncryptionKeyPair { +pub struct StoredEpochEncryptionKeypair { #[id(hex, column = "id_hex")] pub id: Vec, #[sensitive] @@ -209,7 +210,7 @@ pub struct MlsEpochEncryptionKeyPair { #[derive(core_crypto_macros::Debug, Clone, PartialEq, Eq, Zeroize, serde::Serialize, serde::Deserialize)] #[zeroize(drop)] #[sensitive] -pub struct MlsPskBundle { +pub struct StoredPskBundle { pub psk_id: Vec, pub psk: Vec, } @@ -227,7 +228,7 @@ pub struct MlsPskBundle { )] #[zeroize(drop)] #[entity(collection_name = "mls_keypackages")] -pub struct MlsKeyPackage { +pub struct StoredKeypackage { #[id(hex, column = "keypackage_ref_hex")] pub keypackage_ref: Vec, #[sensitive] @@ -248,7 +249,7 @@ pub struct MlsKeyPackage { )] #[zeroize(drop)] #[entity(collection_name = "e2ei_enrollment", no_upsert)] -pub struct E2eiEnrollment { +pub struct StoredE2eiEnrollment { pub id: Vec, pub content: Vec, } diff --git a/keystore/src/entities/platform/generic/mls/credential.rs b/keystore/src/entities/platform/generic/mls/credential.rs index 8d58da8f68..3011e7d9b6 100644 --- a/keystore/src/entities/platform/generic/mls/credential.rs +++ b/keystore/src/entities/platform/generic/mls/credential.rs @@ -3,16 +3,46 @@ use std::{ time::SystemTime, }; +use rusqlite::Transaction; + use crate::{ CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind, connection::{DatabaseConnection, KeystoreDatabaseConnection, TransactionWrapper}, entities::{ - Entity, EntityBase, EntityFindParams, EntityTransactionExt, MlsCredential, MlsCredentialExt, StringEntityId, + Entity, EntityBase, EntityFindParams, EntityTransactionExt, MlsCredentialExt, StoredCredential, StringEntityId, }, }; +impl StoredCredential { + fn load(transaction: &Transaction<'_>, rowid: i64, created_at: u64) -> CryptoKeystoreResult { + let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, "mls_credentials", "id", rowid, true)?; + + let mut id = vec![]; + blob.read_to_end(&mut id)?; + blob.close()?; + + let mut blob = transaction.blob_open( + rusqlite::DatabaseName::Main, + "mls_credentials", + "credential", + rowid, + true, + )?; + + let mut credential = vec![]; + blob.read_to_end(&mut credential)?; + blob.close()?; + + Ok(Self { + id, + credential, + created_at, + }) + } +} + #[async_trait::async_trait] -impl Entity for MlsCredential { +impl Entity for StoredCredential { fn id_raw(&self) -> &[u8] { self.id.as_slice() } @@ -33,40 +63,18 @@ impl Entity for MlsCredential { params.to_sql() ); - let mut stmt = transaction.prepare_cached(&query)?; - let mut rows = stmt.query_map([], |r| Ok((r.get(0)?, r.get(1)?)))?; - let entities = rows.try_fold(Vec::new(), |mut acc, rowid_result| { - use std::io::Read as _; - let (rowid, created_at) = rowid_result?; - - let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, "mls_credentials", "id", rowid, true)?; - - let mut id = vec![]; - blob.read_to_end(&mut id)?; - blob.close()?; - - let mut blob = transaction.blob_open( - rusqlite::DatabaseName::Main, - "mls_credentials", - "credential", - rowid, - true, - )?; - - let mut credential = vec![]; - blob.read_to_end(&mut credential)?; - blob.close()?; - - acc.push(Self { - id, - credential, - created_at, - }); - - crate::CryptoKeystoreResult::Ok(acc) - })?; - - Ok(entities) + transaction + .prepare_cached(&query)? + .query_map([], |row| { + let rowid = row.get(0)?; + let created_at = row.get(1)?; + Ok((rowid, created_at)) + })? + .map(|rowid_result| -> CryptoKeystoreResult<_> { + let (rowid, created_at) = rowid_result?; + Self::load(&transaction, rowid, created_at) + }) + .collect() } async fn find_one( @@ -76,35 +84,15 @@ impl Entity for MlsCredential { let mut conn = conn.conn().await; let transaction = conn.transaction()?; use rusqlite::OptionalExtension as _; - let maybe_rowid = transaction + transaction .query_row( "SELECT rowid, unixepoch(created_at) FROM mls_credentials WHERE id = ?", [id.as_slice()], |r| Ok((r.get::<_, i64>(0)?, r.get(1)?)), ) - .optional()?; - - if let Some((rowid, created_at)) = maybe_rowid { - let mut blob = transaction.blob_open( - rusqlite::DatabaseName::Main, - "mls_credentials", - "credential", - rowid, - true, - )?; - - let mut credential = Vec::with_capacity(blob.len()); - blob.read_to_end(&mut credential)?; - blob.close()?; - - Ok(Some(Self { - id: id.to_bytes(), - credential, - created_at, - })) - } else { - Ok(None) - } + .optional()? + .map(|(rowid, created_at)| Self::load(&transaction, rowid, created_at)) + .transpose() } async fn count(conn: &mut Self::ConnectionType) -> crate::CryptoKeystoreResult { @@ -115,22 +103,22 @@ impl Entity for MlsCredential { } #[async_trait::async_trait] -impl EntityBase for MlsCredential { +impl EntityBase for StoredCredential { type ConnectionType = KeystoreDatabaseConnection; type AutoGeneratedFields = u64; const COLLECTION_NAME: &'static str = "mls_credentials"; fn to_missing_key_err_kind() -> MissingKeyErrorKind { - MissingKeyErrorKind::MlsCredential + MissingKeyErrorKind::StoredCredential } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::MlsCredential(self) + crate::transaction::dynamic_dispatch::Entity::StoredCredential(self) } } #[async_trait::async_trait] -impl EntityTransactionExt for MlsCredential { +impl EntityTransactionExt for StoredCredential { async fn save(&self, transaction: &TransactionWrapper<'_>) -> crate::CryptoKeystoreResult<()> { Self::ConnectionType::check_buffer_size(self.id.len())?; Self::ConnectionType::check_buffer_size(self.credential.len())?; @@ -189,7 +177,7 @@ impl EntityTransactionExt for MlsCredential { } #[async_trait::async_trait] -impl MlsCredentialExt for MlsCredential { +impl MlsCredentialExt for StoredCredential { async fn delete_by_credential( transaction: &TransactionWrapper<'_>, credential: Vec, diff --git a/keystore/src/entities/platform/generic/mls/encryption_keypair.rs b/keystore/src/entities/platform/generic/mls/encryption_keypair.rs index 6bbb6f2fec..fa0eb75b04 100644 --- a/keystore/src/entities/platform/generic/mls/encryption_keypair.rs +++ b/keystore/src/entities/platform/generic/mls/encryption_keypair.rs @@ -4,13 +4,13 @@ use crate::{ CryptoKeystoreResult, MissingKeyErrorKind, connection::{DatabaseConnection, KeystoreDatabaseConnection, TransactionWrapper}, entities::{ - Entity, EntityBase, EntityFindParams, EntityIdStringExt, EntityTransactionExt, MlsEncryptionKeyPair, + Entity, EntityBase, EntityFindParams, EntityIdStringExt, EntityTransactionExt, StoredEncryptionKeyPair, StringEntityId, }, }; #[async_trait::async_trait] -impl Entity for MlsEncryptionKeyPair { +impl Entity for StoredEncryptionKeyPair { fn id_raw(&self) -> &[u8] { self.pk.as_slice() } @@ -115,13 +115,13 @@ impl Entity for MlsEncryptionKeyPair { } #[async_trait::async_trait] -impl EntityBase for MlsEncryptionKeyPair { +impl EntityBase for StoredEncryptionKeyPair { type ConnectionType = KeystoreDatabaseConnection; type AutoGeneratedFields = (); const COLLECTION_NAME: &'static str = "mls_encryption_keypairs"; fn to_missing_key_err_kind() -> MissingKeyErrorKind { - MissingKeyErrorKind::MlsEncryptionKeyPair + MissingKeyErrorKind::StoredEncryptionKeyPair } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { @@ -130,7 +130,7 @@ impl EntityBase for MlsEncryptionKeyPair { } #[async_trait::async_trait] -impl EntityTransactionExt for MlsEncryptionKeyPair { +impl EntityTransactionExt for StoredEncryptionKeyPair { async fn save(&self, transaction: &TransactionWrapper<'_>) -> CryptoKeystoreResult<()> { use rusqlite::ToSql as _; diff --git a/keystore/src/entities/platform/generic/mls/hpke_private_key.rs b/keystore/src/entities/platform/generic/mls/hpke_private_key.rs index 8788651448..a51bffe0bb 100644 --- a/keystore/src/entities/platform/generic/mls/hpke_private_key.rs +++ b/keystore/src/entities/platform/generic/mls/hpke_private_key.rs @@ -6,13 +6,13 @@ use crate::{ CryptoKeystoreResult, MissingKeyErrorKind, connection::{DatabaseConnection, KeystoreDatabaseConnection, TransactionWrapper}, entities::{ - Entity, EntityBase, EntityFindParams, EntityIdStringExt, EntityTransactionExt, MlsHpkePrivateKey, + Entity, EntityBase, EntityFindParams, EntityIdStringExt, EntityTransactionExt, StoredHpkePrivateKey, StringEntityId, }, }; #[async_trait::async_trait] -impl Entity for MlsHpkePrivateKey { +impl Entity for StoredHpkePrivateKey { fn id_raw(&self) -> &[u8] { self.pk.as_slice() } @@ -97,13 +97,13 @@ impl Entity for MlsHpkePrivateKey { } #[async_trait::async_trait] -impl EntityBase for MlsHpkePrivateKey { +impl EntityBase for StoredHpkePrivateKey { type ConnectionType = KeystoreDatabaseConnection; type AutoGeneratedFields = (); const COLLECTION_NAME: &'static str = "mls_hpke_private_keys"; fn to_missing_key_err_kind() -> MissingKeyErrorKind { - MissingKeyErrorKind::MlsHpkePrivateKey + MissingKeyErrorKind::StoredHpkePrivateKey } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { @@ -112,7 +112,7 @@ impl EntityBase for MlsHpkePrivateKey { } #[async_trait::async_trait] -impl EntityTransactionExt for MlsHpkePrivateKey { +impl EntityTransactionExt for StoredHpkePrivateKey { async fn save(&self, transaction: &TransactionWrapper<'_>) -> CryptoKeystoreResult<()> { Self::ConnectionType::check_buffer_size(self.sk.len())?; Self::ConnectionType::check_buffer_size(self.pk.len())?; diff --git a/keystore/src/entities/platform/generic/mls/psk_bundle.rs b/keystore/src/entities/platform/generic/mls/psk_bundle.rs index b0e35b0449..7266857bf1 100644 --- a/keystore/src/entities/platform/generic/mls/psk_bundle.rs +++ b/keystore/src/entities/platform/generic/mls/psk_bundle.rs @@ -4,12 +4,12 @@ use crate::{ CryptoKeystoreResult, MissingKeyErrorKind, connection::{DatabaseConnection, KeystoreDatabaseConnection, TransactionWrapper}, entities::{ - Entity, EntityBase, EntityFindParams, EntityIdStringExt, EntityTransactionExt, MlsPskBundle, StringEntityId, + Entity, EntityBase, EntityFindParams, EntityIdStringExt, EntityTransactionExt, StoredPskBundle, StringEntityId, }, }; #[async_trait::async_trait] -impl Entity for MlsPskBundle { +impl Entity for StoredPskBundle { fn id_raw(&self) -> &[u8] { self.psk_id.as_slice() } @@ -89,13 +89,13 @@ impl Entity for MlsPskBundle { } #[async_trait::async_trait] -impl EntityBase for MlsPskBundle { +impl EntityBase for StoredPskBundle { type ConnectionType = KeystoreDatabaseConnection; type AutoGeneratedFields = (); const COLLECTION_NAME: &'static str = "mls_psk_bundles"; fn to_missing_key_err_kind() -> MissingKeyErrorKind { - MissingKeyErrorKind::MlsPskBundle + MissingKeyErrorKind::StoredPskBundle } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { @@ -104,7 +104,7 @@ impl EntityBase for MlsPskBundle { } #[async_trait::async_trait] -impl EntityTransactionExt for MlsPskBundle { +impl EntityTransactionExt for StoredPskBundle { async fn save(&self, transaction: &TransactionWrapper<'_>) -> CryptoKeystoreResult<()> { use rusqlite::ToSql as _; Self::ConnectionType::check_buffer_size(self.psk.len())?; diff --git a/keystore/src/entities/platform/generic/mls/signature_keypair.rs b/keystore/src/entities/platform/generic/mls/signature_keypair.rs index 737e266250..a15380de6b 100644 --- a/keystore/src/entities/platform/generic/mls/signature_keypair.rs +++ b/keystore/src/entities/platform/generic/mls/signature_keypair.rs @@ -1,13 +1,62 @@ use std::io::{Read, Write}; +use rusqlite::Transaction; + use crate::{ CryptoKeystoreResult, MissingKeyErrorKind, connection::{DatabaseConnection, KeystoreDatabaseConnection, TransactionWrapper}, - entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, MlsSignatureKeyPair, StringEntityId}, + entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, StoredSignatureKeypair, StringEntityId}, }; +impl StoredSignatureKeypair { + fn load(transaction: &Transaction<'_>, rowid: i64, signature_scheme: u16) -> crate::CryptoKeystoreResult { + let mut blob = transaction.blob_open( + rusqlite::DatabaseName::Main, + "mls_signature_keypairs", + "keypair", + rowid, + true, + )?; + + let mut keypair = vec![]; + blob.read_to_end(&mut keypair)?; + blob.close()?; + + let mut blob = transaction.blob_open( + rusqlite::DatabaseName::Main, + "mls_signature_keypairs", + "pk", + rowid, + true, + )?; + + let mut pk = vec![]; + blob.read_to_end(&mut pk)?; + blob.close()?; + + let mut blob = transaction.blob_open( + rusqlite::DatabaseName::Main, + "mls_signature_keypairs", + "credential_id", + rowid, + true, + )?; + + let mut credential_id = vec![]; + blob.read_to_end(&mut credential_id)?; + blob.close()?; + + Ok(Self { + signature_scheme, + keypair, + pk, + credential_id, + }) + } +} + #[async_trait::async_trait] -impl Entity for MlsSignatureKeyPair { +impl Entity for StoredSignatureKeypair { fn id_raw(&self) -> &[u8] { self.pk.as_slice() } @@ -23,59 +72,18 @@ impl Entity for MlsSignatureKeyPair { params.to_sql() ); - let mut stmt = transaction.prepare_cached(&query)?; - let mut rows = stmt.query_map([], |r| Ok((r.get(0)?, r.get(1)?)))?; - let entities = rows.try_fold(Vec::new(), |mut acc, rowid_result| { - use std::io::Read as _; - let (rowid, signature_scheme) = rowid_result?; - - let mut blob = transaction.blob_open( - rusqlite::DatabaseName::Main, - "mls_signature_keypairs", - "keypair", - rowid, - true, - )?; - - let mut keypair = vec![]; - blob.read_to_end(&mut keypair)?; - blob.close()?; - - let mut blob = transaction.blob_open( - rusqlite::DatabaseName::Main, - "mls_signature_keypairs", - "pk", - rowid, - true, - )?; - - let mut pk = vec![]; - blob.read_to_end(&mut pk)?; - blob.close()?; - - let mut blob = transaction.blob_open( - rusqlite::DatabaseName::Main, - "mls_signature_keypairs", - "credential_id", - rowid, - true, - )?; - - let mut credential_id = vec![]; - blob.read_to_end(&mut credential_id)?; - blob.close()?; - - acc.push(Self { - signature_scheme, - keypair, - pk, - credential_id, - }); - - crate::CryptoKeystoreResult::Ok(acc) - })?; - - Ok(entities) + transaction + .prepare_cached(&query)? + .query_map([], |row| { + let rowid = row.get(0)?; + let signature_scheme = row.get(1)?; + Ok((rowid, signature_scheme)) + })? + .map(|rowid_result| -> crate::CryptoKeystoreResult<_> { + let (rowid, signature_scheme) = rowid_result?; + Self::load(&transaction, rowid, signature_scheme) + }) + .collect() } async fn find_one( @@ -85,60 +93,15 @@ impl Entity for MlsSignatureKeyPair { let mut conn = conn.conn().await; let transaction = conn.transaction()?; use rusqlite::OptionalExtension as _; - let maybe_rowid = transaction + transaction .query_row( "SELECT rowid, signature_scheme FROM mls_signature_keypairs WHERE pk = ?", [id.as_slice()], |r| Ok((r.get::<_, i64>(0)?, r.get(1)?)), ) - .optional()?; - - if let Some((rowid, signature_scheme)) = maybe_rowid { - let mut blob = transaction.blob_open( - rusqlite::DatabaseName::Main, - "mls_signature_keypairs", - "pk", - rowid, - true, - )?; - - let mut pk = Vec::with_capacity(blob.len()); - blob.read_to_end(&mut pk)?; - blob.close()?; - - let mut blob = transaction.blob_open( - rusqlite::DatabaseName::Main, - "mls_signature_keypairs", - "keypair", - rowid, - true, - )?; - - let mut keypair = Vec::with_capacity(blob.len()); - blob.read_to_end(&mut keypair)?; - blob.close()?; - - let mut blob = transaction.blob_open( - rusqlite::DatabaseName::Main, - "mls_signature_keypairs", - "credential_id", - rowid, - true, - )?; - - let mut credential_id = Vec::with_capacity(blob.len()); - blob.read_to_end(&mut credential_id)?; - blob.close()?; - - Ok(Some(Self { - signature_scheme, - pk, - keypair, - credential_id, - })) - } else { - Ok(None) - } + .optional()? + .map(|(rowid, signature_scheme)| Self::load(&transaction, rowid, signature_scheme)) + .transpose() } async fn count(conn: &mut Self::ConnectionType) -> crate::CryptoKeystoreResult { @@ -149,13 +112,13 @@ impl Entity for MlsSignatureKeyPair { } #[async_trait::async_trait] -impl EntityBase for MlsSignatureKeyPair { +impl EntityBase for StoredSignatureKeypair { type ConnectionType = KeystoreDatabaseConnection; type AutoGeneratedFields = (); const COLLECTION_NAME: &'static str = "mls_signature_keypairs"; fn to_missing_key_err_kind() -> MissingKeyErrorKind { - MissingKeyErrorKind::MlsSignatureKeyPair + MissingKeyErrorKind::StoredSignatureKeypair } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { @@ -164,7 +127,7 @@ impl EntityBase for MlsSignatureKeyPair { } #[async_trait::async_trait] -impl EntityTransactionExt for MlsSignatureKeyPair { +impl EntityTransactionExt for StoredSignatureKeypair { async fn save(&self, transaction: &TransactionWrapper<'_>) -> CryptoKeystoreResult<()> { Self::ConnectionType::check_buffer_size(self.keypair.len())?; Self::ConnectionType::check_buffer_size(self.pk.len())?; diff --git a/keystore/src/entities/platform/wasm/mls/credential.rs b/keystore/src/entities/platform/wasm/mls/credential.rs index 364a3e48a4..771f30aa3f 100644 --- a/keystore/src/entities/platform/wasm/mls/credential.rs +++ b/keystore/src/entities/platform/wasm/mls/credential.rs @@ -5,27 +5,27 @@ use crate::{ CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind, connection::{DatabaseConnection, KeystoreDatabaseConnection, storage::WasmStorageTransaction}, entities::{ - Entity, EntityBase, EntityFindParams, EntityTransactionExt, MlsCredential, MlsCredentialExt, StringEntityId, + Entity, EntityBase, EntityFindParams, EntityTransactionExt, MlsCredentialExt, StoredCredential, StringEntityId, }, }; #[async_trait::async_trait(?Send)] -impl EntityBase for MlsCredential { +impl EntityBase for StoredCredential { type ConnectionType = KeystoreDatabaseConnection; type AutoGeneratedFields = u64; const COLLECTION_NAME: &'static str = "mls_credentials"; fn to_missing_key_err_kind() -> MissingKeyErrorKind { - MissingKeyErrorKind::MlsCredential + MissingKeyErrorKind::StoredCredential } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::MlsCredential(self) + crate::transaction::dynamic_dispatch::Entity::StoredCredential(self) } } #[async_trait::async_trait(?Send)] -impl EntityTransactionExt for MlsCredential { +impl EntityTransactionExt for StoredCredential { async fn pre_save<'a>(&'a mut self) -> CryptoKeystoreResult { let now = SystemTime::now(); let created_at = now @@ -38,7 +38,7 @@ impl EntityTransactionExt for MlsCredential { } #[async_trait::async_trait(?Send)] -impl Entity for MlsCredential { +impl Entity for StoredCredential { fn id_raw(&self) -> &[u8] { self.id.as_slice() } @@ -80,7 +80,7 @@ impl Entity for MlsCredential { } #[async_trait::async_trait(?Send)] -impl MlsCredentialExt for MlsCredential { +impl MlsCredentialExt for StoredCredential { async fn delete_by_credential( transaction: &WasmStorageTransaction<'_>, credential: Vec, @@ -100,7 +100,7 @@ impl MlsCredentialExt for MlsCredential { return Err(CryptoKeystoreError::NotFound(reason, value)); }; - let mut credential = serde_wasm_bindgen::from_value::(entity_raw)?; + let mut credential = serde_wasm_bindgen::from_value::(entity_raw)?; credential.decrypt(cipher)?; let id = JsValue::from(credential.id.clone()); diff --git a/keystore/src/entities/platform/wasm/mls/encryption_keypair.rs b/keystore/src/entities/platform/wasm/mls/encryption_keypair.rs index 0a4d999223..54c0ab63dd 100644 --- a/keystore/src/entities/platform/wasm/mls/encryption_keypair.rs +++ b/keystore/src/entities/platform/wasm/mls/encryption_keypair.rs @@ -1,17 +1,17 @@ use crate::{ CryptoKeystoreResult, MissingKeyErrorKind, connection::{DatabaseConnection, KeystoreDatabaseConnection}, - entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, MlsEncryptionKeyPair, StringEntityId}, + entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, StoredEncryptionKeyPair, StringEntityId}, }; #[async_trait::async_trait(?Send)] -impl EntityBase for MlsEncryptionKeyPair { +impl EntityBase for StoredEncryptionKeyPair { type ConnectionType = KeystoreDatabaseConnection; type AutoGeneratedFields = (); const COLLECTION_NAME: &'static str = "mls_encryption_keypairs"; fn to_missing_key_err_kind() -> MissingKeyErrorKind { - MissingKeyErrorKind::MlsEncryptionKeyPair + MissingKeyErrorKind::StoredEncryptionKeyPair } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { @@ -20,10 +20,10 @@ impl EntityBase for MlsEncryptionKeyPair { } #[async_trait::async_trait(?Send)] -impl EntityTransactionExt for MlsEncryptionKeyPair {} +impl EntityTransactionExt for StoredEncryptionKeyPair {} #[async_trait::async_trait(?Send)] -impl Entity for MlsEncryptionKeyPair { +impl Entity for StoredEncryptionKeyPair { fn id_raw(&self) -> &[u8] { self.pk.as_slice() } diff --git a/keystore/src/entities/platform/wasm/mls/hpke_private_key.rs b/keystore/src/entities/platform/wasm/mls/hpke_private_key.rs index 9135644cf0..9664db0b86 100644 --- a/keystore/src/entities/platform/wasm/mls/hpke_private_key.rs +++ b/keystore/src/entities/platform/wasm/mls/hpke_private_key.rs @@ -1,17 +1,17 @@ use crate::{ CryptoKeystoreResult, MissingKeyErrorKind, connection::{DatabaseConnection, KeystoreDatabaseConnection}, - entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, MlsHpkePrivateKey, StringEntityId}, + entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, StoredHpkePrivateKey, StringEntityId}, }; #[async_trait::async_trait(?Send)] -impl EntityBase for MlsHpkePrivateKey { +impl EntityBase for StoredHpkePrivateKey { type ConnectionType = KeystoreDatabaseConnection; type AutoGeneratedFields = (); const COLLECTION_NAME: &'static str = "mls_hpke_private_keys"; fn to_missing_key_err_kind() -> MissingKeyErrorKind { - MissingKeyErrorKind::MlsHpkePrivateKey + MissingKeyErrorKind::StoredHpkePrivateKey } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { @@ -20,10 +20,10 @@ impl EntityBase for MlsHpkePrivateKey { } #[async_trait::async_trait(?Send)] -impl EntityTransactionExt for MlsHpkePrivateKey {} +impl EntityTransactionExt for StoredHpkePrivateKey {} #[async_trait::async_trait(?Send)] -impl Entity for MlsHpkePrivateKey { +impl Entity for StoredHpkePrivateKey { fn id_raw(&self) -> &[u8] { self.pk.as_slice() } diff --git a/keystore/src/entities/platform/wasm/mls/psk_bundle.rs b/keystore/src/entities/platform/wasm/mls/psk_bundle.rs index 577a3bcf22..655c3c9450 100644 --- a/keystore/src/entities/platform/wasm/mls/psk_bundle.rs +++ b/keystore/src/entities/platform/wasm/mls/psk_bundle.rs @@ -1,17 +1,17 @@ use crate::{ CryptoKeystoreResult, MissingKeyErrorKind, connection::{DatabaseConnection, KeystoreDatabaseConnection}, - entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, MlsPskBundle, StringEntityId}, + entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, StoredPskBundle, StringEntityId}, }; #[async_trait::async_trait(?Send)] -impl EntityBase for MlsPskBundle { +impl EntityBase for StoredPskBundle { type ConnectionType = KeystoreDatabaseConnection; type AutoGeneratedFields = (); const COLLECTION_NAME: &'static str = "mls_psk_bundles"; fn to_missing_key_err_kind() -> MissingKeyErrorKind { - MissingKeyErrorKind::MlsPskBundle + MissingKeyErrorKind::StoredPskBundle } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { @@ -20,10 +20,10 @@ impl EntityBase for MlsPskBundle { } #[async_trait::async_trait(?Send)] -impl EntityTransactionExt for MlsPskBundle {} +impl EntityTransactionExt for StoredPskBundle {} #[async_trait::async_trait(?Send)] -impl Entity for MlsPskBundle { +impl Entity for StoredPskBundle { fn id_raw(&self) -> &[u8] { self.psk_id.as_slice() } diff --git a/keystore/src/entities/platform/wasm/mls/signature_keypair.rs b/keystore/src/entities/platform/wasm/mls/signature_keypair.rs index 852826f201..a427dc3a9c 100644 --- a/keystore/src/entities/platform/wasm/mls/signature_keypair.rs +++ b/keystore/src/entities/platform/wasm/mls/signature_keypair.rs @@ -1,17 +1,17 @@ use crate::{ CryptoKeystoreResult, MissingKeyErrorKind, connection::{DatabaseConnection, KeystoreDatabaseConnection}, - entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, MlsSignatureKeyPair, StringEntityId}, + entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, StoredSignatureKeypair, StringEntityId}, }; #[async_trait::async_trait(?Send)] -impl EntityBase for MlsSignatureKeyPair { +impl EntityBase for StoredSignatureKeypair { type ConnectionType = KeystoreDatabaseConnection; type AutoGeneratedFields = (); const COLLECTION_NAME: &'static str = "mls_signature_keypairs"; fn to_missing_key_err_kind() -> MissingKeyErrorKind { - MissingKeyErrorKind::MlsSignatureKeyPair + MissingKeyErrorKind::StoredSignatureKeypair } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { @@ -20,10 +20,10 @@ impl EntityBase for MlsSignatureKeyPair { } #[async_trait::async_trait(?Send)] -impl EntityTransactionExt for MlsSignatureKeyPair {} +impl EntityTransactionExt for StoredSignatureKeypair {} #[async_trait::async_trait(?Send)] -impl Entity for MlsSignatureKeyPair { +impl Entity for StoredSignatureKeypair { fn id_raw(&self) -> &[u8] { self.pk.as_slice() } diff --git a/keystore/src/error.rs b/keystore/src/error.rs index 2f47ed831a..2ef2ee5b5c 100644 --- a/keystore/src/error.rs +++ b/keystore/src/error.rs @@ -4,21 +4,21 @@ pub enum MissingKeyErrorKind { #[error("Consumer Data")] ConsumerData, #[error("MLS KeyPackage")] - MlsKeyPackage, + StoredKeypackage, #[error("MLS SignatureKeyPair")] - MlsSignatureKeyPair, + StoredSignatureKeypair, #[error("MLS HpkePrivateKey")] - MlsHpkePrivateKey, + StoredHpkePrivateKey, #[error("MLS EncryptionKeyPair")] - MlsEncryptionKeyPair, + StoredEncryptionKeyPair, #[error("MLS Epoch EncryptionKeyPair")] - MlsEpochEncryptionKeyPair, + StoredEpochEncryptionKeypair, #[error("MLS PreSharedKeyBundle")] - MlsPskBundle, - #[error("MLS CredentialBundle")] - MlsCredential, + StoredPskBundle, + #[error("MLS Credential")] + StoredCredential, #[error("MLS Buffered Commit")] - MlsBufferedCommit, + StoredBufferedCommit, #[error("MLS Persisted Group")] PersistedMlsGroup, #[error("MLS Persisted Pending Group")] @@ -26,7 +26,7 @@ pub enum MissingKeyErrorKind { #[error("MLS Pending Messages")] MlsPendingMessages, #[error("End-to-end identity enrollment")] - E2eiEnrollment, + StoredE2eiEnrollment, #[error("OIDC refresh token")] E2eiRefreshToken, #[error("End-to-end identity root trust anchor CA cert")] diff --git a/keystore/src/mls.rs b/keystore/src/mls.rs index 4af1d55eb3..f8476cfeca 100644 --- a/keystore/src/mls.rs +++ b/keystore/src/mls.rs @@ -5,8 +5,8 @@ use crate::{ CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind, connection::FetchFromDatabase, entities::{ - E2eiEnrollment, EntityFindParams, MlsEncryptionKeyPair, MlsEpochEncryptionKeyPair, MlsHpkePrivateKey, - MlsKeyPackage, MlsPskBundle, MlsSignatureKeyPair, PersistedMlsGroup, PersistedMlsPendingGroup, + EntityFindParams, PersistedMlsGroup, PersistedMlsPendingGroup, StoredE2eiEnrollment, StoredEncryptionKeyPair, + StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle, StoredSignatureKeypair, }, }; @@ -123,15 +123,9 @@ pub trait CryptoKeystoreMls: Sized { #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] impl CryptoKeystoreMls for crate::Database { async fn mls_fetch_keypackages(&self, count: u32) -> CryptoKeystoreResult> { - cfg_if::cfg_if! { - if #[cfg(not(target_family = "wasm"))] { - let reverse = true; - } else { - let reverse = false; - } - } + let reverse = !cfg!(target_family = "wasm"); let keypackages = self - .find_all::(EntityFindParams { + .find_all::(EntityFindParams { limit: Some(count), offset: None, reverse, @@ -214,7 +208,7 @@ impl CryptoKeystoreMls for crate::Database { } async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()> { - self.save(E2eiEnrollment { + self.save(StoredE2eiEnrollment { id: id.into(), content: content.into(), }) @@ -225,12 +219,12 @@ impl CryptoKeystoreMls for crate::Database { async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult> { // someone who has time could try to optimize this but honestly it's really on the cold path let enrollment = self - .find::(id) + .find::(id) .await? .ok_or(CryptoKeystoreError::MissingKeyInStore( - MissingKeyErrorKind::E2eiEnrollment, + MissingKeyErrorKind::StoredE2eiEnrollment, ))?; - self.remove::(id).await?; + self.remove::(id).await?; Ok(enrollment.content.clone()) } } @@ -275,7 +269,7 @@ impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Database // Having an empty credential id seems tolerable, since the SignatureKeyPair type is retrieved from the key store via its public key. let credential_id = vec![]; - let kp = MlsSignatureKeyPair::new( + let kp = StoredSignatureKeypair::new( concrete_signature_keypair.signature_scheme(), k.into(), data, @@ -284,29 +278,29 @@ impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Database self.save(kp).await?; } MlsEntityId::KeyPackage => { - let kp = MlsKeyPackage { + let kp = StoredKeypackage { keypackage_ref: k.into(), keypackage: data, }; self.save(kp).await?; } MlsEntityId::HpkePrivateKey => { - let kp = MlsHpkePrivateKey { pk: k.into(), sk: data }; + let kp = StoredHpkePrivateKey { pk: k.into(), sk: data }; self.save(kp).await?; } MlsEntityId::PskBundle => { - let kp = MlsPskBundle { + let kp = StoredPskBundle { psk_id: k.into(), psk: data, }; self.save(kp).await?; } MlsEntityId::EncryptionKeyPair => { - let kp = MlsEncryptionKeyPair { pk: k.into(), sk: data }; + let kp = StoredEncryptionKeyPair { pk: k.into(), sk: data }; self.save(kp).await?; } MlsEntityId::EpochEncryptionKeyPair => { - let kp = MlsEpochEncryptionKeyPair { + let kp = StoredEpochEncryptionKeypair { id: k.into(), keypairs: data, }; @@ -331,27 +325,27 @@ impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Database deser(&group.state).ok() } MlsEntityId::SignatureKeyPair => { - let sig: MlsSignatureKeyPair = self.find(k).await.ok().flatten()?; + let sig: StoredSignatureKeypair = self.find(k).await.ok().flatten()?; deser(&sig.keypair).ok() } MlsEntityId::KeyPackage => { - let kp: MlsKeyPackage = self.find(k).await.ok().flatten()?; + let kp: StoredKeypackage = self.find(k).await.ok().flatten()?; deser(&kp.keypackage).ok() } MlsEntityId::HpkePrivateKey => { - let hpke_pk: MlsHpkePrivateKey = self.find(k).await.ok().flatten()?; + let hpke_pk: StoredHpkePrivateKey = self.find(k).await.ok().flatten()?; deser(&hpke_pk.sk).ok() } MlsEntityId::PskBundle => { - let psk_bundle: MlsPskBundle = self.find(k).await.ok().flatten()?; + let psk_bundle: StoredPskBundle = self.find(k).await.ok().flatten()?; deser(&psk_bundle.psk).ok() } MlsEntityId::EncryptionKeyPair => { - let kp: MlsEncryptionKeyPair = self.find(k).await.ok().flatten()?; + let kp: StoredEncryptionKeyPair = self.find(k).await.ok().flatten()?; deser(&kp.sk).ok() } MlsEntityId::EpochEncryptionKeyPair => { - let kp: MlsEpochEncryptionKeyPair = self.find(k).await.ok().flatten()?; + let kp: StoredEpochEncryptionKeypair = self.find(k).await.ok().flatten()?; deser(&kp.keypairs).ok() } } @@ -360,12 +354,12 @@ impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Database async fn delete(&self, k: &[u8]) -> Result<(), Self::Error> { match V::ID { MlsEntityId::GroupState => self.remove::(k).await?, - MlsEntityId::SignatureKeyPair => self.remove::(k).await?, - MlsEntityId::HpkePrivateKey => self.remove::(k).await?, - MlsEntityId::KeyPackage => self.remove::(k).await?, - MlsEntityId::PskBundle => self.remove::(k).await?, - MlsEntityId::EncryptionKeyPair => self.remove::(k).await?, - MlsEntityId::EpochEncryptionKeyPair => self.remove::(k).await?, + MlsEntityId::SignatureKeyPair => self.remove::(k).await?, + MlsEntityId::HpkePrivateKey => self.remove::(k).await?, + MlsEntityId::KeyPackage => self.remove::(k).await?, + MlsEntityId::PskBundle => self.remove::(k).await?, + MlsEntityId::EncryptionKeyPair => self.remove::(k).await?, + MlsEntityId::EpochEncryptionKeyPair => self.remove::(k).await?, } Ok(()) diff --git a/keystore/src/transaction/dynamic_dispatch.rs b/keystore/src/transaction/dynamic_dispatch.rs index ec277b6db4..09672a47a8 100644 --- a/keystore/src/transaction/dynamic_dispatch.rs +++ b/keystore/src/transaction/dynamic_dispatch.rs @@ -9,28 +9,28 @@ use crate::{ CryptoKeystoreError, CryptoKeystoreResult, connection::TransactionWrapper, entities::{ - ConsumerData, E2eiAcmeCA, E2eiCrl, E2eiEnrollment, E2eiIntermediateCert, EntityBase, EntityTransactionExt, - MlsBufferedCommit, MlsCredential, MlsEncryptionKeyPair, MlsEpochEncryptionKeyPair, MlsHpkePrivateKey, - MlsKeyPackage, MlsPendingMessage, MlsPskBundle, MlsSignatureKeyPair, PersistedMlsGroup, - PersistedMlsPendingGroup, StringEntityId, UniqueEntity, + ConsumerData, E2eiAcmeCA, E2eiCrl, E2eiIntermediateCert, EntityBase, EntityTransactionExt, MlsPendingMessage, + PersistedMlsGroup, PersistedMlsPendingGroup, StoredBufferedCommit, StoredCredential, StoredE2eiEnrollment, + StoredEncryptionKeyPair, StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle, + StoredSignatureKeypair, StringEntityId, UniqueEntity, }, }; #[derive(Debug)] pub enum Entity { ConsumerData(ConsumerData), - SignatureKeyPair(MlsSignatureKeyPair), - HpkePrivateKey(MlsHpkePrivateKey), - MlsKeyPackage(MlsKeyPackage), - PskBundle(MlsPskBundle), - EncryptionKeyPair(MlsEncryptionKeyPair), - MlsEpochEncryptionKeyPair(MlsEpochEncryptionKeyPair), - MlsCredential(MlsCredential), - MlsBufferedCommit(MlsBufferedCommit), + SignatureKeyPair(StoredSignatureKeypair), + HpkePrivateKey(StoredHpkePrivateKey), + StoredKeypackage(StoredKeypackage), + PskBundle(StoredPskBundle), + EncryptionKeyPair(StoredEncryptionKeyPair), + StoredEpochEncryptionKeypair(StoredEpochEncryptionKeypair), + StoredCredential(StoredCredential), + StoredBufferedCommit(StoredBufferedCommit), PersistedMlsGroup(PersistedMlsGroup), PersistedMlsPendingGroup(PersistedMlsPendingGroup), MlsPendingMessage(MlsPendingMessage), - E2eiEnrollment(E2eiEnrollment), + StoredE2eiEnrollment(StoredE2eiEnrollment), #[cfg(target_family = "wasm")] E2eiRefreshToken(E2eiRefreshToken), E2eiAcmeCA(E2eiAcmeCA), @@ -52,12 +52,12 @@ pub enum EntityId { PskBundle(Vec), EncryptionKeyPair(Vec), EpochEncryptionKeyPair(Vec), - MlsCredential(Vec), - MlsBufferedCommit(Vec), + StoredCredential(Vec), + StoredBufferedCommit(Vec), PersistedMlsGroup(Vec), PersistedMlsPendingGroup(Vec), MlsPendingMessage(Vec), - E2eiEnrollment(Vec), + StoredE2eiEnrollment(Vec), #[cfg(target_family = "wasm")] E2eiRefreshToken(Vec), E2eiAcmeCA(Vec), @@ -80,12 +80,12 @@ impl EntityId { EntityId::PskBundle(vec) => vec.as_slice().into(), EntityId::EncryptionKeyPair(vec) => vec.as_slice().into(), EntityId::EpochEncryptionKeyPair(vec) => vec.as_slice().into(), - EntityId::MlsCredential(vec) => vec.as_slice().into(), - EntityId::MlsBufferedCommit(vec) => vec.as_slice().into(), + EntityId::StoredCredential(vec) => vec.as_slice().into(), + EntityId::StoredBufferedCommit(vec) => vec.as_slice().into(), EntityId::PersistedMlsGroup(vec) => vec.as_slice().into(), EntityId::PersistedMlsPendingGroup(vec) => vec.as_slice().into(), EntityId::MlsPendingMessage(vec) => vec.as_slice().into(), - EntityId::E2eiEnrollment(vec) => vec.as_slice().into(), + EntityId::StoredE2eiEnrollment(vec) => vec.as_slice().into(), #[cfg(target_family = "wasm")] EntityId::E2eiRefreshToken(vec) => vec.as_slice().into(), EntityId::E2eiAcmeCA(vec) => vec.as_slice().into(), @@ -102,18 +102,18 @@ impl EntityId { pub(crate) fn from_collection_name(entity_id: &'static str, id: &[u8]) -> CryptoKeystoreResult { match entity_id { - MlsSignatureKeyPair::COLLECTION_NAME => Ok(Self::SignatureKeyPair(id.into())), - MlsHpkePrivateKey::COLLECTION_NAME => Ok(Self::HpkePrivateKey(id.into())), - MlsKeyPackage::COLLECTION_NAME => Ok(Self::KeyPackage(id.into())), - MlsPskBundle::COLLECTION_NAME => Ok(Self::PskBundle(id.into())), - MlsEncryptionKeyPair::COLLECTION_NAME => Ok(Self::EncryptionKeyPair(id.into())), - MlsEpochEncryptionKeyPair::COLLECTION_NAME => Ok(Self::EpochEncryptionKeyPair(id.into())), - MlsBufferedCommit::COLLECTION_NAME => Ok(Self::MlsBufferedCommit(id.into())), + StoredSignatureKeypair::COLLECTION_NAME => Ok(Self::SignatureKeyPair(id.into())), + StoredHpkePrivateKey::COLLECTION_NAME => Ok(Self::HpkePrivateKey(id.into())), + StoredKeypackage::COLLECTION_NAME => Ok(Self::KeyPackage(id.into())), + StoredPskBundle::COLLECTION_NAME => Ok(Self::PskBundle(id.into())), + StoredEncryptionKeyPair::COLLECTION_NAME => Ok(Self::EncryptionKeyPair(id.into())), + StoredEpochEncryptionKeypair::COLLECTION_NAME => Ok(Self::EpochEncryptionKeyPair(id.into())), + StoredBufferedCommit::COLLECTION_NAME => Ok(Self::StoredBufferedCommit(id.into())), PersistedMlsGroup::COLLECTION_NAME => Ok(Self::PersistedMlsGroup(id.into())), PersistedMlsPendingGroup::COLLECTION_NAME => Ok(Self::PersistedMlsPendingGroup(id.into())), - MlsCredential::COLLECTION_NAME => Ok(Self::MlsCredential(id.into())), + StoredCredential::COLLECTION_NAME => Ok(Self::StoredCredential(id.into())), MlsPendingMessage::COLLECTION_NAME => Ok(Self::MlsPendingMessage(id.into())), - E2eiEnrollment::COLLECTION_NAME => Ok(Self::E2eiEnrollment(id.into())), + StoredE2eiEnrollment::COLLECTION_NAME => Ok(Self::StoredE2eiEnrollment(id.into())), E2eiCrl::COLLECTION_NAME => Ok(Self::E2eiCrl(id.into())), E2eiAcmeCA::COLLECTION_NAME => Ok(Self::E2eiAcmeCA(id.into())), #[cfg(target_family = "wasm")] @@ -131,17 +131,17 @@ impl EntityId { pub(crate) fn collection_name(&self) -> &'static str { match self { - EntityId::SignatureKeyPair(_) => MlsSignatureKeyPair::COLLECTION_NAME, - EntityId::KeyPackage(_) => MlsKeyPackage::COLLECTION_NAME, - EntityId::PskBundle(_) => MlsPskBundle::COLLECTION_NAME, - EntityId::EncryptionKeyPair(_) => MlsEncryptionKeyPair::COLLECTION_NAME, - EntityId::EpochEncryptionKeyPair(_) => MlsEpochEncryptionKeyPair::COLLECTION_NAME, - EntityId::MlsCredential(_) => MlsCredential::COLLECTION_NAME, - EntityId::MlsBufferedCommit(_) => MlsBufferedCommit::COLLECTION_NAME, + EntityId::SignatureKeyPair(_) => StoredSignatureKeypair::COLLECTION_NAME, + EntityId::KeyPackage(_) => StoredKeypackage::COLLECTION_NAME, + EntityId::PskBundle(_) => StoredPskBundle::COLLECTION_NAME, + EntityId::EncryptionKeyPair(_) => StoredEncryptionKeyPair::COLLECTION_NAME, + EntityId::EpochEncryptionKeyPair(_) => StoredEpochEncryptionKeypair::COLLECTION_NAME, + EntityId::StoredCredential(_) => StoredCredential::COLLECTION_NAME, + EntityId::StoredBufferedCommit(_) => StoredBufferedCommit::COLLECTION_NAME, EntityId::PersistedMlsGroup(_) => PersistedMlsGroup::COLLECTION_NAME, EntityId::PersistedMlsPendingGroup(_) => PersistedMlsPendingGroup::COLLECTION_NAME, EntityId::MlsPendingMessage(_) => MlsPendingMessage::COLLECTION_NAME, - EntityId::E2eiEnrollment(_) => E2eiEnrollment::COLLECTION_NAME, + EntityId::StoredE2eiEnrollment(_) => StoredE2eiEnrollment::COLLECTION_NAME, #[cfg(target_family = "wasm")] EntityId::E2eiRefreshToken(_) => E2eiRefreshToken::COLLECTION_NAME, EntityId::E2eiAcmeCA(_) => E2eiAcmeCA::COLLECTION_NAME, @@ -153,7 +153,7 @@ impl EntityId { EntityId::ProteusPrekey(_) => ProteusPrekey::COLLECTION_NAME, #[cfg(feature = "proteus-keystore")] EntityId::ProteusSession(_) => ProteusSession::COLLECTION_NAME, - EntityId::HpkePrivateKey(_) => MlsHpkePrivateKey::COLLECTION_NAME, + EntityId::HpkePrivateKey(_) => StoredHpkePrivateKey::COLLECTION_NAME, } } } @@ -163,18 +163,18 @@ pub async fn execute_save(tx: &TransactionWrapper<'_>, entity: &Entity) -> Crypt Entity::ConsumerData(consumer_data) => consumer_data.replace(tx).await, Entity::SignatureKeyPair(mls_signature_key_pair) => mls_signature_key_pair.save(tx).await, Entity::HpkePrivateKey(mls_hpke_private_key) => mls_hpke_private_key.save(tx).await, - Entity::MlsKeyPackage(mls_key_package) => mls_key_package.save(tx).await, + Entity::StoredKeypackage(mls_key_package) => mls_key_package.save(tx).await, Entity::PskBundle(mls_psk_bundle) => mls_psk_bundle.save(tx).await, Entity::EncryptionKeyPair(mls_encryption_key_pair) => mls_encryption_key_pair.save(tx).await, - Entity::MlsEpochEncryptionKeyPair(mls_epoch_encryption_key_pair) => { + Entity::StoredEpochEncryptionKeypair(mls_epoch_encryption_key_pair) => { mls_epoch_encryption_key_pair.save(tx).await } - Entity::MlsCredential(mls_credential) => mls_credential.save(tx).await, - Entity::MlsBufferedCommit(mls_pending_commit) => mls_pending_commit.save(tx).await, + Entity::StoredCredential(mls_credential) => mls_credential.save(tx).await, + Entity::StoredBufferedCommit(mls_pending_commit) => mls_pending_commit.save(tx).await, Entity::PersistedMlsGroup(persisted_mls_group) => persisted_mls_group.save(tx).await, Entity::PersistedMlsPendingGroup(persisted_mls_pending_group) => persisted_mls_pending_group.save(tx).await, Entity::MlsPendingMessage(mls_pending_message) => mls_pending_message.save(tx).await, - Entity::E2eiEnrollment(e2ei_enrollment) => e2ei_enrollment.save(tx).await, + Entity::StoredE2eiEnrollment(e2ei_enrollment) => e2ei_enrollment.save(tx).await, #[cfg(target_family = "wasm")] Entity::E2eiRefreshToken(e2ei_refresh_token) => e2ei_refresh_token.replace(tx).await, Entity::E2eiAcmeCA(e2ei_acme_ca) => e2ei_acme_ca.replace(tx).await, @@ -191,18 +191,18 @@ pub async fn execute_save(tx: &TransactionWrapper<'_>, entity: &Entity) -> Crypt pub async fn execute_delete(tx: &TransactionWrapper<'_>, entity_id: &EntityId) -> CryptoKeystoreResult<()> { match entity_id { - id @ EntityId::SignatureKeyPair(_) => MlsSignatureKeyPair::delete(tx, id.as_id()).await, - id @ EntityId::HpkePrivateKey(_) => MlsHpkePrivateKey::delete(tx, id.as_id()).await, - id @ EntityId::KeyPackage(_) => MlsKeyPackage::delete(tx, id.as_id()).await, - id @ EntityId::PskBundle(_) => MlsPskBundle::delete(tx, id.as_id()).await, - id @ EntityId::EncryptionKeyPair(_) => MlsEncryptionKeyPair::delete(tx, id.as_id()).await, - id @ EntityId::EpochEncryptionKeyPair(_) => MlsEpochEncryptionKeyPair::delete(tx, id.as_id()).await, - id @ EntityId::MlsCredential(_) => MlsCredential::delete(tx, id.as_id()).await, - id @ EntityId::MlsBufferedCommit(_) => MlsBufferedCommit::delete(tx, id.as_id()).await, + id @ EntityId::SignatureKeyPair(_) => StoredSignatureKeypair::delete(tx, id.as_id()).await, + id @ EntityId::HpkePrivateKey(_) => StoredHpkePrivateKey::delete(tx, id.as_id()).await, + id @ EntityId::KeyPackage(_) => StoredKeypackage::delete(tx, id.as_id()).await, + id @ EntityId::PskBundle(_) => StoredPskBundle::delete(tx, id.as_id()).await, + id @ EntityId::EncryptionKeyPair(_) => StoredEncryptionKeyPair::delete(tx, id.as_id()).await, + id @ EntityId::EpochEncryptionKeyPair(_) => StoredEpochEncryptionKeypair::delete(tx, id.as_id()).await, + id @ EntityId::StoredCredential(_) => StoredCredential::delete(tx, id.as_id()).await, + id @ EntityId::StoredBufferedCommit(_) => StoredBufferedCommit::delete(tx, id.as_id()).await, id @ EntityId::PersistedMlsGroup(_) => PersistedMlsGroup::delete(tx, id.as_id()).await, id @ EntityId::PersistedMlsPendingGroup(_) => PersistedMlsPendingGroup::delete(tx, id.as_id()).await, id @ EntityId::MlsPendingMessage(_) => MlsPendingMessage::delete(tx, id.as_id()).await, - id @ EntityId::E2eiEnrollment(_) => E2eiEnrollment::delete(tx, id.as_id()).await, + id @ EntityId::StoredE2eiEnrollment(_) => StoredE2eiEnrollment::delete(tx, id.as_id()).await, #[cfg(target_family = "wasm")] id @ EntityId::E2eiRefreshToken(_) => E2eiRefreshToken::delete(tx, id.as_id()).await, id @ EntityId::E2eiAcmeCA(_) => E2eiAcmeCA::delete(tx, id.as_id()).await, diff --git a/keystore/src/transaction/mod.rs b/keystore/src/transaction/mod.rs index 9176ba2d46..ec74b846ed 100644 --- a/keystore/src/transaction/mod.rs +++ b/keystore/src/transaction/mod.rs @@ -54,7 +54,7 @@ impl KeystoreTransaction { let table = cache_guard.entry(E::COLLECTION_NAME.to_string()).or_default(); let serialized = postcard::to_stdvec(&entity)?; // Use merge_key() because `id_raw()` is not always unique for records. - // For `MlsCredential`, `id_raw()` is the `CLientId`. + // For `StoredCredential`, `id_raw()` is the `CLientId`. // For `MlsPendingMessage` it's the id of the group it belongs to. table.insert(entity.merge_key(), Zeroizing::new(serialized)); Ok(entity) @@ -104,7 +104,7 @@ impl KeystoreTransaction { pub(crate) async fn cred_delete_by_credential(&self, cred: Vec) -> CryptoKeystoreResult<()> { let mut cache_guard = self.cache.write().await; - if let Entry::Occupied(mut table) = cache_guard.entry(MlsCredential::COLLECTION_NAME.to_string()) { + if let Entry::Occupied(mut table) = cache_guard.entry(StoredCredential::COLLECTION_NAME.to_string()) { table.get_mut().retain(|_, value| **value != cred); } @@ -292,7 +292,7 @@ impl KeystoreTransaction { maybe_credential: &E, deleted_credentials: &[Vec], ) -> bool { - let Some(credential) = maybe_credential.downcast::() else { + let Some(credential) = maybe_credential.downcast::() else { return false; }; deleted_credentials.contains(&credential.credential) @@ -312,8 +312,8 @@ impl KeystoreTransaction { /// commit_transaction!( /// transaction, db, /// [ -/// (identifier_01, MlsCredential), -/// (identifier_02, MlsSignatureKeyPair), +/// (identifier_01, StoredCredential), +/// (identifier_02, StoredSignatureKeypair), /// ], /// ); /// @@ -322,8 +322,8 @@ impl KeystoreTransaction { /// commit_transaction!( /// transaction, db, /// [ -/// (identifier_01, MlsCredential), -/// (identifier_02, MlsSignatureKeyPair), +/// (identifier_01, StoredCredential), +/// (identifier_02, StoredSignatureKeypair), /// ], /// proteus_types: [ /// (identifier_03, ProteusPrekey), @@ -386,7 +386,7 @@ macro_rules! commit_transaction { } for deleted_credential in $keystore_transaction.deleted_credentials.read().await.iter() { - MlsCredential::delete_by_credential(&tx, deleted_credential.to_owned()).await?; + StoredCredential::delete_by_credential(&tx, deleted_credential.to_owned()).await?; } tx.commit_tx().await?; @@ -400,17 +400,17 @@ impl KeystoreTransaction { commit_transaction!( self, db, [ - (identifier_01, MlsCredential), - (identifier_02, MlsSignatureKeyPair), - (identifier_03, MlsHpkePrivateKey), - (identifier_04, MlsEncryptionKeyPair), - (identifier_05, MlsEpochEncryptionKeyPair), - (identifier_06, MlsPskBundle), - (identifier_07, MlsKeyPackage), + (identifier_01, StoredCredential), + (identifier_02, StoredSignatureKeypair), + (identifier_03, StoredHpkePrivateKey), + (identifier_04, StoredEncryptionKeyPair), + (identifier_05, StoredEpochEncryptionKeypair), + (identifier_06, StoredPskBundle), + (identifier_07, StoredKeypackage), (identifier_08, PersistedMlsGroup), (identifier_09, PersistedMlsPendingGroup), (identifier_10, MlsPendingMessage), - (identifier_11, E2eiEnrollment), + (identifier_11, StoredE2eiEnrollment), // (identifier_12, E2eiRefreshToken), (identifier_13, E2eiAcmeCA), (identifier_14, E2eiIntermediateCert), diff --git a/keystore/tests/mls.rs b/keystore/tests/mls.rs index 55be9cc83a..35ffe8c014 100644 --- a/keystore/tests/mls.rs +++ b/keystore/tests/mls.rs @@ -18,8 +18,8 @@ mod tests { use core_crypto_keystore::{ MissingKeyErrorKind, entities::{ - EntityBase, MlsCredential, MlsHpkePrivateKey, MlsKeyPackage, MlsPskBundle, MlsSignatureKeyPair, - PersistedMlsGroup, PersistedMlsPendingGroup, + EntityBase, PersistedMlsGroup, PersistedMlsPendingGroup, StoredCredential, StoredHpkePrivateKey, + StoredKeypackage, StoredPskBundle, StoredSignatureKeypair, }, }; use mls_crypto_provider::MlsCryptoProvider; @@ -30,13 +30,13 @@ mod tests { #[wasm_bindgen_test] fn mls_entities_have_correct_error_kinds() { assert_eq!( - MlsCredential::to_missing_key_err_kind(), - MissingKeyErrorKind::MlsCredential + StoredCredential::to_missing_key_err_kind(), + MissingKeyErrorKind::StoredCredential ); assert_eq!( - MlsKeyPackage::to_missing_key_err_kind(), - MissingKeyErrorKind::MlsKeyPackage + StoredKeypackage::to_missing_key_err_kind(), + MissingKeyErrorKind::StoredKeypackage ); assert_eq!( @@ -50,23 +50,23 @@ mod tests { ); assert_eq!( - MlsHpkePrivateKey::to_missing_key_err_kind(), - MissingKeyErrorKind::MlsHpkePrivateKey + StoredHpkePrivateKey::to_missing_key_err_kind(), + MissingKeyErrorKind::StoredHpkePrivateKey ); assert_eq!( - MlsSignatureKeyPair::to_missing_key_err_kind(), - MissingKeyErrorKind::MlsSignatureKeyPair + StoredSignatureKeypair::to_missing_key_err_kind(), + MissingKeyErrorKind::StoredSignatureKeypair ); assert_eq!( - MlsPskBundle::to_missing_key_err_kind(), - MissingKeyErrorKind::MlsPskBundle + StoredPskBundle::to_missing_key_err_kind(), + MissingKeyErrorKind::StoredPskBundle ); } #[apply(all_storage_types)] - pub async fn can_add_read_delete_credential_bundle_openmls_traits(context: KeystoreTestContext) { + pub async fn can_add_read_delete_credential_openmls_traits(context: KeystoreTestContext) { use core_crypto_keystore::connection::FetchFromDatabase; use itertools::Itertools as _; use openmls_basic_credential::SignatureKeyPair; @@ -82,7 +82,7 @@ mod tests { let credential_id: Vec = credential.identity().into(); - let store_credential = MlsCredential { + let store_credential = StoredCredential { id: credential_id.clone(), credential: credential.tls_serialize_detached().unwrap(), created_at: 0, @@ -96,7 +96,7 @@ mod tests { ) .unwrap(); - let store_keypair = MlsSignatureKeyPair::new( + let store_keypair = StoredSignatureKeypair::new( keypair.signature_scheme(), keypair.to_public_vec(), keypair.tls_serialize_detached().unwrap(), @@ -107,14 +107,14 @@ mod tests { let (credential_from_store,) = backend .key_store() - .find_all::(Default::default()) + .find_all::(Default::default()) .await .unwrap() .into_iter() .filter(|cred| cred.id == credential_id) .collect_tuple() .expect("credentials should be exactly one"); - let keypair2: MlsSignatureKeyPair = backend.key_store().find(keypair.public()).await.unwrap().unwrap(); + let keypair2: StoredSignatureKeypair = backend.key_store().find(keypair.public()).await.unwrap().unwrap(); assert_eq!(keypair2.credential_id, credential_from_store.id); @@ -132,7 +132,7 @@ mod tests { .unwrap(); backend .key_store() - .remove::(keypair.public()) + .remove::(keypair.public()) .await .unwrap(); } @@ -222,14 +222,14 @@ mod tests { // let key_string = key_id.as_hyphenated().to_string(); - // let entity = MlsKeyPackage { + // let entity = StoredKeypackage { // id: key_string.clone(), // key: keypackage_bundle.to_key_store_value().unwrap(), // }; // backend.key_store().save(entity).await.unwrap(); - // let entity2: MlsKeyPackage = backend.key_store().find(key_string.as_bytes()).await.unwrap().unwrap(); + // let entity2: StoredKeypackage = backend.key_store().find(key_string.as_bytes()).await.unwrap().unwrap(); // let bundle2 = KeyPackageBundle::from_key_store_value(&entity2.key).unwrap(); // let (b1_kp, (b1_sk, b1_ls)) = keypackage_bundle.into_parts(); @@ -240,7 +240,7 @@ mod tests { // backend // .key_store() - // .remove::(key_string.as_bytes()) + // .remove::(key_string.as_bytes()) // .await // .unwrap(); diff --git a/keystore/tests/z_entities.rs b/keystore/tests/z_entities.rs index 43c2670d2b..a7256bfd4a 100644 --- a/keystore/tests/z_entities.rs +++ b/keystore/tests/z_entities.rs @@ -46,7 +46,7 @@ macro_rules! test_for_entity { mod tests_impl { use core_crypto_keystore::{ connection::{FetchFromDatabase, KeystoreDatabaseConnection}, - entities::{Entity, EntityFindParams, EntityTransactionExt, MlsCredential, MlsPendingMessage}, + entities::{Entity, EntityFindParams, EntityTransactionExt, MlsPendingMessage, StoredCredential}, }; use super::common::*; @@ -76,8 +76,12 @@ mod tests_impl { .pop() .unwrap(); assert_eq!(*pending_message, pending_message_from_store); - } else if let Some(credential) = entity.downcast::() { - let mut credential_from_store = store.find::(&entity.merge_key()).await.unwrap().unwrap(); + } else if let Some(credential) = entity.downcast::() { + let mut credential_from_store = store + .find::(&entity.merge_key()) + .await + .unwrap() + .unwrap(); credential_from_store.equalize(); assert_eq!(*credential, credential_from_store); } else { @@ -162,16 +166,16 @@ mod tests { test_for_entity!(test_persisted_mls_group, PersistedMlsGroup); test_for_entity!(test_persisted_mls_pending_group, PersistedMlsPendingGroup); test_for_entity!(test_mls_pending_message, MlsPendingMessage ignore_update:true ignore_find_many:true); - test_for_entity!(test_mls_credential, MlsCredential ignore_update:true); - test_for_entity!(test_mls_keypackage, MlsKeyPackage); - test_for_entity!(test_mls_signature_keypair, MlsSignatureKeyPair ignore_update:true); - test_for_entity!(test_mls_psk_bundle, MlsPskBundle); - test_for_entity!(test_mls_encryption_keypair, MlsEncryptionKeyPair); - test_for_entity!(test_mls_epoch_encryption_keypair, MlsEpochEncryptionKeyPair); - test_for_entity!(test_mls_hpke_private_key, MlsHpkePrivateKey); + test_for_entity!(test_mls_credential, StoredCredential ignore_update:true); + test_for_entity!(test_mls_keypackage, StoredKeypackage); + test_for_entity!(test_mls_signature_keypair, StoredSignatureKeypair ignore_update:true); + test_for_entity!(test_mls_psk_bundle, StoredPskBundle); + test_for_entity!(test_mls_encryption_keypair, StoredEncryptionKeyPair); + test_for_entity!(test_mls_epoch_encryption_keypair, StoredEpochEncryptionKeypair); + test_for_entity!(test_mls_hpke_private_key, StoredHpkePrivateKey); test_for_entity!(test_e2ei_intermediate_cert, E2eiIntermediateCert); test_for_entity!(test_e2ei_crl, E2eiCrl); - test_for_entity!(test_e2ei_enrollment, E2eiEnrollment ignore_update:true); + test_for_entity!(test_e2ei_enrollment, StoredE2eiEnrollment ignore_update:true); cfg_if::cfg_if! { if #[cfg(feature = "proteus-keystore")] { @@ -189,7 +193,7 @@ mod tests { pub async fn update_e2ei_enrollment_emits_error(context: KeystoreTestContext) { let store = context.store(); - let mut entity = E2eiEnrollment::random(); + let mut entity = StoredE2eiEnrollment::random(); store.save(entity.clone()).await.unwrap(); store.commit_transaction().await.unwrap(); @@ -202,7 +206,7 @@ mod tests { assert!(matches!( error, - CryptoKeystoreError::AlreadyExists(E2eiEnrollment::COLLECTION_NAME) + CryptoKeystoreError::AlreadyExists(StoredE2eiEnrollment::COLLECTION_NAME) )); // It's required by cleanup to have a running transaction before finishing the test @@ -214,9 +218,9 @@ mod tests { #[cfg(test)] pub mod utils { use core_crypto_keystore::entities::{ - E2eiEnrollment, MlsCredential, MlsEncryptionKeyPair, MlsEpochEncryptionKeyPair, MlsHpkePrivateKey, - MlsKeyPackage, MlsPendingMessage, MlsPskBundle, MlsSignatureKeyPair, PersistedMlsGroup, - PersistedMlsPendingGroup, ProteusSession, + MlsPendingMessage, PersistedMlsGroup, PersistedMlsPendingGroup, ProteusSession, StoredCredential, + StoredE2eiEnrollment, StoredEncryptionKeyPair, StoredEpochEncryptionKeypair, StoredHpkePrivateKey, + StoredKeypackage, StoredPskBundle, StoredSignatureKeypair, }; use rand::Rng as _; @@ -324,17 +328,17 @@ pub mod utils { }; } - impl_entity_random_update_ext!(MlsKeyPackage, blob_fields=[keypackage,], additional_fields=[(keypackage_ref: uuid::Uuid::new_v4().hyphenated().to_string().into()),]); - impl_entity_random_update_ext!(MlsCredential, blob_fields=[credential,], additional_fields=[(id: uuid::Uuid::new_v4().hyphenated().to_string().into()),(created_at: 0; auto-generated:true),]); - impl_entity_random_update_ext!(MlsSignatureKeyPair, blob_fields=[pk,keypair,credential_id,], additional_fields=[(signature_scheme: rand::random()),]); - impl_entity_random_update_ext!(MlsHpkePrivateKey, blob_fields=[pk id_like:true,sk,]); - impl_entity_random_update_ext!(MlsEncryptionKeyPair, blob_fields=[pk id_like:true,sk,]); - impl_entity_random_update_ext!(MlsPskBundle, blob_fields=[psk,psk_id id_like:true,]); + impl_entity_random_update_ext!(StoredKeypackage, blob_fields=[keypackage,], additional_fields=[(keypackage_ref: uuid::Uuid::new_v4().hyphenated().to_string().into()),]); + impl_entity_random_update_ext!(StoredCredential, blob_fields=[credential,], additional_fields=[(id: uuid::Uuid::new_v4().hyphenated().to_string().into()),(created_at: 0; auto-generated:true),]); + impl_entity_random_update_ext!(StoredSignatureKeypair, blob_fields=[pk,keypair,credential_id,], additional_fields=[(signature_scheme: rand::random()),]); + impl_entity_random_update_ext!(StoredHpkePrivateKey, blob_fields=[pk id_like:true,sk,]); + impl_entity_random_update_ext!(StoredEncryptionKeyPair, blob_fields=[pk id_like:true,sk,]); + impl_entity_random_update_ext!(StoredPskBundle, blob_fields=[psk,psk_id id_like:true,]); impl_entity_random_update_ext!(PersistedMlsGroup, id_field=id, blob_fields=[state,], additional_fields=[(parent_id: None),]); impl_entity_random_update_ext!(PersistedMlsPendingGroup, id_field=id, blob_fields=[state,custom_configuration,], additional_fields=[(parent_id: None),]); impl_entity_random_update_ext!(MlsPendingMessage, id_field = foreign_id, blob_fields = [message,]); - impl_entity_random_update_ext!(E2eiEnrollment, id_field = id, blob_fields = [content,]); - impl_entity_random_update_ext!(MlsEpochEncryptionKeyPair, id_field = id, blob_fields = [keypairs,]); + impl_entity_random_update_ext!(StoredE2eiEnrollment, id_field = id, blob_fields = [content,]); + impl_entity_random_update_ext!(StoredEpochEncryptionKeypair, id_field = id, blob_fields = [keypairs,]); impl EntityRandomExt for core_crypto_keystore::entities::E2eiIntermediateCert { fn random() -> Self { diff --git a/mls-provider/src/crypto_provider.rs b/mls-provider/src/crypto_provider.rs index ee5f638849..137f0ee1d6 100644 --- a/mls-provider/src/crypto_provider.rs +++ b/mls-provider/src/crypto_provider.rs @@ -234,6 +234,7 @@ impl OpenMlsCrypto for RustCrypto { } } + /// Generate a `(secret key, public key)` pair from a signature scheme. fn signature_key_gen(&self, alg: SignatureScheme) -> Result<(Vec, Vec), CryptoError> { let mut rng = self.rng.write().map_err(|_| CryptoError::InsufficientRandomness)?; diff --git a/mls-provider/src/lib.rs b/mls-provider/src/lib.rs index 519c0e0b8a..e0f018f8b0 100644 --- a/mls-provider/src/lib.rs +++ b/mls-provider/src/lib.rs @@ -8,6 +8,13 @@ mod pki; pub use crypto_provider::RustCrypto; pub use error::{MlsProviderError, MlsProviderResult}; +use openmls_traits::{ + crypto::OpenMlsCrypto, + types::{ + AeadType, Ciphersuite, CryptoError, ExporterSecret, HashType, HpkeCiphertext, HpkeConfig, HpkeKeyPair, + KemOutput, SignatureScheme, + }, +}; pub use pki::{CertProfile, CertificateGenerationArgs, PkiKeypair}; use crate::pki::PkiEnvironmentProvider; @@ -146,3 +153,132 @@ impl openmls_traits::OpenMlsCryptoProvider for MlsCryptoProvider { &self.pki_env } } + +/// Passthrough implementation of crypto functionality for references to `MlsCryptoProvider`. +impl OpenMlsCrypto for &MlsCryptoProvider { + fn supports(&self, ciphersuite: Ciphersuite) -> Result<(), CryptoError> { + self.crypto.supports(ciphersuite) + } + + fn supported_ciphersuites(&self) -> Vec { + self.crypto.supported_ciphersuites() + } + + fn hkdf_extract( + &self, + hash_type: HashType, + salt: &[u8], + ikm: &[u8], + ) -> Result { + self.crypto.hkdf_extract(hash_type, salt, ikm) + } + + fn hkdf_expand( + &self, + hash_type: HashType, + prk: &[u8], + info: &[u8], + okm_len: usize, + ) -> Result { + self.crypto.hkdf_expand(hash_type, prk, info, okm_len) + } + + fn hash(&self, hash_type: HashType, data: &[u8]) -> Result, CryptoError> { + self.crypto.hash(hash_type, data) + } + + fn aead_encrypt( + &self, + alg: AeadType, + key: &[u8], + data: &[u8], + nonce: &[u8], + aad: &[u8], + ) -> Result, CryptoError> { + self.crypto.aead_encrypt(alg, key, data, nonce, aad) + } + + fn aead_decrypt( + &self, + alg: AeadType, + key: &[u8], + ct_tag: &[u8], + nonce: &[u8], + aad: &[u8], + ) -> Result, CryptoError> { + self.crypto.aead_decrypt(alg, key, ct_tag, nonce, aad) + } + + fn signature_key_gen(&self, alg: SignatureScheme) -> Result<(Vec, Vec), CryptoError> { + self.crypto.signature_key_gen(alg) + } + + fn signature_public_key_len(&self, alg: SignatureScheme) -> usize { + self.crypto.signature_public_key_len(alg) + } + + fn verify_signature( + &self, + alg: SignatureScheme, + data: &[u8], + pk: &[u8], + signature: &[u8], + ) -> Result<(), CryptoError> { + self.crypto.verify_signature(alg, data, pk, signature) + } + + fn sign(&self, alg: SignatureScheme, data: &[u8], key: &[u8]) -> Result, CryptoError> { + self.crypto.sign(alg, data, key) + } + + fn hpke_seal( + &self, + config: HpkeConfig, + pk_r: &[u8], + info: &[u8], + aad: &[u8], + ptxt: &[u8], + ) -> Result { + self.crypto.hpke_seal(config, pk_r, info, aad, ptxt) + } + + fn hpke_open( + &self, + config: HpkeConfig, + input: &HpkeCiphertext, + sk_r: &[u8], + info: &[u8], + aad: &[u8], + ) -> Result, CryptoError> { + self.crypto.hpke_open(config, input, sk_r, info, aad) + } + + fn hpke_setup_sender_and_export( + &self, + config: HpkeConfig, + pk_r: &[u8], + info: &[u8], + exporter_context: &[u8], + exporter_length: usize, + ) -> Result<(KemOutput, ExporterSecret), CryptoError> { + self.crypto + .hpke_setup_sender_and_export(config, pk_r, info, exporter_context, exporter_length) + } + + fn hpke_setup_receiver_and_export( + &self, + config: HpkeConfig, + enc: &[u8], + sk_r: &[u8], + info: &[u8], + exporter_context: &[u8], + exporter_length: usize, + ) -> Result { + self.crypto + .hpke_setup_receiver_and_export(config, enc, sk_r, info, exporter_context, exporter_length) + } + + fn derive_hpke_keypair(&self, config: HpkeConfig, ikm: &[u8]) -> Result { + self.crypto.derive_hpke_keypair(config, ikm) + } +} diff --git a/obfuscate/Cargo.toml b/obfuscate/Cargo.toml index 7f56a7274f..b4284bb5cb 100644 --- a/obfuscate/Cargo.toml +++ b/obfuscate/Cargo.toml @@ -5,12 +5,13 @@ edition = "2024" rust-version = "1.90" [dependencies] +derive_more.workspace = true hex.workspace = true -sha2.workspace = true +log.workspace = true +openmls_basic_credential.workspace = true openmls.workspace = true -derive_more.workspace = true rand.workspace = true -log.workspace = true +sha2.workspace = true [lints] workspace = true diff --git a/obfuscate/src/impls/mod.rs b/obfuscate/src/impls/mod.rs index fc3078b100..d2338e0744 100644 --- a/obfuscate/src/impls/mod.rs +++ b/obfuscate/src/impls/mod.rs @@ -1,2 +1,3 @@ pub mod openmls; +pub mod openmls_basic_credential; pub mod std; diff --git a/obfuscate/src/impls/openmls_basic_credential.rs b/obfuscate/src/impls/openmls_basic_credential.rs new file mode 100644 index 0000000000..9c3fec2226 --- /dev/null +++ b/obfuscate/src/impls/openmls_basic_credential.rs @@ -0,0 +1,16 @@ +use std::fmt::Formatter; + +use hex::ToHex as _; +use openmls_basic_credential::SignatureKeyPair; + +use crate::{Obfuscate, compute_hash}; + +impl Obfuscate for SignatureKeyPair { + fn obfuscate(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + f.debug_struct("SignatureKeyPair") + .field("signature_scheme", &self.signature_scheme()) + .field("public", &self.public().encode_hex::()) + .field("private", &compute_hash(self.private())) + .finish() + } +} diff --git a/obfuscate/src/impls/std.rs b/obfuscate/src/impls/std.rs index 9216d70b8c..53e1bbe39e 100644 --- a/obfuscate/src/impls/std.rs +++ b/obfuscate/src/impls/std.rs @@ -2,12 +2,18 @@ use std::fmt::Formatter; use crate::{Obfuscate, compute_hash}; -impl Obfuscate for Vec { +impl Obfuscate for [u8] { fn obfuscate(&self, f: &mut Formatter<'_>) -> core::fmt::Result { f.write_str(hex::encode(compute_hash(self)).as_str()) } } +impl Obfuscate for Vec { + fn obfuscate(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + (**self).obfuscate(f) + } +} + impl Obfuscate for Vec { fn obfuscate(&self, f: &mut Formatter<'_>) -> core::fmt::Result { f.write_str("[")?; diff --git a/obfuscate/src/lib.rs b/obfuscate/src/lib.rs index 079eb1e6d7..6676376526 100644 --- a/obfuscate/src/lib.rs +++ b/obfuscate/src/lib.rs @@ -38,8 +38,8 @@ pub fn compute_hash(bytes: &[u8]) -> [u8; 10] { /// This wrapper lets us log a partial hash of the sensitive item, so we have deterministic loggable non-sensitive /// aliases for all our sensitive values. #[derive(From)] -pub struct Obfuscated<'a, T: Obfuscate>(&'a T); -impl<'a, T: Obfuscate> core::fmt::Debug for Obfuscated<'a, T> { +pub struct Obfuscated<'a, T: Obfuscate + ?Sized>(&'a T); +impl<'a, T: Obfuscate + ?Sized> core::fmt::Debug for Obfuscated<'a, T> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.0.obfuscate(f) }