From 6c5bc0625fae6909cb40def15bc4bb45db6f7f4d Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 6 Dec 2024 11:14:32 +0100 Subject: [PATCH] server : (refactoring) do not rely on JSON internally (#10643) * server : (refactoring) reduce usage of json internally * move all response types to struct * wip [no ci] * many fixes * add virtual function * fix index * minor style fix * add std::move * refactor handle_completions_generic * add virtual functions * remove server.hpp * clarify server_sent_event RFC specs * apply review comments * fix model_alias and completion_probabilities * small clean up * remove virtual for to_json_oai_compat() * naming oai_compat --> oaicompat * fix unwanted recursive call * update docs --- common/common.h | 2 +- examples/server/README.md | 8 +- examples/server/server.cpp | 1337 +++++++++++------ examples/server/tests/README.md | 6 + examples/server/tests/tests.sh | 4 + .../server/tests/unit/test_chat_completion.py | 33 +- examples/server/tests/unit/test_completion.py | 39 + examples/server/utils.hpp | 249 +-- 8 files changed, 983 insertions(+), 695 deletions(-) diff --git a/common/common.h b/common/common.h index 0373fd3ead49e..95d20401d2a9a 100644 --- a/common/common.h +++ b/common/common.h @@ -215,7 +215,7 @@ struct common_params { struct common_params_speculative speculative; std::string model = ""; // model path // NOLINT - std::string model_alias = "unknown"; // model alias // NOLINT + std::string model_alias = ""; // model alias // NOLINT std::string model_url = ""; // model url to download // NOLINT std::string hf_token = ""; // HF token // NOLINT std::string hf_repo = ""; // HF repo // NOLINT diff --git a/examples/server/README.md b/examples/server/README.md index b2dd7b65a990c..8dbed2626a444 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -473,9 +473,11 @@ Notice that each `probs` is an array of length `n_probs`. - `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.). - `model`: The path to the model loaded with `-m` - `prompt`: The provided `prompt` -- `stopped_eos`: Indicating whether the completion has stopped because it encountered the EOS token -- `stopped_limit`: Indicating whether the completion stopped because `n_predict` tokens were generated before stop words or EOS was encountered -- `stopped_word`: Indicating whether the completion stopped due to encountering a stopping word from `stop` JSON array provided +- `stop_type`: Indicating whether the completion has stopped. Possible values are: + - `none`: Generating (not stopped) + - `eos`: Stopped because it encountered the EOS token + - `limit`: Stopped because `n_predict` tokens were generated before stop words or EOS was encountered + - `word`: Stopped due to encountering a stopping word from `stop` JSON array provided - `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word) - `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second` - `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 31dfd62408047..809fafa187add 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -33,8 +33,10 @@ using json = nlohmann::ordered_json; enum stop_type { - STOP_TYPE_FULL, - STOP_TYPE_PARTIAL, + STOP_TYPE_NONE, + STOP_TYPE_EOS, + STOP_TYPE_WORD, + STOP_TYPE_LIMIT, }; // state diagram: https://github.com/ggerganov/llama.cpp/pull/9283 @@ -69,12 +71,25 @@ enum server_task_inf_type { SERVER_TASK_INF_TYPE_INFILL, }; +// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 +enum error_type { + ERROR_TYPE_INVALID_REQUEST, + ERROR_TYPE_AUTHENTICATION, + ERROR_TYPE_SERVER, + ERROR_TYPE_NOT_FOUND, + ERROR_TYPE_PERMISSION, + ERROR_TYPE_UNAVAILABLE, // custom error + ERROR_TYPE_NOT_SUPPORTED, // custom error +}; + struct server_task { int id = -1; // to be filled by server_queue int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL llama_tokens prompt_tokens; server_task_type type; + + // TODO @ngxson : we should get rid of json type here json data; server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION; @@ -89,15 +104,6 @@ struct server_task { } }; -struct server_task_result { - int id = -1; - - json data; - - bool stop; - bool error; -}; - struct slot_params { bool stream = true; bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt @@ -111,9 +117,603 @@ struct slot_params { int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit std::vector antiprompt; + bool timings_per_token = false; struct common_params_sampling sampling; struct common_params_speculative speculative; + + // params only used in to_json() + int32_t n_ctx; + uint32_t seed_cur; + bool can_speculative; + + // OAI-compat fields + bool verbose = false; + bool oaicompat = false; + bool oaicompat_chat = true; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + + json to_json() { + std::vector samplers; + samplers.reserve(sampling.samplers.size()); + for (const auto & sampler : sampling.samplers) { + samplers.emplace_back(common_sampler_type_to_str(sampler)); + } + + return json { + {"n_ctx", n_ctx}, + {"n_predict", n_predict}, // Server configured n_predict + {"temperature", sampling.temp}, + {"dynatemp_range", sampling.dynatemp_range}, + {"dynatemp_exponent", sampling.dynatemp_exponent}, + {"top_k", sampling.top_k}, + {"top_p", sampling.top_p}, + {"min_p", sampling.min_p}, + {"xtc_probability", sampling.xtc_probability}, + {"xtc_threshold", sampling.xtc_threshold}, + {"typical_p", sampling.typ_p}, + {"repeat_last_n", sampling.penalty_last_n}, + {"repeat_penalty", sampling.penalty_repeat}, + {"presence_penalty", sampling.penalty_present}, + {"frequency_penalty", sampling.penalty_freq}, + {"dry_multiplier", sampling.dry_multiplier}, + {"dry_base", sampling.dry_base}, + {"dry_allowed_length", sampling.dry_allowed_length}, + {"dry_penalty_last_n", sampling.dry_penalty_last_n}, + {"dry_sequence_breakers", sampling.dry_sequence_breakers}, + {"mirostat", sampling.mirostat}, + {"mirostat_tau", sampling.mirostat_tau}, + {"mirostat_eta", sampling.mirostat_eta}, + {"penalize_nl", sampling.penalize_nl}, + {"stop", antiprompt}, + {"max_tokens", n_predict}, // User configured n_predict + {"n_keep", n_keep}, + {"n_discard", n_discard}, + {"ignore_eos", sampling.ignore_eos}, + {"stream", stream}, + //{"logit_bias", sampling.logit_bias}, + {"n_probs", sampling.n_probs}, + {"min_keep", sampling.min_keep}, + {"grammar", sampling.grammar}, + {"samplers", samplers}, + {"speculative", can_speculative}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"timings_per_token", timings_per_token}, + }; + } +}; + +struct result_timings { + int32_t prompt_n = -1; + double prompt_ms; + double prompt_per_token_ms; + double prompt_per_second; + + int32_t predicted_n = -1; + double predicted_ms; + double predicted_per_token_ms; + double predicted_per_second; + + json to_json() { + return { + {"prompt_n", prompt_n}, + {"prompt_ms", prompt_ms}, + {"prompt_per_token_ms", prompt_per_token_ms}, + {"prompt_per_second", prompt_per_second}, + + {"predicted_n", predicted_n}, + {"predicted_ms", predicted_ms}, + {"predicted_per_token_ms", predicted_per_token_ms}, + {"predicted_per_second", predicted_per_second}, + }; + } +}; + +struct server_task_result { + int id = -1; + int id_slot = -1; + virtual bool is_error() { + // only used by server_task_result_error + return false; + } + virtual bool is_stop() { + // only used by server_task_result_cmpl_partial + return false; + } + virtual int get_index() { + return -1; + } + virtual json to_json() = 0; + virtual ~server_task_result() = default; +}; + +// using shared_ptr for polymorphism of server_task_result +using server_task_result_ptr = std::unique_ptr; + +inline std::string stop_type_to_str(stop_type type) { + switch (type) { + case STOP_TYPE_EOS: return "eos"; + case STOP_TYPE_WORD: return "word"; + case STOP_TYPE_LIMIT: return "limit"; + default: return "none"; + } +} + +struct completion_token_output { + llama_token tok; + std::string text_to_send; + struct token_prob { + llama_token tok; + std::string tok_str; + float prob; + }; + std::vector probs; + + json to_json() const { + json probs_for_token = json::array(); + for (const auto & p : probs) { + probs_for_token.push_back(json { + {"tok_str", p.tok_str}, + {"prob", p.prob}, + }); + } + return probs_for_token; + } + + static json probs_vector_to_json(const std::vector & probs) { + json out = json::array(); + for (const auto & prob : probs) { + const std::string tok_str = prob.text_to_send; + out.push_back(json { + {"content", tok_str}, + {"probs", prob.to_json()}, + }); + } + return out; + } +}; + +struct server_task_result_cmpl_final : server_task_result { + int index = 0; + std::string content; + bool stream; + result_timings timings; + std::string prompt; + + bool truncated; + int32_t n_decoded; + int32_t n_prompt_tokens; + int32_t n_tokens_cached; + int32_t has_new_line; + std::string stopping_word; + stop_type stop = STOP_TYPE_NONE; + + std::vector probs_output; + + slot_params generation_params; + + // OAI-compat fields + bool verbose = false; + bool oaicompat = false; + bool oaicompat_chat = true; // TODO: support oaicompat for non-chat + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { + return oaicompat ? to_json_oaicompat_chat() : to_json_non_oaicompat(); + } + + json to_json_non_oaicompat() { + json res = json { + {"index", index}, + {"content", content}, + {"id_slot", id_slot}, + {"stop", true}, + {"model", oaicompat_model}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + {"generation_settings", generation_params.to_json()}, + {"prompt", prompt}, + {"has_new_line", has_new_line}, + {"truncated", truncated}, + {"stop_type", stop_type_to_str(stop)}, + {"stopping_word", stopping_word}, + {"tokens_cached", n_tokens_cached}, + {"timings", timings.to_json()}, + }; + if (!probs_output.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); + } + return res; + } + + json to_json_oaicompat_chat() { + std::string finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + + json choices = json::array({json{ + {"finish_reason", finish_reason}, + {"index", 0}, + {"message", json{ + {"content", content}, + {"role", "assistant"} + } + }}}); + + std::time_t t = std::time(0); + + json res = json { + {"choices", choices}, + {"created", t}, + {"model", oaicompat_model}, + {"object", "chat.completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } +}; + +struct server_task_result_cmpl_partial : server_task_result { + int index = 0; + std::string content; + + bool truncated; + int32_t n_decoded; + int32_t n_prompt_tokens; + + stop_type stop = STOP_TYPE_NONE; + + std::vector probs_output; + result_timings timings; + + // OAI-compat fields + bool verbose = false; + bool oaicompat = false; + bool oaicompat_chat = true; // TODO: support oaicompat for non-chat + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return stop != STOP_TYPE_NONE; + } + + virtual json to_json() override { + if (oaicompat) { + return to_json_oaicompat(); + } + bool is_stop = stop != STOP_TYPE_NONE; + // non-OAI-compat JSON + json res = json { + {"index", index}, + {"content", content}, + {"stop_type", stop_type_to_str(stop)}, + {"stop", is_stop}, + {"id_slot", id_slot}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + }; + // populate the timings object when needed (usually for the last response or with timings_per_token enabled) + if (timings.prompt_n > 0) { + res.push_back({"timings", timings.to_json()}); + } + if (!probs_output.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); + } + if (is_stop) { + res.push_back({"truncated", truncated}); + } + return res; + } + + json to_json_oaicompat() { + bool first = n_decoded == 0; + + std::string finish_reason; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } else if (stop == STOP_TYPE_LIMIT) { + finish_reason = "length"; + } + + std::time_t t = std::time(0); + + json choices; + + if (!finish_reason.empty()) { + choices = json::array({json{{"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}}}); + } else { + if (first) { + if (content.empty()) { + choices = json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{{"role", "assistant"}}}}}); + } else { + // We have to send this as two updates to conform to openai behavior + json initial_ret = json{{"choices", json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"role", "assistant"} + }}}})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + json second_ret = json{ + {"choices", json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"content", content}}} + }})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + return std::vector({initial_ret, second_ret}); + } + } else { + // Some idiosyncrasy in task processing logic makes several trailing calls + // with empty content, we ignore these at the calee site. + if (content.empty()) { + return std::vector({json::object()}); + } + + choices = json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", + json{ + {"content", content}, + }}, + }}); + } + } + + json ret = json { + {"choices", choices}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"} + }; + + if (timings.prompt_n >= 0) { + ret.push_back({"timings", timings.to_json()}); + } + + if (!finish_reason.empty()) { + ret.push_back({"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}); + } + + return std::vector({ret}); + } +}; + +struct server_task_result_embd : server_task_result { + int index = 0; + std::vector embedding; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { + return json { + {"index", index}, + {"embedding", embedding}, + }; + } +}; + +struct server_task_result_rerank : server_task_result { + int index = 0; + float score = -1e6; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { + return json { + {"index", index}, + {"score", score}, + }; + } +}; + +// this function maybe used outside of server_task_result_error +static json format_error_response(const std::string & message, const enum error_type type) { + std::string type_str; + int code = 500; + switch (type) { + case ERROR_TYPE_INVALID_REQUEST: + type_str = "invalid_request_error"; + code = 400; + break; + case ERROR_TYPE_AUTHENTICATION: + type_str = "authentication_error"; + code = 401; + break; + case ERROR_TYPE_NOT_FOUND: + type_str = "not_found_error"; + code = 404; + break; + case ERROR_TYPE_SERVER: + type_str = "server_error"; + code = 500; + break; + case ERROR_TYPE_PERMISSION: + type_str = "permission_error"; + code = 403; + break; + case ERROR_TYPE_NOT_SUPPORTED: + type_str = "not_supported_error"; + code = 501; + break; + case ERROR_TYPE_UNAVAILABLE: + type_str = "unavailable_error"; + code = 503; + break; + } + return json { + {"code", code}, + {"message", message}, + {"type", type_str}, + }; +} + +struct server_task_result_error : server_task_result { + int index = 0; + error_type err_type = ERROR_TYPE_SERVER; + std::string err_msg; + + virtual bool is_error() override { + return true; + } + + virtual json to_json() override { + return format_error_response(err_msg, err_type); + } +}; + +struct server_task_result_metrics : server_task_result { + int n_idle_slots; + int n_processing_slots; + int n_tasks_deferred; + int64_t t_start; + + int32_t kv_cache_tokens_count; + int32_t kv_cache_used_cells; + + // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + // TODO: get rid of this json object and use to_json() instead + json slots_data = json::array(); + + virtual json to_json() override { + return json { + { "idle", n_idle_slots }, + { "processing", n_processing_slots }, + { "deferred", n_tasks_deferred }, + { "t_start", t_start }, + + { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, + { "t_tokens_generation_total", t_tokens_generation_total }, + { "n_tokens_predicted_total", n_tokens_predicted_total }, + { "t_prompt_processing_total", t_prompt_processing_total }, + + { "n_prompt_tokens_processed", n_prompt_tokens_processed }, + { "t_prompt_processing", t_prompt_processing }, + { "n_tokens_predicted", n_tokens_predicted }, + { "t_tokens_generation", t_tokens_generation }, + + { "n_decode_total", n_decode_total }, + { "n_busy_slots_total", n_busy_slots_total }, + + { "kv_cache_tokens_count", kv_cache_tokens_count }, + { "kv_cache_used_cells", kv_cache_used_cells }, + + { "slots", slots_data }, + }; + } +}; + +struct server_task_result_slot_save_load : server_task_result { + std::string filename; + bool is_save; // true = save, false = load + + size_t n_tokens; + size_t n_bytes; + double t_ms; + + virtual json to_json() override { + if (is_save) { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", n_tokens }, + { "n_written", n_bytes }, + { "timings", { + { "save_ms", t_ms } + }}, + }; + } else { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", n_tokens }, + { "n_read", n_bytes }, + { "timings", { + { "restore_ms", t_ms } + }}, + }; + } + } +}; + +struct server_task_result_slot_erase : server_task_result { + size_t n_erased; + + virtual json to_json() override { + return json { + { "id_slot", id_slot }, + { "n_erased", n_erased }, + }; + } +}; + +struct server_task_result_apply_lora : server_task_result { + virtual json to_json() override { + return json {{ "success", true }}; + } }; struct server_slot { @@ -162,15 +762,8 @@ struct server_slot { bool has_next_token = true; bool has_new_line = false; bool truncated = false; - bool stopped_eos = false; - bool stopped_word = false; - bool stopped_limit = false; - - bool timings_per_token = false; + stop_type stop; - bool oaicompat = false; - - std::string oaicompat_model; std::string stopping_word; // sampling @@ -200,9 +793,7 @@ struct server_slot { generated_text = ""; has_new_line = false; truncated = false; - stopped_eos = false; - stopped_word = false; - stopped_limit = false; + stop = STOP_TYPE_NONE; stopping_word = ""; n_past = 0; n_sent_text = 0; @@ -255,38 +846,40 @@ struct server_slot { } } - json get_formated_timings() const { - return json { - {"prompt_n", n_prompt_tokens_processed}, - {"prompt_ms", t_prompt_processing}, - {"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed}, - {"prompt_per_second", 1e3 / t_prompt_processing * n_prompt_tokens_processed}, - - {"predicted_n", n_decoded}, - {"predicted_ms", t_token_generation}, - {"predicted_per_token_ms", t_token_generation / n_decoded}, - {"predicted_per_second", 1e3 / t_token_generation * n_decoded}, - }; + result_timings get_timings() const { + result_timings timings; + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; + timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; + timings.predicted_per_token_ms = t_token_generation / n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + + return timings; } - size_t find_stopping_strings(const std::string & text, const size_t last_token_size, const stop_type type) { + size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { size_t stop_pos = std::string::npos; for (const std::string & word : params.antiprompt) { size_t pos; - if (type == STOP_TYPE_FULL) { + if (is_full_stop) { const size_t tmp = word.size() + last_token_size; const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; pos = text.find(word, from_pos); } else { + // otherwise, partial stop pos = find_partial_stop_string(word, text); } if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { - if (type == STOP_TYPE_FULL) { - stopped_word = true; + if (is_full_stop) { + stop = STOP_TYPE_WORD; stopping_word = word; has_next_token = false; } @@ -511,8 +1104,8 @@ struct server_response { // for keeping track of all tasks waiting for the result std::unordered_set waiting_task_ids; - // the main result queue - std::vector queue_results; + // the main result queue (using ptr for polymorphism) + std::vector queue_results; std::mutex mutex_results; std::condition_variable condition_results; @@ -552,7 +1145,7 @@ struct server_response { } // This function blocks the thread until there is a response for one of the id_tasks - server_task_result recv(const std::unordered_set & id_tasks) { + server_task_result_ptr recv(const std::unordered_set & id_tasks) { while (true) { std::unique_lock lock(mutex_results); condition_results.wait(lock, [&]{ @@ -560,8 +1153,8 @@ struct server_response { }); for (int i = 0; i < (int) queue_results.size(); i++) { - if (id_tasks.find(queue_results[i].id) != id_tasks.end()) { - server_task_result res = queue_results[i]; + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); queue_results.erase(queue_results.begin() + i); return res; } @@ -572,21 +1165,21 @@ struct server_response { } // single-task version of recv() - server_task_result recv(int id_task) { + server_task_result_ptr recv(int id_task) { std::unordered_set id_tasks = {id_task}; return recv(id_tasks); } // Send a new result to a waiting id_task - void send(server_task_result & result) { - SRV_DBG("sending result for task id = %d\n", result.id); + void send(server_task_result_ptr && result) { + SRV_DBG("sending result for task id = %d\n", result->id); std::unique_lock lock(mutex_results); for (const auto & id_task : waiting_task_ids) { - if (result.id == id_task) { - SRV_DBG("task id = %d moved to result queue\n", result.id); + if (result->id == id_task) { + SRV_DBG("task id = %d pushed to result queue\n", result->id); - queue_results.push_back(std::move(result)); + queue_results.emplace_back(std::move(result)); condition_results.notify_all(); return; } @@ -777,7 +1370,7 @@ struct server_context { slots.push_back(slot); } - default_generation_settings_for_props = get_formated_generation(slots.front()); + default_generation_settings_for_props = slots[0].params.to_json(); default_generation_settings_for_props["seed"] = -1; // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens @@ -873,14 +1466,19 @@ struct server_context { const auto & data = task.data; if (data.count("__oaicompat") != 0) { - slot.oaicompat = true; - slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); + std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; + slot.params.oaicompat = true; + slot.params.oaicompat_chat = json_value(data, "__oaicompat_chat", false); + slot.params.oaicompat_model = json_value(data, "model", model_name); + slot.params.oaicompat_cmpl_id = json_value(data, "completion_id", std::string()); } else { - slot.oaicompat = false; - slot.oaicompat_model = ""; + slot.params.oaicompat = false; } - slot.timings_per_token = json_value(data, "timings_per_token", false); + + // enabling this will output extra debug information in the HTTP responses from the server + slot.params.verbose = params_base.verbosity > 9; + slot.params.timings_per_token = json_value(data, "timings_per_token", false); slot.params.stream = json_value(data, "stream", false); slot.params.cache_prompt = json_value(data, "cache_prompt", true); @@ -1110,14 +1708,14 @@ struct server_context { const std::string str_test = slot.generated_text.substr(pos); bool send_text = true; - size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL); + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); if (stop_pos != std::string::npos) { slot.generated_text.erase( slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end()); pos = std::min(slot.n_sent_text, slot.generated_text.size()); } else if (slot.has_next_token) { - stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL); + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); send_text = stop_pos == std::string::npos; } @@ -1141,7 +1739,7 @@ struct server_context { // check the limits if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { - slot.stopped_limit = true; + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); @@ -1150,7 +1748,7 @@ struct server_context { if (slot.has_new_line) { // if we have already seen a new line, we stop after a certain time limit if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { - slot.stopped_limit = true; + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); @@ -1170,7 +1768,7 @@ struct server_context { } if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) { - slot.stopped_limit = true; + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; // cut the last line @@ -1199,7 +1797,7 @@ struct server_context { // if context shift is disabled, we stop when it reaches the context limit if (slot.n_past >= slot.n_ctx) { slot.truncated = true; - slot.stopped_limit = true; + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", @@ -1207,7 +1805,7 @@ struct server_context { } if (llama_token_is_eog(model, result.tok)) { - slot.stopped_eos = true; + slot.stop = STOP_TYPE_EOS; slot.has_next_token = false; SLT_DBG(slot, "%s", "stopped by EOS\n"); @@ -1217,7 +1815,7 @@ struct server_context { if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { slot.truncated = true; - slot.stopped_limit = true; + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; // stop prediction SLT_WRN(slot, @@ -1231,60 +1829,6 @@ struct server_context { return slot.has_next_token; // continue } - json get_formated_generation(const server_slot & slot) const { - std::vector samplers; - samplers.reserve(slot.params.sampling.samplers.size()); - for (const auto & sampler : slot.params.sampling.samplers) { - samplers.emplace_back(common_sampler_type_to_str(sampler)); - } - - return json { - {"n_ctx", slot.n_ctx}, - {"n_predict", slot.n_predict}, // Server configured n_predict - {"model", params_base.model_alias}, - {"seed", slot.params.sampling.seed}, - {"seed_cur", slot.smpl ? common_sampler_get_seed(slot.smpl) : 0}, - {"temperature", slot.params.sampling.temp}, - {"dynatemp_range", slot.params.sampling.dynatemp_range}, - {"dynatemp_exponent", slot.params.sampling.dynatemp_exponent}, - {"top_k", slot.params.sampling.top_k}, - {"top_p", slot.params.sampling.top_p}, - {"min_p", slot.params.sampling.min_p}, - {"xtc_probability", slot.params.sampling.xtc_probability}, - {"xtc_threshold", slot.params.sampling.xtc_threshold}, - {"typical_p", slot.params.sampling.typ_p}, - {"repeat_last_n", slot.params.sampling.penalty_last_n}, - {"repeat_penalty", slot.params.sampling.penalty_repeat}, - {"presence_penalty", slot.params.sampling.penalty_present}, - {"frequency_penalty", slot.params.sampling.penalty_freq}, - {"dry_multiplier", slot.params.sampling.dry_multiplier}, - {"dry_base", slot.params.sampling.dry_base}, - {"dry_allowed_length", slot.params.sampling.dry_allowed_length}, - {"dry_penalty_last_n", slot.params.sampling.dry_penalty_last_n}, - {"dry_sequence_breakers", slot.params.sampling.dry_sequence_breakers}, - {"mirostat", slot.params.sampling.mirostat}, - {"mirostat_tau", slot.params.sampling.mirostat_tau}, - {"mirostat_eta", slot.params.sampling.mirostat_eta}, - {"penalize_nl", slot.params.sampling.penalize_nl}, - {"stop", slot.params.antiprompt}, - {"max_tokens", slot.params.n_predict}, // User configured n_predict - {"n_keep", slot.params.n_keep}, - {"n_discard", slot.params.n_discard}, - {"ignore_eos", slot.params.sampling.ignore_eos}, - {"stream", slot.params.stream}, - //{"logit_bias", slot.params.sampling.logit_bias}, - {"n_probs", slot.params.sampling.n_probs}, - {"min_keep", slot.params.sampling.min_keep}, - {"grammar", slot.params.sampling.grammar}, - {"samplers", samplers}, - {"speculative", slot.can_speculate()}, - {"speculative.n_max", slot.params.speculative.n_max}, - {"speculative.n_min", slot.params.speculative.n_min}, - {"speculative.p_min", slot.params.speculative.p_min}, - {"timings_per_token", slot.timings_per_token}, - }; - } - void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { send_error(task.id, error, type); } @@ -1296,28 +1840,33 @@ struct server_context { void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); - server_task_result res; - res.id = id_task; - res.stop = false; - res.error = true; - res.data = format_error_response(error, type); + auto res = std::make_unique(); + res->id = id_task; + res->err_type = type; + res->err_msg = error; - queue_results.send(res); + queue_results.send(std::move(res)); } void send_partial_response(server_slot & slot, completion_token_output tkn) { - server_task_result res; - res.id = slot.id_task; - res.error = false; - res.stop = false; - res.data = json { - {"content", tkn.text_to_send}, - {"stop", false}, - {"id_slot", slot.id}, - {"multimodal", false}, - {"index", slot.index}, - }; + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->content = tkn.text_to_send; + + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + + res->stop = slot.stop; + + res->verbose = slot.params.verbose; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_chat = slot.params.oaicompat_chat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + // populate res.probs_output if (slot.params.sampling.n_probs > 0) { const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); @@ -1325,83 +1874,74 @@ struct server_context { std::vector probs_output; if (probs_pos < probs_stop_pos) { - probs_output = std::vector( + res->probs_output = std::vector( slot.generated_token_probs.begin() + probs_pos, slot.generated_token_probs.begin() + probs_stop_pos); } - slot.n_sent_token_probs = probs_stop_pos; - - res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output); } - if (slot.oaicompat) { - res.data["oaicompat_token_ctr"] = slot.n_decoded; - res.data["model"] = slot.oaicompat_model; + // populate timings if this is final response or timings_per_token is enabled + if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) { + res->timings = slot.get_timings(); } - if (slot.timings_per_token) { - res.data["timings"] = slot.get_formated_timings(); + queue_results.send(std::move(res)); + } + + void send_final_response(server_slot & slot) { + if (slot.params.stream) { + // if in stream mode, send the last partial response + return send_partial_response(slot, {0, "", {}}); } - queue_results.send(res); - } + auto res = std::make_unique(); + res->id = slot.id_task; + res->id_slot = slot.id; - void send_final_response(const server_slot & slot) { - server_task_result res; - res.id = slot.id_task; - res.error = false; - res.stop = true; - res.data = json { - {"content", !slot.params.stream ? slot.generated_text : ""}, - {"id_slot", slot.id}, - {"stop", true}, - {"model", params_base.model_alias}, - {"tokens_predicted", slot.n_decoded}, - {"tokens_evaluated", slot.n_prompt_tokens}, - {"generation_settings", get_formated_generation(slot)}, - {"prompt", common_detokenize(ctx, slot.prompt_tokens)}, - {"has_new_line", slot.has_new_line}, - {"truncated", slot.truncated}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - {"tokens_cached", slot.n_past}, - {"timings", slot.get_formated_timings()}, - {"index", slot.index}, - }; + res->index = slot.index; + res->content = slot.generated_text; + res->timings = slot.get_timings(); + res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); + + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_tokens_cached = slot.n_past; + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + res->verbose = slot.params.verbose; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_chat = slot.params.oaicompat_chat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + + // populate res.probs_output if (slot.params.sampling.n_probs > 0) { - std::vector probs; - if (!slot.params.stream && slot.stopped_word) { + if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); - probs = std::vector( + res->probs_output = std::vector( slot.generated_token_probs.begin(), slot.generated_token_probs.end() - safe_offset); } else { - probs = std::vector( + res->probs_output = std::vector( slot.generated_token_probs.begin(), slot.generated_token_probs.end()); } - - res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs); } - if (slot.oaicompat) { - res.data["oaicompat_token_ctr"] = slot.n_decoded; - res.data["model"] = slot.oaicompat_model; - } + res->generation_params = slot.params; // copy the parameters - queue_results.send(res); + queue_results.send(std::move(res)); } void send_embedding(const server_slot & slot, const llama_batch & batch) { - server_task_result res; - res.id = slot.id_task; - res.error = false; - res.stop = true; + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; const int n_embd = llama_n_embd(model); @@ -1420,32 +1960,23 @@ struct server_context { if (embd == NULL) { SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - res.data = json { - {"embedding", std::vector(n_embd, 0.0f)}, - {"index", slot.index}, - }; - + res->embedding = std::vector(n_embd, 0.0f); continue; } common_embd_normalize(embd, embd_res.data(), n_embd); - - res.data = json { - {"embedding", embd_res}, - {"index", slot.index}, - }; + res->embedding = embd_res; } SLT_DBG(slot, "%s", "sending embeddings\n"); - queue_results.send(res); + queue_results.send(std::move(res)); } void send_rerank(const server_slot & slot, const llama_batch & batch) { - server_task_result res; - res.id = slot.id_task; - res.error = false; - res.stop = true; + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; for (int i = 0; i < batch.n_tokens; ++i) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { @@ -1460,23 +1991,16 @@ struct server_context { if (embd == NULL) { SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - res.data = json { - {"index", slot.index}, - {"score", -1e6}, - }; - + res->score = -1e6; continue; } - res.data = json { - {"index", slot.index}, - {"score", embd[0]}, - }; + res->score = embd[0]; } - SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str()); + SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); - queue_results.send(res); + queue_results.send(std::move(res)); } // @@ -1567,49 +2091,54 @@ struct server_context { } // receive the results from task(s) created by create_tasks_inference - void receive_cmpl_results( + void receive_multi_results( const std::unordered_set & id_tasks, - const std::function&)> & result_handler, + const std::function&)> & result_handler, const std::function & error_handler) { - // TODO: currently, there is no way to detect the client has cancelled the request - std::vector results(id_tasks.size()); + std::vector results(id_tasks.size()); for (size_t i = 0; i < id_tasks.size(); i++) { - server_task_result result = queue_results.recv(id_tasks); + server_task_result_ptr result = queue_results.recv(id_tasks); - if (result.error) { - error_handler(result.data); + if (result->is_error()) { + error_handler(result->to_json()); cancel_tasks(id_tasks); return; } - const size_t idx = result.data["index"]; + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + const size_t idx = result->get_index(); GGML_ASSERT(idx < results.size() && "index out of range"); - - results[idx] = result; + results[idx] = std::move(result); } result_handler(results); } // receive the results from task(s) created by create_tasks_inference, in stream mode void receive_cmpl_results_stream( - const std::unordered_set & id_tasks, const - std::function & result_handler, const - std::function & error_handler) { + const std::unordered_set & id_tasks, + const std::function & result_handler, + const std::function & error_handler) { size_t n_finished = 0; while (true) { - server_task_result result = queue_results.recv(id_tasks); - if (!result_handler(result)) { + server_task_result_ptr result = queue_results.recv(id_tasks); + + if (result->is_error()) { + error_handler(result->to_json()); cancel_tasks(id_tasks); - break; + return; } - if (result.error) { - error_handler(result.data); + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + if (!result_handler(result)) { cancel_tasks(id_tasks); break; } - if (result.stop) { + if (result->is_stop()) { if (++n_finished == id_tasks.size()) { break; } @@ -1676,7 +2205,7 @@ struct server_context { int n_processing_slots = 0; for (server_slot & slot : slots) { - json slot_data = get_formated_generation(slot); + json slot_data = slot.params.to_json(); slot_data["id"] = slot.id; slot_data["id_task"] = slot.id_task; slot_data["is_processing"] = slot.is_processing(); @@ -1686,9 +2215,6 @@ struct server_context { {"has_new_line", slot.has_new_line}, {"n_remain", slot.n_remaining}, {"n_decoded", slot.n_decoded}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, {"stopping_word", slot.stopping_word}, }; @@ -1702,39 +2228,33 @@ struct server_context { } SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); - server_task_result res; - res.id = task.id; - res.stop = true; - res.error = false; - res.data = { - { "idle", n_idle_slots }, - { "processing", n_processing_slots }, - { "deferred", queue_tasks.queue_tasks_deferred.size() }, - { "t_start", metrics.t_start}, - - { "n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total}, - { "t_tokens_generation_total", metrics.t_tokens_generation_total}, - { "n_tokens_predicted_total", metrics.n_tokens_predicted_total}, - { "t_prompt_processing_total", metrics.t_prompt_processing_total}, + auto res = std::make_unique(); + res->id = task.id; + res->n_idle_slots = n_idle_slots; + res->n_processing_slots = n_processing_slots; + res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); + res->t_start = metrics.t_start; - { "n_prompt_tokens_processed", metrics.n_prompt_tokens_processed}, - { "t_prompt_processing", metrics.t_prompt_processing}, - { "n_tokens_predicted", metrics.n_tokens_predicted}, - { "t_tokens_generation", metrics.t_tokens_generation}, + res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx); + res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx); - { "n_decode_total", metrics.n_decode_total}, - { "n_busy_slots_total", metrics.n_busy_slots_total}, + res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; + res->t_prompt_processing_total = metrics.t_prompt_processing_total; + res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; + res->t_tokens_generation_total = metrics.t_tokens_generation_total; - { "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)}, - { "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)}, + res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; + res->t_prompt_processing = metrics.t_prompt_processing; + res->n_tokens_predicted = metrics.n_tokens_predicted; + res->t_tokens_generation = metrics.t_tokens_generation; - { "slots", slots_data }, - }; + res->n_decode_total = metrics.n_decode_total; + res->n_busy_slots_total = metrics.n_busy_slots_total; if (json_value(task.data, "reset_bucket", false)) { metrics.reset_bucket(); } - queue_results.send(res); + queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SLOT_SAVE: { @@ -1762,20 +2282,15 @@ struct server_context { const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_saved", token_count }, // tokens saved - { "n_written", nwrite }, // bytes written - { "timings", { - { "save_ms", t_save_ms } - } } - }; - queue_results.send(result); + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = true; + res->n_tokens = token_count; + res->n_bytes = nwrite; + res->t_ms = t_save_ms; + queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SLOT_RESTORE: { @@ -1810,20 +2325,15 @@ struct server_context { const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_restored", token_count }, // tokens restored - { "n_read", nread }, // bytes read - { "timings", { - { "restore_ms", t_restore_ms } - } } - }; - queue_results.send(result); + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = false; + res->n_tokens = token_count; + res->n_bytes = nread; + res->t_ms = t_restore_ms; + queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SLOT_ERASE: { @@ -1845,25 +2355,18 @@ struct server_context { llama_kv_cache_seq_rm(ctx, slot->id, -1, -1); slot->cache_tokens.clear(); - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json { - { "id_slot", id_slot }, - { "n_erased", n_erased } - }; - queue_results.send(result); + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->n_erased = n_erased; + queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SET_LORA: { common_lora_adapters_apply(ctx, loras); - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json{{ "success", true }}; - queue_results.send(result); + auto res = std::make_unique(); + res->id = task.id; + queue_results.send(std::move(res)); } break; } } @@ -2298,9 +2801,11 @@ struct server_context { const auto * cur_p = common_sampler_get_candidates(slot.smpl); for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) { + auto tok_id = cur_p->data[i].id; result.probs.push_back({ - cur_p->data[i].id, - i >= cur_p->size ? 0.0f : cur_p->data[i].p, + tok_id, + tokens_to_output_formatted_string(ctx, tok_id), + i >= cur_p->size ? 0.0f : cur_p->data[i].p, }); } @@ -2452,17 +2957,9 @@ int main(int argc, char ** argv) { common_init(); - // enabling this will output extra debug information in the HTTP responses from the server - // see format_final_response_oaicompat() - const bool verbose = params.verbosity > 9; - // struct that contains llama context and inference server_context ctx_server; - if (params.model_alias == "unknown") { - params.model_alias = params.model; - } - llama_backend_init(); llama_numa_init(params.numa); @@ -2647,19 +3144,27 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.post(task, true); // high-priority task // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id); + if (result->is_error()) { + res_error(res, result->to_json()); + return; + } + + // TODO: get rid of this dynamic_cast + auto res_metrics = dynamic_cast(result.get()); + GGML_ASSERT(res_metrics != nullptr); + // optionally return "fail_on_no_slot" error - const int n_idle_slots = result.data.at("idle"); if (req.has_param("fail_on_no_slot")) { - if (n_idle_slots == 0) { + if (res_metrics->n_idle_slots == 0) { res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); return; } } - res_ok(res, result.data.at("slots")); + res_ok(res, res_metrics->slots_data); }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { @@ -2679,73 +3184,69 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.post(task, true); // high-priority task // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id); - json data = result.data; - - const uint64_t n_prompt_tokens_processed = data.at("n_prompt_tokens_processed"); - const uint64_t t_prompt_processing = data.at("t_prompt_processing"); - - const uint64_t n_tokens_predicted = data.at("n_tokens_predicted"); - const uint64_t t_tokens_generation = data.at("t_tokens_generation"); - - const uint64_t n_decode_total = data.at("n_decode_total"); - const uint64_t n_busy_slots_total = data.at("n_busy_slots_total"); + if (result->is_error()) { + res_error(res, result->to_json()); + return; + } - const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells"); + // TODO: get rid of this dynamic_cast + auto res_metrics = dynamic_cast(result.get()); + GGML_ASSERT(res_metrics != nullptr); // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names json all_metrics_def = json { {"counter", {{ {"name", "prompt_tokens_total"}, {"help", "Number of prompt tokens processed."}, - {"value", (uint64_t) data.at("n_prompt_tokens_processed_total")} + {"value", (uint64_t) res_metrics->n_prompt_tokens_processed_total} }, { {"name", "prompt_seconds_total"}, {"help", "Prompt process time"}, - {"value", (uint64_t) data.at("t_prompt_processing_total") / 1.e3} + {"value", (uint64_t) res_metrics->t_prompt_processing_total / 1.e3} }, { {"name", "tokens_predicted_total"}, {"help", "Number of generation tokens processed."}, - {"value", (uint64_t) data.at("n_tokens_predicted_total")} + {"value", (uint64_t) res_metrics->n_tokens_predicted_total} }, { {"name", "tokens_predicted_seconds_total"}, {"help", "Predict process time"}, - {"value", (uint64_t) data.at("t_tokens_generation_total") / 1.e3} + {"value", (uint64_t) res_metrics->t_tokens_generation_total / 1.e3} }, { {"name", "n_decode_total"}, {"help", "Total number of llama_decode() calls"}, - {"value", n_decode_total} + {"value", res_metrics->n_decode_total} }, { {"name", "n_busy_slots_per_decode"}, {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) n_busy_slots_total / (float) n_decode_total} + {"value", (float) res_metrics->n_busy_slots_total / (float) res_metrics->n_decode_total} }}}, {"gauge", {{ {"name", "prompt_tokens_seconds"}, {"help", "Average prompt throughput in tokens/s."}, - {"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.} + {"value", res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.} },{ {"name", "predicted_tokens_seconds"}, {"help", "Average generation throughput in tokens/s."}, - {"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.} + {"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.} },{ {"name", "kv_cache_usage_ratio"}, {"help", "KV-cache usage. 1 means 100 percent usage."}, - {"value", 1. * kv_cache_used_cells / params.n_ctx} + {"value", 1. * res_metrics->kv_cache_used_cells / params.n_ctx} },{ {"name", "kv_cache_tokens"}, {"help", "KV-cache tokens."}, - {"value", (uint64_t) data.at("kv_cache_tokens_count")} + {"value", (uint64_t) res_metrics->kv_cache_tokens_count} },{ {"name", "requests_processing"}, {"help", "Number of request processing."}, - {"value", (uint64_t) data.at("processing")} + {"value", (uint64_t) res_metrics->n_processing_slots} },{ {"name", "requests_deferred"}, {"help", "Number of request deferred."}, - {"value", (uint64_t) data.at("deferred")} + {"value", (uint64_t) res_metrics->n_tasks_deferred} }}} }; @@ -2766,8 +3267,7 @@ int main(int argc, char ** argv) { } } - const int64_t t_start = data.at("t_start"); - res.set_header("Process-Start-Time-Unix", std::to_string(t_start)); + res.set_header("Process-Start-Time-Unix", std::to_string(res_metrics->t_start)); res.set_content(prometheus.str(), "text/plain; version=0.0.4"); res.status = 200; // HTTP OK @@ -2793,14 +3293,15 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - server_task_result result = ctx_server.queue_results.recv(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - if (result.error) { - res_error(res, result.data); - } else { - res_ok(res, result.data); + if (result->is_error()) { + res_error(res, result->to_json()); + return; } + + res_ok(res, result->to_json()); }; const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { @@ -2823,14 +3324,16 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - server_task_result result = ctx_server.queue_results.recv(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - if (result.error) { - res_error(res, result.data); - } else { - res_ok(res, result.data); + if (result->is_error()) { + res_error(res, result->to_json()); + return; } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res_ok(res, result->to_json()); }; const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { @@ -2843,14 +3346,16 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - server_task_result result = ctx_server.queue_results.recv(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - if (result.error) { - res_error(res, result.data); - } else { - res_ok(res, result.data); + if (result->is_error()) { + res_error(res, result->to_json()); + return; } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res_ok(res, result->to_json()); }; const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { @@ -2905,12 +3410,19 @@ int main(int argc, char ** argv) { res_ok(res, {{ "success", true }}); }; - const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) { + // handle completion-like requests (completion, chat, infill) + // we can optionally provide a custom format for partial results and final results + const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok]( + server_task_inf_type inf_type, + json & data, + httplib::Response & res, + bool oai_compat = false) { if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } + data["completion_id"] = gen_chatcmplid(); std::vector tasks = ctx_server.create_tasks_inference(data, inf_type); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -2919,15 +3431,15 @@ int main(int argc, char ** argv) { const auto task_ids = server_task::get_list_id(tasks); if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { if (results.size() == 1) { // single result - res_ok(res, results[0].data); + res_ok(res, results[0]->to_json()); } else { // multiple results (multitask) json arr = json::array(); - for (const auto & res : results) { - arr.push_back(res.data); + for (auto & res : results) { + arr.push_back(res->to_json()); } res_ok(res, arr); } @@ -2937,12 +3449,26 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_ids(task_ids); } else { - const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { - return server_sent_event(sink, "data", result.data); + const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) { + ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { + json res_json = result->to_json(); + if (res_json.is_array()) { + for (const auto & res : res_json) { + if (!server_sent_event(sink, "data", res)) { + return false; + } + } + return true; + } else { + return server_sent_event(sink, "data", res_json); + } }, [&](const json & error_data) { server_sent_event(sink, "error", error_data); }); + if (oai_compat) { + static const std::string ev_done = "data: [DONE]\n\n"; + sink.write(ev_done.data(), ev_done.size()); + } sink.done(); return false; }; @@ -3010,61 +3536,15 @@ int main(int argc, char ** argv) { return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res); }; - // TODO: maybe merge this function with "handle_completions_generic" - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - - std::vector tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(tasks); - - bool stream = json_value(data, "stream", false); - const auto task_ids = server_task::get_list_id(tasks); - const auto completion_id = gen_chatcmplid(); - - if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [&](const std::vector & results) { - // multitask is never support in chat completion, there is only one result - json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose); - res_ok(res, result_oai); - }, [&](const json & error_data) { - res_error(res, error_data); - }); - - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - } else { - const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { - std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); - for (auto & event_data : result_array) { - if (event_data.empty()) { - continue; // skip the stop token - } - if (!server_sent_event(sink, "data", event_data)) { - return false; // connection is closed - } - } - return true; // ok - }, [&](const json & error_data) { - server_sent_event(sink, "error", error_data); - }); - static const std::string ev_done = "data: [DONE]\n\n"; - sink.write(ev_done.data(), ev_done.size()); - sink.done(); - return true; - }; - - auto on_complete = [task_ids, &ctx_server] (bool) { - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } + data["__oaicompat_chat"] = true; + return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true); }; const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { @@ -3139,12 +3619,12 @@ int main(int argc, char ** argv) { const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); - bool is_openai = false; + bool oaicompat = false; // an input prompt can be a string or a list of tokens (integer) json prompt; if (body.count("input") != 0) { - is_openai = true; + oaicompat = true; prompt = body.at("input"); } else if (body.count("content") != 0) { // with "content", we only support single prompt @@ -3165,9 +3645,10 @@ int main(int argc, char ** argv) { // get the result std::unordered_set task_ids = server_task::get_list_id(tasks); - ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { - for (const auto & res : results) { - responses.push_back(res.data); + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + for (auto & res : results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); } }, [&](const json & error_data) { res_error(res, error_data); @@ -3182,7 +3663,7 @@ int main(int argc, char ** argv) { } // write JSON response - json root = is_openai + json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : responses[0]; res_ok(res, root); @@ -3243,9 +3724,10 @@ int main(int argc, char ** argv) { // get the result std::unordered_set task_ids = server_task::get_list_id(tasks); - ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { - for (const auto & res : results) { - responses.push_back(res.data); + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + for (auto & res : results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); } }, [&](const json & error_data) { res_error(res, error_data); @@ -3301,11 +3783,16 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - server_task_result result = ctx_server.queue_results.recv(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - res_ok(res, result.data); - res.status = 200; // HTTP OK + if (result->is_error()) { + res_error(res, result->to_json()); + return; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res_ok(res, result->to_json()); }; // diff --git a/examples/server/tests/README.md b/examples/server/tests/README.md index 2930a2e0dea0f..fa3d0a2f5ff66 100644 --- a/examples/server/tests/README.md +++ b/examples/server/tests/README.md @@ -44,4 +44,10 @@ To run with stdout/stderr display in real time (verbose output, but useful for d DEBUG=1 ./tests.sh -s -v -x ``` +Hint: You can compile and run test in single command, useful for local developement: + +```shell +cmake --build build -j --target llama-server && ./examples/server/tests/tests.sh +``` + To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html) diff --git a/examples/server/tests/tests.sh b/examples/server/tests/tests.sh index 1e285dcdac14b..1e0777de367fc 100755 --- a/examples/server/tests/tests.sh +++ b/examples/server/tests/tests.sh @@ -1,5 +1,9 @@ #!/bin/bash +# make sure we are in the right directory +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR + set -eu if [ $# -lt 1 ] diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 8a439f9ef0f29..f13c6c4ca4bd3 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -12,13 +12,13 @@ def create_server(): @pytest.mark.parametrize( - "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated", + "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", [ - ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False), + (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), ] ) -def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated): +def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): global server server.start() res = server.make_request("POST", "/chat/completions", data={ @@ -30,29 +30,27 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte ], }) assert res.status_code == 200 + assert res.body["model"] == model if model is not None else server.model_alias assert res.body["usage"]["prompt_tokens"] == n_prompt assert res.body["usage"]["completion_tokens"] == n_predicted choice = res.body["choices"][0] assert "assistant" == choice["message"]["role"] assert match_regex(re_content, choice["message"]["content"]) - if truncated: - assert choice["finish_reason"] == "length" - else: - assert choice["finish_reason"] == "stop" + assert choice["finish_reason"] == finish_reason @pytest.mark.parametrize( - "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated", + "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", [ - ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False), + ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), + ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), ] ) -def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated): +def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): global server + server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL server.start() res = server.make_stream_request("POST", "/chat/completions", data={ - "model": model, "max_tokens": max_tokens, "messages": [ {"role": "system", "content": system_prompt}, @@ -63,16 +61,13 @@ def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, r content = "" for data in res: choice = data["choices"][0] + assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future if choice["finish_reason"] in ["stop", "length"]: assert data["usage"]["prompt_tokens"] == n_prompt assert data["usage"]["completion_tokens"] == n_predicted assert "content" not in choice["delta"] assert match_regex(re_content, content) - # FIXME: not sure why this is incorrect in stream mode - # if truncated: - # assert choice["finish_reason"] == "length" - # else: - # assert choice["finish_reason"] == "stop" + assert choice["finish_reason"] == finish_reason else: assert choice["finish_reason"] is None content += choice["delta"]["content"] @@ -93,7 +88,7 @@ def test_chat_completion_with_openai_library(): temperature=0.8, ) print(res) - assert res.choices[0].finish_reason == "stop" + assert res.choices[0].finish_reason == "length" assert res.choices[0].message.content is not None assert match_regex("(Suddenly)+", res.choices[0].message.content) diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 2fa30dd033431..1c3aa77de5bba 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -51,6 +51,24 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp content += data["content"] +def test_completion_stream_vs_non_stream(): + global server + server.start() + res_stream = server.make_stream_request("POST", "/completion", data={ + "n_predict": 8, + "prompt": "I believe the meaning of life is", + "stream": True, + }) + res_non_stream = server.make_request("POST", "/completion", data={ + "n_predict": 8, + "prompt": "I believe the meaning of life is", + }) + content_stream = "" + for data in res_stream: + content_stream += data["content"] + assert content_stream == res_non_stream.body["content"] + + @pytest.mark.parametrize("n_slots", [1, 2]) def test_consistent_result_same_seed(n_slots: int): global server @@ -221,3 +239,24 @@ def check_slots_status(): assert len(res.body["content"]) > 10 # FIXME: the result is not deterministic when using other slot than slot 0 # assert match_regex(re_content, res.body["content"]) + + +def test_n_probs(): + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "n_probs": 10, + "temperature": 0.0, + "n_predict": 5, + }) + assert res.status_code == 200 + assert "completion_probabilities" in res.body + assert len(res.body["completion_probabilities"]) == 5 + for tok in res.body["completion_probabilities"]: + assert "probs" in tok + assert len(tok["probs"]) == 10 + for prob in tok["probs"]: + assert "prob" in prob + assert "tok_str" in prob + assert 0.0 <= prob["prob"] <= 1.0 diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index e4451532c9d0c..a96116ac36caa 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -20,6 +20,7 @@ #include #include #include +#include #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" @@ -40,17 +41,6 @@ using json = nlohmann::ordered_json; #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 -enum error_type { - ERROR_TYPE_INVALID_REQUEST, - ERROR_TYPE_AUTHENTICATION, - ERROR_TYPE_SERVER, - ERROR_TYPE_NOT_FOUND, - ERROR_TYPE_PERMISSION, - ERROR_TYPE_UNAVAILABLE, // custom error - ERROR_TYPE_NOT_SUPPORTED, // custom error -}; - template static T json_value(const json & body, const std::string & key, const T & default_value) { // Fallback null to default value @@ -485,48 +475,11 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx, return out; } -struct completion_token_output { - llama_token tok; - std::string text_to_send; - - struct token_prob { - llama_token tok; - float prob; - }; - - std::vector probs; -}; - -// convert a vector of completion_token_output to json -static json probs_vector_to_json(const llama_context * ctx, const std::vector & probs) { - json out = json::array(); - - for (const auto & prob : probs) { - json probs_for_token = json::array(); - - for (const auto & p : prob.probs) { - const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok); - probs_for_token.push_back(json { - {"tok_str", tok_str}, - {"prob", p.prob}, - }); - } - - const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok); - out.push_back(json { - {"content", tok_str}, - {"probs", probs_for_token}, - }); - } - - return out; -} - static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) { const std::string str = std::string(event) + ": " + data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; // note: these newlines are important (not sure why though, if you know, add a comment to explain) + "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). LOG_DBG("data stream, to_send: %s", str.c_str()); @@ -604,164 +557,6 @@ static json oaicompat_completion_params_parse( return llama_params; } -static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) { - bool stopped_word = result.count("stopped_word") != 0; - bool stopped_eos = json_value(result, "stopped_eos", false); - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - std::string content = json_value(result, "content", std::string("")); - - std::string finish_reason = "length"; - if (stopped_word || stopped_eos) { - finish_reason = "stop"; - } - - json choices = - streaming ? json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}) - : json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"message", json{{"content", content}, - {"role", "assistant"}}}}}); - - std::time_t t = std::time(0); - - json res = json { - {"choices", choices}, - {"created", t}, - {"model", - json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, - {"usage", json { - {"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens} - }}, - {"id", completion_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = result; - } - - if (result.contains("completion_probabilities")) { - res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); - } - - if (result.contains("timings")) { - res.push_back({"timings", json_value(result, "timings", json::object())}); - } - - return res; -} - -// return value is vector as there is one case where we might need to generate two responses -static std::vector format_partial_response_oaicompat(const json & result, const std::string & completion_id) { - if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { - return std::vector({result}); - } - - bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; - std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - - bool stopped_word = json_value(result, "stopped_word", false); - bool stopped_eos = json_value(result, "stopped_eos", false); - bool stopped_limit = json_value(result, "stopped_limit", false); - std::string content = json_value(result, "content", std::string("")); - - std::string finish_reason; - if (stopped_word || stopped_eos) { - finish_reason = "stop"; - } - if (stopped_limit) { - finish_reason = "length"; - } - - std::time_t t = std::time(0); - - json choices; - - if (!finish_reason.empty()) { - choices = json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}); - } else { - if (first) { - if (content.empty()) { - choices = json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{{"role", "assistant"}}}}}); - } else { - // We have to send this as two updates to conform to openai behavior - json initial_ret = json{{"choices", json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"role", "assistant"} - }}}})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - json second_ret = json{ - {"choices", json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"content", content}}} - }})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - return std::vector({initial_ret, second_ret}); - } - } else { - // Some idiosyncrasy in task processing logic makes several trailing calls - // with empty content, we ignore these at the calee site. - if (content.empty()) { - return std::vector({json::object()}); - } - - choices = json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", - json{ - {"content", content}, - }}, - }}); - } - } - - json ret = json { - {"choices", choices}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"} - }; - - if (result.contains("timings")) { - ret.push_back({"timings", json_value(result, "timings", json::object())}); - } - - if (!finish_reason.empty()) { - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - ret.push_back({"usage", json { - {"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens} - }}); - } - - return std::vector({ret}); -} - static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) { json data = json::array(); int i = 0; @@ -853,43 +648,3 @@ static json format_detokenized_response(const std::string & content) { {"content", content} }; } - -static json format_error_response(const std::string & message, const enum error_type type) { - std::string type_str; - int code = 500; - switch (type) { - case ERROR_TYPE_INVALID_REQUEST: - type_str = "invalid_request_error"; - code = 400; - break; - case ERROR_TYPE_AUTHENTICATION: - type_str = "authentication_error"; - code = 401; - break; - case ERROR_TYPE_NOT_FOUND: - type_str = "not_found_error"; - code = 404; - break; - case ERROR_TYPE_SERVER: - type_str = "server_error"; - code = 500; - break; - case ERROR_TYPE_PERMISSION: - type_str = "permission_error"; - code = 403; - break; - case ERROR_TYPE_NOT_SUPPORTED: - type_str = "not_supported_error"; - code = 501; - break; - case ERROR_TYPE_UNAVAILABLE: - type_str = "unavailable_error"; - code = 503; - break; - } - return json { - {"code", code}, - {"message", message}, - {"type", type_str}, - }; -}