From 9322d8d0af2ac6e87191c8d05d0d7f1ec234b73b Mon Sep 17 00:00:00 2001 From: Chris Smith <1979423+chris13524@users.noreply.github.com> Date: Mon, 15 Jan 2024 08:30:56 -0500 Subject: [PATCH] fix: subscribe twice sends 2 welcome notifications (#280) * fix: subscribe not idempotent * chore: add comment --- src/model/helpers.rs | 20 +- .../handlers/notify_subscribe.rs | 123 +++++----- tests/integration.rs | 216 ++++++++++++++++-- 3 files changed, 275 insertions(+), 84 deletions(-) diff --git a/src/model/helpers.rs b/src/model/helpers.rs index 5a85e090..2e89cba0 100644 --- a/src/model/helpers.rs +++ b/src/model/helpers.rs @@ -314,10 +314,13 @@ pub async fn get_project_topics( } #[derive(Debug, FromRow)] -pub struct SubscriberWithId { +pub struct SubscribeResponse { pub id: Uuid, #[sqlx(try_from = "String")] pub account: AccountId, + #[sqlx(try_from = "String")] + pub topic: Topic, + pub inserted: bool, } // TODO test idempotency @@ -330,11 +333,10 @@ pub async fn upsert_subscriber( notify_topic: Topic, postgres: &PgPool, metrics: Option<&Metrics>, -) -> Result { +) -> Result { let mut txn = postgres.begin().await?; - // Note that sym_key and topic are updated on conflict. This could be implemented return the existing value like subscribe-topic does, - // but no reason to currently: https://walletconnect.slack.com/archives/C044SKFKELR/p1701994415291179?thread_ts=1701960403.729959&cid=C044SKFKELR + // `xmax = 0`: https://stackoverflow.com/a/39204667 let query = " INSERT INTO subscriber ( @@ -347,13 +349,15 @@ pub async fn upsert_subscriber( VALUES ($1, $2, $3, $4, $5) ON CONFLICT (project, get_address_lower(account)) DO UPDATE SET updated_at=now(), - sym_key=$3, - topic=$4, expiry=$5 - RETURNING id, account + RETURNING + id, + account, + topic, + (xmax = 0) AS inserted "; let start = Instant::now(); - let subscriber = sqlx::query_as::(query) + let subscriber = sqlx::query_as::(query) .bind(project) .bind(account.as_ref()) .bind(hex::encode(notify_key)) diff --git a/src/services/websocket_server/handlers/notify_subscribe.rs b/src/services/websocket_server/handlers/notify_subscribe.rs index 55301ffc..e9a39153 100644 --- a/src/services/websocket_server/handlers/notify_subscribe.rs +++ b/src/services/websocket_server/handlers/notify_subscribe.rs @@ -87,7 +87,6 @@ pub async fn handle(msg: PublishedMessage, state: &AppState) -> Result<()> { info!("response_topic: {response_topic}"); let msg: NotifyRequest = decrypt_message(envelope, &sym_key)?; - let id = msg.id; let request_auth = from_jwt::(&msg.params.subscription_auth)?; info!( @@ -133,41 +132,42 @@ pub async fn handle(msg: PublishedMessage, state: &AppState) -> Result<()> { (account, domain) }; - let secret = StaticSecret::random_from_rng(chacha20poly1305::aead::OsRng); - - // Technically we don't need to derive based on client_public_key anymore; we just need a symkey. But this is technical - // debt from when clients derived the same symkey on their end via Diffie-Hellman. But now they use the value from - // watch subscriptions. - let notify_key = derive_key(&client_public_key, &secret)?; - let scope = parse_scope(&request_auth.scp)?; - let notify_topic = topic_from_key(¬ify_key); - - let project_id = project.project_id; - info!( - "Registering account: {account} with topic: {notify_topic} at project: {project_id}. \ - Scope: {scope:?}. RPC ID: {id:?}", - ); - - info!("Timing: Upserting subscriber"); - let subscriber = upsert_subscriber( - project.id, - account.clone(), - scope.clone(), - ¬ify_key, - notify_topic.clone(), - &state.postgres, - state.metrics.as_ref(), - ) - .await?; + let subscriber = { + // Technically we don't need to derive based on client_public_key anymore; we just need a symkey. But this is technical + // debt from when clients derived the same symkey on their end via Diffie-Hellman. But now they use the value from + // watch subscriptions. + let secret = StaticSecret::random_from_rng(chacha20poly1305::aead::OsRng); + let notify_key = derive_key(&client_public_key, &secret)?; + let notify_topic = topic_from_key(¬ify_key); + + info!("Timing: Upserting subscriber"); + upsert_subscriber( + project.id, + account.clone(), + scope.clone(), + ¬ify_key, + notify_topic, + &state.postgres, + state.metrics.as_ref(), + ) + .await? + }; info!("Timing: Finished upserting subscriber"); + let notify_topic = subscriber.topic; + // TODO do in same transaction as upsert_subscriber() state .notify_webhook( - project_id.as_ref(), + project.project_id.as_ref(), + // TODO uncomment when `WebhookNotificationEvent::Updated` exists + // if subscriber.inserted { WebhookNotificationEvent::Subscribed, + // } else { + // WebhookNotificationEvent::Updated + // }, account.as_ref(), ) .await?; @@ -182,7 +182,7 @@ pub async fn handle(msg: PublishedMessage, state: &AppState) -> Result<()> { info!("Timing: Recording SubscriberUpdateParams"); state.analytics.client(SubscriberUpdateParams { project_pk: project.id, - project_id, + project_id: project.project_id, pk: subscriber.id, account: subscriber.account, // Use a consistent account for analytics rather than the per-request one updated_by_iss: request_iss_client_id.to_did_key().into(), @@ -267,39 +267,44 @@ pub async fn handle(msg: PublishedMessage, state: &AppState) -> Result<()> { .await?; info!("Timing: Finished publishing noop to notify_topic"); - let welcome_notification = - get_welcome_notification(project.id, &state.postgres, state.metrics.as_ref()).await?; - if let Some(welcome_notification) = welcome_notification { - info!("Welcome notification enabled"); - if welcome_notification.enabled && scope.contains(&welcome_notification.r#type) { - info!("Scope contains welcome notification type, sending welcome notification"); - let notification = upsert_notification( - Uuid::new_v4().to_string(), - project.id, - Notification { - r#type: welcome_notification.r#type, - title: welcome_notification.title, - body: welcome_notification.body, - url: welcome_notification.url, - icon: None, - }, - &state.postgres, - state.metrics.as_ref(), - ) - .await?; - - upsert_subscriber_notifications( - notification.id, - &[subscriber.id], - &state.postgres, - state.metrics.as_ref(), - ) - .await?; + // TODO do in same txn as upsert_subscriber() + if subscriber.inserted { + let welcome_notification = + get_welcome_notification(project.id, &state.postgres, state.metrics.as_ref()).await?; + if let Some(welcome_notification) = welcome_notification { + info!("Welcome notification enabled"); + if welcome_notification.enabled && scope.contains(&welcome_notification.r#type) { + info!("Scope contains welcome notification type, sending welcome notification"); + let notification = upsert_notification( + Uuid::new_v4().to_string(), + project.id, + Notification { + r#type: welcome_notification.r#type, + title: welcome_notification.title, + body: welcome_notification.body, + url: welcome_notification.url, + icon: None, + }, + &state.postgres, + state.metrics.as_ref(), + ) + .await?; + + upsert_subscriber_notifications( + notification.id, + &[subscriber.id], + &state.postgres, + state.metrics.as_ref(), + ) + .await?; + } else { + info!("Scope does not contain welcome notification type, not sending welcome notification"); + } } else { - info!("Scope does not contain welcome notification type, not sending welcome notification"); + info!("Welcome notification not enabled"); } } else { - info!("Welcome notification not enabled"); + info!("Subscriber already existed, not sending welcome notification"); } send_to_subscription_watchers( diff --git a/tests/integration.rs b/tests/integration.rs index e0278a36..6f99aaf7 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -38,8 +38,8 @@ use { get_subscriber_topics, get_subscribers_for_project_in, get_subscriptions_by_account_and_maybe_app, get_welcome_notification, set_welcome_notification, upsert_project, upsert_subscriber, - GetNotificationsParams, GetNotificationsResult, SubscriberAccountAndScopes, - SubscriberWithId, WelcomeNotification, + GetNotificationsParams, GetNotificationsResult, SubscribeResponse, + SubscriberAccountAndScopes, WelcomeNotification, }, types::AccountId, }, @@ -1822,7 +1822,7 @@ async fn test_notify_subscriber_rate_limit(notify_server: &NotifyServerContext) let scope = HashSet::from([notification_type]); let notify_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); let notify_topic = topic_from_key(¬ify_key); - let SubscriberWithId { + let SubscribeResponse { id: subscriber_id, .. } = upsert_subscriber( project.id, @@ -1933,7 +1933,7 @@ async fn test_notify_subscriber_rate_limit_single(notify_server: &NotifyServerCo let scope = HashSet::from([notification_type]); let notify_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); let notify_topic = topic_from_key(¬ify_key); - let SubscriberWithId { + let SubscribeResponse { id: subscriber_id1, .. } = upsert_subscriber( project.id, @@ -1951,7 +1951,7 @@ async fn test_notify_subscriber_rate_limit_single(notify_server: &NotifyServerCo let scope = HashSet::from([notification_type]); let notify_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); let notify_topic = topic_from_key(¬ify_key); - let SubscriberWithId { + let SubscribeResponse { id: _subscriber_id2, .. } = upsert_subscriber( @@ -2064,7 +2064,7 @@ async fn test_ignores_invalid_scopes(notify_server: &NotifyServerContext) { let scope = HashSet::from([Uuid::new_v4(), Uuid::new_v4()]); let notify_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); let notify_topic = topic_from_key(¬ify_key); - let SubscriberWithId { id: subscriber, .. } = upsert_subscriber( + let SubscribeResponse { id: subscriber, .. } = upsert_subscriber( project.id, account.clone(), scope.clone(), @@ -2294,7 +2294,7 @@ async fn test_dead_letter_and_giveup_checks() { let subscriber_sym_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); let subscriber_topic = topic_from_key(&subscriber_sym_key); let subscriber_scope = HashSet::from([Uuid::new_v4(), Uuid::new_v4()]); - let SubscriberWithId { + let SubscribeResponse { id: subscriber_id, .. } = upsert_subscriber( project.id, @@ -4800,7 +4800,7 @@ async fn get_notifications_0() { let subscriber_sym_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); let subscriber_topic = topic_from_key(&subscriber_sym_key); let subscriber_scope = HashSet::from([Uuid::new_v4(), Uuid::new_v4()]); - let SubscriberWithId { id: subscriber, .. } = upsert_subscriber( + let SubscribeResponse { id: subscriber, .. } = upsert_subscriber( project.id, account_id.clone(), subscriber_scope.clone(), @@ -4855,7 +4855,7 @@ async fn get_notifications_1() { let subscriber_sym_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); let subscriber_topic = topic_from_key(&subscriber_sym_key); let subscriber_scope = HashSet::from([Uuid::new_v4(), Uuid::new_v4()]); - let SubscriberWithId { id: subscriber, .. } = upsert_subscriber( + let SubscribeResponse { id: subscriber, .. } = upsert_subscriber( project.id, account_id.clone(), subscriber_scope.clone(), @@ -4937,7 +4937,7 @@ async fn get_notifications_4() { let subscriber_sym_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); let subscriber_topic = topic_from_key(&subscriber_sym_key); let subscriber_scope = HashSet::from([Uuid::new_v4(), Uuid::new_v4()]); - let SubscriberWithId { id: subscriber, .. } = upsert_subscriber( + let SubscribeResponse { id: subscriber, .. } = upsert_subscriber( project.id, account_id.clone(), subscriber_scope.clone(), @@ -5039,7 +5039,7 @@ async fn get_notifications_5() { let subscriber_sym_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); let subscriber_topic = topic_from_key(&subscriber_sym_key); let subscriber_scope = HashSet::from([Uuid::new_v4(), Uuid::new_v4()]); - let SubscriberWithId { id: subscriber, .. } = upsert_subscriber( + let SubscribeResponse { id: subscriber, .. } = upsert_subscriber( project.id, account_id.clone(), subscriber_scope.clone(), @@ -5137,7 +5137,7 @@ async fn get_notifications_6() { let subscriber_sym_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); let subscriber_topic = topic_from_key(&subscriber_sym_key); let subscriber_scope = HashSet::from([Uuid::new_v4(), Uuid::new_v4()]); - let SubscriberWithId { id: subscriber, .. } = upsert_subscriber( + let SubscribeResponse { id: subscriber, .. } = upsert_subscriber( project.id, account_id.clone(), subscriber_scope.clone(), @@ -5260,7 +5260,7 @@ async fn get_notifications_7() { let subscriber_sym_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); let subscriber_topic = topic_from_key(&subscriber_sym_key); let subscriber_scope = HashSet::from([Uuid::new_v4(), Uuid::new_v4()]); - let SubscriberWithId { id: subscriber, .. } = upsert_subscriber( + let SubscribeResponse { id: subscriber, .. } = upsert_subscriber( project.id, account_id.clone(), subscriber_scope.clone(), @@ -5367,7 +5367,7 @@ async fn different_created_at() { let subscriber_sym_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); let subscriber_topic = topic_from_key(&subscriber_sym_key); let subscriber_scope = HashSet::from([Uuid::new_v4(), Uuid::new_v4()]); - let SubscriberWithId { id: subscriber, .. } = upsert_subscriber( + let SubscribeResponse { id: subscriber, .. } = upsert_subscriber( project.id, account_id.clone(), subscriber_scope.clone(), @@ -5474,7 +5474,7 @@ async fn duplicate_created_at() { let subscriber_sym_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); let subscriber_topic = topic_from_key(&subscriber_sym_key); let subscriber_scope = HashSet::from([Uuid::new_v4(), Uuid::new_v4()]); - let SubscriberWithId { id: subscriber, .. } = upsert_subscriber( + let SubscribeResponse { id: subscriber, .. } = upsert_subscriber( project.id, account_id.clone(), subscriber_scope.clone(), @@ -6104,6 +6104,181 @@ async fn e2e_send_welcome_notification(notify_server: &NotifyServerContext) { assert_eq!(welcome_notification.url, gotten_notification.url); } +#[test_context(NotifyServerContext)] +#[tokio::test] +async fn e2e_send_single_welcome_notification(notify_server: &NotifyServerContext) { + let ( + relay_ws_client, + mut rx, + account, + identity_key_details, + project_id, + app_domain, + app_client_id, + app_key_agreement_key, + _app_authentication_key, + notify_server_client_id, + watch_topic_key, + ) = setup_project_and_watch(notify_server.url.clone()).await; + let project = get_project_by_project_id(project_id.clone(), ¬ify_server.postgres, None) + .await + .unwrap(); + + let notification_type = Uuid::new_v4(); + set_welcome_notification( + project.id, + WelcomeNotification { + enabled: true, + r#type: notification_type, + title: "title".to_owned(), + body: "body".to_owned(), + url: None, + }, + ¬ify_server.postgres, + None, + ) + .await + .unwrap(); + + let _notify_key = subscribe_to_notifications( + &relay_ws_client, + &mut rx, + &account, + &identity_key_details, + app_domain.clone(), + &app_client_id, + app_key_agreement_key, + ¬ify_server_client_id, + watch_topic_key, + HashSet::from([notification_type]), + ) + .await; + + let notify_key = subscribe_to_notifications( + &relay_ws_client, + &mut rx, + &account, + &identity_key_details, + app_domain.clone(), + &app_client_id, + app_key_agreement_key, + ¬ify_server_client_id, + watch_topic_key, + HashSet::from([notification_type]), + ) + .await; + + let result = get_notifications( + &relay_ws_client, + &mut rx, + &account, + &identity_key_details, + &app_domain, + &app_client_id, + notify_key, + GetNotificationsParams { + limit: 5, + after: None, + }, + ) + .await; + assert_eq!(result.notifications.len(), 1); +} + +#[test_context(NotifyServerContext)] +#[tokio::test] +async fn subscribe_idempotent_keeps_symkey(notify_server: &NotifyServerContext) { + let ( + relay_ws_client, + mut rx, + account, + identity_key_details, + _project_id, + app_domain, + app_client_id, + app_key_agreement_key, + _app_authentication_key, + notify_server_client_id, + watch_topic_key, + ) = setup_project_and_watch(notify_server.url.clone()).await; + + let notify_key1 = subscribe_to_notifications( + &relay_ws_client, + &mut rx, + &account, + &identity_key_details, + app_domain.clone(), + &app_client_id, + app_key_agreement_key, + ¬ify_server_client_id, + watch_topic_key, + HashSet::from([Uuid::new_v4()]), + ) + .await; + + let notify_key2 = subscribe_to_notifications( + &relay_ws_client, + &mut rx, + &account, + &identity_key_details, + app_domain.clone(), + &app_client_id, + app_key_agreement_key, + ¬ify_server_client_id, + watch_topic_key, + HashSet::from([Uuid::new_v4()]), + ) + .await; + + assert_eq!(notify_key1, notify_key2); +} + +#[test_context(NotifyServerContext)] +#[tokio::test] +async fn subscribe_idempotent_updates_notification_types(notify_server: &NotifyServerContext) { + let ( + relay_ws_client, + mut rx, + account, + identity_key_details, + _project_id, + app_domain, + app_client_id, + app_key_agreement_key, + _app_authentication_key, + _notify_server_client_id, + _watch_topic_key, + ) = setup_project_and_watch(notify_server.url.clone()).await; + + let notification_types = HashSet::from([Uuid::new_v4()]); + let subs = subscribe_v1( + &relay_ws_client, + &mut rx, + &account, + &identity_key_details, + app_key_agreement_key, + &app_client_id, + app_domain.clone(), + notification_types.clone(), + ) + .await; + assert_eq!(subs[0].scope, notification_types); + + let notification_types = HashSet::from([Uuid::new_v4()]); + let subs = subscribe_v1( + &relay_ws_client, + &mut rx, + &account, + &identity_key_details, + app_key_agreement_key, + &app_client_id, + app_domain, + notification_types.clone(), + ) + .await; + assert_eq!(subs[0].scope, notification_types); +} + #[test_context(NotifyServerContext)] #[tokio::test] async fn e2e_doesnt_send_welcome_notification(notify_server: &NotifyServerContext) { @@ -8342,13 +8517,20 @@ impl<'s> MigrationSource<'s> for &'s StopMigrator { } } +#[derive(Debug, FromRow)] +pub struct RawSubscribeResponse { + pub id: Uuid, + #[sqlx(try_from = "String")] + pub account: AccountId, +} + pub async fn raw_upsert_subscriber( project: Uuid, account: AccountId, notify_key: &[u8; 32], notify_topic: Topic, postgres: &PgPool, -) -> Result { +) -> Result { let query = " INSERT INTO subscriber ( project, @@ -8360,7 +8542,7 @@ pub async fn raw_upsert_subscriber( VALUES ($1, $2, $3, $4, $5) RETURNING id, account "; - let subscriber = sqlx::query_as::(query) + let subscriber = sqlx::query_as::(query) .bind(project) .bind(account.as_ref()) .bind(hex::encode(notify_key))