Skip to content

Commit

Permalink
Top & Rare PPL commands support (#568) (#584)
Browse files Browse the repository at this point in the history
* Adding support for Rare & Top PPL

top [N] <field-list> [by-clause]

N: number of results to return. Default: 10
field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------

rare <field-list> [by-clause]

field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------
commands:
 - #461
 - #536


* Adding support for Rare & Top PPL

top [N] <field-list> [by-clause]

N: number of results to return. Default: 10
field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------

rare <field-list> [by-clause]

field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------
commands:
 - #461
 - #536


* Adding support for Rare & Top PPL

top [N] <field-list> [by-clause]

N: number of results to return. Default: 10
field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------

rare <field-list> [by-clause]

field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------
commands:
 - #461
 - #536


* Adding support for Rare & Top PPL

top [N] <field-list> [by-clause]

N: number of results to return. Default: 10
field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------

rare <field-list> [by-clause]

field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------
commands:
 - #461
 - #536


* Adding support for Rare & Top PPL

top [N] <field-list> [by-clause]

N: number of results to return. Default: 10
field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------

rare <field-list> [by-clause]

field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------
commands:
 - #461
 - #536


* Adding support for Rare & Top PPL

top [N] <field-list> [by-clause]

N: number of results to return. Default: 10
field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------

rare <field-list> [by-clause]

field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------
commands:
 - #461
 - #536


* update scala fmt style




* add additional support for `rare` & `top` commands options



* add additional support for `rare` & `top` commands options including top N ...



* update scalafmtAll style format



* remove unrelated agg test




---------


(cherry picked from commit 4af03c2)

Signed-off-by: YANGDB <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
1 parent 601c5a4 commit fad5fc6
Show file tree
Hide file tree
Showing 8 changed files with 553 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.command.DescribeTableCommand
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLTopAndRareITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedMultiRowAddressTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("create ppl rare address field query test") {
val frame = sql(s"""
| source = $testTable| rare address
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 3)

val expectedRow = Row(1, "Vancouver")
assert(
results.head == expectedRow,
s"Expected least frequent result to be $expectedRow, but got ${results.head}")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val addressField = UnresolvedAttribute("address")
val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))

val aggregateExpressions = Seq(
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false),
"count(address)")(),
addressField)
val aggregatePlan =
Aggregate(
Seq(addressField),
aggregateExpressions,
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("address"), Descending)),
global = true,
aggregatePlan)
val expectedPlan = Project(projectList, sortedPlan)
comparePlans(expectedPlan, logicalPlan, false)
}

test("create ppl rare address by age field query test") {
val frame = sql(s"""
| source = $testTable| rare address by age
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 5)

val expectedRow = Row(1, "Vancouver", 60)
assert(
results.head == expectedRow,
s"Expected least frequent result to be $expectedRow, but got ${results.head}")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val addressField = UnresolvedAttribute("address")
val ageField = UnresolvedAttribute("age")
val ageAlias = Alias(ageField, "age")()

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))

val countExpr = Alias(
UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false),
"count(address)")()

val aggregateExpressions = Seq(countExpr, addressField, ageAlias)
val aggregatePlan =
Aggregate(
Seq(addressField, ageAlias),
aggregateExpressions,
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))

val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("address"), Descending)),
global = true,
aggregatePlan)

val expectedPlan = Project(projectList, sortedPlan)
comparePlans(expectedPlan, logicalPlan, false)
}

test("create ppl top address field query test") {
val frame = sql(s"""
| source = $testTable| top address
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 3)

val expectedRows = Set(Row(2, "Portland"), Row(2, "Seattle"))
val actualRows = results.take(2).toSet

// Compare the sets
assert(
actualRows == expectedRows,
s"The first two results do not match the expected rows. Expected: $expectedRows, Actual: $actualRows")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val addressField = UnresolvedAttribute("address")
val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))

val aggregateExpressions = Seq(
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false),
"count(address)")(),
addressField)
val aggregatePlan =
Aggregate(
Seq(addressField),
aggregateExpressions,
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("address"), Ascending)),
global = true,
aggregatePlan)
val expectedPlan = Project(projectList, sortedPlan)
comparePlans(expectedPlan, logicalPlan, false)
}

test("create ppl top 3 countries by occupation field query test") {
val newTestTable = "spark_catalog.default.new_flint_ppl_test"
createOccupationTable(newTestTable)

val frame = sql(s"""
| source = $newTestTable| top 3 country by occupation
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 3)

val expectedRows = Set(
Row(1, "Canada", "Doctor"),
Row(1, "Canada", "Scientist"),
Row(1, "Canada", "Unemployed"))
val actualRows = results.take(3).toSet

// Compare the sets
assert(
actualRows == expectedRows,
s"The first two results do not match the expected rows. Expected: $expectedRows, Actual: $actualRows")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical

val countryField = UnresolvedAttribute("country")
val occupationField = UnresolvedAttribute("occupation")
val occupationFieldAlias = Alias(occupationField, "occupation")()

val countExpr = Alias(
UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false),
"count(country)")()
val aggregateExpressions = Seq(countExpr, countryField, occupationFieldAlias)
val aggregatePlan =
Aggregate(
Seq(countryField, occupationFieldAlias),
aggregateExpressions,
UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test")))

val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("country"), Ascending)),
global = true,
aggregatePlan)

