Skip to content

Commit 072a2ac

Browse files
Add qdrant native engine (#271)
* Add httpx engine for qdrant * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4d52abd commit 072a2ac

File tree

9 files changed

+695
-4
lines changed

9 files changed

+695
-4
lines changed

engine/base_client/upload.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ def upload(
3434
parallel = self.upload_params.get("parallel", 1)
3535
batch_size = self.upload_params.get("batch_size", 64)
3636

37-
self.init_client(
38-
self.host, distance, self.connection_params, self.upload_params
39-
)
40-
4137
if parallel == 1:
38+
# Initialize client in parent process for serial uploads
39+
self.init_client(
40+
self.host, distance, self.connection_params, self.upload_params
41+
)
4242
for batch in iter_batches(tqdm.tqdm(records), batch_size):
4343
latencies.append(self._upload_batch(batch))
4444
else:
@@ -59,6 +59,10 @@ def upload(
5959
iter_batches(tqdm.tqdm(records), batch_size),
6060
)
6161
)
62+
# Initialize client in parent process for post-upload operations
63+
self.init_client(
64+
self.host, distance, self.connection_params, self.upload_params
65+
)
6266

6367
upload_time = time.perf_counter() - start
6468

engine/clients/client_factory.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
PgVectorUploader,
2525
)
2626
from engine.clients.qdrant import QdrantConfigurator, QdrantSearcher, QdrantUploader
27+
from engine.clients.qdrant_native import (
28+
QdrantNativeConfigurator,
29+
QdrantNativeSearcher,
30+
QdrantNativeUploader,
31+
)
2732
from engine.clients.redis import RedisConfigurator, RedisSearcher, RedisUploader
2833
from engine.clients.weaviate import (
2934
WeaviateConfigurator,
@@ -33,6 +38,7 @@
3338

3439
ENGINE_CONFIGURATORS = {
3540
"qdrant": QdrantConfigurator,
41+
"qdrant_native": QdrantNativeConfigurator,
3642
"weaviate": WeaviateConfigurator,
3743
"milvus": MilvusConfigurator,
3844
"elasticsearch": ElasticConfigurator,
@@ -43,6 +49,7 @@
4349

4450
ENGINE_UPLOADERS = {
4551
"qdrant": QdrantUploader,
52+
"qdrant_native": QdrantNativeUploader,
4653
"weaviate": WeaviateUploader,
4754
"milvus": MilvusUploader,
4855
"elasticsearch": ElasticUploader,
@@ -53,6 +60,7 @@
5360

5461
ENGINE_SEARCHERS = {
5562
"qdrant": QdrantSearcher,
63+
"qdrant_native": QdrantNativeSearcher,
5664
"weaviate": WeaviateSearcher,
5765
"milvus": MilvusSearcher,
5866
"elasticsearch": ElasticSearcher,
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .configure import QdrantNativeConfigurator
2+
from .search import QdrantNativeSearcher
3+
from .upload import QdrantNativeUploader
4+
5+
__all__ = [
6+
"QdrantNativeConfigurator",
7+
"QdrantNativeUploader",
8+
"QdrantNativeSearcher",
9+
]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import os
2+
3+
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "benchmark")
4+
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", None)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import httpx
2+
3+
from benchmark.dataset import Dataset
4+
from engine.base_client.configure import BaseConfigurator
5+
from engine.base_client.distances import Distance
6+
from engine.clients.qdrant_native.config import QDRANT_API_KEY, QDRANT_COLLECTION_NAME
7+
8+
9+
class QdrantNativeConfigurator(BaseConfigurator):
10+
SPARSE_VECTOR_SUPPORT = True
11+
DISTANCE_MAPPING = {
12+
Distance.L2: "Euclid",
13+
Distance.COSINE: "Cosine",
14+
Distance.DOT: "Dot",
15+
}
16+
INDEX_TYPE_MAPPING = {
17+
"int": "integer",
18+
"keyword": "keyword",
19+
"text": "text",
20+
"float": "float",
21+
"geo": "geo",
22+
}
23+
24+
def __init__(self, host, collection_params: dict, connection_params: dict):
25+
super().__init__(host, collection_params, connection_params)
26+
27+
self.host = f"http://{host.rstrip('/')}:6333"
28+
self.connection_params = connection_params
29+
30+
self.headers = {"Content-Type": "application/json"}
31+
if QDRANT_API_KEY:
32+
self.headers["api-key"] = QDRANT_API_KEY
33+
34+
timeout = connection_params.get("timeout", 30)
35+
self.client = httpx.Client(
36+
headers=self.headers,
37+
timeout=httpx.Timeout(timeout=timeout),
38+
)
39+
40+
def clean(self):
41+
"""Delete the collection"""
42+
url = f"{self.host}/collections/{QDRANT_COLLECTION_NAME}"
43+
response = self.client.delete(url)
44+
# 404 is ok if collection doesn't exist
45+
if response.status_code not in [200, 404]:
46+
response.raise_for_status()
47+
48+
def recreate(self, dataset: Dataset, collection_params):
49+
"""Create collection with proper configuration"""
50+
url = f"{self.host}/collections/{QDRANT_COLLECTION_NAME}"
51+
52+
# Build vectors configuration
53+
if dataset.config.type == "sparse":
54+
vectors_config = {}
55+
sparse_vectors_config = {
56+
"sparse": {
57+
"index": {
58+
"on_disk": False,
59+
}
60+
}
61+
}
62+
else:
63+
is_vectors_on_disk = self.collection_params.get("vectors_config", {}).get(
64+
"on_disk", False
65+
)
66+
self.collection_params.pop("vectors_config", None)
67+
68+
vectors_config = {
69+
"size": dataset.config.vector_size,
70+
"distance": self.DISTANCE_MAPPING.get(dataset.config.distance),
71+
"on_disk": is_vectors_on_disk,
72+
}
73+
sparse_vectors_config = None
74+
75+
payload_index_params = self.collection_params.pop("payload_index_params", {})
76+
if not set(payload_index_params.keys()).issubset(dataset.config.schema.keys()):
77+
raise ValueError("payload_index_params are not found in dataset schema")
78+
79+
# Set optimizers config - disable index building during upload by default
80+
optimizers_config = self.collection_params.setdefault("optimizers_config", {})
81+
optimizers_config.setdefault("max_optimization_threads", 0)
82+
83+
# Build the collection creation payload
84+
payload = {}
85+
if vectors_config:
86+
payload["vectors"] = vectors_config
87+
if sparse_vectors_config:
88+
payload["sparse_vectors"] = sparse_vectors_config
89+
90+
for key, value in self.collection_params.items():
91+
payload[key] = value
92+
93+
response = self.client.put(url, json=payload)
94+
response.raise_for_status()
95+
96+
for field_name, field_type in dataset.config.schema.items():
97+
self._create_payload_index(field_name, field_type, payload_index_params)
98+
99+
def _create_payload_index(
100+
self, field_name: str, field_type: str, payload_index_params: dict
101+
):
102+
"""Create a payload index for a specific field"""
103+
url = f"{self.host}/collections/{QDRANT_COLLECTION_NAME}/index"
104+
105+
# Build the field schema based on type
106+
if field_type in ["keyword", "uuid"]:
107+
field_schema = {
108+
"type": self.INDEX_TYPE_MAPPING.get(field_type, "keyword"),
109+
}
110+
111+
# Add optional parameters if provided
112+
params = payload_index_params.get(field_name, {})
113+
if "is_tenant" in params and params["is_tenant"] is not None:
114+
field_schema["is_tenant"] = params["is_tenant"]
115+
if "on_disk" in params and params["on_disk"] is not None:
116+
field_schema["on_disk"] = params["on_disk"]
117+
else:
118+
# For other types, just use the type string
119+
field_schema = self.INDEX_TYPE_MAPPING.get(field_type, field_type)
120+
121+
payload = {
122+
"field_name": field_name,
123+
"field_schema": field_schema,
124+
}
125+
126+
response = self.client.put(url, json=payload)
127+
response.raise_for_status()
128+
129+
def delete_client(self):
130+
"""Cleanup HTTP client"""
131+
if hasattr(self, "client") and self.client is not None:
132+
self.client.close()
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Any, List, Optional
2+
3+
from engine.base_client.parser import BaseConditionParser, FieldValue
4+
5+
6+
class QdrantNativeConditionParser(BaseConditionParser):
7+
"""
8+
Parser that converts internal filter format to Qdrant REST API JSON format.
9+
Returns plain dictionaries instead of Pydantic models.
10+
"""
11+
12+
def build_condition(
13+
self, and_subfilters: Optional[List[Any]], or_subfilters: Optional[List[Any]]
14+
) -> Optional[Any]:
15+
"""Build a filter condition combining AND/OR subfilters"""
16+
filter_dict = {}
17+
18+
if and_subfilters:
19+
filter_dict["must"] = and_subfilters
20+
21+
if or_subfilters:
22+
filter_dict["should"] = or_subfilters
23+
24+
return filter_dict if filter_dict else None
25+
26+
def build_exact_match_filter(self, field_name: str, value: FieldValue) -> Any:
27+
"""Build an exact match filter"""
28+
return {
29+
"key": field_name,
30+
"match": {"value": value},
31+
}
32+
33+
def build_range_filter(
34+
self,
35+
field_name: str,
36+
lt: Optional[FieldValue],
37+
gt: Optional[FieldValue],
38+
lte: Optional[FieldValue],
39+
gte: Optional[FieldValue],
40+
) -> Any:
41+
"""Build a range filter"""
42+
range_dict = {}
43+
if lt is not None:
44+
range_dict["lt"] = lt
45+
if gt is not None:
46+
range_dict["gt"] = gt
47+
if lte is not None:
48+
range_dict["lte"] = lte
49+
if gte is not None:
50+
range_dict["gte"] = gte
51+
52+
return {
53+
"key": field_name,
54+
"range": range_dict,
55+
}
56+
57+
def build_geo_filter(
58+
self, field_name: str, lat: float, lon: float, radius: float
59+
) -> Any:
60+
"""Build a geo radius filter"""
61+
return {
62+
"key": field_name,
63+
"geo_radius": {
64+
"center": {
65+
"lon": lon,
66+
"lat": lat,
67+
},
68+
"radius": radius,
69+
},
70+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from typing import List, Tuple
2+
3+
import httpx
4+
5+
from dataset_reader.base_reader import Query
6+
from engine.base_client.search import BaseSearcher
7+
from engine.clients.qdrant_native.config import QDRANT_API_KEY, QDRANT_COLLECTION_NAME
8+
from engine.clients.qdrant_native.parser import QdrantNativeConditionParser
9+
10+
11+
class QdrantNativeSearcher(BaseSearcher):
12+
search_params = {}
13+
client: httpx.Client = None
14+
parser = QdrantNativeConditionParser()
15+
host = None
16+
headers = {}
17+
18+
@classmethod
19+
def init_client(cls, host, distance, connection_params: dict, search_params: dict):
20+
cls.host = f"http://{host.rstrip('/')}:6333"
21+
cls.search_params = search_params
22+
23+
# Build headers
24+
cls.headers = {"Content-Type": "application/json"}
25+
if QDRANT_API_KEY:
26+
cls.headers["api-key"] = QDRANT_API_KEY
27+
28+
# Create HTTP client
29+
# Use longer timeout for write operations to handle large query payloads
30+
base_timeout = connection_params.get("timeout", 30)
31+
cls.client = httpx.Client(
32+
headers=cls.headers,
33+
timeout=httpx.Timeout(
34+
connect=base_timeout,
35+
read=base_timeout,
36+
write=base_timeout * 5, # 5x longer for writes
37+
pool=base_timeout,
38+
),
39+
limits=httpx.Limits(max_connections=None, max_keepalive_connections=0),
40+
)
41+
42+
@classmethod
43+
def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]:
44+
"""Execute a single search query using REST API"""
45+
url = f"{cls.host}/collections/{QDRANT_COLLECTION_NAME}/points/query"
46+
47+
if query.sparse_vector is None:
48+
query_vector = query.vector
49+
else:
50+
# Convert numpy types to native Python types for JSON serialization
51+
query_vector = {
52+
"indices": [int(i) for i in query.sparse_vector.indices],
53+
"values": [float(v) for v in query.sparse_vector.values],
54+
}
55+
56+
payload = {
57+
"query": query_vector,
58+
"limit": top,
59+
}
60+
61+
if query.sparse_vector is not None:
62+
payload["using"] = "sparse"
63+
64+
query_filter = cls.parser.parse(query.meta_conditions)
65+
if query_filter:
66+
payload["filter"] = query_filter
67+
68+
search_config = cls.search_params.get("config", {})
69+
if search_config:
70+
payload["params"] = search_config
71+
72+
prefetch_config = cls.search_params.get("prefetch")
73+
if prefetch_config:
74+
prefetch = {
75+
**prefetch_config,
76+
"query": query_vector,
77+
}
78+
payload["prefetch"] = [prefetch]
79+
80+
with_payload = cls.search_params.get("with_payload", False)
81+
payload["with_payload"] = with_payload
82+
83+
try:
84+
response = cls.client.post(url, json=payload)
85+
response.raise_for_status()
86+
result = response.json()
87+
88+
points = result["result"]["points"]
89+
return [(point["id"], point["score"]) for point in points]
90+
91+
except Exception as ex:
92+
print(f"Something went wrong during search: {ex}")
93+
raise ex
94+
95+
@classmethod
96+
def delete_client(cls):
97+
"""Cleanup HTTP client"""
98+
if cls.client is not None:
99+
cls.client.close()
100+
cls.client = None

0 commit comments

Comments
 (0)