Skip to content

Commit 8e5a4b5

Browse files
authored
Fix peek reregister after would block (#1895)
1 parent 80896b3 commit 8e5a4b5

File tree

3 files changed

+131
-1
lines changed

3 files changed

+131
-1
lines changed

src/net/tcp/stream.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ impl TcpStream {
212212
/// Successive calls return the same data. This is accomplished by passing
213213
/// `MSG_PEEK` as a flag to the underlying recv system call.
214214
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
215-
self.inner.peek(buf)
215+
// Need to re-register if `peek` returns `WouldBlock`
216+
// to ensure the socket will receive more events once it is ready again.
217+
self.inner.do_io(|inner| inner.peek(buf))
216218
}
217219

218220
/// Execute an I/O operation ensuring that the socket receives more events

tests/tcp_stream.rs

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,3 +876,129 @@ fn send_oob_data<S: AsRawFd>(stream: &S, data: &[u8]) -> io::Result<usize> {
876876
}
877877
}
878878
}
879+
880+
#[test]
881+
fn peek_ok() {
882+
let mut buf = [0; 2];
883+
let (mut poll, mut events) = init_with_poll();
884+
885+
let listener = net::TcpListener::bind(any_local_address()).unwrap();
886+
let sockaddr = listener.local_addr().unwrap();
887+
let thread_handle = thread::spawn(move || listener.accept().unwrap());
888+
let stream1 = net::TcpStream::connect(sockaddr).unwrap();
889+
let (mut stream2, _) = thread_handle.join().unwrap();
890+
891+
stream1.set_nonblocking(true).unwrap();
892+
let mut stream1 = TcpStream::from_std(stream1);
893+
894+
poll.registry()
895+
.register(&mut stream1, ID1, Interest::READABLE)
896+
.unwrap();
897+
898+
expect_no_events(&mut poll, &mut events);
899+
900+
assert_eq!(stream2.write(&[0]).unwrap(), 1);
901+
// peek multiple times until we get a byte
902+
peek_until_ok(&mut buf, &mut stream1, 1);
903+
// a successful peek shouldn't remove readable interest
904+
// so we should still get a readable event
905+
expect_events(
906+
&mut poll,
907+
&mut events,
908+
vec![ExpectEvent::new(ID1, Readiness::READABLE)],
909+
);
910+
}
911+
912+
fn peek_until_ok<const N: usize>(buf: &mut [u8; N], stream1: &mut TcpStream, expected: usize) {
913+
loop {
914+
let res = stream1.peek(buf);
915+
match res {
916+
Ok(x) => {
917+
assert_eq!(x, expected);
918+
break;
919+
}
920+
Err(err) if err.kind() == io::ErrorKind::WouldBlock => continue,
921+
_ => panic!("Unexpected error: {:?}", res),
922+
}
923+
}
924+
}
925+
926+
#[test]
927+
fn peek_would_block() {
928+
let mut buf = [0; 1];
929+
let (mut poll, mut events) = init_with_poll();
930+
931+
let listener = net::TcpListener::bind(any_local_address()).unwrap();
932+
let sockaddr = listener.local_addr().unwrap();
933+
let thread_handle = thread::spawn(move || listener.accept().unwrap());
934+
let stream1 = net::TcpStream::connect(sockaddr).unwrap();
935+
let (mut stream2, _) = thread_handle.join().unwrap();
936+
937+
stream1.set_nonblocking(true).unwrap();
938+
let mut stream1 = TcpStream::from_std(stream1);
939+
940+
poll.registry()
941+
.register(&mut stream1, ID1, Interest::READABLE)
942+
.unwrap();
943+
944+
expect_no_events(&mut poll, &mut events);
945+
946+
assert_eq!(stream2.write(&[0]).unwrap(), 1);
947+
expect_events(
948+
&mut poll,
949+
&mut events,
950+
vec![ExpectEvent::new(ID1, Readiness::READABLE)],
951+
);
952+
953+
assert_eq!(stream1.read(&mut buf).unwrap(), 1);
954+
assert_would_block(stream1.peek(&mut buf));
955+
956+
assert_eq!(stream2.write(&[0, 1, 2, 3]).unwrap(), 4);
957+
958+
expect_events(
959+
&mut poll,
960+
&mut events,
961+
vec![ExpectEvent::new(ID1, Readiness::READABLE)],
962+
);
963+
}
964+
965+
#[test]
966+
fn read_peek_would_block() {
967+
let mut buf = [0; 1];
968+
let (mut poll, mut events) = init_with_poll();
969+
970+
let listener = net::TcpListener::bind(any_local_address()).unwrap();
971+
let sockaddr = listener.local_addr().unwrap();
972+
let thread_handle = thread::spawn(move || listener.accept().unwrap());
973+
let stream1 = net::TcpStream::connect(sockaddr).unwrap();
974+
let (mut stream2, _) = thread_handle.join().unwrap();
975+
976+
stream1.set_nonblocking(true).unwrap();
977+
let mut stream1 = TcpStream::from_std(stream1);
978+
979+
poll.registry()
980+
.register(&mut stream1, ID1, Interest::READABLE)
981+
.unwrap();
982+
983+
assert_would_block(stream1.read(&mut buf));
984+
985+
assert_eq!(stream2.write(&[0]).unwrap(), 1);
986+
987+
expect_events(
988+
&mut poll,
989+
&mut events,
990+
vec![ExpectEvent::new(ID1, Readiness::READABLE)],
991+
);
992+
993+
assert_eq!(stream1.read(&mut buf).unwrap(), 1);
994+
995+
assert_would_block(stream1.peek(&mut buf));
996+
997+
assert_eq!(stream2.write(&[1]).unwrap(), 1);
998+
999+
expect_events(
1000+
&mut poll,
1001+
&mut events,
1002+
vec![ExpectEvent::new(ID1, Readiness::READABLE)],
1003+
);
1004+
}

tests/util/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ impl From<Interest> for Readiness {
134134
}
135135
}
136136

137+
#[track_caller]
137138
pub fn expect_events(poll: &mut Poll, events: &mut Events, mut expected: Vec<ExpectEvent>) {
138139
// In a lot of calls we expect more then one event, but it could be that
139140
// poll returns the first event only in a single call. To be a bit more
@@ -164,6 +165,7 @@ pub fn expect_events(poll: &mut Poll, events: &mut Events, mut expected: Vec<Exp
164165
);
165166
}
166167

168+
#[track_caller]
167169
pub fn expect_no_events(poll: &mut Poll, events: &mut Events) {
168170
poll.poll(events, Some(Duration::from_millis(50)))
169171
.expect("unable to poll");

0 commit comments

Comments
 (0)