diff --git a/guidance/_grammar.py b/guidance/_grammar.py index e85c46bb0..14ba5dd56 100644 --- a/guidance/_grammar.py +++ b/guidance/_grammar.py @@ -170,6 +170,10 @@ def __radd__(self, value): def __getitem__(self, value): raise StatefulException("GrammarFunctions can't access state!") + + @property + def token_count(self): + raise StatefulException("GrammarFunctions can't access state!") def match( self, diff --git a/guidance/_guidance.py b/guidance/_guidance.py index 510cac9e0..1fd6ddff9 100644 --- a/guidance/_guidance.py +++ b/guidance/_guidance.py @@ -2,7 +2,7 @@ import inspect from . import models -from ._grammar import Placeholder, RawFunction, Terminal, replace_grammar_node, string +from ._grammar import Placeholder, RawFunction, Terminal, replace_grammar_node, string, StatefulException from ._utils import strip_multiline_string_indents @@ -35,39 +35,38 @@ def _decorator(f, *, stateless, cache, dedent, model): @functools.wraps(f) def wrapped(*args, **kwargs): - # make a stateless grammar if we can - if stateless is True or ( - callable(stateless) and stateless(*args, **kwargs) - ): + # if we have a placeholder set then we must be in a recursive definition and so we return the placeholder + placeholder = getattr(f, "_self_call_placeholder_", None) + if placeholder is not None: + return placeholder - # if we have a placeholder set then we must be in a recursive definition and so we return the placeholder - placeholder = getattr(f, "_self_call_placeholder_", None) - if placeholder is not None: - return placeholder - - # otherwise we call the function to generate the grammar - else: + # otherwise we call the function to generate the grammar + else: - # set a placeholder for recursive calls (only if we don't have arguments that might make caching a bad idea) - no_args = len(args) + len(kwargs) == 0 - if no_args: - f._self_call_placeholder_ = Placeholder() + # set a placeholder for recursive calls (only if we don't have arguments that might make caching a bad idea) + no_args = len(args) + len(kwargs) == 0 + if no_args: + f._self_call_placeholder_ = Placeholder() - # call the function to get the grammar node + # try to trace the function to get a grammar node + node = None + try: node = f(_null_grammar, *args, **kwargs) if not isinstance(node, (Terminal, str)): node.name = f.__name__ - - # replace all the placeholders with our generated node + + # if that fails we must be stateful (which means we can't be inside a select() call) + except StatefulException: + return RawFunction(f, args, kwargs) + + # clean up, replacing all the placeholders with our generated node + finally: if no_args: - replace_grammar_node(node, f._self_call_placeholder_, node) + if node: + replace_grammar_node(node, f._self_call_placeholder_, node) del f._self_call_placeholder_ - return node - - # otherwise must be stateful (which means we can't be inside a select() call) - else: - return RawFunction(f, args, kwargs) + return node # Remove the first argument from the wrapped function signature = inspect.signature(f)