-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathregistry.py
100 lines (74 loc) · 2.55 KB
/
registry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import re
from custom_logging import ollm_logging
MODEL_REGISTRY = {}
CPP_MODEL_REGISTRY = {}
def wildcard_to_regex(pattern):
pattern = re.escape(pattern)
pattern = pattern.replace(r"\*", ".*")
return pattern
def is_match(pattern, target):
if isinstance(pattern, list):
return any(is_match(p, target) for p in pattern)
elif "*" in pattern:
regex_pattern = wildcard_to_regex(pattern)
return re.search(regex_pattern, target, re.IGNORECASE) is not None
else:
return pattern.lower() in target.lower()
def register_model(name):
def _register(cls):
if name in MODEL_REGISTRY:
return MODEL_REGISTRY[name]
MODEL_REGISTRY[name] = cls
return cls
return _register
def register_cpp_model(name):
def _register(cls):
if name in CPP_MODEL_REGISTRY:
return CPP_MODEL_REGISTRY[name]
CPP_MODEL_REGISTRY[name] = cls
return cls
return _register
def load_model(name):
if name in MODEL_REGISTRY:
return MODEL_REGISTRY[name]
else:
return None
def load_cpp_model(name):
if name in CPP_MODEL_REGISTRY:
return CPP_MODEL_REGISTRY[name]
else:
return None
def get_llm_class(ollm_model_id: str):
"""Get LLM class.
Args:
ollm_model_id (str): String of LLM model ID.
Returns:
class: LLM class.
"""
llm_class = None
for _, model_class in MODEL_REGISTRY.items():
# if model_class.include_name.lower() in ollm_model_id.lower():
if is_match(model_class.include_name, ollm_model_id):
llm_class = model_class
if llm_class is None:
llm_class = MODEL_REGISTRY["default"] if "default" in MODEL_REGISTRY else None
ollm_logging.debug(f"Using model class: {llm_class.__name__}")
llm_class.model_id = ollm_model_id
return llm_class
def get_cpp_llm_class(cpp_ollm_model_id: str):
"""Get llama.cpp LLM class.
Args:
cpp_ollm_model_id (str): String of llama.cpp LLM model ID.
Returns:
class: llama.cpp LLM class.
"""
llm_class = None
for _, model_class in CPP_MODEL_REGISTRY.items():
# if model_class.include_name.lower() in cpp_ollm_model_id.lower():
if is_match(model_class.include_name, cpp_ollm_model_id):
llm_class = model_class
if llm_class is None:
llm_class = CPP_MODEL_REGISTRY["default"] if "default" in CPP_MODEL_REGISTRY else None
ollm_logging.debug(f"Using model class: {llm_class.__name__}")
llm_class.model_id = cpp_ollm_model_id
return llm_class