Skip to content

Commit

Permalink
chore: no lifecycle context to shutdown ProviderQueryManager (#734)
Browse files Browse the repository at this point in the history
* no lifecycle context to shutdown ProviderQueryManager, use Close function instead.
  • Loading branch information
gammazero authored Dec 6, 2024
1 parent f6befaf commit ef25808
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 60 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ The following emojis are used to highlight certain changes:
- `gateway`: `NewCacheBlockStore` and `NewCarBackend` will use `prometheus.DefaultRegisterer` when a custom one is not specified via `WithPrometheusRegistry` [#722](https://github.com/ipfs/boxo/pull/722)
- `filestore`: added opt-in `WithMMapReader` option to `FileManager` to enable memory-mapped file reads [#665](https://github.com/ipfs/boxo/pull/665)
- `bitswap/routing` `ProviderQueryManager` does not require calling `Startup` separate from `New`. [#741](https://github.com/ipfs/boxo/pull/741)
- `bitswap/routing` ProviderQueryManager does not use liftcycle context.

### Changed

Expand Down
10 changes: 6 additions & 4 deletions bitswap/client/bitswap_with_sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,6 @@ func assertBlockListsFrom(from peer.ID, got, exp []blocks.Block) error {
// TestCustomProviderQueryManager tests that nothing breaks if we use a custom
// PQM when creating bitswap.
func TestCustomProviderQueryManager(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

vnet := getVirtualNetwork()
router := mockrouting.NewServer()
ig := testinstance.NewTestInstanceGenerator(vnet, router, nil, nil)
Expand All @@ -130,10 +127,15 @@ func TestCustomProviderQueryManager(t *testing.T) {
b := ig.Next()

// Replace bitswap in instance a with our customized one.
pqm, err := providerquerymanager.New(ctx, a.Adapter, router.Client(a.Identity))
pqm, err := providerquerymanager.New(a.Adapter, router.Client(a.Identity))
if err != nil {
t.Fatal(err)
}
defer pqm.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

bs := bitswap.New(ctx, a.Adapter, pqm, a.Blockstore,
bitswap.WithClientOption(client.WithDefaultProviderQueryManager(false)))
a.Exchange.Close() // close old to be sure.
Expand Down
5 changes: 4 additions & 1 deletion bitswap/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, providerFinder Pr

if bs.providerFinder != nil && bs.defaultProviderQueryManager {
// network can do dialing.
pqm, err := rpqm.New(ctx, network, bs.providerFinder,
pqm, err := rpqm.New(network, bs.providerFinder,
rpqm.WithMaxInProcessRequests(16),
rpqm.WithMaxProviders(10),
rpqm.WithMaxTimeout(10*time.Second))
Expand Down Expand Up @@ -512,6 +512,9 @@ func (bs *Client) Close() error {
close(bs.closing)
bs.sm.Shutdown()
bs.cancel()
if bs.pqm != nil {
bs.pqm.Close()
}
bs.notif.Shutdown()
})
return nil
Expand Down
43 changes: 25 additions & 18 deletions routing/providerquerymanager/providerquerymanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ type cancelRequestMessage struct {
// - ensure two findprovider calls for the same block don't run concurrently
// - manage timeouts
type ProviderQueryManager struct {
ctx context.Context
closeOnce sync.Once
closing chan struct{}
dialer ProviderQueryDialer
router ProviderQueryRouter
providerQueryMessages chan providerQueryMessage
Expand Down Expand Up @@ -133,9 +134,9 @@ func WithMaxProviders(count int) Option {

// New initializes a new ProviderQueryManager for a given context and a given
// network provider.
func New(ctx context.Context, dialer ProviderQueryDialer, router ProviderQueryRouter, opts ...Option) (*ProviderQueryManager, error) {
func New(dialer ProviderQueryDialer, router ProviderQueryRouter, opts ...Option) (*ProviderQueryManager, error) {
pqm := &ProviderQueryManager{
ctx: ctx,
closing: make(chan struct{}),
dialer: dialer,
router: router,
providerQueryMessages: make(chan providerQueryMessage),
Expand All @@ -155,6 +156,12 @@ func New(ctx context.Context, dialer ProviderQueryDialer, router ProviderQueryRo
return pqm, nil
}

func (pqm *ProviderQueryManager) Close() {
pqm.closeOnce.Do(func() {
close(pqm.closing)
})
}

type inProgressRequest struct {
providersSoFar []peer.AddrInfo
incoming chan peer.AddrInfo
Expand All @@ -180,7 +187,7 @@ func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context,
k: k,
inProgressRequestChan: inProgressRequestChan,
}:
case <-pqm.ctx.Done():
case <-pqm.closing:
ch := make(chan peer.AddrInfo)
close(ch)
span.End()
Expand All @@ -196,7 +203,7 @@ func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context,
// get to receiveProviders.
var receivedInProgressRequest inProgressRequest
select {
case <-pqm.ctx.Done():
case <-pqm.closing:
ch := make(chan peer.AddrInfo)
close(ch)
span.End()
Expand Down Expand Up @@ -256,7 +263,7 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k

for receivedProviders.Len() > 0 || incomingProviders != nil {
select {
case <-pqm.ctx.Done():
case <-pqm.closing:
return
case <-sessionCtx.Done():
if incomingProviders != nil {
Expand Down Expand Up @@ -300,7 +307,7 @@ func (pqm *ProviderQueryManager) cancelProviderRequest(ctx context.Context, k ci
if !ok {
return
}
case <-pqm.ctx.Done():
case <-pqm.closing:
return
}
}
Expand All @@ -316,13 +323,13 @@ func (pqm *ProviderQueryManager) findProviderWorker() {
}

// Read find provider requests until channel is closed. The channel is
// closed as soon as pqm.ctx is canceled, so there is no need to select on
// that context here.
// closed as soon as pqm.Close is called, so there is no need to select on
// any other channel to detect shutdown.
for fpr := range pqm.providerRequestsProcessing.Out() {
if findSem != nil {
select {
case findSem <- struct{}{}:
case <-pqm.ctx.Done():
case <-pqm.closing:
return
}
}
Expand Down Expand Up @@ -362,7 +369,7 @@ func (pqm *ProviderQueryManager) findProviderWorker() {
k: k,
p: p,
}:
case <-pqm.ctx.Done():
case <-pqm.closing:
return
}
}(p)
Expand All @@ -374,7 +381,7 @@ func (pqm *ProviderQueryManager) findProviderWorker() {
ctx: ctx,
k: k,
}:
case <-pqm.ctx.Done():
case <-pqm.closing:
}
}(fpr.ctx, fpr.k)
}
Expand Down Expand Up @@ -402,7 +409,7 @@ func (pqm *ProviderQueryManager) run() {
case nextMessage := <-pqm.providerQueryMessages:
nextMessage.debugMessage()
nextMessage.handle(pqm)
case <-pqm.ctx.Done():
case <-pqm.closing:
return
}
}
Expand All @@ -423,7 +430,7 @@ func (rpm *receivedProviderMessage) handle(pqm *ProviderQueryManager) {
for listener := range requestStatus.listeners {
select {
case listener <- rpm.p:
case <-pqm.ctx.Done():
case <-pqm.closing:
return
}
}
Expand Down Expand Up @@ -458,12 +465,12 @@ func (npqm *newProvideQueryMessage) debugMessage() {
func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) {
requestStatus, ok := pqm.inProgressRequestStatuses[npqm.k]
if !ok {
ctx, cancelFn := context.WithCancel(pqm.ctx)
ctx, cancelFn := context.WithCancel(context.Background())
span := trace.SpanFromContext(npqm.ctx)
span.AddEvent("NewQuery", trace.WithAttributes(attribute.Stringer("cid", npqm.k)))
ctx = trace.ContextWithSpan(ctx, span)

// Use context derived from pqm.ctx here, and not the context from the
// Use context derived from background here, and not the context from the
// request (npqm.ctx), because this inProgressRequestStatus applies to
// all in-progress requests for the CID (npqm.k).
//
Expand All @@ -486,7 +493,7 @@ func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) {
k: npqm.k,
ctx: ctx,
}:
case <-pqm.ctx.Done():
case <-pqm.closing:
return
}
} else {
Expand All @@ -502,7 +509,7 @@ func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) {
providersSoFar: requestStatus.providersSoFar,
incoming: inProgressChan,
}:
case <-pqm.ctx.Done():
case <-pqm.closing:
}
}

Expand Down
72 changes: 35 additions & 37 deletions routing/providerquerymanager/providerquerymanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ func TestNormalSimultaneousFetch(t *testing.T) {
peersFound: peers,
delay: 1 * time.Millisecond,
}
ctx := context.Background()
providerQueryManager := mustNotErr(New(ctx, fpd, fpn))
providerQueryManager := mustNotErr(New(fpd, fpn))
defer providerQueryManager.Close()
keys := random.Cids(2)

sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
sessionCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[1], 0)
Expand Down Expand Up @@ -111,11 +111,11 @@ func TestDedupingProviderRequests(t *testing.T) {
peersFound: peers,
delay: 1 * time.Millisecond,
}
ctx := context.Background()
providerQueryManager := mustNotErr(New(ctx, fpd, fpn))
providerQueryManager := mustNotErr(New(fpd, fpn))
defer providerQueryManager.Close()
key := random.Cids(1)[0]

sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
sessionCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0)
Expand Down Expand Up @@ -151,12 +151,13 @@ func TestCancelOneRequestDoesNotTerminateAnother(t *testing.T) {
peersFound: peers,
delay: 1 * time.Millisecond,
}
ctx := context.Background()
providerQueryManager := mustNotErr(New(ctx, fpd, fpn))
providerQueryManager := mustNotErr(New(fpd, fpn))
defer providerQueryManager.Close()

key := random.Cids(1)[0]

// first session will cancel before done
ctx := context.Background()
firstSessionCtx, firstCancel := context.WithTimeout(ctx, 3*time.Millisecond)
defer firstCancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(firstSessionCtx, key, 0)
Expand Down Expand Up @@ -195,14 +196,13 @@ func TestCancelManagerExitsGracefully(t *testing.T) {
peersFound: peers,
delay: 1 * time.Millisecond,
}
ctx := context.Background()
managerCtx, managerCancel := context.WithTimeout(ctx, 5*time.Millisecond)
defer managerCancel()
providerQueryManager := mustNotErr(New(managerCtx, fpd, fpn))
providerQueryManager := mustNotErr(New(fpd, fpn))
defer providerQueryManager.Close()
time.AfterFunc(5*time.Millisecond, providerQueryManager.Close)

key := random.Cids(1)[0]

sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond)
sessionCtx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0)
Expand Down Expand Up @@ -232,12 +232,12 @@ func TestPeersWithConnectionErrorsNotAddedToPeerList(t *testing.T) {
peersFound: peers,
delay: 1 * time.Millisecond,
}
ctx := context.Background()
providerQueryManager := mustNotErr(New(ctx, fpd, fpn))
providerQueryManager := mustNotErr(New(fpd, fpn))
defer providerQueryManager.Close()

key := random.Cids(1)[0]

sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond)
sessionCtx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0)
Expand Down Expand Up @@ -266,13 +266,11 @@ func TestRateLimitingRequests(t *testing.T) {
peersFound: peers,
delay: 5 * time.Millisecond,
}
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxInProcessRequests(maxInProcessRequests)))
providerQueryManager := mustNotErr(New(fpd, fpn, WithMaxInProcessRequests(maxInProcessRequests)))
defer providerQueryManager.Close()

keys := random.Cids(maxInProcessRequests + 1)
sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
sessionCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var requestChannels []<-chan peer.AddrInfo
for i := 0; i < maxInProcessRequests+1; i++ {
Expand Down Expand Up @@ -307,11 +305,11 @@ func TestUnlimitedRequests(t *testing.T) {
peersFound: peers,
delay: 5 * time.Millisecond,
}
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxInProcessRequests(0)))
providerQueryManager := mustNotErr(New(fpd, fpn, WithMaxInProcessRequests(0)))
defer providerQueryManager.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
keys := random.Cids(inProcessRequests)
sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
Expand Down Expand Up @@ -346,11 +344,11 @@ func TestFindProviderTimeout(t *testing.T) {
peersFound: peers,
delay: 10 * time.Millisecond,
}
ctx := context.Background()
providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxTimeout(2*time.Millisecond)))
providerQueryManager := mustNotErr(New(fpd, fpn, WithMaxTimeout(2*time.Millisecond)))
defer providerQueryManager.Close()
keys := random.Cids(1)

sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
sessionCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0)
var firstPeersReceived []peer.AddrInfo
Expand All @@ -369,11 +367,11 @@ func TestFindProviderPreCanceled(t *testing.T) {
peersFound: peers,
delay: 1 * time.Millisecond,
}
ctx := context.Background()
providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxTimeout(100*time.Millisecond)))
providerQueryManager := mustNotErr(New(fpd, fpn, WithMaxTimeout(100*time.Millisecond)))
defer providerQueryManager.Close()
keys := random.Cids(1)

sessionCtx, cancel := context.WithCancel(ctx)
sessionCtx, cancel := context.WithCancel(context.Background())
cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0)
if firstRequestChan == nil {
Expand All @@ -393,11 +391,11 @@ func TestCancelFindProvidersAfterCompletion(t *testing.T) {
peersFound: peers,
delay: 1 * time.Millisecond,
}
ctx := context.Background()
providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxTimeout(100*time.Millisecond)))
providerQueryManager := mustNotErr(New(fpd, fpn, WithMaxTimeout(100*time.Millisecond)))
defer providerQueryManager.Close()
keys := random.Cids(1)

sessionCtx, cancel := context.WithCancel(ctx)
sessionCtx, cancel := context.WithCancel(context.Background())
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0)
<-firstRequestChan // wait for everything to start.
time.Sleep(10 * time.Millisecond) // wait for the incoming providres to stop.
Expand Down Expand Up @@ -425,11 +423,11 @@ func TestLimitedProviders(t *testing.T) {
peersFound: peers,
delay: 1 * time.Millisecond,
}
ctx := context.Background()
providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxProviders(max), WithMaxTimeout(100*time.Millisecond)))
providerQueryManager := mustNotErr(New(fpd, fpn, WithMaxProviders(max), WithMaxTimeout(100*time.Millisecond)))
defer providerQueryManager.Close()
keys := random.Cids(1)

providersChan := providerQueryManager.FindProvidersAsync(ctx, keys[0], 0)
providersChan := providerQueryManager.FindProvidersAsync(context.Background(), keys[0], 0)
total := 0
for range providersChan {
total++
Expand Down

0 comments on commit ef25808

Please sign in to comment.