Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/llama_stack/core/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
StorageConfig,
)
from llama_stack.log import LoggingConfig
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT
from llama_stack_api import (
Api,
Benchmark,
Expand Down Expand Up @@ -349,6 +350,27 @@ class QualifiedModel(BaseModel):
model_id: str


class RewriteQueryParams(BaseModel):
"""Parameters for query rewriting/expansion."""

model: QualifiedModel | None = Field(
default=None,
description="LLM model for query rewriting/expansion in vector search.",
)
prompt: str = Field(
default=DEFAULT_QUERY_REWRITE_PROMPT,
description="Prompt template for query rewriting. Use {query} as placeholder for the original query.",
)
max_tokens: int = Field(
default=100,
description="Maximum number of tokens for query expansion responses.",
)
temperature: float = Field(
default=0.3,
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
)


class VectorStoresConfig(BaseModel):
"""Configuration for vector stores in the stack."""

Expand All @@ -360,6 +382,10 @@ class VectorStoresConfig(BaseModel):
default=None,
description="Default embedding model configuration for vector stores.",
)
rewrite_query_params: RewriteQueryParams | None = Field(
default=None,
description="Parameters for query rewriting/expansion. None disables query rewriting.",
)


class SafetyConfig(BaseModel):
Expand Down
17 changes: 16 additions & 1 deletion src/llama_stack/core/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,13 @@ def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str,
)
}

# Add inference as an optional dependency for vector_io to enable query rewriting
optional_deps = []
deps_list = [info.routing_table_api.value]
if info.router_api == Api.vector_io:
optional_deps = [Api.inference]
deps_list.append(Api.inference.value)

specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__autorouted__",
Expand All @@ -209,7 +216,8 @@ def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str,
module="llama_stack.core.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]),
optional_api_dependencies=optional_deps,
deps__=deps_list,
),
)
}
Expand Down Expand Up @@ -315,6 +323,13 @@ async def instantiate_providers(
api = Api(api_str)
impls[api] = impl

# Post-instantiation: Inject VectorIORouter into VectorStoresRoutingTable
if Api.vector_io in impls and Api.vector_stores in impls:
vector_io_router = impls[Api.vector_io]
vector_stores_routing_table = impls[Api.vector_stores]
if hasattr(vector_stores_routing_table, "vector_io_router"):
vector_stores_routing_table.vector_io_router = vector_io_router

return impls


Expand Down
1 change: 1 addition & 0 deletions src/llama_stack/core/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ async def get_auto_router_impl(
api_to_dep_impl["store"] = inference_store
elif api == Api.vector_io:
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
api_to_dep_impl["inference_api"] = deps.get(Api.inference)
elif api == Api.safety:
api_to_dep_impl["safety_config"] = run_config.safety

Expand Down
60 changes: 57 additions & 3 deletions src/llama_stack/core/routers/vector_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
Chunk,
HealthResponse,
HealthStatus,
Inference,
InterleavedContent,
ModelNotFoundError,
ModelType,
ModelTypeError,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
OpenAICreateVectorStoreRequestWithExtraBody,
OpenAIUserMessageParam,
QueryChunksResponse,
RoutingTable,
SearchRankingOptions,
Expand Down Expand Up @@ -51,10 +54,11 @@ def __init__(
self,
routing_table: RoutingTable,
vector_stores_config: VectorStoresConfig | None = None,
inference_api: Inference | None = None,
) -> None:
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table
self.vector_stores_config = vector_stores_config
self.inference_api = inference_api

async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
Expand All @@ -64,6 +68,46 @@ async def shutdown(self) -> None:
logger.debug("VectorIORouter.shutdown")
pass

async def _rewrite_query_for_search(self, query: str) -> str:
"""Rewrite a search query using the configured LLM model for better retrieval results."""
if (
not self.vector_stores_config
or not self.vector_stores_config.rewrite_query_params
or not self.vector_stores_config.rewrite_query_params.model
):
logger.warning(
"User is trying to use vector_store query rewriting, but it is not configured. Please configure rewrite_query_params.model in vector_stores config."
)
raise ValueError("Query rewriting is not available")

if not self.inference_api:
logger.warning("Query rewriting requires inference API but it is not available")
raise ValueError("Query rewriting is not available")

model = self.vector_stores_config.rewrite_query_params.model
model_id = f"{model.provider_id}/{model.model_id}"

prompt = self.vector_stores_config.rewrite_query_params.prompt.format(query=query)

request = OpenAIChatCompletionRequestWithExtraBody(
model=model_id,
messages=[OpenAIUserMessageParam(role="user", content=prompt)],
max_tokens=self.vector_stores_config.rewrite_query_params.max_tokens or 100,
temperature=self.vector_stores_config.rewrite_query_params.temperature or 0.3,
)

try:
response = await self.inference_api.openai_chat_completion(request)
content = response.choices[0].message.content
if content is None:
logger.error(f"LLM returned None content for query rewriting. Model: {model_id}")
raise RuntimeError("Query rewrite failed due to an internal error")
rewritten_query: str = content.strip()
return rewritten_query
except Exception as e:
logger.error(f"Query rewrite failed with LLM call error. Model: {model_id}, Error: {e}")
raise RuntimeError("Query rewrite failed due to an internal error") from e

async def _get_embedding_model_dimension(self, embedding_model_id: str) -> int:
"""Get the embedding dimension for a specific embedding model."""
all_models = await self.routing_table.get_all_with_type("model")
Expand Down Expand Up @@ -292,14 +336,24 @@ async def openai_search_vector_store(
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")

# Handle query rewriting at the router level
search_query = query
if rewrite_query:
if isinstance(query, list):
original_query = " ".join(query)
else:
original_query = query
search_query = await self._rewrite_query_for_search(original_query)

provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
query=search_query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
rewrite_query=False, # Already handled at router level
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i agree with handling this at the router level and probably there's no way to avoid this but i at least want to state for the record that someone outside of us looking at this code in isolation may result in confusion...but maybe it'll just be an LLM that pulls the router into the context. 🤷 🥲

search_mode=search_mode,
)

Expand Down
23 changes: 23 additions & 0 deletions src/llama_stack/core/routing_tables/vector_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
Only provides internal routing functionality for VectorIORouter.
"""

def __init__(
self,
impls_by_provider_id: dict[str, Any],
dist_registry: Any,
policy: list[Any],
) -> None:
super().__init__(impls_by_provider_id, dist_registry, policy)
self.vector_io_router = None # Will be set post-instantiation

# Internal methods only - no public API exposure

async def register_vector_store(
Expand Down Expand Up @@ -133,6 +142,20 @@ async def openai_search_vector_store(
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_store", vector_store_id)

# Delegate to VectorIORouter if available (which handles query rewriting)
if self.vector_io_router is not None:
return await self.vector_io_router.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
search_mode=search_mode,
)

# Fallback to direct provider call if VectorIORouter not available
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
Expand Down
64 changes: 50 additions & 14 deletions src/llama_stack/core/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic import BaseModel

from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, SafetyConfig, StackConfig, VectorStoresConfig
from llama_stack.core.datatypes import Provider, QualifiedModel, SafetyConfig, StackConfig, VectorStoresConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
Expand Down Expand Up @@ -221,35 +221,71 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
if vector_stores_config is None:
return

default_embedding_model = vector_stores_config.default_embedding_model
if default_embedding_model is None:
return
# Validate default embedding model
if vector_stores_config.default_embedding_model is not None:
await _validate_embedding_model(vector_stores_config.default_embedding_model, impls)

# Validate rewrite query params
if vector_stores_config.rewrite_query_params:
if vector_stores_config.rewrite_query_params.model:
await _validate_rewrite_query_model(vector_stores_config.rewrite_query_params.model, impls)
if "{query}" not in vector_stores_config.rewrite_query_params.prompt:
raise ValueError("'{query}' placeholder is required in the prompt template")


provider_id = default_embedding_model.provider_id
model_id = default_embedding_model.model_id
default_model_id = f"{provider_id}/{model_id}"
async def _validate_embedding_model(embedding_model: QualifiedModel, impls: dict[Api, Any]) -> None:
"""Validate that an embedding model exists and has required metadata."""
provider_id = embedding_model.provider_id
model_id = embedding_model.model_id
model_identifier = f"{provider_id}/{model_id}"

if Api.models not in impls:
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
raise ValueError(f"Models API is not available but vector_stores config requires model '{model_identifier}'")

models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}

default_model = models_list.get(default_model_id)
if default_model is None:
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
model = models_list.get(model_identifier)
if model is None:
raise ValueError(
f"Embedding model '{model_identifier}' not found. Available embedding models: {list(models_list.keys())}"
)

embedding_dimension = default_model.metadata.get("embedding_dimension")
embedding_dimension = model.metadata.get("embedding_dimension")
if embedding_dimension is None:
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
raise ValueError(f"Embedding model '{model_identifier}' is missing 'embedding_dimension' in metadata")

try:
int(embedding_dimension)
except ValueError as err:
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err

logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
logger.debug(f"Validated embedding model: {model_identifier} (dimension: {embedding_dimension})")


async def _validate_rewrite_query_model(rewrite_query_model: QualifiedModel, impls: dict[Api, Any]) -> None:
"""Validate that a rewrite query model exists and is accessible."""
provider_id = rewrite_query_model.provider_id
model_id = rewrite_query_model.model_id
model_identifier = f"{provider_id}/{model_id}"

if Api.models not in impls:
raise ValueError(
f"Models API is not available but vector_stores config requires rewrite query model '{model_identifier}'"
)

models_impl = impls[Api.models]
response = await models_impl.list_models()
llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"}

model = llm_models_list.get(model_identifier)
if model is None:
raise ValueError(
f"Rewrite query model '{model_identifier}' not found. Available LLM models: {list(llm_models_list.keys())}"
)

logger.debug(f"Validated rewrite query model: {model_identifier}")


async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
Expand Down
4 changes: 4 additions & 0 deletions src/llama_stack/providers/utils/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .constants import DEFAULT_QUERY_REWRITE_PROMPT

__all__ = ["DEFAULT_QUERY_REWRITE_PROMPT"]
8 changes: 8 additions & 0 deletions src/llama_stack/providers/utils/memory/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

# Default prompt template for query rewriting in vector search
DEFAULT_QUERY_REWRITE_PROMPT = "Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:"
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,11 @@ async def openai_search_vector_store(
str | None
) = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
) -> VectorStoreSearchResponsePage:
"""Search for chunks in a vector store."""
"""Search for chunks in a vector store.

Note: Query rewriting is handled at the router level, not here.
The rewrite_query parameter is kept for API compatibility but is ignored.
"""
max_num_results = max_num_results or 10

# Validate search_mode
Expand Down
Loading
Loading