diff --git a/ac-local-plugin/code/LocalLlama.cpp b/ac-local-plugin/code/LocalLlama.cpp index cd971eb..812b266 100644 --- a/ac-local-plugin/code/LocalLlama.cpp +++ b/ac-local-plugin/code/LocalLlama.cpp @@ -170,7 +170,8 @@ class LlamaInstance final : public Instance { } auto tokenStr = model.vocab().tokenToString(t); - if (antiprompt.feedGeneratedText(tokenStr)) { + auto matchedAntiPrompt = antiprompt.feedGeneratedText(tokenStr); + if (!matchedAntiPrompt.empty()) { break; } diff --git a/code/ac/llama/AntipromptManager.cpp b/code/ac/llama/AntipromptManager.cpp index ba5359d..408f9a6 100644 --- a/code/ac/llama/AntipromptManager.cpp +++ b/code/ac/llama/AntipromptManager.cpp @@ -9,15 +9,18 @@ void AntipromptManager::addAntiprompt(std::string_view antiprompt) { m_antiprompts.push_back(std::string(antiprompt)); } -bool AntipromptManager::feedGeneratedText(std::string_view text) { +std::string AntipromptManager::feedGeneratedText(std::string_view text) { for (auto& ap : m_antiprompts) { - if (ap.feedText(text)) { + int found = ap.feedText(text); + if (found > 0) { reset(); - return true; + return found == 0 ? + ap.getString(): + ap.getString() + std::string(text.substr(found, text.length())); } } - return false; + return {}; } void AntipromptManager::reset() { diff --git a/code/ac/llama/AntipromptManager.hpp b/code/ac/llama/AntipromptManager.hpp index bd2ad96..9690c7f 100644 --- a/code/ac/llama/AntipromptManager.hpp +++ b/code/ac/llama/AntipromptManager.hpp @@ -19,7 +19,7 @@ class AC_LLAMA_EXPORT AntipromptManager { void addAntiprompt(std::string_view antiprompt); // feed each antiprompt with the text - bool feedGeneratedText(std::string_view text); + std::string feedGeneratedText(std::string_view text); // reset the state of all antiprompts void reset(); diff --git a/code/ac/llama/IncrementalStringFinder.cpp b/code/ac/llama/IncrementalStringFinder.cpp index 44c8a70..bbc47c3 100644 --- a/code/ac/llama/IncrementalStringFinder.cpp +++ b/code/ac/llama/IncrementalStringFinder.cpp @@ -10,9 +10,9 @@ IncrementalStringFinder::IncrementalStringFinder(std::string searchStr) , m_currentPos(0) {} -bool IncrementalStringFinder::feedText(std::string_view text) { +int IncrementalStringFinder::feedText(std::string_view text) { if (m_searchStr.length() == 0) { - return false; + return -1; } uint32_t promptPos = 0; @@ -32,10 +32,10 @@ bool IncrementalStringFinder::feedText(std::string_view text) { if (m_currentPos == m_searchStr.length()) { m_currentPos = 0; - return true; + return promptPos; } - return false; + return -1; } void IncrementalStringFinder::reset() { diff --git a/code/ac/llama/IncrementalStringFinder.hpp b/code/ac/llama/IncrementalStringFinder.hpp index 780adc6..689ae7a 100644 --- a/code/ac/llama/IncrementalStringFinder.hpp +++ b/code/ac/llama/IncrementalStringFinder.hpp @@ -11,14 +11,19 @@ namespace ac::llama { class AC_LLAMA_EXPORT IncrementalStringFinder { public: - IncrementalStringFinder(std::string searchStr = ""); + IncrementalStringFinder(std::string searchStr); // incremental search for `m_str` in `text` - bool feedText(std::string_view text); + // returns -1 if the search string was not found + // returns >=0 the search string was found and the count of matched characters of last feed + int feedText(std::string_view text); // reset the `currentPos` void reset(); + // return the string that was searched for + const std::string& getString() const { return m_searchStr; } + private: std::string m_searchStr; uint16_t m_currentPos; diff --git a/example/e-gui.cpp b/example/e-gui.cpp index 83deba1..7953c80 100644 --- a/example/e-gui.cpp +++ b/example/e-gui.cpp @@ -106,8 +106,8 @@ class UModel { auto tokenStr = m_vocab.tokenToString(token); m_text += tokenStr; - - if (m_antiprompt.feedGeneratedText(tokenStr)) { + auto matchedAntiPrompt = m_antiprompt.feedGeneratedText(tokenStr); + if (!matchedAntiPrompt.empty()) { m_numTokens = 0; return; } diff --git a/test/t-Antiprompt.cpp b/test/t-Antiprompt.cpp index 617f813..9c9fa13 100644 --- a/test/t-Antiprompt.cpp +++ b/test/t-Antiprompt.cpp @@ -9,58 +9,58 @@ TEST_CASE("incremental finder - empty") { // by default string finder has empty search string // and will always return false - ac::llama::IncrementalStringFinder f; - CHECK(f.feedText("") == false); - CHECK(f.feedText("empty") == false); + ac::llama::IncrementalStringFinder f(""); + CHECK(f.feedText("") == -1); + CHECK(f.feedText("empty") == -1); f = ac::llama::IncrementalStringFinder("demo"); // empty feed - CHECK(f.feedText("") == false); + CHECK(f.feedText("") == -1); } TEST_CASE("incremental finder - partial match") { ac::llama::IncrementalStringFinder f("demo"); - CHECK_FALSE(f.feedText("de")); - CHECK(f.feedText("mo")); + CHECK(f.feedText("de") == -1); + CHECK(f.feedText("mo") == 2); f = ac::llama::IncrementalStringFinder("the"); // no match - CHECK_FALSE(f.feedText("empty")); + CHECK(f.feedText("empty") == -1); // complex partial match - CHECK_FALSE(f.feedText("emptyth")); // last 2 are 'th' - CHECK(f.feedText("ehooooo")); // + 'e' from the start + CHECK(f.feedText("emptyth") == -1); // last 2 are 'th' + CHECK(f.feedText("ehooooo") == 1); // + 'e' from the start } TEST_CASE("incremental finder - substring") { ac::llama::IncrementalStringFinder f("demo"); // complex substring - CHECK_FALSE(f.feedText("dede")); // will find only 2 - CHECK(f.feedText("demo2")); // has the contaning string + CHECK(f.feedText("dede") == -1); // will find only 2 + CHECK(f.feedText("demo2") == 4); // has the contaning string } TEST_CASE("incremental finder - case sensitivity") { // case sensitivity ac::llama::IncrementalStringFinder f("The"); - CHECK_FALSE(f.feedText("the")); + CHECK_FALSE(f.feedText("the") == 3); } TEST_CASE("antiprompt manager - empty") { ac::llama::AntipromptManager am; am.addAntiprompt(""); - CHECK_FALSE(am.feedGeneratedText("empty")); + CHECK(am.feedGeneratedText("empty").empty()); am.addAntiprompt("user:"); - CHECK_FALSE(am.feedGeneratedText("")); + CHECK(am.feedGeneratedText("").empty()); } TEST_CASE("antiprompt manager - detect") { ac::llama::AntipromptManager am; am.addAntiprompt("exit"); am.addAntiprompt("quit"); - CHECK_FALSE(am.feedGeneratedText("please continue")); - CHECK(am.feedGeneratedText("please exit!")); - CHECK(am.feedGeneratedText("please quit now!")); + CHECK(am.feedGeneratedText("please continue").empty()); + CHECK(am.feedGeneratedText("please exit!") == "exit!"); + CHECK(am.feedGeneratedText("please quit now!") == "quit now!"); } TEST_CASE("antiprompt manager - incremental feed") { @@ -68,23 +68,23 @@ TEST_CASE("antiprompt manager - incremental feed") { am.addAntiprompt("downstream"); am.addAntiprompt("shutdown"); - CHECK_FALSE(am.feedGeneratedText("shut")); // Partial match, so false - CHECK(am.feedGeneratedText("down")); // Completes the match, so true + CHECK(am.feedGeneratedText("shut").empty()); // Partial match, so false + CHECK(am.feedGeneratedText("down") == "shutdown"); // Completes the match, so true - CHECK_FALSE(am.feedGeneratedText("stream")); // state should be reset after match + CHECK(am.feedGeneratedText("stream").empty()); // state should be reset after match } TEST_CASE("antiprompt manager - reset/clear") { ac::llama::AntipromptManager am; am.addAntiprompt("cancel"); - CHECK_FALSE(am.feedGeneratedText("cance")); // Partial match, so false + CHECK(am.feedGeneratedText("cance").empty()); // Partial match, so false am.reset(); // Reset the manager's antiprompts state - CHECK(am.feedGeneratedText("cancel")); // Should match, since the state was reset + CHECK(am.feedGeneratedText("cancel") == "cancel"); // Should match, since the state was reset am.clear(); // Clear the manager's antiprompts - CHECK_FALSE(am.feedGeneratedText("cancel")); // Should match, since the state was reset + CHECK(am.feedGeneratedText("cancel").empty()); // Should match, since the prompts are gone am.addAntiprompt("cancel");// add the antiprompt again - CHECK(am.feedGeneratedText("cancel")); + CHECK(am.feedGeneratedText("cancel!") == "cancel!"); }