Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def __init__(self, megaservice, host="0.0.0.0", port=8889):
host,
port,
str(MegaServiceEndpoint.RETRIEVALTOOL),
Union[TextDoc, EmbeddingRequest, ChatCompletionRequest],
Union[TextDoc, ChatCompletionRequest],
Union[RerankedDoc, LLMParamsDoc],
)

Expand All @@ -789,7 +789,7 @@ def parser_input(data, TypeClass, key):

data = await request.json()
query = None
for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]):
for key, TypeClass in zip(["text", "messages"], [TextDoc, ChatCompletionRequest]):
query, chat_request = parser_input(data, TypeClass, key)
if query is not None:
break
Expand Down
4 changes: 2 additions & 2 deletions comps/reranks/tei/reranking_tei.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
endpoint="/v1/reranking",
host="0.0.0.0",
port=8000,
input_datatype=SearchedDoc,
output_datatype=LLMParamsDoc,
input_datatype=Union[SearchedDoc, RerankingRequest, ChatCompletionRequest],
output_datatype=Union[LLMParamsDoc, RerankingResponse, ChatCompletionRequest],
)
@register_statistics(names=["opea_service@reranking_tei"])
async def reranking(
Expand Down
16 changes: 15 additions & 1 deletion comps/retrievers/redis/langchain/retriever_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
EmbeddingResponse,
RetrievalRequest,
RetrievalResponse,
RetrievalResponseData,
Expand Down Expand Up @@ -54,12 +55,25 @@ async def retrieve(
else:
if isinstance(input, EmbedDoc):
query = input.text
embedding_data_input = input.embedding
else:
# for RetrievalRequest, ChatCompletionRequest
query = input.input
if isinstance(input.embedding, EmbeddingResponse):
embeddings = input.embedding.data
embedding_data_input = []
for emb in embeddings:
# each emb is EmbeddingResponseData
# print("Embedding data: ", emb.embedding)
# print("Embedding data length: ",len(emb.embedding))
embedding_data_input.append(emb.embedding)
# print("All Embedding data length: ",len(embedding_data_input))
else:
embedding_data_input = input.embedding

# if the Redis index has data, perform the search
if input.search_type == "similarity":
search_res = await vector_db.asimilarity_search_by_vector(embedding=input.embedding, k=input.k)
search_res = await vector_db.asimilarity_search_by_vector(embedding=embedding_data_input, k=input.k)
elif input.search_type == "similarity_distance_threshold":
if input.distance_threshold is None:
raise ValueError("distance_threshold must be provided for " + "similarity_distance_threshold retriever")
Expand Down