Skip to content

Commit

Permalink
added user_id as an option parameter to getappointments cli command
Browse files Browse the repository at this point in the history
Signed-off-by: aruokhai <[email protected]>

removed todo

Signed-off-by: aruokhai <[email protected]>

added broader unit test
  • Loading branch information
aruokhai committed Apr 4, 2024
1 parent 9eb02c9 commit 4dfb470
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 53 deletions.
2 changes: 1 addition & 1 deletion teos/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::configure()
.extern_path(".common.teos.v2", "::teos-common::protos")
.type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]")
.field_attribute("user_id", "#[serde(with = \"hex::serde\")]")
.field_attribute("GetUserRequest.user_id", "#[serde(with = \"hex::serde\")]")
.field_attribute("tower_id", "#[serde(with = \"hex::serde\")]")
.field_attribute(
"user_ids",
Expand Down
3 changes: 2 additions & 1 deletion teos/proto/teos/v2/appointment.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ package teos.v2;
import "common/teos/v2/appointment.proto";

message GetAppointmentsRequest {
// Request the information of appointments with specific locator.
// Request the information of appointments with specific locator and user_id (optional) .

bytes locator = 1;
optional bytes user_id = 2;
}

message GetAppointmentsResponse {
Expand Down
49 changes: 34 additions & 15 deletions teos/src/api/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::sync::{Arc, Condvar, Mutex};
use tonic::{Code, Request, Response, Status};
use triggered::Trigger;

use crate::extended_appointment::UUID;
use crate::protos as msgs;
use crate::protos::private_tower_services_server::PrivateTowerServices;
use crate::protos::public_tower_services_server::PublicTowerServices;
Expand Down Expand Up @@ -280,31 +279,44 @@ impl PrivateTowerServices for Arc<InternalAPI> {
.map_or("an unknown address".to_owned(), |a| a.to_string())
);

let mut matching_appointments = vec![];
let locator = Locator::from_slice(&request.into_inner().locator).map_err(|_| {
let req_data = request.into_inner();
let locator = Locator::from_slice(&req_data.locator).map_err(|_| {
Status::new(
Code::InvalidArgument,
"The provided locator does not match the expected format (16-byte hexadecimal string)",
)
})?;

for (_, appointment) in self
let user_id = req_data
.user_id
.map(|id| UserId::from_slice(&id))
.transpose()
.map_err(|_| {
Status::new(
Code::InvalidArgument,
"The Provided user_id does not match expected format (33-byte hex string)",
)
})?;

let appointments: Vec<Appointment> = self
.watcher
.get_watcher_appointments_with_locator(locator)
.get_watcher_appointments_with_locator(locator, user_id)
.into_values()
.map(|appointment| appointment.inner)
.collect();

let mut matching_appointments: Vec<common_msgs::AppointmentData> = appointments
.into_iter()
{
matching_appointments.push(common_msgs::AppointmentData {
.map(|appointment| common_msgs::AppointmentData {
appointment_data: Some(
common_msgs::appointment_data::AppointmentData::Appointment(
appointment.inner.into(),
),
common_msgs::appointment_data::AppointmentData::Appointment(appointment.into()),
),
})
}
.collect();

for (_, tracker) in self
.watcher
.get_responder_trackers_with_locator(locator)
.get_responder_trackers_with_locator(locator, user_id)
.into_iter()
{
matching_appointments.push(common_msgs::AppointmentData {
Expand Down Expand Up @@ -390,7 +402,6 @@ impl PrivateTowerServices for Arc<InternalAPI> {
Some((info, locators)) => Ok(Response::new(msgs::GetUserResponse {
available_slots: info.available_slots,
subscription_expiry: info.subscription_expiry,
// TODO: Should make `get_appointments` queryable using the (user_id, locator) pair for consistency.
appointments: locators
.into_iter()
.map(|locator| locator.to_vec())
Expand Down Expand Up @@ -511,7 +522,10 @@ mod tests_private_api {

let locator = Locator::new(get_random_tx().txid()).to_vec();
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest { locator }))
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator,
user_id: None,
}))
.await
.unwrap()
.into_inner();
Expand Down Expand Up @@ -548,6 +562,7 @@ mod tests_private_api {
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator: locator.to_vec(),
user_id: None,
}))
.await
.unwrap()
Expand Down Expand Up @@ -599,6 +614,7 @@ mod tests_private_api {
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator: locator.to_vec(),
user_id: None,
}))
.await
.unwrap()
Expand Down Expand Up @@ -747,7 +763,10 @@ mod tests_private_api {

assert_eq!(response.available_slots, SLOTS - 1);
assert_eq!(response.subscription_expiry, START_HEIGHT as u32 + DURATION);
assert_eq!(response.appointments, Vec::from([appointment.inner.locator.to_vec()]));
assert_eq!(
response.appointments,
Vec::from([appointment.inner.locator.to_vec()])
);
}

#[tokio::test]
Expand Down
32 changes: 20 additions & 12 deletions teos/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,28 @@ async fn main() {
println!("{}", pretty_json(&appointments.into_inner()).unwrap());
}
Command::GetAppointments(appointments_data) => {
match Locator::from_hex(&appointments_data.locator) {
Ok(locator) => {
match client
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator: locator.to_vec(),
}))
.await
{
Ok(appointments) => {
println!("{}", pretty_json(&appointments.into_inner()).unwrap())
match appointments_data
.user_id
.map(|id| UserId::from_str(&id).map(|user_id| user_id.to_vec()))
.transpose()
{
Ok(user_id) => match Locator::from_hex(&appointments_data.locator) {
Ok(locator) => {
match client
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator: locator.to_vec(),
user_id,
}))
.await
{
Ok(appointments) => {
println!("{}", pretty_json(&appointments.into_inner()).unwrap())
}
Err(status) => handle_error(status.message()),
}
Err(status) => handle_error(status.message()),
}
}
Err(e) => handle_error(e),
},
Err(e) => handle_error(e),
};
}
Expand Down
2 changes: 2 additions & 0 deletions teos/src/cli_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub struct GetUserData {
pub struct GetAppointmentsData {
/// The locator of the appointments (16-byte hexadecimal string).
pub locator: String,
/// The user identifier (33-byte compressed public key).
pub user_id: Option<String>,
}

/// Holds all the command line options and commands.
Expand Down
155 changes: 135 additions & 20 deletions teos/src/dbm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,23 +330,30 @@ impl DBM {
/// matching this locator. If no locator is given, all the appointments in the database would be returned.
pub(crate) fn load_appointments(
&self,
locator: Option<Locator>,
locator_and_userid: Option<(Locator, Option<UserId>)>,
) -> HashMap<UUID, ExtendedAppointment> {
let mut appointments = HashMap::new();

let mut sql =
"SELECT a.UUID, a.locator, a.encrypted_blob, a.to_self_delay, a.user_signature, a.start_block, a.user_id
FROM appointments as a LEFT JOIN trackers as t ON a.UUID=t.UUID WHERE t.UUID IS NULL".to_string();
// If a locator was passed, filter based on it.
if locator.is_some() {
sql.push_str(" AND a.locator=(?)");
}

// If a locator and an optional user_id were passed, filter based on it.
if let Some((_, user_id)) = locator_and_userid {
sql.push_str(" AND a.locator=(?1)");
if user_id.is_some() {
sql.push_str(" AND a.user_id=(?2)");
}
};

let mut stmt = self.connection.prepare(&sql).unwrap();

let mut rows = if let Some(locator) = locator {
stmt.query([locator.to_vec()]).unwrap()
} else {
stmt.query([]).unwrap()
let mut rows = match locator_and_userid {
Some((locator, None)) => stmt.query([locator.to_vec()]).unwrap(),
Some((locator, Some(user_id))) => {
stmt.query([locator.to_vec(), user_id.to_vec()]).unwrap()
}
_ => stmt.query([]).unwrap(),
};

while let Ok(Some(row)) = rows.next() {
Expand Down Expand Up @@ -596,23 +603,30 @@ impl DBM {
/// matching this locator. If no locator is given, all the trackers in the database would be returned.
pub(crate) fn load_trackers(
&self,
locator: Option<Locator>,
locator_and_userid: Option<(Locator, Option<UserId>)>,
) -> HashMap<UUID, TransactionTracker> {
let mut trackers = HashMap::new();

let mut sql = "SELECT t.UUID, t.dispute_tx, t.penalty_tx, t.height, t.confirmed, a.user_id
FROM trackers as t INNER JOIN appointments as a ON t.UUID=a.UUID"
.to_string();
// If a locator was passed, filter based on it.
if locator.is_some() {
sql.push_str(" WHERE a.locator=(?)");

// If a locator and an optional user_id were passed, filter based on it.
if let Some((_, user_id)) = locator_and_userid {
sql.push_str(" AND a.locator=(?1)");
if user_id.is_some() {
sql.push_str(" AND a.user_id=(?2)");
}
}

let mut stmt = self.connection.prepare(&sql).unwrap();

let mut rows = if let Some(locator) = locator {
stmt.query([locator.to_vec()]).unwrap()
} else {
stmt.query([]).unwrap()
let mut rows = match locator_and_userid {
Some((locator, None)) => stmt.query([locator.to_vec()]).unwrap(),
Some((locator, Some(user_id))) => {
stmt.query([locator.to_vec(), user_id.to_vec()]).unwrap()
}
_ => stmt.query([]).unwrap(),
};

while let Ok(Some(row)) = rows.next() {
Expand Down Expand Up @@ -1157,7 +1171,7 @@ mod tests {
}

// Validate that no other appointments than the ones with our locator are returned.
assert_eq!(dbm.load_appointments(Some(locator)), appointments);
assert_eq!(dbm.load_appointments(Some((locator, None))), appointments);

// If an appointment has an associated tracker, it should not be loaded since it is seen
// as a triggered appointment
Expand All @@ -1175,7 +1189,61 @@ mod tests {
dbm.store_tracker(uuid, &tracker).unwrap();

// We should get all the appointments matching our locator back except from the triggered one
assert_eq!(dbm.load_appointments(Some(locator)), appointments);
assert_eq!(dbm.load_appointments(Some((locator, None))), appointments);
}

#[test]
fn test_load_appointments_with_locator_and_user_id() {
let dbm = DBM::in_memory().unwrap();

let mut appointments = HashMap::new();
let dispute_tx = get_random_tx();
let dispute_txid = dispute_tx.txid();
let locator = Locator::new(dispute_txid);

// create user id
let user_id = get_random_user_id();
let user = UserInfo::new(AVAILABLE_SLOTS, SUBSCRIPTION_START, SUBSCRIPTION_EXPIRY);
dbm.store_user(user_id, &user).unwrap();

// Create and store a particular appointment
let (uuid, appointment) =
generate_dummy_appointment_with_user(user_id, Some(&dispute_txid));
dbm.store_appointment(uuid, &appointment).unwrap();
appointments.insert(uuid, appointment.clone());

// create random appointments
for _ in 1..11 {
let (uuid, appointment) = generate_dummy_appointment_with_user(user_id, None);
dbm.store_appointment(uuid, &appointment).unwrap();
appointments.insert(uuid, appointment);
}

// Returns empty if no appointment matches both userid and locator
assert_eq!(
dbm.load_appointments(Some((locator, Some(get_random_user_id()))),),
HashMap::new()
);
assert_eq!(
dbm.load_appointments(Some((get_random_locator(), Some(user_id))),),
HashMap::new()
);

// Returns particular appointments if they match both userid and locator
assert_eq!(
dbm.load_appointments(Some((locator, Some(user_id))),),
HashMap::from([(uuid, appointment)])
);

// Create a tracker from existing appointment
let tracker = get_random_tracker(user_id, ConfirmationStatus::InMempoolSince(100));
dbm.store_tracker(uuid, &tracker).unwrap();

// ensure that no tracker is returned
assert_eq!(
dbm.load_appointments(Some((locator, Some(user_id))),),
HashMap::new()
);
}

#[test]
Expand Down Expand Up @@ -1544,7 +1612,54 @@ mod tests {
dbm.store_tracker(uuid, &tracker).unwrap();
}

assert_eq!(dbm.load_trackers(Some(locator)), trackers);
assert_eq!(dbm.load_trackers(Some((locator, None))), trackers);
}

#[test]
fn test_load_trackers_with_locator_and_user_id() {
let dbm = DBM::in_memory().unwrap();
let mut trackers = HashMap::new();
let dispute_tx = get_random_tx();
let dispute_txid = dispute_tx.txid();
let locator = Locator::new(dispute_txid);
let status = ConfirmationStatus::InMempoolSince(42);

// create user id
let user_id = get_random_user_id();
let user = UserInfo::new(AVAILABLE_SLOTS, SUBSCRIPTION_START, SUBSCRIPTION_EXPIRY);
dbm.store_user(user_id, &user).unwrap();

// Create and store a particular tracker
let (uuid, appointment) =
generate_dummy_appointment_with_user(user_id, Some(&dispute_txid));
let tracker = get_random_tracker(user_id, status);
dbm.store_appointment(uuid, &appointment).unwrap();
dbm.store_tracker(uuid, &tracker).unwrap();
trackers.insert(uuid, tracker.clone());

// create random trackers
for _ in 1..11 {
let (uuid, appointment) = generate_dummy_appointment_with_user(user_id, None);
let tracker = get_random_tracker(user_id, status);
dbm.store_appointment(uuid, &appointment).unwrap();
dbm.store_tracker(uuid, &tracker).unwrap();
}

// Returns empty if no tracker matches both userid and locator
assert_eq!(
dbm.load_trackers(Some((locator, Some(get_random_user_id()))),),
HashMap::new()
);
assert_eq!(
dbm.load_trackers(Some((get_random_locator(), Some(user_id))),),
HashMap::new()
);

// Returns particular trackers if they match both userid and locator
assert_eq!(
dbm.load_trackers(Some((locator, Some(user_id))),),
HashMap::from([(uuid, tracker)])
);
}

#[test]
Expand Down
Loading

0 comments on commit 4dfb470

Please sign in to comment.