diff --git a/components/backend/cmd/sync_flags.go b/components/backend/cmd/sync_flags.go index e96e83edf..6d453d196 100644 --- a/components/backend/cmd/sync_flags.go +++ b/components/backend/cmd/sync_flags.go @@ -48,11 +48,17 @@ type FlagsConfig struct { } // FlagsFromManifest converts a model manifest into FlagSpecs. -// Skips the default model and unavailable models. +// Skips default models (global and per-provider) and unavailable models. func FlagsFromManifest(manifest *types.ModelManifest) []FlagSpec { + // Build set of all default model IDs (global + per-provider) + defaults := map[string]bool{manifest.DefaultModel: true} + for _, id := range manifest.ProviderDefaults { + defaults[id] = true + } + var specs []FlagSpec for _, model := range manifest.Models { - if model.ID == manifest.DefaultModel { + if defaults[model.ID] { continue } if !model.Available { diff --git a/components/backend/cmd/sync_flags_test.go b/components/backend/cmd/sync_flags_test.go index fa92238d1..2d2ddb61c 100644 --- a/components/backend/cmd/sync_flags_test.go +++ b/components/backend/cmd/sync_flags_test.go @@ -66,22 +66,44 @@ func TestParseManifestPath(t *testing.T) { func TestFlagsFromManifest_SkipsDefaultAndUnavailable(t *testing.T) { manifest := &types.ModelManifest{ DefaultModel: "claude-sonnet-4-5", + ProviderDefaults: map[string]string{ + "anthropic": "claude-sonnet-4-5", + "google": "gemini-2.5-flash", + }, Models: []types.ModelEntry{ - {ID: "claude-sonnet-4-5", Label: "Sonnet 4.5", Available: true}, - {ID: "claude-opus-4-6", Label: "Opus 4.6", Available: true}, - {ID: "claude-opus-4-1", Label: "Opus 4.1", Available: false}, + {ID: "claude-sonnet-4-5", Label: "Sonnet 4.5", Provider: "anthropic", Available: true}, + {ID: "claude-opus-4-6", Label: "Opus 4.6", Provider: "anthropic", Available: true}, + {ID: "claude-opus-4-1", Label: "Opus 4.1", Provider: "anthropic", Available: false}, + {ID: "gemini-2.5-flash", Label: "Gemini 2.5 Flash", Provider: "google", Available: true}, + {ID: "gemini-2.5-pro", Label: "Gemini 2.5 Pro", Provider: "google", Available: true}, }, } flags := FlagsFromManifest(manifest) - if len(flags) != 1 { - t.Fatalf("expected 1 flag, got %d: %v", len(flags), flags) + + // Should skip: claude-sonnet-4-5 (global default + anthropic default), + // gemini-2.5-flash (google default), + // claude-opus-4-1 (unavailable) + // Should include: claude-opus-4-6, gemini-2.5-pro + if len(flags) != 2 { + t.Fatalf("expected 2 flags, got %d: %v", len(flags), flags) + } + + names := map[string]bool{} + for _, f := range flags { + names[f.Name] = true + } + if !names["model.claude-opus-4-6.enabled"] { + t.Error("expected model.claude-opus-4-6.enabled") + } + if !names["model.gemini-2.5-pro.enabled"] { + t.Error("expected model.gemini-2.5-pro.enabled") } - if flags[0].Name != "model.claude-opus-4-6.enabled" { - t.Errorf("expected model.claude-opus-4-6.enabled, got %s", flags[0].Name) + if names["model.claude-sonnet-4-5.enabled"] { + t.Error("global default should be skipped") } - if len(flags[0].Tags) != 1 || flags[0].Tags[0].Type != "scope" || flags[0].Tags[0].Value != "workspace" { - t.Errorf("expected scope:workspace tag, got %v", flags[0].Tags) + if names["model.gemini-2.5-flash.enabled"] { + t.Error("provider default should be skipped") } } diff --git a/components/backend/handlers/models.go b/components/backend/handlers/models.go index d6bb15ab4..a8d40e24c 100644 --- a/components/backend/handlers/models.go +++ b/components/backend/handlers/models.go @@ -54,14 +54,15 @@ func ListModelsForProject(c *gin.Context) { ctx := c.Request.Context() namespace := sanitizeParam(c.Param("projectName")) + providerFilter := sanitizeParam(c.Query("provider")) manifest, err := LoadManifest(ManifestPath()) if err != nil { - log.Printf("WARNING: failed to load model manifest: %v", err) + log.Printf("WARNING: failed to load model manifest from disk: %v", err) manifest = cachedManifest.Load() if manifest == nil { - log.Printf("WARNING: no cached manifest available, using hardcoded defaults") - c.JSON(http.StatusOK, defaultModelsResponse()) + log.Printf("ERROR: no model manifest available (file unreadable, no cache)") + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Model manifest unavailable"}) return } } else { @@ -76,13 +77,27 @@ func ListModelsForProject(c *gin.Context) { // Continue without overrides } - var models []types.Model + // Resolve which model ID is the "default" for this request. + // When filtering by provider, use the provider-specific default. + effectiveDefault := manifest.DefaultModel + if providerFilter != "" { + if pd, ok := manifest.ProviderDefaults[providerFilter]; ok { + effectiveDefault = pd + } + } + + models := make([]types.Model, 0) for _, entry := range manifest.Models { if !entry.Available { continue } - isDefault := entry.ID == manifest.DefaultModel + // Filter by provider if specified + if providerFilter != "" && entry.Provider != providerFilter { + continue + } + + isDefault := entry.ID == effectiveDefault flagName := fmt.Sprintf("model.%s.enabled", entry.ID) // Default model is always included @@ -103,15 +118,15 @@ func ListModelsForProject(c *gin.Context) { } } + responseDefault := effectiveDefault if len(models) == 0 { - log.Printf("WARNING: no models passed filtering, using defaults") - c.JSON(http.StatusOK, defaultModelsResponse()) - return + log.Printf("WARNING: no models passed filtering for provider=%q in namespace %s", providerFilter, namespace) + responseDefault = "" } c.JSON(http.StatusOK, types.ListModelsResponse{ Models: models, - DefaultModel: manifest.DefaultModel, + DefaultModel: responseDefault, }) } @@ -143,94 +158,68 @@ func LoadManifest(path string) (*types.ModelManifest, error) { return &manifest, nil } -// isModelAvailable checks if a model is available for session creation. -// -// Validation strategy: -// 1. Check the agent registry — if the model is declared in the selected -// runner's model list, it's valid. This is the primary check for all runners. -// 2. For models also in the models.json manifest (Claude models), additionally -// check feature-flag gating and workspace overrides. -// 3. If the model is not found in either source, reject it. -func isModelAvailable(ctx context.Context, k8sClient kubernetes.Interface, modelID, runnerTypeID, namespace string) bool { +// isModelAvailable checks if a model is available for session creation in the +// given workspace namespace. All models (Claude and Gemini) are validated +// against models.json. Returns true if the model exists, is available, and +// is enabled (checking workspace overrides first, then Unleash). +// When requiredProvider is non-empty, the model's provider must match +// (prevents using a Gemini model with a Claude runner, for example). +// The default model always returns true. Fails open when no manifest has +// ever been loaded (cold start). +func isModelAvailable(ctx context.Context, k8sClient kubernetes.Interface, modelID, requiredProvider, namespace string) bool { if modelID == "" { return true // Empty model will use default } - // 1. Check agent registry — runner-specific model validation - rt, err := GetRuntime(runnerTypeID) - if err == nil && len(rt.Models) > 0 { - found := false - for _, m := range rt.Models { - if m.Value == modelID { - found = true - break - } - } - if !found { - log.Printf("Model %q not in runner %q model list, rejecting", modelID, runnerTypeID) - return false - } - // Model is in the runner's list — now check if it also needs - // feature-flag gating via the manifest (applies to Claude models). - } - - // 2. Check models.json manifest for feature-flag gating (if applicable) manifest, err := LoadManifest(ManifestPath()) if err != nil { - log.Printf("WARNING: failed to load model manifest: %v", err) + log.Printf("WARNING: failed to load model manifest for validation: %v", err) manifest = cachedManifest.Load() + if manifest == nil { + // When we know the runner's provider, reject unknown models rather + // than allowing a cross-provider mismatch through to the runner. + // Fail-open only when both manifest and registry are unavailable + // (requiredProvider == "") to avoid blocking cold starts. + if requiredProvider != "" { + log.Printf("WARNING: no manifest available, rejecting model %q (provider=%q)", modelID, requiredProvider) + return false + } + log.Printf("WARNING: no manifest or registry available, allowing model %q", modelID) + return true + } } else { cachedManifest.Store(manifest) } - if manifest != nil { - // Default model is always available - if modelID == manifest.DefaultModel { - return true - } - for _, entry := range manifest.Models { - if entry.ID == modelID { - if !entry.Available { - return false - } - flagName := fmt.Sprintf("model.%s.enabled", entry.ID) - overrides, oErr := getWorkspaceOverrides(ctx, k8sClient, namespace) - if oErr != nil { - log.Printf("WARNING: failed to read workspace overrides for %s: %v", namespace, oErr) + for _, entry := range manifest.Models { + if entry.ID == modelID { + if !entry.Available { + return false + } + // Provider mismatch check applies to ALL models, including defaults + if requiredProvider != "" && entry.Provider != requiredProvider { + log.Printf("Model %q has provider %q but runner requires %q", modelID, entry.Provider, requiredProvider) + return false + } + // Default models (global and per-provider) are always enabled + // (skip feature flag check) but must still pass provider matching above + if modelID == manifest.DefaultModel { + return true + } + for provider, pd := range manifest.ProviderDefaults { + if modelID == pd && (requiredProvider == "" || provider == requiredProvider) { + return true } - return isModelEnabledWithOverrides(flagName, overrides) } + flagName := fmt.Sprintf("model.%s.enabled", entry.ID) + overrides, oErr := getWorkspaceOverrides(ctx, k8sClient, namespace) + if oErr != nil { + log.Printf("WARNING: failed to read workspace overrides for %s: %v", namespace, oErr) + } + return isModelEnabledWithOverrides(flagName, overrides) } } - // 3. If we validated via registry in step 1 (found=true), allow it. - // Models not in the manifest skip feature-flag gating (e.g., Gemini models). - if rt != nil && len(rt.Models) > 0 { - return true // Already validated in step 1 - } - - // No manifest loaded and no registry available — fail-open on cold start - if manifest == nil { - log.Printf("WARNING: no manifest or registry available, allowing model %q", modelID) - return true - } - - log.Printf("WARNING: model %q not found in manifest or agent registry, rejecting", modelID) + log.Printf("WARNING: model %q not found in manifest, rejecting", modelID) return false } - -// defaultModelsResponse returns a hardcoded ListModelsResponse as a fallback -// when the model manifest file is unavailable or malformed. -// Keep in sync with components/manifests/base/models.json (available: true entries). -func defaultModelsResponse() types.ListModelsResponse { - return types.ListModelsResponse{ - DefaultModel: "claude-sonnet-4-5", - Models: []types.Model{ - {ID: "claude-sonnet-4-5", Label: "Claude Sonnet 4.5", Provider: "anthropic", IsDefault: true}, - {ID: "claude-sonnet-4-6", Label: "Claude Sonnet 4.6", Provider: "anthropic", IsDefault: false}, - {ID: "claude-opus-4-6", Label: "Claude Opus 4.6", Provider: "anthropic", IsDefault: false}, - {ID: "claude-opus-4-5", Label: "Claude Opus 4.5", Provider: "anthropic", IsDefault: false}, - {ID: "claude-haiku-4-5", Label: "Claude Haiku 4.5", Provider: "anthropic", IsDefault: false}, - }, - } -} diff --git a/components/backend/handlers/models_test.go b/components/backend/handlers/models_test.go index 65e1076b3..7a6eb7576 100644 --- a/components/backend/handlers/models_test.go +++ b/components/backend/handlers/models_test.go @@ -36,13 +36,19 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant ) validManifestObj := types.ModelManifest{ - Version: 1, + Version: 2, DefaultModel: "claude-sonnet-4-5", + ProviderDefaults: map[string]string{ + "anthropic": "claude-sonnet-4-5", + "google": "gemini-2.5-flash", + }, Models: []types.ModelEntry{ {ID: "claude-sonnet-4-5", Label: "Claude Sonnet 4.5", VertexID: "claude-sonnet-4-5@20250929", Provider: "anthropic", Available: true}, {ID: "claude-opus-4-6", Label: "Claude Opus 4.6", VertexID: "claude-opus-4-6@default", Provider: "anthropic", Available: true}, {ID: "claude-opus-4-5", Label: "Claude Opus 4.5", VertexID: "claude-opus-4-5@20251101", Provider: "anthropic", Available: true}, {ID: "claude-haiku-4-5", Label: "Claude Haiku 4.5", VertexID: "claude-haiku-4-5@20251001", Provider: "anthropic", Available: true}, + {ID: "gemini-2.5-flash", Label: "Gemini 2.5 Flash", VertexID: "gemini-2.5-flash", Provider: "google", Available: true}, + {ID: "gemini-2.5-pro", Label: "Gemini 2.5 Pro", VertexID: "gemini-2.5-pro", Provider: "google", Available: true}, }, } @@ -130,8 +136,8 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant var resp types.ListModelsResponse err := json.Unmarshal(httpTestUtils.GetResponseRecorder().Body.Bytes(), &resp) Expect(err).NotTo(HaveOccurred()) - // With no Unleash configured, IsModelEnabled returns true, so all 4 models pass - Expect(resp.Models).To(HaveLen(4)) + // With no Unleash configured, IsModelEnabled returns true, so all 6 models pass + Expect(resp.Models).To(HaveLen(6)) Expect(resp.DefaultModel).To(Equal("claude-sonnet-4-5")) }) @@ -224,9 +230,9 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant err := json.Unmarshal(httpTestUtils.GetResponseRecorder().Body.Bytes(), &resp) Expect(err).NotTo(HaveOccurred()) - // opus-4-6 excluded by override; the other 3 should still be present - // (default model + 2 non-default models via Unleash fallback which returns true when not configured) - Expect(resp.Models).To(HaveLen(3)) + // opus-4-6 excluded by override; the other 5 should still be present + // (default model + 4 non-default models via Unleash fallback which returns true when not configured) + Expect(resp.Models).To(HaveLen(5)) ids := make([]string, len(resp.Models)) for i, m := range resp.Models { ids[i] = m.ID @@ -234,6 +240,8 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant Expect(ids).To(ContainElement("claude-sonnet-4-5")) Expect(ids).To(ContainElement("claude-opus-4-5")) Expect(ids).To(ContainElement("claude-haiku-4-5")) + Expect(ids).To(ContainElement("gemini-2.5-flash")) + Expect(ids).To(ContainElement("gemini-2.5-pro")) Expect(ids).NotTo(ContainElement("claude-opus-4-6")) }) @@ -284,15 +292,15 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant var resp types.ListModelsResponse err = json.Unmarshal(httpTestUtils.GetResponseRecorder().Body.Bytes(), &resp) Expect(err).NotTo(HaveOccurred()) - Expect(resp.Models).To(HaveLen(3)) + Expect(resp.Models).To(HaveLen(5)) for _, m := range resp.Models { Expect(m.ID).NotTo(Equal("claude-opus-4-6")) } }) - It("should return hardcoded defaults when manifest file is missing and no cache", func() { - logger.Log("Testing ListModelsForProject fallback when manifest file missing and no cache") + It("should return 503 when manifest file is missing and no cache", func() { + logger.Log("Testing ListModelsForProject returns 503 when manifest unavailable") os.Setenv("MODELS_MANIFEST_PATH", filepath.Join(GinkgoT().TempDir(), "nonexistent.json")) fakeClient := setupK8sWithOverrides() setupAuth(fakeClient) @@ -300,13 +308,7 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant ginCtx := createAuthenticatedContext("test-project") ListModelsForProject(ginCtx) - httpTestUtils.AssertHTTPStatus(http.StatusOK) - - var resp types.ListModelsResponse - err := json.Unmarshal(httpTestUtils.GetResponseRecorder().Body.Bytes(), &resp) - Expect(err).NotTo(HaveOccurred()) - Expect(resp.Models).To(HaveLen(5)) - Expect(resp.DefaultModel).To(Equal("claude-sonnet-4-5")) + httpTestUtils.AssertHTTPStatus(http.StatusServiceUnavailable) }) It("should use cached manifest when file becomes unavailable", func() { @@ -333,14 +335,13 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant var resp types.ListModelsResponse err := json.Unmarshal(httpTestUtils.GetResponseRecorder().Body.Bytes(), &resp) Expect(err).NotTo(HaveOccurred()) - // Cached manifest has 4 models (not 5 hardcoded defaults), - // and they go through flag filtering - Expect(resp.Models).To(HaveLen(4)) + // Cached manifest has 6 models and they go through flag filtering + Expect(resp.Models).To(HaveLen(6)) Expect(resp.DefaultModel).To(Equal("claude-sonnet-4-5")) }) - It("should return hardcoded defaults when JSON is malformed and no cache", func() { - logger.Log("Testing ListModelsForProject fallback with malformed JSON and no cache") + It("should return 503 when JSON is malformed and no cache", func() { + logger.Log("Testing ListModelsForProject returns 503 with malformed JSON and no cache") writeManifestFile("{invalid json") fakeClient := setupK8sWithOverrides() setupAuth(fakeClient) @@ -348,14 +349,79 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant ginCtx := createAuthenticatedContext("test-project") ListModelsForProject(ginCtx) + httpTestUtils.AssertHTTPStatus(http.StatusServiceUnavailable) + }) + }) + + Context("ListModelsForProject with provider filter", func() { + It("should return only anthropic models when provider=anthropic", func() { + logger.Log("Testing provider filter for anthropic") + writeManifestFile(validManifest) + fakeClient := setupK8sWithOverrides() + setupAuth(fakeClient) + + ginCtx := httpTestUtils.CreateTestGinContext("GET", "/api/projects/test-project/models?provider=anthropic", nil) + httpTestUtils.SetAuthHeader("test-token") + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + ginCtx.Request.URL.RawQuery = "provider=anthropic" + + ListModelsForProject(ginCtx) + httpTestUtils.AssertHTTPStatus(http.StatusOK) var resp types.ListModelsResponse err := json.Unmarshal(httpTestUtils.GetResponseRecorder().Body.Bytes(), &resp) Expect(err).NotTo(HaveOccurred()) - Expect(resp.Models).To(HaveLen(5)) + + for _, m := range resp.Models { + Expect(m.Provider).To(Equal("anthropic"), "All models should be anthropic") + } + Expect(resp.Models).To(HaveLen(4)) Expect(resp.DefaultModel).To(Equal("claude-sonnet-4-5")) }) + + It("should return only google models when provider=google", func() { + logger.Log("Testing provider filter for google") + writeManifestFile(validManifest) + fakeClient := setupK8sWithOverrides() + setupAuth(fakeClient) + + ginCtx := httpTestUtils.CreateTestGinContext("GET", "/api/projects/test-project/models?provider=google", nil) + httpTestUtils.SetAuthHeader("test-token") + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + ginCtx.Request.URL.RawQuery = "provider=google" + + ListModelsForProject(ginCtx) + + httpTestUtils.AssertHTTPStatus(http.StatusOK) + + var resp types.ListModelsResponse + err := json.Unmarshal(httpTestUtils.GetResponseRecorder().Body.Bytes(), &resp) + Expect(err).NotTo(HaveOccurred()) + + for _, m := range resp.Models { + Expect(m.Provider).To(Equal("google"), "All models should be google") + } + Expect(resp.Models).To(HaveLen(2)) + Expect(resp.DefaultModel).To(Equal("gemini-2.5-flash")) + }) + + It("should return all models when no provider filter", func() { + logger.Log("Testing no provider filter returns all models") + writeManifestFile(validManifest) + fakeClient := setupK8sWithOverrides() + setupAuth(fakeClient) + + ginCtx := createAuthenticatedContext("test-project") + ListModelsForProject(ginCtx) + + httpTestUtils.AssertHTTPStatus(http.StatusOK) + + var resp types.ListModelsResponse + err := json.Unmarshal(httpTestUtils.GetResponseRecorder().Body.Bytes(), &resp) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.Models).To(HaveLen(6)) + }) }) Context("isModelEnabledWithOverrides", func() { @@ -392,9 +458,11 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant manifest, err := LoadManifest(path) Expect(err).NotTo(HaveOccurred()) - Expect(manifest.Version).To(Equal(1)) + Expect(manifest.Version).To(Equal(2)) Expect(manifest.DefaultModel).To(Equal("claude-sonnet-4-5")) - Expect(manifest.Models).To(HaveLen(4)) + Expect(manifest.ProviderDefaults).To(HaveLen(2)) + Expect(manifest.ProviderDefaults["google"]).To(Equal("gemini-2.5-flash")) + Expect(manifest.Models).To(HaveLen(6)) }) It("should return error when file is missing", func() { @@ -415,7 +483,7 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant Context("isModelAvailable", func() { It("should return true for empty model ID", func() { logger.Log("Testing isModelAvailable with empty model ID") - result := isModelAvailable(context.Background(), K8sClient, "", "claude-agent-sdk", "test-ns") + result := isModelAvailable(context.Background(), K8sClient, "", "", "test-ns") Expect(result).To(BeTrue()) }) @@ -424,7 +492,7 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant writeManifestFile(validManifest) setupK8sWithOverrides() - result := isModelAvailable(context.Background(), K8sClient, "claude-sonnet-4-5", "claude-agent-sdk", "test-ns") + result := isModelAvailable(context.Background(), K8sClient, "claude-sonnet-4-5", "", "test-ns") Expect(result).To(BeTrue()) }) @@ -433,7 +501,7 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant writeManifestFile(validManifest) setupK8sWithOverrides() - result := isModelAvailable(context.Background(), K8sClient, "claude-opus-4-6", "claude-agent-sdk", "test-ns") + result := isModelAvailable(context.Background(), K8sClient, "claude-opus-4-6", "", "test-ns") Expect(result).To(BeTrue()) }) @@ -450,7 +518,7 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant writeManifestFile(string(manifestBytes)) setupK8sWithOverrides() - result := isModelAvailable(context.Background(), K8sClient, "claude-opus-4-6", "claude-agent-sdk", "test-ns") + result := isModelAvailable(context.Background(), K8sClient, "claude-opus-4-6", "", "test-ns") Expect(result).To(BeFalse()) }) @@ -459,18 +527,26 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant writeManifestFile(validManifest) setupK8sWithOverrides() - result := isModelAvailable(context.Background(), K8sClient, "nonexistent-model", "claude-agent-sdk", "test-ns") + result := isModelAvailable(context.Background(), K8sClient, "nonexistent-model", "", "test-ns") Expect(result).To(BeFalse()) }) - It("should fail-open when manifest file is missing", func() { - logger.Log("Testing isModelAvailable fail-open when manifest file missing") + It("should fail-open when manifest is missing and no provider required", func() { + logger.Log("Testing isModelAvailable fail-open when manifest missing and requiredProvider empty") os.Setenv("MODELS_MANIFEST_PATH", filepath.Join(GinkgoT().TempDir(), "nonexistent.json")) - result := isModelAvailable(context.Background(), K8sClient, "claude-opus-4-6", "claude-agent-sdk", "test-ns") + result := isModelAvailable(context.Background(), K8sClient, "claude-opus-4-6", "", "test-ns") Expect(result).To(BeTrue()) }) + It("should reject when manifest is missing but provider is required", func() { + logger.Log("Testing isModelAvailable rejects when manifest missing and requiredProvider set") + os.Setenv("MODELS_MANIFEST_PATH", filepath.Join(GinkgoT().TempDir(), "nonexistent.json")) + + result := isModelAvailable(context.Background(), K8sClient, "claude-opus-4-6", "anthropic", "test-ns") + Expect(result).To(BeFalse()) + }) + It("should return false when workspace override disables the model", func() { logger.Log("Testing isModelAvailable respects workspace override=false") writeManifestFile(validManifest) @@ -485,7 +561,7 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant } setupK8sWithOverrides(overrideCM) - result := isModelAvailable(context.Background(), K8sClient, "claude-opus-4-6", "claude-agent-sdk", "test-project") + result := isModelAvailable(context.Background(), K8sClient, "claude-opus-4-6", "", "test-project") Expect(result).To(BeFalse()) }) @@ -503,7 +579,35 @@ var _ = Describe("Models Handler", Label(test_constants.LabelUnit, test_constant } setupK8sWithOverrides(overrideCM) - result := isModelAvailable(context.Background(), K8sClient, "claude-opus-4-6", "claude-agent-sdk", "test-project") + result := isModelAvailable(context.Background(), K8sClient, "claude-opus-4-6", "", "test-project") + Expect(result).To(BeTrue()) + }) + + It("should reject provider-default model when provider does not match requiredProvider", func() { + logger.Log("Testing isModelAvailable rejects provider-default with wrong provider") + writeManifestFile(validManifest) + setupK8sWithOverrides() + + // gemini-2.5-flash is the google provider default — should be rejected for anthropic runner + result := isModelAvailable(context.Background(), K8sClient, "gemini-2.5-flash", "anthropic", "test-ns") + Expect(result).To(BeFalse()) + }) + + It("should reject model when provider does not match requiredProvider", func() { + logger.Log("Testing isModelAvailable rejects provider mismatch") + writeManifestFile(validManifest) + setupK8sWithOverrides() + + result := isModelAvailable(context.Background(), K8sClient, "claude-opus-4-6", "google", "test-ns") + Expect(result).To(BeFalse()) + }) + + It("should accept model when provider matches requiredProvider", func() { + logger.Log("Testing isModelAvailable accepts matching provider") + writeManifestFile(validManifest) + setupK8sWithOverrides() + + result := isModelAvailable(context.Background(), K8sClient, "claude-opus-4-6", "anthropic", "test-ns") Expect(result).To(BeTrue()) }) }) diff --git a/components/backend/handlers/runner_types.go b/components/backend/handlers/runner_types.go index 626167cf4..c4875918a 100644 --- a/components/backend/handlers/runner_types.go +++ b/components/backend/handlers/runner_types.go @@ -23,16 +23,15 @@ const DefaultRunnerPort = 8001 // Keep both in sync when modifying the schema. // It is the single source of truth for runtime configuration. type AgentRuntimeSpec struct { - ID string `json:"id"` - DisplayName string `json:"displayName"` - Description string `json:"description"` - Framework string `json:"framework"` - Container ContainerSpec `json:"container"` - Sandbox SandboxSpec `json:"sandbox"` - Auth AuthSpec `json:"auth"` - DefaultModel string `json:"defaultModel"` - Models []ModelOption `json:"models"` - FeatureGate string `json:"featureGate"` + ID string `json:"id"` + DisplayName string `json:"displayName"` + Description string `json:"description"` + Framework string `json:"framework"` + Container ContainerSpec `json:"container"` + Sandbox SandboxSpec `json:"sandbox"` + Auth AuthSpec `json:"auth"` + Provider string `json:"provider"` + FeatureGate string `json:"featureGate"` } // ContainerSpec defines the runner container configuration. @@ -72,23 +71,16 @@ type AuthSpec struct { VertexSupported bool `json:"vertexSupported"` } -// ModelOption represents a model choice within a runner type. -type ModelOption struct { - Value string `json:"value"` - Label string `json:"label"` -} - // RunnerTypeResponse is the public API shape returned to the frontend. // FeatureGate is intentionally excluded — gated runners are already filtered // out by the handler, so the frontend never needs to see the gate name. type RunnerTypeResponse struct { - ID string `json:"id"` - DisplayName string `json:"displayName"` - Description string `json:"description"` - Framework string `json:"framework"` - DefaultModel string `json:"defaultModel"` - Models []ModelOption `json:"models"` - Auth AuthSpec `json:"auth"` + ID string `json:"id"` + DisplayName string `json:"displayName"` + Description string `json:"description"` + Framework string `json:"framework"` + Provider string `json:"provider"` + Auth AuthSpec `json:"auth"` } // In-memory cache for the agent registry (ConfigMap content changes rarely). @@ -221,9 +213,20 @@ func isRunnerEnabled(runnerID string) bool { return FeatureEnabled(rt.FeatureGate) } -// GetRunnerTypes handles GET /api/runner-types and returns the list of available runner types. -// Runners gated by feature flags are filtered out. -func GetRunnerTypes(c *gin.Context) { +// isRunnerEnabledWithOverrides checks workspace ConfigMap overrides first, +// then falls back to the Unleash SDK for global state. +func isRunnerEnabledWithOverrides(flagName string, overrides map[string]string) bool { + if overrides != nil { + if val, exists := overrides[flagName]; exists { + return val == "true" + } + } + return FeatureEnabled(flagName) +} + +// GetRunnerTypesGlobal handles GET /api/runner-types (no auth, no workspace overrides). +// Used by admin pages that need to list all runner types regardless of workspace. +func GetRunnerTypesGlobal(c *gin.Context) { entries, err := loadAgentRegistry() if err != nil { log.Printf("Failed to load agent registry: %v", err) @@ -233,19 +236,61 @@ func GetRunnerTypes(c *gin.Context) { resp := make([]RunnerTypeResponse, 0, len(entries)) for _, e := range entries { - // Check feature gate directly instead of calling isRunnerEnabled (which - // re-loads the registry per entry — N+1 pattern). if e.FeatureGate != "" && !FeatureEnabled(e.FeatureGate) { continue } resp = append(resp, RunnerTypeResponse{ - ID: e.ID, - DisplayName: e.DisplayName, - Description: e.Description, - Framework: e.Framework, - DefaultModel: e.DefaultModel, - Models: e.Models, - Auth: e.Auth, + ID: e.ID, + DisplayName: e.DisplayName, + Description: e.Description, + Framework: e.Framework, + Provider: e.Provider, + Auth: e.Auth, + }) + } + + c.JSON(http.StatusOK, resp) +} + +// GetRunnerTypes handles GET /api/projects/:projectName/runner-types and returns +// the list of available runner types. Runners gated by feature flags are filtered +// out, respecting workspace-scoped overrides in the feature-flag-overrides ConfigMap. +func GetRunnerTypes(c *gin.Context) { + reqK8s, _ := GetK8sClientsForRequest(c) + if reqK8s == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "User token required"}) + c.Abort() + return + } + + ctx := c.Request.Context() + namespace := sanitizeParam(c.Param("projectName")) + + entries, err := loadAgentRegistry() + if err != nil { + log.Printf("Failed to load agent registry: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to load runner types"}) + return + } + + // Load workspace overrides for feature gate evaluation + overrides, err := getWorkspaceOverrides(ctx, reqK8s, namespace) + if err != nil { + log.Printf("WARNING: failed to read workspace overrides for runner types in %s: %v", namespace, err) + } + + resp := make([]RunnerTypeResponse, 0, len(entries)) + for _, e := range entries { + if e.FeatureGate != "" && !isRunnerEnabledWithOverrides(e.FeatureGate, overrides) { + continue + } + resp = append(resp, RunnerTypeResponse{ + ID: e.ID, + DisplayName: e.DisplayName, + Description: e.Description, + Framework: e.Framework, + Provider: e.Provider, + Auth: e.Auth, }) } diff --git a/components/backend/handlers/runner_types_test.go b/components/backend/handlers/runner_types_test.go index 32b10e682..ba6b91380 100644 --- a/components/backend/handlers/runner_types_test.go +++ b/components/backend/handlers/runner_types_test.go @@ -10,6 +10,9 @@ import ( "time" "github.com/gin-gonic/gin" + "k8s.io/apimachinery/pkg/runtime" + dynamicfake "k8s.io/client-go/dynamic/fake" + "k8s.io/client-go/kubernetes/fake" ) // sampleRegistryJSON returns a test agent registry JSON with 2 runtimes. @@ -20,6 +23,7 @@ func sampleRegistryJSON() string { DisplayName: "Claude Code", Description: "Anthropic Claude with full coding capabilities", Framework: "claude-agent-sdk", + Provider: "anthropic", Container: ContainerSpec{ Image: "quay.io/ambient_code/ambient_runner:latest", Port: 8001, @@ -42,11 +46,6 @@ func sampleRegistryJSON() string { SecretKeyLogic: "any", VertexSupported: true, }, - DefaultModel: "claude-sonnet-4-5", - Models: []ModelOption{ - {Value: "claude-sonnet-4-5", Label: "Claude Sonnet 4.5"}, - {Value: "claude-opus-4-6", Label: "Claude Opus 4.6"}, - }, FeatureGate: "", }, { @@ -54,6 +53,7 @@ func sampleRegistryJSON() string { DisplayName: "Gemini CLI", Description: "Google Gemini coding agent", Framework: "gemini-cli", + Provider: "google", Container: ContainerSpec{ Image: "quay.io/ambient_code/ambient_runner:latest", Port: 9090, @@ -72,10 +72,6 @@ func sampleRegistryJSON() string { SecretKeyLogic: "any", VertexSupported: true, }, - DefaultModel: "gemini-2.5-flash", - Models: []ModelOption{ - {Value: "gemini-2.5-flash", Label: "Gemini 2.5 Flash"}, - }, FeatureGate: "runner.gemini-cli.enabled", }, } @@ -88,13 +84,21 @@ func sampleRegistryJSON() string { func setupRegistryForTest(t *testing.T) { t.Helper() - // Write test registry JSON to a temp file and point AGENT_REGISTRY_PATH at it - tmpDir := t.TempDir() - tmpFile := filepath.Join(tmpDir, "agent-registry.json") - if err := os.WriteFile(tmpFile, []byte(sampleRegistryJSON()), 0644); err != nil { + // Write registry JSON to a temp file and point env var to it + dir := t.TempDir() + path := filepath.Join(dir, "agent-registry.json") + if err := os.WriteFile(path, []byte(sampleRegistryJSON()), 0644); err != nil { t.Fatalf("Failed to write test registry: %v", err) } - t.Setenv("AGENT_REGISTRY_PATH", tmpFile) + t.Setenv("AGENT_REGISTRY_PATH", path) + + // Set up fake K8s clients for auth and workspace overrides + K8sClientMw = fake.NewSimpleClientset() + DynamicClient = dynamicfake.NewSimpleDynamicClient(runtime.NewScheme()) + + if Namespace == "" { + Namespace = "test-ns" + } // Clear the in-memory cache registryCacheMu.Lock() @@ -146,12 +150,12 @@ func TestGetRuntime_FullFields(t *testing.T) { t.Fatalf("GetRuntime failed: %v", err) } - // Framework if rt.Framework != "claude-agent-sdk" { t.Errorf("Framework: expected 'claude-agent-sdk', got %q", rt.Framework) } - - // Auth + if rt.Provider != "anthropic" { + t.Errorf("Provider: expected 'anthropic', got %q", rt.Provider) + } if len(rt.Auth.RequiredSecretKeys) != 1 || rt.Auth.RequiredSecretKeys[0] != "ANTHROPIC_API_KEY" { t.Errorf("Auth.RequiredSecretKeys: expected [ANTHROPIC_API_KEY], got %v", rt.Auth.RequiredSecretKeys) } @@ -161,19 +165,9 @@ func TestGetRuntime_FullFields(t *testing.T) { if !rt.Auth.VertexSupported { t.Error("Auth.VertexSupported: expected true") } - - // FeatureGate if rt.FeatureGate != "" { t.Errorf("FeatureGate: expected empty string, got %q", rt.FeatureGate) } - - // Models - if len(rt.Models) != 2 { - t.Errorf("Expected 2 models, got %d", len(rt.Models)) - } - if rt.DefaultModel != "claude-sonnet-4-5" { - t.Errorf("DefaultModel: expected 'claude-sonnet-4-5', got %q", rt.DefaultModel) - } } func TestGetRuntime_GeminiFields(t *testing.T) { @@ -187,6 +181,9 @@ func TestGetRuntime_GeminiFields(t *testing.T) { if rt.Framework != "gemini-cli" { t.Errorf("Framework: expected 'gemini-cli', got %q", rt.Framework) } + if rt.Provider != "google" { + t.Errorf("Provider: expected 'google', got %q", rt.Provider) + } if rt.FeatureGate != "runner.gemini-cli.enabled" { t.Errorf("FeatureGate: expected 'runner.gemini-cli.enabled', got %q", rt.FeatureGate) } @@ -272,18 +269,16 @@ func TestGetContainerEnvVars_UnknownFallback(t *testing.T) { // --- GetRunnerTypes handler test --- -func TestGetRunnerTypes_ReturnsFullFields(t *testing.T) { +func TestGetRunnerTypes_ReturnsProvider(t *testing.T) { setupRegistryForTest(t) - // Without Unleash initialized, FeatureEnabled returns false. - // isRunnerEnabled returns true for runtimes with empty featureGate, - // and false for runtimes with a non-empty featureGate. - // In our test data: claude (featureGate="") -> enabled, gemini (featureGate="runner.gemini-cli.enabled") -> disabled. - gin.SetMode(gin.TestMode) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) - c.Request = httptest.NewRequest(http.MethodGet, "/api/runner-types", nil) + req := httptest.NewRequest(http.MethodGet, "/api/projects/test-project/runner-types", nil) + req.Header.Set("Authorization", "Bearer test-token") + c.Request = req + c.Params = gin.Params{{Key: "projectName", Value: "test-project"}} GetRunnerTypes(c) @@ -297,7 +292,6 @@ func TestGetRunnerTypes_ReturnsFullFields(t *testing.T) { } // Only claude-agent-sdk should be returned (empty featureGate = always enabled) - // gemini-cli has featureGate="runner.gemini-cli.enabled" which is disabled without Unleash if len(resp) != 1 { t.Fatalf("Expected 1 runner type (only ungated), got %d", len(resp)) } @@ -306,36 +300,27 @@ func TestGetRunnerTypes_ReturnsFullFields(t *testing.T) { if claude.ID != "claude-agent-sdk" { t.Fatalf("Expected claude-agent-sdk, got %q", claude.ID) } - - // Verify full AgentRuntimeSpec fields are in the response + if claude.Provider != "anthropic" { + t.Errorf("Provider: expected 'anthropic', got %q", claude.Provider) + } if claude.Framework != "claude-agent-sdk" { t.Errorf("Framework: expected 'claude-agent-sdk', got %q", claude.Framework) } if claude.Auth.SecretKeyLogic != "any" { t.Errorf("Auth.SecretKeyLogic: expected 'any', got %q", claude.Auth.SecretKeyLogic) } - if claude.Auth.VertexSupported != true { - t.Error("Auth.VertexSupported: expected true") - } - if len(claude.Auth.RequiredSecretKeys) != 1 || claude.Auth.RequiredSecretKeys[0] != "ANTHROPIC_API_KEY" { - t.Errorf("Auth.RequiredSecretKeys: expected [ANTHROPIC_API_KEY], got %v", claude.Auth.RequiredSecretKeys) - } - if claude.DefaultModel != "claude-sonnet-4-5" { - t.Errorf("DefaultModel: expected 'claude-sonnet-4-5', got %q", claude.DefaultModel) - } - if len(claude.Models) != 2 { - t.Errorf("Expected 2 models, got %d", len(claude.Models)) - } } func TestGetRunnerTypes_GatedRunnersFiltered(t *testing.T) { setupRegistryForTest(t) - // Without Unleash, gated runners should be filtered out gin.SetMode(gin.TestMode) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) - c.Request = httptest.NewRequest(http.MethodGet, "/api/runner-types", nil) + req := httptest.NewRequest(http.MethodGet, "/api/projects/test-project/runner-types", nil) + req.Header.Set("Authorization", "Bearer test-token") + c.Request = req + c.Params = gin.Params{{Key: "projectName", Value: "test-project"}} GetRunnerTypes(c) @@ -352,7 +337,6 @@ func TestGetRunnerTypes_GatedRunnersFiltered(t *testing.T) { func TestIsRunnerEnabled_EmptyGate(t *testing.T) { setupRegistryForTest(t) - // Runtimes with empty featureGate should always be enabled if !isRunnerEnabled("claude-agent-sdk") { t.Error("claude-agent-sdk with empty featureGate should be enabled") } @@ -361,8 +345,88 @@ func TestIsRunnerEnabled_EmptyGate(t *testing.T) { func TestIsRunnerEnabled_NonEmptyGate_Disabled(t *testing.T) { setupRegistryForTest(t) - // Without Unleash, non-empty featureGate should be disabled if isRunnerEnabled("gemini-cli") { t.Error("gemini-cli should be disabled when Unleash is not configured") } } + +// --- isRunnerEnabledWithOverrides tests --- + +func TestIsRunnerEnabledWithOverrides_OverrideTrue(t *testing.T) { + overrides := map[string]string{"runner.gemini-cli.enabled": "true"} + if !isRunnerEnabledWithOverrides("runner.gemini-cli.enabled", overrides) { + t.Error("expected enabled when override is true") + } +} + +func TestIsRunnerEnabledWithOverrides_OverrideFalse(t *testing.T) { + overrides := map[string]string{"runner.gemini-cli.enabled": "false"} + if isRunnerEnabledWithOverrides("runner.gemini-cli.enabled", overrides) { + t.Error("expected disabled when override is false") + } +} + +func TestIsRunnerEnabledWithOverrides_NoOverrideFallsThrough(t *testing.T) { + overrides := map[string]string{"other.flag": "true"} + // Without Unleash configured, FeatureEnabled returns false + if isRunnerEnabledWithOverrides("runner.gemini-cli.enabled", overrides) { + t.Error("expected disabled when no override and Unleash not configured") + } +} + +func TestIsRunnerEnabledWithOverrides_NilOverrides(t *testing.T) { + // Without Unleash configured, FeatureEnabled returns false + if isRunnerEnabledWithOverrides("runner.gemini-cli.enabled", nil) { + t.Error("expected disabled with nil overrides and Unleash not configured") + } +} + +// --- GetRunnerTypesGlobal tests --- + +func TestGetRunnerTypesGlobal_ReturnsUngatedRunners(t *testing.T) { + setupRegistryForTest(t) + + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/api/runner-types", nil) + + GetRunnerTypesGlobal(c) + + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp []RunnerTypeResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + // Only ungated runners returned (gemini-cli gated, disabled without Unleash) + if len(resp) != 1 { + t.Fatalf("Expected 1 runner type, got %d", len(resp)) + } + if resp[0].ID != "claude-agent-sdk" { + t.Errorf("Expected claude-agent-sdk, got %q", resp[0].ID) + } + if resp[0].Provider != "anthropic" { + t.Errorf("Expected provider anthropic, got %q", resp[0].Provider) + } +} + +func TestGetRunnerTypesGlobal_NoAuthRequired(t *testing.T) { + setupRegistryForTest(t) + + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + // No auth header + c.Request = httptest.NewRequest(http.MethodGet, "/api/runner-types", nil) + + GetRunnerTypesGlobal(c) + + // Should succeed without auth + if w.Code != http.StatusOK { + t.Fatalf("Expected 200 without auth, got %d", w.Code) + } +} diff --git a/components/backend/handlers/sessions.go b/components/backend/handlers/sessions.go index 19e1071ed..aa93a53fb 100644 --- a/components/backend/handlers/sessions.go +++ b/components/backend/handlers/sessions.go @@ -624,8 +624,18 @@ func CreateSession(c *gin.Context) { } } - // Validate that the requested model is available for this runner type - if llmSettings.Model != "" && !isModelAvailable(c.Request.Context(), reqK8s, llmSettings.Model, runnerTypeID, project) { + // Validate model availability with provider matching. + // If the runner type is found in the registry, enforce that the model's + // provider matches the runner's provider. If the registry is unavailable + // (e.g., ConfigMap not mounted), skip provider matching but still validate + // the model against the manifest. + runnerProvider := "" + if rt, rtErr := GetRuntime(runnerTypeID); rtErr == nil { + runnerProvider = rt.Provider + } else { + log.Printf("WARNING: could not resolve runner type %q from registry: %v", runnerTypeID, rtErr) + } + if llmSettings.Model != "" && !isModelAvailable(c.Request.Context(), reqK8s, llmSettings.Model, runnerProvider, project) { c.JSON(http.StatusBadRequest, gin.H{"error": "Model is not available for this runner type"}) return } diff --git a/components/backend/routes.go b/components/backend/routes.go index 177d8e4db..955e4fe57 100644 --- a/components/backend/routes.go +++ b/components/backend/routes.go @@ -13,13 +13,15 @@ func registerRoutes(r *gin.Engine) { { // Public endpoints (no auth required) api.GET("/workflows/ootb", handlers.ListOOTBWorkflows) - api.GET("/runner-types", handlers.GetRunnerTypes) + // Global runner-types endpoint (no workspace overrides — for admin pages) + api.GET("/runner-types", handlers.GetRunnerTypesGlobal) api.POST("/projects/:projectName/agentic-sessions/:sessionName/github/token", handlers.MintSessionGitHubToken) projectGroup := api.Group("/projects/:projectName", handlers.ValidateProjectContext()) { projectGroup.GET("/models", handlers.ListModelsForProject) + projectGroup.GET("/runner-types", handlers.GetRunnerTypes) projectGroup.GET("/access", handlers.AccessCheck) projectGroup.GET("/integration-status", handlers.GetProjectIntegrationStatus) projectGroup.GET("/users/forks", handlers.ListUserForks) diff --git a/components/backend/types/models.go b/components/backend/types/models.go index f01caf124..3f38c527d 100644 --- a/components/backend/types/models.go +++ b/components/backend/types/models.go @@ -20,9 +20,10 @@ type ModelEntry struct { // ModelManifest represents the top-level model manifest structure. type ModelManifest struct { - Version int `json:"version"` - DefaultModel string `json:"defaultModel"` - Models []ModelEntry `json:"models"` + Version int `json:"version"` + DefaultModel string `json:"defaultModel"` + ProviderDefaults map[string]string `json:"providerDefaults,omitempty"` + Models []ModelEntry `json:"models"` } // ListModelsResponse is the API response for the models endpoint. diff --git a/components/frontend/src/app/admin/runtimes/page.tsx b/components/frontend/src/app/admin/runtimes/page.tsx index 22dbc1576..86f1c5921 100644 --- a/components/frontend/src/app/admin/runtimes/page.tsx +++ b/components/frontend/src/app/admin/runtimes/page.tsx @@ -19,7 +19,7 @@ import { Skeleton } from "@/components/ui/skeleton"; import { Breadcrumbs } from "@/components/breadcrumbs"; import { PageHeader } from "@/components/page-header"; import { EmptyState } from "@/components/empty-state"; -import { useRunnerTypes } from "@/services/queries/use-runner-types"; +import { useRunnerTypesGlobal } from "@/services/queries/use-runner-types"; import type { RunnerType } from "@/services/api/runner-types"; function RuntimeStatusBadge() { @@ -40,7 +40,7 @@ function RuntimeDetailPanel({ runtime }: { runtime: RunnerType }) {
Required keys: - {(runtime.auth?.requiredSecretKeys ?? runtime.requiredSecretKeys ?? []).join(", ") || "None"} + {(runtime.auth?.requiredSecretKeys ?? []).join(", ") || "None"}
@@ -56,17 +56,10 @@ function RuntimeDetailPanel({ runtime }: { runtime: RunnerType }) {
-

Models

-
- {runtime.models.map((model) => ( - - {model.label} - - ))} -
-
- Default: {runtime.defaultModel} -
+

Provider

+ + {runtime.provider} +
@@ -95,7 +88,7 @@ function LoadingSkeleton() { } export default function AdminRuntimesPage() { - const { data: runtimes, isLoading, isError, error, refetch } = useRunnerTypes(); + const { data: runtimes, isLoading, isError, error, refetch } = useRunnerTypesGlobal(); const [expandedId, setExpandedId] = useState(null); const toggleExpanded = (id: string) => { @@ -174,7 +167,7 @@ export default function AdminRuntimesPage() { Runtime Description - Models + Provider Status @@ -235,7 +228,7 @@ function RuntimeRow({ {runtime.description || "\u2014"} - {runtime.models.length} + {runtime.provider} diff --git a/components/frontend/src/app/api/projects/[name]/models/route.ts b/components/frontend/src/app/api/projects/[name]/models/route.ts index 12edfa494..c1799a7cf 100644 --- a/components/frontend/src/app/api/projects/[name]/models/route.ts +++ b/components/frontend/src/app/api/projects/[name]/models/route.ts @@ -2,8 +2,9 @@ import { BACKEND_URL } from "@/lib/config"; import { buildForwardHeadersAsync } from "@/lib/auth"; /** - * GET /api/projects/:projectName/models - * Proxies to backend to list available models with workspace overrides + * GET /api/projects/:projectName/models?provider=... + * Proxies to backend to list available models with workspace overrides. + * Optional provider query param filters by model provider. */ export async function GET( request: Request, @@ -13,8 +14,12 @@ export async function GET( const { name: projectName } = await params; const headers = await buildForwardHeadersAsync(request); + const url = new URL(request.url); + const provider = url.searchParams.get("provider"); + const backendParams = provider ? `?provider=${encodeURIComponent(provider)}` : ""; + const response = await fetch( - `${BACKEND_URL}/projects/${encodeURIComponent(projectName)}/models`, + `${BACKEND_URL}/projects/${encodeURIComponent(projectName)}/models${backendParams}`, { headers } ); diff --git a/components/frontend/src/app/api/projects/[name]/runner-types/route.ts b/components/frontend/src/app/api/projects/[name]/runner-types/route.ts new file mode 100644 index 000000000..85a719fed --- /dev/null +++ b/components/frontend/src/app/api/projects/[name]/runner-types/route.ts @@ -0,0 +1,34 @@ +import { BACKEND_URL } from "@/lib/config"; +import { buildForwardHeadersAsync } from "@/lib/auth"; + +/** + * GET /api/projects/:projectName/runner-types + * Proxies to backend to list available runner types with workspace overrides. + */ +export async function GET( + request: Request, + { params }: { params: Promise<{ name: string }> } +) { + try { + const { name: projectName } = await params; + const headers = await buildForwardHeadersAsync(request); + + const response = await fetch( + `${BACKEND_URL}/projects/${encodeURIComponent(projectName)}/runner-types`, + { headers } + ); + + const data = await response.text(); + + return new Response(data, { + status: response.status, + headers: { "Content-Type": "application/json" }, + }); + } catch (error) { + console.error("Failed to fetch runner types:", error); + return Response.json( + { error: "Failed to fetch runner types" }, + { status: 500 } + ); + } +} diff --git a/components/frontend/src/app/projects/[name]/sessions/[sessionName]/page.tsx b/components/frontend/src/app/projects/[name]/sessions/[sessionName]/page.tsx index 17bf5f1b4..fa8b91594 100644 --- a/components/frontend/src/app/projects/[name]/sessions/[sessionName]/page.tsx +++ b/components/frontend/src/app/projects/[name]/sessions/[sessionName]/page.tsx @@ -234,7 +234,7 @@ export default function ProjectSessionDetailPage({ // Fetch runner capabilities and derive agent display name const { data: capabilities } = useCapabilities(projectName, sessionName, phase === "Running"); - const { data: runnerTypes } = useRunnerTypes(); + const { data: runnerTypes } = useRunnerTypes(projectName); const agentName = useMemo(() => { if (capabilities?.framework && runnerTypes) { const matched = runnerTypes.find((rt) => rt.id === capabilities.framework); diff --git a/components/frontend/src/components/create-session-dialog.tsx b/components/frontend/src/components/create-session-dialog.tsx index 92329c0f5..a45930fdb 100644 --- a/components/frontend/src/components/create-session-dialog.tsx +++ b/components/frontend/src/components/create-session-dialog.tsx @@ -42,14 +42,8 @@ import { useIntegrationsStatus } from "@/services/queries/use-integrations"; import { useModels } from "@/services/queries/use-models"; import { errorToast } from "@/hooks/use-toast"; -// Keep in sync with components/manifests/base/models.json (available: true entries). -const fallbackModels = [ - { value: "claude-sonnet-4-5", label: "Claude Sonnet 4.5" }, - { value: "claude-sonnet-4-6", label: "Claude Sonnet 4.6" }, - { value: "claude-opus-4-6", label: "Claude Opus 4.6" }, - { value: "claude-opus-4-5", label: "Claude Opus 4.5" }, - { value: "claude-haiku-4-5", label: "Claude Haiku 4.5" }, -]; +// Static default used for form initialization before the API responds. +const DEFAULT_MODEL = "claude-sonnet-4-5"; const formSchema = z.object({ displayName: z.string().max(50).optional(), @@ -76,16 +70,9 @@ export function CreateSessionDialog({ const [open, setOpen] = useState(false); const router = useRouter(); const createSessionMutation = useCreateSession(); - const { data: runnerTypes, isLoading: runnerTypesLoading, isError: runnerTypesError, refetch: refetchRunnerTypes } = useRunnerTypes(); - - const { data: modelsData, isLoading: modelsLoading } = useModels(projectName, open); + const { data: runnerTypes, isLoading: runnerTypesLoading, isError: runnerTypesError, refetch: refetchRunnerTypes } = useRunnerTypes(projectName); const { data: integrationsStatus } = useIntegrationsStatus(); - const models = modelsData - ? modelsData.models.map((m) => ({ value: m.id, label: m.label })) - : fallbackModels; - const defaultModel = modelsData?.defaultModel ?? "claude-sonnet-4-5"; - const githubConfigured = integrationsStatus?.github?.active != null; const gitlabConfigured = integrationsStatus?.gitlab?.connected ?? false; const atlassianConfigured = integrationsStatus?.jira?.connected ?? false; @@ -96,35 +83,44 @@ export function CreateSessionDialog({ defaultValues: { displayName: "", runnerType: DEFAULT_RUNNER_TYPE_ID, - model: defaultModel, + model: DEFAULT_MODEL, temperature: 0.7, maxTokens: 4000, timeout: 300, }, }); - useEffect(() => { - if (modelsData?.defaultModel && !form.formState.dirtyFields.model) { - form.setValue("model", modelsData.defaultModel, { shouldDirty: false }); - } - }, [modelsData?.defaultModel, form]); - const selectedRunnerType = form.watch("runnerType"); - // Derive the available models from the selected runner type const selectedRunner = useMemo( () => runnerTypes?.find((rt) => rt.id === selectedRunnerType), [runnerTypes, selectedRunnerType] ); - const availableModels = selectedRunner?.models ?? models; + + // Fetch models filtered by the selected runner's provider. + // models.json is the single source of truth — no hardcoded fallback lists. + // Wait for runner types to load so we know the provider before fetching. + const { data: modelsData, isLoading: modelsLoading, isError: modelsError } = useModels( + projectName, open && !runnerTypesLoading && !runnerTypesError, selectedRunner?.provider + ); + + const models = modelsData + ? modelsData.models.map((m) => ({ value: m.id, label: m.label })) + : []; + + // Update form model when API response arrives or provider changes + useEffect(() => { + if (modelsData?.defaultModel && !form.formState.dirtyFields.model) { + form.setValue("model", modelsData.defaultModel, { shouldDirty: false }); + } + }, [modelsData?.defaultModel, form]); const handleRunnerTypeChange = (value: string, onChange: (v: string) => void) => { onChange(value); - const runner = runnerTypes?.find((rt) => rt.id === value); - if (runner) { - // Reset model to the runner's default - form.setValue("model", runner.defaultModel); - } + // Model list will refetch via useModels when provider changes. + // resetField clears both value AND dirty state so the useEffect + // above will set the new provider's default model. + form.resetField("model", { defaultValue: "" }); }; const onSubmit = async (values: FormValues) => { @@ -276,11 +272,17 @@ export function CreateSessionDialog({ - {availableModels.map((m) => ( - - {m.label} - - ))} + {models.length === 0 && !modelsLoading ? ( +
+ No models available for this runner +
+ ) : ( + models.map((m) => ( + + {m.label} + + )) + )}
@@ -433,7 +435,7 @@ export function CreateSessionDialog({ > Cancel -