diff --git a/src/core/search/indices.h b/src/core/search/indices.h index 107426d46f16..7c70e93d4499 100644 --- a/src/core/search/indices.h +++ b/src/core/search/indices.h @@ -133,6 +133,22 @@ template struct BaseStringIndex : public BaseIndex { return field_num_docs_; } + // Sum of per-doc field lengths. Pair with GetFieldNumDocs() to aggregate + // avg doc len across shards without going through the lossy ratio. + size_t GetFieldTotalDocsLen() const { + return field_total_docs_len_; + } + + // Schema canonical identifier of this field. Set by FieldIndices after + // construction; empty when the index is built outside that path. + std::string_view field_ident() const { + return field_ident_; + } + + void set_field_ident(std::string_view ident) { + field_ident_ = ident; + } + protected: using StringList = DocumentAccessor::StringList; @@ -160,6 +176,9 @@ template struct BaseStringIndex : public BaseIndex { std::vector field_doc_lengths_; // DocId -> sum of TF in this field size_t field_total_docs_len_ = 0; size_t field_num_docs_ = 0; // Number of docs with non-empty content in this field + + // Borrows from owning Schema's field map; the schema outlives this index. + std::string_view field_ident_; }; // Index for text fields. diff --git a/src/core/search/scoring.cc b/src/core/search/scoring.cc index f69f3e405fb7..e621e5b2fb0e 100644 --- a/src/core/search/scoring.cc +++ b/src/core/search/scoring.cc @@ -14,4 +14,33 @@ double ScoreDocument(ScorerFn scorer, const ScoringContext& ctx, return score; } +void GlobalScoringStats::Merge(const ShardScoringStats& shard) { + num_docs += shard.num_docs; + for (const auto& [field, stats] : shard.field_stats) { + auto& dst = field_stats[field]; + dst.num_docs += stats.num_docs; + dst.total_docs_len += stats.total_docs_len; + } + for (const auto& [field, terms] : shard.term_stats) { + auto& dst = term_stats[field]; + for (const auto& [term, count] : terms) + dst[term] += count; + } +} + +double GlobalScoringStats::GetFieldAvgDocLen(std::string_view field_ident) const { + auto it = field_stats.find(field_ident); + if (it == field_stats.end() || it->second.num_docs == 0) + return 0.0; + return static_cast(it->second.total_docs_len) / it->second.num_docs; +} + +size_t GlobalScoringStats::GetTermDocs(std::string_view field_ident, std::string_view term) const { + auto field_it = term_stats.find(field_ident); + if (field_it == term_stats.end()) + return 0; + auto term_it = field_it->second.find(term); + return term_it == field_it->second.end() ? 0 : term_it->second; +} + } // namespace dfly::search diff --git a/src/core/search/scoring.h b/src/core/search/scoring.h index ec9a2e0ada51..0aa726c40503 100644 --- a/src/core/search/scoring.h +++ b/src/core/search/scoring.h @@ -4,10 +4,13 @@ #pragma once +#include + #include #include #include #include +#include #include #include "core/search/base.h" @@ -92,4 +95,30 @@ inline double TfIdfDocNorm(const ScoringContext& ctx, const ScoringTermInfo& ter double ScoreDocument(ScorerFn scorer, const ScoringContext& ctx, const std::vector& terms); +// Single-shard slice of the counts a scorer needs. Keys are schema canonical +// names; terms are post-synonym-resolution. +struct ShardScoringStats { + struct FieldStats { + size_t num_docs = 0; + size_t total_docs_len = 0; + }; + + size_t num_docs = 0; + absl::flat_hash_map field_stats; + absl::flat_hash_map> term_stats; +}; + +// Cluster-wide aggregate of ShardScoringStats. Injected into ScoringContext +// so ranking is independent of how documents are partitioned across shards. +struct GlobalScoringStats { + size_t num_docs = 0; + absl::flat_hash_map field_stats; + absl::flat_hash_map> term_stats; + + void Merge(const ShardScoringStats& shard); + + double GetFieldAvgDocLen(std::string_view field_ident) const; + size_t GetTermDocs(std::string_view field_ident, std::string_view term) const; +}; + } // namespace dfly::search diff --git a/src/core/search/search.cc b/src/core/search/search.cc index 6334d6a9886a..36f36dc81db4 100644 --- a/src/core/search/search.cc +++ b/src/core/search/search.cc @@ -116,8 +116,16 @@ struct ProfileBuilder { struct BasicSearch { using LogicOp = AstLogicalNode::LogicOp; - BasicSearch(const FieldIndices* indices, ScorerFn scorer = nullptr) - : indices_{indices}, scorer_{scorer} { + // Cached posting list for a (TextIndex, term) pair so the AST walker + // and the scoring loop don't repeat the rax tree lookup. + struct MatchedTerm { + TextIndex* index; + std::string term; + const TextIndex::Container* container; + }; + + BasicSearch(const FieldIndices* indices, ScorerFn scorer, const GlobalScoringStats* global_stats) + : indices_{indices}, scorer_{scorer}, global_stats_{global_stats} { } void EnableProfiling() { @@ -233,38 +241,37 @@ struct BasicSearch { indices = indices_->GetAllTextIndices(); } - // Track matched terms for scoring (prefix/suffix/infix expand to multiple terms). - // Synonym shadow entries (freq=0) are resolved to their group_id for correct scoring. - if (scorer_) { - for (auto* index : indices) { - auto term_cb = [this, index](string_view term, const auto*) { + // Single trie walk: enumerate matching terms once, dispatch each container + // to AddMatchedTerm (for scoring) and to the union (for the result set). + vector sub_results; + sub_results.reserve(indices.size()); + for (auto* index : indices) { + IndexResult per_index{}; + auto term_cb = [&per_index, this, index](string_view term, const auto* container) { + if (scorer_) { std::string resolved{term}; + // Synonym shadow entries have freq=0; scoring must use the group's posting list. + const auto* score_container = container; if (auto synonyms = indices_->GetSynonyms(); synonyms) { - if (auto group_id = synonyms->GetGroupToken(resolved); group_id) + if (auto group_id = synonyms->GetGroupToken(resolved); group_id) { resolved = std::move(*group_id); + score_container = index->Matching(resolved, /*strip_whitespace=*/false); + } } - AddMatchedTerm(index, std::move(resolved)); - }; - if constexpr (T == TagType::PREFIX) - index->MatchPrefixWithTerm(node.affix, term_cb); - else if constexpr (T == TagType::SUFFIX) - index->MatchSuffixWithTerm(node.affix, term_cb); - else if constexpr (T == TagType::INFIX) - index->MatchInfixWithTerm(node.affix, term_cb); - } - } - - auto mapping = [&node, this](TextIndex* index) { + AddMatchedTerm(index, std::move(resolved), score_container); + } + Merge(IndexResult{container}, &per_index, LogicOp::OR); + }; if constexpr (T == TagType::PREFIX) - return CollectMatches(index, node.affix, &TextIndex::MatchPrefix); + index->MatchPrefixWithTerm(node.affix, term_cb); else if constexpr (T == TagType::SUFFIX) - return CollectMatches(index, node.affix, &TextIndex::MatchSuffix); + index->MatchSuffixWithTerm(node.affix, term_cb); else if constexpr (T == TagType::INFIX) - return CollectMatches(index, node.affix, &TextIndex::MatchInfix); - else - return vector{}; - }; - return UnifyResults(GetSubResults(indices, mapping), LogicOp::OR); + index->MatchInfixWithTerm(node.affix, term_cb); + sub_results.push_back(std::move(per_index)); + } + + return UnifyResults(std::move(sub_results), LogicOp::OR); } // "term": access field's text index or unify results from all text indices if no field is set @@ -281,26 +288,26 @@ struct BasicSearch { if (!active_field.empty()) { if (auto* index = GetIndex(active_field); index) { + const auto* container = index->Matching(term, strip_whitespace); if (scorer_) - AddMatchedTerm(index, term); - return IndexResult{index->Matching(term, strip_whitespace)}; + AddMatchedTerm(index, term, container); + return IndexResult{container}; } return IndexResult{}; } vector selected_indices = indices_->GetAllTextIndices(); - // Track terms for scoring - if (scorer_) { - for (auto* index : selected_indices) - AddMatchedTerm(index, term); + vector sub_results; + sub_results.reserve(selected_indices.size()); + for (auto* index : selected_indices) { + const auto* container = index->Matching(term, strip_whitespace); + if (scorer_) + AddMatchedTerm(index, term, container); + sub_results.emplace_back(IndexResult{container}); } - auto mapping = [&term, strip_whitespace](TextIndex* index) { - return index->Matching(term, strip_whitespace); - }; - - return UnifyResults(GetSubResults(selected_indices, mapping), LogicOp::OR); + return UnifyResults(std::move(sub_results), LogicOp::OR); } // [range]: access field's numeric index @@ -559,6 +566,7 @@ struct BasicSearch { struct TermCursor { TextIndex* index; size_t term_docs; + double field_avg_doc_len; // pre-resolved (global or local) TextIndex::Container::BlockListIterator it; TextIndex::Container::BlockListIterator end; }; @@ -582,17 +590,21 @@ struct BasicSearch { // Ensure sorted for cursor-based scoring sort(all_docs.begin(), all_docs.end()); - // Open cursors on posting lists for each matched term + // Open cursors on posting lists cached during the AST walk. vector cursors; cursors.reserve(matched_text_terms_.size()); - for (auto& [index, term] : matched_text_terms_) { - auto* container = index->Matching(term, /*strip_whitespace=*/false); + for (const auto& [index, term, container] : matched_text_terms_) { if (!container) continue; - cursors.push_back({index, container->Size(), container->begin(), container->end()}); + string_view field_ident = index->field_ident(); + size_t term_docs = + global_stats_ ? global_stats_->GetTermDocs(field_ident, term) : container->Size(); + double avg = global_stats_ ? global_stats_->GetFieldAvgDocLen(field_ident) + : index->GetFieldAvgDocLen(); + cursors.push_back({index, term_docs, avg, container->begin(), container->end()}); } - ScoringContext ctx{indices_->GetAllDocs().size()}; + ScoringContext ctx{global_stats_ ? global_stats_->num_docs : indices_->GetAllDocs().size()}; // Score all docs - reuse term_infos buffer across iterations vector> scored; @@ -605,7 +617,7 @@ struct BasicSearch { term_infos[t].term_freq = SeekCursor(cursors[t], doc); if (cursors[t].index) { term_infos[t].field_doc_len = cursors[t].index->GetFieldDocLength(doc); - term_infos[t].field_avg_doc_len = cursors[t].index->GetFieldAvgDocLen(); + term_infos[t].field_avg_doc_len = cursors[t].field_avg_doc_len; } } scored.emplace_back(static_cast(ScoreDocument(scorer_, ctx, term_infos)), doc); @@ -631,13 +643,14 @@ struct BasicSearch { return std::make_tuple(std::move(out), total_size, std::move(text_scores)); } - void AddMatchedTerm(TextIndex* index, string term) { + void AddMatchedTerm(TextIndex* index, string term, const TextIndex::Container* container) { if (matched_terms_set_.emplace(index, term).second) - matched_text_terms_.emplace_back(index, std::move(term)); + matched_text_terms_.push_back({index, std::move(term), container}); } const FieldIndices* indices_; ScorerFn scorer_ = nullptr; + const GlobalScoringStats* global_stats_ = nullptr; string error_; optional profile_builder_ = ProfileBuilder{}; @@ -645,10 +658,8 @@ struct BasicSearch { std::vector> knn_scores_; vector> knn_distances_; - // Tracked text terms for scoring: (TextIndex*, normalized_term) - // Deduplicated via matched_terms_set_ to avoid double-counting synonyms resolved to same - // group_id. - vector> matched_text_terms_; + // Deduped (TextIndex, normalized_term) pairs with their cached posting list. + vector matched_text_terms_; absl::flat_hash_set> matched_terms_set_; }; @@ -656,6 +667,110 @@ struct BasicSearch { #pragma GCC diagnostic pop #endif +// Walks the AST to collect per-(field, term) and per-field stats for the +// scoring phase. +struct StatsCollector { + explicit StatsCollector(const FieldIndices* indices) : indices_{indices} { + stats_.num_docs = indices_->GetAllDocs().size(); + } + + ShardScoringStats Take() && { + return std::move(stats_); + } + + void Walk(const AstNode& node, string_view active_field) { + visit([this, active_field](const auto& inner) { Visit(inner, active_field); }, node.Variant()); + } + + private: + // Catch-all for nodes that don't reference text terms; specific overloads below win. + template void Visit(const T&, string_view) { + } + + void Visit(const AstFieldNode& node, string_view) { + DCHECK(node.node); + Walk(*node.node, node.field); + } + void Visit(const AstLogicalNode& node, string_view active_field) { + for (const auto& child : node.nodes) + Walk(child, active_field); + } + void Visit(const AstNegateNode& node, string_view active_field) { + Walk(*node.node, active_field); + } + void Visit(const AstOptionalNode& node, string_view active_field) { + Walk(*node.node, active_field); + } + void Visit(const AstKnnNode& node, string_view active_field) { + Walk(*node.filter, active_field); + } + + void Visit(const AstTermNode& node, string_view active_field) { + string term = node.affix; + bool strip_whitespace = true; + if (auto* syn = indices_->GetSynonyms(); syn) { + if (auto group_id = syn->GetGroupToken(term); group_id) { + term = *group_id; + strip_whitespace = false; + } + } + for (auto* idx : SelectTextIndices(active_field)) { + const auto* container = idx->Matching(term, strip_whitespace); + Record(idx, term, container); + } + } + + template void Visit(const AstAffixNode& node, string_view active_field) { + static_assert(T != TagType::REGULAR); + for (auto* idx : SelectTextIndices(active_field)) { + auto cb = [this, idx](string_view term, const auto* container) { + string resolved{term}; + // Synonym shadow has freq=0; stats must come from the group's posting list. + const auto* effective = container; + if (auto* syn = indices_->GetSynonyms(); syn) { + if (auto group_id = syn->GetGroupToken(resolved); group_id) { + resolved = std::move(*group_id); + effective = idx->Matching(resolved, /*strip_whitespace=*/false); + } + } + Record(idx, std::move(resolved), effective); + }; + if constexpr (T == TagType::PREFIX) + idx->MatchPrefixWithTerm(node.affix, cb); + else if constexpr (T == TagType::SUFFIX) + idx->MatchSuffixWithTerm(node.affix, cb); + else if constexpr (T == TagType::INFIX) + idx->MatchInfixWithTerm(node.affix, cb); + } + } + + vector SelectTextIndices(string_view active_field) { + if (active_field.empty()) + return indices_->GetAllTextIndices(); + auto* idx = dynamic_cast(indices_->GetIndex(active_field)); + return idx ? vector{idx} : vector{}; + } + + void Record(TextIndex* idx, string term, const TextIndex::Container* container) { + string_view field_ident = idx->field_ident(); + if (field_ident.empty()) + return; + if (!seen_.emplace(idx, term).second) + return; + auto [it, inserted] = + stats_.field_stats.try_emplace(string{field_ident}, ShardScoringStats::FieldStats{}); + if (inserted) { + it->second.num_docs = idx->GetFieldNumDocs(); + it->second.total_docs_len = idx->GetFieldTotalDocsLen(); + } + stats_.term_stats[string{field_ident}][std::move(term)] = container ? container->Size() : 0; + } + + const FieldIndices* indices_; + ShardScoringStats stats_; + absl::flat_hash_set> seen_; +}; + } // namespace AstNode OptionalNumericFilter::Node(std::string field) { @@ -698,8 +813,10 @@ void FieldIndices::CreateIndices(PMR_NS::memory_resource* mr) { switch (field_info.type) { case SchemaField::TEXT: { const auto& tparams = std::get(field_info.special_params); - indices_[field_ident] = + auto idx = make_unique(mr, &options_.stopwords, synonyms_, tparams.with_suffixtrie); + idx->set_field_ident(field_ident); + indices_[field_ident] = std::move(idx); break; } case SchemaField::NUMERIC: { @@ -898,15 +1015,23 @@ bool SearchAlgorithm::Init(string_view query, const QueryParams* params, return true; } -SearchResult SearchAlgorithm::Search(const FieldIndices* index, size_t cuttoff_limit) const { +SearchResult SearchAlgorithm::Search(const FieldIndices* index, size_t cuttoff_limit, + const GlobalScoringStats* global_stats) const { DCHECK(query_); - auto bs = BasicSearch{index, scorer_}; + auto bs = BasicSearch{index, scorer_, global_stats}; if (profiling_enabled_) bs.EnableProfiling(); return bs.Search(*query_, cuttoff_limit); } +ShardScoringStats SearchAlgorithm::CollectScoringStats(const FieldIndices* index) const { + DCHECK(query_); + StatsCollector collector{index}; + collector.Walk(*query_, ""); + return std::move(collector).Take(); +} + std::optional SearchAlgorithm::GetKnnScoreSortOption() const { // HNSW KNN query if (knn_hnsw_score_sort_option_) { diff --git a/src/core/search/search.h b/src/core/search/search.h index 0578ead865ef..f3ec55b8c165 100644 --- a/src/core/search/search.h +++ b/src/core/search/search.h @@ -220,9 +220,15 @@ class SearchAlgorithm { bool Init(std::string_view query, const QueryParams* params, const OptionalFilters* filters = nullptr); - // Search on given index with predefined limit for cutting off result ids + // Search on given index with predefined limit for cutting off result ids. + // When global_stats is non-null, scorers see cluster-wide counts instead of + // values local to `index`. SearchResult Search(const FieldIndices* index, - size_t cuttoff_limit = std::numeric_limits::max()) const; + size_t cuttoff_limit = std::numeric_limits::max(), + const GlobalScoringStats* global_stats = nullptr) const; + + // This shard's contribution to GlobalScoringStats. Requires Init(). + ShardScoringStats CollectScoringStats(const FieldIndices* index) const; std::optional GetKnnScoreSortOption() const; diff --git a/src/server/search/aggregator.cc b/src/server/search/aggregator.cc index 6e2b4928b287..53c525f3ebc5 100644 --- a/src/server/search/aggregator.cc +++ b/src/server/search/aggregator.cc @@ -116,6 +116,10 @@ void Aggregator::DoSort(const SortParams& sort_params) { continue; return order == SortOrder::ASC ? *lv < *rv : *lv > *rv; } + // All explicit fields tied: break by hidden __key for cross-shard determinism. + auto lk = l.find("__key"), rk = r.find("__key"); + if (lk != l.end() && rk != r.end()) + return lk->second < rk->second; return false; }; diff --git a/src/server/search/doc_index.cc b/src/server/search/doc_index.cc index 8493c5989cb8..9fb8b8a0cfc0 100644 --- a/src/server/search/doc_index.cc +++ b/src/server/search/doc_index.cc @@ -853,16 +853,20 @@ vector ShardDocIndex::KeepTopKSorted(vector* ids, } SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& params, - search::SearchAlgorithm* search_algo, - bool is_knn_prefilter) const { + search::SearchAlgorithm* search_algo, bool is_knn_prefilter, + const search::GlobalScoringStats* global_stats) const { size_t limit = params.limit_offset + params.limit_total; + // Disable BasicSearch's per-shard cutoff; we re-rank by (score, key) below. + const bool sort_by_text_score = params.scorer || params.with_scores; + // If we don't sort the documents, we don't need to copy more ids than are requested // Also for HNSW KNN search we don't cut results at the search stage. - bool can_cut = !params.sort_option && !search_algo->GetKnnScoreSortOption() && !is_knn_prefilter; + bool can_cut = !params.sort_option && !search_algo->GetKnnScoreSortOption() && + !is_knn_prefilter && !sort_by_text_score; size_t id_cutoff_limit = can_cut ? limit : numeric_limits::max(); - auto result = search_algo->Search(&*indices_, id_cutoff_limit); + auto result = search_algo->Search(&*indices_, id_cutoff_limit, global_stats); if (!result.error.empty()) return {facade::ErrorReply(std::move(result.error))}; @@ -916,6 +920,40 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa } } + // Re-rank by (score, key) so per-shard top-K matches what a global merge + // would pick. Skipped when SORTBY or KNN drives the order — those vectors + // are positionally aligned with result.ids and reordering desyncs them. + if (sort_by_text_score && sort_scores.empty() && result.knn_scores.empty() && + !result.text_scores.empty()) { + struct Scored { + float score; + std::string_view key; + search::DocId doc; + }; + std::vector entries; + entries.reserve(result.text_scores.size()); + for (const auto& [doc, score] : result.text_scores) + entries.push_back({score, key_index_.Get(doc), doc}); + + const size_t take = std::min(limit, entries.size()); + std::partial_sort(entries.begin(), entries.begin() + take, entries.end(), + [](const Scored& a, const Scored& b) { + if (a.score != b.score) + return a.score > b.score; + return a.key < b.key; + }); + + // Trim text_scores to the surviving top-K so the score map below stays small. + result.ids.clear(); + result.ids.reserve(take); + result.text_scores.clear(); + result.text_scores.reserve(take); + for (size_t i = 0; i < take; i++) { + result.ids.push_back(entries[i].doc); + result.text_scores.emplace_back(entries[i].doc, entries[i].score); + } + } + // Cut off unnecessary items result.ids.resize(min(result.ids.size(), limit)); @@ -965,10 +1003,16 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa return {result.total - expired_count, std::move(out), std::move(result.profile)}; } +search::ShardScoringStats ShardDocIndex::CollectScoringStats( + search::SearchAlgorithm* search_algo) const { + return search_algo->CollectScoringStats(&*indices_); +} + vector ShardDocIndex::SearchForAggregator( const OpArgs& op_args, const AggregateParams& params, search::SearchAlgorithm* search_algo) const { - auto search_results = search_algo->Search(&*indices_); + auto search_results = search_algo->Search(&*indices_, std::numeric_limits::max(), + params.global_scoring_stats); if (!search_results.error.empty()) return {}; @@ -1024,7 +1068,7 @@ vector ShardDocIndex::LoadDocEntriesWithScores( auto entry = LoadEntry(doc, op_args); if (!entry) continue; - auto& [_, accessor] = *entry; + auto& [key, accessor] = *entry; SearchDocData extracted_sort_indicies; extracted_sort_indicies.reserve(sort_indicies.size()); @@ -1045,6 +1089,8 @@ vector ShardDocIndex::LoadDocEntriesWithScores( if (!text_score_map.empty()) { if (auto it = text_score_map.find(doc); it != text_score_map.end()) out.back()["__score"] = static_cast(it->second); + // Hidden tie-breaker for SORTBY @__score; not added to fields_to_print. + out.back()["__key"] = string{key}; } } return out; diff --git a/src/server/search/doc_index.h b/src/server/search/doc_index.h index c6aa10f972b5..c897000c4759 100644 --- a/src/server/search/doc_index.h +++ b/src/server/search/doc_index.h @@ -198,6 +198,9 @@ struct AggregateParams { bool add_scores = false; // ADDSCORES flag search::ScorerFn scorer = nullptr; // SCORER parameter (null = not set) + + // Set only for multi-shard scoring queries; not owned. + const search::GlobalScoringStats* global_scoring_stats = nullptr; }; // Stores basic info about a document index. @@ -352,8 +355,14 @@ class ShardDocIndex { ~ShardDocIndex(); // Perform search on all indexed documents and return results. + // When global_stats is non-null, scorers see cluster-wide counts. SearchResult Search(const OpArgs& op_args, const SearchParams& params, - search::SearchAlgorithm* search_algo, bool is_knn_prefilter) const; + search::SearchAlgorithm* search_algo, bool is_knn_prefilter, + const search::GlobalScoringStats* global_stats) const; + + // This shard's contribution to a GlobalScoringStats. search_algo must be + // Init()-ed. + search::ShardScoringStats CollectScoringStats(search::SearchAlgorithm* search_algo) const; // Perform search and load requested values - note params might be interpreted differently. std::vector SearchForAggregator(const OpArgs& op_args, diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index 0e99d989cdde..2f59071b7591 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -1123,6 +1123,20 @@ void SearchReply(const SearchParams& params, } const size_t end = limit + offset; + // Re-sort union of per-shard tops by (score, key) so LIMIT picks global top-K. + const bool scoring_active = params.scorer || params.with_scores; + const bool needs_text_score_sort = + scoring_active && !knn_sort_option && (!params.sort_option || ignore_sort); + if (needs_text_score_sort) { + auto by_score_then_key = [](SerializedSearchDoc* l, SerializedSearchDoc* r) { + if (l->text_score != r->text_score) + return l->text_score > r->text_score; + return l->key < r->key; + }; + partial_sort(docs.begin(), docs.begin() + std::min(end, docs.size()), docs.end(), + by_score_then_key); + } + // Apply SORTBY if its different from the KNN sort if (params.sort_option && !ignore_sort) PartialSort(absl::MakeSpan(docs), end, params.sort_option->order, @@ -1955,17 +1969,51 @@ void CmdFtSearch(CmdArgList args, CommandContext* cmd_cntx) { const bool knn_has_prefilter = knn && knn->HasPreFilter(); bool empty_prefilter_result = true; + // Phase 1 collects per-shard counts; phase 2 scores with the global aggregate. + // Skipped for KNN/HNSW and single-shard. + const bool scoring_active = + (params->scorer || params->with_scores) && (!knn || knn_has_prefilter) && !hnsw_range; + const bool needs_global_stats = scoring_active && shard_set->size() > 1; + search::GlobalScoringStats global_scoring_stats; + const search::GlobalScoringStats* global_stats_ptr = nullptr; + + if (needs_global_stats) { + std::vector shard_stats(shard_set->size()); + cmd_cntx->tx()->Execute( + [&](Transaction* t, EngineShard* es) { + if (auto* index = es->search_indices()->GetIndex(index_name); index) + shard_stats[es->shard_id()] = index->CollectScoringStats(&search_algo); + else + index_not_found.store(true, memory_order_relaxed); + return OpStatus::OK; + }, + false); + + if (index_not_found.load(memory_order_relaxed)) { + cmd_cntx->tx()->Conclude(); // phase 1 ran with conclude=false + return cmd_cntx->SendError(string{index_name} + ": no such index"); + } + + for (auto& s : shard_stats) + global_scoring_stats.Merge(s); + global_stats_ptr = &global_scoring_stats; + } + // If the query does not contain knn component, or it is a hybrid query. // HNSW vector range has no prefilter, so skip per-shard search entirely. if ((!knn || knn_has_prefilter) && !hnsw_range) { - cmd_cntx->tx()->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { + auto search_cb = [&](Transaction* t, EngineShard* es) { if (auto* index = es->search_indices()->GetIndex(index_name); index) - docs[es->shard_id()] = - index->Search(t->GetOpArgs(es), *params, &search_algo, knn_has_prefilter); + docs[es->shard_id()] = index->Search(t->GetOpArgs(es), *params, &search_algo, + knn_has_prefilter, global_stats_ptr); else index_not_found.store(true, memory_order_relaxed); return OpStatus::OK; - }); + }; + if (needs_global_stats) + cmd_cntx->tx()->Execute(std::move(search_cb), true); + else + cmd_cntx->tx()->ScheduleSingleHop(std::move(search_cb)); if (index_not_found.load(memory_order_relaxed)) return cmd_cntx->SendError(string{index_name} + ": no such index"); @@ -2072,7 +2120,9 @@ void CmdFtProfile(CmdArgList args, CommandContext* cmd_cntx) { const ShardId shard_id = es->shard_id(); auto shard_start = absl::Now(); - search_results[shard_id] = index->Search(t->GetOpArgs(es), *params, &search_algo, false); + search_results[shard_id] = index->Search(t->GetOpArgs(es), *params, &search_algo, + /*is_knn_prefilter=*/false, + /*global_stats=*/nullptr); profile_results[shard_id] = {absl::Now() - shard_start}; return OpStatus::OK; @@ -2206,7 +2256,7 @@ void CmdFtAggregate(CmdArgList args, CommandContext* cmd_cntx) { CmdArgParser parser{args}; auto* builder = cmd_cntx->rb(); - const auto params = ParseAggregatorParams(&parser); + auto params = ParseAggregatorParams(&parser); if (SendErrorIfOccurred(params, &parser, cmd_cntx)) return; @@ -2281,7 +2331,8 @@ void CmdFtAggregate(CmdArgList args, CommandContext* cmd_cntx) { sp.limit_total = std::numeric_limits::max(); sp.return_fields.emplace(); // ids-only, skip field serialization prefilter_docs[es->shard_id()] = - index->Search(t->GetOpArgs(es), sp, &search_algo, true); + index->Search(t->GetOpArgs(es), sp, &search_algo, /*is_knn_prefilter=*/true, + /*global_stats=*/nullptr); } return OpStatus::OK; }, @@ -2328,13 +2379,37 @@ void CmdFtAggregate(CmdArgList args, CommandContext* cmd_cntx) { cmd_cntx->tx()->ScheduleSingleHop( make_load_cb(shard_docs, hnsw_range->score_alias, prefilter_text_scores)); } else { - cmd_cntx->tx()->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { + // Same global-stats phase as FT.SEARCH so SORTBY @__score is stable + // across shard counts. + const bool agg_scoring_active = params->scorer || params->add_scores; + const bool agg_needs_global_stats = agg_scoring_active && shard_set->size() > 1; + search::GlobalScoringStats agg_global_stats; + + if (agg_needs_global_stats) { + std::vector shard_stats(shard_set->size()); + cmd_cntx->tx()->Execute( + [&](Transaction* t, EngineShard* es) { + if (auto* index = es->search_indices()->GetIndex(params->index); index) + shard_stats[es->shard_id()] = index->CollectScoringStats(&search_algo); + return OpStatus::OK; + }, + false); + for (auto& s : shard_stats) + agg_global_stats.Merge(s); + params->global_scoring_stats = &agg_global_stats; + } + + auto agg_search_cb = [&](Transaction* t, EngineShard* es) { if (auto* index = es->search_indices()->GetIndex(params->index); index) { query_results[es->shard_id()] = index->SearchForAggregator(t->GetOpArgs(es), params.value(), &search_algo); } return OpStatus::OK; - }); + }; + if (agg_needs_global_stats) + cmd_cntx->tx()->Execute(std::move(agg_search_cb), true); + else + cmd_cntx->tx()->ScheduleSingleHop(std::move(agg_search_cb)); } // ResultContainer is absl::flat_hash_map diff --git a/tests/dragonfly/search_test.py b/tests/dragonfly/search_test.py index 8561dc823683..af021d356486 100644 --- a/tests/dragonfly/search_test.py +++ b/tests/dragonfly/search_test.py @@ -1295,3 +1295,236 @@ async def test_ft_aggregate_addscores(async_client: aioredis.Redis): assert float(results[0]["__score"]) >= float(results[1]["__score"]) await async_client.execute_command("FT.DROPINDEX", "agg_score_idx") + + +# Documents must be diverse enough that BM25/TFIDF rankings differ between docs; +# ~200 docs spread across 4 shards puts ~50 docs per shard, enough that local +# IDF/avgdl differ measurably from global if scoring is per-shard. +def _make_scorer_corpus(): + base_terms = [ + "alpha beta", + "alpha gamma delta", + "beta gamma delta epsilon", + "alpha epsilon", + "gamma", + "delta epsilon zeta", + "alpha zeta eta", + "alpha alpha alpha beta", + "epsilon eta theta", + "beta beta gamma theta iota", + ] + docs = [] + for i in range(200): + content = base_terms[i % len(base_terms)] + if i % 7 == 0: + content = content + " kappa lambda mu nu xi omicron" + if i % 13 == 0: + content = content + " " + content + docs.append({"key": f"sdoc:{i}", "content": content}) + return docs + + +async def _index_and_query_with_scorer(client, scorer: str, query: str, k: int): + idx_name = f"scorer_shard_idx_{scorer.lower().replace('.', '_')}" + await client.execute_command( + "FT.CREATE", idx_name, "ON", "HASH", "PREFIX", "1", "sdoc:", "SCHEMA", "content", "TEXT" + ) + + for doc in _make_scorer_corpus(): + await client.hset(doc["key"], mapping={"content": doc["content"]}) + + res = await client.execute_command( + "FT.SEARCH", + idx_name, + query, + "WITHSCORES", + "SCORER", + scorer, + "LIMIT", + "0", + str(k), + ) + + total = int(res[0]) + docs = [] + i = 1 + while i < len(res): + key = res[i].decode() if isinstance(res[i], bytes) else res[i] + score_raw = res[i + 1] + score = float(score_raw.decode() if isinstance(score_raw, bytes) else score_raw) + docs.append((key, score)) + i += 3 # skip key, score, fields + + await client.execute_command("FT.DROPINDEX", idx_name) + return total, docs + + +@pytest.mark.parametrize("scorer", ["BM25STD", "TFIDF", "TFIDF.DOCNORM"]) +async def test_scorer_consistent_across_shards(df_factory: DflyInstanceFactory, scorer: str): + """Top-K and scores must be identical regardless of proactor_threads count. + + Reproducer for the per-shard IDF/avgdl bug: when a scorer is computed using + only the docs that live on the executing shard, IDF and avgdl shift with + shard count and ranking degrades. The fix runs a stats-collection phase + across all shards before scoring so that every shard uses the same global + statistics. This test verifies that property. + """ + inst1 = df_factory.create(proactor_threads=1) + inst4 = df_factory.create(proactor_threads=4) + df_factory.start_all([inst1, inst4]) + + client1 = inst1.client() + client4 = inst4.client() + + # Query that matches a meaningful subset (most docs contain at least one of these). + query = "alpha|beta|gamma" + k = 20 + + total1, docs1 = await _index_and_query_with_scorer(client1, scorer, query, k) + total4, docs4 = await _index_and_query_with_scorer(client4, scorer, query, k) + + assert total1 == total4, f"total_hits diverged: 1-shard={total1}, 4-shard={total4}" + + keys1 = [d[0] for d in docs1] + keys4 = [d[0] for d in docs4] + assert keys1 == keys4, ( + f"top-{k} ordering differs across shard counts (scorer={scorer}).\n" + f" 1-shard: {keys1}\n" + f" 4-shard: {keys4}" + ) + + # Scores: identical inputs to the scorer => identical outputs (bit-equality + # would be ideal, but tolerate tiny FP variation from sum order). + by_key1 = dict(docs1) + by_key4 = dict(docs4) + for key in keys1: + s1, s4 = by_key1[key], by_key4[key] + denom = max(abs(s1), abs(s4), 1e-9) + assert abs(s1 - s4) / denom < 1e-4, ( + f"score for {key} differs across shard counts (scorer={scorer}): " + f"1-shard={s1}, 4-shard={s4}" + ) + + +async def _aggregate_with_scorer(client, scorer: str, query: str, k: int): + idx_name = f"agg_scorer_idx_{scorer.lower().replace('.', '_')}" + await client.execute_command( + "FT.CREATE", + idx_name, + "ON", + "HASH", + "PREFIX", + "1", + "sdoc:", + "SCHEMA", + "content", + "TEXT", + "doc_id", + "TAG", + ) + for doc in _make_scorer_corpus(): + await client.hset(doc["key"], mapping={"content": doc["content"], "doc_id": doc["key"]}) + + res = await client.execute_command( + "FT.AGGREGATE", + idx_name, + query, + "LOAD", + "1", + "@doc_id", + "SCORER", + scorer, + "ADDSCORES", + "SORTBY", + "2", + "@__score", + "DESC", + "LIMIT", + "0", + str(k), + ) + + docs = [] + for row in res[1:]: + kv = {} + for i in range(0, len(row), 2): + k_ = row[i].decode() if isinstance(row[i], bytes) else row[i] + v_ = row[i + 1].decode() if isinstance(row[i + 1], bytes) else row[i + 1] + kv[k_] = v_ + docs.append((kv["doc_id"], float(kv["__score"]))) + + await client.execute_command("FT.DROPINDEX", idx_name) + return docs + + +@pytest.mark.parametrize("scorer", ["BM25STD", "TFIDF", "TFIDF.DOCNORM"]) +async def test_aggregate_scorer_consistent_across_shards( + df_factory: DflyInstanceFactory, scorer: str +): + """FT.AGGREGATE top-K and order must be identical regardless of shard count. + + Tied scores are broken implicitly by doc key in Aggregator::DoSort. + """ + inst1 = df_factory.create(proactor_threads=1) + inst4 = df_factory.create(proactor_threads=4) + df_factory.start_all([inst1, inst4]) + + docs1 = await _aggregate_with_scorer(inst1.client(), scorer, "alpha|beta|gamma", 20) + docs4 = await _aggregate_with_scorer(inst4.client(), scorer, "alpha|beta|gamma", 20) + + assert [d[0] for d in docs1] == [d[0] for d in docs4], ( + f"FT.AGGREGATE top-K differs across shards (scorer={scorer}).\n" + f" 1-shard: {[d[0] for d in docs1]}\n 4-shard: {[d[0] for d in docs4]}" + ) + for (k1, s1), (k4, s4) in zip(docs1, docs4): + denom = max(abs(s1), abs(s4), 1e-9) + assert abs(s1 - s4) / denom < 1e-4, f"score for {k1} differs: {s1} vs {s4}" + + +async def test_sortby_with_scores_alignment(async_client: aioredis.Redis): + """Regression: SORTBY+WITHSCORES must keep score/sort_score/key consistent + after per-shard re-rank (re-rank must not desync sort_scores). + """ + await async_client.execute_command( + "FT.CREATE", + "sw_idx", + "ON", + "HASH", + "PREFIX", + "1", + "sw:", + "SCHEMA", + "name", + "TEXT", + "rank", + "NUMERIC", + "SORTABLE", + ) + for i in range(1, 13): + await async_client.hset(f"sw:{i}", mapping={"name": "alpha beta", "rank": str(100 - i)}) + + res = await async_client.execute_command( + "FT.SEARCH", "sw_idx", "alpha", "SORTBY", "rank", "DESC", "WITHSCORES", "LIMIT", "0", "5" + ) + + total = int(res[0]) + assert total == 12 + + keys, ranks = [], [] + i = 1 + while i < len(res): + key = res[i].decode() if isinstance(res[i], bytes) else res[i] + fields = res[i + 2] + kv = {} + for j in range(0, len(fields), 2): + k_ = fields[j].decode() if isinstance(fields[j], bytes) else fields[j] + v_ = fields[j + 1].decode() if isinstance(fields[j + 1], bytes) else fields[j + 1] + kv[k_] = v_ + keys.append(key) + ranks.append(int(kv["rank"])) + i += 3 + + # SORTBY rank DESC: ranks must be strictly descending. If re-rank desynced + # sort_scores from ids, ranks would be jumbled. + assert ranks == sorted(ranks, reverse=True), f"SORTBY desynced: {keys} -> {ranks}" + await async_client.execute_command("FT.DROPINDEX", "sw_idx")