diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLookupITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLookupITSuite.scala new file mode 100644 index 000000000..2815f6031 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLookupITSuite.scala @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, SortOrder} +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.{QueryTest, Row} + +class FlintSparkPPLLookupITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + private val lookupTable = "spark_catalog.default.flint_ppl_test_lookup" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + createOccupationTable(lookupTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("create ppl simple query test") { + val frame = sql(s""" + | source = $testTable | where age > 20 | lookup flint_ppl_test_lookup name + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + + assert(results.length == 3) + + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("Jake", 70, "California", "USA", 2023, 4, "Jake", "Engineer", "England", 100000, 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4, "Hello", "Artist", "USA", 70000, 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4, "John", "Doctor", "Canada", 120000, 2023, 4)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val expectedPlan: LogicalPlan = + Project( + Seq(UnresolvedStar(None)), + Join( + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")), + JoinType.apply("left"), + Option.empty, + JoinHint.NONE + ) + //UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + ) + // Compare the two plans + assert(expectedPlan === logicalPlan) + } +} + + diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index b1c988b28..61370b679 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -65,6 +65,7 @@ NUM: 'NUM'; // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; CONSECUTIVE: 'CONSECUTIVE'; +APPENDONLY: 'APPENDONLY'; DEDUP_SPLITVALUES: 'DEDUP_SPLITVALUES'; PARTITIONS: 'PARTITIONS'; ALLNUM: 'ALLNUM'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 4b4e64c1a..4169dfe0a 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -38,6 +38,7 @@ commands | renameCommand | statsCommand | dedupCommand + | lookupCommand | sortCommand | evalCommand | headCommand @@ -107,6 +108,18 @@ dedupCommand : DEDUP (number = integerLiteral)? fieldList (KEEPEMPTY EQUAL keepempty = booleanLiteral)? (CONSECUTIVE EQUAL consecutive = booleanLiteral)? ; +matchFieldWithOptAs + : orignalMatchField = fieldExpression (AS asMatchField = fieldExpression)? + ; + +copyFieldWithOptAs + : orignalCopyField = fieldExpression (AS asCopyField = fieldExpression)? + ; + +lookupCommand + : LOOKUP tableSource matchFieldWithOptAs (COMMA matchFieldWithOptAs)* (APPENDONLY EQUAL appendonly = booleanLiteral)? (copyFieldWithOptAs (COMMA copyFieldWithOptAs)*)* + ; + sortCommand : SORT sortbyClause ; @@ -848,6 +861,7 @@ keywordsCanBeId | RENAME | STATS | DEDUP + | LOOKUP | SORT | EVAL | HEAD diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index e3d0c6a2b..1614e374c 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -43,6 +43,7 @@ import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; +import org.opensearch.sql.ast.tree.Lookup; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareTopN; @@ -208,6 +209,10 @@ public T visitDedupe(Dedupe node, C context) { return visitChildren(node, context); } + public T visitLookup(Lookup node, C context) { + return visitChildren(node, context); + } + public T visitHead(Head node, C context) { return visitChildren(node, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Lookup.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Lookup.java new file mode 100644 index 000000000..06b3370a9 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Lookup.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Map; + +import java.util.List; + +/** AST node represent Lookup operation. */ + +public class Lookup extends UnresolvedPlan { + private UnresolvedPlan child; + private final String tableName; + private final List matchFieldList; + private final List options; + private final List copyFieldList; + + public Lookup(UnresolvedPlan child, String tableName, List matchFieldList, List options, List copyFieldList) { + this.child = child; + this.tableName = tableName; + this.matchFieldList = matchFieldList; + this.options = options; + this.copyFieldList = copyFieldList; + } + + public Lookup(String tableName, List matchFieldList, List options, List copyFieldList) { + this.tableName = tableName; + this.matchFieldList = matchFieldList; + this.options = options; + this.copyFieldList = copyFieldList; + } + + @Override + public Lookup attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + public String getTableName() { + return tableName; + } + + public List getMatchFieldList() { + return matchFieldList; + } + + public List getOptions() { + return options; + } + + public List getCopyFieldList() { + return copyFieldList; + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitLookup(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 6d14db328..40609dd86 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -7,17 +7,23 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; +import org.apache.spark.sql.catalyst.analysis.UnresolvedStar; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.expressions.EqualTo; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.catalyst.plans.JoinType; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.Join; +import org.apache.spark.sql.catalyst.plans.logical.JoinHint; import org.apache.spark.sql.catalyst.plans.logical.Limit; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.AllFields; @@ -32,6 +38,7 @@ import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.Map; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; @@ -49,6 +56,7 @@ import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; +import org.opensearch.sql.ast.tree.Lookup; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareTopN; import org.opensearch.sql.ast.tree.Relation; @@ -56,9 +64,12 @@ import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; import org.opensearch.sql.ppl.utils.SortUtils; +import org.sparkproject.guava.collect.Iterables; import scala.Option; import scala.collection.Seq; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -256,6 +267,93 @@ public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { throw new IllegalStateException("Not Supported operation : dedupe "); } + @Override + public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { + Node root = node.getChild().get(0); + + while(!root.getChild().isEmpty()) { + root = root.getChild().get(0); + } + + org.opensearch.sql.ast.tree.Relation source = (org.opensearch.sql.ast.tree.Relation) root; + + node.getChild().get(0).accept(this, context); + + //TODO: not sure how to implement appendonly + Boolean appendonly = (Boolean) node.getOptions().get(0).getValue().getValue(); + + LogicalPlan lookupRelation = new UnresolvedRelation(seq(of(node.getTableName())), CaseInsensitiveStringMap.empty(), false); + org.apache.spark.sql.catalyst.plans.logical.Project lookupProject; + + List lookupRelationFields = buildLookupTableFieldList(node, context); + if (! lookupRelationFields.isEmpty()) { + lookupProject = new org.apache.spark.sql.catalyst.plans.logical.Project(seq(lookupRelationFields), lookupRelation); + } else { + lookupProject = new org.apache.spark.sql.catalyst.plans.logical.Project(seq(of(new UnresolvedStar(Option.empty()))), lookupRelation); + } + + //TODO: use node.getCopyFieldList() to prefilter the right logical plan + //and return only the fields listed there. rename fields when requested + + Expression joinCondition = buildLookupTableJoinCondition(node.getMatchFieldList(), source.getTableQualifiedName().toString(), node.getTableName(), context); + + return context.apply(p -> new Join( + + p, //original query (left) + + lookupProject, //lookup query (right) + + JoinType.apply("left"), //https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-join.html + + Option.apply(joinCondition), //which fields to join + + JoinHint.NONE() //TODO: check, https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-hints.html#join-hints-types + )); + } + + private List buildLookupTableFieldList(Lookup node, CatalystPlanContext context) { + if (node.getCopyFieldList().isEmpty()) { + return Collections.emptyList(); + } else { + //todo should we also append fields used to match records - node.getMatchFieldList()? + List copyFields = node.getCopyFieldList().stream() + .map(copyField -> (NamedExpression) expressionAnalyzer.visitAlias(copyField, context)) + .collect(Collectors.toList()); + return copyFields; + } + } + + private org.opensearch.sql.ast.expression.Field prefixField(List prefixParts, UnresolvedExpression field) { + org.opensearch.sql.ast.expression.Field in = (org.opensearch.sql.ast.expression.Field) field; + org.opensearch.sql.ast.expression.QualifiedName inq = (org.opensearch.sql.ast.expression.QualifiedName) in.getField(); + Iterable finalParts = Iterables.concat(prefixParts, inq.getParts()); + return new org.opensearch.sql.ast.expression.Field(new org.opensearch.sql.ast.expression.QualifiedName(finalParts), in.getFieldArgs()); + } + + private Expression buildLookupTableJoinCondition(List fieldMap, String sourceTableName, String lookupTableName, CatalystPlanContext context) { + int size = fieldMap.size(); + + List allEqlExpressions = new ArrayList<>(size); + + for (Map map : fieldMap) { + + //todo do we need to run prefixField? match fields are anyway handled as qualifiedName? +// Expression origin = visitExpression(prefixField(of(sourceTableName.split("\\.")),map.getOrigin()), context); +// Expression target = visitExpression(prefixField(of(lookupTableName.split("\\.")),map.getTarget()), context); + Expression origin = visitExpression(map.getOrigin(), context); + Expression target = visitExpression(map.getTarget(), context); + + + //important + context.retainAllNamedParseExpressions(e -> e); + + Expression eql = new EqualTo(origin, target); + allEqlExpressions.add(eql); + } + + return allEqlExpressions.stream().reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); + } + /** * Expression Analyzer. */ diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index a810ea180..61be26287 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -32,6 +32,7 @@ import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; +import org.opensearch.sql.ast.tree.Lookup; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareTopN; @@ -43,7 +44,6 @@ import org.opensearch.sql.ppl.utils.ArgumentFactory; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -195,6 +195,47 @@ public UnresolvedPlan visitDedupCommand(OpenSearchPPLParser.DedupCommandContext return new Dedupe(ArgumentFactory.getArgumentList(ctx), getFieldList(ctx.fieldList())); } + /** Lookup command */ + @Override + public UnresolvedPlan visitLookupCommand(OpenSearchPPLParser.LookupCommandContext ctx) { + ArgumentFactory.getArgumentList(ctx); + ctx.tableSource(); + ctx.copyFieldWithOptAs(); + ctx.matchFieldWithOptAs(); + return new Lookup( + ctx.tableSource().tableQualifiedName().getText(), + ctx.matchFieldWithOptAs().stream() + .map( + ct -> + new Map( + evaluateFieldExpressionContext(ct.orignalMatchField), + evaluateFieldExpressionContext(ct.asMatchField, ct.orignalMatchField))) + .collect(Collectors.toList()), + ArgumentFactory.getArgumentList(ctx), + ctx.copyFieldWithOptAs().stream() + .map( + ct -> + new Alias( + ct.asCopyField == null? ct.orignalCopyField.getText() : ct.asCopyField.getText(), + evaluateFieldExpressionContext(ct.orignalCopyField))) + .collect(Collectors.toList())); + } + + private UnresolvedExpression evaluateFieldExpressionContext( + OpenSearchPPLParser.FieldExpressionContext f) { + return internalVisitExpression(f); + } + + private UnresolvedExpression evaluateFieldExpressionContext( + OpenSearchPPLParser.FieldExpressionContext f0, + OpenSearchPPLParser.FieldExpressionContext f1) { + if (f0 == null) { + return internalVisitExpression(f1); + } else { + return internalVisitExpression(f0); + } + } + /** Head command visitor. */ @Override public UnresolvedPlan visitHeadCommand(OpenSearchPPLParser.HeadCommandContext ctx) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index 43f696bcd..6c106f5fb 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -72,6 +72,13 @@ public static List getArgumentList(OpenSearchPPLParser.DedupCommandCon : new Argument("consecutive", new Literal(false, DataType.BOOLEAN))); } + public static List getArgumentList(OpenSearchPPLParser.LookupCommandContext ctx) { + return Arrays.asList( + ctx.appendonly != null + ? new Argument("appendonly", getArgumentValue(ctx.appendonly)) + : new Argument("appendonly", new Literal(false, DataType.BOOLEAN))); + } + /** * Get list of {@link Argument}. * diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLookupTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLookupTranslatorTestSuite.scala new file mode 100644 index 000000000..5959d2193 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLookupTranslatorTestSuite.scala @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, GreaterThanOrEqual, Literal, Multiply, SortOrder, TimeWindow} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.junit.Assert.assertEquals +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +class PPLLogicalPlanLookupTranslatorTestSuite + extends SparkFunSuite + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test lookup ") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source = table | lookup a b,c as d, e as f,g as b, j appendonly=true q,w as z ", false), context) + val star = Seq(UnresolvedStar(None)) + + val priceField = UnresolvedAttribute("price") + val tableRelation = UnresolvedRelation(Seq("table")) + val aggregateExpressions = Seq( + Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")()) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation) + val expectedPlan = Project(star, aggregatePlan) + + //scalastyle:off + println("### plan ###\n"+compareByString(logPlan)+"\n#########"); + + assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + } +}