Skip to content

Commit fc7c324

Browse files
committed
Adding YIELD_SCORE_AS tests and clearing up support for the keyword
1 parent 421d7da commit fc7c324

File tree

3 files changed

+261
-15
lines changed

3 files changed

+261
-15
lines changed

redis/commands/search/hybrid_query.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ def __init__(
1414
self,
1515
query_string: str,
1616
scorer: Optional[str] = None,
17-
yield_score_as: Optional[
18-
str
19-
] = None, ## TODO check if this will be supported or it should be removed!
17+
yield_score_as: Optional[str] = None,
2018
) -> None:
2119
"""
2220
Create a new hybrid search query object.
@@ -42,13 +40,18 @@ def scorer(self, scorer: str) -> "HybridSearchQuery":
4240
self._scorer = scorer
4341
return self
4442

43+
def yield_score_as(self, alias: str) -> "HybridSearchQuery":
44+
"""
45+
Yield the score as a field.
46+
"""
47+
self._yield_score_as = alias
48+
return self
49+
4550
def get_args(self) -> List[str]:
4651
args = ["SEARCH", self._query_string]
4752
if self._scorer:
4853
args.extend(("SCORER", self._scorer))
49-
if (
50-
self._yield_score_as
51-
): # TODO check if this will be supported or it should be removed!
54+
if self._yield_score_as:
5255
args.extend(("YIELD_SCORE_AS", self._yield_score_as))
5356
return args
5457

@@ -109,7 +112,7 @@ def vsim_method_params(
109112
for key, value in kwargs.items():
110113
vsim_method_params.extend((key, value))
111114
self._vsim_method_params = vsim_method_params
112-
print(self._vsim_method_params)
115+
113116
return self
114117

115118
def filter(self, flt: "HybridFilter") -> "HybridVsimQuery":
@@ -171,17 +174,13 @@ def __init__(self) -> None:
171174
def combine(
172175
self,
173176
method: Literal["RRF", "LINEAR"],
174-
yield_score_as: Optional[
175-
str
176-
] = None, # TODO check if this will be supported or it should be removed!
177177
**kwargs,
178178
) -> Self:
179179
"""
180180
Add combine parameters to the query.
181181
182182
Args:
183183
method: The combine method to use - RRF or LINEAR.
184-
yield_score_as: Optional field name to yield the score as.
185184
kwargs: Additional combine parameters.
186185
"""
187186
self._combine: List[Union[str, int]] = [method]
@@ -191,10 +190,6 @@ def combine(
191190
for key, value in kwargs.items():
192191
self._combine.extend([key, value])
193192

194-
if (
195-
yield_score_as
196-
): # TODO check if this will be supported or it should be removed!
197-
self._combine.extend(["YIELD_SCORE_AS", yield_score_as])
198193
return self
199194

200195
def load(self, *fields: str) -> Self:

tests/test_asyncio/test_search.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2605,6 +2605,70 @@ async def test_hybrid_search_query_with_vsim_range(self, decoded_r):
26052605
assert res["warnings"] == []
26062606
assert res["execution_time"] > 0
26072607

2608+
@pytest.mark.redismod
2609+
@skip_if_server_version_lt("8.3.224")
2610+
async def test_hybrid_search_query_with_combine_all_score_aliases(self, decoded_r):
2611+
# Create index and add data
2612+
await self._create_hybrid_search_index(decoded_r)
2613+
await self._add_data_for_hybrid_search(
2614+
decoded_r, items_sets=1, use_random_str_data=True
2615+
)
2616+
2617+
search_query = HybridSearchQuery("shoes")
2618+
search_query.yield_score_as("search_score")
2619+
2620+
vsim_query = HybridVsimQuery(
2621+
vector_field_name="@embedding-hnsw",
2622+
vector_data="abcd1234efgh5678",
2623+
vsim_search_method="KNN",
2624+
vsim_search_method_params={
2625+
"K": 3,
2626+
"EF_RUNTIME": 1,
2627+
"YIELD_SCORE_AS": "vsim_score",
2628+
},
2629+
)
2630+
2631+
hybrid_query = HybridQuery(search_query, vsim_query)
2632+
2633+
posprocessing_config = HybridPostProcessingConfig()
2634+
posprocessing_config.combine(
2635+
"LINEAR", ALPHA=0.5, BETA=0.5, YIELD_SCORE_AS="combined_score"
2636+
)
2637+
2638+
res = await decoded_r.ft().hybrid_search(
2639+
query=hybrid_query, post_processing=posprocessing_config, timeout=10
2640+
)
2641+
2642+
if is_resp2_connection(decoded_r):
2643+
assert len(res.results) > 0
2644+
assert res.warnings == []
2645+
for item in res.results:
2646+
assert item["combined_score"] is not None
2647+
assert "__score" not in item
2648+
if item["__key"] in [b"item:0", b"item:1", b"item:4"]:
2649+
assert item["search_score"] is not None
2650+
else:
2651+
assert "search_score" not in item
2652+
if item["__key"] in [b"item:0", b"item:1", b"item:2"]:
2653+
assert item["vsim_score"] is not None
2654+
else:
2655+
assert "vsim_score" not in item
2656+
2657+
else:
2658+
assert len(res["results"]) > 0
2659+
assert res["warnings"] == []
2660+
for item in res["results"]:
2661+
assert item["combined_score"] is not None
2662+
assert "__score" not in item
2663+
if item["__key"] in ["item:0", "item:1", "item:4"]:
2664+
assert item["search_score"] is not None
2665+
else:
2666+
assert "search_score" not in item
2667+
if item["__key"] in ["item:0", "item:1", "item:2"]:
2668+
assert item["vsim_score"] is not None
2669+
else:
2670+
assert "vsim_score" not in item
2671+
26082672
@pytest.mark.redismod
26092673
@skip_if_server_version_lt("8.3.224")
26102674
async def test_hybrid_search_query_with_combine(self, decoded_r):

tests/test_search.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4281,6 +4281,193 @@ def test_hybrid_search_query_with_vsim_filter(self, client):
42814281
assert item["price"] in ["15", "16"]
42824282
assert item["size"] in ["10", "11"]
42834283

4284+
@pytest.mark.redismod
4285+
@skip_if_server_version_lt("8.3.224")
4286+
def test_hybrid_search_query_with_search_score_aliases(self, client):
4287+
# Create index and add data
4288+
self._create_hybrid_search_index(client)
4289+
self._add_data_for_hybrid_search(client, items_sets=1, use_random_str_data=True)
4290+
4291+
search_query = HybridSearchQuery("shoes")
4292+
search_query.yield_score_as("search_score")
4293+
4294+
vsim_query = HybridVsimQuery(
4295+
vector_field_name="@embedding",
4296+
vector_data="abcd1234efgh5678",
4297+
)
4298+
4299+
hybrid_query = HybridQuery(search_query, vsim_query)
4300+
4301+
res = client.ft().hybrid_search(query=hybrid_query, timeout=10)
4302+
4303+
if is_resp2_connection(client):
4304+
assert len(res.results) > 0
4305+
assert res.warnings == []
4306+
for item in res.results:
4307+
if item["__key"] in [b"item:0", b"item:1", b"item:4"]:
4308+
assert item["search_score"] is not None
4309+
assert item["__score"] is not None
4310+
else:
4311+
assert "search_score" not in item
4312+
assert item["__score"] is not None
4313+
4314+
else:
4315+
assert len(res["results"]) > 0
4316+
assert res["warnings"] == []
4317+
for item in res["results"]:
4318+
if item["__key"] in ["item:0", "item:1", "item:4"]:
4319+
assert item["search_score"] is not None
4320+
assert item["__score"] is not None
4321+
else:
4322+
assert "search_score" not in item
4323+
assert item["__score"] is not None
4324+
4325+
@pytest.mark.redismod
4326+
@skip_if_server_version_lt("8.3.224")
4327+
def test_hybrid_search_query_with_vsim_score_aliases(self, client):
4328+
# Create index and add data
4329+
self._create_hybrid_search_index(client)
4330+
self._add_data_for_hybrid_search(client, items_sets=1, use_random_str_data=True)
4331+
4332+
search_query = HybridSearchQuery("shoes")
4333+
4334+
vsim_query = HybridVsimQuery(
4335+
vector_field_name="@embedding-hnsw",
4336+
vector_data="abcd1234efgh5678",
4337+
vsim_search_method="KNN",
4338+
vsim_search_method_params={
4339+
"K": 3,
4340+
"EF_RUNTIME": 1,
4341+
"YIELD_SCORE_AS": "vsim_score",
4342+
},
4343+
)
4344+
4345+
hybrid_query = HybridQuery(search_query, vsim_query)
4346+
4347+
res = client.ft().hybrid_search(query=hybrid_query, timeout=10)
4348+
4349+
if is_resp2_connection(client):
4350+
assert len(res.results) > 0
4351+
assert res.warnings == []
4352+
for item in res.results:
4353+
if item["__key"] in [b"item:0", b"item:1", b"item:2"]:
4354+
assert item["vsim_score"] is not None
4355+
assert item["__score"] is not None
4356+
else:
4357+
assert "vsim_score" not in item
4358+
assert item["__score"] is not None
4359+
4360+
else:
4361+
assert len(res["results"]) > 0
4362+
assert res["warnings"] == []
4363+
for item in res["results"]:
4364+
if item["__key"] in ["item:0", "item:1", "item:2"]:
4365+
assert item["vsim_score"] is not None
4366+
assert item["__score"] is not None
4367+
else:
4368+
assert "vsim_score" not in item
4369+
assert item["__score"] is not None
4370+
4371+
@pytest.mark.redismod
4372+
@skip_if_server_version_lt("8.3.224")
4373+
def test_hybrid_search_query_with_combine_score_aliases(self, client):
4374+
# Create index and add data
4375+
self._create_hybrid_search_index(client)
4376+
self._add_data_for_hybrid_search(client, items_sets=1, use_random_str_data=True)
4377+
4378+
search_query = HybridSearchQuery("shoes")
4379+
4380+
vsim_query = HybridVsimQuery(
4381+
vector_field_name="@embedding-hnsw", vector_data="abcd1234efgh5678"
4382+
)
4383+
4384+
hybrid_query = HybridQuery(search_query, vsim_query)
4385+
4386+
posprocessing_config = HybridPostProcessingConfig()
4387+
posprocessing_config.combine(
4388+
"LINEAR", ALPHA=0.5, BETA=0.5, YIELD_SCORE_AS="combined_score"
4389+
)
4390+
4391+
res = client.ft().hybrid_search(
4392+
query=hybrid_query, post_processing=posprocessing_config, timeout=10
4393+
)
4394+
4395+
if is_resp2_connection(client):
4396+
assert len(res.results) > 0
4397+
assert res.warnings == []
4398+
for item in res.results:
4399+
assert item["combined_score"] is not None
4400+
assert "__score" not in item
4401+
4402+
else:
4403+
assert len(res["results"]) > 0
4404+
assert res["warnings"] == []
4405+
for item in res["results"]:
4406+
assert item["combined_score"] is not None
4407+
assert "__score" not in item
4408+
4409+
@pytest.mark.redismod
4410+
@skip_if_server_version_lt("8.3.224")
4411+
def test_hybrid_search_query_with_combine_all_score_aliases(self, client):
4412+
# Create index and add data
4413+
self._create_hybrid_search_index(client)
4414+
self._add_data_for_hybrid_search(client, items_sets=1, use_random_str_data=True)
4415+
4416+
search_query = HybridSearchQuery("shoes")
4417+
search_query.yield_score_as("search_score")
4418+
4419+
vsim_query = HybridVsimQuery(
4420+
vector_field_name="@embedding-hnsw",
4421+
vector_data="abcd1234efgh5678",
4422+
vsim_search_method="KNN",
4423+
vsim_search_method_params={
4424+
"K": 3,
4425+
"EF_RUNTIME": 1,
4426+
"YIELD_SCORE_AS": "vsim_score",
4427+
},
4428+
)
4429+
4430+
hybrid_query = HybridQuery(search_query, vsim_query)
4431+
4432+
posprocessing_config = HybridPostProcessingConfig()
4433+
posprocessing_config.combine(
4434+
"LINEAR", ALPHA=0.5, BETA=0.5, YIELD_SCORE_AS="combined_score"
4435+
)
4436+
4437+
res = client.ft().hybrid_search(
4438+
query=hybrid_query, post_processing=posprocessing_config, timeout=10
4439+
)
4440+
4441+
if is_resp2_connection(client):
4442+
assert len(res.results) > 0
4443+
assert res.warnings == []
4444+
for item in res.results:
4445+
assert item["combined_score"] is not None
4446+
assert "__score" not in item
4447+
if item["__key"] in [b"item:0", b"item:1", b"item:4"]:
4448+
assert item["search_score"] is not None
4449+
else:
4450+
assert "search_score" not in item
4451+
if item["__key"] in [b"item:0", b"item:1", b"item:2"]:
4452+
assert item["vsim_score"] is not None
4453+
else:
4454+
assert "vsim_score" not in item
4455+
4456+
else:
4457+
assert len(res["results"]) > 0
4458+
assert res["warnings"] == []
4459+
for item in res["results"]:
4460+
assert item["combined_score"] is not None
4461+
assert "__score" not in item
4462+
if item["__key"] in ["item:0", "item:1", "item:4"]:
4463+
assert item["search_score"] is not None
4464+
else:
4465+
assert "search_score" not in item
4466+
if item["__key"] in ["item:0", "item:1", "item:2"]:
4467+
assert item["vsim_score"] is not None
4468+
else:
4469+
assert "vsim_score" not in item
4470+
42844471
@pytest.mark.redismod
42854472
@skip_if_server_version_lt("8.3.224")
42864473
def test_hybrid_search_query_with_vsim_knn(self, client):

0 commit comments

Comments
 (0)