Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ target-version = "py310"
docstring-code-format = true

[tool.ruff.lint]
select = ["A", "ASYNC", "B", "BLE", "C4", "C90", "D", "DTZ", "E", "EM", "ERA", "F", "FBT", "FLY", "FURB", "G", "I", "ICN", "INP", "INT", "ISC", "LOG", "N", "NPY", "PERF", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "Q", "RET", "RSE", "RUF", "S", "SIM", "SLF", "SLOT", "T10", "T20", "TCH", "TID", "TRY", "UP", "W", "YTT"]
ignore = ["D203", "D213", "E501", "RET504", "RUF002", "RUF022", "S101", "S307", "TC004"]
select = ["ALL"]
ignore = ["CPY", "FIX", "ARG001", "COM812", "D203", "D213", "E501", "PD008", "PD009", "RET504", "S101", "TD003"]
unfixable = ["ERA001", "F401", "F841", "T201", "T203"]

[tool.ruff.lint.flake8-tidy-imports]
Expand Down
2 changes: 1 addition & 1 deletion src/raglite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
vector_search,
)

__all__ = [
__all__ = [ # noqa: RUF022
# Config
"RAGLiteConfig",
# Insert
Expand Down
23 changes: 14 additions & 9 deletions src/raglite/_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
insert_variant: str | None = None,
search_variant: str | None = None,
config: RAGLiteConfig | None = None,
):
) -> None:
super().__init__(
dataset,
num_results=num_results,
Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(
num_results: int = 10,
insert_variant: str | None = None,
search_variant: str | None = None,
):
) -> None:
super().__init__(
dataset,
num_results=num_results,
Expand All @@ -156,7 +156,7 @@ def __init__(
self.embedder_dim = 3072
self.persist_path = self.cwd / self.insert_id

def insert_documents(self, max_workers: int | None = None) -> None:
def insert_documents(self, max_workers: int | None = None) -> None: # noqa: ARG002
# Adapted from https://docs.llamaindex.ai/en/stable/examples/vector_stores/FaissIndexDemo/.
import faiss
from llama_index.core import Document, StorageContext, VectorStoreIndex
Expand All @@ -178,14 +178,15 @@ def insert_documents(self, max_workers: int | None = None) -> None:
index.storage_context.persist(persist_dir=self.persist_path)

@cached_property
def index(self) -> Any:
def index(self) -> Any: # noqa: ANN401
from llama_index.core import StorageContext, load_index_from_storage
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore

vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.persist_path.as_posix())
storage_context = StorageContext.from_defaults(
vector_store=vector_store, persist_dir=self.persist_path.as_posix()
vector_store=vector_store,
persist_dir=self.persist_path.as_posix(),
)
embed_model = OpenAIEmbedding(model=self.embedder, dimensions=self.embedder_dim)
index = load_index_from_storage(storage_context, embed_model=embed_model)
Expand Down Expand Up @@ -215,7 +216,7 @@ def __init__(
num_results: int = 10,
insert_variant: str | None = None,
search_variant: str | None = None,
):
) -> None:
super().__init__(
dataset,
num_results=num_results,
Expand All @@ -227,7 +228,7 @@ def __init__(
)

@cached_property
def client(self) -> Any:
def client(self) -> Any: # noqa: ANN401
import openai

return openai.OpenAI()
Expand Down Expand Up @@ -269,7 +270,9 @@ def insert_documents(self, max_workers: int | None = None) -> None:
files.append(temp_file.open("rb"))
if len(files) == max_files_per_batch or (i == self.dataset.docs_count() - 1):
self.client.vector_stores.file_batches.upload_and_poll(
vector_store_id=vector_store.id, files=files, max_concurrency=max_workers
vector_store_id=vector_store.id,
files=files,
max_concurrency=max_workers,
)
for f in files:
f.close()
Expand All @@ -283,7 +286,9 @@ def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[Sc
if not self.vector_store_id:
return []
response = self.client.vector_stores.search(
vector_store_id=self.vector_store_id, query=query, max_num_results=2 * num_results
vector_store_id=self.vector_store_id,
query=query,
max_num_results=2 * num_results,
)
scored_docs = [
ScoredDoc(
Expand Down
6 changes: 4 additions & 2 deletions src/raglite/_chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def start_chat() -> None:
TextInput(id="llm", label="LLM", initial=config.llm),
TextInput(id="embedder", label="Embedder", initial=config.embedder),
Switch(id="vector_search_query_adapter", label="Query adapter", initial=True),
]
],
).send()
await update_config(settings)

Expand Down Expand Up @@ -95,7 +95,9 @@ async def handle_message(user_message: cl.Message) -> None:
messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call]
messages.append({"role": "user", "content": user_prompt})
async for token in async_rag(
messages, on_retrieval=lambda x: chunk_spans.extend(x), config=config
messages,
on_retrieval=lambda x: chunk_spans.extend(x),
config=config,
):
await assistant_message.stream_token(token)
# Append RAG sources, if any.
Expand Down
55 changes: 38 additions & 17 deletions src/raglite/_chatml_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def _convert_chunks_to_completion(
{
"text": text,
"index": 0,
"logprobs": logprobs, # TODO: Improve accumulation of logprobs
"logprobs": logprobs, # TODO(lsorber): Improve accumulation of logprobs
"finish_reason": finish_reason, # type: ignore[typeddict-item]
}
},
],
}
# Add usage section if present in the chunks
Expand Down Expand Up @@ -131,7 +131,8 @@ def _stream_tool_calls(
prompt += f"functions.{tool_name}:\n"
try:
grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
json.dumps(tool["function"]["parameters"]),
verbose=llama.verbose,
)
except Exception as e:
warnings.warn(
Expand All @@ -140,7 +141,8 @@ def _stream_tool_calls(
stacklevel=2,
)
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF, verbose=llama.verbose
llama_grammar.JSON_GBNF,
verbose=llama.verbose,
)
completion_or_chunks = llama.create_completion(
prompt=prompt,
Expand Down Expand Up @@ -182,7 +184,8 @@ def _stream_tool_calls(
"stop": [*completion_kwargs["stop"], ":", "</function_calls>"],
"max_tokens": None,
"grammar": llama_grammar.LlamaGrammar.from_string(
follow_up_gbnf_tool_grammar, verbose=llama.verbose
follow_up_gbnf_tool_grammar,
verbose=llama.verbose,
),
},
),
Expand Down Expand Up @@ -253,7 +256,7 @@ def chatml_function_calling_with_streaming(
grammar: Optional[llama.LlamaGrammar] = None, # type: ignore[name-defined]
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs: Any,
**kwargs: Any, # noqa: ANN401
) -> Union[
llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse],
Expand Down Expand Up @@ -381,7 +384,10 @@ def chatml_function_calling_with_streaming(
or len(tools) == 0
):
prompt = template_renderer.render(
messages=messages, tools=[], tool_calls=None, add_generation_prompt=True
messages=messages,
tools=[],
tool_calls=None,
add_generation_prompt=True,
)
return llama_chat_format._convert_completion_to_chat( # noqa: SLF001
llama.create_completion(
Expand All @@ -404,7 +410,10 @@ def chatml_function_calling_with_streaming(
assert tools
function_names = " | ".join([f'''"functions.{t["function"]["name"]}:"''' for t in tools])
prompt = template_renderer.render(
messages=messages, tools=tools, tool_calls=True, add_generation_prompt=True
messages=messages,
tools=tools,
tool_calls=True,
add_generation_prompt=True,
)
initial_gbnf_tool_grammar = (
(
Expand All @@ -429,7 +438,8 @@ def chatml_function_calling_with_streaming(
"stream": False,
"max_tokens": None,
"grammar": llama_grammar.LlamaGrammar.from_string(
initial_gbnf_tool_grammar, verbose=llama.verbose
initial_gbnf_tool_grammar,
verbose=llama.verbose,
),
},
),
Expand All @@ -449,7 +459,10 @@ def chatml_function_calling_with_streaming(
# Case 2 step 2A: Respond with a message
if tool_name is None:
prompt = template_renderer.render(
messages=messages, tools=[], tool_calls=None, add_generation_prompt=True
messages=messages,
tools=[],
tool_calls=None,
add_generation_prompt=True,
)
prompt += think
return llama_chat_format._convert_completion_to_chat( # noqa: SLF001
Expand All @@ -469,7 +482,12 @@ def chatml_function_calling_with_streaming(
prompt += "<function_calls>\n"
if stream:
return _stream_tool_calls(
llama, prompt, tools, tool_name, completion_kwargs, follow_up_gbnf_tool_grammar
llama,
prompt,
tools,
tool_name,
completion_kwargs,
follow_up_gbnf_tool_grammar,
)
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
completions: List[llama_types.CreateCompletionResponse] = []
Expand All @@ -479,7 +497,8 @@ def chatml_function_calling_with_streaming(
prompt += f"functions.{tool_name}:\n"
try:
grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
json.dumps(tool["function"]["parameters"]),
verbose=llama.verbose,
)
except Exception as e:
warnings.warn(
Expand All @@ -488,7 +507,8 @@ def chatml_function_calling_with_streaming(
stacklevel=2,
)
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF, verbose=llama.verbose
llama_grammar.JSON_GBNF,
verbose=llama.verbose,
)
completion_or_chunks = llama.create_completion(
prompt=prompt,
Expand All @@ -515,7 +535,8 @@ def chatml_function_calling_with_streaming(
"stop": [*completion_kwargs["stop"], ":", "</function_calls>"], # type: ignore[misc]
"max_tokens": None,
"grammar": llama_grammar.LlamaGrammar.from_string(
follow_up_gbnf_tool_grammar, verbose=llama.verbose
follow_up_gbnf_tool_grammar,
verbose=llama.verbose,
),
},
),
Expand All @@ -533,7 +554,7 @@ def chatml_function_calling_with_streaming(
"finish_reason": "tool_calls",
"index": 0,
"logprobs": _convert_text_completion_logprobs_to_chat(
completion["choices"][0]["logprobs"]
completion["choices"][0]["logprobs"],
),
"message": {
"role": "assistant",
Expand All @@ -548,11 +569,11 @@ def chatml_function_calling_with_streaming(
},
}
for i, (tool_name, completion) in enumerate(
zip(completions_tool_name, completions, strict=True)
zip(completions_tool_name, completions, strict=True),
)
],
},
}
},
],
"usage": {
"completion_tokens": sum(
Expand Down
25 changes: 18 additions & 7 deletions src/raglite/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ class RAGLiteCLIConfig(BaseSettings):
"""RAGLite CLI config."""

model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
env_prefix="RAGLITE_", env_file=".env", extra="allow"
env_prefix="RAGLITE_",
env_file=".env",
extra="allow",
)

mcp_server_name: str = "RAGLite"
Expand Down Expand Up @@ -67,7 +69,7 @@ def install_mcp_server(
claude_config_path = get_claude_config_path()
if not claude_config_path:
typer.echo(
"Please download the Claude desktop app from https://claude.ai/download before installing an MCP server."
"Please download the Claude desktop app from https://claude.ai/download before installing an MCP server.",
)
return
claude_config_filepath = claude_config_path / "claude_desktop_config.json"
Expand All @@ -88,7 +90,7 @@ def install_mcp_server(
"--python",
"3.11",
"--with",
"numpy<2.0.0", # TODO: Remove this constraint when uv no longer needs it to solve the environment.
"numpy<2.0.0", # TODO(lsorber): Remove this constraint when uv no longer needs it to solve the environment.
"raglite",
"mcp",
"run",
Expand All @@ -112,7 +114,9 @@ def run_mcp_server(
from raglite._mcp import create_mcp_server

config = RAGLiteConfig(
db_url=ctx.obj["db_url"], llm=ctx.obj["llm"], embedder=ctx.obj["embedder"]
db_url=ctx.obj["db_url"],
llm=ctx.obj["llm"],
embedder=ctx.obj["embedder"],
)
mcp = create_mcp_server(server_name, config=config)
mcp.run()
Expand All @@ -122,7 +126,10 @@ def run_mcp_server(
def bench(
ctx: typer.Context,
dataset_name: str = typer.Option(
"nano-beir/hotpotqa", "--dataset", "-d", help="Dataset to use from https://ir-datasets.com/"
"nano-beir/hotpotqa",
"--dataset",
"-d",
help="Dataset to use from https://ir-datasets.com/",
),
measure: str = typer.Option(
"AP@10",
Expand Down Expand Up @@ -157,7 +164,9 @@ def bench(
)
dataset = ir_datasets.load(dataset_name)
evaluator = RAGLiteEvaluator(
dataset, insert_variant=f"single-vector-{chunk_max_size // 4}t", config=config
dataset,
insert_variant=f"single-vector-{chunk_max_size // 4}t",
config=config,
)
index.append("RAGLite (single-vector)")
results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score()))
Expand All @@ -170,7 +179,9 @@ def bench(
)
dataset = ir_datasets.load(dataset_name)
evaluator = RAGLiteEvaluator(
dataset, insert_variant=f"multi-vector-{chunk_max_size // 4}t", config=config
dataset,
insert_variant=f"multi-vector-{chunk_max_size // 4}t",
config=config,
)
index.append("RAGLite (multi-vector)")
results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score()))
Expand Down
11 changes: 7 additions & 4 deletions src/raglite/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@


# Lazily load the default search method to avoid circular imports.
# TODO: Replace with search_and_rerank_chunk_spans after benchmarking.
# TODO(lsorber): Replace with search_and_rerank_chunk_spans after benchmarking.
def _vector_search(
query: str, *, num_results: int = 8, config: "RAGLiteConfig | None" = None
query: str,
*,
num_results: int = 8,
config: "RAGLiteConfig | None" = None,
) -> tuple[list[ChunkId], list[float]]:
from raglite._search import vector_search

Expand All @@ -45,7 +48,7 @@ class RAGLiteConfig:
"llama-cpp-python/unsloth/Qwen3-8B-GGUF/*Q4_K_M.gguf@8192"
if llama_supports_gpu_offload()
else "llama-cpp-python/unsloth/Qwen3-4B-GGUF/*Q4_K_M.gguf@8192"
)
),
)
llm_max_tries: int = 4
# Embedder config used for indexing.
Expand All @@ -54,7 +57,7 @@ class RAGLiteConfig:
"llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf@512"
if llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 4 # noqa: PLR2004
else "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@512"
)
),
)
embedder_normalize: bool = True
# Chunk config used to partition documents into chunks.
Expand Down
Loading