diff --git a/CHANGELOG.md b/CHANGELOG.md index e1379979e..7f4d0c094 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,6 +68,7 @@ The following emojis are used to highlight certain changes: - `gateway`: `NewCacheBlockStore` and `NewCarBackend` will use `prometheus.DefaultRegisterer` when a custom one is not specified via `WithPrometheusRegistry` [#722](https://github.com/ipfs/boxo/pull/722) - `filestore`: added opt-in `WithMMapReader` option to `FileManager` to enable memory-mapped file reads [#665](https://github.com/ipfs/boxo/pull/665) - `bitswap/routing` `ProviderQueryManager` does not require calling `Startup` separate from `New`. [#741](https://github.com/ipfs/boxo/pull/741) +- `bitswap/routing` ProviderQueryManager does not use liftcycle context. ### Changed diff --git a/bitswap/client/bitswap_with_sessions_test.go b/bitswap/client/bitswap_with_sessions_test.go index 5d5ac8226..2fee84217 100644 --- a/bitswap/client/bitswap_with_sessions_test.go +++ b/bitswap/client/bitswap_with_sessions_test.go @@ -117,9 +117,6 @@ func assertBlockListsFrom(from peer.ID, got, exp []blocks.Block) error { // TestCustomProviderQueryManager tests that nothing breaks if we use a custom // PQM when creating bitswap. func TestCustomProviderQueryManager(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - vnet := getVirtualNetwork() router := mockrouting.NewServer() ig := testinstance.NewTestInstanceGenerator(vnet, router, nil, nil) @@ -130,10 +127,15 @@ func TestCustomProviderQueryManager(t *testing.T) { b := ig.Next() // Replace bitswap in instance a with our customized one. - pqm, err := providerquerymanager.New(ctx, a.Adapter, router.Client(a.Identity)) + pqm, err := providerquerymanager.New(a.Adapter, router.Client(a.Identity)) if err != nil { t.Fatal(err) } + defer pqm.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + bs := bitswap.New(ctx, a.Adapter, pqm, a.Blockstore, bitswap.WithClientOption(client.WithDefaultProviderQueryManager(false))) a.Exchange.Close() // close old to be sure. diff --git a/bitswap/client/client.go b/bitswap/client/client.go index 5f950588a..a115d07f6 100644 --- a/bitswap/client/client.go +++ b/bitswap/client/client.go @@ -182,7 +182,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, providerFinder Pr if bs.providerFinder != nil && bs.defaultProviderQueryManager { // network can do dialing. - pqm, err := rpqm.New(ctx, network, bs.providerFinder, + pqm, err := rpqm.New(network, bs.providerFinder, rpqm.WithMaxInProcessRequests(16), rpqm.WithMaxProviders(10), rpqm.WithMaxTimeout(10*time.Second)) @@ -512,6 +512,9 @@ func (bs *Client) Close() error { close(bs.closing) bs.sm.Shutdown() bs.cancel() + if bs.pqm != nil { + bs.pqm.Close() + } bs.notif.Shutdown() }) return nil diff --git a/routing/providerquerymanager/providerquerymanager.go b/routing/providerquerymanager/providerquerymanager.go index 98497ee66..592f7f814 100644 --- a/routing/providerquerymanager/providerquerymanager.go +++ b/routing/providerquerymanager/providerquerymanager.go @@ -85,7 +85,8 @@ type cancelRequestMessage struct { // - ensure two findprovider calls for the same block don't run concurrently // - manage timeouts type ProviderQueryManager struct { - ctx context.Context + closeOnce sync.Once + closing chan struct{} dialer ProviderQueryDialer router ProviderQueryRouter providerQueryMessages chan providerQueryMessage @@ -133,9 +134,9 @@ func WithMaxProviders(count int) Option { // New initializes a new ProviderQueryManager for a given context and a given // network provider. -func New(ctx context.Context, dialer ProviderQueryDialer, router ProviderQueryRouter, opts ...Option) (*ProviderQueryManager, error) { +func New(dialer ProviderQueryDialer, router ProviderQueryRouter, opts ...Option) (*ProviderQueryManager, error) { pqm := &ProviderQueryManager{ - ctx: ctx, + closing: make(chan struct{}), dialer: dialer, router: router, providerQueryMessages: make(chan providerQueryMessage), @@ -155,6 +156,12 @@ func New(ctx context.Context, dialer ProviderQueryDialer, router ProviderQueryRo return pqm, nil } +func (pqm *ProviderQueryManager) Close() { + pqm.closeOnce.Do(func() { + close(pqm.closing) + }) +} + type inProgressRequest struct { providersSoFar []peer.AddrInfo incoming chan peer.AddrInfo @@ -180,7 +187,7 @@ func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, k: k, inProgressRequestChan: inProgressRequestChan, }: - case <-pqm.ctx.Done(): + case <-pqm.closing: ch := make(chan peer.AddrInfo) close(ch) span.End() @@ -196,7 +203,7 @@ func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, // get to receiveProviders. var receivedInProgressRequest inProgressRequest select { - case <-pqm.ctx.Done(): + case <-pqm.closing: ch := make(chan peer.AddrInfo) close(ch) span.End() @@ -256,7 +263,7 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k for receivedProviders.Len() > 0 || incomingProviders != nil { select { - case <-pqm.ctx.Done(): + case <-pqm.closing: return case <-sessionCtx.Done(): if incomingProviders != nil { @@ -300,7 +307,7 @@ func (pqm *ProviderQueryManager) cancelProviderRequest(ctx context.Context, k ci if !ok { return } - case <-pqm.ctx.Done(): + case <-pqm.closing: return } } @@ -316,13 +323,13 @@ func (pqm *ProviderQueryManager) findProviderWorker() { } // Read find provider requests until channel is closed. The channel is - // closed as soon as pqm.ctx is canceled, so there is no need to select on - // that context here. + // closed as soon as pqm.Close is called, so there is no need to select on + // any other channel to detect shutdown. for fpr := range pqm.providerRequestsProcessing.Out() { if findSem != nil { select { case findSem <- struct{}{}: - case <-pqm.ctx.Done(): + case <-pqm.closing: return } } @@ -362,7 +369,7 @@ func (pqm *ProviderQueryManager) findProviderWorker() { k: k, p: p, }: - case <-pqm.ctx.Done(): + case <-pqm.closing: return } }(p) @@ -374,7 +381,7 @@ func (pqm *ProviderQueryManager) findProviderWorker() { ctx: ctx, k: k, }: - case <-pqm.ctx.Done(): + case <-pqm.closing: } }(fpr.ctx, fpr.k) } @@ -402,7 +409,7 @@ func (pqm *ProviderQueryManager) run() { case nextMessage := <-pqm.providerQueryMessages: nextMessage.debugMessage() nextMessage.handle(pqm) - case <-pqm.ctx.Done(): + case <-pqm.closing: return } } @@ -423,7 +430,7 @@ func (rpm *receivedProviderMessage) handle(pqm *ProviderQueryManager) { for listener := range requestStatus.listeners { select { case listener <- rpm.p: - case <-pqm.ctx.Done(): + case <-pqm.closing: return } } @@ -458,12 +465,12 @@ func (npqm *newProvideQueryMessage) debugMessage() { func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) { requestStatus, ok := pqm.inProgressRequestStatuses[npqm.k] if !ok { - ctx, cancelFn := context.WithCancel(pqm.ctx) + ctx, cancelFn := context.WithCancel(context.Background()) span := trace.SpanFromContext(npqm.ctx) span.AddEvent("NewQuery", trace.WithAttributes(attribute.Stringer("cid", npqm.k))) ctx = trace.ContextWithSpan(ctx, span) - // Use context derived from pqm.ctx here, and not the context from the + // Use context derived from background here, and not the context from the // request (npqm.ctx), because this inProgressRequestStatus applies to // all in-progress requests for the CID (npqm.k). // @@ -486,7 +493,7 @@ func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) { k: npqm.k, ctx: ctx, }: - case <-pqm.ctx.Done(): + case <-pqm.closing: return } } else { @@ -502,7 +509,7 @@ func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) { providersSoFar: requestStatus.providersSoFar, incoming: inProgressChan, }: - case <-pqm.ctx.Done(): + case <-pqm.closing: } } diff --git a/routing/providerquerymanager/providerquerymanager_test.go b/routing/providerquerymanager/providerquerymanager_test.go index 1be26c4e3..8026c5364 100644 --- a/routing/providerquerymanager/providerquerymanager_test.go +++ b/routing/providerquerymanager/providerquerymanager_test.go @@ -74,11 +74,11 @@ func TestNormalSimultaneousFetch(t *testing.T) { peersFound: peers, delay: 1 * time.Millisecond, } - ctx := context.Background() - providerQueryManager := mustNotErr(New(ctx, fpd, fpn)) + providerQueryManager := mustNotErr(New(fpd, fpn)) + defer providerQueryManager.Close() keys := random.Cids(2) - sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + sessionCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0) secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[1], 0) @@ -111,11 +111,11 @@ func TestDedupingProviderRequests(t *testing.T) { peersFound: peers, delay: 1 * time.Millisecond, } - ctx := context.Background() - providerQueryManager := mustNotErr(New(ctx, fpd, fpn)) + providerQueryManager := mustNotErr(New(fpd, fpn)) + defer providerQueryManager.Close() key := random.Cids(1)[0] - sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + sessionCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0) secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0) @@ -151,12 +151,13 @@ func TestCancelOneRequestDoesNotTerminateAnother(t *testing.T) { peersFound: peers, delay: 1 * time.Millisecond, } - ctx := context.Background() - providerQueryManager := mustNotErr(New(ctx, fpd, fpn)) + providerQueryManager := mustNotErr(New(fpd, fpn)) + defer providerQueryManager.Close() key := random.Cids(1)[0] // first session will cancel before done + ctx := context.Background() firstSessionCtx, firstCancel := context.WithTimeout(ctx, 3*time.Millisecond) defer firstCancel() firstRequestChan := providerQueryManager.FindProvidersAsync(firstSessionCtx, key, 0) @@ -195,14 +196,13 @@ func TestCancelManagerExitsGracefully(t *testing.T) { peersFound: peers, delay: 1 * time.Millisecond, } - ctx := context.Background() - managerCtx, managerCancel := context.WithTimeout(ctx, 5*time.Millisecond) - defer managerCancel() - providerQueryManager := mustNotErr(New(managerCtx, fpd, fpn)) + providerQueryManager := mustNotErr(New(fpd, fpn)) + defer providerQueryManager.Close() + time.AfterFunc(5*time.Millisecond, providerQueryManager.Close) key := random.Cids(1)[0] - sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) + sessionCtx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) defer cancel() firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0) secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0) @@ -232,12 +232,12 @@ func TestPeersWithConnectionErrorsNotAddedToPeerList(t *testing.T) { peersFound: peers, delay: 1 * time.Millisecond, } - ctx := context.Background() - providerQueryManager := mustNotErr(New(ctx, fpd, fpn)) + providerQueryManager := mustNotErr(New(fpd, fpn)) + defer providerQueryManager.Close() key := random.Cids(1)[0] - sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) + sessionCtx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) defer cancel() firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0) secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0) @@ -266,13 +266,11 @@ func TestRateLimitingRequests(t *testing.T) { peersFound: peers, delay: 5 * time.Millisecond, } - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxInProcessRequests(maxInProcessRequests))) + providerQueryManager := mustNotErr(New(fpd, fpn, WithMaxInProcessRequests(maxInProcessRequests))) + defer providerQueryManager.Close() keys := random.Cids(maxInProcessRequests + 1) - sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + sessionCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() var requestChannels []<-chan peer.AddrInfo for i := 0; i < maxInProcessRequests+1; i++ { @@ -307,11 +305,11 @@ func TestUnlimitedRequests(t *testing.T) { peersFound: peers, delay: 5 * time.Millisecond, } - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxInProcessRequests(0))) + providerQueryManager := mustNotErr(New(fpd, fpn, WithMaxInProcessRequests(0))) + defer providerQueryManager.Close() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() keys := random.Cids(inProcessRequests) sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() @@ -346,11 +344,11 @@ func TestFindProviderTimeout(t *testing.T) { peersFound: peers, delay: 10 * time.Millisecond, } - ctx := context.Background() - providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxTimeout(2*time.Millisecond))) + providerQueryManager := mustNotErr(New(fpd, fpn, WithMaxTimeout(2*time.Millisecond))) + defer providerQueryManager.Close() keys := random.Cids(1) - sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + sessionCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0) var firstPeersReceived []peer.AddrInfo @@ -369,11 +367,11 @@ func TestFindProviderPreCanceled(t *testing.T) { peersFound: peers, delay: 1 * time.Millisecond, } - ctx := context.Background() - providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxTimeout(100*time.Millisecond))) + providerQueryManager := mustNotErr(New(fpd, fpn, WithMaxTimeout(100*time.Millisecond))) + defer providerQueryManager.Close() keys := random.Cids(1) - sessionCtx, cancel := context.WithCancel(ctx) + sessionCtx, cancel := context.WithCancel(context.Background()) cancel() firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0) if firstRequestChan == nil { @@ -393,11 +391,11 @@ func TestCancelFindProvidersAfterCompletion(t *testing.T) { peersFound: peers, delay: 1 * time.Millisecond, } - ctx := context.Background() - providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxTimeout(100*time.Millisecond))) + providerQueryManager := mustNotErr(New(fpd, fpn, WithMaxTimeout(100*time.Millisecond))) + defer providerQueryManager.Close() keys := random.Cids(1) - sessionCtx, cancel := context.WithCancel(ctx) + sessionCtx, cancel := context.WithCancel(context.Background()) firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0) <-firstRequestChan // wait for everything to start. time.Sleep(10 * time.Millisecond) // wait for the incoming providres to stop. @@ -425,11 +423,11 @@ func TestLimitedProviders(t *testing.T) { peersFound: peers, delay: 1 * time.Millisecond, } - ctx := context.Background() - providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxProviders(max), WithMaxTimeout(100*time.Millisecond))) + providerQueryManager := mustNotErr(New(fpd, fpn, WithMaxProviders(max), WithMaxTimeout(100*time.Millisecond))) + defer providerQueryManager.Close() keys := random.Cids(1) - providersChan := providerQueryManager.FindProvidersAsync(ctx, keys[0], 0) + providersChan := providerQueryManager.FindProvidersAsync(context.Background(), keys[0], 0) total := 0 for range providersChan { total++