Skip to content

Commit

Permalink
feat(option): to try several special token names
Browse files Browse the repository at this point in the history
  • Loading branch information
grencez committed Jul 27, 2024
1 parent 54030b1 commit bda5832
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 17 deletions.
14 changes: 12 additions & 2 deletions example/prompt/assistant_chatml/setting.sxpb
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,18 @@
;(bos_token_alias "<|im_start|>")
;(eos_token_alias "<|im_end|>")
(special_tokens (())
(() (name "<|im_start|>"))
(() (name "<|im_end|>"))
(()
(alias "<|im_start|>")
(candidates (())
"<|im_start|>"
"<start_of_turn>" ; For Gemma models.
))
(()
(alias "<|im_end|>")
(candidates (())
"<|im_end|>"
"<end_of_turn>" ; For Gemma models.
))
)
)

Expand Down
15 changes: 12 additions & 3 deletions example/prompt/assistant_gemma/setting.sxpb
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,22 @@
)
(substitution
(special_tokens (())
(() (name "<start_of_turn>"))
(() (name "<end_of_turn>"))
(()
(alias "<start_of_turn>")
(candidates (())
"<start_of_turn>"
"<|im_start|>" ; For ChatML models.
))
(()
(alias "<end_of_turn>")
(candidates (())
"<end_of_turn>"
"<|im_end|>" ; For ChatML models.
))
)
)

(x_priming "priming.txt")
(x_rolling "rolling.txt")
(o_rolling "../../../bld/example/prompt/assistant_gemma.txt")

; No starting space.
Expand Down
5 changes: 2 additions & 3 deletions example/prompt/assistant_mistral/setting.sxpb
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
(substitution
(eos_token_alias "</s>")
(special_tokens (())
(() (name "[INST]"))
(() (name "[/INST]"))
(() (alias "[INST]"))
(() (alias "[/INST]"))
)
)

(x_priming "priming.txt")
(x_rolling "rolling.txt")
(o_rolling "../../../bld/example/prompt/assistant_mistral.txt")

Expand Down
14 changes: 9 additions & 5 deletions src/chat/chat_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,18 @@ int main(int argc, char** argv)
vocabulary.assign_substitution(
opt.eos_token_alias, vocabulary.eos_token_id());
}
for (const auto& name : opt.special_token_names) {
Vocabulary::Token_id token_id = vocabulary.tokenize_special(name);
if (token_id < static_cast<Vocabulary::Token_id>(vocabulary.cardinality())) {
vocabulary.assign_substitution(name, token_id);
for (const auto& special : opt.special_tokens) {
Vocabulary::Token_id token_id = Vocabulary::null_token_id;
for (const auto& name : special.candidates) {
token_id = vocabulary.tokenize_special(name);
if (token_id != Vocabulary::null_token_id) {break;}
}
if (token_id != Vocabulary::null_token_id) {
vocabulary.assign_substitution(special.alias, token_id);
}
else {
exstatus = 65;
fildesh_log_errorf("Unknown special token: %s", name.c_str());
fildesh_log_errorf("Unknown special token: %s", special.alias.c_str());
}
}
chat_disp.out_ = open_FildeshOF("/dev/stdout");
Expand Down
15 changes: 13 additions & 2 deletions src/chat/opt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,19 @@ slurp_sxpb_options_close_FildeshX(
if (!nullish_FildeshSxpbIT(sub_it)) {
for (sub_it = first_at_FildeshSxpb(sxpb, sub_it); !nullish_FildeshSxpbIT(sub_it);
sub_it = next_at_FildeshSxpb(sxpb, sub_it)) {
if (lone_subfield_at_FildeshSxpb_to_str(&s, sxpb, sub_it, "name")) {
opt.special_token_names.push_back(s);
auto& special = opt.special_tokens.emplace_back();
lone_subfield_at_FildeshSxpb_to_cc_string(&special.alias, sxpb, sub_it, "alias");
assert(!special.alias.empty());
FildeshSxpbIT candidate_it = lookup_subfield_at_FildeshSxpb(sxpb, sub_it, "candidates");
if (nullish_FildeshSxpbIT(candidate_it)) {
special.candidates.push_back(special.alias);
}
else {
for (candidate_it = first_at_FildeshSxpb(sxpb, candidate_it);
!nullish_FildeshSxpbIT(candidate_it);
candidate_it = next_at_FildeshSxpb(sxpb, candidate_it)) {
special.candidates.push_back(str_value_at_FildeshSxpb(sxpb, candidate_it));
}
}
}
}
Expand Down
7 changes: 6 additions & 1 deletion src/chat/opt.hh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ struct ChatMessageOpt {
std::string given_suffix;
};

struct SpecialToken {
std::string alias;
std::vector<std::string> candidates;
};

struct ChatOptions {

std::string protagonist;
Expand All @@ -26,7 +31,7 @@ struct ChatOptions {
std::string confidant_alias;
std::string bos_token_alias;
std::string eos_token_alias;
std::vector<std::string> special_token_names;
std::vector<SpecialToken> special_tokens;
std::vector<ChatMessageOpt> message_opts;
std::string model_filename;
std::string lora_filename;
Expand Down
4 changes: 3 additions & 1 deletion src/chat/opt_schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ static FildeshSxprotoField chat_prefixes_manyof[] = {
{"m", FILL_FildeshSxprotoField_MESSAGE(chat_prefixes_m_message)},
};
static FildeshSxprotoField special_token_message[] = {
{"name", FILL_FildeshSxprotoField_STRING(1, INT_MAX)},
{"alias", FILL_FildeshSxprotoField_STRING(1, INT_MAX)},
{"name", FILL_DEFAULT_FildeshSxprotoField_ALIAS},
{"candidates", FILL_DEFAULT_FildeshSxprotoField_STRINGS},
};
static FildeshSxprotoField substitution_message[] = {
{"protagonist_alias", FILL_FildeshSxprotoField_STRING(1, INT_MAX)},
Expand Down
1 change: 1 addition & 0 deletions src/language/vocabulary.hh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace rendezllama {
class Vocabulary {
public:
typedef int Token_id;
static const Token_id null_token_id = -1;

public:
explicit Vocabulary(const llama_model* model);
Expand Down

0 comments on commit bda5832

Please sign in to comment.