Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ class RegisterModelRequest(BaseModel):
persist: bool


class AddModelRequest(BaseModel):
model_type: str
model_json: Dict[str, Any]


class UpdateModelRequest(BaseModel):
model_type: str


class BuildGradioInterfaceRequest(BaseModel):
model_type: str
model_name: str
Expand Down Expand Up @@ -900,6 +909,26 @@ async def internal_exception_handler(request: Request, exc: Exception):
else None
),
)
self._router.add_api_route(
"/v1/models/add",
self.add_model,
methods=["POST"],
dependencies=(
[Security(self._auth_service, scopes=["models:add"])]
if self.is_authenticated()
else None
),
)
self._router.add_api_route(
"/v1/models/update_type",
self.update_model_type,
methods=["POST"],
dependencies=(
[Security(self._auth_service, scopes=["models:add"])]
if self.is_authenticated()
else None
),
)
self._router.add_api_route(
"/v1/cache/models",
self.list_cached_models,
Expand Down Expand Up @@ -3123,25 +3152,114 @@ async def unregister_model(self, model_type: str, model_name: str) -> JSONRespon
raise HTTPException(status_code=500, detail=str(e))
return JSONResponse(content=None)

async def add_model(self, request: Request) -> JSONResponse:
try:

# Parse request
raw_json = await request.json()

if "model_type" in raw_json and "model_json" in raw_json:
body = AddModelRequest.parse_obj(raw_json)
model_type = body.model_type
model_json = body.model_json
else:
model_json = raw_json

# Priority 1: Check if model_type is explicitly provided in the JSON
if "model_type" in model_json:
model_type = model_json["model_type"]
logger.info(
f"[DEBUG] Using explicit model_type from JSON: {model_type}"
)
else:
# model_type is required in the JSON when using unwrapped format
logger.error(
f"[DEBUG] model_type not provided in JSON, this is required"
)
raise HTTPException(
status_code=400,
detail="model_type is required in the model JSON. Supported types: LLM, embedding, audio, image, video, rerank",
)

supervisor_ref = await self._get_supervisor_ref()

# Call supervisor
await supervisor_ref.add_model(model_type, model_json)

except ValueError as re:
logger.error(f"ValueError in add_model API: {re}", exc_info=True)
logger.error(f"ValueError details: {type(re).__name__}: {re}")
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(f"Unexpected error in add_model API: {e}", exc_info=True)
logger.error(f"Error details: {type(e).__name__}: {e}")
import traceback

logger.error(f"Full traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e))

return JSONResponse(
content={"message": f"Model added successfully for type: {model_type}"}
)

async def update_model_type(self, request: Request) -> JSONResponse:
try:
# Parse request
raw_json = await request.json()

body = UpdateModelRequest.parse_obj(raw_json)
model_type = body.model_type

# Get supervisor reference
supervisor_ref = await self._get_supervisor_ref()

await supervisor_ref.update_model_type(model_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.


except ValueError as re:
logger.error(f"ValueError in update_model_type API: {re}", exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(
f"Unexpected error in update_model_type API: {e}", exc_info=True
)
raise HTTPException(status_code=500, detail=str(e))

return JSONResponse(
content={
"message": f"Model configurations updated successfully for type: {model_type}"
}
)

async def list_model_registrations(
self, model_type: str, detailed: bool = Query(False)
) -> JSONResponse:
try:

data = await (await self._get_supervisor_ref()).list_model_registrations(
model_type, detailed=detailed
)

# Remove duplicate model names.
model_names = set()
final_data = []
for item in data:
if item["model_name"] not in model_names:
model_names.add(item["model_name"])
final_data.append(item)

return JSONResponse(content=final_data)
except ValueError as re:
logger.error(
f"ValueError in list_model_registrations: {re}",
exc_info=True,
)
logger.error(re, exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(
f"Unexpected error in list_model_registrations: {e}",
exc_info=True,
)
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

Expand Down
Loading
Loading