From dca17db82c9a518f4b77d766a3476b9c5d9acffb Mon Sep 17 00:00:00 2001 From: Scott Lundberg Date: Tue, 5 Dec 2023 21:51:07 +0000 Subject: [PATCH] unconstrained log_probs recording --- guidance/_grammar.py | 7 +++-- guidance/library/_gen.py | 2 +- guidance/models/_model.py | 55 ++++++++++++++++++++++++---------- tests/models/test_llama_cpp.py | 6 ++++ 4 files changed, 50 insertions(+), 20 deletions(-) diff --git a/guidance/_grammar.py b/guidance/_grammar.py index dd896f422..e181997ee 100644 --- a/guidance/_grammar.py +++ b/guidance/_grammar.py @@ -503,10 +503,9 @@ def select(options, name=None, list_append=False, recurse=False, skip_checks=Fal options[i] = str(value) # set up list append var saving if requested - name = "__LIST_APPEND:" + name if list_append else name + if list_append: + name = "__LIST_APPEND:" + name - # if name is None: - # name = _find_name() + "_" + StatelessFunction._new_name() if recurse: node = Select([], capture_name=name, recursive=True) node.values = [node + v for v in options if v != ""] + options @@ -527,6 +526,8 @@ def byte_range(low, high): # return value def capture(value, name): + # if log_probs: + # name += ":__LOG_PROBS" if not (isinstance(value, Join) and len(value.values) == 1): # don't double wrap value = Join([value]) # this ensures we capture what we want, and not something surprisingly self_recursive value.capture_name = name diff --git a/guidance/library/_gen.py b/guidance/library/_gen.py index 4a9398476..c1186984e 100644 --- a/guidance/library/_gen.py +++ b/guidance/library/_gen.py @@ -19,7 +19,7 @@ @guidance(stateless=lambda *args, **kwargs: kwargs.get("tools", None) is None) # TODO: uncomment this once we get temperature stateless def gen(lm, name=None, *, max_tokens=1000, list_append=False, regex=None, tools=None, hide_tool_call=False, stop=None, stop_regex=None, suffix="", n=1, temperature=0.0, top_p=1.0, - logprobs=None, stream_tokens=None, save_stop_text=False, **llm_kwargs): + stream_tokens=None, save_stop_text=False, **llm_kwargs): """ TODO: document this tools is a list of guidance.Tool or python functions (which will be converted to guidance.Tool) diff --git a/guidance/models/_model.py b/guidance/models/_model.py index c86183210..a2dedad83 100644 --- a/guidance/models/_model.py +++ b/guidance/models/_model.py @@ -69,6 +69,7 @@ def __init__(self, tokens, bos_token_id=None, eos_token_id=None, echo=True, comp # private attributes self._variables = {} # these are the state variables stored with the model + self._variables_log_probs = {} # these are the state variables stored with the model self._cache_state = {} # mutable caching state used to save computation self._state = "" # the current bytes that represent the state of the model self._event_queue = None # TODO: these are for streaming results in code, but that needs implemented @@ -138,6 +139,7 @@ def copy(self): # then copy a few things we need deeper copies of new_lm._variables = self._variables.copy() + new_lm._variables_log_probs = self._variables_log_probs.copy() new_lm.opened_blocks = self.opened_blocks.copy() # create a new clean event queue # TODO: can we delete this now? @@ -188,6 +190,7 @@ def reset(self, clear_variables=True): self._state = self._state[:0] if clear_variables: self._variables = {} + self._variables_log_probs = {} return self def _repr_html_(self): @@ -346,6 +349,19 @@ def remove(self, key): copy = self return copy + def log_prob(self, key, default=None): + '''Return the log prob of a variable, or a default value if the variable is not present. + + Parameters + ---------- + key : str + The name of the variable. + default : any + The value to return if the variable is not current set. + ''' + # TODO: support calling without a key to get the log prob of the whole model + return self._variables_log_probs.get(key, default) + def get_cache(self): return self.engine.cache @@ -440,7 +456,7 @@ def _run_stateless(lm, stateless_function, temperature=0.0, top_p=1.0, n=1): # see if we are in a list_append mode if isinstance(v, list): - for inner_v in v: + for i,inner_v in enumerate(v): # convert to a string if possible # TODO: will need to not just always do this once we support images etc. try: @@ -450,7 +466,9 @@ def _run_stateless(lm, stateless_function, temperature=0.0, top_p=1.0, n=1): if k not in lm or not isinstance(lm._variables[k], list): lm._variables[k] = [] + lm._variables_log_probs[k] = [] lm._variables[k].append(inner_v) + lm._variables_log_probs[k].append(capture_group_log_probs[k][i]) # ...or standard assignment mode else: @@ -461,6 +479,7 @@ def _run_stateless(lm, stateless_function, temperature=0.0, top_p=1.0, n=1): except UnicodeDecodeError: pass lm._variables[k] = v + lm._variables_log_probs[k] = capture_group_log_probs[k] # if len(capture_groups) > 0: # for k in capture_groups: @@ -714,14 +733,17 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e logits = self._get_logits(token_ids, parser.bytes[start_pos:forced_pos]) # if requested we compute the log probabilities so we can track the probabilities of each node - # TODO: we should lower this step to C++ with pybind11 if self.compute_log_probs: if torch: - probs = torch.nn.functional.softmax(torch.tensor(logits), dim=-1).cpu().numpy() # note we don't adjust for temp since we consider that a sampling step, not part of the probs + probs_torch = torch.nn.functional.softmax(torch.tensor(logits), dim=-1) + probs = probs_torch.cpu().numpy() # note we don't adjust for temp since we consider that a sampling step, not part of the probs else: probs = softmax(logits, axis=-1) # this numpy code is slower, so we don't use it if we have torch... self._clean_duplicate_tokens(probs) - trie.compute_probs(probs) + trie.compute_probs(probs) # C++ impl + else: + probs_torch = None + probs = None # get the sampling order grammar_temp = parser.next_byte_temperature() @@ -731,13 +753,15 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e else: assert top_p == 1, "Still need to add support for top_p!" if torch: - logits = torch.tensor(logits) - torch.div(logits, current_temp, out=logits) - probs = torch.nn.functional.softmax(logits, dim=-1) - sampling_order = torch.multinomial(probs, len(probs)).cpu().numpy() + if probs_torch is None: + logits = torch.tensor(logits) + torch.div(logits, current_temp, out=logits) + probs_torch = torch.nn.functional.softmax(logits, dim=-1) + sampling_order = torch.multinomial(probs_torch, len(probs_torch)).cpu().numpy() else: # this numpy version allows us to drop our dependence on pytorch...but it is way slower - probs = softmax(logits / current_temp, axis=-1) + if probs is None: + probs = softmax(logits / current_temp, axis=-1) probs += 1e-10 # ensure we have no zero probs that mess up numpy probs /= np.sum(probs) sampling_order = np.random.choice(len(probs), size=len(probs), p=probs, replace=False) # the 1e-10 is ensure we have no zero probs, which numpy does not like @@ -773,10 +797,6 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e # advance or fail according to the (now up-to-date) match cache if next_node.match: - - # mark that we accepted this byte - node = next_node - token_pos += 1 # get the parser to consume the next byte if next_node.prob < 1e-8: @@ -787,8 +807,12 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e else: log_prob_delta = np.log(next_node.prob) - np.log(node.prob) # log_prob_delta = np.log(next_node.prob) - np.log(node.prob) - new_bytes_prob = node.prob + new_bytes_prob = next_node.prob commit_point = parser.consume_byte(next_byte, log_prob=log_prob_delta) + + # mark that we accepted this byte + node = next_node + token_pos += 1 # if we are at a hidden commit point then we need to hide the bytes that match that node if commit_point is not None and commit_point.node.hidden: @@ -797,6 +821,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e # TODO: build a whole parse tree under this commit_point node so we can record child node captures if commit_point.node.capture_name: captured_data[commit_point.node.capture_name] = parser.bytes[commit_point.start:] + captured_log_prob_data[commit_point.node.capture_name] = commit_point.log_prob # This takes the item and commits to it as part of the parse and then shrinks it to zero width # in other words this hides the item @@ -864,8 +889,6 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e _record_captures(parse_tree, captured_data, captured_log_prob_data, parser.bytes) # we have no valid log prob data if we didn't compute it - if not self.compute_log_probs: - captured_log_prob_data = {k: None for k in captured_data} yield new_bytes[hidden_count:], not is_forced, new_bytes_prob, captured_data, captured_log_prob_data, token_count - last_token_count last_token_count = token_count break # we are done! diff --git a/tests/models/test_llama_cpp.py b/tests/models/test_llama_cpp.py index c7e4a464d..4dfb405a1 100644 --- a/tests/models/test_llama_cpp.py +++ b/tests/models/test_llama_cpp.py @@ -1,3 +1,4 @@ +import numpy as np import guidance from guidance import select, gen from ..utils import get_model @@ -7,6 +8,11 @@ def test_llama_cpp_gen(): lm = lm + "this is a test" + gen("test", max_tokens=10) assert len(str(lm)) > len("this is a test") +def test_llama_cpp_gen_log_probs(): + lm = get_model("llama_cpp:") + lm = lm + "this is a test" + gen("test", max_tokens=1) + assert 1 >= np.exp(lm.log_prob("test")) >= 0 + def test_llama_cpp_recursion_error(): lm = get_model("llama_cpp:")