From f17e434d512afa502975b62c9f07777e8bd7ba0f Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Fri, 31 May 2024 18:06:27 -0700 Subject: [PATCH 1/8] First pass at a fix for #863 --- guidance/_parser.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/guidance/_parser.py b/guidance/_parser.py index 600d0cb43..331f2210e 100644 --- a/guidance/_parser.py +++ b/guidance/_parser.py @@ -594,10 +594,14 @@ def _record_captures_from_root(self, initial_item, data, log_prob_data): def _compute_parse_tree(self, initial_pos, initial_item, reversed_state_sets): stack = [(initial_pos, initial_item)] + seen = set() while stack: pos, item = stack.pop() - + if (pos, item) in seen: + # Skip items we have already processed + continue + seen.add((pos, item)) # compute the children for this item assert self._compute_children(pos, item, reversed_state_sets) From de33cf57d062bcbc51b922dd951f6b453f9d4dd2 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Sat, 1 Jun 2024 17:16:50 -0700 Subject: [PATCH 2/8] Add fix to _record_captures_from_root --- guidance/_parser.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/guidance/_parser.py b/guidance/_parser.py index 331f2210e..3ec79a180 100644 --- a/guidance/_parser.py +++ b/guidance/_parser.py @@ -538,9 +538,12 @@ def _record_captures_from_root(self, initial_item, data, log_prob_data): used_names = ( set() ) # track which capture names have been used so self-recursive children don't overwrite their parents - + seen = set() while stack: item, byte_pos = stack.pop() + if (item, byte_pos) in seen: + continue + seen.add((item, byte_pos)) # terminal nodes if isinstance(item, Terminal): From 5617cf2cf50cc7dd57c8ae0cae5f74039eb8feba Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Sat, 1 Jun 2024 17:53:37 -0700 Subject: [PATCH 3/8] Add tests (+ pytest-timeout test dependency) --- setup.py | 1 + tests/test_parser.py | 49 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 19cec3a39..b5d54b062 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,7 @@ "papermill", "pytest", "pytest-cov", + "pytest-timeout", "torch", "transformers", "mypy==1.9.0", diff --git a/tests/test_parser.py b/tests/test_parser.py index 63fafbd93..451d2e821 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,5 +1,7 @@ +import pytest + from guidance import char_set, one_or_more, select, string, zero_or_more -from guidance._grammar import Byte, ByteRange +from guidance._grammar import Byte, ByteRange, Select from guidance._parser import EarleyCommitParser @@ -131,3 +133,48 @@ def test_string_utf8(): parser.consume_byte(b[:1]) assert parser.valid_next_bytes() == set([Byte(b[1:])]) parser.consume_byte(b[1:]) + + +class TestRecursiveNullableGrammars: + @pytest.mark.timeout(5) + def test_no_infinite_loop(self): + """ + A -> A + A -> + """ + # Note that we get a different grammar if we made `A = select([''], recurse=True)` + A = Select([], recursive=True) + A.values = [A, ""] + parser = EarleyCommitParser(A) + # Test that computing the parse tree doesn't hang + parser.parse_tree() + # Test that getting captures doesn't hang + parser.get_captures() + + @pytest.mark.timeout(5) + def test_no_infinite_loop_with_terminal(self): + """ + A -> A B + A -> + B -> 'x' + B -> + """ + B = select(["x", ""]) + A = select([B, ""], recurse=True) + parser = EarleyCommitParser(A) + # Test that computing the parse tree doesn't hang + parser.parse_tree() + # Test that getting captures doesn't hang + parser.get_captures() + + @pytest.mark.timeout(5) + def test_captures(self): + B = select(["x", ""], name="B") + A = select([B, ""], recurse=True, name="A") + parser = EarleyCommitParser(A) + parser.consume_byte(b"x") + captures, _ = parser.get_captures() + assert captures == {"B": b"x", "A": b"x"} + parser.consume_byte(b"x") + captures, _ = parser.get_captures() + assert captures == {"B": b"x", "A": b"xx"} From e04d693566bbdfb9126c89a3dd63376b6e69a806 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Sat, 1 Jun 2024 18:20:32 -0700 Subject: [PATCH 4/8] Test both partial and full captures --- tests/test_parser.py | 53 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/tests/test_parser.py b/tests/test_parser.py index 451d2e821..2fc9fce8e 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,4 +1,5 @@ import pytest +from unittest.mock import patch from guidance import char_set, one_or_more, select, string, zero_or_more from guidance._grammar import Byte, ByteRange, Select @@ -168,13 +169,51 @@ def test_no_infinite_loop_with_terminal(self): parser.get_captures() @pytest.mark.timeout(5) - def test_captures(self): + def test_captures_from_root(self): B = select(["x", ""], name="B") A = select([B, ""], recurse=True, name="A") parser = EarleyCommitParser(A) - parser.consume_byte(b"x") - captures, _ = parser.get_captures() - assert captures == {"B": b"x", "A": b"x"} - parser.consume_byte(b"x") - captures, _ = parser.get_captures() - assert captures == {"B": b"x", "A": b"xx"} + + with patch.object( + parser, + "_record_captures_from_root", + wraps=parser._record_captures_from_root, + ) as mock: + parser.consume_byte(b"x") + captures, _ = parser.get_captures() + assert mock.call_count == 1 + assert captures == {"B": b"x", "A": b"x"} + + parser.consume_byte(b"x") + captures, _ = parser.get_captures() + assert mock.call_count == 2 + assert captures == {"B": b"x", "A": b"xx"} + + @pytest.mark.timeout(5) + def test_partial_captures(self): + B = select(["x", ""], name="B") + A = select([B, ""], recurse=True, name="A") + C = A + "y" + parser = EarleyCommitParser(C) + + with patch.object( + parser, + "_record_captures_partial", + wraps=parser._record_captures_partial, + ) as mock: + parser.consume_byte(b"x") + captures, _ = parser.get_captures() + assert mock.call_count == 1 + assert captures == {"B": b"", "A": b"x"} + + parser.consume_byte(b"x") + captures, _ = parser.get_captures() + assert mock.call_count == 2 + assert captures == {"B": b"", "A": b"xx"} + + # No new call to _record_captures_partial, but make sure that the captures are updated + # when finally called from root + parser.consume_byte(b"y") + captures, _ = parser.get_captures() + assert mock.call_count == 2 # no new call + assert captures == {"B": b"x", "A": b"xx"} From 40533bb65bc04c4938c4d22264d6bb681a0c3a63 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Sat, 1 Jun 2024 18:25:36 -0700 Subject: [PATCH 5/8] comments --- guidance/_parser.py | 11 ++++++----- tests/test_parser.py | 4 ++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/guidance/_parser.py b/guidance/_parser.py index 3ec79a180..5974e998c 100644 --- a/guidance/_parser.py +++ b/guidance/_parser.py @@ -535,13 +535,14 @@ def _record_captures_partial(self, data, log_prob_data): def _record_captures_from_root(self, initial_item, data, log_prob_data): byte_data = self.bytes stack = [(initial_item, 0)] - used_names = ( - set() - ) # track which capture names have been used so self-recursive children don't overwrite their parents + # track which capture names have been used so self-recursive children don't overwrite their parents + used_names = set() + # track which items we have seen so we don't process them multiple times, leading to infinite loops seen = set() while stack: item, byte_pos = stack.pop() if (item, byte_pos) in seen: + # skip items we have already processed continue seen.add((item, byte_pos)) # terminal nodes @@ -597,12 +598,12 @@ def _record_captures_from_root(self, initial_item, data, log_prob_data): def _compute_parse_tree(self, initial_pos, initial_item, reversed_state_sets): stack = [(initial_pos, initial_item)] + # track which items we have seen so we don't process them multiple times, leading to infinite loops seen = set() - while stack: pos, item = stack.pop() if (pos, item) in seen: - # Skip items we have already processed + # skip items we have already processed continue seen.add((pos, item)) # compute the children for this item diff --git a/tests/test_parser.py b/tests/test_parser.py index 2fc9fce8e..848314ec1 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -137,6 +137,10 @@ def test_string_utf8(): class TestRecursiveNullableGrammars: + """ + Computing parse tree of recursive nullable grammars will cause an infinite + loop if not handled correctly + """ @pytest.mark.timeout(5) def test_no_infinite_loop(self): """ From 3f2a3ef52789d2baa6a9a0fe8eeff34c1491db85 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Sat, 1 Jun 2024 21:24:20 -0700 Subject: [PATCH 6/8] One more test to reflect Loup Vaillant's tutorial --- tests/test_parser.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/test_parser.py b/tests/test_parser.py index 848314ec1..40b23779d 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -2,7 +2,7 @@ from unittest.mock import patch from guidance import char_set, one_or_more, select, string, zero_or_more -from guidance._grammar import Byte, ByteRange, Select +from guidance._grammar import Byte, ByteRange, Select, Join from guidance._parser import EarleyCommitParser @@ -141,11 +141,14 @@ class TestRecursiveNullableGrammars: Computing parse tree of recursive nullable grammars will cause an infinite loop if not handled correctly """ + @pytest.mark.timeout(5) def test_no_infinite_loop(self): """ A -> A A -> + + Loop occurs because `A -> A` is a nullable rule """ # Note that we get a different grammar if we made `A = select([''], recurse=True)` A = Select([], recursive=True) @@ -163,6 +166,8 @@ def test_no_infinite_loop_with_terminal(self): A -> B -> 'x' B -> + + Loop occurs because `A -> A B` is a nullable rule """ B = select(["x", ""]) A = select([B, ""], recurse=True) @@ -172,6 +177,32 @@ def test_no_infinite_loop_with_terminal(self): # Test that getting captures doesn't hang parser.get_captures() + @pytest.mark.timeout(5) + def test_no_infinite_loop_extra_indirection(self): + """ + A -> A C + A -> B + A -> + B -> A + C -> 'x' + + Loop occurs because `A -> B`, `B -> A` are nullable rules + """ + C = Join(["x"]) + # Initialize as nullable -- quirk in how nullability is determined in Select + B = Select([""]) + # Initialize as nullable -- quirk in how nullability is determined in Select + A = Select([""]) + B.values = [A] + A.values = [A + C, B, ""] + assert A.nullable + assert B.nullable + parser = EarleyCommitParser(A) + # Test that computing the parse tree doesn't hang + parser.parse_tree() + # Test that getting captures doesn't hang + parser.get_captures() + @pytest.mark.timeout(5) def test_captures_from_root(self): B = select(["x", ""], name="B") From 7c644967797687c61a37147d1a3e34076302ea95 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Mon, 3 Jun 2024 15:30:55 -0700 Subject: [PATCH 7/8] Add other option for B to better track capture state --- tests/test_parser.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/test_parser.py b/tests/test_parser.py index 40b23779d..9bce37a02 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -205,7 +205,7 @@ def test_no_infinite_loop_extra_indirection(self): @pytest.mark.timeout(5) def test_captures_from_root(self): - B = select(["x", ""], name="B") + B = select(["x", "y", ""], name="B") A = select([B, ""], recurse=True, name="A") parser = EarleyCommitParser(A) @@ -219,16 +219,21 @@ def test_captures_from_root(self): assert mock.call_count == 1 assert captures == {"B": b"x", "A": b"x"} - parser.consume_byte(b"x") + parser.consume_byte(b"y") captures, _ = parser.get_captures() assert mock.call_count == 2 - assert captures == {"B": b"x", "A": b"xx"} + assert captures == {"B": b"y", "A": b"xy"} + + parser.consume_byte(b"x") + captures, _ = parser.get_captures() + assert mock.call_count == 3 + assert captures == {"B": b"x", "A": b"xyx"} @pytest.mark.timeout(5) def test_partial_captures(self): - B = select(["x", ""], name="B") + B = select(["x", "y", ""], name="B") A = select([B, ""], recurse=True, name="A") - C = A + "y" + C = A + "z" parser = EarleyCommitParser(C) with patch.object( @@ -241,14 +246,14 @@ def test_partial_captures(self): assert mock.call_count == 1 assert captures == {"B": b"", "A": b"x"} - parser.consume_byte(b"x") + parser.consume_byte(b"y") captures, _ = parser.get_captures() assert mock.call_count == 2 - assert captures == {"B": b"", "A": b"xx"} + assert captures == {"B": b"", "A": b"xy"} # No new call to _record_captures_partial, but make sure that the captures are updated # when finally called from root - parser.consume_byte(b"y") + parser.consume_byte(b"z") captures, _ = parser.get_captures() assert mock.call_count == 2 # no new call - assert captures == {"B": b"x", "A": b"xx"} + assert captures == {"B": b"y", "A": b"xy"} From 867adbe8b3d7ae9e203693db3d0024466444db7d Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Mon, 3 Jun 2024 15:36:28 -0700 Subject: [PATCH 8/8] keep parse_tree and get_captures tests separate --- tests/test_parser.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/test_parser.py b/tests/test_parser.py index 9bce37a02..8d9e8ef43 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -156,8 +156,6 @@ def test_no_infinite_loop(self): parser = EarleyCommitParser(A) # Test that computing the parse tree doesn't hang parser.parse_tree() - # Test that getting captures doesn't hang - parser.get_captures() @pytest.mark.timeout(5) def test_no_infinite_loop_with_terminal(self): @@ -174,8 +172,6 @@ def test_no_infinite_loop_with_terminal(self): parser = EarleyCommitParser(A) # Test that computing the parse tree doesn't hang parser.parse_tree() - # Test that getting captures doesn't hang - parser.get_captures() @pytest.mark.timeout(5) def test_no_infinite_loop_extra_indirection(self): @@ -200,8 +196,6 @@ def test_no_infinite_loop_extra_indirection(self): parser = EarleyCommitParser(A) # Test that computing the parse tree doesn't hang parser.parse_tree() - # Test that getting captures doesn't hang - parser.get_captures() @pytest.mark.timeout(5) def test_captures_from_root(self):