Skip to content

Commit b0790c0

Browse files
committed
WIP
1 parent 95e93c3 commit b0790c0

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

redisvl/extensions/router/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Route(BaseModel):
1616
"""List of reference phrases for the route."""
1717
metadata: Dict[str, str] = Field(default={})
1818
"""Metadata associated with the route."""
19-
distance_threshold: Optional[float] = Field(default=None)
19+
distance_threshold: Optional[float] = Field(default=0.5)
2020
"""Distance threshold for matching the route."""
2121

2222
@validator("name")
@@ -63,7 +63,7 @@ class DistanceAggregationMethod(Enum):
6363
class RoutingConfig(BaseModel):
6464
"""Configuration for routing behavior."""
6565

66-
distance_threshold: float = Field(default=0.5)
66+
# distance_threshold: float = Field(default=0.5)
6767
"""The threshold for semantic distance."""
6868
max_k: int = Field(default=1)
6969
"""The maximum number of top matches to return."""

redisvl/extensions/router/semantic.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,12 +302,17 @@ def _classify_multi_route(
302302
aggregation_method: DistanceAggregationMethod,
303303
) -> List[RouteMatch]:
304304
"""Classify to multiple routes, up to max_k (int), using a vector."""
305+
306+
# Set range query distance threshold to get all possible results to be further filtered
307+
distance_threshold = max(route.distance_threshold for route in self.routes)
308+
305309
vector_range_query = RangeQuery(
306310
vector=vector,
307311
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
308312
distance_threshold=distance_threshold,
309313
return_fields=["route_name"],
310314
)
315+
311316
aggregate_request = self._build_aggregate_request(
312317
vector_range_query, aggregation_method, max_k
313318
)
@@ -368,7 +373,9 @@ def __call__(
368373
self,
369374
statement: Optional[str] = None,
370375
vector: Optional[List[float]] = None,
371-
distance_threshold: Optional[float] = None,
376+
distance_threshold: Optional[
377+
float
378+
] = None, # TODO: does this get removed if route becomes the owner of a distance threshold? Or do we apply to all?
372379
aggregation_method: Optional[DistanceAggregationMethod] = None,
373380
) -> RouteMatch:
374381
"""Query the semantic router with a given statement or vector.

0 commit comments

Comments
 (0)