diff --git a/tentacle/src/session.rs b/tentacle/src/session.rs index 10ec6682..5d275f17 100644 --- a/tentacle/src/session.rs +++ b/tentacle/src/session.rs @@ -10,7 +10,7 @@ use std::{ time::Duration, }; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::codec::{Framed, FramedParts, FramedRead, FramedWrite, LengthDelimitedCodec}; +use tokio_util::codec::{Framed, FramedParts, LengthDelimitedCodec}; use yamux::{Control, Session as YamuxSession, StreamHandle}; use crate::{ @@ -36,6 +36,19 @@ pub trait AsyncRw: AsyncWrite + AsyncRead {} impl AsyncRw for T {} +fn split_spawn_framed( + part: FramedParts, +) -> ( + futures::stream::SplitSink, bytes::Bytes>, + futures::stream::SplitStream>, +) +where + T: AsyncRead + AsyncWrite, + U: crate::traits::Codec, +{ + Framed::from_parts(part).split() +} + /// Event generated/received by the Session pub(crate) enum SessionEvent { /// Session close event @@ -432,13 +445,13 @@ impl Session { match proto.spawn { Some(ref spawn) => { - let (read, write) = crate::runtime::split(raw_part.io); + let mut part = FramedParts::new(raw_part.io, (proto.codec)()); + part.read_buf = raw_part.read_buf; + part.write_buf = raw_part.write_buf; + let (write, read) = split_spawn_framed(part); let read_part = { - let mut frame = FramedRead::new(read, (proto.codec)()); - *frame.read_buffer_mut() = raw_part.read_buf; - SubstreamReadPart { - substream: frame, + substream: read, before_receive: before_receive_fn, proto_id, stream_id: self.next_stream, @@ -455,7 +468,7 @@ impl Session { .proto_id(proto_id) .stream_id(self.next_stream) .config(self.config) - .build(FramedWrite::new(write, (proto.codec)())); + .build(write); crate::runtime::spawn(write_part.for_each(|_| future::ready(()))); spawn.spawn(self.context.clone(), &self.service_control, read_part); @@ -800,6 +813,37 @@ impl Stream for Session { } } +#[cfg(all(test, not(target_family = "wasm")))] +mod tests { + use super::split_spawn_framed; + use bytes::{Bytes, BytesMut}; + use futures::StreamExt; + use tokio::io::duplex; + use tokio_util::codec::{Encoder, FramedParts, LengthDelimitedCodec}; + + #[tokio::test] + async fn split_spawn_framed_reads_buffered_first_frame() { + let (io, _peer) = duplex(64); + let mut codec = LengthDelimitedCodec::new(); + let mut read_buf = BytesMut::new(); + codec + .encode(Bytes::from_static(b"init"), &mut read_buf) + .expect("encode buffered first frame"); + + let mut parts = FramedParts::new(io, codec); + parts.read_buf = read_buf; + + let (_write, mut read) = split_spawn_framed(parts); + let first = read + .next() + .await + .expect("buffered frame should be available") + .expect("buffered frame should decode"); + + assert_eq!(first.freeze(), Bytes::from_static(b"init")); + } +} + pub(crate) struct SessionMeta { config: SessionConfig, protocol_configs_by_name: HashMap>, diff --git a/tentacle/src/substream.rs b/tentacle/src/substream.rs index 2be804c1..391e366e 100644 --- a/tentacle/src/substream.rs +++ b/tentacle/src/substream.rs @@ -1,4 +1,7 @@ -use futures::{SinkExt, StreamExt, channel::mpsc, prelude::*, stream::iter}; +use futures::{ + SinkExt, StreamExt, channel::mpsc, prelude::*, stream::SplitSink, stream::SplitStream, + stream::iter, +}; use log::debug; use std::{ collections::VecDeque, @@ -8,7 +11,7 @@ use std::{ task::{Context, Poll}, }; use tokio::io::AsyncWrite; -use tokio_util::codec::{Framed, FramedRead, FramedWrite, length_delimited::LengthDelimitedCodec}; +use tokio_util::codec::{Framed, length_delimited::LengthDelimitedCodec}; use crate::{ ProtocolId, StreamId, @@ -589,7 +592,7 @@ impl SubstreamBuilder { /* Code organization under read-write separation */ pub(crate) struct SubstreamWritePart { - substream: FramedWrite, U>, + substream: SplitSink, bytes::Bytes>, id: StreamId, proto_id: ProtocolId, @@ -748,7 +751,7 @@ where Poll::Ready(None) => { // Must be session close self.dead = true; - if let Poll::Ready(Err(e)) = Pin::new(self.substream.get_mut()).poll_shutdown(cx) { + if let Poll::Ready(Err(e)) = Pin::new(&mut self.substream).poll_close(cx) { log::trace!("sub stream poll shutdown err {}", e) } Poll::Ready(None) @@ -777,7 +780,7 @@ where fn close_proto_stream(&mut self, cx: &mut Context) { self.event_receiver.close(); - if let Poll::Ready(Err(e)) = Pin::new(self.substream.get_mut()).poll_shutdown(cx) { + if let Poll::Ready(Err(e)) = Pin::new(&mut self.substream).poll_close(cx) { log::trace!("sub stream poll shutdown err {}", e) } if !self.context.closed.load(Ordering::SeqCst) { @@ -860,8 +863,7 @@ where /// Protocol Stream read part pub struct SubstreamReadPart { - pub(crate) substream: - FramedRead, Box>, + pub(crate) substream: SplitStream>>, pub(crate) before_receive: Option, pub(crate) proto_id: ProtocolId, pub(crate) stream_id: StreamId, @@ -964,7 +966,7 @@ impl SubstreamWritePartBuilder { pub fn build( self, - substream: FramedWrite, U>, + substream: SplitSink, bytes::Bytes>, ) -> SubstreamWritePart where U: Codec,