Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,362 changes: 0 additions & 1,362 deletions chromadb-demo.ipynb

This file was deleted.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ dependencies = [
"webdataset",
"flask",
"Pillow",
"faiss-gpu-cu12"
"faiss-gpu-cu12",
"gunicorn"
]

[project.optional-dependencies]
Expand Down
74 changes: 74 additions & 0 deletions src/bioclip_vector_db/query/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Neighborhood Server

This server provides an API to search for similar vectors in a FAISS index.

## Running the server

There are two ways to run the server:

1. **Directly with Python (for development)**
2. **With Gunicorn (for production)**

### 1. Running with Python

You can run the server directly using the `neighborhood_server.py` script. You must provide the configuration as command-line arguments.

**Usage:**

```bash
python src/bioclip_vector_db/query/neighborhood_server.py --index_dir <path_to_index_dir> --index_file_prefix <prefix> --leader_index <leader_index_file> --partitions <partitions> [options]
```

**Example:**

```bash
python src/bioclip_vector_db/query/neighborhood_server.py \
--index_dir /path/to/faiss_index \
--index_file_prefix local_ \
--leader_index leader.index \
--partitions "1,2,5-10" \
--nprobe 10 \
--port 5001
```

### 2. Running with Gunicorn

For production, it is recommended to use a WSGI server like Gunicorn.

**Prerequisites:**

* Install Gunicorn: `pip install gunicorn`
* Ensure all dependencies from `requirements.txt` are installed.

**Configuration:**

The Gunicorn server is configured using environment variables.

**Required:**

* `INDEX_DIR`: Directory where the index files are stored.
* `INDEX_FILE_PREFIX`: The prefix of the index files (e.g., `local_`).
* `LEADER_INDEX`: The leader index file, which contains all the centroids.
* `PARTITIONS`: List of partition numbers to load (e.g., `"1,2,5-10"`).

**Optional:**

* `NPROBE`: Number of inverted list probes (default: `1`).
* `USE_CACHE`: Enable lazy loading cache (default: `False`).
* `PORT`: Port to run the server on (default: `5001`).
* `WORKERS`: Number of Gunicorn worker processes (default: `4`).

**Running the command:**

From the `bioclip-vector-db` directory, run the following command:

