Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c1a2e78
feat: change sentence splitter to sat-1l-sm
jirastorza Oct 15, 2025
34720cc
fix: _clip function, so context window is never exceeded.
jirastorza Oct 21, 2025
b280285
Merge remote-tracking branch 'origin/main' into fix/context_window_ex…
jirastorza Oct 21, 2025
6a39e23
fix: set sat-3l-sm as sentence splitter by default.
jirastorza Oct 21, 2025
c843600
fix: change warning for logging to avoid test case fail.
jirastorza Oct 21, 2025
f2b52e0
fix: adapt retrieve_context to avoid context exceeding.
jirastorza Oct 21, 2025
d795ae0
fix: proportional tool call available tokens.
jirastorza Oct 23, 2025
a304a0b
fix: add a small buffer to each tool call token limit..
jirastorza Oct 23, 2025
b24b4b1
fix: check in _clip that the last user query is not being clipped.
jirastorza Oct 23, 2025
26f3d30
fix: modify _clip, if user query fits, include it.
jirastorza Oct 23, 2025
7665d6b
fix: limit tool tokens proportionally to the toekns of their retrieve…
jirastorza Oct 24, 2025
1916957
fix: limit tool tokens proportionally to the toekns of their retrieve…
jirastorza Oct 24, 2025
7ac6b9c
fix: change chunkspan token count logic.
jirastorza Oct 27, 2025
d5ed59c
fix: increase CONTEXT_BUFFER.
jirastorza Oct 27, 2025
106ee51
fix: calculate buffer for _limit_chunk_spans
jirastorza Oct 27, 2025
ec3467b
fix: change README and add_context calls.
jirastorza Oct 27, 2025
13a41da
fix: add tool_calls to context buffer.
jirastorza Oct 27, 2025
56bef88
fix: change sentence splitter back to sat-1l-sm.
jirastorza Oct 27, 2025
549a37d
fix: manual rag test returns empty retrieved chunks due to exceeded c…
jirastorza Oct 27, 2025
96917a0
fix: rerun tests
jirastorza Oct 27, 2025
8a74a3c
fix: change expected sentences in test_split_sentences in accordace t…
jirastorza Oct 27, 2025
e55f472
fix: include system_message in _clip.
jirastorza Oct 28, 2025
613c304
fix: simplify function _limit_chunkspans
jirastorza Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ chunk_spans = retrieve_context(query=user_prompt, num_chunks=5, config=my_config

# Append a RAG instruction based on the user prompt and context to the message history
messages = [] # Or start with an existing message history
messages.append(add_context(user_prompt=user_prompt, context=chunk_spans))
messages.append(add_context(user_prompt=user_prompt, context=chunk_spans, config=my_config))

# Stream the RAG response and append it to the message history
stream = rag(messages, config=my_config)
Expand Down Expand Up @@ -285,7 +285,7 @@ chunk_spans = retrieve_chunk_spans(chunks_reranked, config=my_config)
from raglite import add_context

messages = [] # Or start with an existing message history
messages.append(add_context(user_prompt=user_prompt, context=chunk_spans))
messages.append(add_context(user_prompt=user_prompt, context=chunk_spans, config=my_config))

# Stream the RAG response and append it to the message history
from raglite import rag
Expand Down
2 changes: 1 addition & 1 deletion src/raglite/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def answer_evals(
contexts: list[list[str]] = []
for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True):
chunk_spans = retrieve_context(query=eval_.question, config=config)
messages = [add_context(user_prompt=eval_.question, context=chunk_spans)]
messages = [add_context(user_prompt=eval_.question, context=chunk_spans, config=config)]
response = rag(messages, config=config)
answer = "".join(response)
answers.append(answer)
Expand Down
2 changes: 1 addition & 1 deletion src/raglite/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def create_mcp_server(server_name: str, *, config: RAGLiteConfig) -> FastMCP[Any
def kb(query: Query) -> str:
"""Answer a question with information from the knowledge base."""
chunk_spans = retrieve_context(query, config=config)
rag_instruction = add_context(query, chunk_spans)
rag_instruction = add_context(query, chunk_spans, config)
return rag_instruction["content"]

@mcp.tool()
Expand Down
188 changes: 166 additions & 22 deletions src/raglite/_rag.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Retrieval-augmented generation."""

import json
from collections.abc import AsyncIterator, Callable, Iterator
import logging
import warnings
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
from typing import Any

import numpy as np
Expand All @@ -19,6 +21,8 @@
from raglite._search import retrieve_chunk_spans
from raglite._typing import MetadataFilter

logger = logging.getLogger(__name__)

# The default RAG instruction template follows Anthropic's best practices [1].
# [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
RAG_INSTRUCTION_TEMPLATE = """
Expand Down Expand Up @@ -61,9 +65,121 @@ def retrieve_context(
return chunk_spans


def _count_tokens(item: str) -> int:
"""Estimate the number of tokens in an item."""
return len(item) // 3


def _get_last_message_idx(messages: list[dict[str, str]], role: str) -> int | None:
"""Get the index of the last message with a specified role."""
# Find the last index of a message with the specified role
for i in range(len(messages) - 1, -1, -1):
if messages[i].get("role") == role:
return i
return None


def _calculate_buffer_tokens(
messages: list[dict[str, str]] | None,
roles: list[str],
user_prompt: str | None,
template: str,
) -> int:
"""Calculate the number of tokens used by other messages."""
# Calculate already used tokens (buffer)
buffer = 0
# Triggered when using tool calls
if messages:
# Count tokens in the last user, system and tool call messages
for role in roles:
idx = _get_last_message_idx(messages, role)
if idx is not None:
buffer += _count_tokens(json.dumps(messages[idx]))
return buffer
# Triggered when using add_context
if user_prompt:
return _count_tokens(template.format(context="", user_prompt=user_prompt))
return 0


def _cutoff_idx(token_counts: list[int], max_tokens: int, *, reverse: bool = False) -> int:
"""Find the cutoff index in token counts to fit within max tokens."""
counts = token_counts[::-1] if reverse else token_counts
cum_tokens = np.cumsum(counts)
cutoff_idx = int(np.searchsorted(cum_tokens, max_tokens, side="right"))
return len(token_counts) - cutoff_idx if reverse else cutoff_idx


def _get_token_counts(items: Sequence[str | ChunkSpan | Mapping[str, str]]) -> list[int]:
"""Compute token counts for a list of items."""
return [
_count_tokens(item.to_xml())
if isinstance(item, ChunkSpan)
else _count_tokens(json.dumps(item, ensure_ascii=False))
if isinstance(item, dict)
else _count_tokens(item)
if isinstance(item, str)
else 0
for item in items
]


def _limit_chunkspans(
tool_chunk_spans: dict[str, list[ChunkSpan]],
config: RAGLiteConfig,
*,
messages: list[dict[str, str]] | None = None,
user_prompt: str | None = None,
template: str = RAG_INSTRUCTION_TEMPLATE,
) -> dict[str, list[ChunkSpan]]:
"""Limit chunk spans to fit within the context window."""
# Calculate already used tokens (buffer)
buffer = _calculate_buffer_tokens(
messages, ["user", "system", "assistant"], user_prompt, template
)
# Determine max tokens available for context
max_tokens = get_context_size(config) - buffer
# Compute token counts for all chunk spans per tool
tool_tokens_list: dict[str, list[int]] = {}
tool_total_tokens: dict[str, int] = {}
total_tokens = 0
for tool_id, chunk_spans in tool_chunk_spans.items():
tokens_list = _get_token_counts(chunk_spans)
tool_tokens_list[tool_id] = tokens_list
tool_total = sum(tokens_list)
tool_total_tokens[tool_id] = tool_total
total_tokens += tool_total
# Early exit if we're already under the limit
if total_tokens <= max_tokens:
return tool_chunk_spans
# Allocate tokens proportionally and truncate
total_chunk_spans = sum(len(spans) for spans in tool_chunk_spans.values())
limited_tool_chunk_spans: dict[str, list[ChunkSpan]] = {}
for tool_id, chunk_spans in tool_chunk_spans.items():
if not chunk_spans:
limited_tool_chunk_spans[tool_id] = []
continue
# Proportional allocation
tool_max_tokens = max_tokens * tool_total_tokens[tool_id] // total_tokens
# Find cutoff point using cumulative sum
cutoff_idx = _cutoff_idx(tool_tokens_list[tool_id], tool_max_tokens)
limited_tool_chunk_spans[tool_id] = chunk_spans[:cutoff_idx]
# Log warning if chunks were dropped
new_total_chunk_spans = sum(len(spans) for spans in limited_tool_chunk_spans.values())
if new_total_chunk_spans < total_chunk_spans:
logger.warning(
"RAG context was limited to %d out of %d chunks due to context window size. "
"Consider using a model with a bigger context window or reducing the number of retrieved chunks.",
new_total_chunk_spans,
total_chunk_spans,
)
return limited_tool_chunk_spans


def add_context(
user_prompt: str,
context: list[ChunkSpan],
config: RAGLiteConfig,
*,
rag_instruction_template: str = RAG_INSTRUCTION_TEMPLATE,
) -> dict[str, str]:
Expand All @@ -73,11 +189,13 @@ def add_context(

[1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
"""
# Limit context to fit within the context window.
limited_context = _limit_chunkspans({"temp": context}, config, user_prompt=user_prompt)["temp"]
message = {
"role": "user",
"content": rag_instruction_template.format(
context="\n".join(
chunk_span.to_xml(index=i + 1) for i, chunk_span in enumerate(context)
chunk_span.to_xml(index=i + 1) for i, chunk_span in enumerate(limited_context)
),
user_prompt=user_prompt.strip(),
),
Expand All @@ -87,9 +205,31 @@ def add_context(

def _clip(messages: list[dict[str, str]], max_tokens: int) -> list[dict[str, str]]:
"""Left clip a messages array to avoid hitting the context limit."""
cum_tokens = np.cumsum([len(message.get("content") or "") // 3 for message in messages][::-1])
first_message = -np.searchsorted(cum_tokens, max_tokens)
return messages[first_message:]
token_counts = _get_token_counts(messages)
cutoff_idx = _cutoff_idx(token_counts, max_tokens, reverse=True)
idx_user = _get_last_message_idx(messages, "user")
if cutoff_idx == 0 or (idx_user is not None and idx_user < cutoff_idx):
warnings.warn(
(
f"Context window of {max_tokens} tokens exceeded."
"Consider using a model with a bigger context window or reducing the number of retrieved chunks."
),
stacklevel=2,
)
# Try to include both last system and user messages if they fit together.
# If not, include just user if it fits, else return empty.
idx_system = _get_last_message_idx(messages, "system")
if (
idx_user is not None
and idx_system is not None
and idx_system < idx_user
and token_counts[idx_user] + token_counts[idx_system] <= max_tokens
):
return [messages[idx_system], messages[idx_user]]
if idx_user is not None and token_counts[idx_user] <= max_tokens:
return [messages[idx_user]]
return []
return messages[cutoff_idx:]


def _get_tools(
Expand Down Expand Up @@ -145,31 +285,35 @@ def _run_tools(
tool_calls: list[ChatCompletionMessageToolCall],
on_retrieval: Callable[[list[ChunkSpan]], None] | None,
config: RAGLiteConfig,
*,
messages: list[dict[str, str]] | None,
) -> list[dict[str, Any]]:
"""Run tools to search the knowledge base for RAG context."""
tool_chunk_spans: dict[str, list[ChunkSpan]] = {}
tool_messages: list[dict[str, Any]] = []
for tool_call in tool_calls:
if tool_call.function.name == "search_knowledge_base":
kwargs = json.loads(tool_call.function.arguments)
kwargs["config"] = config
chunk_spans = retrieve_context(**kwargs)
tool_messages.append(
{
"role": "tool",
"content": '{{"documents": [{elements}]}}'.format(
elements=", ".join(
chunk_span.to_json(index=i + 1)
for i, chunk_span in enumerate(chunk_spans)
)
),
"tool_call_id": tool_call.id,
}
)
if chunk_spans and callable(on_retrieval):
on_retrieval(chunk_spans)
tool_chunk_spans[tool_call.id] = retrieve_context(**kwargs)
else:
error_message = f"Unknown function `{tool_call.function.name}`."
raise ValueError(error_message)
tool_chunk_spans = _limit_chunkspans(tool_chunk_spans, config, messages=messages)
for tool_id, chunk_spans in tool_chunk_spans.items():
tool_messages.append(
{
"role": "tool",
"content": '{{"documents": [{elements}]}}'.format(
elements=", ".join(
chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans)
)
),
"tool_call_id": tool_id,
}
)
if chunk_spans and callable(on_retrieval):
on_retrieval(chunk_spans)
return tool_messages


Expand Down Expand Up @@ -202,7 +346,7 @@ def rag(
# Add the tool call request to the message array.
messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr]
# Run the tool calls to retrieve the RAG context and append the output to the message array.
messages.extend(_run_tools(tool_calls, on_retrieval, config))
messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages))
# Stream the assistant response.
chunks = []
stream = completion(model=config.llm, messages=_clip(messages, max_tokens), stream=True)
Expand Down Expand Up @@ -245,7 +389,7 @@ async def async_rag(
messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr]
# Run the tool calls to retrieve the RAG context and append the output to the message array.
# TODO: Make this async.
messages.extend(_run_tools(tool_calls, on_retrieval, config))
messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages))
# Asynchronously stream the assistant response.
chunks = []
async_stream = await acompletion(
Expand Down
2 changes: 1 addition & 1 deletion src/raglite/_split_sentences.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@cache
def _load_sat() -> tuple[SaT, dict[str, Any]]:
"""Load a Segment any Text (SaT) model."""
sat = SaT("sat-3l-sm") # This model makes the best trade-off between speed and accuracy.
sat = SaT("sat-1l-sm") # This model makes the best trade-off between speed and accuracy.
sat_kwargs = {"stride": 128, "block_size": 256, "weighting": "hat"}
return sat, sat_kwargs

Expand Down
5 changes: 3 additions & 2 deletions tests/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None:
# Answer a question with manual RAG.
user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?"
chunk_spans = retrieve_context(query=user_prompt, config=raglite_test_config)
messages = [add_context(user_prompt, context=chunk_spans)]
messages = [add_context(user_prompt, context=chunk_spans, config=raglite_test_config)]
stream = rag(messages, config=raglite_test_config)
answer = ""
for update in stream:
Expand All @@ -42,7 +42,8 @@ def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None:
# Verify that RAG context was retrieved automatically.
assert [message["role"] for message in messages] == ["user", "assistant", "tool", "assistant"]
assert json.loads(messages[-2]["content"])
assert chunk_spans
if not raglite_test_config.llm.startswith("llama-cpp-python"):
assert chunk_spans
assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)


Expand Down
3 changes: 1 addition & 2 deletions tests/test_split_sentences.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def test_split_sentences() -> None:
"They suggest rather that, as has\nalready been shown to the first order of small quantities, the same laws of\nelectrodynamics and optics will be valid for all frames of reference for which the\nequations of mechanics hold good.1 ",
"We will raise this conjecture (the purport\nof which will hereafter be called the “Principle of Relativity”) to the status\n\nof a postulate, and also introduce another postulate, which is only apparently\nirreconcilable with the former, namely, that light is always propagated in empty\nspace with a definite velocity c which is independent of the state of motion of the\nemitting body. ",
"These two postulates suffice for the attainment of a simple and\nconsistent theory of the electrodynamics of moving bodies based on Maxwell’s\ntheory for stationary bodies. ", # noqa: RUF001
"The introduction of a “luminiferous ether” will\nprove to be superfluous inasmuch as the view here to be developed will not\nrequire an “absolutely stationary space” provided with special properties, nor\n1",
"The preceding memoir by Lorentz was not at this time known to the author.\n\n",
"The introduction of a “luminiferous ether” will\nprove to be superfluous inasmuch as the view here to be developed will not\nrequire an “absolutely stationary space” provided with special properties, nor\n1The preceding memoir by Lorentz was not at this time known to the author.\n\n",
"assign a velocity-vector to a point of the empty space in which electromagnetic\nprocesses take place.\n\n",
"The theory to be developed is based—like all electrodynamics—on the kine-\nmatics of the rigid body, since the assertions of any such theory have to do\nwith the relationships between rigid bodies (systems of co-ordinates), clocks,\nand electromagnetic processes. ",
"Insufficient consideration of this circumstance\nlies at the root of the difficulties which the electrodynamics of moving bodies\nat present encounters.\n\n",
Expand Down
Loading