|
24 | 24 | logger = logging.getLogger(__name__) |
25 | 25 |
|
26 | 26 |
|
27 | | -def convert_embedding_model_format(model_json: Dict[str, Any]) -> Dict[str, Any]: |
28 | | - """ |
29 | | - Convert embedding model hub JSON format to Xinference expected format. |
30 | | - """ |
31 | | - logger.debug( |
32 | | - f"convert_embedding_model_format called for: {model_json.get('model_name', 'Unknown')}" |
33 | | - ) |
34 | | - |
35 | | - # Ensure required fields for embedding models |
36 | | - converted = model_json.copy() |
37 | | - |
38 | | - # Add missing required fields based on EmbeddingModelFamilyV2 requirements |
39 | | - if "version" not in converted: |
40 | | - converted["version"] = 2 |
41 | | - if "model_lang" not in converted: |
42 | | - converted["model_lang"] = ["en"] |
43 | | - |
44 | | - # Handle model_specs |
45 | | - if "model_specs" not in converted or not converted["model_specs"]: |
46 | | - converted["model_specs"] = [ |
47 | | - { |
48 | | - "model_format": "pytorch", |
49 | | - "model_size_in_billions": None, |
50 | | - "quantization": "none", |
51 | | - "model_hub": "huggingface", |
52 | | - } |
53 | | - ] |
54 | | - else: |
55 | | - # Ensure each spec has required fields |
56 | | - for spec in converted["model_specs"]: |
57 | | - if "quantization" not in spec: |
58 | | - spec["quantization"] = "none" |
59 | | - if "model_hub" not in spec: |
60 | | - spec["model_hub"] = "huggingface" |
61 | | - |
62 | | - return converted |
63 | | - |
64 | | - |
65 | 27 | from .core import ( |
66 | 28 | EMBEDDING_MODEL_DESCRIPTIONS, |
67 | 29 | EmbeddingModelFamilyV2, |
@@ -108,17 +70,90 @@ def register_custom_model(): |
108 | 70 |
|
109 | 71 |
|
110 | 72 | def register_builtin_model(): |
111 | | - from ..utils import load_complete_builtin_models |
| 73 | + # Use unified loading function with flatten_quantizations for embedding models |
| 74 | + from ..utils import flatten_quantizations, load_complete_builtin_models |
112 | 75 | from .embed_family import BUILTIN_EMBEDDING_MODELS |
113 | 76 |
|
114 | | - # Use unified loading function |
| 77 | + def convert_embedding_with_quantizations(model_json): |
| 78 | + if "model_specs" not in model_json: |
| 79 | + return model_json |
| 80 | + |
| 81 | + # Process each model_spec with flatten_quantizations (like builtin embedding loading) |
| 82 | + result = model_json.copy() |
| 83 | + flattened_specs = [] |
| 84 | + for spec in result["model_specs"]: |
| 85 | + if "model_src" in spec: |
| 86 | + flattened_specs.extend(flatten_quantizations(spec)) |
| 87 | + else: |
| 88 | + flattened_specs.append(spec) |
| 89 | + result["model_specs"] = flattened_specs |
| 90 | + |
| 91 | + return result |
| 92 | + |
115 | 93 | loaded_count = load_complete_builtin_models( |
116 | 94 | model_type="embedding", |
117 | | - builtin_registry=BUILTIN_EMBEDDING_MODELS, |
118 | | - convert_format_func=convert_embedding_model_format, |
| 95 | + builtin_registry={}, # Temporarily use empty dict, we handle it manually |
| 96 | + convert_format_func=convert_embedding_with_quantizations, |
119 | 97 | model_class=EmbeddingModelFamilyV2, |
120 | 98 | ) |
121 | 99 |
|
| 100 | + # Manually handle embedding's special registration logic |
| 101 | + if loaded_count > 0: |
| 102 | + from ...constants import XINFERENCE_MODEL_DIR |
| 103 | + from ..custom import RegistryManager |
| 104 | + |
| 105 | + registry = RegistryManager.get_registry("embedding") |
| 106 | + existing_model_names = { |
| 107 | + spec.model_name for spec in registry.get_custom_models() |
| 108 | + } |
| 109 | + |
| 110 | + builtin_embedding_dir = os.path.join( |
| 111 | + XINFERENCE_MODEL_DIR, "v2", "builtin", "embedding" |
| 112 | + ) |
| 113 | + complete_json_path = os.path.join( |
| 114 | + builtin_embedding_dir, "embedding_models.json" |
| 115 | + ) |
| 116 | + |
| 117 | + if os.path.exists(complete_json_path): |
| 118 | + with codecs.open(complete_json_path, encoding="utf-8") as fd: |
| 119 | + model_data = json.load(fd) |
| 120 | + |
| 121 | + models_to_register = [] |
| 122 | + if isinstance(model_data, list): |
| 123 | + models_to_register = model_data |
| 124 | + elif isinstance(model_data, dict): |
| 125 | + if "model_name" in model_data: |
| 126 | + models_to_register = [model_data] |
| 127 | + else: |
| 128 | + for key, value in model_data.items(): |
| 129 | + if isinstance(value, dict) and "model_name" in value: |
| 130 | + models_to_register.append(value) |
| 131 | + |
| 132 | + for model_data in models_to_register: |
| 133 | + try: |
| 134 | + from ..utils import flatten_quantizations |
| 135 | + |
| 136 | + converted_data = model_data.copy() |
| 137 | + if "model_specs" in converted_data: |
| 138 | + flattened_specs = [] |
| 139 | + for spec in converted_data["model_specs"]: |
| 140 | + if "model_src" in spec: |
| 141 | + flattened_specs.extend(flatten_quantizations(spec)) |
| 142 | + else: |
| 143 | + flattened_specs.append(spec) |
| 144 | + converted_data["model_specs"] = flattened_specs |
| 145 | + builtin_embedding_family = EmbeddingModelFamilyV2.parse_obj( |
| 146 | + converted_data |
| 147 | + ) |
| 148 | + |
| 149 | + if builtin_embedding_family.model_name not in existing_model_names: |
| 150 | + register_embedding(builtin_embedding_family, persist=False) |
| 151 | + existing_model_names.add(builtin_embedding_family.model_name) |
| 152 | + except Exception as e: |
| 153 | + warnings.warn( |
| 154 | + f"Error parsing model {model_data.get('model_name', 'Unknown')}: {e}" |
| 155 | + ) |
| 156 | + |
122 | 157 | logger.info( |
123 | 158 | f"Successfully loaded {loaded_count} embedding models from complete JSON" |
124 | 159 | ) |
|
0 commit comments