From 2e04ccf4e66a56eade51c2b62d7fe9026021fbb9 Mon Sep 17 00:00:00 2001 From: nvrxq Date: Wed, 18 Dec 2024 01:21:44 +0300 Subject: [PATCH] llama_server_response_fields --- examples/server/server.cpp | 6 +++++- examples/server/utils.hpp | 27 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 436170a034fde..bc179cfb5effd 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -91,6 +91,7 @@ struct slot_params { int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit std::vector antiprompt; + std::vector requested_fields; bool timings_per_token = false; bool ignore_eos = false; @@ -205,6 +206,7 @@ struct server_task { params.n_discard = json_value(data, "n_discard", defaults.n_discard); //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + params.requested_fields = json_value(data, "requested_fields", std::vector()); params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); @@ -482,6 +484,7 @@ struct server_task_result_cmpl_final : server_task_result { stop_type stop = STOP_TYPE_NONE; std::vector probs_output; + std::vector requested_fields; slot_params generation_params; @@ -527,7 +530,7 @@ struct server_task_result_cmpl_final : server_task_result { if (!probs_output.empty()) { res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); } - return res; + return requested_fields.empty() ? res : json_get_nested_values(requested_fields, res); } json to_json_oaicompat_chat() { @@ -1960,6 +1963,7 @@ struct server_context { res->content = slot.generated_text; res->timings = slot.get_timings(); res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); + res->requested_fields = slot.params.requested_fields; res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 8fffe484aec12..0ac8b2cce8478 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -88,6 +88,33 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) { return false; } +// get value by path(key1 / key2) +static json json_get_nested_values(const std::vector& paths, const json& js) { + json result = json::object(); + + for (const std::string& path : paths) { + json current = js; + std::istringstream stream(path); + std::string key; + std::vector keys; + while (std::getline(stream, key, '/')) { + keys.push_back(key); + } + bool valid_path = true; + for (const std::string& k : keys) { + if (valid_path && current.is_object() && current.contains(k)) { + current = current[k]; + } else { + valid_path = false; + } + } + if (valid_path) { + result[path] = current; + } + } + return result; +} + /** * this handles 2 cases: * - only string, example: "string"