Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ different:
embeddings_model_dir=args.model_dir,
num_workers=args.workers,
vector_store_type=args.vector_store_type,
exclude_embed_metadata=args.exclude_metadata,
exclude_llm_metadata=args.exclude_llm_metadata,
)

# Load and embed the documents, this method can be called multiple times
Expand Down
113 changes: 112 additions & 1 deletion src/lightspeed_rag_content/document_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import tempfile
import time
import types
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union

Expand Down Expand Up @@ -76,6 +77,8 @@ def __init__(self, config: _Config):
if config.doc_type == "markdown":
Settings.node_parser = MarkdownNodeParser()

self.original_model_copy = TextNode.model_copy

@staticmethod
def _got_whitespace(text: str) -> bool:
"""Indicate if the parameter string contains whitespace."""
Expand All @@ -102,6 +105,15 @@ def _split_and_filter(cls, docs: list[Document]) -> list[TextNode]:
valid_nodes = cls._filter_out_invalid_nodes(nodes)
return valid_nodes

@staticmethod
def _remove_metadata(
metadata: dict[str, Any], remove: Optional[list[str]]
) -> dict[str, Any]:
"""Return a metadata dictionary without some keys."""
if not remove:
return metadata.copy()
return {key: value for key, value in metadata.items() if key not in remove}


class _LlamaIndexDB(_BaseDB):
def __init__(self, config: _Config):
Expand Down Expand Up @@ -141,6 +153,7 @@ def add_docs(self, docs: list[Document]) -> None:
"""Add documents to the list of documents to save."""
valid_nodes = self._split_and_filter(docs)
self._good_nodes.extend(valid_nodes)
self.exclude_metadata(self._good_nodes)

