Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix infinite loop when computing parse tree for recursive nullable grammars #874

Closed
wants to merge 9 commits into from
20 changes: 14 additions & 6 deletions guidance/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,12 +535,16 @@ def _record_captures_partial(self, data, log_prob_data):
def _record_captures_from_root(self, initial_item, data, log_prob_data):
byte_data = self.bytes
stack = [(initial_item, 0)]
used_names = (
set()
) # track which capture names have been used so self-recursive children don't overwrite their parents

# track which capture names have been used so self-recursive children don't overwrite their parents
used_names = set()
# track which items we have seen so we don't process them multiple times, leading to infinite loops
seen = set()
while stack:
item, byte_pos = stack.pop()
if (item, byte_pos) in seen:
Copy link
Collaborator Author

@hudson-ai hudson-ai Jun 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just check item

# skip items we have already processed
continue
seen.add((item, byte_pos))
# terminal nodes
if isinstance(item, Terminal):

Expand Down Expand Up @@ -594,10 +598,14 @@ def _record_captures_from_root(self, initial_item, data, log_prob_data):

def _compute_parse_tree(self, initial_pos, initial_item, reversed_state_sets):
stack = [(initial_pos, initial_item)]

# track which items we have seen so we don't process them multiple times, leading to infinite loops
seen = set()
while stack:
pos, item = stack.pop()

if (pos, item) in seen:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here -- just need to check item?

# skip items we have already processed
continue
seen.add((pos, item))
# compute the children for this item
assert self._compute_children(pos, item, reversed_state_sets)

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"papermill",
"pytest",
"pytest-cov",
"pytest-timeout",
"torch",
"transformers",
"mypy==1.9.0",
Expand Down
123 changes: 122 additions & 1 deletion tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest
from unittest.mock import patch

from guidance import char_set, one_or_more, select, string, zero_or_more
from guidance._grammar import Byte, ByteRange
from guidance._grammar import Byte, ByteRange, Select, Join
from guidance._parser import EarleyCommitParser


Expand Down Expand Up @@ -131,3 +134,121 @@ def test_string_utf8():
parser.consume_byte(b[:1])
assert parser.valid_next_bytes() == set([Byte(b[1:])])
parser.consume_byte(b[1:])


class TestRecursiveNullableGrammars:
"""
Computing parse tree of recursive nullable grammars will cause an infinite
loop if not handled correctly
"""

@pytest.mark.timeout(5)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in love with checking for an infinite loop by introducing a timeout, but it wasn't straight-forward to catch this otherwise. Anyone have a suggestion?

5 seconds should be more than ample to compute the parse tree...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent some time thinking about this, and couldn't come up with anything better 🤷

def test_no_infinite_loop(self):
"""
A -> A
A ->

Loop occurs because `A -> A` is a nullable rule
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens with the JSON 'linked list' here? Doesn't that come down to something like this?

Copy link
Collaborator Author

@hudson-ai hudson-ai Jun 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. This is pretty subtle and took me a while to understand.

The grammar

A -> A 'x'
A ->

is totally fine because the first rule isn't nullable by the definition I gave above because 'x' isn't nullable -- only A is.

I.e.

zero_or_more('x') is fine; zero_or_more(optional('x')) isn't.

The latter looks like

A -> A B
A ->
B -> 'x'
B ->

Linked lists look more like the former.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose the fact that the existing JSON tests passed is a good sign, but I'm still trying to work out these grammars.

"""
# Note that we get a different grammar if we made `A = select([''], recurse=True)`
A = Select([], recursive=True)
A.values = [A, ""]
parser = EarleyCommitParser(A)
# Test that computing the parse tree doesn't hang
parser.parse_tree()
# Test that getting captures doesn't hang
parser.get_captures()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you assert anything about the last two operations?

Also, if either of them can hang on an infinite loop, would it be worth having two tests, to identify which one hit the loop (although that may be more than needed).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reserved making asserts for some of the other tests, but I'll have a think 😄

get_captures actually calls parse_tree under the hood BUT it can also hang even if parse_tree doesn't. I can maybe put another mock in here to disentangle them a bit..?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't worry about injecting another mock. I was just wondering if there was any visible state change which could be asserted after each step?


@pytest.mark.timeout(5)
def test_no_infinite_loop_with_terminal(self):
"""
A -> A B
A ->
B -> 'x'
B ->

Loop occurs because `A -> A B` is a nullable rule
"""
B = select(["x", ""])
A = select([B, ""], recurse=True)
parser = EarleyCommitParser(A)
# Test that computing the parse tree doesn't hang
parser.parse_tree()
# Test that getting captures doesn't hang
parser.get_captures()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all of these tests, it would be nice if there was something to assert about the state of the parser at the end.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Will keep thinking about this!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working on making some assertions here, I became dissatisfied with the fact that my fix doesn't actually prevent infinite parse trees like A -> A -> ...; it only prevents looping over these parse trees forever. The behavior is right in that we're still able to correctly get captures, so maybe this is fine. But marking this as a draft in the meantime while I see if I can get a bit closer to the root of the issue.


@pytest.mark.timeout(5)
def test_no_infinite_loop_extra_indirection(self):
"""
A -> A C
A -> B
A ->
B -> A
C -> 'x'

Loop occurs because `A -> B`, `B -> A` are nullable rules
"""
C = Join(["x"])
# Initialize as nullable -- quirk in how nullability is determined in Select
B = Select([""])
# Initialize as nullable -- quirk in how nullability is determined in Select
A = Select([""])
B.values = [A]
A.values = [A + C, B, ""]
assert A.nullable
assert B.nullable
parser = EarleyCommitParser(A)
# Test that computing the parse tree doesn't hang
parser.parse_tree()
# Test that getting captures doesn't hang
parser.get_captures()

@pytest.mark.timeout(5)
def test_captures_from_root(self):
B = select(["x", ""], name="B")
A = select([B, ""], recurse=True, name="A")
parser = EarleyCommitParser(A)

with patch.object(
parser,
"_record_captures_from_root",
wraps=parser._record_captures_from_root,
) as mock:
parser.consume_byte(b"x")
captures, _ = parser.get_captures()
assert mock.call_count == 1
assert captures == {"B": b"x", "A": b"x"}

parser.consume_byte(b"x")
captures, _ = parser.get_captures()
assert mock.call_count == 2
assert captures == {"B": b"x", "A": b"xx"}

@pytest.mark.timeout(5)
def test_partial_captures(self):
B = select(["x", ""], name="B")
A = select([B, ""], recurse=True, name="A")
C = A + "y"
parser = EarleyCommitParser(C)

with patch.object(
parser,
"_record_captures_partial",
wraps=parser._record_captures_partial,
) as mock:
parser.consume_byte(b"x")
captures, _ = parser.get_captures()
assert mock.call_count == 1
assert captures == {"B": b"", "A": b"x"}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick note here: I'm not sure what the semantics of partial captures are supposed to be, but this test is consistent with the existing behavior without these nullable loops


parser.consume_byte(b"x")
captures, _ = parser.get_captures()
assert mock.call_count == 2
assert captures == {"B": b"", "A": b"xx"}

# No new call to _record_captures_partial, but make sure that the captures are updated
# when finally called from root
parser.consume_byte(b"y")
captures, _ = parser.get_captures()
assert mock.call_count == 2 # no new call
assert captures == {"B": b"x", "A": b"xx"}