diff --git a/guidance/_parser.py b/guidance/_parser.py index 600d0cb43..5974e998c 100644 --- a/guidance/_parser.py +++ b/guidance/_parser.py @@ -535,12 +535,16 @@ def _record_captures_partial(self, data, log_prob_data): def _record_captures_from_root(self, initial_item, data, log_prob_data): byte_data = self.bytes stack = [(initial_item, 0)] - used_names = ( - set() - ) # track which capture names have been used so self-recursive children don't overwrite their parents - + # track which capture names have been used so self-recursive children don't overwrite their parents + used_names = set() + # track which items we have seen so we don't process them multiple times, leading to infinite loops + seen = set() while stack: item, byte_pos = stack.pop() + if (item, byte_pos) in seen: + # skip items we have already processed + continue + seen.add((item, byte_pos)) # terminal nodes if isinstance(item, Terminal): @@ -594,10 +598,14 @@ def _record_captures_from_root(self, initial_item, data, log_prob_data): def _compute_parse_tree(self, initial_pos, initial_item, reversed_state_sets): stack = [(initial_pos, initial_item)] - + # track which items we have seen so we don't process them multiple times, leading to infinite loops + seen = set() while stack: pos, item = stack.pop() - + if (pos, item) in seen: + # skip items we have already processed + continue + seen.add((pos, item)) # compute the children for this item assert self._compute_children(pos, item, reversed_state_sets) diff --git a/setup.py b/setup.py index 19cec3a39..b5d54b062 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,7 @@ "papermill", "pytest", "pytest-cov", + "pytest-timeout", "torch", "transformers", "mypy==1.9.0", diff --git a/tests/test_parser.py b/tests/test_parser.py index 63fafbd93..8d9e8ef43 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,5 +1,8 @@ +import pytest +from unittest.mock import patch + from guidance import char_set, one_or_more, select, string, zero_or_more -from guidance._grammar import Byte, ByteRange +from guidance._grammar import Byte, ByteRange, Select, Join from guidance._parser import EarleyCommitParser @@ -131,3 +134,120 @@ def test_string_utf8(): parser.consume_byte(b[:1]) assert parser.valid_next_bytes() == set([Byte(b[1:])]) parser.consume_byte(b[1:]) + + +class TestRecursiveNullableGrammars: + """ + Computing parse tree of recursive nullable grammars will cause an infinite + loop if not handled correctly + """ + + @pytest.mark.timeout(5) + def test_no_infinite_loop(self): + """ + A -> A + A -> + + Loop occurs because `A -> A` is a nullable rule + """ + # Note that we get a different grammar if we made `A = select([''], recurse=True)` + A = Select([], recursive=True) + A.values = [A, ""] + parser = EarleyCommitParser(A) + # Test that computing the parse tree doesn't hang + parser.parse_tree() + + @pytest.mark.timeout(5) + def test_no_infinite_loop_with_terminal(self): + """ + A -> A B + A -> + B -> 'x' + B -> + + Loop occurs because `A -> A B` is a nullable rule + """ + B = select(["x", ""]) + A = select([B, ""], recurse=True) + parser = EarleyCommitParser(A) + # Test that computing the parse tree doesn't hang + parser.parse_tree() + + @pytest.mark.timeout(5) + def test_no_infinite_loop_extra_indirection(self): + """ + A -> A C + A -> B + A -> + B -> A + C -> 'x' + + Loop occurs because `A -> B`, `B -> A` are nullable rules + """ + C = Join(["x"]) + # Initialize as nullable -- quirk in how nullability is determined in Select + B = Select([""]) + # Initialize as nullable -- quirk in how nullability is determined in Select + A = Select([""]) + B.values = [A] + A.values = [A + C, B, ""] + assert A.nullable + assert B.nullable + parser = EarleyCommitParser(A) + # Test that computing the parse tree doesn't hang + parser.parse_tree() + + @pytest.mark.timeout(5) + def test_captures_from_root(self): + B = select(["x", "y", ""], name="B") + A = select([B, ""], recurse=True, name="A") + parser = EarleyCommitParser(A) + + with patch.object( + parser, + "_record_captures_from_root", + wraps=parser._record_captures_from_root, + ) as mock: + parser.consume_byte(b"x") + captures, _ = parser.get_captures() + assert mock.call_count == 1 + assert captures == {"B": b"x", "A": b"x"} + + parser.consume_byte(b"y") + captures, _ = parser.get_captures() + assert mock.call_count == 2 + assert captures == {"B": b"y", "A": b"xy"} + + parser.consume_byte(b"x") + captures, _ = parser.get_captures() + assert mock.call_count == 3 + assert captures == {"B": b"x", "A": b"xyx"} + + @pytest.mark.timeout(5) + def test_partial_captures(self): + B = select(["x", "y", ""], name="B") + A = select([B, ""], recurse=True, name="A") + C = A + "z" + parser = EarleyCommitParser(C) + + with patch.object( + parser, + "_record_captures_partial", + wraps=parser._record_captures_partial, + ) as mock: + parser.consume_byte(b"x") + captures, _ = parser.get_captures() + assert mock.call_count == 1 + assert captures == {"B": b"", "A": b"x"} + + parser.consume_byte(b"y") + captures, _ = parser.get_captures() + assert mock.call_count == 2 + assert captures == {"B": b"", "A": b"xy"} + + # No new call to _record_captures_partial, but make sure that the captures are updated + # when finally called from root + parser.consume_byte(b"z") + captures, _ = parser.get_captures() + assert mock.call_count == 2 # no new call + assert captures == {"B": b"y", "A": b"xy"}