Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions README-zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -462,30 +462,22 @@ async def initialize_rag():

* 如果您想使用 Hugging Face 模型,只需要按如下方式设置 LightRAG:

参见 `lightrag_hf_demo.py`
参见`lightrag_hf_demo.py`, `lightrag_sentence_transformers_demo.py`等示例代码。

```python
from functools import partial
from transformers import AutoTokenizer, AutoModel

# Pre-load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
embed_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

# 使用 Hugging Face 模型初始化 LightRAG
from sentence_transformers import SentenceTransformer

rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=hf_model_complete, # 使用 Hugging Face 模型进行文本生成
llm_model_name='meta-llama/Llama-3.1-8B-Instruct', # Hugging Face 的模型名称
# 使用 Hugging Face 嵌入函数
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=2048,
model_name="sentence-transformers/all-MiniLM-L6-v2",
func=partial(
hf_embed.func, # 使用 .func 访问底层未封装的函数
tokenizer=tokenizer,
embed_model=embed_model
func=lambda texts: sentence_transformers_embed(
texts,
model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
)
),
)
Expand Down Expand Up @@ -762,6 +754,7 @@ rag = LightRAG(
* **Cohere / vLLM**: `cohere_rerank`
* **Jina AI**: `jina_rerank`
* **阿里云**: `ali_rerank`
* **Sentence Transformers**: `sentence_transformers_rerank`

您可以将其中一个函数注入到 LightRAG 对象的 `rerank_model_func` 属性中。这将使 LightRAG 的查询函数能够使用注入的函数对检索到的文本块进行重新排序。详细用法请参考 `examples/rerank_example.py` 文件。

Expand Down
23 changes: 8 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -462,30 +462,22 @@ async def initialize_rag():

* If you want to use Hugging Face models, you only need to set LightRAG as follows:

See `lightrag_hf_demo.py`
See `lightrag_hf_demo.py` & `lightrag_sentence_transformers_demo.py` for complete examples.

```python
from functools import partial
from transformers import AutoTokenizer, AutoModel

# Pre-load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
embed_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

# Initialize LightRAG with Hugging Face model
from sentence_transformers import SentenceTransformer

rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=hf_model_complete, # Use Hugging Face model for text generation
llm_model_name='meta-llama/Llama-3.1-8B-Instruct', # Model name from Hugging Face
# Use Hugging Face embedding function
# Use Hugging Face Sentence Transformers embedding function
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=2048,
model_name="sentence-transformers/all-MiniLM-L6-v2",
func=partial(
hf_embed.func, # Use .func to access the unwrapped function
tokenizer=tokenizer,
embed_model=embed_model
func=lambda texts: sentence_transformers_embed(
texts,
model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
)
),
)
Expand Down Expand Up @@ -763,6 +755,7 @@ To enhance retrieval quality, documents can be re-ranked based on a more effecti
* **Cohere / vLLM**: `cohere_rerank`
* **Jina AI**: `jina_rerank`
* **Aliyun**: `ali_rerank`
* **Sentence Transformers**: `sentence_transformers_rerank`

You can inject one of these functions into the `rerank_model_func` attribute of the LightRAG object. This will enable LightRAG's query function to re-order retrieved text blocks using the injected function. For detailed usage, please refer to the `examples/rerank_example.py` file.

Expand Down
75 changes: 75 additions & 0 deletions examples/unofficial-sample/lightrag_sentence_transformers_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os

from lightrag import LightRAG, QueryParam
from lightrag.llm.hf import hf_model_complete
from lightrag.llm.sentence_transformers import sentence_transformers_embed
from lightrag.utils import EmbeddingFunc
from sentence_transformers import SentenceTransformer

import asyncio
import nest_asyncio

nest_asyncio.apply()

WORKING_DIR = "./dickens"

if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)


async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=hf_model_complete,
llm_model_name="meta-llama/Llama-3.1-8B-Instruct",
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=512,
func=lambda texts: sentence_transformers_embed(
texts,
model=SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2"),
),
),
)

