From 7fd9223792c43997499a7011ecf174e85430506b Mon Sep 17 00:00:00 2001 From: YANGDB Date: Fri, 23 Aug 2024 13:46:23 -0700 Subject: [PATCH] PPL Parse command (#595) * add parse regexp command for PPL * add parse code & classes Signed-off-by: YANGDB * add parse / grok / patterns command Signed-off-by: YANGDB * update tests with more complex tests Signed-off-by: YANGDB Signed-off-by: YANGDB * scalafmtAll fixes Signed-off-by: YANGDB * fix depended top/rare issues update readme with command Signed-off-by: YANGDB --------- Signed-off-by: YANGDB --- .../flint/spark/FlintSparkSuite.scala | 69 +++++ .../spark/ppl/FlintSparkPPLParseITSuite.scala | 220 ++++++++++++++++ .../ppl/FlintSparkPPLTopAndRareITSuite.scala | 104 ++++++-- ppl-spark-integration/README.md | 8 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 3 + .../opensearch/sql/common/grok/Converter.java | 165 ++++++++++++ .../org/opensearch/sql/common/grok/Grok.java | 171 +++++++++++++ .../sql/common/grok/GrokCompiler.java | 199 +++++++++++++++ .../opensearch/sql/common/grok/GrokUtils.java | 59 +++++ .../org/opensearch/sql/common/grok/Match.java | 241 ++++++++++++++++++ .../common/grok/exception/GrokException.java | 50 ++++ .../sql/ppl/CatalystPlanContext.java | 33 ++- .../sql/ppl/CatalystQueryPlanVisitor.java | 56 +++- .../opensearch/sql/ppl/parser/AstBuilder.java | 18 +- .../sql/ppl/parser/AstExpressionBuilder.java | 2 + .../opensearch/sql/ppl/utils/ParseUtils.java | 238 +++++++++++++++++ ...LLogicalPlanParseTranslatorTestSuite.scala | 239 +++++++++++++++++ ...TopAndRareQueriesTranslatorTestSuite.scala | 92 ++++++- 18 files changed, 1914 insertions(+), 53 deletions(-) create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Converter.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Grok.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokUtils.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Match.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/exception/GrokException.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index e93eee790..3f843dbe4 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -100,6 +100,42 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit } } + protected def createPartitionedGrokEmailTable(testTable: String): Unit = { + spark.sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT, + | email STRING, + | street_address STRING + | ) + | USING $tableType $tableOptions + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + val data = Seq( + ("Alice", 30, "alice@example.com", "123 Main St, Seattle", 2023, 4), + ("Bob", 55, "bob@test.org", "456 Elm St, Portland", 2023, 5), + ("Charlie", 65, "charlie@domain.net", "789 Pine St, San Francisco", 2023, 4), + ("David", 19, "david@anotherdomain.com", "101 Maple St, New York", 2023, 5), + ("Eve", 21, "eve@examples.com", "202 Oak St, Boston", 2023, 4), + ("Frank", 76, "frank@sample.org", "303 Cedar St, Austin", 2023, 5), + ("Grace", 41, "grace@demo.net", "404 Birch St, Chicago", 2023, 4), + ("Hank", 32, "hank@demonstration.com", "505 Spruce St, Miami", 2023, 5), + ("Ivy", 9, "ivy@examples.com", "606 Fir St, Denver", 2023, 4), + ("Jack", 12, "jack@sample.net", "707 Ash St, Seattle", 2023, 5)) + + data.foreach { case (name, age, email, street_address, year, month) => + spark.sql(s""" + | INSERT INTO $testTable + | PARTITION (year=$year, month=$month) + | VALUES ('$name', $age, '$email', '$street_address') + | """.stripMargin) + } + } protected def createPartitionedAddressTable(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable @@ -241,6 +277,39 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | """.stripMargin) } + protected def createOccupationTopRareTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | occupation STRING, + | country STRING, + | salary INT + | ) + | USING $tableType $tableOptions + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Insert data into the new table + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 'Engineer', 'England' , 100000), + | ('Hello', 'Artist', 'USA', 70000), + | ('John', 'Doctor', 'Canada', 120000), + | ('Rachel', 'Doctor', 'Canada', 220000), + | ('Henry', 'Doctor', 'Canada', 220000), + | ('David', 'Engineer', 'USA', 320000), + | ('Barty', 'Engineer', 'USA', 120000), + | ('David', 'Unemployed', 'Canada', 0), + | ('Jane', 'Scientist', 'Canada', 90000), + | ('Philip', 'Scientist', 'Canada', 190000) + | """.stripMargin) + } + protected def createHobbiesTable(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala new file mode 100644 index 000000000..388de3d31 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala @@ -0,0 +1,220 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import scala.reflect.internal.Reporter.Count + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Coalesce, Descending, GreaterThan, Literal, NullsLast, RegExpExtract, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, LogicalPlan, Project, Sort} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLParseITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedGrokEmailTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test parse email expressions parsing") { + val frame = sql(s""" + | source = $testTable| parse email '.+@(?.+)' | fields email, host ; + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("charlie@domain.net", "domain.net"), + Row("david@anotherdomain.com", "anotherdomain.com"), + Row("hank@demonstration.com", "demonstration.com"), + Row("alice@example.com", "example.com"), + Row("frank@sample.org", "sample.org"), + Row("grace@demo.net", "demo.net"), + Row("jack@sample.net", "sample.net"), + Row("eve@examples.com", "examples.com"), + Row("ivy@examples.com", "examples.com"), + Row("bob@test.org", "test.org")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val emailAttribute = UnresolvedAttribute("email") + val hostAttribute = UnresolvedAttribute("host") + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal("1")))), + "host")() + val expectedPlan = Project( + Seq(emailAttribute, hostAttribute), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))) + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("test parse email expressions parsing filter & sort by age") { + val frame = sql(s""" + | source = $testTable| parse email '.+@(?.+)' | where age > 45 | sort - age | fields age, email, host ; + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row(76, "frank@sample.org", "sample.org"), + Row(65, "charlie@domain.net", "domain.net"), + Row(55, "bob@test.org", "test.org")) + + // Compare the results + assert(results.sameElements(expectedResults)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val emailAttribute = UnresolvedAttribute("email") + val ageAttribute = UnresolvedAttribute("age") + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), + "host")() + + // Define the corrected expected plan + val expectedPlan = Project( + Seq(ageAttribute, emailAttribute, UnresolvedAttribute("host")), + Sort( + Seq(SortOrder(ageAttribute, Descending, NullsLast, Seq.empty)), + global = true, + Filter( + GreaterThan(ageAttribute, Literal(45)), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))))) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test parse email expressions and group by count host ") { + val frame = sql(s""" + | source = $testTable| parse email '.+@(?.+)' | stats count() by host + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row(1L, "demonstration.com"), + Row(1L, "example.com"), + Row(1L, "domain.net"), + Row(1L, "anotherdomain.com"), + Row(1L, "sample.org"), + Row(1L, "demo.net"), + Row(1L, "sample.net"), + Row(2L, "examples.com"), + Row(1L, "test.org")) + + // Sort both the results and the expected results + implicit val rowOrdering: Ordering[Row] = Ordering.by(r => (r.getLong(0), r.getString(1))) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val emailAttribute = UnresolvedAttribute("email") + val hostAttribute = UnresolvedAttribute("host") + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), + "host")() + + // Define the corrected expected plan + val expectedPlan = Project( + Seq(UnresolvedStar(None)), // Matches the '*' in the Project + Aggregate( + Seq(Alias(hostAttribute, "host")()), // Group by 'host' + Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "count()")(), + Alias(hostAttribute, "host")()), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))))) + // Compare the logical plans + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test parse email expressions and top count_host ") { + val frame = sql(s""" + | source = $testTable| parse email '.+@(?.+)' | top 1 host + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2L, "examples.com")) + + // Sort both the results and the expected results + implicit val rowOrdering: Ordering[Row] = Ordering.by(r => (r.getLong(0), r.getString(1))) + assert(results.sorted.sameElements(expectedResults.sorted)) + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val emailAttribute = UnresolvedAttribute("email") + val hostAttribute = UnresolvedAttribute("host") + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), + "host")() + + val sortedPlan = Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(hostAttribute), isDistinct = false), + "count_host")(), + Descending, + NullsLast, + Seq.empty)), + global = true, + Aggregate( + Seq(hostAttribute), + Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(hostAttribute), isDistinct = false), + "count_host")(), + hostAttribute), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))))) + // Define the corrected expected plan + val expectedPlan = Project( + Seq(UnresolvedStar(None)), // Matches the '*' in the Project + GlobalLimit(Literal(1), LocalLimit(Literal(1), sortedPlan))) + // Compare the logical plans + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala index 09307aa44..f10b6e2f5 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala @@ -21,11 +21,13 @@ class FlintSparkPPLTopAndRareITSuite /** Test table and index name */ private val testTable = "spark_catalog.default.flint_ppl_test" + private val newTestTable = "spark_catalog.default.new_flint_ppl_test" override def beforeAll(): Unit = { super.beforeAll() - // Create test table + // Create test tables + createOccupationTopRareTable(newTestTable) createPartitionedMultiRowAddressTable(testTable) } @@ -61,7 +63,7 @@ class FlintSparkPPLTopAndRareITSuite val aggregateExpressions = Seq( Alias( UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), - "count(address)")(), + "count_address")(), addressField) val aggregatePlan = Aggregate( @@ -70,11 +72,16 @@ class FlintSparkPPLTopAndRareITSuite UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("address"), Descending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Ascending)), global = true, aggregatePlan) val expectedPlan = Project(projectList, sortedPlan) - comparePlans(expectedPlan, logicalPlan, false) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } test("create ppl rare address by age field query test") { @@ -101,7 +108,7 @@ class FlintSparkPPLTopAndRareITSuite val countExpr = Alias( UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), - "count(address)")() + "count_address")() val aggregateExpressions = Seq(countExpr, addressField, ageAlias) val aggregatePlan = @@ -112,7 +119,12 @@ class FlintSparkPPLTopAndRareITSuite val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("address"), Descending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Ascending)), global = true, aggregatePlan) @@ -146,7 +158,7 @@ class FlintSparkPPLTopAndRareITSuite val aggregateExpressions = Seq( Alias( UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), - "count(address)")(), + "count_address")(), addressField) val aggregatePlan = Aggregate( @@ -155,17 +167,66 @@ class FlintSparkPPLTopAndRareITSuite UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("address"), Ascending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Descending)), global = true, aggregatePlan) val expectedPlan = Project(projectList, sortedPlan) - comparePlans(expectedPlan, logicalPlan, false) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } - test("create ppl top 3 countries by occupation field query test") { - val newTestTable = "spark_catalog.default.new_flint_ppl_test" - createOccupationTable(newTestTable) + test("create ppl top 3 countries query test") { + val frame = sql(s""" + | source = $newTestTable| top 3 country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + val expectedRows = Set(Row(6, "Canada"), Row(3, "USA"), Row(1, "England")) + val actualRows = results.take(3).toSet + + // Compare the sets + assert( + actualRows == expectedRows, + s"The first two results do not match the expected rows. Expected: $expectedRows, Actual: $actualRows") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val countryField = UnresolvedAttribute("country") + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField) + val aggregatePlan = + Aggregate( + Seq(countryField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("create ppl top 2 countries by occupation field query test") { val frame = sql(s""" | source = $newTestTable| top 3 country by occupation | """.stripMargin) @@ -174,10 +235,8 @@ class FlintSparkPPLTopAndRareITSuite val results: Array[Row] = frame.collect() assert(results.length == 3) - val expectedRows = Set( - Row(1, "Canada", "Doctor"), - Row(1, "Canada", "Scientist"), - Row(1, "Canada", "Unemployed")) + val expectedRows = + Set(Row(3, "Canada", "Doctor"), Row(2, "Canada", "Scientist"), Row(2, "USA", "Engineer")) val actualRows = results.take(3).toSet // Compare the sets @@ -187,14 +246,13 @@ class FlintSparkPPLTopAndRareITSuite // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - val countryField = UnresolvedAttribute("country") val occupationField = UnresolvedAttribute("occupation") val occupationFieldAlias = Alias(occupationField, "occupation")() val countExpr = Alias( UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), - "count(country)")() + "count_country")() val aggregateExpressions = Seq(countExpr, countryField, occupationFieldAlias) val aggregatePlan = Aggregate( @@ -204,13 +262,19 @@ class FlintSparkPPLTopAndRareITSuite val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")(), + Descending)), global = true, aggregatePlan) val planWithLimit = GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) - comparePlans(expectedPlan, logicalPlan, false) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index bc8a96c52..24639e444 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -306,6 +306,14 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source=accounts | top 1 gender` - `source=accounts | top 1 age by gender` +**Parse** +- `source=accounts | parse email '.+@(?.+)' | fields email, host ` +- `source=accounts | parse email '.+@(?.+)' | top 1 host ` +- `source=accounts | parse email '.+@(?.+)' | stats count() by host` +- `source=accounts | parse email '.+@(?.+)' | eval eval_result=1 | fields host, eval_result` +- `source=accounts | parse email '.+@(?.+)' | where age > 45 | sort - age | fields age, email, host` +- `source=accounts | parse address '(?\d+) (?.+)' | where streetNumber > 500 | sort num(streetNumber) | fields streetNumber, street` + > For additional details on PPL commands - view [PPL Commands Docs](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/index.rst) diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index f4065be6d..3ce7cef7c 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -41,6 +41,9 @@ commands | topCommand | rareCommand | evalCommand + | grokCommand + | parseCommand + | patternsCommand ; searchCommand diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Converter.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Converter.java new file mode 100644 index 000000000..ddd3a2bbb --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Converter.java @@ -0,0 +1,165 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.grok; + +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.time.temporal.TemporalAccessor; +import java.util.AbstractMap; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +/** Convert String argument to the right type. */ +public class Converter { + + public enum Type { + BYTE(Byte::valueOf), + BOOLEAN(Boolean::valueOf), + SHORT(Short::valueOf), + INT(Integer::valueOf, "integer"), + LONG(Long::valueOf), + FLOAT(Float::valueOf), + DOUBLE(Double::valueOf), + DATETIME(new DateConverter(), "date"), + STRING(v -> v, "text"); + + public final IConverter converter; + public final List aliases; + + Type(IConverter converter, String... aliases) { + this.converter = converter; + this.aliases = Arrays.asList(aliases); + } + } + + private static final Pattern SPLITTER = Pattern.compile("[:;]"); + + private static final Map TYPES = + Arrays.stream(Type.values()).collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); + + private static final Map TYPE_ALIASES = + Arrays.stream(Type.values()) + .flatMap( + type -> + type.aliases.stream().map(alias -> new AbstractMap.SimpleEntry<>(alias, type))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + private static Type getType(String key) { + key = key.toLowerCase(); + Type type = TYPES.getOrDefault(key, TYPE_ALIASES.get(key)); + if (type == null) { + throw new IllegalArgumentException("Invalid data type :" + key); + } + return type; + } + + /** getConverters. */ + public static Map> getConverters( + Collection groupNames, Object... params) { + return groupNames.stream() + .filter(Converter::containsDelimiter) + .collect( + Collectors.toMap( + Function.identity(), + key -> { + String[] list = splitGrokPattern(key); + IConverter converter = getType(list[1]).converter; + if (list.length == 3) { + converter = converter.newConverter(list[2], params); + } + return converter; + })); + } + + /** getGroupTypes. */ + public static Map getGroupTypes(Collection groupNames) { + return groupNames.stream() + .filter(Converter::containsDelimiter) + .map(Converter::splitGrokPattern) + .collect(Collectors.toMap(l -> l[0], l -> getType(l[1]))); + } + + public static String extractKey(String key) { + return splitGrokPattern(key)[0]; + } + + private static boolean containsDelimiter(String string) { + return string.indexOf(':') >= 0 || string.indexOf(';') >= 0; + } + + private static String[] splitGrokPattern(String string) { + return SPLITTER.split(string, 3); + } + + interface IConverter { + + T convert(String value); + + default IConverter newConverter(String param, Object... params) { + return this; + } + } + + static class DateConverter implements IConverter { + + private final DateTimeFormatter formatter; + private final ZoneId timeZone; + + public DateConverter() { + this.formatter = DateTimeFormatter.ISO_DATE_TIME; + this.timeZone = ZoneOffset.UTC; + } + + private DateConverter(DateTimeFormatter formatter, ZoneId timeZone) { + this.formatter = formatter; + this.timeZone = timeZone; + } + + @Override + public Instant convert(String value) { + TemporalAccessor dt = + formatter.parseBest( + value.trim(), + ZonedDateTime::from, + LocalDateTime::from, + OffsetDateTime::from, + Instant::from, + LocalDate::from); + if (dt instanceof ZonedDateTime) { + return ((ZonedDateTime) dt).toInstant(); + } else if (dt instanceof LocalDateTime) { + return ((LocalDateTime) dt).atZone(timeZone).toInstant(); + } else if (dt instanceof OffsetDateTime) { + return ((OffsetDateTime) dt).atZoneSameInstant(timeZone).toInstant(); + } else if (dt instanceof Instant) { + return ((Instant) dt); + } else if (dt instanceof LocalDate) { + return ((LocalDate) dt).atStartOfDay(timeZone).toInstant(); + } else { + return null; + } + } + + @Override + public DateConverter newConverter(String param, Object... params) { + if (!(params.length == 1 && params[0] instanceof ZoneId)) { + throw new IllegalArgumentException("Invalid parameters"); + } + return new DateConverter(DateTimeFormatter.ofPattern(param), (ZoneId) params[0]); + } + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Grok.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Grok.java new file mode 100644 index 000000000..e0c37af99 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Grok.java @@ -0,0 +1,171 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.grok; + +import org.opensearch.sql.common.grok.Converter.IConverter; + +import java.io.Serializable; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * {@code Grok} parse arbitrary text and structure it.
+ * {@code Grok} is simple API that allows you to easily parse logs and other files (single line). + * With {@code Grok}, you can turn unstructured log and event data into structured data. + * + * @since 0.0.1 + */ +public class Grok implements Serializable { + /** Named regex of the originalGrokPattern. */ + private final String namedRegex; + + /** + * Map of the named regex of the originalGrokPattern with id = namedregexid and value = + * namedregex. + */ + private final Map namedRegexCollection; + + /** Original {@code Grok} pattern (expl: %{IP}). */ + private final String originalGrokPattern; + + /** Pattern of the namedRegex. */ + private final Pattern compiledNamedRegex; + + /** {@code Grok} patterns definition. */ + private final Map grokPatternDefinition; + + public final Set namedGroups; + + public final Map groupTypes; + + public final Map> converters; + + /** only use in grok discovery. */ + private String savedPattern = ""; + + /** Grok. */ + public Grok( + String pattern, + String namedRegex, + Map namedRegexCollection, + Map patternDefinitions, + ZoneId defaultTimeZone) { + this.originalGrokPattern = pattern; + this.namedRegex = namedRegex; + this.compiledNamedRegex = Pattern.compile(namedRegex); + this.namedRegexCollection = namedRegexCollection; + this.namedGroups = GrokUtils.getNameGroups(namedRegex); + this.groupTypes = Converter.getGroupTypes(namedRegexCollection.values()); + this.converters = Converter.getConverters(namedRegexCollection.values(), defaultTimeZone); + this.grokPatternDefinition = patternDefinitions; + } + + public String getSaved_pattern() { + return savedPattern; + } + + public void setSaved_pattern(String savedpattern) { + this.savedPattern = savedpattern; + } + + /** + * Get the current map of {@code Grok} pattern. + * + * @return Patterns (name, regular expression) + */ + public Map getPatterns() { + return grokPatternDefinition; + } + + /** + * Get the named regex from the {@code Grok} pattern.
+ * + * @return named regex + */ + public String getNamedRegex() { + return namedRegex; + } + + /** + * Original grok pattern used to compile to the named regex. + * + * @return String Original Grok pattern + */ + public String getOriginalGrokPattern() { + return originalGrokPattern; + } + + /** + * Get the named regex from the given id. + * + * @param id : named regex id + * @return String of the named regex + */ + public String getNamedRegexCollectionById(String id) { + return namedRegexCollection.get(id); + } + + /** + * Get the full collection of the named regex. + * + * @return named RegexCollection + */ + public Map getNamedRegexCollection() { + return namedRegexCollection; + } + + /** + * Match the given log with the named regex. And return the json representation of the + * matched element + * + * @param log : log to match + * @return map containing matches + */ + public Map capture(String log) { + Match match = match(log); + return match.capture(); + } + + /** + * Match the given list of log with the named regex and return the list of json + * representation of the matched elements. + * + * @param logs : list of log + * @return list of maps containing matches + */ + public ArrayList> capture(List logs) { + final ArrayList> matched = new ArrayList<>(); + for (String log : logs) { + matched.add(capture(log)); + } + return matched; + } + + /** + * Match the given text with the named regex {@code Grok} will extract data from the + * string and get an extence of {@link Match}. + * + * @param text : Single line of log + * @return Grok Match + */ + public Match match(CharSequence text) { + if (compiledNamedRegex == null || text == null) { + return Match.EMPTY; + } + + Matcher matcher = compiledNamedRegex.matcher(text); + if (matcher.find()) { + return new Match(text, this, matcher, matcher.start(0), matcher.end(0)); + } + + return Match.EMPTY; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java new file mode 100644 index 000000000..7d51038cd --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java @@ -0,0 +1,199 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.grok; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.sql.common.grok.exception.GrokException; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.Reader; +import java.io.Serializable; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static java.lang.String.format; + +public class GrokCompiler implements Serializable { + + // We don't want \n and commented line + private static final Pattern patternLinePattern = Pattern.compile("^([A-z0-9_]+)\\s+(.*)$"); + + /** {@code Grok} patterns definitions. */ + private final Map grokPatternDefinitions = new HashMap<>(); + + private GrokCompiler() {} + + public static GrokCompiler newInstance() { + return new GrokCompiler(); + } + + public Map getPatternDefinitions() { + return grokPatternDefinitions; + } + + /** + * Registers a new pattern definition. + * + * @param name : Pattern Name + * @param pattern : Regular expression Or {@code Grok} pattern + * @throws GrokException runtime expt + */ + public void register(String name, String pattern) { + name = Objects.requireNonNull(name).trim(); + pattern = Objects.requireNonNull(pattern).trim(); + + if (!name.isEmpty() && !pattern.isEmpty()) { + grokPatternDefinitions.put(name, pattern); + } + } + + /** Registers multiple pattern definitions. */ + public void register(Map patternDefinitions) { + Objects.requireNonNull(patternDefinitions); + patternDefinitions.forEach(this::register); + } + + /** + * Registers multiple pattern definitions from a given inputStream, and decoded as a UTF-8 source. + */ + public void register(InputStream input) throws IOException { + register(input, StandardCharsets.UTF_8); + } + + /** Registers multiple pattern definitions from a given inputStream. */ + public void register(InputStream input, Charset charset) throws IOException { + try (BufferedReader in = new BufferedReader(new InputStreamReader(input, charset))) { + in.lines() + .map(patternLinePattern::matcher) + .filter(Matcher::matches) + .forEach(m -> register(m.group(1), m.group(2))); + } + } + + /** Registers multiple pattern definitions from a given Reader. */ + public void register(Reader input) throws IOException { + new BufferedReader(input) + .lines() + .map(patternLinePattern::matcher) + .filter(Matcher::matches) + .forEach(m -> register(m.group(1), m.group(2))); + } + + public void registerDefaultPatterns() { + registerPatternFromClasspath("/patterns/patterns"); + } + + public void registerPatternFromClasspath(String path) throws GrokException { + registerPatternFromClasspath(path, StandardCharsets.UTF_8); + } + + /** registerPatternFromClasspath. */ + public void registerPatternFromClasspath(String path, Charset charset) throws GrokException { + final InputStream inputStream = this.getClass().getResourceAsStream(path); + try (Reader reader = new InputStreamReader(inputStream, charset)) { + register(reader); + } catch (IOException e) { + throw new GrokException(e.getMessage(), e); + } + } + + /** Compiles a given Grok pattern and returns a Grok object which can parse the pattern. */ + public Grok compile(String pattern) throws IllegalArgumentException { + return compile(pattern, false); + } + + public Grok compile(final String pattern, boolean namedOnly) throws IllegalArgumentException { + return compile(pattern, ZoneOffset.systemDefault(), namedOnly); + } + + /** + * Compiles a given Grok pattern and returns a Grok object which can parse the pattern. + * + * @param pattern : Grok pattern (ex: %{IP}) + * @param defaultTimeZone : time zone used to parse a timestamp when it doesn't contain the time + * zone + * @param namedOnly : Whether to capture named expressions only or not (i.e. %{IP:ip} but not + * ${IP}) + * @return a compiled pattern + * @throws IllegalArgumentException when pattern definition is invalid + */ + public Grok compile(final String pattern, ZoneId defaultTimeZone, boolean namedOnly) + throws IllegalArgumentException { + + if (StringUtils.isBlank(pattern)) { + throw new IllegalArgumentException("{pattern} should not be empty or null"); + } + + String namedRegex = pattern; + int index = 0; + // flag for infinite recursion + int iterationLeft = 1000; + Boolean continueIteration = true; + Map patternDefinitions = new HashMap<>(grokPatternDefinitions); + + // output + Map namedRegexCollection = new HashMap<>(); + + // Replace %{foo} with the regex (mostly group name regex) + // and then compile the regex + while (continueIteration) { + continueIteration = false; + if (iterationLeft <= 0) { + throw new IllegalArgumentException("Deep recursion pattern compilation of " + pattern); + } + iterationLeft--; + + Set namedGroups = GrokUtils.getNameGroups(GrokUtils.GROK_PATTERN.pattern()); + Matcher matcher = GrokUtils.GROK_PATTERN.matcher(namedRegex); + // Match %{Foo:bar} -> pattern name and subname + // Match %{Foo=regex} -> add new regex definition + if (matcher.find()) { + continueIteration = true; + Map group = GrokUtils.namedGroups(matcher, namedGroups); + if (group.get("definition") != null) { + patternDefinitions.put(group.get("pattern"), group.get("definition")); + group.put("name", group.get("name") + "=" + group.get("definition")); + } + int count = StringUtils.countMatches(namedRegex, "%{" + group.get("name") + "}"); + for (int i = 0; i < count; i++) { + String definitionOfPattern = patternDefinitions.get(group.get("pattern")); + if (definitionOfPattern == null) { + throw new IllegalArgumentException( + format("No definition for key '%s' found, aborting", group.get("pattern"))); + } + String replacement = String.format("(?%s)", index, definitionOfPattern); + if (namedOnly && group.get("subname") == null) { + replacement = String.format("(?:%s)", definitionOfPattern); + } + namedRegexCollection.put( + "name" + index, + (group.get("subname") != null ? group.get("subname") : group.get("name"))); + namedRegex = + StringUtils.replace(namedRegex, "%{" + group.get("name") + "}", replacement, 1); + // System.out.println(_expanded_pattern); + index++; + } + } + } + + if (namedRegex.isEmpty()) { + throw new IllegalArgumentException("Pattern not found"); + } + + return new Grok(pattern, namedRegex, namedRegexCollection, patternDefinitions, defaultTimeZone); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokUtils.java new file mode 100644 index 000000000..4b145bbbe --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokUtils.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.grok; + +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * {@code GrokUtils} contain set of useful tools or methods. + * + * @since 0.0.6 + */ +public class GrokUtils { + + /** Extract Grok patter like %{FOO} to FOO, Also Grok pattern with semantic. */ + public static final Pattern GROK_PATTERN = + Pattern.compile( + "%\\{" + + "(?" + + "(?[A-z0-9]+)" + + "(?::(?[A-z0-9_:;,\\-\\/\\s\\.']+))?" + + ")" + + "(?:=(?" + + "(?:" + + "(?:[^{}]+|\\.+)+" + + ")+" + + ")" + + ")?" + + "\\}"); + + public static final Pattern NAMED_REGEX = Pattern.compile("\\(\\?<([a-zA-Z][a-zA-Z0-9]*)>"); + + /** getNameGroups. */ + public static Set getNameGroups(String regex) { + Set namedGroups = new LinkedHashSet<>(); + Matcher matcher = NAMED_REGEX.matcher(regex); + while (matcher.find()) { + namedGroups.add(matcher.group(1)); + } + return namedGroups; + } + + /** namedGroups. */ + public static Map namedGroups(Matcher matcher, Set groupNames) { + Map namedGroups = new LinkedHashMap<>(); + for (String groupName : groupNames) { + String groupValue = matcher.group(groupName); + namedGroups.put(groupName, groupValue); + } + return namedGroups; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Match.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Match.java new file mode 100644 index 000000000..1c02627c6 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Match.java @@ -0,0 +1,241 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.grok; + +import org.opensearch.sql.common.grok.Converter.IConverter; +import org.opensearch.sql.common.grok.exception.GrokException; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; + +import static java.lang.String.format; + +/** + * {@code Match} is a representation in {@code Grok} world of your log. + * + * @since 0.0.1 + */ +public class Match { + private final CharSequence subject; + private final Grok grok; + private final Matcher match; + private final int start; + private final int end; + private boolean keepEmptyCaptures = true; + private Map capture = Collections.emptyMap(); + + /** Create a new {@code Match} object. */ + public Match(CharSequence subject, Grok grok, Matcher match, int start, int end) { + this.subject = subject; + this.grok = grok; + this.match = match; + this.start = start; + this.end = end; + } + + /** Create Empty grok matcher. */ + public static final Match EMPTY = new Match("", null, null, 0, 0); + + public Matcher getMatch() { + return match; + } + + public int getStart() { + return start; + } + + public int getEnd() { + return end; + } + + /** Ignore empty captures. */ + public void setKeepEmptyCaptures(boolean ignore) { + // clear any cached captures + if (capture.size() > 0) { + capture = new LinkedHashMap<>(); + } + this.keepEmptyCaptures = ignore; + } + + public boolean isKeepEmptyCaptures() { + return this.keepEmptyCaptures; + } + + /** + * Retrurn the single line of log. + * + * @return the single line of log + */ + public CharSequence getSubject() { + return subject; + } + + /** + * Match to the subject the regex and save the matched element into a map. + * + *

Multiple values for the same key are stored as list. + */ + public Map capture() { + return capture(false); + } + + /** + * Private implementation of captureFlattened and capture. + * + * @param flattened will it flatten values. + * @return the matched elements. + * @throws GrokException if a keys has multiple non-null values, but only if flattened is set to + * true. + */ + private Map capture(boolean flattened) throws GrokException { + if (match == null) { + return Collections.emptyMap(); + } + + if (!capture.isEmpty()) { + return capture; + } + + capture = new LinkedHashMap<>(); + + // _capture.put("LINE", this.line); + // _capture.put("LENGTH", this.line.length() +""); + + Map mappedw = GrokUtils.namedGroups(this.match, this.grok.namedGroups); + + mappedw.forEach( + (key, valueString) -> { + String id = this.grok.getNamedRegexCollectionById(key); + if (id != null && !id.isEmpty()) { + key = id; + } + + if ("UNWANTED".equals(key)) { + return; + } + + Object value = valueString; + if (valueString != null) { + IConverter converter = grok.converters.get(key); + + if (converter != null) { + key = Converter.extractKey(key); + try { + value = converter.convert(valueString); + } catch (Exception e) { + capture.put(key + "_grokfailure", e.toString()); + } + + if (value instanceof String) { + value = cleanString((String) value); + } + } else { + value = cleanString(valueString); + } + } else if (!isKeepEmptyCaptures()) { + return; + } + + if (capture.containsKey(key)) { + Object currentValue = capture.get(key); + + if (flattened) { + if (currentValue == null && value != null) { + capture.put(key, value); + } + if (currentValue != null && value != null) { + throw new GrokException( + format( + "key '%s' has multiple non-null values, this is not allowed in flattened" + + " mode, values:'%s', '%s'", + key, currentValue, value)); + } + } else { + if (currentValue instanceof List) { + @SuppressWarnings("unchecked") + List cvl = (List) currentValue; + cvl.add(value); + } else { + List list = new ArrayList(); + list.add(currentValue); + list.add(value); + capture.put(key, list); + } + } + } else { + capture.put(key, value); + } + }); + + capture = Collections.unmodifiableMap(capture); + + return capture; + } + + /** + * Match to the subject the regex and save the matched element into a map + * + *

Multiple values to the same key are flattened to one value: the sole non-null value will be + * captured. Should there be multiple non-null values a RuntimeException is being thrown. + * + *

This can be used in cases like: (foo (.*:message) bar|bar (.*:message) foo) where the regexp + * guarantees that only one value will be captured. + * + *

See also {@link #capture} which returns multiple values of the same key as list. + * + * @return the matched elements + * @throws GrokException if a keys has multiple non-null values. + */ + public Map captureFlattened() throws GrokException { + return capture(true); + } + + /** + * remove from the string the quote and double quote. + * + * @param value string to pure: "my/text" + * @return unquoted string: my/text + */ + private String cleanString(String value) { + if (value == null || value.isEmpty()) { + return value; + } + + char firstChar = value.charAt(0); + char lastChar = value.charAt(value.length() - 1); + + if (firstChar == lastChar && (firstChar == '"' || firstChar == '\'')) { + if (value.length() <= 2) { + return ""; + } else { + int found = 0; + for (int i = 1; i < value.length() - 1; i++) { + if (value.charAt(i) == firstChar) { + found++; + } + } + if (found == 0) { + return value.substring(1, value.length() - 1); + } + } + } + + return value; + } + + /** + * Util fct. + * + * @return boolean + */ + public Boolean isNull() { + return this.match == null; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/exception/GrokException.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/exception/GrokException.java new file mode 100644 index 000000000..0e9d6d2dd --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/exception/GrokException.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.grok.exception; + +/** + * Signals that an {@code Grok} exception of some sort has occurred. This class is the general class + * of exceptions produced by failed or interrupted Grok operations. + * + * @since 0.0.4 + */ +public class GrokException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + /** Creates a new GrokException. */ + public GrokException() { + super(); + } + + /** + * Constructs a new GrokException. + * + * @param message the reason for the exception + * @param cause the underlying Throwable that caused this exception to be thrown. + */ + public GrokException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Constructs a new GrokException. + * + * @param message the reason for the exception + */ + public GrokException(String message) { + super(message); + } + + /** + * Constructs a new GrokException. + * + * @param cause the underlying Throwable that caused this exception to be thrown. + */ + public GrokException(Throwable cause) { + super(cause); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 3aa579275..e262acbde 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -6,9 +6,13 @@ package org.opensearch.sql.ppl; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; +import org.apache.spark.sql.catalyst.expressions.AttributeReference; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.catalyst.plans.logical.Union; +import org.apache.spark.sql.types.Metadata; +import org.opensearch.sql.data.type.ExprType; import scala.collection.Iterator; import scala.collection.Seq; @@ -30,6 +34,10 @@ * The context used for Catalyst logical plan. */ public class CatalystPlanContext { + /** + * Catalyst relations list + **/ + private List projectedFields = new ArrayList<>(); /** * Catalyst relations list **/ @@ -61,6 +69,10 @@ public List getRelations() { return relations; } + public List getProjectedFields() { + return projectedFields; + } + public LogicalPlan getPlan() { if (this.planBranches.size() == 1) { return planBranches.peek(); @@ -89,7 +101,16 @@ public Optional popNamedParseExpressions() { public Stack getGroupingParseExpressions() { return groupingParseExpressions; } - + + /** + * define new field + * @param symbol + * @return + */ + public LogicalPlan define(Expression symbol) { + namedParseExpressions.push(symbol); + return getPlan(); + } /** * append relation to relations list * @@ -100,6 +121,16 @@ public LogicalPlan withRelation(UnresolvedRelation relation) { this.relations.add(relation); return with(relation); } + /** + * append projected fields + * + * @param projectedFields + * @return + */ + public LogicalPlan withProjectedFields(List projectedFields) { + this.projectedFields.addAll(projectedFields); + return getPlan(); + } /** * append plan with evolving plans branches diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index e78be65f7..6caaec839 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -9,23 +9,26 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.expressions.AttributeReference; import org.apache.spark.sql.catalyst.expressions.Ascending$; +import org.apache.spark.sql.catalyst.expressions.Coalesce; import org.apache.spark.sql.catalyst.expressions.Descending$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.apache.spark.sql.catalyst.expressions.RegExpExtract; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.catalyst.expressions.StringRegexExpression; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; import org.apache.spark.sql.catalyst.plans.logical.Deduplicate; -import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; import org.apache.spark.sql.catalyst.plans.logical.Limit; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.catalyst.plans.logical.Union; -import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.AggregateFunction; @@ -45,6 +48,7 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.ParseMethod; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedExpression; @@ -61,6 +65,7 @@ import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; +import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareAggregation; import org.opensearch.sql.ast.tree.RareTopN; @@ -70,6 +75,7 @@ import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; +import org.opensearch.sql.ppl.utils.ParseUtils; import org.opensearch.sql.ppl.utils.SortUtils; import scala.Option; import scala.Option$; @@ -77,6 +83,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.function.BiFunction; @@ -84,6 +91,8 @@ import static java.util.Collections.emptyList; import static java.util.List.of; +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.StringType; import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; @@ -197,7 +206,7 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex // set sort direction according to command type (`rare` is Asc, `top` is Desc, default to Asc) List sortDirections = new ArrayList<>(); - sortDirections.add(node instanceof RareAggregation ? Descending$.MODULE$ : Ascending$.MODULE$); + sortDirections.add(node instanceof RareAggregation ? Ascending$.MODULE$ : Descending$.MODULE$); if (!node.getSortExprList().isEmpty()) { visitExpressionList(node.getSortExprList(), context); @@ -231,7 +240,7 @@ public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { @Override public LogicalPlan visitProject(Project node, CatalystPlanContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); - List expressionList = visitExpressionList(node.getProjectList(), context); + context.withProjectedFields(visitExpressionList(node.getProjectList(), context)); // Create a projection list from the existing expressions Seq projectList = seq(context.getNamedParseExpressions()); @@ -277,6 +286,45 @@ private Expression visitExpression(UnresolvedExpression expression, CatalystPlan return expressionAnalyzer.analyze(expression, context); } + @Override + public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + Expression sourceField = visitExpression(node.getSourceField(), context); + ParseMethod parseMethod = node.getParseMethod(); + java.util.Map arguments = node.getArguments(); + String pattern = (String) node.getPattern().getValue(); + return visitParseCommand(node, sourceField, parseMethod, arguments, pattern, context); + } + + private LogicalPlan visitParseCommand(Parse node, Expression sourceField, ParseMethod parseMethod, Map arguments, String pattern, CatalystPlanContext context) { + List namedGroupCandidates = ParseUtils.getNamedGroupCandidates(parseMethod, pattern, arguments); + String cleanedPattern = ParseUtils.extractPatterns(parseMethod, pattern, namedGroupCandidates); + for (int i = 0; i < namedGroupCandidates.size(); i++) { + String group = namedGroupCandidates.get(i); + //first create the regExp + RegExpExtract regExpExtract = new RegExpExtract(sourceField, + org.apache.spark.sql.catalyst.expressions.Literal.create(cleanedPattern, StringType), + org.apache.spark.sql.catalyst.expressions.Literal.create(i+1, IntegerType)); + //next create Coalesce to handle potential null values + Coalesce coalesce = new Coalesce(seq(regExpExtract)); + //next Alias the extracted fields + context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(coalesce, + group, + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + } + // Create an UnresolvedStar for all-fields projection (possible external wrapping projection that may include additional fields) + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + // extract all fields to project with + Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + // build the plan with the projection step + LogicalPlan child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + return child; + } + @Override public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 7d91bbb7a..fdb11c342 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -278,12 +278,11 @@ public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandCo public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); - ImmutableList.Builder sortListBuilder = new ImmutableList.Builder<>(); ctx.fieldList().fieldExpression().forEach(field -> { UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); - Alias alias = new Alias("count("+name+")", aggExpression); + Alias alias = new Alias("count_"+name, aggExpression); aggListBuilder.add(alias); // group by the `field-list` as the mandatory groupBy fields groupListBuilder.add(internalVisitExpression(field)); @@ -305,16 +304,12 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) .collect(Collectors.toList())) .orElse(emptyList()) ); - //build the sort fields - ctx.fieldList().fieldExpression().forEach(field -> { - sortListBuilder.add(internalVisitExpression(field)); - }); UnresolvedExpression unresolvedPlan = (ctx.number != null ? internalVisitExpression(ctx.number) : null); TopAggregation aggregation = new TopAggregation( Optional.ofNullable((Literal) unresolvedPlan), aggListBuilder.build(), - sortListBuilder.build(), + aggListBuilder.build(), groupListBuilder.build()); return aggregation; } @@ -324,12 +319,11 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); - ImmutableList.Builder sortListBuilder = new ImmutableList.Builder<>(); ctx.fieldList().fieldExpression().forEach(field -> { UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); - Alias alias = new Alias("count("+name+")", aggExpression); + Alias alias = new Alias("count_"+name, aggExpression); aggListBuilder.add(alias); // group by the `field-list` as the mandatory groupBy fields groupListBuilder.add(internalVisitExpression(field)); @@ -351,14 +345,10 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct .collect(Collectors.toList())) .orElse(emptyList()) ); - //build the sort fields - ctx.fieldList().fieldExpression().forEach(field -> { - sortListBuilder.add(internalVisitExpression(field)); - }); RareAggregation aggregation = new RareAggregation( aggListBuilder.build(), - sortListBuilder.build(), + aggListBuilder.build(), groupListBuilder.build()); return aggregation; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 352853398..aa0abe7f3 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -26,12 +26,14 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.ParseMethod; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.expression.UnresolvedArgument; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java new file mode 100644 index 000000000..54b43db0e --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java @@ -0,0 +1,238 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.ParseMethod; +import org.opensearch.sql.common.grok.Grok; +import org.opensearch.sql.common.grok.GrokCompiler; +import org.opensearch.sql.common.grok.Match; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +public class ParseUtils { + private static final String NEW_FIELD_KEY = "new_field"; + + /** + * Construct corresponding ParseExpression by {@link ParseMethod}. + * + * @param parseMethod method used to parse + * @param pattern pattern used for parsing + * @param identifier derived field + * @return {@link ParseExpression} + */ + public static ParseExpression createParseExpression( + ParseMethod parseMethod, String pattern, String identifier) { + switch (parseMethod) { + case GROK: return new GrokExpression(pattern, identifier); + case PATTERNS: return new PatternsExpression(pattern, identifier); + default: return new RegexExpression(pattern, identifier); + } + } + + /** + * Get list of derived fields based on parse pattern. + * + * @param pattern pattern used for parsing + * @return list of names of the derived fields + */ + public static List getNamedGroupCandidates( + ParseMethod parseMethod, String pattern, Map arguments) { + switch (parseMethod) { + case REGEX: + return RegexExpression.getNamedGroupCandidates(pattern); + case GROK: + return GrokExpression.getNamedGroupCandidates(pattern); + default: + return PatternsExpression.getNamedGroupCandidates( + arguments.containsKey(NEW_FIELD_KEY) + ? (String) arguments.get(NEW_FIELD_KEY).getValue() + : null); + } + } + + /** + * extract the cleaner pattern without the additional fields + * @param parseMethod + * @param pattern + * @param columns + * @return + */ + public static String extractPatterns( + ParseMethod parseMethod, String pattern, List columns) { + switch (parseMethod) { + case REGEX: + return RegexExpression.extractPattern(pattern, columns); + case GROK: + return GrokExpression.extractPattern(pattern, columns); + default: + return PatternsExpression.extractPattern(pattern, columns); + } + } + + public static abstract class ParseExpression { + abstract String parseValue(String value); + } + + public static class RegexExpression extends ParseExpression{ + private static final Pattern GROUP_PATTERN = Pattern.compile("\\(\\?<([a-zA-Z][a-zA-Z0-9]*)>"); + private final Pattern regexPattern; + protected final String identifier; + + public RegexExpression(String patterns, String identifier) { + this.regexPattern = Pattern.compile(patterns); + this.identifier = identifier; + } + + /** + * Get list of derived fields based on parse pattern. + * + * @param pattern pattern used for parsing + * @return list of names of the derived fields + */ + public static List getNamedGroupCandidates(String pattern) { + ImmutableList.Builder namedGroups = ImmutableList.builder(); + Matcher m = GROUP_PATTERN.matcher(pattern); + while (m.find()) { + namedGroups.add(m.group(1)); + } + return namedGroups.build(); + } + + @Override + public String parseValue(String value) { + Matcher matcher = regexPattern.matcher(value); + if (matcher.matches()) { + return matcher.group(identifier); + } + return ""; + } + + public static String extractPattern(String patterns, List columns) { + StringBuilder result = new StringBuilder(); + Matcher matcher = GROUP_PATTERN.matcher(patterns); + + int lastEnd = 0; + while (matcher.find()) { + String groupName = matcher.group(1); + if (columns.contains(groupName)) { + result.append(patterns, lastEnd, matcher.start()); + result.append("("); + lastEnd = matcher.end(); + } + } + result.append(patterns.substring(lastEnd)); + return result.toString(); + } + } + + public static class GrokExpression extends ParseExpression{ + private static final GrokCompiler grokCompiler = GrokCompiler.newInstance(); + private final Grok grok; + private final String identifier; + + public GrokExpression(String pattern, String identifier) { + this.grok = grokCompiler.compile(pattern); + this.identifier = identifier; + } + + @Override + public String parseValue(String value) { + Match grokMatch = grok.match(value); + Map capture = grokMatch.capture(); + Object match = capture.get(identifier); + if (match != null) { + return match.toString(); + } + return ""; + } + + /** + * Get list of derived fields based on parse pattern. + * + * @param pattern pattern used for parsing + * @return list of names of the derived fields + */ + public static List getNamedGroupCandidates(String pattern) { + Grok grok = grokCompiler.compile(pattern); + return grok.namedGroups.stream() + .map(grok::getNamedRegexCollectionById) + .filter(group -> !group.equals("UNWANTED")) + .collect(Collectors.toUnmodifiableList()); + } + + public static String extractPattern(String patterns, List columns) { + //todo implement + return patterns; + } + } + + public static class PatternsExpression extends ParseExpression{ + public static final String DEFAULT_NEW_FIELD = "patterns_field"; + + private static final ImmutableSet DEFAULT_IGNORED_CHARS = + ImmutableSet.copyOf( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + .chars() + .mapToObj(c -> (char) c) + .toArray(Character[]::new)); + private final boolean useCustomPattern; + private Pattern pattern; + + /** + * PatternsExpression. + * + * @param pattern pattern used for parsing + * @param identifier derived field + */ + public PatternsExpression(String pattern, String identifier) { + useCustomPattern = !pattern.isEmpty(); + if (useCustomPattern) { + this.pattern = Pattern.compile(pattern); + } + } + + @Override + public String parseValue(String value) { + if (useCustomPattern) { + return pattern.matcher(value).replaceAll(""); + } + + char[] chars = value.toCharArray(); + int pos = 0; + for (int i = 0; i < chars.length; i++) { + if (!DEFAULT_IGNORED_CHARS.contains(chars[i])) { + chars[pos++] = chars[i]; + } + } + return new String(chars, 0, pos); + } + + /** + * Get list of derived fields. + * + * @param identifier identifier used to generate the field name + * @return list of names of the derived fields + */ + public static List getNamedGroupCandidates(String identifier) { + return ImmutableList.of(Objects.requireNonNullElse(identifier, DEFAULT_NEW_FIELD)); + } + + public static String extractPattern(String patterns, List columns) { + //todo implement + return patterns; + } + } + +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala new file mode 100644 index 000000000..cfc3d9725 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala @@ -0,0 +1,239 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.ScalaReflection.universe.Star +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Coalesce, Descending, GreaterThan, Literal, NamedExpression, NullsFirst, NullsLast, RegExpExtract, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Project, Sort} + +class PPLLogicalPlanParseTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test parse email & host expressions") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | parse email '.+@(?.+)' | fields email, host", + isExplain = false), + context) + + val emailAttribute = UnresolvedAttribute("email") + val hostAttribute = UnresolvedAttribute("host") + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal("1")))), + "host")() + val expectedPlan = Project( + Seq(emailAttribute, hostAttribute), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("t")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test parse email expression") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | parse email '.+@(?.+)' | fields email", false), + context) + + val emailAttribute = UnresolvedAttribute("email") + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal("1")))), + "email")() + val expectedPlan = Project( + Seq(emailAttribute), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("t")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test parse email expression with filter by age and sort by age field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | parse email '.+@(?.+)' | where age > 45 | sort - age | fields age, email, host", + isExplain = false), + context) + + // Define the expected logical plan + val emailAttribute = UnresolvedAttribute("email") + val ageAttribute = UnresolvedAttribute("age") + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), + "host")() + + // Define the corrected expected plan + val expectedPlan = Project( + Seq(ageAttribute, emailAttribute, UnresolvedAttribute("host")), + Sort( + Seq(SortOrder(ageAttribute, Descending, NullsLast, Seq.empty)), + global = true, + Filter( + GreaterThan(ageAttribute, Literal(45)), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("t")))))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test parse email expression, generate new host field and eval result") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | parse email '.+@(?.+)' | eval eval_result=1 | fields host, eval_result", + false), + context) + + val emailAttribute = UnresolvedAttribute("email") + val hostAttribute = UnresolvedAttribute("host") + val evalResultAttribute = UnresolvedAttribute("eval_result") + + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal("1")))), + "host")() + + val evalResultExpression = Alias(Literal(1), "eval_result")() + + val expectedPlan = Project( + Seq(hostAttribute, evalResultAttribute), + Project( + Seq(UnresolvedStar(None), evalResultExpression), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("t"))))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test parse email & host expressions including cast and sort commands") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | parse address '(?\\d+) (?.+)' | where streetNumber > 500 | sort num(streetNumber) | fields streetNumber, street", + false), + context) + + val addressAttribute = UnresolvedAttribute("address") + val streetNumberAttribute = UnresolvedAttribute("streetNumber") + val streetAttribute = UnresolvedAttribute("street") + + val streetNumberExpression = Alias( + Coalesce(Seq(RegExpExtract(addressAttribute, Literal("(\\d+) (.+)"), Literal("1")))), + "streetNumber")() + + val streetExpression = Alias( + Coalesce(Seq(RegExpExtract(addressAttribute, Literal("(\\d+) (.+)"), Literal("2")))), + "street")() + + val expectedPlan = Project( + Seq(streetNumberAttribute, streetAttribute), + Sort( + Seq(SortOrder(streetNumberAttribute, Ascending, NullsFirst, Seq.empty)), + global = true, + Filter( + GreaterThan(streetNumberAttribute, Literal(500)), + Project( + Seq(addressAttribute, streetNumberExpression, streetExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("t")))))) + + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test parse email expressions and group by count host ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | parse email '.+@(?.+)' | stats count() by host", false), + context) + + val emailAttribute = UnresolvedAttribute("email") + val hostAttribute = UnresolvedAttribute("host") + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), + "host")() + + // Define the corrected expected plan + val expectedPlan = Project( + Seq(UnresolvedStar(None)), // Matches the '*' in the Project + Aggregate( + Seq(Alias(hostAttribute, "host")()), // Group by 'host' + Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "count()")(), + Alias(hostAttribute, "host")()), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("t"))))) + + // Compare the logical plans + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test parse email expressions and top count_host ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | parse email '.+@(?.+)' | top 1 host", false), + context) + + val emailAttribute = UnresolvedAttribute("email") + val hostAttribute = UnresolvedAttribute("host") + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), + "host")() + + val sortedPlan = Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(hostAttribute), isDistinct = false), + "count_host")(), + Descending, + NullsLast, + Seq.empty)), + global = true, + Aggregate( + Seq(hostAttribute), + Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(hostAttribute), isDistinct = false), + "count_host")(), + hostAttribute), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("t"))))) + // Define the corrected expected plan + val expectedPlan = Project( + Seq(UnresolvedStar(None)), // Matches the '*' in the Project + GlobalLimit(Literal(1), LocalLimit(Literal(1), sortedPlan))) + // Compare the logical plans + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala index 5bd5da28c..c6e5a7f38 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala @@ -30,7 +30,9 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext val logPlan = - planTransformer.visit(plan(pplParser, "source=accounts | rare address", false), context) + planTransformer.visit( + plan(pplParser, "source=accounts | rare address", isExplain = false), + context) val addressField = UnresolvedAttribute("address") val tableRelation = UnresolvedRelation(Seq("accounts")) @@ -39,7 +41,7 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val aggregateExpressions = Seq( Alias( UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), - "count(address)")(), + "count_address")(), addressField) val aggregatePlan = @@ -47,11 +49,16 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("address"), Descending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Ascending)), global = true, aggregatePlan) val expectedPlan = Project(projectList, sortedPlan) - comparePlans(expectedPlan, logPlan, false) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) } test("test simple rare command with a by field test") { @@ -59,7 +66,7 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val context = new CatalystPlanContext val logicalPlan = planTransformer.visit( - plan(pplParser, "source=accounts | rare address by age", false), + plan(pplParser, "source=accounts | rare address by age", isExplain = false), context) // Retrieve the logical plan // Define the expected logical plan @@ -71,7 +78,7 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val countExpr = Alias( UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), - "count(address)")() + "count_address")() val aggregateExpressions = Seq(countExpr, addressField, ageAlias) val aggregatePlan = @@ -82,19 +89,26 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("address"), Descending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Ascending)), global = true, aggregatePlan) val expectedPlan = Project(projectList, sortedPlan) - comparePlans(expectedPlan, logicalPlan, false) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } test("test simple top command with a single field") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext val logPlan = - planTransformer.visit(plan(pplParser, "source=accounts | top address", false), context) + planTransformer.visit( + plan(pplParser, "source=accounts | top address", isExplain = false), + context) val addressField = UnresolvedAttribute("address") val tableRelation = UnresolvedRelation(Seq("accounts")) @@ -103,7 +117,7 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val aggregateExpressions = Seq( Alias( UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), - "count(address)")(), + "count_address")(), addressField) val aggregatePlan = @@ -111,11 +125,16 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("address"), Ascending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Descending)), global = true, aggregatePlan) val expectedPlan = Project(projectList, sortedPlan) - comparePlans(expectedPlan, logPlan, false) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) } test("test simple top 1 command by age field") { @@ -132,7 +151,7 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val countExpr = Alias( UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), - "count(address)")() + "count_address")() val aggregateExpressions = Seq(countExpr, addressField, ageAlias) val aggregatePlan = Aggregate( @@ -142,7 +161,12 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("address"), Ascending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Descending)), global = true, aggregatePlan) @@ -151,4 +175,44 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) comparePlans(expectedPlan, logPlan, false) } + + test("create ppl top 3 countries by occupation field query test") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=accounts | top 3 country by occupation", false), + context) + + val countryField = UnresolvedAttribute("country") + val occupationField = UnresolvedAttribute("occupation") + val occupationFieldAlias = Alias(occupationField, "occupation")() + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField, occupationFieldAlias) + val aggregatePlan = + Aggregate( + Seq(countryField, occupationFieldAlias), + aggregateExpressions, + UnresolvedRelation(Seq("accounts"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + }