diff --git a/.github/workflows/release-check.yml b/.github/workflows/release-check.yml index 0b5ff6070f..4d37f9b4a3 100644 --- a/.github/workflows/release-check.yml +++ b/.github/workflows/release-check.yml @@ -16,4 +16,4 @@ concurrency: jobs: release-check: - uses: ipdxco/unified-github-workflows/.github/workflows/release-check.yml@v1.0 + uses: marcopolo/unified-github-workflows/.github/workflows/release-check.yml@e66cb9667a2e1148efda4591e29c56258eaf385b diff --git a/FUNDING.json b/FUNDING.json new file mode 100644 index 0000000000..5952e90cb0 --- /dev/null +++ b/FUNDING.json @@ -0,0 +1,10 @@ +{ + "opRetro": { + "projectId": "0xc71faa1bcb4ceb785816c3f22823377e9e5e7c48649badd9f0a0ce491f20d4b3" + }, + "drips": { + "filecoin": { + "ownedBy": "0x53DCAf729e11022D5b8949946f6601211C662B38" + } + } + } diff --git a/config/config.go b/config/config.go index 07bee93c60..900c06bc30 100644 --- a/config/config.go +++ b/config/config.go @@ -36,7 +36,9 @@ import ( circuitv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" + "github.com/libp2p/go-libp2p/p2p/protocol/identify" "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "github.com/prometheus/client_golang/prometheus" @@ -142,6 +144,10 @@ type Config struct { CustomUDPBlackHoleSuccessCounter bool IPv6BlackHoleSuccessCounter *swarm.BlackHoleSuccessCounter CustomIPv6BlackHoleSuccessCounter bool + + UserFxOptions []fx.Option + + ShareTCPListener bool } func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) { @@ -286,6 +292,12 @@ func (cfg *Config) addTransports() ([]fx.Option, error) { fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }), fx.Provide(func() pnet.PSK { return cfg.PSK }), fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }), + fx.Provide(func(gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *tcpreuse.ConnMgr { + if !cfg.ShareTCPListener { + return nil + } + return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, gater, rcmgr) + }), fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn { hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool { quicAddrPorts := map[string]struct{}{} @@ -482,6 +494,9 @@ func (cfg *Config) NewNode() (host.Host, error) { return sw, nil }), fx.Provide(cfg.newBasicHost), + fx.Provide(func(bh *bhost.BasicHost) identify.IDService { + return bh.IDService() + }), fx.Provide(func(bh *bhost.BasicHost) host.Host { return bh }), @@ -536,6 +551,8 @@ func (cfg *Config) NewNode() (host.Host, error) { fxopts = append(fxopts, fx.Invoke(func(bho *routed.RoutedHost) { rh = bho })) } + fxopts = append(fxopts, cfg.UserFxOptions...) + app := fx.New(fxopts...) if err := app.Start(context.Background()); err != nil { return nil, err diff --git a/core/routing/routing.go b/core/routing/routing.go index b995b052b9..bb8de71541 100644 --- a/core/routing/routing.go +++ b/core/routing/routing.go @@ -18,17 +18,18 @@ var ErrNotFound = errors.New("routing: not found") // type/operation. var ErrNotSupported = errors.New("routing: operation or key not supported") -// ContentRouting is a value provider layer of indirection. It is used to find -// information about who has what content. -// -// Content is identified by CID (content identifier), which encodes a hash -// of the identified content in a future-proof manner. -type ContentRouting interface { +// ContentProviding is able to announce where to find content on the Routing +// system. +type ContentProviding interface { // Provide adds the given cid to the content routing system. If 'true' is // passed, it also announces it, otherwise it is just kept in the local // accounting of which objects are being provided. Provide(context.Context, cid.Cid, bool) error +} +// ContentDiscovery is able to retrieve providers for a given CID using +// the Routing system. +type ContentDiscovery interface { // Search for peers who are able to provide a given key // // When count is 0, this method will return an unbounded number of @@ -36,6 +37,16 @@ type ContentRouting interface { FindProvidersAsync(context.Context, cid.Cid, int) <-chan peer.AddrInfo } +// ContentRouting is a value provider layer of indirection. It is used to find +// information about who has what content. +// +// Content is identified by CID (content identifier), which encodes a hash +// of the identified content in a future-proof manner. +type ContentRouting interface { + ContentProviding + ContentDiscovery +} + // PeerRouting is a way to find address information about certain peers. // This can be implemented by a simple lookup table, a tracking server, // or even a DHT. diff --git a/dashboards/prometheus.yml b/dashboards/prometheus.yml index f091718860..89534d55dd 100644 --- a/dashboards/prometheus.yml +++ b/dashboards/prometheus.yml @@ -6,7 +6,7 @@ alerting: alertmanagers: - scheme: http timeout: 10s - api_version: v1 + api_version: v2 static_configs: - targets: [] scrape_configs: diff --git a/funding.json b/funding.json deleted file mode 100644 index f32c67b079..0000000000 --- a/funding.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "opRetro": { - "projectId": "0xc71faa1bcb4ceb785816c3f22823377e9e5e7c48649badd9f0a0ce491f20d4b3" - } - } \ No newline at end of file diff --git a/fx_options_test.go b/fx_options_test.go new file mode 100644 index 0000000000..48ac79b53d --- /dev/null +++ b/fx_options_test.go @@ -0,0 +1,60 @@ +package libp2p + +import ( + "testing" + + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/protocol/identify" + "github.com/stretchr/testify/require" + "go.uber.org/fx" +) + +func TestGetPeerID(t *testing.T) { + var id peer.ID + host, err := New( + WithFxOption(fx.Populate(&id)), + ) + require.NoError(t, err) + defer host.Close() + + require.Equal(t, host.ID(), id) + +} + +func TestGetEventBus(t *testing.T) { + var eb event.Bus + host, err := New( + NoTransports, + WithFxOption(fx.Populate(&eb)), + ) + require.NoError(t, err) + defer host.Close() + + require.NotNil(t, eb) +} + +func TestGetHost(t *testing.T) { + var h host.Host + host, err := New( + NoTransports, + WithFxOption(fx.Populate(&h)), + ) + require.NoError(t, err) + defer host.Close() + + require.NotNil(t, h) +} + +func TestGetIDService(t *testing.T) { + var id identify.IDService + host, err := New( + NoTransports, + WithFxOption(fx.Populate(&id)), + ) + require.NoError(t, err) + defer host.Close() + + require.NotNil(t, id) +} diff --git a/go.mod b/go.mod index d37a6ff09f..38caff5d85 100644 --- a/go.mod +++ b/go.mod @@ -44,7 +44,7 @@ require ( github.com/multiformats/go-multibase v0.2.0 github.com/multiformats/go-multicodec v0.9.0 github.com/multiformats/go-multihash v0.2.3 - github.com/multiformats/go-multistream v0.5.0 + github.com/multiformats/go-multistream v0.6.0 github.com/multiformats/go-varint v0.0.7 github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 github.com/pion/datachannel v1.5.9 @@ -55,7 +55,7 @@ require ( github.com/pion/webrtc/v3 v3.3.4 github.com/prometheus/client_golang v1.20.5 github.com/prometheus/client_model v0.6.1 - github.com/quic-go/quic-go v0.48.0 + github.com/quic-go/quic-go v0.48.1 github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66 github.com/raulk/go-watchdog v1.3.0 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index 7b0c780196..082c64848f 100644 --- a/go.sum +++ b/go.sum @@ -246,8 +246,8 @@ github.com/multiformats/go-multicodec v0.9.0/go.mod h1:L3QTQvMIaVBkXOXXtVmYE+LI1 github.com/multiformats/go-multihash v0.0.8/go.mod h1:YSLudS+Pi8NHE7o6tb3D8vrpKa63epEDmG8nTduyAew= github.com/multiformats/go-multihash v0.2.3 h1:7Lyc8XfX/IY2jWb/gI7JP+o7JEq9hOa7BFvVU9RSh+U= github.com/multiformats/go-multihash v0.2.3/go.mod h1:dXgKXCXjBzdscBLk9JkjINiEsCKRVch90MdaGiKsvSM= -github.com/multiformats/go-multistream v0.5.0 h1:5htLSLl7lvJk3xx3qT/8Zm9J4K8vEOf/QGkvOGQAyiE= -github.com/multiformats/go-multistream v0.5.0/go.mod h1:n6tMZiwiP2wUsR8DgfDWw1dydlEqV3l6N3/GBsX6ILA= +github.com/multiformats/go-multistream v0.6.0 h1:ZaHKbsL404720283o4c/IHQXiS6gb8qAN5EIJ4PN5EA= +github.com/multiformats/go-multistream v0.6.0/go.mod h1:MOyoG5otO24cHIg8kf9QW2/NozURlkP/rvi2FQJyCPg= github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8= github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= @@ -333,8 +333,8 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= -github.com/quic-go/quic-go v0.48.0 h1:2TCyvBrMu1Z25rvIAlnp2dPT4lgh/uTqLqiXVpp5AeU= -github.com/quic-go/quic-go v0.48.0/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= +github.com/quic-go/quic-go v0.48.1 h1:y/8xmfWI9qmGTc+lBr4jKRUWLGSlSigv847ULJ4hYXA= +github.com/quic-go/quic-go v0.48.1/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66 h1:4WFk6u3sOT6pLa1kQ50ZVdm8BQFgJNA117cepZxtLIg= github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66/go.mod h1:Vp72IJajgeOL6ddqrAhmp7IM9zbTcgkQxD/YdxrVwMw= github.com/raulk/go-watchdog v1.3.0 h1:oUmdlHxdkXRJlwfG0O9omj8ukerm8MEQavSiDTEtBsk= diff --git a/libp2p_test.go b/libp2p_test.go index b290227fc1..3de82946d8 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -59,7 +59,7 @@ func TestTransportConstructor(t *testing.T) { _ connmgr.ConnectionGater, upgrader transport.Upgrader, ) transport.Transport { - tpt, err := tcp.NewTCPTransport(upgrader, nil) + tpt, err := tcp.NewTCPTransport(upgrader, nil, nil) require.NoError(t, err) return tpt } @@ -751,3 +751,27 @@ func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config { }}, } } + +func TestSharedTCPAddr(t *testing.T) { + h, err := New( + ShareTCPListener(), + Transport(tcp.NewTCPTransport), + Transport(websocket.New), + ListenAddrStrings("/ip4/0.0.0.0/tcp/8888"), + ListenAddrStrings("/ip4/0.0.0.0/tcp/8888/ws"), + ) + require.NoError(t, err) + sawTCP := false + sawWS := false + for _, addr := range h.Addrs() { + if strings.HasSuffix(addr.String(), "/tcp/8888") { + sawTCP = true + } + if strings.HasSuffix(addr.String(), "/tcp/8888/ws") { + sawWS = true + } + } + require.True(t, sawTCP) + require.True(t, sawWS) + h.Close() +} diff --git a/options.go b/options.go index 8e720137c0..0329b7e60b 100644 --- a/options.go +++ b/options.go @@ -634,3 +634,24 @@ func IPv6BlackHoleSuccessCounter(f *swarm.BlackHoleSuccessCounter) Option { return nil } } + +// WithFxOption adds a user provided fx.Option to the libp2p constructor. +// Experimental: This option is subject to change or removal. +func WithFxOption(opts ...fx.Option) Option { + return func(cfg *Config) error { + cfg.UserFxOptions = append(cfg.UserFxOptions, opts...) + return nil + } +} + +// ShareTCPListener shares the same listen address between TCP and Websocket +// transports. This lets both transports use the same TCP port. +// +// Currently this behavior is Opt-in. In a future release this will be the +// default, and this option will be removed. +func ShareTCPListener() Option { + return func(cfg *Config) error { + cfg.ShareTCPListener = true + return nil + } +} diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index f7d3c5275a..820411bd27 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -122,9 +122,10 @@ type HostOpts struct { // MultistreamMuxer is essential for the *BasicHost and will use a sensible default value if omitted. MultistreamMuxer *msmux.MultistreamMuxer[protocol.ID] - // NegotiationTimeout determines the read and write timeouts on streams. - // If 0 or omitted, it will use DefaultNegotiationTimeout. - // If below 0, timeouts on streams will be deactivated. + // NegotiationTimeout determines the read and write timeouts when negotiating + // protocols for streams. If 0 or omitted, it will use + // DefaultNegotiationTimeout. If below 0, timeouts on streams will be + // deactivated. NegotiationTimeout time.Duration // AddrsFactory holds a function which can be used to override or filter the result of Addrs. @@ -689,6 +690,14 @@ func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) { // to create one. If ProtocolID is "", writes no header. // (Thread-safe) func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (str network.Stream, strErr error) { + if _, ok := ctx.Deadline(); !ok { + if h.negtimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, h.negtimeout) + defer cancel() + } + } + // If the caller wants to prevent the host from dialing, it should use the NoDial option. if nodial, _ := network.GetNoDial(ctx); !nodial { err := h.Connect(ctx, peer.AddrInfo{ID: p}) diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 1ab98aae9d..2a7a772976 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -2,6 +2,7 @@ package basichost import ( "context" + "encoding/binary" "fmt" "io" "reflect" @@ -941,3 +942,56 @@ func TestTrimHostAddrList(t *testing.T) { }) } } + +func TestHostTimeoutNewStream(t *testing.T) { + h1, err := NewHost(swarmt.GenSwarm(t), nil) + require.NoError(t, err) + h1.Start() + defer h1.Close() + + const proto = "/testing" + h2 := swarmt.GenSwarm(t) + + h2.SetStreamHandler(func(s network.Stream) { + // First message is multistream header. Just echo it + msHeader := []byte("\x19/multistream/1.0.0\n") + _, err := s.Read(msHeader) + assert.NoError(t, err) + _, err = s.Write(msHeader) + assert.NoError(t, err) + + buf := make([]byte, 1024) + n, err := s.Read(buf) + assert.NoError(t, err) + + msgLen, varintN := binary.Uvarint(buf[:n]) + buf = buf[varintN:] + proto := buf[:int(msgLen)] + if string(proto) == "/ipfs/id/1.0.0\n" { + // Signal we don't support identify + na := []byte("na\n") + n := binary.PutUvarint(buf, uint64(len(na))) + copy(buf[n:], na) + + _, err = s.Write(buf[:int(n)+len(na)]) + assert.NoError(t, err) + } else { + // Stall + time.Sleep(5 * time.Second) + } + t.Log("Resetting") + s.Reset() + }) + + err = h1.Connect(context.Background(), peer.AddrInfo{ + ID: h2.LocalPeer(), + Addrs: h2.ListenAddresses(), + }) + require.NoError(t, err) + + // No context passed in, fallback to negtimeout + h1.negtimeout = time.Second + _, err = h1.NewStream(context.Background(), h2.LocalPeer(), proto) + require.Error(t, err) + require.ErrorContains(t, err, "context deadline exceeded") +} diff --git a/p2p/host/eventbus/basic.go b/p2p/host/eventbus/basic.go index 42365a7916..af6a74bd02 100644 --- a/p2p/host/eventbus/basic.go +++ b/p2p/host/eventbus/basic.go @@ -6,10 +6,20 @@ import ( "reflect" "sync" "sync/atomic" + "time" + logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p/core/event" ) +type logInterface interface { + Errorf(string, ...interface{}) +} + +var log logInterface = logging.Logger("eventbus") + +const slowConsumerWarningTimeout = time.Second + // ///////////////////// // BUS @@ -116,6 +126,7 @@ type wildcardSub struct { w *wildcardNode metricsTracer MetricsTracer name string + closeOnce sync.Once } func (w *wildcardSub) Out() <-chan interface{} { @@ -123,10 +134,13 @@ func (w *wildcardSub) Out() <-chan interface{} { } func (w *wildcardSub) Close() error { - w.w.removeSink(w.ch) - if w.metricsTracer != nil { - w.metricsTracer.RemoveSubscriber(reflect.TypeOf(event.WildcardSubscription)) - } + w.closeOnce.Do(func() { + w.w.removeSink(w.ch) + if w.metricsTracer != nil { + w.metricsTracer.RemoveSubscriber(reflect.TypeOf(event.WildcardSubscription)) + } + }) + return nil } @@ -145,6 +159,7 @@ type sub struct { dropper func(reflect.Type) metricsTracer MetricsTracer name string + closeOnce sync.Once } func (s *sub) Name() string { @@ -162,31 +177,32 @@ func (s *sub) Close() error { for range s.ch { } }() - - for _, n := range s.nodes { - n.lk.Lock() - - for i := 0; i < len(n.sinks); i++ { - if n.sinks[i].ch == s.ch { - n.sinks[i], n.sinks[len(n.sinks)-1] = n.sinks[len(n.sinks)-1], nil - n.sinks = n.sinks[:len(n.sinks)-1] - - if s.metricsTracer != nil { - s.metricsTracer.RemoveSubscriber(n.typ) + s.closeOnce.Do(func() { + for _, n := range s.nodes { + n.lk.Lock() + + for i := 0; i < len(n.sinks); i++ { + if n.sinks[i].ch == s.ch { + n.sinks[i], n.sinks[len(n.sinks)-1] = n.sinks[len(n.sinks)-1], nil + n.sinks = n.sinks[:len(n.sinks)-1] + + if s.metricsTracer != nil { + s.metricsTracer.RemoveSubscriber(n.typ) + } + break } - break } - } - tryDrop := len(n.sinks) == 0 && n.nEmitters.Load() == 0 + tryDrop := len(n.sinks) == 0 && n.nEmitters.Load() == 0 - n.lk.Unlock() + n.lk.Unlock() - if tryDrop { - s.dropper(n.typ) + if tryDrop { + s.dropper(n.typ) + } } - } - close(s.ch) + close(s.ch) + }) return nil } @@ -322,6 +338,8 @@ type wildcardNode struct { nSinks atomic.Int32 sinks []*namedSink metricsTracer MetricsTracer + + slowConsumerTimer *time.Timer } func (n *wildcardNode) addSink(sink *namedSink) { @@ -336,6 +354,12 @@ func (n *wildcardNode) addSink(sink *namedSink) { } func (n *wildcardNode) removeSink(ch chan interface{}) { + go func() { + // drain the event channel, will return when closed and drained. + // this is necessary to unblock publishes to this channel. + for range ch { + } + }() n.nSinks.Add(-1) // ok to do outside the lock n.Lock() for i := 0; i < len(n.sinks); i++ { @@ -348,6 +372,8 @@ func (n *wildcardNode) removeSink(ch chan interface{}) { n.Unlock() } +var wildcardType = reflect.TypeOf(event.WildcardSubscription) + func (n *wildcardNode) emit(evt interface{}) { if n.nSinks.Load() == 0 { return @@ -360,7 +386,16 @@ func (n *wildcardNode) emit(evt interface{}) { // record channel full events before blocking sendSubscriberMetrics(n.metricsTracer, sink) - sink.ch <- evt + select { + case sink.ch <- evt: + default: + slowConsumerTimer := emitAndLogError(n.slowConsumerTimer, wildcardType, evt, sink) + defer func() { + n.Lock() + n.slowConsumerTimer = slowConsumerTimer + n.Unlock() + }() + } } n.RUnlock() } @@ -379,6 +414,8 @@ type node struct { sinks []*namedSink metricsTracer MetricsTracer + + slowConsumerTimer *time.Timer } func newNode(typ reflect.Type, metricsTracer MetricsTracer) *node { @@ -404,11 +441,37 @@ func (n *node) emit(evt interface{}) { // Sending metrics before sending on channel allows us to // record channel full events before blocking sendSubscriberMetrics(n.metricsTracer, sink) - sink.ch <- evt + select { + case sink.ch <- evt: + default: + n.slowConsumerTimer = emitAndLogError(n.slowConsumerTimer, n.typ, evt, sink) + } } n.lk.Unlock() } +func emitAndLogError(timer *time.Timer, typ reflect.Type, evt interface{}, sink *namedSink) *time.Timer { + // Slow consumer. Log a warning if stalled for the timeout + if timer == nil { + timer = time.NewTimer(slowConsumerWarningTimeout) + } else { + timer.Reset(slowConsumerWarningTimeout) + } + + select { + case sink.ch <- evt: + if !timer.Stop() { + <-timer.C + } + case <-timer.C: + log.Errorf("subscriber named \"%s\" is a slow consumer of %s. This can lead to libp2p stalling and hard to debug issues.", sink.name, typ) + // Continue to stall since there's nothing else we can do. + sink.ch <- evt + } + + return timer +} + func sendSubscriberMetrics(metricsTracer MetricsTracer, sink *namedSink) { if metricsTracer != nil { metricsTracer.SubscriberQueueLength(sink.name, len(sink.ch)+1) diff --git a/p2p/host/eventbus/basic_test.go b/p2p/host/eventbus/basic_test.go index 57362ce9b7..530a46ff19 100644 --- a/p2p/host/eventbus/basic_test.go +++ b/p2p/host/eventbus/basic_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "reflect" + "strings" "sync" "sync/atomic" "testing" @@ -13,6 +14,7 @@ import ( "github.com/libp2p/go-libp2p-testing/race" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -131,6 +133,85 @@ func TestEmitNoSubNoBlock(t *testing.T) { em.Emit(EventA{}) } +type mockLogger struct { + mu sync.Mutex + logs []string +} + +func (m *mockLogger) Errorf(format string, args ...interface{}) { + m.mu.Lock() + defer m.mu.Unlock() + m.logs = append(m.logs, fmt.Sprintf(format, args...)) +} + +func (m *mockLogger) Logs() []string { + m.mu.Lock() + defer m.mu.Unlock() + return m.logs +} + +func (m *mockLogger) Clear() { + m.mu.Lock() + defer m.mu.Unlock() + m.logs = nil +} + +func TestEmitLogsErrorOnStall(t *testing.T) { + oldLogger := log + defer func() { + log = oldLogger + }() + ml := &mockLogger{} + log = ml + + bus1 := NewBus() + bus2 := NewBus() + + eventSub, err := bus1.Subscribe(new(EventA)) + if err != nil { + t.Fatal(err) + } + + wildcardSub, err := bus2.Subscribe(event.WildcardSubscription) + if err != nil { + t.Fatal(err) + } + + testCases := []event.Subscription{eventSub, wildcardSub} + eventBuses := []event.Bus{bus1, bus2} + + for i, sub := range testCases { + bus := eventBuses[i] + em, err := bus.Emitter(new(EventA)) + if err != nil { + t.Fatal(err) + } + defer em.Close() + + go func() { + for i := 0; i < subSettingsDefault.buffer+2; i++ { + em.Emit(EventA{}) + } + }() + + require.EventuallyWithT(t, func(collect *assert.CollectT) { + logs := ml.Logs() + found := false + for _, log := range logs { + if strings.Contains(log, "slow consumer") { + found = true + break + } + } + assert.True(collect, found, "expected to find slow consumer log") + }, 3*time.Second, 500*time.Millisecond) + ml.Clear() + + // Close the subscriber so the worker can finish. + sub.Close() + } +} + func TestEmitOnClosed(t *testing.T) { bus := NewBus() @@ -313,10 +394,13 @@ func TestManyWildcardSubscriptions(t *testing.T) { require.NoError(t, em1.Emit(EventA{})) require.NoError(t, em2.Emit(EventB(1))) - // the first five still have 2 events, while the other five have 4 events. - for _, s := range subs[:5] { - require.Len(t, s.Out(), 2) - } + // the first five 0 events because it was closed. The other five + // have 4 events. + require.EventuallyWithT(t, func(t *assert.CollectT) { + for _, s := range subs[:5] { + require.Len(t, s.Out(), 0, "expected closed subscription to have flushed events") + } + }, 2*time.Second, 100*time.Millisecond) for _, s := range subs[5:] { require.Len(t, s.Out(), 4) @@ -326,6 +410,10 @@ func TestManyWildcardSubscriptions(t *testing.T) { for _, s := range subs { require.NoError(t, s.Close()) } + + for _, s := range subs { + require.Zero(t, s.(*wildcardSub).w.nSinks.Load()) + } } func TestWildcardValidations(t *testing.T) { @@ -481,6 +569,17 @@ func TestSubFailFully(t *testing.T) { } } +func TestSubCloseMultiple(t *testing.T) { + bus := NewBus() + + sub, err := bus.Subscribe([]interface{}{new(EventB)}) + require.NoError(t, err) + err = sub.Close() + require.NoError(t, err) + err = sub.Close() + require.NoError(t, err) +} + func testMany(t testing.TB, subs, emits, msgs int, stateful bool) { if race.WithRace() && subs+emits > 5000 { t.SkipNow() diff --git a/p2p/http/example_test.go b/p2p/http/example_test.go index 7073b7f0e3..8e94f6e7e0 100644 --- a/p2p/http/example_test.go +++ b/p2p/http/example_test.go @@ -6,6 +6,7 @@ import ( "log" "net" "net/http" + "regexp" "strings" "github.com/libp2p/go-libp2p" @@ -125,18 +126,24 @@ func ExampleHost_overLibp2pStreams() { // Output: Hello HTTP } +var tcpPortRE = regexp.MustCompile(`/tcp/(\d+)`) + func ExampleHost_Serve() { server := libp2phttp.Host{ InsecureAllowHTTP: true, // For our example, we'll allow insecure HTTP - ListenAddrs: []ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/50221/http")}, + ListenAddrs: []ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/0/http")}, } go server.Serve() defer server.Close() - fmt.Println(server.Addrs()) + for _, a := range server.Addrs() { + s := a.String() + addrWithoutSpecificPort := tcpPortRE.ReplaceAllString(s, "/tcp/") + fmt.Println(addrWithoutSpecificPort) + } - // Output: [/ip4/127.0.0.1/tcp/50221/http] + // Output: /ip4/127.0.0.1/tcp//http } func ExampleHost_SetHTTPHandler() { diff --git a/p2p/http/libp2phttp.go b/p2p/http/libp2phttp.go index d2a544d224..4dad47dd43 100644 --- a/p2p/http/libp2phttp.go +++ b/p2p/http/libp2phttp.go @@ -359,6 +359,7 @@ func (h *Host) Serve() error { expectedErrCount := len(h.httpTransport.listeners) select { case <-h.httpTransport.closeListeners: + err = http.ErrServerClosed case err = <-errCh: expectedErrCount-- } diff --git a/p2p/http/libp2phttp_test.go b/p2p/http/libp2phttp_test.go index fa8687e308..ae3285b9a3 100644 --- a/p2p/http/libp2phttp_test.go +++ b/p2p/http/libp2phttp_test.go @@ -30,6 +30,7 @@ import ( httpping "github.com/libp2p/go-libp2p/p2p/http/ping" libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic" ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -996,3 +997,20 @@ func TestImpliedHostIsSet(t *testing.T) { } } + +func TestErrServerClosed(t *testing.T) { + server := libp2phttp.Host{ + InsecureAllowHTTP: true, + ListenAddrs: []ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/0/http")}, + } + + done := make(chan struct{}) + go func() { + err := server.Serve() + assert.Equal(t, http.ErrServerClosed, err) + close(done) + }() + + server.Close() + <-done +} diff --git a/p2p/net/pnet/psk_conn.go b/p2p/net/pnet/psk_conn.go index c600c8d093..b36d434904 100644 --- a/p2p/net/pnet/psk_conn.go +++ b/p2p/net/pnet/psk_conn.go @@ -3,6 +3,7 @@ package pnet import ( "crypto/cipher" "crypto/rand" + "fmt" "io" "net" @@ -33,7 +34,7 @@ func (c *pskConn) Read(out []byte) (int, error) { nonce := make([]byte, 24) _, err := io.ReadFull(c.Conn, nonce) if err != nil { - return 0, errShortNonce + return 0, fmt.Errorf("%w: %w", errShortNonce, err) } c.readS20 = salsa20.New(c.psk, nonce) } diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index ed4f00ff58..d264fd1230 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -84,7 +84,7 @@ func makeSwarmWithNoListenAddrs(t *testing.T, opts ...Option) *Swarm { upgrader := makeUpgrader(t, s) var tcpOpts []tcp.Option tcpOpts = append(tcpOpts, tcp.DisableReuseport()) - tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...) + tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, nil, tcpOpts...) require.NoError(t, err) if err := s.AddTransport(tcpTransport); err != nil { t.Fatal(err) diff --git a/p2p/net/swarm/swarm_addr_test.go b/p2p/net/swarm/swarm_addr_test.go index 435866e920..43e76716e5 100644 --- a/p2p/net/swarm/swarm_addr_test.go +++ b/p2p/net/swarm/swarm_addr_test.go @@ -79,7 +79,7 @@ func TestDialAddressSelection(t *testing.T) { s, err := swarm.NewSwarm("local", nil, eventbus.NewBus()) require.NoError(t, err) - tcpTr, err := tcp.NewTCPTransport(nil, nil) + tcpTr, err := tcp.NewTCPTransport(nil, nil, nil) require.NoError(t, err) require.NoError(t, s.AddTransport(tcpTr)) reuse, err := quicreuse.NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{}) diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index 0d5020fa04..35fc567fd8 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -627,7 +627,7 @@ func (s *Swarm) dialAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, updC // Trust the transport? Yeah... right. if connC.RemotePeer() != p { connC.Close() - err = fmt.Errorf("BUG in transport %T: tried to dial %s, dialed %s", p, connC.RemotePeer(), tpt) + err = fmt.Errorf("BUG in transport %T: tried to dial %s, dialed %s", tpt, p, connC.RemotePeer()) log.Error(err) return nil, err } diff --git a/p2p/net/swarm/swarm_dial_test.go b/p2p/net/swarm/swarm_dial_test.go index 0ef43cf62e..add6f5cbba 100644 --- a/p2p/net/swarm/swarm_dial_test.go +++ b/p2p/net/swarm/swarm_dial_test.go @@ -53,7 +53,7 @@ func TestAddrsForDial(t *testing.T) { ps.AddPrivKey(id, priv) t.Cleanup(func() { ps.Close() }) - tpt, err := websocket.New(nil, &network.NullResourceManager{}) + tpt, err := websocket.New(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(ResolverFromMaDNS{resolver})) require.NoError(t, err) @@ -100,7 +100,7 @@ func TestDedupAddrsForDial(t *testing.T) { require.NoError(t, err) defer s.Close() - tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}) + tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) err = s.AddTransport(tpt) require.NoError(t, err) @@ -134,7 +134,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm { }) // Add a tcp transport so that we know we can dial a tcp multiaddr and we don't filter it out. - tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}) + tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) err = s.AddTransport(tpt) require.NoError(t, err) @@ -151,7 +151,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm { err = s.AddTransport(wtTpt) require.NoError(t, err) - wsTpt, err := websocket.New(nil, &network.NullResourceManager{}) + wsTpt, err := websocket.New(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) err = s.AddTransport(wsTpt) require.NoError(t, err) diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index 2bbe8b27a5..773314a1b8 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -164,7 +164,7 @@ func GenSwarm(t testing.TB, opts ...Option) *swarm.Swarm { if cfg.disableReuseport { tcpOpts = append(tcpOpts, tcp.DisableReuseport()) } - tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...) + tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, nil, tcpOpts...) require.NoError(t, err) if err := s.AddTransport(tcpTransport); err != nil { t.Fatal(err) diff --git a/p2p/net/upgrader/listener.go b/p2p/net/upgrader/listener.go index 8af2791b36..c2e81d2e93 100644 --- a/p2p/net/upgrader/listener.go +++ b/p2p/net/upgrader/listener.go @@ -84,23 +84,33 @@ func (l *listener) handleIncoming() { } catcher.Reset() - // gate the connection if applicable - if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) { - log.Debugf("gater blocked incoming connection on local addr %s from %s", - maconn.LocalMultiaddr(), maconn.RemoteMultiaddr()) - if err := maconn.Close(); err != nil { - log.Warnf("failed to close incoming connection rejected by gater: %s", err) - } - continue + // Check if we already have a connection scope. See the comment in tcpreuse/listener.go for an explanation. + var connScope network.ConnManagementScope + if sc, ok := maconn.(interface { + Scope() network.ConnManagementScope + }); ok { + connScope = sc.Scope() } + if connScope == nil { + // gate the connection if applicable + if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) { + log.Debugf("gater blocked incoming connection on local addr %s from %s", + maconn.LocalMultiaddr(), maconn.RemoteMultiaddr()) + if err := maconn.Close(); err != nil { + log.Warnf("failed to close incoming connection rejected by gater: %s", err) + } + continue + } - connScope, err := l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr()) - if err != nil { - log.Debugw("resource manager blocked accept of new connection", "error", err) - if err := maconn.Close(); err != nil { - log.Warnf("failed to incoming connection rejected by resource manager: %s", err) + var err error + connScope, err = l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr()) + if err != nil { + log.Debugw("resource manager blocked accept of new connection", "error", err) + if err := maconn.Close(); err != nil { + log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err) + } + continue } - continue } // The go routine below calls Release when the context is diff --git a/p2p/protocol/circuitv2/relay/relay.go b/p2p/protocol/circuitv2/relay/relay.go index 2ee237d97b..326d17781b 100644 --- a/p2p/protocol/circuitv2/relay/relay.go +++ b/p2p/protocol/circuitv2/relay/relay.go @@ -118,7 +118,7 @@ func (r *Relay) Close() error { r.host.RemoveStreamHandler(proto.ProtoIDv2Hop) r.host.Network().StopNotify(r.notifiee) - r.scope.Done() + defer r.scope.Done() r.cancel() r.gc() if r.metricsTracer != nil { @@ -315,7 +315,7 @@ func (r *Relay) handleConnect(s network.Stream, msg *pbv2.HopMessage) pbv2.Statu connStTime := time.Now() cleanup := func() { - span.Done() + defer span.Done() r.mx.Lock() r.rmConn(src) r.rmConn(dest.ID) diff --git a/p2p/protocol/circuitv2/relay/relay_test.go b/p2p/protocol/circuitv2/relay/relay_test.go index e5d32b0c96..f6b63e32de 100644 --- a/p2p/protocol/circuitv2/relay/relay_test.go +++ b/p2p/protocol/circuitv2/relay/relay_test.go @@ -60,7 +60,7 @@ func getNetHosts(t *testing.T, ctx context.Context, n int) (hosts []host.Host, u upgrader := swarmt.GenUpgrader(t, netw, nil) upgraders = append(upgraders, upgrader) - tpt, err := tcp.NewTCPTransport(upgrader, nil) + tpt, err := tcp.NewTCPTransport(upgrader, nil, nil) if err != nil { t.Fatal(err) } diff --git a/p2p/protocol/holepunch/tracer.go b/p2p/protocol/holepunch/tracer.go index 82e0ebfc0f..3ba06f653d 100644 --- a/p2p/protocol/holepunch/tracer.go +++ b/p2p/protocol/holepunch/tracer.go @@ -20,13 +20,10 @@ const ( func WithTracer(et EventTracer) Option { return func(hps *Service) error { hps.tracer = &tracer{ - et: et, - mt: nil, - self: hps.host.ID(), - peers: make(map[peer.ID]struct { - counter int - last time.Time - }), + et: et, + mt: nil, + self: hps.host.ID(), + peers: make(map[peer.ID]peerInfo), } return nil } @@ -36,13 +33,10 @@ func WithTracer(et EventTracer) Option { func WithMetricsTracer(mt MetricsTracer) Option { return func(hps *Service) error { hps.tracer = &tracer{ - et: nil, - mt: mt, - self: hps.host.ID(), - peers: make(map[peer.ID]struct { - counter int - last time.Time - }), + et: nil, + mt: mt, + self: hps.host.ID(), + peers: make(map[peer.ID]peerInfo), } return nil } @@ -52,13 +46,10 @@ func WithMetricsTracer(mt MetricsTracer) Option { func WithMetricsAndEventTracer(mt MetricsTracer, et EventTracer) Option { return func(hps *Service) error { hps.tracer = &tracer{ - et: et, - mt: mt, - self: hps.host.ID(), - peers: make(map[peer.ID]struct { - counter int - last time.Time - }), + et: et, + mt: mt, + self: hps.host.ID(), + peers: make(map[peer.ID]peerInfo), } return nil } @@ -74,10 +65,12 @@ type tracer struct { ctxCancel context.CancelFunc mutex sync.Mutex - peers map[peer.ID]struct { - counter int - last time.Time - } + peers map[peer.ID]peerInfo +} + +type peerInfo struct { + counter int + last time.Time } type EventTracer interface { diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index 1733a4166c..b6d5240ba6 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -347,7 +347,8 @@ func (ids *idService) sendPushes(ctx context.Context) { defer func() { <-sem }() ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - str, err := ids.Host.NewStream(ctx, c.RemotePeer(), IDPush) + + str, err := newStreamAndNegotiate(ctx, c, IDPush) if err != nil { // connection might have been closed recently return } @@ -437,25 +438,38 @@ func (ids *idService) IdentifyWait(c network.Conn) <-chan struct{} { return e.IdentifyWaitChan } -func (ids *idService) identifyConn(c network.Conn) error { - ctx, cancel := context.WithTimeout(context.Background(), Timeout) - defer cancel() +// newStreamAndNegotiate opens a new stream on the given connection and negotiates the given protocol. +func newStreamAndNegotiate(ctx context.Context, c network.Conn, proto protocol.ID) (network.Stream, error) { s, err := c.NewStream(network.WithAllowLimitedConn(ctx, "identify")) if err != nil { log.Debugw("error opening identify stream", "peer", c.RemotePeer(), "error", err) - return err + return nil, err + } + err = s.SetDeadline(time.Now().Add(Timeout)) + if err != nil { + return nil, err } - s.SetDeadline(time.Now().Add(Timeout)) - if err := s.SetProtocol(ID); err != nil { + if err := s.SetProtocol(proto); err != nil { log.Warnf("error setting identify protocol for stream: %s", err) - s.Reset() + _ = s.Reset() } // ok give the response to our handler. - if err := msmux.SelectProtoOrFail(ID, s); err != nil { + if err := msmux.SelectProtoOrFail(proto, s); err != nil { log.Infow("failed negotiate identify protocol with peer", "peer", c.RemotePeer(), "error", err) - s.Reset() + _ = s.Reset() + return nil, err + } + return s, nil +} + +func (ids *idService) identifyConn(c network.Conn) error { + ctx, cancel := context.WithTimeout(context.Background(), Timeout) + defer cancel() + s, err := newStreamAndNegotiate(network.WithAllowLimitedConn(ctx, "identify"), c, ID) + if err != nil { + log.Debugw("error opening identify stream", "peer", c.RemotePeer(), "error", err) return err } diff --git a/p2p/protocol/identify/obsaddr.go b/p2p/protocol/identify/obsaddr.go index ffe60345e1..06e54bf5fd 100644 --- a/p2p/protocol/identify/obsaddr.go +++ b/p2p/protocol/identify/obsaddr.go @@ -335,6 +335,11 @@ func (o *ObservedAddrManager) worker() { } } +func isRelayedAddress(a ma.Multiaddr) bool { + _, err := a.ValueForProtocol(ma.P_CIRCUIT) + return err == nil +} + func (o *ObservedAddrManager) shouldRecordObservation(conn connMultiaddrs, observed ma.Multiaddr) (shouldRecord bool, localTW thinWaist, observedTW thinWaist) { if conn == nil || observed == nil { return false, thinWaist{}, thinWaist{} @@ -350,6 +355,12 @@ func (o *ObservedAddrManager) shouldRecordObservation(conn connMultiaddrs, obser return false, thinWaist{}, thinWaist{} } + // Ignore p2p-circuit addresses. These are the observed address of the relay. + // Not useful for us. + if isRelayedAddress(observed) { + return false, thinWaist{}, thinWaist{} + } + // we should only use ObservedAddr when our connection's LocalAddr is one // of our ListenAddrs. If we Dial out using an ephemeral addr, knowing that // address's external mapping is not very useful because the port will not be @@ -410,7 +421,7 @@ func (o *ObservedAddrManager) maybeRecordObservation(conn connMultiaddrs, observ if !shouldRecord { return } - log.Debugw("added own observed listen addr", "observed", observed) + log.Debugw("added own observed listen addr", "conn", conn, "observed", observed) o.mu.Lock() defer o.mu.Unlock() diff --git a/p2p/protocol/identify/obsaddr_glass_test.go b/p2p/protocol/identify/obsaddr_glass_test.go index 31fd4f5726..3211aa5f54 100644 --- a/p2p/protocol/identify/obsaddr_glass_test.go +++ b/p2p/protocol/identify/obsaddr_glass_test.go @@ -53,6 +53,24 @@ func TestShouldRecordObservationWithWebTransport(t *testing.T) { require.True(t, shouldRecord) } +func TestShouldNotRecordObservationWithRelayedAddr(t *testing.T) { + listenAddr := ma.StringCast("/ip4/1.2.3.4/udp/8888/quic-v1/p2p-circuit") + ifaceAddr := ma.StringCast("/ip4/10.0.0.2/udp/9999/quic-v1") + listenAddrs := func() []ma.Multiaddr { return []ma.Multiaddr{listenAddr} } + ifaceListenAddrs := func() ([]ma.Multiaddr, error) { return []ma.Multiaddr{ifaceAddr}, nil } + addrs := func() []ma.Multiaddr { return []ma.Multiaddr{listenAddr} } + + c := &mockConn{ + local: listenAddr, + remote: ma.StringCast("/ip4/1.2.3.6/udp/1236/quic-v1/p2p-circuit"), + } + observedAddr := ma.StringCast("/ip4/1.2.3.4/udp/1231/quic-v1/p2p-circuit") + o, err := NewObservedAddrManager(listenAddrs, addrs, ifaceListenAddrs, normalize) + require.NoError(t, err) + shouldRecord, _, _ := o.shouldRecordObservation(c, observedAddr) + require.False(t, shouldRecord) +} + func TestShouldRecordObservationWithNAT64Addr(t *testing.T) { listenAddr1 := ma.StringCast("/ip4/0.0.0.0/tcp/1234") ifaceAddr1 := ma.StringCast("/ip4/10.0.0.2/tcp/4321") diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index df53da6eeb..99ce67b521 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -2,6 +2,8 @@ package transport_integration import ( "context" + "encoding/binary" + "net/netip" "strings" "testing" "time" @@ -30,6 +32,23 @@ func stripCertHash(addr ma.Multiaddr) ma.Multiaddr { return addr } +func addrPort(addr ma.Multiaddr) netip.AddrPort { + a := netip.Addr{} + p := uint16(0) + ma.ForEach(addr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_IP4 || c.Protocol().Code == ma.P_IP6 { + a, _ = netip.AddrFromSlice(c.RawValue()) + return false + } + if c.Protocol().Code == ma.P_UDP || c.Protocol().Code == ma.P_TCP { + p = binary.BigEndian.Uint16(c.RawValue()) + return true + } + return false + }) + return netip.AddrPortFrom(a, p) +} + func TestInterceptPeerDial(t *testing.T) { if race.WithRace() { t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.") @@ -173,10 +192,14 @@ func TestInterceptAccept(t *testing.T) { // remove the certhash component from WebTransport addresses require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) }).AnyTimes() + } else if strings.Contains(tc.Name, "WebSocket-Shared") { + connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { + require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr())) + }) } else { connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) + require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr(), "%s\n%s", h2.Addrs()[0], addrs.LocalMultiaddr()) }) } diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 7010b3d0ce..203c382124 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -98,6 +98,38 @@ var transportsToTest = []TransportTestCase{ return h }, }, + { + Name: "TCP-Shared / TLS / Yamux", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) + libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New)) + libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport)) + if opts.NoListen { + libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) + } else { + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0")) + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return h + }, + }, + { + Name: "WebSocket-Shared", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) + if opts.NoListen { + libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) + } else { + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/ws")) + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return h + }, + }, { Name: "WebSocket", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { diff --git a/p2p/transport/tcp/metrics.go b/p2p/transport/tcp/metrics.go index 213ee2200a..50820d870c 100644 --- a/p2p/transport/tcp/metrics.go +++ b/p2p/transport/tcp/metrics.go @@ -24,7 +24,7 @@ var ( const collectFrequency = 10 * time.Second -var collector *aggregatingCollector +var defaultCollector *aggregatingCollector var initMetricsOnce sync.Once @@ -34,8 +34,8 @@ func initMetrics() { bytesSentDesc = prometheus.NewDesc("tcp_sent_bytes", "TCP bytes sent", nil, nil) bytesRcvdDesc = prometheus.NewDesc("tcp_rcvd_bytes", "TCP bytes received", nil, nil) - collector = newAggregatingCollector() - prometheus.MustRegister(collector) + defaultCollector = newAggregatingCollector() + prometheus.MustRegister(defaultCollector) const direction = "direction" @@ -196,7 +196,7 @@ func (c *aggregatingCollector) Collect(metrics chan<- prometheus.Metric) { func (c *aggregatingCollector) ClosedConn(conn *tracingConn, direction string) { c.mutex.Lock() - collector.removeConn(conn.id) + c.removeConn(conn.id) c.mutex.Unlock() closedConns.WithLabelValues(direction).Inc() } @@ -204,6 +204,8 @@ func (c *aggregatingCollector) ClosedConn(conn *tracingConn, direction string) { type tracingConn struct { id uint64 + collector *aggregatingCollector + startTime time.Time isClient bool @@ -213,7 +215,8 @@ type tracingConn struct { closeErr error } -func newTracingConn(c manet.Conn, isClient bool) (*tracingConn, error) { +// newTracingConn wraps a manet.Conn with a tracingConn. A nil collector will use the default collector. +func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool) (*tracingConn, error) { initMetricsOnce.Do(func() { initMetrics() }) conn, err := tcp.NewConn(c) if err != nil { @@ -224,8 +227,12 @@ func newTracingConn(c manet.Conn, isClient bool) (*tracingConn, error) { isClient: isClient, Conn: c, tcpConn: conn, + collector: collector, + } + if tc.collector == nil { + tc.collector = defaultCollector } - tc.id = collector.AddConn(tc) + tc.id = tc.collector.AddConn(tc) newConns.WithLabelValues(tc.getDirection()).Inc() return tc, nil } @@ -239,7 +246,7 @@ func (c *tracingConn) getDirection() string { func (c *tracingConn) Close() error { c.closeOnce.Do(func() { - collector.ClosedConn(c, c.getDirection()) + c.collector.ClosedConn(c, c.getDirection()) c.closeErr = c.Conn.Close() }) return c.closeErr @@ -258,10 +265,12 @@ func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) { type tracingListener struct { manet.Listener + collector *aggregatingCollector } -func newTracingListener(l manet.Listener) *tracingListener { - return &tracingListener{Listener: l} +// newTracingListener wraps a manet.Listener with a tracingListener. A nil collector will use the default collector. +func newTracingListener(l manet.Listener, collector *aggregatingCollector) *tracingListener { + return &tracingListener{Listener: l, collector: collector} } func (l *tracingListener) Accept() (manet.Conn, error) { @@ -269,5 +278,5 @@ func (l *tracingListener) Accept() (manet.Conn, error) { if err != nil { return nil, err } - return newTracingConn(conn, false) + return newTracingConn(conn, l.collector, false) } diff --git a/p2p/transport/tcp/metrics_none.go b/p2p/transport/tcp/metrics_none.go index 8538b30c89..cbee982070 100644 --- a/p2p/transport/tcp/metrics_none.go +++ b/p2p/transport/tcp/metrics_none.go @@ -6,5 +6,9 @@ package tcp import manet "github.com/multiformats/go-multiaddr/net" -func newTracingConn(c manet.Conn, _ bool) (manet.Conn, error) { return c, nil } -func newTracingListener(l manet.Listener) manet.Listener { return l } +type aggregatingCollector struct{} + +func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool) (manet.Conn, error) { + return c, nil +} +func newTracingListener(l manet.Listener, collector *aggregatingCollector) manet.Listener { return l } diff --git a/p2p/transport/tcp/metrics_unix_test.go b/p2p/transport/tcp/metrics_unix_test.go new file mode 100644 index 0000000000..0a09526206 --- /dev/null +++ b/p2p/transport/tcp/metrics_unix_test.go @@ -0,0 +1,54 @@ +// go:build: unix + +package tcp + +import ( + "testing" + + tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" + ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite" + + "github.com/stretchr/testify/require" +) + +func TestTcpTransportCollectsMetricsWithSharedTcpSocket(t *testing.T) { + + peerA, ia := makeInsecureMuxer(t) + _, ib := makeInsecureMuxer(t) + + sharedTCPSocketA := tcpreuse.NewConnMgr(false, nil, nil) + sharedTCPSocketB := tcpreuse.NewConnMgr(false, nil, nil) + + ua, err := tptu.New(ia, muxers, nil, nil, nil) + require.NoError(t, err) + ta, err := NewTCPTransport(ua, nil, sharedTCPSocketA, WithMetrics()) + require.NoError(t, err) + ub, err := tptu.New(ib, muxers, nil, nil, nil) + require.NoError(t, err) + tb, err := NewTCPTransport(ub, nil, sharedTCPSocketB, WithMetrics()) + require.NoError(t, err) + + zero := "/ip4/127.0.0.1/tcp/0" + + // Not running any test that needs more than 1 conn because the testsuite + // opens multiple conns via multiple listeners, which is not expected to work + // with the shared TCP socket. + subtestsToRun := []ttransport.TransportSubTestFn{ + ttransport.SubtestProtocols, + ttransport.SubtestBasic, + ttransport.SubtestCancel, + ttransport.SubtestPingPong, + + // Stolen from the stream muxer test suite. + ttransport.SubtestStress1Conn1Stream1Msg, + ttransport.SubtestStress1Conn1Stream100Msg, + ttransport.SubtestStress1Conn100Stream100Msg, + ttransport.SubtestStress1Conn1000Stream10Msg, + ttransport.SubtestStress1Conn100Stream100Msg10MB, + ttransport.SubtestStreamOpenStress, + ttransport.SubtestStreamReset, + } + + ttransport.SubtestTransportWithFs(t, ta, tb, zero, peerA, subtestsToRun) +} diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index d52bb96019..61e68941e2 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -13,6 +13,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/net/reuseport" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" logging "github.com/ipfs/go-log/v2" ma "github.com/multiformats/go-multiaddr" @@ -33,6 +34,9 @@ type canKeepAlive interface { var _ canKeepAlive = &net.TCPConn{} +// Deprecated: Use tcpreuse.ReuseportIsAvailable +var ReuseportIsAvailable = tcpreuse.ReuseportIsAvailable + func tryKeepAlive(conn net.Conn, keepAlive bool) { keepAliveConn, ok := conn.(canKeepAlive) if !ok { @@ -122,20 +126,25 @@ type TcpTransport struct { disableReuseport bool // Explicitly disable reuseport. enableMetrics bool + // share and demultiplex TCP listeners across multiple transports + sharedTcp *tcpreuse.ConnMgr + // TCP connect timeout connectTimeout time.Duration rcmgr network.ResourceManager reuse reuseport.Transport + + metricsCollector *aggregatingCollector } var _ transport.Transport = &TcpTransport{} var _ transport.DialUpdater = &TcpTransport{} // NewTCPTransport creates a tcp transport object that tracks dialers and listeners -// created. It represents an entire TCP stack (though it might not necessarily be). -func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*TcpTransport, error) { +// created. +func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreuse.ConnMgr, opts ...Option) (*TcpTransport, error) { if rcmgr == nil { rcmgr = &network.NullResourceManager{} } @@ -143,6 +152,7 @@ func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, upgrader: upgrader, connectTimeout: defaultConnectTimeout, // can be set by using the WithConnectionTimeout option rcmgr: rcmgr, + sharedTcp: sharedTCP, } for _, o := range opts { if err := o(tr); err != nil { @@ -168,6 +178,10 @@ func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Co defer cancel() } + if t.sharedTcp != nil { + return t.sharedTcp.DialContext(ctx, raddr) + } + if t.UseReuseport() { return t.reuse.DialContext(ctx, raddr) } @@ -212,7 +226,7 @@ func (t *TcpTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p c := conn if t.enableMetrics { var err error - c, err = newTracingConn(conn, true) + c, err = newTracingConn(conn, t.metricsCollector, true) if err != nil { return nil, err } @@ -233,10 +247,10 @@ func (t *TcpTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p // UseReuseport returns true if reuseport is enabled and available. func (t *TcpTransport) UseReuseport() bool { - return !t.disableReuseport && ReuseportIsAvailable() + return !t.disableReuseport && tcpreuse.ReuseportIsAvailable() } -func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) { +func (t *TcpTransport) unsharedMAListen(laddr ma.Multiaddr) (manet.Listener, error) { if t.UseReuseport() { return t.reuse.Listen(laddr) } @@ -245,12 +259,20 @@ func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) { // Listen listens on the given multiaddr. func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) { - list, err := t.maListen(laddr) + var list manet.Listener + var err error + + if t.sharedTcp == nil { + list, err = t.unsharedMAListen(laddr) + } else { + list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.DemultiplexedConnType_MultistreamSelect) + } if err != nil { return nil, err } + if t.enableMetrics { - list = newTracingListener(&tcpListener{list, 0}) + list = newTracingListener(&tcpListener{list, 0}, t.metricsCollector) } return t.upgrader.UpgradeListener(t, list), nil } diff --git a/p2p/transport/tcp/tcp_test.go b/p2p/transport/tcp/tcp_test.go index a57a65e420..1f939d92be 100644 --- a/p2p/transport/tcp/tcp_test.go +++ b/p2p/transport/tcp/tcp_test.go @@ -14,6 +14,7 @@ import ( "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/muxer/yamux" tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite" ma "github.com/multiformats/go-multiaddr" @@ -31,19 +32,19 @@ func TestTcpTransport(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil) + ta, err := NewTCPTransport(ua, nil, nil) require.NoError(t, err) ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) - tb, err := NewTCPTransport(ub, nil) + tb, err := NewTCPTransport(ub, nil, nil) require.NoError(t, err) zero := "/ip4/127.0.0.1/tcp/0" ttransport.SubtestTransport(t, ta, tb, zero, peerA) - envReuseportVal = false + tcpreuse.EnvReuseportVal = false } - envReuseportVal = true + tcpreuse.EnvReuseportVal = true } func TestTcpTransportWithMetrics(t *testing.T) { @@ -52,11 +53,11 @@ func TestTcpTransportWithMetrics(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil, WithMetrics()) + ta, err := NewTCPTransport(ua, nil, nil, WithMetrics()) require.NoError(t, err) ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) - tb, err := NewTCPTransport(ub, nil, WithMetrics()) + tb, err := NewTCPTransport(ub, nil, nil, WithMetrics()) require.NoError(t, err) zero := "/ip4/127.0.0.1/tcp/0" @@ -72,7 +73,7 @@ func TestResourceManager(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil) + ta, err := NewTCPTransport(ua, nil, nil) require.NoError(t, err) ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) require.NoError(t, err) @@ -81,7 +82,7 @@ func TestResourceManager(t *testing.T) { ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) rcmgr := mocknetwork.NewMockResourceManager(ctrl) - tb, err := NewTCPTransport(ub, rcmgr) + tb, err := NewTCPTransport(ub, rcmgr, nil) require.NoError(t, err) t.Run("success", func(t *testing.T) { @@ -119,16 +120,16 @@ func TestTcpTransportCantDialDNS(t *testing.T) { require.NoError(t, err) var u transport.Upgrader - tpt, err := NewTCPTransport(u, nil) + tpt, err := NewTCPTransport(u, nil, nil) require.NoError(t, err) if tpt.CanDial(dnsa) { t.Fatal("shouldn't be able to dial dns") } - envReuseportVal = false + tcpreuse.EnvReuseportVal = false } - envReuseportVal = true + tcpreuse.EnvReuseportVal = true } func TestTcpTransportCantListenUtp(t *testing.T) { @@ -137,15 +138,15 @@ func TestTcpTransportCantListenUtp(t *testing.T) { require.NoError(t, err) var u transport.Upgrader - tpt, err := NewTCPTransport(u, nil) + tpt, err := NewTCPTransport(u, nil, nil) require.NoError(t, err) _, err = tpt.Listen(utpa) require.Error(t, err, "shouldn't be able to listen on utp addr with tcp transport") - envReuseportVal = false + tcpreuse.EnvReuseportVal = false } - envReuseportVal = true + tcpreuse.EnvReuseportVal = true } func TestDialWithUpdates(t *testing.T) { @@ -154,7 +155,7 @@ func TestDialWithUpdates(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil) + ta, err := NewTCPTransport(ua, nil, nil) require.NoError(t, err) ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) require.NoError(t, err) @@ -162,7 +163,7 @@ func TestDialWithUpdates(t *testing.T) { ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) - tb, err := NewTCPTransport(ub, nil) + tb, err := NewTCPTransport(ub, nil, nil) require.NoError(t, err) updCh := make(chan transport.DialUpdate, 1) diff --git a/p2p/transport/tcpreuse/connwithscope.go b/p2p/transport/tcpreuse/connwithscope.go new file mode 100644 index 0000000000..ca66f20325 --- /dev/null +++ b/p2p/transport/tcpreuse/connwithscope.go @@ -0,0 +1,26 @@ +package tcpreuse + +import ( + "fmt" + + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse/internal/sampledconn" + manet "github.com/multiformats/go-multiaddr/net" +) + +type connWithScope struct { + sampledconn.ManetTCPConnInterface + scope network.ConnManagementScope +} + +func (c connWithScope) Scope() network.ConnManagementScope { + return c.scope +} + +func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (manet.Conn, error) { + if tcpconn, ok := c.(sampledconn.ManetTCPConnInterface); ok { + return &connWithScope{tcpconn, scope}, nil + } + + return nil, fmt.Errorf("manet.Conn is not a TCP Conn") +} diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go new file mode 100644 index 0000000000..f9175ecfdb --- /dev/null +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -0,0 +1,100 @@ +package tcpreuse + +import ( + "errors" + "fmt" + "time" + + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse/internal/sampledconn" + manet "github.com/multiformats/go-multiaddr/net" +) + +// This is reading the first 3 bytes of the first packet after the handshake. +// It's set to the default TCP connect timeout in the TCP Transport. +// +// A var so we can change it in tests. +var identifyConnTimeout = 5 * time.Second + +type DemultiplexedConnType int + +const ( + DemultiplexedConnType_Unknown DemultiplexedConnType = iota + DemultiplexedConnType_MultistreamSelect + DemultiplexedConnType_HTTP + DemultiplexedConnType_TLS +) + +func (t DemultiplexedConnType) String() string { + switch t { + case DemultiplexedConnType_MultistreamSelect: + return "MultistreamSelect" + case DemultiplexedConnType_HTTP: + return "HTTP" + case DemultiplexedConnType_TLS: + return "TLS" + default: + return fmt.Sprintf("Unknown(%d)", int(t)) + } +} + +func (t DemultiplexedConnType) IsKnown() bool { + return t >= 1 || t <= 3 +} + +// identifyConnType attempts to identify the connection type by peeking at the +// first few bytes. +// Its Callers must not use the passed in Conn after this function returns. +// If an error is returned, the connection will be closed. +func identifyConnType(c manet.Conn) (DemultiplexedConnType, manet.Conn, error) { + if err := c.SetReadDeadline(time.Now().Add(identifyConnTimeout)); err != nil { + closeErr := c.Close() + return 0, nil, errors.Join(err, closeErr) + } + + s, peekedConn, err := sampledconn.PeekBytes(c) + if err != nil { + closeErr := c.Close() + return 0, nil, errors.Join(err, closeErr) + } + + if err := peekedConn.SetReadDeadline(time.Time{}); err != nil { + closeErr := peekedConn.Close() + return 0, nil, errors.Join(err, closeErr) + } + + if IsMultistreamSelect(s) { + return DemultiplexedConnType_MultistreamSelect, peekedConn, nil + } + if IsTLS(s) { + return DemultiplexedConnType_TLS, peekedConn, nil + } + if IsHTTP(s) { + return DemultiplexedConnType_HTTP, peekedConn, nil + } + return DemultiplexedConnType_Unknown, peekedConn, nil +} + +// Matchers are implemented here instead of in the transports so we can easily fuzz them together. +type Prefix = [3]byte + +func IsMultistreamSelect(s Prefix) bool { + return string(s[:]) == "\x13/m" +} + +func IsHTTP(s Prefix) bool { + switch string(s[:]) { + case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT": + return true + default: + return false + } +} + +func IsTLS(s Prefix) bool { + switch string(s[:]) { + case "\x16\x03\x01", "\x16\x03\x02", "\x16\x03\x03": + return true + default: + return false + } +} diff --git a/p2p/transport/tcpreuse/demultiplex_test.go b/p2p/transport/tcpreuse/demultiplex_test.go new file mode 100644 index 0000000000..e201f2ca75 --- /dev/null +++ b/p2p/transport/tcpreuse/demultiplex_test.go @@ -0,0 +1,50 @@ +package tcpreuse + +import "testing" + +func FuzzClash(f *testing.F) { + // make untyped literals type correctly + add := func(a, b, c byte) { f.Add(a, b, c) } + + // multistream-select + add('\x13', '/', 'm') + // http + add('G', 'E', 'T') + add('H', 'E', 'A') + add('P', 'O', 'S') + add('P', 'U', 'T') + add('D', 'E', 'L') + add('C', 'O', 'N') + add('O', 'P', 'T') + add('T', 'R', 'A') + add('P', 'A', 'T') + // tls + add('\x16', '\x03', '\x01') + add('\x16', '\x03', '\x02') + add('\x16', '\x03', '\x03') + add('\x16', '\x03', '\x04') + + f.Fuzz(func(t *testing.T, a, b, c byte) { + s := Prefix{a, b, c} + var total uint + + ms := IsMultistreamSelect(s) + if ms { + total++ + } + + http := IsHTTP(s) + if http { + total++ + } + + tls := IsTLS(s) + if tls { + total++ + } + + if total > 1 { + t.Errorf("clash on: %q; ms: %v; http: %v; tls: %v", s, ms, http, tls) + } + }) +} diff --git a/p2p/transport/tcpreuse/dialer.go b/p2p/transport/tcpreuse/dialer.go new file mode 100644 index 0000000000..ad634583ed --- /dev/null +++ b/p2p/transport/tcpreuse/dialer.go @@ -0,0 +1,16 @@ +package tcpreuse + +import ( + "context" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +// DialContext is like Dial but takes a context. +func (t *ConnMgr) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) { + if t.useReuseport() { + return t.reuse.DialContext(ctx, raddr) + } + var d manet.Dialer + return d.DialContext(ctx, raddr) +} diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go new file mode 100644 index 0000000000..7324b45849 --- /dev/null +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go @@ -0,0 +1,89 @@ +package sampledconn + +import ( + "errors" + "io" + "net" + "syscall" + "time" + + manet "github.com/multiformats/go-multiaddr/net" +) + +const peekSize = 3 + +type PeekedBytes = [peekSize]byte + +var errNotSupported = errors.New("not supported on this platform") + +var ErrNotTCPConn = errors.New("passed conn is not a TCPConn") + +func PeekBytes(conn manet.Conn) (PeekedBytes, manet.Conn, error) { + if c, ok := conn.(syscall.Conn); ok { + b, err := OSPeekConn(c) + if err == nil { + return b, conn, nil + } + if err != errNotSupported { + return PeekedBytes{}, nil, err + } + // Fallback to wrapping the coonn + } + + if c, ok := conn.(ManetTCPConnInterface); ok { + return newFallbackSampledConn(c) + } + + return PeekedBytes{}, nil, ErrNotTCPConn +} + +type fallbackPeekingConn struct { + ManetTCPConnInterface + peekedBytes PeekedBytes + bytesPeeked uint8 +} + +// tcpConnInterface is the interface for TCPConn's functions +// NOTE: `SyscallConn() (syscall.RawConn, error)` is here to make using this as +// a TCP Conn easier, but it's a potential footgun as you could skipped the +// peeked bytes if using the fallback +type tcpConnInterface interface { + net.Conn + syscall.Conn + + CloseRead() error + CloseWrite() error + + SetLinger(sec int) error + SetKeepAlive(keepalive bool) error + SetKeepAlivePeriod(d time.Duration) error + SetNoDelay(noDelay bool) error + MultipathTCP() (bool, error) + + io.ReaderFrom + io.WriterTo +} + +type ManetTCPConnInterface interface { + manet.Conn + tcpConnInterface +} + +func newFallbackSampledConn(conn ManetTCPConnInterface) (PeekedBytes, *fallbackPeekingConn, error) { + s := &fallbackPeekingConn{ManetTCPConnInterface: conn} + _, err := io.ReadFull(conn, s.peekedBytes[:]) + if err != nil { + return s.peekedBytes, nil, err + } + return s.peekedBytes, s, nil +} + +func (sc *fallbackPeekingConn) Read(b []byte) (int, error) { + if int(sc.bytesPeeked) != len(sc.peekedBytes) { + red := copy(b, sc.peekedBytes[sc.bytesPeeked:]) + sc.bytesPeeked += uint8(red) + return red, nil + } + + return sc.ManetTCPConnInterface.Read(b) +} diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go new file mode 100644 index 0000000000..5197052fab --- /dev/null +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go @@ -0,0 +1,11 @@ +//go:build !unix + +package sampledconn + +import ( + "syscall" +) + +func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) { + return PeekedBytes{}, errNotSupported +} diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go new file mode 100644 index 0000000000..d5b31009e2 --- /dev/null +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go @@ -0,0 +1,78 @@ +package sampledconn + +import ( + "io" + "syscall" + "testing" + "time" + + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + + "github.com/stretchr/testify/assert" +) + +func TestSampledConn(t *testing.T) { + testCases := []string{ + "platform", + "fallback", + } + + // Start a TCP server + listener, err := manet.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) + assert.NoError(t, err) + defer listener.Close() + + serverAddr := listener.Multiaddr() + + // Server goroutine + go func() { + for i := 0; i < len(testCases); i++ { + conn, err := listener.Accept() + assert.NoError(t, err) + defer conn.Close() + + // Write some data to the connection + _, err = conn.Write([]byte("hello")) + assert.NoError(t, err) + } + }() + + // Give the server a moment to start + time.Sleep(100 * time.Millisecond) + + for _, tc := range testCases { + t.Run(tc, func(t *testing.T) { + // Create a TCP client + clientConn, err := manet.Dial(serverAddr) + assert.NoError(t, err) + defer clientConn.Close() + + if tc == "platform" { + // Wrap the client connection in SampledConn + peeked, clientConn, err := PeekBytes(clientConn.(interface { + manet.Conn + syscall.Conn + })) + assert.NoError(t, err) + assert.Equal(t, "hel", string(peeked[:])) + + buf := make([]byte, 5) + _, err = io.ReadFull(clientConn, buf) + assert.NoError(t, err) + assert.Equal(t, "hello", string(buf)) + } else { + // Wrap the client connection in SampledConn + sample, sampledConn, err := newFallbackSampledConn(clientConn.(ManetTCPConnInterface)) + assert.NoError(t, err) + assert.Equal(t, "hel", string(sample[:])) + + buf := make([]byte, 5) + _, err = io.ReadFull(sampledConn, buf) + assert.NoError(t, err) + assert.Equal(t, "hello", string(buf)) + + } + }) + } +} diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go new file mode 100644 index 0000000000..9847e8d4be --- /dev/null +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go @@ -0,0 +1,42 @@ +//go:build unix + +package sampledconn + +import ( + "errors" + "syscall" +) + +func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) { + s := PeekedBytes{} + + rawConn, err := conn.SyscallConn() + if err != nil { + return s, err + } + + readBytes := 0 + var readErr error + err = rawConn.Read(func(fd uintptr) bool { + for readBytes < peekSize { + var n int + n, _, readErr = syscall.Recvfrom(int(fd), s[readBytes:], syscall.MSG_PEEK) + if errors.Is(readErr, syscall.EAGAIN) { + return false + } + if readErr != nil { + return true + } + readBytes += n + } + return true + }) + if readErr != nil { + return s, readErr + } + if err != nil { + return s, err + } + + return s, nil +} diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go new file mode 100644 index 0000000000..d94186e7ec --- /dev/null +++ b/p2p/transport/tcpreuse/listener.go @@ -0,0 +1,327 @@ +package tcpreuse + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/connmgr" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/net/reuseport" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +const acceptQueueSize = 64 // It is fine to read 3 bytes from 64 connections in parallel. + +// How long we wait for a connection to be accepted before dropping it. +const acceptTimeout = 30 * time.Second + +var log = logging.Logger("tcp-demultiplex") + +// ConnMgr enables you to share the same listen address between TCP and WebSocket transports. +type ConnMgr struct { + enableReuseport bool + reuse reuseport.Transport + connGater connmgr.ConnectionGater + rcmgr network.ResourceManager + + mx sync.Mutex + listeners map[string]*multiplexedListener +} + +func NewConnMgr(enableReuseport bool, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *ConnMgr { + if rcmgr == nil { + rcmgr = &network.NullResourceManager{} + } + return &ConnMgr{ + enableReuseport: enableReuseport, + reuse: reuseport.Transport{}, + connGater: gater, + rcmgr: rcmgr, + listeners: make(map[string]*multiplexedListener), + } +} + +func (t *ConnMgr) maListen(listenAddr ma.Multiaddr) (manet.Listener, error) { + if t.useReuseport() { + return t.reuse.Listen(listenAddr) + } else { + return manet.Listen(listenAddr) + } +} + +func (t *ConnMgr) useReuseport() bool { + return t.enableReuseport && ReuseportIsAvailable() +} + +func getTCPAddr(listenAddr ma.Multiaddr) (ma.Multiaddr, error) { + haveTCP := false + addr, _ := ma.SplitFunc(listenAddr, func(c ma.Component) bool { + if haveTCP { + return true + } + if c.Protocol().Code == ma.P_TCP { + haveTCP = true + } + return false + }) + if !haveTCP { + return nil, fmt.Errorf("invalid listen addr %s, need tcp address", listenAddr) + } + return addr, nil +} + +// DemultiplexedListen returns a listener for laddr listening for `connType` connections. The connections +// accepted from returned listeners need to be upgraded with a `transport.Upgrader`. +// NOTE: All listeners for port 0 share the same underlying socket, so they have the same specific port. +func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (manet.Listener, error) { + if !connType.IsKnown() { + return nil, fmt.Errorf("unknown connection type: %s", connType) + } + laddr, err := getTCPAddr(laddr) + if err != nil { + return nil, err + } + + t.mx.Lock() + defer t.mx.Unlock() + ml, ok := t.listeners[laddr.String()] + if ok { + dl, err := ml.DemultiplexedListen(connType) + if err != nil { + return nil, err + } + return dl, nil + } + + l, err := t.maListen(laddr) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithCancel(context.Background()) + cancelFunc := func() error { + cancel() + t.mx.Lock() + defer t.mx.Unlock() + delete(t.listeners, laddr.String()) + delete(t.listeners, l.Multiaddr().String()) + return l.Close() + } + ml = &multiplexedListener{ + Listener: l, + listeners: make(map[DemultiplexedConnType]*demultiplexedListener), + ctx: ctx, + closeFn: cancelFunc, + connGater: t.connGater, + rcmgr: t.rcmgr, + } + t.listeners[laddr.String()] = ml + t.listeners[l.Multiaddr().String()] = ml + + dl, err := ml.DemultiplexedListen(connType) + if err != nil { + cerr := ml.Close() + return nil, errors.Join(err, cerr) + } + + ml.wg.Add(1) + go ml.run() + + return dl, nil +} + +var _ manet.Listener = &demultiplexedListener{} + +type multiplexedListener struct { + manet.Listener + listeners map[DemultiplexedConnType]*demultiplexedListener + mx sync.RWMutex + + connGater connmgr.ConnectionGater + rcmgr network.ResourceManager + ctx context.Context + closeFn func() error + wg sync.WaitGroup +} + +var ErrListenerExists = errors.New("listener already exists for this conn type on this address") + +func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (manet.Listener, error) { + if !connType.IsKnown() { + return nil, fmt.Errorf("unknown connection type: %s", connType) + } + + m.mx.Lock() + defer m.mx.Unlock() + if _, ok := m.listeners[connType]; ok { + return nil, ErrListenerExists + } + + ctx, cancel := context.WithCancel(m.ctx) + l := &demultiplexedListener{ + buffer: make(chan manet.Conn), + inner: m.Listener, + ctx: ctx, + cancelFunc: cancel, + closeFn: func() error { m.removeDemultiplexedListener(connType); return nil }, + } + + m.listeners[connType] = l + + return l, nil +} + +func (m *multiplexedListener) run() error { + defer m.Close() + defer m.wg.Done() + acceptQueue := make(chan struct{}, acceptQueueSize) + for { + c, err := m.Listener.Accept() + if err != nil { + return err + } + + // Gate and resource limit the connection here. + // If done after sampling the connection, we'll be vulnerable to DOS attacks by a single peer + // which clogs up our entire connection queue. + // This duplicates the responsibility of gating and resource limiting between here and the upgrader. The + // alternative without duplication requires moving the process of upgrading the connection here, which forces + // us to establish the websocket connection here. That is more duplication, or a significant breaking change. + // + // Bugs around multiple calls to OpenConnection or InterceptAccept are prevented by the transport + // integration tests. + if m.connGater != nil && !m.connGater.InterceptAccept(c) { + log.Debugf("gater blocked incoming connection on local addr %s from %s", + c.LocalMultiaddr(), c.RemoteMultiaddr()) + if err := c.Close(); err != nil { + log.Warnf("failed to close incoming connection rejected by gater: %s", err) + } + continue + } + connScope, err := m.rcmgr.OpenConnection(network.DirInbound, true, c.RemoteMultiaddr()) + if err != nil { + log.Debugw("resource manager blocked accept of new connection", "error", err) + if err := c.Close(); err != nil { + log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err) + } + continue + } + + select { + case acceptQueue <- struct{}{}: + // NOTE: We can drop the connection, but this is similar to the behaviour in the upgrader. + case <-m.ctx.Done(): + c.Close() + log.Debugf("accept queue full, dropping connection: %s", c.RemoteMultiaddr()) + } + + m.wg.Add(1) + go func() { + defer func() { <-acceptQueue }() + defer m.wg.Done() + ctx, cancelCtx := context.WithTimeout(m.ctx, acceptTimeout) + defer cancelCtx() + t, c, err := identifyConnType(c) + if err != nil { + connScope.Done() + log.Debugf("error demultiplexing connection: %s", err.Error()) + return + } + + connWithScope, err := manetConnWithScope(c, connScope) + if err != nil { + connScope.Done() + closeErr := c.Close() + err = errors.Join(err, closeErr) + log.Debugf("error wrapping connection with scope: %s", err.Error()) + return + } + + m.mx.RLock() + demux, ok := m.listeners[t] + m.mx.RUnlock() + if !ok { + closeErr := connWithScope.Close() + if closeErr != nil { + log.Debugf("no registered listener for demultiplex connection %s. Error closing the connection %s", t, closeErr.Error()) + } else { + log.Debugf("no registered listener for demultiplex connection %s", t) + } + return + } + + select { + case demux.buffer <- connWithScope: + case <-ctx.Done(): + connWithScope.Close() + } + }() + } +} + +func (m *multiplexedListener) Close() error { + m.mx.Lock() + for _, l := range m.listeners { + l.cancelFunc() + } + err := m.closeListener() + m.mx.Unlock() + m.wg.Wait() + return err +} + +func (m *multiplexedListener) closeListener() error { + lerr := m.Listener.Close() + cerr := m.closeFn() + return errors.Join(lerr, cerr) +} + +func (m *multiplexedListener) removeDemultiplexedListener(c DemultiplexedConnType) { + m.mx.Lock() + defer m.mx.Unlock() + + delete(m.listeners, c) + if len(m.listeners) == 0 { + m.closeListener() + m.mx.Unlock() + m.wg.Wait() + m.mx.Lock() + } +} + +type demultiplexedListener struct { + buffer chan manet.Conn + inner manet.Listener + ctx context.Context + cancelFunc context.CancelFunc + closeFn func() error +} + +func (m *demultiplexedListener) Accept() (manet.Conn, error) { + select { + case c := <-m.buffer: + return c, nil + case <-m.ctx.Done(): + return nil, transport.ErrListenerClosed + } +} + +func (m *demultiplexedListener) Close() error { + m.cancelFunc() + return m.closeFn() +} + +func (m *demultiplexedListener) Multiaddr() ma.Multiaddr { + return m.inner.Multiaddr() +} + +func (m *demultiplexedListener) Addr() net.Addr { + return m.inner.Addr() +} diff --git a/p2p/transport/tcpreuse/listener_test.go b/p2p/transport/tcpreuse/listener_test.go new file mode 100644 index 0000000000..b5dc49f2c1 --- /dev/null +++ b/p2p/transport/tcpreuse/listener_test.go @@ -0,0 +1,476 @@ +package tcpreuse + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" + "net" + "net/http" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + "github.com/multiformats/go-multistream" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func selfSignedTLSConfig(t *testing.T) *tls.Config { + t.Helper() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + certTemplate := x509.Certificate{ + SerialNumber: &big.Int{}, + Subject: pkix.Name{ + Organization: []string{"Test"}, + }, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &priv.PublicKey, priv) + require.NoError(t, err) + + cert := tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: priv, + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + return tlsConfig +} + +type wsHandler struct{ conns chan *websocket.Conn } + +func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + u := websocket.Upgrader{} + c, _ := u.Upgrade(w, r, http.Header{}) + wh.conns <- c +} + +func TestListenerSingle(t *testing.T) { + listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") + const N = 64 + for _, enableReuseport := range []bool{true, false} { + t.Run(fmt.Sprintf("multistream-reuseport:%v", enableReuseport), func(t *testing.T) { + cm := NewConnMgr(enableReuseport, nil, nil) + l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + go func() { + d := net.Dialer{} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := d.DialContext(ctx, l.Addr().Network(), l.Addr().String()) + if err != nil { + t.Error("failed to dial", err, i) + return + } + lconn := multistream.NewMSSelect(conn, "a") + buf := make([]byte, 10) + _, err = lconn.Write([]byte("hello-multistream")) + if err != nil { + t.Error(err) + } + _, err = lconn.Read(buf) + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + + var wg sync.WaitGroup + for i := 0; i < N; i++ { + c, err := l.Accept() + require.NoError(t, err) + wg.Add(1) + go func() { + defer wg.Done() + cc := multistream.NewMSSelect(c, "a") + defer cc.Close() + buf := make([]byte, 30) + n, err := cc.Read(buf) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "hello-multistream", string(buf[:n])) { + return + } + }() + } + wg.Wait() + }) + + t.Run(fmt.Sprintf("WebSocket-reuseport:%v", enableReuseport), func(t *testing.T) { + cm := NewConnMgr(enableReuseport, nil, nil) + l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} + go func() { + http.Serve(manet.NetListener(l), wh) + }() + go func() { + d := websocket.Dialer{} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, _, err := d.DialContext(ctx, fmt.Sprintf("ws://%s", l.Addr().String()), http.Header{}) + if err != nil { + t.Error("failed to dial", err, i) + return + } + err = conn.WriteMessage(websocket.TextMessage, []byte("hello")) + if err != nil { + t.Error(err) + } + _, _, err = conn.ReadMessage() + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + var wg sync.WaitGroup + for i := 0; i < N; i++ { + c := <-wh.conns + wg.Add(1) + go func() { + defer wg.Done() + defer c.Close() + msgType, buf, err := c.ReadMessage() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, msgType, websocket.TextMessage) { + return + } + if !assert.Equal(t, "hello", string(buf)) { + return + } + }() + } + wg.Wait() + }) + + t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", enableReuseport), func(t *testing.T) { + cm := NewConnMgr(enableReuseport, nil, nil) + l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) + require.NoError(t, err) + defer l.Close() + wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} + go func() { + s := http.Server{Handler: wh, TLSConfig: selfSignedTLSConfig(t)} + s.ServeTLS(manet.NetListener(l), "", "") + }() + go func() { + d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, _, err := d.DialContext(ctx, fmt.Sprintf("wss://%s", l.Addr().String()), http.Header{}) + if err != nil { + t.Error("failed to dial", err, i) + return + } + err = conn.WriteMessage(websocket.TextMessage, []byte("hello")) + if err != nil { + t.Error(err) + } + _, _, err = conn.ReadMessage() + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + var wg sync.WaitGroup + for i := 0; i < N; i++ { + c := <-wh.conns + wg.Add(1) + go func() { + defer wg.Done() + defer c.Close() + msgType, buf, err := c.ReadMessage() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, msgType, websocket.TextMessage) { + return + } + if !assert.Equal(t, "hello", string(buf)) { + return + } + }() + } + wg.Wait() + }) + } +} + +func TestListenerMultiplexed(t *testing.T) { + listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") + const N = 20 + for _, enableReuseport := range []bool{true, false} { + cm := NewConnMgr(enableReuseport, nil, nil) + msl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + defer msl.Close() + + wsl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + defer wsl.Close() + require.Equal(t, wsl.Multiaddr(), msl.Multiaddr()) + wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} + go func() { + http.Serve(manet.NetListener(wsl), wh) + }() + + wssl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) + require.NoError(t, err) + defer wssl.Close() + require.Equal(t, wssl.Multiaddr(), wsl.Multiaddr()) + whs := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} + go func() { + s := http.Server{Handler: whs, TLSConfig: selfSignedTLSConfig(t)} + s.ServeTLS(manet.NetListener(wssl), "", "") + }() + + // multistream connections + go func() { + d := net.Dialer{} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := d.DialContext(ctx, msl.Addr().Network(), msl.Addr().String()) + if err != nil { + t.Error("failed to dial", err, i) + return + } + lconn := multistream.NewMSSelect(conn, "a") + buf := make([]byte, 10) + _, err = lconn.Write([]byte("multistream")) + if err != nil { + t.Error(err) + } + _, err = lconn.Read(buf) + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + + // ws connections + go func() { + d := websocket.Dialer{} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, _, err := d.DialContext(ctx, fmt.Sprintf("ws://%s", msl.Addr().String()), http.Header{}) + if err != nil { + t.Error("failed to dial", err, i) + return + } + err = conn.WriteMessage(websocket.TextMessage, []byte("websocket")) + if err != nil { + t.Error(err) + } + _, _, err = conn.ReadMessage() + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + + // wss connections + go func() { + d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, _, err := d.DialContext(ctx, fmt.Sprintf("wss://%s", msl.Addr().String()), http.Header{}) + if err != nil { + t.Error("failed to dial", err, i) + return + } + err = conn.WriteMessage(websocket.TextMessage, []byte("websocket-tls")) + if err != nil { + t.Error(err) + } + _, _, err = conn.ReadMessage() + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < N; i++ { + c, err := msl.Accept() + if !assert.NoError(t, err) { + return + } + wg.Add(1) + go func() { + defer wg.Done() + cc := multistream.NewMSSelect(c, "a") + defer cc.Close() + buf := make([]byte, 20) + n, err := cc.Read(buf) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "multistream", string(buf[:n])) { + return + } + }() + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < N; i++ { + c := <-wh.conns + wg.Add(1) + go func() { + defer wg.Done() + defer c.Close() + msgType, buf, err := c.ReadMessage() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, msgType, websocket.TextMessage) { + return + } + if !assert.Equal(t, "websocket", string(buf)) { + return + } + }() + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < N; i++ { + c := <-whs.conns + wg.Add(1) + go func() { + defer wg.Done() + defer c.Close() + msgType, buf, err := c.ReadMessage() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, msgType, websocket.TextMessage) { + return + } + if !assert.Equal(t, "websocket-tls", string(buf)) { + return + } + }() + } + }() + wg.Wait() + } +} + +func TestListenerClose(t *testing.T) { + testClose := func(listenAddr ma.Multiaddr) { + // listen on port 0 + cm := NewConnMgr(false, nil, nil) + ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + wl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + wl.Close() + + wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + + ml.Close() + + mll, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + + mll.Close() + wl.Close() + + ml, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + + // Now listen on the specific port previously used + listenAddr = ml.Multiaddr() + wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + wl.Close() + + wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + + ml.Close() + wl.Close() + } + listenAddrs := []ma.Multiaddr{ma.StringCast("/ip4/0.0.0.0/tcp/0"), ma.StringCast("/ip6/::/tcp/0")} + for _, listenAddr := range listenAddrs { + testClose(listenAddr) + } +} + +func setDeferReset[T any](t testing.TB, ptr *T, val T) { + t.Helper() + orig := *ptr + *ptr = val + t.Cleanup(func() { *ptr = orig }) +} + +// TestHitTimeout asserts that we don't panic in case we fail to peek at the connection. +func TestHitTimeout(t *testing.T) { + setDeferReset(t, &identifyConnTimeout, 100*time.Millisecond) + // listen on port 0 + cm := NewConnMgr(false, nil, nil) + + listenAddr := ma.StringCast("/ip4/127.0.0.1/tcp/0") + ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + defer ml.Close() + + tcpConn, err := net.Dial(ml.Addr().Network(), ml.Addr().String()) + require.NoError(t, err) + + // Stall tcp conn for over the timeout. + time.Sleep(identifyConnTimeout + 100*time.Millisecond) + + tcpConn.Close() +} diff --git a/p2p/transport/tcp/reuseport.go b/p2p/transport/tcpreuse/reuseport.go similarity index 81% rename from p2p/transport/tcp/reuseport.go rename to p2p/transport/tcpreuse/reuseport.go index ba09304622..a2529c0bda 100644 --- a/p2p/transport/tcp/reuseport.go +++ b/p2p/transport/tcpreuse/reuseport.go @@ -1,4 +1,4 @@ -package tcp +package tcpreuse import ( "os" @@ -11,13 +11,13 @@ import ( // It default to true. const envReuseport = "LIBP2P_TCP_REUSEPORT" -// envReuseportVal stores the value of envReuseport. defaults to true. -var envReuseportVal = true +// EnvReuseportVal stores the value of envReuseport. defaults to true. +var EnvReuseportVal = true func init() { v := strings.ToLower(os.Getenv(envReuseport)) if v == "false" || v == "f" || v == "0" { - envReuseportVal = false + EnvReuseportVal = false log.Infof("REUSEPORT disabled (LIBP2P_TCP_REUSEPORT=%s)", v) } } @@ -31,5 +31,5 @@ func init() { // If this becomes a sought after feature, we could add this to the config. // In the end, reuseport is a stop-gap. func ReuseportIsAvailable() bool { - return envReuseportVal && reuseport.Available() + return EnvReuseportVal && reuseport.Available() } diff --git a/p2p/transport/testsuite/utils_suite.go b/p2p/transport/testsuite/utils_suite.go index 5e488397a5..8b002f8900 100644 --- a/p2p/transport/testsuite/utils_suite.go +++ b/p2p/transport/testsuite/utils_suite.go @@ -11,7 +11,9 @@ import ( ma "github.com/multiformats/go-multiaddr" ) -var Subtests = []func(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID){ +type TransportSubTestFn func(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID) + +var Subtests = []TransportSubTestFn{ SubtestProtocols, SubtestBasic, SubtestCancel, @@ -33,12 +35,17 @@ func getFunctionName(i interface{}) string { } func SubtestTransport(t *testing.T, ta, tb transport.Transport, addr string, peerA peer.ID) { + t.Helper() + SubtestTransportWithFs(t, ta, tb, addr, peerA, Subtests) +} + +func SubtestTransportWithFs(t *testing.T, ta, tb transport.Transport, addr string, peerA peer.ID, tests []TransportSubTestFn) { maddr, err := ma.NewMultiaddr(addr) if err != nil { t.Fatal(err) } - for _, f := range Subtests { + for _, f := range tests { t.Run(getFunctionName(f), func(t *testing.T) { f(t, ta, tb, maddr, peerA) }) diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index d4ba3c0550..c3e2f29799 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -33,8 +33,12 @@ func (c *connMultiaddrs) LocalMultiaddr() ma.Multiaddr { return c.local } func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote } const ( - candidateSetupTimeout = 20 * time.Second - DefaultMaxInFlightConnections = 10 + candidateSetupTimeout = 10 * time.Second + // This is higher than other transports(64) as there's no way to detect a peer that has gone away after + // sending the initial connection request message(STUN Binding request). Such peers take up a goroutine + // till connection timeout. As the number of handshakes in parallel is still guarded by the resource + // manager, this higher number is okay. + DefaultMaxInFlightConnections = 128 ) type listener struct { @@ -325,8 +329,7 @@ func (l *listener) Multiaddr() ma.Multiaddr { // addOnConnectionStateChangeCallback adds the OnConnectionStateChange to the PeerConnection. // The channel returned here: // * is closed when the state changes to Connection -// * receives an error when the state changes to Failed -// * doesn't receive anything (nor is closed) when the state changes to Disconnected +// * receives an error when the state changes to Failed or Closed or Disconnected func addOnConnectionStateChangeCallback(pc *webrtc.PeerConnection) <-chan error { errC := make(chan error, 1) var once sync.Once @@ -334,17 +337,15 @@ func addOnConnectionStateChangeCallback(pc *webrtc.PeerConnection) <-chan error switch pc.ConnectionState() { case webrtc.PeerConnectionStateConnected: once.Do(func() { close(errC) }) - case webrtc.PeerConnectionStateFailed: + // PeerConnectionStateFailed happens when we fail to negotiate the connection. + // PeerConnectionStateDisconnected happens when we disconnect immediately after connecting. + // PeerConnectionStateClosed happens when we close the peer connection locally, not when remote closes. We don't need + // to error in this case, but it's a no-op, so it doesn't hurt. + case webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateClosed, webrtc.PeerConnectionStateDisconnected: once.Do(func() { errC <- errors.New("peerconnection failed") close(errC) }) - case webrtc.PeerConnectionStateDisconnected: - // the connection can move to a disconnected state and back to a connected state without ICE renegotiation. - // This could happen when underlying UDP packets are lost, and therefore the connection moves to the disconnected state. - // If the connection then receives packets on the connection, it can move back to the connected state. - // If no packets are received until the failed timeout is triggered, the connection moves to the failed state. - log.Warn("peerconnection disconnected") } }) return errC diff --git a/p2p/transport/websocket/addrs_test.go b/p2p/transport/websocket/addrs_test.go index 3c5ba502a9..50a8b9e823 100644 --- a/p2p/transport/websocket/addrs_test.go +++ b/p2p/transport/websocket/addrs_test.go @@ -69,7 +69,7 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) { } func TestListeningOnDNSAddr(t *testing.T) { - ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil) + ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil) require.NoError(t, err) addr := ln.Multiaddr() first, rest := ma.SplitFirst(addr) diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index 30b70055d0..ce51611703 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -1,6 +1,7 @@ package websocket import ( + "errors" "io" "net" "sync" @@ -8,6 +9,8 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" ws "github.com/gorilla/websocket" ) @@ -22,20 +25,53 @@ type Conn struct { secure bool DefaultMessageType int reader io.Reader - closeOnce sync.Once + closeOnceVal func() error + laddr ma.Multiaddr + raddr ma.Multiaddr readLock, writeLock sync.Mutex } var _ net.Conn = (*Conn)(nil) +var _ manet.Conn = (*Conn)(nil) // NewConn creates a Conn given a regular gorilla/websocket Conn. +// +// Deprecated: There's no reason to use this method externally. It'll be unexported in a future release. func NewConn(raw *ws.Conn, secure bool) *Conn { - return &Conn{ + lna := NewAddrWithScheme(raw.LocalAddr().String(), secure) + laddr, err := manet.FromNetAddr(lna) + if err != nil { + log.Errorf("BUG: invalid localaddr on websocket conn", raw.LocalAddr()) + return nil + } + + rna := NewAddrWithScheme(raw.RemoteAddr().String(), secure) + raddr, err := manet.FromNetAddr(rna) + if err != nil { + log.Errorf("BUG: invalid remoteaddr on websocket conn", raw.RemoteAddr()) + return nil + } + + c := &Conn{ Conn: raw, secure: secure, DefaultMessageType: ws.BinaryMessage, + laddr: laddr, + raddr: raddr, } + c.closeOnceVal = sync.OnceValue(c.closeOnceFn) + return c +} + +// LocalMultiaddr implements manet.Conn. +func (c *Conn) LocalMultiaddr() ma.Multiaddr { + return c.laddr +} + +// RemoteMultiaddr implements manet.Conn. +func (c *Conn) RemoteMultiaddr() ma.Multiaddr { + return c.raddr } func (c *Conn) Read(b []byte) (int, error) { @@ -99,26 +135,31 @@ func (c *Conn) Write(b []byte) (n int, err error) { return len(b), nil } -// Close closes the connection. Only the first call to Close will receive the -// close error, subsequent and concurrent calls will return nil. +func (c *Conn) Scope() network.ConnManagementScope { + nc := c.NetConn() + if sc, ok := nc.(interface { + Scope() network.ConnManagementScope + }); ok { + return sc.Scope() + } + return nil +} + +// Close closes the connection. +// subsequent and concurrent calls will return the same error value. // This method is thread-safe. func (c *Conn) Close() error { - var err error - c.closeOnce.Do(func() { - err1 := c.Conn.WriteControl( - ws.CloseMessage, - ws.FormatCloseMessage(ws.CloseNormalClosure, "closed"), - time.Now().Add(GracefulCloseTimeout), - ) - err2 := c.Conn.Close() - switch { - case err1 != nil: - err = err1 - case err2 != nil: - err = err2 - } - }) - return err + return c.closeOnceVal() +} + +func (c *Conn) closeOnceFn() error { + err1 := c.Conn.WriteControl( + ws.CloseMessage, + ws.FormatCloseMessage(ws.CloseNormalClosure, "closed"), + time.Now().Add(GracefulCloseTimeout), + ) + err2 := c.Conn.Close() + return errors.Join(err1, err2) } func (c *Conn) LocalAddr() net.Addr { diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index 8071ddb814..dd399aa079 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -4,14 +4,16 @@ import ( "crypto/tls" "errors" "fmt" - "go.uber.org/zap" "net" "net/http" "sync" + "go.uber.org/zap" + logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" @@ -50,7 +52,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr { // newListener creates a new listener from a raw net.Listener. // tlsConf may be nil (for unencrypted websockets). -func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { +func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr) (*listener, error) { parsed, err := parseWebsocketMultiaddr(a) if err != nil { return nil, err @@ -60,19 +62,36 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a) } - lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr) - if err != nil { - return nil, err - } - nl, err := net.Listen(lnet, lnaddr) - if err != nil { - return nil, err + var nl net.Listener + + if sharedTcp == nil { + lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr) + if err != nil { + return nil, err + } + nl, err = net.Listen(lnet, lnaddr) + if err != nil { + return nil, err + } + } else { + var connType tcpreuse.DemultiplexedConnType + if parsed.isWSS { + connType = tcpreuse.DemultiplexedConnType_TLS + } else { + connType = tcpreuse.DemultiplexedConnType_HTTP + } + mal, err := sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType) + if err != nil { + return nil, err + } + nl = manet.NetListener(mal) } laddr, err := manet.FromNetAddr(nl.Addr()) if err != nil { return nil, err } + first, _ := ma.SplitFirst(a) // Don't resolve dns addresses. // We want to be able to announce domain names, so the peer can validate the TLS certificate. @@ -111,7 +130,12 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { // The upgrader writes a response for us. return } - + nc := NewConn(c, l.isWss) + if nc == nil { + c.Close() + w.WriteHeader(500) + return + } select { case l.incoming <- NewConn(c, l.isWss): case <-l.closed: @@ -126,13 +150,7 @@ func (l *listener) Accept() (manet.Conn, error) { if !ok { return nil, transport.ErrListenerClosed } - - mnc, err := manet.WrapNetConn(c) - if err != nil { - c.Close() - return nil, err - } - return mnc, nil + return c, nil case <-l.closed: return nil, transport.ErrListenerClosed } diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 0f07617dc7..e24cb88c6d 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -11,6 +11,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" @@ -87,11 +88,13 @@ type WebsocketTransport struct { tlsClientConf *tls.Config tlsConf *tls.Config + + sharedTcp *tcpreuse.ConnMgr } var _ transport.Transport = (*WebsocketTransport)(nil) -func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*WebsocketTransport, error) { +func New(u transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreuse.ConnMgr, opts ...Option) (*WebsocketTransport, error) { if rcmgr == nil { rcmgr = &network.NullResourceManager{} } @@ -99,6 +102,7 @@ func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (* upgrader: u, rcmgr: rcmgr, tlsClientConf: &tls.Config{}, + sharedTcp: sharedTCP, } for _, opt := range opts { if err := opt(t); err != nil { @@ -233,7 +237,7 @@ func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) { if t.tlsConf != nil { tlsConf = t.tlsConf.Clone() } - l, err := newListener(a, tlsConf) + l, err := newListener(a, tlsConf, t.sharedTcp) if err != nil { return nil, err } diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 8f912c4138..9ca03775a2 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -154,7 +154,7 @@ func testWSSServer(t *testing.T, listenAddr ma.Multiaddr) (ma.Multiaddr, peer.ID } id, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, WithTLSConfig(tlsConf)) + tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSConfig(tlsConf)) if err != nil { t.Fatal(err) } @@ -237,7 +237,7 @@ func TestHostHeaderWss(t *testing.T) { tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA _, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, WithTLSClientConfig(tlsConfig)) + tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig)) require.NoError(t, err) masToDial, err := tpt.Resolve(context.Background(), serverMA) @@ -256,7 +256,7 @@ func TestDialWss(t *testing.T) { tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA _, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, WithTLSClientConfig(tlsConfig)) + tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig)) require.NoError(t, err) masToDial, err := tpt.Resolve(context.Background(), serverMA) @@ -279,7 +279,7 @@ func TestDialWssNoClientCert(t *testing.T) { require.Contains(t, serverMA.String(), "tls") _, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) require.NoError(t, err) masToDial, err := tpt.Resolve(context.Background(), serverMA) @@ -294,12 +294,12 @@ func TestDialWssNoClientCert(t *testing.T) { func TestWebsocketTransport(t *testing.T) { peerA, ua := newUpgrader(t) - ta, err := New(ua, nil) + ta, err := New(ua, nil, nil) if err != nil { t.Fatal(err) } _, ub := newUpgrader(t) - tb, err := New(ub, nil) + tb, err := New(ub, nil, nil) if err != nil { t.Fatal(err) } @@ -325,7 +325,7 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) { opts = append(opts, WithTLSConfig(tlsConf)) } server, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, opts...) + tpt, err := New(u, &network.NullResourceManager{}, nil, opts...) require.NoError(t, err) l, err := tpt.Listen(laddr) require.NoError(t, err) @@ -344,7 +344,7 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) { opts = append(opts, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) } _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, opts...) + tpt, err := New(u, &network.NullResourceManager{}, nil, opts...) require.NoError(t, err) c, err := tpt.Dial(context.Background(), l.Multiaddr(), server) require.NoError(t, err) @@ -382,7 +382,7 @@ func TestWebsocketConnection(t *testing.T) { func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) { _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) require.NoError(t, err) addr := ma.StringCast("/ip4/127.0.0.1/tcp/0/wss") _, err = tpt.Listen(addr) @@ -391,7 +391,7 @@ func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) { func TestWebsocketListenSecureAndInsecure(t *testing.T) { serverID, serverUpgrader := newUpgrader(t) - server, err := New(serverUpgrader, &network.NullResourceManager{}, WithTLSConfig(generateTLSConfig(t))) + server, err := New(serverUpgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t))) require.NoError(t, err) lnInsecure, err := server.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) @@ -401,7 +401,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { t.Run("insecure", func(t *testing.T) { _, clientUpgrader := newUpgrader(t) - client, err := New(clientUpgrader, &network.NullResourceManager{}, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) + client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) require.NoError(t, err) // dialing the insecure address should succeed @@ -418,7 +418,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { t.Run("secure", func(t *testing.T) { _, clientUpgrader := newUpgrader(t) - client, err := New(clientUpgrader, &network.NullResourceManager{}, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) + client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) require.NoError(t, err) // dialing the insecure address should succeed @@ -436,7 +436,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { func TestConcurrentClose(t *testing.T) { _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) require.NoError(t, err) l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) if err != nil { @@ -474,7 +474,7 @@ func TestConcurrentClose(t *testing.T) { func TestWriteZero(t *testing.T) { _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) if err != nil { t.Fatal(err) } diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index 0525124711..d914398e0e 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -71,7 +71,7 @@ func (c *conn) allowWindowIncrease(size uint64) bool { // It must be called even if the peer closed the connection in order for // garbage collection to properly work in this package. func (c *conn) Close() error { - c.scope.Done() + defer c.scope.Done() c.transport.removeConn(c.session) err := c.session.CloseWithError(0, "") _ = c.qconn.CloseWithError(1, "") diff --git a/test-plans/go.mod b/test-plans/go.mod index 5b0d9980bf..e61b3e1889 100644 --- a/test-plans/go.mod +++ b/test-plans/go.mod @@ -61,7 +61,7 @@ require ( github.com/multiformats/go-multibase v0.2.0 // indirect github.com/multiformats/go-multicodec v0.9.0 // indirect github.com/multiformats/go-multihash v0.2.3 // indirect - github.com/multiformats/go-multistream v0.5.0 // indirect + github.com/multiformats/go-multistream v0.6.0 // indirect github.com/multiformats/go-varint v0.0.7 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/onsi/ginkgo/v2 v2.20.2 // indirect @@ -90,7 +90,7 @@ require ( github.com/prometheus/common v0.60.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/quic-go/qpack v0.5.1 // indirect - github.com/quic-go/quic-go v0.48.0 // indirect + github.com/quic-go/quic-go v0.48.1 // indirect github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66 // indirect github.com/raulk/go-watchdog v1.3.0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect diff --git a/test-plans/go.sum b/test-plans/go.sum index cd5f393530..d37b876a6f 100644 --- a/test-plans/go.sum +++ b/test-plans/go.sum @@ -197,8 +197,8 @@ github.com/multiformats/go-multicodec v0.9.0/go.mod h1:L3QTQvMIaVBkXOXXtVmYE+LI1 github.com/multiformats/go-multihash v0.0.8/go.mod h1:YSLudS+Pi8NHE7o6tb3D8vrpKa63epEDmG8nTduyAew= github.com/multiformats/go-multihash v0.2.3 h1:7Lyc8XfX/IY2jWb/gI7JP+o7JEq9hOa7BFvVU9RSh+U= github.com/multiformats/go-multihash v0.2.3/go.mod h1:dXgKXCXjBzdscBLk9JkjINiEsCKRVch90MdaGiKsvSM= -github.com/multiformats/go-multistream v0.5.0 h1:5htLSLl7lvJk3xx3qT/8Zm9J4K8vEOf/QGkvOGQAyiE= -github.com/multiformats/go-multistream v0.5.0/go.mod h1:n6tMZiwiP2wUsR8DgfDWw1dydlEqV3l6N3/GBsX6ILA= +github.com/multiformats/go-multistream v0.6.0 h1:ZaHKbsL404720283o4c/IHQXiS6gb8qAN5EIJ4PN5EA= +github.com/multiformats/go-multistream v0.6.0/go.mod h1:MOyoG5otO24cHIg8kf9QW2/NozURlkP/rvi2FQJyCPg= github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8= github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= @@ -280,8 +280,8 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= -github.com/quic-go/quic-go v0.48.0 h1:2TCyvBrMu1Z25rvIAlnp2dPT4lgh/uTqLqiXVpp5AeU= -github.com/quic-go/quic-go v0.48.0/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= +github.com/quic-go/quic-go v0.48.1 h1:y/8xmfWI9qmGTc+lBr4jKRUWLGSlSigv847ULJ4hYXA= +github.com/quic-go/quic-go v0.48.1/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66 h1:4WFk6u3sOT6pLa1kQ50ZVdm8BQFgJNA117cepZxtLIg= github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66/go.mod h1:Vp72IJajgeOL6ddqrAhmp7IM9zbTcgkQxD/YdxrVwMw= github.com/raulk/go-watchdog v1.3.0 h1:oUmdlHxdkXRJlwfG0O9omj8ukerm8MEQavSiDTEtBsk= diff --git a/version.json b/version.json index 53072426c1..707e97ed5b 100644 --- a/version.json +++ b/version.json @@ -1,3 +1,3 @@ { - "version": "v0.36.2" + "version": "v0.37.0" }