Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions common/client-core/src/init/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,34 @@ const MEASUREMENTS: usize = 3;
const CONN_TIMEOUT: Duration = Duration::from_millis(1500);
const PING_TIMEOUT: Duration = Duration::from_millis(1000);

struct GatewayWithLatency<'a> {
gateway: &'a gateway::Node,
// The abstraction that some of these helpers use
pub trait ConnectableGateway {
fn identity(&self) -> &identity::PublicKey;
fn clients_address(&self) -> String;
fn is_wss(&self) -> bool;
}

impl ConnectableGateway for gateway::Node {
fn identity(&self) -> &identity::PublicKey {
self.identity()
}

fn clients_address(&self) -> String {
self.clients_address()
}

fn is_wss(&self) -> bool {
self.clients_wss_port.is_some()
}
}

struct GatewayWithLatency<'a, G: ConnectableGateway> {
gateway: &'a G,
latency: Duration,
}

impl<'a> GatewayWithLatency<'a> {
fn new(gateway: &'a gateway::Node, latency: Duration) -> Self {
impl<'a, G: ConnectableGateway> GatewayWithLatency<'a, G> {
fn new(gateway: &'a G, latency: Duration) -> Self {
GatewayWithLatency { gateway, latency }
}
}
Expand Down Expand Up @@ -130,11 +151,14 @@ async fn connect(endpoint: &str) -> Result<WsConn, ClientCoreError> {
JSWebsocket::new(endpoint).map_err(|_| ClientCoreError::GatewayJsConnectionFailure)
}

async fn measure_latency(gateway: &gateway::Node) -> Result<GatewayWithLatency, ClientCoreError> {
async fn measure_latency<G>(gateway: &G) -> Result<GatewayWithLatency<G>, ClientCoreError>
where
G: ConnectableGateway,
{
let addr = gateway.clients_address();
trace!(
"establishing connection to {} ({addr})...",
gateway.identity_key,
gateway.identity(),
);
let mut stream = connect(&addr).await?;

Expand Down Expand Up @@ -177,7 +201,7 @@ async fn measure_latency(gateway: &gateway::Node) -> Result<GatewayWithLatency,
let count = results.len() as u64;
if count == 0 {
return Err(ClientCoreError::NoGatewayMeasurements {
identity: gateway.identity_key.to_base58_string(),
identity: gateway.identity().to_base58_string(),
});
}

Expand All @@ -187,11 +211,11 @@ async fn measure_latency(gateway: &gateway::Node) -> Result<GatewayWithLatency,
Ok(GatewayWithLatency::new(gateway, avg))
}

pub async fn choose_gateway_by_latency<R: Rng>(
pub async fn choose_gateway_by_latency<'a, R: Rng, G: ConnectableGateway + Clone>(
rng: &mut R,
gateways: &[gateway::Node],
gateways: &[G],
must_use_tls: bool,
) -> Result<gateway::Node, ClientCoreError> {
) -> Result<G, ClientCoreError> {
let gateways = filter_by_tls(gateways, must_use_tls)?;

info!(
Expand Down Expand Up @@ -223,21 +247,19 @@ pub async fn choose_gateway_by_latency<R: Rng>(

info!(
"chose gateway {} with average latency of {:?}",
chosen.gateway.identity_key, chosen.latency
chosen.gateway.identity(),
chosen.latency
);

Ok(chosen.gateway.clone())
}

fn filter_by_tls(
gateways: &[gateway::Node],
fn filter_by_tls<G: ConnectableGateway>(
gateways: &[G],
must_use_tls: bool,
) -> Result<Vec<&gateway::Node>, ClientCoreError> {
) -> Result<Vec<&G>, ClientCoreError> {
if must_use_tls {
let filtered = gateways
.iter()
.filter(|g| g.clients_wss_port.is_some())
.collect::<Vec<_>>();
let filtered = gateways.iter().filter(|g| g.is_wss()).collect::<Vec<_>>();

if filtered.is_empty() {
return Err(ClientCoreError::NoWssGateways);
Expand Down