@@ -719,9 +719,38 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
719719 logger .error (f"Invalid model name format: { model_name } " )
720720 raise ValueError (f"Invalid model name format: { model_name } " )
721721
722- # Convert model hub JSON format to Xinference expected format
722+ # Convert model hub JSON format to Xinference expected format using flatten_model_src
723723 try :
724- converted_model_json = self ._convert_model_json_format (model_json )
724+ from ..model .utils import flatten_model_src
725+
726+ # Check different model format types
727+ if "model_src" in model_json :
728+ # Simple flat format with model_src at top level
729+ flattened_list = flatten_model_src (model_json )
730+ converted_model_json = flattened_list [0 ] if flattened_list else model_json
731+ elif "model_specs" in model_json and isinstance (model_json ["model_specs" ], list ):
732+ # LLM/embedding/rerank format with model_specs
733+ from ..model .utils import flatten_quantizations
734+ converted_model_json = model_json .copy ()
735+
736+ # Process all model_specs using flatten_quantizations - exactly like builtin models
737+ flattened_specs = []
738+ for spec in converted_model_json ["model_specs" ]:
739+ if "model_src" in spec :
740+ # Use flatten_quantizations like builtin LLM loading
741+ quantized_specs = flatten_quantizations (spec )
742+ if quantized_specs :
743+ flattened_specs .extend (quantized_specs )
744+ else :
745+ flattened_specs .append (spec )
746+
747+ # Use all flattened specs like builtin models
748+ if flattened_specs :
749+ converted_model_json ["model_specs" ] = flattened_specs
750+ logger .info (f"Processed { len (flattened_specs )} model specifications for { model_name } " )
751+ else :
752+ # Already flattened format, use as-is
753+ converted_model_json = model_json
725754 except Exception as e :
726755 logger .error (f"Format conversion failed: { str (e )} " , exc_info = True )
727756 raise ValueError (f"Failed to convert model JSON format: { str (e )} " )
@@ -828,187 +857,7 @@ async def add_model(self, model_type: str, model_json: Dict[str, Any]):
828857 f"Failed to register model '{ model_spec .model_name } ': { str (e )} "
829858 )
830859
831- def _convert_model_json_format (self , model_json : Dict [str , Any ]) -> Dict [str , Any ]:
832- """
833- Convert model hub JSON format to Xinference expected format.
834-
835- The input format uses nested 'model_src' structure, but Xinference expects
836- flattened fields at the spec level.
837-
838- For LLM/embedding/rerank models: uses model_specs structure
839- For image/audio/video models: uses flat structure with direct fields
840- """
841- # Determine if this is an image/audio/video model (flat structure) or LLM/embedding/rerank (model_specs structure)
842- flat_model_types = ["image" , "audio" , "video" ]
843- model_type = None
844-
845- # Try to determine model type from context or model_ability
846- if "model_ability" in model_json :
847- abilities = model_json ["model_ability" ]
848- if isinstance (abilities , list ):
849- if (
850- "text2img" in abilities
851- or "image2image" in abilities
852- or "ocr" in abilities
853- ):
854- model_type = "image"
855- elif "auto-speech" in abilities or "text-to-speech" in abilities :
856- model_type = "audio"
857- elif "text-to-video" in abilities :
858- model_type = "video"
859-
860- if model_type in flat_model_types :
861- # Handle image/audio/video models with flat structure
862-
863- if "model_src" in model_json :
864- model_src = model_json ["model_src" ]
865-
866- # Extract fields from model_src to top level
867- if "huggingface" in model_src :
868- hf_data = model_src ["huggingface" ]
869-
870- if "model_id" in hf_data and model_json .get ("model_id" ) is None :
871- model_json ["model_id" ] = hf_data ["model_id" ]
872-
873- if (
874- "model_revision" in hf_data
875- and model_json .get ("model_revision" ) is None
876- ):
877- model_json ["model_revision" ] = hf_data ["model_revision" ]
878-
879- # Remove model_src field as it's not needed in the final format
880- del model_json ["model_src" ]
881-
882- # Set required defaults for image models
883- if model_json .get ("model_hub" ) is None :
884- model_json ["model_hub" ] = "huggingface"
885-
886- # Add null fields for completeness based on builtin image model structure
887- null_fields = [
888- "cache_config" ,
889- "controlnet" ,
890- "gguf_model_id" ,
891- "gguf_quantizations" ,
892- "gguf_model_file_name_template" ,
893- "lightning_model_id" ,
894- "lightning_versions" ,
895- "lightning_model_file_name_template" ,
896- "model_uri" ,
897- ]
898- for field in null_fields :
899- if field not in model_json :
900- model_json [field ] = None
901-
902- # Add empty dict fields if missing
903- dict_fields = ["default_model_config" , "default_generate_config" ]
904- for field in dict_fields :
905- if field not in model_json :
906- model_json [field ] = {}
907-
908- else :
909- # Handle LLM/embedding/rerank models with model_specs structure
910-
911- # Handle model_specs - if multiple formats are provided, select the first one
912- if "model_specs" in model_json and isinstance (
913- model_json ["model_specs" ], list
914- ):
915- if len (model_json ["model_specs" ]) > 1 :
916- # For add_model, we'll use the first spec as the primary one
917- # The other specs will be ignored for this registration
918- model_json ["model_specs" ] = [model_json ["model_specs" ][0 ]]
919-
920- # Fix missing quantization field for pytorch/mlx specs
921- spec = model_json ["model_specs" ][0 ]
922- if "quantization" not in spec :
923- model_format = spec .get ("model_format" , "" )
924- if model_format in ["pytorch" , "gptq" , "awq" , "fp8" , "bnb" ]:
925- # Extract quantization from model_src if available
926- if "model_src" in spec and "huggingface" in spec ["model_src" ]:
927- quantizations = spec ["model_src" ]["huggingface" ].get (
928- "quantizations" , []
929- )
930- if quantizations :
931- spec ["quantization" ] = quantizations [
932- 0
933- ] # Use first quantization
934- else :
935- spec ["quantization" ] = "none" # Default quantization
936- else :
937- spec ["quantization" ] = "none" # Default quantization
938- elif model_format == "mlx" :
939- # Extract quantization from model_src if available
940- if "model_src" in spec and "huggingface" in spec ["model_src" ]:
941- quantizations = spec ["model_src" ]["huggingface" ].get (
942- "quantizations" , []
943- )
944- if quantizations :
945- spec ["quantization" ] = quantizations [
946- 0
947- ] # Use first quantization
948- else :
949- spec ["quantization" ] = "4bit" # Default for MLX
950- else :
951- spec ["quantization" ] = "4bit" # Default for MLX
952- elif model_format == "ggufv2" :
953- # GGUF models need to extract quantization from filename template
954- if "model_file_name_template" in spec :
955- template = spec ["model_file_name_template" ]
956- if "{quantization}" not in template :
957- # Try to extract from model_id or set default
958- spec ["quantization" ] = (
959- "Q4_K_M" # Common GGUF quantization
960- )
961- else :
962- spec ["quantization" ] = "Q4_K_M" # Default for GGUF
963-
964- # Add missing required fields for LLM-style specs
965- if "model_hub" not in spec :
966- spec ["model_hub" ] = "huggingface"
967-
968- if "model_id" not in spec :
969- if "model_src" in spec and "huggingface" in spec ["model_src" ]:
970- spec ["model_id" ] = spec ["model_src" ]["huggingface" ]["model_id" ]
971-
972- if "model_revision" not in spec :
973- if "model_src" in spec and "huggingface" in spec ["model_src" ]:
974- spec ["model_revision" ] = spec ["model_src" ]["huggingface" ][
975- "model_revision"
976- ]
977-
978- # Remove model_src from spec as it's not needed in the final format
979- if "model_src" in spec :
980- del spec ["model_src" ]
981-
982- # Handle legacy top-level model_src for backward compatibility
983- if model_json .get ("model_id" ) is None and "model_src" in model_json :
984- model_src = model_json ["model_src" ]
985-
986- if "huggingface" in model_src and "model_id" in model_src ["huggingface" ]:
987- model_json ["model_id" ] = model_src ["huggingface" ]["model_id" ]
988- elif "modelscope" in model_src and "model_id" in model_src ["modelscope" ]:
989- model_json ["model_id" ] = model_src ["modelscope" ]["model_id" ]
990-
991- if model_json .get ("model_revision" ) is None :
992- if (
993- "huggingface" in model_src
994- and "model_revision" in model_src ["huggingface" ]
995- ):
996- model_json ["model_revision" ] = model_src ["huggingface" ][
997- "model_revision"
998- ]
999- elif (
1000- "modelscope" in model_src
1001- and "model_revision" in model_src ["modelscope" ]
1002- ):
1003- model_json ["model_revision" ] = model_src ["modelscope" ][
1004- "model_revision"
1005- ]
1006-
1007- # Remove top-level model_src field as it's not needed in the final format
1008- del model_json ["model_src" ]
1009-
1010- return model_json
1011-
860+
1012861 @log_async (logger = logger )
1013862 async def update_model_type (self , model_type : str ):
1014863 """
0 commit comments