@@ -1061,9 +1061,9 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa
10611061 docs = [
10621062 (
10631063 Document (
1064- id = str (result .EmbeddingStore . id ),
1065- page_content = result .EmbeddingStore . document ,
1066- metadata = result .EmbeddingStore . cmetadata ,
1064+ id = str (result .id ),
1065+ page_content = result .document ,
1066+ metadata = result .cmetadata ,
10671067 ),
10681068 result .distance if self .embeddings is not None else None ,
10691069 )
@@ -1396,8 +1396,16 @@ def __query_collection(
13961396 embedding : List [float ],
13971397 k : int = 4 ,
13981398 filter : Optional [Dict [str , str ]] = None ,
1399+ retrieve_embeddings : bool = False ,
13991400 ) -> Sequence [Any ]:
14001401 """Query the collection."""
1402+ columns_to_select = [
1403+ self .EmbeddingStore .id ,
1404+ self .EmbeddingStore .document ,
1405+ self .EmbeddingStore .cmetadata ,
1406+ ]
1407+ if retrieve_embeddings :
1408+ columns_to_select .append (self .EmbeddingStore .embedding )
14011409 with self ._make_sync_session () as session : # type: ignore[arg-type]
14021410 collection = self .get_collection (session )
14031411 if not collection :
@@ -1418,7 +1426,7 @@ def __query_collection(
14181426
14191427 results : List [Any ] = (
14201428 session .query (
1421- self . EmbeddingStore ,
1429+ * columns_to_select ,
14221430 self .distance_strategy (embedding ).label ("distance" ),
14231431 )
14241432 .filter (* filter_by )
@@ -1439,8 +1447,16 @@ async def __aquery_collection(
14391447 embedding : List [float ],
14401448 k : int = 4 ,
14411449 filter : Optional [Dict [str , str ]] = None ,
1450+ retrieve_embeddings : bool = False ,
14421451 ) -> Sequence [Any ]:
14431452 """Query the collection."""
1453+ columns_to_select = [
1454+ self .EmbeddingStore .id ,
1455+ self .EmbeddingStore .document ,
1456+ self .EmbeddingStore .cmetadata ,
1457+ ]
1458+ if retrieve_embeddings :
1459+ columns_to_select .append (self .EmbeddingStore .embedding )
14441460 async with self ._make_async_session () as session : # type: ignore[arg-type]
14451461 collection = await self .aget_collection (session )
14461462 if not collection :
@@ -1900,9 +1916,11 @@ def max_marginal_relevance_search_with_score_by_vector(
19001916 relevance to the query and score for each.
19011917 """
19021918 assert not self ._async_engine , "This method must be called without async_mode"
1903- results = self .__query_collection (embedding = embedding , k = fetch_k , filter = filter )
1919+ results = self .__query_collection (
1920+ embedding = embedding , k = fetch_k , filter = filter , retrieve_embeddings = True
1921+ )
19041922
1905- embedding_list = [result .EmbeddingStore . embedding for result in results ]
1923+ embedding_list = [result .embedding for result in results ]
19061924
19071925 mmr_selected = maximal_marginal_relevance (
19081926 np .array (embedding , dtype = np .float32 ),
@@ -1948,10 +1966,14 @@ async def amax_marginal_relevance_search_with_score_by_vector(
19481966 await self .__apost_init__ () # Lazy async init
19491967 async with self ._make_async_session () as session :
19501968 results = await self .__aquery_collection (
1951- session = session , embedding = embedding , k = fetch_k , filter = filter
1969+ session = session ,
1970+ embedding = embedding ,
1971+ k = fetch_k ,
1972+ filter = filter ,
1973+ retrieve_embeddings = True ,
19521974 )
19531975
1954- embedding_list = [result .EmbeddingStore . embedding for result in results ]
1976+ embedding_list = [result .embedding for result in results ]
19551977
19561978 mmr_selected = maximal_marginal_relevance (
19571979 np .array (embedding , dtype = np .float32 ),
0 commit comments