Skip to content
1 change: 1 addition & 0 deletions comps/cores/proto/api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ class RetrievalRequest(BaseModel):
class RetrievalRequestArangoDB(RetrievalRequest):
graph_name: str | None = None
search_start: str | None = None # "node", "edge", "chunk"
search_type: str | None = None # "vector", "hybrid"
num_centroids: int | None = None
distance_strategy: str | None = None # # "COSINE", "EUCLIDEAN_DISTANCE"
use_approx_search: bool | None = None
Expand Down
26 changes: 14 additions & 12 deletions comps/dataprep/src/integrations/arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,18 +370,20 @@ async def ingest_files(self, input: Union[DataprepRequest, ArangoDBDataprepReque
chunk_overlap = input.chunk_overlap
process_table = input.process_table
table_strategy = input.table_strategy
graph_name = getattr(input, "graph_name", ARANGO_GRAPH_NAME)
insert_async = getattr(input, "insert_async", ARANGO_INSERT_ASYNC)
insert_batch_size = getattr(input, "insert_batch_size", ARANGO_BATCH_SIZE)
embed_nodes = getattr(input, "embed_nodes", EMBED_NODES)
embed_edges = getattr(input, "embed_edges", EMBED_EDGES)
embed_chunks = getattr(input, "embed_chunks", EMBED_CHUNKS)
allowed_node_types = getattr(input, "allowed_node_types", ALLOWED_NODE_TYPES)
allowed_edge_types = getattr(input, "allowed_edge_types", ALLOWED_EDGE_TYPES)
node_properties = getattr(input, "node_properties", NODE_PROPERTIES)
edge_properties = getattr(input, "edge_properties", EDGE_PROPERTIES)
text_capitalization_strategy = getattr(input, "text_capitalization_strategy", TEXT_CAPITALIZATION_STRATEGY)
include_chunks = getattr(input, "include_chunks", INCLUDE_CHUNKS)
graph_name = getattr(input, "graph_name", ARANGO_GRAPH_NAME) or ARANGO_GRAPH_NAME
insert_async = getattr(input, "insert_async", ARANGO_INSERT_ASYNC) or ARANGO_INSERT_ASYNC
insert_batch_size = getattr(input, "insert_batch_size", ARANGO_BATCH_SIZE) or ARANGO_BATCH_SIZE
embed_nodes = getattr(input, "embed_nodes", EMBED_NODES) or EMBED_NODES
embed_edges = getattr(input, "embed_edges", EMBED_EDGES) or EMBED_EDGES
embed_chunks = getattr(input, "embed_chunks", EMBED_CHUNKS) or EMBED_CHUNKS
allowed_node_types = getattr(input, "allowed_node_types", ALLOWED_NODE_TYPES) or ALLOWED_NODE_TYPES
allowed_edge_types = getattr(input, "allowed_edge_types", ALLOWED_EDGE_TYPES) or ALLOWED_EDGE_TYPES
node_properties = getattr(input, "node_properties", NODE_PROPERTIES) or NODE_PROPERTIES
edge_properties = getattr(input, "edge_properties", EDGE_PROPERTIES) or EDGE_PROPERTIES
text_capitalization_strategy = (
getattr(input, "text_capitalization_strategy", TEXT_CAPITALIZATION_STRATEGY) or TEXT_CAPITALIZATION_STRATEGY
)
include_chunks = getattr(input, "include_chunks", INCLUDE_CHUNKS) or INCLUDE_CHUNKS

self._initialize_llm(
allowed_node_types=allowed_node_types,
Expand Down
2 changes: 1 addition & 1 deletion comps/dataprep/src/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ huggingface_hub
ipython
json-repair
langchain
langchain-arangodb
langchain-arangodb==0.0.6
langchain-community
langchain-elasticsearch
langchain-experimental
Expand Down
2 changes: 2 additions & 0 deletions comps/retrievers/src/README_arangodb.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ ArangoDB Vector configuration
- `ARANGO_USE_APPROX_SEARCH`: If set to True, the microservice will use the approximate nearest neighbor search for as part of the retrieval step. Defaults to `False`, which means the microservice will use the exact search.
- `ARANGO_NUM_CENTROIDS`: The number of centroids to use for the approximate nearest neighbor search. Defaults to `1`.
- `ARANGO_SEARCH_START`: The starting point for the search. Defaults to `node`. Other option could be `"edge"`, or `"chunk"`.
- `ARANGO_SEARCH_TYPE`: The type of search to use for the ArangoDB service. Defaults to `vector`. Other option could be `"hybrid"`, which combines Vector Search + Full Text Search via Reciprocal Rank Fusion (RRF).

ArangoDB Traversal configuration

Expand Down Expand Up @@ -165,6 +166,7 @@ class RetrievalRequest(BaseModel): ...
class RetrievalRequestArangoDB(RetrievalRequest):
graph_name: str | None = None
search_start: str | None = None # "node", "edge", "chunk"
search_type: str | None = None # "vector", "hybrid"
num_centroids: int | None = None
distance_strategy: str | None = None # # "COSINE", "EUCLIDEAN_DISTANCE"
use_approx_search: bool | None = None
Expand Down
70 changes: 46 additions & 24 deletions comps/retrievers/src/integrations/arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ARANGO_NUM_CENTROIDS,
ARANGO_PASSWORD,
ARANGO_SEARCH_START,
ARANGO_SEARCH_TYPE,
ARANGO_TRAVERSAL_ENABLED,
ARANGO_TRAVERSAL_MAX_DEPTH,
ARANGO_TRAVERSAL_MAX_RETURNED,
Expand Down Expand Up @@ -325,25 +326,34 @@ async def invoke(
# Process Input #
#################

query = getattr(input, "input", getattr(input, "text"))
input_dict = input.model_dump(exclude_none=True)
query = input_dict.get("input", input_dict.get("text"))

if not query:
if logflag:
logger.error("Query is empty.")
logger.error("Query is empty. Please provide a valid query.")

return []

embedding = input.embedding if isinstance(input.embedding, list) else None
graph_name = getattr(input, "graph_name", ARANGO_GRAPH_NAME)
search_start = getattr(input, "search_start", ARANGO_SEARCH_START)
enable_traversal = getattr(input, "enable_traversal", ARANGO_TRAVERSAL_ENABLED)
enable_summarizer = getattr(input, "enable_summarizer", SUMMARIZER_ENABLED)
distance_strategy = getattr(input, "distance_strategy", ARANGO_DISTANCE_STRATEGY)
use_approx_search = getattr(input, "use_approx_search", ARANGO_USE_APPROX_SEARCH)
num_centroids = getattr(input, "num_centroids", ARANGO_NUM_CENTROIDS)
traversal_max_depth = getattr(input, "traversal_max_depth", ARANGO_TRAVERSAL_MAX_DEPTH)
traversal_max_returned = getattr(input, "traversal_max_returned", ARANGO_TRAVERSAL_MAX_RETURNED)
traversal_score_threshold = getattr(input, "traversal_score_threshold", ARANGO_TRAVERSAL_SCORE_THRESHOLD)
traversal_query = getattr(input, "traversal_query", ARANGO_TRAVERSAL_QUERY)
graph_name = input_dict.get("graph_name", ARANGO_GRAPH_NAME)
search_start = input_dict.get("search_start", ARANGO_SEARCH_START)
search_type = input_dict.get("search_type", ARANGO_SEARCH_TYPE)
enable_traversal = input_dict.get("enable_traversal", ARANGO_TRAVERSAL_ENABLED)
enable_summarizer = input_dict.get("enable_summarizer", SUMMARIZER_ENABLED)
distance_strategy = input_dict.get("distance_strategy", ARANGO_DISTANCE_STRATEGY)
use_approx_search = input_dict.get("use_approx_search", ARANGO_USE_APPROX_SEARCH)
num_centroids = input_dict.get("num_centroids", ARANGO_NUM_CENTROIDS)
traversal_max_depth = input_dict.get("traversal_max_depth", ARANGO_TRAVERSAL_MAX_DEPTH)
traversal_max_returned = input_dict.get("traversal_max_returned", ARANGO_TRAVERSAL_MAX_RETURNED)
traversal_score_threshold = input_dict.get("traversal_score_threshold", ARANGO_TRAVERSAL_SCORE_THRESHOLD)
traversal_query = input_dict.get("traversal_query", ARANGO_TRAVERSAL_QUERY)

if not graph_name:
raise HTTPException(
status_code=400,
detail="Graph name is empty. Please provide a valid graph name.",
)

if search_start == "node":
collection_name = f"{graph_name}_ENTITY"
Expand Down Expand Up @@ -375,7 +385,12 @@ async def invoke(

if not (v_col_exists or e_col_exists):
if logflag:
collection_names = self.db.graph(graph_name).vertex_collections()
collection_names = set()
for e_d in self.db.graph(graph_name).edge_definitions():
collection_names.add(e_d["edge_collection"])
collection_names.update(e_d["from_vertex_collections"])
collection_names.update(e_d["to_vertex_collections"])

m = f"Collection '{collection_name}' does not exist in graph '{graph_name}'. Collections: {collection_names}"
logger.error(m)
return []
Expand Down Expand Up @@ -430,16 +445,23 @@ async def invoke(
else:
embeddings = HuggingFaceBgeEmbeddings(model_name=TEI_EMBED_MODEL)

vector_db = ArangoVector(
embedding=embeddings,
embedding_dimension=dimension,
database=self.db,
collection_name=collection_name,
embedding_field=ARANGO_EMBEDDING_FIELD,
text_field=ARANGO_TEXT_FIELD,
distance_strategy=distance_strategy,
num_centroids=num_centroids,
)
try:
vector_db = ArangoVector(
embedding=embeddings,
embedding_dimension=dimension,
database=self.db,
collection_name=collection_name,
embedding_field=ARANGO_EMBEDDING_FIELD,
text_field=ARANGO_TEXT_FIELD,
distance_strategy=distance_strategy,
num_centroids=num_centroids,
search_type=search_type,
)
except Exception as e:
if logflag:
logger.error(f"Error during ArangoVector initialization: {e}")

return []

######################
# Compute Similarity #
Expand Down
1 change: 1 addition & 0 deletions comps/retrievers/src/integrations/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def format_opensearch_conn_from_env():
ARANGO_USE_APPROX_SEARCH = os.getenv("ARANGO_USE_APPROX_SEARCH", "false").lower() == "true"
ARANGO_NUM_CENTROIDS = os.getenv("ARANGO_NUM_CENTROIDS", 1)
ARANGO_SEARCH_START = os.getenv("ARANGO_SEARCH_START", "node")
ARANGO_SEARCH_TYPE = os.getenv("ARANGO_SEARCH_TYPE", "vector")

# ArangoDB Traversal configuration
ARANGO_TRAVERSAL_ENABLED = os.getenv("ARANGO_TRAVERSAL_ENABLED", "false").lower() == "true"
Expand Down
2 changes: 1 addition & 1 deletion comps/retrievers/src/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fastapi
future
graspologic
haystack-ai==2.3.1
langchain-arangodb
langchain-arangodb==0.0.6
langchain-elasticsearch
langchain-mariadb
langchain-openai
Expand Down
2 changes: 1 addition & 1 deletion tests/dataprep/test_dataprep_arango.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function start_service() {
docker compose up ${service_name} -d > ${LOG_PATH}/start_services_with_compose.log

# Debug time
sleep 1m
sleep 2m

check_healthy "dataprep-arangodb" || exit 1
}
Expand Down
18 changes: 12 additions & 6 deletions tests/retrievers/test_retrievers_arango.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ export ARANGO_URL=${ARANGO_URL:-"http://${host_ip}:8529"}
export ARANGO_USERNAME=${ARANGO_USERNAME:-"root"}
export ARANGO_PASSWORD=${ARANGO_PASSWORD:-"test"}
export ARANGO_DB_NAME=${ARANGO_DB_NAME:-"_system"}
export ARANGO_COLLECTION_NAME=${ARANGO_COLLECTION_NAME:-"test"}

function build_docker_images() {
cd $WORKPATH
Expand All @@ -43,27 +42,34 @@ function validate_microservice() {
source activate
URL="http://${host_ip}:7000/v1/retrieval"

# Create ARANGO_COLLECTION_NAME
# Create vertex collection GRAPH_ENTITY
curl -X POST --header 'accept: application/json' \
--header 'Content-Type: application/json' \
--data '{"name": "'${ARANGO_COLLECTION_NAME}'", "type": 2, "waitForSync": true}' \
--data '{"name": "GRAPH_ENTITY", "type": 2, "waitForSync": true}' \
"${ARANGO_URL}/_db/${ARANGO_DB_NAME}/_api/collection" \
-u ${ARANGO_USERNAME}:${ARANGO_PASSWORD}

# Create graph GRAPH with GRAPH_ENTITY as vertex collection
curl -X POST --header 'accept: application/json' \
--header 'Content-Type: application/json' \
--data '{"name": "GRAPH", "edgeDefinitions": [], "orphanCollections": ["GRAPH_ENTITY"]}' \
"${ARANGO_URL}/_db/${ARANGO_DB_NAME}/_api/gharial" \
-u ${ARANGO_USERNAME}:${ARANGO_PASSWORD}

# Insert data into arango: {text: "test", embedding: [0.1, 0.2, 0.3, 0.4, 0.5]}
curl -X POST --header 'accept: application/json' \
--header 'Content-Type: application/json' \
--data '{"text": "test", "embedding": [0.1, 0.2, 0.3, 0.4, 0.5]}' \
"${ARANGO_URL}/_db/${ARANGO_DB_NAME}/_api/document/${ARANGO_COLLECTION_NAME}" \
"${ARANGO_URL}/_db/${ARANGO_DB_NAME}/_api/document/GRAPH_ENTITY" \
-u ${ARANGO_USERNAME}:${ARANGO_PASSWORD}

sleep 1m

test_embedding="[0.1, 0.2, 0.3, 0.4, 0.5]"
HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X POST -d "{\"text\":\"test\",\"embedding\":${test_embedding}}" -H 'Content-Type: application/json' "$URL")
HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X POST -d "{\"input\":\"test\",\"embedding\":${test_embedding},\"graph_name\":\"GRAPH\"}" -H 'Content-Type: application/json' "$URL")
if [ "$HTTP_STATUS" -eq 200 ]; then
echo "[ retriever ] HTTP status is 200. Checking content..."
local CONTENT=$(curl -s -X POST -d "{\"text\":\"test\",\"embedding\":${test_embedding}}" -H 'Content-Type: application/json' "$URL" | tee ${LOG_PATH}/retriever.log)
local CONTENT=$(curl -s -X POST -d "{\"input\":\"test\",\"embedding\":${test_embedding},\"graph_name\":\"GRAPH\"}" -H 'Content-Type: application/json' "$URL" | tee ${LOG_PATH}/retriever.log)

if echo "$CONTENT" | grep -q "retrieved_docs"; then
echo "[ retriever ] Content is as expected."
Expand Down
Loading