From a48717f14710ce553635f13263fbe594b0748007 Mon Sep 17 00:00:00 2001 From: Scott Lundberg Date: Tue, 5 Dec 2023 00:05:38 +0000 Subject: [PATCH] fix log prob computation --- guidance/_cpp/byte_trie.cpp | 5 ++++- guidance/models/_model.py | 27 ++++++++++++++++++--------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/guidance/_cpp/byte_trie.cpp b/guidance/_cpp/byte_trie.cpp index 460e8b630..fc3e85947 100644 --- a/guidance/_cpp/byte_trie.cpp +++ b/guidance/_cpp/byte_trie.cpp @@ -8,7 +8,7 @@ class ByteTrie : public std::enable_shared_from_this { // enable_share int match_version = -1; bool match = false; bool partial_match = false; - double prob = 1; + double prob = 0; int value = -1; std::map> children; @@ -64,7 +64,10 @@ class ByteTrie : public std::enable_shared_from_this { // enable_share } } + // we could save a lot of work if we assume the top node has prob 1.0 and then only explore the subtree we care about void compute_probs(const std::vector& probs) { + prob = 0.0; + if (value != -1) { prob += probs[value]; } diff --git a/guidance/models/_model.py b/guidance/models/_model.py index 635f623d1..c86183210 100644 --- a/guidance/models/_model.py +++ b/guidance/models/_model.py @@ -413,7 +413,7 @@ def _run_stateless(lm, stateless_function, temperature=0.0, top_p=1.0, n=1): delayed_bytes = b"" # last_is_generated = False - for new_bytes, is_generated, new_bytes_log_prob, capture_groups, capture_group_log_probs, new_token_count in gen_obj: + for new_bytes, is_generated, new_bytes_prob, capture_groups, capture_group_log_probs, new_token_count in gen_obj: # convert the bytes to a string (delaying if we don't yet have a valid unicode string) lm.token_count += new_token_count new_bytes = delayed_bytes + new_bytes @@ -427,7 +427,7 @@ def _run_stateless(lm, stateless_function, temperature=0.0, top_p=1.0, n=1): if len(new_bytes) > 0: generated_value += new_text if is_generated: - lm += f"<||_html:_||>" + lm += f"<||_html:_||>" lm += new_text if is_generated: lm += "<||_html:_||>" @@ -604,6 +604,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e retry_token_gen = False trie = self._token_trie trie.match_version += 1 # this invalidates all the match caches from the previous token + # trie.prob = 0.0 # need to reset when we reset the match_version while True: next_byte_mask = parser.next_byte_mask() next_byte_mask_sum = next_byte_mask.sum() @@ -630,6 +631,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e # mark this trie node with an up-to-date match flag (may save work later) node = trie.child(byte) node.match_version = self._token_trie.match_version + # node.prob = 0.0 # reset when we reset the match_version node.match = next_byte_mask[byte[0]] # see if we found a match @@ -689,7 +691,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e if is_forced: sampled_token_ind = trie.value sampled_token = self.tokens[sampled_token_ind] - new_bytes_log_prob = 0.0 + new_bytes_prob = 1.0 was_forced = True # we are at the end of the grammar @@ -700,7 +702,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e if trie != self._token_trie: sampled_token_ind = trie.value sampled_token = self.tokens[sampled_token_ind] - new_bytes_log_prob = 0.0 + new_bytes_prob = 1.0 # otherwise we need to compute the logits and sample a valid token else: @@ -746,7 +748,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e # make sure the parse is backed up to the position we want to start checking from TODO: make this account for shared prefixes with the last token parser.pos = forced_pos - new_bytes_log_prob = 0.0 + new_bytes_prob = 1.0 # make sure it matches any forced prefix if start_pos < forced_pos and not sampled_token.startswith(parser.bytes[start_pos:forced_pos]): @@ -777,8 +779,15 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e token_pos += 1 # get the parser to consume the next byte - log_prob_delta = np.log(next_node.prob) - np.log(node.prob) - new_bytes_log_prob += log_prob_delta + if next_node.prob < 1e-8: + if node.prob < 1e-8: + log_prob_delta = 0 + else: + log_prob_delta = -20 + else: + log_prob_delta = np.log(next_node.prob) - np.log(node.prob) + # log_prob_delta = np.log(next_node.prob) - np.log(node.prob) + new_bytes_prob = node.prob commit_point = parser.consume_byte(next_byte, log_prob=log_prob_delta) # if we are at a hidden commit point then we need to hide the bytes that match that node @@ -857,7 +866,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e # we have no valid log prob data if we didn't compute it if not self.compute_log_probs: captured_log_prob_data = {k: None for k in captured_data} - yield new_bytes[hidden_count:], not is_forced, new_bytes_log_prob, captured_data, captured_log_prob_data, token_count - last_token_count + yield new_bytes[hidden_count:], not is_forced, new_bytes_prob, captured_data, captured_log_prob_data, token_count - last_token_count last_token_count = token_count break # we are done! else: @@ -866,7 +875,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e # yeild the snippet of text created by the next token out = new_bytes[hidden_count:] if len(out) > 0: - yield out, not is_forced, new_bytes_log_prob, {}, {}, token_count - last_token_count # note that we don't capture groups until a complete parse right now... + yield out, not is_forced, new_bytes_prob, {}, {}, token_count - last_token_count # note that we don't capture groups until a complete parse right now... last_token_count = token_count hidden_count = 0 token_count += 1 # note we only update this for tokens that emit non-hidden content