diff --git a/teos/proto/teos/v2/appointment.proto b/teos/proto/teos/v2/appointment.proto index ecb873ee..07aab6f1 100644 --- a/teos/proto/teos/v2/appointment.proto +++ b/teos/proto/teos/v2/appointment.proto @@ -4,7 +4,8 @@ package teos.v2; import "common/teos/v2/appointment.proto"; message GetAppointmentsRequest { - // Request the information of appointments with specific locator and user_id (optional) . + // Request the information of appointments with specific locator. + // If a user id is provided (optional), request only appointments belonging to that user. bytes locator = 1; optional bytes user_id = 2; diff --git a/teos/src/api/internal.rs b/teos/src/api/internal.rs index fdfd13dd..7a5e5ee5 100644 --- a/teos/src/api/internal.rs +++ b/teos/src/api/internal.rs @@ -298,26 +298,23 @@ impl PrivateTowerServices for Arc { ) })?; - let appointments: Vec = self + let mut matching_appointments: Vec = self .watcher .get_watcher_appointments_with_locator(locator, user_id) .into_values() - .map(|appointment| appointment.inner) - .collect(); - - let mut matching_appointments: Vec = appointments - .into_iter() .map(|appointment| common_msgs::AppointmentData { appointment_data: Some( - common_msgs::appointment_data::AppointmentData::Appointment(appointment.into()), + common_msgs::appointment_data::AppointmentData::Appointment( + appointment.inner.into(), + ), ), }) .collect(); - for (_, tracker) in self + for tracker in self .watcher .get_responder_trackers_with_locator(locator, user_id) - .into_iter() + .into_values() { matching_appointments.push(common_msgs::AppointmentData { appointment_data: Some(common_msgs::appointment_data::AppointmentData::Tracker( @@ -445,6 +442,8 @@ mod tests_private_api { use bitcoin::hashes::Hash; use bitcoin::Txid; + use rand::{self, thread_rng, Rng}; + use crate::responder::{ConfirmationStatus, TransactionTracker}; use crate::test_utils::{ create_api, generate_dummy_appointment, generate_dummy_appointment_with_user, @@ -453,7 +452,7 @@ mod tests_private_api { use crate::watcher::Breach; use teos_common::cryptography::{self, get_random_keypair}; - use teos_common::test_utils::get_random_user_id; + use teos_common::test_utils::{get_random_locator, get_random_user_id}; #[tokio::test] async fn test_get_all_appointments() { @@ -531,18 +530,33 @@ mod tests_private_api { .into_inner(); assert!(matches!(response, msgs::GetAppointmentsResponse { .. })); + + let user_id = get_random_user_id().to_vec(); + let locator = get_random_locator().to_vec(); + let response = internal_api + .get_appointments(Request::new(msgs::GetAppointmentsRequest { + locator, + user_id: Some(user_id), + })) + .await + .unwrap() + .into_inner(); + + assert!(matches!(response, msgs::GetAppointmentsResponse { .. })); } #[tokio::test] - async fn test_get_appointments_watcher() { + async fn test_get_appointments_watcher_without_userid() { let (internal_api, _s) = create_api().await; for i in 0..3 { // Create a dispute tx to be used for creating different dummy appointments with the same locator. let dispute_txid = get_random_tx().txid(); + let locator = Locator::new(dispute_txid); // The number of different appointments to create for this dispute tx. - let appointments_to_create = 4 * i + 7; + let random_number = 4 * i + 7; + let appointments_to_create = random_number; // Add that many appointments to the watcher. for _ in 0..appointments_to_create { @@ -556,8 +570,6 @@ mod tests_private_api { .unwrap(); } - let locator = Locator::new(dispute_txid); - // Query for the current locator and assert it retrieves correct appointments. let response = internal_api .get_appointments(Request::new(msgs::GetAppointmentsRequest { @@ -584,6 +596,62 @@ mod tests_private_api { } } + #[tokio::test] + async fn test_get_appointments_watcher_with_userid() { + let (internal_api, _s) = create_api().await; + + for i in 0..3 { + // Create a dispute tx to be used for creating different dummy appointments with the same locator. + let dispute_txid = get_random_tx().txid(); + let locator = Locator::new(dispute_txid); + + // The number of different appointments to create for this dispute tx. + let random_number = 4 * i + 7; + let appointments_to_create = random_number; + + let mut random_users_list = Vec::new(); + + // Add that many appointments to the watcher. + for _ in 0..appointments_to_create { + let (user_sk, user_pk) = get_random_keypair(); + let user_id = UserId(user_pk); + internal_api.watcher.register(user_id).unwrap(); + random_users_list.push(user_id); + let appointment = generate_dummy_appointment(Some(&dispute_txid)).inner; + let signature = cryptography::sign(&appointment.to_vec(), &user_sk).unwrap(); + internal_api + .watcher + .add_appointment(appointment, signature) + .unwrap(); + } + + for user_id in random_users_list.into_iter() { + let response = internal_api + .get_appointments(Request::new(msgs::GetAppointmentsRequest { + locator: locator.to_vec(), + user_id: Some(user_id.to_vec()), + })) + .await + .unwrap() + .into_inner(); + + // Verify that only a single appointment is returned + assert_eq!(response.appointments.len(), 1); + + // Verify that the appointment have the current locator + assert!(matches!( + response.appointments[0].appointment_data, + Some(common_msgs::appointment_data::AppointmentData::Appointment( + common_msgs::Appointment { + locator: ref app_loc, + .. + } + )) if Locator::from_slice(app_loc).unwrap() == locator + )); + } + } + } + #[tokio::test] async fn test_get_appointments_responder() { let (internal_api, _s) = create_api().await; @@ -596,11 +664,19 @@ mod tests_private_api { // The number of different trackers to create for this dispute tx. let trackers_to_create = 4 * i + 7; + let random_tracker_num = thread_rng().gen_range(0..trackers_to_create); + let random_user_id = get_random_user_id(); + // Add that many trackers to the responder. - for _ in 0..trackers_to_create { + for i in 0..trackers_to_create { + let user_id = if i == random_tracker_num { + random_user_id + } else { + get_random_user_id() + }; let tracker = TransactionTracker::new( breach.clone(), - get_random_user_id(), + user_id, ConfirmationStatus::ConfirmedIn(100), ); internal_api @@ -610,7 +686,7 @@ mod tests_private_api { let locator = Locator::new(dispute_tx.txid()); - // Query for the current locator and assert it retrieves correct trackers. + // Query for the current locator without the optional user_id. let response = internal_api .get_appointments(Request::new(msgs::GetAppointmentsRequest { locator: locator.to_vec(), @@ -620,7 +696,7 @@ mod tests_private_api { .unwrap() .into_inner(); - // The response should contain `trackers_to_create` trackers, all with dispute txid that matches with the locator of the current iteration. + // Verify that the response should contain `trackers_to_create` trackers, all with dispute txid that matches with the locator of the current iteration. assert_eq!(response.appointments.len(), trackers_to_create); for app_data in response.appointments { assert!(matches!( @@ -633,6 +709,28 @@ mod tests_private_api { )) if Locator::new(Txid::from_slice(dispute_txid).unwrap()) == locator )); } + + // Query for the current locator with the optional user_id present. + let response = internal_api + .get_appointments(Request::new(msgs::GetAppointmentsRequest { + locator: locator.to_vec(), + user_id: Some(random_user_id.to_vec()), + })) + .await + .unwrap() + .into_inner(); + + // Verify that only a single appointment is returned and the correct locator is found + assert_eq!(response.appointments.len(), 1); + assert!(matches!( + response.appointments[0].appointment_data, + Some(common_msgs::appointment_data::AppointmentData::Tracker( + common_msgs::Tracker { + ref dispute_txid, + .. + } + )) if Locator::new(Txid::from_slice(dispute_txid).unwrap()) == locator + )); } } diff --git a/teos/src/config.rs b/teos/src/config.rs index 109dd91b..f1db3de0 100644 --- a/teos/src/config.rs +++ b/teos/src/config.rs @@ -18,13 +18,10 @@ pub fn data_dir_absolute_path(data_dir: String) -> PathBuf { pub fn from_file(path: &PathBuf) -> T { match std::fs::read(path) { - Ok(file_content) => toml::from_slice::(&file_content).map_or_else( - |e| { - eprintln!("Couldn't parse config file: {e}"); - T::default() - }, - |config| config, - ), + Ok(file_content) => toml::from_slice::(&file_content).unwrap_or_else(|e| { + eprintln!("Couldn't parse config file: {e}"); + T::default() + }), Err(_) => T::default(), } } @@ -392,7 +389,7 @@ mod tests { assert_eq!(config.api_bind, expected_value); // Check the rest of fields are equal. The easiest is to just the field back and compare with a clone - config.api_bind = config_clone.api_bind.clone(); + config.api_bind.clone_from(&config_clone.api_bind); assert_eq!(config, config_clone); } diff --git a/teos/src/dbm.rs b/teos/src/dbm.rs index 3d38e348..0c173a54 100644 --- a/teos/src/dbm.rs +++ b/teos/src/dbm.rs @@ -338,13 +338,15 @@ impl DBM { "SELECT a.UUID, a.locator, a.encrypted_blob, a.to_self_delay, a.user_signature, a.start_block, a.user_id FROM appointments as a LEFT JOIN trackers as t ON a.UUID=t.UUID WHERE t.UUID IS NULL".to_string(); - // If a locator and an optional user_id were passed, filter based on it. - if let Some((_, user_id)) = locator_and_userid { + // If a locator was passed, filter based on it. + if locator_and_userid.is_some() { sql.push_str(" AND a.locator=(?1)"); - if user_id.is_some() { - sql.push_str(" AND a.user_id=(?2)"); - } - }; + } + + // If a user_id is passed, filter even more. + if locator_and_userid.is_some_and(|inner| inner.1.is_some()) { + sql.push_str(" AND a.user_id=(?2)"); + } let mut stmt = self.connection.prepare(&sql).unwrap(); @@ -611,12 +613,14 @@ impl DBM { FROM trackers as t INNER JOIN appointments as a ON t.UUID=a.UUID" .to_string(); - // If a locator and an optional user_id were passed, filter based on it. - if let Some((_, user_id)) = locator_and_userid { + // If a locator was passed, filter based on it. + if locator_and_userid.is_some() { sql.push_str(" AND a.locator=(?1)"); - if user_id.is_some() { - sql.push_str(" AND a.user_id=(?2)"); - } + } + + // If a user_id is passed, filter even more. + if locator_and_userid.is_some_and(|inner| inner.1.is_some()) { + sql.push_str(" AND a.user_id=(?2)"); } let mut stmt = self.connection.prepare(&sql).unwrap(); @@ -774,8 +778,8 @@ mod tests { use crate::rpc_errors; use crate::test_utils::{ - generate_dummy_appointment, generate_dummy_appointment_with_user, generate_uuid, - get_random_tracker, get_random_tx, AVAILABLE_SLOTS, SUBSCRIPTION_EXPIRY, + generate_dummy_appointment, generate_dummy_appointment_with_user, generate_dummy_tracker, + generate_uuid, get_random_tracker, get_random_tx, AVAILABLE_SLOTS, SUBSCRIPTION_EXPIRY, SUBSCRIPTION_START, }; @@ -1201,7 +1205,7 @@ mod tests { let dispute_txid = dispute_tx.txid(); let locator = Locator::new(dispute_txid); - // create user id + // Create user id let user_id = get_random_user_id(); let user = UserInfo::new(AVAILABLE_SLOTS, SUBSCRIPTION_START, SUBSCRIPTION_EXPIRY); dbm.store_user(user_id, &user).unwrap(); @@ -1212,14 +1216,14 @@ mod tests { dbm.store_appointment(uuid, &appointment).unwrap(); appointments.insert(uuid, appointment.clone()); - // create random appointments + // Create random appointments for _ in 1..11 { let (uuid, appointment) = generate_dummy_appointment_with_user(user_id, None); dbm.store_appointment(uuid, &appointment).unwrap(); appointments.insert(uuid, appointment); } - // Returns empty if no appointment matches both userid and locator + // Verify that no appointment is returned if there is not an exact match of user_id + locator assert_eq!( dbm.load_appointments(Some((locator, Some(get_random_user_id()))),), HashMap::new() @@ -1229,17 +1233,17 @@ mod tests { HashMap::new() ); - // Returns particular appointments if they match both userid and locator + // Verify that the expected appointment is returned, if the correct user_id and locator is given assert_eq!( dbm.load_appointments(Some((locator, Some(user_id))),), HashMap::from([(uuid, appointment)]) ); - // Create a tracker from existing appointment - let tracker = get_random_tracker(user_id, ConfirmationStatus::InMempoolSince(100)); + // Create a tracker from the existing appointment + let tracker = generate_dummy_tracker(user_id, dispute_tx.clone()); dbm.store_tracker(uuid, &tracker).unwrap(); - // ensure that no tracker is returned + // Verify that an appointment is not returned, if it is triggered (there's a tracker for it) assert_eq!( dbm.load_appointments(Some((locator, Some(user_id))),), HashMap::new() @@ -1622,9 +1626,8 @@ mod tests { let dispute_tx = get_random_tx(); let dispute_txid = dispute_tx.txid(); let locator = Locator::new(dispute_txid); - let status = ConfirmationStatus::InMempoolSince(42); - // create user id + // Create user id let user_id = get_random_user_id(); let user = UserInfo::new(AVAILABLE_SLOTS, SUBSCRIPTION_START, SUBSCRIPTION_EXPIRY); dbm.store_user(user_id, &user).unwrap(); @@ -1632,20 +1635,20 @@ mod tests { // Create and store a particular tracker let (uuid, appointment) = generate_dummy_appointment_with_user(user_id, Some(&dispute_txid)); - let tracker = get_random_tracker(user_id, status); + let tracker = generate_dummy_tracker(user_id, dispute_tx.clone()); dbm.store_appointment(uuid, &appointment).unwrap(); dbm.store_tracker(uuid, &tracker).unwrap(); trackers.insert(uuid, tracker.clone()); - // create random trackers + // Create random trackers for _ in 1..11 { let (uuid, appointment) = generate_dummy_appointment_with_user(user_id, None); - let tracker = get_random_tracker(user_id, status); + let tracker = generate_dummy_tracker(user_id, dispute_tx.clone()); dbm.store_appointment(uuid, &appointment).unwrap(); dbm.store_tracker(uuid, &tracker).unwrap(); } - // Returns empty if no tracker matches both userid and locator + // Verify that no tracker is returned if there is not an exact match of user_id + locator assert_eq!( dbm.load_trackers(Some((locator, Some(get_random_user_id()))),), HashMap::new() @@ -1655,7 +1658,7 @@ mod tests { HashMap::new() ); - // Returns particular trackers if they match both userid and locator + // Verify that the expected tracker is returned if both the correct user_id and locator are provided assert_eq!( dbm.load_trackers(Some((locator, Some(user_id))),), HashMap::from([(uuid, tracker)]) diff --git a/teos/src/test_utils.rs b/teos/src/test_utils.rs index 4dba0a93..2366b195 100644 --- a/teos/src/test_utils.rs +++ b/teos/src/test_utils.rs @@ -341,6 +341,17 @@ pub(crate) fn get_random_tracker( TransactionTracker::new(breach, user_id, status) } +pub(crate) fn generate_dummy_tracker( + user_id: UserId, + dispute_tx: Transaction, +) -> TransactionTracker { + TransactionTracker::new( + Breach::new(dispute_tx.clone(), get_random_tx()), + user_id, + ConfirmationStatus::ConfirmedIn(100), + ) +} + pub(crate) fn store_appointment_and_its_user(dbm: &DBM, appointment: &ExtendedAppointment) { dbm.store_user( appointment.user_id, diff --git a/teos/src/tx_index.rs b/teos/src/tx_index.rs index c9d22b4e..404c0474 100644 --- a/teos/src/tx_index.rs +++ b/teos/src/tx_index.rs @@ -378,7 +378,7 @@ mod tests { // Check that the block data is not in the cache anymore assert_eq!(cache.blocks().len(), cache.size - i - 1); assert!(!cache.blocks().contains(&header.block_hash())); - assert!(cache.tx_in_block.get(&header.block_hash()).is_none()); + assert!(!cache.tx_in_block.contains_key(&header.block_hash())); for locator in locators.iter() { assert!(!cache.contains_key(locator)); } diff --git a/teos/src/watcher.rs b/teos/src/watcher.rs index 1ee78af1..23147dec 100644 --- a/teos/src/watcher.rs +++ b/teos/src/watcher.rs @@ -423,7 +423,8 @@ impl Watcher { self.dbm.lock().unwrap().load_appointments(None) } - /// Gets all the appointments matching a specific locator and an optional user id from the [Watcher] (from the database). + /// Gets all the appointments matching a specific locator + /// If a user id is provided (optional), only the appointments matching that user are returned pub(crate) fn get_watcher_appointments_with_locator( &self, locator: Locator, @@ -435,12 +436,12 @@ impl Watcher { .load_appointments(Some((locator, user_id))) } - /// Gets all the trackers stored in the [Responder] (from the database). + /// Gets all the trackers stored in the [Responder]. pub(crate) fn get_all_responder_trackers(&self) -> HashMap { self.dbm.lock().unwrap().load_trackers(None) } - /// Gets all the trackers matching a specific locator and an optional user id from the [Responder] (from the database). + /// Gets all the trackers matching a specific locator and an optional user id from the [Responder]. pub(crate) fn get_responder_trackers_with_locator( &self, locator: Locator, diff --git a/watchtower-plugin/src/main.rs b/watchtower-plugin/src/main.rs index 5778e8d0..1b845a64 100755 --- a/watchtower-plugin/src/main.rs +++ b/watchtower-plugin/src/main.rs @@ -385,7 +385,7 @@ async fn abandon_tower( ) -> Result { let tower_id = TowerId::try_from(v).map_err(|e| anyhow!(e))?; let mut state = plugin.state().lock().unwrap(); - if state.towers.get(&tower_id).is_some() { + if state.towers.contains_key(&tower_id) { state.remove_tower(tower_id).unwrap(); Ok(json!(format!("{tower_id} successfully abandoned"))) } else { diff --git a/watchtower-plugin/src/retrier.rs b/watchtower-plugin/src/retrier.rs index 0920e366..e61c9b33 100644 --- a/watchtower-plugin/src/retrier.rs +++ b/watchtower-plugin/src/retrier.rs @@ -432,7 +432,7 @@ impl Retrier { // Create a new scope so we can get all the data only locking the WTClient once. let (tower_id, status, net_addr, user_id, user_sk, proxy) = { let wt_client = self.wt_client.lock().unwrap(); - if wt_client.towers.get(&self.tower_id).is_none() { + if !wt_client.towers.contains_key(&self.tower_id) { return Err(Error::permanent(RetryError::Abandoned)); }