diff --git a/cascadeflow/__init__.py b/cascadeflow/__init__.py index aabe6191..b9bc7682 100644 --- a/cascadeflow/__init__.py +++ b/cascadeflow/__init__.py @@ -239,6 +239,21 @@ get_tool_risk_routing, ) +# NEW: Harness API scaffold (V2 core branch) +from .harness import ( + HarnessConfig, + HarnessInitReport, + HarnessRunContext, + init, + reset, + run, + agent as harness_agent, + get_harness_config, + get_current_run, + get_harness_callback_manager, + set_harness_callback_manager, +) + # ==================== MAIN AGENT & RESULT ==================== @@ -381,6 +396,18 @@ "ToolRiskClassification", # NEW: v0.8.0 - Classification result "ToolRiskClassifier", # NEW: v0.8.0 - Tool risk classifier "get_tool_risk_routing", # NEW: v0.8.0 - Routing by risk level + # ===== HARNESS API (V2 scaffold) ===== + "HarnessConfig", + "HarnessInitReport", + "HarnessRunContext", + "init", + "reset", + "run", + "harness_agent", + "get_harness_config", + "get_current_run", + "get_harness_callback_manager", + "set_harness_callback_manager", # ===== PROVIDERS ===== "ModelResponse", "BaseProvider", diff --git a/cascadeflow/harness/__init__.py b/cascadeflow/harness/__init__.py new file mode 100644 index 00000000..74c07219 --- /dev/null +++ b/cascadeflow/harness/__init__.py @@ -0,0 +1,38 @@ +""" +Core harness API scaffold for V2 planning work. + +This module provides a minimal, backward-compatible surface: +- init(): global harness settings (opt-in) +- run(): scoped run context for budget/trace accounting +- agent(): decorator for attaching policy metadata + +The implementation intentionally avoids modifying existing CascadeAgent behavior. +""" + +from .api import ( + HarnessConfig, + HarnessInitReport, + HarnessRunContext, + agent, + get_harness_callback_manager, + get_current_run, + get_harness_config, + init, + reset, + run, + set_harness_callback_manager, +) + +__all__ = [ + "HarnessConfig", + "HarnessInitReport", + "HarnessRunContext", + "init", + "run", + "agent", + "get_current_run", + "get_harness_callback_manager", + "get_harness_config", + "set_harness_callback_manager", + "reset", +] diff --git a/cascadeflow/harness/api.py b/cascadeflow/harness/api.py new file mode 100644 index 00000000..6221164e --- /dev/null +++ b/cascadeflow/harness/api.py @@ -0,0 +1,628 @@ +from __future__ import annotations + +import inspect +import json +import logging +import os +import time +from contextvars import ContextVar, Token +from dataclasses import dataclass, field +from functools import wraps +from importlib.util import find_spec +from pathlib import Path +from typing import Any, Callable, Literal, Optional, TypeVar, cast +from uuid import uuid4 + +logger = logging.getLogger("cascadeflow.harness") + +HarnessMode = Literal["off", "observe", "enforce"] + + +@dataclass +class HarnessConfig: + mode: HarnessMode = "off" + verbose: bool = False + budget: Optional[float] = None + max_tool_calls: Optional[int] = None + max_latency_ms: Optional[float] = None + max_energy: Optional[float] = None + kpi_targets: Optional[dict[str, float]] = None + kpi_weights: Optional[dict[str, float]] = None + compliance: Optional[str] = None + + +@dataclass +class HarnessInitReport: + mode: HarnessMode + instrumented: list[str] + detected_but_not_instrumented: list[str] + config_sources: dict[str, str] + + +@dataclass +class HarnessRunContext: + run_id: str = field(default_factory=lambda: uuid4().hex[:12]) + _started_monotonic: float = field(default_factory=time.monotonic, init=False, repr=False) + started_at_ms: float = field(default_factory=lambda: time.time() * 1000) + ended_at_ms: Optional[float] = None + duration_ms: Optional[float] = None + mode: HarnessMode = "off" + budget_max: Optional[float] = None + tool_calls_max: Optional[int] = None + latency_max_ms: Optional[float] = None + energy_max: Optional[float] = None + kpi_targets: Optional[dict[str, float]] = None + kpi_weights: Optional[dict[str, float]] = None + compliance: Optional[str] = None + + cost: float = 0.0 + savings: float = 0.0 + tool_calls: int = 0 + step_count: int = 0 + latency_used_ms: float = 0.0 + energy_used: float = 0.0 + budget_remaining: Optional[float] = None + model_used: Optional[str] = None + last_action: str = "allow" + draft_accepted: Optional[bool] = None + _trace: list[dict[str, Any]] = field(default_factory=list) + _token: Optional[Token[Optional[HarnessRunContext]]] = field( + default=None, init=False, repr=False + ) + + def __post_init__(self) -> None: + if self.budget_max is not None and self.budget_remaining is None: + self.budget_remaining = self.budget_max + + def __enter__(self) -> HarnessRunContext: + self._token = _current_run.set(self) + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + self.ended_at_ms = time.time() * 1000 + self.duration_ms = max(0.0, (time.monotonic() - self._started_monotonic) * 1000.0) + self._log_summary() + if self._token is not None: + _current_run.reset(self._token) + self._token = None + + async def __aenter__(self) -> HarnessRunContext: + return self.__enter__() + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + self.__exit__(exc_type, exc, tb) + + def trace(self) -> list[dict[str, Any]]: + return list(self._trace) + + def summary(self) -> dict[str, Any]: + return { + "run_id": self.run_id, + "mode": self.mode, + "step_count": self.step_count, + "tool_calls": self.tool_calls, + "cost": self.cost, + "savings": self.savings, + "latency_used_ms": self.latency_used_ms, + "energy_used": self.energy_used, + "budget_max": self.budget_max, + "budget_remaining": self.budget_remaining, + "last_action": self.last_action, + "model_used": self.model_used, + "duration_ms": self.duration_ms, + } + + def _log_summary(self) -> None: + if self.mode == "off" or self.step_count <= 0: + return + logger.info( + ( + "harness run summary run_id=%s mode=%s steps=%d tool_calls=%d " + "cost=%.6f latency_ms=%.2f energy=%.4f last_action=%s model=%s " + "budget_remaining=%s" + ), + self.run_id, + self.mode, + self.step_count, + self.tool_calls, + self.cost, + self.latency_used_ms, + self.energy_used, + self.last_action, + self.model_used, + self.budget_remaining, + ) + + def record( + self, + action: str, + reason: str, + model: Optional[str] = None, + *, + applied: Optional[bool] = None, + decision_mode: Optional[str] = None, + ) -> None: + safe_action = _sanitize_trace_value(action, max_length=_MAX_ACTION_LEN) + if not safe_action: + logger.warning("record() called with empty action, defaulting to 'allow'") + safe_action = "allow" + safe_reason = _sanitize_trace_value(reason, max_length=_MAX_REASON_LEN) or "unspecified" + safe_model = ( + _sanitize_trace_value(model, max_length=_MAX_MODEL_LEN) if model is not None else None + ) + + self.last_action = safe_action + self.model_used = safe_model + entry: dict[str, Any] = { + "action": safe_action, + "reason": safe_reason, + "model": safe_model, + "run_id": self.run_id, + "mode": self.mode, + "step": self.step_count, + "timestamp_ms": time.time() * 1000, + "tool_calls_total": self.tool_calls, + "cost_total": self.cost, + "latency_used_ms": self.latency_used_ms, + "energy_used": self.energy_used, + "budget_state": { + "max": self.budget_max, + "remaining": self.budget_remaining, + }, + } + if applied is not None: + entry["applied"] = applied + if decision_mode is not None: + entry["decision_mode"] = decision_mode + self._trace.append(entry) + _emit_harness_decision(entry) + + +_harness_config: HarnessConfig = HarnessConfig() +_current_run: ContextVar[Optional[HarnessRunContext]] = ContextVar( + "cascadeflow_harness_run", default=None +) +_is_instrumented: bool = False +_harness_callback_manager: Any = None +_UNSET = object() + + +def _validate_mode(mode: str) -> HarnessMode: + if mode not in {"off", "observe", "enforce"}: + raise ValueError("mode must be one of: off, observe, enforce") + return cast(HarnessMode, mode) + + +def _detect_sdks() -> dict[str, bool]: + return { + "openai": find_spec("openai") is not None, + "anthropic": find_spec("anthropic") is not None, + } + + +def get_harness_config() -> HarnessConfig: + return HarnessConfig(**_harness_config.__dict__) + + +def get_current_run() -> Optional[HarnessRunContext]: + return _current_run.get() + + +def get_harness_callback_manager() -> Any: + return _harness_callback_manager + + +def set_harness_callback_manager(callback_manager: Any) -> None: + global _harness_callback_manager + _harness_callback_manager = callback_manager + + +def reset() -> None: + """ + Reset harness global state and unpatch instrumented clients. + + Intended for tests and controlled shutdown paths. + """ + + global _harness_config + global _is_instrumented + global _harness_callback_manager + global _cached_cascade_decision_event + + from cascadeflow.harness.instrument import unpatch_anthropic, unpatch_openai + + unpatch_openai() + unpatch_anthropic() + _harness_config = HarnessConfig() + _is_instrumented = False + _harness_callback_manager = None + _cached_cascade_decision_event = None + _current_run.set(None) + + +_MAX_ACTION_LEN = 64 +_MAX_REASON_LEN = 160 +_MAX_MODEL_LEN = 128 +_MAX_ENV_JSON_LEN = 4096 + + +def _sanitize_trace_value(value: Any, *, max_length: int) -> Optional[str]: + if value is None: + return None + text = str(value).replace("\n", " ").replace("\r", " ").strip() + text = "".join(c for c in text if c.isprintable()) + if len(text) > max_length: + text = text[: max_length - 3] + "..." + return text or None + + +_cached_cascade_decision_event: Any = None + + +def _emit_harness_decision(entry: dict[str, Any]) -> None: + global _cached_cascade_decision_event + + manager = get_harness_callback_manager() + if manager is None: + return + + trigger = getattr(manager, "trigger", None) + if not callable(trigger): + logger.debug("harness callback manager has no trigger() method") + return + + if _cached_cascade_decision_event is None: + try: + from cascadeflow.telemetry.callbacks import CallbackEvent + + _cached_cascade_decision_event = CallbackEvent.CASCADE_DECISION + except Exception: + logger.debug("telemetry callbacks unavailable for harness decision emit", exc_info=True) + return + + try: + trigger( + _cached_cascade_decision_event, + query="[harness]", + data=dict(entry), + workflow="harness", + ) + except Exception: + logger.debug("failed to emit harness decision callback", exc_info=True) + + +def _parse_bool(raw: str) -> bool: + normalized = raw.strip().lower() + return normalized in {"1", "true", "yes", "on"} + + +def _parse_float(raw: str) -> float: + return float(raw.strip()) + + +def _parse_int(raw: str) -> int: + return int(raw.strip()) + + +def _parse_json_dict(raw: str) -> dict[str, float]: + if len(raw) > _MAX_ENV_JSON_LEN: + raise ValueError( + f"JSON config exceeds {_MAX_ENV_JSON_LEN} characters for harness env var" + ) + value = json.loads(raw) + if not isinstance(value, dict): + raise ValueError("expected JSON object") + parsed: dict[str, float] = {} + for key, item in value.items(): + parsed[str(key)] = float(item) + return parsed + + +def _read_env_config() -> dict[str, Any]: + env_config: dict[str, Any] = {} + + mode = os.getenv("CASCADEFLOW_HARNESS_MODE") or os.getenv("CASCADEFLOW_MODE") + if mode: + env_config["mode"] = mode + + verbose = os.getenv("CASCADEFLOW_HARNESS_VERBOSE") + if verbose is not None: + env_config["verbose"] = _parse_bool(verbose) + + budget = os.getenv("CASCADEFLOW_HARNESS_BUDGET") or os.getenv("CASCADEFLOW_BUDGET") + if budget is not None: + env_config["budget"] = _parse_float(budget) + + max_tool_calls = os.getenv("CASCADEFLOW_HARNESS_MAX_TOOL_CALLS") + if max_tool_calls is not None: + env_config["max_tool_calls"] = _parse_int(max_tool_calls) + + max_latency_ms = os.getenv("CASCADEFLOW_HARNESS_MAX_LATENCY_MS") + if max_latency_ms is not None: + env_config["max_latency_ms"] = _parse_float(max_latency_ms) + + max_energy = os.getenv("CASCADEFLOW_HARNESS_MAX_ENERGY") + if max_energy is not None: + env_config["max_energy"] = _parse_float(max_energy) + + compliance = os.getenv("CASCADEFLOW_HARNESS_COMPLIANCE") + if compliance is not None: + env_config["compliance"] = compliance + + kpi_targets = os.getenv("CASCADEFLOW_HARNESS_KPI_TARGETS") + if kpi_targets is not None: + env_config["kpi_targets"] = _parse_json_dict(kpi_targets) + + kpi_weights = os.getenv("CASCADEFLOW_HARNESS_KPI_WEIGHTS") + if kpi_weights is not None: + env_config["kpi_weights"] = _parse_json_dict(kpi_weights) + + return env_config + + +def _read_file_config() -> tuple[dict[str, Any], Optional[str]]: + """ + Read harness config from CASCADEFLOW_CONFIG path or default config discovery. + """ + + config_path: Optional[str] = os.getenv("CASCADEFLOW_CONFIG") + loaded_path: Optional[str] = None + + try: + from cascadeflow.config_loader import find_config, load_config + except Exception: + logger.debug("config_loader unavailable while reading harness config", exc_info=True) + return {}, None + + try: + if config_path: + loaded_path = str(Path(config_path)) + raw = load_config(config_path) + else: + discovered = find_config() + if not discovered: + return {}, None + loaded_path = str(discovered) + raw = load_config(discovered) + except Exception: + logger.warning("failed to load harness config file", exc_info=True) + return {}, None + + if not isinstance(raw, dict): + return {}, loaded_path + + harness_block = raw.get("harness") + if isinstance(harness_block, dict): + return dict(harness_block), loaded_path + + # Fallback: allow top-level harness keys. + keys = { + "mode", + "verbose", + "budget", + "max_tool_calls", + "max_latency_ms", + "max_energy", + "kpi_targets", + "kpi_weights", + "compliance", + } + fallback = {k: v for k, v in raw.items() if k in keys} + return fallback, loaded_path + + +def _resolve_value( + name: str, + explicit: Any, + env_config: dict[str, Any], + file_config: dict[str, Any], + default: Any, + sources: dict[str, str], +) -> Any: + if explicit is not _UNSET: + sources[name] = "code" + return explicit + if name in env_config: + sources[name] = "env" + return env_config[name] + if name in file_config: + sources[name] = "file" + return file_config[name] + sources[name] = "default" + return default + + +def init( + *, + mode: HarnessMode | object = _UNSET, + verbose: bool | object = _UNSET, + budget: Optional[float] | object = _UNSET, + max_tool_calls: Optional[int] | object = _UNSET, + max_latency_ms: Optional[float] | object = _UNSET, + max_energy: Optional[float] | object = _UNSET, + kpi_targets: Optional[dict[str, float]] | object = _UNSET, + kpi_weights: Optional[dict[str, float]] | object = _UNSET, + compliance: Optional[str] | object = _UNSET, + callback_manager: Any | object = _UNSET, +) -> HarnessInitReport: + """ + Initialize global harness settings and instrument detected SDK clients. + """ + + global _harness_config + global _is_instrumented + + env_config = _read_env_config() + file_config, file_path = _read_file_config() + sources: dict[str, str] = {} + + resolved_mode = _resolve_value("mode", mode, env_config, file_config, "off", sources) + resolved_verbose = _resolve_value("verbose", verbose, env_config, file_config, False, sources) + resolved_budget = _resolve_value("budget", budget, env_config, file_config, None, sources) + resolved_max_tool_calls = _resolve_value( + "max_tool_calls", max_tool_calls, env_config, file_config, None, sources + ) + resolved_max_latency_ms = _resolve_value( + "max_latency_ms", max_latency_ms, env_config, file_config, None, sources + ) + resolved_max_energy = _resolve_value( + "max_energy", max_energy, env_config, file_config, None, sources + ) + resolved_kpi_targets = _resolve_value( + "kpi_targets", kpi_targets, env_config, file_config, None, sources + ) + resolved_kpi_weights = _resolve_value( + "kpi_weights", kpi_weights, env_config, file_config, None, sources + ) + resolved_compliance = _resolve_value( + "compliance", compliance, env_config, file_config, None, sources + ) + if callback_manager is not _UNSET: + set_harness_callback_manager(callback_manager) + sources["callback_manager"] = "code" + + validated_mode = _validate_mode(str(resolved_mode)) + _harness_config = HarnessConfig( + mode=validated_mode, + verbose=bool(resolved_verbose), + budget=cast(Optional[float], resolved_budget), + max_tool_calls=cast(Optional[int], resolved_max_tool_calls), + max_latency_ms=cast(Optional[float], resolved_max_latency_ms), + max_energy=cast(Optional[float], resolved_max_energy), + kpi_targets=cast(Optional[dict[str, float]], resolved_kpi_targets), + kpi_weights=cast(Optional[dict[str, float]], resolved_kpi_weights), + compliance=cast(Optional[str], resolved_compliance), + ) + + sdk_presence = _detect_sdks() + instrumented: list[str] = [] + detected_but_not_instrumented: list[str] = [] + + if validated_mode != "off" and sdk_presence["openai"]: + from cascadeflow.harness.instrument import patch_openai + + if patch_openai(): + instrumented.append("openai") + else: + detected_but_not_instrumented.append("openai") + + if validated_mode != "off" and sdk_presence["anthropic"]: + from cascadeflow.harness.instrument import patch_anthropic + + if patch_anthropic(): + instrumented.append("anthropic") + else: + detected_but_not_instrumented.append("anthropic") + + if validated_mode == "off": + from cascadeflow.harness.instrument import ( + is_anthropic_patched, + is_openai_patched, + unpatch_anthropic, + unpatch_openai, + ) + + if is_openai_patched(): + unpatch_openai() + if is_anthropic_patched(): + unpatch_anthropic() + + if _is_instrumented: + logger.debug("harness init called again; instrumentation remains idempotent") + _is_instrumented = True + + logger.info("harness init mode=%s instrumented=%s", validated_mode, instrumented) + if detected_but_not_instrumented: + logger.info( + "harness detected but not instrumented=%s", + detected_but_not_instrumented, + ) + if file_path: + logger.debug("harness loaded config file=%s", file_path) + + return HarnessInitReport( + mode=validated_mode, + instrumented=instrumented, + detected_but_not_instrumented=detected_but_not_instrumented, + config_sources=sources, + ) + + +def run( + *, + budget: Optional[float] = None, + max_tool_calls: Optional[int] = None, + max_latency_ms: Optional[float] = None, + max_energy: Optional[float] = None, + kpi_targets: Optional[dict[str, float]] = None, + kpi_weights: Optional[dict[str, float]] = None, + compliance: Optional[str] = None, +) -> HarnessRunContext: + """ + Create a scoped run context. + + Scope-level values override global init defaults for the scope only. + """ + + config = get_harness_config() + resolved_budget = budget if budget is not None else config.budget + resolved_tool_calls = max_tool_calls if max_tool_calls is not None else config.max_tool_calls + resolved_latency = max_latency_ms if max_latency_ms is not None else config.max_latency_ms + resolved_energy = max_energy if max_energy is not None else config.max_energy + resolved_kpi_targets = kpi_targets if kpi_targets is not None else config.kpi_targets + resolved_kpi_weights = kpi_weights if kpi_weights is not None else config.kpi_weights + resolved_compliance = compliance if compliance is not None else config.compliance + + return HarnessRunContext( + mode=config.mode, + budget_max=resolved_budget, + tool_calls_max=resolved_tool_calls, + latency_max_ms=resolved_latency, + energy_max=resolved_energy, + kpi_targets=resolved_kpi_targets, + kpi_weights=resolved_kpi_weights, + compliance=resolved_compliance, + ) + + +F = TypeVar("F", bound=Callable[..., Any]) + + +def agent( + *, + budget: Optional[float] = None, + kpi_targets: Optional[dict[str, float]] = None, + kpi_weights: Optional[dict[str, float]] = None, + compliance: Optional[str] = None, +) -> Callable[[F], F]: + """ + Attach policy metadata to an agent function without changing behavior. + """ + + metadata = { + "budget": budget, + "kpi_targets": kpi_targets, + "kpi_weights": kpi_weights, + "compliance": compliance, + } + + def decorator(func: F) -> F: + func.__cascadeflow_agent_policy__ = metadata # type: ignore[attr-defined] + + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + return await func(*args, **kwargs) + + async_wrapper.__cascadeflow_agent_policy__ = metadata # type: ignore[attr-defined] + return cast(F, async_wrapper) + + @wraps(func) + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + return func(*args, **kwargs) + + sync_wrapper.__cascadeflow_agent_policy__ = metadata # type: ignore[attr-defined] + return cast(F, sync_wrapper) + + return decorator diff --git a/cascadeflow/harness/instrument.py b/cascadeflow/harness/instrument.py new file mode 100644 index 00000000..4b08b9f6 --- /dev/null +++ b/cascadeflow/harness/instrument.py @@ -0,0 +1,1297 @@ +"""Python SDK auto-instrumentation for cascadeflow harness. + +Patches OpenAI and Anthropic SDK request methods to intercept LLM calls for +observe/enforce modes. + +This module is called internally by ``cascadeflow.harness.init()``. Users +should not call patch/unpatch helpers directly. + +Implementation notes: + - Patching is class-level (all current and future client instances). + - Patching is idempotent (safe to call multiple times). + - ``unpatch_openai()`` restores the original methods exactly. + - Streaming responses are wrapped to capture usage after completion. + - ``with_raw_response`` is NOT patched in V2 (known limitation). +""" + +from __future__ import annotations + +import functools +import logging +import time +from dataclasses import dataclass +from typing import Any + +from cascadeflow.harness.pricing import ( + DEFAULT_ENERGY_COEFFICIENT as _DEFAULT_ENERGY_COEFFICIENT, +) +from cascadeflow.harness.pricing import ( + ENERGY_COEFFICIENTS as _ENERGY_COEFFICIENTS, +) +from cascadeflow.harness.pricing import ( + OPENAI_MODEL_POOL as _PRICING_MODELS, +) +from cascadeflow.harness.pricing import ( + estimate_cost as _estimate_cost_shared, +) +from cascadeflow.harness.pricing import ( + estimate_energy as _estimate_energy_shared, +) +from cascadeflow.harness.pricing import ( + model_total_price as _model_total_price_shared, +) + +logger = logging.getLogger("cascadeflow.harness.instrument") + +# --------------------------------------------------------------------------- +# Module-level state for idempotent patch/unpatch +# --------------------------------------------------------------------------- + +_openai_patched: bool = False +_original_sync_create: Any = None +_original_async_create: Any = None +_anthropic_patched: bool = False +_original_anthropic_sync_create: Any = None +_original_anthropic_async_create: Any = None + +_MODEL_TOTAL_COSTS: dict[str, float] = { + name: _model_total_price_shared(name) for name in _PRICING_MODELS +} +_CHEAPEST_MODEL: str = min(_MODEL_TOTAL_COSTS, key=_MODEL_TOTAL_COSTS.get) +_MIN_TOTAL_COST: float = min(_MODEL_TOTAL_COSTS.values()) +_MAX_TOTAL_COST: float = max(_MODEL_TOTAL_COSTS.values()) + +_OPENAI_ENERGY_COEFFS: dict[str, float] = { + name: _ENERGY_COEFFICIENTS.get(name, _DEFAULT_ENERGY_COEFFICIENT) for name in _PRICING_MODELS +} +_LOWEST_ENERGY_MODEL: str = min(_OPENAI_ENERGY_COEFFS, key=_OPENAI_ENERGY_COEFFS.get) +_MIN_ENERGY_COEFF: float = min(_OPENAI_ENERGY_COEFFS.values()) +_MAX_ENERGY_COEFF: float = max(_OPENAI_ENERGY_COEFFS.values()) + +# Relative priors used by KPI-weighted soft-control scoring. +# These are deterministic heuristics based on internal benchmark runs and +# intended as defaults until provider-specific online scoring is wired in. +_QUALITY_PRIORS: dict[str, float] = { + "gpt-4o": 0.90, + "gpt-4o-mini": 0.75, + "gpt-5-mini": 0.86, + "gpt-4-turbo": 0.88, + "gpt-4": 0.87, + "gpt-3.5-turbo": 0.65, + "o1": 0.95, + "o1-mini": 0.82, + "o3-mini": 0.80, +} +_LATENCY_PRIORS: dict[str, float] = { + "gpt-4o": 0.72, + "gpt-4o-mini": 0.93, + "gpt-5-mini": 0.84, + "gpt-4-turbo": 0.66, + "gpt-4": 0.52, + "gpt-3.5-turbo": 1.00, + "o1": 0.40, + "o1-mini": 0.60, + "o3-mini": 0.78, +} +_LATENCY_CANDIDATES: tuple[str, ...] = tuple( + name for name in _PRICING_MODELS if name in _LATENCY_PRIORS +) +_FASTEST_MODEL: str | None = ( + max(_LATENCY_CANDIDATES, key=lambda name: _LATENCY_PRIORS[name]) + if _LATENCY_CANDIDATES + else None +) + +# OpenAI-model allowlists used by the current OpenAI harness instrumentation. +# Future provider instrumentation should provide provider-specific allowlists. +_COMPLIANCE_MODEL_ALLOWLISTS: dict[str, set[str]] = { + "gdpr": {"gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"}, + "hipaa": {"gpt-4o", "gpt-4o-mini"}, + "pci": {"gpt-4o-mini", "gpt-3.5-turbo"}, + "strict": {"gpt-4o"}, +} + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _ensure_stream_usage(kwargs: dict[str, Any]) -> dict[str, Any]: + """Inject ``stream_options.include_usage=True`` for streaming requests. + + OpenAI only sends usage data in the final stream chunk when this option + is set. Without it the harness would record zero cost for every + streaming call. + """ + if not kwargs.get("stream", False): + return kwargs + stream_options = kwargs.get("stream_options") or {} + if not stream_options.get("include_usage"): + stream_options = {**stream_options, "include_usage": True} + kwargs = {**kwargs, "stream_options": stream_options} + return kwargs + + +def _estimate_cost(model: str, prompt_tokens: int, completion_tokens: int) -> float: + """Estimate cost in USD from model name and token counts.""" + return _estimate_cost_shared(model, prompt_tokens, completion_tokens) + + +def _estimate_energy(model: str, prompt_tokens: int, completion_tokens: int) -> float: + """Estimate energy units (deterministic proxy, not live carbon).""" + return _estimate_energy_shared(model, prompt_tokens, completion_tokens) + + +def _count_tool_calls_in_openai_response(response: Any) -> int: + """Count tool calls in a non-streaming ChatCompletion response.""" + choices = getattr(response, "choices", None) + if not choices: + return 0 + message = getattr(choices[0], "message", None) + if message is None: + return 0 + tool_calls = getattr(message, "tool_calls", None) + if tool_calls is None: + return 0 + return len(tool_calls) + + +def _extract_openai_usage(response: Any) -> tuple[int, int]: + """Extract (prompt_tokens, completion_tokens) from a response.""" + usage = getattr(response, "usage", None) + if usage is None: + return 0, 0 + return ( + getattr(usage, "prompt_tokens", 0) or 0, + getattr(usage, "completion_tokens", 0) or 0, + ) + + +def _extract_anthropic_usage(response: Any) -> tuple[int, int]: + """Extract (input_tokens, output_tokens) from an Anthropic response.""" + usage = getattr(response, "usage", None) + if usage is None: + return 0, 0 + return ( + getattr(usage, "input_tokens", 0) or 0, + getattr(usage, "output_tokens", 0) or 0, + ) + + +def _count_tool_calls_in_anthropic_response(response: Any) -> int: + """Count Anthropic ``tool_use`` blocks in a non-streaming response.""" + content = getattr(response, "content", None) + if not content: + return 0 + count = 0 + for block in content: + if getattr(block, "type", None) == "tool_use": + count += 1 + return count + + +def _model_total_cost(model: str) -> float: + return _MODEL_TOTAL_COSTS.get(model, _model_total_price_shared(model)) + + +def _select_cheaper_model(current_model: str) -> str: + if _model_total_cost(_CHEAPEST_MODEL) < _model_total_cost(current_model): + return _CHEAPEST_MODEL + return current_model + + +def _select_faster_model(current_model: str) -> str: + if _FASTEST_MODEL is None: + return current_model + current_latency = _LATENCY_PRIORS.get(current_model, 0.7) + if _LATENCY_PRIORS[_FASTEST_MODEL] > current_latency: + return _FASTEST_MODEL + return current_model + + +def _select_lower_energy_model(current_model: str) -> str: + if _ENERGY_COEFFICIENTS.get( + _LOWEST_ENERGY_MODEL, _DEFAULT_ENERGY_COEFFICIENT + ) < _ENERGY_COEFFICIENTS.get( + current_model, + _DEFAULT_ENERGY_COEFFICIENT, + ): + return _LOWEST_ENERGY_MODEL + return current_model + + +def _normalize_weights(weights: dict[str, float]) -> dict[str, float]: + normalized = { + key: float(value) + for key, value in weights.items() + if key in {"cost", "quality", "latency", "energy"} and float(value) > 0 + } + total = sum(normalized.values()) + if total <= 0: + return {} + return {key: value / total for key, value in normalized.items()} + + +def _cost_utility(model: str) -> float: + model_cost = _model_total_cost(model) + if _MAX_TOTAL_COST == _MIN_TOTAL_COST: + return 1.0 + return (_MAX_TOTAL_COST - model_cost) / (_MAX_TOTAL_COST - _MIN_TOTAL_COST) + + +def _energy_utility(model: str) -> float: + coeff = _ENERGY_COEFFICIENTS.get(model, _DEFAULT_ENERGY_COEFFICIENT) + if _MAX_ENERGY_COEFF == _MIN_ENERGY_COEFF: + return 1.0 + return (_MAX_ENERGY_COEFF - coeff) / (_MAX_ENERGY_COEFF - _MIN_ENERGY_COEFF) + + +def _kpi_score_with_normalized(model: str, normalized: dict[str, float]) -> float: + if not normalized: + return 0.0 + quality = _QUALITY_PRIORS.get(model, 0.7) + latency = _LATENCY_PRIORS.get(model, 0.7) + cost = _cost_utility(model) + energy = _energy_utility(model) + return ( + (normalized.get("quality", 0.0) * quality) + + (normalized.get("latency", 0.0) * latency) + + (normalized.get("cost", 0.0) * cost) + + (normalized.get("energy", 0.0) * energy) + ) + + +def _kpi_score(model: str, weights: dict[str, float]) -> float: + normalized = _normalize_weights(weights) + return _kpi_score_with_normalized(model, normalized) + + +def _select_kpi_weighted_model(current_model: str, weights: dict[str, float]) -> str: + normalized = _normalize_weights(weights) + if not normalized: + return current_model + best_model = current_model + best_score = _kpi_score_with_normalized(current_model, normalized) + for candidate in _PRICING_MODELS: + score = _kpi_score_with_normalized(candidate, normalized) + if score > best_score: + best_model = candidate + best_score = score + return best_model + + +def _compliance_allowlist(compliance: str | None) -> set[str] | None: + if not compliance: + return None + return _COMPLIANCE_MODEL_ALLOWLISTS.get(compliance.strip().lower()) + + +def _select_compliant_model(current_model: str, compliance: str) -> str | None: + allowlist = _compliance_allowlist(compliance) + if not allowlist: + return current_model + if current_model in allowlist: + return current_model + available = [name for name in _PRICING_MODELS if name in allowlist] + if not available: + return None + return min(available, key=_model_total_cost) + + +@dataclass(frozen=True) +class _PreCallDecision: + action: str + reason: str + target_model: str + + +def _evaluate_pre_call_decision(ctx: Any, model: str, has_tools: bool) -> _PreCallDecision: + if ctx.budget_max is not None and ctx.cost >= ctx.budget_max: + return _PreCallDecision(action="stop", reason="budget_exceeded", target_model=model) + + if has_tools and ctx.tool_calls_max is not None and ctx.tool_calls >= ctx.tool_calls_max: + return _PreCallDecision( + action="deny_tool", reason="max_tool_calls_reached", target_model=model + ) + + compliance = getattr(ctx, "compliance", None) + if compliance: + compliant_model = _select_compliant_model(model, str(compliance)) + if compliant_model is None: + if has_tools: + return _PreCallDecision( + action="deny_tool", + reason="compliance_no_approved_tool_path", + target_model=model, + ) + return _PreCallDecision( + action="stop", reason="compliance_no_approved_model", target_model=model + ) + if compliant_model != model: + return _PreCallDecision( + action="switch_model", + reason="compliance_model_policy", + target_model=compliant_model, + ) + if str(compliance).strip().lower() == "strict" and has_tools: + return _PreCallDecision( + action="deny_tool", + reason="compliance_tool_restriction", + target_model=model, + ) + + if ctx.latency_max_ms is not None and ctx.latency_used_ms >= ctx.latency_max_ms: + faster_model = _select_faster_model(model) + if faster_model != model: + return _PreCallDecision( + action="switch_model", + reason="latency_limit_exceeded", + target_model=faster_model, + ) + return _PreCallDecision(action="stop", reason="latency_limit_exceeded", target_model=model) + + if ctx.energy_max is not None and ctx.energy_used >= ctx.energy_max: + lower_energy_model = _select_lower_energy_model(model) + if lower_energy_model != model: + return _PreCallDecision( + action="switch_model", + reason="energy_limit_exceeded", + target_model=lower_energy_model, + ) + return _PreCallDecision(action="stop", reason="energy_limit_exceeded", target_model=model) + + if ( + ctx.budget_max is not None + and ctx.budget_max > 0 + and ctx.budget_remaining is not None + and (ctx.budget_remaining / ctx.budget_max) < 0.2 + ): + cheaper_model = _select_cheaper_model(model) + if cheaper_model != model: + return _PreCallDecision( + action="switch_model", + reason="budget_pressure", + target_model=cheaper_model, + ) + + kpi_weights = getattr(ctx, "kpi_weights", None) + if isinstance(kpi_weights, dict) and kpi_weights: + weighted_model = _select_kpi_weighted_model(model, kpi_weights) + if weighted_model != model: + return _PreCallDecision( + action="switch_model", + reason="kpi_weight_optimization", + target_model=weighted_model, + ) + + return _PreCallDecision(action="allow", reason=ctx.mode, target_model=model) + + +def _raise_stop_error(ctx: Any, reason: str) -> None: + from cascadeflow.schema.exceptions import BudgetExceededError, HarnessStopError + + if reason == "budget_exceeded": + remaining = 0.0 + if ctx.budget_max is not None: + remaining = ctx.budget_max - ctx.cost + raise BudgetExceededError( + f"Budget exhausted: spent ${ctx.cost:.4f} of ${ctx.budget_max or 0.0:.4f} max", + remaining=remaining, + ) + raise HarnessStopError(f"cascadeflow harness stop: {reason}", reason=reason) + + +def _resolve_pre_call_decision( + ctx: Any, + mode: str, + model: str, + kwargs: dict[str, Any], +) -> tuple[dict[str, Any], str, str, str, str, bool]: + decision = _evaluate_pre_call_decision(ctx, model, has_tools=bool(kwargs.get("tools"))) + action = decision.action + reason = decision.reason + target_model = decision.target_model + applied = action == "allow" + + if mode == "enforce": + if action == "stop": + ctx.record( + action="stop", + reason=reason, + model=model, + applied=True, + decision_mode=mode, + ) + _raise_stop_error(ctx, reason) + + if action == "switch_model" and target_model != model: + kwargs = {**kwargs, "model": target_model} + model = target_model + applied = True + elif action == "switch_model": + applied = False + + if action == "deny_tool": + if kwargs.get("tools"): + kwargs = {**kwargs, "tools": []} + applied = True + else: + applied = False + elif action != "allow": + logger.debug( + "harness observe decision: action=%s reason=%s model=%s target=%s", + action, + reason, + model, + target_model, + ) + applied = False + + return kwargs, model, action, reason, target_model, applied + + +def _update_context( + ctx: Any, + model: str, + prompt_tokens: int, + completion_tokens: int, + tool_call_count: int, + elapsed_ms: float, + *, + action: str = "allow", + action_reason: str | None = None, + action_model: str | None = None, + applied: bool | None = None, + decision_mode: str | None = None, +) -> None: + """Update a HarnessRunContext with call metrics.""" + cost = _estimate_cost(model, prompt_tokens, completion_tokens) + energy = _estimate_energy(model, prompt_tokens, completion_tokens) + + ctx.cost += cost + ctx.step_count += 1 + ctx.latency_used_ms += elapsed_ms + ctx.energy_used += energy + ctx.tool_calls += tool_call_count + + if ctx.budget_max is not None: + ctx.budget_remaining = ctx.budget_max - ctx.cost + + if applied is None: + applied = action == "allow" + if decision_mode is None: + decision_mode = ctx.mode + + if action == "allow": + ctx.record( + action="allow", + reason=ctx.mode, + model=model, + applied=applied, + decision_mode=decision_mode, + ) + return + + ctx.record( + action=action, + reason=action_reason or ctx.mode, + model=action_model or model, + applied=applied, + decision_mode=decision_mode, + ) + + +# --------------------------------------------------------------------------- +# Stream wrappers +# --------------------------------------------------------------------------- + + +class _InstrumentedStreamBase: + """Shared stream-wrapper logic for sync and async OpenAI streams.""" + + __slots__ = ( + "_stream", + "_ctx", + "_model", + "_start_time", + "_pre_action", + "_pre_reason", + "_pre_model", + "_pre_applied", + "_decision_mode", + "_usage", + "_tool_call_count", + "_finalized", + ) + + def __init__( + self, + stream: Any, + ctx: Any, + model: str, + start_time: float, + pre_action: str = "allow", + pre_reason: str = "observe", + pre_model: str | None = None, + pre_applied: bool = True, + decision_mode: str = "observe", + ) -> None: + self._stream = stream + self._ctx = ctx + self._model = model + self._start_time = start_time + self._pre_action = pre_action + self._pre_reason = pre_reason + self._pre_model = pre_model or model + self._pre_applied = pre_applied + self._decision_mode = decision_mode + self._usage: Any = None + self._tool_call_count: int = 0 + self._finalized: bool = False + + def close(self) -> None: + self._finalize() + if hasattr(self._stream, "close"): + self._stream.close() + + @property + def response(self) -> Any: + return getattr(self._stream, "response", None) + + def _inspect_chunk(self, chunk: Any) -> None: + usage = getattr(chunk, "usage", None) + if usage is not None: + self._usage = usage + + choices = getattr(chunk, "choices", []) + if choices: + delta = getattr(choices[0], "delta", None) + if delta: + tool_calls = getattr(delta, "tool_calls", None) + if tool_calls: + for tc in tool_calls: + # A new tool call has an ``id``; subsequent deltas for + # the same call only have ``index``. + if getattr(tc, "id", None): + self._tool_call_count += 1 + + def _finalize(self) -> None: + if self._finalized: + return + self._finalized = True + + if self._ctx is None: + return + + elapsed_ms = (time.monotonic() - self._start_time) * 1000 + prompt_tokens = 0 + completion_tokens = 0 + if self._usage: + prompt_tokens = getattr(self._usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(self._usage, "completion_tokens", 0) or 0 + + _update_context( + self._ctx, + self._model, + prompt_tokens, + completion_tokens, + self._tool_call_count, + elapsed_ms, + action=self._pre_action, + action_reason=self._pre_reason, + action_model=self._pre_model, + applied=self._pre_applied, + decision_mode=self._decision_mode, + ) + + +class _InstrumentedStream(_InstrumentedStreamBase): + """Wraps an OpenAI sync ``Stream`` and tracks usage at stream end.""" + + __slots__ = () + + def __iter__(self) -> _InstrumentedStream: + return self + + def __next__(self) -> Any: + try: + chunk = next(self._stream) + self._inspect_chunk(chunk) + return chunk + except StopIteration: + self._finalize() + raise + except Exception: + self._finalize() + raise + + def __enter__(self) -> _InstrumentedStream: + if hasattr(self._stream, "__enter__"): + self._stream.__enter__() + return self + + def __exit__(self, *args: Any) -> bool: + self._finalize() + if hasattr(self._stream, "__exit__"): + return self._stream.__exit__(*args) # type: ignore[no-any-return] + return False + + +class _InstrumentedAsyncStream(_InstrumentedStreamBase): + """Wraps an OpenAI async ``AsyncStream`` and tracks usage at stream end.""" + + __slots__ = () + + def __aiter__(self) -> _InstrumentedAsyncStream: + return self + + async def __anext__(self) -> Any: + try: + chunk = await self._stream.__anext__() + self._inspect_chunk(chunk) + return chunk + except StopAsyncIteration: + self._finalize() + raise + except Exception: + self._finalize() + raise + + async def __aenter__(self) -> _InstrumentedAsyncStream: + if hasattr(self._stream, "__aenter__"): + await self._stream.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> bool: + self._finalize() + if hasattr(self._stream, "__aexit__"): + return await self._stream.__aexit__(*args) # type: ignore[no-any-return] + return False + + +class _InstrumentedAnthropicStreamBase: + """Shared stream-wrapper logic for sync and async Anthropic streams.""" + + __slots__ = ( + "_stream", + "_ctx", + "_model", + "_start_time", + "_pre_action", + "_pre_reason", + "_pre_model", + "_pre_applied", + "_decision_mode", + "_input_tokens", + "_output_tokens", + "_tool_call_count", + "_finalized", + ) + + def __init__( + self, + stream: Any, + ctx: Any, + model: str, + start_time: float, + pre_action: str = "allow", + pre_reason: str = "observe", + pre_model: str | None = None, + pre_applied: bool = True, + decision_mode: str = "observe", + ) -> None: + self._stream = stream + self._ctx = ctx + self._model = model + self._start_time = start_time + self._pre_action = pre_action + self._pre_reason = pre_reason + self._pre_model = pre_model or model + self._pre_applied = pre_applied + self._decision_mode = decision_mode + self._input_tokens: int = 0 + self._output_tokens: int = 0 + self._tool_call_count: int = 0 + self._finalized: bool = False + + def close(self) -> None: + self._finalize() + if hasattr(self._stream, "close"): + self._stream.close() + + def _inspect_event(self, event: Any) -> None: + event_type = getattr(event, "type", None) + + if event_type == "message_start": + message = getattr(event, "message", None) + usage = getattr(message, "usage", None) + if usage is not None: + input_tokens = getattr(usage, "input_tokens", None) + output_tokens = getattr(usage, "output_tokens", None) + if isinstance(input_tokens, (int, float)): + self._input_tokens = int(input_tokens) if input_tokens > 0 else 0 + if isinstance(output_tokens, (int, float)): + self._output_tokens = int(output_tokens) if output_tokens > 0 else 0 + return + + usage = getattr(event, "usage", None) + if usage is not None: + input_tokens = getattr(usage, "input_tokens", None) + output_tokens = getattr(usage, "output_tokens", None) + if isinstance(input_tokens, (int, float)) and input_tokens > 0: + self._input_tokens = int(input_tokens) + if isinstance(output_tokens, (int, float)): + self._output_tokens = int(output_tokens) if output_tokens > 0 else 0 + + if event_type == "content_block_start": + content_block = getattr(event, "content_block", None) + block_type = getattr(content_block, "type", None) + if block_type in {"tool_use", "server_tool_use"}: + self._tool_call_count += 1 + + def _finalize(self) -> None: + if self._finalized: + return + self._finalized = True + + if self._ctx is None: + return + + elapsed_ms = (time.monotonic() - self._start_time) * 1000 + _update_context( + self._ctx, + self._model, + self._input_tokens, + self._output_tokens, + self._tool_call_count, + elapsed_ms, + action=self._pre_action, + action_reason=self._pre_reason, + action_model=self._pre_model, + applied=self._pre_applied, + decision_mode=self._decision_mode, + ) + + +class _InstrumentedAnthropicStream(_InstrumentedAnthropicStreamBase): + """Wraps an Anthropic sync stream and tracks usage at stream end.""" + + __slots__ = () + + def __iter__(self) -> _InstrumentedAnthropicStream: + return self + + def __next__(self) -> Any: + try: + event = next(self._stream) + self._inspect_event(event) + return event + except StopIteration: + self._finalize() + raise + except Exception: + self._finalize() + raise + + def __enter__(self) -> _InstrumentedAnthropicStream: + if hasattr(self._stream, "__enter__"): + self._stream.__enter__() + return self + + def __exit__(self, *args: Any) -> bool: + self._finalize() + if hasattr(self._stream, "__exit__"): + return self._stream.__exit__(*args) # type: ignore[no-any-return] + return False + + +class _InstrumentedAnthropicAsyncStream(_InstrumentedAnthropicStreamBase): + """Wraps an Anthropic async stream and tracks usage at stream end.""" + + __slots__ = () + + def __aiter__(self) -> _InstrumentedAnthropicAsyncStream: + return self + + async def __anext__(self) -> Any: + try: + event = await self._stream.__anext__() + self._inspect_event(event) + return event + except StopAsyncIteration: + self._finalize() + raise + except Exception: + self._finalize() + raise + + async def __aenter__(self) -> _InstrumentedAnthropicAsyncStream: + if hasattr(self._stream, "__aenter__"): + await self._stream.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> bool: + self._finalize() + if hasattr(self._stream, "__aexit__"): + return await self._stream.__aexit__(*args) # type: ignore[no-any-return] + return False + + +# --------------------------------------------------------------------------- +# Wrapper factories +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class _CallInterceptionState: + kwargs: dict[str, Any] + model: str + pre_action: str + pre_reason: str + pre_model: str + pre_applied: bool + is_stream: bool + start_time: float + + +def _prepare_call_interception( + *, + ctx: Any, + mode: str, + kwargs: dict[str, Any], +) -> _CallInterceptionState: + model: str = kwargs.get("model", "unknown") + pre_action = "allow" + pre_reason = mode + pre_model = model + pre_applied = True + + if ctx: + kwargs, model, pre_action, pre_reason, pre_model, pre_applied = _resolve_pre_call_decision( + ctx, + mode, + model, + kwargs, + ) + + is_stream: bool = bool(kwargs.get("stream", False)) + kwargs = _ensure_stream_usage(kwargs) + + return _CallInterceptionState( + kwargs=kwargs, + model=model, + pre_action=pre_action, + pre_reason=pre_reason, + pre_model=pre_model, + pre_applied=pre_applied, + is_stream=is_stream, + start_time=time.monotonic(), + ) + + +def _finalize_interception( + *, + ctx: Any, + mode: str, + state: _CallInterceptionState, + response: Any, + stream_wrapper: type[_InstrumentedStream] | type[_InstrumentedAsyncStream], +) -> Any: + if state.is_stream and ctx: + return stream_wrapper( + response, + ctx, + state.model, + state.start_time, + state.pre_action, + state.pre_reason, + state.pre_model, + state.pre_applied, + mode, + ) + + if (not state.is_stream) and ctx: + elapsed_ms = (time.monotonic() - state.start_time) * 1000 + prompt_tokens, completion_tokens = _extract_openai_usage(response) + tool_call_count = _count_tool_calls_in_openai_response(response) + _update_context( + ctx, + state.model, + prompt_tokens, + completion_tokens, + tool_call_count, + elapsed_ms, + action=state.pre_action, + action_reason=state.pre_reason, + action_model=state.pre_model, + applied=state.pre_applied, + decision_mode=mode, + ) + else: + logger.debug( + "harness %s: model=%s (no active run scope, metrics not tracked)", + mode, + state.model, + ) + + return response + + +def _make_patched_create(original_fn: Any) -> Any: + """Create a patched version of ``Completions.create``.""" + + @functools.wraps(original_fn) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + from cascadeflow.harness.api import get_current_run, get_harness_config + + config = get_harness_config() + ctx = get_current_run() + mode = ctx.mode if ctx else config.mode + + if mode == "off": + return original_fn(self, *args, **kwargs) + + state = _prepare_call_interception(ctx=ctx, mode=mode, kwargs=kwargs) + + logger.debug( + "harness intercept: model=%s stream=%s mode=%s", + state.model, + state.is_stream, + mode, + ) + + response = original_fn(self, *args, **state.kwargs) + + return _finalize_interception( + ctx=ctx, + mode=mode, + state=state, + response=response, + stream_wrapper=_InstrumentedStream, + ) + + return wrapper + + +def _make_patched_async_create(original_fn: Any) -> Any: + """Create a patched version of ``AsyncCompletions.create``.""" + + @functools.wraps(original_fn) + async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + from cascadeflow.harness.api import get_current_run, get_harness_config + + config = get_harness_config() + ctx = get_current_run() + mode = ctx.mode if ctx else config.mode + + if mode == "off": + return await original_fn(self, *args, **kwargs) + + state = _prepare_call_interception(ctx=ctx, mode=mode, kwargs=kwargs) + + logger.debug( + "harness intercept async: model=%s stream=%s mode=%s", + state.model, + state.is_stream, + mode, + ) + + response = await original_fn(self, *args, **state.kwargs) + + return _finalize_interception( + ctx=ctx, + mode=mode, + state=state, + response=response, + stream_wrapper=_InstrumentedAsyncStream, + ) + + return wrapper + + +def _make_patched_anthropic_create(original_fn: Any) -> Any: + """Create a patched version of ``anthropic.Messages.create``.""" + + @functools.wraps(original_fn) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + from cascadeflow.harness.api import get_current_run, get_harness_config + + config = get_harness_config() + ctx = get_current_run() + mode = ctx.mode if ctx else config.mode + + if mode == "off": + return original_fn(self, *args, **kwargs) + + model: str = kwargs.get("model", "unknown") + pre_action = "allow" + pre_reason = mode + pre_model = model + pre_applied = True + + if ctx: + kwargs, model, pre_action, pre_reason, pre_model, pre_applied = ( + _resolve_pre_call_decision( + ctx, + mode, + model, + kwargs, + ) + ) + + is_stream = bool(kwargs.get("stream", False)) + start_time = time.monotonic() + response = original_fn(self, *args, **kwargs) + + if not ctx: + logger.debug( + "harness %s (anthropic): model=%s (no active run scope, metrics not tracked)", + mode, + model, + ) + return response + + if is_stream: + return _InstrumentedAnthropicStream( + response, + ctx, + model, + start_time, + pre_action, + pre_reason, + pre_model, + pre_applied, + mode, + ) + + elapsed_ms = (time.monotonic() - start_time) * 1000 + input_tokens, output_tokens = _extract_anthropic_usage(response) + tool_call_count = _count_tool_calls_in_anthropic_response(response) + _update_context( + ctx, + model, + input_tokens, + output_tokens, + tool_call_count, + elapsed_ms, + action=pre_action, + action_reason=pre_reason, + action_model=pre_model, + applied=pre_applied, + decision_mode=mode, + ) + return response + + return wrapper + + +def _make_patched_anthropic_async_create(original_fn: Any) -> Any: + """Create a patched version of ``anthropic.AsyncMessages.create``.""" + + @functools.wraps(original_fn) + async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + from cascadeflow.harness.api import get_current_run, get_harness_config + + config = get_harness_config() + ctx = get_current_run() + mode = ctx.mode if ctx else config.mode + + if mode == "off": + return await original_fn(self, *args, **kwargs) + + model: str = kwargs.get("model", "unknown") + pre_action = "allow" + pre_reason = mode + pre_model = model + pre_applied = True + + if ctx: + kwargs, model, pre_action, pre_reason, pre_model, pre_applied = ( + _resolve_pre_call_decision( + ctx, + mode, + model, + kwargs, + ) + ) + + is_stream = bool(kwargs.get("stream", False)) + start_time = time.monotonic() + response = await original_fn(self, *args, **kwargs) + + if not ctx: + logger.debug( + "harness %s async (anthropic): model=%s (no active run scope, metrics not tracked)", + mode, + model, + ) + return response + + if is_stream: + return _InstrumentedAnthropicAsyncStream( + response, + ctx, + model, + start_time, + pre_action, + pre_reason, + pre_model, + pre_applied, + mode, + ) + + elapsed_ms = (time.monotonic() - start_time) * 1000 + input_tokens, output_tokens = _extract_anthropic_usage(response) + tool_call_count = _count_tool_calls_in_anthropic_response(response) + _update_context( + ctx, + model, + input_tokens, + output_tokens, + tool_call_count, + elapsed_ms, + action=pre_action, + action_reason=pre_reason, + action_model=pre_model, + applied=pre_applied, + decision_mode=mode, + ) + return response + + return wrapper + + +# --------------------------------------------------------------------------- +# Public API (called by cascadeflow.harness.api) +# --------------------------------------------------------------------------- + + +def patch_openai() -> bool: + """Patch the OpenAI Python client for harness instrumentation. + + Returns ``True`` if patching succeeded, ``False`` if openai is not + installed. Idempotent: safe to call multiple times. + """ + global _openai_patched, _original_sync_create, _original_async_create + + if _openai_patched: + logger.debug("openai already patched, skipping") + return True + + try: + from openai.resources.chat.completions import AsyncCompletions, Completions + except ImportError: + logger.debug("openai package not available, skipping instrumentation") + return False + + _original_sync_create = Completions.create + _original_async_create = AsyncCompletions.create + + Completions.create = _make_patched_create(_original_sync_create) # type: ignore[assignment] + AsyncCompletions.create = _make_patched_async_create( # type: ignore[assignment] + _original_async_create, + ) + + _openai_patched = True + logger.info("openai client instrumented (sync + async)") + return True + + +def patch_anthropic() -> bool: + """Patch the Anthropic Python client for harness instrumentation. + + Returns ``True`` if patching succeeded, ``False`` if anthropic is not + installed. Idempotent: safe to call multiple times. + """ + global _anthropic_patched, _original_anthropic_sync_create, _original_anthropic_async_create + + if _anthropic_patched: + logger.debug("anthropic already patched, skipping") + return True + + try: + from anthropic.resources.messages import AsyncMessages, Messages + except ImportError: + logger.debug("anthropic package not available, skipping instrumentation") + return False + + _original_anthropic_sync_create = Messages.create + _original_anthropic_async_create = AsyncMessages.create + + Messages.create = _make_patched_anthropic_create(_original_anthropic_sync_create) # type: ignore[assignment] + AsyncMessages.create = _make_patched_anthropic_async_create( # type: ignore[assignment] + _original_anthropic_async_create, + ) + + _anthropic_patched = True + logger.info("anthropic client instrumented (sync + async)") + return True + + +def unpatch_openai() -> None: + """Restore original OpenAI client methods. + + Safe to call even if not patched. Used by ``reset()`` and tests. + """ + global _openai_patched, _original_sync_create, _original_async_create + + if not _openai_patched: + return + + try: + from openai.resources.chat.completions import AsyncCompletions, Completions + except ImportError: + _openai_patched = False + return + + if _original_sync_create is not None: + Completions.create = _original_sync_create # type: ignore[assignment] + if _original_async_create is not None: + AsyncCompletions.create = _original_async_create # type: ignore[assignment] + + _original_sync_create = None + _original_async_create = None + _openai_patched = False + logger.info("openai client unpatched") + + +def unpatch_anthropic() -> None: + """Restore original Anthropic client methods. + + Safe to call even if not patched. Used by ``reset()`` and tests. + """ + global _anthropic_patched, _original_anthropic_sync_create, _original_anthropic_async_create + + if not _anthropic_patched: + return + + try: + from anthropic.resources.messages import AsyncMessages, Messages + except ImportError: + _anthropic_patched = False + return + + if _original_anthropic_sync_create is not None: + Messages.create = _original_anthropic_sync_create # type: ignore[assignment] + if _original_anthropic_async_create is not None: + AsyncMessages.create = _original_anthropic_async_create # type: ignore[assignment] + + _original_anthropic_sync_create = None + _original_anthropic_async_create = None + _anthropic_patched = False + logger.info("anthropic client unpatched") + + +def is_openai_patched() -> bool: + """Return whether the OpenAI client is currently patched.""" + return _openai_patched + + +def is_anthropic_patched() -> bool: + """Return whether the Anthropic client is currently patched.""" + return _anthropic_patched + + +def is_patched() -> bool: + """Return whether any supported Python SDK is currently patched.""" + return _openai_patched or _anthropic_patched diff --git a/cascadeflow/harness/pricing.py b/cascadeflow/harness/pricing.py new file mode 100644 index 00000000..bd86323e --- /dev/null +++ b/cascadeflow/harness/pricing.py @@ -0,0 +1,78 @@ +"""Shared harness pricing and energy profiles. + +This module centralizes model-cost and energy-estimation defaults used by +harness integrations (OpenAI auto-instrumentation, OpenAI Agents SDK, CrewAI). +""" + +from __future__ import annotations + +from typing import Final + +# USD per 1M tokens (input, output). +PRICING_USD_PER_M: Final[dict[str, tuple[float, float]]] = { + # OpenAI + "gpt-4o": (2.50, 10.00), + "gpt-4o-mini": (0.15, 0.60), + "gpt-5": (1.25, 10.00), + "gpt-5-mini": (0.20, 0.80), + "gpt-4-turbo": (10.00, 30.00), + "gpt-4": (30.00, 60.00), + "gpt-3.5-turbo": (0.50, 1.50), + "o1": (15.00, 60.00), + "o1-mini": (3.00, 12.00), + "o3-mini": (1.10, 4.40), + # Anthropic aliases used by CrewAI model names. + "claude-sonnet-4": (3.00, 15.00), + "claude-haiku-3.5": (1.00, 5.00), + "claude-opus-4.5": (5.00, 25.00), +} +DEFAULT_PRICING_USD_PER_M: Final[tuple[float, float]] = (2.50, 10.00) + +# Deterministic proxy coefficients for energy tracking. +ENERGY_COEFFICIENTS: Final[dict[str, float]] = { + "gpt-4o": 1.0, + "gpt-4o-mini": 0.3, + "gpt-5": 1.2, + "gpt-5-mini": 0.35, + "gpt-4-turbo": 1.5, + "gpt-4": 1.5, + "gpt-3.5-turbo": 0.2, + "o1": 2.0, + "o1-mini": 0.8, + "o3-mini": 0.5, +} +DEFAULT_ENERGY_COEFFICIENT: Final[float] = 1.0 +ENERGY_OUTPUT_WEIGHT: Final[float] = 1.5 + +# Explicit pools keep provider/model-switching logic constrained even though the +# pricing table is shared across integrations. +OPENAI_MODEL_POOL: Final[tuple[str, ...]] = ( + "gpt-4o", + "gpt-4o-mini", + "gpt-5", + "gpt-5-mini", + "gpt-4-turbo", + "gpt-4", + "gpt-3.5-turbo", + "o1", + "o1-mini", + "o3-mini", +) + + +def estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float: + """Estimate USD cost from token usage.""" + in_price, out_price = PRICING_USD_PER_M.get(model, DEFAULT_PRICING_USD_PER_M) + return (input_tokens / 1_000_000.0) * in_price + (output_tokens / 1_000_000.0) * out_price + + +def estimate_energy(model: str, input_tokens: int, output_tokens: int) -> float: + """Estimate deterministic proxy energy units.""" + coefficient = ENERGY_COEFFICIENTS.get(model, DEFAULT_ENERGY_COEFFICIENT) + return coefficient * (input_tokens + (output_tokens * ENERGY_OUTPUT_WEIGHT)) + + +def model_total_price(model: str) -> float: + """Return total (input + output) price per 1M tokens.""" + in_price, out_price = PRICING_USD_PER_M.get(model, DEFAULT_PRICING_USD_PER_M) + return in_price + out_price diff --git a/cascadeflow/integrations/__init__.py b/cascadeflow/integrations/__init__.py index 9a0dfa4d..33552773 100644 --- a/cascadeflow/integrations/__init__.py +++ b/cascadeflow/integrations/__init__.py @@ -90,6 +90,25 @@ extract_token_usage = None MODEL_PRICING = None +# Try to import OpenAI Agents SDK integration +try: + from .openai_agents import ( + OPENAI_AGENTS_SDK_AVAILABLE, + CascadeFlowModelProvider, + OpenAIAgentsIntegrationConfig, + create_openai_agents_provider, + is_openai_agents_sdk_available, + ) + + OPENAI_AGENTS_AVAILABLE = OPENAI_AGENTS_SDK_AVAILABLE +except ImportError: + OPENAI_AGENTS_AVAILABLE = False + OPENAI_AGENTS_SDK_AVAILABLE = False + CascadeFlowModelProvider = None + OpenAIAgentsIntegrationConfig = None + create_openai_agents_provider = None + is_openai_agents_sdk_available = None + # OpenClaw integration helpers (no external deps) try: from .openclaw import ( @@ -146,6 +165,26 @@ PaygenticUsageReporter = None PaygenticProxyService = None +# Try to import CrewAI integration +try: + from .crewai import ( + CREWAI_AVAILABLE, + CrewAIHarnessConfig, + enable as crewai_enable, + disable as crewai_disable, + is_available as crewai_is_available, + is_enabled as crewai_is_enabled, + get_config as crewai_get_config, + ) +except ImportError: + CREWAI_AVAILABLE = False + CrewAIHarnessConfig = None + crewai_enable = None + crewai_disable = None + crewai_is_available = None + crewai_is_enabled = None + crewai_get_config = None + __all__ = [] if LITELLM_AVAILABLE: @@ -209,6 +248,17 @@ ] ) +if OPENAI_AGENTS_AVAILABLE: + __all__.extend( + [ + "OPENAI_AGENTS_SDK_AVAILABLE", + "CascadeFlowModelProvider", + "OpenAIAgentsIntegrationConfig", + "create_openai_agents_provider", + "is_openai_agents_sdk_available", + ] + ) + if PAYGENTIC_AVAILABLE: __all__.extend( [ @@ -222,13 +272,28 @@ ] ) +if CREWAI_AVAILABLE: + __all__.extend( + [ + "CREWAI_AVAILABLE", + "CrewAIHarnessConfig", + "crewai_enable", + "crewai_disable", + "crewai_is_available", + "crewai_is_enabled", + "crewai_get_config", + ] + ) + # Integration capabilities INTEGRATION_CAPABILITIES = { "litellm": LITELLM_AVAILABLE, "opentelemetry": OPENTELEMETRY_AVAILABLE, "langchain": LANGCHAIN_AVAILABLE, + "openai_agents": OPENAI_AGENTS_AVAILABLE, "openclaw": OPENCLAW_AVAILABLE, "paygentic": PAYGENTIC_AVAILABLE, + "crewai": CREWAI_AVAILABLE, } @@ -250,6 +315,8 @@ def get_integration_info(): "litellm_available": LITELLM_AVAILABLE, "opentelemetry_available": OPENTELEMETRY_AVAILABLE, "langchain_available": LANGCHAIN_AVAILABLE, + "openai_agents_available": OPENAI_AGENTS_AVAILABLE, "openclaw_available": OPENCLAW_AVAILABLE, "paygentic_available": PAYGENTIC_AVAILABLE, + "crewai_available": CREWAI_AVAILABLE, } diff --git a/cascadeflow/integrations/crewai.py b/cascadeflow/integrations/crewai.py new file mode 100644 index 00000000..604ae600 --- /dev/null +++ b/cascadeflow/integrations/crewai.py @@ -0,0 +1,307 @@ +"""CrewAI harness integration for cascadeflow. + +Uses CrewAI's native ``llm_hooks`` system (v1.5+) to intercept all LLM calls +inside Crew executions, feeding metrics into ``cascadeflow.harness`` run +contexts. + +This module is optional — ``pip install cascadeflow[crewai]`` pulls in the +crewai dependency. When crewai is not installed the public helpers return +gracefully and ``CREWAI_AVAILABLE`` is ``False``. + +Integration surface: + - ``enable()``: register before/after LLM-call hooks globally + - ``disable()``: unregister hooks and clean up + - ``CrewAIHarnessConfig``: optional knobs (fail_open, enable_budget_gate) +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass +from importlib.util import find_spec +from typing import Any, Optional + +from cascadeflow.harness.pricing import estimate_cost as _estimate_shared_cost +from cascadeflow.harness.pricing import estimate_energy as _estimate_shared_energy + +logger = logging.getLogger("cascadeflow.integrations.crewai") + +CREWAI_AVAILABLE = find_spec("crewai") is not None + + +def _estimate_cost(model: str, prompt_tokens: int, completion_tokens: int) -> float: + return _estimate_shared_cost(model, prompt_tokens, completion_tokens) + + +def _estimate_energy(model: str, prompt_tokens: int, completion_tokens: int) -> float: + return _estimate_shared_energy(model, prompt_tokens, completion_tokens) + + +def _extract_message_content(message: Any) -> str: + """Extract content text from a CrewAI message (dict or object). + + CrewAI hooks pass messages as dicts (``{"role": "...", "content": "..."}``) + but we also handle object-style messages defensively. + """ + if isinstance(message, dict): + return str(message.get("content", "") or "") + return str(getattr(message, "content", "") or "") + + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + + +@dataclass +class CrewAIHarnessConfig: + """Runtime configuration for the CrewAI harness integration. + + fail_open: + If ``True`` (default), errors inside hooks never break the CrewAI + execution — they are logged and swallowed. + enable_budget_gate: + If ``True`` (default), a ``before_llm_call`` hook blocks calls when + the harness run budget is exhausted (enforce mode only). + """ + + fail_open: bool = True + enable_budget_gate: bool = True + + +# --------------------------------------------------------------------------- +# Module-level state +# --------------------------------------------------------------------------- + +_config: CrewAIHarnessConfig = CrewAIHarnessConfig() +_hooks_registered: bool = False +_before_hook_ref: Any = None +_after_hook_ref: Any = None +# Track call start times per thread via a dict keyed by id(context) +_call_start_times: dict[int, float] = {} + + +# --------------------------------------------------------------------------- +# Hook implementations +# --------------------------------------------------------------------------- + + +def _extract_model_name(context: Any) -> str: + """Best-effort extraction of the model name from a LLMCallHookContext.""" + llm = getattr(context, "llm", None) + if llm is None: + return "unknown" + # CrewAI LLM objects have a .model attribute + model = getattr(llm, "model", None) + if isinstance(model, str): + # Strip provider prefix like "openai/gpt-4o" → "gpt-4o" + if "/" in model: + return model.rsplit("/", 1)[-1] + return model + return "unknown" + + +def _before_llm_call_hook(context: Any) -> Optional[bool]: + """Harness before-LLM-call hook registered with CrewAI. + + - In enforce mode with budget gate: blocks calls when budget exhausted. + - Tracks call start time for latency measurement. + - Returns ``None`` (allow) or ``False`` (block). + """ + try: + from cascadeflow.harness.api import get_current_run + + ctx = get_current_run() + if ctx is None: + return None + + # Budget gate in enforce mode — check BEFORE recording start time + # so blocked calls don't leak entries in _call_start_times. + if ( + _config.enable_budget_gate + and ctx.mode == "enforce" + and ctx.budget_max is not None + and ctx.cost >= ctx.budget_max + ): + logger.warning( + "crewai hook: blocking LLM call — budget exhausted " "(spent $%.4f of $%.4f max)", + ctx.cost, + ctx.budget_max, + ) + ctx.record(action="stop", reason="budget_exhausted", model=_extract_model_name(context)) + return False + + # Record start time for latency tracking (only for allowed calls) + _call_start_times[id(context)] = time.monotonic() + + return None + except Exception: + if _config.fail_open: + logger.debug("crewai before_llm_call hook error (fail_open)", exc_info=True) + return None + raise + + +def _after_llm_call_hook(context: Any) -> Optional[str]: + """Harness after-LLM-call hook registered with CrewAI. + + Updates the active HarnessRunContext with: + - cost (estimated from model + response length) + - latency + - energy estimate + - step count + - trace record + + Returns ``None`` (keep original response). + """ + try: + from cascadeflow.harness.api import get_current_run + + ctx = get_current_run() + if ctx is None: + return None + + model = _extract_model_name(context) + response = getattr(context, "response", None) or "" + + # Estimate tokens from text (rough: 1 token ≈ 4 chars). + # CrewAI hooks don't expose raw token counts, so we approximate. + # Messages are typically dicts ({"role": "...", "content": "..."}). + messages = getattr(context, "messages", []) + prompt_chars = sum(len(_extract_message_content(m)) for m in messages) + completion_chars = len(str(response)) + prompt_tokens = max(prompt_chars // 4, 1) + completion_tokens = max(completion_chars // 4, 1) + + cost = _estimate_cost(model, prompt_tokens, completion_tokens) + energy = _estimate_energy(model, prompt_tokens, completion_tokens) + + # Latency + start_time = _call_start_times.pop(id(context), None) + elapsed_ms = (time.monotonic() - start_time) * 1000 if start_time else 0.0 + + ctx.cost += cost + ctx.step_count += 1 + ctx.latency_used_ms += elapsed_ms + ctx.energy_used += energy + + if ctx.budget_max is not None: + ctx.budget_remaining = ctx.budget_max - ctx.cost + + ctx.model_used = model + ctx.record(action="allow", reason=ctx.mode, model=model) + + logger.debug( + "crewai hook: tracked call model=%s cost=$%.6f latency=%.0fms", + model, + cost, + elapsed_ms, + ) + + return None + except Exception: + if _config.fail_open: + logger.debug("crewai after_llm_call hook error (fail_open)", exc_info=True) + return None + raise + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def is_available() -> bool: + """Return whether the crewai package is installed.""" + return CREWAI_AVAILABLE + + +def is_enabled() -> bool: + """Return whether harness hooks are currently registered with CrewAI.""" + return _hooks_registered + + +def enable(config: Optional[CrewAIHarnessConfig] = None) -> bool: + """Register cascadeflow harness hooks with CrewAI's global hook system. + + Idempotent: safe to call multiple times. + + Args: + config: Optional configuration overrides. + + Returns: + ``True`` if hooks were registered, ``False`` if crewai is not + installed. + """ + global _config, _hooks_registered, _before_hook_ref, _after_hook_ref + + if _hooks_registered: + logger.debug("crewai harness hooks already registered") + return True + + if not CREWAI_AVAILABLE: + logger.debug("crewai not installed, skipping hook registration") + return False + + if config is not None: + _config = config + + try: + from crewai.hooks import ( # noqa: I001 + register_after_llm_call_hook, + register_before_llm_call_hook, + ) + except ImportError: + logger.warning( + "crewai is installed but hooks module not available " "(requires crewai>=1.5); skipping" + ) + return False + + _before_hook_ref = _before_llm_call_hook + _after_hook_ref = _after_llm_call_hook + + register_before_llm_call_hook(_before_hook_ref) + register_after_llm_call_hook(_after_hook_ref) + + _hooks_registered = True + logger.info("crewai harness hooks registered (before + after llm call)") + return True + + +def disable() -> None: + """Unregister cascadeflow harness hooks from CrewAI. + + Safe to call even if not enabled. + """ + global _hooks_registered, _before_hook_ref, _after_hook_ref + + if not _hooks_registered: + return + + try: + from crewai.hooks import ( # noqa: I001 + unregister_after_llm_call_hook, + unregister_before_llm_call_hook, + ) + + if _before_hook_ref is not None: + unregister_before_llm_call_hook(_before_hook_ref) + if _after_hook_ref is not None: + unregister_after_llm_call_hook(_after_hook_ref) + except ImportError: + pass + + _before_hook_ref = None + _after_hook_ref = None + _hooks_registered = False + _call_start_times.clear() + logger.info("crewai harness hooks unregistered") + + +def get_config() -> CrewAIHarnessConfig: + """Return a copy of the current configuration.""" + return CrewAIHarnessConfig( + fail_open=_config.fail_open, + enable_budget_gate=_config.enable_budget_gate, + ) diff --git a/cascadeflow/integrations/langchain/__init__.py b/cascadeflow/integrations/langchain/__init__.py index 45c6ea2f..7b3f9551 100644 --- a/cascadeflow/integrations/langchain/__init__.py +++ b/cascadeflow/integrations/langchain/__init__.py @@ -54,6 +54,14 @@ CascadeFlowCallbackHandler, get_cascade_callback, ) +from .harness_callback import ( + HarnessAwareCascadeFlowCallbackHandler, + get_harness_callback, +) +from .harness_state import ( + apply_langgraph_state, + extract_langgraph_state, +) __all__ = [ # Main classes @@ -93,4 +101,8 @@ # LangChain callback handlers "CascadeFlowCallbackHandler", "get_cascade_callback", + "HarnessAwareCascadeFlowCallbackHandler", + "get_harness_callback", + "extract_langgraph_state", + "apply_langgraph_state", ] diff --git a/cascadeflow/integrations/langchain/harness_callback.py b/cascadeflow/integrations/langchain/harness_callback.py new file mode 100644 index 00000000..01f08d8c --- /dev/null +++ b/cascadeflow/integrations/langchain/harness_callback.py @@ -0,0 +1,248 @@ +"""Harness-aware callbacks for LangChain/LangGraph integration. + +Enforce-mode limitations (LangChain callback architecture): + - ``stop`` (budget/latency/energy exceeded): fully enforced — raises + BudgetExceededError or HarnessStopError from ``on_llm_start``. + - ``deny_tool`` (tool-call cap): fully enforced at the tool level via + ``on_tool_start`` — raises HarnessStopError before tool execution. + - ``switch_model``: **observe-only** — LangChain dispatches the LLM call + before ``on_llm_start`` returns, so the callback cannot redirect to a + different model. The decision is recorded with ``applied=False``. + - ``deny_tool`` at LLM level (pre-call decision): **observe-only** — the + callback cannot strip tools from an already-dispatched LLM request. + The decision is recorded with ``applied=False``. +""" + +from __future__ import annotations + +import logging +import time +from contextlib import contextmanager +from typing import Any, Optional + +from cascadeflow.harness import get_current_run +from cascadeflow.harness.pricing import estimate_cost, estimate_energy +from cascadeflow.schema.exceptions import HarnessStopError + +from .harness_state import apply_langgraph_state, extract_langgraph_state +from .langchain_callbacks import CascadeFlowCallbackHandler +from .utils import extract_token_usage + +logger = logging.getLogger("cascadeflow.harness.langchain") + + +class HarnessAwareCascadeFlowCallbackHandler(CascadeFlowCallbackHandler): + """LangChain callback that bridges native lifecycle events into HarnessRunContext. + + See module docstring for enforce-mode limitations on ``switch_model`` + and LLM-level ``deny_tool``. + """ + + def __init__(self, *, fail_open: bool = True): + super().__init__() + self.fail_open = fail_open + self._llm_started_at: Optional[float] = None + self._pre_action: str = "allow" + self._pre_reason: str = "allow" + self._pre_model: Optional[str] = None + self._pre_recorded: bool = False + + def _handle_harness_error(self, error: Exception) -> None: + if self.fail_open: + logger.exception("langchain harness callback failed (fail-open)", exc_info=error) + return + raise error + + def _sync_state(self, payload: dict[str, Any]) -> None: + run_ctx = get_current_run() + if run_ctx is None: + return + state = extract_langgraph_state(payload) + if state: + apply_langgraph_state(run_ctx, state) + + def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any) -> None: + super().on_llm_start(serialized=serialized, prompts=prompts, **kwargs) + self._llm_started_at = time.monotonic() + self._pre_action = "allow" + self._pre_reason = "allow" + self._pre_model = self.current_model + self._pre_recorded = False + + try: + self._sync_state(kwargs) + + run_ctx = get_current_run() + if run_ctx is None: + return + + model_name = self.current_model or "unknown" + invocation_params = kwargs.get("invocation_params") + has_tools = False + if isinstance(invocation_params, dict): + has_tools = bool(invocation_params.get("tools")) + if not has_tools: + has_tools = bool(kwargs.get("tools")) + + from cascadeflow.harness.instrument import ( + _evaluate_pre_call_decision, + _raise_stop_error, + ) # noqa: I001 + + decision = _evaluate_pre_call_decision(run_ctx, model_name, has_tools=has_tools) + self._pre_action = decision.action + self._pre_reason = decision.reason + self._pre_model = decision.target_model + + if run_ctx.mode == "observe": + if decision.action != "allow": + run_ctx.record( + action=decision.action, + reason=decision.reason, + model=decision.target_model, + applied=False, + decision_mode="observe", + ) + self._pre_recorded = True + return + + if run_ctx.mode != "enforce": + return + + if decision.action == "stop": + run_ctx.record( + action="stop", + reason=decision.reason, + model=model_name, + applied=True, + decision_mode="enforce", + ) + self._pre_recorded = True + _raise_stop_error(run_ctx, decision.reason) + + if decision.action == "switch_model": + run_ctx.record( + action="switch_model", + reason=decision.reason, + model=decision.target_model, + applied=False, + decision_mode="enforce", + ) + self._pre_recorded = True + + if decision.action == "deny_tool" and has_tools: + run_ctx.record( + action="deny_tool", + reason=decision.reason, + model=model_name, + applied=False, + decision_mode="enforce", + ) + self._pre_recorded = True + + except Exception as exc: + self._handle_harness_error(exc) + + def on_llm_end(self, response: Any, **kwargs: Any) -> None: + super().on_llm_end(response=response, **kwargs) + + try: + self._sync_state(kwargs) + run_ctx = get_current_run() + if run_ctx is None: + return + + model_name = self.current_model + if not model_name and getattr(response, "llm_output", None): + model_name = response.llm_output.get("model_name") + model_name = model_name or "unknown" + + token_usage = extract_token_usage(response) + prompt_tokens = int(token_usage["input"]) + completion_tokens = int(token_usage["output"]) + elapsed_ms = 0.0 + if self._llm_started_at is not None: + elapsed_ms = (time.monotonic() - self._llm_started_at) * 1000.0 + + run_ctx.step_count += 1 + run_ctx.cost += estimate_cost(model_name, prompt_tokens, completion_tokens) + run_ctx.energy_used += estimate_energy(model_name, prompt_tokens, completion_tokens) + run_ctx.latency_used_ms += elapsed_ms + + if run_ctx.budget_max is not None: + run_ctx.budget_remaining = run_ctx.budget_max - run_ctx.cost + + if self._pre_action == "allow": + run_ctx.record( + action="allow", + reason="langchain_step", + model=model_name, + applied=True, + decision_mode=run_ctx.mode, + ) + elif not self._pre_recorded: + run_ctx.record( + action=self._pre_action, + reason=self._pre_reason, + model=self._pre_model or model_name, + applied=False, + decision_mode=run_ctx.mode, + ) + + except Exception as exc: + self._handle_harness_error(exc) + finally: + self._llm_started_at = None + self._pre_action = "allow" + self._pre_reason = "allow" + self._pre_model = None + self._pre_recorded = False + + def on_tool_start(self, serialized: dict[str, Any], input_str: str, **kwargs: Any) -> Any: + try: + self._sync_state(kwargs) + run_ctx = get_current_run() + if run_ctx is None: + return None + if run_ctx.tool_calls_max is None: + return None + + if run_ctx.tool_calls >= run_ctx.tool_calls_max: + if run_ctx.mode == "observe": + run_ctx.record( + action="deny_tool", + reason="max_tool_calls_reached", + model=self.current_model, + applied=False, + decision_mode="observe", + ) + return None + if run_ctx.mode == "enforce": + run_ctx.record( + action="deny_tool", + reason="max_tool_calls_reached", + model=self.current_model, + applied=True, + decision_mode="enforce", + ) + raise HarnessStopError( + "cascadeflow harness deny_tool: max tool calls reached", + reason="max_tool_calls_reached", + ) + + # Track executed tools (not predicted tool calls in LLM output). + run_ctx.tool_calls += 1 + return None + except Exception as exc: + self._handle_harness_error(exc) + return None + + +@contextmanager +def get_harness_callback(*, fail_open: bool = True): + """Context manager that yields a harness-aware LangChain callback handler.""" + callback = HarnessAwareCascadeFlowCallbackHandler(fail_open=fail_open) + yield callback + + +__all__ = ["HarnessAwareCascadeFlowCallbackHandler", "get_harness_callback"] diff --git a/cascadeflow/integrations/langchain/harness_state.py b/cascadeflow/integrations/langchain/harness_state.py new file mode 100644 index 00000000..b4b40da5 --- /dev/null +++ b/cascadeflow/integrations/langchain/harness_state.py @@ -0,0 +1,124 @@ +"""LangGraph/LangChain state extraction helpers for harness integration.""" + +from __future__ import annotations + +from typing import Any, Mapping, Optional + + +def _as_int(value: Any) -> Optional[int]: + try: + if value is None: + return None + return int(value) + except (TypeError, ValueError): + return None + + +def _as_float(value: Any) -> Optional[float]: + try: + if value is None: + return None + return float(value) + except (TypeError, ValueError): + return None + + +def _extract_candidate_state(source: Any) -> Optional[Mapping[str, Any]]: + """Extract a named state container from a mapping. + + Only returns state from explicitly named keys (langgraph_state, graph_state, + state). Returns None when no named key matches — avoids treating arbitrary + kwargs as harness state. + """ + if not isinstance(source, Mapping): + return None + + for key in ("langgraph_state", "graph_state", "state"): + candidate = source.get(key) + if isinstance(candidate, Mapping): + return candidate + + return None + + +def extract_langgraph_state(payload: Any) -> dict[str, Any]: + """Extract normalized harness-relevant fields from LangGraph-style state payloads.""" + + candidates: list[Mapping[str, Any]] = [] + root = _extract_candidate_state(payload) + if root is not None: + candidates.append(root) + + if isinstance(payload, Mapping): + metadata = payload.get("metadata") + if isinstance(metadata, Mapping): + state_from_metadata = _extract_candidate_state(metadata) + if state_from_metadata is not None: + candidates.append(state_from_metadata) + + configurable = payload.get("configurable") + if isinstance(configurable, Mapping): + state_from_configurable = _extract_candidate_state(configurable) + if state_from_configurable is not None: + candidates.append(state_from_configurable) + + merged: dict[str, Any] = {} + for source in candidates: + if "agent_id" in source and isinstance(source.get("agent_id"), str): + merged["agent_id"] = source["agent_id"] + if "model" in source and isinstance(source.get("model"), str): + merged["model_used"] = source["model"] + if "model_used" in source and isinstance(source.get("model_used"), str): + merged["model_used"] = source["model_used"] + + step_count = _as_int(source.get("step_count", source.get("step"))) + if step_count is not None: + merged["step_count"] = step_count + + tool_calls = _as_int(source.get("tool_calls")) + if tool_calls is not None: + merged["tool_calls"] = tool_calls + + budget_remaining = _as_float(source.get("budget_remaining")) + if budget_remaining is not None: + merged["budget_remaining"] = budget_remaining + + latency_used_ms = _as_float(source.get("latency_used_ms", source.get("latency_ms"))) + if latency_used_ms is not None: + merged["latency_used_ms"] = latency_used_ms + + energy_used = _as_float(source.get("energy_used", source.get("energy"))) + if energy_used is not None: + merged["energy_used"] = energy_used + + return merged + + +def apply_langgraph_state(run_ctx: Any, state: Mapping[str, Any]) -> None: + """Apply extracted state fields onto an active HarnessRunContext.""" + if run_ctx is None or not isinstance(state, Mapping): + return + + step_count = _as_int(state.get("step_count")) + if step_count is not None and step_count > getattr(run_ctx, "step_count", 0): + run_ctx.step_count = step_count + + tool_calls = _as_int(state.get("tool_calls")) + if tool_calls is not None and tool_calls > getattr(run_ctx, "tool_calls", 0): + run_ctx.tool_calls = tool_calls + + latency_used_ms = _as_float(state.get("latency_used_ms")) + if latency_used_ms is not None and latency_used_ms > getattr(run_ctx, "latency_used_ms", 0.0): + run_ctx.latency_used_ms = latency_used_ms + + energy_used = _as_float(state.get("energy_used")) + if energy_used is not None and energy_used > getattr(run_ctx, "energy_used", 0.0): + run_ctx.energy_used = energy_used + + budget_remaining = _as_float(state.get("budget_remaining")) + if budget_remaining is not None: + run_ctx.budget_remaining = budget_remaining + + model_used = state.get("model_used") + if isinstance(model_used, str) and model_used: + run_ctx.model_used = model_used diff --git a/cascadeflow/integrations/langchain/tests/test_langchain_harness_callback.py b/cascadeflow/integrations/langchain/tests/test_langchain_harness_callback.py new file mode 100644 index 00000000..9ba062e5 --- /dev/null +++ b/cascadeflow/integrations/langchain/tests/test_langchain_harness_callback.py @@ -0,0 +1,213 @@ +"""Tests for harness-aware LangChain callback integration.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration, LLMResult + +from cascadeflow.harness import init, reset, run +from cascadeflow.integrations.langchain.harness_callback import ( + HarnessAwareCascadeFlowCallbackHandler, +) +from cascadeflow.integrations.langchain.harness_state import ( + apply_langgraph_state, + extract_langgraph_state, +) +from cascadeflow.integrations.langchain.utils import extract_tool_calls +from cascadeflow.schema.exceptions import BudgetExceededError, HarnessStopError + + +@pytest.fixture(autouse=True) +def _reset_harness_state() -> None: + reset() + + +def _llm_result(model_name: str, prompt_tokens: int, completion_tokens: int) -> LLMResult: + generation = ChatGeneration(message=AIMessage(content="ok"), generation_info={}) + return LLMResult( + generations=[[generation]], + llm_output={ + "model_name": model_name, + "token_usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + }, + ) + + +def test_harness_callback_updates_active_run_metrics() -> None: + init(mode="observe", budget=1.0) + handler = HarnessAwareCascadeFlowCallbackHandler() + + with run(budget=1.0) as ctx: + handler.on_llm_start( + serialized={}, + prompts=["hello"], + invocation_params={"model": "gpt-4o-mini"}, + ) + handler.on_llm_end(_llm_result("gpt-4o-mini", 120, 80)) + + assert ctx.step_count == 1 + assert ctx.cost > 0 + assert ctx.energy_used > 0 + assert ctx.budget_remaining is not None + assert ctx.budget_remaining < 1.0 + assert ctx.last_action == "allow" + assert ctx.model_used == "gpt-4o-mini" + + +def test_harness_callback_enforce_raises_when_budget_exhausted() -> None: + init(mode="enforce", budget=0.1) + handler = HarnessAwareCascadeFlowCallbackHandler(fail_open=False) + + with run(budget=0.1) as ctx: + ctx.cost = 0.1 + ctx.budget_remaining = 0.0 + + with pytest.raises(BudgetExceededError): + handler.on_llm_start( + serialized={}, + prompts=["hello"], + invocation_params={"model": "gpt-4o-mini"}, + ) + + trace = ctx.trace() + assert trace + assert trace[-1]["action"] == "stop" + assert trace[-1]["reason"] == "budget_exceeded" + assert trace[-1]["applied"] is True + + +def test_harness_callback_observe_records_non_applied_decisions() -> None: + init(mode="observe", budget=1.0) + handler = HarnessAwareCascadeFlowCallbackHandler() + + with run(budget=1.0) as ctx: + ctx.cost = 0.9 + ctx.budget_remaining = 0.1 + + handler.on_llm_start( + serialized={}, + prompts=["hello"], + invocation_params={"model": "gpt-4o", "tools": [{"name": "lookup"}]}, + ) + + trace = ctx.trace() + assert trace + assert trace[-1]["action"] in {"switch_model", "deny_tool"} + assert trace[-1]["applied"] is False + assert trace[-1]["decision_mode"] == "observe" + + +def test_harness_callback_enforce_denies_tool_when_limit_reached() -> None: + init(mode="enforce", max_tool_calls=0, budget=1.0) + handler = HarnessAwareCascadeFlowCallbackHandler(fail_open=False) + + with run(max_tool_calls=0, budget=1.0) as ctx: + with pytest.raises(HarnessStopError, match="max tool calls"): + handler.on_tool_start(serialized={"name": "search"}, input_str="query") + + trace = ctx.trace() + assert trace + assert trace[-1]["action"] == "deny_tool" + assert trace[-1]["applied"] is True + assert trace[-1]["decision_mode"] == "enforce" + + +def test_on_llm_end_no_run_context_is_safe() -> None: + handler = HarnessAwareCascadeFlowCallbackHandler() + handler.on_llm_start( + serialized={}, + prompts=["hello"], + invocation_params={"model": "gpt-4o-mini"}, + ) + handler.on_llm_end(_llm_result("gpt-4o-mini", 10, 5)) + + +def test_on_tool_start_no_run_context_is_safe() -> None: + handler = HarnessAwareCascadeFlowCallbackHandler() + handler.on_tool_start(serialized={"name": "search"}, input_str="query") + + +def test_extract_state_ignores_plain_kwargs() -> None: + """Kwargs without a named state key should not leak into state.""" + state = extract_langgraph_state({"model": "gpt-4o", "invocation_params": {"tools": []}}) + assert state == {} + + +def test_tool_deny_uses_run_ctx_tool_calls() -> None: + """Tool gating should use run_ctx.tool_calls, not a local counter.""" + init(mode="enforce", max_tool_calls=2, budget=1.0) + handler = HarnessAwareCascadeFlowCallbackHandler(fail_open=False) + + with run(max_tool_calls=2, budget=1.0) as ctx: + # Simulate tool calls already counted by on_llm_end or other integrations + ctx.tool_calls = 2 + + with pytest.raises(HarnessStopError, match="max tool calls"): + handler.on_tool_start(serialized={"name": "search"}, input_str="query") + + +def test_tool_start_counts_executions_and_blocks_after_limit() -> None: + init(mode="enforce", max_tool_calls=1, budget=1.0) + handler = HarnessAwareCascadeFlowCallbackHandler(fail_open=False) + + with run(max_tool_calls=1, budget=1.0) as ctx: + assert ctx.tool_calls == 0 + assert handler.on_tool_start(serialized={"name": "search"}, input_str="first") is None + assert ctx.tool_calls == 1 + + with pytest.raises(HarnessStopError, match="max tool calls"): + handler.on_tool_start(serialized={"name": "search"}, input_str="second") + + assert ctx.tool_calls == 1 + trace = ctx.trace() + assert trace[-1]["action"] == "deny_tool" + assert trace[-1]["applied"] is True + + +def test_extract_tool_calls_supports_llm_result_nested_generations() -> None: + generation = ChatGeneration( + message=AIMessage( + content="", tool_calls=[{"name": "search", "args": {"q": "x"}, "id": "t1"}] + ), + generation_info={}, + ) + llm_result = LLMResult(generations=[[generation]], llm_output={"model_name": "gpt-4o-mini"}) + tool_calls = extract_tool_calls(llm_result) + assert len(tool_calls) == 1 + assert tool_calls[0]["name"] == "search" + + +def test_extract_and_apply_langgraph_state() -> None: + state = extract_langgraph_state( + { + "metadata": { + "langgraph_state": { + "step": 4, + "tool_calls": 3, + "budget_remaining": 0.42, + "latency_ms": 130.0, + "energy": 77.0, + "model": "gpt-4o-mini", + } + } + } + ) + + assert state["step_count"] == 4 + assert state["tool_calls"] == 3 + assert state["model_used"] == "gpt-4o-mini" + + init(mode="observe", budget=1.0) + with run(budget=1.0) as ctx: + apply_langgraph_state(ctx, state) + assert ctx.step_count == 4 + assert ctx.tool_calls == 3 + assert ctx.budget_remaining == pytest.approx(0.42) + assert ctx.latency_used_ms == pytest.approx(130.0) + assert ctx.energy_used == pytest.approx(77.0) + assert ctx.model_used == "gpt-4o-mini" diff --git a/cascadeflow/integrations/langchain/tests/test_langchain_integration_features.py b/cascadeflow/integrations/langchain/tests/test_langchain_integration_features.py index fdbcff1d..0f051519 100644 --- a/cascadeflow/integrations/langchain/tests/test_langchain_integration_features.py +++ b/cascadeflow/integrations/langchain/tests/test_langchain_integration_features.py @@ -4,7 +4,11 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.outputs import ChatGeneration, ChatResult +from cascadeflow.harness import init, reset, run from cascadeflow.integrations.langchain import CascadeFlow +from cascadeflow.integrations.langchain.harness_callback import ( + HarnessAwareCascadeFlowCallbackHandler, +) class MockSequenceChatModel(BaseChatModel): @@ -116,3 +120,38 @@ def test_domain_policy_direct_to_verifier_skips_drafter() -> None: assert drafter.calls == 0 assert verifier.calls == 1 assert result.llm_output["cascade"]["routing_reason"] == "domain_policy_direct" + + +def test_wrapper_only_auto_adds_harness_callback_inside_active_run_scope() -> None: + reset() + init(mode="observe") + drafter = MockSequenceChatModel("draft") + verifier = MockSequenceChatModel("verify") + cascade = CascadeFlow(drafter=drafter, verifier=verifier, enable_pre_router=False) + + outside_callbacks = cascade._resolve_callbacks([]) + assert not any( + isinstance(cb, HarnessAwareCascadeFlowCallbackHandler) for cb in outside_callbacks + ) + + with run(): + inside_callbacks = cascade._resolve_callbacks([]) + assert any( + isinstance(cb, HarnessAwareCascadeFlowCallbackHandler) for cb in inside_callbacks + ) + + +def test_wrapper_does_not_duplicate_harness_callback() -> None: + reset() + init(mode="observe") + drafter = MockSequenceChatModel("draft") + verifier = MockSequenceChatModel("verify") + cascade = CascadeFlow(drafter=drafter, verifier=verifier, enable_pre_router=False) + existing = HarnessAwareCascadeFlowCallbackHandler() + + with run(): + callbacks = cascade._resolve_callbacks([existing]) + assert ( + len([cb for cb in callbacks if isinstance(cb, HarnessAwareCascadeFlowCallbackHandler)]) + == 1 + ) diff --git a/cascadeflow/integrations/langchain/utils.py b/cascadeflow/integrations/langchain/utils.py index fe47a353..04f3e4a5 100644 --- a/cascadeflow/integrations/langchain/utils.py +++ b/cascadeflow/integrations/langchain/utils.py @@ -195,6 +195,10 @@ def extract_tool_calls(response: Any) -> list[dict[str, Any]]: msg = None if hasattr(response, "generations") and response.generations: generation = response.generations[0] + # LLMResult.generations is often list[list[Generation]], while ChatResult + # uses list[Generation]. Support both shapes. + if isinstance(generation, list) and generation: + generation = generation[0] msg = getattr(generation, "message", None) else: msg = getattr(response, "message", None) or response diff --git a/cascadeflow/integrations/langchain/wrapper.py b/cascadeflow/integrations/langchain/wrapper.py index ed6d554b..f108d60f 100644 --- a/cascadeflow/integrations/langchain/wrapper.py +++ b/cascadeflow/integrations/langchain/wrapper.py @@ -169,6 +169,35 @@ def _split_runnable_config( model_kwargs[key] = value return model_kwargs, config + def _resolve_callbacks(self, raw_callbacks: Any) -> list[Any]: + if raw_callbacks is None: + callbacks: list[Any] = [] + elif isinstance(raw_callbacks, list): + callbacks = list(raw_callbacks) + elif isinstance(raw_callbacks, tuple): + callbacks = list(raw_callbacks) + else: + callbacks = [raw_callbacks] + + try: + from cascadeflow.harness import get_current_run, get_harness_config + + harness_config = get_harness_config() + run_ctx = get_current_run() + if harness_config.mode == "off" or run_ctx is None or run_ctx.mode == "off": + return callbacks + + from .harness_callback import HarnessAwareCascadeFlowCallbackHandler + + if any(isinstance(cb, HarnessAwareCascadeFlowCallbackHandler) for cb in callbacks): + return callbacks + + callbacks.append(HarnessAwareCascadeFlowCallbackHandler()) + return callbacks + except Exception: + # Preserve existing behavior for users who do not enable harness flows. + return callbacks + def _generate( self, messages: list[BaseMessage], @@ -202,7 +231,7 @@ def _generate( merged_kwargs["stop"] = stop # Extract callbacks before filtering (need to pass them explicitly to nested models) - callbacks = merged_kwargs.get("callbacks", []) + callbacks = self._resolve_callbacks(merged_kwargs.get("callbacks", [])) existing_tags = merged_kwargs.get("tags", []) or [] base_tags = existing_tags + ["cascadeflow"] if existing_tags else ["cascadeflow"] @@ -599,7 +628,7 @@ async def _agenerate( merged_kwargs["stop"] = stop # Extract callbacks before filtering (need to pass them explicitly to nested models) - callbacks = merged_kwargs.get("callbacks", []) + callbacks = self._resolve_callbacks(merged_kwargs.get("callbacks", [])) existing_tags = merged_kwargs.get("tags", []) or [] base_tags = existing_tags + ["cascadeflow"] if existing_tags else ["cascadeflow"] @@ -1001,7 +1030,7 @@ def _stream( stream_kwargs, base_config = self._split_runnable_config(merged_kwargs) base_tags = (base_config.get("tags") or []) + ["cascadeflow"] existing_metadata = base_config.get("metadata", {}) or {} - callbacks = base_config.get("callbacks", []) + callbacks = self._resolve_callbacks(base_config.get("callbacks", [])) resolved_domain = self._resolve_domain(messages, existing_metadata) effective_quality_threshold = self._effective_quality_threshold(resolved_domain) force_verifier_for_domain = self._domain_forces_verifier(resolved_domain) @@ -1324,7 +1353,7 @@ async def _astream( stream_kwargs, base_config = self._split_runnable_config(merged_kwargs) base_tags = (base_config.get("tags") or []) + ["cascadeflow"] existing_metadata = base_config.get("metadata", {}) or {} - callbacks = base_config.get("callbacks", []) + callbacks = self._resolve_callbacks(base_config.get("callbacks", [])) safe_kwargs = { k: v for k, v in stream_kwargs.items() diff --git a/cascadeflow/integrations/openai_agents.py b/cascadeflow/integrations/openai_agents.py new file mode 100644 index 00000000..cbce9b96 --- /dev/null +++ b/cascadeflow/integrations/openai_agents.py @@ -0,0 +1,487 @@ +""" +OpenAI Agents SDK integration for cascadeflow harness. + +This module provides an opt-in ModelProvider implementation that applies +cascadeflow harness decisions (model switching, tool gating, run accounting) +inside OpenAI Agents SDK execution. +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass +from importlib.util import find_spec +from typing import TYPE_CHECKING, Any, AsyncIterator, Optional + +from cascadeflow.harness import get_current_run +from cascadeflow.harness.pricing import ( + OPENAI_MODEL_POOL, +) +from cascadeflow.harness.pricing import ( + estimate_cost as _estimate_shared_cost, +) +from cascadeflow.harness.pricing import ( + estimate_energy as _estimate_shared_energy, +) +from cascadeflow.harness.pricing import ( + model_total_price as _shared_model_total_price, +) +from cascadeflow.schema.exceptions import BudgetExceededError + +logger = logging.getLogger("cascadeflow.harness.openai_agents") + +OPENAI_AGENTS_SDK_AVAILABLE = find_spec("agents") is not None + +if TYPE_CHECKING: + from agents.items import ModelResponse + from agents.model_settings import ModelSettings + from agents.models.interface import Model, ModelProvider, ModelTracing + from agents.tool import Tool + from openai.types.responses.response_prompt_param import ResponsePromptParam +else: + Model = object + ModelProvider = object + ModelSettings = Any + ModelTracing = Any + ModelResponse = Any + Tool = Any + ResponsePromptParam = Any + + +@dataclass +class OpenAIAgentsIntegrationConfig: + """ + Runtime behavior for the OpenAI Agents integration. + + model_candidates: + Optional ordered list of candidate models used when harness decides + to switch models under pressure (for example low remaining budget). + enable_tool_gating: + If enabled, removes tools from a model call when the run already + exceeded tool-call caps in enforce mode. + fail_open: + If True, harness-side integration errors never break the agent call. + """ + + model_candidates: Optional[list[str]] = None + enable_tool_gating: bool = True + fail_open: bool = True + + +def _estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float: + return _estimate_shared_cost(model, input_tokens, output_tokens) + + +def _estimate_energy(model: str, input_tokens: int, output_tokens: int) -> float: + return _estimate_shared_energy(model, input_tokens, output_tokens) + + +def _total_model_price(model: str) -> float: + return _shared_model_total_price(model) + + +def _extract_usage_tokens(usage: Any) -> tuple[int, int]: + if usage is None: + return 0, 0 + + input_tokens = getattr(usage, "input_tokens", None) + output_tokens = getattr(usage, "output_tokens", None) + + if input_tokens is None: + input_tokens = getattr(usage, "prompt_tokens", 0) + if output_tokens is None: + output_tokens = getattr(usage, "completion_tokens", 0) + + return int(input_tokens or 0), int(output_tokens or 0) + + +def _count_tool_calls(output_items: Any) -> int: + if not output_items: + return 0 + + count = 0 + for item in output_items: + item_type = None + if isinstance(item, dict): + item_type = item.get("type") + else: + item_type = getattr(item, "type", None) + + if item_type in {"function_call", "tool_call"}: + count += 1 + + return count + + +def _safe_record(action: str, reason: str, model: Optional[str]) -> None: + run = get_current_run() + if run is None: + return + run.record(action=action, reason=reason, model=model) + + +def _apply_run_metrics( + *, + model_name: str, + response: Any, + elapsed_ms: float, + pre_action: str, + allow_reason: str, +) -> None: + run = get_current_run() + if run is None: + return + + usage = getattr(response, "usage", None) if response is not None else None + input_tokens, output_tokens = _extract_usage_tokens(usage) + tool_calls = _count_tool_calls(getattr(response, "output", None)) if response is not None else 0 + + run.step_count += 1 + run.latency_used_ms += elapsed_ms + run.energy_used += _estimate_energy(model_name, input_tokens, output_tokens) + run.cost += _estimate_cost(model_name, input_tokens, output_tokens) + run.tool_calls += tool_calls + + if run.budget_max is not None: + run.budget_remaining = run.budget_max - run.cost + + if pre_action == "deny_tool": + run.last_action = "deny_tool" + run.model_used = model_name + else: + run.record("allow", allow_reason, model_name) + + if run.mode == "enforce" and run.budget_remaining is not None and run.budget_remaining <= 0: + logger.info("openai-agents step exhausted budget; next step will be blocked") + + +class CascadeFlowModelProvider(ModelProvider): # type: ignore[misc] + """ + OpenAI Agents SDK ModelProvider with cascadeflow harness awareness. + + Works as an integration layer only. It is opt-in and never enabled by + default for existing cascadeflow users. + """ + + def __init__( + self, + *, + base_provider: Optional[Any] = None, + config: Optional[OpenAIAgentsIntegrationConfig] = None, + ) -> None: + self._config = config or OpenAIAgentsIntegrationConfig() + self._base_provider = base_provider or self._create_default_provider() + + def _create_default_provider(self) -> Any: + if not OPENAI_AGENTS_SDK_AVAILABLE: + raise ImportError( + "OpenAI Agents SDK not installed. Install with `pip install cascadeflow[openai-agents]`." + ) + + # Local import keeps this integration optional for users who don't + # install the extra. + from agents.models.openai_provider import OpenAIProvider + + return OpenAIProvider() + + def _initial_model_candidate(self, requested_model: Optional[str]) -> str: + if requested_model: + return requested_model + if self._config.model_candidates: + return self._config.model_candidates[0] + return "gpt-4o-mini" + + def _resolve_model(self, requested_model: Optional[str]) -> str: + candidate = self._initial_model_candidate(requested_model) + + run = get_current_run() + if run is None: + return candidate + if run.mode != "enforce": + return candidate + + if run.budget_remaining is not None and run.budget_remaining <= 0: + run.record("stop", "budget_exceeded", candidate) + raise BudgetExceededError( + "cascadeflow harness budget exceeded", + remaining=run.budget_remaining, + ) + + if not self._config.model_candidates or run.budget_max is None or run.budget_max <= 0: + return candidate + + if run.budget_remaining is None: + return candidate + + # Under budget pressure, switch to the cheapest configured candidate. + if run.budget_remaining / run.budget_max < 0.2: + compatible_candidates = [ + name for name in self._config.model_candidates if name in OPENAI_MODEL_POOL + ] + candidates = compatible_candidates or self._config.model_candidates + cheapest = min( + candidates, + key=_total_model_price, + ) + if cheapest != candidate: + run.record("switch_model", "budget_pressure", cheapest) + return cheapest + + return candidate + + def get_model(self, model_name: str | None) -> Model: + fallback_model = self._initial_model_candidate(model_name) + selected_model = fallback_model + + try: + selected_model = self._resolve_model(model_name) + except BudgetExceededError: + raise + except Exception: + if not self._config.fail_open: + raise + logger.exception( + "openai-agents model resolution failed; falling back to requested model (fail-open)" + ) + selected_model = fallback_model + + try: + base_model = self._base_provider.get_model(selected_model) + except Exception: + if not self._config.fail_open: + raise + logger.exception( + "openai-agents provider.get_model failed; retrying with fallback model (fail-open)" + ) + selected_model = fallback_model + base_model = self._base_provider.get_model(selected_model) + + return _CascadeFlowWrappedModel( + base_model=base_model, + model_name=selected_model, + config=self._config, + ) + + async def aclose(self) -> None: + close = getattr(self._base_provider, "aclose", None) + if close is None: + return + await close() + + +class _CascadeFlowWrappedModel(Model): # type: ignore[misc] + def __init__( + self, + *, + base_model: Any, + model_name: str, + config: OpenAIAgentsIntegrationConfig, + ) -> None: + self._base_model = base_model + self._model_name = model_name + self._config = config + + def _gate_tools(self, tools: list[Tool]) -> tuple[list[Tool], str]: + run = get_current_run() + if run is None: + return tools, "allow" + if run.mode != "enforce" or not self._config.enable_tool_gating: + return tools, "allow" + if run.tool_calls_max is None: + return tools, "allow" + if run.tool_calls < run.tool_calls_max: + return tools, "allow" + if not tools: + return tools, "allow" + + run.record("deny_tool", "max_tool_calls_reached", self._model_name) + return [], "deny_tool" + + def _update_run_metrics( + self, + *, + response: Any, + elapsed_ms: float, + pre_action: str, + ) -> None: + _apply_run_metrics( + model_name=self._model_name, + response=response, + elapsed_ms=elapsed_ms, + pre_action=pre_action, + allow_reason="openai_agents_step", + ) + + async def get_response( + self, + system_instructions: str | None, + input: str | list[Any], # noqa: A002 - required by OpenAI Agents SDK Model interface + model_settings: ModelSettings, + tools: list[Tool], + output_schema: Any | None, + handoffs: list[Any], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: ResponsePromptParam | None, + ) -> ModelResponse: + gated_tools, pre_action = self._gate_tools(tools) + started_at = time.monotonic() + + response = await self._base_model.get_response( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=gated_tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt, + ) + + elapsed_ms = (time.monotonic() - started_at) * 1000.0 + + try: + self._update_run_metrics( + response=response, elapsed_ms=elapsed_ms, pre_action=pre_action + ) + except Exception: + if self._config.fail_open: + logger.exception("openai-agents harness metric update failed (fail-open)") + else: + raise + + return response + + def stream_response( + self, + system_instructions: str | None, + input: str | list[Any], # noqa: A002 - required by OpenAI Agents SDK Model interface + model_settings: ModelSettings, + tools: list[Tool], + output_schema: Any | None, + handoffs: list[Any], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: ResponsePromptParam | None, + ) -> AsyncIterator[Any]: + gated_tools, pre_action = self._gate_tools(tools) + started_at = time.monotonic() + + stream = self._base_model.stream_response( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=gated_tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt, + ) + return _CascadeFlowStreamWrapper( + stream=stream, + model_name=self._model_name, + started_at=started_at, + pre_action=pre_action, + fail_open=self._config.fail_open, + ) + + +class _CascadeFlowStreamWrapper: + def __init__( + self, + *, + stream: AsyncIterator[Any], + model_name: str, + started_at: float, + pre_action: str, + fail_open: bool, + ) -> None: + self._stream = stream + self._model_name = model_name + self._started_at = started_at + self._pre_action = pre_action + self._fail_open = fail_open + self._finalized = False + self._last_response = None + + def __aiter__(self) -> _CascadeFlowStreamWrapper: + return self + + async def __anext__(self) -> Any: + try: + event = await self._stream.__anext__() + except StopAsyncIteration: + await self._finalize() + raise + + response = getattr(event, "response", None) + if response is not None: + self._last_response = response + return event + + async def _finalize(self) -> None: + if self._finalized: + return + self._finalized = True + + run = get_current_run() + if run is None: + return + + elapsed_ms = (time.monotonic() - self._started_at) * 1000.0 + response = self._last_response + + try: + _apply_run_metrics( + model_name=self._model_name, + response=response, + elapsed_ms=elapsed_ms, + pre_action=self._pre_action, + allow_reason="openai_agents_stream_step", + ) + except Exception: + if self._fail_open: + logger.exception("openai-agents stream metric update failed (fail-open)") + return + raise + + +def create_openai_agents_provider( + *, + model_candidates: Optional[list[str]] = None, + enable_tool_gating: bool = True, + fail_open: bool = True, +) -> CascadeFlowModelProvider: + """ + Convenience factory for OpenAI Agents SDK integration. + """ + + return CascadeFlowModelProvider( + config=OpenAIAgentsIntegrationConfig( + model_candidates=model_candidates, + enable_tool_gating=enable_tool_gating, + fail_open=fail_open, + ) + ) + + +def is_openai_agents_sdk_available() -> bool: + return OPENAI_AGENTS_SDK_AVAILABLE + + +__all__ = [ + "OPENAI_AGENTS_SDK_AVAILABLE", + "OpenAIAgentsIntegrationConfig", + "CascadeFlowModelProvider", + "create_openai_agents_provider", + "is_openai_agents_sdk_available", +] diff --git a/cascadeflow/schema/exceptions.py b/cascadeflow/schema/exceptions.py index 90a8bd80..b36a9d81 100644 --- a/cascadeflow/schema/exceptions.py +++ b/cascadeflow/schema/exceptions.py @@ -12,6 +12,7 @@ │ └── TimeoutError ├── ModelError ├── BudgetExceededError + ├── HarnessStopError ├── RateLimitError ├── QualityThresholdError ├── RoutingError @@ -137,6 +138,14 @@ def __init__(self, message: str, remaining: float = 0.0): self.remaining = remaining +class HarnessStopError(cascadeflowError): + """Harness enforcement stop for non-budget hard limits.""" + + def __init__(self, message: str, reason: str): + super().__init__(message) + self.reason = reason + + class RateLimitError(cascadeflowError): """Rate limit exceeded.""" @@ -202,6 +211,7 @@ def __init__( "TimeoutError", "ModelError", "BudgetExceededError", + "HarnessStopError", "RateLimitError", "QualityThresholdError", "RoutingError", diff --git a/docs/README.md b/docs/README.md index 1972c55f..b9cedf66 100644 --- a/docs/README.md +++ b/docs/README.md @@ -20,6 +20,7 @@ Welcome to cascadeflow documentation! 🌊 - [Tools](guides/tools.md) - Function calling and tool usage with cascades - [Agentic Patterns (Python)](guides/agentic-python.md) - Tool loops and multi-agent orchestration in Python - [Agentic Patterns (TypeScript)](guides/agentic-typescript.md) - Tool loops, multi-agent orchestration, and message best practices +- [Harness Telemetry & Privacy](guides/harness_telemetry_privacy.md) - Decision traces, callbacks, and privacy-safe observability - [Cost Tracking](guides/cost_tracking.md) - Track and analyze API costs across queries - [Proxy Routing](guides/proxy.md) - Route requests through provider-aware proxy plans @@ -35,10 +36,12 @@ Welcome to cascadeflow documentation! 🌊 - [Custom Validation](guides/custom_validation.md) - Implement custom quality validators - [Edge Device Deployment](guides/edge_device.md) - Deploy cascades on edge devices (Jetson, etc.) - [Browser/Edge Runtime](guides/browser_cascading.md) - Run cascades in browser or edge environments +- [Agent Intelligence V2/V2.1 Plan](strategy/agent-intelligence-v2-plan.md) - Unified strategic and execution plan for in-process agent intelligence harness delivery ### Integrations - [n8n Integration](guides/n8n_integration.md) - Use cascadeflow in n8n workflows - [Paygentic Integration](guides/paygentic_integration.md) - Usage metering and billing lifecycle helpers (opt-in) +- [OpenAI Agents SDK Integration](guides/openai_agents_integration.md) - Harness-aware model provider for existing OpenAI Agents apps ## 📚 Examples diff --git a/docs/guides/harness_telemetry_privacy.md b/docs/guides/harness_telemetry_privacy.md new file mode 100644 index 00000000..01e75402 --- /dev/null +++ b/docs/guides/harness_telemetry_privacy.md @@ -0,0 +1,59 @@ +# Harness Telemetry and Privacy + +Use this guide when you want harness observability without leaking user content. + +## What the Harness Records + +Each `run.trace()` decision entry includes: + +- `action`, `reason`, `model` +- `run_id`, `mode`, `step`, `timestamp_ms` +- `cost_total`, `latency_used_ms`, `energy_used`, `tool_calls_total` +- `budget_state` (`max`, `remaining`) +- `applied`, `decision_mode` (when available) + +The trace is scoped to the current `run()` context. + +## What the Harness Does Not Record + +By default, harness decision traces do not include: + +- raw prompts or user messages +- model response text +- tool argument payloads + +This keeps decision telemetry focused on policy/routing state instead of request content. + +## Callback Emission (Optional) + +If you provide a callback manager, each harness decision emits `CallbackEvent.CASCADE_DECISION`. + +```python +from cascadeflow import init, run +from cascadeflow.telemetry.callbacks import CallbackEvent, CallbackManager + +manager = CallbackManager() + +def on_decision(event): + print(event.data["action"], event.data["model"]) + +manager.register(CallbackEvent.CASCADE_DECISION, on_decision) + +init(mode="observe", callback_manager=manager) + +with run(budget=1.0) as r: + ... +``` + +The emitted callback uses `query="[harness]"` and `workflow="harness"` to avoid passing user prompt content. + +## Per-Run Summary Logging + +When a scoped run exits (and recorded at least one step), the harness logs a summary on logger `cascadeflow.harness`: + +- run id, mode, steps, tool calls +- cost/latency/energy totals +- last action/model +- remaining budget + +Use standard Python logging controls to direct this to your existing log sink. diff --git a/docs/guides/openai_agents_integration.md b/docs/guides/openai_agents_integration.md new file mode 100644 index 00000000..2db6b8b7 --- /dev/null +++ b/docs/guides/openai_agents_integration.md @@ -0,0 +1,73 @@ +# OpenAI Agents SDK Integration + +Use cascadeflow as an explicit, opt-in `ModelProvider` integration for the OpenAI Agents SDK. + +## Design Principles + +- Integration-only: nothing is enabled by default +- Works with existing Agents SDK apps +- Harness behavior is controlled by `cascadeflow.init(...)` and `cascadeflow.run(...)` +- Fail-open integration path: harness integration errors should not break agent execution + +## Install + +```bash +pip install "cascadeflow[openai,openai-agents]" +``` + +## Quickstart + +```python +import asyncio + +from agents import Agent, RunConfig, Runner +from cascadeflow import init, run +from cascadeflow.integrations.openai_agents import ( + CascadeFlowModelProvider, + OpenAIAgentsIntegrationConfig, +) + + +async def main() -> None: + # Global harness defaults. + init(mode="enforce", budget=1.0, max_tool_calls=6) + + provider = CascadeFlowModelProvider( + config=OpenAIAgentsIntegrationConfig( + model_candidates=["gpt-4o", "gpt-4o-mini"], + enable_tool_gating=True, + ) + ) + + agent = Agent( + name="SupportAgent", + instructions="Answer support questions clearly and concisely.", + model="gpt-4o", + ) + + run_config = RunConfig(model_provider=provider) + + # Scoped run accounting for a single user task. + with run(budget=0.5, max_tool_calls=3) as session: + result = await Runner.run(agent, "Reset my account password", run_config=run_config) + print(result.final_output) + print(session.trace()) + + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## What This Integration Adds + +- Harness-aware model switching under budget pressure +- Tool gating when enforce-mode limits are reached +- Run metrics on `cascadeflow.run()` context: + - `cost`, `budget_remaining`, `step_count`, `tool_calls`, `latency_used_ms`, `energy_used` +- Full action trace through `run.trace()` + +## Notes + +- This is a Python integration for OpenAI Agents SDK. +- The SDK remains optional and is only installed via the `openai-agents` extra. +- Existing non-Agents users are unaffected. diff --git a/docs/strategy/agent-intelligence-v2-plan.md b/docs/strategy/agent-intelligence-v2-plan.md new file mode 100644 index 00000000..117968a8 --- /dev/null +++ b/docs/strategy/agent-intelligence-v2-plan.md @@ -0,0 +1,1082 @@ +# Agent Intelligence V2 Plan + +Last updated: February 25, 2026 +Status: Planning (no implementation in this document) +Supersedes: agent-intelligence-v1-plan.md + +## 1. Objective + +Make cascadeflow the default **in-process agent intelligence harness** for teams that need runtime control over cost, latency, quality, risk, budget, energy, and business KPIs. + +Not a proxy. Not a hosted dependency. A local-first infrastructure layer that can influence agent decisions during execution. + +### 1.1 Winning Criteria + +This plan is successful only if all three pillars are achieved: + +1. **Low-friction install** + - Time-to-first-value under 15 minutes + - Existing apps can activate in 1-3 lines + - Explicit opt-in, no breaking changes for current users +2. **In-loop business KPI control** + - Policies can influence step-level decisions and tool usage at runtime + - Hard constraints and soft KPI preferences both supported + - Decisions are explainable (`why` + `what action`) +3. **Reproducible benchmark superiority on realistic workflows** + - Better or equal quality vs baseline while improving cost/latency + - Results reproducible with pinned configs, prompts, models, and scripts + - Agentic benchmarks include tool loops and multi-step workflows (not only static QA) + +## 2. Product Thesis (Grounded) + +Most routers and gateways optimize at request boundaries. The bigger opportunity is inside agent execution: + +- Per-step model decisions based on agent state +- Per-tool-call gating based on remaining budget +- Runtime-aware stop/continue/escalate actions +- Business KPI injection during agent loops +- Learning from outcomes to improve future routing + +This is the moat: **in-process harness for agent decisions**, not external provider routing. + +### What Competitors Already Do (and Why That Is Not Enough) + +- External routers/gateways already do strong request-level routing, fallback, and policy checks. +- Agent frameworks already expose hook systems and guardrails. + +The remaining gap is **cross-framework, local-first, step-level optimization with shared policy semantics**: +- one policy model across different agent stacks, +- one observability model across direct SDK + frameworks, +- one enforcement model across tool loops and sub-agent calls. + +### Why External Proxies Stay Structurally Limited + +A proxy sees: `POST /v1/chat/completions { model, messages, tools }`. + +cascadeflow's harness sees: agent state, step count, budget consumed, tool call history, error context, quality scores on intermediate results, domain, complexity, conversation depth, and any user-defined business context. + +This information asymmetry is structural and permanent. Replicating in-process agent state awareness from an external proxy requires fundamental architectural changes — not a feature addition. + +## 3. Target Users and Segments + +- Startups shipping AI agents in existing products +- Platform teams standardizing agent behavior across products and tenants +- Individual developers are supported, but V2 optimization is for teams with production constraints + +Primary constraints (hard): +- Max cost, max latency, max tool calls, risk/compliance gates, max energy + +Secondary constraints (soft): +- Weighted KPI preferences that influence model/tool decisions when hard limits are not violated + +## 4. V2/V2.1 Release Contract (Single Plan) + +This document contains both releases in one plan with explicit boundaries: + +| Area | V2 (Python-first) | V2.1 | +|---|---|---| +| Core harness API (`init`, `run`, `@agent`) | Python | TypeScript parity | +| Auto-instrumentation | OpenAI Python client | Anthropic Python + OpenAI/Anthropic TS clients | +| Integrations | OpenAI Agents SDK, CrewAI, LangChain (Python) + regression checks for existing integrations | TS integration parity + deeper framework convergence | +| Policy semantics | Defined and validated in Python | Same semantics validated in TS parity fixtures | +| Launch target | Production-ready Python harness + reproducible benchmarks | Cross-language parity release | + +## 5. V2 Product Definition + +V2 ships an **agent harness** as an optional, integration-first intelligence layer: + +- Not enabled by default +- No cloud dependency required +- Works in existing apps/agents with minimal code changes (target: 1-3 lines) +- Default behavior remains unchanged unless explicitly enabled +- All framework-specific integrations are separate packages (not bundled with core) + +### Harness Modes + +- `off`: No harness evaluation (default for all existing users) +- `observe`: Evaluate + emit decisions, no behavior change (safe production rollout) +- `enforce`: Apply harness actions at runtime + +### Recommended Rollout for Users + +1. Start with `observe` in production +2. Validate traces + false positives + overhead +3. Enable `enforce` for selected tenants/channels + +## 5.1 Low-Friction DX Contract (Must-Haves) + +- Explicit activation only: no hidden patching. +- Existing code path preserved if harness is `off`. +- If auto-instrumentation is not safe in a runtime, users can use explicit adapter hooks (fallback mode). +- Quickstarts prioritize existing applications first, greenfield second. + +## 5.2 DX Philosophy + +### Principle: Invisible infrastructure, not wrappers + +The gold standard DX is Sentry, DataDog, OpenTelemetry — you activate it, your existing code doesn't change. + +cascadeflow targets this with **auto-instrumentation where safe**, plus **framework-native hooks** in optional integration packages. + +> **Note**: The APIs shown below (`cascadeflow.init()`, `cascadeflow.run()`, `@cascadeflow.agent()`) are the **target V2 API design**. They do not exist today. Current API is `CascadeAgent(models).run(query)`. Building these APIs is the V2 deliverable. + +### Tier 1: Zero-change activation (core, target API) + +```python +import cascadeflow + +cascadeflow.init(mode="observe") +# Every openai call in your app is now observed. +# No code changes. No wrappers. +# Example startup diagnostics: +# [cascadeflow] instrumented: openai +# [cascadeflow] detected but not instrumented in V2: anthropic (planned V2.1) + +cascadeflow.init(mode="enforce") +# Now actively cascading, routing, and enforcing budgets. +``` + +How it works: `init()` patches LLM client libraries at the call level. This is the same proven pattern used by Sentry, DataDog APM, and OpenTelemetry auto-instrumentation. + +V2 scope: `openai` Python client patching only. `anthropic` client patching follows in V2.1. Auto-instrumentation covers code that calls the `openai` SDK directly. Frameworks that abstract over the SDK (LangChain's `ChatOpenAI`, CrewAI via LiteLLM) require their respective integration packages for full coverage. + +### Tier 2: Agent-scoped harness (core, target API) + +```python +async with cascadeflow.run(budget=0.50, max_tool_calls=10) as run: + # Your existing agent code + result = await my_agent.invoke({"task": "Fix the login bug"}) + + print(run.cost) # $0.12 + print(run.savings) # 68% + print(run.tool_calls) # 4 of 10 budget used +``` + +A context manager scopes budget tracking and harness decisions to an agent run. No restructuring of agent code required. + +### Tier 3: Decorated agent with KPIs (core, target API) + +```python +import openai + +@cascadeflow.agent( + budget=0.50, + kpi_targets={"quality_min": 0.90, "latency_ms_max": 3000}, + kpi_weights={"cost": 0.4, "quality": 0.3, "latency": 0.2, "energy": 0.1}, + compliance="gdpr", +) +async def customer_support_agent(task: str): + client = openai.AsyncOpenAI() + response = await client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": task}], + ) + return response.choices[0].message.content +``` + +A decorator adds metadata. The function body doesn't change. + +### Tier 4: Framework-specific deep integration (integration packages) + +```python +# Requires separate install — not bundled with core. +# These extras do not exist in pyproject.toml today and must be added in Phase D. + +# pip install cascadeflow[langchain] +from cascadeflow.integrations.langchain import CascadeFlowCallbackHandler + +# pip install cascadeflow[openai-agents] +from cascadeflow.integrations.openai_agents import CascadeFlowModelProvider + +# pip install cascadeflow[crewai] +from cascadeflow.integrations.crewai import CascadeFlowHooks +``` + +Framework-specific packages provide deeper integration (state extraction, middleware hooks, framework-native telemetry). These are optional — Tier 1-3 work without them for code that calls the `openai` SDK directly. + +### TypeScript Equivalent + +```typescript +import { cascadeflow } from '@cascadeflow/core'; + +// Tier 1: Auto-instrument +cascadeflow.init({ mode: 'enforce' }); + +// Tier 2: Scoped run +const result = await cascadeflow.run({ budget: 0.50 }, async (run) => { + return await myAgent.invoke({ task: 'Fix the login bug' }); +}); + +// Tier 4: Framework packages +// npm install @cascadeflow/langchain +// npm install @cascadeflow/openai-agents +// npm install @cascadeflow/vercel-ai (already exists) +// npm install @cascadeflow/n8n (already exists) +``` + +## 5.3 DX Execution Contracts (Required) + +These contracts remove ambiguity for production teams: + +1. **`init()` instrumentation diagnostics** + - `init()` emits a startup summary of what was instrumented and what was detected but not instrumented in the current version. + - V2 example: OpenAI instrumented, Anthropic detected-but-not-instrumented warning. +2. **`init()` + `run()` scope composition** + - `init()` defines global defaults for calls outside any scoped run. + - `run()` creates an isolated child scope. + - Inside a `run()` scope, run-level settings override global defaults for that scope only. + - Nested `run()` scopes are isolated; inner scope does not mutate outer scope. +3. **Existing `CascadeAgent` migration behavior** + - `cascadeflow.init()` does not rewrite `CascadeAgent`'s core cascade behavior. + - `CascadeAgent` can execute inside `cascadeflow.run()` to contribute to run-level budget/trace accounting. +4. **Configuration precedence** + - Effective config resolution order: explicit code kwargs > environment variables > config file (`cascadeflow.yaml` / JSON) > library defaults. + - `init()` without kwargs may resolve from env/file for platform deployments. + +## 6. Scope (V2) + +### In Scope + +- Harness engine in core (init, run context, decorator, action evaluation) +- Auto-instrumentation of `openai` Python client library (V2 scope; `anthropic` client and TS parity in V2.1) +- Harness modes: `off | observe | enforce` +- Action vocabulary: `allow | switch_model | deny_tool | stop` +- Config precedence support for harness init (code kwargs > env vars > config file > defaults) +- Hard controls: max cost, max latency, max tool calls, risk gates, max energy +- Soft controls: weighted KPI preferences +- Step-level and tool-level harness hooks +- Energy dimension (optional, in core) +- Parity fixtures/spec for TS implementation in V2.1 (Python implementation ships in V2) +- Integration packages (separate install, not bundled with core): + - `cascadeflow[openai-agents]` — OpenAI Agents SDK (NEW — extra must be added to pyproject.toml) + - `cascadeflow[crewai]` — CrewAI via LLM hooks (NEW — extra must be added to pyproject.toml) + - `cascadeflow[langchain]` — LangChain/LangGraph (EXISTS as code, extra must be added to pyproject.toml) + - Existing integrations verified: Vercel AI SDK, n8n +- Named benchmark suite with acceptance gates + +### Out of Scope (V2) + +- Hosted control plane / Studio (future product) +- Mandatory migration for existing users +- Autonomous learning loop with remote training (future phase) +- Speculative agent execution (future phase) +- Carbon API integration (future; energy estimate is V2, live carbon data is not) +- MCP tool call interception (future phase) +- Google ADK integration (on demand) + +## 7. Non-Negotiable Constraints + +- Backward compatible: existing users see zero behavior change +- Opt-in only: `off` by default +- No default latency regression for non-harness users +- Harness decision overhead target: **<5ms p95** +- Cascade execution overhead: documented and expected (extra LLM call for verification) +- Preserve existing DX simplicity for non-harness users +- Framework integrations are never auto-installed with core +- Auto-instrumentation is explicit (`cascadeflow.init()`) — never hidden + +## 8. Architecture + +### 8.1 Package Boundaries + +``` +cascadeflow (core) +├── cascadeflow.harness # Harness engine (NEW) +│ ├── init() # Auto-instrumentation entry point +│ ├── run() # Context manager for scoped runs +│ ├── agent() # Decorator for KPI-annotated agents +│ ├── actions # allow, switch_model, deny_tool, stop +│ ├── context # HarnessContext (runtime state) +│ └── instrument # LLM client patching (openai, anthropic) +├── cascadeflow.rules # Rule engine (EXISTS, extended) +├── cascadeflow.quality # Quality validation (EXISTS) +├── cascadeflow.routing # Routing (EXISTS) +├── cascadeflow.core.cascade # Speculative cascade (EXISTS) +├── cascadeflow.telemetry # Cost tracking + metrics (EXISTS) +└── cascadeflow.providers # LLM providers (EXISTS) + +cascadeflow[openai-agents] # Integration package (NEW) +├── CascadeFlowModelProvider # OpenAI Agents SDK ModelProvider +├── tool_guard # Tool call gating via Agents SDK hooks +└── trace_adapter # Map Agents SDK traces to harness context + +cascadeflow[crewai] # Integration package (NEW) +├── CascadeFlowHooks # CrewAI LLM call hooks +├── crew_context # Extract crew/agent/task state +└── step_callback # Budget tracking per crew step + +cascadeflow[langchain] # Integration package (EXISTS, extended) +├── CascadeFlow(BaseChatModel) # Existing LangChain wrapper +├── harness_callback # NEW: LangGraph middleware for harness +└── state_extractor # NEW: Extract LangGraph state for context +``` + +### 8.2 Core Harness Layer + +Extend current rule context with runtime/loop state: + +```python +@dataclass +class HarnessContext: + # Identification + agent_id: Optional[str] = None + run_id: str = field(default_factory=lambda: uuid4().hex[:12]) + + # Budget tracking (hard controls) + budget_max: Optional[float] = None + budget_used: float = 0.0 + tool_calls_max: Optional[int] = None + tool_calls_used: int = 0 + latency_max_ms: Optional[float] = None + latency_used_ms: float = 0.0 + energy_max: Optional[float] = None + energy_used: float = 0.0 + + # Agent state + step_count: int = 0 + tool_history: list[str] = field(default_factory=list) + error_history: list[str] = field(default_factory=list) + prior_actions: list[str] = field(default_factory=list) + cascade_active: bool = False + draft_model: Optional[str] = None + verifier_model: Optional[str] = None + draft_accepted: Optional[bool] = None + + # Soft controls (KPI weights, sum to 1.0) + kpi_weights: Optional[dict[str, float]] = None + + # Compliance + compliance_tags: list[str] = field(default_factory=list) + + # Harness mode + mode: Literal["off", "observe", "enforce"] = "off" +``` + +### 8.3 Harness Action Surface + +Actions the harness can take: + +| Action | Description | When | +|---|---|---| +| `allow` | Proceed normally (default) | Hard limits not violated | +| `switch_model` | Use a different model for this call | Cost/quality/latency optimization | +| `deny_tool` | Block a tool call | Budget exhausted, risk gate, compliance | +| `stop` | Terminate the agent run | Hard budget exceeded, safety gate | + +These actions are evaluated at three hook points: + +- **Pre-LLM-call**: Before each model invocation (model selection, budget check) +- **Pre-tool-call**: Before each tool execution (tool gating, budget check) +- **Post-LLM-call**: After each model response (quality validation, state update) + +In `observe` mode: actions are computed and logged but not applied. +In `enforce` mode: actions are computed, logged, and applied. + +### 8.3.1 `switch_model` Resolution Path + +`switch_model` is not a simple fallback list. It uses existing cascadeflow intelligence: + +1. Rule constraints (tenant/channel/KPI/tier/workflow context) +2. Complexity + domain signals +3. Model capability and safety constraints (tool support, risk/compliance requirements) +4. Cost/latency/quality scoring over remaining candidate models + +The selected model and reason are always included in the decision trace. + +### 8.3.2 `deny_tool` Contract (Default) + +Default behavior in V2: + +1. **Prevention path (preferred):** if a tool is disallowed before model execution, the tool is removed/blocked from the callable set for that step. +2. **Interception path:** if a disallowed tool call is still emitted, return a synthetic structured tool result: + - `{"error":"tool_denied","reason":"budget_exceeded","action":"deny_tool"}` +3. Continue the loop with the denial result in context so the agent can recover or stop. + +Integrations may map this to framework-native interruption semantics, but the default contract remains structured and non-crashing. + +### 8.4 Auto-Instrumentation Layer + +Core patches LLM client libraries to intercept calls: + +```python +# V2 scope — core auto-instrumentation: +# - openai (Python) — already an optional dep in pyproject.toml + +# V2.1 scope: +# - anthropic (Python) — already an optional dep +# - openai (TypeScript) — in @cascadeflow/core + +# Supported via integration packages (separate install): +# - litellm (existing integration module; optional dependency) +# - langchain ChatModels (via cascadeflow[langchain]) +# - crewai LLM (via cascadeflow[crewai]) +``` + +The patch intercepts `create()` / `acreate()` calls and: +1. Reads the current `HarnessContext` (from context manager or `contextvars`, not thread-local) +2. Evaluates harness rules (complexity, domain, budget state) +3. In `observe`: logs the decision, passes through unchanged +4. In `enforce`: applies action (switch model, cascade, deny) +5. Updates context (cost, latency, step count) + +Implementation contract: +- Patch registration is idempotent (multiple `init()` calls are safe). +- Scoped runs use isolated contextvar state (including nested runs). +- A clean unpatch/reset path exists for tests and controlled shutdown. + +### 8.5 Integration Layer + +Ship as optional integration packages, same pattern as existing integrations: + +- Explicit install (`pip install cascadeflow[crewai]`) +- Explicit enable/config +- No hidden activation from core install +- Try/except imports with `AVAILABLE` flags +- Graceful degradation when not installed + +Each integration provides: +1. **State extraction**: Pull agent/framework state into `HarnessContext` +2. **Native hooks**: Use the framework's own extension points (not custom wrappers) +3. **Telemetry bridge**: Map framework traces to harness telemetry + +| Integration | Framework Extension Point | What It Adds | +|---|---|---| +| `openai-agents` | `ModelProvider` at `Runner.run` level | Model routing, tool gating | +| `crewai` | `llm_hooks` (native CrewAI feature) | LLM call interception, crew state | +| `langchain` | `BaseChatModel` (existing) + LangGraph middleware | State extraction, callbacks | +| `vercel-ai` | Existing `@cascadeflow/vercel-ai` | Extend with harness config | +| `n8n` | Existing `@cascadeflow/n8n-nodes-cascadeflow` | Extend with harness node params | + +## 9. Hard vs Soft Controls + +### 9.0 KPI Input Schema + +To avoid ambiguity, harness KPI config is split into two explicit inputs: + +- `kpi_targets`: absolute goals/limits (for example `quality_min`, `latency_ms_max`) +- `kpi_weights`: optimization preferences used for scoring when hard limits are not violated + +### Hard Controls (enforced when enabled) + +| Control | Config | Action on Violation | +|---|---|---| +| Max cost per run | `budget=0.50` | `switch_model` (downgrade) or `stop` | +| Max tool calls | `max_tool_calls=10` | `deny_tool` | +| Max latency per run | `max_latency_ms=5000` | `switch_model` (faster) or `stop` | +| Risk/compliance gate | `compliance="gdpr"` | Route to compliant model or `deny_tool` | +| Max energy estimate | `max_energy=0.01` | `switch_model` (lighter) or `stop` | + +### Soft Controls (influence, don't enforce) + +Weighted KPI preferences that influence model/tool decisions when hard limits are not violated: + +```python +cascadeflow.init( + mode="enforce", + kpi_weights={ + "cost": 0.4, # 40% weight on cost optimization + "quality": 0.3, # 30% weight on quality + "latency": 0.2, # 20% weight on latency + "energy": 0.1, # 10% weight on energy efficiency + } +) +``` + +Soft controls affect model scoring in the cascade routing decision. They do not trigger `deny_tool` or `stop`. + +### 9.1 Prompt Caching Strategy + +Prompt caching is complementary to cascading and budget enforcement. + +V2: +- Capture cache-related usage signals where available (e.g., cached tokens) in telemetry. +- Expose cache metrics in traces and benchmark artifacts. +- Do not make cache-hit optimization a hard routing objective yet. + +V2.1: +- Optional cache-aware scoring bias for compatible providers/models. +- Validate that cache-aware routing improves net economics without quality regressions. + +### 9.2 Energy Estimation Specification (V2) + +V2 uses a deterministic proxy estimate (not real-time grid carbon): + +- `energy_units = model_coefficient * (input_tokens + output_tokens * output_weight)` +- `model_coefficient` comes from a versioned local mapping (fallback to default when unknown). +- `output_weight` defaults to >1 to reflect higher generation compute cost. + +This keeps energy scoring deterministic, reproducible, and local-first. Live carbon-intensity routing remains post-V2. + +## 10. TS/Python Parity Requirements + +Parity means same core semantics, not necessarily identical APIs. + +V2 ships Python first. TS parity is a V2.1 deliverable (Phase F). Parity fixtures are written in V2 Phase A as the TS implementation spec. + +Target parity (V2.1): +- Same harness modes: `off | observe | enforce` +- Same action vocabulary: `allow | switch_model | deny_tool | stop` +- Same `HarnessContext` fields for budget/latency/energy/tool-depth +- Same fallback behavior when harness is disabled +- Same hook points: pre-LLM-call, pre-tool-call, post-LLM-call +- Comparable telemetry fields for analysis +- Shared parity test fixtures (written in V2, validated in V2.1) + +## 11. Framework Integrations (V2) + +### 11.1 OpenAI Agents SDK (`cascadeflow[openai-agents]`) + +Required as official integration coverage in V2. + +Integration approach: +- Use `ModelProvider` at `Runner.run` level (framework's native extension) +- NOT a custom wrapper around the SDK +- Harness evaluates at each agent step via the model provider +- Tool gating via tool-call inspection in model responses + +Minimum capabilities: +- Harness runs in `observe` and `enforce` modes +- Tool-call gating (deny on harness action) +- Model recommendation/switch based on harness decision +- Budget tracking across multi-step agent runs +- No hard dependency forced onto all cascadeflow users + +### 11.2 CrewAI (`cascadeflow[crewai]`) + +Integration approach: +- Use CrewAI's native `llm_hooks` (before/after LLM calls) +- Extract crew/agent/task state into `HarnessContext` +- Budget tracking via `step_callback` + +### 11.3 LangChain/LangGraph (`cascadeflow[langchain]`) + +Integration approach: +- Extend existing `CascadeFlow(BaseChatModel)` wrapper +- Add LangGraph-specific middleware for state extraction +- Add harness-aware callback handler +- Preserve existing DX for current LangChain users + +### 11.4 Existing Integrations + +Verify and extend (no breaking changes): +- `@cascadeflow/vercel-ai`: Add harness config pass-through +- `@cascadeflow/n8n-nodes-cascadeflow`: Add harness mode parameter to nodes +- `cascadeflow.integrations.litellm`: Verify harness compatibility +- `cascadeflow.integrations.openclaw`: Verify harness compatibility + +## 12. Transparency and Debugging + +Auto-instrumentation must not be magic. Every harness decision is visible: + +- `cascadeflow.init(mode="observe")`: Logs every decision (what it *would* do) +- `cascadeflow.init(mode="enforce", verbose=True)`: Rich console output showing cascade path +- Harness metadata is accessible via two paths depending on usage mode: + - **Library mode** (in-process): Metadata on `HarnessContext` / `run` object — `run.last_action`, `run.model_used`, `run.draft_accepted`, `run.budget_remaining`, `run.run_id` + - **Proxy mode** (HTTP gateway): `x-cascadeflow-*` response headers (existing proxy behavior, unchanged) +- `run.trace()` returns full decision log for a scoped run +- Harness decisions are emitted via existing `CallbackManager` events +- All decisions include: action taken, reason, model used, budget state, run_id for correlation + +Default logging destination: +- Logger name: `cascadeflow.harness` +- `DEBUG`: per-step decisions and action reasons +- `INFO`: per-run summaries in `run()` scope +- `verbose=True`: adds rich console rendering on top of logger output (does not replace structured logging) + +### 12.1 Run Object Surface (V2 Target API) + +```python +run.cost # float: total cost in scoped run +run.savings # float: savings percentage vs selected baseline +run.tool_calls # int: tool calls used +run.budget_remaining # float|None: remaining budget if configured +run.model_used # str|None: most recent selected model +run.last_action # str: allow|switch_model|deny_tool|stop +run.draft_accepted # bool|None: draft acceptance for last cascade decision +run.run_id # str: correlation id +run.trace() # list[dict]: full decision timeline +``` + +## 13. Benchmark and Validation Plan + +Use live API runs and keep comparability with prior benchmark set. Winning claims require reproducible, public methodology. + +### 13.1 Benchmark Families + +- Baseline language/reasoning: MT-Bench, TruthfulQA +- Code correctness: HumanEval, SWE-bench Lite slices +- Classification/structured output: Banking77 +- Tool use and agent loops: BFCL-style tool/function scenarios + internal loop tests +- Product realism: customer-support and multi-agent delegation scenarios already aligned with cascadeflow usage + +### 13.2 Realistic Workflow Suite (Required) + +Each benchmark run must include at least these workload types: + +- Existing app integration flow (OpenAI SDK direct calls) +- Existing agent framework flow (OpenAI Agents SDK, LangChain/LangGraph, CrewAI) +- Tool-heavy flow (5+ loop steps, mixed tool success/failure) +- Budget-constrained flow (mid-run budget pressure) +- Risk/compliance-constrained flow (policy escalation and tool deny paths) + +### 13.3 Reproducibility Protocol (Non-Negotiable) + +- Pin exact git SHA, benchmark script version, model names, and provider endpoints. +- Store raw per-case outputs (JSON/JSONL), not only aggregate summaries. +- Record both quality metrics and economics metrics per case: + - accepted/rejected, + - draft acceptance, + - total cost, + - latency, + - selected model path, + - policy action path. +- Publish confidence intervals and sample sizes for reported improvements. +- Re-run on at least two separate days before public claims. + +### 13.4 Superiority Criteria (Grounded) + +To claim “winning” in go-to-market material: + +- Quality: non-inferior to baseline on core tasks with agreed margin. +- Cost: statistically significant reduction on realistic agent workflows. +- Latency: no material regression for non-harness users; harness overhead p95 <5ms. +- Policy safety: false-positive enforcement rate under agreed threshold. +- DX: time-to-first-value within target and successful quickstart completion by external testers. + +### 13.5 Launch Gates + +- Observe mode must be behavior-identical to baseline (output parity checks). +- Enforce mode must show measurable value on at least three realistic workflow families. +- Benchmark scripts and result artifacts must be executable by third parties with documented setup. + +## 14. Competitive Positioning + +### 14.1 Ecosystem Baseline Capabilities + +- Provider/model fallback and load balancing +- Request-level cost optimization (model selection) +- Cross-provider unified API access +- Low integration friction (URL change) +- Framework middleware/hooks, guardrails, and tracing + +### 14.2 What Remains Unresolved Across These Tools + +- No shared cross-framework policy semantics for business KPIs. +- Limited consistent in-loop controls across model/tool/agent-step decisions. +- Weak portability of optimization behavior across direct SDK use and framework use. +- Economic claims are often hard to reproduce end-to-end on realistic workflows. + +### 14.3 Remaining Gap We Target + +- Cross-framework policy semantics (one control model across stacks). +- In-loop optimization that combines cost, latency, quality, risk, and business KPIs. +- Local-first deployment without mandatory cloud control plane. +- Reproducible economic + quality gains under realistic agent workflows. + +### 14.4 Positioning Against Current Market + +| Category | Examples | Their Strength | cascadeflow Differentiator | +|---|---|---|---| +| Budget-only enforcement | AgentBudget, custom budget middleware | Fast setup for spend caps and loop stops | Multi-dimensional optimization: cost + quality + latency + KPI + energy + cascade validation | +| Proxy cost-control + observability | Helicone, similar gateway observability stacks | Fast request-level analytics/caching/rules without code-level harness changes | In-process agent-state decisions and step/tool-level policy enforcement inside loops | +| External routers/gateways | OpenRouter, Portkey, NotDiamond | Provider/routing control at API boundary | In-loop action control with agent state and policy context | +| Framework-native orchestration | OpenAI Agents SDK, LangGraph, CrewAI | Rich framework-specific hooks and orchestration | Cross-framework policy layer + unified KPI semantics | +| Single-provider optimization | Provider-native routing features | Tight provider integration and defaults | Multi-provider, user-economics-first optimization | + +## 15. Risks and Mitigations + +- **Risk**: Over-complex harness UX + Mitigation: Default `off`, `observe` before `enforce`, 1-3 lines to activate. Progressive complexity. + +- **Risk**: Auto-instrumentation surprises (patching library internals) + Mitigation: Explicit `init()` required. Never hidden. `observe` mode first. Verbose logging available. Metadata on every response. + +- **Risk**: "Always verifier" behavior in sensitive benchmarks + Mitigation: Explicit harness reasons + scenario tests + calibrated hard/soft boundaries. + +- **Risk**: TS/Python drift + Mitigation: Shared parity fixtures and decision test cases. + +- **Risk**: Integration sprawl + Mitigation: One harness core, thin adapters per integration. Auto-instrumentation plus explicit adapter mode for hard runtimes. + +- **Risk**: Framework API instability (breaking changes in LangGraph, CrewAI, etc.) + Mitigation: Integrations are thin adapters (<500 lines). Core harness works via LLM client patching regardless of framework changes. + +- **Risk**: LangChain/OpenAI build competing harness features + Mitigation: Ship fast, position as complementary (not competing), framework-agnostic is the moat. LangChain's Deep Agents is LangChain-only. cascadeflow works with everything. + +- **Risk**: LLM provider builds internal routing (GPT-5 internal router) + Mitigation: Provider routing is single-provider and optimizes for provider economics. cascadeflow is multi-provider and optimizes for user economics/KPIs. Re-evaluate this risk quarterly with a documented competitive capability review. + +- **Risk**: Harness decision overhead exceeds target + Mitigation: Rule evaluation is CPU-only (no network calls). Benchmark continuously. Degrade gracefully (skip harness if overhead budget exceeded). + +- **Risk**: Low-friction promise fails in real teams + Mitigation: Track time-to-first-value in external pilot tests; gate launch on quickstart completion metrics. + +- **Risk**: Benchmark claims are not trusted externally + Mitigation: Publish reproducibility protocol, scripts, and raw artifacts for independent reruns. + +## 16. Release Plan (Phased) + +### Phase A: Harness Core Definition (2-3 weeks) + +- Finalize `HarnessContext` schema +- Finalize action vocabulary and hook points +- Define `off | observe | enforce` mode behavior +- Write parity fixtures (Python first, TS fixtures as spec — TS implementation in V2.1) +- Design auto-instrumentation for `openai` Python client +- Add new extras to `pyproject.toml`: `langchain`, `openai-agents`, `crewai` + +Exit criteria: +- Schema frozen +- Python parity fixture tests green +- Auto-instrumentation prototype patching `openai` Python client +- pyproject.toml extras defined (even if integration code is not yet complete) + +### Phase B: Observe Mode (3-4 weeks) + +- Implement `cascadeflow.init(mode="observe")` (NEW top-level API) +- Auto-instrument `openai` Python client (sync + async + streaming + tool calling) +- Emit startup instrumentation diagnostics (instrumented vs detected-but-not-instrumented SDKs) +- Implement `cascadeflow.run()` context manager +- Emit decision traces via `CallbackManager` +- Integrate with existing `RuleEngine` (extended with `HarnessContext`) +- Harness metadata on `HarnessContext` / `run` object + +Note: Auto-instrumentation of `openai` client is the highest-risk engineering task. Patching async streaming, tool calling, retries, and `with_raw_response` requires exhaustive edge-case testing. + +Exit criteria: +- `observe` mode produces zero behavior change (validated by benchmark) +- Decision traces are accurate and complete +- Overhead within <5ms p95 target +- All existing tests still pass (backward compatibility) +- Edge cases validated: streaming, async, tool calling, parallel tool calls, retries + +### Phase C: Enforce Mode (3-4 weeks) + +- Activate `switch_model`, `deny_tool`, `stop` actions +- Implement hard controls (budget, tool-call cap, latency, energy) +- Implement soft controls (KPI-weighted model scoring) +- Add safety fallbacks (graceful degradation on harness error) +- Implement `@cascadeflow.agent()` decorator + +Exit criteria: +- Enforced behavior matches harness intent +- No critical regressions in benchmark suite +- Hard controls reliably enforced (100% of violations caught) +- Harness errors never crash the agent (fail-open) + +### Phase D: Integration Packages (3-5 weeks, parallelizable with Phase C) + +- `cascadeflow[openai-agents]`: ModelProvider + tool gating +- `cascadeflow[crewai]`: LLM hooks + crew state extraction +- `cascadeflow[langchain]`: Extend existing with harness callbacks +- Verify existing integrations: Vercel AI SDK, n8n, LiteLLM, OpenClaw +- Docs + quickstarts + examples for each integration + +Exit criteria: +- Install and quickstart verified end-to-end for each integration +- CI and integration tests green +- Each integration <500 lines of framework-specific code + +### Phase E: Benchmarks + Public Launch (2-3 weeks) + +- Run full benchmark suite (baseline + agentic + harness scenarios) +- Publish reproducible benchmark results +- Write launch content (blog post, integration cookbooks) +- Go/No-Go checklist validated + +Exit criteria: +- All acceptance gates met +- Benchmark results published and reproducible +- DX quickstart works for existing app/agent users with 1-3 lines of code + +### Total V2 Timeline (Python): 14-18 weeks + +This is the realistic timeline for Python-first delivery with one primary contributor. Phases C and D can overlap (integration packages start once enforce mode core is stable). + +### V2 Success Scorecard (Must Pass Before Launch) + +- **Low-friction install** + - 80%+ of pilot users complete quickstart without maintainer help. + - Median time-to-first-value under 15 minutes. +- **In-loop KPI control** + - Policy actions (`switch_model`, `deny_tool`, `stop`) triggered and logged correctly in scenario tests. + - Observe→enforce rollout shows no unexpected behavior in pilot tenants. +- **Benchmark superiority** + - Quality non-inferior vs baseline on agreed benchmark set. + - Statistically significant cost reduction on realistic agent workflows. + - Harness overhead p95 under 5ms for decision path. + +### Phase F: TypeScript Parity (V2.1, post-V2 launch) + +- Port `cascadeflow.init()` / `run()` to `@cascadeflow/core` +- Auto-instrument `openai` TypeScript client (OpenAI Node SDK) +- Port `HarnessContext`, action evaluation, harness modes +- TS parity fixture tests green +- Extend `@cascadeflow/vercel-ai` and `@cascadeflow/n8n` with harness support + +Estimated: 6-8 weeks after V2 Python launch. + +### Phase G: Anthropic Client Instrumentation (V2.1) + +- Auto-instrument `anthropic` Python client +- Auto-instrument `@anthropic-ai/sdk` TypeScript client +- Validate with Claude-based agent workflows + +Estimated: 3-4 weeks (can parallel with Phase F). + +### 16.1 Parallel Branch Workboard (Tick-Off) + +Use this section as the single coordination board for parallel execution. + +Branching model: +- Keep `main` always releasable. +- Use one integration branch for this program: `feature/agent-intelligence-v2-integration`. +- Contributors build on short-lived feature branches and merge to the integration branch first. +- Merge to `main` only after integration branch CI + benchmark gates are green. + +Claim checklist (one owner per branch at a time): +- [x] `feat/v2-core-harness-api` — Owner: `@codex` — PR: `TBD` — Status: `completed` +- [x] `feat/v2-openai-auto-instrumentation` — Owner: `@claude` — PR: `TBD` — Status: `in-progress` +- [x] `feat/v2-enforce-actions` — Owner: `@codex` — PR: `TBD` — Status: `completed (ready for PR)` +- [ ] `feat/v2-openai-agents-integration` — Owner: `@codex` — PR: `TBD` — Status: `in-progress` +- [ ] `feat/v2-crewai-integration` — Owner: `@` — PR: `#` — Status: `claimed/in-progress/review/merged` +- [x] `feat/v2-langchain-harness-extension` — Owner: `@codex` — PR: `TBD` — Status: `completed` +- [ ] `feat/v2-dx-docs-quickstarts` — Owner: `@` — PR: `#` — Status: `claimed/in-progress/review/merged` +- [x] `feat/v2-bench-repro-pipeline` — Owner: `@codex` — PR: `#163` — Status: `completed (merged to integration branch)` +- [x] `feat/v2-security-privacy-telemetry` — Owner: `@codex` — PR: `#162` — Status: `completed (merged to integration branch)` + +Merge gates per feature branch: +- [ ] Unit/integration tests green for touched scope +- [ ] Docs/examples updated for any API or behavior change +- [ ] Backward compatibility verified (`off` mode unchanged) +- [ ] Bench impact assessed (if runtime behavior changed) + +Integration-branch promotion gates: +- [ ] Core + integration CI green +- [ ] Full benchmark suite rerun with reproducibility artifacts +- [ ] Quickstart verification for existing app and framework paths +- [ ] Go/No-Go checklist in Section 18 satisfied before merging to `main` + +### 16.2 V2.1 Parallel Execution Split + +To enable parallel work without merge collisions, split V2.1 into Python and TS tracks: + +- `feat/v2.1-anthropic-python-auto-instrumentation` (completed in this branch) + - Scope: `cascadeflow/harness/*`, Python harness tests, Python docs notes + - Deliverables: Anthropic Python auto-instrumentation, validation for `init()/run()` harness path +- `feat/v2.1-ts-harness-api-parity` (completed and merged into this branch scope) + - Scope: `packages/core/*`, TS parity fixtures, TS docs notes + - Deliverables: `@cascadeflow/core` exports parity (`init()/run()`), TS fixture parity validation + +Parallel-safe rule: +- Python track does not touch `packages/core/*` +- TS track does not touch `cascadeflow/harness/*` + +## 17. Future Phases (Post-V2, Not in Scope) + +For roadmap visibility. These inform V2 telemetry design but are not V2 deliverables. + +### Future: Speculative Agent Execution +- Extend speculative cascade from model-level to agent-step-level +- Speculative next-step execution with cheap models, rollback on validation failure +- Selective verification (not every step needs verification) +- Validated by: Sherlock (Microsoft, 2025), Speculative Actions (2025) + +### Future: Adaptive Learning Engine +- Contextual bandit routing (replace/augment static rules) +- Per-agent, per-task performance tracking +- Online learning from outcomes, no offline training needed +- Cold-start with aggregated anonymous routing telemetry (opt-in) +- Validated by: EMNLP 2025 bandit routing papers, BATS (Google, 2025) + +### Future: cascadeflow Studio (Cloud Product) +- Dashboard: real-time visualization of all dimensions +- Fleet suggestions: auto-recommend optimal model combinations +- Learning flywheel: shared (anonymized) routing data improves routing for all users +- A/B testing: compare routing strategies in production +- Custom KPI builder: visual interface for defining business dimensions +- V2 telemetry fields are designed to support Studio without breaking changes + +### Future: MCP Integration +- Intercept MCP tool calls (not just function-calling) +- Apply harness logic to MCP server interactions +- Track MCP server latency/reliability as routing dimensions + +### Future: Additional Dimensions +- Carbon-aware routing with live grid carbon intensity data +- Data residency / compliance-aware model selection +- Custom business KPI plugins (user-defined scoring functions) + +## 18. Go/No-Go Checklist + +Go when all are true (V2 Python launch): + +- [ ] Harness layer is opt-in and backward compatible +- [ ] `cascadeflow.init()` auto-instruments `openai` Python client +- [ ] `observe` mode produces zero behavior change (benchmark-validated) +- [ ] `enforce` mode actions work correctly (switch_model, deny_tool, stop) +- [ ] Harness decision overhead <5ms p95 +- [ ] Python parity fixture tests pass +- [ ] Core + integration CI green +- [ ] Benchmark comparison acceptable vs latest baseline +- [ ] OpenAI Agents SDK integration documented and validated +- [ ] CrewAI integration documented and validated +- [ ] LangChain integration extended and validated +- [ ] Existing integrations (Vercel AI, n8n) verified compatible (no regressions) +- [ ] DX quickstart works for existing app/agent users with 1-3 lines of code change +- [ ] External pilot median time-to-first-value <15 minutes +- [ ] Public benchmark results ready for launch +- [ ] Benchmark scripts + raw artifacts are reproducible by third parties +- [ ] pyproject.toml extras (`openai-agents`, `crewai`, `langchain`) defined and installable + +V2.1 Go/No-Go (TS parity + anthropic): +- [x] TS parity fixtures pass +- [x] `@cascadeflow/core` exports `cascadeflow.init()` and `cascadeflow.run()` +- [x] `anthropic` Python client auto-instrumentation validated +- [x] `@anthropic-ai/sdk` TS client auto-instrumentation validated + +## 19. Academic Validation + +Key research validating this approach (all 2025): + +- **BATS: Budget-Aware Tool-Use** (Google DeepMind) — Budget tracker module achieves comparable accuracy with 40% fewer tool calls, 31% cost reduction +- **Sherlock** (Microsoft) — Speculative execution + selective verification for agent workflows: 18% accuracy gain, 48% latency reduction +- **Speculative Actions** — Generalizes speculative decoding to agent actions (tool calls, API calls): up to 55% next-action prediction accuracy +- **Faster Cascades via Speculative Decoding** (Google, ICLR 2025) — Validates cascadeflow's core cascade architecture as a first-class research direction +- **GAR: Carbon-Aware Routing** — Treats carbon as first-class routing objective alongside accuracy and latency SLOs +- **Adaptive LLM Routing Under Budget Constraints** (EMNLP 2025) — Bandit-based routing outperforms static rules with no offline training +- **LLM Routing with Dueling Feedback** — Contextual dueling bandits for model routing with cost-performance balancing + +## 20. Documentation Strategy: Human + Agent Readable + +### Platform + +All public documentation hosted on [Mintlify](https://www.mintlify.com/). Mintlify provides structured, searchable, API-aware docs with built-in support for code examples, OpenAPI specs, and MDX. + +### Why This Matters Strategically + +In 2026+, agents increasingly select their own tooling. When a coding agent (Cursor, Claude Code, Devin, OpenHands) is asked to "build an agent that optimizes costs," it will search for and evaluate libraries. If cascadeflow's documentation is the clearest, most structured, and most machine-parseable option, agents will recommend and integrate cascadeflow by default. + +**The documentation is a product surface for two audiences: human developers and AI agents.** + +Documentation must directly support the three winning pillars: +- fast activation (low-friction install), +- clear policy control patterns (in-loop KPI influence), +- transparent evidence (reproducible benchmark claims). + +### Design Principles + +**For humans:** +- 5-minute quickstart with copy-paste code (Tier 1: `init()` → instant savings) +- Progressive disclosure: simple first, advanced when ready +- Integration cookbooks per framework (LangGraph, OpenAI Agents SDK, CrewAI) +- Visual architecture diagrams +- Benchmark results with reproducible methodology +- Changelog and migration guides + +**For agents (LLM-readable):** +- Structured, consistent markdown with clear headings and hierarchy +- Every API surface documented with typed signatures, parameter descriptions, return types, and concrete examples +- `llms.txt` and/or `llms-full.txt` at docs root — a single-file summary optimized for LLM context windows (see llmstxt.org convention) +- OpenAPI spec for the proxy gateway endpoints +- Decision trees in explicit if/then format: "If you need cost optimization for an agent → use `cascadeflow.init()`. If you need budget enforcement → use `cascadeflow.run(budget=X)`." +- Comparison pages with structured tables: cascadeflow vs OpenRouter, vs NotDiamond, vs Portkey — in a format agents can parse and reason about +- Code examples that are complete and runnable (no `...` ellipsis, no "fill in your code here") +- Explicit "when to use" and "when NOT to use" sections — agents make better recommendations when they understand boundaries +- Capability matrix: what works with which framework, which harness mode, which language + +### Documentation Structure + +``` +docs/ +├── quickstart.mdx # 5-minute setup, Tier 1 (init) +├── concepts/ +│ ├── how-cascading-works.mdx # Core cascade explained +│ ├── harness-modes.mdx # off / observe / enforce +│ ├── hard-soft-controls.mdx # Budget, KPIs, compliance +│ └── actions.mdx # allow, switch_model, deny_tool, stop +├── guides/ +│ ├── budget-enforcement.mdx # Tier 2: run() context manager +│ ├── kpi-injection.mdx # Tier 3: @agent() decorator +│ ├── observe-first.mdx # Production rollout guide +│ └── benchmarks.mdx # How to reproduce our benchmarks +├── integrations/ +│ ├── openai-agents-sdk.mdx # Cookbook + full example +│ ├── crewai.mdx # Cookbook + full example +│ ├── langchain.mdx # Cookbook + full example +│ ├── vercel-ai-sdk.mdx # Existing, extended +│ ├── n8n.mdx # Existing, extended +│ └── litellm.mdx # Existing +├── api-reference/ +│ ├── python/ # Auto-generated from docstrings +│ └── typescript/ # Auto-generated from TSDoc +├── comparisons/ +│ ├── vs-openrouter.mdx # Structured comparison +│ ├── vs-notdiamond.mdx +│ ├── vs-portkey.mdx +│ └── vs-litellm.mdx +├── llms.txt # Single-file LLM-optimized summary +└── llms-full.txt # Complete API reference for agent context +``` + +### `llms.txt` Specification + +A concise, structured file at the docs root that gives any LLM/agent everything it needs to evaluate and use cascadeflow: + +``` +# cascadeflow + +> Agent intelligence harness for cost, quality, latency, and business KPI optimization. +> In-process library (not a proxy). Works inside agent loops with full state awareness. + +## Install +pip install cascadeflow + +## Quickstart (3 lines) +import cascadeflow +cascadeflow.init(mode="enforce") +# All openai SDK calls are now cascaded. 40-80% cost savings. + +## When to use cascadeflow +- You run AI agents (LangGraph, CrewAI, OpenAI Agents SDK, or custom) +- You want to reduce LLM costs without changing agent code +- You need budget enforcement across multi-step agent runs +- You need to inject business KPIs into agent decisions + +## When NOT to use cascadeflow +- Single one-off LLM calls (overhead not justified) +- You only use one model and don't want routing + +## Key APIs +- cascadeflow.init(mode) — activate harness globally +- cascadeflow.run(budget, max_tool_calls) — scoped agent run with budget +- @cascadeflow.agent(budget, kpis) — annotate agent functions + +## Integrations +- pip install cascadeflow[langchain] +- pip install cascadeflow[openai-agents] +- pip install cascadeflow[crewai] + +## Docs: https://docs.cascadeflow.ai +``` + +### Timeline + +Documentation is not a post-launch task. It ships with each phase: + +- Phase A: `llms.txt`, concepts pages, API reference stubs +- Phase B: Quickstart (observe mode), `llms-full.txt` +- Phase C: Budget enforcement guide, KPI injection guide +- Phase D: Integration cookbooks (one per framework) +- Phase E: Comparison pages, benchmark results, launch blog post + +## 21. Document Owners + +- Product strategy: cascadeflow maintainers +- Technical design owner: core/runtime maintainers +- Integration owners: per package maintainer (same pattern as existing integrations) +- Documentation: maintained alongside code — every PR that changes API must update docs diff --git a/examples/integrations/README.md b/examples/integrations/README.md index f8728e21..e7e7906a 100644 --- a/examples/integrations/README.md +++ b/examples/integrations/README.md @@ -5,6 +5,7 @@ This directory contains production-ready integration examples for cascadeflow wi ## 📋 Table of Contents - [LiteLLM Integration](#-litellm-integration) - Access 10+ providers with automatic cost tracking +- [OpenAI Agents SDK Integration](#-openai-agents-sdk-integration) - Harness-aware ModelProvider for existing agent apps - [Paygentic Integration](#-paygentic-integration) - Usage event reporting and billing lifecycle helpers - [Local Providers](#-local-providers-setup) - Ollama and vLLM configuration examples - [OpenTelemetry & Grafana](#-opentelemetry--grafana) - Production observability and metrics @@ -138,6 +139,27 @@ export HF_TOKEN="..." --- +## 🤖 OpenAI Agents SDK Integration + +**File:** [`openai_agents_harness.py`](openai_agents_harness.py) + +Use cascadeflow as an explicit `ModelProvider` integration in the OpenAI Agents SDK. + +### Quick Start + +```bash +pip install "cascadeflow[openai,openai-agents]" +python examples/integrations/openai_agents_harness.py +``` + +### What It Shows + +- Harness-aware model switching with candidate models +- Tool gating when enforce-mode caps are reached +- Run-scoped metrics and trace inspection via `cascadeflow.run(...)` + +--- + ## 💳 Paygentic Integration **File:** [`paygentic_usage.py`](paygentic_usage.py) diff --git a/examples/integrations/openai_agents_harness.py b/examples/integrations/openai_agents_harness.py new file mode 100644 index 00000000..ac9d6c68 --- /dev/null +++ b/examples/integrations/openai_agents_harness.py @@ -0,0 +1,64 @@ +""" +OpenAI Agents SDK + cascadeflow harness integration example. + +Run: + pip install "cascadeflow[openai,openai-agents]" + python examples/integrations/openai_agents_harness.py +""" + +from __future__ import annotations + +import asyncio + + +async def main() -> None: + try: + from agents import Agent, RunConfig, Runner + except ImportError as exc: + raise SystemExit( + "OpenAI Agents SDK is not installed. " + 'Install with: pip install "cascadeflow[openai,openai-agents]"' + ) from exc + + from cascadeflow import init, run + from cascadeflow.integrations.openai_agents import ( + CascadeFlowModelProvider, + OpenAIAgentsIntegrationConfig, + ) + + init(mode="observe", budget=1.0, max_tool_calls=5) + + provider = CascadeFlowModelProvider( + config=OpenAIAgentsIntegrationConfig( + model_candidates=["gpt-4o", "gpt-4o-mini"], + enable_tool_gating=True, + ) + ) + + agent = Agent( + name="RouteAwareAgent", + instructions="Respond clearly and include a short reasoning summary.", + model="gpt-4o", + ) + + run_config = RunConfig(model_provider=provider) + + with run(budget=0.5, max_tool_calls=3) as session: + result = await Runner.run( + agent, "Summarize why model routing helps agent budgets.", run_config=run_config + ) + + print("=== Result ===") + print(result.final_output) + print("\n=== Harness Metrics ===") + print(f"Cost: ${session.cost:.6f}") + print(f"Remaining budget: {session.budget_remaining}") + print(f"Steps: {session.step_count}") + print(f"Tool calls: {session.tool_calls}") + print("\n=== Decision Trace ===") + for event in session.trace(): + print(event) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/packages/core/README.md b/packages/core/README.md index a0918d78..3188df91 100644 --- a/packages/core/README.md +++ b/packages/core/README.md @@ -33,6 +33,23 @@ pnpm add @cascadeflow/core yarn add @cascadeflow/core ``` +## Harness Quick Start (V2.1) + +```typescript +import { cascadeflow } from '@cascadeflow/core'; + +// 1) Turn on in-process harness decisions + SDK auto-instrumentation +cascadeflow.init({ mode: 'enforce', budget: 0.5 }); + +// 2) Scope one run (global defaults are inherited) +const result = await cascadeflow.run({ maxToolCalls: 8 }, async (run) => { + // Any OpenAI / Anthropic SDK calls made here are evaluated by the harness. + return { runId: run.runId }; +}); + +console.log(result); +``` + ## Quick Start ### Recommended Setup (Claude Haiku + GPT-5) diff --git a/packages/core/src/__tests__/harness.test.ts b/packages/core/src/__tests__/harness.test.ts new file mode 100644 index 00000000..bad03376 --- /dev/null +++ b/packages/core/src/__tests__/harness.test.ts @@ -0,0 +1,232 @@ +import { afterEach, describe, expect, it } from 'vitest'; + +import { + BudgetExceededError, + cascadeflow, + getCurrentRun, + getHarnessConfig, + init, + reset, + run, +} from '../harness'; +import { + __resetInstrumentationLoadersForTest, + __resetInstrumentationStateForTest, + __setInstrumentationLoadersForTest, + isAnthropicPatched, + isOpenAIPatched, +} from '../harness-instrument'; + +class FakeOpenAICompletions { + constructor(private readonly calls: Array>) {} + + create(request: Record): Promise> { + this.calls.push({ ...request }); + return Promise.resolve({ + usage: { + prompt_tokens: 100, + completion_tokens: 25, + }, + choices: [ + { + message: { + tool_calls: [{ id: 'tool_1', type: 'function' }], + }, + }, + ], + }); + } +} + +class FakeAnthropicMessages { + constructor(private readonly calls: Array>) {} + + create(request: Record): Promise> { + this.calls.push({ ...request }); + return Promise.resolve({ + usage: { + input_tokens: 120, + output_tokens: 40, + }, + content: [ + { type: 'text', text: 'hello' }, + { type: 'tool_use', id: 'tool_1', name: 'search', input: { q: 'x' } }, + ], + }); + } +} + +afterEach(() => { + reset(); + __resetInstrumentationStateForTest(); + __resetInstrumentationLoadersForTest(); +}); + +describe('harness API (TypeScript parity)', () => { + it('exposes cascadeflow init/run object API', async () => { + expect(typeof cascadeflow.init).toBe('function'); + expect(typeof cascadeflow.run).toBe('function'); + + init({ mode: 'observe' }); + const value = await cascadeflow.run(async (scope) => { + expect(scope.mode).toBe('observe'); + expect(getCurrentRun()).toBe(scope); + return 42; + }); + + expect(value).toBe(42); + expect(getCurrentRun()).toBeNull(); + }); + + it('honors code > env precedence and preserves nested scope isolation', async () => { + const previousMode = process.env.CASCADEFLOW_HARNESS_MODE; + process.env.CASCADEFLOW_HARNESS_MODE = 'observe'; + + init(); + expect(getHarnessConfig().mode).toBe('observe'); + + init({ mode: 'enforce' }); + expect(getHarnessConfig().mode).toBe('enforce'); + + await run({ budget: 1.0 }, async (outer) => { + outer.cost = 0.1; + expect(outer.budgetMax).toBe(1.0); + expect(getCurrentRun()).toBe(outer); + + await run({ budget: 0.25 }, async (inner) => { + expect(getCurrentRun()).toBe(inner); + expect(inner.budgetMax).toBe(0.25); + inner.cost = 0.2; + }); + + expect(getCurrentRun()).toBe(outer); + expect(outer.budgetMax).toBe(1.0); + expect(outer.cost).toBe(0.1); + }); + + if (previousMode == null) { + delete process.env.CASCADEFLOW_HARNESS_MODE; + } else { + process.env.CASCADEFLOW_HARNESS_MODE = previousMode; + } + }); + + it('auto-instruments OpenAI and enforces switch_model decisions', async () => { + const openaiCalls: Array> = []; + + __setInstrumentationLoadersForTest({ + openai: () => ({ + Completions: FakeOpenAICompletions, + }), + anthropic: () => null, + }); + + init({ mode: 'enforce' }); + expect(isOpenAIPatched()).toBe(true); + + await run({ kpiWeights: { cost: 1 } }, async (scope) => { + const client = new FakeOpenAICompletions(openaiCalls); + await client.create({ + model: 'gpt-4o', + messages: [{ role: 'user', content: 'hi' }], + }); + + expect(scope.stepCount).toBe(1); + expect(scope.cost).toBeGreaterThan(0); + expect(scope.toolCalls).toBe(1); + + const trace = scope.trace(); + expect(trace).toHaveLength(1); + expect(trace[0]?.action).toBe('switch_model'); + expect(trace[0]?.applied).toBe(true); + expect(trace[0]?.decisionMode).toBe('enforce'); + }); + + expect(openaiCalls).toHaveLength(1); + expect(openaiCalls[0]?.model).not.toBe('gpt-4o'); + }); + + it('observe mode logs non-allow decisions without mutating request', async () => { + const openaiCalls: Array> = []; + + __setInstrumentationLoadersForTest({ + openai: () => ({ + Completions: FakeOpenAICompletions, + }), + anthropic: () => null, + }); + + init({ mode: 'observe' }); + + await run({ kpiWeights: { cost: 1 } }, async (scope) => { + const client = new FakeOpenAICompletions(openaiCalls); + await client.create({ + model: 'gpt-4o', + messages: [{ role: 'user', content: 'hi' }], + }); + + const trace = scope.trace(); + expect(trace).toHaveLength(1); + expect(trace[0]?.action).toBe('switch_model'); + expect(trace[0]?.applied).toBe(false); + expect(trace[0]?.decisionMode).toBe('observe'); + }); + + expect(openaiCalls).toHaveLength(1); + expect(openaiCalls[0]?.model).toBe('gpt-4o'); + }); + + it('enforce mode stops calls when budget is exhausted', async () => { + const openaiCalls: Array> = []; + + __setInstrumentationLoadersForTest({ + openai: () => ({ + Completions: FakeOpenAICompletions, + }), + anthropic: () => null, + }); + + init({ mode: 'enforce' }); + + await expect( + run({ budget: 0 }, async () => { + const client = new FakeOpenAICompletions(openaiCalls); + await client.create({ + model: 'gpt-4o', + messages: [{ role: 'user', content: 'hi' }], + }); + }), + ).rejects.toBeInstanceOf(BudgetExceededError); + + expect(openaiCalls).toHaveLength(0); + }); + + it('auto-instruments Anthropic and tracks usage/tool calls', async () => { + const anthropicCalls: Array> = []; + + __setInstrumentationLoadersForTest({ + openai: () => null, + anthropic: () => ({ + Messages: FakeAnthropicMessages, + }), + }); + + init({ mode: 'enforce' }); + expect(isAnthropicPatched()).toBe(true); + + await run(async (scope) => { + const client = new FakeAnthropicMessages(anthropicCalls); + await client.create({ + model: 'claude-sonnet-4-5-20250929', + messages: [{ role: 'user', content: 'hello' }], + }); + + expect(scope.stepCount).toBe(1); + expect(scope.toolCalls).toBe(1); + expect(scope.cost).toBeGreaterThan(0); + expect(scope.trace()[0]?.action).toBe('allow'); + }); + + expect(anthropicCalls).toHaveLength(1); + }); +}); diff --git a/packages/core/src/harness-instrument.ts b/packages/core/src/harness-instrument.ts new file mode 100644 index 00000000..901af4ae --- /dev/null +++ b/packages/core/src/harness-instrument.ts @@ -0,0 +1,746 @@ +type Action = 'allow' | 'switch_model' | 'deny_tool' | 'stop'; + +type CreateFunction = (this: any, ...args: any[]) => any; + +type OpenAIModuleLike = { + Completions?: { + prototype?: { + create?: CreateFunction; + }; + }; +}; + +type AnthropicModuleLike = { + Messages?: { + prototype?: { + create?: CreateFunction; + }; + }; +}; + +type Pricing = { input: number; output: number }; + +type PreCallDecision = { + action: Action; + reason: string; + targetModel: string; +}; + +type HarnessRuntime = { + getCurrentRun: () => HarnessRunContextLike | null; + getHarnessMode: () => HarnessModeLike; + createBudgetExceededError: (message: string, remaining?: number) => Error; + createHarnessStopError: (message: string, reason?: string) => Error; +}; + +type HarnessModeLike = 'off' | 'observe' | 'enforce'; + +type HarnessRunContextLike = { + mode: HarnessModeLike; + cost: number; + stepCount: number; + toolCalls: number; + latencyUsedMs: number; + energyUsed: number; + budgetMax?: number; + budgetRemaining?: number; + toolCallsMax?: number; + latencyMaxMs?: number; + energyMax?: number; + compliance?: string; + kpiWeights?: Record; + record: ( + action: string, + reason: string, + model?: string, + options?: { + applied?: boolean; + decisionMode?: HarnessModeLike; + }, + ) => void; +}; + +const MODEL_PRICING_PER_MILLION: Record = { + // OpenAI + 'gpt-5': { input: 1.25, output: 10.0 }, + 'gpt-5-mini': { input: 0.25, output: 2.0 }, + 'gpt-5-nano': { input: 0.05, output: 0.4 }, + 'gpt-4o': { input: 2.5, output: 10.0 }, + 'gpt-4o-mini': { input: 0.15, output: 0.6 }, + 'o1': { input: 15.0, output: 60.0 }, + 'o1-mini': { input: 3.0, output: 12.0 }, + 'o3-mini': { input: 1.0, output: 5.0 }, + + // Anthropic + 'claude-opus-4-5-20251101': { input: 15.0, output: 75.0 }, + 'claude-opus-4-20250514': { input: 15.0, output: 75.0 }, + 'claude-sonnet-4-5-20250929': { input: 3.0, output: 15.0 }, + 'claude-sonnet-4-20250514': { input: 3.0, output: 15.0 }, + 'claude-haiku-4-5-20251001': { input: 1.0, output: 5.0 }, + 'claude-3-5-haiku-20241022': { input: 1.0, output: 5.0 }, +}; + +const ENERGY_COEFFICIENTS: Record = { + 'gpt-5': 1.15, + 'gpt-5-mini': 0.72, + 'gpt-5-nano': 0.45, + 'gpt-4o': 1.0, + 'gpt-4o-mini': 0.55, + 'o1': 1.25, + 'o1-mini': 0.85, + 'o3-mini': 0.75, + 'claude-opus-4-5-20251101': 1.2, + 'claude-opus-4-20250514': 1.15, + 'claude-sonnet-4-5-20250929': 0.95, + 'claude-sonnet-4-20250514': 0.92, + 'claude-haiku-4-5-20251001': 0.7, + 'claude-3-5-haiku-20241022': 0.68, +}; + +const LATENCY_PRIORS: Record = { + 'gpt-5': 0.45, + 'gpt-5-mini': 0.72, + 'gpt-5-nano': 0.9, + 'gpt-4o': 0.58, + 'gpt-4o-mini': 0.82, + 'o1': 0.35, + 'o1-mini': 0.62, + 'o3-mini': 0.7, + 'claude-opus-4-5-20251101': 0.4, + 'claude-opus-4-20250514': 0.44, + 'claude-sonnet-4-5-20250929': 0.6, + 'claude-sonnet-4-20250514': 0.63, + 'claude-haiku-4-5-20251001': 0.85, + 'claude-3-5-haiku-20241022': 0.86, +}; + +const QUALITY_PRIORS: Record = { + 'gpt-5': 0.95, + 'gpt-5-mini': 0.86, + 'gpt-5-nano': 0.74, + 'gpt-4o': 0.9, + 'gpt-4o-mini': 0.82, + 'o1': 0.93, + 'o1-mini': 0.84, + 'o3-mini': 0.86, + 'claude-opus-4-5-20251101': 0.94, + 'claude-opus-4-20250514': 0.92, + 'claude-sonnet-4-5-20250929': 0.9, + 'claude-sonnet-4-20250514': 0.88, + 'claude-haiku-4-5-20251001': 0.82, + 'claude-3-5-haiku-20241022': 0.8, +}; + +const COMPLIANCE_ALLOWLISTS: Record> = { + strict: new Set(['gpt-4o', 'gpt-4o-mini', 'claude-sonnet-4-5-20250929', 'claude-haiku-4-5-20251001']), + regulated: new Set(['gpt-4o', 'claude-sonnet-4-5-20250929']), +}; + +const DEFAULT_ENERGY_COEFFICIENT = 0.9; +const DEFAULT_OUTPUT_WEIGHT = 1.5; + +const PRICING_MODELS = Object.keys(MODEL_PRICING_PER_MILLION); + +let openAIPatched = false; +let anthropicPatched = false; + +let originalOpenAICreate: CreateFunction | null = null; +let originalAnthropicCreate: CreateFunction | null = null; +let patchedOpenAIClass: { prototype?: { create?: CreateFunction } } | null = null; +let patchedAnthropicClass: { prototype?: { create?: CreateFunction } } | null = null; + +const defaultOpenAILoader = (): OpenAIModuleLike | null => { + try { + // eslint-disable-next-line @typescript-eslint/no-var-requires + return require('openai/resources/chat/completions') as OpenAIModuleLike; + } catch { + return null; + } +}; + +const defaultAnthropicLoader = (): AnthropicModuleLike | null => { + try { + // eslint-disable-next-line @typescript-eslint/no-var-requires + return require('@anthropic-ai/sdk/resources/messages') as AnthropicModuleLike; + } catch { + return null; + } +}; + +let loadOpenAIModule = defaultOpenAILoader; +let loadAnthropicModule = defaultAnthropicLoader; +let harnessRuntimeBindings: HarnessRuntime | null = null; + +function getHarnessRuntime(): HarnessRuntime { + if (!harnessRuntimeBindings) { + throw new Error('Harness runtime bindings not configured'); + } + return harnessRuntimeBindings; +} + +export function setHarnessRuntimeBindingsForInstrumentation(bindings: HarnessRuntime): void { + harnessRuntimeBindings = bindings; +} + +function nowMonotonicMs(): number { + // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition + if (typeof globalThis !== 'undefined' && (globalThis as any).performance?.now) { + return (globalThis as any).performance.now() as number; + } + + // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition + if (typeof process !== 'undefined' && process.hrtime?.bigint) { + return Number(process.hrtime.bigint()) / 1_000_000; + } + + return Date.now(); +} + +function normalizeModelName(model: string): string { + return model.trim().toLowerCase(); +} + +function estimateCost(model: string, promptTokens: number, completionTokens: number): number { + const price = MODEL_PRICING_PER_MILLION[normalizeModelName(model)]; + if (!price) { + return 0; + } + + return (promptTokens / 1_000_000) * price.input + (completionTokens / 1_000_000) * price.output; +} + +function estimateEnergy(model: string, promptTokens: number, completionTokens: number): number { + const coefficient = ENERGY_COEFFICIENTS[normalizeModelName(model)] ?? DEFAULT_ENERGY_COEFFICIENT; + return coefficient * (promptTokens + completionTokens * DEFAULT_OUTPUT_WEIGHT) / 1000; +} + +function modelTotalCost(model: string): number { + const price = MODEL_PRICING_PER_MILLION[normalizeModelName(model)]; + if (!price) { + return Number.POSITIVE_INFINITY; + } + return price.input + price.output; +} + +function selectCheaperModel(currentModel: string): string { + const currentCost = modelTotalCost(currentModel); + let bestModel = currentModel; + let bestCost = currentCost; + + for (const candidate of PRICING_MODELS) { + const candidateCost = modelTotalCost(candidate); + if (candidateCost < bestCost) { + bestModel = candidate; + bestCost = candidateCost; + } + } + + return bestModel; +} + +function selectLowerEnergyModel(currentModel: string): string { + const currentCoeff = ENERGY_COEFFICIENTS[normalizeModelName(currentModel)] ?? DEFAULT_ENERGY_COEFFICIENT; + let bestModel = currentModel; + let bestCoeff = currentCoeff; + + for (const candidate of PRICING_MODELS) { + const coeff = ENERGY_COEFFICIENTS[candidate] ?? DEFAULT_ENERGY_COEFFICIENT; + if (coeff < bestCoeff) { + bestModel = candidate; + bestCoeff = coeff; + } + } + + return bestModel; +} + +function selectFasterModel(currentModel: string): string { + const currentLatency = LATENCY_PRIORS[normalizeModelName(currentModel)] ?? 0.7; + let bestModel = currentModel; + let bestLatency = currentLatency; + + for (const candidate of PRICING_MODELS) { + const score = LATENCY_PRIORS[candidate] ?? 0.7; + if (score > bestLatency) { + bestModel = candidate; + bestLatency = score; + } + } + + return bestModel; +} + +function normalizeWeights(weights: Record): Record { + const normalized: Record = {}; + let total = 0; + + for (const [key, value] of Object.entries(weights)) { + if (!Number.isFinite(value) || value <= 0) { + continue; + } + normalized[key] = value; + total += value; + } + + if (total <= 0) { + return {}; + } + + for (const key of Object.keys(normalized)) { + normalized[key] /= total; + } + + return normalized; +} + +function costUtility(model: string): number { + const costs = PRICING_MODELS.map(modelTotalCost).filter(Number.isFinite); + const min = Math.min(...costs); + const max = Math.max(...costs); + const current = modelTotalCost(model); + + if (!Number.isFinite(current) || max === min) { + return 0.5; + } + + return (max - current) / (max - min); +} + +function energyUtility(model: string): number { + const coeffs = PRICING_MODELS.map((name) => ENERGY_COEFFICIENTS[name] ?? DEFAULT_ENERGY_COEFFICIENT); + const min = Math.min(...coeffs); + const max = Math.max(...coeffs); + const current = ENERGY_COEFFICIENTS[normalizeModelName(model)] ?? DEFAULT_ENERGY_COEFFICIENT; + + if (max === min) { + return 0.5; + } + + return (max - current) / (max - min); +} + +function kpiScore(model: string, weights: Record): number { + const normalized = normalizeWeights(weights); + if (Object.keys(normalized).length === 0) { + return 0; + } + + const key = normalizeModelName(model); + const quality = QUALITY_PRIORS[key] ?? 0.7; + const latency = LATENCY_PRIORS[key] ?? 0.7; + const cost = costUtility(key); + const energy = energyUtility(key); + + return ( + (normalized.quality ?? 0) * quality + + (normalized.latency ?? 0) * latency + + (normalized.cost ?? 0) * cost + + (normalized.energy ?? 0) * energy + ); +} + +function selectKPIWeightedModel(currentModel: string, weights: Record): string { + const normalized = normalizeWeights(weights); + if (Object.keys(normalized).length === 0) { + return currentModel; + } + + let bestModel = currentModel; + let bestScore = kpiScore(currentModel, normalized); + + for (const candidate of PRICING_MODELS) { + const score = kpiScore(candidate, normalized); + if (score > bestScore) { + bestModel = candidate; + bestScore = score; + } + } + + return bestModel; +} + +function extractOpenAIUsage(response: any): [number, number] { + const usage = response?.usage; + if (!usage || typeof usage !== 'object') { + return [0, 0]; + } + const promptTokens = Number(usage.prompt_tokens ?? usage.input_tokens ?? 0); + const completionTokens = Number(usage.completion_tokens ?? usage.output_tokens ?? 0); + return [ + Number.isFinite(promptTokens) ? promptTokens : 0, + Number.isFinite(completionTokens) ? completionTokens : 0, + ]; +} + +function extractAnthropicUsage(response: any): [number, number] { + const usage = response?.usage; + if (!usage || typeof usage !== 'object') { + return [0, 0]; + } + + const inputTokens = Number(usage.input_tokens ?? usage.prompt_tokens ?? 0); + const outputTokens = Number(usage.output_tokens ?? usage.completion_tokens ?? 0); + return [ + Number.isFinite(inputTokens) ? inputTokens : 0, + Number.isFinite(outputTokens) ? outputTokens : 0, + ]; +} + +function countOpenAIToolCalls(response: any): number { + const toolCalls = response?.choices?.[0]?.message?.tool_calls; + if (!Array.isArray(toolCalls)) { + return 0; + } + return toolCalls.length; +} + +function countAnthropicToolCalls(response: any): number { + const content = response?.content; + if (!Array.isArray(content)) { + return 0; + } + return content.filter((item: any) => item?.type === 'tool_use').length; +} + +function evaluatePreCallDecision(ctx: HarnessRunContextLike, model: string, hasTools: boolean): PreCallDecision { + if (ctx.budgetMax != null && ctx.cost >= ctx.budgetMax) { + return { action: 'stop', reason: 'budget_exceeded', targetModel: model }; + } + + if (hasTools && ctx.toolCallsMax != null && ctx.toolCalls >= ctx.toolCallsMax) { + return { action: 'deny_tool', reason: 'max_tool_calls_reached', targetModel: model }; + } + + if (ctx.compliance) { + const profile = COMPLIANCE_ALLOWLISTS[ctx.compliance.trim().toLowerCase()]; + if (profile) { + const normalized = normalizeModelName(model); + if (!profile.has(normalized)) { + const next = PRICING_MODELS.find((candidate) => profile.has(candidate)); + if (next) { + return { action: 'switch_model', reason: 'compliance_model_policy', targetModel: next }; + } + return { + action: hasTools ? 'deny_tool' : 'stop', + reason: hasTools ? 'compliance_no_approved_tool_path' : 'compliance_no_approved_model', + targetModel: model, + }; + } + if (ctx.compliance.trim().toLowerCase() === 'strict' && hasTools) { + return { action: 'deny_tool', reason: 'compliance_tool_restriction', targetModel: model }; + } + } + } + + if (ctx.latencyMaxMs != null && ctx.latencyUsedMs >= ctx.latencyMaxMs) { + const faster = selectFasterModel(model); + if (normalizeModelName(faster) !== normalizeModelName(model)) { + return { action: 'switch_model', reason: 'latency_limit_exceeded', targetModel: faster }; + } + return { action: 'stop', reason: 'latency_limit_exceeded', targetModel: model }; + } + + if (ctx.energyMax != null && ctx.energyUsed >= ctx.energyMax) { + const lower = selectLowerEnergyModel(model); + if (normalizeModelName(lower) !== normalizeModelName(model)) { + return { action: 'switch_model', reason: 'energy_limit_exceeded', targetModel: lower }; + } + return { action: 'stop', reason: 'energy_limit_exceeded', targetModel: model }; + } + + if ( + ctx.budgetMax != null + && ctx.budgetMax > 0 + && ctx.budgetRemaining != null + && (ctx.budgetRemaining / ctx.budgetMax) < 0.2 + ) { + const cheaper = selectCheaperModel(model); + if (normalizeModelName(cheaper) !== normalizeModelName(model)) { + return { action: 'switch_model', reason: 'budget_pressure', targetModel: cheaper }; + } + } + + if (ctx.kpiWeights && Object.keys(ctx.kpiWeights).length > 0) { + const candidate = selectKPIWeightedModel(model, ctx.kpiWeights); + if (normalizeModelName(candidate) !== normalizeModelName(model)) { + return { action: 'switch_model', reason: 'kpi_weight_optimization', targetModel: candidate }; + } + } + + return { action: 'allow', reason: ctx.mode, targetModel: model }; +} + +function raiseStopError(ctx: HarnessRunContextLike, reason: string): never { + const runtime = getHarnessRuntime(); + if (reason === 'budget_exceeded') { + const remaining = Math.max(0, (ctx.budgetMax ?? 0) - ctx.cost); + throw runtime.createBudgetExceededError( + `Budget exhausted: spent $${ctx.cost.toFixed(4)} of $${(ctx.budgetMax ?? 0).toFixed(4)} max`, + remaining, + ); + } + + throw runtime.createHarnessStopError(`cascadeflow harness stop: ${reason}`, reason); +} + +function updateContext( + ctx: HarnessRunContextLike, + mode: HarnessModeLike, + model: string, + promptTokens: number, + completionTokens: number, + toolCalls: number, + elapsedMs: number, + decision: PreCallDecision, + applied: boolean, +): void { + const cost = estimateCost(model, promptTokens, completionTokens); + const energy = estimateEnergy(model, promptTokens, completionTokens); + + ctx.cost += cost; + ctx.stepCount += 1; + ctx.toolCalls += toolCalls; + ctx.latencyUsedMs += elapsedMs; + ctx.energyUsed += energy; + + if (ctx.budgetMax != null) { + ctx.budgetRemaining = ctx.budgetMax - ctx.cost; + } + + ctx.record(decision.action, decision.reason, decision.targetModel, { + applied, + decisionMode: mode, + }); +} + +function isThenable(value: any): value is Promise { + return Boolean(value) && typeof value.then === 'function'; +} + +function makePatchedCreate(provider: 'openai' | 'anthropic', original: CreateFunction): CreateFunction { + return function patchedCreate(this: any, ...args: any[]): any { + const runtime = getHarnessRuntime(); + const activeRun = runtime.getCurrentRun(); + const mode = activeRun?.mode ?? runtime.getHarnessMode(); + + if (mode === 'off') { + return original.apply(this, args); + } + + const firstArg = args[0]; + const request = firstArg && typeof firstArg === 'object' ? { ...firstArg } : {}; + const model = typeof request.model === 'string' ? request.model : 'unknown'; + const hasTools = Array.isArray(request.tools) && request.tools.length > 0; + + const decision = activeRun ? evaluatePreCallDecision(activeRun, model, hasTools) : { + action: 'allow' as const, + reason: mode, + targetModel: model, + }; + + let applied = decision.action === 'allow'; + let effectiveModel = model; + + if (activeRun && mode === 'enforce') { + if (decision.action === 'stop') { + activeRun.record('stop', decision.reason, model, { + applied: true, + decisionMode: mode, + }); + raiseStopError(activeRun, decision.reason); + } + + if (decision.action === 'switch_model') { + if (normalizeModelName(decision.targetModel) !== normalizeModelName(model)) { + request.model = decision.targetModel; + effectiveModel = decision.targetModel; + applied = true; + } else { + applied = false; + } + } + + if (decision.action === 'deny_tool') { + if (Array.isArray(request.tools) && request.tools.length > 0) { + request.tools = []; + applied = true; + } else { + applied = false; + } + } + } else if (decision.action !== 'allow') { + applied = false; + } + + const interceptedArgs = firstArg && typeof firstArg === 'object' + ? [request, ...args.slice(1)] + : args; + + const isStream = Boolean(request.stream); + const startedAt = nowMonotonicMs(); + const result = original.apply(this, interceptedArgs); + + if (!activeRun) { + return result; + } + + const finalize = (response: any): any => { + const elapsedMs = Math.max(0, nowMonotonicMs() - startedAt); + + let promptTokens = 0; + let completionTokens = 0; + let toolCallCount = 0; + + if (!isStream) { + if (provider === 'openai') { + [promptTokens, completionTokens] = extractOpenAIUsage(response); + toolCallCount = countOpenAIToolCalls(response); + } else { + [promptTokens, completionTokens] = extractAnthropicUsage(response); + toolCallCount = countAnthropicToolCalls(response); + } + } + + updateContext( + activeRun, + mode, + effectiveModel, + promptTokens, + completionTokens, + toolCallCount, + elapsedMs, + decision, + applied, + ); + + return response; + }; + + if (isThenable(result)) { + result + .then((response) => { + finalize(response); + }) + .catch(() => { + // fail-open: harness instrumentation errors must not crash user flow. + }); + return result; + } + + return finalize(result); + }; +} + +export function detectOpenAIInstrumentationTarget(): boolean { + const module = loadOpenAIModule(); + return Boolean(module?.Completions?.prototype?.create); +} + +export function detectAnthropicInstrumentationTarget(): boolean { + const module = loadAnthropicModule(); + return Boolean(module?.Messages?.prototype?.create); +} + +export function patchOpenAI(): boolean { + if (openAIPatched) { + return true; + } + + const module = loadOpenAIModule(); + const cls = module?.Completions; + const prototype = cls?.prototype; + const create = prototype?.create; + + if (!cls || !prototype || typeof create !== 'function') { + return false; + } + + originalOpenAICreate = create; + patchedOpenAIClass = cls; + prototype.create = makePatchedCreate('openai', create); + openAIPatched = true; + return true; +} + +export function patchAnthropic(): boolean { + if (anthropicPatched) { + return true; + } + + const module = loadAnthropicModule(); + const cls = module?.Messages; + const prototype = cls?.prototype; + const create = prototype?.create; + + if (!cls || !prototype || typeof create !== 'function') { + return false; + } + + originalAnthropicCreate = create; + patchedAnthropicClass = cls; + prototype.create = makePatchedCreate('anthropic', create); + anthropicPatched = true; + return true; +} + +export function unpatchOpenAI(): void { + if (!openAIPatched) { + return; + } + + if (patchedOpenAIClass?.prototype && originalOpenAICreate) { + patchedOpenAIClass.prototype.create = originalOpenAICreate; + } + + openAIPatched = false; + originalOpenAICreate = null; + patchedOpenAIClass = null; +} + +export function unpatchAnthropic(): void { + if (!anthropicPatched) { + return; + } + + if (patchedAnthropicClass?.prototype && originalAnthropicCreate) { + patchedAnthropicClass.prototype.create = originalAnthropicCreate; + } + + anthropicPatched = false; + originalAnthropicCreate = null; + patchedAnthropicClass = null; +} + +export function isOpenAIPatched(): boolean { + return openAIPatched; +} + +export function isAnthropicPatched(): boolean { + return anthropicPatched; +} + +export function isPatched(): boolean { + return openAIPatched || anthropicPatched; +} + +export function __setInstrumentationLoadersForTest(loaders: { + openai?: () => OpenAIModuleLike | null; + anthropic?: () => AnthropicModuleLike | null; +}): void { + if (loaders.openai) { + loadOpenAIModule = loaders.openai; + } + if (loaders.anthropic) { + loadAnthropicModule = loaders.anthropic; + } +} + +export function __resetInstrumentationLoadersForTest(): void { + loadOpenAIModule = defaultOpenAILoader; + loadAnthropicModule = defaultAnthropicLoader; +} + +export function __resetInstrumentationStateForTest(): void { + unpatchOpenAI(); + unpatchAnthropic(); +} diff --git a/packages/core/src/harness.ts b/packages/core/src/harness.ts new file mode 100644 index 00000000..3815360e --- /dev/null +++ b/packages/core/src/harness.ts @@ -0,0 +1,754 @@ +import { + __resetInstrumentationStateForTest, + detectAnthropicInstrumentationTarget, + detectOpenAIInstrumentationTarget, + patchAnthropic, + patchOpenAI, + setHarnessRuntimeBindingsForInstrumentation, + unpatchAnthropic, + unpatchOpenAI, +} from './harness-instrument'; + +export type HarnessMode = 'off' | 'observe' | 'enforce'; + +export type HarnessConfig = { + mode: HarnessMode; + verbose: boolean; + budget?: number; + maxToolCalls?: number; + maxLatencyMs?: number; + maxEnergy?: number; + kpiTargets?: Record; + kpiWeights?: Record; + compliance?: string; +}; + +export type HarnessInitOptions = Partial; + +export type HarnessRunOptions = { + budget?: number; + maxToolCalls?: number; + maxLatencyMs?: number; + maxEnergy?: number; + kpiTargets?: Record; + kpiWeights?: Record; + compliance?: string; +}; + +export type HarnessInitReport = { + mode: HarnessMode; + instrumented: string[]; + detectedButNotInstrumented: string[]; + configSources: Record; +}; + +export type HarnessRecordOptions = { + applied?: boolean; + decisionMode?: HarnessMode; +}; + +export type HarnessTraceEntry = { + action: string; + reason: string; + model?: string; + runId: string; + mode: HarnessMode; + step: number; + timestampMs: number; + toolCallsTotal: number; + costTotal: number; + latencyUsedMs: number; + energyUsed: number; + budgetState: { + max?: number; + remaining?: number; + }; + applied?: boolean; + decisionMode?: HarnessMode; +}; + +export type HarnessRunSummary = { + runId: string; + mode: HarnessMode; + stepCount: number; + toolCalls: number; + cost: number; + savings: number; + latencyUsedMs: number; + energyUsed: number; + budgetMax?: number; + budgetRemaining?: number; + lastAction: string; + modelUsed?: string; + durationMs?: number; +}; + +export class HarnessStopError extends Error { + reason: string; + + constructor(message: string, reason = 'stop') { + super(message); + this.name = 'HarnessStopError'; + this.reason = reason; + } +} + +export class BudgetExceededError extends HarnessStopError { + remaining: number; + + constructor(message: string, remaining = 0) { + super(message, 'budget_exceeded'); + this.name = 'BudgetExceededError'; + this.remaining = remaining; + } +} + +function randomRunId(): string { + return Math.random().toString(36).slice(2, 14); +} + +function nowMonotonicMs(): number { + // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition + if (typeof globalThis !== 'undefined' && (globalThis as any).performance?.now) { + return (globalThis as any).performance.now() as number; + } + + // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition + if (typeof process !== 'undefined' && process.hrtime?.bigint) { + return Number(process.hrtime.bigint()) / 1_000_000; + } + + return Date.now(); +} + +const MAX_ACTION_LEN = 64; +const MAX_REASON_LEN = 160; +const MAX_MODEL_LEN = 128; + +function sanitizeTraceValue(value: unknown, maxLength: number): string | undefined { + if (value == null) { + return undefined; + } + + const text = String(value).replace(/\r?\n/g, ' ').trim(); + if (!text) { + return undefined; + } + + if (text.length <= maxLength) { + return text; + } + + return `${text.slice(0, Math.max(0, maxLength - 3))}...`; +} + +export class HarnessRunContext { + runId: string; + startedAtMs: number; + endedAtMs?: number; + durationMs?: number; + + mode: HarnessMode; + budgetMax?: number; + toolCallsMax?: number; + latencyMaxMs?: number; + energyMax?: number; + kpiTargets?: Record; + kpiWeights?: Record; + compliance?: string; + + cost = 0; + savings = 0; + toolCalls = 0; + stepCount = 0; + latencyUsedMs = 0; + energyUsed = 0; + verbose = false; + budgetRemaining?: number; + modelUsed?: string; + lastAction = 'allow'; + draftAccepted?: boolean; + + private readonly _startedMonotonic: number; + private readonly _trace: HarnessTraceEntry[] = []; + private _finalized = false; + + constructor(config: { + mode: HarnessMode; + budgetMax?: number; + toolCallsMax?: number; + latencyMaxMs?: number; + energyMax?: number; + kpiTargets?: Record; + kpiWeights?: Record; + compliance?: string; + verbose?: boolean; + }) { + this.runId = randomRunId(); + this.startedAtMs = Date.now(); + this._startedMonotonic = nowMonotonicMs(); + + this.mode = config.mode; + this.budgetMax = config.budgetMax; + this.toolCallsMax = config.toolCallsMax; + this.latencyMaxMs = config.latencyMaxMs; + this.energyMax = config.energyMax; + this.kpiTargets = config.kpiTargets; + this.kpiWeights = config.kpiWeights; + this.compliance = config.compliance; + this.verbose = Boolean(config.verbose); + + if (config.budgetMax != null) { + this.budgetRemaining = config.budgetMax; + } + } + + finish(): void { + if (this._finalized) { + return; + } + + this._finalized = true; + this.endedAtMs = Date.now(); + this.durationMs = Math.max(0, nowMonotonicMs() - this._startedMonotonic); + + if (this.verbose && this.mode !== 'off' && this.stepCount > 0) { + // Keep logging cheap and controlled. + // eslint-disable-next-line no-console + console.info( + '[cascadeflow.harness] run summary', + { + runId: this.runId, + mode: this.mode, + steps: this.stepCount, + toolCalls: this.toolCalls, + cost: this.cost, + latencyMs: this.latencyUsedMs, + energy: this.energyUsed, + lastAction: this.lastAction, + model: this.modelUsed, + budgetRemaining: this.budgetRemaining, + durationMs: this.durationMs, + }, + ); + } + } + + record(action: string, reason: string, model?: string, options: HarnessRecordOptions = {}): void { + let safeAction = sanitizeTraceValue(action, MAX_ACTION_LEN); + if (!safeAction) { + safeAction = 'allow'; + } + + const safeReason = sanitizeTraceValue(reason, MAX_REASON_LEN) ?? 'unspecified'; + const safeModel = sanitizeTraceValue(model, MAX_MODEL_LEN); + + this.lastAction = safeAction; + this.modelUsed = safeModel; + + const entry: HarnessTraceEntry = { + action: safeAction, + reason: safeReason, + model: safeModel, + runId: this.runId, + mode: this.mode, + step: this.stepCount, + timestampMs: Date.now(), + toolCallsTotal: this.toolCalls, + costTotal: this.cost, + latencyUsedMs: this.latencyUsedMs, + energyUsed: this.energyUsed, + budgetState: { + max: this.budgetMax, + remaining: this.budgetRemaining, + }, + }; + + if (options.applied != null) { + entry.applied = options.applied; + } + + if (options.decisionMode != null) { + entry.decisionMode = options.decisionMode; + } + + this._trace.push(entry); + } + + trace(): HarnessTraceEntry[] { + return [...this._trace]; + } + + summary(): HarnessRunSummary { + return { + runId: this.runId, + mode: this.mode, + stepCount: this.stepCount, + toolCalls: this.toolCalls, + cost: this.cost, + savings: this.savings, + latencyUsedMs: this.latencyUsedMs, + energyUsed: this.energyUsed, + budgetMax: this.budgetMax, + budgetRemaining: this.budgetRemaining, + lastAction: this.lastAction, + modelUsed: this.modelUsed, + durationMs: this.durationMs, + }; + } +} + +type ConfigSource = 'code' | 'env' | 'file' | 'default'; + +type ConfigWithSources = { + config: HarnessConfig; + sources: Record; +}; + +let _harnessConfig: HarnessConfig = { + mode: 'off', + verbose: false, +}; + +let _isInstrumented = false; +let fallbackCurrentRun: HarnessRunContext | null = null; + +let asyncLocalStorageInstance: { run: (store: HarnessRunContext, callback: () => Promise) => Promise; getStore: () => HarnessRunContext | undefined } | null = null; + +function getAsyncLocalStorage(): typeof asyncLocalStorageInstance { + if (asyncLocalStorageInstance) { + return asyncLocalStorageInstance; + } + + try { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const mod = require('node:async_hooks') as { + AsyncLocalStorage: new () => { run: (store: T, callback: () => Promise) => Promise; getStore: () => T | undefined }; + }; + + asyncLocalStorageInstance = new mod.AsyncLocalStorage(); + } catch { + asyncLocalStorageInstance = null; + } + + return asyncLocalStorageInstance; +} + +function parseBoolean(raw: string): boolean { + const normalized = raw.trim().toLowerCase(); + return normalized === '1' || normalized === 'true' || normalized === 'yes' || normalized === 'on'; +} + +function parseNumber(raw: string): number { + const value = Number(raw); + if (!Number.isFinite(value)) { + throw new Error(`Invalid numeric value: ${raw}`); + } + return value; +} + +function parseJSONMap(raw: string): Record { + const parsed = JSON.parse(raw); + if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) { + throw new Error('Expected object'); + } + + const result: Record = {}; + for (const [key, value] of Object.entries(parsed as Record)) { + result[String(key)] = Number(value); + } + return result; +} + +function normalizeMode(mode: unknown): HarnessMode { + if (mode === 'off' || mode === 'observe' || mode === 'enforce') { + return mode; + } + + throw new Error('mode must be one of: off, observe, enforce'); +} + +function normalizeConfigRecord(raw: Record): HarnessInitOptions { + const out: HarnessInitOptions = {}; + + const mode = raw.mode ?? raw.harness_mode; + if (typeof mode === 'string') { + out.mode = normalizeMode(mode); + } + + const verbose = raw.verbose ?? raw.harness_verbose; + if (typeof verbose === 'boolean') { + out.verbose = verbose; + } + + const budget = raw.budget ?? raw.max_budget; + if (typeof budget === 'number') { + out.budget = budget; + } + + const maxToolCalls = raw.maxToolCalls ?? raw.max_tool_calls; + if (typeof maxToolCalls === 'number') { + out.maxToolCalls = maxToolCalls; + } + + const maxLatencyMs = raw.maxLatencyMs ?? raw.max_latency_ms; + if (typeof maxLatencyMs === 'number') { + out.maxLatencyMs = maxLatencyMs; + } + + const maxEnergy = raw.maxEnergy ?? raw.max_energy; + if (typeof maxEnergy === 'number') { + out.maxEnergy = maxEnergy; + } + + const kpiTargets = raw.kpiTargets ?? raw.kpi_targets; + if (kpiTargets && typeof kpiTargets === 'object' && !Array.isArray(kpiTargets)) { + out.kpiTargets = kpiTargets as Record; + } + + const kpiWeights = raw.kpiWeights ?? raw.kpi_weights; + if (kpiWeights && typeof kpiWeights === 'object' && !Array.isArray(kpiWeights)) { + out.kpiWeights = kpiWeights as Record; + } + + const compliance = raw.compliance; + if (typeof compliance === 'string') { + out.compliance = compliance; + } + + return out; +} + +function readEnvConfig(): HarnessInitOptions { + // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition + if (typeof process === 'undefined' || !process.env) { + return {}; + } + + const env = process.env; + const config: HarnessInitOptions = {}; + + const mode = env.CASCADEFLOW_HARNESS_MODE ?? env.CASCADEFLOW_MODE; + if (mode) { + config.mode = normalizeMode(mode); + } + + if (env.CASCADEFLOW_HARNESS_VERBOSE != null) { + config.verbose = parseBoolean(env.CASCADEFLOW_HARNESS_VERBOSE); + } + + const budget = env.CASCADEFLOW_HARNESS_BUDGET ?? env.CASCADEFLOW_BUDGET; + if (budget != null) { + config.budget = parseNumber(budget); + } + + if (env.CASCADEFLOW_HARNESS_MAX_TOOL_CALLS != null) { + config.maxToolCalls = parseNumber(env.CASCADEFLOW_HARNESS_MAX_TOOL_CALLS); + } + + if (env.CASCADEFLOW_HARNESS_MAX_LATENCY_MS != null) { + config.maxLatencyMs = parseNumber(env.CASCADEFLOW_HARNESS_MAX_LATENCY_MS); + } + + if (env.CASCADEFLOW_HARNESS_MAX_ENERGY != null) { + config.maxEnergy = parseNumber(env.CASCADEFLOW_HARNESS_MAX_ENERGY); + } + + if (env.CASCADEFLOW_HARNESS_KPI_TARGETS != null) { + config.kpiTargets = parseJSONMap(env.CASCADEFLOW_HARNESS_KPI_TARGETS); + } + + if (env.CASCADEFLOW_HARNESS_KPI_WEIGHTS != null) { + config.kpiWeights = parseJSONMap(env.CASCADEFLOW_HARNESS_KPI_WEIGHTS); + } + + if (env.CASCADEFLOW_HARNESS_COMPLIANCE != null) { + config.compliance = env.CASCADEFLOW_HARNESS_COMPLIANCE; + } + + return config; +} + +function readFileConfig(): HarnessInitOptions { + // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition + if (typeof process === 'undefined' || !process.cwd) { + return {}; + } + + try { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const fs = require('node:fs') as typeof import('node:fs'); + // eslint-disable-next-line @typescript-eslint/no-var-requires + const path = require('node:path') as typeof import('node:path'); + + const configuredPath = process.env.CASCADEFLOW_CONFIG; + const candidates = configuredPath + ? [configuredPath] + : ['cascadeflow.json', 'cascadeflow.config.json']; + + for (const candidate of candidates) { + const full = path.isAbsolute(candidate) ? candidate : path.join(process.cwd(), candidate); + if (!fs.existsSync(full)) { + continue; + } + + const content = fs.readFileSync(full, 'utf8'); + const parsed = JSON.parse(content) as Record; + const harnessBlock = ( + parsed.harness && typeof parsed.harness === 'object' && !Array.isArray(parsed.harness) + ) + ? (parsed.harness as Record) + : parsed; + + return normalizeConfigRecord(harnessBlock); + } + } catch { + return {}; + } + + return {}; +} + +function resolveConfig(options: HarnessInitOptions): ConfigWithSources { + const env = readEnvConfig(); + const file = readFileConfig(); + const sources: Record = {}; + + const resolve = ( + key: keyof HarnessConfig, + explicit: T | undefined, + envValue: T | undefined, + fileValue: T | undefined, + defaultValue: T, + ): T => { + if (explicit !== undefined) { + sources[key] = 'code'; + return explicit; + } + if (envValue !== undefined) { + sources[key] = 'env'; + return envValue; + } + if (fileValue !== undefined) { + sources[key] = 'file'; + return fileValue; + } + sources[key] = 'default'; + return defaultValue; + }; + + const mode = resolve('mode', options.mode, env.mode, file.mode, 'off'); + const verbose = resolve('verbose', options.verbose, env.verbose, file.verbose, false); + const budget = resolve('budget', options.budget, env.budget, file.budget, undefined); + const maxToolCalls = resolve( + 'maxToolCalls', + options.maxToolCalls, + env.maxToolCalls, + file.maxToolCalls, + undefined, + ); + const maxLatencyMs = resolve( + 'maxLatencyMs', + options.maxLatencyMs, + env.maxLatencyMs, + file.maxLatencyMs, + undefined, + ); + const maxEnergy = resolve('maxEnergy', options.maxEnergy, env.maxEnergy, file.maxEnergy, undefined); + const kpiTargets = resolve( + 'kpiTargets', + options.kpiTargets, + env.kpiTargets, + file.kpiTargets, + undefined, + ); + const kpiWeights = resolve( + 'kpiWeights', + options.kpiWeights, + env.kpiWeights, + file.kpiWeights, + undefined, + ); + const compliance = resolve( + 'compliance', + options.compliance, + env.compliance, + file.compliance, + undefined, + ); + + return { + config: { + mode, + verbose, + budget, + maxToolCalls, + maxLatencyMs, + maxEnergy, + kpiTargets, + kpiWeights, + compliance, + }, + sources, + }; +} + +export function getHarnessConfig(): HarnessConfig { + return { ..._harnessConfig }; +} + +export function getCurrentRun(): HarnessRunContext | null { + const als = getAsyncLocalStorage(); + if (als) { + return als.getStore() ?? null; + } + + return fallbackCurrentRun; +} + +export function reset(): void { + unpatchOpenAI(); + unpatchAnthropic(); + __resetInstrumentationStateForTest(); + + _harnessConfig = { mode: 'off', verbose: false }; + _isInstrumented = false; + fallbackCurrentRun = null; +} + +export function init(options: HarnessInitOptions = {}): HarnessInitReport { + const { config, sources } = resolveConfig(options); + config.mode = normalizeMode(config.mode); + + _harnessConfig = config; + + const instrumented: string[] = []; + const detectedButNotInstrumented: string[] = []; + + const openaiDetected = detectOpenAIInstrumentationTarget(); + const anthropicDetected = detectAnthropicInstrumentationTarget(); + + if (config.mode !== 'off' && openaiDetected) { + if (patchOpenAI()) { + instrumented.push('openai'); + } else { + detectedButNotInstrumented.push('openai'); + } + } + + if (config.mode !== 'off' && anthropicDetected) { + if (patchAnthropic()) { + instrumented.push('anthropic'); + } else { + detectedButNotInstrumented.push('anthropic'); + } + } + + if (config.mode === 'off') { + unpatchOpenAI(); + unpatchAnthropic(); + } + + _isInstrumented = true; + + if (config.verbose) { + // eslint-disable-next-line no-console + console.info('[cascadeflow.harness] init', { + mode: config.mode, + instrumented, + detectedButNotInstrumented, + }); + } + + return { + mode: config.mode, + instrumented, + detectedButNotInstrumented, + configSources: sources, + }; +} + +type RunCallback = (run: HarnessRunContext) => Promise | T; + +async function executeScopedRun(runContext: HarnessRunContext, fn: RunCallback): Promise { + try { + return await fn(runContext); + } finally { + runContext.finish(); + } +} + +export async function run(callback: RunCallback): Promise; +export async function run(options: HarnessRunOptions, callback: RunCallback): Promise; +export async function run( + optionsOrCallback: HarnessRunOptions | RunCallback, + callback?: RunCallback, +): Promise { + const options = typeof optionsOrCallback === 'function' ? {} : optionsOrCallback; + const cb = (typeof optionsOrCallback === 'function' ? optionsOrCallback : callback) as RunCallback | undefined; + + if (!cb) { + throw new Error('run() requires a callback: run(options?, async (run) => { ... })'); + } + + const cfg = getHarnessConfig(); + const runContext = new HarnessRunContext({ + mode: cfg.mode, + budgetMax: options.budget ?? cfg.budget, + toolCallsMax: options.maxToolCalls ?? cfg.maxToolCalls, + latencyMaxMs: options.maxLatencyMs ?? cfg.maxLatencyMs, + energyMax: options.maxEnergy ?? cfg.maxEnergy, + kpiTargets: options.kpiTargets ?? cfg.kpiTargets, + kpiWeights: options.kpiWeights ?? cfg.kpiWeights, + compliance: options.compliance ?? cfg.compliance, + verbose: cfg.verbose, + }); + + const als = getAsyncLocalStorage(); + if (als) { + return als.run(runContext, async () => executeScopedRun(runContext, cb)) as Promise; + } + + const previous = fallbackCurrentRun; + fallbackCurrentRun = runContext; + try { + return await executeScopedRun(runContext, cb); + } finally { + fallbackCurrentRun = previous; + } +} + +export function agent(policy: HarnessRunOptions): any>(fn: T) => T { + return any>(fn: T): T => { + const wrapped = ((...args: any[]) => fn(...args)) as T; + (wrapped as any).__cascadeflow_agent_policy__ = { + budget: policy.budget, + kpiTargets: policy.kpiTargets, + kpiWeights: policy.kpiWeights, + compliance: policy.compliance, + }; + return wrapped; + }; +} + +setHarnessRuntimeBindingsForInstrumentation({ + getCurrentRun, + getHarnessMode: () => getHarnessConfig().mode, + createBudgetExceededError: (message: string, remaining?: number) => + new BudgetExceededError(message, remaining), + createHarnessStopError: (message: string, reason?: string) => + new HarnessStopError(message, reason), +}); + +export const cascadeflow = { + init, + run, + agent, + reset, + getHarnessConfig, + getCurrentRun, +}; + +export function isHarnessInstrumented(): boolean { + return _isInstrumented; +} diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 29819183..c919f67e 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -42,6 +42,31 @@ export { DEFAULT_CASCADE_CONFIG, } from './config'; +// Harness API (v2.1+) +export type { + HarnessMode, + HarnessConfig, + HarnessInitOptions, + HarnessRunOptions, + HarnessInitReport, + HarnessRecordOptions, + HarnessTraceEntry, + HarnessRunSummary, +} from './harness'; +export { + HarnessRunContext, + HarnessStopError, + BudgetExceededError, + init, + run, + agent as harnessAgent, + reset as resetHarness, + getHarnessConfig, + getCurrentRun, + isHarnessInstrumented, + cascadeflow, +} from './harness'; + // Results export type { CascadeResult } from './result'; export { resultToObject } from './result'; diff --git a/pyproject.toml b/pyproject.toml index 0d488faa..2bbd3082 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,28 @@ semantic = [ # OpenClaw integration (auto-enables FastEmbed for semantic routing) openclaw = ["fastembed>=0.7.0"] +# CrewAI harness integration (opt-in) +crewai = ["crewai>=1.5.0"] + +# OpenAI Agents SDK integration (opt-in) +openai-agents = [ + "openai-agents>=0.8.4; python_version < '3.10'", + "openai-agents>=0.9.0; python_version >= '3.10'", +] + +# LangChain harness integration (opt-in) +langchain = [ + "langchain>=0.3.0", + "langchain-core>=0.3.0", +] + +# LangGraph state extraction (opt-in, adds langgraph on top of langchain) +langgraph = [ + "langchain>=0.3.0", + "langchain-core>=0.3.0", + "langgraph>=0.2.0", +] + # Development tools (includes rich for terminal output) dev = [ "pytest>=7.4.0", diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py index 185d1589..f012e698 100644 --- a/tests/benchmarks/__init__.py +++ b/tests/benchmarks/__init__.py @@ -15,6 +15,22 @@ DOMAIN_CONFIGS, ) +# Reproducibility pipeline +from .repro import ReproMetadata, collect_repro_metadata, metadata_to_dict +from .baseline import ( + BaselineArtifact, + BenchmarkDelta, + ComparisonReport, + GoNoGoResult, + save_baseline, + load_baseline, + compare_to_baseline, + check_go_nogo, +) +from .harness_overhead import OverheadReport, measure_harness_overhead +from .observe_validation import ObserveValidationResult, validate_observe_mode +from .artifact import bundle_artifact + __all__ = [ # Base classes "Benchmark", @@ -39,4 +55,21 @@ "DRAFTER_MODELS", "VERIFIER_MODELS", "DOMAIN_CONFIGS", + # Reproducibility pipeline + "ReproMetadata", + "collect_repro_metadata", + "metadata_to_dict", + "BaselineArtifact", + "BenchmarkDelta", + "ComparisonReport", + "GoNoGoResult", + "save_baseline", + "load_baseline", + "compare_to_baseline", + "check_go_nogo", + "OverheadReport", + "measure_harness_overhead", + "ObserveValidationResult", + "validate_observe_mode", + "bundle_artifact", ] diff --git a/tests/benchmarks/artifact.py b/tests/benchmarks/artifact.py new file mode 100644 index 00000000..fde0f616 --- /dev/null +++ b/tests/benchmarks/artifact.py @@ -0,0 +1,117 @@ +"""Artifact bundler. + +Writes a single JSON artifact containing benchmark results, reproducibility +metadata, harness-overhead measurements, observe-mode validation, and an +optional baseline comparison. Also generates a ``REPRODUCE.md`` with exact +pip-install + run commands. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Optional + +SCHEMA_VERSION = "1.0.0" + + +def bundle_artifact( + *, + results: dict[str, Any], + metadata: dict[str, Any], + overhead: dict[str, Any], + observe: dict[str, Any], + comparison: Optional[dict[str, Any]] = None, + output_dir: Path | str = ".", + run_id: str = "unknown", +) -> Path: + """Write the full artifact bundle as ``artifact_.json``. + + Returns the path to the written file. + """ + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + bundle: dict[str, Any] = { + "schema_version": SCHEMA_VERSION, + "metadata": metadata, + "results": results, + "harness_overhead": overhead, + "observe_validation": observe, + } + if comparison is not None: + bundle["baseline_comparison"] = comparison + + artifact_path = output_dir / f"artifact_{run_id}.json" + artifact_path.write_text(json.dumps(bundle, indent=2)) + + _write_reproduce_md(output_dir, metadata) + + return artifact_path + + +# --------------------------------------------------------------------------- +# REPRODUCE.md generation +# --------------------------------------------------------------------------- + +_REPRODUCE_TEMPLATE = """\ +# Reproducing this benchmark run + +## Environment + +- **Git SHA:** {git_sha} +- **Python:** {python_version} +- **Platform:** {platform} +- **cascadeflow:** {cascadeflow_version} +- **Profile:** {profile} +- **Harness mode:** {harness_mode} + +## Steps + +```bash +# 1. Clone and checkout the exact commit +git clone https://github.com/lemony-ai/cascadeflow.git +cd cascadeflow +git checkout {git_sha} + +# 2. Create a virtual environment +python -m venv .venv && source .venv/bin/activate + +# 3. Install dependencies +pip install -e ".[dev]" + +# 4. Set API keys +export OPENAI_API_KEY="" +export ANTHROPIC_API_KEY="" + +# 5. Run the benchmark suite +python -m tests.benchmarks.run_all --profile {profile} --with-repro +``` + +## Package versions at time of run + +{package_table} +""" + + +def _write_reproduce_md(output_dir: Path, metadata: dict[str, Any]) -> Path: + packages = metadata.get("package_versions", {}) + rows = [f"| {name} | {ver} |" for name, ver in sorted(packages.items())] + table = ( + "| Package | Version |\n|---------|----------|\n" + "\n".join(rows) if rows else "_none_" + ) + + content = _REPRODUCE_TEMPLATE.format( + git_sha=metadata.get("git_sha", "unknown"), + python_version=metadata.get("python_version", "unknown"), + platform=metadata.get("platform", "unknown"), + cascadeflow_version=metadata.get("cascadeflow_version", "unknown"), + profile=metadata.get("profile", "smoke"), + harness_mode=metadata.get("harness_mode", "off"), + package_table=table, + ) + + path = output_dir / "REPRODUCE.md" + path.write_text(content) + return path diff --git a/tests/benchmarks/baseline.py b/tests/benchmarks/baseline.py new file mode 100644 index 00000000..85dc7770 --- /dev/null +++ b/tests/benchmarks/baseline.py @@ -0,0 +1,186 @@ +"""Baseline management and Go/No-Go gate evaluation. + +Saves and loads benchmark baselines as JSON, computes per-benchmark deltas, +and evaluates the four V2 Go/No-Go criteria. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Optional + +# --------------------------------------------------------------------------- +# Tolerance constants (module-level so tests can inspect / override) +# --------------------------------------------------------------------------- + +ACCURACY_REGRESSION_TOLERANCE: float = 2.0 # percentage points +SAVINGS_REGRESSION_TOLERANCE: float = 5.0 # percentage points + + +# --------------------------------------------------------------------------- +# Data classes +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class BaselineArtifact: + """A persisted baseline snapshot.""" + + metadata: dict[str, Any] + results: dict[str, Any] + + +@dataclass(frozen=True) +class BenchmarkDelta: + """Per-benchmark comparison between current and baseline.""" + + benchmark: str + accuracy_delta: float # positive = improvement + savings_delta: float + accept_rate_delta: float + latency_delta_ms: float # positive = slower + accuracy_regressed: bool + savings_regressed: bool + + +@dataclass(frozen=True) +class ComparisonReport: + """Aggregated comparison across all benchmarks.""" + + deltas: list[BenchmarkDelta] + any_accuracy_regression: bool + any_savings_regression: bool + + +@dataclass(frozen=True) +class GoNoGoResult: + """Result of the four-gate V2 readiness check.""" + + observe_zero_change: bool + overhead_under_5ms: bool + no_accuracy_regression: bool + no_savings_regression: bool + overall: bool + details: dict[str, Any] = field(default_factory=dict) + + +# --------------------------------------------------------------------------- +# Persistence +# --------------------------------------------------------------------------- + + +def save_baseline( + results: dict[str, Any], + metadata: dict[str, Any], + path: Path, +) -> Path: + """Write *results* + *metadata* as a baseline JSON file.""" + + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + payload = {"metadata": metadata, "results": results} + path.write_text(json.dumps(payload, indent=2)) + return path + + +def load_baseline(path: Path) -> BaselineArtifact: + """Load a previously saved baseline.""" + + raw = json.loads(Path(path).read_text()) + return BaselineArtifact(metadata=raw["metadata"], results=raw["results"]) + + +# --------------------------------------------------------------------------- +# Comparison +# --------------------------------------------------------------------------- + +_RESULT_KEYS = ("accuracy", "savings_pct", "accept_rate", "avg_latency_ms") + + +def compare_to_baseline( + current: dict[str, Any], + baseline: dict[str, Any], +) -> ComparisonReport: + """Compute per-benchmark deltas between *current* and *baseline*.""" + + deltas: list[BenchmarkDelta] = [] + + all_benchmarks = set(current) | set(baseline) + for name in sorted(all_benchmarks): + cur = current.get(name, {}) + base = baseline.get(name, {}) + if not cur or not base: + continue + + acc_delta = cur.get("accuracy", 0.0) - base.get("accuracy", 0.0) + sav_delta = cur.get("savings_pct", 0.0) - base.get("savings_pct", 0.0) + ar_delta = cur.get("accept_rate", 0.0) - base.get("accept_rate", 0.0) + lat_delta = cur.get("avg_latency_ms", 0.0) - base.get("avg_latency_ms", 0.0) + + deltas.append( + BenchmarkDelta( + benchmark=name, + accuracy_delta=acc_delta, + savings_delta=sav_delta, + accept_rate_delta=ar_delta, + latency_delta_ms=lat_delta, + accuracy_regressed=acc_delta < -ACCURACY_REGRESSION_TOLERANCE, + savings_regressed=sav_delta < -SAVINGS_REGRESSION_TOLERANCE, + ) + ) + + return ComparisonReport( + deltas=deltas, + any_accuracy_regression=any(d.accuracy_regressed for d in deltas), + any_savings_regression=any(d.savings_regressed for d in deltas), + ) + + +# --------------------------------------------------------------------------- +# Go / No-Go +# --------------------------------------------------------------------------- + + +def check_go_nogo( + comparison: Optional[ComparisonReport], + overhead_p95_us: float, + observe_all_passed: bool, +) -> GoNoGoResult: + """Evaluate the four V2 readiness gates. + + Args: + comparison: Baseline comparison (may be ``None`` when no baseline exists). + overhead_p95_us: Harness decision overhead p95 in *microseconds*. + observe_all_passed: Whether observe-mode validation passed all cases. + + Returns: + A ``GoNoGoResult`` with individual gate flags and ``overall``. + """ + + observe_ok = observe_all_passed + overhead_ok = overhead_p95_us < 5_000.0 # 5 ms = 5 000 us + + if comparison is not None: + acc_ok = not comparison.any_accuracy_regression + sav_ok = not comparison.any_savings_regression + else: + # No baseline → cannot fail these gates. + acc_ok = True + sav_ok = True + + overall = observe_ok and overhead_ok and acc_ok and sav_ok + + return GoNoGoResult( + observe_zero_change=observe_ok, + overhead_under_5ms=overhead_ok, + no_accuracy_regression=acc_ok, + no_savings_regression=sav_ok, + overall=overall, + details={ + "overhead_p95_us": overhead_p95_us, + "accuracy_tolerance": ACCURACY_REGRESSION_TOLERANCE, + "savings_tolerance": SAVINGS_REGRESSION_TOLERANCE, + }, + ) diff --git a/tests/benchmarks/harness_overhead.py b/tests/benchmarks/harness_overhead.py new file mode 100644 index 00000000..926d026a --- /dev/null +++ b/tests/benchmarks/harness_overhead.py @@ -0,0 +1,116 @@ +"""Harness decision-path overhead measurement. + +Measures the CPU time of ``_evaluate_pre_call_decision`` across a variety of +model / budget / latency / energy / KPI states. No network calls are made. +""" + +from __future__ import annotations + +import itertools +import statistics +import time +from dataclasses import dataclass +from typing import Any + +from cascadeflow.harness.api import HarnessRunContext +from cascadeflow.harness.instrument import _evaluate_pre_call_decision +from cascadeflow.harness.pricing import OPENAI_MODEL_POOL + + +@dataclass(frozen=True) +class OverheadReport: + """Summary of decision-path latency measurements.""" + + iterations: int + p50_us: float + p95_us: float + p99_us: float + mean_us: float + max_us: float + p95_under_5ms: bool + + +# Representative context configurations that exercise different code paths. +_BUDGET_STATES: list[dict[str, Any]] = [ + {}, # no budget + {"budget_max": 10.0, "cost": 0.0}, # plenty of budget + {"budget_max": 1.0, "cost": 0.85}, # budget pressure (<20 % remaining) + {"budget_max": 1.0, "cost": 1.0}, # budget exhausted +] + +_LATENCY_STATES: list[dict[str, Any]] = [ + {}, + {"latency_max_ms": 5000.0, "latency_used_ms": 0.0}, + {"latency_max_ms": 5000.0, "latency_used_ms": 5500.0}, # over limit +] + +_ENERGY_STATES: list[dict[str, Any]] = [ + {}, + {"energy_max": 100.0, "energy_used": 0.0}, + {"energy_max": 100.0, "energy_used": 110.0}, # over limit +] + +_TOOL_FLAGS: list[bool] = [False, True] + + +def _build_ctx(**overrides: Any) -> HarnessRunContext: + """Create a lightweight HarnessRunContext for overhead testing.""" + + return HarnessRunContext(mode="enforce", **overrides) + + +def measure_harness_overhead(iterations: int = 1000) -> OverheadReport: + """Run *iterations* calls to ``_evaluate_pre_call_decision`` and report timing. + + The function cycles through ``OPENAI_MODEL_POOL`` models and various budget / + latency / energy / tool states to exercise all decision branches. Timing is + captured with ``time.perf_counter_ns`` for nanosecond resolution. + """ + + # Build a round-robin of (model, budget, latency, energy, has_tools) combos. + combos = list( + itertools.product( + OPENAI_MODEL_POOL, + _BUDGET_STATES, + _LATENCY_STATES, + _ENERGY_STATES, + _TOOL_FLAGS, + ) + ) + combo_cycle = itertools.cycle(combos) + + timings_ns: list[int] = [] + + for _ in range(iterations): + model, budget, latency, energy, has_tools = next(combo_cycle) + ctx = _build_ctx(**budget, **latency, **energy) + + t0 = time.perf_counter_ns() + _evaluate_pre_call_decision(ctx, model, has_tools) + t1 = time.perf_counter_ns() + + timings_ns.append(t1 - t0) + + timings_us = [t / 1_000.0 for t in timings_ns] + timings_us.sort() + + def _percentile(data: list[float], pct: float) -> float: + idx = int(len(data) * pct / 100.0) + idx = min(idx, len(data) - 1) + return data[idx] + + p50 = _percentile(timings_us, 50) + p95 = _percentile(timings_us, 95) + p99 = _percentile(timings_us, 99) + mean = statistics.mean(timings_us) + max_val = timings_us[-1] + + return OverheadReport( + iterations=iterations, + p50_us=p50, + p95_us=p95, + p99_us=p99, + mean_us=mean, + max_us=max_val, + p95_under_5ms=p95 < 5_000.0, + ) diff --git a/tests/benchmarks/observe_validation.py b/tests/benchmarks/observe_validation.py new file mode 100644 index 00000000..cd6bf1a9 --- /dev/null +++ b/tests/benchmarks/observe_validation.py @@ -0,0 +1,174 @@ +"""Observe-mode zero-change proof. + +For each synthetic scenario we run ``_prepare_call_interception`` in both +``"off"`` and ``"observe"`` modes and assert the resulting kwargs are +identical — proving that observe mode produces zero behavior change. +""" + +from __future__ import annotations + +import copy +from dataclasses import dataclass, field +from typing import Any + +from cascadeflow.harness.api import HarnessRunContext +from cascadeflow.harness.instrument import ( + _CallInterceptionState, + _prepare_call_interception, +) + + +@dataclass +class ObserveValidationResult: + """Outcome of the full observe-mode validation suite.""" + + total_cases: int + passed: int + failed: int + failures: list[str] + all_passed: bool + + +# --------------------------------------------------------------------------- +# Synthetic cases +# --------------------------------------------------------------------------- + +_COMPARE_KEYS = ("model", "messages", "tools", "stream") + + +def _simple_chat_kwargs() -> dict[str, Any]: + return { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + "stream": False, + } + + +def _chat_with_tools_kwargs() -> dict[str, Any]: + return { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "What is the weather?"}], + "tools": [ + { + "type": "function", + "function": {"name": "get_weather", "parameters": {}}, + } + ], + "stream": False, + } + + +def _budget_exceeded_kwargs() -> dict[str, Any]: + return _simple_chat_kwargs() + + +def _tool_limit_kwargs() -> dict[str, Any]: + return _chat_with_tools_kwargs() + + +def _compliance_kwargs() -> dict[str, Any]: + return { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Draft a contract"}], + "stream": False, + } + + +def _kpi_weighted_kwargs() -> dict[str, Any]: + return { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Summarise this document"}], + "stream": False, + } + + +# Each case: (label, kwargs_factory, ctx_overrides_for_constraints) +_CASES: list[tuple[str, Any, dict[str, Any]]] = [ + ("simple_chat", _simple_chat_kwargs, {}), + ("chat_with_tools", _chat_with_tools_kwargs, {}), + ( + "budget_exceeded", + _budget_exceeded_kwargs, + {"budget_max": 1.0, "cost": 1.0}, + ), + ( + "tool_limit_reached", + _tool_limit_kwargs, + {"tool_calls_max": 5, "tool_calls": 5}, + ), + ( + "compliance_constraint", + _compliance_kwargs, + {"compliance": "gdpr"}, + ), + ( + "kpi_weighted", + _kpi_weighted_kwargs, + {"kpi_weights": {"cost": 0.7, "quality": 0.3}}, + ), +] + + +def _run_single_case( + label: str, + kwargs_factory: Any, + ctx_overrides: dict[str, Any], +) -> str | None: + """Run one case. Returns an error string on failure, ``None`` on success.""" + + # --- reference: mode="off" --- + ref_kwargs = kwargs_factory() + ref_ctx = HarnessRunContext(mode="off", **ctx_overrides) + ref_state: _CallInterceptionState = _prepare_call_interception( + ctx=ref_ctx, + mode="off", + kwargs=copy.deepcopy(ref_kwargs), + ) + + # --- observed: mode="observe" --- + obs_kwargs = kwargs_factory() + obs_ctx = HarnessRunContext(mode="observe", **ctx_overrides) + obs_state: _CallInterceptionState = _prepare_call_interception( + ctx=obs_ctx, + mode="observe", + kwargs=copy.deepcopy(obs_kwargs), + ) + + # 1. kwargs identity — the observable behaviour MUST be the same. + for key in _COMPARE_KEYS: + ref_val = ref_state.kwargs.get(key) + obs_val = obs_state.kwargs.get(key) + if ref_val != obs_val: + return f"{label}: kwargs['{key}'] differs — off={ref_val!r} observe={obs_val!r}" + + # 2. The model sent to the API must not change. + if ref_state.model != obs_state.model: + return f"{label}: model differs — off={ref_state.model} observe={obs_state.model}" + + # 3. For constrained cases the harness SHOULD have evaluated a decision + # (pre_action may be "stop", "switch_model", etc.). Confirm the + # observe path at least recorded the same action as the off path. + if ctx_overrides and obs_state.pre_action == "allow" and ref_state.pre_action != "allow": + return f"{label}: observe lost decision — expected {ref_state.pre_action}, got allow" + + return None + + +def validate_observe_mode() -> ObserveValidationResult: + """Run all synthetic cases and return the validation result.""" + + failures: list[str] = [] + for label, factory, overrides in _CASES: + err = _run_single_case(label, factory, overrides) + if err is not None: + failures.append(err) + + total = len(_CASES) + failed = len(failures) + return ObserveValidationResult( + total_cases=total, + passed=total - failed, + failed=failed, + failures=failures, + all_passed=failed == 0, + ) diff --git a/tests/benchmarks/repro.py b/tests/benchmarks/repro.py new file mode 100644 index 00000000..2472c245 --- /dev/null +++ b/tests/benchmarks/repro.py @@ -0,0 +1,132 @@ +"""Reproducibility metadata collector. + +Captures a full environment fingerprint so benchmark runs can be reproduced +by third parties. The metadata is embedded in every artifact bundle. +""" + +from __future__ import annotations + +import platform +import subprocess +import sys +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any + + +@dataclass(frozen=True) +class ReproMetadata: + """Immutable snapshot of the environment at benchmark start.""" + + git_sha: str + git_dirty: bool + python_version: str + platform: str + cascadeflow_version: str + profile: str + drafter_model: str + verifier_model: str + baseline_model: str + harness_mode: str + package_versions: dict[str, str] + run_id: str + timestamp_utc: str + + +# Packages whose versions we record for reproducibility. +_TRACKED_PACKAGES = ( + "cascadeflow", + "openai", + "anthropic", + "httpx", + "pydantic", + "tiktoken", +) + + +def _git_sha() -> str: + try: + return ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], + stderr=subprocess.DEVNULL, + ) + .decode() + .strip() + ) + except Exception: + return "unknown" + + +def _git_dirty() -> bool: + try: + out = ( + subprocess.check_output( + ["git", "status", "--porcelain"], + stderr=subprocess.DEVNULL, + ) + .decode() + .strip() + ) + return bool(out) + except Exception: + return False + + +def _package_version(name: str) -> str: + try: + from importlib.metadata import version + + return version(name) + except Exception: + return "not installed" + + +def collect_repro_metadata( + *, + profile: str = "smoke", + drafter_model: str = "", + verifier_model: str = "", + baseline_model: str = "", + harness_mode: str = "off", +) -> ReproMetadata: + """Capture a reproducibility fingerprint of the current environment.""" + + versions = {pkg: _package_version(pkg) for pkg in _TRACKED_PACKAGES} + + return ReproMetadata( + git_sha=_git_sha(), + git_dirty=_git_dirty(), + python_version=sys.version, + platform=platform.platform(), + cascadeflow_version=versions.get("cascadeflow", "not installed"), + profile=profile, + drafter_model=drafter_model, + verifier_model=verifier_model, + baseline_model=baseline_model, + harness_mode=harness_mode, + package_versions=versions, + run_id=uuid.uuid4().hex[:12], + timestamp_utc=datetime.now(timezone.utc).isoformat(), + ) + + +def metadata_to_dict(meta: ReproMetadata) -> dict[str, Any]: + """Convert *meta* to a JSON-serializable dict.""" + + return { + "git_sha": meta.git_sha, + "git_dirty": meta.git_dirty, + "python_version": meta.python_version, + "platform": meta.platform, + "cascadeflow_version": meta.cascadeflow_version, + "profile": meta.profile, + "drafter_model": meta.drafter_model, + "verifier_model": meta.verifier_model, + "baseline_model": meta.baseline_model, + "harness_mode": meta.harness_mode, + "package_versions": dict(meta.package_versions), + "run_id": meta.run_id, + "timestamp_utc": meta.timestamp_utc, + } diff --git a/tests/benchmarks/run_all.py b/tests/benchmarks/run_all.py index 9153d6c8..739c0342 100644 --- a/tests/benchmarks/run_all.py +++ b/tests/benchmarks/run_all.py @@ -21,7 +21,12 @@ from pathlib import Path from typing import Any +from .artifact import bundle_artifact +from .baseline import compare_to_baseline, load_baseline +from .harness_overhead import measure_harness_overhead +from .observe_validation import validate_observe_mode from .reporter import BenchmarkReporter +from .repro import collect_repro_metadata, metadata_to_dict PROFILE_PRESETS: dict[str, dict[str, Any]] = { @@ -432,6 +437,24 @@ async def main(): action="store_true", help="Skip provider comparison benchmark (expensive; enabled by default in standard/overnight/full)", ) + parser.add_argument( + "--baseline", + type=str, + default=None, + help="Path to a baseline JSON file for regression comparison", + ) + parser.add_argument( + "--harness-mode", + type=str, + choices=["off", "observe"], + default="off", + help="Harness mode for overhead and observe-validation measurements", + ) + parser.add_argument( + "--with-repro", + action="store_true", + help="Collect reproducibility metadata and write an artifact bundle", + ) args = parser.parse_args() @@ -460,40 +483,100 @@ async def main(): f.write(comparison_table) print(f"\n✅ Markdown report: {md_path}") + # Convert summaries to dicts for JSON serialization (used by --with-repro too). + json_results: dict[str, Any] = {} + for name, summary in results.items(): + if summary is None: + json_results[name] = None + elif summary == "completed": + json_results[name] = "completed" + elif isinstance(summary, dict): + json_results[name] = summary + else: + json_results[name] = { + "dataset_name": summary.dataset_name, + "total_tests": summary.total_tests, + "successful_tests": summary.successful_tests, + "failed_tests": summary.failed_tests, + "accuracy": summary.accuracy, + "drafter_accepted": summary.drafter_accepted, + "acceptance_rate_pct": summary.acceptance_rate_pct, + "escalation_rate_pct": summary.escalation_rate_pct, + "total_cost": summary.total_cost, + "total_baseline_cost": summary.total_baseline_cost, + "total_savings": summary.total_savings, + "avg_savings_pct": summary.avg_savings_pct, + "avg_latency_ms": summary.avg_latency_ms, + "drafter_accuracy": summary.drafter_accuracy, + "verifier_accuracy": summary.verifier_accuracy, + } + if "json" in formats: json_path = output_dir / "results.json" - # Convert summaries to dicts for JSON serialization - json_results = {} - for name, summary in results.items(): - if summary is None: - json_results[name] = None - elif summary == "completed": - json_results[name] = "completed" - elif isinstance(summary, dict): - json_results[name] = summary - else: - json_results[name] = { - "dataset_name": summary.dataset_name, - "total_tests": summary.total_tests, - "successful_tests": summary.successful_tests, - "failed_tests": summary.failed_tests, - "accuracy": summary.accuracy, - "drafter_accepted": summary.drafter_accepted, - "acceptance_rate_pct": summary.acceptance_rate_pct, - "escalation_rate_pct": summary.escalation_rate_pct, - "total_cost": summary.total_cost, - "total_baseline_cost": summary.total_baseline_cost, - "total_savings": summary.total_savings, - "avg_savings_pct": summary.avg_savings_pct, - "avg_latency_ms": summary.avg_latency_ms, - "drafter_accuracy": summary.drafter_accuracy, - "verifier_accuracy": summary.verifier_accuracy, - } - with open(json_path, "w") as f: json.dump(json_results, f, indent=2) print(f"✅ JSON results: {json_path}") + # ---- Reproducibility artifact (opt-in) ---- + if args.with_repro: + meta = collect_repro_metadata( + profile=args.profile, + harness_mode=args.harness_mode, + ) + meta_dict = metadata_to_dict(meta) + + overhead_report = measure_harness_overhead() + overhead_dict = { + "iterations": overhead_report.iterations, + "p50_us": overhead_report.p50_us, + "p95_us": overhead_report.p95_us, + "p99_us": overhead_report.p99_us, + "mean_us": overhead_report.mean_us, + "max_us": overhead_report.max_us, + "p95_under_5ms": overhead_report.p95_under_5ms, + } + + observe_result = validate_observe_mode() + observe_dict = { + "total_cases": observe_result.total_cases, + "passed": observe_result.passed, + "failed": observe_result.failed, + "failures": observe_result.failures, + "all_passed": observe_result.all_passed, + } + + comparison_data = None + if args.baseline: + baseline_artifact = load_baseline(args.baseline) + report = compare_to_baseline(json_results, baseline_artifact.results) + comparison_data = { + "deltas": [ + { + "benchmark": d.benchmark, + "accuracy_delta": d.accuracy_delta, + "savings_delta": d.savings_delta, + "accept_rate_delta": d.accept_rate_delta, + "latency_delta_ms": d.latency_delta_ms, + "accuracy_regressed": d.accuracy_regressed, + "savings_regressed": d.savings_regressed, + } + for d in report.deltas + ], + "any_accuracy_regression": report.any_accuracy_regression, + "any_savings_regression": report.any_savings_regression, + } + + artifact_path = bundle_artifact( + results=json_results, + metadata=meta_dict, + overhead=overhead_dict, + observe=observe_dict, + comparison=comparison_data, + output_dir=output_dir, + run_id=meta.run_id, + ) + print(f"✅ Artifact bundle: {artifact_path}") + print("\n" + "=" * 80) print("BENCHMARK SUITE COMPLETED") print("=" * 80 + "\n") diff --git a/tests/test_bench_repro_pipeline.py b/tests/test_bench_repro_pipeline.py new file mode 100644 index 00000000..d598e398 --- /dev/null +++ b/tests/test_bench_repro_pipeline.py @@ -0,0 +1,329 @@ +"""Tests for the benchmark reproducibility pipeline. + +All 15 tests use mocks — no live API calls. +""" + +from __future__ import annotations + +import json +from unittest.mock import patch + +import pytest + +from tests.benchmarks.repro import ReproMetadata, collect_repro_metadata, metadata_to_dict +from tests.benchmarks.baseline import ( + BaselineArtifact, + BenchmarkDelta, + ComparisonReport, + GoNoGoResult, + save_baseline, + load_baseline, + compare_to_baseline, + check_go_nogo, +) +from tests.benchmarks.harness_overhead import OverheadReport, measure_harness_overhead +from tests.benchmarks.observe_validation import ( + ObserveValidationResult, + validate_observe_mode, +) +from tests.benchmarks.artifact import SCHEMA_VERSION, bundle_artifact + + +# ── Fixtures ────────────────────────────────────────────────────────────── + + +@pytest.fixture +def sample_results() -> dict: + """Minimal results dict matching the existing results JSON format.""" + return { + "MMLU": { + "accuracy": 100.0, + "accept_rate": 30.0, + "savings_pct": 10.68, + "effective_savings_pct": 10.68, + "drafter_accuracy": 100.0, + "verifier_accuracy": 100.0, + "total_cascade_cost": 0.017848, + "total_baseline_cost": 0.019983, + "avg_latency_ms": 1765.95, + "n_total": 40, + "n_correct": 40, + "n_accepted": 12, + }, + "TruthfulQA": { + "accuracy": 86.67, + "accept_rate": 73.33, + "savings_pct": 43.45, + "effective_savings_pct": 30.26, + "drafter_accuracy": 90.91, + "verifier_accuracy": 75.0, + "total_cascade_cost": 0.022044, + "total_baseline_cost": 0.038982, + "avg_latency_ms": 4164.96, + "n_total": 15, + "n_correct": 13, + "n_accepted": 11, + }, + } + + +@pytest.fixture +def sample_metadata() -> dict: + return metadata_to_dict(collect_repro_metadata(profile="smoke", harness_mode="off")) + + +# ── 1-2: ReproMetadata ─────────────────────────────────────────────────── + + +def test_collect_repro_metadata(): + """All fields populated.""" + meta = collect_repro_metadata( + profile="standard", + drafter_model="gpt-4o-mini", + verifier_model="gpt-4o", + baseline_model="gpt-4o", + harness_mode="observe", + ) + assert isinstance(meta, ReproMetadata) + assert meta.git_sha # non-empty + assert isinstance(meta.git_dirty, bool) + assert meta.python_version + assert meta.platform + assert meta.cascadeflow_version + assert meta.profile == "standard" + assert meta.drafter_model == "gpt-4o-mini" + assert meta.verifier_model == "gpt-4o" + assert meta.baseline_model == "gpt-4o" + assert meta.harness_mode == "observe" + assert isinstance(meta.package_versions, dict) + assert meta.run_id + assert meta.timestamp_utc + + +def test_metadata_round_trip(): + """JSON serializable and round-trips.""" + meta = collect_repro_metadata() + d = metadata_to_dict(meta) + raw = json.dumps(d) + loaded = json.loads(raw) + assert loaded["run_id"] == meta.run_id + assert loaded["python_version"] == meta.python_version + assert loaded["package_versions"] == d["package_versions"] + + +# ── 3-6: Baseline persistence + comparison ─────────────────────────────── + + +def test_save_load_baseline(tmp_path, sample_results, sample_metadata): + """Save then load produces identical data.""" + path = tmp_path / "baselines" / "test.json" + save_baseline(sample_results, sample_metadata, path) + loaded = load_baseline(path) + assert isinstance(loaded, BaselineArtifact) + assert loaded.results == sample_results + assert loaded.metadata == sample_metadata + + +def test_compare_no_regression(sample_results): + """Identical results → zero deltas, no regressions.""" + report = compare_to_baseline(sample_results, sample_results) + assert isinstance(report, ComparisonReport) + assert not report.any_accuracy_regression + assert not report.any_savings_regression + for d in report.deltas: + assert d.accuracy_delta == 0.0 + assert d.savings_delta == 0.0 + assert not d.accuracy_regressed + assert not d.savings_regressed + + +def test_compare_with_regression(sample_results): + """Accuracy drop flagged as regression.""" + worse = { + name: {**vals, "accuracy": vals["accuracy"] - 5.0} for name, vals in sample_results.items() + } + report = compare_to_baseline(worse, sample_results) + assert report.any_accuracy_regression + for d in report.deltas: + assert d.accuracy_delta == pytest.approx(-5.0) + assert d.accuracy_regressed + + +def test_compare_with_improvement(sample_results): + """Savings increase flagged (but not as regression).""" + better = { + name: {**vals, "savings_pct": vals["savings_pct"] + 10.0} + for name, vals in sample_results.items() + } + report = compare_to_baseline(better, sample_results) + assert not report.any_savings_regression + for d in report.deltas: + assert d.savings_delta == pytest.approx(10.0) + assert not d.savings_regressed + + +# ── 7-9: Go/No-Go ──────────────────────────────────────────────────────── + + +def test_go_nogo_all_pass(sample_results): + """All criteria met → overall=True.""" + report = compare_to_baseline(sample_results, sample_results) + result = check_go_nogo(report, overhead_p95_us=500.0, observe_all_passed=True) + assert isinstance(result, GoNoGoResult) + assert result.observe_zero_change is True + assert result.overhead_under_5ms is True + assert result.no_accuracy_regression is True + assert result.no_savings_regression is True + assert result.overall is True + + +def test_go_nogo_overhead_fail(sample_results): + """p95 >5 ms → overall=False.""" + report = compare_to_baseline(sample_results, sample_results) + result = check_go_nogo(report, overhead_p95_us=6_000.0, observe_all_passed=True) + assert result.overhead_under_5ms is False + assert result.overall is False + + +def test_go_nogo_observe_fail(sample_results): + """Observe mismatch → overall=False.""" + report = compare_to_baseline(sample_results, sample_results) + result = check_go_nogo(report, overhead_p95_us=500.0, observe_all_passed=False) + assert result.observe_zero_change is False + assert result.overall is False + + +# ── 10: Harness overhead ───────────────────────────────────────────────── + + +def test_harness_overhead_measurement(): + """100 iterations, values >0, p95 < p99 <= max.""" + report = measure_harness_overhead(iterations=100) + assert isinstance(report, OverheadReport) + assert report.iterations == 100 + assert report.p50_us > 0 + assert report.p95_us > 0 + assert report.p99_us > 0 + assert report.mean_us > 0 + assert report.max_us > 0 + assert report.p95_us <= report.p99_us + assert report.p99_us <= report.max_us + + +# ── 11-12: Observe validation ──────────────────────────────────────────── + + +def test_observe_validation_all_pass(): + """Default cases all pass.""" + result = validate_observe_mode() + assert isinstance(result, ObserveValidationResult) + assert result.all_passed, f"Failures: {result.failures}" + assert result.total_cases > 0 + assert result.failed == 0 + + +def test_observe_validation_detects_mutation(): + """Simulated observe-mode bug: if observe mode mutated kwargs, detection would occur. + + We patch _prepare_call_interception to inject a mutation when mode="observe", + proving the validator catches it. + """ + from cascadeflow.harness.instrument import _prepare_call_interception as _real + + def _mutating_intercept(*, ctx, mode, kwargs): + state = _real(ctx=ctx, mode=mode, kwargs=kwargs) + if mode == "observe": + # Simulate a bug: observe mode switches the model. + mutated_kwargs = {**state.kwargs, "model": "MUTATED"} + from cascadeflow.harness.instrument import _CallInterceptionState + + return _CallInterceptionState( + kwargs=mutated_kwargs, + model="MUTATED", + pre_action=state.pre_action, + pre_reason=state.pre_reason, + pre_model=state.pre_model, + pre_applied=state.pre_applied, + is_stream=state.is_stream, + start_time=state.start_time, + ) + return state + + with patch( + "tests.benchmarks.observe_validation._prepare_call_interception", + side_effect=_mutating_intercept, + ): + result = validate_observe_mode() + + assert not result.all_passed + assert result.failed > 0 + assert any("model" in f or "MUTATED" in f for f in result.failures) + + +# ── 13-15: Artifact bundle ─────────────────────────────────────────────── + + +def _make_bundle(tmp_path, sample_results, sample_metadata) -> dict: + overhead = { + "iterations": 100, + "p50_us": 5.0, + "p95_us": 12.0, + "p99_us": 20.0, + "mean_us": 8.0, + "max_us": 25.0, + "p95_under_5ms": True, + } + observe = { + "total_cases": 6, + "passed": 6, + "failed": 0, + "failures": [], + "all_passed": True, + } + path = bundle_artifact( + results=sample_results, + metadata=sample_metadata, + overhead=overhead, + observe=observe, + output_dir=tmp_path, + run_id="test123", + ) + return json.loads(path.read_text()) + + +def test_artifact_bundle_format(tmp_path, sample_results, sample_metadata): + """Has schema_version, metadata, results keys.""" + bundle = _make_bundle(tmp_path, sample_results, sample_metadata) + assert "schema_version" in bundle + assert "metadata" in bundle + assert "results" in bundle + assert "harness_overhead" in bundle + assert "observe_validation" in bundle + + +def test_artifact_schema_version(tmp_path, sample_results, sample_metadata): + """Matches '1.0.0'.""" + bundle = _make_bundle(tmp_path, sample_results, sample_metadata) + assert bundle["schema_version"] == "1.0.0" + assert bundle["schema_version"] == SCHEMA_VERSION + + +def test_artifact_results_compatible(tmp_path, sample_results, sample_metadata): + """Result keys match existing format.""" + bundle = _make_bundle(tmp_path, sample_results, sample_metadata) + expected_keys = { + "accuracy", + "accept_rate", + "savings_pct", + "effective_savings_pct", + "drafter_accuracy", + "verifier_accuracy", + "total_cascade_cost", + "total_baseline_cost", + "avg_latency_ms", + "n_total", + "n_correct", + "n_accepted", + } + for name, bench in bundle["results"].items(): + assert expected_keys == set(bench.keys()), f"{name} keys mismatch: {set(bench.keys())}" diff --git a/tests/test_crewai_integration.py b/tests/test_crewai_integration.py new file mode 100644 index 00000000..c17498b4 --- /dev/null +++ b/tests/test_crewai_integration.py @@ -0,0 +1,510 @@ +"""Tests for cascadeflow.integrations.crewai harness integration. + +crewai is not installed in test environments, so we mock the hooks module +and test the integration logic directly against HarnessRunContext. +""" + +from __future__ import annotations + +import types +from unittest.mock import patch + +import pytest + +from cascadeflow.harness import init, reset, run + +# Import the module directly — it does not require crewai at import time +# (CREWAI_AVAILABLE will be False, but all functions/classes are still defined). +import cascadeflow.integrations.crewai as crewai_mod + + +@pytest.fixture(autouse=True) +def _reset_crewai_state(): + """Reset harness and crewai module state before every test.""" + reset() + crewai_mod._hooks_registered = False + crewai_mod._before_hook_ref = None + crewai_mod._after_hook_ref = None + crewai_mod._config = crewai_mod.CrewAIHarnessConfig() + crewai_mod._call_start_times.clear() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class FakeLLM: + """Minimal stand-in for a CrewAI LLM object.""" + + def __init__(self, model: str = "gpt-4o"): + self.model = model + + +class FakeHookContext: + """Minimal stand-in for crewai's LLMCallHookContext.""" + + def __init__( + self, + *, + llm: FakeLLM | None = None, + messages: list | None = None, + response: str | None = None, + ): + self.llm = llm or FakeLLM() + self.messages = messages or [] + self.response = response + + +def _make_fake_hooks_module(): + """Build a fake crewai.hooks module with recording registration helpers.""" + mod = types.ModuleType("crewai.hooks") + mod._before_hooks = [] + mod._after_hooks = [] + mod.register_before_llm_call_hook = lambda fn: mod._before_hooks.append(fn) + mod.register_after_llm_call_hook = lambda fn: mod._after_hooks.append(fn) + mod.unregister_before_llm_call_hook = lambda fn: ( + mod._before_hooks.remove(fn) if fn in mod._before_hooks else None + ) + mod.unregister_after_llm_call_hook = lambda fn: ( + mod._after_hooks.remove(fn) if fn in mod._after_hooks else None + ) + return mod + + +# --------------------------------------------------------------------------- +# _extract_message_content +# --------------------------------------------------------------------------- + + +class TestExtractMessageContent: + def test_dict_message(self): + msg = {"role": "user", "content": "Hello world"} + assert crewai_mod._extract_message_content(msg) == "Hello world" + + def test_dict_message_missing_content(self): + msg = {"role": "system"} + assert crewai_mod._extract_message_content(msg) == "" + + def test_dict_message_none_content(self): + msg = {"role": "assistant", "content": None} + assert crewai_mod._extract_message_content(msg) == "" + + def test_object_message(self): + class Msg: + content = "from object" + + assert crewai_mod._extract_message_content(Msg()) == "from object" + + def test_object_message_no_content(self): + assert crewai_mod._extract_message_content(object()) == "" + + +# --------------------------------------------------------------------------- +# _extract_model_name +# --------------------------------------------------------------------------- + + +class TestExtractModelName: + def test_extracts_plain_model(self): + ctx = FakeHookContext(llm=FakeLLM("gpt-4o")) + assert crewai_mod._extract_model_name(ctx) == "gpt-4o" + + def test_strips_provider_prefix(self): + ctx = FakeHookContext(llm=FakeLLM("openai/gpt-4o-mini")) + assert crewai_mod._extract_model_name(ctx) == "gpt-4o-mini" + + def test_no_llm_returns_unknown(self): + ctx = FakeHookContext() + ctx.llm = None + assert crewai_mod._extract_model_name(ctx) == "unknown" + + def test_no_model_attr_returns_unknown(self): + ctx = FakeHookContext() + ctx.llm = object() # no .model attribute + assert crewai_mod._extract_model_name(ctx) == "unknown" + + def test_non_string_model_returns_unknown(self): + ctx = FakeHookContext() + ctx.llm = FakeLLM("gpt-4o") + ctx.llm.model = 42 # not a string + assert crewai_mod._extract_model_name(ctx) == "unknown" + + +# --------------------------------------------------------------------------- +# Cost / energy estimation +# --------------------------------------------------------------------------- + + +class TestEstimation: + def test_estimate_cost_known_model(self): + cost = crewai_mod._estimate_cost("gpt-4o-mini", 1_000_000, 1_000_000) + assert cost == pytest.approx(0.15 + 0.60) + + def test_estimate_cost_unknown_model_uses_default(self): + cost = crewai_mod._estimate_cost("unknown-model", 1_000_000, 0) + assert cost == pytest.approx(2.50) + + def test_estimate_energy_known_model(self): + energy = crewai_mod._estimate_energy("gpt-4o", 100, 100) + # coeff=1.0, output_weight=1.5 + assert energy == pytest.approx(1.0 * (100 + 100 * 1.5)) + + def test_estimate_energy_unknown_model(self): + energy = crewai_mod._estimate_energy("unknown-model", 100, 100) + assert energy == pytest.approx(1.0 * (100 + 100 * 1.5)) + + +# --------------------------------------------------------------------------- +# before_llm_call_hook +# --------------------------------------------------------------------------- + + +class TestBeforeHook: + def test_no_run_context_returns_none(self): + ctx = FakeHookContext() + result = crewai_mod._before_llm_call_hook(ctx) + assert result is None + + def test_observe_mode_allows(self): + init(mode="observe", budget=0.001) + with run(budget=0.001) as run_ctx: + run_ctx.cost = 0.002 # over budget + hook_ctx = FakeHookContext() + result = crewai_mod._before_llm_call_hook(hook_ctx) + # observe mode never blocks + assert result is None + + def test_enforce_blocks_when_budget_exhausted(self): + init(mode="enforce", budget=0.001) + with run(budget=0.001) as run_ctx: + run_ctx.cost = 0.001 # exactly at budget + hook_ctx = FakeHookContext(llm=FakeLLM("gpt-4o")) + result = crewai_mod._before_llm_call_hook(hook_ctx) + assert result is False + assert run_ctx.last_action == "stop" + trace = run_ctx.trace() + assert trace[-1]["reason"] == "budget_exhausted" + + def test_enforce_blocked_call_does_not_leak_start_time(self): + """Blocked calls must not leave stale entries in _call_start_times.""" + init(mode="enforce", budget=0.001) + with run(budget=0.001) as run_ctx: + run_ctx.cost = 0.001 + hook_ctx = FakeHookContext(llm=FakeLLM("gpt-4o")) + crewai_mod._before_llm_call_hook(hook_ctx) + assert id(hook_ctx) not in crewai_mod._call_start_times + + def test_enforce_allows_when_under_budget(self): + init(mode="enforce", budget=1.0) + with run(budget=1.0) as run_ctx: + run_ctx.cost = 0.5 + hook_ctx = FakeHookContext() + result = crewai_mod._before_llm_call_hook(hook_ctx) + assert result is None + + def test_records_start_time(self): + init(mode="observe") + with run(): + hook_ctx = FakeHookContext() + crewai_mod._before_llm_call_hook(hook_ctx) + assert id(hook_ctx) in crewai_mod._call_start_times + + def test_budget_gate_disabled_in_config(self): + crewai_mod._config = crewai_mod.CrewAIHarnessConfig(enable_budget_gate=False) + init(mode="enforce", budget=0.001) + with run(budget=0.001) as run_ctx: + run_ctx.cost = 0.002 + hook_ctx = FakeHookContext() + result = crewai_mod._before_llm_call_hook(hook_ctx) + assert result is None # gate disabled, not blocked + + def test_fail_open_swallows_errors(self): + crewai_mod._config = crewai_mod.CrewAIHarnessConfig(fail_open=True) + init(mode="enforce") + with run(): + hook_ctx = FakeHookContext() + with patch( + "cascadeflow.harness.api.get_current_run", + side_effect=RuntimeError("boom"), + ): + result = crewai_mod._before_llm_call_hook(hook_ctx) + assert result is None # fail_open returns None + + def test_fail_closed_raises_errors(self): + crewai_mod._config = crewai_mod.CrewAIHarnessConfig(fail_open=False) + init(mode="enforce") + with run(): + hook_ctx = FakeHookContext() + with patch( + "cascadeflow.harness.api.get_current_run", + side_effect=RuntimeError("boom"), + ): + with pytest.raises(RuntimeError, match="boom"): + crewai_mod._before_llm_call_hook(hook_ctx) + + +# --------------------------------------------------------------------------- +# after_llm_call_hook +# --------------------------------------------------------------------------- + + +class TestAfterHook: + def test_no_run_context_returns_none(self): + ctx = FakeHookContext(response="hello") + result = crewai_mod._after_llm_call_hook(ctx) + assert result is None + + def test_updates_run_metrics_with_dict_messages(self): + """CrewAI passes messages as dicts — verify cost is nonzero.""" + init(mode="observe") + with run(budget=1.0) as run_ctx: + hook_ctx = FakeHookContext( + llm=FakeLLM("gpt-4o-mini"), + messages=[{"role": "user", "content": "What is 2+2?"}], + response="The answer is 4.", + ) + crewai_mod._call_start_times[id(hook_ctx)] = __import__("time").monotonic() - 0.1 + + crewai_mod._after_llm_call_hook(hook_ctx) + + assert run_ctx.step_count == 1 + assert run_ctx.cost > 0 + assert run_ctx.energy_used > 0 + assert run_ctx.latency_used_ms > 0 + assert run_ctx.model_used == "gpt-4o-mini" + assert run_ctx.last_action == "allow" + + def test_updates_run_metrics_with_object_messages(self): + """Also support object-style messages (defensive).""" + init(mode="observe") + + class ObjMsg: + content = "What is 2+2?" + + with run(budget=1.0) as run_ctx: + hook_ctx = FakeHookContext( + llm=FakeLLM("gpt-4o-mini"), + messages=[ObjMsg()], + response="The answer is 4.", + ) + crewai_mod._after_llm_call_hook(hook_ctx) + assert run_ctx.cost > 0 + + def test_updates_budget_remaining(self): + init(mode="enforce", budget=1.0) + with run(budget=1.0) as run_ctx: + hook_ctx = FakeHookContext( + llm=FakeLLM("gpt-4o"), + messages=[{"role": "user", "content": "test"}], + response="response", + ) + crewai_mod._after_llm_call_hook(hook_ctx) + assert run_ctx.budget_remaining is not None + assert run_ctx.budget_remaining == pytest.approx(1.0 - run_ctx.cost) + + def test_trace_records_mode(self): + init(mode="enforce") + with run() as run_ctx: + hook_ctx = FakeHookContext( + llm=FakeLLM("gpt-4o"), + messages=[{"role": "user", "content": "test"}], + response="done", + ) + crewai_mod._after_llm_call_hook(hook_ctx) + trace = run_ctx.trace() + assert len(trace) == 1 + assert trace[0]["reason"] == "enforce" + assert trace[0]["model"] == "gpt-4o" + + def test_no_start_time_records_zero_latency(self): + init(mode="observe") + with run() as run_ctx: + hook_ctx = FakeHookContext( + llm=FakeLLM("gpt-4o"), + messages=[], + response="ok", + ) + # Don't set start time + crewai_mod._after_llm_call_hook(hook_ctx) + assert run_ctx.latency_used_ms == 0.0 + + def test_token_estimation_from_dict_messages(self): + """Verify token estimation works with dict messages (real CrewAI shape).""" + init(mode="observe") + with run() as run_ctx: + # 400 chars in messages → 100 prompt tokens + # 80 chars in response → 20 completion tokens + messages = [{"role": "user", "content": "x" * 400}] + hook_ctx = FakeHookContext( + llm=FakeLLM("gpt-4o"), + messages=messages, + response="y" * 80, + ) + crewai_mod._after_llm_call_hook(hook_ctx) + # gpt-4o: $2.50/1M in, $10.00/1M out + expected_cost = (100 / 1_000_000) * 2.50 + (20 / 1_000_000) * 10.00 + assert run_ctx.cost == pytest.approx(expected_cost) + + def test_fail_open_swallows_errors(self): + crewai_mod._config = crewai_mod.CrewAIHarnessConfig(fail_open=True) + init(mode="observe") + with run(): + hook_ctx = FakeHookContext(response="ok") + with patch( + "cascadeflow.harness.api.get_current_run", + side_effect=RuntimeError("boom"), + ): + result = crewai_mod._after_llm_call_hook(hook_ctx) + assert result is None + + +# --------------------------------------------------------------------------- +# enable / disable lifecycle +# --------------------------------------------------------------------------- + + +class TestEnableDisable: + def test_enable_returns_false_when_crewai_not_available(self): + with patch.object(crewai_mod, "CREWAI_AVAILABLE", False): + result = crewai_mod.enable() + assert result is False + assert not crewai_mod.is_enabled() + + def test_enable_registers_hooks(self, monkeypatch): + fake_hooks = _make_fake_hooks_module() + monkeypatch.setattr(crewai_mod, "CREWAI_AVAILABLE", True) + + import sys + + monkeypatch.setitem(sys.modules, "crewai.hooks", fake_hooks) + + result = crewai_mod.enable() + assert result is True + assert crewai_mod.is_enabled() + assert len(fake_hooks._before_hooks) == 1 + assert len(fake_hooks._after_hooks) == 1 + + def test_enable_is_idempotent(self, monkeypatch): + fake_hooks = _make_fake_hooks_module() + monkeypatch.setattr(crewai_mod, "CREWAI_AVAILABLE", True) + + import sys + + monkeypatch.setitem(sys.modules, "crewai.hooks", fake_hooks) + + crewai_mod.enable() + crewai_mod.enable() # second call + assert len(fake_hooks._before_hooks) == 1 # still just one + + def test_enable_applies_config(self, monkeypatch): + fake_hooks = _make_fake_hooks_module() + monkeypatch.setattr(crewai_mod, "CREWAI_AVAILABLE", True) + + import sys + + monkeypatch.setitem(sys.modules, "crewai.hooks", fake_hooks) + + custom_config = crewai_mod.CrewAIHarnessConfig(fail_open=False, enable_budget_gate=False) + crewai_mod.enable(config=custom_config) + + cfg = crewai_mod.get_config() + assert cfg.fail_open is False + assert cfg.enable_budget_gate is False + + def test_disable_unregisters_hooks(self, monkeypatch): + fake_hooks = _make_fake_hooks_module() + monkeypatch.setattr(crewai_mod, "CREWAI_AVAILABLE", True) + + import sys + + monkeypatch.setitem(sys.modules, "crewai.hooks", fake_hooks) + + crewai_mod.enable() + assert crewai_mod.is_enabled() + assert len(fake_hooks._before_hooks) == 1 + + crewai_mod.disable() + assert not crewai_mod.is_enabled() + assert len(fake_hooks._before_hooks) == 0 + assert len(fake_hooks._after_hooks) == 0 + + def test_disable_when_not_enabled_is_safe(self): + crewai_mod.disable() # should not raise + assert not crewai_mod.is_enabled() + + def test_disable_clears_call_start_times(self, monkeypatch): + fake_hooks = _make_fake_hooks_module() + monkeypatch.setattr(crewai_mod, "CREWAI_AVAILABLE", True) + + import sys + + monkeypatch.setitem(sys.modules, "crewai.hooks", fake_hooks) + + crewai_mod.enable() + crewai_mod._call_start_times[123] = 1.0 + crewai_mod.disable() + assert len(crewai_mod._call_start_times) == 0 + + def test_enable_returns_false_for_old_crewai(self, monkeypatch): + """When crewai is installed but lacks hooks module (< v1.5).""" + monkeypatch.setattr(crewai_mod, "CREWAI_AVAILABLE", True) + + import sys + + # Remove crewai.hooks from modules so import fails + monkeypatch.delitem(sys.modules, "crewai.hooks", raising=False) + + original_import = ( + __builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__ + ) + + def fake_import(name, *args, **kwargs): + if name == "crewai.hooks": + raise ImportError("no hooks") + return original_import(name, *args, **kwargs) + + monkeypatch.setattr("builtins.__import__", fake_import) + + result = crewai_mod.enable() + assert result is False + + +# --------------------------------------------------------------------------- +# Public API helpers +# --------------------------------------------------------------------------- + + +class TestPublicAPI: + def test_is_available_reflects_module_flag(self): + # crewai is not installed in test env + assert crewai_mod.is_available() == crewai_mod.CREWAI_AVAILABLE + + def test_is_enabled_default_false(self): + assert crewai_mod.is_enabled() is False + + def test_get_config_returns_copy(self): + cfg = crewai_mod.get_config() + assert isinstance(cfg, crewai_mod.CrewAIHarnessConfig) + assert cfg.fail_open is True + assert cfg.enable_budget_gate is True + # Modifying the copy doesn't affect the module state + cfg.fail_open = False + assert crewai_mod.get_config().fail_open is True + + +# --------------------------------------------------------------------------- +# CrewAIHarnessConfig +# --------------------------------------------------------------------------- + + +class TestConfig: + def test_defaults(self): + cfg = crewai_mod.CrewAIHarnessConfig() + assert cfg.fail_open is True + assert cfg.enable_budget_gate is True + + def test_custom_values(self): + cfg = crewai_mod.CrewAIHarnessConfig(fail_open=False, enable_budget_gate=False) + assert cfg.fail_open is False + assert cfg.enable_budget_gate is False diff --git a/tests/test_harness_api.py b/tests/test_harness_api.py new file mode 100644 index 00000000..850255ba --- /dev/null +++ b/tests/test_harness_api.py @@ -0,0 +1,476 @@ +import sys + +import pytest + +import cascadeflow +import cascadeflow.harness.api as harness_api +from cascadeflow.harness import agent, get_current_run, get_harness_config, init, reset, run +from cascadeflow.telemetry.callbacks import CallbackEvent, CallbackManager + + +def setup_function() -> None: + reset() + + +def test_init_sets_mode_and_returns_report(): + report = init(mode="observe", budget=1.5, max_tool_calls=7) + + cfg = get_harness_config() + assert cfg.mode == "observe" + assert cfg.budget == 1.5 + assert cfg.max_tool_calls == 7 + assert report.mode == "observe" + assert isinstance(report.instrumented, list) + assert isinstance(report.detected_but_not_instrumented, list) + assert report.config_sources["mode"] == "code" + + +def test_init_rejects_invalid_mode(): + with pytest.raises(ValueError): + init(mode="invalid") # type: ignore[arg-type] + + +def test_init_idempotent_logs(monkeypatch, caplog): + monkeypatch.setattr(harness_api, "find_spec", lambda _: None) + with caplog.at_level("DEBUG", logger="cascadeflow.harness"): + init(mode="observe") + init(mode="observe") + assert any("idempotent" in rec.message for rec in caplog.records) + + +def test_env_aliases_and_false_bool(monkeypatch): + monkeypatch.setenv("CASCADEFLOW_MODE", "observe") + monkeypatch.setenv("CASCADEFLOW_BUDGET", "0.33") + monkeypatch.setenv("CASCADEFLOW_HARNESS_VERBOSE", "off") + monkeypatch.setenv("CASCADEFLOW_HARNESS_MAX_TOOL_CALLS", "4") + monkeypatch.setenv("CASCADEFLOW_HARNESS_MAX_LATENCY_MS", "1200") + monkeypatch.setenv("CASCADEFLOW_HARNESS_MAX_ENERGY", "0.01") + monkeypatch.setenv("CASCADEFLOW_HARNESS_COMPLIANCE", "gdpr") + + report = init() + cfg = get_harness_config() + + assert report.mode == "observe" + assert cfg.mode == "observe" + assert cfg.budget == 0.33 + assert cfg.verbose is False + assert cfg.max_tool_calls == 4 + assert cfg.max_latency_ms == 1200 + assert cfg.max_energy == 0.01 + assert cfg.compliance == "gdpr" + + +def test_init_invalid_json_env_raises(monkeypatch): + monkeypatch.setenv("CASCADEFLOW_HARNESS_KPI_WEIGHTS", "[1,2,3]") + with pytest.raises(ValueError): + init() + + +def test_init_non_numeric_env_raises(monkeypatch): + monkeypatch.setenv("CASCADEFLOW_HARNESS_BUDGET", "abc") + with pytest.raises(ValueError): + init() + + +def test_run_uses_global_defaults_and_overrides(): + init( + mode="enforce", + budget=2.0, + max_tool_calls=5, + kpi_targets={"quality_min": 0.9}, + kpi_weights={"cost": 0.7, "quality": 0.3}, + compliance="gdpr", + ) + + default_ctx = run() + assert default_ctx.mode == "enforce" + assert default_ctx.budget_max == 2.0 + assert default_ctx.tool_calls_max == 5 + assert default_ctx.budget_remaining == 2.0 + assert default_ctx.kpi_targets == {"quality_min": 0.9} + assert default_ctx.kpi_weights == {"cost": 0.7, "quality": 0.3} + assert default_ctx.compliance == "gdpr" + + override_ctx = run( + budget=0.5, + max_tool_calls=3, + kpi_weights={"quality": 1.0}, + compliance="strict", + ) + assert override_ctx.budget_max == 0.5 + assert override_ctx.tool_calls_max == 3 + assert override_ctx.budget_remaining == 0.5 + assert override_ctx.kpi_targets == {"quality_min": 0.9} + assert override_ctx.kpi_weights == {"quality": 1.0} + assert override_ctx.compliance == "strict" + + +def test_run_without_enter_exit_is_safe(): + ctx = run() + ctx.__exit__(None, None, None) + + +@pytest.mark.asyncio +async def test_nested_run_context_is_isolated(): + init(mode="enforce", budget=1.0) + + async with run(budget=0.7) as outer: + assert get_current_run() is outer + assert outer.budget_max == 0.7 + + async with run(budget=0.2) as inner: + assert get_current_run() is inner + assert inner.budget_max == 0.2 + + assert get_current_run() is outer + + assert get_current_run() is None + + +def test_sync_run_context_isolated(): + init(mode="enforce", budget=1.0) + with run(budget=0.6) as outer: + assert get_current_run() is outer + with run(budget=0.1) as inner: + assert get_current_run() is inner + assert inner.budget_max == 0.1 + assert get_current_run() is outer + assert get_current_run() is None + + +def test_agent_decorator_keeps_sync_behavior_and_attaches_metadata(): + @agent( + budget=0.9, + kpi_targets={"quality_min": 0.9}, + kpi_weights={"cost": 0.5, "quality": 0.5}, + compliance="gdpr", + ) + def fn(x: int) -> int: + return x + 1 + + assert fn(2) == 3 + policy = fn.__cascadeflow_agent_policy__ + assert policy["budget"] == 0.9 + assert policy["kpi_targets"] == {"quality_min": 0.9} + assert policy["compliance"] == "gdpr" + + +def test_agent_decorator_preserves_function_metadata(): + @agent(budget=0.5) + def fn(x: int) -> int: + """sample doc""" + return x + + assert fn.__name__ == "fn" + assert fn.__doc__ == "sample doc" + assert fn.__annotations__ == {"x": int, "return": int} + + +@pytest.mark.asyncio +async def test_agent_decorator_keeps_async_behavior_and_attaches_metadata(): + @agent(budget=0.4, kpi_weights={"cost": 1.0}) + async def fn(x: int) -> int: + return x * 2 + + assert await fn(4) == 8 + policy = fn.__cascadeflow_agent_policy__ + assert policy["budget"] == 0.4 + assert policy["kpi_weights"] == {"cost": 1.0} + + +def test_top_level_exports_exist(): + assert callable(cascadeflow.init) + assert callable(cascadeflow.reset) + assert callable(cascadeflow.run) + assert callable(cascadeflow.harness_agent) + assert hasattr(cascadeflow.agent, "PROVIDER_REGISTRY") + assert callable(cascadeflow.get_harness_callback_manager) + assert callable(cascadeflow.set_harness_callback_manager) + report = cascadeflow.init(mode="off") + assert report.mode == "off" + + +def test_run_record_and_trace_copy(): + ctx = run(budget=1.0) + ctx.record(action="switch_model", reason="cost_pressure", model="gpt-4o-mini") + trace_a = ctx.trace() + trace_b = ctx.trace() + assert trace_a == trace_b + assert trace_a[0]["action"] == "switch_model" + assert "budget_state" in trace_a[0] + assert trace_a[0]["budget_state"]["max"] == 1.0 + trace_a.append({"action": "mutated"}) + assert len(ctx.trace()) == 1 + + +def test_init_reads_from_env(monkeypatch): + monkeypatch.setenv("CASCADEFLOW_HARNESS_MODE", "observe") + monkeypatch.setenv("CASCADEFLOW_HARNESS_BUDGET", "0.25") + monkeypatch.setenv("CASCADEFLOW_HARNESS_KPI_TARGETS", '{"quality_min": 0.9}') + monkeypatch.setenv("CASCADEFLOW_HARNESS_KPI_WEIGHTS", '{"cost": 1.0}') + + report = init() + cfg = get_harness_config() + + assert report.mode == "observe" + assert cfg.mode == "observe" + assert cfg.budget == 0.25 + assert cfg.kpi_targets == {"quality_min": 0.9} + assert cfg.kpi_weights == {"cost": 1.0} + assert report.config_sources["mode"] == "env" + assert report.config_sources["budget"] == "env" + + +def test_init_rejects_oversized_env_json(monkeypatch): + monkeypatch.setenv("CASCADEFLOW_HARNESS_KPI_TARGETS", "x" * 5000) + with pytest.raises(ValueError, match="JSON config exceeds"): + init() + + +def test_init_reads_from_config_file(tmp_path, monkeypatch): + config = tmp_path / "cascadeflow.json" + config.write_text( + '{"harness":{"mode":"observe","budget":0.75,"max_tool_calls":11,"kpi_targets":{"quality_min":0.9}}}' + ) + monkeypatch.setenv("CASCADEFLOW_CONFIG", str(config)) + + report = init() + cfg = get_harness_config() + + assert cfg.mode == "observe" + assert cfg.budget == 0.75 + assert cfg.max_tool_calls == 11 + assert cfg.kpi_targets == {"quality_min": 0.9} + assert report.config_sources["mode"] == "file" + assert report.config_sources["budget"] == "file" + + +def test_init_reads_top_level_config_file_keys(tmp_path, monkeypatch): + config = tmp_path / "cascadeflow.json" + config.write_text('{"mode":"observe","budget":0.4,"max_tool_calls":2}') + monkeypatch.setenv("CASCADEFLOW_CONFIG", str(config)) + + report = init() + cfg = get_harness_config() + + assert cfg.mode == "observe" + assert cfg.budget == 0.4 + assert cfg.max_tool_calls == 2 + assert report.config_sources["mode"] == "file" + + +def test_init_non_dict_config_file_ignored(tmp_path, monkeypatch): + config = tmp_path / "cascadeflow.json" + config.write_text('["not-a-dict"]') + monkeypatch.setenv("CASCADEFLOW_CONFIG", str(config)) + + report = init() + cfg = get_harness_config() + + assert cfg.mode == "off" + assert cfg.budget is None + assert report.config_sources["mode"] == "default" + + +def test_init_file_loader_exception_falls_back_defaults(monkeypatch): + import cascadeflow.config_loader as cl + + monkeypatch.setattr(cl, "find_config", lambda: "broken.json") + + def _raise(_path): + raise RuntimeError("boom") + + monkeypatch.setattr(cl, "load_config", _raise) + + report = init() + cfg = get_harness_config() + assert cfg.mode == "off" + assert report.config_sources["mode"] == "default" + + +def test_init_config_loader_import_failure_falls_back(monkeypatch): + monkeypatch.setitem(sys.modules, "cascadeflow.config_loader", object()) + report = init(mode="observe") + assert report.mode == "observe" + assert report.config_sources["mode"] == "code" + + +def test_precedence_code_over_env_over_file(tmp_path, monkeypatch): + config = tmp_path / "cascadeflow.json" + config.write_text('{"harness":{"mode":"off","budget":9.9}}') + monkeypatch.setenv("CASCADEFLOW_CONFIG", str(config)) + monkeypatch.setenv("CASCADEFLOW_HARNESS_MODE", "observe") + monkeypatch.setenv("CASCADEFLOW_HARNESS_BUDGET", "0.5") + + # env overrides file + report_env = init() + cfg_env = get_harness_config() + assert cfg_env.mode == "observe" + assert cfg_env.budget == 0.5 + assert report_env.config_sources["mode"] == "env" + assert report_env.config_sources["budget"] == "env" + + # code overrides env + report_code = init(mode="enforce", budget=0.2) + cfg_code = get_harness_config() + assert cfg_code.mode == "enforce" + assert cfg_code.budget == 0.2 + assert report_code.config_sources["mode"] == "code" + assert report_code.config_sources["budget"] == "code" + + +def test_reset_clears_state(): + init(mode="enforce", budget=0.9) + with run() as ctx: + assert get_current_run() is ctx + reset() + cfg = get_harness_config() + assert cfg.mode == "off" + assert cfg.budget is None + assert get_current_run() is None + + +def test_init_without_detected_sdks(monkeypatch): + monkeypatch.setattr(harness_api, "find_spec", lambda _: None) + report = init(mode="observe") + assert report.instrumented == [] + assert report.detected_but_not_instrumented == [] + + +def test_init_reports_openai_instrumented_when_patch_succeeds(monkeypatch): + monkeypatch.setattr( + harness_api, + "find_spec", + lambda name: object() if name == "openai" else None, + ) + + import cascadeflow.harness.instrument as instrument + + monkeypatch.setattr(instrument, "patch_openai", lambda: True) + report = init(mode="observe") + assert report.instrumented == ["openai"] + + +def test_init_reports_anthropic_instrumented_when_patch_succeeds(monkeypatch): + monkeypatch.setattr( + harness_api, + "find_spec", + lambda name: object() if name == "anthropic" else None, + ) + + import cascadeflow.harness.instrument as instrument + + monkeypatch.setattr(instrument, "patch_anthropic", lambda: True) + report = init(mode="observe") + assert report.instrumented == ["anthropic"] + + +def test_init_reports_anthropic_detected_not_instrumented_on_patch_failure(monkeypatch): + monkeypatch.setattr( + harness_api, + "find_spec", + lambda name: object() if name == "anthropic" else None, + ) + + import cascadeflow.harness.instrument as instrument + + monkeypatch.setattr(instrument, "patch_anthropic", lambda: False) + report = init(mode="observe") + assert report.instrumented == [] + assert report.detected_but_not_instrumented == ["anthropic"] + + +def test_run_summary_populates_on_context_exit(): + init(mode="observe") + with run(budget=1.5) as ctx: + ctx.step_count = 2 + ctx.tool_calls = 1 + ctx.cost = 0.42 + ctx.latency_used_ms = 123.0 + ctx.energy_used = 33.0 + ctx.budget_remaining = 1.08 + ctx.last_action = "allow" + ctx.model_used = "gpt-4o-mini" + + summary = ctx.summary() + assert summary["run_id"] == ctx.run_id + assert summary["step_count"] == 2 + assert summary["budget_remaining"] == pytest.approx(1.08) + assert summary["duration_ms"] is not None + assert summary["duration_ms"] >= 0.0 + assert ctx.duration_ms is not None + assert ctx.duration_ms >= 0.0 + + +def test_run_context_logs_summary(caplog): + init(mode="observe") + with caplog.at_level("INFO", logger="cascadeflow.harness"): + with run(budget=1.0) as ctx: + ctx.step_count = 1 + ctx.cost = 0.01 + ctx.model_used = "gpt-4o-mini" + + assert any("harness run summary" in rec.message for rec in caplog.records) + + +def test_record_emits_cascade_decision_callback(): + manager = CallbackManager() + received = [] + + def _on_decision(data): + received.append(data) + + manager.register(CallbackEvent.CASCADE_DECISION, _on_decision) + report = init(mode="observe", callback_manager=manager) + assert report.config_sources["callback_manager"] == "code" + + with run(budget=1.0) as ctx: + ctx.step_count = 1 + ctx.record(action="switch_model", reason="budget_pressure", model="gpt-4o-mini") + + assert len(received) == 1 + event = received[0] + assert event.event == CallbackEvent.CASCADE_DECISION + assert event.query == "[harness]" + assert event.workflow == "harness" + assert event.data["action"] == "switch_model" + assert event.data["run_id"] == ctx.run_id + + +def test_record_sanitizes_trace_values(): + ctx = run() + ctx.record( + action="allow\nnewline", + reason="a" * 400, + model="model\r\nname", + ) + entry = ctx.trace()[0] + assert "\n" not in entry["action"] + assert "\r" not in entry["model"] + assert len(entry["reason"]) <= 160 + + +def test_record_sanitizes_non_printable_values(): + ctx = run() + ctx.record(action="allow\x00", reason="ok\x1f", model="gpt-4o-mini\x07") + entry = ctx.trace()[0] + assert "\x00" not in entry["action"] + assert "\x1f" not in entry["reason"] + assert "\x07" not in entry["model"] + + +def test_record_without_callback_manager_is_noop(): + init(mode="observe") + with run(budget=1.0) as ctx: + ctx.record(action="allow", reason="test", model="gpt-4o-mini") + assert len(ctx.trace()) == 1 + + +def test_record_empty_action_warns_and_defaults(caplog): + init(mode="observe") + with caplog.at_level("WARNING", logger="cascadeflow.harness"): + with run(budget=1.0) as ctx: + ctx.record(action="", reason="test", model="gpt-4o-mini") + entry = ctx.trace()[0] + assert entry["action"] == "allow" + assert any("empty action" in rec.message for rec in caplog.records) diff --git a/tests/test_harness_instrument.py b/tests/test_harness_instrument.py new file mode 100644 index 00000000..55e71837 --- /dev/null +++ b/tests/test_harness_instrument.py @@ -0,0 +1,1502 @@ +"""Tests for cascadeflow.harness.instrument — OpenAI + Anthropic auto-instrumentation.""" + +from __future__ import annotations + +from importlib.util import find_spec +import time +from typing import Optional +from unittest.mock import AsyncMock, MagicMock + +import pytest + +pytest.importorskip("openai", reason="openai package required for instrumentation tests") + +from cascadeflow.harness import init, reset, run +from cascadeflow.harness.instrument import ( + _InstrumentedAnthropicAsyncStream, + _InstrumentedAnthropicStream, + _InstrumentedAsyncStream, + _InstrumentedStream, + _count_tool_calls_in_anthropic_response, + _estimate_cost, + _estimate_energy, + _extract_anthropic_usage, + _make_patched_anthropic_async_create, + _make_patched_anthropic_create, + _make_patched_async_create, + _make_patched_create, + is_anthropic_patched, + is_openai_patched, + is_patched, + patch_anthropic, + patch_openai, + unpatch_anthropic, + unpatch_openai, +) + + +@pytest.fixture(autouse=True) +def _reset_harness() -> None: + reset() + yield # type: ignore[misc] + reset() + + +# --------------------------------------------------------------------------- +# Mock helpers +# --------------------------------------------------------------------------- + + +def _mock_usage(prompt_tokens: int = 100, completion_tokens: int = 50) -> MagicMock: + u = MagicMock() + u.prompt_tokens = prompt_tokens + u.completion_tokens = completion_tokens + return u + + +def _mock_completion( + prompt_tokens: int = 100, + completion_tokens: int = 50, + tool_calls: Optional[list] = None, +) -> MagicMock: + msg = MagicMock() + msg.tool_calls = tool_calls + choice = MagicMock() + choice.message = msg + resp = MagicMock() + resp.usage = _mock_usage(prompt_tokens, completion_tokens) + resp.choices = [choice] + return resp + + +def _mock_tool_call(tc_id: str) -> MagicMock: + tc = MagicMock() + tc.id = tc_id + return tc + + +def _mock_stream_chunk( + content: str = "hi", + usage: Optional[MagicMock] = None, + tool_calls: Optional[list] = None, +) -> MagicMock: + delta = MagicMock() + delta.content = content + delta.tool_calls = tool_calls + choice = MagicMock() + choice.delta = delta + chunk = MagicMock() + chunk.choices = [choice] + chunk.usage = usage + return chunk + + +# --------------------------------------------------------------------------- +# Patch lifecycle +# --------------------------------------------------------------------------- + + +class TestPatchLifecycle: + def test_patch_and_unpatch(self) -> None: + assert not is_openai_patched() + result = patch_openai() + assert result is True + assert is_openai_patched() + unpatch_openai() + assert not is_openai_patched() + + def test_idempotent_patching(self) -> None: + patch_openai() + patch_openai() + assert is_openai_patched() + unpatch_openai() + assert not is_openai_patched() + + def test_unpatch_without_prior_patch(self) -> None: + unpatch_openai() # should not raise + + def test_init_observe_patches(self) -> None: + report = init(mode="observe") + assert "openai" in report.instrumented + assert is_openai_patched() + + def test_init_enforce_patches(self) -> None: + report = init(mode="enforce") + assert "openai" in report.instrumented + assert is_openai_patched() + + def test_init_off_does_not_patch(self) -> None: + init(mode="off") + assert not is_patched() + + def test_reset_unpatches(self) -> None: + init(mode="observe") + assert is_openai_patched() + reset() + assert not is_patched() + + def test_class_method_actually_replaced(self) -> None: + from openai.resources.chat.completions import Completions + + original = Completions.create + patch_openai() + assert Completions.create is not original + unpatch_openai() + assert Completions.create is original + + def test_patch_and_unpatch_anthropic(self) -> None: + if find_spec("anthropic") is None: + pytest.skip("anthropic package not available") + assert not is_anthropic_patched() + result = patch_anthropic() + assert result is True + assert is_anthropic_patched() + unpatch_anthropic() + assert not is_anthropic_patched() + + def test_anthropic_class_method_actually_replaced(self) -> None: + if find_spec("anthropic") is None: + pytest.skip("anthropic package not available") + from anthropic.resources.messages import Messages + + original = Messages.create + patch_anthropic() + assert Messages.create is not original + unpatch_anthropic() + assert Messages.create is original + + +# --------------------------------------------------------------------------- +# Sync wrapper +# --------------------------------------------------------------------------- + + +class TestSyncWrapper: + def test_observe_passes_through_response(self) -> None: + init(mode="observe") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=1.0) as ctx: + result = wrapper(MagicMock(), model="gpt-4o-mini") + + assert result is mock_resp + original.assert_called_once() + + def test_observe_tracks_cost(self) -> None: + init(mode="observe") + mock_resp = _mock_completion(prompt_tokens=1_000_000, completion_tokens=1_000_000) + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=10.0) as ctx: + wrapper(MagicMock(), model="gpt-4o-mini") + + # gpt-4o-mini: $0.15/1M in + $0.60/1M out = $0.75 + assert ctx.cost == pytest.approx(0.75, abs=0.01) + + def test_observe_tracks_step_count(self) -> None: + init(mode="observe") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=1.0) as ctx: + wrapper(MagicMock(), model="gpt-4o-mini") + wrapper(MagicMock(), model="gpt-4o-mini") + + assert ctx.step_count == 2 + + def test_observe_tracks_tool_calls(self) -> None: + init(mode="observe") + tc1 = _mock_tool_call("tc_1") + tc2 = _mock_tool_call("tc_2") + mock_resp = _mock_completion(tool_calls=[tc1, tc2]) + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=1.0) as ctx: + wrapper(MagicMock(), model="gpt-4o") + + assert ctx.tool_calls == 2 + + def test_observe_tracks_energy(self) -> None: + init(mode="observe") + mock_resp = _mock_completion(prompt_tokens=1000, completion_tokens=500) + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=1.0) as ctx: + wrapper(MagicMock(), model="gpt-4o-mini") + + # gpt-4o-mini coefficient=0.3, output_weight=1.5 + # energy = 0.3 * (1000 + 500 * 1.5) = 0.3 * 1750 = 525.0 + assert ctx.energy_used == pytest.approx(525.0) + + def test_observe_tracks_latency(self) -> None: + init(mode="observe") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=1.0) as ctx: + wrapper(MagicMock(), model="gpt-4o-mini") + + assert ctx.latency_used_ms > 0 + + def test_budget_remaining_decreases(self) -> None: + init(mode="observe") + mock_resp = _mock_completion(prompt_tokens=1_000_000, completion_tokens=1_000_000) + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=10.0) as ctx: + wrapper(MagicMock(), model="gpt-4o-mini") + + assert ctx.budget_remaining is not None + assert ctx.budget_remaining < 10.0 + assert ctx.budget_remaining == pytest.approx(10.0 - 0.75, abs=0.01) + + def test_model_used_and_trace(self) -> None: + init(mode="observe") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=1.0) as ctx: + wrapper(MagicMock(), model="gpt-4o") + + assert ctx.model_used == "gpt-4o" + trace = ctx.trace() + assert len(trace) == 1 + assert trace[0]["action"] == "allow" + assert trace[0]["reason"] == "observe" + assert trace[0]["model"] == "gpt-4o" + assert trace[0]["applied"] is True + assert trace[0]["decision_mode"] == "observe" + + def test_off_mode_passthrough_no_tracking(self) -> None: + init(mode="off") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run() as ctx: + result = wrapper(MagicMock(), model="gpt-4o") + + assert result is mock_resp + assert ctx.cost == 0.0 + assert ctx.step_count == 0 + + def test_no_run_scope_logs_but_does_not_track(self) -> None: + init(mode="observe") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + # Call outside any run() scope + result = wrapper(MagicMock(), model="gpt-4o") + assert result is mock_resp + + def test_multiple_calls_accumulate(self) -> None: + init(mode="observe") + mock_resp = _mock_completion(prompt_tokens=1_000_000, completion_tokens=1_000_000) + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=10.0) as ctx: + wrapper(MagicMock(), model="gpt-4o-mini") + wrapper(MagicMock(), model="gpt-4o-mini") + + assert ctx.cost == pytest.approx(1.50, abs=0.01) + assert ctx.step_count == 2 + assert len(ctx.trace()) == 2 + + +# --------------------------------------------------------------------------- +# Async wrapper +# --------------------------------------------------------------------------- + + +class TestAsyncWrapper: + @pytest.mark.asyncio + async def test_observe_passes_through_response(self) -> None: + init(mode="observe") + mock_resp = _mock_completion() + original = AsyncMock(return_value=mock_resp) + wrapper = _make_patched_async_create(original) + + async with run(budget=1.0) as ctx: + result = await wrapper(MagicMock(), model="gpt-4o") + + assert result is mock_resp + + @pytest.mark.asyncio + async def test_observe_tracks_cost(self) -> None: + init(mode="observe") + mock_resp = _mock_completion(prompt_tokens=1_000_000, completion_tokens=1_000_000) + original = AsyncMock(return_value=mock_resp) + wrapper = _make_patched_async_create(original) + + async with run(budget=10.0) as ctx: + await wrapper(MagicMock(), model="gpt-4o-mini") + + assert ctx.cost == pytest.approx(0.75, abs=0.01) + assert ctx.step_count == 1 + + @pytest.mark.asyncio + async def test_off_mode_passthrough(self) -> None: + init(mode="off") + mock_resp = _mock_completion() + original = AsyncMock(return_value=mock_resp) + wrapper = _make_patched_async_create(original) + + async with run() as ctx: + result = await wrapper(MagicMock(), model="gpt-4o") + + assert result is mock_resp + assert ctx.cost == 0.0 + + +# --------------------------------------------------------------------------- +# Sync stream wrapper +# --------------------------------------------------------------------------- + + +class TestSyncStreamWrapper: + def test_stream_yields_all_chunks(self) -> None: + init(mode="observe") + chunk1 = _mock_stream_chunk("Hello") + chunk2 = _mock_stream_chunk(" world", usage=_mock_usage(100, 50)) + mock_stream = iter([chunk1, chunk2]) + + with run(budget=1.0) as ctx: + wrapped = _InstrumentedStream(mock_stream, ctx, "gpt-4o-mini", time.monotonic()) + chunks = list(wrapped) + + assert len(chunks) == 2 + assert chunks[0] is chunk1 + assert chunks[1] is chunk2 + + def test_stream_tracks_cost_after_consumption(self) -> None: + init(mode="observe") + chunk1 = _mock_stream_chunk("Hello") + chunk2 = _mock_stream_chunk(" world", usage=_mock_usage(1_000_000, 1_000_000)) + mock_stream = iter([chunk1, chunk2]) + + with run(budget=10.0) as ctx: + wrapped = _InstrumentedStream(mock_stream, ctx, "gpt-4o-mini", time.monotonic()) + list(wrapped) + + assert ctx.cost == pytest.approx(0.75, abs=0.01) + assert ctx.step_count == 1 + + def test_stream_tracks_tool_calls(self) -> None: + init(mode="observe") + tc = _mock_tool_call("tc_1") + chunk1 = _mock_stream_chunk("", tool_calls=[tc]) + chunk2 = _mock_stream_chunk("", usage=_mock_usage(100, 50)) + mock_stream = iter([chunk1, chunk2]) + + with run(budget=1.0) as ctx: + wrapped = _InstrumentedStream(mock_stream, ctx, "gpt-4o", time.monotonic()) + list(wrapped) + + assert ctx.tool_calls == 1 + + def test_stream_context_manager(self) -> None: + init(mode="observe") + chunk1 = _mock_stream_chunk("data", usage=_mock_usage(100, 50)) + mock_inner = MagicMock() + mock_inner.__iter__ = MagicMock(return_value=iter([chunk1])) + mock_inner.__next__ = MagicMock(side_effect=[chunk1, StopIteration]) + mock_inner.__enter__ = MagicMock(return_value=mock_inner) + mock_inner.__exit__ = MagicMock(return_value=False) + + with run(budget=1.0) as ctx: + with _InstrumentedStream(mock_inner, ctx, "gpt-4o-mini", time.monotonic()) as stream: + for _ in stream: + pass + + assert ctx.step_count == 1 + + def test_stream_finalize_is_idempotent(self) -> None: + init(mode="observe") + chunk1 = _mock_stream_chunk("data", usage=_mock_usage(100, 50)) + mock_stream = iter([chunk1]) + + with run(budget=1.0) as ctx: + wrapped = _InstrumentedStream(mock_stream, ctx, "gpt-4o-mini", time.monotonic()) + list(wrapped) + # Force finalize again + wrapped._finalize() + + assert ctx.step_count == 1 # Should not double-count + + def test_stream_finalizes_on_iteration_error(self) -> None: + init(mode="observe") + chunk1 = _mock_stream_chunk("data", usage=_mock_usage(100, 50)) + + class _FailingStream: + def __init__(self) -> None: + self._done = False + + def __iter__(self): + return self + + def __next__(self): + if not self._done: + self._done = True + return chunk1 + raise RuntimeError("stream failed") + + with run(budget=1.0) as ctx: + wrapped = _InstrumentedStream(_FailingStream(), ctx, "gpt-4o-mini", time.monotonic()) + with pytest.raises(RuntimeError, match="stream failed"): + list(wrapped) + + assert ctx.step_count == 1 + assert ctx.cost > 0 + + def test_stream_wrapper_via_patched_create(self) -> None: + """Verify that stream=True in the wrapper returns an _InstrumentedStream.""" + init(mode="observe") + chunk = _mock_stream_chunk("hi", usage=_mock_usage(50, 25)) + mock_stream = iter([chunk]) + original = MagicMock(return_value=mock_stream) + wrapper = _make_patched_create(original) + + with run(budget=1.0) as ctx: + result = wrapper(MagicMock(), model="gpt-4o-mini", stream=True) + assert isinstance(result, _InstrumentedStream) + list(result) + + assert ctx.step_count == 1 + + +# --------------------------------------------------------------------------- +# Async stream wrapper +# --------------------------------------------------------------------------- + + +class TestAsyncStreamWrapper: + @pytest.mark.asyncio + async def test_async_stream_yields_all_chunks(self) -> None: + init(mode="observe") + chunk1 = _mock_stream_chunk("Hello") + chunk2 = _mock_stream_chunk(" world", usage=_mock_usage(100, 50)) + + async def _async_iter(): + yield chunk1 + yield chunk2 + + mock_stream = _async_iter() + + async with run(budget=1.0) as ctx: + wrapped = _InstrumentedAsyncStream(mock_stream, ctx, "gpt-4o-mini", time.monotonic()) + chunks = [c async for c in wrapped] + + assert len(chunks) == 2 + assert ctx.cost > 0 + assert ctx.step_count == 1 + + @pytest.mark.asyncio + async def test_async_stream_via_patched_create(self) -> None: + """Verify that stream=True in async wrapper returns an _InstrumentedAsyncStream.""" + init(mode="observe") + chunk = _mock_stream_chunk("hi", usage=_mock_usage(50, 25)) + + async def _async_iter(): + yield chunk + + mock_stream = _async_iter() + original = AsyncMock(return_value=mock_stream) + wrapper = _make_patched_async_create(original) + + async with run(budget=1.0) as ctx: + result = await wrapper(MagicMock(), model="gpt-4o-mini", stream=True) + assert isinstance(result, _InstrumentedAsyncStream) + _ = [c async for c in result] + + assert ctx.step_count == 1 + + @pytest.mark.asyncio + async def test_async_stream_finalizes_on_iteration_error(self) -> None: + init(mode="observe") + chunk1 = _mock_stream_chunk("data", usage=_mock_usage(100, 50)) + + async def _failing_iter(): + yield chunk1 + raise RuntimeError("async stream failed") + + async with run(budget=1.0) as ctx: + wrapped = _InstrumentedAsyncStream(_failing_iter(), ctx, "gpt-4o-mini", time.monotonic()) + with pytest.raises(RuntimeError, match="async stream failed"): + async for _ in wrapped: + pass + + assert ctx.step_count == 1 + assert ctx.cost > 0 + + +# --------------------------------------------------------------------------- +# Cost and energy estimation +# --------------------------------------------------------------------------- + + +class TestEstimation: + def test_cost_known_model(self) -> None: + cost = _estimate_cost("gpt-4o-mini", 1_000_000, 1_000_000) + assert cost == pytest.approx(0.15 + 0.60) + + def test_cost_unknown_model_uses_default(self) -> None: + cost = _estimate_cost("my-custom-model", 1_000_000, 1_000_000) + # default pricing: $2.50/$10.00 + assert cost == pytest.approx(2.50 + 10.00) + + def test_cost_zero_tokens(self) -> None: + cost = _estimate_cost("gpt-4o", 0, 0) + assert cost == 0.0 + + def test_energy_known_model(self) -> None: + energy = _estimate_energy("gpt-4o-mini", 1000, 500) + # coeff=0.3, output_weight=1.5 + # energy = 0.3 * (1000 + 500 * 1.5) = 0.3 * 1750 = 525.0 + assert energy == pytest.approx(525.0) + + def test_energy_unknown_model_uses_default(self) -> None: + energy = _estimate_energy("custom-model", 1000, 500) + # default coeff=1.0 + # energy = 1.0 * (1000 + 500 * 1.5) = 1750.0 + assert energy == pytest.approx(1750.0) + + +# --------------------------------------------------------------------------- +# Nested run isolation +# --------------------------------------------------------------------------- + + +class TestNestedRuns: + def test_inner_run_does_not_affect_outer(self) -> None: + init(mode="observe") + mock_resp = _mock_completion(prompt_tokens=1_000_000, completion_tokens=1_000_000) + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=10.0) as outer: + wrapper(MagicMock(), model="gpt-4o-mini") # $0.75 to outer + outer_cost_before_inner = outer.cost + + with run(budget=5.0) as inner: + wrapper(MagicMock(), model="gpt-4o-mini") # $0.75 to inner + + # Outer cost should be unchanged after inner scope exits + assert outer.cost == pytest.approx(outer_cost_before_inner) + assert inner.cost == pytest.approx(0.75, abs=0.01) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_response_without_usage(self) -> None: + init(mode="observe") + mock_resp = MagicMock() + mock_resp.usage = None + mock_resp.choices = [] + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=1.0) as ctx: + wrapper(MagicMock(), model="gpt-4o") + + assert ctx.cost == 0.0 + assert ctx.step_count == 1 + + def test_response_without_choices(self) -> None: + init(mode="observe") + mock_resp = MagicMock() + mock_resp.usage = _mock_usage(100, 50) + mock_resp.choices = [] + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=1.0) as ctx: + wrapper(MagicMock(), model="gpt-4o") + + assert ctx.tool_calls == 0 + assert ctx.cost > 0 + + def test_stream_without_usage_in_any_chunk(self) -> None: + init(mode="observe") + chunk1 = _mock_stream_chunk("Hello") + chunk2 = _mock_stream_chunk(" world") + mock_stream = iter([chunk1, chunk2]) + + with run(budget=1.0) as ctx: + wrapped = _InstrumentedStream(mock_stream, ctx, "gpt-4o-mini", time.monotonic()) + list(wrapped) + + assert ctx.cost == 0.0 # No usage data available + assert ctx.step_count == 1 # Step still counted + + +# --------------------------------------------------------------------------- +# Fix: init(mode="off") unpatches previously patched client +# --------------------------------------------------------------------------- + + +class TestInitOffUnpatches: + def test_init_off_after_observe_unpatches(self) -> None: + init(mode="observe") + assert is_patched() + init(mode="off") + assert not is_patched() + + def test_init_off_when_not_patched_is_safe(self) -> None: + init(mode="off") + assert not is_patched() + + +# --------------------------------------------------------------------------- +# Fix: enforce mode — budget gate and correct trace reason +# --------------------------------------------------------------------------- + + +class TestEnforceMode: + def test_enforce_trace_records_enforce_reason(self) -> None: + init(mode="enforce") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=10.0) as ctx: + wrapper(MagicMock(), model="gpt-4o") + + trace = ctx.trace() + assert trace[0]["reason"] == "enforce" + + def test_observe_trace_records_observe_reason(self) -> None: + init(mode="observe") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=10.0) as ctx: + wrapper(MagicMock(), model="gpt-4o") + + trace = ctx.trace() + assert trace[0]["reason"] == "observe" + + def test_enforce_raises_on_budget_exhausted(self) -> None: + from cascadeflow.schema.exceptions import BudgetExceededError + + init(mode="enforce") + mock_resp = _mock_completion(prompt_tokens=1_000_000, completion_tokens=1_000_000) + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=0.001) as ctx: + # First call uses the tiny budget + wrapper(MagicMock(), model="gpt-4o") + # Second call should raise — budget exhausted + with pytest.raises(BudgetExceededError): + wrapper(MagicMock(), model="gpt-4o") + + def test_observe_does_not_raise_on_budget_exhausted(self) -> None: + init(mode="observe") + mock_resp = _mock_completion(prompt_tokens=1_000_000, completion_tokens=1_000_000) + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=0.001) as ctx: + wrapper(MagicMock(), model="gpt-4o") + # Second call should NOT raise — observe mode is permissive + wrapper(MagicMock(), model="gpt-4o") + + assert ctx.cost > ctx.budget_max # type: ignore[operator] + trace = ctx.trace() + assert trace[-1]["action"] == "stop" + assert trace[-1]["reason"] == "budget_exceeded" + assert trace[-1]["applied"] is False + assert trace[-1]["decision_mode"] == "observe" + + @pytest.mark.asyncio + async def test_enforce_raises_on_budget_exhausted_async(self) -> None: + from cascadeflow.schema.exceptions import BudgetExceededError + + init(mode="enforce") + mock_resp = _mock_completion(prompt_tokens=1_000_000, completion_tokens=1_000_000) + original = AsyncMock(return_value=mock_resp) + wrapper = _make_patched_async_create(original) + + async with run(budget=0.001) as ctx: + await wrapper(MagicMock(), model="gpt-4o") + with pytest.raises(BudgetExceededError): + await wrapper(MagicMock(), model="gpt-4o") + + +# --------------------------------------------------------------------------- +# Enforce actions: switch_model, deny_tool, stop +# --------------------------------------------------------------------------- + + +class TestEnforceActions: + def test_enforce_switches_model_under_budget_pressure(self) -> None: + init(mode="enforce") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=1.0) as ctx: + ctx.cost = 0.85 + ctx.budget_remaining = 0.15 + wrapper(MagicMock(), model="gpt-4o") + + assert original.call_args[1]["model"] == "gpt-4o-mini" + trace = ctx.trace() + assert trace[0]["action"] == "switch_model" + assert trace[0]["reason"] == "budget_pressure" + assert trace[0]["applied"] is True + assert trace[0]["decision_mode"] == "enforce" + + def test_observe_computes_switch_model_but_does_not_apply(self) -> None: + init(mode="observe") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=1.0) as ctx: + ctx.cost = 0.85 + ctx.budget_remaining = 0.15 + wrapper(MagicMock(), model="gpt-4o") + + assert original.call_args[1]["model"] == "gpt-4o" + trace = ctx.trace() + assert trace[0]["action"] == "switch_model" + assert trace[0]["reason"] == "budget_pressure" + assert trace[0]["model"] == "gpt-4o-mini" + assert trace[0]["applied"] is False + assert trace[0]["decision_mode"] == "observe" + + def test_enforce_denies_tools_when_cap_reached(self) -> None: + init(mode="enforce", max_tool_calls=0) + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(max_tool_calls=0) as ctx: + wrapper( + MagicMock(), + model="gpt-4o", + tools=[{"type": "function", "function": {"name": "t1"}}], + ) + + assert original.call_args[1]["tools"] == [] + trace = ctx.trace() + assert trace[0]["action"] == "deny_tool" + assert trace[0]["reason"] == "max_tool_calls_reached" + assert trace[0]["applied"] is True + assert trace[0]["decision_mode"] == "enforce" + + def test_observe_logs_deny_tool_but_keeps_tools(self) -> None: + init(mode="observe", max_tool_calls=0) + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + tools = [{"type": "function", "function": {"name": "t1"}}] + with run(max_tool_calls=0) as ctx: + wrapper(MagicMock(), model="gpt-4o", tools=tools) + + assert original.call_args[1]["tools"] == tools + trace = ctx.trace() + assert trace[0]["action"] == "deny_tool" + assert trace[0]["reason"] == "max_tool_calls_reached" + assert trace[0]["applied"] is False + assert trace[0]["decision_mode"] == "observe" + + def test_enforce_stops_when_latency_limit_exceeded_at_fastest_model(self) -> None: + from cascadeflow.schema.exceptions import HarnessStopError + + init(mode="enforce") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(max_latency_ms=1.0) as ctx: + ctx.latency_used_ms = 5.0 + with pytest.raises(HarnessStopError, match="latency_limit_exceeded"): + wrapper(MagicMock(), model="gpt-3.5-turbo") + + original.assert_not_called() + trace = ctx.trace() + assert trace[0]["action"] == "stop" + assert trace[0]["reason"] == "latency_limit_exceeded" + assert trace[0]["applied"] is True + assert trace[0]["decision_mode"] == "enforce" + + def test_enforce_stops_when_energy_limit_exceeded_at_lowest_energy_model(self) -> None: + from cascadeflow.schema.exceptions import HarnessStopError + + init(mode="enforce") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(max_energy=1.0) as ctx: + ctx.energy_used = 5.0 + with pytest.raises(HarnessStopError, match="energy_limit_exceeded"): + wrapper(MagicMock(), model="gpt-3.5-turbo") + + original.assert_not_called() + trace = ctx.trace() + assert trace[0]["action"] == "stop" + assert trace[0]["reason"] == "energy_limit_exceeded" + assert trace[0]["applied"] is True + assert trace[0]["decision_mode"] == "enforce" + + @pytest.mark.asyncio + async def test_async_enforce_denies_tools_when_cap_reached(self) -> None: + init(mode="enforce", max_tool_calls=0) + mock_resp = _mock_completion() + original = AsyncMock(return_value=mock_resp) + wrapper = _make_patched_async_create(original) + + async with run(max_tool_calls=0) as ctx: + await wrapper( + MagicMock(), + model="gpt-4o", + tools=[{"type": "function", "function": {"name": "t1"}}], + ) + + assert original.call_args[1]["tools"] == [] + trace = ctx.trace() + assert trace[0]["action"] == "deny_tool" + assert trace[0]["reason"] == "max_tool_calls_reached" + assert trace[0]["applied"] is True + assert trace[0]["decision_mode"] == "enforce" + + def test_enforce_switches_model_for_compliance_policy(self) -> None: + init(mode="enforce", compliance="strict") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run() as ctx: + wrapper(MagicMock(), model="gpt-4o-mini") + + assert original.call_args[1]["model"] == "gpt-4o" + trace = ctx.trace() + assert trace[0]["action"] == "switch_model" + assert trace[0]["reason"] == "compliance_model_policy" + assert trace[0]["applied"] is True + assert trace[0]["decision_mode"] == "enforce" + + def test_enforce_denies_tool_for_strict_compliance(self) -> None: + init(mode="enforce", compliance="strict") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run() as ctx: + wrapper( + MagicMock(), + model="gpt-4o", + tools=[{"type": "function", "function": {"name": "t1"}}], + ) + + assert original.call_args[1]["tools"] == [] + trace = ctx.trace() + assert trace[0]["action"] == "deny_tool" + assert trace[0]["reason"] == "compliance_tool_restriction" + assert trace[0]["applied"] is True + assert trace[0]["decision_mode"] == "enforce" + + def test_observe_logs_compliance_switch_without_applying(self) -> None: + init(mode="observe", compliance="strict") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run() as ctx: + wrapper(MagicMock(), model="gpt-4o-mini") + + assert original.call_args[1]["model"] == "gpt-4o-mini" + trace = ctx.trace() + assert trace[0]["action"] == "switch_model" + assert trace[0]["reason"] == "compliance_model_policy" + assert trace[0]["model"] == "gpt-4o" + assert trace[0]["applied"] is False + assert trace[0]["decision_mode"] == "observe" + + def test_enforce_switches_model_using_kpi_weights(self) -> None: + init(mode="enforce", kpi_weights={"quality": 1.0}) + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run() as ctx: + wrapper(MagicMock(), model="gpt-3.5-turbo") + + assert original.call_args[1]["model"] == "o1" + trace = ctx.trace() + assert trace[0]["action"] == "switch_model" + assert trace[0]["reason"] == "kpi_weight_optimization" + assert trace[0]["applied"] is True + assert trace[0]["decision_mode"] == "enforce" + + def test_observe_logs_kpi_switch_without_applying(self) -> None: + init(mode="observe", kpi_weights={"quality": 1.0}) + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run() as ctx: + wrapper(MagicMock(), model="gpt-3.5-turbo") + + assert original.call_args[1]["model"] == "gpt-3.5-turbo" + trace = ctx.trace() + assert trace[0]["action"] == "switch_model" + assert trace[0]["reason"] == "kpi_weight_optimization" + assert trace[0]["model"] == "o1" + assert trace[0]["applied"] is False + assert trace[0]["decision_mode"] == "observe" + + +# --------------------------------------------------------------------------- +# Fix: stream_options.include_usage auto-injection +# --------------------------------------------------------------------------- + + +class TestStreamUsageInjection: + def test_stream_injects_include_usage(self) -> None: + init(mode="observe") + mock_stream = iter([_mock_stream_chunk("hi", usage=_mock_usage(50, 25))]) + original = MagicMock(return_value=mock_stream) + wrapper = _make_patched_create(original) + + with run(budget=1.0) as ctx: + result = wrapper(MagicMock(), model="gpt-4o-mini", stream=True) + list(result) + + # Check the original was called with stream_options injected + call_kwargs = original.call_args[1] + assert call_kwargs.get("stream_options", {}).get("include_usage") is True + + def test_stream_preserves_existing_stream_options(self) -> None: + init(mode="observe") + mock_stream = iter([_mock_stream_chunk("hi", usage=_mock_usage(50, 25))]) + original = MagicMock(return_value=mock_stream) + wrapper = _make_patched_create(original) + + with run(budget=1.0) as ctx: + result = wrapper( + MagicMock(), + model="gpt-4o-mini", + stream=True, + stream_options={"include_usage": True}, + ) + list(result) + + call_kwargs = original.call_args[1] + assert call_kwargs["stream_options"]["include_usage"] is True + + def test_non_stream_does_not_inject_stream_options(self) -> None: + init(mode="observe") + mock_resp = _mock_completion() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_create(original) + + with run(budget=1.0) as ctx: + wrapper(MagicMock(), model="gpt-4o-mini") + + call_kwargs = original.call_args[1] + assert "stream_options" not in call_kwargs + + +# =========================================================================== +# Anthropic instrumentation tests +# =========================================================================== + + +def _mock_anthropic_usage( + input_tokens: Optional[int] = 100, + output_tokens: Optional[int] = 50, +) -> MagicMock: + u = MagicMock() + u.input_tokens = input_tokens + u.output_tokens = output_tokens + return u + + +def _mock_anthropic_response( + input_tokens: int = 100, + output_tokens: int = 50, + content: Optional[list] = None, +) -> MagicMock: + resp = MagicMock() + resp.usage = _mock_anthropic_usage(input_tokens, output_tokens) + resp.content = content or [] + return resp + + +def _mock_tool_use_block() -> MagicMock: + block = MagicMock() + block.type = "tool_use" + return block + + +def _mock_text_block() -> MagicMock: + block = MagicMock() + block.type = "text" + return block + + +def _mock_anthropic_message_start_event( + input_tokens: int = 100, + output_tokens: int = 0, +) -> MagicMock: + event = MagicMock() + event.type = "message_start" + event.message = MagicMock() + event.message.usage = _mock_anthropic_usage(input_tokens, output_tokens) + return event + + +def _mock_anthropic_message_delta_event( + output_tokens: int = 50, +) -> MagicMock: + event = MagicMock() + event.type = "message_delta" + event.usage = _mock_anthropic_usage(None, output_tokens) + return event + + +def _mock_anthropic_content_block_start_event( + block_type: str = "tool_use", +) -> MagicMock: + event = MagicMock() + event.type = "content_block_start" + event.content_block = MagicMock() + event.content_block.type = block_type + return event + + +def _mock_anthropic_message_stop_event() -> MagicMock: + event = MagicMock() + event.type = "message_stop" + event.usage = None + return event + + +# --------------------------------------------------------------------------- +# Anthropic usage extraction +# --------------------------------------------------------------------------- + + +class TestAnthropicUsageExtraction: + def test_extract_usage(self) -> None: + resp = _mock_anthropic_response(input_tokens=200, output_tokens=100) + inp, out = _extract_anthropic_usage(resp) + assert inp == 200 + assert out == 100 + + def test_extract_usage_none(self) -> None: + resp = MagicMock() + resp.usage = None + inp, out = _extract_anthropic_usage(resp) + assert inp == 0 + assert out == 0 + + +# --------------------------------------------------------------------------- +# Anthropic tool call counting +# --------------------------------------------------------------------------- + + +class TestAnthropicToolCallCounting: + def test_counts_tool_use_blocks(self) -> None: + resp = _mock_anthropic_response( + content=[_mock_text_block(), _mock_tool_use_block(), _mock_tool_use_block()] + ) + assert _count_tool_calls_in_anthropic_response(resp) == 2 + + def test_no_content(self) -> None: + resp = MagicMock() + resp.content = None + assert _count_tool_calls_in_anthropic_response(resp) == 0 + + def test_empty_content(self) -> None: + resp = _mock_anthropic_response(content=[]) + assert _count_tool_calls_in_anthropic_response(resp) == 0 + + def test_text_only(self) -> None: + resp = _mock_anthropic_response(content=[_mock_text_block()]) + assert _count_tool_calls_in_anthropic_response(resp) == 0 + + +# --------------------------------------------------------------------------- +# Anthropic sync wrapper +# --------------------------------------------------------------------------- + + +class TestAnthropicSyncWrapper: + def test_observe_passes_through_response(self) -> None: + init(mode="observe") + mock_resp = _mock_anthropic_response() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_create(original) + + with run(budget=1.0) as ctx: + result = wrapper(MagicMock(), model="claude-sonnet-4") + + assert result is mock_resp + original.assert_called_once() + + def test_observe_tracks_cost(self) -> None: + init(mode="observe") + mock_resp = _mock_anthropic_response(input_tokens=1_000_000, output_tokens=1_000_000) + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_create(original) + + with run(budget=100.0) as ctx: + wrapper(MagicMock(), model="claude-sonnet-4") + + # claude-sonnet-4: $3.00/1M in + $15.00/1M out = $18.00 + assert ctx.cost == pytest.approx(18.0, abs=0.01) + + def test_observe_tracks_step_count(self) -> None: + init(mode="observe") + mock_resp = _mock_anthropic_response() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_create(original) + + with run(budget=1.0) as ctx: + wrapper(MagicMock(), model="claude-sonnet-4") + wrapper(MagicMock(), model="claude-sonnet-4") + + assert ctx.step_count == 2 + + def test_observe_tracks_tool_calls(self) -> None: + init(mode="observe") + mock_resp = _mock_anthropic_response( + content=[_mock_tool_use_block(), _mock_tool_use_block()] + ) + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_create(original) + + with run(budget=1.0) as ctx: + wrapper(MagicMock(), model="claude-sonnet-4") + + assert ctx.tool_calls == 2 + + def test_observe_tracks_energy(self) -> None: + init(mode="observe") + mock_resp = _mock_anthropic_response(input_tokens=1000, output_tokens=500) + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_create(original) + + with run(budget=1.0) as ctx: + wrapper(MagicMock(), model="claude-sonnet-4") + + # claude-sonnet-4 uses default coefficient=1.0, output_weight=1.5 + # energy = 1.0 * (1000 + 500 * 1.5) = 1750.0 + assert ctx.energy_used == pytest.approx(1750.0) + + def test_observe_tracks_latency(self) -> None: + init(mode="observe") + mock_resp = _mock_anthropic_response() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_create(original) + + with run(budget=1.0) as ctx: + wrapper(MagicMock(), model="claude-sonnet-4") + + assert ctx.latency_used_ms > 0 + + def test_budget_remaining_decreases(self) -> None: + init(mode="observe") + mock_resp = _mock_anthropic_response(input_tokens=1_000_000, output_tokens=1_000_000) + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_create(original) + + with run(budget=100.0) as ctx: + wrapper(MagicMock(), model="claude-sonnet-4") + + assert ctx.budget_remaining is not None + assert ctx.budget_remaining == pytest.approx(100.0 - 18.0, abs=0.01) + + def test_trace_records_model_and_mode(self) -> None: + init(mode="observe") + mock_resp = _mock_anthropic_response() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_create(original) + + with run(budget=1.0) as ctx: + wrapper(MagicMock(), model="claude-sonnet-4") + + trace = ctx.trace() + assert len(trace) == 1 + assert trace[0]["action"] == "allow" + assert trace[0]["reason"] == "observe" + assert trace[0]["model"] == "claude-sonnet-4" + + def test_off_mode_passthrough_no_tracking(self) -> None: + init(mode="off") + mock_resp = _mock_anthropic_response() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_create(original) + + with run() as ctx: + result = wrapper(MagicMock(), model="claude-sonnet-4") + + assert result is mock_resp + assert ctx.cost == 0.0 + assert ctx.step_count == 0 + + def test_no_run_scope_returns_response(self) -> None: + init(mode="observe") + mock_resp = _mock_anthropic_response() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_create(original) + + result = wrapper(MagicMock(), model="claude-sonnet-4") + assert result is mock_resp + + def test_stream_tracks_usage_and_tool_calls(self) -> None: + init(mode="observe") + mock_stream = iter( + [ + _mock_anthropic_message_start_event(input_tokens=1_000_000), + _mock_anthropic_content_block_start_event("tool_use"), + _mock_anthropic_message_delta_event(output_tokens=1_000_000), + _mock_anthropic_message_stop_event(), + ] + ) + original = MagicMock(return_value=mock_stream) + wrapper = _make_patched_anthropic_create(original) + + with run(budget=1.0) as ctx: + result = wrapper(MagicMock(), model="claude-sonnet-4", stream=True) + assert isinstance(result, _InstrumentedAnthropicStream) + list(result) + + assert ctx.cost == pytest.approx(18.0, abs=0.01) + assert ctx.step_count == 1 + assert ctx.tool_calls == 1 + + def test_stream_finalizes_on_iteration_error(self) -> None: + init(mode="observe") + + class _FailingAnthropicStream: + def __init__(self) -> None: + self._done = False + + def __iter__(self): + return self + + def __next__(self): + if not self._done: + self._done = True + return _mock_anthropic_message_start_event(input_tokens=1_000_000) + raise RuntimeError("anthropic stream failed") + + original = MagicMock(return_value=_FailingAnthropicStream()) + wrapper = _make_patched_anthropic_create(original) + + with run(budget=1.0) as ctx: + result = wrapper(MagicMock(), model="claude-sonnet-4", stream=True) + assert isinstance(result, _InstrumentedAnthropicStream) + with pytest.raises(RuntimeError, match="anthropic stream failed"): + list(result) + + assert ctx.step_count == 1 + assert ctx.cost > 0 + + def test_multiple_calls_accumulate(self) -> None: + init(mode="observe") + mock_resp = _mock_anthropic_response(input_tokens=1_000_000, output_tokens=1_000_000) + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_create(original) + + with run(budget=100.0) as ctx: + wrapper(MagicMock(), model="claude-sonnet-4") + wrapper(MagicMock(), model="claude-sonnet-4") + + assert ctx.cost == pytest.approx(36.0, abs=0.01) + assert ctx.step_count == 2 + + +# --------------------------------------------------------------------------- +# Anthropic async wrapper +# --------------------------------------------------------------------------- + + +class TestAnthropicAsyncWrapper: + async def test_observe_passes_through_response(self) -> None: + init(mode="observe") + mock_resp = _mock_anthropic_response() + original = AsyncMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_async_create(original) + + async with run(budget=1.0) as ctx: + result = await wrapper(MagicMock(), model="claude-sonnet-4") + + assert result is mock_resp + + async def test_observe_tracks_cost(self) -> None: + init(mode="observe") + mock_resp = _mock_anthropic_response(input_tokens=1_000_000, output_tokens=1_000_000) + original = AsyncMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_async_create(original) + + async with run(budget=100.0) as ctx: + await wrapper(MagicMock(), model="claude-sonnet-4") + + assert ctx.cost == pytest.approx(18.0, abs=0.01) + assert ctx.step_count == 1 + + async def test_off_mode_passthrough(self) -> None: + init(mode="off") + mock_resp = _mock_anthropic_response() + original = AsyncMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_async_create(original) + + async with run() as ctx: + result = await wrapper(MagicMock(), model="claude-sonnet-4") + + assert result is mock_resp + assert ctx.cost == 0.0 + + async def test_stream_tracks_usage_and_tool_calls(self) -> None: + init(mode="observe") + + async def _event_stream(): + yield _mock_anthropic_message_start_event(input_tokens=1_000_000) + yield _mock_anthropic_content_block_start_event("tool_use") + yield _mock_anthropic_message_delta_event(output_tokens=1_000_000) + yield _mock_anthropic_message_stop_event() + + original = AsyncMock(return_value=_event_stream()) + wrapper = _make_patched_anthropic_async_create(original) + + async with run(budget=1.0) as ctx: + result = await wrapper(MagicMock(), model="claude-sonnet-4", stream=True) + assert isinstance(result, _InstrumentedAnthropicAsyncStream) + async for _ in result: + pass + + assert ctx.cost == pytest.approx(18.0, abs=0.01) + assert ctx.step_count == 1 + assert ctx.tool_calls == 1 + + async def test_stream_finalizes_on_iteration_error(self) -> None: + init(mode="observe") + + async def _failing_event_stream(): + yield _mock_anthropic_message_start_event(input_tokens=1_000_000) + raise RuntimeError("anthropic async stream failed") + + original = AsyncMock(return_value=_failing_event_stream()) + wrapper = _make_patched_anthropic_async_create(original) + + async with run(budget=1.0) as ctx: + result = await wrapper(MagicMock(), model="claude-sonnet-4", stream=True) + assert isinstance(result, _InstrumentedAnthropicAsyncStream) + with pytest.raises(RuntimeError, match="anthropic async stream failed"): + async for _ in result: + pass + + assert ctx.step_count == 1 + assert ctx.cost > 0 + + +# --------------------------------------------------------------------------- +# Anthropic enforce mode +# --------------------------------------------------------------------------- + + +class TestAnthropicEnforceMode: + def test_enforce_trace_records_enforce_reason(self) -> None: + init(mode="enforce") + mock_resp = _mock_anthropic_response() + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_create(original) + + with run(budget=100.0) as ctx: + wrapper(MagicMock(), model="claude-sonnet-4") + + trace = ctx.trace() + assert trace[0]["reason"] == "enforce" + + def test_enforce_raises_on_budget_exhausted(self) -> None: + from cascadeflow.schema.exceptions import BudgetExceededError + + init(mode="enforce") + mock_resp = _mock_anthropic_response(input_tokens=1_000_000, output_tokens=1_000_000) + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_create(original) + + with run(budget=0.001) as ctx: + wrapper(MagicMock(), model="claude-sonnet-4") + with pytest.raises(BudgetExceededError): + wrapper(MagicMock(), model="claude-sonnet-4") + + def test_observe_does_not_raise_on_budget_exhausted(self) -> None: + init(mode="observe") + mock_resp = _mock_anthropic_response(input_tokens=1_000_000, output_tokens=1_000_000) + original = MagicMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_create(original) + + with run(budget=0.001) as ctx: + wrapper(MagicMock(), model="claude-sonnet-4") + wrapper(MagicMock(), model="claude-sonnet-4") + + assert ctx.cost > ctx.budget_max + + async def test_async_enforce_raises_on_budget_exhausted(self) -> None: + from cascadeflow.schema.exceptions import BudgetExceededError + + init(mode="enforce") + mock_resp = _mock_anthropic_response(input_tokens=1_000_000, output_tokens=1_000_000) + original = AsyncMock(return_value=mock_resp) + wrapper = _make_patched_anthropic_async_create(original) + + async with run(budget=0.001) as ctx: + await wrapper(MagicMock(), model="claude-sonnet-4") + with pytest.raises(BudgetExceededError): + await wrapper(MagicMock(), model="claude-sonnet-4") + + +# --------------------------------------------------------------------------- +# Anthropic init() integration +# --------------------------------------------------------------------------- + + +class TestAnthropicInitIntegration: + def test_init_observe_patches_anthropic(self) -> None: + if find_spec("anthropic") is None: + pytest.skip("anthropic package not available") + report = init(mode="observe") + assert "anthropic" in report.instrumented + assert is_anthropic_patched() + + def test_init_off_unpatches_anthropic(self) -> None: + if find_spec("anthropic") is None: + pytest.skip("anthropic package not available") + init(mode="observe") + assert is_anthropic_patched() + init(mode="off") + assert not is_anthropic_patched() + + def test_reset_unpatches_anthropic(self) -> None: + if find_spec("anthropic") is None: + pytest.skip("anthropic package not available") + init(mode="observe") + assert is_anthropic_patched() + reset() + assert not is_anthropic_patched() diff --git a/tests/test_harness_shared_pricing.py b/tests/test_harness_shared_pricing.py new file mode 100644 index 00000000..a26398f3 --- /dev/null +++ b/tests/test_harness_shared_pricing.py @@ -0,0 +1,70 @@ +"""Tests for shared harness pricing/energy profiles.""" + +from __future__ import annotations + +import pytest + +import cascadeflow.harness.instrument as instrument_mod +import cascadeflow.integrations.crewai as crewai_mod +import cascadeflow.integrations.openai_agents as openai_agents_mod +from cascadeflow.harness.pricing import ( + OPENAI_MODEL_POOL, + estimate_cost, + estimate_energy, + model_total_price, +) + + +def test_shared_estimate_cost_known_models() -> None: + assert estimate_cost("gpt-4o-mini", 1_000_000, 1_000_000) == pytest.approx(0.75) + assert estimate_cost("gpt-5", 1_000_000, 1_000_000) == pytest.approx(11.25) + assert estimate_cost("claude-sonnet-4", 1_000_000, 1_000_000) == pytest.approx(18.0) + + +def test_shared_estimate_energy_defaults_for_unknown() -> None: + # default coeff=1.0, output weight=1.5 + assert estimate_energy("unknown-model", 100, 100) == pytest.approx(250.0) + + +def test_openai_pool_is_openai_only() -> None: + assert "gpt-4o" in OPENAI_MODEL_POOL + assert "gpt-5" in OPENAI_MODEL_POOL + assert "claude-sonnet-4" not in OPENAI_MODEL_POOL + + +def test_integration_estimators_use_shared_profiles() -> None: + model = "gpt-5-mini" + input_tokens = 12_345 + output_tokens = 6_789 + + shared_cost = estimate_cost(model, input_tokens, output_tokens) + shared_energy = estimate_energy(model, input_tokens, output_tokens) + + assert instrument_mod._estimate_cost(model, input_tokens, output_tokens) == pytest.approx( + shared_cost + ) + assert crewai_mod._estimate_cost(model, input_tokens, output_tokens) == pytest.approx( + shared_cost + ) + assert openai_agents_mod._estimate_cost(model, input_tokens, output_tokens) == pytest.approx( + shared_cost + ) + + assert instrument_mod._estimate_energy(model, input_tokens, output_tokens) == pytest.approx( + shared_energy + ) + assert crewai_mod._estimate_energy(model, input_tokens, output_tokens) == pytest.approx( + shared_energy + ) + assert openai_agents_mod._estimate_energy(model, input_tokens, output_tokens) == pytest.approx( + shared_energy + ) + + +def test_openai_agents_total_price_uses_shared_profiles() -> None: + assert openai_agents_mod._total_model_price("gpt-5") == pytest.approx( + model_total_price("gpt-5") + ) + assert openai_agents_mod._total_model_price("gpt-4o-mini") == pytest.approx( + model_total_price("gpt-4o-mini") + ) diff --git a/tests/test_openai_agents_integration.py b/tests/test_openai_agents_integration.py new file mode 100644 index 00000000..b2644036 --- /dev/null +++ b/tests/test_openai_agents_integration.py @@ -0,0 +1,207 @@ +import pytest + +from cascadeflow.harness import init, reset, run +import cascadeflow.integrations.openai_agents as openai_agents_integration +from cascadeflow.integrations.openai_agents import ( + CascadeFlowModelProvider, + OpenAIAgentsIntegrationConfig, +) +from cascadeflow.schema.exceptions import BudgetExceededError + + +def setup_function() -> None: + reset() + + +def test_requires_sdk_for_default_provider(monkeypatch): + monkeypatch.setattr(openai_agents_integration, "OPENAI_AGENTS_SDK_AVAILABLE", False) + with pytest.raises(ImportError): + CascadeFlowModelProvider() + + +class _FakeUsage: + def __init__(self, input_tokens: int, output_tokens: int) -> None: + self.input_tokens = input_tokens + self.output_tokens = output_tokens + + +class _FakeResponse: + def __init__(self, input_tokens: int = 0, output_tokens: int = 0, output=None) -> None: + self.usage = _FakeUsage(input_tokens=input_tokens, output_tokens=output_tokens) + self.output = output or [] + + +class _FakeEvent: + def __init__(self, response=None) -> None: + self.response = response + + +class _FakeAsyncStream: + def __init__(self, events) -> None: + self._events = list(events) + self._index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self._index >= len(self._events): + raise StopAsyncIteration + event = self._events[self._index] + self._index += 1 + return event + + +class _FakeModel: + def __init__(self, response: _FakeResponse, stream_events=None) -> None: + self._response = response + self._stream_events = stream_events or [] + self.last_kwargs = None + + async def get_response(self, **kwargs): + self.last_kwargs = kwargs + return self._response + + def stream_response(self, **kwargs): + self.last_kwargs = kwargs + return _FakeAsyncStream(self._stream_events) + + +class _FakeBaseProvider: + def __init__(self, model: _FakeModel) -> None: + self._model = model + self.requested_models = [] + + def get_model(self, model_name): + self.requested_models.append(model_name) + return self._model + + +def _response_call_kwargs(): + return { + "system_instructions": None, + "input": "hello", + "model_settings": None, + "tools": [], + "output_schema": None, + "handoffs": [], + "tracing": None, + "previous_response_id": None, + "conversation_id": None, + "prompt": None, + } + + +@pytest.mark.asyncio +async def test_metrics_updated_from_get_response(): + init(mode="observe", budget=2.0) + + output = [{"type": "function_call", "name": "lookup"}] + response = _FakeResponse(input_tokens=200, output_tokens=100, output=output) + model = _FakeModel(response=response) + provider = CascadeFlowModelProvider(base_provider=_FakeBaseProvider(model)) + + wrapped = provider.get_model("gpt-4o") + + with run(budget=2.0) as ctx: + await wrapped.get_response(**_response_call_kwargs()) + assert model.last_kwargs is not None + assert model.last_kwargs["input"] == "hello" + assert ctx.step_count == 1 + assert ctx.tool_calls == 1 + assert ctx.cost > 0 + assert ctx.energy_used > 0 + assert ctx.budget_remaining is not None + assert ctx.budget_remaining < 2.0 + assert ctx.model_used == "gpt-4o" + + +@pytest.mark.asyncio +async def test_tool_gating_enforced_when_limit_reached(): + init(mode="enforce", max_tool_calls=0, budget=1.0) + + response = _FakeResponse(input_tokens=10, output_tokens=5) + model = _FakeModel(response=response) + provider = CascadeFlowModelProvider(base_provider=_FakeBaseProvider(model)) + wrapped = provider.get_model("gpt-4o-mini") + + kwargs = _response_call_kwargs() + kwargs["tools"] = [{"name": "lookup"}] + + with run(max_tool_calls=0, budget=1.0) as ctx: + await wrapped.get_response(**kwargs) + assert model.last_kwargs is not None + assert model.last_kwargs["tools"] == [] + assert ctx.last_action == "deny_tool" + + +def test_switches_to_cheapest_candidate_under_budget_pressure(): + init(mode="enforce", budget=1.0) + + response = _FakeResponse() + model = _FakeModel(response=response) + base_provider = _FakeBaseProvider(model) + config = OpenAIAgentsIntegrationConfig(model_candidates=["gpt-4o", "gpt-4o-mini"]) + provider = CascadeFlowModelProvider(base_provider=base_provider, config=config) + + with run(budget=1.0) as ctx: + ctx.cost = 0.9 + ctx.budget_remaining = 0.1 + provider.get_model("gpt-4o") + assert base_provider.requested_models[-1] == "gpt-4o-mini" + assert ctx.last_action == "switch_model" + + +def test_budget_exceeded_raises_cascadeflow_budget_error(): + init(mode="enforce", budget=1.0) + + response = _FakeResponse() + model = _FakeModel(response=response) + provider = CascadeFlowModelProvider(base_provider=_FakeBaseProvider(model)) + + with run(budget=1.0) as ctx: + ctx.budget_remaining = 0.0 + with pytest.raises(BudgetExceededError): + provider.get_model("gpt-4o-mini") + + +def test_fail_open_falls_back_when_model_resolution_errors(monkeypatch): + response = _FakeResponse() + model = _FakeModel(response=response) + base_provider = _FakeBaseProvider(model) + provider = CascadeFlowModelProvider(base_provider=base_provider) + + def _boom(_: object) -> str: + raise ValueError("resolution failed") + + monkeypatch.setattr(provider, "_resolve_model", _boom) + wrapped = provider.get_model("gpt-4o") + + assert wrapped is not None + assert base_provider.requested_models[-1] == "gpt-4o" + + +@pytest.mark.asyncio +async def test_stream_response_updates_metrics(): + init(mode="observe", budget=3.0) + + final_response = _FakeResponse( + input_tokens=120, + output_tokens=60, + output=[{"type": "function_call", "name": "tool_a"}], + ) + stream_events = [_FakeEvent(response=final_response)] + model = _FakeModel(response=final_response, stream_events=stream_events) + provider = CascadeFlowModelProvider(base_provider=_FakeBaseProvider(model)) + wrapped = provider.get_model("gpt-4o-mini") + + with run(budget=3.0) as ctx: + async for _ in wrapped.stream_response(**_response_call_kwargs()): + pass + + assert model.last_kwargs is not None + assert model.last_kwargs["input"] == "hello" + assert ctx.step_count == 1 + assert ctx.tool_calls == 1 + assert ctx.cost > 0 + assert ctx.model_used == "gpt-4o-mini"