diff --git a/README.md b/README.md index 72febc99..1f501009 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,17 @@ my_config = RAGLiteConfig( ) ``` +Self-query is also supported, allowing the LLM to automatically generate and apply metadata filters to refine search results based on the user's input. To enable self-query, set `self_query=True` in your `RAGLiteConfig`: + +```python +my_config = RAGLiteConfig( + db_url="duckdb:///raglite.db", + llm="gpt-4o-mini", + embedder="text-embedding-3-large", + self_query=True, # Enable self-query +) +``` + ### 2. Inserting documents > [!TIP] diff --git a/src/raglite/_config.py b/src/raglite/_config.py index 1c55a531..de715c68 100644 --- a/src/raglite/_config.py +++ b/src/raglite/_config.py @@ -80,3 +80,4 @@ class RAGLiteConfig: # Search config: you can pick any search method that returns (list[ChunkId], list[float]), # list[Chunk], or list[ChunkSpan]. search_method: SearchMethod = field(default=_vector_search, compare=False) + self_query: bool = False diff --git a/src/raglite/_database.py b/src/raglite/_database.py index d89c585b..aa5cb803 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -41,12 +41,20 @@ FloatMatrix, FloatVector, IndexId, + MetadataValue, PickledObject, ) MetadataJSON = JSON().with_variant(JSONB(), "postgresql") +def _adapt_metadata(metadata: Any) -> dict[str, MetadataValue | list[MetadataValue]]: + """Adapt metadata to the format expected by the database.""" + if not metadata: + return {} + return {k: v if isinstance(v, list) else [v] for k, v in metadata.items()} + + def hash_bytes(data: bytes, max_len: int = 16) -> str: """Hash bytes to a hexadecimal string.""" return sha256(data, usedforsecurity=False).hexdigest()[:max_len] @@ -120,7 +128,6 @@ def from_path( Document A document. """ - # Extract metadata. metadata = { "filename": doc_path.name, "uri": id, @@ -130,6 +137,8 @@ def from_path( "modified": doc_path.stat().st_mtime, **kwargs, } + # Ensure all metadata values are lists + metadata = _adapt_metadata(metadata) # Create the document instance. return Document( id=id if id is not None else hash_bytes(doc_path.read_bytes()), @@ -172,7 +181,7 @@ def from_text( first_line = content.strip().split("\n", 1)[0].strip() if len(first_line) > 80: # noqa: PLR2004 first_line = f"{first_line[:80]}..." - # Extract metadata. + metadata = { "filename": filename or first_line, "uri": id, @@ -180,6 +189,8 @@ def from_text( "size": len(content.encode()), **kwargs, } + # Ensure all metadata values are lists + metadata = _adapt_metadata(metadata) # Create the document instance. return Document( id=id if id is not None else hash_bytes(content.encode()), @@ -224,7 +235,9 @@ def from_body( index=index, headings=Chunk.truncate_headings(headings, body), body=body, - metadata_={"filename": document.filename, "url": document.url, **kwargs}, + metadata_=_adapt_metadata( + {"filename": document.filename, "url": document.url, **kwargs} + ), ) @staticmethod @@ -449,6 +462,17 @@ def get(id_: str = "default", *, config: RAGLiteConfig | None = None) -> dict[st return metadata +class Metadata(SQLModel, table=True): + """A table for metadata values, linked to field names.""" + + __tablename__ = "metadata" + + name: str = Field(..., primary_key=True) + values: list[MetadataValue] = Field( + default_factory=list, sa_column=Column("metadata", MetadataJSON) + ) + + class Eval(SQLModel, table=True): """A RAG evaluation example.""" @@ -478,6 +502,8 @@ def from_chunks( """Create a chunk from Markdown.""" document_id = contexts[0].document_id chunk_ids = [context.id for context in contexts] + # Ensure all metadata values from kwargs are lists + processed_kwargs = _adapt_metadata(kwargs) return Eval( id=hash_bytes(f"{document_id}-{chunk_ids}-{question}".encode()), document_id=document_id, @@ -485,7 +511,7 @@ def from_chunks( question=question, contexts=[str(context) for context in contexts], ground_truth=ground_truth, - metadata_=kwargs, + metadata_=processed_kwargs, ) diff --git a/src/raglite/_insert.py b/src/raglite/_insert.py index bd63c29e..2f72ddd1 100644 --- a/src/raglite/_insert.py +++ b/src/raglite/_insert.py @@ -1,5 +1,6 @@ """Index documents.""" +from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import nullcontext from functools import partial @@ -8,15 +9,79 @@ from filelock import FileLock from sqlalchemy import text from sqlalchemy.engine import make_url +from sqlalchemy.orm.attributes import flag_modified from sqlmodel import Session, col, select from tqdm.auto import tqdm from raglite._config import RAGLiteConfig -from raglite._database import Chunk, ChunkEmbedding, Document, create_database_engine +from raglite._database import ( + Chunk, + ChunkEmbedding, + Document, + Metadata, + create_database_engine, +) from raglite._embed import embed_strings, embed_strings_without_late_chunking, embedding_type from raglite._split_chunklets import split_chunklets from raglite._split_chunks import split_chunks from raglite._split_sentences import split_sentences +from raglite._typing import MetadataValue + +METADATA_EXCLUDED_FIELDS = ["filename", "uri", "url", "size", "created", "modified"] + + +def _get_database_metadata( + session: Session | None = None, + config: RAGLiteConfig | None = None, +) -> Sequence[Metadata]: + """Fetch all metadata records from the database.""" + if session: + return session.exec(select(Metadata)).all() + with Session(create_database_engine(config or RAGLiteConfig())) as new_session: + return new_session.exec(select(Metadata)).all() + + +def _aggregate_metadata_from_documents( + documents: list[Document], + metadata_excluded_fields: list[str] = METADATA_EXCLUDED_FIELDS, +) -> dict[str, set[MetadataValue]]: + """Aggregate metadata values from all documents.""" + metadata: dict[str, set[MetadataValue]] = {} + for doc in documents: + for key, value in doc.metadata_.items(): + if key in metadata_excluded_fields: + continue + if key not in metadata: + metadata[key] = set() + if isinstance(value, list): + metadata[key].update(value) + else: + metadata[key].add(value) + return metadata + + +def _update_metadata_from_documents( + session: Session, + documents: list[Document], +) -> None: + """Update or add metadata records.""" + if not documents: + return + metadata = _aggregate_metadata_from_documents(documents=documents) + existing_metadata = {record.name: record for record in _get_database_metadata(session=session)} + # Update or add metadata records. + for key, values in metadata.items(): + # Update + if key in existing_metadata: + result = existing_metadata[key] + values_to_add = set(values) - set(result.values) + if values_to_add: + result.values.extend(values_to_add) + flag_modified(result, "values") # Notify SQLAlchemy of the change + session.add(result) + # Add + else: + session.add(Metadata(name=key, values=list(values))) def _create_chunk_records( @@ -171,6 +236,8 @@ def insert_documents( # noqa: C901 session.expunge_all() # Release memory of flushed changes. num_unflushed_embeddings = 0 pbar.update() + # Update metadata table. + _update_metadata_from_documents(documents=documents, session=session) session.commit() if engine.dialect.name == "duckdb": # DuckDB does not automatically update its keyword search index [1], so we do it diff --git a/src/raglite/_search.py b/src/raglite/_search.py index cfeaa50f..c3306cfe 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -2,14 +2,16 @@ import contextlib import json +import logging import re import string from collections import defaultdict from itertools import groupby -from typing import Any +from typing import Any, ClassVar import numpy as np from langdetect import LangDetectException, detect +from pydantic import BaseModel, Field, create_model from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import joinedload from sqlmodel import Session, and_, col, func, or_, select, text @@ -20,10 +22,15 @@ ChunkEmbedding, ChunkSpan, IndexMetadata, + _adapt_metadata, create_database_engine, ) from raglite._embed import embed_strings -from raglite._typing import BasicSearchMethod, ChunkId, FloatVector, MetadataFilter +from raglite._extract import extract_with_llm +from raglite._insert import _get_database_metadata +from raglite._typing import BasicSearchMethod, ChunkId, FloatVector, MetadataFilter, MetadataValue + +logger = logging.getLogger(__name__) def vector_search( @@ -37,6 +44,12 @@ def vector_search( """Search chunks using ANN vector search.""" # Read the config. config = config or RAGLiteConfig() + # Normalize metadata filter values to lists. + metadata_filter = _adapt_metadata(metadata_filter) + # If self_query is enabled, extract metadata filters from the query. + if config.self_query and isinstance(query, str): + self_query_filter = _self_query(query, config=config) + metadata_filter = {**self_query_filter, **(metadata_filter or {})} # Embed the query. query_embedding = ( embed_strings([query], config=config)[0, :] if isinstance(query, str) else np.ravel(query) @@ -150,6 +163,12 @@ def keyword_search( """Search chunks using BM25 keyword search.""" # Read the config. config = config or RAGLiteConfig() + # Normalize metadata filter values to lists. + metadata_filter = _adapt_metadata(metadata_filter) + # If self_query is enabled, extract metadata filters from the query. + if config.self_query and isinstance(query, str): + self_query_filter = _self_query(query, config=config) + metadata_filter = {**self_query_filter, **(metadata_filter or {})} # Connect to the database. with Session(create_database_engine(config)) as session: dialect = session.get_bind().dialect.name @@ -412,3 +431,83 @@ def search_and_rerank_chunk_spans( # noqa: PLR0913 chunks = rerank_chunks(query, chunk_ids, config=config)[:num_results] chunk_spans = retrieve_chunk_spans(chunks, neighbors=neighbors, config=config) return chunk_spans + + +SELF_QUERY_PROMPT = """ +You are an assistant that extracts metadata filters from user queries to help search a knowledge base. + +Instructions: +1. For each metadata field, only populate it if the query explicitly and unambiguously mentions a specific allowed value. +2. If the query is general, ambiguous, or does not mention a field, set it to None. +3. Do NOT infer values from common knowledge or context. +4. For each field, return ONLY the numeric ID(s) from the allowed options below. Do NOT return labels or text. +5. Output your answer as a JSON object with field names as keys and lists of IDs or None as values. + +Example: +Allowed options: +- category: {0: "Technology", 1: "Health", 2: "Finance"} +- region: {0: "Europe", 1: "Asia", 2: "Americas"} + +Query: "Show me the latest news in Technology from Asia." +Output: +{"category": [0], "region": [1]} +""".strip() + + +def _self_query( + query: str, + *, + system_prompt: str = SELF_QUERY_PROMPT, + config: RAGLiteConfig | None = None, +) -> MetadataFilter: + """Extract metadata filters from a natural language query.""" + config = config or RAGLiteConfig() + # Retrieve the available metadata from the database. + metadata_records = _get_database_metadata(config=config) + if not metadata_records: + return {} + # Create dynamic Pydantic model for the metadata filter + field_ids_mapping: dict[str, dict[int, MetadataValue]] = {} + field_definitions: dict[str, Any] = {} + field_definitions["system_prompt"] = (ClassVar[str], system_prompt) + # Note: + # The LLM tends to return escaped Unicode or hexadecimal strings when asked to output + # labels directly. By assigning each allowed metadata value a numeric ID and asking + # the model to return only IDs, we avoid encoding issues and reliably map results + # back to their actual metadata values afterward. + for record in metadata_records: + field_ids_mapping[record.name] = dict(enumerate(record.values)) + # Store the mapping in + description = ( + "Return ONLY IDs from this set (use IDs, not labels). " + f"Allowed options: {field_ids_mapping[record.name]}" + ) + field_definitions[record.name] = ( + list[int] | None, + Field(default=None, description=description), + ) + metadata_filter_model = create_model( + "MetadataFilterModel", **field_definitions, __base__=BaseModel + ) + # Call extract_with_llm + try: + result = extract_with_llm( + return_type=metadata_filter_model, + user_prompt=query, + config=config, + temperature=0, + ) + except ValueError as e: + logger.debug("Failed to extract metadata filter: %s", e) + return {} + else: + metadata_filter = result.model_dump(exclude_none=True) + # Convert from field IDs to actual metadata values. + for field, value_ids in metadata_filter.items(): + if field in field_ids_mapping: + metadata_filter[field] = [ + field_ids_mapping[field].get(value_id) + for value_id in value_ids + if value_id in field_ids_mapping[field] + ] + return metadata_filter diff --git a/src/raglite/_typing.py b/src/raglite/_typing.py index d2a07e9e..21fe0333 100644 --- a/src/raglite/_typing.py +++ b/src/raglite/_typing.py @@ -24,8 +24,8 @@ DistanceMetric = Literal["cosine", "dot", "l1", "l2"] -MetadataValue = str | int | float | bool | list[str] | list[int] | list[float] | list[bool] -MetadataFilter = Mapping[str, MetadataValue] +MetadataValue = str | int | float | bool # | list[str] | list[int] | list[float] | list[bool] +MetadataFilter = Mapping[str, list[MetadataValue] | MetadataValue] FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]] FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]] diff --git a/tests/test_insert.py b/tests/test_insert.py index acc9246b..68882837 100644 --- a/tests/test_insert.py +++ b/tests/test_insert.py @@ -7,6 +7,7 @@ from raglite._config import RAGLiteConfig from raglite._database import Chunk, Document, create_database_engine +from raglite._insert import _get_database_metadata from raglite._markdown import document_to_markdown @@ -43,3 +44,15 @@ def test_insert(raglite_test_config: RAGLiteConfig) -> None: doc = document_to_markdown(doc_path) doc = doc.replace("\n", "").strip() assert restored_document == doc, "Restored document does not match the original input." + # Verify that the document metadata matches. + metadata = _get_database_metadata(session) + assert len(metadata) > 0, "No metadata found for the document" + # Check that the metadata values match the original document metadata. + for meta in metadata: + assert meta.name in document.metadata_, ( + f"Metadata {meta.name} not found in document metadata" + ) + for value in document.metadata_[meta.name]: + assert value in meta.values, ( + f"Metadata value '{value}' for '{meta.name}' not found in database metadata" + ) diff --git a/tests/test_rag.py b/tests/test_rag.py index 151cd96d..623721f6 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -60,3 +60,20 @@ def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None: # Verify that no RAG context was retrieved. assert [message["role"] for message in messages] == ["user", "assistant"] assert not chunk_spans + + +def test_retrieve_context_self_query(raglite_test_config: RAGLiteConfig) -> None: + """Test retrieve_context with self_query functionality.""" + from dataclasses import replace + + new_config = replace(raglite_test_config, self_query=True) + query = "What does Albert Einstein's paper say about time dilation?" + chunk_spans = retrieve_context(query=query, num_chunks=5, config=new_config) + assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans) + for chunk_span in chunk_spans: + assert chunk_span.document.metadata_.get("type") == ["Paper"], ( + f"Expected type='Paper', got {chunk_span.document.metadata_.get('type')}" + ) + assert chunk_span.document.metadata_.get("author") == ["Albert Einstein"], ( + f"Expected author='Albert Einstein', got {chunk_span.document.metadata_.get('author')}" + ) diff --git a/tests/test_search.py b/tests/test_search.py index abf7e817..cadb88bd 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1,5 +1,7 @@ """Test RAGLite's search functionality.""" +from typing import Any + import pytest from raglite import ( @@ -12,7 +14,8 @@ vector_search, ) from raglite._database import Chunk, ChunkSpan -from raglite._typing import BasicSearchMethod +from raglite._search import _self_query +from raglite._typing import BasicSearchMethod, MetadataFilter @pytest.fixture( @@ -88,7 +91,7 @@ def test_search_metadata_filter( """Test searching with metadata filtering that should return results.""" query = "What does it mean for two events to be simultaneous?" num_results = 5 - metadata_filter = {"type": "Paper", "topic": "Physics"} + metadata_filter: MetadataFilter = {"type": "Paper", "topic": "Physics"} # Verify basic properties chunk_ids, scores = search_method( @@ -104,15 +107,15 @@ def test_search_metadata_filter( chunks = retrieve_chunks(chunk_ids, config=raglite_test_config) assert all(isinstance(chunk, Chunk) for chunk in chunks) for chunk in chunks: - assert chunk.metadata_.get("type") == "Paper", ( + assert chunk.metadata_.get("type") == ["Paper"], ( f"Expected type='Paper', got {chunk.metadata_.get('type')}" ) - assert chunk.metadata_.get("topic") == "Physics", ( + assert chunk.metadata_.get("topic") == ["Physics"], ( f"Expected topic='Physics', got {chunk.metadata_.get('topic')}" ) # Test filtering for a different topic that should return no results - metadata_filter_empty = {"type": "Paper", "topic": "Mathematics"} + metadata_filter_empty: MetadataFilter = {"type": "Paper", "topic": "Mathematics"} chunk_ids_empty, scores_empty = search_method( query, num_results=num_results, @@ -122,3 +125,19 @@ def test_search_metadata_filter( assert len(chunk_ids_empty) == len(scores_empty) == 0, ( "Expected no results when filtering for Mathematics papers" ) + + +def test_self_query(raglite_test_config: RAGLiteConfig) -> None: + """Test self-query functionality that extracts metadata filters from queries.""" + # Test 1: Query that should extract "Physics" from topic field + query1 = "I want to learn about the topic Physics." + expected_topic = ["Physics"] + actual_filter1 = _self_query(query1, config=raglite_test_config) + assert actual_filter1.get("topic") == expected_topic, ( + f"Expected topic '{expected_topic}', got {actual_filter1.get('topic')}" + ) + # Test 2: Query with non-existent metadata values should return empty filter + query2 = "What is the price of a Bugatti Chiron?" + expected_filter2: dict[str, Any] = {} + actual_filter2 = _self_query(query2, config=raglite_test_config) + assert actual_filter2 == expected_filter2, f"Expected {expected_filter2}, got {actual_filter2}"