Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/core/search/indices.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,22 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
return field_num_docs_;
}

// Sum of per-doc field lengths. Pair with GetFieldNumDocs() to aggregate
// avg doc len across shards without going through the lossy ratio.
size_t GetFieldTotalDocsLen() const {
return field_total_docs_len_;
}

// Schema canonical identifier of this field. Set by FieldIndices after
// construction; empty when the index is built outside that path.
std::string_view field_ident() const {
return field_ident_;
}

void set_field_ident(std::string_view ident) {
field_ident_ = ident;
}

protected:
using StringList = DocumentAccessor::StringList;

Expand Down Expand Up @@ -160,6 +176,9 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
std::vector<uint32_t> field_doc_lengths_; // DocId -> sum of TF in this field
size_t field_total_docs_len_ = 0;
size_t field_num_docs_ = 0; // Number of docs with non-empty content in this field

// Borrows from owning Schema's field map; the schema outlives this index.
std::string_view field_ident_;
};

// Index for text fields.
Expand Down
29 changes: 29 additions & 0 deletions src/core/search/scoring.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,33 @@ double ScoreDocument(ScorerFn scorer, const ScoringContext& ctx,
return score;
}

void GlobalScoringStats::Merge(const ShardScoringStats& shard) {
num_docs += shard.num_docs;
for (const auto& [field, stats] : shard.field_stats) {
auto& dst = field_stats[field];
dst.num_docs += stats.num_docs;
dst.total_docs_len += stats.total_docs_len;
}
for (const auto& [field, terms] : shard.term_stats) {
auto& dst = term_stats[field];
for (const auto& [term, count] : terms)
dst[term] += count;
}
}

double GlobalScoringStats::GetFieldAvgDocLen(std::string_view field_ident) const {
auto it = field_stats.find(field_ident);
if (it == field_stats.end() || it->second.num_docs == 0)
return 0.0;
return static_cast<double>(it->second.total_docs_len) / it->second.num_docs;
}

size_t GlobalScoringStats::GetTermDocs(std::string_view field_ident, std::string_view term) const {
auto field_it = term_stats.find(field_ident);
if (field_it == term_stats.end())
return 0;
auto term_it = field_it->second.find(term);
return term_it == field_it->second.end() ? 0 : term_it->second;
}

} // namespace dfly::search
29 changes: 29 additions & 0 deletions src/core/search/scoring.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

#pragma once

#include <absl/container/flat_hash_map.h>

#include <algorithm>
#include <cmath>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include "core/search/base.h"
Expand Down Expand Up @@ -92,4 +95,30 @@ inline double TfIdfDocNorm(const ScoringContext& ctx, const ScoringTermInfo& ter
double ScoreDocument(ScorerFn scorer, const ScoringContext& ctx,
const std::vector<ScoringTermInfo>& terms);

// Single-shard slice of the counts a scorer needs. Keys are schema canonical
// names; terms are post-synonym-resolution.
struct ShardScoringStats {
struct FieldStats {
size_t num_docs = 0;
size_t total_docs_len = 0;
};

size_t num_docs = 0;
absl::flat_hash_map<std::string, FieldStats> field_stats;
absl::flat_hash_map<std::string, absl::flat_hash_map<std::string, size_t>> term_stats;
};

// Cluster-wide aggregate of ShardScoringStats. Injected into ScoringContext
// so ranking is independent of how documents are partitioned across shards.
struct GlobalScoringStats {
size_t num_docs = 0;
absl::flat_hash_map<std::string, ShardScoringStats::FieldStats> field_stats;
absl::flat_hash_map<std::string, absl::flat_hash_map<std::string, size_t>> term_stats;

void Merge(const ShardScoringStats& shard);

double GetFieldAvgDocLen(std::string_view field_ident) const;
size_t GetTermDocs(std::string_view field_ident, std::string_view term) const;
};

} // namespace dfly::search
Loading
Loading