From da0056b688cb17c78c108696f0a59d73b8d37ad8 Mon Sep 17 00:00:00 2001 From: Sandesh Kumar Date: Mon, 12 Aug 2024 19:11:21 -0700 Subject: [PATCH] Star stree request/response changes Signed-off-by: Sandesh Kumar --- .../opensearch/common/util/FeatureFlags.java | 2 +- .../index/query/QueryShardContext.java | 70 +++++ .../org/opensearch/search/SearchService.java | 55 +++- .../aggregations/AggregatorFactories.java | 6 +- .../aggregations/AggregatorFactory.java | 4 + .../metrics/NumericMetricsAggregator.java | 17 + .../aggregations/metrics/SumAggregator.java | 44 ++- .../metrics/SumAggregatorFactory.java | 2 +- .../aggregations/support/ValuesSource.java | 4 + .../ValuesSourceAggregatorFactory.java | 12 + .../startree/OriginalOrStarTreeQuery.java | 81 +++++ .../search/startree/StarTreeFilter.java | 294 ++++++++++++++++++ .../search/startree/StarTreeQuery.java | 123 ++++++++ .../search/startree/package-info.java | 9 + 14 files changed, 712 insertions(+), 11 deletions(-) create mode 100644 server/src/main/java/org/opensearch/search/startree/OriginalOrStarTreeQuery.java create mode 100644 server/src/main/java/org/opensearch/search/startree/StarTreeFilter.java create mode 100644 server/src/main/java/org/opensearch/search/startree/StarTreeQuery.java create mode 100644 server/src/main/java/org/opensearch/search/startree/package-info.java diff --git a/server/src/main/java/org/opensearch/common/util/FeatureFlags.java b/server/src/main/java/org/opensearch/common/util/FeatureFlags.java index ceb2559a0e16c..e4f5e46950b5f 100644 --- a/server/src/main/java/org/opensearch/common/util/FeatureFlags.java +++ b/server/src/main/java/org/opensearch/common/util/FeatureFlags.java @@ -105,7 +105,7 @@ public class FeatureFlags { * aggregations. */ public static final String STAR_TREE_INDEX = "opensearch.experimental.feature.composite_index.star_tree.enabled"; - public static final Setting STAR_TREE_INDEX_SETTING = Setting.boolSetting(STAR_TREE_INDEX, false, Property.NodeScope); + public static final Setting STAR_TREE_INDEX_SETTING = Setting.boolSetting(STAR_TREE_INDEX, true, Property.NodeScope); private static final List> ALL_FEATURE_FLAG_SETTINGS = List.of( REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING, diff --git a/server/src/main/java/org/opensearch/index/query/QueryShardContext.java b/server/src/main/java/org/opensearch/index/query/QueryShardContext.java index 91313092d8d28..262937e165c8d 100644 --- a/server/src/main/java/org/opensearch/index/query/QueryShardContext.java +++ b/server/src/main/java/org/opensearch/index/query/QueryShardContext.java @@ -56,7 +56,11 @@ import org.opensearch.index.IndexSortConfig; import org.opensearch.index.analysis.IndexAnalyzers; import org.opensearch.index.cache.bitset.BitsetFilterCache; +import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; +import org.opensearch.index.compositeindex.datacube.Metric; +import org.opensearch.index.compositeindex.datacube.MetricStat; import org.opensearch.index.fielddata.IndexFieldData; +import org.opensearch.index.mapper.CompositeDataCubeFieldType; import org.opensearch.index.mapper.ContentPath; import org.opensearch.index.mapper.DerivedFieldResolver; import org.opensearch.index.mapper.DerivedFieldResolverFactory; @@ -73,12 +77,17 @@ import org.opensearch.script.ScriptContext; import org.opensearch.script.ScriptFactory; import org.opensearch.script.ScriptService; +import org.opensearch.search.aggregations.AggregatorFactory; +import org.opensearch.search.aggregations.metrics.SumAggregatorFactory; import org.opensearch.search.aggregations.support.AggregationUsageService; import org.opensearch.search.aggregations.support.ValuesSourceRegistry; import org.opensearch.search.lookup.SearchLookup; +import org.opensearch.search.startree.OriginalOrStarTreeQuery; +import org.opensearch.search.startree.StarTreeQuery; import org.opensearch.transport.RemoteClusterAware; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -89,6 +98,7 @@ import java.util.function.LongSupplier; import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.stream.Collectors; import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; @@ -522,6 +532,66 @@ private ParsedQuery toQuery(QueryBuilder queryBuilder, CheckedFunction>> predicateMap = getStarTreePredicates(queryBuilder); + StarTreeQuery starTreeQuery = new StarTreeQuery(starTree, predicateMap, null); + OriginalOrStarTreeQuery originalOrStarTreeQuery = new OriginalOrStarTreeQuery(starTreeQuery, query); + return new ParsedQuery(originalOrStarTreeQuery); + } + + /** + * Parse query body to star-tree predicates + * @param queryBuilder + * @return + */ + private Map>> getStarTreePredicates(QueryBuilder queryBuilder) { + // Assuming the following variables have been initialized: + Map>> predicateMap = new HashMap<>(); + + // Check if the query builder is an instance of TermQueryBuilder + if (queryBuilder instanceof TermQueryBuilder) { + TermQueryBuilder tq = (TermQueryBuilder) queryBuilder; + String field = tq.fieldName(); + long inputQueryVal = Long.parseLong(tq.value().toString()); + + // Get or create the list of predicates for the given field + List> predicates = predicateMap.getOrDefault(field, new ArrayList<>()); + + // Create a predicate to match the input query value + Predicate predicate = dimVal -> dimVal == inputQueryVal; + predicates.add(predicate); + + // Put the predicates list back into the map + predicateMap.put(field, predicates); + } else { + throw new IllegalArgumentException("The query is not a term query"); + } + return predicateMap; + + } + + public boolean validateStarTreeMetricSuport(CompositeDataCubeFieldType compositeIndexFieldInfo, AggregatorFactory aggregatorFactory) { + String field = null; + Map> supportedMetrics = compositeIndexFieldInfo.getMetrics() + .stream() + .collect(Collectors.toMap(Metric::getField, Metric::getMetrics)); + + // Existing support only for MetricAggregators without sub-aggregations + if (aggregatorFactory.getSubFactories().getFactories().length != 0) { + return false; + } + + // TODO: increment supported aggregation type + if (aggregatorFactory instanceof SumAggregatorFactory) { + field = ((SumAggregatorFactory) aggregatorFactory).getField(); + if (supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.SUM)) { + return true; + } + } + + return false; + } + public Index index() { return indexSettings.getIndex(); } diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index a53a7198c366f..bfb1376f0ff02 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -77,12 +77,16 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.IndexService; import org.opensearch.index.IndexSettings; +import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; import org.opensearch.index.engine.Engine; +import org.opensearch.index.mapper.CompositeDataCubeFieldType; import org.opensearch.index.mapper.DerivedFieldResolver; import org.opensearch.index.mapper.DerivedFieldResolverFactory; +import org.opensearch.index.mapper.StarTreeMapper; import org.opensearch.index.query.InnerHitContextBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.MatchNoneQueryBuilder; +import org.opensearch.index.query.ParsedQuery; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; @@ -97,11 +101,13 @@ import org.opensearch.script.ScriptService; import org.opensearch.search.aggregations.AggregationInitializationException; import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.AggregatorFactory; import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.InternalAggregation.ReduceContext; import org.opensearch.search.aggregations.MultiBucketConsumerService; import org.opensearch.search.aggregations.SearchContextAggregations; import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree; +import org.opensearch.search.aggregations.support.ValuesSourceAggregatorFactory; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.collapse.CollapseContext; import org.opensearch.search.dfs.DfsPhase; @@ -1314,6 +1320,11 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc context.evaluateRequestShouldUseConcurrentSearch(); return; } + // Can be marked false for majority cases for which star-tree cannot be used + // Will save checking the criteria later and we can have a limit on what search requests are supported + // As we increment the cases where star-tree can be used, this can be set back to true + boolean canUseStarTree = context.mapperService().isCompositeIndexPresent(); + SearchShardTarget shardTarget = context.shardTarget(); QueryShardContext queryShardContext = context.getQueryShardContext(); context.from(source.from()); @@ -1339,9 +1350,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc if (source.sorts() != null) { try { Optional optionalSort = SortBuilder.buildSort(source.sorts(), context.getQueryShardContext()); - if (optionalSort.isPresent()) { - context.sort(optionalSort.get()); - } + optionalSort.ifPresent(context::sort); } catch (IOException e) { throw new SearchException(shardTarget, "failed to create sort elements", e); } @@ -1496,6 +1505,46 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc if (source.profile()) { context.setProfilers(new Profilers(context.searcher(), context.shouldUseConcurrentSearch())); } + + if (canUseStarTree) { + try { + setStarTreeQuery(context, queryShardContext, source); + logger.info("using star tree"); + } catch (IOException e) { + logger.info("not using star tree"); + } + } + } + + private boolean setStarTreeQuery(SearchContext context, QueryShardContext queryShardContext, SearchSourceBuilder source) + throws IOException { + + if (source.aggregations() == null) { + return false; + } + + // TODO: Support for multiple startrees + CompositeDataCubeFieldType compositeMappedFieldType = (StarTreeMapper.StarTreeFieldType) context.mapperService() + .getCompositeFieldTypes() + .iterator() + .next(); + CompositeIndexFieldInfo starTree = new CompositeIndexFieldInfo( + compositeMappedFieldType.name(), + compositeMappedFieldType.getCompositeIndexType() + ); + + ParsedQuery parsedQuery = queryShardContext.toStarTreeQuery(starTree, source.query(), context.query()); + AggregatorFactory aggregatorFactory = context.aggregations().factories().getFactories()[0]; + if (!(aggregatorFactory instanceof ValuesSourceAggregatorFactory + && aggregatorFactory.getSubFactories().getFactories().length == 0)) { + return false; + } + + if (queryShardContext.validateStarTreeMetricSuport(compositeMappedFieldType, aggregatorFactory)) { + context.parsedQuery(parsedQuery); + } + + return true; } /** diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java b/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java index eeb0c606694b0..dfcb245ef3656 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java @@ -255,7 +255,7 @@ public static Builder builder() { return new Builder(); } - private AggregatorFactories(AggregatorFactory[] factories) { + public AggregatorFactories(AggregatorFactory[] factories) { this.factories = factories; } @@ -661,4 +661,8 @@ public PipelineTree buildPipelineTree() { return new PipelineTree(subTrees, aggregators); } } + + public AggregatorFactory[] getFactories() { + return factories; + } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactory.java index 6cc3a78fb1e36..86fbb46a9ad3c 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactory.java @@ -127,4 +127,8 @@ protected boolean supportsConcurrentSegmentSearch() { public boolean evaluateChildFactories() { return factories.allFactoriesSupportConcurrentSearch(); } + + public AggregatorFactories getSubFactories() { + return factories; + } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/NumericMetricsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/NumericMetricsAggregator.java index f90e5a092385f..9802e3fdae50d 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/NumericMetricsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/NumericMetricsAggregator.java @@ -31,13 +31,20 @@ package org.opensearch.search.aggregations.metrics; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SegmentReader; +import org.opensearch.common.lucene.Lucene; import org.opensearch.common.util.Comparators; +import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; +import org.opensearch.index.codec.composite.CompositeIndexReader; +import org.opensearch.index.codec.composite.datacube.startree.StarTreeValues; import org.opensearch.search.aggregations.Aggregator; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.sort.SortOrder; import java.io.IOException; import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; /** * Base class to aggregate all docs into a single numeric metric value. @@ -107,4 +114,14 @@ public BucketComparator bucketComparator(String key, SortOrder order) { return (lhs, rhs) -> Comparators.compareDiscardNaN(metric(key, lhs), metric(key, rhs), order == SortOrder.ASC); } } + + protected StarTreeValues getStarTreeValues(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException { + SegmentReader reader = Lucene.segmentReader(ctx.reader()); + if (!(reader.getDocValuesReader() instanceof CompositeIndexReader)) return null; + CompositeIndexReader starTreeDocValuesReader = (CompositeIndexReader) reader.getDocValuesReader(); + StarTreeValues values = (StarTreeValues) starTreeDocValuesReader.getCompositeIndexValues(starTree); + final AtomicReference aggrVal = new AtomicReference<>(null); + + return values; + } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java index 4b8e882cd69bc..aa71224588c2c 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java @@ -36,6 +36,8 @@ import org.opensearch.common.lease.Releasables; import org.opensearch.common.util.BigArrays; import org.opensearch.common.util.DoubleArray; +import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; +import org.opensearch.index.codec.composite.datacube.startree.StarTreeValues; import org.opensearch.index.fielddata.SortedNumericDoubleValues; import org.opensearch.search.DocValueFormat; import org.opensearch.search.aggregations.Aggregator; @@ -45,6 +47,8 @@ import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.startree.OriginalOrStarTreeQuery; +import org.opensearch.search.startree.StarTreeQuery; import java.io.IOException; import java.util.Map; @@ -56,13 +60,13 @@ */ public class SumAggregator extends NumericMetricsAggregator.SingleValue { - private final ValuesSource.Numeric valuesSource; - private final DocValueFormat format; + protected final ValuesSource.Numeric valuesSource; + protected final DocValueFormat format; - private DoubleArray sums; - private DoubleArray compensations; + protected DoubleArray sums; + protected DoubleArray compensations; - SumAggregator( + public SumAggregator( String name, ValuesSourceConfig valuesSourceConfig, SearchContext context, @@ -86,6 +90,14 @@ public ScoreMode scoreMode() { @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { + if (context.query() instanceof OriginalOrStarTreeQuery && ((OriginalOrStarTreeQuery) context.query()).isStarTreeUsed()) { + StarTreeQuery starTreeQuery = ((OriginalOrStarTreeQuery) context.query()).getStarTreeQuery(); + return getStarTreeLeafCollector(ctx, sub, starTreeQuery.getStarTree()); + } + return getDefaultLeafCollector(ctx, sub); + } + + private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { if (valuesSource == null) { return LeafBucketCollector.NO_OP_COLLECTOR; } @@ -118,6 +130,28 @@ public void collect(int doc, long bucket) throws IOException { }; } + private LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree) + throws IOException { + final BigArrays bigArrays = context.bigArrays(); + final CompensatedSum kahanSummation = new CompensatedSum(0, 0); + + StarTreeValues starTreeValues = getStarTreeValues(ctx, starTree); + + // + String fieldName = ((ValuesSource.Numeric.FieldData) valuesSource).getIndexFieldName(); + + return new LeafBucketCollectorBase(sub, starTreeValues) { + @Override + public void collect(int doc, long bucket) throws IOException { + // TODO: Fix the response for collecting star tree sum + sums = bigArrays.grow(sums, bucket + 1); + compensations = bigArrays.grow(compensations, bucket + 1); + compensations.set(bucket, kahanSummation.delta()); + sums.set(bucket, kahanSummation.value()); + } + }; + } + @Override public double metric(long owningBucketOrd) { if (valuesSource == null || owningBucketOrd >= sums.size()) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregatorFactory.java index ef9b93920ba18..e0cd44f2672a8 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregatorFactory.java @@ -52,7 +52,7 @@ * * @opensearch.internal */ -class SumAggregatorFactory extends ValuesSourceAggregatorFactory { +public class SumAggregatorFactory extends ValuesSourceAggregatorFactory { SumAggregatorFactory( String name, diff --git a/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSource.java b/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSource.java index 1f4dd429e094e..5732d545cb2d2 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSource.java +++ b/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSource.java @@ -625,6 +625,10 @@ public SortedNumericDocValues longValues(LeafReaderContext context) { public SortedNumericDoubleValues doubleValues(LeafReaderContext context) { return indexFieldData.load(context).getDoubleValues(); } + + public String getIndexFieldName() { + return indexFieldData.getFieldName(); + } } /** diff --git a/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSourceAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSourceAggregatorFactory.java index 69a4a5d8b6703..b19e466b081f9 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSourceAggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSourceAggregatorFactory.java @@ -102,4 +102,16 @@ protected abstract Aggregator doCreateInternal( public String getStatsSubtype() { return config.valueSourceType().typeName(); } + + public String getField() { + return config.fieldContext().field(); + } + + public String getAggregationName() { + return name; + } + + public ValuesSourceConfig getConfig() { + return config; + } } diff --git a/server/src/main/java/org/opensearch/search/startree/OriginalOrStarTreeQuery.java b/server/src/main/java/org/opensearch/search/startree/OriginalOrStarTreeQuery.java new file mode 100644 index 0000000000000..bc8ef51bfb537 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/startree/OriginalOrStarTreeQuery.java @@ -0,0 +1,81 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.startree; + +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Accountable; + +import java.io.IOException; + +/** + * Preserves star-tree queries which can be used along with original query + * Decides which star-tree query to use (or not) based on cost factors + */ +public class OriginalOrStarTreeQuery extends Query implements Accountable { + + private final StarTreeQuery starTreeQuery; + private final Query originalQuery; + private boolean starTreeQueryUsed; + + public OriginalOrStarTreeQuery(StarTreeQuery starTreeQuery, Query originalQuery) { + this.starTreeQuery = starTreeQuery; + this.originalQuery = originalQuery; + this.starTreeQueryUsed = false; + } + + @Override + public String toString(String s) { + return ""; + } + + @Override + public void visit(QueryVisitor queryVisitor) { + + } + + @Override + public boolean equals(Object o) { + return false; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public long ramBytesUsed() { + return 0; + } + + public boolean isStarTreeUsed() { + return starTreeQueryUsed; + } + + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + if (searcher.getIndexReader().hasDeletions() == false) { + this.starTreeQueryUsed = true; + return this.starTreeQuery.createWeight(searcher, scoreMode, boost); + } else { + return this.originalQuery.createWeight(searcher, scoreMode, boost); + } + } + + public Query getOriginalQuery() { + return originalQuery; + } + + public StarTreeQuery getStarTreeQuery() { + return starTreeQuery; + } +} diff --git a/server/src/main/java/org/opensearch/search/startree/StarTreeFilter.java b/server/src/main/java/org/opensearch/search/startree/StarTreeFilter.java new file mode 100644 index 0000000000000..a3a000703a3b8 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/startree/StarTreeFilter.java @@ -0,0 +1,294 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.startree; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.DocIdSetBuilder; +import org.opensearch.index.codec.composite.datacube.startree.StarTreeValues; +import org.opensearch.index.compositeindex.datacube.startree.node.StarTreeNode; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.function.Predicate; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +// TODO: Code inspired from a POC - need to review this once +/** Filter operator for star tree data structure. */ +public class StarTreeFilter { + private static final Logger logger = LogManager.getLogger(StarTreeFilter.class); + + /** Helper class to wrap the result from traversing the star tree. */ + static class StarTreeResult { + final DocIdSetBuilder _matchedDocIds; + final Set _remainingPredicateColumns; + final int numOfMatchedDocs; + final int maxMatchedDoc; + + StarTreeResult(DocIdSetBuilder matchedDocIds, Set remainingPredicateColumns, int numOfMatchedDocs, int maxMatchedDoc) { + _matchedDocIds = matchedDocIds; + _remainingPredicateColumns = remainingPredicateColumns; + this.numOfMatchedDocs = numOfMatchedDocs; + this.maxMatchedDoc = maxMatchedDoc; + } + } + + private final StarTreeNode starTreeRoot; + + Map>> _predicateEvaluators; + // private final List _groupByColumns; + + DocIdSetBuilder docsWithField; + + DocIdSetBuilder.BulkAdder adder; + Map dimValueMap; + + public StarTreeFilter( + StarTreeValues starTreeAggrStructure, + Map>> predicateEvaluators, + List groupByColumns + ) { + // This filter operator does not support AND/OR/NOT operations. + starTreeRoot = starTreeAggrStructure.getRoot(); + dimValueMap = starTreeAggrStructure.getDimensionDocValuesIteratorMap(); + _predicateEvaluators = predicateEvaluators != null ? predicateEvaluators : Collections.emptyMap(); + // _groupByColumns = groupByColumns != null ? groupByColumns : Collections.emptyList(); + + // TODO : this should be the maximum number of doc values + docsWithField = new DocIdSetBuilder(Integer.MAX_VALUE); + } + + /** + *
    + *
  • First go over the star tree and try to match as many dimensions as possible + *
  • For the remaining columns, use doc values indexes to match them + *
+ */ + public DocIdSetIterator getStarTreeResult() throws IOException { + StarTreeResult starTreeResult = traverseStarTree(); + List andIterators = new ArrayList<>(); + andIterators.add(starTreeResult._matchedDocIds.build().iterator()); + DocIdSetIterator docIdSetIterator = andIterators.get(0); + // No matches, return + if (starTreeResult.maxMatchedDoc == -1) { + return docIdSetIterator; + } + int docCount = 0; + for (String remainingPredicateColumn : starTreeResult._remainingPredicateColumns) { + // TODO : set to max value of doc values + logger.info("remainingPredicateColumn : {}, maxMatchedDoc : {} ", remainingPredicateColumn, starTreeResult.maxMatchedDoc); + DocIdSetBuilder builder = new DocIdSetBuilder(starTreeResult.maxMatchedDoc + 1); + List> compositePredicateEvaluators = _predicateEvaluators.get(remainingPredicateColumn); + SortedNumericDocValues ndv = (SortedNumericDocValues) this.dimValueMap.get(remainingPredicateColumn); + List docIds = new ArrayList<>(); + while (docIdSetIterator.nextDoc() != NO_MORE_DOCS) { + docCount++; + int docID = docIdSetIterator.docID(); + if (ndv.advanceExact(docID)) { + final int valuesCount = ndv.docValueCount(); + long value = ndv.nextValue(); + for (Predicate compositePredicateEvaluator : compositePredicateEvaluators) { + // TODO : this might be expensive as its done against all doc values docs + if (compositePredicateEvaluator.test(value)) { + docIds.add(docID); + for (int i = 0; i < valuesCount - 1; i++) { + while (docIdSetIterator.nextDoc() != NO_MORE_DOCS) { + docIds.add(docIdSetIterator.docID()); + } + } + break; + } + } + } + } + DocIdSetBuilder.BulkAdder adder = builder.grow(docIds.size()); + for (int docID : docIds) { + adder.add(docID); + } + docIdSetIterator = builder.build().iterator(); + } + return docIdSetIterator; + } + + /** + * Helper method to traverse the star tree, get matching documents and keep track of all the + * predicate dimensions that are not matched. + */ + private StarTreeResult traverseStarTree() throws IOException { + Set globalRemainingPredicateColumns = null; + + StarTreeNode starTree = starTreeRoot; + + List dimensionNames = new ArrayList<>(dimValueMap.keySet()); + + // Track whether we have found a leaf node added to the queue. If we have found a leaf node, and + // traversed to the + // level of the leave node, we can set globalRemainingPredicateColumns if not already set + // because we know the leaf + // node won't split further on other predicate columns. + boolean foundLeafNode = starTree.isLeaf(); + + // Use BFS to traverse the star tree + Queue queue = new ArrayDeque<>(); + queue.add(starTree); + int currentDimensionId = -1; + Set remainingPredicateColumns = new HashSet<>(_predicateEvaluators.keySet()); + // Set remainingGroupByColumns = new HashSet<>(_groupByColumns); + if (foundLeafNode) { + globalRemainingPredicateColumns = new HashSet<>(remainingPredicateColumns); + } + + int matchedDocsCountInStarTree = 0; + int maxDocNum = -1; + + StarTreeNode starTreeNode; + List docIds = new ArrayList<>(); + while ((starTreeNode = queue.poll()) != null) { + int dimensionId = starTreeNode.getDimensionId(); + if (dimensionId > currentDimensionId) { + // Previous level finished + String dimension = dimensionNames.get(dimensionId); + remainingPredicateColumns.remove(dimension); + // remainingGroupByColumns.remove(dimension); + if (foundLeafNode && globalRemainingPredicateColumns == null) { + globalRemainingPredicateColumns = new HashSet<>(remainingPredicateColumns); + } + currentDimensionId = dimensionId; + } + + // If all predicate columns columns are matched, we can use aggregated document + if (remainingPredicateColumns.isEmpty()) { + int docId = starTreeNode.getAggregatedDocId(); + docIds.add(docId); + matchedDocsCountInStarTree++; + maxDocNum = Math.max(docId, maxDocNum); + continue; + } + + // For leaf node, because we haven't exhausted all predicate columns and group-by columns, we + // cannot use + // the aggregated document. Add the range of documents for this node to the bitmap, and keep + // track of the + // remaining predicate columns for this node + if (starTreeNode.isLeaf()) { + for (long i = starTreeNode.getStartDocId(); i < starTreeNode.getEndDocId(); i++) { + docIds.add((int) i); + matchedDocsCountInStarTree++; + maxDocNum = Math.max((int) i, maxDocNum); + } + continue; + } + + // For non-leaf node, proceed to next level + String childDimension = dimensionNames.get(dimensionId + 1); + + // Only read star-node when the dimension is not in the global remaining predicate columns or + // group-by columns + // because we cannot use star-node in such cases + StarTreeNode starNode = null; + if ((globalRemainingPredicateColumns == null || !globalRemainingPredicateColumns.contains(childDimension))) { + starNode = starTreeNode.getChildForDimensionValue(StarTreeNode.ALL); + } + + if (remainingPredicateColumns.contains(childDimension)) { + // Have predicates on the next level, add matching nodes to the queue + + // Calculate the matching dictionary ids for the child dimension + int numChildren = starTreeNode.getNumChildren(); + + // If number of matching dictionary ids is large, use scan instead of binary search + + Iterator childrenIterator = starTreeNode.getChildrenIterator(); + + // When the star-node exists, and the number of matching doc ids is more than or equal to + // the + // number of non-star child nodes, check if all the child nodes match the predicate, and use + // the star-node if so + if (starNode != null) { + List matchingChildNodes = new ArrayList<>(); + boolean findLeafChildNode = false; + while (childrenIterator.hasNext()) { + StarTreeNode childNode = childrenIterator.next(); + List> predicates = _predicateEvaluators.get(childDimension); + for (Predicate predicate : predicates) { + long val = childNode.getDimensionValue(); + if (predicate.test(val)) { + matchingChildNodes.add(childNode); + findLeafChildNode |= childNode.isLeaf(); + break; + } + } + } + if (matchingChildNodes.size() == numChildren - 1) { + // All the child nodes (except for the star-node) match the predicate, use the star-node + queue.add(starNode); + foundLeafNode |= starNode.isLeaf(); + } else { + // Some child nodes do not match the predicate, use the matching child nodes + queue.addAll(matchingChildNodes); + foundLeafNode |= findLeafChildNode; + } + } else { + // Cannot use the star-node, use the matching child nodes + while (childrenIterator.hasNext()) { + StarTreeNode childNode = childrenIterator.next(); + List> predicates = _predicateEvaluators.get(childDimension); + for (Predicate predicate : predicates) { + if (predicate.test(childNode.getDimensionValue())) { + queue.add(childNode); + foundLeafNode |= childNode.isLeaf(); + break; + } + } + } + } + } else { + // No predicate on the next level + + if (starNode != null) { + // Star-node exists, use it + queue.add(starNode); + foundLeafNode |= starNode.isLeaf(); + } else { + // Star-node does not exist or cannot be used, add all non-star nodes to the queue + Iterator childrenIterator = starTreeNode.getChildrenIterator(); + while (childrenIterator.hasNext()) { + StarTreeNode childNode = childrenIterator.next(); + if (childNode.getDimensionValue() != StarTreeNode.ALL) { + queue.add(childNode); + foundLeafNode |= childNode.isLeaf(); + } + } + } + } + } + + adder = docsWithField.grow(docIds.size()); + for (int id : docIds) { + adder.add(id); + } + return new StarTreeResult( + docsWithField, + globalRemainingPredicateColumns != null ? globalRemainingPredicateColumns : Collections.emptySet(), + matchedDocsCountInStarTree, + maxDocNum + ); + } +} diff --git a/server/src/main/java/org/opensearch/search/startree/StarTreeQuery.java b/server/src/main/java/org/opensearch/search/startree/StarTreeQuery.java new file mode 100644 index 0000000000000..94b293335e3cc --- /dev/null +++ b/server/src/main/java/org/opensearch/search/startree/StarTreeQuery.java @@ -0,0 +1,123 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.startree; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.search.ConstantScoreScorer; +import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Accountable; +import org.opensearch.common.lucene.Lucene; +import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; +import org.opensearch.index.codec.composite.CompositeIndexReader; +import org.opensearch.index.codec.composite.datacube.startree.StarTreeValues; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + +/** Query class for querying star tree data structure */ +public class StarTreeQuery extends Query implements Accountable { + + /** + * Star tree field info + * This is used to get the star tree data structure + */ + CompositeIndexFieldInfo starTree; + + /** + * Map of field name to a list of predicates to be applied on that field + * This is used to filter the data based on the predicates + */ + Map>> compositePredicateMap; + + /** + * Set of field names to be used for grouping the results + * This is used to group the data based on the fields + */ + List groupByColumns; + + public StarTreeQuery( + CompositeIndexFieldInfo starTree, + Map>> compositePredicateMap, + List groupByColumns + ) { + this.starTree = starTree; + this.compositePredicateMap = compositePredicateMap; + this.groupByColumns = groupByColumns; + } + + @Override + public String toString(String field) { + return null; + } + + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this); + } + + @Override + public boolean equals(Object obj) { + return sameClassAs(obj); + } + + @Override + public int hashCode() { + return classHash(); + } + + @Override + public long ramBytesUsed() { + return 0; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + return new ConstantScoreWeight(this, boost) { + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + SegmentReader reader = Lucene.segmentReader(context.reader()); + + // We get the 'CompositeIndexReader' instance so that we can get StarTreeValues + if (!(reader.getDocValuesReader() instanceof CompositeIndexReader)) return null; + + CompositeIndexReader starTreeDocValuesReader = (CompositeIndexReader) reader.getDocValuesReader(); + List compositeIndexFields = starTreeDocValuesReader.getCompositeIndexFields(); + StarTreeValues starTreeValues = null; + if (compositeIndexFields != null && !compositeIndexFields.isEmpty()) { + starTreeValues = (StarTreeValues) starTreeDocValuesReader.getCompositeIndexValues(starTree); + } else { + return null; + } + + StarTreeFilter filter = new StarTreeFilter(starTreeValues, compositePredicateMap, groupByColumns); + DocIdSetIterator result = filter.getStarTreeResult(); + return new ConstantScoreScorer(this, score(), scoreMode, result); + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + }; + } + + public CompositeIndexFieldInfo getStarTree() { + return starTree; + } +} diff --git a/server/src/main/java/org/opensearch/search/startree/package-info.java b/server/src/main/java/org/opensearch/search/startree/package-info.java new file mode 100644 index 0000000000000..cb0802988f1f9 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/startree/package-info.java @@ -0,0 +1,9 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.startree;