diff --git a/.gitignore b/.gitignore index c590bc397..7659ac08d 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ node_modules .eggs/ .env .DS_Store +.idea # Ignore native library built by setup guidance/*.so diff --git a/guidance/_utils.py b/guidance/_utils.py index d48d4eacc..e9e8fceff 100644 --- a/guidance/_utils.py +++ b/guidance/_utils.py @@ -6,6 +6,7 @@ import sys import textwrap import types +import re import numpy as np @@ -261,3 +262,35 @@ def softmax(array: np.ndarray, axis: int = -1) -> np.ndarray: array_maxs = np.amax(array, axis=axis, keepdims=True) exp_x_shifted = np.exp(array - array_maxs) return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True) + + +# Is it good to allow user to create their own instances of output? +class ReadableOutput: + def feed(self, state_list): + """ + Main function to parse model output + + state_list is a list, + where [0] is the text chunk + and [1] is rgba color tuple (if text was generated) or None (if text was inserted by us) + + The function must return something to print, or it will be None all the way + """ + raise NotImplementedError('"feed" must be implemented!') + + +class ReadableOutputCLIStream(ReadableOutput): + def __init__(self): + self._cur_chunk = 0 + super().__init__() + + def feed(self, state_list): + new_text = "" + for text, color in state_list[self._cur_chunk:]: + if color is not None: + new_text += '\033[38;2;{};{};{}m'.format(*[round(x) for x in color[:3]]) + new_text += text + if color is not None: + new_text += '\033[0m' + self._cur_chunk += 1 + return new_text diff --git a/guidance/models/_model.py b/guidance/models/_model.py index bdde3f2c0..31a94cf4f 100644 --- a/guidance/models/_model.py +++ b/guidance/models/_model.py @@ -9,19 +9,24 @@ import time import warnings - -from pprint import pprint from typing import Dict, TYPE_CHECKING - import numpy as np try: + from IPython import get_ipython from IPython.display import clear_output, display, HTML - - ipython_is_imported = True except ImportError: ipython_is_imported = False + notebook_mode = False +else: + ipython_is_imported = True + _ipython = get_ipython() + notebook_mode = ( + _ipython is not None + and "IPKernelApp" in _ipython.config + ) + try: import torch @@ -39,7 +44,7 @@ ) from .. import _cpp as cpp from ._guidance_engine_metrics import GuidanceEngineMetrics -from .._utils import softmax, CaptureEvents +from .._utils import softmax, CaptureEvents, ReadableOutputCLIStream from .._parser import EarleyCommitParser, Parser from .._grammar import ( GrammarFunction, @@ -857,11 +862,13 @@ def __init__(self, engine, echo=True, **kwargs): self._variables = {} # these are the state variables stored with the model self._variables_log_probs = {} # these are the state variables stored with the model self._cache_state = {} # mutable caching state used to save computation + self._state_list = [] self._state = "" # the current bytes that represent the state of the model self._event_queue = None # TODO: these are for streaming results in code, but that needs implemented self._event_parent = None self._last_display = 0 # used to track the last display call to enable throttling self._last_event_stream = 0 # used to track the last event streaming call to enable throttling + self._state_dict_parser = ReadableOutputCLIStream() # used to parse the state for cli display @property def active_role_end(self): @@ -975,11 +982,11 @@ def _update_display(self, throttle=True): else: self._last_display = curr_time - if ipython_is_imported: + if notebook_mode: clear_output(wait=True) display(HTML(self._html())) else: - pprint(self._state) + print(self._state_dict_parser.feed(self._state_list), end='', flush=True) def reset(self, clear_variables=True): """This resets the state of the model object. @@ -995,6 +1002,7 @@ def reset(self, clear_variables=True): self._variables_log_probs = {} return self + # Is this used anywhere? def _repr_html_(self): if ipython_is_imported: clear_output(wait=True) @@ -1327,9 +1335,18 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1): if len(chunk.new_bytes) > 0: generated_value += new_text + + # Add text to state list + self._state_list.append([new_text, None]) + if chunk.is_generated: lm += f"<||_html:_||>" + + # If that was generated text - color it + self._state_list[-1][1] = (165 * (1 - chunk.new_bytes_prob) + 0, 165 * chunk.new_bytes_prob + 0, 0, 0.15) + lm += new_text + if chunk.is_generated: lm += "<||_html:_||>" diff --git a/tests/unit/test_grammar.py b/tests/unit/test_grammar.py index 3b2774ffc..55c85a8e4 100644 --- a/tests/unit/test_grammar.py +++ b/tests/unit/test_grammar.py @@ -2,6 +2,12 @@ from guidance import gen, models, optional, select +def test_readable_output(): + model = models.Mock() + model += "Not colored " + select(options=["colored", "coloblue", "cologreen"]) + assert str(model) in ["Not colored colored", "Not colored coloblue", "Not colored cologreen"] + + def test_select_reset_pos(): model = models.Mock() model += "This is" + select(options=["bad", "quite bad"])