Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,9 @@ public interface FusionStrategy {
@SuppressWarnings("unchecked")
List<Document> fuseResults(int topK, List<Document>... resultLists);

@SuppressWarnings("unchecked")
default List<Document> fuseResultsWithWeights(int topK, List<Double> weights, List<Document>... resultLists) {
return fuseResults(topK, resultLists);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,148 @@

import com.alibaba.cloud.ai.dataagent.service.hybrid.fusion.FusionStrategy;
import org.springframework.ai.document.Document;
import org.springframework.util.StringUtils;

import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class WeightedAverageStrategy implements FusionStrategy {

@SuppressWarnings("unchecked")
@Override
public List<Document> fuseResults(int topK, List<Document>... resultLists) {
throw new UnsupportedOperationException("Not implemented");
return fuseResultsWithWeights(topK, List.of(), resultLists);
}

@SuppressWarnings("unchecked")
@Override
public List<Document> fuseResultsWithWeights(int topK, List<Double> weights, List<Document>... resultLists) {
if (topK <= 0 || resultLists == null || resultLists.length == 0) {
return List.of();
}

Map<String, ScoredDocument> fusedDocuments = new LinkedHashMap<>();
int sequence = 0;

for (int listIndex = 0; listIndex < resultLists.length; listIndex++) {
List<Document> resultList = resultLists[listIndex];
if (resultList == null || resultList.isEmpty()) {
continue;
}

double listWeight = resolveWeight(weights, listIndex);
if (listWeight == 0.0) {
continue;
}

List<Double> normalizedScores = normalizeScores(resultList);
for (int documentIndex = 0; documentIndex < resultList.size(); documentIndex++) {
Document document = resultList.get(documentIndex);
if (document == null) {
continue;
}

String documentId = getDocumentId(document);
ScoredDocument scoredDocument = fusedDocuments.get(documentId);
if (scoredDocument == null) {
scoredDocument = new ScoredDocument(document, sequence);
fusedDocuments.put(documentId, scoredDocument);
}
scoredDocument.addScore(normalizedScores.get(documentIndex) * listWeight);
sequence++;
}
}

return fusedDocuments.values()
.stream()
.sorted(Comparator.comparingDouble(ScoredDocument::getScore)
.reversed()
.thenComparingInt(ScoredDocument::getFirstSeen))
.limit(topK)
.map(ScoredDocument::getDocument)
.collect(Collectors.toList());
}

private double resolveWeight(List<Double> weights, int index) {
if (weights == null || index >= weights.size()) {
return 1.0;
}
Double weight = weights.get(index);
if (weight == null || !Double.isFinite(weight)) {
return 1.0;
}
return Math.max(0.0, weight);
}

private List<Double> normalizeScores(List<Document> resultList) {
List<Double> rawScores = IntStream.range(0, resultList.size())
.mapToObj(index -> resolveScore(resultList.get(index), index, resultList.size()))
.collect(Collectors.toList());

double minScore = rawScores.stream().mapToDouble(Double::doubleValue).min().orElse(0.0);
double maxScore = rawScores.stream().mapToDouble(Double::doubleValue).max().orElse(0.0);
if (Double.compare(maxScore, minScore) == 0) {
return rawScores.stream().map(this::clampToUnitInterval).collect(Collectors.toList());
}

return rawScores.stream().map(score -> (score - minScore) / (maxScore - minScore)).collect(Collectors.toList());
}

private double resolveScore(Document document, int index, int resultListSize) {
if (document != null && document.getScore() != null && Double.isFinite(document.getScore())) {
return document.getScore();
}
return (resultListSize - index) / (double) resultListSize;
}

private double clampToUnitInterval(double score) {
return Math.max(0.0, Math.min(1.0, score));
}

private String getDocumentId(Document document) {
if (StringUtils.hasText(document.getId())) {
return document.getId();
}
if (StringUtils.hasText(document.getText())) {
return String.valueOf(document.getText().hashCode());
}
return String.valueOf(Objects.hashCode(document.getMetadata()));
}

private static final class ScoredDocument {

private final Document document;

private final int firstSeen;

private double score;

private ScoredDocument(Document document, int firstSeen) {
this.document = document;
this.firstSeen = firstSeen;
}

private void addScore(double score) {
this.score += score;
}

private Document getDocument() {
return this.document;
}

private int getFirstSeen() {
return this.firstSeen;
}

private double getScore() {
return this.score;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -77,7 +78,8 @@ public List<Document> retrieve(HybridSearchRequest request) {
List<Document> keywordResults = keywordSearchFuture.get();

// 融合结果
List<Document> finalDocuments = fusionStrategy.fuseResults(request.getTopK(), vectorResults,
List<Document> finalDocuments = fusionStrategy.fuseResultsWithWeights(request.getTopK(),
Arrays.asList(request.getVectorWeight(), request.getKeywordWeight()), vectorResults,
keywordResults);
log.debug("Fusion completed. Found {} documents", finalDocuments.size());
return finalDocuments;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,94 @@
import org.junit.jupiter.api.Test;
import org.springframework.ai.document.Document;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

class WeightedAverageStrategyTest {

@Test
void testFuseResults_throwsUnsupportedOperationException() {
void fuseResults_withNullInput_returnsEmptyList() {
WeightedAverageStrategy strategy = new WeightedAverageStrategy();
List<Document> list1 = List.of(new Document("doc1"));
List<Document> list2 = List.of(new Document("doc2"));

assertThrows(UnsupportedOperationException.class, () -> strategy.fuseResults(5, list1, list2));
List<Document> result = strategy.fuseResults(10, (List<Document>[]) null);

assertTrue(result.isEmpty());
}

@Test
void fuseResults_withTopKLessThanOne_returnsEmptyList() {
WeightedAverageStrategy strategy = new WeightedAverageStrategy();
Document doc = new Document("id1", "content1", Collections.emptyMap());

List<Document> result = strategy.fuseResults(0, List.of(doc));

assertTrue(result.isEmpty());
}

@Test
void fuseResults_withSingleList_keepsRankOrder() {
WeightedAverageStrategy strategy = new WeightedAverageStrategy();
Document doc1 = new Document("id1", "content1", Collections.emptyMap());
Document doc2 = new Document("id2", "content2", Collections.emptyMap());

List<Document> result = strategy.fuseResults(10, Arrays.asList(doc1, doc2));

assertEquals(2, result.size());
assertEquals("id1", result.get(0).getId());
assertEquals("id2", result.get(1).getId());
}

@Test
void fuseResults_withTopKLimit_returnsLimitedResults() {
WeightedAverageStrategy strategy = new WeightedAverageStrategy();
Document doc1 = new Document("id1", "content1", Collections.emptyMap());
Document doc2 = new Document("id2", "content2", Collections.emptyMap());
Document doc3 = new Document("id3", "content3", Collections.emptyMap());

List<Document> result = strategy.fuseResults(2, Arrays.asList(doc1, doc2, doc3));

assertEquals(2, result.size());
}

@Test
void fuseResults_withDuplicatesAcrossLists_mergesScores() {
WeightedAverageStrategy strategy = new WeightedAverageStrategy();
Document docA = new Document("a", "content a", Collections.emptyMap());
Document docB = new Document("b", "content b", Collections.emptyMap());
Document docC = new Document("c", "content c", Collections.emptyMap());
Document docACopy = new Document("a", "content a copy", Collections.emptyMap());

List<Document> result = strategy.fuseResults(10, Arrays.asList(docB, docA, docC), List.of(docACopy));

assertEquals(3, result.size());
assertEquals("a", result.get(0).getId());
}

@Test
void fuseResults_withWeights_prefersHigherWeightedSource() {
WeightedAverageStrategy strategy = new WeightedAverageStrategy();
Document vectorDoc = new Document("vector", "vector content", Collections.emptyMap());
Document keywordDoc = new Document("keyword", "keyword content", Collections.emptyMap());

List<Document> result = strategy.fuseResultsWithWeights(10, Arrays.asList(0.2, 0.8), List.of(vectorDoc),
List.of(keywordDoc));

assertEquals(2, result.size());
assertEquals("keyword", result.get(0).getId());
}

@Test
void fuseResults_withNullListAndNullDocument_skipsInvalidEntries() {
WeightedAverageStrategy strategy = new WeightedAverageStrategy();
Document doc = new Document("id1", "content1", Collections.emptyMap());

List<Document> result = strategy.fuseResults(10, Arrays.asList(doc, null), null);

assertEquals(1, result.size());
assertEquals("id1", result.get(0).getId());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -60,7 +62,7 @@ void testRetrieve_combinesVectorAndKeywordResults() {
Document fusedDoc = new Document("fused result");

when(vectorStore.similaritySearch(any(SearchRequest.class))).thenReturn(List.of(vectorDoc));
when(fusionStrategy.fuseResults(anyInt(), any(), any())).thenReturn(List.of(fusedDoc));
when(fusionStrategy.fuseResultsWithWeights(anyInt(), anyList(), any(), any())).thenReturn(List.of(fusedDoc));

strategy.setKeywordResults(List.of(keywordDoc));

Expand All @@ -75,7 +77,7 @@ void testRetrieve_combinesVectorAndKeywordResults() {
@Test
void testRetrieve_emptyResults() {
when(vectorStore.similaritySearch(any(SearchRequest.class))).thenReturn(List.of());
when(fusionStrategy.fuseResults(anyInt(), any(), any())).thenReturn(List.of());
when(fusionStrategy.fuseResultsWithWeights(anyInt(), anyList(), any(), any())).thenReturn(List.of());

strategy.setKeywordResults(List.of());

Expand All @@ -86,6 +88,28 @@ void testRetrieve_emptyResults() {
assertTrue(results.isEmpty());
}

@Test
void testRetrieve_passesConfiguredFusionWeights() {
Document vectorDoc = new Document("vector result");
Document keywordDoc = new Document("keyword result");

when(vectorStore.similaritySearch(any(SearchRequest.class))).thenReturn(List.of(vectorDoc));
when(fusionStrategy.fuseResultsWithWeights(anyInt(), anyList(), any(), any())).thenReturn(List.of(vectorDoc));

strategy.setKeywordResults(List.of(keywordDoc));

HybridSearchRequest request = HybridSearchRequest.builder()
.query("test")
.topK(5)
.vectorWeight(0.7)
.keywordWeight(0.3)
.build();

strategy.retrieve(request);

verify(fusionStrategy).fuseResultsWithWeights(eq(5), eq(Arrays.asList(0.7, 0.3)), any(), any());
}

static class TestHybridRetrievalStrategy extends AbstractHybridRetrievalStrategy {

private List<Document> keywordResults = List.of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -71,14 +72,14 @@ void retrieve_returnsVectorDocsOnly() {
HybridSearchRequest request = HybridSearchRequest.builder().query("test query").topK(5).build();

when(vectorStore.similaritySearch(any(SearchRequest.class))).thenReturn(vectorResults);
when(fusionStrategy.fuseResults(eq(5), any(), any())).thenReturn(vectorResults);
when(fusionStrategy.fuseResultsWithWeights(eq(5), anyList(), any(), any())).thenReturn(vectorResults);

List<Document> result = strategy.retrieve(request);

assertEquals(1, result.size());
assertEquals("v1", result.get(0).getId());
verify(vectorStore).similaritySearch(any(SearchRequest.class));
verify(fusionStrategy).fuseResults(eq(5), any(), any());
verify(fusionStrategy).fuseResultsWithWeights(eq(5), eq(Arrays.asList(0.5, 0.5)), any(), any());
}

}
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

<!-- Maven Compiler Plugin -->
<maven-compiler-plugin.version>3.11.0</maven-compiler-plugin.version>
<maven-surefire-plugin.version>3.2.5</maven-surefire-plugin.version>
<!-- spotless version-->
<spotless-maven-plugin.version>3.1.0</spotless-maven-plugin.version>
<docker-java.version>3.5.3</docker-java.version>
Expand Down Expand Up @@ -373,6 +374,11 @@
</compilerArgs>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>${maven-surefire-plugin.version}</version>
</plugin>
</plugins>
</build>

Expand Down
Loading