diff --git a/README-zh.md b/README-zh.md index 12ab7b62b8..04577b4e7b 100644 --- a/README-zh.md +++ b/README-zh.md @@ -462,17 +462,12 @@ async def initialize_rag(): * 如果您想使用 Hugging Face 模型,只需要按如下方式设置 LightRAG: -参见 `lightrag_hf_demo.py` +参见`lightrag_hf_demo.py`, `lightrag_sentence_transformers_demo.py`等示例代码。 ```python -from functools import partial -from transformers import AutoTokenizer, AutoModel - -# Pre-load tokenizer and model -tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") -embed_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") - # 使用 Hugging Face 模型初始化 LightRAG +from sentence_transformers import SentenceTransformer + rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=hf_model_complete, # 使用 Hugging Face 模型进行文本生成 @@ -480,12 +475,9 @@ rag = LightRAG( # 使用 Hugging Face 嵌入函数 embedding_func=EmbeddingFunc( embedding_dim=384, - max_token_size=2048, - model_name="sentence-transformers/all-MiniLM-L6-v2", - func=partial( - hf_embed.func, # 使用 .func 访问底层未封装的函数 - tokenizer=tokenizer, - embed_model=embed_model + func=lambda texts: sentence_transformers_embed( + texts, + model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") ) ), ) @@ -762,6 +754,7 @@ rag = LightRAG( * **Cohere / vLLM**: `cohere_rerank` * **Jina AI**: `jina_rerank` * **阿里云**: `ali_rerank` +* **Sentence Transformers**: `sentence_transformers_rerank` 您可以将其中一个函数注入到 LightRAG 对象的 `rerank_model_func` 属性中。这将使 LightRAG 的查询函数能够使用注入的函数对检索到的文本块进行重新排序。详细用法请参考 `examples/rerank_example.py` 文件。 diff --git a/README.md b/README.md index c10c692786..65345cc325 100644 --- a/README.md +++ b/README.md @@ -462,30 +462,22 @@ async def initialize_rag(): * If you want to use Hugging Face models, you only need to set LightRAG as follows: -See `lightrag_hf_demo.py` +See `lightrag_hf_demo.py` & `lightrag_sentence_transformers_demo.py` for complete examples. ```python -from functools import partial -from transformers import AutoTokenizer, AutoModel - -# Pre-load tokenizer and model -tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") -embed_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") - # Initialize LightRAG with Hugging Face model +from sentence_transformers import SentenceTransformer + rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=hf_model_complete, # Use Hugging Face model for text generation llm_model_name='meta-llama/Llama-3.1-8B-Instruct', # Model name from Hugging Face - # Use Hugging Face embedding function + # Use Hugging Face Sentence Transformers embedding function embedding_func=EmbeddingFunc( embedding_dim=384, - max_token_size=2048, - model_name="sentence-transformers/all-MiniLM-L6-v2", - func=partial( - hf_embed.func, # Use .func to access the unwrapped function - tokenizer=tokenizer, - embed_model=embed_model + func=lambda texts: sentence_transformers_embed( + texts, + model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") ) ), ) @@ -763,6 +755,7 @@ To enhance retrieval quality, documents can be re-ranked based on a more effecti * **Cohere / vLLM**: `cohere_rerank` * **Jina AI**: `jina_rerank` * **Aliyun**: `ali_rerank` +* **Sentence Transformers**: `sentence_transformers_rerank` You can inject one of these functions into the `rerank_model_func` attribute of the LightRAG object. This will enable LightRAG's query function to re-order retrieved text blocks using the injected function. For detailed usage, please refer to the `examples/rerank_example.py` file. diff --git a/examples/unofficial-sample/lightrag_sentence_transformers_demo.py b/examples/unofficial-sample/lightrag_sentence_transformers_demo.py new file mode 100644 index 0000000000..9971a7f88f --- /dev/null +++ b/examples/unofficial-sample/lightrag_sentence_transformers_demo.py @@ -0,0 +1,75 @@ +import os + +from lightrag import LightRAG, QueryParam +from lightrag.llm.hf import hf_model_complete +from lightrag.llm.sentence_transformers import sentence_transformers_embed +from lightrag.utils import EmbeddingFunc +from sentence_transformers import SentenceTransformer + +import asyncio +import nest_asyncio + +nest_asyncio.apply() + +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=hf_model_complete, + llm_model_name="meta-llama/Llama-3.1-8B-Instruct", + embedding_func=EmbeddingFunc( + embedding_dim=384, + max_token_size=512, + func=lambda texts: sentence_transformers_embed( + texts, + model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2"), + ), + ), + ) + + await rag.initialize_storages() # Auto-initializes pipeline_status + return rag + + +def main(): + rag = asyncio.run(initialize_rag()) + + with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + + # Perform naive search + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) + ) + + # Perform local search + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) + + # Perform global search + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="global") + ) + ) + + # Perform hybrid search + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="hybrid") + ) + ) + + +if __name__ == "__main__": + main() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 137a5335c6..03070655e5 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -814,6 +814,23 @@ async def optimized_embedding_function(texts, embedding_dim=None): if model: kwargs["model"] = model return await actual_func(**kwargs) + elif binding == "sentence_transformers": + from lightrag.llm.sentence_transformers import ( + sentence_transformers_embed, + ) + + actual_func = ( + sentence_transformers_embed.func + if isinstance(sentence_transformers_embed, EmbeddingFunc) + else sentence_transformers_embed + ) + kwargs = { + "texts": texts, + "embedding_dim": embedding_dim, + } + if model: + kwargs["model"] = model + return await actual_func(**kwargs) elif binding == "gemini": from lightrag.llm.gemini import gemini_embed @@ -976,13 +993,19 @@ async def bedrock_model_complete( # Configure rerank function based on args.rerank_bindingparameter rerank_model_func = None if args.rerank_binding != "null": - from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank + from lightrag.rerank import ( + cohere_rerank, + jina_rerank, + ali_rerank, + sentence_transformers_rerank, + ) # Map rerank binding to corresponding function rerank_functions = { "cohere": cohere_rerank, "jina": jina_rerank, "aliyun": ali_rerank, + "sentence_transformers": sentence_transformers_rerank, } # Select the appropriate rerank function based on binding diff --git a/lightrag/llm/sentence_transformers.py b/lightrag/llm/sentence_transformers.py new file mode 100644 index 0000000000..229c02e9c7 --- /dev/null +++ b/lightrag/llm/sentence_transformers.py @@ -0,0 +1,32 @@ +import pipmaster as pm # Pipmaster for dynamic library install + +if not pm.is_installed("sentence_transformers"): + pm.install("sentence_transformers") +if not pm.is_installed("numpy"): + pm.install("numpy") + +import numpy as np +from lightrag.utils import EmbeddingFunc +from sentence_transformers import SentenceTransformer + + +async def sentence_transformers_embed( + texts: list[str], model: SentenceTransformer, embedding_dim: int | None = None +) -> np.ndarray: + async def inner_encode( + texts: list[str], model: SentenceTransformer, embedding_dim: int = 1024 + ): + return model.encode( + texts, + truncate_dim=embedding_dim, + convert_to_numpy=True, + convert_to_tensor=False, + show_progress_bar=False, + ) + + embedding_func = EmbeddingFunc( + embedding_dim=embedding_dim or model.get_sentence_embedding_dimension() or 1024, + func=inner_encode, + max_token_size=model.get_max_seq_length(), + ) + return await embedding_func(texts, model=model) diff --git a/lightrag/rerank.py b/lightrag/rerank.py index 12950fe615..582fa29bbc 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -513,6 +513,40 @@ async def ali_rerank( ) +async def sentence_transformers_rerank( + query: str, + documents: List[str], + top_n: Optional[int] = None, + api_key: Optional[str] = None, + model: str = "BAAI/bge-reranker-v2-m3", + base_url: Optional[str] = None, + extra_body: Optional[Dict[str, Any]] = None, +) -> List[Dict[str, Any]]: + """ + Rerank documents using CrossEncoder from Sentence Transformers. + + Args: + query: The search query + documents: List of strings to rerank + top_n: Number of top results to return + api_key: Unused + model: rerank model name + base_url: Unused + extra_body: Unused + + Returns: + List of dictionary of ["index": int, "relevance_score": float] + """ + from sentence_transformers import CrossEncoder + + cross_encoder = CrossEncoder(model) + rankings = cross_encoder.rank(query=query, documents=documents, top_k=top_n) + return [ + {"index": result["corpus_id"], "relevance_score": result["score"]} + for result in rankings + ] + + """Please run this test as a module: python -m lightrag.rerank """ @@ -574,4 +608,19 @@ async def main(): except Exception as e: print(f"Aliyun Error: {e}") + # Test Sentence Transformers rerank + try: + print("\n=== Sentence Transformers Rerank ===") + result = await sentence_transformers_rerank( + query=query, + documents=docs, + top_n=2, + ) + print("Results:") + for item in result: + print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}") + print(f"Document: {docs[item['index']]}") + except Exception as e: + print(f"Sentence Transformers Error: {e}") + asyncio.run(main())