diff --git a/neqo-transport/src/streams.rs b/neqo-transport/src/streams.rs index 2f2e713c98..1cb386e4bd 100644 --- a/neqo-transport/src/streams.rs +++ b/neqo-transport/src/streams.rs @@ -122,20 +122,6 @@ impl Streams { /// # Errors /// When the frame is invalid. pub fn input_frame(&mut self, frame: &Frame, stats: &mut FrameStats) -> Res<()> { - if let Frame::ResetStream { stream_id, .. } - | Frame::StopSending { stream_id, .. } - | Frame::Stream { stream_id, .. } - | Frame::MaxStreamData { stream_id, .. } - | Frame::StreamDataBlocked { stream_id, .. } = frame - { - if stream_id.is_remote_initiated(self.role) - || self.local_stream_limits[stream_id.stream_type()].used() > stream_id.index() - { - // Remote stream, or local stream that was never initiated. - return Err(Error::StreamStateError); - } - } - match frame { Frame::ResetStream { stream_id, @@ -145,6 +131,8 @@ impl Streams { stats.reset_stream += 1; if let (_, Some(rs)) = self.obtain_stream(*stream_id)? { rs.reset(*application_error_code, *final_size)?; + } else if !self.ensure_existed_if_local(*stream_id) { + return Err(Error::StreamStateError); } } Frame::StopSending { @@ -156,6 +144,8 @@ impl Streams { .send_stream_stop_sending(*stream_id, *application_error_code); if let (Some(ss), _) = self.obtain_stream(*stream_id)? { ss.reset(*application_error_code); + } else if !self.ensure_existed_if_local(*stream_id) { + return Err(Error::StreamStateError); } } Frame::Stream { @@ -168,6 +158,8 @@ impl Streams { stats.stream += 1; if let (_, Some(rs)) = self.obtain_stream(*stream_id)? { rs.inbound_stream_frame(*fin, *offset, data)?; + } else if !self.ensure_existed_if_local(*stream_id) { + return Err(Error::StreamStateError); } } Frame::MaxData { maximum_data } => { @@ -186,6 +178,8 @@ impl Streams { stats.max_stream_data += 1; if let (Some(ss), _) = self.obtain_stream(*stream_id)? { ss.set_max_stream_data(*maximum_stream_data); + } else if !self.ensure_existed_if_local(*stream_id) { + return Err(Error::StreamStateError); } } Frame::MaxStreams { @@ -212,6 +206,8 @@ impl Streams { if let (_, Some(rs)) = self.obtain_stream(*stream_id)? { rs.send_flowc_update(); + } else if !self.ensure_existed_if_local(*stream_id) { + return Err(Error::StreamStateError); } } Frame::StreamsBlocked { .. } => { @@ -356,6 +352,10 @@ impl Streams { self.remote_stream_limits[StreamType::UniDi].add_retired(removed_uni); } + fn ensure_existed_if_local(&self, stream_id: StreamId) -> bool { + !stream_id.is_remote_initiated(self.role) + && self.local_stream_limits[stream_id.stream_type()].used() > stream_id.index() + } fn ensure_created_if_remote(&mut self, stream_id: StreamId) -> Res<()> { if !stream_id.is_remote_initiated(self.role) || !self.remote_stream_limits[stream_id.stream_type()].is_new_stream(stream_id)?