From 5cef7c3d4bb0c5208d262fc3ffb7d7083724de1c Mon Sep 17 00:00:00 2001 From: wxiwnd <40122078+wxiwnd@users.noreply.github.com> Date: Fri, 21 Jun 2024 15:34:17 +0800 Subject: [PATCH] FEAT: Add Rerank model token input/output usage (#1657) --- xinference/api/restful_api.py | 2 ++ xinference/client/restful/restful_client.py | 4 +++ xinference/core/model.py | 2 ++ xinference/model/rerank/core.py | 36 +++++++++++++++++--- xinference/model/rerank/tests/test_rerank.py | 13 +++++++ xinference/types.py | 28 +++++++++++++++ 6 files changed, 81 insertions(+), 4 deletions(-) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index bc4f338613..03a77f341f 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -109,6 +109,7 @@ class RerankRequest(BaseModel): documents: List[str] top_n: Optional[int] = None return_documents: Optional[bool] = False + return_len: Optional[bool] = False max_chunks_per_doc: Optional[int] = None @@ -1116,6 +1117,7 @@ async def rerank(self, request: Request) -> Response: top_n=body.top_n, max_chunks_per_doc=body.max_chunks_per_doc, return_documents=body.return_documents, + return_len=body.return_len, **kwargs, ) return Response(scores, media_type="application/json") diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index 2230bab848..10d9ae8231 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -135,6 +135,7 @@ def rerank( top_n: Optional[int] = None, max_chunks_per_doc: Optional[int] = None, return_documents: Optional[bool] = None, + return_len: Optional[bool] = None, **kwargs, ): """ @@ -152,6 +153,8 @@ def rerank( The maximum number of chunks derived from a document return_documents: bool if return documents + return_len: bool + if return tokens len Returns ------- Scores @@ -170,6 +173,7 @@ def rerank( "top_n": top_n, "max_chunks_per_doc": max_chunks_per_doc, "return_documents": return_documents, + "return_len": return_len, } request_body.update(kwargs) response = requests.post(url, json=request_body, headers=self.auth_headers) diff --git a/xinference/core/model.py b/xinference/core/model.py index fae1b1811a..94c9f16cec 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -543,6 +543,7 @@ async def rerank( top_n: Optional[int], max_chunks_per_doc: Optional[int], return_documents: Optional[bool], + return_len: Optional[bool], *args, **kwargs, ): @@ -554,6 +555,7 @@ async def rerank( top_n, max_chunks_per_doc, return_documents, + return_len, *args, **kwargs, ) diff --git a/xinference/model/rerank/core.py b/xinference/model/rerank/core.py index f1ec7e2b29..2f7924ebcd 100644 --- a/xinference/model/rerank/core.py +++ b/xinference/model/rerank/core.py @@ -23,7 +23,7 @@ from ...constants import XINFERENCE_CACHE_DIR from ...device_utils import empty_cache -from ...types import Document, DocumentObj, Rerank +from ...types import Document, DocumentObj, Rerank, RerankTokens from ..core import CacheableModelSpec, ModelDescription from ..utils import is_model_cached @@ -121,11 +121,17 @@ def __init__( if model_spec.type == "unknown": model_spec.type = self._auto_detect_type(model_path) + @staticmethod + def _get_tokenizer(model_path): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + return tokenizer + @staticmethod def _auto_detect_type(model_path): """This method may not be stable due to the fact that the tokenizer name may be changed. Therefore, we only use this method for unknown model types.""" - from transformers import AutoTokenizer type_mapper = { "LlamaTokenizerFast": "LLM-based layerwise", @@ -133,7 +139,7 @@ def _auto_detect_type(model_path): "XLMRobertaTokenizerFast": "normal", } - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + tokenizer = RerankModel._get_tokenizer(model_path) rerank_type = type_mapper.get(type(tokenizer).__name__) if rerank_type is None: logger.warning( @@ -186,6 +192,7 @@ def rerank( top_n: Optional[int], max_chunks_per_doc: Optional[int], return_documents: Optional[bool], + return_len: Optional[bool], **kwargs, ) -> Rerank: self._counter += 1 @@ -224,7 +231,28 @@ def rerank( ) for arg in sim_scores_argsort ] - return Rerank(id=str(uuid.uuid1()), results=docs) + if return_len: + tokenizer = self._get_tokenizer(self._model_path) + input_len = sum([len(tokenizer.tokenize(t)) for t in documents]) + + # Rerank Model output is just score or documents + # while return_documents = True + output_len = input_len + + # api_version, billed_units, warnings + # is for Cohere API compatibility, set to None + metadata = { + "api_version": None, + "billed_units": None, + "tokens": ( + RerankTokens(input_tokens=input_len, output_tokens=output_len) + if return_len + else None + ), + "warnings": None, + } + + return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata) def get_cache_dir(model_spec: RerankModelSpec): diff --git a/xinference/model/rerank/tests/test_rerank.py b/xinference/model/rerank/tests/test_rerank.py index aa6ac848bd..7314e5b964 100644 --- a/xinference/model/rerank/tests/test_rerank.py +++ b/xinference/model/rerank/tests/test_rerank.py @@ -54,6 +54,19 @@ def test_restful_api(model_name, setup): assert scores["results"][0]["index"] == 0 assert scores["results"][0]["document"] == corpus[0] + scores = model.rerank(corpus, query, return_len=True) + assert ( + scores["meta"]["tokens"]["input_tokens"] + == scores["meta"]["tokens"]["output_tokens"] + ) + + print(scores) + + scores = model.rerank(corpus, query) + assert scores["meta"]["tokens"] == None + + print(scores) + kwargs = { "invalid": "invalid", } diff --git a/xinference/types.py b/xinference/types.py index fda35750fa..3ea098c60e 100644 --- a/xinference/types.py +++ b/xinference/types.py @@ -80,9 +80,37 @@ class DocumentObj(TypedDict): document: Optional[Document] +# Cohere API compatibility +class ApiVersion(TypedDict): + version: str + is_deprecated: bool + is_experimental: bool + + +# Cohere API compatibility +class BilledUnit(TypedDict): + input_tokens: int + output_tokens: int + search_units: int + classifications: int + + +class RerankTokens(TypedDict): + input_tokens: int + output_tokens: int + + +class Meta(TypedDict): + api_version: Optional[ApiVersion] + billed_units: Optional[BilledUnit] + tokens: RerankTokens + warnings: Optional[List[str]] + + class Rerank(TypedDict): id: str results: List[DocumentObj] + meta: Meta class CompletionLogprobs(TypedDict):