diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index 0d17141f83ea..1536ed029aad 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -217,6 +217,7 @@ struct llama_client_slot bool infill = false; bool embedding = false; + bool reranker = false; bool has_next_token = true; bool truncated = false; bool stopped_eos = false; @@ -1409,7 +1410,54 @@ struct llama_server_context queue_results.send(res); } - void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id) + void send_rerank(llama_client_slot &slot, const llama_batch & batch) + { + task_result res; + res.id = slot.task_id; + res.multitask_id = slot.multitask_id; + res.error = false; + res.stop = true; + + float score = -1e6f; // Default score if we fail to get embeddings + + if (!params.rerank) + { + LOG_WARNING("reranking disabled", { + {"params.rerank", params.rerank}, + }); + } + else + { + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + LOG("failed to get embeddings"); + continue; + } + + score = embd[0]; + } + } + + // Format result as JSON similar to the embedding function + res.result_json = json + { + {"score", score}, + {"tokens", slot.n_prompt_tokens} + }; + + queue_results.send(res); + } + + void request_completion(int task_id, json data, bool infill, bool embedding, bool rerank, int multitask_id) { task_server task; task.id = task_id; @@ -1417,6 +1465,7 @@ struct llama_server_context task.data = std::move(data); task.infill_mode = infill; task.embedding_mode = embedding; + task.reranking_mode = rerank; task.type = TASK_TYPE_COMPLETION; task.multitask_id = multitask_id; @@ -1548,7 +1597,7 @@ struct llama_server_context subtask_data["prompt"] = subtask_data["prompt"][i]; // subtasks inherit everything else (infill mode, embedding mode, etc.) - request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id); + request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multiprompt_task.reranking_mode, multitask_id); } } @@ -1587,6 +1636,7 @@ struct llama_server_context slot->infill = task.infill_mode; slot->embedding = task.embedding_mode; + slot->reranker = task.reranking_mode; slot->task_id = task.id; slot->multitask_id = task.multitask_id; @@ -2030,6 +2080,14 @@ struct llama_server_context continue; } + if (slot.reranker) + { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; + } + completion_token_output result; const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i); diff --git a/backend/cpp/llama/utils.hpp b/backend/cpp/llama/utils.hpp index 198b6f265957..d79b63daa170 100644 --- a/backend/cpp/llama/utils.hpp +++ b/backend/cpp/llama/utils.hpp @@ -61,6 +61,7 @@ struct task_server { json data; bool infill_mode = false; bool embedding_mode = false; + bool reranking_mode = false; int multitask_id = -1; };