Skip to content

Commit 0199099

Browse files
committed
num2
1 parent 5d4aefc commit 0199099

File tree

8 files changed

+203
-271
lines changed

8 files changed

+203
-271
lines changed

xinference/core/worker.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -727,10 +727,15 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
727727
if "model_src" in model_json:
728728
# Simple flat format with model_src at top level
729729
flattened_list = flatten_model_src(model_json)
730-
converted_model_json = flattened_list[0] if flattened_list else model_json
731-
elif "model_specs" in model_json and isinstance(model_json["model_specs"], list):
730+
converted_model_json = (
731+
flattened_list[0] if flattened_list else model_json
732+
)
733+
elif "model_specs" in model_json and isinstance(
734+
model_json["model_specs"], list
735+
):
732736
# LLM/embedding/rerank format with model_specs
733737
from ..model.utils import flatten_quantizations
738+
734739
converted_model_json = model_json.copy()
735740

736741
# Process all model_specs using flatten_quantizations - exactly like builtin models
@@ -747,7 +752,9 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
747752
# Use all flattened specs like builtin models
748753
if flattened_specs:
749754
converted_model_json["model_specs"] = flattened_specs
750-
logger.info(f"Processed {len(flattened_specs)} model specifications for {model_name}")
755+
logger.info(
756+
f"Processed {len(flattened_specs)} model specifications for {model_name}"
757+
)
751758
else:
752759
# Already flattened format, use as-is
753760
converted_model_json = model_json
@@ -857,7 +864,6 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
857864
f"Failed to register model '{model_spec.model_name}': {str(e)}"
858865
)
859866

