diff --git a/src/substream/mod.rs b/src/substream/mod.rs index 83a3b136..eff703ec 100644 --- a/src/substream/mod.rs +++ b/src/substream/mod.rs @@ -34,16 +34,17 @@ use crate::transport::websocket; use bytes::{Buf, Bytes, BytesMut}; use futures::{Sink, Stream}; +use indexmap::{map::Entry, IndexMap}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use unsigned_varint::{decode, encode}; use std::{ - collections::{hash_map::Entry, HashMap, VecDeque}, + collections::VecDeque, fmt, hash::Hash, io::ErrorKind, pin::Pin, - task::{Context, Poll}, + task::{Context, Poll, Waker}, }; /// Logging target for the file. @@ -765,7 +766,9 @@ where K: SubstreamSetKey, S: Stream> + Unpin, { - substreams: HashMap, + substreams: IndexMap, + poll_index: usize, + waker: Option, } impl SubstreamSet @@ -776,7 +779,9 @@ where /// Create new [`SubstreamSet`]. pub fn new() -> Self { Self { - substreams: HashMap::new(), + substreams: IndexMap::new(), + poll_index: 0, + waker: None, } } @@ -785,6 +790,7 @@ where match self.substreams.entry(key) { Entry::Vacant(entry) => { entry.insert(substream); + self.waker.take().map(|waker| waker.wake()); } Entry::Occupied(_) => { tracing::error!(?key, "substream already exists"); @@ -795,7 +801,15 @@ where /// Remove substream from the set. pub fn remove(&mut self, key: &K) -> Option { - self.substreams.remove(key) + // The `swap_remove()` changes the order of elements in the map, + // however it completes in O(1). This is acceptable since the + // alternative of calling `shift_remove()` would be O(n). + let Some(substream) = self.substreams.swap_remove(key) else { + return None; + }; + + self.waker.take().map(|waker| waker.wake()); + Some(substream) } /// Get mutable reference to stored substream. @@ -825,8 +839,15 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let inner = Pin::into_inner(self); - // TODO: poll the streams more randomly - for (key, mut substream) in inner.substreams.iter_mut() { + let len = inner.substreams.len(); + + for _ in 0..len { + let index = inner.poll_index % len; + inner.poll_index = (inner.poll_index + 1) % len; + + let (key, mut substream) = + inner.substreams.get_index_mut(index).expect("Index within range; qed"); + match Pin::new(&mut substream).poll_next(cx) { Poll::Pending => continue, Poll::Ready(Some(data)) => return Poll::Ready(Some((*key, data))), @@ -835,6 +856,7 @@ where } } + inner.waker = Some(cx.waker().clone()); Poll::Pending } }