@@ -1060,9 +1060,9 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa
10601060 docs = [
10611061 (
10621062 Document (
1063- id = str (result .EmbeddingStore . id ),
1064- page_content = result .EmbeddingStore . document ,
1065- metadata = result .EmbeddingStore . cmetadata ,
1063+ id = str (result .id ),
1064+ page_content = result .document ,
1065+ metadata = result .cmetadata ,
10661066 ),
10671067 result .distance if self .embeddings is not None else None ,
10681068 )
@@ -1395,8 +1395,16 @@ def __query_collection(
13951395 embedding : List [float ],
13961396 k : int = 4 ,
13971397 filter : Optional [Dict [str , str ]] = None ,
1398+ retrieve_embeddings : bool = False ,
13981399 ) -> Sequence [Any ]:
13991400 """Query the collection."""
1401+ columns_to_select = [
1402+ self .EmbeddingStore .id ,
1403+ self .EmbeddingStore .document ,
1404+ self .EmbeddingStore .cmetadata ,
1405+ ]
1406+ if retrieve_embeddings :
1407+ columns_to_select .append (self .EmbeddingStore .embedding )
14001408 with self ._make_sync_session () as session : # type: ignore[arg-type]
14011409 collection = self .get_collection (session )
14021410 if not collection :
@@ -1417,7 +1425,7 @@ def __query_collection(
14171425
14181426 results : List [Any ] = (
14191427 session .query (
1420- self . EmbeddingStore ,
1428+ * columns_to_select ,
14211429 self .distance_strategy (embedding ).label ("distance" ),
14221430 )
14231431 .filter (* filter_by )
@@ -1438,8 +1446,16 @@ async def __aquery_collection(
14381446 embedding : List [float ],
14391447 k : int = 4 ,
14401448 filter : Optional [Dict [str , str ]] = None ,
1449+ retrieve_embeddings : bool = False ,
14411450 ) -> Sequence [Any ]:
14421451 """Query the collection."""
1452+ columns_to_select = [
1453+ self .EmbeddingStore .id ,
1454+ self .EmbeddingStore .document ,
1455+ self .EmbeddingStore .cmetadata ,
1456+ ]
1457+ if retrieve_embeddings :
1458+ columns_to_select .append (self .EmbeddingStore .embedding )
14431459 async with self ._make_async_session () as session : # type: ignore[arg-type]
14441460 collection = await self .aget_collection (session )
14451461 if not collection :
@@ -1460,7 +1476,7 @@ async def __aquery_collection(
14601476
14611477 stmt = (
14621478 select (
1463- self . EmbeddingStore ,
1479+ * columns_to_select ,
14641480 self .distance_strategy (embedding ).label ("distance" ),
14651481 )
14661482 .filter (* filter_by )
@@ -1899,9 +1915,11 @@ def max_marginal_relevance_search_with_score_by_vector(
18991915 relevance to the query and score for each.
19001916 """
19011917 assert not self ._async_engine , "This method must be called without async_mode"
1902- results = self .__query_collection (embedding = embedding , k = fetch_k , filter = filter )
1918+ results = self .__query_collection (
1919+ embedding = embedding , k = fetch_k , filter = filter , retrieve_embeddings = True
1920+ )
19031921
1904- embedding_list = [result .EmbeddingStore . embedding for result in results ]
1922+ embedding_list = [result .embedding for result in results ]
19051923
19061924 mmr_selected = maximal_marginal_relevance (
19071925 np .array (embedding , dtype = np .float32 ),
@@ -1947,10 +1965,14 @@ async def amax_marginal_relevance_search_with_score_by_vector(
19471965 await self .__apost_init__ () # Lazy async init
19481966 async with self ._make_async_session () as session :
19491967 results = await self .__aquery_collection (
1950- session = session , embedding = embedding , k = fetch_k , filter = filter
1968+ session = session ,
1969+ embedding = embedding ,
1970+ k = fetch_k ,
1971+ filter = filter ,
1972+ retrieve_embeddings = True ,
19511973 )
19521974
1953- embedding_list = [result .EmbeddingStore . embedding for result in results ]
1975+ embedding_list = [result .embedding for result in results ]
19541976
19551977 mmr_selected = maximal_marginal_relevance (
19561978 np .array (embedding , dtype = np .float32 ),
0 commit comments