Skip to content

Commit

Permalink
added more test coverage for optional user id \n improved comment syn…
Browse files Browse the repository at this point in the history
…tax and meaning
  • Loading branch information
aruokhai committed Jul 1, 2024
1 parent 4dfb470 commit ca497ee
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 60 deletions.
3 changes: 2 additions & 1 deletion teos/proto/teos/v2/appointment.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ package teos.v2;
import "common/teos/v2/appointment.proto";

message GetAppointmentsRequest {
// Request the information of appointments with specific locator and user_id (optional) .
// Request the information of appointments with specific locator.
// If a user id is provided (optional), request only appointments belonging to that user.

bytes locator = 1;
optional bytes user_id = 2;
Expand Down
134 changes: 116 additions & 18 deletions teos/src/api/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,26 +298,23 @@ impl PrivateTowerServices for Arc<InternalAPI> {
)
})?;

let appointments: Vec<Appointment> = self
let mut matching_appointments: Vec<common_msgs::AppointmentData> = self
.watcher
.get_watcher_appointments_with_locator(locator, user_id)
.into_values()
.map(|appointment| appointment.inner)
.collect();

let mut matching_appointments: Vec<common_msgs::AppointmentData> = appointments
.into_iter()
.map(|appointment| common_msgs::AppointmentData {
appointment_data: Some(
common_msgs::appointment_data::AppointmentData::Appointment(appointment.into()),
common_msgs::appointment_data::AppointmentData::Appointment(
appointment.inner.into(),
),
),
})
.collect();

