Skip to content
3 changes: 3 additions & 0 deletions pkg/engine/internal/scheduler/wire/wire_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ var (
// LocalWorker is the address of the local worker when using the
// [Local] listener.
LocalWorker net.Addr = localAddr("worker")

// LocalWorker2 is another address of the local worker.
LocalWorker2 net.Addr = localAddr("worker2")
Comment on lines +19 to +20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, the package API shouldn't be leaking concepts used for tests.

Workers are allowed to dial themselves (worker connecting to worker), so you likely don't need to introduce another local address here.

)

type localAddr string
Expand Down
25 changes: 18 additions & 7 deletions pkg/engine/internal/worker/thread.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"sync"
"time"

"github.com/go-kit/log"
Expand Down Expand Up @@ -42,11 +43,21 @@ type thread struct {
Logger log.Logger

Ready chan<- readyRequest

stopped chan struct{}
stopOnce sync.Once
}

func (t *thread) Stop() {
t.stopOnce.Do(func() {
close(t.stopped)
})
}

// Run starts the thread. Run will request and run tasks in a loop until the
// context is canceled.
func (t *thread) Run(ctx context.Context) error {
// thread is stopped. Run will not stop if any job failed, it will log the error and continue
// acceptinh other jobs.
func (t *thread) Run() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simplify this a bit by continuing to pass the context into run, but not attaching the context to the readyRequest.Context. That way <-ctx.Done() acts as you're using t.stopped now, and you don't need to add the additional Stop method.

NextTask:
for {
level.Debug(t.Logger).Log("msg", "requesting task")
Expand All @@ -59,20 +70,20 @@ NextTask:
// ensures that the context of tasks written to respCh are bound to the
// lifetime of the thread, but can also be canceled by the scheduler.
req := readyRequest{
Context: ctx,
Context: context.Background(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

readyRequest.Context is always context.Background now, so I think we can remove the field from the struct now

Response: respCh,
}

// Send our request.
select {
case <-ctx.Done():
return nil
case <-t.stopped:
return
case t.Ready <- req:
}

// Wait for a task assignment.
select {
case <-ctx.Done():
case <-t.stopped:
// TODO(rfratto): This will silently drop tasks written to respCh.
// But since Run only exits when the worker is exiting, this should
// be handled gracefully by the scheduler (it will detect the
Expand All @@ -81,7 +92,7 @@ NextTask:
// If, in the future, we dynamically change the number of threads,
// we'll want a mechanism to gracefully handle this so the writer to
// respCh knows that the task was dropped.
return nil
return

case resp := <-respCh:
if resp.Error != nil {
Expand Down
76 changes: 58 additions & 18 deletions pkg/engine/internal/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,57 +162,97 @@ func (w *Worker) Service() services.Service {

// run starts the worker, running until the provided context is canceled.
func (w *Worker) run(ctx context.Context) error {
g, ctx := errgroup.WithContext(ctx)
threadsGroup := &sync.WaitGroup{}
threads := make([]*thread, w.numThreads)

// Spin up worker threads.
for i := range w.numThreads {
t := &thread{
threads[i] = &thread{
BatchSize: w.config.BatchSize,
Logger: log.With(w.logger, "thread", i),
Bucket: w.config.Bucket,

Ready: w.readyCh,
Ready: w.readyCh,
stopped: make(chan struct{}),
}

g.Go(func() error { return t.Run(ctx) })
threadsGroup.Go(func() { threads[i].Run() })
}

g.Go(func() error { return w.runAcceptLoop(ctx) })
// Spin up the listener for peer connections
peerConnectionsCtx, peerConnectionsCancel := context.WithCancel(context.Background())
defer peerConnectionsCancel()
listenerCtx, listenerCancel := context.WithCancel(context.Background())
defer listenerCancel()
Comment on lines +182 to +186
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline: the most important part about a graceful shutdown is we eventually stop accepting new tasks, and wait for current tasks to finish.

We actually don't want to stop accepting new connections, since we may be running a task which depends on a yet-to-be-scheduled scan task, and we need to receive those results to properly compute results.


go func() {
w.runAcceptLoop(listenerCtx, peerConnectionsCtx)
}()

// Spin up the scheduler loop
schedulerCtx, schedulerCancel := context.WithCancel(context.Background())
defer schedulerCancel()

schedulerGroup := errgroup.Group{}
if w.config.SchedulerLookupAddress != "" {
disc, err := newSchedulerLookup(w.logger, w.config.SchedulerLookupAddress, w.config.SchedulerLookupInterval)
if err != nil {
return fmt.Errorf("creating scheduler lookup: %w", err)
}
g.Go(func() error {
return disc.Run(ctx, func(ctx context.Context, addr net.Addr) {
schedulerGroup.Go(func() error {
return disc.Run(schedulerCtx, func(ctx context.Context, addr net.Addr) {
_ = w.schedulerLoop(ctx, addr)
})
})
}

if w.config.SchedulerAddress != nil {
level.Info(w.logger).Log("msg", "directly connecting to scheduler", "scheduler_addr", w.config.SchedulerAddress)
g.Go(func() error { return w.schedulerLoop(ctx, w.config.SchedulerAddress) })
schedulerGroup.Go(func() error { return w.schedulerLoop(schedulerCtx, w.config.SchedulerAddress) })
}

// Wait for shutdown
<-ctx.Done()

// Stop accepting new connections from peers.
listenerCancel()

// Signal all worker threads to stop. This will make them not to ask for new tasks, but continue processing current jobs.
for _, t := range threads {
t.Stop()
}
// Wait for all worker threads to finish their current jobs.
threadsGroup.Wait()

return g.Wait()
// Stop scheduler loop
schedulerCancel()

// Wait for scheduler loop to finish
err := schedulerGroup.Wait()
if err != nil {
return err
}

// Close all peer connections
peerConnectionsCancel()

return nil
}

// runAcceptLoop handles incoming connections from peers. Incoming connections
// are exclusively used to receive task results from other workers, or between
// threads within this worker.
func (w *Worker) runAcceptLoop(ctx context.Context) error {
func (w *Worker) runAcceptLoop(listenerCtx, peerConnectionsCtx context.Context) {
for {
conn, err := w.listener.Accept(ctx)
if err != nil && ctx.Err() != nil {
return nil
conn, err := w.listener.Accept(listenerCtx)
if err != nil && listenerCtx.Err() != nil {
return
} else if err != nil {
level.Warn(w.logger).Log("msg", "failed to accept connection", "err", err)
continue
}

go w.handleConn(ctx, conn)
go w.handleConn(peerConnectionsCtx, conn)
}
}

Expand Down Expand Up @@ -365,13 +405,13 @@ func (w *Worker) handleSchedulerConn(ctx context.Context, logger log.Logger, con
return handleAssignment(peer, msg)

case wire.TaskCancelMessage:
return w.handleCancelMessage(ctx, msg)
return w.handleCancelMessage(msg)

case wire.StreamBindMessage:
return w.handleBindMessage(ctx, msg)

case wire.StreamStatusMessage:
return w.handleStreamStatusMessage(ctx, msg)
return w.handleStreamStatusMessage(msg)

default:
level.Warn(logger).Log("msg", "unsupported message type", "type", reflect.TypeOf(msg).String())
Expand Down Expand Up @@ -526,7 +566,7 @@ func (w *Worker) newJob(ctx context.Context, scheduler *wire.Peer, logger log.Lo
return job, nil
}

func (w *Worker) handleCancelMessage(_ context.Context, msg wire.TaskCancelMessage) error {
func (w *Worker) handleCancelMessage(msg wire.TaskCancelMessage) error {
w.resourcesMut.RLock()
job, found := w.jobs[msg.ID]
w.resourcesMut.RUnlock()
Expand All @@ -550,7 +590,7 @@ func (w *Worker) handleBindMessage(ctx context.Context, msg wire.StreamBindMessa
return sink.Bind(ctx, msg.Receiver)
}

func (w *Worker) handleStreamStatusMessage(_ context.Context, msg wire.StreamStatusMessage) error {
func (w *Worker) handleStreamStatusMessage(msg wire.StreamStatusMessage) error {
w.resourcesMut.RLock()
source, found := w.sources[msg.StreamID]
w.resourcesMut.RUnlock()
Expand Down
Loading
Loading