Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions common/client-libs/gateway-client/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,8 @@ impl<C, St> GatewayClient<C, St> {
self.local_identity.as_ref(),
self.gateway_identity,
self.cfg.bandwidth.require_tickets,
#[cfg(not(target_arch = "wasm32"))]
self.task_client.clone(),
)
.await
.map_err(GatewayClientError::RegistrationFailure),
Expand Down
1 change: 1 addition & 0 deletions common/gateway-requests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ zeroize = { workspace = true }
nym-crypto = { path = "../crypto" }
nym-pemstore = { path = "../pemstore" }
nym-sphinx = { path = "../nymsphinx" }
nym-task = { path = "../task" }

nym-credentials = { path = "../credentials" }
nym-credentials-interface = { path = "../credentials-interface" }
Expand Down
3 changes: 3 additions & 0 deletions common/gateway-requests/src/registration/handshake/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ impl<'a> ClientHandshake<'a> {
identity: &'a nym_crypto::asymmetric::identity::KeyPair,
gateway_pubkey: identity::PublicKey,
expects_credential_usage: bool,
#[cfg(not(target_arch = "wasm32"))] shutdown: nym_task::TaskClient,
) -> Self
where
S: Stream<Item = WsItem> + Sink<WsMessage> + Unpin + Send + 'a,
Expand All @@ -35,6 +36,8 @@ impl<'a> ClientHandshake<'a> {
identity,
Some(gateway_pubkey),
expects_credential_usage,
#[cfg(not(target_arch = "wasm32"))]
shutdown,
);

