diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go index df9b08b4b2..61c81462cc 100644 --- a/p2p/transport/memory/conn.go +++ b/p2p/transport/memory/conn.go @@ -13,10 +13,12 @@ import ( ) type conn struct { - id int64 + id int64 + rconn *conn - transport *transport scope network.ConnManagementScope + listener *listener + transport *transport localPeer peer.ID localMultiaddr ma.Multiaddr @@ -62,11 +64,11 @@ func newConnection( } func (c *conn) Close() error { - c.closed.Store(true) - for _, s := range c.streams { - //c.removeStream(id) - s.Close() - } + c.closeOnce.Do(func() { + c.closed.Store(true) + go c.rconn.Close() + c.teardown() + }) return nil } @@ -79,11 +81,14 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { sl, sr := newStreamPair() c.streamC <- sr + sl.conn = c + c.addStream(sl.id, sl) return sl, nil } func (c *conn) AcceptStream() (network.MuxedStream, error) { in := <-c.streamC + in.conn = c c.addStream(in.id, in) return in, nil } @@ -128,3 +133,11 @@ func (c *conn) removeStream(id int64) { delete(c.streams, id) } + +func (c *conn) teardown() { + for _, s := range c.streams { + s.Reset() + } + + // TODO: remove self from listener +} diff --git a/p2p/transport/memory/hub.go b/p2p/transport/memory/hub.go new file mode 100644 index 0000000000..55b85ccbad --- /dev/null +++ b/p2p/transport/memory/hub.go @@ -0,0 +1,80 @@ +package memory + +import ( + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + ma "github.com/multiformats/go-multiaddr" + mafmt "github.com/multiformats/go-multiaddr-fmt" + "sync" + "sync/atomic" +) + +var ( + connCounter atomic.Int64 + streamCounter atomic.Int64 + listenerCounter atomic.Int64 + dialMatcher = mafmt.Base(ma.P_MEMORY) + memhub = newHub() +) + +type hub struct { + mu sync.RWMutex + closeOnce sync.Once + pubKeys map[peer.ID]ic.PubKey + listeners map[string]*listener +} + +func newHub() *hub { + return &hub{ + pubKeys: make(map[peer.ID]ic.PubKey), + listeners: make(map[string]*listener), + } +} + +func (h *hub) addListener(addr string, l *listener) { + h.mu.Lock() + defer h.mu.Unlock() + + h.listeners[addr] = l +} + +func (h *hub) removeListener(addr string, l *listener) { + h.mu.Lock() + defer h.mu.Unlock() + + delete(h.listeners, addr) +} + +func (h *hub) getListener(addr string) (*listener, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + l, ok := h.listeners[addr] + return l, ok +} + +func (h *hub) addPubKey(p peer.ID, pk ic.PubKey) { + h.mu.Lock() + defer h.mu.Unlock() + + h.pubKeys[p] = pk +} + +func (h *hub) getPubKey(p peer.ID) (ic.PubKey, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + pk, ok := h.pubKeys[p] + return pk, ok +} + +func (h *hub) close() { + h.closeOnce.Do(func() { + h.mu.Lock() + defer h.mu.Unlock() + + for _, l := range h.listeners { + l.Close() + } + }) +} diff --git a/p2p/transport/memory/listener.go b/p2p/transport/memory/listener.go index 54417e2a8b..81d5e484d0 100644 --- a/p2p/transport/memory/listener.go +++ b/p2p/transport/memory/listener.go @@ -14,6 +14,8 @@ const ( ) type listener struct { + id int64 + t *transport ctx context.Context cancel context.CancelFunc @@ -31,6 +33,7 @@ func (l *listener) Multiaddr() ma.Multiaddr { func newListener(t *transport, laddr ma.Multiaddr) *listener { ctx, cancel := context.WithCancel(context.Background()) return &listener{ + id: listenerCounter.Add(1), t: t, ctx: ctx, cancel: cancel, @@ -53,6 +56,7 @@ func (l *listener) Accept() (tpt.CapableConn, error) { l.mu.Lock() defer l.mu.Unlock() + c.listener = l c.transport = l.t l.connections[c.id] = c return c, nil diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index a445a712b2..6aa59bfc7b 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -3,7 +3,6 @@ package memory import ( "errors" "io" - "log" "net" "time" @@ -12,7 +11,8 @@ import ( // stream implements network.Stream type stream struct { - id int64 + id int64 + conn *conn read *io.PipeReader write *io.PipeWriter @@ -31,13 +31,13 @@ func newStreamPair() (*stream, *stream) { ra, wb := io.Pipe() rb, wa := io.Pipe() - sa := newStream(rb, wa, network.DirOutbound) - sb := newStream(ra, wb, network.DirInbound) + sa := newStream(wa, ra, network.DirOutbound) + sb := newStream(wb, rb, network.DirInbound) return sa, sb } -func newStream(r *io.PipeReader, w *io.PipeWriter, _ network.Direction) *stream { +func newStream(w *io.PipeWriter, r *io.PipeReader, _ network.Direction) *stream { s := &stream{ id: streamCounter.Add(1), read: r, @@ -47,7 +47,6 @@ func newStream(r *io.PipeReader, w *io.PipeWriter, _ network.Direction) *stream close: make(chan struct{}, 1), closed: make(chan struct{}), } - log.Println("newStream", "id", s.id) go s.writeLoop() return s @@ -66,7 +65,7 @@ func (s *stream) Write(p []byte) (int, error) { return len(p), nil } -func (s *stream) Read(p []byte) (n int, err error) { +func (s *stream) Read(p []byte) (int, error) { return s.read.Read(p) } @@ -75,9 +74,7 @@ func (s *stream) CloseWrite() error { case s.close <- struct{}{}: default: } - log.Println("waiting close", "id", s.id) <-s.closed - log.Println("closed write", "id", s.id) if !errors.Is(s.writeErr, ErrClosed) { return s.writeErr } @@ -95,6 +92,7 @@ func (s *stream) Close() error { } func (s *stream) Reset() error { + // Cancel any pending reads/writes with an error. s.write.CloseWithError(network.ErrReset) s.read.CloseWithError(network.ErrReset) @@ -103,7 +101,7 @@ func (s *stream) Reset() error { default: } <-s.closed - + // No meaningful error case here. return nil } @@ -120,8 +118,7 @@ func (s *stream) SetWriteDeadline(t time.Time) error { } func (s *stream) writeLoop() { - defer close(s.closed) - defer log.Println("closing write", "id", s.id) + defer s.teardown() for { // Reset takes precedent. @@ -144,9 +141,24 @@ func (s *stream) writeLoop() { return case p := <-s.writeC: if _, err := s.write.Write(p); err != nil { - s.writeErr = err + s.cancelWrite(err) return } } } } + +func (s *stream) cancelWrite(err error) { + s.write.CloseWithError(err) + s.writeErr = err +} + +func (s *stream) teardown() { + // at this point, no streams are writing. + if s.conn != nil { + s.conn.removeStream(s.id) + } + + // Mark as closed. + close(s.closed) +} diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go index 5ce3939863..1f3e8fb0cd 100644 --- a/p2p/transport/memory/stream_test.go +++ b/p2p/transport/memory/stream_test.go @@ -1,10 +1,9 @@ package memory import ( + "github.com/stretchr/testify/require" "io" "testing" - - "github.com/stretchr/testify/require" ) func TestStreamSimpleReadWriteClose(t *testing.T) { diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go index 56cde6a3f6..39d152e7d3 100644 --- a/p2p/transport/memory/transport.go +++ b/p2p/transport/memory/transport.go @@ -3,85 +3,15 @@ package memory import ( "context" "errors" - "sync" - "sync/atomic" - ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/pnet" tpt "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" - mafmt "github.com/multiformats/go-multiaddr-fmt" -) - -var ( - connCounter atomic.Int64 - streamCounter atomic.Int64 - listenerCounter atomic.Int64 - dialMatcher = mafmt.Base(ma.P_MEMORY) + "sync" ) -type hub struct { - mu sync.RWMutex - closeOnce sync.Once - pubKeys map[peer.ID]ic.PubKey - listeners map[string]*listener -} - -func (h *hub) addListener(addr string, l *listener) { - h.mu.Lock() - defer h.mu.Unlock() - - h.listeners[addr] = l -} - -func (h *hub) removeListener(addr string, l *listener) { - h.mu.Lock() - defer h.mu.Unlock() - - delete(h.listeners, addr) -} - -func (h *hub) getListener(addr string) (*listener, bool) { - h.mu.RLock() - defer h.mu.RUnlock() - - l, ok := h.listeners[addr] - return l, ok -} - -func (h *hub) addPubKey(p peer.ID, pk ic.PubKey) { - h.mu.Lock() - defer h.mu.Unlock() - - h.pubKeys[p] = pk -} - -func (h *hub) getPubKey(p peer.ID) (ic.PubKey, bool) { - h.mu.RLock() - defer h.mu.RUnlock() - - pk, ok := h.pubKeys[p] - return pk, ok -} - -func (h *hub) close() { - h.closeOnce.Do(func() { - h.mu.Lock() - defer h.mu.Unlock() - - for _, l := range h.listeners { - l.Close() - } - }) -} - -var memhub = &hub{ - listeners: make(map[string]*listener), - pubKeys: make(map[peer.ID]ic.PubKey), -} - type transport struct { psk pnet.PSK rcmgr network.ResourceManager @@ -129,15 +59,12 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return c, nil } -func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, rpid peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { +func (t *transport) dialWithScope(_ context.Context, raddr ma.Multiaddr, rpid peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { if err := scope.SetPeer(rpid); err != nil { return nil, err } - // TODO: Check if there is an existing listener for this address - t.mu.RLock() - defer t.mu.RUnlock() - l, ok := memhub.getListener(raddr.String()) + rl, ok := memhub.getListener(raddr.String()) if !ok { return nil, errors.New("failed to get listener") } @@ -147,14 +74,10 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, rpid return nil, errors.New("failed to get remote public key") } - sl, sr := newStreamPair() - - lconn := newConnection(t, sl, t.localPeerID, nil, remotePubKey, rpid, raddr) - rconn := newConnection(nil, sr, rpid, raddr, t.localPubKey, t.localPeerID, nil) + lc, rc := t.newConnPair(remotePubKey, rpid, raddr) - l.connCh <- rconn - - return lconn, nil + rl.connCh <- rc + return lc, nil } func (t *transport) CanDial(addr ma.Multiaddr) bool { @@ -164,10 +87,6 @@ func (t *transport) CanDial(addr ma.Multiaddr) bool { func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { // TODO: Check if we need to add scope via conn mngr l := newListener(t, laddr) - - t.mu.Lock() - defer t.mu.Unlock() - memhub.addListener(laddr.String(), l) return l, nil @@ -188,7 +107,8 @@ func (t *transport) String() string { func (t *transport) Close() error { // TODO: Go trough all listeners and close them - //memhub.close() + t.mu.Lock() + defer t.mu.Unlock() for _, c := range t.connections { c.Close() @@ -211,3 +131,14 @@ func (t *transport) removeConn(c *conn) { delete(t.connections, c.id) } + +func (t *transport) newConnPair(remotePubKey ic.PubKey, rpid peer.ID, raddr ma.Multiaddr) (*conn, *conn) { + sl, sr := newStreamPair() + + lc := newConnection(t, sl, t.localPeerID, nil, remotePubKey, rpid, raddr) + rc := newConnection(nil, sr, rpid, raddr, t.localPubKey, t.localPeerID, nil) + + lc.rconn = rc + rc.rconn = lc + return lc, rc +}