Skip to content

Commit 400d15f

Browse files
adds integration tests to TextQuery and HybridQuery for text weights
1 parent de56f84 commit 400d15f

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed

tests/integration/test_aggregation.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,80 @@ def test_hybrid_query_with_text_filter(index):
317317
assert "research" not in result[text_field].lower()
318318

319319

320+
@pytest.mark.parametrize("scorer", ["BM25", "BM25STD", "TFIDF", "TFIDF.DOCNORM"])
321+
def test_hybrid_query_word_weights(index, scorer):
322+
text = "a medical professional with expertise in lung cancers"
323+
text_field = "description"
324+
vector = [0.1, 0.1, 0.5]
325+
vector_field = "user_embedding"
326+
return_fields = ["description"]
327+
328+
weights = {"medical": 3.4, "cancers": 5}
329+
330+
# test we can run a query with text weights
331+
weighted_query = HybridQuery(
332+
text=text,
333+
text_field_name=text_field,
334+
vector=vector,
335+
vector_field_name=vector_field,
336+
return_fields=return_fields,
337+
text_scorer=scorer,
338+
text_weights=weights,
339+
)
340+
341+
weighted_results = index.query(weighted_query)
342+
assert len(weighted_results) == 7
343+
344+
# test that weights do change the scores on results
345+
unweighted_query = HybridQuery(
346+
text=text,
347+
text_field_name=text_field,
348+
vector=vector,
349+
vector_field_name=vector_field,
350+
return_fields=return_fields,
351+
text_scorer=scorer,
352+
text_weights={},
353+
)
354+
355+
unweighted_results = index.query(unweighted_query)
356+
357+
for weighted, unweighted in zip(weighted_results, unweighted_results):
358+
for word in weights:
359+
if word in weighted["description"] or word in unweighted["description"]:
360+
assert float(weighted["text_score"]) > float(unweighted["text_score"])
361+
362+
# test that weights do change the document score and order of results
363+
weights = {"medical": 5, "cancers": 3.4} # switch the weights
364+
weighted_query = HybridQuery(
365+
text=text,
366+
text_field_name=text_field,
367+
vector=vector,
368+
vector_field_name=vector_field,
369+
return_fields=return_fields,
370+
text_scorer=scorer,
371+
text_weights=weights,
372+
)
373+
374+
weighted_results = index.query(weighted_query)
375+
assert weighted_results != unweighted_results
376+
377+
# test assigning weights on construction is equivalent to setting them on the query object
378+
new_query = HybridQuery(
379+
text=text,
380+
text_field_name=text_field,
381+
vector=vector,
382+
vector_field_name=vector_field,
383+
return_fields=return_fields,
384+
text_scorer=scorer,
385+
text_weights=None,
386+
)
387+
388+
new_query.set_text_weights(weights)
389+
390+
new_weighted_results = index.query(new_query)
391+
assert new_weighted_results == weighted_results
392+
393+
320394
def test_multivector_query(index):
321395
skip_if_redis_version_below(index.client, "7.2.0")
322396

tests/integration/test_query.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,70 @@ def test_text_query_with_text_filter(index):
888888
assert "research" not in result[text_field]
889889

890890

891+
@pytest.mark.parametrize("scorer", ["BM25", "BM25STD", "TFIDF", "TFIDF.DOCNORM"])
892+
def test_text_query_word_weights(index, scorer):
893+
text = "a medical professional with expertise in lung cancers"
894+
text_field = "description"
895+
return_fields = ["description"]
896+
897+
weights = {"medical": 3.4, "cancers": 5}
898+
899+
# test we can run a query with text weights
900+
weighted_query = TextQuery(
901+
text=text,
902+
text_field_name=text_field,
903+
return_fields=return_fields,
904+
text_scorer=scorer,
905+
text_weights=weights,
906+
)
907+
908+
weighted_results = index.query(weighted_query)
909+
assert len(weighted_results) == 4
910+
911+
# test that weights do change the scores on results
912+
unweighted_query = TextQuery(
913+
text=text,
914+
text_field_name=text_field,
915+
return_fields=return_fields,
916+
text_scorer=scorer,
917+
text_weights={},
918+
)
919+
920+
unweighted_results = index.query(unweighted_query)
921+
922+
for weighted, unweighted in zip(weighted_results, unweighted_results):
923+
for word in weights:
924+
if word in weighted["description"] or word in unweighted["description"]:
925+
assert weighted["score"] > unweighted["score"]
926+
927+
# test that weights do change the document score and order of results
928+
weights = {"medical": 5, "cancers": 3.4} # switch the weights
929+
weighted_query = TextQuery(
930+
text=text,
931+
text_field_name=text_field,
932+
return_fields=return_fields,
933+
text_scorer=scorer,
934+
text_weights=weights,
935+
)
936+
937+
weighted_results = index.query(weighted_query)
938+
assert weighted_results != unweighted_results
939+
940+
# test assigning weights on construction is equivalent to setting them on the query object
941+
new_query = TextQuery(
942+
text=text,
943+
text_field_name=text_field,
944+
return_fields=return_fields,
945+
text_scorer=scorer,
946+
text_weights=None,
947+
)
948+
949+
new_query.set_text_weights(weights)
950+
951+
new_weighted_results = index.query(new_query)
952+
assert new_weighted_results == weighted_results
953+
954+
891955
def test_vector_query_with_ef_runtime(index, vector_query, sample_data):
892956
"""
893957
Integration test: Verify that setting EF_RUNTIME on a VectorQuery works correctly.

0 commit comments

Comments
 (0)