diff --git a/packages/app/src/components/header.tsx b/packages/app/src/components/header.tsx index 9a620c90342..55fd5ff5ca3 100644 --- a/packages/app/src/components/header.tsx +++ b/packages/app/src/components/header.tsx @@ -59,8 +59,8 @@ export function Header(props: { when={layout.projects.list().length > 0 && params.dir} fallback={ } > @@ -121,6 +121,12 @@ export function Header(props: {
+ {/* Theme and Font first - desktop only */} + + {/* Review toggle - requires session */} + {/* Terminal toggle - always visible on desktop */} + {/* Share - requires session and share enabled */} -
) diff --git a/packages/app/src/components/welcome-screen.tsx b/packages/app/src/components/welcome-screen.tsx new file mode 100644 index 00000000000..cc2172edd2d --- /dev/null +++ b/packages/app/src/components/welcome-screen.tsx @@ -0,0 +1,240 @@ +import { createEffect, createMemo, createSignal, onCleanup, Show, For } from "solid-js" +import { createStore, reconcile } from "solid-js/store" +import { AsciiLogo } from "@opencode-ai/ui/logo" +import { Button } from "@opencode-ai/ui/button" +import { TextField } from "@opencode-ai/ui/text-field" +import { Icon } from "@opencode-ai/ui/icon" +import { normalizeServerUrl, serverDisplayName, useServer } from "@/context/server" +import { usePlatform } from "@/context/platform" +import { createOpencodeClient } from "@opencode-ai/sdk/v2/client" +import { isHostedEnvironment, hasUrlQueryParam, getUrlQueryParam } from "@/utils/hosted" + +type ServerStatus = { healthy: boolean; version?: string } + +async function checkHealth(url: string, fetch?: typeof globalThis.fetch): Promise { + const sdk = createOpencodeClient({ + baseUrl: url, + fetch, + signal: AbortSignal.timeout(3000), + }) + return sdk.global + .health() + .then((x) => ({ healthy: x.data?.healthy === true, version: x.data?.version })) + .catch(() => ({ healthy: false })) +} + +export interface WelcomeScreenProps { + attemptedUrl?: string + onRetry?: () => void +} + +export function WelcomeScreen(props: WelcomeScreenProps) { + const server = useServer() + const platform = usePlatform() + const [store, setStore] = createStore({ + url: "", + connecting: false, + error: "", + status: {} as Record, + }) + + const urlOverride = getUrlQueryParam() + const isLocalhost = () => { + const url = props.attemptedUrl || "" + return url.includes("localhost") || url.includes("127.0.0.1") + } + + const items = createMemo(() => { + const list = server.list + return list.filter((x) => x !== props.attemptedUrl) + }) + + async function refreshHealth() { + const results: Record = {} + await Promise.all( + items().map(async (url) => { + results[url] = await checkHealth(url, platform.fetch) + }), + ) + setStore("status", reconcile(results)) + } + + createEffect(() => { + if (items().length === 0) return + refreshHealth() + const interval = setInterval(refreshHealth, 10_000) + onCleanup(() => clearInterval(interval)) + }) + + async function handleConnect(url: string, persist = false) { + const normalized = normalizeServerUrl(url) + if (!normalized) return + + setStore("connecting", true) + setStore("error", "") + + const result = await checkHealth(normalized, platform.fetch) + setStore("connecting", false) + + if (!result.healthy) { + setStore("error", "Could not connect to server") + return + } + + if (persist) { + server.add(normalized) + } else { + server.setActive(normalized) + } + props.onRetry?.() + } + + async function handleSubmit(e: SubmitEvent) { + e.preventDefault() + const value = normalizeServerUrl(store.url) + if (!value) return + await handleConnect(value, true) + } + + return ( +
+
+ + +
+

Welcome to Shuvcode

+

+ {urlOverride + ? `Could not connect to the server at ${urlOverride}` + : "Connect to a Shuvcode server to get started"} +

+
+ + {/* Local Server Section */} +
+
+ +

Local Server

+
+ + +
+

Start a local server by running:

+ shuvcode +

or

+ npx shuvcode +
+
+ + +
+ + {/* Remote Server Section */} +
+
+ +

Remote Server

+
+ +
+
+
+ { + setStore("url", v) + setStore("error", "") + }} + validationState={store.error ? "invalid" : "valid"} + error={store.error} + /> +
+ +
+
+ +

+ Note: Connecting to a remote server means trusting that server with your data. +

+
+ + {/* Saved Servers Section */} + 0}> +
+

Saved Servers

+
+ + {(url) => ( + + )} + +
+
+
+ + {/* Troubleshooting Section */} + +
+ Troubleshooting +
+

+ Server not running: Make sure you have a Shuvcode server running locally or accessible + remotely. +

+

+ CORS blocked: The server must allow requests from{" "} + {location.origin}. Local servers automatically allow + this domain. +

+

+ Mixed content: If connecting to an http:// server from this{" "} + https:// page, your browser may block the connection. Use https:// for remote + servers. +

+
+
+
+ + +

Version: {platform.version}

+
+
+
+ ) +} diff --git a/packages/app/src/context/global-sync.tsx b/packages/app/src/context/global-sync.tsx index cb7bf9cf737..59704abda2b 100644 --- a/packages/app/src/context/global-sync.tsx +++ b/packages/app/src/context/global-sync.tsx @@ -25,9 +25,12 @@ import { Binary } from "@opencode-ai/util/binary" import { retry } from "@opencode-ai/util/retry" import { useGlobalSDK } from "./global-sdk" import { ErrorPage, type InitError } from "../pages/error" +import { WelcomeScreen } from "../components/welcome-screen" import { batch, createContext, useContext, onMount, type ParentProps, Switch, Match } from "solid-js" import { showToast } from "@opencode-ai/ui/toast" import { getFilename } from "@opencode-ai/util/path" +import { isHostedEnvironment } from "@/utils/hosted" +import { useServer } from "./server" type State = { ready: boolean @@ -66,17 +69,24 @@ type State = { } } +type ConnectionState = "connecting" | "ready" | "needs_config" | "error" + function createGlobalSync() { const globalSDK = useGlobalSDK() + const server = useServer() const [globalStore, setGlobalStore] = createStore<{ + connectionState: ConnectionState ready: boolean error?: InitError + attemptedUrl?: string path: Path project: Project[] provider: ProviderListResponse provider_auth: ProviderAuthResponse }>({ + connectionState: "connecting", ready: false, + attemptedUrl: undefined, path: { state: "", config: "", worktree: "", directory: "", home: "" }, project: [], provider: { all: [], connected: [], default: {} }, @@ -402,16 +412,46 @@ function createGlobalSync() { } }) + /** + * Probes the server health with a short timeout (2 seconds). + * Used for initial connection to provide quick feedback. + */ + async function probeHealth( + url: string, + healthFn: () => Promise<{ data?: { healthy?: boolean } }>, + ): Promise<{ healthy: boolean }> { + try { + const controller = new AbortController() + const timeoutId = setTimeout(() => controller.abort(), 2000) + + const result = await healthFn() + clearTimeout(timeoutId) + + return { healthy: result.data?.healthy === true } + } catch { + return { healthy: false } + } + } + async function bootstrap() { - const health = await globalSDK.client.global - .health() - .then((x) => x.data) - .catch(() => undefined) - if (!health?.healthy) { + setGlobalStore("connectionState", "connecting") + setGlobalStore("attemptedUrl", globalSDK.url) + + // Use a short timeout for the health probe (2 seconds) + const probeResult = await probeHealth(globalSDK.url, () => globalSDK.client.global.health()) + + if (!probeResult.healthy) { + // For hosted environments, show the welcome/configuration screen + if (isHostedEnvironment()) { + setGlobalStore("connectionState", "needs_config") + return + } + // For non-hosted environments, show the error page setGlobalStore( "error", new Error(`Could not connect to server. Is there a server running at \`${globalSDK.url}\`?`), ) + setGlobalStore("connectionState", "error") return } @@ -452,8 +492,14 @@ function createGlobalSync() { }), ), ]) - .then(() => setGlobalStore("ready", true)) - .catch((e) => setGlobalStore("error", e)) + .then(() => { + setGlobalStore("ready", true) + setGlobalStore("connectionState", "ready") + }) + .catch((e) => { + setGlobalStore("error", e) + setGlobalStore("connectionState", "error") + }) } onMount(() => { @@ -468,6 +514,12 @@ function createGlobalSync() { get error() { return globalStore.error }, + get connectionState() { + return globalStore.connectionState + }, + get attemptedUrl() { + return globalStore.attemptedUrl + }, child, bootstrap, project: { @@ -482,10 +534,18 @@ export function GlobalSyncProvider(props: ParentProps) { const value = createGlobalSync() return ( - + +
+
Connecting to server...
+
+
+ + value.bootstrap()} /> + + - + {props.children}
diff --git a/packages/app/src/context/layout.tsx b/packages/app/src/context/layout.tsx index f6a5adeb42a..6cceae188f0 100644 --- a/packages/app/src/context/layout.tsx +++ b/packages/app/src/context/layout.tsx @@ -1,5 +1,5 @@ import { createStore, produce } from "solid-js/store" -import { batch, createEffect, createMemo, onMount } from "solid-js" +import { batch, createEffect, createMemo, onCleanup, onMount } from "solid-js" import { createSimpleContext } from "@opencode-ai/ui/context" import { useGlobalSync } from "./global-sync" import { useGlobalSDK } from "./global-sdk" @@ -10,6 +10,12 @@ import { applyTheme, DEFAULT_THEME_ID } from "@/theme/apply-theme" import { applyFontWithLoad } from "@/fonts/apply-font" import { getFontById, FONTS } from "@/fonts/font-definitions" +export const REVIEW_PANE = { + DEFAULT_WIDTH: 450, + MIN_WIDTH: 200, + MAX_WIDTH_RATIO: 0.33, +} as const + const AVATAR_COLOR_KEYS = ["pink", "mint", "orange", "purple", "cyan", "lime"] as const export type AvatarColorKey = (typeof AVATAR_COLOR_KEYS)[number] @@ -57,7 +63,7 @@ export const { use: useLayout, provider: LayoutProvider } = createSimpleContext( review: { opened: false, state: "pane" as "pane" | "tab", - width: 450, + width: REVIEW_PANE.DEFAULT_WIDTH as number, }, session: { width: 600, @@ -127,12 +133,41 @@ export const { use: useLayout, provider: LayoutProvider } = createSimpleContext( const enriched = createMemo(() => server.projects.list().flatMap(enrich)) const list = createMemo(() => enriched().flatMap(colorize)) + // Helper function to calculate minimum session width (to enforce max review pane width) + const getMinSessionWidth = () => { + if (typeof window === "undefined") return 320 + const maxReviewWidth = window.innerWidth * REVIEW_PANE.MAX_WIDTH_RATIO + return Math.max(REVIEW_PANE.MIN_WIDTH, window.innerWidth - maxReviewWidth) + } + + // Clamp session width to enforce review pane constraints + const clampSessionWidth = () => { + const minSessionWidth = getMinSessionWidth() + if (store.session?.width && store.session.width < minSessionWidth) { + setStore("session", "width", minSessionWidth) + } + } + onMount(() => { + // Load project sessions Promise.all( server.projects.list().map((project) => { return globalSync.project.loadSessions(project.worktree) }), ) + + // Normalize persisted review state (ensure opened defaults to false for old/missing state) + if (store.review === undefined || store.review.opened === undefined) { + setStore("review", "opened", false) + } + + // Clamp session width on initial load + clampSessionWidth() + + // Re-clamp on window resize + const handleResize = () => clampSessionWidth() + window.addEventListener("resize", handleResize) + onCleanup(() => window.removeEventListener("resize", handleResize)) }) createEffect(() => { @@ -226,10 +261,13 @@ export const { use: useLayout, provider: LayoutProvider } = createSimpleContext( session: { width: createMemo(() => store.session?.width ?? 600), resize(width: number) { + // Enforce minimum session width to limit review pane to MAX_WIDTH_RATIO + const minSessionWidth = getMinSessionWidth() + const clampedWidth = Math.max(minSessionWidth, width) if (!store.session) { - setStore("session", { width }) + setStore("session", { width: clampedWidth }) } else { - setStore("session", "width", width) + setStore("session", "width", clampedWidth) } }, }, diff --git a/packages/app/src/context/server.tsx b/packages/app/src/context/server.tsx index c77b027ec7d..18e59bb3bc4 100644 --- a/packages/app/src/context/server.tsx +++ b/packages/app/src/context/server.tsx @@ -39,10 +39,11 @@ export const { use: useServer, provider: ServerProvider } = createSimpleContext( const platform = usePlatform() const [store, setStore, _, ready] = persisted( - "server.v3", + "server.v4", createStore({ list: [] as string[], projects: {} as Record, + active: "" as string, // Persist the last active server }), ) @@ -51,7 +52,10 @@ export const { use: useServer, provider: ServerProvider } = createSimpleContext( function setActive(input: string) { const url = normalizeServerUrl(input) if (!url) return - setActiveRaw(url) + batch(() => { + setActiveRaw(url) + setStore("active", url) // Persist active server + }) } function add(input: string) { @@ -60,7 +64,10 @@ export const { use: useServer, provider: ServerProvider } = createSimpleContext( const fallback = normalizeServerUrl(props.defaultUrl) if (fallback && url === fallback) { - setActiveRaw(url) + batch(() => { + setActiveRaw(url) + setStore("active", url) + }) return } @@ -69,6 +76,7 @@ export const { use: useServer, provider: ServerProvider } = createSimpleContext( setStore("list", store.list.length, url) } setActiveRaw(url) + setStore("active", url) }) } @@ -82,13 +90,17 @@ export const { use: useServer, provider: ServerProvider } = createSimpleContext( batch(() => { setStore("list", list) setActiveRaw(next) + setStore("active", next) }) } + // Initialize active server from persisted state or default createEffect(() => { if (!ready()) return if (active()) return - const url = normalizeServerUrl(props.defaultUrl) + // Priority: persisted active > default URL + const persistedActive = store.active ? normalizeServerUrl(store.active) : undefined + const url = persistedActive || normalizeServerUrl(props.defaultUrl) if (!url) return setActiveRaw(url) }) diff --git a/packages/app/src/pages/session.tsx b/packages/app/src/pages/session.tsx index c45830bea07..deabb6385d3 100644 --- a/packages/app/src/pages/session.tsx +++ b/packages/app/src/pages/session.tsx @@ -42,7 +42,7 @@ import type { DragEvent } from "@thisbeyond/solid-dnd" import type { JSX } from "solid-js" import { useSync } from "@/context/sync" import { useTerminal, type LocalPTY } from "@/context/terminal" -import { useLayout } from "@/context/layout" +import { useLayout, REVIEW_PANE } from "@/context/layout" import { getDirectory, getFilename } from "@opencode-ai/util/path" import { Terminal } from "@/components/terminal" import { checksum } from "@opencode-ai/util/encode" @@ -773,6 +773,7 @@ export default function Page() { ) } + const hasReviewContent = createMemo(() => diffs().length > 0 || tabs().all().length > 0) const showTabs = createMemo(() => layout.review.opened()) const tabsValue = createMemo(() => tabs().active() ?? "review") @@ -879,7 +880,7 @@ export default function Page() { direction="horizontal" size={layout.session.width()} min={320} - max={window.innerWidth * 0.7} + max={window.innerWidth - REVIEW_PANE.MIN_WIDTH} onResize={layout.session.resize} /> @@ -941,83 +942,96 @@ export default function Page() { - - -
- setStore("diffSplit", (x) => !x)} - > - {store.diffSplit ? "Inline" : "Split"} - - } - /> + +
+ +
No files to review
+
Changes will appear here
+
-
-
- - {(tab) => { - const [file] = createResource( - () => tab, - async (tab) => { - if (tab.startsWith("file://")) { - return local.file.node(tab.replace("file://", "")) - } - return undefined - }, - ) - return ( - - - {(content) => { - const f = file()! - const isPreviewableImage = - content.encoding === "base64" && - content.mimeType?.startsWith("image/") && - content.mimeType !== "image/svg+xml" - return ( - - -
- {f.path} -
-
- -
- -
-
-
- ) + } + > + + +
+ - - ) - }} - + diffs={diffs()} + split={store.diffSplit} + actions={ + + } + /> +
+
+
+ + {(tab) => { + const [file] = createResource( + () => tab, + async (tab) => { + if (tab.startsWith("file://")) { + return local.file.node(tab.replace("file://", "")) + } + return undefined + }, + ) + return ( + + + {(content) => { + const f = file()! + const isPreviewableImage = + content.encoding === "base64" && + content.mimeType?.startsWith("image/") && + content.mimeType !== "image/svg+xml" + return ( + + +
+ {f.path} +
+
+ +
+ +
+
+
+ ) + }} +
+
+ ) + }} +
+
diff --git a/packages/app/src/utils/hosted.ts b/packages/app/src/utils/hosted.ts new file mode 100644 index 00000000000..94c4c9b9c91 --- /dev/null +++ b/packages/app/src/utils/hosted.ts @@ -0,0 +1,25 @@ +/** + * Checks if the app is running in a hosted environment (app.shuv.ai or app.opencode.ai). + * In hosted environments, users need to configure their server connection. + */ +export function isHostedEnvironment(): boolean { + if (typeof window === "undefined") return false + return location.hostname.includes("opencode.ai") || location.hostname.includes("shuv.ai") +} + +/** + * Checks if a ?url= query parameter was provided in the URL. + * This indicates the user is trying to connect to a specific server. + */ +export function hasUrlQueryParam(): boolean { + if (typeof window === "undefined") return false + return new URLSearchParams(document.location.search).has("url") +} + +/** + * Gets the ?url= query parameter value if present. + */ +export function getUrlQueryParam(): string | null { + if (typeof window === "undefined") return null + return new URLSearchParams(document.location.search).get("url") +} diff --git a/packages/app/test/hosted.test.ts b/packages/app/test/hosted.test.ts new file mode 100644 index 00000000000..09d1620416e --- /dev/null +++ b/packages/app/test/hosted.test.ts @@ -0,0 +1,132 @@ +import { describe, expect, test, beforeEach, afterEach } from "bun:test" + +// Note: These tests require the happy-dom environment set up via bunfig.toml + +describe("hosted.ts utilities", () => { + let originalHostname: string + let originalSearch: string + + beforeEach(() => { + originalHostname = window.location.hostname + originalSearch = window.location.search + }) + + afterEach(() => { + // Reset location properties (happy-dom allows this) + Object.defineProperty(window.location, "hostname", { + value: originalHostname, + writable: true, + }) + Object.defineProperty(window.location, "search", { + value: originalSearch, + writable: true, + }) + }) + + describe("isHostedEnvironment", () => { + test("returns true for opencode.ai domains", async () => { + Object.defineProperty(window.location, "hostname", { + value: "app.opencode.ai", + writable: true, + }) + + // Dynamic import to get fresh evaluation + const { isHostedEnvironment } = await import("../src/utils/hosted") + expect(isHostedEnvironment()).toBe(true) + }) + + test("returns true for shuv.ai domains", async () => { + Object.defineProperty(window.location, "hostname", { + value: "app.shuv.ai", + writable: true, + }) + + const { isHostedEnvironment } = await import("../src/utils/hosted") + expect(isHostedEnvironment()).toBe(true) + }) + + test("returns false for localhost", async () => { + Object.defineProperty(window.location, "hostname", { + value: "localhost", + writable: true, + }) + + const { isHostedEnvironment } = await import("../src/utils/hosted") + expect(isHostedEnvironment()).toBe(false) + }) + + test("returns false for other domains", async () => { + Object.defineProperty(window.location, "hostname", { + value: "example.com", + writable: true, + }) + + const { isHostedEnvironment } = await import("../src/utils/hosted") + expect(isHostedEnvironment()).toBe(false) + }) + }) + + describe("hasUrlQueryParam", () => { + test("returns true when ?url= parameter exists", async () => { + Object.defineProperty(window.location, "search", { + value: "?url=http://localhost:4096", + writable: true, + }) + + const { hasUrlQueryParam } = await import("../src/utils/hosted") + expect(hasUrlQueryParam()).toBe(true) + }) + + test("returns false when no ?url= parameter", async () => { + Object.defineProperty(window.location, "search", { + value: "", + writable: true, + }) + + const { hasUrlQueryParam } = await import("../src/utils/hosted") + expect(hasUrlQueryParam()).toBe(false) + }) + + test("returns false when other parameters exist but not ?url=", async () => { + Object.defineProperty(window.location, "search", { + value: "?foo=bar&baz=qux", + writable: true, + }) + + const { hasUrlQueryParam } = await import("../src/utils/hosted") + expect(hasUrlQueryParam()).toBe(false) + }) + }) + + describe("getUrlQueryParam", () => { + test("returns the URL value when present", async () => { + Object.defineProperty(window.location, "search", { + value: "?url=http://localhost:4096", + writable: true, + }) + + const { getUrlQueryParam } = await import("../src/utils/hosted") + expect(getUrlQueryParam()).toBe("http://localhost:4096") + }) + + test("returns null when not present", async () => { + Object.defineProperty(window.location, "search", { + value: "", + writable: true, + }) + + const { getUrlQueryParam } = await import("../src/utils/hosted") + expect(getUrlQueryParam()).toBeNull() + }) + + test("handles URL-encoded values", async () => { + Object.defineProperty(window.location, "search", { + value: "?url=https%3A%2F%2Fmy-server.example.com%3A8080", + writable: true, + }) + + const { getUrlQueryParam } = await import("../src/utils/hosted") + expect(getUrlQueryParam()).toBe("https://my-server.example.com:8080") + }) + }) +}) diff --git a/packages/opencode/src/server/cors.ts b/packages/opencode/src/server/cors.ts new file mode 100644 index 00000000000..b791520668c --- /dev/null +++ b/packages/opencode/src/server/cors.ts @@ -0,0 +1,27 @@ +/** + * Checks if the given origin is allowed by the CORS policy. + * @param origin - The origin header value from the request + * @returns The origin string if allowed, undefined otherwise + */ +export function isOriginAllowed(origin: string | undefined): string | undefined { + if (!origin) return undefined + + // localhost (http only, any port) + if (origin.startsWith("http://localhost:")) return origin + if (origin.startsWith("http://127.0.0.1:")) return origin + + // Tauri desktop origins + if (origin === "tauri://localhost" || origin === "http://tauri.localhost") return origin + + // *.opencode.ai (https only) + if (/^https:\/\/([a-z0-9-]+\.)*opencode\.ai$/.test(origin)) { + return origin + } + + // *.shuv.ai (https only) - fork's hosted domain + if (/^https:\/\/([a-z0-9-]+\.)*shuv\.ai$/.test(origin)) { + return origin + } + + return undefined +} diff --git a/packages/opencode/src/server/server.ts b/packages/opencode/src/server/server.ts index 63ef1b647f9..9b755f5150c 100644 --- a/packages/opencode/src/server/server.ts +++ b/packages/opencode/src/server/server.ts @@ -1,3 +1,4 @@ +import { isOriginAllowed } from "./cors" import { BusEvent } from "@/bus/bus-event" import { Bus } from "@/bus" import { GlobalBus } from "@/bus/global" @@ -113,19 +114,7 @@ export namespace Server { }) .use( cors({ - origin(input) { - if (!input) return - - if (input.startsWith("http://localhost:")) return input - if (input.startsWith("http://127.0.0.1:")) return input - if (input === "tauri://localhost" || input === "http://tauri.localhost") return input - - // *.opencode.ai (https only, adjust if needed) - if (/^https:\/\/([a-z0-9-]+\.)*opencode\.ai$/.test(input)) { - return input - } - return - }, + origin: isOriginAllowed, }), ) .get( diff --git a/packages/opencode/src/session/message-v2.ts b/packages/opencode/src/session/message-v2.ts index bb78ae64ce6..171ab6937e5 100644 --- a/packages/opencode/src/session/message-v2.ts +++ b/packages/opencode/src/session/message-v2.ts @@ -161,6 +161,19 @@ export namespace MessageV2 { description: z.string(), agent: z.string(), command: z.string().optional(), + model: z + .object({ + providerID: z.string(), + modelID: z.string(), + }) + .optional(), + parentAgent: z.string().optional(), + parentModel: z + .object({ + providerID: z.string(), + modelID: z.string(), + }) + .optional(), }) export type SubtaskPart = z.infer diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts index 40c44f2d07f..fdb0e76ebb5 100644 --- a/packages/opencode/src/session/prompt.ts +++ b/packages/opencode/src/session/prompt.ts @@ -90,6 +90,7 @@ export namespace SessionPrompt { noReply: z.boolean().optional(), tools: z.record(z.string(), z.boolean()).optional(), system: z.string().optional(), + variant: z.string().optional(), parts: z.array( z.discriminatedUnion("type", [ MessageV2.TextPart.omit({ @@ -283,139 +284,149 @@ export namespace SessionPrompt { }) const model = await Provider.getModel(lastUser.model.providerID, lastUser.model.modelID) - const task = tasks.pop() - // pending subtask + const subtasks = tasks.filter((t): t is Extract => t.type === "subtask") + const otherTasks = tasks.filter((t) => t.type !== "subtask") + tasks.length = 0 + tasks.push(...otherTasks) + + // pending subtasks // TODO: centralize "invoke tool" logic - if (task?.type === "subtask") { + if (subtasks.length > 0) { const taskTool = await TaskTool.init() - const assistantMessage = (await Session.updateMessage({ - id: Identifier.ascending("message"), - role: "assistant", - parentID: lastUser.id, - sessionID, - mode: task.agent, - agent: task.agent, - path: { - cwd: Instance.directory, - root: Instance.worktree, - }, - cost: 0, - tokens: { - input: 0, - output: 0, - reasoning: 0, - cache: { read: 0, write: 0 }, - }, - modelID: model.id, - providerID: model.providerID, - time: { - created: Date.now(), - }, - })) as MessageV2.Assistant - let part = (await Session.updatePart({ - id: Identifier.ascending("part"), - messageID: assistantMessage.id, - sessionID: assistantMessage.sessionID, - type: "tool", - callID: ulid(), - tool: TaskTool.id, - state: { - status: "running", - input: { - prompt: task.prompt, - description: task.description, - subagent_type: task.agent, - command: task.command, + + const executeSubtask = async (task: (typeof subtasks)[0]) => { + const assistantMessage = (await Session.updateMessage({ + id: Identifier.ascending("message"), + role: "assistant", + parentID: lastUser.id, + sessionID, + mode: task.agent, + agent: task.agent, + path: { + cwd: Instance.directory, + root: Instance.worktree, }, + cost: 0, + tokens: { + input: 0, + output: 0, + reasoning: 0, + cache: { read: 0, write: 0 }, + }, + modelID: model.id, + providerID: model.providerID, time: { - start: Date.now(), + created: Date.now(), }, - }, - })) as MessageV2.ToolPart - const taskArgs = { - prompt: task.prompt, - description: task.description, - subagent_type: task.agent, - command: task.command, - } - await Plugin.trigger( - "tool.execute.before", - { - tool: "task", - sessionID, - callID: part.id, - }, - { args: taskArgs }, - ) - let executionError: Error | undefined - const result = await taskTool - .execute(taskArgs, { - agent: task.agent, + })) as MessageV2.Assistant + let part = (await Session.updatePart({ + id: Identifier.ascending("part"), messageID: assistantMessage.id, - sessionID: sessionID, - abort, - async metadata(input) { - await Session.updatePart({ - ...part, - type: "tool", - state: { - ...part.state, - ...input, - }, - } satisfies MessageV2.ToolPart) - }, - }) - .catch((error) => { - executionError = error - log.error("subtask execution failed", { error, agent: task.agent, description: task.description }) - return undefined - }) - await Plugin.trigger( - "tool.execute.after", - { - tool: "task", - sessionID, - callID: part.id, - }, - result, - ) - assistantMessage.finish = "tool-calls" - assistantMessage.time.completed = Date.now() - await Session.updateMessage(assistantMessage) - if (result && part.state.status === "running") { - await Session.updatePart({ - ...part, + sessionID: assistantMessage.sessionID, + type: "tool", + callID: ulid(), + tool: TaskTool.id, state: { - status: "completed", - input: part.state.input, - title: result.title, - metadata: result.metadata, - output: result.output, - attachments: result.attachments, + status: "running", + input: { + prompt: task.prompt, + description: task.description, + subagent_type: task.agent, + command: task.command, + }, time: { - ...part.state.time, - end: Date.now(), + start: Date.now(), }, }, - } satisfies MessageV2.ToolPart) - } - if (!result) { - await Session.updatePart({ - ...part, - state: { - status: "error", - error: executionError ? `Tool execution failed: ${executionError.message}` : "Tool execution failed", - time: { - start: part.state.status === "running" ? part.state.time.start : Date.now(), - end: Date.now(), + })) as MessageV2.ToolPart + const taskArgs = { + prompt: task.prompt, + description: task.description, + subagent_type: task.agent, + command: task.command, + } + await Plugin.trigger( + "tool.execute.before", + { + tool: "task", + sessionID, + callID: part.id, + }, + { args: taskArgs }, + ) + let executionError: Error | undefined + const result = await taskTool + .execute(taskArgs, { + agent: task.agent, + messageID: assistantMessage.id, + sessionID: sessionID, + abort, + extra: { model: task.model }, + async metadata(input) { + await Session.updatePart({ + ...part, + type: "tool", + state: { + ...part.state, + ...input, + }, + } satisfies MessageV2.ToolPart) }, - metadata: part.metadata, - input: part.state.input, + }) + .catch((error) => { + executionError = error + log.error("subtask execution failed", { error, agent: task.agent, description: task.description }) + return undefined + }) + await Plugin.trigger( + "tool.execute.after", + { + tool: "task", + sessionID, + callID: part.id, }, - } satisfies MessageV2.ToolPart) + result, + ) + assistantMessage.finish = "tool-calls" + assistantMessage.time.completed = Date.now() + await Session.updateMessage(assistantMessage) + if (result && part.state.status === "running") { + await Session.updatePart({ + ...part, + state: { + status: "completed", + input: part.state.input, + title: result.title, + metadata: result.metadata, + output: result.output, + attachments: result.attachments, + time: { + ...part.state.time, + end: Date.now(), + }, + }, + } satisfies MessageV2.ToolPart) + } + if (!result) { + await Session.updatePart({ + ...part, + state: { + status: "error", + error: executionError ? `Tool execution failed: ${executionError.message}` : "Tool execution failed", + time: { + start: part.state.status === "running" ? part.state.time.start : Date.now(), + end: Date.now(), + }, + metadata: part.metadata, + input: part.state.input, + }, + } satisfies MessageV2.ToolPart) + } } + await Promise.all(subtasks.map(executeSubtask)) + // Add synthetic user message to prevent certain reasoning models from erroring // If we create assistant messages w/ out user ones following mid loop thinking signatures // will be missing and it can cause errors for models like gemini for example @@ -426,8 +437,8 @@ export namespace SessionPrompt { time: { created: Date.now(), }, - agent: lastUser.agent, - model: lastUser.model, + agent: subtasks[0]?.parentAgent ?? lastUser.agent, + model: subtasks[0]?.parentModel ?? lastUser.model, } await Session.updateMessage(summaryUserMsg) await Session.updatePart({ @@ -442,6 +453,8 @@ export namespace SessionPrompt { continue } + const task = otherTasks.pop() + // pending compaction if (task?.type === "compaction") { const result = await SessionCompaction.process({ @@ -564,9 +577,8 @@ export namespace SessionPrompt { async function lastModel(sessionID: string) { for await (const item of MessageV2.stream(sessionID)) { - if (item.info.role === "user" && item.info.model) return item.info.model + if (item.info.role === "user" && item.info.model?.modelID) return item.info.model } - return Provider.defaultModel() } async function resolveTools(input: { @@ -607,7 +619,7 @@ export namespace SessionPrompt { abort: options.abortSignal!, messageID: input.processor.message.id, callID: options.toolCallId, - extra: { model: input.model }, + extra: {}, agent: input.agent.name, metadata: async (val) => { const match = input.processor.partFromToolCall(options.toolCallId) @@ -725,8 +737,9 @@ export namespace SessionPrompt { }, tools: input.tools, agent: agent.name, - model: input.model ?? agent.model ?? (await lastModel(input.sessionID)), + model: input.model ?? agent.model ?? (await lastModel(input.sessionID)) ?? (await Provider.defaultModel()), system: input.system, + variant: input.variant, } const parts = await Promise.all( @@ -1053,7 +1066,7 @@ export namespace SessionPrompt { SessionRevert.cleanup(session) } const agent = await Agent.get(input.agent) - const model = input.model ?? agent.model ?? (await lastModel(input.sessionID)) + const model = input.model ?? agent.model ?? (await lastModel(input.sessionID)) ?? (await Provider.defaultModel()) const userMsg: MessageV2.User = { id: Identifier.ascending("message"), sessionID: input.sessionID, @@ -1271,6 +1284,7 @@ export namespace SessionPrompt { model: z.string().optional(), arguments: z.string(), command: z.string(), + variant: z.string().optional(), }) export type CommandInput = z.infer const bashRegex = /!`([^`]+)`/g @@ -1395,6 +1409,7 @@ export namespace SessionPrompt { } template = template.trim() + const sessionModel = await lastModel(input.sessionID) const model = await (async () => { if (command.model) { return Provider.parseModel(command.model) @@ -1406,7 +1421,7 @@ export namespace SessionPrompt { } } if (input.model) return Provider.parseModel(input.model) - return await lastModel(input.sessionID) + return sessionModel ?? (await Provider.defaultModel()) })() try { @@ -1423,6 +1438,8 @@ export namespace SessionPrompt { throw e } const agent = await Agent.get(agentName) + const parentAgent = input.agent ?? "build" + const parentModel = input.model ? Provider.parseModel(input.model) : sessionModel const parts = (agent.mode === "subagent" && command.subtask !== false) || command.subtask === true @@ -1432,18 +1449,32 @@ export namespace SessionPrompt { agent: agent.name, description: command.description ?? "", command: input.command, + model: { providerID: model.providerID, modelID: model.modelID }, + parentAgent, + parentModel, // TODO: how can we make task tool accept a more complex input? prompt: await resolvePromptParts(template).then((x) => x.find((y) => y.type === "text")?.text ?? ""), }, ] : await resolvePromptParts(template) + await Plugin.trigger( + "command.execute.before", + { + command: input.command, + sessionID: input.sessionID, + arguments: input.arguments, + }, + { parts }, + ) + const result = (await prompt({ sessionID: input.sessionID, messageID: input.messageID, model, agent: agentName, parts, + variant: input.variant, })) as MessageV2.WithParts Bus.publish(Command.Event.Executed, { diff --git a/packages/opencode/src/tool/task.ts b/packages/opencode/src/tool/task.ts index bc7958889a6..a895b6a6946 100644 --- a/packages/opencode/src/tool/task.ts +++ b/packages/opencode/src/tool/task.ts @@ -10,6 +10,7 @@ import { SessionPrompt } from "../session/prompt" import { iife } from "@/util/iife" import { defer } from "@/util/defer" import { Config } from "../config/config" +import { Provider } from "../provider/provider" export { DESCRIPTION as TASK_DESCRIPTION } @@ -78,10 +79,11 @@ export const TaskTool = Tool.define("task", async () => { }) }) - const model = agent.model ?? { - modelID: msg.info.modelID, - providerID: msg.info.providerID, - } + const defaultModel = await Provider.defaultModel() + const extraModel = ctx.extra?.model?.modelID ? ctx.extra.model : undefined + const agentModel = agent.model?.modelID ? agent.model : undefined + const msgModel = msg.info.modelID ? { modelID: msg.info.modelID, providerID: msg.info.providerID } : undefined + const model = extraModel ?? agentModel ?? msgModel ?? defaultModel function cancel() { SessionPrompt.cancel(session.id) diff --git a/packages/opencode/test/server/cors.test.ts b/packages/opencode/test/server/cors.test.ts new file mode 100644 index 00000000000..4b96c260af8 --- /dev/null +++ b/packages/opencode/test/server/cors.test.ts @@ -0,0 +1,62 @@ +import { describe, expect, test } from "bun:test" +import { isOriginAllowed } from "../../src/server/cors" + +describe("server.cors", () => { + describe("isOriginAllowed", () => { + test("should return undefined for undefined input", () => { + expect(isOriginAllowed(undefined)).toBeUndefined() + }) + + test("should return undefined for empty string", () => { + expect(isOriginAllowed("")).toBeUndefined() + }) + + test("should allow localhost with any port", () => { + expect(isOriginAllowed("http://localhost:3000")).toBe("http://localhost:3000") + expect(isOriginAllowed("http://localhost:4096")).toBe("http://localhost:4096") + expect(isOriginAllowed("http://localhost:8080")).toBe("http://localhost:8080") + }) + + test("should allow 127.0.0.1 with any port", () => { + expect(isOriginAllowed("http://127.0.0.1:3000")).toBe("http://127.0.0.1:3000") + expect(isOriginAllowed("http://127.0.0.1:4096")).toBe("http://127.0.0.1:4096") + }) + + test("should allow Tauri origins", () => { + expect(isOriginAllowed("tauri://localhost")).toBe("tauri://localhost") + expect(isOriginAllowed("http://tauri.localhost")).toBe("http://tauri.localhost") + }) + + test("should allow *.opencode.ai origins (https only)", () => { + expect(isOriginAllowed("https://opencode.ai")).toBe("https://opencode.ai") + expect(isOriginAllowed("https://app.opencode.ai")).toBe("https://app.opencode.ai") + expect(isOriginAllowed("https://foo.opencode.ai")).toBe("https://foo.opencode.ai") + expect(isOriginAllowed("https://dev.app.opencode.ai")).toBe("https://dev.app.opencode.ai") + }) + + test("should allow *.shuv.ai origins (https only)", () => { + expect(isOriginAllowed("https://shuv.ai")).toBe("https://shuv.ai") + expect(isOriginAllowed("https://app.shuv.ai")).toBe("https://app.shuv.ai") + expect(isOriginAllowed("https://foo.shuv.ai")).toBe("https://foo.shuv.ai") + expect(isOriginAllowed("https://dev.app.shuv.ai")).toBe("https://dev.app.shuv.ai") + }) + + test("should deny http:// for opencode.ai and shuv.ai domains", () => { + expect(isOriginAllowed("http://opencode.ai")).toBeUndefined() + expect(isOriginAllowed("http://app.opencode.ai")).toBeUndefined() + expect(isOriginAllowed("http://shuv.ai")).toBeUndefined() + expect(isOriginAllowed("http://app.shuv.ai")).toBeUndefined() + }) + + test("should deny other domains", () => { + expect(isOriginAllowed("https://evil.com")).toBeUndefined() + expect(isOriginAllowed("https://example.com")).toBeUndefined() + expect(isOriginAllowed("https://fakeopencode.ai")).toBeUndefined() + expect(isOriginAllowed("https://fakeshuv.ai")).toBeUndefined() + }) + + test("should deny https localhost (not typical)", () => { + expect(isOriginAllowed("https://localhost:3000")).toBeUndefined() + }) + }) +}) diff --git a/packages/plugin/src/index.ts b/packages/plugin/src/index.ts index 26368f14611..5a4c93b29c8 100644 --- a/packages/plugin/src/index.ts +++ b/packages/plugin/src/index.ts @@ -165,6 +165,10 @@ export interface Hooks { output: { temperature: number; topP: number; topK: number; options: Record }, ) => Promise "permission.ask"?: (input: Permission, output: { status: "ask" | "deny" | "allow" }) => Promise + "command.execute.before"?: ( + input: { command: string; sessionID: string; arguments: string }, + output: { parts: Part[] }, + ) => Promise "tool.execute.before"?: ( input: { tool: string; sessionID: string; callID: string }, output: { args: any }, diff --git a/packages/sdk/js/src/v2/gen/types.gen.ts b/packages/sdk/js/src/v2/gen/types.gen.ts index 8a854c0a51f..dc4dfeae80c 100644 --- a/packages/sdk/js/src/v2/gen/types.gen.ts +++ b/packages/sdk/js/src/v2/gen/types.gen.ts @@ -422,6 +422,15 @@ export type Part = description: string agent: string command?: string + model?: { + providerID: string + modelID: string + } + parentAgent?: string + parentModel?: { + providerID: string + modelID: string + } } | ReasoningPart | FilePart @@ -1863,6 +1872,15 @@ export type SubtaskPartInput = { description: string agent: string command?: string + model?: { + providerID: string + modelID: string + } + parentAgent?: string + parentModel?: { + providerID: string + modelID: string + } } export type Command = { @@ -1871,7 +1889,7 @@ export type Command = { agent?: string model?: string template: string - type: "template" | "plugin" + type?: "template" | "plugin" subtask?: boolean sessionOnly?: boolean aliases?: Array