diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index 6aa59bfc7b..823513a279 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -1,9 +1,12 @@ package memory import ( + "bytes" "errors" "io" "net" + "sync" + "sync/atomic" "time" "github.com/libp2p/go-libp2p/core/network" @@ -14,59 +17,149 @@ type stream struct { id int64 conn *conn - read *io.PipeReader - write *io.PipeWriter - writeC chan []byte + wrMu sync.Mutex // Serialize Write operations + buf *bytes.Buffer // Buffer for partial reads - reset chan struct{} - close chan struct{} - closed chan struct{} + // Used by local Read to interact with remote Write. + rdRx <-chan []byte - writeErr error + // Used by local Write to interact with remote Read. + wrTx chan<- []byte + + once sync.Once // Protects closing localDone + localDone chan struct{} + remoteDone <-chan struct{} + + reset chan struct{} + close chan struct{} + readClosed atomic.Bool + writeClosed atomic.Bool } var ErrClosed = errors.New("stream closed") func newStreamPair() (*stream, *stream) { - ra, wb := io.Pipe() - rb, wa := io.Pipe() - - sa := newStream(wa, ra, network.DirOutbound) - sb := newStream(wb, rb, network.DirInbound) + io.Pipe() + + cb1 := make(chan []byte, 1) + cb2 := make(chan []byte, 1) + + done1 := make(chan struct{}) + done2 := make(chan struct{}) + + sa := &stream{ + id: streamCounter.Add(1), + rdRx: cb1, + wrTx: cb2, + buf: new(bytes.Buffer), + localDone: done1, remoteDone: done2, + reset: make(chan struct{}, 1), + close: make(chan struct{}, 1), + } + sb := &stream{ + rdRx: cb2, + wrTx: cb1, + buf: new(bytes.Buffer), + localDone: done2, remoteDone: done1, + reset: make(chan struct{}, 1), + close: make(chan struct{}, 1), + } return sa, sb } -func newStream(w *io.PipeWriter, r *io.PipeReader, _ network.Direction) *stream { +func newStream(rdRx <-chan []byte, wrTx chan<- []byte, localDone chan struct{}, remoteDone <-chan struct{}) *stream { s := &stream{ - id: streamCounter.Add(1), - read: r, - write: w, - writeC: make(chan []byte), - reset: make(chan struct{}, 1), - close: make(chan struct{}, 1), - closed: make(chan struct{}), + rdRx: rdRx, + wrTx: wrTx, + localDone: localDone, + remoteDone: remoteDone, + reset: make(chan struct{}, 1), + close: make(chan struct{}, 1), } - go s.writeLoop() return s } -func (s *stream) Write(p []byte) (int, error) { - cpy := make([]byte, len(p)) - copy(cpy, p) +func (p *stream) Write(b []byte) (int, error) { + if p.writeClosed.Load() { + return 0, ErrClosed + } + + n, err := p.write(b) + if err != nil && err != io.ErrClosedPipe { + err = &net.OpError{Op: "write", Net: "pipe", Err: err} + } + return n, err +} + +func (p *stream) write(b []byte) (n int, err error) { + switch { + case isClosedChan(p.localDone): + return 0, io.ErrClosedPipe + case isClosedChan(p.remoteDone): + return 0, io.ErrClosedPipe + } + + p.wrMu.Lock() // Ensure entirety of b is written together + defer p.wrMu.Unlock() select { - case <-s.closed: - return 0, s.writeErr - case s.writeC <- cpy: + case <-p.close: + return n, ErrClosed + case <-p.reset: + return n, network.ErrReset + case p.wrTx <- b: + n += len(b) + case <-p.localDone: + return n, io.ErrClosedPipe + case <-p.remoteDone: + return n, io.ErrClosedPipe } - return len(p), nil + return n, nil } -func (s *stream) Read(p []byte) (int, error) { - return s.read.Read(p) +func (p *stream) Read(b []byte) (int, error) { + if p.readClosed.Load() { + return 0, ErrClosed + } + + n, err := p.read(b) + if err != nil && err != io.EOF && err != io.ErrClosedPipe { + err = &net.OpError{Op: "read", Net: "pipe", Err: err} + } + + return n, err +} + +func (p *stream) read(b []byte) (n int, err error) { + switch { + case isClosedChan(p.localDone): + return 0, io.ErrClosedPipe + case isClosedChan(p.remoteDone): + return 0, io.EOF + } + + select { + case <-p.reset: + return n, network.ErrReset + case bw, ok := <-p.rdRx: + if !ok { + p.readClosed.Store(true) + return 0, io.EOF + } + + p.buf.Write(bw) + case <-p.localDone: + return 0, io.ErrClosedPipe + case <-p.remoteDone: + return 0, io.EOF + default: + n, err = p.buf.Read(b) + } + + return n, err } func (s *stream) CloseWrite() error { @@ -74,16 +167,14 @@ func (s *stream) CloseWrite() error { case s.close <- struct{}{}: default: } - <-s.closed - if !errors.Is(s.writeErr, ErrClosed) { - return s.writeErr - } - return nil + s.writeClosed.Store(true) + return nil } func (s *stream) CloseRead() error { - return s.read.CloseWithError(ErrClosed) + s.readClosed.Store(true) + return nil } func (s *stream) Close() error { @@ -92,15 +183,15 @@ func (s *stream) Close() error { } func (s *stream) Reset() error { - // Cancel any pending reads/writes with an error. - s.write.CloseWithError(network.ErrReset) - s.read.CloseWithError(network.ErrReset) - select { case s.reset <- struct{}{}: default: } - <-s.closed + + s.once.Do(func() { + close(s.localDone) + }) + // No meaningful error case here. return nil } @@ -117,48 +208,11 @@ func (s *stream) SetWriteDeadline(t time.Time) error { return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } -func (s *stream) writeLoop() { - defer s.teardown() - - for { - // Reset takes precedent. - select { - case <-s.reset: - s.writeErr = network.ErrReset - return - default: - } - - select { - case <-s.reset: - s.writeErr = network.ErrReset - return - case <-s.close: - s.writeErr = s.write.Close() - if s.writeErr == nil { - s.writeErr = ErrClosed - } - return - case p := <-s.writeC: - if _, err := s.write.Write(p); err != nil { - s.cancelWrite(err) - return - } - } - } -} - -func (s *stream) cancelWrite(err error) { - s.write.CloseWithError(err) - s.writeErr = err -} - -func (s *stream) teardown() { - // at this point, no streams are writing. - if s.conn != nil { - s.conn.removeStream(s.id) +func isClosedChan(c <-chan struct{}) bool { + select { + case <-c: + return true + default: + return false } - - // Mark as closed. - close(s.closed) } diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go index 1f3e8fb0cd..f154bbbcdd 100644 --- a/p2p/transport/memory/stream_test.go +++ b/p2p/transport/memory/stream_test.go @@ -8,62 +8,60 @@ import ( func TestStreamSimpleReadWriteClose(t *testing.T) { t.Parallel() - clientStr, serverStr := newStreamPair() + streamLocal, streamRemote := newStreamPair() // send a foobar from the client - n, err := clientStr.Write([]byte("foobar")) + n, err := streamLocal.Write([]byte("foobar")) require.NoError(t, err) require.Equal(t, 6, n) - require.NoError(t, clientStr.CloseWrite()) + require.NoError(t, streamLocal.CloseWrite()) // writing after closing should error - _, err = clientStr.Write([]byte("foobar")) + _, err = streamLocal.Write([]byte("foobar")) require.Error(t, err) // now read all the data on the server side - b, err := io.ReadAll(serverStr) + b, err := io.ReadAll(streamRemote) require.NoError(t, err) require.Equal(t, []byte("foobar"), b) // reading again should give another io.EOF - n, err = serverStr.Read(make([]byte, 10)) + n, err = streamRemote.Read(make([]byte, 10)) require.Zero(t, n) require.ErrorIs(t, err, io.EOF) // send something back - _, err = serverStr.Write([]byte("lorem ipsum")) + _, err = streamRemote.Write([]byte("lorem ipsum")) require.NoError(t, err) - require.NoError(t, serverStr.CloseWrite()) + require.NoError(t, streamRemote.CloseWrite()) // and read it at the client - b, err = io.ReadAll(clientStr) + b, err = io.ReadAll(streamLocal) require.NoError(t, err) require.Equal(t, []byte("lorem ipsum"), b) // stream is only cleaned up on calling Close or Reset - clientStr.Close() - serverStr.Close() - // Need to call Close for cleanup. Otherwise the FIN_ACK is never read - require.NoError(t, serverStr.Close()) + require.NoError(t, streamLocal.Close()) + require.NoError(t, streamRemote.Close()) } func TestStreamPartialReads(t *testing.T) { t.Parallel() - clientStr, serverStr := newStreamPair() + streamLocal, streamRemote := newStreamPair() - _, err := serverStr.Write([]byte("foobar")) + _, err := streamRemote.Write([]byte("foobar")) require.NoError(t, err) - require.NoError(t, serverStr.CloseWrite()) + require.NoError(t, streamRemote.CloseWrite()) - n, err := clientStr.Read([]byte{}) // empty read + n, err := streamLocal.Read([]byte{}) // empty read require.NoError(t, err) require.Zero(t, n) b := make([]byte, 3) - n, err = clientStr.Read(b) + n, err = streamLocal.Read(b) require.Equal(t, 3, n) require.NoError(t, err) require.Equal(t, []byte("foo"), b) - b, err = io.ReadAll(clientStr) + b, err = io.ReadAll(streamLocal) require.NoError(t, err) require.Equal(t, []byte("bar"), b) }