diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..559a2249ed --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,27 @@ +version: 2 + +updates: + + # Go dependencies (entire repo) + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "weekly" + labels: + - "dependencies" + - "go" + + # Frontend dependencies + - package-ecosystem: "npm" + directory: "/web/frontend" + schedule: + interval: "weekly" + labels: + - "dependencies" + - "frontend" + + # GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" \ No newline at end of file diff --git a/README.fr.md b/README.fr.md index d5fe873bf6..49a02fb771 100644 --- a/README.fr.md +++ b/README.fr.md @@ -991,6 +991,7 @@ Cette conception permet également le **support multi-agent** avec une sélectio | **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Obtenir Clé](https://www.byteplus.com/) | | **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obtenir une clé](https://longcat.chat/platform) | | **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Obtenir un Token](https://modelscope.cn/my/tokens) | +| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Obtenir Clé](https://portal.azure.com) | | **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth uniquement | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/README.ja.md b/README.ja.md index 7fff46d13e..c0d27de4f4 100644 --- a/README.ja.md +++ b/README.ja.md @@ -935,6 +935,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る | **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [キーを取得](https://www.byteplus.com) | | **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [キーを取得](https://longcat.chat/platform) | | **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [トークンを取得](https://modelscope.cn/my/tokens) | +| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [キーを取得](https://portal.azure.com) | | **Antigravity** | `antigravity/` | Google Cloud | カスタム | OAuthのみ | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/README.md b/README.md index e64daf0e4f..159ac706f4 100644 --- a/README.md +++ b/README.md @@ -1006,6 +1006,7 @@ The subagent has access to tools (message, web_search, etc.) and can communicate | `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | | `cerebras` | LLM (Cerebras direct) | [cerebras.ai](https://cerebras.ai) | | `vivgrid` | LLM (Vivgrid direct) | [vivgrid.com](https://vivgrid.com) | +| `azure` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) | ### Model Configuration (model_list) @@ -1042,6 +1043,7 @@ This design also enables **multi-agent support** with flexible provider selectio | **Vivgrid** | `vivgrid/` | `https://api.vivgrid.com/v1` | OpenAI | [Get Key](https://vivgrid.com) | | **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Get Key](https://longcat.chat/platform) | | **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Get Token](https://modelscope.cn/my/tokens) | +| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Get Key](https://portal.azure.com) | | **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth only | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/README.pt-br.md b/README.pt-br.md index 3fe24d7eaf..56946139b0 100644 --- a/README.pt-br.md +++ b/README.pt-br.md @@ -987,6 +987,7 @@ Este design também possibilita o **suporte multi-agent** com seleção flexíve | **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Obter Chave](https://www.byteplus.com) | | **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obter Chave](https://longcat.chat/platform) | | **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Obter Token](https://modelscope.cn/my/tokens) | +| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Obter Chave](https://portal.azure.com) | | **Antigravity** | `antigravity/` | Google Cloud | Custom | Apenas OAuth | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/README.vi.md b/README.vi.md index 3ee0209f6c..a542d6507f 100644 --- a/README.vi.md +++ b/README.vi.md @@ -956,6 +956,7 @@ Thiết kế này cũng cho phép **hỗ trợ đa tác nhân** với lựa ch | **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Lấy Khóa](https://www.byteplus.com) | | **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Lấy Key](https://longcat.chat/platform) | | **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Lấy Token](https://modelscope.cn/my/tokens) | +| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Lấy Khóa](https://portal.azure.com) | | **Antigravity** | `antigravity/` | Google Cloud | Tùy chỉnh | Chỉ OAuth | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/README.zh.md b/README.zh.md index 66d7c5f7cc..9877ef9f4e 100644 --- a/README.zh.md +++ b/README.zh.md @@ -528,6 +528,7 @@ Agent 读取 HEARTBEAT.md | **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [获取密钥](https://www.byteplus.com) | | **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [获取密钥](https://longcat.chat/platform) | | **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [获取 Token](https://modelscope.cn/my/tokens) | +| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [获取密钥](https://portal.azure.com) | | **Antigravity** | `antigravity/` | Google Cloud | 自定义 | 仅 OAuth | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/config/config.example.json b/config/config.example.json index 094aa46df2..1c11cd42a9 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -53,6 +53,12 @@ "api_key": "your-modelscope-access-token", "api_base": "https://api-inference.modelscope.cn/v1" }, + { + "model_name": "azure-gpt5", + "model": "azure/my-gpt5-deployment", + "api_key": "your-azure-api-key", + "api_base": "https://your-resource.openai.azure.com" + }, { "model_name": "loadbalanced-gpt-5.4", "model": "openai/gpt-5.4", diff --git a/docs/design/steering-spec.md b/docs/design/steering-spec.md new file mode 100644 index 0000000000..0951bf864e --- /dev/null +++ b/docs/design/steering-spec.md @@ -0,0 +1,306 @@ +# Steering — Implementation Specification + +## Problem + +When the agent is running (executing a chain of tool calls), the user has no way to redirect it. They must wait for the full cycle to complete before sending a new message. This creates a poor experience when the agent takes a wrong direction — the user watches it waste time on tools that are no longer relevant. + +## Solution + +Steering introduces a **message queue** that external callers can push into at any time. The agent loop polls this queue at well-defined checkpoints. When a steering message is found, the agent: + +1. Stops executing further tools in the current batch +2. Injects the user's message into the conversation context +3. Calls the LLM again with the updated context + +The user's intent reaches the model **as soon as the current tool finishes**, not after the entire turn completes. + +## Architecture Overview + +```mermaid +graph TD + subgraph External Callers + TG[Telegram] + DC[Discord] + SL[Slack] + end + + subgraph AgentLoop + BUS[MessageBus] + DRAIN[drainBusToSteering goroutine] + SQ[steeringQueue] + RLI[runLLMIteration] + TE[Tool Execution Loop] + LLM[LLM Call] + end + + TG -->|PublishInbound| BUS + DC -->|PublishInbound| BUS + SL -->|PublishInbound| BUS + + BUS -->|ConsumeInbound while busy| DRAIN + DRAIN -->|Steer| SQ + + RLI -->|1. initial poll| SQ + TE -->|2. poll after each tool| SQ + + SQ -->|pendingMessages| RLI + RLI -->|inject into context| LLM +``` + +### Bus drain mechanism + +Channels (Telegram, Discord, etc.) publish messages to the `MessageBus` via `PublishInbound`. Without additional wiring, these messages would sit in the bus buffer until the current `processMessage` finishes — meaning steering would never work for real users. + +The solution: when `Run()` starts processing a message, it spawns a **drain goroutine** (`drainBusToSteering`) that keeps consuming from the bus and calling `Steer()`. When `processMessage` returns, the drain is canceled and normal consumption resumes. + +```mermaid +sequenceDiagram + participant Bus + participant Run + participant Drain + participant AgentLoop + + Run->>Bus: ConsumeInbound() → msg + Run->>Drain: spawn drainBusToSteering(ctx) + Run->>Run: processMessage(msg) + + Note over Drain: running concurrently + + Bus-->>Drain: ConsumeInbound() → newMsg + Drain->>AgentLoop: al.transcribeAudioInMessage(ctx, newMsg) + Drain->>AgentLoop: Steer(providers.Message{Content: newMsg.Content}) + + Run->>Run: processMessage returns + Run->>Drain: cancel context + Note over Drain: exits +``` + +## Data Structures + +### steeringQueue + +A thread-safe FIFO queue, private to the `agent` package. + +| Field | Type | Description | +|-------|------|-------------| +| `mu` | `sync.Mutex` | Protects all access to `queue` and `mode` | +| `queue` | `[]providers.Message` | Pending steering messages | +| `mode` | `SteeringMode` | Dequeue strategy | + +**Methods:** + +| Method | Description | +|--------|-------------| +| `push(msg) error` | Appends a message to the queue. Returns an error if the queue is full (`MaxQueueSize`) | +| `dequeue() []Message` | Removes and returns messages according to `mode`. Returns `nil` if empty | +| `len() int` | Returns the current queue length | +| `setMode(mode)` | Updates the dequeue strategy | +| `getMode() SteeringMode` | Returns the current mode | + +### SteeringMode + +| Value | Constant | Behavior | +|-------|----------|----------| +| `"one-at-a-time"` | `SteeringOneAtATime` | `dequeue()` returns only the **first** message. Remaining messages stay in the queue for subsequent polls. | +| `"all"` | `SteeringAll` | `dequeue()` drains the **entire** queue and returns all messages at once. | + +Default: `"one-at-a-time"`. + +### processOptions extension + +A new field was added to `processOptions`: + +| Field | Type | Description | +|-------|------|-------------| +| `SkipInitialSteeringPoll` | `bool` | When `true`, the initial steering poll at loop start is skipped. Used by `Continue()` to avoid double-dequeuing. | + +## Public API on AgentLoop + +| Method | Signature | Description | +|--------|-----------|-------------| +| `Steer` | `Steer(msg providers.Message) error` | Enqueues a steering message. Returns an error if the queue is full or not initialized. Thread-safe, can be called from any goroutine. | +| `SteeringMode` | `SteeringMode() SteeringMode` | Returns the current dequeue mode. | +| `SetSteeringMode` | `SetSteeringMode(mode SteeringMode)` | Changes the dequeue mode at runtime. | +| `Continue` | `Continue(ctx, sessionKey, channel, chatID) (string, error)` | Resumes an idle agent using pending steering messages. Returns `""` if queue is empty. | + +## Integration into the Agent Loop + +### Where steering is wired + +The steering queue lives as a field on `AgentLoop`: + +``` +AgentLoop + ├── bus + ├── cfg + ├── registry + ├── steering *steeringQueue ← new + ├── ... +``` + +It is initialized in `NewAgentLoop` from `cfg.Agents.Defaults.SteeringMode`. + +### Detailed flow through runLLMIteration + +```mermaid +sequenceDiagram + participant User + participant AgentLoop + participant runLLMIteration + participant ToolExecution + participant LLM + + User->>AgentLoop: Steer(message) + Note over AgentLoop: steeringQueue.push(message) + + Note over runLLMIteration: ── iteration starts ── + + runLLMIteration->>AgentLoop: dequeueSteeringMessages()
[initial poll] + AgentLoop-->>runLLMIteration: [] (empty, or messages) + + alt pendingMessages not empty + runLLMIteration->>runLLMIteration: inject into messages[]
save to session + end + + runLLMIteration->>LLM: Chat(messages, tools) + LLM-->>runLLMIteration: response with toolCalls[0..N] + + loop for each tool call (sequential) + ToolExecution->>ToolExecution: execute tool[i] + ToolExecution->>ToolExecution: process result,
append to messages[] + + ToolExecution->>AgentLoop: dequeueSteeringMessages() + AgentLoop-->>ToolExecution: steeringMessages + + alt steering found + opt remaining tools > 0 + Note over ToolExecution: Mark tool[i+1..N-1] as
"Skipped due to queued user message." + end + Note over ToolExecution: steeringAfterTools = steeringMessages + Note over ToolExecution: break out of tool loop + end + end + + alt steeringAfterTools not empty + ToolExecution-->>runLLMIteration: pendingMessages = steeringAfterTools + Note over runLLMIteration: next iteration will inject
these before calling LLM + end + + Note over runLLMIteration: ── loop back to iteration start ── +``` + +### Polling checkpoints + +| # | Location | When | Purpose | +|---|----------|------|---------| +| 1 | Top of `runLLMIteration`, before first LLM call | Once, at loop entry | Catch messages enqueued while the agent was still setting up context | +| 2 | After every tool completes (including the first and the last) | Immediately after each tool's result is processed | Interrupt the batch as early as possible — if steering is found and there are remaining tools, they are all skipped | + +### What happens to skipped tools + +When steering interrupts a tool batch after tool `[i]` completes, all tools from `[i+1]` to `[N-1]` are **not executed**. Instead, a tool result message is generated for each: + +```json +{ + "role": "tool", + "content": "Skipped due to queued user message.", + "tool_call_id": "" +} +``` + +These results are: +- Appended to the conversation `messages[]` +- Saved to the session via `AddFullMessage` + +This ensures the LLM knows which of its requested actions were not performed. + +### Loop condition change + +The iteration loop condition was changed from: + +```go +for iteration < agent.MaxIterations +``` + +to: + +```go +for iteration < agent.MaxIterations || len(pendingMessages) > 0 +``` + +This allows **one extra iteration** when steering arrives right at the max iteration boundary, ensuring the steering message is always processed. + +### Tool execution: parallel → sequential + +**Before steering:** all tool calls in a batch were executed in parallel using `sync.WaitGroup`. + +**After steering:** tool calls execute **sequentially**. This is required because steering must be polled between individual tool completions. A parallel execution model would not allow interrupting mid-batch. + +> **Trade-off:** This introduces latency when the LLM requests multiple independent tools in a single turn. In practice, most batches contain 1-2 tools, so the impact is minimal. The benefit of being able to interrupt outweighs the cost. + +### Why skip remaining tools (instead of letting them finish) + +Two strategies were considered when a steering message is detected mid-batch: + +1. **Skip remaining tools** (chosen) — stop executing, mark the rest as skipped, inject steering +2. **Finish all tools, then inject** — let everything run, append steering afterwards + +Strategy 2 was rejected for three reasons: + +**Irreversible side effects.** Tools can send emails, write files, spawn subagents, or call external APIs. If the user says "stop" or "change direction", those actions have already happened and cannot be undone. + +| Tool batch | Steering | Skip (1) | Finish (2) | +|---|---|---|---| +| `[search, send_email]` | "don't send it" | Email not sent | Email sent | +| `[query, write_file, spawn]` | "wrong database" | Only query runs | File + subagent wasted | +| `[fetch₁, fetch₂, fetch₃, write]` | topic change | 1 fetch | 3 fetches + write, all discarded | + +**Wasted latency.** Tools like web fetches and API calls take seconds each. In a 3-tool batch averaging 3-4s per tool, the user would wait 10+ seconds for work that gets thrown away. + +**The LLM retains full awareness.** Skipped tools receive an explicit `"Skipped due to queued user message."` result, so the model knows what was not done and can decide whether to re-execute with the new context or take a different path. + +## The Continue() method + +`Continue` handles the case where the agent is **idle** (its last message was from the assistant) and the user has enqueued steering messages in the meantime. + +```mermaid +flowchart TD + A[Continue called] --> B{dequeueSteeringMessages} + B -->|empty| C["return ('', nil)"] + B -->|messages found| D[Combine message contents] + D --> E["runAgentLoop with
SkipInitialSteeringPoll: true"] + E --> F[Return response] +``` + +**Why `SkipInitialSteeringPoll: true`?** Because `Continue` already dequeued the messages itself. Without this flag, `runLLMIteration` would poll again at the start and find nothing (the queue is already empty), or worse, double-process if new messages arrived in the meantime. + +## Configuration + +```json +{ + "agents": { + "defaults": { + "steering_mode": "one-at-a-time" + } + } +} +``` + +| Field | Type | Default | Env var | +|-------|------|---------|---------| +| `steering_mode` | `string` | `"one-at-a-time"` | `PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE` | + + +## Design decisions and trade-offs + +| Decision | Rationale | +|----------|-----------| +| Sequential tool execution | Required for per-tool steering polls. Parallel execution cannot be interrupted mid-batch. | +| Polling-based (not channel/signal) | Keeps the implementation simple. No need for `select` or signal channels. The polling cost is negligible (mutex lock + slice length check). | +| `one-at-a-time` as default | Gives the model a chance to react to each steering message individually. More predictable behavior than dumping all messages at once. | +| Skipped tools get explicit error results | The LLM protocol requires a tool result for every tool call in the assistant message. Omitting them would cause API errors. The skip message also informs the model about what was not done. | +| `Continue()` uses `SkipInitialSteeringPoll` | Prevents race conditions and double-dequeuing when resuming an idle agent. | +| Queue stored on `AgentLoop`, not `AgentInstance` | Steering is a loop-level concern (it affects the iteration flow), not a per-agent concern. All agents share the same steering queue since `processMessage` is sequential. | +| Bus drain goroutine in `Run()` | Channels (Telegram, Discord, etc.) publish to the bus via `PublishInbound`. Without the drain, messages would queue in the bus channel buffer and only be consumed after `processMessage` returns — defeating the purpose of steering. The drain goroutine bridges the gap by consuming new bus messages and calling `Steer()` while the agent is busy. | +| Audio transcription before steering | The drain goroutine calls `al.transcribeAudioInMessage(ctx, msg)` before steering, so voice messages are converted to text before the agent sees them. If transcription fails, the error is silently discarded and the original message is steered as-is. | +| `MaxQueueSize = 10` | Prevents unbounded memory growth if a user sends many messages while the agent is busy. Excess messages are dropped with a warning. | diff --git a/docs/steering.md b/docs/steering.md new file mode 100644 index 0000000000..ad08f84250 --- /dev/null +++ b/docs/steering.md @@ -0,0 +1,166 @@ +# Steering + +Steering allows injecting messages into an already-running agent loop, interrupting it between tool calls without waiting for the entire cycle to complete. + +## How it works + +When the agent is executing a sequence of tool calls (e.g. the model requested 3 tools in a single turn), steering checks the queue **after each tool** completes. If it finds queued messages: + +1. The remaining tools are **skipped** and receive `"Skipped due to queued user message."` as their result +2. The steering messages are **injected into the conversation context** +3. The model is called again with the updated context, including the user's steering message + +``` +User ──► Steer("change approach") + │ +Agent Loop ▼ + ├─ tool[0] ✔ (executed) + ├─ [polling] → steering found! + ├─ tool[1] ✘ (skipped) + ├─ tool[2] ✘ (skipped) + └─ new LLM turn with steering message +``` + +## Configuration + +In `config.json`, under `agents.defaults`: + +```json +{ + "agents": { + "defaults": { + "steering_mode": "one-at-a-time" + } + } +} +``` + +### Modes + +| Value | Behavior | +|-------|----------| +| `"one-at-a-time"` | **(default)** Dequeues only one message per polling cycle. If there are 3 messages in the queue, they are processed one at a time across 3 successive iterations. | +| `"all"` | Drains the entire queue in a single poll. All pending messages are injected into the context together. | + +The environment variable `PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE` can be used as an alternative. + +## Go API + +### Steer — Send a steering message + +```go +err := agentLoop.Steer(providers.Message{ + Role: "user", + Content: "change direction, focus on X instead", +}) +if err != nil { + // Queue is full (MaxQueueSize=10) or not initialized +} +``` + +The message is enqueued in a thread-safe manner. Returns an error if the queue is full or not initialized. It will be picked up at the next polling point (after the current tool finishes). + +### SteeringMode / SetSteeringMode + +```go +// Read the current mode +mode := agentLoop.SteeringMode() // SteeringOneAtATime | SteeringAll + +// Change it at runtime +agentLoop.SetSteeringMode(agent.SteeringAll) +``` + +### Continue — Resume an idle agent + +When the agent is idle (it has finished processing and its last message was from the assistant), `Continue` checks if there are steering messages in the queue and uses them to start a new cycle: + +```go +response, err := agentLoop.Continue(ctx, sessionKey, channel, chatID) +if err != nil { + // Error (e.g. "no default agent available") +} +if response == "" { + // No steering messages in queue, the agent stays idle +} +``` + +`Continue` internally uses `SkipInitialSteeringPoll: true` to avoid double-dequeuing the same messages (since it already extracted them and passes them directly as input). + +## Polling points in the loop + +Steering is checked at **two points** in the agent cycle: + +1. **At loop start** — before the first LLM call, to catch messages enqueued during setup +2. **After every tool completes** — including the first and the last. If steering is found and there are remaining tools, they are all skipped immediately + +## Why remaining tools are skipped + +When a steering message is detected, all remaining tools in the batch are skipped rather than executed. The alternative — let all tools finish and inject the steering message afterwards — was considered and rejected. Here is why. + +### Preventing unwanted side effects + +Tools can have **irreversible side effects**. If the user says "no, wait" while the agent is mid-batch, executing the remaining tools means those side effects happen anyway: + +| Tool batch | Steering message | With skip | Without skip | +|---|---|---|---| +| `[web_search, send_email]` | "don't send it" | Email **not** sent | Email sent, damage done | +| `[query_db, write_file, spawn_agent]` | "use another database" | Only the query runs | File written + subagent spawned, all wasted | +| `[search₁, search₂, search₃, write_file]` | user changes topic entirely | 1 search | 3 searches + file write, all irrelevant | + +### Avoiding wasted time + +Tools that take seconds (web fetches, API calls, database queries) would all run to completion before the agent sees the user's correction. In a batch of 3 tools each taking 3-4 seconds, that's 10+ seconds of work that will be discarded. + +With skipping, the agent reacts as soon as the current tool finishes — typically within a few seconds instead of waiting for the entire batch. + +### The LLM gets full context + +Skipped tools receive an explicit error result (`"Skipped due to queued user message."`), so the model knows exactly which actions were not performed. It can then decide whether to re-execute them with the new context, or take a different path entirely. + +### Trade-off: sequential execution + +Skipping requires tools to run **sequentially** (the previous implementation ran them in parallel). This introduces latency when the LLM requests multiple independent tools in a single turn. In practice, most batches contain 1-2 tools, so the impact is minimal compared to the benefit of being able to stop unwanted actions. + +## Skipped tool result format + +When steering interrupts a batch, each tool that was not executed receives a `tool` result with: + +``` +Content: "Skipped due to queued user message." +``` + +This is saved to the session via `AddFullMessage` and sent to the model, so it is aware that some requested actions were not performed. + +## Full flow example + +``` +1. User: "search for info on X, write a file, and send me a message" + +2. LLM responds with 3 tool calls: [web_search, write_file, message] + +3. web_search is executed → result saved + +4. [polling] → User called Steer("no, search for Y instead") + +5. write_file is skipped → "Skipped due to queued user message." + message is skipped → "Skipped due to queued user message." + +6. Message "search for Y instead" injected into context + +7. LLM receives the full updated context and responds accordingly +``` + +## Automatic bus drain + +When the agent loop (`Run()`) starts processing a message, it spawns a background goroutine that keeps consuming new inbound messages from the bus. These messages are automatically redirected into the steering queue via `Steer()`. This means: + +- Users on any channel (Telegram, Discord, etc.) don't need to do anything special — their messages are automatically captured as steering when the agent is busy +- Audio messages are transcribed before being steered, so the agent receives text. If transcription fails, the original (non-transcribed) message is steered as-is +- When `processMessage` finishes, the drain goroutine is canceled and normal message consumption resumes + +## Notes + +- Steering **does not interrupt** a tool that is currently executing. It waits for the current tool to finish, then checks the queue. +- With `one-at-a-time` mode, if multiple messages are enqueued rapidly, they will be processed one per iteration. This gives the model the opportunity to react to each message individually. +- With `all` mode, all pending messages are combined into a single injection. Useful when you want the agent to receive all the context at once. +- The steering queue has a maximum capacity of 10 messages (`MaxQueueSize`). `Steer()` returns an error when the queue is full. In the bus drain path, the error is logged as a warning and the message is effectively dropped. diff --git a/pkg/agent/eventbus_mock.go b/pkg/agent/eventbus_mock.go new file mode 100644 index 0000000000..c9641092be --- /dev/null +++ b/pkg/agent/eventbus_mock.go @@ -0,0 +1,12 @@ +package agent + +import "fmt" + +// MockEventBus - for POC +var MockEventBus = struct { + Emit func(event any) +}{ + Emit: func(event any) { + fmt.Printf("[Mock EventBus] %T %+v\n", event, event) + }, +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index f20a56b9c4..21516e7de9 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -48,6 +48,7 @@ type AgentLoop struct { transcriber voice.Transcriber cmdRegistry *commands.Registry mcp mcpRuntime + steering *steeringQueue mu sync.RWMutex // Track active requests for safe provider cleanup activeRequests sync.WaitGroup @@ -55,15 +56,16 @@ type AgentLoop struct { // processOptions configures how a message is processed type processOptions struct { - SessionKey string // Session identifier for history/context - Channel string // Target channel for tool execution - ChatID string // Target chat ID for tool execution - UserMessage string // User message content (may include prefix) - Media []string // media:// refs from inbound message - DefaultResponse string // Response when LLM returns empty - EnableSummary bool // Whether to trigger summarization - SendResponse bool // Whether to send response via bus - NoHistory bool // If true, don't load session history (for heartbeat) + SessionKey string // Session identifier for history/context + Channel string // Target channel for tool execution + ChatID string // Target chat ID for tool execution + UserMessage string // User message content (may include prefix) + Media []string // media:// refs from inbound message + DefaultResponse string // Response when LLM returns empty + EnableSummary bool // Whether to trigger summarization + SendResponse bool // Whether to send response via bus + NoHistory bool // If true, don't load session history (for heartbeat) + SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue) } const ( @@ -105,6 +107,7 @@ func NewAgentLoop( summarizing: sync.Map{}, fallback: fallbackChain, cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()), + steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)), } return al @@ -257,6 +260,13 @@ func (al *AgentLoop) Run(ctx context.Context) error { continue } + // Start a goroutine that drains the bus while processMessage is + // running. Any inbound messages that arrive during processing are + // redirected into the steering queue so the agent loop can pick + // them up between tool calls. + drainCtx, drainCancel := context.WithCancel(ctx) + go al.drainBusToSteering(drainCtx) + // Process message func() { // TODO: Re-enable media cleanup after inbound media is properly consumed by the agent. @@ -272,6 +282,8 @@ func (al *AgentLoop) Run(ctx context.Context) error { // } // }() + defer drainCancel() + response, err := al.processMessage(ctx, msg) if err != nil { response = fmt.Sprintf("Error processing message: %v", err) @@ -318,6 +330,39 @@ func (al *AgentLoop) Run(ctx context.Context) error { return nil } +// drainBusToSteering continuously consumes inbound messages and redirects +// them into the steering queue. It runs in a goroutine while processMessage +// is active and stops when drainCtx is canceled (i.e., processMessage returns). +func (al *AgentLoop) drainBusToSteering(ctx context.Context) { + for { + msg, ok := al.bus.ConsumeInbound(ctx) + if !ok { + return + } + + // Transcribe audio if needed before steering, so the agent sees text. + msg, _ = al.transcribeAudioInMessage(ctx, msg) + + logger.InfoCF("agent", "Redirecting inbound message to steering queue", + map[string]any{ + "channel": msg.Channel, + "sender_id": msg.SenderID, + "content_len": len(msg.Content), + }) + + if err := al.Steer(providers.Message{ + Role: "user", + Content: msg.Content, + }); err != nil { + logger.WarnCF("agent", "Failed to steer message, will be lost", + map[string]any{ + "error": err.Error(), + "channel": msg.Channel, + }) + } + } +} + func (al *AgentLoop) Stop() { al.running.Store(false) } @@ -999,6 +1044,16 @@ func (al *AgentLoop) runLLMIteration( ) (string, int, error) { iteration := 0 var finalContent string + var pendingMessages []providers.Message + + // Poll for steering messages at loop start (in case the user typed while + // the agent was setting up), unless the caller already provided initial + // steering messages (e.g. Continue). + if !opts.SkipInitialSteeringPoll { + if msgs := al.dequeueSteeringMessages(); len(msgs) > 0 { + pendingMessages = msgs + } + } // Determine effective model tier for this conversation turn. // selectCandidates evaluates routing once and the decision is sticky for @@ -1006,9 +1061,25 @@ func (al *AgentLoop) runLLMIteration( // tool chain doesn't switch models mid-way through. activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages) - for iteration < agent.MaxIterations { + for iteration < agent.MaxIterations || len(pendingMessages) > 0 { iteration++ + // Inject pending steering messages into the conversation context + // before the next LLM call. + if len(pendingMessages) > 0 { + for _, pm := range pendingMessages { + messages = append(messages, pm) + agent.Sessions.AddMessage(opts.SessionKey, pm.Role, pm.Content) + logger.InfoCF("agent", "Injected steering message into context", + map[string]any{ + "agent_id": agent.ID, + "iteration": iteration, + "content_len": len(pm.Content), + }) + } + pendingMessages = nil + } + logger.DebugCF("agent", "LLM iteration", map[string]any{ "agent_id": agent.ID, @@ -1251,107 +1322,83 @@ func (al *AgentLoop) runLLMIteration( // Save assistant message with tool calls to session agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg) - // Execute tool calls in parallel - type indexedAgentResult struct { - result *tools.ToolResult - tc providers.ToolCall - } - - agentResults := make([]indexedAgentResult, len(normalizedToolCalls)) - var wg sync.WaitGroup + // Execute tool calls sequentially. After each tool completes, check + // for steering messages. If any are found, skip remaining tools. + var steeringAfterTools []providers.Message for i, tc := range normalizedToolCalls { - agentResults[i].tc = tc - - wg.Add(1) - go func(idx int, tc providers.ToolCall) { - defer wg.Done() + argsJSON, _ := json.Marshal(tc.Arguments) + argsPreview := utils.Truncate(string(argsJSON), 200) + logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), + map[string]any{ + "agent_id": agent.ID, + "tool": tc.Name, + "iteration": iteration, + }) - argsJSON, _ := json.Marshal(tc.Arguments) - argsPreview := utils.Truncate(string(argsJSON), 200) - logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), - map[string]any{ - "agent_id": agent.ID, - "tool": tc.Name, - "iteration": iteration, + // Create async callback for tools that implement AsyncExecutor. + asyncCallback := func(_ context.Context, result *tools.ToolResult) { + if !result.Silent && result.ForUser != "" { + outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer outCancel() + _ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Content: result.ForUser, }) + } - // Create async callback for tools that implement AsyncExecutor. - // When the background work completes, this publishes the result - // as an inbound system message so processSystemMessage routes it - // back to the user via the normal agent loop. - asyncCallback := func(_ context.Context, result *tools.ToolResult) { - // Send ForUser content directly to the user (immediate feedback), - // mirroring the synchronous tool execution path. - if !result.Silent && result.ForUser != "" { - outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer outCancel() - _ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{ - Channel: opts.Channel, - ChatID: opts.ChatID, - Content: result.ForUser, - }) - } - - // Determine content for the agent loop (ForLLM or error). - content := result.ForLLM - if content == "" && result.Err != nil { - content = result.Err.Error() - } - if content == "" { - return - } - - logger.InfoCF("agent", "Async tool completed, publishing result", - map[string]any{ - "tool": tc.Name, - "content_len": len(content), - "channel": opts.Channel, - }) + content := result.ForLLM + if content == "" && result.Err != nil { + content = result.Err.Error() + } + if content == "" { + return + } - pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer pubCancel() - _ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{ - Channel: "system", - SenderID: fmt.Sprintf("async:%s", tc.Name), - ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID), - Content: content, + logger.InfoCF("agent", "Async tool completed, publishing result", + map[string]any{ + "tool": tc.Name, + "content_len": len(content), + "channel": opts.Channel, }) - } - toolResult := agent.Tools.ExecuteWithContext( - ctx, - tc.Name, - tc.Arguments, - opts.Channel, - opts.ChatID, - asyncCallback, - ) - agentResults[idx].result = toolResult - }(i, tc) - } - wg.Wait() + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + _ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{ + Channel: "system", + SenderID: fmt.Sprintf("async:%s", tc.Name), + ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID), + Content: content, + }) + } + + toolResult := agent.Tools.ExecuteWithContext( + ctx, + tc.Name, + tc.Arguments, + opts.Channel, + opts.ChatID, + asyncCallback, + ) - // Process results in original order (send to user, save to session) - for _, r := range agentResults { - // Send ForUser content to user immediately if not Silent - if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse { + // Process tool result + if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, - Content: r.result.ForUser, + Content: toolResult.ForUser, }) logger.DebugCF("agent", "Sent tool result to user", map[string]any{ - "tool": r.tc.Name, - "content_len": len(r.result.ForUser), + "tool": tc.Name, + "content_len": len(toolResult.ForUser), }) } - // If tool returned media refs, publish them as outbound media - if len(r.result.Media) > 0 { - parts := make([]bus.MediaPart, 0, len(r.result.Media)) - for _, ref := range r.result.Media { + if len(toolResult.Media) > 0 { + parts := make([]bus.MediaPart, 0, len(toolResult.Media)) + for _, ref := range toolResult.Media { part := bus.MediaPart{Ref: ref} if al.mediaStore != nil { if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { @@ -1369,21 +1416,55 @@ func (al *AgentLoop) runLLMIteration( }) } - // Determine content for LLM based on tool result - contentForLLM := r.result.ForLLM - if contentForLLM == "" && r.result.Err != nil { - contentForLLM = r.result.Err.Error() + contentForLLM := toolResult.ForLLM + if contentForLLM == "" && toolResult.Err != nil { + contentForLLM = toolResult.Err.Error() } toolResultMsg := providers.Message{ Role: "tool", Content: contentForLLM, - ToolCallID: r.tc.ID, + ToolCallID: tc.ID, } messages = append(messages, toolResultMsg) - - // Save tool result message to session agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) + + // After EVERY tool (including the first and last), check for + // steering messages. If found and there are remaining tools, + // skip them all. + if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 { + remaining := len(normalizedToolCalls) - i - 1 + if remaining > 0 { + logger.InfoCF("agent", "Steering interrupt: skipping remaining tools", + map[string]any{ + "agent_id": agent.ID, + "completed": i + 1, + "skipped": remaining, + "total_tools": len(normalizedToolCalls), + "steering_count": len(steerMsgs), + }) + + // Mark remaining tool calls as skipped + for j := i + 1; j < len(normalizedToolCalls); j++ { + skippedTC := normalizedToolCalls[j] + toolResultMsg := providers.Message{ + Role: "tool", + Content: "Skipped due to queued user message.", + ToolCallID: skippedTC.ID, + } + messages = append(messages, toolResultMsg) + agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) + } + } + steeringAfterTools = steerMsgs + break + } + } + + // If steering messages were captured during tool execution, they + // become pendingMessages for the next iteration of the inner loop. + if len(steeringAfterTools) > 0 { + pendingMessages = steeringAfterTools } // Tick down TTL of discovered tools after processing tool results. diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go new file mode 100644 index 0000000000..8c7c79c160 --- /dev/null +++ b/pkg/agent/steering.go @@ -0,0 +1,188 @@ +package agent + +import ( + "context" + "fmt" + "strings" + "sync" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// SteeringMode controls how queued steering messages are dequeued. +type SteeringMode string + +const ( + // SteeringOneAtATime dequeues only the first queued message per poll. + SteeringOneAtATime SteeringMode = "one-at-a-time" + // SteeringAll drains the entire queue in a single poll. + SteeringAll SteeringMode = "all" + // MaxQueueSize number of possible messages in the Steering Queue + MaxQueueSize = 10 +) + +// parseSteeringMode normalizes a config string into a SteeringMode. +func parseSteeringMode(s string) SteeringMode { + switch s { + case "all": + return SteeringAll + default: + return SteeringOneAtATime + } +} + +// steeringQueue is a thread-safe queue of user messages that can be injected +// into a running agent loop to interrupt it between tool calls. +type steeringQueue struct { + mu sync.Mutex + queue []providers.Message + mode SteeringMode +} + +func newSteeringQueue(mode SteeringMode) *steeringQueue { + return &steeringQueue{ + mode: mode, + } +} + +// push enqueues a steering message. +func (sq *steeringQueue) push(msg providers.Message) error { + sq.mu.Lock() + defer sq.mu.Unlock() + if len(sq.queue) >= MaxQueueSize { + return fmt.Errorf("steering queue is full") + } + sq.queue = append(sq.queue, msg) + return nil +} + +// dequeue removes and returns pending steering messages according to the +// configured mode. Returns nil when the queue is empty. +func (sq *steeringQueue) dequeue() []providers.Message { + sq.mu.Lock() + defer sq.mu.Unlock() + + if len(sq.queue) == 0 { + return nil + } + + switch sq.mode { + case SteeringAll: + msgs := sq.queue + sq.queue = nil + return msgs + default: // one-at-a-time + msg := sq.queue[0] + sq.queue[0] = providers.Message{} // Clear reference for GC + sq.queue = sq.queue[1:] + return []providers.Message{msg} + } +} + +// len returns the number of queued messages. +func (sq *steeringQueue) len() int { + sq.mu.Lock() + defer sq.mu.Unlock() + return len(sq.queue) +} + +// setMode updates the steering mode. +func (sq *steeringQueue) setMode(mode SteeringMode) { + sq.mu.Lock() + defer sq.mu.Unlock() + sq.mode = mode +} + +// getMode returns the current steering mode. +func (sq *steeringQueue) getMode() SteeringMode { + sq.mu.Lock() + defer sq.mu.Unlock() + return sq.mode +} + +// --- AgentLoop steering API --- + +// Steer enqueues a user message to be injected into the currently running +// agent loop. The message will be picked up after the current tool finishes +// executing, causing any remaining tool calls in the batch to be skipped. +func (al *AgentLoop) Steer(msg providers.Message) error { + if al.steering == nil { + return fmt.Errorf("steering queue is not initialized") + } + if err := al.steering.push(msg); err != nil { + logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{ + "error": err.Error(), + "role": msg.Role, + }) + return err + } + logger.DebugCF("agent", "Steering message enqueued", map[string]any{ + "role": msg.Role, + "content_len": len(msg.Content), + "queue_len": al.steering.len(), + }) + + return nil +} + +// SteeringMode returns the current steering mode. +func (al *AgentLoop) SteeringMode() SteeringMode { + if al.steering == nil { + return SteeringOneAtATime + } + return al.steering.getMode() +} + +// SetSteeringMode updates the steering mode. +func (al *AgentLoop) SetSteeringMode(mode SteeringMode) { + if al.steering == nil { + return + } + al.steering.setMode(mode) +} + +// dequeueSteeringMessages is the internal method called by the agent loop +// to poll for steering messages. Returns nil when no messages are pending. +func (al *AgentLoop) dequeueSteeringMessages() []providers.Message { + if al.steering == nil { + return nil + } + return al.steering.dequeue() +} + +// Continue resumes an idle agent by dequeuing any pending steering messages +// and running them through the agent loop. This is used when the agent's last +// message was from the assistant (i.e., it has stopped processing) and the +// user has since enqueued steering messages. +// +// If no steering messages are pending, it returns an empty string. +func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) { + steeringMsgs := al.dequeueSteeringMessages() + if len(steeringMsgs) == 0 { + return "", nil + } + + agent := al.GetRegistry().GetDefaultAgent() + if agent == nil { + return "", fmt.Errorf("no default agent available") + } + + // Build a combined user message from the steering messages. + var contents []string + for _, msg := range steeringMsgs { + contents = append(contents, msg.Content) + } + combinedContent := strings.Join(contents, "\n") + + return al.runAgentLoop(ctx, agent, processOptions{ + SessionKey: sessionKey, + Channel: channel, + ChatID: chatID, + UserMessage: combinedContent, + DefaultResponse: defaultResponse, + EnableSummary: true, + SendResponse: false, + SkipInitialSteeringPoll: true, + }) +} diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go new file mode 100644 index 0000000000..e8cdb23449 --- /dev/null +++ b/pkg/agent/steering_test.go @@ -0,0 +1,744 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "os" + "sync" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// --- steeringQueue unit tests --- + +func TestSteeringQueue_PushDequeue_OneAtATime(t *testing.T) { + sq := newSteeringQueue(SteeringOneAtATime) + + sq.push(providers.Message{Role: "user", Content: "msg1"}) + sq.push(providers.Message{Role: "user", Content: "msg2"}) + sq.push(providers.Message{Role: "user", Content: "msg3"}) + + if sq.len() != 3 { + t.Fatalf("expected 3 messages, got %d", sq.len()) + } + + msgs := sq.dequeue() + if len(msgs) != 1 { + t.Fatalf("expected 1 message in one-at-a-time mode, got %d", len(msgs)) + } + if msgs[0].Content != "msg1" { + t.Fatalf("expected 'msg1', got %q", msgs[0].Content) + } + if sq.len() != 2 { + t.Fatalf("expected 2 remaining, got %d", sq.len()) + } + + msgs = sq.dequeue() + if len(msgs) != 1 || msgs[0].Content != "msg2" { + t.Fatalf("expected 'msg2', got %v", msgs) + } + + msgs = sq.dequeue() + if len(msgs) != 1 || msgs[0].Content != "msg3" { + t.Fatalf("expected 'msg3', got %v", msgs) + } + + msgs = sq.dequeue() + if msgs != nil { + t.Fatalf("expected nil from empty queue, got %v", msgs) + } +} + +func TestSteeringQueue_PushDequeue_All(t *testing.T) { + sq := newSteeringQueue(SteeringAll) + + sq.push(providers.Message{Role: "user", Content: "msg1"}) + sq.push(providers.Message{Role: "user", Content: "msg2"}) + sq.push(providers.Message{Role: "user", Content: "msg3"}) + + msgs := sq.dequeue() + if len(msgs) != 3 { + t.Fatalf("expected 3 messages in all mode, got %d", len(msgs)) + } + if msgs[0].Content != "msg1" || msgs[1].Content != "msg2" || msgs[2].Content != "msg3" { + t.Fatalf("unexpected messages: %v", msgs) + } + + if sq.len() != 0 { + t.Fatalf("expected 0 remaining, got %d", sq.len()) + } + + msgs = sq.dequeue() + if msgs != nil { + t.Fatalf("expected nil from empty queue, got %v", msgs) + } +} + +func TestSteeringQueue_EmptyDequeue(t *testing.T) { + sq := newSteeringQueue(SteeringOneAtATime) + if msgs := sq.dequeue(); msgs != nil { + t.Fatalf("expected nil, got %v", msgs) + } +} + +func TestSteeringQueue_SetMode(t *testing.T) { + sq := newSteeringQueue(SteeringOneAtATime) + if sq.getMode() != SteeringOneAtATime { + t.Fatalf("expected one-at-a-time, got %v", sq.getMode()) + } + + sq.setMode(SteeringAll) + if sq.getMode() != SteeringAll { + t.Fatalf("expected all, got %v", sq.getMode()) + } + + // Push two messages and verify all-mode drains them + sq.push(providers.Message{Role: "user", Content: "a"}) + sq.push(providers.Message{Role: "user", Content: "b"}) + + msgs := sq.dequeue() + if len(msgs) != 2 { + t.Fatalf("expected 2 messages after mode switch, got %d", len(msgs)) + } +} + +func TestSteeringQueue_ConcurrentAccess(t *testing.T) { + sq := newSteeringQueue(SteeringOneAtATime) + + var wg sync.WaitGroup + const n = MaxQueueSize + + // Push from multiple goroutines + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + sq.push(providers.Message{Role: "user", Content: fmt.Sprintf("msg%d", i)}) + }(i) + } + wg.Wait() + + if sq.len() != n { + t.Fatalf("expected %d messages, got %d", n, sq.len()) + } + + // Drain from multiple goroutines + var drained int + var mu sync.Mutex + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if msgs := sq.dequeue(); len(msgs) > 0 { + mu.Lock() + drained += len(msgs) + mu.Unlock() + } + }() + } + wg.Wait() + + if drained != n { + t.Fatalf("expected to drain %d messages, got %d", n, drained) + } +} + +func TestSteeringQueue_Overflow(t *testing.T) { + sq := newSteeringQueue(SteeringOneAtATime) + + // Fill the queue up to its maximum capacity + for i := 0; i < MaxQueueSize; i++ { + err := sq.push(providers.Message{Role: "user", Content: fmt.Sprintf("msg%d", i)}) + if err != nil { + t.Fatalf("unexpected error pushing message %d: %v", i, err) + } + } + + // Sanity check: ensure the queue is actually full + if sq.len() != MaxQueueSize { + t.Fatalf("expected queue length %d, got %d", MaxQueueSize, sq.len()) + } + + // Attempt to push one more message, which MUST fail + err := sq.push(providers.Message{Role: "user", Content: "overflow_msg"}) + + // Assert the error happened and is the exact one we expect + if err == nil { + t.Fatal("expected an error when pushing to a full queue, but got nil") + } + + expectedErr := "steering queue is full" + if err.Error() != expectedErr { + t.Errorf("expected error message %q, got %q", expectedErr, err.Error()) + } +} + +func TestParseSteeringMode(t *testing.T) { + tests := []struct { + input string + expected SteeringMode + }{ + {"", SteeringOneAtATime}, + {"one-at-a-time", SteeringOneAtATime}, + {"all", SteeringAll}, + {"unknown", SteeringOneAtATime}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + if got := parseSteeringMode(tt.input); got != tt.expected { + t.Fatalf("parseSteeringMode(%q) = %v, want %v", tt.input, got, tt.expected) + } + }) + } +} + +// --- AgentLoop steering integration tests --- + +func TestAgentLoop_Steer_Enqueues(t *testing.T) { + al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t) + defer cleanup() + + if cfg == nil { + t.Fatal("expected config to be initialized") + } + if msgBus == nil { + t.Fatal("expected message bus to be initialized") + } + if provider == nil { + t.Fatal("expected provider to be initialized") + } + + al.Steer(providers.Message{Role: "user", Content: "interrupt me"}) + + if al.steering.len() != 1 { + t.Fatalf("expected 1 steering message, got %d", al.steering.len()) + } + + msgs := al.dequeueSteeringMessages() + if len(msgs) != 1 || msgs[0].Content != "interrupt me" { + t.Fatalf("unexpected dequeued message: %v", msgs) + } +} + +func TestAgentLoop_SteeringMode_GetSet(t *testing.T) { + al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t) + defer cleanup() + + if cfg == nil { + t.Fatal("expected config to be initialized") + } + if msgBus == nil { + t.Fatal("expected message bus to be initialized") + } + if provider == nil { + t.Fatal("expected provider to be initialized") + } + + if al.SteeringMode() != SteeringOneAtATime { + t.Fatalf("expected default mode one-at-a-time, got %v", al.SteeringMode()) + } + + al.SetSteeringMode(SteeringAll) + if al.SteeringMode() != SteeringAll { + t.Fatalf("expected all mode, got %v", al.SteeringMode()) + } +} + +func TestAgentLoop_SteeringMode_ConfiguredFromConfig(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + SteeringMode: "all", + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + if al.SteeringMode() != SteeringAll { + t.Fatalf("expected 'all' mode from config, got %v", al.SteeringMode()) + } +} + +func TestAgentLoop_Continue_NoMessages(t *testing.T) { + al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t) + defer cleanup() + + if cfg == nil { + t.Fatal("expected config to be initialized") + } + if msgBus == nil { + t.Fatal("expected message bus to be initialized") + } + if provider == nil { + t.Fatal("expected provider to be initialized") + } + + resp, err := al.Continue(context.Background(), "test-session", "test", "chat1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp != "" { + t.Fatalf("expected empty response for no steering messages, got %q", resp) + } +} + +func TestAgentLoop_Continue_WithMessages(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &simpleMockProvider{response: "continued response"} + al := NewAgentLoop(cfg, msgBus, provider) + + al.Steer(providers.Message{Role: "user", Content: "new direction"}) + + resp, err := al.Continue(context.Background(), "test-session", "test", "chat1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp != "continued response" { + t.Fatalf("expected 'continued response', got %q", resp) + } +} + +// slowTool simulates a tool that takes some time to execute. +type slowTool struct { + name string + duration time.Duration + execCh chan struct{} // closed when Execute starts +} + +func (t *slowTool) Name() string { return t.name } +func (t *slowTool) Description() string { return "slow tool for testing" } +func (t *slowTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (t *slowTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + if t.execCh != nil { + close(t.execCh) + } + time.Sleep(t.duration) + return tools.SilentResult(fmt.Sprintf("executed %s", t.name)) +} + +// toolCallProvider returns an LLM response with tool calls on the first call, +// then a direct response on subsequent calls. +type toolCallProvider struct { + mu sync.Mutex + calls int + toolCalls []providers.ToolCall + finalResp string +} + +func (m *toolCallProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.calls++ + + if m.calls == 1 && len(m.toolCalls) > 0 { + return &providers.LLMResponse{ + Content: "", + ToolCalls: m.toolCalls, + }, nil + } + + return &providers.LLMResponse{ + Content: m.finalResp, + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *toolCallProvider) GetDefaultModel() string { + return "tool-call-mock" +} + +func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + tool1ExecCh := make(chan struct{}) + tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh} + tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond} + + provider := &toolCallProvider{ + toolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Name: "tool_one", + Function: &providers.FunctionCall{ + Name: "tool_one", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + { + ID: "call_2", + Type: "function", + Name: "tool_two", + Function: &providers.FunctionCall{ + Name: "tool_two", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + }, + finalResp: "steered response", + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, provider) + al.RegisterTool(tool1) + al.RegisterTool(tool2) + + // Start processing in a goroutine + type result struct { + resp string + err error + } + resultCh := make(chan result, 1) + + go func() { + resp, err := al.ProcessDirectWithChannel( + context.Background(), + "do something", + "test-session", + "test", + "chat1", + ) + resultCh <- result{resp, err} + }() + + // Wait for tool_one to start executing, then enqueue a steering message + select { + case <-tool1ExecCh: + // tool_one has started executing + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for tool_one to start") + } + + al.Steer(providers.Message{Role: "user", Content: "change course"}) + + // Get the result + select { + case r := <-resultCh: + if r.err != nil { + t.Fatalf("unexpected error: %v", r.err) + } + if r.resp != "steered response" { + t.Fatalf("expected 'steered response', got %q", r.resp) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for agent loop to complete") + } + + // The provider should have been called twice: + // 1. first call returned tool calls + // 2. second call (after steering) returned the final response + provider.mu.Lock() + calls := provider.calls + provider.mu.Unlock() + if calls != 2 { + t.Fatalf("expected 2 provider calls, got %d", calls) + } +} + +func TestAgentLoop_Steering_InitialPoll(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + // Provider that captures messages it receives + var capturedMessages []providers.Message + var capMu sync.Mutex + provider := &capturingMockProvider{ + response: "ack", + captureFn: func(msgs []providers.Message) { + capMu.Lock() + capturedMessages = make([]providers.Message, len(msgs)) + copy(capturedMessages, msgs) + capMu.Unlock() + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, provider) + + // Enqueue a steering message before processing starts + al.Steer(providers.Message{Role: "user", Content: "pre-enqueued steering"}) + + // Process a normal message - the initial steering poll should inject the steering message + _, err = al.ProcessDirectWithChannel( + context.Background(), + "initial message", + "test-session", + "test", + "chat1", + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // The steering message should have been injected into the conversation + capMu.Lock() + msgs := capturedMessages + capMu.Unlock() + + // Look for the steering message in the captured messages + found := false + for _, m := range msgs { + if m.Content == "pre-enqueued steering" { + found = true + break + } + } + if !found { + t.Fatal("expected steering message to be injected into conversation context") + } +} + +// capturingMockProvider captures messages sent to Chat for inspection. +type capturingMockProvider struct { + response string + calls int + captureFn func([]providers.Message) +} + +func (m *capturingMockProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + if m.captureFn != nil { + m.captureFn(messages) + } + return &providers.LLMResponse{ + Content: m.response, + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *capturingMockProvider) GetDefaultModel() string { + return "capturing-mock" +} + +func TestAgentLoop_Steering_SkippedToolsHaveErrorResults(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + execCh := make(chan struct{}) + tool1 := &slowTool{name: "slow_tool", duration: 50 * time.Millisecond, execCh: execCh} + tool2 := &slowTool{name: "skipped_tool", duration: 50 * time.Millisecond} + + // Provider that captures messages on the second call (after tools) + var secondCallMessages []providers.Message + var capMu sync.Mutex + callCount := 0 + + provider := &toolCallProvider{ + toolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Name: "slow_tool", + Function: &providers.FunctionCall{ + Name: "slow_tool", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + { + ID: "call_2", + Type: "function", + Name: "skipped_tool", + Function: &providers.FunctionCall{ + Name: "skipped_tool", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + }, + finalResp: "done", + } + + // Wrap provider to capture messages on second call + wrappedProvider := &wrappingProvider{ + inner: provider, + onChat: func(msgs []providers.Message) { + capMu.Lock() + callCount++ + if callCount >= 2 { + secondCallMessages = make([]providers.Message, len(msgs)) + copy(secondCallMessages, msgs) + } + capMu.Unlock() + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, wrappedProvider) + al.RegisterTool(tool1) + al.RegisterTool(tool2) + + resultCh := make(chan string, 1) + go func() { + resp, _ := al.ProcessDirectWithChannel( + context.Background(), "go", "test-session", "test", "chat1", + ) + resultCh <- resp + }() + + <-execCh + al.Steer(providers.Message{Role: "user", Content: "interrupt!"}) + + select { + case <-resultCh: + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } + + // Check that the skipped tool result message is in the conversation + capMu.Lock() + msgs := secondCallMessages + capMu.Unlock() + + foundSkipped := false + for _, m := range msgs { + if m.Role == "tool" && m.ToolCallID == "call_2" && m.Content == "Skipped due to queued user message." { + foundSkipped = true + break + } + } + if !foundSkipped { + // Log what we actually got + for i, m := range msgs { + t.Logf("msg[%d]: role=%s toolCallID=%s content=%s", i, m.Role, m.ToolCallID, truncate(m.Content, 80)) + } + t.Fatal("expected skipped tool result for call_2") + } +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} + +// wrappingProvider wraps another provider to hook into Chat calls. +type wrappingProvider struct { + inner providers.LLMProvider + onChat func([]providers.Message) +} + +func (w *wrappingProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + if w.onChat != nil { + w.onChat(messages) + } + return w.inner.Chat(ctx, messages, tools, model, opts) +} + +func (w *wrappingProvider) GetDefaultModel() string { + return w.inner.GetDefaultModel() +} + +// Ensure NormalizeToolCall handles our test tool calls. +func init() { + // This is a no-op init; we just need the tool call tests to work + // with the proper argument serialization. + _ = json.Marshal +} diff --git a/pkg/agent/subturn.go b/pkg/agent/subturn.go new file mode 100644 index 0000000000..ab7d60957b --- /dev/null +++ b/pkg/agent/subturn.go @@ -0,0 +1,309 @@ +package agent + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/session" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// ====================== Config & Constants ====================== +const maxSubTurnDepth = 3 + +var ( + ErrDepthLimitExceeded = errors.New("sub-turn depth limit exceeded") + ErrInvalidSubTurnConfig = errors.New("invalid sub-turn config") +) + +// ====================== SubTurn Config ====================== +type SubTurnConfig struct { + Model string + Tools []tools.Tool + SystemPrompt string + MaxTokens int + // Can be extended with temperature, topP, etc. +} + +// ====================== Sub-turn Events (Aligned with EventBus) ====================== +type SubTurnSpawnEvent struct { + ParentID string + ChildID string + Config SubTurnConfig +} + +type SubTurnEndEvent struct { + ChildID string + Result *tools.ToolResult + Err error +} + +type SubTurnResultDeliveredEvent struct { + ParentID string + ChildID string + Result *tools.ToolResult +} + +type SubTurnOrphanResultEvent struct { + ParentID string + ChildID string + Result *tools.ToolResult +} + +// ====================== turnState (Simplified, reusable with existing structs) ====================== +type turnState struct { + ctx context.Context + cancelFunc context.CancelFunc // Used to cancel all children when this turn finishes + turnID string + parentTurnID string + depth int + childTurnIDs []string + pendingResults chan *tools.ToolResult + session session.SessionStore + mu sync.Mutex + isFinished bool // Marks if the parent Turn has ended +} + +// ====================== Helper Functions ====================== +var globalTurnCounter int64 + +func generateTurnID() string { + return fmt.Sprintf("subturn-%d", atomic.AddInt64(&globalTurnCounter, 1)) +} + +func newTurnState(ctx context.Context, id string, parent *turnState) *turnState { + turnCtx, cancel := context.WithCancel(ctx) + return &turnState{ + ctx: turnCtx, + cancelFunc: cancel, + turnID: id, + parentTurnID: parent.turnID, + depth: parent.depth + 1, + session: newEphemeralSession(parent.session), + // NOTE: In this PoC, I use a fixed-size channel (16). + // Under high concurrency or long-running sub-turns, this might fill up and cause + // intermediate results to be discarded in deliverSubTurnResult. + // For production, consider an unbounded queue or a blocking strategy with backpressure. + pendingResults: make(chan *tools.ToolResult, 16), + } +} + +// Finish marks the turn as finished and cancels its context, aborting any running sub-turns. +func (ts *turnState) Finish() { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.isFinished = true + if ts.cancelFunc != nil { + ts.cancelFunc() + } +} + +// ephemeralSessionStore is a pure in-memory SessionStore for SubTurns. +// It never writes to disk, keeping sub-turn history isolated from the parent session. +type ephemeralSessionStore struct { + mu sync.Mutex + history []providers.Message + summary string +} + +func (e *ephemeralSessionStore) AddMessage(sessionKey, role, content string) { + e.mu.Lock() + defer e.mu.Unlock() + e.history = append(e.history, providers.Message{Role: role, Content: content}) +} + +func (e *ephemeralSessionStore) AddFullMessage(sessionKey string, msg providers.Message) { + e.mu.Lock() + defer e.mu.Unlock() + e.history = append(e.history, msg) +} + +func (e *ephemeralSessionStore) GetHistory(key string) []providers.Message { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]providers.Message, len(e.history)) + copy(out, e.history) + return out +} + +func (e *ephemeralSessionStore) GetSummary(key string) string { + e.mu.Lock() + defer e.mu.Unlock() + return e.summary +} + +func (e *ephemeralSessionStore) SetSummary(key, summary string) { + e.mu.Lock() + defer e.mu.Unlock() + e.summary = summary +} + +func (e *ephemeralSessionStore) SetHistory(key string, history []providers.Message) { + e.mu.Lock() + defer e.mu.Unlock() + e.history = make([]providers.Message, len(history)) + copy(e.history, history) +} + +func (e *ephemeralSessionStore) TruncateHistory(key string, keepLast int) { + e.mu.Lock() + defer e.mu.Unlock() + if len(e.history) > keepLast { + e.history = e.history[len(e.history)-keepLast:] + } +} + +func (e *ephemeralSessionStore) Save(key string) error { return nil } +func (e *ephemeralSessionStore) Close() error { return nil } + +func newEphemeralSession(_ session.SessionStore) session.SessionStore { + return &ephemeralSessionStore{} +} + +// ====================== Core Function: spawnSubTurn ====================== +func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg SubTurnConfig) (result *tools.ToolResult, err error) { + // 1. Depth limit check + if parentTS.depth >= maxSubTurnDepth { + return nil, ErrDepthLimitExceeded + } + + // 2. Config validation + if cfg.Model == "" { + return nil, ErrInvalidSubTurnConfig + } + + // Create a sub-context for the child turn to support cancellation + childCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // 3. Create child Turn state + childID := generateTurnID() + childTS := newTurnState(childCtx, childID, parentTS) + + // 4. Establish parent-child relationship (thread-safe) + parentTS.mu.Lock() + parentTS.childTurnIDs = append(parentTS.childTurnIDs, childID) + parentTS.mu.Unlock() + + // 5. Emit Spawn event (currently using Mock, will be replaced by real EventBus) + MockEventBus.Emit(SubTurnSpawnEvent{ + ParentID: parentTS.turnID, + ChildID: childID, + Config: cfg, + }) + + // 6. Defer emitting End event, and recover from panics to ensure it's always fired + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("subturn panicked: %v", r) + } + + MockEventBus.Emit(SubTurnEndEvent{ + ChildID: childID, + Result: result, + Err: err, + }) + }() + + // 7. Execute sub-turn via the real agent loop. + // Build a child AgentInstance from SubTurnConfig, inheriting defaults from the parent agent. + result, err = runTurn(childCtx, al, childTS, cfg) + + // 8. Deliver result back to parent Turn + deliverSubTurnResult(parentTS, childID, result) + + return result, err +} + +// ====================== Result Delivery ====================== +func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.ToolResult) { + parentTS.mu.Lock() + defer parentTS.mu.Unlock() + + // Emit ResultDelivered event + MockEventBus.Emit(SubTurnResultDeliveredEvent{ + ParentID: parentTS.turnID, + ChildID: childID, + Result: result, + }) + + if !parentTS.isFinished { + // Parent Turn is still running → Place in pending queue (handled automatically by parent loop in next round) + select { + case parentTS.pendingResults <- result: + default: + fmt.Println("[SubTurn] warning: pendingResults channel full") + } + return + } + + // Parent Turn has ended + // emit an OrphanResultEvent so the system/UI can handle this late arrival. + if result != nil { + MockEventBus.Emit(SubTurnOrphanResultEvent{ + ParentID: parentTS.turnID, + ChildID: childID, + Result: result, + }) + } +} + +// runTurn builds a temporary AgentInstance from SubTurnConfig and delegates to +// the real agent loop. The child's ephemeral session is used for history so it +// never pollutes the parent session. +func runTurn(ctx context.Context, al *AgentLoop, ts *turnState, cfg SubTurnConfig) (*tools.ToolResult, error) { + // Derive candidates from the requested model using the parent loop's provider. + defaultProvider := al.GetConfig().Agents.Defaults.Provider + candidates := providers.ResolveCandidates( + providers.ModelConfig{Primary: cfg.Model}, + defaultProvider, + ) + + // Build a minimal AgentInstance for this sub-turn. + // It reuses the parent loop's provider and config, but gets its own + // ephemeral session store and tool registry. + toolRegistry := tools.NewToolRegistry() + for _, t := range cfg.Tools { + toolRegistry.Register(t) + } + + parentAgent := al.GetRegistry().GetDefaultAgent() + childAgent := &AgentInstance{ + ID: ts.turnID, + Model: cfg.Model, + MaxIterations: parentAgent.MaxIterations, + MaxTokens: cfg.MaxTokens, + Temperature: parentAgent.Temperature, + ThinkingLevel: parentAgent.ThinkingLevel, + ContextWindow: cfg.MaxTokens, + SummarizeMessageThreshold: parentAgent.SummarizeMessageThreshold, + SummarizeTokenPercent: parentAgent.SummarizeTokenPercent, + Provider: parentAgent.Provider, + Sessions: ts.session, + ContextBuilder: parentAgent.ContextBuilder, + Tools: toolRegistry, + Candidates: candidates, + } + if childAgent.MaxTokens == 0 { + childAgent.MaxTokens = parentAgent.MaxTokens + childAgent.ContextWindow = parentAgent.ContextWindow + } + + finalContent, err := al.runAgentLoop(ctx, childAgent, processOptions{ + SessionKey: ts.turnID, + UserMessage: cfg.SystemPrompt, + DefaultResponse: "", + EnableSummary: false, + SendResponse: false, + }) + if err != nil { + return nil, err + } + return &tools.ToolResult{ForLLM: finalContent}, nil +} + +// ====================== Other Types ====================== diff --git a/pkg/agent/subturn_test.go b/pkg/agent/subturn_test.go new file mode 100644 index 0000000000..97dfc0130c --- /dev/null +++ b/pkg/agent/subturn_test.go @@ -0,0 +1,254 @@ +package agent + +import ( + "context" + "reflect" + "testing" + + "github.com/sipeed/picoclaw/pkg/tools" +) + +// ====================== Test Helper: Event Collector ====================== +type eventCollector struct { + events []any +} + +func (c *eventCollector) collect(e any) { + c.events = append(c.events, e) +} + +func (c *eventCollector) hasEventOfType(typ any) bool { + targetType := reflect.TypeOf(typ) + for _, e := range c.events { + if reflect.TypeOf(e) == targetType { + return true + } + } + return false +} + +func (c *eventCollector) countOfType(typ any) int { + targetType := reflect.TypeOf(typ) + count := 0 + for _, e := range c.events { + if reflect.TypeOf(e) == targetType { + count++ + } + } + return count +} + +// ====================== Main Test Function ====================== +func TestSpawnSubTurn(t *testing.T) { + tests := []struct { + name string + parentDepth int + config SubTurnConfig + wantErr error + wantSpawn bool + wantEnd bool + wantDepthFail bool + }{ + { + name: "Basic success path - Single layer sub-turn", + parentDepth: 0, + config: SubTurnConfig{ + Model: "gpt-4o-mini", + Tools: []tools.Tool{}, // At least one tool + }, + wantErr: nil, + wantSpawn: true, + wantEnd: true, + }, + { + name: "Nested 2 layers - Normal", + parentDepth: 1, + config: SubTurnConfig{ + Model: "gpt-4o-mini", + Tools: []tools.Tool{}, + }, + wantErr: nil, + wantSpawn: true, + wantEnd: true, + }, + { + name: "Depth limit triggered - 4th layer fails", + parentDepth: 3, + config: SubTurnConfig{ + Model: "gpt-4o-mini", + Tools: []tools.Tool{}, + }, + wantErr: ErrDepthLimitExceeded, + wantSpawn: false, + wantEnd: false, + wantDepthFail: true, + }, + { + name: "Invalid config - Empty Model", + parentDepth: 0, + config: SubTurnConfig{ + Model: "", + Tools: []tools.Tool{}, + }, + wantErr: ErrInvalidSubTurnConfig, + wantSpawn: false, + wantEnd: false, + }, + } + + al, _, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Prepare parent Turn + parent := &turnState{ + ctx: context.Background(), + turnID: "parent-1", + depth: tt.parentDepth, + childTurnIDs: []string{}, + pendingResults: make(chan *tools.ToolResult, 10), + session: &ephemeralSessionStore{}, + } + + // Replace mock with test collector + collector := &eventCollector{} + originalEmit := MockEventBus.Emit + MockEventBus.Emit = collector.collect + defer func() { MockEventBus.Emit = originalEmit }() + + // Execute spawnSubTurn + result, err := spawnSubTurn(context.Background(), al, parent, tt.config) + + // Assert errors + if tt.wantErr != nil { + if err == nil || err != tt.wantErr { + t.Errorf("expected error %v, got %v", tt.wantErr, err) + } + return + } + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Verify result + if result == nil { + t.Error("expected non-nil result") + } + + // Verify event emission + if tt.wantSpawn { + if !collector.hasEventOfType(SubTurnSpawnEvent{}) { + t.Error("SubTurnSpawnEvent not emitted") + } + } + if tt.wantEnd { + if !collector.hasEventOfType(SubTurnEndEvent{}) { + t.Error("SubTurnEndEvent not emitted") + } + } + + // Verify turn tree + if len(parent.childTurnIDs) == 0 && !tt.wantDepthFail { + t.Error("child Turn not added to parent.childTurnIDs") + } + + // Verify result delivery (pendingResults or history) + if len(parent.pendingResults) > 0 || len(parent.session.GetHistory("")) > 0 { + // Result delivered via at least one path + } else { + t.Error("child result not delivered") + } + }) + } +} + +// ====================== Extra Independent Test: Ephemeral Session Isolation ====================== +func TestSpawnSubTurn_EphemeralSessionIsolation(t *testing.T) { + al, _, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + + parentSession := &ephemeralSessionStore{} + parentSession.AddMessage("", "user", "parent msg") + parent := &turnState{ + ctx: context.Background(), + turnID: "parent-1", + depth: 0, + session: parentSession, + } + + cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}} + + // Record main session length before execution + originalLen := len(parent.session.GetHistory("")) + + _, _ = spawnSubTurn(context.Background(), al, parent, cfg) + + // After sub-turn ends, main session must remain unchanged + if len(parent.session.GetHistory("")) != originalLen { + t.Error("ephemeral session polluted the main session") + } +} + +// ====================== Extra Independent Test: Result Delivery Path ====================== +func TestSpawnSubTurn_ResultDelivery(t *testing.T) { + al, _, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + + parent := &turnState{ + ctx: context.Background(), + turnID: "parent-1", + depth: 0, + pendingResults: make(chan *tools.ToolResult, 1), + session: &ephemeralSessionStore{}, + } + + cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}} + + _, _ = spawnSubTurn(context.Background(), al, parent, cfg) + + // Check if pendingResults received the result + select { + case res := <-parent.pendingResults: + if res == nil { + t.Error("received nil result in pendingResults") + } + default: + t.Error("result did not enter pendingResults") + } +} + +// ====================== Extra Independent Test: Orphan Result Routing ====================== +func TestSpawnSubTurn_OrphanResultRouting(t *testing.T) { + parentCtx, cancelParent := context.WithCancel(context.Background()) + parent := &turnState{ + ctx: parentCtx, + cancelFunc: cancelParent, + turnID: "parent-1", + depth: 0, + pendingResults: make(chan *tools.ToolResult, 1), + session: &ephemeralSessionStore{}, + } + + collector := &eventCollector{} + originalEmit := MockEventBus.Emit + MockEventBus.Emit = collector.collect + defer func() { MockEventBus.Emit = originalEmit }() + + // Simulate parent finishing before child delivers result + parent.Finish() + + // Call deliverSubTurnResult directly to simulate a delayed child + deliverSubTurnResult(parent, "delayed-child", &tools.ToolResult{ForLLM: "late result"}) + + // Verify Orphan event is emitted + if !collector.hasEventOfType(SubTurnOrphanResultEvent{}) { + t.Error("SubTurnOrphanResultEvent not emitted for finished parent") + } + + // Verify history is NOT polluted + if len(parent.session.GetHistory("")) != 0 { + t.Error("Parent history was polluted by orphan result") + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 1903412248..a8b8f337fa 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -234,6 +234,7 @@ type AgentDefaults struct { SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"` MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"` Routing *RoutingConfig `json:"routing,omitempty"` + SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all" } const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 189af0a845..caa09b0e2e 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -35,6 +35,7 @@ func DefaultConfig() *Config { MaxToolIterations: 50, SummarizeMessageThreshold: 20, SummarizeTokenPercent: 75, + SteeringMode: "one-at-a-time", }, }, Bindings: []AgentBinding{}, @@ -384,6 +385,15 @@ func DefaultConfig() *Config { APIBase: "http://localhost:8000/v1", APIKey: "", }, + + // Azure OpenAI - https://portal.azure.com + // model_name is a user-friendly alias; the model field's path after "azure/" is your deployment name + { + ModelName: "azure-gpt5", + Model: "azure/my-gpt5-deployment", + APIBase: "https://your-resource.openai.azure.com", + APIKey: "", + }, }, Gateway: GatewayConfig{ Host: "127.0.0.1", diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 302613f338..4204cc192f 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "runtime" + "strconv" "strings" "sync" @@ -45,6 +46,9 @@ func init() { consoleWriter := zerolog.ConsoleWriter{ Out: os.Stdout, TimeFormat: "15:04:05", // TODO: make it configurable??? + + // Custom formatter to handle multiline strings and JSON objects + FormatFieldValue: formatFieldValue, } logger = zerolog.New(consoleWriter).With().Timestamp().Logger() @@ -52,6 +56,37 @@ func init() { }) } +func formatFieldValue(i any) string { + var s string + + switch val := i.(type) { + case string: + s = val + case []byte: + s = string(val) + default: + return fmt.Sprintf("%v", i) + } + + if unquoted, err := strconv.Unquote(s); err == nil { + s = unquoted + } + + if strings.Contains(s, "\n") { + return fmt.Sprintf("\n%s", s) + } + + if strings.Contains(s, " ") { + if (strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}")) || + (strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]")) { + return s + } + return fmt.Sprintf("%q", s) + } + + return s +} + func SetLevel(level LogLevel) { mu.Lock() defer mu.Unlock() @@ -163,10 +198,7 @@ func logMessage(level LogLevel, component string, message string, fields map[str event.Str("caller", fmt.Sprintf(" %s:%d (%s)", callerFile, callerLine, callerFunc)) } - for k, v := range fields { - event.Interface(k, v) - } - + appendFields(event, fields) event.Msg(message) // Also log to file if enabled @@ -176,9 +208,8 @@ func logMessage(level LogLevel, component string, message string, fields map[str if component != "" { fileEvent.Str("component", component) } - for k, v := range fields { - fileEvent.Interface(k, v) - } + + appendFields(event, fields) fileEvent.Msg(message) } @@ -187,6 +218,26 @@ func logMessage(level LogLevel, component string, message string, fields map[str } } +func appendFields(event *zerolog.Event, fields map[string]any) { + for k, v := range fields { + // Type switch to avoid double JSON serialization of strings + switch val := v.(type) { + case string: + event.Str(k, val) + case int: + event.Int(k, val) + case int64: + event.Int64(k, val) + case float64: + event.Float64(k, val) + case bool: + event.Bool(k, val) + default: + event.Interface(k, v) // Fallback for struct, slice and maps + } + } +} + func Debug(message string) { logMessage(DEBUG, "", message, nil) } diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go index 8170a618ba..31b40484cb 100644 --- a/pkg/logger/logger_test.go +++ b/pkg/logger/logger_test.go @@ -141,3 +141,114 @@ func TestLoggerHelperFunctions(t *testing.T) { Debugf("test from %v", "Debugf") WarnF("Warning with fields", map[string]any{"key": "value"}) } + +func TestFormatFieldValue(t *testing.T) { + tests := []struct { + name string + input any + expected string + }{ + // Basic types test (default case of the switch) + { + name: "Integer Type", + input: 42, + expected: "42", + }, + { + name: "Boolean Type", + input: true, + expected: "true", + }, + { + name: "Unsupported Struct Type", + input: struct{ A int }{A: 1}, + expected: "{1}", + }, + + // Simple strings and byte slices test + { + name: "Simple string without spaces", + input: "simple_value", + expected: "simple_value", + }, + { + name: "Simple byte slice", + input: []byte("byte_value"), + expected: "byte_value", + }, + + // Unquoting test (strconv.Unquote) + { + name: "Quoted string", + input: `"quoted_value"`, + expected: "quoted_value", + }, + + // Strings with newline (\n) test + { + name: "String with newline", + input: "line1\nline2", + expected: "\nline1\nline2", + }, + { + name: "Quoted string with newline (Unquote -> newline)", + input: `"line1\nline2"`, // Escaped \n that Unquote will resolve + expected: "\nline1\nline2", + }, + + // Strings with spaces test (which should be quoted) + { + name: "String with spaces", + input: "hello world", + expected: `"hello world"`, + }, + { + name: "Quoted string with spaces (Unquote -> has spaces -> Re-quote)", + input: `"hello world"`, + expected: `"hello world"`, + }, + + // JSON formats test (strings with spaces that start/end with brackets) + { + name: "Valid JSON object", + input: `{"key": "value"}`, + expected: `{"key": "value"}`, + }, + { + name: "Valid JSON array", + input: `[1, 2, "three"]`, + expected: `[1, 2, "three"]`, + }, + { + name: "Fake JSON (starts with { but doesn't end with })", + input: `{"key": "value"`, // Missing closing bracket, has spaces + expected: `"{\"key\": \"value\""`, + }, + { + name: "Empty JSON (object)", + input: `{ }`, + expected: `{ }`, + }, + + // 7. Edge Cases + { + name: "Empty string", + input: "", + expected: "", + }, + { + name: "Whitespace only string", + input: " ", + expected: `" "`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := formatFieldValue(tt.input) + if actual != tt.expected { + t.Errorf("formatFieldValue() = %q, expected %q", actual, tt.expected) + } + }) + } +} diff --git a/pkg/providers/azure/provider.go b/pkg/providers/azure/provider.go new file mode 100644 index 0000000000..e0ddbbde44 --- /dev/null +++ b/pkg/providers/azure/provider.go @@ -0,0 +1,150 @@ +package azure + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers/common" + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +type ( + LLMResponse = protocoltypes.LLMResponse + Message = protocoltypes.Message + ToolDefinition = protocoltypes.ToolDefinition +) + +const ( + // azureAPIVersion is the Azure OpenAI API version used for all requests. + azureAPIVersion = "2024-10-21" + defaultRequestTimeout = common.DefaultRequestTimeout +) + +// Provider implements the LLM provider interface for Azure OpenAI endpoints. +// It handles Azure-specific authentication (api-key header), URL construction +// (deployment-based), and request body formatting (max_completion_tokens, no model field). +type Provider struct { + apiKey string + apiBase string + httpClient *http.Client +} + +// Option configures the Azure Provider. +type Option func(*Provider) + +// WithRequestTimeout sets the HTTP request timeout. +func WithRequestTimeout(timeout time.Duration) Option { + return func(p *Provider) { + if timeout > 0 { + p.httpClient.Timeout = timeout + } + } +} + +// NewProvider creates a new Azure OpenAI provider. +func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider { + p := &Provider{ + apiKey: apiKey, + apiBase: strings.TrimRight(apiBase, "/"), + httpClient: common.NewHTTPClient(proxy), + } + + for _, opt := range opts { + if opt != nil { + opt(p) + } + } + + return p +} + +// NewProviderWithTimeout creates a new Azure OpenAI provider with a custom request timeout in seconds. +func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds int) *Provider { + return NewProvider( + apiKey, apiBase, proxy, + WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second), + ) +} + +// Chat sends a chat completion request to the Azure OpenAI endpoint. +// The model parameter is used as the Azure deployment name in the URL. +func (p *Provider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (*LLMResponse, error) { + if p.apiBase == "" { + return nil, fmt.Errorf("Azure API base not configured") + } + + // model is the deployment name for Azure OpenAI + deployment := model + + // Build Azure-specific URL safely using url.JoinPath and query encoding + // to prevent path traversal or query injection via deployment names. + base, err := url.JoinPath(p.apiBase, "openai/deployments", deployment, "chat/completions") + if err != nil { + return nil, fmt.Errorf("failed to build Azure request URL: %w", err) + } + requestURL := base + "?api-version=" + azureAPIVersion + + // Build request body — no "model" field (Azure infers from deployment URL) + requestBody := map[string]any{ + "messages": common.SerializeMessages(messages), + } + + if len(tools) > 0 { + requestBody["tools"] = tools + requestBody["tool_choice"] = "auto" + } + + // Azure OpenAI always uses max_completion_tokens + if maxTokens, ok := common.AsInt(options["max_tokens"]); ok { + requestBody["max_completion_tokens"] = maxTokens + } + + if temperature, ok := common.AsFloat(options["temperature"]); ok { + requestBody["temperature"] = temperature + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Azure uses api-key header instead of Authorization: Bearer + req.Header.Set("Content-Type", "application/json") + if p.apiKey != "" { + req.Header.Set("Api-Key", p.apiKey) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, common.HandleErrorResponse(resp, p.apiBase) + } + + return common.ReadAndParseResponse(resp, p.apiBase) +} + +// GetDefaultModel returns an empty string as Azure deployments are user-configured. +func (p *Provider) GetDefaultModel() string { + return "" +} diff --git a/pkg/providers/azure/provider_test.go b/pkg/providers/azure/provider_test.go new file mode 100644 index 0000000000..531b812965 --- /dev/null +++ b/pkg/providers/azure/provider_test.go @@ -0,0 +1,232 @@ +package azure + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +// writeValidResponse writes a minimal valid Azure OpenAI chat completion response. +func writeValidResponse(w http.ResponseWriter) { + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +func TestProviderChat_AzureURLConstruction(t *testing.T) { + var capturedPath string + var capturedAPIVersion string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + capturedAPIVersion = r.URL.Query().Get("api-version") + writeValidResponse(w) + })) + defer server.Close() + + p := NewProvider("test-key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + wantPath := "/openai/deployments/my-gpt5-deployment/chat/completions" + if capturedPath != wantPath { + t.Errorf("URL path = %q, want %q", capturedPath, wantPath) + } + if capturedAPIVersion != azureAPIVersion { + t.Errorf("api-version = %q, want %q", capturedAPIVersion, azureAPIVersion) + } +} + +func TestProviderChat_AzureAuthHeader(t *testing.T) { + var capturedAPIKey string + var capturedAuth string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAPIKey = r.Header.Get("Api-Key") + capturedAuth = r.Header.Get("Authorization") + writeValidResponse(w) + })) + defer server.Close() + + p := NewProvider("test-azure-key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if capturedAPIKey != "test-azure-key" { + t.Errorf("api-key header = %q, want %q", capturedAPIKey, "test-azure-key") + } + if capturedAuth != "" { + t.Errorf("Authorization header should be empty, got %q", capturedAuth) + } +} + +func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&requestBody) + writeValidResponse(w) + })) + defer server.Close() + + p := NewProvider("test-key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if _, exists := requestBody["model"]; exists { + t.Error("request body should not contain 'model' field for Azure OpenAI") + } +} + +func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&requestBody) + writeValidResponse(w) + })) + defer server.Close() + + p := NewProvider("test-key", server.URL, "") + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "deployment", + map[string]any{"max_tokens": 2048}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if _, exists := requestBody["max_completion_tokens"]; !exists { + t.Error("request body should contain 'max_completion_tokens'") + } + if _, exists := requestBody["max_tokens"]; exists { + t.Error("request body should not contain 'max_tokens'") + } +} + +func TestProviderChat_AzureHTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized) + })) + defer server.Close() + + p := NewProvider("bad-key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProviderChat_AzureParseToolCalls(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{ + "content": "", + "tool_calls": []map[string]any{ + { + "id": "call_1", + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "arguments": `{"city":"Seattle"}`, + }, + }, + }, + }, + "finish_reason": "tool_calls", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("test-key", server.URL, "") + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].Name != "get_weather" { + t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather") + } +} + +func TestProvider_AzureEmptyAPIBase(t *testing.T) { + p := NewProvider("test-key", "", "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil) + if err == nil { + t.Fatal("expected error for empty API base") + } +} + +func TestProvider_AzureRequestTimeoutDefault(t *testing.T) { + p := NewProvider("test-key", "https://example.com", "") + if p.httpClient.Timeout != defaultRequestTimeout { + t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout) + } +} + +func TestProvider_AzureRequestTimeoutOverride(t *testing.T) { + p := NewProvider("test-key", "https://example.com", "", WithRequestTimeout(300*time.Second)) + if p.httpClient.Timeout != 300*time.Second { + t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second) + } +} + +func TestProvider_AzureNewProviderWithTimeout(t *testing.T) { + p := NewProviderWithTimeout("test-key", "https://example.com", "", 180) + if p.httpClient.Timeout != 180*time.Second { + t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second) + } +} + +func TestProviderChat_AzureDeploymentNameEscaped(t *testing.T) { + var capturedPath string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.RawPath // use RawPath to see percent-encoding + if capturedPath == "" { + capturedPath = r.URL.Path + } + writeValidResponse(w) + })) + defer server.Close() + + p := NewProvider("test-key", server.URL, "") + + // Deployment name with characters that could cause path injection + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my deploy/../../admin", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + // The slash and special chars in the deployment name must be escaped, not treated as path separators + if capturedPath == "/openai/deployments/my deploy/../../admin/chat/completions" { + t.Fatal("deployment name was interpolated without escaping — path injection possible") + } +} diff --git a/pkg/providers/common/common.go b/pkg/providers/common/common.go new file mode 100644 index 0000000000..23680a1bf9 --- /dev/null +++ b/pkg/providers/common/common.go @@ -0,0 +1,380 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +// Package common provides shared utilities used by multiple LLM provider +// implementations (openai_compat, azure, etc.). +package common + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +// Re-export protocol types used across providers. +type ( + ToolCall = protocoltypes.ToolCall + FunctionCall = protocoltypes.FunctionCall + LLMResponse = protocoltypes.LLMResponse + UsageInfo = protocoltypes.UsageInfo + Message = protocoltypes.Message + ToolDefinition = protocoltypes.ToolDefinition + ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition + ExtraContent = protocoltypes.ExtraContent + GoogleExtra = protocoltypes.GoogleExtra + ReasoningDetail = protocoltypes.ReasoningDetail +) + +const DefaultRequestTimeout = 120 * time.Second + +// NewHTTPClient creates an *http.Client with an optional proxy and the default timeout. +func NewHTTPClient(proxy string) *http.Client { + client := &http.Client{ + Timeout: DefaultRequestTimeout, + } + if proxy != "" { + parsed, err := url.Parse(proxy) + if err == nil { + // Preserve http.DefaultTransport settings (TLS, HTTP/2, timeouts, etc.) + if base, ok := http.DefaultTransport.(*http.Transport); ok { + tr := base.Clone() + tr.Proxy = http.ProxyURL(parsed) + client.Transport = tr + } else { + // Fallback: minimal transport if DefaultTransport is not *http.Transport. + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(parsed), + } + } + } else { + log.Printf("common: invalid proxy URL %q: %v", proxy, err) + } + } + return client +} + +// --- Message serialization --- + +// openaiMessage is the wire-format message for OpenAI-compatible APIs. +// It mirrors protocoltypes.Message but omits SystemParts, which is an +// internal field that would be unknown to third-party endpoints. +type openaiMessage struct { + Role string `json:"role"` + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +// SerializeMessages converts internal Message structs to the OpenAI wire format. +// - Strips SystemParts (unknown to third-party endpoints) +// - Converts messages with Media to multipart content format (text + image_url parts) +// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages +func SerializeMessages(messages []Message) []any { + out := make([]any, 0, len(messages)) + for _, m := range messages { + if len(m.Media) == 0 { + out = append(out, openaiMessage{ + Role: m.Role, + Content: m.Content, + ReasoningContent: m.ReasoningContent, + ToolCalls: m.ToolCalls, + ToolCallID: m.ToolCallID, + }) + continue + } + + // Multipart content format for messages with media + parts := make([]map[string]any, 0, 1+len(m.Media)) + if m.Content != "" { + parts = append(parts, map[string]any{ + "type": "text", + "text": m.Content, + }) + } + for _, mediaURL := range m.Media { + if strings.HasPrefix(mediaURL, "data:image/") { + parts = append(parts, map[string]any{ + "type": "image_url", + "image_url": map[string]any{ + "url": mediaURL, + }, + }) + } + } + + msg := map[string]any{ + "role": m.Role, + "content": parts, + } + if m.ToolCallID != "" { + msg["tool_call_id"] = m.ToolCallID + } + if len(m.ToolCalls) > 0 { + msg["tool_calls"] = m.ToolCalls + } + if m.ReasoningContent != "" { + msg["reasoning_content"] = m.ReasoningContent + } + out = append(out, msg) + } + return out +} + +// --- Response parsing --- + +// ParseResponse parses a JSON chat completion response body into an LLMResponse. +func ParseResponse(body io.Reader) (*LLMResponse, error) { + var apiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content"` + Reasoning string `json:"reasoning"` + ReasoningDetails []ReasoningDetail `json:"reasoning_details"` + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function *struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` + } `json:"function"` + ExtraContent *struct { + Google *struct { + ThoughtSignature string `json:"thought_signature"` + } `json:"google"` + } `json:"extra_content"` + } `json:"tool_calls"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage *UsageInfo `json:"usage"` + } + + if err := json.NewDecoder(body).Decode(&apiResponse); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + if len(apiResponse.Choices) == 0 { + return &LLMResponse{ + Content: "", + FinishReason: "stop", + }, nil + } + + choice := apiResponse.Choices[0] + toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) + for _, tc := range choice.Message.ToolCalls { + arguments := make(map[string]any) + name := "" + + // Extract thought_signature from Gemini/Google-specific extra content + thoughtSignature := "" + if tc.ExtraContent != nil && tc.ExtraContent.Google != nil { + thoughtSignature = tc.ExtraContent.Google.ThoughtSignature + } + + if tc.Function != nil { + name = tc.Function.Name + arguments = DecodeToolCallArguments(tc.Function.Arguments, name) + } + + toolCall := ToolCall{ + ID: tc.ID, + Name: name, + Arguments: arguments, + ThoughtSignature: thoughtSignature, + } + + if thoughtSignature != "" { + toolCall.ExtraContent = &ExtraContent{ + Google: &GoogleExtra{ + ThoughtSignature: thoughtSignature, + }, + } + } + + toolCalls = append(toolCalls, toolCall) + } + + return &LLMResponse{ + Content: choice.Message.Content, + ReasoningContent: choice.Message.ReasoningContent, + Reasoning: choice.Message.Reasoning, + ReasoningDetails: choice.Message.ReasoningDetails, + ToolCalls: toolCalls, + FinishReason: choice.FinishReason, + Usage: apiResponse.Usage, + }, nil +} + +// DecodeToolCallArguments decodes a tool call's arguments from raw JSON. +func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any { + arguments := make(map[string]any) + raw = bytes.TrimSpace(raw) + if len(raw) == 0 || bytes.Equal(raw, []byte("null")) { + return arguments + } + + var decoded any + if err := json.Unmarshal(raw, &decoded); err != nil { + log.Printf("common: failed to decode tool call arguments payload for %q: %v", name, err) + arguments["raw"] = string(raw) + return arguments + } + + switch v := decoded.(type) { + case string: + if strings.TrimSpace(v) == "" { + return arguments + } + if err := json.Unmarshal([]byte(v), &arguments); err != nil { + log.Printf("common: failed to decode tool call arguments for %q: %v", name, err) + arguments["raw"] = v + } + return arguments + case map[string]any: + return v + default: + log.Printf("common: unsupported tool call arguments type for %q: %T", name, decoded) + arguments["raw"] = string(raw) + return arguments + } +} + +// --- HTTP response helpers --- + +// HandleErrorResponse reads a non-200 response body and returns an appropriate error. +func HandleErrorResponse(resp *http.Response, apiBase string) error { + contentType := resp.Header.Get("Content-Type") + body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256)) + if readErr != nil { + return fmt.Errorf("failed to read response: %w", readErr) + } + if LooksLikeHTML(body, contentType) { + return WrapHTMLResponseError(resp.StatusCode, body, contentType, apiBase) + } + return fmt.Errorf( + "API request failed:\n Status: %d\n Body: %s", + resp.StatusCode, + ResponsePreview(body, 128), + ) +} + +// ReadAndParseResponse peeks at the response body to detect HTML errors, +// then parses the JSON response into an LLMResponse. +func ReadAndParseResponse(resp *http.Response, apiBase string) (*LLMResponse, error) { + contentType := resp.Header.Get("Content-Type") + reader := bufio.NewReader(resp.Body) + prefix, err := reader.Peek(256) + if err != nil && err != io.EOF && err != bufio.ErrBufferFull { + return nil, fmt.Errorf("failed to inspect response: %w", err) + } + if LooksLikeHTML(prefix, contentType) { + return nil, WrapHTMLResponseError(resp.StatusCode, prefix, contentType, apiBase) + } + out, err := ParseResponse(reader) + if err != nil { + return nil, fmt.Errorf("failed to parse JSON response: %w", err) + } + return out, nil +} + +// LooksLikeHTML checks if the response body appears to be HTML. +func LooksLikeHTML(body []byte, contentType string) bool { + contentType = strings.ToLower(strings.TrimSpace(contentType)) + if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") { + return true + } + prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128)) + return bytes.HasPrefix(prefix, []byte("" + } + if len(trimmed) <= maxLen { + return string(trimmed) + } + return string(trimmed[:maxLen]) + "..." +} + +func leadingTrimmedPrefix(body []byte, maxLen int) []byte { + i := 0 + for i < len(body) { + switch body[i] { + case ' ', '\t', '\n', '\r', '\f', '\v': + i++ + default: + end := i + maxLen + if end > len(body) { + end = len(body) + } + return body[i:end] + } + } + return nil +} + +// --- Numeric helpers --- + +// AsInt converts various numeric types to int. +func AsInt(v any) (int, bool) { + switch val := v.(type) { + case int: + return val, true + case int64: + return int(val), true + case float64: + return int(val), true + case float32: + return int(val), true + default: + return 0, false + } +} + +// AsFloat converts various numeric types to float64. +func AsFloat(v any) (float64, bool) { + switch val := v.(type) { + case float64: + return val, true + case float32: + return float64(val), true + case int: + return float64(val), true + case int64: + return float64(val), true + default: + return 0, false + } +} diff --git a/pkg/providers/common/common_test.go b/pkg/providers/common/common_test.go new file mode 100644 index 0000000000..bb7e7434d0 --- /dev/null +++ b/pkg/providers/common/common_test.go @@ -0,0 +1,558 @@ +package common + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +// --- NewHTTPClient tests --- + +func TestNewHTTPClient_DefaultTimeout(t *testing.T) { + client := NewHTTPClient("") + if client.Timeout != DefaultRequestTimeout { + t.Errorf("timeout = %v, want %v", client.Timeout, DefaultRequestTimeout) + } +} + +func TestNewHTTPClient_WithProxy(t *testing.T) { + client := NewHTTPClient("http://127.0.0.1:8080") + transport, ok := client.Transport.(*http.Transport) + if !ok || transport == nil { + t.Fatalf("expected http.Transport with proxy, got %T", client.Transport) + } + req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}} + gotProxy, err := transport.Proxy(req) + if err != nil { + t.Fatalf("proxy function error: %v", err) + } + if gotProxy == nil || gotProxy.String() != "http://127.0.0.1:8080" { + t.Errorf("proxy = %v, want http://127.0.0.1:8080", gotProxy) + } +} + +func TestNewHTTPClient_NoProxy(t *testing.T) { + client := NewHTTPClient("") + if client.Transport != nil { + t.Errorf("expected nil transport without proxy, got %T", client.Transport) + } +} + +func TestNewHTTPClient_InvalidProxy(t *testing.T) { + // Should not panic, just log and return client without proxy + client := NewHTTPClient("://bad-url") + if client == nil { + t.Fatal("expected non-nil client even with invalid proxy") + } +} + +// --- SerializeMessages tests --- + +func TestSerializeMessages_PlainText(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi", ReasoningContent: "thinking..."}, + } + result := SerializeMessages(messages) + + data, _ := json.Marshal(result) + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + if msgs[0]["content"] != "hello" { + t.Errorf("expected plain string content, got %v", msgs[0]["content"]) + } + if msgs[1]["reasoning_content"] != "thinking..." { + t.Errorf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"]) + } +} + +func TestSerializeMessages_WithMedia(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}}, + } + result := SerializeMessages(messages) + + data, _ := json.Marshal(result) + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + content, ok := msgs[0]["content"].([]any) + if !ok { + t.Fatalf("expected array content for media message, got %T", msgs[0]["content"]) + } + if len(content) != 2 { + t.Fatalf("expected 2 content parts, got %d", len(content)) + } +} + +func TestSerializeMessages_MediaWithToolCallID(t *testing.T) { + messages := []Message{ + {Role: "tool", Content: "result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"}, + } + result := SerializeMessages(messages) + + data, _ := json.Marshal(result) + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + if msgs[0]["tool_call_id"] != "call_1" { + t.Errorf("tool_call_id not preserved, got %v", msgs[0]["tool_call_id"]) + } +} + +func TestSerializeMessages_StripsSystemParts(t *testing.T) { + messages := []Message{ + { + Role: "system", + Content: "you are helpful", + SystemParts: []protocoltypes.ContentBlock{ + {Type: "text", Text: "you are helpful"}, + }, + }, + } + result := SerializeMessages(messages) + + data, _ := json.Marshal(result) + if strings.Contains(string(data), "system_parts") { + t.Error("system_parts should not appear in serialized output") + } +} + +// --- ParseResponse tests --- + +func TestParseResponse_BasicContent(t *testing.T) { + body := `{"choices":[{"message":{"content":"hello world"},"finish_reason":"stop"}]}` + out, err := ParseResponse(strings.NewReader(body)) + if err != nil { + t.Fatalf("ParseResponse() error = %v", err) + } + if out.Content != "hello world" { + t.Errorf("Content = %q, want %q", out.Content, "hello world") + } + if out.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop") + } +} + +func TestParseResponse_EmptyChoices(t *testing.T) { + body := `{"choices":[]}` + out, err := ParseResponse(strings.NewReader(body)) + if err != nil { + t.Fatalf("ParseResponse() error = %v", err) + } + if out.Content != "" { + t.Errorf("Content = %q, want empty", out.Content) + } + if out.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop") + } +} + +func TestParseResponse_WithToolCalls(t *testing.T) { + body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"city\":\"SF\"}"}}]},"finish_reason":"tool_calls"}]}` + out, err := ParseResponse(strings.NewReader(body)) + if err != nil { + t.Fatalf("ParseResponse() error = %v", err) + } + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].Name != "get_weather" { + t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather") + } + if out.ToolCalls[0].Arguments["city"] != "SF" { + t.Errorf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"]) + } +} + +func TestParseResponse_WithUsage(t *testing.T) { + body := `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}` + out, err := ParseResponse(strings.NewReader(body)) + if err != nil { + t.Fatalf("ParseResponse() error = %v", err) + } + if out.Usage == nil { + t.Fatal("Usage is nil") + } + if out.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", out.Usage.PromptTokens) + } +} + +func TestParseResponse_WithReasoningContent(t *testing.T) { + body := `{"choices":[{"message":{"content":"2","reasoning_content":"Let me think... 1+1=2"},"finish_reason":"stop"}]}` + out, err := ParseResponse(strings.NewReader(body)) + if err != nil { + t.Fatalf("ParseResponse() error = %v", err) + } + if out.ReasoningContent != "Let me think... 1+1=2" { + t.Errorf("ReasoningContent = %q, want %q", out.ReasoningContent, "Let me think... 1+1=2") + } +} + +func TestParseResponse_InvalidJSON(t *testing.T) { + _, err := ParseResponse(strings.NewReader("not json")) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +// --- DecodeToolCallArguments tests --- + +func TestDecodeToolCallArguments_ObjectJSON(t *testing.T) { + raw := json.RawMessage(`{"city":"Seattle","units":"metric"}`) + args := DecodeToolCallArguments(raw, "test") + if args["city"] != "Seattle" { + t.Errorf("city = %v, want Seattle", args["city"]) + } + if args["units"] != "metric" { + t.Errorf("units = %v, want metric", args["units"]) + } +} + +func TestDecodeToolCallArguments_StringJSON(t *testing.T) { + raw := json.RawMessage(`"{\"city\":\"SF\"}"`) + args := DecodeToolCallArguments(raw, "test") + if args["city"] != "SF" { + t.Errorf("city = %v, want SF", args["city"]) + } +} + +func TestDecodeToolCallArguments_EmptyInput(t *testing.T) { + args := DecodeToolCallArguments(nil, "test") + if len(args) != 0 { + t.Errorf("expected empty map, got %v", args) + } +} + +func TestDecodeToolCallArguments_NullInput(t *testing.T) { + args := DecodeToolCallArguments(json.RawMessage(`null`), "test") + if len(args) != 0 { + t.Errorf("expected empty map, got %v", args) + } +} + +func TestDecodeToolCallArguments_InvalidJSON(t *testing.T) { + args := DecodeToolCallArguments(json.RawMessage(`not-json`), "test") + if _, ok := args["raw"]; !ok { + t.Error("expected 'raw' fallback key for invalid JSON") + } +} + +func TestDecodeToolCallArguments_EmptyStringJSON(t *testing.T) { + args := DecodeToolCallArguments(json.RawMessage(`" "`), "test") + if len(args) != 0 { + t.Errorf("expected empty map for whitespace string, got %v", args) + } +} + +// --- HandleErrorResponse tests --- + +func TestHandleErrorResponse_JSONError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error":"bad request"}`)) + })) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("http.Get() error = %v", err) + } + defer resp.Body.Close() + err = HandleErrorResponse(resp, server.URL) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "400") { + t.Errorf("error should contain status code, got %v", err) + } + if strings.Contains(err.Error(), "HTML") { + t.Errorf("should not mention HTML for JSON error, got %v", err) + } +} + +func TestHandleErrorResponse_HTMLError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusBadGateway) + w.Write([]byte("bad gateway")) + })) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("http.Get() error = %v", err) + } + defer resp.Body.Close() + err = HandleErrorResponse(resp, server.URL) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "HTML instead of JSON") { + t.Errorf("expected HTML error message, got %v", err) + } +} + +// --- ReadAndParseResponse tests --- + +func TestReadAndParseResponse_ValidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`)) + })) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("http.Get() error = %v", err) + } + defer resp.Body.Close() + out, err := ReadAndParseResponse(resp, server.URL) + if err != nil { + t.Fatalf("ReadAndParseResponse() error = %v", err) + } + if out.Content != "ok" { + t.Errorf("Content = %q, want %q", out.Content, "ok") + } +} + +func TestReadAndParseResponse_HTMLResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Write([]byte("login page")) + })) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("http.Get() error = %v", err) + } + defer resp.Body.Close() + _, err = ReadAndParseResponse(resp, server.URL) + if err == nil { + t.Fatal("expected error for HTML response") + } + if !strings.Contains(err.Error(), "HTML instead of JSON") { + t.Errorf("expected HTML error, got %v", err) + } +} + +// --- LooksLikeHTML tests --- + +func TestLooksLikeHTML_ContentTypeHTML(t *testing.T) { + if !LooksLikeHTML(nil, "text/html; charset=utf-8") { + t.Error("expected true for text/html content type") + } +} + +func TestLooksLikeHTML_ContentTypeXHTML(t *testing.T) { + if !LooksLikeHTML(nil, "application/xhtml+xml") { + t.Error("expected true for xhtml content type") + } +} + +func TestLooksLikeHTML_BodyPrefix(t *testing.T) { + tests := []struct { + name string + body string + }{ + {"doctype", ""}, + {"html tag", ""}, + {"head tag", ""}, + {"body tag", "<body>content"}, + {"whitespace before", " \n\t<!DOCTYPE html>"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if !LooksLikeHTML([]byte(tt.body), "application/json") { + t.Errorf("expected true for body %q", tt.body) + } + }) + } +} + +func TestLooksLikeHTML_NotHTML(t *testing.T) { + if LooksLikeHTML([]byte(`{"error":"bad"}`), "application/json") { + t.Error("expected false for JSON body") + } +} + +// --- ResponsePreview tests --- + +func TestResponsePreview_Short(t *testing.T) { + got := ResponsePreview([]byte("hello"), 128) + if got != "hello" { + t.Errorf("got %q, want %q", got, "hello") + } +} + +func TestResponsePreview_Truncated(t *testing.T) { + body := strings.Repeat("a", 200) + got := ResponsePreview([]byte(body), 128) + if len(got) != 131 { // 128 + "..." + t.Errorf("len = %d, want 131", len(got)) + } + if !strings.HasSuffix(got, "...") { + t.Error("expected ... suffix") + } +} + +func TestResponsePreview_Empty(t *testing.T) { + got := ResponsePreview([]byte(""), 128) + if got != "<empty>" { + t.Errorf("got %q, want %q", got, "<empty>") + } +} + +func TestResponsePreview_Whitespace(t *testing.T) { + got := ResponsePreview([]byte(" \n\t "), 128) + if got != "<empty>" { + t.Errorf("got %q, want %q for whitespace-only body", got, "<empty>") + } +} + +// --- AsInt tests --- + +func TestAsInt(t *testing.T) { + tests := []struct { + name string + val any + want int + ok bool + }{ + {"int", 42, 42, true}, + {"int64", int64(99), 99, true}, + {"float64", float64(512), 512, true}, + {"float32", float32(256), 256, true}, + {"string", "nope", 0, false}, + {"nil", nil, 0, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := AsInt(tt.val) + if ok != tt.ok || got != tt.want { + t.Errorf("AsInt(%v) = (%d, %v), want (%d, %v)", tt.val, got, ok, tt.want, tt.ok) + } + }) + } +} + +// --- AsFloat tests --- + +func TestAsFloat(t *testing.T) { + tests := []struct { + name string + val any + want float64 + ok bool + }{ + {"float64", float64(0.7), 0.7, true}, + {"float32", float32(0.5), float64(float32(0.5)), true}, + {"int", 1, 1.0, true}, + {"int64", int64(100), 100.0, true}, + {"string", "nope", 0, false}, + {"nil", nil, 0, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := AsFloat(tt.val) + if ok != tt.ok || got != tt.want { + t.Errorf("AsFloat(%v) = (%f, %v), want (%f, %v)", tt.val, got, ok, tt.want, tt.ok) + } + }) + } +} + +// --- WrapHTMLResponseError tests --- + +func TestWrapHTMLResponseError(t *testing.T) { + err := WrapHTMLResponseError(502, []byte("<html>bad</html>"), "text/html", "https://api.example.com") + if err == nil { + t.Fatal("expected error") + } + msg := err.Error() + if !strings.Contains(msg, "502") { + t.Errorf("expected status code in error, got %v", msg) + } + if !strings.Contains(msg, "https://api.example.com") { + t.Errorf("expected api base in error, got %v", msg) + } + if !strings.Contains(msg, "HTML instead of JSON") { + t.Errorf("expected HTML mention in error, got %v", msg) + } +} + +// --- HandleErrorResponse with read failure --- + +func TestHandleErrorResponse_EmptyBody(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + // empty body + })) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("http.Get() error = %v", err) + } + defer resp.Body.Close() + err = HandleErrorResponse(resp, server.URL) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "500") { + t.Errorf("expected status code, got %v", err) + } +} + +// --- ReadAndParseResponse with invalid JSON --- + +func TestReadAndParseResponse_InvalidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("not valid json")) + })) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("http.Get() error = %v", err) + } + defer resp.Body.Close() + _, err = ReadAndParseResponse(resp, server.URL) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +// --- ParseResponse with thought_signature (Google/Gemini) --- + +func TestParseResponse_WithThoughtSignature(t *testing.T) { + body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"test_tool","arguments":"{}"},"extra_content":{"google":{"thought_signature":"sig123"}}}]},"finish_reason":"tool_calls"}]}` + out, err := ParseResponse(strings.NewReader(body)) + if err != nil { + t.Fatalf("ParseResponse() error = %v", err) + } + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].ThoughtSignature != "sig123" { + t.Errorf("ThoughtSignature = %q, want %q", out.ToolCalls[0].ThoughtSignature, "sig123") + } + if out.ToolCalls[0].ExtraContent == nil || out.ToolCalls[0].ExtraContent.Google == nil { + t.Fatal("ExtraContent.Google is nil") + } + if out.ToolCalls[0].ExtraContent.Google.ThoughtSignature != "sig123" { + t.Errorf("ExtraContent.Google.ThoughtSignature = %q, want %q", + out.ToolCalls[0].ExtraContent.Google.ThoughtSignature, "sig123") + } +} diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index e99e07bc26..b7567f9fcf 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -11,6 +11,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages" + "github.com/sipeed/picoclaw/pkg/providers/azure" ) // createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store. @@ -94,6 +95,24 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.RequestTimeout, ), modelID, nil + case "azure", "azure-openai": + // Azure OpenAI uses deployment-based URLs, api-key header auth, + // and always sends max_completion_tokens. + if cfg.APIKey == "" { + return nil, "", fmt.Errorf("api_key is required for azure protocol") + } + if cfg.APIBase == "" { + return nil, "", fmt.Errorf( + "api_base is required for azure protocol (e.g., https://your-resource.openai.azure.com)", + ) + } + return azure.NewProviderWithTimeout( + cfg.APIKey, + cfg.APIBase, + cfg.Proxy, + cfg.RequestTimeout, + ), modelID, nil + case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", "vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian", diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go index 00676ebf98..b678a7eb61 100644 --- a/pkg/providers/factory_provider_test.go +++ b/pkg/providers/factory_provider_test.go @@ -64,6 +64,12 @@ func TestExtractProtocol(t *testing.T) { wantProtocol: "nvidia", wantModelID: "meta/llama-3.1-8b", }, + { + name: "azure with prefix", + model: "azure/my-gpt5-deployment", + wantProtocol: "azure", + wantModelID: "my-gpt5-deployment", + }, } for _, tt := range tests { @@ -371,3 +377,69 @@ func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) { t.Fatalf("Chat() error = %q, want timeout-related error", errMsg) } } + +func TestCreateProviderFromConfig_Azure(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "azure-gpt5", + Model: "azure/my-gpt5-deployment", + APIKey: "test-azure-key", + APIBase: "https://my-resource.openai.azure.com", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "my-gpt5-deployment" { + t.Errorf("modelID = %q, want %q", modelID, "my-gpt5-deployment") + } +} + +func TestCreateProviderFromConfig_AzureOpenAIAlias(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "azure-gpt4", + Model: "azure-openai/my-deployment", + APIKey: "test-azure-key", + APIBase: "https://my-resource.openai.azure.com", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "my-deployment" { + t.Errorf("modelID = %q, want %q", modelID, "my-deployment") + } +} + +func TestCreateProviderFromConfig_AzureMissingAPIKey(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "azure-gpt5", + Model: "azure/my-gpt5-deployment", + APIBase: "https://my-resource.openai.azure.com", + } + + _, _, err := CreateProviderFromConfig(cfg) + if err == nil { + t.Fatal("CreateProviderFromConfig() expected error for missing API key") + } +} + +func TestCreateProviderFromConfig_AzureMissingAPIBase(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "azure-gpt5", + Model: "azure/my-gpt5-deployment", + APIKey: "test-azure-key", + } + + _, _, err := CreateProviderFromConfig(cfg) + if err == nil { + t.Fatal("CreateProviderFromConfig() expected error for missing API base") + } +} diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index f97bf3acd5..fb2abaa5c2 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -1,18 +1,16 @@ package openai_compat import ( - "bufio" "bytes" "context" "encoding/json" "fmt" - "io" - "log" "net/http" "net/url" "strings" "time" + "github.com/sipeed/picoclaw/pkg/providers/common" "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) @@ -38,7 +36,7 @@ type Provider struct { type Option func(*Provider) -const defaultRequestTimeout = 120 * time.Second +const defaultRequestTimeout = common.DefaultRequestTimeout func WithMaxTokensField(maxTokensField string) Option { return func(p *Provider) { @@ -55,25 +53,10 @@ func WithRequestTimeout(timeout time.Duration) Option { } func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider { - client := &http.Client{ - Timeout: defaultRequestTimeout, - } - - if proxy != "" { - parsed, err := url.Parse(proxy) - if err == nil { - client.Transport = &http.Transport{ - Proxy: http.ProxyURL(parsed), - } - } else { - log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err) - } - } - p := &Provider{ apiKey: apiKey, apiBase: strings.TrimRight(apiBase, "/"), - httpClient: client, + httpClient: common.NewHTTPClient(proxy), } for _, opt := range opts { @@ -117,7 +100,7 @@ func (p *Provider) Chat( requestBody := map[string]any{ "model": model, - "messages": serializeMessages(messages), + "messages": common.SerializeMessages(messages), } if len(tools) > 0 { @@ -125,7 +108,7 @@ func (p *Provider) Chat( requestBody["tool_choice"] = "auto" } - if maxTokens, ok := asInt(options["max_tokens"]); ok { + if maxTokens, ok := common.AsInt(options["max_tokens"]); ok { // Use configured maxTokensField if specified, otherwise fallback to model-based detection fieldName := p.maxTokensField if fieldName == "" { @@ -141,7 +124,7 @@ func (p *Provider) Chat( requestBody[fieldName] = maxTokens } - if temperature, ok := asFloat(options["temperature"]); ok { + if temperature, ok := common.AsFloat(options["temperature"]); ok { lowerModel := strings.ToLower(model) // Kimi k2 models only support temperature=1. if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { @@ -185,275 +168,11 @@ func (p *Provider) Chat( } defer resp.Body.Close() - contentType := resp.Header.Get("Content-Type") - - // Non-200: read a prefix to tell HTML error page apart from JSON error body. if resp.StatusCode != http.StatusOK { - body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256)) - if readErr != nil { - return nil, fmt.Errorf("failed to read response: %w", readErr) - } - if looksLikeHTML(body, contentType) { - return nil, wrapHTMLResponseError(resp.StatusCode, body, contentType, p.apiBase) - } - return nil, fmt.Errorf( - "API request failed:\n Status: %d\n Body: %s", - resp.StatusCode, - responsePreview(body, 128), - ) - } - - // Peek without consuming so the full stream reaches the JSON decoder. - reader := bufio.NewReader(resp.Body) - prefix, err := reader.Peek(256) // io.EOF/ErrBufferFull are normal; only real errors abort - if err != nil && err != io.EOF && err != bufio.ErrBufferFull { - return nil, fmt.Errorf("failed to inspect response: %w", err) - } - if looksLikeHTML(prefix, contentType) { - return nil, wrapHTMLResponseError(resp.StatusCode, prefix, contentType, p.apiBase) - } - - out, err := parseResponse(reader) - if err != nil { - return nil, fmt.Errorf("failed to parse JSON response: %w", err) + return nil, common.HandleErrorResponse(resp, p.apiBase) } - return out, nil -} - -func wrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error { - respPreview := responsePreview(body, 128) - return fmt.Errorf( - "API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s", - apiBase, - contentType, - statusCode, - respPreview, - ) -} - -func looksLikeHTML(body []byte, contentType string) bool { - contentType = strings.ToLower(strings.TrimSpace(contentType)) - if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") { - return true - } - prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128)) - return bytes.HasPrefix(prefix, []byte("<!doctype html")) || - bytes.HasPrefix(prefix, []byte("<html")) || - bytes.HasPrefix(prefix, []byte("<head")) || - bytes.HasPrefix(prefix, []byte("<body")) -} - -func leadingTrimmedPrefix(body []byte, maxLen int) []byte { - i := 0 - for i < len(body) { - switch body[i] { - case ' ', '\t', '\n', '\r', '\f', '\v': - i++ - default: - end := i + maxLen - if end > len(body) { - end = len(body) - } - return body[i:end] - } - } - return nil -} - -func responsePreview(body []byte, maxLen int) string { - trimmed := bytes.TrimSpace(body) - if len(trimmed) == 0 { - return "<empty>" - } - if len(trimmed) <= maxLen { - return string(trimmed) - } - return string(trimmed[:maxLen]) + "..." -} - -func parseResponse(body io.Reader) (*LLMResponse, error) { - var apiResponse struct { - Choices []struct { - Message struct { - Content string `json:"content"` - ReasoningContent string `json:"reasoning_content"` - Reasoning string `json:"reasoning"` - ReasoningDetails []ReasoningDetail `json:"reasoning_details"` - ToolCalls []struct { - ID string `json:"id"` - Type string `json:"type"` - Function *struct { - Name string `json:"name"` - Arguments json.RawMessage `json:"arguments"` - } `json:"function"` - ExtraContent *struct { - Google *struct { - ThoughtSignature string `json:"thought_signature"` - } `json:"google"` - } `json:"extra_content"` - } `json:"tool_calls"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage *UsageInfo `json:"usage"` - } - - if err := json.NewDecoder(body).Decode(&apiResponse); err != nil { - return nil, fmt.Errorf("failed to decode response: %w", err) - } - - if len(apiResponse.Choices) == 0 { - return &LLMResponse{ - Content: "", - FinishReason: "stop", - }, nil - } - - choice := apiResponse.Choices[0] - toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) - for _, tc := range choice.Message.ToolCalls { - arguments := make(map[string]any) - name := "" - - // Extract thought_signature from Gemini/Google-specific extra content - thoughtSignature := "" - if tc.ExtraContent != nil && tc.ExtraContent.Google != nil { - thoughtSignature = tc.ExtraContent.Google.ThoughtSignature - } - - if tc.Function != nil { - name = tc.Function.Name - arguments = decodeToolCallArguments(tc.Function.Arguments, name) - } - - // Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence - toolCall := ToolCall{ - ID: tc.ID, - Name: name, - Arguments: arguments, - ThoughtSignature: thoughtSignature, - } - - if thoughtSignature != "" { - toolCall.ExtraContent = &ExtraContent{ - Google: &GoogleExtra{ - ThoughtSignature: thoughtSignature, - }, - } - } - - toolCalls = append(toolCalls, toolCall) - } - - return &LLMResponse{ - Content: choice.Message.Content, - ReasoningContent: choice.Message.ReasoningContent, - Reasoning: choice.Message.Reasoning, - ReasoningDetails: choice.Message.ReasoningDetails, - ToolCalls: toolCalls, - FinishReason: choice.FinishReason, - Usage: apiResponse.Usage, - }, nil -} - -func decodeToolCallArguments(raw json.RawMessage, name string) map[string]any { - arguments := make(map[string]any) - raw = bytes.TrimSpace(raw) - if len(raw) == 0 || bytes.Equal(raw, []byte("null")) { - return arguments - } - - var decoded any - if err := json.Unmarshal(raw, &decoded); err != nil { - log.Printf("openai_compat: failed to decode tool call arguments payload for %q: %v", name, err) - arguments["raw"] = string(raw) - return arguments - } - - switch v := decoded.(type) { - case string: - if strings.TrimSpace(v) == "" { - return arguments - } - if err := json.Unmarshal([]byte(v), &arguments); err != nil { - log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err) - arguments["raw"] = v - } - return arguments - case map[string]any: - return v - default: - log.Printf("openai_compat: unsupported tool call arguments type for %q: %T", name, decoded) - arguments["raw"] = string(raw) - return arguments - } -} - -// openaiMessage is the wire-format message for OpenAI-compatible APIs. -// It mirrors protocoltypes.Message but omits SystemParts, which is an -// internal field that would be unknown to third-party endpoints. -type openaiMessage struct { - Role string `json:"role"` - Content string `json:"content"` - ReasoningContent string `json:"reasoning_content,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` -} - -// serializeMessages converts internal Message structs to the OpenAI wire format. -// - Strips SystemParts (unknown to third-party endpoints) -// - Converts messages with Media to multipart content format (text + image_url parts) -// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages -func serializeMessages(messages []Message) []any { - out := make([]any, 0, len(messages)) - for _, m := range messages { - if len(m.Media) == 0 { - out = append(out, openaiMessage{ - Role: m.Role, - Content: m.Content, - ReasoningContent: m.ReasoningContent, - ToolCalls: m.ToolCalls, - ToolCallID: m.ToolCallID, - }) - continue - } - - // Multipart content format for messages with media - parts := make([]map[string]any, 0, 1+len(m.Media)) - if m.Content != "" { - parts = append(parts, map[string]any{ - "type": "text", - "text": m.Content, - }) - } - for _, mediaURL := range m.Media { - if strings.HasPrefix(mediaURL, "data:image/") { - parts = append(parts, map[string]any{ - "type": "image_url", - "image_url": map[string]any{ - "url": mediaURL, - }, - }) - } - } - - msg := map[string]any{ - "role": m.Role, - "content": parts, - } - if m.ToolCallID != "" { - msg["tool_call_id"] = m.ToolCallID - } - if len(m.ToolCalls) > 0 { - msg["tool_calls"] = m.ToolCalls - } - if m.ReasoningContent != "" { - msg["reasoning_content"] = m.ReasoningContent - } - out = append(out, msg) - } - return out + return common.ReadAndParseResponse(resp, p.apiBase) } func normalizeModel(model, apiBase string) string { @@ -476,36 +195,6 @@ func normalizeModel(model, apiBase string) string { } } -func asInt(v any) (int, bool) { - switch val := v.(type) { - case int: - return val, true - case int64: - return int(val), true - case float64: - return int(val), true - case float32: - return int(val), true - default: - return 0, false - } -} - -func asFloat(v any) (float64, bool) { - switch val := v.(type) { - case float64: - return val, true - case float32: - return float64(val), true - case int: - return float64(val), true - case int64: - return float64(val), true - default: - return 0, false - } -} - // supportsPromptCacheKey reports whether the given API base is known to // support the prompt_cache_key request field. Currently only OpenAI's own // API and Azure OpenAI support this. All other OpenAI-compatible providers diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 41f278a1b1..ed9747f9d7 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/sipeed/picoclaw/pkg/providers/common" "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) @@ -648,7 +649,7 @@ func TestSerializeMessages_PlainText(t *testing.T) { {Role: "user", Content: "hello"}, {Role: "assistant", Content: "hi", ReasoningContent: "thinking..."}, } - result := serializeMessages(messages) + result := common.SerializeMessages(messages) data, err := json.Marshal(result) if err != nil { @@ -670,7 +671,7 @@ func TestSerializeMessages_WithMedia(t *testing.T) { messages := []protocoltypes.Message{ {Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}}, } - result := serializeMessages(messages) + result := common.SerializeMessages(messages) data, _ := json.Marshal(result) var msgs []map[string]any @@ -703,7 +704,7 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) { messages := []protocoltypes.Message{ {Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"}, } - result := serializeMessages(messages) + result := common.SerializeMessages(messages) data, _ := json.Marshal(result) var msgs []map[string]any @@ -833,7 +834,7 @@ func TestSerializeMessages_StripsSystemParts(t *testing.T) { }, }, } - result := serializeMessages(messages) + result := common.SerializeMessages(messages) data, _ := json.Marshal(result) raw := string(data)