diff --git a/Cargo.lock b/Cargo.lock index ad525671a6a..03acc93774a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5069,6 +5069,7 @@ dependencies = [ "nym-crypto", "nym-pemstore", "nym-sphinx", + "nym-task", "rand", "serde", "serde_json", diff --git a/common/client-libs/gateway-client/src/client/mod.rs b/common/client-libs/gateway-client/src/client/mod.rs index e544a610b6a..d899ec29285 100644 --- a/common/client-libs/gateway-client/src/client/mod.rs +++ b/common/client-libs/gateway-client/src/client/mod.rs @@ -417,6 +417,8 @@ impl GatewayClient { self.local_identity.as_ref(), self.gateway_identity, self.cfg.bandwidth.require_tickets, + #[cfg(not(target_arch = "wasm32"))] + self.task_client.clone(), ) .await .map_err(GatewayClientError::RegistrationFailure), diff --git a/common/gateway-requests/Cargo.toml b/common/gateway-requests/Cargo.toml index c1dadc47321..72a2bdfe9d9 100644 --- a/common/gateway-requests/Cargo.toml +++ b/common/gateway-requests/Cargo.toml @@ -24,6 +24,7 @@ zeroize = { workspace = true } nym-crypto = { path = "../crypto" } nym-pemstore = { path = "../pemstore" } nym-sphinx = { path = "../nymsphinx" } +nym-task = { path = "../task" } nym-credentials = { path = "../credentials" } nym-credentials-interface = { path = "../credentials-interface" } diff --git a/common/gateway-requests/src/registration/handshake/client.rs b/common/gateway-requests/src/registration/handshake/client.rs index edb30ebf206..902a520bcaf 100644 --- a/common/gateway-requests/src/registration/handshake/client.rs +++ b/common/gateway-requests/src/registration/handshake/client.rs @@ -25,6 +25,7 @@ impl<'a> ClientHandshake<'a> { identity: &'a nym_crypto::asymmetric::identity::KeyPair, gateway_pubkey: identity::PublicKey, expects_credential_usage: bool, + #[cfg(not(target_arch = "wasm32"))] shutdown: nym_task::TaskClient, ) -> Self where S: Stream + Sink + Unpin + Send + 'a, @@ -35,6 +36,8 @@ impl<'a> ClientHandshake<'a> { identity, Some(gateway_pubkey), expects_credential_usage, + #[cfg(not(target_arch = "wasm32"))] + shutdown, ); ClientHandshake { diff --git a/common/gateway-requests/src/registration/handshake/error.rs b/common/gateway-requests/src/registration/handshake/error.rs index 8fee78fd388..6e82ef040cd 100644 --- a/common/gateway-requests/src/registration/handshake/error.rs +++ b/common/gateway-requests/src/registration/handshake/error.rs @@ -25,6 +25,8 @@ pub enum HandshakeError { MalformedRequest, #[error("sent request was malformed")] HandshakeFailure, + #[error("received shutdown")] + ReceivedShutdown, #[error("timed out waiting for a handshake message")] Timeout, diff --git a/common/gateway-requests/src/registration/handshake/gateway.rs b/common/gateway-requests/src/registration/handshake/gateway.rs index bb0003d0a30..a42851a3200 100644 --- a/common/gateway-requests/src/registration/handshake/gateway.rs +++ b/common/gateway-requests/src/registration/handshake/gateway.rs @@ -8,6 +8,7 @@ use futures::future::BoxFuture; use futures::task::{Context, Poll}; use futures::{Future, Sink, Stream}; use nym_crypto::asymmetric::encryption; +use nym_task::TaskClient; use rand::{CryptoRng, RngCore}; use std::pin::Pin; use tungstenite::Message as WsMessage; @@ -22,11 +23,12 @@ impl<'a> GatewayHandshake<'a> { ws_stream: &'a mut S, identity: &'a nym_crypto::asymmetric::identity::KeyPair, received_init_payload: Vec, + shutdown: TaskClient, ) -> Self where S: Stream + Sink + Unpin + Send + 'a, { - let mut state = State::new(rng, ws_stream, identity, None, true); + let mut state = State::new(rng, ws_stream, identity, None, true, shutdown); GatewayHandshake { handshake_future: Box::pin(async move { // If any step along the way failed (that are non-network related), diff --git a/common/gateway-requests/src/registration/handshake/mod.rs b/common/gateway-requests/src/registration/handshake/mod.rs index 4e21fcfc0df..d3faec79037 100644 --- a/common/gateway-requests/src/registration/handshake/mod.rs +++ b/common/gateway-requests/src/registration/handshake/mod.rs @@ -8,6 +8,8 @@ use self::gateway::GatewayHandshake; pub use self::shared_key::{SharedKeySize, SharedKeys}; use futures::{Sink, Stream}; use nym_crypto::asymmetric::identity; +#[cfg(not(target_arch = "wasm32"))] +use nym_task::TaskClient; use rand::{CryptoRng, RngCore}; use tungstenite::{Error as WsError, Message as WsMessage}; @@ -31,6 +33,7 @@ pub async fn client_handshake<'a, S>( identity: &'a identity::KeyPair, gateway_pubkey: identity::PublicKey, expects_credential_usage: bool, + #[cfg(not(target_arch = "wasm32"))] shutdown: TaskClient, ) -> Result where S: Stream + Sink + Unpin + Send + 'a, @@ -41,6 +44,8 @@ where identity, gateway_pubkey, expects_credential_usage, + #[cfg(not(target_arch = "wasm32"))] + shutdown, ) .await } @@ -51,11 +56,12 @@ pub async fn gateway_handshake<'a, S>( ws_stream: &'a mut S, identity: &'a identity::KeyPair, received_init_payload: Vec, + shutdown: TaskClient, ) -> Result where S: Stream + Sink + Unpin + Send + 'a, { - GatewayHandshake::new(rng, ws_stream, identity, received_init_payload).await + GatewayHandshake::new(rng, ws_stream, identity, received_init_payload, shutdown).await } /* diff --git a/common/gateway-requests/src/registration/handshake/state.rs b/common/gateway-requests/src/registration/handshake/state.rs index 328b5d59cb8..70cd48620c2 100644 --- a/common/gateway-requests/src/registration/handshake/state.rs +++ b/common/gateway-requests/src/registration/handshake/state.rs @@ -13,6 +13,8 @@ use nym_crypto::{ symmetric::stream_cipher, }; use nym_sphinx::params::{GatewayEncryptionAlgorithm, GatewaySharedKeyHkdfAlgorithm}; +#[cfg(not(target_arch = "wasm32"))] +use nym_task::TaskClient; use rand::{CryptoRng, RngCore}; use tracing::log::*; @@ -48,6 +50,10 @@ pub(crate) struct State<'a, S> { // this field is really out of place here, however, we need to propagate this information somehow // in order to establish correct protocol for backwards compatibility reasons expects_credential_usage: bool, + + // channel to receive shutdown signal + #[cfg(not(target_arch = "wasm32"))] + shutdown: TaskClient, } impl<'a, S> State<'a, S> { @@ -57,6 +63,7 @@ impl<'a, S> State<'a, S> { identity: &'a identity::KeyPair, remote_pubkey: Option, expects_credential_usage: bool, + #[cfg(not(target_arch = "wasm32"))] shutdown: TaskClient, ) -> Self { let ephemeral_keypair = encryption::KeyPair::new(rng); State { @@ -66,6 +73,8 @@ impl<'a, S> State<'a, S> { remote_pubkey, derived_shared_keys: None, expects_credential_usage, + #[cfg(not(target_arch = "wasm32"))] + shutdown, } } @@ -199,46 +208,76 @@ impl<'a, S> State<'a, S> { self.remote_pubkey = Some(remote_pubkey) } + fn on_wg_msg(msg: Option) -> Result>, HandshakeError> { + let Some(msg) = msg else { + return Err(HandshakeError::ClosedStream); + }; + + let Ok(msg) = msg else { + return Err(HandshakeError::NetworkError); + }; + match msg { + WsMessage::Text(ref ws_msg) => { + match types::RegistrationHandshake::from_str(ws_msg) { + Ok(reg_handshake_msg) => { + match reg_handshake_msg { + // hehe, that's a bit disgusting that the type system requires we explicitly ignore the + // protocol_version field that we actually never attach at this point + // yet another reason for the overdue refactor + types::RegistrationHandshake::HandshakePayload { data, .. } => { + Ok(Some(data)) + } + types::RegistrationHandshake::HandshakeError { message } => { + Err(HandshakeError::RemoteError(message)) + } + } + } + Err(_) => { + error!("Received a non-handshake message during the registration handshake! It's getting dropped. The received content was: '{msg}'"); + Ok(None) + } + } + } + _ => { + error!("Received non-text message during registration handshake"); + Ok(None) + } + } + } + + #[cfg(not(target_arch = "wasm32"))] async fn _receive_handshake_message(&mut self) -> Result, HandshakeError> where S: Stream + Unpin, { loop { - let Some(msg) = self.ws_stream.next().await else { - return Err(HandshakeError::ClosedStream); - }; - - let Ok(msg) = msg else { - return Err(HandshakeError::NetworkError); - }; - - match msg { - WsMessage::Text(ref ws_msg) => { - match types::RegistrationHandshake::from_str(ws_msg) { - Ok(reg_handshake_msg) => { - return match reg_handshake_msg { - // hehe, that's a bit disgusting that the type system requires we explicitly ignore the - // protocol_version field that we actually never attach at this point - // yet another reason for the overdue refactor - types::RegistrationHandshake::HandshakePayload { data, .. } => { - Ok(data) - } - types::RegistrationHandshake::HandshakeError { message } => { - Err(HandshakeError::RemoteError(message)) - } - }; - } - Err(_) => { - error!("Received a non-handshake message during the registration handshake! It's getting dropped. The received content was: '{msg}'"); - continue; - } - } + tokio::select! { + biased; + _ = self.shutdown.recv() => return Err(HandshakeError::ReceivedShutdown), + msg = self.ws_stream.next() => { + let Some(ret) = Self::on_wg_msg(msg)? else { + continue; + }; + return Ok(ret); } - _ => error!("Received non-text message during registration handshake"), } } } + #[cfg(target_arch = "wasm32")] + async fn _receive_handshake_message(&mut self) -> Result, HandshakeError> + where + S: Stream + Unpin, + { + loop { + let msg = self.ws_stream.next().await; + let Some(ret) = Self::on_wg_msg(msg)? else { + continue; + }; + return Ok(ret); + } + } + pub(crate) async fn receive_handshake_message(&mut self) -> Result, HandshakeError> where S: Stream + Unpin, diff --git a/common/task/src/manager.rs b/common/task/src/manager.rs index ab03870dc01..8a4c21a3117 100644 --- a/common/task/src/manager.rs +++ b/common/task/src/manager.rs @@ -3,7 +3,6 @@ use futures::{future::pending, FutureExt, SinkExt, StreamExt}; use log::{log, Level}; -use std::future::Future; use std::sync::atomic::{AtomicBool, Ordering}; use std::{error::Error, time::Duration}; use tokio::sync::{ @@ -368,17 +367,6 @@ impl TaskClient { self.named(name) } - pub async fn run_future(&mut self, fut: Fut) -> Option - where - Fut: Future, - { - tokio::select! { - biased; - _ = self.recv() => None, - res = fut => Some(res) - } - } - // Create a dummy that will never report that we should shutdown. pub fn dummy() -> TaskClient { let (_notify_tx, notify_rx) = watch::channel(()); diff --git a/gateway/src/node/client_handling/websocket/connection_handler/fresh.rs b/gateway/src/node/client_handling/websocket/connection_handler/fresh.rs index 09865328964..fc5f453e257 100644 --- a/gateway/src/node/client_handling/websocket/connection_handler/fresh.rs +++ b/gateway/src/node/client_handling/websocket/connection_handler/fresh.rs @@ -20,6 +20,7 @@ use nym_gateway_requests::{ }; use nym_mixnet_client::forwarder::MixForwardingSender; use nym_sphinx::DestinationAddressBytes; +use nym_task::TaskClient; use rand::{CryptoRng, Rng}; use std::net::SocketAddr; use std::time::Duration; @@ -111,6 +112,9 @@ pub(crate) enum InitialAuthenticationError { #[error("failed to upgrade the client handler: {source}")] HandlerUpgradeFailure { source: RequestHandlingError }, + + #[error("received shutdown")] + ReceivedShutdown, } impl InitialAuthenticationError { @@ -127,6 +131,7 @@ pub(crate) struct FreshHandler { pub(crate) outbound_mix_sender: MixForwardingSender, pub(crate) socket_connection: SocketStream, pub(crate) peer_address: SocketAddr, + pub(crate) shutdown: TaskClient, // currently unused (but populated) pub(crate) negotiated_protocol: Option, @@ -149,6 +154,7 @@ where active_clients_store: ActiveClientsStore, shared_state: CommonHandlerState, peer_address: SocketAddr, + shutdown: TaskClient, ) -> Self { FreshHandler { rng, @@ -158,6 +164,7 @@ where peer_address, negotiated_protocol: None, shared_state, + shutdown, } } @@ -201,6 +208,7 @@ where ws_stream, self.shared_state.local_identity.as_ref(), init_msg, + self.shutdown.clone(), ) .await } @@ -720,111 +728,125 @@ where { trace!("Started waiting for authenticate/register request..."); - while let Some(msg) = self.read_websocket_message().await { - let msg = match msg { - Ok(msg) => msg, - Err(source) => { - debug!("failed to obtain message from websocket stream! stopping connection handler: {source}"); - return Err(InitialAuthenticationError::FailedToReadMessage { source }); - } - }; + let mut shutdown = self.shutdown.clone(); + loop { + tokio::select! { + biased; + _ = shutdown.recv() => { + trace!("received shutdown signal while performing initial authentication"); + return Err(InitialAuthenticationError::ReceivedShutdown); + }, + msg = self.read_websocket_message() => { + let Some(msg) = msg else { + break; + }; - if msg.is_close() { - return Err(InitialAuthenticationError::CloseMessage); - } + let msg = match msg { + Ok(msg) => msg, + Err(source) => { + debug!("failed to obtain message from websocket stream! stopping connection handler: {source}"); + return Err(InitialAuthenticationError::FailedToReadMessage { source }); + } + }; - // ONLY handle 'Authenticate' or 'Register' requests, ignore everything else - match msg { - // we have explicitly checked for close message - Message::Close(_) => unreachable!(), - Message::Text(text_msg) => { - let (mix_sender, mix_receiver) = mpsc::unbounded(); - return match self.handle_initial_authentication_request(text_msg).await { - Err(err) => { - debug!("authentication failure: {err}"); - - // try to send error to the client - if let Err(source) = - self.send_websocket_message(err.to_error_message()).await - { - debug!("Failed to send authentication error response: {source}"); - return Err(InitialAuthenticationError::ErrorResponseSendFailure { - source, - }); - } - // return the underlying error - Err(err) + if msg.is_close() { + return Err(InitialAuthenticationError::CloseMessage); + } + + // ONLY handle 'Authenticate' or 'Register' requests, ignore everything else + match msg { + // we have explicitly checked for close message + Message::Close(_) => unreachable!(), + Message::Text(text_msg) => { + let (mix_sender, mix_receiver) = mpsc::unbounded(); + return match self.handle_initial_authentication_request(text_msg).await { + Err(err) => { + debug!("authentication failure: {err}"); + + // try to send error to the client + if let Err(source) = + self.send_websocket_message(err.to_error_message()).await + { + debug!("Failed to send authentication error response: {source}"); + return Err(InitialAuthenticationError::ErrorResponseSendFailure { + source, + }); + } + // return the underlying error + Err(err) + } + Ok(auth_result) => { + // try to send auth response back to the client + if let Err(source) = self + .send_websocket_message(auth_result.server_response.into()) + .await + { + debug!("Failed to send authentication response: {source}"); + return Err(InitialAuthenticationError::ResponseSendFailure { + source, + }); + } + + if let Some(client_details) = auth_result.client_details { + // Channel for handlers to ask other handlers if they are still active. + let (is_active_request_sender, is_active_request_receiver) = + mpsc::unbounded(); + self.active_clients_store.insert_remote( + client_details.address, + mix_sender, + is_active_request_sender, + ); + AuthenticatedHandler::upgrade( + self, + client_details, + mix_receiver, + is_active_request_receiver, + ) + .await + .map_err(|source| { + InitialAuthenticationError::HandlerUpgradeFailure { source } + }) + } else { + // honestly, it's been so long I don't remember under what conditions its possible (if at all) + // to have empty client details + Err(InitialAuthenticationError::EmptyClientDetails) + } + } + }; } - Ok(auth_result) => { - // try to send auth response back to the client + Message::Binary(_) => { + // perhaps logging level should be reduced here, let's leave it for now and see what happens + // if client is working correctly, this should have never happened + debug!("possibly received a sphinx packet without prior authentication. Request is going to be ignored"); if let Err(source) = self - .send_websocket_message(auth_result.server_response.into()) + .send_websocket_message( + ServerResponse::new_error( + "binary request without prior authentication", + ) + .into(), + ) .await { - debug!("Failed to send authentication response: {source}"); - return Err(InitialAuthenticationError::ResponseSendFailure { + return Err(InitialAuthenticationError::ErrorResponseSendFailure { source, }); } - - if let Some(client_details) = auth_result.client_details { - // Channel for handlers to ask other handlers if they are still active. - let (is_active_request_sender, is_active_request_receiver) = - mpsc::unbounded(); - self.active_clients_store.insert_remote( - client_details.address, - mix_sender, - is_active_request_sender, - ); - AuthenticatedHandler::upgrade( - self, - client_details, - mix_receiver, - is_active_request_receiver, - ) - .await - .map_err(|source| { - InitialAuthenticationError::HandlerUpgradeFailure { source } - }) - } else { - // honestly, it's been so long I don't remember under what conditions its possible (if at all) - // to have empty client details - Err(InitialAuthenticationError::EmptyClientDetails) - } + return Err(InitialAuthenticationError::BinaryRequestWithoutAuthentication); } + + _ => continue, }; } - Message::Binary(_) => { - // perhaps logging level should be reduced here, let's leave it for now and see what happens - // if client is working correctly, this should have never happened - debug!("possibly received a sphinx packet without prior authentication. Request is going to be ignored"); - if let Err(source) = self - .send_websocket_message( - ServerResponse::new_error( - "binary request without prior authentication", - ) - .into(), - ) - .await - { - return Err(InitialAuthenticationError::ErrorResponseSendFailure { - source, - }); - } - return Err(InitialAuthenticationError::BinaryRequestWithoutAuthentication); - } - - _ => continue, - }; + } } Err(InitialAuthenticationError::ClosedConnection) } - pub(crate) async fn start_handling(self, shutdown: nym_task::TaskClient) + pub(crate) async fn start_handling(self) where S: AsyncRead + AsyncWrite + Unpin + Send, { - super::handle_connection(self, shutdown).await + super::handle_connection(self).await } } diff --git a/gateway/src/node/client_handling/websocket/connection_handler/mod.rs b/gateway/src/node/client_handling/websocket/connection_handler/mod.rs index 656ac91c65c..5eca9ad070a 100644 --- a/gateway/src/node/client_handling/websocket/connection_handler/mod.rs +++ b/gateway/src/node/client_handling/websocket/connection_handler/mod.rs @@ -8,7 +8,6 @@ use nym_gateway_requests::registration::handshake::SharedKeys; use nym_gateway_requests::ServerResponse; use nym_gateway_storage::Storage; use nym_sphinx::DestinationAddressBytes; -use nym_task::TaskClient; use rand::{CryptoRng, Rng}; use std::time::Duration; use time::OffsetDateTime; @@ -24,6 +23,8 @@ pub(crate) mod authenticated; pub(crate) mod ecash; mod fresh; +const WEBSOCKET_HANDSHAKE_TIMEOUT: Duration = Duration::from_millis(1_500); + // TODO: note for my future self to consider the following idea: // split the socket connection into sink and stream // stream will be for reading explicit requests @@ -87,53 +88,47 @@ impl InitialAuthResult { // imo there's no point in including the peer address in anything higher than debug #[instrument(level = "debug", skip_all, fields(peer = %handle.peer_address))] -pub(crate) async fn handle_connection( - mut handle: FreshHandler, - mut shutdown: TaskClient, -) where +pub(crate) async fn handle_connection(mut handle: FreshHandler) +where R: Rng + CryptoRng, S: AsyncRead + AsyncWrite + Unpin + Send, St: Storage + Clone + 'static, { // If the connection handler abruptly stops, we shouldn't signal global shutdown - shutdown.mark_as_success(); + handle.shutdown.mark_as_success(); - match shutdown - .run_future(handle.perform_websocket_handshake()) - .await + match tokio::time::timeout( + WEBSOCKET_HANDSHAKE_TIMEOUT, + handle.perform_websocket_handshake(), + ) + .await { - None => { - trace!("received shutdown signal while performing websocket handshake"); + Err(timeout_err) => { + warn!("websocket handshake timedout: {timeout_err}"); return; } - Some(Err(err)) => { + Ok(Err(err)) => { warn!("Failed to complete WebSocket handshake: {err}. Stopping the handler"); return; } - _ => (), + _ => {} } trace!("Managed to perform websocket handshake!"); - match shutdown - .run_future(handle.perform_initial_authentication()) - .await - { - None => { - trace!("received shutdown signal while performing initial authentication"); - return; - } + let shutdown = handle.shutdown.clone(); + match handle.perform_initial_authentication().await { // For storage error, we want to print the extended storage error, but without // including it in the error that's returned to the clients - Some(Err(InitialAuthenticationError::StorageError(err))) => { + Err(InitialAuthenticationError::StorageError(err)) => { warn!("authentication has failed: {err}"); return; } - Some(Err(err)) => { + Err(err) => { warn!("authentication has failed: {err}"); return; } - Some(Ok(auth_handle)) => auth_handle.listen_for_requests(shutdown).await, + Ok(auth_handle) => auth_handle.listen_for_requests(shutdown).await, } trace!("The handler is done!"); diff --git a/gateway/src/node/client_handling/websocket/listener.rs b/gateway/src/node/client_handling/websocket/listener.rs index 8c0b98a4401..be3c6c51a28 100644 --- a/gateway/src/node/client_handling/websocket/listener.rs +++ b/gateway/src/node/client_handling/websocket/listener.rs @@ -54,6 +54,7 @@ where connection = tcp_listener.accept() => { match connection { Ok((socket, remote_addr)) => { + let shutdown = shutdown.clone().named(format!("ClientConnectionHandler_{remote_addr}")); trace!("received a socket connection from {remote_addr}"); // TODO: I think we *REALLY* need a mechanism for having a maximum number of connected // clients or spawned tokio tasks -> perhaps a worker system? @@ -64,9 +65,9 @@ where active_clients_store.clone(), self.shared_state.clone(), remote_addr, + shutdown, ); - let shutdown = shutdown.clone().named(format!("ClientConnectionHandler_{remote_addr}")); - tokio::spawn(async move { handle.start_handling(shutdown).await }); + tokio::spawn(handle.start_handling()); } Err(err) => warn!("failed to get client: {err}"), }