Skip to content

Commit

Permalink
Implementation of case function. (opensearch-project#695)
Browse files Browse the repository at this point in the history
* Implementation of case function.

Signed-off-by: Lukasz Soszynski <[email protected]>

* Additional tests and documentation related to the case function.

Signed-off-by: Lukasz Soszynski <[email protected]>

---------

Signed-off-by: Lukasz Soszynski <[email protected]>
  • Loading branch information
lukasz-soszynski-eliatra authored Sep 27, 2024
1 parent 38ca314 commit c0924cc
Show file tree
Hide file tree
Showing 9 changed files with 338 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -596,4 +596,27 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
| )
|)""".stripMargin)
}

protected def createTableHttpLog(testTable: String): Unit = {
sql(s"""
| CREATE TABLE $testTable
|(
| id INT,
| status_code INT,
| request_path STRING,
| timestamp STRING
|)
| USING $tableType $tableOptions
|""".stripMargin)

sql(s"""
| INSERT INTO $testTable
| VALUES (1, 200, '/home', '2023-10-01 10:00:00'),
| (2, null, '/about', '2023-10-01 10:05:00'),
| (3, 500, '/contact', '2023-10-01 10:10:00'),
| (4, 301, '/home', '2023-10-01 10:15:00'),
| (5, 200, '/services', '2023-10-01 10:20:00'),
| (6, 403, '/home', '2023-10-01 10:25:00')
| """.stripMargin)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq

import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, LessThan, Literal, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, LessThan, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort}
import org.apache.spark.sql.streaming.StreamTest

Expand All @@ -21,12 +21,14 @@ class FlintSparkPPLEvalITSuite

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"
private val testTableHttpLog = "spark_catalog.default.flint_ppl_test_http_log"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedStateCountryTable(testTable)
createTableHttpLog(testTableHttpLog)
}

protected override def afterEach(): Unit = {
Expand Down Expand Up @@ -504,7 +506,134 @@ class FlintSparkPPLEvalITSuite
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("eval case function") {
val frame = sql(s"""
| source = $testTableHttpLog |
| eval status_category =
| case(status_code >= 200 AND status_code < 300, 'Success',
| status_code >= 300 AND status_code < 400, 'Redirection',
| status_code >= 400 AND status_code < 500, 'Client Error',
| status_code >= 500, 'Server Error'
| else concat('Incorrect HTTP status code for request ', request_path)
| )
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row(1, 200, "/home", "2023-10-01 10:00:00", "Success"),
Row(
2,
null,
"/about",
"2023-10-01 10:05:00",
"Incorrect HTTP status code for request /about"),
Row(3, 500, "/contact", "2023-10-01 10:10:00", "Server Error"),
Row(4, 301, "/home", "2023-10-01 10:15:00", "Redirection"),
Row(5, 200, "/services", "2023-10-01 10:20:00", "Success"),
Row(6, 403, "/home", "2023-10-01 10:25:00", "Client Error"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getInt(0))
assert(results.sorted.sameElements(expectedResults.sorted))
val expectedColumns =
Array[String]("id", "status_code", "request_path", "timestamp", "status_category")
assert(frame.columns.sameElements(expectedColumns))

val logicalPlan: LogicalPlan = frame.queryExecution.logical

