diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index e90e56af0f..bd46b13d0c 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -684,6 +684,10 @@ type GatewaySchedulingConfig struct { // 负载计算 LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` + // 负载均衡时每次窗口化读取的候选账号数量。 + CandidatePageSize int `mapstructure:"candidate_page_size"` + // 负载均衡时最多参与轮转采样的有序账号数量。 + CandidateScanLimit int `mapstructure:"candidate_scan_limit"` // 过期槽位清理周期(0 表示禁用) SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"` @@ -1411,6 +1415,8 @@ func setDefaults() { viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used") viper.SetDefault("gateway.scheduling.load_batch_enabled", true) + viper.SetDefault("gateway.scheduling.candidate_page_size", 256) + viper.SetDefault("gateway.scheduling.candidate_scan_limit", 8192) viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) viper.SetDefault("gateway.scheduling.db_fallback_enabled", true) viper.SetDefault("gateway.scheduling.db_fallback_timeout_seconds", 0) @@ -2185,6 +2191,15 @@ func (c *Config) Validate() error { if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 { return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive") } + if c.Gateway.Scheduling.CandidatePageSize <= 0 { + return fmt.Errorf("gateway.scheduling.candidate_page_size must be positive") + } + if c.Gateway.Scheduling.CandidateScanLimit <= 0 { + return fmt.Errorf("gateway.scheduling.candidate_scan_limit must be positive") + } + if c.Gateway.Scheduling.CandidateScanLimit < c.Gateway.Scheduling.CandidatePageSize { + return fmt.Errorf("gateway.scheduling.candidate_scan_limit must be >= candidate_page_size") + } if c.Gateway.Scheduling.SlotCleanupInterval < 0 { return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative") } diff --git a/backend/internal/pkg/logger/logger.go b/backend/internal/pkg/logger/logger.go index 3fca706ec9..5c823a0b88 100644 --- a/backend/internal/pkg/logger/logger.go +++ b/backend/internal/pkg/logger/logger.go @@ -47,6 +47,7 @@ var ( sugar atomic.Pointer[zap.SugaredLogger] atomicLevel zap.AtomicLevel initOptions InitOptions + activeClosers []io.Closer currentSink atomic.Value // sinkState stdLogUndo func() bootstrapOnce sync.Once @@ -72,16 +73,18 @@ func Init(options InitOptions) error { func initLocked(options InitOptions) error { normalized := options.normalized() - zl, al, err := buildLogger(normalized) + zl, al, closers, err := buildLogger(normalized) if err != nil { return err } prev := global.Load() + prevClosers := activeClosers global.Store(zl) sugar.Store(zl.Sugar()) atomicLevel = al initOptions = normalized + activeClosers = closers bridgeSlogLocked() bridgeStdLogLocked() @@ -89,6 +92,7 @@ func initLocked(options InitOptions) error { if prev != nil { _ = prev.Sync() } + closeClosers(prevClosers) return nil } @@ -205,6 +209,19 @@ func Sync() { } } +func Close() { + mu.Lock() + defer mu.Unlock() + + if l := global.Load(); l != nil { + _ = l.Sync() + } + closeClosers(activeClosers) + activeClosers = nil + global.Store(nil) + sugar.Store(nil) +} + func bridgeStdLogLocked() { if stdLogUndo != nil { stdLogUndo() @@ -238,7 +255,7 @@ func bridgeSlogLocked() { slog.SetDefault(slog.New(newSlogZapHandler(base.Named("slog")))) } -func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) { +func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, []io.Closer, error) { level, _ := parseLevel(options.Level) atomic := zap.NewAtomicLevelAt(level) @@ -265,6 +282,7 @@ func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) { sinkCore := newSinkCore() cores := make([]zapcore.Core, 0, 3) + closers := make([]io.Closer, 0, 1) if options.Output.ToStdout { infoPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { @@ -278,7 +296,7 @@ func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) { } if options.Output.ToFile { - fileCore, filePath, fileErr := buildFileCore(enc, atomic, options) + fileCore, fileCloser, filePath, fileErr := buildFileCore(enc, atomic, options) if fileErr != nil { _, _ = fmt.Fprintf(os.Stderr, "time=%s level=WARN msg=\"日志文件输出初始化失败,降级为仅标准输出\" path=%s err=%v\n", time.Now().Format(time.RFC3339Nano), @@ -287,6 +305,9 @@ func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) { ) } else { cores = append(cores, fileCore) + if fileCloser != nil { + closers = append(closers, fileCloser) + } } } @@ -313,10 +334,10 @@ func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) { zap.String("service", options.ServiceName), zap.String("env", options.Environment), ) - return logger, atomic, nil + return logger, atomic, closers, nil } -func buildFileCore(enc zapcore.Encoder, atomic zap.AtomicLevel, options InitOptions) (zapcore.Core, string, error) { +func buildFileCore(enc zapcore.Encoder, atomic zap.AtomicLevel, options InitOptions) (zapcore.Core, io.Closer, string, error) { filePath := options.Output.FilePath if strings.TrimSpace(filePath) == "" { filePath = resolveLogFilePath("") @@ -324,7 +345,7 @@ func buildFileCore(enc zapcore.Encoder, atomic zap.AtomicLevel, options InitOpti dir := filepath.Dir(filePath) if err := os.MkdirAll(dir, 0o755); err != nil { - return nil, filePath, err + return nil, nil, filePath, err } lj := &lumberjack.Logger{ Filename: filePath, @@ -334,7 +355,16 @@ func buildFileCore(enc zapcore.Encoder, atomic zap.AtomicLevel, options InitOpti Compress: options.Rotation.Compress, LocalTime: options.Rotation.LocalTime, } - return zapcore.NewCore(enc, zapcore.AddSync(lj), atomic), filePath, nil + return zapcore.NewCore(enc, zapcore.AddSync(lj), atomic), lj, filePath, nil +} + +func closeClosers(closers []io.Closer) { + for _, closer := range closers { + if closer == nil { + continue + } + _ = closer.Close() + } } type sinkCore struct { diff --git a/backend/internal/pkg/logger/logger_test.go b/backend/internal/pkg/logger/logger_test.go index 74aae0613a..53044cde6b 100644 --- a/backend/internal/pkg/logger/logger_test.go +++ b/backend/internal/pkg/logger/logger_test.go @@ -5,10 +5,18 @@ import ( "io" "os" "path/filepath" + "runtime" "strings" "testing" ) +func syncTestLogger() { + if runtime.GOOS == "windows" { + return + } + Sync() +} + func TestInit_DualOutput(t *testing.T) { tmpDir := t.TempDir() logPath := filepath.Join(tmpDir, "logs", "sub2api.log") @@ -54,10 +62,11 @@ func TestInit_DualOutput(t *testing.T) { if err != nil { t.Fatalf("Init() error: %v", err) } + t.Cleanup(Close) L().Info("dual-output-info") L().Warn("dual-output-warn") - Sync() + syncTestLogger() _ = stdoutW.Close() _ = stderrW.Close() @@ -121,6 +130,7 @@ func TestInit_FileOutputFailureDowngrade(t *testing.T) { if err != nil { t.Fatalf("Init() should downgrade instead of failing, got: %v", err) } + t.Cleanup(Close) _ = stderrW.Close() stderrBytes, _ := io.ReadAll(stderrR) @@ -164,9 +174,10 @@ func TestInit_CallerShouldPointToCallsite(t *testing.T) { }); err != nil { t.Fatalf("Init() error: %v", err) } + t.Cleanup(Close) L().Info("caller-check") - Sync() + syncTestLogger() _ = stdoutW.Close() logBytes, _ := io.ReadAll(stdoutR) diff --git a/backend/internal/pkg/logger/options_test.go b/backend/internal/pkg/logger/options_test.go index 10d50d72c9..876316f8e2 100644 --- a/backend/internal/pkg/logger/options_test.go +++ b/backend/internal/pkg/logger/options_test.go @@ -95,7 +95,7 @@ func TestBuildFileCore_InvalidPathFallback(t *testing.T) { EncodeLevel: zapcore.CapitalLevelEncoder, } encoder := zapcore.NewJSONEncoder(encoderCfg) - _, _, err := buildFileCore(encoder, zap.NewAtomicLevel(), opts) + _, _, _, err := buildFileCore(encoder, zap.NewAtomicLevel(), opts) if err == nil { t.Fatalf("buildFileCore() expected error for invalid path") } diff --git a/backend/internal/pkg/logger/stdlog_bridge_test.go b/backend/internal/pkg/logger/stdlog_bridge_test.go index 4482a2ecd3..ef236a441b 100644 --- a/backend/internal/pkg/logger/stdlog_bridge_test.go +++ b/backend/internal/pkg/logger/stdlog_bridge_test.go @@ -73,11 +73,12 @@ func TestStdLogBridgeRoutesLevels(t *testing.T) { }); err != nil { t.Fatalf("Init() error: %v", err) } + t.Cleanup(Close) log.Printf("service started") log.Printf("Warning: queue full") log.Printf("Forward request failed: timeout") - Sync() + syncTestLogger() _ = stdoutW.Close() _ = stderrW.Close() @@ -135,11 +136,12 @@ func TestLegacyPrintfRoutesLevels(t *testing.T) { }); err != nil { t.Fatalf("Init() error: %v", err) } + t.Cleanup(Close) LegacyPrintf("service.test", "request started") LegacyPrintf("service.test", "Warning: queue full") LegacyPrintf("service.test", "forward failed: timeout") - Sync() + syncTestLogger() _ = stdoutW.Close() _ = stderrW.Close() diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 35b908de0b..a8bee15b31 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -504,7 +504,7 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { accounts, err := r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{ status: service.StatusActive, - }) + }, 0, 0) if err != nil { return nil, err } @@ -811,12 +811,16 @@ func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupI return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{ status: service.StatusActive, schedulable: true, - }) + }, 0, 0) } func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return r.ListSchedulableByPlatformWindow(ctx, platform, 0, 0) +} + +func (r *accountRepository) ListSchedulableByPlatformWindow(ctx context.Context, platform string, offset, limit int) ([]service.Account, error) { now := time.Now() - accounts, err := r.client.Account.Query(). + query := r.client.Account.Query(). Where( dbaccount.PlatformEQ(platform), dbaccount.StatusEQ(service.StatusActive), @@ -826,31 +830,35 @@ func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platf dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), ). - Order(dbent.Asc(dbaccount.FieldPriority)). - All(ctx) - if err != nil { - return nil, err - } - return r.accountsToService(ctx, accounts) + Order(dbent.Asc(dbaccount.FieldPriority)) + return r.queryAccountsWindow(ctx, query, offset, limit) } func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) { // 单平台查询复用多平台逻辑,保持过滤条件与排序策略一致。 + return r.ListSchedulableByGroupIDAndPlatformWindow(ctx, groupID, platform, 0, 0) +} + +func (r *accountRepository) ListSchedulableByGroupIDAndPlatformWindow(ctx context.Context, groupID int64, platform string, offset, limit int) ([]service.Account, error) { return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{ status: service.StatusActive, schedulable: true, platforms: []string{platform}, - }) + }, offset, limit) } func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + return r.ListSchedulableByPlatformsWindow(ctx, platforms, 0, 0) +} + +func (r *accountRepository) ListSchedulableByPlatformsWindow(ctx context.Context, platforms []string, offset, limit int) ([]service.Account, error) { if len(platforms) == 0 { return nil, nil } // 仅返回可调度的活跃账号,并过滤处于过载/限流窗口的账号。 // 代理与分组信息统一在 accountsToService 中批量加载,避免 N+1 查询。 now := time.Now() - accounts, err := r.client.Account.Query(). + query := r.client.Account.Query(). Where( dbaccount.PlatformIn(platforms...), dbaccount.StatusEQ(service.StatusActive), @@ -860,17 +868,17 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), ). - Order(dbent.Asc(dbaccount.FieldPriority)). - All(ctx) - if err != nil { - return nil, err - } - return r.accountsToService(ctx, accounts) + Order(dbent.Asc(dbaccount.FieldPriority)) + return r.queryAccountsWindow(ctx, query, offset, limit) } func (r *accountRepository) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return r.ListSchedulableUngroupedByPlatformWindow(ctx, platform, 0, 0) +} + +func (r *accountRepository) ListSchedulableUngroupedByPlatformWindow(ctx context.Context, platform string, offset, limit int) ([]service.Account, error) { now := time.Now() - accounts, err := r.client.Account.Query(). + query := r.client.Account.Query(). Where( dbaccount.PlatformEQ(platform), dbaccount.StatusEQ(service.StatusActive), @@ -881,20 +889,20 @@ func (r *accountRepository) ListSchedulableUngroupedByPlatform(ctx context.Conte dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), ). - Order(dbent.Asc(dbaccount.FieldPriority)). - All(ctx) - if err != nil { - return nil, err - } - return r.accountsToService(ctx, accounts) + Order(dbent.Asc(dbaccount.FieldPriority)) + return r.queryAccountsWindow(ctx, query, offset, limit) } func (r *accountRepository) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + return r.ListSchedulableUngroupedByPlatformsWindow(ctx, platforms, 0, 0) +} + +func (r *accountRepository) ListSchedulableUngroupedByPlatformsWindow(ctx context.Context, platforms []string, offset, limit int) ([]service.Account, error) { if len(platforms) == 0 { return nil, nil } now := time.Now() - accounts, err := r.client.Account.Query(). + query := r.client.Account.Query(). Where( dbaccount.PlatformIn(platforms...), dbaccount.StatusEQ(service.StatusActive), @@ -905,12 +913,8 @@ func (r *accountRepository) ListSchedulableUngroupedByPlatforms(ctx context.Cont dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), ). - Order(dbent.Asc(dbaccount.FieldPriority)). - All(ctx) - if err != nil { - return nil, err - } - return r.accountsToService(ctx, accounts) + Order(dbent.Asc(dbaccount.FieldPriority)) + return r.queryAccountsWindow(ctx, query, offset, limit) } func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { @@ -918,11 +922,18 @@ func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Con return nil, nil } // 复用按分组查询逻辑,保证分组优先级 + 账号优先级的排序与筛选一致。 + return r.ListSchedulableByGroupIDAndPlatformsWindow(ctx, groupID, platforms, 0, 0) +} + +func (r *accountRepository) ListSchedulableByGroupIDAndPlatformsWindow(ctx context.Context, groupID int64, platforms []string, offset, limit int) ([]service.Account, error) { + if len(platforms) == 0 { + return nil, nil + } return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{ status: service.StatusActive, schedulable: true, platforms: platforms, - }) + }, offset, limit) } func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { @@ -1363,7 +1374,7 @@ type accountGroupQueryOptions struct { platforms []string // 允许的多个平台,空切片表示不进行平台过滤 } -func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID int64, opts accountGroupQueryOptions) ([]service.Account, error) { +func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID int64, opts accountGroupQueryOptions, offset, limit int) ([]service.Account, error) { q := r.client.AccountGroup.Query(). Where(dbaccountgroup.GroupIDEQ(groupID)) @@ -1391,13 +1402,19 @@ func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID in q = q.Where(dbaccountgroup.HasAccountWith(preds...)) } - groups, err := q. + q = q. Order( dbaccountgroup.ByPriority(), dbaccountgroup.ByAccountField(dbaccount.FieldPriority), ). - WithAccount(). - All(ctx) + WithAccount() + if offset > 0 { + q = q.Offset(offset) + } + if limit > 0 { + q = q.Limit(limit) + } + groups, err := q.All(ctx) if err != nil { return nil, err } @@ -1425,6 +1442,20 @@ func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID in return r.accountsToService(ctx, accounts) } +func (r *accountRepository) queryAccountsWindow(ctx context.Context, query *dbent.AccountQuery, offset, limit int) ([]service.Account, error) { + if offset > 0 { + query = query.Offset(offset) + } + if limit > 0 { + query = query.Limit(limit) + } + accounts, err := query.All(ctx) + if err != nil { + return nil, err + } + return r.accountsToService(ctx, accounts) +} + func (r *accountRepository) accountsToService(ctx context.Context, accounts []*dbent.Account) ([]service.Account, error) { if len(accounts) == 0 { return []service.Account{}, nil diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go index 4f447e4fea..f2b0627871 100644 --- a/backend/internal/repository/scheduler_cache.go +++ b/backend/internal/repository/scheduler_cache.go @@ -87,6 +87,72 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul return accounts, true, nil } +func (c *schedulerCache) GetSnapshotWindow(ctx context.Context, bucket service.SchedulerBucket, offset, limit int) ([]*service.Account, bool, error) { + if limit <= 0 { + return c.GetSnapshot(ctx, bucket) + } + if offset < 0 { + offset = 0 + } + + readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket) + readyVal, err := c.rdb.Get(ctx, readyKey).Result() + if err == redis.Nil { + return nil, false, nil + } + if err != nil { + return nil, false, err + } + if readyVal != "1" { + return nil, false, nil + } + + activeKey := schedulerBucketKey(schedulerActivePrefix, bucket) + activeVal, err := c.rdb.Get(ctx, activeKey).Result() + if err == redis.Nil { + return nil, false, nil + } + if err != nil { + return nil, false, err + } + + snapshotKey := schedulerSnapshotKey(bucket, activeVal) + stop := int64(offset + limit - 1) + ids, err := c.rdb.ZRange(ctx, snapshotKey, int64(offset), stop).Result() + if err != nil { + return nil, false, err + } + if len(ids) == 0 { + if offset == 0 { + return nil, false, nil + } + return []*service.Account{}, true, nil + } + + keys := make([]string, 0, len(ids)) + for _, id := range ids { + keys = append(keys, schedulerAccountKey(id)) + } + values, err := c.rdb.MGet(ctx, keys...).Result() + if err != nil { + return nil, false, err + } + + accounts := make([]*service.Account, 0, len(values)) + for _, val := range values { + if val == nil { + return nil, false, nil + } + account, err := decodeCachedAccount(val) + if err != nil { + return nil, false, err + } + accounts = append(accounts, account) + } + + return accounts, true, nil +} + func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error { activeKey := schedulerBucketKey(schedulerActivePrefix, bucket) oldActive, _ := c.rdb.Get(ctx, activeKey).Result() diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 789888cb36..1ae7151819 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -562,41 +562,37 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( ctx context.Context, req OpenAIAccountScheduleRequest, ) (*AccountSelectionResult, int, int, float64, error) { - accounts, err := s.service.listSchedulableAccounts(ctx, req.GroupID) - if err != nil { - return nil, 0, 0, 0, err - } - if len(accounts) == 0 { - return nil, 0, 0, 0, errors.New("no available OpenAI accounts") - } - - filtered := make([]*Account, 0, len(accounts)) - loadReq := make([]AccountWithConcurrency, 0, len(accounts)) - for i := range accounts { - account := &accounts[i] + filtered, err := s.service.collectBoundedOpenAICandidates(ctx, req.GroupID, req.SessionHash, req.RequestedModel, func(account *Account) bool { if req.ExcludedIDs != nil { if _, excluded := req.ExcludedIDs[account.ID]; excluded { - continue + return false } } if !account.IsSchedulable() || !account.IsOpenAI() { - continue + return false } if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { - continue + return false } if !s.isAccountTransportCompatible(account, req.RequiredTransport) { - continue + return false } - filtered = append(filtered, account) + return true + }) + if err != nil { + return nil, 0, 0, 0, err + } + if len(filtered) == 0 { + return nil, 0, 0, 0, errors.New("no available OpenAI accounts") + } + + loadReq := make([]AccountWithConcurrency, 0, len(filtered)) + for _, account := range filtered { loadReq = append(loadReq, AccountWithConcurrency{ ID: account.ID, MaxConcurrency: account.EffectiveLoadFactor(), }) } - if len(filtered) == 0 { - return nil, 0, 0, 0, errors.New("no available OpenAI accounts") - } loadMap := map[int64]*AccountLoadInfo{} if s.service.concurrencyService != nil { diff --git a/backend/internal/service/openai_bounded_scheduling_test.go b/backend/internal/service/openai_bounded_scheduling_test.go new file mode 100644 index 0000000000..1d091844c2 --- /dev/null +++ b/backend/internal/service/openai_bounded_scheduling_test.go @@ -0,0 +1,246 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type openAIWindowCall struct { + offset int + limit int +} + +type windowedOpenAISnapshotCacheStub struct { + openAISnapshotCacheStub + fullCalls int + windowCalls []openAIWindowCall + miss bool +} + +func (s *windowedOpenAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) { + s.fullCalls++ + if s.miss { + return nil, false, nil + } + return s.openAISnapshotCacheStub.GetSnapshot(ctx, bucket) +} + +func (s *windowedOpenAISnapshotCacheStub) GetSnapshotWindow(ctx context.Context, bucket SchedulerBucket, offset, limit int) ([]*Account, bool, error) { + s.windowCalls = append(s.windowCalls, openAIWindowCall{offset: offset, limit: limit}) + if s.miss { + return nil, false, nil + } + window := sliceAccountsWindow(derefAccounts(s.snapshotAccounts), offset, limit) + out := make([]*Account, 0, len(window)) + for i := range window { + account := window[i] + out = append(out, &account) + } + return out, true, nil +} + +type windowedOpenAIRepoStub struct { + stubOpenAIAccountRepo + fullCalls int + windowCalls []openAIWindowCall +} + +func (r *windowedOpenAIRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) { + r.fullCalls++ + return nil, errors.New("unexpected full schedulable list call") +} + +func (r *windowedOpenAIRepoStub) ListSchedulableByGroupIDAndPlatformWindow(ctx context.Context, groupID int64, platform string, offset, limit int) ([]Account, error) { + r.windowCalls = append(r.windowCalls, openAIWindowCall{offset: offset, limit: limit}) + var filtered []Account + for _, acc := range r.accounts { + if acc.Platform == platform { + filtered = append(filtered, acc) + } + } + return sliceAccountsWindow(filtered, offset, limit), nil +} + +func (r *windowedOpenAIRepoStub) ListSchedulableByPlatformWindow(ctx context.Context, platform string, offset, limit int) ([]Account, error) { + return nil, errors.New("unexpected platform window call") +} + +func (r *windowedOpenAIRepoStub) ListSchedulableByPlatformsWindow(ctx context.Context, platforms []string, offset, limit int) ([]Account, error) { + return nil, errors.New("unexpected platforms window call") +} + +func (r *windowedOpenAIRepoStub) ListSchedulableByGroupIDAndPlatformsWindow(ctx context.Context, groupID int64, platforms []string, offset, limit int) ([]Account, error) { + return nil, errors.New("unexpected group platforms window call") +} + +func (r *windowedOpenAIRepoStub) ListSchedulableUngroupedByPlatformWindow(ctx context.Context, platform string, offset, limit int) ([]Account, error) { + return nil, errors.New("unexpected ungrouped platform window call") +} + +func (r *windowedOpenAIRepoStub) ListSchedulableUngroupedByPlatformsWindow(ctx context.Context, platforms []string, offset, limit int) ([]Account, error) { + return nil, errors.New("unexpected ungrouped platforms window call") +} + +func TestOpenAIGatewayService_SelectAccountWithLoadAwareness_UsesWindowedSnapshotCandidates(t *testing.T) { + groupID := int64(42) + accounts := []*Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 10}, + {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 10}, + {ID: 3, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 10}, + {ID: 4, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 10}, + {ID: 5, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 10}, + {ID: 6, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 10}, + } + accountsByID := make(map[int64]*Account, len(accounts)) + loadMap := make(map[int64]*AccountLoadInfo, len(accounts)) + for _, account := range accounts { + cloned := *account + accountsByID[account.ID] = &cloned + loadMap[account.ID] = &AccountLoadInfo{AccountID: account.ID, LoadRate: 90} + } + + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = true + cfg.Gateway.Scheduling.CandidatePageSize = 2 + cfg.Gateway.Scheduling.CandidateScanLimit = 6 + cfg.Gateway.Scheduling.StickySessionMaxWaiting = 3 + cfg.Gateway.Scheduling.StickySessionWaitTimeout = 45 * time.Second + cfg.Gateway.Scheduling.FallbackWaitTimeout = 30 * time.Second + cfg.Gateway.Scheduling.FallbackMaxWaiting = 100 + + snapshotCache := &windowedOpenAISnapshotCacheStub{ + openAISnapshotCacheStub: openAISnapshotCacheStub{ + snapshotAccounts: accounts, + accountsByID: accountsByID, + }, + } + snapshotSvc := &SchedulerSnapshotService{cache: snapshotCache} + svc := &OpenAIGatewayService{ + cache: &stubGatewayCache{}, + cfg: cfg, + schedulerSnapshot: snapshotSvc, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{loadMap: loadMap, skipDefaultLoad: true}), + } + + const sessionHash = "bounded-window-seed" + pageSize, _, startPage := svc.openAICandidateWindowPlan(&groupID, sessionHash, "") + startOffset := startPage * pageSize + expectedID := int64(startOffset + 2) + loadMap[expectedID] = &AccountLoadInfo{AccountID: expectedID, LoadRate: 0} + + selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "", nil) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, expectedID, selection.Account.ID) + require.Equal(t, 0, snapshotCache.fullCalls) + require.Equal(t, []openAIWindowCall{{offset: startOffset, limit: pageSize}}, snapshotCache.windowCalls) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWindowedSnapshotCandidates(t *testing.T) { + groupID := int64(77) + accounts := []*Account{ + {ID: 11, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 10}, + {ID: 12, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 10}, + {ID: 13, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 10}, + {ID: 14, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 10}, + {ID: 15, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 10}, + {ID: 16, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 10}, + } + accountsByID := make(map[int64]*Account, len(accounts)) + loadMap := make(map[int64]*AccountLoadInfo, len(accounts)) + for _, account := range accounts { + cloned := *account + accountsByID[account.ID] = &cloned + loadMap[account.ID] = &AccountLoadInfo{AccountID: account.ID, LoadRate: 80} + } + + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = true + cfg.Gateway.Scheduling.CandidatePageSize = 2 + cfg.Gateway.Scheduling.CandidateScanLimit = 6 + cfg.Gateway.Scheduling.StickySessionMaxWaiting = 3 + cfg.Gateway.Scheduling.StickySessionWaitTimeout = 45 * time.Second + cfg.Gateway.Scheduling.FallbackWaitTimeout = 30 * time.Second + cfg.Gateway.Scheduling.FallbackMaxWaiting = 100 + cfg.Gateway.OpenAIWS.LBTopK = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 + + snapshotCache := &windowedOpenAISnapshotCacheStub{ + openAISnapshotCacheStub: openAISnapshotCacheStub{ + snapshotAccounts: accounts, + accountsByID: accountsByID, + }, + } + snapshotSvc := &SchedulerSnapshotService{cache: snapshotCache} + svc := &OpenAIGatewayService{ + cache: &stubGatewayCache{}, + cfg: cfg, + schedulerSnapshot: snapshotSvc, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{loadMap: loadMap, skipDefaultLoad: true}), + } + + const sessionHash = "bounded-scheduler-seed" + pageSize, _, startPage := svc.openAICandidateWindowPlan(&groupID, sessionHash, "") + startOffset := startPage * pageSize + expectedID := int64(startOffset + 12) + loadMap[expectedID] = &AccountLoadInfo{AccountID: expectedID, LoadRate: 0} + + selection, decision, err := svc.SelectAccountWithScheduler( + context.Background(), + &groupID, + "", + sessionHash, + "", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, expectedID, selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.Equal(t, pageSize, decision.CandidateCount) + require.Equal(t, 0, snapshotCache.fullCalls) + require.Equal(t, []openAIWindowCall{{offset: startOffset, limit: pageSize}}, snapshotCache.windowCalls) +} + +func TestSchedulerSnapshotService_ListSchedulableAccountsWindow_UsesWindowedDBFallback(t *testing.T) { + groupID := int64(123) + repo := &windowedOpenAIRepoStub{ + stubOpenAIAccountRepo: stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, + {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2}, + {ID: 3, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 3}, + {ID: 4, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 4}, + }, + }, + } + snapshotCache := &windowedOpenAISnapshotCacheStub{miss: true} + cfg := &config.Config{} + cfg.Gateway.Scheduling.DbFallbackEnabled = true + service := &SchedulerSnapshotService{ + cache: snapshotCache, + accountRepo: repo, + cfg: cfg, + } + + accounts, useMixed, err := service.ListSchedulableAccountsWindow(context.Background(), &groupID, PlatformOpenAI, false, 2, 2) + require.NoError(t, err) + require.False(t, useMixed) + require.Len(t, accounts, 2) + require.Equal(t, []int64{3, 4}, []int64{accounts[0].ID, accounts[1].ID}) + require.Equal(t, 0, repo.fullCalls) + require.Equal(t, []openAIWindowCall{{offset: 2, limit: 2}}, repo.windowCalls) + require.Equal(t, []openAIWindowCall{{offset: 2, limit: 2}}, snapshotCache.windowCalls) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 4e96cf0597..c761f6f5c5 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1325,14 +1325,6 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex }, nil } - accounts, err := s.listSchedulableAccounts(ctx, groupID) - if err != nil { - return nil, err - } - if len(accounts) == 0 { - return nil, ErrNoAvailableAccounts - } - isExcluded := func(accountID int64) bool { if excludedIDs == nil { return false @@ -1381,24 +1373,24 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex } // ============ Layer 2: Load-aware selection ============ - candidates := make([]*Account, 0, len(accounts)) - for i := range accounts { - acc := &accounts[i] + candidates, err := s.collectBoundedOpenAICandidates(ctx, groupID, sessionHash, requestedModel, func(acc *Account) bool { if isExcluded(acc.ID) { - continue + return false } // Scheduler snapshots can be temporarily stale (bucket rebuild is throttled); // re-check schedulability here so recently rate-limited/overloaded accounts // are not selected again before the bucket is rebuilt. if !acc.IsSchedulable() { - continue + return false } if requestedModel != "" && !acc.IsModelSupported(requestedModel) { - continue + return false } - candidates = append(candidates, acc) + return true + }) + if err != nil { + return nil, err } - if len(candidates) == 0 { return nil, ErrNoAvailableAccounts } @@ -1530,6 +1522,121 @@ func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, grou return accounts, nil } +func (s *OpenAIGatewayService) listSchedulableAccountsWindow(ctx context.Context, groupID *int64, offset, limit int) ([]Account, error) { + if limit <= 0 { + return s.listSchedulableAccounts(ctx, groupID) + } + if s.schedulerSnapshot != nil { + accounts, _, err := s.schedulerSnapshot.ListSchedulableAccountsWindow(ctx, groupID, PlatformOpenAI, false, offset, limit) + return accounts, err + } + if loader, ok := s.accountRepo.(schedulableAccountWindowLoader); ok { + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + return loader.ListSchedulableByPlatformWindow(ctx, PlatformOpenAI, offset, limit) + } + if groupID != nil { + return loader.ListSchedulableByGroupIDAndPlatformWindow(ctx, *groupID, PlatformOpenAI, offset, limit) + } + return loader.ListSchedulableUngroupedByPlatformWindow(ctx, PlatformOpenAI, offset, limit) + } + accounts, err := s.listSchedulableAccounts(ctx, groupID) + if err != nil { + return nil, err + } + return sliceAccountsWindow(accounts, offset, limit), nil +} + +func (s *OpenAIGatewayService) schedulingCandidatePageSize() int { + cfg := s.schedulingConfig() + if cfg.CandidatePageSize > 0 { + return cfg.CandidatePageSize + } + return 256 +} + +func (s *OpenAIGatewayService) schedulingCandidateScanLimit() int { + cfg := s.schedulingConfig() + if cfg.CandidateScanLimit > 0 { + return cfg.CandidateScanLimit + } + return 8192 +} + +func (s *OpenAIGatewayService) openAICandidateWindowPlan(groupID *int64, sessionHash string, requestedModel string) (pageSize int, pageCount int, startPage int) { + pageSize = s.schedulingCandidatePageSize() + scanLimit := s.schedulingCandidateScanLimit() + if pageSize <= 0 { + pageSize = 256 + } + if scanLimit < pageSize { + scanLimit = pageSize + } + pageCount = (scanLimit + pageSize - 1) / pageSize + if pageCount <= 1 { + return pageSize, 1, 0 + } + + seed := sessionHash + if seed == "" { + seed = fmt.Sprintf("%d:%s:%d", derefGroupID(groupID), requestedModel, time.Now().UnixNano()/int64(time.Millisecond)) + } else { + seed = fmt.Sprintf("%d:%s:%s", derefGroupID(groupID), requestedModel, sessionHash) + } + startPage = int(xxhash.Sum64String(seed) % uint64(pageCount)) + return pageSize, pageCount, startPage +} + +func (s *OpenAIGatewayService) collectBoundedOpenAICandidates( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + include func(*Account) bool, +) ([]*Account, error) { + pageSize, pageCount, startPage := s.openAICandidateWindowPlan(groupID, sessionHash, requestedModel) + candidates := make([]*Account, 0, pageSize) + visitedPages := make(map[int]struct{}, pageCount) + currentPage := startPage + wrappedToHead := currentPage == 0 + + for len(candidates) < pageSize { + if _, seen := visitedPages[currentPage]; seen { + break + } + visitedPages[currentPage] = struct{}{} + offset := currentPage * pageSize + window, err := s.listSchedulableAccountsWindow(ctx, groupID, offset, pageSize) + if err != nil { + return nil, err + } + if len(window) == 0 { + if currentPage != 0 && !wrappedToHead { + currentPage = 0 + wrappedToHead = true + continue + } + break + } + for i := range window { + account := &window[i] + if !include(account) { + continue + } + candidates = append(candidates, account) + if len(candidates) >= pageSize { + break + } + } + currentPage++ + if currentPage >= pageCount { + currentPage = 0 + wrappedToHead = true + } + } + + return candidates, nil +} + func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { if s.concurrencyService == nil { return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil @@ -1587,6 +1694,8 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig FallbackWaitTimeout: 30 * time.Second, FallbackMaxWaiting: 100, LoadBatchEnabled: true, + CandidatePageSize: 256, + CandidateScanLimit: 8192, SlotCleanupInterval: 30 * time.Second, } } diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 4c9540f115..98d09bdca3 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -131,6 +131,50 @@ func (s *SchedulerSnapshotService) ListSchedulableAccounts(ctx context.Context, return accounts, useMixed, nil } +func (s *SchedulerSnapshotService) ListSchedulableAccountsWindow(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool, offset, limit int) ([]Account, bool, error) { + useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform + mode := s.resolveMode(platform, hasForcePlatform) + bucket := s.bucketFor(groupID, platform, mode) + + if limit <= 0 { + return s.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + } + if offset < 0 { + offset = 0 + } + + if s.cache != nil { + if windowCache, ok := s.cache.(schedulerWindowCache); ok { + cached, hit, err := windowCache.GetSnapshotWindow(ctx, bucket, offset, limit) + if err != nil { + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] cache window read failed: bucket=%s offset=%d limit=%d err=%v", bucket.String(), offset, limit, err) + } else if hit { + return derefAccounts(cached), useMixed, nil + } + } else { + cached, hit, err := s.cache.GetSnapshot(ctx, bucket) + if err != nil { + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] cache read failed: bucket=%s err=%v", bucket.String(), err) + } else if hit { + return sliceAccountsWindow(derefAccounts(cached), offset, limit), useMixed, nil + } + } + } + + if err := s.guardFallback(ctx); err != nil { + return nil, useMixed, err + } + + fallbackCtx, cancel := s.withFallbackTimeout(ctx) + defer cancel() + + accounts, err := s.loadAccountsWindowFromDB(fallbackCtx, bucket, useMixed, offset, limit) + if err != nil { + return nil, useMixed, err + } + return accounts, useMixed, nil +} + func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int64) (*Account, error) { if accountID <= 0 { return nil, nil @@ -591,6 +635,10 @@ func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest Sc } func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucket SchedulerBucket, useMixed bool) ([]Account, error) { + return s.loadAccountsWindowFromDB(ctx, bucket, useMixed, 0, 0) +} + +func (s *SchedulerSnapshotService) loadAccountsWindowFromDB(ctx context.Context, bucket SchedulerBucket, useMixed bool, offset, limit int) ([]Account, error) { if s.accountRepo == nil { return nil, ErrSchedulerCacheNotReady } @@ -603,7 +651,15 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke platforms := []string{bucket.Platform, PlatformAntigravity} var accounts []Account var err error - if groupID > 0 { + if loader, ok := s.accountRepo.(schedulableAccountWindowLoader); ok { + if groupID > 0 { + accounts, err = loader.ListSchedulableByGroupIDAndPlatformsWindow(ctx, groupID, platforms, offset, limit) + } else if s.isRunModeSimple() { + accounts, err = loader.ListSchedulableByPlatformsWindow(ctx, platforms, offset, limit) + } else { + accounts, err = loader.ListSchedulableUngroupedByPlatformsWindow(ctx, platforms, offset, limit) + } + } else if groupID > 0 { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, groupID, platforms) } else if s.isRunModeSimple() { accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) @@ -613,6 +669,9 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke if err != nil { return nil, err } + if _, ok := s.accountRepo.(schedulableAccountWindowLoader); !ok { + accounts = sliceAccountsWindow(accounts, offset, limit) + } filtered := make([]Account, 0, len(accounts)) for _, acc := range accounts { if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { @@ -623,13 +682,25 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke return filtered, nil } + if loader, ok := s.accountRepo.(schedulableAccountWindowLoader); ok { + if groupID > 0 { + return loader.ListSchedulableByGroupIDAndPlatformWindow(ctx, groupID, bucket.Platform, offset, limit) + } + if s.isRunModeSimple() { + return loader.ListSchedulableByPlatformWindow(ctx, bucket.Platform, offset, limit) + } + return loader.ListSchedulableUngroupedByPlatformWindow(ctx, bucket.Platform, offset, limit) + } if groupID > 0 { - return s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform) + accounts, err := s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform) + return sliceAccountsWindow(accounts, offset, limit), err } if s.isRunModeSimple() { - return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform) + accounts, err := s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform) + return sliceAccountsWindow(accounts, offset, limit), err } - return s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, bucket.Platform) + accounts, err := s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, bucket.Platform) + return sliceAccountsWindow(accounts, offset, limit), err } func (s *SchedulerSnapshotService) bucketFor(groupID *int64, platform string, mode string) SchedulerBucket { diff --git a/backend/internal/service/scheduler_windowing.go b/backend/internal/service/scheduler_windowing.go new file mode 100644 index 0000000000..e17413eb13 --- /dev/null +++ b/backend/internal/service/scheduler_windowing.go @@ -0,0 +1,40 @@ +package service + +import "context" + +type schedulerWindowCache interface { + GetSnapshotWindow(ctx context.Context, bucket SchedulerBucket, offset, limit int) ([]*Account, bool, error) +} + +type schedulableAccountWindowLoader interface { + ListSchedulableByPlatformWindow(ctx context.Context, platform string, offset, limit int) ([]Account, error) + ListSchedulableByGroupIDAndPlatformWindow(ctx context.Context, groupID int64, platform string, offset, limit int) ([]Account, error) + ListSchedulableByPlatformsWindow(ctx context.Context, platforms []string, offset, limit int) ([]Account, error) + ListSchedulableByGroupIDAndPlatformsWindow(ctx context.Context, groupID int64, platforms []string, offset, limit int) ([]Account, error) + ListSchedulableUngroupedByPlatformWindow(ctx context.Context, platform string, offset, limit int) ([]Account, error) + ListSchedulableUngroupedByPlatformsWindow(ctx context.Context, platforms []string, offset, limit int) ([]Account, error) +} + +func sliceAccountsWindow(accounts []Account, offset, limit int) []Account { + if len(accounts) == 0 { + return []Account{} + } + if offset < 0 { + offset = 0 + } + if offset >= len(accounts) { + return []Account{} + } + if limit <= 0 { + out := make([]Account, len(accounts)-offset) + copy(out, accounts[offset:]) + return out + } + end := offset + limit + if end > len(accounts) { + end = len(accounts) + } + out := make([]Account, end-offset) + copy(out, accounts[offset:end]) + return out +} diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 2058ced170..1bac4c7857 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -347,6 +347,12 @@ gateway: # Enable batch load calculation for scheduling # 启用调度批量负载计算 load_batch_enabled: true + # Load-aware scheduling page size when sampling large account pools + # 大账号池负载调度时每次窗口化读取的候选数量 + candidate_page_size: 256 + # Maximum ordered accounts considered during rotating candidate sampling + # 轮转候选采样时最多参与扫描的有序账号数量 + candidate_scan_limit: 8192 # Slot cleanup interval (duration) # 并发槽位清理周期(时间段) slot_cleanup_interval: 30s