val planWithLimit =
GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan))
val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit)
comparePlans(expectedPlan, logicalPlan, false)
}
}
12 changes: 10 additions & 2 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ Limitation: Overriding existing field is unsupported, following queries throw ex
- `source = table | stats sum(productsAmount) by span(transactionDate, 1w) as age_date, productId`

**Dedup**

- `source = table | dedup a | fields a,b,c`
- `source = table | dedup a,b | fields a,b,c`
- `source = table | dedup a keepempty=true | fields a,b,c`
Expand All @@ -290,8 +289,17 @@ Limitation: Overriding existing field is unsupported, following queries throw ex
- `source = table | dedup 1 a consecutive=true| fields a,b,c` (Unsupported)
- `source = table | dedup 2 a | fields a,b,c` (Unsupported)

**Rare**
- `source=accounts | rare gender`
- `source=accounts | rare age by gender`

**Top**
- `source=accounts | top gender`
- `source=accounts | top 1 gender`
- `source=accounts | top 1 age by gender`


For additional details on PPL commands - view [PPL Commands Docs](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/index.rst)
> For additional details on PPL commands - view [PPL Commands Docs](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/index.rst)
---

Expand Down
2 changes: 2 additions & 0 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ commands
| dedupCommand
| sortCommand
| headCommand
| topCommand
| rareCommand
| evalCommand
;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import org.opensearch.sql.ast.expression.UnresolvedExpression;

import java.util.Collections;
import java.util.List;

/** Logical plan node of Rare (Aggregation) command, the interface for building aggregation actions in queries. */
public class RareAggregation extends Aggregation {
/** Aggregation Constructor without span and argument. */
public RareAggregation(
List<UnresolvedExpression> aggExprList,
List<UnresolvedExpression> sortExprList,
List<UnresolvedExpression> groupExprList) {
super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

import java.util.Collections;
import java.util.List;
import java.util.Optional;

/** Logical plan node of Top (Aggregation) command, the interface for building aggregation actions in queries. */
public class TopAggregation extends Aggregation {
private final Optional<Literal> results;

/** Aggregation Constructor without span and argument. */
public TopAggregation(
Optional<Literal> results,
List<UnresolvedExpression> aggExprList,
List<UnresolvedExpression> sortExprList,
List<UnresolvedExpression> groupExprList) {
super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList());
this.results = results;
}

public Optional<Literal> getResults() {
return results;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
import org.apache.spark.sql.catalyst.expressions.Ascending$;
import org.apache.spark.sql.catalyst.expressions.Descending$;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.expressions.SortDirection;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$;
Expand Down Expand Up @@ -59,9 +62,11 @@
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Project;
import org.opensearch.sql.ast.tree.RareAggregation;
import org.opensearch.sql.ast.tree.RareTopN;
import org.opensearch.sql.ast.tree.Relation;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.TopAggregation;
import org.opensearch.sql.ppl.utils.AggregatorTranslator;
import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator;
import org.opensearch.sql.ppl.utils.ComparatorTransformer;
Expand Down Expand Up @@ -176,20 +181,39 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex
node.getChild().get(0).accept(this, context);
List<Expression> aggsExpList = visitExpressionList(node.getAggExprList(), context);
List<Expression> groupExpList = visitExpressionList(node.getGroupExprList(), context);

if (!groupExpList.isEmpty()) {
//add group by fields to context
context.getGroupingParseExpressions().addAll(groupExpList);
}

UnresolvedExpression span = node.getSpan();
if (!Objects.isNull(span)) {
span.accept(this, context);
//add span's group alias field (most recent added expression)
context.getGroupingParseExpressions().add(context.getNamedParseExpressions().peek());
}
// build the aggregation logical step
return extractedAggregation(context);
LogicalPlan logicalPlan = extractedAggregation(context);

// set sort direction according to command type (`rare` is Asc, `top` is Desc, default to Asc)
List<SortDirection> sortDirections = new ArrayList<>();
sortDirections.add(node instanceof RareAggregation ? Descending$.MODULE$ : Ascending$.MODULE$);

if (!node.getSortExprList().isEmpty()) {
visitExpressionList(node.getSortExprList(), context);
Seq<SortOrder> sortElements = context.retainAllNamedParseExpressions(exp ->
new SortOrder(exp,
sortDirections.get(0),
sortDirections.get(0).defaultNullOrdering(),
seq(new ArrayList<Expression>())));
context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, logicalPlan));
}
//visit TopAggregation results limit
if((node instanceof TopAggregation) && ((TopAggregation) node).getResults().isPresent()) {
context.apply(p ->(LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal(
((TopAggregation) node).getResults().get().getValue(), org.apache.spark.sql.types.DataTypes.IntegerType), p));
}
return logicalPlan;
}

private static LogicalPlan extractedAggregation(CatalystPlanContext context) {
Expand Down
Loading

0 comments on commit fad5fc6

Please sign in to comment.