ClientHandshake {
Expand Down
2 changes: 2 additions & 0 deletions common/gateway-requests/src/registration/handshake/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ pub enum HandshakeError {
MalformedRequest,
#[error("sent request was malformed")]
HandshakeFailure,
#[error("received shutdown")]
ReceivedShutdown,

#[error("timed out waiting for a handshake message")]
Timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use futures::future::BoxFuture;
use futures::task::{Context, Poll};
use futures::{Future, Sink, Stream};
use nym_crypto::asymmetric::encryption;
use nym_task::TaskClient;
use rand::{CryptoRng, RngCore};
use std::pin::Pin;
use tungstenite::Message as WsMessage;
Expand All @@ -22,11 +23,12 @@ impl<'a> GatewayHandshake<'a> {
ws_stream: &'a mut S,
identity: &'a nym_crypto::asymmetric::identity::KeyPair,
received_init_payload: Vec<u8>,
shutdown: TaskClient,
) -> Self
where
S: Stream<Item = WsItem> + Sink<WsMessage> + Unpin + Send + 'a,
{
let mut state = State::new(rng, ws_stream, identity, None, true);
let mut state = State::new(rng, ws_stream, identity, None, true, shutdown);
GatewayHandshake {
handshake_future: Box::pin(async move {
// If any step along the way failed (that are non-network related),
Expand Down
8 changes: 7 additions & 1 deletion common/gateway-requests/src/registration/handshake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use self::gateway::GatewayHandshake;
pub use self::shared_key::{SharedKeySize, SharedKeys};
use futures::{Sink, Stream};
use nym_crypto::asymmetric::identity;
#[cfg(not(target_arch = "wasm32"))]
use nym_task::TaskClient;
use rand::{CryptoRng, RngCore};
use tungstenite::{Error as WsError, Message as WsMessage};

Expand All @@ -31,6 +33,7 @@ pub async fn client_handshake<'a, S>(
identity: &'a identity::KeyPair,
gateway_pubkey: identity::PublicKey,
expects_credential_usage: bool,
#[cfg(not(target_arch = "wasm32"))] shutdown: TaskClient,
) -> Result<SharedKeys, HandshakeError>
where
S: Stream<Item = WsItem> + Sink<WsMessage> + Unpin + Send + 'a,
Expand All @@ -41,6 +44,8 @@ where
identity,
gateway_pubkey,
expects_credential_usage,
#[cfg(not(target_arch = "wasm32"))]
shutdown,
)
.await
}
Expand All @@ -51,11 +56,12 @@ pub async fn gateway_handshake<'a, S>(
ws_stream: &'a mut S,
identity: &'a identity::KeyPair,
received_init_payload: Vec<u8>,
shutdown: TaskClient,
) -> Result<SharedKeys, HandshakeError>
where
S: Stream<Item = WsItem> + Sink<WsMessage> + Unpin + Send + 'a,
{
GatewayHandshake::new(rng, ws_stream, identity, received_init_payload).await
GatewayHandshake::new(rng, ws_stream, identity, received_init_payload, shutdown).await
}

/*
Expand Down
99 changes: 69 additions & 30 deletions common/gateway-requests/src/registration/handshake/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use nym_crypto::{
symmetric::stream_cipher,
};
use nym_sphinx::params::{GatewayEncryptionAlgorithm, GatewaySharedKeyHkdfAlgorithm};
#[cfg(not(target_arch = "wasm32"))]
use nym_task::TaskClient;
use rand::{CryptoRng, RngCore};
use tracing::log::*;

Expand Down Expand Up @@ -48,6 +50,10 @@ pub(crate) struct State<'a, S> {
// this field is really out of place here, however, we need to propagate this information somehow
// in order to establish correct protocol for backwards compatibility reasons
expects_credential_usage: bool,

// channel to receive shutdown signal
#[cfg(not(target_arch = "wasm32"))]
shutdown: TaskClient,
}

impl<'a, S> State<'a, S> {
Expand All @@ -57,6 +63,7 @@ impl<'a, S> State<'a, S> {
identity: &'a identity::KeyPair,
remote_pubkey: Option<identity::PublicKey>,
expects_credential_usage: bool,
#[cfg(not(target_arch = "wasm32"))] shutdown: TaskClient,
) -> Self {
let ephemeral_keypair = encryption::KeyPair::new(rng);
State {
Expand All @@ -66,6 +73,8 @@ impl<'a, S> State<'a, S> {
remote_pubkey,
derived_shared_keys: None,
expects_credential_usage,
#[cfg(not(target_arch = "wasm32"))]
shutdown,
}
}

Expand Down Expand Up @@ -199,46 +208,76 @@ impl<'a, S> State<'a, S> {
self.remote_pubkey = Some(remote_pubkey)
}

fn on_wg_msg(msg: Option<WsItem>) -> Result<Option<Vec<u8>>, HandshakeError> {
let Some(msg) = msg else {
return Err(HandshakeError::ClosedStream);
};

let Ok(msg) = msg else {
return Err(HandshakeError::NetworkError);
};
match msg {
WsMessage::Text(ref ws_msg) => {
match types::RegistrationHandshake::from_str(ws_msg) {
Ok(reg_handshake_msg) => {
match reg_handshake_msg {
// hehe, that's a bit disgusting that the type system requires we explicitly ignore the
// protocol_version field that we actually never attach at this point
// yet another reason for the overdue refactor
types::RegistrationHandshake::HandshakePayload { data, .. } => {
Ok(Some(data))
}
types::RegistrationHandshake::HandshakeError { message } => {
Err(HandshakeError::RemoteError(message))
}
}
}
Err(_) => {
error!("Received a non-handshake message during the registration handshake! It's getting dropped. The received content was: '{msg}'");
Ok(None)
}
}
}
_ => {
error!("Received non-text message during registration handshake");
Ok(None)
}
}
}

#[cfg(not(target_arch = "wasm32"))]
async fn _receive_handshake_message(&mut self) -> Result<Vec<u8>, HandshakeError>
where
S: Stream<Item = WsItem> + Unpin,
{
loop {
let Some(msg) = self.ws_stream.next().await else {
return Err(HandshakeError::ClosedStream);
};

let Ok(msg) = msg else {
return Err(HandshakeError::NetworkError);
};

match msg {
WsMessage::Text(ref ws_msg) => {
match types::RegistrationHandshake::from_str(ws_msg) {
Ok(reg_handshake_msg) => {
return match reg_handshake_msg {
// hehe, that's a bit disgusting that the type system requires we explicitly ignore the
// protocol_version field that we actually never attach at this point
// yet another reason for the overdue refactor
types::RegistrationHandshake::HandshakePayload { data, .. } => {
Ok(data)
}
types::RegistrationHandshake::HandshakeError { message } => {
Err(HandshakeError::RemoteError(message))
}
};
}
Err(_) => {
error!("Received a non-handshake message during the registration handshake! It's getting dropped. The received content was: '{msg}'");
continue;
}
}
tokio::select! {
biased;
_ = self.shutdown.recv() => return Err(HandshakeError::ReceivedShutdown),
msg = self.ws_stream.next() => {
let Some(ret) = Self::on_wg_msg(msg)? else {
continue;
};
return Ok(ret);
}
_ => error!("Received non-text message during registration handshake"),
}
}
}

#[cfg(target_arch = "wasm32")]
async fn _receive_handshake_message(&mut self) -> Result<Vec<u8>, HandshakeError>
where
S: Stream<Item = WsItem> + Unpin,
{
loop {
let msg = self.ws_stream.next().await;
let Some(ret) = Self::on_wg_msg(msg)? else {
continue;
};
return Ok(ret);
}
}

pub(crate) async fn receive_handshake_message(&mut self) -> Result<Vec<u8>, HandshakeError>
where
S: Stream<Item = WsItem> + Unpin,
Expand Down
12 changes: 0 additions & 12 deletions common/task/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

use futures::{future::pending, FutureExt, SinkExt, StreamExt};
use log::{log, Level};
use std::future::Future;
use std::sync::atomic::{AtomicBool, Ordering};
use std::{error::Error, time::Duration};
use tokio::sync::{
Expand Down Expand Up @@ -368,17 +367,6 @@ impl TaskClient {
self.named(name)
}

pub async fn run_future<Fut, T>(&mut self, fut: Fut) -> Option<T>
where
Fut: Future<Output = T>,
{
tokio::select! {
biased;
_ = self.recv() => None,
res = fut => Some(res)
}
}

// Create a dummy that will never report that we should shutdown.
pub fn dummy() -> TaskClient {
let (_notify_tx, notify_rx) = watch::channel(());
Expand Down
Loading