Skip to content

Commit

Permalink
Clean up placeholders and add token count stateful check
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed May 6, 2024
1 parent 7d7a96b commit bed7492
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 25 deletions.
4 changes: 4 additions & 0 deletions guidance/_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def __radd__(self, value):

def __getitem__(self, value):
raise StatefulException("GrammarFunctions can't access state!")

@property
def token_count(self):
raise StatefulException("GrammarFunctions can't access state!")

def match(
self,
Expand Down
51 changes: 26 additions & 25 deletions guidance/_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,37 +35,38 @@ def _decorator(f, *, stateless, cache, dedent, model):
@functools.wraps(f)
def wrapped(*args, **kwargs):

# make a stateless grammar if we can
try:

# if we have a placeholder set then we must be in a recursive definition and so we return the placeholder
placeholder = getattr(f, "_self_call_placeholder_", None)
if placeholder is not None:
return placeholder

# otherwise we call the function to generate the grammar
else:

# set a placeholder for recursive calls (only if we don't have arguments that might make caching a bad idea)
no_args = len(args) + len(kwargs) == 0
if no_args:
f._self_call_placeholder_ = Placeholder()

# call the function to get the grammar node
# if we have a placeholder set then we must be in a recursive definition and so we return the placeholder
placeholder = getattr(f, "_self_call_placeholder_", None)
if placeholder is not None:
return placeholder

# otherwise we call the function to generate the grammar
else:

# set a placeholder for recursive calls (only if we don't have arguments that might make caching a bad idea)
no_args = len(args) + len(kwargs) == 0
if no_args:
f._self_call_placeholder_ = Placeholder()

# try to trace the function to get a grammar node
node = None
try:
node = f(_null_grammar, *args, **kwargs)
if not isinstance(node, (Terminal, str)):
node.name = f.__name__

# replace all the placeholders with our generated node

# if that fails we must be stateful (which means we can't be inside a select() call)
except StatefulException:
return RawFunction(f, args, kwargs)

# clean up, replacing all the placeholders with our generated node
finally:
if no_args:
replace_grammar_node(node, f._self_call_placeholder_, node)
if node:
replace_grammar_node(node, f._self_call_placeholder_, node)
del f._self_call_placeholder_

return node

# otherwise must be stateful (which means we can't be inside a select() call)
except StatefulException:
return RawFunction(f, args, kwargs)
return node

# Remove the first argument from the wrapped function
signature = inspect.signature(f)
Expand Down

0 comments on commit bed7492

Please sign in to comment.