Skip to content

Commit

Permalink
Merge branch 'main' into patch-2
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam authored Jul 12, 2024
2 parents 2d7d0f7 + 8f5aad3 commit d5b43f5
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "semantic-router"
version = "0.0.50"
version = "0.0.51"
description = "Super fast semantic router for AI decision making"
authors = [
"James Briggs <[email protected]>",
Expand Down
33 changes: 32 additions & 1 deletion semantic_router/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,32 @@ class BaseIndex(BaseModel):
utterances: Optional[np.ndarray] = None
dimensions: Union[int, None] = None
type: str = "base"
sync: Union[str, None] = None

def add(
self, embeddings: List[List[float]], routes: List[str], utterances: List[Any]
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[Any],
):
"""
Add embeddings to the index.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def _add_and_sync(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[Any],
):
"""
Add embeddings to the index and manage index syncing if necessary.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def delete(self, route_name: str):
"""
Deletes route by route name.
Expand Down Expand Up @@ -74,5 +90,20 @@ def delete_index(self):
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def _sync_index(self, local_routes: dict):
"""
Synchronize the local index with the remote index based on the specified mode.
Modes:
- "error": Raise an error if local and remote are not synchronized.
- "remote": Take remote as the source of truth and update local to align.
- "local": Take local as the source of truth and update remote to align.
- "merge-force-remote": Merge both local and remote taking only remote routes utterances when a route with same route name is present both locally and remotely.
- "merge-force-local": Merge both local and remote taking only local routes utterances when a route with same route name is present both locally and remotely.
- "merge": Merge both local and remote, merging also local and remote utterances when a route with same route name is present both locally and remotely.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")

class Config:
arbitrary_types_allowed = True
16 changes: 15 additions & 1 deletion semantic_router/index/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from semantic_router.index.base import BaseIndex
from semantic_router.linear import similarity_matrix, top_scores
from semantic_router.utils.logger import logger


class LocalIndex(BaseIndex):
Expand All @@ -21,7 +22,10 @@ class Config:
arbitrary_types_allowed = True

def add(
self, embeddings: List[List[float]], routes: List[str], utterances: List[str]
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
):
embeds = np.array(embeddings) # type: ignore
routes_arr = np.array(routes)
Expand All @@ -38,6 +42,16 @@ def add(
self.routes = np.concatenate([self.routes, routes_arr])
self.utterances = np.concatenate([self.utterances, utterances_arr])

def _add_and_sync(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
):
if self.sync is not None:
logger.warning("Sync add is not implemented for LocalIndex.")
self.add(embeddings, routes, utterances)

def get_routes(self) -> List[Tuple]:
"""
Gets a list of route and utterance objects currently stored in the index.
Expand Down
148 changes: 146 additions & 2 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
host: str = "",
namespace: Optional[str] = "",
base_url: Optional[str] = "https://api.pinecone.io",
sync: str = "local",
):
super().__init__()
self.index_name = index_name
Expand All @@ -77,6 +78,7 @@ def __init__(
self.type = "pinecone"
self.api_key = api_key or os.getenv("PINECONE_API_KEY")
self.base_url = base_url
self.sync = sync

if self.api_key is None:
raise ValueError("Pinecone API key is required.")
Expand Down Expand Up @@ -195,6 +197,79 @@ async def _init_async_index(self, force_create: bool = False):
logger.warning("Index could not be initialized.")
self.host = index_stats["host"] if index_stats else None

def _sync_index(self, local_routes: dict):
remote_routes = self.get_routes()
remote_dict: dict = {route: set() for route, _ in remote_routes}
for route, utterance in remote_routes:
remote_dict[route].add(utterance)

local_dict: dict = {route: set() for route in local_routes["routes"]}
for route, utterance in zip(local_routes["routes"], local_routes["utterances"]):
local_dict[route].add(utterance)

all_routes = set(remote_dict.keys()).union(local_dict.keys())

routes_to_add = []
routes_to_delete = []

for route in all_routes:
local_utterances = local_dict.get(route, set())
remote_utterances = remote_dict.get(route, set())

if self.sync == "error":
if local_utterances != remote_utterances:
raise ValueError(
f"Synchronization error: Differences found in route '{route}'"
)
utterances_to_include: set = set()
elif self.sync == "remote":
utterances_to_include = set()
elif self.sync == "local":
utterances_to_include = local_utterances - remote_utterances
routes_to_delete.extend(
[
(route, utterance)
for utterance in remote_utterances
if utterance not in local_utterances
]
)
elif self.sync == "merge-force-remote":
if route in local_dict and route not in remote_dict:
utterances_to_include = local_utterances
else:
utterances_to_include = set()
elif self.sync == "merge-force-local":
if route in local_dict:
utterances_to_include = local_utterances - remote_utterances
routes_to_delete.extend(
[
(route, utterance)
for utterance in remote_utterances
if utterance not in local_utterances
]
)
else:
utterances_to_include = set()
elif self.sync == "merge":
utterances_to_include = local_utterances - remote_utterances
else:
raise ValueError("Invalid sync mode specified")

for utterance in utterances_to_include:
indices = [
i
for i, x in enumerate(local_routes["utterances"])
if x == utterance and local_routes["routes"][i] == route
]
routes_to_add.extend(
[
(local_routes["embeddings"][idx], route, utterance)
for idx in indices
]
)

return routes_to_add, routes_to_delete

def _batch_upsert(self, batch: List[Dict]):
"""Helper method for upserting a single batch of records."""
if self.index is not None:
Expand Down Expand Up @@ -223,11 +298,73 @@ def add(
batch = vectors_to_upsert[i : i + batch_size]
self._batch_upsert(batch)

def _add_and_sync(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
batch_size: int = 100,
):
"""Add vectors to Pinecone in batches."""
if self.index is None:
self.dimensions = self.dimensions or len(embeddings[0])
self.index = self._init_index(force_create=True)

local_routes = {
"routes": routes,
"utterances": utterances,
"embeddings": embeddings,
}
if self.sync is not None:
data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes)
routes_to_delete: dict = {}
for route, utterance in data_to_delete:
routes_to_delete.setdefault(route, []).append(utterance)

for route, utterances in routes_to_delete.items():
remote_routes = self._get_routes_with_ids(route_name=route)
ids_to_delete = [
r["id"]
for r in remote_routes
if (r["route"], r["utterance"])
in zip([route] * len(utterances), utterances)
]
if ids_to_delete and self.index:
self.index.delete(ids=ids_to_delete)
else:
data_to_upsert = [
(vector, route, utterance)
for vector, route, utterance in zip(embeddings, routes, utterances)
]

vectors_to_upsert = [
PineconeRecord(values=vector, route=route, utterance=utterance).to_dict()
for vector, route, utterance in data_to_upsert
]

for i in range(0, len(vectors_to_upsert), batch_size):
batch = vectors_to_upsert[i : i + batch_size]
self._batch_upsert(batch)

def _get_route_ids(self, route_name: str):
clean_route = clean_route_name(route_name)
ids, _ = self._get_all(prefix=f"{clean_route}#")
return ids

def _get_routes_with_ids(self, route_name: str):
clean_route = clean_route_name(route_name)
ids, metadata = self._get_all(prefix=f"{clean_route}#", include_metadata=True)
route_tuples = []
for id, data in zip(ids, metadata):
route_tuples.append(
{
"id": id,
"route": data["sr_route"],
"utterance": data["sr_utterance"],
}
)
return route_tuples

def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False):
"""
Retrieves all vector IDs from the Pinecone index using pagination.
Expand Down Expand Up @@ -267,9 +404,16 @@ def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False)

# if we need metadata, we fetch it
if include_metadata:
res_meta = self.index.fetch(ids=vector_ids, namespace=self.namespace)
for id in vector_ids:
res_meta = (
self.index.fetch(ids=[id], namespace=self.namespace)
if self.index
else {}
)
metadata.extend(
[x["metadata"] for x in res_meta["vectors"].values()]
)
# extract metadata only
metadata.extend([x["metadata"] for x in res_meta["vectors"].values()])

# Check if there's a next page token; if not, break the loop
next_page_token = response_data.get("pagination", {}).get("next")
Expand Down
11 changes: 11 additions & 0 deletions semantic_router/index/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,17 @@ def _init_collection(self) -> None:
**self.config,
)

def _add_and_sync(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE,
):
if self.sync is not None:
logger.warning("Sync add is not implemented for QdrantIndex")
self.add(embeddings, routes, utterances, batch_size)

def add(
self,
embeddings: List[List[float]],
Expand Down
2 changes: 1 addition & 1 deletion semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def _add_routes(self, routes: List[Route]):
# create route array
route_names = [route.name for route in routes for _ in route.utterances]
# add everything to the index
self.index.add(
self.index._add_and_sync(
embeddings=embedded_utterances,
routes=route_names,
utterances=all_utterances,
Expand Down

0 comments on commit d5b43f5

Please sign in to comment.