Skip to content

Commit

Permalink
updated nanodb
Browse files Browse the repository at this point in the history
  • Loading branch information
dusty-nv committed Sep 30, 2023
1 parent f6d8527 commit 087e48e
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 40 deletions.
14 changes: 13 additions & 1 deletion packages/vectordb/nanodb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from .nanodb import NanoDB, DistanceMetrics
from .server import Server

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

Expand All @@ -25,6 +26,10 @@
parser.add_argument('--test', action='store_true', help="run a search query for each item in the index")
parser.add_argument('--autosave', action='store_true', help="automatically save the database when new items are scanned")

parser.add_argument('--server', action='store_true', help="start the webserver and gradio UI")
parser.add_argument('--host', type=str, default='0.0.0.0', help="the network interface to listen on (default all)")
parser.add_argument('--port', type=int, default=7860, help="the webserver port to use")

args = parser.parse_args()

if args.scan:
Expand Down Expand Up @@ -57,9 +62,16 @@
print(f"-- testing index with k={args.k}")
db.test(k=args.k)

if args.server:
server = Server(db, host=args.host, port=args.port)
server.start()

while True:
query = input('\n> ')
query = input('\n> ').strip()

if not query:
continue

if os.path.isfile(query) or os.path.isdir(query):
db.scan(path)
elif query.lower() == 'save':
Expand Down
6 changes: 4 additions & 2 deletions packages/vectordb/nanodb/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, model='ViT-L/14@336px', dtype=np.float32, crop=True, model_ca

if crop:
self.preprocessor.append(transforms.CenterCrop(self.config.input_shape[0]))
print("-- image cropping enabled")