val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test_http_log"))
val conditionValueSequence = Seq(
(graterOrEqualAndLessThan("status_code", 200, 300), Literal("Success")),
(graterOrEqualAndLessThan("status_code", 300, 400), Literal("Redirection")),
(graterOrEqualAndLessThan("status_code", 400, 500), Literal("Client Error")),
(
EqualTo(
Literal(true),
GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(500))),
Literal("Server Error")))
val elseValue = UnresolvedFunction(
"concat",
Seq(
Literal("Incorrect HTTP status code for request "),
UnresolvedAttribute("request_path")),
isDistinct = false)
val caseFunction = CaseWhen(conditionValueSequence, elseValue)
val aliasStatusCategory = Alias(caseFunction, "status_category")()
val evalProjectList = Seq(UnresolvedStar(None), aliasStatusCategory)
val evalProject = Project(evalProjectList, table)
val expectedPlan = Project(Seq(UnresolvedStar(None)), evalProject)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("eval case function in complex pipeline") {
val frame = sql(s"""
| source = $testTableHttpLog
| | where ispresent(status_code)
| | eval status_category =
| case(status_code >= 200 AND status_code < 300, 'Success',
| status_code >= 300 AND status_code < 400, 'Redirection',
| status_code >= 400 AND status_code < 500, 'Client Error',
| status_code >= 500, 'Server Error'
| else 'Unknown'
| )
| | stats count() by status_category
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row(1L, "Redirection"),
Row(1L, "Client Error"),
Row(1L, "Server Error"),
Row(2L, "Success"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getString(1))
assert(results.sorted.sameElements(expectedResults.sorted))
val expectedColumns = Array[String]("count()", "status_category")
assert(frame.columns.sameElements(expectedColumns))

val logicalPlan: LogicalPlan = frame.queryExecution.logical

val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test_http_log"))
val filter = Filter(
UnresolvedFunction(
"isnotnull",
Seq(UnresolvedAttribute("status_code")),
isDistinct = false),
table)
val conditionValueSequence = Seq(
(graterOrEqualAndLessThan("status_code", 200, 300), Literal("Success")),
(graterOrEqualAndLessThan("status_code", 300, 400), Literal("Redirection")),
(graterOrEqualAndLessThan("status_code", 400, 500), Literal("Client Error")),
(
EqualTo(
Literal(true),
GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(500))),
Literal("Server Error")))
val elseValue = Literal("Unknown")
val caseFunction = CaseWhen(conditionValueSequence, elseValue)
val aliasStatusCategory = Alias(caseFunction, "status_category")()
val evalProjectList = Seq(UnresolvedStar(None), aliasStatusCategory)
val evalProject = Project(evalProjectList, filter)
val aggregation = Aggregate(
Seq(Alias(UnresolvedAttribute("status_category"), "status_category")()),
Seq(
Alias(
UnresolvedFunction("COUNT", Seq(UnresolvedStar(None)), isDistinct = false),
"count()")(),
Alias(UnresolvedAttribute("status_category"), "status_category")()),
evalProject)
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregation)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

private def graterOrEqualAndLessThan(fieldName: String, min: Int, max: Int) = {
val and = And(
GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(min)),
LessThan(UnresolvedAttribute(fieldName), Literal(max)))
EqualTo(Literal(true), and)
}

// Todo excluded fields not support yet

ignore("test single eval expression with excluded fields") {
val frame = sql(s"""
| source = $testTable | eval new_field = "New Field" | fields - age
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest

Expand All @@ -19,11 +19,13 @@ class FlintSparkPPLFiltersITSuite

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"
private val duplicationTable = "spark_catalog.default.flint_ppl_test_duplication_table"

override def beforeAll(): Unit = {
super.beforeAll()
// Create test table
createPartitionedStateCountryTable(testTable)
createDuplicationNullableTable(duplicationTable)
}

protected override def afterEach(): Unit = {
Expand Down Expand Up @@ -348,4 +350,107 @@ class FlintSparkPPLFiltersITSuite
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("case function used as filter") {
val frame = sql(s"""
| source = $testTable case(country = 'USA', 'The United States of America' else 'Other country') = 'The United States of America'
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("Jake", 70, "California", "USA", 2023, 4),
Row("Hello", 30, "New York", "USA", 2023, 4))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))

assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val conditionValueSequence = Seq(
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("country"), Literal("USA"))),
Literal("The United States of America")))
val elseValue = Literal("Other country")
val caseFunction = CaseWhen(conditionValueSequence, elseValue)
val filterExpr = EqualTo(caseFunction, Literal("The United States of America"))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, filterPlan)
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("case function used as filter complex filter") {
val frame = sql(s"""
| source = $duplicationTable
| | eval factor = case(id > 15, id - 14, isnull(name), id - 7, id < 3, id + 1 else 1)
| | where case(factor = 2, 'even', factor = 4, 'even', factor = 6, 'even', factor = 8, 'even' else 'odd') = 'even'
| | stats count() by factor
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect() // count(), factor
// Define the expected results
val expectedResults: Array[Row] = Array(Row(1, 4), Row(1, 6), Row(2, 2))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table =
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test_duplication_table"))

// case function used in eval command
val conditionValueEval = Seq(
(
EqualTo(Literal(true), GreaterThan(UnresolvedAttribute("id"), Literal(15))),
UnresolvedFunction("-", Seq(UnresolvedAttribute("id"), Literal(14)), isDistinct = false)),
(
EqualTo(
Literal(true),
UnresolvedFunction("isnull", Seq(UnresolvedAttribute("name")), isDistinct = false)),
UnresolvedFunction("-", Seq(UnresolvedAttribute("id"), Literal(7)), isDistinct = false)),
(
EqualTo(Literal(true), LessThan(UnresolvedAttribute("id"), Literal(3))),
UnresolvedFunction("+", Seq(UnresolvedAttribute("id"), Literal(1)), isDistinct = false)))
val aliasCaseFactor = Alias(CaseWhen(conditionValueEval, Literal(1)), "factor")()
val evalProject = Project(Seq(UnresolvedStar(None), aliasCaseFactor), table)

// case in where clause
val conditionValueWhere = Seq(
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(2))),
Literal("even")),
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(4))),
Literal("even")),
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(6))),
Literal("even")),
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(8))),
Literal("even")))
val caseFunctionWhere = CaseWhen(conditionValueWhere, Literal("odd"))
val filterPlan = Filter(EqualTo(caseFunctionWhere, Literal("even")), evalProject)

