Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
11a3850
feat: add self-query functionality
jirastorza Oct 9, 2025
f954bca
fix: modified self_query_prompt
jirastorza Oct 10, 2025
f0e66da
fix: modified self_query_prompt
jirastorza Oct 10, 2025
2e6c436
fix: code simplification
jirastorza Oct 14, 2025
3507ad5
fix: test rag
jirastorza Oct 14, 2025
238d3a1
fix: add self_query option to config and update tool calling logic.
jirastorza Oct 14, 2025
b8055da
fix: corret logger
jirastorza Oct 15, 2025
9e32790
fix: linting
jirastorza Oct 15, 2025
c8e4fa9
fix: simplify rag test.
jirastorza Oct 15, 2025
ff97cd2
fix: remove repetitive self_query call.
jirastorza Oct 15, 2025
e12ed5b
fix: move self_query to _search.py
jirastorza Oct 16, 2025
752ea2b
fix: modify test structure.
jirastorza Oct 16, 2025
b0b46a6
fix: allow list metadata values.
jirastorza Oct 16, 2025
5d575e9
fix: allow list type metadata handling.
jirastorza Oct 16, 2025
b32f070
fix: reduce MetadataValues to hashable types, modify document metadat…
jirastorza Oct 17, 2025
f937fe6
fix: adapt test.
jirastorza Oct 17, 2025
f68d1c7
fix: adapt test case to changes.
jirastorza Oct 17, 2025
ecbcae2
fix: additional test fix.
jirastorza Oct 17, 2025
fb5a01b
fix: database chunk and document metadata.
jirastorza Oct 17, 2025
15a6000
fix: update README.
jirastorza Oct 22, 2025
f20c512
Merge remote-tracking branch 'origin/main' into self-query
jirastorza Oct 28, 2025
1e10550
fix: ensure metadata is stored as proper JSON without escape characters
jirastorza Oct 29, 2025
723931d
fix: handle hex byte escape sequences in metadata filter values
jirastorza Oct 29, 2025
775cae3
fix: sanitize LLM metadata output to remove NULs and decode escaped c…
jirastorza Oct 30, 2025
1e3cb2d
docs: clarify comment explaining why LLM output is cleaned after extr…
jirastorza Oct 30, 2025
ed8558e
fix: remove metadata filter decoding
jirastorza Oct 30, 2025
5586c08
fix: decode escaped Unicode sequences in metadata_filter
jirastorza Oct 30, 2025
f83e57a
fix: encode query with ensure_ascii for consistent Unicode handling i…
jirastorza Oct 30, 2025
dc4e62a
feat: use ID-based metadata mapping for more reliable self-query extr…
jirastorza Oct 30, 2025
1e2c9a0
feat: use ID-based metadata mapping for more reliable self-query extr…
jirastorza Oct 30, 2025
f8225f5
fix: update self_query template for small model extraction
jirastorza Oct 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/raglite/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,4 @@ class RAGLiteConfig:
# Search config: you can pick any search method that returns (list[ChunkId], list[float]),
# list[Chunk], or list[ChunkSpan].
search_method: SearchMethod = field(default=_vector_search, compare=False)
self_query: bool = False
32 changes: 28 additions & 4 deletions src/raglite/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,20 @@
FloatMatrix,
FloatVector,
IndexId,
MetadataValue,
PickledObject,
)

MetadataJSON = JSON().with_variant(JSONB(), "postgresql")


def _adapt_metadata(metadata: Any) -> dict[str, MetadataValue | list[MetadataValue]]:
"""Adapt metadata to the format expected by the database."""
if not metadata:
return {}
return {k: v if isinstance(v, list) else [v] for k, v in metadata.items()}


