Skip to content

Commit

Permalink
Merge pull request #416 from aurelio-labs/james/func-schema-bug-pinecone
Browse files Browse the repository at this point in the history
fix: bug in synced dynamic routes
  • Loading branch information
jamescalam committed Sep 6, 2024
2 parents 555290d + be86fe0 commit 38eddff
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 65 deletions.
116 changes: 63 additions & 53 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,31 +499,16 @@ def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False)
return all_vector_ids, metadata

def get_routes(self) -> List[Tuple]:
"""
Gets a list of route and utterance objects currently stored in the index, including additional metadata.
"""Gets a list of route and utterance objects currently stored in the
index, including additional metadata.
Returns:
List[Tuple]: A list of tuples, each containing route, utterance, function schema and additional metadata.
:return: A list of tuples, each containing route, utterance, function
schema and additional metadata.
:rtype: List[Tuple]
"""
_, metadata = self._get_all(include_metadata=True)
route_tuples = [
(
data.get("sr_route", ""),
data.get("sr_utterance", ""),
(
json.loads(data["sr_function_schema"])
if data.get("sr_function_schema", "")
else {}
),
{
key: value
for key, value in data.items()
if key not in ["sr_route", "sr_utterance", "sr_function_schema"]
},
)
for data in metadata
]
return route_tuples # type: ignore
route_tuples = parse_route_info(metadata=metadata)
return route_tuples

def delete(self, route_name: str):
route_vec_ids = self._get_route_ids(route_name=route_name)
Expand Down Expand Up @@ -553,8 +538,7 @@ def query(
route_filter: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[np.ndarray, List[str]]:
"""
Search the index for the query vector and return the top_k results.
"""Search the index for the query vector and return the top_k results.
:param vector: The query vector to search for.
:type vector: np.ndarray
Expand Down Expand Up @@ -633,11 +617,11 @@ async def aquery(
return np.array(scores), route_names

async def aget_routes(self) -> list[tuple]:
"""
Asynchronously get a list of route and utterance objects currently stored in the index.
"""Asynchronously get a list of route and utterance objects currently
stored in the index.
Returns:
List[Tuple]: A list of (route_name, utterance) objects.
:return: A list of (route_name, utterance) objects.
:rtype: List[Tuple]
"""
if self.async_client is None or self.host is None:
raise ValueError("Async client or host are not initialized.")
Expand Down Expand Up @@ -703,8 +687,15 @@ async def _async_describe_index(self, name: str):
async def _async_get_all(
self, prefix: Optional[str] = None, include_metadata: bool = False
) -> tuple[list[str], list[dict]]:
"""
Retrieves all vector IDs from the Pinecone index using pagination asynchronously.
"""Retrieves all vector IDs from the Pinecone index using pagination
asynchronously.
:param prefix: The prefix to filter the vectors by.
:type prefix: Optional[str]
:param include_metadata: Whether to include metadata in the response.
:type include_metadata: bool
:return: A tuple containing a list of vector IDs and a list of metadata dictionaries.
:rtype: tuple[list[str], list[dict]]
"""
if self.index is None:
raise ValueError("Index is None, could not retrieve vector IDs.")
Expand Down Expand Up @@ -754,8 +745,13 @@ async def _async_get_all(
return all_vector_ids, metadata

async def _async_fetch_metadata(self, vector_id: str) -> dict:
"""
Fetch metadata for a single vector ID asynchronously using the async_client.
"""Fetch metadata for a single vector ID asynchronously using the
async_client.
:param vector_id: The ID of the vector to fetch metadata for.
:type vector_id: str
:return: A dictionary containing the metadata for the vector.
:rtype: dict
"""
url = f"https://{self.host}/vectors/fetch"

Expand Down Expand Up @@ -786,31 +782,45 @@ async def _async_fetch_metadata(self, vector_id: str) -> dict:
)

