Skip to content

Commit

Permalink
Fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Vits-99 committed Jul 16, 2024
1 parent a59e7d1 commit 09c3a74
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 5 deletions.
56 changes: 52 additions & 4 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from semantic_router.index.base import BaseIndex
from semantic_router.utils.logger import logger
from semantic_router.route import Route


def clean_route_name(route_name: str) -> str:
Expand Down Expand Up @@ -203,6 +204,28 @@ async def _init_async_index(self, force_create: bool = False):

def _sync_index(self, local_routes: dict):
remote_routes = self.get_routes()
if not local_routes["routes"]:
if self.sync != "remote":
raise ValueError(
"Local routes must be provided to sync the index if the sync setting is not 'remote'."
)
else:
if not remote_routes:
raise ValueError("No routes found in the index.")
if (
(self.sync in ["remote", "merge-force-remote"] and not remote_routes)
or (
self.sync in ["error", "local", "merge-force-local"]
and not local_routes["routes"]
)
or (
self.sync == "merge"
and not remote_routes
and not local_routes["routes"]
)
):
raise ValueError("No routes found in the index.")

remote_dict: dict = {route: set() for route, _ in remote_routes}
for route, utterance in remote_routes:
remote_dict[route].add(utterance)
Expand All @@ -215,6 +238,7 @@ def _sync_index(self, local_routes: dict):

routes_to_add = []
routes_to_delete = []
layer_routes = {}

for route in all_routes:
local_utterances = local_dict.get(route, set())
Expand All @@ -226,8 +250,11 @@ def _sync_index(self, local_routes: dict):
f"Synchronization error: Differences found in route '{route}'"
)
utterances_to_include: set = set()
layer_routes[route] = list(local_utterances)
elif self.sync == "remote":
utterances_to_include = set()
if remote_utterances:
layer_routes[route] = list(remote_utterances)
elif self.sync == "local":
utterances_to_include = local_utterances - remote_utterances
routes_to_delete.extend(
Expand All @@ -237,11 +264,16 @@ def _sync_index(self, local_routes: dict):
if utterance not in local_utterances
]
)
layer_routes[route] = list(local_utterances)
elif self.sync == "merge-force-remote":
if route in local_dict and route not in remote_dict:
utterances_to_include = local_utterances
if local_utterances:
layer_routes[route] = list(local_utterances)
else:
utterances_to_include = set()
if remote_utterances:
layer_routes[route] = list(remote_utterances)
elif self.sync == "merge-force-local":
if route in local_dict:
utterances_to_include = local_utterances - remote_utterances
Expand All @@ -252,10 +284,15 @@ def _sync_index(self, local_routes: dict):
if utterance not in local_utterances
]
)
if local_utterances:
layer_routes[route] = local_utterances
else:
utterances_to_include = set()
if remote_utterances:
layer_routes[route] = list(remote_utterances)
elif self.sync == "merge":
utterances_to_include = local_utterances - remote_utterances
layer_routes[route] = list(remote_utterances.union(local_utterances))
else:
raise ValueError("Invalid sync mode specified")

Expand All @@ -272,7 +309,7 @@ def _sync_index(self, local_routes: dict):
]
)

return routes_to_add, routes_to_delete
return routes_to_add, routes_to_delete, layer_routes

def _batch_upsert(self, batch: List[Dict]):
"""Helper method for upserting a single batch of records."""
Expand Down Expand Up @@ -308,8 +345,8 @@ def _add_and_sync(
routes: List[str],
utterances: List[str],
batch_size: int = 100,
):
"""Add vectors to Pinecone in batches."""
) -> List[Route]:
"""Add vectors to Pinecone in batches and return the overall updated list of Route objects."""
if self.index is None:
self.dimensions = self.dimensions or len(embeddings[0])
self.index = self._init_index(force_create=True)
Expand All @@ -320,7 +357,15 @@ def _add_and_sync(
"embeddings": embeddings,
}
if self.sync is not None:
data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes)
data_to_upsert, data_to_delete, layer_routes_dict = self._sync_index(
local_routes=local_routes
)

layer_routes = [
Route(name=route, utterances=layer_routes_dict[route])
for route in layer_routes_dict.keys()
]

routes_to_delete: dict = {}
for route, utterance in data_to_delete:
routes_to_delete.setdefault(route, []).append(utterance)
Expand All @@ -335,6 +380,7 @@ def _add_and_sync(
]
if ids_to_delete and self.index:
self.index.delete(ids=ids_to_delete)

else:
data_to_upsert = [
(vector, route, utterance)
Expand All @@ -350,6 +396,8 @@ def _add_and_sync(
batch = vectors_to_upsert[i : i + batch_size]
self._batch_upsert(batch)

return layer_routes

def _get_route_ids(self, route_name: str):
clean_route = clean_route_name(route_name)
ids, _ = self._get_all(prefix=f"{clean_route}#")
Expand Down
24 changes: 23 additions & 1 deletion semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,19 @@ def __init__(
if len(self.routes) > 0:
# initialize index now
self._add_routes(routes=self.routes)
elif self.index.sync in ["merge", "remote", "merge-force-remote"]:
dummy_embedding = self.encoder(["dummy"])

layer_routes = self.index._add_and_sync(
embeddings=dummy_embedding,
routes=[],
utterances=[],
)
self._set_layer_routes(layer_routes)
else:
raise ValueError(
"No routes provided for RouteLayer. Please provide routes or set sync to 'remote' if you want to use only remote routes."
)

def check_for_matching_routes(self, top_class: str) -> Optional[Route]:
matching_routes = [route for route in self.routes if route.name == top_class]
Expand Down Expand Up @@ -380,6 +393,14 @@ def _check_threshold(self, scores: List[float], route: Optional[Route]) -> bool:
)
return self._pass_threshold(scores, threshold)

def _set_layer_routes(self, new_routes: List[Route]):
"""
Set and override the current routes with a new list of routes.
:param new_routes: List of Route objects to set as the current routes.
"""
self.routes = new_routes

def __str__(self):
return (
f"RouteLayer(encoder={self.encoder}, "
Expand Down Expand Up @@ -466,11 +487,12 @@ def _add_routes(self, routes: List[Route]):
# create route array
route_names = [route.name for route in routes for _ in route.utterances]
# add everything to the index
self.index._add_and_sync(
layer_routes = self.index._add_and_sync(
embeddings=embedded_utterances,
routes=route_names,
utterances=all_utterances,
)
self._set_layer_routes(layer_routes)

def _encode(self, text: str) -> Any:
"""Given some text, encode it."""
Expand Down

0 comments on commit 09c3a74

Please sign in to comment.