def hash_bytes(data: bytes, max_len: int = 16) -> str:
"""Hash bytes to a hexadecimal string."""
return sha256(data, usedforsecurity=False).hexdigest()[:max_len]
Expand Down Expand Up @@ -112,7 +120,6 @@ def from_path(
Document
A document.
"""
# Extract metadata.
metadata = {
"filename": doc_path.name,
"uri": id,
Expand All @@ -122,6 +129,8 @@ def from_path(
"modified": doc_path.stat().st_mtime,
**kwargs,
}
# Ensure all metadata values are lists
metadata = _adapt_metadata(metadata)
# Create the document instance.
return Document(
id=id if id is not None else hash_bytes(doc_path.read_bytes()),
Expand Down Expand Up @@ -164,14 +173,16 @@ def from_text(
first_line = content.strip().split("\n", 1)[0].strip()
if len(first_line) > 80: # noqa: PLR2004
first_line = f"{first_line[:80]}..."
# Extract metadata.

metadata = {
"filename": filename or first_line,
"uri": id,
"url": url,
"size": len(content.encode()),
**kwargs,
}
# Ensure all metadata values are lists
metadata = _adapt_metadata(metadata)
# Create the document instance.
return Document(
id=id if id is not None else hash_bytes(content.encode()),
Expand Down Expand Up @@ -213,7 +224,9 @@ def from_body(
index=index,
headings=Chunk.truncate_headings(headings, body),
body=body,
metadata_={"filename": document.filename, "url": document.url, **kwargs},
metadata_=_adapt_metadata(
{"filename": document.filename, "url": document.url, **kwargs}
),
)

@staticmethod
Expand Down Expand Up @@ -438,6 +451,15 @@ def get(id_: str = "default", *, config: RAGLiteConfig | None = None) -> dict[st
return metadata


class Metadata(SQLModel, table=True):
"""A table for metadata values, linked to field names."""

__tablename__ = "metadata"

name: str = Field(..., primary_key=True)
values: list[MetadataValue] = Field(default_factory=list, sa_column=Column(JSON))


class Eval(SQLModel, table=True):
"""A RAG evaluation example."""

Expand Down Expand Up @@ -467,14 +489,16 @@ def from_chunks(
"""Create a chunk from Markdown."""
document_id = contexts[0].document_id
chunk_ids = [context.id for context in contexts]
# Ensure all metadata values from kwargs are lists
processed_kwargs = _adapt_metadata(kwargs)
return Eval(
id=hash_bytes(f"{document_id}-{chunk_ids}-{question}".encode()),
document_id=document_id,
chunk_ids=chunk_ids,
question=question,
contexts=[str(context) for context in contexts],
ground_truth=ground_truth,
metadata_=kwargs,
metadata_=processed_kwargs,
)


Expand Down
69 changes: 68 additions & 1 deletion src/raglite/_insert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Index documents."""

from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import nullcontext
from functools import partial
Expand All @@ -8,15 +9,79 @@
from filelock import FileLock
from sqlalchemy import text
from sqlalchemy.engine import make_url
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import Session, col, select
from tqdm.auto import tqdm

from raglite._config import RAGLiteConfig
from raglite._database import Chunk, ChunkEmbedding, Document, create_database_engine
from raglite._database import (
Chunk,
ChunkEmbedding,
Document,
Metadata,
create_database_engine,
)
from raglite._embed import embed_strings, embed_strings_without_late_chunking, embedding_type
from raglite._split_chunklets import split_chunklets
from raglite._split_chunks import split_chunks
from raglite._split_sentences import split_sentences
from raglite._typing import MetadataValue

METADATA_EXCLUDED_FIELDS = ["filename", "uri", "url", "size", "created", "modified"]


def _get_database_metadata(
session: Session | None = None,
config: RAGLiteConfig | None = None,
) -> Sequence[Metadata]:
"""Fetch all metadata records from the database."""
if session:
return session.exec(select(Metadata)).all()
with Session(create_database_engine(config or RAGLiteConfig())) as new_session:
return new_session.exec(select(Metadata)).all()


def _aggregate_metadata_from_documents(
documents: list[Document],
metadata_excluded_fields: list[str] = METADATA_EXCLUDED_FIELDS,
) -> dict[str, set[MetadataValue]]:
"""Aggregate metadata values from all documents."""
metadata: dict[str, set[MetadataValue]] = {}
for doc in documents:
for key, value in doc.metadata_.items():
if key in metadata_excluded_fields:
continue
if key not in metadata:
metadata[key] = set()
if isinstance(value, list):
metadata[key].update(value)
else:
metadata[key].add(value)
return metadata


def _update_metadata_from_documents(
session: Session,
documents: list[Document],
) -> None:
"""Update or add metadata records."""
if not documents:
return
metadata = _aggregate_metadata_from_documents(documents=documents)
existing_metadata = {record.name: record for record in _get_database_metadata(session=session)}
# Update or add metadata records.
for key, values in metadata.items():
# Update
if key in existing_metadata:
result = existing_metadata[key]
values_to_add = set(values) - set(result.values)
if values_to_add:
result.values.extend(values_to_add)
flag_modified(result, "values") # Notify SQLAlchemy of the change
session.add(result)
# Add
else:
session.add(Metadata(name=key, values=list(values)))


def _create_chunk_records(
Expand Down Expand Up @@ -171,6 +236,8 @@ def insert_documents( # noqa: C901
session.expunge_all() # Release memory of flushed changes.
num_unflushed_embeddings = 0
pbar.update()
# Update metadata table.
_update_metadata_from_documents(documents=documents, session=session)
session.commit()
if engine.dialect.name == "duckdb":
# DuckDB does not automatically update its keyword search index [1], so we do it
Expand Down
73 changes: 71 additions & 2 deletions src/raglite/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

import contextlib
import json
import logging
import re
import string
from collections import defaultdict
from itertools import groupby
from typing import Any
from typing import Any, ClassVar

import numpy as np
from langdetect import LangDetectException, detect
from pydantic import BaseModel, Field, create_model
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import joinedload
from sqlmodel import Session, and_, col, func, or_, select, text
Expand All @@ -20,10 +22,15 @@
ChunkEmbedding,
ChunkSpan,
IndexMetadata,
_adapt_metadata,
create_database_engine,
)
from raglite._embed import embed_strings
from raglite._typing import BasicSearchMethod, ChunkId, FloatVector, MetadataFilter
from raglite._extract import extract_with_llm
from raglite._insert import _get_database_metadata
from raglite._typing import BasicSearchMethod, ChunkId, FloatVector, MetadataFilter, MetadataValue

logger = logging.getLogger(__name__)


def vector_search(
Expand All @@ -37,6 +44,12 @@ def vector_search(
"""Search chunks using ANN vector search."""
# Read the config.
config = config or RAGLiteConfig()
# Normalize metadata filter values to lists.
metadata_filter = _adapt_metadata(metadata_filter)
# If self_query is enabled, extract metadata filters from the query.
if config.self_query and isinstance(query, str):
self_query_filter = _self_query(query, config=config)
metadata_filter = {**self_query_filter, **(metadata_filter or {})}
# Embed the query.
query_embedding = (
embed_strings([query], config=config)[0, :] if isinstance(query, str) else np.ravel(query)
Expand Down Expand Up @@ -150,6 +163,12 @@ def keyword_search(
"""Search chunks using BM25 keyword search."""
# Read the config.
config = config or RAGLiteConfig()
# Normalize metadata filter values to lists.
metadata_filter = _adapt_metadata(metadata_filter)
# If self_query is enabled, extract metadata filters from the query.
if config.self_query and isinstance(query, str):
self_query_filter = _self_query(query, config=config)
metadata_filter = {**self_query_filter, **(metadata_filter or {})}
# Connect to the database.
with Session(create_database_engine(config)) as session:
dialect = session.get_bind().dialect.name
Expand Down Expand Up @@ -412,3 +431,53 @@ def search_and_rerank_chunk_spans( # noqa: PLR0913
chunks = rerank_chunks(query, chunk_ids, config=config)[:num_results]
chunk_spans = retrieve_chunk_spans(chunks, neighbors=neighbors, config=config)
return chunk_spans


SELF_QUERY_PROMPT = """
You extract metadata filters from user queries to help search a knowledge base.

Rules:
- Only populate a field when the query explicitly and unambiguously mentions a specific allowed value for that field
- If the query is general, ambiguous, or doesn't mention a field, leave it as None
- Do not infer values from common knowledge, popularity, or context from other fields
""".strip()


def _self_query(
query: str,
*,
system_prompt: str = SELF_QUERY_PROMPT,
config: RAGLiteConfig | None = None,
) -> MetadataFilter:
"""Extract metadata filters from a natural language query."""
config = config or RAGLiteConfig()
# Retrieve the available metadata from the database.
metadata_records = _get_database_metadata(config=config)
if not metadata_records:
return {}
# Create dynamic Pydantic model for the metadata filter
field_definitions: dict[str, Any] = {}
field_definitions["system_prompt"] = (ClassVar[str], system_prompt)
for record in metadata_records:
description = f"Allowed values are: {json.dumps(record.values)}"
field_definitions[record.name] = (
list[MetadataValue] | None,
Field(default=None, description=description),
)
metadata_filter_model = create_model(
"MetadataFilterModel", **field_definitions, __base__=BaseModel
)
# Call extract_with_llm
try:
result = extract_with_llm(
return_type=metadata_filter_model,
user_prompt=query,
config=config,
temperature=0,
)
except ValueError as e:
logger.debug("Failed to extract metadata filter: %s", e)
return {}
else:
metadata_filter = result.model_dump(exclude_none=True)
return metadata_filter
4 changes: 2 additions & 2 deletions src/raglite/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

DistanceMetric = Literal["cosine", "dot", "l1", "l2"]

MetadataValue = str | int | float | bool | list[str] | list[int] | list[float] | list[bool]
MetadataFilter = Mapping[str, MetadataValue]
MetadataValue = str | int | float | bool # | list[str] | list[int] | list[float] | list[bool]
MetadataFilter = Mapping[str, list[MetadataValue] | MetadataValue]

FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]]
FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]]
Expand Down
13 changes: 13 additions & 0 deletions tests/test_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from raglite._config import RAGLiteConfig
from raglite._database import Chunk, Document, create_database_engine
from raglite._insert import _get_database_metadata
from raglite._markdown import document_to_markdown


Expand Down Expand Up @@ -43,3 +44,15 @@ def test_insert(raglite_test_config: RAGLiteConfig) -> None:
doc = document_to_markdown(doc_path)
doc = doc.replace("\n", "").strip()
assert restored_document == doc, "Restored document does not match the original input."
# Verify that the document metadata matches.
metadata = _get_database_metadata(session)
assert len(metadata) > 0, "No metadata found for the document"
# Check that the metadata values match the original document metadata.
for meta in metadata:
assert meta.name in document.metadata_, (
f"Metadata {meta.name} not found in document metadata"
)
for value in document.metadata_[meta.name]:
assert value in meta.values, (
f"Metadata value '{value}' for '{meta.name}' not found in database metadata"
)
17 changes: 17 additions & 0 deletions tests/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,20 @@ def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None:
# Verify that no RAG context was retrieved.
assert [message["role"] for message in messages] == ["user", "assistant"]
assert not chunk_spans


def test_retrieve_context_self_query(raglite_test_config: RAGLiteConfig) -> None:
"""Test retrieve_context with self_query functionality."""
from dataclasses import replace

new_config = replace(raglite_test_config, self_query=True)
query = "What does Albert Einstein's paper say about time dilation?"
chunk_spans = retrieve_context(query=query, num_chunks=5, config=new_config)
assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)
for chunk_span in chunk_spans:
assert chunk_span.document.metadata_.get("type") == ["Paper"], (
f"Expected type='Paper', got {chunk_span.document.metadata_.get('type')}"
)
assert chunk_span.document.metadata_.get("author") == ["Albert Einstein"], (
f"Expected author='Albert Einstein', got {chunk_span.document.metadata_.get('author')}"
)
Loading