From bda5832eaac088ec95520f8da62103f5041d3d83 Mon Sep 17 00:00:00 2001 From: grencez Date: Sat, 27 Jul 2024 13:20:47 -0700 Subject: [PATCH] feat(option): to try several special token names --- example/prompt/assistant_chatml/setting.sxpb | 14 ++++++++++++-- example/prompt/assistant_gemma/setting.sxpb | 15 ++++++++++++--- example/prompt/assistant_mistral/setting.sxpb | 5 ++--- src/chat/chat_main.cc | 14 +++++++++----- src/chat/opt.cc | 15 +++++++++++++-- src/chat/opt.hh | 7 ++++++- src/chat/opt_schema.cc | 4 +++- src/language/vocabulary.hh | 1 + 8 files changed, 58 insertions(+), 17 deletions(-) diff --git a/example/prompt/assistant_chatml/setting.sxpb b/example/prompt/assistant_chatml/setting.sxpb index 1d36392..b2e8d5a 100644 --- a/example/prompt/assistant_chatml/setting.sxpb +++ b/example/prompt/assistant_chatml/setting.sxpb @@ -12,8 +12,18 @@ ;(bos_token_alias "<|im_start|>") ;(eos_token_alias "<|im_end|>") (special_tokens (()) - (() (name "<|im_start|>")) - (() (name "<|im_end|>")) + (() + (alias "<|im_start|>") + (candidates (()) + "<|im_start|>" + "" ; For Gemma models. + )) + (() + (alias "<|im_end|>") + (candidates (()) + "<|im_end|>" + "" ; For Gemma models. + )) ) ) diff --git a/example/prompt/assistant_gemma/setting.sxpb b/example/prompt/assistant_gemma/setting.sxpb index 2f3a684..1ee1adb 100644 --- a/example/prompt/assistant_gemma/setting.sxpb +++ b/example/prompt/assistant_gemma/setting.sxpb @@ -9,13 +9,22 @@ ) (substitution (special_tokens (()) - (() (name "")) - (() (name "")) + (() + (alias "") + (candidates (()) + "" + "<|im_start|>" ; For ChatML models. + )) + (() + (alias "") + (candidates (()) + "" + "<|im_end|>" ; For ChatML models. + )) ) ) (x_priming "priming.txt") -(x_rolling "rolling.txt") (o_rolling "../../../bld/example/prompt/assistant_gemma.txt") ; No starting space. diff --git a/example/prompt/assistant_mistral/setting.sxpb b/example/prompt/assistant_mistral/setting.sxpb index 4826ca4..ecde125 100644 --- a/example/prompt/assistant_mistral/setting.sxpb +++ b/example/prompt/assistant_mistral/setting.sxpb @@ -11,12 +11,11 @@ (substitution (eos_token_alias "") (special_tokens (()) - (() (name "[INST]")) - (() (name "[/INST]")) + (() (alias "[INST]")) + (() (alias "[/INST]")) ) ) -(x_priming "priming.txt") (x_rolling "rolling.txt") (o_rolling "../../../bld/example/prompt/assistant_mistral.txt") diff --git a/src/chat/chat_main.cc b/src/chat/chat_main.cc index e4c4ac7..a7e045e 100644 --- a/src/chat/chat_main.cc +++ b/src/chat/chat_main.cc @@ -117,14 +117,18 @@ int main(int argc, char** argv) vocabulary.assign_substitution( opt.eos_token_alias, vocabulary.eos_token_id()); } - for (const auto& name : opt.special_token_names) { - Vocabulary::Token_id token_id = vocabulary.tokenize_special(name); - if (token_id < static_cast(vocabulary.cardinality())) { - vocabulary.assign_substitution(name, token_id); + for (const auto& special : opt.special_tokens) { + Vocabulary::Token_id token_id = Vocabulary::null_token_id; + for (const auto& name : special.candidates) { + token_id = vocabulary.tokenize_special(name); + if (token_id != Vocabulary::null_token_id) {break;} + } + if (token_id != Vocabulary::null_token_id) { + vocabulary.assign_substitution(special.alias, token_id); } else { exstatus = 65; - fildesh_log_errorf("Unknown special token: %s", name.c_str()); + fildesh_log_errorf("Unknown special token: %s", special.alias.c_str()); } } chat_disp.out_ = open_FildeshOF("/dev/stdout"); diff --git a/src/chat/opt.cc b/src/chat/opt.cc index 160282c..1d0193f 100644 --- a/src/chat/opt.cc +++ b/src/chat/opt.cc @@ -514,8 +514,19 @@ slurp_sxpb_options_close_FildeshX( if (!nullish_FildeshSxpbIT(sub_it)) { for (sub_it = first_at_FildeshSxpb(sxpb, sub_it); !nullish_FildeshSxpbIT(sub_it); sub_it = next_at_FildeshSxpb(sxpb, sub_it)) { - if (lone_subfield_at_FildeshSxpb_to_str(&s, sxpb, sub_it, "name")) { - opt.special_token_names.push_back(s); + auto& special = opt.special_tokens.emplace_back(); + lone_subfield_at_FildeshSxpb_to_cc_string(&special.alias, sxpb, sub_it, "alias"); + assert(!special.alias.empty()); + FildeshSxpbIT candidate_it = lookup_subfield_at_FildeshSxpb(sxpb, sub_it, "candidates"); + if (nullish_FildeshSxpbIT(candidate_it)) { + special.candidates.push_back(special.alias); + } + else { + for (candidate_it = first_at_FildeshSxpb(sxpb, candidate_it); + !nullish_FildeshSxpbIT(candidate_it); + candidate_it = next_at_FildeshSxpb(sxpb, candidate_it)) { + special.candidates.push_back(str_value_at_FildeshSxpb(sxpb, candidate_it)); + } } } } diff --git a/src/chat/opt.hh b/src/chat/opt.hh index ba56528..3ba0080 100644 --- a/src/chat/opt.hh +++ b/src/chat/opt.hh @@ -18,6 +18,11 @@ struct ChatMessageOpt { std::string given_suffix; }; +struct SpecialToken { + std::string alias; + std::vector candidates; +}; + struct ChatOptions { std::string protagonist; @@ -26,7 +31,7 @@ struct ChatOptions { std::string confidant_alias; std::string bos_token_alias; std::string eos_token_alias; - std::vector special_token_names; + std::vector special_tokens; std::vector message_opts; std::string model_filename; std::string lora_filename; diff --git a/src/chat/opt_schema.cc b/src/chat/opt_schema.cc index 20ca537..c3de847 100644 --- a/src/chat/opt_schema.cc +++ b/src/chat/opt_schema.cc @@ -13,7 +13,9 @@ static FildeshSxprotoField chat_prefixes_manyof[] = { {"m", FILL_FildeshSxprotoField_MESSAGE(chat_prefixes_m_message)}, }; static FildeshSxprotoField special_token_message[] = { - {"name", FILL_FildeshSxprotoField_STRING(1, INT_MAX)}, + {"alias", FILL_FildeshSxprotoField_STRING(1, INT_MAX)}, + {"name", FILL_DEFAULT_FildeshSxprotoField_ALIAS}, + {"candidates", FILL_DEFAULT_FildeshSxprotoField_STRINGS}, }; static FildeshSxprotoField substitution_message[] = { {"protagonist_alias", FILL_FildeshSxprotoField_STRING(1, INT_MAX)}, diff --git a/src/language/vocabulary.hh b/src/language/vocabulary.hh index a2d24d6..35bb407 100644 --- a/src/language/vocabulary.hh +++ b/src/language/vocabulary.hh @@ -12,6 +12,7 @@ namespace rendezllama { class Vocabulary { public: typedef int Token_id; + static const Token_id null_token_id = -1; public: explicit Vocabulary(const llama_model* model);