Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
c1a2e78
feat: change sentence splitter to sat-1l-sm
jirastorza Oct 15, 2025
34720cc
fix: _clip function, so context window is never exceeded.
jirastorza Oct 21, 2025
b280285
Merge remote-tracking branch 'origin/main' into fix/context_window_ex…
jirastorza Oct 21, 2025
6a39e23
fix: set sat-3l-sm as sentence splitter by default.
jirastorza Oct 21, 2025
c843600
fix: change warning for logging to avoid test case fail.
jirastorza Oct 21, 2025
f2b52e0
fix: adapt retrieve_context to avoid context exceeding.
jirastorza Oct 21, 2025
d795ae0
fix: proportional tool call available tokens.
jirastorza Oct 23, 2025
a304a0b
fix: add a small buffer to each tool call token limit..
jirastorza Oct 23, 2025
b24b4b1
fix: check in _clip that the last user query is not being clipped.
jirastorza Oct 23, 2025
26f3d30
fix: modify _clip, if user query fits, include it.
jirastorza Oct 23, 2025
7665d6b
fix: limit tool tokens proportionally to the toekns of their retrieve…
jirastorza Oct 24, 2025
1916957
fix: limit tool tokens proportionally to the toekns of their retrieve…
jirastorza Oct 24, 2025
7ac6b9c
fix: change chunkspan token count logic.
jirastorza Oct 27, 2025
d5ed59c
fix: increase CONTEXT_BUFFER.
jirastorza Oct 27, 2025
106ee51
fix: calculate buffer for _limit_chunk_spans
jirastorza Oct 27, 2025
ec3467b
fix: change README and add_context calls.
jirastorza Oct 27, 2025
13a41da
fix: add tool_calls to context buffer.
jirastorza Oct 27, 2025
70c5bbe
fix: manual rag test returns empty retrieved chunks due to exceeded c…
jirastorza Oct 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/raglite/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,4 @@ class RAGLiteConfig:
# Search config: you can pick any search method that returns (list[ChunkId], list[float]),
# list[Chunk], or list[ChunkSpan].
search_method: SearchMethod = field(default=_vector_search, compare=False)
_num_queries: int = 1
29 changes: 29 additions & 0 deletions src/raglite/_rag.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Retrieval-augmented generation."""

import json
import logging
import warnings
from collections.abc import AsyncIterator, Callable, Iterator
from dataclasses import replace
from typing import Any

import numpy as np
Expand All @@ -19,6 +22,8 @@
from raglite._search import retrieve_chunk_spans
from raglite._typing import MetadataFilter

logger = logging.getLogger(__name__)

# The default RAG instruction template follows Anthropic's best practices [1].
# [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
RAG_INSTRUCTION_TEMPLATE = """
Expand Down Expand Up @@ -58,9 +63,23 @@ def retrieve_context(
chunk_spans = retrieve_chunk_spans(results, config=config) # type: ignore[arg-type]
elif all(isinstance(result, ChunkSpan) for result in results):
chunk_spans = results # type: ignore[assignment]
chunk_spans = limit_chunkspans(chunk_spans, config)
return chunk_spans


def limit_chunkspans(chunk_spans: list[ChunkSpan], config: RAGLiteConfig) -> list[ChunkSpan]:
max_tokens = get_context_size(config) // config._num_queries - 300 // config._num_queries # noqa: SLF001
cum_tokens = np.cumsum([len(chunk_span.to_json()) // 3 for chunk_span in chunk_spans])
first_chunk = np.searchsorted(cum_tokens, max_tokens)
if first_chunk < len(chunk_spans):
logger.warning(
"Retrieved chunks exceed context window. "
"Truncating to %d chunk(s). Consider reducing the number of retrieved chunks or using a model with bigger context window.",
first_chunk // len(chunk_spans),
)
return chunk_spans[:first_chunk]


def add_context(
user_prompt: str,
context: list[ChunkSpan],
Expand Down Expand Up @@ -89,6 +108,15 @@ def _clip(messages: list[dict[str, str]], max_tokens: int) -> list[dict[str, str
"""Left clip a messages array to avoid hitting the context limit."""
cum_tokens = np.cumsum([len(message.get("content") or "") // 3 for message in messages][::-1])
first_message = -np.searchsorted(cum_tokens, max_tokens)
if first_message == 0 and cum_tokens[-1] > max_tokens:
warnings.warn(
(
f"Context window of {max_tokens} tokens exceeded even after clipping all previous messages."
"Consider using a model with a bigger context window or reducing the number of retrieved chunks."
),
stacklevel=2,
)
return []
return messages[first_message:]


Expand Down Expand Up @@ -147,6 +175,7 @@ def _run_tools(
config: RAGLiteConfig,
) -> list[dict[str, Any]]:
"""Run tools to search the knowledge base for RAG context."""
config = replace(config, _num_queries=len(tool_calls))
tool_messages: list[dict[str, Any]] = []
for tool_call in tool_calls:
if tool_call.function.name == "search_knowledge_base":
Expand Down
Loading