Skip to content

Commit 31d0ff1

Browse files
authored
server / ranking : add sorting and management of top_n (ggml-org#16403)
* server / ranking : add sorting and management of top_n * Make the retro compatible if no top_n will return all results here is a script to make some test ```script URL=${1:-http://127.0.0.1:8181} curl "$URL/v1/rerank" -H "Content-Type: application/json" \ -d '{ "model": "M", "query": "What is the recipe to make bread ?", "return_text" : true, "texts" : true, "top_n": 6, "documents": [ "voici la recette pour faire du pain, il faut de la farine de l eau et du levain et du sel", "it is a bear", "bread recipe : floor, water, yest, salt", "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.", "here is the ingedients to bake bread : 500g floor, 350g water, 120g fresh refresh yest, 15g salt", "recipe to make cookies : floor, eggs, water, chocolat", "here is the recipe to make bread : 500g floor, 350g water, 120g fresh refresh yest, 15g salt", "il fait tres beau aujourd hui", "je n ai pas faim, je ne veux pas manger", "je suis a paris" ] }' | jq ``` * use resize() instead for(...) * simplify top_n init since no need to return error result to test : ./tests.sh unit/test_rerank.py -v -x ==================================================== test session starts ===================================================== platform linux -- Python 3.12.3, pytest-8.3.5, pluggy-1.6.0 -- /home/yann/dev/yann/llama.cpp/tools/server/tests/test/bin/python3 cachedir: .pytest_cache rootdir: /home/yann/dev/yann/llama.cpp/tools/server/tests configfile: pytest.ini plugins: anyio-4.11.0 collected 8 items unit/test_rerank.py::test_rerank PASSED [ 12%] unit/test_rerank.py::test_rerank_tei_format PASSED [ 25%] unit/test_rerank.py::test_invalid_rerank_req[documents0] PASSED [ 37%] unit/test_rerank.py::test_invalid_rerank_req[None] PASSED [ 50%] unit/test_rerank.py::test_invalid_rerank_req[123] PASSED [ 62%] unit/test_rerank.py::test_invalid_rerank_req[documents3] PASSED [ 75%] unit/test_rerank.py::test_rerank_usage[Machine learning is-A machine-Learning is-19] PASSED [ 87%] unit/test_rerank.py::test_rerank_usage[Which city?-Machine learning is -Paris, capitale de la-26] PASSED [100%] ===================================================== 8 passed in 4.31s ====================================================== * add rerank top_n unit test here is the result : ./tests.sh unit/test_rerank.py -v -x =================================================================== test session starts =================================================================== platform linux -- Python 3.12.3, pytest-8.3.5, pluggy-1.6.0 -- /home/yann/dev/yann/llama.cpp/tools/server/tests/test/bin/python3 cachedir: .pytest_cache rootdir: /home/yann/dev/yann/llama.cpp/tools/server/tests configfile: pytest.ini plugins: anyio-4.11.0 collected 16 items unit/test_rerank.py::test_rerank PASSED [ 6%] unit/test_rerank.py::test_rerank_tei_format PASSED [ 12%] unit/test_rerank.py::test_invalid_rerank_req[documents0] PASSED [ 18%] unit/test_rerank.py::test_invalid_rerank_req[None] PASSED [ 25%] unit/test_rerank.py::test_invalid_rerank_req[123] PASSED [ 31%] unit/test_rerank.py::test_invalid_rerank_req[documents3] PASSED [ 37%] unit/test_rerank.py::test_rerank_usage[Machine learning is-A machine-Learning is-19] PASSED [ 43%] unit/test_rerank.py::test_rerank_usage[Which city?-Machine learning is -Paris, capitale de la-26] PASSED [ 50%] unit/test_rerank.py::test_rerank_top_n[None-4] PASSED [ 56%] unit/test_rerank.py::test_rerank_top_n[2-2] PASSED [ 62%] unit/test_rerank.py::test_rerank_top_n[4-4] PASSED [ 68%] unit/test_rerank.py::test_rerank_top_n[99-4] PASSED [ 75%] unit/test_rerank.py::test_rerank_tei_top_n[None-4] PASSED [ 81%] unit/test_rerank.py::test_rerank_tei_top_n[2-2] PASSED [ 87%] unit/test_rerank.py::test_rerank_tei_top_n[4-4] PASSED [ 93%] unit/test_rerank.py::test_rerank_tei_top_n[99-4] PASSED [100%] =================================================================== 16 passed in 8.84s =================================================================== * editor config check fix
1 parent 97870e6 commit 31d0ff1

File tree

3 files changed

+81
-48
lines changed

3 files changed

+81
-48
lines changed

tools/server/server.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5401,15 +5401,6 @@ int main(int argc, char ** argv) {
54015401

54025402
const json body = json::parse(req.body);
54035403

5404-
// TODO: implement
5405-
//int top_n = 1;
5406-
//if (body.count("top_n") != 1) {
5407-
// top_n = body.at("top_n");
5408-
//} else {
5409-
// res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
5410-
// return;
5411-
//}
5412-
54135404
// if true, use TEI API format, otherwise use Jina API format
54145405
// Jina: https://jina.ai/reranker/
54155406
// TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
@@ -5434,6 +5425,8 @@ int main(int argc, char ** argv) {
54345425
return;
54355426
}
54365427

5428+
int top_n = json_value(body, "top_n", (int)documents.size());
5429+
54375430
// create and queue the task
54385431
json responses = json::array();
54395432
bool error = false;
@@ -5474,7 +5467,8 @@ int main(int argc, char ** argv) {
54745467
body,
54755468
responses,
54765469
is_tei_format,
5477-
documents);
5470+
documents,
5471+
top_n);
54785472

54795473
res_ok(res, root);
54805474
};

