From 087e48e7f5674129a75915b38c2a470c2b128e9d Mon Sep 17 00:00:00 2001 From: Dustin Franklin Date: Fri, 29 Sep 2023 23:26:07 -0400 Subject: [PATCH] updated nanodb --- packages/vectordb/nanodb/__main__.py | 14 ++++++- packages/vectordb/nanodb/clip.py | 6 ++- packages/vectordb/nanodb/nanodb.py | 37 +++++++++++-------- packages/vectordb/nanodb/requirements.txt | 3 +- packages/vectordb/nanodb/utils.py | 22 +++++++++++ packages/vectordb/nanodb/vector_index.py | 45 ++++++++++++----------- 6 files changed, 87 insertions(+), 40 deletions(-) diff --git a/packages/vectordb/nanodb/__main__.py b/packages/vectordb/nanodb/__main__.py index 6b6d09242..810809b9d 100644 --- a/packages/vectordb/nanodb/__main__.py +++ b/packages/vectordb/nanodb/__main__.py @@ -6,6 +6,7 @@ import numpy as np from .nanodb import NanoDB, DistanceMetrics +from .server import Server parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -25,6 +26,10 @@ parser.add_argument('--test', action='store_true', help="run a search query for each item in the index") parser.add_argument('--autosave', action='store_true', help="automatically save the database when new items are scanned") +parser.add_argument('--server', action='store_true', help="start the webserver and gradio UI") +parser.add_argument('--host', type=str, default='0.0.0.0', help="the network interface to listen on (default all)") +parser.add_argument('--port', type=int, default=7860, help="the webserver port to use") + args = parser.parse_args() if args.scan: @@ -57,9 +62,16 @@ print(f"-- testing index with k={args.k}") db.test(k=args.k) +if args.server: + server = Server(db, host=args.host, port=args.port) + server.start() + while True: - query = input('\n> ') + query = input('\n> ').strip() + if not query: + continue + if os.path.isfile(query) or os.path.isdir(query): db.scan(path) elif query.lower() == 'save': diff --git a/packages/vectordb/nanodb/clip.py b/packages/vectordb/nanodb/clip.py index e50884d4c..8050b5102 100644 --- a/packages/vectordb/nanodb/clip.py +++ b/packages/vectordb/nanodb/clip.py @@ -91,6 +91,7 @@ def __init__(self, model='ViT-L/14@336px', dtype=np.float32, crop=True, model_ca if crop: self.preprocessor.append(transforms.CenterCrop(self.config.input_shape[0])) + print("-- image cropping enabled") self.preprocessor.append(transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))) self.preprocessor.append(transforms.ConvertImageDtype(self.config.dtype)) @@ -99,8 +100,9 @@ def __init__(self, model='ViT-L/14@336px', dtype=np.float32, crop=True, model_ca print(self.model) - print(f'-- {self.config.name} warmup') - self.embed_image(PIL.Image.new('RGB', self.config.input_shape, (255,255,255))) + print(f"-- {self.config.name} warmup") + for i in range(2): + self.embed_image(PIL.Image.new('RGB', self.config.input_shape, (255,255,255))) print_table(self.config) def embed_image(self, image, return_tensors='pt', **kwargs): diff --git a/packages/vectordb/nanodb/nanodb.py b/packages/vectordb/nanodb/nanodb.py index b352c943d..c45384e70 100644 --- a/packages/vectordb/nanodb/nanodb.py +++ b/packages/vectordb/nanodb/nanodb.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import os +import sys import math import time import tqdm @@ -10,7 +11,7 @@ from .clip import CLIPEmbedding from .vector_index import cudaVectorIndex, DistanceMetrics -from .utils import print_table +from .utils import print_table, tqdm_redirect_stdout class NanoDB: def __init__(self, path=None, model='ViT-L/14@336px', dtype=np.float32, autosave=False, **kwargs): @@ -75,11 +76,14 @@ def scan(self, path, max_items=None, **kwargs): indexes = [] - for file in files: - embedding = self.embed(file, **kwargs) - index = self.index.add(embedding, sync=False) - self.metadata.insert(index, dict(path=file)) - indexes.append(index) + for file in tqdm.tqdm(files, file=sys.stdout): + with tqdm_redirect_stdout(): + embedding = self.embed(file, **kwargs) + index = self.index.add(embedding, sync=False) + self.metadata.insert(index, dict(path=file)) + indexes.append(index) + if (len(indexes) % 1000 == 0) and self.path and self.autosave: + self.save(self.path) time_elapsed = time.perf_counter() - time_begin print(f"-- added {len(indexes)} items to the index in from {path} ({time_elapsed:.1f} sec, {len(indexes)/time_elapsed:.1f} items/sec)") @@ -107,7 +111,7 @@ def load(self, path=None): with open(paths['config'], 'r') as file: config = json.load(file) pprint.pprint(config) - + with open(paths['metadata'], 'r') as file: self.metadata = json.load(file) @@ -122,8 +126,10 @@ def load(self, path=None): if config['shape'][0] > self.index.reserved: raise RuntimeError(f"{paths['vectors']} exceeds the reserve memory that the index was allocated with") - + + self.scans.extend(config['scans']) vectors.shape = config['shape'] + self.index.vectors.array[:vectors.shape[0]] = vectors self.index.shape = (vectors.shape[0], self.index.shape[1]) @@ -205,7 +211,7 @@ def get_paths(self, path, check_exists=False, raise_exception=False): def embed(self, data, type=None, **kwargs): if type is None: type = self.embedding_type(data) - print(f"-- generating embedding for {data} with type={type}") + #print(f"-- generating embedding for {data} with type={type}") if type == 'image': embedding = self.model.embed_image(data) @@ -235,9 +241,10 @@ def embedding_type(self, data): raise ValueError(f"couldn't find type of embedding for {type(data)}, please specify the 'type' argument") def test(self, k): - for i in range(len(self.index)): - indexes, distances = self.index.search(self.index.vectors.array[i], k=k) - print(f"-- search results for {i} {self.metadata[i]['path']}") - for n in range(k): - print(f" * {indexes[n]} {self.metadata[indexes[n]]['path']} {'similarity' if self.index.metric == 'cosine' else 'distance'}={distances[n]}") - \ No newline at end of file + for i in tqdm.tqdm(range(len(self.index)), file=sys.stdout): + with tqdm_redirect_stdout(): + indexes, distances = self.index.search(self.index.vectors.array[i], k=k) + print(f"-- search results for {i} {self.metadata[i]['path']}") + for n in range(k): + print(f" * {indexes[n]} {self.metadata[indexes[n]]['path']} {'similarity' if self.index.metric == 'cosine' else 'distance'}={distances[n]}") + \ No newline at end of file diff --git a/packages/vectordb/nanodb/requirements.txt b/packages/vectordb/nanodb/requirements.txt index 7e7f608cf..20965e055 100644 --- a/packages/vectordb/nanodb/requirements.txt +++ b/packages/vectordb/nanodb/requirements.txt @@ -1,7 +1,8 @@ websockets termcolor tabulate -gradio +gradio==3.34.0 +fastapi==0.99.0 flask tqdm git+https://github.com/openai/CLIP \ No newline at end of file diff --git a/packages/vectordb/nanodb/utils.py b/packages/vectordb/nanodb/utils.py index 48bd730a3..11befc697 100644 --- a/packages/vectordb/nanodb/utils.py +++ b/packages/vectordb/nanodb/utils.py @@ -1,8 +1,12 @@ #!/usr/bin/env python3 import os +import sys import time +import tqdm import json import requests +import contextlib + import torch import torchvision import numpy as np @@ -205,4 +209,22 @@ def torch_dtype(dtype): Convert numpy.dtype or str to torch.dtype """ return torch_dtype_dict[str(dtype)] + + +# https://stackoverflow.com/a/37243211 +class TQDMRedirectStdOut(object): + file = None + def __init__(self, file): + self.file = file + + def write(self, x): + if len(x.rstrip()) > 0: # Avoid print() second call (useless \n) + tqdm.tqdm.write(x, file=self.file) + +@contextlib.contextmanager +def tqdm_redirect_stdout(): + save_stdout = sys.stdout + sys.stdout = TQDMRedirectStdOut(sys.stdout) + yield + sys.stdout = save_stdout \ No newline at end of file diff --git a/packages/vectordb/nanodb/vector_index.py b/packages/vectordb/nanodb/vector_index.py index 666fda416..70f9d18ec 100644 --- a/packages/vectordb/nanodb/vector_index.py +++ b/packages/vectordb/nanodb/vector_index.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 import os +import sys import time +import tqdm import torch import numpy as np @@ -20,7 +22,7 @@ cudaStreamSynchronize, ) -from .utils import AttrDict, torch_dtype +from .utils import AttrDict, torch_dtype, tqdm_redirect_stdout class cudaVectorIndex: @@ -208,26 +210,27 @@ def validate(self, k=4): correct=True metric=self.metric if self.metric != 'cosine' else 'inner_product' - for n in range(self.shape[0]): - assert(cudaKNN( - C.cast(self.vectors.ptr, C.c_void_p), - C.cast(self.vectors.ptr+n*self.shape[1]*self.dsize, C.c_void_p), - self.dsize, - self.shape[0], - 1, - self.shape[1], - k, - DistanceMetrics[metric], - C.cast(self.vector_norms.ptr, C.POINTER(C.c_float)) if self.metric == 'l2' else None, - C.cast(self.distances.ptr, C.POINTER(C.c_float)), - C.cast(self.indexes.ptr, C.POINTER(C.c_longlong)), - C.cast(int(self.stream), C.c_void_p) if self.stream else None - )) - self.sync() - if self.indexes.array[0][0] != n: - print(f"incorrect or duplicate index [{n}] indexes={self.indexes.array[0,:k]} distances={self.distances.array[0,:k]}") - #assert(self.indexes[0][0]==n) - correct=False + for n in tqdm.tqdm(range(self.shape[0]), file=sys.stdout): + with tqdm_redirect_stdout(): + assert(cudaKNN( + C.cast(self.vectors.ptr, C.c_void_p), + C.cast(self.vectors.ptr+n*self.shape[1]*self.dsize, C.c_void_p), + self.dsize, + self.shape[0], + 1, + self.shape[1], + k, + DistanceMetrics[metric], + C.cast(self.vector_norms.ptr, C.POINTER(C.c_float)) if self.metric == 'l2' else None, + C.cast(self.distances.ptr, C.POINTER(C.c_float)), + C.cast(self.indexes.ptr, C.POINTER(C.c_longlong)), + C.cast(int(self.stream), C.c_void_p) if self.stream else None + )) + self.sync() + if self.indexes.array[0][0] != n: + print(f"incorrect or duplicate index [{n}] indexes={self.indexes.array[0,:k]} distances={self.distances.array[0,:k]}") + #assert(self.indexes[0][0]==n) + correct=False return correct