Skip to content

Commit

Permalink
add ability to send protocol extension packets
Browse files Browse the repository at this point in the history
  • Loading branch information
r58Playz committed Apr 17, 2024
1 parent fd94f12 commit 6c41c54
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 33 deletions.
12 changes: 6 additions & 6 deletions server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ async fn accept_http(
}
}

async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result<bool, WispError> {
async fn handle_mux(packet: ConnectPacket, stream: MuxStream) -> Result<bool, WispError> {
let uri = format!(
"{}:{}",
packet.destination_hostname, packet.destination_port
Expand Down Expand Up @@ -318,8 +318,8 @@ async fn accept_ws(
println!("{:?}: connected", addr);
// to prevent memory ""leaks"" because users are sending in packets way too fast the buffer
// size is set to 128
let (mut mux, fut) = if mux_options.enforce_auth {
let (mut mux, fut) = ServerMux::new(rx, tx, 128, Some(mux_options.auth.as_slice())).await?;
let (mux, fut) = if mux_options.enforce_auth {
let (mux, fut) = ServerMux::new(rx, tx, 128, Some(mux_options.auth.as_slice())).await?;
if !mux
.supported_extension_ids
.iter()
Expand Down Expand Up @@ -354,7 +354,7 @@ async fn accept_ws(
}
});

while let Some((packet, mut stream)) = mux.server_new_stream().await {
while let Some((packet, stream)) = mux.server_new_stream().await {
tokio::spawn(async move {
if (mux_options.block_non_http
&& !(packet.destination_port == 80 || packet.destination_port == 443))
Expand Down Expand Up @@ -386,8 +386,8 @@ async fn accept_ws(
}
}
}
let mut close_err = stream.get_close_handle();
let mut close_ok = stream.get_close_handle();
let close_err = stream.get_close_handle();
let close_ok = stream.get_close_handle();
let _ = handle_mux(packet, stream)
.or_else(|err| async move {
let _ = close_err.close(CloseReason::Unexpected).await;
Expand Down
4 changes: 2 additions & 2 deletions simple-wisp-client/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
extensions.push(Box::new(auth));
}

let (mut mux, fut) = if opts.wisp_v1 {
let (mux, fut) = if opts.wisp_v1 {
ClientMux::new(rx, tx, None).await?
} else {
ClientMux::new(rx, tx, Some(extensions.as_slice())).await?
Expand Down Expand Up @@ -212,7 +212,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {

let start_time = Instant::now();
for _ in 0..opts.streams {
let (mut cr, mut cw) = mux
let (cr, cw) = mux
.client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port)
.await?
.into_split();
Expand Down
19 changes: 11 additions & 8 deletions wisp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ impl MuxInner {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::SendBytes(packet, channel) => {
let _ = channel.send(self.tx.write_frame(ws::Frame::binary(packet)).await);
}
WsEvent::CreateStream(stream_type, host, port, channel) => {
let ret: Result<MuxStream, WispError> = async {
let stream_id = next_free_stream_id;
Expand Down Expand Up @@ -552,11 +555,11 @@ impl ServerMux {
}

/// Wait for a stream to be created.
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> {
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
self.muxstream_recv.recv_async().await.ok()
}

async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.close_tx
.send_async(WsEvent::EndFut(reason))
.await
Expand All @@ -567,15 +570,15 @@ impl ServerMux {
///
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
/// this function is called.
pub async fn close(&mut self) -> Result<(), WispError> {
pub async fn close(&self) -> Result<(), WispError> {
self.close_internal(None).await
}

/// Close all streams and send an extension incompatibility error to the client.
///
/// Also terminates the multiplexor future. Waiting for a new stream will never succed after
/// this function is called.
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> {
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await
}
Expand Down Expand Up @@ -696,7 +699,7 @@ impl ClientMux {

/// Create a new stream, multiplexed through Wisp.
pub async fn client_new_stream(
&mut self,
&self,
stream_type: StreamType,
host: String,
port: u16,
Expand All @@ -717,7 +720,7 @@ impl ClientMux {
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
}

async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
self.stream_tx
.send_async(WsEvent::EndFut(reason))
.await
Expand All @@ -728,15 +731,15 @@ impl ClientMux {
///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function.
pub async fn close(&mut self) -> Result<(), WispError> {
pub async fn close(&self) -> Result<(), WispError> {
self.close_internal(None).await
}

/// Close all streams and send an extension incompatibility error to the client.
///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function.
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> {
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
self.close_internal(Some(CloseReason::IncompatibleExtensions))
.await
}
Expand Down
20 changes: 13 additions & 7 deletions wisp/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,14 @@ impl Packet {
}
}

pub(crate) fn raw_encode(packet_type: u8, stream_id: u32, bytes: Bytes) -> Bytes {
let mut encoded = BytesMut::with_capacity(1 + 4 + bytes.len());
encoded.put_u8(packet_type);
encoded.put_u32_le(stream_id);
encoded.extend(bytes);
encoded.freeze()
}

fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result<Self, WispError> {
use PacketType as P;
Ok(Self {
Expand Down Expand Up @@ -494,13 +502,11 @@ impl TryFrom<Bytes> for Packet {

impl From<Packet> for Bytes {
fn from(packet: Packet) -> Self {
let inner_u8 = packet.packet_type.as_u8();
let inner = Bytes::from(packet.packet_type);
let mut encoded = BytesMut::with_capacity(1 + 4 + inner.len());
encoded.put_u8(inner_u8);
encoded.put_u32_le(packet.stream_id);
encoded.extend(inner);
encoded.freeze()
Packet::raw_encode(
packet.packet_type.as_u8(),
packet.stream_id,
packet.packet_type.into(),
)
}
}

Expand Down
62 changes: 52 additions & 10 deletions wisp/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::{

pub(crate) enum WsEvent {
SendPacket(Packet, oneshot::Sender<Result<(), WispError>>),
SendBytes(Bytes, oneshot::Sender<Result<(), WispError>>),
Close(Packet, oneshot::Sender<Result<(), WispError>>),
CreateStream(
StreamType,
Expand Down Expand Up @@ -49,7 +50,7 @@ pub struct MuxStreamRead {

impl MuxStreamRead {
/// Read an event from the stream.
pub async fn read(&mut self) -> Option<Bytes> {
pub async fn read(&self) -> Option<Bytes> {
if self.is_closed.load(Ordering::Acquire) {
return None;
}
Expand Down Expand Up @@ -79,7 +80,7 @@ impl MuxStreamRead {
}

pub(crate) fn into_stream(self) -> Pin<Box<dyn Stream<Item = Bytes> + Send>> {
Box::pin(stream::unfold(self, |mut rx| async move {
Box::pin(stream::unfold(self, |rx| async move {
Some((rx.read().await?, rx))
}))
}
Expand All @@ -100,7 +101,7 @@ pub struct MuxStreamWrite {

impl MuxStreamWrite {
/// Write data to the stream.
pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> {
pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
Expand Down Expand Up @@ -147,8 +148,17 @@ impl MuxStreamWrite {
}
}

/// Get a protocol extension stream to send protocol extension packets.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
MuxProtocolExtensionStream {
stream_id: self.stream_id,
tx: self.tx.clone(),
is_closed: self.is_closed.clone(),
}
}

/// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> {
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
Expand All @@ -171,12 +181,12 @@ impl MuxStreamWrite {
let handle = self.get_close_handle();
Box::pin(sink_unfold::unfold(
self,
|mut tx, data| async move {
|tx, data| async move {
tx.write(data).await?;
Ok(tx)
},
handle,
move |mut handle| async {
move |handle| async {
handle.close(CloseReason::Unknown).await?;
Ok(handle)
},
Expand Down Expand Up @@ -246,12 +256,12 @@ impl MuxStream {
}

/// Read an event from the stream.
pub async fn read(&mut self) -> Option<Bytes> {
pub async fn read(&self) -> Option<Bytes> {
self.rx.read().await
}

/// Write data to the stream.
pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> {
pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
self.tx.write(data).await
}

Expand All @@ -270,8 +280,13 @@ impl MuxStream {
self.tx.get_close_handle()
}

/// Get a protocol extension stream to send protocol extension packets.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
self.tx.get_protocol_extension_stream()
}

/// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> {
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
self.tx.close(reason).await
}

Expand Down Expand Up @@ -300,7 +315,7 @@ pub struct MuxStreamCloser {

impl MuxStreamCloser {
/// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> {
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
Expand All @@ -320,6 +335,33 @@ impl MuxStreamCloser {
}
}

/// Stream for sending arbitrary protocol extension packets.
pub struct MuxProtocolExtensionStream {
/// ID of the stream.
pub stream_id: u32,
tx: mpsc::Sender<WsEvent>,
is_closed: Arc<AtomicBool>,
}

impl MuxProtocolExtensionStream {
/// Send a protocol extension packet.
pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx
.send_async(WsEvent::SendBytes(
Packet::raw_encode(packet_type, self.stream_id, data),
tx,
))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
Ok(())
}
}

pin_project! {
/// Multiplexor stream that implements futures `Stream + Sink`.
pub struct MuxStreamIo {
Expand Down

0 comments on commit 6c41c54

Please sign in to comment.