for (_, tracker) in self
for tracker in self
.watcher
.get_responder_trackers_with_locator(locator, user_id)
.into_iter()
.into_values()
{
matching_appointments.push(common_msgs::AppointmentData {
appointment_data: Some(common_msgs::appointment_data::AppointmentData::Tracker(
Expand Down Expand Up @@ -445,6 +442,8 @@ mod tests_private_api {
use bitcoin::hashes::Hash;
use bitcoin::Txid;

use rand::{self, thread_rng, Rng};

use crate::responder::{ConfirmationStatus, TransactionTracker};
use crate::test_utils::{
create_api, generate_dummy_appointment, generate_dummy_appointment_with_user,
Expand All @@ -453,7 +452,7 @@ mod tests_private_api {
use crate::watcher::Breach;

use teos_common::cryptography::{self, get_random_keypair};
use teos_common::test_utils::get_random_user_id;
use teos_common::test_utils::{get_random_locator, get_random_user_id};

#[tokio::test]
async fn test_get_all_appointments() {
Expand Down Expand Up @@ -531,18 +530,33 @@ mod tests_private_api {
.into_inner();

assert!(matches!(response, msgs::GetAppointmentsResponse { .. }));

let user_id = get_random_user_id().to_vec();
let locator = get_random_locator().to_vec();
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator,
user_id: Some(user_id),
}))
.await
.unwrap()
.into_inner();

assert!(matches!(response, msgs::GetAppointmentsResponse { .. }));
}

#[tokio::test]
async fn test_get_appointments_watcher() {
async fn test_get_appointments_watcher_without_userid() {
let (internal_api, _s) = create_api().await;

for i in 0..3 {
// Create a dispute tx to be used for creating different dummy appointments with the same locator.
let dispute_txid = get_random_tx().txid();
let locator = Locator::new(dispute_txid);

// The number of different appointments to create for this dispute tx.
let appointments_to_create = 4 * i + 7;
let random_number = 4 * i + 7;
let appointments_to_create = random_number;

// Add that many appointments to the watcher.
for _ in 0..appointments_to_create {
Expand All @@ -556,8 +570,6 @@ mod tests_private_api {
.unwrap();
}

let locator = Locator::new(dispute_txid);

// Query for the current locator and assert it retrieves correct appointments.
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
Expand All @@ -584,6 +596,62 @@ mod tests_private_api {
}
}

#[tokio::test]
async fn test_get_appointments_watcher_with_userid() {
let (internal_api, _s) = create_api().await;

for i in 0..3 {
// Create a dispute tx to be used for creating different dummy appointments with the same locator.
let dispute_txid = get_random_tx().txid();
let locator = Locator::new(dispute_txid);

// The number of different appointments to create for this dispute tx.
let random_number = 4 * i + 7;
let appointments_to_create = random_number;

let mut random_users_list = Vec::new();

// Add that many appointments to the watcher.
for _ in 0..appointments_to_create {
let (user_sk, user_pk) = get_random_keypair();
let user_id = UserId(user_pk);
internal_api.watcher.register(user_id).unwrap();
random_users_list.push(user_id);
let appointment = generate_dummy_appointment(Some(&dispute_txid)).inner;
let signature = cryptography::sign(&appointment.to_vec(), &user_sk).unwrap();
internal_api
.watcher
.add_appointment(appointment, signature)
.unwrap();
}

for user_id in random_users_list.into_iter() {
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator: locator.to_vec(),
user_id: Some(user_id.to_vec()),
}))
.await
.unwrap()
.into_inner();

// Verify that only a single appointment is returned
assert_eq!(response.appointments.len(), 1);

// Verify that the appointment have the current locator
assert!(matches!(
response.appointments[0].appointment_data,
Some(common_msgs::appointment_data::AppointmentData::Appointment(
common_msgs::Appointment {
locator: ref app_loc,
..
}
)) if Locator::from_slice(app_loc).unwrap() == locator
));
}
}
}

#[tokio::test]
async fn test_get_appointments_responder() {
let (internal_api, _s) = create_api().await;
Expand All @@ -596,11 +664,19 @@ mod tests_private_api {
// The number of different trackers to create for this dispute tx.
let trackers_to_create = 4 * i + 7;

let random_tracker_num = thread_rng().gen_range(0..trackers_to_create);
let random_user_id = get_random_user_id();

// Add that many trackers to the responder.
for _ in 0..trackers_to_create {
for i in 0..trackers_to_create {
let user_id = if i == random_tracker_num {
random_user_id
} else {
get_random_user_id()
};
let tracker = TransactionTracker::new(
breach.clone(),
get_random_user_id(),
user_id,
ConfirmationStatus::ConfirmedIn(100),
);
internal_api
Expand All @@ -610,7 +686,7 @@ mod tests_private_api {

let locator = Locator::new(dispute_tx.txid());

// Query for the current locator and assert it retrieves correct trackers.
// Query for the current locator without the optional user_id.
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator: locator.to_vec(),
Expand All @@ -620,7 +696,7 @@ mod tests_private_api {
.unwrap()
.into_inner();

// The response should contain `trackers_to_create` trackers, all with dispute txid that matches with the locator of the current iteration.
// Verify that the response should contain `trackers_to_create` trackers, all with dispute txid that matches with the locator of the current iteration.
assert_eq!(response.appointments.len(), trackers_to_create);
for app_data in response.appointments {
assert!(matches!(
Expand All @@ -633,6 +709,28 @@ mod tests_private_api {
)) if Locator::new(Txid::from_slice(dispute_txid).unwrap()) == locator
));
}

// Query for the current locator with the optional user_id present.
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator: locator.to_vec(),
user_id: Some(random_user_id.to_vec()),
}))
.await
.unwrap()
.into_inner();

// Verify that only a single appointment is returned and the correct locator is found
assert_eq!(response.appointments.len(), 1);
assert!(matches!(
response.appointments[0].appointment_data,
Some(common_msgs::appointment_data::AppointmentData::Tracker(
common_msgs::Tracker {
ref dispute_txid,
..
}
)) if Locator::new(Txid::from_slice(dispute_txid).unwrap()) == locator
));
}
}

Expand Down
13 changes: 5 additions & 8 deletions teos/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@ pub fn data_dir_absolute_path(data_dir: String) -> PathBuf {

pub fn from_file<T: Default + serde::de::DeserializeOwned>(path: &PathBuf) -> T {
match std::fs::read(path) {
Ok(file_content) => toml::from_slice::<T>(&file_content).map_or_else(
|e| {
eprintln!("Couldn't parse config file: {e}");
T::default()
},
|config| config,
),
Ok(file_content) => toml::from_slice::<T>(&file_content).unwrap_or_else(|e| {
eprintln!("Couldn't parse config file: {e}");
T::default()
}),
Err(_) => T::default(),
}
}
Expand Down Expand Up @@ -392,7 +389,7 @@ mod tests {
assert_eq!(config.api_bind, expected_value);

// Check the rest of fields are equal. The easiest is to just the field back and compare with a clone
config.api_bind = config_clone.api_bind.clone();
config.api_bind.clone_from(&config_clone.api_bind);
assert_eq!(config, config_clone);
}

Expand Down
Loading

0 comments on commit ca497ee

Please sign in to comment.