Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: scheduling extensions #1131

Merged
merged 12 commits into from
Jan 8, 2025
5 changes: 3 additions & 2 deletions pkg/repository/prisma/scheduler_lease.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,9 @@ func (d *leaseRepository) ListActiveWorkers(ctx context.Context, tenantId pgtype
for _, worker := range activeWorkers {
wId := sqlchelpers.UUIDToStr(worker.ID)
res = append(res, &repository.ListActiveWorkersResult{
ID: worker.ID,
Labels: workerIdsToLabels[wId],
ID: worker.ID,
MaxRuns: int(worker.MaxRuns),
Labels: workerIdsToLabels[wId],
})
}

Expand Down
5 changes: 3 additions & 2 deletions pkg/repository/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ type SchedulerRepository interface {
}

type ListActiveWorkersResult struct {
ID pgtype.UUID
Labels []*dbsqlc.ListManyWorkerLabelsRow
ID pgtype.UUID
MaxRuns int
Labels []*dbsqlc.ListManyWorkerLabelsRow
}

type LeaseRepository interface {
Expand Down
86 changes: 86 additions & 0 deletions pkg/scheduling/v2/extension.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package v2

import (
"sync"

"golang.org/x/sync/errgroup"

"github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc"
)

type PostScheduleInput struct {
Workers map[string]*WorkerCp

Slots []*SlotCp

Unassigned []*dbsqlc.QueueItem

ActionsToSlots map[string][]*SlotCp
}

type WorkerCp struct {
WorkerId string
MaxRuns int
Labels []*dbsqlc.ListManyWorkerLabelsRow
}

type SlotCp struct {
WorkerId string
Used bool
}

type SchedulerExtension interface {
SetTenants(tenants []*dbsqlc.Tenant)
PostSchedule(tenantId string, input *PostScheduleInput)
Cleanup() error
}

type Extensions struct {
mu sync.RWMutex
exts []SchedulerExtension
}

func (e *Extensions) Add(ext SchedulerExtension) {
e.mu.Lock()
defer e.mu.Unlock()

if e.exts == nil {
e.exts = make([]SchedulerExtension, 0)
}

e.exts = append(e.exts, ext)
}

func (e *Extensions) PostSchedule(tenantId string, input *PostScheduleInput) {
e.mu.RLock()
defer e.mu.RUnlock()

for _, ext := range e.exts {
f := ext.PostSchedule
go f(tenantId, input)
}
}

func (e *Extensions) Cleanup() error {
e.mu.RLock()
defer e.mu.RUnlock()

eg := errgroup.Group{}

for _, ext := range e.exts {
f := ext.Cleanup
eg.Go(f)
}

return eg.Wait()
}

func (e *Extensions) SetTenants(tenants []*dbsqlc.Tenant) {
e.mu.RLock()
defer e.mu.RUnlock()

for _, ext := range e.exts {
f := ext.SetTenants
go f(tenants)
}
}
13 changes: 12 additions & 1 deletion pkg/scheduling/v2/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type sharedConfig struct {

// SchedulingPool is responsible for managing a pool of tenantManagers.
type SchedulingPool struct {
Extensions *Extensions

tenants sync.Map
setMu mutex

Expand All @@ -33,6 +35,7 @@ func NewSchedulingPool(repo repository.SchedulerRepository, l *zerolog.Logger, s
resultsCh := make(chan *QueueResults, 1000)

s := &SchedulingPool{
Extensions: &Extensions{},
cf: &sharedConfig{
repo: repo,
l: l,
Expand Down Expand Up @@ -62,6 +65,12 @@ func (p *SchedulingPool) cleanup() {
})

p.cleanupTenants(toCleanup)

err := p.Extensions.Cleanup()

if err != nil {
p.cf.l.Error().Err(err).Msg("failed to cleanup extensions")
}
}

func (p *SchedulingPool) SetTenants(tenants []*dbsqlc.Tenant) {
Expand Down Expand Up @@ -103,6 +112,8 @@ func (p *SchedulingPool) SetTenants(tenants []*dbsqlc.Tenant) {
// any cleaned up tenants in the map
p.cleanupTenants(toCleanup)
}()

go p.Extensions.SetTenants(tenants)
}

func (p *SchedulingPool) cleanupTenants(toCleanup []*tenantManager) {
Expand Down Expand Up @@ -148,7 +159,7 @@ func (p *SchedulingPool) getTenantManager(tenantId string, storeIfNotFound bool)

if !ok {
if storeIfNotFound {
tm = newTenantManager(p.cf, tenantId, p.resultsCh)
tm = newTenantManager(p.cf, tenantId, p.resultsCh, p.Extensions)
p.tenants.Store(tenantId, tm)
} else {
return nil
Expand Down
95 changes: 92 additions & 3 deletions pkg/scheduling/v2/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ type Scheduler struct {
unackedSlots map[int]*slot
unackedMu mutex

rl *rateLimiter
rl *rateLimiter
exts *Extensions
}

func newScheduler(cf *sharedConfig, tenantId pgtype.UUID, rl *rateLimiter) *Scheduler {
func newScheduler(cf *sharedConfig, tenantId pgtype.UUID, rl *rateLimiter, exts *Extensions) *Scheduler {
l := cf.l.With().Str("tenant_id", sqlchelpers.UUIDToStr(tenantId)).Logger()

return &Scheduler{
Expand All @@ -57,6 +58,7 @@ func newScheduler(cf *sharedConfig, tenantId pgtype.UUID, rl *rateLimiter) *Sche
workersMu: newMu(cf.l),
assignedCountMu: newMu(cf.l),
unackedMu: newMu(cf.l),
exts: exts,
}
}

Expand Down Expand Up @@ -665,6 +667,9 @@ func (s *Scheduler) tryAssign(
wg := sync.WaitGroup{}
startTotal := time.Now()

extensionResults := make([]*assignResults, 0)
extensionResultsMu := sync.Mutex{}

// process each action id in parallel
for actionId, qis := range actionIdToQueueItems {
wg.Add(1)
Expand Down Expand Up @@ -733,12 +738,18 @@ func (s *Scheduler) tryAssign(
s.l.Warn().Dur("duration", sinceStart).Msgf("processing batch of %d queue items took longer than 100ms", len(batchQis))
}

resultsCh <- &assignResults{
r := &assignResults{
assigned: batchAssigned,
rateLimited: batchRateLimited,
unassigned: batchUnassigned,
}

extensionResultsMu.Lock()
extensionResults = append(extensionResults, r)
extensionResultsMu.Unlock()

resultsCh <- r

return nil
})

Expand All @@ -752,6 +763,10 @@ func (s *Scheduler) tryAssign(
span.End()
close(resultsCh)

extInput := s.getExtensionInput(extensionResults)

s.exts.PostSchedule(sqlchelpers.UUIDToStr(s.tenantId), extInput)

if sinceStart := time.Since(startTotal); sinceStart > 100*time.Millisecond {
s.l.Warn().Dur("duration", sinceStart).Msgf("assigning queue items took longer than 100ms")
}
Expand All @@ -760,6 +775,80 @@ func (s *Scheduler) tryAssign(
return resultsCh
}

func (s *Scheduler) getExtensionInput(results []*assignResults) *PostScheduleInput {
unassigned := make([]*dbsqlc.QueueItem, 0)

for _, res := range results {
unassigned = append(unassigned, res.unassigned...)
}

workers := s.getWorkers()

res := &PostScheduleInput{
Workers: make(map[string]*WorkerCp),
Slots: make([]*SlotCp, 0),
Unassigned: unassigned,
}

for workerId, worker := range workers {
res.Workers[workerId] = &WorkerCp{
WorkerId: workerId,
MaxRuns: worker.MaxRuns,
Labels: worker.Labels,
}
}

// NOTE: these locks are important because we must acquire locks in the same order as the replenish and tryAssignBatch
// functions. we always acquire actionsMu first and then the specific action's lock.
actionKeys := make([]string, 0, len(s.actions))

s.actionsMu.RLock()

for actionId := range s.actions {
actionKeys = append(actionKeys, actionId)
}

s.actionsMu.RUnlock()

uniqueSlots := make(map[*slot]*SlotCp)
actionsToSlots := make(map[string][]*SlotCp)

for _, actionId := range actionKeys {
s.actionsMu.RLock()
action, ok := s.actions[actionId]
s.actionsMu.RUnlock()

if !ok || action == nil {
continue
}

action.mu.RLock()
actionsToSlots[action.actionId] = make([]*SlotCp, 0, len(action.slots))

for _, slot := range action.slots {
if _, ok := uniqueSlots[slot]; ok {
continue
}

uniqueSlots[slot] = &SlotCp{
WorkerId: slot.getWorkerId(),
Used: slot.used,
}

actionsToSlots[action.actionId] = append(actionsToSlots[action.actionId], uniqueSlots[slot])
}
action.mu.RUnlock()
}

for _, slot := range uniqueSlots {
res.Slots = append(res.Slots, slot)
}

res.ActionsToSlots = actionsToSlots

return res
}

func isTimedOut(qi *dbsqlc.QueueItem) bool {
// if the current time is after the scheduleTimeoutAt, then mark this as timed out
now := time.Now().UTC().UTC()
Expand Down
4 changes: 2 additions & 2 deletions pkg/scheduling/v2/tenant_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ type tenantManager struct {
cleanup func()
}

func newTenantManager(cf *sharedConfig, tenantId string, resultsCh chan *QueueResults) *tenantManager {
func newTenantManager(cf *sharedConfig, tenantId string, resultsCh chan *QueueResults, exts *Extensions) *tenantManager {
tenantIdUUID := sqlchelpers.UUIDFromStr(tenantId)

rl := newRateLimiter(cf, tenantIdUUID)
s := newScheduler(cf, tenantIdUUID, rl)
s := newScheduler(cf, tenantIdUUID, rl, exts)
leaseManager, workersCh, queuesCh := newLeaseManager(cf, tenantIdUUID)

t := &tenantManager{
Expand Down
Loading