```bash
export INDEX_DIR=/path/to/faiss_index
export INDEX_FILE_PREFIX=local_
export LEADER_INDEX=leader.index
export PARTITIONS="1,2,5-10"
export NPROBE=10
export PORT=5001

gunicorn --workers ${WORKERS:-4} --bind 0.0.0.0:${PORT} --chdir src bioclip_vector_db.query.wsgi:app
```
65 changes: 43 additions & 22 deletions src/bioclip_vector_db/query/neighborhood_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
nprobe (int, optional): The number of partitions to search. Defaults to 1.
metadata_db (MetadataDatabase, optional): An instance of MetadataDatabase.
Defaults to None.
use_cache (bool, optional): Whether to use caching. Defaults to False. If enabled,
use_cache (bool, optional): Whether to use caching. Defaults to False. If enabled,
lazy loading of the local neighborhoods will happen.
"""
self._index_path_pattern = index_path_pattern
Expand All @@ -82,8 +82,9 @@ def __init__(
self._use_cache = use_cache

if self._use_cache and self._nprobe > MAX_CACHE_SIZE:
raise ValueError(f"nprobe cannot be greater than MAX_CACHE_SIZE: {MAX_CACHE_SIZE}")

raise ValueError(
f"nprobe cannot be greater than MAX_CACHE_SIZE: {MAX_CACHE_SIZE}"
)

if leader_index_path is None or not os.path.exists(leader_index_path):
logger.error(f"Loading leader index from: {leader_index_path}")
Expand Down Expand Up @@ -139,9 +140,11 @@ def _load_with_cache(self, neighborhood_ids: List[int]):
if neighborhood_id in self._indices:
self._cache_hits += 1
self._indices.move_to_end(neighborhood_id)
logger.info(f"Neighborhood {neighborhood_id} already in cache. Moved to most recently used.")
logger.info(
f"Neighborhood {neighborhood_id} already in cache. Moved to most recently used."
)
continue

self._cache_miss += 1
if len(self._indices) >= MAX_CACHE_SIZE:
self._cache_evictions += 1
Expand Down Expand Up @@ -362,6 +365,34 @@ def parse_partitions(partition_str: str) -> List[int]:
return sorted(list(partitions))


def create_app(
index_dir: str,
index_file_prefix: str,
leader_index: str,
nprobe: int,
partitions_str: str,
use_cache: bool,
):
"""Creates and configures the Flask application."""
index_path_pattern = f"{index_dir}/{index_file_prefix}{{}}.index"
leader_index_path = f"{index_dir}/{leader_index}"
partitions = parse_partitions(partitions_str)

metadata_db = MetadataDatabase(index_dir)

svc = FaissIndexService(
index_path_pattern,
partitions,
leader_index_path,
nprobe=nprobe,
metadata_db=metadata_db,
use_cache=use_cache,
)

server = LocalIndexServer(service=svc)
return server._app


def __main__():
parser = argparse.ArgumentParser(description="FAISS Neighborhood Server")
parser.add_argument(
Expand Down Expand Up @@ -408,31 +439,21 @@ def __main__():
)
args = parser.parse_args()

index_path_pattern = f"{args.index_dir}/{args.index_file_prefix}{{}}.index"
leader_index_path = f"{args.index_dir}/{args.leader_index}"
partitions = parse_partitions(args.partitions)

metadata_db = MetadataDatabase(args.index_dir)

svc = FaissIndexService(
index_path_pattern,
partitions,
leader_index_path,
app = create_app(
index_dir=args.index_dir,
index_file_prefix=args.index_file_prefix,
leader_index=args.leader_index,
nprobe=args.nprobe,
metadata_db=metadata_db,
partitions_str=args.partitions,
use_cache=args.use_cache,
)

SERVER_HOST = "0.0.0.0"
SERVER_PORT = args.port

# 2. Initialize the server with the index service
server = LocalIndexServer(service=svc)

# 3. Run the server
print(f"Starting server at http://{SERVER_HOST}:{SERVER_PORT}")
server.run(host=SERVER_HOST, port=SERVER_PORT)
app.run(host=SERVER_HOST, port=SERVER_PORT)


if __name__ == "__main__":
__main__()
__main__()
27 changes: 27 additions & 0 deletions src/bioclip_vector_db/query/wsgi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
from bioclip_vector_db.query.neighborhood_server import create_app

# Required environment variables
INDEX_DIR = os.environ.get("INDEX_DIR")
INDEX_FILE_PREFIX = os.environ.get("INDEX_FILE_PREFIX")
LEADER_INDEX = os.environ.get("LEADER_INDEX")
PARTITIONS = os.environ.get("PARTITIONS")

# Optional environment variables
NPROBE = int(os.environ.get("NPROBE", 1))
USE_CACHE = os.environ.get("USE_CACHE", "False").lower() in ("true", "1", "t")
PORT = int(os.environ.get("PORT", 5001))

if not all([INDEX_DIR, INDEX_FILE_PREFIX, LEADER_INDEX, PARTITIONS]):
raise ValueError(
"Missing one or more required environment variables: INDEX_DIR, INDEX_FILE_PREFIX, LEADER_INDEX, PARTITIONS"
)

app = create_app(
index_dir=INDEX_DIR,
index_file_prefix=INDEX_FILE_PREFIX,
leader_index=LEADER_INDEX,
nprobe=NPROBE,
partitions_str=PARTITIONS,
use_cache=USE_CACHE,
)
4 changes: 2 additions & 2 deletions src/bioclip_vector_db/storage/faiss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _maybe_flush_buffers(self):
if len(self._partition_to_embedding_map[partition_id]) >= self._batch_size:
self._write_partition_to_file(partition_id)

def add_embedding(self, original_id: str, embedding: np.ndarray):
def add_embedding(self, original_id: str, embedding: np.ndarray, metadata: dict = None):
"""Adds a single embedding vector to the appropriate partition buffer.

Args:
Expand All @@ -111,7 +111,7 @@ def add_embedding(self, original_id: str, embedding: np.ndarray):
partition_id = int(partition_ids[0][0])

faiss_id = self._partition_faiss_ids[partition_id]
self._metadata_db.add_mapping(partition_id, faiss_id, original_id)
self._metadata_db.add_mapping(partition_id, faiss_id, original_id, metadata)
self._partition_faiss_ids[partition_id] += 1

