Skip to content

Commit

Permalink
unconstrained log_probs recording
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed Dec 5, 2023
1 parent 8e64b00 commit dca17db
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 20 deletions.
7 changes: 4 additions & 3 deletions guidance/_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,10 +503,9 @@ def select(options, name=None, list_append=False, recurse=False, skip_checks=Fal
options[i] = str(value)

# set up list append var saving if requested
name = "__LIST_APPEND:" + name if list_append else name
if list_append:
name = "__LIST_APPEND:" + name

# if name is None:
# name = _find_name() + "_" + StatelessFunction._new_name()
if recurse:
node = Select([], capture_name=name, recursive=True)
node.values = [node + v for v in options if v != ""] + options
Expand All @@ -527,6 +526,8 @@ def byte_range(low, high):
# return value

def capture(value, name):
# if log_probs:
# name += ":__LOG_PROBS"
if not (isinstance(value, Join) and len(value.values) == 1): # don't double wrap
value = Join([value]) # this ensures we capture what we want, and not something surprisingly self_recursive
value.capture_name = name
Expand Down
2 changes: 1 addition & 1 deletion guidance/library/_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
@guidance(stateless=lambda *args, **kwargs: kwargs.get("tools", None) is None) # TODO: uncomment this once we get temperature stateless
def gen(lm, name=None, *, max_tokens=1000, list_append=False, regex=None,
tools=None, hide_tool_call=False, stop=None, stop_regex=None, suffix="", n=1, temperature=0.0, top_p=1.0,
logprobs=None, stream_tokens=None, save_stop_text=False, **llm_kwargs):
stream_tokens=None, save_stop_text=False, **llm_kwargs):
"""
TODO: document this
tools is a list of guidance.Tool or python functions (which will be converted to guidance.Tool)
Expand Down
55 changes: 39 additions & 16 deletions guidance/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, tokens, bos_token_id=None, eos_token_id=None, echo=True, comp

# private attributes
self._variables = {} # these are the state variables stored with the model
self._variables_log_probs = {} # these are the state variables stored with the model
self._cache_state = {} # mutable caching state used to save computation
self._state = "" # the current bytes that represent the state of the model
self._event_queue = None # TODO: these are for streaming results in code, but that needs implemented
Expand Down Expand Up @@ -138,6 +139,7 @@ def copy(self):

# then copy a few things we need deeper copies of
new_lm._variables = self._variables.copy()
new_lm._variables_log_probs = self._variables_log_probs.copy()
new_lm.opened_blocks = self.opened_blocks.copy()

# create a new clean event queue # TODO: can we delete this now?
Expand Down Expand Up @@ -188,6 +190,7 @@ def reset(self, clear_variables=True):
self._state = self._state[:0]
if clear_variables:
self._variables = {}
self._variables_log_probs = {}
return self

def _repr_html_(self):
Expand Down Expand Up @@ -346,6 +349,19 @@ def remove(self, key):
copy = self
return copy

def log_prob(self, key, default=None):
'''Return the log prob of a variable, or a default value if the variable is not present.
Parameters
----------
key : str
The name of the variable.
default : any
The value to return if the variable is not current set.
'''
# TODO: support calling without a key to get the log prob of the whole model
return self._variables_log_probs.get(key, default)

def get_cache(self):
return self.engine.cache

Expand Down Expand Up @@ -440,7 +456,7 @@ def _run_stateless(lm, stateless_function, temperature=0.0, top_p=1.0, n=1):

# see if we are in a list_append mode
if isinstance(v, list):
for inner_v in v:
for i,inner_v in enumerate(v):
# convert to a string if possible
# TODO: will need to not just always do this once we support images etc.
try:
Expand All @@ -450,7 +466,9 @@ def _run_stateless(lm, stateless_function, temperature=0.0, top_p=1.0, n=1):

if k not in lm or not isinstance(lm._variables[k], list):
lm._variables[k] = []
lm._variables_log_probs[k] = []
lm._variables[k].append(inner_v)
lm._variables_log_probs[k].append(capture_group_log_probs[k][i])

# ...or standard assignment mode
else:
Expand All @@ -461,6 +479,7 @@ def _run_stateless(lm, stateless_function, temperature=0.0, top_p=1.0, n=1):
except UnicodeDecodeError:
pass
lm._variables[k] = v
lm._variables_log_probs[k] = capture_group_log_probs[k]

# if len(capture_groups) > 0:
# for k in capture_groups:
Expand Down Expand Up @@ -714,14 +733,17 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
logits = self._get_logits(token_ids, parser.bytes[start_pos:forced_pos])

# if requested we compute the log probabilities so we can track the probabilities of each node
# TODO: we should lower this step to C++ with pybind11
if self.compute_log_probs:
if torch:
probs = torch.nn.functional.softmax(torch.tensor(logits), dim=-1).cpu().numpy() # note we don't adjust for temp since we consider that a sampling step, not part of the probs
probs_torch = torch.nn.functional.softmax(torch.tensor(logits), dim=-1)
probs = probs_torch.cpu().numpy() # note we don't adjust for temp since we consider that a sampling step, not part of the probs
else:
probs = softmax(logits, axis=-1) # this numpy code is slower, so we don't use it if we have torch...
self._clean_duplicate_tokens(probs)
trie.compute_probs(probs)
trie.compute_probs(probs) # C++ impl
else:
probs_torch = None
probs = None

# get the sampling order
grammar_temp = parser.next_byte_temperature()
Expand All @@ -731,13 +753,15 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
else:
assert top_p == 1, "Still need to add support for top_p!"
if torch:
logits = torch.tensor(logits)
torch.div(logits, current_temp, out=logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
sampling_order = torch.multinomial(probs, len(probs)).cpu().numpy()
if probs_torch is None:
logits = torch.tensor(logits)
torch.div(logits, current_temp, out=logits)
probs_torch = torch.nn.functional.softmax(logits, dim=-1)
sampling_order = torch.multinomial(probs_torch, len(probs_torch)).cpu().numpy()
else:
# this numpy version allows us to drop our dependence on pytorch...but it is way slower
probs = softmax(logits / current_temp, axis=-1)
if probs is None:
probs = softmax(logits / current_temp, axis=-1)
probs += 1e-10 # ensure we have no zero probs that mess up numpy
probs /= np.sum(probs)
sampling_order = np.random.choice(len(probs), size=len(probs), p=probs, replace=False) # the 1e-10 is ensure we have no zero probs, which numpy does not like
Expand Down Expand Up @@ -773,10 +797,6 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e

# advance or fail according to the (now up-to-date) match cache
if next_node.match:

# mark that we accepted this byte
node = next_node
token_pos += 1

# get the parser to consume the next byte
if next_node.prob < 1e-8:
Expand All @@ -787,8 +807,12 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
else:
log_prob_delta = np.log(next_node.prob) - np.log(node.prob)
# log_prob_delta = np.log(next_node.prob) - np.log(node.prob)
new_bytes_prob = node.prob
new_bytes_prob = next_node.prob
commit_point = parser.consume_byte(next_byte, log_prob=log_prob_delta)

# mark that we accepted this byte
node = next_node
token_pos += 1

# if we are at a hidden commit point then we need to hide the bytes that match that node
if commit_point is not None and commit_point.node.hidden:
Expand All @@ -797,6 +821,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
# TODO: build a whole parse tree under this commit_point node so we can record child node captures
if commit_point.node.capture_name:
captured_data[commit_point.node.capture_name] = parser.bytes[commit_point.start:]
captured_log_prob_data[commit_point.node.capture_name] = commit_point.log_prob

# This takes the item and commits to it as part of the parse and then shrinks it to zero width
# in other words this hides the item
Expand Down Expand Up @@ -864,8 +889,6 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
_record_captures(parse_tree, captured_data, captured_log_prob_data, parser.bytes)

# we have no valid log prob data if we didn't compute it
if not self.compute_log_probs:
captured_log_prob_data = {k: None for k in captured_data}
yield new_bytes[hidden_count:], not is_forced, new_bytes_prob, captured_data, captured_log_prob_data, token_count - last_token_count
last_token_count = token_count
break # we are done!
Expand Down
6 changes: 6 additions & 0 deletions tests/models/test_llama_cpp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import guidance
from guidance import select, gen
from ..utils import get_model
Expand All @@ -7,6 +8,11 @@ def test_llama_cpp_gen():
lm = lm + "this is a test" + gen("test", max_tokens=10)
assert len(str(lm)) > len("this is a test")

def test_llama_cpp_gen_log_probs():
lm = get_model("llama_cpp:")
lm = lm + "this is a test" + gen("test", max_tokens=1)
assert 1 >= np.exp(lm.log_prob("test")) >= 0

def test_llama_cpp_recursion_error():
lm = get_model("llama_cpp:")

Expand Down

0 comments on commit dca17db

Please sign in to comment.