Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the filter rewrite optimization #14464

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@
import org.opensearch.search.aggregations.MultiBucketCollector;
import org.opensearch.search.aggregations.MultiBucketConsumerService;
import org.opensearch.search.aggregations.bucket.BucketsAggregator;
import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper;
import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper.AbstractDateHistogramAggregationType;
import org.opensearch.search.aggregations.bucket.missing.MissingOrder;
import org.opensearch.search.aggregations.bucket.terms.LongKeyedBucketOrds;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.optimization.ranges.DateHistogramAggregatorBridge;
import org.opensearch.search.optimization.ranges.OptimizationContext;
import org.opensearch.search.searchafter.SearchAfterBuilder;
import org.opensearch.search.sort.SortAndFormats;

Expand All @@ -89,13 +89,14 @@
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.LongUnaryOperator;
import java.util.stream.Collectors;

import static org.opensearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING;

/**
* Main aggregator that aggregates docs from mulitple aggregations
* Main aggregator that aggregates docs from multiple aggregations
*
* @opensearch.internal
*/
Expand All @@ -118,9 +119,8 @@ public final class CompositeAggregator extends BucketsAggregator {

private boolean earlyTerminated;

private final FastFilterRewriteHelper.FastFilterContext fastFilterContext;
private LongKeyedBucketOrds bucketOrds = null;
private Rounding.Prepared preparedRounding = null;
private final OptimizationContext optimizationContext;
private LongKeyedBucketOrds bucketOrds;

CompositeAggregator(
String name,
Expand Down Expand Up @@ -166,56 +166,68 @@ public final class CompositeAggregator extends BucketsAggregator {
this.queue = new CompositeValuesCollectorQueue(context.bigArrays(), sources, size, rawAfterKey);
this.rawAfterKey = rawAfterKey;

fastFilterContext = new FastFilterRewriteHelper.FastFilterContext(context);
if (!FastFilterRewriteHelper.isCompositeAggRewriteable(sourceConfigs)) {
return;
}
fastFilterContext.setAggregationType(new CompositeAggregationType());
if (fastFilterContext.isRewriteable(parent, subAggregators.length)) {
// bucketOrds is used for saving date histogram results
bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), CardinalityUpperBound.ONE);
preparedRounding = ((CompositeAggregationType) fastFilterContext.getAggregationType()).getRoundingPrepared();
fastFilterContext.buildRanges(sourceConfigs[0].fieldType());
}
}
optimizationContext = new OptimizationContext(new DateHistogramAggregatorBridge() {
private RoundingValuesSource valuesSource;
private long afterKey = -1L;

/**
* Currently the filter rewrite is only supported for date histograms
*/
public class CompositeAggregationType extends AbstractDateHistogramAggregationType {
private final RoundingValuesSource valuesSource;
private long afterKey = -1L;

public CompositeAggregationType() {
super(sourceConfigs[0].fieldType(), sourceConfigs[0].missingBucket(), sourceConfigs[0].hasScript());
this.valuesSource = (RoundingValuesSource) sourceConfigs[0].valuesSource();
if (rawAfterKey != null) {
assert rawAfterKey.size() == 1 && formats.size() == 1;
this.afterKey = formats.get(0).parseLong(rawAfterKey.get(0).toString(), false, () -> {
throw new IllegalArgumentException("now() is not supported in [after] key");
});
@Override
protected boolean canOptimize() {
if (canOptimize(sourceConfigs)) {
this.valuesSource = (RoundingValuesSource) sourceConfigs[0].valuesSource();
if (rawAfterKey != null) {
assert rawAfterKey.size() == 1 && formats.size() == 1;
this.afterKey = formats.get(0).parseLong(rawAfterKey.get(0).toString(), false, () -> {
throw new IllegalArgumentException("now() is not supported in [after] key");
});
}

// bucketOrds is used for saving the date histogram results got from the optimization path
bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), CardinalityUpperBound.ONE);
return true;
}
return false;
}
}

public Rounding getRounding(final long low, final long high) {
return valuesSource.getRounding();
}
@Override
protected void buildRanges() throws IOException {
buildRanges(context);
}

public Rounding.Prepared getRoundingPrepared() {
return valuesSource.getPreparedRounding();
}
protected Rounding getRounding(final long low, final long high) {
return valuesSource.getRounding();
}

@Override
protected void processAfterKey(long[] bound, long interval) {
// afterKey is the last bucket key in previous response, and the bucket key
// is the minimum of all values in the bucket, so need to add the interval
if (afterKey != -1L) {
bound[0] = afterKey + interval;
protected Rounding.Prepared getRoundingPrepared() {
return valuesSource.getPreparedRounding();
}

@Override
protected long[] processAfterKey(long[] bounds, long interval) {
// afterKey is the last bucket key in previous response, and the bucket key
// is the minimum of all values in the bucket, so need to add the interval
if (afterKey != -1L) {
bounds[0] = afterKey + interval;
}
return bounds;
}

@Override
protected int getSize() {
return size;
}
}

public int getSize() {
return size;
@Override
protected Function<Object, Long> bucketOrdProducer() {
return (key) -> bucketOrds.add(0, getRoundingPrepared().round((long) key));
}

@Override
protected boolean segmentMatchAll(LeafReaderContext leaf) throws IOException {
return segmentMatchAll(context, leaf);
}
});
if (optimizationContext.canOptimize(parent, subAggregators.length, context)) {
optimizationContext.prepare();
}
}

Expand Down Expand Up @@ -368,7 +380,7 @@ private boolean isMaybeMultivalued(LeafReaderContext context, SortField sortFiel
return v2 != null && DocValues.unwrapSingleton(v2) == null;

default:
// we have no clue whether the field is multi-valued or not so we assume it is.
// we have no clue whether the field is multivalued or not so we assume it is.
return true;
}
}
Expand Down Expand Up @@ -551,11 +563,7 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
boolean optimized = fastFilterContext.tryFastFilterAggregation(
ctx,
this::incrementBucketDocCount,
(key) -> bucketOrds.add(0, preparedRounding.round((long) key))
);
boolean optimized = optimizationContext.tryOptimize(ctx, this::incrementBucketDocCount);
if (optimized) throw new CollectionTerminatedException();

