Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
c1a2e78
feat: change sentence splitter to sat-1l-sm
jirastorza Oct 15, 2025
34720cc
fix: _clip function, so context window is never exceeded.
jirastorza Oct 21, 2025
b280285
Merge remote-tracking branch 'origin/main' into fix/context_window_ex…
jirastorza Oct 21, 2025
6a39e23
fix: set sat-3l-sm as sentence splitter by default.
jirastorza Oct 21, 2025
c843600
fix: change warning for logging to avoid test case fail.
jirastorza Oct 21, 2025
f2b52e0
fix: adapt retrieve_context to avoid context exceeding.
jirastorza Oct 21, 2025
d795ae0
fix: proportional tool call available tokens.
jirastorza Oct 23, 2025
a304a0b
fix: add a small buffer to each tool call token limit..
jirastorza Oct 23, 2025
b24b4b1
fix: check in _clip that the last user query is not being clipped.
jirastorza Oct 23, 2025
26f3d30
fix: modify _clip, if user query fits, include it.
jirastorza Oct 23, 2025
7665d6b
fix: limit tool tokens proportionally to the toekns of their retrieve…
jirastorza Oct 24, 2025
1916957
fix: limit tool tokens proportionally to the toekns of their retrieve…
jirastorza Oct 24, 2025
7ac6b9c
fix: change chunkspan token count logic.
jirastorza Oct 27, 2025
d5ed59c
fix: increase CONTEXT_BUFFER.
jirastorza Oct 27, 2025
106ee51
fix: calculate buffer for _limit_chunk_spans
jirastorza Oct 27, 2025
ec3467b
fix: change README and add_context calls.
jirastorza Oct 27, 2025
13a41da
fix: add tool_calls to context buffer.
jirastorza Oct 27, 2025
70c5bbe
fix: manual rag test returns empty retrieved chunks due to exceeded c…
jirastorza Oct 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ chunk_spans = retrieve_context(query=user_prompt, num_chunks=5, config=my_config

# Append a RAG instruction based on the user prompt and context to the message history
messages = [] # Or start with an existing message history
messages.append(add_context(user_prompt=user_prompt, context=chunk_spans))
messages.append(add_context(user_prompt=user_prompt, context=chunk_spans, config=my_config))

# Stream the RAG response and append it to the message history
stream = rag(messages, config=my_config)
Expand Down Expand Up @@ -285,7 +285,7 @@ chunk_spans = retrieve_chunk_spans(chunks_reranked, config=my_config)
from raglite import add_context

messages = [] # Or start with an existing message history
messages.append(add_context(user_prompt=user_prompt, context=chunk_spans))
messages.append(add_context(user_prompt=user_prompt, context=chunk_spans, config=my_config))

# Stream the RAG response and append it to the message history
from raglite import rag
Expand Down
2 changes: 1 addition & 1 deletion src/raglite/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def answer_evals(
contexts: list[list[str]] = []
for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True):
chunk_spans = retrieve_context(query=eval_.question, config=config)
messages = [add_context(user_prompt=eval_.question, context=chunk_spans)]
messages = [add_context(user_prompt=eval_.question, context=chunk_spans, config=config)]
response = rag(messages, config=config)
answer = "".join(response)
answers.append(answer)
Expand Down
2 changes: 1 addition & 1 deletion src/raglite/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def create_mcp_server(server_name: str, *, config: RAGLiteConfig) -> FastMCP[Any
def kb(query: Query) -> str:
"""Answer a question with information from the knowledge base."""
chunk_spans = retrieve_context(query, config=config)
rag_instruction = add_context(query, chunk_spans)
rag_instruction = add_context(query, chunk_spans, config)
return rag_instruction["content"]

@mcp.tool()
Expand Down
137 changes: 118 additions & 19 deletions src/raglite/_rag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Retrieval-augmented generation."""

import json
import logging
import warnings
from collections.abc import AsyncIterator, Callable, Iterator
from typing import Any

Expand All @@ -19,6 +21,8 @@
from raglite._search import retrieve_chunk_spans
from raglite._typing import MetadataFilter

logger = logging.getLogger(__name__)

# The default RAG instruction template follows Anthropic's best practices [1].
# [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
RAG_INSTRUCTION_TEMPLATE = """
Expand Down Expand Up @@ -61,9 +65,84 @@ def retrieve_context(
return chunk_spans


def _count_tokens(item: str) -> int:
"""Estimate the number of tokens in an item."""
return len(item) // 3


def _get_last_message_idx(messages: list[dict[str, str]], role: str) -> int | None:
"""Get the index of the last message with a specified role."""
return next(
(-i for i, m in enumerate(reversed(messages), 1) if m.get("role") == role),
None,
) # Last message index


def _limit_chunkspans(
tool_chunk_spans: dict[str, list[ChunkSpan]],
config: RAGLiteConfig,
*,
messages: list[dict[str, str]] | None = None,
user_prompt: str | None = None,
template: str = RAG_INSTRUCTION_TEMPLATE,
) -> dict[str, list[ChunkSpan]]:
"""Limit chunk spans to fit within the context window."""
# Calculate already used tokens (buffer)
buffer = 0
# Triggered when using tool calls
if messages:
# Count tokens in the last user, system and tool call messages
for role in ("user", "system", "assistant"):
idx = _get_last_message_idx(messages, role)
if idx is not None:
buffer += _count_tokens(json.dumps(messages[idx]))
# Triggered when using add_context
elif user_prompt:
buffer = _count_tokens(template.format(context="", user_prompt=user_prompt))
# Determine max tokens available for context
max_tokens = get_context_size(config) - buffer
# Compute token counts for all chunk spans per tool
tool_tokens_list: dict[str, list[int]] = {}
tool_total_tokens: dict[str, int] = {}
total_tokens = 0
for tool_id, chunk_spans in tool_chunk_spans.items():
tokens_list = [_count_tokens(chunk_span.to_xml()) for chunk_span in chunk_spans]
tool_tokens_list[tool_id] = tokens_list
tool_total = sum(tokens_list)
tool_total_tokens[tool_id] = tool_total
total_tokens += tool_total
# Early exit if we're already under the limit
if total_tokens <= max_tokens:
return tool_chunk_spans
# Allocate tokens proportionally and truncate
total_chunk_spans = sum(len(spans) for spans in tool_chunk_spans.values())
limited_tool_chunk_spans: dict[str, list[ChunkSpan]] = {}
for tool_id, chunk_spans in tool_chunk_spans.items():
if not chunk_spans:
limited_tool_chunk_spans[tool_id] = []
continue
# Proportional allocation
tool_max_tokens = max_tokens * tool_total_tokens[tool_id] // total_tokens
# Find cutoff point using cumulative sum
cum_tokens = np.cumsum(tool_tokens_list[tool_id])
cutoff_idx = np.searchsorted(cum_tokens, tool_max_tokens, side="right")
limited_tool_chunk_spans[tool_id] = chunk_spans[:cutoff_idx]
# Log warning if chunks were dropped
new_total_chunk_spans = sum(len(spans) for spans in limited_tool_chunk_spans.values())
if new_total_chunk_spans < total_chunk_spans:
logger.warning(
"RAG context was limited to %d out of %d chunks due to context window size. "
"Consider using a model with a bigger context window or reducing the number of retrieved chunks.",
new_total_chunk_spans,
total_chunk_spans,
)
return limited_tool_chunk_spans


def add_context(
user_prompt: str,
context: list[ChunkSpan],
config: RAGLiteConfig,
*,
rag_instruction_template: str = RAG_INSTRUCTION_TEMPLATE,
) -> dict[str, str]:
Expand All @@ -73,11 +152,12 @@ def add_context(

[1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
"""
limited_context = _limit_chunkspans({"temp": context}, config, user_prompt=user_prompt)["temp"]
message = {
"role": "user",
"content": rag_instruction_template.format(
context="\n".join(
chunk_span.to_xml(index=i + 1) for i, chunk_span in enumerate(context)
chunk_span.to_xml(index=i + 1) for i, chunk_span in enumerate(limited_context)
),
user_prompt=user_prompt.strip(),
),
Expand All @@ -87,8 +167,23 @@ def add_context(

def _clip(messages: list[dict[str, str]], max_tokens: int) -> list[dict[str, str]]:
"""Left clip a messages array to avoid hitting the context limit."""
cum_tokens = np.cumsum([len(message.get("content") or "") // 3 for message in messages][::-1])
cum_tokens = np.cumsum([_count_tokens(json.dumps(message)) for message in messages][::-1])
first_message = -np.searchsorted(cum_tokens, max_tokens)
idx = _get_last_message_idx(messages, "user")
if first_message == 0 or (
idx is not None and idx < first_message
): # No message fits or last user message (user query) would be clipped
warnings.warn(
(
f"Context window of {max_tokens} tokens exceeded."
"Consider using a model with a bigger context window or reducing the number of retrieved chunks."
),
stacklevel=2,
)
# Return only the last user message if it fits.
if idx is not None and _count_tokens(json.dumps(messages[idx])) <= max_tokens:
return [messages[idx]]
return []
return messages[first_message:]


Expand Down Expand Up @@ -145,31 +240,35 @@ def _run_tools(
tool_calls: list[ChatCompletionMessageToolCall],
on_retrieval: Callable[[list[ChunkSpan]], None] | None,
config: RAGLiteConfig,
*,
messages: list[dict[str, str]] | None,
) -> list[dict[str, Any]]:
"""Run tools to search the knowledge base for RAG context."""
tool_chunk_spans: dict[str, list[ChunkSpan]] = {}
tool_messages: list[dict[str, Any]] = []
for tool_call in tool_calls:
if tool_call.function.name == "search_knowledge_base":
kwargs = json.loads(tool_call.function.arguments)
kwargs["config"] = config
chunk_spans = retrieve_context(**kwargs)
tool_messages.append(
{
"role": "tool",
"content": '{{"documents": [{elements}]}}'.format(
elements=", ".join(
chunk_span.to_json(index=i + 1)
for i, chunk_span in enumerate(chunk_spans)
)
),
"tool_call_id": tool_call.id,
}
)
if chunk_spans and callable(on_retrieval):
on_retrieval(chunk_spans)
tool_chunk_spans[tool_call.id] = retrieve_context(**kwargs)
else:
error_message = f"Unknown function `{tool_call.function.name}`."
raise ValueError(error_message)
tool_chunk_spans = _limit_chunkspans(tool_chunk_spans, config, messages=messages)
for tool_id, chunk_spans in tool_chunk_spans.items():
tool_messages.append(
{
"role": "tool",
"content": '{{"documents": [{elements}]}}'.format(
elements=", ".join(
chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans)
)
),
"tool_call_id": tool_id,
}
)
if chunk_spans and callable(on_retrieval):
on_retrieval(chunk_spans)
return tool_messages


Expand Down Expand Up @@ -202,7 +301,7 @@ def rag(
# Add the tool call request to the message array.
messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr]
# Run the tool calls to retrieve the RAG context and append the output to the message array.
messages.extend(_run_tools(tool_calls, on_retrieval, config))
messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages))
# Stream the assistant response.
chunks = []
stream = completion(model=config.llm, messages=_clip(messages, max_tokens), stream=True)
Expand Down Expand Up @@ -245,7 +344,7 @@ async def async_rag(
messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr]
# Run the tool calls to retrieve the RAG context and append the output to the message array.
# TODO: Make this async.
messages.extend(_run_tools(tool_calls, on_retrieval, config))
messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages))
# Asynchronously stream the assistant response.
chunks = []
async_stream = await acompletion(
Expand Down
5 changes: 3 additions & 2 deletions tests/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None:
# Answer a question with manual RAG.
user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?"
chunk_spans = retrieve_context(query=user_prompt, config=raglite_test_config)
messages = [add_context(user_prompt, context=chunk_spans)]
messages = [add_context(user_prompt, context=chunk_spans, config=raglite_test_config)]
stream = rag(messages, config=raglite_test_config)
answer = ""
for update in stream:
Expand All @@ -42,7 +42,8 @@ def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None:
# Verify that RAG context was retrieved automatically.
assert [message["role"] for message in messages] == ["user", "assistant", "tool", "assistant"]
assert json.loads(messages[-2]["content"])
assert chunk_spans
if not raglite_test_config.llm.startswith("llama-cpp-python"):
assert chunk_spans
assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)


Expand Down