diff --git a/guidance/_cpp/__init__.py b/guidance/_cpp/__init__.py new file mode 100644 index 000000000..c62e28a12 --- /dev/null +++ b/guidance/_cpp/__init__.py @@ -0,0 +1 @@ +from .byte_trie import ByteTrie \ No newline at end of file diff --git a/guidance/_cpp/byte_trie.cpp b/guidance/_cpp/byte_trie.cpp index fc3e85947..860b9dc67 100644 --- a/guidance/_cpp/byte_trie.cpp +++ b/guidance/_cpp/byte_trie.cpp @@ -10,7 +10,7 @@ class ByteTrie : public std::enable_shared_from_this { // enable_share bool partial_match = false; double prob = 0; int value = -1; - std::map> children; + std::unordered_map> children; ByteTrie(std::vector byte_strings) { for (size_t i = 0; i < byte_strings.size(); ++i) { diff --git a/guidance/_cpp/byte_trie.py b/guidance/_cpp/byte_trie.py new file mode 100644 index 000000000..42619a08a --- /dev/null +++ b/guidance/_cpp/byte_trie.py @@ -0,0 +1,59 @@ + + +class ByteTrie: + """A python implementation mirroring the C++ ByteTrie class.""" + def __init__(self, byte_strings=None, values=None, parent=None): + self._parent = parent + self.match_version = -1 + self.match = False + self.partial_match = False + self.prob = 0 + self.value = -1 + self.children = {} + + if byte_strings is not None: + if values is None: + for s in byte_strings: + self.insert(s, 0) + else: + for i,s in enumerate(byte_strings): + self.insert(s, values[i]) + + def keys(self): + return self.children.keys() + + def has_child(self, byte): + return byte in self.children + + def child(self, byte): + return self.children[byte] + + def parent(self): + return self._parent + + def size(self): + return len(self.children) + def __len__(self): + return self.size() + + def insert(self, s, value, pos=0): + if len(s) <= pos: + if self.value < 0: + self.value = value + else: + first_byte = s[pos:pos+1] + if first_byte not in self.children: + self.children[first_byte] = ByteTrie(parent=self) + self.children[first_byte].insert(s, value, pos + 1) + + def compute_probs(self, probs): + self.prob = 0.0 + + if self.value != -1: + self.prob += probs[self.value] + + if self.children: + for k in self.children: + child = self.children[k] + child.compute_probs(probs) + self.prob += child.prob \ No newline at end of file diff --git a/guidance/_cpp/main.cpp b/guidance/_cpp/main.cpp index 923a86974..5a45e3044 100644 --- a/guidance/_cpp/main.cpp +++ b/guidance/_cpp/main.cpp @@ -19,7 +19,7 @@ PYBIND11_MODULE(cpp, m) { .def("has_child", &ByteTrie::has_child) .def("child", &ByteTrie::child) .def("parent", &ByteTrie::parent) - .def("__len__", &ByteTrie::size) + .def("__len__", &ByteTrie::size) .def("keys", [](const ByteTrie& self) { auto byte_strings = self.keys(); py::list py_byte_strings; diff --git a/guidance/_parser.py b/guidance/_parser.py index 10956354a..4de4fe182 100644 --- a/guidance/_parser.py +++ b/guidance/_parser.py @@ -191,29 +191,6 @@ def earliest_hidden_start(self, state_pos=None): earliest_pos = min(earliest_pos, item.hidden_start) return earliest_pos - # def earliest_hidden_start(self, state_pos=None): - # '''The earliest that a hidden node might match. - - # This is useful because it tells us which bytes may end being hidden. - # ''' - # if state_pos is None: - # state_pos = self.state_set_pos - # earliest_pos = 10000000000 - # for item in self.state_sets[state_pos]: - - # if item.pos > 0: - # # check for hidden nodes - # if item.node.hidden and item.start < earliest_pos: - # earliest_pos = item.start - - # # check for nodes that are not hidden but end with a hidden terminal (terminal won't be in our state set by themselves, so we need this check) - # else: - # last_value = item.values[item.pos-1] - # if isinstance(last_value, Terminal) and last_value.hidden and state_pos - len(last_value) < earliest_pos: - # earliest_pos = state_pos - len(last_value) - - # return earliest_pos - def matched(self): '''Checks if the parser has completely matched the grammar.''' if self.shadow_pos != self.state_set_pos: diff --git a/guidance/models/_model.py b/guidance/models/_model.py index a2dedad83..1b366118e 100644 --- a/guidance/models/_model.py +++ b/guidance/models/_model.py @@ -14,7 +14,7 @@ import time import numpy as np import logging -from .. import cpp +from .. import cpp as cpp from .._utils import ByteTrie, log_softmax, softmax from .._parser import EarleyCommitParser from .._grammar import StatelessFunction, string, _call_pool, _tag_pattern, Null, replace_model_variables, unreplace_model_variables, select, Terminal @@ -869,7 +869,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e if parser.matched(): break # if we already have a full match we don't try more tokens we just give up as soon as the model deviates from the grammar - # if we just collpased a hidden commit point then we start over looking for a new token + # if we just collapased a hidden commit point then we start over looking for a new token if retry_token_gen: continue