Skip to content

Commit

Permalink
feat: expose option for passing a grammar to the sampler, ref #8
Browse files Browse the repository at this point in the history
  • Loading branch information
pminev committed Jan 20, 2025
1 parent 060bb7a commit 947b7ce
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 1 deletion.
4 changes: 3 additions & 1 deletion code/ac/llama/Instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ llama_context_params llamaFromInstanceInitParams(const Instance::InitParams& par

Instance::Instance(Model& model, InitParams params)
: m_model(model)
, m_sampler(model, {})
, m_sampler(model, {
.grammar = params.grammar,
})
, m_lctx(llama_init_from_model(model.lmodel(), llamaFromInstanceInitParams(params)), llama_free)
{
if (!m_lctx) {
Expand Down
1 change: 1 addition & 0 deletions code/ac/llama/Instance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class AC_LLAMA_EXPORT Instance {
uint32_t batchSize = 2048; // logical batch size for prompt processing (may be silently truncated to ctxSize)
uint32_t ubatchSize = 512; // physical batch size for prompt processing (0 = batchSize)
bool flashAttn = false; // enable flash attention
std::string grammar; // BNF-styled grammar
};

explicit Instance(Model& model, InitParams params);
Expand Down
89 changes: 89 additions & 0 deletions test/t-integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,92 @@ TEST_CASE("control_vector") {
}
}
}

TEST_CASE("grammar") {
ac::llama::Model::Params iParams = {};
auto lmodel = ac::llama::ModelRegistry::getInstance().loadModel(Model_117m_q6_k, {}, iParams);
ac::llama::Model model(lmodel, iParams);
CHECK(!!model.lmodel());

auto& params = model.params();
CHECK(params.gpu);
CHECK_FALSE(params.vocabOnly);

CHECK(model.trainCtxLength() == 1024);
CHECK_FALSE(model.shouldAddBosToken());
CHECK_FALSE(model.hasEncoder());

SUBCASE("Numbers 6-9 only") {
ac::llama::Instance::InitParams iparams;
iparams.grammar = R""""(
root ::= en-char+ ([ \t\n] en-char+)*
en-char ::= digit | letter
letter ::= [a-zA-Z]
digit ::= [6-9]
)"""";

ac::llama::Instance inst(model, iparams);
inst.warmup(); // should be safe

auto& s = inst.startSession({});
std::vector<ac::llama::Token> tokens = model.vocab().tokenize("My name is Daniel and my age is", true, true);
s.setInitialPrompt(tokens);
std::string text;
for (int i = 0; i < 5; ++i) {
auto t = s.getToken();
REQUIRE(t != ac::llama::Token_Invalid);
text += model.vocab().tokenToString(t);
}

CHECK(text == "s about 9 years old");
}

SUBCASE("Numbers 1-5 only") {
ac::llama::Instance::InitParams iparams;
iparams.grammar = R""""(
root ::= en-char+ ([ \t\n] en-char+)*
en-char ::= digit | letter
letter ::= [a-zA-Z]
digit ::= [1-5]
)"""";

ac::llama::Instance inst(model, iparams);
inst.warmup(); // should be safe

auto& s = inst.startSession({});
std::vector<ac::llama::Token> tokens = model.vocab().tokenize("My name is Daniel and my age is", true, true);
s.setInitialPrompt(tokens);
std::string text;
for (int i = 0; i < 5; ++i) {
auto t = s.getToken();
REQUIRE(t != ac::llama::Token_Invalid);
text += model.vocab().tokenToString(t);
}

CHECK(text == "54 and I am an");
}

SUBCASE("All capital letters") {
ac::llama::Instance::InitParams iparams;
iparams.grammar = R""""(
root ::= en-char+ ([ \t\n] en-char+)*
en-char ::= letter
letter ::= [A-Z]
)"""";

ac::llama::Instance inst(model, iparams);
inst.warmup(); // should be safe

auto& s = inst.startSession({});
std::vector<ac::llama::Token> tokens = model.vocab().tokenize("My name is Daniel and my age is", true, true);
s.setInitialPrompt(tokens);
std::string text;
for (int i = 0; i < 5; ++i) {
auto t = s.getToken();
REQUIRE(t != ac::llama::Token_Invalid);
text += model.vocab().tokenToString(t);
}

CHECK(text == "ELLIE JONES");
}
}

0 comments on commit 947b7ce

Please sign in to comment.