Skip to content

Commit

Permalink
make requiring protocol extensions easy
Browse files Browse the repository at this point in the history
  • Loading branch information
r58Playz committed Apr 21, 2024
1 parent 063b527 commit 01d7ac5
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 90 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion client/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,11 @@ pub async fn make_mux(
let (wtx, wrx) =
WebSocketWrapper::connect(url, vec![]).map_err(|_| WispError::WsImplSocketClosed)?;
wtx.wait_for_open().await;
ClientMux::new(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await
Ok(
ClientMux::create(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())]))
.await?
.with_no_required_extensions(),
)
}

pub fn spawn_mux_fut(
Expand Down
50 changes: 17 additions & 33 deletions server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,45 +253,39 @@ async fn accept_http(
}
}

async fn handle_mux(packet: ConnectPacket, stream: MuxStream) -> Result<bool, WispError> {
async fn handle_mux(
packet: ConnectPacket,
stream: MuxStream,
) -> Result<bool, Box<dyn std::error::Error + Sync + Send>> {
let uri = format!(
"{}:{}",
packet.destination_hostname, packet.destination_port
);
match packet.stream_type {
StreamType::Tcp => {
let mut tcp_stream = TcpStream::connect(uri)
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
let mut tcp_stream = TcpStream::connect(uri).await?;
let mut mux_stream = stream.into_io().into_asyncrw();
copy_bidirectional(&mut mux_stream, &mut tcp_stream)
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
copy_bidirectional(&mut mux_stream, &mut tcp_stream).await?;
}
StreamType::Udp => {
let uri = lookup_host(uri)
.await
.map_err(|x| WispError::Other(Box::new(x)))?
.await?
.next()
.ok_or(WispError::InvalidUri)?;
let udp_socket = UdpSocket::bind(if uri.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" })
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
udp_socket
.connect(uri)
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
let udp_socket =
UdpSocket::bind(if uri.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" }).await?;
udp_socket.connect(uri).await?;
let mut data = vec![0u8; 65507]; // udp standard max datagram size
loop {
tokio::select! {
size = udp_socket.recv(&mut data).map_err(|x| WispError::Other(Box::new(x))) => {
size = udp_socket.recv(&mut data) => {
let size = size?;
stream.write(Bytes::copy_from_slice(&data[..size])).await?
},
event = stream.read() => {
match event {
Some(event) => {
let _ = udp_socket.send(&event).await.map_err(|x| WispError::Other(Box::new(x)))?;
let _ = udp_socket.send(&event).await?;
}
None => break,
}
Expand Down Expand Up @@ -319,28 +313,18 @@ async fn accept_ws(
// to prevent memory ""leaks"" because users are sending in packets way too fast the buffer
// size is set to 128
let (mux, fut) = if mux_options.enforce_auth {
let (mux, fut) = ServerMux::new(rx, tx, 128, Some(mux_options.auth.as_slice())).await?;
if !mux
.supported_extension_ids
.iter()
.any(|x| *x == PasswordProtocolExtension::ID)
{
println!(
"{:?}: client did not support auth or password was invalid",
addr
);
mux.close_extension_incompat().await?;
return Ok(());
}
(mux, fut)
ServerMux::create(rx, tx, 128, Some(mux_options.auth.as_slice()))
.await?
.with_required_extensions(&[PasswordProtocolExtension::ID]).await?
} else {
ServerMux::new(
ServerMux::create(
rx,
tx,
128,
Some(&[Box::new(UdpProtocolExtensionBuilder())]),
)
.await?
.with_no_required_extensions()
};

println!(
Expand Down
72 changes: 28 additions & 44 deletions simple-wisp-client/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,53 +156,33 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let rx = FragmentCollectorRead::new(rx);

let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new();
let mut extension_ids: Vec<u8> = Vec::new();
if opts.udp {
extensions.push(Box::new(UdpProtocolExtensionBuilder()));
extension_ids.push(UdpProtocolExtension::ID);
}
let enforce_auth = auth.is_some();
if let Some(auth) = auth {
extensions.push(Box::new(auth));
extension_ids.push(PasswordProtocolExtension::ID);
}

let (mux, fut) = if opts.wisp_v1 {
ClientMux::new(rx, tx, None).await?
ClientMux::create(rx, tx, None)
.await?
.with_no_required_extensions()
} else {
ClientMux::new(rx, tx, Some(extensions.as_slice())).await?
ClientMux::create(rx, tx, Some(extensions.as_slice()))
.await?
.with_required_extensions(extension_ids.as_slice()).await?
};

if opts.udp
&& !mux
.supported_extension_ids
.iter()
.any(|x| *x == UdpProtocolExtension::ID)
{
println!(
"server did not support udp, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids
);
mux.close_extension_incompat().await?;
exit(1);
}
if enforce_auth
&& !mux
.supported_extension_ids
.iter()
.any(|x| *x == PasswordProtocolExtension::ID)
{
println!(
"server did not support passwords or password was incorrect, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids
);
mux.close_extension_incompat().await?;
exit(1);
}

println!(
"connected and created ClientMux, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids
);

let mut threads = Vec::with_capacity(opts.streams * 2 + 3);
let mut threads = Vec::with_capacity(opts.streams + 4);
let mut reads = Vec::with_capacity(opts.streams);

threads.push(tokio::spawn(fut));

Expand All @@ -226,13 +206,15 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
#[allow(unreachable_code)]
Ok::<(), WispError>(())
}));
threads.push(tokio::spawn(async move {
loop {
cr.read().await;
}
}));
reads.push(cr);
}

threads.push(tokio::spawn(async move {
loop {
select_all(reads.iter().map(|x| Box::pin(x.read()))).await;
}
}));

let cnt_avg = cnt.clone();
threads.push(tokio::spawn(async move {
let mut interval = interval(Duration::from_millis(100));
Expand Down Expand Up @@ -295,14 +277,16 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {

mux.close().await?;

println!(
"\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)",
cnt.get(),
opts.packet_size,
cnt.get() * opts.packet_size,
format_duration(duration_since),
(cnt.get() * opts.packet_size) as u64 / duration_since.as_secs(),
);
if duration_since.as_secs() != 0 {
println!(
"\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)",
cnt.get(),
opts.packet_size,
cnt.get() * opts.packet_size,
format_duration(duration_since),
(cnt.get() * opts.packet_size) as u64 / duration_since.as_secs(),
);
}

Ok(())
}
1 change: 0 additions & 1 deletion wisp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional =
flume = "0.11.0"
futures = "0.3.30"
futures-timer = "3.0.3"
futures-util = "0.3.30"
pin-project-lite = "0.2.13"
tokio = { version = "1.35.1", optional = true, default-features = false }

Expand Down
Loading

0 comments on commit 01d7ac5

Please sign in to comment.