def save(
self, index: str, output_dir: str, embedded_files: int, exec_time: int
Expand Down Expand Up @@ -183,8 +196,56 @@ def _save_metadata(
) as file:
file.write(json.dumps(metadata))

def _model_copy_excluding_llm_metadata(
self, node: TextNode, *args: Any, **kwargs: Any
) -> TextNode:
"""Replace node's model_copy to remove metadata."""
res = self.original_model_copy(node, *args, **kwargs)
res.metadata = self._remove_metadata(
res.metadata, node.excluded_llm_metadata_keys
)
return res

def exclude_metadata(self, documents: list[TextNode]) -> None:
"""Exclude metadata from documents.

By default llama-index already excludes the following keys:
"file_name", "file_type", "file_size", "creation_date",
"last_modified_date", and "last_accessed_date".

This method adds more metadata keys to be excluded from embedding
calculations and from the metadata returned to the LLM.
"""
for doc in documents:
doc.excluded_embed_metadata_keys = self.config.exclude_embed_metadata

# Llama index stores all the metadata and expects that on retrieval
# the `doc.excluded_llm_metadata_keys` is used to manually fix the
# dict, or call `doc.get_content(metadata_mode=MetadataMode.LLM)`
# to get a string representation of the contents + LLM_metadata
# We don't want that, we only want to store the LLM metadata, so
# we replace the model_copy that happens *after* the embedding has
# already happened[1] when adding nodes to index[2], allowing
# different embedding and LLM metadata.
# [1]: https://github.com/run-llama/llama_index/blob/6bf76bf1ca5c70e479a3adb3327cf896cbcd869f/llama-index-core/llama_index/core/indices/vector_store/base.py#L145 # pylint: disable=line-too-long # noqa: E501
# [2]: https://github.com/run-llama/llama_index/blob/6bf76bf1ca5c70e479a3adb3327cf896cbcd869f/llama-index-core/llama_index/core/indices/vector_store/base.py#L231 # pylint: disable=line-too-long # noqa: E501
doc.excluded_llm_metadata_keys = self.config.exclude_llm_metadata
# Override `model_copy` so we don't store excluded metadata, cannot
# use `doc.model_copy = ` or `setattr(doc, "model_copy" ...)`
# because pydantic has custom setattr code that rejects it.
object.__setattr__(
doc,
"model_copy",
types.MethodType(self._model_copy_excluding_llm_metadata, doc),
)


class _LlamaStackDB(_BaseDB):
# Templates for manual creation of embeddings
EMBEDDING_METADATA_SEPARATOR = "\n"
EMBEDDING_METADATA_TEMPLATE = "{key}: {value}"
EMBEDDING_TEMPLATE = "{metadata_str}\n\n{content}"

# Lllama-stack faiss vector-db uses IndexFlatL2 (it's hardcoded for now)
TEMPLATE = """version: 2
image_name: ollama
Expand Down Expand Up @@ -349,16 +410,26 @@ def add_docs(self, docs: list[Document]) -> None:
"chunk_id": node.id_,
"source": node.metadata.get("docs_url", node.metadata["title"]),
}
embed_metadata = self._remove_metadata(
node.metadata, self.config.exclude_embed_metadata
)
llm_metadata = self._remove_metadata(
node.metadata, self.config.exclude_llm_metadata
)
# Add document_id to node's metadata because llama-stack needs it
llm_metadata["document_id"] = node.ref_doc_id
self.documents.append(
{
"content": node.text,
"mime_type": "text/plain",
"metadata": node.metadata,
"metadata": llm_metadata,
"chunk_metadata": chunk_metadata,
"embed_metadata": embed_metadata, # internal to this script
}
)

else:
LOG.warning("Llama-stack automatic mode doesn't use metadata for Embedding")
self.documents.extend(
self.document_class(
document_id=doc.doc_id,
Expand All @@ -369,6 +440,41 @@ def add_docs(self, docs: list[Document]) -> None:
for doc in docs
)

def _calculate_embeddings(
self, client: Any, documents: list[dict[str, Any]]
) -> None:
"""Calculate the embeddings with metadata.

This method is necessary because llama-stack doesn't use the metadata
to calculate embeddings, so we need to calculate the embeddings
ourselves.

In the `add_docs` method the `embed_metadata` was added as a key just
to have it here, but we will use and remove it now because llama-stack
doesn't accept this parameter.

Embeddings are stored in the `embedding` key that llama-stack expects.
"""
for doc in documents:
# Build a str with the metadata
embed_metadata = doc.pop("embed_metadata")
metadata_str = self.EMBEDDING_METADATA_SEPARATOR.join(
self.EMBEDDING_METADATA_TEMPLATE.format(key=key, value=str(value))
for key, value in embed_metadata.items()
)

# We'll embed the chunk contents with the metadata
data = self.EMBEDDING_TEMPLATE.format(
content=doc["content"], metadata_str=metadata_str
)

embedding = client.inference.embeddings(
model_id=self.config.model_name,
contents=[data],
)

doc["embedding"] = embedding.embeddings[0]

def save(
self,
index: str,
Expand All @@ -387,6 +493,7 @@ def save(
client = self._start_llama_stack(cfg_file)
try:
if self.config.manual_chunking:
self._calculate_embeddings(client, self.documents)
client.vector_io.insert(vector_db_id=index, chunks=self.documents)
else:
client.tool_runtime.rag_tool.insert(
Expand All @@ -413,6 +520,8 @@ def __init__(
table_name: Optional[str] = None,
manual_chunking: bool = True,
doc_type: str = "text",
exclude_embed_metadata: Optional[list[str]] = None,
exclude_llm_metadata: Optional[list[str]] = None,
):
"""Initialize instance."""
if vector_store_type == "postgres" and not table_name:
Expand All @@ -429,6 +538,8 @@ def __init__(
table_name=table_name,
manual_chunking=manual_chunking,
doc_type=doc_type,
exclude_embed_metadata=exclude_embed_metadata,
exclude_llm_metadata=exclude_llm_metadata,
)

self._check_config(self.config)
Expand Down
18 changes: 17 additions & 1 deletion src/lightspeed_rag_content/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
"""Utilities for rag-content modules."""
import argparse

DEFAULT_METADATA_EXCLUSSION = [
"file_name",
"file_type",
"file_size",
"creation_date",
"last_modified_date",
"last_accessed_date",
]


def get_common_arg_parser() -> argparse.ArgumentParser:
"""Provide common CLI arguments to document processing scripts."""
Expand All @@ -39,9 +48,16 @@ def get_common_arg_parser() -> argparse.ArgumentParser:
"-em",
"--exclude-metadata",
nargs="+",
default=None,
default=DEFAULT_METADATA_EXCLUSSION,
help="Metadata to be excluded during embedding",
)
parser.add_argument(
"-elm",
"--exclude-llm-metadata",
nargs="+",
default=DEFAULT_METADATA_EXCLUSSION,
help="Metadata to be excluded on the DB",
)
parser.add_argument("-o", "--output", help="Vector DB output folder")
parser.add_argument("-i", "--index", help="Product index")
parser.add_argument(
Expand Down
47 changes: 47 additions & 0 deletions tests/test_document_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def test_init_default(self, mock_processor):
manual_chunking=True,
table_name=None,
vector_store_type="faiss",
exclude_embed_metadata=None,
exclude_llm_metadata=None,
)
assert expected_params == doc_processor.config._Config__attributes
assert doc_processor._num_embedded_files == 0
Expand All @@ -106,6 +108,8 @@ def test_init_llama_index(self, vector_store_type, mock_processor):
manual_chunking=True,
table_name=None,
vector_store_type=vector_store_type,
exclude_embed_metadata=None,
exclude_llm_metadata=None,
)
if vector_store_type == "postgres":
params["table_name"] = "table_name"
Expand Down Expand Up @@ -135,6 +139,8 @@ def test_init_llama_stack(self, vector_store_type, mock_processor):
embedding_dimension=None, # Not calculated because class is mocked
manual_chunking=True,
table_name=None,
exclude_embed_metadata=None,
exclude_llm_metadata=None,
)
assert params == doc_processor.config._Config__attributes
assert doc_processor._num_embedded_files == 0
Expand Down Expand Up @@ -196,3 +202,44 @@ def test_save(self, mock_processor):
doc_processor = document_processor.DocumentProcessor(**mock_processor["params"])

doc_processor.save(mock.sentinel.index, mock.sentinel.output_dir)


class TestBaseDB:
"""Test cases for the _BaseDB class in document_processor module."""

def test__remove_metadata(self):
"""Test that _remove_metadata removes specified keys from metadata dictionary."""
metadata = {
"file_path": "/path/to/file",
"url": "https://example.com",
"title": "Test Document",
"file_name": "test.txt",
}
keys_to_remove = ["file_path", "file_name"]

result = document_processor._BaseDB._remove_metadata(metadata, keys_to_remove)

assert "file_path" not in result
assert "file_name" not in result
assert "url" in result
assert "title" in result
assert result["url"] == "https://example.com"
assert result["title"] == "Test Document"

def test__remove_metadata_empty_list(self):
"""Test that _remove_metadata returns original metadata when no keys to remove."""
metadata = {"key1": "value1", "key2": "value2"}

result = document_processor._BaseDB._remove_metadata(metadata, [])

assert result == metadata

def test__remove_metadata_nonexistent_keys(self):
"""Test that _remove_metadata handles nonexistent keys gracefully."""
metadata = {"key1": "value1", "key2": "value2"}

result = document_processor._BaseDB._remove_metadata(
metadata, ["nonexistent_key"]
)

assert result == metadata
Loading
Loading