Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix infinite loop when computing parse tree for recursive nullable grammars #874

Closed
wants to merge 9 commits into from
20 changes: 14 additions & 6 deletions guidance/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,12 +535,16 @@ def _record_captures_partial(self, data, log_prob_data):
def _record_captures_from_root(self, initial_item, data, log_prob_data):
byte_data = self.bytes
stack = [(initial_item, 0)]
used_names = (
set()
) # track which capture names have been used so self-recursive children don't overwrite their parents

# track which capture names have been used so self-recursive children don't overwrite their parents
used_names = set()
# track which items we have seen so we don't process them multiple times, leading to infinite loops
seen = set()
while stack:
item, byte_pos = stack.pop()
if (item, byte_pos) in seen:
Copy link
Collaborator Author

@hudson-ai hudson-ai Jun 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just check item

# skip items we have already processed
continue
seen.add((item, byte_pos))
# terminal nodes
if isinstance(item, Terminal):

Expand Down Expand Up @@ -594,10 +598,14 @@ def _record_captures_from_root(self, initial_item, data, log_prob_data):

def _compute_parse_tree(self, initial_pos, initial_item, reversed_state_sets):
stack = [(initial_pos, initial_item)]

# track which items we have seen so we don't process them multiple times, leading to infinite loops
seen = set()
while stack:
pos, item = stack.pop()

if (pos, item) in seen:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here -- just need to check item?

# skip items we have already processed
continue
seen.add((pos, item))
# compute the children for this item
assert self._compute_children(pos, item, reversed_state_sets)

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"papermill",
"pytest",
"pytest-cov",
"pytest-timeout",
"torch",
"transformers",
"mypy==1.9.0",
Expand Down
122 changes: 121 additions & 1 deletion tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest
from unittest.mock import patch

from guidance import char_set, one_or_more, select, string, zero_or_more
from guidance._grammar import Byte, ByteRange
from guidance._grammar import Byte, ByteRange, Select, Join
from guidance._parser import EarleyCommitParser


Expand Down Expand Up @@ -131,3 +134,120 @@ def test_string_utf8():
parser.consume_byte(b[:1])
assert parser.valid_next_bytes() == set([Byte(b[1:])])
parser.consume_byte(b[1:])


class TestRecursiveNullableGrammars:
"""
Computing parse tree of recursive nullable grammars will cause an infinite
loop if not handled correctly
"""

@pytest.mark.timeout(5)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in love with checking for an infinite loop by introducing a timeout, but it wasn't straight-forward to catch this otherwise. Anyone have a suggestion?

5 seconds should be more than ample to compute the parse tree...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent some time thinking about this, and couldn't come up with anything better 🤷

def test_no_infinite_loop(self):
"""
A -> A
A ->

Loop occurs because `A -> A` is a nullable rule
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens with the JSON 'linked list' here? Doesn't that come down to something like this?

Copy link
Collaborator Author

@hudson-ai hudson-ai Jun 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. This is pretty subtle and took me a while to understand.

The grammar

A -> A 'x'
A ->

is totally fine because the first rule isn't nullable by the definition I gave above because 'x' isn't nullable -- only A is.

I.e.

zero_or_more('x') is fine; zero_or_more(optional('x')) isn't.

The latter looks like

A -> A B
A ->
B -> 'x'
B ->

Linked lists look more like the former.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose the fact that the existing JSON tests passed is a good sign, but I'm still trying to work out these grammars.

"""
# Note that we get a different grammar if we made `A = select([''], recurse=True)`
A = Select([], recursive=True)
A.values = [A, ""]
parser = EarleyCommitParser(A)
# Test that computing the parse tree doesn't hang
parser.parse_tree()

@pytest.mark.timeout(5)
def test_no_infinite_loop_with_terminal(self):
"""
A -> A B
A ->
B -> 'x'
B ->

Loop occurs because `A -> A B` is a nullable rule
"""
B = select(["x", ""])
A = select([B, ""], recurse=True)
parser = EarleyCommitParser(A)
# Test that computing the parse tree doesn't hang
parser.parse_tree()

@pytest.mark.timeout(5)
def test_no_infinite_loop_extra_indirection(self):
"""
A -> A C
A -> B
A ->
B -> A
C -> 'x'

Loop occurs because `A -> B`, `B -> A` are nullable rules
"""
C = Join(["x"])
# Initialize as nullable -- quirk in how nullability is determined in Select
B = Select([""])
# Initialize as nullable -- quirk in how nullability is determined in Select
A = Select([""])
B.values = [A]
A.values = [A + C, B, ""]
assert A.nullable
assert B.nullable
parser = EarleyCommitParser(A)
# Test that computing the parse tree doesn't hang
parser.parse_tree()

@pytest.mark.timeout(5)
def test_captures_from_root(self):
B = select(["x", "y", ""], name="B")
A = select([B, ""], recurse=True, name="A")
parser = EarleyCommitParser(A)

with patch.object(
parser,
"_record_captures_from_root",
wraps=parser._record_captures_from_root,
) as mock:
parser.consume_byte(b"x")
captures, _ = parser.get_captures()
assert mock.call_count == 1
assert captures == {"B": b"x", "A": b"x"}

parser.consume_byte(b"y")
captures, _ = parser.get_captures()
assert mock.call_count == 2
assert captures == {"B": b"y", "A": b"xy"}

parser.consume_byte(b"x")
captures, _ = parser.get_captures()
assert mock.call_count == 3
assert captures == {"B": b"x", "A": b"xyx"}

@pytest.mark.timeout(5)
def test_partial_captures(self):
B = select(["x", "y", ""], name="B")
A = select([B, ""], recurse=True, name="A")
C = A + "z"
parser = EarleyCommitParser(C)

with patch.object(
parser,
"_record_captures_partial",
wraps=parser._record_captures_partial,
) as mock:
parser.consume_byte(b"x")
captures, _ = parser.get_captures()
assert mock.call_count == 1
assert captures == {"B": b"", "A": b"x"}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick note here: I'm not sure what the semantics of partial captures are supposed to be, but this test is consistent with the existing behavior without these nullable loops


parser.consume_byte(b"y")
captures, _ = parser.get_captures()
assert mock.call_count == 2
assert captures == {"B": b"", "A": b"xy"}

# No new call to _record_captures_partial, but make sure that the captures are updated
# when finally called from root
parser.consume_byte(b"z")
captures, _ = parser.get_captures()
assert mock.call_count == 2 # no new call
assert captures == {"B": b"y", "A": b"xy"}
Loading