Skip to content

Commit

Permalink
added more test coverage for optional user id \n improved comment syn…
Browse files Browse the repository at this point in the history
…tax and meaning
  • Loading branch information
aruokhai committed Jun 29, 2024
1 parent 4dfb470 commit be0b6f2
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 49 deletions.
3 changes: 2 additions & 1 deletion teos/proto/teos/v2/appointment.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ package teos.v2;
import "common/teos/v2/appointment.proto";

message GetAppointmentsRequest {
// Request the information of appointments with specific locator and user_id (optional) .
// Request the information of appointments with specific locator.
// If a user id is provided (optional), request only appointments belonging to that user.

bytes locator = 1;
optional bytes user_id = 2;
Expand Down
134 changes: 116 additions & 18 deletions teos/src/api/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,26 +298,23 @@ impl PrivateTowerServices for Arc<InternalAPI> {
)
})?;

let appointments: Vec<Appointment> = self
let mut matching_appointments: Vec<common_msgs::AppointmentData> = self
.watcher
.get_watcher_appointments_with_locator(locator, user_id)
.into_values()
.map(|appointment| appointment.inner)
.collect();

let mut matching_appointments: Vec<common_msgs::AppointmentData> = appointments
.into_iter()
.map(|appointment| common_msgs::AppointmentData {
appointment_data: Some(
common_msgs::appointment_data::AppointmentData::Appointment(appointment.into()),
common_msgs::appointment_data::AppointmentData::Appointment(
appointment.inner.into(),
),
),
})
.collect();

