Skip to content

Commit 07d6126

Browse files
Gkrumbach07claudeAmbient Code Bot
authored
feat: live model switching for running sessions (#1239)
<!-- acp:session_id=session-b6f81446-c636-400f-90ae-907dca607c9b source=#1090 last_action=2026-04-10T14:33:00Z retry_count=1 --> ## Summary - Add `POST /api/projects/{project}/agentic-sessions/{session}/model` endpoint that validates session is Running, updates CR (validates RBAC), proxies model switch to runner, and reverts CR if runner rejects - Add runner `/model` endpoint that guards against mid-generation switches (422), updates `LLM_MODEL` env var, emits `ambient:model_switched` ag-ui event, then calls `mark_dirty()` to rebuild adapter - Add `LiveModelSelector` dropdown in the chat input toolbar (visible only when session is Running, disabled while agent is generating) with loading/error states - Handle `ambient:model_switched` custom event to inject a timestamped confirmation message in the conversation ## Test plan - [x] Frontend: 5 LiveModelSelector unit tests (rendering, disabled states, spinner) - [x] Frontend: 1 event-handler test for `ambient:model_switched` custom event - [x] Runner: 9 endpoint tests (success, validation, mid-generation guard, env update, event emission) - [x] Backend: Go build passes - [x] Frontend: TypeScript compiles cleanly - [ ] Manual: verify model selector appears in chat toolbar for Running sessions - [ ] Manual: verify model switch while agent is idle succeeds - [ ] Manual: verify model switch while agent is generating returns 422 Closes #1090 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Live model selector in the chat toolbar for Running sessions — users can switch the session's LLM at runtime. * **UX** * Shows spinner and disables controls while switching; switches blocked during active generation; failures surface via toast and append a system message about model changes. * **Reliability** * Safer name validation for switch requests and automatic rollback on switch failures to preserve session state. * **Tests** * Frontend and backend tests covering UI behavior and the runtime switch flow. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Ambient Code Bot <bot@ambient-code.local>
1 parent 2da4ada commit 07d6126

File tree

13 files changed

+720
-5
lines changed

13 files changed

+720
-5
lines changed

components/backend/handlers/sessions.go

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"net/url"
1414
"os"
1515
"path/filepath"
16+
"regexp"
1617
"sort"
1718
"strings"
1819
"sync"
@@ -1474,6 +1475,172 @@ func UpdateSessionDisplayName(c *gin.Context) {
14741475
c.JSON(http.StatusOK, session)
14751476
}
14761477

1478+
// SwitchModel switches the LLM model for a running session
1479+
// POST /api/projects/:projectName/agentic-sessions/:sessionName/model
1480+
func SwitchModel(c *gin.Context) {
1481+
project := c.GetString("project")
1482+
sessionName := c.Param("sessionName")
1483+
_, k8sDyn := GetK8sClientsForRequest(c)
1484+
if k8sDyn == nil {
1485+
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or missing token"})
1486+
c.Abort()
1487+
return
1488+
}
1489+
1490+
var req struct {
1491+
Model string `json:"model" binding:"required"`
1492+
}
1493+
if err := c.ShouldBindJSON(&req); err != nil {
1494+
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body: model is required"})
1495+
return
1496+
}
1497+
1498+
if req.Model == "" {
1499+
c.JSON(http.StatusBadRequest, gin.H{"error": "model must not be empty"})
1500+
return
1501+
}
1502+
1503+
gvr := GetAgenticSessionV1Alpha1Resource()
1504+
1505+
// Get current session
1506+
item, err := k8sDyn.Resource(gvr).Namespace(project).Get(context.TODO(), sessionName, v1.GetOptions{})
1507+
if err != nil {
1508+
if errors.IsNotFound(err) {
1509+
c.JSON(http.StatusNotFound, gin.H{"error": "Session not found"})
1510+
return
1511+
}
1512+
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get session"})
1513+
return
1514+
}
1515+
1516+
// Ensure session is Running
1517+
if err := ensureRuntimeMutationAllowed(item); err != nil {
1518+
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
1519+
return
1520+
}
1521+
1522+
// Get current model for comparison
1523+
spec, ok := item.Object["spec"].(map[string]interface{})
1524+
if !ok {
1525+
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid session spec"})
1526+
return
1527+
}
1528+
llmSettings, _, _ := unstructured.NestedMap(spec, "llmSettings")
1529+
previousModel, _ := llmSettings["model"].(string)
1530+
1531+
// No-op if same model
1532+
if previousModel == req.Model {
1533+
session := types.AgenticSession{
1534+
APIVersion: item.GetAPIVersion(),
1535+
Kind: item.GetKind(),
1536+
}
1537+
if meta, ok := item.Object["metadata"].(map[string]interface{}); ok {
1538+
session.Metadata = meta
1539+
}
1540+
session.Spec = parseSpec(spec)
1541+
if status, ok := item.Object["status"].(map[string]interface{}); ok {
1542+
session.Status = parseStatus(status)
1543+
}
1544+
c.JSON(http.StatusOK, session)
1545+
return
1546+
}
1547+
1548+
// Update the CR first to validate RBAC (user needs update permission).
1549+
// This ensures a user with only get access cannot trigger a runner-side
1550+
// model switch without also being allowed to persist the change.
1551+
if llmSettings == nil {
1552+
llmSettings = map[string]interface{}{}
1553+
}
1554+
llmSettings["model"] = req.Model
1555+
spec["llmSettings"] = llmSettings
1556+
1557+
updated, err := k8sDyn.Resource(gvr).Namespace(project).Update(context.TODO(), item, v1.UpdateOptions{})
1558+
if err != nil {
1559+
log.Printf("Failed to update session CR %s for model switch: %v", sessionName, err)
1560+
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update session record"})
1561+
return
1562+
}
1563+
1564+
// Proxy to runner — if runner rejects (e.g., agent is mid-generation), revert the CR.
1565+
// Sanitize the CR name against a strict allowlist to prevent SSRF.
1566+
sanitizedName, err := sanitizeK8sName(item.GetName())
1567+
if err != nil {
1568+
log.Printf("Invalid session name %q for model switch: %v", item.GetName(), err)
1569+
revertModelSwitch(updated, previousModel, k8sDyn, gvr, project)
1570+
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid session name"})
1571+
return
1572+
}
1573+
sanitizedProject, err := sanitizeK8sName(project)
1574+
if err != nil {
1575+
log.Printf("Invalid project name %q for model switch: %v", project, err)
1576+
revertModelSwitch(updated, previousModel, k8sDyn, gvr, project)
1577+
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid project name"})
1578+
return
1579+
}
1580+
serviceName := getRunnerServiceName(sanitizedName)
1581+
runnerURL := fmt.Sprintf("http://%s.%s.svc.cluster.local:8001/model", serviceName, sanitizedProject)
1582+
runnerReq := map[string]string{"model": req.Model}
1583+
reqBody, _ := json.Marshal(runnerReq)
1584+
1585+
httpReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", runnerURL, bytes.NewReader(reqBody))
1586+
if err != nil {
1587+
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create runner request"})
1588+
return
1589+
}
1590+
httpReq.Header.Set("Content-Type", "application/json")
1591+
1592+
client := &http.Client{Timeout: 30 * time.Second}
1593+
resp, err := client.Do(httpReq)
1594+
if err != nil {
1595+
log.Printf("Failed to proxy model switch to runner for session %s: %v", sessionName, err)
1596+
// Revert the CR update on the server-returned object
1597+
revertModelSwitch(updated, previousModel, k8sDyn, gvr, project)
1598+
c.JSON(http.StatusBadGateway, gin.H{"error": "Failed to reach session runner"})
1599+
return
1600+
}
1601+
defer resp.Body.Close()
1602+
1603+
if resp.StatusCode != http.StatusOK {
1604+
body, _ := io.ReadAll(resp.Body)
1605+
log.Printf("Runner rejected model switch for session %s: %d %s", sessionName, resp.StatusCode, string(body))
1606+
// Revert the CR update on the server-returned object
1607+
revertModelSwitch(updated, previousModel, k8sDyn, gvr, project)
1608+
// Forward runner's status code and error
1609+
c.Data(resp.StatusCode, "application/json", body)
1610+
return
1611+
}
1612+
1613+
session := types.AgenticSession{
1614+
APIVersion: updated.GetAPIVersion(),
1615+
Kind: updated.GetKind(),
1616+
}
1617+
if meta, ok := updated.Object["metadata"].(map[string]interface{}); ok {
1618+
session.Metadata = meta
1619+
}
1620+
if s, ok := updated.Object["spec"].(map[string]interface{}); ok {
1621+
session.Spec = parseSpec(s)
1622+
}
1623+
if status, ok := updated.Object["status"].(map[string]interface{}); ok {
1624+
session.Status = parseStatus(status)
1625+
}
1626+
1627+
c.JSON(http.StatusOK, session)
1628+
}
1629+
1630+
// revertModelSwitch restores the previous model on the server-returned CR object.
1631+
// Called when the runner rejects a model switch after the CR was already updated.
1632+
func revertModelSwitch(updated *unstructured.Unstructured, previousModel string, k8sDyn dynamic.Interface, gvr schema.GroupVersionResource, namespace string) {
1633+
if updatedSpec, ok := updated.Object["spec"].(map[string]interface{}); ok {
1634+
if updatedLLM, ok := updatedSpec["llmSettings"].(map[string]interface{}); ok {
1635+
updatedLLM["model"] = previousModel
1636+
_, err := k8sDyn.Resource(gvr).Namespace(namespace).Update(context.TODO(), updated, v1.UpdateOptions{})
1637+
if err != nil {
1638+
log.Printf("Failed to revert model switch for session %s: %v", updated.GetName(), err)
1639+
}
1640+
}
1641+
}
1642+
}
1643+
14771644
// SelectWorkflow sets the active workflow for a session
14781645
// POST /api/projects/:projectName/agentic-sessions/:sessionName/workflow
14791646
func SelectWorkflow(c *gin.Context) {
@@ -1945,6 +2112,20 @@ func RemoveRepo(c *gin.Context) {
19452112
c.JSON(http.StatusOK, gin.H{"message": "Repository removed", "session": session})
19462113
}
19472114

2115+
// k8sNameRegexp matches valid Kubernetes resource names (RFC 1123 DNS label).
2116+
var k8sNameRegexp = regexp.MustCompile(`^[a-z0-9]([a-z0-9\-]*[a-z0-9])?$`)
2117+
2118+
// sanitizeK8sName validates that name is a valid Kubernetes resource name
2119+
// and returns it unchanged if valid, or returns an error. This breaks the
2120+
// taint chain for static analysis (CodeQL SSRF) by proving the value matches
2121+
// a strict allowlist before it reaches any network call.
2122+
func sanitizeK8sName(name string) (string, error) {
2123+
if len(name) == 0 || len(name) > 253 || !k8sNameRegexp.MatchString(name) {
2124+
return "", fmt.Errorf("invalid Kubernetes resource name: %q", name)
2125+
}
2126+
return name, nil
2127+
}
2128+
19482129
// getRunnerServiceName returns the K8s Service name for a session's runner.
19492130
// The runner serves both AG-UI and content endpoints on port 8001.
19502131
func getRunnerServiceName(session string) string {

components/backend/routes.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ func registerRoutes(r *gin.Engine) {
5959
projectGroup.GET("/agentic-sessions/:sessionName/repos/status", handlers.GetReposStatus)
6060
projectGroup.DELETE("/agentic-sessions/:sessionName/repos/:repoName", handlers.RemoveRepo)
6161
projectGroup.PUT("/agentic-sessions/:sessionName/displayname", handlers.UpdateSessionDisplayName)
62+
projectGroup.POST("/agentic-sessions/:sessionName/model", handlers.SwitchModel)
6263

6364
// OAuth integration - requires user auth like all other session endpoints
6465
projectGroup.GET("/agentic-sessions/:sessionName/oauth/:provider/url", handlers.GetOAuthURL)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import { describe, it, expect, vi, beforeEach } from 'vitest';
2+
import { render, screen } from '@testing-library/react';
3+
import { LiveModelSelector } from '../live-model-selector';
4+
import type { ListModelsResponse } from '@/types/api';
5+
6+
const mockAnthropicModels: ListModelsResponse = {
7+
models: [
8+
{ id: 'claude-haiku-4-5', label: 'Claude Haiku 4.5', provider: 'anthropic', isDefault: false },
9+
{ id: 'claude-sonnet-4-5', label: 'Claude Sonnet 4.5', provider: 'anthropic', isDefault: true },
10+
{ id: 'claude-opus-4-6', label: 'Claude Opus 4.6', provider: 'anthropic', isDefault: false },
11+
],
12+
defaultModel: 'claude-sonnet-4-5',
13+
};
14+
15+
const mockUseModels = vi.fn(() => ({ data: mockAnthropicModels }));
16+
17+
vi.mock('@/services/queries/use-models', () => ({
18+
useModels: () => mockUseModels(),
19+
}));
20+
21+
describe('LiveModelSelector', () => {
22+
const defaultProps = {
23+
projectName: 'test-project',
24+
currentModel: 'claude-sonnet-4-5',
25+
onSelect: vi.fn(),
26+
};
27+
28+
beforeEach(() => {
29+
vi.clearAllMocks();
30+
mockUseModels.mockReturnValue({ data: mockAnthropicModels });
31+
});
32+
33+
it('renders with current model name displayed', () => {
34+
render(<LiveModelSelector {...defaultProps} />);
35+
const button = screen.getByRole('button');
36+
expect(button.textContent).toContain('Claude Sonnet 4.5');
37+
});
38+
39+
it('renders with model id fallback when model not in list', () => {
40+
render(
41+
<LiveModelSelector
42+
{...defaultProps}
43+
currentModel="unknown-model-id"
44+
/>
45+
);
46+
const button = screen.getByRole('button');
47+
expect(button.textContent).toContain('unknown-model-id');
48+
});
49+
50+
it('shows spinner when switching', () => {
51+
render(<LiveModelSelector {...defaultProps} switching />);
52+
const spinner = document.querySelector('.animate-spin');
53+
expect(spinner).not.toBeNull();
54+
});
55+
56+
it('button is disabled when disabled prop is true', () => {
57+
render(<LiveModelSelector {...defaultProps} disabled />);
58+
const button = screen.getByRole('button');
59+
expect((button as HTMLButtonElement).disabled).toBe(true);
60+
});
61+
62+
it('button is disabled when switching prop is true', () => {
63+
render(<LiveModelSelector {...defaultProps} switching />);
64+
const button = screen.getByRole('button');
65+
expect((button as HTMLButtonElement).disabled).toBe(true);
66+
});
67+
});
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"use client";
2+
3+
import { useMemo } from "react";
4+
import { ChevronDown, Loader2 } from "lucide-react";
5+
import { Button } from "@/components/ui/button";
6+
import {
7+
DropdownMenu,
8+
DropdownMenuContent,
9+
DropdownMenuRadioGroup,
10+
DropdownMenuRadioItem,
11+
DropdownMenuTrigger,
12+
} from "@/components/ui/dropdown-menu";
13+
import { useModels } from "@/services/queries/use-models";
14+
15+
type LiveModelSelectorProps = {
16+
projectName: string;
17+
currentModel: string;
18+
provider?: string;
19+
disabled?: boolean;
20+
switching?: boolean;
21+
onSelect: (model: string) => void;
22+
};
23+
24+
export function LiveModelSelector({
25+
projectName,
26+
currentModel,
27+
provider,
28+
disabled,
29+
switching,
30+
onSelect,
31+
}: LiveModelSelectorProps) {
32+
const { data: modelsData, isLoading, isError } = useModels(projectName, true, provider);
33+
34+
const models = useMemo(() => {
35+
return modelsData?.models.map((m) => ({ id: m.id, name: m.label })) ?? [];
36+
}, [modelsData]);
37+
38+
const currentModelName =
39+
models.find((m) => m.id === currentModel)?.name ?? currentModel;
40+
41+
return (
42+
<DropdownMenu>
43+
<DropdownMenuTrigger asChild>
44+
<Button
45+
variant="ghost"
46+
size="sm"
47+
className="gap-1 text-xs text-muted-foreground hover:text-foreground h-7 px-2"
48+
disabled={disabled || switching}
49+
>
50+
{switching ? (
51+
<Loader2 className="h-3 w-3 animate-spin" />
52+
) : null}
53+
<span className="truncate max-w-[160px]">
54+
{currentModelName}
55+
</span>
56+
<ChevronDown className="h-3 w-3 opacity-50 flex-shrink-0" />
57+
</Button>
58+
</DropdownMenuTrigger>
59+
<DropdownMenuContent align="end" side="top" sideOffset={4}>
60+
{isLoading ? (
61+
<div className="flex items-center justify-center px-2 py-4">
62+
<Loader2 className="h-4 w-4 animate-spin text-muted-foreground" />
63+
</div>
64+
) : isError ? (
65+
<div className="px-2 py-4 text-center text-sm text-destructive">
66+
Failed to load models
67+
</div>
68+
) : models.length > 0 ? (
69+
<DropdownMenuRadioGroup
70+
value={currentModel}
71+
onValueChange={(modelId) => {
72+
if (modelId !== currentModel) {
73+
onSelect(modelId);
74+
}
75+
}}
76+
>
77+
{models.map((model) => (
78+
<DropdownMenuRadioItem key={model.id} value={model.id}>
79+
{model.name}
80+
</DropdownMenuRadioItem>
81+
))}
82+
</DropdownMenuRadioGroup>
83+
) : (
84+
<div className="px-2 py-4 text-center text-sm text-muted-foreground">
85+
No models available
86+
</div>
87+
)}
88+
</DropdownMenuContent>
89+
</DropdownMenu>
90+
);
91+
}

0 commit comments

Comments
 (0)