Skip to content
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Migrate usages of deprecated `Operations#union` from Lucene ([#19397](https://github.com/opensearch-project/OpenSearch/pull/19397))
- Delegate primitive write methods with ByteSizeCachingDirectory wrapped IndexOutput ([#19432](https://github.com/opensearch-project/OpenSearch/pull/19432))
- Bump opensearch-protobufs dependency to 0.18.0 and update transport-grpc module compatibility ([#19447](https://github.com/opensearch-project/OpenSearch/issues/19447))
- StreamStringTermsAggregator rejects collecting the second segment without prior reset(). ProfilingAggregator propagates reset() to underneath collector ([#19416](https://github.com/opensearch-project/OpenSearch/pull/19416)))
- Bump opensearch-protobufs dependency to 0.19.0 ([#19453](https://github.com/opensearch-project/OpenSearch/issues/19453))

### Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public class StreamStringTermsAggregator extends AbstractStringTermsAggregator i
protected int segmentsWithSingleValuedOrds = 0;
protected int segmentsWithMultiValuedOrds = 0;
protected final ResultStrategy<?, ?> resultStrategy;
private boolean leafCollectorCreated = false;

public StreamStringTermsAggregator(
String name,
Expand All @@ -74,6 +75,7 @@ public void doReset() {
super.doReset();
valueCount = 0;
sortedDocValuesPerBatch = null;
this.leafCollectorCreated = false;
}

@Override
Expand All @@ -88,6 +90,13 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
if (this.leafCollectorCreated) {
throw new IllegalStateException(
"Calling " + StreamStringTermsAggregator.class.getSimpleName() + " for the second segment: " + ctx
);
} else {
this.leafCollectorCreated = true;
}
this.sortedDocValuesPerBatch = valuesSource.ordinalsValues(ctx);
this.valueCount = sortedDocValuesPerBatch.getValueCount(); // for streaming case, the value count is reset to per batch
// cardinality
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOExce
}
}

@Override
public void reset() {
delegate.reset();
super.reset();
}

@Override
public void preCollection() throws IOException {
this.profileBreakdown = profiler.getQueryBreakdown(delegate);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,25 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.util.BytesRef;
import org.opensearch.action.OriginalIndices;
import org.opensearch.action.search.SearchShardTask;
import org.opensearch.action.support.StreamSearchChannelListener;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.MockBigArrays;
import org.opensearch.common.util.MockPageCacheRecycler;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.indices.breaker.NoneCircuitBreakerService;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.index.mapper.KeywordFieldMapper;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.AggregatorTestCase;
import org.opensearch.search.aggregations.BucketOrder;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.aggregations.MultiBucketConsumerService;
import org.opensearch.search.aggregations.metrics.Avg;
import org.opensearch.search.aggregations.metrics.AvgAggregationBuilder;
Expand All @@ -43,6 +51,16 @@
import org.opensearch.search.aggregations.metrics.ValueCount;
import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.fetch.QueryFetchSearchResult;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.profile.Timer;
import org.opensearch.search.profile.aggregation.AggregationProfileBreakdown;
import org.opensearch.search.profile.aggregation.AggregationProfiler;
import org.opensearch.search.profile.aggregation.ProfilingAggregator;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.streaming.FlushMode;
import org.opensearch.search.streaming.Streamable;
import org.opensearch.search.streaming.StreamingCostMetrics;

Expand All @@ -55,11 +73,18 @@
import java.util.function.BiConsumer;

import static org.opensearch.test.InternalAggregationTestCase.DEFAULT_MAX_BUCKETS;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class StreamStringTermsAggregatorTests extends AggregatorTestCase {
public void testBuildAggregationsBatchDirectBucketCreation() throws Exception {
Expand Down Expand Up @@ -343,14 +368,156 @@ public void testBuildAggregationsBatchWithCountOrder() throws Exception {
}
}

public void testBuildAggregationsBatchReset() throws Exception {
public void testBuildAggregationsWithContextSearcherNoProfile() throws Exception {
doAggOverManySegments(false);
}

public void testBuildAggregationsWithContextSearcherProfile() throws Exception {
doAggOverManySegments(true);
}

private void doAggOverManySegments(boolean profile) throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
boolean isSegmented = false;
for (int i = 0; i < 3; i++) {
Document document = new Document();
document.add(new SortedSetDocValuesField("field", new BytesRef("common")));
indexWriter.addDocument(document);
if (rarely()) {
indexWriter.flush();
isSegmented = true;
}
}
indexWriter.flush();
for (int i = 0; i < 2; i++) {
Document document = new Document();
document.add(new SortedSetDocValuesField("field", new BytesRef("medium")));
indexWriter.addDocument(document);
if (rarely()) {
indexWriter.flush();
isSegmented = true;
}
}

if (!isSegmented) {
indexWriter.flush();
}

Document document = new Document();
document.add(new SortedSetDocValuesField("field", new BytesRef("test")));
document.add(new SortedSetDocValuesField("field", new BytesRef("rare")));
indexWriter.addDocument(document);

try (IndexReader indexReader = maybeWrapReaderEs(indexWriter.getReader())) {
IndexSearcher indexSearcher = newIndexSearcher(indexReader);
SearchContext searchContext = createSearchContext(
indexSearcher,
createIndexSettings(),
null,
new MultiBucketConsumerService.MultiBucketConsumer(
MultiBucketConsumerService.DEFAULT_MAX_BUCKETS,
new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST)
),
new NumberFieldMapper.NumberFieldType("test", NumberFieldMapper.NumberType.INTEGER)
);
when(searchContext.isStreamSearch()).thenReturn(true);
when(searchContext.getFlushMode()).thenReturn(FlushMode.PER_SEGMENT);
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node_1",
new ShardId("foo", "_na_", 1),
null,
OriginalIndices.NONE
);
when(searchContext.shardTarget()).thenReturn(searchShardTarget);
SearchShardTask task = new SearchShardTask(0, "n/a", "n/a", "test-kind", null, null);
searchContext.setTask(task);
when(searchContext.queryResult()).thenReturn(new QuerySearchResult());
when(searchContext.fetchResult()).thenReturn(new FetchSearchResult());
StreamSearchChannelListener listenerMock = mock(StreamSearchChannelListener.class);
final List<InternalAggregations> perSegAggs = new ArrayList<>();
when(searchContext.getStreamChannelListener()).thenReturn(listenerMock);
doAnswer((invok) -> {
QuerySearchResult querySearchResult = ((QueryFetchSearchResult) invok.getArgument(0, TransportResponse.class))
.queryResult();
InternalAggregations internalAggregations = querySearchResult.aggregations().expand();
perSegAggs.add(internalAggregations);
return null;
}).when(listenerMock).onStreamResponse(any(), anyBoolean());
ContextIndexSearcher contextIndexSearcher = searchContext.searcher();

MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("field");

TermsAggregationBuilder aggregationBuilder = new TermsAggregationBuilder("test").field("field")
.order(BucketOrder.count(false));

Aggregator aggregator = createStreamAggregator(
null,
aggregationBuilder,
indexSearcher,
createIndexSettings(),
new MultiBucketConsumerService.MultiBucketConsumer(
DEFAULT_MAX_BUCKETS,
new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST)
),
fieldType
);

if (profile) {
aggregator = wrapByProfilingAgg(aggregator);
}

aggregator.preCollection();

contextIndexSearcher.search(new MatchAllDocsQuery(), aggregator);
aggregator.postCollection();

InternalAggregation.ReduceContext ctx = InternalAggregation.ReduceContext.forFinalReduction(
new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()),
getMockScriptService(),
b -> {},
PipelineTree.EMPTY
);

assertThat(perSegAggs, not(empty()));
InternalAggregations summary = InternalAggregations.reduce(perSegAggs, ctx);

StringTerms result = summary.get("test");

assertThat(result, notNullValue());
assertThat(result.getBuckets().size(), equalTo(3));

List<StringTerms.Bucket> buckets = result.getBuckets();
assertThat(buckets.get(0).getKeyAsString(), equalTo("common"));
assertThat(buckets.get(0).getDocCount(), equalTo(3L));
assertThat(buckets.get(1).getKeyAsString(), equalTo("medium"));
assertThat(buckets.get(1).getDocCount(), equalTo(2L));
assertThat(buckets.get(2).getKeyAsString(), equalTo("rare"));
assertThat(buckets.get(2).getDocCount(), equalTo(1L));
}
}
}
}

private static Aggregator wrapByProfilingAgg(Aggregator aggregator) throws IOException {
AggregationProfiler aggregationProfiler = mock(AggregationProfiler.class);
AggregationProfileBreakdown aggregationProfileBreakdown = mock(AggregationProfileBreakdown.class);
when(aggregationProfileBreakdown.getTimer(any())).thenReturn(mock(Timer.class));
when(aggregationProfiler.getQueryBreakdown(any())).thenReturn(aggregationProfileBreakdown);
aggregator = new ProfilingAggregator(aggregator, aggregationProfiler);
return aggregator;
}

public void testBuildAggregationsBatchReset() throws Exception {
try (Directory directory = newDirectory()) {
try (IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig())) {
Document document = new Document();
document.add(new SortedSetDocValuesField("field", new BytesRef("test")));
indexWriter.addDocument(document);
document = new Document();
document.add(new SortedSetDocValuesField("field", new BytesRef("best")));
indexWriter.addDocument(document);

try (IndexReader indexReader = maybeWrapReaderEs(DirectoryReader.open(indexWriter))) {
IndexSearcher indexSearcher = newIndexSearcher(indexReader);
MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("field");

Expand All @@ -374,7 +541,7 @@ public void testBuildAggregationsBatchReset() throws Exception {
aggregator.postCollection();

StringTerms firstResult = (StringTerms) aggregator.buildAggregations(new long[] { 0 })[0];
assertThat(firstResult.getBuckets().size(), equalTo(1));
assertThat(firstResult.getBuckets().size(), equalTo(2));

aggregator.doReset();

Expand All @@ -384,7 +551,7 @@ public void testBuildAggregationsBatchReset() throws Exception {
aggregator.postCollection();

StringTerms secondResult = (StringTerms) aggregator.buildAggregations(new long[] { 0 })[0];
assertThat(secondResult.getBuckets().size(), equalTo(1));
assertThat(secondResult.getBuckets().size(), equalTo(2));
assertThat(secondResult.getBuckets().get(0).getDocCount(), equalTo(1L));
}
}
Expand Down Expand Up @@ -431,7 +598,7 @@ public void testMultipleBatches() throws Exception {

public void testSubAggregationWithMax() throws Exception {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
try (IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig())) {
Document document = new Document();
document.add(new SortedSetDocValuesField("category", new BytesRef("electronics")));
document.add(new NumericDocValuesField("price", 100));
Expand All @@ -447,7 +614,7 @@ public void testSubAggregationWithMax() throws Exception {
document.add(new NumericDocValuesField("price", 50));
indexWriter.addDocument(document);

try (IndexReader indexReader = maybeWrapReaderEs(indexWriter.getReader())) {
try (IndexReader indexReader = maybeWrapReaderEs(DirectoryReader.open(indexWriter))) {
IndexSearcher indexSearcher = newIndexSearcher(indexReader);
MappedFieldType categoryFieldType = new KeywordFieldMapper.KeywordFieldType("category");
MappedFieldType priceFieldType = new NumberFieldMapper.NumberFieldType("price", NumberFieldMapper.NumberType.LONG);
Expand Down Expand Up @@ -1167,6 +1334,41 @@ public void testReduceSingleAggregation() throws Exception {
}
}

public void testThrowOnManySegments() throws Exception {
try (Directory directory = newDirectory()) {
try (IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig())) {
for (int i = 0; i < atLeast(2); i++) {
Document doc = new Document();
doc.add(new SortedSetDocValuesField("category", new BytesRef("electronics")));
indexWriter.addDocument(doc);
indexWriter.commit();
}
try (IndexReader reader = maybeWrapReaderEs(DirectoryReader.open(indexWriter))) {
IndexSearcher searcher = newIndexSearcher(reader);
MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("category");
TermsAggregationBuilder aggregationBuilder = new TermsAggregationBuilder("categories").field("category")
.order(BucketOrder.count(false)); // Order by count descending

StreamStringTermsAggregator aggregator = createStreamAggregator(
null,
aggregationBuilder,
searcher,
createIndexSettings(),
new MultiBucketConsumerService.MultiBucketConsumer(
DEFAULT_MAX_BUCKETS,
new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST)
),
fieldType
);

// Execute the aggregator
aggregator.preCollection();
assertThrows(IllegalStateException.class, () -> { searcher.search(new MatchAllDocsQuery(), aggregator); });
}
}
}
}

private InternalAggregation buildInternalStreamingAggregation(
TermsAggregationBuilder builder,
MappedFieldType fieldType1,
Expand Down
Loading