From 835edda6fc5433f53ceec53b0adbaa897522a2ee Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 18 Nov 2025 12:18:56 +0100 Subject: [PATCH 1/2] Add embeddings & reranking via Sentence Transformers --- README-zh.md | 8 +- README.md | 10 +-- .../lightrag_sentence_transformers_demo.py | 75 +++++++++++++++++++ lightrag/api/lightrag_server.py | 19 ++++- lightrag/llm/sentence_transformers.py | 26 +++++++ lightrag/rerank.py | 49 ++++++++++++ 6 files changed, 177 insertions(+), 10 deletions(-) create mode 100644 examples/unofficial-sample/lightrag_sentence_transformers_demo.py create mode 100644 lightrag/llm/sentence_transformers.py diff --git a/README-zh.md b/README-zh.md index 57eb9e4ab9..3ad4bc7b8c 100644 --- a/README-zh.md +++ b/README-zh.md @@ -453,7 +453,7 @@ async def initialize_rag(): * 如果您想使用Hugging Face模型,只需要按如下方式设置LightRAG: -参见`lightrag_hf_demo.py` +参见`lightrag_hf_demo.py`, `lightrag_sentence_transformers_demo.py`等示例代码。 ```python # 使用Hugging Face模型初始化LightRAG @@ -464,10 +464,9 @@ rag = LightRAG( # 使用Hugging Face嵌入函数 embedding_func=EmbeddingFunc( embedding_dim=384, - func=lambda texts: hf_embed( + func=lambda texts: sentence_transformers_embed( texts, - tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"), - embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") + model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") ) ), ) @@ -635,6 +634,7 @@ if __name__ == "__main__": * **Cohere / vLLM**: `cohere_rerank` * **Jina AI**: `jina_rerank` * **Aliyun阿里云**: `ali_rerank` +* **Sentence Transformers**: `sentence_transformers_rerank` 您可以将这些函数之一注入到LightRAG对象的`rerank_model_func`属性中。这将使LightRAG的查询功能能够使用注入的函数对检索到的文本块进行重新排序。有关详细用法,请参阅`examples/rerank_example.py`文件。 diff --git a/README.md b/README.md index 9b3e3c703c..cfe9801189 100644 --- a/README.md +++ b/README.md @@ -449,7 +449,7 @@ async def initialize_rag(): * If you want to use Hugging Face models, you only need to set LightRAG as follows: -See `lightrag_hf_demo.py` +See `lightrag_hf_demo.py` & `lightrag_sentence_transformers_demo.py` for complete examples. ```python # Initialize LightRAG with Hugging Face model @@ -457,13 +457,12 @@ rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=hf_model_complete, # Use Hugging Face model for text generation llm_model_name='meta-llama/Llama-3.1-8B-Instruct', # Model name from Hugging Face - # Use Hugging Face embedding function + # Use Hugging Face Sentence Transformers embedding function embedding_func=EmbeddingFunc( embedding_dim=384, - func=lambda texts: hf_embed( + func=lambda texts: sentence_transformers_embed( texts, - tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"), - embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") + model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") ) ), ) @@ -633,6 +632,7 @@ To enhance retrieval quality, documents can be re-ranked based on a more effecti * **Cohere / vLLM**: `cohere_rerank` * **Jina AI**: `jina_rerank` * **Aliyun**: `ali_rerank` +* **Sentence Transformers**: `sentence_transformers_rerank` You can inject one of these functions into the `rerank_model_func` attribute of the LightRAG object. This will enable LightRAG's query function to re-order retrieved text blocks using the injected function. For detailed usage, please refer to the `examples/rerank_example.py` file. diff --git a/examples/unofficial-sample/lightrag_sentence_transformers_demo.py b/examples/unofficial-sample/lightrag_sentence_transformers_demo.py new file mode 100644 index 0000000000..9971a7f88f --- /dev/null +++ b/examples/unofficial-sample/lightrag_sentence_transformers_demo.py @@ -0,0 +1,75 @@ +import os + +from lightrag import LightRAG, QueryParam +from lightrag.llm.hf import hf_model_complete +from lightrag.llm.sentence_transformers import sentence_transformers_embed +from lightrag.utils import EmbeddingFunc +from sentence_transformers import SentenceTransformer + +import asyncio +import nest_asyncio + +nest_asyncio.apply() + +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=hf_model_complete, + llm_model_name="meta-llama/Llama-3.1-8B-Instruct", + embedding_func=EmbeddingFunc( + embedding_dim=384, + max_token_size=512, + func=lambda texts: sentence_transformers_embed( + texts, + model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2"), + ), + ), + ) + + await rag.initialize_storages() # Auto-initializes pipeline_status + return rag + + +def main(): + rag = asyncio.run(initialize_rag()) + + with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + + # Perform naive search + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) + ) + + # Perform local search + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) + + # Perform global search + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="global") + ) + ) + + # Perform hybrid search + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="hybrid") + ) + ) + + +if __name__ == "__main__": + main() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index b29e39b2eb..a09f45a7d6 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -781,6 +781,17 @@ async def optimized_embedding_function(texts, embedding_dim=None): base_url=host, api_key=api_key, ) + elif binding == "sentence_transformers": + from lightrag.llm.sentence_transformers import ( + sentence_transformers_embed, + ) + + actual_func = ( + sentence_transformers_embed.func + if isinstance(sentence_transformers_embed, EmbeddingFunc) + else sentence_transformers_embed + ) + return await actual_func(texts, embedding_dim=embedding_dim) elif binding == "gemini": from lightrag.llm.gemini import gemini_embed @@ -932,13 +943,19 @@ async def bedrock_model_complete( # Configure rerank function based on args.rerank_bindingparameter rerank_model_func = None if args.rerank_binding != "null": - from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank + from lightrag.rerank import ( + cohere_rerank, + jina_rerank, + ali_rerank, + sentence_transformers_rerank, + ) # Map rerank binding to corresponding function rerank_functions = { "cohere": cohere_rerank, "jina": jina_rerank, "aliyun": ali_rerank, + "sentence_transformers": sentence_transformers_rerank, } # Select the appropriate rerank function based on binding diff --git a/lightrag/llm/sentence_transformers.py b/lightrag/llm/sentence_transformers.py new file mode 100644 index 0000000000..b5858a10eb --- /dev/null +++ b/lightrag/llm/sentence_transformers.py @@ -0,0 +1,26 @@ +import pipmaster as pm # Pipmaster for dynamic library install + +if not pm.is_installed("sentence_transformers"): + pm.install("sentence_transformers") +if not pm.is_installed("numpy"): + pm.install("numpy") + +import numpy as np +from lightrag.utils import EmbeddingFunc +from sentence_transformers import SentenceTransformer + + +async def sentence_transformers_embed( + texts: list[str], model: SentenceTransformer +) -> np.ndarray: + async def inner_encode(texts: list[str], model: SentenceTransformer, embedding_dim: int = 1024): + return model.encode( + texts, + truncate_dim=embedding_dim, + convert_to_numpy=True, + convert_to_tensor=False, + show_progress_bar=False, + ) + + embedding_func = EmbeddingFunc(embedding_dim=model.get_sentence_embedding_dimension(), func=inner_encode, max_token_size=model.get_max_seq_length()) + return await embedding_func(texts, model=model) diff --git a/lightrag/rerank.py b/lightrag/rerank.py index 35551f5a04..cf508a8a74 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -290,6 +290,40 @@ async def ali_rerank( ) +async def sentence_transformers_rerank( + query: str, + documents: List[str], + top_n: Optional[int] = None, + api_key: Optional[str] = None, + model: str = "BAAI/bge-reranker-v2-m3", + base_url: Optional[str] = None, + extra_body: Optional[Dict[str, Any]] = None, +) -> List[Dict[str, Any]]: + """ + Rerank documents using CrossEncoder from Sentence Transformers. + + Args: + query: The search query + documents: List of strings to rerank + top_n: Number of top results to return + api_key: Unused + model: rerank model name + base_url: Unused + extra_body: Unused + + Returns: + List of dictionary of ["index": int, "relevance_score": float] + """ + from sentence_transformers import CrossEncoder + + cross_encoder = CrossEncoder(model) + rankings = cross_encoder.rank(query=query, documents=documents, top_k=top_n) + return [ + {"index": result["corpus_id"], "relevance_score": result["score"]} + for result in rankings + ] + + """Please run this test as a module: python -m lightrag.rerank """ @@ -350,5 +384,20 @@ async def main(): print(f"Document: {docs[item['index']]}") except Exception as e: print(f"Aliyun Error: {e}") + + # Test Sentence Transformers rerank + try: + print("\n=== Sentence Transformers Rerank ===") + result = await sentence_transformers_rerank( + query=query, + documents=docs, + top_n=2, + ) + print("Results:") + for item in result: + print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}") + print(f"Document: {docs[item['index']]}") + except Exception as e: + print(f"Sentence Transformers Error: {e}") asyncio.run(main()) From 23e7ffbe1c71d8dd26fe06cdf191c8308e9e088c Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 18 Nov 2025 12:19:55 +0100 Subject: [PATCH 2/2] Reformat --- lightrag/llm/sentence_transformers.py | 12 +++++++++--- lightrag/rerank.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/lightrag/llm/sentence_transformers.py b/lightrag/llm/sentence_transformers.py index b5858a10eb..8a8ef5cdb1 100644 --- a/lightrag/llm/sentence_transformers.py +++ b/lightrag/llm/sentence_transformers.py @@ -13,7 +13,9 @@ async def sentence_transformers_embed( texts: list[str], model: SentenceTransformer ) -> np.ndarray: - async def inner_encode(texts: list[str], model: SentenceTransformer, embedding_dim: int = 1024): + async def inner_encode( + texts: list[str], model: SentenceTransformer, embedding_dim: int = 1024 + ): return model.encode( texts, truncate_dim=embedding_dim, @@ -21,6 +23,10 @@ async def inner_encode(texts: list[str], model: SentenceTransformer, embedding_d convert_to_tensor=False, show_progress_bar=False, ) - - embedding_func = EmbeddingFunc(embedding_dim=model.get_sentence_embedding_dimension(), func=inner_encode, max_token_size=model.get_max_seq_length()) + + embedding_func = EmbeddingFunc( + embedding_dim=model.get_sentence_embedding_dimension(), + func=inner_encode, + max_token_size=model.get_max_seq_length(), + ) return await embedding_func(texts, model=model) diff --git a/lightrag/rerank.py b/lightrag/rerank.py index cf508a8a74..4d5ba9328b 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -384,7 +384,7 @@ async def main(): print(f"Document: {docs[item['index']]}") except Exception as e: print(f"Aliyun Error: {e}") - + # Test Sentence Transformers rerank try: print("\n=== Sentence Transformers Rerank ===")