@@ -302,12 +302,17 @@ def _classify_multi_route(
302302 aggregation_method : DistanceAggregationMethod ,
303303 ) -> List [RouteMatch ]:
304304 """Classify to multiple routes, up to max_k (int), using a vector."""
305+
306+ # Set range query distance threshold to get all possible results to be further filtered
307+ distance_threshold = max (route .distance_threshold for route in self .routes )
308+
305309 vector_range_query = RangeQuery (
306310 vector = vector ,
307311 vector_field_name = ROUTE_VECTOR_FIELD_NAME ,
308312 distance_threshold = distance_threshold ,
309313 return_fields = ["route_name" ],
310314 )
315+
311316 aggregate_request = self ._build_aggregate_request (
312317 vector_range_query , aggregation_method , max_k
313318 )
@@ -368,7 +373,9 @@ def __call__(
368373 self ,
369374 statement : Optional [str ] = None ,
370375 vector : Optional [List [float ]] = None ,
371- distance_threshold : Optional [float ] = None ,
376+ distance_threshold : Optional [
377+ float
378+ ] = None , # TODO: does this get removed if route becomes the owner of a distance threshold? Or do we apply to all?
372379 aggregation_method : Optional [DistanceAggregationMethod ] = None ,
373380 ) -> RouteMatch :
374381 """Query the semantic router with a given statement or vector.
0 commit comments