diff --git a/.github/dependabot.yml b/.github/dependabot.yml
new file mode 100644
index 0000000000..559a2249ed
--- /dev/null
+++ b/.github/dependabot.yml
@@ -0,0 +1,27 @@
+version: 2
+
+updates:
+
+ # Go dependencies (entire repo)
+ - package-ecosystem: "gomod"
+ directory: "/"
+ schedule:
+ interval: "weekly"
+ labels:
+ - "dependencies"
+ - "go"
+
+ # Frontend dependencies
+ - package-ecosystem: "npm"
+ directory: "/web/frontend"
+ schedule:
+ interval: "weekly"
+ labels:
+ - "dependencies"
+ - "frontend"
+
+ # GitHub Actions
+ - package-ecosystem: "github-actions"
+ directory: "/"
+ schedule:
+ interval: "weekly"
\ No newline at end of file
diff --git a/README.fr.md b/README.fr.md
index d5fe873bf6..49a02fb771 100644
--- a/README.fr.md
+++ b/README.fr.md
@@ -991,6 +991,7 @@ Cette conception permet également le **support multi-agent** avec une sélectio
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Obtenir Clé](https://www.byteplus.com/) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obtenir une clé](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Obtenir un Token](https://modelscope.cn/my/tokens) |
+| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Obtenir Clé](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth uniquement |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
diff --git a/README.ja.md b/README.ja.md
index 7fff46d13e..c0d27de4f4 100644
--- a/README.ja.md
+++ b/README.ja.md
@@ -935,6 +935,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [キーを取得](https://www.byteplus.com) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [キーを取得](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [トークンを取得](https://modelscope.cn/my/tokens) |
+| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [キーを取得](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | カスタム | OAuthのみ |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
diff --git a/README.md b/README.md
index e64daf0e4f..159ac706f4 100644
--- a/README.md
+++ b/README.md
@@ -1006,6 +1006,7 @@ The subagent has access to tools (message, web_search, etc.) and can communicate
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
| `cerebras` | LLM (Cerebras direct) | [cerebras.ai](https://cerebras.ai) |
| `vivgrid` | LLM (Vivgrid direct) | [vivgrid.com](https://vivgrid.com) |
+| `azure` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
### Model Configuration (model_list)
@@ -1042,6 +1043,7 @@ This design also enables **multi-agent support** with flexible provider selectio
| **Vivgrid** | `vivgrid/` | `https://api.vivgrid.com/v1` | OpenAI | [Get Key](https://vivgrid.com) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Get Key](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Get Token](https://modelscope.cn/my/tokens) |
+| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Get Key](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth only |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
diff --git a/README.pt-br.md b/README.pt-br.md
index 3fe24d7eaf..56946139b0 100644
--- a/README.pt-br.md
+++ b/README.pt-br.md
@@ -987,6 +987,7 @@ Este design também possibilita o **suporte multi-agent** com seleção flexíve
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Obter Chave](https://www.byteplus.com) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obter Chave](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Obter Token](https://modelscope.cn/my/tokens) |
+| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Obter Chave](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | Custom | Apenas OAuth |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
diff --git a/README.vi.md b/README.vi.md
index 3ee0209f6c..a542d6507f 100644
--- a/README.vi.md
+++ b/README.vi.md
@@ -956,6 +956,7 @@ Thiết kế này cũng cho phép **hỗ trợ đa tác nhân** với lựa ch
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Lấy Khóa](https://www.byteplus.com) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Lấy Key](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Lấy Token](https://modelscope.cn/my/tokens) |
+| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Lấy Khóa](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | Tùy chỉnh | Chỉ OAuth |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
diff --git a/README.zh.md b/README.zh.md
index 66d7c5f7cc..9877ef9f4e 100644
--- a/README.zh.md
+++ b/README.zh.md
@@ -528,6 +528,7 @@ Agent 读取 HEARTBEAT.md
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [获取密钥](https://www.byteplus.com) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [获取密钥](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [获取 Token](https://modelscope.cn/my/tokens) |
+| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [获取密钥](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | 自定义 | 仅 OAuth |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
diff --git a/config/config.example.json b/config/config.example.json
index 094aa46df2..1c11cd42a9 100644
--- a/config/config.example.json
+++ b/config/config.example.json
@@ -53,6 +53,12 @@
"api_key": "your-modelscope-access-token",
"api_base": "https://api-inference.modelscope.cn/v1"
},
+ {
+ "model_name": "azure-gpt5",
+ "model": "azure/my-gpt5-deployment",
+ "api_key": "your-azure-api-key",
+ "api_base": "https://your-resource.openai.azure.com"
+ },
{
"model_name": "loadbalanced-gpt-5.4",
"model": "openai/gpt-5.4",
diff --git a/docs/design/steering-spec.md b/docs/design/steering-spec.md
new file mode 100644
index 0000000000..0951bf864e
--- /dev/null
+++ b/docs/design/steering-spec.md
@@ -0,0 +1,306 @@
+# Steering — Implementation Specification
+
+## Problem
+
+When the agent is running (executing a chain of tool calls), the user has no way to redirect it. They must wait for the full cycle to complete before sending a new message. This creates a poor experience when the agent takes a wrong direction — the user watches it waste time on tools that are no longer relevant.
+
+## Solution
+
+Steering introduces a **message queue** that external callers can push into at any time. The agent loop polls this queue at well-defined checkpoints. When a steering message is found, the agent:
+
+1. Stops executing further tools in the current batch
+2. Injects the user's message into the conversation context
+3. Calls the LLM again with the updated context
+
+The user's intent reaches the model **as soon as the current tool finishes**, not after the entire turn completes.
+
+## Architecture Overview
+
+```mermaid
+graph TD
+ subgraph External Callers
+ TG[Telegram]
+ DC[Discord]
+ SL[Slack]
+ end
+
+ subgraph AgentLoop
+ BUS[MessageBus]
+ DRAIN[drainBusToSteering goroutine]
+ SQ[steeringQueue]
+ RLI[runLLMIteration]
+ TE[Tool Execution Loop]
+ LLM[LLM Call]
+ end
+
+ TG -->|PublishInbound| BUS
+ DC -->|PublishInbound| BUS
+ SL -->|PublishInbound| BUS
+
+ BUS -->|ConsumeInbound while busy| DRAIN
+ DRAIN -->|Steer| SQ
+
+ RLI -->|1. initial poll| SQ
+ TE -->|2. poll after each tool| SQ
+
+ SQ -->|pendingMessages| RLI
+ RLI -->|inject into context| LLM
+```
+
+### Bus drain mechanism
+
+Channels (Telegram, Discord, etc.) publish messages to the `MessageBus` via `PublishInbound`. Without additional wiring, these messages would sit in the bus buffer until the current `processMessage` finishes — meaning steering would never work for real users.
+
+The solution: when `Run()` starts processing a message, it spawns a **drain goroutine** (`drainBusToSteering`) that keeps consuming from the bus and calling `Steer()`. When `processMessage` returns, the drain is canceled and normal consumption resumes.
+
+```mermaid
+sequenceDiagram
+ participant Bus
+ participant Run
+ participant Drain
+ participant AgentLoop
+
+ Run->>Bus: ConsumeInbound() → msg
+ Run->>Drain: spawn drainBusToSteering(ctx)
+ Run->>Run: processMessage(msg)
+
+ Note over Drain: running concurrently
+
+ Bus-->>Drain: ConsumeInbound() → newMsg
+ Drain->>AgentLoop: al.transcribeAudioInMessage(ctx, newMsg)
+ Drain->>AgentLoop: Steer(providers.Message{Content: newMsg.Content})
+
+ Run->>Run: processMessage returns
+ Run->>Drain: cancel context
+ Note over Drain: exits
+```
+
+## Data Structures
+
+### steeringQueue
+
+A thread-safe FIFO queue, private to the `agent` package.
+
+| Field | Type | Description |
+|-------|------|-------------|
+| `mu` | `sync.Mutex` | Protects all access to `queue` and `mode` |
+| `queue` | `[]providers.Message` | Pending steering messages |
+| `mode` | `SteeringMode` | Dequeue strategy |
+
+**Methods:**
+
+| Method | Description |
+|--------|-------------|
+| `push(msg) error` | Appends a message to the queue. Returns an error if the queue is full (`MaxQueueSize`) |
+| `dequeue() []Message` | Removes and returns messages according to `mode`. Returns `nil` if empty |
+| `len() int` | Returns the current queue length |
+| `setMode(mode)` | Updates the dequeue strategy |
+| `getMode() SteeringMode` | Returns the current mode |
+
+### SteeringMode
+
+| Value | Constant | Behavior |
+|-------|----------|----------|
+| `"one-at-a-time"` | `SteeringOneAtATime` | `dequeue()` returns only the **first** message. Remaining messages stay in the queue for subsequent polls. |
+| `"all"` | `SteeringAll` | `dequeue()` drains the **entire** queue and returns all messages at once. |
+
+Default: `"one-at-a-time"`.
+
+### processOptions extension
+
+A new field was added to `processOptions`:
+
+| Field | Type | Description |
+|-------|------|-------------|
+| `SkipInitialSteeringPoll` | `bool` | When `true`, the initial steering poll at loop start is skipped. Used by `Continue()` to avoid double-dequeuing. |
+
+## Public API on AgentLoop
+
+| Method | Signature | Description |
+|--------|-----------|-------------|
+| `Steer` | `Steer(msg providers.Message) error` | Enqueues a steering message. Returns an error if the queue is full or not initialized. Thread-safe, can be called from any goroutine. |
+| `SteeringMode` | `SteeringMode() SteeringMode` | Returns the current dequeue mode. |
+| `SetSteeringMode` | `SetSteeringMode(mode SteeringMode)` | Changes the dequeue mode at runtime. |
+| `Continue` | `Continue(ctx, sessionKey, channel, chatID) (string, error)` | Resumes an idle agent using pending steering messages. Returns `""` if queue is empty. |
+
+## Integration into the Agent Loop
+
+### Where steering is wired
+
+The steering queue lives as a field on `AgentLoop`:
+
+```
+AgentLoop
+ ├── bus
+ ├── cfg
+ ├── registry
+ ├── steering *steeringQueue ← new
+ ├── ...
+```
+
+It is initialized in `NewAgentLoop` from `cfg.Agents.Defaults.SteeringMode`.
+
+### Detailed flow through runLLMIteration
+
+```mermaid
+sequenceDiagram
+ participant User
+ participant AgentLoop
+ participant runLLMIteration
+ participant ToolExecution
+ participant LLM
+
+ User->>AgentLoop: Steer(message)
+ Note over AgentLoop: steeringQueue.push(message)
+
+ Note over runLLMIteration: ── iteration starts ──
+
+ runLLMIteration->>AgentLoop: dequeueSteeringMessages()
[initial poll]
+ AgentLoop-->>runLLMIteration: [] (empty, or messages)
+
+ alt pendingMessages not empty
+ runLLMIteration->>runLLMIteration: inject into messages[]
save to session
+ end
+
+ runLLMIteration->>LLM: Chat(messages, tools)
+ LLM-->>runLLMIteration: response with toolCalls[0..N]
+
+ loop for each tool call (sequential)
+ ToolExecution->>ToolExecution: execute tool[i]
+ ToolExecution->>ToolExecution: process result,
append to messages[]
+
+ ToolExecution->>AgentLoop: dequeueSteeringMessages()
+ AgentLoop-->>ToolExecution: steeringMessages
+
+ alt steering found
+ opt remaining tools > 0
+ Note over ToolExecution: Mark tool[i+1..N-1] as
"Skipped due to queued user message."
+ end
+ Note over ToolExecution: steeringAfterTools = steeringMessages
+ Note over ToolExecution: break out of tool loop
+ end
+ end
+
+ alt steeringAfterTools not empty
+ ToolExecution-->>runLLMIteration: pendingMessages = steeringAfterTools
+ Note over runLLMIteration: next iteration will inject
these before calling LLM
+ end
+
+ Note over runLLMIteration: ── loop back to iteration start ──
+```
+
+### Polling checkpoints
+
+| # | Location | When | Purpose |
+|---|----------|------|---------|
+| 1 | Top of `runLLMIteration`, before first LLM call | Once, at loop entry | Catch messages enqueued while the agent was still setting up context |
+| 2 | After every tool completes (including the first and the last) | Immediately after each tool's result is processed | Interrupt the batch as early as possible — if steering is found and there are remaining tools, they are all skipped |
+
+### What happens to skipped tools
+
+When steering interrupts a tool batch after tool `[i]` completes, all tools from `[i+1]` to `[N-1]` are **not executed**. Instead, a tool result message is generated for each:
+
+```json
+{
+ "role": "tool",
+ "content": "Skipped due to queued user message.",
+ "tool_call_id": ""
+}
+```
+
+These results are:
+- Appended to the conversation `messages[]`
+- Saved to the session via `AddFullMessage`
+
+This ensures the LLM knows which of its requested actions were not performed.
+
+### Loop condition change
+
+The iteration loop condition was changed from:
+
+```go
+for iteration < agent.MaxIterations
+```
+
+to:
+
+```go
+for iteration < agent.MaxIterations || len(pendingMessages) > 0
+```
+
+This allows **one extra iteration** when steering arrives right at the max iteration boundary, ensuring the steering message is always processed.
+
+### Tool execution: parallel → sequential
+
+**Before steering:** all tool calls in a batch were executed in parallel using `sync.WaitGroup`.
+
+**After steering:** tool calls execute **sequentially**. This is required because steering must be polled between individual tool completions. A parallel execution model would not allow interrupting mid-batch.
+
+> **Trade-off:** This introduces latency when the LLM requests multiple independent tools in a single turn. In practice, most batches contain 1-2 tools, so the impact is minimal. The benefit of being able to interrupt outweighs the cost.
+
+### Why skip remaining tools (instead of letting them finish)
+
+Two strategies were considered when a steering message is detected mid-batch:
+
+1. **Skip remaining tools** (chosen) — stop executing, mark the rest as skipped, inject steering
+2. **Finish all tools, then inject** — let everything run, append steering afterwards
+
+Strategy 2 was rejected for three reasons:
+
+**Irreversible side effects.** Tools can send emails, write files, spawn subagents, or call external APIs. If the user says "stop" or "change direction", those actions have already happened and cannot be undone.
+
+| Tool batch | Steering | Skip (1) | Finish (2) |
+|---|---|---|---|
+| `[search, send_email]` | "don't send it" | Email not sent | Email sent |
+| `[query, write_file, spawn]` | "wrong database" | Only query runs | File + subagent wasted |
+| `[fetch₁, fetch₂, fetch₃, write]` | topic change | 1 fetch | 3 fetches + write, all discarded |
+
+**Wasted latency.** Tools like web fetches and API calls take seconds each. In a 3-tool batch averaging 3-4s per tool, the user would wait 10+ seconds for work that gets thrown away.
+
+**The LLM retains full awareness.** Skipped tools receive an explicit `"Skipped due to queued user message."` result, so the model knows what was not done and can decide whether to re-execute with the new context or take a different path.
+
+## The Continue() method
+
+`Continue` handles the case where the agent is **idle** (its last message was from the assistant) and the user has enqueued steering messages in the meantime.
+
+```mermaid
+flowchart TD
+ A[Continue called] --> B{dequeueSteeringMessages}
+ B -->|empty| C["return ('', nil)"]
+ B -->|messages found| D[Combine message contents]
+ D --> E["runAgentLoop with
SkipInitialSteeringPoll: true"]
+ E --> F[Return response]
+```
+
+**Why `SkipInitialSteeringPoll: true`?** Because `Continue` already dequeued the messages itself. Without this flag, `runLLMIteration` would poll again at the start and find nothing (the queue is already empty), or worse, double-process if new messages arrived in the meantime.
+
+## Configuration
+
+```json
+{
+ "agents": {
+ "defaults": {
+ "steering_mode": "one-at-a-time"
+ }
+ }
+}
+```
+
+| Field | Type | Default | Env var |
+|-------|------|---------|---------|
+| `steering_mode` | `string` | `"one-at-a-time"` | `PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE` |
+
+
+## Design decisions and trade-offs
+
+| Decision | Rationale |
+|----------|-----------|
+| Sequential tool execution | Required for per-tool steering polls. Parallel execution cannot be interrupted mid-batch. |
+| Polling-based (not channel/signal) | Keeps the implementation simple. No need for `select` or signal channels. The polling cost is negligible (mutex lock + slice length check). |
+| `one-at-a-time` as default | Gives the model a chance to react to each steering message individually. More predictable behavior than dumping all messages at once. |
+| Skipped tools get explicit error results | The LLM protocol requires a tool result for every tool call in the assistant message. Omitting them would cause API errors. The skip message also informs the model about what was not done. |
+| `Continue()` uses `SkipInitialSteeringPoll` | Prevents race conditions and double-dequeuing when resuming an idle agent. |
+| Queue stored on `AgentLoop`, not `AgentInstance` | Steering is a loop-level concern (it affects the iteration flow), not a per-agent concern. All agents share the same steering queue since `processMessage` is sequential. |
+| Bus drain goroutine in `Run()` | Channels (Telegram, Discord, etc.) publish to the bus via `PublishInbound`. Without the drain, messages would queue in the bus channel buffer and only be consumed after `processMessage` returns — defeating the purpose of steering. The drain goroutine bridges the gap by consuming new bus messages and calling `Steer()` while the agent is busy. |
+| Audio transcription before steering | The drain goroutine calls `al.transcribeAudioInMessage(ctx, msg)` before steering, so voice messages are converted to text before the agent sees them. If transcription fails, the error is silently discarded and the original message is steered as-is. |
+| `MaxQueueSize = 10` | Prevents unbounded memory growth if a user sends many messages while the agent is busy. Excess messages are dropped with a warning. |
diff --git a/docs/steering.md b/docs/steering.md
new file mode 100644
index 0000000000..ad08f84250
--- /dev/null
+++ b/docs/steering.md
@@ -0,0 +1,166 @@
+# Steering
+
+Steering allows injecting messages into an already-running agent loop, interrupting it between tool calls without waiting for the entire cycle to complete.
+
+## How it works
+
+When the agent is executing a sequence of tool calls (e.g. the model requested 3 tools in a single turn), steering checks the queue **after each tool** completes. If it finds queued messages:
+
+1. The remaining tools are **skipped** and receive `"Skipped due to queued user message."` as their result
+2. The steering messages are **injected into the conversation context**
+3. The model is called again with the updated context, including the user's steering message
+
+```
+User ──► Steer("change approach")
+ │
+Agent Loop ▼
+ ├─ tool[0] ✔ (executed)
+ ├─ [polling] → steering found!
+ ├─ tool[1] ✘ (skipped)
+ ├─ tool[2] ✘ (skipped)
+ └─ new LLM turn with steering message
+```
+
+## Configuration
+
+In `config.json`, under `agents.defaults`:
+
+```json
+{
+ "agents": {
+ "defaults": {
+ "steering_mode": "one-at-a-time"
+ }
+ }
+}
+```
+
+### Modes
+
+| Value | Behavior |
+|-------|----------|
+| `"one-at-a-time"` | **(default)** Dequeues only one message per polling cycle. If there are 3 messages in the queue, they are processed one at a time across 3 successive iterations. |
+| `"all"` | Drains the entire queue in a single poll. All pending messages are injected into the context together. |
+
+The environment variable `PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE` can be used as an alternative.
+
+## Go API
+
+### Steer — Send a steering message
+
+```go
+err := agentLoop.Steer(providers.Message{
+ Role: "user",
+ Content: "change direction, focus on X instead",
+})
+if err != nil {
+ // Queue is full (MaxQueueSize=10) or not initialized
+}
+```
+
+The message is enqueued in a thread-safe manner. Returns an error if the queue is full or not initialized. It will be picked up at the next polling point (after the current tool finishes).
+
+### SteeringMode / SetSteeringMode
+
+```go
+// Read the current mode
+mode := agentLoop.SteeringMode() // SteeringOneAtATime | SteeringAll
+
+// Change it at runtime
+agentLoop.SetSteeringMode(agent.SteeringAll)
+```
+
+### Continue — Resume an idle agent
+
+When the agent is idle (it has finished processing and its last message was from the assistant), `Continue` checks if there are steering messages in the queue and uses them to start a new cycle:
+
+```go
+response, err := agentLoop.Continue(ctx, sessionKey, channel, chatID)
+if err != nil {
+ // Error (e.g. "no default agent available")
+}
+if response == "" {
+ // No steering messages in queue, the agent stays idle
+}
+```
+
+`Continue` internally uses `SkipInitialSteeringPoll: true` to avoid double-dequeuing the same messages (since it already extracted them and passes them directly as input).
+
+## Polling points in the loop
+
+Steering is checked at **two points** in the agent cycle:
+
+1. **At loop start** — before the first LLM call, to catch messages enqueued during setup
+2. **After every tool completes** — including the first and the last. If steering is found and there are remaining tools, they are all skipped immediately
+
+## Why remaining tools are skipped
+
+When a steering message is detected, all remaining tools in the batch are skipped rather than executed. The alternative — let all tools finish and inject the steering message afterwards — was considered and rejected. Here is why.
+
+### Preventing unwanted side effects
+
+Tools can have **irreversible side effects**. If the user says "no, wait" while the agent is mid-batch, executing the remaining tools means those side effects happen anyway:
+
+| Tool batch | Steering message | With skip | Without skip |
+|---|---|---|---|
+| `[web_search, send_email]` | "don't send it" | Email **not** sent | Email sent, damage done |
+| `[query_db, write_file, spawn_agent]` | "use another database" | Only the query runs | File written + subagent spawned, all wasted |
+| `[search₁, search₂, search₃, write_file]` | user changes topic entirely | 1 search | 3 searches + file write, all irrelevant |
+
+### Avoiding wasted time
+
+Tools that take seconds (web fetches, API calls, database queries) would all run to completion before the agent sees the user's correction. In a batch of 3 tools each taking 3-4 seconds, that's 10+ seconds of work that will be discarded.
+
+With skipping, the agent reacts as soon as the current tool finishes — typically within a few seconds instead of waiting for the entire batch.
+
+### The LLM gets full context
+
+Skipped tools receive an explicit error result (`"Skipped due to queued user message."`), so the model knows exactly which actions were not performed. It can then decide whether to re-execute them with the new context, or take a different path entirely.
+
+### Trade-off: sequential execution
+
+Skipping requires tools to run **sequentially** (the previous implementation ran them in parallel). This introduces latency when the LLM requests multiple independent tools in a single turn. In practice, most batches contain 1-2 tools, so the impact is minimal compared to the benefit of being able to stop unwanted actions.
+
+## Skipped tool result format
+
+When steering interrupts a batch, each tool that was not executed receives a `tool` result with:
+
+```
+Content: "Skipped due to queued user message."
+```
+
+This is saved to the session via `AddFullMessage` and sent to the model, so it is aware that some requested actions were not performed.
+
+## Full flow example
+
+```
+1. User: "search for info on X, write a file, and send me a message"
+
+2. LLM responds with 3 tool calls: [web_search, write_file, message]
+
+3. web_search is executed → result saved
+
+4. [polling] → User called Steer("no, search for Y instead")
+
+5. write_file is skipped → "Skipped due to queued user message."
+ message is skipped → "Skipped due to queued user message."
+
+6. Message "search for Y instead" injected into context
+
+7. LLM receives the full updated context and responds accordingly
+```
+
+## Automatic bus drain
+
+When the agent loop (`Run()`) starts processing a message, it spawns a background goroutine that keeps consuming new inbound messages from the bus. These messages are automatically redirected into the steering queue via `Steer()`. This means:
+
+- Users on any channel (Telegram, Discord, etc.) don't need to do anything special — their messages are automatically captured as steering when the agent is busy
+- Audio messages are transcribed before being steered, so the agent receives text. If transcription fails, the original (non-transcribed) message is steered as-is
+- When `processMessage` finishes, the drain goroutine is canceled and normal message consumption resumes
+
+## Notes
+
+- Steering **does not interrupt** a tool that is currently executing. It waits for the current tool to finish, then checks the queue.
+- With `one-at-a-time` mode, if multiple messages are enqueued rapidly, they will be processed one per iteration. This gives the model the opportunity to react to each message individually.
+- With `all` mode, all pending messages are combined into a single injection. Useful when you want the agent to receive all the context at once.
+- The steering queue has a maximum capacity of 10 messages (`MaxQueueSize`). `Steer()` returns an error when the queue is full. In the bus drain path, the error is logged as a warning and the message is effectively dropped.
diff --git a/pkg/agent/eventbus_mock.go b/pkg/agent/eventbus_mock.go
new file mode 100644
index 0000000000..c9641092be
--- /dev/null
+++ b/pkg/agent/eventbus_mock.go
@@ -0,0 +1,12 @@
+package agent
+
+import "fmt"
+
+// MockEventBus - for POC
+var MockEventBus = struct {
+ Emit func(event any)
+}{
+ Emit: func(event any) {
+ fmt.Printf("[Mock EventBus] %T %+v\n", event, event)
+ },
+}
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index f20a56b9c4..21516e7de9 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -48,6 +48,7 @@ type AgentLoop struct {
transcriber voice.Transcriber
cmdRegistry *commands.Registry
mcp mcpRuntime
+ steering *steeringQueue
mu sync.RWMutex
// Track active requests for safe provider cleanup
activeRequests sync.WaitGroup
@@ -55,15 +56,16 @@ type AgentLoop struct {
// processOptions configures how a message is processed
type processOptions struct {
- SessionKey string // Session identifier for history/context
- Channel string // Target channel for tool execution
- ChatID string // Target chat ID for tool execution
- UserMessage string // User message content (may include prefix)
- Media []string // media:// refs from inbound message
- DefaultResponse string // Response when LLM returns empty
- EnableSummary bool // Whether to trigger summarization
- SendResponse bool // Whether to send response via bus
- NoHistory bool // If true, don't load session history (for heartbeat)
+ SessionKey string // Session identifier for history/context
+ Channel string // Target channel for tool execution
+ ChatID string // Target chat ID for tool execution
+ UserMessage string // User message content (may include prefix)
+ Media []string // media:// refs from inbound message
+ DefaultResponse string // Response when LLM returns empty
+ EnableSummary bool // Whether to trigger summarization
+ SendResponse bool // Whether to send response via bus
+ NoHistory bool // If true, don't load session history (for heartbeat)
+ SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue)
}
const (
@@ -105,6 +107,7 @@ func NewAgentLoop(
summarizing: sync.Map{},
fallback: fallbackChain,
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
+ steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
}
return al
@@ -257,6 +260,13 @@ func (al *AgentLoop) Run(ctx context.Context) error {
continue
}
+ // Start a goroutine that drains the bus while processMessage is
+ // running. Any inbound messages that arrive during processing are
+ // redirected into the steering queue so the agent loop can pick
+ // them up between tool calls.
+ drainCtx, drainCancel := context.WithCancel(ctx)
+ go al.drainBusToSteering(drainCtx)
+
// Process message
func() {
// TODO: Re-enable media cleanup after inbound media is properly consumed by the agent.
@@ -272,6 +282,8 @@ func (al *AgentLoop) Run(ctx context.Context) error {
// }
// }()
+ defer drainCancel()
+
response, err := al.processMessage(ctx, msg)
if err != nil {
response = fmt.Sprintf("Error processing message: %v", err)
@@ -318,6 +330,39 @@ func (al *AgentLoop) Run(ctx context.Context) error {
return nil
}
+// drainBusToSteering continuously consumes inbound messages and redirects
+// them into the steering queue. It runs in a goroutine while processMessage
+// is active and stops when drainCtx is canceled (i.e., processMessage returns).
+func (al *AgentLoop) drainBusToSteering(ctx context.Context) {
+ for {
+ msg, ok := al.bus.ConsumeInbound(ctx)
+ if !ok {
+ return
+ }
+
+ // Transcribe audio if needed before steering, so the agent sees text.
+ msg, _ = al.transcribeAudioInMessage(ctx, msg)
+
+ logger.InfoCF("agent", "Redirecting inbound message to steering queue",
+ map[string]any{
+ "channel": msg.Channel,
+ "sender_id": msg.SenderID,
+ "content_len": len(msg.Content),
+ })
+
+ if err := al.Steer(providers.Message{
+ Role: "user",
+ Content: msg.Content,
+ }); err != nil {
+ logger.WarnCF("agent", "Failed to steer message, will be lost",
+ map[string]any{
+ "error": err.Error(),
+ "channel": msg.Channel,
+ })
+ }
+ }
+}
+
func (al *AgentLoop) Stop() {
al.running.Store(false)
}
@@ -999,6 +1044,16 @@ func (al *AgentLoop) runLLMIteration(
) (string, int, error) {
iteration := 0
var finalContent string
+ var pendingMessages []providers.Message
+
+ // Poll for steering messages at loop start (in case the user typed while
+ // the agent was setting up), unless the caller already provided initial
+ // steering messages (e.g. Continue).
+ if !opts.SkipInitialSteeringPoll {
+ if msgs := al.dequeueSteeringMessages(); len(msgs) > 0 {
+ pendingMessages = msgs
+ }
+ }
// Determine effective model tier for this conversation turn.
// selectCandidates evaluates routing once and the decision is sticky for
@@ -1006,9 +1061,25 @@ func (al *AgentLoop) runLLMIteration(
// tool chain doesn't switch models mid-way through.
activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
- for iteration < agent.MaxIterations {
+ for iteration < agent.MaxIterations || len(pendingMessages) > 0 {
iteration++
+ // Inject pending steering messages into the conversation context
+ // before the next LLM call.
+ if len(pendingMessages) > 0 {
+ for _, pm := range pendingMessages {
+ messages = append(messages, pm)
+ agent.Sessions.AddMessage(opts.SessionKey, pm.Role, pm.Content)
+ logger.InfoCF("agent", "Injected steering message into context",
+ map[string]any{
+ "agent_id": agent.ID,
+ "iteration": iteration,
+ "content_len": len(pm.Content),
+ })
+ }
+ pendingMessages = nil
+ }
+
logger.DebugCF("agent", "LLM iteration",
map[string]any{
"agent_id": agent.ID,
@@ -1251,107 +1322,83 @@ func (al *AgentLoop) runLLMIteration(
// Save assistant message with tool calls to session
agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg)
- // Execute tool calls in parallel
- type indexedAgentResult struct {
- result *tools.ToolResult
- tc providers.ToolCall
- }
-
- agentResults := make([]indexedAgentResult, len(normalizedToolCalls))
- var wg sync.WaitGroup
+ // Execute tool calls sequentially. After each tool completes, check
+ // for steering messages. If any are found, skip remaining tools.
+ var steeringAfterTools []providers.Message
for i, tc := range normalizedToolCalls {
- agentResults[i].tc = tc
-
- wg.Add(1)
- go func(idx int, tc providers.ToolCall) {
- defer wg.Done()
+ argsJSON, _ := json.Marshal(tc.Arguments)
+ argsPreview := utils.Truncate(string(argsJSON), 200)
+ logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
+ map[string]any{
+ "agent_id": agent.ID,
+ "tool": tc.Name,
+ "iteration": iteration,
+ })
- argsJSON, _ := json.Marshal(tc.Arguments)
- argsPreview := utils.Truncate(string(argsJSON), 200)
- logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
- map[string]any{
- "agent_id": agent.ID,
- "tool": tc.Name,
- "iteration": iteration,
+ // Create async callback for tools that implement AsyncExecutor.
+ asyncCallback := func(_ context.Context, result *tools.ToolResult) {
+ if !result.Silent && result.ForUser != "" {
+ outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer outCancel()
+ _ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
+ Channel: opts.Channel,
+ ChatID: opts.ChatID,
+ Content: result.ForUser,
})
+ }
- // Create async callback for tools that implement AsyncExecutor.
- // When the background work completes, this publishes the result
- // as an inbound system message so processSystemMessage routes it
- // back to the user via the normal agent loop.
- asyncCallback := func(_ context.Context, result *tools.ToolResult) {
- // Send ForUser content directly to the user (immediate feedback),
- // mirroring the synchronous tool execution path.
- if !result.Silent && result.ForUser != "" {
- outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer outCancel()
- _ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
- Channel: opts.Channel,
- ChatID: opts.ChatID,
- Content: result.ForUser,
- })
- }
-
- // Determine content for the agent loop (ForLLM or error).
- content := result.ForLLM
- if content == "" && result.Err != nil {
- content = result.Err.Error()
- }
- if content == "" {
- return
- }
-
- logger.InfoCF("agent", "Async tool completed, publishing result",
- map[string]any{
- "tool": tc.Name,
- "content_len": len(content),
- "channel": opts.Channel,
- })
+ content := result.ForLLM
+ if content == "" && result.Err != nil {
+ content = result.Err.Error()
+ }
+ if content == "" {
+ return
+ }
- pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer pubCancel()
- _ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
- Channel: "system",
- SenderID: fmt.Sprintf("async:%s", tc.Name),
- ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID),
- Content: content,
+ logger.InfoCF("agent", "Async tool completed, publishing result",
+ map[string]any{
+ "tool": tc.Name,
+ "content_len": len(content),
+ "channel": opts.Channel,
})
- }
- toolResult := agent.Tools.ExecuteWithContext(
- ctx,
- tc.Name,
- tc.Arguments,
- opts.Channel,
- opts.ChatID,
- asyncCallback,
- )
- agentResults[idx].result = toolResult
- }(i, tc)
- }
- wg.Wait()
+ pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer pubCancel()
+ _ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
+ Channel: "system",
+ SenderID: fmt.Sprintf("async:%s", tc.Name),
+ ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID),
+ Content: content,
+ })
+ }
+
+ toolResult := agent.Tools.ExecuteWithContext(
+ ctx,
+ tc.Name,
+ tc.Arguments,
+ opts.Channel,
+ opts.ChatID,
+ asyncCallback,
+ )
- // Process results in original order (send to user, save to session)
- for _, r := range agentResults {
- // Send ForUser content to user immediately if not Silent
- if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse {
+ // Process tool result
+ if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
- Content: r.result.ForUser,
+ Content: toolResult.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]any{
- "tool": r.tc.Name,
- "content_len": len(r.result.ForUser),
+ "tool": tc.Name,
+ "content_len": len(toolResult.ForUser),
})
}
- // If tool returned media refs, publish them as outbound media
- if len(r.result.Media) > 0 {
- parts := make([]bus.MediaPart, 0, len(r.result.Media))
- for _, ref := range r.result.Media {
+ if len(toolResult.Media) > 0 {
+ parts := make([]bus.MediaPart, 0, len(toolResult.Media))
+ for _, ref := range toolResult.Media {
part := bus.MediaPart{Ref: ref}
if al.mediaStore != nil {
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
@@ -1369,21 +1416,55 @@ func (al *AgentLoop) runLLMIteration(
})
}
- // Determine content for LLM based on tool result
- contentForLLM := r.result.ForLLM
- if contentForLLM == "" && r.result.Err != nil {
- contentForLLM = r.result.Err.Error()
+ contentForLLM := toolResult.ForLLM
+ if contentForLLM == "" && toolResult.Err != nil {
+ contentForLLM = toolResult.Err.Error()
}
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
- ToolCallID: r.tc.ID,
+ ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
-
- // Save tool result message to session
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
+
+ // After EVERY tool (including the first and last), check for
+ // steering messages. If found and there are remaining tools,
+ // skip them all.
+ if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 {
+ remaining := len(normalizedToolCalls) - i - 1
+ if remaining > 0 {
+ logger.InfoCF("agent", "Steering interrupt: skipping remaining tools",
+ map[string]any{
+ "agent_id": agent.ID,
+ "completed": i + 1,
+ "skipped": remaining,
+ "total_tools": len(normalizedToolCalls),
+ "steering_count": len(steerMsgs),
+ })
+
+ // Mark remaining tool calls as skipped
+ for j := i + 1; j < len(normalizedToolCalls); j++ {
+ skippedTC := normalizedToolCalls[j]
+ toolResultMsg := providers.Message{
+ Role: "tool",
+ Content: "Skipped due to queued user message.",
+ ToolCallID: skippedTC.ID,
+ }
+ messages = append(messages, toolResultMsg)
+ agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
+ }
+ }
+ steeringAfterTools = steerMsgs
+ break
+ }
+ }
+
+ // If steering messages were captured during tool execution, they
+ // become pendingMessages for the next iteration of the inner loop.
+ if len(steeringAfterTools) > 0 {
+ pendingMessages = steeringAfterTools
}
// Tick down TTL of discovered tools after processing tool results.
diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go
new file mode 100644
index 0000000000..8c7c79c160
--- /dev/null
+++ b/pkg/agent/steering.go
@@ -0,0 +1,188 @@
+package agent
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "sync"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+// SteeringMode controls how queued steering messages are dequeued.
+type SteeringMode string
+
+const (
+ // SteeringOneAtATime dequeues only the first queued message per poll.
+ SteeringOneAtATime SteeringMode = "one-at-a-time"
+ // SteeringAll drains the entire queue in a single poll.
+ SteeringAll SteeringMode = "all"
+ // MaxQueueSize number of possible messages in the Steering Queue
+ MaxQueueSize = 10
+)
+
+// parseSteeringMode normalizes a config string into a SteeringMode.
+func parseSteeringMode(s string) SteeringMode {
+ switch s {
+ case "all":
+ return SteeringAll
+ default:
+ return SteeringOneAtATime
+ }
+}
+
+// steeringQueue is a thread-safe queue of user messages that can be injected
+// into a running agent loop to interrupt it between tool calls.
+type steeringQueue struct {
+ mu sync.Mutex
+ queue []providers.Message
+ mode SteeringMode
+}
+
+func newSteeringQueue(mode SteeringMode) *steeringQueue {
+ return &steeringQueue{
+ mode: mode,
+ }
+}
+
+// push enqueues a steering message.
+func (sq *steeringQueue) push(msg providers.Message) error {
+ sq.mu.Lock()
+ defer sq.mu.Unlock()
+ if len(sq.queue) >= MaxQueueSize {
+ return fmt.Errorf("steering queue is full")
+ }
+ sq.queue = append(sq.queue, msg)
+ return nil
+}
+
+// dequeue removes and returns pending steering messages according to the
+// configured mode. Returns nil when the queue is empty.
+func (sq *steeringQueue) dequeue() []providers.Message {
+ sq.mu.Lock()
+ defer sq.mu.Unlock()
+
+ if len(sq.queue) == 0 {
+ return nil
+ }
+
+ switch sq.mode {
+ case SteeringAll:
+ msgs := sq.queue
+ sq.queue = nil
+ return msgs
+ default: // one-at-a-time
+ msg := sq.queue[0]
+ sq.queue[0] = providers.Message{} // Clear reference for GC
+ sq.queue = sq.queue[1:]
+ return []providers.Message{msg}
+ }
+}
+
+// len returns the number of queued messages.
+func (sq *steeringQueue) len() int {
+ sq.mu.Lock()
+ defer sq.mu.Unlock()
+ return len(sq.queue)
+}
+
+// setMode updates the steering mode.
+func (sq *steeringQueue) setMode(mode SteeringMode) {
+ sq.mu.Lock()
+ defer sq.mu.Unlock()
+ sq.mode = mode
+}
+
+// getMode returns the current steering mode.
+func (sq *steeringQueue) getMode() SteeringMode {
+ sq.mu.Lock()
+ defer sq.mu.Unlock()
+ return sq.mode
+}
+
+// --- AgentLoop steering API ---
+
+// Steer enqueues a user message to be injected into the currently running
+// agent loop. The message will be picked up after the current tool finishes
+// executing, causing any remaining tool calls in the batch to be skipped.
+func (al *AgentLoop) Steer(msg providers.Message) error {
+ if al.steering == nil {
+ return fmt.Errorf("steering queue is not initialized")
+ }
+ if err := al.steering.push(msg); err != nil {
+ logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{
+ "error": err.Error(),
+ "role": msg.Role,
+ })
+ return err
+ }
+ logger.DebugCF("agent", "Steering message enqueued", map[string]any{
+ "role": msg.Role,
+ "content_len": len(msg.Content),
+ "queue_len": al.steering.len(),
+ })
+
+ return nil
+}
+
+// SteeringMode returns the current steering mode.
+func (al *AgentLoop) SteeringMode() SteeringMode {
+ if al.steering == nil {
+ return SteeringOneAtATime
+ }
+ return al.steering.getMode()
+}
+
+// SetSteeringMode updates the steering mode.
+func (al *AgentLoop) SetSteeringMode(mode SteeringMode) {
+ if al.steering == nil {
+ return
+ }
+ al.steering.setMode(mode)
+}
+
+// dequeueSteeringMessages is the internal method called by the agent loop
+// to poll for steering messages. Returns nil when no messages are pending.
+func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
+ if al.steering == nil {
+ return nil
+ }
+ return al.steering.dequeue()
+}
+
+// Continue resumes an idle agent by dequeuing any pending steering messages
+// and running them through the agent loop. This is used when the agent's last
+// message was from the assistant (i.e., it has stopped processing) and the
+// user has since enqueued steering messages.
+//
+// If no steering messages are pending, it returns an empty string.
+func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
+ steeringMsgs := al.dequeueSteeringMessages()
+ if len(steeringMsgs) == 0 {
+ return "", nil
+ }
+
+ agent := al.GetRegistry().GetDefaultAgent()
+ if agent == nil {
+ return "", fmt.Errorf("no default agent available")
+ }
+
+ // Build a combined user message from the steering messages.
+ var contents []string
+ for _, msg := range steeringMsgs {
+ contents = append(contents, msg.Content)
+ }
+ combinedContent := strings.Join(contents, "\n")
+
+ return al.runAgentLoop(ctx, agent, processOptions{
+ SessionKey: sessionKey,
+ Channel: channel,
+ ChatID: chatID,
+ UserMessage: combinedContent,
+ DefaultResponse: defaultResponse,
+ EnableSummary: true,
+ SendResponse: false,
+ SkipInitialSteeringPoll: true,
+ })
+}
diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go
new file mode 100644
index 0000000000..e8cdb23449
--- /dev/null
+++ b/pkg/agent/steering_test.go
@@ -0,0 +1,744 @@
+package agent
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/providers"
+ "github.com/sipeed/picoclaw/pkg/tools"
+)
+
+// --- steeringQueue unit tests ---
+
+func TestSteeringQueue_PushDequeue_OneAtATime(t *testing.T) {
+ sq := newSteeringQueue(SteeringOneAtATime)
+
+ sq.push(providers.Message{Role: "user", Content: "msg1"})
+ sq.push(providers.Message{Role: "user", Content: "msg2"})
+ sq.push(providers.Message{Role: "user", Content: "msg3"})
+
+ if sq.len() != 3 {
+ t.Fatalf("expected 3 messages, got %d", sq.len())
+ }
+
+ msgs := sq.dequeue()
+ if len(msgs) != 1 {
+ t.Fatalf("expected 1 message in one-at-a-time mode, got %d", len(msgs))
+ }
+ if msgs[0].Content != "msg1" {
+ t.Fatalf("expected 'msg1', got %q", msgs[0].Content)
+ }
+ if sq.len() != 2 {
+ t.Fatalf("expected 2 remaining, got %d", sq.len())
+ }
+
+ msgs = sq.dequeue()
+ if len(msgs) != 1 || msgs[0].Content != "msg2" {
+ t.Fatalf("expected 'msg2', got %v", msgs)
+ }
+
+ msgs = sq.dequeue()
+ if len(msgs) != 1 || msgs[0].Content != "msg3" {
+ t.Fatalf("expected 'msg3', got %v", msgs)
+ }
+
+ msgs = sq.dequeue()
+ if msgs != nil {
+ t.Fatalf("expected nil from empty queue, got %v", msgs)
+ }
+}
+
+func TestSteeringQueue_PushDequeue_All(t *testing.T) {
+ sq := newSteeringQueue(SteeringAll)
+
+ sq.push(providers.Message{Role: "user", Content: "msg1"})
+ sq.push(providers.Message{Role: "user", Content: "msg2"})
+ sq.push(providers.Message{Role: "user", Content: "msg3"})
+
+ msgs := sq.dequeue()
+ if len(msgs) != 3 {
+ t.Fatalf("expected 3 messages in all mode, got %d", len(msgs))
+ }
+ if msgs[0].Content != "msg1" || msgs[1].Content != "msg2" || msgs[2].Content != "msg3" {
+ t.Fatalf("unexpected messages: %v", msgs)
+ }
+
+ if sq.len() != 0 {
+ t.Fatalf("expected 0 remaining, got %d", sq.len())
+ }
+
+ msgs = sq.dequeue()
+ if msgs != nil {
+ t.Fatalf("expected nil from empty queue, got %v", msgs)
+ }
+}
+
+func TestSteeringQueue_EmptyDequeue(t *testing.T) {
+ sq := newSteeringQueue(SteeringOneAtATime)
+ if msgs := sq.dequeue(); msgs != nil {
+ t.Fatalf("expected nil, got %v", msgs)
+ }
+}
+
+func TestSteeringQueue_SetMode(t *testing.T) {
+ sq := newSteeringQueue(SteeringOneAtATime)
+ if sq.getMode() != SteeringOneAtATime {
+ t.Fatalf("expected one-at-a-time, got %v", sq.getMode())
+ }
+
+ sq.setMode(SteeringAll)
+ if sq.getMode() != SteeringAll {
+ t.Fatalf("expected all, got %v", sq.getMode())
+ }
+
+ // Push two messages and verify all-mode drains them
+ sq.push(providers.Message{Role: "user", Content: "a"})
+ sq.push(providers.Message{Role: "user", Content: "b"})
+
+ msgs := sq.dequeue()
+ if len(msgs) != 2 {
+ t.Fatalf("expected 2 messages after mode switch, got %d", len(msgs))
+ }
+}
+
+func TestSteeringQueue_ConcurrentAccess(t *testing.T) {
+ sq := newSteeringQueue(SteeringOneAtATime)
+
+ var wg sync.WaitGroup
+ const n = MaxQueueSize
+
+ // Push from multiple goroutines
+ for i := 0; i < n; i++ {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ sq.push(providers.Message{Role: "user", Content: fmt.Sprintf("msg%d", i)})
+ }(i)
+ }
+ wg.Wait()
+
+ if sq.len() != n {
+ t.Fatalf("expected %d messages, got %d", n, sq.len())
+ }
+
+ // Drain from multiple goroutines
+ var drained int
+ var mu sync.Mutex
+ for i := 0; i < n; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if msgs := sq.dequeue(); len(msgs) > 0 {
+ mu.Lock()
+ drained += len(msgs)
+ mu.Unlock()
+ }
+ }()
+ }
+ wg.Wait()
+
+ if drained != n {
+ t.Fatalf("expected to drain %d messages, got %d", n, drained)
+ }
+}
+
+func TestSteeringQueue_Overflow(t *testing.T) {
+ sq := newSteeringQueue(SteeringOneAtATime)
+
+ // Fill the queue up to its maximum capacity
+ for i := 0; i < MaxQueueSize; i++ {
+ err := sq.push(providers.Message{Role: "user", Content: fmt.Sprintf("msg%d", i)})
+ if err != nil {
+ t.Fatalf("unexpected error pushing message %d: %v", i, err)
+ }
+ }
+
+ // Sanity check: ensure the queue is actually full
+ if sq.len() != MaxQueueSize {
+ t.Fatalf("expected queue length %d, got %d", MaxQueueSize, sq.len())
+ }
+
+ // Attempt to push one more message, which MUST fail
+ err := sq.push(providers.Message{Role: "user", Content: "overflow_msg"})
+
+ // Assert the error happened and is the exact one we expect
+ if err == nil {
+ t.Fatal("expected an error when pushing to a full queue, but got nil")
+ }
+
+ expectedErr := "steering queue is full"
+ if err.Error() != expectedErr {
+ t.Errorf("expected error message %q, got %q", expectedErr, err.Error())
+ }
+}
+
+func TestParseSteeringMode(t *testing.T) {
+ tests := []struct {
+ input string
+ expected SteeringMode
+ }{
+ {"", SteeringOneAtATime},
+ {"one-at-a-time", SteeringOneAtATime},
+ {"all", SteeringAll},
+ {"unknown", SteeringOneAtATime},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ if got := parseSteeringMode(tt.input); got != tt.expected {
+ t.Fatalf("parseSteeringMode(%q) = %v, want %v", tt.input, got, tt.expected)
+ }
+ })
+ }
+}
+
+// --- AgentLoop steering integration tests ---
+
+func TestAgentLoop_Steer_Enqueues(t *testing.T) {
+ al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
+ defer cleanup()
+
+ if cfg == nil {
+ t.Fatal("expected config to be initialized")
+ }
+ if msgBus == nil {
+ t.Fatal("expected message bus to be initialized")
+ }
+ if provider == nil {
+ t.Fatal("expected provider to be initialized")
+ }
+
+ al.Steer(providers.Message{Role: "user", Content: "interrupt me"})
+
+ if al.steering.len() != 1 {
+ t.Fatalf("expected 1 steering message, got %d", al.steering.len())
+ }
+
+ msgs := al.dequeueSteeringMessages()
+ if len(msgs) != 1 || msgs[0].Content != "interrupt me" {
+ t.Fatalf("unexpected dequeued message: %v", msgs)
+ }
+}
+
+func TestAgentLoop_SteeringMode_GetSet(t *testing.T) {
+ al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
+ defer cleanup()
+
+ if cfg == nil {
+ t.Fatal("expected config to be initialized")
+ }
+ if msgBus == nil {
+ t.Fatal("expected message bus to be initialized")
+ }
+ if provider == nil {
+ t.Fatal("expected provider to be initialized")
+ }
+
+ if al.SteeringMode() != SteeringOneAtATime {
+ t.Fatalf("expected default mode one-at-a-time, got %v", al.SteeringMode())
+ }
+
+ al.SetSteeringMode(SteeringAll)
+ if al.SteeringMode() != SteeringAll {
+ t.Fatalf("expected all mode, got %v", al.SteeringMode())
+ }
+}
+
+func TestAgentLoop_SteeringMode_ConfiguredFromConfig(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ SteeringMode: "all",
+ },
+ },
+ }
+
+ msgBus := bus.NewMessageBus()
+ provider := &mockProvider{}
+ al := NewAgentLoop(cfg, msgBus, provider)
+
+ if al.SteeringMode() != SteeringAll {
+ t.Fatalf("expected 'all' mode from config, got %v", al.SteeringMode())
+ }
+}
+
+func TestAgentLoop_Continue_NoMessages(t *testing.T) {
+ al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
+ defer cleanup()
+
+ if cfg == nil {
+ t.Fatal("expected config to be initialized")
+ }
+ if msgBus == nil {
+ t.Fatal("expected message bus to be initialized")
+ }
+ if provider == nil {
+ t.Fatal("expected provider to be initialized")
+ }
+
+ resp, err := al.Continue(context.Background(), "test-session", "test", "chat1")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if resp != "" {
+ t.Fatalf("expected empty response for no steering messages, got %q", resp)
+ }
+}
+
+func TestAgentLoop_Continue_WithMessages(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+
+ msgBus := bus.NewMessageBus()
+ provider := &simpleMockProvider{response: "continued response"}
+ al := NewAgentLoop(cfg, msgBus, provider)
+
+ al.Steer(providers.Message{Role: "user", Content: "new direction"})
+
+ resp, err := al.Continue(context.Background(), "test-session", "test", "chat1")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if resp != "continued response" {
+ t.Fatalf("expected 'continued response', got %q", resp)
+ }
+}
+
+// slowTool simulates a tool that takes some time to execute.
+type slowTool struct {
+ name string
+ duration time.Duration
+ execCh chan struct{} // closed when Execute starts
+}
+
+func (t *slowTool) Name() string { return t.name }
+func (t *slowTool) Description() string { return "slow tool for testing" }
+func (t *slowTool) Parameters() map[string]any {
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{},
+ }
+}
+
+func (t *slowTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
+ if t.execCh != nil {
+ close(t.execCh)
+ }
+ time.Sleep(t.duration)
+ return tools.SilentResult(fmt.Sprintf("executed %s", t.name))
+}
+
+// toolCallProvider returns an LLM response with tool calls on the first call,
+// then a direct response on subsequent calls.
+type toolCallProvider struct {
+ mu sync.Mutex
+ calls int
+ toolCalls []providers.ToolCall
+ finalResp string
+}
+
+func (m *toolCallProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ tools []providers.ToolDefinition,
+ model string,
+ opts map[string]any,
+) (*providers.LLMResponse, error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.calls++
+
+ if m.calls == 1 && len(m.toolCalls) > 0 {
+ return &providers.LLMResponse{
+ Content: "",
+ ToolCalls: m.toolCalls,
+ }, nil
+ }
+
+ return &providers.LLMResponse{
+ Content: m.finalResp,
+ ToolCalls: []providers.ToolCall{},
+ }, nil
+}
+
+func (m *toolCallProvider) GetDefaultModel() string {
+ return "tool-call-mock"
+}
+
+func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+
+ tool1ExecCh := make(chan struct{})
+ tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
+ tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
+
+ provider := &toolCallProvider{
+ toolCalls: []providers.ToolCall{
+ {
+ ID: "call_1",
+ Type: "function",
+ Name: "tool_one",
+ Function: &providers.FunctionCall{
+ Name: "tool_one",
+ Arguments: "{}",
+ },
+ Arguments: map[string]any{},
+ },
+ {
+ ID: "call_2",
+ Type: "function",
+ Name: "tool_two",
+ Function: &providers.FunctionCall{
+ Name: "tool_two",
+ Arguments: "{}",
+ },
+ Arguments: map[string]any{},
+ },
+ },
+ finalResp: "steered response",
+ }
+
+ msgBus := bus.NewMessageBus()
+ al := NewAgentLoop(cfg, msgBus, provider)
+ al.RegisterTool(tool1)
+ al.RegisterTool(tool2)
+
+ // Start processing in a goroutine
+ type result struct {
+ resp string
+ err error
+ }
+ resultCh := make(chan result, 1)
+
+ go func() {
+ resp, err := al.ProcessDirectWithChannel(
+ context.Background(),
+ "do something",
+ "test-session",
+ "test",
+ "chat1",
+ )
+ resultCh <- result{resp, err}
+ }()
+
+ // Wait for tool_one to start executing, then enqueue a steering message
+ select {
+ case <-tool1ExecCh:
+ // tool_one has started executing
+ case <-time.After(2 * time.Second):
+ t.Fatal("timeout waiting for tool_one to start")
+ }
+
+ al.Steer(providers.Message{Role: "user", Content: "change course"})
+
+ // Get the result
+ select {
+ case r := <-resultCh:
+ if r.err != nil {
+ t.Fatalf("unexpected error: %v", r.err)
+ }
+ if r.resp != "steered response" {
+ t.Fatalf("expected 'steered response', got %q", r.resp)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout waiting for agent loop to complete")
+ }
+
+ // The provider should have been called twice:
+ // 1. first call returned tool calls
+ // 2. second call (after steering) returned the final response
+ provider.mu.Lock()
+ calls := provider.calls
+ provider.mu.Unlock()
+ if calls != 2 {
+ t.Fatalf("expected 2 provider calls, got %d", calls)
+ }
+}
+
+func TestAgentLoop_Steering_InitialPoll(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+
+ // Provider that captures messages it receives
+ var capturedMessages []providers.Message
+ var capMu sync.Mutex
+ provider := &capturingMockProvider{
+ response: "ack",
+ captureFn: func(msgs []providers.Message) {
+ capMu.Lock()
+ capturedMessages = make([]providers.Message, len(msgs))
+ copy(capturedMessages, msgs)
+ capMu.Unlock()
+ },
+ }
+
+ msgBus := bus.NewMessageBus()
+ al := NewAgentLoop(cfg, msgBus, provider)
+
+ // Enqueue a steering message before processing starts
+ al.Steer(providers.Message{Role: "user", Content: "pre-enqueued steering"})
+
+ // Process a normal message - the initial steering poll should inject the steering message
+ _, err = al.ProcessDirectWithChannel(
+ context.Background(),
+ "initial message",
+ "test-session",
+ "test",
+ "chat1",
+ )
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // The steering message should have been injected into the conversation
+ capMu.Lock()
+ msgs := capturedMessages
+ capMu.Unlock()
+
+ // Look for the steering message in the captured messages
+ found := false
+ for _, m := range msgs {
+ if m.Content == "pre-enqueued steering" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Fatal("expected steering message to be injected into conversation context")
+ }
+}
+
+// capturingMockProvider captures messages sent to Chat for inspection.
+type capturingMockProvider struct {
+ response string
+ calls int
+ captureFn func([]providers.Message)
+}
+
+func (m *capturingMockProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ tools []providers.ToolDefinition,
+ model string,
+ opts map[string]any,
+) (*providers.LLMResponse, error) {
+ m.calls++
+ if m.captureFn != nil {
+ m.captureFn(messages)
+ }
+ return &providers.LLMResponse{
+ Content: m.response,
+ ToolCalls: []providers.ToolCall{},
+ }, nil
+}
+
+func (m *capturingMockProvider) GetDefaultModel() string {
+ return "capturing-mock"
+}
+
+func TestAgentLoop_Steering_SkippedToolsHaveErrorResults(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+
+ execCh := make(chan struct{})
+ tool1 := &slowTool{name: "slow_tool", duration: 50 * time.Millisecond, execCh: execCh}
+ tool2 := &slowTool{name: "skipped_tool", duration: 50 * time.Millisecond}
+
+ // Provider that captures messages on the second call (after tools)
+ var secondCallMessages []providers.Message
+ var capMu sync.Mutex
+ callCount := 0
+
+ provider := &toolCallProvider{
+ toolCalls: []providers.ToolCall{
+ {
+ ID: "call_1",
+ Type: "function",
+ Name: "slow_tool",
+ Function: &providers.FunctionCall{
+ Name: "slow_tool",
+ Arguments: "{}",
+ },
+ Arguments: map[string]any{},
+ },
+ {
+ ID: "call_2",
+ Type: "function",
+ Name: "skipped_tool",
+ Function: &providers.FunctionCall{
+ Name: "skipped_tool",
+ Arguments: "{}",
+ },
+ Arguments: map[string]any{},
+ },
+ },
+ finalResp: "done",
+ }
+
+ // Wrap provider to capture messages on second call
+ wrappedProvider := &wrappingProvider{
+ inner: provider,
+ onChat: func(msgs []providers.Message) {
+ capMu.Lock()
+ callCount++
+ if callCount >= 2 {
+ secondCallMessages = make([]providers.Message, len(msgs))
+ copy(secondCallMessages, msgs)
+ }
+ capMu.Unlock()
+ },
+ }
+
+ msgBus := bus.NewMessageBus()
+ al := NewAgentLoop(cfg, msgBus, wrappedProvider)
+ al.RegisterTool(tool1)
+ al.RegisterTool(tool2)
+
+ resultCh := make(chan string, 1)
+ go func() {
+ resp, _ := al.ProcessDirectWithChannel(
+ context.Background(), "go", "test-session", "test", "chat1",
+ )
+ resultCh <- resp
+ }()
+
+ <-execCh
+ al.Steer(providers.Message{Role: "user", Content: "interrupt!"})
+
+ select {
+ case <-resultCh:
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout")
+ }
+
+ // Check that the skipped tool result message is in the conversation
+ capMu.Lock()
+ msgs := secondCallMessages
+ capMu.Unlock()
+
+ foundSkipped := false
+ for _, m := range msgs {
+ if m.Role == "tool" && m.ToolCallID == "call_2" && m.Content == "Skipped due to queued user message." {
+ foundSkipped = true
+ break
+ }
+ }
+ if !foundSkipped {
+ // Log what we actually got
+ for i, m := range msgs {
+ t.Logf("msg[%d]: role=%s toolCallID=%s content=%s", i, m.Role, m.ToolCallID, truncate(m.Content, 80))
+ }
+ t.Fatal("expected skipped tool result for call_2")
+ }
+}
+
+func truncate(s string, n int) string {
+ if len(s) <= n {
+ return s
+ }
+ return s[:n] + "..."
+}
+
+// wrappingProvider wraps another provider to hook into Chat calls.
+type wrappingProvider struct {
+ inner providers.LLMProvider
+ onChat func([]providers.Message)
+}
+
+func (w *wrappingProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ tools []providers.ToolDefinition,
+ model string,
+ opts map[string]any,
+) (*providers.LLMResponse, error) {
+ if w.onChat != nil {
+ w.onChat(messages)
+ }
+ return w.inner.Chat(ctx, messages, tools, model, opts)
+}
+
+func (w *wrappingProvider) GetDefaultModel() string {
+ return w.inner.GetDefaultModel()
+}
+
+// Ensure NormalizeToolCall handles our test tool calls.
+func init() {
+ // This is a no-op init; we just need the tool call tests to work
+ // with the proper argument serialization.
+ _ = json.Marshal
+}
diff --git a/pkg/agent/subturn.go b/pkg/agent/subturn.go
new file mode 100644
index 0000000000..ab7d60957b
--- /dev/null
+++ b/pkg/agent/subturn.go
@@ -0,0 +1,309 @@
+package agent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+
+ "github.com/sipeed/picoclaw/pkg/providers"
+ "github.com/sipeed/picoclaw/pkg/session"
+ "github.com/sipeed/picoclaw/pkg/tools"
+)
+
+// ====================== Config & Constants ======================
+const maxSubTurnDepth = 3
+
+var (
+ ErrDepthLimitExceeded = errors.New("sub-turn depth limit exceeded")
+ ErrInvalidSubTurnConfig = errors.New("invalid sub-turn config")
+)
+
+// ====================== SubTurn Config ======================
+type SubTurnConfig struct {
+ Model string
+ Tools []tools.Tool
+ SystemPrompt string
+ MaxTokens int
+ // Can be extended with temperature, topP, etc.
+}
+
+// ====================== Sub-turn Events (Aligned with EventBus) ======================
+type SubTurnSpawnEvent struct {
+ ParentID string
+ ChildID string
+ Config SubTurnConfig
+}
+
+type SubTurnEndEvent struct {
+ ChildID string
+ Result *tools.ToolResult
+ Err error
+}
+
+type SubTurnResultDeliveredEvent struct {
+ ParentID string
+ ChildID string
+ Result *tools.ToolResult
+}
+
+type SubTurnOrphanResultEvent struct {
+ ParentID string
+ ChildID string
+ Result *tools.ToolResult
+}
+
+// ====================== turnState (Simplified, reusable with existing structs) ======================
+type turnState struct {
+ ctx context.Context
+ cancelFunc context.CancelFunc // Used to cancel all children when this turn finishes
+ turnID string
+ parentTurnID string
+ depth int
+ childTurnIDs []string
+ pendingResults chan *tools.ToolResult
+ session session.SessionStore
+ mu sync.Mutex
+ isFinished bool // Marks if the parent Turn has ended
+}
+
+// ====================== Helper Functions ======================
+var globalTurnCounter int64
+
+func generateTurnID() string {
+ return fmt.Sprintf("subturn-%d", atomic.AddInt64(&globalTurnCounter, 1))
+}
+
+func newTurnState(ctx context.Context, id string, parent *turnState) *turnState {
+ turnCtx, cancel := context.WithCancel(ctx)
+ return &turnState{
+ ctx: turnCtx,
+ cancelFunc: cancel,
+ turnID: id,
+ parentTurnID: parent.turnID,
+ depth: parent.depth + 1,
+ session: newEphemeralSession(parent.session),
+ // NOTE: In this PoC, I use a fixed-size channel (16).
+ // Under high concurrency or long-running sub-turns, this might fill up and cause
+ // intermediate results to be discarded in deliverSubTurnResult.
+ // For production, consider an unbounded queue or a blocking strategy with backpressure.
+ pendingResults: make(chan *tools.ToolResult, 16),
+ }
+}
+
+// Finish marks the turn as finished and cancels its context, aborting any running sub-turns.
+func (ts *turnState) Finish() {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.isFinished = true
+ if ts.cancelFunc != nil {
+ ts.cancelFunc()
+ }
+}
+
+// ephemeralSessionStore is a pure in-memory SessionStore for SubTurns.
+// It never writes to disk, keeping sub-turn history isolated from the parent session.
+type ephemeralSessionStore struct {
+ mu sync.Mutex
+ history []providers.Message
+ summary string
+}
+
+func (e *ephemeralSessionStore) AddMessage(sessionKey, role, content string) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.history = append(e.history, providers.Message{Role: role, Content: content})
+}
+
+func (e *ephemeralSessionStore) AddFullMessage(sessionKey string, msg providers.Message) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.history = append(e.history, msg)
+}
+
+func (e *ephemeralSessionStore) GetHistory(key string) []providers.Message {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ out := make([]providers.Message, len(e.history))
+ copy(out, e.history)
+ return out
+}
+
+func (e *ephemeralSessionStore) GetSummary(key string) string {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.summary
+}
+
+func (e *ephemeralSessionStore) SetSummary(key, summary string) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.summary = summary
+}
+
+func (e *ephemeralSessionStore) SetHistory(key string, history []providers.Message) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.history = make([]providers.Message, len(history))
+ copy(e.history, history)
+}
+
+func (e *ephemeralSessionStore) TruncateHistory(key string, keepLast int) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if len(e.history) > keepLast {
+ e.history = e.history[len(e.history)-keepLast:]
+ }
+}
+
+func (e *ephemeralSessionStore) Save(key string) error { return nil }
+func (e *ephemeralSessionStore) Close() error { return nil }
+
+func newEphemeralSession(_ session.SessionStore) session.SessionStore {
+ return &ephemeralSessionStore{}
+}
+
+// ====================== Core Function: spawnSubTurn ======================
+func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg SubTurnConfig) (result *tools.ToolResult, err error) {
+ // 1. Depth limit check
+ if parentTS.depth >= maxSubTurnDepth {
+ return nil, ErrDepthLimitExceeded
+ }
+
+ // 2. Config validation
+ if cfg.Model == "" {
+ return nil, ErrInvalidSubTurnConfig
+ }
+
+ // Create a sub-context for the child turn to support cancellation
+ childCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ // 3. Create child Turn state
+ childID := generateTurnID()
+ childTS := newTurnState(childCtx, childID, parentTS)
+
+ // 4. Establish parent-child relationship (thread-safe)
+ parentTS.mu.Lock()
+ parentTS.childTurnIDs = append(parentTS.childTurnIDs, childID)
+ parentTS.mu.Unlock()
+
+ // 5. Emit Spawn event (currently using Mock, will be replaced by real EventBus)
+ MockEventBus.Emit(SubTurnSpawnEvent{
+ ParentID: parentTS.turnID,
+ ChildID: childID,
+ Config: cfg,
+ })
+
+ // 6. Defer emitting End event, and recover from panics to ensure it's always fired
+ defer func() {
+ if r := recover(); r != nil {
+ err = fmt.Errorf("subturn panicked: %v", r)
+ }
+
+ MockEventBus.Emit(SubTurnEndEvent{
+ ChildID: childID,
+ Result: result,
+ Err: err,
+ })
+ }()
+
+ // 7. Execute sub-turn via the real agent loop.
+ // Build a child AgentInstance from SubTurnConfig, inheriting defaults from the parent agent.
+ result, err = runTurn(childCtx, al, childTS, cfg)
+
+ // 8. Deliver result back to parent Turn
+ deliverSubTurnResult(parentTS, childID, result)
+
+ return result, err
+}
+
+// ====================== Result Delivery ======================
+func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.ToolResult) {
+ parentTS.mu.Lock()
+ defer parentTS.mu.Unlock()
+
+ // Emit ResultDelivered event
+ MockEventBus.Emit(SubTurnResultDeliveredEvent{
+ ParentID: parentTS.turnID,
+ ChildID: childID,
+ Result: result,
+ })
+
+ if !parentTS.isFinished {
+ // Parent Turn is still running → Place in pending queue (handled automatically by parent loop in next round)
+ select {
+ case parentTS.pendingResults <- result:
+ default:
+ fmt.Println("[SubTurn] warning: pendingResults channel full")
+ }
+ return
+ }
+
+ // Parent Turn has ended
+ // emit an OrphanResultEvent so the system/UI can handle this late arrival.
+ if result != nil {
+ MockEventBus.Emit(SubTurnOrphanResultEvent{
+ ParentID: parentTS.turnID,
+ ChildID: childID,
+ Result: result,
+ })
+ }
+}
+
+// runTurn builds a temporary AgentInstance from SubTurnConfig and delegates to
+// the real agent loop. The child's ephemeral session is used for history so it
+// never pollutes the parent session.
+func runTurn(ctx context.Context, al *AgentLoop, ts *turnState, cfg SubTurnConfig) (*tools.ToolResult, error) {
+ // Derive candidates from the requested model using the parent loop's provider.
+ defaultProvider := al.GetConfig().Agents.Defaults.Provider
+ candidates := providers.ResolveCandidates(
+ providers.ModelConfig{Primary: cfg.Model},
+ defaultProvider,
+ )
+
+ // Build a minimal AgentInstance for this sub-turn.
+ // It reuses the parent loop's provider and config, but gets its own
+ // ephemeral session store and tool registry.
+ toolRegistry := tools.NewToolRegistry()
+ for _, t := range cfg.Tools {
+ toolRegistry.Register(t)
+ }
+
+ parentAgent := al.GetRegistry().GetDefaultAgent()
+ childAgent := &AgentInstance{
+ ID: ts.turnID,
+ Model: cfg.Model,
+ MaxIterations: parentAgent.MaxIterations,
+ MaxTokens: cfg.MaxTokens,
+ Temperature: parentAgent.Temperature,
+ ThinkingLevel: parentAgent.ThinkingLevel,
+ ContextWindow: cfg.MaxTokens,
+ SummarizeMessageThreshold: parentAgent.SummarizeMessageThreshold,
+ SummarizeTokenPercent: parentAgent.SummarizeTokenPercent,
+ Provider: parentAgent.Provider,
+ Sessions: ts.session,
+ ContextBuilder: parentAgent.ContextBuilder,
+ Tools: toolRegistry,
+ Candidates: candidates,
+ }
+ if childAgent.MaxTokens == 0 {
+ childAgent.MaxTokens = parentAgent.MaxTokens
+ childAgent.ContextWindow = parentAgent.ContextWindow
+ }
+
+ finalContent, err := al.runAgentLoop(ctx, childAgent, processOptions{
+ SessionKey: ts.turnID,
+ UserMessage: cfg.SystemPrompt,
+ DefaultResponse: "",
+ EnableSummary: false,
+ SendResponse: false,
+ })
+ if err != nil {
+ return nil, err
+ }
+ return &tools.ToolResult{ForLLM: finalContent}, nil
+}
+
+// ====================== Other Types ======================
diff --git a/pkg/agent/subturn_test.go b/pkg/agent/subturn_test.go
new file mode 100644
index 0000000000..97dfc0130c
--- /dev/null
+++ b/pkg/agent/subturn_test.go
@@ -0,0 +1,254 @@
+package agent
+
+import (
+ "context"
+ "reflect"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/tools"
+)
+
+// ====================== Test Helper: Event Collector ======================
+type eventCollector struct {
+ events []any
+}
+
+func (c *eventCollector) collect(e any) {
+ c.events = append(c.events, e)
+}
+
+func (c *eventCollector) hasEventOfType(typ any) bool {
+ targetType := reflect.TypeOf(typ)
+ for _, e := range c.events {
+ if reflect.TypeOf(e) == targetType {
+ return true
+ }
+ }
+ return false
+}
+
+func (c *eventCollector) countOfType(typ any) int {
+ targetType := reflect.TypeOf(typ)
+ count := 0
+ for _, e := range c.events {
+ if reflect.TypeOf(e) == targetType {
+ count++
+ }
+ }
+ return count
+}
+
+// ====================== Main Test Function ======================
+func TestSpawnSubTurn(t *testing.T) {
+ tests := []struct {
+ name string
+ parentDepth int
+ config SubTurnConfig
+ wantErr error
+ wantSpawn bool
+ wantEnd bool
+ wantDepthFail bool
+ }{
+ {
+ name: "Basic success path - Single layer sub-turn",
+ parentDepth: 0,
+ config: SubTurnConfig{
+ Model: "gpt-4o-mini",
+ Tools: []tools.Tool{}, // At least one tool
+ },
+ wantErr: nil,
+ wantSpawn: true,
+ wantEnd: true,
+ },
+ {
+ name: "Nested 2 layers - Normal",
+ parentDepth: 1,
+ config: SubTurnConfig{
+ Model: "gpt-4o-mini",
+ Tools: []tools.Tool{},
+ },
+ wantErr: nil,
+ wantSpawn: true,
+ wantEnd: true,
+ },
+ {
+ name: "Depth limit triggered - 4th layer fails",
+ parentDepth: 3,
+ config: SubTurnConfig{
+ Model: "gpt-4o-mini",
+ Tools: []tools.Tool{},
+ },
+ wantErr: ErrDepthLimitExceeded,
+ wantSpawn: false,
+ wantEnd: false,
+ wantDepthFail: true,
+ },
+ {
+ name: "Invalid config - Empty Model",
+ parentDepth: 0,
+ config: SubTurnConfig{
+ Model: "",
+ Tools: []tools.Tool{},
+ },
+ wantErr: ErrInvalidSubTurnConfig,
+ wantSpawn: false,
+ wantEnd: false,
+ },
+ }
+
+ al, _, _, _, cleanup := newTestAgentLoop(t)
+ defer cleanup()
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Prepare parent Turn
+ parent := &turnState{
+ ctx: context.Background(),
+ turnID: "parent-1",
+ depth: tt.parentDepth,
+ childTurnIDs: []string{},
+ pendingResults: make(chan *tools.ToolResult, 10),
+ session: &ephemeralSessionStore{},
+ }
+
+ // Replace mock with test collector
+ collector := &eventCollector{}
+ originalEmit := MockEventBus.Emit
+ MockEventBus.Emit = collector.collect
+ defer func() { MockEventBus.Emit = originalEmit }()
+
+ // Execute spawnSubTurn
+ result, err := spawnSubTurn(context.Background(), al, parent, tt.config)
+
+ // Assert errors
+ if tt.wantErr != nil {
+ if err == nil || err != tt.wantErr {
+ t.Errorf("expected error %v, got %v", tt.wantErr, err)
+ }
+ return
+ }
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ // Verify result
+ if result == nil {
+ t.Error("expected non-nil result")
+ }
+
+ // Verify event emission
+ if tt.wantSpawn {
+ if !collector.hasEventOfType(SubTurnSpawnEvent{}) {
+ t.Error("SubTurnSpawnEvent not emitted")
+ }
+ }
+ if tt.wantEnd {
+ if !collector.hasEventOfType(SubTurnEndEvent{}) {
+ t.Error("SubTurnEndEvent not emitted")
+ }
+ }
+
+ // Verify turn tree
+ if len(parent.childTurnIDs) == 0 && !tt.wantDepthFail {
+ t.Error("child Turn not added to parent.childTurnIDs")
+ }
+
+ // Verify result delivery (pendingResults or history)
+ if len(parent.pendingResults) > 0 || len(parent.session.GetHistory("")) > 0 {
+ // Result delivered via at least one path
+ } else {
+ t.Error("child result not delivered")
+ }
+ })
+ }
+}
+
+// ====================== Extra Independent Test: Ephemeral Session Isolation ======================
+func TestSpawnSubTurn_EphemeralSessionIsolation(t *testing.T) {
+ al, _, _, _, cleanup := newTestAgentLoop(t)
+ defer cleanup()
+
+ parentSession := &ephemeralSessionStore{}
+ parentSession.AddMessage("", "user", "parent msg")
+ parent := &turnState{
+ ctx: context.Background(),
+ turnID: "parent-1",
+ depth: 0,
+ session: parentSession,
+ }
+
+ cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}}
+
+ // Record main session length before execution
+ originalLen := len(parent.session.GetHistory(""))
+
+ _, _ = spawnSubTurn(context.Background(), al, parent, cfg)
+
+ // After sub-turn ends, main session must remain unchanged
+ if len(parent.session.GetHistory("")) != originalLen {
+ t.Error("ephemeral session polluted the main session")
+ }
+}
+
+// ====================== Extra Independent Test: Result Delivery Path ======================
+func TestSpawnSubTurn_ResultDelivery(t *testing.T) {
+ al, _, _, _, cleanup := newTestAgentLoop(t)
+ defer cleanup()
+
+ parent := &turnState{
+ ctx: context.Background(),
+ turnID: "parent-1",
+ depth: 0,
+ pendingResults: make(chan *tools.ToolResult, 1),
+ session: &ephemeralSessionStore{},
+ }
+
+ cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}}
+
+ _, _ = spawnSubTurn(context.Background(), al, parent, cfg)
+
+ // Check if pendingResults received the result
+ select {
+ case res := <-parent.pendingResults:
+ if res == nil {
+ t.Error("received nil result in pendingResults")
+ }
+ default:
+ t.Error("result did not enter pendingResults")
+ }
+}
+
+// ====================== Extra Independent Test: Orphan Result Routing ======================
+func TestSpawnSubTurn_OrphanResultRouting(t *testing.T) {
+ parentCtx, cancelParent := context.WithCancel(context.Background())
+ parent := &turnState{
+ ctx: parentCtx,
+ cancelFunc: cancelParent,
+ turnID: "parent-1",
+ depth: 0,
+ pendingResults: make(chan *tools.ToolResult, 1),
+ session: &ephemeralSessionStore{},
+ }
+
+ collector := &eventCollector{}
+ originalEmit := MockEventBus.Emit
+ MockEventBus.Emit = collector.collect
+ defer func() { MockEventBus.Emit = originalEmit }()
+
+ // Simulate parent finishing before child delivers result
+ parent.Finish()
+
+ // Call deliverSubTurnResult directly to simulate a delayed child
+ deliverSubTurnResult(parent, "delayed-child", &tools.ToolResult{ForLLM: "late result"})
+
+ // Verify Orphan event is emitted
+ if !collector.hasEventOfType(SubTurnOrphanResultEvent{}) {
+ t.Error("SubTurnOrphanResultEvent not emitted for finished parent")
+ }
+
+ // Verify history is NOT polluted
+ if len(parent.session.GetHistory("")) != 0 {
+ t.Error("Parent history was polluted by orphan result")
+ }
+}
diff --git a/pkg/config/config.go b/pkg/config/config.go
index 1903412248..a8b8f337fa 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -234,6 +234,7 @@ type AgentDefaults struct {
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
Routing *RoutingConfig `json:"routing,omitempty"`
+ SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all"
}
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go
index 189af0a845..caa09b0e2e 100644
--- a/pkg/config/defaults.go
+++ b/pkg/config/defaults.go
@@ -35,6 +35,7 @@ func DefaultConfig() *Config {
MaxToolIterations: 50,
SummarizeMessageThreshold: 20,
SummarizeTokenPercent: 75,
+ SteeringMode: "one-at-a-time",
},
},
Bindings: []AgentBinding{},
@@ -384,6 +385,15 @@ func DefaultConfig() *Config {
APIBase: "http://localhost:8000/v1",
APIKey: "",
},
+
+ // Azure OpenAI - https://portal.azure.com
+ // model_name is a user-friendly alias; the model field's path after "azure/" is your deployment name
+ {
+ ModelName: "azure-gpt5",
+ Model: "azure/my-gpt5-deployment",
+ APIBase: "https://your-resource.openai.azure.com",
+ APIKey: "",
+ },
},
Gateway: GatewayConfig{
Host: "127.0.0.1",
diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go
index 302613f338..4204cc192f 100644
--- a/pkg/logger/logger.go
+++ b/pkg/logger/logger.go
@@ -5,6 +5,7 @@ import (
"os"
"path/filepath"
"runtime"
+ "strconv"
"strings"
"sync"
@@ -45,6 +46,9 @@ func init() {
consoleWriter := zerolog.ConsoleWriter{
Out: os.Stdout,
TimeFormat: "15:04:05", // TODO: make it configurable???
+
+ // Custom formatter to handle multiline strings and JSON objects
+ FormatFieldValue: formatFieldValue,
}
logger = zerolog.New(consoleWriter).With().Timestamp().Logger()
@@ -52,6 +56,37 @@ func init() {
})
}
+func formatFieldValue(i any) string {
+ var s string
+
+ switch val := i.(type) {
+ case string:
+ s = val
+ case []byte:
+ s = string(val)
+ default:
+ return fmt.Sprintf("%v", i)
+ }
+
+ if unquoted, err := strconv.Unquote(s); err == nil {
+ s = unquoted
+ }
+
+ if strings.Contains(s, "\n") {
+ return fmt.Sprintf("\n%s", s)
+ }
+
+ if strings.Contains(s, " ") {
+ if (strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}")) ||
+ (strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]")) {
+ return s
+ }
+ return fmt.Sprintf("%q", s)
+ }
+
+ return s
+}
+
func SetLevel(level LogLevel) {
mu.Lock()
defer mu.Unlock()
@@ -163,10 +198,7 @@ func logMessage(level LogLevel, component string, message string, fields map[str
event.Str("caller", fmt.Sprintf(" %s:%d (%s)", callerFile, callerLine, callerFunc))
}
- for k, v := range fields {
- event.Interface(k, v)
- }
-
+ appendFields(event, fields)
event.Msg(message)
// Also log to file if enabled
@@ -176,9 +208,8 @@ func logMessage(level LogLevel, component string, message string, fields map[str
if component != "" {
fileEvent.Str("component", component)
}
- for k, v := range fields {
- fileEvent.Interface(k, v)
- }
+
+ appendFields(event, fields)
fileEvent.Msg(message)
}
@@ -187,6 +218,26 @@ func logMessage(level LogLevel, component string, message string, fields map[str
}
}
+func appendFields(event *zerolog.Event, fields map[string]any) {
+ for k, v := range fields {
+ // Type switch to avoid double JSON serialization of strings
+ switch val := v.(type) {
+ case string:
+ event.Str(k, val)
+ case int:
+ event.Int(k, val)
+ case int64:
+ event.Int64(k, val)
+ case float64:
+ event.Float64(k, val)
+ case bool:
+ event.Bool(k, val)
+ default:
+ event.Interface(k, v) // Fallback for struct, slice and maps
+ }
+ }
+}
+
func Debug(message string) {
logMessage(DEBUG, "", message, nil)
}
diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go
index 8170a618ba..31b40484cb 100644
--- a/pkg/logger/logger_test.go
+++ b/pkg/logger/logger_test.go
@@ -141,3 +141,114 @@ func TestLoggerHelperFunctions(t *testing.T) {
Debugf("test from %v", "Debugf")
WarnF("Warning with fields", map[string]any{"key": "value"})
}
+
+func TestFormatFieldValue(t *testing.T) {
+ tests := []struct {
+ name string
+ input any
+ expected string
+ }{
+ // Basic types test (default case of the switch)
+ {
+ name: "Integer Type",
+ input: 42,
+ expected: "42",
+ },
+ {
+ name: "Boolean Type",
+ input: true,
+ expected: "true",
+ },
+ {
+ name: "Unsupported Struct Type",
+ input: struct{ A int }{A: 1},
+ expected: "{1}",
+ },
+
+ // Simple strings and byte slices test
+ {
+ name: "Simple string without spaces",
+ input: "simple_value",
+ expected: "simple_value",
+ },
+ {
+ name: "Simple byte slice",
+ input: []byte("byte_value"),
+ expected: "byte_value",
+ },
+
+ // Unquoting test (strconv.Unquote)
+ {
+ name: "Quoted string",
+ input: `"quoted_value"`,
+ expected: "quoted_value",
+ },
+
+ // Strings with newline (\n) test
+ {
+ name: "String with newline",
+ input: "line1\nline2",
+ expected: "\nline1\nline2",
+ },
+ {
+ name: "Quoted string with newline (Unquote -> newline)",
+ input: `"line1\nline2"`, // Escaped \n that Unquote will resolve
+ expected: "\nline1\nline2",
+ },
+
+ // Strings with spaces test (which should be quoted)
+ {
+ name: "String with spaces",
+ input: "hello world",
+ expected: `"hello world"`,
+ },
+ {
+ name: "Quoted string with spaces (Unquote -> has spaces -> Re-quote)",
+ input: `"hello world"`,
+ expected: `"hello world"`,
+ },
+
+ // JSON formats test (strings with spaces that start/end with brackets)
+ {
+ name: "Valid JSON object",
+ input: `{"key": "value"}`,
+ expected: `{"key": "value"}`,
+ },
+ {
+ name: "Valid JSON array",
+ input: `[1, 2, "three"]`,
+ expected: `[1, 2, "three"]`,
+ },
+ {
+ name: "Fake JSON (starts with { but doesn't end with })",
+ input: `{"key": "value"`, // Missing closing bracket, has spaces
+ expected: `"{\"key\": \"value\""`,
+ },
+ {
+ name: "Empty JSON (object)",
+ input: `{ }`,
+ expected: `{ }`,
+ },
+
+ // 7. Edge Cases
+ {
+ name: "Empty string",
+ input: "",
+ expected: "",
+ },
+ {
+ name: "Whitespace only string",
+ input: " ",
+ expected: `" "`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ actual := formatFieldValue(tt.input)
+ if actual != tt.expected {
+ t.Errorf("formatFieldValue() = %q, expected %q", actual, tt.expected)
+ }
+ })
+ }
+}
diff --git a/pkg/providers/azure/provider.go b/pkg/providers/azure/provider.go
new file mode 100644
index 0000000000..e0ddbbde44
--- /dev/null
+++ b/pkg/providers/azure/provider.go
@@ -0,0 +1,150 @@
+package azure
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/providers/common"
+ "github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
+)
+
+type (
+ LLMResponse = protocoltypes.LLMResponse
+ Message = protocoltypes.Message
+ ToolDefinition = protocoltypes.ToolDefinition
+)
+
+const (
+ // azureAPIVersion is the Azure OpenAI API version used for all requests.
+ azureAPIVersion = "2024-10-21"
+ defaultRequestTimeout = common.DefaultRequestTimeout
+)
+
+// Provider implements the LLM provider interface for Azure OpenAI endpoints.
+// It handles Azure-specific authentication (api-key header), URL construction
+// (deployment-based), and request body formatting (max_completion_tokens, no model field).
+type Provider struct {
+ apiKey string
+ apiBase string
+ httpClient *http.Client
+}
+
+// Option configures the Azure Provider.
+type Option func(*Provider)
+
+// WithRequestTimeout sets the HTTP request timeout.
+func WithRequestTimeout(timeout time.Duration) Option {
+ return func(p *Provider) {
+ if timeout > 0 {
+ p.httpClient.Timeout = timeout
+ }
+ }
+}
+
+// NewProvider creates a new Azure OpenAI provider.
+func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
+ p := &Provider{
+ apiKey: apiKey,
+ apiBase: strings.TrimRight(apiBase, "/"),
+ httpClient: common.NewHTTPClient(proxy),
+ }
+
+ for _, opt := range opts {
+ if opt != nil {
+ opt(p)
+ }
+ }
+
+ return p
+}
+
+// NewProviderWithTimeout creates a new Azure OpenAI provider with a custom request timeout in seconds.
+func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds int) *Provider {
+ return NewProvider(
+ apiKey, apiBase, proxy,
+ WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
+ )
+}
+
+// Chat sends a chat completion request to the Azure OpenAI endpoint.
+// The model parameter is used as the Azure deployment name in the URL.
+func (p *Provider) Chat(
+ ctx context.Context,
+ messages []Message,
+ tools []ToolDefinition,
+ model string,
+ options map[string]any,
+) (*LLMResponse, error) {
+ if p.apiBase == "" {
+ return nil, fmt.Errorf("Azure API base not configured")
+ }
+
+ // model is the deployment name for Azure OpenAI
+ deployment := model
+
+ // Build Azure-specific URL safely using url.JoinPath and query encoding
+ // to prevent path traversal or query injection via deployment names.
+ base, err := url.JoinPath(p.apiBase, "openai/deployments", deployment, "chat/completions")
+ if err != nil {
+ return nil, fmt.Errorf("failed to build Azure request URL: %w", err)
+ }
+ requestURL := base + "?api-version=" + azureAPIVersion
+
+ // Build request body — no "model" field (Azure infers from deployment URL)
+ requestBody := map[string]any{
+ "messages": common.SerializeMessages(messages),
+ }
+
+ if len(tools) > 0 {
+ requestBody["tools"] = tools
+ requestBody["tool_choice"] = "auto"
+ }
+
+ // Azure OpenAI always uses max_completion_tokens
+ if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
+ requestBody["max_completion_tokens"] = maxTokens
+ }
+
+ if temperature, ok := common.AsFloat(options["temperature"]); ok {
+ requestBody["temperature"] = temperature
+ }
+
+ jsonData, err := json.Marshal(requestBody)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(jsonData))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ // Azure uses api-key header instead of Authorization: Bearer
+ req.Header.Set("Content-Type", "application/json")
+ if p.apiKey != "" {
+ req.Header.Set("Api-Key", p.apiKey)
+ }
+
+ resp, err := p.httpClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to send request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, common.HandleErrorResponse(resp, p.apiBase)
+ }
+
+ return common.ReadAndParseResponse(resp, p.apiBase)
+}
+
+// GetDefaultModel returns an empty string as Azure deployments are user-configured.
+func (p *Provider) GetDefaultModel() string {
+ return ""
+}
diff --git a/pkg/providers/azure/provider_test.go b/pkg/providers/azure/provider_test.go
new file mode 100644
index 0000000000..531b812965
--- /dev/null
+++ b/pkg/providers/azure/provider_test.go
@@ -0,0 +1,232 @@
+package azure
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+)
+
+// writeValidResponse writes a minimal valid Azure OpenAI chat completion response.
+func writeValidResponse(w http.ResponseWriter) {
+ resp := map[string]any{
+ "choices": []map[string]any{
+ {
+ "message": map[string]any{"content": "ok"},
+ "finish_reason": "stop",
+ },
+ },
+ }
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(resp)
+}
+
+func TestProviderChat_AzureURLConstruction(t *testing.T) {
+ var capturedPath string
+ var capturedAPIVersion string
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ capturedPath = r.URL.Path
+ capturedAPIVersion = r.URL.Query().Get("api-version")
+ writeValidResponse(w)
+ }))
+ defer server.Close()
+
+ p := NewProvider("test-key", server.URL, "")
+ _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil)
+ if err != nil {
+ t.Fatalf("Chat() error = %v", err)
+ }
+
+ wantPath := "/openai/deployments/my-gpt5-deployment/chat/completions"
+ if capturedPath != wantPath {
+ t.Errorf("URL path = %q, want %q", capturedPath, wantPath)
+ }
+ if capturedAPIVersion != azureAPIVersion {
+ t.Errorf("api-version = %q, want %q", capturedAPIVersion, azureAPIVersion)
+ }
+}
+
+func TestProviderChat_AzureAuthHeader(t *testing.T) {
+ var capturedAPIKey string
+ var capturedAuth string
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ capturedAPIKey = r.Header.Get("Api-Key")
+ capturedAuth = r.Header.Get("Authorization")
+ writeValidResponse(w)
+ }))
+ defer server.Close()
+
+ p := NewProvider("test-azure-key", server.URL, "")
+ _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
+ if err != nil {
+ t.Fatalf("Chat() error = %v", err)
+ }
+
+ if capturedAPIKey != "test-azure-key" {
+ t.Errorf("api-key header = %q, want %q", capturedAPIKey, "test-azure-key")
+ }
+ if capturedAuth != "" {
+ t.Errorf("Authorization header should be empty, got %q", capturedAuth)
+ }
+}
+
+func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) {
+ var requestBody map[string]any
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewDecoder(r.Body).Decode(&requestBody)
+ writeValidResponse(w)
+ }))
+ defer server.Close()
+
+ p := NewProvider("test-key", server.URL, "")
+ _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
+ if err != nil {
+ t.Fatalf("Chat() error = %v", err)
+ }
+
+ if _, exists := requestBody["model"]; exists {
+ t.Error("request body should not contain 'model' field for Azure OpenAI")
+ }
+}
+
+func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) {
+ var requestBody map[string]any
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewDecoder(r.Body).Decode(&requestBody)
+ writeValidResponse(w)
+ }))
+ defer server.Close()
+
+ p := NewProvider("test-key", server.URL, "")
+ _, err := p.Chat(
+ t.Context(),
+ []Message{{Role: "user", Content: "hi"}},
+ nil,
+ "deployment",
+ map[string]any{"max_tokens": 2048},
+ )
+ if err != nil {
+ t.Fatalf("Chat() error = %v", err)
+ }
+
+ if _, exists := requestBody["max_completion_tokens"]; !exists {
+ t.Error("request body should contain 'max_completion_tokens'")
+ }
+ if _, exists := requestBody["max_tokens"]; exists {
+ t.Error("request body should not contain 'max_tokens'")
+ }
+}
+
+func TestProviderChat_AzureHTTPError(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
+ }))
+ defer server.Close()
+
+ p := NewProvider("bad-key", server.URL, "")
+ _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+}
+
+func TestProviderChat_AzureParseToolCalls(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := map[string]any{
+ "choices": []map[string]any{
+ {
+ "message": map[string]any{
+ "content": "",
+ "tool_calls": []map[string]any{
+ {
+ "id": "call_1",
+ "type": "function",
+ "function": map[string]any{
+ "name": "get_weather",
+ "arguments": `{"city":"Seattle"}`,
+ },
+ },
+ },
+ },
+ "finish_reason": "tool_calls",
+ },
+ },
+ }
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(resp)
+ }))
+ defer server.Close()
+
+ p := NewProvider("test-key", server.URL, "")
+ out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil)
+ if err != nil {
+ t.Fatalf("Chat() error = %v", err)
+ }
+
+ if len(out.ToolCalls) != 1 {
+ t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
+ }
+ if out.ToolCalls[0].Name != "get_weather" {
+ t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
+ }
+}
+
+func TestProvider_AzureEmptyAPIBase(t *testing.T) {
+ p := NewProvider("test-key", "", "")
+ _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
+ if err == nil {
+ t.Fatal("expected error for empty API base")
+ }
+}
+
+func TestProvider_AzureRequestTimeoutDefault(t *testing.T) {
+ p := NewProvider("test-key", "https://example.com", "")
+ if p.httpClient.Timeout != defaultRequestTimeout {
+ t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
+ }
+}
+
+func TestProvider_AzureRequestTimeoutOverride(t *testing.T) {
+ p := NewProvider("test-key", "https://example.com", "", WithRequestTimeout(300*time.Second))
+ if p.httpClient.Timeout != 300*time.Second {
+ t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
+ }
+}
+
+func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
+ p := NewProviderWithTimeout("test-key", "https://example.com", "", 180)
+ if p.httpClient.Timeout != 180*time.Second {
+ t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second)
+ }
+}
+
+func TestProviderChat_AzureDeploymentNameEscaped(t *testing.T) {
+ var capturedPath string
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ capturedPath = r.URL.RawPath // use RawPath to see percent-encoding
+ if capturedPath == "" {
+ capturedPath = r.URL.Path
+ }
+ writeValidResponse(w)
+ }))
+ defer server.Close()
+
+ p := NewProvider("test-key", server.URL, "")
+
+ // Deployment name with characters that could cause path injection
+ _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my deploy/../../admin", nil)
+ if err != nil {
+ t.Fatalf("Chat() error = %v", err)
+ }
+
+ // The slash and special chars in the deployment name must be escaped, not treated as path separators
+ if capturedPath == "/openai/deployments/my deploy/../../admin/chat/completions" {
+ t.Fatal("deployment name was interpolated without escaping — path injection possible")
+ }
+}
diff --git a/pkg/providers/common/common.go b/pkg/providers/common/common.go
new file mode 100644
index 0000000000..23680a1bf9
--- /dev/null
+++ b/pkg/providers/common/common.go
@@ -0,0 +1,380 @@
+// PicoClaw - Ultra-lightweight personal AI agent
+// License: MIT
+//
+// Copyright (c) 2026 PicoClaw contributors
+
+// Package common provides shared utilities used by multiple LLM provider
+// implementations (openai_compat, azure, etc.).
+package common
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
+)
+
+// Re-export protocol types used across providers.
+type (
+ ToolCall = protocoltypes.ToolCall
+ FunctionCall = protocoltypes.FunctionCall
+ LLMResponse = protocoltypes.LLMResponse
+ UsageInfo = protocoltypes.UsageInfo
+ Message = protocoltypes.Message
+ ToolDefinition = protocoltypes.ToolDefinition
+ ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
+ ExtraContent = protocoltypes.ExtraContent
+ GoogleExtra = protocoltypes.GoogleExtra
+ ReasoningDetail = protocoltypes.ReasoningDetail
+)
+
+const DefaultRequestTimeout = 120 * time.Second
+
+// NewHTTPClient creates an *http.Client with an optional proxy and the default timeout.
+func NewHTTPClient(proxy string) *http.Client {
+ client := &http.Client{
+ Timeout: DefaultRequestTimeout,
+ }
+ if proxy != "" {
+ parsed, err := url.Parse(proxy)
+ if err == nil {
+ // Preserve http.DefaultTransport settings (TLS, HTTP/2, timeouts, etc.)
+ if base, ok := http.DefaultTransport.(*http.Transport); ok {
+ tr := base.Clone()
+ tr.Proxy = http.ProxyURL(parsed)
+ client.Transport = tr
+ } else {
+ // Fallback: minimal transport if DefaultTransport is not *http.Transport.
+ client.Transport = &http.Transport{
+ Proxy: http.ProxyURL(parsed),
+ }
+ }
+ } else {
+ log.Printf("common: invalid proxy URL %q: %v", proxy, err)
+ }
+ }
+ return client
+}
+
+// --- Message serialization ---
+
+// openaiMessage is the wire-format message for OpenAI-compatible APIs.
+// It mirrors protocoltypes.Message but omits SystemParts, which is an
+// internal field that would be unknown to third-party endpoints.
+type openaiMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+ ReasoningContent string `json:"reasoning_content,omitempty"`
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+ ToolCallID string `json:"tool_call_id,omitempty"`
+}
+
+// SerializeMessages converts internal Message structs to the OpenAI wire format.
+// - Strips SystemParts (unknown to third-party endpoints)
+// - Converts messages with Media to multipart content format (text + image_url parts)
+// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
+func SerializeMessages(messages []Message) []any {
+ out := make([]any, 0, len(messages))
+ for _, m := range messages {
+ if len(m.Media) == 0 {
+ out = append(out, openaiMessage{
+ Role: m.Role,
+ Content: m.Content,
+ ReasoningContent: m.ReasoningContent,
+ ToolCalls: m.ToolCalls,
+ ToolCallID: m.ToolCallID,
+ })
+ continue
+ }
+
+ // Multipart content format for messages with media
+ parts := make([]map[string]any, 0, 1+len(m.Media))
+ if m.Content != "" {
+ parts = append(parts, map[string]any{
+ "type": "text",
+ "text": m.Content,
+ })
+ }
+ for _, mediaURL := range m.Media {
+ if strings.HasPrefix(mediaURL, "data:image/") {
+ parts = append(parts, map[string]any{
+ "type": "image_url",
+ "image_url": map[string]any{
+ "url": mediaURL,
+ },
+ })
+ }
+ }
+
+ msg := map[string]any{
+ "role": m.Role,
+ "content": parts,
+ }
+ if m.ToolCallID != "" {
+ msg["tool_call_id"] = m.ToolCallID
+ }
+ if len(m.ToolCalls) > 0 {
+ msg["tool_calls"] = m.ToolCalls
+ }
+ if m.ReasoningContent != "" {
+ msg["reasoning_content"] = m.ReasoningContent
+ }
+ out = append(out, msg)
+ }
+ return out
+}
+
+// --- Response parsing ---
+
+// ParseResponse parses a JSON chat completion response body into an LLMResponse.
+func ParseResponse(body io.Reader) (*LLMResponse, error) {
+ var apiResponse struct {
+ Choices []struct {
+ Message struct {
+ Content string `json:"content"`
+ ReasoningContent string `json:"reasoning_content"`
+ Reasoning string `json:"reasoning"`
+ ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
+ ToolCalls []struct {
+ ID string `json:"id"`
+ Type string `json:"type"`
+ Function *struct {
+ Name string `json:"name"`
+ Arguments json.RawMessage `json:"arguments"`
+ } `json:"function"`
+ ExtraContent *struct {
+ Google *struct {
+ ThoughtSignature string `json:"thought_signature"`
+ } `json:"google"`
+ } `json:"extra_content"`
+ } `json:"tool_calls"`
+ } `json:"message"`
+ FinishReason string `json:"finish_reason"`
+ } `json:"choices"`
+ Usage *UsageInfo `json:"usage"`
+ }
+
+ if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
+ return nil, fmt.Errorf("failed to decode response: %w", err)
+ }
+
+ if len(apiResponse.Choices) == 0 {
+ return &LLMResponse{
+ Content: "",
+ FinishReason: "stop",
+ }, nil
+ }
+
+ choice := apiResponse.Choices[0]
+ toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
+ for _, tc := range choice.Message.ToolCalls {
+ arguments := make(map[string]any)
+ name := ""
+
+ // Extract thought_signature from Gemini/Google-specific extra content
+ thoughtSignature := ""
+ if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
+ thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
+ }
+
+ if tc.Function != nil {
+ name = tc.Function.Name
+ arguments = DecodeToolCallArguments(tc.Function.Arguments, name)
+ }
+
+ toolCall := ToolCall{
+ ID: tc.ID,
+ Name: name,
+ Arguments: arguments,
+ ThoughtSignature: thoughtSignature,
+ }
+
+ if thoughtSignature != "" {
+ toolCall.ExtraContent = &ExtraContent{
+ Google: &GoogleExtra{
+ ThoughtSignature: thoughtSignature,
+ },
+ }
+ }
+
+ toolCalls = append(toolCalls, toolCall)
+ }
+
+ return &LLMResponse{
+ Content: choice.Message.Content,
+ ReasoningContent: choice.Message.ReasoningContent,
+ Reasoning: choice.Message.Reasoning,
+ ReasoningDetails: choice.Message.ReasoningDetails,
+ ToolCalls: toolCalls,
+ FinishReason: choice.FinishReason,
+ Usage: apiResponse.Usage,
+ }, nil
+}
+
+// DecodeToolCallArguments decodes a tool call's arguments from raw JSON.
+func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
+ arguments := make(map[string]any)
+ raw = bytes.TrimSpace(raw)
+ if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
+ return arguments
+ }
+
+ var decoded any
+ if err := json.Unmarshal(raw, &decoded); err != nil {
+ log.Printf("common: failed to decode tool call arguments payload for %q: %v", name, err)
+ arguments["raw"] = string(raw)
+ return arguments
+ }
+
+ switch v := decoded.(type) {
+ case string:
+ if strings.TrimSpace(v) == "" {
+ return arguments
+ }
+ if err := json.Unmarshal([]byte(v), &arguments); err != nil {
+ log.Printf("common: failed to decode tool call arguments for %q: %v", name, err)
+ arguments["raw"] = v
+ }
+ return arguments
+ case map[string]any:
+ return v
+ default:
+ log.Printf("common: unsupported tool call arguments type for %q: %T", name, decoded)
+ arguments["raw"] = string(raw)
+ return arguments
+ }
+}
+
+// --- HTTP response helpers ---
+
+// HandleErrorResponse reads a non-200 response body and returns an appropriate error.
+func HandleErrorResponse(resp *http.Response, apiBase string) error {
+ contentType := resp.Header.Get("Content-Type")
+ body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
+ if readErr != nil {
+ return fmt.Errorf("failed to read response: %w", readErr)
+ }
+ if LooksLikeHTML(body, contentType) {
+ return WrapHTMLResponseError(resp.StatusCode, body, contentType, apiBase)
+ }
+ return fmt.Errorf(
+ "API request failed:\n Status: %d\n Body: %s",
+ resp.StatusCode,
+ ResponsePreview(body, 128),
+ )
+}
+
+// ReadAndParseResponse peeks at the response body to detect HTML errors,
+// then parses the JSON response into an LLMResponse.
+func ReadAndParseResponse(resp *http.Response, apiBase string) (*LLMResponse, error) {
+ contentType := resp.Header.Get("Content-Type")
+ reader := bufio.NewReader(resp.Body)
+ prefix, err := reader.Peek(256)
+ if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
+ return nil, fmt.Errorf("failed to inspect response: %w", err)
+ }
+ if LooksLikeHTML(prefix, contentType) {
+ return nil, WrapHTMLResponseError(resp.StatusCode, prefix, contentType, apiBase)
+ }
+ out, err := ParseResponse(reader)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse JSON response: %w", err)
+ }
+ return out, nil
+}
+
+// LooksLikeHTML checks if the response body appears to be HTML.
+func LooksLikeHTML(body []byte, contentType string) bool {
+ contentType = strings.ToLower(strings.TrimSpace(contentType))
+ if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
+ return true
+ }
+ prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
+ return bytes.HasPrefix(prefix, []byte(""
+ }
+ if len(trimmed) <= maxLen {
+ return string(trimmed)
+ }
+ return string(trimmed[:maxLen]) + "..."
+}
+
+func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
+ i := 0
+ for i < len(body) {
+ switch body[i] {
+ case ' ', '\t', '\n', '\r', '\f', '\v':
+ i++
+ default:
+ end := i + maxLen
+ if end > len(body) {
+ end = len(body)
+ }
+ return body[i:end]
+ }
+ }
+ return nil
+}
+
+// --- Numeric helpers ---
+
+// AsInt converts various numeric types to int.
+func AsInt(v any) (int, bool) {
+ switch val := v.(type) {
+ case int:
+ return val, true
+ case int64:
+ return int(val), true
+ case float64:
+ return int(val), true
+ case float32:
+ return int(val), true
+ default:
+ return 0, false
+ }
+}
+
+// AsFloat converts various numeric types to float64.
+func AsFloat(v any) (float64, bool) {
+ switch val := v.(type) {
+ case float64:
+ return val, true
+ case float32:
+ return float64(val), true
+ case int:
+ return float64(val), true
+ case int64:
+ return float64(val), true
+ default:
+ return 0, false
+ }
+}
diff --git a/pkg/providers/common/common_test.go b/pkg/providers/common/common_test.go
new file mode 100644
index 0000000000..bb7e7434d0
--- /dev/null
+++ b/pkg/providers/common/common_test.go
@@ -0,0 +1,558 @@
+package common
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
+)
+
+// --- NewHTTPClient tests ---
+
+func TestNewHTTPClient_DefaultTimeout(t *testing.T) {
+ client := NewHTTPClient("")
+ if client.Timeout != DefaultRequestTimeout {
+ t.Errorf("timeout = %v, want %v", client.Timeout, DefaultRequestTimeout)
+ }
+}
+
+func TestNewHTTPClient_WithProxy(t *testing.T) {
+ client := NewHTTPClient("http://127.0.0.1:8080")
+ transport, ok := client.Transport.(*http.Transport)
+ if !ok || transport == nil {
+ t.Fatalf("expected http.Transport with proxy, got %T", client.Transport)
+ }
+ req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}}
+ gotProxy, err := transport.Proxy(req)
+ if err != nil {
+ t.Fatalf("proxy function error: %v", err)
+ }
+ if gotProxy == nil || gotProxy.String() != "http://127.0.0.1:8080" {
+ t.Errorf("proxy = %v, want http://127.0.0.1:8080", gotProxy)
+ }
+}
+
+func TestNewHTTPClient_NoProxy(t *testing.T) {
+ client := NewHTTPClient("")
+ if client.Transport != nil {
+ t.Errorf("expected nil transport without proxy, got %T", client.Transport)
+ }
+}
+
+func TestNewHTTPClient_InvalidProxy(t *testing.T) {
+ // Should not panic, just log and return client without proxy
+ client := NewHTTPClient("://bad-url")
+ if client == nil {
+ t.Fatal("expected non-nil client even with invalid proxy")
+ }
+}
+
+// --- SerializeMessages tests ---
+
+func TestSerializeMessages_PlainText(t *testing.T) {
+ messages := []Message{
+ {Role: "user", Content: "hello"},
+ {Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
+ }
+ result := SerializeMessages(messages)
+
+ data, _ := json.Marshal(result)
+ var msgs []map[string]any
+ json.Unmarshal(data, &msgs)
+
+ if msgs[0]["content"] != "hello" {
+ t.Errorf("expected plain string content, got %v", msgs[0]["content"])
+ }
+ if msgs[1]["reasoning_content"] != "thinking..." {
+ t.Errorf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"])
+ }
+}
+
+func TestSerializeMessages_WithMedia(t *testing.T) {
+ messages := []Message{
+ {Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
+ }
+ result := SerializeMessages(messages)
+
+ data, _ := json.Marshal(result)
+ var msgs []map[string]any
+ json.Unmarshal(data, &msgs)
+
+ content, ok := msgs[0]["content"].([]any)
+ if !ok {
+ t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
+ }
+ if len(content) != 2 {
+ t.Fatalf("expected 2 content parts, got %d", len(content))
+ }
+}
+
+func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
+ messages := []Message{
+ {Role: "tool", Content: "result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
+ }
+ result := SerializeMessages(messages)
+
+ data, _ := json.Marshal(result)
+ var msgs []map[string]any
+ json.Unmarshal(data, &msgs)
+
+ if msgs[0]["tool_call_id"] != "call_1" {
+ t.Errorf("tool_call_id not preserved, got %v", msgs[0]["tool_call_id"])
+ }
+}
+
+func TestSerializeMessages_StripsSystemParts(t *testing.T) {
+ messages := []Message{
+ {
+ Role: "system",
+ Content: "you are helpful",
+ SystemParts: []protocoltypes.ContentBlock{
+ {Type: "text", Text: "you are helpful"},
+ },
+ },
+ }
+ result := SerializeMessages(messages)
+
+ data, _ := json.Marshal(result)
+ if strings.Contains(string(data), "system_parts") {
+ t.Error("system_parts should not appear in serialized output")
+ }
+}
+
+// --- ParseResponse tests ---
+
+func TestParseResponse_BasicContent(t *testing.T) {
+ body := `{"choices":[{"message":{"content":"hello world"},"finish_reason":"stop"}]}`
+ out, err := ParseResponse(strings.NewReader(body))
+ if err != nil {
+ t.Fatalf("ParseResponse() error = %v", err)
+ }
+ if out.Content != "hello world" {
+ t.Errorf("Content = %q, want %q", out.Content, "hello world")
+ }
+ if out.FinishReason != "stop" {
+ t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
+ }
+}
+
+func TestParseResponse_EmptyChoices(t *testing.T) {
+ body := `{"choices":[]}`
+ out, err := ParseResponse(strings.NewReader(body))
+ if err != nil {
+ t.Fatalf("ParseResponse() error = %v", err)
+ }
+ if out.Content != "" {
+ t.Errorf("Content = %q, want empty", out.Content)
+ }
+ if out.FinishReason != "stop" {
+ t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
+ }
+}
+
+func TestParseResponse_WithToolCalls(t *testing.T) {
+ body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"city\":\"SF\"}"}}]},"finish_reason":"tool_calls"}]}`
+ out, err := ParseResponse(strings.NewReader(body))
+ if err != nil {
+ t.Fatalf("ParseResponse() error = %v", err)
+ }
+ if len(out.ToolCalls) != 1 {
+ t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
+ }
+ if out.ToolCalls[0].Name != "get_weather" {
+ t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
+ }
+ if out.ToolCalls[0].Arguments["city"] != "SF" {
+ t.Errorf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
+ }
+}
+
+func TestParseResponse_WithUsage(t *testing.T) {
+ body := `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}`
+ out, err := ParseResponse(strings.NewReader(body))
+ if err != nil {
+ t.Fatalf("ParseResponse() error = %v", err)
+ }
+ if out.Usage == nil {
+ t.Fatal("Usage is nil")
+ }
+ if out.Usage.PromptTokens != 10 {
+ t.Errorf("PromptTokens = %d, want 10", out.Usage.PromptTokens)
+ }
+}
+
+func TestParseResponse_WithReasoningContent(t *testing.T) {
+ body := `{"choices":[{"message":{"content":"2","reasoning_content":"Let me think... 1+1=2"},"finish_reason":"stop"}]}`
+ out, err := ParseResponse(strings.NewReader(body))
+ if err != nil {
+ t.Fatalf("ParseResponse() error = %v", err)
+ }
+ if out.ReasoningContent != "Let me think... 1+1=2" {
+ t.Errorf("ReasoningContent = %q, want %q", out.ReasoningContent, "Let me think... 1+1=2")
+ }
+}
+
+func TestParseResponse_InvalidJSON(t *testing.T) {
+ _, err := ParseResponse(strings.NewReader("not json"))
+ if err == nil {
+ t.Fatal("expected error for invalid JSON")
+ }
+}
+
+// --- DecodeToolCallArguments tests ---
+
+func TestDecodeToolCallArguments_ObjectJSON(t *testing.T) {
+ raw := json.RawMessage(`{"city":"Seattle","units":"metric"}`)
+ args := DecodeToolCallArguments(raw, "test")
+ if args["city"] != "Seattle" {
+ t.Errorf("city = %v, want Seattle", args["city"])
+ }
+ if args["units"] != "metric" {
+ t.Errorf("units = %v, want metric", args["units"])
+ }
+}
+
+func TestDecodeToolCallArguments_StringJSON(t *testing.T) {
+ raw := json.RawMessage(`"{\"city\":\"SF\"}"`)
+ args := DecodeToolCallArguments(raw, "test")
+ if args["city"] != "SF" {
+ t.Errorf("city = %v, want SF", args["city"])
+ }
+}
+
+func TestDecodeToolCallArguments_EmptyInput(t *testing.T) {
+ args := DecodeToolCallArguments(nil, "test")
+ if len(args) != 0 {
+ t.Errorf("expected empty map, got %v", args)
+ }
+}
+
+func TestDecodeToolCallArguments_NullInput(t *testing.T) {
+ args := DecodeToolCallArguments(json.RawMessage(`null`), "test")
+ if len(args) != 0 {
+ t.Errorf("expected empty map, got %v", args)
+ }
+}
+
+func TestDecodeToolCallArguments_InvalidJSON(t *testing.T) {
+ args := DecodeToolCallArguments(json.RawMessage(`not-json`), "test")
+ if _, ok := args["raw"]; !ok {
+ t.Error("expected 'raw' fallback key for invalid JSON")
+ }
+}
+
+func TestDecodeToolCallArguments_EmptyStringJSON(t *testing.T) {
+ args := DecodeToolCallArguments(json.RawMessage(`" "`), "test")
+ if len(args) != 0 {
+ t.Errorf("expected empty map for whitespace string, got %v", args)
+ }
+}
+
+// --- HandleErrorResponse tests ---
+
+func TestHandleErrorResponse_JSONError(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusBadRequest)
+ w.Write([]byte(`{"error":"bad request"}`))
+ }))
+ defer server.Close()
+
+ resp, err := http.Get(server.URL)
+ if err != nil {
+ t.Fatalf("http.Get() error = %v", err)
+ }
+ defer resp.Body.Close()
+ err = HandleErrorResponse(resp, server.URL)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if !strings.Contains(err.Error(), "400") {
+ t.Errorf("error should contain status code, got %v", err)
+ }
+ if strings.Contains(err.Error(), "HTML") {
+ t.Errorf("should not mention HTML for JSON error, got %v", err)
+ }
+}
+
+func TestHandleErrorResponse_HTMLError(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/html")
+ w.WriteHeader(http.StatusBadGateway)
+ w.Write([]byte("bad gateway"))
+ }))
+ defer server.Close()
+
+ resp, err := http.Get(server.URL)
+ if err != nil {
+ t.Fatalf("http.Get() error = %v", err)
+ }
+ defer resp.Body.Close()
+ err = HandleErrorResponse(resp, server.URL)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if !strings.Contains(err.Error(), "HTML instead of JSON") {
+ t.Errorf("expected HTML error message, got %v", err)
+ }
+}
+
+// --- ReadAndParseResponse tests ---
+
+func TestReadAndParseResponse_ValidJSON(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
+ }))
+ defer server.Close()
+
+ resp, err := http.Get(server.URL)
+ if err != nil {
+ t.Fatalf("http.Get() error = %v", err)
+ }
+ defer resp.Body.Close()
+ out, err := ReadAndParseResponse(resp, server.URL)
+ if err != nil {
+ t.Fatalf("ReadAndParseResponse() error = %v", err)
+ }
+ if out.Content != "ok" {
+ t.Errorf("Content = %q, want %q", out.Content, "ok")
+ }
+}
+
+func TestReadAndParseResponse_HTMLResponse(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/html")
+ w.Write([]byte("login page"))
+ }))
+ defer server.Close()
+
+ resp, err := http.Get(server.URL)
+ if err != nil {
+ t.Fatalf("http.Get() error = %v", err)
+ }
+ defer resp.Body.Close()
+ _, err = ReadAndParseResponse(resp, server.URL)
+ if err == nil {
+ t.Fatal("expected error for HTML response")
+ }
+ if !strings.Contains(err.Error(), "HTML instead of JSON") {
+ t.Errorf("expected HTML error, got %v", err)
+ }
+}
+
+// --- LooksLikeHTML tests ---
+
+func TestLooksLikeHTML_ContentTypeHTML(t *testing.T) {
+ if !LooksLikeHTML(nil, "text/html; charset=utf-8") {
+ t.Error("expected true for text/html content type")
+ }
+}
+
+func TestLooksLikeHTML_ContentTypeXHTML(t *testing.T) {
+ if !LooksLikeHTML(nil, "application/xhtml+xml") {
+ t.Error("expected true for xhtml content type")
+ }
+}
+
+func TestLooksLikeHTML_BodyPrefix(t *testing.T) {
+ tests := []struct {
+ name string
+ body string
+ }{
+ {"doctype", ""},
+ {"html tag", ""},
+ {"head tag", ""},
+ {"body tag", "content"},
+ {"whitespace before", " \n\t"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if !LooksLikeHTML([]byte(tt.body), "application/json") {
+ t.Errorf("expected true for body %q", tt.body)
+ }
+ })
+ }
+}
+
+func TestLooksLikeHTML_NotHTML(t *testing.T) {
+ if LooksLikeHTML([]byte(`{"error":"bad"}`), "application/json") {
+ t.Error("expected false for JSON body")
+ }
+}
+
+// --- ResponsePreview tests ---
+
+func TestResponsePreview_Short(t *testing.T) {
+ got := ResponsePreview([]byte("hello"), 128)
+ if got != "hello" {
+ t.Errorf("got %q, want %q", got, "hello")
+ }
+}
+
+func TestResponsePreview_Truncated(t *testing.T) {
+ body := strings.Repeat("a", 200)
+ got := ResponsePreview([]byte(body), 128)
+ if len(got) != 131 { // 128 + "..."
+ t.Errorf("len = %d, want 131", len(got))
+ }
+ if !strings.HasSuffix(got, "...") {
+ t.Error("expected ... suffix")
+ }
+}
+
+func TestResponsePreview_Empty(t *testing.T) {
+ got := ResponsePreview([]byte(""), 128)
+ if got != "" {
+ t.Errorf("got %q, want %q", got, "")
+ }
+}
+
+func TestResponsePreview_Whitespace(t *testing.T) {
+ got := ResponsePreview([]byte(" \n\t "), 128)
+ if got != "" {
+ t.Errorf("got %q, want %q for whitespace-only body", got, "")
+ }
+}
+
+// --- AsInt tests ---
+
+func TestAsInt(t *testing.T) {
+ tests := []struct {
+ name string
+ val any
+ want int
+ ok bool
+ }{
+ {"int", 42, 42, true},
+ {"int64", int64(99), 99, true},
+ {"float64", float64(512), 512, true},
+ {"float32", float32(256), 256, true},
+ {"string", "nope", 0, false},
+ {"nil", nil, 0, false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, ok := AsInt(tt.val)
+ if ok != tt.ok || got != tt.want {
+ t.Errorf("AsInt(%v) = (%d, %v), want (%d, %v)", tt.val, got, ok, tt.want, tt.ok)
+ }
+ })
+ }
+}
+
+// --- AsFloat tests ---
+
+func TestAsFloat(t *testing.T) {
+ tests := []struct {
+ name string
+ val any
+ want float64
+ ok bool
+ }{
+ {"float64", float64(0.7), 0.7, true},
+ {"float32", float32(0.5), float64(float32(0.5)), true},
+ {"int", 1, 1.0, true},
+ {"int64", int64(100), 100.0, true},
+ {"string", "nope", 0, false},
+ {"nil", nil, 0, false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, ok := AsFloat(tt.val)
+ if ok != tt.ok || got != tt.want {
+ t.Errorf("AsFloat(%v) = (%f, %v), want (%f, %v)", tt.val, got, ok, tt.want, tt.ok)
+ }
+ })
+ }
+}
+
+// --- WrapHTMLResponseError tests ---
+
+func TestWrapHTMLResponseError(t *testing.T) {
+ err := WrapHTMLResponseError(502, []byte("bad"), "text/html", "https://api.example.com")
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ msg := err.Error()
+ if !strings.Contains(msg, "502") {
+ t.Errorf("expected status code in error, got %v", msg)
+ }
+ if !strings.Contains(msg, "https://api.example.com") {
+ t.Errorf("expected api base in error, got %v", msg)
+ }
+ if !strings.Contains(msg, "HTML instead of JSON") {
+ t.Errorf("expected HTML mention in error, got %v", msg)
+ }
+}
+
+// --- HandleErrorResponse with read failure ---
+
+func TestHandleErrorResponse_EmptyBody(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusInternalServerError)
+ // empty body
+ }))
+ defer server.Close()
+
+ resp, err := http.Get(server.URL)
+ if err != nil {
+ t.Fatalf("http.Get() error = %v", err)
+ }
+ defer resp.Body.Close()
+ err = HandleErrorResponse(resp, server.URL)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if !strings.Contains(err.Error(), "500") {
+ t.Errorf("expected status code, got %v", err)
+ }
+}
+
+// --- ReadAndParseResponse with invalid JSON ---
+
+func TestReadAndParseResponse_InvalidJSON(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.Write([]byte("not valid json"))
+ }))
+ defer server.Close()
+
+ resp, err := http.Get(server.URL)
+ if err != nil {
+ t.Fatalf("http.Get() error = %v", err)
+ }
+ defer resp.Body.Close()
+ _, err = ReadAndParseResponse(resp, server.URL)
+ if err == nil {
+ t.Fatal("expected error for invalid JSON")
+ }
+}
+
+// --- ParseResponse with thought_signature (Google/Gemini) ---
+
+func TestParseResponse_WithThoughtSignature(t *testing.T) {
+ body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"test_tool","arguments":"{}"},"extra_content":{"google":{"thought_signature":"sig123"}}}]},"finish_reason":"tool_calls"}]}`
+ out, err := ParseResponse(strings.NewReader(body))
+ if err != nil {
+ t.Fatalf("ParseResponse() error = %v", err)
+ }
+ if len(out.ToolCalls) != 1 {
+ t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
+ }
+ if out.ToolCalls[0].ThoughtSignature != "sig123" {
+ t.Errorf("ThoughtSignature = %q, want %q", out.ToolCalls[0].ThoughtSignature, "sig123")
+ }
+ if out.ToolCalls[0].ExtraContent == nil || out.ToolCalls[0].ExtraContent.Google == nil {
+ t.Fatal("ExtraContent.Google is nil")
+ }
+ if out.ToolCalls[0].ExtraContent.Google.ThoughtSignature != "sig123" {
+ t.Errorf("ExtraContent.Google.ThoughtSignature = %q, want %q",
+ out.ToolCalls[0].ExtraContent.Google.ThoughtSignature, "sig123")
+ }
+}
diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go
index e99e07bc26..b7567f9fcf 100644
--- a/pkg/providers/factory_provider.go
+++ b/pkg/providers/factory_provider.go
@@ -11,6 +11,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
+ "github.com/sipeed/picoclaw/pkg/providers/azure"
)
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
@@ -94,6 +95,24 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.RequestTimeout,
), modelID, nil
+ case "azure", "azure-openai":
+ // Azure OpenAI uses deployment-based URLs, api-key header auth,
+ // and always sends max_completion_tokens.
+ if cfg.APIKey == "" {
+ return nil, "", fmt.Errorf("api_key is required for azure protocol")
+ }
+ if cfg.APIBase == "" {
+ return nil, "", fmt.Errorf(
+ "api_base is required for azure protocol (e.g., https://your-resource.openai.azure.com)",
+ )
+ }
+ return azure.NewProviderWithTimeout(
+ cfg.APIKey,
+ cfg.APIBase,
+ cfg.Proxy,
+ cfg.RequestTimeout,
+ ), modelID, nil
+
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian",
diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go
index 00676ebf98..b678a7eb61 100644
--- a/pkg/providers/factory_provider_test.go
+++ b/pkg/providers/factory_provider_test.go
@@ -64,6 +64,12 @@ func TestExtractProtocol(t *testing.T) {
wantProtocol: "nvidia",
wantModelID: "meta/llama-3.1-8b",
},
+ {
+ name: "azure with prefix",
+ model: "azure/my-gpt5-deployment",
+ wantProtocol: "azure",
+ wantModelID: "my-gpt5-deployment",
+ },
}
for _, tt := range tests {
@@ -371,3 +377,69 @@ func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) {
t.Fatalf("Chat() error = %q, want timeout-related error", errMsg)
}
}
+
+func TestCreateProviderFromConfig_Azure(t *testing.T) {
+ cfg := &config.ModelConfig{
+ ModelName: "azure-gpt5",
+ Model: "azure/my-gpt5-deployment",
+ APIKey: "test-azure-key",
+ APIBase: "https://my-resource.openai.azure.com",
+ }
+
+ provider, modelID, err := CreateProviderFromConfig(cfg)
+ if err != nil {
+ t.Fatalf("CreateProviderFromConfig() error = %v", err)
+ }
+ if provider == nil {
+ t.Fatal("CreateProviderFromConfig() returned nil provider")
+ }
+ if modelID != "my-gpt5-deployment" {
+ t.Errorf("modelID = %q, want %q", modelID, "my-gpt5-deployment")
+ }
+}
+
+func TestCreateProviderFromConfig_AzureOpenAIAlias(t *testing.T) {
+ cfg := &config.ModelConfig{
+ ModelName: "azure-gpt4",
+ Model: "azure-openai/my-deployment",
+ APIKey: "test-azure-key",
+ APIBase: "https://my-resource.openai.azure.com",
+ }
+
+ provider, modelID, err := CreateProviderFromConfig(cfg)
+ if err != nil {
+ t.Fatalf("CreateProviderFromConfig() error = %v", err)
+ }
+ if provider == nil {
+ t.Fatal("CreateProviderFromConfig() returned nil provider")
+ }
+ if modelID != "my-deployment" {
+ t.Errorf("modelID = %q, want %q", modelID, "my-deployment")
+ }
+}
+
+func TestCreateProviderFromConfig_AzureMissingAPIKey(t *testing.T) {
+ cfg := &config.ModelConfig{
+ ModelName: "azure-gpt5",
+ Model: "azure/my-gpt5-deployment",
+ APIBase: "https://my-resource.openai.azure.com",
+ }
+
+ _, _, err := CreateProviderFromConfig(cfg)
+ if err == nil {
+ t.Fatal("CreateProviderFromConfig() expected error for missing API key")
+ }
+}
+
+func TestCreateProviderFromConfig_AzureMissingAPIBase(t *testing.T) {
+ cfg := &config.ModelConfig{
+ ModelName: "azure-gpt5",
+ Model: "azure/my-gpt5-deployment",
+ APIKey: "test-azure-key",
+ }
+
+ _, _, err := CreateProviderFromConfig(cfg)
+ if err == nil {
+ t.Fatal("CreateProviderFromConfig() expected error for missing API base")
+ }
+}
diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go
index f97bf3acd5..fb2abaa5c2 100644
--- a/pkg/providers/openai_compat/provider.go
+++ b/pkg/providers/openai_compat/provider.go
@@ -1,18 +1,16 @@
package openai_compat
import (
- "bufio"
"bytes"
"context"
"encoding/json"
"fmt"
- "io"
- "log"
"net/http"
"net/url"
"strings"
"time"
+ "github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
@@ -38,7 +36,7 @@ type Provider struct {
type Option func(*Provider)
-const defaultRequestTimeout = 120 * time.Second
+const defaultRequestTimeout = common.DefaultRequestTimeout
func WithMaxTokensField(maxTokensField string) Option {
return func(p *Provider) {
@@ -55,25 +53,10 @@ func WithRequestTimeout(timeout time.Duration) Option {
}
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
- client := &http.Client{
- Timeout: defaultRequestTimeout,
- }
-
- if proxy != "" {
- parsed, err := url.Parse(proxy)
- if err == nil {
- client.Transport = &http.Transport{
- Proxy: http.ProxyURL(parsed),
- }
- } else {
- log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
- }
- }
-
p := &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
- httpClient: client,
+ httpClient: common.NewHTTPClient(proxy),
}
for _, opt := range opts {
@@ -117,7 +100,7 @@ func (p *Provider) Chat(
requestBody := map[string]any{
"model": model,
- "messages": serializeMessages(messages),
+ "messages": common.SerializeMessages(messages),
}
if len(tools) > 0 {
@@ -125,7 +108,7 @@ func (p *Provider) Chat(
requestBody["tool_choice"] = "auto"
}
- if maxTokens, ok := asInt(options["max_tokens"]); ok {
+ if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
// Use configured maxTokensField if specified, otherwise fallback to model-based detection
fieldName := p.maxTokensField
if fieldName == "" {
@@ -141,7 +124,7 @@ func (p *Provider) Chat(
requestBody[fieldName] = maxTokens
}
- if temperature, ok := asFloat(options["temperature"]); ok {
+ if temperature, ok := common.AsFloat(options["temperature"]); ok {
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1.
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
@@ -185,275 +168,11 @@ func (p *Provider) Chat(
}
defer resp.Body.Close()
- contentType := resp.Header.Get("Content-Type")
-
- // Non-200: read a prefix to tell HTML error page apart from JSON error body.
if resp.StatusCode != http.StatusOK {
- body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
- if readErr != nil {
- return nil, fmt.Errorf("failed to read response: %w", readErr)
- }
- if looksLikeHTML(body, contentType) {
- return nil, wrapHTMLResponseError(resp.StatusCode, body, contentType, p.apiBase)
- }
- return nil, fmt.Errorf(
- "API request failed:\n Status: %d\n Body: %s",
- resp.StatusCode,
- responsePreview(body, 128),
- )
- }
-
- // Peek without consuming so the full stream reaches the JSON decoder.
- reader := bufio.NewReader(resp.Body)
- prefix, err := reader.Peek(256) // io.EOF/ErrBufferFull are normal; only real errors abort
- if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
- return nil, fmt.Errorf("failed to inspect response: %w", err)
- }
- if looksLikeHTML(prefix, contentType) {
- return nil, wrapHTMLResponseError(resp.StatusCode, prefix, contentType, p.apiBase)
- }
-
- out, err := parseResponse(reader)
- if err != nil {
- return nil, fmt.Errorf("failed to parse JSON response: %w", err)
+ return nil, common.HandleErrorResponse(resp, p.apiBase)
}
- return out, nil
-}
-
-func wrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
- respPreview := responsePreview(body, 128)
- return fmt.Errorf(
- "API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
- apiBase,
- contentType,
- statusCode,
- respPreview,
- )
-}
-
-func looksLikeHTML(body []byte, contentType string) bool {
- contentType = strings.ToLower(strings.TrimSpace(contentType))
- if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
- return true
- }
- prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
- return bytes.HasPrefix(prefix, []byte(" len(body) {
- end = len(body)
- }
- return body[i:end]
- }
- }
- return nil
-}
-
-func responsePreview(body []byte, maxLen int) string {
- trimmed := bytes.TrimSpace(body)
- if len(trimmed) == 0 {
- return ""
- }
- if len(trimmed) <= maxLen {
- return string(trimmed)
- }
- return string(trimmed[:maxLen]) + "..."
-}
-
-func parseResponse(body io.Reader) (*LLMResponse, error) {
- var apiResponse struct {
- Choices []struct {
- Message struct {
- Content string `json:"content"`
- ReasoningContent string `json:"reasoning_content"`
- Reasoning string `json:"reasoning"`
- ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
- ToolCalls []struct {
- ID string `json:"id"`
- Type string `json:"type"`
- Function *struct {
- Name string `json:"name"`
- Arguments json.RawMessage `json:"arguments"`
- } `json:"function"`
- ExtraContent *struct {
- Google *struct {
- ThoughtSignature string `json:"thought_signature"`
- } `json:"google"`
- } `json:"extra_content"`
- } `json:"tool_calls"`
- } `json:"message"`
- FinishReason string `json:"finish_reason"`
- } `json:"choices"`
- Usage *UsageInfo `json:"usage"`
- }
-
- if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
- return nil, fmt.Errorf("failed to decode response: %w", err)
- }
-
- if len(apiResponse.Choices) == 0 {
- return &LLMResponse{
- Content: "",
- FinishReason: "stop",
- }, nil
- }
-
- choice := apiResponse.Choices[0]
- toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
- for _, tc := range choice.Message.ToolCalls {
- arguments := make(map[string]any)
- name := ""
-
- // Extract thought_signature from Gemini/Google-specific extra content
- thoughtSignature := ""
- if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
- thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
- }
-
- if tc.Function != nil {
- name = tc.Function.Name
- arguments = decodeToolCallArguments(tc.Function.Arguments, name)
- }
-
- // Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence
- toolCall := ToolCall{
- ID: tc.ID,
- Name: name,
- Arguments: arguments,
- ThoughtSignature: thoughtSignature,
- }
-
- if thoughtSignature != "" {
- toolCall.ExtraContent = &ExtraContent{
- Google: &GoogleExtra{
- ThoughtSignature: thoughtSignature,
- },
- }
- }
-
- toolCalls = append(toolCalls, toolCall)
- }
-
- return &LLMResponse{
- Content: choice.Message.Content,
- ReasoningContent: choice.Message.ReasoningContent,
- Reasoning: choice.Message.Reasoning,
- ReasoningDetails: choice.Message.ReasoningDetails,
- ToolCalls: toolCalls,
- FinishReason: choice.FinishReason,
- Usage: apiResponse.Usage,
- }, nil
-}
-
-func decodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
- arguments := make(map[string]any)
- raw = bytes.TrimSpace(raw)
- if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
- return arguments
- }
-
- var decoded any
- if err := json.Unmarshal(raw, &decoded); err != nil {
- log.Printf("openai_compat: failed to decode tool call arguments payload for %q: %v", name, err)
- arguments["raw"] = string(raw)
- return arguments
- }
-
- switch v := decoded.(type) {
- case string:
- if strings.TrimSpace(v) == "" {
- return arguments
- }
- if err := json.Unmarshal([]byte(v), &arguments); err != nil {
- log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
- arguments["raw"] = v
- }
- return arguments
- case map[string]any:
- return v
- default:
- log.Printf("openai_compat: unsupported tool call arguments type for %q: %T", name, decoded)
- arguments["raw"] = string(raw)
- return arguments
- }
-}
-
-// openaiMessage is the wire-format message for OpenAI-compatible APIs.
-// It mirrors protocoltypes.Message but omits SystemParts, which is an
-// internal field that would be unknown to third-party endpoints.
-type openaiMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
- ReasoningContent string `json:"reasoning_content,omitempty"`
- ToolCalls []ToolCall `json:"tool_calls,omitempty"`
- ToolCallID string `json:"tool_call_id,omitempty"`
-}
-
-// serializeMessages converts internal Message structs to the OpenAI wire format.
-// - Strips SystemParts (unknown to third-party endpoints)
-// - Converts messages with Media to multipart content format (text + image_url parts)
-// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
-func serializeMessages(messages []Message) []any {
- out := make([]any, 0, len(messages))
- for _, m := range messages {
- if len(m.Media) == 0 {
- out = append(out, openaiMessage{
- Role: m.Role,
- Content: m.Content,
- ReasoningContent: m.ReasoningContent,
- ToolCalls: m.ToolCalls,
- ToolCallID: m.ToolCallID,
- })
- continue
- }
-
- // Multipart content format for messages with media
- parts := make([]map[string]any, 0, 1+len(m.Media))
- if m.Content != "" {
- parts = append(parts, map[string]any{
- "type": "text",
- "text": m.Content,
- })
- }
- for _, mediaURL := range m.Media {
- if strings.HasPrefix(mediaURL, "data:image/") {
- parts = append(parts, map[string]any{
- "type": "image_url",
- "image_url": map[string]any{
- "url": mediaURL,
- },
- })
- }
- }
-
- msg := map[string]any{
- "role": m.Role,
- "content": parts,
- }
- if m.ToolCallID != "" {
- msg["tool_call_id"] = m.ToolCallID
- }
- if len(m.ToolCalls) > 0 {
- msg["tool_calls"] = m.ToolCalls
- }
- if m.ReasoningContent != "" {
- msg["reasoning_content"] = m.ReasoningContent
- }
- out = append(out, msg)
- }
- return out
+ return common.ReadAndParseResponse(resp, p.apiBase)
}
func normalizeModel(model, apiBase string) string {
@@ -476,36 +195,6 @@ func normalizeModel(model, apiBase string) string {
}
}
-func asInt(v any) (int, bool) {
- switch val := v.(type) {
- case int:
- return val, true
- case int64:
- return int(val), true
- case float64:
- return int(val), true
- case float32:
- return int(val), true
- default:
- return 0, false
- }
-}
-
-func asFloat(v any) (float64, bool) {
- switch val := v.(type) {
- case float64:
- return val, true
- case float32:
- return float64(val), true
- case int:
- return float64(val), true
- case int64:
- return float64(val), true
- default:
- return 0, false
- }
-}
-
// supportsPromptCacheKey reports whether the given API base is known to
// support the prompt_cache_key request field. Currently only OpenAI's own
// API and Azure OpenAI support this. All other OpenAI-compatible providers
diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go
index 41f278a1b1..ed9747f9d7 100644
--- a/pkg/providers/openai_compat/provider_test.go
+++ b/pkg/providers/openai_compat/provider_test.go
@@ -12,6 +12,7 @@ import (
"testing"
"time"
+ "github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
@@ -648,7 +649,7 @@ func TestSerializeMessages_PlainText(t *testing.T) {
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
}
- result := serializeMessages(messages)
+ result := common.SerializeMessages(messages)
data, err := json.Marshal(result)
if err != nil {
@@ -670,7 +671,7 @@ func TestSerializeMessages_WithMedia(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
}
- result := serializeMessages(messages)
+ result := common.SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
@@ -703,7 +704,7 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
}
- result := serializeMessages(messages)
+ result := common.SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
@@ -833,7 +834,7 @@ func TestSerializeMessages_StripsSystemParts(t *testing.T) {
},
},
}
- result := serializeMessages(messages)
+ result := common.SerializeMessages(messages)
data, _ := json.Marshal(result)
raw := string(data)