Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose fastforward and backtrack llguidance flags #1075

Merged
merged 2 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions guidance/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def __init__(
tokenizer: Tokenizer,
prompt: bytes = b"",
ensure_bos_token: bool = True,
enable_backtrack: bool = True,
enable_ff_tokens: bool = True,
):
if isinstance(grammar, GrammarFunction):
# we can't have a terminal as the root
Expand All @@ -50,6 +52,8 @@ def __init__(
self.ll_interpreter = llguidance.LLInterpreter(
self.ll_tokenizer,
serialized_grammar,
enable_backtrack,
enable_ff_tokens,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
self._threadpool = ThreadPoolExecutor(max_workers=1)
Expand Down
2 changes: 1 addition & 1 deletion guidance/models/_grammarless.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(
Using a transformers, tiktoken, or guidance.GrammarlessTokenizer directly will solve this issue."
)
# build the Engine
super().__init__(tokenizer=tokenizer, compute_log_probs=compute_log_probs)
super().__init__(tokenizer=tokenizer, compute_log_probs=compute_log_probs, enable_backtrack=False, enable_ff_tokens=False)

# build a prefix tree of the tokens
self._token_trie = cpp.ByteTrie(
Expand Down
44 changes: 26 additions & 18 deletions guidance/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .._utils import softmax, CaptureEvents
from .._parser import TokenParser
from .._grammar import (
Function, # for da types, just for you Hudson <3
GrammarFunction,
string,
_call_pool,
Expand Down Expand Up @@ -63,31 +64,29 @@ class Engine:
Server so a single server can serve many clients' model objects through a single Engine object.
"""

def __init__(self, tokenizer: Tokenizer, compute_log_probs=False):
def __init__(self, tokenizer: Tokenizer, compute_log_probs=False, enable_backtrack=True, enable_ff_tokens=True):
self.tokenizer = tokenizer
self.compute_log_probs = compute_log_probs
self._enable_backtrack = enable_backtrack
self._enable_ff_tokens = enable_ff_tokens
self.metrics = GuidanceEngineMetrics()

# These need to be properties because once an Engine is started, you can't change their behavior.
@property
def enable_backtrack(self):
return self._enable_backtrack

@property
def enable_ff_tokens(self):
return self._enable_ff_tokens

def get_chat_template(self): # TODO [HN]: Add more logic here...should we instantiate class here? do we even need to?
return self.tokenizer.chat_template() # Instantiate the class before returning to client for now

def reset_metrics(self):
self.metrics = GuidanceEngineMetrics()

def start(self, prompt, grammar, ensure_bos_token=True) -> TokenParser:
"""Start processing parser state executed through the grammar.

Parameters
----------
prompt : str or Parser
This is represents the current state of a guidance parser that will be extended
using the passed grammar. If a string is given then we assume the previous parser
state is just a fixed string prompt, if a full Parser is given then we extend that
parser by appending the new grammar to the parser's current grammar and then
inferencing the model. (TODO: implement full parser extension support)
grammar: Grammar
This is the grammar we are extending the prompt with.
"""
# def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, ensure_bos_token=True):
# assert n == 1, "Still need to add support for n > 1!"

Expand All @@ -113,10 +112,17 @@ def start(self, prompt, grammar, ensure_bos_token=True) -> TokenParser:
grammar=grammar,
tokenizer=self.tokenizer,
prompt=prompt,
ensure_bos_token=ensure_bos_token
ensure_bos_token=ensure_bos_token,
enable_backtrack=self.enable_backtrack,
enable_ff_tokens=self.enable_ff_tokens,
)

def __call__(self, prompt, grammar, ensure_bos_token=True) -> Iterator[EngineCallResponse]:
def __call__(
self,
prompt: Union[str, TokenParser],
grammar: Function,
ensure_bos_token: bool = True,
) -> Iterator[EngineCallResponse]:
"""Main entry point for the inference-parser loop. Yields EngineCallResponse objects as
the parser advances through the grammar.

Expand All @@ -128,8 +134,10 @@ def __call__(self, prompt, grammar, ensure_bos_token=True) -> Iterator[EngineCal
state is just a fixed string prompt, if a full Parser is given then we extend that
parser by appending the new grammar to the parser's current grammar and then
inferencing the model. (TODO: implement full parser extension support)
grammar: Grammar
This is the grammar we are extending the prompt with.
grammar: Function
Grammar (RawFunction or GrammarFunction) used to extend the prompt.
ensure_bos_token: bool
Ensures that the prompt ends with the BOS token.
"""
parser = self.start(prompt, grammar, ensure_bos_token)

Expand Down
8 changes: 7 additions & 1 deletion guidance/models/llama_cpp/_llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def encode(self, byte_string: bytes) -> list[int]:
class LlamaCppEngine(Engine):
"""The core class that runs inference using llama.cpp."""

def __init__(self, model, compute_log_probs, chat_template=None, **kwargs):
def __init__(self, model, compute_log_probs, chat_template=None, enable_backtrack=True, enable_ff_tokens=True, **kwargs):
if not is_llama_cpp:
raise Exception(
"Please install llama-cpp-python with `pip install llama-cpp-python` in order to use guidance.models.LlamaCpp!"
Expand Down Expand Up @@ -149,6 +149,8 @@ def __init__(self, model, compute_log_probs, chat_template=None, **kwargs):
super().__init__(
LlamaCppTokenizer(self.model_obj, chat_template=chat_template),
compute_log_probs=compute_log_probs,
enable_backtrack=enable_backtrack,
enable_ff_tokens=enable_ff_tokens,
)

self._n_vocab = len(self.tokenizer.tokens)
Expand Down Expand Up @@ -226,6 +228,8 @@ def __init__(
compute_log_probs=False,
api_key=None,
chat_template=None,
enable_backtrack=True,
enable_ff_tokens=True,
**llama_cpp_kwargs,
):
"""Build a new LlamaCpp model object that represents a model in a given state."""
Expand All @@ -237,6 +241,8 @@ def __init__(
model,
compute_log_probs=compute_log_probs,
chat_template=chat_template,
enable_backtrack=enable_backtrack,
enable_ff_tokens=enable_ff_tokens,
**llama_cpp_kwargs,
)

Expand Down
8 changes: 7 additions & 1 deletion guidance/models/transformers/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def recode(self, tokens: Sequence[int]) -> list[int]:


class TransformersEngine(Engine):
def __init__(self, model, tokenizer, compute_log_probs: bool, chat_template=None, **kwargs):
def __init__(self, model, tokenizer, compute_log_probs: bool, chat_template=None, enable_backtrack=True, enable_ff_tokens=True, **kwargs):
# fill in default model value
if model is None:
model = os.environ.get("TRANSFORMERS_MODEL", None)
Expand Down Expand Up @@ -413,6 +413,8 @@ def __init__(self, model, tokenizer, compute_log_probs: bool, chat_template=None
super().__init__(
my_tokenizer,
compute_log_probs=compute_log_probs,
enable_backtrack=enable_backtrack,
enable_ff_tokens=enable_ff_tokens
)

def _model(self, model, **kwargs):
Expand Down Expand Up @@ -524,6 +526,8 @@ def __init__(
echo=True,
compute_log_probs=False,
chat_template=None,
enable_backtrack=True,
enable_ff_tokens=True,
**kwargs,
):
"""Build a new Transformers model object that represents a model in a given state."""
Expand All @@ -533,6 +537,8 @@ def __init__(
tokenizer,
compute_log_probs,
chat_template=chat_template,
enable_backtrack=enable_backtrack,
enable_ff_tokens=enable_ff_tokens,
**kwargs,
),
echo=echo,
Expand Down
34 changes: 33 additions & 1 deletion tests/model_specific/test_transformers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from guidance import gen, select, models, assistant, system, user
from guidance import gen, select, models, assistant, system, user, guidance

from ..utils import get_model

Expand All @@ -27,6 +27,38 @@ def test_gpt2():

assert len(str(lm)) > len("this is a test")

def test_gpt2_fastforward(): # TODO [HN]: figure out how all the get_model and fixture stuff works
@guidance
def ff_prompt(lm):
big_opts = [
"Lorem ipsum dolor sit amet",
"Duis aute irure dolor "
]
lm += "Example text: " + select(big_opts, name="choice")
return lm

# We should have significantly less output tokens in the fast-forwarded version (1 output)

gpt2_noff = models.Transformers("gpt2", enable_ff_tokens=False)
gpt2_noff += ff_prompt()
noff_count = gpt2_noff.engine.metrics.engine_output_tokens

gpt2_nobt = models.Transformers("gpt2", enable_backtrack=False)
gpt2_nobt += ff_prompt()
nobt_count = gpt2_nobt.engine.metrics.engine_output_tokens

gpt2_ff = models.Transformers("gpt2")
gpt2_ff += ff_prompt()
ff_count = gpt2_ff.engine.metrics.engine_output_tokens

assert nobt_count == 3
assert ff_count == 3
assert noff_count > ff_count






def test_recursion_error():
"""This checks for an infinite recursion error resulting from a terminal node at the root of a grammar."""
Expand Down
Loading