From dab76c92cc63072d9495ba87f2f3f3a4872d4f57 Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Mon, 23 Dec 2024 00:21:40 +0000 Subject: [PATCH] llama-run : include temperature option (#10899) This commit updates the `examples/run/README.md` file to include a new option for setting the temperature and updates the `run.cpp` file to parse this option. Signed-off-by: Eric Curtin --- examples/run/README.md | 2 + examples/run/run.cpp | 111 +++++++++++++++++++++++++++-------------- 2 files changed, 75 insertions(+), 38 deletions(-) diff --git a/examples/run/README.md b/examples/run/README.md index 874293516f4b6..a0680544120b9 100644 --- a/examples/run/README.md +++ b/examples/run/README.md @@ -19,6 +19,8 @@ Options: Context size (default: 2048) -n, --ngl Number of GPU layers (default: 0) + --temp + Temperature (default: 0.8) -v, --verbose, --log-verbose Set verbosity level to infinity (i.e. log all messages, useful for debugging) -h, --help diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 03da54ca3b2ef..f89d041c44ac1 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -55,29 +55,51 @@ static int printe(const char * fmt, ...) { class Opt { public: int init(int argc, const char ** argv) { + ctx_params = llama_context_default_params(); + model_params = llama_model_default_params(); + context_size_default = ctx_params.n_batch; + ngl_default = model_params.n_gpu_layers; + common_params_sampling sampling; + temperature_default = sampling.temp; + + if (argc < 2) { + printe("Error: No arguments provided.\n"); + print_help(); + return 1; + } + // Parse arguments if (parse(argc, argv)) { printe("Error: Failed to parse arguments.\n"); - help(); + print_help(); return 1; } // If help is requested, show help and exit - if (help_) { - help(); + if (help) { + print_help(); return 2; } + ctx_params.n_batch = context_size >= 0 ? context_size : context_size_default; + model_params.n_gpu_layers = ngl >= 0 ? ngl : ngl_default; + temperature = temperature >= 0 ? temperature : temperature_default; + return 0; // Success } + llama_context_params ctx_params; + llama_model_params model_params; std::string model_; - std::string user_; - int context_size_ = -1, ngl_ = -1; - bool verbose_ = false; + std::string user; + int context_size = -1, ngl = -1; + float temperature = -1; + bool verbose = false; private: - bool help_ = false; + int context_size_default = -1, ngl_default = -1; + float temperature_default = -1; + bool help = false; bool parse_flag(const char ** argv, int i, const char * short_opt, const char * long_opt) { return strcmp(argv[i], short_opt) == 0 || strcmp(argv[i], long_opt) == 0; @@ -89,6 +111,17 @@ class Opt { } option_value = std::atoi(argv[++i]); + + return 0; + } + + int handle_option_with_value(int argc, const char ** argv, int & i, float & option_value) { + if (i + 1 >= argc) { + return 1; + } + + option_value = std::atof(argv[++i]); + return 0; } @@ -96,18 +129,22 @@ class Opt { bool options_parsing = true; for (int i = 1, positional_args_i = 0; i < argc; ++i) { if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) { - if (handle_option_with_value(argc, argv, i, context_size_) == 1) { + if (handle_option_with_value(argc, argv, i, context_size) == 1) { return 1; } } else if (options_parsing && (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0)) { - if (handle_option_with_value(argc, argv, i, ngl_) == 1) { + if (handle_option_with_value(argc, argv, i, ngl) == 1) { + return 1; + } + } else if (options_parsing && strcmp(argv[i], "--temp") == 0) { + if (handle_option_with_value(argc, argv, i, temperature) == 1) { return 1; } } else if (options_parsing && (parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) { - verbose_ = true; + verbose = true; } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) { - help_ = true; + help = true; return 0; } else if (options_parsing && strcmp(argv[i], "--") == 0) { options_parsing = false; @@ -120,16 +157,16 @@ class Opt { model_ = argv[i]; } else if (positional_args_i == 1) { ++positional_args_i; - user_ = argv[i]; + user = argv[i]; } else { - user_ += " " + std::string(argv[i]); + user += " " + std::string(argv[i]); } } return 0; } - void help() const { + void print_help() const { printf( "Description:\n" " Runs a llm\n" @@ -142,6 +179,8 @@ class Opt { " Context size (default: %d)\n" " -n, --ngl \n" " Number of GPU layers (default: %d)\n" + " --temp \n" + " Temperature (default: %.1f)\n" " -v, --verbose, --log-verbose\n" " Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n" " -h, --help\n" @@ -170,7 +209,7 @@ class Opt { " llama-run file://some-file3.gguf\n" " llama-run --ngl 999 some-file4.gguf\n" " llama-run --ngl 999 some-file5.gguf Hello World\n", - llama_context_default_params().n_batch, llama_model_default_params().n_gpu_layers); + context_size_default, ngl_default, temperature_default); } }; @@ -495,12 +534,12 @@ class LlamaData { return 1; } - context = initialize_context(model, opt.context_size_); + context = initialize_context(model, opt); if (!context) { return 1; } - sampler = initialize_sampler(); + sampler = initialize_sampler(opt); return 0; } @@ -619,14 +658,12 @@ class LlamaData { // Initializes the model and returns a unique pointer to it llama_model_ptr initialize_model(Opt & opt) { ggml_backend_load_all(); - llama_model_params model_params = llama_model_default_params(); - model_params.n_gpu_layers = opt.ngl_ >= 0 ? opt.ngl_ : model_params.n_gpu_layers; resolve_model(opt.model_); printe( "\r%*s" "\rLoading model", get_terminal_width(), " "); - llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), model_params)); + llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), opt.model_params)); if (!model) { printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str()); } @@ -636,10 +673,8 @@ class LlamaData { } // Initializes the context with the specified parameters - llama_context_ptr initialize_context(const llama_model_ptr & model, const int n_ctx) { - llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = ctx_params.n_batch = n_ctx >= 0 ? n_ctx : ctx_params.n_batch; - llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params)); + llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) { + llama_context_ptr context(llama_new_context_with_model(model.get(), opt.ctx_params)); if (!context) { printe("%s: error: failed to create the llama_context\n", __func__); } @@ -648,10 +683,10 @@ class LlamaData { } // Initializes and configures the sampler - llama_sampler_ptr initialize_sampler() { + llama_sampler_ptr initialize_sampler(const Opt & opt) { llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params())); llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1)); - llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f)); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(opt.temperature)); llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); return sampler; @@ -798,9 +833,9 @@ static int apply_chat_template_with_error_handling(LlamaData & llama_data, const } // Helper function to handle user input -static int handle_user_input(std::string & user_input, const std::string & user_) { - if (!user_.empty()) { - user_input = user_; +static int handle_user_input(std::string & user_input, const std::string & user) { + if (!user.empty()) { + user_input = user; return 0; // No need for interactive input } @@ -832,17 +867,17 @@ static bool is_stdout_a_terminal() { } // Function to tokenize the prompt -static int chat_loop(LlamaData & llama_data, const std::string & user_) { +static int chat_loop(LlamaData & llama_data, const std::string & user) { int prev_len = 0; llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); static const bool stdout_a_terminal = is_stdout_a_terminal(); while (true) { // Get user input std::string user_input; - while (handle_user_input(user_input, user_)) { + while (handle_user_input(user_input, user)) { } - add_message("user", user_.empty() ? user_input : user_, llama_data); + add_message("user", user.empty() ? user_input : user, llama_data); int new_len; if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) { return 1; @@ -854,7 +889,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user_) { return 1; } - if (!user_.empty()) { + if (!user.empty()) { break; } @@ -869,7 +904,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user_) { static void log_callback(const enum ggml_log_level level, const char * text, void * p) { const Opt * opt = static_cast(p); - if (opt->verbose_ || level == GGML_LOG_LEVEL_ERROR) { + if (opt->verbose || level == GGML_LOG_LEVEL_ERROR) { printe("%s", text); } } @@ -890,11 +925,11 @@ int main(int argc, const char ** argv) { } if (!is_stdin_a_terminal()) { - if (!opt.user_.empty()) { - opt.user_ += "\n\n"; + if (!opt.user.empty()) { + opt.user += "\n\n"; } - opt.user_ += read_pipe_data(); + opt.user += read_pipe_data(); } llama_log_set(log_callback, &opt); @@ -903,7 +938,7 @@ int main(int argc, const char ** argv) { return 1; } - if (chat_loop(llama_data, opt.user_)) { + if (chat_loop(llama_data, opt.user)) { return 1; }