Skip to content

Commit

Permalink
Revert to stream using channels instead of io.Pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
pyropy committed Nov 25, 2024
1 parent 5236ff2 commit 76e141f
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 103 deletions.
222 changes: 138 additions & 84 deletions p2p/transport/memory/stream.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package memory

import (
"bytes"
"errors"
"io"
"net"
"sync"
"sync/atomic"
"time"

"github.com/libp2p/go-libp2p/core/network"
Expand All @@ -14,76 +17,164 @@ type stream struct {
id int64
conn *conn

read *io.PipeReader
write *io.PipeWriter
writeC chan []byte
wrMu sync.Mutex // Serialize Write operations
buf *bytes.Buffer // Buffer for partial reads

reset chan struct{}
close chan struct{}
closed chan struct{}
// Used by local Read to interact with remote Write.
rdRx <-chan []byte

writeErr error
// Used by local Write to interact with remote Read.
wrTx chan<- []byte

once sync.Once // Protects closing localDone
localDone chan struct{}
remoteDone <-chan struct{}

reset chan struct{}
close chan struct{}
readClosed atomic.Bool
writeClosed atomic.Bool
}

var ErrClosed = errors.New("stream closed")

func newStreamPair() (*stream, *stream) {
ra, wb := io.Pipe()
rb, wa := io.Pipe()

sa := newStream(wa, ra, network.DirOutbound)
sb := newStream(wb, rb, network.DirInbound)
io.Pipe()

cb1 := make(chan []byte, 1)
cb2 := make(chan []byte, 1)

done1 := make(chan struct{})
done2 := make(chan struct{})

sa := &stream{
id: streamCounter.Add(1),
rdRx: cb1,
wrTx: cb2,
buf: new(bytes.Buffer),
localDone: done1, remoteDone: done2,
reset: make(chan struct{}, 1),
close: make(chan struct{}, 1),
}
sb := &stream{
rdRx: cb2,
wrTx: cb1,
buf: new(bytes.Buffer),
localDone: done2, remoteDone: done1,
reset: make(chan struct{}, 1),
close: make(chan struct{}, 1),
}

return sa, sb
}

func newStream(w *io.PipeWriter, r *io.PipeReader, _ network.Direction) *stream {
func newStream(rdRx <-chan []byte, wrTx chan<- []byte, localDone chan struct{}, remoteDone <-chan struct{}) *stream {
s := &stream{
id: streamCounter.Add(1),
read: r,
write: w,
writeC: make(chan []byte),
reset: make(chan struct{}, 1),
close: make(chan struct{}, 1),
closed: make(chan struct{}),
rdRx: rdRx,
wrTx: wrTx,
localDone: localDone,
remoteDone: remoteDone,
reset: make(chan struct{}, 1),
close: make(chan struct{}, 1),
}

go s.writeLoop()
return s
}

func (s *stream) Write(p []byte) (int, error) {
cpy := make([]byte, len(p))
copy(cpy, p)
func (p *stream) Write(b []byte) (int, error) {
if p.writeClosed.Load() {
return 0, ErrClosed
}

n, err := p.write(b)
if err != nil && err != io.ErrClosedPipe {
err = &net.OpError{Op: "write", Net: "pipe", Err: err}
}
return n, err
}

func (p *stream) write(b []byte) (n int, err error) {
switch {
case isClosedChan(p.localDone):
return 0, io.ErrClosedPipe
case isClosedChan(p.remoteDone):
return 0, io.ErrClosedPipe
}

p.wrMu.Lock() // Ensure entirety of b is written together
defer p.wrMu.Unlock()

select {
case <-s.closed:
return 0, s.writeErr
case s.writeC <- cpy:
case <-p.close:
return n, ErrClosed
case <-p.reset:
return n, network.ErrReset
case p.wrTx <- b:
n += len(b)
case <-p.localDone:
return n, io.ErrClosedPipe
case <-p.remoteDone:
return n, io.ErrClosedPipe
}

return len(p), nil
return n, nil
}

func (s *stream) Read(p []byte) (int, error) {
return s.read.Read(p)
func (p *stream) Read(b []byte) (int, error) {
if p.readClosed.Load() {
return 0, ErrClosed
}

n, err := p.read(b)
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
err = &net.OpError{Op: "read", Net: "pipe", Err: err}
}

return n, err
}

func (p *stream) read(b []byte) (n int, err error) {
switch {
case isClosedChan(p.localDone):
return 0, io.ErrClosedPipe
case isClosedChan(p.remoteDone):
return 0, io.EOF
}

select {
case <-p.reset:
return n, network.ErrReset
case bw, ok := <-p.rdRx:
if !ok {
p.readClosed.Store(true)
return 0, io.EOF
}

p.buf.Write(bw)
case <-p.localDone:
return 0, io.ErrClosedPipe
case <-p.remoteDone:
return 0, io.EOF
default:
n, err = p.buf.Read(b)
}

return n, err
}

func (s *stream) CloseWrite() error {
select {
case s.close <- struct{}{}:
default:
}
<-s.closed
if !errors.Is(s.writeErr, ErrClosed) {
return s.writeErr
}
return nil

s.writeClosed.Store(true)
return nil
}

func (s *stream) CloseRead() error {
return s.read.CloseWithError(ErrClosed)
s.readClosed.Store(true)
return nil
}

func (s *stream) Close() error {
Expand All @@ -92,15 +183,15 @@ func (s *stream) Close() error {
}

func (s *stream) Reset() error {
// Cancel any pending reads/writes with an error.
s.write.CloseWithError(network.ErrReset)
s.read.CloseWithError(network.ErrReset)

select {
case s.reset <- struct{}{}:
default:
}
<-s.closed

s.once.Do(func() {
close(s.localDone)
})

// No meaningful error case here.
return nil
}
Expand All @@ -117,48 +208,11 @@ func (s *stream) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}

func (s *stream) writeLoop() {
defer s.teardown()

for {
// Reset takes precedent.
select {
case <-s.reset:
s.writeErr = network.ErrReset
return
default:
}

select {
case <-s.reset:
s.writeErr = network.ErrReset
return
case <-s.close:
s.writeErr = s.write.Close()
if s.writeErr == nil {
s.writeErr = ErrClosed
}
return
case p := <-s.writeC:
if _, err := s.write.Write(p); err != nil {
s.cancelWrite(err)
return
}
}
}
}

func (s *stream) cancelWrite(err error) {
s.write.CloseWithError(err)
s.writeErr = err
}

func (s *stream) teardown() {
// at this point, no streams are writing.
if s.conn != nil {
s.conn.removeStream(s.id)
func isClosedChan(c <-chan struct{}) bool {
select {
case <-c:
return true
default:
return false
}

// Mark as closed.
close(s.closed)
}
36 changes: 17 additions & 19 deletions p2p/transport/memory/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,62 +8,60 @@ import (

func TestStreamSimpleReadWriteClose(t *testing.T) {
t.Parallel()
clientStr, serverStr := newStreamPair()
streamLocal, streamRemote := newStreamPair()

// send a foobar from the client
n, err := clientStr.Write([]byte("foobar"))
n, err := streamLocal.Write([]byte("foobar"))
require.NoError(t, err)
require.Equal(t, 6, n)
require.NoError(t, clientStr.CloseWrite())
require.NoError(t, streamLocal.CloseWrite())

// writing after closing should error
_, err = clientStr.Write([]byte("foobar"))
_, err = streamLocal.Write([]byte("foobar"))
require.Error(t, err)

// now read all the data on the server side
b, err := io.ReadAll(serverStr)
b, err := io.ReadAll(streamRemote)
require.NoError(t, err)
require.Equal(t, []byte("foobar"), b)

// reading again should give another io.EOF
n, err = serverStr.Read(make([]byte, 10))
n, err = streamRemote.Read(make([]byte, 10))
require.Zero(t, n)
require.ErrorIs(t, err, io.EOF)

// send something back
_, err = serverStr.Write([]byte("lorem ipsum"))
_, err = streamRemote.Write([]byte("lorem ipsum"))
require.NoError(t, err)
require.NoError(t, serverStr.CloseWrite())
require.NoError(t, streamRemote.CloseWrite())

// and read it at the client
b, err = io.ReadAll(clientStr)
b, err = io.ReadAll(streamLocal)
require.NoError(t, err)
require.Equal(t, []byte("lorem ipsum"), b)

// stream is only cleaned up on calling Close or Reset
clientStr.Close()
serverStr.Close()
// Need to call Close for cleanup. Otherwise the FIN_ACK is never read
require.NoError(t, serverStr.Close())
require.NoError(t, streamLocal.Close())
require.NoError(t, streamRemote.Close())
}

func TestStreamPartialReads(t *testing.T) {
t.Parallel()
clientStr, serverStr := newStreamPair()
streamLocal, streamRemote := newStreamPair()

_, err := serverStr.Write([]byte("foobar"))
_, err := streamRemote.Write([]byte("foobar"))
require.NoError(t, err)
require.NoError(t, serverStr.CloseWrite())
require.NoError(t, streamRemote.CloseWrite())

n, err := clientStr.Read([]byte{}) // empty read
n, err := streamLocal.Read([]byte{}) // empty read
require.NoError(t, err)
require.Zero(t, n)
b := make([]byte, 3)
n, err = clientStr.Read(b)
n, err = streamLocal.Read(b)
require.Equal(t, 3, n)
require.NoError(t, err)
require.Equal(t, []byte("foo"), b)
b, err = io.ReadAll(clientStr)
b, err = io.ReadAll(streamLocal)
require.NoError(t, err)
require.Equal(t, []byte("bar"), b)
}

0 comments on commit 76e141f

Please sign in to comment.