Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP speculative #4683

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 224 additions & 3 deletions backend/cpp/llama/grpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "backend.grpc.pb.h"
#include "utils.hpp"
#include "sampling.h"
#include "speculative.h"
// include std::regex
#include <cstddef>
#include <thread>
Expand Down Expand Up @@ -185,12 +186,45 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
return out;
}

struct llama_slot_params {
uint32_t seed = -1; // RNG seed
bool stream = true;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
bool return_tokens = false;

int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters

int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit

std::vector<common_adapter_lora_info> lora;

std::vector<std::string> antiprompt;
std::vector<std::string> response_fields;
bool timings_per_token = false;
bool post_sampling_probs = false;
bool ignore_eos = false;

json input_prefix;
json input_suffix;

struct common_params_sampling sampling;
struct common_params_speculative speculative;
};


struct llama_client_slot
{
int id;
int task_id = -1;

struct slot_params params;
struct llama_slot_params params;
common_speculative * spec = nullptr;
llama_batch batch_spec = {};


slot_state state = IDLE;
slot_command command = NONE;
Expand Down Expand Up @@ -283,6 +317,7 @@ struct llama_client_slot
images.clear();
}


bool has_budget(common_params &global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1)
{
Expand Down Expand Up @@ -454,6 +489,10 @@ struct llama_server_context
{
llama_model *model = nullptr;
llama_context *ctx = nullptr;
common_init_result llama_init_dft;
llama_context * ctx_dft = nullptr;
llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
const llama_vocab * vocab = nullptr;

clip_ctx *clp_ctx = nullptr;
Expand Down Expand Up @@ -502,6 +541,7 @@ struct llama_server_context
}
}


bool load_model(const common_params &params_)
{
params = params_;
Expand Down Expand Up @@ -545,6 +585,45 @@ struct llama_server_context
add_bos_token = llama_vocab_get_add_bos(vocab);
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;

if (!params.speculative.model.empty()) {
LOG("loading draft model '%s'\n", params.speculative.model.c_str());

auto params_dft = params;

params_dft.devices = params.speculative.devices;
params_dft.model = params.speculative.model;
params_dft.n_ctx = params.speculative.n_ctx == 0 ? params.n_ctx / params.n_parallel : params.speculative.n_ctx;
params_dft.n_gpu_layers = params.speculative.n_gpu_layers;
params_dft.n_parallel = 1;

llama_init_dft = common_init_from_params(params_dft);

model_dft = llama_init_dft.model.get();

if (model_dft == nullptr) {
LOG("failed to load draft model, '%s'\n", params.speculative.model.c_str());
return false;
}

if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) {
LOG("the draft model '%s' is not compatible with the target model '%s'\n", params.speculative.model.c_str(), params.model.c_str());

return false;
}

const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());

cparams_dft = common_context_params_to_llama(params_dft);
cparams_dft.n_batch = n_ctx_dft;

// force F16 KV cache for the draft model for extra performance
cparams_dft.type_k = GGML_TYPE_F16;
cparams_dft.type_v = GGML_TYPE_F16;

// the context is not needed - we will create one for each slot
llama_init_dft.context.reset();
}

return true;
}

Expand Down Expand Up @@ -573,6 +652,22 @@ struct llama_server_context
slot.n_ctx = n_ctx_slot;
slot.n_predict = params.n_predict;

if (model_dft) {
slot.batch_spec = llama_batch_init(params.speculative.n_max + 1, 0, 1);

ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (ctx_dft == nullptr) {
LOG("%s", "failed to create draft context\n");
return;
}

slot.spec = common_speculative_init(ctx_dft);
if (slot.spec == nullptr) {
LOG("%s", "failed to create speculator\n");
return;
}
}

LOG_INFO("new slot", {
{"slot_id", slot.id},
{"n_ctx_slot", slot.n_ctx}
Expand Down Expand Up @@ -681,9 +776,11 @@ struct llama_server_context
}

bool launch_slot_with_data(llama_client_slot* &slot, json data) {
slot_params default_params;
llama_slot_params default_params;
common_params_sampling default_sparams;


default_sparams.speculative = params_base.speculative;

slot->params.stream = json_value(data, "stream", false);
slot->params.cache_prompt = json_value(data, "cache_prompt", false);
slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict);
Expand All @@ -707,6 +804,15 @@ struct llama_server_context
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);


slot->sparams.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
slot->sparams.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
slot->sparams.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);

slot->sparams.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
slot->sparams.speculative.n_min = std::max(params.speculative.n_min, 2);
slot->sparams.speculative.n_max = std::max(params.speculative.n_max, 0);

if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) {
// Might be better to reject the request with a 400 ?
LOG_WARNING("Max tokens to predict exceeds server configuration", {
Expand Down Expand Up @@ -2024,6 +2130,97 @@ struct llama_server_context
}
}

// do speculative decoding
for (auto & slot : slots) {
if (!slot.is_processing() || !(ctx_dft && params.speculative.n_max > 0)) {
continue;
}

if (slot.state != PROCESSING) {
continue;
}

// determine the max draft that fits the current slot state
int n_draft_max = slot.params.speculative.n_max;

// note: n_past is not yet increased for the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);

if (slot.n_remaining > 0) {
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
}

LOG("max possible draft: %d\n", n_draft_max);

if (n_draft_max < slot.params.speculative.n_min) {
LOG("the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);

continue;
}

llama_token id = slot.sampled;

struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;

llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);

// ignore small drafts
if (slot.params.speculative.n_min > (int) draft.size()) {
LOG("ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);

continue;
}

// construct the speculation batch
common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);

for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
}

LOG("decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);

llama_decode(ctx, slot.batch_spec);

// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, draft);

slot.n_past += ids.size();
slot.n_decoded += ids.size();

slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);

llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);

for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;

result.tok = ids[i];
result.text_to_send = common_token_to_piece(ctx, result.tok, params.special);
//result.prob = 1.0f; // set later

// TODO: set result.probs

if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
break;
}
}

LOG("accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
}


LOG_VERBOSE("slots updated", {});
return true;
}
Expand Down Expand Up @@ -2296,6 +2493,30 @@ static void params_parse(const backend::ModelOptions* request,
params.cpuparams.n_threads = request->threads();
params.n_gpu_layers = request->ngpulayers();
params.n_batch = request->nbatch();
params.speculative.model = request->draftmodel();

// If options is not NULL, parse options
for (int i = 0; request->options()[i] != NULL; i++) {
char *optname = strtok(request->options()[i], ":");
char *optval = strtok(NULL, ":");
if (optval == NULL) {
optval = "true";
}

if (!strcmp(optname, "speculative.n_gpu_layers")) {
params.speculative.n_gpu_layers = std::stoi(optval);
}
if (!strcmp(optname, "speculative.n_ctx")) {
params.speculative.n_ctx = std::stoi(optval);
}
}

if params.speculative.n_gpu_layers == 0 {
params.speculative.n_gpu_layers = params.n_gpu_layers;
}
if params.speculative.n_ctx == 0 {
params.speculative.n_ctx = params.n_ctx;
}
// Set params.n_parallel by environment variable (LLAMA_PARALLEL), defaults to 1
//params.n_parallel = 1;
const char *env_parallel = std::getenv("LLAMACPP_PARALLEL");
Expand Down
Loading