diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index a578eb01..64312b3f 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -11,6 +11,7 @@ from semantic_router.index.base import BaseIndex from semantic_router.utils.logger import logger +from semantic_router.route import Route def clean_route_name(route_name: str) -> str: @@ -203,6 +204,28 @@ async def _init_async_index(self, force_create: bool = False): def _sync_index(self, local_routes: dict): remote_routes = self.get_routes() + if not local_routes["routes"]: + if self.sync != "remote": + raise ValueError( + "Local routes must be provided to sync the index if the sync setting is not 'remote'." + ) + else: + if not remote_routes: + raise ValueError("No routes found in the index.") + if ( + (self.sync in ["remote", "merge-force-remote"] and not remote_routes) + or ( + self.sync in ["error", "local", "merge-force-local"] + and not local_routes["routes"] + ) + or ( + self.sync == "merge" + and not remote_routes + and not local_routes["routes"] + ) + ): + raise ValueError("No routes found in the index.") + remote_dict: dict = {route: set() for route, _ in remote_routes} for route, utterance in remote_routes: remote_dict[route].add(utterance) @@ -215,6 +238,7 @@ def _sync_index(self, local_routes: dict): routes_to_add = [] routes_to_delete = [] + layer_routes = {} for route in all_routes: local_utterances = local_dict.get(route, set()) @@ -226,8 +250,11 @@ def _sync_index(self, local_routes: dict): f"Synchronization error: Differences found in route '{route}'" ) utterances_to_include: set = set() + layer_routes[route] = list(local_utterances) elif self.sync == "remote": utterances_to_include = set() + if remote_utterances: + layer_routes[route] = list(remote_utterances) elif self.sync == "local": utterances_to_include = local_utterances - remote_utterances routes_to_delete.extend( @@ -237,11 +264,16 @@ def _sync_index(self, local_routes: dict): if utterance not in local_utterances ] ) + layer_routes[route] = list(local_utterances) elif self.sync == "merge-force-remote": if route in local_dict and route not in remote_dict: utterances_to_include = local_utterances + if local_utterances: + layer_routes[route] = list(local_utterances) else: utterances_to_include = set() + if remote_utterances: + layer_routes[route] = list(remote_utterances) elif self.sync == "merge-force-local": if route in local_dict: utterances_to_include = local_utterances - remote_utterances @@ -252,10 +284,15 @@ def _sync_index(self, local_routes: dict): if utterance not in local_utterances ] ) + if local_utterances: + layer_routes[route] = local_utterances else: utterances_to_include = set() + if remote_utterances: + layer_routes[route] = list(remote_utterances) elif self.sync == "merge": utterances_to_include = local_utterances - remote_utterances + layer_routes[route] = list(remote_utterances.union(local_utterances)) else: raise ValueError("Invalid sync mode specified") @@ -272,7 +309,7 @@ def _sync_index(self, local_routes: dict): ] ) - return routes_to_add, routes_to_delete + return routes_to_add, routes_to_delete, layer_routes def _batch_upsert(self, batch: List[Dict]): """Helper method for upserting a single batch of records.""" @@ -308,8 +345,8 @@ def _add_and_sync( routes: List[str], utterances: List[str], batch_size: int = 100, - ): - """Add vectors to Pinecone in batches.""" + ) -> List[Route]: + """Add vectors to Pinecone in batches and return the overall updated list of Route objects.""" if self.index is None: self.dimensions = self.dimensions or len(embeddings[0]) self.index = self._init_index(force_create=True) @@ -320,7 +357,15 @@ def _add_and_sync( "embeddings": embeddings, } if self.sync is not None: - data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes) + data_to_upsert, data_to_delete, layer_routes_dict = self._sync_index( + local_routes=local_routes + ) + + layer_routes = [ + Route(name=route, utterances=layer_routes_dict[route]) + for route in layer_routes_dict.keys() + ] + routes_to_delete: dict = {} for route, utterance in data_to_delete: routes_to_delete.setdefault(route, []).append(utterance) @@ -335,6 +380,7 @@ def _add_and_sync( ] if ids_to_delete and self.index: self.index.delete(ids=ids_to_delete) + else: data_to_upsert = [ (vector, route, utterance) @@ -350,6 +396,8 @@ def _add_and_sync( batch = vectors_to_upsert[i : i + batch_size] self._batch_upsert(batch) + return layer_routes + def _get_route_ids(self, route_name: str): clean_route = clean_route_name(route_name) ids, _ = self._get_all(prefix=f"{clean_route}#") diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 5c2d7228..20a325b7 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -220,6 +220,19 @@ def __init__( if len(self.routes) > 0: # initialize index now self._add_routes(routes=self.routes) + elif self.index.sync in ["merge", "remote", "merge-force-remote"]: + dummy_embedding = self.encoder(["dummy"]) + + layer_routes = self.index._add_and_sync( + embeddings=dummy_embedding, + routes=[], + utterances=[], + ) + self._set_layer_routes(layer_routes) + else: + raise ValueError( + "No routes provided for RouteLayer. Please provide routes or set sync to 'remote' if you want to use only remote routes." + ) def check_for_matching_routes(self, top_class: str) -> Optional[Route]: matching_routes = [route for route in self.routes if route.name == top_class] @@ -380,6 +393,14 @@ def _check_threshold(self, scores: List[float], route: Optional[Route]) -> bool: ) return self._pass_threshold(scores, threshold) + def _set_layer_routes(self, new_routes: List[Route]): + """ + Set and override the current routes with a new list of routes. + + :param new_routes: List of Route objects to set as the current routes. + """ + self.routes = new_routes + def __str__(self): return ( f"RouteLayer(encoder={self.encoder}, " @@ -466,11 +487,12 @@ def _add_routes(self, routes: List[Route]): # create route array route_names = [route.name for route in routes for _ in route.utterances] # add everything to the index - self.index._add_and_sync( + layer_routes = self.index._add_and_sync( embeddings=embedded_utterances, routes=route_names, utterances=all_utterances, ) + self._set_layer_routes(layer_routes) def _encode(self, text: str) -> Any: """Given some text, encode it."""