for (_, tracker) in self
for tracker in self
.watcher
.get_responder_trackers_with_locator(locator, user_id)
.into_iter()
.into_values()
{
matching_appointments.push(common_msgs::AppointmentData {
appointment_data: Some(common_msgs::appointment_data::AppointmentData::Tracker(
Expand Down Expand Up @@ -445,6 +442,8 @@ mod tests_private_api {
use bitcoin::hashes::Hash;
use bitcoin::Txid;

use rand::{self, thread_rng, Rng};

use crate::responder::{ConfirmationStatus, TransactionTracker};
use crate::test_utils::{
create_api, generate_dummy_appointment, generate_dummy_appointment_with_user,
Expand All @@ -453,7 +452,7 @@ mod tests_private_api {
use crate::watcher::Breach;

use teos_common::cryptography::{self, get_random_keypair};
use teos_common::test_utils::get_random_user_id;
use teos_common::test_utils::{get_random_locator, get_random_user_id};

#[tokio::test]
async fn test_get_all_appointments() {
Expand Down Expand Up @@ -531,18 +530,33 @@ mod tests_private_api {
.into_inner();

assert!(matches!(response, msgs::GetAppointmentsResponse { .. }));

let user_id = get_random_user_id().to_vec();
let locator = get_random_locator().to_vec();
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator,
user_id: Some(user_id),
}))
.await
.unwrap()
.into_inner();

assert!(matches!(response, msgs::GetAppointmentsResponse { .. }));
}

#[tokio::test]
async fn test_get_appointments_watcher() {
async fn test_get_appointments_watcher_without_userid() {
let (internal_api, _s) = create_api().await;

for i in 0..3 {
// Create a dispute tx to be used for creating different dummy appointments with the same locator.
let dispute_txid = get_random_tx().txid();
let locator = Locator::new(dispute_txid);

// The number of different appointments to create for this dispute tx.
let appointments_to_create = 4 * i + 7;
let random_number = 4 * i + 7;
let appointments_to_create = random_number;

// Add that many appointments to the watcher.
for _ in 0..appointments_to_create {
Expand All @@ -556,8 +570,6 @@ mod tests_private_api {
.unwrap();
}

let locator = Locator::new(dispute_txid);

// Query for the current locator and assert it retrieves correct appointments.
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
Expand All @@ -584,6 +596,62 @@ mod tests_private_api {
}
}

#[tokio::test]
async fn test_get_appointments_watcher_with_userid() {
let (internal_api, _s) = create_api().await;

for i in 0..3 {
// Create a dispute tx to be used for creating different dummy appointments with the same locator.
let dispute_txid = get_random_tx().txid();
let locator = Locator::new(dispute_txid);

// The number of different appointments to create for this dispute tx.
let random_number = 4 * i + 7;
let appointments_to_create = random_number;

let mut random_users_list = Vec::new();

// Add that many appointments to the watcher.
for _ in 0..appointments_to_create {
let (user_sk, user_pk) = get_random_keypair();
let user_id = UserId(user_pk);
internal_api.watcher.register(user_id).unwrap();
random_users_list.push(user_id);
let appointment = generate_dummy_appointment(Some(&dispute_txid)).inner;
let signature = cryptography::sign(&appointment.to_vec(), &user_sk).unwrap();
internal_api
.watcher
.add_appointment(appointment, signature)
.unwrap();
}

for user_id in random_users_list.into_iter() {
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator: locator.to_vec(),
user_id: Some(user_id.to_vec()),
}))
.await
.unwrap()
.into_inner();

// Verify that only a single appointment is returned
assert_eq!(response.appointments.len(), 1);

// Verify that the appointment have the current locator
assert!(matches!(
response.appointments[0].appointment_data,
Some(common_msgs::appointment_data::AppointmentData::Appointment(
common_msgs::Appointment {
locator: ref app_loc,
..
}
)) if Locator::from_slice(app_loc).unwrap() == locator
));
}
}
}

#[tokio::test]
async fn test_get_appointments_responder() {
let (internal_api, _s) = create_api().await;
Expand All @@ -596,11 +664,19 @@ mod tests_private_api {
// The number of different trackers to create for this dispute tx.
let trackers_to_create = 4 * i + 7;

let random_tracker_num = thread_rng().gen_range(0..trackers_to_create);
let random_user_id = get_random_user_id();

// Add that many trackers to the responder.
for _ in 0..trackers_to_create {
for i in 0..trackers_to_create {
let user_id = if i == random_tracker_num {
random_user_id
} else {
get_random_user_id()
};
let tracker = TransactionTracker::new(
breach.clone(),
get_random_user_id(),
user_id,
ConfirmationStatus::ConfirmedIn(100),
);
internal_api
Expand All @@ -610,7 +686,7 @@ mod tests_private_api {

let locator = Locator::new(dispute_tx.txid());

// Query for the current locator and assert it retrieves correct trackers.
// Query for the current locator without the optional user_id.
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator: locator.to_vec(),
Expand All @@ -620,7 +696,7 @@ mod tests_private_api {
.unwrap()
.into_inner();

// The response should contain `trackers_to_create` trackers, all with dispute txid that matches with the locator of the current iteration.
// Verify that the response should contain `trackers_to_create` trackers, all with dispute txid that matches with the locator of the current iteration.
assert_eq!(response.appointments.len(), trackers_to_create);
for app_data in response.appointments {
assert!(matches!(
Expand All @@ -633,6 +709,28 @@ mod tests_private_api {
)) if Locator::new(Txid::from_slice(dispute_txid).unwrap()) == locator
));
}

// Query for the current locator with the optional user_id present.
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator: locator.to_vec(),
user_id: Some(random_user_id.to_vec()),
}))
.await
.unwrap()
.into_inner();

// Verify that only a single appointment is returned and the correct locator is found
assert_eq!(response.appointments.len(), 1);
assert!(matches!(
response.appointments[0].appointment_data,
Some(common_msgs::appointment_data::AppointmentData::Tracker(
common_msgs::Tracker {
ref dispute_txid,
..
}
)) if Locator::new(Txid::from_slice(dispute_txid).unwrap()) == locator
));
}
}

Expand Down
57 changes: 30 additions & 27 deletions teos/src/dbm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,15 @@ impl DBM {
"SELECT a.UUID, a.locator, a.encrypted_blob, a.to_self_delay, a.user_signature, a.start_block, a.user_id
FROM appointments as a LEFT JOIN trackers as t ON a.UUID=t.UUID WHERE t.UUID IS NULL".to_string();

// If a locator and an optional user_id were passed, filter based on it.
if let Some((_, user_id)) = locator_and_userid {
// If a locator was passed, filter based on it.
if locator_and_userid.is_some() {
sql.push_str(" AND a.locator=(?1)");
if user_id.is_some() {
sql.push_str(" AND a.user_id=(?2)");
}
};
}

// If a user_id is passed, filter even more.
if locator_and_userid.is_some_and(|inner| inner.1.is_some()) {
sql.push_str(" AND a.user_id=(?2)");
}

let mut stmt = self.connection.prepare(&sql).unwrap();

Expand Down Expand Up @@ -611,12 +613,14 @@ impl DBM {
FROM trackers as t INNER JOIN appointments as a ON t.UUID=a.UUID"
.to_string();

// If a locator and an optional user_id were passed, filter based on it.
if let Some((_, user_id)) = locator_and_userid {
// If a locator was passed, filter based on it.
if locator_and_userid.is_some() {
sql.push_str(" AND a.locator=(?1)");
if user_id.is_some() {
sql.push_str(" AND a.user_id=(?2)");
}
}

// If a user_id is passed, filter even more.
if locator_and_userid.is_some_and(|inner| inner.1.is_some()) {
sql.push_str(" AND a.user_id=(?2)");
}

let mut stmt = self.connection.prepare(&sql).unwrap();
Expand Down Expand Up @@ -774,8 +778,8 @@ mod tests {

use crate::rpc_errors;
use crate::test_utils::{
generate_dummy_appointment, generate_dummy_appointment_with_user, generate_uuid,
get_random_tracker, get_random_tx, AVAILABLE_SLOTS, SUBSCRIPTION_EXPIRY,
generate_dummy_appointment, generate_dummy_appointment_with_user, generate_dummy_tracker,
generate_uuid, get_random_tracker, get_random_tx, AVAILABLE_SLOTS, SUBSCRIPTION_EXPIRY,
SUBSCRIPTION_START,
};

Expand Down Expand Up @@ -1201,7 +1205,7 @@ mod tests {
let dispute_txid = dispute_tx.txid();
let locator = Locator::new(dispute_txid);

// create user id
// Create user id
let user_id = get_random_user_id();
let user = UserInfo::new(AVAILABLE_SLOTS, SUBSCRIPTION_START, SUBSCRIPTION_EXPIRY);
dbm.store_user(user_id, &user).unwrap();
Expand All @@ -1212,14 +1216,14 @@ mod tests {
dbm.store_appointment(uuid, &appointment).unwrap();
appointments.insert(uuid, appointment.clone());

// create random appointments
// Create random appointments
for _ in 1..11 {
let (uuid, appointment) = generate_dummy_appointment_with_user(user_id, None);
dbm.store_appointment(uuid, &appointment).unwrap();
appointments.insert(uuid, appointment);
}

// Returns empty if no appointment matches both userid and locator
// Verify that no appointment is returned if there is not an exact match of user_id + locator
assert_eq!(
dbm.load_appointments(Some((locator, Some(get_random_user_id()))),),
HashMap::new()
Expand All @@ -1229,17 +1233,17 @@ mod tests {
HashMap::new()
);

// Returns particular appointments if they match both userid and locator
// Verify that the expected appointment is returned, if the correct user_id and locator is given
assert_eq!(
dbm.load_appointments(Some((locator, Some(user_id))),),
HashMap::from([(uuid, appointment)])
);

// Create a tracker from existing appointment
let tracker = get_random_tracker(user_id, ConfirmationStatus::InMempoolSince(100));
// Create a tracker from the existing appointment
let tracker = generate_dummy_tracker(user_id, dispute_tx.clone());
dbm.store_tracker(uuid, &tracker).unwrap();

// ensure that no tracker is returned
// Verify that an appointment is not returned, if it is triggered (there's a tracker for it)
assert_eq!(
dbm.load_appointments(Some((locator, Some(user_id))),),
HashMap::new()
Expand Down Expand Up @@ -1622,30 +1626,29 @@ mod tests {
let dispute_tx = get_random_tx();
let dispute_txid = dispute_tx.txid();
let locator = Locator::new(dispute_txid);
let status = ConfirmationStatus::InMempoolSince(42);

// create user id
// Create user id
let user_id = get_random_user_id();
let user = UserInfo::new(AVAILABLE_SLOTS, SUBSCRIPTION_START, SUBSCRIPTION_EXPIRY);
dbm.store_user(user_id, &user).unwrap();

// Create and store a particular tracker
let (uuid, appointment) =
generate_dummy_appointment_with_user(user_id, Some(&dispute_txid));
let tracker = get_random_tracker(user_id, status);
let tracker = generate_dummy_tracker(user_id, dispute_tx.clone());
dbm.store_appointment(uuid, &appointment).unwrap();
dbm.store_tracker(uuid, &tracker).unwrap();
trackers.insert(uuid, tracker.clone());

// create random trackers
// Create random trackers
for _ in 1..11 {
let (uuid, appointment) = generate_dummy_appointment_with_user(user_id, None);
let tracker = get_random_tracker(user_id, status);
let tracker = generate_dummy_tracker(user_id, dispute_tx.clone());
dbm.store_appointment(uuid, &appointment).unwrap();
dbm.store_tracker(uuid, &tracker).unwrap();
}

// Returns empty if no tracker matches both userid and locator
// Verify that no tracker is returned if there is not an exact match of user_id + locator
assert_eq!(
dbm.load_trackers(Some((locator, Some(get_random_user_id()))),),
HashMap::new()
Expand All @@ -1655,7 +1658,7 @@ mod tests {
HashMap::new()
);

// Returns particular trackers if they match both userid and locator
// Verify that the expected tracker is returned if both the correct user_id and locator are provided
assert_eq!(
dbm.load_trackers(Some((locator, Some(user_id))),),
HashMap::from([(uuid, tracker)])
Expand Down
Loading

0 comments on commit be0b6f2

Please sign in to comment.