From 2e16414758137ca0c685b9e4ffb835905be5e91d Mon Sep 17 00:00:00 2001 From: lukasz-soszynski-eliatra <110241464+lukasz-soszynski-eliatra@users.noreply.github.com> Date: Tue, 8 Oct 2024 22:16:27 +0200 Subject: [PATCH] Fillnull command introduced (#723) * Fillnull command introduced. Signed-off-by: Lukasz Soszynski * Introduced more tests for the Fillnull command, and code preparation for the review. Signed-off-by: Lukasz Soszynski * New syntax applied to the fillnull command Signed-off-by: Lukasz Soszynski --------- Signed-off-by: Lukasz Soszynski --- docs/ppl-lang/PPL-Example-Commands.md | 8 + docs/ppl-lang/README.md | 2 + .../flint/spark/FlintSparkSuite.scala | 23 ++ .../ppl/FlintSparkPPLFillnullITSuite.scala | 303 ++++++++++++++++++ .../src/main/antlr4/OpenSearchPPLLexer.g4 | 4 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 24 ++ .../sql/ast/AbstractNodeVisitor.java | 4 + .../org/opensearch/sql/ast/tree/FillNull.java | 83 +++++ .../sql/ppl/CatalystQueryPlanVisitor.java | 39 ++- .../opensearch/sql/ppl/parser/AstBuilder.java | 37 +++ ...anFillnullCommandTranslatorTestSuite.scala | 240 ++++++++++++++ 11 files changed, 766 insertions(+), 1 deletion(-) create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FillNull.java create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFillnullCommandTranslatorTestSuite.scala diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 6a552df24..28c4e0a01 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -88,6 +88,14 @@ Assumptions: `a`, `b`, `c` are existing fields in `table` - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else 'unknown')` - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else concat(a, ' is an incorrect binary digit'))` +#### Fillnull +Assumptions: `a`, `b`, `c`, `d`, `e` are existing fields in `table` +- `source = table | fillnull with 0 in a` +- `source = table | fillnull with 'N/A' in a, b, c` +- `source = table | fillnull with concat(a, b) in c, d` +- `source = table | fillnull using a = 101` +- `source = table | fillnull using a = 101, b = 102` +- `source = table | fillnull using a = concat(b, c), d = 2 * pi() * e` ```sql source = table | eval e = eval status_category = diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index 16ff636f7..2ddceca0a 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -27,6 +27,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`dedup command `](ppl-dedup-command.md) - [`describe command`](PPL-Example-Commands.md/#describe) + + - [`fillnull command`](ppl-fillnull-command.md) - [`eval command`](ppl-eval-command.md) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 78abf7ff2..1ecf48d28 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -619,4 +619,27 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | (6, 403, '/home', '2023-10-01 10:25:00') | """.stripMargin) } + + protected def createNullableTableHttpLog(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + |( + | id INT, + | status_code INT, + | request_path STRING, + | timestamp STRING + |) + | USING $tableType $tableOptions + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES (1, 200, '/home', null), + | (2, null, '/about', '2023-10-01 10:05:00'), + | (3, null, '/contact', '2023-10-01 10:10:00'), + | (4, 301, null, '2023-10-01 10:15:00'), + | (5, 200, null, '2023-10-01 10:20:00'), + | (6, 403, '/home', null) + | """.stripMargin) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala new file mode 100644 index 000000000..4788aa23f --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala @@ -0,0 +1,303 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Expression, Literal, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, LogicalPlan, Project, Sort} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLFillnullITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createNullableTableHttpLog(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test fillnull with one null replacement value and one column") { + val frame = sql(s""" + | source = $testTable | fillnull with 0 in status_code + | """.stripMargin) + + assert(frame.columns.sameElements(Array("id", "request_path", "timestamp", "status_code"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, "/home", null, 200), + Row(2, "/about", "2023-10-01 10:05:00", 0), + Row(3, "/contact", "2023-10-01 10:10:00", 0), + Row(4, null, "2023-10-01 10:15:00", 301), + Row(5, null, "2023-10-01 10:20:00", 200), + Row(6, "/home", null, 403)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val expectedPlan = fillNullExpectedPlan(Seq(("status_code", Literal(0)))) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test fillnull with various null replacement values and one column") { + val frame = sql(s""" + | source = $testTable | fillnull using status_code=101 + | """.stripMargin) + + assert(frame.columns.sameElements(Array("id", "request_path", "timestamp", "status_code"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, "/home", null, 200), + Row(2, "/about", "2023-10-01 10:05:00", 101), + Row(3, "/contact", "2023-10-01 10:10:00", 101), + Row(4, null, "2023-10-01 10:15:00", 301), + Row(5, null, "2023-10-01 10:20:00", 200), + Row(6, "/home", null, 403)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val expectedPlan = fillNullExpectedPlan(Seq(("status_code", Literal(101)))) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test fillnull with one null replacement value and two columns") { + val frame = sql(s""" + | source = $testTable | fillnull with concat('??', '?') in request_path, timestamp | fields id, request_path, timestamp + | """.stripMargin) + + assert(frame.columns.sameElements(Array("id", "request_path", "timestamp"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, "/home", "???"), + Row(2, "/about", "2023-10-01 10:05:00"), + Row(3, "/contact", "2023-10-01 10:10:00"), + Row(4, "???", "2023-10-01 10:15:00"), + Row(5, "???", "2023-10-01 10:20:00"), + Row(6, "/home", "???")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val fillNullPlan = fillNullExpectedPlan( + Seq( + ( + "request_path", + UnresolvedFunction("concat", Seq(Literal("??"), Literal("?")), isDistinct = false)), + ( + "timestamp", + UnresolvedFunction("concat", Seq(Literal("??"), Literal("?")), isDistinct = false))), + addDefaultProject = false) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("request_path"), + UnresolvedAttribute("timestamp")), + fillNullPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test fillnull with various null replacement values and two columns") { + val frame = sql(s""" + | source = $testTable | fillnull using request_path=upper('/not_found'), timestamp='*' | fields id, request_path, timestamp + | """.stripMargin) + + assert(frame.columns.sameElements(Array("id", "request_path", "timestamp"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, "/home", "*"), + Row(2, "/about", "2023-10-01 10:05:00"), + Row(3, "/contact", "2023-10-01 10:10:00"), + Row(4, "/NOT_FOUND", "2023-10-01 10:15:00"), + Row(5, "/NOT_FOUND", "2023-10-01 10:20:00"), + Row(6, "/home", "*")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val fillNullPlan = fillNullExpectedPlan( + Seq( + ( + "request_path", + UnresolvedFunction("upper", Seq(Literal("/not_found")), isDistinct = false)), + ("timestamp", Literal("*"))), + addDefaultProject = false) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("request_path"), + UnresolvedAttribute("timestamp")), + fillNullPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test fillnull with one null replacement value and stats and sort command") { + val frame = sql(s""" + | source = $testTable | fillnull with 500 in status_code + | | stats count(status_code) by status_code, request_path + | | sort request_path, status_code + | """.stripMargin) + + assert(frame.columns.sameElements(Array("count(status_code)", "status_code", "request_path"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, 200, null), + Row(1, 301, null), + Row(1, 500, "/about"), + Row(1, 500, "/contact"), + Row(1, 200, "/home"), + Row(1, 403, "/home")) + // Compare the results + assert(results.sameElements(expectedResults)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val fillNullPlan = + fillNullExpectedPlan(Seq(("status_code", Literal(500))), addDefaultProject = false) + val aggregateExpressions = + Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "count(status_code)")(), + Alias(UnresolvedAttribute("status_code"), "status_code")(), + Alias(UnresolvedAttribute("request_path"), "request_path")()) + val aggregatePlan = Aggregate( + Seq( + Alias(UnresolvedAttribute("status_code"), "status_code")(), + Alias(UnresolvedAttribute("request_path"), "request_path")()), + aggregateExpressions, + fillNullPlan) + val sortPlan = Sort( + Seq( + SortOrder(UnresolvedAttribute("request_path"), Ascending), + SortOrder(UnresolvedAttribute("status_code"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(seq(UnresolvedStar(None)), sortPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test fillnull with various null replacement value and stats and sort command") { + val frame = sql(s""" + | source = $testTable | fillnull using status_code = 500, request_path = '/home' + | | stats count(status_code) by status_code, request_path + | | sort request_path, status_code + | """.stripMargin) + + assert(frame.columns.sameElements(Array("count(status_code)", "status_code", "request_path"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, 500, "/about"), + Row(1, 500, "/contact"), + Row(2, 200, "/home"), + Row(1, 301, "/home"), + Row(1, 403, "/home")) + // Compare the results + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val fillNullPlan = fillNullExpectedPlan( + Seq(("status_code", Literal(500)), ("request_path", Literal("/home"))), + addDefaultProject = false) + val aggregateExpressions = + Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "count(status_code)")(), + Alias(UnresolvedAttribute("status_code"), "status_code")(), + Alias(UnresolvedAttribute("request_path"), "request_path")()) + val aggregatePlan = Aggregate( + Seq( + Alias(UnresolvedAttribute("status_code"), "status_code")(), + Alias(UnresolvedAttribute("request_path"), "request_path")()), + aggregateExpressions, + fillNullPlan) + val sortPlan = Sort( + Seq( + SortOrder(UnresolvedAttribute("request_path"), Ascending), + SortOrder(UnresolvedAttribute("status_code"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(seq(UnresolvedStar(None)), sortPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test fillnull with one null replacement value and missing columns") { + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | fillnull with '!!!' in + | """.stripMargin)) + + assert(ex.getMessage().contains("Syntax error ")) + } + + test("test fillnull with various null replacement values and missing columns") { + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | fillnull using + | """.stripMargin)) + + assert(ex.getMessage().contains("Syntax error ")) + } + + private def fillNullExpectedPlan( + nullReplacements: Seq[(String, Expression)], + addDefaultProject: Boolean = true): LogicalPlan = { + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val renameProjectList = UnresolvedStar(None) +: nullReplacements.map { + case (nullableColumn, nullReplacement) => + Alias( + UnresolvedFunction( + "coalesce", + Seq(UnresolvedAttribute(nullableColumn), nullReplacement), + isDistinct = false), + nullableColumn)() + } + val renameProject = Project(renameProjectList, table) + val droppedColumns = + nullReplacements.map(_._1).map(columnName => UnresolvedAttribute(columnName)) + val dropSourceColumn = DataFrameDropColumns(droppedColumns, renameProject) + if (addDefaultProject) { + Project(seq(UnresolvedStar(None)), dropSourceColumn) + } else { + dropSourceColumn + } + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 64eaf415d..dd43007f4 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -35,6 +35,7 @@ NEW_FIELD: 'NEW_FIELD'; KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; +FILLNULL: 'FILLNULL'; //Native JOIN KEYWORDS JOIN: 'JOIN'; @@ -72,6 +73,9 @@ INDEX: 'INDEX'; D: 'D'; DESC: 'DESC'; DATASOURCES: 'DATASOURCES'; +VALUE: 'VALUE'; +USING: 'USING'; +WITH: 'WITH'; // CLAUSE KEYWORDS SORTBY: 'SORTBY'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 06b3166f0..fb1c79bd2 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -51,6 +51,7 @@ commands | patternsCommand | lookupCommand | renameCommand + | fillnullCommand ; searchCommand @@ -184,6 +185,29 @@ lookupPair : inputField = fieldExpression (AS outputField = fieldExpression)? ; +fillnullCommand + : FILLNULL (fillNullWithTheSameValue + | fillNullWithFieldVariousValues) + ; + + fillNullWithTheSameValue + : WITH nullReplacement IN nullableField (COMMA nullableField)* + ; + + fillNullWithFieldVariousValues + : USING nullableField EQUAL nullReplacement (COMMA nullableField EQUAL nullReplacement)* + ; + + + nullableField + : fieldExpression + ; + + nullReplacement + : expression + ; + + kmeansCommand : KMEANS (kmeansParameter)* ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 76f9479f4..e42306965 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -56,6 +56,7 @@ import org.opensearch.sql.ast.tree.SubqueryAlias; import org.opensearch.sql.ast.tree.TableFunction; import org.opensearch.sql.ast.tree.Values; +import org.opensearch.sql.ast.tree.*; /** AST nodes visitor Defines the traverse path. */ public abstract class AbstractNodeVisitor { @@ -293,4 +294,7 @@ public T visitExplain(Explain node, C context) { public T visitInSubquery(InSubquery node, C context) { return visitChildren(node, context); } + public T visitFillNull(FillNull fillNull, C context) { + return visitChildren(fillNull, context); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FillNull.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FillNull.java new file mode 100644 index 000000000..19bfea668 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FillNull.java @@ -0,0 +1,83 @@ +package org.opensearch.sql.ast.tree; + +import lombok.Getter; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +@RequiredArgsConstructor +public class FillNull extends UnresolvedPlan { + + @Getter + @RequiredArgsConstructor + public static class NullableFieldFill { + @NonNull + private final Field nullableFieldReference; + @NonNull + private final UnresolvedExpression replaceNullWithMe; + } + + public interface ContainNullableFieldFill { + List getNullFieldFill(); + + static ContainNullableFieldFill ofVariousValue(List replacements) { + return new VariousValueNullFill(replacements); + } + + static ContainNullableFieldFill ofSameValue(UnresolvedExpression replaceNullWithMe, List nullableFieldReferences) { + return new SameValueNullFill(replaceNullWithMe, nullableFieldReferences); + } + } + + private static class SameValueNullFill implements ContainNullableFieldFill { + @Getter(onMethod_ = @Override) + private final List nullFieldFill; + + public SameValueNullFill(UnresolvedExpression replaceNullWithMe, List nullableFieldReferences) { + Objects.requireNonNull(replaceNullWithMe, "Null replacement is required"); + this.nullFieldFill = Objects.requireNonNull(nullableFieldReferences, "Nullable field reference is required") + .stream() + .map(nullableReference -> new NullableFieldFill(nullableReference, replaceNullWithMe)) + .collect(Collectors.toList()); + } + } + + @RequiredArgsConstructor + private static class VariousValueNullFill implements ContainNullableFieldFill { + @NonNull + @Getter(onMethod_ = @Override) + private final List nullFieldFill; + } + + private UnresolvedPlan child; + + @NonNull + private final ContainNullableFieldFill containNullableFieldFill; + + public List getNullableFieldFills() { + return containNullableFieldFill.getNullFieldFill(); + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return child == null ? List.of() : List.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFillNull(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index bd1785c85..e6ab083ee 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -7,6 +7,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier; import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.Ascending$; @@ -60,6 +61,7 @@ import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.DescribeRelation; import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Join; @@ -84,6 +86,7 @@ import scala.Option; import scala.Option$; import scala.Tuple2; +import scala.collection.IterableLike; import scala.collection.Seq; import java.util.*; @@ -388,6 +391,37 @@ public LogicalPlan visitHead(Head node, CatalystPlanContext context) { node.getSize(), DataTypes.IntegerType), p)); } + @Override + public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) { + fillNull.getChild().get(0).accept(this, context); + List aliases = new ArrayList<>(); + for(FillNull.NullableFieldFill nullableFieldFill : fillNull.getNullableFieldFills()) { + Field field = nullableFieldFill.getNullableFieldReference(); + UnresolvedExpression replaceNullWithMe = nullableFieldFill.getReplaceNullWithMe(); + Function coalesce = new Function("coalesce", of(field, replaceNullWithMe)); + String fieldName = field.getField().toString(); + Alias alias = new Alias(fieldName, coalesce); + aliases.add(alias); + } + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + // ((Alias) expressionList.get(0)).child().children().head() + List toDrop = visitExpressionList(aliases, context).stream() + .map(org.apache.spark.sql.catalyst.expressions.Alias.class::cast) + .map(org.apache.spark.sql.catalyst.expressions.Alias::child) // coalesce + .map(UnresolvedFunction.class::cast)// coalesce + .map(UnresolvedFunction::children) // Seq of coalesce arguments + .map(IterableLike::head) // first function argument which is source field + .collect(Collectors.toList()); + Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + // build the plan with the projection step + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + LogicalPlan resultWithoutDuplicatedColumns = context.apply(logicalPlan -> DataFrameDropColumns$.MODULE$.apply(seq(toDrop), logicalPlan)); + return Objects.requireNonNull(resultWithoutDuplicatedColumns, "FillNull operation failed"); + } + private void visitFieldList(List fieldList, CatalystPlanContext context) { fieldList.forEach(field -> visitExpression(field, context)); } @@ -694,7 +728,10 @@ public Expression visitIsEmpty(IsEmpty node, CatalystPlanContext context) { return expression; } - + @Override + public Expression visitFillNull(FillNull fillNull, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : FillNull"); + } @Override public Expression visitInterval(Interval node, CatalystPlanContext context) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 8ab370c7f..8673b1582 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -11,6 +11,8 @@ import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.ParseTree; import org.opensearch.flint.spark.ppl.OpenSearchPPLParser; +import org.opensearch.flint.spark.ppl.OpenSearchPPLParser.FillNullWithFieldVariousValuesContext; +import org.opensearch.flint.spark.ppl.OpenSearchPPLParser.FillNullWithTheSameValueContext; import org.opensearch.flint.spark.ppl.OpenSearchPPLParserBaseVisitor; import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; @@ -28,6 +30,9 @@ import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.expression.UnresolvedArgument; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.*; +import org.opensearch.sql.ast.tree.FillNull.NullableFieldFill; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Correlation; import org.opensearch.sql.ast.tree.Dedupe; @@ -57,9 +62,12 @@ import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; +import java.util.stream.IntStream; import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; +import static org.opensearch.sql.ast.tree.FillNull.ContainNullableFieldFill.ofSameValue; +import static org.opensearch.sql.ast.tree.FillNull.ContainNullableFieldFill.ofVariousValue; /** Class of building the AST. Refines the visit path and build the AST nodes */ @@ -498,6 +506,35 @@ public UnresolvedPlan visitKmeansCommand(OpenSearchPPLParser.KmeansCommandContex return new Kmeans(builder.build()); } + @Override + public UnresolvedPlan visitFillnullCommand(OpenSearchPPLParser.FillnullCommandContext ctx) { + // ctx contain result of parsing fillnull command. Lets transform it to UnresolvedPlan which is FillNull + FillNullWithTheSameValueContext sameValueContext = ctx.fillNullWithTheSameValue(); + FillNullWithFieldVariousValuesContext variousValuesContext = ctx.fillNullWithFieldVariousValues(); + if (sameValueContext != null) { + // todo consider using expression instead of Literal + UnresolvedExpression replaceNullWithMe = internalVisitExpression(sameValueContext.nullReplacement().expression()); + List fieldsToReplace = sameValueContext.nullableField() + .stream() + .map(this::internalVisitExpression) + .map(Field.class::cast) + .collect(Collectors.toList()); + return new FillNull(ofSameValue(replaceNullWithMe, fieldsToReplace)); + } else if (variousValuesContext != null) { + List nullableFieldFills = IntStream.range(0, variousValuesContext.nullableField().size()) + .mapToObj(index -> { + variousValuesContext.nullableField(index); + UnresolvedExpression replaceNullWithMe = internalVisitExpression(variousValuesContext.nullReplacement(index).expression()); + Field nullableFieldReference = (Field) internalVisitExpression(variousValuesContext.nullableField(index)); + return new NullableFieldFill(nullableFieldReference, replaceNullWithMe); + }) + .collect(Collectors.toList()); + return new FillNull(ofVariousValue(nullableFieldFills)); + } else { + throw new SyntaxCheckException("Invalid fillnull command"); + } + } + /** AD command. */ @Override public UnresolvedPlan visitAdCommand(OpenSearchPPLParser.AdCommandContext ctx) { diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFillnullCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFillnullCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..9f38465da --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFillnullCommandTranslatorTestSuite.scala @@ -0,0 +1,240 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Project} + +class PPLLogicalPlanFillnullCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test fillnull with one null replacement value and one column") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | fillnull with 'null replacement value' in column_name"), + context) + + val relation = UnresolvedRelation(Seq("relation")) + + val renameProjectList: Seq[NamedExpression] = + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "coalesce", + Seq(UnresolvedAttribute("column_name"), Literal("null replacement value")), + isDistinct = false), + "column_name")()) + val renameProject = Project(renameProjectList, relation) + + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("column_name")), renameProject) + + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test fillnull with one null replacement value, one column and function invocation") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | fillnull with upper(another_field) in column_name"), + context) + + val relation = UnresolvedRelation(Seq("relation")) + + val renameProjectList: Seq[NamedExpression] = + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "coalesce", + Seq( + UnresolvedAttribute("column_name"), + UnresolvedFunction( + "upper", + Seq(UnresolvedAttribute("another_field")), + isDistinct = false)), + isDistinct = false), + "column_name")()) + val renameProject = Project(renameProjectList, relation) + + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("column_name")), renameProject) + + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test fillnull with one null replacement value and multiple column") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | fillnull with 'another null replacement value' in column_name_one, column_name_two, column_name_three"), + context) + + val relation = UnresolvedRelation(Seq("relation")) + + val renameProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "coalesce", + Seq(UnresolvedAttribute("column_name_one"), Literal("another null replacement value")), + isDistinct = false), + "column_name_one")(), + Alias( + UnresolvedFunction( + "coalesce", + Seq(UnresolvedAttribute("column_name_two"), Literal("another null replacement value")), + isDistinct = false), + "column_name_two")(), + Alias( + UnresolvedFunction( + "coalesce", + Seq( + UnresolvedAttribute("column_name_three"), + Literal("another null replacement value")), + isDistinct = false), + "column_name_three")()) + val renameProject = Project(renameProjectList, relation) + + val dropSourceColumn = DataFrameDropColumns( + Seq( + UnresolvedAttribute("column_name_one"), + UnresolvedAttribute("column_name_two"), + UnresolvedAttribute("column_name_three")), + renameProject) + + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test fillnull with possibly various null replacement value and one column") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | fillnull using column_name='null replacement value'"), + context) + + val relation = UnresolvedRelation(Seq("relation")) + + val renameProjectList: Seq[NamedExpression] = + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "coalesce", + Seq(UnresolvedAttribute("column_name"), Literal("null replacement value")), + isDistinct = false), + "column_name")()) + val renameProject = Project(renameProjectList, relation) + + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("column_name")), renameProject) + + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test( + "test fillnull with possibly various null replacement value, one column and function invocation") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | fillnull using column_name=concat('missing value for', id)"), + context) + + val relation = UnresolvedRelation(Seq("relation")) + + val renameProjectList: Seq[NamedExpression] = + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "coalesce", + Seq( + UnresolvedAttribute("column_name"), + UnresolvedFunction( + "concat", + Seq(Literal("missing value for"), UnresolvedAttribute("id")), + isDistinct = false)), + isDistinct = false), + "column_name")()) + val renameProject = Project(renameProjectList, relation) + + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("column_name")), renameProject) + + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test fillnull with possibly various null replacement value and three columns") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | fillnull using column_name_1='null replacement value 1', column_name_2='null replacement value 2', column_name_3='null replacement value 3'"), + context) + + val relation = UnresolvedRelation(Seq("relation")) + + val renameProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "coalesce", + Seq(UnresolvedAttribute("column_name_1"), Literal("null replacement value 1")), + isDistinct = false), + "column_name_1")(), + Alias( + UnresolvedFunction( + "coalesce", + Seq(UnresolvedAttribute("column_name_2"), Literal("null replacement value 2")), + isDistinct = false), + "column_name_2")(), + Alias( + UnresolvedFunction( + "coalesce", + Seq(UnresolvedAttribute("column_name_3"), Literal("null replacement value 3")), + isDistinct = false), + "column_name_3")()) + val renameProject = Project(renameProjectList, relation) + + val dropSourceColumn = DataFrameDropColumns( + Seq( + UnresolvedAttribute("column_name_1"), + UnresolvedAttribute("column_name_2"), + UnresolvedAttribute("column_name_3")), + renameProject) + + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +}