Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 90 additions & 30 deletions common/client-libs/gateway-client/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ impl<C, St> GatewayClient<C, St> {
) -> Result<(), GatewayClientError> {
if let Some(shared_key) = self.shared_key() {
let encrypted = message.encrypt(&*shared_key)?;
Box::pin(self.send_websocket_message(encrypted)).await?;
Box::pin(self.send_websocket_message_without_response(encrypted)).await?;
Ok(())
} else {
Err(GatewayClientError::ConnectionInInvalidState)
Expand Down Expand Up @@ -330,9 +330,80 @@ impl<C, St> GatewayClient<C, St> {
}
}

/// Attempt to send a websocket message to the gateway without waiting for any response
async fn send_websocket_message_without_response(
&mut self,
msg: impl Into<Message>,
) -> Result<(), GatewayClientError> {
match self.connection {
SocketState::Available(ref mut conn) => Ok(conn.send(msg.into()).await?),
SocketState::PartiallyDelegated(ref mut partially_delegated) => {
if let Err(err) = partially_delegated.send_without_response(msg.into()).await {
error!("failed to send message without response - {err}...");
// we must ensure we do not leave the task still active
if let Err(err) = self.recover_socket_connection().await {
error!("... and the delegated stream has also errored out - {err}")
}
Err(err)
} else {
Ok(())
}
}
SocketState::NotConnected => Err(GatewayClientError::ConnectionNotEstablished),
_ => Err(GatewayClientError::ConnectionInInvalidState),
}
}

// A very nasty hack due to lack of id tags on messages - send a non-sphinx packet websocket
// message and wait until first non 'Send' response within timeout
pub async fn send_websocket_message_with_non_send_response(
&mut self,
msg: impl Into<Message>,
) -> Result<ServerResponse, GatewayClientError> {
let should_restart_mixnet_listener = if self.connection.is_partially_delegated() {
self.recover_socket_connection().await?;
true
} else {
false
};

let conn = match self.connection {
SocketState::Available(ref mut conn) => conn,
SocketState::NotConnected => return Err(GatewayClientError::ConnectionNotEstablished),
_ => return Err(GatewayClientError::ConnectionInInvalidState),
};
conn.send(msg.into()).await?;

let timeout = sleep(self.cfg.connection.response_timeout_duration);
tokio::pin!(timeout);

let response = loop {
tokio::select! {
_ = &mut timeout => {
break Err(GatewayClientError::Timeout);
}
// note: the below will also listen for shutdown signals
msg = self.read_control_response() => {
match msg {
Ok(res) => if !res.is_send() {
break Ok(res);
},
Err(err) => break Err(err),
}
}
}
};

if should_restart_mixnet_listener {
self.start_listening_for_mixnet_messages()?;
}
response
}

/// Attempt to send a websocket message to the gateway and wait until we receive a response.
// If we want to send a message (with response), we need to have a full control over the socket,
// as we need to be able to write the request and read the subsequent response
pub async fn send_websocket_message(
pub async fn send_websocket_message_with_response(
&mut self,
msg: impl Into<Message>,
) -> Result<ServerResponse, GatewayClientError> {
Expand Down Expand Up @@ -387,29 +458,6 @@ impl<C, St> GatewayClient<C, St> {
}
}

async fn send_websocket_message_without_response(
&mut self,
msg: Message,
) -> Result<(), GatewayClientError> {
match self.connection {
SocketState::Available(ref mut conn) => Ok(conn.send(msg).await?),
SocketState::PartiallyDelegated(ref mut partially_delegated) => {
if let Err(err) = partially_delegated.send_without_response(msg).await {
error!("failed to send message without response - {err}...");
// we must ensure we do not leave the task still active
if let Err(err) = self.recover_socket_connection().await {
error!("... and the delegated stream has also errored out - {err}")
}
Err(err)
} else {
Ok(())
}
}
SocketState::NotConnected => Err(GatewayClientError::ConnectionNotEstablished),
_ => Err(GatewayClientError::ConnectionInInvalidState),
}
}

fn check_gateway_protocol(
&self,
gateway_protocol: Option<u8>,
Expand Down Expand Up @@ -535,7 +583,10 @@ impl<C, St> GatewayClient<C, St> {
.encrypt(legacy_key)?;

info!("sending upgrade request and awaiting the acknowledgement back");
let (ciphertext, nonce) = match self.send_websocket_message(upgrade_request).await? {
let (ciphertext, nonce) = match self
.send_websocket_message_with_response(upgrade_request)
.await?
{
ServerResponse::EncryptedResponse { ciphertext, nonce } => (ciphertext, nonce),
ServerResponse::Error { message } => {
return Err(GatewayClientError::GatewayError(message))
Expand Down Expand Up @@ -567,7 +618,7 @@ impl<C, St> GatewayClient<C, St> {
&mut self,
msg: ClientControlRequest,
) -> Result<(), GatewayClientError> {
match self.send_websocket_message(msg).await? {
match self.send_websocket_message_with_response(msg).await? {
ServerResponse::Authenticate {
protocol_version,
status,
Expand Down Expand Up @@ -717,13 +768,16 @@ impl<C, St> GatewayClient<C, St> {
}
}

/// Attempt to retrieve the currently supported gateway protocol version of the remote.
pub async fn get_gateway_protocol(&mut self) -> Result<u8, GatewayClientError> {
if !self.connection.is_established() {
return Err(GatewayClientError::ConnectionNotEstablished);
}

match self
.send_websocket_message(ClientControlRequest::SupportedProtocol {})
.send_websocket_message_with_non_send_response(
ClientControlRequest::SupportedProtocol {},
)
.await?
{
ServerResponse::SupportedProtocol { version } => Ok(version),
Expand All @@ -740,7 +794,10 @@ impl<C, St> GatewayClient<C, St> {
credential,
self.shared_key.as_ref().unwrap(),
)?;
let bandwidth_remaining = match self.send_websocket_message(msg).await? {
let bandwidth_remaining = match self
.send_websocket_message_with_non_send_response(msg)
.await?
{
ServerResponse::Bandwidth { available_total } => Ok(available_total),
ServerResponse::Error { message } => Err(GatewayClientError::GatewayError(message)),
ServerResponse::TypedError { error } => {
Expand All @@ -758,7 +815,10 @@ impl<C, St> GatewayClient<C, St> {

async fn try_claim_testnet_bandwidth(&mut self) -> Result<(), GatewayClientError> {
let msg = ClientControlRequest::ClaimFreeTestnetBandwidth;
let bandwidth_remaining = match self.send_websocket_message(msg).await? {
let bandwidth_remaining = match self
.send_websocket_message_with_non_send_response(msg)
.await?
{
ServerResponse::Bandwidth { available_total } => Ok(available_total),
ServerResponse::Error { message } => Err(GatewayClientError::GatewayError(message)),
other => Err(GatewayClientError::UnexpectedResponse { name: other.name() }),
Expand Down
4 changes: 4 additions & 0 deletions common/gateway-requests/src/types/text_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ impl ServerResponse {
_ => false,
}
}

pub fn is_send(&self) -> bool {
matches!(self, ServerResponse::Send { .. })
}
}

impl From<ServerResponse> for Message {
Expand Down
Loading