self.preprocessor.append(transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)))
self.preprocessor.append(transforms.ConvertImageDtype(self.config.dtype))
Expand All @@ -99,8 +100,9 @@ def __init__(self, model='ViT-L/14@336px', dtype=np.float32, crop=True, model_ca

print(self.model)

print(f'-- {self.config.name} warmup')
self.embed_image(PIL.Image.new('RGB', self.config.input_shape, (255,255,255)))
print(f"-- {self.config.name} warmup")
for i in range(2):
self.embed_image(PIL.Image.new('RGB', self.config.input_shape, (255,255,255)))
print_table(self.config)

def embed_image(self, image, return_tensors='pt', **kwargs):
Expand Down
37 changes: 22 additions & 15 deletions packages/vectordb/nanodb/nanodb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
import os
import sys
import math
import time
import tqdm
Expand All @@ -10,7 +11,7 @@

from .clip import CLIPEmbedding
from .vector_index import cudaVectorIndex, DistanceMetrics
from .utils import print_table
from .utils import print_table, tqdm_redirect_stdout

class NanoDB:
def __init__(self, path=None, model='ViT-L/14@336px', dtype=np.float32, autosave=False, **kwargs):
Expand Down Expand Up @@ -75,11 +76,14 @@ def scan(self, path, max_items=None, **kwargs):

indexes = []

for file in files:
embedding = self.embed(file, **kwargs)
index = self.index.add(embedding, sync=False)
self.metadata.insert(index, dict(path=file))
indexes.append(index)
for file in tqdm.tqdm(files, file=sys.stdout):
with tqdm_redirect_stdout():
embedding = self.embed(file, **kwargs)
index = self.index.add(embedding, sync=False)
self.metadata.insert(index, dict(path=file))
indexes.append(index)
if (len(indexes) % 1000 == 0) and self.path and self.autosave:
self.save(self.path)

time_elapsed = time.perf_counter() - time_begin
print(f"-- added {len(indexes)} items to the index in from {path} ({time_elapsed:.1f} sec, {len(indexes)/time_elapsed:.1f} items/sec)")
Expand Down Expand Up @@ -107,7 +111,7 @@ def load(self, path=None):
with open(paths['config'], 'r') as file:
config = json.load(file)
pprint.pprint(config)

with open(paths['metadata'], 'r') as file:
self.metadata = json.load(file)

Expand All @@ -122,8 +126,10 @@ def load(self, path=None):

if config['shape'][0] > self.index.reserved:
raise RuntimeError(f"{paths['vectors']} exceeds the reserve memory that the index was allocated with")


self.scans.extend(config['scans'])
vectors.shape = config['shape']

self.index.vectors.array[:vectors.shape[0]] = vectors
self.index.shape = (vectors.shape[0], self.index.shape[1])

Expand Down Expand Up @@ -205,7 +211,7 @@ def get_paths(self, path, check_exists=False, raise_exception=False):
def embed(self, data, type=None, **kwargs):
if type is None:
type = self.embedding_type(data)
print(f"-- generating embedding for {data} with type={type}")
#print(f"-- generating embedding for {data} with type={type}")

if type == 'image':
embedding = self.model.embed_image(data)
Expand Down Expand Up @@ -235,9 +241,10 @@ def embedding_type(self, data):
raise ValueError(f"couldn't find type of embedding for {type(data)}, please specify the 'type' argument")

def test(self, k):
for i in range(len(self.index)):
indexes, distances = self.index.search(self.index.vectors.array[i], k=k)
print(f"-- search results for {i} {self.metadata[i]['path']}")
for n in range(k):
print(f" * {indexes[n]} {self.metadata[indexes[n]]['path']} {'similarity' if self.index.metric == 'cosine' else 'distance'}={distances[n]}")

for i in tqdm.tqdm(range(len(self.index)), file=sys.stdout):
with tqdm_redirect_stdout():
indexes, distances = self.index.search(self.index.vectors.array[i], k=k)
print(f"-- search results for {i} {self.metadata[i]['path']}")
for n in range(k):
print(f" * {indexes[n]} {self.metadata[indexes[n]]['path']} {'similarity' if self.index.metric == 'cosine' else 'distance'}={distances[n]}")

3 changes: 2 additions & 1 deletion packages/vectordb/nanodb/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
websockets
termcolor
tabulate
gradio
gradio==3.34.0
fastapi==0.99.0
flask
tqdm
git+https://github.com/openai/CLIP
22 changes: 22 additions & 0 deletions packages/vectordb/nanodb/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#!/usr/bin/env python3
import os
import sys
import time
import tqdm
import json
import requests
import contextlib

import torch
import torchvision
import numpy as np
Expand Down Expand Up @@ -205,4 +209,22 @@ def torch_dtype(dtype):
Convert numpy.dtype or str to torch.dtype
"""
return torch_dtype_dict[str(dtype)]


# https://stackoverflow.com/a/37243211
class TQDMRedirectStdOut(object):
file = None
def __init__(self, file):
self.file = file

def write(self, x):
if len(x.rstrip()) > 0: # Avoid print() second call (useless \n)
tqdm.tqdm.write(x, file=self.file)

@contextlib.contextmanager
def tqdm_redirect_stdout():
save_stdout = sys.stdout
sys.stdout = TQDMRedirectStdOut(sys.stdout)
yield
sys.stdout = save_stdout

45 changes: 24 additions & 21 deletions packages/vectordb/nanodb/vector_index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#!/usr/bin/env python3
import os
import sys
import time
import tqdm
import torch

import numpy as np
Expand All @@ -20,7 +22,7 @@
cudaStreamSynchronize,
)

from .utils import AttrDict, torch_dtype
from .utils import AttrDict, torch_dtype, tqdm_redirect_stdout


class cudaVectorIndex:
Expand Down Expand Up @@ -208,26 +210,27 @@ def validate(self, k=4):
correct=True
metric=self.metric if self.metric != 'cosine' else 'inner_product'

for n in range(self.shape[0]):
assert(cudaKNN(
C.cast(self.vectors.ptr, C.c_void_p),
C.cast(self.vectors.ptr+n*self.shape[1]*self.dsize, C.c_void_p),
self.dsize,
self.shape[0],
1,
self.shape[1],
k,
DistanceMetrics[metric],
C.cast(self.vector_norms.ptr, C.POINTER(C.c_float)) if self.metric == 'l2' else None,
C.cast(self.distances.ptr, C.POINTER(C.c_float)),
C.cast(self.indexes.ptr, C.POINTER(C.c_longlong)),
C.cast(int(self.stream), C.c_void_p) if self.stream else None
))
self.sync()
if self.indexes.array[0][0] != n:
print(f"incorrect or duplicate index [{n}] indexes={self.indexes.array[0,:k]} distances={self.distances.array[0,:k]}")
#assert(self.indexes[0][0]==n)
correct=False
for n in tqdm.tqdm(range(self.shape[0]), file=sys.stdout):
with tqdm_redirect_stdout():
assert(cudaKNN(
C.cast(self.vectors.ptr, C.c_void_p),
C.cast(self.vectors.ptr+n*self.shape[1]*self.dsize, C.c_void_p),
self.dsize,
self.shape[0],
1,
self.shape[1],
k,
DistanceMetrics[metric],
C.cast(self.vector_norms.ptr, C.POINTER(C.c_float)) if self.metric == 'l2' else None,
C.cast(self.distances.ptr, C.POINTER(C.c_float)),
C.cast(self.indexes.ptr, C.POINTER(C.c_longlong)),
C.cast(int(self.stream), C.c_void_p) if self.stream else None
))
self.sync()
if self.indexes.array[0][0] != n:
print(f"incorrect or duplicate index [{n}] indexes={self.indexes.array[0,:k]} distances={self.distances.array[0,:k]}")
#assert(self.indexes[0][0]==n)
correct=False

return correct

Expand Down

0 comments on commit 087e48e

Please sign in to comment.