Skip to content

Commit

Permalink
Adding support for Rare & Top PPL
Browse files Browse the repository at this point in the history
top [N] <field-list> [by-clause]

N: number of results to return. Default: 10
field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------

rare <field-list> [by-clause]

field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------
commands:
 - opensearch-project#461
 - opensearch-project#536
Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Aug 14, 2024
1 parent c733d3c commit 1664dc9
Show file tree
Hide file tree
Showing 4 changed files with 402 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.command.DescribeTableCommand
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.{QueryTest, Row}

class FlintSparkPPLTopAndRareITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedMultiRowAddressTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("create ppl rare address field query test") {
val frame = sql(s"""
| source = $testTable| rare address"
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 2)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val limitPlan: LogicalPlan =
Limit(Literal(2), UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
val expectedPlan = Project(Seq(UnresolvedStar(None)), limitPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple query with head (limit) and sorted test") {
val frame = sql(s"""
| source = $testTable| sort name | head 2
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 2)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical

val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("name"), Ascending)),
global = true,
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))

// Define the expected logical plan
val expectedPlan: LogicalPlan =
Project(Seq(UnresolvedStar(None)), Limit(Literal(2), sortedPlan))

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple query two with fields result test") {
val frame = sql(s"""
| source = $testTable| fields name, age
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] =
Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val expectedPlan: LogicalPlan = Project(
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
// Compare the two plans
assert(expectedPlan === logicalPlan)
}

test("create ppl simple sorted query two with fields result test sorted") {
val frame = sql(s"""
| source = $testTable| sort age | fields name, age
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] =
Array(Row("Jane", 20), Row("John", 25), Row("Hello", 30), Row("Jake", 70))
assert(results === expectedResults)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical

val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("age"), Ascending)),
global = true,
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))

// Define the expected logical plan
val expectedPlan: LogicalPlan =
Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), sortedPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple query two with fields and head (limit) test") {
val frame = sql(s"""
| source = $testTable| fields name, age | head 1
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 1)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val project = Project(
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
// Define the expected logical plan
val limitPlan: LogicalPlan = Limit(Literal(1), project)
val expectedPlan: LogicalPlan = Project(Seq(UnresolvedStar(None)), limitPlan)
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple query two with fields and head (limit) with sorting test") {
Seq(("name, age", "age"), ("`name`, `age`", "`age`")).foreach {
case (selectFields, sortField) =>
val frame = sql(s"""
| source = $testTable| fields $selectFields | head 1 | sort $sortField
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 1)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val project = Project(
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
// Define the expected logical plan
val limitPlan: LogicalPlan = Limit(Literal(1), project)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, limitPlan)

val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan)
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
}
}
2 changes: 2 additions & 0 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ commands
| dedupCommand
| sortCommand
| headCommand
| topCommand
| rareCommand
| evalCommand
;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.antlr.v4.runtime.tree.ParseTree;
import org.opensearch.flint.spark.ppl.OpenSearchPPLParser;
import org.opensearch.flint.spark.ppl.OpenSearchPPLParserBaseVisitor;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.Field;
Expand Down Expand Up @@ -236,12 +237,6 @@ private List<Field> getFieldList(OpenSearchPPLParser.FieldListContext ctx) {
.collect(Collectors.toList());
}

/** Rare command. */
@Override
public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) {
throw new RuntimeException("Rare Command is not supported ");
}

@Override
public UnresolvedPlan visitGrokCommand(OpenSearchPPLParser.GrokCommandContext ctx) {
UnresolvedExpression sourceField = internalVisitExpression(ctx.source_field);
Expand Down Expand Up @@ -278,13 +273,42 @@ public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandCo
/** Top command. */
@Override
public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) {

}

/** Rare command. */
@Override
public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) {
ImmutableList.Builder<UnresolvedExpression> aggListBuilder = new ImmutableList.Builder<>();
ctx.fieldList().fieldExpression().forEach(field -> {
UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field));
String name = field.qualifiedName().getText();
Alias alias = new Alias(name, aggExpression);
aggListBuilder.add(alias);
});
List<UnresolvedExpression> groupList =
ctx.byClause() == null ? emptyList() : getGroupByList(ctx.byClause());
return new RareTopN(
RareTopN.CommandType.TOP,
ArgumentFactory.getArgumentList(ctx),
getFieldList(ctx.fieldList()),
groupList);
Optional.ofNullable(ctx.byClause())
.map(OpenSearchPPLParser.ByClauseContext::fieldList)
.map(
expr ->
expr.fieldExpression().stream()
.map(
groupCtx ->
(UnresolvedExpression)
new Alias(
getTextInQuery(groupCtx),
internalVisitExpression(groupCtx)))
.collect(Collectors.toList()))
.orElse(emptyList());

Aggregation aggregation =
new Aggregation(
aggListBuilder.build(),
emptyList(),
groupList,
null,
ArgumentFactory.getArgumentList(ctx));
return aggregation;
}

/** From clause. */
Expand Down
Loading

0 comments on commit 1664dc9

Please sign in to comment.