diff --git a/README.md b/README.md index d627611e..da07c459 100644 --- a/README.md +++ b/README.md @@ -145,8 +145,15 @@ export LLM_API_KEY="your-api-key" # Optional export LLM_API_BASE="your-api-base-url" # if using a local model, e.g. Ollama, LMStudio export PERPLEXITY_API_KEY="your-api-key" # for search capabilities + +# Proxy Configuration (optional) +export STRIX_PROXY_ALL="socks5://proxy.example.com:1080" # Proxy for all traffic +export STRIX_PROXY_TOOLS="http://proxy.example.com:8080" # Proxy for tool traffic only +export STRIX_PROXY_LLM="https://proxy.example.com:8080" # Proxy for LLM traffic only ``` +**Proxy Support**: Strix supports both HTTP and SOCKS5 proxies for routing traffic through corporate networks, WAF allow-lists, or SSH tunnels. Configure separate proxies for tool traffic and LLM requests, or use `STRIX_PROXY_ALL` for unified routing. + [📚 View supported AI models](https://docs.litellm.ai/docs/providers) ### 🤖 Headless Mode diff --git a/pyproject.toml b/pyproject.toml index 974087a0..1c3bd0e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,9 @@ textual = "^4.0.0" xmltodict = "^0.13.0" pyte = "^0.8.1" requests = "^2.32.0" +requests-socks = "^2.0.0" # SOCKS proxy support for requests libtmux = "^0.46.2" +httpx-socks = "^0.9.1" # SOCKS proxy support for httpx [tool.poetry.group.dev.dependencies] # Type checking and static analysis diff --git a/strix/interface/main.py b/strix/interface/main.py index 063dc10d..1e411112 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -32,6 +32,7 @@ process_pull_line, validate_llm_response, ) +from strix.proxy_config import configure_global_proxies from strix.runtime.docker_runtime import STRIX_IMAGE from strix.telemetry.tracer import get_global_tracer @@ -68,6 +69,34 @@ def validate_environment() -> None: # noqa: PLR0912, PLR0915 if not os.getenv("PERPLEXITY_API_KEY"): missing_optional_vars.append("PERPLEXITY_API_KEY") + # Validate proxy configuration + try: + configure_global_proxies() + except ValueError as e: + error_text = Text() + error_text.append("❌ ", style="bold red") + error_text.append("INVALID PROXY CONFIGURATION", style="bold red") + error_text.append("\n\n", style="white") + error_text.append(str(e), style="white") + error_text.append("\n\nSupported proxy formats:\n", style="white") + error_text.append("• http://proxy.example.com:8080\n", style="dim white") + error_text.append("• https://proxy.example.com:8080\n", style="dim white") + error_text.append("• socks5://proxy.example.com:1080\n", style="dim white") + error_text.append("• socks5h://proxy.example.com:1080\n", style="dim white") + + panel = Panel( + error_text, + title="[bold red]🛡️ STRIX CONFIGURATION ERROR", + title_align="center", + border_style="red", + padding=(1, 2), + ) + + console.print("\n") + console.print(panel) + console.print() + sys.exit(1) + if missing_required_vars: error_text = Text() error_text.append("❌ ", style="bold red") @@ -123,6 +152,25 @@ def validate_environment() -> None: # noqa: PLR0912, PLR0915 style="white", ) + # Add proxy configuration documentation + proxy_configured = any([ + os.getenv("STRIX_PROXY_ALL"), + os.getenv("STRIX_PROXY_TOOLS"), + os.getenv("STRIX_PROXY_LLM") + ]) + + if proxy_configured or missing_optional_vars: + error_text.append("\nProxy configuration (optional):\n", style="white") + error_text.append("• ", style="white") + error_text.append("STRIX_PROXY_ALL", style="bold cyan") + error_text.append(" - Proxy for all traffic (tools and LLM)\n", style="white") + error_text.append("• ", style="white") + error_text.append("STRIX_PROXY_TOOLS", style="bold cyan") + error_text.append(" - Proxy for tool traffic only\n", style="white") + error_text.append("• ", style="white") + error_text.append("STRIX_PROXY_LLM", style="bold cyan") + error_text.append(" - Proxy for LLM traffic only\n", style="white") + error_text.append("\nExample setup:\n", style="white") error_text.append("export STRIX_LLM='openai/gpt-5'\n", style="dim white") @@ -147,6 +195,19 @@ def validate_environment() -> None: # noqa: PLR0912, PLR0915 "export PERPLEXITY_API_KEY='your-perplexity-key-here'\n", style="dim white" ) + # Add proxy examples if any proxy is configured + proxy_configured = any([ + os.getenv("STRIX_PROXY_ALL"), + os.getenv("STRIX_PROXY_TOOLS"), + os.getenv("STRIX_PROXY_LLM") + ]) + + if proxy_configured: + error_text.append("\nProxy examples:\n", style="white") + error_text.append("export STRIX_PROXY_ALL='socks5://proxy.example.com:1080'\n", style="dim white") + error_text.append("export STRIX_PROXY_TOOLS='http://proxy.example.com:8080'\n", style="dim white") + error_text.append("export STRIX_PROXY_LLM='https://llm-proxy.example.com:8080'\n", style="dim white") + panel = Panel( error_text, title="[bold red]🛡️ STRIX CONFIGURATION ERROR", diff --git a/strix/proxy_config.py b/strix/proxy_config.py new file mode 100644 index 00000000..eb312f69 --- /dev/null +++ b/strix/proxy_config.py @@ -0,0 +1,154 @@ +""" +Proxy configuration module for Strix. + +This module handles upstream proxy configuration for both tool traffic and LLM traffic. +Supports both SOCKS5 and HTTP proxies as requested in: +https://github.com/usestrix/strix/issues/19 +""" + +import os +from dataclasses import dataclass +from typing import Any +from urllib.parse import urlparse + + +@dataclass +class ProxyConfig: + """Configuration for upstream proxies.""" + + tools_proxy: str | None = None + llm_proxy: str | None = None + all_proxy: str | None = None + + def __post_init__(self) -> None: + """Validate proxy configurations.""" + for proxy_name, proxy_url in [ + ("STRIX_PROXY_TOOLS", self.tools_proxy), + ("STRIX_PROXY_LLM", self.llm_proxy), + ("STRIX_PROXY_ALL", self.all_proxy), + ]: + if proxy_url: + self._validate_proxy_url(proxy_url, proxy_name) + + def _validate_proxy_url(self, proxy_url: str, env_var_name: str) -> None: + """Validate proxy URL format.""" + try: + parsed = urlparse(proxy_url) + if parsed.scheme not in ["http", "https", "socks5", "socks5h"]: + raise ValueError( + f"Invalid proxy scheme in {env_var_name}: {parsed.scheme}. " + "Supported schemes: http, https, socks5, socks5h" + ) + if not parsed.hostname: + raise ValueError(f"Missing hostname in {env_var_name}: {proxy_url}") + if not parsed.port: + raise ValueError(f"Missing port in {env_var_name}: {proxy_url}") + except Exception as e: + raise ValueError(f"Invalid proxy URL in {env_var_name}: {proxy_url}") from e + + def get_tools_proxy(self) -> str | None: + """Get proxy configuration for tools traffic.""" + return self.tools_proxy or self.all_proxy + + def get_llm_proxy(self) -> str | None: + """Get proxy configuration for LLM traffic.""" + return self.llm_proxy or self.all_proxy + + def get_requests_proxies(self, proxy_type: str = "tools") -> dict[str, str] | None: + """ + Get proxy configuration in requests library format. + + Args: + proxy_type: Either 'tools' or 'llm' to determine which proxy to use. + + Returns: + Dictionary with 'http' and 'https' keys, or None if no proxy configured. + """ + proxy_url = self.get_tools_proxy() if proxy_type == "tools" else self.get_llm_proxy() + if not proxy_url: + return None + + return {"http": proxy_url, "https": proxy_url} + + def get_httpx_proxies(self, proxy_type: str = "tools") -> dict[str, str] | None: + """ + Get proxy configuration in httpx library format. + + Args: + proxy_type: Either 'tools' or 'llm' to determine which proxy to use. + + Returns: + Dictionary with protocol keys, or None if no proxy configured. + + Note: + For SOCKS proxies with httpx, we need to use httpx-socks library + and create AsyncProxyTransport instead of simple URL strings. + """ + proxy_url = self.get_tools_proxy() if proxy_type == "tools" else self.get_llm_proxy() + if not proxy_url: + return None + + # For httpx, we can return the same format as requests for HTTP proxies + # SOCKS proxies need special handling with httpx-socks + parsed = urlparse(proxy_url) + if parsed.scheme in ["socks5", "socks5h"]: + # We'll handle SOCKS in the calling code using httpx-socks + return {"_socks_proxy": proxy_url} + else: + # HTTP/HTTPS proxies work the same as requests + return {"http://": proxy_url, "https://": proxy_url} + + def get_litellm_proxy_env(self) -> dict[str, str]: + """ + Get environment variables for litellm proxy configuration. + + Returns: + Dictionary of environment variables to set for litellm. + """ + env_vars = {} + llm_proxy = self.get_llm_proxy() + + if llm_proxy: + # litellm supports standard proxy environment variables + env_vars["HTTP_PROXY"] = llm_proxy + env_vars["HTTPS_PROXY"] = llm_proxy + + return env_vars + + +def load_proxy_config() -> ProxyConfig: + """Load proxy configuration from environment variables.""" + return ProxyConfig( + tools_proxy=os.getenv("STRIX_PROXY_TOOLS"), + llm_proxy=os.getenv("STRIX_PROXY_LLM"), + all_proxy=os.getenv("STRIX_PROXY_ALL"), + ) + + +def configure_global_proxies() -> ProxyConfig: + """ + Configure global proxy settings and return the configuration. + + This function should be called early in the application startup + to ensure proxy settings are applied globally. + """ + config = load_proxy_config() + + # Set environment variables for litellm if LLM proxy is configured + llm_proxy_env = config.get_litellm_proxy_env() + for key, value in llm_proxy_env.items(): + os.environ[key] = value + + return config + + +# Global proxy configuration instance +_global_proxy_config: ProxyConfig | None = None + + +def get_proxy_config() -> ProxyConfig: + """Get the global proxy configuration instance.""" + global _global_proxy_config # noqa: PLW0603 + if _global_proxy_config is None: + _global_proxy_config = configure_global_proxies() + return _global_proxy_config \ No newline at end of file diff --git a/strix/runtime/docker_runtime.py b/strix/runtime/docker_runtime.py index 32cc6252..c32f34ad 100644 --- a/strix/runtime/docker_runtime.py +++ b/strix/runtime/docker_runtime.py @@ -11,6 +11,8 @@ from docker.errors import DockerException, ImageNotFound, NotFound from docker.models.containers import Container +from strix.proxy_config import get_proxy_config + from .runtime import AbstractRuntime, SandboxInfo @@ -340,18 +342,40 @@ async def _register_agent_with_tool_server( ) -> None: import httpx - try: - async with httpx.AsyncClient(trust_env=False) as client: - response = await client.post( - f"{api_url}/register_agent", - params={"agent_id": agent_id}, - headers={"Authorization": f"Bearer {token}"}, - timeout=30, - ) - response.raise_for_status() - logger.info(f"Registered agent {agent_id} with tool server") - except (httpx.RequestError, httpx.HTTPStatusError) as e: - logger.warning(f"Failed to register agent {agent_id}: {e}") + proxy_config = get_proxy_config() + proxies = proxy_config.get_httpx_proxies("tools") + + # Handle SOCKS proxies with httpx-socks + if proxies and "_socks_proxy" in proxies: + from httpx_socks import AsyncProxyTransport + + socks_url = proxies["_socks_proxy"] + transport = AsyncProxyTransport.from_url(socks_url) + try: + async with httpx.AsyncClient(transport=transport, trust_env=False) as client: + response = await client.post( + f"{api_url}/register_agent", + params={"agent_id": agent_id}, + headers={"Authorization": f"Bearer {token}"}, + timeout=30, + ) + response.raise_for_status() + logger.info(f"Registered agent {agent_id} with tool server") + except (httpx.RequestError, httpx.HTTPStatusError) as e: + logger.warning(f"Failed to register agent {agent_id}: {e}") + else: + try: + async with httpx.AsyncClient(trust_env=False, proxies=proxies) as client: + response = await client.post( + f"{api_url}/register_agent", + params={"agent_id": agent_id}, + headers={"Authorization": f"Bearer {token}"}, + timeout=30, + ) + response.raise_for_status() + logger.info(f"Registered agent {agent_id} with tool server") + except (httpx.RequestError, httpx.HTTPStatusError) as e: + logger.warning(f"Failed to register agent {agent_id}: {e}") async def get_sandbox_url(self, container_id: str, port: int) -> str: try: diff --git a/strix/tools/executor.py b/strix/tools/executor.py index 6dd1b04d..e6e300dc 100644 --- a/strix/tools/executor.py +++ b/strix/tools/executor.py @@ -4,6 +4,8 @@ import httpx +from strix.proxy_config import get_proxy_config + if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false": from strix.runtime import get_runtime @@ -62,22 +64,50 @@ async def _execute_tool_in_sandbox(tool_name: str, agent_state: Any, **kwargs: A "Content-Type": "application/json", } - async with httpx.AsyncClient(trust_env=False) as client: - try: - response = await client.post( - request_url, json=request_data, headers=headers, timeout=None - ) - response.raise_for_status() - response_data = response.json() - if response_data.get("error"): - raise RuntimeError(f"Sandbox execution error: {response_data['error']}") - return response_data.get("result") - except httpx.HTTPStatusError as e: - if e.response.status_code == 401: - raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e - raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e - except httpx.RequestError as e: - raise RuntimeError(f"Request error calling tool server: {e}") from e + proxy_config = get_proxy_config() + proxies = proxy_config.get_httpx_proxies("tools") + + # Handle SOCKS proxies with httpx-socks + if proxies and "_socks_proxy" in proxies: + from httpx_socks import AsyncProxyTransport + from urllib.parse import urlparse + + socks_url = proxies["_socks_proxy"] + parsed = urlparse(socks_url) + transport = AsyncProxyTransport.from_url(socks_url) + async with httpx.AsyncClient(transport=transport, trust_env=False) as client: + try: + response = await client.post( + request_url, json=request_data, headers=headers, timeout=None + ) + response.raise_for_status() + response_data = response.json() + if response_data.get("error"): + raise RuntimeError(f"Sandbox execution error: {response_data['error']}") + return response_data.get("result") + except httpx.HTTPStatusError as e: + if e.response.status_code == 401: + raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e + raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e + except httpx.RequestError as e: + raise RuntimeError(f"Request error calling tool server: {e}") from e + else: + async with httpx.AsyncClient(trust_env=False, proxies=proxies) as client: + try: + response = await client.post( + request_url, json=request_data, headers=headers, timeout=None + ) + response.raise_for_status() + response_data = response.json() + if response_data.get("error"): + raise RuntimeError(f"Sandbox execution error: {response_data['error']}") + return response_data.get("result") + except httpx.HTTPStatusError as e: + if e.response.status_code == 401: + raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e + raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e + except httpx.RequestError as e: + raise RuntimeError(f"Request error calling tool server: {e}") from e async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwargs: Any) -> Any: diff --git a/strix/tools/proxy/proxy_manager.py b/strix/tools/proxy/proxy_manager.py index e02d85b7..1ca15b23 100644 --- a/strix/tools/proxy/proxy_manager.py +++ b/strix/tools/proxy/proxy_manager.py @@ -11,6 +11,8 @@ from gql.transport.requests import RequestsHTTPTransport from requests.exceptions import ProxyError, RequestException, Timeout +from strix.proxy_config import get_proxy_config + if TYPE_CHECKING: from collections.abc import Callable diff --git a/strix/tools/web_search/web_search_actions.py b/strix/tools/web_search/web_search_actions.py index 52f00a97..bc224641 100644 --- a/strix/tools/web_search/web_search_actions.py +++ b/strix/tools/web_search/web_search_actions.py @@ -3,6 +3,7 @@ import requests +from strix.proxy_config import get_proxy_config from strix.tools.registry import register_tool @@ -53,7 +54,13 @@ def web_search(query: str) -> dict[str, Any]: ], } - response = requests.post(url, headers=headers, json=payload, timeout=300) + response = requests.post( + url, + headers=headers, + json=payload, + timeout=300, + proxies=get_proxy_config().get_requests_proxies("tools") + ) response.raise_for_status() response_data = response.json()