diff --git a/config/config.go b/config/config.go index 38da06b..a7c34ba 100644 --- a/config/config.go +++ b/config/config.go @@ -80,6 +80,12 @@ type UpdateConfig struct { type SearchConfig struct { Boost BoostConfig `yaml:"boost"` Hybrid HybridConfig `yaml:"hybrid"` + Dedup DedupConfig `yaml:"dedup"` +} + +// DedupConfig controls file-level deduplication of search results. +type DedupConfig struct { + Enabled bool `yaml:"enabled"` } type HybridConfig struct { @@ -300,6 +306,9 @@ func DefaultConfig() *Config { RPGMaxDirtyFilesPerBatch: DefaultWatchRPGMaxDirtyFilesPerBatch, }, Search: SearchConfig{ + Dedup: DedupConfig{ + Enabled: true, + }, Hybrid: HybridConfig{ Enabled: false, K: 60, diff --git a/search/dedup.go b/search/dedup.go new file mode 100644 index 0000000..a7ed93d --- /dev/null +++ b/search/dedup.go @@ -0,0 +1,17 @@ +package search + +import "github.com/yoanbernabeu/grepai/store" + +// DeduplicateByFile keeps only the highest-scoring chunk per file path. +func DeduplicateByFile(results []store.SearchResult) []store.SearchResult { + seen := make(map[string]bool, len(results)) + deduped := make([]store.SearchResult, 0, len(results)) + for _, r := range results { + if seen[r.Chunk.FilePath] { + continue + } + seen[r.Chunk.FilePath] = true + deduped = append(deduped, r) + } + return deduped +} diff --git a/search/dedup_test.go b/search/dedup_test.go new file mode 100644 index 0000000..76c437d --- /dev/null +++ b/search/dedup_test.go @@ -0,0 +1,62 @@ +package search + +import ( + "testing" + + "github.com/yoanbernabeu/grepai/store" +) + +func TestDeduplicateByFile(t *testing.T) { + results := []store.SearchResult{ + {Chunk: store.Chunk{ID: "a_0", FilePath: "a.go"}, Score: 0.9}, + {Chunk: store.Chunk{ID: "b_0", FilePath: "b.go"}, Score: 0.8}, + {Chunk: store.Chunk{ID: "a_1", FilePath: "a.go"}, Score: 0.7}, + {Chunk: store.Chunk{ID: "c_0", FilePath: "c.go"}, Score: 0.6}, + {Chunk: store.Chunk{ID: "b_1", FilePath: "b.go"}, Score: 0.5}, + } + + deduped := DeduplicateByFile(results) + + if len(deduped) != 3 { + t.Fatalf("expected 3 results, got %d", len(deduped)) + } + + expected := []struct { + id string + score float32 + }{ + {"a_0", 0.9}, + {"b_0", 0.8}, + {"c_0", 0.6}, + } + + for i, want := range expected { + if deduped[i].Chunk.ID != want.id { + t.Errorf("result[%d]: expected ID %q, got %q", i, want.id, deduped[i].Chunk.ID) + } + if deduped[i].Score != want.score { + t.Errorf("result[%d]: expected score %v, got %v", i, want.score, deduped[i].Score) + } + } +} + +func TestDeduplicateByFile_Empty(t *testing.T) { + deduped := DeduplicateByFile(nil) + if len(deduped) != 0 { + t.Fatalf("expected 0 results, got %d", len(deduped)) + } +} + +func TestDeduplicateByFile_AllUnique(t *testing.T) { + results := []store.SearchResult{ + {Chunk: store.Chunk{ID: "a_0", FilePath: "a.go"}, Score: 0.9}, + {Chunk: store.Chunk{ID: "b_0", FilePath: "b.go"}, Score: 0.8}, + {Chunk: store.Chunk{ID: "c_0", FilePath: "c.go"}, Score: 0.7}, + } + + deduped := DeduplicateByFile(results) + + if len(deduped) != 3 { + t.Fatalf("expected 3 results, got %d", len(deduped)) + } +} diff --git a/search/search.go b/search/search.go index 5776e78..7bd5e87 100644 --- a/search/search.go +++ b/search/search.go @@ -13,6 +13,7 @@ type Searcher struct { embedder embedder.Embedder boostCfg config.BoostConfig hybridCfg config.HybridConfig + dedupCfg config.DedupConfig } func NewSearcher(st store.VectorStore, emb embedder.Embedder, searchCfg config.SearchConfig) *Searcher { @@ -21,26 +22,27 @@ func NewSearcher(st store.VectorStore, emb embedder.Embedder, searchCfg config.S embedder: emb, boostCfg: searchCfg.Boost, hybridCfg: searchCfg.Hybrid, + dedupCfg: searchCfg.Dedup, } } func (s *Searcher) Search(ctx context.Context, query string, limit int, pathPrefix string) ([]store.SearchResult, error) { - // Embed the query queryVector, err := s.embedder.Embed(ctx, query) if err != nil { return nil, err } - // Fetch more results to allow re-ranking - fetchLimit := limit * 2 + fetchMultiplier := 2 + if s.dedupCfg.Enabled { + fetchMultiplier = 4 + } + fetchLimit := limit * fetchMultiplier var results []store.SearchResult if s.hybridCfg.Enabled { - // Hybrid search: combine vector + text search with RRF results, err = s.hybridSearch(ctx, query, queryVector, fetchLimit, pathPrefix) } else { - // Vector-only search results, err = s.store.Search(ctx, queryVector, fetchLimit, store.SearchOptions{PathPrefix: pathPrefix}) } @@ -48,10 +50,12 @@ func (s *Searcher) Search(ctx context.Context, query string, limit int, pathPref return nil, err } - // Apply structural boosting results = ApplyBoost(results, s.boostCfg) - // Trim to requested limit + if s.dedupCfg.Enabled { + results = DeduplicateByFile(results) + } + if len(results) > limit { results = results[:limit] } @@ -61,13 +65,11 @@ func (s *Searcher) Search(ctx context.Context, query string, limit int, pathPref // hybridSearch combines vector search and text search using RRF. func (s *Searcher) hybridSearch(ctx context.Context, query string, queryVector []float32, limit int, pathPrefix string) ([]store.SearchResult, error) { - // Vector search vectorResults, err := s.store.Search(ctx, queryVector, limit, store.SearchOptions{PathPrefix: pathPrefix}) if err != nil { return nil, err } - // Text search (get all chunks first) allChunks, err := s.store.GetAllChunks(ctx) if err != nil { return nil, err @@ -75,10 +77,9 @@ func (s *Searcher) hybridSearch(ctx context.Context, query string, queryVector [ textResults := TextSearch(ctx, allChunks, query, limit, pathPrefix) - // Combine with RRF k := s.hybridCfg.K if k <= 0 { - k = 60 // default + k = 60 } return ReciprocalRankFusion(k, limit, vectorResults, textResults), nil