From 6c3744ec81931567d50f47075cd45b3e8906997c Mon Sep 17 00:00:00 2001 From: Rupal Mahajan Date: Fri, 23 Jun 2023 10:11:51 -0700 Subject: [PATCH 1/4] Add Spark application (#1723) * Initial spark application draft Signed-off-by: Rupal Mahajan * Remove temp table Signed-off-by: Rupal Mahajan * Add license header Signed-off-by: Rupal Mahajan * Add scalastyle-config and update readme Signed-off-by: Rupal Mahajan * Fix datatype for result and schema Signed-off-by: Rupal Mahajan * Add test Signed-off-by: Rupal Mahajan * Simplify code using toJSON.collect.toList Signed-off-by: Rupal Mahajan * Add example in readme Signed-off-by: Rupal Mahajan * Fix triple quotes issue Signed-off-by: Rupal Mahajan * Update method name and description Signed-off-by: Rupal Mahajan --------- Signed-off-by: Rupal Mahajan --- spark-sql-application/.gitignore | 14 +++ spark-sql-application/README.md | 107 ++++++++++++++++++ spark-sql-application/build.sbt | 28 +++++ .../project/build.properties | 1 + spark-sql-application/project/plugins.sbt | 6 + spark-sql-application/scalastyle-config.xml | 106 +++++++++++++++++ .../scala/org/opensearch/sql/SQLJob.scala | 98 ++++++++++++++++ .../scala/org/opensearch/sql/SQLJobTest.scala | 54 +++++++++ 8 files changed, 414 insertions(+) create mode 100644 spark-sql-application/.gitignore create mode 100644 spark-sql-application/README.md create mode 100644 spark-sql-application/build.sbt create mode 100644 spark-sql-application/project/build.properties create mode 100644 spark-sql-application/project/plugins.sbt create mode 100644 spark-sql-application/scalastyle-config.xml create mode 100644 spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala create mode 100644 spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala diff --git a/spark-sql-application/.gitignore b/spark-sql-application/.gitignore new file mode 100644 index 0000000000..ec13a702be --- /dev/null +++ b/spark-sql-application/.gitignore @@ -0,0 +1,14 @@ +# Compiled output +target/ +project/target/ + +# sbt-specific files +.sbtserver +.sbt/ +.bsp/ + +# Miscellaneous +.DS_Store +*.class +*.log +*.zip \ No newline at end of file diff --git a/spark-sql-application/README.md b/spark-sql-application/README.md new file mode 100644 index 0000000000..b0505282ab --- /dev/null +++ b/spark-sql-application/README.md @@ -0,0 +1,107 @@ +# Spark SQL Application + +This application execute sql query and store the result in OpenSearch index in following format +``` +"stepId":"", +"schema": "json blob", +"result": "json blob" +``` + +## Prerequisites + ++ Spark 3.3.1 ++ Scala 2.12.15 ++ flint-spark-integration + +## Usage + +To use this application, you can run Spark with Flint extension: + +``` +./bin/spark-submit \ + --class org.opensearch.sql.SQLJob \ + --jars \ + sql-job.jar \ + \ + \ + \ + \ + \ + \ + \ +``` + +## Result Specifications + +Following example shows how the result is written to OpenSearch index after query execution. + +Let's assume sql query result is +``` ++------+------+ +|Letter|Number| ++------+------+ +|A |1 | +|B |2 | +|C |3 | ++------+------+ +``` +OpenSearch index document will look like +```json +{ + "_index" : ".query_execution_result", + "_id" : "A2WOsYgBMUoqCqlDJHrn", + "_score" : 1.0, + "_source" : { + "result" : [ + "{'Letter':'A','Number':1}", + "{'Letter':'B','Number':2}", + "{'Letter':'C','Number':3}" + ], + "schema" : [ + "{'column_name':'Letter','data_type':'string'}", + "{'column_name':'Number','data_type':'integer'}" + ], + "stepId" : "s-JZSB1139WIVU" + } +} +``` + +## Build + +To build and run this application with Spark, you can run: + +``` +sbt clean publishLocal +``` + +## Test + +To run tests, you can use: + +``` +sbt test +``` + +## Scalastyle + +To check code with scalastyle, you can run: + +``` +sbt scalastyle +``` + +## Code of Conduct + +This project has adopted an [Open Source Code of Conduct](../CODE_OF_CONDUCT.md). + +## Security + +If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public GitHub issue. + +## License + +See the [LICENSE](../LICENSE.txt) file for our project's licensing. We will ask you to confirm the licensing of your contribution. + +## Copyright + +Copyright OpenSearch Contributors. See [NOTICE](../NOTICE) for details. \ No newline at end of file diff --git a/spark-sql-application/build.sbt b/spark-sql-application/build.sbt new file mode 100644 index 0000000000..79d69a30d1 --- /dev/null +++ b/spark-sql-application/build.sbt @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +name := "sql-job" + +version := "1.0" + +scalaVersion := "2.12.15" + +val sparkVersion = "3.3.2" + +mainClass := Some("org.opensearch.sql.SQLJob") + +artifactName := { (sv: ScalaVersion, module: ModuleID, artifact: Artifact) => + "sql-job.jar" +} + +resolvers ++= Seq( + ("apache-snapshots" at "http://repository.apache.org/snapshots/").withAllowInsecureProtocol(true) +) + +libraryDependencies ++= Seq( + "org.apache.spark" %% "spark-core" % sparkVersion % "provided", + "org.apache.spark" %% "spark-sql" % sparkVersion % "provided", + "org.scalatest" %% "scalatest" % "3.2.15" % Test +) diff --git a/spark-sql-application/project/build.properties b/spark-sql-application/project/build.properties new file mode 100644 index 0000000000..46e43a97ed --- /dev/null +++ b/spark-sql-application/project/build.properties @@ -0,0 +1 @@ +sbt.version=1.8.2 diff --git a/spark-sql-application/project/plugins.sbt b/spark-sql-application/project/plugins.sbt new file mode 100644 index 0000000000..4d14ba6c10 --- /dev/null +++ b/spark-sql-application/project/plugins.sbt @@ -0,0 +1,6 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0") \ No newline at end of file diff --git a/spark-sql-application/scalastyle-config.xml b/spark-sql-application/scalastyle-config.xml new file mode 100644 index 0000000000..37b1978cd7 --- /dev/null +++ b/spark-sql-application/scalastyle-config.xml @@ -0,0 +1,106 @@ + + Scalastyle standard configuration + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala b/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala new file mode 100644 index 0000000000..f2dd0c869c --- /dev/null +++ b/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql + +import org.apache.spark.sql.{DataFrame, SparkSession, Row} +import org.apache.spark.sql.types._ + +/** + * Spark SQL Application entrypoint + * + * @param args(0) + * sql query + * @param args(1) + * opensearch index name + * @param args(2-6) + * opensearch connection values required for flint-integration jar. host, port, scheme, auth, region respectively. + * @return + * write sql query result to given opensearch index + */ +object SQLJob { + def main(args: Array[String]) { + // Get the SQL query and Opensearch Config from the command line arguments + val query = args(0) + val index = args(1) + val host = args(2) + val port = args(3) + val scheme = args(4) + val auth = args(5) + val region = args(6) + + // Create a SparkSession + val spark = SparkSession.builder().appName("SQLJob").getOrCreate() + + try { + // Execute SQL query + val result: DataFrame = spark.sql(query) + + // Get Data + val data = getFormattedData(result, spark) + + // Write data to OpenSearch index + val aos = Map( + "host" -> host, + "port" -> port, + "scheme" -> scheme, + "auth" -> auth, + "region" -> region) + + data.write + .format("flint") + .options(aos) + .mode("append") + .save(index) + + } finally { + // Stop SparkSession + spark.stop() + } + } + + /** + * Create a new formatted dataframe with json result, json schema and EMR_STEP_ID. + * + * @param result + * sql query result dataframe + * @param spark + * spark session + * @return + * dataframe with result, schema and emr step id + */ + def getFormattedData(result: DataFrame, spark: SparkSession): DataFrame = { + // Create the schema dataframe + val schemaRows = result.schema.fields.map { field => + Row(field.name, field.dataType.typeName) + } + val resultSchema = spark.createDataFrame(spark.sparkContext.parallelize(schemaRows), StructType(Seq( + StructField("column_name", StringType, nullable = false), + StructField("data_type", StringType, nullable = false)))) + + // Define the data schema + val schema = StructType(Seq( + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), + StructField("stepId", StringType, nullable = true))) + + // Create the data rows + val rows = Seq(( + result.toJSON.collect.toList.map(_.replaceAll("\"", "'")), + resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'")), + sys.env.getOrElse("EMR_STEP_ID", ""))) + + // Create the DataFrame for data + spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) + } +} diff --git a/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala b/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala new file mode 100644 index 0000000000..2cdb06d6ca --- /dev/null +++ b/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql + +import org.scalatest.funsuite.AnyFunSuite +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} + + +class SQLJobTest extends AnyFunSuite{ + + val spark = SparkSession.builder().appName("Test").master("local").getOrCreate() + + // Define input dataframe + val inputSchema = StructType(Seq( + StructField("Letter", StringType, nullable = false), + StructField("Number", IntegerType, nullable = false) + )) + val inputRows = Seq( + Row("A", 1), + Row("B", 2), + Row("C", 3) + ) + val input: DataFrame = spark.createDataFrame(spark.sparkContext.parallelize(inputRows), inputSchema) + + test("Test getFormattedData method") { + // Define expected dataframe + val expectedSchema = StructType(Seq( + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), + StructField("stepId", StringType, nullable = true) + )) + val expectedRows = Seq( + Row( + Array("{'Letter':'A','Number':1}","{'Letter':'B','Number':2}", "{'Letter':'C','Number':3}"), + Array("{'column_name':'Letter','data_type':'string'}", "{'column_name':'Number','data_type':'integer'}"), + "" + ) + ) + val expected: DataFrame = spark.createDataFrame(spark.sparkContext.parallelize(expectedRows), expectedSchema) + + // Compare the result + val result = SQLJob.getFormattedData(input, spark) + assertEqualDataframe(expected, result) + } + + def assertEqualDataframe(expected: DataFrame, result: DataFrame): Unit ={ + assert(expected.schema === result.schema) + assert(expected.collect() === result.collect()) + } +} From dc4e468a4849d6f2fd28b15d1b926d8cce808a9f Mon Sep 17 00:00:00 2001 From: Max Ksyunz Date: Mon, 26 Jun 2023 11:32:47 -0700 Subject: [PATCH 2/4] Update sqlite-jdbc to 3.41.2.2 to address CVE-2023-32697 (#1667) * Update sqlite-jdbc to 3.41.2.2 to address CVE-2023-32697 Signed-off-by: MaxKsyunz * Don't check column names on H2 results for correctness tests as described in https://github.com/opensearch-project/sql/pull/1667#issuecomment-1603659136. Signed-off-by: Yury-Fridlyand * Address PR review comment. Signed-off-by: Yury-Fridlyand --------- Signed-off-by: MaxKsyunz Signed-off-by: Yury-Fridlyand Co-authored-by: Yury-Fridlyand --- integ-test/build.gradle | 2 +- .../runner/resultset/DBResult.java | 24 +++++++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/integ-test/build.gradle b/integ-test/build.gradle index 2c1a066481..fc97fff9a4 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -101,7 +101,7 @@ dependencies { testRuntimeOnly('org.junit.jupiter:junit-jupiter-engine:5.6.2') testImplementation group: 'com.h2database', name: 'h2', version: '2.1.214' - testImplementation group: 'org.xerial', name: 'sqlite-jdbc', version: '3.32.3.3' + testImplementation group: 'org.xerial', name: 'sqlite-jdbc', version: '3.41.2.2' testImplementation group: 'com.google.code.gson', name: 'gson', version: '2.8.9' // Needed for BWC tests diff --git a/integ-test/src/test/java/org/opensearch/sql/correctness/runner/resultset/DBResult.java b/integ-test/src/test/java/org/opensearch/sql/correctness/runner/resultset/DBResult.java index 0899a6e2c4..eb522b008d 100644 --- a/integ-test/src/test/java/org/opensearch/sql/correctness/runner/resultset/DBResult.java +++ b/integ-test/src/test/java/org/opensearch/sql/correctness/runner/resultset/DBResult.java @@ -12,8 +12,9 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.Set; -import lombok.EqualsAndHashCode; +import java.util.stream.Collectors; import lombok.Getter; import lombok.ToString; import org.json.JSONPropertyName; @@ -24,7 +25,6 @@ * query with SELECT columns or just *, order of column and row may matter or not. So the internal data structure of this * class is passed in from outside either list or set, hash map or linked hash map etc. */ -@EqualsAndHashCode(exclude = "databaseName") @ToString public class DBResult { @@ -191,4 +191,24 @@ private static > List sort(Collection collection) return list; } + public boolean equals(final Object o) { + if (o == this) { + return true; + } + if (!(o instanceof DBResult)) { + return false; + } + final DBResult other = (DBResult) o; + // H2 calculates the value before setting column name + // for example, for query "select 1 + 1" it returns a column named "2" instead of "1 + 1" + boolean skipColumnNameCheck = databaseName.equalsIgnoreCase("h2") || other.databaseName.equalsIgnoreCase("h2"); + if (!skipColumnNameCheck && !schema.equals(other.schema)) { + return false; + } + if (skipColumnNameCheck && !schema.stream().map(Type::getType).collect(Collectors.toList()) + .equals(other.schema.stream().map(Type::getType).collect(Collectors.toList()))) { + return false; + } + return dataRows.equals(other.dataRows); + } } From 9fbcf11258964616d2f5056420cc83afedd71613 Mon Sep 17 00:00:00 2001 From: Forest Vey Date: Tue, 27 Jun 2023 11:36:20 -0700 Subject: [PATCH 3/4] Support Array and ExprValue Parsing With Inner Hits (#1737) * Add support for Array and ExprValue Parsing With Inner Hits Signed-off-by: forestmvey * Adding schema validation for IT test, and another UT for nested arrays. Signed-off-by: forestmvey * Making handleAggregationResponse a private function. Signed-off-by: forestmvey --------- Signed-off-by: forestmvey --- .../sql/legacy/ObjectFieldSelectIT.java | 1 - .../java/org/opensearch/sql/sql/NestedIT.java | 35 +++ .../nested_simple_index_mapping.json | 8 + .../src/test/resources/nested_simple.json | 10 +- .../sql/opensearch/data/utils/Content.java | 25 ++ .../opensearch/data/utils/ObjectContent.java | 26 ++ .../data/utils/OpenSearchJsonContent.java | 30 +- .../value/OpenSearchExprValueFactory.java | 129 ++++++-- .../response/OpenSearchResponse.java | 149 ++++++---- .../storage/script/core/ExpressionScript.java | 6 +- .../client/OpenSearchNodeClientTest.java | 2 +- .../client/OpenSearchRestClientTest.java | 3 +- .../value/OpenSearchExprValueFactoryTest.java | 280 +++++++++++++++++- .../response/OpenSearchResponseTest.java | 21 +- 14 files changed, 613 insertions(+), 112 deletions(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/ObjectFieldSelectIT.java b/integ-test/src/test/java/org/opensearch/sql/legacy/ObjectFieldSelectIT.java index b1db21a2ff..ce781123d6 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/ObjectFieldSelectIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/ObjectFieldSelectIT.java @@ -14,7 +14,6 @@ import org.json.JSONArray; import org.json.JSONObject; -import org.junit.Assume; import org.junit.Test; import org.opensearch.sql.legacy.utils.StringUtils; diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java index 6cb7b7580b..80886fe779 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java @@ -6,6 +6,7 @@ package org.opensearch.sql.sql; import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_MULTI_NESTED_TYPE; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_NESTED_SIMPLE; import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_NESTED_TYPE; import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_NESTED_TYPE_WITHOUT_ARRAYS; import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_NESTED_WITH_NULLS; @@ -31,6 +32,7 @@ public void init() throws IOException { loadIndex(Index.NESTED_WITHOUT_ARRAYS); loadIndex(Index.EMPLOYEE_NESTED); loadIndex(Index.NESTED_WITH_NULLS); + loadIndex(Index.NESTED_SIMPLE); } @Test @@ -366,4 +368,37 @@ public void test_nested_in_where_as_predicate_expression_with_relevance_query() assertEquals(1, result.getInt("total")); verifyDataRows(result, rows(10, "a")); } + + @Test + public void nested_function_with_date_types_as_object_arrays_within_arrays_test() { + String query = "SELECT nested(address.moveInDate) FROM " + TEST_INDEX_NESTED_SIMPLE; + JSONObject result = executeJdbcRequest(query); + + assertEquals(11, result.getInt("total")); + verifySchema(result, + schema("nested(address.moveInDate)", null, "object") + ); + verifyDataRows(result, + rows(new JSONObject(Map.of("dateAndTime","1984-04-12 09:07:42"))), + rows(new JSONArray( + List.of( + Map.of("dateAndTime", "2023-05-03 08:07:42"), + Map.of("dateAndTime", "2001-11-11 04:07:44")) + ) + ), + rows(new JSONObject(Map.of("dateAndTime", "1966-03-19 03:04:55"))), + rows(new JSONObject(Map.of("dateAndTime", "2011-06-01 01:01:42"))), + rows(new JSONObject(Map.of("dateAndTime", "1901-08-11 04:03:33"))), + rows(new JSONObject(Map.of("dateAndTime", "2023-05-03 08:07:42"))), + rows(new JSONObject(Map.of("dateAndTime", "2001-11-11 04:07:44"))), + rows(new JSONObject(Map.of("dateAndTime", "1977-07-13 09:04:41"))), + rows(new JSONObject(Map.of("dateAndTime", "1933-12-12 05:05:45"))), + rows(new JSONObject(Map.of("dateAndTime", "1909-06-17 01:04:21"))), + rows(new JSONArray( + List.of( + Map.of("dateAndTime", "2001-11-11 04:07:44")) + ) + ) + ); + } } diff --git a/integ-test/src/test/resources/indexDefinitions/nested_simple_index_mapping.json b/integ-test/src/test/resources/indexDefinitions/nested_simple_index_mapping.json index 2ebc8a50de..7e521cdd44 100644 --- a/integ-test/src/test/resources/indexDefinitions/nested_simple_index_mapping.json +++ b/integ-test/src/test/resources/indexDefinitions/nested_simple_index_mapping.json @@ -21,6 +21,14 @@ "ignore_above": 256 } } + }, + "moveInDate" : { + "properties": { + "dateAndTime": { + "type": "date", + "format": "basic_date_time" + } + } } } }, diff --git a/integ-test/src/test/resources/nested_simple.json b/integ-test/src/test/resources/nested_simple.json index d42cc667df..f3cb1a5ebe 100644 --- a/integ-test/src/test/resources/nested_simple.json +++ b/integ-test/src/test/resources/nested_simple.json @@ -1,10 +1,10 @@ {"index":{"_id":"1"}} -{"name":"abbas","age":24,"address":[{"city":"New york city","state":"NY"},{"city":"bellevue","state":"WA"},{"city":"seattle","state":"WA"},{"city":"chicago","state":"IL"}]} +{"name":"abbas","age":24,"address":[{"city":"New york city","state":"NY","moveInDate":{"dateAndTime":"19840412T090742.000Z"}},{"city":"bellevue","state":"WA","moveInDate":[{"dateAndTime":"20230503T080742.000Z"},{"dateAndTime":"20011111T040744.000Z"}]},{"city":"seattle","state":"WA","moveInDate":{"dateAndTime":"19660319T030455.000Z"}},{"city":"chicago","state":"IL","moveInDate":{"dateAndTime":"20110601T010142.000Z"}}]} {"index":{"_id":"2"}} -{"name":"chen","age":32,"address":[{"city":"Miami","state":"Florida"},{"city":"los angeles","state":"CA"}]} +{"name":"chen","age":32,"address":[{"city":"Miami","state":"Florida","moveInDate":{"dateAndTime":"19010811T040333.000Z"}},{"city":"los angeles","state":"CA","moveInDate":{"dateAndTime":"20230503T080742.000Z"}}]} {"index":{"_id":"3"}} -{"name":"peng","age":26,"address":[{"city":"san diego","state":"CA"},{"city":"austin","state":"TX"}]} +{"name":"peng","age":26,"address":[{"city":"san diego","state":"CA","moveInDate":{"dateAndTime":"20011111T040744.000Z"}},{"city":"austin","state":"TX","moveInDate":{"dateAndTime":"19770713T090441.000Z"}}]} {"index":{"_id":"4"}} -{"name":"andy","age":19,"id":4,"address":[{"city":"houston","state":"TX"}]} +{"name":"andy","age":19,"id":4,"address":[{"city":"houston","state":"TX","moveInDate":{"dateAndTime":"19331212T050545.000Z"}}]} {"index":{"_id":"5"}} -{"name":"david","age":25,"address":[{"city":"raleigh","state":"NC"},{"city":"charlotte","state":"SC"}]} +{"name":"david","age":25,"address":[{"city":"raleigh","state":"NC","moveInDate":{"dateAndTime":"19090617T010421.000Z"}},{"city":"charlotte","state":"SC","moveInDate":[{"dateAndTime":"20011111T040744.000Z"}]}]} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/Content.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/Content.java index 94cd9d93ca..992689a186 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/Content.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/Content.java @@ -29,11 +29,36 @@ public interface Content { */ boolean isNumber(); + /** + * Is float value. + */ + boolean isFloat(); + + /** + * Is double value. + */ + boolean isDouble(); + + /** + * Is long value. + */ + boolean isLong(); + + /** + * Is boolean value. + */ + boolean isBoolean(); + /** * Is string value. */ boolean isString(); + /** + * Is array value. + */ + boolean isArray(); + /** * Get integer value. */ diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/ObjectContent.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/ObjectContent.java index 15e2e959a4..e8875d19ba 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/ObjectContent.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/ObjectContent.java @@ -6,6 +6,7 @@ package org.opensearch.sql.opensearch.data.utils; +import com.fasterxml.jackson.databind.node.ArrayNode; import java.util.AbstractMap; import java.util.Iterator; import java.util.List; @@ -103,6 +104,31 @@ public boolean isNumber() { return value instanceof Number; } + @Override + public boolean isFloat() { + return value instanceof Float; + } + + @Override + public boolean isDouble() { + return value instanceof Double; + } + + @Override + public boolean isLong() { + return value instanceof Long; + } + + @Override + public boolean isBoolean() { + return value instanceof Boolean; + } + + @Override + public boolean isArray() { + return value instanceof ArrayNode; + } + @Override public boolean isString() { return value instanceof String; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContent.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContent.java index 13a1fbf6a4..61da7c3b74 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContent.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContent.java @@ -88,11 +88,36 @@ public boolean isNumber() { return value().isNumber(); } + @Override + public boolean isLong() { + return value().isLong(); + } + + @Override + public boolean isFloat() { + return value().isFloat(); + } + + @Override + public boolean isDouble() { + return value().isDouble(); + } + @Override public boolean isString() { return value().isTextual(); } + @Override + public boolean isBoolean() { + return value().isBoolean(); + } + + @Override + public boolean isArray() { + return value().isArray(); + } + @Override public Object objectValue() { return value(); @@ -126,11 +151,10 @@ public Pair geoValue() { } /** - * Return the first element if is OpenSearch Array. - * https://www.elastic.co/guide/en/elasticsearch/reference/current/array.html. + * Getter for value. If value is array the whole array is returned. */ private JsonNode value() { - return value.isArray() ? value.get(0) : value; + return value; } /** diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java index 1ff5af7304..abad197bd4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java @@ -6,8 +6,14 @@ package org.opensearch.sql.opensearch.data.value; +import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; +import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; import static org.opensearch.sql.data.type.ExprCoreType.DATE; import static org.opensearch.sql.data.type.ExprCoreType.DATETIME; +import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; +import static org.opensearch.sql.data.type.ExprCoreType.FLOAT; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.LONG; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; import static org.opensearch.sql.data.type.ExprCoreType.TIME; @@ -18,10 +24,10 @@ import static org.opensearch.sql.utils.DateTimeUtils.UTC_ZONE_ID; import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterators; import java.time.Instant; import java.time.LocalDate; import java.time.LocalTime; @@ -55,8 +61,11 @@ import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.opensearch.data.type.OpenSearchBinaryType; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.type.OpenSearchDateType; +import org.opensearch.sql.opensearch.data.type.OpenSearchGeoPointType; +import org.opensearch.sql.opensearch.data.type.OpenSearchIpType; import org.opensearch.sql.opensearch.data.utils.Content; import org.opensearch.sql.opensearch.data.utils.ObjectContent; import org.opensearch.sql.opensearch.data.utils.OpenSearchJsonContent; @@ -149,10 +158,10 @@ public OpenSearchExprValueFactory(Map typeMapping) { * { "employ.id", "INTEGER" } * { "employ.state", "STRING" } */ - public ExprValue construct(String jsonString) { + public ExprValue construct(String jsonString, boolean supportArrays) { try { return parse(new OpenSearchJsonContent(OBJECT_MAPPER.readTree(jsonString)), TOP_PATH, - Optional.of(STRUCT)); + Optional.of(STRUCT), supportArrays); } catch (JsonProcessingException e) { throw new IllegalStateException(String.format("invalid json: %s.", jsonString), e); } @@ -167,21 +176,27 @@ public ExprValue construct(String jsonString) { * @param value value object * @return ExprValue */ - public ExprValue construct(String field, Object value) { - return parse(new ObjectContent(value), field, type(field)); + public ExprValue construct(String field, Object value, boolean supportArrays) { + return parse(new ObjectContent(value), field, type(field), supportArrays); } - private ExprValue parse(Content content, String field, Optional fieldType) { + private ExprValue parse( + Content content, + String field, + Optional fieldType, + boolean supportArrays + ) { if (content.isNull() || !fieldType.isPresent()) { return ExprNullValue.of(); } ExprType type = fieldType.get(); - if (type.equals(OpenSearchDataType.of(OpenSearchDataType.MappingType.Object)) - || type == STRUCT) { - return parseStruct(content, field); - } else if (type.equals(OpenSearchDataType.of(OpenSearchDataType.MappingType.Nested))) { - return parseArray(content, field); + if (type.equals(OpenSearchDataType.of(OpenSearchDataType.MappingType.Nested)) + || content.isArray()) { + return parseArray(content, field, type, supportArrays); + } else if (type.equals(OpenSearchDataType.of(OpenSearchDataType.MappingType.Object)) + || type == STRUCT) { + return parseStruct(content, field, supportArrays); } else { if (typeActionMap.containsKey(type)) { return typeActionMap.get(type).apply(content, type); @@ -338,39 +353,97 @@ private ExprValue createOpenSearchDateType(Content value, ExprType type) { return new ExprTimestampValue((Instant) value.objectValue()); } - private ExprValue parseStruct(Content content, String prefix) { + /** + * Parse struct content. + * @param content Content to parse. + * @param prefix Prefix for Level of object depth to parse. + * @param supportArrays Parsing the whole array if array is type nested. + * @return Value parsed from content. + */ + private ExprValue parseStruct(Content content, String prefix, boolean supportArrays) { LinkedHashMap result = new LinkedHashMap<>(); content.map().forEachRemaining(entry -> result.put(entry.getKey(), parse(entry.getValue(), makeField(prefix, entry.getKey()), - type(makeField(prefix, entry.getKey()))))); + type(makeField(prefix, entry.getKey())), supportArrays))); return new ExprTupleValue(result); } /** - * Todo. ARRAY is not completely supported now. In OpenSearch, there is no dedicated array type. - * docs - * The similar data type is nested, but it can only allow a list of objects. + * Parse array content. Can also parse nested which isn't necessarily an array. + * @param content Content to parse. + * @param prefix Prefix for Level of object depth to parse. + * @param type Type of content parsing. + * @param supportArrays Parsing the whole array if array is type nested. + * @return Value parsed from content. */ - private ExprValue parseArray(Content content, String prefix) { + private ExprValue parseArray( + Content content, + String prefix, + ExprType type, + boolean supportArrays + ) { List result = new ArrayList<>(); - // ExprCoreType.ARRAY does not indicate inner elements type. - if (Iterators.size(content.array()) == 1 && content.objectValue() instanceof JsonNode) { - result.add(parse(content, prefix, Optional.of(STRUCT))); + + // ARRAY is mapped to nested but can take the json structure of an Object. + if (content.objectValue() instanceof ObjectNode) { + result.add(parseStruct(content, prefix, supportArrays)); + // non-object type arrays are only supported when parsing inner_hits of OS response. + } else if ( + !(type instanceof OpenSearchDataType + && ((OpenSearchDataType) type).getExprType().equals(ARRAY)) + && !supportArrays) { + return parseInnerArrayValue(content.array().next(), prefix, type, supportArrays); } else { content.array().forEachRemaining(v -> { - // ExprCoreType.ARRAY does not indicate inner elements type. OpenSearch nested will be an - // array of structs, otherwise parseArray currently only supports array of strings. - if (v.isString()) { - result.add(parse(v, prefix, Optional.of(OpenSearchDataType.of(STRING)))); - } else { - result.add(parse(v, prefix, Optional.of(STRUCT))); - } + result.add(parseInnerArrayValue(v, prefix, type, supportArrays)); }); } return new ExprCollectionValue(result); } + /** + * Parse inner array value. Can be object type and recurse continues. + * @param content Array index being parsed. + * @param prefix Prefix for value. + * @param type Type of inner array value. + * @param supportArrays Parsing the whole array if array is type nested. + * @return Inner array value. + */ + private ExprValue parseInnerArrayValue( + Content content, + String prefix, + ExprType type, + boolean supportArrays + ) { + if (type instanceof OpenSearchIpType + || type instanceof OpenSearchBinaryType + || type instanceof OpenSearchDateType + || type instanceof OpenSearchGeoPointType) { + return parse(content, prefix, Optional.of(type), supportArrays); + } else if (content.isString()) { + return parse(content, prefix, Optional.of(OpenSearchDataType.of(STRING)), supportArrays); + } else if (content.isLong()) { + return parse(content, prefix, Optional.of(OpenSearchDataType.of(LONG)), supportArrays); + } else if (content.isFloat()) { + return parse(content, prefix, Optional.of(OpenSearchDataType.of(FLOAT)), supportArrays); + } else if (content.isDouble()) { + return parse(content, prefix, Optional.of(OpenSearchDataType.of(DOUBLE)), supportArrays); + } else if (content.isNumber()) { + return parse(content, prefix, Optional.of(OpenSearchDataType.of(INTEGER)), supportArrays); + } else if (content.isBoolean()) { + return parse(content, prefix, Optional.of(OpenSearchDataType.of(BOOLEAN)), supportArrays); + } else { + return parse(content, prefix, Optional.of(STRUCT), supportArrays); + } + } + + /** + * Make complete path string for field. + * @param path Path of field. + * @param field Field to append to path. + * @return Field appended to path level. + */ private String makeField(String path, String field) { return path.equalsIgnoreCase(TOP_PATH) ? field : String.join(".", path, field); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java index 733fad6203..973624d19a 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java @@ -15,15 +15,15 @@ import com.google.common.collect.ImmutableMap; import java.util.Arrays; -import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.stream.Collectors; import lombok.EqualsAndHashCode; import lombok.ToString; import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.text.Text; +import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.Aggregations; import org.opensearch.sql.data.model.ExprFloatValue; @@ -107,61 +107,108 @@ public boolean isAggregationResponse() { */ public Iterator iterator() { if (isAggregationResponse()) { - return exprValueFactory.getParser().parse(aggregations).stream().map(entry -> { - ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); - for (Map.Entry value : entry.entrySet()) { - builder.put(value.getKey(), exprValueFactory.construct(value.getKey(), value.getValue())); - } - return (ExprValue) ExprTupleValue.fromExprValueMap(builder.build()); - }).iterator(); + return handleAggregationResponse(); } else { - List metaDataFieldSet = includes.stream() - .filter(include -> METADATAFIELD_TYPE_MAP.containsKey(include)) - .collect(Collectors.toList()); - ExprFloatValue maxScore = Float.isNaN(hits.getMaxScore()) - ? null : new ExprFloatValue(hits.getMaxScore()); return Arrays.stream(hits.getHits()) .map(hit -> { - String source = hit.getSourceAsString(); - ExprValue docData = exprValueFactory.construct(source); - ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); - if (hit.getInnerHits() == null || hit.getInnerHits().isEmpty()) { - builder.putAll(docData.tupleValue()); - } else { - Map rowSource = hit.getSourceAsMap(); - builder.putAll(ExprValueUtils.tupleValue(rowSource).tupleValue()); - } - - metaDataFieldSet.forEach(metaDataField -> { - if (metaDataField.equals(METADATA_FIELD_INDEX)) { - builder.put(METADATA_FIELD_INDEX, new ExprStringValue(hit.getIndex())); - } else if (metaDataField.equals(METADATA_FIELD_ID)) { - builder.put(METADATA_FIELD_ID, new ExprStringValue(hit.getId())); - } else if (metaDataField.equals(METADATA_FIELD_SCORE)) { - if (!Float.isNaN(hit.getScore())) { - builder.put(METADATA_FIELD_SCORE, new ExprFloatValue(hit.getScore())); - } - } else if (metaDataField.equals(METADATA_FIELD_MAXSCORE)) { - if (maxScore != null) { - builder.put(METADATA_FIELD_MAXSCORE, maxScore); - } - } else { // if (metaDataField.equals(METADATA_FIELD_SORT)) { - builder.put(METADATA_FIELD_SORT, new ExprLongValue(hit.getSeqNo())); - } - }); - - if (!hit.getHighlightFields().isEmpty()) { - var hlBuilder = ImmutableMap.builder(); - for (var es : hit.getHighlightFields().entrySet()) { - hlBuilder.put(es.getKey(), ExprValueUtils.collectionValue( - Arrays.stream(es.getValue().fragments()).map( - t -> (t.toString())).collect(Collectors.toList()))); - } - builder.put("_highlight", ExprTupleValue.fromExprValueMap(hlBuilder.build())); - } + addParsedHitsToBuilder(builder, hit); + addMetaDataFieldsToBuilder(builder, hit); + addHighlightsToBuilder(builder, hit); return (ExprValue) ExprTupleValue.fromExprValueMap(builder.build()); }).iterator(); } } + + /** + * Parse response for all hits to add to builder. Inner_hits supports arrays of objects + * with nested type. + * @param builder builder to build values from response. + * @param hit Search hit from response. + */ + private void addParsedHitsToBuilder( + ImmutableMap.Builder builder, + SearchHit hit + ) { + builder.putAll( + exprValueFactory.construct( + hit.getSourceAsString(), + !(hit.getInnerHits() == null || hit.getInnerHits().isEmpty()) + ).tupleValue()); + } + + /** + * If highlight fields are present in response add the fields to the builder. + * @param builder builder to build values from response. + * @param hit Search hit from response. + */ + private void addHighlightsToBuilder( + ImmutableMap.Builder builder, + SearchHit hit + ) { + if (!hit.getHighlightFields().isEmpty()) { + var hlBuilder = ImmutableMap.builder(); + for (var es : hit.getHighlightFields().entrySet()) { + hlBuilder.put(es.getKey(), ExprValueUtils.collectionValue( + Arrays.stream(es.getValue().fragments()).map( + Text::toString).collect(Collectors.toList()))); + } + builder.put("_highlight", ExprTupleValue.fromExprValueMap(hlBuilder.build())); + } + } + + /** + * Add metadata fields to builder from response. + * @param builder builder to build values from response. + * @param hit Search hit from response. + */ + private void addMetaDataFieldsToBuilder( + ImmutableMap.Builder builder, + SearchHit hit + ) { + List metaDataFieldSet = includes.stream() + .filter(METADATAFIELD_TYPE_MAP::containsKey) + .collect(Collectors.toList()); + ExprFloatValue maxScore = Float.isNaN(hits.getMaxScore()) + ? null : new ExprFloatValue(hits.getMaxScore()); + + metaDataFieldSet.forEach(metaDataField -> { + if (metaDataField.equals(METADATA_FIELD_INDEX)) { + builder.put(METADATA_FIELD_INDEX, new ExprStringValue(hit.getIndex())); + } else if (metaDataField.equals(METADATA_FIELD_ID)) { + builder.put(METADATA_FIELD_ID, new ExprStringValue(hit.getId())); + } else if (metaDataField.equals(METADATA_FIELD_SCORE)) { + if (!Float.isNaN(hit.getScore())) { + builder.put(METADATA_FIELD_SCORE, new ExprFloatValue(hit.getScore())); + } + } else if (metaDataField.equals(METADATA_FIELD_MAXSCORE)) { + if (maxScore != null) { + builder.put(METADATA_FIELD_MAXSCORE, maxScore); + } + } else { // if (metaDataField.equals(METADATA_FIELD_SORT)) { + builder.put(METADATA_FIELD_SORT, new ExprLongValue(hit.getSeqNo())); + } + }); + } + + /** + * Handle an aggregation response. + * @return Parsed and built return values from response. + */ + private Iterator handleAggregationResponse() { + return exprValueFactory.getParser().parse(aggregations).stream().map(entry -> { + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); + for (Map.Entry value : entry.entrySet()) { + builder.put( + value.getKey(), + exprValueFactory.construct( + value.getKey(), + value.getValue(), + false + ) + ); + } + return (ExprValue) ExprTupleValue.fromExprValueMap(builder.build()); + }).iterator(); + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/core/ExpressionScript.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/core/ExpressionScript.java index b327b73b86..9bdb15d63a 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/core/ExpressionScript.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/core/ExpressionScript.java @@ -118,7 +118,11 @@ private Environment buildValueEnv( Map valueEnv = new HashMap<>(); for (ReferenceExpression field : fields) { String fieldName = field.getAttr(); - ExprValue exprValue = valueFactory.construct(fieldName, getDocValue(field, docProvider)); + ExprValue exprValue = valueFactory.construct( + fieldName, + getDocValue(field, docProvider), + false + ); valueEnv.put(field, exprValue); } // Encapsulate map data structure into anonymous Environment class diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java index b378fae297..9417a1de1d 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java @@ -313,7 +313,7 @@ void search() { 1.0F)); when(searchHit.getSourceAsString()).thenReturn("{\"id\", 1}"); when(searchHit.getInnerHits()).thenReturn(null); - when(factory.construct(any())).thenReturn(exprTupleValue); + when(factory.construct(any(), anyBoolean())).thenReturn(exprTupleValue); // Mock second scroll request followed SearchResponse scrollResponse = mock(SearchResponse.class); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java index 2958fa1100..b521c6605c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java @@ -13,6 +13,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Answers.RETURNS_DEEP_STUBS; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -297,7 +298,7 @@ void search() throws IOException { 1.0F)); when(searchHit.getSourceAsString()).thenReturn("{\"id\", 1}"); when(searchHit.getInnerHits()).thenReturn(null); - when(factory.construct(any())).thenReturn(exprTupleValue); + when(factory.construct(any(), anyBoolean())).thenReturn(exprTupleValue); // Mock second scroll request followed SearchResponse scrollResponse = mock(SearchResponse.class); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java index 81ac39ede0..a7e3531e8b 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java @@ -12,6 +12,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.data.model.ExprValueUtils.booleanValue; import static org.opensearch.sql.data.model.ExprValueUtils.byteValue; +import static org.opensearch.sql.data.model.ExprValueUtils.collectionValue; import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; import static org.opensearch.sql.data.model.ExprValueUtils.floatValue; import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; @@ -37,12 +38,12 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.time.Instant; import java.time.LocalDate; import java.time.LocalTime; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import lombok.EqualsAndHashCode; import lombok.ToString; @@ -50,7 +51,6 @@ import org.opensearch.sql.data.model.ExprCollectionValue; import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.model.ExprDatetimeValue; -import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.data.model.ExprTimeValue; import org.opensearch.sql.data.model.ExprTimestampValue; import org.opensearch.sql.data.model.ExprTupleValue; @@ -92,6 +92,17 @@ class OpenSearchExprValueFactoryTest { .put("arrayV", OpenSearchDataType.of(ARRAY)) .put("arrayV.info", OpenSearchDataType.of(STRING)) .put("arrayV.author", OpenSearchDataType.of(STRING)) + .put("deepNestedV", OpenSearchDataType.of( + OpenSearchDataType.of(OpenSearchDataType.MappingType.Nested)) + ) + .put("deepNestedV.year", OpenSearchDataType.of( + OpenSearchDataType.of(OpenSearchDataType.MappingType.Nested)) + ) + .put("deepNestedV.year.timeV", OpenSearchDateType.of(TIME)) + .put("nestedV", OpenSearchDataType.of( + OpenSearchDataType.of(OpenSearchDataType.MappingType.Nested)) + ) + .put("nestedV.count", OpenSearchDataType.of(INTEGER)) .put("textV", OpenSearchDataType.of(OpenSearchDataType.MappingType.Text)) .put("textKeywordV", OpenSearchTextType.of(Map.of("words", OpenSearchDataType.of(OpenSearchDataType.MappingType.Keyword)))) @@ -381,7 +392,7 @@ public void constructDateFromUnsupportedFormat_ThrowIllegalArgumentException() { @Test public void constructArray() { assertEquals( - new ExprCollectionValue(ImmutableList.of(new ExprTupleValue( + new ExprCollectionValue(List.of(new ExprTupleValue( new LinkedHashMap() { { put("info", stringValue("zz")); @@ -390,22 +401,252 @@ public void constructArray() { }))), tupleValue("{\"arrayV\":[{\"info\":\"zz\",\"author\":\"au\"}]}").get("arrayV")); assertEquals( - new ExprCollectionValue(ImmutableList.of(new ExprTupleValue( + new ExprCollectionValue(List.of(new ExprTupleValue( new LinkedHashMap() { { put("info", stringValue("zz")); put("author", stringValue("au")); } }))), - constructFromObject("arrayV", ImmutableList.of( + constructFromObject("arrayV", List.of( ImmutableMap.of("info", "zz", "author", "au")))); } @Test public void constructArrayOfStrings() { assertEquals(new ExprCollectionValue( - ImmutableList.of(new ExprStringValue("zz"), new ExprStringValue("au"))), - constructFromObject("arrayV", ImmutableList.of("zz", "au"))); + List.of(stringValue("zz"), stringValue("au"))), + constructFromObject("arrayV", List.of("zz", "au"))); + } + + @Test + public void constructNestedArraysOfStrings() { + assertEquals( + new ExprCollectionValue( + List.of( + collectionValue( + List.of("zz", "au") + ), + collectionValue( + List.of("ss") + ) + ) + ), + tupleValueWithArraySupport( + "{\"stringV\":[" + + "[\"zz\", \"au\"]," + + "[\"ss\"]" + + "]}" + ).get("stringV")); + } + + @Test + public void constructNestedArraysOfStringsReturnsFirstIndex() { + assertEquals( + stringValue("zz"), + tupleValue( + "{\"stringV\":[" + + "[\"zz\", \"au\"]," + + "[\"ss\"]" + + "]}" + ).get("stringV")); + } + + @Test + public void constructMultiNestedArraysOfStringsReturnsFirstIndex() { + assertEquals( + stringValue("z"), + tupleValue( + "{\"stringV\":" + + "[\"z\"," + + "[\"s\"]," + + "[\"zz\", \"au\"]" + + "]}" + ).get("stringV")); + } + + @Test + public void constructArrayOfInts() { + assertEquals(new ExprCollectionValue( + List.of(integerValue(1), integerValue(2))), + constructFromObject("arrayV", List.of(1, 2))); + } + + @Test + public void constructArrayOfShorts() { + // Shorts are treated same as integer + assertEquals(new ExprCollectionValue( + List.of(shortValue((short)3), shortValue((short)4))), + constructFromObject("arrayV", List.of(3, 4))); + } + + @Test + public void constructArrayOfLongs() { + assertEquals(new ExprCollectionValue( + List.of(longValue(123456789L), longValue(987654321L))), + constructFromObject("arrayV", List.of(123456789L, 987654321L))); + } + + @Test + public void constructArrayOfFloats() { + assertEquals(new ExprCollectionValue( + List.of(floatValue(3.14f), floatValue(4.13f))), + constructFromObject("arrayV", List.of(3.14f, 4.13f))); + } + + @Test + public void constructArrayOfDoubles() { + assertEquals(new ExprCollectionValue( + List.of(doubleValue(9.1928374756D), doubleValue(4.987654321D))), + constructFromObject("arrayV", List.of(9.1928374756D, 4.987654321D))); + } + + @Test + public void constructArrayOfBooleans() { + assertEquals(new ExprCollectionValue( + List.of(booleanValue(true), booleanValue(false))), + constructFromObject("arrayV", List.of(true, false))); + } + + @Test + public void constructNestedObjectArrayNode() { + assertEquals(collectionValue( + List.of( + Map.of("count", 1), + Map.of("count", 2) + )), + tupleValueWithArraySupport("{\"nestedV\":[{\"count\":1},{\"count\":2}]}") + .get("nestedV")); + } + + @Test + public void constructNestedObjectArrayOfObjectArraysNode() { + assertEquals( + collectionValue( + List.of( + Map.of("year", + List.of( + Map.of("timeV", new ExprTimeValue("09:07:42")), + Map.of("timeV", new ExprTimeValue("09:07:42")) + ) + ), + Map.of("year", + List.of( + Map.of("timeV", new ExprTimeValue("09:07:42")), + Map.of("timeV", new ExprTimeValue("09:07:42")) + ) + ) + ) + ), + tupleValueWithArraySupport( + "{\"deepNestedV\":" + + "[" + + "{\"year\":" + + "[" + + "{\"timeV\":\"09:07:42\"}," + + "{\"timeV\":\"09:07:42\"}" + + "]" + + "}," + + "{\"year\":" + + "[" + + "{\"timeV\":\"09:07:42\"}," + + "{\"timeV\":\"09:07:42\"}" + + "]" + + "}" + + "]" + + "}") + .get("deepNestedV")); + } + + @Test + public void constructNestedArrayNode() { + assertEquals(collectionValue( + List.of( + 1969, + 2011 + )), + tupleValueWithArraySupport("{\"nestedV\":[1969,2011]}") + .get("nestedV")); + } + + @Test + public void constructNestedObjectNode() { + assertEquals(collectionValue( + List.of( + Map.of("count", 1969) + )), + tupleValue("{\"nestedV\":{\"count\":1969}}") + .get("nestedV")); + } + + @Test + public void constructArrayOfGeoPoints() { + assertEquals(new ExprCollectionValue( + List.of( + new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), + new OpenSearchExprGeoPointValue(-33.6123556, 66.287449)) + ), + tupleValueWithArraySupport( + "{\"geoV\":[" + + "{\"lat\":42.60355556,\"lon\":-97.25263889}," + + "{\"lat\":-33.6123556,\"lon\":66.287449}" + + "]}" + ).get("geoV") + ); + } + + @Test + public void constructArrayOfIPsReturnsFirstIndex() { + assertEquals( + new OpenSearchExprIpValue("192.168.0.1"), + tupleValue("{\"ipV\":[\"192.168.0.1\",\"192.168.0.2\"]}") + .get("ipV") + ); + } + + @Test + public void constructBinaryArrayReturnsFirstIndex() { + assertEquals( + new OpenSearchExprBinaryValue("U29tZSBiaWsdfsdfgYmxvYg=="), + tupleValue("{\"binaryV\":[\"U29tZSBiaWsdfsdfgYmxvYg==\",\"U987yuhjjiy8jhk9vY+98jjdf\"]}") + .get("binaryV") + ); + } + + @Test + public void constructArrayOfCustomEpochMillisReturnsFirstIndex() { + assertEquals( + new ExprDatetimeValue("2015-01-01 12:10:30"), + tupleValue("{\"customAndEpochMillisV\":[\"2015-01-01 12:10:30\",\"1999-11-09 01:09:44\"]}") + .get("customAndEpochMillisV") + ); + } + + @Test + public void constructArrayOfDateStringsReturnsFirstIndex() { + assertEquals( + new ExprDateValue("1984-04-12"), + tupleValue("{\"dateStringV\":[\"1984-04-12\",\"2033-05-03\"]}") + .get("dateStringV") + ); + } + + @Test + public void constructArrayOfTimeStringsReturnsFirstIndex() { + assertEquals( + new ExprTimeValue("12:10:30"), + tupleValue("{\"timeStringV\":[\"12:10:30.000Z\",\"18:33:55.000Z\"]}") + .get("timeStringV") + ); + } + + @Test + public void constructArrayOfEpochMillis() { + assertEquals( + new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), + tupleValue("{\"dateOrEpochMillisV\":[\"1420070400001\",\"1454251113333\"]}") + .get("dateOrEpochMillisV") + ); } @Test @@ -517,13 +758,19 @@ public void noTypeFoundForMapping() { @Test public void constructUnsupportedTypeThrowException() { OpenSearchExprValueFactory exprValueFactory = - new OpenSearchExprValueFactory(ImmutableMap.of("type", new TestType())); + new OpenSearchExprValueFactory(Map.of("type", new TestType())); IllegalStateException exception = - assertThrows(IllegalStateException.class, () -> exprValueFactory.construct("{\"type\":1}")); + assertThrows( + IllegalStateException.class, + () -> exprValueFactory.construct("{\"type\":1}", false) + ); assertEquals("Unsupported type: TEST_TYPE for value: 1.", exception.getMessage()); exception = - assertThrows(IllegalStateException.class, () -> exprValueFactory.construct("type", 1)); + assertThrows( + IllegalStateException.class, + () -> exprValueFactory.construct("type", 1, false) + ); assertEquals( "Unsupported type: TEST_TYPE for value: 1.", exception.getMessage()); @@ -553,12 +800,21 @@ public void factoryMappingsAreExtendableWithoutOverWrite() } public Map tupleValue(String jsonString) { - final ExprValue construct = exprValueFactory.construct(jsonString); + final ExprValue construct = exprValueFactory.construct(jsonString, false); + return construct.tupleValue(); + } + + public Map tupleValueWithArraySupport(String jsonString) { + final ExprValue construct = exprValueFactory.construct(jsonString, true); return construct.tupleValue(); } private ExprValue constructFromObject(String fieldName, Object value) { - return exprValueFactory.construct(fieldName, value); + return exprValueFactory.construct(fieldName, value, false); + } + + private ExprValue constructFromObjectWithArraySupport(String fieldName, Object value) { + return exprValueFactory.construct(fieldName, value, true); } @EqualsAndHashCode(callSuper = false) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java index 079a82b783..05e5d80c39 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java @@ -12,7 +12,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableMap; @@ -114,7 +118,8 @@ void iterator() { when(searchHit2.getSourceAsString()).thenReturn("{\"id1\", 2}"); when(searchHit1.getInnerHits()).thenReturn(null); when(searchHit2.getInnerHits()).thenReturn(null); - when(factory.construct(any())).thenReturn(exprTupleValue1).thenReturn(exprTupleValue2); + when(factory.construct(any(), anyBoolean())) + .thenReturn(exprTupleValue1).thenReturn(exprTupleValue2); int i = 0; for (ExprValue hit : new OpenSearchResponse(searchResponse, factory, List.of("id1"))) { @@ -149,7 +154,7 @@ void iterator_metafields() { when(searchHit1.getScore()).thenReturn(3.75F); when(searchHit1.getSeqNo()).thenReturn(123456L); - when(factory.construct(any())).thenReturn(exprTupleHit); + when(factory.construct(any(), anyBoolean())).thenReturn(exprTupleHit); ExprTupleValue exprTupleResponse = ExprTupleValue.fromExprValueMap(ImmutableMap.of( "id1", new ExprIntegerValue(1), @@ -187,7 +192,7 @@ void iterator_metafields_withoutIncludes() { when(searchHit1.getSourceAsString()).thenReturn("{\"id1\", 1}"); - when(factory.construct(any())).thenReturn(exprTupleHit); + when(factory.construct(any(), anyBoolean())).thenReturn(exprTupleHit); List includes = List.of("id1"); ExprTupleValue exprTupleResponse = ExprTupleValue.fromExprValueMap(ImmutableMap.of( @@ -224,7 +229,7 @@ void iterator_metafields_scoreNaN() { when(searchHit1.getScore()).thenReturn(Float.NaN); when(searchHit1.getSeqNo()).thenReturn(123456L); - when(factory.construct(any())).thenReturn(exprTupleHit); + when(factory.construct(any(), anyBoolean())).thenReturn(exprTupleHit); List includes = List.of("id1", "_index", "_id", "_sort", "_score", "_maxscore"); ExprTupleValue exprTupleResponse = ExprTupleValue.fromExprValueMap(ImmutableMap.of( @@ -252,8 +257,6 @@ void iterator_with_inner_hits() { new SearchHit[] {searchHit1}, new TotalHits(2L, TotalHits.Relation.EQUAL_TO), 1.0F)); - when(searchHit1.getSourceAsString()).thenReturn("{\"id1\", 1}"); - when(searchHit1.getSourceAsMap()).thenReturn(Map.of("id1", 1)); when(searchHit1.getInnerHits()).thenReturn( Map.of( "innerHit", @@ -262,7 +265,7 @@ void iterator_with_inner_hits() { new TotalHits(2L, TotalHits.Relation.EQUAL_TO), 1.0F))); - when(factory.construct(any())).thenReturn(exprTupleValue1); + when(factory.construct(any(), anyBoolean())).thenReturn(exprTupleValue1); for (ExprValue hit : new OpenSearchResponse(searchResponse, factory, includes)) { assertEquals(exprTupleValue1, hit); @@ -293,7 +296,7 @@ void aggregation_iterator() { .thenReturn(Arrays.asList(ImmutableMap.of("id1", 1), ImmutableMap.of("id2", 2))); when(searchResponse.getAggregations()).thenReturn(aggregations); when(factory.getParser()).thenReturn(parser); - when(factory.construct(anyString(), any())) + when(factory.construct(anyString(), anyInt(), anyBoolean())) .thenReturn(new ExprIntegerValue(1)) .thenReturn(new ExprIntegerValue(2)); @@ -329,7 +332,7 @@ void highlight_iterator() { 1.0F)); when(searchHit1.getHighlightFields()).thenReturn(highlightMap); - when(factory.construct(any())).thenReturn(resultTuple); + when(factory.construct(any(), anyBoolean())).thenReturn(resultTuple); for (ExprValue resultHit : new OpenSearchResponse(searchResponse, factory, includes)) { var expected = ExprValueUtils.collectionValue( From 3302ec8b6fc411378f97cdfaf0f81b48a0bccb37 Mon Sep 17 00:00:00 2001 From: Forest Vey Date: Tue, 27 Jun 2023 13:58:48 -0700 Subject: [PATCH 4/4] Add Support for Nested Function in Order By Clause (#1789) * Add Support for Nested Function in Order By Clause (#280) * Adding order by clause support for nested function. Signed-off-by: forestmvey * Adding test coverage for nested in ORDER BY clause. Signed-off-by: forestmvey * Added nested function validation to NestedAnalyzer. Signed-off-by: forestmvey --------- Signed-off-by: forestmvey * Adding semantic check for missing arguments in function and unit test. Signed-off-by: forestmvey --------- Signed-off-by: forestmvey --- .../sql/analysis/NestedAnalyzer.java | 15 +++- .../opensearch/sql/analysis/AnalyzerTest.java | 7 ++ docs/user/dql/functions.rst | 11 +++ .../java/org/opensearch/sql/sql/NestedIT.java | 34 +++++++++ .../scan/OpenSearchIndexScanBuilder.java | 10 ++- .../script/filter/lucene/LuceneQuery.java | 7 +- .../storage/script/sort/SortQueryBuilder.java | 39 ++++++++++ .../OpenSearchIndexScanOptimizationTest.java | 72 +++++++++++++++++++ .../script/sort/SortQueryBuilderTest.java | 70 ++++++++++++++++++ .../sql/sql/parser/AstSortBuilder.java | 14 ---- .../sql/sql/antlr/SQLSyntaxParserTest.java | 23 ++++-- .../sql/sql/parser/AstBuilderTest.java | 15 ---- 12 files changed, 278 insertions(+), 39 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java index 756c1f20b3..4e3939bb14 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java @@ -17,6 +17,8 @@ import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.function.BuiltinFunctionName; @@ -105,7 +107,18 @@ private void validateArgs(List args) { * @param field : Nested field to generate path of. * @return : Path of field derived from last level of nesting. */ - private ReferenceExpression generatePath(String field) { + public static ReferenceExpression generatePath(String field) { return new ReferenceExpression(field.substring(0, field.lastIndexOf(".")), STRING); } + + /** + * Check if supplied expression is a nested function. + * @param expr Expression checking if is nested function. + * @return True if expression is a nested function. + */ + public static Boolean isNestedFunction(Expression expr) { + return (expr instanceof FunctionExpression + && ((FunctionExpression) expr).getFunctionName().getFunctionName() + .equalsIgnoreCase(BuiltinFunctionName.NESTED.name())); + } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 59edde6f86..6d83ee53a8 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -9,9 +9,11 @@ import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; +import static org.opensearch.sql.analysis.NestedAnalyzer.isNestedFunction; import static org.opensearch.sql.ast.dsl.AstDSL.aggregate; import static org.opensearch.sql.ast.dsl.AstDSL.alias; import static org.opensearch.sql.ast.dsl.AstDSL.argument; @@ -39,6 +41,7 @@ import static org.opensearch.sql.data.type.ExprCoreType.LONG; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; +import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.utils.MLCommonsConstants.ACTION; import static org.opensearch.sql.utils.MLCommonsConstants.ALGO; import static org.opensearch.sql.utils.MLCommonsConstants.ASYNC; @@ -574,6 +577,10 @@ public void project_nested_field_arg() { function("nested", qualifiedName("message", "info")), null) ) ); + + assertTrue(isNestedFunction(DSL.nested(DSL.ref("message.info", STRING)))); + assertFalse(isNestedFunction(DSL.literal("fieldA"))); + assertFalse(isNestedFunction(DSL.match(DSL.namedArgument("field", literal("message"))))); } @Test diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index fa37dc7778..cef87624a5 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -4469,6 +4469,17 @@ Example with ``field`` and ``path`` parameters in the SELECT and WHERE clause:: | b | +---------------------------------+ +Example with ``field`` and ``path`` parameters in the SELECT and ORDER BY clause:: + + os> SELECT nested(message.info, message) FROM nested ORDER BY nested(message.info, message) DESC; + fetched rows / total rows = 2/2 + +---------------------------------+ + | nested(message.info, message) | + |---------------------------------| + | b | + | a | + +---------------------------------+ + System Functions ================ diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java index 80886fe779..69b54cfc4f 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java @@ -189,6 +189,40 @@ public void nested_function_with_order_by_clause() { rows("zz")); } + @Test + public void nested_function_with_order_by_clause_desc() { + String query = + "SELECT nested(message.info) FROM " + TEST_INDEX_NESTED_TYPE + + " ORDER BY nested(message.info, message) DESC"; + JSONObject result = executeJdbcRequest(query); + + assertEquals(6, result.getInt("total")); + verifyDataRows(result, + rows("zz"), + rows("c"), + rows("c"), + rows("a"), + rows("b"), + rows("a")); + } + + @Test + public void nested_function_and_field_with_order_by_clause() { + String query = + "SELECT nested(message.info), myNum FROM " + TEST_INDEX_NESTED_TYPE + + " ORDER BY nested(message.info, message), myNum"; + JSONObject result = executeJdbcRequest(query); + + assertEquals(6, result.getInt("total")); + verifyDataRows(result, + rows("a", 1), + rows("c", 4), + rows("a", 4), + rows("b", 2), + rows("c", 3), + rows("zz", new JSONArray(List.of(3, 4)))); + } + // Nested function in GROUP BY clause is not yet implemented for JDBC format. This test ensures // that the V2 engine falls back to legacy implementation. // TODO Fix the test when NESTED is supported in GROUP BY in the V2 engine. diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java index 3a0d06d079..edcbedc7a7 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java @@ -5,6 +5,8 @@ package org.opensearch.sql.opensearch.storage.scan; +import static org.opensearch.sql.analysis.NestedAnalyzer.isNestedFunction; + import java.util.function.Function; import lombok.EqualsAndHashCode; import org.opensearch.sql.expression.ReferenceExpression; @@ -113,9 +115,15 @@ public boolean pushDownNested(LogicalNested nested) { return delegate.pushDownNested(nested); } + /** + * Valid if sorting is only by fields. + * @param sort Logical sort + * @return True if sorting by fields only + */ private boolean sortByFieldsOnly(LogicalSort sort) { return sort.getSortList().stream() - .map(sortItem -> sortItem.getRight() instanceof ReferenceExpression) + .map(sortItem -> sortItem.getRight() instanceof ReferenceExpression + || isNestedFunction(sortItem.getRight())) .reduce(true, Boolean::logicalAnd); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQuery.java index 4dcfec125e..a45c535383 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQuery.java @@ -6,6 +6,8 @@ package org.opensearch.sql.opensearch.storage.script.filter.lucene; +import static org.opensearch.sql.analysis.NestedAnalyzer.isNestedFunction; + import com.google.common.collect.ImmutableMap; import java.util.Map; import java.util.function.Function; @@ -62,10 +64,7 @@ public boolean canSupport(FunctionExpression func) { * @return return true if function has supported nested function expression. */ public boolean isNestedPredicate(FunctionExpression func) { - return ((func.getArguments().get(0) instanceof FunctionExpression - && ((FunctionExpression)func.getArguments().get(0)) - .getFunctionName().getFunctionName().equalsIgnoreCase(BuiltinFunctionName.NESTED.name())) - ); + return isNestedFunction(func.getArguments().get(0)); } /** diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilder.java index 1415fc22c6..62c923832c 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilder.java @@ -6,15 +6,21 @@ package org.opensearch.sql.opensearch.storage.script.sort; +import static org.opensearch.sql.analysis.NestedAnalyzer.generatePath; +import static org.opensearch.sql.analysis.NestedAnalyzer.isNestedFunction; + import com.google.common.collect.ImmutableMap; import java.util.Map; import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.NestedSortBuilder; import org.opensearch.search.sort.SortBuilder; import org.opensearch.search.sort.SortBuilders; import org.opensearch.search.sort.SortOrder; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; /** @@ -53,11 +59,44 @@ public SortBuilder build(Expression expression, Sort.SortOption option) { return SortBuilders.scoreSort().order(sortOrderMap.get(option.getSortOrder())); } return fieldBuild((ReferenceExpression) expression, option); + } else if (isNestedFunction(expression)) { + + validateNestedArgs((FunctionExpression) expression); + String orderByName = ((FunctionExpression)expression).getArguments().get(0).toString(); + // Generate path if argument not supplied in function. + ReferenceExpression path = ((FunctionExpression)expression).getArguments().size() == 2 + ? (ReferenceExpression) ((FunctionExpression)expression).getArguments().get(1) + : generatePath(orderByName); + return SortBuilders.fieldSort(orderByName) + .order(sortOrderMap.get(option.getSortOrder())) + .setNestedSort(new NestedSortBuilder(path.toString())); } else { throw new IllegalStateException("unsupported expression " + expression.getClass()); } } + /** + * Validate semantics for arguments in nested function. + * @param nestedFunc Nested function expression. + */ + private void validateNestedArgs(FunctionExpression nestedFunc) { + if (nestedFunc.getArguments().size() < 1 || nestedFunc.getArguments().size() > 2) { + throw new IllegalArgumentException( + "nested function supports 2 parameters (field, path) or 1 parameter (field)" + ); + } + + for (Expression arg : nestedFunc.getArguments()) { + if (!(arg instanceof ReferenceExpression)) { + throw new IllegalArgumentException( + String.format("Illegal nested field name: %s", + arg.toString() + ) + ); + } + } + } + private FieldSortBuilder fieldBuild(ReferenceExpression ref, Sort.SortOption option) { return SortBuilders.fieldSort( OpenSearchTextType.convertTextToKeyword(ref.getAttr(), ref.type())) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java index d5283cecb7..e045bae3e3 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java @@ -19,6 +19,7 @@ import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.LONG; import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.highlight; @@ -58,6 +59,7 @@ import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.sort.NestedSortBuilder; import org.opensearch.search.sort.SortBuilder; import org.opensearch.search.sort.SortBuilders; import org.opensearch.search.sort.SortOrder; @@ -574,6 +576,76 @@ void only_one_project_should_be_push() { ); } + @Test + void test_nested_sort_filter_push_down() { + assertEqualsAfterOptimization( + project( + indexScanBuilder( + withFilterPushedDown(QueryBuilders.termQuery("intV", 1)), + withSortPushedDown( + SortBuilders.fieldSort("message.info") + .order(SortOrder.ASC) + .setNestedSort(new NestedSortBuilder("message")))), + DSL.named("intV", DSL.ref("intV", INTEGER)) + ), + project( + sort( + filter( + relation("schema", table), + DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) + ), + Pair.of( + SortOption.DEFAULT_ASC, DSL.nested(DSL.ref("message.info", STRING)) + ) + ), + DSL.named("intV", DSL.ref("intV", INTEGER)) + ) + ); + } + + @Test + void test_function_expression_sort_returns_optimized_logical_sort() { + // Invalid use case coverage OpenSearchIndexScanBuilder::sortByFieldsOnly returns false + assertEqualsAfterOptimization( + sort( + indexScanBuilder(), + Pair.of( + SortOption.DEFAULT_ASC, + DSL.match(DSL.namedArgument("field", literal("message"))) + ) + ), + sort( + relation("schema", table), + Pair.of( + SortOption.DEFAULT_ASC, + DSL.match(DSL.namedArgument("field", literal("message")) + ) + ) + ) + ); + } + + @Test + void test_non_field_sort_returns_optimized_logical_sort() { + // Invalid use case coverage OpenSearchIndexScanBuilder::sortByFieldsOnly returns false + assertEqualsAfterOptimization( + sort( + indexScanBuilder(), + Pair.of( + SortOption.DEFAULT_ASC, + DSL.literal("field") + ) + ), + sort( + relation("schema", table), + Pair.of( + SortOption.DEFAULT_ASC, + DSL.literal("field") + ) + ) + ); + } + @Test void sort_with_expression_cannot_merge_with_relation() { assertEqualsAfterOptimization( diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java index df6cfae78f..e84ed14e43 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java @@ -10,11 +10,14 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; import org.hamcrest.Matchers; import org.junit.jupiter.api.Test; import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.data.model.ExprShortValue; import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.LiteralExpression; class SortQueryBuilderTest { @@ -25,6 +28,73 @@ void build_sortbuilder_from_reference() { assertNotNull(sortQueryBuilder.build(DSL.ref("intV", INTEGER), Sort.SortOption.DEFAULT_ASC)); } + @Test + void build_sortbuilder_from_nested_function() { + assertNotNull( + sortQueryBuilder.build( + DSL.nested(DSL.ref("message.info", STRING)), + Sort.SortOption.DEFAULT_ASC + ) + ); + } + + @Test + void build_sortbuilder_from_nested_function_with_path_param() { + assertNotNull( + sortQueryBuilder.build( + DSL.nested(DSL.ref("message.info", STRING), DSL.ref("message", STRING)), + Sort.SortOption.DEFAULT_ASC + ) + ); + } + + @Test + void nested_with_too_many_args_throws_exception() { + assertThrows( + IllegalArgumentException.class, + () -> sortQueryBuilder.build( + DSL.nested( + DSL.ref("message.info", STRING), + DSL.ref("message", STRING), + DSL.ref("message", STRING) + ), + Sort.SortOption.DEFAULT_ASC + ) + ); + } + + @Test + void nested_with_too_few_args_throws_exception() { + assertThrows( + IllegalArgumentException.class, + () -> sortQueryBuilder.build( + DSL.nested(), + Sort.SortOption.DEFAULT_ASC + ) + ); + } + + @Test + void nested_with_invalid_arg_type_throws_exception() { + assertThrows( + IllegalArgumentException.class, + () -> sortQueryBuilder.build( + DSL.nested( + DSL.literal(1) + ), + Sort.SortOption.DEFAULT_ASC + ) + ); + } + + @Test + void build_sortbuilder_from_expression_should_throw_exception() { + final IllegalStateException exception = + assertThrows(IllegalStateException.class, () -> sortQueryBuilder.build( + new LiteralExpression(new ExprShortValue(1)), Sort.SortOption.DEFAULT_ASC)); + assertThat(exception.getMessage(), Matchers.containsString("unsupported expression")); + } + @Test void build_sortbuilder_from_function_should_throw_exception() { final IllegalStateException exception = diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstSortBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstSortBuilder.java index 993bc10615..1b872dce54 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstSortBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstSortBuilder.java @@ -17,17 +17,12 @@ import lombok.RequiredArgsConstructor; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Field; -import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.Sort.NullOrder; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.ast.tree.Sort.SortOrder; import org.opensearch.sql.ast.tree.UnresolvedPlan; -import org.opensearch.sql.common.antlr.SyntaxCheckException; -import org.opensearch.sql.exception.SemanticCheckException; -import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParserBaseVisitor; import org.opensearch.sql.sql.parser.context.QuerySpecification; @@ -53,15 +48,6 @@ private List createSortFields() { List items = querySpec.getOrderByItems(); List options = querySpec.getOrderByOptions(); for (int i = 0; i < items.size(); i++) { - // TODO remove me when Nested function is supported in ORDER BY clause. - if (items.get(i) instanceof Function - && ((Function)items.get(i)).getFuncName().equalsIgnoreCase( - BuiltinFunctionName.NESTED.name()) - ) { - throw new SyntaxCheckException( - "Falling back to legacy engine. Nested function is not supported in ORDER BY clause." - ); - } fields.add( new Field( querySpec.replaceIfAliasOrOrdinal(items.get(i)), diff --git a/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java b/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java index 39fe8811b5..6d43daa60f 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java @@ -638,13 +638,28 @@ public void can_parse_wildcard_query_relevance_function() { @Test public void can_parse_nested_function() { assertNotNull( - parser.parse("SELECT NESTED(FIELD.DAYOFWEEK) FROM TEST")); + parser.parse("SELECT NESTED(PATH.INNER_FIELD) FROM TEST")); assertNotNull( - parser.parse("SELECT NESTED('FIELD.DAYOFWEEK') FROM TEST")); + parser.parse("SELECT NESTED('PATH.INNER_FIELD') FROM TEST")); assertNotNull( - parser.parse("SELECT SUM(NESTED(FIELD.SUBFIELD)) FROM TEST")); + parser.parse("SELECT SUM(NESTED(PATH.INNER_FIELD)) FROM TEST")); assertNotNull( - parser.parse("SELECT NESTED(FIELD.DAYOFWEEK, PATH) FROM TEST")); + parser.parse("SELECT NESTED(PATH.INNER_FIELD, PATH) FROM TEST")); + assertNotNull( + parser.parse( + "SELECT * FROM TEST WHERE NESTED(PATH.INNER_FIELDS) = 'A'" + ) + ); + assertNotNull( + parser.parse( + "SELECT * FROM TEST WHERE NESTED(PATH.INNER_FIELDS, PATH) = 'A'" + ) + ); + assertNotNull( + parser.parse( + "SELECT FIELD FROM TEST ORDER BY nested(PATH.INNER_FIELD, PATH)" + ) + ); } @Test diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java index 0c69909334..e017bd8cd6 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java @@ -462,21 +462,6 @@ public void can_build_order_by_null_option() { buildAST("SELECT name FROM test ORDER BY name NULLS LAST")); } - /** - * Ensure Nested function falls back to legacy engine when used in an ORDER BY clause. - * TODO Remove this test when support is added. - */ - @Test - public void nested_in_order_by_clause_throws_exception() { - SyntaxCheckException exception = assertThrows(SyntaxCheckException.class, - () -> buildAST("SELECT * FROM test ORDER BY nested(message.info)") - ); - - assertEquals( - "Falling back to legacy engine. Nested function is not supported in ORDER BY clause.", - exception.getMessage()); - } - /** * Ensure Nested function falls back to legacy engine when used in an HAVING clause. * TODO Remove this test when support is added.