diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/hybrid/fusion/FusionStrategy.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/hybrid/fusion/FusionStrategy.java index c4cfe53ae..db32c5c52 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/hybrid/fusion/FusionStrategy.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/hybrid/fusion/FusionStrategy.java @@ -25,4 +25,9 @@ public interface FusionStrategy { @SuppressWarnings("unchecked") List fuseResults(int topK, List... resultLists); + @SuppressWarnings("unchecked") + default List fuseResultsWithWeights(int topK, List weights, List... resultLists) { + return fuseResults(topK, resultLists); + } + } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/hybrid/fusion/impl/WeightedAverageStrategy.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/hybrid/fusion/impl/WeightedAverageStrategy.java index b6f6248fb..deedf5109 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/hybrid/fusion/impl/WeightedAverageStrategy.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/hybrid/fusion/impl/WeightedAverageStrategy.java @@ -17,15 +17,148 @@ import com.alibaba.cloud.ai.dataagent.service.hybrid.fusion.FusionStrategy; import org.springframework.ai.document.Document; +import org.springframework.util.StringUtils; +import java.util.Comparator; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.IntStream; public class WeightedAverageStrategy implements FusionStrategy { @SuppressWarnings("unchecked") @Override public List fuseResults(int topK, List... resultLists) { - throw new UnsupportedOperationException("Not implemented"); + return fuseResultsWithWeights(topK, List.of(), resultLists); + } + + @SuppressWarnings("unchecked") + @Override + public List fuseResultsWithWeights(int topK, List weights, List... resultLists) { + if (topK <= 0 || resultLists == null || resultLists.length == 0) { + return List.of(); + } + + Map fusedDocuments = new LinkedHashMap<>(); + int sequence = 0; + + for (int listIndex = 0; listIndex < resultLists.length; listIndex++) { + List resultList = resultLists[listIndex]; + if (resultList == null || resultList.isEmpty()) { + continue; + } + + double listWeight = resolveWeight(weights, listIndex); + if (listWeight == 0.0) { + continue; + } + + List normalizedScores = normalizeScores(resultList); + for (int documentIndex = 0; documentIndex < resultList.size(); documentIndex++) { + Document document = resultList.get(documentIndex); + if (document == null) { + continue; + } + + String documentId = getDocumentId(document); + ScoredDocument scoredDocument = fusedDocuments.get(documentId); + if (scoredDocument == null) { + scoredDocument = new ScoredDocument(document, sequence); + fusedDocuments.put(documentId, scoredDocument); + } + scoredDocument.addScore(normalizedScores.get(documentIndex) * listWeight); + sequence++; + } + } + + return fusedDocuments.values() + .stream() + .sorted(Comparator.comparingDouble(ScoredDocument::getScore) + .reversed() + .thenComparingInt(ScoredDocument::getFirstSeen)) + .limit(topK) + .map(ScoredDocument::getDocument) + .collect(Collectors.toList()); + } + + private double resolveWeight(List weights, int index) { + if (weights == null || index >= weights.size()) { + return 1.0; + } + Double weight = weights.get(index); + if (weight == null || !Double.isFinite(weight)) { + return 1.0; + } + return Math.max(0.0, weight); + } + + private List normalizeScores(List resultList) { + List rawScores = IntStream.range(0, resultList.size()) + .mapToObj(index -> resolveScore(resultList.get(index), index, resultList.size())) + .collect(Collectors.toList()); + + double minScore = rawScores.stream().mapToDouble(Double::doubleValue).min().orElse(0.0); + double maxScore = rawScores.stream().mapToDouble(Double::doubleValue).max().orElse(0.0); + if (Double.compare(maxScore, minScore) == 0) { + return rawScores.stream().map(this::clampToUnitInterval).collect(Collectors.toList()); + } + + return rawScores.stream().map(score -> (score - minScore) / (maxScore - minScore)).collect(Collectors.toList()); + } + + private double resolveScore(Document document, int index, int resultListSize) { + if (document != null && document.getScore() != null && Double.isFinite(document.getScore())) { + return document.getScore(); + } + return (resultListSize - index) / (double) resultListSize; + } + + private double clampToUnitInterval(double score) { + return Math.max(0.0, Math.min(1.0, score)); + } + + private String getDocumentId(Document document) { + if (StringUtils.hasText(document.getId())) { + return document.getId(); + } + if (StringUtils.hasText(document.getText())) { + return String.valueOf(document.getText().hashCode()); + } + return String.valueOf(Objects.hashCode(document.getMetadata())); + } + + private static final class ScoredDocument { + + private final Document document; + + private final int firstSeen; + + private double score; + + private ScoredDocument(Document document, int firstSeen) { + this.document = document; + this.firstSeen = firstSeen; + } + + private void addScore(double score) { + this.score += score; + } + + private Document getDocument() { + return this.document; + } + + private int getFirstSeen() { + return this.firstSeen; + } + + private double getScore() { + return this.score; + } + } } diff --git a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/hybrid/retrieval/AbstractHybridRetrievalStrategy.java b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/hybrid/retrieval/AbstractHybridRetrievalStrategy.java index c6c346dce..91fffc782 100644 --- a/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/hybrid/retrieval/AbstractHybridRetrievalStrategy.java +++ b/data-agent-management/src/main/java/com/alibaba/cloud/ai/dataagent/service/hybrid/retrieval/AbstractHybridRetrievalStrategy.java @@ -22,6 +22,7 @@ import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; +import java.util.Arrays; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -77,7 +78,8 @@ public List retrieve(HybridSearchRequest request) { List keywordResults = keywordSearchFuture.get(); // 融合结果 - List finalDocuments = fusionStrategy.fuseResults(request.getTopK(), vectorResults, + List finalDocuments = fusionStrategy.fuseResultsWithWeights(request.getTopK(), + Arrays.asList(request.getVectorWeight(), request.getKeywordWeight()), vectorResults, keywordResults); log.debug("Fusion completed. Found {} documents", finalDocuments.size()); return finalDocuments; diff --git a/data-agent-management/src/test/java/com/alibaba/cloud/ai/dataagent/service/hybrid/fusion/impl/WeightedAverageStrategyTest.java b/data-agent-management/src/test/java/com/alibaba/cloud/ai/dataagent/service/hybrid/fusion/impl/WeightedAverageStrategyTest.java index 3fb1beffd..e597a07bc 100644 --- a/data-agent-management/src/test/java/com/alibaba/cloud/ai/dataagent/service/hybrid/fusion/impl/WeightedAverageStrategyTest.java +++ b/data-agent-management/src/test/java/com/alibaba/cloud/ai/dataagent/service/hybrid/fusion/impl/WeightedAverageStrategyTest.java @@ -18,6 +18,8 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.document.Document; +import java.util.Arrays; +import java.util.Collections; import java.util.List; import static org.junit.jupiter.api.Assertions.*; @@ -25,12 +27,85 @@ class WeightedAverageStrategyTest { @Test - void testFuseResults_throwsUnsupportedOperationException() { + void fuseResults_withNullInput_returnsEmptyList() { WeightedAverageStrategy strategy = new WeightedAverageStrategy(); - List list1 = List.of(new Document("doc1")); - List list2 = List.of(new Document("doc2")); - assertThrows(UnsupportedOperationException.class, () -> strategy.fuseResults(5, list1, list2)); + List result = strategy.fuseResults(10, (List[]) null); + + assertTrue(result.isEmpty()); + } + + @Test + void fuseResults_withTopKLessThanOne_returnsEmptyList() { + WeightedAverageStrategy strategy = new WeightedAverageStrategy(); + Document doc = new Document("id1", "content1", Collections.emptyMap()); + + List result = strategy.fuseResults(0, List.of(doc)); + + assertTrue(result.isEmpty()); + } + + @Test + void fuseResults_withSingleList_keepsRankOrder() { + WeightedAverageStrategy strategy = new WeightedAverageStrategy(); + Document doc1 = new Document("id1", "content1", Collections.emptyMap()); + Document doc2 = new Document("id2", "content2", Collections.emptyMap()); + + List result = strategy.fuseResults(10, Arrays.asList(doc1, doc2)); + + assertEquals(2, result.size()); + assertEquals("id1", result.get(0).getId()); + assertEquals("id2", result.get(1).getId()); + } + + @Test + void fuseResults_withTopKLimit_returnsLimitedResults() { + WeightedAverageStrategy strategy = new WeightedAverageStrategy(); + Document doc1 = new Document("id1", "content1", Collections.emptyMap()); + Document doc2 = new Document("id2", "content2", Collections.emptyMap()); + Document doc3 = new Document("id3", "content3", Collections.emptyMap()); + + List result = strategy.fuseResults(2, Arrays.asList(doc1, doc2, doc3)); + + assertEquals(2, result.size()); + } + + @Test + void fuseResults_withDuplicatesAcrossLists_mergesScores() { + WeightedAverageStrategy strategy = new WeightedAverageStrategy(); + Document docA = new Document("a", "content a", Collections.emptyMap()); + Document docB = new Document("b", "content b", Collections.emptyMap()); + Document docC = new Document("c", "content c", Collections.emptyMap()); + Document docACopy = new Document("a", "content a copy", Collections.emptyMap()); + + List result = strategy.fuseResults(10, Arrays.asList(docB, docA, docC), List.of(docACopy)); + + assertEquals(3, result.size()); + assertEquals("a", result.get(0).getId()); + } + + @Test + void fuseResults_withWeights_prefersHigherWeightedSource() { + WeightedAverageStrategy strategy = new WeightedAverageStrategy(); + Document vectorDoc = new Document("vector", "vector content", Collections.emptyMap()); + Document keywordDoc = new Document("keyword", "keyword content", Collections.emptyMap()); + + List result = strategy.fuseResultsWithWeights(10, Arrays.asList(0.2, 0.8), List.of(vectorDoc), + List.of(keywordDoc)); + + assertEquals(2, result.size()); + assertEquals("keyword", result.get(0).getId()); + } + + @Test + void fuseResults_withNullListAndNullDocument_skipsInvalidEntries() { + WeightedAverageStrategy strategy = new WeightedAverageStrategy(); + Document doc = new Document("id1", "content1", Collections.emptyMap()); + + List result = strategy.fuseResults(10, Arrays.asList(doc, null), null); + + assertEquals(1, result.size()); + assertEquals("id1", result.get(0).getId()); } } diff --git a/data-agent-management/src/test/java/com/alibaba/cloud/ai/dataagent/service/hybrid/retrieval/AbstractHybridRetrievalStrategyTest.java b/data-agent-management/src/test/java/com/alibaba/cloud/ai/dataagent/service/hybrid/retrieval/AbstractHybridRetrievalStrategyTest.java index e7943a941..75272405d 100644 --- a/data-agent-management/src/test/java/com/alibaba/cloud/ai/dataagent/service/hybrid/retrieval/AbstractHybridRetrievalStrategyTest.java +++ b/data-agent-management/src/test/java/com/alibaba/cloud/ai/dataagent/service/hybrid/retrieval/AbstractHybridRetrievalStrategyTest.java @@ -26,12 +26,14 @@ import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; +import java.util.Arrays; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -60,7 +62,7 @@ void testRetrieve_combinesVectorAndKeywordResults() { Document fusedDoc = new Document("fused result"); when(vectorStore.similaritySearch(any(SearchRequest.class))).thenReturn(List.of(vectorDoc)); - when(fusionStrategy.fuseResults(anyInt(), any(), any())).thenReturn(List.of(fusedDoc)); + when(fusionStrategy.fuseResultsWithWeights(anyInt(), anyList(), any(), any())).thenReturn(List.of(fusedDoc)); strategy.setKeywordResults(List.of(keywordDoc)); @@ -75,7 +77,7 @@ void testRetrieve_combinesVectorAndKeywordResults() { @Test void testRetrieve_emptyResults() { when(vectorStore.similaritySearch(any(SearchRequest.class))).thenReturn(List.of()); - when(fusionStrategy.fuseResults(anyInt(), any(), any())).thenReturn(List.of()); + when(fusionStrategy.fuseResultsWithWeights(anyInt(), anyList(), any(), any())).thenReturn(List.of()); strategy.setKeywordResults(List.of()); @@ -86,6 +88,28 @@ void testRetrieve_emptyResults() { assertTrue(results.isEmpty()); } + @Test + void testRetrieve_passesConfiguredFusionWeights() { + Document vectorDoc = new Document("vector result"); + Document keywordDoc = new Document("keyword result"); + + when(vectorStore.similaritySearch(any(SearchRequest.class))).thenReturn(List.of(vectorDoc)); + when(fusionStrategy.fuseResultsWithWeights(anyInt(), anyList(), any(), any())).thenReturn(List.of(vectorDoc)); + + strategy.setKeywordResults(List.of(keywordDoc)); + + HybridSearchRequest request = HybridSearchRequest.builder() + .query("test") + .topK(5) + .vectorWeight(0.7) + .keywordWeight(0.3) + .build(); + + strategy.retrieve(request); + + verify(fusionStrategy).fuseResultsWithWeights(eq(5), eq(Arrays.asList(0.7, 0.3)), any(), any()); + } + static class TestHybridRetrievalStrategy extends AbstractHybridRetrievalStrategy { private List keywordResults = List.of(); diff --git a/data-agent-management/src/test/java/com/alibaba/cloud/ai/dataagent/service/hybrid/retrieval/impl/DefaultHybridRetrievalStrategyTest.java b/data-agent-management/src/test/java/com/alibaba/cloud/ai/dataagent/service/hybrid/retrieval/impl/DefaultHybridRetrievalStrategyTest.java index 443005851..51420e91f 100644 --- a/data-agent-management/src/test/java/com/alibaba/cloud/ai/dataagent/service/hybrid/retrieval/impl/DefaultHybridRetrievalStrategyTest.java +++ b/data-agent-management/src/test/java/com/alibaba/cloud/ai/dataagent/service/hybrid/retrieval/impl/DefaultHybridRetrievalStrategyTest.java @@ -26,6 +26,7 @@ import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.concurrent.ExecutorService; @@ -71,14 +72,14 @@ void retrieve_returnsVectorDocsOnly() { HybridSearchRequest request = HybridSearchRequest.builder().query("test query").topK(5).build(); when(vectorStore.similaritySearch(any(SearchRequest.class))).thenReturn(vectorResults); - when(fusionStrategy.fuseResults(eq(5), any(), any())).thenReturn(vectorResults); + when(fusionStrategy.fuseResultsWithWeights(eq(5), anyList(), any(), any())).thenReturn(vectorResults); List result = strategy.retrieve(request); assertEquals(1, result.size()); assertEquals("v1", result.get(0).getId()); verify(vectorStore).similaritySearch(any(SearchRequest.class)); - verify(fusionStrategy).fuseResults(eq(5), any(), any()); + verify(fusionStrategy).fuseResultsWithWeights(eq(5), eq(Arrays.asList(0.5, 0.5)), any(), any()); } } diff --git a/pom.xml b/pom.xml index 9f06c6086..d86d3430a 100644 --- a/pom.xml +++ b/pom.xml @@ -47,6 +47,7 @@ 3.11.0 + 3.2.5 3.1.0 3.5.3 @@ -373,6 +374,11 @@ + + org.apache.maven.plugins + maven-surefire-plugin + ${maven-surefire-plugin.version} +