diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala new file mode 100644 index 000000000..09307aa44 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala @@ -0,0 +1,216 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.command.DescribeTableCommand +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLTopAndRareITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedMultiRowAddressTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("create ppl rare address field query test") { + val frame = sql(s""" + | source = $testTable| rare address + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + val expectedRow = Row(1, "Vancouver") + assert( + results.head == expectedRow, + s"Expected least frequent result to be $expectedRow, but got ${results.head}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val addressField = UnresolvedAttribute("address") + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count(address)")(), + addressField) + val aggregatePlan = + Aggregate( + Seq(addressField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("address"), Descending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, false) + } + + test("create ppl rare address by age field query test") { + val frame = sql(s""" + | source = $testTable| rare address by age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 5) + + val expectedRow = Row(1, "Vancouver", 60) + assert( + results.head == expectedRow, + s"Expected least frequent result to be $expectedRow, but got ${results.head}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val addressField = UnresolvedAttribute("address") + val ageField = UnresolvedAttribute("age") + val ageAlias = Alias(ageField, "age")() + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count(address)")() + + val aggregateExpressions = Seq(countExpr, addressField, ageAlias) + val aggregatePlan = + Aggregate( + Seq(addressField, ageAlias), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("address"), Descending)), + global = true, + aggregatePlan) + + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, false) + } + + test("create ppl top address field query test") { + val frame = sql(s""" + | source = $testTable| top address + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + val expectedRows = Set(Row(2, "Portland"), Row(2, "Seattle")) + val actualRows = results.take(2).toSet + + // Compare the sets + assert( + actualRows == expectedRows, + s"The first two results do not match the expected rows. Expected: $expectedRows, Actual: $actualRows") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val addressField = UnresolvedAttribute("address") + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count(address)")(), + addressField) + val aggregatePlan = + Aggregate( + Seq(addressField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("address"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, false) + } + + test("create ppl top 3 countries by occupation field query test") { + val newTestTable = "spark_catalog.default.new_flint_ppl_test" + createOccupationTable(newTestTable) + + val frame = sql(s""" + | source = $newTestTable| top 3 country by occupation + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + val expectedRows = Set( + Row(1, "Canada", "Doctor"), + Row(1, "Canada", "Scientist"), + Row(1, "Canada", "Unemployed")) + val actualRows = results.take(3).toSet + + // Compare the sets + assert( + actualRows == expectedRows, + s"The first two results do not match the expected rows. Expected: $expectedRows, Actual: $actualRows") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val countryField = UnresolvedAttribute("country") + val occupationField = UnresolvedAttribute("occupation") + val occupationFieldAlias = Alias(occupationField, "occupation")() + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count(country)")() + val aggregateExpressions = Seq(countExpr, countryField, occupationFieldAlias) + val aggregatePlan = + Aggregate( + Seq(countryField, occupationFieldAlias), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, false) + } +} diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index 575fd6b0c..116bed931 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -278,7 +278,6 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source = table | stats sum(productsAmount) by span(transactionDate, 1w) as age_date, productId` **Dedup** - - `source = table | dedup a | fields a,b,c` - `source = table | dedup a,b | fields a,b,c` - `source = table | dedup a keepempty=true | fields a,b,c` @@ -290,8 +289,17 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source = table | dedup 1 a consecutive=true| fields a,b,c` (Unsupported) - `source = table | dedup 2 a | fields a,b,c` (Unsupported) +**Rare** +- `source=accounts | rare gender` +- `source=accounts | rare age by gender` + +**Top** +- `source=accounts | top gender` +- `source=accounts | top 1 gender` +- `source=accounts | top 1 age by gender` + -For additional details on PPL commands - view [PPL Commands Docs](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/index.rst) +> For additional details on PPL commands - view [PPL Commands Docs](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/index.rst) --- diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 6f56550c9..76e65753b 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -38,6 +38,8 @@ commands | dedupCommand | sortCommand | headCommand + | topCommand + | rareCommand | evalCommand ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java new file mode 100644 index 000000000..55b2e4c43 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.Collections; +import java.util.List; + +/** Logical plan node of Rare (Aggregation) command, the interface for building aggregation actions in queries. */ +public class RareAggregation extends Aggregation { + /** Aggregation Constructor without span and argument. */ + public RareAggregation( + List aggExprList, + List sortExprList, + List groupExprList) { + super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList()); + } + +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java new file mode 100644 index 000000000..451446cc3 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** Logical plan node of Top (Aggregation) command, the interface for building aggregation actions in queries. */ +public class TopAggregation extends Aggregation { + private final Optional results; + + /** Aggregation Constructor without span and argument. */ + public TopAggregation( + Optional results, + List aggExprList, + List sortExprList, + List groupExprList) { + super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList()); + this.results = results; + } + + public Optional getResults() { + return results; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index ce671f085..7cddd09fd 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -9,9 +9,12 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.expressions.Ascending$; +import org.apache.spark.sql.catalyst.expressions.Descending$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; @@ -59,9 +62,11 @@ import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.RareAggregation; import org.opensearch.sql.ast.tree.RareTopN; import org.opensearch.sql.ast.tree.Relation; import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.TopAggregation; import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; @@ -176,12 +181,11 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex node.getChild().get(0).accept(this, context); List aggsExpList = visitExpressionList(node.getAggExprList(), context); List groupExpList = visitExpressionList(node.getGroupExprList(), context); - if (!groupExpList.isEmpty()) { //add group by fields to context context.getGroupingParseExpressions().addAll(groupExpList); } - + UnresolvedExpression span = node.getSpan(); if (!Objects.isNull(span)) { span.accept(this, context); @@ -189,7 +193,27 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex context.getGroupingParseExpressions().add(context.getNamedParseExpressions().peek()); } // build the aggregation logical step - return extractedAggregation(context); + LogicalPlan logicalPlan = extractedAggregation(context); + + // set sort direction according to command type (`rare` is Asc, `top` is Desc, default to Asc) + List sortDirections = new ArrayList<>(); + sortDirections.add(node instanceof RareAggregation ? Descending$.MODULE$ : Ascending$.MODULE$); + + if (!node.getSortExprList().isEmpty()) { + visitExpressionList(node.getSortExprList(), context); + Seq sortElements = context.retainAllNamedParseExpressions(exp -> + new SortOrder(exp, + sortDirections.get(0), + sortDirections.get(0).defaultNullOrdering(), + seq(new ArrayList()))); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, logicalPlan)); + } + //visit TopAggregation results limit + if((node instanceof TopAggregation) && ((TopAggregation) node).getResults().isPresent()) { + context.apply(p ->(LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( + ((TopAggregation) node).getResults().get().getValue(), org.apache.spark.sql.types.DataTypes.IntegerType), p)); + } + return logicalPlan; } private static LogicalPlan extractedAggregation(CatalystPlanContext context) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index e94d4e0f4..7d91bbb7a 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -12,7 +12,9 @@ import org.antlr.v4.runtime.tree.ParseTree; import org.opensearch.flint.spark.ppl.OpenSearchPPLParser; import org.opensearch.flint.spark.ppl.OpenSearchPPLParserBaseVisitor; +import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.FieldsMapping; @@ -35,11 +37,13 @@ import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.RareAggregation; import org.opensearch.sql.ast.tree.RareTopN; import org.opensearch.sql.ast.tree.Relation; import org.opensearch.sql.ast.tree.Rename; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.TableFunction; +import org.opensearch.sql.ast.tree.TopAggregation; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ppl.utils.ArgumentFactory; @@ -236,12 +240,6 @@ private List getFieldList(OpenSearchPPLParser.FieldListContext ctx) { .collect(Collectors.toList()); } - /** Rare command. */ - @Override - public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { - throw new RuntimeException("Rare Command is not supported "); - } - @Override public UnresolvedPlan visitGrokCommand(OpenSearchPPLParser.GrokCommandContext ctx) { UnresolvedExpression sourceField = internalVisitExpression(ctx.source_field); @@ -278,13 +276,91 @@ public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandCo /** Top command. */ @Override public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { - List groupList = - ctx.byClause() == null ? emptyList() : getGroupByList(ctx.byClause()); - return new RareTopN( - RareTopN.CommandType.TOP, - ArgumentFactory.getArgumentList(ctx), - getFieldList(ctx.fieldList()), - groupList); + ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); + ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); + ImmutableList.Builder sortListBuilder = new ImmutableList.Builder<>(); + ctx.fieldList().fieldExpression().forEach(field -> { + UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), + Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); + String name = field.qualifiedName().getText(); + Alias alias = new Alias("count("+name+")", aggExpression); + aggListBuilder.add(alias); + // group by the `field-list` as the mandatory groupBy fields + groupListBuilder.add(internalVisitExpression(field)); + }); + + // group by the `by-clause` as the optional groupBy fields + groupListBuilder.addAll( + Optional.ofNullable(ctx.byClause()) + .map(OpenSearchPPLParser.ByClauseContext::fieldList) + .map( + expr -> + expr.fieldExpression().stream() + .map( + groupCtx -> + (UnresolvedExpression) + new Alias( + getTextInQuery(groupCtx), + internalVisitExpression(groupCtx))) + .collect(Collectors.toList())) + .orElse(emptyList()) + ); + //build the sort fields + ctx.fieldList().fieldExpression().forEach(field -> { + sortListBuilder.add(internalVisitExpression(field)); + }); + UnresolvedExpression unresolvedPlan = (ctx.number != null ? internalVisitExpression(ctx.number) : null); + TopAggregation aggregation = + new TopAggregation( + Optional.ofNullable((Literal) unresolvedPlan), + aggListBuilder.build(), + sortListBuilder.build(), + groupListBuilder.build()); + return aggregation; + } + + /** Rare command. */ + @Override + public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { + ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); + ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); + ImmutableList.Builder sortListBuilder = new ImmutableList.Builder<>(); + ctx.fieldList().fieldExpression().forEach(field -> { + UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), + Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); + String name = field.qualifiedName().getText(); + Alias alias = new Alias("count("+name+")", aggExpression); + aggListBuilder.add(alias); + // group by the `field-list` as the mandatory groupBy fields + groupListBuilder.add(internalVisitExpression(field)); + }); + + // group by the `by-clause` as the optional groupBy fields + groupListBuilder.addAll( + Optional.ofNullable(ctx.byClause()) + .map(OpenSearchPPLParser.ByClauseContext::fieldList) + .map( + expr -> + expr.fieldExpression().stream() + .map( + groupCtx -> + (UnresolvedExpression) + new Alias( + getTextInQuery(groupCtx), + internalVisitExpression(groupCtx))) + .collect(Collectors.toList())) + .orElse(emptyList()) + ); + //build the sort fields + ctx.fieldList().fieldExpression().forEach(field -> { + sortListBuilder.add(internalVisitExpression(field)); + }); + RareAggregation aggregation = + new RareAggregation( + aggListBuilder.build(), + sortListBuilder.build(), + groupListBuilder.build()); + return aggregation; } /** From clause. */ diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala new file mode 100644 index 000000000..5bd5da28c --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala @@ -0,0 +1,154 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} +import org.apache.spark.sql.execution.command.DescribeTableCommand + +class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test simple rare command with a single field") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=accounts | rare address", false), context) + val addressField = UnresolvedAttribute("address") + val tableRelation = UnresolvedRelation(Seq("accounts")) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count(address)")(), + addressField) + + val aggregatePlan = + Aggregate(Seq(addressField), aggregateExpressions, tableRelation) + + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("address"), Descending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test simple rare command with a by field test") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logicalPlan = + planTransformer.visit( + plan(pplParser, "source=accounts | rare address by age", false), + context) + // Retrieve the logical plan + // Define the expected logical plan + val addressField = UnresolvedAttribute("address") + val ageField = UnresolvedAttribute("age") + val ageAlias = Alias(ageField, "age")() + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count(address)")() + + val aggregateExpressions = Seq(countExpr, addressField, ageAlias) + val aggregatePlan = + Aggregate( + Seq(addressField, ageAlias), + aggregateExpressions, + UnresolvedRelation(Seq("accounts"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("address"), Descending)), + global = true, + aggregatePlan) + + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, false) + } + + test("test simple top command with a single field") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=accounts | top address", false), context) + val addressField = UnresolvedAttribute("address") + val tableRelation = UnresolvedRelation(Seq("accounts")) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count(address)")(), + addressField) + + val aggregatePlan = + Aggregate(Seq(addressField), aggregateExpressions, tableRelation) + + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("address"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test simple top 1 command by age field") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=accounts | top 1 address by age", false), + context) + + val addressField = UnresolvedAttribute("address") + val ageField = UnresolvedAttribute("age") + val ageAlias = Alias(ageField, "age")() + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count(address)")() + val aggregateExpressions = Seq(countExpr, addressField, ageAlias) + val aggregatePlan = + Aggregate( + Seq(addressField, ageAlias), + aggregateExpressions, + UnresolvedRelation(Seq("accounts"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("address"), Ascending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(1), LocalLimit(Literal(1), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logPlan, false) + } +}