async def _async_get_routes(self) -> List[Tuple]:
"""
Asynchronously gets a list of route and utterance objects currently stored in the index, including additional metadata.
"""Asynchronously gets a list of route and utterance objects currently
stored in the index, including additional metadata.
Returns:
List[Tuple]: A list of tuples, each containing route, utterance, function schema and additional metadata.
:return: A list of tuples, each containing route, utterance, function
schema and additional metadata.
:rtype: List[Tuple]
"""
_, metadata = await self._async_get_all(include_metadata=True)
route_info = [
(
data.get("sr_route", ""),
data.get("sr_utterance", ""),
(
json.loads(data["sr_function_schema"])
if data.get("sr_function_schema", "")
else {}
),
{
key: value
for key, value in data.items()
if key not in ["sr_route", "sr_utterance", "sr_function_schema"]
},
)
for data in metadata
]
route_info = parse_route_info(metadata=metadata)
return route_info # type: ignore

def __len__(self):
return self.index.describe_index_stats()["total_vector_count"]


def parse_route_info(metadata: List[Dict[str, Any]]) -> List[Tuple]:
"""Parses metadata from Pinecone index to extract route, utterance, function
schema and additional metadata.
:param metadata: List of metadata dictionaries.
:type metadata: List[Dict[str, Any]]
:return: A list of tuples, each containing route, utterance, function schema and additional metadata.
:rtype: List[Tuple]
"""
route_info = []
for record in metadata:
sr_route = record.get("sr_route", "")
sr_utterance = record.get("sr_utterance", "")
sr_function_schema = json.loads(record.get("sr_function_schema", "{}"))
if sr_function_schema == {}:
sr_function_schema = None

additional_metadata = {
key: value
for key, value in record.items()
if key not in ["sr_route", "sr_utterance", "sr_function_schema"]
}
# TODO: Not a fan of tuple packing here
route_info.append(
(sr_route, sr_utterance, sr_function_schema, additional_metadata)
)
return route_info
20 changes: 12 additions & 8 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,15 +544,19 @@ def _add_and_sync_routes(self, routes: List[Route]):
)

# Update local route layer state
self.routes = [
Route(
name=route,
utterances=data.get("utterances", []),
function_schemas=[data.get("function_schemas", None)],
metadata=data.get("metadata", {}),
self.routes = []
for route, data in layer_routes_dict.items():
function_schemas = data.get("function_schemas", None)
if function_schemas is not None:
function_schemas = [function_schemas]
self.routes.append(
Route(
name=route,
utterances=data.get("utterances", []),
function_schemas=function_schemas,
metadata=data.get("metadata", {}),
)
)
for route, data in layer_routes_dict.items()
]

def _extract_routes_details(
self, routes: List[Route], include_metadata: bool = False
Expand Down
5 changes: 1 addition & 4 deletions semantic_router/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __call__(
raise ValueError("OpenAI client is not initialized.")
try:
tools: Union[List[Dict[str, Any]], NotGiven] = (
function_schemas if function_schemas is not None else NOT_GIVEN
function_schemas if function_schemas else NOT_GIVEN
)

completion = self.client.chat.completions.create(
Expand Down Expand Up @@ -184,8 +184,6 @@ def extract_function_inputs(
raise Exception("No output generated for extract function input")
output = output.replace("'", '"')
function_inputs = json.loads(output)
logger.info(f"Function inputs: {function_inputs}")
logger.info(f"function_schemas: {function_schemas}")
if not self._is_valid_inputs(function_inputs, function_schemas):
raise ValueError("Invalid inputs")
return function_inputs
Expand All @@ -203,7 +201,6 @@ async def async_extract_function_inputs(
raise Exception("No output generated for extract function input")
output = output.replace("'", '"')
function_inputs = json.loads(output)
logger.info(f"OpenAI => Function Inputs: {function_inputs}")
if not self._is_valid_inputs(function_inputs, function_schemas):
raise ValueError("Invalid inputs")
return function_inputs
Expand Down

0 comments on commit 38eddff

Please sign in to comment.