diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index 1aeb89a6c..bb8f697ec 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -27,6 +27,7 @@ If you get integration test failures with error message "Previous attempts to fi The `aws-integration` folder contains tests for cloud server providers. For instance, test against AWS OpenSearch domain, configure the following settings. The client will use the default credential provider to access the AWS OpenSearch domain. ``` export AWS_OPENSEARCH_HOST=search-xxx.us-west-2.on.aws +export AWS_OPENSEARCH_SERVERLESS_HOST=xxx.us-west-2.aoss.amazonaws.com export AWS_REGION=us-west-2 export AWS_EMRS_APPID=xxx export AWS_EMRS_EXECUTION_ROLE=xxx diff --git a/README.md b/README.md index f9568838e..b586ca44b 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,8 @@ Version compatibility: | 0.2.0 | 11+ | 3.3.1 | 2.12.14 | 2.6+ | | 0.3.0 | 11+ | 3.3.2 | 2.12.14 | 2.13+ | | 0.4.0 | 11+ | 3.3.2 | 2.12.14 | 2.13+ | -| 0.5.0 | 11+ | 3.3.2 | 2.12.14 | 2.13+ | +| 0.5.0 | 11+ | 3.5.1 | 2.12.14 | 2.13+ | +| 0.6.0 | 11+ | 3.5.1 | 2.12.14 | 2.13+ | ## Flint Extension Usage @@ -42,31 +43,36 @@ spark-sql --conf "spark.sql.extensions=org.opensearch.flint.spark.FlintPPLSparkE ### Running With both Extension ``` -spark-sql --conf "spark.sql.extensions='org.opensearch.flint.spark.FlintPPLSparkExtensions, org.opensearch.flint.spark.FlintSparkExtensions'" +spark-sql --conf "spark.sql.extensions=org.opensearch.flint.spark.FlintPPLSparkExtensions,org.opensearch.flint.spark.FlintSparkExtensions" ``` ## Build -To build and run this application with Spark, you can run: +To build and run this application with Spark, you can run (requires Java 11): ``` sbt clean standaloneCosmetic/publishM2 ``` -then add org.opensearch:opensearch-spark_2.12 when run spark application, for example, +then add org.opensearch:opensearch-spark-standalone_2.12 when run spark application, for example, ``` -bin/spark-shell --packages "org.opensearch:opensearch-spark_2.12:0.5.0-SNAPSHOT" +bin/spark-shell --packages "org.opensearch:opensearch-spark-standalone_2.12:0.6.0-SNAPSHOT" \ + --conf "spark.sql.extensions=org.opensearch.flint.spark.FlintSparkExtensions" \ + --conf "spark.sql.catalog.dev=org.apache.spark.opensearch.catalog.OpenSearchCatalog" ``` ### PPL Build & Run -To build and run this PPL in Spark, you can run: +To build and run this PPL in Spark, you can run (requires Java 11): ``` sbt clean sparkPPLCosmetic/publishM2 ``` -then add org.opensearch:opensearch-spark_2.12 when run spark application, for example, +then add org.opensearch:opensearch-spark-ppl_2.12 when run spark application, for example, ``` -bin/spark-shell --packages "org.opensearch:opensearch-spark-ppl_2.12:0.5.0-SNAPSHOT" +bin/spark-shell --packages "org.opensearch:opensearch-spark-ppl_2.12:0.6.0-SNAPSHOT" \ + --conf "spark.sql.extensions=org.opensearch.flint.spark.FlintPPLSparkExtensions" \ + --conf "spark.sql.catalog.dev=org.apache.spark.opensearch.catalog.OpenSearchCatalog" + ``` ## Code of Conduct diff --git a/build.sbt b/build.sbt index 6f8237aac..542086f2e 100644 --- a/build.sbt +++ b/build.sbt @@ -5,10 +5,10 @@ import Dependencies._ lazy val scala212 = "2.12.14" -lazy val sparkVersion = "3.3.2" -// Spark jackson version. Spark jackson-module-scala strictly check the jackson-databind version hould compatbile +lazy val sparkVersion = "3.5.1" +// Spark jackson version. Spark jackson-module-scala strictly check the jackson-databind version should compatible // https://github.com/FasterXML/jackson-module-scala/blob/2.18/src/main/scala/com/fasterxml/jackson/module/scala/JacksonModule.scala#L59 -lazy val jacksonVersion = "2.13.4" +lazy val jacksonVersion = "2.15.2" // The transitive opensearch jackson-databind dependency version should align with Spark jackson databind dependency version. // Issue: https://github.com/opensearch-project/opensearch-spark/issues/442 @@ -20,7 +20,7 @@ val sparkMinorVersion = sparkVersion.split("\\.").take(2).mkString(".") ThisBuild / organization := "org.opensearch" -ThisBuild / version := "0.5.0-SNAPSHOT" +ThisBuild / version := "0.6.0-SNAPSHOT" ThisBuild / scalaVersion := scala212 @@ -210,10 +210,6 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) val oldStrategy = (assembly / assemblyMergeStrategy).value oldStrategy(x) }, - assembly / assemblyExcludedJars := { - val cp = (assembly / fullClasspath).value - cp filter { file => file.data.getName.contains("LogsConnectorSpark")} - }, assembly / test := (Test / test).value) lazy val IntegrationTest = config("it") extend Test diff --git a/docs/PPL-Correlation-command.md b/docs/PPL-Correlation-command.md index f7ef3e266..2e8507a14 100644 --- a/docs/PPL-Correlation-command.md +++ b/docs/PPL-Correlation-command.md @@ -1,5 +1,8 @@ ## PPL Correlation Command +> This is an experimental command - it may be removed in future versions + + ## Overview In the past year OpenSearch Observability & security teams have been busy with many aspects of improving data monitoring and visibility. @@ -262,6 +265,8 @@ The new correlation command is actually a ‘hidden’ join command therefore th Catalyst engine will optimize this query according to the most efficient join ordering. +> This is an experimental command - it may be removed in future versions + * * * ## Appendix diff --git a/docs/PPL-on-Spark.md b/docs/PPL-on-Spark.md index dd7c40710..7e7dbde5d 100644 --- a/docs/PPL-on-Spark.md +++ b/docs/PPL-on-Spark.md @@ -34,7 +34,7 @@ sbt clean sparkPPLCosmetic/publishM2 ``` then add org.opensearch:opensearch-spark_2.12 when run spark application, for example, ``` -bin/spark-shell --packages "org.opensearch:opensearch-spark-ppl_2.12:0.5.0-SNAPSHOT" +bin/spark-shell --packages "org.opensearch:opensearch-spark-ppl_2.12:0.6.0-SNAPSHOT" ``` ### PPL Extension Usage @@ -46,7 +46,7 @@ spark-sql --conf "spark.sql.extensions=org.opensearch.flint.spark.FlintPPLSparkE ``` ### Running With both Flint & PPL Extensions -In order to make use of both flint and ppl extension, one can simply add both jars (`org.opensearch:opensearch-spark-ppl_2.12:0.5.0-SNAPSHOT`,`org.opensearch:opensearch-spark_2.12:0.5.0-SNAPSHOT`) to the cluster's +In order to make use of both flint and ppl extension, one can simply add both jars (`org.opensearch:opensearch-spark-ppl_2.12:0.6.0-SNAPSHOT`,`org.opensearch:opensearch-spark_2.12:0.6.0-SNAPSHOT`) to the cluster's classpath. Next need to configure both extensions : diff --git a/docs/index.md b/docs/index.md index 981ed16f0..3bfa0b468 100644 --- a/docs/index.md +++ b/docs/index.md @@ -60,7 +60,7 @@ Currently, Flint metadata is only static configuration without version control a ```json { - "version": "0.5.0", + "version": "0.6.0", "name": "...", "kind": "skipping", "source": "...", @@ -521,12 +521,15 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i - `spark.datasource.flint.auth.username`: basic auth username. - `spark.datasource.flint.auth.password`: basic auth password. - `spark.datasource.flint.region`: default is us-west-2. only been used when auth=sigv4 -- `spark.datasource.flint.customAWSCredentialsProvider`: default is empty. +- `spark.datasource.flint.customAWSCredentialsProvider`: default is empty. +- `spark.datasource.flint.customFlintMetadataLogServiceClass`: default is empty. +- `spark.datasource.flint.customFlintIndexMetadataServiceClass`: default is empty. - `spark.datasource.flint.write.id_name`: no default value. - `spark.datasource.flint.ignore.id_column` : default value is true. - `spark.datasource.flint.write.batch_size`: "The number of documents written to Flint in a single batch request. Default value is Integer.MAX_VALUE. - `spark.datasource.flint.write.batch_bytes`: The approximately amount of data in bytes written to Flint in a single batch request. The actual data write to OpenSearch may more than it. Default value is 1mb. The writing process checks after each document whether the total number of documents (docCount) has reached batch_size or the buffer size has surpassed batch_bytes. If either condition is met, the current batch is flushed and the document count resets to zero. - `spark.datasource.flint.write.refresh_policy`: default value is false. valid values [NONE(false), IMMEDIATE(true), WAIT_UNTIL(wait_for)] +- `spark.datasource.flint.write.bulkRequestRateLimitPerNode`: [Experimental] Rate limit(request/sec) for bulk request per worker node. Only accept integer value. To reduce the traffic less than 1 req/sec, batch_bytes or batch_size should be reduced. Default value is 0, which disables rate limit. - `spark.datasource.flint.read.scroll_size`: default value is 100. - `spark.datasource.flint.read.scroll_duration`: default value is 5 minutes. scroll context keep alive duration. - `spark.datasource.flint.retry.max_retries`: max retries on failed HTTP request. default value is 3. Use 0 to disable retry. @@ -689,7 +692,7 @@ For now, only single or conjunct conditions (conditions connected by AND) in WHE ### AWS EMR Spark Integration - Using execution role Flint use [DefaultAWSCredentialsProviderChain](https://docs.aws.amazon.com/AWSJavaSDK/latest/javadoc/com/amazonaws/auth/DefaultAWSCredentialsProviderChain.html). When running in EMR Spark, Flint use executionRole credentials ``` ---conf spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.5.0-SNAPSHOT \ +--conf spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.6.0-SNAPSHOT \ --conf spark.jars.repositories=https://aws.oss.sonatype.org/content/repositories/snapshots \ --conf spark.emr-serverless.driverEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64 \ --conf spark.executorEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64 \ @@ -731,7 +734,7 @@ Flint use [DefaultAWSCredentialsProviderChain](https://docs.aws.amazon.com/AWSJa ``` 3. Set the spark.datasource.flint.customAWSCredentialsProvider property with value as com.amazonaws.emr.AssumeRoleAWSCredentialsProvider. Set the environment variable ASSUME_ROLE_CREDENTIALS_ROLE_ARN with the ARN value of CrossAccountRoleB. ``` ---conf spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.5.0-SNAPSHOT \ +--conf spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.6.0-SNAPSHOT \ --conf spark.jars.repositories=https://aws.oss.sonatype.org/content/repositories/snapshots \ --conf spark.emr-serverless.driverEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64 \ --conf spark.executorEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64 \ diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintVersion.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/FlintVersion.scala similarity index 79% rename from flint-core/src/main/scala/org/opensearch/flint/core/FlintVersion.scala rename to flint-commons/src/main/scala/org/opensearch/flint/common/FlintVersion.scala index 909d76ce5..1203ea7ef 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintVersion.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/FlintVersion.scala @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.core +package org.opensearch.flint.common /** * Flint version. @@ -19,6 +19,7 @@ object FlintVersion { val V_0_3_0: FlintVersion = FlintVersion("0.3.0") val V_0_4_0: FlintVersion = FlintVersion("0.4.0") val V_0_5_0: FlintVersion = FlintVersion("0.5.0") + val V_0_6_0: FlintVersion = FlintVersion("0.6.0") - def current(): FlintVersion = V_0_5_0 + def current(): FlintVersion = V_0_6_0 } diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/FlintIndexMetadataService.java b/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/FlintIndexMetadataService.java new file mode 100644 index 000000000..b990998a9 --- /dev/null +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/FlintIndexMetadataService.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.common.metadata; + +import java.util.Map; + +/** + * Flint index metadata service provides API for index metadata related operations on a Flint index + * regardless of underlying storage. + *

