Skip to content

Commit

Permalink
Daily commit
Browse files Browse the repository at this point in the history
  • Loading branch information
pyropy committed Nov 22, 2024
1 parent b96f386 commit 5236ff2
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 110 deletions.
27 changes: 20 additions & 7 deletions p2p/transport/memory/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import (
)

type conn struct {
id int64
id int64
rconn *conn

transport *transport
scope network.ConnManagementScope
listener *listener
transport *transport

localPeer peer.ID
localMultiaddr ma.Multiaddr
Expand Down Expand Up @@ -62,11 +64,11 @@ func newConnection(
}

func (c *conn) Close() error {
c.closed.Store(true)
for _, s := range c.streams {
//c.removeStream(id)
s.Close()
}
c.closeOnce.Do(func() {
c.closed.Store(true)
go c.rconn.Close()
c.teardown()
})

return nil
}
Expand All @@ -79,11 +81,14 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
sl, sr := newStreamPair()

c.streamC <- sr
sl.conn = c
c.addStream(sl.id, sl)
return sl, nil
}

func (c *conn) AcceptStream() (network.MuxedStream, error) {
in := <-c.streamC
in.conn = c
c.addStream(in.id, in)
return in, nil
}
Expand Down Expand Up @@ -128,3 +133,11 @@ func (c *conn) removeStream(id int64) {

delete(c.streams, id)
}

func (c *conn) teardown() {
for _, s := range c.streams {
s.Reset()
}

// TODO: remove self from listener
}
80 changes: 80 additions & 0 deletions p2p/transport/memory/hub.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package memory

import (
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
ma "github.com/multiformats/go-multiaddr"
mafmt "github.com/multiformats/go-multiaddr-fmt"
"sync"
"sync/atomic"
)

var (
connCounter atomic.Int64
streamCounter atomic.Int64
listenerCounter atomic.Int64
dialMatcher = mafmt.Base(ma.P_MEMORY)
memhub = newHub()
)

type hub struct {
mu sync.RWMutex
closeOnce sync.Once
pubKeys map[peer.ID]ic.PubKey
listeners map[string]*listener
}

func newHub() *hub {
return &hub{
pubKeys: make(map[peer.ID]ic.PubKey),
listeners: make(map[string]*listener),
}
}

func (h *hub) addListener(addr string, l *listener) {
h.mu.Lock()
defer h.mu.Unlock()

h.listeners[addr] = l
}

func (h *hub) removeListener(addr string, l *listener) {
h.mu.Lock()
defer h.mu.Unlock()

delete(h.listeners, addr)
}

func (h *hub) getListener(addr string) (*listener, bool) {
h.mu.RLock()
defer h.mu.RUnlock()

l, ok := h.listeners[addr]
return l, ok
}

func (h *hub) addPubKey(p peer.ID, pk ic.PubKey) {
h.mu.Lock()
defer h.mu.Unlock()

h.pubKeys[p] = pk
}

func (h *hub) getPubKey(p peer.ID) (ic.PubKey, bool) {
h.mu.RLock()
defer h.mu.RUnlock()

pk, ok := h.pubKeys[p]
return pk, ok
}

func (h *hub) close() {
h.closeOnce.Do(func() {
h.mu.Lock()
defer h.mu.Unlock()

for _, l := range h.listeners {
l.Close()
}
})
}
4 changes: 4 additions & 0 deletions p2p/transport/memory/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ const (
)

type listener struct {
id int64

t *transport
ctx context.Context
cancel context.CancelFunc
Expand All @@ -31,6 +33,7 @@ func (l *listener) Multiaddr() ma.Multiaddr {
func newListener(t *transport, laddr ma.Multiaddr) *listener {
ctx, cancel := context.WithCancel(context.Background())
return &listener{
id: listenerCounter.Add(1),
t: t,
ctx: ctx,
cancel: cancel,
Expand All @@ -53,6 +56,7 @@ func (l *listener) Accept() (tpt.CapableConn, error) {
l.mu.Lock()
defer l.mu.Unlock()

c.listener = l
c.transport = l.t
l.connections[c.id] = c
return c, nil
Expand Down
38 changes: 25 additions & 13 deletions p2p/transport/memory/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package memory
import (
"errors"
"io"
"log"
"net"
"time"

Expand All @@ -12,7 +11,8 @@ import (

// stream implements network.Stream
type stream struct {
id int64
id int64
conn *conn

read *io.PipeReader
write *io.PipeWriter
Expand All @@ -31,13 +31,13 @@ func newStreamPair() (*stream, *stream) {
ra, wb := io.Pipe()
rb, wa := io.Pipe()

sa := newStream(rb, wa, network.DirOutbound)
sb := newStream(ra, wb, network.DirInbound)
sa := newStream(wa, ra, network.DirOutbound)
sb := newStream(wb, rb, network.DirInbound)

return sa, sb
}

func newStream(r *io.PipeReader, w *io.PipeWriter, _ network.Direction) *stream {
func newStream(w *io.PipeWriter, r *io.PipeReader, _ network.Direction) *stream {
s := &stream{
id: streamCounter.Add(1),
read: r,
Expand All @@ -47,7 +47,6 @@ func newStream(r *io.PipeReader, w *io.PipeWriter, _ network.Direction) *stream
close: make(chan struct{}, 1),
closed: make(chan struct{}),
}
log.Println("newStream", "id", s.id)

go s.writeLoop()
return s
Expand All @@ -66,7 +65,7 @@ func (s *stream) Write(p []byte) (int, error) {
return len(p), nil
}

func (s *stream) Read(p []byte) (n int, err error) {
func (s *stream) Read(p []byte) (int, error) {
return s.read.Read(p)
}

Expand All @@ -75,9 +74,7 @@ func (s *stream) CloseWrite() error {
case s.close <- struct{}{}:
default:
}
log.Println("waiting close", "id", s.id)
<-s.closed
log.Println("closed write", "id", s.id)
if !errors.Is(s.writeErr, ErrClosed) {
return s.writeErr
}
Expand All @@ -95,6 +92,7 @@ func (s *stream) Close() error {
}

func (s *stream) Reset() error {
// Cancel any pending reads/writes with an error.
s.write.CloseWithError(network.ErrReset)
s.read.CloseWithError(network.ErrReset)

Expand All @@ -103,7 +101,7 @@ func (s *stream) Reset() error {
default:
}
<-s.closed

// No meaningful error case here.
return nil
}

Expand All @@ -120,8 +118,7 @@ func (s *stream) SetWriteDeadline(t time.Time) error {
}

func (s *stream) writeLoop() {
defer close(s.closed)
defer log.Println("closing write", "id", s.id)
defer s.teardown()

for {
// Reset takes precedent.
Expand All @@ -144,9 +141,24 @@ func (s *stream) writeLoop() {
return
case p := <-s.writeC:
if _, err := s.write.Write(p); err != nil {
s.writeErr = err
s.cancelWrite(err)
return
}
}
}
}

func (s *stream) cancelWrite(err error) {
s.write.CloseWithError(err)
s.writeErr = err
}

func (s *stream) teardown() {
// at this point, no streams are writing.
if s.conn != nil {
s.conn.removeStream(s.id)
}

// Mark as closed.
close(s.closed)
}
3 changes: 1 addition & 2 deletions p2p/transport/memory/stream_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package memory

import (
"github.com/stretchr/testify/require"
"io"
"testing"

"github.com/stretchr/testify/require"
)

func TestStreamSimpleReadWriteClose(t *testing.T) {
Expand Down
Loading

0 comments on commit 5236ff2

Please sign in to comment.