Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions src/lightspeed_rag_content/document_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,35 @@ def _got_whitespace(text: str) -> bool:
return True
return False

@classmethod
def _filter_out_invalid_nodes(cls, nodes: list[Any]) -> list[TextNode]:
@staticmethod
def _got_non_headers(text: str) -> bool:
"""Check if text has content besides markdown headers."""
for line in text.splitlines():
line = line.strip()
if line and not line.startswith("#"):
return True
return False

def _valid_text_node(self, text: str) -> bool:
"""Check if text node is valid: has whitespace and has content."""
if self.config.doc_type == "markdown" and not self._got_non_headers(text):
return False
return self._got_whitespace(text)

def _filter_out_invalid_nodes(self, nodes: list[Any]) -> list[TextNode]:
"""Filter out invalid nodes."""
good_nodes = []
for node in nodes:
if isinstance(node, TextNode) and cls._got_whitespace(node.text):
if isinstance(node, TextNode) and self._valid_text_node(node.text):
# Exclude given metadata during embedding
good_nodes.append(node)
else:
LOG.debug("Skipping node without whitespace: %s", repr(node))
LOG.debug("Skipping invalid node: %s", repr(node))
return good_nodes

@classmethod
def _split_and_filter(cls, docs: list[Document]) -> list[TextNode]:
def _split_and_filter(self, docs: list[Document]) -> list[TextNode]:
nodes = Settings.text_splitter.get_nodes_from_documents(docs)
valid_nodes = cls._filter_out_invalid_nodes(nodes)
valid_nodes = self._filter_out_invalid_nodes(nodes)
return valid_nodes


Expand Down
112 changes: 101 additions & 11 deletions tests/test_document_processor_llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import pytest
from llama_index.core import Document
from llama_index.core.schema import TextNode
from llama_index.core.schema import Node, TextNode

from lightspeed_rag_content import document_processor
from tests.conftest import RagMockEmbedding
Expand Down Expand Up @@ -72,17 +72,107 @@ def test__got_whitespace_true(self, doc_processor):
result = doc_processor["processor"].db._got_whitespace(text)
assert result

def test__valid_text_node(self, doc_processor):
"""Test that valid text node checks for got whitespace on non markdown chunker."""
db = doc_processor["processor"].db
db.config.doc_type = "plain"

with (
mock.patch.object(db, "_got_whitespace") as mock_got_ws,
mock.patch.object(db, "_got_non_headers") as mock_got_nh,
):
text = "NoWhitespace"
res = db._valid_text_node(text)
assert res is mock_got_ws.return_value
mock_got_ws.assert_called_once_with(text)
mock_got_nh.assert_not_called()

def test__valid_text_node_markdown_non_headers_true(self, doc_processor):
"""Test that text node is valid when markdown has non header content."""
db = doc_processor["processor"].db
db.config.doc_type = "markdown"

with (
mock.patch.object(db, "_got_non_headers", return_value=True) as mock_got_nh,
mock.patch.object(db, "_got_whitespace") as mock_got_ws,
):
text = "# Header\nActual content here"
res = db._valid_text_node(text)
assert res is mock_got_ws.return_value
mock_got_nh.assert_called_once_with(text)
mock_got_ws.assert_called_once_with(text)

def test__valid_text_node_markdown_non_headers_false(self, doc_processor):
"""Test that text node is invalid when markdown only has headers."""
db = doc_processor["processor"].db
db.config.doc_type = "markdown"

with (
mock.patch.object(
db, "_got_non_headers", return_value=False
) as mock_got_nh,
mock.patch.object(db, "_got_whitespace") as mock_got_ws,
):
text = "# Header1\n# Header2\n\n"
res = db._valid_text_node(text)
assert res is False
mock_got_nh.assert_called_once_with(text)
mock_got_ws.assert_not_called()

@pytest.mark.parametrize(
"text",
[
"# Header\nSome content", # Header followed by content
"# Header1\n# Header2\nAlso here", # Multiple headers, then content
"No headers, just content", # No headers, just content
"# H\n# H2\n\tThis is non-header", # Tabs and spaces before content
],
)
def test__got_non_headers_with_content(self, doc_processor, text):
"""Test we detect when markdown has something beside headers."""
db = doc_processor["processor"].db
assert db._got_non_headers(text) is True

@pytest.mark.parametrize(
"text",
[
"# Only header",
"# Another header\n## Subheader",
"# Header with space \n",
"## \n#",
"#Header1\n#Header2\n#Header3",
"# ", # header with whitespace
" # Header with leading space",
" \n\t \n", # only whitespaces
],
)
def test__got_non_headers_only_headers(self, doc_processor, text):
"""Test we detect when markdown only has headers."""
db = doc_processor["processor"].db
assert db._got_non_headers(text) is False

def test__filter_out_invalid_nodes(self, doc_processor):
"""Test that _filter_out_invalid_nodes only returns nodes with whitespace."""
fake_node_0 = mock.Mock(spec=TextNode)
fake_node_1 = mock.Mock(spec=TextNode)
fake_node_0.text = "Got whitespace"
fake_node_1.text = "NoWhitespace"

result = doc_processor["processor"].db._filter_out_invalid_nodes(
[fake_node_0, fake_node_1]
)
assert result == [fake_node_0]
"""Test that _filter_out_invalid_nodes checks for validity of text nodes."""
fake_text_node_0 = mock.Mock(spec=TextNode, text="fake_text_node_0")
fake_text_node_1 = mock.Mock(spec=TextNode, text="fake_text_node_1")
fake_node_2 = mock.Mock(spec=Node)

db = doc_processor["processor"].db
with mock.patch.object(
db, "_valid_text_node", side_effect=(True, False)
) as valid_tn_mock:
result = db._filter_out_invalid_nodes(
[fake_text_node_0, fake_text_node_1, fake_node_2]
)

expected_calls = [
mock.call(fake_text_node_0.text),
mock.call(fake_text_node_1.text),
]
valid_tn_mock.assert_has_calls(expected_calls)
assert len(expected_calls) == valid_tn_mock.call_count

assert result == [fake_text_node_0]

def test__save_index(self, mocker, doc_processor):
"""Test that _save_index sets index ID and persists the storage context."""
Expand Down
Loading