diff --git a/common/client-core/src/init/helpers.rs b/common/client-core/src/init/helpers.rs index f0fbe7c23f1..7db4e3a01ec 100644 --- a/common/client-core/src/init/helpers.rs +++ b/common/client-core/src/init/helpers.rs @@ -46,13 +46,34 @@ const MEASUREMENTS: usize = 3; const CONN_TIMEOUT: Duration = Duration::from_millis(1500); const PING_TIMEOUT: Duration = Duration::from_millis(1000); -struct GatewayWithLatency<'a> { - gateway: &'a gateway::Node, +// The abstraction that some of these helpers use +pub trait ConnectableGateway { + fn identity(&self) -> &identity::PublicKey; + fn clients_address(&self) -> String; + fn is_wss(&self) -> bool; +} + +impl ConnectableGateway for gateway::Node { + fn identity(&self) -> &identity::PublicKey { + self.identity() + } + + fn clients_address(&self) -> String { + self.clients_address() + } + + fn is_wss(&self) -> bool { + self.clients_wss_port.is_some() + } +} + +struct GatewayWithLatency<'a, G: ConnectableGateway> { + gateway: &'a G, latency: Duration, } -impl<'a> GatewayWithLatency<'a> { - fn new(gateway: &'a gateway::Node, latency: Duration) -> Self { +impl<'a, G: ConnectableGateway> GatewayWithLatency<'a, G> { + fn new(gateway: &'a G, latency: Duration) -> Self { GatewayWithLatency { gateway, latency } } } @@ -130,11 +151,14 @@ async fn connect(endpoint: &str) -> Result { JSWebsocket::new(endpoint).map_err(|_| ClientCoreError::GatewayJsConnectionFailure) } -async fn measure_latency(gateway: &gateway::Node) -> Result { +async fn measure_latency(gateway: &G) -> Result, ClientCoreError> +where + G: ConnectableGateway, +{ let addr = gateway.clients_address(); trace!( "establishing connection to {} ({addr})...", - gateway.identity_key, + gateway.identity(), ); let mut stream = connect(&addr).await?; @@ -177,7 +201,7 @@ async fn measure_latency(gateway: &gateway::Node) -> Result Result( +pub async fn choose_gateway_by_latency<'a, R: Rng, G: ConnectableGateway + Clone>( rng: &mut R, - gateways: &[gateway::Node], + gateways: &[G], must_use_tls: bool, -) -> Result { +) -> Result { let gateways = filter_by_tls(gateways, must_use_tls)?; info!( @@ -223,21 +247,19 @@ pub async fn choose_gateway_by_latency( info!( "chose gateway {} with average latency of {:?}", - chosen.gateway.identity_key, chosen.latency + chosen.gateway.identity(), + chosen.latency ); Ok(chosen.gateway.clone()) } -fn filter_by_tls( - gateways: &[gateway::Node], +fn filter_by_tls( + gateways: &[G], must_use_tls: bool, -) -> Result, ClientCoreError> { +) -> Result, ClientCoreError> { if must_use_tls { - let filtered = gateways - .iter() - .filter(|g| g.clients_wss_port.is_some()) - .collect::>(); + let filtered = gateways.iter().filter(|g| g.is_wss()).collect::>(); if filtered.is_empty() { return Err(ClientCoreError::NoWssGateways);