-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathindex_manager.py
79 lines (59 loc) · 2.83 KB
/
index_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from typing import Union, List, Dict
from babydragon.models.embedders.cohere import CohereEmbedder
from babydragon.memory.indexes.memory_index import MemoryIndex
from babydragon.memory.indexes.python_index import PythonIndex
import pkg_resources
import os
import importlib
## Gonna create a class soon here
def create_wiki_index():
dataset_url = "Cohere/wikipedia-22-12-simple-embeddings"
index = MemoryIndex(name="wiki", load=True, is_batched=True,embedder=CohereEmbedder) # pyright: ignore
if len(index.values)>0:
return index
else:
print("Index not found, creating new index")
index = MemoryIndex.from_hf_dataset(dataset_url, ["title", "text"],embeddings_column= "emb", name="wiki", is_batched=True,embedder=CohereEmbedder) # pyright: ignore
return index
def check_load(label: str) -> bool:
load_directory = os.path.join("storage", label)
if not os.path.exists(load_directory):
return False
print(f"Loading index from {load_directory}")
index_filename = os.path.join(load_directory, f"{label}_index.faiss")
values_filename = os.path.join(load_directory, f"{label}_values.json")
if os.path.exists(index_filename) and os.path.exists(values_filename):
return True
else:
return False
def create_local_python_index(label:str) -> Union[MemoryIndex,PythonIndex,bool]:
""" if (check_load(label)): """
""" return MemoryIndex(name=label, load=True, is_batched=True,max_workers=16,backup=False) """
""" else: """
try:
module = importlib.import_module(label)
if (module.__file__ is not None):
path = os.path.dirname(os.path.abspath(module.__file__))
return PythonIndex(path,
name=f'{label}_index_parallel',
minify_code=False,
load=True,
max_workers=16,
backup=False,
filter='class')
except Exception as e:
print(str(e))
return False
def retrieve_index(requested_index: List) -> Dict[str,Union[MemoryIndex,PythonIndex]]:
indexes = {'babydragon':create_local_python_index('babydragon')}
if len(requested_index)>0:
for k in requested_index:
match (k.source, k.category, k.label):
case ("local", _, _):
if k.category == 'python':
indexes[k.label] = create_local_python_index(k.label)
elif (k.category == 'cohere') and (k.label == "wiki"):
indexes[k.label] = create_wiki_index()
else:
return {'babydragon':create_local_python_index('babydragon')} # pyright: ignore
return {k: v for k, v in indexes.items() if v != False} # pyright: ignore