Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: create_embedding add field model_replica #2779

Merged
merged 15 commits into from
Feb 28, 2025
Merged
1 change: 1 addition & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,7 @@ async def create_embedding(self, request: Request) -> Response:
raise HTTPException(status_code=500, detail=str(e))

try:
kwargs["model_uid"] = model_uid
embedding = await model.create_embedding(body.input, **kwargs)
return Response(embedding, media_type="application/json")
except Exception as e:
Expand Down
3 changes: 3 additions & 0 deletions xinference/core/tests/test_restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,9 @@ def test_restful_api_for_embedding(setup):

assert "embedding" in embedding_res["data"][0]
assert len(embedding_res["data"][0]["embedding"]) == model_spec.dimensions
assert "model_replica" in embedding_res
assert embedding_res["model_replica"] is not None
assert embedding_res["model"] == payload["model"]

# test multiple
payload = {
Expand Down
3 changes: 2 additions & 1 deletion xinference/model/embedding/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,8 @@ def base64_to_image(base64_str: str) -> Image.Image:
if not is_bge_m3_flag_model and not kwargs.get("return_sparse")
else "dict"
),
model=self._model_uid,
model=kwargs.get("model_uid"), # type: ignore
model_replica=self._model_uid,
data=embedding_list,
usage=usage,
)
Expand Down
1 change: 1 addition & 0 deletions xinference/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class EmbeddingData(TypedDict):
class Embedding(TypedDict):
object: Literal["list"]
model: str
model_replica: str
data: List[EmbeddingData]
usage: EmbeddingUsage

Expand Down
Loading