val aggregation = Aggregate(
Seq(Alias(UnresolvedAttribute("factor"), "factor")()),
Seq(
Alias(
UnresolvedFunction("COUNT", Seq(UnresolvedStar(None)), isDistinct = false),
"count()")(),
Alias(UnresolvedAttribute("factor"), "factor")()),
filterPlan)
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregation)
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
48 changes: 48 additions & 0 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,29 @@ _- **Limitation: new field added by eval command with a function cannot be dropp
- `source = table | where ispresent(b)`
- `source = table | where isnull(coalesce(a, b)) | fields a,b,c | head 3`
- `source = table | where isempty(a)`
- `source = table | where case(length(a) > 6, 'True' else 'False') = 'True'`;
-
```
source = table | eval status_category =
case(a >= 200 AND a < 300, 'Success',
a >= 300 AND a < 400, 'Redirection',
a >= 400 AND a < 500, 'Client Error',
a >= 500, 'Server Error'
else 'Incorrect HTTP status code')
| where case(a >= 200 AND a < 300, 'Success',
a >= 300 AND a < 400, 'Redirection',
a >= 400 AND a < 500, 'Client Error',
a >= 500, 'Server Error'
else 'Incorrect HTTP status code'
) = 'Incorrect HTTP status code'
```
-
```
source = table
| eval factor = case(a > 15, a - 14, isnull(b), a - 7, a < 3, a + 1 else 1)
| where case(factor = 2, 'even', factor = 4, 'even', factor = 6, 'even', factor = 8, 'even' else 'odd') = 'even'
| stats count() by factor
```

**Filters With Logical Conditions**
- `source = table | where c = 'test' AND a = 1 | fields a,b,c`
Expand All @@ -272,6 +295,31 @@ Assumptions: `a`, `b`, `c` are existing fields in `table`
- `source = table | eval f = ispresent(a)`
- `source = table | eval r = coalesce(a, b, c) | fields r`
- `source = table | eval e = isempty(a) | fields e`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one', a = 2, 'two', a = 3, 'three', a = 4, 'four', a = 5, 'five', a = 6, 'six', a = 7, 'se7en', a = 8, 'eight', a = 9, 'nine')`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else 'unknown')`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else concat(a, ' is an incorrect binary digit'))`
-
```
source = table | eval e = eval status_category =
case(a >= 200 AND a < 300, 'Success',
a >= 300 AND a < 400, 'Redirection',
a >= 400 AND a < 500, 'Client Error',
a >= 500, 'Server Error'
else 'Unknown'
)
```
-
```
source = table | where ispresent(a) |
eval status_category =
case(a >= 200 AND a < 300, 'Success',
a >= 300 AND a < 400, 'Redirection',
a >= 400 AND a < 500, 'Client Error',
a >= 500, 'Server Error'
else 'Incorrect HTTP status code'
)
| stats count() by status_category
```

Limitation: Overriding existing field is unsupported, following queries throw exceptions with "Reference 'a' is ambiguous"
- `source = table | eval a = 10 | fields a,b,c`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ APPEND: 'APPEND';

// COMPARISON FUNCTION KEYWORDS
CASE: 'CASE';
ELSE: 'ELSE';
IN: 'IN';

// LOGICAL KEYWORDS
Expand Down
Loading

0 comments on commit c0924cc

Please sign in to comment.