await rag.initialize_storages() # Auto-initializes pipeline_status
return rag


def main():
rag = asyncio.run(initialize_rag())

with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())

# Perform naive search
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)

# Perform local search
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)

# Perform global search
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="global")
)
)

# Perform hybrid search
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
)
)


if __name__ == "__main__":
main()
25 changes: 24 additions & 1 deletion lightrag/api/lightrag_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,23 @@ async def optimized_embedding_function(texts, embedding_dim=None):
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
elif binding == "sentence_transformers":
from lightrag.llm.sentence_transformers import (
sentence_transformers_embed,
)

actual_func = (
sentence_transformers_embed.func
if isinstance(sentence_transformers_embed, EmbeddingFunc)
else sentence_transformers_embed
)
kwargs = {
"texts": texts,
"embedding_dim": embedding_dim,
}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed

Expand Down Expand Up @@ -976,13 +993,19 @@ async def bedrock_model_complete(
# Configure rerank function based on args.rerank_bindingparameter
rerank_model_func = None
if args.rerank_binding != "null":
from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank
from lightrag.rerank import (
cohere_rerank,
jina_rerank,
ali_rerank,
sentence_transformers_rerank,
)

# Map rerank binding to corresponding function
rerank_functions = {
"cohere": cohere_rerank,
"jina": jina_rerank,
"aliyun": ali_rerank,
"sentence_transformers": sentence_transformers_rerank,
}

# Select the appropriate rerank function based on binding
Expand Down
32 changes: 32 additions & 0 deletions lightrag/llm/sentence_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pipmaster as pm # Pipmaster for dynamic library install

if not pm.is_installed("sentence_transformers"):
pm.install("sentence_transformers")
if not pm.is_installed("numpy"):
pm.install("numpy")

import numpy as np
from lightrag.utils import EmbeddingFunc
from sentence_transformers import SentenceTransformer


async def sentence_transformers_embed(
texts: list[str], model: SentenceTransformer, embedding_dim: int | None = None
) -> np.ndarray:
async def inner_encode(
texts: list[str], model: SentenceTransformer, embedding_dim: int = 1024
):
return model.encode(
texts,
truncate_dim=embedding_dim,
convert_to_numpy=True,
convert_to_tensor=False,
show_progress_bar=False,
)

embedding_func = EmbeddingFunc(
embedding_dim=embedding_dim or model.get_sentence_embedding_dimension() or 1024,
func=inner_encode,
max_token_size=model.get_max_seq_length(),
)
return await embedding_func(texts, model=model)
49 changes: 49 additions & 0 deletions lightrag/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,40 @@ async def ali_rerank(
)


async def sentence_transformers_rerank(
query: str,
documents: List[str],
top_n: Optional[int] = None,
api_key: Optional[str] = None,
model: str = "BAAI/bge-reranker-v2-m3",
base_url: Optional[str] = None,
extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Rerank documents using CrossEncoder from Sentence Transformers.

Args:
query: The search query
documents: List of strings to rerank
top_n: Number of top results to return
api_key: Unused
model: rerank model name
base_url: Unused
extra_body: Unused

Returns:
List of dictionary of ["index": int, "relevance_score": float]
"""
from sentence_transformers import CrossEncoder

cross_encoder = CrossEncoder(model)
rankings = cross_encoder.rank(query=query, documents=documents, top_k=top_n)
return [
{"index": result["corpus_id"], "relevance_score": result["score"]}
for result in rankings
]


"""Please run this test as a module:
python -m lightrag.rerank
"""
Expand Down Expand Up @@ -574,4 +608,19 @@ async def main():
except Exception as e:
print(f"Aliyun Error: {e}")

# Test Sentence Transformers rerank
try:
print("\n=== Sentence Transformers Rerank ===")
result = await sentence_transformers_rerank(
query=query,
documents=docs,
top_n=2,
)
print("Results:")
for item in result:
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
print(f"Document: {docs[item['index']]}")
except Exception as e:
print(f"Sentence Transformers Error: {e}")

asyncio.run(main())
Loading