diff --git a/.gitignore b/.gitignore index f571c929..094c61f4 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ node_modules package-lock.json package.json test.ipynb +test_sync.ipynb ``` # docs diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index a23f92bf..5ddb586e 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -26,6 +26,8 @@ def add( embeddings: List[List[float]], routes: List[str], utterances: List[Any], + function_schemas: Optional[List[Dict[str, Any]]] = None, + metadata_list: List[Dict[str, Any]] = [], ): """ Add embeddings to the index. @@ -109,7 +111,12 @@ def delete_index(self): raise NotImplementedError("This method should be implemented by subclasses.") def _sync_index( - self, local_route_names: List[str], local_utterances: List[str], dimensions: int + self, + local_route_names: List[str], + local_utterances: List[str], + local_function_schemas: List[Dict[str, Any]], + local_metadata: List[Dict[str, Any]], + dimensions: int, ): """ Synchronize the local index with the remote index based on the specified mode. @@ -117,9 +124,9 @@ def _sync_index( - "error": Raise an error if local and remote are not synchronized. - "remote": Take remote as the source of truth and update local to align. - "local": Take local as the source of truth and update remote to align. - - "merge-force-remote": Merge both local and remote taking only remote routes utterances when a route with same route name is present both locally and remotely. - - "merge-force-local": Merge both local and remote taking only local routes utterances when a route with same route name is present both locally and remotely. - - "merge": Merge both local and remote, merging also local and remote utterances when a route with same route name is present both locally and remotely. + - "merge-force-remote": Merge both local and remote taking only remote routes features when a route with same route name is present both locally and remotely. + - "merge-force-local": Merge both local and remote taking only local routes features when a route with same route name is present both locally and remotely. + - "merge": Merge both local and remote, merging also local and remote features when a route with same route name is present both locally and remotely. This method should be implemented by subclasses. """ diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index f2398618..09e23ffc 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -5,6 +5,7 @@ from semantic_router.index.base import BaseIndex from semantic_router.linear import similarity_matrix, top_scores from semantic_router.utils.logger import logger +from typing import Any class LocalIndex(BaseIndex): @@ -26,6 +27,8 @@ def add( embeddings: List[List[float]], routes: List[str], utterances: List[str], + function_schemas: Optional[List[Dict[str, Any]]] = None, + metadata_list: List[Dict[str, Any]] = [], ): embeds = np.array(embeddings) # type: ignore routes_arr = np.array(routes) @@ -47,7 +50,12 @@ def _remove_and_sync(self, routes_to_delete: dict): logger.warning("Sync remove is not implemented for LocalIndex.") def _sync_index( - self, local_route_names: List[str], local_utterances: List[str], dimensions: int + self, + local_route_names: List[str], + local_utterances: List[str], + local_function_schemas: List[Dict[str, Any]], + local_metadata: List[Dict[str, Any]], + dimensions: int, ): if self.sync is not None: logger.error("Sync remove is not implemented for LocalIndex.") diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 33b867be..174fb49a 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -3,7 +3,9 @@ import hashlib import os import time -from typing import Any, Dict, List, Optional, Tuple, Union +import json + +from typing import Any, Dict, List, Optional, Union, Tuple import numpy as np from pydantic.v1 import BaseModel, Field @@ -21,6 +23,8 @@ class PineconeRecord(BaseModel): values: List[float] route: str utterance: str + function_schema: str + metadata: Dict[str, Any] = {} # Additional metadata dictionary def __init__(self, **data): super().__init__(**data) @@ -28,12 +32,19 @@ def __init__(self, **data): # Use SHA-256 for a more secure hash utterance_id = hashlib.sha256(self.utterance.encode()).hexdigest() self.id = f"{clean_route}#{utterance_id}" + self.metadata.update( + { + "sr_route": self.route, + "sr_utterance": self.utterance, + "sr_function_schema": self.function_schema, + } + ) def to_dict(self): return { "id": self.id, "values": self.values, - "metadata": {"sr_route": self.route, "sr_utterance": self.utterance}, + "metadata": self.metadata, } @@ -205,21 +216,48 @@ async def _init_async_index(self, force_create: bool = False): self.host = index_stats["host"] if index_stats else None def _sync_index( - self, local_route_names: List[str], local_utterances: List[str], dimensions: int - ): + self, + local_route_names: List[str], + local_utterances_list: List[str], + local_function_schemas_list: List[Dict[str, Any]], + local_metadata_list: List[Dict[str, Any]], + dimensions: int, + ) -> Tuple[List, List, Dict]: if self.index is None: self.dimensions = self.dimensions or dimensions self.index = self._init_index(force_create=True) remote_routes = self.get_routes() - remote_dict: dict = {route: set() for route, _ in remote_routes} - for route, utterance in remote_routes: - remote_dict[route].add(utterance) - - local_dict: dict = {route: set() for route in local_route_names} - for route, utterance in zip(local_route_names, local_utterances): - local_dict[route].add(utterance) + # Create remote dictionary for storing utterances and metadata + remote_dict: Dict[str, Dict[str, Any]] = { + route: { + "utterances": set(), + "function_schemas": function_schemas, + "metadata": metadata, + } + for route, utterance, function_schemas, metadata in remote_routes + } + for route, utterance, function_schemas, metadata in remote_routes: + remote_dict[route]["utterances"].add(utterance) + + # Create local dictionary for storing utterances and metadata + local_dict: Dict[str, Dict[str, Any]] = {} + for route, utterance, function_schemas, metadata in zip( + local_route_names, + local_utterances_list, + local_function_schemas_list, + local_metadata_list, + ): + if route not in local_dict: + local_dict[route] = { + "utterances": set(), + "function_schemas": function_schemas, + "metadata": metadata, + } + local_dict[route]["utterances"].add(utterance) + local_dict[route]["function_schemas"] = function_schemas + local_dict[route]["metadata"] = metadata all_routes = set(remote_dict.keys()).union(local_dict.keys()) @@ -228,24 +266,47 @@ def _sync_index( layer_routes = {} for route in all_routes: - local_utterances = local_dict.get(route, set()) - remote_utterances = remote_dict.get(route, set()) + local_utterances = local_dict.get(route, {}).get("utterances", set()) + remote_utterances = remote_dict.get(route, {}).get("utterances", set()) + local_function_schemas = local_dict.get(route, {}).get( + "function_schemas", {} + ) + remote_function_schemas = remote_dict.get(route, {}).get( + "function_schemas", {} + ) + local_metadata = local_dict.get(route, {}).get("metadata", {}) + remote_metadata = remote_dict.get(route, {}).get("metadata", {}) - if not local_utterances and not remote_utterances: - continue + utterances_to_include = set() + + metadata_changed = local_metadata != remote_metadata + function_schema_changed = local_function_schemas != remote_function_schemas if self.sync == "error": - if local_utterances != remote_utterances: + if ( + local_utterances != remote_utterances + or local_function_schemas != remote_function_schemas + or local_metadata != remote_metadata + ): raise ValueError( f"Synchronization error: Differences found in route '{route}'" ) - utterances_to_include: set = set() + if local_utterances: - layer_routes[route] = list(local_utterances) + layer_routes[route] = { + "utterances": list(local_utterances), + "function_schemas": local_function_schemas, + "metadata": local_metadata, + } + elif self.sync == "remote": - utterances_to_include = set() if remote_utterances: - layer_routes[route] = list(remote_utterances) + layer_routes[route] = { + "utterances": list(remote_utterances), + "function_schemas": remote_function_schemas, + "metadata": remote_metadata, + } + elif self.sync == "local": utterances_to_include = local_utterances - remote_utterances routes_to_delete.extend( @@ -256,16 +317,29 @@ def _sync_index( ] ) if local_utterances: - layer_routes[route] = list(local_utterances) + layer_routes[route] = { + "utterances": list(local_utterances), + "function_schemas": local_function_schemas, + "metadata": local_metadata, + } + elif self.sync == "merge-force-remote": if route in local_dict and route not in remote_dict: - utterances_to_include = set(local_utterances) + utterances_to_include = local_utterances if local_utterances: - layer_routes[route] = list(local_utterances) + layer_routes[route] = { + "utterances": list(local_utterances), + "function_schemas": local_function_schemas, + "metadata": local_metadata, + } else: - utterances_to_include = set() if remote_utterances: - layer_routes[route] = list(remote_utterances) + layer_routes[route] = { + "utterances": list(remote_utterances), + "function_schemas": remote_function_schemas, + "metadata": remote_metadata, + } + elif self.sync == "merge-force-local": if route in local_dict: utterances_to_include = local_utterances - remote_utterances @@ -277,22 +351,56 @@ def _sync_index( ] ) if local_utterances: - layer_routes[route] = local_utterances + layer_routes[route] = { + "utterances": list(local_utterances), + "function_schemas": local_function_schemas, + "metadata": local_metadata, + } else: - utterances_to_include = set() if remote_utterances: - layer_routes[route] = list(remote_utterances) + layer_routes[route] = { + "utterances": list(remote_utterances), + "function_schemas": remote_function_schemas, + "metadata": remote_metadata, + } + elif self.sync == "merge": utterances_to_include = local_utterances - remote_utterances if local_utterances or remote_utterances: - layer_routes[route] = list( - remote_utterances.union(local_utterances) - ) + # Here metadata are merged, with local metadata taking precedence for same keys + merged_metadata = {**remote_metadata, **local_metadata} + merged_function_schemas = { + **remote_function_schemas, + **local_function_schemas, + } + layer_routes[route] = { + "utterances": list(remote_utterances.union(local_utterances)), + "function_schemas": merged_function_schemas, + "metadata": merged_metadata, + } + else: raise ValueError("Invalid sync mode specified") - for utterance in utterances_to_include: - routes_to_add.append((route, utterance)) + # Add utterances if metadata has changed or if there are new utterances + if (metadata_changed or function_schema_changed) and self.sync in [ + "local", + "merge-force-local", + ]: + for utterance in local_utterances: + routes_to_add.append( + (route, utterance, local_function_schemas, local_metadata) + ) + if (metadata_changed or function_schema_changed) and self.sync == "merge": + for utterance in local_utterances: + routes_to_add.append( + (route, utterance, merged_function_schemas, merged_metadata) + ) + elif utterances_to_include: + for utterance in utterances_to_include: + routes_to_add.append( + (route, utterance, local_function_schemas, local_metadata) + ) return routes_to_add, routes_to_delete, layer_routes @@ -308,6 +416,8 @@ def add( embeddings: List[List[float]], routes: List[str], utterances: List[str], + function_schemas: Optional[List[Dict[str, Any]]] = None, + metadata_list: List[Dict[str, Any]] = [], batch_size: int = 100, ): """Add vectors to Pinecone in batches.""" @@ -316,8 +426,16 @@ def add( self.index = self._init_index(force_create=True) vectors_to_upsert = [ - PineconeRecord(values=vector, route=route, utterance=utterance).to_dict() - for vector, route, utterance in zip(embeddings, routes, utterances) + PineconeRecord( + values=vector, + route=route, + utterance=utterance, + function_schema=json.dumps(function_schema), + metadata=metadata, + ).to_dict() + for vector, route, utterance, function_schema, metadata in zip( + embeddings, routes, utterances, function_schemas, metadata_list # type: ignore + ) ] for i in range(0, len(vectors_to_upsert), batch_size): @@ -382,15 +500,30 @@ def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False) def get_routes(self) -> List[Tuple]: """ - Gets a list of route and utterance objects currently stored in the index. + Gets a list of route and utterance objects currently stored in the index, including additional metadata. Returns: - List[Tuple]: A list of (route_name, utterance) objects. + List[Tuple]: A list of tuples, each containing route, utterance, function schema and additional metadata. """ - # Get all records _, metadata = self._get_all(include_metadata=True) - route_tuples = [(x["sr_route"], x["sr_utterance"]) for x in metadata] - return route_tuples + route_tuples = [ + ( + data.get("sr_route", ""), + data.get("sr_utterance", ""), + ( + json.loads(data["sr_function_schema"]) + if data.get("sr_function_schema", "") + else {} + ), + { + key: value + for key, value in data.items() + if key not in ["sr_route", "sr_utterance", "sr_function_schema"] + }, + ) + for data in metadata + ] + return route_tuples # type: ignore def delete(self, route_name: str): route_vec_ids = self._get_route_ids(route_name=route_name) @@ -652,16 +785,32 @@ async def _async_fetch_metadata(self, vector_id: str) -> dict: response_data.get("vectors", {}).get(vector_id, {}).get("metadata", {}) ) - async def _async_get_routes(self) -> list[tuple]: + async def _async_get_routes(self) -> List[Tuple]: """ - Gets a list of route and utterance objects currently stored in the index. + Asynchronously gets a list of route and utterance objects currently stored in the index, including additional metadata. Returns: - List[Tuple]: A list of (route_name, utterance) objects. + List[Tuple]: A list of tuples, each containing route, utterance, function schema and additional metadata. """ _, metadata = await self._async_get_all(include_metadata=True) - route_tuples = [(x["sr_route"], x["sr_utterance"]) for x in metadata] - return route_tuples + route_info = [ + ( + data.get("sr_route", ""), + data.get("sr_utterance", ""), + ( + json.loads(data["sr_function_schema"]) + if data["sr_function_schema"] + else {} + ), + { + key: value + for key, value in data.items() + if key not in ["sr_route", "sr_utterance", "sr_function_schema"] + }, + ) + for data in metadata + ] + return route_info # type: ignore def __len__(self): return self.index.describe_index_stats()["total_vector_count"] diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py index 52afdeec..ff63ec09 100644 --- a/semantic_router/index/postgres.py +++ b/semantic_router/index/postgres.py @@ -255,7 +255,12 @@ def _check_embeddings_dimensions(self) -> bool: raise ValueError("No comment found for the 'vector' column.") def add( - self, embeddings: List[List[float]], routes: List[str], utterances: List[Any] + self, + embeddings: List[List[float]], + routes: List[str], + utterances: List[str], + function_schemas: Optional[List[Dict[str, Any]]] = None, + metadata_list: List[Dict[str, Any]] = [], ) -> None: """ Adds vectors to the index. diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index 3077da33..b372c49c 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -165,7 +165,12 @@ def _remove_and_sync(self, routes_to_delete: dict): logger.error("Sync remove is not implemented for QdrantIndex.") def _sync_index( - self, local_route_names: List[str], local_utterances: List[str], dimensions: int + self, + local_route_names: List[str], + local_utterances_list: List[str], + local_function_schemas: List[Dict[str, Any]], + local_metadata_list: List[Dict[str, Any]], + dimensions: int, ): if self.sync is not None: logger.error("Sync remove is not implemented for QdrantIndex.") @@ -175,6 +180,8 @@ def add( embeddings: List[List[float]], routes: List[str], utterances: List[str], + function_schemas: Optional[List[Dict[str, Any]]] = None, + metadata_list: List[Dict[str, Any]] = [], batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE, ): self.dimensions = self.dimensions or len(embeddings[0]) diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 6b548fc0..0bf7d99f 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -195,7 +195,7 @@ def __init__( else: self.encoder = encoder self.llm = llm - self.routes: List[Route] = routes if routes is not None else [] + self.routes = routes if routes else [] if self.encoder.score_threshold is None: raise ValueError( "No score threshold provided for encoder. Please set the score threshold " @@ -227,14 +227,16 @@ def __init__( self._add_routes(routes=self.routes) def check_for_matching_routes(self, top_class: str) -> Optional[Route]: - matching_routes = [route for route in self.routes if route.name == top_class] - if not matching_routes: + matching_route = next( + (route for route in self.routes if route.name == top_class), None + ) + if matching_route is None: logger.error( f"No route found with name {top_class}. Check to see if any Routes " "have been defined." ) return None - return matching_routes[0] + return matching_route def __call__( self, @@ -390,14 +392,6 @@ def _check_threshold(self, scores: List[float], route: Optional[Route]) -> bool: ) return self._pass_threshold(scores, threshold) - def _set_layer_routes(self, new_routes: List[Route]): - """ - Set and override the current routes with a new list of routes. - - :param new_routes: List of Route objects to set as the current routes. - """ - self.routes = new_routes - def __str__(self): return ( f"RouteLayer(encoder={self.encoder}, " @@ -423,19 +417,20 @@ def from_config(cls, config: LayerConfig, index: Optional[BaseIndex] = None): return cls(encoder=encoder, routes=config.routes, index=index) def add(self, route: Route): - logger.info(f"Adding `{route.name}` route") - # create embeddings - embeds = self.encoder(route.utterances) - # if route has no score_threshold, use default - if route.score_threshold is None: - route.score_threshold = self.score_threshold - - # add routes to the index + embedded_utterances = self.encoder(route.utterances) self.index.add( - embeddings=embeds, + embeddings=embedded_utterances, routes=[route.name] * len(route.utterances), utterances=route.utterances, + function_schemas=( + route.function_schemas * len(route.utterances) + if route.function_schemas + else [{}] * len(route.utterances) + ), + metadata_list=[route.metadata if route.metadata else {}] + * len(route.utterances), ) + self.routes.append(route) def list_route_names(self) -> List[str]: @@ -476,55 +471,104 @@ def _refresh_routes(self): self.routes.append(route) def _add_routes(self, routes: List[Route]): + if not routes: + logger.warning("No routes provided to add.") + return # create embeddings for all routes - route_names, all_utterances = self._extract_routes_details(routes) - embedded_utterances = self.encoder(all_utterances) - # create route array - # add everything to the index - self.index.add( - embeddings=embedded_utterances, - routes=route_names, - utterances=all_utterances, + route_names, all_utterances, all_function_schemas, all_metadata = ( + self._extract_routes_details(routes, include_metadata=True) ) + embedded_utterances = self.encoder(all_utterances) + try: + # Batch insertion into the index + self.index.add( + embeddings=embedded_utterances, + routes=route_names, + utterances=all_utterances, + function_schemas=all_function_schemas, + metadata_list=all_metadata, + ) + except Exception as e: + logger.error(f"Failed to add routes to the index: {e}") + raise Exception("Indexing error occurred") from e def _add_and_sync_routes(self, routes: List[Route]): # create embeddings for all routes and sync at startup with remote ones based on sync setting - local_route_names, local_utterances = self._extract_routes_details(routes) + local_route_names, local_utterances, local_function_schemas, local_metadata = ( + self._extract_routes_details(routes, include_metadata=True) + ) + routes_to_add, routes_to_delete, layer_routes_dict = self.index._sync_index( - local_route_names=local_route_names, - local_utterances=local_utterances, + local_route_names, + local_utterances, + local_function_schemas, + local_metadata, dimensions=len(self.encoder(["dummy"])[0]), ) - layer_routes = [ - Route(name=route, utterances=layer_routes_dict[route]) - for route in layer_routes_dict.keys() - ] + logger.info(f"Routes to add: {routes_to_add}") + logger.info(f"Routes to delete: {routes_to_delete}") + logger.info(f"Layer routes: {layer_routes_dict}") - data_to_delete: dict = {} + data_to_delete = {} # type: ignore for route, utterance in routes_to_delete: data_to_delete.setdefault(route, []).append(utterance) self.index._remove_and_sync(data_to_delete) - all_utterances_to_add = [utt for _, utt in routes_to_add] + # Prepare data for addition + if routes_to_add: + ( + route_names_to_add, + all_utterances_to_add, + function_schemas_to_add, + metadata_to_add, + ) = map(list, zip(*routes_to_add)) + else: + ( + route_names_to_add, + all_utterances_to_add, + function_schemas_to_add, + metadata_to_add, + ) = ([], [], [], []) + embedded_utterances_to_add = ( self.encoder(all_utterances_to_add) if all_utterances_to_add else [] ) - route_names_to_add = [route for route, _, in routes_to_add] - self.index.add( embeddings=embedded_utterances_to_add, routes=route_names_to_add, utterances=all_utterances_to_add, + function_schemas=function_schemas_to_add, + metadata_list=metadata_to_add, ) - self._set_layer_routes(layer_routes) + # Update local route layer state + self.routes = [ + Route( + name=route, + utterances=data.get("utterances", []), + function_schemas=[data.get("function_schemas", None)], + metadata=data.get("metadata", {}), + ) + for route, data in layer_routes_dict.items() + ] - def _extract_routes_details(self, routes: List[Route]) -> Tuple: + def _extract_routes_details( + self, routes: List[Route], include_metadata: bool = False + ) -> Tuple: route_names = [route.name for route in routes for _ in route.utterances] utterances = [utterance for route in routes for utterance in route.utterances] - return route_names, utterances + function_schemas = [ + route.function_schemas[0] if route.function_schemas is not None else {} + for route in routes + for _ in route.utterances + ] + + if include_metadata: + metadata = [route.metadata for route in routes for _ in route.utterances] + return route_names, utterances, function_schemas, metadata + return route_names, utterances, function_schemas def _encode(self, text: str) -> Any: """Given some text, encode it.""" @@ -718,11 +762,15 @@ def fit( remote_routes = self.index.get_routes() # TODO Enhance by retrieving directly the vectors instead of embedding all utterances again - routes = [route_tuple[0] for route_tuple in remote_routes] - utterances = [route_tuple[1] for route_tuple in remote_routes] + routes, utterances, metadata = map(list, zip(*remote_routes)) embeddings = self.encoder(utterances) self.index = LocalIndex() - self.index.add(embeddings=embeddings, routes=routes, utterances=utterances) + self.index.add( + embeddings=embeddings, + routes=routes, + utterances=utterances, + metadata_list=metadata, + ) # convert inputs into array Xq: List[List[float]] = [] diff --git a/semantic_router/route.py b/semantic_router/route.py index 3fc3f040..41fd0bf2 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -50,6 +50,7 @@ class Route(BaseModel): function_schemas: Optional[List[Dict[str, Any]]] = None llm: Optional[BaseLLM] = None score_threshold: Optional[float] = None + metadata: Optional[Dict[str, Any]] = {} class Config: arbitrary_types_allowed = True diff --git a/tests/unit/test_route.py b/tests/unit/test_route.py index b21ffd87..fe202181 100644 --- a/tests/unit/test_route.py +++ b/tests/unit/test_route.py @@ -127,6 +127,7 @@ def test_to_dict(self): "function_schemas": None, "llm": None, "score_threshold": None, + "metadata": {}, } assert route.to_dict() == expected_dict