Skip to content

Commit

Permalink
fix log prob computation
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed Dec 5, 2023
1 parent f790e98 commit a48717f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
5 changes: 4 additions & 1 deletion guidance/_cpp/byte_trie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class ByteTrie : public std::enable_shared_from_this<ByteTrie> { // enable_share
int match_version = -1;
bool match = false;
bool partial_match = false;
double prob = 1;
double prob = 0;
int value = -1;
std::map<char, std::shared_ptr<ByteTrie>> children;

Expand Down Expand Up @@ -64,7 +64,10 @@ class ByteTrie : public std::enable_shared_from_this<ByteTrie> { // enable_share
}
}

// we could save a lot of work if we assume the top node has prob 1.0 and then only explore the subtree we care about
void compute_probs(const std::vector<double>& probs) {
prob = 0.0;

if (value != -1) {
prob += probs[value];
}
Expand Down
27 changes: 18 additions & 9 deletions guidance/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def _run_stateless(lm, stateless_function, temperature=0.0, top_p=1.0, n=1):

delayed_bytes = b""
# last_is_generated = False
for new_bytes, is_generated, new_bytes_log_prob, capture_groups, capture_group_log_probs, new_token_count in gen_obj:
for new_bytes, is_generated, new_bytes_prob, capture_groups, capture_group_log_probs, new_token_count in gen_obj:
# convert the bytes to a string (delaying if we don't yet have a valid unicode string)
lm.token_count += new_token_count
new_bytes = delayed_bytes + new_bytes
Expand All @@ -427,7 +427,7 @@ def _run_stateless(lm, stateless_function, temperature=0.0, top_p=1.0, n=1):
if len(new_bytes) > 0:
generated_value += new_text
if is_generated:
lm += f"<||_html:<span style='background-color: rgba(0, 165, 0, {0.15 + 0.4 * (1 - np.exp(new_bytes_log_prob))}); border-radius: 3px;' title='{new_bytes_log_prob}'>_||>"
lm += f"<||_html:<span style='background-color: rgba({165*(1-new_bytes_prob) + 0}, {165*new_bytes_prob + 0}, 0, {0.15}); border-radius: 3px;' title='{new_bytes_prob}'>_||>"
lm += new_text
if is_generated:
lm += "<||_html:</span>_||>"
Expand Down Expand Up @@ -604,6 +604,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
retry_token_gen = False
trie = self._token_trie
trie.match_version += 1 # this invalidates all the match caches from the previous token
# trie.prob = 0.0 # need to reset when we reset the match_version
while True:
next_byte_mask = parser.next_byte_mask()
next_byte_mask_sum = next_byte_mask.sum()
Expand All @@ -630,6 +631,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
# mark this trie node with an up-to-date match flag (may save work later)
node = trie.child(byte)
node.match_version = self._token_trie.match_version
# node.prob = 0.0 # reset when we reset the match_version
node.match = next_byte_mask[byte[0]]

# see if we found a match
Expand Down Expand Up @@ -689,7 +691,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
if is_forced:
sampled_token_ind = trie.value
sampled_token = self.tokens[sampled_token_ind]
new_bytes_log_prob = 0.0
new_bytes_prob = 1.0
was_forced = True

# we are at the end of the grammar
Expand All @@ -700,7 +702,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
if trie != self._token_trie:
sampled_token_ind = trie.value
sampled_token = self.tokens[sampled_token_ind]
new_bytes_log_prob = 0.0
new_bytes_prob = 1.0

# otherwise we need to compute the logits and sample a valid token
else:
Expand Down Expand Up @@ -746,7 +748,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e

# make sure the parse is backed up to the position we want to start checking from TODO: make this account for shared prefixes with the last token
parser.pos = forced_pos
new_bytes_log_prob = 0.0
new_bytes_prob = 1.0

# make sure it matches any forced prefix
if start_pos < forced_pos and not sampled_token.startswith(parser.bytes[start_pos:forced_pos]):
Expand Down Expand Up @@ -777,8 +779,15 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
token_pos += 1

# get the parser to consume the next byte
log_prob_delta = np.log(next_node.prob) - np.log(node.prob)
new_bytes_log_prob += log_prob_delta
if next_node.prob < 1e-8:
if node.prob < 1e-8:
log_prob_delta = 0
else:
log_prob_delta = -20
else:
log_prob_delta = np.log(next_node.prob) - np.log(node.prob)
# log_prob_delta = np.log(next_node.prob) - np.log(node.prob)
new_bytes_prob = node.prob
commit_point = parser.consume_byte(next_byte, log_prob=log_prob_delta)

# if we are at a hidden commit point then we need to hide the bytes that match that node
Expand Down Expand Up @@ -857,7 +866,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
# we have no valid log prob data if we didn't compute it
if not self.compute_log_probs:
captured_log_prob_data = {k: None for k in captured_data}
yield new_bytes[hidden_count:], not is_forced, new_bytes_log_prob, captured_data, captured_log_prob_data, token_count - last_token_count
yield new_bytes[hidden_count:], not is_forced, new_bytes_prob, captured_data, captured_log_prob_data, token_count - last_token_count
last_token_count = token_count
break # we are done!
else:
Expand All @@ -866,7 +875,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
# yeild the snippet of text created by the next token
out = new_bytes[hidden_count:]
if len(out) > 0:
yield out, not is_forced, new_bytes_log_prob, {}, {}, token_count - last_token_count # note that we don't capture groups until a complete parse right now...
yield out, not is_forced, new_bytes_prob, {}, {}, token_count - last_token_count # note that we don't capture groups until a complete parse right now...
last_token_count = token_count
hidden_count = 0
token_count += 1 # note we only update this for tokens that emit non-hidden content
Expand Down

0 comments on commit a48717f

Please sign in to comment.