Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 51 additions & 7 deletions tentacle/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{
time::Duration,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::{Framed, FramedParts, FramedRead, FramedWrite, LengthDelimitedCodec};
use tokio_util::codec::{Framed, FramedParts, LengthDelimitedCodec};
use yamux::{Control, Session as YamuxSession, StreamHandle};

use crate::{
Expand All @@ -36,6 +36,19 @@ pub trait AsyncRw: AsyncWrite + AsyncRead {}

impl<T: AsyncRead + AsyncWrite> AsyncRw for T {}

fn split_spawn_framed<T, U>(
part: FramedParts<T, U>,
) -> (
futures::stream::SplitSink<Framed<T, U>, bytes::Bytes>,
futures::stream::SplitStream<Framed<T, U>>,
)
where
T: AsyncRead + AsyncWrite,
U: crate::traits::Codec,
{
Framed::from_parts(part).split()
}

/// Event generated/received by the Session
pub(crate) enum SessionEvent {
/// Session close event
Expand Down Expand Up @@ -432,13 +445,13 @@ impl Session {

match proto.spawn {
Some(ref spawn) => {
let (read, write) = crate::runtime::split(raw_part.io);
let mut part = FramedParts::new(raw_part.io, (proto.codec)());
part.read_buf = raw_part.read_buf;
part.write_buf = raw_part.write_buf;
let (write, read) = split_spawn_framed(part);
let read_part = {
let mut frame = FramedRead::new(read, (proto.codec)());
*frame.read_buffer_mut() = raw_part.read_buf;

SubstreamReadPart {
substream: frame,
substream: read,
before_receive: before_receive_fn,
proto_id,
stream_id: self.next_stream,
Expand All @@ -455,7 +468,7 @@ impl Session {
.proto_id(proto_id)
.stream_id(self.next_stream)
.config(self.config)
.build(FramedWrite::new(write, (proto.codec)()));
.build(write);

crate::runtime::spawn(write_part.for_each(|_| future::ready(())));
spawn.spawn(self.context.clone(), &self.service_control, read_part);
Expand Down Expand Up @@ -800,6 +813,37 @@ impl Stream for Session {
}
}

#[cfg(all(test, not(target_family = "wasm")))]
mod tests {
use super::split_spawn_framed;
use bytes::{Bytes, BytesMut};
use futures::StreamExt;
use tokio::io::duplex;
use tokio_util::codec::{Encoder, FramedParts, LengthDelimitedCodec};

#[tokio::test]
async fn split_spawn_framed_reads_buffered_first_frame() {
let (io, _peer) = duplex(64);
let mut codec = LengthDelimitedCodec::new();
let mut read_buf = BytesMut::new();
codec
.encode(Bytes::from_static(b"init"), &mut read_buf)
.expect("encode buffered first frame");

let mut parts = FramedParts::new(io, codec);
parts.read_buf = read_buf;

let (_write, mut read) = split_spawn_framed(parts);
let first = read
.next()
.await
.expect("buffered frame should be available")
.expect("buffered frame should decode");

assert_eq!(first.freeze(), Bytes::from_static(b"init"));
}
}

pub(crate) struct SessionMeta {
config: SessionConfig,
protocol_configs_by_name: HashMap<String, Arc<Meta>>,
Expand Down
18 changes: 10 additions & 8 deletions tentacle/src/substream.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use futures::{SinkExt, StreamExt, channel::mpsc, prelude::*, stream::iter};
use futures::{
SinkExt, StreamExt, channel::mpsc, prelude::*, stream::SplitSink, stream::SplitStream,
stream::iter,
};
use log::debug;
use std::{
collections::VecDeque,
Expand All @@ -8,7 +11,7 @@ use std::{
task::{Context, Poll},
};
use tokio::io::AsyncWrite;
use tokio_util::codec::{Framed, FramedRead, FramedWrite, length_delimited::LengthDelimitedCodec};
use tokio_util::codec::{Framed, length_delimited::LengthDelimitedCodec};

use crate::{
ProtocolId, StreamId,
Expand Down Expand Up @@ -589,7 +592,7 @@ impl SubstreamBuilder {
/* Code organization under read-write separation */

pub(crate) struct SubstreamWritePart<U> {
substream: FramedWrite<crate::runtime::WriteHalf<StreamHandle>, U>,
substream: SplitSink<Framed<StreamHandle, U>, bytes::Bytes>,
id: StreamId,
proto_id: ProtocolId,

Expand Down Expand Up @@ -748,7 +751,7 @@ where
Poll::Ready(None) => {
// Must be session close
self.dead = true;
if let Poll::Ready(Err(e)) = Pin::new(self.substream.get_mut()).poll_shutdown(cx) {
if let Poll::Ready(Err(e)) = Pin::new(&mut self.substream).poll_close(cx) {
log::trace!("sub stream poll shutdown err {}", e)
}
Poll::Ready(None)
Expand Down Expand Up @@ -777,7 +780,7 @@ where

fn close_proto_stream(&mut self, cx: &mut Context) {
self.event_receiver.close();
if let Poll::Ready(Err(e)) = Pin::new(self.substream.get_mut()).poll_shutdown(cx) {
if let Poll::Ready(Err(e)) = Pin::new(&mut self.substream).poll_close(cx) {
log::trace!("sub stream poll shutdown err {}", e)
}
if !self.context.closed.load(Ordering::SeqCst) {
Expand Down Expand Up @@ -860,8 +863,7 @@ where

/// Protocol Stream read part
pub struct SubstreamReadPart {
pub(crate) substream:
FramedRead<crate::runtime::ReadHalf<StreamHandle>, Box<dyn Codec + Send + 'static>>,
pub(crate) substream: SplitStream<Framed<StreamHandle, Box<dyn Codec + Send + 'static>>>,
pub(crate) before_receive: Option<BeforeReceive>,
pub(crate) proto_id: ProtocolId,
pub(crate) stream_id: StreamId,
Expand Down Expand Up @@ -964,7 +966,7 @@ impl SubstreamWritePartBuilder {

pub fn build<U>(
self,
substream: FramedWrite<crate::runtime::WriteHalf<StreamHandle>, U>,
substream: SplitSink<Framed<StreamHandle, U>, bytes::Bytes>,
) -> SubstreamWritePart<U>
where
U: Codec,
Expand Down