Skip to content

Commit

Permalink
Python mirror of c++ byte trie
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed Dec 5, 2023
1 parent dca17db commit bcc64ce
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 27 deletions.
1 change: 1 addition & 0 deletions guidance/_cpp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .byte_trie import ByteTrie
2 changes: 1 addition & 1 deletion guidance/_cpp/byte_trie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ByteTrie : public std::enable_shared_from_this<ByteTrie> { // enable_share
bool partial_match = false;
double prob = 0;
int value = -1;
std::map<char, std::shared_ptr<ByteTrie>> children;
std::unordered_map<char, std::shared_ptr<ByteTrie>> children;

ByteTrie(std::vector<std::string> byte_strings) {
for (size_t i = 0; i < byte_strings.size(); ++i) {
Expand Down
59 changes: 59 additions & 0 deletions guidance/_cpp/byte_trie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@


class ByteTrie:
"""A python implementation mirroring the C++ ByteTrie class."""
def __init__(self, byte_strings=None, values=None, parent=None):
self._parent = parent
self.match_version = -1
self.match = False
self.partial_match = False
self.prob = 0
self.value = -1
self.children = {}

if byte_strings is not None:
if values is None:
for s in byte_strings:
self.insert(s, 0)
else:
for i,s in enumerate(byte_strings):
self.insert(s, values[i])

def keys(self):
return self.children.keys()

def has_child(self, byte):
return byte in self.children

def child(self, byte):
return self.children[byte]

def parent(self):
return self._parent

def size(self):
return len(self.children)
def __len__(self):
return self.size()

def insert(self, s, value, pos=0):
if len(s) <= pos:
if self.value < 0:
self.value = value
else:
first_byte = s[pos:pos+1]
if first_byte not in self.children:
self.children[first_byte] = ByteTrie(parent=self)
self.children[first_byte].insert(s, value, pos + 1)

def compute_probs(self, probs):
self.prob = 0.0

if self.value != -1:
self.prob += probs[self.value]

if self.children:
for k in self.children:
child = self.children[k]
child.compute_probs(probs)
self.prob += child.prob
2 changes: 1 addition & 1 deletion guidance/_cpp/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ PYBIND11_MODULE(cpp, m) {
.def("has_child", &ByteTrie::has_child)
.def("child", &ByteTrie::child)
.def("parent", &ByteTrie::parent)
.def("__len__", &ByteTrie::size)
.def("__len__", &ByteTrie::size)
.def("keys", [](const ByteTrie& self) {
auto byte_strings = self.keys();
py::list py_byte_strings;
Expand Down
23 changes: 0 additions & 23 deletions guidance/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,29 +191,6 @@ def earliest_hidden_start(self, state_pos=None):
earliest_pos = min(earliest_pos, item.hidden_start)
return earliest_pos

# def earliest_hidden_start(self, state_pos=None):
# '''The earliest that a hidden node might match.

# This is useful because it tells us which bytes may end being hidden.
# '''
# if state_pos is None:
# state_pos = self.state_set_pos
# earliest_pos = 10000000000
# for item in self.state_sets[state_pos]:

# if item.pos > 0:
# # check for hidden nodes
# if item.node.hidden and item.start < earliest_pos:
# earliest_pos = item.start

# # check for nodes that are not hidden but end with a hidden terminal (terminal won't be in our state set by themselves, so we need this check)
# else:
# last_value = item.values[item.pos-1]
# if isinstance(last_value, Terminal) and last_value.hidden and state_pos - len(last_value) < earliest_pos:
# earliest_pos = state_pos - len(last_value)

# return earliest_pos

def matched(self):
'''Checks if the parser has completely matched the grammar.'''
if self.shadow_pos != self.state_set_pos:
Expand Down
4 changes: 2 additions & 2 deletions guidance/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import time
import numpy as np
import logging
from .. import cpp
from .. import cpp as cpp
from .._utils import ByteTrie, log_softmax, softmax
from .._parser import EarleyCommitParser
from .._grammar import StatelessFunction, string, _call_pool, _tag_pattern, Null, replace_model_variables, unreplace_model_variables, select, Terminal
Expand Down Expand Up @@ -869,7 +869,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
if parser.matched():
break # if we already have a full match we don't try more tokens we just give up as soon as the model deviates from the grammar

# if we just collpased a hidden commit point then we start over looking for a new token
# if we just collapased a hidden commit point then we start over looking for a new token
if retry_token_gen:
continue

Expand Down

0 comments on commit bcc64ce

Please sign in to comment.