Skip to content

Commit 5d4aefc

Browse files
committed
num2:llm
1 parent 8df104a commit 5d4aefc

File tree

3 files changed

+67
-240
lines changed

3 files changed

+67
-240
lines changed

xinference/core/worker.py

Lines changed: 32 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -719,9 +719,38 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
719719
logger.error(f"Invalid model name format: {model_name}")
720720
raise ValueError(f"Invalid model name format: {model_name}")
721721

722-
# Convert model hub JSON format to Xinference expected format
722+
# Convert model hub JSON format to Xinference expected format using flatten_model_src
723723
try:
724-
converted_model_json = self._convert_model_json_format(model_json)
724+
from ..model.utils import flatten_model_src
725+
726+
# Check different model format types
727+
if "model_src" in model_json:
728+
# Simple flat format with model_src at top level
729+
flattened_list = flatten_model_src(model_json)
730+
converted_model_json = flattened_list[0] if flattened_list else model_json
731+
elif "model_specs" in model_json and isinstance(model_json["model_specs"], list):
732+
# LLM/embedding/rerank format with model_specs
733+
from ..model.utils import flatten_quantizations
734+
converted_model_json = model_json.copy()
735+
736+
# Process all model_specs using flatten_quantizations - exactly like builtin models
737+
flattened_specs = []
738+
for spec in converted_model_json["model_specs"]:
739+
if "model_src" in spec:
740+
# Use flatten_quantizations like builtin LLM loading
741+
quantized_specs = flatten_quantizations(spec)
742+
if quantized_specs:
743+
flattened_specs.extend(quantized_specs)
744+
else:
745+
flattened_specs.append(spec)
746+
747+
# Use all flattened specs like builtin models
748+
if flattened_specs:
749+
converted_model_json["model_specs"] = flattened_specs
750+
logger.info(f"Processed {len(flattened_specs)} model specifications for {model_name}")
751+
else:
752+
# Already flattened format, use as-is
753+
converted_model_json = model_json
725754
except Exception as e:
726755
logger.error(f"Format conversion failed: {str(e)}", exc_info=True)
727756
raise ValueError(f"Failed to convert model JSON format: {str(e)}")
@@ -828,187 +857,7 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
828857
f"Failed to register model '{model_spec.model_name}': {str(e)}"
829858
)
830859

