diff --git a/hindsight-api-slim/hindsight_api/api/http.py b/hindsight-api-slim/hindsight_api/api/http.py index 4580ebd50..00d119c1d 100644 --- a/hindsight-api-slim/hindsight_api/api/http.py +++ b/hindsight-api-slim/hindsight_api/api/http.py @@ -169,6 +169,15 @@ class RecallRequest(BaseModel): "Each group is a leaf {tags, match} or compound {and: [...]}, {or: [...]}, {not: ...}.", ) + @field_validator("query") + @classmethod + def validate_query_not_empty(cls, v: str) -> str: + from ..engine.search.retrieval import tokenize_query + + if not tokenize_query(v): + raise ValueError("query must contain at least one word character after normalization") + return v + @model_validator(mode="after") def validate_tags_exclusive(self) -> "RecallRequest": if self.tags is not None and self.tag_groups is not None: diff --git a/hindsight-api-slim/hindsight_api/engine/search/retrieval.py b/hindsight-api-slim/hindsight_api/engine/search/retrieval.py index e4c43e5de..403855d1f 100644 --- a/hindsight-api-slim/hindsight_api/engine/search/retrieval.py +++ b/hindsight-api-slim/hindsight_api/engine/search/retrieval.py @@ -10,6 +10,7 @@ import asyncio import logging +import re from dataclasses import dataclass, field from datetime import UTC, datetime from typing import Optional @@ -26,6 +27,15 @@ logger = logging.getLogger(__name__) +def tokenize_query(query_text: str) -> list[str]: + """Normalize query text and split into BM25 tokens. + + Strips punctuation, lowercases, and splits on whitespace. + Returns an empty list when the query contains no word characters. + """ + return re.sub(r"[^\w\s]", " ", query_text.lower()).split() + + @dataclass class ParallelRetrievalResult: """Result from parallel retrieval across all methods.""" @@ -129,12 +139,9 @@ async def retrieve_semantic_bm25_combined( Returns: Dict mapping fact_type -> (semantic_results, bm25_results) """ - import re - result_dict: dict[str, tuple[list[RetrievalResult], list[RetrievalResult]]] = {ft: ([], []) for ft in fact_types} - sanitized_text = re.sub(r"[^\w\s]", " ", query_text.lower()) - tokens = [token for token in sanitized_text.split() if token] + tokens = tokenize_query(query_text) # Over-fetch for HNSW approximation; semantic results trimmed to limit in Python. hnsw_fetch = max(limit * 5, 100) @@ -148,11 +155,15 @@ async def retrieve_semantic_bm25_combined( # --- Parameter layout --- # $1 = query_emb_str (semantic arms) # $2 = bank_id - # $3 = limit (BM25 LIMIT; semantic uses inlined hnsw_fetch literal) - # $4 = bm25_text (only when tokens present) - # $N = tags (N=4 when no tokens, N=5 when tokens present) - # $M+ = tag_groups params (one per leaf, starting after tags param) - tags_param_idx = 5 if tokens else 4 + # When tokens present: + # $3 = limit (BM25 LIMIT; semantic uses inlined hnsw_fetch literal) + # $4 = bm25_text + # $5 = tags (if present) + # $6+ = tag_groups params (one per leaf) + # When no tokens ($3 is skipped — not included in params to avoid type inference gap): + # $3 = tags (if present) + # $4+ = tag_groups params (one per leaf) + tags_param_idx = 5 if tokens else 3 tags_clause = build_tags_where_clause_simple(tags, tags_param_idx, match=tags_match) # tag_groups params start immediately after the tags param slot @@ -222,9 +233,10 @@ async def retrieve_semantic_bm25_combined( query = "\nUNION ALL\n".join(arms) - params: list = [query_emb_str, bank_id, limit] + params: list = [query_emb_str, bank_id] if tokens: - params.append(bm25_text_param) + params.append(limit) # $3: BM25 LIMIT (only referenced when tokens are present) + params.append(bm25_text_param) # $4 if tags: params.append(tags) params.extend(groups_params)