Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
11a3850
feat: add self-query functionality
jirastorza Oct 9, 2025
f954bca
fix: modified self_query_prompt
jirastorza Oct 10, 2025
f0e66da
fix: modified self_query_prompt
jirastorza Oct 10, 2025
2e6c436
fix: code simplification
jirastorza Oct 14, 2025
3507ad5
fix: test rag
jirastorza Oct 14, 2025
238d3a1
fix: add self_query option to config and update tool calling logic.
jirastorza Oct 14, 2025
b8055da
fix: corret logger
jirastorza Oct 15, 2025
9e32790
fix: linting
jirastorza Oct 15, 2025
c8e4fa9
fix: simplify rag test.
jirastorza Oct 15, 2025
ff97cd2
fix: remove repetitive self_query call.
jirastorza Oct 15, 2025
e12ed5b
fix: move self_query to _search.py
jirastorza Oct 16, 2025
752ea2b
fix: modify test structure.
jirastorza Oct 16, 2025
b0b46a6
fix: allow list metadata values.
jirastorza Oct 16, 2025
5d575e9
fix: allow list type metadata handling.
jirastorza Oct 16, 2025
b32f070
fix: reduce MetadataValues to hashable types, modify document metadat…
jirastorza Oct 17, 2025
f937fe6
fix: adapt test.
jirastorza Oct 17, 2025
f68d1c7
fix: adapt test case to changes.
jirastorza Oct 17, 2025
ecbcae2
fix: additional test fix.
jirastorza Oct 17, 2025
fb5a01b
fix: database chunk and document metadata.
jirastorza Oct 17, 2025
15a6000
fix: update README.
jirastorza Oct 22, 2025
f20c512
Merge remote-tracking branch 'origin/main' into self-query
jirastorza Oct 28, 2025
1e10550
fix: ensure metadata is stored as proper JSON without escape characters
jirastorza Oct 29, 2025
723931d
fix: handle hex byte escape sequences in metadata filter values
jirastorza Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/raglite/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,4 @@ class RAGLiteConfig:
# Search config: you can pick any search method that returns (list[ChunkId], list[float]),
# list[Chunk], or list[ChunkSpan].
search_method: SearchMethod = field(default=_vector_search, compare=False)
self_query: bool = False
10 changes: 10 additions & 0 deletions src/raglite/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
FloatMatrix,
FloatVector,
IndexId,
MetadataValue,
PickledObject,
)

