diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 91726e44..5199421b 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -499,31 +499,16 @@ def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False) return all_vector_ids, metadata def get_routes(self) -> List[Tuple]: - """ - Gets a list of route and utterance objects currently stored in the index, including additional metadata. + """Gets a list of route and utterance objects currently stored in the + index, including additional metadata. - Returns: - List[Tuple]: A list of tuples, each containing route, utterance, function schema and additional metadata. + :return: A list of tuples, each containing route, utterance, function + schema and additional metadata. + :rtype: List[Tuple] """ _, metadata = self._get_all(include_metadata=True) - route_tuples = [ - ( - data.get("sr_route", ""), - data.get("sr_utterance", ""), - ( - json.loads(data["sr_function_schema"]) - if data.get("sr_function_schema", "") - else {} - ), - { - key: value - for key, value in data.items() - if key not in ["sr_route", "sr_utterance", "sr_function_schema"] - }, - ) - for data in metadata - ] - return route_tuples # type: ignore + route_tuples = parse_route_info(metadata=metadata) + return route_tuples def delete(self, route_name: str): route_vec_ids = self._get_route_ids(route_name=route_name) @@ -553,8 +538,7 @@ def query( route_filter: Optional[List[str]] = None, **kwargs: Any, ) -> Tuple[np.ndarray, List[str]]: - """ - Search the index for the query vector and return the top_k results. + """Search the index for the query vector and return the top_k results. :param vector: The query vector to search for. :type vector: np.ndarray @@ -633,11 +617,11 @@ async def aquery( return np.array(scores), route_names async def aget_routes(self) -> list[tuple]: - """ - Asynchronously get a list of route and utterance objects currently stored in the index. + """Asynchronously get a list of route and utterance objects currently + stored in the index. - Returns: - List[Tuple]: A list of (route_name, utterance) objects. + :return: A list of (route_name, utterance) objects. + :rtype: List[Tuple] """ if self.async_client is None or self.host is None: raise ValueError("Async client or host are not initialized.") @@ -703,8 +687,15 @@ async def _async_describe_index(self, name: str): async def _async_get_all( self, prefix: Optional[str] = None, include_metadata: bool = False ) -> tuple[list[str], list[dict]]: - """ - Retrieves all vector IDs from the Pinecone index using pagination asynchronously. + """Retrieves all vector IDs from the Pinecone index using pagination + asynchronously. + + :param prefix: The prefix to filter the vectors by. + :type prefix: Optional[str] + :param include_metadata: Whether to include metadata in the response. + :type include_metadata: bool + :return: A tuple containing a list of vector IDs and a list of metadata dictionaries. + :rtype: tuple[list[str], list[dict]] """ if self.index is None: raise ValueError("Index is None, could not retrieve vector IDs.") @@ -754,8 +745,13 @@ async def _async_get_all( return all_vector_ids, metadata async def _async_fetch_metadata(self, vector_id: str) -> dict: - """ - Fetch metadata for a single vector ID asynchronously using the async_client. + """Fetch metadata for a single vector ID asynchronously using the + async_client. + + :param vector_id: The ID of the vector to fetch metadata for. + :type vector_id: str + :return: A dictionary containing the metadata for the vector. + :rtype: dict """ url = f"https://{self.host}/vectors/fetch" @@ -786,31 +782,45 @@ async def _async_fetch_metadata(self, vector_id: str) -> dict: ) async def _async_get_routes(self) -> List[Tuple]: - """ - Asynchronously gets a list of route and utterance objects currently stored in the index, including additional metadata. + """Asynchronously gets a list of route and utterance objects currently + stored in the index, including additional metadata. - Returns: - List[Tuple]: A list of tuples, each containing route, utterance, function schema and additional metadata. + :return: A list of tuples, each containing route, utterance, function + schema and additional metadata. + :rtype: List[Tuple] """ _, metadata = await self._async_get_all(include_metadata=True) - route_info = [ - ( - data.get("sr_route", ""), - data.get("sr_utterance", ""), - ( - json.loads(data["sr_function_schema"]) - if data.get("sr_function_schema", "") - else {} - ), - { - key: value - for key, value in data.items() - if key not in ["sr_route", "sr_utterance", "sr_function_schema"] - }, - ) - for data in metadata - ] + route_info = parse_route_info(metadata=metadata) return route_info # type: ignore def __len__(self): return self.index.describe_index_stats()["total_vector_count"] + + +def parse_route_info(metadata: List[Dict[str, Any]]) -> List[Tuple]: + """Parses metadata from Pinecone index to extract route, utterance, function + schema and additional metadata. + + :param metadata: List of metadata dictionaries. + :type metadata: List[Dict[str, Any]] + :return: A list of tuples, each containing route, utterance, function schema and additional metadata. + :rtype: List[Tuple] + """ + route_info = [] + for record in metadata: + sr_route = record.get("sr_route", "") + sr_utterance = record.get("sr_utterance", "") + sr_function_schema = json.loads(record.get("sr_function_schema", "{}")) + if sr_function_schema == {}: + sr_function_schema = None + + additional_metadata = { + key: value + for key, value in record.items() + if key not in ["sr_route", "sr_utterance", "sr_function_schema"] + } + # TODO: Not a fan of tuple packing here + route_info.append( + (sr_route, sr_utterance, sr_function_schema, additional_metadata) + ) + return route_info diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 0bf7d99f..42179459 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -544,15 +544,19 @@ def _add_and_sync_routes(self, routes: List[Route]): ) # Update local route layer state - self.routes = [ - Route( - name=route, - utterances=data.get("utterances", []), - function_schemas=[data.get("function_schemas", None)], - metadata=data.get("metadata", {}), + self.routes = [] + for route, data in layer_routes_dict.items(): + function_schemas = data.get("function_schemas", None) + if function_schemas is not None: + function_schemas = [function_schemas] + self.routes.append( + Route( + name=route, + utterances=data.get("utterances", []), + function_schemas=function_schemas, + metadata=data.get("metadata", {}), + ) ) - for route, data in layer_routes_dict.items() - ] def _extract_routes_details( self, routes: List[Route], include_metadata: bool = False diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index f22f409e..dfff8096 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -92,7 +92,7 @@ def __call__( raise ValueError("OpenAI client is not initialized.") try: tools: Union[List[Dict[str, Any]], NotGiven] = ( - function_schemas if function_schemas is not None else NOT_GIVEN + function_schemas if function_schemas else NOT_GIVEN ) completion = self.client.chat.completions.create( @@ -184,8 +184,6 @@ def extract_function_inputs( raise Exception("No output generated for extract function input") output = output.replace("'", '"') function_inputs = json.loads(output) - logger.info(f"Function inputs: {function_inputs}") - logger.info(f"function_schemas: {function_schemas}") if not self._is_valid_inputs(function_inputs, function_schemas): raise ValueError("Invalid inputs") return function_inputs @@ -203,7 +201,6 @@ async def async_extract_function_inputs( raise Exception("No output generated for extract function input") output = output.replace("'", '"') function_inputs = json.loads(output) - logger.info(f"OpenAI => Function Inputs: {function_inputs}") if not self._is_valid_inputs(function_inputs, function_schemas): raise ValueError("Invalid inputs") return function_inputs