From 5f8fc998c8c6b29553bff2146bc64166072f8b6a Mon Sep 17 00:00:00 2001 From: 0x376h <0x376h@gmail.com> Date: Wed, 18 Oct 2023 01:09:57 +0800 Subject: [PATCH] feat: dd method queryid to check id exists (#63) * add num_docs add num_docs * Update test_inmemory_vectordb.py * Update test_hnswlib_vectordb.py * add * add * Update test_inmemory_vectordb.py * change method name * commit some miss files * blank spaces change * Update test_inmemory_vectordb.py * Update test_inmemory_vectordb.py --- tests/unit/test_hnswlib_vectordb.py | 13 +++++++++++-- tests/unit/test_inmemory_vectordb.py | 13 +++++++++++-- vectordb/db/base.py | 8 ++++++++ vectordb/db/executors/hnsw_indexer.py | 5 ++++- vectordb/db/executors/inmemory_exact_indexer.py | 5 ++++- 5 files changed, 38 insertions(+), 6 deletions(-) diff --git a/tests/unit/test_hnswlib_vectordb.py b/tests/unit/test_hnswlib_vectordb.py index fe6aa62..cc45e14 100644 --- a/tests/unit/test_hnswlib_vectordb.py +++ b/tests/unit/test_hnswlib_vectordb.py @@ -175,5 +175,14 @@ def test_hnswlib_num_dos(tmpdir): db = HNSWVectorDB[MyDoc](workspace=str(tmpdir)) doc_list = [MyDoc(text=f'toy doc {i}', embedding=np.random.rand(128)) for i in range(1000)] db.index(inputs=DocList[MyDoc](doc_list)) - x=db.num_docs() - assert x['num_docs']==1000 + x = db.num_docs() + assert x['num_docs'] == 1000 + +def test_hnswlib_query_id(tmpdir): + db = HNSWVectorDB[MyDoc](workspace=str(tmpdir)) + doc_list = [MyDoc(id='test_1',text=f'test', embedding=np.random.rand(128)) ] + db.index(inputs=DocList[MyDoc](doc_list)) + queryobjtest1 = db.get_by_id('test_1') + queryobjtest2 = db.get_by_id('test_2') + assert queryobjtest2 is None + assert queryobjtest1.id == 'test_1' diff --git a/tests/unit/test_inmemory_vectordb.py b/tests/unit/test_inmemory_vectordb.py index cc812e4..77b8be6 100644 --- a/tests/unit/test_inmemory_vectordb.py +++ b/tests/unit/test_inmemory_vectordb.py @@ -177,5 +177,14 @@ def test_inmemory_num_dos(tmpdir): db = InMemoryExactNNVectorDB[MyDoc](workspace=str(tmpdir)) doc_list = [MyDoc(text=f'toy doc {i}', embedding=np.random.rand(128)) for i in range(1000)] db.index(inputs=DocList[MyDoc](doc_list)) - x=db.num_docs() - assert x['num_docs']==1000 + x = db.num_docs() + assert x['num_docs'] == 1000 + +def test_inmemory_query_id(tmpdir): + db = InMemoryExactNNVectorDB[MyDoc](workspace=str(tmpdir)) + doc_list = [MyDoc(id='test_1',text=f'test', embedding=np.random.rand(128)) ] + db.index(inputs=DocList[MyDoc](doc_list)) + queryobjtest1 = db.get_by_id('test_1') + queryobjtest2 = db.get_by_id('test_2') + assert queryobjtest2 is None + assert queryobjtest1.id == 'test_1' diff --git a/vectordb/db/base.py b/vectordb/db/base.py index 8338889..6cbeeba 100644 --- a/vectordb/db/base.py +++ b/vectordb/db/base.py @@ -229,6 +229,14 @@ async def _deploy(): def num_docs(self, **kwargs): return self._executor.num_docs() + + def get_by_id(self,info_id, **kwargs): + ret = None + try: + ret = self._executor.get_by_id(info_id) + except KeyError: + pass + return ret @pass_kwargs_as_params @unify_input_output diff --git a/vectordb/db/executors/hnsw_indexer.py b/vectordb/db/executors/hnsw_indexer.py index c6447bd..6c8c1fc 100644 --- a/vectordb/db/executors/hnsw_indexer.py +++ b/vectordb/db/executors/hnsw_indexer.py @@ -106,7 +106,10 @@ async def async_update(self, docs, *args, **kwargs): def num_docs(self, **kwargs): return {'num_docs': self._indexer.num_docs()} - + + def get_by_id(self,info_id,**kwargs): + return self._indexer[info_id] + def snapshot(self, snapshot_dir): # TODO: Maybe copy the work_dir to workspace if `handle` is False raise NotImplementedError('Act as not implemented') diff --git a/vectordb/db/executors/inmemory_exact_indexer.py b/vectordb/db/executors/inmemory_exact_indexer.py index aec7ab8..3af8a2b 100644 --- a/vectordb/db/executors/inmemory_exact_indexer.py +++ b/vectordb/db/executors/inmemory_exact_indexer.py @@ -72,7 +72,10 @@ def update(self, docs, *args, **kwargs): def num_docs(self, *args, **kwargs): return {'num_docs': self._indexer.num_docs()} - + + def get_by_id(self,info_id,**kwargs): + return self._indexer[info_id] + def snapshot(self, snapshot_dir): snapshot_file = f'{snapshot_dir}/index.bin' self._indexer.persist(snapshot_file)