diff --git a/src/lightspeed_rag_content/document_processor.py b/src/lightspeed_rag_content/document_processor.py index 6ff00ccd..4ff2e01a 100644 --- a/src/lightspeed_rag_content/document_processor.py +++ b/src/lightspeed_rag_content/document_processor.py @@ -84,22 +84,35 @@ def _got_whitespace(text: str) -> bool: return True return False - @classmethod - def _filter_out_invalid_nodes(cls, nodes: list[Any]) -> list[TextNode]: + @staticmethod + def _got_non_headers(text: str) -> bool: + """Check if text has content besides markdown headers.""" + for line in text.splitlines(): + line = line.strip() + if line and not line.startswith("#"): + return True + return False + + def _valid_text_node(self, text: str) -> bool: + """Check if text node is valid: has whitespace and has content.""" + if self.config.doc_type == "markdown" and not self._got_non_headers(text): + return False + return self._got_whitespace(text) + + def _filter_out_invalid_nodes(self, nodes: list[Any]) -> list[TextNode]: """Filter out invalid nodes.""" good_nodes = [] for node in nodes: - if isinstance(node, TextNode) and cls._got_whitespace(node.text): + if isinstance(node, TextNode) and self._valid_text_node(node.text): # Exclude given metadata during embedding good_nodes.append(node) else: - LOG.debug("Skipping node without whitespace: %s", repr(node)) + LOG.debug("Skipping invalid node: %s", repr(node)) return good_nodes - @classmethod - def _split_and_filter(cls, docs: list[Document]) -> list[TextNode]: + def _split_and_filter(self, docs: list[Document]) -> list[TextNode]: nodes = Settings.text_splitter.get_nodes_from_documents(docs) - valid_nodes = cls._filter_out_invalid_nodes(nodes) + valid_nodes = self._filter_out_invalid_nodes(nodes) return valid_nodes diff --git a/tests/test_document_processor_llama_index.py b/tests/test_document_processor_llama_index.py index 846c3f12..4f371531 100644 --- a/tests/test_document_processor_llama_index.py +++ b/tests/test_document_processor_llama_index.py @@ -19,7 +19,7 @@ import pytest from llama_index.core import Document -from llama_index.core.schema import TextNode +from llama_index.core.schema import Node, TextNode from lightspeed_rag_content import document_processor from tests.conftest import RagMockEmbedding @@ -72,17 +72,107 @@ def test__got_whitespace_true(self, doc_processor): result = doc_processor["processor"].db._got_whitespace(text) assert result + def test__valid_text_node(self, doc_processor): + """Test that valid text node checks for got whitespace on non markdown chunker.""" + db = doc_processor["processor"].db + db.config.doc_type = "plain" + + with ( + mock.patch.object(db, "_got_whitespace") as mock_got_ws, + mock.patch.object(db, "_got_non_headers") as mock_got_nh, + ): + text = "NoWhitespace" + res = db._valid_text_node(text) + assert res is mock_got_ws.return_value + mock_got_ws.assert_called_once_with(text) + mock_got_nh.assert_not_called() + + def test__valid_text_node_markdown_non_headers_true(self, doc_processor): + """Test that text node is valid when markdown has non header content.""" + db = doc_processor["processor"].db + db.config.doc_type = "markdown" + + with ( + mock.patch.object(db, "_got_non_headers", return_value=True) as mock_got_nh, + mock.patch.object(db, "_got_whitespace") as mock_got_ws, + ): + text = "# Header\nActual content here" + res = db._valid_text_node(text) + assert res is mock_got_ws.return_value + mock_got_nh.assert_called_once_with(text) + mock_got_ws.assert_called_once_with(text) + + def test__valid_text_node_markdown_non_headers_false(self, doc_processor): + """Test that text node is invalid when markdown only has headers.""" + db = doc_processor["processor"].db + db.config.doc_type = "markdown" + + with ( + mock.patch.object( + db, "_got_non_headers", return_value=False + ) as mock_got_nh, + mock.patch.object(db, "_got_whitespace") as mock_got_ws, + ): + text = "# Header1\n# Header2\n\n" + res = db._valid_text_node(text) + assert res is False + mock_got_nh.assert_called_once_with(text) + mock_got_ws.assert_not_called() + + @pytest.mark.parametrize( + "text", + [ + "# Header\nSome content", # Header followed by content + "# Header1\n# Header2\nAlso here", # Multiple headers, then content + "No headers, just content", # No headers, just content + "# H\n# H2\n\tThis is non-header", # Tabs and spaces before content + ], + ) + def test__got_non_headers_with_content(self, doc_processor, text): + """Test we detect when markdown has something beside headers.""" + db = doc_processor["processor"].db + assert db._got_non_headers(text) is True + + @pytest.mark.parametrize( + "text", + [ + "# Only header", + "# Another header\n## Subheader", + "# Header with space \n", + "## \n#", + "#Header1\n#Header2\n#Header3", + "# ", # header with whitespace + " # Header with leading space", + " \n\t \n", # only whitespaces + ], + ) + def test__got_non_headers_only_headers(self, doc_processor, text): + """Test we detect when markdown only has headers.""" + db = doc_processor["processor"].db + assert db._got_non_headers(text) is False + def test__filter_out_invalid_nodes(self, doc_processor): - """Test that _filter_out_invalid_nodes only returns nodes with whitespace.""" - fake_node_0 = mock.Mock(spec=TextNode) - fake_node_1 = mock.Mock(spec=TextNode) - fake_node_0.text = "Got whitespace" - fake_node_1.text = "NoWhitespace" - - result = doc_processor["processor"].db._filter_out_invalid_nodes( - [fake_node_0, fake_node_1] - ) - assert result == [fake_node_0] + """Test that _filter_out_invalid_nodes checks for validity of text nodes.""" + fake_text_node_0 = mock.Mock(spec=TextNode, text="fake_text_node_0") + fake_text_node_1 = mock.Mock(spec=TextNode, text="fake_text_node_1") + fake_node_2 = mock.Mock(spec=Node) + + db = doc_processor["processor"].db + with mock.patch.object( + db, "_valid_text_node", side_effect=(True, False) + ) as valid_tn_mock: + result = db._filter_out_invalid_nodes( + [fake_text_node_0, fake_text_node_1, fake_node_2] + ) + + expected_calls = [ + mock.call(fake_text_node_0.text), + mock.call(fake_text_node_1.text), + ] + valid_tn_mock.assert_has_calls(expected_calls) + assert len(expected_calls) == valid_tn_mock.call_count + + assert result == [fake_text_node_0] def test__save_index(self, mocker, doc_processor): """Test that _save_index sets index ID and persists the storage context."""