self._partition_to_embedding_map[partition_id].append(embedding)
Expand Down
66 changes: 57 additions & 9 deletions src/bioclip_vector_db/storage/metadata_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def create_table(self):
with conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS metadata (
CREATE TABLE IF NOT EXISTS id_mapping (
partition_id INTEGER NOT NULL,
faiss_id INTEGER NOT NULL,
original_id TEXT NOT NULL,
metadata TEXT,
metadata BLOB,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (partition_id, faiss_id)
)
Expand All @@ -61,7 +61,7 @@ def create_table(self):
raise

cursor = conn.cursor()
cursor.execute("SELECT count(*) FROM metadata")
cursor.execute("SELECT count(*) FROM id_mapping")
result = cursor.fetchone()
logger.info(f"Total number of records: {result[0]}")

Expand All @@ -82,12 +82,12 @@ def add_mapping(
metadata: Optional dictionary of metadata to store as a JSON string.
"""
conn = self._get_connection()
metadata_json = json.dumps(metadata) if metadata else None
metadata_blob = json.dumps(metadata).encode("utf-8") if metadata else None
try:
with conn:
conn.execute(
"INSERT INTO metadata (partition_id, faiss_id, original_id, metadata) VALUES (?, ?, ?, ?)",
(int(partition_id), int(faiss_id), original_id, metadata_json),
"INSERT INTO id_mapping (partition_id, faiss_id, original_id, metadata) VALUES (?, ?, ?, ?)",
(int(partition_id), int(faiss_id), original_id, metadata_blob),
)
logger.debug(
f"Added mapping: partition_id={partition_id}, faiss_id={faiss_id}, original_id={original_id}"
Expand Down Expand Up @@ -115,7 +115,7 @@ def get_original_id(self, partition_id: int, faiss_id: int) -> Optional[str]:
try:
cursor = conn.cursor()
cursor.execute(
"SELECT original_id FROM metadata WHERE partition_id = ? AND faiss_id = ?",
"SELECT original_id FROM id_mapping WHERE partition_id = ? AND faiss_id = ?",
(int(partition_id), int(faiss_id)),
)
result = cursor.fetchone()
Expand All @@ -124,6 +124,54 @@ def get_original_id(self, partition_id: int, faiss_id: int) -> Optional[str]:
logger.error(f"Error getting original_id: {e}")
raise

def get_metadata(self, partition_id: int, faiss_id: int) -> Optional[Dict[str, Any]]:
"""
Retrieves the metadata for a given FAISS ID in a specific partition.

Args:
partition_id: The ID of the partition file.
faiss_id: The FAISS index ID.

Returns:
The metadata dictionary, or None if not found.
"""
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT original_id, metadata FROM id_mapping WHERE partition_id = ? AND faiss_id = ?",
(int(partition_id), int(faiss_id)),
)
result = cursor.fetchone()
if result and result[0]:
return json.loads(result[0].decode('utf-8'))
return None
except sqlite3.Error as e:
logger.error(f"Error getting metadata: {e}")
raise

def get_metadata(self, original_id: int) -> Optional[Dict[str, Any]]:
"""
Retrieves the metadata for a given original ID.

Args:
original_id: The original ID
"""
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT original_id, metadata FROM id_mapping WHERE original_id = ?",
(str(original_id)),
)
result = cursor.fetchone()
if result and result[0]:
return json.loads(result[0])
return None
except sqlite3.Error as e:
logger.error(f"Error getting metadata: {e}")
raise

def batch_get_original_id(self, partition_id: int, faiss_ids: List[int]) -> Dict[int, str]:
# Implement me.
pass
Expand All @@ -142,7 +190,7 @@ def get_faiss_id(self, original_id: str) -> Optional[int]:
try:
cursor = conn.cursor()
cursor.execute(
"SELECT faiss_id FROM metadata WHERE original_id = ?", (original_id,)
"SELECT faiss_id FROM id_mapping WHERE original_id = ?", (original_id,)
)
result = cursor.fetchone()
return result[0] if result else None
Expand All @@ -161,7 +209,7 @@ def _reset(self):
conn = self._get_connection()
try:
with conn:
conn.execute("DROP TABLE IF EXISTS metadata")
conn.execute("DROP TABLE IF EXISTS id_mapping")
logger.info("SQLITE: Reset table successful.")
except sqlite3.Error as e:
logger.error(f"Error resetting table: {e}")
Expand Down
2 changes: 1 addition & 1 deletion src/bioclip_vector_db/storage/storage_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _add_embedding_to_index(
self, id: str, embedding: List[float], metadata: Dict[str, str]
):
embedding_np = np.array([embedding]).astype("float32")
self._writer.add_embedding(id, embedding_np)
self._writer.add_embedding(id, embedding_np, metadata=metadata)

def add_embedding(self, id: str, embedding: List[float], metadata: Dict[str, str]):
if len(self._train_ids) < self._train_set_size:
Expand Down