From dfb1a7381c010002fd5dbec13251d81632e8ec6d Mon Sep 17 00:00:00 2001 From: Plamen Minev Date: Mon, 20 Jan 2025 18:28:08 +0200 Subject: [PATCH] feat: expose option for passing a grammar to the sampler, ref #8 --- code/ac/llama/Instance.cpp | 4 +- code/ac/llama/Instance.hpp | 1 + test/t-integration.cpp | 89 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 1 deletion(-) diff --git a/code/ac/llama/Instance.cpp b/code/ac/llama/Instance.cpp index 7eeefd2..a93e484 100644 --- a/code/ac/llama/Instance.cpp +++ b/code/ac/llama/Instance.cpp @@ -34,7 +34,9 @@ llama_context_params llamaFromInstanceInitParams(const Instance::InitParams& par Instance::Instance(Model& model, InitParams params) : m_model(model) - , m_sampler(model, {}) + , m_sampler(model, { + .grammar = params.grammar, + }) , m_lctx(llama_init_from_model(model.lmodel(), llamaFromInstanceInitParams(params)), llama_free) { if (!m_lctx) { diff --git a/code/ac/llama/Instance.hpp b/code/ac/llama/Instance.hpp index bc76bc3..0da5c9a 100644 --- a/code/ac/llama/Instance.hpp +++ b/code/ac/llama/Instance.hpp @@ -23,6 +23,7 @@ class AC_LLAMA_EXPORT Instance { uint32_t batchSize = 2048; // logical batch size for prompt processing (may be silently truncated to ctxSize) uint32_t ubatchSize = 512; // physical batch size for prompt processing (0 = batchSize) bool flashAttn = false; // enable flash attention + std::string grammar; // BNF-styled grammar }; explicit Instance(Model& model, InitParams params); diff --git a/test/t-integration.cpp b/test/t-integration.cpp index 41c86b8..37c891f 100644 --- a/test/t-integration.cpp +++ b/test/t-integration.cpp @@ -321,3 +321,92 @@ TEST_CASE("control_vector") { } } } + +TEST_CASE("grammar") { + ac::llama::Model::Params iParams = {}; + auto lmodel = ac::llama::ModelRegistry::getInstance().loadModel(Model_117m_q6_k, {}, iParams); + ac::llama::Model model(lmodel, iParams); + CHECK(!!model.lmodel()); + + auto& params = model.params(); + CHECK(params.gpu); + CHECK_FALSE(params.vocabOnly); + + CHECK(model.trainCtxLength() == 1024); + CHECK_FALSE(model.shouldAddBosToken()); + CHECK_FALSE(model.hasEncoder()); + + SUBCASE("Numbers 6-9 only") { + ac::llama::Instance::InitParams iparams; + iparams.grammar = R""""( +root ::= en-char+ ([ \t\n] en-char+)* +en-char ::= digit | letter +letter ::= [a-zA-Z] +digit ::= [6-9] + )""""; + + ac::llama::Instance inst(model, iparams); + inst.warmup(); // should be safe + + auto& s = inst.startSession({}); + std::vector tokens = model.vocab().tokenize("My name is Daniel and my age is", true, true); + s.setInitialPrompt(tokens); + std::string text; + for (int i = 0; i < 5; ++i) { + auto t = s.getToken(); + REQUIRE(t != ac::llama::Token_Invalid); + text += model.vocab().tokenToString(t); + } + + CHECK(text == "s about 9 years old"); + } + + SUBCASE("Numbers 1-5 only") { + ac::llama::Instance::InitParams iparams; + iparams.grammar = R""""( +root ::= en-char+ ([ \t\n] en-char+)* +en-char ::= digit | letter +letter ::= [a-zA-Z] +digit ::= [1-5] + )""""; + + ac::llama::Instance inst(model, iparams); + inst.warmup(); // should be safe + + auto& s = inst.startSession({}); + std::vector tokens = model.vocab().tokenize("My name is Daniel and my age is", true, true); + s.setInitialPrompt(tokens); + std::string text; + for (int i = 0; i < 5; ++i) { + auto t = s.getToken(); + REQUIRE(t != ac::llama::Token_Invalid); + text += model.vocab().tokenToString(t); + } + + CHECK(text == "54 and I am an"); + } + + SUBCASE("All capital letters") { + ac::llama::Instance::InitParams iparams; + iparams.grammar = R""""( +root ::= en-char+ ([ \t\n] en-char+)* +en-char ::= letter +letter ::= [A-Z] + )""""; + + ac::llama::Instance inst(model, iparams); + inst.warmup(); // should be safe + + auto& s = inst.startSession({}); + std::vector tokens = model.vocab().tokenize("My name is Daniel and my age is", true, true); + s.setInitialPrompt(tokens); + std::string text; + for (int i = 0; i < 5; ++i) { + auto t = s.getToken(); + REQUIRE(t != ac::llama::Token_Invalid); + text += model.vocab().tokenToString(t); + } + + CHECK(text == "ELLIE JONES"); + } +}