+ * Custom implementations of this interface are expected to provide a public constructor with + * the signature {@code public MyCustomService(SparkConf sparkConf)} to be instantiated by + * the FlintIndexMetadataServiceBuilder. + */ +public interface FlintIndexMetadataService { + + /** + * Retrieve metadata for a Flint index. + * + * @param indexName index name + * @return index metadata + */ + FlintMetadata getIndexMetadata(String indexName); + + /** + * Retrieve all metadata for Flint index whose name matches the given pattern. + * + * @param indexNamePattern index name pattern + * @return map where the keys are the matched index names, and the values are + * corresponding index metadata + */ + Map getAllIndexMetadata(String... indexNamePattern); + + /** + * Update metadata for a Flint index. + * + * @param indexName index name + * @param metadata index metadata to update + */ + void updateIndexMetadata(String indexName, FlintMetadata metadata); + + /** + * Delete metadata for a Flint index. + * + * @param indexName index name + */ + void deleteIndexMetadata(String indexName); +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/FlintMetadata.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/FlintMetadata.scala similarity index 52% rename from flint-core/src/main/scala/org/opensearch/flint/core/metadata/FlintMetadata.scala rename to flint-commons/src/main/scala/org/opensearch/flint/common/metadata/FlintMetadata.scala index e4e94cc8c..219a0a831 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/FlintMetadata.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/FlintMetadata.scala @@ -3,14 +3,13 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.core.metadata +package org.opensearch.flint.common.metadata import java.util +import org.opensearch.flint.common.FlintVersion +import org.opensearch.flint.common.FlintVersion.current import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry -import org.opensearch.flint.core.FlintVersion -import org.opensearch.flint.core.FlintVersion.current -import org.opensearch.flint.core.metadata.FlintJsonHelper._ /** * Flint metadata follows Flint index specification and defines metadata for a Flint index @@ -35,7 +34,11 @@ case class FlintMetadata( schema: util.Map[String, AnyRef] = new util.HashMap[String, AnyRef], /** Optional latest metadata log entry id */ latestId: Option[String] = None, - /** Optional latest metadata log entry */ + /** + * Optional latest metadata log entry. TODO: remove? This was added for SHOW command to be + * fetched during get(All)IndexMetadata. Now describeIndex uses metadata log service to fetch + * log entry after get(All)IndexMetadata so this doesn't need to be part of FlintMetadata. + */ latestLogEntry: Option[FlintMetadataLogEntry] = None, /** Optional Flint index settings. TODO: move elsewhere? */ indexSettings: Option[String]) { @@ -44,124 +47,10 @@ case class FlintMetadata( require(name != null, "name is required") require(kind != null, "kind is required") require(source != null, "source is required") - - /** - * Generate JSON content as index metadata. - * - * @return - * JSON content - */ - def getContent: String = { - try { - buildJson(builder => { - // Add _meta field - objectField(builder, "_meta") { - builder - .field("version", version.version) - .field("name", name) - .field("kind", kind) - .field("source", source) - .field("indexedColumns", indexedColumns) - - if (latestId.isDefined) { - builder.field("latestId", latestId.get) - } - optionalObjectField(builder, "options", options) - optionalObjectField(builder, "properties", properties) - } - - // Add properties (schema) field - builder.field("properties", schema) - }) - } catch { - case e: Exception => - throw new IllegalStateException("Failed to jsonify Flint metadata", e) - } - } } object FlintMetadata { - /** - * Construct Flint metadata with JSON content, index settings, and latest log entry. - * - * @param content - * JSON content - * @param settings - * index settings - * @param latestLogEntry - * latest metadata log entry - * @return - * Flint metadata - */ - def apply( - content: String, - settings: String, - latestLogEntry: FlintMetadataLogEntry): FlintMetadata = { - val metadata = FlintMetadata(content, settings) - metadata.copy(latestLogEntry = Option(latestLogEntry)) - } - - /** - * Construct Flint metadata with JSON content and index settings. - * - * @param content - * JSON content - * @param settings - * index settings - * @return - * Flint metadata - */ - def apply(content: String, settings: String): FlintMetadata = { - val metadata = FlintMetadata(content) - metadata.copy(indexSettings = Option(settings)) - } - - /** - * Parse the given JSON content and construct Flint metadata class. - * - * @param content - * JSON content - * @return - * Flint metadata - */ - def apply(content: String): FlintMetadata = { - try { - val builder = new FlintMetadata.Builder() - parseJson(content) { (parser, fieldName) => - { - fieldName match { - case "_meta" => - parseObjectField(parser) { (parser, innerFieldName) => - { - innerFieldName match { - case "version" => builder.version(FlintVersion.apply(parser.text())) - case "name" => builder.name(parser.text()) - case "kind" => builder.kind(parser.text()) - case "source" => builder.source(parser.text()) - case "indexedColumns" => - parseArrayField(parser) { - builder.addIndexedColumn(parser.map()) - } - case "options" => builder.options(parser.map()) - case "properties" => builder.properties(parser.map()) - case _ => // Handle other fields as needed - } - } - } - case "properties" => - builder.schema(parser.map()) - case _ => // Ignore other fields, for instance, dynamic. - } - } - } - builder.build() - } catch { - case e: Exception => - throw new IllegalStateException("Failed to parse metadata JSON", e) - } - } - def builder(): FlintMetadata.Builder = new Builder /** @@ -231,16 +120,6 @@ object FlintMetadata { this } - def schema(schema: String): this.type = { - parseJson(schema) { (parser, fieldName) => - fieldName match { - case "properties" => this.schema = parser.map() - case _ => // do nothing - } - } - this - } - def latestLogEntry(entry: FlintMetadataLogEntry): this.type = { this.latestId = Option(entry.id) this.latestLogEntry = Option(entry) diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/log/FlintMetadataLogEntry.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/log/FlintMetadataLogEntry.scala index a7391ed6a..982b7df23 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/log/FlintMetadataLogEntry.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/log/FlintMetadataLogEntry.scala @@ -43,8 +43,8 @@ case class FlintMetadataLogEntry( state: IndexState, entryVersion: JMap[String, Any], error: String, - storageContext: JMap[String, Any]) = { - this(id, createTime, state, entryVersion.asScala.toMap, error, storageContext.asScala.toMap) + properties: JMap[String, Any]) = { + this(id, createTime, state, entryVersion.asScala.toMap, error, properties.asScala.toMap) } def this( @@ -53,8 +53,8 @@ case class FlintMetadataLogEntry( state: IndexState, entryVersion: JMap[String, Any], error: String, - storageContext: Map[String, Any]) = { - this(id, createTime, state, entryVersion.asScala.toMap, error, storageContext) + properties: Map[String, Any]) = { + this(id, createTime, state, entryVersion.asScala.toMap, error, properties) } } diff --git a/flint-commons/src/main/scala/org/opensearch/flint/data/ContextualDataStore.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/ContextualDataStore.scala similarity index 93% rename from flint-commons/src/main/scala/org/opensearch/flint/data/ContextualDataStore.scala rename to flint-commons/src/main/scala/org/opensearch/flint/common/model/ContextualDataStore.scala index 109bf654a..408216fad 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/data/ContextualDataStore.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/ContextualDataStore.scala @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.data +package org.opensearch.flint.common.model /** * Provides a mutable map to store and retrieve contextual data using key-value pairs. diff --git a/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala similarity index 78% rename from flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala rename to flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala index dbe73e9a5..bc8b38d9a 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala @@ -3,7 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.data +package org.opensearch.flint.common.model + +import java.util.Locale import org.json4s.{Formats, NoTypeHints} import org.json4s.JsonAST.JString @@ -14,6 +16,7 @@ object StatementStates { val RUNNING = "running" val SUCCESS = "success" val FAILED = "failed" + val TIMEOUT = "timeout" val WAITING = "waiting" } @@ -50,10 +53,15 @@ class FlintStatement( def running(): Unit = state = StatementStates.RUNNING def complete(): Unit = state = StatementStates.SUCCESS def fail(): Unit = state = StatementStates.FAILED - def isRunning: Boolean = state == StatementStates.RUNNING - def isComplete: Boolean = state == StatementStates.SUCCESS - def isFailed: Boolean = state == StatementStates.FAILED - def isWaiting: Boolean = state == StatementStates.WAITING + def timeout(): Unit = state = StatementStates.TIMEOUT + + def isRunning: Boolean = state.equalsIgnoreCase(StatementStates.RUNNING) + + def isComplete: Boolean = state.equalsIgnoreCase(StatementStates.SUCCESS) + + def isFailed: Boolean = state.equalsIgnoreCase(StatementStates.FAILED) + + def isWaiting: Boolean = state.equalsIgnoreCase(StatementStates.WAITING) // Does not include context, which could contain sensitive information. override def toString: String = @@ -66,7 +74,7 @@ object FlintStatement { def deserialize(statement: String): FlintStatement = { val meta = parse(statement) - val state = (meta \ "state").extract[String] + val state = (meta \ "state").extract[String].toLowerCase(Locale.ROOT) val query = (meta \ "query").extract[String] val statementId = (meta \ "statementId").extract[String] val queryId = (meta \ "queryId").extract[String] @@ -82,6 +90,8 @@ object FlintStatement { def serialize(flintStatement: FlintStatement): String = { // we only need to modify state and error Serialization.write( - Map("state" -> flintStatement.state, "error" -> flintStatement.error.getOrElse(""))) + Map( + "state" -> flintStatement.state.toLowerCase(Locale.ROOT), + "error" -> flintStatement.error.getOrElse(""))) } } diff --git a/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala similarity index 90% rename from flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala rename to flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala index c5eaee4f1..9acdeab5f 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala @@ -3,9 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.data +package org.opensearch.flint.common.model -import java.util.{Map => JavaMap} +import java.util.{Locale, Map => JavaMap} import scala.collection.JavaConverters._ @@ -16,9 +16,8 @@ import org.json4s.native.Serialization object SessionStates { val RUNNING = "running" - val COMPLETE = "complete" - val FAILED = "failed" - val WAITING = "waiting" + val DEAD = "dead" + val FAIL = "fail" } /** @@ -56,10 +55,15 @@ class InteractiveSession( extends ContextualDataStore { context = sessionContext // Initialize the context from the constructor - def isRunning: Boolean = state == SessionStates.RUNNING - def isComplete: Boolean = state == SessionStates.COMPLETE - def isFailed: Boolean = state == SessionStates.FAILED - def isWaiting: Boolean = state == SessionStates.WAITING + def running(): Unit = state = SessionStates.RUNNING + def complete(): Unit = state = SessionStates.DEAD + def fail(): Unit = state = SessionStates.FAIL + + def isRunning: Boolean = state.equalsIgnoreCase(SessionStates.RUNNING) + + def isComplete: Boolean = state.equalsIgnoreCase(SessionStates.DEAD) + + def isFail: Boolean = state.equalsIgnoreCase(SessionStates.FAIL) override def toString: String = { val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]") @@ -77,7 +81,7 @@ object InteractiveSession { def deserialize(job: String): InteractiveSession = { val meta = parse(job) val applicationId = (meta \ "applicationId").extract[String] - val state = (meta \ "state").extract[String] + val state = (meta \ "state").extract[String].toLowerCase(Locale.ROOT) val jobId = (meta \ "jobId").extract[String] val sessionId = (meta \ "sessionId").extract[String] val lastUpdateTime = (meta \ "lastUpdateTime").extract[Long] @@ -116,7 +120,7 @@ object InteractiveSession { val scalaSource = source.asScala val applicationId = scalaSource("applicationId").asInstanceOf[String] - val state = scalaSource("state").asInstanceOf[String] + val state = scalaSource("state").asInstanceOf[String].toLowerCase(Locale.ROOT) val jobId = scalaSource("jobId").asInstanceOf[String] val sessionId = scalaSource("sessionId").asInstanceOf[String] val lastUpdateTime = scalaSource("lastUpdateTime").asInstanceOf[Long] @@ -178,7 +182,7 @@ object InteractiveSession { "sessionId" -> job.sessionId, "error" -> job.error.getOrElse(""), "applicationId" -> job.applicationId, - "state" -> job.state, + "state" -> job.state.toLowerCase(Locale.ROOT), // update last update time "lastUpdateTime" -> currentTime, // Convert a Seq[String] into a comma-separated string, such as "id1,id2". diff --git a/flint-commons/src/test/scala/org/opensearch/flint/common/metadata/FlintMetadataSuite.scala b/flint-commons/src/test/scala/org/opensearch/flint/common/metadata/FlintMetadataSuite.scala new file mode 100644 index 000000000..ea69c69ef --- /dev/null +++ b/flint-commons/src/test/scala/org/opensearch/flint/common/metadata/FlintMetadataSuite.scala @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.common.metadata + +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import org.opensearch.flint.common.FlintVersion.current +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class FlintMetadataSuite extends AnyFlatSpec with Matchers { + "builder" should "build FlintMetadata with provided fields" in { + val builder = new FlintMetadata.Builder + builder.name("test_index") + builder.kind("test_kind") + builder.source("test_source_table") + builder.addIndexedColumn(Map[String, AnyRef]("test_field" -> "spark_type").asJava) + builder.schema(Map[String, AnyRef]("test_field" -> Map("type" -> "os_type").asJava).asJava) + + val metadata = builder.build() + + metadata.version shouldBe current() + metadata.name shouldBe "test_index" + metadata.kind shouldBe "test_kind" + metadata.source shouldBe "test_source_table" + metadata.indexedColumns shouldBe Array(Map("test_field" -> "spark_type").asJava) + metadata.schema shouldBe Map("test_field" -> Map("type" -> "os_type").asJava).asJava + } +} diff --git a/flint-commons/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala b/flint-commons/src/test/scala/org/opensearch/flint/common/model/InteractiveSessionTest.scala similarity index 97% rename from flint-commons/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala rename to flint-commons/src/test/scala/org/opensearch/flint/common/model/InteractiveSessionTest.scala index f69fe70b4..5f6b1fdc1 100644 --- a/flint-commons/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala +++ b/flint-commons/src/test/scala/org/opensearch/flint/common/model/InteractiveSessionTest.scala @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.data +package org.opensearch.flint.common.model import java.util.{HashMap => JavaHashMap} @@ -21,7 +21,7 @@ class InteractiveSessionTest extends SparkFunSuite with Matchers { instance.applicationId shouldBe "app-123" instance.jobId shouldBe "job-456" instance.sessionId shouldBe "session-789" - instance.state shouldBe "RUNNING" + instance.state shouldBe "running" instance.lastUpdateTime shouldBe 1620000000000L instance.jobStartTime shouldBe 1620000001000L instance.excludedJobIds should contain allOf ("job-101", "job-202") @@ -44,7 +44,7 @@ class InteractiveSessionTest extends SparkFunSuite with Matchers { json should include(""""applicationId":"app-123"""") json should not include (""""jobId":"job-456"""") json should include(""""sessionId":"session-789"""") - json should include(""""state":"RUNNING"""") + json should include(""""state":"running"""") json should include(s""""lastUpdateTime":$currentTime""") json should include( """"excludeJobIds":"job-101,job-202"""" @@ -149,7 +149,7 @@ class InteractiveSessionTest extends SparkFunSuite with Matchers { instance.applicationId shouldBe "app-123" instance.jobId shouldBe "job-456" instance.sessionId shouldBe "session-789" - instance.state shouldBe "RUNNING" + instance.state shouldBe "running" instance.lastUpdateTime shouldBe 1620000000000L instance.jobStartTime shouldBe 0L // Default or expected value for missing jobStartTime instance.excludedJobIds should contain allOf ("job-101", "job-202") diff --git a/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java b/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java index f26a6c158..28fad39a0 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java @@ -37,6 +37,7 @@ import org.opensearch.client.transport.rest_client.RestClientTransport; import java.io.IOException; +import org.opensearch.flint.core.storage.BulkRequestRateLimiter; import static org.opensearch.flint.core.metrics.MetricConstants.OS_READ_OP_METRIC_PREFIX; import static org.opensearch.flint.core.metrics.MetricConstants.OS_WRITE_OP_METRIC_PREFIX; @@ -47,6 +48,7 @@ */ public class RestHighLevelClientWrapper implements IRestHighLevelClient { private final RestHighLevelClient client; + private final BulkRequestRateLimiter rateLimiter; private final static JacksonJsonpMapper JACKSON_MAPPER = new JacksonJsonpMapper(); @@ -55,13 +57,21 @@ public class RestHighLevelClientWrapper implements IRestHighLevelClient { * * @param client the RestHighLevelClient instance to wrap */ - public RestHighLevelClientWrapper(RestHighLevelClient client) { + public RestHighLevelClientWrapper(RestHighLevelClient client, BulkRequestRateLimiter rateLimiter) { this.client = client; + this.rateLimiter = rateLimiter; } @Override public BulkResponse bulk(BulkRequest bulkRequest, RequestOptions options) throws IOException { - return execute(OS_WRITE_OP_METRIC_PREFIX, () -> client.bulk(bulkRequest, options)); + return execute(OS_WRITE_OP_METRIC_PREFIX, () -> { + try { + rateLimiter.acquirePermit(); + return client.bulk(bulkRequest, options); + } catch (InterruptedException e) { + throw new RuntimeException("rateLimiter.acquirePermit was interrupted.", e); + } + }); } @Override diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java index 6e3e90916..c3e00a067 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java @@ -10,9 +10,9 @@ import java.util.logging.Level; import java.util.logging.Logger; -import org.apache.commons.lang.StringUtils; import com.amazonaws.services.cloudwatch.model.Dimension; +import org.apache.commons.lang3.StringUtils; import org.apache.spark.SparkEnv; /** @@ -124,4 +124,4 @@ private static Dimension getEnvironmentVariableDimension(String envVarName, Stri private static Dimension getDefaultDimension(String dimensionName) { return getEnvironmentVariableDimension(dimensionName, dimensionName); } -} \ No newline at end of file +} diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java index f4d456899..a5ea190c5 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java @@ -46,7 +46,6 @@ import java.util.stream.Collectors; import java.util.stream.LongStream; import java.util.stream.Stream; -import org.apache.commons.lang.StringUtils; import org.apache.spark.metrics.sink.CloudWatchSink.DimensionNameGroups; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java index 1a3775f0b..29b5f6de9 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java @@ -5,9 +5,7 @@ package org.opensearch.flint.core; -import java.util.Map; - -import org.opensearch.flint.core.metadata.FlintMetadata; +import org.opensearch.flint.common.metadata.FlintMetadata; import org.opensearch.flint.core.storage.FlintWriter; /** @@ -32,31 +30,6 @@ public interface FlintClient { */ boolean exists(String indexName); - /** - * Retrieve all metadata for Flint index whose name matches the given pattern. - * - * @param indexNamePattern index name pattern - * @return map where the keys are the matched index names, and the values are - * corresponding index metadata - */ - Map getAllIndexMetadata(String... indexNamePattern); - - /** - * Retrieve metadata in a Flint index. - * - * @param indexName index name - * @return index metadata - */ - FlintMetadata getIndexMetadata(String indexName); - - /** - * Update a Flint index with the metadata given. - * - * @param indexName index name - * @param metadata index metadata - */ - void updateIndex(String indexName, FlintMetadata metadata); - /** * Delete a Flint index. * diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java index 2678a8f67..6c3c02b9f 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java @@ -95,12 +95,17 @@ public class FlintOptions implements Serializable { public static final String DEFAULT_BATCH_BYTES = "1mb"; - public static final String CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS = "spark.datasource.flint.customFlintMetadataLogServiceClass"; + public static final String CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS = "customFlintMetadataLogServiceClass"; + + public static final String CUSTOM_FLINT_INDEX_METADATA_SERVICE_CLASS = "customFlintIndexMetadataServiceClass"; public static final String SUPPORT_SHARD = "read.support_shard"; public static final String DEFAULT_SUPPORT_SHARD = "true"; + public static final String BULK_REQUEST_RATE_LIMIT_PER_NODE = "bulkRequestRateLimitPerNode"; + public static final String DEFAULT_BULK_REQUEST_RATE_LIMIT_PER_NODE = "0"; + public FlintOptions(Map options) { this.options = options; this.retryOptions = new FlintRetryOptions(options); @@ -186,6 +191,10 @@ public String getCustomFlintMetadataLogServiceClass() { return options.getOrDefault(CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS, ""); } + public String getCustomFlintIndexMetadataServiceClass() { + return options.getOrDefault(CUSTOM_FLINT_INDEX_METADATA_SERVICE_CLASS, ""); + } + /** * FIXME, This is workaround for AWS OpenSearch Serverless (AOSS). AOSS does not support shard * operation, but shard info is exposed in index settings. Remove this setting when AOSS fix @@ -197,4 +206,8 @@ public boolean supportShard() { return options.getOrDefault(SUPPORT_SHARD, DEFAULT_SUPPORT_SHARD).equalsIgnoreCase( DEFAULT_SUPPORT_SHARD); } + + public long getBulkRequestRateLimitPerNode() { + return Long.parseLong(options.getOrDefault(BULK_REQUEST_RATE_LIMIT_PER_NODE, DEFAULT_BULK_REQUEST_RATE_LIMIT_PER_NODE)); + } } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/MetaData.scala b/flint-core/src/main/scala/org/opensearch/flint/core/MetaData.scala index 98b7f8960..1ada5f2e3 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/MetaData.scala +++ b/flint-core/src/main/scala/org/opensearch/flint/core/MetaData.scala @@ -5,8 +5,6 @@ package org.opensearch.flint.core -import org.opensearch.flint.core.metadata.FlintMetadata - /** * OpenSearch Table metadata. * @@ -18,11 +16,3 @@ import org.opensearch.flint.core.metadata.FlintMetadata * setting */ case class MetaData(name: String, properties: String, setting: String) - -object MetaData { - def apply(name: String, flintMetadata: FlintMetadata): MetaData = { - val properties = flintMetadata.getContent - val setting = flintMetadata.indexSettings.getOrElse("") - MetaData(name, properties, setting) - } -} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/FlintIndexMetadataServiceBuilder.java b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/FlintIndexMetadataServiceBuilder.java new file mode 100644 index 000000000..d6f88135f --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/FlintIndexMetadataServiceBuilder.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metadata; + +import java.lang.reflect.Constructor; +import org.opensearch.flint.common.metadata.FlintIndexMetadataService; +import org.opensearch.flint.core.FlintOptions; +import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService; + +/** + * {@link FlintIndexMetadataService} builder. + *

+ * Custom implementations of {@link FlintIndexMetadataService} are expected to provide a public + * constructor with no arguments to be instantiated by this builder. + */ +public class FlintIndexMetadataServiceBuilder { + public static FlintIndexMetadataService build(FlintOptions options) { + String className = options.getCustomFlintIndexMetadataServiceClass(); + if (className.isEmpty()) { + return new FlintOpenSearchIndexMetadataService(options); + } + + // Attempts to instantiate Flint index metadata service using reflection + try { + Class flintIndexMetadataServiceClass = Class.forName(className); + Constructor constructor = flintIndexMetadataServiceClass.getConstructor(); + return (FlintIndexMetadataService) constructor.newInstance(); + } catch (Exception e) { + throw new RuntimeException("Failed to instantiate FlintIndexMetadataService: " + className, e); + } + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogServiceBuilder.java b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogServiceBuilder.java index 9ec4ac2c4..fc89eea9f 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogServiceBuilder.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogServiceBuilder.java @@ -6,7 +6,6 @@ package org.opensearch.flint.core.metadata.log; import java.lang.reflect.Constructor; -import org.apache.spark.SparkConf; import org.opensearch.flint.common.metadata.log.FlintMetadataLogService; import org.opensearch.flint.core.FlintOptions; import org.opensearch.flint.core.storage.FlintOpenSearchMetadataLogService; @@ -15,21 +14,20 @@ * {@link FlintMetadataLogService} builder. *

* Custom implementations of {@link FlintMetadataLogService} are expected to provide a public - * constructor with the signature {@code public MyCustomService(SparkConf sparkConf)} to be - * instantiated by this builder. + * constructor with no arguments to be instantiated by this builder. */ public class FlintMetadataLogServiceBuilder { - public static FlintMetadataLogService build(FlintOptions options, SparkConf sparkConf) { + public static FlintMetadataLogService build(FlintOptions options) { String className = options.getCustomFlintMetadataLogServiceClass(); if (className.isEmpty()) { return new FlintOpenSearchMetadataLogService(options); } - // Attempts to instantiate Flint metadata log service with sparkConf using reflection + // Attempts to instantiate Flint metadata log service using reflection try { Class flintMetadataLogServiceClass = Class.forName(className); - Constructor constructor = flintMetadataLogServiceClass.getConstructor(SparkConf.class); - return (FlintMetadataLogService) constructor.newInstance(sparkConf); + Constructor constructor = flintMetadataLogServiceClass.getConstructor(); + return (FlintMetadataLogService) constructor.newInstance(); } catch (Exception e) { throw new RuntimeException("Failed to instantiate FlintMetadataLogService: " + className, e); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiter.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiter.java new file mode 100644 index 000000000..af298cc8f --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiter.java @@ -0,0 +1,30 @@ +package org.opensearch.flint.core.storage; + +import dev.failsafe.RateLimiter; +import java.time.Duration; +import java.util.logging.Logger; +import org.opensearch.flint.core.FlintOptions; + +public class BulkRequestRateLimiter { + private static final Logger LOG = Logger.getLogger(BulkRequestRateLimiter.class.getName()); + private RateLimiter rateLimiter; + + public BulkRequestRateLimiter(FlintOptions flintOptions) { + long bulkRequestRateLimitPerNode = flintOptions.getBulkRequestRateLimitPerNode(); + if (bulkRequestRateLimitPerNode > 0) { + LOG.info("Setting rate limit for bulk request to " + bulkRequestRateLimitPerNode + "/sec"); + this.rateLimiter = RateLimiter.smoothBuilder( + flintOptions.getBulkRequestRateLimitPerNode(), + Duration.ofSeconds(1)).build(); + } else { + LOG.info("Rate limit for bulk request was not set."); + } + } + + // Wait so it won't exceed rate limit. Does nothing if rate limit is not set. + public void acquirePermit() throws InterruptedException { + if (rateLimiter != null) { + this.rateLimiter.acquirePermit(); + } + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterHolder.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterHolder.java new file mode 100644 index 000000000..0453c70c8 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterHolder.java @@ -0,0 +1,22 @@ +package org.opensearch.flint.core.storage; + +import org.opensearch.flint.core.FlintOptions; + +/** + * Hold shared instance of BulkRequestRateLimiter. This class is introduced to make + * BulkRequestRateLimiter testable and share single instance. + */ +public class BulkRequestRateLimiterHolder { + + private static BulkRequestRateLimiter instance; + + private BulkRequestRateLimiterHolder() {} + + public synchronized static BulkRequestRateLimiter getBulkRequestRateLimiter( + FlintOptions flintOptions) { + if (instance == null) { + instance = new BulkRequestRateLimiter(flintOptions); + } + return instance; + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index 1a7c976c2..affcd0e36 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -9,25 +9,15 @@ import org.opensearch.client.RequestOptions; import org.opensearch.client.indices.CreateIndexRequest; import org.opensearch.client.indices.GetIndexRequest; -import org.opensearch.client.indices.GetIndexResponse; -import org.opensearch.client.indices.PutMappingRequest; -import org.opensearch.cluster.metadata.MappingMetadata; -import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.flint.common.metadata.FlintMetadata; import org.opensearch.flint.core.FlintClient; import org.opensearch.flint.core.FlintOptions; import org.opensearch.flint.core.IRestHighLevelClient; -import org.opensearch.flint.core.metadata.FlintMetadata; import scala.Option; import java.io.IOException; -import java.util.Arrays; -import java.util.Locale; -import java.util.Map; -import java.util.Objects; -import java.util.Set; import java.util.logging.Logger; -import java.util.stream.Collectors; /** * Flint client implementation for OpenSearch storage. @@ -36,13 +26,6 @@ public class FlintOpenSearchClient implements FlintClient { private static final Logger LOG = Logger.getLogger(FlintOpenSearchClient.class.getName()); - /** - * Invalid index name characters to percent-encode, - * excluding '*' because it's reserved for pattern matching. - */ - private final static Set INVALID_INDEX_NAME_CHARS = - Set.of(' ', ',', ':', '"', '+', '/', '\\', '|', '?', '#', '>', '<'); - private final FlintOptions options; public FlintOpenSearchClient(FlintOptions options) { @@ -52,7 +35,7 @@ public FlintOpenSearchClient(FlintOptions options) { @Override public void createIndex(String indexName, FlintMetadata metadata) { LOG.info("Creating Flint index " + indexName + " with metadata " + metadata); - createIndex(indexName, metadata.getContent(), metadata.indexSettings()); + createIndex(indexName, FlintOpenSearchIndexMetadataService.serialize(metadata, false), metadata.indexSettings()); } protected void createIndex(String indexName, String mapping, Option settings) { @@ -81,58 +64,6 @@ public boolean exists(String indexName) { } } - @Override - public Map getAllIndexMetadata(String... indexNamePattern) { - LOG.info("Fetching all Flint index metadata for pattern " + String.join(",", indexNamePattern)); - String[] indexNames = - Arrays.stream(indexNamePattern).map(this::sanitizeIndexName).toArray(String[]::new); - try (IRestHighLevelClient client = createClient()) { - GetIndexRequest request = new GetIndexRequest(indexNames); - GetIndexResponse response = client.getIndex(request, RequestOptions.DEFAULT); - - return Arrays.stream(response.getIndices()) - .collect(Collectors.toMap( - index -> index, - index -> FlintMetadata.apply( - response.getMappings().get(index).source().toString(), - response.getSettings().get(index).toString() - ) - )); - } catch (Exception e) { - throw new IllegalStateException("Failed to get Flint index metadata for " + - String.join(",", indexNames), e); - } - } - - @Override - public FlintMetadata getIndexMetadata(String indexName) { - LOG.info("Fetching Flint index metadata for " + indexName); - String osIndexName = sanitizeIndexName(indexName); - try (IRestHighLevelClient client = createClient()) { - GetIndexRequest request = new GetIndexRequest(osIndexName); - GetIndexResponse response = client.getIndex(request, RequestOptions.DEFAULT); - - MappingMetadata mapping = response.getMappings().get(osIndexName); - Settings settings = response.getSettings().get(osIndexName); - return FlintMetadata.apply(mapping.source().string(), settings.toString()); - } catch (Exception e) { - throw new IllegalStateException("Failed to get Flint index metadata for " + osIndexName, e); - } - } - - @Override - public void updateIndex(String indexName, FlintMetadata metadata) { - LOG.info("Updating Flint index " + indexName + " with metadata " + metadata); - String osIndexName = sanitizeIndexName(indexName); - try (IRestHighLevelClient client = createClient()) { - PutMappingRequest request = new PutMappingRequest(osIndexName); - request.source(metadata.getContent(), XContentType.JSON); - client.updateIndexMapping(request, RequestOptions.DEFAULT); - } catch (Exception e) { - throw new IllegalStateException("Failed to update Flint index " + osIndexName, e); - } - } - @Override public void deleteIndex(String indexName) { LOG.info("Deleting Flint index " + indexName); @@ -157,40 +88,7 @@ public IRestHighLevelClient createClient() { return OpenSearchClientUtils.createClient(options); } - /* - * Because OpenSearch requires all lowercase letters in index name, we have to - * lowercase all letters in the given Flint index name. - */ - private String toLowercase(String indexName) { - Objects.requireNonNull(indexName); - - return indexName.toLowerCase(Locale.ROOT); - } - - /* - * Percent-encode invalid OpenSearch index name characters. - */ - private String percentEncode(String indexName) { - Objects.requireNonNull(indexName); - - StringBuilder builder = new StringBuilder(indexName.length()); - for (char ch : indexName.toCharArray()) { - if (INVALID_INDEX_NAME_CHARS.contains(ch)) { - builder.append(String.format("%%%02X", (int) ch)); - } else { - builder.append(ch); - } - } - return builder.toString(); - } - - /* - * Sanitize index name to comply with OpenSearch index name restrictions. - */ private String sanitizeIndexName(String indexName) { - Objects.requireNonNull(indexName); - - String encoded = percentEncode(indexName); - return toLowercase(encoded); + return OpenSearchClientUtils.sanitizeIndexName(indexName); } } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchIndexMetadataService.scala b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchIndexMetadataService.scala new file mode 100644 index 000000000..fad2f1b63 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchIndexMetadataService.scala @@ -0,0 +1,203 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.storage + +import java.util + +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import org.opensearch.client.RequestOptions +import org.opensearch.client.indices.{GetIndexRequest, GetIndexResponse, PutMappingRequest} +import org.opensearch.common.xcontent.XContentType +import org.opensearch.flint.common.FlintVersion +import org.opensearch.flint.common.metadata.{FlintIndexMetadataService, FlintMetadata} +import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.core.IRestHighLevelClient +import org.opensearch.flint.core.metadata.FlintJsonHelper._ + +import org.apache.spark.internal.Logging + +class FlintOpenSearchIndexMetadataService(options: FlintOptions) + extends FlintIndexMetadataService + with Logging { + + override def getIndexMetadata(indexName: String): FlintMetadata = { + logInfo(s"Fetching Flint index metadata for $indexName") + val osIndexName = OpenSearchClientUtils.sanitizeIndexName(indexName) + var client: IRestHighLevelClient = null + try { + client = OpenSearchClientUtils.createClient(options) + val request = new GetIndexRequest(osIndexName) + val response = client.getIndex(request, RequestOptions.DEFAULT) + val mapping = response.getMappings.get(osIndexName) + val settings = response.getSettings.get(osIndexName) + FlintOpenSearchIndexMetadataService.deserialize(mapping.source.string, settings.toString) + } catch { + case e: Exception => + throw new IllegalStateException( + "Failed to get Flint index metadata for " + osIndexName, + e) + } finally + if (client != null) { + client.close() + } + } + + override def getAllIndexMetadata(indexNamePattern: String*): util.Map[String, FlintMetadata] = { + logInfo(s"Fetching all Flint index metadata for pattern ${indexNamePattern.mkString(",")}"); + val indexNames = indexNamePattern.map(OpenSearchClientUtils.sanitizeIndexName) + var client: IRestHighLevelClient = null + try { + client = OpenSearchClientUtils.createClient(options) + val request = new GetIndexRequest(indexNames: _*) + val response: GetIndexResponse = client.getIndex(request, RequestOptions.DEFAULT) + + response.getIndices + .map(index => + index -> FlintOpenSearchIndexMetadataService.deserialize( + response.getMappings.get(index).source().string(), + response.getSettings.get(index).toString)) + .toMap + .asJava + } catch { + case e: Exception => + throw new IllegalStateException( + s"Failed to get Flint index metadata for ${indexNames.mkString(",")}", + e) + } finally + if (client != null) { + client.close() + } + } + + override def updateIndexMetadata(indexName: String, metadata: FlintMetadata): Unit = { + logInfo(s"Updating Flint index $indexName with metadata $metadata"); + val osIndexName = OpenSearchClientUtils.sanitizeIndexName(indexName) + var client: IRestHighLevelClient = null + try { + client = OpenSearchClientUtils.createClient(options) + val request = new PutMappingRequest(osIndexName) + request.source(FlintOpenSearchIndexMetadataService.serialize(metadata), XContentType.JSON) + client.updateIndexMapping(request, RequestOptions.DEFAULT) + } catch { + case e: Exception => + throw new IllegalStateException(s"Failed to update Flint index $osIndexName", e) + } finally + if (client != null) { + client.close() + } + } + + // Do nothing. For OpenSearch, deleting the index will also delete its metadata + override def deleteIndexMetadata(indexName: String): Unit = {} +} + +object FlintOpenSearchIndexMetadataService { + + def serialize(metadata: FlintMetadata): String = { + serialize(metadata, true) + } + + /** + * Generate JSON content as index metadata. + * + * @param metadata + * Flint index metadata + * @param includeSpec + * Whether to include _meta field in the JSON content for Flint index specification + * @return + * JSON content + */ + def serialize(metadata: FlintMetadata, includeSpec: Boolean): String = { + try { + buildJson(builder => { + if (includeSpec) { + // Add _meta field + objectField(builder, "_meta") { + builder + .field("version", metadata.version.version) + .field("name", metadata.name) + .field("kind", metadata.kind) + .field("source", metadata.source) + .field("indexedColumns", metadata.indexedColumns) + + if (metadata.latestId.isDefined) { + builder.field("latestId", metadata.latestId.get) + } + optionalObjectField(builder, "options", metadata.options) + optionalObjectField(builder, "properties", metadata.properties) + } + } + + // Add properties (schema) field + builder.field("properties", metadata.schema) + }) + } catch { + case e: Exception => + throw new IllegalStateException("Failed to jsonify Flint metadata", e) + } + } + + /** + * Construct Flint metadata with JSON content and index settings. + * + * @param content + * JSON content + * @param settings + * index settings + * @return + * Flint metadata + */ + def deserialize(content: String, settings: String): FlintMetadata = { + val metadata = deserialize(content) + metadata.copy(indexSettings = Option(settings)) + } + + /** + * Parse the given JSON content and construct Flint metadata class. + * + * @param content + * JSON content + * @return + * Flint metadata + */ + def deserialize(content: String): FlintMetadata = { + try { + val builder = new FlintMetadata.Builder() + parseJson(content) { (parser, fieldName) => + { + fieldName match { + case "_meta" => + parseObjectField(parser) { (parser, innerFieldName) => + { + innerFieldName match { + case "version" => builder.version(FlintVersion.apply(parser.text())) + case "name" => builder.name(parser.text()) + case "kind" => builder.kind(parser.text()) + case "source" => builder.source(parser.text()) + case "indexedColumns" => + parseArrayField(parser) { + builder.addIndexedColumn(parser.map()) + } + case "options" => builder.options(parser.map()) + case "properties" => builder.properties(parser.map()) + case _ => // Handle other fields as needed + } + } + } + case "properties" => + builder.schema(parser.map()) + case _ => // Ignore other fields, for instance, dynamic. + } + } + } + builder.build() + } catch { + case e: Exception => + throw new IllegalStateException("Failed to parse metadata JSON", e) + } + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java index 7944de5ae..8c327b664 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java @@ -40,7 +40,7 @@ * - entryVersion: * - seqNo (Long): OpenSearch sequence number * - primaryTerm (Long): OpenSearch primary term - * - storageContext: + * - properties: * - dataSourceName (String): OpenSearch data source associated */ public class FlintOpenSearchMetadataLog implements FlintMetadataLog { @@ -67,7 +67,8 @@ public FlintOpenSearchMetadataLog(FlintOptions options, String flintIndexName, S this.options = options; this.metadataLogIndexName = metadataLogIndexName; this.dataSourceName = options.getDataSourceName(); - this.latestId = Base64.getEncoder().encodeToString(flintIndexName.getBytes()); + String osIndexName = OpenSearchClientUtils.sanitizeIndexName(flintIndexName); + this.latestId = Base64.getEncoder().encodeToString(osIndexName.getBytes()); } @Override diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java index 21241d7ab..0f80d07c9 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java @@ -8,6 +8,9 @@ import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import java.lang.reflect.Constructor; +import java.util.Locale; +import java.util.Objects; +import java.util.Set; import java.util.concurrent.atomic.AtomicReference; import org.apache.http.HttpHost; import org.apache.http.auth.AuthScope; @@ -35,6 +38,13 @@ public class OpenSearchClientUtils { */ public final static String META_LOG_NAME_PREFIX = ".query_execution_request"; + /** + * Invalid index name characters to percent-encode, + * excluding '*' because it's reserved for pattern matching. + */ + private final static Set INVALID_INDEX_NAME_CHARS = + Set.of(' ', ',', ':', '"', '+', '/', '\\', '|', '?', '#', '>', '<'); + /** * Used in IT. */ @@ -58,7 +68,45 @@ public static RestHighLevelClient createRestHighLevelClient(FlintOptions options } public static IRestHighLevelClient createClient(FlintOptions options) { - return new RestHighLevelClientWrapper(createRestHighLevelClient(options)); + return new RestHighLevelClientWrapper(createRestHighLevelClient(options), + BulkRequestRateLimiterHolder.getBulkRequestRateLimiter(options)); + } + + /** + * Sanitize index name to comply with OpenSearch index name restrictions. + */ + public static String sanitizeIndexName(String indexName) { + Objects.requireNonNull(indexName); + + String encoded = percentEncode(indexName); + return toLowercase(encoded); + } + + /** + * Because OpenSearch requires all lowercase letters in index name, we have to + * lowercase all letters in the given Flint index name. + */ + private static String toLowercase(String indexName) { + Objects.requireNonNull(indexName); + + return indexName.toLowerCase(Locale.ROOT); + } + + /** + * Percent-encode invalid OpenSearch index name characters. + */ + private static String percentEncode(String indexName) { + Objects.requireNonNull(indexName); + + StringBuilder builder = new StringBuilder(indexName.length()); + for (char ch : indexName.toCharArray()) { + if (INVALID_INDEX_NAME_CHARS.contains(ch)) { + builder.append(String.format("%%%02X", (int) ch)); + } else { + builder.append(ch); + } + } + return builder.toString(); } private static RestClientBuilder configureSigV4Auth(RestClientBuilder restClientBuilder, FlintOptions options) { diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/table/OpenSearchCluster.java b/flint-core/src/main/scala/org/opensearch/flint/core/table/OpenSearchCluster.java new file mode 100644 index 000000000..2736177cc --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/table/OpenSearchCluster.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.table; + +import org.opensearch.client.RequestOptions; +import org.opensearch.client.indices.GetIndexRequest; +import org.opensearch.client.indices.GetIndexResponse; +import org.opensearch.flint.core.FlintOptions; +import org.opensearch.flint.core.IRestHighLevelClient; +import org.opensearch.flint.core.MetaData; +import org.opensearch.flint.core.storage.OpenSearchClientUtils; + +import java.util.Arrays; +import java.util.List; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +public class OpenSearchCluster { + + private static final Logger LOG = Logger.getLogger(OpenSearchCluster.class.getName()); + + /** + * Creates list of OpenSearchIndexTable instance of indices in OpenSearch domain. + * + * @param indexName + * tableName support (1) single index name. (2) wildcard index name. (3) comma sep index name. + * @param options + * The options for Flint. + * @return + * A list of OpenSearchIndexTable instance. + */ + public static List apply(String indexName, FlintOptions options) { + return getAllOpenSearchTableMetadata(options, indexName.split(",")) + .stream() + .map(metadata -> new OpenSearchIndexTable(metadata, options)) + .collect(Collectors.toList()); + } + + /** + * Retrieve all metadata for OpenSearch table whose name matches the given pattern. + * + * @param options The options for Flint. + * @param indexNamePattern index name pattern + * @return list of OpenSearch table metadata + */ + public static List getAllOpenSearchTableMetadata(FlintOptions options, String... indexNamePattern) { + LOG.info("Fetching all OpenSearch table metadata for pattern " + String.join(",", indexNamePattern)); + String[] indexNames = + Arrays.stream(indexNamePattern).map(OpenSearchClientUtils::sanitizeIndexName).toArray(String[]::new); + try (IRestHighLevelClient client = OpenSearchClientUtils.createClient(options)) { + GetIndexRequest request = new GetIndexRequest(indexNames); + GetIndexResponse response = client.getIndex(request, RequestOptions.DEFAULT); + + return Arrays.stream(response.getIndices()) + .map(index -> new MetaData( + index, + response.getMappings().get(index).source().string(), + response.getSettings().get(index).toString())) + .collect(Collectors.toList()); + } catch (Exception e) { + throw new IllegalStateException("Failed to get OpenSearch table metadata for " + + String.join(",", indexNames), e); + } + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/table/OpenSearchIndexTable.scala b/flint-core/src/main/scala/org/opensearch/flint/core/table/OpenSearchIndexTable.scala index 783163687..57c770eb8 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/table/OpenSearchIndexTable.scala +++ b/flint-core/src/main/scala/org/opensearch/flint/core/table/OpenSearchIndexTable.scala @@ -5,8 +5,6 @@ package org.opensearch.flint.core.table -import scala.collection.JavaConverters._ - import org.json4s.{Formats, NoTypeHints} import org.json4s.JsonAST.JString import org.json4s.jackson.JsonMethods @@ -146,28 +144,3 @@ object OpenSearchIndexTable { */ val maxSplitSizeBytes = 10 * 1024 * 1024 } - -object OpenSearchCluster { - - /** - * Creates list of OpenSearchIndexTable instance of indices in OpenSearch domain. - * - * @param indexName - * tableName support (1) single index name. (2) wildcard index name. (3) comma sep index name. - * @param options - * The options for Flint. - * @return - * An list of OpenSearchIndexTable instance. - */ - def apply(indexName: String, options: FlintOptions): Seq[OpenSearchIndexTable] = { - val client = FlintClientBuilder.build(options) - client - .getAllIndexMetadata(indexName.split(","): _*) - .asScala - .toMap - .map(entry => { - new OpenSearchIndexTable(MetaData.apply(entry._1, entry._2), options) - }) - .toSeq - } -} diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterHolderTest.java b/flint-core/src/test/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterHolderTest.java new file mode 100644 index 000000000..f2f160973 --- /dev/null +++ b/flint-core/src/test/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterHolderTest.java @@ -0,0 +1,20 @@ +package org.opensearch.flint.core.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.Test; +import org.opensearch.flint.core.FlintOptions; + +class BulkRequestRateLimiterHolderTest { + FlintOptions flintOptions = new FlintOptions(ImmutableMap.of()); + @Test + public void getBulkRequestRateLimiter() { + BulkRequestRateLimiter instance0 = BulkRequestRateLimiterHolder.getBulkRequestRateLimiter(flintOptions); + BulkRequestRateLimiter instance1 = BulkRequestRateLimiterHolder.getBulkRequestRateLimiter(flintOptions); + + assertNotNull(instance0); + assertEquals(instance0, instance1); + } +} \ No newline at end of file diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterTest.java b/flint-core/src/test/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterTest.java new file mode 100644 index 000000000..d86f06d24 --- /dev/null +++ b/flint-core/src/test/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterTest.java @@ -0,0 +1,46 @@ +package org.opensearch.flint.core.storage; + + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.Test; +import org.opensearch.flint.core.FlintOptions; + +class BulkRequestRateLimiterTest { + FlintOptions flintOptionsWithRateLimit = new FlintOptions(ImmutableMap.of(FlintOptions.BULK_REQUEST_RATE_LIMIT_PER_NODE, "1")); + FlintOptions flintOptionsWithoutRateLimit = new FlintOptions(ImmutableMap.of(FlintOptions.BULK_REQUEST_RATE_LIMIT_PER_NODE, "0")); + + @Test + void acquirePermitWithRateConfig() throws Exception { + BulkRequestRateLimiter limiter = new BulkRequestRateLimiter(flintOptionsWithRateLimit); + + assertTrue(timer(() -> { + limiter.acquirePermit(); + limiter.acquirePermit(); + }) >= 1000); + } + + @Test + void acquirePermitWithoutRateConfig() throws Exception { + BulkRequestRateLimiter limiter = new BulkRequestRateLimiter(flintOptionsWithoutRateLimit); + + assertTrue(timer(() -> { + limiter.acquirePermit(); + limiter.acquirePermit(); + }) < 100); + } + + private interface Procedure { + void run() throws Exception; + } + + private long timer(Procedure procedure) throws Exception { + long start = System.currentTimeMillis(); + procedure.run(); + long end = System.currentTimeMillis(); + return end - start; + } +} diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/metadata/FlintMetadataSuite.scala b/flint-core/src/test/scala/org/opensearch/flint/core/storage/FlintOpenSearchIndexMetadataServiceSuite.scala similarity index 65% rename from flint-core/src/test/scala/org/opensearch/flint/core/metadata/FlintMetadataSuite.scala rename to flint-core/src/test/scala/org/opensearch/flint/core/storage/FlintOpenSearchIndexMetadataServiceSuite.scala index 01b4e266c..f1ed09531 100644 --- a/flint-core/src/test/scala/org/opensearch/flint/core/metadata/FlintMetadataSuite.scala +++ b/flint-core/src/test/scala/org/opensearch/flint/core/storage/FlintOpenSearchIndexMetadataServiceSuite.scala @@ -3,16 +3,17 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.core.metadata +package org.opensearch.flint.core.storage import scala.collection.JavaConverters.mapAsJavaMapConverter import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson -import org.opensearch.flint.core.FlintVersion.current +import org.opensearch.flint.common.FlintVersion.current +import org.opensearch.flint.common.metadata.FlintMetadata import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers -class FlintMetadataSuite extends AnyFlatSpec with Matchers { +class FlintOpenSearchIndexMetadataServiceSuite extends AnyFlatSpec with Matchers { /** Test Flint index meta JSON string */ val testMetadataJson: String = s""" @@ -60,6 +61,16 @@ class FlintMetadataSuite extends AnyFlatSpec with Matchers { | } |""".stripMargin + val testNoSpec: String = s""" + | { + | "properties": { + | "test_field": { + | "type": "os_type" + | } + | } + | } + |""".stripMargin + val testIndexSettingsJson: String = """ | { "number_of_shards": 3 } @@ -67,7 +78,8 @@ class FlintMetadataSuite extends AnyFlatSpec with Matchers { "constructor" should "deserialize the given JSON and assign parsed value to field" in { Seq(testMetadataJson, testDynamic).foreach(mapping => { - val metadata = FlintMetadata(mapping, testIndexSettingsJson) + val metadata = + FlintOpenSearchIndexMetadataService.deserialize(mapping, testIndexSettingsJson) metadata.version shouldBe current() metadata.name shouldBe "test_index" metadata.kind shouldBe "test_kind" @@ -77,15 +89,27 @@ class FlintMetadataSuite extends AnyFlatSpec with Matchers { }) } - "getContent" should "serialize all fields to JSON" in { + "serialize" should "serialize all fields to JSON" in { + val builder = new FlintMetadata.Builder + builder.name("test_index") + builder.kind("test_kind") + builder.source("test_source_table") + builder.addIndexedColumn(Map[String, AnyRef]("test_field" -> "spark_type").asJava) + builder.schema(Map[String, AnyRef]("test_field" -> Map("type" -> "os_type").asJava).asJava) + + val metadata = builder.build() + FlintOpenSearchIndexMetadataService.serialize(metadata) should matchJson(testMetadataJson) + } + + "serialize without spec" should "serialize all fields to JSON without adding _meta field" in { val builder = new FlintMetadata.Builder builder.name("test_index") builder.kind("test_kind") builder.source("test_source_table") - builder.addIndexedColumn(Map[String, AnyRef]("test_field" -> "spark_type").asJava); - builder.schema("""{"properties": {"test_field": {"type": "os_type"}}}""") + builder.addIndexedColumn(Map[String, AnyRef]("test_field" -> "spark_type").asJava) + builder.schema(Map[String, AnyRef]("test_field" -> Map("type" -> "os_type").asJava).asJava) val metadata = builder.build() - metadata.getContent should matchJson(testMetadataJson) + FlintOpenSearchIndexMetadataService.serialize(metadata, false) should matchJson(testNoSpec) } } diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/storage/OpenSearchClientUtilsSuite.scala b/flint-core/src/test/scala/org/opensearch/flint/core/storage/OpenSearchClientUtilsSuite.scala new file mode 100644 index 000000000..abcf9edf8 --- /dev/null +++ b/flint-core/src/test/scala/org/opensearch/flint/core/storage/OpenSearchClientUtilsSuite.scala @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.storage + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class OpenSearchClientUtilsSuite extends AnyFlatSpec with Matchers { + + "sanitizeIndexName" should "percent-encode invalid OpenSearch index name characters and lowercase all characters" in { + val indexName = "TEST :\"+/\\|?#><" + val sanitizedIndexName = OpenSearchClientUtils.sanitizeIndexName(indexName) + sanitizedIndexName shouldBe "test%20%3a%22%2b%2f%5c%7c%3f%23%3e%3c" + } +} diff --git a/flint-spark-integration/lib/LogsConnectorSpark-1.0.jar b/flint-spark-integration/lib/LogsConnectorSpark-1.0.jar deleted file mode 100644 index 0aa50bbb2..000000000 Binary files a/flint-spark-integration/lib/LogsConnectorSpark-1.0.jar and /dev/null differ diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintReadOnlyTable.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintReadOnlyTable.scala index ed6902841..e9f6f5ea1 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintReadOnlyTable.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintReadOnlyTable.scala @@ -7,6 +7,8 @@ package org.apache.spark.sql.flint import java.util +import scala.collection.JavaConverters._ + import org.opensearch.flint.core.table.OpenSearchCluster import org.apache.spark.sql.SparkSession @@ -39,7 +41,7 @@ class FlintReadOnlyTable( lazy val name: String = flintSparkConf.tableName() lazy val tables: Seq[org.opensearch.flint.core.Table] = - OpenSearchCluster.apply(name, flintSparkConf.flintOptions()) + OpenSearchCluster.apply(name, flintSparkConf.flintOptions()).asScala lazy val resolvedTablesSchema: StructType = tables.headOption .map(tbl => FlintDataType.deserialize(tbl.schema().asJson())) diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScan.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScan.scala index 201e7c748..9b9f70be0 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScan.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScan.scala @@ -5,7 +5,8 @@ package org.apache.spark.sql.flint -import org.apache.spark.internal.Logging +import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterMightContain + import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan} import org.apache.spark.sql.flint.config.FlintSparkConf @@ -17,8 +18,7 @@ case class FlintScan( options: FlintSparkConf, pushedPredicates: Array[Predicate]) extends Scan - with Batch - with Logging { + with Batch { override def readSchema(): StructType = schema @@ -44,10 +44,13 @@ case class FlintScan( * Print pushedPredicates when explain(mode="extended"). Learn from SPARK JDBCScan. */ override def description(): String = { - super.description() + ", PushedPredicates: " + seqToString(pushedPredicates) + super.description() + ", PushedPredicates: " + pushedPredicates + .map { + case p if p.name().equalsIgnoreCase(BloomFilterMightContain.NAME) => p.name() + case p => p.toString() + } + .mkString("[", ", ", "]") } - - private def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]") } /** diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScanBuilder.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScanBuilder.scala index 0c6f7d700..82a570b2f 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScanBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScanBuilder.scala @@ -5,6 +5,8 @@ package org.apache.spark.sql.flint +import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterMightContain + import org.apache.spark.internal.Logging import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownV2Filters} @@ -34,4 +36,5 @@ case class FlintScanBuilder( } override def pushedPredicates(): Array[Predicate] = pushedPredicate + .filterNot(_.name().equalsIgnoreCase(BloomFilterMightContain.NAME)) } diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala index 25b4db940..8691de3d0 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala @@ -48,4 +48,6 @@ case class FlintWrite( override def toBatch: BatchWrite = this override def toStreaming: StreamingWrite = this + + override def useCommitCoordinator(): Boolean = false } diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index dc110afb9..1d12d004e 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -136,6 +136,12 @@ object FlintSparkConf { .doc("max retries on failed HTTP request, 0 means retry is disabled, default is 3") .createWithDefault(String.valueOf(FlintRetryOptions.DEFAULT_MAX_RETRIES)) + val BULK_REQUEST_RATE_LIMIT_PER_NODE = + FlintConfig(s"spark.datasource.flint.${FlintOptions.BULK_REQUEST_RATE_LIMIT_PER_NODE}") + .datasourceOption() + .doc("[Experimental] Rate limit (requests/sec) for bulk request per worker node. Rate won't be limited by default") + .createWithDefault(FlintOptions.DEFAULT_BULK_REQUEST_RATE_LIMIT_PER_NODE) + val RETRYABLE_HTTP_STATUS_CODES = FlintConfig(s"spark.datasource.flint.${FlintRetryOptions.RETRYABLE_HTTP_STATUS_CODES}") .datasourceOption() @@ -187,10 +193,16 @@ object FlintSparkConf { .doc("data source name") .createOptional() val CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS = - FlintConfig(FlintOptions.CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS) + FlintConfig(s"spark.datasource.flint.${FlintOptions.CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS}") .datasourceOption() .doc("custom Flint metadata log service class") .createOptional() + val CUSTOM_FLINT_INDEX_METADATA_SERVICE_CLASS = + FlintConfig( + s"spark.datasource.flint.${FlintOptions.CUSTOM_FLINT_INDEX_METADATA_SERVICE_CLASS}") + .datasourceOption() + .doc("custom Flint index metadata service class") + .createOptional() val QUERY = FlintConfig("spark.flint.job.query") .doc("Flint query for batch and streaming job") @@ -275,6 +287,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable AUTH, MAX_RETRIES, RETRYABLE_HTTP_STATUS_CODES, + BULK_REQUEST_RATE_LIMIT_PER_NODE, REGION, CUSTOM_AWS_CREDENTIALS_PROVIDER, SERVICE_NAME, @@ -291,6 +304,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable RETRYABLE_EXCEPTION_CLASS_NAMES, DATA_SOURCE_NAME, CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS, + CUSTOM_FLINT_INDEX_METADATA_SERVICE_CLASS, SESSION_ID, REQUEST_INDEX, METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER, diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/json/FlintJacksonParser.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/json/FlintJacksonParser.scala index 31db1909f..a9e9122fb 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/json/FlintJacksonParser.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/json/FlintJacksonParser.scala @@ -118,7 +118,8 @@ class FlintJacksonParser( array.toArray[InternalRow](schema) } case START_ARRAY => - throw QueryExecutionErrors.cannotParseJsonArraysAsStructsError() + throw QueryExecutionErrors.cannotParseJsonArraysAsStructsError( + parser.currentToken().asString()) } } @@ -420,17 +421,17 @@ class FlintJacksonParser( case VALUE_STRING if parser.getTextLength < 1 && allowEmptyString => dataType match { case FloatType | DoubleType | TimestampType | DateType => - throw QueryExecutionErrors.failToParseEmptyStringForDataTypeError(dataType) + throw QueryExecutionErrors.emptyJsonFieldValueError(dataType) case _ => null } case VALUE_STRING if parser.getTextLength < 1 => - throw QueryExecutionErrors.failToParseEmptyStringForDataTypeError(dataType) + throw QueryExecutionErrors.emptyJsonFieldValueError(dataType) case token => // We cannot parse this token based on the given data type. So, we throw a // RuntimeException and this exception will be caught by `parse` method. - throw QueryExecutionErrors.failToParseValueForDataTypeError(parser, token, dataType) + throw QueryExecutionErrors.cannotParseJSONFieldError(parser, token, dataType) } /** @@ -537,7 +538,7 @@ class FlintJacksonParser( // JSON parser currently doesn't support partial results for corrupted records. // For such records, all fields other than the field configured by // `columnNameOfCorruptRecord` are set to `null`. - throw BadRecordException(() => recordLiteral(record), () => None, e) + throw BadRecordException(() => recordLiteral(record), cause = e) case e: CharConversionException if options.encoding.isEmpty => val msg = """JSON parser cannot handle a character in its input. @@ -545,11 +546,11 @@ class FlintJacksonParser( |""".stripMargin + e.getMessage val wrappedCharException = new CharConversionException(msg) wrappedCharException.initCause(e) - throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException) + throw BadRecordException(() => recordLiteral(record), cause = wrappedCharException) case PartialResultException(row, cause) => throw BadRecordException( record = () => recordLiteral(record), - partialResult = () => Some(row), + partialResults = () => Array(row), cause) } } diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala index 06c92882b..d71bc5d12 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala @@ -9,6 +9,7 @@ import java.util.concurrent.ScheduledExecutorService import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog._ +import org.apache.spark.sql.internal.SQLConf.DEFAULT_CATALOG import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.util.{ShutdownHookManager, ThreadUtils} @@ -72,14 +73,8 @@ package object flint { def qualifyTableName(spark: SparkSession, tableName: String): String = { val (catalog, ident) = parseTableName(spark, tableName) - // Tricky that our Flint delegate catalog's name has to be spark_catalog - // so we have to find its actual name in CatalogManager - val catalogMgr = spark.sessionState.catalogManager - val catalogName = - catalogMgr - .listCatalogs(Some("*")) - .find(catalogMgr.catalog(_) == catalog) - .getOrElse(catalog.name()) + // more reading at https://github.com/opensearch-project/opensearch-spark/issues/319. + val catalogName = resolveCatalogName(spark, catalog) s"$catalogName.${ident.namespace.mkString(".")}.${ident.name}" } @@ -134,4 +129,41 @@ package object flint { def findField(rootField: StructType, fieldName: String): Option[StructField] = { rootField.findNestedField(fieldName.split('.')).map(_._2) } + + /** + * Resolve catalog name. spark.sql.defaultCatalog name is returned if catalog.name is + * spark_catalog otherwise, catalog.name is returned. + * @see + * issue319 + * + * @param spark + * Spark Session + * @param catalog + * Spark Catalog + * @return + * catalog name. + */ + def resolveCatalogName(spark: SparkSession, catalog: CatalogPlugin): String = { + + /** + * Check if the provided catalog is a session catalog. + */ + if (CatalogV2Util.isSessionCatalog(catalog)) { + val defaultCatalog = spark.conf.get(DEFAULT_CATALOG) + if (spark.sessionState.catalogManager.isCatalogRegistered(defaultCatalog)) { + defaultCatalog + } else { + + /** + * It may happen when spark.sql.defaultCatalog is configured, but there's no + * implementation. For instance, spark.sql.defaultCatalog = "unknown" + */ + throw new IllegalStateException(s"Unknown catalog name: $defaultCatalog") + } + } else { + // Return the name for non-session catalogs + catalog.name() + } + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index fe2f68333..3eb36010e 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -9,12 +9,13 @@ import scala.collection.JavaConverters._ import org.json4s.{Formats, NoTypeHints} import org.json4s.native.Serialization +import org.opensearch.flint.common.metadata.{FlintIndexMetadataService, FlintMetadata} import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState._ import org.opensearch.flint.common.metadata.log.FlintMetadataLogService import org.opensearch.flint.common.metadata.log.OptimisticTransaction import org.opensearch.flint.common.metadata.log.OptimisticTransaction.NO_LOG_ENTRY import org.opensearch.flint.core.{FlintClient, FlintClientBuilder} -import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.core.metadata.FlintIndexMetadataServiceBuilder import org.opensearch.flint.core.metadata.log.FlintMetadataLogServiceBuilder import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName._ @@ -47,10 +48,12 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w /** Flint client for low-level index operation */ private val flintClient: FlintClient = FlintClientBuilder.build(flintSparkConf.flintOptions()) + private val flintIndexMetadataService: FlintIndexMetadataService = { + FlintIndexMetadataServiceBuilder.build(flintSparkConf.flintOptions()) + } + override protected val flintMetadataLogService: FlintMetadataLogService = { - FlintMetadataLogServiceBuilder.build( - flintSparkConf.flintOptions(), - spark.sparkContext.getConf) + FlintMetadataLogServiceBuilder.build(flintSparkConf.flintOptions()) } /** Required by json4s parse function */ @@ -58,7 +61,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w /** Flint Spark index monitor */ val flintIndexMonitor: FlintSparkIndexMonitor = - new FlintSparkIndexMonitor(spark, flintMetadataLogService) + new FlintSparkIndexMonitor(spark, flintClient, flintMetadataLogService) /** * Create index builder for creating index with fluent API. @@ -114,9 +117,12 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w .commit(latest => if (latest == null) { // in case transaction capability is disabled flintClient.createIndex(indexName, metadata) + flintIndexMetadataService.updateIndexMetadata(indexName, metadata) } else { logInfo(s"Creating index with metadata log entry ID ${latest.id}") flintClient.createIndex(indexName, metadata.copy(latestId = Some(latest.id))) + flintIndexMetadataService + .updateIndexMetadata(indexName, metadata.copy(latestId = Some(latest.id))) }) } } @@ -163,7 +169,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w def describeIndexes(indexNamePattern: String): Seq[FlintSparkIndex] = { logInfo(s"Describing indexes with pattern $indexNamePattern") if (flintClient.exists(indexNamePattern)) { - flintClient + flintIndexMetadataService .getAllIndexMetadata(indexNamePattern) .asScala .map { case (indexName, metadata) => @@ -187,7 +193,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w def describeIndex(indexName: String): Option[FlintSparkIndex] = { logInfo(s"Describing index name $indexName") if (flintClient.exists(indexName)) { - val metadata = flintClient.getIndexMetadata(indexName) + val metadata = flintIndexMetadataService.getIndexMetadata(indexName) val metadataWithEntry = attachLatestLogEntry(indexName, metadata) FlintSparkIndexFactory.create(metadataWithEntry) } else { @@ -267,6 +273,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w .finalLog(_ => NO_LOG_ENTRY) .commit(_ => { flintClient.deleteIndex(indexName) + flintIndexMetadataService.deleteIndexMetadata(indexName) true }) } else { @@ -428,7 +435,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w .transientLog(latest => latest.copy(state = UPDATING)) .finalLog(latest => latest.copy(state = ACTIVE)) .commit(_ => { - flintClient.updateIndex(indexName, index.metadata) + flintIndexMetadataService.updateIndexMetadata(indexName, index.metadata) logInfo("Update index options complete") flintIndexMonitor.stopMonitor(indexName) stopRefreshingJob(indexName) @@ -453,7 +460,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w latest.copy(state = REFRESHING) }) .commit(_ => { - flintClient.updateIndex(indexName, index.metadata) + flintIndexMetadataService.updateIndexMetadata(indexName, index.metadata) logInfo("Update index options complete") indexRefresh.start(spark, flintSparkConf) }) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 34c2ae452..44ea5188f 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -5,10 +5,11 @@ package org.opensearch.flint.spark -import scala.collection.JavaConverters.mapAsJavaMapConverter +import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConverter} +import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry -import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.core.metadata.FlintJsonHelper._ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.flint.datatype.FlintDataType @@ -176,4 +177,18 @@ object FlintSparkIndex { val structType = StructType.fromDDL(catalogDDL) FlintDataType.serialize(structType) } + + def generateSchema(allFieldTypes: Map[String, String]): Map[String, AnyRef] = { + val schemaJson = generateSchemaJSON(allFieldTypes) + var schemaMap: Map[String, AnyRef] = Map.empty + + parseJson(schemaJson) { (parser, fieldName) => + fieldName match { + case "properties" => schemaMap = parser.map().asScala.toMap + case _ => // do nothing + } + } + + schemaMap + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala index aa3c23360..6c34e00e1 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala @@ -9,7 +9,7 @@ import java.util.Collections import scala.collection.JavaConverters.mapAsScalaMapConverter -import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE import org.opensearch.flint.spark.mv.FlintSparkMaterializedView diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala index 29acaea6b..2eb99ef34 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala @@ -18,6 +18,7 @@ import dev.failsafe.event.ExecutionAttemptedEvent import dev.failsafe.function.CheckedRunnable import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState.{FAILED, REFRESHING} import org.opensearch.flint.common.metadata.log.FlintMetadataLogService +import org.opensearch.flint.core.FlintClient import org.opensearch.flint.core.logging.ExceptionMessages.extractRootCause import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} @@ -31,11 +32,14 @@ import org.apache.spark.sql.flint.newDaemonThreadPoolScheduledExecutor * * @param spark * Spark session + * @param flintClient + * Flint client * @param flintMetadataLogService * Flint metadata log service */ class FlintSparkIndexMonitor( spark: SparkSession, + flintClient: FlintClient, flintMetadataLogService: FlintMetadataLogService) extends Logging { @@ -158,6 +162,11 @@ class FlintSparkIndexMonitor( if (isStreamingJobActive(indexName)) { logInfo("Streaming job is still active") flintMetadataLogService.recordHeartbeat(indexName) + + if (!flintClient.exists(indexName)) { + logWarning("Streaming job is active but data is deleted") + stopStreamingJobAndMonitor(indexName) + } } else { logError("Streaming job is not active. Cancelling monitor task") stopMonitor(indexName) @@ -172,10 +181,7 @@ class FlintSparkIndexMonitor( // Stop streaming job and its monitor if max retry limit reached if (errorCnt >= MAX_ERROR_COUNT) { - logInfo(s"Terminating streaming job and index monitor for $indexName") - stopStreamingJob(indexName) - stopMonitor(indexName) - logInfo(s"Streaming job and index monitor terminated") + stopStreamingJobAndMonitor(indexName) } } } @@ -184,6 +190,13 @@ class FlintSparkIndexMonitor( private def isStreamingJobActive(indexName: String): Boolean = spark.streams.active.exists(_.name == indexName) + private def stopStreamingJobAndMonitor(indexName: String): Unit = { + logInfo(s"Terminating streaming job and index monitor for $indexName") + stopStreamingJob(indexName) + stopMonitor(indexName) + logInfo(s"Streaming job and index monitor terminated") + } + private def stopStreamingJob(indexName: String): Unit = { val job = spark.streams.active.find(_.name == indexName) if (job.isDefined) { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index 0fade2ee7..8748bf874 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -7,10 +7,10 @@ package org.opensearch.flint.spark.covering import scala.collection.JavaConverters.mapAsJavaMapConverter +import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry -import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark._ -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchemaJSON, metadataBuilder, quotedTableName} +import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchema, metadataBuilder, quotedTableName} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE} @@ -53,13 +53,13 @@ case class FlintSparkCoveringIndex( Map[String, AnyRef]("columnName" -> colName, "columnType" -> colType).asJava }.toArray } - val schemaJson = generateSchemaJSON(indexedColumns) + val schema = generateSchema(indexedColumns).asJava val builder = metadataBuilder(this) .name(indexName) .source(tableName) .indexedColumns(indexColumnMaps) - .schema(schemaJson) + .schema(schema) // Add optional index properties filterCondition.map(builder.addProperty("filterCondition", _)) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index 48dfee50a..caa75be75 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -10,10 +10,10 @@ import java.util.Locale import scala.collection.JavaConverters.mapAsJavaMapConverter import scala.collection.convert.ImplicitConversions.`map AsScala` +import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry -import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder, FlintSparkIndexOptions} -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchemaJSON, metadataBuilder, StreamingRefresh} +import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchema, metadataBuilder, StreamingRefresh} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.function.TumbleFunction import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getFlintIndexName, MV_INDEX_TYPE} @@ -59,13 +59,13 @@ case class FlintSparkMaterializedView( outputSchema.map { case (colName, colType) => Map[String, AnyRef]("columnName" -> colName, "columnType" -> colType).asJava }.toArray - val schemaJson = generateSchemaJSON(outputSchema) + val schema = generateSchema(outputSchema).asJava metadataBuilder(this) .name(mvName) .source(query) .indexedColumns(indexColumnMaps) - .schema(schemaJson) + .schema(schema) .build() } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala index 8ce458055..3c14eb00d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala @@ -5,19 +5,15 @@ package org.opensearch.flint.spark.skipping -import com.amazon.awslogsdataaccesslayer.connectors.spark.LogsTable import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState.DELETED import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex} -import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE} +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE} -import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{And, Expression, Or, Predicate} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.flint.qualifyTableName /** @@ -62,46 +58,6 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan] } else { filter } - case filter @ Filter( - condition: Predicate, - relation @ DataSourceV2Relation(table, _, Some(catalog), Some(identifier), _)) - if hasNoDisjunction(condition) && - // Check if query plan already rewritten - table.isInstanceOf[LogsTable] && !table.asInstanceOf[LogsTable].hasFileIndexScan() => - val index = flint.describeIndex(getIndexName(catalog, identifier)) - if (isActiveSkippingIndex(index)) { - val skippingIndex = index.get.asInstanceOf[FlintSparkSkippingIndex] - val indexFilter = rewriteToIndexFilter(skippingIndex, condition) - /* - * Replace original LogsTable with a new one with file index scan: - * Filter(a=b) - * |- DataSourceV2Relation(A) - * |- LogsTable <== replaced with a new LogsTable with file index scan - */ - if (indexFilter.isDefined) { - val indexScan = flint.queryIndex(skippingIndex.name()) - val selectFileIndexScan = - // Non hybrid scan - // TODO: refactor common logic with file-based skipping index - indexScan - .filter(new Column(indexFilter.get)) - .select(FILE_PATH_COLUMN) - - // Construct LogsTable with file index scan - // It will build scan operator using log file ids collected from file index scan - val logsTable = table.asInstanceOf[LogsTable] - val newTable = new LogsTable( - logsTable.schema(), - logsTable.options(), - selectFileIndexScan, - logsTable.processedFields()) - filter.copy(child = relation.copy(table = newTable)) - } else { - filter - } - } else { - filter - } } private def getIndexName(table: CatalogTable): String = { @@ -112,11 +68,6 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan] getSkippingIndexName(qualifiedTableName) } - private def getIndexName(catalog: CatalogPlugin, identifier: Identifier): String = { - val qualifiedTableName = s"${catalog.name}.${identifier}" - getSkippingIndexName(qualifiedTableName) - } - private def hasNoDisjunction(condition: Expression): Boolean = { condition.collectFirst { case Or(_, _) => true diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndex.scala index bd7abcfb3..dc7875ccf 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndex.scala @@ -10,7 +10,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.FILE_PATH_COL import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.execution.datasources.{FileIndex, PartitionDirectory} +import org.apache.spark.sql.execution.datasources.{FileIndex, FileStatusWithMetadata, PartitionDirectory} import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.functions.isnull import org.apache.spark.sql.types.StructType @@ -96,7 +96,7 @@ case class FlintSparkSkippingFileIndex( .toSet } - private def isFileNotSkipped(selectedFiles: Set[String], f: FileStatus) = { + private def isFileNotSkipped(selectedFiles: Set[String], f: FileStatusWithMetadata) = { selectedFiles.contains(f.getPath.toUri.toString) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index da73ea01e..b6f21e455 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -7,8 +7,8 @@ package org.opensearch.flint.spark.skipping import scala.collection.JavaConverters.mapAsJavaMapConverter +import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry -import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark._ import org.opensearch.flint.spark.FlintSparkIndex._ import org.opensearch.flint.spark.FlintSparkIndexOptions.empty @@ -65,13 +65,13 @@ case class FlintSparkSkippingIndex( indexedColumns .flatMap(_.outputSchema()) .toMap + (FILE_PATH_COLUMN -> "string") - val schemaJson = generateSchemaJSON(fieldTypes) + val schema = generateSchema(fieldTypes).asJava metadataBuilder(this) .name(name()) .source(tableName) .indexedColumns(indexColumnMaps) - .schema(schemaJson) + .schema(schema) .build() } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala index de2ea772d..fa9b23951 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala @@ -11,7 +11,10 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKi import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GetStructField} +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.StringType /** * Skipping index strategy that defines skipping data structure building and reading logic. @@ -115,6 +118,17 @@ object FlintSparkSkippingStrategy { Seq(attr.name) case GetStructField(child, _, Some(name)) => extractColumnName(child) :+ name + /** + * Since Spark 3.4 add read-side padding, char_col = "sample char" became + * (staticinvoke(class org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils, + * StringType, readSidePadding, char_col#47, 20, true, false, true) = sample char ) + * + * When create skipping index, Spark did write-side padding. So read-side push down can be + * ignored. More reading, https://issues.apache.org/jira/browse/SPARK-40697 + */ + case StaticInvoke(staticObject, StringType, "readSidePadding", arguments, _, _, _, _) + if classOf[CharVarcharCodegenUtils].isAssignableFrom(staticObject) => + extractColumnName(arguments.head) case _ => Seq.empty } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterMightContain.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterMightContain.scala index 653abbd7d..c7ba66c9c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterMightContain.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterMightContain.scala @@ -8,6 +8,7 @@ package org.opensearch.flint.spark.skipping.bloomfilter import java.io.ByteArrayInputStream import org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter +import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterMightContain.NAME import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow @@ -40,7 +41,7 @@ case class BloomFilterMightContain(bloomFilterExpression: Expression, valueExpre override def dataType: DataType = BooleanType - override def symbol: String = "BLOOM_FILTER_MIGHT_CONTAIN" + override def symbol: String = NAME override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { @@ -109,6 +110,8 @@ case class BloomFilterMightContain(bloomFilterExpression: Expression, valueExpre object BloomFilterMightContain { + val NAME = "BLOOM_FILTER_MIGHT_CONTAIN" + /** * Generate bloom filter might contain function given the bloom filter column and value. * diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/sql/FlintCatalogSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/sql/FlintCatalogSuite.scala new file mode 100644 index 000000000..3c75cf541 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/apache/spark/sql/FlintCatalogSuite.scala @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.mockito.Mockito.when +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin} +import org.apache.spark.sql.flint.resolveCatalogName +import org.apache.spark.sql.internal.{SessionState, SQLConf} +import org.apache.spark.sql.internal.SQLConf.DEFAULT_CATALOG + +class FlintCatalogSuite extends SparkFunSuite with MockitoSugar { + + test("resolveCatalogName returns default catalog name for session catalog") { + assertCatalog() + .withCatalogName("spark_catalog") + .withDefaultCatalog("glue") + .registerCatalog("glue") + .shouldResolveCatalogName("glue") + } + + test("resolveCatalogName returns default catalog name for spark_catalog") { + assertCatalog() + .withCatalogName("spark_catalog") + .withDefaultCatalog("spark_catalog") + .registerCatalog("spark_catalog") + .shouldResolveCatalogName("spark_catalog") + } + + test("resolveCatalogName should return catalog name for non-session catalogs") { + assertCatalog() + .withCatalogName("custom_catalog") + .withDefaultCatalog("custom_catalog") + .registerCatalog("custom_catalog") + .shouldResolveCatalogName("custom_catalog") + } + + test( + "resolveCatalogName should throw RuntimeException when default catalog is not registered") { + assertCatalog() + .withCatalogName("spark_catalog") + .withDefaultCatalog("glue") + .registerCatalog("unknown") + .shouldThrowException() + } + + private def assertCatalog(): AssertionHelper = { + new AssertionHelper + } + + private class AssertionHelper { + private val spark = mock[SparkSession] + private val catalog = mock[CatalogPlugin] + private val sessionState = mock[SessionState] + private val catalogManager = mock[CatalogManager] + + def withCatalogName(catalogName: String): AssertionHelper = { + when(catalog.name()).thenReturn(catalogName) + this + } + + def withDefaultCatalog(catalogName: String): AssertionHelper = { + val conf = new SQLConf + conf.setConf(DEFAULT_CATALOG, catalogName) + when(spark.conf).thenReturn(new RuntimeConfig(conf)) + this + } + + def registerCatalog(catalogName: String): AssertionHelper = { + when(spark.sessionState).thenReturn(sessionState) + when(sessionState.catalogManager).thenReturn(catalogManager) + when(catalogManager.isCatalogRegistered(catalogName)).thenReturn(true) + this + } + + def shouldResolveCatalogName(expectedCatalogName: String): Unit = { + assert(resolveCatalogName(spark, catalog) == expectedCatalogName) + } + + def shouldThrowException(): Unit = { + assertThrows[IllegalStateException] { + resolveCatalogName(spark, catalog) + } + } + } +} diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala index ca8349cd5..1a164a9f2 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala @@ -9,6 +9,7 @@ import java.util.Optional import scala.collection.JavaConverters._ +import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.http.FlintRetryOptions._ import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper @@ -63,6 +64,16 @@ class FlintSparkConfSuite extends FlintSuite { retryOptions.getRetryableExceptionClassNames.get() shouldBe "java.net.ConnectException" } + test("test bulkRequestRateLimitPerNode default value") { + val options = FlintSparkConf().flintOptions() + options.getBulkRequestRateLimitPerNode shouldBe 0 + } + + test("test specified bulkRequestRateLimitPerNode") { + val options = FlintSparkConf(Map("bulkRequestRateLimitPerNode" -> "5").asJava).flintOptions() + options.getBulkRequestRateLimitPerNode shouldBe 5 + } + test("test metadata access AWS credentials provider option") { withSparkConf("spark.metadata.accessAWSCredentialsProvider") { spark.conf.set( diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala index a590eccb1..2c5518778 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala @@ -7,13 +7,14 @@ package org.opensearch.flint.spark.covering import scala.collection.JavaConverters._ -import org.mockito.ArgumentMatchers.any +import org.mockito.ArgumentMatchers.{any, eq => mockitoEq} import org.mockito.Mockito.{mockStatic, when, RETURNS_DEEP_STUBS} import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState.{ACTIVE, DELETED, IndexState} -import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions, IRestHighLevelClient} -import org.opensearch.flint.core.storage.OpenSearchClientUtils +import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions, MetaData} +import org.opensearch.flint.core.table.OpenSearchCluster import org.opensearch.flint.spark.FlintSpark +import org.opensearch.flint.spark.FlintSparkIndex.generateSchemaJSON import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName import org.scalatest.matchers.{Matcher, MatchResult} import org.scalatest.matchers.should.Matchers @@ -34,9 +35,8 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { private val clientBuilder = mockStatic(classOf[FlintClientBuilder]) private val client = mock[FlintClient](RETURNS_DEEP_STUBS) - /** Mock IRestHighLevelClient to avoid looking for real OpenSearch cluster */ - private val clientUtils = mockStatic(classOf[OpenSearchClientUtils]) - private val openSearchClient = mock[IRestHighLevelClient](RETURNS_DEEP_STUBS) + /** Mock OpenSearchCluster to avoid looking for real OpenSearch cluster */ + private val openSearchCluster = mockStatic(classOf[OpenSearchCluster]) /** Mock FlintSpark which is required by the rule. Deep stub required to replace spark val. */ private val flint = mock[FlintSpark](RETURNS_DEEP_STUBS) @@ -59,16 +59,17 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { clientBuilder .when(() => FlintClientBuilder.build(any(classOf[FlintOptions]))) .thenReturn(client) - when(flint.spark).thenReturn(spark) // Mock static - clientUtils - .when(() => OpenSearchClientUtils.createClient(any(classOf[FlintOptions]))) - .thenReturn(openSearchClient) + openSearchCluster + .when(() => OpenSearchCluster.apply(any(classOf[String]), any(classOf[FlintOptions]))) + .thenCallRealMethod() + when(flint.spark).thenReturn(spark) } override protected def afterAll(): Unit = { sql(s"DROP TABLE $testTable") clientBuilder.close() + openSearchCluster.close() super.afterAll() } @@ -274,8 +275,18 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { }) indexes.foreach { index => - when(client.getAllIndexMetadata(index.name())) - .thenReturn(Map.apply(index.name() -> index.metadata()).asJava) + { + openSearchCluster + .when(() => + OpenSearchCluster + .getAllOpenSearchTableMetadata(any(classOf[FlintOptions]), mockitoEq(index.name))) + .thenReturn( + Seq( + MetaData( + index.name, + generateSchemaJSON(index.indexedColumns), + index.metadata.indexSettings.getOrElse(""))).asJava) + } } rule.apply(plan) } @@ -284,8 +295,8 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { Matcher { (plan: LogicalPlan) => val result = plan.exists { case LogicalRelation(_, _, Some(table), _) => - // Table name in logical relation doesn't have catalog name - table.qualifiedName == expectedTableName.split('.').drop(1).mkString(".") + // Since Spark 3.4, Table name in logical relation have catalog name + table.qualifiedName == expectedTableName case _ => false } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala index c099a1a86..f03116de9 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala @@ -7,112 +7,105 @@ package org.opensearch.flint.spark.skipping import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ -import org.mockito.invocation.InvocationOnMock import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState.{DELETED, IndexState, REFRESHING} -import org.opensearch.flint.spark.FlintSpark -import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE} +import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions} +import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndexOptions} import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.SkippingKind import org.scalatest.matchers.{Matcher, MatchResult} import org.scalatest.matchers.should.Matchers import org.scalatestplus.mockito.MockitoSugar.mock -import org.apache.spark.SparkFunSuite +import org.apache.spark.FlintSuite import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, ExprId, Literal, Or} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, SubqueryAlias} -import org.apache.spark.sql.execution.datasources.{FileIndex, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions.col -import org.apache.spark.sql.sources.BaseRelation -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -class ApplyFlintSparkSkippingIndexSuite extends SparkFunSuite with Matchers { +class ApplyFlintSparkSkippingIndexSuite extends FlintSuite with Matchers { /** Test table and index */ private val testTable = "spark_catalog.default.apply_skipping_index_test" - private val testIndex = getSkippingIndexName(testTable) - private val testSchema = StructType( - Seq( - StructField("name", StringType, nullable = false), - StructField("age", IntegerType, nullable = false), - StructField("address", StringType, nullable = false))) - - /** Resolved column reference used in filtering condition */ - private val nameCol = - AttributeReference("name", StringType, nullable = false)(exprId = ExprId(1)) - private val ageCol = - AttributeReference("age", IntegerType, nullable = false)(exprId = ExprId(2)) - private val addressCol = - AttributeReference("address", StringType, nullable = false)(exprId = ExprId(3)) + + // Mock FlintClient to avoid looking for real OpenSearch cluster + private val clientBuilder = mockStatic(classOf[FlintClientBuilder]) + private val client = mock[FlintClient](RETURNS_DEEP_STUBS) + + /** Mock FlintSpark which is required by the rule */ + private val flint = mock[FlintSpark] + + override protected def beforeAll(): Unit = { + super.beforeAll() + sql(s"CREATE TABLE $testTable (name STRING, age INT, address STRING) USING JSON") + + // Mock static create method in FlintClientBuilder used by Flint data source + clientBuilder + .when(() => FlintClientBuilder.build(any(classOf[FlintOptions]))) + .thenReturn(client) + when(flint.spark).thenReturn(spark) + } + + override protected def afterAll(): Unit = { + sql(s"DROP TABLE $testTable") + clientBuilder.close() + super.afterAll() + } test("should not rewrite query if no skipping index") { assertFlintQueryRewriter() - .withSourceTable(testTable, testSchema) - .withFilter(EqualTo(nameCol, Literal("hello"))) + .withQuery(s"SELECT * FROM $testTable WHERE name = 'hello'") .withNoSkippingIndex() .shouldNotRewrite() } test("should not rewrite query if filter condition is disjunction") { assertFlintQueryRewriter() - .withSourceTable(testTable, testSchema) - .withFilter(Or(EqualTo(nameCol, Literal("hello")), EqualTo(ageCol, Literal(30)))) - .withSkippingIndex(testIndex, REFRESHING, "name", "age") + .withQuery(s"SELECT * FROM $testTable WHERE name = 'hello' or age = 30") + .withSkippingIndex(REFRESHING, "name", "age") .shouldNotRewrite() } test("should not rewrite query if filter condition contains disjunction") { assertFlintQueryRewriter() - .withSourceTable(testTable, testSchema) - .withFilter( - And( - Or(EqualTo(nameCol, Literal("hello")), EqualTo(ageCol, Literal(30))), - EqualTo(ageCol, Literal(30)))) - .withSkippingIndex(testIndex, REFRESHING, "name", "age") + .withQuery( + s"SELECT * FROM $testTable WHERE (name = 'hello' or age = 30) and address = 'Seattle'") + .withSkippingIndex(REFRESHING, "name", "age") .shouldNotRewrite() } test("should rewrite query with skipping index") { assertFlintQueryRewriter() - .withSourceTable(testTable, testSchema) - .withFilter(EqualTo(nameCol, Literal("hello"))) - .withSkippingIndex(testIndex, REFRESHING, "name") + .withQuery(s"SELECT * FROM $testTable WHERE name = 'hello'") + .withSkippingIndex(REFRESHING, "name") .shouldPushDownAfterRewrite(col("name") === "hello") } test("should not rewrite query with deleted skipping index") { assertFlintQueryRewriter() - .withSourceTable(testTable, testSchema) - .withFilter(EqualTo(nameCol, Literal("hello"))) - .withSkippingIndex(testIndex, DELETED, "name") + .withQuery(s"SELECT * FROM $testTable WHERE name = 'hello'") + .withSkippingIndex(DELETED, "name") .shouldNotRewrite() } test("should only push down filter condition with indexed column") { assertFlintQueryRewriter() - .withSourceTable(testTable, testSchema) - .withFilter(And(EqualTo(nameCol, Literal("hello")), EqualTo(ageCol, Literal(30)))) - .withSkippingIndex(testIndex, REFRESHING, "name") + .withQuery(s"SELECT * FROM $testTable WHERE name = 'hello' and age = 30") + .withSkippingIndex(REFRESHING, "name") .shouldPushDownAfterRewrite(col("name") === "hello") } test("should push down all filter conditions with indexed column") { assertFlintQueryRewriter() - .withSourceTable(testTable, testSchema) - .withFilter(And(EqualTo(nameCol, Literal("hello")), EqualTo(ageCol, Literal(30)))) - .withSkippingIndex(testIndex, REFRESHING, "name", "age") + .withQuery(s"SELECT * FROM $testTable WHERE name = 'hello' and age = 30") + .withSkippingIndex(REFRESHING, "name", "age") .shouldPushDownAfterRewrite(col("name") === "hello" && col("age") === 30) assertFlintQueryRewriter() - .withSourceTable(testTable, testSchema) - .withFilter( - And( - EqualTo(nameCol, Literal("hello")), - And(EqualTo(ageCol, Literal(30)), EqualTo(addressCol, Literal("Seattle"))))) - .withSkippingIndex(testIndex, REFRESHING, "name", "age", "address") + .withQuery( + s"SELECT * FROM $testTable WHERE name = 'hello' and (age = 30 and address = 'Seattle')") + .withSkippingIndex(REFRESHING, "name", "age", "address") .shouldPushDownAfterRewrite( col("name") === "hello" && col("age") === 30 && col("address") === "Seattle") } @@ -122,46 +115,27 @@ class ApplyFlintSparkSkippingIndexSuite extends SparkFunSuite with Matchers { } private class AssertionHelper { - private val flint = { - val mockFlint = mock[FlintSpark](RETURNS_DEEP_STUBS) - when(mockFlint.spark.sessionState.catalogManager.currentCatalog.name()) - .thenReturn("spark_catalog") - mockFlint - } private val rule = new ApplyFlintSparkSkippingIndex(flint) - private var relation: LogicalRelation = _ private var plan: LogicalPlan = _ - def withSourceTable(fullname: String, schema: StructType): AssertionHelper = { - val table = CatalogTable( - identifier = TableIdentifier(fullname.split('.')(1), Some(fullname.split('.')(0))), - tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty, - schema = null) - relation = LogicalRelation(mockBaseRelation(schema), table) - this - } - - def withFilter(condition: Expression): AssertionHelper = { - val filter = Filter(condition, relation) - val project = Project(Seq(), filter) - plan = SubqueryAlias("alb_logs", project) + def withQuery(query: String): AssertionHelper = { + this.plan = sql(query).queryExecution.optimizedPlan this } - def withSkippingIndex( - indexName: String, - indexState: IndexState, - indexCols: String*): AssertionHelper = { - val skippingIndex = mock[FlintSparkSkippingIndex] - when(skippingIndex.kind).thenReturn(SKIPPING_INDEX_TYPE) - when(skippingIndex.name()).thenReturn(indexName) - when(skippingIndex.indexedColumns).thenReturn(indexCols.map(FakeSkippingStrategy)) - - // Mock index log entry with the given state - val logEntry = mock[FlintMetadataLogEntry] - when(logEntry.state).thenReturn(indexState) - when(skippingIndex.latestLogEntry).thenReturn(Some(logEntry)) + def withSkippingIndex(indexState: IndexState, indexCols: String*): AssertionHelper = { + val skippingIndex = new FlintSparkSkippingIndex( + tableName = testTable, + indexedColumns = indexCols.map(FakeSkippingStrategy), + options = FlintSparkIndexOptions.empty, + latestLogEntry = Some( + new FlintMetadataLogEntry( + "id", + 0L, + indexState, + Map.empty[String, Any], + "", + Map.empty[String, Any]))) when(flint.describeIndex(any())).thenReturn(Some(skippingIndex)) this @@ -181,23 +155,6 @@ class ApplyFlintSparkSkippingIndexSuite extends SparkFunSuite with Matchers { } } - private def mockBaseRelation(schema: StructType): BaseRelation = { - val fileIndex = mock[FileIndex] - val baseRelation: HadoopFsRelation = mock[HadoopFsRelation] - when(baseRelation.location).thenReturn(fileIndex) - when(baseRelation.schema).thenReturn(schema) - - // Mock baseRelation.copy(location = FlintFileIndex) - doAnswer((invocation: InvocationOnMock) => { - val location = invocation.getArgument[FileIndex](0) - val relationCopy: HadoopFsRelation = mock[HadoopFsRelation] - when(relationCopy.location).thenReturn(location) - relationCopy - }).when(baseRelation).copy(any(), any(), any(), any(), any(), any())(any()) - - baseRelation - } - private def pushDownFilterToIndexScan(expect: Column): Matcher[LogicalPlan] = { Matcher { (plan: LogicalPlan) => val useFlintSparkSkippingFileIndex = plan.exists { diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndexSuite.scala index d2ef72158..4b707841c 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndexSuite.scala @@ -16,7 +16,7 @@ import org.apache.spark.FlintSuite import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Literal, Predicate} -import org.apache.spark.sql.execution.datasources.{FileIndex, PartitionDirectory} +import org.apache.spark.sql.execution.datasources.{FileIndex, FileStatusWithMetadata, PartitionDirectory} import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ @@ -118,7 +118,8 @@ class FlintSparkSkippingFileIndexSuite extends FlintSuite with Matchers { private def mockPartitions(partitions: Map[String, Seq[String]]): Seq[PartitionDirectory] = { partitions.map { case (partitionName, filePaths) => - val files = filePaths.map(path => new FileStatus(0, false, 0, 0, 0, new Path(path))) + val files = filePaths.map(path => + FileStatusWithMetadata(new FileStatus(0, false, 0, 0, 0, new Path(path)))) PartitionDirectory(InternalRow(Literal(partitionName)), files) }.toSeq } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala index 6772eb8f3..1b332660e 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala @@ -11,7 +11,8 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.json4s.native.JsonMethods.parse import org.mockito.Mockito.when -import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.common.metadata.FlintMetadata +import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE} import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind @@ -263,7 +264,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { } private def schemaShouldMatch(metadata: FlintMetadata, expected: String): Unit = { - val actual = parse(metadata.getContent) \ "properties" + val actual = parse(FlintOpenSearchIndexMetadataService.serialize(metadata)) \ "properties" assert(actual == parse(expected)) } } diff --git a/integ-test/src/aws-integration/scala/org/opensearch/flint/spark/aws/AWSEmrServerlessAccessTestSuite.scala b/integ-test/src/aws-integration/scala/org/opensearch/flint/spark/aws/AWSEmrServerlessAccessTestSuite.scala index 67e036d28..2e599c418 100644 --- a/integ-test/src/aws-integration/scala/org/opensearch/flint/spark/aws/AWSEmrServerlessAccessTestSuite.scala +++ b/integ-test/src/aws-integration/scala/org/opensearch/flint/spark/aws/AWSEmrServerlessAccessTestSuite.scala @@ -5,13 +5,14 @@ package org.opensearch.flint.spark.aws +import java.io.File import java.time.LocalDateTime import scala.concurrent.duration.DurationInt -import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder +import com.amazonaws.services.emrserverless.{AWSEMRServerless, AWSEMRServerlessClientBuilder} import com.amazonaws.services.emrserverless.model.{GetJobRunRequest, JobDriver, SparkSubmit, StartJobRunRequest} -import com.amazonaws.services.s3.AmazonS3ClientBuilder +import com.amazonaws.services.s3.{AmazonS3, AmazonS3ClientBuilder} import org.scalatest.BeforeAndAfter import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -19,12 +20,13 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.internal.Logging class AWSEmrServerlessAccessTestSuite - extends AnyFlatSpec + extends AnyFlatSpec with BeforeAndAfter with Matchers with Logging { lazy val testHost: String = System.getenv("AWS_OPENSEARCH_HOST") + lazy val testServerlessHost: String = System.getenv("AWS_OPENSEARCH_SERVERLESS_HOST") lazy val testPort: Int = -1 lazy val testRegion: String = System.getenv("AWS_REGION") lazy val testScheme: String = "https" @@ -36,53 +38,38 @@ class AWSEmrServerlessAccessTestSuite lazy val testS3CodePrefix: String = System.getenv("AWS_S3_CODE_PREFIX") lazy val testResultIndex: String = System.getenv("AWS_OPENSEARCH_RESULT_INDEX") - "EMR Serverless job" should "run successfully" in { + "EMR Serverless job with AOS" should "run successfully" in { val s3Client = AmazonS3ClientBuilder.standard().withRegion(testRegion).build() val emrServerless = AWSEMRServerlessClientBuilder.standard().withRegion(testRegion).build() - val appJarPath = - sys.props.getOrElse("appJar", throw new IllegalArgumentException("appJar not set")) - val extensionJarPath = sys.props.getOrElse( - "extensionJar", - throw new IllegalArgumentException("extensionJar not set")) - val pplJarPath = - sys.props.getOrElse("pplJar", throw new IllegalArgumentException("pplJar not set")) + uploadJarsToS3(s3Client) - s3Client.putObject( - testS3CodeBucket, - s"$testS3CodePrefix/sql-job.jar", - new java.io.File(appJarPath)) - s3Client.putObject( - testS3CodeBucket, - s"$testS3CodePrefix/extension.jar", - new java.io.File(extensionJarPath)) - s3Client.putObject( - testS3CodeBucket, - s"$testS3CodePrefix/ppl.jar", - new java.io.File(pplJarPath)) + val jobRunRequest = startJobRun("SELECT 1", testHost, "es") - val jobRunRequest = new StartJobRunRequest() - .withApplicationId(testAppId) - .withExecutionRoleArn(testExecutionRole) - .withName(s"integration-${LocalDateTime.now()}") - .withJobDriver(new JobDriver() - .withSparkSubmit(new SparkSubmit() - .withEntryPoint(s"s3://$testS3CodeBucket/$testS3CodePrefix/sql-job.jar") - .withEntryPointArguments(testResultIndex) - .withSparkSubmitParameters(s"--class org.apache.spark.sql.FlintJob --jars " + - s"s3://$testS3CodeBucket/$testS3CodePrefix/extension.jar," + - s"s3://$testS3CodeBucket/$testS3CodePrefix/ppl.jar " + - s"--conf spark.datasource.flint.host=$testHost " + - s"--conf spark.datasource.flint.port=-1 " + - s"--conf spark.datasource.flint.scheme=$testScheme " + - s"--conf spark.datasource.flint.auth=$testAuth " + - s"--conf spark.sql.catalog.glue=org.opensearch.sql.FlintDelegatingSessionCatalog " + - s"--conf spark.flint.datasource.name=glue " + - s"""--conf spark.flint.job.query="SELECT 1" """ + - s"--conf spark.hadoop.hive.metastore.client.factory.class=com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory"))) + val jobRunResponse = emrServerless.startJobRun(jobRunRequest) + + verifyJobSucceed(emrServerless, jobRunResponse.getJobRunId) + } + + "EMR Serverless job with AOSS" should "run successfully" in { + val s3Client = AmazonS3ClientBuilder.standard().withRegion(testRegion).build() + val emrServerless = AWSEMRServerlessClientBuilder.standard().withRegion(testRegion).build() + + uploadJarsToS3(s3Client) + + val jobRunRequest = startJobRun( + "SELECT 1", + testServerlessHost, + "aoss", + conf("spark.datasource.flint.write.refresh_policy", "false") + ) val jobRunResponse = emrServerless.startJobRun(jobRunRequest) + verifyJobSucceed(emrServerless, jobRunResponse.getJobRunId) + } + + private def verifyJobSucceed(emrServerless: AWSEMRServerless, jobRunId: String): Unit = { val startTime = System.currentTimeMillis() val timeout = 5.minutes.toMillis var jobState = "STARTING" @@ -92,11 +79,72 @@ class AWSEmrServerlessAccessTestSuite Thread.sleep(30000) val request = new GetJobRunRequest() .withApplicationId(testAppId) - .withJobRunId(jobRunResponse.getJobRunId) + .withJobRunId(jobRunId) jobState = emrServerless.getJobRun(request).getJobRun.getState logInfo(s"Current job state: $jobState at ${System.currentTimeMillis()}") } - jobState shouldBe "SUCCESS" } + + private def startJobRun(query: String, host: String, authServiceName: String, additionalParams: String*) = { + new StartJobRunRequest() + .withApplicationId(testAppId) + .withExecutionRoleArn(testExecutionRole) + .withName(s"integration-${authServiceName}-${LocalDateTime.now()}") + .withJobDriver(new JobDriver() + .withSparkSubmit(new SparkSubmit() + .withEntryPoint(s"s3://$testS3CodeBucket/$testS3CodePrefix/sql-job.jar") + .withEntryPointArguments(testResultIndex) + .withSparkSubmitParameters( + join( + clazz("org.apache.spark.sql.FlintJob"), + jars(s"s3://$testS3CodeBucket/$testS3CodePrefix/extension.jar", s"s3://$testS3CodeBucket/$testS3CodePrefix/ppl.jar"), + conf("spark.datasource.flint.host", host), + conf("spark.datasource.flint.port", s"$testPort"), + conf("spark.datasource.flint.scheme", testScheme), + conf("spark.datasource.flint.auth", testAuth), + conf("spark.datasource.flint.auth.servicename", authServiceName), + conf("spark.sql.catalog.glue", "org.opensearch.sql.FlintDelegatingSessionCatalog"), + conf("spark.flint.datasource.name", "glue"), + conf("spark.flint.job.query", quote(query)), + conf("spark.hadoop.hive.metastore.client.factory.class", "com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory"), + join(additionalParams: _*) + ) + ) + ) + ) + } + + private def join(params: String*): String = params.mkString(" ") + + private def clazz(clazz: String): String = s"--class $clazz" + + private def jars(jars: String*): String = s"--jars ${jars.mkString(",")}" + + private def quote(str: String): String = "\"" + str + "\"" + + private def conf(name: String, value: String): String = s"--conf $name=$value" + + private def uploadJarsToS3(s3Client: AmazonS3) = { + val appJarPath = + sys.props.getOrElse("appJar", throw new IllegalArgumentException("appJar not set")) + val extensionJarPath = sys.props.getOrElse( + "extensionJar", + throw new IllegalArgumentException("extensionJar not set")) + val pplJarPath = + sys.props.getOrElse("pplJar", throw new IllegalArgumentException("pplJar not set")) + + s3Client.putObject( + testS3CodeBucket, + s"$testS3CodePrefix/sql-job.jar", + new File(appJarPath)) + s3Client.putObject( + testS3CodeBucket, + s"$testS3CodePrefix/extension.jar", + new File(extensionJarPath)) + s3Client.putObject( + testS3CodeBucket, + s"$testS3CodePrefix/ppl.jar", + new File(pplJarPath)) + } } diff --git a/integ-test/src/integration/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala index 0bccf787b..fd5c5bf8f 100644 --- a/integ-test/src/integration/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala @@ -104,7 +104,9 @@ class FlintDataSourceV2ITSuite val df2 = df.filter($"aText".contains("second")) checkFiltersRemoved(df2) - checkPushedInfo(df2, "PushedPredicates: [aText IS NOT NULL, aText LIKE '%second%']") + checkPushedInfo( + df2, + "PushedPredicates: [aText IS NOT NULL, aText LIKE '%second%' ESCAPE '\\']") checkAnswer(df2, Row(2, "b", "i am second")) val df3 = @@ -117,7 +119,7 @@ class FlintDataSourceV2ITSuite checkFiltersRemoved(df4) checkPushedInfo( df4, - "PushedPredicates: [aInt IS NOT NULL, aText IS NOT NULL, aInt > 1, aText LIKE '%second%']") + "PushedPredicates: [aInt IS NOT NULL, aText IS NOT NULL, aInt > 1, aText LIKE '%second%' ESCAPE '\\']") checkAnswer(df4, Row(2, "b", "i am second")) } } diff --git a/integ-test/src/integration/scala/org/apache/spark/opensearch/catalog/OpenSearchCatalogITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/opensearch/catalog/OpenSearchCatalogITSuite.scala index ea5988577..69900677c 100644 --- a/integ-test/src/integration/scala/org/apache/spark/opensearch/catalog/OpenSearchCatalogITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/opensearch/catalog/OpenSearchCatalogITSuite.scala @@ -22,6 +22,7 @@ class OpenSearchCatalogITSuite extends OpenSearchCatalogSuite { } } + // FIXME https://github.com/opensearch-project/opensearch-spark/issues/529 test("Describe single index as table") { val indexName = "t0001" withIndexName(indexName) { @@ -29,16 +30,13 @@ class OpenSearchCatalogITSuite extends OpenSearchCatalogSuite { val df = spark.sql(s""" DESC ${catalogName}.default.$indexName""") - assert(df.count() == 6) + assert(df.count() == 3) checkAnswer( df, Seq( - Row("# Partitioning", "", ""), - Row("", "", ""), - Row("Not partitioned", "", ""), - Row("accountId", "string", ""), - Row("eventName", "string", ""), - Row("eventSource", "string", ""))) + Row("accountId", "string", null), + Row("eventName", "string", null), + Row("eventSource", "string", null))) } } diff --git a/integ-test/src/integration/scala/org/apache/spark/opensearch/table/OpenSearchCatalogSuite.scala b/integ-test/src/integration/scala/org/apache/spark/opensearch/table/OpenSearchCatalogSuite.scala index 21323cca4..832642088 100644 --- a/integ-test/src/integration/scala/org/apache/spark/opensearch/table/OpenSearchCatalogSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/opensearch/table/OpenSearchCatalogSuite.scala @@ -8,7 +8,7 @@ package org.apache.spark.opensearch.table import org.opensearch.flint.spark.FlintSparkSuite trait OpenSearchCatalogSuite extends FlintSparkSuite { - val catalogName = "dev" + override lazy val catalogName = "dev" override def beforeAll(): Unit = { super.beforeAll() diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala index 921db792a..1d86a6589 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -15,12 +15,13 @@ import scala.util.control.Breaks.{break, breakable} import org.opensearch.OpenSearchStatusException import org.opensearch.flint.OpenSearchSuite +import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} import org.opensearch.flint.core.{FlintClient, FlintOptions} import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater} -import org.opensearch.flint.data.{FlintStatement, InteractiveSession} import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY import org.apache.spark.sql.flint.config.FlintSparkConf.{DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID} import org.apache.spark.sql.util.MockEnvironment import org.apache.spark.util.ThreadUtils @@ -120,7 +121,6 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { updater = new OpenSearchUpdater( requestIndex, new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava))) - } override def afterEach(): Unit = { @@ -130,19 +130,20 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { def createSession(jobId: String, excludeJobId: String): Unit = { val docs = Seq(s"""{ - | "state": "running", - | "lastUpdateTime": 1698796582978, - | "applicationId": "00fd777k3k3ls20p", - | "error": "", - | "sessionId": ${sessionId}, - | "jobId": \"${jobId}\", - | "type": "session", - | "excludeJobIds": [\"${excludeJobId}\"] - |}""".stripMargin) + | "state": "running", + | "lastUpdateTime": 1698796582978, + | "applicationId": "00fd777k3k3ls20p", + | "error": "", + | "sessionId": ${sessionId}, + | "jobId": \"${jobId}\", + | "type": "session", + | "excludeJobIds": [\"${excludeJobId}\"] + |}""".stripMargin) index(requestIndex, oneNodeSetting, requestIndexMapping, docs) } - def startREPL(): Future[Unit] = { + def startREPL(queryLoopExecutionFrequency: Long = DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + : Future[Unit] = { val prefix = "flint-repl-test" val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor(prefix, 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -164,6 +165,10 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { System.setProperty(HOST_PORT.key, String.valueOf(openSearchPort)) System.setProperty(REFRESH_POLICY.key, "true") + System.setProperty( + "spark.flint.job.queryLoopExecutionFrequency", + queryLoopExecutionFrequency.toString) + FlintREPL.envinromentProvider = new MockEnvironment( Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) FlintREPL.enableHiveSupport = false @@ -266,7 +271,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin // submitted from last year. We won't pick it up val lateSelectStatementId = - submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId, 1672101970000L) + submitQuery(s"${makeJsonCompliant(lateSelectQuery)}", lateSelectQueryId, 1672101970000L) // clean up val dropStatement = @@ -485,6 +490,99 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { } } + test("query loop should exit with inactivity timeout due to large query loop freq") { + try { + createSession(jobRunId, "") + threadLocalFuture.set(startREPL(5000L)) + val createStatement = + s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\\t' + | ) + |""".stripMargin + submitQuery(s"${makeJsonCompliant(createStatement)}", "119") + + val insertStatement = + s""" + | INSERT INTO $testTable + | VALUES ('Hello', 30) + | """.stripMargin + submitQuery(s"${makeJsonCompliant(insertStatement)}", "120") + + val selectQueryId = "121" + val selectQueryStartTime = System.currentTimeMillis() + val selectQuery = s"SELECT name, age FROM $testTable".stripMargin + val selectStatementId = submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId) + + val lateSelectQueryId = "122" + val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin + // old query + val lateSelectStatementId = + submitQuery(s"${makeJsonCompliant(lateSelectQuery)}", lateSelectQueryId, 1672101970000L) + + // clean up + val dropStatement = + s"""DROP TABLE $testTable""".stripMargin + submitQuery(s"${makeJsonCompliant(dropStatement)}", "999") + + val selectQueryValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 1, + s"expected result size is 1, but got ${result.results.size}") + val expectedResult = "{'name':'Hello','age':30}" + assert( + result.results(0).equals(expectedResult), + s"expected result is $expectedResult, but got ${result.results(0)}") + assert( + result.schemas.size == 2, + s"expected schema size is 2, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'age','data_type':'integer'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + commonValidation(result, selectQueryId, selectQuery, selectQueryStartTime) + successValidation(result) + true + } + pollForResultAndAssert(selectQueryValidation, selectQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + selectStatementId), + s"Fail to verify for $selectStatementId.") + + assert( + awaitConditionForStatementOrTimeout( + statement => { + statement.state != "waiting" + }, + lateSelectStatementId), + s"Fail to verify for $lateSelectStatementId.") + } catch { + case e: Exception => + logError("Unexpected exception", e) + assert(false, "Unexpected exception") + } finally { + waitREPLStop(threadLocalFuture.get()) + threadLocalFuture.remove() + + // shutdown hook is called after all tests have finished. We cannot verify if session has correctly been set in IT. + } + } + /** * JSON does not support raw newlines (\n) in string values. All newlines must be escaped or * removed when inside a JSON string. The same goes for tab characters, which should be diff --git a/integ-test/src/integration/scala/org/opensearch/flint/core/FlintMetadataLogITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/core/FlintMetadataLogITSuite.scala index eadc5031a..9aeba7512 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/core/FlintMetadataLogITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/core/FlintMetadataLogITSuite.scala @@ -19,8 +19,7 @@ import org.opensearch.flint.core.storage.FlintOpenSearchMetadataLogService import org.opensearch.index.seqno.SequenceNumbers.{UNASSIGNED_PRIMARY_TERM, UNASSIGNED_SEQ_NO} import org.scalatest.matchers.should.Matchers -import org.apache.spark.SparkConf -import org.apache.spark.sql.flint.config.FlintSparkConf.{CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS, DATA_SOURCE_NAME} +import org.apache.spark.sql.flint.config.FlintSparkConf.DATA_SOURCE_NAME class FlintMetadataLogITSuite extends OpenSearchTransactionSuite with Matchers { @@ -46,18 +45,19 @@ class FlintMetadataLogITSuite extends OpenSearchTransactionSuite with Matchers { test("should build metadata log service") { val customOptions = - openSearchOptions + (CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS.key -> "org.opensearch.flint.core.TestMetadataLogService") + openSearchOptions + (FlintOptions.CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS -> "org.opensearch.flint.core.TestMetadataLogService") val customFlintOptions = new FlintOptions(customOptions.asJava) val customFlintMetadataLogService = - FlintMetadataLogServiceBuilder.build(customFlintOptions, sparkConf) + FlintMetadataLogServiceBuilder.build(customFlintOptions) customFlintMetadataLogService shouldBe a[TestMetadataLogService] } test("should fail to build metadata log service if class name doesn't exist") { - val options = openSearchOptions + (CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS.key -> "dummy") + val options = + openSearchOptions + (FlintOptions.CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS -> "dummy") val flintOptions = new FlintOptions(options.asJava) the[RuntimeException] thrownBy { - FlintMetadataLogServiceBuilder.build(flintOptions, sparkConf) + FlintMetadataLogServiceBuilder.build(flintOptions) } } @@ -118,7 +118,7 @@ class FlintMetadataLogITSuite extends OpenSearchTransactionSuite with Matchers { } } -case class TestMetadataLogService(sparkConf: SparkConf) extends FlintMetadataLogService { +class TestMetadataLogService extends FlintMetadataLogService { override def startTransaction[T]( indexName: String, forceInit: Boolean): OptimisticTransaction[T] = { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala index 53188fb5a..2dc6016b2 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala @@ -10,23 +10,20 @@ import scala.collection.JavaConverters._ import org.json4s.{Formats, NoTypeHints} import org.json4s.native.JsonMethods.parse import org.json4s.native.Serialization -import org.mockito.Mockito.when import org.opensearch.flint.OpenSearchSuite -import org.opensearch.flint.core.metadata.FlintMetadata -import org.opensearch.flint.core.storage.FlintOpenSearchClient +import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintOpenSearchIndexMetadataService} import org.opensearch.flint.core.table.OpenSearchCluster import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers -import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY class FlintOpenSearchClientSuite extends AnyFlatSpec with OpenSearchSuite with Matchers { - lazy val options = new FlintOptions(openSearchOptions.asJava) - /** Lazy initialize after container started. */ - lazy val flintClient = new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava)) + lazy val options = new FlintOptions(openSearchOptions.asJava) + lazy val flintClient = new FlintOpenSearchClient(options) + lazy val flintIndexMetadataService = new FlintOpenSearchIndexMetadataService(options) behavior of "Flint OpenSearch client" @@ -45,98 +42,39 @@ class FlintOpenSearchClientSuite extends AnyFlatSpec with OpenSearchSuite with M | } |""".stripMargin - val metadata = mock[FlintMetadata] - when(metadata.getContent).thenReturn(content) - when(metadata.indexSettings).thenReturn(None) + val metadata = FlintOpenSearchIndexMetadataService.deserialize(content) flintClient.createIndex(indexName, metadata) + flintIndexMetadataService.updateIndexMetadata(indexName, metadata) flintClient.exists(indexName) shouldBe true - flintClient.getIndexMetadata(indexName).kind shouldBe "test_kind" + flintIndexMetadataService.getIndexMetadata(indexName).kind shouldBe "test_kind" } it should "create index with settings" in { val indexName = "flint_test_with_settings" val indexSettings = "{\"number_of_shards\": 3,\"number_of_replicas\": 2}" - val metadata = mock[FlintMetadata] - when(metadata.getContent).thenReturn("{}") - when(metadata.indexSettings).thenReturn(Some(indexSettings)) + val metadata = FlintOpenSearchIndexMetadataService.deserialize("{}", indexSettings) flintClient.createIndex(indexName, metadata) flintClient.exists(indexName) shouldBe true // OS uses full setting name ("index" prefix) and store as string implicit val formats: Formats = Serialization.formats(NoTypeHints) - val settings = parse(flintClient.getIndexMetadata(indexName).indexSettings.get) + val settings = parse(flintIndexMetadataService.getIndexMetadata(indexName).indexSettings.get) (settings \ "index.number_of_shards").extract[String] shouldBe "3" (settings \ "index.number_of_replicas").extract[String] shouldBe "2" } - it should "update index successfully" in { - val indexName = "test_update" - val content = - """ { - | "_meta": { - | "kind": "test_kind" - | }, - | "properties": { - | "age": { - | "type": "integer" - | } - | } - | } - |""".stripMargin - - val metadata = mock[FlintMetadata] - when(metadata.getContent).thenReturn(content) - when(metadata.indexSettings).thenReturn(None) - flintClient.createIndex(indexName, metadata) - - val newContent = - """ { - | "_meta": { - | "kind": "test_kind", - | "name": "test_name" - | }, - | "properties": { - | "age": { - | "type": "integer" - | } - | } - | } - |""".stripMargin - - val newMetadata = mock[FlintMetadata] - when(newMetadata.getContent).thenReturn(newContent) - when(newMetadata.indexSettings).thenReturn(None) - flintClient.updateIndex(indexName, newMetadata) - - flintClient.exists(indexName) shouldBe true - flintClient.getIndexMetadata(indexName).kind shouldBe "test_kind" - flintClient.getIndexMetadata(indexName).name shouldBe "test_name" - } - - it should "get all index metadata with the given index name pattern" in { - val metadata = mock[FlintMetadata] - when(metadata.getContent).thenReturn("{}") - when(metadata.indexSettings).thenReturn(None) - flintClient.createIndex("flint_test_1_index", metadata) - flintClient.createIndex("flint_test_2_index", metadata) - - val allMetadata = flintClient.getAllIndexMetadata("flint_*_index") - allMetadata should have size 2 - allMetadata.values.forEach(metadata => metadata.getContent should not be empty) - allMetadata.values.forEach(metadata => metadata.indexSettings should not be empty) - } - it should "convert index name to all lowercase" in { val indexName = "flint_ELB_logs_index" flintClient.createIndex( indexName, - FlintMetadata("""{"properties": {"test": { "type": "integer" } } }""")) + FlintOpenSearchIndexMetadataService.deserialize( + """{"properties": {"test": { "type": "integer" } } }""")) flintClient.exists(indexName) shouldBe true - flintClient.getIndexMetadata(indexName) should not be null - flintClient.getAllIndexMetadata("flint_ELB_*") should not be empty + flintIndexMetadataService.getIndexMetadata(indexName) should not be null + flintIndexMetadataService.getAllIndexMetadata("flint_ELB_*") should not be empty // Read write test val writer = flintClient.createWriter(indexName) @@ -156,11 +94,12 @@ class FlintOpenSearchClientSuite extends AnyFlatSpec with OpenSearchSuite with M val indexName = "test :\"+/\\|?#><" flintClient.createIndex( indexName, - FlintMetadata("""{"properties": {"test": { "type": "integer" } } }""")) + FlintOpenSearchIndexMetadataService.deserialize( + """{"properties": {"test": { "type": "integer" } } }""")) flintClient.exists(indexName) shouldBe true - flintClient.getIndexMetadata(indexName) should not be null - flintClient.getAllIndexMetadata("test *") should not be empty + flintIndexMetadataService.getIndexMetadata(indexName) should not be null + flintIndexMetadataService.getAllIndexMetadata("test *") should not be empty // Read write test val writer = flintClient.createWriter(indexName) @@ -268,6 +207,6 @@ class FlintOpenSearchClientSuite extends AnyFlatSpec with OpenSearchSuite with M } def createTable(indexName: String, options: FlintOptions): Table = { - OpenSearchCluster.apply(indexName, options).head + OpenSearchCluster.apply(indexName, options).asScala.head } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/core/FlintOpenSearchIndexMetadataServiceITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/core/FlintOpenSearchIndexMetadataServiceITSuite.scala new file mode 100644 index 000000000..c5bd75951 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/core/FlintOpenSearchIndexMetadataServiceITSuite.scala @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core + +import java.util + +import scala.collection.JavaConverters._ + +import org.opensearch.flint.OpenSearchSuite +import org.opensearch.flint.common.metadata.{FlintIndexMetadataService, FlintMetadata} +import org.opensearch.flint.core.metadata.FlintIndexMetadataServiceBuilder +import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintOpenSearchIndexMetadataService} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class FlintOpenSearchIndexMetadataServiceITSuite + extends AnyFlatSpec + with OpenSearchSuite + with Matchers { + + /** Lazy initialize after container started. */ + lazy val options = new FlintOptions(openSearchOptions.asJava) + lazy val flintClient = new FlintOpenSearchClient(options) + lazy val flintIndexMetadataService = new FlintOpenSearchIndexMetadataService(options) + + behavior of "Flint index metadata service builder" + + it should "build index metadata service" in { + val customOptions = + openSearchOptions + (FlintOptions.CUSTOM_FLINT_INDEX_METADATA_SERVICE_CLASS -> "org.opensearch.flint.core.TestIndexMetadataService") + val customFlintOptions = new FlintOptions(customOptions.asJava) + val customFlintIndexMetadataService = + FlintIndexMetadataServiceBuilder.build(customFlintOptions) + customFlintIndexMetadataService shouldBe a[TestIndexMetadataService] + } + + it should "fail to build index metadata service if class name doesn't exist" in { + val options = + openSearchOptions + (FlintOptions.CUSTOM_FLINT_INDEX_METADATA_SERVICE_CLASS -> "dummy") + val flintOptions = new FlintOptions(options.asJava) + the[RuntimeException] thrownBy { + FlintIndexMetadataServiceBuilder.build(flintOptions) + } + } + + behavior of "Flint OpenSearch index metadata service" + + it should "get all index metadata with the given index name pattern" in { + val metadata = FlintOpenSearchIndexMetadataService.deserialize("{}") + flintClient.createIndex("flint_test_1_index", metadata) + flintClient.createIndex("flint_test_2_index", metadata) + + val allMetadata = flintIndexMetadataService.getAllIndexMetadata("flint_*_index") + allMetadata should have size 2 + allMetadata.values.forEach(metadata => + FlintOpenSearchIndexMetadataService.serialize(metadata) should not be empty) + allMetadata.values.forEach(metadata => metadata.indexSettings should not be empty) + } + + it should "update index metadata successfully" in { + val indexName = "test_update" + val content = + """ { + | "_meta": { + | "kind": "test_kind" + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + + val metadata = FlintOpenSearchIndexMetadataService.deserialize(content) + flintClient.createIndex(indexName, metadata) + + flintIndexMetadataService.getIndexMetadata(indexName).kind shouldBe empty + + flintIndexMetadataService.updateIndexMetadata(indexName, metadata) + + flintIndexMetadataService.getIndexMetadata(indexName).kind shouldBe "test_kind" + flintIndexMetadataService.getIndexMetadata(indexName).name shouldBe empty + + val newContent = + """ { + | "_meta": { + | "kind": "test_kind", + | "name": "test_name" + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + + val newMetadata = FlintOpenSearchIndexMetadataService.deserialize(newContent) + flintIndexMetadataService.updateIndexMetadata(indexName, newMetadata) + + flintIndexMetadataService.getIndexMetadata(indexName).kind shouldBe "test_kind" + flintIndexMetadataService.getIndexMetadata(indexName).name shouldBe "test_name" + } +} + +class TestIndexMetadataService extends FlintIndexMetadataService { + override def getIndexMetadata(indexName: String): FlintMetadata = { + null + } + + override def getAllIndexMetadata(indexNamePattern: String*): util.Map[String, FlintMetadata] = { + null + } + + override def updateIndexMetadata(indexName: String, metadata: FlintMetadata): Unit = {} + + override def deleteIndexMetadata(indexName: String): Unit = {} +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala index fa7f75b81..1c4da53fb 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala @@ -10,8 +10,8 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.action.get.{GetRequest, GetResponse} import org.opensearch.client.RequestOptions import org.opensearch.flint.OpenSearchTransactionSuite +import org.opensearch.flint.common.model.InteractiveSession import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchUpdater} -import org.opensearch.flint.data.InteractiveSession import org.scalatest.matchers.should.Matchers class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala index 31b5c14b1..9c91a129e 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala @@ -8,7 +8,8 @@ package org.opensearch.flint.spark import java.util.Base64 import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson -import org.opensearch.flint.core.FlintVersion.current +import org.opensearch.flint.common.FlintVersion.current +import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName import org.scalatest.matchers.must.Matchers.defined import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper @@ -47,7 +48,7 @@ class FlintSparkCoveringIndexITSuite extends FlintSparkSuite { val index = flint.describeIndex(testFlintIndex) index shouldBe defined - index.get.metadata().getContent should matchJson(s"""{ + FlintOpenSearchIndexMetadataService.serialize(index.get.metadata()) should matchJson(s"""{ | "_meta": { | "version": "${current()}", | "name": "name_and_age", diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala index ffd956b1c..235cab4d2 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala @@ -12,7 +12,7 @@ import org.json4s.{Formats, NoTypeHints} import org.json4s.native.JsonMethods.parse import org.json4s.native.Serialization import org.opensearch.flint.core.FlintOptions -import org.opensearch.flint.core.storage.FlintOpenSearchClient +import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName import org.scalatest.matchers.must.Matchers.defined @@ -94,10 +94,12 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { |""".stripMargin) // Check if the index setting option is set to OS index setting - val flintClient = new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava)) + val flintIndexMetadataService = + new FlintOpenSearchIndexMetadataService(new FlintOptions(openSearchOptions.asJava)) implicit val formats: Formats = Serialization.formats(NoTypeHints) - val settings = parse(flintClient.getIndexMetadata(testFlintIndex).indexSettings.get) + val settings = + parse(flintIndexMetadataService.getIndexMetadata(testFlintIndex).indexSettings.get) (settings \ "index.number_of_shards").extract[String] shouldBe "2" (settings \ "index.number_of_replicas").extract[String] shouldBe "3" } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala index ad5029fcb..8d8311b11 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala @@ -9,9 +9,11 @@ import java.util.Base64 import java.util.concurrent.TimeUnit import scala.collection.JavaConverters.mapAsJavaMapConverter +import scala.concurrent.duration.{DurationInt, FiniteDuration} import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{doAnswer, spy} +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest import org.opensearch.client.RequestOptions import org.opensearch.flint.OpenSearchTransactionSuite @@ -148,25 +150,38 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc spark.streams.active.exists(_.name == testFlintIndex) shouldBe false } + test("monitor task and streaming job should terminate if data index is deleted") { + val task = FlintSparkIndexMonitor.indexMonitorTracker(testFlintIndex) + openSearchClient + .indices() + .delete(new DeleteIndexRequest(testFlintIndex), RequestOptions.DEFAULT) + + // Wait for index monitor execution and assert + waitForMonitorTaskRun() + task.isCancelled shouldBe true + spark.streams.active.exists(_.name == testFlintIndex) shouldBe false + + // Assert index state is still refreshing + val latestLog = latestLogEntry(testLatestId) + latestLog should contain("state" -> "refreshing") + } + test("await monitor terminated without exception should stay refreshing state") { // Setup a timer to terminate the streaming job - new Thread(() => { - Thread.sleep(3000L) + asyncAfter(3.seconds) { spark.streams.active.find(_.name == testFlintIndex).get.stop() - }).start() + } // Await until streaming job terminated flint.flintIndexMonitor.awaitMonitor() - // Assert index state is active now + // Assert index state is still refreshing val latestLog = latestLogEntry(testLatestId) latestLog should contain("state" -> "refreshing") } test("await monitor terminated with exception should update index state to failed with error") { - new Thread(() => { - Thread.sleep(3000L) - + asyncAfter(3.seconds) { // Set Flint index readonly to simulate streaming job exception val settings = Map("index.blocks.write" -> true) val request = new UpdateSettingsRequest(testFlintIndex).settings(settings.asJava) @@ -178,7 +193,7 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc | PARTITION (year=2023, month=6) | VALUES ('Test', 35, 'Vancouver') | """.stripMargin) - }).start() + } // Await until streaming job terminated flint.flintIndexMonitor.awaitMonitor() @@ -204,6 +219,19 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc latestLog should not contain "error" } + private def async(block: => Unit): Unit = { + new Thread(() => { + block + }).start() + } + + private def asyncAfter(delay: FiniteDuration)(block: => Unit): Unit = { + new Thread(() => { + Thread.sleep(delay.toMillis) + block + }).start() + } + private def getLatestTimestamp: (Long, Long) = { val latest = latestLogEntry(testLatestId) (latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long]) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkIndexSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkIndexSqlITSuite.scala index a5744271f..ce1cbb2ea 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkIndexSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkIndexSqlITSuite.scala @@ -217,4 +217,32 @@ class FlintSparkIndexSqlITSuite extends FlintSparkSuite with Matchers { deleteTestIndex(testSkippingFlintIndex) } } + + test("show flint index with special characters") { + val testCoveringIndexSpecial = "test :\"+/\\|?#><" + val testCoveringFlintIndexSpecial = + FlintSparkCoveringIndex.getFlintIndexName(testCoveringIndexSpecial, testTableQualifiedName) + + flint + .coveringIndex() + .name(testCoveringIndexSpecial) + .onTable(testTableQualifiedName) + .addIndexColumns("name", "age") + .options(FlintSparkIndexOptions(Map(AUTO_REFRESH.toString -> "true"))) + .create() + flint.refreshIndex(testCoveringFlintIndexSpecial) + + checkAnswer( + sql(s"SHOW FLINT INDEX IN spark_catalog"), + Seq( + Row( + testCoveringFlintIndexSpecial, + "covering", + "default", + testTableName, + testCoveringIndexSpecial, + true, + "refreshing"))) + deleteTestIndex(testCoveringFlintIndexSpecial) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala index f824aab73..605975af6 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala @@ -9,7 +9,8 @@ import java.sql.Timestamp import java.util.Base64 import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson -import org.opensearch.flint.core.FlintVersion.current +import org.opensearch.flint.common.FlintVersion.current +import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.getFlintIndexName import org.scalatest.matchers.must.Matchers.defined import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper @@ -59,7 +60,7 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { val index = flint.describeIndex(testFlintIndex) index shouldBe defined - index.get.metadata().getContent should matchJson(s""" + FlintOpenSearchIndexMetadataService.serialize(index.get.metadata()) should matchJson(s""" | { | "_meta": { | "version": "${current()}", diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala index fc4cdbeac..66d6e0779 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala @@ -14,7 +14,7 @@ import org.json4s.{Formats, NoTypeHints} import org.json4s.native.JsonMethods.parse import org.json4s.native.Serialization import org.opensearch.flint.core.FlintOptions -import org.opensearch.flint.core.storage.FlintOpenSearchClient +import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.getFlintIndexName import org.scalatest.matchers.must.Matchers.{defined, have} import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} @@ -25,8 +25,8 @@ import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { /** Test table, MV, index name and query */ - private val testTable = "spark_catalog.default.mv_test" - private val testMvName = "spark_catalog.default.mv_test_metrics" + private val testTable = s"$catalogName.default.mv_test" + private val testMvName = s"$catalogName.default.mv_test_metrics" private val testFlintIndex = getFlintIndexName(testMvName) private val testQuery = s""" @@ -152,10 +152,12 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { |""".stripMargin) // Check if the index setting option is set to OS index setting - val flintClient = new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava)) + val flintIndexMetadataService = + new FlintOpenSearchIndexMetadataService(new FlintOptions(openSearchOptions.asJava)) implicit val formats: Formats = Serialization.formats(NoTypeHints) - val settings = parse(flintClient.getIndexMetadata(testFlintIndex).indexSettings.get) + val settings = + parse(flintIndexMetadataService.getIndexMetadata(testFlintIndex).indexSettings.get) (settings \ "index.number_of_shards").extract[String] shouldBe "3" (settings \ "index.number_of_replicas").extract[String] shouldBe "2" } @@ -218,7 +220,11 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { } test("issue 112, https://github.com/opensearch-project/opensearch-spark/issues/112") { - val tableName = "spark_catalog.default.issue112" + if (tableType.equalsIgnoreCase("iceberg")) { + cancel + } + + val tableName = s"$catalogName.default.issue112" createTableIssue112(tableName) sql(s""" |CREATE MATERIALIZED VIEW $testMvName AS @@ -261,14 +267,14 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { test("create materialized view with quoted name and column name") { val testQuotedQuery = - """ SELECT + s""" SELECT | window.start AS `start.time`, | COUNT(*) AS `count` - | FROM `spark_catalog`.`default`.`mv_test` + | FROM `$catalogName`.`default`.`mv_test` | GROUP BY TUMBLE(`time`, '10 Minutes')""".stripMargin.trim sql(s""" - | CREATE MATERIALIZED VIEW `spark_catalog`.`default`.`mv_test_metrics` + | CREATE MATERIALIZED VIEW `$catalogName`.`default`.`mv_test_metrics` | AS $testQuotedQuery |""".stripMargin) @@ -303,34 +309,34 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { test("show all materialized views in catalog and database") { // Show in catalog - flint.materializedView().name("spark_catalog.default.mv1").query(testQuery).create() - checkAnswer(sql(s"SHOW MATERIALIZED VIEW IN spark_catalog"), Seq(Row("mv1"))) + flint.materializedView().name(s"$catalogName.default.mv1").query(testQuery).create() + checkAnswer(sql(s"SHOW MATERIALIZED VIEW IN $catalogName"), Seq(Row("mv1"))) // Show in catalog.database - flint.materializedView().name("spark_catalog.default.mv2").query(testQuery).create() + flint.materializedView().name(s"$catalogName.default.mv2").query(testQuery).create() checkAnswer( - sql(s"SHOW MATERIALIZED VIEW IN spark_catalog.default"), + sql(s"SHOW MATERIALIZED VIEW IN $catalogName.default"), Seq(Row("mv1"), Row("mv2"))) - checkAnswer(sql(s"SHOW MATERIALIZED VIEW IN spark_catalog.other"), Seq.empty) + checkAnswer(sql(s"SHOW MATERIALIZED VIEW IN $catalogName.other"), Seq.empty) deleteTestIndex( - getFlintIndexName("spark_catalog.default.mv1"), - getFlintIndexName("spark_catalog.default.mv2")) + getFlintIndexName(s"$catalogName.default.mv1"), + getFlintIndexName(s"$catalogName.default.mv2")) } test("show materialized view in database with the same prefix") { - flint.materializedView().name("spark_catalog.default.mv1").query(testQuery).create() - flint.materializedView().name("spark_catalog.default_test.mv2").query(testQuery).create() - checkAnswer(sql(s"SHOW MATERIALIZED VIEW IN spark_catalog.default"), Seq(Row("mv1"))) + flint.materializedView().name(s"$catalogName.default.mv1").query(testQuery).create() + flint.materializedView().name(s"$catalogName.default_test.mv2").query(testQuery).create() + checkAnswer(sql(s"SHOW MATERIALIZED VIEW IN $catalogName.default"), Seq(Row("mv1"))) deleteTestIndex( - getFlintIndexName("spark_catalog.default.mv1"), - getFlintIndexName("spark_catalog.default_test.mv2")) + getFlintIndexName(s"$catalogName.default.mv1"), + getFlintIndexName(s"$catalogName.default_test.mv2")) } test("should return emtpy when show materialized views in empty database") { - checkAnswer(sql(s"SHOW MATERIALIZED VIEW IN spark_catalog.other"), Seq.empty) + checkAnswer(sql(s"SHOW MATERIALIZED VIEW IN $catalogName.other"), Seq.empty) } test("describe materialized view") { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index 66e777dea..968f09345 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -10,7 +10,8 @@ import java.util.Base64 import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson import org.json4s.native.JsonMethods._ import org.opensearch.client.RequestOptions -import org.opensearch.flint.core.FlintVersion.current +import org.opensearch.flint.common.FlintVersion.current +import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN import org.opensearch.flint.spark.skipping.FlintSparkSkippingFileIndex import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName @@ -60,7 +61,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { val index = flint.describeIndex(testIndex) index shouldBe defined - index.get.metadata().getContent should matchJson(s"""{ + FlintOpenSearchIndexMetadataService.serialize(index.get.metadata()) should matchJson(s"""{ | "_meta": { | "name": "flint_spark_catalog_default_skipping_test_skipping_index", | "version": "${current()}", @@ -155,7 +156,11 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { val index = flint.describeIndex(testIndex) index shouldBe defined val optionJson = - compact(render(parse(index.get.metadata().getContent) \ "_meta" \ "options")) + compact( + render( + parse( + FlintOpenSearchIndexMetadataService.serialize( + index.get.metadata())) \ "_meta" \ "options")) optionJson should matchJson(s""" | { | "auto_refresh": "true", @@ -644,7 +649,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { val index = flint.describeIndex(testIndex) index shouldBe defined - index.get.metadata().getContent should matchJson(s"""{ + FlintOpenSearchIndexMetadataService.serialize(index.get.metadata()) should matchJson(s"""{ | "_meta": { | "name": "flint_spark_catalog_default_data_type_table_skipping_index", | "version": "${current()}", diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala index af497eb2b..ff114b8e2 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala @@ -13,7 +13,7 @@ import org.json4s.{Formats, NoTypeHints} import org.json4s.native.JsonMethods.{compact, parse, render} import org.json4s.native.Serialization import org.opensearch.flint.core.FlintOptions -import org.opensearch.flint.core.storage.FlintOpenSearchClient +import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName import org.scalatest.matchers.must.Matchers.defined import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} @@ -25,8 +25,8 @@ import org.apache.spark.sql.flint.config.FlintSparkConf.CHECKPOINT_MANDATORY class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite with ExplainSuiteHelper { /** Test table and index name */ - private val testTable = "spark_catalog.default.skipping_sql_test" - private val testIndex = getSkippingIndexName(testTable) + protected val testTable = s"$catalogName.default.skipping_sql_test" + protected val testIndex = getSkippingIndexName(testTable) override def beforeEach(): Unit = { super.beforeAll() @@ -150,7 +150,8 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite with ExplainSuit |""".stripMargin)).foreach { case (query, expectedParamJson) => test(s"create skipping index with bloom filter parameters $expectedParamJson") { sql(query) - val metadata = flint.describeIndex(testIndex).get.metadata().getContent + val metadata = FlintOpenSearchIndexMetadataService.serialize( + flint.describeIndex(testIndex).get.metadata()) val parameters = compact(render(parse(metadata) \\ "parameters")) parameters should matchJson(expectedParamJson) } @@ -187,10 +188,11 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite with ExplainSuit |""".stripMargin) // Check if the index setting option is set to OS index setting - val flintClient = new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava)) + val flintIndexMetadataService = + new FlintOpenSearchIndexMetadataService(new FlintOptions(openSearchOptions.asJava)) implicit val formats: Formats = Serialization.formats(NoTypeHints) - val settings = parse(flintClient.getIndexMetadata(testIndex).indexSettings.get) + val settings = parse(flintIndexMetadataService.getIndexMetadata(testIndex).indexSettings.get) (settings \ "index.number_of_shards").extract[String] shouldBe "3" (settings \ "index.number_of_replicas").extract[String] shouldBe "2" } @@ -201,7 +203,9 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite with ExplainSuit "`struct_col`.`field1`.`subfield` VALUE_SET, `struct_col`.`field2` MIN_MAX").foreach { columnSkipTypes => test(s"build skipping index for nested field $columnSkipTypes") { - val testTable = "spark_catalog.default.nested_field_table" + assume(tableType != "iceberg", "ignore iceberg skipping index query rewrite test") + + val testTable = s"$catalogName.default.nested_field_table" val testIndex = getSkippingIndexName(testTable) withTable(testTable) { createStructTable(testTable) @@ -339,7 +343,7 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite with ExplainSuit test("create skipping index with quoted table and column name") { sql(s""" - | CREATE SKIPPING INDEX ON `spark_catalog`.`default`.`skipping_sql_test` + | CREATE SKIPPING INDEX ON `$catalogName`.`default`.`skipping_sql_test` | ( | `year` PARTITION, | `name` VALUE_SET, @@ -385,17 +389,26 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite with ExplainSuit sql("USE sample") // Create index without database name specified - sql("CREATE TABLE test1 (name STRING) USING CSV") + sql(s"CREATE TABLE test1 (name STRING) USING $tableType") sql("CREATE SKIPPING INDEX ON test1 (name VALUE_SET)") // Create index with database name specified - sql("CREATE TABLE test2 (name STRING) USING CSV") + sql(s"CREATE TABLE test2 (name STRING) USING $tableType") sql("CREATE SKIPPING INDEX ON sample.test2 (name VALUE_SET)") try { - flint.describeIndex("flint_spark_catalog_sample_test1_skipping_index") shouldBe defined - flint.describeIndex("flint_spark_catalog_sample_test2_skipping_index") shouldBe defined + flint.describeIndex(s"flint_${catalogName}_sample_test1_skipping_index") shouldBe defined + flint.describeIndex(s"flint_${catalogName}_sample_test2_skipping_index") shouldBe defined } finally { + + /** + * TODO: REMOVE DROP TABLE when iceberg support CASCADE. More reading at + * https://github.com/apache/iceberg/pull/7275. + */ + if (tableType.equalsIgnoreCase("iceberg")) { + sql("DROP TABLE test1") + sql("DROP TABLE test2") + } sql("DROP DATABASE sample CASCADE") } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index bc5fe7999..f631a7e19 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -35,6 +35,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit lazy protected val flint: FlintSpark = new FlintSpark(spark) lazy protected val tableType: String = "CSV" lazy protected val tableOptions: String = "OPTIONS (header 'false', delimiter '\t')" + lazy protected val catalogName: String = "spark_catalog" override protected def sparkConf: SparkConf = { val conf = super.sparkConf @@ -69,7 +70,8 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit flint.deleteIndex(testIndex) flint.vacuumIndex(testIndex) } catch { - case _: IllegalStateException => + // Forcefully delete index data and log entry in case of any errors, such as version conflict + case _: Exception => if (openSearchClient .indices() .exists(new GetIndexRequest(testIndex), RequestOptions.DEFAULT)) { @@ -308,6 +310,40 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | """.stripMargin) } + protected def createDuplicationNullableTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | id INT, + | name STRING, + | category STRING + | ) + | USING $tableType $tableOptions + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES (1, "A", "X"), + | (2, "A", "Y"), + | (3, "A", "Y"), + | (4, "B", "Z"), + | (5, "B", "Z"), + | (6, "B", "Z"), + | (7, "C", "X"), + | (8, null, "Y"), + | (9, "D", "Z"), + | (10, "E", null), + | (11, "A", "X"), + | (12, "A", "Y"), + | (13, null, "X"), + | (14, "B", null), + | (15, "B", "Y"), + | (16, null, "Z"), + | (17, "C", "X"), + | (18, null, null) + | """.stripMargin) + } + protected def createTimeSeriesTable(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala index c42822f71..f2ed92adc 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala @@ -8,6 +8,7 @@ package org.opensearch.flint.spark import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson import org.json4s.native.JsonMethods._ import org.opensearch.client.RequestOptions +import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName import org.opensearch.index.query.QueryBuilders import org.opensearch.index.reindex.DeleteByQueryRequest @@ -57,7 +58,11 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { // Verify index after update val indexFinal = flint.describeIndex(testIndex).get val optionJson = - compact(render(parse(indexFinal.metadata().getContent) \ "_meta" \ "options")) + compact( + render( + parse( + FlintOpenSearchIndexMetadataService.serialize( + indexFinal.metadata())) \ "_meta" \ "options")) optionJson should matchJson(s""" | { | "auto_refresh": "true", diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkWindowingFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkWindowingFunctionITSuite.scala index 4cba9099c..79e70655b 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkWindowingFunctionITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkWindowingFunctionITSuite.scala @@ -7,11 +7,12 @@ package org.opensearch.flint.spark import java.sql.Timestamp +import org.scalatest.Ignore import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.FlintSuite import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} class FlintSparkWindowingFunctionITSuite extends QueryTest with FlintSuite { @@ -26,8 +27,21 @@ class FlintSparkWindowingFunctionITSuite extends QueryTest with FlintSuite { val resultDF = inputDF.selectExpr("TUMBLE(timestamp, '10 minutes')") - resultDF.schema shouldBe StructType.fromDDL( - "window struct NOT NULL") + // Since Spark 3.4. https://issues.apache.org/jira/browse/SPARK-40821 + val expected = + StructType(StructType.fromDDL("window struct NOT NULL").map { + case StructField(name, dataType: StructType, nullable, _) if name == "window" => + StructField( + name, + dataType, + nullable, + metadata = new MetadataBuilder() + .putBoolean("spark.timeWindow", true) + .build()) + case other => other + }) + + resultDF.schema shouldBe expected checkAnswer( resultDF, Seq( diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergSuite.scala index 2ae0d157a..e7ce5316b 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergSuite.scala @@ -22,13 +22,20 @@ trait FlintSparkIcebergSuite extends FlintSparkSuite { // You can also override tableOptions if Iceberg requires different options override lazy protected val tableOptions: String = "" + override lazy protected val catalogName: String = "local" + // Override the sparkConf method to include Iceberg-specific configurations override protected def sparkConf: SparkConf = { val conf = super.sparkConf // Set Iceberg-specific Spark configurations .set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog") - .set("spark.sql.catalog.spark_catalog.type", "hadoop") - .set("spark.sql.catalog.spark_catalog.warehouse", s"spark-warehouse/${suiteName}") + .set("spark.sql.catalog.spark_catalog.type", "hive") + .set(s"spark.sql.catalog.$catalogName", "org.apache.iceberg.spark.SparkCatalog") + .set(s"spark.sql.catalog.$catalogName.type", "hadoop") + // Required by IT(create skipping index on table without database name) + .set(s"spark.sql.catalog.$catalogName.default-namespace", "default") + .set(s"spark.sql.catalog.$catalogName.warehouse", s"spark-warehouse/${suiteName}") + .set(s"spark.sql.defaultCatalog", s"$catalogName") .set( "spark.sql.extensions", List( diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala index b3abf8438..1e80c94b4 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, Divide, Floor, Literal, Multiply, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, Divide, EqualTo, Floor, Literal, Multiply, Not, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -262,4 +262,93 @@ class FlintSparkPPLAggregationWithSpanITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + + /** + * | age_span | age_stddev_samp | + * |:---------|-------------------:| + * | 20 | 3.5355339059327378 | + */ + test( + "create ppl age sample stddev by span of interval of 10 years query with country filter test ") { + val frame = sql(s""" + | source = $testTable | where country != 'USA' | stats stddev_samp(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(3.5355339059327378d, 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val countryField = UnresolvedAttribute("country") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false), + "stddev_samp(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val filterExpr = Not(EqualTo(countryField, Literal("USA"))) + val filterPlan = Filter(filterExpr, table) + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + /** + * | age_span | age_stddev_pop | + * |:---------|---------------:| + * | 20 | 2.5 | + * | 30 | 0 | + */ + test( + "create ppl age population stddev by span of interval of 10 years query with state filter test ") { + val frame = sql(s""" + | source = $testTable | where state != 'California' | stats stddev_pop(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2.5d, 20L), Row(0d, 30L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val stateField = UnresolvedAttribute("state") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false), + "stddev_pop(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val filterExpr = Not(EqualTo(stateField, Literal("California"))) + val filterPlan = Filter(filterExpr, table) + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala index 3bc227e7d..4f9d4c64e 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala @@ -380,4 +380,239 @@ class FlintSparkPPLAggregationsITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + + test("create ppl age sample stddev") { + val frame = sql(s""" + | source = $testTable| stats stddev_samp(age) + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(22.86737122335374d)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false), + "stddev_samp(age)")() + val aggregatePlan = + Aggregate(Seq.empty, Seq(aggregateExpressions), table) + val expectedPlan = Project(star, aggregatePlan) + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + + test("create ppl age sample stddev group by country query test with sort") { + val frame = sql(s""" + | source = $testTable | stats stddev_samp(age) by country | sort country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row(3.5355339059327378d, "Canada"), Row(28.284271247461902d, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false), + "stddev_samp(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + + test("create ppl age sample stddev group by country with state filter query test") { + val frame = sql(s""" + | source = $testTable | where state != 'Ontario' | stats stddev_samp(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(null, "Canada"), Row(28.284271247461902d, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val filterExpr = Not(EqualTo(stateField, Literal("Ontario"))) + val filterPlan = Filter(filterExpr, table) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false), + "stddev_samp(age)")() + val productAlias = Alias(countryField, "country")() + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl age population stddev") { + val frame = sql(s""" + | source = $testTable| stats stddev_pop(age) + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(19.803724397193573d)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false), + "stddev_pop(age)")() + val aggregatePlan = + Aggregate(Seq.empty, Seq(aggregateExpressions), table) + val expectedPlan = Project(star, aggregatePlan) + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + + test("create ppl age population stddev group by country query test with sort") { + val frame = sql(s""" + | source = $testTable | stats stddev_pop(age) by country | sort country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2.5d, "Canada"), Row(20d, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false), + "stddev_pop(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + + test("create ppl age population stddev group by country with state filter query test") { + val frame = sql(s""" + | source = $testTable | where state != 'Ontario' | stats stddev_pop(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(0d, "Canada"), Row(20d, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val filterExpr = Not(EqualTo(stateField, Literal("Ontario"))) + val filterPlan = Filter(filterExpr, table) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false), + "stddev_pop(age)")() + val productAlias = Alias(countryField, "country")() + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala index fc77b7156..7d51e123d 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala @@ -6,9 +6,11 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar, UnresolvedTableOrView} import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.command.DescribeTableCommand import org.apache.spark.sql.streaming.StreamTest class FlintSparkPPLBasicITSuite @@ -36,6 +38,51 @@ class FlintSparkPPLBasicITSuite } } + test("describe (extended) table query test") { + val testTableQuoted = "`spark_catalog`.`default`.`flint_ppl_test`" + Seq(testTable, testTableQuoted).foreach { table => + val frame = sql(s""" + describe flint_ppl_test + """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("name", "string", null), + Row("age", "int", null), + Row("state", "string", null), + Row("country", "string", null), + Row("year", "int", null), + Row("month", "int", null), + Row("# Partition Information", "", ""), + Row("# col_name", "data_type", "comment"), + Row("year", "int", null), + Row("month", "int", null)) + + // Convert actual results to a Set for quick lookup + val resultsSet: Set[Row] = results.toSet + // Check that each expected row is present in the actual results + expectedResults.foreach { expectedRow => + assert( + resultsSet.contains(expectedRow), + s"Expected row $expectedRow not found in results") + } + // Retrieve the logical plan + val logicalPlan: LogicalPlan = + frame.queryExecution.commandExecuted.asInstanceOf[CommandResult].commandLogicalPlan + // Define the expected logical plan + val expectedPlan: LogicalPlan = + DescribeTableCommand( + TableIdentifier("flint_ppl_test"), + Map.empty[String, String], + isExtended = true, + output = DescribeRelation.getOutputAttrs) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + } + test("create ppl simple query test") { val testTableQuoted = "`spark_catalog`.`default`.`flint_ppl_test`" Seq(testTable, testTableQuoted).foreach { table => @@ -208,7 +255,7 @@ class FlintSparkPPLBasicITSuite val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, limitPlan) - val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan); + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupITSuite.scala new file mode 100644 index 000000000..06c90527d --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupITSuite.scala @@ -0,0 +1,310 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions.{And, IsNotNull, IsNull, Or} +import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Filter, LogicalPlan, Project, Union} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLDedupITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createDuplicationNullableTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test dedupe 1 name") { + val frame = sql(s""" + | source = $testTable | dedup 1 name | fields name + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row("A"), Row("B"), Row("C"), Row("D"), Row("E")) + implicit val oneColRowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq(UnresolvedAttribute("name")) + val dedupKeys = Seq(UnresolvedAttribute("name")) + val filter = Filter(IsNotNull(UnresolvedAttribute("name")), table) + val expectedPlan = Project(fieldsProjectList, Deduplicate(dedupKeys, filter)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test dedupe 1 name, category") { + val frame = sql(s""" + | source = $testTable | dedup 1 name, category | fields name, category + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("A", "X"), + Row("A", "Y"), + Row("B", "Z"), + Row("C", "X"), + Row("D", "Z"), + Row("B", "Y")) + implicit val twoColsRowOrdering: Ordering[Row] = + Ordering.by[Row, (String, String)](row => (row.getAs(0), row.getAs(1))) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("category")) + val dedupKeys = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("category")) + val filter = Filter( + And(IsNotNull(UnresolvedAttribute("name")), IsNotNull(UnresolvedAttribute("category"))), + table) + val expectedPlan = Project(fieldsProjectList, Deduplicate(dedupKeys, filter)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test dedupe 1 name KEEPEMPTY=true") { + val frame = sql(s""" + | source = $testTable | dedup 1 name KEEPEMPTY=true | fields name, category + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("A", "X"), + Row("B", "Z"), + Row("C", "X"), + Row("D", "Z"), + Row("E", null), + Row(null, "Y"), + Row(null, "X"), + Row(null, "Z"), + Row(null, null)) + implicit val nullableTwoColsRowOrdering: Ordering[Row] = + Ordering.by[Row, (String, String)](row => { + val value0 = row.getAs[String](0) + val value1 = row.getAs[String](1) + ( + if (value0 == null) String.valueOf(Int.MaxValue) else value0, + if (value1 == null) String.valueOf(Int.MaxValue) else value1) + }) + assert( + results.sorted + .map(_.getAs[String](0)) + .sameElements(expectedResults.sorted.map(_.getAs[String](0)))) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val fieldsProjectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("category")) + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val isNotNullFilter = + Filter(IsNotNull(UnresolvedAttribute("name")), table) + val deduplicate = Deduplicate(Seq(UnresolvedAttribute("name")), isNotNullFilter) + val isNullFilter = Filter(IsNull(UnresolvedAttribute("name")), table) + val union = Union(deduplicate, isNullFilter) + val expectedPlan = Project(fieldsProjectList, union) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test dedupe 1 name, category KEEPEMPTY=true") { + val frame = sql(s""" + | source = $testTable | dedup 1 name, category KEEPEMPTY=true | fields name, category + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("A", "X"), + Row("A", "Y"), + Row("B", "Z"), + Row("C", "X"), + Row("D", "Z"), + Row("B", "Y"), + Row(null, "Y"), + Row("E", null), + Row(null, "X"), + Row("B", null), + Row(null, "Z"), + Row(null, null)) + implicit val nullableTwoColsRowOrdering: Ordering[Row] = + Ordering.by[Row, (String, String)](row => { + val value0 = row.getAs[String](0) + val value1 = row.getAs[String](1) + ( + if (value0 == null) String.valueOf(Int.MaxValue) else value0, + if (value1 == null) String.valueOf(Int.MaxValue) else value1) + }) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("category")) + val isNotNullFilter = Filter( + And(IsNotNull(UnresolvedAttribute("name")), IsNotNull(UnresolvedAttribute("category"))), + table) + val deduplicate = Deduplicate( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("category")), + isNotNullFilter) + val isNullFilter = Filter( + Or(IsNull(UnresolvedAttribute("name")), IsNull(UnresolvedAttribute("category"))), + table) + val union = Union(deduplicate, isNullFilter) + val expectedPlan = Project(fieldsProjectList, union) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test 1 name CONSECUTIVE=true") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable | dedup 1 name CONSECUTIVE=true | fields name + | """.stripMargin)) + assert(ex.getMessage.contains("Consecutive deduplication is not supported")) + } + + test("test 1 name KEEPEMPTY=true CONSECUTIVE=true") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable | dedup 1 name KEEPEMPTY=true CONSECUTIVE=true | fields name + | """.stripMargin)) + assert(ex.getMessage.contains("Consecutive deduplication is not supported")) + } + + ignore("test dedupe 2 name") { + val frame = sql(s""" + | source = $testTable| dedup 2 name | fields name + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = + Array(Row("A"), Row("A"), Row("B"), Row("B"), Row("C"), Row("C"), Row("D"), Row("E")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + ignore("test dedupe 2 name, category") { + val frame = sql(s""" + | source = $testTable| dedup 2 name, category | fields name, category + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("A", "X"), + Row("A", "X"), + Row("A", "Y"), + Row("A", "Y"), + Row("B", "Y"), + Row("B", "Z"), + Row("B", "Z"), + Row("C", "X"), + Row("C", "X"), + Row("D", "Z")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](row => { + val value = row.getAs[String](0) + if (value == null) String.valueOf(Int.MaxValue) else value + }) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + ignore("test dedupe 2 name KEEPEMPTY=true") { + val frame = sql(s""" + | source = $testTable| dedup 2 name KEEPEMPTY=true | fields name, category + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("A", "X"), + Row("A", "Y"), + Row("B", "Z"), + Row("B", "Z"), + Row("C", "X"), + Row("C", "X"), + Row("D", "Z"), + Row("E", null), + Row(null, "Y"), + Row(null, "X"), + Row(null, "Z"), + Row(null, null)) + implicit val nullableTwoColsRowOrdering: Ordering[Row] = + Ordering.by[Row, (String, String)](row => { + val value0 = row.getAs[String](0) + val value1 = row.getAs[String](1) + ( + if (value0 == null) String.valueOf(Int.MaxValue) else value0, + if (value1 == null) String.valueOf(Int.MaxValue) else value1) + }) + assert( + results.sorted + .map(_.getAs[String](0)) + .sameElements(expectedResults.sorted.map(_.getAs[String](0)))) + } + + ignore("test dedupe 2 name, category KEEPEMPTY=true") { + val frame = sql(s""" + | source = $testTable| dedup 2 name, category KEEPEMPTY=true | fields name, category + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("A", "X"), + Row("A", "X"), + Row("A", "Y"), + Row("A", "Y"), + Row("B", "Y"), + Row("B", "Z"), + Row("B", "Z"), + Row("C", "X"), + Row("C", "X"), + Row("D", "Z"), + Row(null, "Y"), + Row("E", null), + Row(null, "X"), + Row("B", null), + Row(null, "Z"), + Row(null, null)) + implicit val nullableTwoColsRowOrdering: Ordering[Row] = + Ordering.by[Row, (String, String)](row => { + val value0 = row.getAs[String](0) + val value1 = row.getAs[String](1) + ( + if (value0 == null) String.valueOf(Int.MaxValue) else value0, + if (value1 == null) String.valueOf(Int.MaxValue) else value1) + }) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test 2 name CONSECUTIVE=true") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable | dedup 2 name CONSECUTIVE=true | fields name + | """.stripMargin)) + assert(ex.getMessage.contains("Consecutive deduplication is not supported")) + } + + test("test 2 name KEEPEMPTY=true CONSECUTIVE=true") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable | dedup 2 name KEEPEMPTY=true CONSECUTIVE=true | fields name + | """.stripMargin)) + assert(ex.getMessage.contains("Consecutive deduplication is not supported")) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala index 407c2cb3b..ea77ff990 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala @@ -171,14 +171,14 @@ class FlintSparkPPLEvalITSuite val ex = intercept[AnalysisException](sql(s""" | source = $testTable | eval age = 40 | eval name = upper(name) | sort name | fields name, age, state | """.stripMargin)) - assert(ex.getMessage().contains("Reference 'name' is ambiguous")) + assert(ex.getMessage().contains("Reference `name` is ambiguous")) } test("test overriding existing fields: throw exception when specify the new field in where") { val ex = intercept[AnalysisException](sql(s""" | source = $testTable | eval age = abs(age) | where age < 50 | """.stripMargin)) - assert(ex.getMessage().contains("Reference 'age' is ambiguous")) + assert(ex.getMessage().contains("Reference `age` is ambiguous")) } test( @@ -186,7 +186,7 @@ class FlintSparkPPLEvalITSuite val ex = intercept[AnalysisException](sql(s""" | source = $testTable | eval age = abs(age) | stats avg(age) | """.stripMargin)) - assert(ex.getMessage().contains("Reference 'age' is ambiguous")) + assert(ex.getMessage().contains("Reference `age` is ambiguous")) } test( @@ -194,7 +194,7 @@ class FlintSparkPPLEvalITSuite val ex = intercept[AnalysisException](sql(s""" | source = $testTable | eval country = upper(country) | stats avg(age) by country | """.stripMargin)) - assert(ex.getMessage().contains("Reference 'country' is ambiguous")) + assert(ex.getMessage().contains("Reference `country` is ambiguous")) } test("test override existing fields: the eval field doesn't appear in fields command") { @@ -480,12 +480,7 @@ class FlintSparkPPLEvalITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } - // +--------------------------------+ - // | Below tests are not supported | - // +--------------------------------+ - // Todo: Upgrading spark version to 3.4.0 and above could fix this test. - // https://issues.apache.org/jira/browse/SPARK-27561 - ignore("test lateral eval expressions references - SPARK-27561 required") { + test("test lateral eval expressions references") { val frame = sql(s""" | source = $testTable | eval col1 = 1, col2 = col1 | fields name, age, col2 | """.stripMargin) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/sessioncatalog/FlintSessionCatalogSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/sessioncatalog/FlintSessionCatalogSuite.scala new file mode 100644 index 000000000..fb3d4bbda --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/sessioncatalog/FlintSessionCatalogSuite.scala @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.sessioncatalog + +import org.opensearch.flint.spark.FlintSparkSuite + +import org.apache.spark.SparkConf + +/** + * Test with FlintDelegatingSessionCatalog. + */ +trait FlintSessionCatalogSuite extends FlintSparkSuite { + // Override catalog name + override lazy protected val catalogName: String = "mycatalog" + + override protected def sparkConf: SparkConf = { + val conf = super.sparkConf + .set("spark.sql.catalog.mycatalog", "org.opensearch.sql.FlintDelegatingSessionCatalog") + .set("spark.sql.defaultCatalog", catalogName) + conf + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/sessioncatalog/FlintSparkSessionCatalogMaterializedViewITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/sessioncatalog/FlintSparkSessionCatalogMaterializedViewITSuite.scala new file mode 100644 index 000000000..25567ba4e --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/sessioncatalog/FlintSparkSessionCatalogMaterializedViewITSuite.scala @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.sessioncatalog + +import org.opensearch.flint.spark.FlintSparkMaterializedViewSqlITSuite + +/** + * Test MaterializedView with FlintDelegatingSessionCatalog. + */ +class FlintSparkSessionCatalogMaterializedViewITSuite + extends FlintSparkMaterializedViewSqlITSuite + with FlintSessionCatalogSuite {} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/sessioncatalog/FlintSparkSessionCatalogSkippingIndexITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/sessioncatalog/FlintSparkSessionCatalogSkippingIndexITSuite.scala new file mode 100644 index 000000000..7b29d5883 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/sessioncatalog/FlintSparkSessionCatalogSkippingIndexITSuite.scala @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.sessioncatalog + +import org.opensearch.flint.spark.FlintSparkSkippingIndexSqlITSuite + +/** + * Test Skipping Index with FlintDelegatingSessionCatalog. + */ +class FlintSparkSessionCatalogSkippingIndexITSuite + extends FlintSparkSkippingIndexSqlITSuite + with FlintSessionCatalogSuite {} diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index 1538f43be..67cecd48d 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -221,8 +221,11 @@ Next tasks ahead will resolve this: This section describes the next steps planned for enabling additional commands and gamer translation. -#### Supported -The next samples of PPL queries are currently supported: +#### Example PPL Queries +See the next samples of PPL queries : + +**Describe** + - `describe table` This command is equal to the `DESCRIBE EXTENDED table` SQL command **Fields** - `source = table` @@ -249,7 +252,7 @@ Assumptions: `a`, `b`, `c` are existing fields in `table` - `source = table | eval n = now() | eval t = unix_timestamp(a) | fields n,t` - `source = table | eval f = a | where f > 1 | sort f | fields a,b,c | head 5` - `source = table | eval f = a * 2 | eval h = f * 2 | fields a,f,h` - - `source = table | eval f = a * 2, h = f * 2 | fields a,f,h` (Spark 3.4.0+ required) + - `source = table | eval f = a * 2, h = f * 2 | fields a,f,h` - `source = table | eval f = a * 2, h = b | stats avg(f) by h` Limitation: Overriding existing field is unsupported, following queries throw exceptions with "Reference 'a' is ambiguous" @@ -262,6 +265,8 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source = table | where a < 50 | stats avg(c) ` - `source = table | stats max(c) by b` - `source = table | stats count(c) by b | head 5` + - `source = table | stats stddev_samp(c)` + - `source = table | stats stddev_pop(c)` **Aggregations With Span** - `source = table | stats count(a) by span(a, 10) as a_span` @@ -272,31 +277,31 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source = table | stats sum(productsAmount) by span(transactionDate, 1d) as age_date | sort age_date` - `source = table | stats sum(productsAmount) by span(transactionDate, 1w) as age_date, productId` -> For additional details, review [FlintSparkPPLTimeWindowITSuite.scala](../integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTimeWindowITSuite.scala) +**Dedup** + +- `source = table | dedup a | fields a,b,c` +- `source = table | dedup a,b | fields a,b,c` +- `source = table | dedup a keepempty=true | fields a,b,c` +- `source = table | dedup a,b keepempty=true | fields a,b,c` +- `source = table | dedup 1 a | fields a,b,c` +- `source = table | dedup 1 a,b | fields a,b,c` +- `source = table | dedup 1 a keepempty=true | fields a,b,c` +- `source = table | dedup 1 a,b keepempty=true | fields a,b,c` +- `source = table | dedup 1 a consecutive=true| fields a,b,c` (Unsupported) +- `source = table | dedup 2 a | fields a,b,c` (Unsupported) -#### Supported Commands: - - `search` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/search.rst) - - `where` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/where.rst) - - `fields` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/fields.rst) - - `eval` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/eval.rst) - - `head` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/head.rst) - - `stats` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/stats.rst) (supports AVG, COUNT, DISTINCT_COUNT, MAX, MIN and SUM aggregation functions) - - `sort` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/sort.rst) - - `correlation` - [See details](../docs/PPL-Correlation-command.md) -> For additional details, review [Integration Tests](../integ-test/src/test/scala/org/opensearch/flint/spark/) - +For additional details on PPL commands - view [PPL Commands Docs](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/index.rst) + --- -#### Planned Support +For additional details on Spark PPL commands project, see [PPL Project](https://github.com/orgs/opensearch-project/projects/214/views/2) +For additional details on Spark PPL commands support campaign, see [PPL Commands Campaign](https://github.com/opensearch-project/opensearch-spark/issues/408) + +#### Experimental Commands: + - `correlation` - [See details](../docs/PPL-Correlation-command.md) - - support the `explain` command to return the explained PPL query logical plan and expected execution plan +> This is an experimental command - it may be removed in future versions - - attend [sort](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/sort.rst) partially supported, missing capability to sort by alias field (span like or aggregation) - - attend `alias` - partially supported, missing capability to sort by / group-by alias field name - - add [conditions](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/functions/condition.rst) support - - add [top](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/top.rst) support - - add [cast](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/functions/conversion.rst) support - - add [math](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/functions/math.rst) support - - add [deduplicate](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/dedup.rst) support \ No newline at end of file + \ No newline at end of file diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 2d0986890..6f56550c9 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -27,6 +27,7 @@ queryStatement // commands pplCommands : searchCommand + | describeCommand ; commands @@ -34,6 +35,7 @@ commands | correlateCommand | fieldsCommand | statsCommand + | dedupCommand | sortCommand | headCommand | evalCommand @@ -232,6 +234,8 @@ statsFunctionName | SUM | MIN | MAX + | STDDEV_SAMP + | STDDEV_POP ; takeAggFunction diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java new file mode 100644 index 000000000..5fd237bcb --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +/** + * Extend Relation to describe the table itself + */ +public class DescribeRelation extends Relation{ + public DescribeRelation(UnresolvedExpression tableName) { + super(tableName); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 2f601e56f..f361f53d1 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -5,6 +5,7 @@ package org.opensearch.sql.ppl; +import org.apache.spark.sql.catalyst.TableIdentifier; import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; @@ -14,8 +15,14 @@ import org.apache.spark.sql.catalyst.expressions.RegExpExtract; import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; +import org.apache.spark.sql.catalyst.plans.logical.Deduplicate; +import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; import org.apache.spark.sql.catalyst.plans.logical.Limit; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.execution.command.DescribeTableCommand; +import org.apache.spark.sql.catalyst.plans.logical.Union; +import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.opensearch.sql.ast.AbstractNodeVisitor; @@ -47,6 +54,7 @@ import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Correlation; import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ast.tree.DescribeRelation; import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; @@ -61,6 +69,7 @@ import org.opensearch.sql.ppl.utils.ComparatorTransformer; import org.opensearch.sql.ppl.utils.SortUtils; import scala.Option; +import scala.Option$; import scala.collection.Seq; import java.util.ArrayList; @@ -109,6 +118,26 @@ public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { @Override public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { + if (node instanceof DescribeRelation) { + TableIdentifier identifier; + if (node.getTableQualifiedName().getParts().size() == 1) { + identifier = new TableIdentifier(node.getTableQualifiedName().getParts().get(0)); + } else if (node.getTableQualifiedName().getParts().size() == 2) { + identifier = new TableIdentifier( + node.getTableQualifiedName().getParts().get(1), + Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0))); + } else { + throw new IllegalArgumentException("Invalid table name: " + node.getTableQualifiedName() + + " Syntax: [ database_name. ] table_name"); + } + return context.with( + new DescribeTableCommand( + identifier, + scala.collection.immutable.Map$.MODULE$.empty(), + true, + DescribeRelation$.MODULE$.getOutputAttrs())); + } + //regular sql algebraic relations node.getTableName().forEach(t -> // Resolving the qualifiedName which is composed of a datasource.schema.table context.with(new UnresolvedRelation(seq(of(t.split("\\."))), CaseInsensitiveStringMap.empty(), false)) @@ -300,7 +329,105 @@ public LogicalPlan visitWindowFunction(WindowFunction node, CatalystPlanContext @Override public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : dedupe "); + node.getChild().get(0).accept(this, context); + List options = node.getOptions(); + Integer allowedDuplication = (Integer) options.get(0).getValue().getValue(); + Boolean keepEmpty = (Boolean) options.get(1).getValue().getValue(); + Boolean consecutive = (Boolean) options.get(2).getValue().getValue(); + if (allowedDuplication <= 0) { + throw new IllegalArgumentException("Number of duplicate events must be greater than 0"); + } + if (consecutive) { + // Spark is not able to remove only consecutive events + throw new UnsupportedOperationException("Consecutive deduplication is not supported"); + } + visitFieldList(node.getFields(), context); + // Columns to deduplicate + Seq dedupFields + = context.retainAllNamedParseExpressions(e -> (org.apache.spark.sql.catalyst.expressions.Attribute) e); + // Although we can also use the Window operator to translate this as allowedDuplication > 1 did, + // adding Aggregate operator could achieve better performance. + if (allowedDuplication == 1) { + if (keepEmpty) { + // Union + // :- Deduplicate ['a, 'b] + // : +- Filter (isnotnull('a) AND isnotnull('b) + // : +- Project + // : +- UnresolvedRelation + // +- Filter (isnull('a) OR isnull('a)) + // +- Project + // +- UnresolvedRelation + + context.apply(p -> { + Expression isNullExpr = buildIsNullFilterExpression(node, context); + LogicalPlan right = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNullExpr, p); + + Expression isNotNullExpr = buildIsNotNullFilterExpression(node, context); + LogicalPlan left = + new Deduplicate(dedupFields, + new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); + return new Union(seq(left, right), false, false); + }); + return context.getPlan(); + } else { + // Deduplicate ['a, 'b] + // +- Filter (isnotnull('a) AND isnotnull('b)) + // +- Project + // +- UnresolvedRelation + + Expression isNotNullExpr = buildIsNotNullFilterExpression(node, context); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); + // Todo DeduplicateWithinWatermark in streaming dataset? + return context.apply(p -> new Deduplicate(dedupFields, p)); + } + } else { + // TODO + throw new UnsupportedOperationException("Number of duplicate events greater than 1 is not supported"); + } + } + + private Expression buildIsNotNullFilterExpression(Dedupe node, CatalystPlanContext context) { + visitFieldList(node.getFields(), context); + Seq isNotNullExpressions = + context.retainAllNamedParseExpressions( + org.apache.spark.sql.catalyst.expressions.IsNotNull$.MODULE$::apply); + + Expression isNotNullExpr; + if (isNotNullExpressions.size() == 1) { + isNotNullExpr = isNotNullExpressions.apply(0); + } else { + isNotNullExpr = isNotNullExpressions.reduce( + new scala.Function2() { + @Override + public Expression apply(Expression e1, Expression e2) { + return new org.apache.spark.sql.catalyst.expressions.And(e1, e2); + } + } + ); + } + return isNotNullExpr; + } + + private Expression buildIsNullFilterExpression(Dedupe node, CatalystPlanContext context) { + visitFieldList(node.getFields(), context); + Seq isNullExpressions = + context.retainAllNamedParseExpressions( + org.apache.spark.sql.catalyst.expressions.IsNull$.MODULE$::apply); + + Expression isNullExpr; + if (isNullExpressions.size() == 1) { + isNullExpr = isNullExpressions.apply(0); + } else { + isNullExpr = isNullExpressions.reduce( + new scala.Function2() { + @Override + public Expression apply(Expression e1, Expression e2) { + return new org.apache.spark.sql.catalyst.expressions.Or(e1, e2); + } + } + ); + } + return isNullExpr; } /** diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 9973f4676..e94d4e0f4 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -28,6 +28,7 @@ import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Correlation; import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ast.tree.DescribeRelation; import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; @@ -97,7 +98,7 @@ public UnresolvedPlan visitDescribeCommand(OpenSearchPPLParser.DescribeCommandCo final Relation table = (Relation) visitTableSourceClause(ctx.tableSourceClause()); QualifiedName tableQualifiedName = table.getTableQualifiedName(); ArrayList parts = new ArrayList<>(tableQualifiedName.getParts()); - return new Relation(new QualifiedName(parts)); + return new DescribeRelation(new QualifiedName(parts)); } /** Where command. */ diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java index 23ca992d9..7792dbecd 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java @@ -15,6 +15,7 @@ import org.opensearch.sql.ast.statement.Explain; import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; +import org.opensearch.sql.ast.tree.DescribeRelation; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.UnresolvedPlan; @@ -78,11 +79,13 @@ public Object build() { } } - private UnresolvedPlan addSelectAll(UnresolvedPlan plan) { - if ((plan instanceof Project) && !((Project) plan).isExcluded()) { - return plan; - } else { - return new Project(ImmutableList.of(AllFields.of())).attach(plan); + private UnresolvedPlan addSelectAll(UnresolvedPlan plan) { + if ((plan instanceof Project) && !((Project) plan).isExcluded()) { + return plan; + } else if (plan instanceof DescribeRelation) { + return plan; + } else { + return new Project(ImmutableList.of(AllFields.of())).attach(plan); + } } - } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java index e15324cc0..eba60248d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java @@ -35,6 +35,10 @@ static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction return new UnresolvedFunction(seq("COUNT"), seq(arg),false, empty(),false); case SUM: return new UnresolvedFunction(seq("SUM"), seq(arg),false, empty(),false); + case STDDEV_POP: + return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + case STDDEV_SAMP: + return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); } throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName()); } diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParser.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParser.scala index 332dabc95..51618d487 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParser.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParser.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.types.{DataType, StructType} class FlintSparkPPLParser(sparkParser: ParserInterface) extends ParserInterface { /** OpenSearch (PPL) AST builder. */ - private val planTrnasormer = new CatalystQueryPlanVisitor() + private val planTransformer = new CatalystQueryPlanVisitor() private val pplParser = new PPLSyntaxParser() @@ -55,7 +55,7 @@ class FlintSparkPPLParser(sparkParser: ParserInterface) extends ParserInterface try { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext - planTrnasormer.visit(plan(pplParser, sqlText, false), context) + planTransformer.visit(plan(pplParser, sqlText, false), context) context.getPlan } catch { // Fall back to Spark parse plan logic if flint cannot parse diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index ba634cc1c..61190294b 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -372,4 +372,226 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test price sample stddev group by product sorted") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats stddev_samp(price) by product | sort product", + false), + context) + val star = Seq(UnresolvedStar(None)) + val priceField = UnresolvedAttribute("price") + val productField = UnresolvedAttribute("product") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(productField, "product")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(priceField), isDistinct = false), + "stddev_samp(price)")() + val productAlias = Alias(productField, "product")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(productField, Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test price sample stddev with alias and filter") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table category = 'vegetable' | stats stddev_samp(price) as dev_samp", + false), + context) + val star = Seq(UnresolvedStar(None)) + val categoryField = UnresolvedAttribute("category") + val priceField = UnresolvedAttribute("price") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(priceField), isDistinct = false), + "dev_samp")()) + val filterExpr = EqualTo(categoryField, Literal("vegetable")) + val filterPlan = Filter(filterExpr, tableRelation) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test age sample stddev by span of interval of 5 years query with sort ") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats stddev_samp(age) by span(age, 5) as age_span | sort age", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false), + "stddev_samp(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(5))), Literal(5)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test number of flights sample stddev by airport with alias and limit") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats stddev_samp(no_of_flights) as dev_samp_flights by airport | head 10", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val numberOfFlightsField = UnresolvedAttribute("no_of_flights") + val airportField = UnresolvedAttribute("airport") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(airportField, "airport")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(numberOfFlightsField), isDistinct = false), + "dev_samp_flights")() + val airportAlias = Alias(airportField, "airport")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, airportAlias), tableRelation) + val planWithLimit = GlobalLimit(Literal(10), LocalLimit(Literal(10), aggregatePlan)) + val expectedPlan = Project(star, planWithLimit) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test price population stddev group by product sorted") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats stddev_pop(price) by product | sort product", + false), + context) + val star = Seq(UnresolvedStar(None)) + val priceField = UnresolvedAttribute("price") + val productField = UnresolvedAttribute("product") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(productField, "product")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(priceField), isDistinct = false), + "stddev_pop(price)")() + val productAlias = Alias(productField, "product")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(productField, Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test price population stddev with alias and filter") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table category = 'vegetable' | stats stddev_pop(price) as dev_pop", + false), + context) + val star = Seq(UnresolvedStar(None)) + val categoryField = UnresolvedAttribute("category") + val priceField = UnresolvedAttribute("price") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(priceField), isDistinct = false), + "dev_pop")()) + val filterExpr = EqualTo(categoryField, Literal("vegetable")) + val filterPlan = Filter(filterExpr, tableRelation) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test age population stddev by span of interval of 5 years query with sort ") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats stddev_pop(age) by span(age, 5) as age_span | sort age", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false), + "stddev_pop(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(5))), Literal(5)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test number of flights population stddev by airport with alias and limit") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats stddev_pop(no_of_flights) as dev_pop_flights by airport | head 50", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val numberOfFlightsField = UnresolvedAttribute("no_of_flights") + val airportField = UnresolvedAttribute("airport") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(airportField, "airport")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(numberOfFlightsField), isDistinct = false), + "dev_pop_flights")() + val airportAlias = Alias(airportField, "airport")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, airportAlias), tableRelation) + val planWithLimit = GlobalLimit(Literal(50), LocalLimit(Literal(50), aggregatePlan)) + val expectedPlan = Project(star, planWithLimit) + + comparePlans(expectedPlan, logPlan, false) + } + } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 5b94ca092..36dc014f7 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -10,10 +10,12 @@ import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Descending, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.command.DescribeTableCommand class PPLLogicalPlanBasicQueriesTranslatorTestSuite extends SparkFunSuite @@ -24,6 +26,40 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite private val planTransformer = new CatalystQueryPlanVisitor() private val pplParser = new PPLSyntaxParser() + test("test error describe clause") { + val context = new CatalystPlanContext + val thrown = intercept[IllegalArgumentException] { + planTransformer.visit(plan(pplParser, "describe t.b.c.d", false), context) + } + + assert( + thrown.getMessage === "Invalid table name: t.b.c.d Syntax: [ database_name. ] table_name") + } + + test("test simple describe clause") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "describe t", false), context) + + val expectedPlan = DescribeTableCommand( + TableIdentifier("t"), + Map.empty[String, String], + isExtended = true, + output = DescribeRelation.getOutputAttrs) + comparePlans(expectedPlan, logPlan, false) + } + + test("test FQN table describe table clause") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "describe catalog.t", false), context) + + val expectedPlan = DescribeTableCommand( + TableIdentifier("t", Option("catalog")), + Map.empty[String, String].empty, + isExtended = true, + output = DescribeRelation.getOutputAttrs) + comparePlans(expectedPlan, logPlan, false) + } + test("test simple search with only one table and no explicit fields (defaults to all fields)") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupTranslatorTestSuite.scala new file mode 100644 index 000000000..34cfcbd90 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupTranslatorTestSuite.scala @@ -0,0 +1,290 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{And, IsNotNull, IsNull, NamedExpression, Or} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Filter, Project, Union} + +class PPLLogicalPlanDedupTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test dedup a") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=table | dedup a | fields a", false), context) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("a")) + val filter = Filter(IsNotNull(UnresolvedAttribute("a")), UnresolvedRelation(Seq("table"))) + val deduplicate = Deduplicate(Seq(UnresolvedAttribute("a")), filter) + val expectedPlan = Project(projectList, deduplicate) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test dedup a, b, c") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=table | dedup a, b, c | fields a, b, c", false), + context) + + val projectList: Seq[NamedExpression] = + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")) + val filter = Filter( + And( + And(IsNotNull(UnresolvedAttribute("a")), IsNotNull(UnresolvedAttribute("b"))), + IsNotNull(UnresolvedAttribute("c"))), + UnresolvedRelation(Seq("table"))) + val deduplicate = Deduplicate( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")), + filter) + val expectedPlan = Project(projectList, deduplicate) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test dedup a keepempty=true") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=table | dedup a keepempty=true | fields a", false), + context) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("a")) + val isNotNullFilter = + Filter(IsNotNull(UnresolvedAttribute("a")), UnresolvedRelation(Seq("table"))) + val deduplicate = Deduplicate(Seq(UnresolvedAttribute("a")), isNotNullFilter) + val isNullFilter = Filter(IsNull(UnresolvedAttribute("a")), UnresolvedRelation(Seq("table"))) + val union = Union(deduplicate, isNullFilter) + val expectedPlan = Project(projectList, union) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test dedup a, b, c keepempty=true") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=table | dedup a, b, c keepempty=true | fields a, b, c", false), + context) + + val projectList: Seq[NamedExpression] = + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")) + val isNotNullFilter = Filter( + And( + And(IsNotNull(UnresolvedAttribute("a")), IsNotNull(UnresolvedAttribute("b"))), + IsNotNull(UnresolvedAttribute("c"))), + UnresolvedRelation(Seq("table"))) + val deduplicate = Deduplicate( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")), + isNotNullFilter) + val isNullFilter = Filter( + Or( + Or(IsNull(UnresolvedAttribute("a")), IsNull(UnresolvedAttribute("b"))), + IsNull(UnresolvedAttribute("c"))), + UnresolvedRelation(Seq("table"))) + val union = Union(deduplicate, isNullFilter) + val expectedPlan = Project(projectList, union) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test dedup a consecutive=true") { + val context = new CatalystPlanContext + val ex = intercept[UnsupportedOperationException] { + planTransformer.visit( + plan(pplParser, "source=table | dedup a consecutive=true | fields a", false), + context) + } + assert(ex.getMessage === "Consecutive deduplication is not supported") + } + + test("test dedup a keepempty=true consecutive=true") { + val context = new CatalystPlanContext + val ex = intercept[UnsupportedOperationException] { + planTransformer.visit( + plan( + pplParser, + "source=table | dedup a keepempty=true consecutive=true | fields a", + false), + context) + } + assert(ex.getMessage === "Consecutive deduplication is not supported") + } + + test("test dedup 1 a") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=table | dedup 1 a | fields a", false), + context) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("a")) + val filter = Filter(IsNotNull(UnresolvedAttribute("a")), UnresolvedRelation(Seq("table"))) + val deduplicate = Deduplicate(Seq(UnresolvedAttribute("a")), filter) + val expectedPlan = Project(projectList, deduplicate) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test dedup 1 a, b, c") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=table | dedup 1 a, b, c | fields a, b, c", false), + context) + + val projectList: Seq[NamedExpression] = + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")) + val filter = Filter( + And( + And(IsNotNull(UnresolvedAttribute("a")), IsNotNull(UnresolvedAttribute("b"))), + IsNotNull(UnresolvedAttribute("c"))), + UnresolvedRelation(Seq("table"))) + val deduplicate = Deduplicate( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")), + filter) + val expectedPlan = Project(projectList, deduplicate) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test dedup 1 a keepempty=true") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=table | dedup 1 a keepempty=true | fields a", false), + context) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("a")) + val isNotNullFilter = + Filter(IsNotNull(UnresolvedAttribute("a")), UnresolvedRelation(Seq("table"))) + val deduplicate = Deduplicate(Seq(UnresolvedAttribute("a")), isNotNullFilter) + val isNullFilter = Filter(IsNull(UnresolvedAttribute("a")), UnresolvedRelation(Seq("table"))) + val union = Union(deduplicate, isNullFilter) + val expectedPlan = Project(projectList, union) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test dedup 1 a, b, c keepempty=true") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=table | dedup 1 a, b, c keepempty=true | fields a, b, c", false), + context) + + val projectList: Seq[NamedExpression] = + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")) + val isNotNullFilter = Filter( + And( + And(IsNotNull(UnresolvedAttribute("a")), IsNotNull(UnresolvedAttribute("b"))), + IsNotNull(UnresolvedAttribute("c"))), + UnresolvedRelation(Seq("table"))) + val deduplicate = Deduplicate( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")), + isNotNullFilter) + val isNullFilter = Filter( + Or( + Or(IsNull(UnresolvedAttribute("a")), IsNull(UnresolvedAttribute("b"))), + IsNull(UnresolvedAttribute("c"))), + UnresolvedRelation(Seq("table"))) + val union = Union(deduplicate, isNullFilter) + val expectedPlan = Project(projectList, union) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test dedup 1 a consecutive=true") { + val context = new CatalystPlanContext + val ex = intercept[UnsupportedOperationException] { + planTransformer.visit( + plan(pplParser, "source=table | dedup 1 a consecutive=true | fields a", false), + context) + } + assert(ex.getMessage === "Consecutive deduplication is not supported") + } + + test("test dedup 1 a keepempty=true consecutive=true") { + val context = new CatalystPlanContext + val ex = intercept[UnsupportedOperationException] { + planTransformer.visit( + plan( + pplParser, + "source=table | dedup 1 a keepempty=true consecutive=true | fields a", + false), + context) + } + assert(ex.getMessage === "Consecutive deduplication is not supported") + } + + test("test dedup 0") { + val context = new CatalystPlanContext + val ex = intercept[IllegalArgumentException] { + planTransformer.visit( + plan(pplParser, "source=table | dedup 0 a | fields a", false), + context) + } + assert(ex.getMessage === "Number of duplicate events must be greater than 0") + } + + // Todo + ignore("test dedup 2 a") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=table | dedup 2 a | fields a", false), + context) + + } + + // Todo + ignore("test dedup 2 a, b, c") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=table | dedup 2 a, b, c | fields a, b, c", false), + context) + + } + + // Todo + ignore("test dedup 2 a keepempty=true") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=table | dedup 2 a keepempty=true | fields a", false), + context) + + } + + // Todo + ignore("test dedup 2 a, b, c keepempty=true") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=table | dedup 2 a, b, c keepempty=true | fields a, b, c", false), + context) + + } + + test("test dedup 2 a consecutive=true") { + val context = new CatalystPlanContext + val ex = intercept[UnsupportedOperationException] { + planTransformer.visit( + plan(pplParser, "source=table | dedup 2 a consecutive=true | fields a | fields a", false), + context) + } + assert(ex.getMessage === "Consecutive deduplication is not supported") + } + + test("test dedup 2 a keepempty=true consecutive=true") { + val context = new CatalystPlanContext + val ex = intercept[UnsupportedOperationException] { + planTransformer.visit( + plan( + pplParser, + "source=table | dedup 2 a keepempty=true consecutive=true | fields a", + false), + context) + } + assert(ex.getMessage === "Consecutive deduplication is not supported") + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index fe2fa5212..048f69ced 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -21,4 +21,5 @@ case class CommandContext( jobId: String, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, - queryWaitTimeMillis: Long) + queryWaitTimeMillis: Long, + queryLoopExecutionFrequency: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 00f023694..37801a9e8 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -96,6 +96,24 @@ trait FlintJobExecutor { } }""".stripMargin + // Define the data schema + val schema = StructType( + Seq( + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), + StructField("jobRunId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true), + StructField("dataSourceName", StringType, nullable = true), + StructField("status", StringType, nullable = true), + StructField("error", StringType, nullable = true), + StructField("queryId", StringType, nullable = true), + StructField("queryText", StringType, nullable = true), + StructField("sessionId", StringType, nullable = true), + StructField("jobType", StringType, nullable = true), + // number is not nullable + StructField("updateTime", LongType, nullable = false), + StructField("queryRunTime", LongType, nullable = true))) + def createSparkConf(): SparkConf = { val conf = new SparkConf().setAppName(getClass.getSimpleName) @@ -129,11 +147,14 @@ trait FlintJobExecutor { builder.getOrCreate() } - private def writeData(resultData: DataFrame, resultIndex: String): Unit = { + private def writeData( + resultData: DataFrame, + resultIndex: String, + refreshPolicy: String): Unit = { try { resultData.write .format("flint") - .option(REFRESH_POLICY.optionKey, "wait_for") + .option(REFRESH_POLICY.optionKey, refreshPolicy) .mode("append") .save(resultIndex) IRestHighLevelClient.recordOperationSuccess( @@ -160,11 +181,12 @@ trait FlintJobExecutor { resultData: DataFrame, resultIndex: String, osClient: OSClient): Unit = { + val refreshPolicy = osClient.flintOptions.getRefreshPolicy; if (osClient.doesIndexExist(resultIndex)) { - writeData(resultData, resultIndex) + writeData(resultData, resultIndex, refreshPolicy) } else { createResultIndex(osClient, resultIndex, resultIndexMapping) - writeData(resultData, resultIndex) + writeData(resultData, resultIndex, refreshPolicy) } } @@ -199,24 +221,6 @@ trait FlintJobExecutor { StructField("column_name", StringType, nullable = false), StructField("data_type", StringType, nullable = false)))) - // Define the data schema - val schema = StructType( - Seq( - StructField("result", ArrayType(StringType, containsNull = true), nullable = true), - StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), - StructField("jobRunId", StringType, nullable = true), - StructField("applicationId", StringType, nullable = true), - StructField("dataSourceName", StringType, nullable = true), - StructField("status", StringType, nullable = true), - StructField("error", StringType, nullable = true), - StructField("queryId", StringType, nullable = true), - StructField("queryText", StringType, nullable = true), - StructField("sessionId", StringType, nullable = true), - StructField("jobType", StringType, nullable = true), - // number is not nullable - StructField("updateTime", LongType, nullable = false), - StructField("queryRunTime", LongType, nullable = true))) - val resultToSave = result.toJSON.collect.toList .map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'")) @@ -249,35 +253,17 @@ trait FlintJobExecutor { spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) } - def getFailedData( + def constructErrorDF( spark: SparkSession, dataSource: String, + status: String, error: String, queryId: String, - query: String, + queryText: String, sessionId: String, - startTime: Long, - timeProvider: TimeProvider): DataFrame = { - - // Define the data schema - val schema = StructType( - Seq( - StructField("result", ArrayType(StringType, containsNull = true), nullable = true), - StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), - StructField("jobRunId", StringType, nullable = true), - StructField("applicationId", StringType, nullable = true), - StructField("dataSourceName", StringType, nullable = true), - StructField("status", StringType, nullable = true), - StructField("error", StringType, nullable = true), - StructField("queryId", StringType, nullable = true), - StructField("queryText", StringType, nullable = true), - StructField("sessionId", StringType, nullable = true), - StructField("jobType", StringType, nullable = true), - // number is not nullable - StructField("updateTime", LongType, nullable = false), - StructField("queryRunTime", LongType, nullable = true))) + startTime: Long): DataFrame = { - val endTime = timeProvider.currentEpochMillis() + val updateTime = currentTimeProvider.currentEpochMillis() // Create the data rows val rows = Seq( @@ -287,14 +273,14 @@ trait FlintJobExecutor { envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), dataSource, - "FAILED", + status.toUpperCase(Locale.ROOT), error, queryId, - query, + queryText, sessionId, spark.conf.get(FlintSparkConf.JOB_TYPE.key), - endTime, - endTime - startTime)) + updateTime, + updateTime - startTime)) // Create the DataFrame for data spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 8cad8844b..e6b8b11ce 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -18,21 +18,32 @@ import com.codahale.metrics.Timer import org.json4s.native.Serialization import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings +import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} +import org.opensearch.flint.common.model.InteractiveSession.formats import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} -import org.opensearch.flint.data.{FlintStatement, InteractiveSession} -import org.opensearch.flint.data.InteractiveSession.formats import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.sql.FlintREPLConfConstants._ import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.util.ThreadUtils +object FlintREPLConfConstants { + val HEARTBEAT_INTERVAL_MILLIS = 60000L + val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) + val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) + val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 + val DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY = 100L + val INITIAL_DELAY_MILLIS = 3000L + val EARLY_TERMINATION_CHECK_FREQUENCY = 60000L +} + /** * Spark SQL Application entrypoint * @@ -48,13 +59,6 @@ import org.apache.spark.util.ThreadUtils */ object FlintREPL extends Logging with FlintJobExecutor { - private val HEARTBEAT_INTERVAL_MILLIS = 60000L - private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) - private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) - private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 - val INITIAL_DELAY_MILLIS = 3000L - val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L - @volatile var earlyExitFlag: Boolean = false def updateSessionIndex(flintStatement: FlintStatement, updater: OpenSearchUpdater): Unit = { @@ -134,7 +138,10 @@ object FlintREPL extends Logging with FlintJobExecutor { SECONDS) val queryWaitTimeoutMillis: Long = conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) - + val queryLoopExecutionFrequency: Long = + conf.getLong( + "spark.flint.job.queryLoopExecutionFrequency", + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) @@ -199,7 +206,8 @@ object FlintREPL extends Logging with FlintJobExecutor { jobId, queryExecutionTimeoutSecs, inactivityLimitMillis, - queryWaitTimeoutMillis) + queryWaitTimeoutMillis, + queryLoopExecutionFrequency) exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { queryLoop(commandContext) } @@ -342,7 +350,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } def queryLoop(commandContext: CommandContext): Unit = { - // 1 thread for updating heart beat + // 1 thread for async query execution val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -392,7 +400,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintReader.close() } - Thread.sleep(100) + Thread.sleep(commandContext.queryLoopExecutionFrequency) } } finally { if (threadPool != null) { @@ -448,7 +456,7 @@ object FlintREPL extends Logging with FlintJobExecutor { .getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error)) updateFlintInstance(flintInstance, flintSessionIndexUpdater, sessionId) - if (flintInstance.state.equals("fail")) { + if (flintInstance.isFail) { recordSessionFailed(sessionTimerContext) } } @@ -522,15 +530,15 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime: Long): DataFrame = { flintStatement.fail() flintStatement.error = Some(error) - super.getFailedData( + super.constructErrorDF( spark, dataSource, + flintStatement.state, error, flintStatement.queryId, flintStatement.query, sessionId, - startTime, - currentTimeProvider) + startTime) } def processQueryException(ex: Exception, flintStatement: FlintStatement): String = { @@ -555,8 +563,8 @@ object FlintREPL extends Logging with FlintJobExecutor { while (canProceed) { val currentTime = currentTimeProvider.currentEpochMillis() - // Only call canPickNextStatement if EARLY_TERMIANTION_CHECK_FREQUENCY milliseconds have passed - if (currentTime - lastCanPickCheckTime > EARLY_TERMIANTION_CHECK_FREQUENCY) { + // Only call canPickNextStatement if EARLY_TERMINATION_CHECK_FREQUENCY milliseconds have passed + if (currentTime - lastCanPickCheckTime > EARLY_TERMINATION_CHECK_FREQUENCY) { canPickNextStatementResult = canPickNextStatement(sessionId, jobId, osClient, sessionIndex) lastCanPickCheckTime = currentTime @@ -646,7 +654,7 @@ object FlintREPL extends Logging with FlintJobExecutor { error: String, flintStatement: FlintStatement, sessionId: String, - startTime: Long): Option[DataFrame] = { + startTime: Long): DataFrame = { /* * https://tinyurl.com/2ezs5xj9 * @@ -660,14 +668,17 @@ object FlintREPL extends Logging with FlintJobExecutor { * actions that require the computation of results that need to be collected or stored. */ spark.sparkContext.cancelJobGroup(flintStatement.queryId) - Some( - handleCommandFailureAndGetFailedData( - spark, - dataSource, - error, - flintStatement, - sessionId, - startTime)) + flintStatement.timeout() + flintStatement.error = Some(error) + super.constructErrorDF( + spark, + dataSource, + flintStatement.state, + error, + flintStatement.queryId, + flintStatement.query, + sessionId, + startTime) } def executeAndHandle( @@ -694,7 +705,7 @@ object FlintREPL extends Logging with FlintJobExecutor { case e: TimeoutException => val error = s"Executing ${flintStatement.query} timed out" CustomLogging.logError(error, e) - handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime) + Some(handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime)) case e: Exception => val error = processQueryException(e, flintStatement) Some( @@ -753,8 +764,14 @@ object FlintREPL extends Logging with FlintJobExecutor { case e: TimeoutException => val error = s"Getting the mapping of index $resultIndex timed out" CustomLogging.logError(error, e) - dataToWrite = - handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime) + dataToWrite = Some( + handleCommandTimeout( + spark, + dataSource, + error, + flintStatement, + sessionId, + startTime)) case NonFatal(e) => val error = s"An unexpected error occurred: ${e.getMessage}" CustomLogging.logError(error, e) @@ -933,7 +950,7 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId: String, sessionTimerContext: Timer.Context): Unit = { val flintInstance = InteractiveSession.deserializeFromMap(source) - flintInstance.state = "dead" + flintInstance.complete() flintSessionIndexUpdater.updateIf( sessionId, InteractiveSession.serializeWithoutJobId( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index f315dc836..c079b3e96 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -57,7 +57,7 @@ case class JobOperator( dataToWrite = Some(mappingCheckResult match { case Right(_) => data case Left(error) => - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider) + constructErrorDF(spark, dataSource, "FAILED", error, "", query, "", startTime) }) exceptionThrown = false } catch { @@ -65,11 +65,11 @@ case class JobOperator( val error = s"Getting the mapping of index $resultIndex timed out" logError(error, e) dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + constructErrorDF(spark, dataSource, "TIMEOUT", error, "", query, "", startTime)) case e: Exception => val error = processQueryException(e) dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + constructErrorDF(spark, dataSource, "FAILED", error, "", query, "", startTime)) } finally { cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient) } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index d8ddcb665..9c193fc9a 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -17,19 +17,21 @@ import scala.reflect.runtime.universe.TypeTag import com.amazonaws.services.glue.model.AccessDeniedException import com.codahale.metrics.Timer -import org.mockito.ArgumentMatchersSugar -import org.mockito.Mockito._ +import org.mockito.{ArgumentMatchersSugar, Mockito} +import org.mockito.Mockito.{atLeastOnce, never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse +import org.opensearch.flint.common.model.FlintStatement import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} -import org.opensearch.flint.data.FlintStatement import org.opensearch.search.sort.SortOrder +import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.SparkListenerApplicationEnd import org.apache.spark.sql.FlintREPL.PreShutdownListener +import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.trees.Origin @@ -228,7 +230,7 @@ class FlintREPLTest verify(flintSessionIndexUpdater).updateIf(*, *, *, *) } - test("Test getFailedData method") { + test("Test super.constructErrorDF should construct dataframe properly") { // Define expected dataframe val dataSourceName = "myGlueS3" val expectedSchema = StructType( @@ -286,7 +288,7 @@ class FlintREPLTest "20", currentTime - queryRunTime) assertEqualDataframe(expected, result) - assert("failed" == flintStatement.state) + assert(flintStatement.isFailed) assert(error == flintStatement.error.get) } finally { spark.close() @@ -490,7 +492,7 @@ class FlintREPLTest assert(result == expectedError) } - test("handleGeneralException should handle MetaException with AccessDeniedException properly") { + test("processQueryException should handle MetaException with AccessDeniedException properly") { val mockFlintCommand = mock[FlintStatement] // Simulate the root cause being MetaException @@ -599,7 +601,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), 60, - 60) + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) intercept[RuntimeException] { FlintREPL.exponentialBackoffRetry(maxRetries, 2.seconds) { @@ -617,7 +620,6 @@ class FlintREPLTest test("executeAndHandle should handle TimeoutException properly") { val mockSparkSession = mock[SparkSession] - val mockFlintStatement = mock[FlintStatement] val mockConf = mock[RuntimeConfig] when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) @@ -630,9 +632,8 @@ class FlintREPLTest val sessionId = "someSessionId" val startTime = System.currentTimeMillis() val expectedDataFrame = mock[DataFrame] - - when(mockFlintStatement.query).thenReturn("SELECT 1") - when(mockFlintStatement.submitTime).thenReturn(Instant.now().toEpochMilli()) + val flintStatement = + new FlintStatement("running", "select 1", "30", "10", Instant.now().toEpochMilli(), None) // When the `sql` method is called, execute the custom Answer that introduces a delay when(mockSparkSession.sql(any[String])).thenAnswer(new Answer[DataFrame] { override def answer(invocation: InvocationOnMock): DataFrame = { @@ -653,7 +654,7 @@ class FlintREPLTest val result = FlintREPL.executeAndHandle( mockSparkSession, - mockFlintStatement, + flintStatement, dataSource, sessionId, executionContext, @@ -664,6 +665,8 @@ class FlintREPLTest verify(mockSparkSession, times(1)).sql(any[String]) verify(sparkContext, times(1)).cancelJobGroup(any[String]) + assert("timeout" == flintStatement.state) + assert(s"Executing ${flintStatement.query} timed out" == flintStatement.error.get) result should not be None } finally threadPool.shutdown() } @@ -880,7 +883,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), shortInactivityLimit, - 60) + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) // Mock processCommands to always allow loop continuation val getResponse = mock[GetResponse] @@ -930,7 +934,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), longInactivityLimit, - 60) + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) // Mocking canPickNextStatement to return false when(osClient.getDoc(sessionIndex, sessionId)).thenAnswer(_ => { @@ -986,7 +991,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) try { // Mocking ThreadUtils to track the shutdown call @@ -1036,7 +1042,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) try { // Mocking ThreadUtils to track the shutdown call @@ -1117,7 +1124,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) val startTime = Instant.now().toEpochMilli() @@ -1131,58 +1139,70 @@ class FlintREPLTest verify(osClient, times(1)).getIndexMetadata(*) } - test("queryLoop should execute loop without processing any commands") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(false) + val testCases = Table( + ("inactivityLimit", "queryLoopExecutionFrequency"), + (5000, 100L), // 5 seconds, 100 ms + (100, 300L) // 100 ms, 300 ms + ) - // Configure mockReader to always return false, indicating no commands to process - when(mockReader.hasNext).thenReturn(false) - - val resultIndex = "testResultIndex" - val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" - val sessionId = "testSessionId" - val jobId = "testJobId" + test( + "queryLoop should execute loop without processing any commands for different inactivity limits and frequencies") { + forAll(testCases) { (inactivityLimit, queryLoopExecutionFrequency) => + val mockReader = mock[FlintReader] + val osClient = mock[OSClient] + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + val getResponse = mock[GetResponse] + when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(false) + when(mockReader.hasNext).thenReturn(false) + + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val jobId = "testJobId" + + // Create a SparkSession for testing + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val inactivityLimit = 5000 // 5 seconds + val flintSessionIndexUpdater = mock[OpenSearchUpdater] - // Create a SparkSession for testing - val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + val commandContext = CommandContext( + spark, + dataSource, + resultIndex, + sessionId, + flintSessionIndexUpdater, + osClient, + sessionIndex, + jobId, + Duration(10, MINUTES), + inactivityLimit, + 60, + queryLoopExecutionFrequency) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val startTime = Instant.now().toEpochMilli() - val commandContext = CommandContext( - spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, - jobId, - Duration(10, MINUTES), - inactivityLimit, - 60) + // Running the queryLoop + FlintREPL.queryLoop(commandContext) - val startTime = Instant.now().toEpochMilli() + val endTime = Instant.now().toEpochMilli() - // Running the queryLoop - FlintREPL.queryLoop(commandContext) + val elapsedTime = endTime - startTime - val endTime = Instant.now().toEpochMilli() + // Assert that the loop ran for at least the duration of the inactivity limit + assert(elapsedTime >= inactivityLimit) - // Assert that the loop ran for at least the duration of the inactivity limit - assert(endTime - startTime >= inactivityLimit) + // Verify query execution frequency + val expectedCalls = Math.ceil(elapsedTime.toDouble / queryLoopExecutionFrequency).toInt + verify(mockReader, Mockito.atMost(expectedCalls)).hasNext - // Verify that no command was actually processed - verify(mockReader, never()).next() + // Verify that no command was actually processed + verify(mockReader, never()).next() - // Stop the SparkSession - spark.stop() + // Stop the SparkSession + spark.stop() + } } } diff --git a/spark-sql-application/src/test/scala/org/opensearch/sql/FlintDelegatingSessionCatalogTest.scala b/spark-sql-application/src/test/scala/org/opensearch/sql/FlintDelegatingSessionCatalogTest.scala index f6be0b1c3..fc8cb4f4a 100644 --- a/spark-sql-application/src/test/scala/org/opensearch/sql/FlintDelegatingSessionCatalogTest.scala +++ b/spark-sql-application/src/test/scala/org/opensearch/sql/FlintDelegatingSessionCatalogTest.scala @@ -45,7 +45,10 @@ class FlintDelegatingSessionCatalogTest extends QueryTest with SharedSparkSessio test("query without catalog name") { sql("use mycatalog") - assert(sql("SHOW CATALOGS").collect === Array(Row("mycatalog"))) + // Since Spark 3.4.0. https://issues.apache.org/jira/browse/SPARK-40055, listCatalogs should + // also return spark_catalog even spark_catalog implementation is defaultSessionCatalog + assert( + sql("SHOW CATALOGS").collect.toSet === Array(Row("mycatalog"), Row("spark_catalog")).toSet) checkAnswer(sql(s"SELECT name, age FROM $testTableWithoutCatalog"), Seq(Row("Hello", 30))) }