Skip to content

Commit

Permalink
remove a bunch of allocations from packet encode, drop rogue clients'…
Browse files Browse the repository at this point in the history
… packets
  • Loading branch information
r58Playz committed Apr 28, 2024
1 parent ce26609 commit 855fa61
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 74 deletions.
2 changes: 1 addition & 1 deletion wisp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ impl MuxInner {
}
Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.send_async(data).await;
let _ = stream.stream.try_send(data);
if stream.stream_type == StreamType::Tcp {
stream.flow_control.store(
stream
Expand Down
133 changes: 69 additions & 64 deletions wisp/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ impl TryFrom<u8> for CloseReason {
}
}

trait Encode {
fn encode(self, bytes: &mut BytesMut);
}

/// Packet used to create a new stream.
///
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---connect).
Expand Down Expand Up @@ -120,9 +124,9 @@ impl ConnectPacket {
}
}

impl TryFrom<Bytes> for ConnectPacket {
impl TryFrom<BytesMut> for ConnectPacket {
type Error = WispError;
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
if bytes.remaining() < (1 + 2) {
return Err(Self::Error::PacketTooSmall);
}
Expand All @@ -134,13 +138,11 @@ impl TryFrom<Bytes> for ConnectPacket {
}
}

impl From<ConnectPacket> for Bytes {
fn from(packet: ConnectPacket) -> Self {
let mut encoded = BytesMut::with_capacity(1 + 2 + packet.destination_hostname.len());
encoded.put_u8(packet.stream_type.into());
encoded.put_u16_le(packet.destination_port);
encoded.extend(packet.destination_hostname.bytes());
encoded.freeze()
impl Encode for ConnectPacket {
fn encode(self, bytes: &mut BytesMut) {
bytes.put_u8(self.stream_type.into());
bytes.put_u16_le(self.destination_port);
bytes.extend(self.destination_hostname.bytes());
}
}

Expand All @@ -160,9 +162,9 @@ impl ContinuePacket {
}
}

impl TryFrom<Bytes> for ContinuePacket {
impl TryFrom<BytesMut> for ContinuePacket {
type Error = WispError;
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
if bytes.remaining() < 4 {
return Err(Self::Error::PacketTooSmall);
}
Expand All @@ -172,11 +174,9 @@ impl TryFrom<Bytes> for ContinuePacket {
}
}

impl From<ContinuePacket> for Bytes {
fn from(packet: ContinuePacket) -> Self {
let mut encoded = BytesMut::with_capacity(4);
encoded.put_u32_le(packet.buffer_remaining);
encoded.freeze()
impl Encode for ContinuePacket {
fn encode(self, bytes: &mut BytesMut) {
bytes.put_u32_le(self.buffer_remaining);
}
}

Expand All @@ -197,9 +197,9 @@ impl ClosePacket {
}
}

impl TryFrom<Bytes> for ClosePacket {
impl TryFrom<BytesMut> for ClosePacket {
type Error = WispError;
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
if bytes.remaining() < 1 {
return Err(Self::Error::PacketTooSmall);
}
Expand All @@ -209,11 +209,9 @@ impl TryFrom<Bytes> for ClosePacket {
}
}

impl From<ClosePacket> for Bytes {
fn from(packet: ClosePacket) -> Self {
let mut encoded = BytesMut::with_capacity(1);
encoded.put_u8(packet.reason as u8);
encoded.freeze()
impl Encode for ClosePacket {
fn encode(self, bytes: &mut BytesMut) {
bytes.put_u8(self.reason as u8);
}
}

Expand All @@ -237,15 +235,13 @@ pub struct InfoPacket {
pub extensions: Vec<AnyProtocolExtension>,
}

impl From<InfoPacket> for Bytes {
fn from(value: InfoPacket) -> Self {
let mut bytes = BytesMut::with_capacity(2);
bytes.put_u8(value.version.major);
bytes.put_u8(value.version.minor);
for extension in value.extensions {
impl Encode for InfoPacket {
fn encode(self, bytes: &mut BytesMut) {
bytes.put_u8(self.version.major);
bytes.put_u8(self.version.minor);
for extension in self.extensions {
bytes.extend(Bytes::from(extension));
}
bytes.freeze()
}
}

Expand Down Expand Up @@ -276,21 +272,32 @@ impl PacketType {
P::Info(_) => 0x05,
}
}
}

impl From<PacketType> for Bytes {
fn from(packet: PacketType) -> Self {
pub(crate) fn get_packet_size(&self) -> usize {
use PacketType as P;
match packet {
P::Connect(x) => x.into(),
P::Data(x) => x,
P::Continue(x) => x.into(),
P::Close(x) => x.into(),
P::Info(x) => x.into(),
match self {
P::Connect(p) => 1 + 2 + p.destination_hostname.len(),
P::Data(p) => p.len(),
P::Continue(_) => 4,
P::Close(_) => 1,
P::Info(_) => 2,
}
}
}

impl Encode for PacketType {
fn encode(self, bytes: &mut BytesMut) {
use PacketType as P;
match self {
P::Connect(x) => x.encode(bytes),
P::Data(x) => bytes.extend(x),
P::Continue(x) => x.encode(bytes),
P::Close(x) => x.encode(bytes),
P::Info(x) => x.encode(bytes),
};
}
}

/// Wisp protocol packet.
#[derive(Debug, Clone)]
pub struct Packet {
Expand Down Expand Up @@ -362,21 +369,13 @@ impl Packet {
}
}

pub(crate) fn raw_encode(packet_type: u8, stream_id: u32, bytes: Bytes) -> BytesMut {
let mut encoded = BytesMut::with_capacity(1 + 4 + bytes.len());
encoded.put_u8(packet_type);
encoded.put_u32_le(stream_id);
encoded.extend(bytes);
encoded
}

fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result<Self, WispError> {
fn parse_packet(packet_type: u8, mut bytes: BytesMut) -> Result<Self, WispError> {
use PacketType as P;
Ok(Self {
stream_id: bytes.get_u32_le(),
packet_type: match packet_type {
0x01 => P::Connect(ConnectPacket::try_from(bytes)?),
0x02 => P::Data(bytes),
0x02 => P::Data(bytes.freeze()),
0x03 => P::Continue(ContinuePacket::try_from(bytes)?),
0x04 => P::Close(ClosePacket::try_from(bytes)?),
// 0x05 is handled seperately
Expand All @@ -396,7 +395,7 @@ impl Packet {
if frame.opcode != OpCode::Binary {
return Err(WispError::WsFrameInvalidType);
}
let mut bytes = frame.payload.freeze();
let mut bytes = frame.payload;
if bytes.remaining() < 1 {
return Err(WispError::PacketTooSmall);
}
Expand All @@ -420,8 +419,8 @@ impl Packet {
if frame.opcode != OpCode::Binary {
return Err(WispError::WsFrameInvalidType);
}
let mut bytes = frame.payload.freeze();
if bytes.remaining() < 1 {
let mut bytes = frame.payload;
if bytes.remaining() < 5 {
return Err(WispError::PacketTooSmall);
}
let packet_type = bytes.get_u8();
Expand All @@ -432,7 +431,7 @@ impl Packet {
})),
0x02 => Ok(Some(Self {
stream_id: bytes.get_u32_le(),
packet_type: PacketType::Data(bytes),
packet_type: PacketType::Data(bytes.freeze()),
})),
0x03 => Ok(Some(Self {
stream_id: bytes.get_u32_le(),
Expand All @@ -448,7 +447,7 @@ impl Packet {
.iter_mut()
.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
{
extension.handle_packet(bytes, read, write).await?;
extension.handle_packet(bytes.freeze(), read, write).await?;
Ok(None)
} else {
Err(WispError::InvalidPacketType)
Expand All @@ -458,7 +457,7 @@ impl Packet {
}

fn parse_info(
mut bytes: Bytes,
mut bytes: BytesMut,
role: Role,
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
) -> Result<Self, WispError> {
Expand Down Expand Up @@ -507,9 +506,17 @@ impl Packet {
}
}

impl TryFrom<Bytes> for Packet {
impl Encode for Packet {
fn encode(self, bytes: &mut BytesMut) {
bytes.put_u8(self.packet_type.as_u8());
bytes.put_u32_le(self.stream_id);
self.packet_type.encode(bytes);
}
}

impl TryFrom<BytesMut> for Packet {
type Error = WispError;
fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
if bytes.remaining() < 1 {
return Err(Self::Error::PacketTooSmall);
}
Expand All @@ -520,11 +527,9 @@ impl TryFrom<Bytes> for Packet {

impl From<Packet> for BytesMut {
fn from(packet: Packet) -> Self {
Packet::raw_encode(
packet.packet_type.as_u8(),
packet.stream_id,
packet.packet_type.into(),
)
let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size());
packet.encode(&mut encoded);
encoded
}
}

Expand All @@ -537,12 +542,12 @@ impl TryFrom<ws::Frame> for Packet {
if frame.opcode != ws::OpCode::Binary {
return Err(Self::Error::WsFrameInvalidType);
}
frame.payload.freeze().try_into()
Packet::try_from(frame.payload)
}
}

impl From<Packet> for ws::Frame {
fn from(packet: Packet) -> Self {
Self::binary(packet.into())
Self::binary(BytesMut::from(packet))
}
}
16 changes: 7 additions & 9 deletions wisp/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
};

pub use async_io_stream::IoStream;
use bytes::Bytes;
use bytes::{BufMut, Bytes, BytesMut};
use event_listener::Event;
use flume as mpsc;
use futures::{
Expand Down Expand Up @@ -114,7 +114,7 @@ impl MuxStreamWrite {
}

self.tx
.write_frame(Packet::new_data(self.stream_id, data).into())
.write_frame(Frame::from(Packet::new_data(self.stream_id, data)))
.await?;

if self.role == Role::Client && self.stream_type == StreamType::Tcp {
Expand Down Expand Up @@ -348,13 +348,11 @@ impl MuxProtocolExtensionStream {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
self.tx
.write_frame(Frame::binary(Packet::raw_encode(
packet_type,
self.stream_id,
data,
)))
.await
let mut encoded = BytesMut::with_capacity(1 + 4 + data.len());
encoded.put_u8(packet_type);
encoded.put_u32_le(self.stream_id);
encoded.extend(data);
self.tx.write_frame(Frame::binary(encoded)).await
}
}

Expand Down

0 comments on commit 855fa61

Please sign in to comment.