2020import org .apache .lucene .store .Directory ;
2121import org .apache .lucene .tests .index .RandomIndexWriter ;
2222import org .apache .lucene .util .BytesRef ;
23+ import org .opensearch .action .OriginalIndices ;
24+ import org .opensearch .action .search .SearchShardTask ;
25+ import org .opensearch .action .support .StreamSearchChannelListener ;
2326import org .opensearch .common .settings .Settings ;
2427import org .opensearch .common .util .MockBigArrays ;
2528import org .opensearch .common .util .MockPageCacheRecycler ;
2629import org .opensearch .core .common .breaker .CircuitBreaker ;
30+ import org .opensearch .core .index .shard .ShardId ;
2731import org .opensearch .core .indices .breaker .NoneCircuitBreakerService ;
32+ import org .opensearch .core .transport .TransportResponse ;
2833import org .opensearch .index .mapper .KeywordFieldMapper ;
2934import org .opensearch .index .mapper .MappedFieldType ;
3035import org .opensearch .index .mapper .NumberFieldMapper ;
36+ import org .opensearch .search .SearchShardTarget ;
37+ import org .opensearch .search .aggregations .Aggregator ;
3138import org .opensearch .search .aggregations .AggregatorTestCase ;
3239import org .opensearch .search .aggregations .BucketOrder ;
3340import org .opensearch .search .aggregations .InternalAggregation ;
41+ import org .opensearch .search .aggregations .InternalAggregations ;
3442import org .opensearch .search .aggregations .MultiBucketConsumerService ;
3543import org .opensearch .search .aggregations .metrics .Avg ;
3644import org .opensearch .search .aggregations .metrics .AvgAggregationBuilder ;
4351import org .opensearch .search .aggregations .metrics .ValueCount ;
4452import org .opensearch .search .aggregations .metrics .ValueCountAggregationBuilder ;
4553import org .opensearch .search .aggregations .pipeline .PipelineAggregator .PipelineTree ;
54+ import org .opensearch .search .fetch .FetchSearchResult ;
55+ import org .opensearch .search .fetch .QueryFetchSearchResult ;
56+ import org .opensearch .search .internal .ContextIndexSearcher ;
57+ import org .opensearch .search .internal .SearchContext ;
58+ import org .opensearch .search .profile .Timer ;
59+ import org .opensearch .search .profile .aggregation .AggregationProfileBreakdown ;
60+ import org .opensearch .search .profile .aggregation .AggregationProfiler ;
61+ import org .opensearch .search .profile .aggregation .ProfilingAggregator ;
62+ import org .opensearch .search .query .QuerySearchResult ;
63+ import org .opensearch .search .streaming .FlushMode ;
4664import org .opensearch .search .streaming .Streamable ;
4765import org .opensearch .search .streaming .StreamingCostMetrics ;
4866
5573import java .util .function .BiConsumer ;
5674
5775import static org .opensearch .test .InternalAggregationTestCase .DEFAULT_MAX_BUCKETS ;
76+ import static org .hamcrest .Matchers .empty ;
5877import static org .hamcrest .Matchers .equalTo ;
5978import static org .hamcrest .Matchers .instanceOf ;
6079import static org .hamcrest .Matchers .lessThan ;
6180import static org .hamcrest .Matchers .lessThanOrEqualTo ;
81+ import static org .hamcrest .Matchers .not ;
6282import static org .hamcrest .Matchers .notNullValue ;
83+ import static org .mockito .ArgumentMatchers .any ;
84+ import static org .mockito .ArgumentMatchers .anyBoolean ;
85+ import static org .mockito .Mockito .doAnswer ;
86+ import static org .mockito .Mockito .mock ;
87+ import static org .mockito .Mockito .when ;
6388
6489public class StreamStringTermsAggregatorTests extends AggregatorTestCase {
6590 public void testBuildAggregationsBatchDirectBucketCreation () throws Exception {
@@ -343,14 +368,156 @@ public void testBuildAggregationsBatchWithCountOrder() throws Exception {
343368 }
344369 }
345370
346- public void testBuildAggregationsBatchReset () throws Exception {
371+ public void testBuildAggregationsWithContextSearcherNoProfile () throws Exception {
372+ doAggOverManySegments (false );
373+ }
374+
375+ public void testBuildAggregationsWithContextSearcherProfile () throws Exception {
376+ doAggOverManySegments (true );
377+ }
378+
379+ private void doAggOverManySegments (boolean profile ) throws IOException {
347380 try (Directory directory = newDirectory ()) {
348381 try (RandomIndexWriter indexWriter = new RandomIndexWriter (random (), directory )) {
382+ boolean isSegmented = false ;
383+ for (int i = 0 ; i < 3 ; i ++) {
384+ Document document = new Document ();
385+ document .add (new SortedSetDocValuesField ("field" , new BytesRef ("common" )));
386+ indexWriter .addDocument (document );
387+ if (rarely ()) {
388+ indexWriter .flush ();
389+ isSegmented = true ;
390+ }
391+ }
392+ indexWriter .flush ();
393+ for (int i = 0 ; i < 2 ; i ++) {
394+ Document document = new Document ();
395+ document .add (new SortedSetDocValuesField ("field" , new BytesRef ("medium" )));
396+ indexWriter .addDocument (document );
397+ if (rarely ()) {
398+ indexWriter .flush ();
399+ isSegmented = true ;
400+ }
401+ }
402+
403+ if (!isSegmented ) {
404+ indexWriter .flush ();
405+ }
406+
349407 Document document = new Document ();
350- document .add (new SortedSetDocValuesField ("field" , new BytesRef ("test " )));
408+ document .add (new SortedSetDocValuesField ("field" , new BytesRef ("rare " )));
351409 indexWriter .addDocument (document );
352410
353411 try (IndexReader indexReader = maybeWrapReaderEs (indexWriter .getReader ())) {
412+ IndexSearcher indexSearcher = newIndexSearcher (indexReader );
413+ SearchContext searchContext = createSearchContext (
414+ indexSearcher ,
415+ createIndexSettings (),
416+ null ,
417+ new MultiBucketConsumerService .MultiBucketConsumer (
418+ MultiBucketConsumerService .DEFAULT_MAX_BUCKETS ,
419+ new NoneCircuitBreakerService ().getBreaker (CircuitBreaker .REQUEST )
420+ ),
421+ new NumberFieldMapper .NumberFieldType ("test" , NumberFieldMapper .NumberType .INTEGER )
422+ );
423+ when (searchContext .isStreamSearch ()).thenReturn (true );
424+ when (searchContext .getFlushMode ()).thenReturn (FlushMode .PER_SEGMENT );
425+ SearchShardTarget searchShardTarget = new SearchShardTarget (
426+ "node_1" ,
427+ new ShardId ("foo" , "_na_" , 1 ),
428+ null ,
429+ OriginalIndices .NONE
430+ );
431+ when (searchContext .shardTarget ()).thenReturn (searchShardTarget );
432+ SearchShardTask task = new SearchShardTask (0 , "n/a" , "n/a" , "test-kind" , null , null );
433+ searchContext .setTask (task );
434+ when (searchContext .queryResult ()).thenReturn (new QuerySearchResult ());
435+ when (searchContext .fetchResult ()).thenReturn (new FetchSearchResult ());
436+ StreamSearchChannelListener listenerMock = mock (StreamSearchChannelListener .class );
437+ final List <InternalAggregations > perSegAggs = new ArrayList <>();
438+ when (searchContext .getStreamChannelListener ()).thenReturn (listenerMock );
439+ doAnswer ((invok ) -> {
440+ QuerySearchResult querySearchResult = ((QueryFetchSearchResult ) invok .getArgument (0 , TransportResponse .class ))
441+ .queryResult ();
442+ InternalAggregations internalAggregations = querySearchResult .aggregations ().expand ();
443+ perSegAggs .add (internalAggregations );
444+ return null ;
445+ }).when (listenerMock ).onStreamResponse (any (), anyBoolean ());
446+ ContextIndexSearcher contextIndexSearcher = searchContext .searcher ();
447+
448+ MappedFieldType fieldType = new KeywordFieldMapper .KeywordFieldType ("field" );
449+
450+ TermsAggregationBuilder aggregationBuilder = new TermsAggregationBuilder ("test" ).field ("field" )
451+ .order (BucketOrder .count (false ));
452+
453+ Aggregator aggregator = createStreamAggregator (
454+ null ,
455+ aggregationBuilder ,
456+ indexSearcher ,
457+ createIndexSettings (),
458+ new MultiBucketConsumerService .MultiBucketConsumer (
459+ DEFAULT_MAX_BUCKETS ,
460+ new NoneCircuitBreakerService ().getBreaker (CircuitBreaker .REQUEST )
461+ ),
462+ fieldType
463+ );
464+
465+ if (profile ) {
466+ aggregator = wrapByProfilingAgg (aggregator );
467+ }
468+
469+ aggregator .preCollection ();
470+
471+ contextIndexSearcher .search (new MatchAllDocsQuery (), aggregator );
472+ aggregator .postCollection ();
473+
474+ InternalAggregation .ReduceContext ctx = InternalAggregation .ReduceContext .forFinalReduction (
475+ new MockBigArrays (new MockPageCacheRecycler (Settings .EMPTY ), new NoneCircuitBreakerService ()),
476+ getMockScriptService (),
477+ b -> {},
478+ PipelineTree .EMPTY
479+ );
480+
481+ assertThat (perSegAggs , not (empty ()));
482+ InternalAggregations summary = InternalAggregations .reduce (perSegAggs , ctx );
483+
484+ StringTerms result = summary .get ("test" );
485+
486+ assertThat (result , notNullValue ());
487+ assertThat (result .getBuckets ().size (), equalTo (3 ));
488+
489+ List <StringTerms .Bucket > buckets = result .getBuckets ();
490+ assertThat (buckets .get (0 ).getKeyAsString (), equalTo ("common" ));
491+ assertThat (buckets .get (0 ).getDocCount (), equalTo (3L ));
492+ assertThat (buckets .get (1 ).getKeyAsString (), equalTo ("medium" ));
493+ assertThat (buckets .get (1 ).getDocCount (), equalTo (2L ));
494+ assertThat (buckets .get (2 ).getKeyAsString (), equalTo ("rare" ));
495+ assertThat (buckets .get (2 ).getDocCount (), equalTo (1L ));
496+ }
497+ }
498+ }
499+ }
500+
501+ private static Aggregator wrapByProfilingAgg (Aggregator aggregator ) throws IOException {
502+ AggregationProfiler aggregationProfiler = mock (AggregationProfiler .class );
503+ AggregationProfileBreakdown aggregationProfileBreakdown = mock (AggregationProfileBreakdown .class );
504+ when (aggregationProfileBreakdown .getTimer (any ())).thenReturn (mock (Timer .class ));
505+ when (aggregationProfiler .getQueryBreakdown (any ())).thenReturn (aggregationProfileBreakdown );
506+ aggregator = new ProfilingAggregator (aggregator , aggregationProfiler );
507+ return aggregator ;
508+ }
509+
510+ public void testBuildAggregationsBatchReset () throws Exception {
511+ try (Directory directory = newDirectory ()) {
512+ try (IndexWriter indexWriter = new IndexWriter (directory , new IndexWriterConfig ())) {
513+ Document document = new Document ();
514+ document .add (new SortedSetDocValuesField ("field" , new BytesRef ("test" )));
515+ indexWriter .addDocument (document );
516+ document = new Document ();
517+ document .add (new SortedSetDocValuesField ("field" , new BytesRef ("best" )));
518+ indexWriter .addDocument (document );
519+
520+ try (IndexReader indexReader = maybeWrapReaderEs (DirectoryReader .open (indexWriter ))) {
354521 IndexSearcher indexSearcher = newIndexSearcher (indexReader );
355522 MappedFieldType fieldType = new KeywordFieldMapper .KeywordFieldType ("field" );
356523
@@ -374,7 +541,7 @@ public void testBuildAggregationsBatchReset() throws Exception {
374541 aggregator .postCollection ();
375542
376543 StringTerms firstResult = (StringTerms ) aggregator .buildAggregations (new long [] { 0 })[0 ];
377- assertThat (firstResult .getBuckets ().size (), equalTo (1 ));
544+ assertThat (firstResult .getBuckets ().size (), equalTo (2 ));
378545
379546 aggregator .doReset ();
380547
@@ -384,7 +551,7 @@ public void testBuildAggregationsBatchReset() throws Exception {
384551 aggregator .postCollection ();
385552
386553 StringTerms secondResult = (StringTerms ) aggregator .buildAggregations (new long [] { 0 })[0 ];
387- assertThat (secondResult .getBuckets ().size (), equalTo (1 ));
554+ assertThat (secondResult .getBuckets ().size (), equalTo (2 ));
388555 assertThat (secondResult .getBuckets ().get (0 ).getDocCount (), equalTo (1L ));
389556 }
390557 }
@@ -431,7 +598,7 @@ public void testMultipleBatches() throws Exception {
431598
432599 public void testSubAggregationWithMax () throws Exception {
433600 try (Directory directory = newDirectory ()) {
434- try (RandomIndexWriter indexWriter = new RandomIndexWriter ( random (), directory )) {
601+ try (IndexWriter indexWriter = new IndexWriter ( directory , new IndexWriterConfig () )) {
435602 Document document = new Document ();
436603 document .add (new SortedSetDocValuesField ("category" , new BytesRef ("electronics" )));
437604 document .add (new NumericDocValuesField ("price" , 100 ));
@@ -447,7 +614,7 @@ public void testSubAggregationWithMax() throws Exception {
447614 document .add (new NumericDocValuesField ("price" , 50 ));
448615 indexWriter .addDocument (document );
449616
450- try (IndexReader indexReader = maybeWrapReaderEs (indexWriter . getReader ( ))) {
617+ try (IndexReader indexReader = maybeWrapReaderEs (DirectoryReader . open ( indexWriter ))) {
451618 IndexSearcher indexSearcher = newIndexSearcher (indexReader );
452619 MappedFieldType categoryFieldType = new KeywordFieldMapper .KeywordFieldType ("category" );
453620 MappedFieldType priceFieldType = new NumberFieldMapper .NumberFieldType ("price" , NumberFieldMapper .NumberType .LONG );
@@ -1167,6 +1334,41 @@ public void testReduceSingleAggregation() throws Exception {
11671334 }
11681335 }
11691336
1337+ public void testThrowOnManySegments () throws Exception {
1338+ try (Directory directory = newDirectory ()) {
1339+ try (IndexWriter indexWriter = new IndexWriter (directory , new IndexWriterConfig ())) {
1340+ for (int i = 0 ; i < atLeast (2 ); i ++) {
1341+ Document doc = new Document ();
1342+ doc .add (new SortedSetDocValuesField ("category" , new BytesRef ("electronics" )));
1343+ indexWriter .addDocument (doc );
1344+ indexWriter .commit ();
1345+ }
1346+ try (IndexReader reader = maybeWrapReaderEs (DirectoryReader .open (indexWriter ))) {
1347+ IndexSearcher searcher = newIndexSearcher (reader );
1348+ MappedFieldType fieldType = new KeywordFieldMapper .KeywordFieldType ("category" );
1349+ TermsAggregationBuilder aggregationBuilder = new TermsAggregationBuilder ("categories" ).field ("category" )
1350+ .order (BucketOrder .count (false )); // Order by count descending
1351+
1352+ StreamStringTermsAggregator aggregator = createStreamAggregator (
1353+ null ,
1354+ aggregationBuilder ,
1355+ searcher ,
1356+ createIndexSettings (),
1357+ new MultiBucketConsumerService .MultiBucketConsumer (
1358+ DEFAULT_MAX_BUCKETS ,
1359+ new NoneCircuitBreakerService ().getBreaker (CircuitBreaker .REQUEST )
1360+ ),
1361+ fieldType
1362+ );
1363+
1364+ // Execute the aggregator
1365+ aggregator .preCollection ();
1366+ assertThrows (IllegalStateException .class , () -> { searcher .search (new MatchAllDocsQuery (), aggregator ); });
1367+ }
1368+ }
1369+ }
1370+ }
1371+
11701372 private InternalAggregation buildInternalStreamingAggregation (
11711373 TermsAggregationBuilder builder ,
11721374 MappedFieldType fieldType1 ,
0 commit comments