diff --git a/CHANGELOG.md b/CHANGELOG.md index 47a135e8046fd..dfb518e21f743 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -73,6 +73,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Migrate usages of deprecated `Operations#union` from Lucene ([#19397](https://github.com/opensearch-project/OpenSearch/pull/19397)) - Delegate primitive write methods with ByteSizeCachingDirectory wrapped IndexOutput ([#19432](https://github.com/opensearch-project/OpenSearch/pull/19432)) - Bump opensearch-protobufs dependency to 0.18.0 and update transport-grpc module compatibility ([#19447](https://github.com/opensearch-project/OpenSearch/issues/19447)) +- StreamStringTermsAggregator rejects collecting the second segment without prior reset(). ProfilingAggregator propagates reset() to underneath collector ([#19416](https://github.com/opensearch-project/OpenSearch/pull/19416))) - Bump opensearch-protobufs dependency to 0.19.0 ([#19453](https://github.com/opensearch-project/OpenSearch/issues/19453)) ### Fixed diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/StreamStringTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/StreamStringTermsAggregator.java index 9e5aa23d214ee..cb95b0efc2e75 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/StreamStringTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/StreamStringTermsAggregator.java @@ -49,6 +49,7 @@ public class StreamStringTermsAggregator extends AbstractStringTermsAggregator i protected int segmentsWithSingleValuedOrds = 0; protected int segmentsWithMultiValuedOrds = 0; protected final ResultStrategy resultStrategy; + private boolean leafCollectorCreated = false; public StreamStringTermsAggregator( String name, @@ -74,6 +75,7 @@ public void doReset() { super.doReset(); valueCount = 0; sortedDocValuesPerBatch = null; + this.leafCollectorCreated = false; } @Override @@ -88,6 +90,13 @@ public InternalAggregation buildEmptyAggregation() { @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + if (this.leafCollectorCreated) { + throw new IllegalStateException( + "Calling " + StreamStringTermsAggregator.class.getSimpleName() + " for the second segment: " + ctx + ); + } else { + this.leafCollectorCreated = true; + } this.sortedDocValuesPerBatch = valuesSource.ordinalsValues(ctx); this.valueCount = sortedDocValuesPerBatch.getValueCount(); // for streaming case, the value count is reset to per batch // cardinality diff --git a/server/src/main/java/org/opensearch/search/profile/aggregation/ProfilingAggregator.java b/server/src/main/java/org/opensearch/search/profile/aggregation/ProfilingAggregator.java index b8004181f2ec5..aaf0a093f848d 100644 --- a/server/src/main/java/org/opensearch/search/profile/aggregation/ProfilingAggregator.java +++ b/server/src/main/java/org/opensearch/search/profile/aggregation/ProfilingAggregator.java @@ -129,6 +129,12 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOExce } } + @Override + public void reset() { + delegate.reset(); + super.reset(); + } + @Override public void preCollection() throws IOException { this.profileBreakdown = profiler.getQueryBreakdown(delegate); diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/StreamStringTermsAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/StreamStringTermsAggregatorTests.java index 3b64b7aa7b7e7..112a47527d0f0 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/StreamStringTermsAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/StreamStringTermsAggregatorTests.java @@ -20,17 +20,25 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.util.BytesRef; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.action.support.StreamSearchChannelListener; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.MockBigArrays; import org.opensearch.common.util.MockPageCacheRecycler; import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.indices.breaker.NoneCircuitBreakerService; +import org.opensearch.core.transport.TransportResponse; import org.opensearch.index.mapper.KeywordFieldMapper; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.Aggregator; import org.opensearch.search.aggregations.AggregatorTestCase; import org.opensearch.search.aggregations.BucketOrder; import org.opensearch.search.aggregations.InternalAggregation; +import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.aggregations.MultiBucketConsumerService; import org.opensearch.search.aggregations.metrics.Avg; import org.opensearch.search.aggregations.metrics.AvgAggregationBuilder; @@ -43,6 +51,16 @@ import org.opensearch.search.aggregations.metrics.ValueCount; import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder; import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.fetch.QueryFetchSearchResult; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.profile.Timer; +import org.opensearch.search.profile.aggregation.AggregationProfileBreakdown; +import org.opensearch.search.profile.aggregation.AggregationProfiler; +import org.opensearch.search.profile.aggregation.ProfilingAggregator; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.streaming.FlushMode; import org.opensearch.search.streaming.Streamable; import org.opensearch.search.streaming.StreamingCostMetrics; @@ -55,11 +73,18 @@ import java.util.function.BiConsumer; import static org.opensearch.test.InternalAggregationTestCase.DEFAULT_MAX_BUCKETS; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class StreamStringTermsAggregatorTests extends AggregatorTestCase { public void testBuildAggregationsBatchDirectBucketCreation() throws Exception { @@ -343,14 +368,156 @@ public void testBuildAggregationsBatchWithCountOrder() throws Exception { } } - public void testBuildAggregationsBatchReset() throws Exception { + public void testBuildAggregationsWithContextSearcherNoProfile() throws Exception { + doAggOverManySegments(false); + } + + public void testBuildAggregationsWithContextSearcherProfile() throws Exception { + doAggOverManySegments(true); + } + + private void doAggOverManySegments(boolean profile) throws IOException { try (Directory directory = newDirectory()) { try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { + boolean isSegmented = false; + for (int i = 0; i < 3; i++) { + Document document = new Document(); + document.add(new SortedSetDocValuesField("field", new BytesRef("common"))); + indexWriter.addDocument(document); + if (rarely()) { + indexWriter.flush(); + isSegmented = true; + } + } + indexWriter.flush(); + for (int i = 0; i < 2; i++) { + Document document = new Document(); + document.add(new SortedSetDocValuesField("field", new BytesRef("medium"))); + indexWriter.addDocument(document); + if (rarely()) { + indexWriter.flush(); + isSegmented = true; + } + } + + if (!isSegmented) { + indexWriter.flush(); + } + Document document = new Document(); - document.add(new SortedSetDocValuesField("field", new BytesRef("test"))); + document.add(new SortedSetDocValuesField("field", new BytesRef("rare"))); indexWriter.addDocument(document); try (IndexReader indexReader = maybeWrapReaderEs(indexWriter.getReader())) { + IndexSearcher indexSearcher = newIndexSearcher(indexReader); + SearchContext searchContext = createSearchContext( + indexSearcher, + createIndexSettings(), + null, + new MultiBucketConsumerService.MultiBucketConsumer( + MultiBucketConsumerService.DEFAULT_MAX_BUCKETS, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ), + new NumberFieldMapper.NumberFieldType("test", NumberFieldMapper.NumberType.INTEGER) + ); + when(searchContext.isStreamSearch()).thenReturn(true); + when(searchContext.getFlushMode()).thenReturn(FlushMode.PER_SEGMENT); + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node_1", + new ShardId("foo", "_na_", 1), + null, + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(searchShardTarget); + SearchShardTask task = new SearchShardTask(0, "n/a", "n/a", "test-kind", null, null); + searchContext.setTask(task); + when(searchContext.queryResult()).thenReturn(new QuerySearchResult()); + when(searchContext.fetchResult()).thenReturn(new FetchSearchResult()); + StreamSearchChannelListener listenerMock = mock(StreamSearchChannelListener.class); + final List perSegAggs = new ArrayList<>(); + when(searchContext.getStreamChannelListener()).thenReturn(listenerMock); + doAnswer((invok) -> { + QuerySearchResult querySearchResult = ((QueryFetchSearchResult) invok.getArgument(0, TransportResponse.class)) + .queryResult(); + InternalAggregations internalAggregations = querySearchResult.aggregations().expand(); + perSegAggs.add(internalAggregations); + return null; + }).when(listenerMock).onStreamResponse(any(), anyBoolean()); + ContextIndexSearcher contextIndexSearcher = searchContext.searcher(); + + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("field"); + + TermsAggregationBuilder aggregationBuilder = new TermsAggregationBuilder("test").field("field") + .order(BucketOrder.count(false)); + + Aggregator aggregator = createStreamAggregator( + null, + aggregationBuilder, + indexSearcher, + createIndexSettings(), + new MultiBucketConsumerService.MultiBucketConsumer( + DEFAULT_MAX_BUCKETS, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ), + fieldType + ); + + if (profile) { + aggregator = wrapByProfilingAgg(aggregator); + } + + aggregator.preCollection(); + + contextIndexSearcher.search(new MatchAllDocsQuery(), aggregator); + aggregator.postCollection(); + + InternalAggregation.ReduceContext ctx = InternalAggregation.ReduceContext.forFinalReduction( + new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()), + getMockScriptService(), + b -> {}, + PipelineTree.EMPTY + ); + + assertThat(perSegAggs, not(empty())); + InternalAggregations summary = InternalAggregations.reduce(perSegAggs, ctx); + + StringTerms result = summary.get("test"); + + assertThat(result, notNullValue()); + assertThat(result.getBuckets().size(), equalTo(3)); + + List buckets = result.getBuckets(); + assertThat(buckets.get(0).getKeyAsString(), equalTo("common")); + assertThat(buckets.get(0).getDocCount(), equalTo(3L)); + assertThat(buckets.get(1).getKeyAsString(), equalTo("medium")); + assertThat(buckets.get(1).getDocCount(), equalTo(2L)); + assertThat(buckets.get(2).getKeyAsString(), equalTo("rare")); + assertThat(buckets.get(2).getDocCount(), equalTo(1L)); + } + } + } + } + + private static Aggregator wrapByProfilingAgg(Aggregator aggregator) throws IOException { + AggregationProfiler aggregationProfiler = mock(AggregationProfiler.class); + AggregationProfileBreakdown aggregationProfileBreakdown = mock(AggregationProfileBreakdown.class); + when(aggregationProfileBreakdown.getTimer(any())).thenReturn(mock(Timer.class)); + when(aggregationProfiler.getQueryBreakdown(any())).thenReturn(aggregationProfileBreakdown); + aggregator = new ProfilingAggregator(aggregator, aggregationProfiler); + return aggregator; + } + + public void testBuildAggregationsBatchReset() throws Exception { + try (Directory directory = newDirectory()) { + try (IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig())) { + Document document = new Document(); + document.add(new SortedSetDocValuesField("field", new BytesRef("test"))); + indexWriter.addDocument(document); + document = new Document(); + document.add(new SortedSetDocValuesField("field", new BytesRef("best"))); + indexWriter.addDocument(document); + + try (IndexReader indexReader = maybeWrapReaderEs(DirectoryReader.open(indexWriter))) { IndexSearcher indexSearcher = newIndexSearcher(indexReader); MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("field"); @@ -374,7 +541,7 @@ public void testBuildAggregationsBatchReset() throws Exception { aggregator.postCollection(); StringTerms firstResult = (StringTerms) aggregator.buildAggregations(new long[] { 0 })[0]; - assertThat(firstResult.getBuckets().size(), equalTo(1)); + assertThat(firstResult.getBuckets().size(), equalTo(2)); aggregator.doReset(); @@ -384,7 +551,7 @@ public void testBuildAggregationsBatchReset() throws Exception { aggregator.postCollection(); StringTerms secondResult = (StringTerms) aggregator.buildAggregations(new long[] { 0 })[0]; - assertThat(secondResult.getBuckets().size(), equalTo(1)); + assertThat(secondResult.getBuckets().size(), equalTo(2)); assertThat(secondResult.getBuckets().get(0).getDocCount(), equalTo(1L)); } } @@ -431,7 +598,7 @@ public void testMultipleBatches() throws Exception { public void testSubAggregationWithMax() throws Exception { try (Directory directory = newDirectory()) { - try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { + try (IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig())) { Document document = new Document(); document.add(new SortedSetDocValuesField("category", new BytesRef("electronics"))); document.add(new NumericDocValuesField("price", 100)); @@ -447,7 +614,7 @@ public void testSubAggregationWithMax() throws Exception { document.add(new NumericDocValuesField("price", 50)); indexWriter.addDocument(document); - try (IndexReader indexReader = maybeWrapReaderEs(indexWriter.getReader())) { + try (IndexReader indexReader = maybeWrapReaderEs(DirectoryReader.open(indexWriter))) { IndexSearcher indexSearcher = newIndexSearcher(indexReader); MappedFieldType categoryFieldType = new KeywordFieldMapper.KeywordFieldType("category"); MappedFieldType priceFieldType = new NumberFieldMapper.NumberFieldType("price", NumberFieldMapper.NumberType.LONG); @@ -1167,6 +1334,41 @@ public void testReduceSingleAggregation() throws Exception { } } + public void testThrowOnManySegments() throws Exception { + try (Directory directory = newDirectory()) { + try (IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig())) { + for (int i = 0; i < atLeast(2); i++) { + Document doc = new Document(); + doc.add(new SortedSetDocValuesField("category", new BytesRef("electronics"))); + indexWriter.addDocument(doc); + indexWriter.commit(); + } + try (IndexReader reader = maybeWrapReaderEs(DirectoryReader.open(indexWriter))) { + IndexSearcher searcher = newIndexSearcher(reader); + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("category"); + TermsAggregationBuilder aggregationBuilder = new TermsAggregationBuilder("categories").field("category") + .order(BucketOrder.count(false)); // Order by count descending + + StreamStringTermsAggregator aggregator = createStreamAggregator( + null, + aggregationBuilder, + searcher, + createIndexSettings(), + new MultiBucketConsumerService.MultiBucketConsumer( + DEFAULT_MAX_BUCKETS, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ), + fieldType + ); + + // Execute the aggregator + aggregator.preCollection(); + assertThrows(IllegalStateException.class, () -> { searcher.search(new MatchAllDocsQuery(), aggregator); }); + } + } + } + } + private InternalAggregation buildInternalStreamingAggregation( TermsAggregationBuilder builder, MappedFieldType fieldType1,