From 39983c2a3f2845b6c2dc7983591f458ad23ce652 Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Tue, 22 Oct 2024 21:05:46 +0200 Subject: [PATCH 01/12] Add memory transport --- p2p/transport/memory/conn.go | 76 ++++++++++++++++++ p2p/transport/memory/stream.go | 124 ++++++++++++++++++++++++++++++ p2p/transport/memory/transport.go | 7 ++ 3 files changed, 207 insertions(+) create mode 100644 p2p/transport/memory/conn.go create mode 100644 p2p/transport/memory/stream.go create mode 100644 p2p/transport/memory/transport.go diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go new file mode 100644 index 0000000000..2665f081bb --- /dev/null +++ b/p2p/transport/memory/conn.go @@ -0,0 +1,76 @@ +package memory + +import ( + "context" + "sync" + + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" +) + +type conn struct { + transport *transport + scope network.ConnManagementScope + + localPeer peer.ID + localMultiaddr ma.Multiaddr + + remotePeerID peer.ID + remotePubKey ic.PubKey + remoteMultiaddr ma.Multiaddr + + closed bool + closeOnce sync.Once +} + +var _ tpt.CapableConn = &conn{} + +func (c *conn) Close() error { + c.closeOnce.Do(func() { + c.closed = true + c.transport.removeConn(c) + }) + + return nil +} + +func (c *conn) IsClosed() bool { + return c.closed +} + +func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { + return newStream(), nil +} + +func (c *conn) AcceptStream() (network.MuxedStream, error) { + return nil, nil +} + +func (c *conn) LocalPeer() peer.ID { return c.localPeer } + +// RemotePeer returns the peer ID of the remote peer. +func (c *conn) RemotePeer() peer.ID { return c.remotePeerID } + +// RemotePublicKey returns the public key of the remote peer. +func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } + +// LocalMultiaddr returns the local Multiaddr associated +func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.localMultiaddr } + +// RemoteMultiaddr returns the remote Multiaddr associated +func (c *conn) RemoteMultiaddr() ma.Multiaddr { return c.remoteMultiaddr } + +func (c *conn) Transport() tpt.Transport { + // TODO: return c.transport + return nil +} + +func (c *conn) Scope() network.ConnScope { return c.scope } + +// ConnState is the state of security connection. +func (c *conn) ConnState() network.ConnectionState { + return network.ConnectionState{Transport: "memory"} +} diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go new file mode 100644 index 0000000000..6b8555dcba --- /dev/null +++ b/p2p/transport/memory/stream.go @@ -0,0 +1,124 @@ +package memory + +import ( + "errors" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/network" +) + +type stream struct { + inC <-chan []byte + outC chan<- []byte + + readCloseC chan struct{} + writeCloseC chan struct{} + + mu sync.Mutex + closed bool + + deadline time.Time + readDeadline time.Time + writeDeadline time.Time +} + +func newStream() *stream { + return &stream{ + inC: make(<-chan []byte), + outC: make(chan<- []byte), + readCloseC: make(chan struct{}), + writeCloseC: make(chan struct{}), + } +} + +func (s *stream) Read(b []byte) (n int, err error) { + if s.closed { + return 0, network.ErrReset + } + + select { + case <-s.readCloseC: + err = network.ErrReset + case r, ok := <-s.inC: + if !ok { + err = network.ErrReset + } else { + n = copy(b, r) + } + } + + return n, err +} + +func (s *stream) Write(b []byte) (n int, err error) { + select { + case <-s.writeCloseC: + err = network.ErrReset + case s.outC <- b: + n = len(b) + default: + err = network.ErrReset + } + + return n, err +} + +func (s *stream) Reset() error { + s.CloseWrite() + s.CloseRead() + return nil +} + +func (s *stream) Close() error { + s.CloseRead() + + s.mu.Lock() + s.closed = true + s.mu.Unlock() + + return nil +} + +func (s *stream) CloseRead() error { + select { + case s.readCloseC <- struct{}{}: + default: + return errors.New("failed to close stream read") + } + + return nil +} + +func (s *stream) CloseWrite() error { + select { + case s.writeCloseC <- struct{}{}: + default: + return errors.New("failed to close stream write") + } + + return nil +} + +func (s *stream) SetDeadline(deadline time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.deadline = deadline + return nil +} + +func (s *stream) SetReadDeadline(readDeadline time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.readDeadline = readDeadline + return nil +} +func (s *stream) SetWriteDeadline(writeDeadline time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.writeDeadline = writeDeadline + return nil +} diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go new file mode 100644 index 0000000000..d5850d7fb0 --- /dev/null +++ b/p2p/transport/memory/transport.go @@ -0,0 +1,7 @@ +package memory + +type transport struct { +} + +func (t *transport) removeConn(c *conn) { +} From 37899d3c0d9cd1fdeb54b6e955deb74a833757c7 Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Wed, 23 Oct 2024 22:46:00 +0200 Subject: [PATCH 02/12] Daily commit --- p2p/transport/memory/conn.go | 65 ++++++++++++++++++--- p2p/transport/memory/listener.go | 62 +++++++++++++++++++++ p2p/transport/memory/stream.go | 63 ++++++++++----------- p2p/transport/memory/transport.go | 93 +++++++++++++++++++++++++++++++ 4 files changed, 240 insertions(+), 43 deletions(-) create mode 100644 p2p/transport/memory/listener.go diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go index 2665f081bb..6dac7f87a1 100644 --- a/p2p/transport/memory/conn.go +++ b/p2p/transport/memory/conn.go @@ -3,6 +3,7 @@ package memory import ( "context" "sync" + "sync/atomic" ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" @@ -12,6 +13,8 @@ import ( ) type conn struct { + id int32 + transport *transport scope network.ConnManagementScope @@ -22,15 +25,35 @@ type conn struct { remotePubKey ic.PubKey remoteMultiaddr ma.Multiaddr - closed bool + isClosed atomic.Bool closeOnce sync.Once + + mu sync.Mutex + + streamC chan *stream + + nextStreamID atomic.Int32 + streams map[int32]network.MuxedStream } var _ tpt.CapableConn = &conn{} +func newConnection(id int32, s *stream) *conn { + c := &conn{ + id: id, + streamC: make(chan *stream, 1), + streams: make(map[int32]network.MuxedStream), + } + + streamID := c.nextStreamID.Add(1) + c.addStream(streamID, s) + + return c +} + func (c *conn) Close() error { c.closeOnce.Do(func() { - c.closed = true + c.isClosed.Store(true) c.transport.removeConn(c) }) @@ -38,15 +61,26 @@ func (c *conn) Close() error { } func (c *conn) IsClosed() bool { - return c.closed + return c.isClosed.Load() } func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { - return newStream(), nil + id := c.nextStreamID.Add(1) + ra := make(chan []byte) + wa := make(chan []byte) + + return newStream(id, ra, wa), nil } func (c *conn) AcceptStream() (network.MuxedStream, error) { - return nil, nil + select { + case in := <-c.streamC: + id := c.nextStreamID.Add(1) + s := newStream(id, in.outC, in.inC) + c.addStream(id, s) + + return s, nil + } } func (c *conn) LocalPeer() peer.ID { return c.localPeer } @@ -64,13 +98,28 @@ func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.localMultiaddr } func (c *conn) RemoteMultiaddr() ma.Multiaddr { return c.remoteMultiaddr } func (c *conn) Transport() tpt.Transport { - // TODO: return c.transport - return nil + return c.transport } -func (c *conn) Scope() network.ConnScope { return c.scope } +func (c *conn) Scope() network.ConnScope { + return c.scope +} // ConnState is the state of security connection. func (c *conn) ConnState() network.ConnectionState { return network.ConnectionState{Transport: "memory"} } + +func (c *conn) addStream(id int32, stream network.MuxedStream) { + c.mu.Lock() + defer c.mu.Unlock() + + c.streams[id] = stream +} + +func (c *conn) removeStream(id int32) { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.streams, id) +} diff --git a/p2p/transport/memory/listener.go b/p2p/transport/memory/listener.go new file mode 100644 index 0000000000..a53f317815 --- /dev/null +++ b/p2p/transport/memory/listener.go @@ -0,0 +1,62 @@ +package memory + +import ( + "context" + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" + "net" + "sync" + "sync/atomic" +) + +type listener struct { + ctx context.Context + cancel context.CancelFunc + laddr ma.Multiaddr + + mu sync.Mutex + connID atomic.Int32 + streamCh chan *stream + connections map[int32]*conn +} + +func (l *listener) Multiaddr() ma.Multiaddr { + return l.laddr +} + +func newListener(laddr ma.Multiaddr, streamCh chan *stream) tpt.Listener { + ctx, cancel := context.WithCancel(context.Background()) + return &listener{ + ctx: ctx, + cancel: cancel, + laddr: laddr, + streamCh: streamCh, + } +} + +// Accept accepts new connections. +func (l *listener) Accept() (tpt.CapableConn, error) { + select { + case s := <-l.streamCh: + l.mu.Lock() + defer l.mu.Unlock() + + id := l.connID.Add(1) + c := newConnection(id, s) + l.connections[id] = c + return nil, nil + case <-l.ctx.Done(): + return nil, l.ctx.Err() + } +} + +// Close closes the listener. +func (l *listener) Close() error { + l.cancel() + return nil +} + +// Addr returns the address of this listener. +func (l *listener) Addr() net.Addr { + return nil +} diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index 6b8555dcba..e816daf952 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -2,38 +2,36 @@ package memory import ( "errors" - "sync" + "sync/atomic" "time" "github.com/libp2p/go-libp2p/core/network" ) type stream struct { - inC <-chan []byte - outC chan<- []byte + id int32 + + inC chan []byte + outC chan []byte readCloseC chan struct{} writeCloseC chan struct{} - mu sync.Mutex - closed bool - - deadline time.Time - readDeadline time.Time - writeDeadline time.Time + closed atomic.Bool } -func newStream() *stream { +func newStream(id int32, in, out chan []byte) *stream { return &stream{ - inC: make(<-chan []byte), - outC: make(chan<- []byte), + id: id, + inC: in, + outC: out, readCloseC: make(chan struct{}), writeCloseC: make(chan struct{}), } } func (s *stream) Read(b []byte) (n int, err error) { - if s.closed { + if s.closed.Load() { return 0, network.ErrReset } @@ -52,6 +50,10 @@ func (s *stream) Read(b []byte) (n int, err error) { } func (s *stream) Write(b []byte) (n int, err error) { + if s.closed.Load() { + return 0, network.ErrReset + } + select { case <-s.writeCloseC: err = network.ErrReset @@ -65,18 +67,21 @@ func (s *stream) Write(b []byte) (n int, err error) { } func (s *stream) Reset() error { - s.CloseWrite() - s.CloseRead() + if err := s.CloseWrite(); err != nil { + return err + } + if err := s.CloseRead(); err != nil { + return err + } return nil } func (s *stream) Close() error { - s.CloseRead() - - s.mu.Lock() - s.closed = true - s.mu.Unlock() + if err := s.CloseRead(); err != nil { + return err + } + s.closed.Store(true) return nil } @@ -100,25 +105,13 @@ func (s *stream) CloseWrite() error { return nil } -func (s *stream) SetDeadline(deadline time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - - s.deadline = deadline +func (s *stream) SetDeadline(_ time.Time) error { return nil } -func (s *stream) SetReadDeadline(readDeadline time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - - s.readDeadline = readDeadline +func (s *stream) SetReadDeadline(_ time.Time) error { return nil } -func (s *stream) SetWriteDeadline(writeDeadline time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - - s.writeDeadline = writeDeadline +func (s *stream) SetWriteDeadline(_ time.Time) error { return nil } diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go index d5850d7fb0..e7f0f27e02 100644 --- a/p2p/transport/memory/transport.go +++ b/p2p/transport/memory/transport.go @@ -1,7 +1,100 @@ package memory +import ( + "context" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" + "sync" + "sync/atomic" +) + type transport struct { + rcmgr network.ResourceManager + + mu sync.RWMutex + + connID atomic.Int32 + listeners map[ma.Multiaddr]*listener + connections map[int32]*conn +} + +func NewTransport() *transport { + return &transport{ + connections: make(map[int32]*conn), + } +} + +func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { + scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) + if err != nil { + return nil, err + } + + c, err := t.dialWithScope(ctx, raddr, p, scope) + if err != nil { + return nil, err + } + + return c, nil +} + +func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { + if err := scope.SetPeer(p); err != nil { + return nil, err + } + + // TODO: Check if there is an existing listener for this address + t.mu.RLock() + defer t.mu.RUnlock() + l := t.listeners[raddr] + + in := make(chan []byte) + out := make(chan []byte) + s := newStream(0, in, out) + l.streamCh <- s + + return newConnection(0, s), nil +} + +func (t *transport) CanDial(addr ma.Multiaddr) bool { + return true +} + +func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { + // TODO: Figure out correct channel type + return newListener(laddr, nil), nil +} + +func (t *transport) Proxy() bool { + return false +} + +// Protocols returns the set of protocols handled by this transport. +func (t *transport) Protocols() []int { + return []int{777} +} + +func (t *transport) String() string { + return "MemoryTransport" +} + +func (t *transport) Close() error { + // TODO: Go trough all listeners and close them + return nil +} + +func (t *transport) addConn(c *conn) { + t.mu.Lock() + defer t.mu.Unlock() + + t.connections[c.id] = c } func (t *transport) removeConn(c *conn) { + t.mu.Lock() + defer t.mu.Unlock() + + delete(t.connections, c.id) } From 0df38f9190159806c237412c955d767f68a7d1d8 Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Thu, 24 Oct 2024 22:03:54 +0200 Subject: [PATCH 03/12] Daily commit --- p2p/test/transport/transport_test.go | 16 ++++++++ p2p/transport/memory/conn.go | 12 +++--- p2p/transport/memory/listener.go | 24 ++++++----- p2p/transport/memory/stream.go | 39 +++++++----------- p2p/transport/memory/stream_test.go | 55 ++++++++++++++++++++++++++ p2p/transport/memory/transport.go | 59 ++++++++++++++++++++++------ 6 files changed, 151 insertions(+), 54 deletions(-) create mode 100644 p2p/transport/memory/stream_test.go diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 7cfab5f3ca..e353ba6526 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -31,6 +31,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/noise" tls "github.com/libp2p/go-libp2p/p2p/security/tls" + libp2pmemory "github.com/libp2p/go-libp2p/p2p/transport/memory" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "go.uber.org/mock/gomock" @@ -156,6 +157,21 @@ var transportsToTest = []TransportTestCase{ return h }, }, + { + Name: "Memory", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + libp2pOpts = append(libp2pOpts, libp2p.Transport(libp2pmemory.NewTransport)) + if opts.NoListen { + libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) + } else { + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/memory/1234")) + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return h + }, + }, } func TestPing(t *testing.T) { diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go index 6dac7f87a1..b01f05ed1a 100644 --- a/p2p/transport/memory/conn.go +++ b/p2p/transport/memory/conn.go @@ -2,6 +2,7 @@ package memory import ( "context" + "io" "sync" "sync/atomic" @@ -66,8 +67,8 @@ func (c *conn) IsClosed() bool { func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { id := c.nextStreamID.Add(1) - ra := make(chan []byte) - wa := make(chan []byte) + // TODO: Figure out how to exchange the pipes between the two streams + ra, wa := io.Pipe() return newStream(id, ra, wa), nil } @@ -76,10 +77,9 @@ func (c *conn) AcceptStream() (network.MuxedStream, error) { select { case in := <-c.streamC: id := c.nextStreamID.Add(1) - s := newStream(id, in.outC, in.inC) - c.addStream(id, s) + c.addStream(id, in) - return s, nil + return in, nil } } @@ -88,7 +88,7 @@ func (c *conn) LocalPeer() peer.ID { return c.localPeer } // RemotePeer returns the peer ID of the remote peer. func (c *conn) RemotePeer() peer.ID { return c.remotePeerID } -// RemotePublicKey returns the public key of the remote peer. +// RemotePublicKey returns the public pkey of the remote peer. func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } // LocalMultiaddr returns the local Multiaddr associated diff --git a/p2p/transport/memory/listener.go b/p2p/transport/memory/listener.go index a53f317815..8041e02aae 100644 --- a/p2p/transport/memory/listener.go +++ b/p2p/transport/memory/listener.go @@ -6,17 +6,16 @@ import ( ma "github.com/multiformats/go-multiaddr" "net" "sync" - "sync/atomic" ) type listener struct { + t *transport ctx context.Context cancel context.CancelFunc laddr ma.Multiaddr mu sync.Mutex - connID atomic.Int32 - streamCh chan *stream + connCh chan *conn connections map[int32]*conn } @@ -24,27 +23,26 @@ func (l *listener) Multiaddr() ma.Multiaddr { return l.laddr } -func newListener(laddr ma.Multiaddr, streamCh chan *stream) tpt.Listener { +func newListener(t *transport, laddr ma.Multiaddr) *listener { ctx, cancel := context.WithCancel(context.Background()) return &listener{ - ctx: ctx, - cancel: cancel, - laddr: laddr, - streamCh: streamCh, + t: t, + ctx: ctx, + cancel: cancel, + laddr: laddr, + connCh: make(chan *conn, listenerQueueSize), } } // Accept accepts new connections. func (l *listener) Accept() (tpt.CapableConn, error) { select { - case s := <-l.streamCh: + case c := <-l.connCh: l.mu.Lock() defer l.mu.Unlock() - id := l.connID.Add(1) - c := newConnection(id, s) - l.connections[id] = c - return nil, nil + l.connections[c.id] = c + return c, nil case <-l.ctx.Done(): return nil, l.ctx.Err() } diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index e816daf952..4e425ee5af 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -2,6 +2,7 @@ package memory import ( "errors" + "io" "sync/atomic" "time" @@ -11,8 +12,8 @@ import ( type stream struct { id int32 - inC chan []byte - outC chan []byte + r *io.PipeReader + w *io.PipeWriter readCloseC chan struct{} writeCloseC chan struct{} @@ -20,50 +21,40 @@ type stream struct { closed atomic.Bool } -func newStream(id int32, in, out chan []byte) *stream { +func newStream(id int32, r *io.PipeReader, w *io.PipeWriter) *stream { return &stream{ id: id, - inC: in, - outC: out, - readCloseC: make(chan struct{}), - writeCloseC: make(chan struct{}), + r: r, + w: w, + readCloseC: make(chan struct{}, 1), + writeCloseC: make(chan struct{}, 1), } } -func (s *stream) Read(b []byte) (n int, err error) { +func (s *stream) Read(b []byte) (int, error) { if s.closed.Load() { return 0, network.ErrReset } select { case <-s.readCloseC: - err = network.ErrReset - case r, ok := <-s.inC: - if !ok { - err = network.ErrReset - } else { - n = copy(b, r) - } + return 0, network.ErrReset + default: + return s.r.Read(b) } - - return n, err } -func (s *stream) Write(b []byte) (n int, err error) { +func (s *stream) Write(b []byte) (int, error) { if s.closed.Load() { return 0, network.ErrReset } select { case <-s.writeCloseC: - err = network.ErrReset - case s.outC <- b: - n = len(b) + return 0, network.ErrReset default: - err = network.ErrReset + return s.w.Write(b) } - - return n, err } func (s *stream) Reset() error { diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go new file mode 100644 index 0000000000..844000cd9d --- /dev/null +++ b/p2p/transport/memory/stream_test.go @@ -0,0 +1,55 @@ +package memory + +import ( + "github.com/stretchr/testify/require" + "io" + "testing" +) + +func TestStreamSimpleReadWriteClose(t *testing.T) { + //client, server := getDetachedDataChannels(t) + ra, wb := io.Pipe() + rb, wa := io.Pipe() + + clientStr := newStream(0, ra, wa) + serverStr := newStream(0, rb, wb) + + // send a foobar from the client + n, err := clientStr.Write([]byte("foobar")) + require.NoError(t, err) + require.Equal(t, 6, n) + require.NoError(t, clientStr.CloseWrite()) + // writing after closing should error + _, err = clientStr.Write([]byte("foobar")) + require.Error(t, err) + //require.False(t, clientDone.Load()) + + // now read all the data on the server side + b, err := io.ReadAll(serverStr) + require.NoError(t, err) + require.Equal(t, []byte("foobar"), b) + // reading again should give another io.EOF + n, err = serverStr.Read(make([]byte, 10)) + require.Zero(t, n) + require.ErrorIs(t, err, io.EOF) + //require.False(t, serverDone.Load()) + + // send something back + _, err = serverStr.Write([]byte("lorem ipsum")) + require.NoError(t, err) + require.NoError(t, serverStr.CloseWrite()) + + // and read it at the client + //require.False(t, clientDone.Load()) + b, err = io.ReadAll(clientStr) + require.NoError(t, err) + require.Equal(t, []byte("lorem ipsum"), b) + + // stream is only cleaned up on calling Close or Reset + clientStr.Close() + serverStr.Close() + //require.Eventually(t, func() bool { return clientDone.Load() }, 5*time.Second, 100*time.Millisecond) + // Need to call Close for cleanup. Otherwise the FIN_ACK is never read + require.NoError(t, serverStr.Close()) + //require.Eventually(t, func() bool { return serverDone.Load() }, 5*time.Second, 100*time.Millisecond) +} diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go index e7f0f27e02..02eb1d24ee 100644 --- a/p2p/transport/memory/transport.go +++ b/p2p/transport/memory/transport.go @@ -2,28 +2,52 @@ package memory import ( "context" + ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/pnet" tpt "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" + "io" "sync" "sync/atomic" ) +const ( + listenerQueueSize = 16 +) + type transport struct { + pkey ic.PrivKey + pid peer.ID + psk pnet.PSK rcmgr network.ResourceManager mu sync.RWMutex connID atomic.Int32 - listeners map[ma.Multiaddr]*listener + listeners map[string]*listener connections map[int32]*conn } -func NewTransport() *transport { +func NewTransport(key ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManager) (tpt.Transport, error) { + if rcmgr == nil { + rcmgr = &network.NullResourceManager{} + } + + id, err := peer.IDFromPrivateKey(key) + if err != nil { + return nil, err + } + return &transport{ + rcmgr: rcmgr, + pid: id, + pkey: key, + psk: psk, + listeners: make(map[string]*listener), connections: make(map[int32]*conn), - } + }, nil } func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { @@ -48,14 +72,16 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee // TODO: Check if there is an existing listener for this address t.mu.RLock() defer t.mu.RUnlock() - l := t.listeners[raddr] + l := t.listeners[raddr.String()] - in := make(chan []byte) - out := make(chan []byte) - s := newStream(0, in, out) - l.streamCh <- s + ra, wb := io.Pipe() + rb, wa := io.Pipe() + in, out := newStream(0, ra, wb), newStream(0, rb, wa) + inId, outId := t.connID.Add(1), t.connID.Add(1) - return newConnection(0, s), nil + l.connCh <- newConnection(inId, in) + + return newConnection(outId, out), nil } func (t *transport) CanDial(addr ma.Multiaddr) bool { @@ -63,8 +89,15 @@ func (t *transport) CanDial(addr ma.Multiaddr) bool { } func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { - // TODO: Figure out correct channel type - return newListener(laddr, nil), nil + // TODO: Check if we need to add scope via conn mngr + l := newListener(t, laddr) + + t.mu.Lock() + defer t.mu.Unlock() + + t.listeners[laddr.String()] = l + + return l, nil } func (t *transport) Proxy() bool { @@ -82,6 +115,10 @@ func (t *transport) String() string { func (t *transport) Close() error { // TODO: Go trough all listeners and close them + for _, l := range t.listeners { + l.Close() + } + return nil } From 67da1924dd80f832e39850cf785701f3182ad6aa Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Wed, 6 Nov 2024 15:54:05 +0100 Subject: [PATCH 04/12] Upgrade go-multiaddr --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 41d7730d39..d37a6ff09f 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,7 @@ require ( github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b github.com/mr-tron/base58 v1.2.0 github.com/multiformats/go-base32 v0.1.0 - github.com/multiformats/go-multiaddr v0.13.0 + github.com/multiformats/go-multiaddr v0.14.0 github.com/multiformats/go-multiaddr-dns v0.4.0 github.com/multiformats/go-multiaddr-fmt v0.1.0 github.com/multiformats/go-multibase v0.2.0 diff --git a/go.sum b/go.sum index df6db73cff..7b0c780196 100644 --- a/go.sum +++ b/go.sum @@ -233,8 +233,8 @@ github.com/multiformats/go-base32 v0.1.0/go.mod h1:Kj3tFY6zNr+ABYMqeUNeGvkIC/UYg github.com/multiformats/go-base36 v0.2.0 h1:lFsAbNOGeKtuKozrtBsAkSVhv1p9D0/qedU9rQyccr0= github.com/multiformats/go-base36 v0.2.0/go.mod h1:qvnKE++v+2MWCfePClUEjE78Z7P2a1UV0xHgWc0hkp4= github.com/multiformats/go-multiaddr v0.1.1/go.mod h1:aMKBKNEYmzmDmxfX88/vz+J5IU55txyt0p4aiWVohjo= -github.com/multiformats/go-multiaddr v0.13.0 h1:BCBzs61E3AGHcYYTv8dqRH43ZfyrqM8RXVPT8t13tLQ= -github.com/multiformats/go-multiaddr v0.13.0/go.mod h1:sBXrNzucqkFJhvKOiwwLyqamGa/P5EIXNPLovyhQCII= +github.com/multiformats/go-multiaddr v0.14.0 h1:bfrHrJhrRuh/NXH5mCnemjpbGjzRw/b+tJFOD41g2tU= +github.com/multiformats/go-multiaddr v0.14.0/go.mod h1:6EkVAxtznq2yC3QT5CM1UTAwG0GTP3EWAIcjHuzQ+r4= github.com/multiformats/go-multiaddr-dns v0.4.0 h1:P76EJ3qzBXpUXZ3twdCDx/kvagMsNo0LMFXpyms/zgU= github.com/multiformats/go-multiaddr-dns v0.4.0/go.mod h1:7hfthtB4E4pQwirrz+J0CcDUfbWzTqEzVyYKKIKpgkc= github.com/multiformats/go-multiaddr-fmt v0.1.0 h1:WLEFClPycPkp4fnIzoFoV9FVd49/eQsuaL3/CWe167E= From e2a5865925445b58964c86a225b11bc01f3e85a9 Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Fri, 8 Nov 2024 09:07:55 +0100 Subject: [PATCH 05/12] Daily commit --- p2p/transport/memory/conn.go | 42 +++++---- p2p/transport/memory/listener.go | 30 ++++--- p2p/transport/memory/stream.go | 56 ++++++------ p2p/transport/memory/stream_test.go | 5 +- p2p/transport/memory/transport.go | 131 +++++++++++++++++++++------- 5 files changed, 175 insertions(+), 89 deletions(-) diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go index b01f05ed1a..d864e93316 100644 --- a/p2p/transport/memory/conn.go +++ b/p2p/transport/memory/conn.go @@ -39,11 +39,24 @@ type conn struct { var _ tpt.CapableConn = &conn{} -func newConnection(id int32, s *stream) *conn { +func newConnection( + id int32, + s *stream, + localPeer peer.ID, + localMultiaddr ma.Multiaddr, + remotePubKey ic.PubKey, + remotePeer peer.ID, + remoteMultiaddr ma.Multiaddr, +) *conn { c := &conn{ - id: id, - streamC: make(chan *stream, 1), - streams: make(map[int32]network.MuxedStream), + id: id, + localPeer: localPeer, + localMultiaddr: localMultiaddr, + remotePubKey: remotePubKey, + remotePeerID: remotePeer, + remoteMultiaddr: remoteMultiaddr, + streamC: make(chan *stream, 1), + streams: make(map[int32]network.MuxedStream), } streamID := c.nextStreamID.Add(1) @@ -66,21 +79,20 @@ func (c *conn) IsClosed() bool { } func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { - id := c.nextStreamID.Add(1) - // TODO: Figure out how to exchange the pipes between the two streams - ra, wa := io.Pipe() + ra, wb := io.Pipe() + rb, wa := io.Pipe() + inConnId, outConnId := c.nextStreamID.Add(1), c.nextStreamID.Add(1) + inStream, outStream := newStream(inConnId, ra, wb), newStream(outConnId, rb, wa) - return newStream(id, ra, wa), nil + c.streamC <- inStream + return outStream, nil } func (c *conn) AcceptStream() (network.MuxedStream, error) { - select { - case in := <-c.streamC: - id := c.nextStreamID.Add(1) - c.addStream(id, in) - - return in, nil - } + in := <-c.streamC + id := c.nextStreamID.Add(1) + c.addStream(id, in) + return in, nil } func (c *conn) LocalPeer() peer.ID { return c.localPeer } diff --git a/p2p/transport/memory/listener.go b/p2p/transport/memory/listener.go index 8041e02aae..39e8acfb29 100644 --- a/p2p/transport/memory/listener.go +++ b/p2p/transport/memory/listener.go @@ -2,10 +2,15 @@ package memory import ( "context" - tpt "github.com/libp2p/go-libp2p/core/transport" - ma "github.com/multiformats/go-multiaddr" "net" "sync" + + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" +) + +const ( + listenerQueueSize = 16 ) type listener struct { @@ -26,25 +31,30 @@ func (l *listener) Multiaddr() ma.Multiaddr { func newListener(t *transport, laddr ma.Multiaddr) *listener { ctx, cancel := context.WithCancel(context.Background()) return &listener{ - t: t, - ctx: ctx, - cancel: cancel, - laddr: laddr, - connCh: make(chan *conn, listenerQueueSize), + t: t, + ctx: ctx, + cancel: cancel, + laddr: laddr, + connCh: make(chan *conn, listenerQueueSize), + connections: make(map[int32]*conn), } } // Accept accepts new connections. func (l *listener) Accept() (tpt.CapableConn, error) { select { - case c := <-l.connCh: + case <-l.ctx.Done(): + return nil, tpt.ErrListenerClosed + case c, ok := <-l.connCh: + if !ok { + return nil, tpt.ErrListenerClosed + } + l.mu.Lock() defer l.mu.Unlock() l.connections[c.id] = c return c, nil - case <-l.ctx.Done(): - return nil, l.ctx.Err() } } diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index 4e425ee5af..101ae516da 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -1,7 +1,6 @@ package memory import ( - "errors" "io" "sync/atomic" "time" @@ -12,8 +11,9 @@ import ( type stream struct { id int32 - r *io.PipeReader - w *io.PipeWriter + r *io.PipeReader + w *io.PipeWriter + writeC chan []byte readCloseC chan struct{} writeCloseC chan struct{} @@ -22,26 +22,33 @@ type stream struct { } func newStream(id int32, r *io.PipeReader, w *io.PipeWriter) *stream { - return &stream{ + s := &stream{ id: id, r: r, w: w, + writeC: make(chan []byte, 1), readCloseC: make(chan struct{}, 1), writeCloseC: make(chan struct{}, 1), } + + go func() { + for { + select { + case b := <-s.writeC: + if _, err := w.Write(b); err != nil { + return + } + case <-s.writeCloseC: + return + } + } + }() + + return s } func (s *stream) Read(b []byte) (int, error) { - if s.closed.Load() { - return 0, network.ErrReset - } - - select { - case <-s.readCloseC: - return 0, network.ErrReset - default: - return s.r.Read(b) - } + return s.r.Read(b) } func (s *stream) Write(b []byte) (int, error) { @@ -52,8 +59,8 @@ func (s *stream) Write(b []byte) (int, error) { select { case <-s.writeCloseC: return 0, network.ErrReset - default: - return s.w.Write(b) + case s.writeC <- b: + return len(b), nil } } @@ -68,31 +75,22 @@ func (s *stream) Reset() error { } func (s *stream) Close() error { - if err := s.CloseRead(); err != nil { - return err - } - - s.closed.Store(true) + s.CloseRead() + s.CloseWrite() return nil } func (s *stream) CloseRead() error { - select { - case s.readCloseC <- struct{}{}: - default: - return errors.New("failed to close stream read") - } - - return nil + return s.r.CloseWithError(network.ErrReset) } func (s *stream) CloseWrite() error { select { case s.writeCloseC <- struct{}{}: default: - return errors.New("failed to close stream write") } + s.closed.Store(true) return nil } diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go index 844000cd9d..33c3cbdc64 100644 --- a/p2p/transport/memory/stream_test.go +++ b/p2p/transport/memory/stream_test.go @@ -1,9 +1,10 @@ package memory import ( - "github.com/stretchr/testify/require" "io" "testing" + + "github.com/stretchr/testify/require" ) func TestStreamSimpleReadWriteClose(t *testing.T) { @@ -12,7 +13,7 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { rb, wa := io.Pipe() clientStr := newStream(0, ra, wa) - serverStr := newStream(0, rb, wb) + serverStr := newStream(1, rb, wb) // send a foobar from the client n, err := clientStr.Write([]byte("foobar")) diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go index 02eb1d24ee..5016e3a7dd 100644 --- a/p2p/transport/memory/transport.go +++ b/p2p/transport/memory/transport.go @@ -2,51 +2,110 @@ package memory import ( "context" + "errors" + "io" + "sync" + "sync/atomic" + ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/pnet" tpt "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" - "io" - "sync" - "sync/atomic" ) -const ( - listenerQueueSize = 16 -) +type hub struct { + mu sync.RWMutex + closeOnce sync.Once + pubKeys map[peer.ID]ic.PubKey + listeners map[string]*listener +} + +func (h *hub) addListener(addr string, l *listener) { + h.mu.Lock() + defer h.mu.Unlock() + + h.listeners[addr] = l +} + +func (h *hub) removeListener(addr string, l *listener) { + h.mu.Lock() + defer h.mu.Unlock() + + delete(h.listeners, addr) +} + +func (h *hub) getListener(addr string) (*listener, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + l, ok := h.listeners[addr] + return l, ok +} + +func (h *hub) addPubKey(p peer.ID, pk ic.PubKey) { + h.mu.Lock() + defer h.mu.Unlock() + + h.pubKeys[p] = pk +} + +func (h *hub) getPubKey(p peer.ID) (ic.PubKey, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + pk, ok := h.pubKeys[p] + return pk, ok +} + +func (h *hub) close() { + h.closeOnce.Do(func() { + h.mu.Lock() + defer h.mu.Unlock() + + for _, l := range h.listeners { + l.Close() + } + }) +} + +var memhub = &hub{ + listeners: make(map[string]*listener), + pubKeys: make(map[peer.ID]ic.PubKey), +} type transport struct { - pkey ic.PrivKey - pid peer.ID - psk pnet.PSK - rcmgr network.ResourceManager + psk pnet.PSK + rcmgr network.ResourceManager + localPeerID peer.ID + localPrivKey ic.PrivKey + localPubKey ic.PubKey mu sync.RWMutex connID atomic.Int32 - listeners map[string]*listener connections map[int32]*conn } -func NewTransport(key ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManager) (tpt.Transport, error) { +func NewTransport(privKey ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManager) (tpt.Transport, error) { if rcmgr == nil { rcmgr = &network.NullResourceManager{} } - id, err := peer.IDFromPrivateKey(key) + id, err := peer.IDFromPrivateKey(privKey) if err != nil { return nil, err } + memhub.addPubKey(id, privKey.GetPublic()) return &transport{ - rcmgr: rcmgr, - pid: id, - pkey: key, - psk: psk, - listeners: make(map[string]*listener), - connections: make(map[int32]*conn), + psk: psk, + rcmgr: rcmgr, + localPeerID: id, + localPrivKey: privKey, + localPubKey: privKey.GetPublic(), + connections: make(map[int32]*conn), }, nil } @@ -64,28 +123,37 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return c, nil } -func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { - if err := scope.SetPeer(p); err != nil { +func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, rpid peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { + if err := scope.SetPeer(rpid); err != nil { return nil, err } // TODO: Check if there is an existing listener for this address t.mu.RLock() defer t.mu.RUnlock() - l := t.listeners[raddr.String()] + l, ok := memhub.getListener(raddr.String()) + if !ok { + return nil, errors.New("failed to get listener") + } + + remotePubKey, ok := memhub.getPubKey(rpid) + if !ok { + return nil, errors.New("failed to get remote public key") + } ra, wb := io.Pipe() rb, wa := io.Pipe() - in, out := newStream(0, ra, wb), newStream(0, rb, wa) - inId, outId := t.connID.Add(1), t.connID.Add(1) + inConnId, outConnId := t.connID.Add(1), t.connID.Add(1) + inStream, outStream := newStream(0, ra, wb), newStream(0, rb, wa) - l.connCh <- newConnection(inId, in) + l.connCh <- newConnection(inConnId, inStream, rpid, raddr, t.localPubKey, t.localPeerID, nil) - return newConnection(outId, out), nil + return newConnection(outConnId, outStream, t.localPeerID, nil, remotePubKey, rpid, raddr), nil } func (t *transport) CanDial(addr ma.Multiaddr) bool { - return true + _, exists := memhub.getListener(addr.String()) + return exists } func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { @@ -95,7 +163,7 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { t.mu.Lock() defer t.mu.Unlock() - t.listeners[laddr.String()] = l + memhub.addListener(laddr.String(), l) return l, nil } @@ -106,7 +174,7 @@ func (t *transport) Proxy() bool { // Protocols returns the set of protocols handled by this transport. func (t *transport) Protocols() []int { - return []int{777} + return []int{ma.P_MEMORY} } func (t *transport) String() string { @@ -115,10 +183,7 @@ func (t *transport) String() string { func (t *transport) Close() error { // TODO: Go trough all listeners and close them - for _, l := range t.listeners { - l.Close() - } - + memhub.close() return nil } From 079bd3e5657177fa6e5d6fc5c40e476eaa950953 Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Sat, 16 Nov 2024 15:56:56 +0100 Subject: [PATCH 06/12] Use plain channels to send data between streams --- p2p/transport/memory/conn.go | 45 ++++---- p2p/transport/memory/listener.go | 5 +- p2p/transport/memory/stream.go | 144 +++++++++++++++---------- p2p/transport/memory/stream_test.go | 12 +-- p2p/transport/memory/transport.go | 35 +++--- p2p/transport/memory/transport_test.go | 70 ++++++++++++ 6 files changed, 201 insertions(+), 110 deletions(-) create mode 100644 p2p/transport/memory/transport_test.go diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go index d864e93316..515fb43625 100644 --- a/p2p/transport/memory/conn.go +++ b/p2p/transport/memory/conn.go @@ -2,7 +2,6 @@ package memory import ( "context" - "io" "sync" "sync/atomic" @@ -14,7 +13,7 @@ import ( ) type conn struct { - id int32 + id int64 transport *transport scope network.ConnManagementScope @@ -26,21 +25,19 @@ type conn struct { remotePubKey ic.PubKey remoteMultiaddr ma.Multiaddr - isClosed atomic.Bool - closeOnce sync.Once - mu sync.Mutex - streamC chan *stream + closed atomic.Bool + closeOnce sync.Once - nextStreamID atomic.Int32 - streams map[int32]network.MuxedStream + streamC chan *stream + streams map[int64]network.MuxedStream } var _ tpt.CapableConn = &conn{} func newConnection( - id int32, + t *transport, s *stream, localPeer peer.ID, localMultiaddr ma.Multiaddr, @@ -49,40 +46,36 @@ func newConnection( remoteMultiaddr ma.Multiaddr, ) *conn { c := &conn{ - id: id, + id: connCounter.Add(1), + transport: t, localPeer: localPeer, localMultiaddr: localMultiaddr, remotePubKey: remotePubKey, remotePeerID: remotePeer, remoteMultiaddr: remoteMultiaddr, streamC: make(chan *stream, 1), - streams: make(map[int32]network.MuxedStream), + streams: make(map[int64]network.MuxedStream), } - streamID := c.nextStreamID.Add(1) - c.addStream(streamID, s) - + c.addStream(s.id, s) return c } func (c *conn) Close() error { - c.closeOnce.Do(func() { - c.isClosed.Store(true) - c.transport.removeConn(c) - }) + c.closed.Store(true) + for _, s := range c.streams { + s.Close() + } return nil } func (c *conn) IsClosed() bool { - return c.isClosed.Load() + return c.closed.Load() } func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { - ra, wb := io.Pipe() - rb, wa := io.Pipe() - inConnId, outConnId := c.nextStreamID.Add(1), c.nextStreamID.Add(1) - inStream, outStream := newStream(inConnId, ra, wb), newStream(outConnId, rb, wa) + inStream, outStream := newStreamPair() c.streamC <- inStream return outStream, nil @@ -90,7 +83,7 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { func (c *conn) AcceptStream() (network.MuxedStream, error) { in := <-c.streamC - id := c.nextStreamID.Add(1) + id := streamCounter.Add(1) c.addStream(id, in) return in, nil } @@ -122,14 +115,14 @@ func (c *conn) ConnState() network.ConnectionState { return network.ConnectionState{Transport: "memory"} } -func (c *conn) addStream(id int32, stream network.MuxedStream) { +func (c *conn) addStream(id int64, stream network.MuxedStream) { c.mu.Lock() defer c.mu.Unlock() c.streams[id] = stream } -func (c *conn) removeStream(id int32) { +func (c *conn) removeStream(id int64) { c.mu.Lock() defer c.mu.Unlock() diff --git a/p2p/transport/memory/listener.go b/p2p/transport/memory/listener.go index 39e8acfb29..54417e2a8b 100644 --- a/p2p/transport/memory/listener.go +++ b/p2p/transport/memory/listener.go @@ -21,7 +21,7 @@ type listener struct { mu sync.Mutex connCh chan *conn - connections map[int32]*conn + connections map[int64]*conn } func (l *listener) Multiaddr() ma.Multiaddr { @@ -36,7 +36,7 @@ func newListener(t *transport, laddr ma.Multiaddr) *listener { cancel: cancel, laddr: laddr, connCh: make(chan *conn, listenerQueueSize), - connections: make(map[int32]*conn), + connections: make(map[int64]*conn), } } @@ -53,6 +53,7 @@ func (l *listener) Accept() (tpt.CapableConn, error) { l.mu.Lock() defer l.mu.Unlock() + c.transport = l.t l.connections[c.id] = c return c, nil } diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index 101ae516da..66d8879f88 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -1,106 +1,132 @@ package memory import ( + "errors" "io" + "net" "sync/atomic" "time" "github.com/libp2p/go-libp2p/core/network" ) +// stream implements network.Stream type stream struct { - id int32 + id int64 - r *io.PipeReader - w *io.PipeWriter - writeC chan []byte + write chan byte + read chan byte - readCloseC chan struct{} - writeCloseC chan struct{} + reset chan struct{} + closeRead chan struct{} + closeWrite chan struct{} + closed atomic.Bool +} + +var ErrClosed = errors.New("stream closed") - closed atomic.Bool +func newStreamPair() (*stream, *stream) { + ra, rb := make(chan byte, 4096), make(chan byte, 4096) + wa, wb := rb, ra + + in := newStream(rb, wb, network.DirInbound) + out := newStream(ra, wa, network.DirOutbound) + return in, out } -func newStream(id int32, r *io.PipeReader, w *io.PipeWriter) *stream { +func newStream(r, w chan byte, _ network.Direction) *stream { s := &stream{ - id: id, - r: r, - w: w, - writeC: make(chan []byte, 1), - readCloseC: make(chan struct{}, 1), - writeCloseC: make(chan struct{}, 1), + id: streamCounter.Add(1), + read: r, + write: w, + reset: make(chan struct{}, 1), + closeRead: make(chan struct{}, 1), + closeWrite: make(chan struct{}, 1), } - go func() { - for { - select { - case b := <-s.writeC: - if _, err := w.Write(b); err != nil { - return - } - case <-s.writeCloseC: - return - } - } - }() - return s } -func (s *stream) Read(b []byte) (int, error) { - return s.r.Read(b) -} - -func (s *stream) Write(b []byte) (int, error) { +// How to handle errors with writes? +func (s *stream) Write(p []byte) (n int, err error) { if s.closed.Load() { - return 0, network.ErrReset + return 0, ErrClosed } - select { - case <-s.writeCloseC: - return 0, network.ErrReset - case s.writeC <- b: - return len(b), nil + for i := 0; i < len(p); i++ { + select { + case <-s.reset: + err = network.ErrReset + return + case <-s.closeWrite: + err = ErrClosed + return + case s.write <- p[i]: + n = i + default: + err = io.ErrClosedPipe + } } + + return n + 1, err } -func (s *stream) Reset() error { - if err := s.CloseWrite(); err != nil { - return err +func (s *stream) Read(p []byte) (n int, err error) { + if s.closed.Load() { + return 0, ErrClosed } - if err := s.CloseRead(); err != nil { - return err + + for n = 0; n < len(p); n++ { + select { + case <-s.reset: + err = network.ErrReset + return + case <-s.closeRead: + err = ErrClosed + return + case b, ok := <-s.read: + if !ok { + err = io.EOF + return + } + p[n] = b + default: + err = io.EOF + return + } } - return nil + + return } -func (s *stream) Close() error { - s.CloseRead() - s.CloseWrite() +func (s *stream) CloseWrite() error { + s.closeWrite <- struct{}{} return nil } func (s *stream) CloseRead() error { - return s.r.CloseWithError(network.ErrReset) + s.closeRead <- struct{}{} + return nil } -func (s *stream) CloseWrite() error { - select { - case s.writeCloseC <- struct{}{}: - default: - } - +func (s *stream) Close() error { s.closed.Store(true) return nil } -func (s *stream) SetDeadline(_ time.Time) error { +func (s *stream) Reset() error { + s.reset <- struct{}{} return nil } -func (s *stream) SetReadDeadline(_ time.Time) error { - return nil +func (s *stream) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } -func (s *stream) SetWriteDeadline(_ time.Time) error { - return nil + +func (s *stream) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (s *stream) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go index 33c3cbdc64..cd5149c685 100644 --- a/p2p/transport/memory/stream_test.go +++ b/p2p/transport/memory/stream_test.go @@ -8,22 +8,17 @@ import ( ) func TestStreamSimpleReadWriteClose(t *testing.T) { - //client, server := getDetachedDataChannels(t) - ra, wb := io.Pipe() - rb, wa := io.Pipe() - - clientStr := newStream(0, ra, wa) - serverStr := newStream(1, rb, wb) + clientStr, serverStr := newStreamPair() // send a foobar from the client n, err := clientStr.Write([]byte("foobar")) require.NoError(t, err) require.Equal(t, 6, n) require.NoError(t, clientStr.CloseWrite()) + // writing after closing should error _, err = clientStr.Write([]byte("foobar")) require.Error(t, err) - //require.False(t, clientDone.Load()) // now read all the data on the server side b, err := io.ReadAll(serverStr) @@ -33,7 +28,6 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { n, err = serverStr.Read(make([]byte, 10)) require.Zero(t, n) require.ErrorIs(t, err, io.EOF) - //require.False(t, serverDone.Load()) // send something back _, err = serverStr.Write([]byte("lorem ipsum")) @@ -49,8 +43,6 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { // stream is only cleaned up on calling Close or Reset clientStr.Close() serverStr.Close() - //require.Eventually(t, func() bool { return clientDone.Load() }, 5*time.Second, 100*time.Millisecond) // Need to call Close for cleanup. Otherwise the FIN_ACK is never read require.NoError(t, serverStr.Close()) - //require.Eventually(t, func() bool { return serverDone.Load() }, 5*time.Second, 100*time.Millisecond) } diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go index 5016e3a7dd..a13c737437 100644 --- a/p2p/transport/memory/transport.go +++ b/p2p/transport/memory/transport.go @@ -3,7 +3,6 @@ package memory import ( "context" "errors" - "io" "sync" "sync/atomic" @@ -13,6 +12,14 @@ import ( "github.com/libp2p/go-libp2p/core/pnet" tpt "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" + mafmt "github.com/multiformats/go-multiaddr-fmt" +) + +var ( + connCounter atomic.Int64 + streamCounter atomic.Int64 + listenerCounter atomic.Int64 + dialMatcher = mafmt.Base(ma.P_MEMORY) ) type hub struct { @@ -84,8 +91,7 @@ type transport struct { mu sync.RWMutex - connID atomic.Int32 - connections map[int32]*conn + connections map[int64]*conn } func NewTransport(privKey ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManager) (tpt.Transport, error) { @@ -105,7 +111,7 @@ func NewTransport(privKey ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManage localPeerID: id, localPrivKey: privKey, localPubKey: privKey.GetPublic(), - connections: make(map[int32]*conn), + connections: make(map[int64]*conn), }, nil } @@ -141,19 +147,16 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, rpid return nil, errors.New("failed to get remote public key") } - ra, wb := io.Pipe() - rb, wa := io.Pipe() - inConnId, outConnId := t.connID.Add(1), t.connID.Add(1) - inStream, outStream := newStream(0, ra, wb), newStream(0, rb, wa) - - l.connCh <- newConnection(inConnId, inStream, rpid, raddr, t.localPubKey, t.localPeerID, nil) + inStream, outStream := newStreamPair() + inConn := newConnection(t, outStream, t.localPeerID, nil, remotePubKey, rpid, raddr) + outConn := newConnection(nil, inStream, rpid, raddr, t.localPubKey, t.localPeerID, nil) + l.connCh <- outConn - return newConnection(outConnId, outStream, t.localPeerID, nil, remotePubKey, rpid, raddr), nil + return inConn, nil } func (t *transport) CanDial(addr ma.Multiaddr) bool { - _, exists := memhub.getListener(addr.String()) - return exists + return dialMatcher.Matches(addr) } func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { @@ -184,6 +187,12 @@ func (t *transport) String() string { func (t *transport) Close() error { // TODO: Go trough all listeners and close them memhub.close() + + for _, c := range t.connections { + c.Close() + delete(t.connections, c.id) + } + return nil } diff --git a/p2p/transport/memory/transport_test.go b/p2p/transport/memory/transport_test.go new file mode 100644 index 0000000000..f83f0d1280 --- /dev/null +++ b/p2p/transport/memory/transport_test.go @@ -0,0 +1,70 @@ +package memory + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "io" + "testing" + + ic "github.com/libp2p/go-libp2p/core/crypto" + tpt "github.com/libp2p/go-libp2p/core/transport" + + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func getTransport(t *testing.T) tpt.Transport { + t.Helper() + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) + require.NoError(t, err) + tr, err := NewTransport(key, nil, nil) + require.NoError(t, err) + return tr +} + +func TestMemoryProtocol(t *testing.T) { + tr := getTransport(t) + defer tr.(io.Closer).Close() + + protocols := tr.Protocols() + if len(protocols) > 1 { + t.Fatalf("expected at most one protocol, got %v", protocols) + } + + if protocols[0] != ma.P_MEMORY { + t.Fatalf("expected the supported protocol to be memory, got %d", protocols[0]) + } +} + +func TestCanDial(t *testing.T) { + tr := getTransport(t) + defer tr.(io.Closer).Close() + + invalid := []string{ + "/ip4/127.0.0.1/udp/1234", + "/ip4/5.5.5.5/tcp/1234", + "/dns/google.com/udp/443/quic-v1", + "/ip4/127.0.0.1/udp/1234/quic", + } + valid := []string{ + "/memory/1234", + "/memory/1337123", + } + for _, s := range invalid { + invalidAddr, err := ma.NewMultiaddr(s) + require.NoError(t, err) + if tr.CanDial(invalidAddr) { + t.Errorf("didn't expect to be able to dial a non-memory address (%s)", invalidAddr) + } + } + for _, s := range valid { + validAddr, err := ma.NewMultiaddr(s) + require.NoError(t, err) + if !tr.CanDial(validAddr) { + t.Errorf("expected to be able to dial memory address (%s)", validAddr) + } + } +} From e3203f2f8768f067c6a15f1d7f0c084102690104 Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Thu, 21 Nov 2024 10:12:24 +0100 Subject: [PATCH 07/12] Daily commit --- p2p/test/transport/transport_test.go | 2 - p2p/transport/memory/conn.go | 6 +-- p2p/transport/memory/stream.go | 58 ++++++++++++++++------------ p2p/transport/memory/stream_test.go | 24 +++++++++++- test-plans/go.mod | 2 +- test-plans/go.sum | 1 + 6 files changed, 62 insertions(+), 31 deletions(-) diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index e353ba6526..7010b3d0ce 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -30,9 +30,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/net/swarm" "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/noise" - tls "github.com/libp2p/go-libp2p/p2p/security/tls" libp2pmemory "github.com/libp2p/go-libp2p/p2p/transport/memory" - libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "go.uber.org/mock/gomock" ma "github.com/multiformats/go-multiaddr" diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go index 515fb43625..d08b942a70 100644 --- a/p2p/transport/memory/conn.go +++ b/p2p/transport/memory/conn.go @@ -63,7 +63,8 @@ func newConnection( func (c *conn) Close() error { c.closed.Store(true) - for _, s := range c.streams { + for id, s := range c.streams { + c.removeStream(id) s.Close() } @@ -83,8 +84,7 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { func (c *conn) AcceptStream() (network.MuxedStream, error) { in := <-c.streamC - id := streamCounter.Add(1) - c.addStream(id, in) + c.addStream(in.id, in) return in, nil } diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index 66d8879f88..aa20eebd23 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -27,10 +27,9 @@ var ErrClosed = errors.New("stream closed") func newStreamPair() (*stream, *stream) { ra, rb := make(chan byte, 4096), make(chan byte, 4096) - wa, wb := rb, ra - in := newStream(rb, wb, network.DirInbound) - out := newStream(ra, wa, network.DirOutbound) + in := newStream(rb, ra, network.DirInbound) + out := newStream(ra, rb, network.DirOutbound) return in, out } @@ -47,28 +46,34 @@ func newStream(r, w chan byte, _ network.Direction) *stream { return s } -// How to handle errors with writes? func (s *stream) Write(p []byte) (n int, err error) { if s.closed.Load() { return 0, ErrClosed } - for i := 0; i < len(p); i++ { + select { + case <-s.reset: + return 0, network.ErrReset + case <-s.closeWrite: + return 0, ErrClosed + default: + } + + for n < len(p) { select { - case <-s.reset: - err = network.ErrReset - return case <-s.closeWrite: - err = ErrClosed - return - case s.write <- p[i]: - n = i + return n, ErrClosed + case <-s.reset: + return n, network.ErrReset + case s.write <- p[n]: + n++ default: err = io.ErrClosedPipe + return } } - return n + 1, err + return } func (s *stream) Read(p []byte) (n int, err error) { @@ -76,27 +81,32 @@ func (s *stream) Read(p []byte) (n int, err error) { return 0, ErrClosed } - for n = 0; n < len(p); n++ { + select { + case <-s.reset: + return 0, network.ErrReset + case <-s.closeRead: + return 0, ErrClosed + default: + } + + for n < len(p) { select { - case <-s.reset: - err = network.ErrReset - return case <-s.closeRead: - err = ErrClosed - return + return n, ErrClosed + case <-s.reset: + return n, network.ErrReset case b, ok := <-s.read: if !ok { - err = io.EOF - return + return n, ErrClosed } p[n] = b + n++ default: - err = io.EOF - return + return n, io.EOF } } - return + return n, nil } func (s *stream) CloseWrite() error { diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go index cd5149c685..b23d3f7bb4 100644 --- a/p2p/transport/memory/stream_test.go +++ b/p2p/transport/memory/stream_test.go @@ -8,6 +8,7 @@ import ( ) func TestStreamSimpleReadWriteClose(t *testing.T) { + // t.Parallel() clientStr, serverStr := newStreamPair() // send a foobar from the client @@ -24,6 +25,7 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { b, err := io.ReadAll(serverStr) require.NoError(t, err) require.Equal(t, []byte("foobar"), b) + // reading again should give another io.EOF n, err = serverStr.Read(make([]byte, 10)) require.Zero(t, n) @@ -35,7 +37,6 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { require.NoError(t, serverStr.CloseWrite()) // and read it at the client - //require.False(t, clientDone.Load()) b, err = io.ReadAll(clientStr) require.NoError(t, err) require.Equal(t, []byte("lorem ipsum"), b) @@ -46,3 +47,24 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { // Need to call Close for cleanup. Otherwise the FIN_ACK is never read require.NoError(t, serverStr.Close()) } + +func TestStreamPartialReads(t *testing.T) { + // t.Parallel() + clientStr, serverStr := newStreamPair() + + _, err := serverStr.Write([]byte("foobar")) + require.NoError(t, err) + require.NoError(t, serverStr.CloseWrite()) + + n, err := clientStr.Read([]byte{}) // empty read + require.NoError(t, err) + require.Zero(t, n) + b := make([]byte, 3) + n, err = clientStr.Read(b) + require.Equal(t, 3, n) + require.NoError(t, err) + require.Equal(t, []byte("foo"), b) + b, err = io.ReadAll(clientStr) + require.NoError(t, err) + require.Equal(t, []byte("bar"), b) +} diff --git a/test-plans/go.mod b/test-plans/go.mod index 1c07eba15d..5b0d9980bf 100644 --- a/test-plans/go.mod +++ b/test-plans/go.mod @@ -7,7 +7,7 @@ toolchain go1.22.1 require ( github.com/go-redis/redis/v8 v8.11.5 github.com/libp2p/go-libp2p v0.0.0 - github.com/multiformats/go-multiaddr v0.13.0 + github.com/multiformats/go-multiaddr v0.14.0 ) require ( diff --git a/test-plans/go.sum b/test-plans/go.sum index 5517e8d41d..cd5f393530 100644 --- a/test-plans/go.sum +++ b/test-plans/go.sum @@ -185,6 +185,7 @@ github.com/multiformats/go-base36 v0.2.0/go.mod h1:qvnKE++v+2MWCfePClUEjE78Z7P2a github.com/multiformats/go-multiaddr v0.1.1/go.mod h1:aMKBKNEYmzmDmxfX88/vz+J5IU55txyt0p4aiWVohjo= github.com/multiformats/go-multiaddr v0.13.0 h1:BCBzs61E3AGHcYYTv8dqRH43ZfyrqM8RXVPT8t13tLQ= github.com/multiformats/go-multiaddr v0.13.0/go.mod h1:sBXrNzucqkFJhvKOiwwLyqamGa/P5EIXNPLovyhQCII= +github.com/multiformats/go-multiaddr v0.14.0/go.mod h1:6EkVAxtznq2yC3QT5CM1UTAwG0GTP3EWAIcjHuzQ+r4= github.com/multiformats/go-multiaddr-dns v0.4.0 h1:P76EJ3qzBXpUXZ3twdCDx/kvagMsNo0LMFXpyms/zgU= github.com/multiformats/go-multiaddr-dns v0.4.0/go.mod h1:7hfthtB4E4pQwirrz+J0CcDUfbWzTqEzVyYKKIKpgkc= github.com/multiformats/go-multiaddr-fmt v0.1.0 h1:WLEFClPycPkp4fnIzoFoV9FVd49/eQsuaL3/CWe167E= From b96f386361bddb23d2c963c0037c71ec220da53c Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Thu, 21 Nov 2024 22:17:51 +0100 Subject: [PATCH 08/12] Daily commit --- p2p/test/transport/transport_test.go | 34 +----- p2p/transport/memory/conn.go | 10 +- p2p/transport/memory/stream.go | 162 +++++++++++++------------ p2p/transport/memory/stream_test.go | 4 +- p2p/transport/memory/transport.go | 16 +-- p2p/transport/memory/transport_test.go | 2 + 6 files changed, 106 insertions(+), 122 deletions(-) diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 203c382124..e353ba6526 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -30,7 +30,9 @@ import ( "github.com/libp2p/go-libp2p/p2p/net/swarm" "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/noise" + tls "github.com/libp2p/go-libp2p/p2p/security/tls" libp2pmemory "github.com/libp2p/go-libp2p/p2p/transport/memory" + libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "go.uber.org/mock/gomock" ma "github.com/multiformats/go-multiaddr" @@ -98,38 +100,6 @@ var transportsToTest = []TransportTestCase{ return h }, }, - { - Name: "TCP-Shared / TLS / Yamux", - HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { - libp2pOpts := transformOpts(opts) - libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) - libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New)) - libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport)) - if opts.NoListen { - libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) - } else { - libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0")) - } - h, err := libp2p.New(libp2pOpts...) - require.NoError(t, err) - return h - }, - }, - { - Name: "WebSocket-Shared", - HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { - libp2pOpts := transformOpts(opts) - libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) - if opts.NoListen { - libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) - } else { - libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/ws")) - } - h, err := libp2p.New(libp2pOpts...) - require.NoError(t, err) - return h - }, - }, { Name: "WebSocket", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go index d08b942a70..df9b08b4b2 100644 --- a/p2p/transport/memory/conn.go +++ b/p2p/transport/memory/conn.go @@ -63,8 +63,8 @@ func newConnection( func (c *conn) Close() error { c.closed.Store(true) - for id, s := range c.streams { - c.removeStream(id) + for _, s := range c.streams { + //c.removeStream(id) s.Close() } @@ -76,10 +76,10 @@ func (c *conn) IsClosed() bool { } func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { - inStream, outStream := newStreamPair() + sl, sr := newStreamPair() - c.streamC <- inStream - return outStream, nil + c.streamC <- sr + return sl, nil } func (c *conn) AcceptStream() (network.MuxedStream, error) { diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index aa20eebd23..a445a712b2 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -3,8 +3,8 @@ package memory import ( "errors" "io" + "log" "net" - "sync/atomic" "time" "github.com/libp2p/go-libp2p/core/network" @@ -14,118 +14,96 @@ import ( type stream struct { id int64 - write chan byte - read chan byte + read *io.PipeReader + write *io.PipeWriter + writeC chan []byte - reset chan struct{} - closeRead chan struct{} - closeWrite chan struct{} - closed atomic.Bool + reset chan struct{} + close chan struct{} + closed chan struct{} + + writeErr error } var ErrClosed = errors.New("stream closed") func newStreamPair() (*stream, *stream) { - ra, rb := make(chan byte, 4096), make(chan byte, 4096) + ra, wb := io.Pipe() + rb, wa := io.Pipe() + + sa := newStream(rb, wa, network.DirOutbound) + sb := newStream(ra, wb, network.DirInbound) - in := newStream(rb, ra, network.DirInbound) - out := newStream(ra, rb, network.DirOutbound) - return in, out + return sa, sb } -func newStream(r, w chan byte, _ network.Direction) *stream { +func newStream(r *io.PipeReader, w *io.PipeWriter, _ network.Direction) *stream { s := &stream{ - id: streamCounter.Add(1), - read: r, - write: w, - reset: make(chan struct{}, 1), - closeRead: make(chan struct{}, 1), - closeWrite: make(chan struct{}, 1), + id: streamCounter.Add(1), + read: r, + write: w, + writeC: make(chan []byte), + reset: make(chan struct{}, 1), + close: make(chan struct{}, 1), + closed: make(chan struct{}), } + log.Println("newStream", "id", s.id) + go s.writeLoop() return s } -func (s *stream) Write(p []byte) (n int, err error) { - if s.closed.Load() { - return 0, ErrClosed - } +func (s *stream) Write(p []byte) (int, error) { + cpy := make([]byte, len(p)) + copy(cpy, p) select { - case <-s.reset: - return 0, network.ErrReset - case <-s.closeWrite: - return 0, ErrClosed - default: + case <-s.closed: + return 0, s.writeErr + case s.writeC <- cpy: } - for n < len(p) { - select { - case <-s.closeWrite: - return n, ErrClosed - case <-s.reset: - return n, network.ErrReset - case s.write <- p[n]: - n++ - default: - err = io.ErrClosedPipe - return - } - } - - return + return len(p), nil } func (s *stream) Read(p []byte) (n int, err error) { - if s.closed.Load() { - return 0, ErrClosed - } + return s.read.Read(p) +} +func (s *stream) CloseWrite() error { select { - case <-s.reset: - return 0, network.ErrReset - case <-s.closeRead: - return 0, ErrClosed + case s.close <- struct{}{}: default: } - - for n < len(p) { - select { - case <-s.closeRead: - return n, ErrClosed - case <-s.reset: - return n, network.ErrReset - case b, ok := <-s.read: - if !ok { - return n, ErrClosed - } - p[n] = b - n++ - default: - return n, io.EOF - } + log.Println("waiting close", "id", s.id) + <-s.closed + log.Println("closed write", "id", s.id) + if !errors.Is(s.writeErr, ErrClosed) { + return s.writeErr } - - return n, nil -} - -func (s *stream) CloseWrite() error { - s.closeWrite <- struct{}{} return nil + } func (s *stream) CloseRead() error { - s.closeRead <- struct{}{} - return nil + return s.read.CloseWithError(ErrClosed) } func (s *stream) Close() error { - s.closed.Store(true) - return nil + _ = s.CloseRead() + return s.CloseWrite() } func (s *stream) Reset() error { - s.reset <- struct{}{} + s.write.CloseWithError(network.ErrReset) + s.read.CloseWithError(network.ErrReset) + + select { + case s.reset <- struct{}{}: + default: + } + <-s.closed + return nil } @@ -140,3 +118,35 @@ func (s *stream) SetReadDeadline(t time.Time) error { func (s *stream) SetWriteDeadline(t time.Time) error { return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } + +func (s *stream) writeLoop() { + defer close(s.closed) + defer log.Println("closing write", "id", s.id) + + for { + // Reset takes precedent. + select { + case <-s.reset: + s.writeErr = network.ErrReset + return + default: + } + + select { + case <-s.reset: + s.writeErr = network.ErrReset + return + case <-s.close: + s.writeErr = s.write.Close() + if s.writeErr == nil { + s.writeErr = ErrClosed + } + return + case p := <-s.writeC: + if _, err := s.write.Write(p); err != nil { + s.writeErr = err + return + } + } + } +} diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go index b23d3f7bb4..5ce3939863 100644 --- a/p2p/transport/memory/stream_test.go +++ b/p2p/transport/memory/stream_test.go @@ -8,7 +8,7 @@ import ( ) func TestStreamSimpleReadWriteClose(t *testing.T) { - // t.Parallel() + t.Parallel() clientStr, serverStr := newStreamPair() // send a foobar from the client @@ -49,7 +49,7 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { } func TestStreamPartialReads(t *testing.T) { - // t.Parallel() + t.Parallel() clientStr, serverStr := newStreamPair() _, err := serverStr.Write([]byte("foobar")) diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go index a13c737437..56cde6a3f6 100644 --- a/p2p/transport/memory/transport.go +++ b/p2p/transport/memory/transport.go @@ -147,12 +147,14 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, rpid return nil, errors.New("failed to get remote public key") } - inStream, outStream := newStreamPair() - inConn := newConnection(t, outStream, t.localPeerID, nil, remotePubKey, rpid, raddr) - outConn := newConnection(nil, inStream, rpid, raddr, t.localPubKey, t.localPeerID, nil) - l.connCh <- outConn + sl, sr := newStreamPair() - return inConn, nil + lconn := newConnection(t, sl, t.localPeerID, nil, remotePubKey, rpid, raddr) + rconn := newConnection(nil, sr, rpid, raddr, t.localPubKey, t.localPeerID, nil) + + l.connCh <- rconn + + return lconn, nil } func (t *transport) CanDial(addr ma.Multiaddr) bool { @@ -186,11 +188,11 @@ func (t *transport) String() string { func (t *transport) Close() error { // TODO: Go trough all listeners and close them - memhub.close() + //memhub.close() for _, c := range t.connections { c.Close() - delete(t.connections, c.id) + //delete(t.connections, c.id) } return nil diff --git a/p2p/transport/memory/transport_test.go b/p2p/transport/memory/transport_test.go index f83f0d1280..f17835f3ff 100644 --- a/p2p/transport/memory/transport_test.go +++ b/p2p/transport/memory/transport_test.go @@ -26,6 +26,7 @@ func getTransport(t *testing.T) tpt.Transport { } func TestMemoryProtocol(t *testing.T) { + t.Parallel() tr := getTransport(t) defer tr.(io.Closer).Close() @@ -40,6 +41,7 @@ func TestMemoryProtocol(t *testing.T) { } func TestCanDial(t *testing.T) { + t.Parallel() tr := getTransport(t) defer tr.(io.Closer).Close() From 5236ff28dc0cae89e45c5379a393a9d7b524070d Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Fri, 22 Nov 2024 18:59:17 +0100 Subject: [PATCH 09/12] Daily commit --- p2p/transport/memory/conn.go | 27 +++++-- p2p/transport/memory/hub.go | 80 +++++++++++++++++++++ p2p/transport/memory/listener.go | 4 ++ p2p/transport/memory/stream.go | 38 ++++++---- p2p/transport/memory/stream_test.go | 3 +- p2p/transport/memory/transport.go | 107 +++++----------------------- 6 files changed, 149 insertions(+), 110 deletions(-) create mode 100644 p2p/transport/memory/hub.go diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go index df9b08b4b2..61c81462cc 100644 --- a/p2p/transport/memory/conn.go +++ b/p2p/transport/memory/conn.go @@ -13,10 +13,12 @@ import ( ) type conn struct { - id int64 + id int64 + rconn *conn - transport *transport scope network.ConnManagementScope + listener *listener + transport *transport localPeer peer.ID localMultiaddr ma.Multiaddr @@ -62,11 +64,11 @@ func newConnection( } func (c *conn) Close() error { - c.closed.Store(true) - for _, s := range c.streams { - //c.removeStream(id) - s.Close() - } + c.closeOnce.Do(func() { + c.closed.Store(true) + go c.rconn.Close() + c.teardown() + }) return nil } @@ -79,11 +81,14 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { sl, sr := newStreamPair() c.streamC <- sr + sl.conn = c + c.addStream(sl.id, sl) return sl, nil } func (c *conn) AcceptStream() (network.MuxedStream, error) { in := <-c.streamC + in.conn = c c.addStream(in.id, in) return in, nil } @@ -128,3 +133,11 @@ func (c *conn) removeStream(id int64) { delete(c.streams, id) } + +func (c *conn) teardown() { + for _, s := range c.streams { + s.Reset() + } + + // TODO: remove self from listener +} diff --git a/p2p/transport/memory/hub.go b/p2p/transport/memory/hub.go new file mode 100644 index 0000000000..55b85ccbad --- /dev/null +++ b/p2p/transport/memory/hub.go @@ -0,0 +1,80 @@ +package memory + +import ( + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + ma "github.com/multiformats/go-multiaddr" + mafmt "github.com/multiformats/go-multiaddr-fmt" + "sync" + "sync/atomic" +) + +var ( + connCounter atomic.Int64 + streamCounter atomic.Int64 + listenerCounter atomic.Int64 + dialMatcher = mafmt.Base(ma.P_MEMORY) + memhub = newHub() +) + +type hub struct { + mu sync.RWMutex + closeOnce sync.Once + pubKeys map[peer.ID]ic.PubKey + listeners map[string]*listener +} + +func newHub() *hub { + return &hub{ + pubKeys: make(map[peer.ID]ic.PubKey), + listeners: make(map[string]*listener), + } +} + +func (h *hub) addListener(addr string, l *listener) { + h.mu.Lock() + defer h.mu.Unlock() + + h.listeners[addr] = l +} + +func (h *hub) removeListener(addr string, l *listener) { + h.mu.Lock() + defer h.mu.Unlock() + + delete(h.listeners, addr) +} + +func (h *hub) getListener(addr string) (*listener, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + l, ok := h.listeners[addr] + return l, ok +} + +func (h *hub) addPubKey(p peer.ID, pk ic.PubKey) { + h.mu.Lock() + defer h.mu.Unlock() + + h.pubKeys[p] = pk +} + +func (h *hub) getPubKey(p peer.ID) (ic.PubKey, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + pk, ok := h.pubKeys[p] + return pk, ok +} + +func (h *hub) close() { + h.closeOnce.Do(func() { + h.mu.Lock() + defer h.mu.Unlock() + + for _, l := range h.listeners { + l.Close() + } + }) +} diff --git a/p2p/transport/memory/listener.go b/p2p/transport/memory/listener.go index 54417e2a8b..81d5e484d0 100644 --- a/p2p/transport/memory/listener.go +++ b/p2p/transport/memory/listener.go @@ -14,6 +14,8 @@ const ( ) type listener struct { + id int64 + t *transport ctx context.Context cancel context.CancelFunc @@ -31,6 +33,7 @@ func (l *listener) Multiaddr() ma.Multiaddr { func newListener(t *transport, laddr ma.Multiaddr) *listener { ctx, cancel := context.WithCancel(context.Background()) return &listener{ + id: listenerCounter.Add(1), t: t, ctx: ctx, cancel: cancel, @@ -53,6 +56,7 @@ func (l *listener) Accept() (tpt.CapableConn, error) { l.mu.Lock() defer l.mu.Unlock() + c.listener = l c.transport = l.t l.connections[c.id] = c return c, nil diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index a445a712b2..6aa59bfc7b 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -3,7 +3,6 @@ package memory import ( "errors" "io" - "log" "net" "time" @@ -12,7 +11,8 @@ import ( // stream implements network.Stream type stream struct { - id int64 + id int64 + conn *conn read *io.PipeReader write *io.PipeWriter @@ -31,13 +31,13 @@ func newStreamPair() (*stream, *stream) { ra, wb := io.Pipe() rb, wa := io.Pipe() - sa := newStream(rb, wa, network.DirOutbound) - sb := newStream(ra, wb, network.DirInbound) + sa := newStream(wa, ra, network.DirOutbound) + sb := newStream(wb, rb, network.DirInbound) return sa, sb } -func newStream(r *io.PipeReader, w *io.PipeWriter, _ network.Direction) *stream { +func newStream(w *io.PipeWriter, r *io.PipeReader, _ network.Direction) *stream { s := &stream{ id: streamCounter.Add(1), read: r, @@ -47,7 +47,6 @@ func newStream(r *io.PipeReader, w *io.PipeWriter, _ network.Direction) *stream close: make(chan struct{}, 1), closed: make(chan struct{}), } - log.Println("newStream", "id", s.id) go s.writeLoop() return s @@ -66,7 +65,7 @@ func (s *stream) Write(p []byte) (int, error) { return len(p), nil } -func (s *stream) Read(p []byte) (n int, err error) { +func (s *stream) Read(p []byte) (int, error) { return s.read.Read(p) } @@ -75,9 +74,7 @@ func (s *stream) CloseWrite() error { case s.close <- struct{}{}: default: } - log.Println("waiting close", "id", s.id) <-s.closed - log.Println("closed write", "id", s.id) if !errors.Is(s.writeErr, ErrClosed) { return s.writeErr } @@ -95,6 +92,7 @@ func (s *stream) Close() error { } func (s *stream) Reset() error { + // Cancel any pending reads/writes with an error. s.write.CloseWithError(network.ErrReset) s.read.CloseWithError(network.ErrReset) @@ -103,7 +101,7 @@ func (s *stream) Reset() error { default: } <-s.closed - + // No meaningful error case here. return nil } @@ -120,8 +118,7 @@ func (s *stream) SetWriteDeadline(t time.Time) error { } func (s *stream) writeLoop() { - defer close(s.closed) - defer log.Println("closing write", "id", s.id) + defer s.teardown() for { // Reset takes precedent. @@ -144,9 +141,24 @@ func (s *stream) writeLoop() { return case p := <-s.writeC: if _, err := s.write.Write(p); err != nil { - s.writeErr = err + s.cancelWrite(err) return } } } } + +func (s *stream) cancelWrite(err error) { + s.write.CloseWithError(err) + s.writeErr = err +} + +func (s *stream) teardown() { + // at this point, no streams are writing. + if s.conn != nil { + s.conn.removeStream(s.id) + } + + // Mark as closed. + close(s.closed) +} diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go index 5ce3939863..1f3e8fb0cd 100644 --- a/p2p/transport/memory/stream_test.go +++ b/p2p/transport/memory/stream_test.go @@ -1,10 +1,9 @@ package memory import ( + "github.com/stretchr/testify/require" "io" "testing" - - "github.com/stretchr/testify/require" ) func TestStreamSimpleReadWriteClose(t *testing.T) { diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go index 56cde6a3f6..39d152e7d3 100644 --- a/p2p/transport/memory/transport.go +++ b/p2p/transport/memory/transport.go @@ -3,85 +3,15 @@ package memory import ( "context" "errors" - "sync" - "sync/atomic" - ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/pnet" tpt "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" - mafmt "github.com/multiformats/go-multiaddr-fmt" -) - -var ( - connCounter atomic.Int64 - streamCounter atomic.Int64 - listenerCounter atomic.Int64 - dialMatcher = mafmt.Base(ma.P_MEMORY) + "sync" ) -type hub struct { - mu sync.RWMutex - closeOnce sync.Once - pubKeys map[peer.ID]ic.PubKey - listeners map[string]*listener -} - -func (h *hub) addListener(addr string, l *listener) { - h.mu.Lock() - defer h.mu.Unlock() - - h.listeners[addr] = l -} - -func (h *hub) removeListener(addr string, l *listener) { - h.mu.Lock() - defer h.mu.Unlock() - - delete(h.listeners, addr) -} - -func (h *hub) getListener(addr string) (*listener, bool) { - h.mu.RLock() - defer h.mu.RUnlock() - - l, ok := h.listeners[addr] - return l, ok -} - -func (h *hub) addPubKey(p peer.ID, pk ic.PubKey) { - h.mu.Lock() - defer h.mu.Unlock() - - h.pubKeys[p] = pk -} - -func (h *hub) getPubKey(p peer.ID) (ic.PubKey, bool) { - h.mu.RLock() - defer h.mu.RUnlock() - - pk, ok := h.pubKeys[p] - return pk, ok -} - -func (h *hub) close() { - h.closeOnce.Do(func() { - h.mu.Lock() - defer h.mu.Unlock() - - for _, l := range h.listeners { - l.Close() - } - }) -} - -var memhub = &hub{ - listeners: make(map[string]*listener), - pubKeys: make(map[peer.ID]ic.PubKey), -} - type transport struct { psk pnet.PSK rcmgr network.ResourceManager @@ -129,15 +59,12 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return c, nil } -func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, rpid peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { +func (t *transport) dialWithScope(_ context.Context, raddr ma.Multiaddr, rpid peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { if err := scope.SetPeer(rpid); err != nil { return nil, err } - // TODO: Check if there is an existing listener for this address - t.mu.RLock() - defer t.mu.RUnlock() - l, ok := memhub.getListener(raddr.String()) + rl, ok := memhub.getListener(raddr.String()) if !ok { return nil, errors.New("failed to get listener") } @@ -147,14 +74,10 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, rpid return nil, errors.New("failed to get remote public key") } - sl, sr := newStreamPair() - - lconn := newConnection(t, sl, t.localPeerID, nil, remotePubKey, rpid, raddr) - rconn := newConnection(nil, sr, rpid, raddr, t.localPubKey, t.localPeerID, nil) + lc, rc := t.newConnPair(remotePubKey, rpid, raddr) - l.connCh <- rconn - - return lconn, nil + rl.connCh <- rc + return lc, nil } func (t *transport) CanDial(addr ma.Multiaddr) bool { @@ -164,10 +87,6 @@ func (t *transport) CanDial(addr ma.Multiaddr) bool { func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { // TODO: Check if we need to add scope via conn mngr l := newListener(t, laddr) - - t.mu.Lock() - defer t.mu.Unlock() - memhub.addListener(laddr.String(), l) return l, nil @@ -188,7 +107,8 @@ func (t *transport) String() string { func (t *transport) Close() error { // TODO: Go trough all listeners and close them - //memhub.close() + t.mu.Lock() + defer t.mu.Unlock() for _, c := range t.connections { c.Close() @@ -211,3 +131,14 @@ func (t *transport) removeConn(c *conn) { delete(t.connections, c.id) } + +func (t *transport) newConnPair(remotePubKey ic.PubKey, rpid peer.ID, raddr ma.Multiaddr) (*conn, *conn) { + sl, sr := newStreamPair() + + lc := newConnection(t, sl, t.localPeerID, nil, remotePubKey, rpid, raddr) + rc := newConnection(nil, sr, rpid, raddr, t.localPubKey, t.localPeerID, nil) + + lc.rconn = rc + rc.rconn = lc + return lc, rc +} From f511c2d8c1252578994fae9b06ed65f6e9ad1c9b Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Mon, 25 Nov 2024 11:56:49 +0100 Subject: [PATCH 10/12] Revert to stream using channels instead of io.Pipe --- p2p/transport/memory/stream.go | 204 +++++++++++++++++----------- p2p/transport/memory/stream_test.go | 36 +++-- 2 files changed, 138 insertions(+), 102 deletions(-) diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index 6aa59bfc7b..d196aea0fc 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -1,9 +1,12 @@ package memory import ( + "bytes" "errors" "io" "net" + "sync" + "sync/atomic" "time" "github.com/libp2p/go-libp2p/core/network" @@ -14,59 +17,133 @@ type stream struct { id int64 conn *conn - read *io.PipeReader - write *io.PipeWriter - writeC chan []byte + wrMu sync.Mutex // Serialize Write operations + buf *bytes.Buffer // Buffer for partial reads - reset chan struct{} - close chan struct{} - closed chan struct{} + // Used by local Read to interact with remote Write. + rdRx <-chan []byte - writeErr error + // Used by local Write to interact with remote Read. + wrTx chan<- []byte + + once sync.Once // Protects closing localDone + localDone chan struct{} + remoteDone <-chan struct{} + + reset chan struct{} + close chan struct{} + readClosed atomic.Bool + writeClosed atomic.Bool } var ErrClosed = errors.New("stream closed") func newStreamPair() (*stream, *stream) { - ra, wb := io.Pipe() - rb, wa := io.Pipe() + cb1 := make(chan []byte, 1) + cb2 := make(chan []byte, 1) + + done1 := make(chan struct{}) + done2 := make(chan struct{}) - sa := newStream(wa, ra, network.DirOutbound) - sb := newStream(wb, rb, network.DirInbound) + sa := newStream(cb1, cb2, done1, done2) + sb := newStream(cb2, cb1, done2, done1) return sa, sb } -func newStream(w *io.PipeWriter, r *io.PipeReader, _ network.Direction) *stream { +func newStream(rdRx <-chan []byte, wrTx chan<- []byte, localDone chan struct{}, remoteDone <-chan struct{}) *stream { s := &stream{ - id: streamCounter.Add(1), - read: r, - write: w, - writeC: make(chan []byte), - reset: make(chan struct{}, 1), - close: make(chan struct{}, 1), - closed: make(chan struct{}), + rdRx: rdRx, + wrTx: wrTx, + buf: new(bytes.Buffer), + localDone: localDone, + remoteDone: remoteDone, + reset: make(chan struct{}, 1), + close: make(chan struct{}, 1), } - go s.writeLoop() return s } -func (s *stream) Write(p []byte) (int, error) { - cpy := make([]byte, len(p)) - copy(cpy, p) +func (p *stream) Write(b []byte) (int, error) { + if p.writeClosed.Load() { + return 0, ErrClosed + } + + n, err := p.write(b) + if err != nil && err != io.ErrClosedPipe { + err = &net.OpError{Op: "write", Net: "pipe", Err: err} + } + return n, err +} + +func (p *stream) write(b []byte) (n int, err error) { + switch { + case isClosedChan(p.localDone): + return 0, io.ErrClosedPipe + case isClosedChan(p.remoteDone): + return 0, io.ErrClosedPipe + } + + p.wrMu.Lock() // Ensure entirety of b is written together + defer p.wrMu.Unlock() select { - case <-s.closed: - return 0, s.writeErr - case s.writeC <- cpy: + case <-p.close: + return n, ErrClosed + case <-p.reset: + return n, network.ErrReset + case p.wrTx <- b: + n += len(b) + case <-p.localDone: + return n, io.ErrClosedPipe + case <-p.remoteDone: + return n, io.ErrClosedPipe } - return len(p), nil + return n, nil } -func (s *stream) Read(p []byte) (int, error) { - return s.read.Read(p) +func (p *stream) Read(b []byte) (int, error) { + if p.readClosed.Load() { + return 0, ErrClosed + } + + n, err := p.read(b) + if err != nil && err != io.EOF && err != io.ErrClosedPipe { + err = &net.OpError{Op: "read", Net: "pipe", Err: err} + } + + return n, err +} + +func (p *stream) read(b []byte) (n int, err error) { + switch { + case isClosedChan(p.localDone): + return 0, io.ErrClosedPipe + case isClosedChan(p.remoteDone): + return 0, io.EOF + } + + select { + case <-p.reset: + return n, network.ErrReset + case bw, ok := <-p.rdRx: + if !ok { + p.readClosed.Store(true) + return 0, io.EOF + } + + p.buf.Write(bw) + case <-p.localDone: + return 0, io.ErrClosedPipe + case <-p.remoteDone: + return 0, io.EOF + default: + n, err = p.buf.Read(b) + } + + return n, err } func (s *stream) CloseWrite() error { @@ -74,16 +151,14 @@ func (s *stream) CloseWrite() error { case s.close <- struct{}{}: default: } - <-s.closed - if !errors.Is(s.writeErr, ErrClosed) { - return s.writeErr - } - return nil + s.writeClosed.Store(true) + return nil } func (s *stream) CloseRead() error { - return s.read.CloseWithError(ErrClosed) + s.readClosed.Store(true) + return nil } func (s *stream) Close() error { @@ -92,15 +167,15 @@ func (s *stream) Close() error { } func (s *stream) Reset() error { - // Cancel any pending reads/writes with an error. - s.write.CloseWithError(network.ErrReset) - s.read.CloseWithError(network.ErrReset) - select { case s.reset <- struct{}{}: default: } - <-s.closed + + s.once.Do(func() { + close(s.localDone) + }) + // No meaningful error case here. return nil } @@ -117,48 +192,11 @@ func (s *stream) SetWriteDeadline(t time.Time) error { return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } -func (s *stream) writeLoop() { - defer s.teardown() - - for { - // Reset takes precedent. - select { - case <-s.reset: - s.writeErr = network.ErrReset - return - default: - } - - select { - case <-s.reset: - s.writeErr = network.ErrReset - return - case <-s.close: - s.writeErr = s.write.Close() - if s.writeErr == nil { - s.writeErr = ErrClosed - } - return - case p := <-s.writeC: - if _, err := s.write.Write(p); err != nil { - s.cancelWrite(err) - return - } - } - } -} - -func (s *stream) cancelWrite(err error) { - s.write.CloseWithError(err) - s.writeErr = err -} - -func (s *stream) teardown() { - // at this point, no streams are writing. - if s.conn != nil { - s.conn.removeStream(s.id) +func isClosedChan(c <-chan struct{}) bool { + select { + case <-c: + return true + default: + return false } - - // Mark as closed. - close(s.closed) } diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go index 1f3e8fb0cd..f154bbbcdd 100644 --- a/p2p/transport/memory/stream_test.go +++ b/p2p/transport/memory/stream_test.go @@ -8,62 +8,60 @@ import ( func TestStreamSimpleReadWriteClose(t *testing.T) { t.Parallel() - clientStr, serverStr := newStreamPair() + streamLocal, streamRemote := newStreamPair() // send a foobar from the client - n, err := clientStr.Write([]byte("foobar")) + n, err := streamLocal.Write([]byte("foobar")) require.NoError(t, err) require.Equal(t, 6, n) - require.NoError(t, clientStr.CloseWrite()) + require.NoError(t, streamLocal.CloseWrite()) // writing after closing should error - _, err = clientStr.Write([]byte("foobar")) + _, err = streamLocal.Write([]byte("foobar")) require.Error(t, err) // now read all the data on the server side - b, err := io.ReadAll(serverStr) + b, err := io.ReadAll(streamRemote) require.NoError(t, err) require.Equal(t, []byte("foobar"), b) // reading again should give another io.EOF - n, err = serverStr.Read(make([]byte, 10)) + n, err = streamRemote.Read(make([]byte, 10)) require.Zero(t, n) require.ErrorIs(t, err, io.EOF) // send something back - _, err = serverStr.Write([]byte("lorem ipsum")) + _, err = streamRemote.Write([]byte("lorem ipsum")) require.NoError(t, err) - require.NoError(t, serverStr.CloseWrite()) + require.NoError(t, streamRemote.CloseWrite()) // and read it at the client - b, err = io.ReadAll(clientStr) + b, err = io.ReadAll(streamLocal) require.NoError(t, err) require.Equal(t, []byte("lorem ipsum"), b) // stream is only cleaned up on calling Close or Reset - clientStr.Close() - serverStr.Close() - // Need to call Close for cleanup. Otherwise the FIN_ACK is never read - require.NoError(t, serverStr.Close()) + require.NoError(t, streamLocal.Close()) + require.NoError(t, streamRemote.Close()) } func TestStreamPartialReads(t *testing.T) { t.Parallel() - clientStr, serverStr := newStreamPair() + streamLocal, streamRemote := newStreamPair() - _, err := serverStr.Write([]byte("foobar")) + _, err := streamRemote.Write([]byte("foobar")) require.NoError(t, err) - require.NoError(t, serverStr.CloseWrite()) + require.NoError(t, streamRemote.CloseWrite()) - n, err := clientStr.Read([]byte{}) // empty read + n, err := streamLocal.Read([]byte{}) // empty read require.NoError(t, err) require.Zero(t, n) b := make([]byte, 3) - n, err = clientStr.Read(b) + n, err = streamLocal.Read(b) require.Equal(t, 3, n) require.NoError(t, err) require.Equal(t, []byte("foo"), b) - b, err = io.ReadAll(clientStr) + b, err = io.ReadAll(streamLocal) require.NoError(t, err) require.Equal(t, []byte("bar"), b) } From 35ebf85e4243d3aed12683a8a210f04914fe59df Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Fri, 29 Nov 2024 15:44:02 +0100 Subject: [PATCH 11/12] Update errors when remote stream is closed - Return io.EOF when reading if remote stream is closed - Return network.ErrReset when reading/writing if remote stream is reset - Return io.ErrPipeClosed when reading if remote stream is closed --- p2p/transport/memory/conn.go | 9 +- p2p/transport/memory/conn_test.go | 1 + p2p/transport/memory/stream.go | 148 +++++++++++++++------------- p2p/transport/memory/stream_test.go | 51 ++++++++++ 4 files changed, 140 insertions(+), 69 deletions(-) create mode 100644 p2p/transport/memory/conn_test.go diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go index 61c81462cc..2f06fd6d6c 100644 --- a/p2p/transport/memory/conn.go +++ b/p2p/transport/memory/conn.go @@ -2,6 +2,7 @@ package memory import ( "context" + "log" "sync" "sync/atomic" @@ -79,10 +80,11 @@ func (c *conn) IsClosed() bool { func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { sl, sr := newStreamPair() - - c.streamC <- sr sl.conn = c c.addStream(sl.id, sl) + log.Println("opening stream", sl.id, sr.id) + + c.rconn.streamC <- sr return sl, nil } @@ -135,7 +137,8 @@ func (c *conn) removeStream(id int64) { } func (c *conn) teardown() { - for _, s := range c.streams { + for id, s := range c.streams { + log.Println("tearing down stream", id) s.Reset() } diff --git a/p2p/transport/memory/conn_test.go b/p2p/transport/memory/conn_test.go new file mode 100644 index 0000000000..05af74b9e7 --- /dev/null +++ b/p2p/transport/memory/conn_test.go @@ -0,0 +1 @@ +package memory diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index d196aea0fc..ff146ef110 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -6,12 +6,31 @@ import ( "io" "net" "sync" - "sync/atomic" "time" "github.com/libp2p/go-libp2p/core/network" ) +// onceError is an object that will only store an error once. +type onceError struct { + sync.Mutex // guards following + err error +} + +func (a *onceError) Store(err error) { + a.Lock() + defer a.Unlock() + if a.err != nil { + return + } + a.err = err +} +func (a *onceError) Load() error { + a.Lock() + defer a.Unlock() + return a.err +} + // stream implements network.Stream type stream struct { id int64 @@ -26,14 +45,16 @@ type stream struct { // Used by local Write to interact with remote Read. wrTx chan<- []byte - once sync.Once // Protects closing localDone + closeOnce sync.Once // Protects closing localDone localDone chan struct{} remoteDone <-chan struct{} - reset chan struct{} - close chan struct{} - readClosed atomic.Bool - writeClosed atomic.Bool + resetOnce sync.Once // Protects closing localReset + localReset chan struct{} + remoteReset <-chan struct{} + + rerr onceError + werr onceError } var ErrClosed = errors.New("stream closed") @@ -45,42 +66,47 @@ func newStreamPair() (*stream, *stream) { done1 := make(chan struct{}) done2 := make(chan struct{}) - sa := newStream(cb1, cb2, done1, done2) - sb := newStream(cb2, cb1, done2, done1) + reset1 := make(chan struct{}) + reset2 := make(chan struct{}) + + sa := newStream(cb1, cb2, done1, done2, reset1, reset2) + sb := newStream(cb2, cb1, done2, done1, reset2, reset1) return sa, sb } -func newStream(rdRx <-chan []byte, wrTx chan<- []byte, localDone chan struct{}, remoteDone <-chan struct{}) *stream { +func newStream(rdRx <-chan []byte, wrTx chan<- []byte, localDone chan struct{}, remoteDone <-chan struct{}, localReset chan struct{}, remoteReset <-chan struct{}) *stream { s := &stream{ - rdRx: rdRx, - wrTx: wrTx, - buf: new(bytes.Buffer), - localDone: localDone, - remoteDone: remoteDone, - reset: make(chan struct{}, 1), - close: make(chan struct{}, 1), + id: streamCounter.Add(1), + rdRx: rdRx, + wrTx: wrTx, + buf: new(bytes.Buffer), + localDone: localDone, + remoteDone: remoteDone, + localReset: localReset, + remoteReset: remoteReset, } return s } func (p *stream) Write(b []byte) (int, error) { - if p.writeClosed.Load() { - return 0, ErrClosed + if err := p.werr.Load(); err != nil { + return 0, err } - n, err := p.write(b) - if err != nil && err != io.ErrClosedPipe { - err = &net.OpError{Op: "write", Net: "pipe", Err: err} - } - return n, err + return p.write(b) + //if err != nil && err != io.ErrClosedPipe && err != network.ErrReset { + // err = &net.OpError{Op: "write", Net: "pipe", Err: err} + //} + // + //return n, err } func (p *stream) write(b []byte) (n int, err error) { switch { - case isClosedChan(p.localDone): - return 0, io.ErrClosedPipe + case isClosedChan(p.remoteReset): + return 0, network.ErrReset case isClosedChan(p.remoteDone): return 0, io.ErrClosedPipe } @@ -89,91 +115,81 @@ func (p *stream) write(b []byte) (n int, err error) { defer p.wrMu.Unlock() select { - case <-p.close: - return n, ErrClosed - case <-p.reset: - return n, network.ErrReset case p.wrTx <- b: n += len(b) - case <-p.localDone: - return n, io.ErrClosedPipe - case <-p.remoteDone: - return n, io.ErrClosedPipe } return n, nil } func (p *stream) Read(b []byte) (int, error) { - if p.readClosed.Load() { - return 0, ErrClosed + if err := p.rerr.Load(); err != nil { + return 0, err } - n, err := p.read(b) - if err != nil && err != io.EOF && err != io.ErrClosedPipe { - err = &net.OpError{Op: "read", Net: "pipe", Err: err} - } - - return n, err + return p.read(b) + //if err != nil && err != io.EOF && err != io.ErrClosedPipe && err != network.ErrReset { + // err = &net.OpError{Op: "read", Net: "pipe", Err: err} + //} + // + //return n, err } func (p *stream) read(b []byte) (n int, err error) { + var readErr error + switch { - case isClosedChan(p.localDone): - return 0, io.ErrClosedPipe + case isClosedChan(p.remoteReset): + err = network.ErrReset case isClosedChan(p.remoteDone): - return 0, io.EOF + err = io.EOF } select { - case <-p.reset: - return n, network.ErrReset case bw, ok := <-p.rdRx: if !ok { - p.readClosed.Store(true) - return 0, io.EOF + err = io.EOF + p.rerr.Store(err) + return } p.buf.Write(bw) - case <-p.localDone: - return 0, io.ErrClosedPipe - case <-p.remoteDone: - return 0, io.EOF default: - n, err = p.buf.Read(b) + } + + n, readErr = p.buf.Read(b) + if err == nil { + err = readErr } return n, err } func (s *stream) CloseWrite() error { - select { - case s.close <- struct{}{}: - default: - } - - s.writeClosed.Store(true) + s.werr.Store(ErrClosed) return nil } func (s *stream) CloseRead() error { - s.readClosed.Store(true) + s.rerr.Store(ErrClosed) return nil } func (s *stream) Close() error { + s.closeOnce.Do(func() { + close(s.localDone) + }) + _ = s.CloseRead() return s.CloseWrite() } func (s *stream) Reset() error { - select { - case s.reset <- struct{}{}: - default: - } + s.rerr.Store(network.ErrReset) + s.werr.Store(network.ErrReset) - s.once.Do(func() { - close(s.localDone) + s.resetOnce.Do(func() { + close(s.localReset) }) // No meaningful error case here. diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go index f154bbbcdd..09e8425993 100644 --- a/p2p/transport/memory/stream_test.go +++ b/p2p/transport/memory/stream_test.go @@ -1,9 +1,12 @@ package memory import ( + "errors" + "github.com/libp2p/go-libp2p/core/network" "github.com/stretchr/testify/require" "io" "testing" + "time" ) func TestStreamSimpleReadWriteClose(t *testing.T) { @@ -65,3 +68,51 @@ func TestStreamPartialReads(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte("bar"), b) } + +func TestStreamResets(t *testing.T) { + clientStr, serverStr := newStreamPair() + + // send a foobar from the client + _, err := clientStr.Write([]byte("foobar")) + require.NoError(t, err) + _, err = serverStr.Write([]byte("lorem ipsum")) + require.NoError(t, err) + require.NoError(t, clientStr.Reset()) // resetting resets both directions + // attempting to write more data should result in a reset error + _, err = clientStr.Write([]byte("foobar")) + require.ErrorIs(t, err, network.ErrReset) + // read what the server sent + b, err := io.ReadAll(clientStr) + require.Empty(t, b) + require.ErrorIs(t, err, network.ErrReset) + + // read the data on the server side + b, err = io.ReadAll(serverStr) + require.Equal(t, []byte("foobar"), b) + require.ErrorIs(t, err, network.ErrReset) + require.Eventually(t, func() bool { + _, err := serverStr.Write([]byte("foobar")) + return errors.Is(err, network.ErrReset) + }, time.Second, 50*time.Millisecond) + serverStr.Close() +} + +func TestStreamReadAfterClose(t *testing.T) { + clientStr, serverStr := newStreamPair() + + serverStr.Close() + b := make([]byte, 1) + _, err := clientStr.Read(b) + require.Equal(t, io.EOF, err) + _, err = clientStr.Read(nil) + require.Equal(t, io.EOF, err) + + clientStr, serverStr = newStreamPair() + + serverStr.Reset() + b = make([]byte, 1) + _, err = clientStr.Read(b) + require.ErrorIs(t, err, network.ErrReset) + _, err = clientStr.Read(nil) + require.ErrorIs(t, err, network.ErrReset) +} From 2e3378f0c498face8a86096d05f04e822581fab5 Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Fri, 6 Dec 2024 10:52:08 +0100 Subject: [PATCH 12/12] Add read and write timeouts --- p2p/transport/memory/stream.go | 70 ++++++++++++++++++-- p2p/transport/memory/transport_test.go | 89 ++++++++++++++++++++++---- 2 files changed, 143 insertions(+), 16 deletions(-) diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index ff146ef110..3ff5cf9c63 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -4,7 +4,7 @@ import ( "bytes" "errors" "io" - "net" + "os" "sync" "time" @@ -53,6 +53,10 @@ type stream struct { localReset chan struct{} remoteReset <-chan struct{} + mu sync.RWMutex + readDeadline time.Time + writeDeadline time.Time + rerr onceError werr onceError } @@ -111,15 +115,39 @@ func (p *stream) write(b []byte) (n int, err error) { return 0, io.ErrClosedPipe } + p.mu.RLock() + writeDeadline := p.writeDeadline + p.mu.RUnlock() + + if !writeDeadline.IsZero() && time.Now().After(writeDeadline) { + return 0, os.ErrDeadlineExceeded + } + var ( + writeDeadlineTimer *time.Timer + writeDeadlineChan <-chan time.Time + ) + defer func() { + if writeDeadlineTimer != nil { + writeDeadlineTimer.Stop() + } + }() + + if !writeDeadline.IsZero() { + writeDeadlineTimer = time.NewTimer(time.Until(writeDeadline)) + writeDeadlineChan = writeDeadlineTimer.C + } + p.wrMu.Lock() // Ensure entirety of b is written together defer p.wrMu.Unlock() select { + case <-writeDeadlineChan: + err = os.ErrDeadlineExceeded case p.wrTx <- b: n += len(b) } - return n, nil + return n, err } func (p *stream) Read(b []byte) (int, error) { @@ -145,7 +173,32 @@ func (p *stream) read(b []byte) (n int, err error) { err = io.EOF } + p.mu.RLock() + readDeadline := p.readDeadline + p.mu.RUnlock() + + if !readDeadline.IsZero() && time.Now().After(readDeadline) { + return 0, os.ErrDeadlineExceeded + } + + var ( + readDeadlineTimer *time.Timer + readDeadlineChan <-chan time.Time + ) + defer func() { + if readDeadlineTimer != nil { + readDeadlineTimer.Stop() + } + }() + + if !readDeadline.IsZero() { + readDeadlineTimer = time.NewTimer(time.Until(readDeadline)) + readDeadlineChan = readDeadlineTimer.C + } + select { + case <-readDeadlineChan: + err = os.ErrDeadlineExceeded case bw, ok := <-p.rdRx: if !ok { err = io.EOF @@ -197,15 +250,22 @@ func (s *stream) Reset() error { } func (s *stream) SetDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} + _ = s.SetReadDeadline(t) + return s.SetWriteDeadline(t) } func (s *stream) SetReadDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} + s.mu.Lock() + defer s.mu.Unlock() + s.readDeadline = t + return nil } func (s *stream) SetWriteDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} + s.mu.Lock() + defer s.mu.Unlock() + s.writeDeadline = t + return nil } func isClosedChan(c <-chan struct{}) bool { diff --git a/p2p/transport/memory/transport_test.go b/p2p/transport/memory/transport_test.go index f17835f3ff..437576b30d 100644 --- a/p2p/transport/memory/transport_test.go +++ b/p2p/transport/memory/transport_test.go @@ -1,33 +1,36 @@ package memory import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" + "context" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" "io" "testing" - ic "github.com/libp2p/go-libp2p/core/crypto" tpt "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" ) -func getTransport(t *testing.T) tpt.Transport { +func getTransport(t *testing.T) (tpt.Transport, peer.ID) { t.Helper() - rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + privKey, _, err := crypto.GenerateKeyPair(crypto.Ed25519, -1) require.NoError(t, err) - key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) + rcmgr := &network.NullResourceManager{} require.NoError(t, err) - tr, err := NewTransport(key, nil, nil) + tr, err := NewTransport(privKey, nil, rcmgr) require.NoError(t, err) - return tr + peerID, err := peer.IDFromPrivateKey(privKey) + require.NoError(t, err) + t.Cleanup(func() { rcmgr.Close() }) + return tr, peerID } func TestMemoryProtocol(t *testing.T) { t.Parallel() - tr := getTransport(t) + tr, _ := getTransport(t) defer tr.(io.Closer).Close() protocols := tr.Protocols() @@ -42,7 +45,7 @@ func TestMemoryProtocol(t *testing.T) { func TestCanDial(t *testing.T) { t.Parallel() - tr := getTransport(t) + tr, _ := getTransport(t) defer tr.(io.Closer).Close() invalid := []string{ @@ -70,3 +73,67 @@ func TestCanDial(t *testing.T) { } } } + +func TestTransport_Listen(t *testing.T) { + t.Parallel() + server, _ := getTransport(t) + defer server.(io.Closer).Close() + + addr, err := ma.NewMultiaddr("/memory/1234") + require.NoError(t, err) + serverListener, err := server.Listen(addr) + require.NoError(t, err) + defer serverListener.Close() + lma := serverListener.Multiaddr() + require.Equal(t, addr, lma) +} + +func TestTransport_Dial(t *testing.T) { + t.Parallel() + server, serverPeerID := getTransport(t) + client, clientPeerID := getTransport(t) + defer func() { + if server != nil { + err := server.(io.Closer).Close() + require.NoError(t, err) + } + }() + + defer func() { + if client != nil { + err := client.(io.Closer).Close() + require.NoError(t, err) + } + }() + + serverAddr, err := ma.NewMultiaddr("/memory/1234") + require.NoError(t, err) + serverListener, err := server.Listen(serverAddr) + require.NoError(t, err) + defer func() { + if serverListener != nil { + err = serverListener.Close() + require.NoError(t, err) + } + }() + + c, err := client.Dial(context.Background(), serverAddr, serverPeerID) + require.NoError(t, err) + defer func() { + if c != nil { + err = c.Close() + require.NoError(t, err) + } + }() + + require.Equal(t, serverAddr, c.RemoteMultiaddr()) + require.Equal(t, clientPeerID, c.LocalPeer()) + require.Equal(t, serverPeerID, c.RemotePeer()) + + // Try to dial address with no listener + otherAddr, err := ma.NewMultiaddr("/memory/4321") + require.NoError(t, err) + + _, err = client.Dial(context.Background(), otherAddr, serverPeerID) + require.Error(t, err) +}