diff --git a/.gitignore b/.gitignore
index 18bf11a..3cccf15 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,5 +7,7 @@ venv
storage
dist
docs
-.aico
-!.aico/project.json
\ No newline at end of file
+.aico/*
+!.aico/project.json
+.coverage
+coverage.xml
diff --git a/README.md b/README.md
index b64fcde..2263fbf 100644
--- a/README.md
+++ b/README.md
@@ -34,10 +34,11 @@ It works as a drop-in replacement for OpenAI's API, allowing you to switch betwe
- [Basic Group Definition](#basic-group-definition)
- [Group-based Access Control](#group-based-access-control)
- [Connection Restrictions](#connection-restrictions)
- - [Custom API Key Validation](#custom-api-key-validation)
+ - [Virtual API Key Validation](#virtual-api-key-validation)
- [Advanced Usage](#%EF%B8%8F-advanced-usage)
- [Dynamic Model Routing](#dynamic-model-routing)
- [Load Balancing Example](#load-balancing-example)
+- [Debugging](#-debugging)
- [Contributing](#-contributing)
- [License](#-license)
@@ -138,7 +139,7 @@ port = 8000 # Port to listen on
dev_autoreload = false # Enable for development
# API key validation function (optional)
-check_api_key = "lm_proxy.core.check_api_key"
+api_key_check = "lm_proxy.api_key_check.check_api_key_in_config"
# LLM Provider Connections
[connections]
@@ -204,7 +205,7 @@ api_key = "env:OPENAI_API_KEY"
At runtime, LM-Proxy automatically retrieves the value of the target variable
(OPENAI_API_KEY) from your operating system’s environment or from a .env file, if present.
-### .env Files
+### .env Files
By default, LM-Proxy looks for a `.env` file in the current working directory
and loads environment variables from it.
@@ -410,7 +411,37 @@ This allows fine-grained control over which users can access which AI providers,
- Implementing usage quotas per group
- Billing and cost allocation by user group
-### Custom API Key Validation
+### Virtual API Key Validation
+
+#### Overview
+
+LM-proxy includes 2 built-in methods for validating Virtual API keys:
+ - `lm_proxy.api_key_check.check_api_key_in_config` - verifies API keys against those defined in the config file; used by default
+ - `lm_proxy.api_key_check.CheckAPIKeyWithRequest` - validates API keys via an external HTTP service
+
+The API key check method can be configured using the `api_key_check` configuration key.
+Its value can be either a reference to a Python function in the format `my_module.sub_module1.sub_module2.fn_name`,
+or an object containing parameters for a class-based validator.
+
+In the .py config representation, the validator function can be passed directly as a callable.
+
+#### Example configuration for external API key validation using HTTP request to Keycloak / OpenID Connect
+
+This example shows how to validate API keys against an external service (e.g., Keycloak):
+
+```toml
+[api_key_check]
+class = "lm_proxy.api_key_check.CheckAPIKeyWithRequest"
+method = "POST"
+url = "http://keycloak:8080/realms/master/protocol/openid-connect/userinfo"
+response_as_user_info = true # interpret response JSON as user info object for further processing / logging
+use_cache = true # requires installing cachetools if True: pip install cachetools
+cache_ttl = 60 # Cache duration in seconds
+
+[api_key_check.headers]
+Authorization = "Bearer {api_key}"
+```
+#### Custom API Key Validation / Extending functionality
For more advanced authentication needs,
you can implement a custom validator function:
@@ -437,7 +468,7 @@ def validate_api_key(api_key: str) -> str | None:
Then reference it in your config:
```toml
-check_api_key = "my_validators.validate_api_key"
+api_key_check = "my_validators.validate_api_key"
```
> **NOTE**
> In this case, the `api_keys` lists in groups are ignored, and the custom function is responsible for all validation logic.
@@ -457,12 +488,39 @@ The routing section allows flexible pattern matching with wildcards:
"custom*" = "local.llama-7b" # Map any "custom*" to a specific local model
"*" = "openai.gpt-3.5-turbo" # Default fallback for unmatched models
```
+Keys are model name patterns (with `*` wildcard support), and values are connection/model mappings.
+Connection names reference those defined in the `[connections]` section.
+
### Load Balancing Example
- [Simple load-balancer configuration](https://github.com/Nayjest/lm-proxy/blob/main/examples/load_balancer_config.py)
This example demonstrates how to set up a load balancer that randomly
distributes requests across multiple language model servers using the lm_proxy.
+## 🔍 Debugging
+
+### Overview
+When **debugging mode** is enabled,
+LM-Proxy provides detailed logging information to help diagnose issues:
+- Stack traces for exceptions are shown in the console
+- Logging level is set to DEBUG instead of INFO
+
+> **Warning** ⚠️
+> Never enable debugging mode in production environments, as it may expose sensitive information to the application logs.
+
+### Enabling Debugging Mode
+To enable debugging, set the `LM_PROXY_DEBUG` environment variable to a truthy value (e.g., "1", "true", "yes").
+> **Tip** 💡
+> Environment variables can also be defined in a `.env` file.
+
+Alternatively, you can enable or disable debugging via the command-line arguments:
+- `--debug` to enable debugging
+- `--no-debug` to disable debugging
+
+> **Note** ℹ️
+> CLI arguments override environment variable settings.
+
+
## 🤝 Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
diff --git a/config.toml b/config.toml
index 87e1af3..e9707dc 100644
--- a/config.toml
+++ b/config.toml
@@ -1,11 +1,11 @@
-# This is a lm-proxy configuration example
+# This is an lm-proxy configuration example
host="0.0.0.0"
port=8000
# dev_autoreload=true
# Validates a Client API key against configured groups and returns the matching group.
-check_api_key="lm_proxy.core.check_api_key"
+api_key_check="lm_proxy.api_key_check.check_api_key_in_config"
model_listing_mode = "as_is"
@@ -40,7 +40,7 @@ api_keys = [
[[loggers]]
class = 'lm_proxy.loggers.BaseLogger'
[loggers.log_writer]
-class = 'lm_proxy.loggers.log_writers.JsonLogWriter'
+class = 'lm_proxy.loggers.JsonLogWriter'
file_name = 'storage/json.log'
[loggers.entry_transformer]
class = 'lm_proxy.loggers.LogEntryTransformer'
diff --git a/coverage.svg b/coverage.svg
index eb83f3f..d79e240 100644
--- a/coverage.svg
+++ b/coverage.svg
@@ -9,13 +9,13 @@
-
+
coverage
coverage
- 59%
- 59%
+ 65%
+ 65%
diff --git a/lm_proxy/__main__.py b/lm_proxy/__main__.py
index 9bbc6a3..9837177 100644
--- a/lm_proxy/__main__.py
+++ b/lm_proxy/__main__.py
@@ -1,3 +1,4 @@
+"""Provides the CLI entry point when the package is executed as a Python module."""
from .app import cli_app
diff --git a/lm_proxy/api_key_check/__init__.py b/lm_proxy/api_key_check/__init__.py
new file mode 100644
index 0000000..09c35b4
--- /dev/null
+++ b/lm_proxy/api_key_check/__init__.py
@@ -0,0 +1,6 @@
+"""Collection of built-in API-key checkers for usage in the configuration."""
+from .in_config import check_api_key_in_config
+from .with_request import CheckAPIKeyWithRequest
+
+
+__all__ = ["check_api_key_in_config", "CheckAPIKeyWithRequest"]
diff --git a/lm_proxy/api_key_check/in_config.py b/lm_proxy/api_key_check/in_config.py
new file mode 100644
index 0000000..d75a0f3
--- /dev/null
+++ b/lm_proxy/api_key_check/in_config.py
@@ -0,0 +1,18 @@
+from typing import Optional
+from ..bootstrap import env
+
+
+def check_api_key_in_config(api_key: Optional[str]) -> Optional[str]:
+ """
+ Validates a Client API key against configured groups and returns the matching group name.
+
+ Args:
+ api_key (Optional[str]): The Virtual / Client API key to validate.
+ Returns:
+ Optional[str]: The group name if the API key is valid and found in a group,
+ None otherwise.
+ """
+ for group_name, group in env.config.groups.items():
+ if api_key in group.api_keys:
+ return group_name
+ return None
diff --git a/lm_proxy/api_key_check/with_request.py b/lm_proxy/api_key_check/with_request.py
new file mode 100644
index 0000000..260f889
--- /dev/null
+++ b/lm_proxy/api_key_check/with_request.py
@@ -0,0 +1,64 @@
+from typing import Optional
+from dataclasses import dataclass, field
+import requests
+
+from ..config import TApiKeyCheckFunc
+
+
+@dataclass(slots=True)
+class CheckAPIKeyWithRequest:
+ url: str = field()
+ method: str = field(default="get")
+ headers: dict = field(default_factory=dict)
+ response_as_user_info: bool = field(default=False)
+ group_field: Optional[str] = field(default=None)
+ default_group: str = field(default="default")
+ key_placeholder: str = field(default="{api_key}")
+ use_cache: bool = field(default=False)
+ cache_size: int = field(default=1024 * 16)
+ cache_ttl: int = field(default=60 * 5) # 5 minutes
+ timeout: int = field(default=5) # seconds
+ _func: TApiKeyCheckFunc = field(init=False, repr=False)
+
+ def __post_init__(self):
+ def check_func(api_key: str) -> Optional[tuple[str, dict]]:
+ try:
+ url = self.url.replace(self.key_placeholder, api_key)
+ headers = {
+ k: str(v).replace(self.key_placeholder, api_key)
+ for k, v in self.headers.items()
+ }
+ response = requests.request(
+ method=self.method,
+ url=url,
+ headers=headers,
+ timeout=self.timeout
+ )
+ response.raise_for_status()
+ group = self.default_group
+ user_info = None
+ if self.response_as_user_info:
+ user_info = response.json()
+ if self.group_field:
+ group = user_info.get(self.group_field, self.default_group)
+ return group, user_info
+ except requests.exceptions.RequestException:
+ return None
+
+ if self.use_cache:
+ try:
+ import cachetools # pylint: disable=import-outside-toplevel
+ except ImportError as e:
+ raise ImportError(
+ "Missing optional dependency 'cachetools'. "
+ "Using 'lm_proxy.api_key_check.CheckAPIKeyWithRequest' with 'use_cache = true' "
+ "requires installing 'cachetools' package. "
+ "\nPlease install it with following command: 'pip install cachetools'"
+ ) from e
+ cache = cachetools.TTLCache(maxsize=self.cache_size, ttl=self.cache_ttl)
+ self._func = cachetools.cached(cache)(check_func)
+ else:
+ self._func = check_func
+
+ def __call__(self, api_key: str) -> Optional[tuple[str, dict]]:
+ return self._func(api_key)
diff --git a/lm_proxy/app.py b/lm_proxy/app.py
index 80628f8..67162cf 100644
--- a/lm_proxy/app.py
+++ b/lm_proxy/app.py
@@ -1,3 +1,6 @@
+"""
+LM-Proxy Application Entrypoint
+"""
import logging
from typing import Optional
from fastapi import FastAPI
@@ -6,12 +9,11 @@
from .bootstrap import env, bootstrap
from .core import chat_completions
-from .models import models
+from .models_endpoint import models
cli_app = typer.Typer()
-# run-server is a default command of cli-app
@cli_app.callback(invoke_without_command=True)
def run_server(
config: Optional[str] = typer.Option(None, help="Path to the configuration file"),
@@ -26,6 +28,9 @@ def run_server(
help="Set the .env file to load ENV vars from",
),
):
+ """
+ Default command for CLI application: Run LM-Proxy web server
+ """
try:
bootstrap(config=config or "config.toml", env_file=env_file, debug=debug)
uvicorn.run(
@@ -38,12 +43,14 @@ def run_server(
except Exception as e:
if env.debug:
raise
- else:
- logging.error(e)
- raise typer.Exit(code=1)
+ logging.error(e)
+ raise typer.Exit(code=1)
def web_app():
+ """
+ Entrypoint for ASGI server
+ """
app = FastAPI(
title="LM-Proxy", description="OpenAI-compatible proxy server for LLM inference"
)
diff --git a/lm_proxy/base_types.py b/lm_proxy/base_types.py
new file mode 100644
index 0000000..4a79745
--- /dev/null
+++ b/lm_proxy/base_types.py
@@ -0,0 +1,55 @@
+"""Base types used in LM-Proxy."""
+import uuid
+from dataclasses import dataclass, field
+from datetime import datetime
+from typing import List, Optional, TYPE_CHECKING
+
+import microcore as mc
+from pydantic import BaseModel
+
+if TYPE_CHECKING:
+ from .config import Group
+
+
+class ChatCompletionRequest(BaseModel):
+ """
+ Request model for chat/completions endpoint.
+ """
+ model: str
+ messages: List[mc.Msg]
+ stream: Optional[bool] = None
+ max_tokens: Optional[int] = None
+ temperature: Optional[float] = None
+ top_p: Optional[float] = None
+ n: Optional[int] = None
+ stop: Optional[List[str]] = None
+ presence_penalty: Optional[float] = None
+ frequency_penalty: Optional[float] = None
+ user: Optional[str] = None
+
+
+@dataclass
+class RequestContext:
+ """
+ Stores information about a single LLM request/response cycle for usage in middleware.
+ """
+ id: Optional[str] = field(default_factory=lambda: str(uuid.uuid4()))
+ request: Optional[ChatCompletionRequest] = field(default=None)
+ response: Optional[mc.LLMResponse] = field(default=None)
+ error: Optional[Exception] = field(default=None)
+ group: Optional["Group"] = field(default=None)
+ connection: Optional[str] = field(default=None)
+ model: Optional[str] = field(default=None)
+ api_key_id: Optional[str] = field(default=None)
+ remote_addr: Optional[str] = field(default=None)
+ created_at: Optional[datetime] = field(default_factory=datetime.now)
+ duration: Optional[float] = field(default=None)
+ user_info: Optional[dict] = field(default=None)
+ extra: dict = field(default_factory=dict)
+
+ def to_dict(self) -> dict:
+ """Export as dictionary."""
+ data = self.__dict__.copy()
+ if self.request:
+ data["request"] = self.request.model_dump(mode="json")
+ return data
diff --git a/lm_proxy/bootstrap.py b/lm_proxy/bootstrap.py
index 92a2f6d..a91bc9b 100644
--- a/lm_proxy/bootstrap.py
+++ b/lm_proxy/bootstrap.py
@@ -1,8 +1,9 @@
+"""Initialization and bootstrapping."""
import sys
import logging
import inspect
from datetime import datetime
-
+from typing import TYPE_CHECKING
import microcore as mc
from microcore import ui
@@ -10,9 +11,14 @@
from dotenv import load_dotenv
from .config import Config
+from .utils import resolve_instance_or_callable
+
+if TYPE_CHECKING:
+ from .loggers import TLogger
def setup_logging(log_level: int = logging.INFO):
+ """Setup logging format and level."""
class CustomFormatter(logging.Formatter):
def format(self, record):
dt = datetime.fromtimestamp(record.created).strftime("%H:%M:%S")
@@ -31,20 +37,33 @@ def format(self, record):
class Env:
+ """Runtime environment singleton."""
config: Config
connections: dict[str, mc.types.LLMAsyncFunctionType]
debug: bool
+ components: dict
+ loggers: list["TLogger"]
+
+ def _init_components(self):
+ self.components = dict()
+ for name, component_data in self.config.components.items():
+ self.components[name] = resolve_instance_or_callable(component_data)
+ logging.info(f"Loaded component '{name}'")
@staticmethod
def init(config: Config | str, debug: bool = False):
env.debug = debug
- if isinstance(config, Config):
- env.config = config
- elif isinstance(config, str):
- env.config = Config.load(config)
- else:
- raise ValueError("config must be a string (file path) or Config instance")
+ if not isinstance(config, Config):
+ if isinstance(config, str):
+ config = Config.load(config)
+ else:
+ raise ValueError("config must be a string (file path) or Config instance")
+ env.config = config
+
+ env._init_components()
+
+ env.loggers = [resolve_instance_or_callable(logger) for logger in env.config.loggers]
# initialize connections
env.connections = dict()
@@ -70,17 +89,19 @@ def init(config: Config | str, debug: bool = False):
def bootstrap(config: str | Config = "config.toml", env_file: str = ".env", debug=None):
+ """Bootstraps the LM-Proxy environment."""
+ def log_bootstrap():
+ cfg_val = 'dynamic' if isinstance(config, Config) else ui.blue(config)
+ cfg_line = f"\n - Config{ui.gray('......')}[ {cfg_val} ]"
+ env_line = f"\n - Env. File{ui.gray('...')}[ {ui.blue(env_file)} ]" if env_file else ""
+ dbg_line = f"\n - Debug{ui.gray('.......')}[ {ui.yellow('On')} ]" if debug else ""
+ logging.info(f"Bootstrapping {ui.magenta('LM-Proxy')}...{cfg_line}{env_line}{dbg_line}")
+
if env_file:
load_dotenv(env_file, override=True)
if debug is None:
debug = "--debug" in sys.argv or get_bool_from_env("LM_PROXY_DEBUG", False)
setup_logging(logging.DEBUG if debug else logging.INFO)
mc.logging.LoggingConfig.OUTPUT_METHOD = logging.info
- logging.info(
- f"Bootstrapping {ui.yellow('lm_proxy')}: "
- f"config_file={'dynamic' if isinstance(config, Config) else ui.blue(config)}"
- f"{' debug=on' if debug else ''}"
- f"{' env_file=' + ui.blue(env_file) if env_file else ''}"
- f"..."
- )
+ log_bootstrap()
Env.init(config, debug=debug)
diff --git a/lm_proxy/config.py b/lm_proxy/config.py
index 42e5b83..aae67aa 100644
--- a/lm_proxy/config.py
+++ b/lm_proxy/config.py
@@ -5,14 +5,13 @@
import os
from enum import StrEnum
-from typing import Union, Callable
-import tomllib
-import importlib.util
+from typing import Union, Callable, Dict, Optional
+from importlib.metadata import entry_points
from pydantic import BaseModel, Field, ConfigDict
-from microcore.utils import resolve_callable
-from .utils import resolve_instance_or_callable
+from .utils import resolve_instance_or_callable, replace_env_strings_recursive
+from .loggers import TLogger
class ModelListingMode(StrEnum):
@@ -41,10 +40,17 @@ def allows_connecting_to(self, connection_name: str) -> bool:
return connection_name in allowed
+TApiKeyCheckResult = Optional[Union[str, tuple[str, dict]]]
+TApiKeyCheckFunc = Callable[[str | None], TApiKeyCheckResult]
+
+
class Config(BaseModel):
"""Main configuration model matching config.toml structure."""
- model_config = ConfigDict(extra="forbid")
+ model_config = ConfigDict(
+ extra="forbid",
+ arbitrary_types_allowed=True,
+ )
enabled: bool = True
host: str = "0.0.0.0"
port: int = 8000
@@ -56,9 +62,12 @@ class Config(BaseModel):
)
routing: dict[str, str] = Field(default_factory=dict)
""" model_name_pattern* => connection_name.< model | * >, example: {"gpt-*": "oai.*"} """
- groups: dict[str, Group] = Field(default_factory=dict)
- check_api_key: Union[str, Callable] = Field(default="lm_proxy.core.check_api_key")
- loggers: list[Union[str, Callable, dict]] = Field(default_factory=list)
+ groups: dict[str, Group] = Field(default_factory=lambda: {"default": Group()})
+ api_key_check: Union[str, TApiKeyCheckFunc, dict] = Field(
+ default="lm_proxy.api_key_check.check_api_key_in_config",
+ description="Function to check Virtual API keys",
+ )
+ loggers: list[Union[str, dict, TLogger]] = Field(default_factory=list)
encryption_key: str = Field(
default="Eclipse",
description="Key for encrypting sensitive data (must be explicitly set)",
@@ -67,19 +76,30 @@ class Config(BaseModel):
default=ModelListingMode.AS_IS,
description="How to handle wildcard models in /v1/models endpoint",
)
+ components: dict[str, Union[str, Callable, dict]] = Field(default_factory=dict)
def __init__(self, **data):
super().__init__(**data)
- self.check_api_key = resolve_callable(self.check_api_key)
- self.loggers = [resolve_instance_or_callable(logger) for logger in self.loggers]
- if not self.groups:
- # Default group with no restrictions
- self.groups = {"default": Group()}
+ self.api_key_check = resolve_instance_or_callable(
+ self.api_key_check,
+ debug_name="check_api_key",
+ )
+
+ @staticmethod
+ def _load_raw(config_path: str = "config.toml") -> Union["Config", Dict]:
+ config_ext = os.path.splitext(config_path)[1].lower().lstrip(".")
+ for entry_point in entry_points(group="config.loaders"):
+ if config_ext == entry_point.name:
+ loader = entry_point.load()
+ config_data = loader(config_path)
+ return config_data
+
+ raise ValueError(f"No loader found for configuration file extension: {config_ext}")
@staticmethod
def load(config_path: str = "config.toml") -> "Config":
"""
- Load configuration from a TOML file.
+ Load configuration from a TOML or Python file.
Args:
config_path: Path to the config.toml file
@@ -87,22 +107,10 @@ def load(config_path: str = "config.toml") -> "Config":
Returns:
Config object with parsed configuration
"""
- if config_path.endswith(".py"):
- spec = importlib.util.spec_from_file_location("config_module", config_path)
- config_module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(config_module)
- return config_module.config
- elif config_path.endswith(".toml"):
- with open(config_path, "rb") as f:
- config_data = tomllib.load(f)
- else:
- raise ValueError(f"Unsupported configuration file extension: {config_path}")
-
- # Process environment variables in api_key fields
- for conn_name, conn_config in config_data.get("connections", {}).items():
- for key, value in conn_config.items():
- if isinstance(value, str) and value.startswith("env:"):
- env_var = value.split(":", 1)[1]
- conn_config[key] = os.environ.get(env_var, "")
-
- return Config(**config_data)
+ config = Config._load_raw(config_path)
+ if isinstance(config, dict):
+ config = replace_env_strings_recursive(config)
+ config = Config(**config)
+ elif not isinstance(config, Config):
+ raise TypeError("Loaded configuration must be a dict or Config instance")
+ return config
diff --git a/lm_proxy/config_loaders/__init__.py b/lm_proxy/config_loaders/__init__.py
new file mode 100644
index 0000000..ea2fdc8
--- /dev/null
+++ b/lm_proxy/config_loaders/__init__.py
@@ -0,0 +1,12 @@
+"""Built-in configuration loaders for different file formats."""
+from .python import load_python_config
+from .toml import load_toml_config
+from .yaml import load_yaml_config
+from .json import load_json_config
+
+__all__ = [
+ "load_python_config",
+ "load_toml_config",
+ "load_yaml_config",
+ "load_json_config",
+]
diff --git a/lm_proxy/config_loaders/json.py b/lm_proxy/config_loaders/json.py
new file mode 100644
index 0000000..70b27b6
--- /dev/null
+++ b/lm_proxy/config_loaders/json.py
@@ -0,0 +1,8 @@
+"""JSON configuration loader."""
+import json
+
+
+def load_json_config(config_path: str) -> dict:
+ """Loads configuration from a JSON file."""
+ with open(config_path, "r", encoding="utf-8") as f:
+ return json.load(f)
diff --git a/lm_proxy/config_loaders/python.py b/lm_proxy/config_loaders/python.py
new file mode 100644
index 0000000..3a92c0c
--- /dev/null
+++ b/lm_proxy/config_loaders/python.py
@@ -0,0 +1,11 @@
+"""Loader for Python configuration files."""
+import importlib.util
+from ..config import Config
+
+
+def load_python_config(config_path: str) -> Config:
+ """Load configuration from a Python file."""
+ spec = importlib.util.spec_from_file_location("config_module", config_path)
+ config_module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(config_module)
+ return config_module.config
diff --git a/lm_proxy/config_loaders/toml.py b/lm_proxy/config_loaders/toml.py
new file mode 100644
index 0000000..d9202d4
--- /dev/null
+++ b/lm_proxy/config_loaders/toml.py
@@ -0,0 +1,8 @@
+"""TOML configuration loader."""
+import tomllib
+
+
+def load_toml_config(config_path: str) -> dict:
+ """Loads configuration from a TOML file."""
+ with open(config_path, "rb") as f:
+ return tomllib.load(f)
diff --git a/lm_proxy/config_loaders/yaml.py b/lm_proxy/config_loaders/yaml.py
new file mode 100644
index 0000000..4473b97
--- /dev/null
+++ b/lm_proxy/config_loaders/yaml.py
@@ -0,0 +1,16 @@
+"""YAML configuration loader."""
+
+
+def load_yaml_config(config_path: str) -> dict:
+ """Loads a YAML configuration file and returns its contents as a dictionary."""
+ try:
+ import yaml # pylint: disable=import-outside-toplevel
+ except ImportError as e:
+ raise ImportError(
+ "Missing optional dependency 'PyYAML'. "
+ "For using YAML configuration files with LM-Proxy, "
+ "please install it with the following command: 'pip install pyyaml'."
+ ) from e
+
+ with open(config_path, "r", encoding="utf-8") as f:
+ return yaml.safe_load(f)
diff --git a/lm_proxy/core.py b/lm_proxy/core.py
index dcd45ef..6ac5fc9 100644
--- a/lm_proxy/core.py
+++ b/lm_proxy/core.py
@@ -1,3 +1,4 @@
+"""Core LM-Proxy logic"""
import asyncio
import fnmatch
import json
@@ -5,35 +6,19 @@
import secrets
import time
import hashlib
-from typing import List, Optional
+from datetime import datetime
+from typing import Optional
-import microcore as mc
from fastapi import HTTPException
-from lm_proxy.loggers import LogEntry
-from pydantic import BaseModel
from starlette.requests import Request
from starlette.responses import JSONResponse, Response, StreamingResponse
+from .base_types import ChatCompletionRequest, RequestContext
from .bootstrap import env
from .config import Config
-from .loggers import log_non_blocking
from .utils import get_client_ip
-class ChatCompletionRequest(BaseModel):
- model: str
- messages: List[mc.Msg]
- stream: Optional[bool] = None
- max_tokens: Optional[int] = None
- temperature: Optional[float] = None
- top_p: Optional[float] = None
- n: Optional[int] = None
- stop: Optional[List[str]] = None
- presence_penalty: Optional[float] = None
- frequency_penalty: Optional[float] = None
- user: Optional[str] = None
-
-
def parse_routing_rule(rule: str, config: Config) -> tuple[str, str]:
"""
Parses a routing rule in the format 'connection.model' or 'connection.*'.
@@ -79,7 +64,7 @@ def resolve_connection_and_model(
async def process_stream(
- async_llm_func, request: ChatCompletionRequest, llm_params, log_entry: LogEntry
+ async_llm_func, request: ChatCompletionRequest, llm_params, log_entry: RequestContext
):
prompt = request.messages
queue = asyncio.Queue()
@@ -173,7 +158,7 @@ def api_key_id(api_key: Optional[str]) -> str | None:
).hexdigest()
-async def check(request: Request) -> tuple[str, str]:
+async def check(request: Request) -> tuple[str, str, dict]:
"""
API key and service availability check for endpoints.
Args:
@@ -196,7 +181,13 @@ async def check(request: Request) -> tuple[str, str]:
},
)
api_key = read_api_key(request)
- group: str | bool | None = (env.config.check_api_key)(api_key)
+ result = (env.config.api_key_check)(api_key)
+ if isinstance(result, tuple):
+ group, user_info = result
+ else:
+ group: str | bool | None = result
+ user_info = dict()
+
if not group:
raise HTTPException(
status_code=403,
@@ -210,7 +201,7 @@ async def check(request: Request) -> tuple[str, str]:
}
},
)
- return group, api_key
+ return group, api_key, user_info
async def chat_completions(
@@ -220,17 +211,19 @@ async def chat_completions(
Endpoint for chat completions that mimics OpenAI's API structure.
Streams the response from the LLM using microcore.
"""
- group, api_key = await check(raw_request)
+ group, api_key, user_info = await check(raw_request)
llm_params = request.model_dump(exclude={"messages"}, exclude_none=True)
connection, llm_params["model"] = resolve_connection_and_model(
env.config, llm_params.get("model", "default_model")
)
- log_entry = LogEntry(
+ log_entry = RequestContext(
request=request,
api_key_id=api_key_id(api_key),
group=group if isinstance(group, str) else None,
remote_addr=get_client_ip(raw_request),
connection=connection,
+ model=llm_params["model"],
+ user_info=user_info,
)
logging.debug(
"Resolved routing for [%s] --> connection: %s, model: %s",
@@ -282,3 +275,26 @@ async def chat_completions(
]
}
)
+
+
+async def log(log_entry: RequestContext):
+ if log_entry.duration is None and log_entry.created_at:
+ log_entry.duration = (datetime.now() - log_entry.created_at).total_seconds()
+ for handler in env.loggers:
+ # check if it is async, then run both sync and async loggers in non-blocking way (sync too)
+ if asyncio.iscoroutinefunction(handler):
+ asyncio.create_task(handler(log_entry))
+ else:
+ try:
+ handler(log_entry)
+ except Exception as e:
+ logging.error("Error in logger handler: %s", e)
+ raise e
+
+
+async def log_non_blocking(
+ log_entry: RequestContext,
+) -> Optional[asyncio.Task]:
+ if env.loggers:
+ task = asyncio.create_task(log(log_entry))
+ return task
diff --git a/lm_proxy/loggers.py b/lm_proxy/loggers.py
new file mode 100644
index 0000000..7bb25cf
--- /dev/null
+++ b/lm_proxy/loggers.py
@@ -0,0 +1,83 @@
+"""LLM Request logging."""
+import abc
+import json
+import os
+from dataclasses import dataclass, field
+from typing import Union, Callable
+
+from .base_types import RequestContext
+from .utils import CustomJsonEncoder, resolve_instance_or_callable, resolve_obj_path
+
+
+class AbstractLogEntryTransformer(abc.ABC):
+ """Transform RequestContext into a dictionary of logged attributes."""
+ @abc.abstractmethod
+ def __call__(self, request_context: RequestContext) -> dict:
+ raise NotImplementedError()
+
+
+class AbstractLogWriter(abc.ABC):
+ """Writes the logged data to a destination."""
+ @abc.abstractmethod
+ def __call__(self, logged_data: dict):
+ raise NotImplementedError()
+
+
+class LogEntryTransformer(AbstractLogEntryTransformer):
+ """Transforms RequestContext into a dictionary of logged attributes"""
+ def __init__(self, **kwargs):
+ self.mapping = kwargs
+
+ def __call__(self, request_context: RequestContext) -> dict:
+ result = {}
+ for key, path in self.mapping.items():
+ result[key] = resolve_obj_path(request_context, path)
+ return result
+
+
+@dataclass
+class BaseLogger:
+ """Base LLM request logger."""
+ log_writer: AbstractLogWriter | str | dict
+ entry_transformer: AbstractLogEntryTransformer | str | dict = field(default=None)
+
+ def __post_init__(self):
+ self.entry_transformer = resolve_instance_or_callable(
+ self.entry_transformer,
+ debug_name="logging..entry_transformer",
+ )
+ self.log_writer = resolve_instance_or_callable(
+ self.log_writer,
+ debug_name="logging..log_writer",
+ )
+
+ def _transform(self, request_context: RequestContext) -> dict:
+ return (
+ self.entry_transformer(request_context)
+ if self.entry_transformer
+ else request_context.to_dict()
+ )
+
+ def __call__(self, request_context: RequestContext):
+ self.log_writer(self._transform(request_context))
+
+
+@dataclass
+class JsonLogWriter(AbstractLogWriter):
+ """Writes logged data to a JSON file."""
+ file_name: str
+
+ def __post_init__(self):
+ dir_path = os.path.dirname(self.file_name)
+ if dir_path:
+ os.makedirs(dir_path, exist_ok=True)
+ # Create the file if it doesn't exist
+ with open(self.file_name, "a", encoding="utf-8"):
+ pass
+
+ def __call__(self, logged_data: dict):
+ with open(self.file_name, "a", encoding="utf-8") as f:
+ f.write(json.dumps(logged_data, cls=CustomJsonEncoder) + "\n")
+
+
+TLogger = Union[BaseLogger, Callable[[RequestContext], None]]
diff --git a/lm_proxy/loggers/__init__.py b/lm_proxy/loggers/__init__.py
deleted file mode 100644
index 3d8ab1c..0000000
--- a/lm_proxy/loggers/__init__.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from .base_logger import BaseLogger, LogEntryTransformer
-from .log_writers import JsonLogWriter
-from .core import LogEntry, log_non_blocking
-
-__all__ = [
- "BaseLogger",
- "LogEntryTransformer",
- "JsonLogWriter",
- "LogEntry",
- "log_non_blocking",
-]
diff --git a/lm_proxy/loggers/base_logger.py b/lm_proxy/loggers/base_logger.py
deleted file mode 100644
index 55241a0..0000000
--- a/lm_proxy/loggers/base_logger.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import abc
-from dataclasses import dataclass, field
-
-from lm_proxy.utils import resolve_instance_or_callable
-
-from ..utils import resolve_obj_path
-from .core import LogEntry
-
-
-class AbstractLogEntryTransformer(abc.ABC):
- @abc.abstractmethod
- def __call__(self, log_entry: LogEntry) -> dict:
- raise NotImplementedError()
-
-
-class LogEntryTransformer(AbstractLogEntryTransformer):
- def __init__(self, **kwargs):
- self.mapping = kwargs
-
- def __call__(self, log_entry: LogEntry) -> dict:
- result = {}
- for key, path in self.mapping.items():
- result[key] = resolve_obj_path(log_entry, path)
- return result
-
-
-class AbstractLogWriter(abc.ABC):
- @abc.abstractmethod
- def __call__(self, logged_data: dict) -> dict:
- raise NotImplementedError()
-
-
-@dataclass
-class BaseLogger:
- log_writer: AbstractLogWriter | str | dict
- entry_transformer: AbstractLogEntryTransformer | str | dict = field(default=None)
-
- def __post_init__(self):
- self.entry_transformer = resolve_instance_or_callable(
- self.entry_transformer,
- debug_name="logging..entry_transformer",
- )
- self.log_writer = resolve_instance_or_callable(
- self.log_writer,
- debug_name="logging..log_writer",
- )
-
- def _transform(self, log_entry: LogEntry) -> dict:
- return (
- self.entry_transformer(log_entry)
- if self.entry_transformer
- else log_entry.to_dict()
- )
-
- def __call__(self, log_entry: LogEntry):
- self.log_writer(self._transform(log_entry))
diff --git a/lm_proxy/loggers/core.py b/lm_proxy/loggers/core.py
deleted file mode 100644
index 95854fd..0000000
--- a/lm_proxy/loggers/core.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import asyncio
-import logging
-from typing import Optional, TYPE_CHECKING
-from dataclasses import dataclass, field
-from datetime import datetime
-
-import microcore as mc
-from ..bootstrap import env
-
-if TYPE_CHECKING:
- from lm_proxy.core import ChatCompletionRequest, Group
-
-
-@dataclass
-class LogEntry:
- request: "ChatCompletionRequest" = field()
- response: Optional[mc.LLMResponse] = field(default=None)
- error: Optional[Exception] = field(default=None)
- group: "Group" = field(default=None)
- connection: str = field(default=None)
- api_key_id: Optional[str] = field(default=None)
- remote_addr: Optional[str] = field(default=None)
- created_at: Optional[datetime] = field(default_factory=datetime.now)
- duration: Optional[float] = field(default=None)
-
- def to_dict(self) -> dict:
- data = self.__dict__.copy()
- if self.request:
- data["request"] = self.request.model_dump(mode="json")
- return data
-
-
-async def log(log_entry: LogEntry):
- if log_entry.duration is None and log_entry.created_at:
- log_entry.duration = (datetime.now() - log_entry.created_at).total_seconds()
- for handler in env.config.loggers:
- # check if it is async, then run both sync and async loggers in non-blocking way (sync too)
- if asyncio.iscoroutinefunction(handler):
- asyncio.create_task(handler(log_entry))
- else:
- try:
- handler(log_entry)
- except Exception as e:
- logging.error("Error in logger handler: %s", e)
- raise e
-
-
-async def log_non_blocking(
- log_entry: LogEntry,
-) -> Optional[asyncio.Task]:
- if env.config.loggers:
- task = asyncio.create_task(log(log_entry))
- return task
diff --git a/lm_proxy/loggers/log_writers.py b/lm_proxy/loggers/log_writers.py
deleted file mode 100644
index 2b12df3..0000000
--- a/lm_proxy/loggers/log_writers.py
+++ /dev/null
@@ -1,24 +0,0 @@
-import os
-import json
-from dataclasses import dataclass
-
-from .base_logger import AbstractLogWriter
-from ..utils import CustomJsonEncoder
-
-
-@dataclass
-class JsonLogWriter(AbstractLogWriter):
-
- file_name: str
-
- def __post_init__(self):
- dir_path = os.path.dirname(self.file_name)
- if dir_path:
- os.makedirs(dir_path, exist_ok=True)
- # Create the file if it doesn't exist
- with open(self.file_name, "a", encoding="utf-8"):
- pass
-
- def __call__(self, logged_data: dict):
- with open(self.file_name, "a", encoding="utf-8") as f:
- f.write(json.dumps(logged_data, cls=CustomJsonEncoder) + "\n")
diff --git a/lm_proxy/models.py b/lm_proxy/models_endpoint.py
similarity index 81%
rename from lm_proxy/models.py
rename to lm_proxy/models_endpoint.py
index 27f6a3a..70272d5 100644
--- a/lm_proxy/models.py
+++ b/lm_proxy/models_endpoint.py
@@ -14,9 +14,9 @@ async def models(request: Request) -> JSONResponse:
"""
Lists available models based on routing rules and group permissions.
"""
- group_name, api_key = await check(request)
+ group_name, api_key, user_info = await check(request)
group: Group = env.config.groups[group_name]
- models = list()
+ models = []
for model_pattern, route in env.config.routing.items():
connection_name, _ = parse_routing_rule(route, env.config)
if group.allows_connecting_to(connection_name):
@@ -28,11 +28,10 @@ async def models(request: Request) -> JSONResponse:
== ModelListingMode.IGNORE_WILDCARDS
):
continue
- else:
- raise NotImplementedError(
- f"'{env.config.model_listing_mode}' model listing mode "
- f"is not implemented yet"
- )
+ raise NotImplementedError(
+ f"'{env.config.model_listing_mode}' model listing mode "
+ f"is not implemented yet"
+ )
models.append(
dict(
id=model_pattern,
diff --git a/lm_proxy/utils.py b/lm_proxy/utils.py
index cbbc7e1..7b1e29e 100644
--- a/lm_proxy/utils.py
+++ b/lm_proxy/utils.py
@@ -1,6 +1,9 @@
+"""Common usage utility functions."""
+import os
import json
import inspect
-from typing import Union, Callable
+import logging
+from typing import Any, Callable, Union
from datetime import datetime, date, time
from microcore.utils import resolve_callable
@@ -21,42 +24,48 @@ def resolve_obj_path(obj, path: str, default=None):
def resolve_instance_or_callable(
- item: Union[str, Callable, dict], class_key: str = "class", debug_name: str = None
+ item: Union[str, Callable, dict, object],
+ class_key: str = "class",
+ debug_name: str = None,
+ allow_types: list[type] = None,
) -> Callable:
- if not item:
+ if item is None or item == "":
return None
if isinstance(item, dict):
if class_key not in item:
raise ValueError(
f"'{class_key}' key is missing in {debug_name or 'item'} config: {item}"
)
- class_name = item.pop(class_key)
+ args = dict(item)
+ class_name = args.pop(class_key)
constructor = resolve_callable(class_name)
- return constructor(**item)
+ return constructor(**args)
if isinstance(item, str):
fn = resolve_callable(item)
return fn() if inspect.isclass(fn) else fn
if callable(item):
return item() if inspect.isclass(item) else item
+ if allow_types and any(isinstance(item, t) for t in allow_types):
+ return item
else:
raise ValueError(f"Invalid {debug_name or 'item'} config: {item}")
class CustomJsonEncoder(json.JSONEncoder):
- def default(self, obj):
- if isinstance(obj, datetime):
- return obj.isoformat()
- elif isinstance(obj, date):
- return obj.isoformat()
- elif isinstance(obj, time):
- return obj.isoformat()
- elif hasattr(obj, "__dict__"):
- return obj.__dict__
- elif hasattr(obj, "model_dump"):
- return obj.model_dump()
- elif hasattr(obj, "dict"):
- return obj.dict()
- return super().default(obj)
+ def default(self, o):
+ if isinstance(o, datetime):
+ return o.isoformat()
+ if isinstance(o, date):
+ return o.isoformat()
+ if isinstance(o, time):
+ return o.isoformat()
+ if hasattr(o, "__dict__"):
+ return o.__dict__
+ if hasattr(o, "model_dump"):
+ return o.model_dump()
+ if hasattr(o, "dict"):
+ return o.dict()
+ return super().default(o)
def get_client_ip(request: Request) -> str:
@@ -71,3 +80,22 @@ def get_client_ip(request: Request) -> str:
# Fallback to direct client
return request.client.host if request.client else "unknown"
+
+
+def replace_env_strings_recursive(data: Any) -> Any:
+ """
+ Recursively traverses dicts and lists, replacing all string values
+ that start with 'env:' with the corresponding environment variable.
+ For example, a string "env:VAR_NAME" will be replaced by the value of the
+ environment variable "VAR_NAME".
+ """
+ if isinstance(data, dict):
+ return {k: replace_env_strings_recursive(v) for k, v in data.items()}
+ if isinstance(data, list):
+ return [replace_env_strings_recursive(i) for i in data]
+ if isinstance(data, str) and data.startswith("env:"):
+ env_var_name = data[4:]
+ if env_var_name not in os.environ:
+ logging.warning(f"Environment variable '{env_var_name}' not found")
+ return os.environ.get(env_var_name, "")
+ return data
diff --git a/poetry.lock b/poetry.lock
index b3c4b28..edb7ba5 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -2007,4 +2007,4 @@ standard = ["colorama (>=0.4) ; sys_platform == \"win32\"", "httptools (>=0.6.3)
[metadata]
lock-version = "2.1"
python-versions = ">=3.11,<4"
-content-hash = "20eaf54bdd65e34c48db3cdd92557f8bdf0f51f1e77cdd80ea2b689636b08f7f"
+content-hash = "384cc8da11306adeca8a6ec7e6800fa2ad4c47196e0d50cd9dc523e29d3ab7f1"
diff --git a/pyproject.toml b/pyproject.toml
index c9eff74..28703f5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "lm-proxy"
-version = "1.1.0"
+version = "2.0.0"
description = "\"LM-Proxy\" is OpenAI-compatible http proxy server for inferencing various LLMs capable of working with Google, Anthropic, OpenAI APIs, local PyTorch inference, etc."
readme = "README.md"
keywords = ["llm", "large language models", "ai", "gpt", "openai", "proxy", "http", "proxy-server"]
@@ -18,6 +18,7 @@ dependencies = [
"fastapi~=0.116.1",
"uvicorn>=0.22.0",
"typer>=0.16.1",
+ "requests~=2.32.3",
]
requires-python = ">=3.11,<4"
@@ -33,6 +34,13 @@ license = { file = "LICENSE" }
[project.urls]
"Source Code" = "https://github.com/Nayjest/lm-proxy"
+[project.entry-points."config.loaders"]
+toml = "lm_proxy.config_loaders:load_toml_config"
+py = "lm_proxy.config_loaders:load_python_config"
+yml = "lm_proxy.config_loaders:load_yaml_config"
+yaml = "lm_proxy.config_loaders:load_yaml_config"
+json = "lm_proxy.config_loaders:load_json_config"
+
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
@@ -51,3 +59,6 @@ lm-proxy = "lm_proxy.app:cli_app"
[tool.pytest.ini_options]
asyncio_mode = "auto"
+testpaths = [
+ "tests",
+]
\ No newline at end of file
diff --git a/tests/configs/__init__.py b/tests/configs/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/configs/config_fn.py b/tests/configs/config_fn.py
index fe9be9e..86812e3 100644
--- a/tests/configs/config_fn.py
+++ b/tests/configs/config_fn.py
@@ -8,7 +8,7 @@
from lm_proxy.config import Config, Group # noqa
-def check_api_key(api_key: str) -> str | None:
+def custom_api_key_check(api_key: str) -> str | None:
return "default" if api_key == "py-test" else None
@@ -20,7 +20,7 @@ def check_api_key(api_key: str) -> str | None:
config = Config(
port=8123,
host="127.0.0.1",
- check_api_key=check_api_key,
+ api_key_check=custom_api_key_check,
connections={"py_oai": mc.env().llm_async_function},
routing={"*": "py_oai.gpt-3.5-turbo", "my-gpt": "py_oai.gpt-3.5-turbo"},
groups={"default": Group(connections="*")},
diff --git a/tests/configs/test_config.json b/tests/configs/test_config.json
new file mode 100644
index 0000000..555c04e
--- /dev/null
+++ b/tests/configs/test_config.json
@@ -0,0 +1,30 @@
+{
+ "host": "127.0.0.1",
+ "port": 8787,
+ "connections": {
+ "test_openai": {
+ "api_type": "open_ai",
+ "api_base": "https://api.openai.com/v1/",
+ "api_key": "env:OPENAI_API_KEY"
+ },
+ "test_google": {
+ "api_type": "google_ai_studio",
+ "api_key": "env:GOOGLE_API_KEY"
+ },
+ "test_anthropic": {
+ "api_type": "anthropic",
+ "api_key": "env:ANTHROPIC_API_KEY"
+ }
+ },
+ "routing": {
+ "gpt*": "test_openai.*",
+ "claude*": "test_anthropic.*",
+ "gemini*": "test_google.*",
+ "*": "test_openai.gpt-5"
+ },
+ "groups": {
+ "default": {
+ "api_keys": []
+ }
+ }
+}
\ No newline at end of file
diff --git a/tests/configs/test_config.toml b/tests/configs/test_config.toml
index 7efebe0..13591b3 100644
--- a/tests/configs/test_config.toml
+++ b/tests/configs/test_config.toml
@@ -1,6 +1,5 @@
host = "127.0.0.1"
port = 8787
-check_api_key = "tests.conftest.check_api_function"
[connections]
[connections.test_openai]
diff --git a/tests/configs/test_config.yml b/tests/configs/test_config.yml
new file mode 100644
index 0000000..7c4bf38
--- /dev/null
+++ b/tests/configs/test_config.yml
@@ -0,0 +1,24 @@
+host: "127.0.0.1"
+port: 8787
+
+connections:
+ test_openai:
+ api_type: "open_ai"
+ api_base: "https://api.openai.com/v1/"
+ api_key: "env:OPENAI_API_KEY"
+ test_google:
+ api_type: "google_ai_studio"
+ api_key: "env:GOOGLE_API_KEY"
+ test_anthropic:
+ api_type: "anthropic"
+ api_key: "env:ANTHROPIC_API_KEY"
+
+routing:
+ "gpt*": "test_openai.*"
+ "claude*": "test_anthropic.*"
+ "gemini*": "test_google.*"
+ "*": "test_openai.gpt-5"
+
+groups:
+ default:
+ api_keys: []
diff --git a/tests/test_config_loaders.py b/tests/test_config_loaders.py
new file mode 100644
index 0000000..cf4e88f
--- /dev/null
+++ b/tests/test_config_loaders.py
@@ -0,0 +1,26 @@
+import os
+from pathlib import Path
+
+import dotenv
+import pytest
+
+from lm_proxy.config import Config
+
+
+def test_config_loaders():
+ root = Path(__file__).resolve().parent
+ dotenv.load_dotenv(root.parent / ".env.template", override=True)
+ oai_key = os.getenv("OPENAI_API_KEY")
+ toml = Config.load(root / "configs" / "test_config.toml")
+ json = Config.load(root / "configs" / "test_config.json")
+ yaml = Config.load(root / "configs" / "test_config.yml")
+
+ assert json.model_dump() == yaml.model_dump() == toml.model_dump()
+ assert json.connections["test_openai"]["api_key"] == oai_key
+
+ py = Config.load(root / "configs" / "config_fn.py")
+ assert isinstance(py, Config)
+
+ # Expect an error for unsupported format
+ with pytest.raises(ValueError):
+ Config.load(root / "configs" / "test_config.xyz")
diff --git a/tests/test_integration.py b/tests/test_integration.py
index df81c65..1740167 100644
--- a/tests/test_integration.py
+++ b/tests/test_integration.py
@@ -3,17 +3,17 @@
from tests.conftest import ServerFixture
-def configure_mc(cfg: ServerFixture):
+def configure_mc_to_use_local_proxy(cfg: ServerFixture):
mc.configure(
LLM_API_TYPE="openai",
LLM_API_BASE=f"http://127.0.0.1:{cfg.port}/v1", # Test server port
LLM_API_KEY=cfg.api_key, # Not used but required
- MODEL=cfg.model, # Will be routed according to test_config.toml
+ MODEL=cfg.model,
)
def test_france_capital_query(server_config_fn: ServerFixture):
- configure_mc(server_config_fn)
+ configure_mc_to_use_local_proxy(server_config_fn)
response = mc.llm("What is the capital of France?\n (!) Respond with 1 word.")
assert (
"paris" in response.lower().strip()
@@ -33,6 +33,7 @@ def test_direct_api_call(server_config_fn: ServerFixture):
"Content-Type": "application/json",
"authorization": f"bearer {cfg.api_key}",
},
+ timeout=120,
)
assert (
@@ -51,7 +52,7 @@ def test_direct_api_call(server_config_fn: ServerFixture):
def test_streaming_response(server_config_fn: ServerFixture):
- configure_mc(server_config_fn)
+ configure_mc_to_use_local_proxy(server_config_fn)
collected_text = []
mc.llm(
"Count from 1 to 5, each number as english word (one, two, ...) on a new line",
diff --git a/tests/test_loggers.py b/tests/test_loggers.py
index a95c3ce..32945a9 100644
--- a/tests/test_loggers.py
+++ b/tests/test_loggers.py
@@ -2,9 +2,8 @@
import microcore as mc
-from lm_proxy.core import ChatCompletionRequest
-from lm_proxy.loggers import LogEntry
-from lm_proxy.loggers.core import log_non_blocking
+from lm_proxy.core import log_non_blocking
+from lm_proxy.base_types import ChatCompletionRequest, RequestContext
from lm_proxy.config import Config
from lm_proxy.bootstrap import bootstrap
from lm_proxy.utils import CustomJsonEncoder
@@ -31,7 +30,7 @@ async def test_custom_config():
messages=[{"role": "user", "content": "Test request message"}],
)
response = mc.LLMResponse("Test response message", dict(prompt=request.messages))
- task = await log_non_blocking(LogEntry(request=request, response=response))
+ task = await log_non_blocking(RequestContext(request=request, response=response))
if task:
await task
assert len(logs) == 1
@@ -60,10 +59,10 @@ async def test_json(tmp_path):
messages=[{"role": "user", "content": "Test request message"}],
)
response = mc.LLMResponse("Test response message", dict(prompt=request.messages))
- task = await log_non_blocking(LogEntry(request=request, response=response))
+ task = await log_non_blocking(RequestContext(request=request, response=response))
if task:
await task
- task = await log_non_blocking(LogEntry(request=request, response=response))
+ task = await log_non_blocking(RequestContext(request=request, response=response))
if task:
await task
with open(tmp_path / "json_log.log", "r") as f:
diff --git a/tests/test_models.py b/tests/test_models_endpoint.py
similarity index 93%
rename from tests/test_models.py
rename to tests/test_models_endpoint.py
index a80a407..41eb617 100644
--- a/tests/test_models.py
+++ b/tests/test_models_endpoint.py
@@ -3,9 +3,8 @@
from starlette.requests import Request
from lm_proxy.config import Config, ModelListingMode
-from lm_proxy.bootstrap import bootstrap
-from lm_proxy.models import models
-from lm_proxy.bootstrap import env
+from lm_proxy.bootstrap import bootstrap, env
+from lm_proxy.models_endpoint import models
async def test_models_endpoint():
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 0000000..7ae730b
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,58 @@
+import os
+import logging
+
+import pytest
+from lm_proxy.utils import resolve_instance_or_callable, replace_env_strings_recursive
+
+
+def test_resolve_instance_or_callable():
+ assert resolve_instance_or_callable(None) is None
+
+ obj1, obj2 = object(), object()
+ ins = resolve_instance_or_callable(obj1, allow_types=[object])
+ assert ins is obj1 and ins is not obj2
+
+ with pytest.raises(ValueError):
+ resolve_instance_or_callable(123)
+
+ with pytest.raises(ValueError):
+ resolve_instance_or_callable([])
+
+ with pytest.raises(ValueError):
+ resolve_instance_or_callable({})
+
+ assert resolve_instance_or_callable(lambda: 42)() == 42
+
+ class MyClass:
+ def __init__(self, value=0):
+ self.value = value
+
+ res = resolve_instance_or_callable(lambda: MyClass(10), allow_types=[MyClass])
+ assert not isinstance(res, MyClass) and res().value == 10
+
+ ins = resolve_instance_or_callable(MyClass(20), allow_types=[MyClass])
+ assert isinstance(ins, MyClass) and ins.value == 20
+ assert resolve_instance_or_callable(
+ "lm_proxy.utils.resolve_instance_or_callable"
+ ) is resolve_instance_or_callable
+
+ ins = resolve_instance_or_callable({
+ 'class': 'lm_proxy.loggers.JsonLogWriter',
+ 'file_name': 'test.log'
+ })
+ assert ins.__class__.__name__ == 'JsonLogWriter' and ins.file_name == 'test.log'
+
+
+def test_replace_env_strings_recursive(caplog):
+ os.environ['TEST_VAR1'] = 'env_value1'
+ os.environ['TEST_VAR2'] = 'env_value2'
+ assert replace_env_strings_recursive("env:TEST_VAR1") == 'env_value1'
+
+ caplog.set_level(logging.WARNING)
+ assert replace_env_strings_recursive("env:NON_EXIST") == ''
+ assert len(caplog.records) == 1
+
+ assert replace_env_strings_recursive([["env:TEST_VAR1"]]) == [['env_value1']]
+ assert replace_env_strings_recursive(
+ {"data": {"field": "env:TEST_VAR1"}}
+ ) == {"data": {"field": "env_value1"}}