Skip to content

Commit

Permalink
FEAT: Add Rerank model token input/output usage (#1657)
Browse files Browse the repository at this point in the history
  • Loading branch information
wxiwnd authored Jun 21, 2024
1 parent 21b5ab2 commit 5cef7c3
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 4 deletions.
2 changes: 2 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class RerankRequest(BaseModel):
documents: List[str]
top_n: Optional[int] = None
return_documents: Optional[bool] = False
return_len: Optional[bool] = False
max_chunks_per_doc: Optional[int] = None


Expand Down Expand Up @@ -1116,6 +1117,7 @@ async def rerank(self, request: Request) -> Response:
top_n=body.top_n,
max_chunks_per_doc=body.max_chunks_per_doc,
return_documents=body.return_documents,
return_len=body.return_len,
**kwargs,
)
return Response(scores, media_type="application/json")
Expand Down
4 changes: 4 additions & 0 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def rerank(
top_n: Optional[int] = None,
max_chunks_per_doc: Optional[int] = None,
return_documents: Optional[bool] = None,
return_len: Optional[bool] = None,
**kwargs,
):
"""
Expand All @@ -152,6 +153,8 @@ def rerank(
The maximum number of chunks derived from a document
return_documents: bool
if return documents
return_len: bool
if return tokens len
Returns
-------
Scores
Expand All @@ -170,6 +173,7 @@ def rerank(
"top_n": top_n,
"max_chunks_per_doc": max_chunks_per_doc,
"return_documents": return_documents,
"return_len": return_len,
}
request_body.update(kwargs)
response = requests.post(url, json=request_body, headers=self.auth_headers)
Expand Down
2 changes: 2 additions & 0 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ async def rerank(
top_n: Optional[int],
max_chunks_per_doc: Optional[int],
return_documents: Optional[bool],
return_len: Optional[bool],
*args,
**kwargs,
):
Expand All @@ -554,6 +555,7 @@ async def rerank(
top_n,
max_chunks_per_doc,
return_documents,
return_len,
*args,
**kwargs,
)
Expand Down
36 changes: 32 additions & 4 deletions xinference/model/rerank/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ...constants import XINFERENCE_CACHE_DIR
from ...device_utils import empty_cache
from ...types import Document, DocumentObj, Rerank
from ...types import Document, DocumentObj, Rerank, RerankTokens
from ..core import CacheableModelSpec, ModelDescription
from ..utils import is_model_cached

Expand Down Expand Up @@ -121,19 +121,25 @@ def __init__(
if model_spec.type == "unknown":
model_spec.type = self._auto_detect_type(model_path)

@staticmethod
def _get_tokenizer(model_path):
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return tokenizer

@staticmethod
def _auto_detect_type(model_path):
"""This method may not be stable due to the fact that the tokenizer name may be changed.
Therefore, we only use this method for unknown model types."""
from transformers import AutoTokenizer

type_mapper = {
"LlamaTokenizerFast": "LLM-based layerwise",
"GemmaTokenizerFast": "LLM-based",
"XLMRobertaTokenizerFast": "normal",
}

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer = RerankModel._get_tokenizer(model_path)
rerank_type = type_mapper.get(type(tokenizer).__name__)
if rerank_type is None:
logger.warning(
Expand Down Expand Up @@ -186,6 +192,7 @@ def rerank(
top_n: Optional[int],
max_chunks_per_doc: Optional[int],
return_documents: Optional[bool],
return_len: Optional[bool],
**kwargs,
) -> Rerank:
self._counter += 1
Expand Down Expand Up @@ -224,7 +231,28 @@ def rerank(
)
for arg in sim_scores_argsort
]
return Rerank(id=str(uuid.uuid1()), results=docs)
if return_len:
tokenizer = self._get_tokenizer(self._model_path)
input_len = sum([len(tokenizer.tokenize(t)) for t in documents])

# Rerank Model output is just score or documents
# while return_documents = True
output_len = input_len

# api_version, billed_units, warnings
# is for Cohere API compatibility, set to None
metadata = {
"api_version": None,
"billed_units": None,
"tokens": (
RerankTokens(input_tokens=input_len, output_tokens=output_len)
if return_len
else None
),
"warnings": None,
}

return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)


def get_cache_dir(model_spec: RerankModelSpec):
Expand Down
13 changes: 13 additions & 0 deletions xinference/model/rerank/tests/test_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,19 @@ def test_restful_api(model_name, setup):
assert scores["results"][0]["index"] == 0
assert scores["results"][0]["document"] == corpus[0]

scores = model.rerank(corpus, query, return_len=True)
assert (
scores["meta"]["tokens"]["input_tokens"]
== scores["meta"]["tokens"]["output_tokens"]
)

print(scores)

scores = model.rerank(corpus, query)
assert scores["meta"]["tokens"] == None

print(scores)

kwargs = {
"invalid": "invalid",
}
Expand Down
28 changes: 28 additions & 0 deletions xinference/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,37 @@ class DocumentObj(TypedDict):
document: Optional[Document]


# Cohere API compatibility
class ApiVersion(TypedDict):
version: str
is_deprecated: bool
is_experimental: bool


# Cohere API compatibility
class BilledUnit(TypedDict):
input_tokens: int
output_tokens: int
search_units: int
classifications: int


class RerankTokens(TypedDict):
input_tokens: int
output_tokens: int


class Meta(TypedDict):
api_version: Optional[ApiVersion]
billed_units: Optional[BilledUnit]
tokens: RerankTokens
warnings: Optional[List[str]]


class Rerank(TypedDict):
id: str
results: List[DocumentObj]
meta: Meta


class CompletionLogprobs(TypedDict):
Expand Down

0 comments on commit 5cef7c3

Please sign in to comment.