diff --git a/.env.example b/.env.example index 55c3adb52a..bd7beda1c9 100644 --- a/.env.example +++ b/.env.example @@ -4,7 +4,7 @@ DATABASE_POOL_SIZE=10 # LLM Provider # LLM_BACKEND=nearai # default -# Possible values: nearai, ollama, openai_compatible, openai, anthropic, tinfoil +# Possible values: nearai, ollama, openai_compatible, openai, anthropic, github_copilot, tinfoil # LLM_REQUEST_TIMEOUT_SECS=120 # Increase for local LLMs (Ollama, vLLM, LM Studio) # === Anthropic Direct === @@ -24,6 +24,17 @@ DATABASE_POOL_SIZE=10 # LLM_USE_CODEX_AUTH=true # CODEX_AUTH_PATH=~/.codex/auth.json +# === GitHub Copilot === +# Uses the OAuth token from your Copilot IDE sign-in (for example +# ~/.config/github-copilot/apps.json on Linux/macOS), or run `ironclaw onboard` +# and choose the GitHub device login flow. +# LLM_BACKEND=github_copilot +# GITHUB_COPILOT_TOKEN=gho_... +# GITHUB_COPILOT_MODEL=gpt-4o +# IronClaw injects standard VS Code Copilot headers automatically. +# Optional advanced headers for custom overrides: +# GITHUB_COPILOT_EXTRA_HEADERS=Copilot-Integration-Id:vscode-chat + # === NEAR AI (Chat Completions API) === # Two auth modes: # 1. Session token (default): Uses browser OAuth (GitHub/Google) on first run. diff --git a/FEATURE_PARITY.md b/FEATURE_PARITY.md index db4ab92a4c..4422957f54 100644 --- a/FEATURE_PARITY.md +++ b/FEATURE_PARITY.md @@ -242,6 +242,7 @@ This document tracks feature parity between IronClaw (Rust implementation) and O | OpenRouter | ✅ | ✅ | - | Via OpenAI-compatible provider (RigAdapter) | | Tinfoil | ❌ | ✅ | - | Private inference provider (IronClaw-only) | | OpenAI-compatible | ❌ | ✅ | - | Generic OpenAI-compatible endpoint (RigAdapter) | +| GitHub Copilot | ✅ | ✅ | - | Dedicated provider with OAuth token exchange (`GithubCopilotProvider`) | | Ollama (local) | ✅ | ✅ | - | via `rig::providers::ollama` (full support) | | Perplexity | ✅ | ❌ | P3 | Freshness parameter for web_search | | MiniMax | ✅ | ❌ | P3 | Regional endpoint selection | diff --git a/README.md b/README.md index 9684ee4de6..8550a023af 100644 --- a/README.md +++ b/README.md @@ -168,7 +168,7 @@ written to `~/.ironclaw/.env` so they are available before the database connects IronClaw defaults to NEAR AI but supports many LLM providers out of the box. Built-in providers include **Anthropic**, **OpenAI**, **Google Gemini**, **MiniMax**, -**Mistral**, and **Ollama** (local). OpenAI-compatible services like **OpenRouter** +**Mistral**, **Github Copilot**, and **Ollama** (local). OpenAI-compatible services like **OpenRouter** (300+ models), **Together AI**, **Fireworks AI**, and self-hosted servers (**vLLM**, **LiteLLM**) are also supported. diff --git a/README.zh-CN.md b/README.zh-CN.md index 3402382227..cc5c4865fd 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -164,7 +164,7 @@ ironclaw onboard ### 替代 LLM 提供商 IronClaw 默认使用 NEAR AI,但开箱即用地支持多种 LLM 提供商。 -内置提供商包括 **Anthropic**、**OpenAI**、**Google Gemini**、**MiniMax**、**Mistral** 和 **Ollama**(本地部署)。同时也支持 OpenAI 兼容服务,如 **OpenRouter**(300+ 模型)、**Together AI**、**Fireworks AI** 以及自托管服务器(**vLLM**、**LiteLLM**)。 +内置提供商包括 **Anthropic**、**OpenAI**、**Google Gemini**、**MiniMax**、**Mistral**、**Github Copilot** 和 **Ollama**(本地部署)。同时也支持 OpenAI 兼容服务,如 **OpenRouter**(300+ 模型)、**Together AI**、**Fireworks AI** 以及自托管服务器(**vLLM**、**LiteLLM**)。 在向导中选择你的提供商,或直接设置环境变量: diff --git a/docs/LLM_PROVIDERS.md b/docs/LLM_PROVIDERS.md index a581a56b35..aa83ffdb66 100644 --- a/docs/LLM_PROVIDERS.md +++ b/docs/LLM_PROVIDERS.md @@ -17,6 +17,7 @@ configurations. | Yandex AI Studio | `yandex` | `YANDEX_API_KEY` | YandexGPT models | | MiniMax | `minimax` | `MINIMAX_API_KEY` | MiniMax-M2.5 models | | Cloudflare Workers AI | `cloudflare` | `CLOUDFLARE_API_KEY` | Access to Workers AI | +| GitHub Copilot | `github_copilot` | `GITHUB_COPILOT_TOKEN` | Multi-models | | Ollama | `ollama` | No | Local inference | | AWS Bedrock | `bedrock` | AWS credentials | Native Converse API | | OpenRouter | `openai_compatible` | `LLM_API_KEY` | 300+ models | @@ -61,6 +62,34 @@ Popular models: `gpt-4o`, `gpt-4o-mini`, `o3-mini` --- +## GitHub Copilot + +GitHub Copilot exposes chat endpoint at +`https://api.githubcopilot.com`. IronClaw uses that endpoint directly through the +built-in `github_copilot` provider. + +```env +LLM_BACKEND=github_copilot +GITHUB_COPILOT_TOKEN=gho_... +GITHUB_COPILOT_MODEL=gpt-4o +# Optional advanced headers if your setup needs them: +# GITHUB_COPILOT_EXTRA_HEADERS=Copilot-Integration-Id:vscode-chat +``` + +`ironclaw onboard` can acquire this token for you using GitHub device login. If you +already signed into Copilot through VS Code or a JetBrains IDE, you can also reuse +the `oauth_token` stored in `~/.config/github-copilot/apps.json`. If you prefer, +`LLM_BACKEND=github-copilot` also works as an alias. + +Popular models vary by subscription, but `gpt-4o` is a safe default. IronClaw keeps +model entry manual for this provider because GitHub Copilot model listing may require +extra integration headers on some clients. IronClaw automatically injects the standard +VS Code identity headers (`User-Agent`, `Editor-Version`, `Editor-Plugin-Version`, +`Copilot-Integration-Id`) and lets you override them with +`GITHUB_COPILOT_EXTRA_HEADERS`. + +--- + ## Ollama (local) Install Ollama from [ollama.com](https://ollama.com), pull a model, then: diff --git a/providers.json b/providers.json index 12723a6f84..8991ae666f 100644 --- a/providers.json +++ b/providers.json @@ -77,6 +77,29 @@ "can_list_models": false } }, + { + "id": "github_copilot", + "aliases": [ + "github-copilot", + "githubcopilot", + "copilot" + ], + "protocol": "github_copilot", + "default_base_url": "https://api.githubcopilot.com", + "api_key_env": "GITHUB_COPILOT_TOKEN", + "api_key_required": true, + "model_env": "GITHUB_COPILOT_MODEL", + "default_model": "gpt-4o", + "extra_headers_env": "GITHUB_COPILOT_EXTRA_HEADERS", + "description": "GitHub Copilot Chat API (OAuth token from IDE sign-in)", + "setup": { + "kind": "api_key", + "secret_name": "llm_github_copilot_token", + "key_url": "https://docs.github.com/en/copilot", + "display_name": "GitHub Copilot", + "can_list_models": false + } + }, { "id": "tinfoil", "aliases": [], diff --git a/src/config/llm.rs b/src/config/llm.rs index 4ad2439928..e0b68647f8 100644 --- a/src/config/llm.rs +++ b/src/config/llm.rs @@ -325,6 +325,14 @@ impl LlmConfig { } else { Vec::new() }; + let extra_headers = if canonical_id == "github_copilot" { + merge_extra_headers( + crate::llm::github_copilot_auth::default_headers(), + extra_headers, + ) + } else { + extra_headers + }; // Resolve OAuth token (Anthropic-specific: `claude login` flow). // Only check for OAuth token when the provider is actually Anthropic. @@ -409,6 +417,26 @@ fn parse_extra_headers(val: &str) -> Result, ConfigError> Ok(headers) } +fn merge_extra_headers( + defaults: Vec<(String, String)>, + overrides: Vec<(String, String)>, +) -> Vec<(String, String)> { + let mut merged = Vec::new(); + let mut positions = std::collections::HashMap::::new(); + + for (key, value) in defaults.into_iter().chain(overrides) { + let normalized = key.to_ascii_lowercase(); + if let Some(existing_index) = positions.get(&normalized).copied() { + merged[existing_index] = (key, value); + } else { + positions.insert(normalized, merged.len()); + merged.push((key, value)); + } + } + + merged +} + /// Get the default session file path (~/.ironclaw/session.json). pub fn default_session_path() -> PathBuf { ironclaw_base_dir().join("session.json") @@ -540,6 +568,29 @@ mod tests { ); } + #[test] + fn merge_extra_headers_prefers_overrides_case_insensitively() { + let merged = merge_extra_headers( + vec![ + ("User-Agent".to_string(), "default-agent".to_string()), + ("X-Test".to_string(), "default".to_string()), + ], + vec![ + ("user-agent".to_string(), "override-agent".to_string()), + ("X-Extra".to_string(), "present".to_string()), + ], + ); + + assert_eq!( + merged, + vec![ + ("user-agent".to_string(), "override-agent".to_string()), + ("X-Test".to_string(), "default".to_string()), + ("X-Extra".to_string(), "present".to_string()), + ] + ); + } + /// Clear all ollama-related env vars. fn clear_ollama_env() { // SAFETY: Only called under ENV_MUTEX in tests. @@ -692,6 +743,54 @@ mod tests { assert_eq!(provider.protocol, ProviderProtocol::OpenAiCompletions); } + #[test] + fn registry_provider_resolves_github_copilot_alias() { + let _guard = ENV_MUTEX.lock().expect("env mutex poisoned"); + // SAFETY: Under ENV_MUTEX. + unsafe { + std::env::set_var("LLM_BACKEND", "github-copilot"); + std::env::set_var("GITHUB_COPILOT_TOKEN", "gho_test_token"); + std::env::set_var( + "GITHUB_COPILOT_EXTRA_HEADERS", + "Copilot-Integration-Id:custom-chat,X-Test:enabled", + ); + } + + let settings = Settings::default(); + + let cfg = LlmConfig::resolve(&settings).expect("resolve should succeed"); + assert_eq!(cfg.backend, "github_copilot"); + let provider = cfg.provider.expect("provider config should be present"); + assert_eq!(provider.provider_id, "github_copilot"); + assert_eq!(provider.base_url, "https://api.githubcopilot.com"); + assert_eq!(provider.model, "gpt-4o"); + assert!( + provider + .extra_headers + .iter() + .any(|(key, value)| { key == "Copilot-Integration-Id" && value == "custom-chat" }) + ); + assert!( + provider + .extra_headers + .iter() + .any(|(key, value)| key == "User-Agent" && value == "GitHubCopilotChat/0.26.7") + ); + assert!( + provider + .extra_headers + .iter() + .any(|(key, value)| key == "X-Test" && value == "enabled") + ); + + // SAFETY: Under ENV_MUTEX. + unsafe { + std::env::remove_var("LLM_BACKEND"); + std::env::remove_var("GITHUB_COPILOT_TOKEN"); + std::env::remove_var("GITHUB_COPILOT_EXTRA_HEADERS"); + } + } + #[test] fn nearai_backend_has_no_registry_provider() { let _guard = ENV_MUTEX.lock().expect("env mutex poisoned"); diff --git a/src/llm/CLAUDE.md b/src/llm/CLAUDE.md index 38d6901058..9e8fc2e2ec 100644 --- a/src/llm/CLAUDE.md +++ b/src/llm/CLAUDE.md @@ -34,6 +34,7 @@ Set via `LLM_BACKEND` env var: | `nearai` (default) | NEAR AI Chat Completions | `NEARAI_SESSION_TOKEN` or `NEARAI_API_KEY` | | `openai` | OpenAI | `OPENAI_API_KEY` | | `anthropic` | Anthropic | `ANTHROPIC_API_KEY` | +| `github_copilot` | GitHub Copilot Chat API | `GITHUB_COPILOT_TOKEN`, `GITHUB_COPILOT_MODEL` | | `ollama` | Ollama local | `OLLAMA_BASE_URL` | | `openai_compatible` | Any OpenAI-compatible endpoint | `LLM_BASE_URL`, `LLM_API_KEY`, `LLM_MODEL` | | `tinfoil` | Tinfoil TEE inference | `TINFOIL_API_KEY`, `TINFOIL_MODEL` | @@ -56,6 +57,22 @@ Uses the native Converse API via `aws-sdk-bedrockruntime` (`bedrock.rs`). Requir - `BEDROCK_MODEL` — Required model ID (e.g., `anthropic.claude-opus-4-6-v1`) - `BEDROCK_CROSS_REGION` — Optional cross-region inference prefix (`us`, `eu`, `apac`, `global`) +## GitHub Copilot Provider Notes + +`github_copilot` is a declarative registry provider backed by the existing +OpenAI-compatible path. It defaults to `https://api.githubcopilot.com` and expects a +GitHub Copilot OAuth token in `GITHUB_COPILOT_TOKEN` (for example the `oauth_token` +stored by your IDE sign-in flow in `~/.config/github-copilot/apps.json`). The setup +wizard also supports GitHub device login using the VS Code Copilot client ID and then +stores the resulting token in the encrypted secrets store. + +Manual model entry is used in the setup wizard (`can_list_models = false`) because +GitHub Copilot model discovery can require extra integration headers on some clients. +IronClaw injects the standard VS Code identity headers automatically: +`User-Agent`, `Editor-Version`, `Editor-Plugin-Version`, and +`Copilot-Integration-Id`. Advanced users can still override or append headers via +`GITHUB_COPILOT_EXTRA_HEADERS`. + ## NEAR AI Provider Gotchas **Dual auth modes:** diff --git a/src/llm/github_copilot.rs b/src/llm/github_copilot.rs new file mode 100644 index 0000000000..15decc78a5 --- /dev/null +++ b/src/llm/github_copilot.rs @@ -0,0 +1,655 @@ +//! GitHub Copilot provider (direct HTTP with token exchange). +//! +//! The GitHub Copilot API at `api.githubcopilot.com` speaks OpenAI Chat +//! Completions format but requires a two-step authentication flow: +//! 1. A long-lived GitHub OAuth token (from device login or IDE sign-in) +//! 2. A short-lived Copilot session token (exchanged via GitHub API) +//! +//! The standard OpenAI rig-core client sends `Authorization: Bearer ` +//! with the raw OAuth token, which gets rejected with "Authorization header +//! is badly formatted". This provider handles the token exchange transparently. + +use std::collections::HashSet; +use std::sync::Arc; + +use async_trait::async_trait; +use reqwest::Client; +use rust_decimal::Decimal; +use secrecy::ExposeSecret; +use serde::{Deserialize, Serialize}; + +use crate::llm::config::RegistryProviderConfig; +use crate::llm::costs; +use crate::llm::error::LlmError; +use crate::llm::github_copilot_auth::CopilotTokenManager; +use crate::llm::provider::{ + ChatMessage, CompletionRequest, CompletionResponse, ContentPart, FinishReason, LlmProvider, + Role, ToolCall, ToolCompletionRequest, ToolCompletionResponse, + strip_unsupported_completion_params, strip_unsupported_tool_params, +}; + +/// GitHub Copilot provider with automatic token exchange. +pub struct GithubCopilotProvider { + client: Client, + token_manager: Arc, + model: String, + base_url: String, + active_model: std::sync::RwLock, + extra_headers: Vec<(String, String)>, + /// Parameter names that this provider does not support. + unsupported_params: HashSet, +} + +impl GithubCopilotProvider { + pub fn new(config: &RegistryProviderConfig) -> Result { + let oauth_token = config + .api_key + .as_ref() + .map(|k| k.expose_secret().to_string()) + .ok_or_else(|| { + tracing::error!("No API key configured for github_copilot — check GITHUB_COPILOT_TOKEN env var or secrets store"); + LlmError::AuthFailed { + provider: "github_copilot".to_string(), + } + })?; + + let client = Client::builder() + .timeout(std::time::Duration::from_secs(120)) + .build() + .map_err(|e| LlmError::RequestFailed { + provider: "github_copilot".to_string(), + reason: format!("Failed to build HTTP client: {e}"), + })?; + + let token_manager = Arc::new(CopilotTokenManager::new(client.clone(), oauth_token)); + + let base_url = if config.base_url.is_empty() { + "https://api.githubcopilot.com".to_string() + } else { + config.base_url.clone() + }; + + let active_model = std::sync::RwLock::new(config.model.clone()); + let unsupported_params: HashSet = + config.unsupported_params.iter().cloned().collect(); + + Ok(Self { + client, + token_manager, + model: config.model.clone(), + base_url, + active_model, + extra_headers: config.extra_headers.clone(), + unsupported_params, + }) + } + + fn api_url(&self) -> String { + let base = self.base_url.trim_end_matches('/'); + format!("{base}/chat/completions") + } + + /// Strip unsupported fields from a `CompletionRequest` in place. + fn strip_unsupported_completion_params(&self, req: &mut CompletionRequest) { + strip_unsupported_completion_params(&self.unsupported_params, req); + } + + /// Strip unsupported fields from a `ToolCompletionRequest` in place. + fn strip_unsupported_tool_params(&self, req: &mut ToolCompletionRequest) { + strip_unsupported_tool_params(&self.unsupported_params, req); + } + + async fn send_request Deserialize<'de>>( + &self, + body: &impl Serialize, + ) -> Result { + let url = self.api_url(); + let token = self.token_manager.get_token().await.map_err(|e| { + tracing::warn!(error = %e, "Copilot: token exchange failed"); + LlmError::AuthFailed { + provider: "github_copilot".to_string(), + } + })?; + + let mut request = self + .client + .post(&url) + .bearer_auth(token.expose_secret()) + .header("Content-Type", "application/json"); + + // Inject Copilot identity headers + for (key, value) in &self.extra_headers { + request = request.header(key.as_str(), value.as_str()); + } + + let response = request.json(body).send().await.map_err(|e| { + tracing::warn!(error = %e, "Copilot: HTTP request failed"); + LlmError::RequestFailed { + provider: "github_copilot".to_string(), + reason: e.to_string(), + } + })?; + + let status = response.status(); + + if !status.is_success() { + let retry_after = response + .headers() + .get("retry-after") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.parse::().ok()) + .map(std::time::Duration::from_secs); + + let response_text = response + .text() + .await + .unwrap_or_else(|e| format!("(failed to read error body: {e})")); + + tracing::warn!( + status = %status, + body = %crate::agent::truncate_for_preview(&response_text, 256), + "Copilot: API error response" + ); + + if status.as_u16() == 401 { + // Invalidate the cached session token so the next attempt + // performs a fresh token exchange. + tracing::warn!( + "Copilot: 401 Unauthorized — invalidating cached session token for retry" + ); + self.token_manager.invalidate().await; + return Err(LlmError::AuthFailed { + provider: "github_copilot".to_string(), + }); + } + if status.as_u16() == 429 { + tracing::warn!(retry_after = ?retry_after, "Copilot: rate limited"); + return Err(LlmError::RateLimited { + provider: "github_copilot".to_string(), + retry_after, + }); + } + let truncated = crate::agent::truncate_for_preview(&response_text, 512); + return Err(LlmError::RequestFailed { + provider: "github_copilot".to_string(), + reason: format!("HTTP {status}: {truncated}"), + }); + } + + let response_text = response.text().await.map_err(|e| LlmError::RequestFailed { + provider: "github_copilot".to_string(), + reason: format!("Failed to read response body: {e}"), + })?; + + serde_json::from_str(&response_text).map_err(|e| { + let truncated = crate::agent::truncate_for_preview(&response_text, 512); + tracing::warn!( + error = %e, + body = %truncated, + "Copilot: failed to parse response JSON" + ); + LlmError::InvalidResponse { + provider: "github_copilot".to_string(), + reason: format!("JSON parse error: {e}. Raw: {truncated}"), + } + }) + } +} + +#[async_trait] +impl LlmProvider for GithubCopilotProvider { + async fn complete(&self, mut req: CompletionRequest) -> Result { + let model = req.model.take().unwrap_or_else(|| self.active_model_name()); + self.strip_unsupported_completion_params(&mut req); + let messages = convert_messages(req.messages); + + let request = OpenAiRequest { + model, + messages, + max_tokens: req.max_tokens, + temperature: req.temperature, + tools: None, + tool_choice: None, + }; + + let response: OpenAiResponse = self.send_request(&request).await?; + let choice = + response + .choices + .into_iter() + .next() + .ok_or_else(|| LlmError::InvalidResponse { + provider: "github_copilot".to_string(), + reason: "No choices in response".to_string(), + })?; + + let (content, _tool_calls) = extract_choice_content(&choice); + + let finish_reason = match choice.finish_reason.as_deref() { + Some("stop") => FinishReason::Stop, + Some("length") => FinishReason::Length, + Some("tool_calls") => FinishReason::ToolUse, + Some("content_filter") => FinishReason::ContentFilter, + _ => FinishReason::Unknown, + }; + + Ok(CompletionResponse { + content: content.unwrap_or_default(), + finish_reason, + input_tokens: response + .usage + .as_ref() + .map(|u| u.prompt_tokens) + .unwrap_or(0), + output_tokens: response + .usage + .as_ref() + .map(|u| u.completion_tokens) + .unwrap_or(0), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }) + } + + async fn complete_with_tools( + &self, + mut req: ToolCompletionRequest, + ) -> Result { + let model = req.model.take().unwrap_or_else(|| self.active_model_name()); + self.strip_unsupported_tool_params(&mut req); + let messages = convert_messages(req.messages); + + let tools: Vec = req + .tools + .into_iter() + .map(|t| OpenAiTool { + tool_type: "function".to_string(), + function: OpenAiFunction { + name: t.name, + description: t.description, + parameters: t.parameters, + }, + }) + .collect(); + + let tool_choice = req.tool_choice.map(|tc| match tc.as_str() { + "auto" | "required" | "none" => serde_json::Value::String(tc), + specific => serde_json::json!({ + "type": "function", + "function": {"name": specific} + }), + }); + + let request = OpenAiRequest { + model, + messages, + max_tokens: req.max_tokens, + temperature: req.temperature, + tools: if tools.is_empty() { None } else { Some(tools) }, + tool_choice, + }; + + let response: OpenAiResponse = self.send_request(&request).await?; + let choice = + response + .choices + .into_iter() + .next() + .ok_or_else(|| LlmError::InvalidResponse { + provider: "github_copilot".to_string(), + reason: "No choices in response".to_string(), + })?; + + let (content, tool_calls) = extract_choice_content(&choice); + + let finish_reason = match choice.finish_reason.as_deref() { + Some("stop") => FinishReason::Stop, + Some("length") => FinishReason::Length, + Some("tool_calls") => FinishReason::ToolUse, + Some("content_filter") => FinishReason::ContentFilter, + _ => { + if !tool_calls.is_empty() { + FinishReason::ToolUse + } else { + FinishReason::Unknown + } + } + }; + + Ok(ToolCompletionResponse { + content, + tool_calls, + finish_reason, + input_tokens: response + .usage + .as_ref() + .map(|u| u.prompt_tokens) + .unwrap_or(0), + output_tokens: response + .usage + .as_ref() + .map(|u| u.completion_tokens) + .unwrap_or(0), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }) + } + + fn model_name(&self) -> &str { + &self.model + } + + fn cost_per_token(&self) -> (Decimal, Decimal) { + let model = self.active_model_name(); + costs::model_cost(&model).unwrap_or_else(costs::default_cost) + } + + fn active_model_name(&self) -> String { + match self.active_model.read() { + Ok(guard) => guard.clone(), + Err(poisoned) => poisoned.into_inner().clone(), + } + } + + fn set_model(&self, model: &str) -> Result<(), LlmError> { + match self.active_model.write() { + Ok(mut guard) => { + *guard = model.to_string(); + } + Err(poisoned) => { + *poisoned.into_inner() = model.to_string(); + } + } + Ok(()) + } +} + +// --- OpenAI Chat Completions API types --- + +#[derive(Debug, Serialize)] +struct OpenAiRequest { + model: String, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, +} + +#[derive(Debug, Serialize)] +struct OpenAiMessage { + role: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, +} + +/// OpenAI content can be a plain string or an array of parts (for multimodal). +#[derive(Debug, Serialize)] +#[serde(untagged)] +enum OpenAiContent { + Text(String), + Parts(Vec), +} + +#[derive(Debug, Serialize)] +#[serde(tag = "type")] +enum OpenAiContentPart { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image_url")] + ImageUrl { image_url: OpenAiImageUrl }, +} + +#[derive(Debug, Serialize)] +struct OpenAiImageUrl { + url: String, +} + +#[derive(Debug, Serialize)] +struct OpenAiToolCall { + id: String, + #[serde(rename = "type")] + call_type: String, + function: OpenAiToolCallFunction, +} + +#[derive(Debug, Serialize)] +struct OpenAiToolCallFunction { + name: String, + arguments: String, +} + +#[derive(Debug, Serialize)] +struct OpenAiTool { + #[serde(rename = "type")] + tool_type: String, + function: OpenAiFunction, +} + +#[derive(Debug, Serialize)] +struct OpenAiFunction { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Debug, Deserialize)] +struct OpenAiResponse { + choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct OpenAiChoice { + message: OpenAiResponseMessage, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct OpenAiResponseMessage { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, +} + +#[derive(Debug, Deserialize)] +struct OpenAiResponseToolCall { + id: String, + function: OpenAiResponseFunction, +} + +#[derive(Debug, Deserialize)] +struct OpenAiResponseFunction { + name: String, + arguments: String, +} + +#[derive(Debug, Deserialize)] +struct OpenAiUsage { + #[serde(default)] + prompt_tokens: u32, + #[serde(default)] + completion_tokens: u32, +} + +/// Convert IronClaw messages to OpenAI Chat Completions format. +fn convert_messages(messages: Vec) -> Vec { + messages + .into_iter() + .map(|msg| match msg.role { + Role::System => OpenAiMessage { + role: "system".to_string(), + content: Some(OpenAiContent::Text(msg.content)), + tool_calls: None, + tool_call_id: None, + name: None, + }, + Role::User => { + let content = if msg.content_parts.is_empty() { + Some(OpenAiContent::Text(msg.content)) + } else { + let mut parts = vec![OpenAiContentPart::Text { text: msg.content }]; + for part in msg.content_parts { + if let ContentPart::ImageUrl { image_url } = part { + parts.push(OpenAiContentPart::ImageUrl { + image_url: OpenAiImageUrl { url: image_url.url }, + }); + } + } + Some(OpenAiContent::Parts(parts)) + }; + OpenAiMessage { + role: "user".to_string(), + content, + tool_calls: None, + tool_call_id: None, + name: None, + } + } + Role::Assistant => { + let tool_calls = msg.tool_calls.map(|calls| { + calls + .into_iter() + .map(|tc| OpenAiToolCall { + id: tc.id, + call_type: "function".to_string(), + function: OpenAiToolCallFunction { + name: tc.name, + arguments: tc.arguments.to_string(), + }, + }) + .collect() + }); + let content = if msg.content.is_empty() { + None + } else { + Some(OpenAiContent::Text(msg.content)) + }; + OpenAiMessage { + role: "assistant".to_string(), + content, + tool_calls, + tool_call_id: None, + name: None, + } + } + Role::Tool => OpenAiMessage { + role: "tool".to_string(), + content: Some(OpenAiContent::Text(msg.content)), + tool_calls: None, + tool_call_id: msg.tool_call_id, + name: msg.name, + }, + }) + .collect() +} + +/// Extract text and tool calls from an OpenAI response choice. +fn extract_choice_content(choice: &OpenAiChoice) -> (Option, Vec) { + let content = choice.message.content.clone(); + let tool_calls = choice + .message + .tool_calls + .as_ref() + .map(|calls| { + calls + .iter() + .map(|tc| ToolCall { + id: tc.id.clone(), + name: tc.function.name.clone(), + arguments: serde_json::from_str(&tc.function.arguments) + .unwrap_or(serde_json::Value::Object(serde_json::Map::new())), + }) + .collect() + }) + .unwrap_or_default(); + + (content, tool_calls) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_convert_messages_basic() { + let messages = vec![ + ChatMessage::system("You are helpful."), + ChatMessage::user("Hello"), + ChatMessage::assistant("Hi there!"), + ]; + let converted = convert_messages(messages); + assert_eq!(converted.len(), 3); + assert_eq!(converted[0].role, "system"); + assert_eq!(converted[1].role, "user"); + assert_eq!(converted[2].role, "assistant"); + } + + #[test] + fn test_convert_messages_tool_calls() { + let tool_calls = vec![ToolCall { + id: "call_1".to_string(), + name: "search".to_string(), + arguments: serde_json::json!({"q": "test"}), + }]; + let messages = vec![ + ChatMessage::user("Search"), + ChatMessage::assistant_with_tool_calls(Some("Searching...".to_string()), tool_calls), + ChatMessage::tool_result("call_1", "search", "found it"), + ]; + let converted = convert_messages(messages); + assert_eq!(converted.len(), 3); + assert!(converted[1].tool_calls.is_some()); + assert_eq!(converted[2].role, "tool"); + assert_eq!(converted[2].tool_call_id, Some("call_1".to_string())); + } + + #[test] + fn test_extract_choice_text_only() { + let choice = OpenAiChoice { + message: OpenAiResponseMessage { + content: Some("Hello!".to_string()), + tool_calls: None, + }, + finish_reason: Some("stop".to_string()), + }; + let (content, tool_calls) = extract_choice_content(&choice); + assert_eq!(content, Some("Hello!".to_string())); + assert!(tool_calls.is_empty()); + } + + #[test] + fn test_extract_choice_with_tool_calls() { + let choice = OpenAiChoice { + message: OpenAiResponseMessage { + content: Some("Let me search.".to_string()), + tool_calls: Some(vec![OpenAiResponseToolCall { + id: "call_1".to_string(), + function: OpenAiResponseFunction { + name: "search".to_string(), + arguments: r#"{"q":"test"}"#.to_string(), + }, + }]), + }, + finish_reason: Some("tool_calls".to_string()), + }; + let (content, tool_calls) = extract_choice_content(&choice); + assert_eq!(content, Some("Let me search.".to_string())); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].name, "search"); + assert_eq!(tool_calls[0].arguments["q"], "test"); + } +} diff --git a/src/llm/github_copilot_auth.rs b/src/llm/github_copilot_auth.rs new file mode 100644 index 0000000000..34182df5f1 --- /dev/null +++ b/src/llm/github_copilot_auth.rs @@ -0,0 +1,746 @@ +use std::time::Duration; + +use secrecy::{ExposeSecret, SecretString}; +use serde::Deserialize; +use tokio::sync::RwLock; + +// ─── Risk: hardcoded VS Code Copilot identity ─────────────────────────────── +// +// The client ID and editor identity headers below are extracted from the +// VS Code Copilot Chat extension. This is the *only* publicly documented +// way to access the Copilot completions API with a personal GitHub token. +// +// **Known risks:** +// • GitHub may rotate or revoke this client ID at any time, which would +// break authentication for all IronClaw users until the constant is +// updated and a new release is shipped. +// • Using another product's client ID may violate GitHub's Terms of +// Service. Maintainers should seek explicit guidance from GitHub +// before shipping this to a wide audience. +// • The editor version strings (`vscode/1.99.3`, `copilot-chat/0.26.7`) +// will become stale and could eventually be rejected by the API. +// +// **Mitigation:** If GitHub publishes an official Copilot API client ID or +// an OAuth app registration flow for third-party tools, migrate to it +// immediately. +// ───────────────────────────────────────────────────────────────────────────── +pub const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98"; +pub const GITHUB_COPILOT_SCOPE: &str = "read:user"; +pub const GITHUB_COPILOT_DEVICE_CODE_URL: &str = "https://github.com/login/device/code"; +pub const GITHUB_COPILOT_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token"; +pub const GITHUB_COPILOT_MODELS_URL: &str = "https://api.githubcopilot.com/models"; +pub const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token"; +pub const GITHUB_COPILOT_USER_AGENT: &str = "GitHubCopilotChat/0.26.7"; +pub const GITHUB_COPILOT_EDITOR_VERSION: &str = "vscode/1.99.3"; +pub const GITHUB_COPILOT_EDITOR_PLUGIN_VERSION: &str = "copilot-chat/0.26.7"; +pub const GITHUB_COPILOT_INTEGRATION_ID: &str = "vscode-chat"; + +/// Buffer before token expiry to trigger a refresh (5 minutes). +const TOKEN_REFRESH_BUFFER_SECS: u64 = 300; + +#[derive(Debug, Clone, Deserialize)] +pub struct DeviceCodeResponse { + pub device_code: String, + pub user_code: String, + pub verification_uri: String, + pub expires_in: u64, + #[serde(default = "default_poll_interval_secs")] + pub interval: u64, +} + +#[derive(Debug, Clone, Deserialize)] +struct AccessTokenResponse { + access_token: Option, + error: Option, + error_description: Option, +} + +#[derive(Debug, thiserror::Error)] +pub enum GithubCopilotAuthError { + #[error("failed to start device login: {0}")] + DeviceCodeRequest(String), + #[error("failed to poll device login: {0}")] + TokenPolling(String), + #[error("device login was denied")] + AccessDenied, + #[error("device login expired before authorization completed")] + Expired, + #[error("github copilot token validation failed: {0}")] + Validation(String), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum DevicePollingStatus { + Pending, + SlowDown, + Authorized(String), +} + +pub fn default_headers() -> Vec<(String, String)> { + vec![ + ( + "User-Agent".to_string(), + GITHUB_COPILOT_USER_AGENT.to_string(), + ), + ( + "Editor-Version".to_string(), + GITHUB_COPILOT_EDITOR_VERSION.to_string(), + ), + ( + "Editor-Plugin-Version".to_string(), + GITHUB_COPILOT_EDITOR_PLUGIN_VERSION.to_string(), + ), + ( + "Copilot-Integration-Id".to_string(), + GITHUB_COPILOT_INTEGRATION_ID.to_string(), + ), + ] +} + +pub fn default_poll_interval_secs() -> u64 { + 5 +} + +pub async fn request_device_code( + client: &reqwest::Client, +) -> Result { + let response = client + .post(GITHUB_COPILOT_DEVICE_CODE_URL) + .header(reqwest::header::ACCEPT, "application/json") + .header(reqwest::header::USER_AGENT, GITHUB_COPILOT_USER_AGENT) + .form(&[ + ("client_id", GITHUB_COPILOT_CLIENT_ID), + ("scope", GITHUB_COPILOT_SCOPE), + ]) + .send() + .await + .map_err(|e| { + tracing::warn!( + error = %e, + is_timeout = e.is_timeout(), + is_connect = e.is_connect(), + url = %GITHUB_COPILOT_DEVICE_CODE_URL, + "Copilot: device code request failed" + ); + GithubCopilotAuthError::DeviceCodeRequest(format_reqwest_error(&e)) + })?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + tracing::warn!( + status = %status, + body = %truncate_for_error(&body), + "Copilot: device code endpoint returned error" + ); + return Err(GithubCopilotAuthError::DeviceCodeRequest(format!( + "HTTP {status}: {}", + truncate_for_error(&body) + ))); + } + + let device = response + .json::() + .await + .map_err(|e| GithubCopilotAuthError::DeviceCodeRequest(e.to_string()))?; + + Ok(device) +} + +pub async fn poll_for_access_token( + client: &reqwest::Client, + device_code: &str, +) -> Result { + let response = client + .post(GITHUB_COPILOT_ACCESS_TOKEN_URL) + .header(reqwest::header::ACCEPT, "application/json") + .header(reqwest::header::USER_AGENT, GITHUB_COPILOT_USER_AGENT) + .form(&[ + ("client_id", GITHUB_COPILOT_CLIENT_ID), + ("device_code", device_code), + ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), + ]) + .send() + .await + .map_err(|e| { + tracing::warn!( + error = %e, + is_timeout = e.is_timeout(), + is_connect = e.is_connect(), + url = %GITHUB_COPILOT_ACCESS_TOKEN_URL, + "Copilot: poll request failed" + ); + GithubCopilotAuthError::TokenPolling(format_reqwest_error(&e)) + })?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + tracing::warn!( + status = %status, + body = %truncate_for_error(&body), + "Copilot: poll endpoint returned error" + ); + return Err(GithubCopilotAuthError::TokenPolling(format!( + "HTTP {status}: {}", + truncate_for_error(&body) + ))); + } + + let body = response + .json::() + .await + .map_err(|e| GithubCopilotAuthError::TokenPolling(e.to_string()))?; + + if let Some(token) = body.access_token { + return Ok(DevicePollingStatus::Authorized(token)); + } + + match body.error.as_deref() { + Some("authorization_pending") | None => Ok(DevicePollingStatus::Pending), + Some("slow_down") => { + tracing::debug!("Copilot: GitHub requested slow_down, increasing poll interval"); + Ok(DevicePollingStatus::SlowDown) + } + Some("access_denied") => { + tracing::warn!("Copilot: device login was denied by user"); + Err(GithubCopilotAuthError::AccessDenied) + } + Some("expired_token") => { + tracing::warn!("Copilot: device code expired before authorization"); + Err(GithubCopilotAuthError::Expired) + } + Some(other) => { + let desc = body + .error_description + .filter(|description| !description.is_empty()) + .unwrap_or_else(|| other.to_string()); + tracing::warn!(error = %other, description = %desc, "Copilot: unexpected poll error"); + Err(GithubCopilotAuthError::TokenPolling(desc)) + } + } +} + +/// Maximum consecutive transient poll failures before giving up. +const MAX_POLL_FAILURES: u32 = 5; + +pub async fn wait_for_device_login( + client: &reqwest::Client, + device: &DeviceCodeResponse, +) -> Result { + let expires_at = std::time::Instant::now() + .checked_add(Duration::from_secs(device.expires_in)) + .ok_or(GithubCopilotAuthError::Expired)?; + let mut poll_interval = device.interval.max(1); + let mut consecutive_failures: u32 = 0; + + loop { + if std::time::Instant::now() >= expires_at { + tracing::warn!("Copilot: device login expired"); + return Err(GithubCopilotAuthError::Expired); + } + + tokio::time::sleep(Duration::from_secs(poll_interval)).await; + + match poll_for_access_token(client, &device.device_code).await { + Ok(DevicePollingStatus::Pending) => { + consecutive_failures = 0; + } + Ok(DevicePollingStatus::SlowDown) => { + consecutive_failures = 0; + poll_interval = poll_interval.saturating_add(5); + } + Ok(DevicePollingStatus::Authorized(token)) => { + return Ok(token); + } + // Definitive failures — propagate immediately + Err(GithubCopilotAuthError::AccessDenied) => { + return Err(GithubCopilotAuthError::AccessDenied); + } + Err(GithubCopilotAuthError::Expired) => { + return Err(GithubCopilotAuthError::Expired); + } + // Transient failures — retry with backoff + Err(e) => { + consecutive_failures += 1; + tracing::warn!( + error = %e, + attempt = consecutive_failures, + max = MAX_POLL_FAILURES, + "Copilot: transient poll failure, will retry" + ); + if consecutive_failures >= MAX_POLL_FAILURES { + tracing::error!( + error = %e, + "Copilot: too many consecutive poll failures, giving up" + ); + return Err(e); + } + // Back off on transient errors + poll_interval = (poll_interval + 2).min(30); + } + } + } +} + +/// Validate a GitHub OAuth token by performing the Copilot token exchange. +/// +/// This exchanges the raw OAuth token for a Copilot session token (proving the +/// token is valid and the user has Copilot access), then verifies the session +/// token works against the models endpoint. +pub async fn validate_token( + client: &reqwest::Client, + token: &str, +) -> Result<(), GithubCopilotAuthError> { + // Step 1: Exchange the OAuth token for a Copilot session token. + // This validates both that the OAuth token is valid and that the user + // has an active Copilot subscription. + let session = exchange_copilot_token(client, token).await?; + // Step 2: Verify the session token works against the models endpoint. + let mut request = client + .get(GITHUB_COPILOT_MODELS_URL) + .bearer_auth(&session.token) + .timeout(Duration::from_secs(15)); + + for (key, value) in default_headers() { + request = request.header(&key, value); + } + + let response = request.send().await.map_err(|e| { + tracing::warn!( + error = %e, + is_timeout = e.is_timeout(), + is_connect = e.is_connect(), + "Copilot: models endpoint request failed" + ); + GithubCopilotAuthError::Validation(format_reqwest_error(&e)) + })?; + + if response.status().is_success() { + return Ok(()); + } + + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + tracing::warn!( + status = %status, + body = %truncate_for_error(&body), + "Copilot: models endpoint returned error during validation" + ); + Err(GithubCopilotAuthError::Validation(format!( + "HTTP {status}: {}", + truncate_for_error(&body) + ))) +} + +/// Response from the Copilot token exchange endpoint. +/// +/// The `token` field is an HMAC-signed session token (not a JWT) used as +/// `Authorization: Bearer ` for requests to `api.githubcopilot.com`. +#[derive(Debug, Clone, Deserialize)] +pub struct CopilotTokenResponse { + /// The Copilot session token (HMAC-signed, not a JWT). + pub token: String, + /// Unix timestamp (seconds) when this token expires. + pub expires_at: u64, +} + +/// Exchange a GitHub OAuth token for a Copilot API session token. +/// +/// Calls `GET https://api.github.com/copilot_internal/v2/token` with the +/// GitHub OAuth token in `Authorization: token ` format. +/// Returns a short-lived session token for `api.githubcopilot.com`. +pub async fn exchange_copilot_token( + client: &reqwest::Client, + oauth_token: &str, +) -> Result { + let mut request = client + .get(GITHUB_COPILOT_TOKEN_URL) + .header(reqwest::header::ACCEPT, "application/json") + // GitHub Copilot uses `token` auth scheme, not `Bearer` + .header( + reqwest::header::AUTHORIZATION, + format!("token {oauth_token}"), + ) + .timeout(Duration::from_secs(15)); + + for (key, value) in default_headers() { + request = request.header(&key, value); + } + + let response = request.send().await.map_err(|e| { + tracing::warn!( + error = %e, + is_timeout = e.is_timeout(), + is_connect = e.is_connect(), + "Copilot: token exchange HTTP request failed" + ); + GithubCopilotAuthError::Validation(format_reqwest_error(&e)) + })?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + tracing::warn!( + status = %status, + body = %truncate_for_error(&body), + "Copilot: token exchange endpoint returned error" + ); + return Err(GithubCopilotAuthError::Validation(format!( + "Copilot token exchange failed: HTTP {status}: {}", + truncate_for_error(&body) + ))); + } + + let token_response = response.json::().await.map_err(|e| { + tracing::warn!(error = %e, "Copilot: failed to parse token exchange response"); + GithubCopilotAuthError::Validation(e.to_string()) + })?; + + Ok(token_response) +} + +/// Manages a cached Copilot API session token with automatic refresh. +/// +/// The GitHub Copilot API requires a two-step authentication: +/// 1. A long-lived GitHub OAuth token (from device login or IDE sign-in) +/// 2. A short-lived Copilot session token (exchanged via `/copilot_internal/v2/token`) +/// +/// This manager caches the session token and refreshes it automatically +/// before it expires (with a 5-minute buffer). +pub struct CopilotTokenManager { + client: reqwest::Client, + oauth_token: SecretString, + cached: RwLock>, +} + +#[derive(Clone)] +struct CachedCopilotToken { + token: SecretString, + expires_at: u64, +} + +fn unix_now() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +impl CopilotTokenManager { + /// Create a new token manager with the given GitHub OAuth token. + pub fn new(client: reqwest::Client, oauth_token: String) -> Self { + Self { + client, + oauth_token: SecretString::from(oauth_token), + cached: RwLock::new(None), + } + } + + /// Get a valid Copilot session token, refreshing if needed. + /// + /// Returns the cached token if it has more than 5 minutes remaining, + /// otherwise exchanges the OAuth token for a fresh session token. + pub async fn get_token(&self) -> Result { + // Fast path: check if cached token is still valid under read lock. + { + let guard = self.cached.read().await; + if let Some(ref cached) = *guard { + let now = unix_now(); + if cached.expires_at > now + TOKEN_REFRESH_BUFFER_SECS { + return Ok(cached.token.clone()); + } + tracing::debug!( + expires_at = cached.expires_at, + now = now, + "Copilot: cached session token expired or expiring soon, refreshing" + ); + } + } + + // Slow path: acquire write lock and re-check (another caller may have + // already refreshed while we waited for the lock). + let mut guard = self.cached.write().await; + if let Some(ref cached) = *guard { + let now = unix_now(); + if cached.expires_at > now + TOKEN_REFRESH_BUFFER_SECS { + return Ok(cached.token.clone()); + } + } + + let response = + exchange_copilot_token(&self.client, self.oauth_token.expose_secret()).await?; + let token = SecretString::from(response.token); + + let expires_at = response.expires_at; + *guard = Some(CachedCopilotToken { + token: token.clone(), + expires_at, + }); + + tracing::debug!( + expires_at = expires_at, + "Copilot session token refreshed" + ); + + Ok(token) + } + + /// Invalidate the cached session token. + /// + /// Called when the API returns 401, so the next `get_token()` call + /// will perform a fresh token exchange instead of reusing the stale token. + pub async fn invalidate(&self) { + let mut guard = self.cached.write().await; + *guard = None; + tracing::debug!("Copilot session token invalidated"); + } +} + +fn truncate_for_error(body: &str) -> String { + const LIMIT: usize = 200; + if body.len() <= LIMIT { + return body.to_string(); + } + + let mut end = LIMIT; + while end > 0 && !body.is_char_boundary(end) { + end -= 1; + } + format!("{}...", &body[..end]) +} + +/// Format a reqwest error with its full causal chain for debugging. +/// +/// `reqwest::Error::to_string()` often just says "error sending request" +/// without the underlying cause (timeout, DNS, TLS, connection refused). +/// This walks the `source()` chain to surface the real problem. +fn format_reqwest_error(e: &reqwest::Error) -> String { + use std::error::Error; + let mut msg = e.to_string(); + let mut source = e.source(); + while let Some(cause) = source { + msg.push_str(&format!(": {cause}")); + source = cause.source(); + } + msg +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_headers_include_required_identity_headers() { + let headers = default_headers(); + assert!(headers.iter().any(|(key, value)| { + key == "Copilot-Integration-Id" && value == GITHUB_COPILOT_INTEGRATION_ID + })); + assert!( + headers + .iter() + .any(|(key, value)| key == "Editor-Version" + && value == GITHUB_COPILOT_EDITOR_VERSION) + ); + assert!( + headers + .iter() + .any(|(key, value)| key == "User-Agent" && value == GITHUB_COPILOT_USER_AGENT) + ); + } + + #[test] + fn truncate_for_error_preserves_utf8_boundaries() { + let long = "日本語".repeat(100); + let truncated = truncate_for_error(&long); + assert!(truncated.ends_with("...")); + assert!(truncated.is_char_boundary(truncated.len() - 3)); + } + + #[test] + fn truncate_for_error_short_strings_unchanged() { + let short = "hello"; + assert_eq!(truncate_for_error(short), "hello"); + } + + // --- poll_for_access_token response parsing --- + + fn parse_access_token_body(json: &str) -> AccessTokenResponse { + serde_json::from_str(json).expect("valid JSON") + } + + #[test] + fn parse_authorization_pending_response() { + let body: AccessTokenResponse = + parse_access_token_body(r#"{"error": "authorization_pending"}"#); + assert!(body.access_token.is_none()); + assert_eq!(body.error.as_deref(), Some("authorization_pending")); + } + + #[test] + fn parse_slow_down_response() { + let body: AccessTokenResponse = parse_access_token_body(r#"{"error": "slow_down"}"#); + assert_eq!(body.error.as_deref(), Some("slow_down")); + } + + #[test] + fn parse_access_denied_response() { + let body: AccessTokenResponse = parse_access_token_body(r#"{"error": "access_denied"}"#); + assert_eq!(body.error.as_deref(), Some("access_denied")); + } + + #[test] + fn parse_expired_token_response() { + let body: AccessTokenResponse = parse_access_token_body(r#"{"error": "expired_token"}"#); + assert_eq!(body.error.as_deref(), Some("expired_token")); + } + + #[test] + fn parse_successful_token_response() { + let body: AccessTokenResponse = + parse_access_token_body(r#"{"access_token": "ghu_abc123"}"#); + assert_eq!(body.access_token.as_deref(), Some("ghu_abc123")); + assert!(body.error.is_none()); + } + + #[test] + fn parse_error_with_description() { + let body: AccessTokenResponse = parse_access_token_body( + r#"{"error": "bad_verification_code", "error_description": "The code has expired"}"#, + ); + assert_eq!(body.error.as_deref(), Some("bad_verification_code")); + assert_eq!( + body.error_description.as_deref(), + Some("The code has expired") + ); + } + + #[test] + fn parse_device_code_response_with_defaults() { + let json = r#"{ + "device_code": "dc_123", + "user_code": "ABCD-1234", + "verification_uri": "https://github.com/login/device", + "expires_in": 900 + }"#; + let resp: DeviceCodeResponse = serde_json::from_str(json).expect("valid JSON"); + assert_eq!(resp.device_code, "dc_123"); + assert_eq!(resp.user_code, "ABCD-1234"); + assert_eq!(resp.interval, 5); // default_poll_interval_secs + assert_eq!(resp.expires_in, 900); + } + + #[test] + fn parse_device_code_response_with_custom_interval() { + let json = r#"{ + "device_code": "dc_456", + "user_code": "EFGH-5678", + "verification_uri": "https://github.com/login/device", + "expires_in": 600, + "interval": 10 + }"#; + let resp: DeviceCodeResponse = serde_json::from_str(json).expect("valid JSON"); + assert_eq!(resp.interval, 10); + } + + // --- CopilotTokenManager --- + + #[tokio::test] + async fn token_manager_caches_token_and_returns_same_value() { + // Pre-populate the cache with a token that expires far in the future. + let client = reqwest::Client::new(); + let manager = CopilotTokenManager::new(client, "unused_oauth".to_string()); + + let far_future = unix_now() + 3600; + { + let mut guard = manager.cached.write().await; + *guard = Some(CachedCopilotToken { + token: SecretString::from("cached_session_token".to_string()), + expires_at: far_future, + }); + } + + let token = manager.get_token().await.expect("should return cached"); + assert_eq!(token.expose_secret(), "cached_session_token"); + + // A second call should return the same cached token. + let token2 = manager.get_token().await.expect("should return cached"); + assert_eq!(token2.expose_secret(), "cached_session_token"); + } + + #[tokio::test] + async fn token_manager_invalidation_clears_cache() { + let client = reqwest::Client::new(); + let manager = CopilotTokenManager::new(client, "unused_oauth".to_string()); + + let far_future = unix_now() + 3600; + { + let mut guard = manager.cached.write().await; + *guard = Some(CachedCopilotToken { + token: SecretString::from("old_token".to_string()), + expires_at: far_future, + }); + } + + manager.invalidate().await; + + let guard = manager.cached.read().await; + assert!(guard.is_none(), "cache should be empty after invalidation"); + } + + #[tokio::test] + async fn token_manager_expired_token_triggers_refresh_path() { + let client = reqwest::Client::new(); + let manager = CopilotTokenManager::new(client, "unused_oauth".to_string()); + + // Set a token that is already expired (expires_at in the past). + { + let mut guard = manager.cached.write().await; + *guard = Some(CachedCopilotToken { + token: SecretString::from("stale_token".to_string()), + expires_at: 1, // way in the past + }); + } + + // get_token will try the slow path (token exchange) which will fail + // because we have no real server, but this proves the cached stale + // token is NOT returned. + let result = manager.get_token().await; + assert!( + result.is_err(), + "expired cached token should trigger exchange, which fails without a server" + ); + } + + #[tokio::test] + async fn token_manager_within_buffer_triggers_refresh() { + let client = reqwest::Client::new(); + let manager = CopilotTokenManager::new(client, "unused_oauth".to_string()); + + // Set a token that expires within the refresh buffer window. + let expires_soon = unix_now() + TOKEN_REFRESH_BUFFER_SECS - 10; + { + let mut guard = manager.cached.write().await; + *guard = Some(CachedCopilotToken { + token: SecretString::from("expiring_soon".to_string()), + expires_at: expires_soon, + }); + } + + let result = manager.get_token().await; + assert!( + result.is_err(), + "token within buffer should trigger exchange" + ); + } + + // --- CopilotTokenResponse parsing --- + + #[test] + fn parse_copilot_token_response() { + let json = r#"{"token": "tid=abc;exp=999;sku=123;sig=xyz", "expires_at": 1700000000}"#; + let resp: CopilotTokenResponse = serde_json::from_str(json).expect("valid JSON"); + assert!(resp.token.starts_with("tid=")); + assert_eq!(resp.expires_at, 1700000000); + } +} diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 51309bf37d..3a75d41421 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -18,6 +18,8 @@ pub mod config; pub mod costs; pub mod error; pub mod failover; +mod github_copilot; +pub(crate) mod github_copilot_auth; mod nearai_chat; pub mod oauth_helpers; mod provider; @@ -153,6 +155,16 @@ fn create_registry_provider( ProviderProtocol::OpenAiCompletions => create_openai_compat_from_registry(config), ProviderProtocol::Anthropic => create_anthropic_from_registry(config), ProviderProtocol::Ollama => create_ollama_from_registry(config), + ProviderProtocol::GithubCopilot => { + let provider = github_copilot::GithubCopilotProvider::new(config)?; + tracing::debug!( + provider = %config.provider_id, + model = %config.model, + base_url = %config.base_url, + "Using GitHub Copilot provider (token exchange)" + ); + Ok(Arc::new(provider)) + } } } diff --git a/src/llm/registry.rs b/src/llm/registry.rs index a36e2479e7..9e2ee7f5a8 100644 --- a/src/llm/registry.rs +++ b/src/llm/registry.rs @@ -37,6 +37,8 @@ pub enum ProviderProtocol { Anthropic, /// Ollama API (OpenAI-ish, no API key required). Ollama, + /// GitHub Copilot API (OpenAI-compatible with token exchange). + GithubCopilot, } /// How the setup wizard should collect credentials for this provider. diff --git a/src/settings.rs b/src/settings.rs index 2a5b6bbd21..ed26218b92 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -47,7 +47,7 @@ pub struct Settings { pub secrets_master_key_hex: Option, // === Step 3: Inference Provider === - /// LLM backend: "nearai", "anthropic", "openai", "ollama", "openai_compatible", "tinfoil", "bedrock". + /// LLM backend: "nearai", "anthropic", "openai", "github_copilot", "ollama", "openai_compatible", "tinfoil", "bedrock". #[serde(default)] pub llm_backend: Option, diff --git a/src/setup/README.md b/src/setup/README.md index 196b910d4f..d8652b1147 100644 --- a/src/setup/README.md +++ b/src/setup/README.md @@ -212,6 +212,7 @@ env-var mode or skipped secrets. | NEAR AI Cloud | API key | `llm_nearai_api_key` | `NEARAI_API_KEY` | | Anthropic | API key | `anthropic_api_key` | `ANTHROPIC_API_KEY` | | OpenAI | API key | `openai_api_key` | `OPENAI_API_KEY` | +| GitHub Copilot | OAuth token | `llm_github_copilot_token` | `GITHUB_COPILOT_TOKEN` | | Ollama | None | - | - | | OpenRouter | API key | `llm_openrouter_api_key` | `OPENROUTER_API_KEY` | | OpenAI-compatible | Optional API key | `llm_compatible_api_key` | `LLM_API_KEY` | @@ -234,6 +235,12 @@ with its own secret name and env var. It is **not** stored as `openai_compatible 5. Preserve `selected_model` on a same-backend re-run; clear it only when switching to a different backend +**GitHub Copilot** (`setup_github_copilot`): +- Offers **GitHub device login** (recommended) or manual token paste +- Device login uses the VS Code Copilot OAuth client and stores the resulting token as `llm_github_copilot_token` +- Validates the token against `https://api.githubcopilot.com/models` before saving +- Injects `GITHUB_COPILOT_TOKEN` into the config overlay for immediate provider use + **NEAR AI** (`setup_nearai`): - Calls `session_manager.ensure_authenticated()` which shows the auth menu: - Options 1-2 (GitHub/Google): browser OAuth → **NEAR AI Chat** mode @@ -522,7 +529,7 @@ pub struct Settings { pub secrets_master_key_source: KeySource, // Keychain | Env | None // Step 3: Inference - pub llm_backend: Option, // "nearai" | "anthropic" | "openai" | "ollama" | "openai_compatible" | "bedrock" + pub llm_backend: Option, // "nearai" | "anthropic" | "openai" | "github_copilot" | "ollama" | "openai_compatible" | "bedrock" pub ollama_base_url: Option, pub openai_compatible_base_url: Option, diff --git a/src/setup/wizard.rs b/src/setup/wizard.rs index 9437d8279b..44dd17f406 100644 --- a/src/setup/wizard.rs +++ b/src/setup/wizard.rs @@ -3,7 +3,7 @@ //! The wizard guides users through: //! 1. Database connection //! 2. Security (secrets master key) -//! 3. Inference provider (NEAR AI, Anthropic, OpenAI, Ollama, OpenAI-compatible) +//! 3. Inference provider (NEAR AI, Anthropic, OpenAI, GitHub Copilot, Ollama, OpenAI-compatible) //! 4. Model selection //! 5. Embeddings //! 6. Channel configuration @@ -932,6 +932,10 @@ impl SetupWizard { return self.setup_anthropic().await; } + if provider_id == "github_copilot" { + return self.setup_github_copilot().await; + } + match setup { crate::llm::registry::SetupHint::ApiKey { secret_name, @@ -1073,6 +1077,71 @@ impl SetupWizard { } } + async fn setup_github_copilot(&mut self) -> Result<(), SetupError> { + self.setup_github_copilot_device_login().await + } + + async fn setup_github_copilot_device_login(&mut self) -> Result<(), SetupError> { + self.set_llm_backend_preserving_model("github_copilot"); + + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(15)) + .build() + .map_err(|e| SetupError::Auth(format!("Failed to create HTTP client: {e}")))?; + + let device = crate::llm::github_copilot_auth::request_device_code(&client) + .await + .map_err(|e| SetupError::Auth(e.to_string()))?; + + print_info("Authorize IronClaw with GitHub Copilot in your browser."); + print_info(&format!("Verification URL: {}", device.verification_uri)); + print_info(&format!("One-time code: {}", device.user_code)); + + if let Err(e) = open::that(&device.verification_uri) { + tracing::debug!( + url = %device.verification_uri, + error = %e, + "Failed to open GitHub Copilot device login URL" + ); + print_info("Open the URL above manually if your browser did not launch."); + } else { + print_info("Opened your browser to GitHub device login."); + } + + print_info("Waiting for GitHub authorization..."); + let token = crate::llm::github_copilot_auth::wait_for_device_login(&client, &device) + .await + .map_err(|e| SetupError::Auth(e.to_string()))?; + + self.save_github_copilot_token(&client, &token).await + } + + async fn save_github_copilot_token( + &mut self, + client: &reqwest::Client, + token: &str, + ) -> Result<(), SetupError> { + crate::llm::github_copilot_auth::validate_token(client, token) + .await + .map_err(|e| SetupError::Auth(e.to_string()))?; + + if let Ok(ctx) = self.init_secrets_context().await { + let key = SecretString::from(token.to_string()); + ctx.save_secret("llm_github_copilot_token", &key) + .await + .map_err(|e| SetupError::Config(format!("Failed to save GitHub token: {e}")))?; + print_success("GitHub Copilot token encrypted and saved"); + } else { + print_info("Secrets not available. Set GITHUB_COPILOT_TOKEN in your environment."); + } + + crate::config::inject_single_var("GITHUB_COPILOT_TOKEN", token); + self.llm_api_key = Some(SecretString::from(token.to_string())); + + print_success("GitHub Copilot configured"); + Ok(()) + } + /// Anthropic OAuth setup: extract token from `claude login` credentials. async fn setup_anthropic_oauth(&mut self) -> Result<(), SetupError> { self.set_llm_backend_preserving_model("anthropic"); @@ -3043,6 +3112,36 @@ mod tests { ); } + #[test] + fn test_github_copilot_setup_preserves_model_for_same_backend() { + let mut wizard = SetupWizard::new(); + wizard.settings.llm_backend = Some("github_copilot".to_string()); + wizard.settings.selected_model = Some("gpt-4o".to_string()); + + wizard.set_llm_backend_preserving_model("github_copilot"); + + assert_eq!(wizard.settings.selected_model.as_deref(), Some("gpt-4o")); + assert_eq!( + wizard.settings.llm_backend.as_deref(), + Some("github_copilot") + ); + } + + #[test] + fn test_github_copilot_setup_clears_stale_model_on_switch() { + let mut wizard = SetupWizard::new(); + wizard.settings.llm_backend = Some("openai".to_string()); + wizard.settings.selected_model = Some("gpt-5".to_string()); + + wizard.set_llm_backend_preserving_model("github_copilot"); + + assert!(wizard.settings.selected_model.is_none()); + assert_eq!( + wizard.settings.llm_backend.as_deref(), + Some("github_copilot") + ); + } + #[test] fn test_is_openai_chat_model_includes_gpt5_and_filters_non_chat_variants() { assert!(is_openai_chat_model("gpt-5")); diff --git a/tests/config_round_trip.rs b/tests/config_round_trip.rs index 8351ff74fc..d35bfe16f5 100644 --- a/tests/config_round_trip.rs +++ b/tests/config_round_trip.rs @@ -56,6 +56,7 @@ fn bootstrap_env_round_trips_llm_backend() { for backend in &[ "nearai", "anthropic", + "github_copilot", "ollama", "openai_compatible", "tinfoil",