Skip to content

Commit

Permalink
feat: get matched antiprompt to prepare data for backtracking, ref #3
Browse files Browse the repository at this point in the history
  • Loading branch information
pminev committed Jan 17, 2025
1 parent 060bb7a commit 2a5e664
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 38 deletions.
3 changes: 2 additions & 1 deletion ac-local-plugin/code/LocalLlama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ class LlamaInstance final : public Instance {
}

auto tokenStr = model.vocab().tokenToString(t);
if (antiprompt.feedGeneratedText(tokenStr)) {
auto matchedAntiPrompt = antiprompt.feedGeneratedText(tokenStr);
if (!matchedAntiPrompt.empty()) {
break;
}

Expand Down
11 changes: 7 additions & 4 deletions code/ac/llama/AntipromptManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@ void AntipromptManager::addAntiprompt(std::string_view antiprompt) {
m_antiprompts.push_back(std::string(antiprompt));
}

bool AntipromptManager::feedGeneratedText(std::string_view text) {
std::string AntipromptManager::feedGeneratedText(std::string_view text) {
for (auto& ap : m_antiprompts) {
if (ap.feedText(text)) {
int found = ap.feedText(text);
if (found > 0) {
reset();
return true;
return found == 0 ?
ap.getString():
ap.getString() + std::string(text.substr(found, text.length()));
}
}

return false;
return {};
}

void AntipromptManager::reset() {
Expand Down
2 changes: 1 addition & 1 deletion code/ac/llama/AntipromptManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class AC_LLAMA_EXPORT AntipromptManager {
void addAntiprompt(std::string_view antiprompt);

// feed each antiprompt with the text
bool feedGeneratedText(std::string_view text);
std::string feedGeneratedText(std::string_view text);

// reset the state of all antiprompts
void reset();
Expand Down
8 changes: 4 additions & 4 deletions code/ac/llama/IncrementalStringFinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ IncrementalStringFinder::IncrementalStringFinder(std::string searchStr)
, m_currentPos(0)
{}

bool IncrementalStringFinder::feedText(std::string_view text) {
int IncrementalStringFinder::feedText(std::string_view text) {
if (m_searchStr.length() == 0) {
return false;
return -1;
}

uint32_t promptPos = 0;
Expand All @@ -32,10 +32,10 @@ bool IncrementalStringFinder::feedText(std::string_view text) {

if (m_currentPos == m_searchStr.length()) {
m_currentPos = 0;
return true;
return promptPos;
}

return false;
return -1;
}

void IncrementalStringFinder::reset() {
Expand Down
9 changes: 7 additions & 2 deletions code/ac/llama/IncrementalStringFinder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@
namespace ac::llama {
class AC_LLAMA_EXPORT IncrementalStringFinder {
public:
IncrementalStringFinder(std::string searchStr = "");
IncrementalStringFinder(std::string searchStr);

// incremental search for `m_str` in `text`
bool feedText(std::string_view text);
// returns -1 if the search string was not found
// returns >=0 the search string was found and the count of matched characters of last feed
int feedText(std::string_view text);

// reset the `currentPos`
void reset();

// return the string that was searched for
const std::string& getString() const { return m_searchStr; }

private:
std::string m_searchStr;
uint16_t m_currentPos;
Expand Down
4 changes: 2 additions & 2 deletions example/e-gui.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ class UModel {

auto tokenStr = m_vocab.tokenToString(token);
m_text += tokenStr;

if (m_antiprompt.feedGeneratedText(tokenStr)) {
auto matchedAntiPrompt = m_antiprompt.feedGeneratedText(tokenStr);
if (!matchedAntiPrompt.empty()) {
m_numTokens = 0;
return;
}
Expand Down
48 changes: 24 additions & 24 deletions test/t-Antiprompt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,82 +9,82 @@
TEST_CASE("incremental finder - empty") {
// by default string finder has empty search string
// and will always return false
ac::llama::IncrementalStringFinder f;
CHECK(f.feedText("") == false);
CHECK(f.feedText("empty") == false);
ac::llama::IncrementalStringFinder f("");
CHECK(f.feedText("") == -1);
CHECK(f.feedText("empty") == -1);

f = ac::llama::IncrementalStringFinder("demo");
// empty feed
CHECK(f.feedText("") == false);
CHECK(f.feedText("") == -1);
}

TEST_CASE("incremental finder - partial match") {
ac::llama::IncrementalStringFinder f("demo");
CHECK_FALSE(f.feedText("de"));
CHECK(f.feedText("mo"));
CHECK(f.feedText("de") == -1);
CHECK(f.feedText("mo") == 2);

f = ac::llama::IncrementalStringFinder("the");
// no match
CHECK_FALSE(f.feedText("empty"));
CHECK(f.feedText("empty") == -1);

// complex partial match
CHECK_FALSE(f.feedText("emptyth")); // last 2 are 'th'
CHECK(f.feedText("ehooooo")); // + 'e' from the start
CHECK(f.feedText("emptyth") == -1); // last 2 are 'th'
CHECK(f.feedText("ehooooo") == 1); // + 'e' from the start
}

TEST_CASE("incremental finder - substring") {
ac::llama::IncrementalStringFinder f("demo");
// complex substring
CHECK_FALSE(f.feedText("dede")); // will find only 2
CHECK(f.feedText("demo2")); // has the contaning string
CHECK(f.feedText("dede") == -1); // will find only 2
CHECK(f.feedText("demo2") == 4); // has the contaning string
}

TEST_CASE("incremental finder - case sensitivity") {
// case sensitivity
ac::llama::IncrementalStringFinder f("The");
CHECK_FALSE(f.feedText("the"));
CHECK_FALSE(f.feedText("the") == 3);
}

TEST_CASE("antiprompt manager - empty") {
ac::llama::AntipromptManager am;
am.addAntiprompt("");
CHECK_FALSE(am.feedGeneratedText("empty"));
CHECK(am.feedGeneratedText("empty").empty());

am.addAntiprompt("user:");
CHECK_FALSE(am.feedGeneratedText(""));
CHECK(am.feedGeneratedText("").empty());
}

TEST_CASE("antiprompt manager - detect") {
ac::llama::AntipromptManager am;
am.addAntiprompt("exit");
am.addAntiprompt("quit");
CHECK_FALSE(am.feedGeneratedText("please continue"));
CHECK(am.feedGeneratedText("please exit!"));
CHECK(am.feedGeneratedText("please quit now!"));
CHECK(am.feedGeneratedText("please continue").empty());
CHECK(am.feedGeneratedText("please exit!") == "exit!");
CHECK(am.feedGeneratedText("please quit now!") == "quit now!");
}

TEST_CASE("antiprompt manager - incremental feed") {
ac::llama::AntipromptManager am;
am.addAntiprompt("downstream");
am.addAntiprompt("shutdown");

CHECK_FALSE(am.feedGeneratedText("shut")); // Partial match, so false
CHECK(am.feedGeneratedText("down")); // Completes the match, so true
CHECK(am.feedGeneratedText("shut").empty()); // Partial match, so false
CHECK(am.feedGeneratedText("down") == "shutdown"); // Completes the match, so true

CHECK_FALSE(am.feedGeneratedText("stream")); // state should be reset after match
CHECK(am.feedGeneratedText("stream").empty()); // state should be reset after match
}

TEST_CASE("antiprompt manager - reset/clear") {
ac::llama::AntipromptManager am;
am.addAntiprompt("cancel");

CHECK_FALSE(am.feedGeneratedText("cance")); // Partial match, so false
CHECK(am.feedGeneratedText("cance").empty()); // Partial match, so false
am.reset(); // Reset the manager's antiprompts state
CHECK(am.feedGeneratedText("cancel")); // Should match, since the state was reset
CHECK(am.feedGeneratedText("cancel") == "cancel"); // Should match, since the state was reset

am.clear(); // Clear the manager's antiprompts
CHECK_FALSE(am.feedGeneratedText("cancel")); // Should match, since the state was reset
CHECK(am.feedGeneratedText("cancel").empty()); // Should match, since the prompts are gone

am.addAntiprompt("cancel");// add the antiprompt again
CHECK(am.feedGeneratedText("cancel"));
CHECK(am.feedGeneratedText("cancel!") == "cancel!");
}

0 comments on commit 2a5e664

Please sign in to comment.