tools/server/tests/unit/test_rerank.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,45 @@ def test_rerank_usage(query, doc1, doc2, n_tokens):
102102
assert res.status_code == 200
103103
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
104104
assert res.body['usage']['prompt_tokens'] == n_tokens
105+
106+
107+
@pytest.mark.parametrize("top_n,expected_len", [
108+
(None, len(TEST_DOCUMENTS)), # no top_n parameter
109+
(2, 2),
110+
(4, 4),
111+
(99, len(TEST_DOCUMENTS)), # higher than available docs
112+
])
113+
def test_rerank_top_n(top_n, expected_len):
114+
global server
115+
server.start()
116+
data = {
117+
"query": "Machine learning is",
118+
"documents": TEST_DOCUMENTS,
119+
}
120+
if top_n is not None:
121+
data["top_n"] = top_n
122+
123+
res = server.make_request("POST", "/rerank", data=data)
124+
assert res.status_code == 200
125+
assert len(res.body["results"]) == expected_len
126+
127+
128+
@pytest.mark.parametrize("top_n,expected_len", [
129+
(None, len(TEST_DOCUMENTS)), # no top_n parameter
130+
(2, 2),
131+
(4, 4),
132+
(99, len(TEST_DOCUMENTS)), # higher than available docs
133+
])
134+
def test_rerank_tei_top_n(top_n, expected_len):
135+
global server
136+
server.start()
137+
data = {
138+
"query": "Machine learning is",
139+
"texts": TEST_DOCUMENTS,
140+
}
141+
if top_n is not None:
142+
data["top_n"] = top_n
143+
144+
res = server.make_request("POST", "/rerank", data=data)
145+
assert res.status_code == 200
146+
assert len(res.body) == expected_len

tools/server/utils.hpp

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -849,47 +849,44 @@ static json format_response_rerank(
849849
const json & request,
850850
const json & ranks,
851851
bool is_tei_format,
852-
std::vector<std::string> & texts) {
853-
json res;
854-
if (is_tei_format) {
855-
// TEI response format
856-
res = json::array();
857-
bool return_text = json_value(request, "return_text", false);
858-
for (const auto & rank : ranks) {
859-
int index = json_value(rank, "index", 0);
860-
json elem = json{
861-
{"index", index},
862-
{"score", json_value(rank, "score", 0.0)},
863-
};
864-
if (return_text) {
865-
elem["text"] = std::move(texts[index]);
866-
}
867-
res.push_back(elem);
868-
}
869-
} else {
870-
// Jina response format
871-
json results = json::array();
872-
int32_t n_tokens = 0;
873-
for (const auto & rank : ranks) {
874-
results.push_back(json{
875-
{"index", json_value(rank, "index", 0)},
876-
{"relevance_score", json_value(rank, "score", 0.0)},
877-
});
878-
879-
n_tokens += json_value(rank, "tokens_evaluated", 0);
880-
}
881-
882-
res = json{
883-
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
884-
{"object", "list"},
885-
{"usage", json{
886-
{"prompt_tokens", n_tokens},
887-
{"total_tokens", n_tokens}
888-
}},
889-
{"results", results}
852+
std::vector<std::string> & texts,
853+
int top_n) {
854+
int32_t n_tokens = 0;
855+
bool return_text = is_tei_format && json_value(request, "return_text", false);
856+
std::vector<json> elements; // Temporary vector to hold unsorted elements
857+
std::string score_label = is_tei_format ? "score" : "relevance_score";
858+
for (const auto & rank : ranks) {
859+
int index = json_value(rank, "index", 0);
860+
json elem = json{
861+
{"index", index},
862+
{score_label, json_value(rank, "score", 0.0)},
890863
};
864+
n_tokens += json_value(rank, "tokens_evaluated", 0);
865+
if (return_text) {
866+
elem["text"] = std::move(texts[index]);
867+
}
868+
elements.push_back(elem);
891869
}
892870

871+
std::sort(elements.begin(), elements.end(), [score_label](const json& a, const json& b) {
872+
return json_value(a, score_label, 0.0) > json_value(b, score_label, 0.0);
873+
});
874+
875+
elements.resize(std::min(top_n, (int)elements.size()));
876+
json results = elements;
877+
878+
if (is_tei_format) return results;
879+
880+
json res = json{
881+
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
882+
{"object", "list"},
883+
{"usage", json{
884+
{"prompt_tokens", n_tokens},
885+
{"total_tokens", n_tokens}
886+
}},
887+
{"results", results}
888+
};
889+
893890
return res;
894891
}
895892

0 commit comments

Comments
 (0)