860-
861867
@log_async(logger=logger)
862868
async def update_model_type(self, model_type: str):
863869
"""

xinference/model/audio/__init__.py

Lines changed: 20 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -27,50 +27,6 @@
2727
logger = logging.getLogger(__name__)
2828

2929

30-
def convert_audio_model_format(model_json: Dict[str, Any]) -> Dict[str, Any]:
31-
"""
32-
Convert audio model hub JSON format to Xinference expected format.
33-
"""
34-
logger.debug(
35-
f"convert_audio_model_format called for: {model_json.get('model_name', 'Unknown')}"
36-
)
37-
38-
# Apply conversion logic to handle null model_id and other issues
39-
if model_json.get("model_id") is None and "model_src" in model_json:
40-
model_src = model_json["model_src"]
41-
# Extract model_id from available sources
42-
if "huggingface" in model_src and "model_id" in model_src["huggingface"]:
43-
model_json["model_id"] = model_src["huggingface"]["model_id"]
44-
elif "modelscope" in model_src and "model_id" in model_src["modelscope"]:
45-
model_json["model_id"] = model_src["modelscope"]["model_id"]
46-
47-
# Extract model_revision if available
48-
if model_json.get("model_revision") is None:
49-
if (
50-
"huggingface" in model_src
51-
and "model_revision" in model_src["huggingface"]
52-
):
53-
model_json["model_revision"] = model_src["huggingface"][
54-
"model_revision"
55-
]
56-
elif (
57-
"modelscope" in model_src
58-
and "model_revision" in model_src["modelscope"]
59-
):
60-
model_json["model_revision"] = model_src["modelscope"]["model_revision"]
61-
62-
# Ensure required fields for audio models
63-
if "version" not in model_json:
64-
model_json["version"] = 2
65-
if "model_lang" not in model_json:
66-
model_json["model_lang"] = [
67-
"en",
68-
"zh",
69-
] # Audio models often support multiple languages
70-
71-
return model_json
72-
73-
7430
from .core import (
7531
AUDIO_MODEL_DESCRIPTIONS,
7632
AudioModelFamilyV2,
@@ -110,13 +66,30 @@ def register_custom_model():
11066

11167

11268
def register_builtin_model():
113-
from ..utils import load_complete_builtin_models
69+
# Use unified loading function with flatten_model_src + audio-specific defaults
70+
from ..utils import flatten_model_src, load_complete_builtin_models
71+
72+
def convert_audio_with_flatten(model_json):
73+
flattened_list = flatten_model_src(model_json)
74+
if not flattened_list:
75+
return model_json
76+
77+
result = flattened_list[0]
78+
79+
# Add required defaults for audio models
80+
if "multilingual" not in result:
81+
result["multilingual"] = True
82+
if "model_lang" not in result:
83+
result["model_lang"] = ["en", "zh"]
84+
if "version" not in result:
85+
result["version"] = 2
86+
87+
return result
11488

115-
# Use unified loading function
11689
loaded_count = load_complete_builtin_models(
11790
model_type="audio",
11891
builtin_registry=BUILTIN_AUDIO_MODELS,
119-
convert_format_func=convert_audio_model_format,
92+
convert_format_func=convert_audio_with_flatten,
12093
model_class=AudioModelFamilyV2,
12194
)
12295

xinference/model/embedding/__init__.py

Lines changed: 77 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,44 +24,6 @@
2424
logger = logging.getLogger(__name__)
2525

2626

27-
def convert_embedding_model_format(model_json: Dict[str, Any]) -> Dict[str, Any]:
28-
"""
29-
Convert embedding model hub JSON format to Xinference expected format.
30-
"""
31-
logger.debug(
32-
f"convert_embedding_model_format called for: {model_json.get('model_name', 'Unknown')}"
33-
)
34-
35-
# Ensure required fields for embedding models
36-
converted = model_json.copy()
37-
38-
# Add missing required fields based on EmbeddingModelFamilyV2 requirements
39-
if "version" not in converted:
40-
converted["version"] = 2
41-
if "model_lang" not in converted:
42-
converted["model_lang"] = ["en"]
43-
44-
# Handle model_specs
45-
if "model_specs" not in converted or not converted["model_specs"]:
46-
converted["model_specs"] = [
47-
{
48-
"model_format": "pytorch",
49-
"model_size_in_billions": None,
50-
"quantization": "none",
51-
"model_hub": "huggingface",
52-
}
53-
]
54-
else:
55-
# Ensure each spec has required fields
56-
for spec in converted["model_specs"]:
57-
if "quantization" not in spec:
58-
spec["quantization"] = "none"
59-
if "model_hub" not in spec:
60-
spec["model_hub"] = "huggingface"
61-
62-
return converted
63-
64-
6527
from .core import (
6628
EMBEDDING_MODEL_DESCRIPTIONS,
6729
EmbeddingModelFamilyV2,
@@ -108,17 +70,90 @@ def register_custom_model():
10870

10971

11072
def register_builtin_model():
111-
from ..utils import load_complete_builtin_models
73+
# Use unified loading function with flatten_quantizations for embedding models
74+
from ..utils import flatten_quantizations, load_complete_builtin_models
11275
from .embed_family import BUILTIN_EMBEDDING_MODELS
11376

114-
# Use unified loading function
77+
def convert_embedding_with_quantizations(model_json):
78+
if "model_specs" not in model_json:
79+
return model_json
80+
81+
# Process each model_spec with flatten_quantizations (like builtin embedding loading)
82+
result = model_json.copy()
83+
flattened_specs = []
84+
for spec in result["model_specs"]:
85+
if "model_src" in spec:
86+
flattened_specs.extend(flatten_quantizations(spec))
87+
else:
88+
flattened_specs.append(spec)
89+
result["model_specs"] = flattened_specs
90+
91+
return result
92+
11593
loaded_count = load_complete_builtin_models(
11694
model_type="embedding",
117-
builtin_registry=BUILTIN_EMBEDDING_MODELS,
118-
convert_format_func=convert_embedding_model_format,
95+
builtin_registry={}, # Temporarily use empty dict, we handle it manually
96+
convert_format_func=convert_embedding_with_quantizations,
11997
model_class=EmbeddingModelFamilyV2,
12098
)
12199

100+
# Manually handle embedding's special registration logic
101+
if loaded_count > 0:
102+
from ...constants import XINFERENCE_MODEL_DIR
103+
from ..custom import RegistryManager
104+
105+
registry = RegistryManager.get_registry("embedding")
106+
existing_model_names = {
107+
spec.model_name for spec in registry.get_custom_models()
108+
}
109+
110+
builtin_embedding_dir = os.path.join(
111+
XINFERENCE_MODEL_DIR, "v2", "builtin", "embedding"
112+
)
113+
complete_json_path = os.path.join(
114+
builtin_embedding_dir, "embedding_models.json"
115+
)
116+
117+
if os.path.exists(complete_json_path):
118+
with codecs.open(complete_json_path, encoding="utf-8") as fd:
119+
model_data = json.load(fd)
120+
121+
models_to_register = []
122+
if isinstance(model_data, list):
123+
models_to_register = model_data
124+
elif isinstance(model_data, dict):
125+
if "model_name" in model_data:
126+
models_to_register = [model_data]
127+
else:
128+
for key, value in model_data.items():
129+
if isinstance(value, dict) and "model_name" in value:
130+
models_to_register.append(value)
131+
132+
for model_data in models_to_register:
133+
try:
134+
from ..utils import flatten_quantizations
135+
136+
converted_data = model_data.copy()
137+
if "model_specs" in converted_data:
138+
flattened_specs = []
139+
for spec in converted_data["model_specs"]:
140+
if "model_src" in spec:
141+
flattened_specs.extend(flatten_quantizations(spec))
142+
else:
143+
flattened_specs.append(spec)
144+
converted_data["model_specs"] = flattened_specs
145+
builtin_embedding_family = EmbeddingModelFamilyV2.parse_obj(
146+
converted_data
147+
)
148+
149+
if builtin_embedding_family.model_name not in existing_model_names:
150+
register_embedding(builtin_embedding_family, persist=False)
151+
existing_model_names.add(builtin_embedding_family.model_name)
152+
except Exception as e:
153+
warnings.warn(
154+
f"Error parsing model {model_data.get('model_name', 'Unknown')}: {e}"
155+
)
156+
122157
logger.info(
123158
f"Successfully loaded {loaded_count} embedding models from complete JSON"
124159
)

xinference/model/image/__init__.py

Lines changed: 7 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -24,67 +24,6 @@
2424
logger = logging.getLogger(__name__)
2525

2626

27-
def convert_image_model_format(model_json: Dict[str, Any]) -> Dict[str, Any]:
28-
"""
29-
Convert image model hub JSON format to Xinference expected format.
30-
"""
31-
logger.debug(
32-
f"convert_image_model_format called for: {model_json.get('model_name', 'Unknown')}"
33-
)
34-
35-
# Ensure required fields for image models
36-
converted = model_json.copy()
37-
38-
# Add missing required fields
39-
if "version" not in converted:
40-
converted["version"] = 2
41-
if "model_lang" not in converted:
42-
converted["model_lang"] = ["en"]
43-
44-
# Handle missing model_id and model_revision
45-
if converted.get("model_id") is None and "model_src" in converted:
46-
model_src = converted["model_src"]
47-
# Extract model_id from available sources
48-
if "huggingface" in model_src and "model_id" in model_src["huggingface"]:
49-
converted["model_id"] = model_src["huggingface"]["model_id"]
50-
elif "modelscope" in model_src and "model_id" in model_src["modelscope"]:
51-
converted["model_id"] = model_src["modelscope"]["model_id"]
52-
53-
if converted.get("model_revision") is None and "model_src" in converted:
54-
model_src = converted["model_src"]
55-
# Extract model_revision if available
56-
if "huggingface" in model_src and "model_revision" in model_src["huggingface"]:
57-
converted["model_revision"] = model_src["huggingface"]["model_revision"]
58-
elif "modelscope" in model_src and "model_revision" in model_src["modelscope"]:
59-
converted["model_revision"] = model_src["modelscope"]["model_revision"]
60-
61-
# Set defaults if still missing
62-
if converted.get("model_id") is None:
63-
converted["model_id"] = converted.get("model_name", "unknown")
64-
if converted.get("model_revision") is None:
65-
converted["model_revision"] = "main"
66-
67-
# Handle model_specs
68-
if "model_specs" not in converted or not converted["model_specs"]:
69-
converted["model_specs"] = [
70-
{
71-
"model_format": "pytorch",
72-
"model_size_in_billions": None,
73-
"quantization": "none",
74-
"model_hub": "huggingface",
75-
}
76-
]
77-
else:
78-
# Ensure each spec has required fields
79-
for spec in converted["model_specs"]:
80-
if "quantization" not in spec:
81-
spec["quantization"] = "none"
82-
if "model_hub" not in spec:
83-
spec["model_hub"] = "huggingface"
84-
85-
return converted
86-
87-
8827
from .core import (
8928
BUILTIN_IMAGE_MODELS,
9029
IMAGE_MODEL_DESCRIPTIONS,
@@ -158,8 +97,13 @@ def register_builtin_model():
15897
# Register all models from the complete JSON
15998
for model_data in models_to_register:
16099
try:
161-
# Convert format if needed
162-
converted_data = convert_image_model_format(model_data)
100+
# Convert format using flatten_model_src
101+
from ..utils import flatten_model_src
102+
103+
flattened_list = flatten_model_src(model_data)
104+
converted_data = (
105+
flattened_list[0] if flattened_list else model_data
106+
)
163107
builtin_image_family = ImageModelFamilyV2.parse_obj(
164108
converted_data
165109
)

xinference/model/llm/__init__.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
logger = logging.getLogger(__name__)
2424

2525

26-
27-
2826
from .core import (
2927
LLM,
3028
LLM_VERSION_INFOS,
@@ -137,10 +135,9 @@ def register_custom_model():
137135

138136

139137
def register_builtin_model():
140-
from ..utils import load_complete_builtin_models
141-
142138
# Use unified loading function with flatten_quantizations for LLM
143-
from ..utils import flatten_quantizations
139+
from ..utils import flatten_quantizations, load_complete_builtin_models
140+
144141
def convert_llm_with_quantizations(model_json):
145142
if "model_specs" not in model_json:
146143
return model_json
@@ -194,9 +191,17 @@ def convert_llm_with_quantizations(model_json):
194191

195192
for model_data in models_to_register:
196193
try:
197-
from ..utils import flatten_model_src
198-
flattened_list = flatten_model_src(model_data)
199-
converted_data = flattened_list[0] if flattened_list else model_data
194+
from ..utils import flatten_quantizations
195+
196+
converted_data = model_data.copy()
197+
if "model_specs" in converted_data:
198+
flattened_specs = []
199+
for spec in converted_data["model_specs"]:
200+
if "model_src" in spec:
201+
flattened_specs.extend(flatten_quantizations(spec))
202+
else:
203+
flattened_specs.append(spec)
204+
converted_data["model_specs"] = flattened_specs
200205
builtin_llm_family = LLMFamilyV2.parse_obj(converted_data)
201206

202207
if builtin_llm_family.model_name not in existing_model_names:

0 commit comments

Comments
 (0)