diff --git a/src/account_manager.rs b/src/account_manager.rs index c7bbc5ce1..208c48f77 100644 --- a/src/account_manager.rs +++ b/src/account_manager.rs @@ -9,9 +9,9 @@ use aes::cipher::{KeyIvInit, StreamCipher as _}; use hmac::digest::Output; use hmac::{Hmac, Mac}; use libsignal_protocol::{ - kem, GenericSignedPreKey, IdentityKey, IdentityKeyPair, IdentityKeyStore, - KeyPair, KyberPreKeyRecord, PrivateKey, ProtocolStore, PublicKey, - SenderKeyStore, SignedPreKeyRecord, Timestamp, + kem, Aci, GenericSignedPreKey, IdentityKey, IdentityKeyPair, + IdentityKeyStore, KeyPair, KyberPreKeyRecord, PrivateKey, ProtocolStore, + PublicKey, SenderKeyStore, ServiceIdKind, SignedPreKeyRecord, Timestamp, }; use prost::Message; use serde::{Deserialize, Serialize}; @@ -30,15 +30,15 @@ use crate::proto::sync_message::PniChangeNumber; use crate::proto::{DeviceName, SyncMessage}; use crate::provisioning::generate_registration_id; use crate::push_service::{ - AvatarWrite, DeviceActivationRequest, DeviceInfo, HttpAuthOverride, - RecaptchaAttributes, RegistrationMethod, ReqwestExt, ServiceIdType, - VerifyAccountResponse, DEFAULT_DEVICE_ID, + AvatarWrite, CaptchaAttributes, DeviceActivationRequest, DeviceInfo, + HttpAuthOverride, RegistrationMethod, ReqwestExt, VerifyAccountResponse, + DEFAULT_DEVICE_ID, }; use crate::sender::OutgoingPushMessage; +use crate::service_address::ServiceIdExt; use crate::session_store::SessionStoreExt; use crate::timestamp::TimestampExt as _; use crate::utils::{random_length_padding, BASE64_RELAXED}; -use crate::ServiceAddress; use crate::{ configuration::{Endpoint, ServiceCredentials}, pre_keys::PreKeyState, @@ -95,13 +95,13 @@ impl AccountManager { >( &mut self, protocol_store: &mut P, - service_id_type: ServiceIdType, + service_id_kind: ServiceIdKind, csprng: &mut R, use_last_resort_key: bool, ) -> Result<(), ServiceError> { let prekey_status = match self .service - .get_pre_key_status(service_id_type) + .get_pre_key_status(service_id_kind) .instrument(tracing::span!( tracing::Level::DEBUG, "Fetching pre key status" @@ -203,7 +203,7 @@ impl AccountManager { }; self.service - .register_pre_keys(service_id_type, pre_key_state) + .register_pre_keys(service_id_kind, pre_key_state) .instrument(tracing::span!( tracing::Level::DEBUG, "Uploading pre keys" @@ -495,7 +495,7 @@ impl AccountManager { pub async fn retrieve_profile( &mut self, - address: ServiceAddress, + address: Aci, ) -> Result { let profile_key = self.profile_key.expect("set profile key in AccountManager"); @@ -624,10 +624,10 @@ impl AccountManager { Endpoint::service("/v1/challenge"), HttpAuthOverride::NoOverride, )? - .json(&RecaptchaAttributes { - r#type: String::from("recaptcha"), - token: String::from(token), - captcha: String::from(captcha), + .json(&CaptchaAttributes { + challenge_type: "captcha", + token, + captcha, }) .send() .await? @@ -642,19 +642,19 @@ impl AccountManager { /// Should be called as the primary device to migrate from pre-PNI to PNI. /// /// This is the equivalent of Android's PnpInitializeDevicesJob or iOS' PniHelloWorldManager. - #[tracing::instrument(skip(self, aci_protocol_store, pni_protocol_store, sender, local_aci), fields(local_aci = %local_aci))] + #[tracing::instrument(skip(self, aci_protocol_store, pni_protocol_store, sender, local_aci), fields(local_aci = %local_aci.service_id_string()))] pub async fn pnp_initialize_devices< // XXX So many constraints here, all imposed by the MessageSender R: rand::Rng + rand::CryptoRng, - Aci: PreKeysStore + SessionStoreExt, - Pni: PreKeysStore, + AciStore: PreKeysStore + SessionStoreExt, + PniStore: PreKeysStore, AciOrPni: ProtocolStore + SenderKeyStore + SessionStoreExt + Sync + Clone, >( &mut self, - aci_protocol_store: &mut Aci, - pni_protocol_store: &mut Pni, + aci_protocol_store: &mut AciStore, + pni_protocol_store: &mut PniStore, mut sender: MessageSender, - local_aci: ServiceAddress, + local_aci: Aci, e164: PhoneNumber, ) -> Result<(), MessageSenderError> { let mut csprng = rand::thread_rng(); @@ -665,7 +665,7 @@ impl AccountManager { // For every linked device, we generate a new set of pre-keys, and send them to the device. let local_device_ids = aci_protocol_store - .get_sub_device_sessions(&local_aci) + .get_sub_device_sessions(&local_aci.into()) .await?; let mut device_messages = @@ -809,7 +809,7 @@ impl AccountManager { let content: ContentBody = msg.into(); let msg = sender .create_encrypted_message( - &local_aci, + &local_aci.into(), None, local_device_id.into(), &content.into_proto().encode_to_vec(), diff --git a/src/cipher.rs b/src/cipher.rs index 645b4aed9..56ab4f8a5 100644 --- a/src/cipher.rs +++ b/src/cipher.rs @@ -9,8 +9,8 @@ use libsignal_protocol::{ CiphertextMessageType, DeviceId, IdentityKeyStore, KyberPreKeyStore, PreKeySignalMessage, PreKeyStore, ProtocolAddress, ProtocolStore, PublicKey, SealedSenderDecryptionResult, SenderCertificate, - SenderKeyDistributionMessage, SenderKeyStore, SessionStore, SignalMessage, - SignalProtocolError, SignedPreKeyStore, Timestamp, + SenderKeyDistributionMessage, SenderKeyStore, ServiceId, SessionStore, + SignalMessage, SignalProtocolError, SignedPreKeyStore, Timestamp, }; use prost::Message; use uuid::Uuid; @@ -22,7 +22,7 @@ use crate::{ sender::OutgoingPushMessage, session_store::SessionStoreExt, utils::BASE64_RELAXED, - ServiceAddress, + ServiceIdExt, }; /// Decrypts incoming messages and encrypts outgoing messages. /// @@ -271,13 +271,16 @@ where ) .await?; - let sender = ServiceAddress::try_from(sender_uuid.as_str()) - .map_err(|e| { - tracing::error!("{:?}", e); + let Some(sender) = + ServiceId::parse_from_service_id_string(&sender_uuid) + else { + return Err( SignalProtocolError::InvalidSealedSenderMessage( "invalid sender UUID".to_string(), ) - })?; + .into(), + ); + }; let needs_receipt = if envelope.source_service_id.is_some() { tracing::warn!(?envelope, "Received an unidentified delivery over an identified channel. Marking needs_receipt=false"); @@ -446,7 +449,7 @@ fn strip_padding(contents: &mut Vec) -> Result<(), ServiceError> { /// Equivalent of `SignalServiceCipher::getPreferredProtocolAddress` pub async fn get_preferred_protocol_address( session_store: &S, - address: &ServiceAddress, + address: &ServiceId, device_id: DeviceId, ) -> Result { let address = address.to_protocol_address(device_id); diff --git a/src/content.rs b/src/content.rs index b779c4c86..46829b58a 100644 --- a/src/content.rs +++ b/src/content.rs @@ -1,4 +1,4 @@ -use libsignal_protocol::ProtocolAddress; +use libsignal_protocol::{ProtocolAddress, ServiceId}; use std::fmt; use uuid::Uuid; @@ -12,6 +12,7 @@ pub use crate::{ SyncMessage, TypingMessage, }, push_service::ServiceError, + ServiceIdExt, }; mod data_message; @@ -19,8 +20,8 @@ mod story_message; #[derive(Clone, Debug)] pub struct Metadata { - pub sender: crate::ServiceAddress, - pub destination: crate::ServiceAddress, + pub sender: ServiceId, + pub destination: ServiceId, pub sender_device: u32, pub timestamp: u64, pub needs_receipt: bool, @@ -37,7 +38,7 @@ impl fmt::Display for Metadata { write!( f, "Metadata {{ sender: {}, guid: {} }}", - self.sender.to_service_id(), + self.sender.service_id_string(), // XXX: should this still be optional? self.server_guid .map(|u| u.to_string()) diff --git a/src/envelope.rs b/src/envelope.rs index 1d2240f37..ae19e0a1d 100644 --- a/src/envelope.rs +++ b/src/envelope.rs @@ -1,30 +1,12 @@ -use std::convert::{TryFrom, TryInto}; - use aes::cipher::block_padding::Pkcs7; use aes::cipher::{BlockDecryptMut, KeyIvInit}; +use libsignal_protocol::ServiceId; use prost::Message; -use crate::{ - configuration::SignalingKey, push_service::ServiceError, - utils::serde_optional_base64, ParseServiceAddressError, ServiceAddress, -}; +use crate::{configuration::SignalingKey, push_service::ServiceError}; pub use crate::proto::Envelope; -impl TryFrom for Envelope { - type Error = ParseServiceAddressError; - - fn try_from(entity: EnvelopeEntity) -> Result { - match entity.source_uuid.as_deref() { - Some(uuid) => { - let address = uuid.try_into()?; - Ok(Envelope::new_with_source(entity, address)) - }, - None => Ok(Envelope::new_from_entity(entity)), - } - } -} - impl Envelope { #[tracing::instrument(skip(input, signaling_key), fields(signaling_key_present = signaling_key.is_some(), input_size = input.len()))] pub fn decrypt( @@ -85,30 +67,6 @@ impl Envelope { } } - fn new_from_entity(entity: EnvelopeEntity) -> Self { - Envelope { - r#type: Some(entity.r#type), - timestamp: Some(entity.timestamp), - server_timestamp: Some(entity.server_timestamp), - server_guid: Some(entity.guid), - content: entity.content, - ..Default::default() - } - } - - fn new_with_source(entity: EnvelopeEntity, source: ServiceAddress) -> Self { - Envelope { - r#type: Some(entity.r#type), - source_device: Some(entity.source_device), - timestamp: Some(entity.timestamp), - server_timestamp: Some(entity.server_timestamp), - server_guid: Some(entity.guid), - source_service_id: Some(source.uuid.to_string()), - content: entity.content, - ..Default::default() - } - } - pub fn is_unidentified_sender(&self) -> bool { self.r#type() == crate::proto::envelope::Type::UnidentifiedSender } @@ -134,54 +92,27 @@ impl Envelope { self.story.unwrap_or(false) } - pub fn source_address(&self) -> ServiceAddress { + pub fn source_address(&self) -> ServiceId { match self.source_service_id.as_deref() { - Some(service_id) => ServiceAddress::try_from(service_id) - .expect("invalid source ProtocolAddress UUID or prefix"), + Some(service_id) => { + ServiceId::parse_from_service_id_string(service_id) + .expect("invalid source ProtocolAddress UUID or prefix") + }, None => panic!("source_service_id is set"), } } - pub fn destination_address(&self) -> ServiceAddress { + pub fn destination_address(&self) -> ServiceId { match self.destination_service_id.as_deref() { - Some(service_id) => ServiceAddress::try_from(service_id) - .expect("invalid destination ProtocolAddress UUID or prefix"), + Some(service_id) => ServiceId::parse_from_service_id_string( + service_id, + ) + .expect("invalid destination ProtocolAddress UUID or prefix"), None => panic!("destination_address is set"), } } } -#[derive(serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct EnvelopeEntity { - pub r#type: i32, - pub timestamp: u64, - pub source: Option, - pub source_uuid: Option, - pub source_device: u32, - #[serde(default)] - pub destination_uuid: Option, - #[serde(default, with = "serde_optional_base64")] - pub content: Option>, - pub server_timestamp: u64, - pub guid: String, - #[serde(default = "default_true")] - pub urgent: bool, - #[serde(default)] - pub story: bool, - #[serde(default, with = "serde_optional_base64")] - pub report_spam_token: Option>, -} - -fn default_true() -> bool { - true -} - -#[derive(serde::Serialize, serde::Deserialize)] -pub(crate) struct EnvelopeEntityList { - pub messages: Vec, -} - pub(crate) const SUPPORTED_VERSION: u8 = 1; pub(crate) const CIPHER_KEY_SIZE: usize = 32; pub(crate) const MAC_KEY_SIZE: usize = 20; diff --git a/src/groups_v2/model.rs b/src/groups_v2/model.rs index 92be11d0d..75990ba59 100644 --- a/src/groups_v2/model.rs +++ b/src/groups_v2/model.rs @@ -1,12 +1,11 @@ use std::{convert::TryFrom, convert::TryInto}; use derivative::Derivative; +use libsignal_protocol::ServiceId; use serde::{Deserialize, Serialize}; use uuid::Uuid; use zkgroup::profiles::ProfileKey; -use crate::ServiceAddress; - use super::GroupDecodingError; #[derive(Copy, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] @@ -34,7 +33,7 @@ impl PartialEq for Member { #[derive(Clone, Debug, PartialEq, Eq)] pub struct PendingMember { - pub address: ServiceAddress, + pub address: ServiceId, pub role: Role, pub added_by_uuid: Uuid, pub timestamp: u64, diff --git a/src/groups_v2/operations.rs b/src/groups_v2/operations.rs index 6b766631a..e14c1059e 100644 --- a/src/groups_v2/operations.rs +++ b/src/groups_v2/operations.rs @@ -148,7 +148,7 @@ impl GroupOperations { let added_by_uuid = self.decrypt_aci(&member.added_by_user_id)?; Ok(PendingMember { - address: service_id.into(), + address: service_id, role: inner_member.role.try_into()?, added_by_uuid: added_by_uuid.into(), timestamp: member.timestamp, diff --git a/src/lib.rs b/src/lib.rs index 2a881ca8d..1cc48816f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,7 +19,6 @@ pub mod messagepipe; pub mod models; pub mod pre_keys; pub mod profile_name; -pub mod profile_service; #[allow(clippy::derive_partial_eq_without_eq)] pub mod proto; pub mod provisioning; @@ -49,7 +48,6 @@ pub const GROUP_UPDATE_FLAG: u32 = 1; pub const GROUP_LEAVE_FLAG: u32 = 2; pub mod prelude { - pub use super::ServiceAddress; pub use crate::{ cipher::ServiceCipher, configuration::{ diff --git a/src/profile_service.rs b/src/profile_service.rs deleted file mode 100644 index 4e7ea9b42..000000000 --- a/src/profile_service.rs +++ /dev/null @@ -1,55 +0,0 @@ -use reqwest::Method; - -use crate::{ - configuration::Endpoint, - prelude::PushService, - push_service::{ - HttpAuthOverride, ReqwestExt, ServiceError, SignalServiceProfile, - }, - ServiceAddress, -}; - -pub struct ProfileService { - push_service: PushService, -} - -impl ProfileService { - pub fn from_socket(push_service: PushService) -> Self { - ProfileService { push_service } - } - - pub async fn retrieve_profile_by_id( - &mut self, - address: ServiceAddress, - profile_key: Option, - ) -> Result { - let path = match profile_key { - Some(key) => { - let version = - bincode::serialize(&key.get_profile_key_version( - address.aci().expect("profile by ACI ProtocolAddress"), - ))?; - let version = std::str::from_utf8(&version) - .expect("hex encoded profile key version"); - format!("/v1/profile/{}/{}", address.uuid, version) - }, - None => { - format!("/v1/profile/{}", address.uuid) - }, - }; - - self.push_service - .request( - Method::GET, - Endpoint::service(path), - HttpAuthOverride::NoOverride, - )? - .send() - .await? - .service_error_for_status() - .await? - .json() - .await - .map_err(Into::into) - } -} diff --git a/src/push_service/account.rs b/src/push_service/account.rs index 0c90ad98d..bbcab6f60 100644 --- a/src/push_service/account.rs +++ b/src/push_service/account.rs @@ -14,27 +14,6 @@ use crate::{ utils::{serde_optional_base64, serde_phone_number}, }; -#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -pub enum ServiceIdType { - /// Account Identity (ACI) - /// - /// An account UUID without an associated phone number, probably in the future to a username - AccountIdentity, - /// Phone number identity (PNI) - /// - /// A UUID associated with a phone number - PhoneNumberIdentity, -} - -impl fmt::Display for ServiceIdType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ServiceIdType::AccountIdentity => f.write_str("aci"), - ServiceIdType::PhoneNumberIdentity => f.write_str("pni"), - } - } -} - #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ServiceIds { #[serde(rename = "uuid")] @@ -120,7 +99,8 @@ pub struct DeviceCapabilities { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct WhoAmIResponse { - pub uuid: Uuid, + #[serde(rename = "uuid")] + pub aci: Uuid, #[serde(default)] // nil when not present (yet) pub pni: Uuid, #[serde(with = "serde_phone_number")] diff --git a/src/push_service/error.rs b/src/push_service/error.rs index 2aad42db0..3b3cb420c 100644 --- a/src/push_service/error.rs +++ b/src/push_service/error.rs @@ -1,8 +1,8 @@ use aes::cipher::block_padding::UnpadError; -use libsignal_protocol::SignalProtocolError; +use libsignal_protocol::{ServiceIdKind, SignalProtocolError}; use zkgroup::ZkGroupDeserializationFailure; -use crate::{groups_v2::GroupDecodingError, ParseServiceAddressError}; +use crate::groups_v2::GroupDecodingError; use super::{ MismatchedDevices, ProofRequired, RegistrationLockFailure, StaleDevices, @@ -16,6 +16,9 @@ pub enum ServiceError { #[error("invalid URL: {0}")] InvalidUrl(#[from] url::ParseError), + #[error("wrong address type: {0}")] + InvalidAddressType(ServiceIdKind), + #[error("Error sending request: {reason}")] SendError { reason: String }, @@ -84,9 +87,6 @@ pub enum ServiceError { #[error("unsupported content")] UnsupportedContent, - #[error(transparent)] - ParseServiceAddress(#[from] ParseServiceAddressError), - #[error("Not found.")] NotFoundError, diff --git a/src/push_service/keys.rs b/src/push_service/keys.rs index d2cf6f091..4403ad1d2 100644 --- a/src/push_service/keys.rs +++ b/src/push_service/keys.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; -use libsignal_protocol::{IdentityKey, PreKeyBundle, SenderCertificate}; +use libsignal_protocol::{ + IdentityKey, PreKeyBundle, SenderCertificate, ServiceId, ServiceIdKind, +}; use reqwest::Method; use serde::Deserialize; @@ -10,12 +12,11 @@ use crate::{ push_service::PreKeyResponse, sender::OutgoingPushMessage, utils::serde_base64, - ServiceAddress, }; use super::{ response::ReqwestExt, HttpAuthOverride, PushService, SenderCertificateJson, - ServiceError, ServiceIdType, VerifyAccountResponse, + ServiceError, VerifyAccountResponse, }; #[derive(Debug, Deserialize, Default)] @@ -28,11 +29,11 @@ pub struct PreKeyStatus { impl PushService { pub async fn get_pre_key_status( &mut self, - service_id_type: ServiceIdType, + service_id_kind: ServiceIdKind, ) -> Result { self.request( Method::GET, - Endpoint::service(format!("/v2/keys?identity={}", service_id_type)), + Endpoint::service(format!("/v2/keys?identity={}", service_id_kind)), HttpAuthOverride::NoOverride, )? .send() @@ -46,12 +47,12 @@ impl PushService { pub async fn register_pre_keys( &mut self, - service_id_type: ServiceIdType, + service_id_kind: ServiceIdKind, pre_key_state: PreKeyState, ) -> Result<(), ServiceError> { self.request( Method::PUT, - Endpoint::service(format!("/v2/keys?identity={}", service_id_type)), + Endpoint::service(format!("/v2/keys?identity={}", service_id_kind)), HttpAuthOverride::NoOverride, )? .json(&pre_key_state) @@ -65,11 +66,14 @@ impl PushService { pub async fn get_pre_key( &mut self, - destination: &ServiceAddress, + destination: &ServiceId, device_id: u32, ) -> Result { - let path = - format!("/v2/keys/{}/{}?pq=true", destination.uuid, device_id); + let path = format!( + "/v2/keys/{}/{}", + destination.service_id_string(), + device_id + ); let mut pre_key_response: PreKeyResponse = self .request( @@ -93,13 +97,17 @@ impl PushService { pub(crate) async fn get_pre_keys( &mut self, - destination: &ServiceAddress, + destination: &ServiceId, device_id: u32, ) -> Result, ServiceError> { let path = if device_id == 1 { - format!("/v2/keys/{}/*?pq=true", destination.uuid) + format!("/v2/keys/{}/*", destination.service_id_string()) } else { - format!("/v2/keys/{}/{}?pq=true", destination.uuid, device_id) + format!( + "/v2/keys/{}/{}", + destination.service_id_string(), + device_id + ) }; let pre_key_response: PreKeyResponse = self .request( diff --git a/src/push_service/profile.rs b/src/push_service/profile.rs index c14b59b91..5553ceb1e 100644 --- a/src/push_service/profile.rs +++ b/src/push_service/profile.rs @@ -1,3 +1,4 @@ +use libsignal_protocol::Aci; use reqwest::Method; use serde::{Deserialize, Serialize}; use zkgroup::profiles::{ProfileKeyCommitment, ProfileKeyVersion}; @@ -8,7 +9,7 @@ use crate::{ profile_cipher::ProfileCipherError, push_service::AvatarWrite, utils::{serde_base64, serde_optional_base64}, - Profile, ServiceAddress, + Profile, }; use super::{DeviceCapabilities, HttpAuthOverride, PushService, ReqwestExt}; @@ -89,18 +90,17 @@ struct SignalServiceProfileWrite<'s> { impl PushService { pub async fn retrieve_profile_by_id( &mut self, - address: ServiceAddress, + address: Aci, profile_key: Option, ) -> Result { let path = if let Some(key) = profile_key { - let version = bincode::serialize(&key.get_profile_key_version( - address.aci().expect("profile by ACI ProtocolAddress"), - ))?; + let version = + bincode::serialize(&key.get_profile_key_version(address))?; let version = std::str::from_utf8(&version) .expect("hex encoded profile key version"); - format!("/v1/profile/{}/{}", address.uuid, version) + format!("/v1/profile/{}/{}", address.service_id_string(), version) } else { - format!("/v1/profile/{}", address.uuid) + format!("/v1/profile/{}", address.service_id_string()) }; // TODO: set locale to en_US self.request( diff --git a/src/push_service/registration.rs b/src/push_service/registration.rs index 8521d682a..f656fbb0f 100644 --- a/src/push_service/registration.rs +++ b/src/push_service/registration.rs @@ -81,11 +81,12 @@ pub struct DeviceActivationRequest { pub pni_pq_last_resort_pre_key: KyberPreKeyEntity, } -#[derive(Debug, Serialize, Deserialize)] -pub struct RecaptchaAttributes { - pub r#type: String, - pub token: String, - pub captcha: String, +#[derive(Debug, Serialize)] +pub struct CaptchaAttributes<'a> { + #[serde(rename = "type")] + pub challenge_type: &'a str, + pub token: &'a str, + pub captcha: &'a str, } #[derive(Debug, Clone, Deserialize)] diff --git a/src/sender.rs b/src/sender.rs index ae9d2cfec..d8ff4c97f 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -2,8 +2,9 @@ use std::{collections::HashSet, time::SystemTime}; use chrono::prelude::*; use libsignal_protocol::{ - process_prekey_bundle, DeviceId, IdentityKey, IdentityKeyPair, - ProtocolStore, SenderCertificate, SenderKeyStore, SignalProtocolError, + process_prekey_bundle, Aci, DeviceId, IdentityKey, IdentityKeyPair, Pni, + ProtocolStore, SenderCertificate, SenderKeyStore, ServiceId, + SignalProtocolError, }; use tracing::{debug, error, info, trace, warn}; use tracing_futures::Instrument; @@ -23,10 +24,11 @@ use crate::{ AttachmentPointer, SyncMessage, }, push_service::*, + service_address::ServiceIdExt, session_store::SessionStoreExt, unidentified_access::UnidentifiedAccess, + utils::serde_service_id, websocket::SignalWebSocket, - ServiceAddress, }; pub use crate::proto::{ContactDetails, GroupDetails}; @@ -42,7 +44,8 @@ pub struct OutgoingPushMessage { #[derive(serde::Serialize, Debug)] pub struct OutgoingPushMessages { - pub destination: uuid::Uuid, + #[serde(with = "serde_service_id")] + pub destination: ServiceId, pub timestamp: u64, pub messages: Vec, pub online: bool, @@ -58,7 +61,7 @@ pub type SendMessageResult = Result; #[derive(Debug, Clone)] pub struct SentMessage { - pub recipient: ServiceAddress, + pub recipient: ServiceId, pub used_identity_key: IdentityKey, pub unidentified: bool, pub needs_sync: bool, @@ -88,8 +91,8 @@ pub struct MessageSender { service: PushService, cipher: ServiceCipher, protocol_store: S, - local_aci: ServiceAddress, - local_pni: ServiceAddress, + local_aci: Aci, + local_pni: Pni, aci_identity: IdentityKeyPair, pni_identity: Option, device_id: DeviceId, @@ -119,7 +122,7 @@ pub enum MessageSenderError { SendSyncMessageError(sync_message::request::Type), #[error("Untrusted identity key with {address:?}")] - UntrustedIdentity { address: ServiceAddress }, + UntrustedIdentity { address: ServiceId }, #[error("Exceeded maximum number of retries")] MaximumRetriesLimitExceeded, @@ -127,8 +130,8 @@ pub enum MessageSenderError { #[error("Proof of type {options:?} required using token {token}")] ProofRequired { token: String, options: Vec }, - #[error("Recipient not found: {addr:?}")] - NotFound { addr: ServiceAddress }, + #[error("Recipient not found: {service_id:?}")] + NotFound { service_id: ServiceId }, #[error("no messages were encrypted: this should not really happen and most likely implies a logic error")] NoMessagesToSend, @@ -159,8 +162,8 @@ where service: PushService, cipher: ServiceCipher, protocol_store: S, - local_aci: impl Into, - local_pni: impl Into, + local_aci: impl Into, + local_pni: impl Into, aci_identity: IdentityKeyPair, pni_identity: Option, device_id: DeviceId, @@ -328,7 +331,7 @@ where async fn is_multi_device(&self) -> bool { if self.device_id == DEFAULT_DEVICE_ID.into() { self.protocol_store - .get_sub_device_sessions(&self.local_aci) + .get_sub_device_sessions(&self.local_aci.into()) .await .map_or(false, |s| !s.is_empty()) } else { @@ -338,12 +341,12 @@ where /// Send a message `content` to a single `recipient`. #[tracing::instrument( - skip(self, unidentified_access, message, recipient), - fields(unidentified_access = unidentified_access.is_some(), recipient = %recipient), + skip(self, unidentified_access, message), + fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()), )] pub async fn send_message( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, mut unidentified_access: Option, message: impl Into, timestamp: u64, @@ -417,7 +420,7 @@ where Some(&result), ); self.try_send_message( - self.local_aci, + self.local_aci.into(), None, &sync_message, timestamp, @@ -429,7 +432,11 @@ where if end_session { let n = self.protocol_store.delete_all_sessions(recipient).await?; - tracing::debug!("ended {} sessions with {}", n, recipient.uuid); + tracing::debug!( + "ended {} sessions with {}", + n, + recipient.raw_uuid() + ); } result @@ -447,7 +454,7 @@ where )] pub async fn send_message_to_group( &mut self, - recipients: impl AsRef<[(ServiceAddress, Option, bool)]>, + recipients: impl AsRef<[(ServiceId, Option, bool)]>, message: impl Into, timestamp: u64, online: bool, @@ -494,7 +501,7 @@ where // See Signal Android `SignalServiceMessageSender.java:2817` if let Err(error) = self .try_send_message( - self.local_aci, + self.local_aci.into(), None, &sync_message, timestamp, @@ -513,12 +520,12 @@ where /// Send a message (`content`) to an address (`recipient`). #[tracing::instrument( level = "trace", - skip(self, unidentified_access), - fields(unidentified_access = unidentified_access.is_some(), recipient = %recipient), + skip(self, unidentified_access, content_body, recipient), + fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()), )] async fn try_send_message( &mut self, - recipient: ServiceAddress, + recipient: ServiceId, mut unidentified_access: Option<&UnidentifiedAccess>, content_body: &ContentBody, timestamp: u64, @@ -555,7 +562,7 @@ where }; let messages = OutgoingPushMessages { - destination: recipient.uuid, + destination: recipient, timestamp, messages, online, @@ -656,7 +663,7 @@ where Err(ServiceError::NotFoundError) => { tracing::debug!("Not found when sending a message"); return Err(MessageSenderError::NotFound { - addr: recipient, + service_id: recipient, }); }, Err(e) => { @@ -675,11 +682,11 @@ where /// Upload contact details to the CDN and send a sync message #[tracing::instrument( skip(self, unidentified_access, contacts, recipient), - fields(unidentified_access = unidentified_access.is_some(), recipient = %recipient), + fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()), )] pub async fn send_contact_details( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, unidentified_access: Option, // XXX It may be interesting to use an intermediary type, // instead of ContactDetails directly, @@ -715,10 +722,10 @@ where } /// Send `Configuration` synchronization message - #[tracing::instrument(skip(self, recipient), fields(recipient = %recipient))] + #[tracing::instrument(skip(self), fields(recipient = recipient.service_id_string()))] pub async fn send_configuration( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, configuration: sync_message::Configuration, ) -> Result<(), MessageSenderError> { let msg = SyncMessage { @@ -734,10 +741,10 @@ where } /// Send `MessageRequestResponse` synchronization message with either a recipient ACI or a GroupV2 ID - #[tracing::instrument(skip(self, recipient), fields(recipient = %recipient))] + #[tracing::instrument(skip(self), fields(recipient = recipient.service_id_string()))] pub async fn send_message_request_response( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, thread: &ThreadIdentifier, action: message_request_response::Type, ) -> Result<(), MessageSenderError> { @@ -781,10 +788,10 @@ where } /// Send `Keys` synchronization message - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), fields(recipient = recipient.service_id_string()))] pub async fn send_keys( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, keys: sync_message::Keys, ) -> Result<(), MessageSenderError> { let msg = SyncMessage { @@ -803,7 +810,7 @@ where #[tracing::instrument(skip(self))] pub async fn send_sync_message_request( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, request_type: sync_message::request::Type, ) -> Result<(), MessageSenderError> { if self.device_id == DEFAULT_DEVICE_ID.into() { @@ -836,7 +843,7 @@ where &mut rand::thread_rng(), )?; Ok(crate::proto::PniSignatureMessage { - pni: Some(self.local_pni.uuid.as_bytes().to_vec()), + pni: Some(self.local_pni.service_id_binary()), signature: Some(signature.into()), }) } @@ -844,12 +851,12 @@ where // Equivalent with `getEncryptedMessages` #[tracing::instrument( level = "trace", - skip(self, unidentified_access, content, recipient), - fields(unidentified_access = unidentified_access.is_some(), recipient = %recipient), + skip(self, unidentified_access, content), + fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()), )] async fn create_encrypted_messages( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, unidentified_access: Option<&SenderCertificate>, content: &[u8], ) -> Result, MessageSenderError> { @@ -867,18 +874,14 @@ where devices.insert(DEFAULT_DEVICE_ID.into()); // never try to send messages to the sender device - match recipient.identity { - ServiceIdType::AccountIdentity => { - if recipient.aci().is_some() - && recipient.aci() == self.local_aci.aci() - { + match recipient { + ServiceId::Aci(aci) => { + if *aci == self.local_aci { devices.remove(&self.device_id); } }, - ServiceIdType::PhoneNumberIdentity => { - if recipient.pni().is_some() - && recipient.pni() == self.local_aci.pni() - { + ServiceId::Pni(pni) => { + if *pni == self.local_pni { devices.remove(&self.device_id); } }, @@ -952,12 +955,12 @@ where /// When no session with the recipient exists, we need to create one. #[tracing::instrument( level = "trace", - skip(self, unidentified_access, content, recipient), - fields(unidentified_access = unidentified_access.is_some(), recipient = %recipient), + skip(self, unidentified_access, content), + fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()), )] pub(crate) async fn create_encrypted_message( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, unidentified_access: Option<&SenderCertificate>, device_id: DeviceId, content: &[u8], @@ -993,7 +996,7 @@ where }, Err(ServiceError::NotFoundError) => { return Err(MessageSenderError::NotFound { - addr: *recipient, + service_id: *recipient, }); }, Err(e) => Err(e)?, @@ -1036,7 +1039,7 @@ where } fn create_multi_device_sent_transcript_content<'a>( - recipient: Option<&ServiceAddress>, + recipient: Option<&ServiceId>, content_body: ContentBody, timestamp: u64, send_message_results: impl IntoIterator, @@ -1060,7 +1063,7 @@ where } = sent; UnidentifiedDeliveryStatus { destination_service_id: Some( - recipient.uuid.to_string(), + recipient.service_id_string(), ), unidentified: Some(*unidentified), destination_identity_key: Some( @@ -1071,7 +1074,8 @@ where .collect(); ContentBody::SynchronizeMessage(SyncMessage { sent: Some(sync_message::Sent { - destination_service_id: recipient.map(|r| r.uuid.to_string()), + destination_service_id: recipient + .map(ServiceId::service_id_string), destination_e164: None, expiration_start_timestamp: data_message .as_ref() diff --git a/src/service_address.rs b/src/service_address.rs index de3680879..de83363ab 100644 --- a/src/service_address.rs +++ b/src/service_address.rs @@ -1,160 +1,39 @@ -use std::convert::TryFrom; +use libsignal_protocol::{Aci, DeviceId, Pni, ProtocolAddress, ServiceId}; -use libsignal_protocol::{DeviceId, ProtocolAddress, ServiceId}; -use uuid::Uuid; - -pub use crate::push_service::ServiceIdType; - -#[derive(thiserror::Error, Debug, Clone)] -pub enum ParseServiceAddressError { - #[error("Supplied UUID could not be parsed")] - InvalidUuid(#[from] uuid::Error), - - #[error("Envelope without UUID")] - NoUuid, -} +pub trait ServiceIdExt { + fn to_protocol_address( + self, + device_id: impl Into, + ) -> ProtocolAddress; -#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] -pub struct ServiceAddress { - pub uuid: Uuid, - pub identity: ServiceIdType, -} + fn aci(self) -> Option; -impl std::fmt::Display for ServiceAddress { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // This is used in ServiceAddress::to_service_id(&self), so keep this consistent. - match self.identity { - ServiceIdType::AccountIdentity => write!(f, "{}", self.uuid), - ServiceIdType::PhoneNumberIdentity => { - write!(f, "PNI:{}", self.uuid) - }, - } - } + fn pni(self) -> Option; } -impl ServiceAddress { - pub fn to_protocol_address( - &self, +impl ServiceIdExt for A +where + A: Into, +{ + fn to_protocol_address( + self, device_id: impl Into, ) -> ProtocolAddress { - match self.identity { - ServiceIdType::AccountIdentity => { - ProtocolAddress::new(self.uuid.to_string(), device_id.into()) - }, - ServiceIdType::PhoneNumberIdentity => ProtocolAddress::new( - format!("PNI:{}", self.uuid), - device_id.into(), - ), - } - } - - #[deprecated] - pub fn new_aci(uuid: Uuid) -> Self { - Self::from_aci(uuid) - } - - pub fn from_aci(uuid: Uuid) -> Self { - Self { - uuid, - identity: ServiceIdType::AccountIdentity, - } - } - - #[deprecated] - pub fn new_pni(uuid: Uuid) -> Self { - Self::from_pni(uuid) - } - - pub fn from_pni(uuid: Uuid) -> Self { - Self { - uuid, - identity: ServiceIdType::PhoneNumberIdentity, - } - } - - pub fn aci(&self) -> Option { - use libsignal_protocol::Aci; - match self.identity { - ServiceIdType::AccountIdentity => { - Some(Aci::from_uuid_bytes(self.uuid.into_bytes())) - }, - ServiceIdType::PhoneNumberIdentity => None, - } - } - - pub fn pni(&self) -> Option { - use libsignal_protocol::Pni; - match self.identity { - ServiceIdType::AccountIdentity => None, - ServiceIdType::PhoneNumberIdentity => { - Some(Pni::from_uuid_bytes(self.uuid.into_bytes())) - }, - } - } - - pub fn to_service_id(&self) -> String { - self.to_string() - } -} - -impl From for ServiceAddress { - fn from(service_id: ServiceId) -> Self { - match service_id { - ServiceId::Aci(service_id) => { - ServiceAddress::from_aci(service_id.into()) - }, - ServiceId::Pni(service_id) => { - ServiceAddress::from_pni(service_id.into()) - }, - } - } -} - -impl TryFrom<&ProtocolAddress> for ServiceAddress { - type Error = ParseServiceAddressError; - - fn try_from(addr: &ProtocolAddress) -> Result { - let value = addr.name(); - if let Some(pni) = value.strip_prefix("PNI:") { - Ok(ServiceAddress::from_pni(Uuid::parse_str(pni)?)) - } else { - Ok(ServiceAddress::from_aci(Uuid::parse_str(value)?)) - } - .map_err(|e| { - tracing::error!("Parsing ServiceAddress from {:?}", addr); - ParseServiceAddressError::InvalidUuid(e) - }) + let service_id: ServiceId = self.into(); + ProtocolAddress::new(service_id.service_id_string(), device_id.into()) } -} - -impl TryFrom<&str> for ServiceAddress { - type Error = ParseServiceAddressError; - fn try_from(value: &str) -> Result { - if let Some(pni) = value.strip_prefix("PNI:") { - Ok(ServiceAddress::from_pni(Uuid::parse_str(pni)?)) - } else { - Ok(ServiceAddress::from_aci(Uuid::parse_str(value)?)) + fn aci(self) -> Option { + match self.into() { + ServiceId::Aci(aci) => Some(aci), + ServiceId::Pni(_) => None, } - .map_err(|e| { - tracing::error!("Parsing ServiceAddress from '{}'", value); - ParseServiceAddressError::InvalidUuid(e) - }) } -} - -impl TryFrom<&[u8]> for ServiceAddress { - type Error = ParseServiceAddressError; - fn try_from(value: &[u8]) -> Result { - if let Some(pni) = value.strip_prefix(b"PNI:") { - Ok(ServiceAddress::from_pni(Uuid::from_slice(pni)?)) - } else { - Ok(ServiceAddress::from_aci(Uuid::from_slice(value)?)) + fn pni(self) -> Option { + match self.into() { + ServiceId::Aci(_) => None, + ServiceId::Pni(pni) => Some(pni), } - .map_err(|e| { - tracing::error!("Parsing ServiceAddress from {:?}", value); - ParseServiceAddressError::InvalidUuid(e) - }) } } diff --git a/src/session_store.rs b/src/session_store.rs index affbff09b..983b18717 100644 --- a/src/session_store.rs +++ b/src/session_store.rs @@ -1,7 +1,9 @@ use async_trait::async_trait; -use libsignal_protocol::{ProtocolAddress, SessionStore, SignalProtocolError}; +use libsignal_protocol::{ + ProtocolAddress, ServiceId, SessionStore, SignalProtocolError, +}; -use crate::{push_service::DEFAULT_DEVICE_ID, ServiceAddress}; +use crate::push_service::DEFAULT_DEVICE_ID; /// This is additional functions required to handle /// session deletion. It might be a candidate for inclusion into @@ -13,7 +15,7 @@ pub trait SessionStoreExt: SessionStore { /// This should return every device except for the main device [DEFAULT_DEVICE_ID]. async fn get_sub_device_sessions( &self, - name: &ServiceAddress, + name: &ServiceId, ) -> Result, SignalProtocolError>; /// Remove a session record for a recipient ID + device ID tuple. @@ -28,7 +30,7 @@ pub trait SessionStoreExt: SessionStore { /// Returns the number of deleted sessions. async fn delete_all_sessions( &self, - address: &ServiceAddress, + address: &ServiceId, ) -> Result; /// Remove a session record for a recipient ID + device ID tuple. @@ -48,10 +50,10 @@ pub trait SessionStoreExt: SessionStore { Ok(count) } - async fn compute_safety_number<'s>( - &'s self, - local_address: &'s ServiceAddress, - address: &'s ServiceAddress, + async fn compute_safety_number( + &self, + local_address: &ServiceId, + address: &ServiceId, ) -> Result where Self: Sized + libsignal_protocol::IdentityKeyStore, @@ -73,9 +75,9 @@ pub trait SessionStoreExt: SessionStore { let fp = libsignal_protocol::Fingerprint::new( 2, 5200, - local_address.uuid.as_bytes(), + local_address.raw_uuid().as_bytes(), local.identity_key(), - address.uuid.as_bytes(), + address.raw_uuid().as_bytes(), &ident, )?; fp.display_string() diff --git a/src/utils.rs b/src/utils.rs index 236b6968b..4d42ef4f7 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -290,3 +290,28 @@ pub mod serde_phone_number { .map_err(serde::de::Error::custom) } } + +pub mod serde_service_id { + use libsignal_protocol::ServiceId; + use serde::{Deserialize, Deserializer, Serializer}; + + pub fn serialize( + service_id: &ServiceId, + serializer: S, + ) -> Result + where + S: Serializer, + { + serializer.serialize_str(&service_id.service_id_string()) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + ServiceId::parse_from_service_id_string(&String::deserialize( + deserializer, + )?) + .ok_or_else(|| serde::de::Error::custom("invalid service ID string")) + } +} diff --git a/src/websocket/sender.rs b/src/websocket/sender.rs index 55cd4fbdd..36a1a120b 100644 --- a/src/websocket/sender.rs +++ b/src/websocket/sender.rs @@ -13,7 +13,10 @@ impl SignalWebSocket { messages: OutgoingPushMessages, ) -> Result { let request = WebSocketRequestMessage::new(Method::PUT) - .path(format!("/v1/messages/{}", messages.destination)) + .path(format!( + "/v1/messages/{}", + messages.destination.service_id_string() + )) .json(&messages)?; self.request_json(request).await } @@ -24,7 +27,10 @@ impl SignalWebSocket { access: &UnidentifiedAccess, ) -> Result { let request = WebSocketRequestMessage::new(Method::PUT) - .path(format!("/v1/messages/{}", messages.destination)) + .path(format!( + "/v1/messages/{}", + messages.destination.service_id_string() + )) .header( "Unidentified-Access-Key", BASE64_RELAXED.encode(&access.key),