From 9224a3593ff426f3a6a2b9ce890c049bbf38ca93 Mon Sep 17 00:00:00 2001 From: Scott Lundberg Date: Mon, 4 Dec 2023 21:50:28 +0000 Subject: [PATCH] Fix prob summing calculations --- guidance/_cpp/byte_trie.cpp | 26 +++++++-------------- guidance/_cpp/main.cpp | 6 ++--- guidance/models/_model.py | 46 ++++++++++++++++++++++++++----------- setup.py | 4 ++-- 4 files changed, 45 insertions(+), 37 deletions(-) diff --git a/guidance/_cpp/byte_trie.cpp b/guidance/_cpp/byte_trie.cpp index 76b3b8b86..242471431 100644 --- a/guidance/_cpp/byte_trie.cpp +++ b/guidance/_cpp/byte_trie.cpp @@ -8,7 +8,7 @@ class ByteTrie : public std::enable_shared_from_this { // enable_share int match_version = -1; bool match = false; bool partial_match = false; - double log_prob = 0; + double prob = 0; int value = -1; std::map> children; @@ -52,7 +52,9 @@ class ByteTrie : public std::enable_shared_from_this { // enable_share void insert(const std::string& s, int value, uint pos = 0) { if (s.size() <= pos) { - this->value = value; + if (this->value < 0) { + this->value = value; + } } else { uint8_t first_byte = s[pos]; if (children.find(first_byte) == children.end()) { @@ -62,28 +64,16 @@ class ByteTrie : public std::enable_shared_from_this { // enable_share } } - void compute_log_probs(const std::vector& log_probs) { + void compute_probs(const std::vector& probs) { if (value != -1) { - log_prob += log_probs[value]; + prob += probs[value]; } if (!children.empty()) { - double max_log_prob = -std::numeric_limits::infinity(); - std::vector child_log_probs; - for (auto& pair : children) { - pair.second->compute_log_probs(log_probs); - child_log_probs.push_back(pair.second->log_prob); - if (pair.second->log_prob > max_log_prob) { - max_log_prob = pair.second->log_prob; - } - } - - double sum_exp = 0; - for (double child_log_prob : child_log_probs) { - sum_exp += std::exp(child_log_prob - max_log_prob); + pair.second->compute_probs(probs); + prob += pair.second->prob; } - log_prob = max_log_prob + std::log(sum_exp); } } }; \ No newline at end of file diff --git a/guidance/_cpp/main.cpp b/guidance/_cpp/main.cpp index 857680103..923a86974 100644 --- a/guidance/_cpp/main.cpp +++ b/guidance/_cpp/main.cpp @@ -5,7 +5,7 @@ // #include // #include #include -#include "byte_trie.cpp" +#include "byte_trie.cpp" namespace py = pybind11; @@ -28,11 +28,11 @@ PYBIND11_MODULE(cpp, m) { } return py_byte_strings; }) - .def("compute_log_probs", &ByteTrie::compute_log_probs) + .def("compute_probs", &ByteTrie::compute_probs) .def_readwrite("match_version", &ByteTrie::match_version) .def_readwrite("match", &ByteTrie::match) .def_readwrite("partial_match", &ByteTrie::partial_match) - .def_readwrite("log_prob", &ByteTrie::log_prob) + .def_readwrite("prob", &ByteTrie::prob) .def_readwrite("value", &ByteTrie::value) .def_readwrite("children", &ByteTrie::children); } \ No newline at end of file diff --git a/guidance/models/_model.py b/guidance/models/_model.py index c32bed771..635f623d1 100644 --- a/guidance/models/_model.py +++ b/guidance/models/_model.py @@ -79,6 +79,15 @@ def __init__(self, tokens, bos_token_id=None, eos_token_id=None, echo=True, comp self._token_trie = cpp.ByteTrie(tokens, np.arange(len(tokens))) self._token_trie.match = True self._token_trie.match_version = 0 + + # track which tokens are duplicates + self.duplicate_tokens = [] + found = {} + for i,t in enumerate(tokens): + if t in found: + self.duplicate_tokens.append((i, found[t])) + else: + found[t] = i @property def default_end_patterns(self): @@ -541,6 +550,11 @@ def _cleanup_tokens(self, token_ids, token_byte_positions): return token_ids, token_byte_positions + def _clean_duplicate_tokens(self, probs): + '''This moves all the probability mass from duplicate positons on to their primary index.''' + for i,j in self.duplicate_tokens: + probs[j] += probs[i] + probs[i] = 0 def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, ensure_bos_token=True): assert n == 1, "Still need to add support for n > 1!" @@ -701,10 +715,11 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e # TODO: we should lower this step to C++ with pybind11 if self.compute_log_probs: if torch: - log_probs = torch.nn.functional.log_softmax(torch.tensor(logits), dim=-1).cpu().numpy() # note we don't adjust for temp since we consider that a sampling step, not part of the probs + probs = torch.nn.functional.softmax(torch.tensor(logits), dim=-1).cpu().numpy() # note we don't adjust for temp since we consider that a sampling step, not part of the probs else: - log_probs = log_softmax(logits, axis=-1) # this numpy code is slower, so we don't use it if we have torch... - trie.compute_log_probs(log_probs) + probs = softmax(logits, axis=-1) # this numpy code is slower, so we don't use it if we have torch... + self._clean_duplicate_tokens(probs) + trie.compute_probs(probs) # get the sampling order grammar_temp = parser.next_byte_temperature() @@ -714,7 +729,9 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e else: assert top_p == 1, "Still need to add support for top_p!" if torch: - probs = torch.nn.functional.softmax(torch.tensor(logits) / current_temp, dim=-1) + logits = torch.tensor(logits) + torch.div(logits, current_temp, out=logits) + probs = torch.nn.functional.softmax(logits, dim=-1) sampling_order = torch.multinomial(probs, len(probs)).cpu().numpy() else: # this numpy version allows us to drop our dependence on pytorch...but it is way slower @@ -760,7 +777,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e token_pos += 1 # get the parser to consume the next byte - log_prob_delta = next_node.log_prob - node.log_prob + log_prob_delta = np.log(next_node.prob) - np.log(node.prob) new_bytes_log_prob += log_prob_delta commit_point = parser.consume_byte(next_byte, log_prob=log_prob_delta) @@ -975,18 +992,19 @@ def _record_captures(initial_item, data, log_prob_data, byte_data): used_names.add(cname) -# def _compute_log_probs(trie, log_probs): +# def _compute_probs(trie, probs, found): # '''Computes the log probabilities for each internal trie node.''' # if trie.value is not None: -# trie.log_prob += log_probs[trie.value] +# found[trie.value] = 1 +# trie.prob += probs[trie.value] -# if len(trie.children) > 0: -# child_log_probs = [] -# for b in trie.children: -# child = trie.children[b] -# _compute_log_probs(child, log_probs) -# child_log_probs.append(child.log_prob) -# trie.log_prob = np.logaddexp.reduce(child_log_probs) +# if len(trie) > 0: +# # child_probs = [] +# for b in trie.keys(): +# child = trie.child(b) +# _compute_probs(child, probs, found) +# trie.prob += child.prob +# # trie.log_prob = np.logaddexp.reduce(child_log_probs) def _check_dominated(node, parser, match_version, next_byte_mask): curr_pos = parser.pos diff --git a/setup.py b/setup.py index ef4d18fc8..594c5a30f 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,6 @@ def find_version(*file_paths): "msal", "requests", "numpy", - "pybind11", "aiohttp", "ordered_set", "pyformlang" @@ -56,7 +55,8 @@ def find_version(*file_paths): 'numpydoc', 'sphinx_rtd_theme', 'sphinx', - 'nbsphinx' + 'nbsphinx', + "pybind11" ], 'test': [ 'pytest',