Skip to content

Commit a0985b4

Browse files
committed
Applying review comments - part 1
1 parent bd777ab commit a0985b4

File tree

4 files changed

+216
-86
lines changed

4 files changed

+216
-86
lines changed

redis/commands/search/commands.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from redis._parsers.helpers import pairs_to_dict
66
from redis.client import NEVER_DECODE, Pipeline
77
from redis.commands.search.hybrid_query import (
8+
CombineResultsMethod,
89
HybridCursorQuery,
910
HybridPostProcessingConfig,
1011
HybridQuery,
@@ -562,6 +563,7 @@ def search(
562563
def hybrid_search(
563564
self,
564565
query: HybridQuery,
566+
combine_method: Optional[CombineResultsMethod] = None,
565567
post_processing: Optional[HybridPostProcessingConfig] = None,
566568
params_substitution: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
567569
timeout: Optional[int] = None,
@@ -573,6 +575,8 @@ def hybrid_search(
573575
Args:
574576
- **query**: HybridQuery object
575577
Contains the text and vector queries
578+
- **combine_method**: CombineResultsMethod object
579+
Contains the combine method and parameters
576580
- **post_processing**: HybridPostProcessingConfig object
577581
Contains the post processing configuration
578582
- **params_substitution**: Dict[str, Union[str, int, float, bytes]]
@@ -587,6 +591,8 @@ def hybrid_search(
587591
options = {}
588592
pieces = [HYBRID_CMD, index]
589593
pieces.extend(query.get_args())
594+
if combine_method:
595+
pieces.extend(combine_method.get_args())
590596
if post_processing:
591597
pieces.extend(post_processing.build_args())
592598
if params_substitution:
@@ -1050,6 +1056,7 @@ async def search(
10501056
async def hybrid_search(
10511057
self,
10521058
query: HybridQuery,
1059+
combine_method: Optional[CombineResultsMethod] = None,
10531060
post_processing: Optional[HybridPostProcessingConfig] = None,
10541061
params_substitution: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
10551062
timeout: Optional[int] = None,
@@ -1061,6 +1068,8 @@ async def hybrid_search(
10611068
Args:
10621069
- **query**: HybridQuery object
10631070
Contains the text and vector queries
1071+
- **combine_method**: CombineResultsMethod object
1072+
Contains the combine method and parameters
10641073
- **post_processing**: HybridPostProcessingConfig object
10651074
Contains the post processing configuration
10661075
- **params_substitution**: Dict[str, Union[str, int, float, bytes]]
@@ -1075,6 +1084,8 @@ async def hybrid_search(
10751084
options = {}
10761085
pieces = [HYBRID_CMD, index]
10771086
pieces.extend(query.get_args())
1087+
if combine_method:
1088+
pieces.extend(combine_method.get_args())
10781089
if post_processing:
10791090
pieces.extend(post_processing.build_args())
10801091
if params_substitution:

redis/commands/search/hybrid_query.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import Enum
12
from typing import Any, Dict, List, Literal, Optional, Union
23

34
try:
@@ -35,7 +36,7 @@ def query_string(self) -> str:
3536
def scorer(self, scorer: str) -> "HybridSearchQuery":
3637
"""
3738
Scoring algorithm for text search query.
38-
Allowed values are "TFIDF" or "BM25"
39+
Allowed values are "TFIDF", "DISMAX", "DOCSCORE", "BM25", etc.
3940
"""
4041
self._scorer = scorer
4142
return self
@@ -56,12 +57,17 @@ def get_args(self) -> List[str]:
5657
return args
5758

5859

60+
class VectorSearchMethods(Enum):
61+
KNN = "KNN"
62+
RANGE = "RANGE"
63+
64+
5965
class HybridVsimQuery:
6066
def __init__(
6167
self,
6268
vector_field_name: str,
6369
vector_data: Union[bytes, str],
64-
vsim_search_method: Optional[str] = None,
70+
vsim_search_method: Optional[VectorSearchMethods] = None,
6571
vsim_search_method_params: Optional[Dict[str, Any]] = None,
6672
filter: Optional["Filter"] = None,
6773
) -> None:
@@ -70,11 +76,22 @@ def __init__(
7076
7177
Args:
7278
vector_field_name: Vector field name.
79+
7380
vector_data: Vector data for the search.
81+
7482
vsim_search_method: Search method that will be used for the vsim search.
75-
Allowed values are "KNN" or "RANGE".
83+
7684
vsim_search_method_params: Search method parameters. Use the param names
77-
for keys and the values for the values. Example: {"K": 10, "EF_RUNTIME": 100}.
85+
for keys and the values for the values.
86+
Example for KNN: {"K": 10, "EF_RUNTIME": 100}
87+
where K is mandatory and defines the number of results
88+
and EF_RUNTIME is optional and definesthe exploration factor.
89+
Example for RANGE: {"RADIUS": 10, "EPSILON": 0.1}
90+
where RADIUS is mandatory and defines the radius of the search
91+
and EPSILON is optional and defines the accuracy of the search.
92+
For both KNN and RANGE, the following parameter is optional:
93+
YIELD_SCORE_AS: The name of the field to yield the calculated score as.
94+
7895
filter: If defined, a filter will be applied on the vsim query results.
7996
"""
8097
self._vector_field = vector_field_name
@@ -95,7 +112,7 @@ def vector_data(self) -> Union[bytes, str]:
95112

96113
def vsim_method_params(
97114
self,
98-
method: str,
115+
method: VectorSearchMethods,
99116
**kwargs,
100117
) -> "HybridVsimQuery":
101118
"""
@@ -106,7 +123,7 @@ def vsim_method_params(
106123
kwargs: Search method parameters. Use the param names for keys and the
107124
values for the values. Example: {"K": 10, "EF_RUNTIME": 100}.
108125
"""
109-
vsim_method_params: List[Union[str, int]] = [method]
126+
vsim_method_params: List[Union[str, int]] = [method.value]
110127
if kwargs:
111128
vsim_method_params.append(len(kwargs.items()) * 2)
112129
for key, value in kwargs.items():
@@ -158,40 +175,44 @@ def get_args(self) -> List[str]:
158175
return args
159176

160177

178+
class CombinationMethods(Enum):
179+
RRF = "RRF"
180+
LINEAR = "LINEAR"
181+
182+
183+
class CombineResultsMethod:
184+
def __init__(self, method: CombinationMethods, **kwargs) -> None:
185+
"""
186+
Create a new combine results method object.
187+
188+
Args:
189+
method: The combine method to use - RRF or LINEAR.
190+
kwargs: Additional combine parameters.
191+
"""
192+
self._method = method
193+
self._kwargs = kwargs
194+
195+
def get_args(self) -> List[Union[str, int]]:
196+
args: List[Union[str, int]] = ["COMBINE", self._method.value]
197+
if self._kwargs:
198+
args.append(len(self._kwargs.items()) * 2)
199+
for key, value in self._kwargs.items():
200+
args.extend((key, value))
201+
return args
202+
203+
161204
class HybridPostProcessingConfig:
162205
def __init__(self) -> None:
163206
"""
164207
Create a new hybrid post processing configuration object.
165208
"""
166-
self._combine = []
167209
self._load_fields = []
168210
self._groupby = []
169211
self._apply = []
170212
self._sortby_fields = []
171213
self._filter = None
172214
self._limit = None
173215

174-
def combine(
175-
self,
176-
method: Literal["RRF", "LINEAR"],
177-
**kwargs,
178-
) -> Self:
179-
"""
180-
Add combine parameters to the query.
181-
182-
Args:
183-
method: The combine method to use - RRF or LINEAR.
184-
kwargs: Additional combine parameters.
185-
"""
186-
self._combine: List[Union[str, int]] = [method]
187-
188-
self._combine.append(len(kwargs) * 2)
189-
190-
for key, value in kwargs.items():
191-
self._combine.extend([key, value])
192-
193-
return self
194-
195216
def load(self, *fields: str) -> Self:
196217
"""
197218
Add load parameters to the query.
@@ -267,8 +288,6 @@ def limit(self, offset: int, num: int) -> Self:
267288

268289
def build_args(self) -> List[str]:
269290
args = []
270-
if self._combine:
271-
args.extend(("COMBINE", *self._combine))
272291
if self._load_fields:
273292
fields_str = " ".join(self._load_fields)
274293
fields = fields_str.split(" ")

0 commit comments

Comments
 (0)