diff --git a/autobot-backend/knowledge/memory_graph/__init__.py b/autobot-backend/knowledge/memory_graph/__init__.py new file mode 100644 index 000000000..e90f29953 --- /dev/null +++ b/autobot-backend/knowledge/memory_graph/__init__.py @@ -0,0 +1,62 @@ +# Copyright (c) mrveiss. All rights reserved. +""" +knowledge.memory_graph — Redis DB 1 (knowledge) memory graph foundation layer. + +Provides: +- Semantic search query processor and hybrid scorer (query_processor.py, hybrid_scorer.py) +- Schema constants and index creation (schema.py) +- Entity / relation CRUD and BFS traversal (graph_store.py) + +Public API re-exported from this package so callers can write: + from knowledge.memory_graph import MemoryGraphQueryProcessor, create_entity +""" + +from .hybrid_scorer import HybridScorer, SearchResult # noqa: F401 +from .query_processor import MemoryGraphQueryProcessor, QueryIntent # noqa: F401 + +# Graph store symbols — added by PR #3608 / issue-3385 +try: + from .graph_store import ( # noqa: F401 + create_entity, + create_relation, + get_entity, + get_incoming_relations, + get_outgoing_relations, + traverse_relations, + ) + from .schema import ( # noqa: F401 + ENTITY_KEY_PREFIX, + ENTITY_TYPES, + FULLTEXT_INDEX_NAME, + PRIMARY_INDEX_NAME, + RELATION_TYPES, + RELATIONS_IN_PREFIX, + RELATIONS_OUT_PREFIX, + ensure_indexes, + ) +except ImportError: + pass # graph store not yet merged; safe to skip + +__all__ = [ + # query processor / hybrid scorer + "HybridScorer", + "MemoryGraphQueryProcessor", + "QueryIntent", + "SearchResult", + # schema + "ENTITY_KEY_PREFIX", + "ENTITY_TYPES", + "FULLTEXT_INDEX_NAME", + "PRIMARY_INDEX_NAME", + "RELATION_TYPES", + "RELATIONS_IN_PREFIX", + "RELATIONS_OUT_PREFIX", + "ensure_indexes", + # graph_store + "create_entity", + "create_relation", + "get_entity", + "get_incoming_relations", + "get_outgoing_relations", + "traverse_relations", +] diff --git a/autobot-backend/knowledge/memory_graph/hybrid_scorer.py b/autobot-backend/knowledge/memory_graph/hybrid_scorer.py new file mode 100644 index 000000000..f0677988a --- /dev/null +++ b/autobot-backend/knowledge/memory_graph/hybrid_scorer.py @@ -0,0 +1,335 @@ +# Copyright (c) mrveiss. All rights reserved. +# AutoBot - AI-Powered Automation Platform +""" +Memory Graph Hybrid Scorer — Phase 2. + +Issue #3384: Hybrid scoring combining semantic similarity (cosine) and +keyword relevance (BM25-style TF-IDF) for memory-graph entity search. + +Scoring formula: + score = SEMANTIC_WEIGHT * cosine_similarity(q_embed, entity_embed) + + KEYWORD_WEIGHT * bm25_score(keywords, entity_text) + +Both components are normalised to [0.0, 1.0] before combination so that +neither dominates purely because of scale differences. + +Entity embeddings are retrieved from Redis (key ``mg:entity_embed:``) +where they are stored by the indexer that accompanies Phase 1 of issue #3385. +When an embedding is absent the semantic score defaults to 0.0 and only the +keyword score contributes. +""" + +import json +import logging +import math +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Scoring weights +# --------------------------------------------------------------------------- + +SEMANTIC_WEIGHT: float = 0.6 +KEYWORD_WEIGHT: float = 0.4 + +# BM25 parameters +_BM25_K1: float = 1.2 +_BM25_B: float = 0.75 + +# Redis key prefix where entity embeddings are stored by the indexer +_ENTITY_EMBED_KEY_PREFIX = "mg:entity_embed:" + + +# --------------------------------------------------------------------------- +# Data structure +# --------------------------------------------------------------------------- + + +@dataclass +class SearchResult: + """A ranked result from the memory-graph hybrid search.""" + + entity: Dict[str, Any] + score: float # combined hybrid score [0.0, 1.0] + semantic_score: float # cosine similarity [0.0, 1.0] + keyword_score: float # BM25 keyword relevance [0.0, 1.0] + matched_keywords: List[str] = field(default_factory=list) + explanation: str = "" + + +# --------------------------------------------------------------------------- +# HybridScorer +# --------------------------------------------------------------------------- + + +class HybridScorer: + """ + Combines cosine-similarity (semantic) and BM25 (keyword) scores. + + Issue #3384 — Phase 2. + + Designed to be instantiated once and reused across requests. + All methods are synchronous except ``score_and_rank`` which fetches + entity embeddings from Redis asynchronously. + """ + + def __init__(self, redis_client=None) -> None: + """ + Args: + redis_client: Optional async Redis client for embedding lookups. + When None a client is acquired lazily on first use. + """ + self._redis = redis_client + + async def _get_redis(self): + """Lazily obtain the async Redis client (cached on self._redis).""" + if self._redis is None: + from autobot_shared.redis_client import get_redis_client + self._redis = await get_redis_client(async_client=True, database="knowledge") + return self._redis + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def score_and_rank( + self, + query: str, + intent: Any, # QueryIntent — avoid circular import; duck-typed + candidates: List[Dict[str, Any]], + query_embedding: Optional[List[float]], + limit: int = 10, + ) -> List["SearchResult"]: + """ + Score all candidates and return the top *limit* results. + + Args: + query: Original user query (for explanation text). + intent: QueryIntent produced by the query processor. + candidates: Entity dicts returned by Redis FT.SEARCH. + query_embedding: Pre-computed query embedding, or None. + limit: Maximum results to return. + + Returns: + Sorted list of SearchResult (highest score first). + """ + keywords = getattr(intent, "keywords", []) + results: List[SearchResult] = [] + redis = await self._get_redis() + + for entity in candidates: + entity_text = _entity_to_text(entity) + entity_id = entity.get("id", "") + + # Semantic score + entity_embedding = await _fetch_entity_embedding(entity_id, redis) + sem_score = ( + cosine_similarity(query_embedding, entity_embedding) + if query_embedding and entity_embedding + else 0.0 + ) + + # Keyword (BM25) score + kw_score, matched = self.bm25_score(keywords, entity_text) + + # Combined score + combined = SEMANTIC_WEIGHT * sem_score + KEYWORD_WEIGHT * kw_score + + results.append( + SearchResult( + entity=entity, + score=round(combined, 4), + semantic_score=round(sem_score, 4), + keyword_score=round(kw_score, 4), + matched_keywords=matched, + explanation=_build_explanation( + sem_score, kw_score, combined, matched + ), + ) + ) + + results.sort(key=lambda r: r.score, reverse=True) + return results[:limit] + + # ------------------------------------------------------------------ + # Cosine similarity + # ------------------------------------------------------------------ + + @staticmethod + def cosine_similarity( + vec_a: List[float], vec_b: List[float] + ) -> float: + """Return cosine similarity in [0.0, 1.0] between two vectors.""" + return cosine_similarity(vec_a, vec_b) + + # ------------------------------------------------------------------ + # BM25 keyword scoring + # ------------------------------------------------------------------ + + def bm25_score( + self, + keywords: List[str], + document: str, + avg_doc_length: float = 150.0, + ) -> tuple: # (float, List[str]) + """ + Compute a normalised BM25 score for *keywords* against *document*. + + The score is normalised to [0.0, 1.0] by dividing by the maximum + possible per-term score (k1 + 1) times the number of keywords, + so results are directly comparable across documents of varying length. + + Args: + keywords: List of query term strings (already lower-cased). + document: Full entity text to score against. + avg_doc_length: Average document length in tokens used for BM25 + length normalisation (default: 150 tokens). + + Returns: + Tuple of (normalised_score, matched_keyword_list). + """ + if not keywords or not document: + return 0.0, [] + + doc_lower = document.lower() + doc_tokens = doc_lower.split() + doc_length = len(doc_tokens) + + raw_score = 0.0 + matched: List[str] = [] + + for term in keywords: + tf = doc_tokens.count(term) + if tf == 0: + continue + + matched.append(term) + # BM25 TF component (IDF set to 1 — single-document context) + tf_component = (tf * (_BM25_K1 + 1)) / ( + tf + _BM25_K1 * (1 - _BM25_B + _BM25_B * doc_length / avg_doc_length) + ) + raw_score += tf_component + + if not matched: + return 0.0, [] + + # Normalise: max per-term score is k1+1 when tf dominates + max_possible = (_BM25_K1 + 1) * len(keywords) + normalised = min(raw_score / max_possible, 1.0) + return round(normalised, 4), matched + + +# --------------------------------------------------------------------------- +# Module-level helpers (exported for direct use / testing) +# --------------------------------------------------------------------------- + + +def cosine_similarity( + vec_a: Optional[List[float]], vec_b: Optional[List[float]] +) -> float: + """ + Compute cosine similarity between two float vectors. + + Returns a value in [0.0, 1.0]. Returns 0.0 when either vector is + None, empty, or has zero magnitude. + """ + if not vec_a or not vec_b: + return 0.0 + + if len(vec_a) != len(vec_b): + logger.warning( + "cosine_similarity: dimension mismatch %d vs %d", + len(vec_a), + len(vec_b), + ) + return 0.0 + + dot = sum(a * b for a, b in zip(vec_a, vec_b)) + mag_a = math.sqrt(sum(a * a for a in vec_a)) + mag_b = math.sqrt(sum(b * b for b in vec_b)) + + if mag_a == 0.0 or mag_b == 0.0: + return 0.0 + + # Clamp to [0, 1] — cosine can be slightly outside due to float rounding + return max(0.0, min(1.0, dot / (mag_a * mag_b))) + + +def _entity_to_text(entity: Dict[str, Any]) -> str: + """ + Build a weighted text representation of an entity for BM25 scoring. + + Applies the same weights used for embedding generation: + name × 3 (weight 0.3) + type × 1 (weight 0.1) + observations × 6 (weight 0.6) + """ + name = entity.get("name", "") + entity_type = entity.get("type", "") + observations = entity.get("observations", []) + + if isinstance(observations, str): + # May arrive as a JSON string when parsed from Redis hash fields + try: + observations = json.loads(observations) + except (ValueError, json.JSONDecodeError): + observations = [observations] + + obs_text = " ".join(str(o) for o in observations) + + parts = [ + f"{name} " * 3, + entity_type, + f"{obs_text} " * 6, + ] + return " ".join(p.strip() for p in parts if p.strip()) + + +def _build_explanation( + sem_score: float, + kw_score: float, + combined: float, + matched_keywords: List[str], +) -> str: + """Produce a human-readable explanation of the hybrid score.""" + kw_part = ( + f"keywords matched: {', '.join(matched_keywords)}" + if matched_keywords + else "no keyword matches" + ) + return ( + f"score={combined:.2f} " + f"(semantic={sem_score:.2f} × {SEMANTIC_WEIGHT}, " + f"keyword={kw_score:.2f} × {KEYWORD_WEIGHT}); " + f"{kw_part}" + ) + + +async def _fetch_entity_embedding( + entity_id: str, + redis_client: Any, +) -> Optional[List[float]]: + """ + Retrieve a pre-computed entity embedding from Redis. + + Returns None if the embedding has not yet been indexed so that the + caller can degrade gracefully to keyword-only scoring. + + Args: + entity_id: Entity UUID (without key prefix). + redis_client: Async Redis client — caller supplies to avoid N connections. + """ + if not entity_id: + return None + + try: + key = f"{_ENTITY_EMBED_KEY_PREFIX}{entity_id}" + raw = await redis_client.get(key) + if raw: + return json.loads(raw) + except Exception as exc: + logger.debug("Failed to fetch entity embedding for %s: %s", entity_id, exc) + + return None diff --git a/autobot-backend/knowledge/memory_graph/query_processor.py b/autobot-backend/knowledge/memory_graph/query_processor.py new file mode 100644 index 000000000..da24f4718 --- /dev/null +++ b/autobot-backend/knowledge/memory_graph/query_processor.py @@ -0,0 +1,650 @@ +# Copyright (c) mrveiss. All rights reserved. +# AutoBot - AI-Powered Automation Platform +""" +Memory Graph Query Processor — Phase 1 & 2. + +Issue #3384: Core infrastructure for semantic search over memory-graph entities. + +5-stage pipeline: + Stage 1 Intent extraction (pattern matching — time, entity type, status) + Stage 2 Filter generation (Redis FT.SEARCH query string) + Stage 3 Query embedding (via existing npu_client fallback) + Stage 4 Hybrid search (Redis candidates + vector ranking) + Stage 5 Result ranking (HybridScorer — semantic + BM25) + +Redis key layout (read-only; written by autobot_memory_graph): + memory:entity: JSON — entity document + FT index name memory_entity_idx +""" + +import asyncio +import hashlib +import json +import logging +import re +import time +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional, Tuple + +from autobot_shared.redis_client import get_redis_client + +from knowledge.memory_graph.hybrid_scorer import HybridScorer, SearchResult + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +_EMBEDDING_CACHE_TTL_SECONDS = 3600 # 1 hour L2 cache for query embeddings +_EMBEDDING_CACHE_KEY_PREFIX = "mg:embed:" +_RESULT_CACHE_TTL_SECONDS = 300 # 5 minutes L2 cache for search results +_RESULT_CACHE_KEY_PREFIX = "mg:search:" + +_DEFAULT_CANDIDATE_LIMIT = 50 # Redis candidates before vector ranking +_DEFAULT_RESULT_LIMIT = 10 + +_FT_INDEX = "memory_entity_idx" +_ENTITY_KEY_PREFIX = "memory:entity:" + +# Intent pattern tables +_TIME_PATTERNS: List[Tuple[str, Any]] = [ + (r"\btoday\b", lambda: {"start": datetime.now(tz=timezone.utc).date()}), + ( + r"\byesterday\b", + lambda: {"start": (datetime.now(tz=timezone.utc) - timedelta(days=1)).date()}, + ), + ( + r"\bthis week\b", + lambda: { + "start": (datetime.now(tz=timezone.utc) - timedelta(days=datetime.now(tz=timezone.utc).weekday())).date() + }, + ), + ( + r"\blast (\d+) days?\b", + lambda m: { + "start": (datetime.now(tz=timezone.utc) - timedelta(days=int(m.group(1)))).date() + }, + ), + ( + r"\bthis month\b", + lambda: {"start": datetime.now(tz=timezone.utc).replace(day=1).date()}, + ), +] + +_ENTITY_TYPE_PATTERNS: List[Tuple[str, List[str]]] = [ + (r"\bbugs?\b", ["bug_fix"]), + (r"\bfix(es)?\b", ["bug_fix"]), + (r"\bfeatures?\b", ["feature"]), + (r"\bdecisions?\b", ["decision"]), + (r"\btasks?\b", ["task"]), + (r"\bconversations?\b", ["conversation"]), +] + +_STATUS_PATTERNS: List[Tuple[str, str]] = [ + (r"\b(worked on|completed|finished|fixed)\b", "completed"), + (r"\b(started|began|working on|in progress)\b", "in_progress"), + (r"\b(planned|todo|pending)\b", "pending"), + (r"\bactive\b", "active"), +] + +# Stopwords to strip before keyword extraction +_STOPWORDS = frozenset( + { + "a", + "an", + "the", + "and", + "or", + "but", + "in", + "on", + "at", + "to", + "for", + "of", + "with", + "is", + "are", + "was", + "were", + "be", + "been", + "being", + "have", + "has", + "had", + "do", + "does", + "did", + "will", + "would", + "could", + "should", + "may", + "might", + "we", + "i", + "me", + "us", + "you", + "he", + "she", + "it", + "they", + "what", + "which", + "who", + "whom", + "this", + "that", + "these", + "those", + "show", + "find", + "get", + "tell", + "give", + "list", + "me", + } +) + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + + +@dataclass +class QueryIntent: + """Structured intent extracted from a natural language query.""" + + entity_types: List[str] = field(default_factory=list) + time_range: Optional[Dict[str, Any]] = None + status_filter: Optional[str] = None + keywords: List[str] = field(default_factory=list) + semantic_query: str = "" + + +# --------------------------------------------------------------------------- +# MemoryGraphQueryProcessor +# --------------------------------------------------------------------------- + + +class MemoryGraphQueryProcessor: + """ + Natural language query processor for memory-graph entities. + + Issue #3384: Phase 1 (core infrastructure) + Phase 2 (hybrid scoring). + + Usage:: + + processor = MemoryGraphQueryProcessor() + results = await processor.process_query("What bugs did we fix today?") + + The processor is stateless except for the injected Redis client, so a + single shared instance per application process is safe. + """ + + def __init__( + self, + redis_client=None, + candidate_limit: int = _DEFAULT_CANDIDATE_LIMIT, + ) -> None: + """ + Initialise processor. + + Args: + redis_client: Async aioredis client pointing at the *knowledge* + database. When None the processor lazily fetches one via + ``get_redis_client``. + candidate_limit: Maximum entities pulled from Redis before vector + ranking (keeps latency bounded). + """ + self._redis = redis_client + self._candidate_limit = candidate_limit + self._scorer = HybridScorer() + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def process_query( + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + limit: int = _DEFAULT_RESULT_LIMIT, + ) -> List[SearchResult]: + """ + Run the 5-stage hybrid search pipeline. + + Args: + query: Natural language query string. + filters: Optional caller-supplied structured filters that are + merged with intent-extracted filters. + limit: Maximum results to return. + + Returns: + Ranked list of SearchResult objects (highest score first). + """ + query = (query or "").strip() + if not query: + logger.warning("process_query called with empty query") + return [] + + t_start = time.monotonic() + + # Stage 1: intent extraction + intent = self._extract_intent(query) + if filters: + _merge_filters(intent, filters) + + # Stage 2: build Redis query string + redis_query = _build_redis_query(intent) + + # Stages 3 & 4 run concurrently + embedding_task = asyncio.create_task( + self._get_query_embedding(intent.semantic_query or query) + ) + candidates_task = asyncio.create_task( + self._fetch_candidates(redis_query, self._candidate_limit) + ) + + query_embedding, candidates = await asyncio.gather( + embedding_task, candidates_task + ) + + if not candidates: + logger.info("No candidates found for query: %s", query) + return [] + + # Stage 5: hybrid score + rank + results = await self._scorer.score_and_rank( + query=query, + intent=intent, + candidates=candidates, + query_embedding=query_embedding, + limit=limit, + ) + + elapsed_ms = (time.monotonic() - t_start) * 1000 + logger.info( + "process_query completed in %.1f ms, %d/%d results returned", + elapsed_ms, + len(results), + len(candidates), + ) + return results + + async def get_entity(self, entity_id: str) -> Optional[Dict[str, Any]]: + """ + Retrieve a single entity document by its UUID. + + Args: + entity_id: UUID string (without the ``memory:entity:`` prefix). + + Returns: + Entity dict, or None if not found. + """ + redis = await self._get_redis() + key = f"{_ENTITY_KEY_PREFIX}{entity_id}" + try: + doc = await redis.json().get(key) + return doc + except Exception as exc: + logger.warning("get_entity failed for %s: %s", entity_id, exc) + return None + + async def get_entity_by_name(self, name: str) -> Optional[Dict[str, Any]]: + """ + Retrieve the first entity matching *name* exactly. + + Args: + name: Exact entity name string. + + Returns: + Entity dict, or None if not found. + """ + redis = await self._get_redis() + # Escape RediSearch special chars in the name + safe_name = re.sub(r"([-@!(){}[\]/\\^$*?.,|:;])", r"\\\1", name) + ft_query = f"@name:({safe_name})" + try: + raw = await redis.execute_command( + "FT.SEARCH", + _FT_INDEX, + ft_query, + "LIMIT", + "0", + "1", + ) + entities = _parse_ft_results(raw) + return entities[0] if entities else None + except Exception as exc: + logger.warning("get_entity_by_name failed for %s: %s", name, exc) + return None + + async def get_related_entities( + self, + entity_name: str, + relation_type: Optional[str] = None, + limit: int = 20, + ) -> List[Dict[str, Any]]: + """ + Retrieve entities related to the given entity via the relations index. + + Args: + entity_name: Name of the source entity. + relation_type: If given, filter to this relation type only. + limit: Maximum relations to follow. + + Returns: + List of related entity dicts. + """ + redis = await self._get_redis() + # Relations are keyed by UUID; resolve name → UUID via FT.SEARCH first + # Key prefix matches schema.py RELATIONS_OUT_PREFIX = "memory:relations:out:" + source = await self.get_entity_by_name(entity_name) + if not source: + return [] + rel_key = f"memory:relations:out:{source['id']}" + try: + raw_json = await redis.json().get(rel_key) + if not raw_json: + return [] + doc = raw_json if isinstance(raw_json, dict) else {} + relations: List[Dict[str, Any]] = doc.get("relations", []) + except Exception as exc: + logger.warning( + "Failed to read relations for %s: %s", entity_name, exc + ) + return [] + + if relation_type: + relations = [r for r in relations if r.get("type") == relation_type] + + related: List[Dict[str, Any]] = [] + for rel in relations[:limit]: + target_name = rel.get("to", "") + if not target_name: + continue + entity = await self.get_entity_by_name(target_name) + if entity: + related.append(entity) + + return related + + # ------------------------------------------------------------------ + # Stage 1: Intent extraction + # ------------------------------------------------------------------ + + def _extract_intent(self, query: str) -> QueryIntent: + """Extract structured intent from a natural language query.""" + query_lower = query.lower() + intent = QueryIntent(semantic_query=query) + + # Time filters — try each pattern; handlers accept an optional match + intent.time_range = _extract_time_range(query_lower) + + # Entity type filters + for pattern, types in _ENTITY_TYPE_PATTERNS: + if re.search(pattern, query_lower): + for t in types: + if t not in intent.entity_types: + intent.entity_types.append(t) + + # Status filters + for pattern, status in _STATUS_PATTERNS: + if re.search(pattern, query_lower): + intent.status_filter = status + break + + # Keywords for hybrid scoring + intent.keywords = self._extract_keywords(query_lower) + + # Derive a clean semantic query by removing stop words + semantic_terms = [k for k in query_lower.split() if k not in _STOPWORDS] + intent.semantic_query = " ".join(semantic_terms) if semantic_terms else query + + return intent + + def _extract_keywords(self, text: str) -> List[str]: + """Extract non-stopword alphabetic tokens from text.""" + tokens = re.findall(r"[a-z]+", text.lower()) + return [t for t in tokens if t not in _STOPWORDS and len(t) > 2] + + # ------------------------------------------------------------------ + # Stage 3: Embedding + # ------------------------------------------------------------------ + + async def _get_query_embedding( + self, semantic_query: str + ) -> Optional[List[float]]: + """ + Return an embedding vector for *semantic_query*. + + Checks the Redis L2 embedding cache first; generates via the NPU + client fallback on a miss. + """ + if not semantic_query: + return None + + cache_key = _embedding_cache_key(semantic_query) + redis = await self._get_redis() + + # L2 cache read + try: + cached = await redis.get(cache_key) + if cached: + return json.loads(cached) + except Exception as exc: + logger.debug("Embedding cache read failed: %s", exc) + + # Generate via NPU worker / Ollama fallback + embedding = await _generate_embedding(semantic_query) + + # L2 cache write (best-effort) + if embedding: + try: + await redis.setex( + cache_key, + _EMBEDDING_CACHE_TTL_SECONDS, + json.dumps(embedding), + ) + except Exception as exc: + logger.debug("Embedding cache write failed: %s", exc) + + return embedding + + # ------------------------------------------------------------------ + # Stage 4: Redis candidate retrieval + # ------------------------------------------------------------------ + + async def _fetch_candidates( + self, redis_query: str, limit: int + ) -> List[Dict[str, Any]]: + """ + Execute FT.SEARCH to retrieve entity candidates. + + Falls back to a scan-based approach if the full-text index is + unavailable (e.g., development without RediSearch). + """ + redis = await self._get_redis() + try: + raw = await redis.execute_command( + "FT.SEARCH", + _FT_INDEX, + redis_query, + "LIMIT", + "0", + str(limit), + ) + return _parse_ft_results(raw) + except Exception as exc: + logger.warning( + "FT.SEARCH failed (%s), falling back to scan", exc + ) + return await self._scan_fallback(limit) + + async def _scan_fallback(self, limit: int) -> List[Dict[str, Any]]: + """Scan Redis for entity keys when FT index is unavailable.""" + redis = await self._get_redis() + entities: List[Dict[str, Any]] = [] + try: + async for key in redis.scan_iter( + match=f"{_ENTITY_KEY_PREFIX}*", count=100 + ): + if len(entities) >= limit: + break + doc = await redis.json().get(key) + if doc: + entities.append(doc) + except Exception as exc: + logger.warning("scan_fallback failed: %s", exc) + return entities + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + async def _get_redis(self): + """Lazily obtain the async Redis client.""" + if self._redis is None: + self._redis = await get_redis_client( + async_client=True, database="knowledge" + ) + return self._redis + + +# --------------------------------------------------------------------------- +# Module-level helpers +# --------------------------------------------------------------------------- + + +def _extract_time_range(query_lower: str) -> Optional[Dict[str, Any]]: + """Try each time pattern and return the first match.""" + for pattern, handler in _TIME_PATTERNS: + m = re.search(pattern, query_lower) + if m: + try: + # Handlers that need the match object accept it positionally + return handler(m) + except TypeError: + return handler() + return None + + +def _build_redis_query(intent: QueryIntent) -> str: + """ + Translate a QueryIntent into a RediSearch FT.SEARCH query string. + + Returns ``*`` (match all) when no intent filters are present so that + the caller always gets candidates to vector-rank. + """ + parts: List[str] = [] + + if intent.entity_types: + type_filter = "|".join(intent.entity_types) + parts.append(f"@type:({type_filter})") + + if intent.status_filter: + parts.append(f"@status:{{{intent.status_filter}}}") + + if intent.time_range and intent.time_range.get("start"): + start_dt = intent.time_range["start"] + # created_at stored as ms-since-epoch integer + if hasattr(start_dt, "timetuple"): + ts = int( + datetime(start_dt.year, start_dt.month, start_dt.day).timestamp() + * 1000 + ) + else: + ts = int(start_dt) * 1000 + parts.append(f"@created_at:[{ts} +inf]") + + return " ".join(parts) if parts else "*" + + +def _merge_filters(intent: QueryIntent, filters: Dict[str, Any]) -> None: + """Merge caller-supplied filters into an extracted intent (in-place).""" + if "entity_types" in filters: + for t in filters["entity_types"]: + if t not in intent.entity_types: + intent.entity_types.append(t) + + if "time_range" in filters and intent.time_range is None: + intent.time_range = filters["time_range"] + + if "status" in filters and intent.status_filter is None: + intent.status_filter = filters["status"] + + +def _parse_ft_results(raw: Any) -> List[Dict[str, Any]]: + """ + Parse raw FT.SEARCH response into a list of entity dicts. + + Redis returns: [total_count, key1, [field, value, ...], key2, ...] + When RETURN is not specified every stored field is returned. + """ + if not raw or not isinstance(raw, (list, tuple)) or len(raw) < 2: + return [] + + entities: List[Dict[str, Any]] = [] + # Results start at index 1; each entry is (key, field_list) pairs + i = 1 + while i < len(raw): + key = raw[i] + if isinstance(key, bytes): + key = key.decode("utf-8", errors="replace") + + fields_raw = raw[i + 1] if i + 1 < len(raw) else [] + entity: Dict[str, Any] = {} + + if isinstance(fields_raw, (list, tuple)): + # Interleaved [field, value, field, value, ...] + for j in range(0, len(fields_raw) - 1, 2): + fname = fields_raw[j] + fval = fields_raw[j + 1] + if isinstance(fname, bytes): + fname = fname.decode("utf-8", errors="replace") + if isinstance(fval, bytes): + fval = fval.decode("utf-8", errors="replace") + # Try JSON decode for list/dict fields + try: + entity[fname] = json.loads(fval) + except (TypeError, ValueError, json.JSONDecodeError): + entity[fname] = fval + + # Ensure we have at least the Redis key + if not entity: + entity["_key"] = key + + entities.append(entity) + i += 2 + + return entities + + +def _embedding_cache_key(text: str) -> str: + digest = hashlib.sha256(text.encode("utf-8")).hexdigest() + return f"{_EMBEDDING_CACHE_KEY_PREFIX}{digest}" + + +async def _generate_embedding(text: str) -> Optional[List[float]]: + """ + Generate a text embedding using the NPU worker / Ollama fallback. + + Imports lazily to avoid circular imports and to allow the module to be + imported in test environments where the service is mocked. + """ + try: + from autobot_shared.ssot_config import config + from services.npu_client import generate_embedding_with_fallback + + model_name = config.get("knowledge.embedding_model", "nomic-embed-text") + return await generate_embedding_with_fallback(text, model_name=model_name) + except Exception as exc: + logger.warning("Embedding generation failed: %s", exc) + return None diff --git a/autobot-backend/knowledge/memory_graph/query_processor_test.py b/autobot-backend/knowledge/memory_graph/query_processor_test.py new file mode 100644 index 000000000..db9668a58 --- /dev/null +++ b/autobot-backend/knowledge/memory_graph/query_processor_test.py @@ -0,0 +1,675 @@ +# Copyright (c) mrveiss. All rights reserved. +# AutoBot - AI-Powered Automation Platform +""" +Unit tests for memory graph query processor and hybrid scorer. + +Issue #3384: Phase 1 & 2 tests — all Redis and embedding calls are mocked. + +Run with: + pytest autobot-backend/knowledge/memory_graph/query_processor_test.py -v +""" + +import json +import math +import sys +import types +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Stub out heavy transitive imports before the module is loaded +# --------------------------------------------------------------------------- + + +def _make_module(name: str, **attrs) -> types.ModuleType: + mod = types.ModuleType(name) + for k, v in attrs.items(): + setattr(mod, k, v) + return mod + + +# autobot_shared stubs +_redis_mod = _make_module("autobot_shared.redis_client") +_redis_mod.get_redis_client = AsyncMock(return_value=MagicMock()) + +_ssot_mod = _make_module("autobot_shared.ssot_config") +_ssot_cfg = MagicMock() +_ssot_mod.config = _ssot_cfg + +# services.npu_client stub +_npu_mod = _make_module( + "services.npu_client", + generate_embedding_with_fallback=AsyncMock(return_value=[0.1] * 768), +) + +for _name, _mod in [ + ("autobot_shared", _make_module("autobot_shared")), + ("autobot_shared.redis_client", _redis_mod), + ("autobot_shared.ssot_config", _ssot_mod), + ("services", _make_module("services")), + ("services.npu_client", _npu_mod), +]: + sys.modules.setdefault(_name, _mod) + +# Now safe to import the modules under test +from knowledge.memory_graph.hybrid_scorer import ( # noqa: E402 + HybridScorer, + SearchResult, + _build_explanation, + _entity_to_text, + cosine_similarity, +) +from knowledge.memory_graph.query_processor import ( # noqa: E402 + MemoryGraphQueryProcessor, + QueryIntent, + _build_redis_query, + _merge_filters, + _parse_ft_results, +) + + +# =========================================================================== +# Fixtures +# =========================================================================== + + +def _make_entity( + entity_id: str = "abc123", + name: str = "System Status Bug", + entity_type: str = "BUG", + observations=None, + status: str = "completed", + created_at: int = 1_700_000_000_000, +) -> dict: + return { + "id": entity_id, + "name": name, + "type": entity_type, + "observations": observations or ["Fixed broken endpoint", "Deployed fix"], + "metadata": {"status": status}, + "created_at": created_at, + } + + +def _make_embedding(dim: int = 768, value: float = 0.1) -> list: + return [value] * dim + + +# =========================================================================== +# Tests: cosine_similarity +# =========================================================================== + + +class TestCosineSimilarity: + def test_identical_vectors_return_one(self): + vec = [1.0, 0.0, 0.0] + assert cosine_similarity(vec, vec) == pytest.approx(1.0, abs=1e-6) + + def test_orthogonal_vectors_return_zero(self): + assert cosine_similarity([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0, abs=1e-6) + + def test_opposite_vectors_clamped_to_zero(self): + # cosine of 180° is -1; we clamp to 0 + assert cosine_similarity([1.0], [-1.0]) == 0.0 + + def test_none_vector_returns_zero(self): + assert cosine_similarity(None, [1.0, 0.0]) == 0.0 + assert cosine_similarity([1.0, 0.0], None) == 0.0 + + def test_empty_vector_returns_zero(self): + assert cosine_similarity([], [1.0]) == 0.0 + + def test_dimension_mismatch_returns_zero(self): + assert cosine_similarity([1.0, 0.0], [1.0]) == 0.0 + + def test_zero_magnitude_returns_zero(self): + assert cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0 + + def test_partial_overlap(self): + a = [1.0, 1.0, 0.0] + b = [1.0, 0.0, 0.0] + # cos(45°) ≈ 0.707 + result = cosine_similarity(a, b) + assert 0.5 < result < 0.9 + + def test_high_dimensional_vectors(self): + dim = 768 + a = [1.0 / math.sqrt(dim)] * dim + b = [1.0 / math.sqrt(dim)] * dim + assert cosine_similarity(a, b) == pytest.approx(1.0, abs=1e-4) + + +# =========================================================================== +# Tests: HybridScorer.bm25_score +# =========================================================================== + + +class TestBM25Score: + def setup_method(self): + self.scorer = HybridScorer() + + def test_matching_keyword_returns_positive(self): + score, matched = self.scorer.bm25_score(["bug"], "we fixed a bug in the system") + assert score > 0.0 + assert "bug" in matched + + def test_no_keywords_returns_zero(self): + score, matched = self.scorer.bm25_score([], "some text") + assert score == 0.0 + assert matched == [] + + def test_empty_document_returns_zero(self): + score, matched = self.scorer.bm25_score(["bug"], "") + assert score == 0.0 + + def test_missing_keyword_not_in_matched(self): + score, matched = self.scorer.bm25_score(["xyz"], "we fixed a bug") + assert score == 0.0 + assert matched == [] + + def test_multiple_keywords_accumulate(self): + score_one, _ = self.scorer.bm25_score(["bug"], "we fixed a bug") + score_two, _ = self.scorer.bm25_score( + ["bug", "fixed"], "we fixed a bug" + ) + assert score_two >= score_one + + def test_score_normalised_to_one(self): + # Saturate: many occurrences of the single keyword + doc = "bug " * 200 + score, matched = self.scorer.bm25_score(["bug"], doc) + assert 0.0 <= score <= 1.0 + assert "bug" in matched + + def test_repeated_term_diminishing_returns(self): + score_one, _ = self.scorer.bm25_score(["bug"], "bug") + score_many, _ = self.scorer.bm25_score(["bug"], "bug " * 50) + # BM25 saturates — many occurrences should be less than k1+1 times one + assert score_many < 10 * score_one + + +# =========================================================================== +# Tests: _entity_to_text +# =========================================================================== + + +class TestEntityToText: + def test_name_repeated_three_times(self): + entity = {"name": "MyEntity", "type": "BUG", "observations": []} + text = _entity_to_text(entity) + assert text.lower().count("myentity") == 3 + + def test_observations_included(self): + entity = { + "name": "E", + "type": "BUG", + "observations": ["Fixed endpoint"], + } + text = _entity_to_text(entity) + assert "fixed endpoint" in text.lower() + + def test_json_string_observations_decoded(self): + entity = { + "name": "E", + "type": "BUG", + "observations": '["obs1", "obs2"]', + } + text = _entity_to_text(entity) + assert "obs1" in text.lower() + + def test_empty_entity_returns_string(self): + text = _entity_to_text({}) + assert isinstance(text, str) + + +# =========================================================================== +# Tests: QueryIntent intent extraction +# =========================================================================== + + +class TestIntentExtraction: + def setup_method(self): + redis_mock = AsyncMock() + redis_mock.json = MagicMock(return_value=AsyncMock()) + redis_mock.get = AsyncMock(return_value=None) + self.processor = MemoryGraphQueryProcessor(redis_client=redis_mock) + + def _extract(self, query: str) -> QueryIntent: + return self.processor._extract_intent(query) + + def test_today_sets_time_range(self): + intent = self._extract("What did we fix today?") + assert intent.time_range is not None + assert intent.time_range["start"] == datetime.now().date() + + def test_yesterday_sets_time_range(self): + intent = self._extract("Show me bugs from yesterday") + expected = (datetime.now() - timedelta(days=1)).date() + assert intent.time_range is not None + assert intent.time_range["start"] == expected + + def test_last_n_days_sets_time_range(self): + intent = self._extract("tasks from last 7 days") + expected = (datetime.now() - timedelta(days=7)).date() + assert intent.time_range is not None + assert intent.time_range["start"] == expected + + def test_bug_keyword_maps_entity_type(self): + intent = self._extract("show me bug reports") + assert "bug_fix" in intent.entity_types + + def test_fix_keyword_maps_entity_type(self): + intent = self._extract("what fixes were deployed?") + assert "bug_fix" in intent.entity_types + + def test_feature_keyword_maps_entity_type(self): + intent = self._extract("new features this week") + assert "feature" in intent.entity_types + + def test_completed_keyword_maps_status(self): + intent = self._extract("tasks we completed") + assert intent.status_filter == "completed" + + def test_no_filters_when_generic_query(self): + intent = self._extract("tell me everything") + assert intent.entity_types == [] + assert intent.time_range is None + assert intent.status_filter is None + + def test_keywords_exclude_stopwords(self): + intent = self._extract("what bugs did we fix today") + assert "what" not in intent.keywords + assert "did" not in intent.keywords + assert "bugs" in intent.keywords + + def test_semantic_query_is_non_empty(self): + intent = self._extract("What bugs did we fix?") + assert intent.semantic_query + + +# =========================================================================== +# Tests: _build_redis_query +# =========================================================================== + + +class TestBuildRedisQuery: + def _intent(self, **kwargs) -> QueryIntent: + i = QueryIntent() + for k, v in kwargs.items(): + setattr(i, k, v) + return i + + def test_empty_intent_returns_star(self): + assert _build_redis_query(QueryIntent()) == "*" + + def test_single_entity_type(self): + intent = self._intent(entity_types=["BUG"]) + q = _build_redis_query(intent) + assert "@type:(BUG)" in q + + def test_multiple_entity_types_pipe_separated(self): + intent = self._intent(entity_types=["BUG", "FEATURE"]) + q = _build_redis_query(intent) + assert "BUG|FEATURE" in q + + def test_status_filter(self): + intent = self._intent(status_filter="completed") + q = _build_redis_query(intent) + assert "@status:{completed}" in q + + def test_time_range_filter(self): + start = datetime.now().date() + intent = self._intent(time_range={"start": start}) + q = _build_redis_query(intent) + assert "@created_at:[" in q + assert "+inf]" in q + + def test_combined_filters(self): + intent = self._intent( + entity_types=["BUG"], + status_filter="completed", + ) + q = _build_redis_query(intent) + assert "@type:(BUG)" in q + assert "@status:{completed}" in q + + +# =========================================================================== +# Tests: _merge_filters +# =========================================================================== + + +class TestMergeFilters: + def test_merges_entity_types(self): + intent = QueryIntent(entity_types=["BUG"]) + _merge_filters(intent, {"entity_types": ["FEATURE"]}) + assert "BUG" in intent.entity_types + assert "FEATURE" in intent.entity_types + + def test_no_duplicate_entity_types(self): + intent = QueryIntent(entity_types=["BUG"]) + _merge_filters(intent, {"entity_types": ["BUG"]}) + assert intent.entity_types.count("BUG") == 1 + + def test_time_range_not_overwritten(self): + existing = {"start": datetime.now().date()} + intent = QueryIntent(time_range=existing) + _merge_filters(intent, {"time_range": {"start": "other"}}) + assert intent.time_range == existing + + def test_status_not_overwritten(self): + intent = QueryIntent(status_filter="completed") + _merge_filters(intent, {"status": "pending"}) + assert intent.status_filter == "completed" + + +# =========================================================================== +# Tests: _parse_ft_results +# =========================================================================== + + +class TestParseFtResults: + def test_empty_response_returns_empty(self): + assert _parse_ft_results(None) == [] + assert _parse_ft_results([]) == [] + assert _parse_ft_results([0]) == [] + + def test_single_result_with_fields(self): + raw = [ + 1, # total count + b"memory:entity:abc", + [b"name", b"Test Entity", b"type", b"BUG"], + ] + results = _parse_ft_results(raw) + assert len(results) == 1 + assert results[0]["name"] == "Test Entity" + assert results[0]["type"] == "BUG" + + def test_json_field_decoded(self): + obs_json = json.dumps(["obs1", "obs2"]) + raw = [ + 1, + b"memory:entity:x", + [b"observations", obs_json.encode("utf-8")], + ] + results = _parse_ft_results(raw) + assert results[0]["observations"] == ["obs1", "obs2"] + + def test_multiple_results(self): + raw = [ + 2, + b"memory:entity:a", + [b"name", b"Entity A"], + b"memory:entity:b", + [b"name", b"Entity B"], + ] + results = _parse_ft_results(raw) + assert len(results) == 2 + + +# =========================================================================== +# Tests: HybridScorer.score_and_rank (async) +# =========================================================================== + + +class TestScoreAndRank: + def setup_method(self): + self.scorer = HybridScorer() + + @pytest.mark.asyncio + async def test_returns_top_k_results(self): + entities = [_make_entity(entity_id=str(i), name=f"E{i}") for i in range(5)] + intent = QueryIntent(keywords=["bug", "fixed"]) + q_embed = _make_embedding() + + with patch( + "knowledge.memory_graph.hybrid_scorer._fetch_entity_embedding", + new=AsyncMock(return_value=_make_embedding()), + ): + results = await self.scorer.score_and_rank( + query="bug fix", + intent=intent, + candidates=entities, + query_embedding=q_embed, + limit=3, + ) + + assert len(results) == 3 + + @pytest.mark.asyncio + async def test_results_sorted_descending(self): + entities = [_make_entity(entity_id=str(i), name=f"E{i}") for i in range(4)] + intent = QueryIntent(keywords=["fixed"]) + q_embed = _make_embedding() + + with patch( + "knowledge.memory_graph.hybrid_scorer._fetch_entity_embedding", + new=AsyncMock(return_value=_make_embedding()), + ): + results = await self.scorer.score_and_rank( + "fixed", intent, entities, q_embed, limit=10 + ) + + scores = [r.score for r in results] + assert scores == sorted(scores, reverse=True) + + @pytest.mark.asyncio + async def test_missing_embedding_degrades_to_keyword(self): + entity = _make_entity(observations=["fixed the bug in endpoint"]) + intent = QueryIntent(keywords=["bug", "endpoint"]) + + with patch( + "knowledge.memory_graph.hybrid_scorer._fetch_entity_embedding", + new=AsyncMock(return_value=None), + ): + results = await self.scorer.score_and_rank( + "bug endpoint", + intent, + [entity], + query_embedding=None, + limit=5, + ) + + assert len(results) == 1 + assert results[0].semantic_score == 0.0 + assert results[0].keyword_score > 0.0 + + @pytest.mark.asyncio + async def test_empty_candidates_returns_empty(self): + results = await self.scorer.score_and_rank( + "query", QueryIntent(), [], None, 5 + ) + assert results == [] + + @pytest.mark.asyncio + async def test_score_within_bounds(self): + entity = _make_entity() + intent = QueryIntent(keywords=["fixed"]) + q_embed = _make_embedding() + + with patch( + "knowledge.memory_graph.hybrid_scorer._fetch_entity_embedding", + new=AsyncMock(return_value=_make_embedding()), + ): + results = await self.scorer.score_and_rank( + "fixed", intent, [entity], q_embed, limit=1 + ) + + r = results[0] + assert 0.0 <= r.score <= 1.0 + assert 0.0 <= r.semantic_score <= 1.0 + assert 0.0 <= r.keyword_score <= 1.0 + + +# =========================================================================== +# Tests: MemoryGraphQueryProcessor.process_query (async, integrated) +# =========================================================================== + + +class TestProcessQuery: + def _make_processor(self, candidates=None, embedding=None): + redis_mock = AsyncMock() + redis_mock.get = AsyncMock(return_value=None) + redis_mock.setex = AsyncMock() + redis_mock.execute_command = AsyncMock( + return_value=_build_raw_ft_response(candidates or []) + ) + + processor = MemoryGraphQueryProcessor(redis_client=redis_mock) + + embed_patch = patch( + "knowledge.memory_graph.query_processor._generate_embedding", + new=AsyncMock(return_value=embedding or _make_embedding()), + ) + entity_embed_patch = patch( + "knowledge.memory_graph.hybrid_scorer._fetch_entity_embedding", + new=AsyncMock(return_value=_make_embedding()), + ) + return processor, embed_patch, entity_embed_patch + + @pytest.mark.asyncio + async def test_empty_query_returns_empty(self): + processor, ep, eep = self._make_processor() + with ep, eep: + results = await processor.process_query("") + assert results == [] + + @pytest.mark.asyncio + async def test_returns_list_of_search_results(self): + candidates = [_make_entity(entity_id="1", name="Bug Fix")] + processor, ep, eep = self._make_processor(candidates=candidates) + with ep, eep: + results = await processor.process_query("bug fix") + assert isinstance(results, list) + for r in results: + assert isinstance(r, SearchResult) + + @pytest.mark.asyncio + async def test_no_candidates_returns_empty(self): + processor, ep, eep = self._make_processor(candidates=[]) + with ep, eep: + results = await processor.process_query("some query") + assert results == [] + + @pytest.mark.asyncio + async def test_respects_limit(self): + candidates = [ + _make_entity(entity_id=str(i), name=f"E{i}") for i in range(10) + ] + processor, ep, eep = self._make_processor(candidates=candidates) + with ep, eep: + results = await processor.process_query("fix bugs", limit=3) + assert len(results) <= 3 + + @pytest.mark.asyncio + async def test_caller_filters_merged(self): + processor, ep, eep = self._make_processor(candidates=[]) + intent_captured = [] + + original_build = __import__( + "knowledge.memory_graph.query_processor", + fromlist=["_build_redis_query"], + )._build_redis_query + + def capturing_build(intent): + intent_captured.append(intent) + return original_build(intent) + + with ep, eep, patch( + "knowledge.memory_graph.query_processor._build_redis_query", + side_effect=capturing_build, + ): + await processor.process_query( + "query", filters={"entity_types": ["FEATURE"]} + ) + + assert intent_captured + assert "FEATURE" in intent_captured[0].entity_types + + +# =========================================================================== +# Tests: get_entity / get_related_entities +# =========================================================================== + + +class TestEntityRetrieval: + @pytest.mark.asyncio + async def test_get_entity_returns_doc(self): + entity = _make_entity() + redis_mock = AsyncMock() + json_mock = AsyncMock() + json_mock.get = AsyncMock(return_value=entity) + redis_mock.json = MagicMock(return_value=json_mock) + + processor = MemoryGraphQueryProcessor(redis_client=redis_mock) + result = await processor.get_entity("abc123") + assert result == entity + + @pytest.mark.asyncio + async def test_get_entity_returns_none_on_error(self): + redis_mock = AsyncMock() + json_mock = AsyncMock() + json_mock.get = AsyncMock(side_effect=Exception("Redis error")) + redis_mock.json = MagicMock(return_value=json_mock) + + processor = MemoryGraphQueryProcessor(redis_client=redis_mock) + result = await processor.get_entity("bad-id") + assert result is None + + @pytest.mark.asyncio + async def test_get_related_entities_empty_when_no_key(self): + """Returns [] when source entity not found by name.""" + redis_mock = AsyncMock() + # get_entity_by_name uses FT.SEARCH — return empty response + redis_mock.execute_command = AsyncMock(return_value=[0]) + + processor = MemoryGraphQueryProcessor(redis_client=redis_mock) + results = await processor.get_related_entities("NonExistent") + assert results == [] + + @pytest.mark.asyncio + async def test_get_related_entities_follows_relations(self): + source_entity = _make_entity(name="Source Entity") + related_entity = _make_entity(name="Related Entity") + relations_doc = {"entity_id": source_entity["id"], "relations": [ + {"to": "Related Entity", "type": "relates_to"} + ]} + + json_mock = AsyncMock() + json_mock.get = AsyncMock(return_value=relations_doc) + redis_mock = AsyncMock() + redis_mock.json = MagicMock(return_value=json_mock) + + # First FT.SEARCH call → source entity; second → related entity + redis_mock.execute_command = AsyncMock(side_effect=[ + _build_raw_ft_response([source_entity]), + _build_raw_ft_response([related_entity]), + ]) + + processor = MemoryGraphQueryProcessor(redis_client=redis_mock) + results = await processor.get_related_entities("Source Entity") + assert len(results) == 1 + + +# =========================================================================== +# Helpers +# =========================================================================== + + +def _build_raw_ft_response(entities: list) -> list: + """Build a minimal FT.SEARCH-style response for a list of entity dicts.""" + raw = [len(entities)] + for i, entity in enumerate(entities): + key = f"memory:entity:{entity.get('id', i)}".encode("utf-8") + fields = [] + for fname, fval in entity.items(): + fields.append(fname.encode("utf-8")) + if isinstance(fval, (dict, list)): + fields.append(json.dumps(fval).encode("utf-8")) + else: + fields.append(str(fval).encode("utf-8")) + raw.append(key) + raw.append(fields) + return raw