Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 50 additions & 8 deletions ai_council/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .utils.config import AICouncilConfig, load_config
from .utils.logging import configure_logging, get_logger
from .factory import AICouncilFactory
from .sanitization import SanitizationFilter


class AICouncil:
Expand Down Expand Up @@ -66,7 +67,17 @@ def __init__(self, config_path: Optional[Path] = None):

# Initialize orchestration layer
self.orchestration_layer: OrchestrationLayer = self.factory.create_orchestration_layer()


# Initialize sanitization filter (runs before prompt construction)
sanitization_config = (
config_path.parent / "sanitization_filters.yaml"
if config_path is not None
else None
)
self.sanitization_filter: SanitizationFilter = SanitizationFilter.from_config(
config_path=sanitization_config
)

self.logger.info("AI Council application initialized successfully")

async def _execute_with_timeout(
Expand Down Expand Up @@ -114,23 +125,54 @@ async def _execute_with_timeout(
)

async def process_request(
self,
user_input: str,
execution_mode: ExecutionMode = ExecutionMode.BALANCED
self,
user_input: str,
execution_mode: ExecutionMode = ExecutionMode.BALANCED,
*,
session_id: str = "anonymous",
) -> FinalResponse:
"""
Process a user request through the AI Council system.


The Sanitization Filter runs FIRST, before any prompt construction
or orchestration. Injection attempts are rejected immediately.

Args:
user_input: The user's request as a string
user_input: The user's request as a string
execution_mode: The execution mode to use (fast, balanced, best_quality)

session_id: Per-session key used for rate-limit tracking.

Returns:
FinalResponse: The final processed response
"""
self.logger.info("Processing request in", extra={"value": execution_mode.value})
self.logger.debug("User input", extra={"user_input": user_input[:200]})


# ── Stage 0: Sanitization Filter ─────────────────────────────────
filter_result = self.sanitization_filter.check(
user_input, source_key=session_id
)
if not filter_result.is_safe:
self.logger.warning(
"Request blocked by SanitizationFilter",
extra={
"session_id": session_id,
"filter": filter_result.filter_name,
"severity": filter_result.severity.value if filter_result.severity else None,
"rule": filter_result.triggered_rule,
},
)
return FinalResponse(
content="",
overall_confidence=0.0,
success=False,
error_message=(
"Unsafe input detected. Request blocked due to potential prompt injection."
),
error_type="prompt_injection",
)
# ─────────────────────────────────────────────────────────────────

return await self._execute_with_timeout(user_input, execution_mode)

async def estimate_cost(self, user_input: str, execution_mode: ExecutionMode = ExecutionMode.BALANCED) -> Dict[str, Any]:
Expand Down
60 changes: 60 additions & 0 deletions ai_council/query_pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
Cost-Optimized Query Processing System for AI Council.

Pipeline (left-to-right):

User Input
→ QueryCache (short-circuit on cache hit)
→ EmbeddingEngine (dense vector representation)
→ VectorStore (top-k nearest-neighbour search)
→ TopicClassifier (topic label + context chunks)
→ SmartQueryDecomposer (sub-queries + dependency graph)
→ ModelRouter (cheap / mid / expensive tier)
→ TokenOptimizer (prompt compression + RAG cherry-pick)
→ Execution (parallel, via existing orchestration)
→ ResponseAggregator (merge + CostReport)
→ QueryCache.store()
→ PipelineResult

Public API::

from ai_council.query_pipeline import QueryPipeline, PipelineConfig

pipeline = QueryPipeline.from_config()
Copy link

Copilot AI Mar 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module docstring shows pipeline = QueryPipeline.from_config(), but QueryPipeline exposes build() (and no from_config). This makes the public API example incorrect for users. Update the docstring to use the actual factory method (or add the missing from_config alias).

Suggested change
pipeline = QueryPipeline.from_config()
pipeline = QueryPipeline.build()

Copilot uses AI. Check for mistakes.
result = await pipeline.process("Explain quicksort and give Python code")
print(result.cost_report)
"""

from .config import PipelineConfig
from .embeddings import EmbeddingEngine
from .vector_store import VectorStore, SearchResult
from .topic_classifier import TopicClassifier, ClassificationResult
from .query_decomposer import SmartQueryDecomposer, DecompositionResult, SubQuery
from .model_router import ModelRouter, RoutingDecision, ModelTier
from .token_optimizer import TokenOptimizer, OptimizedPrompt
from .cache import QueryCache, CacheStats
from .pipeline import QueryPipeline, PipelineResult, CostReport

__all__ = [
# top-level pipeline
"QueryPipeline",
"PipelineResult",
"CostReport",
"PipelineConfig",
# individual components (composable)
"EmbeddingEngine",
"VectorStore",
"SearchResult",
"TopicClassifier",
"ClassificationResult",
"SmartQueryDecomposer",
"DecompositionResult",
"SubQuery",
"ModelRouter",
"RoutingDecision",
"ModelTier",
"TokenOptimizer",
"OptimizedPrompt",
"QueryCache",
"CacheStats",
]
212 changes: 212 additions & 0 deletions ai_council/query_pipeline/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
"""QueryCache — two-level LRU cache for query results.

Level 1: In-memory ``OrderedDict`` LRU (always available).
Level 2: ``diskcache`` persistence (optional, activated when installed).

Cache keys are SHA-256 hashes of the normalised query text, so the cache
is resilient to minor whitespace/punctuation variations.
"""

from __future__ import annotations

import hashlib
import logging
import time
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

logger = logging.getLogger(__name__)


def _normalise(query: str) -> str:
"""Normalise a query for cache key generation."""
return " ".join(query.lower().split())


def _make_key(query: str) -> str:
return hashlib.sha256(_normalise(query).encode()).hexdigest()


# ─────────────────────────────────────────────────────────────────────────────
# Data classes
# ─────────────────────────────────────────────────────────────────────────────

@dataclass
class CachedResponse:
query_key: str
result: Any
stored_at: float = field(default_factory=time.time)
ttl_seconds: int = 3600
hit_count: int = 0

def is_expired(self) -> bool:
return (time.time() - self.stored_at) > self.ttl_seconds


@dataclass
class CacheStats:
hits: int = 0
misses: int = 0
evictions: int = 0
size: int = 0

@property
def hit_rate(self) -> float:
total = self.hits + self.misses
return self.hits / total if total else 0.0

@property
def miss_rate(self) -> float:
return 1.0 - self.hit_rate


# ─────────────────────────────────────────────────────────────────────────────
# QueryCache
# ─────────────────────────────────────────────────────────────────────────────

class QueryCache:
"""Two-level LRU query cache.

Args:
max_memory_entries: Maximum entries in the in-memory LRU.
ttl_seconds: Default time-to-live for cached entries.
persist: Enable diskcache persistence (requires ``diskcache``).
persist_path: Path for the diskcache directory.

Example::

cache = QueryCache(max_memory_entries=256, ttl_seconds=3600)
cache.store("What is quicksort?", {"answer": "..."})
hit = cache.lookup("What is quicksort?")
assert hit is not None
"""

def __init__(
self,
max_memory_entries: int = 512,
ttl_seconds: int = 3600,
persist: bool = False,
persist_path: str = "~/.ai_council/cache/query_pipeline",
):
self._max = max_memory_entries
self._ttl = ttl_seconds
self._mem: OrderedDict[str, CachedResponse] = OrderedDict()
self._stats = CacheStats()
self._disk: Optional[Any] = None

if persist:
self._disk = self._init_disk(persist_path)

# ── Disk cache init ───────────────────────────────────────────────────────

@staticmethod
def _init_disk(path: str) -> Optional[Any]:
try:
import diskcache # type: ignore
import os
resolved = os.path.expanduser(path)
dc = diskcache.Cache(resolved, size_limit=256 * 1024 * 1024)
logger.info("QueryCache: diskcache persisted to '%s'.", resolved)
return dc
except ImportError:
logger.warning("QueryCache: diskcache not installed; memory-only mode.")
return None
except Exception as exc:
logger.warning("QueryCache: failed to init diskcache (%s); memory-only mode.", exc)
return None

# ── Public API ────────────────────────────────────────────────────────────

def lookup(self, query: str) -> Optional[Any]:
"""Return the cached result for *query*, or ``None`` on a miss/expiry."""
key = _make_key(query)

# Level 1: memory
if key in self._mem:
entry = self._mem[key]
if entry.is_expired():
del self._mem[key]
self._stats.evictions += 1
else:
self._mem.move_to_end(key)
entry.hit_count += 1
self._stats.hits += 1
logger.debug("QueryCache HIT (memory) for key=%s...", key[:12])
return entry.result

# Level 2: disk
if self._disk is not None:
try:
data = self._disk.get(key)
if data is not None:
# Promote to memory
self._mem_store(key, data, self._ttl)
self._stats.hits += 1
logger.debug("QueryCache HIT (disk) for key=%s...", key[:12])
return data
except Exception as exc:
logger.warning("QueryCache disk lookup failed: %s", exc)

self._stats.misses += 1
return None

def store(self, query: str, result: Any, ttl: Optional[int] = None) -> None:
"""Cache *result* under *query* for *ttl* seconds."""
key = _make_key(query)
effective_ttl = ttl if ttl is not None else self._ttl

self._mem_store(key, result, effective_ttl)

if self._disk is not None:
try:
self._disk.set(key, result, expire=effective_ttl)
except Exception as exc:
logger.warning("QueryCache disk store failed: %s", exc)

logger.debug("QueryCache stored key=%s... (ttl=%ds)", key[:12], effective_ttl)

def invalidate(self, query: str) -> bool:
"""Remove a single entry. Returns True if it existed."""
key = _make_key(query)
found = False
if key in self._mem:
del self._mem[key]
found = True
if self._disk is not None:
try:
found = self._disk.delete(key) or found
except Exception:
pass
return found

def clear(self) -> None:
"""Clear all cached entries (memory + disk)."""
self._mem.clear()
if self._disk is not None:
try:
self._disk.clear()
except Exception:
pass
logger.info("QueryCache cleared.")

def stats(self) -> CacheStats:
self._stats.size = len(self._mem)
return self._stats

# ── Internals ─────────────────────────────────────────────────────────────

def _mem_store(self, key: str, result: Any, ttl: int) -> None:
if key in self._mem:
self._mem.move_to_end(key)
else:
if len(self._mem) >= self._max:
# Evict LRU
evicted_key = next(iter(self._mem))
del self._mem[evicted_key]
self._stats.evictions += 1
self._mem[key] = CachedResponse(
query_key=key,
result=result,
ttl_seconds=ttl,
)
Loading
Loading