@@ -317,6 +317,80 @@ def test_hybrid_query_with_text_filter(index):
317317        assert  "research"  not  in result [text_field ].lower ()
318318
319319
320+ @pytest .mark .parametrize ("scorer" , ["BM25" , "BM25STD" , "TFIDF" , "TFIDF.DOCNORM" ]) 
321+ def  test_hybrid_query_word_weights (index , scorer ):
322+     text  =  "a medical professional with expertise in lung cancers" 
323+     text_field  =  "description" 
324+     vector  =  [0.1 , 0.1 , 0.5 ]
325+     vector_field  =  "user_embedding" 
326+     return_fields  =  ["description" ]
327+ 
328+     weights  =  {"medical" : 3.4 , "cancers" : 5 }
329+ 
330+     # test we can run a query with text weights 
331+     weighted_query  =  HybridQuery (
332+         text = text ,
333+         text_field_name = text_field ,
334+         vector = vector ,
335+         vector_field_name = vector_field ,
336+         return_fields = return_fields ,
337+         text_scorer = scorer ,
338+         text_weights = weights ,
339+     )
340+ 
341+     weighted_results  =  index .query (weighted_query )
342+     assert  len (weighted_results ) ==  7 
343+ 
344+     # test that weights do change the scores on results 
345+     unweighted_query  =  HybridQuery (
346+         text = text ,
347+         text_field_name = text_field ,
348+         vector = vector ,
349+         vector_field_name = vector_field ,
350+         return_fields = return_fields ,
351+         text_scorer = scorer ,
352+         text_weights = {},
353+     )
354+ 
355+     unweighted_results  =  index .query (unweighted_query )
356+ 
357+     for  weighted , unweighted  in  zip (weighted_results , unweighted_results ):
358+         for  word  in  weights :
359+             if  word  in  weighted ["description" ] or  word  in  unweighted ["description" ]:
360+                 assert  float (weighted ["text_score" ]) >  float (unweighted ["text_score" ])
361+ 
362+     # test that weights do change the document score and order of results 
363+     weights  =  {"medical" : 5 , "cancers" : 3.4 }  # switch the weights 
364+     weighted_query  =  HybridQuery (
365+         text = text ,
366+         text_field_name = text_field ,
367+         vector = vector ,
368+         vector_field_name = vector_field ,
369+         return_fields = return_fields ,
370+         text_scorer = scorer ,
371+         text_weights = weights ,
372+     )
373+ 
374+     weighted_results  =  index .query (weighted_query )
375+     assert  weighted_results  !=  unweighted_results 
376+ 
377+     # test assigning weights on construction is equivalent to setting them on the query object 
378+     new_query  =  HybridQuery (
379+         text = text ,
380+         text_field_name = text_field ,
381+         vector = vector ,
382+         vector_field_name = vector_field ,
383+         return_fields = return_fields ,
384+         text_scorer = scorer ,
385+         text_weights = None ,
386+     )
387+ 
388+     new_query .set_text_weights (weights )
389+ 
390+     new_weighted_results  =  index .query (new_query )
391+     assert  new_weighted_results  ==  weighted_results 
392+ 
393+ 
320394def  test_multivector_query (index ):
321395    skip_if_redis_version_below (index .client , "7.2.0" )
322396
0 commit comments