-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Fix infinite loop when computing parse tree for recursive nullable grammars #874
Changes from all commits
f17e434
de33cf5
5617cf2
e04d693
ca2fca9
40533bb
3f2a3ef
7c64496
867adbe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -535,12 +535,16 @@ def _record_captures_partial(self, data, log_prob_data): | |
def _record_captures_from_root(self, initial_item, data, log_prob_data): | ||
byte_data = self.bytes | ||
stack = [(initial_item, 0)] | ||
used_names = ( | ||
set() | ||
) # track which capture names have been used so self-recursive children don't overwrite their parents | ||
|
||
# track which capture names have been used so self-recursive children don't overwrite their parents | ||
used_names = set() | ||
# track which items we have seen so we don't process them multiple times, leading to infinite loops | ||
seen = set() | ||
while stack: | ||
item, byte_pos = stack.pop() | ||
if (item, byte_pos) in seen: | ||
# skip items we have already processed | ||
continue | ||
seen.add((item, byte_pos)) | ||
# terminal nodes | ||
if isinstance(item, Terminal): | ||
|
||
|
@@ -594,10 +598,14 @@ def _record_captures_from_root(self, initial_item, data, log_prob_data): | |
|
||
def _compute_parse_tree(self, initial_pos, initial_item, reversed_state_sets): | ||
stack = [(initial_pos, initial_item)] | ||
|
||
# track which items we have seen so we don't process them multiple times, leading to infinite loops | ||
seen = set() | ||
while stack: | ||
pos, item = stack.pop() | ||
|
||
if (pos, item) in seen: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here -- just need to check |
||
# skip items we have already processed | ||
continue | ||
seen.add((pos, item)) | ||
# compute the children for this item | ||
assert self._compute_children(pos, item, reversed_state_sets) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,7 @@ | |
"papermill", | ||
"pytest", | ||
"pytest-cov", | ||
"pytest-timeout", | ||
"torch", | ||
"transformers", | ||
"mypy==1.9.0", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
import pytest | ||
from unittest.mock import patch | ||
|
||
from guidance import char_set, one_or_more, select, string, zero_or_more | ||
from guidance._grammar import Byte, ByteRange | ||
from guidance._grammar import Byte, ByteRange, Select, Join | ||
from guidance._parser import EarleyCommitParser | ||
|
||
|
||
|
@@ -131,3 +134,120 @@ def test_string_utf8(): | |
parser.consume_byte(b[:1]) | ||
assert parser.valid_next_bytes() == set([Byte(b[1:])]) | ||
parser.consume_byte(b[1:]) | ||
|
||
|
||
class TestRecursiveNullableGrammars: | ||
""" | ||
Computing parse tree of recursive nullable grammars will cause an infinite | ||
loop if not handled correctly | ||
""" | ||
|
||
@pytest.mark.timeout(5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not in love with checking for an infinite loop by introducing a timeout, but it wasn't straight-forward to catch this otherwise. Anyone have a suggestion? 5 seconds should be more than ample to compute the parse tree... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I spent some time thinking about this, and couldn't come up with anything better 🤷 |
||
def test_no_infinite_loop(self): | ||
""" | ||
A -> A | ||
A -> | ||
|
||
Loop occurs because `A -> A` is a nullable rule | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens with the JSON 'linked list' here? Doesn't that come down to something like this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question. This is pretty subtle and took me a while to understand. The grammar
is totally fine because the first rule isn't nullable by the definition I gave above because I.e.
The latter looks like
Linked lists look more like the former. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose the fact that the existing JSON tests passed is a good sign, but I'm still trying to work out these grammars. |
||
""" | ||
# Note that we get a different grammar if we made `A = select([''], recurse=True)` | ||
A = Select([], recursive=True) | ||
A.values = [A, ""] | ||
parser = EarleyCommitParser(A) | ||
# Test that computing the parse tree doesn't hang | ||
parser.parse_tree() | ||
|
||
@pytest.mark.timeout(5) | ||
def test_no_infinite_loop_with_terminal(self): | ||
""" | ||
A -> A B | ||
A -> | ||
B -> 'x' | ||
B -> | ||
|
||
Loop occurs because `A -> A B` is a nullable rule | ||
""" | ||
B = select(["x", ""]) | ||
A = select([B, ""], recurse=True) | ||
parser = EarleyCommitParser(A) | ||
# Test that computing the parse tree doesn't hang | ||
parser.parse_tree() | ||
|
||
@pytest.mark.timeout(5) | ||
def test_no_infinite_loop_extra_indirection(self): | ||
""" | ||
A -> A C | ||
A -> B | ||
A -> | ||
B -> A | ||
C -> 'x' | ||
|
||
Loop occurs because `A -> B`, `B -> A` are nullable rules | ||
""" | ||
C = Join(["x"]) | ||
# Initialize as nullable -- quirk in how nullability is determined in Select | ||
B = Select([""]) | ||
# Initialize as nullable -- quirk in how nullability is determined in Select | ||
A = Select([""]) | ||
B.values = [A] | ||
A.values = [A + C, B, ""] | ||
assert A.nullable | ||
assert B.nullable | ||
parser = EarleyCommitParser(A) | ||
# Test that computing the parse tree doesn't hang | ||
parser.parse_tree() | ||
|
||
@pytest.mark.timeout(5) | ||
def test_captures_from_root(self): | ||
B = select(["x", "y", ""], name="B") | ||
A = select([B, ""], recurse=True, name="A") | ||
parser = EarleyCommitParser(A) | ||
|
||
with patch.object( | ||
parser, | ||
"_record_captures_from_root", | ||
wraps=parser._record_captures_from_root, | ||
) as mock: | ||
parser.consume_byte(b"x") | ||
captures, _ = parser.get_captures() | ||
assert mock.call_count == 1 | ||
assert captures == {"B": b"x", "A": b"x"} | ||
|
||
parser.consume_byte(b"y") | ||
captures, _ = parser.get_captures() | ||
assert mock.call_count == 2 | ||
assert captures == {"B": b"y", "A": b"xy"} | ||
|
||
parser.consume_byte(b"x") | ||
captures, _ = parser.get_captures() | ||
assert mock.call_count == 3 | ||
assert captures == {"B": b"x", "A": b"xyx"} | ||
|
||
@pytest.mark.timeout(5) | ||
def test_partial_captures(self): | ||
B = select(["x", "y", ""], name="B") | ||
A = select([B, ""], recurse=True, name="A") | ||
C = A + "z" | ||
parser = EarleyCommitParser(C) | ||
|
||
with patch.object( | ||
parser, | ||
"_record_captures_partial", | ||
wraps=parser._record_captures_partial, | ||
) as mock: | ||
parser.consume_byte(b"x") | ||
captures, _ = parser.get_captures() | ||
assert mock.call_count == 1 | ||
assert captures == {"B": b"", "A": b"x"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quick note here: I'm not sure what the semantics of partial captures are supposed to be, but this test is consistent with the existing behavior without these nullable loops |
||
|
||
parser.consume_byte(b"y") | ||
captures, _ = parser.get_captures() | ||
assert mock.call_count == 2 | ||
assert captures == {"B": b"", "A": b"xy"} | ||
|
||
# No new call to _record_captures_partial, but make sure that the captures are updated | ||
# when finally called from root | ||
parser.consume_byte(b"z") | ||
captures, _ = parser.get_captures() | ||
assert mock.call_count == 2 # no new call | ||
assert captures == {"B": b"y", "A": b"xy"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just check
item