Expand Down Expand Up @@ -438,6 +439,15 @@ def get(id_: str = "default", *, config: RAGLiteConfig | None = None) -> dict[st
return metadata


class Metadata(SQLModel, table=True):
"""A table for metadata values, linked to field names."""

__tablename__ = "metadata"

name: str = Field(..., primary_key=True)
values: list[MetadataValue] = Field(default_factory=list, sa_column=Column(JSON))


class Eval(SQLModel, table=True):
"""A RAG evaluation example."""

Expand Down
60 changes: 59 additions & 1 deletion src/raglite/_insert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Index documents."""

from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import nullcontext
from functools import partial
Expand All @@ -8,15 +9,70 @@
from filelock import FileLock
from sqlalchemy import text
from sqlalchemy.engine import make_url
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import Session, col, select
from tqdm.auto import tqdm

from raglite._config import RAGLiteConfig
from raglite._database import Chunk, ChunkEmbedding, Document, create_database_engine
from raglite._database import Chunk, ChunkEmbedding, Document, Metadata, create_database_engine
from raglite._embed import embed_strings, embed_strings_without_late_chunking, embedding_type
from raglite._split_chunklets import split_chunklets
from raglite._split_chunks import split_chunks
from raglite._split_sentences import split_sentences
from raglite._typing import MetadataValue

METADATA_EXCLUDED_FIELDS = ["filename", "uri", "url", "size", "created", "modified"]


def _get_database_metadata(
session: Session | None = None,
config: RAGLiteConfig | None = None,
) -> Sequence[Metadata]:
"""Fetch all metadata records from the database."""
if session:
return session.exec(select(Metadata)).all()
with Session(create_database_engine(config or RAGLiteConfig())) as new_session:
return new_session.exec(select(Metadata)).all()


def _aggregate_metadata_from_documents(
documents: list[Document],
metadata_excluded_fields: list[str] = METADATA_EXCLUDED_FIELDS,
) -> dict[str, set[MetadataValue]]:
"""Aggregate metadata values from all documents."""
metadata: dict[str, set[MetadataValue]] = {}
for doc in documents:
for key, value in doc.metadata_.items():
if key in metadata_excluded_fields:
continue
if key not in metadata:
metadata[key] = set()
metadata[key].add(value)
return metadata


def _update_metadata_from_documents(
session: Session,
documents: list[Document],
) -> None:
"""Update or add metadata records."""
if not documents:
return
metadata = _aggregate_metadata_from_documents(documents=documents)
existing_metadata = {record.name: record for record in _get_database_metadata(session=session)}
# Update or add metadata records.
for key, values in metadata.items():
# Update
if key in existing_metadata:
result = existing_metadata[key]
values_to_add = set(values) - set(result.values)
if values_to_add:
result.values.extend(values_to_add)
flag_modified(result, "values") # Notify SQLAlchemy of the change
session.add(result)
# Add
else:
session.add(Metadata(name=key, values=list(values)))


def _create_chunk_records(
Expand Down Expand Up @@ -171,6 +227,8 @@ def insert_documents( # noqa: C901
session.expunge_all() # Release memory of flushed changes.
num_unflushed_embeddings = 0
pbar.update()
# Update metadata table.
_update_metadata_from_documents(documents=documents, session=session)
session.commit()
if engine.dialect.name == "duckdb":
# DuckDB does not automatically update its keyword search index [1], so we do it
Expand Down
57 changes: 56 additions & 1 deletion src/raglite/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import json
from collections.abc import AsyncIterator, Callable, Iterator
from typing import Any
from typing import Any, ClassVar, Literal
from venv import logger

import numpy as np
from litellm import ( # type: ignore[attr-defined]
Expand All @@ -12,9 +13,12 @@
stream_chunk_builder,
supports_function_calling,
)
from pydantic import BaseModel, create_model

from raglite._config import RAGLiteConfig
from raglite._database import Chunk, ChunkSpan
from raglite._extract import extract_with_llm
from raglite._insert import _get_database_metadata
from raglite._litellm import get_context_size
from raglite._search import retrieve_chunk_spans
from raglite._typing import MetadataFilter
Expand All @@ -36,6 +40,15 @@
{user_prompt}
""".strip()

SELF_QUERY_PROMPT = """
You extract metadata filters from user queries to help search a knowledge base.

Rules:
- Only populate a field when the query explicitly and unambiguously mentions a specific allowed value for that field
- If the query is general, ambiguous, or doesn't mention a field, leave it as None
- Do not infer values from common knowledge, popularity, or context from other fields
""".strip()


def retrieve_context(
query: str,
Expand All @@ -47,6 +60,10 @@ def retrieve_context(
"""Retrieve context for RAG."""
# Call the search method.
config = config or RAGLiteConfig()
# If self_query is enabled, extract metadata filters from the query.
if config.self_query:
self_query_filter = _self_query(query, config=config)
metadata_filter = {**self_query_filter, **(metadata_filter or {})}
results = config.search_method(
query, num_results=num_chunks, metadata_filter=metadata_filter, config=config
)
Expand Down Expand Up @@ -152,6 +169,8 @@ def _run_tools(
if tool_call.function.name == "search_knowledge_base":
kwargs = json.loads(tool_call.function.arguments)
kwargs["config"] = config
if config.self_query:
kwargs["metadata_filter"] = _self_query(**kwargs)
chunk_spans = retrieve_context(**kwargs)
tool_messages.append(
{
Expand All @@ -173,6 +192,42 @@ def _run_tools(
return tool_messages


def _self_query(
query: str,
*,
system_prompt: str = SELF_QUERY_PROMPT,
config: RAGLiteConfig | None = None,
) -> MetadataFilter:
"""Extract metadata filters from a natural language query."""
config = config or RAGLiteConfig()
# Retrieve the available metadata from the database.
metadata_records = _get_database_metadata(config=config)
if not metadata_records:
return {}
# Create dynamic Pydantic model for the metadata filter
field_definitions: dict[str, Any] = {}
field_definitions["system_prompt"] = (ClassVar[str], system_prompt)
for record in metadata_records:
field_definitions[record.name] = (Literal[tuple(record.values)] | None, None)
metadata_filter_model = create_model(
"MetadataFilterModel", **field_definitions, __base__=BaseModel
)
# Call extract_with_llm
try:
result = extract_with_llm(
return_type=metadata_filter_model,
user_prompt=query,
config=config,
temperature=0,
)
except ValueError as e:
logger.debug(f"Failed to extract metadata filter: {e}")
return {}
else:
metadata_filter = result.model_dump()
return {k: v for k, v in metadata_filter.items() if v is not None}


def rag(
messages: list[dict[str, str]],
*,
Expand Down
12 changes: 12 additions & 0 deletions tests/test_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from raglite._config import RAGLiteConfig
from raglite._database import Chunk, Document, create_database_engine
from raglite._insert import _get_database_metadata
from raglite._markdown import document_to_markdown


Expand Down Expand Up @@ -43,3 +44,14 @@ def test_insert(raglite_test_config: RAGLiteConfig) -> None:
doc = document_to_markdown(doc_path)
doc = doc.replace("\n", "").strip()
assert restored_document == doc, "Restored document does not match the original input."
# Verify that the document metadata matches.
metadata = _get_database_metadata(session)
assert len(metadata) > 0, "No metadata found for the document"
# Check that the metadata values match the original document metadata.
for meta in metadata:
assert meta.name in document.metadata_, (
f"Metadata {meta.name} not found in document metadata"
)
assert document.metadata_[meta.name] in meta.values, (
f"Metadata value {document.metadata_[meta.name]} for {meta.name} not found in metadata values {meta.values}"
)
34 changes: 33 additions & 1 deletion tests/test_rag.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Test RAGLite's RAG functionality."""

import json
from typing import Any

from raglite import (
RAGLiteConfig,
add_context,
retrieve_context,
)
from raglite._database import ChunkSpan
from raglite._rag import rag
from raglite._rag import _self_query, rag


def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None:
Expand Down Expand Up @@ -60,3 +61,34 @@ def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None:
# Verify that no RAG context was retrieved.
assert [message["role"] for message in messages] == ["user", "assistant"]
assert not chunk_spans


def test_self_query(raglite_test_config: RAGLiteConfig) -> None:
"""Test self-query functionality that extracts metadata filters from queries."""
# Test 1: Query that should extract "Physics" from topic field
query1 = "Retrieve Physics papers."
expected_filter1 = {"topic": "Physics", "type": "Paper"}
actual_filter1 = _self_query(query1, config=raglite_test_config)
assert actual_filter1 == expected_filter1, f"Expected {expected_filter1}, got {actual_filter1}"
# Test 2: Query with non-existent metadata values should return empty filter
query2 = "What is the price of a Bugatti Chiron?"
expected_filter2: dict[str, Any] = {}
actual_filter2 = _self_query(query2, config=raglite_test_config)
assert actual_filter2 == expected_filter2, f"Expected {expected_filter2}, got {actual_filter2}"


def test_retrieve_context_self_query(raglite_test_config: RAGLiteConfig) -> None:
"""Test retrieve_context with self_query functionality."""
from dataclasses import replace

new_config = replace(raglite_test_config, self_query=True)
query = "What does Albert Einstein's paper say about time dilation?"
chunk_spans = retrieve_context(query=query, num_chunks=5, config=new_config)
assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)
for chunk_span in chunk_spans:
assert chunk_span.document.metadata_.get("type") == "Paper", (
f"Expected type='Paper', got {chunk_span.document.metadata_.get('type')}"
)
assert chunk_span.document.metadata_.get("author") == "Albert Einstein", (
f"Expected author='Albert Einstein', got {chunk_span.document.metadata_.get('author')}"
)