Skip to content

Commit

Permalink
llama_server_response_fields
Browse files Browse the repository at this point in the history
  • Loading branch information
nvrxq committed Dec 17, 2024
1 parent 081b29b commit 2e04ccf
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
6 changes: 5 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ struct slot_params {
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit

std::vector<std::string> antiprompt;
std::vector<std::string> requested_fields;
bool timings_per_token = false;
bool ignore_eos = false;

Expand Down Expand Up @@ -205,6 +206,7 @@ struct server_task {
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
params.requested_fields = json_value(data, "requested_fields", std::vector<std::string>());

params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
Expand Down Expand Up @@ -482,6 +484,7 @@ struct server_task_result_cmpl_final : server_task_result {
stop_type stop = STOP_TYPE_NONE;

std::vector<completion_token_output> probs_output;
std::vector<std::string> requested_fields;

slot_params generation_params;

Expand Down Expand Up @@ -527,7 +530,7 @@ struct server_task_result_cmpl_final : server_task_result {
if (!probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
}
return res;
return requested_fields.empty() ? res : json_get_nested_values(requested_fields, res);
}

json to_json_oaicompat_chat() {
Expand Down Expand Up @@ -1960,6 +1963,7 @@ struct server_context {
res->content = slot.generated_text;
res->timings = slot.get_timings();
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
res->requested_fields = slot.params.requested_fields;

res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded;
Expand Down
27 changes: 27 additions & 0 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,33 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) {
return false;
}

// get value by path(key1 / key2)
static json json_get_nested_values(const std::vector<std::string>& paths, const json& js) {
json result = json::object();

for (const std::string& path : paths) {
json current = js;
std::istringstream stream(path);
std::string key;
std::vector<std::string> keys;
while (std::getline(stream, key, '/')) {
keys.push_back(key);
}
bool valid_path = true;
for (const std::string& k : keys) {
if (valid_path && current.is_object() && current.contains(k)) {
current = current[k];
} else {
valid_path = false;
}
}
if (valid_path) {
result[path] = current;
}
}
return result;
}

/**
* this handles 2 cases:
* - only string, example: "string"
Expand Down

0 comments on commit 2e04ccf

Please sign in to comment.