diff --git a/README.md b/README.md index e8ec710c..60c838ba 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,8 @@ different: embeddings_model_dir=args.model_dir, num_workers=args.workers, vector_store_type=args.vector_store_type, + exclude_embed_metadata=args.exclude_metadata, + exclude_llm_metadata=args.exclude_llm_metadata, ) # Load and embed the documents, this method can be called multiple times diff --git a/src/lightspeed_rag_content/document_processor.py b/src/lightspeed_rag_content/document_processor.py index 6ff00ccd..640c57ba 100644 --- a/src/lightspeed_rag_content/document_processor.py +++ b/src/lightspeed_rag_content/document_processor.py @@ -19,6 +19,7 @@ import os import tempfile import time +import types from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union @@ -76,6 +77,8 @@ def __init__(self, config: _Config): if config.doc_type == "markdown": Settings.node_parser = MarkdownNodeParser() + self.original_model_copy = TextNode.model_copy + @staticmethod def _got_whitespace(text: str) -> bool: """Indicate if the parameter string contains whitespace.""" @@ -102,6 +105,15 @@ def _split_and_filter(cls, docs: list[Document]) -> list[TextNode]: valid_nodes = cls._filter_out_invalid_nodes(nodes) return valid_nodes + @staticmethod + def _remove_metadata( + metadata: dict[str, Any], remove: Optional[list[str]] + ) -> dict[str, Any]: + """Return a metadata dictionary without some keys.""" + if not remove: + return metadata.copy() + return {key: value for key, value in metadata.items() if key not in remove} + class _LlamaIndexDB(_BaseDB): def __init__(self, config: _Config): @@ -141,6 +153,7 @@ def add_docs(self, docs: list[Document]) -> None: """Add documents to the list of documents to save.""" valid_nodes = self._split_and_filter(docs) self._good_nodes.extend(valid_nodes) + self.exclude_metadata(self._good_nodes) def save( self, index: str, output_dir: str, embedded_files: int, exec_time: int @@ -183,8 +196,56 @@ def _save_metadata( ) as file: file.write(json.dumps(metadata)) + def _model_copy_excluding_llm_metadata( + self, node: TextNode, *args: Any, **kwargs: Any + ) -> TextNode: + """Replace node's model_copy to remove metadata.""" + res = self.original_model_copy(node, *args, **kwargs) + res.metadata = self._remove_metadata( + res.metadata, node.excluded_llm_metadata_keys + ) + return res + + def exclude_metadata(self, documents: list[TextNode]) -> None: + """Exclude metadata from documents. + + By default llama-index already excludes the following keys: + "file_name", "file_type", "file_size", "creation_date", + "last_modified_date", and "last_accessed_date". + + This method adds more metadata keys to be excluded from embedding + calculations and from the metadata returned to the LLM. + """ + for doc in documents: + doc.excluded_embed_metadata_keys = self.config.exclude_embed_metadata + + # Llama index stores all the metadata and expects that on retrieval + # the `doc.excluded_llm_metadata_keys` is used to manually fix the + # dict, or call `doc.get_content(metadata_mode=MetadataMode.LLM)` + # to get a string representation of the contents + LLM_metadata + # We don't want that, we only want to store the LLM metadata, so + # we replace the model_copy that happens *after* the embedding has + # already happened[1] when adding nodes to index[2], allowing + # different embedding and LLM metadata. + # [1]: https://github.com/run-llama/llama_index/blob/6bf76bf1ca5c70e479a3adb3327cf896cbcd869f/llama-index-core/llama_index/core/indices/vector_store/base.py#L145 # pylint: disable=line-too-long # noqa: E501 + # [2]: https://github.com/run-llama/llama_index/blob/6bf76bf1ca5c70e479a3adb3327cf896cbcd869f/llama-index-core/llama_index/core/indices/vector_store/base.py#L231 # pylint: disable=line-too-long # noqa: E501 + doc.excluded_llm_metadata_keys = self.config.exclude_llm_metadata + # Override `model_copy` so we don't store excluded metadata, cannot + # use `doc.model_copy = ` or `setattr(doc, "model_copy" ...)` + # because pydantic has custom setattr code that rejects it. + object.__setattr__( + doc, + "model_copy", + types.MethodType(self._model_copy_excluding_llm_metadata, doc), + ) + class _LlamaStackDB(_BaseDB): + # Templates for manual creation of embeddings + EMBEDDING_METADATA_SEPARATOR = "\n" + EMBEDDING_METADATA_TEMPLATE = "{key}: {value}" + EMBEDDING_TEMPLATE = "{metadata_str}\n\n{content}" + # Lllama-stack faiss vector-db uses IndexFlatL2 (it's hardcoded for now) TEMPLATE = """version: 2 image_name: ollama @@ -349,16 +410,26 @@ def add_docs(self, docs: list[Document]) -> None: "chunk_id": node.id_, "source": node.metadata.get("docs_url", node.metadata["title"]), } + embed_metadata = self._remove_metadata( + node.metadata, self.config.exclude_embed_metadata + ) + llm_metadata = self._remove_metadata( + node.metadata, self.config.exclude_llm_metadata + ) + # Add document_id to node's metadata because llama-stack needs it + llm_metadata["document_id"] = node.ref_doc_id self.documents.append( { "content": node.text, "mime_type": "text/plain", - "metadata": node.metadata, + "metadata": llm_metadata, "chunk_metadata": chunk_metadata, + "embed_metadata": embed_metadata, # internal to this script } ) else: + LOG.warning("Llama-stack automatic mode doesn't use metadata for Embedding") self.documents.extend( self.document_class( document_id=doc.doc_id, @@ -369,6 +440,41 @@ def add_docs(self, docs: list[Document]) -> None: for doc in docs ) + def _calculate_embeddings( + self, client: Any, documents: list[dict[str, Any]] + ) -> None: + """Calculate the embeddings with metadata. + + This method is necessary because llama-stack doesn't use the metadata + to calculate embeddings, so we need to calculate the embeddings + ourselves. + + In the `add_docs` method the `embed_metadata` was added as a key just + to have it here, but we will use and remove it now because llama-stack + doesn't accept this parameter. + + Embeddings are stored in the `embedding` key that llama-stack expects. + """ + for doc in documents: + # Build a str with the metadata + embed_metadata = doc.pop("embed_metadata") + metadata_str = self.EMBEDDING_METADATA_SEPARATOR.join( + self.EMBEDDING_METADATA_TEMPLATE.format(key=key, value=str(value)) + for key, value in embed_metadata.items() + ) + + # We'll embed the chunk contents with the metadata + data = self.EMBEDDING_TEMPLATE.format( + content=doc["content"], metadata_str=metadata_str + ) + + embedding = client.inference.embeddings( + model_id=self.config.model_name, + contents=[data], + ) + + doc["embedding"] = embedding.embeddings[0] + def save( self, index: str, @@ -387,6 +493,7 @@ def save( client = self._start_llama_stack(cfg_file) try: if self.config.manual_chunking: + self._calculate_embeddings(client, self.documents) client.vector_io.insert(vector_db_id=index, chunks=self.documents) else: client.tool_runtime.rag_tool.insert( @@ -413,6 +520,8 @@ def __init__( table_name: Optional[str] = None, manual_chunking: bool = True, doc_type: str = "text", + exclude_embed_metadata: Optional[list[str]] = None, + exclude_llm_metadata: Optional[list[str]] = None, ): """Initialize instance.""" if vector_store_type == "postgres" and not table_name: @@ -429,6 +538,8 @@ def __init__( table_name=table_name, manual_chunking=manual_chunking, doc_type=doc_type, + exclude_embed_metadata=exclude_embed_metadata, + exclude_llm_metadata=exclude_llm_metadata, ) self._check_config(self.config) diff --git a/src/lightspeed_rag_content/utils.py b/src/lightspeed_rag_content/utils.py index e0b7feb8..93cce582 100644 --- a/src/lightspeed_rag_content/utils.py +++ b/src/lightspeed_rag_content/utils.py @@ -15,6 +15,15 @@ """Utilities for rag-content modules.""" import argparse +DEFAULT_METADATA_EXCLUSSION = [ + "file_name", + "file_type", + "file_size", + "creation_date", + "last_modified_date", + "last_accessed_date", +] + def get_common_arg_parser() -> argparse.ArgumentParser: """Provide common CLI arguments to document processing scripts.""" @@ -39,9 +48,16 @@ def get_common_arg_parser() -> argparse.ArgumentParser: "-em", "--exclude-metadata", nargs="+", - default=None, + default=DEFAULT_METADATA_EXCLUSSION, help="Metadata to be excluded during embedding", ) + parser.add_argument( + "-elm", + "--exclude-llm-metadata", + nargs="+", + default=DEFAULT_METADATA_EXCLUSSION, + help="Metadata to be excluded on the DB", + ) parser.add_argument("-o", "--output", help="Vector DB output folder") parser.add_argument("-i", "--index", help="Product index") parser.add_argument( diff --git a/tests/test_document_processor.py b/tests/test_document_processor.py index 5567013c..d0a4b131 100644 --- a/tests/test_document_processor.py +++ b/tests/test_document_processor.py @@ -81,6 +81,8 @@ def test_init_default(self, mock_processor): manual_chunking=True, table_name=None, vector_store_type="faiss", + exclude_embed_metadata=None, + exclude_llm_metadata=None, ) assert expected_params == doc_processor.config._Config__attributes assert doc_processor._num_embedded_files == 0 @@ -106,6 +108,8 @@ def test_init_llama_index(self, vector_store_type, mock_processor): manual_chunking=True, table_name=None, vector_store_type=vector_store_type, + exclude_embed_metadata=None, + exclude_llm_metadata=None, ) if vector_store_type == "postgres": params["table_name"] = "table_name" @@ -135,6 +139,8 @@ def test_init_llama_stack(self, vector_store_type, mock_processor): embedding_dimension=None, # Not calculated because class is mocked manual_chunking=True, table_name=None, + exclude_embed_metadata=None, + exclude_llm_metadata=None, ) assert params == doc_processor.config._Config__attributes assert doc_processor._num_embedded_files == 0 @@ -196,3 +202,44 @@ def test_save(self, mock_processor): doc_processor = document_processor.DocumentProcessor(**mock_processor["params"]) doc_processor.save(mock.sentinel.index, mock.sentinel.output_dir) + + +class TestBaseDB: + """Test cases for the _BaseDB class in document_processor module.""" + + def test__remove_metadata(self): + """Test that _remove_metadata removes specified keys from metadata dictionary.""" + metadata = { + "file_path": "/path/to/file", + "url": "https://example.com", + "title": "Test Document", + "file_name": "test.txt", + } + keys_to_remove = ["file_path", "file_name"] + + result = document_processor._BaseDB._remove_metadata(metadata, keys_to_remove) + + assert "file_path" not in result + assert "file_name" not in result + assert "url" in result + assert "title" in result + assert result["url"] == "https://example.com" + assert result["title"] == "Test Document" + + def test__remove_metadata_empty_list(self): + """Test that _remove_metadata returns original metadata when no keys to remove.""" + metadata = {"key1": "value1", "key2": "value2"} + + result = document_processor._BaseDB._remove_metadata(metadata, []) + + assert result == metadata + + def test__remove_metadata_nonexistent_keys(self): + """Test that _remove_metadata handles nonexistent keys gracefully.""" + metadata = {"key1": "value1", "key2": "value2"} + + result = document_processor._BaseDB._remove_metadata( + metadata, ["nonexistent_key"] + ) + + assert result == metadata diff --git a/tests/test_document_processor_llama_index.py b/tests/test_document_processor_llama_index.py index 846c3f12..943e315d 100644 --- a/tests/test_document_processor_llama_index.py +++ b/tests/test_document_processor_llama_index.py @@ -243,3 +243,185 @@ def test_invalid_vector_store_type(self, doc_processor): doc_processor["num_workers"], "nonexisting", ) + + def test_exclude_metadata_sets_keys(self, mocker): + """Test that exclude_metadata sets excluded_embed_metadata_keys and excluded_llm_metadata_keys.""" + mocker.patch.object( + document_processor, "HuggingFaceEmbedding", new=RagMockEmbedding + ) + mocker.patch("os.path.exists", return_value=True) + + exclude_embed = ["file_path", "url"] + exclude_llm = ["file_path", "url_reachable"] + + processor = document_processor.DocumentProcessor( + 380, + 0, + "sentence-transformers/all-mpnet-base-v2", + Path("./embeddings_model"), + 10, + exclude_embed_metadata=exclude_embed, + exclude_llm_metadata=exclude_llm, + ) + + # Create mock nodes + node1 = mock.Mock(spec=TextNode) + node1.excluded_embed_metadata_keys = None + node1.excluded_llm_metadata_keys = None + node1.metadata = {"file_path": "/path", "url": "https://example.com"} + + node2 = mock.Mock(spec=TextNode) + node2.excluded_embed_metadata_keys = None + node2.excluded_llm_metadata_keys = None + node2.metadata = {"file_path": "/path2", "title": "Test"} + + nodes = [node1, node2] + + processor.db.exclude_metadata(nodes) + + assert node1.excluded_embed_metadata_keys == exclude_embed + assert node1.excluded_llm_metadata_keys == exclude_llm + assert node2.excluded_embed_metadata_keys == exclude_embed + assert node2.excluded_llm_metadata_keys == exclude_llm + + def test_exclude_metadata_overrides_model_copy(self, mocker): + """Test that exclude_metadata overrides model_copy method on nodes.""" + mocker.patch.object( + document_processor, "HuggingFaceEmbedding", new=RagMockEmbedding + ) + mocker.patch("os.path.exists", return_value=True) + + exclude_embed = ["file_path"] + exclude_llm = ["file_path", "url_reachable"] + + processor = document_processor.DocumentProcessor( + 380, + 0, + "sentence-transformers/all-mpnet-base-v2", + Path("./embeddings_model"), + 10, + exclude_embed_metadata=exclude_embed, + exclude_llm_metadata=exclude_llm, + ) + + # Create a real TextNode to test model_copy override + node = TextNode( + text="Test content", + metadata={ + "file_path": "/path/to/file", + "url_reachable": True, + "title": "Test Document", + }, + ) + + original_model_copy = node.model_copy + + processor.db.exclude_metadata([node]) + + # Verify model_copy was overridden + assert node.model_copy != original_model_copy + # Verify it's a bound method + assert hasattr(node.model_copy, "__self__") + + def test__model_copy_excluding_llm_metadata(self, mocker): + """Test that _model_copy_excluding_llm_metadata removes excluded metadata.""" + mocker.patch.object( + document_processor, "HuggingFaceEmbedding", new=RagMockEmbedding + ) + mocker.patch("os.path.exists", return_value=True) + + exclude_llm = ["file_path", "url_reachable"] + + processor = document_processor.DocumentProcessor( + 380, + 0, + "sentence-transformers/all-mpnet-base-v2", + Path("./embeddings_model"), + 10, + exclude_llm_metadata=exclude_llm, + ) + + node = TextNode( + text="Test content", + metadata={ + "file_path": "/path/to/file", + "url_reachable": True, + "title": "Test Document", + "url": "https://example.com", + }, + ) + node.excluded_llm_metadata_keys = exclude_llm + + # Call the method directly + result = processor.db._model_copy_excluding_llm_metadata(node) + + # Verify excluded keys are removed + assert "file_path" not in result.metadata + assert "url_reachable" not in result.metadata + # Verify non-excluded keys remain + assert "title" in result.metadata + assert "url" in result.metadata + assert result.metadata["title"] == "Test Document" + assert result.metadata["url"] == "https://example.com" + # Verify text is preserved + assert result.text == "Test content" + + def test_add_docs_calls_exclude_metadata(self, mocker, doc_processor): + """Test that add_docs calls exclude_metadata on nodes.""" + mock_exclude = mocker.patch.object( + doc_processor["processor"].db, "exclude_metadata" + ) + mock_split = mocker.patch.object( + doc_processor["processor"].db, + "_split_and_filter", + return_value=[mock.Mock(spec=TextNode), mock.Mock(spec=TextNode)], + ) + + docs = [Document(text="doc1"), Document(text="doc2")] + doc_processor["processor"].db.add_docs(docs) + + mock_split.assert_called_once_with(docs) + # exclude_metadata should be called with the good nodes + mock_exclude.assert_called_once() + assert len(mock_exclude.call_args[0][0]) == 2 + + def test_exclude_metadata_defaults_to_none(self, mocker): + """Test that exclude_metadata parameters default to None when not provided.""" + mocker.patch.object( + document_processor, "HuggingFaceEmbedding", new=RagMockEmbedding + ) + mocker.patch("os.path.exists", return_value=True) + + processor = document_processor.DocumentProcessor( + 380, + 0, + "sentence-transformers/all-mpnet-base-v2", + Path("./embeddings_model"), + 10, + ) + + assert processor.config.exclude_embed_metadata is None + assert processor.config.exclude_llm_metadata is None + + def test_exclude_metadata_with_custom_values(self, mocker): + """Test that exclude_metadata parameters are stored correctly when provided.""" + mocker.patch.object( + document_processor, "HuggingFaceEmbedding", new=RagMockEmbedding + ) + mocker.patch("os.path.exists", return_value=True) + + exclude_embed = ["custom_key1", "custom_key2"] + exclude_llm = ["custom_key3"] + + processor = document_processor.DocumentProcessor( + 380, + 0, + "sentence-transformers/all-mpnet-base-v2", + Path("./embeddings_model"), + 10, + exclude_embed_metadata=exclude_embed, + exclude_llm_metadata=exclude_llm, + ) + + assert processor.config.exclude_embed_metadata == exclude_embed + assert processor.config.exclude_llm_metadata == exclude_llm diff --git a/tests/test_document_processor_llama_stack.py b/tests/test_document_processor_llama_stack.py index 0c5da8bb..8c03d572 100644 --- a/tests/test_document_processor_llama_stack.py +++ b/tests/test_document_processor_llama_stack.py @@ -83,11 +83,10 @@ def llama_stack_processor(mocker): st = mocker.patch.object(document_processor, "SentenceTransformer") st.return_value.get_sentence_embedding_dimension.return_value = 768 mocker.patch("os.path.exists", return_value=False) - - mocker.patch.object( - document_processor.Settings.text_splitter.__class__, - "get_nodes_from_documents", - ) + # Mock tiktoken to prevent network calls during initialization + mock_encoding = mocker.Mock() + mocker.patch("tiktoken.get_encoding", return_value=mock_encoding) + mocker.patch("tiktoken.encoding_for_model", return_value=mock_encoding) model_name = "sentence-transformers/all-mpnet-base-v2" config = document_processor._Config( @@ -99,6 +98,8 @@ def llama_stack_processor(mocker): embedding_dimension=None, manual_chunking=True, doc_type="text", + exclude_embed_metadata=[], + exclude_llm_metadata=[], ) return {"config": config, "model_name": model_name} @@ -131,7 +132,14 @@ def test_init_model_path(self, mocker, llama_stack_processor): document_processor.tempfile, "TemporaryDirectory" ) temp_dir.return_value.name = "temp_dir" - exists_mock = mocker.patch("os.path.exists", return_value=True) + + # Mock exists to return True for embeddings_model_dir, False for tiktoken cache + def exists_side_effect(path): + if "embeddings_model" in str(path): + return True + return False + + exists_mock = mocker.patch("os.path.exists", side_effect=exists_side_effect) realpath_mock = mocker.patch("os.path.realpath") config = llama_stack_processor["config"] @@ -139,7 +147,10 @@ def test_init_model_path(self, mocker, llama_stack_processor): doc = document_processor._LlamaStackDB(config) assert doc.config == config - exists_mock.assert_called_once_with(config.embeddings_model_dir) + # Check that exists was called with embeddings_model_dir + assert any( + "embeddings_model" in str(call) for call in exists_mock.call_args_list + ) realpath_mock.assert_called_once_with(config.embeddings_model_dir) assert doc.model_name_or_dir == realpath_mock.return_value assert doc.config.embedding_dimension == 768 @@ -307,6 +318,11 @@ def test_add_docs_manual_chunking(self, mocker, llama_stack_processor): "chunk_id": 3, "source": "https://redhat.com/1", }, + "embed_metadata": { + "document_id": 1, + "title": "title1", + "docs_url": "https://redhat.com/1", + }, }, { "content": "2", @@ -321,6 +337,11 @@ def test_add_docs_manual_chunking(self, mocker, llama_stack_processor): "chunk_id": 6, "source": "https://redhat.com/2", }, + "embed_metadata": { + "document_id": 2, + "title": "title2", + "docs_url": "https://redhat.com/2", + }, }, ] assert doc.documents == expect @@ -362,11 +383,21 @@ def test_add_docs_auto_chunking(self, mocker, llama_stack_processor): def _test_save(self, mocker, config): """Helper function to set up and verify save functionality.""" doc = document_processor._LlamaStackDB(config) - doc.documents = mock.sentinel.documents + doc.documents = [ + { + "content": "test", + "mime_type": "text/plain", + "embed_metadata": {"title": "test"}, + "metadata": {"title": "test"}, + "chunk_metadata": {"document_id": 1, "chunk_id": 1}, + } + ] write_cfg = mocker.patch.object(doc, "write_yaml_config") client = mocker.patch.object(doc, "_start_llama_stack") - client.inspect.version.return_value = "0.2.15" + mock_embeddings_response = mocker.Mock() + mock_embeddings_response.embeddings = [[0.1] * 768] + client.return_value.inference.embeddings.return_value = mock_embeddings_response realpath = mocker.patch( "os.path.realpath", return_value="/cwd/out_dir/vector_store.db" ) @@ -387,17 +418,214 @@ def _test_save(self, mocker, config): def test_save_manual_chunking(self, mocker, llama_stack_processor): """Test saving documents with manual chunking workflow.""" client = self._test_save(mocker, llama_stack_processor["config"]) - client.vector_io.insert.assert_called_once_with( - vector_db_id=mock.sentinel.index, chunks=mock.sentinel.documents - ) + client.vector_io.insert.assert_called_once() + call_args = client.vector_io.insert.call_args + assert call_args.kwargs["vector_db_id"] == mock.sentinel.index + assert "chunks" in call_args.kwargs + assert len(call_args.kwargs["chunks"]) == 1 def test_save_auto_chunking(self, mocker, llama_stack_processor): """Test saving documents with automatic chunking workflow.""" config = llama_stack_processor["config"] config.manual_chunking = False client = self._test_save(mocker, config) - client.tool_runtime.rag_tool.insert.assert_called_once_with( - documents=mock.sentinel.documents, - vector_db_id=mock.sentinel.index, - chunk_size_in_tokens=380, - ) + client.tool_runtime.rag_tool.insert.assert_called_once() + call_args = client.tool_runtime.rag_tool.insert.call_args + assert call_args.kwargs["vector_db_id"] == mock.sentinel.index + assert "documents" in call_args.kwargs + assert len(call_args.kwargs["documents"]) == 1 + assert call_args.kwargs["chunk_size_in_tokens"] == 380 + + def test_calculate_embeddings(self, mocker, llama_stack_processor): + """Test _calculate_embeddings method formats metadata and calculates embeddings.""" + doc = document_processor._LlamaStackDB(llama_stack_processor["config"]) + client = mocker.Mock() + mock_embedding_response = mocker.Mock() + mock_embedding_response.embeddings = [[0.5] * 768] + client.inference.embeddings.return_value = mock_embedding_response + + documents = [ + { + "content": "test content", + "embed_metadata": { + "title": "Test Title", + "docs_url": "https://example.com", + }, + } + ] + + doc._calculate_embeddings(client, documents) + + # Verify embed_metadata was removed + assert "embed_metadata" not in documents[0] + # Verify embedding was added + assert "embedding" in documents[0] + assert documents[0]["embedding"] == [0.5] * 768 + # Verify client.inference.embeddings was called with correct data + client.inference.embeddings.assert_called_once() + call_args = client.inference.embeddings.call_args + assert call_args.kwargs["model_id"] == llama_stack_processor["model_name"] + assert len(call_args.kwargs["contents"]) == 1 + # Verify the formatted data includes metadata and content + formatted_data = call_args.kwargs["contents"][0] + assert "title: Test Title" in formatted_data + assert "docs_url: https://example.com" in formatted_data + assert "test content" in formatted_data + assert formatted_data.startswith("title: Test Title") + + def test_calculate_embeddings_empty_metadata(self, mocker, llama_stack_processor): + """Test _calculate_embeddings with empty metadata.""" + doc = document_processor._LlamaStackDB(llama_stack_processor["config"]) + client = mocker.Mock() + mock_embedding_response = mocker.Mock() + mock_embedding_response.embeddings = [[0.3] * 768] + client.inference.embeddings.return_value = mock_embedding_response + + documents = [ + { + "content": "test content", + "embed_metadata": {}, + } + ] + + doc._calculate_embeddings(client, documents) + + # Verify embedding was added + assert "embedding" in documents[0] + # Verify the formatted data only contains content (empty metadata_str) + call_args = client.inference.embeddings.call_args + formatted_data = call_args.kwargs["contents"][0] + assert formatted_data == "\n\ntest content" + + def test_add_docs_exclude_embed_metadata(self, mocker, llama_stack_processor): + """Test that exclude_embed_metadata removes keys from embed_metadata.""" + config = llama_stack_processor["config"] + config.exclude_embed_metadata = ["docs_url", "title"] + doc = document_processor._LlamaStackDB(config) + nodes = [ + mocker.Mock( + spec=TextNode, + ref_doc_id=1, + id_=3, + text="test", + metadata={ + "title": "Test Title", + "docs_url": "https://example.com", + "author": "Test Author", + }, + ) + ] + mocker.patch.object(doc, "_split_and_filter", return_value=nodes) + + doc.add_docs([mocker.Mock()]) + + assert len(doc.documents) == 1 + # embed_metadata should not contain excluded keys + assert "docs_url" not in doc.documents[0]["embed_metadata"] + assert "title" not in doc.documents[0]["embed_metadata"] + assert "author" in doc.documents[0]["embed_metadata"] + # llm_metadata should still contain all keys (except those in exclude_llm_metadata) + assert "title" in doc.documents[0]["metadata"] + assert "docs_url" in doc.documents[0]["metadata"] + assert "author" in doc.documents[0]["metadata"] + + def test_add_docs_exclude_llm_metadata(self, mocker, llama_stack_processor): + """Test that exclude_llm_metadata removes keys from llm_metadata.""" + config = llama_stack_processor["config"] + config.exclude_llm_metadata = ["docs_url", "author"] + doc = document_processor._LlamaStackDB(config) + nodes = [ + mocker.Mock( + spec=TextNode, + ref_doc_id=1, + id_=3, + text="test", + metadata={ + "title": "Test Title", + "docs_url": "https://example.com", + "author": "Test Author", + }, + ) + ] + mocker.patch.object(doc, "_split_and_filter", return_value=nodes) + + doc.add_docs([mocker.Mock()]) + + assert len(doc.documents) == 1 + # llm_metadata should not contain excluded keys + assert "docs_url" not in doc.documents[0]["metadata"] + assert "author" not in doc.documents[0]["metadata"] + assert "title" in doc.documents[0]["metadata"] + # embed_metadata should still contain all keys (except those in exclude_embed_metadata) + assert "title" in doc.documents[0]["embed_metadata"] + assert "docs_url" in doc.documents[0]["embed_metadata"] + assert "author" in doc.documents[0]["embed_metadata"] + + def test_add_docs_exclude_both_metadata(self, mocker, llama_stack_processor): + """Test that both exclude_embed_metadata and exclude_llm_metadata work together.""" + config = llama_stack_processor["config"] + config.exclude_embed_metadata = ["docs_url"] + config.exclude_llm_metadata = ["author"] + doc = document_processor._LlamaStackDB(config) + nodes = [ + mocker.Mock( + spec=TextNode, + ref_doc_id=1, + id_=3, + text="test", + metadata={ + "title": "Test Title", + "docs_url": "https://example.com", + "author": "Test Author", + }, + ) + ] + mocker.patch.object(doc, "_split_and_filter", return_value=nodes) + + doc.add_docs([mocker.Mock()]) + + assert len(doc.documents) == 1 + # embed_metadata should exclude docs_url + assert "docs_url" not in doc.documents[0]["embed_metadata"] + assert "title" in doc.documents[0]["embed_metadata"] + assert "author" in doc.documents[0]["embed_metadata"] + # llm_metadata should exclude author + assert "author" not in doc.documents[0]["metadata"] + assert "title" in doc.documents[0]["metadata"] + assert "docs_url" in doc.documents[0]["metadata"] + + def test_calculate_embeddings_multiple_documents( + self, mocker, llama_stack_processor + ): + """Test _calculate_embeddings with multiple documents.""" + doc = document_processor._LlamaStackDB(llama_stack_processor["config"]) + client = mocker.Mock() + mock_embedding_response = mocker.Mock() + # Return different embeddings for each call + mock_embedding_response.embeddings = [[0.1] * 768] + client.inference.embeddings.side_effect = [ + mocker.Mock(embeddings=[[0.1] * 768]), + mocker.Mock(embeddings=[[0.2] * 768]), + ] + + documents = [ + { + "content": "content 1", + "embed_metadata": {"title": "Title 1"}, + }, + { + "content": "content 2", + "embed_metadata": {"title": "Title 2"}, + }, + ] + + doc._calculate_embeddings(client, documents) + + # Verify both documents have embeddings + assert documents[0]["embedding"] == [0.1] * 768 + assert documents[1]["embedding"] == [0.2] * 768 + # Verify client was called twice + assert client.inference.embeddings.call_count == 2 + # Verify embed_metadata was removed from both + assert "embed_metadata" not in documents[0] + assert "embed_metadata" not in documents[1]