diff --git a/common/client-libs/gateway-client/src/client/mod.rs b/common/client-libs/gateway-client/src/client/mod.rs index eeadec5ea13..612349315f4 100644 --- a/common/client-libs/gateway-client/src/client/mod.rs +++ b/common/client-libs/gateway-client/src/client/mod.rs @@ -272,7 +272,7 @@ impl GatewayClient { ) -> Result<(), GatewayClientError> { if let Some(shared_key) = self.shared_key() { let encrypted = message.encrypt(&*shared_key)?; - Box::pin(self.send_websocket_message(encrypted)).await?; + Box::pin(self.send_websocket_message_without_response(encrypted)).await?; Ok(()) } else { Err(GatewayClientError::ConnectionInInvalidState) @@ -330,9 +330,80 @@ impl GatewayClient { } } + /// Attempt to send a websocket message to the gateway without waiting for any response + async fn send_websocket_message_without_response( + &mut self, + msg: impl Into, + ) -> Result<(), GatewayClientError> { + match self.connection { + SocketState::Available(ref mut conn) => Ok(conn.send(msg.into()).await?), + SocketState::PartiallyDelegated(ref mut partially_delegated) => { + if let Err(err) = partially_delegated.send_without_response(msg.into()).await { + error!("failed to send message without response - {err}..."); + // we must ensure we do not leave the task still active + if let Err(err) = self.recover_socket_connection().await { + error!("... and the delegated stream has also errored out - {err}") + } + Err(err) + } else { + Ok(()) + } + } + SocketState::NotConnected => Err(GatewayClientError::ConnectionNotEstablished), + _ => Err(GatewayClientError::ConnectionInInvalidState), + } + } + + // A very nasty hack due to lack of id tags on messages - send a non-sphinx packet websocket + // message and wait until first non 'Send' response within timeout + pub async fn send_websocket_message_with_non_send_response( + &mut self, + msg: impl Into, + ) -> Result { + let should_restart_mixnet_listener = if self.connection.is_partially_delegated() { + self.recover_socket_connection().await?; + true + } else { + false + }; + + let conn = match self.connection { + SocketState::Available(ref mut conn) => conn, + SocketState::NotConnected => return Err(GatewayClientError::ConnectionNotEstablished), + _ => return Err(GatewayClientError::ConnectionInInvalidState), + }; + conn.send(msg.into()).await?; + + let timeout = sleep(self.cfg.connection.response_timeout_duration); + tokio::pin!(timeout); + + let response = loop { + tokio::select! { + _ = &mut timeout => { + break Err(GatewayClientError::Timeout); + } + // note: the below will also listen for shutdown signals + msg = self.read_control_response() => { + match msg { + Ok(res) => if !res.is_send() { + break Ok(res); + }, + Err(err) => break Err(err), + } + } + } + }; + + if should_restart_mixnet_listener { + self.start_listening_for_mixnet_messages()?; + } + response + } + + /// Attempt to send a websocket message to the gateway and wait until we receive a response. // If we want to send a message (with response), we need to have a full control over the socket, // as we need to be able to write the request and read the subsequent response - pub async fn send_websocket_message( + pub async fn send_websocket_message_with_response( &mut self, msg: impl Into, ) -> Result { @@ -387,29 +458,6 @@ impl GatewayClient { } } - async fn send_websocket_message_without_response( - &mut self, - msg: Message, - ) -> Result<(), GatewayClientError> { - match self.connection { - SocketState::Available(ref mut conn) => Ok(conn.send(msg).await?), - SocketState::PartiallyDelegated(ref mut partially_delegated) => { - if let Err(err) = partially_delegated.send_without_response(msg).await { - error!("failed to send message without response - {err}..."); - // we must ensure we do not leave the task still active - if let Err(err) = self.recover_socket_connection().await { - error!("... and the delegated stream has also errored out - {err}") - } - Err(err) - } else { - Ok(()) - } - } - SocketState::NotConnected => Err(GatewayClientError::ConnectionNotEstablished), - _ => Err(GatewayClientError::ConnectionInInvalidState), - } - } - fn check_gateway_protocol( &self, gateway_protocol: Option, @@ -535,7 +583,10 @@ impl GatewayClient { .encrypt(legacy_key)?; info!("sending upgrade request and awaiting the acknowledgement back"); - let (ciphertext, nonce) = match self.send_websocket_message(upgrade_request).await? { + let (ciphertext, nonce) = match self + .send_websocket_message_with_response(upgrade_request) + .await? + { ServerResponse::EncryptedResponse { ciphertext, nonce } => (ciphertext, nonce), ServerResponse::Error { message } => { return Err(GatewayClientError::GatewayError(message)) @@ -567,7 +618,7 @@ impl GatewayClient { &mut self, msg: ClientControlRequest, ) -> Result<(), GatewayClientError> { - match self.send_websocket_message(msg).await? { + match self.send_websocket_message_with_response(msg).await? { ServerResponse::Authenticate { protocol_version, status, @@ -717,13 +768,16 @@ impl GatewayClient { } } + /// Attempt to retrieve the currently supported gateway protocol version of the remote. pub async fn get_gateway_protocol(&mut self) -> Result { if !self.connection.is_established() { return Err(GatewayClientError::ConnectionNotEstablished); } match self - .send_websocket_message(ClientControlRequest::SupportedProtocol {}) + .send_websocket_message_with_non_send_response( + ClientControlRequest::SupportedProtocol {}, + ) .await? { ServerResponse::SupportedProtocol { version } => Ok(version), @@ -740,7 +794,10 @@ impl GatewayClient { credential, self.shared_key.as_ref().unwrap(), )?; - let bandwidth_remaining = match self.send_websocket_message(msg).await? { + let bandwidth_remaining = match self + .send_websocket_message_with_non_send_response(msg) + .await? + { ServerResponse::Bandwidth { available_total } => Ok(available_total), ServerResponse::Error { message } => Err(GatewayClientError::GatewayError(message)), ServerResponse::TypedError { error } => { @@ -758,7 +815,10 @@ impl GatewayClient { async fn try_claim_testnet_bandwidth(&mut self) -> Result<(), GatewayClientError> { let msg = ClientControlRequest::ClaimFreeTestnetBandwidth; - let bandwidth_remaining = match self.send_websocket_message(msg).await? { + let bandwidth_remaining = match self + .send_websocket_message_with_non_send_response(msg) + .await? + { ServerResponse::Bandwidth { available_total } => Ok(available_total), ServerResponse::Error { message } => Err(GatewayClientError::GatewayError(message)), other => Err(GatewayClientError::UnexpectedResponse { name: other.name() }), diff --git a/common/gateway-requests/src/types/text_response.rs b/common/gateway-requests/src/types/text_response.rs index d60f6adcc6a..c3418649cfd 100644 --- a/common/gateway-requests/src/types/text_response.rs +++ b/common/gateway-requests/src/types/text_response.rs @@ -112,6 +112,10 @@ impl ServerResponse { _ => false, } } + + pub fn is_send(&self) -> bool { + matches!(self, ServerResponse::Send { .. }) + } } impl From for Message {