diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index d8663b888..78abf7ff2 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -596,4 +596,27 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | ) |)""".stripMargin) } + + protected def createTableHttpLog(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + |( + | id INT, + | status_code INT, + | request_path STRING, + | timestamp STRING + |) + | USING $tableType $tableOptions + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES (1, 200, '/home', '2023-10-01 10:00:00'), + | (2, null, '/about', '2023-10-01 10:05:00'), + | (3, 500, '/contact', '2023-10-01 10:10:00'), + | (4, 301, '/home', '2023-10-01 10:15:00'), + | (5, 200, '/services', '2023-10-01 10:20:00'), + | (6, 403, '/home', '2023-10-01 10:25:00') + | """.stripMargin) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala index ea77ff990..596626698 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala @@ -9,7 +9,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, LessThan, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, LessThan, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort} import org.apache.spark.sql.streaming.StreamTest @@ -21,12 +21,14 @@ class FlintSparkPPLEvalITSuite /** Test table and index name */ private val testTable = "spark_catalog.default.flint_ppl_test" + private val testTableHttpLog = "spark_catalog.default.flint_ppl_test_http_log" override def beforeAll(): Unit = { super.beforeAll() // Create test table createPartitionedStateCountryTable(testTable) + createTableHttpLog(testTableHttpLog) } protected override def afterEach(): Unit = { @@ -504,7 +506,134 @@ class FlintSparkPPLEvalITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + test("eval case function") { + val frame = sql(s""" + | source = $testTableHttpLog | + | eval status_category = + | case(status_code >= 200 AND status_code < 300, 'Success', + | status_code >= 300 AND status_code < 400, 'Redirection', + | status_code >= 400 AND status_code < 500, 'Client Error', + | status_code >= 500, 'Server Error' + | else concat('Incorrect HTTP status code for request ', request_path) + | ) + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, 200, "/home", "2023-10-01 10:00:00", "Success"), + Row( + 2, + null, + "/about", + "2023-10-01 10:05:00", + "Incorrect HTTP status code for request /about"), + Row(3, 500, "/contact", "2023-10-01 10:10:00", "Server Error"), + Row(4, 301, "/home", "2023-10-01 10:15:00", "Redirection"), + Row(5, 200, "/services", "2023-10-01 10:20:00", "Success"), + Row(6, 403, "/home", "2023-10-01 10:25:00", "Client Error")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getInt(0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val expectedColumns = + Array[String]("id", "status_code", "request_path", "timestamp", "status_category") + assert(frame.columns.sameElements(expectedColumns)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test_http_log")) + val conditionValueSequence = Seq( + (graterOrEqualAndLessThan("status_code", 200, 300), Literal("Success")), + (graterOrEqualAndLessThan("status_code", 300, 400), Literal("Redirection")), + (graterOrEqualAndLessThan("status_code", 400, 500), Literal("Client Error")), + ( + EqualTo( + Literal(true), + GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(500))), + Literal("Server Error"))) + val elseValue = UnresolvedFunction( + "concat", + Seq( + Literal("Incorrect HTTP status code for request "), + UnresolvedAttribute("request_path")), + isDistinct = false) + val caseFunction = CaseWhen(conditionValueSequence, elseValue) + val aliasStatusCategory = Alias(caseFunction, "status_category")() + val evalProjectList = Seq(UnresolvedStar(None), aliasStatusCategory) + val evalProject = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), evalProject) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("eval case function in complex pipeline") { + val frame = sql(s""" + | source = $testTableHttpLog + | | where ispresent(status_code) + | | eval status_category = + | case(status_code >= 200 AND status_code < 300, 'Success', + | status_code >= 300 AND status_code < 400, 'Redirection', + | status_code >= 400 AND status_code < 500, 'Client Error', + | status_code >= 500, 'Server Error' + | else 'Unknown' + | ) + | | stats count() by status_category + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1L, "Redirection"), + Row(1L, "Client Error"), + Row(1L, "Server Error"), + Row(2L, "Success")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getString(1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val expectedColumns = Array[String]("count()", "status_category") + assert(frame.columns.sameElements(expectedColumns)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test_http_log")) + val filter = Filter( + UnresolvedFunction( + "isnotnull", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + table) + val conditionValueSequence = Seq( + (graterOrEqualAndLessThan("status_code", 200, 300), Literal("Success")), + (graterOrEqualAndLessThan("status_code", 300, 400), Literal("Redirection")), + (graterOrEqualAndLessThan("status_code", 400, 500), Literal("Client Error")), + ( + EqualTo( + Literal(true), + GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(500))), + Literal("Server Error"))) + val elseValue = Literal("Unknown") + val caseFunction = CaseWhen(conditionValueSequence, elseValue) + val aliasStatusCategory = Alias(caseFunction, "status_category")() + val evalProjectList = Seq(UnresolvedStar(None), aliasStatusCategory) + val evalProject = Project(evalProjectList, filter) + val aggregation = Aggregate( + Seq(Alias(UnresolvedAttribute("status_category"), "status_category")()), + Seq( + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedStar(None)), isDistinct = false), + "count()")(), + Alias(UnresolvedAttribute("status_category"), "status_category")()), + evalProject) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregation) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + private def graterOrEqualAndLessThan(fieldName: String, min: Int, max: Int) = { + val and = And( + GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(min)), + LessThan(UnresolvedAttribute(fieldName), Literal(max))) + EqualTo(Literal(true), and) + } + // Todo excluded fields not support yet + ignore("test single eval expression with excluded fields") { val frame = sql(s""" | source = $testTable | eval new_field = "New Field" | fields - age diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala index 9a21bb45a..14ef7ccc4 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -19,11 +19,13 @@ class FlintSparkPPLFiltersITSuite /** Test table and index name */ private val testTable = "spark_catalog.default.flint_ppl_test" + private val duplicationTable = "spark_catalog.default.flint_ppl_test_duplication_table" override def beforeAll(): Unit = { super.beforeAll() // Create test table createPartitionedStateCountryTable(testTable) + createDuplicationNullableTable(duplicationTable) } protected override def afterEach(): Unit = { @@ -348,4 +350,107 @@ class FlintSparkPPLFiltersITSuite assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + test("case function used as filter") { + val frame = sql(s""" + | source = $testTable case(country = 'USA', 'The United States of America' else 'Other country') = 'The United States of America' + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val conditionValueSequence = Seq( + ( + EqualTo(Literal(true), EqualTo(UnresolvedAttribute("country"), Literal("USA"))), + Literal("The United States of America"))) + val elseValue = Literal("Other country") + val caseFunction = CaseWhen(conditionValueSequence, elseValue) + val filterExpr = EqualTo(caseFunction, Literal("The United States of America")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("case function used as filter complex filter") { + val frame = sql(s""" + | source = $duplicationTable + | | eval factor = case(id > 15, id - 14, isnull(name), id - 7, id < 3, id + 1 else 1) + | | where case(factor = 2, 'even', factor = 4, 'even', factor = 6, 'even', factor = 8, 'even' else 'odd') = 'even' + | | stats count() by factor + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() // count(), factor + // Define the expected results + val expectedResults: Array[Row] = Array(Row(1, 4), Row(1, 6), Row(2, 2)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test_duplication_table")) + + // case function used in eval command + val conditionValueEval = Seq( + ( + EqualTo(Literal(true), GreaterThan(UnresolvedAttribute("id"), Literal(15))), + UnresolvedFunction("-", Seq(UnresolvedAttribute("id"), Literal(14)), isDistinct = false)), + ( + EqualTo( + Literal(true), + UnresolvedFunction("isnull", Seq(UnresolvedAttribute("name")), isDistinct = false)), + UnresolvedFunction("-", Seq(UnresolvedAttribute("id"), Literal(7)), isDistinct = false)), + ( + EqualTo(Literal(true), LessThan(UnresolvedAttribute("id"), Literal(3))), + UnresolvedFunction("+", Seq(UnresolvedAttribute("id"), Literal(1)), isDistinct = false))) + val aliasCaseFactor = Alias(CaseWhen(conditionValueEval, Literal(1)), "factor")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasCaseFactor), table) + + // case in where clause + val conditionValueWhere = Seq( + ( + EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(2))), + Literal("even")), + ( + EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(4))), + Literal("even")), + ( + EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(6))), + Literal("even")), + ( + EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(8))), + Literal("even"))) + val caseFunctionWhere = CaseWhen(conditionValueWhere, Literal("odd")) + val filterPlan = Filter(EqualTo(caseFunctionWhere, Literal("even")), evalProject) + + val aggregation = Aggregate( + Seq(Alias(UnresolvedAttribute("factor"), "factor")()), + Seq( + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedStar(None)), isDistinct = false), + "count()")(), + Alias(UnresolvedAttribute("factor"), "factor")()), + filterPlan) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregation) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index d0ef05fce..8f26df20f 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -252,6 +252,29 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where ispresent(b)` - `source = table | where isnull(coalesce(a, b)) | fields a,b,c | head 3` - `source = table | where isempty(a)` + - `source = table | where case(length(a) > 6, 'True' else 'False') = 'True'`; + - + ``` + source = table | eval status_category = + case(a >= 200 AND a < 300, 'Success', + a >= 300 AND a < 400, 'Redirection', + a >= 400 AND a < 500, 'Client Error', + a >= 500, 'Server Error' + else 'Incorrect HTTP status code') + | where case(a >= 200 AND a < 300, 'Success', + a >= 300 AND a < 400, 'Redirection', + a >= 400 AND a < 500, 'Client Error', + a >= 500, 'Server Error' + else 'Incorrect HTTP status code' + ) = 'Incorrect HTTP status code' + ``` +- + ``` + source = table + | eval factor = case(a > 15, a - 14, isnull(b), a - 7, a < 3, a + 1 else 1) + | where case(factor = 2, 'even', factor = 4, 'even', factor = 6, 'even', factor = 8, 'even' else 'odd') = 'even' + | stats count() by factor + ``` **Filters With Logical Conditions** - `source = table | where c = 'test' AND a = 1 | fields a,b,c` @@ -272,6 +295,31 @@ Assumptions: `a`, `b`, `c` are existing fields in `table` - `source = table | eval f = ispresent(a)` - `source = table | eval r = coalesce(a, b, c) | fields r` - `source = table | eval e = isempty(a) | fields e` + - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one', a = 2, 'two', a = 3, 'three', a = 4, 'four', a = 5, 'five', a = 6, 'six', a = 7, 'se7en', a = 8, 'eight', a = 9, 'nine')` + - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else 'unknown')` + - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else concat(a, ' is an incorrect binary digit'))` + - + ``` + source = table | eval e = eval status_category = + case(a >= 200 AND a < 300, 'Success', + a >= 300 AND a < 400, 'Redirection', + a >= 400 AND a < 500, 'Client Error', + a >= 500, 'Server Error' + else 'Unknown' + ) + ``` +- + ``` + source = table | where ispresent(a) | + eval status_category = + case(a >= 200 AND a < 300, 'Success', + a >= 300 AND a < 400, 'Redirection', + a >= 400 AND a < 500, 'Client Error', + a >= 500, 'Server Error' + else 'Incorrect HTTP status code' + ) + | stats count() by status_category + ``` Limitation: Overriding existing field is unsupported, following queries throw exceptions with "Reference 'a' is ambiguous" - `source = table | eval a = 10 | fields a,b,c` diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 834dd3081..9dfb482ff 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -99,6 +99,7 @@ APPEND: 'APPEND'; // COMPARISON FUNCTION KEYWORDS CASE: 'CASE'; +ELSE: 'ELSE'; IN: 'IN'; // LOGICAL KEYWORDS diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 9be7ebb5b..24ebf56c1 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -334,6 +334,7 @@ valueExpression | left = valueExpression binaryOperator = (PLUS | MINUS) right = valueExpression # binaryArithmetic | primaryExpression # valueExpressionDefault | positionFunction # positionFunctionCall + | caseFunction # caseExpr | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr ; @@ -355,6 +356,10 @@ booleanExpression : ISEMPTY LT_PRTHS functionArg RT_PRTHS ; + caseFunction + : CASE LT_PRTHS logicalExpression COMMA valueExpression (COMMA logicalExpression COMMA valueExpression)* (ELSE valueExpression)? RT_PRTHS + ; + relevanceExpression : singleFieldRelevanceFunction | multiFieldRelevanceFunction diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 23dcfb1e9..aea7bbb1d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -176,6 +176,8 @@ public T visitIsEmpty(IsEmpty node, C context) { return visitChildren(node, context); } + // TODO add case + public T visitWindowFunction(WindowFunction node, C context) { return visitChildren(node, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index e0b7f0aae..1bb628497 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -684,6 +684,8 @@ public Expression visitKmeans(Kmeans node, CatalystPlanContext context) { @Override public Expression visitCase(Case node, CatalystPlanContext context) { + Stack initialNameExpressions = new Stack<>(); + initialNameExpressions.addAll(context.getNamedParseExpressions()); analyze(node.getElseClause(), context); Expression elseValue = context.getNamedParseExpressions().pop(); List> whens = new ArrayList<>(); @@ -706,6 +708,7 @@ public Expression visitCase(Case node, CatalystPlanContext context) { } context.retainAllNamedParseExpressions(e -> e); } + context.setNamedParseExpressions(initialNameExpressions); return context.getNamedParseExpressions().push(new CaseWhen(seq(whens), Option.apply(elseValue))); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 160e6c367..2706d85e5 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -44,6 +44,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; @@ -213,6 +214,25 @@ public UnresolvedExpression visitBooleanFunctionCall(OpenSearchPPLParser.Boolean ctx.functionArgs().functionArg()); } + @Override + public UnresolvedExpression visitCaseExpr(OpenSearchPPLParser.CaseExprContext ctx) { + List whens = IntStream.range(0, ctx.caseFunction().logicalExpression().size()) + .mapToObj(index -> { + OpenSearchPPLParser.LogicalExpressionContext logicalExpressionContext = ctx.caseFunction().logicalExpression(index); + OpenSearchPPLParser.ValueExpressionContext valueExpressionContext = ctx.caseFunction().valueExpression(index); + UnresolvedExpression condition = visit(logicalExpressionContext); + UnresolvedExpression result = visit(valueExpressionContext); + return new When(condition, result); + }) + .collect(Collectors.toList()); + UnresolvedExpression elseValue = new Literal(null, DataType.NULL); + if(ctx.caseFunction().valueExpression().size() > ctx.caseFunction().logicalExpression().size()) { + // else value is present + elseValue = visit(ctx.caseFunction().valueExpression(ctx.caseFunction().valueExpression().size() - 1)); + } + return new Case(new Literal(true, DataType.BOOLEAN), whens, elseValue); + } + @Override public UnresolvedExpression visitIsEmptyExpression(OpenSearchPPLParser.IsEmptyExpressionContext ctx) { Function trimFunction = new Function(TRIM.getName().getFunctionName(), Collections.singletonList(this.visitFunctionArg(ctx.functionArg())));