finishLeaf();
Expand Down Expand Up @@ -709,11 +717,6 @@ private static class Entry {

@Override
public void collectDebugInfo(BiConsumer<String, Object> add) {
if (fastFilterContext.optimizedSegments > 0) {
add.accept("optimized_segments", fastFilterContext.optimizedSegments);
add.accept("unoptimized_segments", fastFilterContext.segments - fastFilterContext.optimizedSegments);
add.accept("leaf_visited", fastFilterContext.leaf);
add.accept("inner_visited", fastFilterContext.inner);
}
optimizationContext.populateDebugInfo(add);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import org.opensearch.common.util.IntArray;
import org.opensearch.common.util.LongArray;
import org.opensearch.core.common.util.ByteArray;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.AggregatorFactories;
Expand All @@ -53,18 +52,18 @@
import org.opensearch.search.aggregations.LeafBucketCollectorBase;
import org.opensearch.search.aggregations.bucket.DeferableBucketAggregator;
import org.opensearch.search.aggregations.bucket.DeferringBucketCollector;
import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper;
import org.opensearch.search.aggregations.bucket.MergingBucketsDeferringCollector;
import org.opensearch.search.aggregations.bucket.histogram.AutoDateHistogramAggregationBuilder.RoundingInfo;
import org.opensearch.search.aggregations.bucket.terms.LongKeyedBucketOrds;
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.optimization.ranges.DateHistogramAggregatorBridge;
import org.opensearch.search.optimization.ranges.OptimizationContext;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.LongToIntFunction;
Expand Down Expand Up @@ -135,7 +134,7 @@ static AutoDateHistogramAggregator build(
protected int roundingIdx;
protected Rounding.Prepared preparedRounding;

private final FastFilterRewriteHelper.FastFilterContext fastFilterContext;
private final OptimizationContext optimizationContext;

private AutoDateHistogramAggregator(
String name,
Expand All @@ -158,52 +157,58 @@ private AutoDateHistogramAggregator(
this.roundingPreparer = roundingPreparer;
this.preparedRounding = prepareRounding(0);

fastFilterContext = new FastFilterRewriteHelper.FastFilterContext(
context,
new AutoHistogramAggregationType(
valuesSourceConfig.fieldType(),
valuesSourceConfig.missing() != null,
valuesSourceConfig.script() != null
)
);
if (fastFilterContext.isRewriteable(parent, subAggregators.length)) {
fastFilterContext.buildRanges(Objects.requireNonNull(valuesSourceConfig.fieldType()));
}
}

private class AutoHistogramAggregationType extends FastFilterRewriteHelper.AbstractDateHistogramAggregationType {
optimizationContext = new OptimizationContext(new DateHistogramAggregatorBridge() {
@Override
protected boolean canOptimize() {
return canOptimize(valuesSourceConfig);
}

public AutoHistogramAggregationType(MappedFieldType fieldType, boolean missing, boolean hasScript) {
super(fieldType, missing, hasScript);
}
@Override
protected void buildRanges() throws IOException {
buildRanges(context);
}

@Override
protected Rounding getRounding(final long low, final long high) {
// max - min / targetBuckets = bestDuration
// find the right innerInterval this bestDuration belongs to
// since we cannot exceed targetBuckets, bestDuration should go up,
// so the right innerInterval should be an upper bound
long bestDuration = (high - low) / targetBuckets;
// reset so this function is idempotent
roundingIdx = 0;
while (roundingIdx < roundingInfos.length - 1) {
final RoundingInfo curRoundingInfo = roundingInfos[roundingIdx];
final int temp = curRoundingInfo.innerIntervals[curRoundingInfo.innerIntervals.length - 1];
// If the interval duration is covered by the maximum inner interval,
// we can start with this outer interval for creating the buckets
if (bestDuration <= temp * curRoundingInfo.roughEstimateDurationMillis) {
break;
@Override
protected Rounding getRounding(final long low, final long high) {
// max - min / targetBuckets = bestDuration
// find the right innerInterval this bestDuration belongs to
// since we cannot exceed targetBuckets, bestDuration should go up,
// so the right innerInterval should be an upper bound
long bestDuration = (high - low) / targetBuckets;
// reset so this function is idempotent
roundingIdx = 0;
while (roundingIdx < roundingInfos.length - 1) {
final RoundingInfo curRoundingInfo = roundingInfos[roundingIdx];
final int temp = curRoundingInfo.innerIntervals[curRoundingInfo.innerIntervals.length - 1];
// If the interval duration is covered by the maximum inner interval,
// we can start with this outer interval for creating the buckets
if (bestDuration <= temp * curRoundingInfo.roughEstimateDurationMillis) {
break;
}
roundingIdx++;
}
roundingIdx++;

preparedRounding = prepareRounding(roundingIdx);
return roundingInfos[roundingIdx].rounding;
}

preparedRounding = prepareRounding(roundingIdx);
return roundingInfos[roundingIdx].rounding;
}
@Override
protected Prepared getRoundingPrepared() {
return preparedRounding;
}

@Override
protected Prepared getRoundingPrepared() {
return preparedRounding;
@Override
protected Function<Object, Long> bucketOrdProducer() {
return (key) -> getBucketOrds().add(0, preparedRounding.round((long) key));
}

@Override
protected boolean segmentMatchAll(LeafReaderContext leaf) throws IOException {
return segmentMatchAll(context, leaf);
}
});
if (optimizationContext.canOptimize(parent, subAggregators.length, context)) {
optimizationContext.prepare();
}
}

Expand Down Expand Up @@ -236,11 +241,7 @@ public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBuc
return LeafBucketCollector.NO_OP_COLLECTOR;
}

boolean optimized = fastFilterContext.tryFastFilterAggregation(
ctx,
this::incrementBucketDocCount,
(key) -> getBucketOrds().add(0, preparedRounding.round((long) key))
);
boolean optimized = optimizationContext.tryOptimize(ctx, this::incrementBucketDocCount);
if (optimized) throw new CollectionTerminatedException();

final SortedNumericDocValues values = valuesSource.longValues(ctx);
Expand Down Expand Up @@ -308,12 +309,7 @@ protected final void merge(long[] mergeMap, long newNumBuckets) {
@Override
public void collectDebugInfo(BiConsumer<String, Object> add) {
super.collectDebugInfo(add);
if (fastFilterContext.optimizedSegments > 0) {
add.accept("optimized_segments", fastFilterContext.optimizedSegments);
add.accept("unoptimized_segments", fastFilterContext.segments - fastFilterContext.optimizedSegments);
add.accept("leaf_visited", fastFilterContext.leaf);
add.accept("inner_visited", fastFilterContext.inner);
}
optimizationContext.populateDebugInfo(add);
}

/**
Expand Down
Loading
Loading