diff --git a/comps/cores/proto/api_protocol.py b/comps/cores/proto/api_protocol.py index ff3988b8e9..5068a81145 100644 --- a/comps/cores/proto/api_protocol.py +++ b/comps/cores/proto/api_protocol.py @@ -232,6 +232,7 @@ class RetrievalRequest(BaseModel): class RetrievalRequestArangoDB(RetrievalRequest): graph_name: str | None = None search_start: str | None = None # "node", "edge", "chunk" + search_type: str | None = None # "vector", "hybrid" num_centroids: int | None = None distance_strategy: str | None = None # # "COSINE", "EUCLIDEAN_DISTANCE" use_approx_search: bool | None = None diff --git a/comps/dataprep/src/integrations/arangodb.py b/comps/dataprep/src/integrations/arangodb.py index d75152c76a..07c44fe9f1 100644 --- a/comps/dataprep/src/integrations/arangodb.py +++ b/comps/dataprep/src/integrations/arangodb.py @@ -370,18 +370,20 @@ async def ingest_files(self, input: Union[DataprepRequest, ArangoDBDataprepReque chunk_overlap = input.chunk_overlap process_table = input.process_table table_strategy = input.table_strategy - graph_name = getattr(input, "graph_name", ARANGO_GRAPH_NAME) - insert_async = getattr(input, "insert_async", ARANGO_INSERT_ASYNC) - insert_batch_size = getattr(input, "insert_batch_size", ARANGO_BATCH_SIZE) - embed_nodes = getattr(input, "embed_nodes", EMBED_NODES) - embed_edges = getattr(input, "embed_edges", EMBED_EDGES) - embed_chunks = getattr(input, "embed_chunks", EMBED_CHUNKS) - allowed_node_types = getattr(input, "allowed_node_types", ALLOWED_NODE_TYPES) - allowed_edge_types = getattr(input, "allowed_edge_types", ALLOWED_EDGE_TYPES) - node_properties = getattr(input, "node_properties", NODE_PROPERTIES) - edge_properties = getattr(input, "edge_properties", EDGE_PROPERTIES) - text_capitalization_strategy = getattr(input, "text_capitalization_strategy", TEXT_CAPITALIZATION_STRATEGY) - include_chunks = getattr(input, "include_chunks", INCLUDE_CHUNKS) + graph_name = getattr(input, "graph_name", ARANGO_GRAPH_NAME) or ARANGO_GRAPH_NAME + insert_async = getattr(input, "insert_async", ARANGO_INSERT_ASYNC) or ARANGO_INSERT_ASYNC + insert_batch_size = getattr(input, "insert_batch_size", ARANGO_BATCH_SIZE) or ARANGO_BATCH_SIZE + embed_nodes = getattr(input, "embed_nodes", EMBED_NODES) or EMBED_NODES + embed_edges = getattr(input, "embed_edges", EMBED_EDGES) or EMBED_EDGES + embed_chunks = getattr(input, "embed_chunks", EMBED_CHUNKS) or EMBED_CHUNKS + allowed_node_types = getattr(input, "allowed_node_types", ALLOWED_NODE_TYPES) or ALLOWED_NODE_TYPES + allowed_edge_types = getattr(input, "allowed_edge_types", ALLOWED_EDGE_TYPES) or ALLOWED_EDGE_TYPES + node_properties = getattr(input, "node_properties", NODE_PROPERTIES) or NODE_PROPERTIES + edge_properties = getattr(input, "edge_properties", EDGE_PROPERTIES) or EDGE_PROPERTIES + text_capitalization_strategy = ( + getattr(input, "text_capitalization_strategy", TEXT_CAPITALIZATION_STRATEGY) or TEXT_CAPITALIZATION_STRATEGY + ) + include_chunks = getattr(input, "include_chunks", INCLUDE_CHUNKS) or INCLUDE_CHUNKS self._initialize_llm( allowed_node_types=allowed_node_types, diff --git a/comps/dataprep/src/requirements.txt b/comps/dataprep/src/requirements.txt index 9cf39a4ec5..03a3c4b382 100644 --- a/comps/dataprep/src/requirements.txt +++ b/comps/dataprep/src/requirements.txt @@ -16,7 +16,7 @@ huggingface_hub ipython json-repair langchain -langchain-arangodb +langchain-arangodb==0.0.6 langchain-community langchain-elasticsearch langchain-experimental diff --git a/comps/retrievers/src/README_arangodb.md b/comps/retrievers/src/README_arangodb.md index 2a91a8548e..f5a95fc0f4 100644 --- a/comps/retrievers/src/README_arangodb.md +++ b/comps/retrievers/src/README_arangodb.md @@ -116,6 +116,7 @@ ArangoDB Vector configuration - `ARANGO_USE_APPROX_SEARCH`: If set to True, the microservice will use the approximate nearest neighbor search for as part of the retrieval step. Defaults to `False`, which means the microservice will use the exact search. - `ARANGO_NUM_CENTROIDS`: The number of centroids to use for the approximate nearest neighbor search. Defaults to `1`. - `ARANGO_SEARCH_START`: The starting point for the search. Defaults to `node`. Other option could be `"edge"`, or `"chunk"`. +- `ARANGO_SEARCH_TYPE`: The type of search to use for the ArangoDB service. Defaults to `vector`. Other option could be `"hybrid"`, which combines Vector Search + Full Text Search via Reciprocal Rank Fusion (RRF). ArangoDB Traversal configuration @@ -165,6 +166,7 @@ class RetrievalRequest(BaseModel): ... class RetrievalRequestArangoDB(RetrievalRequest): graph_name: str | None = None search_start: str | None = None # "node", "edge", "chunk" + search_type: str | None = None # "vector", "hybrid" num_centroids: int | None = None distance_strategy: str | None = None # # "COSINE", "EUCLIDEAN_DISTANCE" use_approx_search: bool | None = None diff --git a/comps/retrievers/src/integrations/arangodb.py b/comps/retrievers/src/integrations/arangodb.py index 9905d9134e..c68432f941 100644 --- a/comps/retrievers/src/integrations/arangodb.py +++ b/comps/retrievers/src/integrations/arangodb.py @@ -23,6 +23,7 @@ ARANGO_NUM_CENTROIDS, ARANGO_PASSWORD, ARANGO_SEARCH_START, + ARANGO_SEARCH_TYPE, ARANGO_TRAVERSAL_ENABLED, ARANGO_TRAVERSAL_MAX_DEPTH, ARANGO_TRAVERSAL_MAX_RETURNED, @@ -325,25 +326,34 @@ async def invoke( # Process Input # ################# - query = getattr(input, "input", getattr(input, "text")) + input_dict = input.model_dump(exclude_none=True) + query = input_dict.get("input", input_dict.get("text")) + if not query: if logflag: - logger.error("Query is empty.") + logger.error("Query is empty. Please provide a valid query.") return [] embedding = input.embedding if isinstance(input.embedding, list) else None - graph_name = getattr(input, "graph_name", ARANGO_GRAPH_NAME) - search_start = getattr(input, "search_start", ARANGO_SEARCH_START) - enable_traversal = getattr(input, "enable_traversal", ARANGO_TRAVERSAL_ENABLED) - enable_summarizer = getattr(input, "enable_summarizer", SUMMARIZER_ENABLED) - distance_strategy = getattr(input, "distance_strategy", ARANGO_DISTANCE_STRATEGY) - use_approx_search = getattr(input, "use_approx_search", ARANGO_USE_APPROX_SEARCH) - num_centroids = getattr(input, "num_centroids", ARANGO_NUM_CENTROIDS) - traversal_max_depth = getattr(input, "traversal_max_depth", ARANGO_TRAVERSAL_MAX_DEPTH) - traversal_max_returned = getattr(input, "traversal_max_returned", ARANGO_TRAVERSAL_MAX_RETURNED) - traversal_score_threshold = getattr(input, "traversal_score_threshold", ARANGO_TRAVERSAL_SCORE_THRESHOLD) - traversal_query = getattr(input, "traversal_query", ARANGO_TRAVERSAL_QUERY) + graph_name = input_dict.get("graph_name", ARANGO_GRAPH_NAME) + search_start = input_dict.get("search_start", ARANGO_SEARCH_START) + search_type = input_dict.get("search_type", ARANGO_SEARCH_TYPE) + enable_traversal = input_dict.get("enable_traversal", ARANGO_TRAVERSAL_ENABLED) + enable_summarizer = input_dict.get("enable_summarizer", SUMMARIZER_ENABLED) + distance_strategy = input_dict.get("distance_strategy", ARANGO_DISTANCE_STRATEGY) + use_approx_search = input_dict.get("use_approx_search", ARANGO_USE_APPROX_SEARCH) + num_centroids = input_dict.get("num_centroids", ARANGO_NUM_CENTROIDS) + traversal_max_depth = input_dict.get("traversal_max_depth", ARANGO_TRAVERSAL_MAX_DEPTH) + traversal_max_returned = input_dict.get("traversal_max_returned", ARANGO_TRAVERSAL_MAX_RETURNED) + traversal_score_threshold = input_dict.get("traversal_score_threshold", ARANGO_TRAVERSAL_SCORE_THRESHOLD) + traversal_query = input_dict.get("traversal_query", ARANGO_TRAVERSAL_QUERY) + + if not graph_name: + raise HTTPException( + status_code=400, + detail="Graph name is empty. Please provide a valid graph name.", + ) if search_start == "node": collection_name = f"{graph_name}_ENTITY" @@ -375,7 +385,12 @@ async def invoke( if not (v_col_exists or e_col_exists): if logflag: - collection_names = self.db.graph(graph_name).vertex_collections() + collection_names = set() + for e_d in self.db.graph(graph_name).edge_definitions(): + collection_names.add(e_d["edge_collection"]) + collection_names.update(e_d["from_vertex_collections"]) + collection_names.update(e_d["to_vertex_collections"]) + m = f"Collection '{collection_name}' does not exist in graph '{graph_name}'. Collections: {collection_names}" logger.error(m) return [] @@ -430,16 +445,23 @@ async def invoke( else: embeddings = HuggingFaceBgeEmbeddings(model_name=TEI_EMBED_MODEL) - vector_db = ArangoVector( - embedding=embeddings, - embedding_dimension=dimension, - database=self.db, - collection_name=collection_name, - embedding_field=ARANGO_EMBEDDING_FIELD, - text_field=ARANGO_TEXT_FIELD, - distance_strategy=distance_strategy, - num_centroids=num_centroids, - ) + try: + vector_db = ArangoVector( + embedding=embeddings, + embedding_dimension=dimension, + database=self.db, + collection_name=collection_name, + embedding_field=ARANGO_EMBEDDING_FIELD, + text_field=ARANGO_TEXT_FIELD, + distance_strategy=distance_strategy, + num_centroids=num_centroids, + search_type=search_type, + ) + except Exception as e: + if logflag: + logger.error(f"Error during ArangoVector initialization: {e}") + + return [] ###################### # Compute Similarity # diff --git a/comps/retrievers/src/integrations/config.py b/comps/retrievers/src/integrations/config.py index 8514192611..85e16d42fb 100644 --- a/comps/retrievers/src/integrations/config.py +++ b/comps/retrievers/src/integrations/config.py @@ -205,6 +205,7 @@ def format_opensearch_conn_from_env(): ARANGO_USE_APPROX_SEARCH = os.getenv("ARANGO_USE_APPROX_SEARCH", "false").lower() == "true" ARANGO_NUM_CENTROIDS = os.getenv("ARANGO_NUM_CENTROIDS", 1) ARANGO_SEARCH_START = os.getenv("ARANGO_SEARCH_START", "node") +ARANGO_SEARCH_TYPE = os.getenv("ARANGO_SEARCH_TYPE", "vector") # ArangoDB Traversal configuration ARANGO_TRAVERSAL_ENABLED = os.getenv("ARANGO_TRAVERSAL_ENABLED", "false").lower() == "true" diff --git a/comps/retrievers/src/requirements.txt b/comps/retrievers/src/requirements.txt index a98da36023..b143c157c2 100644 --- a/comps/retrievers/src/requirements.txt +++ b/comps/retrievers/src/requirements.txt @@ -10,7 +10,7 @@ fastapi future graspologic haystack-ai==2.3.1 -langchain-arangodb +langchain-arangodb==0.0.6 langchain-elasticsearch langchain-mariadb langchain-openai diff --git a/tests/dataprep/test_dataprep_arango.sh b/tests/dataprep/test_dataprep_arango.sh index 86e8bdc6c9..b615dfae37 100644 --- a/tests/dataprep/test_dataprep_arango.sh +++ b/tests/dataprep/test_dataprep_arango.sh @@ -61,7 +61,7 @@ function start_service() { docker compose up ${service_name} -d > ${LOG_PATH}/start_services_with_compose.log # Debug time - sleep 1m + sleep 2m check_healthy "dataprep-arangodb" || exit 1 } diff --git a/tests/retrievers/test_retrievers_arango.sh b/tests/retrievers/test_retrievers_arango.sh index 402be6af4b..15184ce7a8 100644 --- a/tests/retrievers/test_retrievers_arango.sh +++ b/tests/retrievers/test_retrievers_arango.sh @@ -16,7 +16,6 @@ export ARANGO_URL=${ARANGO_URL:-"http://${host_ip}:8529"} export ARANGO_USERNAME=${ARANGO_USERNAME:-"root"} export ARANGO_PASSWORD=${ARANGO_PASSWORD:-"test"} export ARANGO_DB_NAME=${ARANGO_DB_NAME:-"_system"} -export ARANGO_COLLECTION_NAME=${ARANGO_COLLECTION_NAME:-"test"} function build_docker_images() { cd $WORKPATH @@ -43,27 +42,34 @@ function validate_microservice() { source activate URL="http://${host_ip}:7000/v1/retrieval" - # Create ARANGO_COLLECTION_NAME + # Create vertex collection GRAPH_ENTITY curl -X POST --header 'accept: application/json' \ --header 'Content-Type: application/json' \ - --data '{"name": "'${ARANGO_COLLECTION_NAME}'", "type": 2, "waitForSync": true}' \ + --data '{"name": "GRAPH_ENTITY", "type": 2, "waitForSync": true}' \ "${ARANGO_URL}/_db/${ARANGO_DB_NAME}/_api/collection" \ -u ${ARANGO_USERNAME}:${ARANGO_PASSWORD} + # Create graph GRAPH with GRAPH_ENTITY as vertex collection + curl -X POST --header 'accept: application/json' \ + --header 'Content-Type: application/json' \ + --data '{"name": "GRAPH", "edgeDefinitions": [], "orphanCollections": ["GRAPH_ENTITY"]}' \ + "${ARANGO_URL}/_db/${ARANGO_DB_NAME}/_api/gharial" \ + -u ${ARANGO_USERNAME}:${ARANGO_PASSWORD} + # Insert data into arango: {text: "test", embedding: [0.1, 0.2, 0.3, 0.4, 0.5]} curl -X POST --header 'accept: application/json' \ --header 'Content-Type: application/json' \ --data '{"text": "test", "embedding": [0.1, 0.2, 0.3, 0.4, 0.5]}' \ - "${ARANGO_URL}/_db/${ARANGO_DB_NAME}/_api/document/${ARANGO_COLLECTION_NAME}" \ + "${ARANGO_URL}/_db/${ARANGO_DB_NAME}/_api/document/GRAPH_ENTITY" \ -u ${ARANGO_USERNAME}:${ARANGO_PASSWORD} sleep 1m test_embedding="[0.1, 0.2, 0.3, 0.4, 0.5]" - HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X POST -d "{\"text\":\"test\",\"embedding\":${test_embedding}}" -H 'Content-Type: application/json' "$URL") + HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X POST -d "{\"input\":\"test\",\"embedding\":${test_embedding},\"graph_name\":\"GRAPH\"}" -H 'Content-Type: application/json' "$URL") if [ "$HTTP_STATUS" -eq 200 ]; then echo "[ retriever ] HTTP status is 200. Checking content..." - local CONTENT=$(curl -s -X POST -d "{\"text\":\"test\",\"embedding\":${test_embedding}}" -H 'Content-Type: application/json' "$URL" | tee ${LOG_PATH}/retriever.log) + local CONTENT=$(curl -s -X POST -d "{\"input\":\"test\",\"embedding\":${test_embedding},\"graph_name\":\"GRAPH\"}" -H 'Content-Type: application/json' "$URL" | tee ${LOG_PATH}/retriever.log) if echo "$CONTENT" | grep -q "retrieved_docs"; then echo "[ retriever ] Content is as expected."