831-
def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, Any]:
832-
"""
833-
Convert model hub JSON format to Xinference expected format.
834-
835-
The input format uses nested 'model_src' structure, but Xinference expects
836-
flattened fields at the spec level.
837-
838-
For LLM/embedding/rerank models: uses model_specs structure
839-
For image/audio/video models: uses flat structure with direct fields
840-
"""
841-
# Determine if this is an image/audio/video model (flat structure) or LLM/embedding/rerank (model_specs structure)
842-
flat_model_types = ["image", "audio", "video"]
843-
model_type = None
844-
845-
# Try to determine model type from context or model_ability
846-
if "model_ability" in model_json:
847-
abilities = model_json["model_ability"]
848-
if isinstance(abilities, list):
849-
if (
850-
"text2img" in abilities
851-
or "image2image" in abilities
852-
or "ocr" in abilities
853-
):
854-
model_type = "image"
855-
elif "auto-speech" in abilities or "text-to-speech" in abilities:
856-
model_type = "audio"
857-
elif "text-to-video" in abilities:
858-
model_type = "video"
859-
860-
if model_type in flat_model_types:
861-
# Handle image/audio/video models with flat structure
862-
863-
if "model_src" in model_json:
864-
model_src = model_json["model_src"]
865-
866-
# Extract fields from model_src to top level
867-
if "huggingface" in model_src:
868-
hf_data = model_src["huggingface"]
869-
870-
if "model_id" in hf_data and model_json.get("model_id") is None:
871-
model_json["model_id"] = hf_data["model_id"]
872-
873-
if (
874-
"model_revision" in hf_data
875-
and model_json.get("model_revision") is None
876-
):
877-
model_json["model_revision"] = hf_data["model_revision"]
878-
879-
# Remove model_src field as it's not needed in the final format
880-
del model_json["model_src"]
881-
882-
# Set required defaults for image models
883-
if model_json.get("model_hub") is None:
884-
model_json["model_hub"] = "huggingface"
885-
886-
# Add null fields for completeness based on builtin image model structure
887-
null_fields = [
888-
"cache_config",
889-
"controlnet",
890-
"gguf_model_id",
891-
"gguf_quantizations",
892-
"gguf_model_file_name_template",
893-
"lightning_model_id",
894-
"lightning_versions",
895-
"lightning_model_file_name_template",
896-
"model_uri",
897-
]
898-
for field in null_fields:
899-
if field not in model_json:
900-
model_json[field] = None
901-
902-
# Add empty dict fields if missing
903-
dict_fields = ["default_model_config", "default_generate_config"]
904-
for field in dict_fields:
905-
if field not in model_json:
906-
model_json[field] = {}
907-
908-
else:
909-
# Handle LLM/embedding/rerank models with model_specs structure
910-
911-
# Handle model_specs - if multiple formats are provided, select the first one
912-
if "model_specs" in model_json and isinstance(
913-
model_json["model_specs"], list
914-
):
915-
if len(model_json["model_specs"]) > 1:
916-
# For add_model, we'll use the first spec as the primary one
917-
# The other specs will be ignored for this registration
918-
model_json["model_specs"] = [model_json["model_specs"][0]]
919-
920-
# Fix missing quantization field for pytorch/mlx specs
921-
spec = model_json["model_specs"][0]
922-
if "quantization" not in spec:
923-
model_format = spec.get("model_format", "")
924-
if model_format in ["pytorch", "gptq", "awq", "fp8", "bnb"]:
925-
# Extract quantization from model_src if available
926-
if "model_src" in spec and "huggingface" in spec["model_src"]:
927-
quantizations = spec["model_src"]["huggingface"].get(
928-
"quantizations", []
929-
)
930-
if quantizations:
931-
spec["quantization"] = quantizations[
932-
0
933-
] # Use first quantization
934-
else:
935-
spec["quantization"] = "none" # Default quantization
936-
else:
937-
spec["quantization"] = "none" # Default quantization
938-
elif model_format == "mlx":
939-
# Extract quantization from model_src if available
940-
if "model_src" in spec and "huggingface" in spec["model_src"]:
941-
quantizations = spec["model_src"]["huggingface"].get(
942-
"quantizations", []
943-
)
944-
if quantizations:
945-
spec["quantization"] = quantizations[
946-
0
947-
] # Use first quantization
948-
else:
949-
spec["quantization"] = "4bit" # Default for MLX
950-
else:
951-
spec["quantization"] = "4bit" # Default for MLX
952-
elif model_format == "ggufv2":
953-
# GGUF models need to extract quantization from filename template
954-
if "model_file_name_template" in spec:
955-
template = spec["model_file_name_template"]
956-
if "{quantization}" not in template:
957-
# Try to extract from model_id or set default
958-
spec["quantization"] = (
959-
"Q4_K_M" # Common GGUF quantization
960-
)
961-
else:
962-
spec["quantization"] = "Q4_K_M" # Default for GGUF
963-
964-
# Add missing required fields for LLM-style specs
965-
if "model_hub" not in spec:
966-
spec["model_hub"] = "huggingface"
967-
968-
if "model_id" not in spec:
969-
if "model_src" in spec and "huggingface" in spec["model_src"]:
970-
spec["model_id"] = spec["model_src"]["huggingface"]["model_id"]
971-
972-
if "model_revision" not in spec:
973-
if "model_src" in spec and "huggingface" in spec["model_src"]:
974-
spec["model_revision"] = spec["model_src"]["huggingface"][
975-
"model_revision"
976-
]
977-
978-
# Remove model_src from spec as it's not needed in the final format
979-
if "model_src" in spec:
980-
del spec["model_src"]
981-
982-
# Handle legacy top-level model_src for backward compatibility
983-
if model_json.get("model_id") is None and "model_src" in model_json:
984-
model_src = model_json["model_src"]
985-
986-
if "huggingface" in model_src and "model_id" in model_src["huggingface"]:
987-
model_json["model_id"] = model_src["huggingface"]["model_id"]
988-
elif "modelscope" in model_src and "model_id" in model_src["modelscope"]:
989-
model_json["model_id"] = model_src["modelscope"]["model_id"]
990-
991-
if model_json.get("model_revision") is None:
992-
if (
993-
"huggingface" in model_src
994-
and "model_revision" in model_src["huggingface"]
995-
):
996-
model_json["model_revision"] = model_src["huggingface"][
997-
"model_revision"
998-
]
999-
elif (
1000-
"modelscope" in model_src
1001-
and "model_revision" in model_src["modelscope"]
1002-
):
1003-
model_json["model_revision"] = model_src["modelscope"][
1004-
"model_revision"
1005-
]
1006-
1007-
# Remove top-level model_src field as it's not needed in the final format
1008-
del model_json["model_src"]
1009-
1010-
return model_json
1011-
860+
1012861
@log_async(logger=logger)
1013862
async def update_model_type(self, model_type: str):
1014863
"""

xinference/model/llm/__init__.py

Lines changed: 22 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -23,60 +23,6 @@
2323
logger = logging.getLogger(__name__)
2424

2525

26-
def convert_model_json_format(model_json: Dict[str, Any]) -> Dict[str, Any]:
27-
"""
28-
Convert model hub JSON format to Xinference expected format.
29-
30-
This is a standalone version of the conversion logic from supervisor.py.
31-
"""
32-
logger.debug(
33-
f"convert_model_json_format called for: {model_json.get('model_name', 'Unknown')}"
34-
)
35-
36-
# If model_specs is missing, provide a default minimal spec
37-
if "model_specs" not in model_json or not model_json["model_specs"]:
38-
logger.debug("model_specs missing or empty, creating default spec")
39-
return {
40-
**model_json,
41-
"version": 2, # Add missing required field
42-
"model_lang": ["en"], # Add missing required field
43-
"model_specs": [
44-
{
45-
"model_format": "pytorch",
46-
"model_size_in_billions": None,
47-
"quantization": "none",
48-
"model_file_name_template": "model.bin",
49-
"model_hub": "huggingface",
50-
}
51-
],
52-
}
53-
54-
converted = model_json.copy()
55-
converted_specs = []
56-
57-
# Ensure required top-level fields
58-
if "version" not in converted:
59-
converted["version"] = 2
60-
if "model_lang" not in converted:
61-
converted["model_lang"] = ["en"]
62-
63-
for spec in model_json["model_specs"]:
64-
model_format = spec.get("model_format", "pytorch")
65-
model_size = spec.get("model_size_in_billions")
66-
67-
# Ensure required fields
68-
converted_spec = spec.copy()
69-
if "quantization" not in converted_spec:
70-
converted_spec["quantization"] = "none"
71-
if "model_file_name_template" not in converted_spec:
72-
converted_spec["model_file_name_template"] = "model.bin"
73-
if "model_hub" not in converted_spec:
74-
converted_spec["model_hub"] = "huggingface"
75-
76-
converted_specs.append(converted_spec)
77-
78-
converted["model_specs"] = converted_specs
79-
return converted
8026

8127

8228
from .core import (
@@ -193,11 +139,28 @@ def register_custom_model():
193139
def register_builtin_model():
194140
from ..utils import load_complete_builtin_models
195141

196-
# Use unified loading function, but LLM needs special handling
142+
# Use unified loading function with flatten_quantizations for LLM
143+
from ..utils import flatten_quantizations
144+
def convert_llm_with_quantizations(model_json):
145+
if "model_specs" not in model_json:
146+
return model_json
147+
148+
# Process each model_spec with flatten_quantizations (like builtin LLM loading)
149+
result = model_json.copy()
150+
flattened_specs = []
151+
for spec in result["model_specs"]:
152+
if "model_src" in spec:
153+
flattened_specs.extend(flatten_quantizations(spec))
154+
else:
155+
flattened_specs.append(spec)
156+
result["model_specs"] = flattened_specs
157+
158+
return result
159+
197160
loaded_count = load_complete_builtin_models(
198161
model_type="llm",
199162
builtin_registry={}, # Temporarily use empty dict, we handle it manually
200-
convert_format_func=convert_model_json_format,
163+
convert_format_func=convert_llm_with_quantizations,
201164
model_class=LLMFamilyV2,
202165
)
203166

@@ -231,7 +194,9 @@ def register_builtin_model():
231194

232195
for model_data in models_to_register:
233196
try:
234-
converted_data = convert_model_json_format(model_data)
197+
from ..utils import flatten_model_src
198+
flattened_list = flatten_model_src(model_data)
199+
converted_data = flattened_list[0] if flattened_list else model_data
235200
builtin_llm_family = LLMFamilyV2.parse_obj(converted_data)
236201

237202
if builtin_llm_family.model_name not in existing_model_names:

xinference/model/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,19 @@ def flatten_quantizations(input_json: dict):
603603
if key != "quantizations":
604604
record[key] = value
605605

606+
# Add required defaults for ggufv2 format if model_file_name_template is missing
607+
if "model_format" in record and record["model_format"] == "ggufv2":
608+
if "model_file_name_template" not in record:
609+
# Generate default template from model_id
610+
model_id = record.get("model_id", "")
611+
if model_id:
612+
# Extract model name from model_id (last part after /)
613+
model_name = model_id.split("/")[-1]
614+
# Remove potential suffixes
615+
if "-GGUF" in model_name:
616+
model_name = model_name.replace("-GGUF", "")
617+
record["model_file_name_template"] = f"{model_name.lower()}-{{quantization}}.gguf"
618+
606619
flattened.append(record)
607620
return flattened
608621

0 commit comments

Comments
 (0)