diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 84c7b18d80..61bd97b6e5 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -198,6 +198,15 @@ class RegisterModelRequest(BaseModel): persist: bool +class AddModelRequest(BaseModel): + model_type: str + model_json: Dict[str, Any] + + +class UpdateModelRequest(BaseModel): + model_type: str + + class BuildGradioInterfaceRequest(BaseModel): model_type: str model_name: str @@ -900,6 +909,26 @@ async def internal_exception_handler(request: Request, exc: Exception): else None ), ) + self._router.add_api_route( + "/v1/models/add", + self.add_model, + methods=["POST"], + dependencies=( + [Security(self._auth_service, scopes=["models:add"])] + if self.is_authenticated() + else None + ), + ) + self._router.add_api_route( + "/v1/models/update_type", + self.update_model_type, + methods=["POST"], + dependencies=( + [Security(self._auth_service, scopes=["models:add"])] + if self.is_authenticated() + else None + ), + ) self._router.add_api_route( "/v1/cache/models", self.list_cached_models, @@ -3123,13 +3152,93 @@ async def unregister_model(self, model_type: str, model_name: str) -> JSONRespon raise HTTPException(status_code=500, detail=str(e)) return JSONResponse(content=None) + async def add_model(self, request: Request) -> JSONResponse: + try: + + # Parse request + raw_json = await request.json() + + if "model_type" in raw_json and "model_json" in raw_json: + body = AddModelRequest.parse_obj(raw_json) + model_type = body.model_type + model_json = body.model_json + else: + model_json = raw_json + + # Priority 1: Check if model_type is explicitly provided in the JSON + if "model_type" in model_json: + model_type = model_json["model_type"] + logger.info( + f"[DEBUG] Using explicit model_type from JSON: {model_type}" + ) + else: + # model_type is required in the JSON when using unwrapped format + logger.error( + f"[DEBUG] model_type not provided in JSON, this is required" + ) + raise HTTPException( + status_code=400, + detail="model_type is required in the model JSON. Supported types: LLM, embedding, audio, image, video, rerank", + ) + + supervisor_ref = await self._get_supervisor_ref() + + # Call supervisor + await supervisor_ref.add_model(model_type, model_json) + + except ValueError as re: + logger.error(f"ValueError in add_model API: {re}", exc_info=True) + logger.error(f"ValueError details: {type(re).__name__}: {re}") + raise HTTPException(status_code=400, detail=str(re)) + except Exception as e: + logger.error(f"Unexpected error in add_model API: {e}", exc_info=True) + logger.error(f"Error details: {type(e).__name__}: {e}") + import traceback + + logger.error(f"Full traceback: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=str(e)) + + return JSONResponse( + content={"message": f"Model added successfully for type: {model_type}"} + ) + + async def update_model_type(self, request: Request) -> JSONResponse: + try: + # Parse request + raw_json = await request.json() + + body = UpdateModelRequest.parse_obj(raw_json) + model_type = body.model_type + + # Get supervisor reference + supervisor_ref = await self._get_supervisor_ref() + + await supervisor_ref.update_model_type(model_type) + + except ValueError as re: + logger.error(f"ValueError in update_model_type API: {re}", exc_info=True) + raise HTTPException(status_code=400, detail=str(re)) + except Exception as e: + logger.error( + f"Unexpected error in update_model_type API: {e}", exc_info=True + ) + raise HTTPException(status_code=500, detail=str(e)) + + return JSONResponse( + content={ + "message": f"Model configurations updated successfully for type: {model_type}" + } + ) + async def list_model_registrations( self, model_type: str, detailed: bool = Query(False) ) -> JSONResponse: try: + data = await (await self._get_supervisor_ref()).list_model_registrations( model_type, detailed=detailed ) + # Remove duplicate model names. model_names = set() final_data = [] @@ -3137,11 +3246,20 @@ async def list_model_registrations( if item["model_name"] not in model_names: model_names.add(item["model_name"]) final_data.append(item) + return JSONResponse(content=final_data) except ValueError as re: + logger.error( + f"ValueError in list_model_registrations: {re}", + exc_info=True, + ) logger.error(re, exc_info=True) raise HTTPException(status_code=400, detail=str(re)) except Exception as e: + logger.error( + f"Unexpected error in list_model_registrations: {e}", + exc_info=True, + ) logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 1ed96cd703..7afed56a1a 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -217,6 +217,13 @@ async def __post_create__(self): register_rerank, unregister_rerank, ) + from ..model.video import ( + CustomVideoModelFamilyV2, + generate_video_description, + get_video_model_descriptions, + register_video, + unregister_video, + ) self._custom_register_type_to_cls: Dict[str, Tuple] = { # type: ignore "LLM": ( @@ -249,6 +256,12 @@ async def __post_create__(self): unregister_audio, generate_audio_description, ), + "video": ( + CustomVideoModelFamilyV2, + register_video, + unregister_video, + generate_video_description, + ), "flexible": ( FlexibleModelSpec, register_flexible_model, @@ -264,6 +277,7 @@ async def __post_create__(self): model_version_infos.update(get_rerank_model_descriptions()) model_version_infos.update(get_image_model_descriptions()) model_version_infos.update(get_audio_model_descriptions()) + model_version_infos.update(get_video_model_descriptions()) model_version_infos.update(get_flexible_model_descriptions()) await self._cache_tracker_ref.record_model_version( model_version_infos, self.address @@ -427,6 +441,51 @@ def _get_spec_dicts( ) return specs, list(download_hubs) + def _is_model_from_builtin_dir(self, model_name: str, model_type: str) -> bool: + """ + Check if a model comes from the builtin directory (added via update_model_type) + or from the custom directory (true user custom models). + """ + import os + + from xinference.constants import XINFERENCE_MODEL_DIR + + # Check builtin directory (update_model_type models) + builtin_dir = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin", model_type.lower() + ) + builtin_file = os.path.join(builtin_dir, f"{model_name}.json") + + if os.path.exists(builtin_file): + return True + + # Also check unified JSON file for models added via update_model_type + unified_json = os.path.join(builtin_dir, f"{model_type.lower()}_models.json") + if os.path.exists(unified_json): + import json + + try: + with open(unified_json, "r", encoding="utf-8") as f: + data = json.load(f) + + # Check if model_name exists in this JSON file + if isinstance(data, list): + return any(model.get("model_name") == model_name for model in data) + elif isinstance(data, dict): + if data.get("model_name") == model_name: + return True + else: + # Check dict values + return any( + isinstance(value, dict) + and value.get("model_name") == model_name + for value in data.values() + ) + except Exception: + pass + + return False + async def _to_llm_reg( self, llm_family: "LLMFamilyV2", is_builtin: bool ) -> Dict[str, Any]: @@ -613,29 +672,66 @@ def sort_helper(item): if not self.is_local_deployment(): workers = list(self._worker_address_to_worker.values()) for worker in workers: - ret.extend(await worker.list_model_registrations(model_type, detailed)) + worker_data = await worker.list_model_registrations( + model_type, detailed + ) + ret.extend(worker_data) - if model_type == "LLM": - from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families + if model_type.upper() == "LLM": + from ..model.llm import ( + BUILTIN_LLM_FAMILIES, + get_registered_llm_families, + register_builtin_model, + ) + register_builtin_model() + + # 1. Hardcoded built-in models for family in BUILTIN_LLM_FAMILIES: if detailed: - ret.append(await self._to_llm_reg(family, True)) + reg_data = await self._to_llm_reg(family, True) + ret.append(reg_data) else: ret.append({"model_name": family.model_name, "is_builtin": True}) - for family in get_user_defined_llm_families(): + # 2. Registered models (user-defined + editor-defined) + registered_families = get_registered_llm_families() + builtin_names = {family.model_name for family in BUILTIN_LLM_FAMILIES} + + for family in registered_families: + # If model is not in hardcoded list, it might be editor-defined, need to check source + if family.model_name not in builtin_names: + # Check if it comes from builtin directory (added via update_model_type) + if self._is_model_from_builtin_dir(family.model_name, "llm"): + # This is an editor-defined model, should be marked as builtin=True + if detailed: + reg_data = await self._to_llm_reg(family, True) + ret.append(reg_data) + else: + ret.append( + {"model_name": family.model_name, "is_builtin": True} + ) + continue + + # True user-defined model, mark as builtin=False if detailed: - ret.append(await self._to_llm_reg(family, False)) + reg_data = await self._to_llm_reg(family, False) + ret.append(reg_data) else: ret.append({"model_name": family.model_name, "is_builtin": False}) - ret.sort(key=sort_helper) + ret.sort(key=sort_helper) return ret elif model_type == "embedding": - from ..model.embedding import BUILTIN_EMBEDDING_MODELS - from ..model.embedding.custom import get_user_defined_embeddings + from ..model.embedding import ( + BUILTIN_EMBEDDING_MODELS, + register_builtin_model, + ) + from ..model.embedding.custom import get_registered_embeddings + + register_builtin_model() + # 1. Hardcoded built-in models for model_name, family in BUILTIN_EMBEDDING_MODELS.items(): if detailed: ret.append( @@ -644,7 +740,34 @@ def sort_helper(item): else: ret.append({"model_name": model_name, "is_builtin": True}) - for model_spec in get_user_defined_embeddings(): + # 2. Registered models (user-defined + editor-defined) + registered_models = get_registered_embeddings() + builtin_names = set(BUILTIN_EMBEDDING_MODELS.keys()) + + for model_spec in registered_models: + # If model is not in hardcoded list, it might be editor-defined, need to check source + if model_spec.model_name not in builtin_names: + # Check if it comes from builtin directory (added via update_model_type) + if self._is_model_from_builtin_dir( + model_spec.model_name, "embedding" + ): + # This is an editor-defined model, should be marked as builtin=True + if detailed: + ret.append( + await self._to_embedding_model_reg( + model_spec, is_builtin=True + ) + ) + else: + ret.append( + { + "model_name": model_spec.model_name, + "is_builtin": True, + } + ) + continue + + # True user-defined model, mark as builtin=False if detailed: ret.append( await self._to_embedding_model_reg(model_spec, is_builtin=False) @@ -657,9 +780,12 @@ def sort_helper(item): ret.sort(key=sort_helper) return ret elif model_type == "image": - from ..model.image import BUILTIN_IMAGE_MODELS - from ..model.image.custom import get_user_defined_images + from ..model.image import BUILTIN_IMAGE_MODELS, register_builtin_model + from ..model.image.custom import get_registered_images + + register_builtin_model() + # 1. Hardcoded built-in models for model_name, families in BUILTIN_IMAGE_MODELS.items(): if detailed: family = [x for x in families if x.model_hub == "huggingface"][0] @@ -669,7 +795,32 @@ def sort_helper(item): else: ret.append({"model_name": model_name, "is_builtin": True}) - for model_spec in get_user_defined_images(): + # 2. Registered models (user-defined + editor-defined) + registered_models = get_registered_images() + builtin_names = set(BUILTIN_IMAGE_MODELS.keys()) + + for model_spec in registered_models: + # If model is not in hardcoded list, it might be editor-defined, need to check source + if model_spec.model_name not in builtin_names: + # Check if it comes from builtin directory (added via update_model_type) + if self._is_model_from_builtin_dir(model_spec.model_name, "image"): + # This is an editor-defined model, should be marked as builtin=True + if detailed: + ret.append( + await self._to_image_model_reg( + model_spec, is_builtin=True + ) + ) + else: + ret.append( + { + "model_name": model_spec.model_name, + "is_builtin": True, + } + ) + continue + + # True user-defined model, mark as builtin=False if detailed: ret.append( await self._to_image_model_reg(model_spec, is_builtin=False) @@ -682,9 +833,12 @@ def sort_helper(item): ret.sort(key=sort_helper) return ret elif model_type == "audio": - from ..model.audio import BUILTIN_AUDIO_MODELS - from ..model.audio.custom import get_user_defined_audios + from ..model.audio import BUILTIN_AUDIO_MODELS, register_builtin_model + from ..model.audio.custom import get_registered_audios + + register_builtin_model() + # 1. Hardcoded built-in models for model_name, families in BUILTIN_AUDIO_MODELS.items(): if detailed: family = [x for x in families if x.model_hub == "huggingface"][0] @@ -694,7 +848,32 @@ def sort_helper(item): else: ret.append({"model_name": model_name, "is_builtin": True}) - for model_spec in get_user_defined_audios(): + # 2. Registered models (user-defined + editor-defined) + registered_models = get_registered_audios() + builtin_names = set(BUILTIN_AUDIO_MODELS.keys()) + + for model_spec in registered_models: + # If model is not in hardcoded list, it might be editor-defined, need to check source + if model_spec.model_name not in builtin_names: + # Check if it comes from builtin directory (added via update_model_type) + if self._is_model_from_builtin_dir(model_spec.model_name, "audio"): + # This is an editor-defined model, should be marked as builtin=True + if detailed: + ret.append( + await self._to_audio_model_reg( + model_spec, is_builtin=True + ) + ) + else: + ret.append( + { + "model_name": model_spec.model_name, + "is_builtin": True, + } + ) + continue + + # True user-defined model, mark as builtin=False if detailed: ret.append( await self._to_audio_model_reg(model_spec, is_builtin=False) @@ -707,8 +886,12 @@ def sort_helper(item): ret.sort(key=sort_helper) return ret elif model_type == "video": - from ..model.video import BUILTIN_VIDEO_MODELS + from ..model.video import BUILTIN_VIDEO_MODELS, register_builtin_model + from ..model.video.custom import get_registered_videos + + register_builtin_model() + # 1. Hardcoded built-in models for model_name, families in BUILTIN_VIDEO_MODELS.items(): if detailed: family = [x for x in families if x.model_hub == "huggingface"][0] @@ -718,19 +901,81 @@ def sort_helper(item): else: ret.append({"model_name": model_name, "is_builtin": True}) + # 2. Registered models (user-defined + editor-defined) + registered_models = get_registered_videos() + builtin_names = set(BUILTIN_VIDEO_MODELS.keys()) + + for model_spec in registered_models: + # If model is not in hardcoded list, it might be editor-defined, need to check source + if model_spec.model_name not in builtin_names: + # Check if it comes from builtin directory (added via update_model_type) + if self._is_model_from_builtin_dir(model_spec.model_name, "video"): + # This is an editor-defined model, should be marked as builtin=True + if detailed: + ret.append( + await self._to_video_model_reg( + model_spec, is_builtin=True + ) + ) + else: + ret.append( + { + "model_name": model_spec.model_name, + "is_builtin": True, + } + ) + continue + + # True user-defined model, mark as builtin=False + if detailed: + ret.append( + await self._to_video_model_reg(model_spec, is_builtin=False) + ) + else: + ret.append( + {"model_name": model_spec.model_name, "is_builtin": False} + ) ret.sort(key=sort_helper) return ret elif model_type == "rerank": - from ..model.rerank import BUILTIN_RERANK_MODELS - from ..model.rerank.custom import get_user_defined_reranks + from ..model.rerank import BUILTIN_RERANK_MODELS, register_builtin_model + from ..model.rerank.custom import get_registered_reranks + + register_builtin_model() + # 1. Hardcoded built-in models for model_name, family in BUILTIN_RERANK_MODELS.items(): if detailed: ret.append(await self._to_rerank_model_reg(family, is_builtin=True)) else: ret.append({"model_name": model_name, "is_builtin": True}) - for model_spec in get_user_defined_reranks(): + # 2. Registered models (user-defined + editor-defined) + registered_models = get_registered_reranks() + builtin_names = set(BUILTIN_RERANK_MODELS.keys()) + + for model_spec in registered_models: + # If model is not in hardcoded list, it might be editor-defined, need to check source + if model_spec.model_name not in builtin_names: + # Check if it comes from builtin directory (added via update_model_type) + if self._is_model_from_builtin_dir(model_spec.model_name, "rerank"): + # This is an editor-defined model, should be marked as builtin=True + if detailed: + ret.append( + await self._to_rerank_model_reg( + model_spec, is_builtin=True + ) + ) + else: + ret.append( + { + "model_name": model_spec.model_name, + "is_builtin": True, + } + ) + continue + + # True user-defined model, mark as builtin=False if detailed: ret.append( await self._to_rerank_model_reg(model_spec, is_builtin=False) @@ -748,13 +993,31 @@ def sort_helper(item): ret = [] for model_spec in get_flexible_models(): + from ..model.cache_manager import CacheManager + + cache_manager = CacheManager(model_spec) + is_persisted_model = False + if hasattr(cache_manager, "_v2_custom_dir_prefix"): + import os + + potential_persist_path = os.path.join( + cache_manager._v2_custom_dir_prefix, + "flexible", + f"{model_spec.model_name}.json", + ) + is_persisted_model = os.path.exists(potential_persist_path) + + is_builtin = is_persisted_model # Treat persisted models as built-in + if detailed: ret.append( - await self._to_flexible_model_reg(model_spec, is_builtin=False) + await self._to_flexible_model_reg( + model_spec, is_builtin=is_builtin + ) ) else: ret.append( - {"model_name": model_spec.model_name, "is_builtin": False} + {"model_name": model_spec.model_name, "is_builtin": is_builtin} ) ret.sort(key=sort_helper) @@ -772,27 +1035,27 @@ async def get_model_registration(self, model_type: str, model_name: str) -> Any: if f is not None: return f - if model_type == "LLM": - from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families + if model_type.upper() == "LLM": + from ..model.llm import BUILTIN_LLM_FAMILIES, get_registered_llm_families - for f in BUILTIN_LLM_FAMILIES + get_user_defined_llm_families(): + for f in BUILTIN_LLM_FAMILIES + get_registered_llm_families(): if f.model_name == model_name: return f raise ValueError(f"Model {model_name} not found") elif model_type == "embedding": from ..model.embedding import BUILTIN_EMBEDDING_MODELS - from ..model.embedding.custom import get_user_defined_embeddings + from ..model.embedding.custom import get_registered_embeddings for f in ( - list(BUILTIN_EMBEDDING_MODELS.values()) + get_user_defined_embeddings() + list(BUILTIN_EMBEDDING_MODELS.values()) + get_registered_embeddings() ): if f.model_name == model_name: return f raise ValueError(f"Model {model_name} not found") elif model_type == "image": from ..model.image import BUILTIN_IMAGE_MODELS - from ..model.image.custom import get_user_defined_images + from ..model.image.custom import get_registered_images if model_name in BUILTIN_IMAGE_MODELS: return [ @@ -801,13 +1064,13 @@ async def get_model_registration(self, model_type: str, model_name: str) -> Any: if x.model_hub == "huggingface" ][0] else: - for f in get_user_defined_images(): + for f in get_registered_images(): if f.model_name == model_name: return f raise ValueError(f"Model {model_name} not found") elif model_type == "audio": from ..model.audio import BUILTIN_AUDIO_MODELS - from ..model.audio.custom import get_user_defined_audios + from ..model.audio.custom import get_registered_audios if model_name in BUILTIN_AUDIO_MODELS: return [ @@ -816,15 +1079,15 @@ async def get_model_registration(self, model_type: str, model_name: str) -> Any: if x.model_hub == "huggingface" ][0] else: - for f in get_user_defined_audios(): + for f in get_registered_audios(): if f.model_name == model_name: return f raise ValueError(f"Model {model_name} not found") elif model_type == "rerank": from ..model.rerank import BUILTIN_RERANK_MODELS - from ..model.rerank.custom import get_user_defined_reranks + from ..model.rerank.custom import get_registered_reranks - for f in list(BUILTIN_RERANK_MODELS.values()) + get_user_defined_reranks(): + for f in list(BUILTIN_RERANK_MODELS.values()) + get_registered_reranks(): if f.model_name == model_name: return f raise ValueError(f"Model {model_name} not found") @@ -837,6 +1100,7 @@ async def get_model_registration(self, model_type: str, model_name: str) -> Any: raise ValueError(f"Model {model_name} not found") elif model_type == "video": from ..model.video import BUILTIN_VIDEO_MODELS + from ..model.video.custom import get_registered_videos if model_name in BUILTIN_VIDEO_MODELS: return [ @@ -844,6 +1108,10 @@ async def get_model_registration(self, model_type: str, model_name: str) -> Any: for x in BUILTIN_VIDEO_MODELS[model_name] if x.model_hub == "huggingface" ][0] + else: + for f in get_registered_videos(): + if f.model_name == model_name: + return f raise ValueError(f"Model {model_name} not found") else: raise ValueError(f"Unsupported model type: {model_type}") @@ -932,6 +1200,35 @@ async def register_model( else: raise ValueError(f"Unsupported model type: {model_type}") + @log_async(logger=logger) + async def add_model(self, model_type: str, model_json: Dict[str, Any]): + """ + Add a new model by forwarding the request to all workers. + + Args: + model_type: Type of model (LLM, embedding, image, etc.) + model_json: JSON configuration for the model + """ + + try: + # Forward the add_model request to all workers + tasks = [] + for worker_address, worker_ref in self._worker_address_to_worker.items(): + tasks.append(worker_ref.add_model(model_type, model_json)) + + # Wait for all workers to complete the operation + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + else: + logger.warning(f"No workers available to forward add_model request") + + except Exception as e: + logger.error( + f"Error during add_model forwarding: {str(e)}", + exc_info=True, + ) + raise ValueError(f"Failed to add model: {str(e)}") + async def _sync_register_model( self, model_type: str, model: str, persist: bool, model_name: str ): @@ -956,6 +1253,37 @@ async def _sync_register_model( logger.warning(f"finish unregister model: {model} for {name}") raise e + @log_async(logger=logger) + async def update_model_type(self, model_type: str): + """ + Update model configurations for a specific model type by forwarding + the request to all workers. + + Args: + model_type: Type of model (LLM, embedding, image, etc.) + """ + + try: + # Forward the update_model_type request to all workers + tasks = [] + for worker_address, worker_ref in self._worker_address_to_worker.items(): + tasks.append(worker_ref.update_model_type(model_type)) + + # Wait for all workers to complete the operation + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + else: + logger.warning( + f"No workers available to forward update_model_type request" + ) + + except Exception as e: + logger.error( + f"Error during update_model_type forwarding: {str(e)}", + exc_info=True, + ) + raise ValueError(f"Failed to update model type: {str(e)}") + @log_async(logger=logger) async def unregister_model(self, model_type: str, model_name: str): if model_type in self._custom_register_type_to_cls: diff --git a/xinference/core/worker.py b/xinference/core/worker.py index 3a211b19e3..d5d41765e5 100644 --- a/xinference/core/worker.py +++ b/xinference/core/worker.py @@ -272,6 +272,12 @@ async def __post_create__(self): register_rerank, unregister_rerank, ) + from ..model.video import ( + CustomVideoModelFamilyV2, + generate_video_description, + register_video, + unregister_video, + ) self._custom_register_type_to_cls: Dict[str, Tuple] = { # type: ignore "LLM": ( @@ -310,6 +316,12 @@ async def __post_create__(self): unregister_flexible_model, generate_flexible_model_description, ), + "video": ( + CustomVideoModelFamilyV2, + register_video, + unregister_video, + generate_video_description, + ), } logger.info("Purge cache directory: %s", XINFERENCE_CACHE_DIR) @@ -652,6 +664,536 @@ async def unregister_model(self, model_type: str, model_name: str): else: raise ValueError(f"Unsupported model type: {model_type}") + @log_async(logger=logger) + async def add_model(self, model_type: str, model_json: Dict[str, Any]): + """ + Add a new model by parsing the provided JSON and registering it. + + Args: + model_type: Type of model (LLM, embedding, image, etc.) + model_json: JSON configuration for the model + """ + # Validate model type (with case normalization) + supported_types = list(self._custom_register_type_to_cls.keys()) + + normalized_model_type = model_type + + if model_type.lower() == "llm" and "LLM" in supported_types: + normalized_model_type = "LLM" + elif model_type.lower() == "llm" and "llm" in supported_types: + normalized_model_type = "llm" + + if normalized_model_type not in self._custom_register_type_to_cls: + logger.error( + f"Unsupported model type: {normalized_model_type} (original: {model_type})" + ) + raise ValueError( + f"Unsupported model type '{model_type}'. " + f"Supported types are: {', '.join(supported_types)}" + ) + + # Use normalized model type for the rest of the function + model_type = normalized_model_type + + # Get the appropriate model class and register function + ( + model_spec_cls, + register_fn, + unregister_fn, + generate_fn, + ) = self._custom_register_type_to_cls[model_type] + + # Validate required fields (only model_name is required) + required_fields = ["model_name"] + for field in required_fields: + if field not in model_json: + logger.error(f"Missing required field: {field}") + raise ValueError(f"Missing required field: {field}") + + # Validate model name format + from ..model.utils import is_valid_model_name + + model_name = model_json["model_name"] + + if not is_valid_model_name(model_name): + logger.error(f"Invalid model name format: {model_name}") + raise ValueError(f"Invalid model name format: {model_name}") + + # Convert model hub JSON format to Xinference expected format + try: + converted_model_json = self._convert_model_json_format(model_json) + except Exception as e: + logger.error(f"Format conversion failed: {str(e)}", exc_info=True) + raise ValueError(f"Failed to convert model JSON format: {str(e)}") + + # Parse the JSON into the appropriate model spec + try: + model_spec = model_spec_cls.parse_obj(converted_model_json) + except Exception as e: + logger.error(f"Model spec parsing failed: {str(e)}", exc_info=True) + raise ValueError(f"Invalid model JSON format: {str(e)}") + + # Check if model already exists + try: + existing_models = await self.list_model_registrations(model_type) + existing_model = None + for model in existing_models: + if model["model_name"] == model_spec.model_name: + existing_model = model + break + + if existing_model is not None: + logger.error(f"Model already exists: {model_spec.model_name}") + raise ValueError( + f"Model '{model_spec.model_name}' already exists for type '{model_type}'. " + f"Please choose a different model name or remove the existing model first." + ) + + except ValueError as e: + if "not found" in str(e): + # Model doesn't exist, we can proceed + pass + else: + # Re-raise validation errors + logger.error(f"Validation error during model check: {str(e)}") + raise e + except Exception as ex: + logger.error( + f"Unexpected error during model check: {str(ex)}", + exc_info=True, + ) + raise ValueError(f"Failed to validate model registration: {str(ex)}") + + try: + # Store model using the same logic as update_model_type for consistency + import json + + from ..constants import XINFERENCE_MODEL_DIR + + model_type_lower = model_type.lower() + builtin_dir = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin", model_type_lower + ) + + # Ensure directory exists + os.makedirs(builtin_dir, exist_ok=True) + + # Use correct storage: save each model as a separate JSON file + # This follows the CacheManager.register_builtin_model pattern + model_dict = model_spec.dict() + + # Create individual model file path + model_file_path = os.path.join(builtin_dir, f"{model_spec.model_name}.json") + + # Check if model already exists + if os.path.exists(model_file_path): + logger.warning( + f"Model {model_spec.model_name} already exists at {model_file_path}" + ) + # Continue with registration even if it exists + + # Save the model as a separate JSON file + with open(model_file_path, "w", encoding="utf-8") as f: + json.dump(model_dict, f, indent=2, ensure_ascii=False) + + # Register in the model registry without persisting to avoid duplicate storage + register_fn(model_spec, persist=False) + + # Record model version + version_info = generate_fn(model_spec) + + await self._cache_tracker_ref.record_model_version( + version_info, self.address + ) + + logger.info( + f"Successfully added model '{model_spec.model_name}' (type: {model_type})" + ) + + except ValueError as e: + # Validation errors - don't need cleanup as model wasn't registered + logger.error(f"ValueError during registration: {str(e)}") + raise e + except Exception as e: + # Unexpected errors - attempt cleanup + logger.error( + f"Unexpected error during registration: {str(e)}", + exc_info=True, + ) + try: + unregister_fn(model_spec.model_name, raise_error=False) + except Exception as cleanup_error: + logger.warning(f"Cleanup failed: {cleanup_error}") + raise ValueError( + f"Failed to register model '{model_spec.model_name}': {str(e)}" + ) + + def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert model hub JSON format to Xinference expected format. + + The input format uses nested 'model_src' structure, but Xinference expects + flattened fields at the spec level. + + For LLM/embedding/rerank models: uses model_specs structure + For image/audio/video models: uses flat structure with direct fields + """ + # Determine if this is an image/audio/video model (flat structure) or LLM/embedding/rerank (model_specs structure) + flat_model_types = ["image", "audio", "video"] + model_type = None + + # Try to determine model type from context or model_ability + if "model_ability" in model_json: + abilities = model_json["model_ability"] + if isinstance(abilities, list): + if ( + "text2img" in abilities + or "image2image" in abilities + or "ocr" in abilities + ): + model_type = "image" + elif "auto-speech" in abilities or "text-to-speech" in abilities: + model_type = "audio" + elif "text-to-video" in abilities: + model_type = "video" + + if model_type in flat_model_types: + # Handle image/audio/video models with flat structure + + if "model_src" in model_json: + model_src = model_json["model_src"] + + # Extract fields from model_src to top level + if "huggingface" in model_src: + hf_data = model_src["huggingface"] + + if "model_id" in hf_data and model_json.get("model_id") is None: + model_json["model_id"] = hf_data["model_id"] + + if ( + "model_revision" in hf_data + and model_json.get("model_revision") is None + ): + model_json["model_revision"] = hf_data["model_revision"] + + # Remove model_src field as it's not needed in the final format + del model_json["model_src"] + + # Set required defaults for image models + if model_json.get("model_hub") is None: + model_json["model_hub"] = "huggingface" + + # Add null fields for completeness based on builtin image model structure + null_fields = [ + "cache_config", + "controlnet", + "gguf_model_id", + "gguf_quantizations", + "gguf_model_file_name_template", + "lightning_model_id", + "lightning_versions", + "lightning_model_file_name_template", + "model_uri", + ] + for field in null_fields: + if field not in model_json: + model_json[field] = None + + # Add empty dict fields if missing + dict_fields = ["default_model_config", "default_generate_config"] + for field in dict_fields: + if field not in model_json: + model_json[field] = {} + + else: + # Handle LLM/embedding/rerank models with model_specs structure + + # Handle model_specs - if multiple formats are provided, select the first one + if "model_specs" in model_json and isinstance( + model_json["model_specs"], list + ): + if len(model_json["model_specs"]) > 1: + # For add_model, we'll use the first spec as the primary one + # The other specs will be ignored for this registration + model_json["model_specs"] = [model_json["model_specs"][0]] + + # Fix missing quantization field for pytorch/mlx specs + spec = model_json["model_specs"][0] + if "quantization" not in spec: + model_format = spec.get("model_format", "") + if model_format in ["pytorch", "gptq", "awq", "fp8", "bnb"]: + # Extract quantization from model_src if available + if "model_src" in spec and "huggingface" in spec["model_src"]: + quantizations = spec["model_src"]["huggingface"].get( + "quantizations", [] + ) + if quantizations: + spec["quantization"] = quantizations[ + 0 + ] # Use first quantization + else: + spec["quantization"] = "none" # Default quantization + else: + spec["quantization"] = "none" # Default quantization + elif model_format == "mlx": + # Extract quantization from model_src if available + if "model_src" in spec and "huggingface" in spec["model_src"]: + quantizations = spec["model_src"]["huggingface"].get( + "quantizations", [] + ) + if quantizations: + spec["quantization"] = quantizations[ + 0 + ] # Use first quantization + else: + spec["quantization"] = "4bit" # Default for MLX + else: + spec["quantization"] = "4bit" # Default for MLX + elif model_format == "ggufv2": + # GGUF models need to extract quantization from filename template + if "model_file_name_template" in spec: + template = spec["model_file_name_template"] + if "{quantization}" not in template: + # Try to extract from model_id or set default + spec["quantization"] = ( + "Q4_K_M" # Common GGUF quantization + ) + else: + spec["quantization"] = "Q4_K_M" # Default for GGUF + + # Add missing required fields for LLM-style specs + if "model_hub" not in spec: + spec["model_hub"] = "huggingface" + + if "model_id" not in spec: + if "model_src" in spec and "huggingface" in spec["model_src"]: + spec["model_id"] = spec["model_src"]["huggingface"]["model_id"] + + if "model_revision" not in spec: + if "model_src" in spec and "huggingface" in spec["model_src"]: + spec["model_revision"] = spec["model_src"]["huggingface"][ + "model_revision" + ] + + # Remove model_src from spec as it's not needed in the final format + if "model_src" in spec: + del spec["model_src"] + + # Handle legacy top-level model_src for backward compatibility + if model_json.get("model_id") is None and "model_src" in model_json: + model_src = model_json["model_src"] + + if "huggingface" in model_src and "model_id" in model_src["huggingface"]: + model_json["model_id"] = model_src["huggingface"]["model_id"] + elif "modelscope" in model_src and "model_id" in model_src["modelscope"]: + model_json["model_id"] = model_src["modelscope"]["model_id"] + + if model_json.get("model_revision") is None: + if ( + "huggingface" in model_src + and "model_revision" in model_src["huggingface"] + ): + model_json["model_revision"] = model_src["huggingface"][ + "model_revision" + ] + elif ( + "modelscope" in model_src + and "model_revision" in model_src["modelscope"] + ): + model_json["model_revision"] = model_src["modelscope"][ + "model_revision" + ] + + # Remove top-level model_src field as it's not needed in the final format + del model_json["model_src"] + + return model_json + + @log_async(logger=logger) + async def update_model_type(self, model_type: str): + """ + Update model configurations for a specific model type by downloading + the latest JSON from the remote API and storing it locally. + + Args: + model_type: Type of model (LLM, embedding, image, etc.) + """ + import json + + import requests + + supported_types = list(self._custom_register_type_to_cls.keys()) + + normalized_for_validation = model_type + if model_type.lower() == "llm" and "LLM" in supported_types: + normalized_for_validation = "LLM" + elif model_type.lower() == "llm" and "llm" in supported_types: + normalized_for_validation = "llm" + + if normalized_for_validation not in supported_types: + logger.error(f"Unsupported model type: {normalized_for_validation}") + raise ValueError( + f"Unsupported model type '{model_type}'. " + f"Supported types are: {', '.join(supported_types)}" + ) + + # Construct the URL to download JSON + url = f"https://model.xinference.io/api/models/download?model_type={model_type.lower()}" + + try: + # Download JSON from remote API + response = requests.get(url, timeout=30) + response.raise_for_status() + + # Parse JSON response + model_data = response.json() + + # Store the JSON data using CacheManager as built-in models + await self._store_complete_model_configurations(model_type, model_data) + + # Dynamically reload built-in models to make them immediately available + try: + if model_type.lower() == "llm": + from ..model.llm import register_builtin_model + + register_builtin_model() + elif model_type.lower() == "embedding": + from ..model.embedding import register_builtin_model + + register_builtin_model() + elif model_type.lower() == "audio": + from ..model.audio import register_builtin_model + + register_builtin_model() + elif model_type.lower() == "image": + from ..model.image import register_builtin_model + + register_builtin_model() + elif model_type.lower() == "rerank": + from ..model.rerank import register_builtin_model + + register_builtin_model() + elif model_type.lower() == "video": + from ..model.video import register_builtin_model + + register_builtin_model() + else: + logger.warning( + f"No dynamic loading available for model type: {model_type}" + ) + except Exception as reload_error: + logger.error( + f"Error reloading built-in models: {reload_error}", + exc_info=True, + ) + # Don't fail the update if reload fails, just log the error + + except requests.exceptions.RequestException as e: + logger.error(f"Network error downloading model configurations: {e}") + raise ValueError(f"Failed to download model configurations: {str(e)}") + except json.JSONDecodeError as e: + logger.error(f"JSON decode error: {e}") + raise ValueError(f"Invalid JSON response from remote API: {str(e)}") + except Exception as e: + logger.error( + f"Unexpected error during model update: {e}", + exc_info=True, + ) + raise ValueError(f"Failed to update model configurations: {str(e)}") + + async def _store_model_configurations(self, model_type: str, model_data): + """ + Store model configurations as separate JSON files (one per model). + This follows the same pattern as CacheManager.register_builtin_model. + + Args: + model_type: Type of model (as provided by user, e.g., "llm") + model_data: JSON data containing model configurations (can be single dict or list) + """ + import json + + from ..constants import XINFERENCE_MODEL_DIR + + try: + # Ensure model_data is a list for consistent processing + if isinstance(model_data, dict): + models_to_store = [model_data] + elif isinstance(model_data, list): + models_to_store = model_data + else: + raise ValueError(f"Invalid model_data type: {type(model_data)}") + + model_type_lower = model_type.lower() + builtin_dir = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin", model_type_lower + ) + + # Ensure directory exists + os.makedirs(builtin_dir, exist_ok=True) + + # Store each model as a separate JSON file + for model_dict in models_to_store: + if not isinstance(model_dict, dict): + logger.warning(f"Skipping invalid model data: {model_dict}") + continue + + model_name = model_dict.get("model_name") + if not model_name: + logger.warning(f"Skipping model without model_name: {model_dict}") + continue + + # Create file path using model name (same as CacheManager pattern) + json_file_path = os.path.join(builtin_dir, f"{model_name}.json") + + # Store the model as a separate JSON file + with open(json_file_path, "w", encoding="utf-8") as f: + json.dump(model_dict, f, indent=2, ensure_ascii=False) + + except Exception as e: + logger.error( + f"Error storing model configurations: {str(e)}", + exc_info=True, + ) + raise ValueError(f"Failed to store model configurations: {str(e)}") + + async def _store_complete_model_configurations(self, model_type: str, model_data): + """ + Store complete model configurations as a unified JSON file. + This is used by update_model_type to preserve the original JSON structure. + + Args: + model_type: Type of model (as provided by user, e.g., "llm") + model_data: JSON data containing model configurations (complete array) + """ + import json + + from ..constants import XINFERENCE_MODEL_DIR + + try: + model_type_lower = model_type.lower() + + # Use the unified JSON file path (same as original update_model_type logic) + builtin_dir = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin", model_type_lower + ) + json_file_path = os.path.join( + builtin_dir, f"{model_type_lower}_models.json" + ) + + # Ensure directory exists + os.makedirs(builtin_dir, exist_ok=True) + + # Store the complete JSON file (preserving original structure) + with open(json_file_path, "w", encoding="utf-8") as f: + json.dump(model_data, f, indent=2, ensure_ascii=False) + + except Exception as e: + logger.error( + f"Error storing complete model configurations: {str(e)}", + exc_info=True, + ) + raise ValueError(f"Failed to store complete model configurations: {str(e)}") + @log_async(logger=logger) async def list_model_registrations( self, model_type: str, detailed: bool = False @@ -661,41 +1203,41 @@ def sort_helper(item): return item.get("model_name").lower() if model_type == "LLM": - from ..model.llm import get_user_defined_llm_families + from ..model.llm import get_registered_llm_families ret = [] - for family in get_user_defined_llm_families(): + for family in get_registered_llm_families(): ret.append({"model_name": family.model_name, "is_builtin": False}) ret.sort(key=sort_helper) return ret elif model_type == "embedding": - from ..model.embedding.custom import get_user_defined_embeddings + from ..model.embedding.custom import get_registered_embeddings ret = [] - for model_spec in get_user_defined_embeddings(): + for model_spec in get_registered_embeddings(): ret.append({"model_name": model_spec.model_name, "is_builtin": False}) ret.sort(key=sort_helper) return ret elif model_type == "image": - from ..model.image.custom import get_user_defined_images + from ..model.image.custom import get_registered_images ret = [] - for model_spec in get_user_defined_images(): + for model_spec in get_registered_images(): ret.append({"model_name": model_spec.model_name, "is_builtin": False}) ret.sort(key=sort_helper) return ret elif model_type == "audio": - from ..model.audio.custom import get_user_defined_audios + from ..model.audio.custom import get_registered_audios ret = [] - for model_spec in get_user_defined_audios(): + for model_spec in get_registered_audios(): ret.append({"model_name": model_spec.model_name, "is_builtin": False}) ret.sort(key=sort_helper) @@ -703,11 +1245,11 @@ def sort_helper(item): elif model_type == "video": return [] elif model_type == "rerank": - from ..model.rerank.custom import get_user_defined_reranks + from ..model.rerank.custom import get_registered_reranks ret = [] - for model_spec in get_user_defined_reranks(): + for model_spec in get_registered_reranks(): ret.append({"model_name": model_spec.model_name, "is_builtin": False}) ret.sort(key=sort_helper) @@ -728,35 +1270,35 @@ def sort_helper(item): @log_sync(logger=logger) async def get_model_registration(self, model_type: str, model_name: str) -> Any: if model_type == "LLM": - from ..model.llm import get_user_defined_llm_families + from ..model.llm import get_registered_llm_families - for f in get_user_defined_llm_families(): + for f in get_registered_llm_families(): if f.model_name == model_name: return f elif model_type == "embedding": - from ..model.embedding.custom import get_user_defined_embeddings + from ..model.embedding.custom import get_registered_embeddings - for f in get_user_defined_embeddings(): + for f in get_registered_embeddings(): if f.model_name == model_name: return f elif model_type == "image": - from ..model.image.custom import get_user_defined_images + from ..model.image.custom import get_registered_images - for f in get_user_defined_images(): + for f in get_registered_images(): if f.model_name == model_name: return f elif model_type == "audio": - from ..model.audio.custom import get_user_defined_audios + from ..model.audio.custom import get_registered_audios - for f in get_user_defined_audios(): + for f in get_registered_audios(): if f.model_name == model_name: return f elif model_type == "video": return None elif model_type == "rerank": - from ..model.rerank.custom import get_user_defined_reranks + from ..model.rerank.custom import get_registered_reranks - for f in get_user_defined_reranks(): + for f in get_registered_reranks(): if f.model_name == model_name: return f return None diff --git a/xinference/model/audio/__init__.py b/xinference/model/audio/__init__.py index 9465771917..775a61aab7 100644 --- a/xinference/model/audio/__init__.py +++ b/xinference/model/audio/__init__.py @@ -14,14 +14,63 @@ import codecs import json +import logging import os import platform import sys import warnings -from typing import Dict, List +from typing import Any, Dict, List from ...constants import XINFERENCE_MODEL_DIR from ..utils import flatten_model_src + +logger = logging.getLogger(__name__) + + +def convert_audio_model_format(model_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert audio model hub JSON format to Xinference expected format. + """ + logger.debug( + f"convert_audio_model_format called for: {model_json.get('model_name', 'Unknown')}" + ) + + # Apply conversion logic to handle null model_id and other issues + if model_json.get("model_id") is None and "model_src" in model_json: + model_src = model_json["model_src"] + # Extract model_id from available sources + if "huggingface" in model_src and "model_id" in model_src["huggingface"]: + model_json["model_id"] = model_src["huggingface"]["model_id"] + elif "modelscope" in model_src and "model_id" in model_src["modelscope"]: + model_json["model_id"] = model_src["modelscope"]["model_id"] + + # Extract model_revision if available + if model_json.get("model_revision") is None: + if ( + "huggingface" in model_src + and "model_revision" in model_src["huggingface"] + ): + model_json["model_revision"] = model_src["huggingface"][ + "model_revision" + ] + elif ( + "modelscope" in model_src + and "model_revision" in model_src["modelscope"] + ): + model_json["model_revision"] = model_src["modelscope"]["model_revision"] + + # Ensure required fields for audio models + if "version" not in model_json: + model_json["version"] = 2 + if "model_lang" not in model_json: + model_json["model_lang"] = [ + "en", + "zh", + ] # Audio models often support multiple languages + + return model_json + + from .core import ( AUDIO_MODEL_DESCRIPTIONS, AudioModelFamilyV2, @@ -30,7 +79,7 @@ ) from .custom import ( CustomAudioModelFamilyV2, - get_user_defined_audios, + get_registered_audios, register_audio, unregister_audio, ) @@ -60,6 +109,20 @@ def register_custom_model(): warnings.warn(f"{user_defined_audio_dir}/{f} has error, {e}") +def register_builtin_model(): + from ..utils import load_complete_builtin_models + + # Use unified loading function + loaded_count = load_complete_builtin_models( + model_type="audio", + builtin_registry=BUILTIN_AUDIO_MODELS, + convert_format_func=convert_audio_model_format, + model_class=AudioModelFamilyV2, + ) + + logger.info(f"Successfully loaded {loaded_count} audio models from complete JSON") + + def _need_filter(spec: dict): if (sys.platform != "darwin" or platform.processor() != "arm") and spec.get( "engine", "" @@ -80,7 +143,7 @@ def _install(): register_custom_model() # register model description - for ud_audio in get_user_defined_audios(): + for ud_audio in get_registered_audios(): AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(ud_audio)) diff --git a/xinference/model/audio/core.py b/xinference/model/audio/core.py index e2d147fa38..666bf45c1d 100644 --- a/xinference/model/audio/core.py +++ b/xinference/model/audio/core.py @@ -100,9 +100,9 @@ def match_audio( ) -> AudioModelFamilyV2: from ..utils import download_from_modelscope from . import BUILTIN_AUDIO_MODELS - from .custom import get_user_defined_audios + from .custom import get_registered_audios - for model_spec in get_user_defined_audios(): + for model_spec in get_registered_audios(): if model_spec.model_name == model_name: return model_spec diff --git a/xinference/model/audio/custom.py b/xinference/model/audio/custom.py index 8024078481..b38fc1378e 100644 --- a/xinference/model/audio/custom.py +++ b/xinference/model/audio/custom.py @@ -83,7 +83,11 @@ def __init__(self): self.builtin_models = list(BUILTIN_AUDIO_MODELS.keys()) -def get_user_defined_audios() -> List[CustomAudioModelFamilyV2]: +def get_registered_audios() -> List[CustomAudioModelFamilyV2]: + """ + Get all audio families registered in the registry (both user-defined and editor-defined). + This excludes hardcoded builtin models. + """ from ..custom import RegistryManager registry = RegistryManager.get_registry("audio") diff --git a/xinference/model/cache_manager.py b/xinference/model/cache_manager.py index ae9a9f1bfd..e4b74e2177 100644 --- a/xinference/model/cache_manager.py +++ b/xinference/model/cache_manager.py @@ -16,8 +16,12 @@ def __init__(self, model_family: "CacheableModelSpec"): self._model_family = model_family self._v2_cache_dir_prefix = os.path.join(XINFERENCE_CACHE_DIR, "v2") self._v2_custom_dir_prefix = os.path.join(XINFERENCE_MODEL_DIR, "v2") + self._v2_builtin_dir_prefix = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin" + ) os.makedirs(self._v2_cache_dir_prefix, exist_ok=True) os.makedirs(self._v2_custom_dir_prefix, exist_ok=True) + os.makedirs(self._v2_builtin_dir_prefix, exist_ok=True) self._cache_dir = os.path.join( self._v2_cache_dir_prefix, self._model_family.model_name.replace(".", "_") ) @@ -109,9 +113,21 @@ def cache(self) -> str: return self._cache() def register_custom_model(self, model_type: str): + model_type_dir = model_type.lower() persist_path = os.path.join( self._v2_custom_dir_prefix, - model_type, + model_type_dir, + f"{self._model_family.model_name}.json", + ) + os.makedirs(os.path.dirname(persist_path), exist_ok=True) + with open(persist_path, mode="w") as fd: + fd.write(self._model_family.json()) + + def register_builtin_model(self, model_type: str): + model_type_dir = model_type.lower() + persist_path = os.path.join( + self._v2_builtin_dir_prefix, + model_type_dir, f"{self._model_family.model_name}.json", ) os.makedirs(os.path.dirname(persist_path), exist_ok=True) @@ -119,9 +135,10 @@ def register_custom_model(self, model_type: str): fd.write(self._model_family.json()) def unregister_custom_model(self, model_type: str): + model_type_dir = model_type.lower() persist_path = os.path.join( self._v2_custom_dir_prefix, - model_type, + model_type_dir, f"{self._model_family.model_name}.json", ) if os.path.exists(persist_path): diff --git a/xinference/model/custom.py b/xinference/model/custom.py index f08a09dfea..a1adee9aea 100644 --- a/xinference/model/custom.py +++ b/xinference/model/custom.py @@ -118,6 +118,7 @@ def get_registry(cls, model_type: str) -> ModelRegistry: from .image.custom import ImageModelRegistry from .llm.custom import LLMModelRegistry from .rerank.custom import RerankModelRegistry + from .video.custom import VideoModelRegistry if model_type not in cls._instances: if model_type == "rerank": @@ -126,6 +127,8 @@ def get_registry(cls, model_type: str) -> ModelRegistry: cls._instances[model_type] = ImageModelRegistry() elif model_type == "audio": cls._instances[model_type] = AudioModelRegistry() + elif model_type == "video": + cls._instances[model_type] = VideoModelRegistry() elif model_type == "llm": cls._instances[model_type] = LLMModelRegistry() elif model_type == "flexible": diff --git a/xinference/model/embedding/__init__.py b/xinference/model/embedding/__init__.py index f1e822e112..602781436d 100644 --- a/xinference/model/embedding/__init__.py +++ b/xinference/model/embedding/__init__.py @@ -14,11 +14,54 @@ import codecs import json +import logging import os import warnings from typing import Any, Dict, List from ..utils import flatten_quantizations + +logger = logging.getLogger(__name__) + + +def convert_embedding_model_format(model_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert embedding model hub JSON format to Xinference expected format. + """ + logger.debug( + f"convert_embedding_model_format called for: {model_json.get('model_name', 'Unknown')}" + ) + + # Ensure required fields for embedding models + converted = model_json.copy() + + # Add missing required fields based on EmbeddingModelFamilyV2 requirements + if "version" not in converted: + converted["version"] = 2 + if "model_lang" not in converted: + converted["model_lang"] = ["en"] + + # Handle model_specs + if "model_specs" not in converted or not converted["model_specs"]: + converted["model_specs"] = [ + { + "model_format": "pytorch", + "model_size_in_billions": None, + "quantization": "none", + "model_hub": "huggingface", + } + ] + else: + # Ensure each spec has required fields + for spec in converted["model_specs"]: + if "quantization" not in spec: + spec["quantization"] = "none" + if "model_hub" not in spec: + spec["model_hub"] = "huggingface" + + return converted + + from .core import ( EMBEDDING_MODEL_DESCRIPTIONS, EmbeddingModelFamilyV2, @@ -27,7 +70,7 @@ ) from .custom import ( CustomEmbeddingModelFamilyV2, - get_user_defined_embeddings, + get_registered_embeddings, register_embedding, unregister_embedding, ) @@ -64,6 +107,23 @@ def register_custom_model(): warnings.warn(f"{user_defined_embedding_dir}/{f} has error, {e}") +def register_builtin_model(): + from ..utils import load_complete_builtin_models + from .embed_family import BUILTIN_EMBEDDING_MODELS + + # Use unified loading function + loaded_count = load_complete_builtin_models( + model_type="embedding", + builtin_registry=BUILTIN_EMBEDDING_MODELS, + convert_format_func=convert_embedding_model_format, + model_class=EmbeddingModelFamilyV2, + ) + + logger.info( + f"Successfully loaded {loaded_count} embedding models from complete JSON" + ) + + def check_format_with_engine(model_format, engine): if model_format in ["ggufv2"] and engine not in ["llama.cpp"]: return False @@ -151,7 +211,7 @@ def _install(): register_custom_model() # register model description - for ud_embedding in get_user_defined_embeddings(): + for ud_embedding in get_registered_embeddings(): EMBEDDING_MODEL_DESCRIPTIONS.update( generate_embedding_description(ud_embedding) ) diff --git a/xinference/model/embedding/custom.py b/xinference/model/embedding/custom.py index 180d2f690a..2e889d5e0a 100644 --- a/xinference/model/embedding/custom.py +++ b/xinference/model/embedding/custom.py @@ -69,7 +69,11 @@ def remove_ud_model_files(self, model_family: "CustomEmbeddingModelFamilyV2"): cache_manager.unregister_custom_model(self.model_type) -def get_user_defined_embeddings() -> List[EmbeddingModelFamilyV2]: +def get_registered_embeddings() -> List[EmbeddingModelFamilyV2]: + """ + Get all embedding families registered in the registry (both user-defined and editor-defined). + This excludes hardcoded builtin models. + """ from ..custom import RegistryManager registry = RegistryManager.get_registry("embedding") diff --git a/xinference/model/embedding/embed_family.py b/xinference/model/embedding/embed_family.py index a572d7cb68..60c2682792 100644 --- a/xinference/model/embedding/embed_family.py +++ b/xinference/model/embedding/embed_family.py @@ -37,14 +37,14 @@ def match_embedding( ] = None, ) -> "EmbeddingModelFamilyV2": from ..utils import download_from_modelscope - from .custom import get_user_defined_embeddings + from .custom import get_registered_embeddings target_family = None if model_name in BUILTIN_EMBEDDING_MODELS: target_family = BUILTIN_EMBEDDING_MODELS[model_name] else: - for model_family in get_user_defined_embeddings(): + for model_family in get_registered_embeddings(): if model_name == model_family.model_name: target_family = model_family break diff --git a/xinference/model/flexible/launchers/__init__.py b/xinference/model/flexible/launchers/__init__.py index f8de4cd8d4..09138b5b2a 100644 --- a/xinference/model/flexible/launchers/__init__.py +++ b/xinference/model/flexible/launchers/__init__.py @@ -11,8 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .image_process_launcher import launcher as image_process -from .modelscope_launcher import launcher as modelscope -from .transformers_launcher import launcher as transformers -from .yolo_launcher import launcher as yolo diff --git a/xinference/model/image/__init__.py b/xinference/model/image/__init__.py index 14230ea41c..b6e8321275 100644 --- a/xinference/model/image/__init__.py +++ b/xinference/model/image/__init__.py @@ -14,10 +14,77 @@ import codecs import json +import logging import os import warnings +from typing import Any, Dict from ..utils import flatten_model_src + +logger = logging.getLogger(__name__) + + +def convert_image_model_format(model_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert image model hub JSON format to Xinference expected format. + """ + logger.debug( + f"convert_image_model_format called for: {model_json.get('model_name', 'Unknown')}" + ) + + # Ensure required fields for image models + converted = model_json.copy() + + # Add missing required fields + if "version" not in converted: + converted["version"] = 2 + if "model_lang" not in converted: + converted["model_lang"] = ["en"] + + # Handle missing model_id and model_revision + if converted.get("model_id") is None and "model_src" in converted: + model_src = converted["model_src"] + # Extract model_id from available sources + if "huggingface" in model_src and "model_id" in model_src["huggingface"]: + converted["model_id"] = model_src["huggingface"]["model_id"] + elif "modelscope" in model_src and "model_id" in model_src["modelscope"]: + converted["model_id"] = model_src["modelscope"]["model_id"] + + if converted.get("model_revision") is None and "model_src" in converted: + model_src = converted["model_src"] + # Extract model_revision if available + if "huggingface" in model_src and "model_revision" in model_src["huggingface"]: + converted["model_revision"] = model_src["huggingface"]["model_revision"] + elif "modelscope" in model_src and "model_revision" in model_src["modelscope"]: + converted["model_revision"] = model_src["modelscope"]["model_revision"] + + # Set defaults if still missing + if converted.get("model_id") is None: + converted["model_id"] = converted.get("model_name", "unknown") + if converted.get("model_revision") is None: + converted["model_revision"] = "main" + + # Handle model_specs + if "model_specs" not in converted or not converted["model_specs"]: + converted["model_specs"] = [ + { + "model_format": "pytorch", + "model_size_in_billions": None, + "quantization": "none", + "model_hub": "huggingface", + } + ] + else: + # Ensure each spec has required fields + for spec in converted["model_specs"]: + if "quantization" not in spec: + spec["quantization"] = "none" + if "model_hub" not in spec: + spec["model_hub"] = "huggingface" + + return converted + + from .core import ( BUILTIN_IMAGE_MODELS, IMAGE_MODEL_DESCRIPTIONS, @@ -27,7 +94,7 @@ ) from .custom import ( CustomImageModelFamilyV2, - get_user_defined_images, + get_registered_images, register_image, unregister_image, ) @@ -55,9 +122,117 @@ def register_custom_model(): warnings.warn(f"{user_defined_image_dir}/{f} has error, {e}") +def register_builtin_model(): + import json + + from ...constants import XINFERENCE_MODEL_DIR + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("image") + existing_model_names = {spec.model_name for spec in registry.get_custom_models()} + + builtin_image_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin", "image") + if os.path.isdir(builtin_image_dir): + # First, try to load from the complete JSON file + complete_json_path = os.path.join(builtin_image_dir, "image_models.json") + if os.path.exists(complete_json_path): + try: + with codecs.open(complete_json_path, encoding="utf-8") as fd: + model_data = json.load(fd) + + # Handle different formats + models_to_register = [] + if isinstance(model_data, list): + # Multiple models in a list + models_to_register = model_data + elif isinstance(model_data, dict): + # Single model + if "model_name" in model_data: + models_to_register = [model_data] + else: + # Models dict - extract models + for key, value in model_data.items(): + if isinstance(value, dict) and "model_name" in value: + models_to_register.append(value) + + # Register all models from the complete JSON + for model_data in models_to_register: + try: + # Convert format if needed + converted_data = convert_image_model_format(model_data) + builtin_image_family = ImageModelFamilyV2.parse_obj( + converted_data + ) + + # Only register if model doesn't already exist + if builtin_image_family.model_name not in existing_model_names: + # Add to BUILTIN_IMAGE_MODELS directly for proper builtin registration + if ( + builtin_image_family.model_name + not in BUILTIN_IMAGE_MODELS + ): + BUILTIN_IMAGE_MODELS[ + builtin_image_family.model_name + ] = [] + BUILTIN_IMAGE_MODELS[ + builtin_image_family.model_name + ].append(builtin_image_family) + # Update model descriptions for the new builtin model + IMAGE_MODEL_DESCRIPTIONS.update( + generate_image_description(builtin_image_family) + ) + existing_model_names.add(builtin_image_family.model_name) + except Exception as e: + warnings.warn( + f"Error parsing image model {model_data.get('model_name', 'Unknown')}: {e}" + ) + + logger.info( + f"Successfully registered {len(models_to_register)} image models from complete JSON" + ) + + except Exception as e: + warnings.warn( + f"Error loading complete JSON file {complete_json_path}: {e}" + ) + # Fall back to individual files if complete JSON loading fails + + # Fall back: load individual JSON files (backward compatibility) + individual_files = [ + f + for f in os.listdir(builtin_image_dir) + if f.endswith(".json") and f != "image_models.json" + ] + for f in individual_files: + try: + with codecs.open( + os.path.join(builtin_image_dir, f), encoding="utf-8" + ) as fd: + builtin_image_family = ImageModelFamilyV2.parse_obj(json.load(fd)) + + # Only register if model doesn't already exist + if builtin_image_family.model_name not in existing_model_names: + # Add to BUILTIN_IMAGE_MODELS directly for proper builtin registration + if builtin_image_family.model_name not in BUILTIN_IMAGE_MODELS: + BUILTIN_IMAGE_MODELS[builtin_image_family.model_name] = [] + BUILTIN_IMAGE_MODELS[builtin_image_family.model_name].append( + builtin_image_family + ) + # Update model descriptions for the new builtin model + IMAGE_MODEL_DESCRIPTIONS.update( + generate_image_description(builtin_image_family) + ) + existing_model_names.add(builtin_image_family.model_name) + except Exception as e: + warnings.warn(f"{builtin_image_dir}/{f} has error, {e}") + + def _install(): load_model_family_from_json("model_spec.json", BUILTIN_IMAGE_MODELS) + # Load models from complete JSON file (from update_model_type) + register_builtin_model() + # register model description for model_name, model_specs in BUILTIN_IMAGE_MODELS.items(): model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0] @@ -65,7 +240,7 @@ def _install(): register_custom_model() - for ud_image in get_user_defined_images(): + for ud_image in get_registered_images(): IMAGE_MODEL_DESCRIPTIONS.update(generate_image_description(ud_image)) diff --git a/xinference/model/image/core.py b/xinference/model/image/core.py index b4baa09bcd..46be16945f 100644 --- a/xinference/model/image/core.py +++ b/xinference/model/image/core.py @@ -121,9 +121,9 @@ def match_diffusion( ) -> ImageModelFamilyV2: from ..utils import download_from_modelscope from . import BUILTIN_IMAGE_MODELS - from .custom import get_user_defined_images + from .custom import get_registered_images - for model_spec in get_user_defined_images(): + for model_spec in get_registered_images(): if model_spec.model_name == model_name: return model_spec diff --git a/xinference/model/image/custom.py b/xinference/model/image/custom.py index 3e3e2a81b9..a8c75433b4 100644 --- a/xinference/model/image/custom.py +++ b/xinference/model/image/custom.py @@ -43,7 +43,11 @@ def __init__(self): self.builtin_models = list(BUILTIN_IMAGE_MODELS.keys()) -def get_user_defined_images() -> List[ImageModelFamilyV2]: +def get_registered_images() -> List[ImageModelFamilyV2]: + """ + Get all image families registered in the registry (both user-defined and editor-defined). + This excludes hardcoded builtin models. + """ from ..custom import RegistryManager registry = RegistryManager.get_registry("image") diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index a4c4704ce4..425d4e7513 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -13,17 +13,79 @@ # limitations under the License. import codecs import json +import logging import os import warnings +from typing import Any, Dict from ..utils import flatten_quantizations + +logger = logging.getLogger(__name__) + + +def convert_model_json_format(model_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert model hub JSON format to Xinference expected format. + + This is a standalone version of the conversion logic from supervisor.py. + """ + logger.debug( + f"convert_model_json_format called for: {model_json.get('model_name', 'Unknown')}" + ) + + # If model_specs is missing, provide a default minimal spec + if "model_specs" not in model_json or not model_json["model_specs"]: + logger.debug("model_specs missing or empty, creating default spec") + return { + **model_json, + "version": 2, # Add missing required field + "model_lang": ["en"], # Add missing required field + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": None, + "quantization": "none", + "model_file_name_template": "model.bin", + "model_hub": "huggingface", + } + ], + } + + converted = model_json.copy() + converted_specs = [] + + # Ensure required top-level fields + if "version" not in converted: + converted["version"] = 2 + if "model_lang" not in converted: + converted["model_lang"] = ["en"] + + for spec in model_json["model_specs"]: + model_format = spec.get("model_format", "pytorch") + model_size = spec.get("model_size_in_billions") + + # Ensure required fields + converted_spec = spec.copy() + if "quantization" not in converted_spec: + converted_spec["quantization"] = "none" + if "model_file_name_template" not in converted_spec: + converted_spec["model_file_name_template"] = "model.bin" + if "model_hub" not in converted_spec: + converted_spec["model_hub"] = "huggingface" + + converted_specs.append(converted_spec) + + converted["model_specs"] = converted_specs + return converted + + from .core import ( LLM, LLM_VERSION_INFOS, generate_llm_version_info, get_llm_version_infos, ) -from .custom import get_user_defined_llm_families, register_llm, unregister_llm +from .custom import get_registered_llm_families, register_llm, unregister_llm from .llm_family import ( BUILTIN_LLM_FAMILIES, BUILTIN_LLM_MODEL_CHAT_FAMILIES, @@ -128,6 +190,61 @@ def register_custom_model(): warnings.warn(f"{user_defined_llm_dir}/{f} has error, {e}") +def register_builtin_model(): + from ..utils import load_complete_builtin_models + + # Use unified loading function, but LLM needs special handling + loaded_count = load_complete_builtin_models( + model_type="llm", + builtin_registry={}, # Temporarily use empty dict, we handle it manually + convert_format_func=convert_model_json_format, + model_class=LLMFamilyV2, + ) + + # Manually handle LLM's special registration logic + if loaded_count > 0: + from ...constants import XINFERENCE_MODEL_DIR + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("llm") + existing_model_names = { + spec.model_name for spec in registry.get_custom_models() + } + + builtin_llm_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin", "llm") + complete_json_path = os.path.join(builtin_llm_dir, "llm_models.json") + + if os.path.exists(complete_json_path): + with codecs.open(complete_json_path, encoding="utf-8") as fd: + model_data = json.load(fd) + + models_to_register = [] + if isinstance(model_data, list): + models_to_register = model_data + elif isinstance(model_data, dict): + if "model_name" in model_data: + models_to_register = [model_data] + else: + for key, value in model_data.items(): + if isinstance(value, dict) and "model_name" in value: + models_to_register.append(value) + + for model_data in models_to_register: + try: + converted_data = convert_model_json_format(model_data) + builtin_llm_family = LLMFamilyV2.parse_obj(converted_data) + + if builtin_llm_family.model_name not in existing_model_names: + register_llm(builtin_llm_family, persist=False) + existing_model_names.add(builtin_llm_family.model_name) + except Exception as e: + warnings.warn( + f"Error parsing model {model_data.get('model_name', 'Unknown')}: {e}" + ) + + logger.info(f"Successfully loaded {loaded_count} llm models from complete JSON") + + def load_model_family_from_json(json_filename, target_families): json_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), json_filename) for json_obj in json.load(codecs.open(json_path, "r", encoding="utf-8")): @@ -210,5 +327,5 @@ def _install(): register_custom_model() # register model description - for ud_llm in get_user_defined_llm_families(): + for ud_llm in get_registered_llm_families(): LLM_VERSION_INFOS.update(generate_llm_version_info(ud_llm)) diff --git a/xinference/model/llm/custom.py b/xinference/model/llm/custom.py index 65cf8f8afd..8d96a341eb 100644 --- a/xinference/model/llm/custom.py +++ b/xinference/model/llm/custom.py @@ -67,7 +67,11 @@ def remove_ud_model_files(self, llm_family: "LLMFamilyV2"): cache_manager.unregister_custom_model(self.model_type) -def get_user_defined_llm_families(): +def get_registered_llm_families(): + """ + Get all LLM families registered in the registry (both user-defined and editor-defined). + This excludes hardcoded builtin models. + """ from ..custom import RegistryManager registry = RegistryManager.get_registry("llm") diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index 628c59e98b..0e58dc9269 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -479,9 +479,9 @@ def match_llm( """ Find an LLM family, spec, and quantization that satisfy given criteria. """ - from .custom import get_user_defined_llm_families + from .custom import get_registered_llm_families - user_defined_llm_families = get_user_defined_llm_families() + user_defined_llm_families = get_registered_llm_families() def _match_quantization(q: Union[str, None], quant: str): # Currently, the quantization name could include both uppercase and lowercase letters, diff --git a/xinference/model/rerank/__init__.py b/xinference/model/rerank/__init__.py index 36334cb9fc..0ce236fd13 100644 --- a/xinference/model/rerank/__init__.py +++ b/xinference/model/rerank/__init__.py @@ -14,12 +14,55 @@ import codecs import json +import logging import os import warnings from typing import Any, Dict, List from ...constants import XINFERENCE_MODEL_DIR from ..utils import flatten_quantizations + +logger = logging.getLogger(__name__) + + +def convert_rerank_model_format(model_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert rerank model hub JSON format to Xinference expected format. + """ + logger.debug( + f"convert_rerank_model_format called for: {model_json.get('model_name', 'Unknown')}" + ) + + # Ensure required fields for rerank models + converted = model_json.copy() + + # Add missing required fields + if "version" not in converted: + converted["version"] = 2 + if "model_lang" not in converted: + converted["model_lang"] = ["en"] + + # Handle model_specs + if "model_specs" not in converted or not converted["model_specs"]: + converted["model_specs"] = [ + { + "model_format": "pytorch", + "model_size_in_billions": None, + "quantization": "none", + "model_hub": "huggingface", + } + ] + else: + # Ensure each spec has required fields + for spec in converted["model_specs"]: + if "quantization" not in spec: + spec["quantization"] = "none" + if "model_hub" not in spec: + spec["model_hub"] = "huggingface" + + return converted + + from .core import ( RERANK_MODEL_DESCRIPTIONS, RerankModelFamilyV2, @@ -28,7 +71,7 @@ ) from .custom import ( CustomRerankModelFamilyV2, - get_user_defined_reranks, + get_registered_reranks, register_rerank, unregister_rerank, ) @@ -63,6 +106,20 @@ def register_custom_model(): warnings.warn(f"{user_defined_rerank_dir}/{f} has error, {e}") +def register_builtin_model(): + from ..utils import load_complete_builtin_models + + # Use unified loading function + loaded_count = load_complete_builtin_models( + model_type="rerank", + builtin_registry=BUILTIN_RERANK_MODELS, + convert_format_func=convert_rerank_model_format, + model_class=RerankModelFamilyV2, + ) + + logger.info(f"Successfully loaded {loaded_count} rerank models from complete JSON") + + def generate_engine_config_by_model_name(model_family: "RerankModelFamilyV2"): model_name = model_family.model_name engines: Dict[str, List[Dict[str, Any]]] = RERANK_ENGINES.get( @@ -127,5 +184,5 @@ def _install(): register_custom_model() # register model description - for ud_rerank in get_user_defined_reranks(): + for ud_rerank in get_registered_reranks(): RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(ud_rerank)) diff --git a/xinference/model/rerank/custom.py b/xinference/model/rerank/custom.py index c09fdd40be..1e22dfaf54 100644 --- a/xinference/model/rerank/custom.py +++ b/xinference/model/rerank/custom.py @@ -67,7 +67,11 @@ def remove_ud_model_files(self, model_family: "CustomRerankModelFamilyV2"): cache_manager.unregister_custom_model(self.model_type) -def get_user_defined_reranks() -> List[CustomRerankModelFamilyV2]: +def get_registered_reranks() -> List[CustomRerankModelFamilyV2]: + """ + Get all rerank families registered in the registry (both user-defined and editor-defined). + This excludes hardcoded builtin models. + """ from ..custom import RegistryManager registry = RegistryManager.get_registry("rerank") diff --git a/xinference/model/rerank/rerank_family.py b/xinference/model/rerank/rerank_family.py index 62639d06cf..1cbcc681d9 100644 --- a/xinference/model/rerank/rerank_family.py +++ b/xinference/model/rerank/rerank_family.py @@ -36,14 +36,14 @@ def match_rerank( ] = None, ) -> "RerankModelFamilyV2": from ..utils import download_from_modelscope - from .custom import get_user_defined_reranks + from .custom import get_registered_reranks target_family = None if model_name in BUILTIN_RERANK_MODELS: target_family = BUILTIN_RERANK_MODELS[model_name] else: - for model_family in get_user_defined_reranks(): + for model_family in get_registered_reranks(): if model_name == model_family.model_name: target_family = model_family break diff --git a/xinference/model/utils.py b/xinference/model/utils.py index ea5dec74d5..83163619d3 100644 --- a/xinference/model/utils.py +++ b/xinference/model/utils.py @@ -709,3 +709,92 @@ def _wrapper(self, *args, **kwargs): return _async_wrapper else: return _wrapper + + +def load_complete_builtin_models( + model_type: str, builtin_registry: dict, convert_format_func=None, model_class=None +): + """ + Load complete JSON files for built-in models in a unified way. + + Args: + model_type: Model type (llm, embedding, audio, image, video, rerank) + builtin_registry: Built-in model registry dictionary + convert_format_func: Format conversion function (optional) + model_class: Model class (optional) + + Returns: + int: Number of successfully loaded models + """ + import codecs + import json + import logging + + from ..constants import XINFERENCE_MODEL_DIR + + logger = logging.getLogger(__name__) + + builtin_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin", model_type) + complete_json_path = os.path.join(builtin_dir, f"{model_type}_models.json") + + if not os.path.exists(complete_json_path): + logger.debug(f"Complete JSON file not found: {complete_json_path}") + return 0 + + try: + with codecs.open(complete_json_path, encoding="utf-8") as fd: + model_data = json.load(fd) + + models_to_register = [] + if isinstance(model_data, list): + models_to_register = model_data + elif isinstance(model_data, dict): + if "model_name" in model_data: + models_to_register = [model_data] + else: + for key, value in model_data.items(): + if isinstance(value, dict) and "model_name" in value: + models_to_register.append(value) + + loaded_count = 0 + for data in models_to_register: + try: + # Apply format conversion function (if provided) + if convert_format_func: + data = convert_format_func(data) + + # Create model instance (if model class is provided) + if model_class: + model = model_class.parse_obj(data) + model_name = model.model_name + else: + model_name = data.get("model_name", "unknown") + model = data + + # Add to registry based on model type + if model_type in ["audio", "image", "video", "llm"]: + # These model types use list structure: dict[model_name] = [model1, model2, ...] + if model_name not in builtin_registry: + builtin_registry[model_name] = [model] + else: + builtin_registry[model_name].append(model) + else: + # embedding, rerank use single model structure: dict[model_name] = model + builtin_registry[model_name] = model + + loaded_count += 1 + logger.info(f"Loaded {model_type} builtin model: {model_name}") + + except Exception as e: + logger.warning( + f"Failed to load {model_type} model {data.get('model_name', 'Unknown')}: {e}" + ) + + logger.info( + f"Successfully loaded {loaded_count} {model_type} models from complete JSON" + ) + return loaded_count + + except Exception as e: + logger.error(f"Failed to load complete JSON {complete_json_path}: {e}") + return 0 diff --git a/xinference/model/video/__init__.py b/xinference/model/video/__init__.py index 5002fcc039..71e22490a0 100644 --- a/xinference/model/video/__init__.py +++ b/xinference/model/video/__init__.py @@ -14,9 +14,77 @@ import codecs import json +import logging import os +import warnings +from typing import Any, Dict from ..utils import flatten_model_src + +logger = logging.getLogger(__name__) + + +def convert_video_model_format(model_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert video model hub JSON format to Xinference expected format. + """ + logger.debug( + f"convert_video_model_format called for: {model_json.get('model_name', 'Unknown')}" + ) + + # Ensure required fields for video models + converted = model_json.copy() + + # Add missing required fields + if "version" not in converted: + converted["version"] = 2 + if "model_lang" not in converted: + converted["model_lang"] = ["en"] + + # Handle missing model_id and model_revision + if converted.get("model_id") is None and "model_src" in converted: + model_src = converted["model_src"] + # Extract model_id from available sources + if "huggingface" in model_src and "model_id" in model_src["huggingface"]: + converted["model_id"] = model_src["huggingface"]["model_id"] + elif "modelscope" in model_src and "model_id" in model_src["modelscope"]: + converted["model_id"] = model_src["modelscope"]["model_id"] + + if converted.get("model_revision") is None and "model_src" in converted: + model_src = converted["model_src"] + # Extract model_revision if available + if "huggingface" in model_src and "model_revision" in model_src["huggingface"]: + converted["model_revision"] = model_src["huggingface"]["model_revision"] + elif "modelscope" in model_src and "model_revision" in model_src["modelscope"]: + converted["model_revision"] = model_src["modelscope"]["model_revision"] + + # Set defaults if still missing + if converted.get("model_id") is None: + converted["model_id"] = converted.get("model_name", "unknown") + if converted.get("model_revision") is None: + converted["model_revision"] = "main" + + # Handle model_specs + if "model_specs" not in converted or not converted["model_specs"]: + converted["model_specs"] = [ + { + "model_format": "pytorch", + "model_size_in_billions": None, + "quantization": "none", + "model_hub": "huggingface", + } + ] + else: + # Ensure each spec has required fields + for spec in converted["model_specs"]: + if "quantization" not in spec: + spec["quantization"] = "none" + if "model_hub" not in spec: + spec["model_hub"] = "huggingface" + + return converted + + from .core import ( BUILTIN_VIDEO_MODELS, VIDEO_MODEL_DESCRIPTIONS, @@ -24,11 +92,61 @@ generate_video_description, get_video_model_descriptions, ) +from .custom import ( + CustomVideoModelFamilyV2, + get_registered_videos, + register_video, + unregister_video, +) + + +def register_custom_model(): + from ...constants import XINFERENCE_MODEL_DIR + from ..custom import migrate_from_v1_to_v2 + + # migrate from v1 to v2 first + migrate_from_v1_to_v2("video", CustomVideoModelFamilyV2) + + user_defined_video_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "video") + if os.path.isdir(user_defined_video_dir): + for f in os.listdir(user_defined_video_dir): + try: + with codecs.open( + os.path.join(user_defined_video_dir, f), encoding="utf-8" + ) as fd: + user_defined_video_family = CustomVideoModelFamilyV2.parse_obj( + json.load(fd) + ) + register_video(user_defined_video_family, persist=False) + except Exception as e: + warnings.warn(f"{user_defined_video_dir}/{f} has error, {e}") + + +def register_builtin_model(): + """ + Dynamically load built-in video models from builtin/video directory. + This function is called every time model list is requested, + ensuring real-time updates without server restart. + """ + from ..utils import load_complete_builtin_models + + # Use unified loading function + loaded_count = load_complete_builtin_models( + model_type="video", + builtin_registry=BUILTIN_VIDEO_MODELS, + convert_format_func=convert_video_model_format, + model_class=VideoModelFamilyV2, + ) + + logger.info(f"Successfully loaded {loaded_count} video models from complete JSON") def _install(): load_model_family_from_json("model_spec.json", BUILTIN_VIDEO_MODELS) + # Load models from complete JSON file (from update_model_type) + register_builtin_model() + # register model description for model_name, model_specs in BUILTIN_VIDEO_MODELS.items(): model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0] diff --git a/xinference/model/video/custom.py b/xinference/model/video/custom.py new file mode 100644 index 0000000000..841917c42e --- /dev/null +++ b/xinference/model/video/custom.py @@ -0,0 +1,74 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, List, Optional + +from ..._compat import ( + Literal, +) +from ..custom import ModelRegistry +from .core import VideoModelFamilyV2 + +logger = logging.getLogger(__name__) + + +class CustomVideoModelFamilyV2(VideoModelFamilyV2): + version: Literal[2] = 2 + model_id: Optional[str] # type: ignore + model_revision: Optional[str] # type: ignore + model_uri: Optional[str] + + +if TYPE_CHECKING: + from typing import TypeVar + + _T = TypeVar("_T", bound="CustomVideoModelFamilyV2") + + +class VideoModelRegistry(ModelRegistry): + model_type = "video" + + def __init__(self): + super().__init__() + + def get_user_defined_models(self) -> List["CustomVideoModelFamilyV2"]: + return self.get_custom_models() + + +video_registry = VideoModelRegistry() + + +def register_video(model_spec: CustomVideoModelFamilyV2, persist: bool = True): + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("video") + registry.register(model_spec, persist) + + +def unregister_video(model_name: str, raise_error: bool = True): + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("video") + registry.unregister(model_name, raise_error) + + +def get_registered_videos() -> List[CustomVideoModelFamilyV2]: + """ + Get all video families registered in the registry (both user-defined and editor-defined). + This excludes hardcoded builtin models. + """ + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("video") + return registry.get_custom_models() diff --git a/xinference/ui/web/ui/src/locales/en.json b/xinference/ui/web/ui/src/locales/en.json index a12662732d..437fb45a5c 100644 --- a/xinference/ui/web/ui/src/locales/en.json +++ b/xinference/ui/web/ui/src/locales/en.json @@ -124,7 +124,23 @@ "featured": "featured", "all": "all", "cancelledSuccessfully": "Cancelled Successfully!", - "mustBeUnique": "{{key}} must be unique" + "mustBeUnique": "{{key}} must be unique", + "addModel": "Add Model", + "addModelDialog": { + "introPrefix": "To add a model, please go to the", + "platformLinkText": "Xinference Model Hub", + "introSuffix": "and fill in the corresponding model name.", + "modelName": "Model Name", + "modelName.tip": "Please enter the model name", + "placeholder": "e.g. qwen3 (case-sensitive)" + }, + "update": "Update", + "error": { + "name_not_matched": "No exact model name match found (case-sensitive)", + "downloadFailed": "Download failed", + "requestFailed": "Request failed", + "json_parse_error": "Failed to parse JSON" + } }, "runningModels": { diff --git a/xinference/ui/web/ui/src/locales/ja.json b/xinference/ui/web/ui/src/locales/ja.json index dc1636bfd3..e4075f9e1d 100644 --- a/xinference/ui/web/ui/src/locales/ja.json +++ b/xinference/ui/web/ui/src/locales/ja.json @@ -124,7 +124,23 @@ "featured": "おすすめとお気に入り", "all": "すべて", "cancelledSuccessfully": "正常にキャンセルされました!", - "mustBeUnique": "{{key}} は一意でなければなりません" + "mustBeUnique": "{{key}} は一意でなければなりません", + "addModel": "モデルを追加", + "addModelDialog": { + "introPrefix": "モデルを追加するには、", + "platformLinkText": "Xinference モデルセンター", + "introSuffix": "で対応するモデル名を入力してください。", + "modelName": "モデル名", + "modelName.tip": "モデル名を入力してください", + "placeholder": "例:qwen3(大文字と小文字を区別します)" + }, + "update": "更新", + "error": { + "name_not_matched": "完全に一致するモデル名が見つかりません(大文字と小文字を区別します)", + "downloadFailed": "ダウンロードに失敗しました", + "requestFailed": "リクエストに失敗しました", + "json_parse_error": "JSON の解析に失敗しました" + } }, "runningModels": { diff --git a/xinference/ui/web/ui/src/locales/ko.json b/xinference/ui/web/ui/src/locales/ko.json index 17ad7626a6..36fd0cd0c2 100644 --- a/xinference/ui/web/ui/src/locales/ko.json +++ b/xinference/ui/web/ui/src/locales/ko.json @@ -124,7 +124,23 @@ "featured": "추천 및 즐겨찾기", "all": "모두", "cancelledSuccessfully": "성공적으로 취소되었습니다!", - "mustBeUnique": "{{key}} 는 고유해야 합니다" + "mustBeUnique": "{{key}} 는 고유해야 합니다", + "addModel": "모델 추가", + "addModelDialog": { + "introPrefix": "모델을 추가하려면", + "platformLinkText": "Xinference 모델 센터", + "introSuffix": "에서 해당 모델 이름을 입력하세요.", + "modelName": "모델 이름", + "modelName.tip": "모델 이름을 입력하세요", + "placeholder": "예: qwen3 (대소문자를 구분합니다)" + }, + "update": "업데이트", + "error": { + "name_not_matched": "완전히 일치하는 모델 이름을 찾을 수 없습니다(대소문자 구분)", + "downloadFailed": "다운로드 실패", + "requestFailed": "요청 실패", + "json_parse_error": "JSON 구문 분석에 실패했습니다" + } }, "runningModels": { diff --git a/xinference/ui/web/ui/src/locales/zh.json b/xinference/ui/web/ui/src/locales/zh.json index 36daec1756..3a0a1d7a19 100644 --- a/xinference/ui/web/ui/src/locales/zh.json +++ b/xinference/ui/web/ui/src/locales/zh.json @@ -124,7 +124,23 @@ "featured": "推荐和收藏", "all": "全部", "cancelledSuccessfully": "取消成功!", - "mustBeUnique": "{{key}} 必须唯一" + "mustBeUnique": "{{key}} 必须唯一", + "addModel": "添加模型", + "addModelDialog": { + "introPrefix": "添加模型需基于", + "platformLinkText": "Xinference 模型中心", + "introSuffix": ",填写模型对应的名称", + "modelName": "模型名称", + "modelName.tip": "请输入模型名称", + "placeholder": "例如:qwen3(需大小写完全匹配)" + }, + "update": "更新", + "error": { + "name_not_matched": "未找到完全匹配的模型名称(需大小写一致)", + "downloadFailed": "下载失败", + "requestFailed": "请求失败", + "json_parse_error": "JSON 解析失败" + } }, "runningModels": { diff --git a/xinference/ui/web/ui/src/scenes/launch_model/LaunchModel.js b/xinference/ui/web/ui/src/scenes/launch_model/LaunchModel.js index cba7bf9a65..623a122b6d 100644 --- a/xinference/ui/web/ui/src/scenes/launch_model/LaunchModel.js +++ b/xinference/ui/web/ui/src/scenes/launch_model/LaunchModel.js @@ -10,9 +10,11 @@ import { Select, } from '@mui/material' import React, { + forwardRef, useCallback, useContext, useEffect, + useImperativeHandle, useRef, useState, } from 'react' @@ -28,494 +30,507 @@ import ModelCard from './modelCard' // Toggle pagination globally for this page. Set to false to disable pagination and load all items. const ENABLE_PAGINATION = false -const LaunchModelComponent = ({ modelType, gpuAvailable, featureModels }) => { - const { isCallingApi, setIsCallingApi, endPoint } = useContext(ApiContext) - const { isUpdatingModel } = useContext(ApiContext) - const { setErrorMsg } = useContext(ApiContext) - const [cookie] = useCookies(['token']) - - const [registrationData, setRegistrationData] = useState([]) - // States used for filtering - const [searchTerm, setSearchTerm] = useState('') - const [status, setStatus] = useState('') - const [statusArr, setStatusArr] = useState([]) - const [collectionArr, setCollectionArr] = useState([]) - const [filterArr, setFilterArr] = useState([]) - const { t } = useTranslation() - const [modelListType, setModelListType] = useState('featured') - const [modelAbilityData, setModelAbilityData] = useState({ - type: modelType, - modelAbility: '', - options: [], - }) - const [selectedModel, setSelectedModel] = useState(null) - const [isOpenLaunchModelDrawer, setIsOpenLaunchModelDrawer] = useState(false) - - // Pagination status - const [displayedData, setDisplayedData] = useState([]) - const [currentPage, setCurrentPage] = useState(1) - const [hasMore, setHasMore] = useState(true) - const itemsPerPage = 20 - const loaderRef = useRef(null) - - const filter = useCallback( - (registration) => { - if (searchTerm !== '') { - if (!registration || typeof searchTerm !== 'string') return false - const modelName = registration.model_name - ? registration.model_name.toLowerCase() - : '' - const modelDescription = registration.model_description - ? registration.model_description.toLowerCase() - : '' +const LaunchModelComponent = forwardRef( + ({ modelType, gpuAvailable, featureModels }, ref) => { + const { isCallingApi, setIsCallingApi, endPoint } = useContext(ApiContext) + const { isUpdatingModel } = useContext(ApiContext) + const { setErrorMsg } = useContext(ApiContext) + const [cookie] = useCookies(['token']) + + const [registrationData, setRegistrationData] = useState([]) + // States used for filtering + const [searchTerm, setSearchTerm] = useState('') + const [status, setStatus] = useState('') + const [statusArr, setStatusArr] = useState([]) + const [collectionArr, setCollectionArr] = useState([]) + const [filterArr, setFilterArr] = useState([]) + const { t } = useTranslation() + const [modelListType, setModelListType] = useState('featured') + const [modelAbilityData, setModelAbilityData] = useState({ + type: modelType, + modelAbility: '', + options: [], + }) + const [selectedModel, setSelectedModel] = useState(null) + const [isOpenLaunchModelDrawer, setIsOpenLaunchModelDrawer] = + useState(false) + + // Pagination status + const [displayedData, setDisplayedData] = useState([]) + const [currentPage, setCurrentPage] = useState(1) + const [hasMore, setHasMore] = useState(true) + const itemsPerPage = 20 + const loaderRef = useRef(null) + + const filter = useCallback( + (registration) => { + if (searchTerm !== '') { + if (!registration || typeof searchTerm !== 'string') return false + const modelName = registration.model_name + ? registration.model_name.toLowerCase() + : '' + const modelDescription = registration.model_description + ? registration.model_description.toLowerCase() + : '' + + if ( + !modelName.includes(searchTerm.toLowerCase()) && + !modelDescription.includes(searchTerm.toLowerCase()) + ) { + return false + } + } - if ( - !modelName.includes(searchTerm.toLowerCase()) && - !modelDescription.includes(searchTerm.toLowerCase()) - ) { - return false + if (modelListType === 'featured') { + if ( + featureModels.length && + !featureModels.includes(registration.model_name) && + !collectionArr?.includes(registration.model_name) + ) { + return false + } } - } - if (modelListType === 'featured') { if ( - featureModels.length && - !featureModels.includes(registration.model_name) && - !collectionArr?.includes(registration.model_name) - ) { + modelAbilityData.modelAbility && + ((Array.isArray(registration.model_ability) && + registration.model_ability.indexOf(modelAbilityData.modelAbility) < + 0) || + (typeof registration.model_ability === 'string' && + registration.model_ability !== modelAbilityData.modelAbility)) + ) return false - } - } - if ( - modelAbilityData.modelAbility && - ((Array.isArray(registration.model_ability) && - registration.model_ability.indexOf(modelAbilityData.modelAbility) < - 0) || - (typeof registration.model_ability === 'string' && - registration.model_ability !== modelAbilityData.modelAbility)) - ) - return false - - if (statusArr.length === 1) { - if (statusArr[0] === 'cached') { + if (statusArr.length === 1) { + if (statusArr[0] === 'cached') { + const judge = + registration.model_specs?.some((spec) => filterCache(spec)) || + registration?.cache_status + return judge + } else { + return collectionArr?.includes(registration.model_name) + } + } else if (statusArr.length > 1) { const judge = registration.model_specs?.some((spec) => filterCache(spec)) || registration?.cache_status - return judge - } else { - return collectionArr?.includes(registration.model_name) + return judge && collectionArr?.includes(registration.model_name) } - } else if (statusArr.length > 1) { - const judge = - registration.model_specs?.some((spec) => filterCache(spec)) || - registration?.cache_status - return judge && collectionArr?.includes(registration.model_name) - } - return true - }, - [ - searchTerm, - modelListType, - featureModels, - collectionArr, - modelAbilityData.modelAbility, - statusArr, - ] - ) - - const filterCache = useCallback((spec) => { - if (Array.isArray(spec.cache_status)) { - return spec.cache_status?.some((cs) => cs) - } else { - return spec.cache_status === true - } - }, []) - - function getUniqueModelAbilities(arr) { - const uniqueAbilities = new Set() + return true + }, + [ + searchTerm, + modelListType, + featureModels, + collectionArr, + modelAbilityData.modelAbility, + statusArr, + ] + ) - arr.forEach((item) => { - if (Array.isArray(item.model_ability)) { - item.model_ability.forEach((ability) => { - uniqueAbilities.add(ability) - }) + const filterCache = useCallback((spec) => { + if (Array.isArray(spec.cache_status)) { + return spec.cache_status?.some((cs) => cs) + } else { + return spec.cache_status === true } - }) + }, []) - return Array.from(uniqueAbilities) - } + function getUniqueModelAbilities(arr) { + const uniqueAbilities = new Set() - const update = () => { - if ( - isCallingApi || - isUpdatingModel || - (cookie.token !== 'no_auth' && !sessionStorage.getItem('token')) - ) - return - - try { - setIsCallingApi(true) - - fetchWrapper - .get(`/v1/model_registrations/${modelType}?detailed=true`) - .then((data) => { - const builtinRegistrations = data.filter((v) => v.is_builtin) - setModelAbilityData({ - ...modelAbilityData, - options: getUniqueModelAbilities(builtinRegistrations), + arr.forEach((item) => { + if (Array.isArray(item.model_ability)) { + item.model_ability.forEach((ability) => { + uniqueAbilities.add(ability) }) - setRegistrationData(builtinRegistrations) - const collectionData = JSON.parse( - localStorage.getItem('collectionArr') - ) - setCollectionArr(collectionData) + } + }) - // Reset pagination status - setCurrentPage(1) - setHasMore(true) - }) - .catch((error) => { - console.error('Error:', error) - if (error.response.status !== 403 && error.response.status !== 401) { - setErrorMsg(error.message) - } - }) - } catch (error) { - console.error('Error:', error) - } finally { - setIsCallingApi(false) + return Array.from(uniqueAbilities) } - } - useEffect(() => { - update() - }, [cookie.token]) + const update = () => { + if ( + isCallingApi || + isUpdatingModel || + (cookie.token !== 'no_auth' && !sessionStorage.getItem('token')) + ) + return + + try { + setIsCallingApi(true) + + fetchWrapper + .get(`/v1/model_registrations/${modelType}?detailed=true`) + .then((data) => { + const builtinRegistrations = data.filter((v) => v.is_builtin) + setModelAbilityData({ + ...modelAbilityData, + options: getUniqueModelAbilities(builtinRegistrations), + }) + setRegistrationData(builtinRegistrations) + const collectionData = JSON.parse( + localStorage.getItem('collectionArr') + ) + setCollectionArr(collectionData) + + // Reset pagination status + setCurrentPage(1) + setHasMore(true) + }) + .catch((error) => { + console.error('Error:', error) + if ( + error.response.status !== 403 && + error.response.status !== 401 + ) { + setErrorMsg(error.message) + } + }) + } catch (error) { + console.error('Error:', error) + } finally { + setIsCallingApi(false) + } + } - // Update pagination data - const updateDisplayedData = useCallback(() => { - const filteredData = registrationData.filter((registration) => - filter(registration) - ) + useEffect(() => { + update() + }, [cookie.token]) - const sortedData = [...filteredData].sort((a, b) => { - if (modelListType === 'featured') { - const indexA = featureModels.indexOf(a.model_name) - const indexB = featureModels.indexOf(b.model_name) - return ( - (indexA !== -1 ? indexA : Infinity) - - (indexB !== -1 ? indexB : Infinity) - ) + // Update pagination data + const updateDisplayedData = useCallback(() => { + const filteredData = registrationData.filter((registration) => + filter(registration) + ) + + const sortedData = [...filteredData].sort((a, b) => { + if (modelListType === 'featured') { + const indexA = featureModels.indexOf(a.model_name) + const indexB = featureModels.indexOf(b.model_name) + return ( + (indexA !== -1 ? indexA : Infinity) - + (indexB !== -1 ? indexB : Infinity) + ) + } + return 0 + }) + + // If pagination is disabled, show all data at once + if (!ENABLE_PAGINATION) { + setDisplayedData(sortedData) + setHasMore(false) + return } - return 0 - }) - // If pagination is disabled, show all data at once - if (!ENABLE_PAGINATION) { - setDisplayedData(sortedData) - setHasMore(false) - return - } + const startIndex = (currentPage - 1) * itemsPerPage + const endIndex = currentPage * itemsPerPage + const newData = sortedData.slice(startIndex, endIndex) - const startIndex = (currentPage - 1) * itemsPerPage - const endIndex = currentPage * itemsPerPage - const newData = sortedData.slice(startIndex, endIndex) + if (currentPage === 1) { + setDisplayedData(newData) + } else { + setDisplayedData((prev) => [...prev, ...newData]) + } + setHasMore(endIndex < sortedData.length) + }, [ + registrationData, + filter, + modelListType, + featureModels, + currentPage, + itemsPerPage, + ]) - if (currentPage === 1) { - setDisplayedData(newData) - } else { - setDisplayedData((prev) => [...prev, ...newData]) - } - setHasMore(endIndex < sortedData.length) - }, [ - registrationData, - filter, - modelListType, - featureModels, - currentPage, - itemsPerPage, - ]) - - useEffect(() => { - updateDisplayedData() - }, [updateDisplayedData]) - - // Reset pagination when filters change - useEffect(() => { - setCurrentPage(1) - setHasMore(true) - }, [searchTerm, modelAbilityData.modelAbility, status, modelListType]) - - // Infinite scroll observer - useEffect(() => { - if (!ENABLE_PAGINATION) return - - const observer = new IntersectionObserver( - (entries) => { - if (entries[0].isIntersecting && hasMore && !isCallingApi) { - setCurrentPage((prev) => prev + 1) - } - }, - { threshold: 1.0 } - ) + useEffect(() => { + updateDisplayedData() + }, [updateDisplayedData]) - if (loaderRef.current) { - observer.observe(loaderRef.current) - } + // Reset pagination when filters change + useEffect(() => { + setCurrentPage(1) + setHasMore(true) + }, [searchTerm, modelAbilityData.modelAbility, status, modelListType]) + + // Infinite scroll observer + useEffect(() => { + if (!ENABLE_PAGINATION) return + + const observer = new IntersectionObserver( + (entries) => { + if (entries[0].isIntersecting && hasMore && !isCallingApi) { + setCurrentPage((prev) => prev + 1) + } + }, + { threshold: 1.0 } + ) - return () => { if (loaderRef.current) { - observer.unobserve(loaderRef.current) + observer.observe(loaderRef.current) } - } - }, [hasMore, isCallingApi, currentPage]) - const getCollectionArr = (data) => { - setCollectionArr(data) - } + return () => { + if (loaderRef.current) { + observer.unobserve(loaderRef.current) + } + } + }, [hasMore, isCallingApi, currentPage]) - const handleChangeFilter = (type, value) => { - const typeMap = { - modelAbility: { - setter: (value) => { - setModelAbilityData({ - ...modelAbilityData, - modelAbility: value, - }) - }, - filterArr: modelAbilityData.options, - }, - status: { setter: setStatus, filterArr: [] }, + const getCollectionArr = (data) => { + setCollectionArr(data) } - const { setter, filterArr: excludeArr } = typeMap[type] || {} - if (!setter) return + const handleChangeFilter = (type, value) => { + const typeMap = { + modelAbility: { + setter: (value) => { + setModelAbilityData({ + ...modelAbilityData, + modelAbility: value, + }) + }, + filterArr: modelAbilityData.options, + }, + status: { setter: setStatus, filterArr: [] }, + } - setter(value) + const { setter, filterArr: excludeArr } = typeMap[type] || {} + if (!setter) return - const updatedFilterArr = Array.from( - new Set([ - ...filterArr.filter((item) => !excludeArr.includes(item)), - value, - ]) - ) + setter(value) + + const updatedFilterArr = Array.from( + new Set([ + ...filterArr.filter((item) => !excludeArr.includes(item)), + value, + ]) + ) - setFilterArr(updatedFilterArr) + setFilterArr(updatedFilterArr) - if (type === 'status') { - setStatusArr( - updatedFilterArr.filter( - (item) => ![...modelAbilityData.options].includes(item) + if (type === 'status') { + setStatusArr( + updatedFilterArr.filter( + (item) => ![...modelAbilityData.options].includes(item) + ) ) - ) - } + } - // Reset pagination status - setDisplayedData([]) - setCurrentPage(1) - setHasMore(true) - } + // Reset pagination status + setDisplayedData([]) + setCurrentPage(1) + setHasMore(true) + } - const handleDeleteChip = (item) => { - setFilterArr( - filterArr.filter((subItem) => { - return subItem !== item - }) - ) - if (item === modelAbilityData.modelAbility) { - setModelAbilityData({ - ...modelAbilityData, - modelAbility: '', - }) - } else { - setStatusArr( - statusArr.filter((subItem) => { + const handleDeleteChip = (item) => { + setFilterArr( + filterArr.filter((subItem) => { return subItem !== item }) ) - if (item === status) setStatus('') - } - - // Reset pagination status - setCurrentPage(1) - setHasMore(true) - } - - const handleModelType = (newModelType) => { - if (newModelType !== null) { - setModelListType(newModelType) + if (item === modelAbilityData.modelAbility) { + setModelAbilityData({ + ...modelAbilityData, + modelAbility: '', + }) + } else { + setStatusArr( + statusArr.filter((subItem) => { + return subItem !== item + }) + ) + if (item === status) setStatus('') + } // Reset pagination status - setDisplayedData([]) setCurrentPage(1) setHasMore(true) } - } - function getLabel(item) { - const translation = t(`launchModel.${item}`) - return translation === `launchModel.${item}` ? item : translation - } + const handleModelType = (newModelType) => { + if (newModelType !== null) { + setModelListType(newModelType) - return ( - -
{ - const hasAbility = modelAbilityData.options.length > 0 - const hasFeature = featureModels.length > 0 - - const baseColumns = hasAbility ? ['200px', '150px'] : ['200px'] - const altColumns = hasAbility ? ['150px', '150px'] : ['150px'] - - const columns = hasFeature - ? [...baseColumns, '150px', '1fr'] - : [...altColumns, '1fr'] - - return columns.join(' ') - })(), - columnGap: '20px', - margin: '30px 2rem', - alignItems: 'center', - }} - > - {featureModels.length > 0 && ( - - - + + + + )} + {modelAbilityData.options.length > 0 && ( + + + {t('launchModel.modelAbility')} + + + + )} - - {t('launchModel.modelAbility')} + + {t('launchModel.status')} - )} - - {t('launchModel.status')} - - - - - { - setSearchTerm(e.target.value) - }} - size="small" - hotkey="Enter" - t={t} - /> - -
-
- {filterArr.map((item, index) => ( - handleDeleteChip(item)} - /> - ))} -
-
- {displayedData.map((filteredRegistration) => ( - + { + setSearchTerm(e.target.value) + }} + size="small" + hotkey="Enter" + t={t} + /> + +
+
+ {filterArr.map((item, index) => ( + handleDeleteChip(item)} + /> + ))} +
+
+ {displayedData.map((filteredRegistration) => ( + { + setSelectedModel(filteredRegistration) + setIsOpenLaunchModelDrawer(true) + }} + /> + ))} +
+ +
+ {ENABLE_PAGINATION && hasMore && !isCallingApi && ( +
+ +
+ )} +
+ + {selectedModel && ( + { - setSelectedModel(filteredRegistration) - setIsOpenLaunchModelDrawer(true) - }} + gpuAvailable={gpuAvailable} + open={isOpenLaunchModelDrawer} + onClose={() => setIsOpenLaunchModelDrawer(false)} /> - ))} - - -
- {ENABLE_PAGINATION && hasMore && !isCallingApi && ( -
- -
)} -
- - {selectedModel && ( - setIsOpenLaunchModelDrawer(false)} - /> - )} -
- ) -} + + ) + } +) + +LaunchModelComponent.displayName = 'LaunchModelComponent' export default LaunchModelComponent diff --git a/xinference/ui/web/ui/src/scenes/launch_model/components/addModelDialog.js b/xinference/ui/web/ui/src/scenes/launch_model/components/addModelDialog.js new file mode 100644 index 0000000000..b258a2bdb5 --- /dev/null +++ b/xinference/ui/web/ui/src/scenes/launch_model/components/addModelDialog.js @@ -0,0 +1,192 @@ +import { + Button, + Dialog, + DialogActions, + DialogContent, + DialogTitle, + TextField, +} from '@mui/material' +import React, { useContext, useState } from 'react' +import { useTranslation } from 'react-i18next' + +import { ApiContext } from '../../../components/apiContext' + +const API_BASE_URL = 'https://model.xinference.io' + +const AddModelDialog = ({ open, onClose, onUpdateList }) => { + const { t } = useTranslation() + const [modelName, setModelName] = useState('') + const [loading, setLoading] = useState(false) + const { endPoint, setErrorMsg } = useContext(ApiContext) + + const searchModelByName = async (name) => { + try { + const url = `${API_BASE_URL}/api/models?order=featured&query=${encodeURIComponent( + name + )}&page=1&pageSize=5` + const res = await fetch(url, { method: 'GET' }) + const rawText = await res.text().catch(() => '') + if (!res.ok) { + setErrorMsg(rawText || `HTTP ${res.status}`) + return null + } + try { + const data = JSON.parse(rawText) + const items = data?.data || [] + const exact = items.find((it) => it?.model_name === name) + if (!exact) { + setErrorMsg(t('launchModel.error.name_not_matched')) + return null + } + const id = exact?.id + const modelType = exact?.model_type + if (!id || !modelType) { + setErrorMsg(t('launchModel.error.downloadFailed')) + return null + } + return { id, modelType } + } catch { + setErrorMsg(rawText || t('launchModel.error.json_parse_error')) + return null + } + } catch (err) { + console.error(err) + setErrorMsg(err.message || t('launchModel.error.requestFailed')) + return null + } + } + + const fetchModelJson = async (modelId) => { + try { + const res = await fetch( + `${API_BASE_URL}/api/models/download?model_id=${encodeURIComponent( + modelId + )}`, + { method: 'GET' } + ) + const rawText = await res.text().catch(() => '') + if (!res.ok) { + setErrorMsg(rawText || `HTTP ${res.status}`) + return null + } + try { + const data = JSON.parse(rawText) + return data + } catch { + setErrorMsg(rawText || t('launchModel.error.json_parse_error')) + return null + } + } catch (err) { + console.error(err) + setErrorMsg(err.message || t('launchModel.error.requestFailed')) + return null + } + } + + const addToLocal = async (modelType, modelJson) => { + try { + const res = await fetch(endPoint + '/v1/models/add', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ model_type: modelType, model_json: modelJson }), + }) + const rawText = await res.text().catch(() => '') + if (!res.ok) { + setErrorMsg(rawText || `HTTP ${res.status}`) + return + } + onClose(`/launch_model/${modelType}`) + onUpdateList(modelType) + } catch (error) { + console.error('Error:', error) + if (error?.response?.status !== 403) { + setErrorMsg(error.message) + } + } + } + + const handleFormSubmit = async (e) => { + e.preventDefault() + const name = modelName?.trim() + if (!name) { + setErrorMsg(t('launchModel.addModelDialog.modelName.tip')) + return + } + setLoading(true) + setErrorMsg('') + try { + const found = await searchModelByName(name) + if (!found) return + const { id, modelType } = found + + const modelJson = await fetchModelJson(id) + if (!modelJson) return + + await addToLocal(modelType, modelJson) + } finally { + setLoading(false) + } + } + + return ( + onClose()} width={500}> + {t('launchModel.addModel')} + +
+
+ {t('launchModel.addModelDialog.introPrefix')}{' '} + + {t('launchModel.addModelDialog.platformLinkText')} + + {t('launchModel.addModelDialog.introSuffix')} +
+
+ { + setModelName(e.target.value) + }} + disabled={loading} + /> + +
+
+ + + + +
+ ) +} + +export default AddModelDialog diff --git a/xinference/ui/web/ui/src/scenes/launch_model/index.js b/xinference/ui/web/ui/src/scenes/launch_model/index.js index 24f886a80d..4ac6cff612 100644 --- a/xinference/ui/web/ui/src/scenes/launch_model/index.js +++ b/xinference/ui/web/ui/src/scenes/launch_model/index.js @@ -1,6 +1,7 @@ -import { TabContext, TabList, TabPanel } from '@mui/lab' -import { Box, Tab } from '@mui/material' -import React, { useContext, useEffect, useState } from 'react' +import Add from '@mui/icons-material/Add' +import { LoadingButton, TabContext, TabList, TabPanel } from '@mui/lab' +import { Box, Button, MenuItem, Select, Tab } from '@mui/material' +import React, { useContext, useEffect, useRef, useState } from 'react' import { useCookies } from 'react-cookie' import { useTranslation } from 'react-i18next' import { useNavigate } from 'react-router-dom' @@ -11,6 +12,7 @@ import fetchWrapper from '../../components/fetchWrapper' import SuccessMessageSnackBar from '../../components/successMessageSnackBar' import Title from '../../components/Title' import { isValidBearerToken } from '../../components/utils' +import AddModelDialog from './components/addModelDialog' import { featureModels } from './data/data' import LaunchCustom from './launchCustom' import LaunchModelComponent from './LaunchModel' @@ -22,13 +24,17 @@ const LaunchModel = () => { : '/launch_model/llm' ) const [gpuAvailable, setGPUAvailable] = useState(-1) + const [open, setOpen] = useState(false) + const [modelType, setModelType] = useState('llm') + const [loading, setLoading] = useState(false) const { setErrorMsg } = useContext(ApiContext) const [cookie] = useCookies(['token']) const navigate = useNavigate() const { t } = useTranslation() + const LaunchModelRefs = useRef({}) - const handleTabChange = (event, newValue) => { + const handleTabChange = (newValue) => { setValue(newValue) navigate(newValue) sessionStorage.setItem('modelType', newValue) @@ -59,14 +65,56 @@ const LaunchModel = () => { } }, [cookie.token]) + const updateList = (modelType) => { + LaunchModelRefs.current[modelType]?.update() + } + + const handleClose = (value) => { + setOpen(false) + if (value) { + handleTabChange(value) + } + } + + const updateModels = () => { + setLoading(true) + fetchWrapper + .post('/v1/models/update_type', { model_type: modelType }) + .then(() => { + handleTabChange(`/launch_model/${modelType}`) + updateList(modelType) + }) + .catch((error) => { + console.error('Error:', error) + if (error.response.status !== 403 && error.response.status !== 401) { + setErrorMsg(error.message) + } + }) + .finally(() => { + setLoading(false) + }) + } + return ( <ErrorMessageSnackBar /> <SuccessMessageSnackBar /> <TabContext value={value}> - <Box sx={{ borderBottom: 1, borderColor: 'divider' }}> - <TabList value={value} onChange={handleTabChange} aria-label="tabs"> + <Box + sx={{ + borderBottom: 1, + borderColor: 'divider', + display: 'flex', + justifyContent: 'space-between', + alignItems: 'center', + }} + > + <TabList + value={value} + onChange={(_, value) => handleTabChange(value)} + aria-label="tabs" + > <Tab label={t('model.languageModels')} value="/launch_model/llm" /> <Tab label={t('model.embeddingModels')} @@ -81,6 +129,53 @@ const LaunchModel = () => { value="/launch_model/custom/llm" /> </TabList> + <Box + sx={{ + display: 'flex', + alignItems: 'center', + gap: '10px', + }} + > + <Box sx={{ display: 'flex', gap: 0 }}> + <Select + value={modelType} + onChange={(e) => setModelType(e.target.value)} + size="small" + sx={{ + borderTopRightRadius: 0, + borderBottomRightRadius: 0, + minWidth: 100, + }} + > + <MenuItem value="llm">LLM</MenuItem> + <MenuItem value="embedding">Embedding</MenuItem> + <MenuItem value="rerank">Rerank</MenuItem> + <MenuItem value="image">Image</MenuItem> + <MenuItem value="audio">Audio</MenuItem> + <MenuItem value="video">Video</MenuItem> + </Select> + + <LoadingButton + variant="contained" + onClick={updateModels} + loading={loading} + sx={{ + borderTopLeftRadius: 0, + borderBottomLeftRadius: 0, + whiteSpace: 'nowrap', + }} + > + {t('launchModel.update')} + </LoadingButton> + </Box> + <Button + variant="outlined" + startIcon={<Add />} + onClick={() => setOpen(true)} + > + {t('launchModel.addModel')} + </Button> + </Box> </Box> <TabPanel value="/launch_model/llm" sx={{ padding: 0 }}> <LaunchModelComponent @@ -89,6 +184,7 @@ const LaunchModel = () => { featureModels={ featureModels.find((item) => item.type === 'llm').feature_models } + ref={(ref) => (LaunchModelRefs.current.llm = ref)} /> </TabPanel> <TabPanel value="/launch_model/embedding" sx={{ padding: 0 }}> @@ -99,6 +195,7 @@ const LaunchModel = () => { featureModels.find((item) => item.type === 'embedding') .feature_models } + ref={(ref) => (LaunchModelRefs.current.embedding = ref)} /> </TabPanel> <TabPanel value="/launch_model/rerank" sx={{ padding: 0 }}> @@ -109,6 +206,7 @@ const LaunchModel = () => { featureModels.find((item) => item.type === 'rerank') .feature_models } + ref={(ref) => (LaunchModelRefs.current.rerank = ref)} /> </TabPanel> <TabPanel value="/launch_model/image" sx={{ padding: 0 }}> @@ -118,6 +216,7 @@ const LaunchModel = () => { featureModels={ featureModels.find((item) => item.type === 'image').feature_models } + ref={(ref) => (LaunchModelRefs.current.image = ref)} /> </TabPanel> <TabPanel value="/launch_model/audio" sx={{ padding: 0 }}> @@ -127,6 +226,7 @@ const LaunchModel = () => { featureModels={ featureModels.find((item) => item.type === 'audio').feature_models } + ref={(ref) => (LaunchModelRefs.current.audio = ref)} /> </TabPanel> <TabPanel value="/launch_model/video" sx={{ padding: 0 }}> @@ -136,12 +236,20 @@ const LaunchModel = () => { featureModels={ featureModels.find((item) => item.type === 'video').feature_models } + ref={(ref) => (LaunchModelRefs.current.video = ref)} /> </TabPanel> <TabPanel value="/launch_model/custom/llm" sx={{ padding: 0 }}> <LaunchCustom gpuAvailable={gpuAvailable} /> </TabPanel> </TabContext> + {open && ( + <AddModelDialog + onUpdateList={updateList} + open={open} + onClose={handleClose} + /> + )} </Box> ) }