diff --git a/comps/cores/proto/api_protocol.py b/comps/cores/proto/api_protocol.py index f8fec8d3a9..1200183740 100644 --- a/comps/cores/proto/api_protocol.py +++ b/comps/cores/proto/api_protocol.py @@ -267,6 +267,9 @@ class ChatCompletionRequest(BaseModel): # define request_type: Literal["chat"] = "chat" + + # key index name + key_index_name: Optional[str] = None class DocSumChatCompletionRequest(ChatCompletionRequest): diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index 5a86d3c90a..6071f8282a 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -102,7 +102,7 @@ class EmbedDoc(BaseDoc): lambda_mult: float = 0.5 score_threshold: float = 0.2 constraints: Optional[Union[Dict[str, Any], List[Dict[str, Any]], None]] = None - + index_name: Optional[str] = None class EmbedMultimodalDoc(EmbedDoc): # extend EmbedDoc with these attributes @@ -225,6 +225,7 @@ class LLMParams(BaseDoc): repetition_penalty: float = 1.03 stream: bool = True language: str = "auto" # can be "en", "zh" + key_index_name: Optional[str] = None chat_template: Optional[str] = Field( default=None, diff --git a/comps/dataprep/src/integrations/redis.py b/comps/dataprep/src/integrations/redis.py index a181013bcd..a62f7f20a3 100644 --- a/comps/dataprep/src/integrations/redis.py +++ b/comps/dataprep/src/integrations/redis.py @@ -99,12 +99,12 @@ def format_redis_conn_from_env(): REDIS_URL = format_redis_conn_from_env() redis_pool = redis.ConnectionPool.from_url(REDIS_URL) - -def check_index_existance(client): +def check_index_existance(client, index_name: str = KEY_INDEX_NAME): if logflag: logger.info(f"[ check index existence ] checking {client}") try: results = client.search("*") + if logflag: logger.info(f"[ check index existence ] index of client exists: {client}") return results @@ -120,6 +120,8 @@ def create_index(client, index_name: str = KEY_INDEX_NAME): try: definition = IndexDefinition(index_type=IndexType.HASH, prefix=["file:"]) client.create_index((TextField("file_name"), TextField("key_ids")), definition=definition) + # client.create_index((TextField("index_name"), TextField("file_name"), TextField("key_ids")), definition=definition) + if logflag: logger.info(f"[ create index ] index {index_name} successfully created") except Exception as e: @@ -131,14 +133,14 @@ def create_index(client, index_name: str = KEY_INDEX_NAME): def store_by_id(client, key, value): if logflag: - logger.info(f"[ store by id ] storing ids of {key}") + logger.info(f"[ store by id ] storing ids of {client.index_name + '_' + key}") try: - client.add_document(doc_id="file:" + key, file_name=key, key_ids=value) + client.add_document(doc_id="file:" + client.index_name + '_' + key, file_name=client.index_name + '_' + key, key_ids=value) if logflag: - logger.info(f"[ store by id ] store document success. id: file:{key}") + logger.info(f"[ store by id ] store document success. id: file:{client.index_name + '_' + key}") except Exception as e: if logflag: - logger.info(f"[ store by id ] fail to store document file:{key}: {e}") + logger.info(f"[ store by id ] fail to store document file:{client.index_name + '_' + key}: {e}") return False return True @@ -184,8 +186,9 @@ def delete_by_id(client, id): def ingest_chunks_to_redis(file_name: str, chunks: List): + KEY_INDEX_NAME = os.getenv("KEY_INDEX_NAME", "file-keys") if logflag: - logger.info(f"[ redis ingest chunks ] file name: {file_name}") + logger.info(f"[ redis ingest chunks ] file name: '{file_name}' to '{KEY_INDEX_NAME}' index.") # Create vectorstore if TEI_EMBEDDING_ENDPOINT: if not HUGGINGFACEHUB_API_TOKEN: @@ -223,23 +226,26 @@ def ingest_chunks_to_redis(file_name: str, chunks: List): _, keys = Redis.from_texts_return_keys( texts=batch_texts, embedding=embedder, - index_name=INDEX_NAME, + index_name=KEY_INDEX_NAME, redis_url=REDIS_URL, ) + keys = [k.replace(KEY_INDEX_NAME, KEY_INDEX_NAME + '_' + file_name) for k in keys] if logflag: logger.info(f"[ redis ingest chunks ] keys: {keys}") file_ids.extend(keys) if logflag: logger.info(f"[ redis ingest chunks ] Processed batch {i//batch_size + 1}/{(num_chunks-1)//batch_size + 1}") - + # store file_ids into index file-keys r = redis.Redis(connection_pool=redis_pool) client = r.ft(KEY_INDEX_NAME) + if not check_index_existance(client): - assert create_index(client) - + assert create_index(client, index_name=KEY_INDEX_NAME) + try: assert store_by_id(client, key=file_name, value="#".join(file_ids)) + except Exception as e: if logflag: logger.info(f"[ redis ingest chunks ] {e}. Fail to store chunks of file {file_name}.") @@ -288,6 +294,7 @@ def ingest_data_to_redis(doc_path: DocPath): logger.info(f"[ redis ingest data ] Done preprocessing. Created {len(chunks)} chunks of the given file.") file_name = doc_path.path.split("/")[-1] + return ingest_chunks_to_redis(file_name, chunks) @@ -364,7 +371,13 @@ async def ingest_files( if logflag: logger.info(f"[ redis ingest ] files:{files}") logger.info(f"[ redis ingest ] link_list:{link_list}") - + + KEY_INDEX_NAME = os.getenv("KEY_INDEX_NAME", "file-keys") + if KEY_INDEX_NAME != "file-keys": + logger.info(f"KEY_INDEX_NAME: {KEY_INDEX_NAME} is different than the default one. Setting up the parameters.") + self.data_index_client = self.client.ft(INDEX_NAME) + self.key_index_client = self.client.ft(KEY_INDEX_NAME) + if files: if not isinstance(files, list): files = [files] @@ -372,26 +385,28 @@ async def ingest_files( for file in files: encode_file = encode_filename(file.filename) - doc_id = "file:" + encode_file + doc_id = "file:" + KEY_INDEX_NAME + '_' + encode_file if logflag: logger.info(f"[ redis ingest ] processing file {doc_id}") - # check whether the file already exists - key_ids = None - try: - key_ids = search_by_id(self.key_index_client, doc_id).key_ids - if logflag: - logger.info(f"[ redis ingest] File {file.filename} already exists.") - except Exception as e: - logger.info(f"[ redis ingest] File {file.filename} does not exist.") - if key_ids: - raise HTTPException( - status_code=400, - detail=f"Uploaded file {file.filename} already exists. Please change file name.", - ) + if KEY_INDEX_NAME in self.get_list_of_indices(): + # check whether the file already exists + key_ids = None + try: + key_ids = search_by_id(self.key_index_client, doc_id).key_ids + if logflag: + logger.info(f"[ redis ingest] File '{file.filename}' already exists in '{KEY_INDEX_NAME}' index.") + except Exception as e: + logger.info(f"[ redis ingest] File {file.filename} does not exist.") + if key_ids: + raise HTTPException( + status_code=400, + detail=f"Uploaded file '{file.filename}' already exists in '{KEY_INDEX_NAME}' index. Please change file name or 'index_name'.", + ) save_path = upload_folder + encode_file await save_content_to_local_disk(save_path, file) + ingest_data_to_redis( DocPath( path=save_path, @@ -451,7 +466,7 @@ async def ingest_files( raise HTTPException(status_code=400, detail="Must provide either a file or a string list.") - async def get_files(self): + async def get_files(self, key_index_name=KEY_INDEX_NAME): """Get file structure from redis database in the format of { "name": "File Name", @@ -459,7 +474,9 @@ async def get_files(self): "type": "File", "parent": "", }""" - + + if key_index_name is None: + key_index_name = KEY_INDEX_NAME if logflag: logger.info("[ redis get ] start to get file structure") @@ -467,20 +484,20 @@ async def get_files(self): file_list = [] # check index existence - res = check_index_existance(self.key_index_client) + res = key_index_name in self.get_list_of_indices() if not res: if logflag: - logger.info(f"[ redis get ] index {KEY_INDEX_NAME} does not exist") + logger.info(f"[ redis get ] index {key_index_name} does not exist") return file_list while True: response = self.client.execute_command( - "FT.SEARCH", KEY_INDEX_NAME, "*", "LIMIT", offset, offset + SEARCH_BATCH_SIZE + "FT.SEARCH", key_index_name, "*", "LIMIT", offset, offset + SEARCH_BATCH_SIZE ) # no doc retrieved if len(response) < 2: break - file_list = format_search_results(response, file_list) + file_list = format_search_results(response, key_index_name, file_list) offset += SEARCH_BATCH_SIZE # last batch if (len(response) - 1) // 2 < SEARCH_BATCH_SIZE: @@ -608,3 +625,17 @@ async def delete_files(self, file_path: str = Body(..., embed=True)): if logflag: logger.info(f"[ redis delete ] Delete folder {file_path} is not supported for now.") raise HTTPException(status_code=404, detail=f"Delete folder {file_path} is not supported for now.") + + def get_list_of_indices(self): + """ + Retrieves a list of all indices from the Redis client. + + Returns: + A list of index names as strings. + """ + # Execute the command to list all indices + indices = self.client.execute_command('FT._LIST') + # Decode each index name from bytes to string + indices_list = [item.decode('utf-8') for item in indices] + return indices_list + \ No newline at end of file diff --git a/comps/dataprep/src/integrations/redis_multimodal.py b/comps/dataprep/src/integrations/redis_multimodal.py index 6ae4d185bc..916968b349 100644 --- a/comps/dataprep/src/integrations/redis_multimodal.py +++ b/comps/dataprep/src/integrations/redis_multimodal.py @@ -7,6 +7,8 @@ import shutil import time import uuid +import redis + from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Type, Union @@ -105,6 +107,8 @@ def format_redis_conn_from_env(): logger = CustomLogger("opea_dataprep_redis_multimodal") logflag = os.getenv("LOGFLAG", False) +redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=os.getenv("REDIS_PASSWORD", None)) + class MultimodalRedis(Redis): """Redis vector database to process multimodal data.""" @@ -192,6 +196,7 @@ def from_text_image_pairs_return_keys( if images else instance.add_text(texts, metadatas, keys=keys) ) + return instance, keys def add_text_image_pairs( @@ -329,6 +334,10 @@ class OpeaMultimodalRedisDataprep(OpeaComponent): def __init__(self, name: str, description: str, config: dict = None): super().__init__(name, ServiceType.DATAPREP.name.lower(), description, config) + + print(">>>>>>>>>>>>>>>>>>>> OpeaMultimodalRedisDataprep - __init__ name:", name) + print(">>>>>>>>>>>>>>>>>>>> OpeaMultimodalRedisDataprep - __init__ description:", description) + self.device = "cpu" self.upload_folder = "./uploaded_files/" # Load embeddings model @@ -454,6 +463,8 @@ def ingest_multimodal(self, filename, data_folder, embeddings, is_pdf=False): path_to_frames = os.path.join(data_folder, "frames") annotation = load_json_file(annotation_file_path) + + print(">>>>>>>>>>>>>>>>>>>> ingest_multimodal - annotation:", annotation) # prepare data to ingest if is_pdf: @@ -464,7 +475,13 @@ def ingest_multimodal(self, filename, data_folder, embeddings, is_pdf=False): text_list, image_list, metadatas = self.prepare_data_and_metadata_from_annotation( annotation, path_to_frames, filename ) - + + INDEX_NAME = os.getenv("INDEX_NAME", "mm-rag-redis") + + print(">>>>>>>>>>>>>>>>>>>> ingest_multimodal - filename:", filename) + print(">>>>>>>>>>>>>>>>>>>> ingest_multimodal - text_list:", len(text_list)) + print() + MultimodalRedis.from_text_image_pairs_return_keys( texts=[f"From {filename}. " + text for text in text_list], images=image_list, @@ -474,7 +491,48 @@ def ingest_multimodal(self, filename, data_folder, embeddings, is_pdf=False): index_schema=INDEX_SCHEMA, redis_url=REDIS_URL, ) + + def get_list_of_indices(self, redis_client=redis_client): + """ + Retrieves a list of all indices from the Redis client. + Args: + redis_client: The Redis client instance to use for executing commands. + + Returns: + A list of index names as strings. + """ + # Execute the command to list all indices + indices = redis_client.execute_command('FT._LIST') + # Decode each index name from bytes to string and strip any surrounding single quotes + indices_list = [item.decode('utf-8').strip("'") for item in indices] + print(">>>>>>>>>>>>>>>>>>>> redis mm - get_list_of_indices ", indices_list) + return indices_list + + def get_items_of_index(self, index_name=INDEX_NAME, redis_client=redis_client): + """ + Retrieves items from a specific index in Redis. + + Args: + index_name: The name of the index to search. + redis_client: The Redis client instance to use for executing commands. + + Returns: + A sorted list of items from the specified index. + """ + # Execute the command to search for all items in the specified index + results = redis_client.execute_command(f'FT.SEARCH {index_name} {"*"} LIMIT 0 100') + list_of_items = [] + # Iterate through the results + for r in results: + if isinstance(r, list): + # Extract and decode the item where 'source_video' is found in the value + list_of_items.append( + [r[i+1].decode('utf-8') for i, v in enumerate(r) if 'source_video' in str(v)][0] + ) + # Return the sorted list of items + return sorted(list_of_items) + def drop_index(self, index_name, redis_url=REDIS_URL): logger.info(f"dropping index {index_name}") try: @@ -648,6 +706,7 @@ async def ingest_generate_captions(self, files: List[UploadFile] = File(None)): raise HTTPException(status_code=400, detail="Must provide at least one file.") async def ingest_files(self, files: Optional[Union[UploadFile, List[UploadFile]]] = File(None)): + if files: accepted_media_formats = [".mp4", ".png", ".jpg", ".jpeg", ".gif", ".pdf"] # Create a lookup dictionary containing all media files @@ -706,6 +765,7 @@ async def ingest_files(self, files: Optional[Union[UploadFile, List[UploadFile]] uploaded_files_map[file_name] = media_file_name if file_extension == ".pdf": + # Set up location to store pdf images and text, reusing "frames" and "annotations" from video output_dir = os.path.join(self.upload_folder, media_dir_name) os.makedirs(output_dir, exist_ok=True) @@ -745,7 +805,7 @@ async def ingest_files(self, files: Optional[Union[UploadFile, List[UploadFile]] "sub_video_id": image_idx, } ) - + with open(os.path.join(output_dir, "annotations.json"), "w") as f: json.dump(annotations, f) @@ -753,6 +813,7 @@ async def ingest_files(self, files: Optional[Union[UploadFile, List[UploadFile]] self.ingest_multimodal( file_name, os.path.join(self.upload_folder, media_dir_name), self.embeddings, is_pdf=True ) + else: # Save caption file in upload directory caption_file_extension = os.path.splitext(matched_files[media_file][1].filename)[1] @@ -792,7 +853,7 @@ async def ingest_files(self, files: Optional[Union[UploadFile, List[UploadFile]] async def get_files(self): """Returns list of names of uploaded videos saved on the server.""" - + if not Path(self.upload_folder).exists(): logger.info("No file uploaded, return empty list.") return [] diff --git a/comps/dataprep/src/opea_dataprep_loader.py b/comps/dataprep/src/opea_dataprep_loader.py index 8ec1042f8d..41e94a06ab 100644 --- a/comps/dataprep/src/opea_dataprep_loader.py +++ b/comps/dataprep/src/opea_dataprep_loader.py @@ -31,6 +31,11 @@ async def delete_files(self, *args, **kwargs): if logflag: logger.info("[ dataprep loader ] delete files") return await self.component.delete_files(*args, **kwargs) + + async def get_list_of_indices(self, *args, **kwargs): + if logflag: + logger.info("[ dataprep loader ] get indices") + return self.component.get_list_of_indices(*args, **kwargs) class OpeaDataprepMultiModalLoader(OpeaComponentLoader): diff --git a/comps/dataprep/src/opea_dataprep_microservice.py b/comps/dataprep/src/opea_dataprep_microservice.py index 7dda2879d4..1c5b8b3aca 100644 --- a/comps/dataprep/src/opea_dataprep_microservice.py +++ b/comps/dataprep/src/opea_dataprep_microservice.py @@ -7,16 +7,16 @@ from typing import List, Optional, Union from fastapi import Body, File, Form, UploadFile -from integrations.elasticsearch import OpeaElasticSearchDataprep -from integrations.milvus import OpeaMilvusDataprep -from integrations.neo4j_llamaindex import OpeaNeo4jLlamaIndexDataprep -from integrations.opensearch import OpeaOpenSearchDataprep -from integrations.pgvect import OpeaPgvectorDataprep -from integrations.pipecone import OpeaPineConeDataprep -from integrations.qdrant import OpeaQdrantDataprep -from integrations.redis import OpeaRedisDataprep -from integrations.vdms import OpeaVdmsDataprep -from opea_dataprep_loader import OpeaDataprepLoader +from comps.dataprep.src.integrations.elasticsearch import OpeaElasticSearchDataprep +from comps.dataprep.src.integrations.milvus import OpeaMilvusDataprep +from comps.dataprep.src.integrations.neo4j_llamaindex import OpeaNeo4jLlamaIndexDataprep +from comps.dataprep.src.integrations.opensearch import OpeaOpenSearchDataprep +from comps.dataprep.src.integrations.pgvect import OpeaPgvectorDataprep +from comps.dataprep.src.integrations.pipecone import OpeaPineConeDataprep +from comps.dataprep.src.integrations.qdrant import OpeaQdrantDataprep +from comps.dataprep.src.integrations.redis import OpeaRedisDataprep +from comps.dataprep.src.integrations.vdms import OpeaVdmsDataprep +from comps.dataprep.src.opea_dataprep_loader import OpeaDataprepLoader from comps import ( CustomLogger, @@ -55,8 +55,13 @@ async def ingest_files( chunk_overlap: int = Form(100), process_table: bool = Form(False), table_strategy: str = Form("fast"), + key_index_name: Optional[str] = Form(None), ): start = time.time() + + if key_index_name: + # Set key_input_name to environment variable + os.environ['KEY_INDEX_NAME'] = key_index_name if logflag: logger.info(f"[ ingest ] files:{files}") @@ -84,7 +89,7 @@ async def ingest_files( port=5000, ) @register_statistics(names=["opea_service@dataprep"]) -async def get_files(): +async def get_files(key_index_name: Optional[str] = File(None)): start = time.time() if logflag: @@ -92,7 +97,7 @@ async def get_files(): try: # Use the loader to invoke the component - response = await loader.get_files() + response = await loader.get_files(key_index_name) # Log the result if logging is enabled if logflag: logger.info(f"[ get ] ingested files: {response}") @@ -131,6 +136,34 @@ async def delete_files(file_path: str = Body(..., embed=True)): logger.error(f"Error during dataprep delete invocation: {e}") raise +@register_microservice( + name="opea_service@dataprep", + service_type=ServiceType.DATAPREP, + endpoint="/v1/dataprep/indices", + host="0.0.0.0", + port=5000, +) +@register_statistics(names=["opea_service@dataprep"]) +async def get_list_of_indices(): + start = time.time() + if logflag: + logger.info("[ get ] start to get list of indices.") + + try: + # Use the loader to invoke the component + response = await loader.get_list_of_indices() + + # Log the result if logging is enabled + if logflag: + logger.info(f"[ get ] list of indices: {response}") + + # Record statistics + statistics_dict["opea_service@dataprep"].append_latency(time.time() - start, None) + + return response + except Exception as e: + logger.error(f"Error during dataprep get list of indices: {e}") + raise if __name__ == "__main__": logger.info("OPEA Dataprep Microservice is starting...") diff --git a/comps/dataprep/src/opea_dataprep_multimodal_microservice.py b/comps/dataprep/src/opea_dataprep_multimodal_microservice.py index 9fbb562a17..ce2cad6c2f 100644 --- a/comps/dataprep/src/opea_dataprep_multimodal_microservice.py +++ b/comps/dataprep/src/opea_dataprep_multimodal_microservice.py @@ -7,9 +7,9 @@ from typing import List, Optional, Union from fastapi import Body, File, UploadFile -from integrations.redis_multimodal import OpeaMultimodalRedisDataprep -from integrations.vdms_multimodal import OpeaMultimodalVdmsDataprep -from opea_dataprep_loader import OpeaDataprepMultiModalLoader +from comps.dataprep.src.integrations.redis_multimodal import OpeaMultimodalRedisDataprep +from comps.dataprep.src.integrations.vdms_multimodal import OpeaMultimodalVdmsDataprep +from comps.dataprep.src.opea_dataprep_loader import OpeaDataprepMultiModalLoader from comps import ( CustomLogger, @@ -41,9 +41,15 @@ port=5000, ) @register_statistics(names=["opea_service@dataprep_multimodal"]) -async def ingest_files(files: Optional[Union[UploadFile, List[UploadFile]]] = File(None)): +async def ingest_files(files: Optional[Union[UploadFile, List[UploadFile]]] = File(None), + index_name: Optional[str] = File(None) + ): start = time.time() - + + if index_name: + # Set an environment variable + os.environ['INDEX_NAME'] = index_name + if logflag: logger.info(f"[ ingest ] files:{files}") @@ -257,6 +263,67 @@ async def delete_files(file_path: str = Body(..., embed=True)): logger.error(f"Error during dataprep delete invocation: {e}") raise +@register_microservice( + name="opea_service@dataprep_multimodal", + service_type=ServiceType.DATAPREP, + endpoint="/v1/dataprep/indices", + host="0.0.0.0", + port=5000, +) +@register_statistics(names=["opea_service@dataprep_multimodal"]) +async def get_list_of_indices(): + start = time.time() + + if logflag: + logger.info("[ get ] start to get list of indices.") + + try: + # Use the loader to invoke the component + response = await loader.get_list_of_indices() + + # Log the result if logging is enabled + if logflag: + logger.info(f"[ get ] list of indices: {response}") + + # Record statistics + statistics_dict["opea_service@dataprep_multimodal"].append_latency(time.time() - start, None) + + return response + except Exception as e: + logger.error(f"Error during dataprep get list of indices: {e}") + raise + + +@register_microservice( + name="opea_service@dataprep_multimodal", + service_type=ServiceType.DATAPREP, + endpoint="/v1/dataprep/items_of_index", + host="0.0.0.0", + port=5000, +) +@register_statistics(names=["opea_service@dataprep_multimodal"]) +async def get_items_of_index(index_name: Optional[str] = File(None) ): + start = time.time() + + if logflag: + logger.info(f"[ get ] start to get items of index:{index_name}.") + + try: + # Use the loader to invoke the component + response = await loader.get_items_of_index(index_name) + + # Log the result if logging is enabled + if logflag: + logger.info(f"[ get ] items of index: {response}") + + # Record statistics + statistics_dict["opea_service@dataprep_multimodal"].append_latency(time.time() - start, None) + + return response + except Exception as e: + logger.error(f"Error during dataprep get list of indexes: {e}") + raise + if __name__ == "__main__": logger.info("OPEA Dataprep Multimodal Microservice is starting...") diff --git a/comps/dataprep/src/utils.py b/comps/dataprep/src/utils.py index f657fd2cee..c18ff5ad47 100644 --- a/comps/dataprep/src/utils.py +++ b/comps/dataprep/src/utils.py @@ -800,14 +800,14 @@ def get_file_structure(root_path: str, parent_path: str = "") -> List[Dict[str, return result -def format_search_results(response, file_list: list): +def format_search_results(response, key_index_name, file_list: list): for i in range(1, len(response), 2): - file_name = response[i].decode()[5:] + file_name = response[i].decode()[4:] file_dict = { "name": decode_filename(file_name), "id": decode_filename(file_name), "type": "File", - "parent": "", + "parent": key_index_name, } file_list.append(file_dict) return file_list diff --git a/comps/retrievers/src/integrations/redis.py b/comps/retrievers/src/integrations/redis.py index f71a0ae5f2..d38ae0db66 100644 --- a/comps/retrievers/src/integrations/redis.py +++ b/comps/retrievers/src/integrations/redis.py @@ -56,16 +56,16 @@ def __init__(self, name: str, description: str, config: dict = None): if not health_status: logger.error("OpeaRedisRetriever health check failed.") - def _initialize_client(self) -> Redis: + def _initialize_client(self, index_name = INDEX_NAME) -> Redis: """Initializes the redis client.""" try: if BRIDGE_TOWER_EMBEDDING: logger.info(f"generate multimodal redis instance with {BRIDGE_TOWER_EMBEDDING}") client = Redis( - embedding=self.embeddings, index_name=INDEX_NAME, index_schema=INDEX_SCHEMA, redis_url=REDIS_URL + embedding=self.embeddings, index_name=index_name, index_schema=INDEX_SCHEMA, redis_url=REDIS_URL ) else: - client = Redis(embedding=self.embeddings, index_name=INDEX_NAME, redis_url=REDIS_URL) + client = Redis(embedding=self.embeddings, index_name=index_name, redis_url=REDIS_URL) return client except Exception as e: logger.error(f"fail to initialize redis client: {e}") @@ -100,6 +100,9 @@ async def invoke( """ if logflag: logger.info(input) + + if input.index_name: + self.client = self._initialize_client(index_name=input.index_name) # check if the Redis index has data if self.client.client.keys() == []: @@ -140,7 +143,5 @@ async def invoke( else: raise ValueError(f"{input.search_type} not valid") - if logflag: - logger.info(search_res) - return search_res + return search_res \ No newline at end of file diff --git a/tests/retrievers/test_retrievers_redis.sh b/tests/retrievers/test_retrievers_redis.sh index aa2bbe61fc..a962a2835c 100644 --- a/tests/retrievers/test_retrievers_redis.sh +++ b/tests/retrievers/test_retrievers_redis.sh @@ -18,7 +18,9 @@ service_name_mm="retriever-redis-multimodal" function build_docker_images() { cd $WORKPATH - docker build --no-cache -t ${REGISTRY:-opea}/retriever:${TAG:-latest} --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/retrievers/src/Dockerfile . + # docker build --no-cache -t ${REGISTRY:-opea}/retriever:${TAG:-latest} --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/retrievers/src/Dockerfile . + docker build -t ${REGISTRY:-opea}/retriever:${TAG:-latest} --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/retrievers/src/Dockerfile . + if [ $? -ne 0 ]; then echo "opea/retriever built fail" exit 1 @@ -137,7 +139,7 @@ function stop_docker() { function main() { - stop_docker + # stop_docker build_docker_images # test text retriever @@ -152,9 +154,9 @@ function main() { validate_microservice "$test_embedding_multi" "$service_name_mm" validate_mm_microservice "$test_embedding_multi" "$service_name_mm" - # clean env - stop_docker - echo y | docker system prune + # # clean env + # stop_docker + # echo y | docker system prune }