diff --git a/src/any_llm/gateway/config.py b/src/any_llm/gateway/config.py index 1ab7546c..8948ed9c 100644 --- a/src/any_llm/gateway/config.py +++ b/src/any_llm/gateway/config.py @@ -1,10 +1,11 @@ +import json import os from pathlib import Path from typing import Any import yaml -from pydantic import BaseModel, Field -from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic import BaseModel, Field, field_validator +from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict API_KEY_HEADER = "X-AnyLLM-Key" @@ -17,12 +18,22 @@ class PricingConfig(BaseModel): class GatewayConfig(BaseSettings): - """Gateway configuration with support for YAML files and environment variables.""" + """Gateway configuration with support for YAML files and environment variables. + + All configuration parameters can be set via environment variables with GATEWAY_ prefix: + - Simple values: GATEWAY_HOST, GATEWAY_PORT, GATEWAY_DATABASE_URL, etc. + - Boolean values: GATEWAY_AUTO_MIGRATE=true/false + - Complex structures (JSON): GATEWAY_PROVIDERS='{"openai": {"api_key": "sk-..."}}' + - Complex structures (JSON): GATEWAY_PRICING='{"openai:gpt-4": {"input_price_per_million": 30, "output_price_per_million": 60}}' + + Environment variables take precedence over YAML config values. + """ model_config = SettingsConfigDict( env_prefix="GATEWAY_", env_file=".env", case_sensitive=False, + env_nested_delimiter="__", ) database_url: str = Field( @@ -44,25 +55,93 @@ class GatewayConfig(BaseSettings): description="Pre-configured model USD pricing (model_key -> {input_price_per_million, output_price_per_million})", ) + @field_validator("providers", mode="before") + @classmethod + def parse_providers(cls, v: Any) -> dict[str, dict[str, Any]]: + """Parse providers from JSON string or return dict as-is.""" + if isinstance(v, str): + try: + parsed = json.loads(v) + except json.JSONDecodeError as e: + msg = f"Invalid JSON in GATEWAY_PROVIDERS: {e}" + raise ValueError(msg) from e + else: + if not isinstance(parsed, dict): + msg = "GATEWAY_PROVIDERS must be a JSON object" + raise ValueError(msg) + return parsed + return v if isinstance(v, dict) else {} + + @field_validator("pricing", mode="before") + @classmethod + def parse_pricing(cls, v: Any) -> dict[str, dict[str, float]]: + """Parse pricing from JSON string or return dict as-is.""" + if isinstance(v, str): + try: + parsed = json.loads(v) + except json.JSONDecodeError as e: + msg = f"Invalid JSON in GATEWAY_PRICING: {e}" + raise ValueError(msg) from e + else: + if not isinstance(parsed, dict): + msg = "GATEWAY_PRICING must be a JSON object" + raise ValueError(msg) + return parsed + return v if isinstance(v, dict) else {} + + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource, ...]: + """Customize settings source precedence. + + Order (highest to lowest priority): + 1. Environment variables + 2. Init settings (from YAML config file) + 3. .env file + 4. Secrets directory + """ + return env_settings, init_settings, dotenv_settings, file_secret_settings + def load_config(config_path: str | None = None) -> GatewayConfig: """Load configuration from file and environment variables. + Environment variables take precedence over YAML config values. + All config parameters support GATEWAY_ prefixed env vars. + Args: config_path: Optional path to YAML config file Returns: GatewayConfig instance with merged configuration + Example: + # Using environment variables only (no config file needed): + export GATEWAY_HOST="0.0.0.0" + export GATEWAY_PORT=8000 + export GATEWAY_DATABASE_URL="postgresql://..." + export GATEWAY_MASTER_KEY="your-secret-key" + export GATEWAY_PROVIDERS='{"openai": {"api_key": "sk-..."}}' + export GATEWAY_PRICING='{"openai:gpt-4": {"input_price_per_million": 30, "output_price_per_million": 60}}' + """ config_dict: dict[str, Any] = {} + # Load from YAML file if provided if config_path and Path(config_path).exists(): with open(config_path, encoding="utf-8") as f: yaml_config = yaml.safe_load(f) if yaml_config: config_dict = _resolve_env_vars(yaml_config) + # GatewayConfig (BaseSettings) will automatically load environment variables + # and they will take precedence over the config_dict values return GatewayConfig(**config_dict) diff --git a/tests/unit/test_gateway_config.py b/tests/unit/test_gateway_config.py new file mode 100644 index 00000000..80c607ad --- /dev/null +++ b/tests/unit/test_gateway_config.py @@ -0,0 +1,242 @@ +"""Tests for gateway configuration with environment variables.""" + +import json +import os +from tempfile import NamedTemporaryFile + +import pytest + +from any_llm.gateway.config import GatewayConfig, load_config + + +class TestGatewayConfigEnvVars: + """Test that all config parameters can be set via environment variables.""" + + def test_simple_string_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test simple string configuration via environment variables.""" + monkeypatch.setenv("GATEWAY_HOST", "127.0.0.1") + monkeypatch.setenv("GATEWAY_DATABASE_URL", "postgresql://test:test@localhost/test") + monkeypatch.setenv("GATEWAY_MASTER_KEY", "test-master-key") + + config = GatewayConfig() + + assert config.host == "127.0.0.1" + assert config.database_url == "postgresql://test:test@localhost/test" + assert config.master_key == "test-master-key" + + def test_integer_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test integer configuration via environment variables.""" + monkeypatch.setenv("GATEWAY_PORT", "9000") + + config = GatewayConfig() + + assert config.port == 9000 + + def test_boolean_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test boolean configuration via environment variables.""" + monkeypatch.setenv("GATEWAY_AUTO_MIGRATE", "false") + + config = GatewayConfig() + + assert config.auto_migrate is False + + monkeypatch.setenv("GATEWAY_AUTO_MIGRATE", "true") + config = GatewayConfig() + + assert config.auto_migrate is True + + def test_providers_json_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test providers configuration via JSON environment variable.""" + providers_json = json.dumps({ + "openai": { + "api_key": "sk-test-key", + "api_base": "https://api.openai.com/v1", + }, + "anthropic": { + "api_key": "sk-ant-test", + }, + }) + monkeypatch.setenv("GATEWAY_PROVIDERS", providers_json) + + config = GatewayConfig() + + assert "openai" in config.providers + assert config.providers["openai"]["api_key"] == "sk-test-key" + assert config.providers["openai"]["api_base"] == "https://api.openai.com/v1" + assert "anthropic" in config.providers + assert config.providers["anthropic"]["api_key"] == "sk-ant-test" + + def test_pricing_json_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test pricing configuration via JSON environment variable.""" + pricing_json = json.dumps({ + "openai:gpt-4": { + "input_price_per_million": 30.0, + "output_price_per_million": 60.0, + }, + "anthropic:claude-3-opus": { + "input_price_per_million": 15.0, + "output_price_per_million": 75.0, + }, + }) + monkeypatch.setenv("GATEWAY_PRICING", pricing_json) + + config = GatewayConfig() + + assert "openai:gpt-4" in config.pricing + assert config.pricing["openai:gpt-4"].input_price_per_million == 30.0 + assert config.pricing["openai:gpt-4"].output_price_per_million == 60.0 + assert "anthropic:claude-3-opus" in config.pricing + assert config.pricing["anthropic:claude-3-opus"].input_price_per_million == 15.0 + assert config.pricing["anthropic:claude-3-opus"].output_price_per_million == 75.0 + + def test_invalid_providers_json_raises_error(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that invalid JSON in GATEWAY_PROVIDERS raises an error.""" + from pydantic_settings import SettingsError + + monkeypatch.setenv("GATEWAY_PROVIDERS", "{invalid json") + + with pytest.raises(SettingsError, match='error parsing value for field "providers"'): + GatewayConfig() + + def test_invalid_pricing_json_raises_error(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that invalid JSON in GATEWAY_PRICING raises an error.""" + from pydantic_settings import SettingsError + + monkeypatch.setenv("GATEWAY_PRICING", "{invalid json") + + with pytest.raises(SettingsError, match='error parsing value for field "pricing"'): + GatewayConfig() + + def test_providers_as_list_becomes_empty_dict(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that array JSON in GATEWAY_PROVIDERS is converted to empty dict.""" + monkeypatch.setenv("GATEWAY_PROVIDERS", '["array", "not", "object"]') + + config = GatewayConfig() + + # Pydantic-settings will parse the JSON array, but our validator will return empty dict + assert config.providers == {} + + def test_pricing_as_list_becomes_empty_dict(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that array JSON in GATEWAY_PRICING is converted to empty dict.""" + monkeypatch.setenv("GATEWAY_PRICING", '["array", "not", "object"]') + + config = GatewayConfig() + + # Pydantic-settings will parse the JSON array, but our validator will return empty dict + assert config.pricing == {} + + def test_all_config_params_via_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that all configuration parameters can be set via environment variables.""" + monkeypatch.setenv("GATEWAY_DATABASE_URL", "postgresql://env:env@localhost/env") + monkeypatch.setenv("GATEWAY_AUTO_MIGRATE", "false") + monkeypatch.setenv("GATEWAY_HOST", "192.168.1.1") + monkeypatch.setenv("GATEWAY_PORT", "7000") + monkeypatch.setenv("GATEWAY_MASTER_KEY", "env-master-key") + monkeypatch.setenv("GATEWAY_PROVIDERS", '{"test": {"key": "value"}}') + monkeypatch.setenv( + "GATEWAY_PRICING", '{"test:model": {"input_price_per_million": 1.0, "output_price_per_million": 2.0}}' + ) + + config = GatewayConfig() + + assert config.database_url == "postgresql://env:env@localhost/env" + assert config.auto_migrate is False + assert config.host == "192.168.1.1" + assert config.port == 7000 + assert config.master_key == "env-master-key" + assert config.providers == {"test": {"key": "value"}} + assert "test:model" in config.pricing + assert config.pricing["test:model"].input_price_per_million == 1.0 + + +class TestLoadConfigPrecedence: + """Test that environment variables take precedence over YAML config.""" + + def test_env_vars_override_yaml_config(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that environment variables override YAML configuration.""" + yaml_content = """ +database_url: "postgresql://yaml:yaml@localhost/yaml" +host: "0.0.0.0" +port: 8000 +master_key: "yaml-master-key" +providers: + openai: + api_key: "yaml-openai-key" +pricing: + openai:gpt-4: + input_price_per_million: 10.0 + output_price_per_million: 20.0 +""" + + with NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(yaml_content) + config_path = f.name + + try: + # Set environment variables that should override YAML + monkeypatch.setenv("GATEWAY_DATABASE_URL", "postgresql://env:env@localhost/env") + monkeypatch.setenv("GATEWAY_PORT", "9000") + monkeypatch.setenv("GATEWAY_MASTER_KEY", "env-master-key") + monkeypatch.setenv("GATEWAY_PROVIDERS", '{"anthropic": {"api_key": "env-anthropic-key"}}') + + config = load_config(config_path) + + # Environment variables should take precedence + assert config.database_url == "postgresql://env:env@localhost/env" + assert config.port == 9000 + assert config.master_key == "env-master-key" + assert "anthropic" in config.providers + assert config.providers["anthropic"]["api_key"] == "env-anthropic-key" + + # YAML value should be used when no env var is set + assert config.host == "0.0.0.0" # noqa: S104 + finally: + os.unlink(config_path) + + def test_load_config_without_file(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that load_config works without a config file (env vars only).""" + monkeypatch.setenv("GATEWAY_HOST", "10.0.0.1") + monkeypatch.setenv("GATEWAY_PORT", "5000") + monkeypatch.setenv("GATEWAY_DATABASE_URL", "postgresql://test:test@localhost/test") + monkeypatch.setenv("GATEWAY_MASTER_KEY", "test-key") + + config = load_config(None) + + assert config.host == "10.0.0.1" + assert config.port == 5000 + assert config.database_url == "postgresql://test:test@localhost/test" + assert config.master_key == "test-key" + + def test_defaults_used_when_no_config_or_env(self) -> None: + """Test that default values are used when no config file or env vars are set.""" + config = GatewayConfig() + + assert config.host == "0.0.0.0" # noqa: S104 + assert config.port == 8000 + assert config.database_url == "postgresql://postgres:postgres@localhost:5432/any_llm_gateway" + assert config.auto_migrate is True + assert config.master_key is None + assert config.providers == {} + assert config.pricing == {} + + def test_yaml_env_var_substitution_still_works(self) -> None: + """Test that ${VAR} substitution in YAML still works.""" + yaml_content = """ +database_url: "postgresql://postgres:postgres@localhost/db" +master_key: "${GATEWAY_MASTER_KEY}" +""" + + with NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(yaml_content) + config_path = f.name + + try: + os.environ["GATEWAY_MASTER_KEY"] = "substituted-key" + + config = load_config(config_path) + + assert config.master_key == "substituted-key" + finally: + os.unlink(config_path) + if "GATEWAY_MASTER_KEY" in os.environ: + del os.environ["GATEWAY_MASTER_KEY"]