Skip to content

Commit

Permalink
Fix prob summing calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed Dec 4, 2023
1 parent c1a896b commit 9224a35
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 37 deletions.
26 changes: 8 additions & 18 deletions guidance/_cpp/byte_trie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class ByteTrie : public std::enable_shared_from_this<ByteTrie> { // enable_share
int match_version = -1;
bool match = false;
bool partial_match = false;
double log_prob = 0;
double prob = 0;
int value = -1;
std::map<char, std::shared_ptr<ByteTrie>> children;

Expand Down Expand Up @@ -52,7 +52,9 @@ class ByteTrie : public std::enable_shared_from_this<ByteTrie> { // enable_share

void insert(const std::string& s, int value, uint pos = 0) {
if (s.size() <= pos) {
this->value = value;
if (this->value < 0) {
this->value = value;
}
} else {
uint8_t first_byte = s[pos];
if (children.find(first_byte) == children.end()) {
Expand All @@ -62,28 +64,16 @@ class ByteTrie : public std::enable_shared_from_this<ByteTrie> { // enable_share
}
}

void compute_log_probs(const std::vector<double>& log_probs) {
void compute_probs(const std::vector<double>& probs) {
if (value != -1) {
log_prob += log_probs[value];
prob += probs[value];
}

if (!children.empty()) {
double max_log_prob = -std::numeric_limits<double>::infinity();
std::vector<double> child_log_probs;

for (auto& pair : children) {
pair.second->compute_log_probs(log_probs);
child_log_probs.push_back(pair.second->log_prob);
if (pair.second->log_prob > max_log_prob) {
max_log_prob = pair.second->log_prob;
}
}

double sum_exp = 0;
for (double child_log_prob : child_log_probs) {
sum_exp += std::exp(child_log_prob - max_log_prob);
pair.second->compute_probs(probs);
prob += pair.second->prob;
}
log_prob = max_log_prob + std::log(sum_exp);
}
}
};
6 changes: 3 additions & 3 deletions guidance/_cpp/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// #include <pybind11/functional.h>
// #include <pybind11/chrono.h>
#include <any>
#include "byte_trie.cpp"
#include "byte_trie.cpp"

namespace py = pybind11;

Expand All @@ -28,11 +28,11 @@ PYBIND11_MODULE(cpp, m) {
}
return py_byte_strings;
})
.def("compute_log_probs", &ByteTrie::compute_log_probs)
.def("compute_probs", &ByteTrie::compute_probs)
.def_readwrite("match_version", &ByteTrie::match_version)
.def_readwrite("match", &ByteTrie::match)
.def_readwrite("partial_match", &ByteTrie::partial_match)
.def_readwrite("log_prob", &ByteTrie::log_prob)
.def_readwrite("prob", &ByteTrie::prob)
.def_readwrite("value", &ByteTrie::value)
.def_readwrite("children", &ByteTrie::children);
}
46 changes: 32 additions & 14 deletions guidance/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ def __init__(self, tokens, bos_token_id=None, eos_token_id=None, echo=True, comp
self._token_trie = cpp.ByteTrie(tokens, np.arange(len(tokens)))
self._token_trie.match = True
self._token_trie.match_version = 0

# track which tokens are duplicates
self.duplicate_tokens = []
found = {}
for i,t in enumerate(tokens):
if t in found:
self.duplicate_tokens.append((i, found[t]))
else:
found[t] = i

@property
def default_end_patterns(self):
Expand Down Expand Up @@ -541,6 +550,11 @@ def _cleanup_tokens(self, token_ids, token_byte_positions):

return token_ids, token_byte_positions

def _clean_duplicate_tokens(self, probs):
'''This moves all the probability mass from duplicate positons on to their primary index.'''
for i,j in self.duplicate_tokens:
probs[j] += probs[i]
probs[i] = 0

def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, ensure_bos_token=True):
assert n == 1, "Still need to add support for n > 1!"
Expand Down Expand Up @@ -701,10 +715,11 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
# TODO: we should lower this step to C++ with pybind11
if self.compute_log_probs:
if torch:
log_probs = torch.nn.functional.log_softmax(torch.tensor(logits), dim=-1).cpu().numpy() # note we don't adjust for temp since we consider that a sampling step, not part of the probs
probs = torch.nn.functional.softmax(torch.tensor(logits), dim=-1).cpu().numpy() # note we don't adjust for temp since we consider that a sampling step, not part of the probs
else:
log_probs = log_softmax(logits, axis=-1) # this numpy code is slower, so we don't use it if we have torch...
trie.compute_log_probs(log_probs)
probs = softmax(logits, axis=-1) # this numpy code is slower, so we don't use it if we have torch...
self._clean_duplicate_tokens(probs)
trie.compute_probs(probs)

# get the sampling order
grammar_temp = parser.next_byte_temperature()
Expand All @@ -714,7 +729,9 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
else:
assert top_p == 1, "Still need to add support for top_p!"
if torch:
probs = torch.nn.functional.softmax(torch.tensor(logits) / current_temp, dim=-1)
logits = torch.tensor(logits)
torch.div(logits, current_temp, out=logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
sampling_order = torch.multinomial(probs, len(probs)).cpu().numpy()
else:
# this numpy version allows us to drop our dependence on pytorch...but it is way slower
Expand Down Expand Up @@ -760,7 +777,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
token_pos += 1

# get the parser to consume the next byte
log_prob_delta = next_node.log_prob - node.log_prob
log_prob_delta = np.log(next_node.prob) - np.log(node.prob)
new_bytes_log_prob += log_prob_delta
commit_point = parser.consume_byte(next_byte, log_prob=log_prob_delta)

Expand Down Expand Up @@ -975,18 +992,19 @@ def _record_captures(initial_item, data, log_prob_data, byte_data):

used_names.add(cname)

# def _compute_log_probs(trie, log_probs):
# def _compute_probs(trie, probs, found):
# '''Computes the log probabilities for each internal trie node.'''
# if trie.value is not None:
# trie.log_prob += log_probs[trie.value]
# found[trie.value] = 1
# trie.prob += probs[trie.value]

# if len(trie.children) > 0:
# child_log_probs = []
# for b in trie.children:
# child = trie.children[b]
# _compute_log_probs(child, log_probs)
# child_log_probs.append(child.log_prob)
# trie.log_prob = np.logaddexp.reduce(child_log_probs)
# if len(trie) > 0:
# # child_probs = []
# for b in trie.keys():
# child = trie.child(b)
# _compute_probs(child, probs, found)
# trie.prob += child.prob
# # trie.log_prob = np.logaddexp.reduce(child_log_probs)

def _check_dominated(node, parser, match_version, next_byte_mask):
curr_pos = parser.pos
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def find_version(*file_paths):
"msal",
"requests",
"numpy",
"pybind11",
"aiohttp",
"ordered_set",
"pyformlang"
Expand All @@ -56,7 +55,8 @@ def find_version(*file_paths):
'numpydoc',
'sphinx_rtd_theme',
'sphinx',
'nbsphinx'
'nbsphinx',
"pybind11"
],
'test': [
'pytest',
Expand Down

0 comments on commit 9224a35

Please sign in to comment.