@@ -85,26 +85,42 @@ def __init__(
8585 vectorizer = vectorizer ,
8686 routing_config = routing_config ,
8787 )
88- self ._initialize_index (redis_client , redis_url , overwrite , ** connection_kwargs )
88+ dtype = kwargs .get ("dtype" , "float32" )
89+ self ._initialize_index (
90+ redis_client , redis_url , overwrite , dtype , ** connection_kwargs
91+ )
8992
9093 def _initialize_index (
9194 self ,
9295 redis_client : Optional [Redis ] = None ,
9396 redis_url : str = "redis://localhost:6379" ,
9497 overwrite : bool = False ,
98+ dtype : str = "float32" ,
9599 ** connection_kwargs ,
96100 ):
97101 """Initialize the search index and handle Redis connection."""
98- schema = SemanticRouterIndexSchema .from_params (self .name , self .vectorizer .dims )
102+ schema = SemanticRouterIndexSchema .from_params (
103+ self .name , self .vectorizer .dims , dtype
104+ )
99105 self ._index = SearchIndex (schema = schema )
100106
101107 if redis_client :
102108 self ._index .set_client (redis_client )
103109 elif redis_url :
104110 self ._index .connect (redis_url = redis_url , ** connection_kwargs )
105111
112+ # Check for existing router index
106113 existed = self ._index .exists ()
107- self ._index .create (overwrite = overwrite )
114+ if not overwrite and existed :
115+ existing_index = SearchIndex .from_existing (
116+ self .name , redis_client = self ._index .client
117+ )
118+ if existing_index .schema != self ._index .schema :
119+ raise ValueError (
120+ f"Existing index { self .name } schema does not match the user provided schema for the semantic router. "
121+ "If you wish to overwrite the index schema, set overwrite=True during initialization."
122+ )
123+ self ._index .create (overwrite = overwrite , drop = False )
108124
109125 if not existed or overwrite :
110126 # write the routes to Redis
@@ -153,7 +169,9 @@ def _add_routes(self, routes: List[Route]):
153169 for route in routes :
154170 # embed route references as a single batch
155171 reference_vectors = self .vectorizer .embed_many (
156- [reference for reference in route .references ], as_buffer = True
172+ [reference for reference in route .references ],
173+ as_buffer = True ,
174+ dtype = self ._index .schema .fields [ROUTE_VECTOR_FIELD_NAME ].attrs .datatype , # type: ignore[union-attr]
157175 )
158176 # set route references
159177 for i , reference in enumerate (route .references ):
@@ -230,6 +248,7 @@ def _classify_route(
230248 vector_field_name = ROUTE_VECTOR_FIELD_NAME ,
231249 distance_threshold = distance_threshold ,
232250 return_fields = ["route_name" ],
251+ dtype = self ._index .schema .fields [ROUTE_VECTOR_FIELD_NAME ].attrs .datatype , # type: ignore[union-attr]
233252 )
234253
235254 aggregate_request = self ._build_aggregate_request (
@@ -282,6 +301,7 @@ def _classify_multi_route(
282301 vector_field_name = ROUTE_VECTOR_FIELD_NAME ,
283302 distance_threshold = distance_threshold ,
284303 return_fields = ["route_name" ],
304+ dtype = self ._index .schema .fields [ROUTE_VECTOR_FIELD_NAME ].attrs .datatype , # type: ignore[union-attr]
285305 )
286306 aggregate_request = self ._build_aggregate_request (
287307 vector_range_query , aggregation_method , max_k
0 commit comments