From 4a668f3a7cf7cc79cb788f1e1c1c536b91983a31 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Sat, 7 Sep 2024 09:32:08 -0700 Subject: [PATCH] update tests Signed-off-by: YANGDB --- .../ppl/FlintSparkPPLPatternsITSuite.scala | 102 +++++------------- 1 file changed, 29 insertions(+), 73 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLPatternsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLPatternsITSuite.scala index d6d98887c..38e4f5544 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLPatternsITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLPatternsITSuite.scala @@ -124,98 +124,54 @@ class FlintSparkPPLPatternsITSuite } test("test patterns email expressions and top count_host ") { - val frame = sql("source=spark_catalog.default.flint_ppl_test | patterns new_field='dot_com' pattern='[0-9]' | top 1 dot_com") + val frame = sql("source=spark_catalog.default.flint_ppl_test | patterns new_field='dot_com' pattern='(.com|.net|.org)' email | stats count() by dot_com ") // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array(Row(2L, ".com")) + val expectedResults: Array[Row] = Array( + Row(1L, "charlie@domain"), + Row(1L, "david@anotherdomain"), + Row(1L, "hank@demonstration"), + Row(1L, "alice@example"), + Row(1L, "frank@sample"), + Row(1L, "grace@demo"), + Row(1L, "jack@sample"), + Row(1L, "eve@examples"), + Row(1L, "ivy@examples"), + Row(1L, "bob@test")) // Sort both the results and the expected results implicit val rowOrdering: Ordering[Row] = Ordering.by(r => (r.getLong(0), r.getString(1))) assert(results.sorted.sameElements(expectedResults.sorted)) + // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - - val emailAttribute = UnresolvedAttribute("email") - val hostAttribute = UnresolvedAttribute("host") + val messageAttribute = UnresolvedAttribute("email") + val noNumbersAttribute = UnresolvedAttribute("dot_com") val hostExpression = Alias( - RegExpExtract( - emailAttribute, + RegExpReplace( + messageAttribute, Literal( - ".*(\\.com)$"), - Literal(1)), - "host")() + "(.com|.net|.org)"), + Literal("")), + "dot_com")() - val sortedPlan = Sort( - Seq( - SortOrder( - Alias( - UnresolvedFunction(Seq("COUNT"), Seq(hostAttribute), isDistinct = false), - "count_host")(), - Descending, - NullsLast, - Seq.empty)), - global = true, + // Define the corrected expected plan + val expectedPlan = Project( + Seq(UnresolvedStar(None)), // Matches the '*' in the Project Aggregate( - Seq(hostAttribute), + Seq(Alias(noNumbersAttribute, "dot_com")()), // Group by 'no_numbers' Seq( Alias( - UnresolvedFunction(Seq("COUNT"), Seq(hostAttribute), isDistinct = false), - "count_host")(), - hostAttribute), + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "count()")(), + Alias(noNumbersAttribute, "dot_com")()), Project( - Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + Seq(messageAttribute, hostExpression, UnresolvedStar(None)), UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))))) - // Define the corrected expected plan - val expectedPlan = Project( - Seq(UnresolvedStar(None)), // Matches the '*' in the Project - GlobalLimit(Literal(1), LocalLimit(Literal(1), sortedPlan))) + // Compare the logical plans comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } - - test("test patterns address expressions with 2 fields identifies ") { - val frame = sql(s""" - | source= $testTable | grok street_address '%{NUMBER} %{GREEDYDATA:address}' | fields address - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Pine St, San Francisco"), - Row("Maple St, New York"), - Row("Spruce St, Miami"), - Row("Main St, Seattle"), - Row("Cedar St, Austin"), - Row("Birch St, Chicago"), - Row("Ash St, Seattle"), - Row("Oak St, Boston"), - Row("Fir St, Denver"), - Row("Elm St, Portland")) - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - - val street_addressAttribute = UnresolvedAttribute("street_address") - val addressAttribute = UnresolvedAttribute("address") - val addressExpression = Alias( - RegExpExtract( - street_addressAttribute, - Literal( - "(?(?:(?(?[+-]?(?:(?:[0-9]+(?:\\.[0-9]+)?)|(?:\\.[0-9]+)))))) (?.*)"), - Literal("3")), - "address")() - val expectedPlan = Project( - Seq(addressAttribute), - Project( - Seq(street_addressAttribute, addressExpression, UnresolvedStar(None)), - UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))) - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - }