diff --git a/build.sbt b/build.sbt index 48e4bca5b..3c6281684 100644 --- a/build.sbt +++ b/build.sbt @@ -42,8 +42,9 @@ lazy val commonSettings = Seq( testScalastyle := (Test / scalastyle).toTask("").value, Test / test := ((Test / test) dependsOn testScalastyle).value) +// running `scalafmtAll` includes all subprojects under root lazy val root = (project in file(".")) - .aggregate(flintCore, flintSparkIntegration, pplSparkIntegration, sparkSqlApplication) + .aggregate(flintCore, flintSparkIntegration, pplSparkIntegration, sparkSqlApplication, integtest) .disablePlugins(AssemblyPlugin) .settings(name := "flint", publish / skip := true) @@ -159,7 +160,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) // Test assembly package with integration test. lazy val integtest = (project in file("integ-test")) - .dependsOn(flintSparkIntegration % "test->test", pplSparkIntegration % "test->test" ) + .dependsOn(flintSparkIntegration % "test->test", pplSparkIntegration % "test->test", sparkSqlApplication % "test->test") .settings( commonSettings, name := "integ-test", @@ -175,7 +176,9 @@ lazy val integtest = (project in file("integ-test")) "org.opensearch.client" % "opensearch-java" % "2.6.0" % "test" exclude ("com.fasterxml.jackson.core", "jackson-databind")), libraryDependencies ++= deps(sparkVersion), - Test / fullClasspath ++= Seq((flintSparkIntegration / assembly).value, (pplSparkIntegration / assembly).value)) + Test / fullClasspath ++= Seq((flintSparkIntegration / assembly).value, (pplSparkIntegration / assembly).value, + (sparkSqlApplication / assembly).value + )) lazy val standaloneCosmetic = project .settings( diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java index c1c5491ed..410d896d2 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java @@ -77,6 +77,8 @@ public class FlintOptions implements Serializable { public static final int DEFAULT_SOCKET_TIMEOUT_MILLIS = 60000; + public static final int DEFAULT_INACTIVITY_LIMIT_MILLIS = 10 * 60 * 1000; + public FlintOptions(Map options) { this.options = options; this.retryOptions = new FlintRetryOptions(options); diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchQueryReader.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchQueryReader.java new file mode 100644 index 000000000..349e5c126 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchQueryReader.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.storage; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.search.ClearScrollRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.common.Strings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.flint.core.FlintOptions; +import org.opensearch.flint.core.IRestHighLevelClient; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.io.IOException; +import java.util.Optional; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * {@link OpenSearchReader} using search. https://opensearch.org/docs/latest/api-reference/search/ + */ +public class OpenSearchQueryReader extends OpenSearchReader { + + private static final Logger LOG = Logger.getLogger(OpenSearchQueryReader.class.getName()); + + public OpenSearchQueryReader(IRestHighLevelClient client, String indexName, SearchSourceBuilder searchSourceBuilder) { + super(client, new SearchRequest().indices(indexName).source(searchSourceBuilder)); + } + + /** + * search. + */ + Optional search(SearchRequest request) throws IOException { + return Optional.of(client.search(request, RequestOptions.DEFAULT)); + } + + /** + * nothing to clean + */ + void clean() throws IOException {} +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java index c70d327fe..e2e831bd0 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java @@ -5,6 +5,7 @@ package org.opensearch.flint.core.storage; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.flint.core.IRestHighLevelClient; @@ -48,6 +49,13 @@ public OpenSearchReader(IRestHighLevelClient client, SearchRequest searchRequest iterator = searchHits.iterator(); } return iterator.hasNext(); + } catch (OpenSearchStatusException e) { + // e.g., org.opensearch.OpenSearchStatusException: OpenSearch exception [type=index_not_found_exception, reason=no such index [query_results2]] + if (e.getMessage() != null && (e.getMessage().contains("index_not_found_exception"))) { + return false; + } else { + throw e; + } } catch (IOException e) { // todo. log error. throw new RuntimeException(e); diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java index 7fab8c346..94760fc37 100644 --- a/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java @@ -57,13 +57,20 @@ public void testGetDimensionsFromSystemEnv() throws NoSuchFieldException, Illega Field field = classOfMap.getDeclaredField("m"); field.setAccessible(true); Map writeableEnvironmentVariables = (Map)field.get(System.getenv()); - writeableEnvironmentVariables.put("TEST_VAR", "dummy1"); - writeableEnvironmentVariables.put("SERVERLESS_EMR_JOB_ID", "dummy2"); - Dimension result1 = DimensionUtils.constructDimension("TEST_VAR", parts); - assertEquals("TEST_VAR", result1.getName()); - assertEquals("dummy1", result1.getValue()); - Dimension result2 = DimensionUtils.constructDimension("jobId", parts); - assertEquals("jobId", result2.getName()); - assertEquals("dummy2", result2.getValue()); + try { + writeableEnvironmentVariables.put("TEST_VAR", "dummy1"); + writeableEnvironmentVariables.put("SERVERLESS_EMR_JOB_ID", "dummy2"); + Dimension result1 = DimensionUtils.constructDimension("TEST_VAR", parts); + assertEquals("TEST_VAR", result1.getName()); + assertEquals("dummy1", result1.getValue()); + Dimension result2 = DimensionUtils.constructDimension("jobId", parts); + assertEquals("jobId", result2.getName()); + assertEquals("dummy2", result2.getValue()); + } finally { + // since system environment is shared by other tests. Make sure to remove them before exiting. + writeableEnvironmentVariables.remove("SERVERLESS_EMR_JOB_ID"); + writeableEnvironmentVariables.remove("TEST_VAR"); + } + } } diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index fd998d46d..359994c56 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -146,7 +146,30 @@ object FlintSparkConf { .datasourceOption() .doc("socket duration in milliseconds") .createWithDefault(String.valueOf(FlintOptions.DEFAULT_SOCKET_TIMEOUT_MILLIS)) - + val DATA_SOURCE_NAME = + FlintConfig(s"spark.flint.datasource.name") + .doc("data source name") + .createOptional() + val JOB_TYPE = + FlintConfig(s"spark.flint.job.type") + .doc("Flint job type. Including interactive and streaming") + .createWithDefault("interactive") + val SESSION_ID = + FlintConfig(s"spark.flint.job.sessionId") + .doc("Flint session id") + .createOptional() + val REQUEST_INDEX = + FlintConfig(s"spark.flint.job.requestIndex") + .doc("Request index") + .createOptional() + val EXCLUDE_JOB_IDS = + FlintConfig(s"spark.flint.deployment.excludeJobs") + .doc("Exclude job ids") + .createOptional() + val REPL_INACTIVITY_TIMEOUT_MILLIS = + FlintConfig(s"spark.flint.job.inactivityLimitMillis") + .doc("inactivity timeout") + .createWithDefault(String.valueOf(FlintOptions.DEFAULT_INACTIVITY_LIMIT_MILLIS)) } /** @@ -196,11 +219,18 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable CUSTOM_AWS_CREDENTIALS_PROVIDER, USERNAME, PASSWORD, - SOCKET_TIMEOUT_MILLIS) + SOCKET_TIMEOUT_MILLIS, + JOB_TYPE, + REPL_INACTIVITY_TIMEOUT_MILLIS) .map(conf => (conf.optionKey, conf.readFrom(reader))) .toMap - val optionsWithoutDefault = Seq(RETRYABLE_EXCEPTION_CLASS_NAMES) + val optionsWithoutDefault = Seq( + RETRYABLE_EXCEPTION_CLASS_NAMES, + DATA_SOURCE_NAME, + SESSION_ID, + REQUEST_INDEX, + EXCLUDE_JOB_IDS) .map(conf => (conf.optionKey, conf.readFrom(reader))) .flatMap { case (_, None) => None diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala index 5af70b793..9911a3b6c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala @@ -25,7 +25,14 @@ class FlintInstance( val lastUpdateTime: Long, val jobStartTime: Long = 0, val excludedJobIds: Seq[String] = Seq.empty[String], - val error: Option[String] = None) {} + val error: Option[String] = None) { + override def toString: String = { + val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]") + val errorStr = error.getOrElse("None") + s"FlintInstance(applicationId=$applicationId, jobId=$jobId, sessionId=$sessionId, state=$state, " + + s"lastUpdateTime=$lastUpdateTime, jobStartTime=$jobStartTime, excludedJobIds=$excludedJobIdsStr, error=$errorStr)" + } +} object FlintInstance { diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala new file mode 100644 index 000000000..f070ef3ab --- /dev/null +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala @@ -0,0 +1,260 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.{Duration, MINUTES} +import scala.util.{Failure, Success} +import scala.util.control.Breaks.{break, breakable} + +import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.spark.FlintSparkSuite +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName +import org.scalatest.matchers.must.Matchers.{defined, have} +import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} + +import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE +import org.apache.spark.sql.util.MockEnvironment +import org.apache.spark.util.ThreadUtils + +class FlintJobITSuite extends FlintSparkSuite with JobTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.skipping_sql_test" + private val testIndex = getSkippingIndexName(testTable) + val resultIndex = "query_results2" + val appId = "00feq82b752mbt0p" + val dataSourceName = "my_glue1" + var osClient: OSClient = _ + val threadLocalFuture = new ThreadLocal[Future[Unit]]() + + override def beforeAll(): Unit = { + super.beforeAll() + // initialized after the container is started + osClient = new OSClient(new FlintOptions(openSearchOptions.asJava)) + createPartitionedMultiRowTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + + deleteTestIndex(testIndex) + + waitJobStop(threadLocalFuture.get()) + + threadLocalFuture.remove() + } + + def waitJobStop(future: Future[Unit]): Unit = { + try { + val activeJob = spark.streams.active.find(_.name == testIndex) + if (activeJob.isDefined) { + activeJob.get.stop() + } + ThreadUtils.awaitResult(future, Duration(1, MINUTES)) + } catch { + case e: Exception => + e.printStackTrace() + assert(false, "failure waiting for job to finish") + } + } + + def startJob(query: String, jobRunId: String): Future[Unit] = { + val prefix = "flint-job-test" + val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor(prefix, 1) + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + + val futureResult = Future { + val job = + JobOperator(spark, query, dataSourceName, resultIndex, true) + job.envinromentProvider = new MockEnvironment( + Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) + + job.start() + } + futureResult.onComplete { + case Success(result) => logInfo(s"Success result: $result") + case Failure(ex) => + ex.printStackTrace() + assert(false, s"An error has occurred: ${ex.getMessage}") + } + futureResult + } + + test("create skipping index with auto refresh") { + val query = + s""" + | CREATE SKIPPING INDEX ON $testTable + | ( + | year PARTITION, + | name VALUE_SET, + | age MIN_MAX + | ) + | WITH (auto_refresh = true) + | """.stripMargin + val queryStartTime = System.currentTimeMillis() + val jobRunId = "00ff4o3b5091080q" + threadLocalFuture.set(startJob(query, jobRunId)) + + val validation: REPLResult => Boolean = result => { + assert( + result.results.size == 0, + s"expected result size is 0, but got ${result.results.size}") + assert( + result.schemas.size == 0, + s"expected schema size is 0, but got ${result.schemas.size}") + + assert(result.status == "SUCCESS", s"expected status is SUCCESS, but got ${result.status}") + assert(result.error.isEmpty, s"we don't expect error, but got ${result.error}") + assert(result.queryId.isEmpty, s"we don't expect query id, but got ${result.queryId}") + + commonAssert(result, jobRunId, query, queryStartTime) + true + } + pollForResultAndAssert(validation, jobRunId) + + val activeJob = spark.streams.active.find(_.name == testIndex) + activeJob shouldBe defined + failAfter(streamingTimeout) { + activeJob.get.processAllAvailable() + } + val indexData = spark.read.format(FLINT_DATASOURCE).load(testIndex) + flint.describeIndex(testIndex) shouldBe defined + indexData.count() shouldBe 2 + } + + test("create skipping index with non-existent table") { + val query = + s""" + | CREATE SKIPPING INDEX ON testTable + | ( + | year PARTITION, + | name VALUE_SET, + | age MIN_MAX + | ) + | WITH (auto_refresh = true) + | """.stripMargin + val queryStartTime = System.currentTimeMillis() + val jobRunId = "00ff4o3b5091080r" + threadLocalFuture.set(startJob(query, jobRunId)) + + val validation: REPLResult => Boolean = result => { + assert( + result.results.size == 0, + s"expected result size is 0, but got ${result.results.size}") + assert( + result.schemas.size == 0, + s"expected schema size is 0, but got ${result.schemas.size}") + + assert(result.status == "FAILED", s"expected status is FAILED, but got ${result.status}") + assert(!result.error.isEmpty, s"we expect error, but got ${result.error}") + commonAssert(result, jobRunId, query, queryStartTime) + true + } + pollForResultAndAssert(validation, jobRunId) + } + + test("describe skipping index") { + flint + .skippingIndex() + .onTable(testTable) + .addPartitions("year") + .addValueSet("name") + .addMinMax("age") + .create() + + val queryStartTime = System.currentTimeMillis() + val jobRunId = "00ff4o3b5091080s" + val query = s"DESC SKIPPING INDEX ON $testTable" + threadLocalFuture.set(startJob(query, jobRunId)) + + val validation: REPLResult => Boolean = result => { + assert( + result.results.size == 3, + s"expected result size is 3, but got ${result.results.size}") + val expectedResult0 = + "{'indexed_col_name':'year','data_type':'int','skip_type':'PARTITION'}" + assert( + result.results(0) == expectedResult0, + s"expected result size is $expectedResult0, but got ${result.results(0)}") + val expectedResult1 = + "{'indexed_col_name':'name','data_type':'string','skip_type':'VALUE_SET'}" + assert( + result.results(1) == expectedResult1, + s"expected result size is $expectedResult1, but got ${result.results(1)}") + val expectedResult2 = "{'indexed_col_name':'age','data_type':'int','skip_type':'MIN_MAX'}" + assert( + result.results(2) == expectedResult2, + s"expected result size is $expectedResult2, but got ${result.results(2)}") + assert( + result.schemas.size == 3, + s"expected schema size is 3, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'indexed_col_name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected 0th field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'data_type','data_type':'string'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected 1st field is $expectedFirstSchema, but got ${result.schemas(1)}") + val expectedSecondSchema = "{'column_name':'skip_type','data_type':'string'}" + assert( + result.schemas(2).equals(expectedSecondSchema), + s"expected 2nd field is $expectedSecondSchema, but got ${result.schemas(2)}") + + assert(result.status == "SUCCESS", s"expected status is FAILED, but got ${result.status}") + assert(result.error.isEmpty, s"we expect error, but got ${result.error}") + + commonAssert(result, jobRunId, query, queryStartTime) + true + } + pollForResultAndAssert(validation, jobRunId) + } + + def commonAssert( + result: REPLResult, + jobRunId: String, + query: String, + queryStartTime: Long): Unit = { + assert( + result.jobRunId == jobRunId, + s"expected jobRunId is $jobRunId, but got ${result.jobRunId}") + assert( + result.applicationId == appId, + s"expected applicationId is $appId, but got ${result.applicationId}") + assert( + result.dataSourceName == dataSourceName, + s"expected data source is $dataSourceName, but got ${result.dataSourceName}") + val actualQueryText = normalizeString(result.queryText) + val expectedQueryText = normalizeString(query) + assert( + actualQueryText == expectedQueryText, + s"expected query is $expectedQueryText, but got $actualQueryText") + assert(result.sessionId.isEmpty, s"we don't expect session id, but got ${result.sessionId}") + assert( + result.updateTime > queryStartTime, + s"expect that update time is ${result.updateTime} later than query start time $queryStartTime, but it is not") + assert( + result.queryRunTime > 0, + s"expected query run time is positive, but got ${result.queryRunTime}") + assert( + result.queryRunTime < System.currentTimeMillis() - queryStartTime, + s"expected query run time ${result.queryRunTime} should be less than ${System + .currentTimeMillis() - queryStartTime}, but it is not") + assert(result.queryId.isEmpty, s"we don't expect query id, but got ${result.queryId}") + } + + def pollForResultAndAssert(expected: REPLResult => Boolean, jobId: String): Unit = { + pollForResultAndAssert( + osClient, + expected, + "jobRunId", + jobId, + streamingTimeout.toMillis, + resultIndex) + } +} diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala new file mode 100644 index 000000000..9a2afc71e --- /dev/null +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -0,0 +1,573 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import java.util.UUID + +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.{Duration, MINUTES} +import scala.util.{Failure, Success, Try} +import scala.util.control.Breaks.{break, breakable} + +import org.opensearch.OpenSearchStatusException +import org.opensearch.flint.OpenSearchSuite +import org.opensearch.flint.app.{FlintCommand, FlintInstance} +import org.opensearch.flint.core.{FlintClient, FlintOptions} +import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater} +import org.opensearch.search.sort.SortOrder + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.flint.config.FlintSparkConf.{DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID} +import org.apache.spark.sql.util.MockEnvironment +import org.apache.spark.util.ThreadUtils + +class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { + + var flintClient: FlintClient = _ + var osClient: OSClient = _ + var updater: OpenSearchUpdater = _ + val requestIndex = "flint_ql_sessions" + val resultIndex = "query_results2" + val jobRunId = "00ff4o3b5091080q" + val appId = "00feq82b752mbt0p" + val dataSourceName = "my_glue1" + val sessionId = "10" + val requestIndexMapping = + """ { + | "properties": { + | "applicationId": { + | "type": "keyword" + | }, + | "dataSourceName": { + | "type": "keyword" + | }, + | "error": { + | "type": "text" + | }, + | "excludeJobIds": { + | "type": "text", + | "fields": { + | "keyword": { + | "type": "keyword", + | "ignore_above": 256 + | } + | } + | }, + | "if_primary_term": { + | "type": "long" + | }, + | "if_seq_no": { + | "type": "long" + | }, + | "jobId": { + | "type": "keyword" + | }, + | "jobStartTime": { + | "type": "long" + | }, + | "lang": { + | "type": "keyword" + | }, + | "lastUpdateTime": { + | "type": "date", + | "format": "strict_date_time||epoch_millis" + | }, + | "query": { + | "type": "text" + | }, + | "queryId": { + | "type": "text", + | "fields": { + | "keyword": { + | "type": "keyword", + | "ignore_above": 256 + | } + | } + | }, + | "sessionId": { + | "type": "keyword" + | }, + | "state": { + | "type": "keyword" + | }, + | "statementId": { + | "type": "keyword" + | }, + | "submitTime": { + | "type": "date", + | "format": "strict_date_time||epoch_millis" + | }, + | "type": { + | "type": "keyword" + | } + | } + | } + |""".stripMargin + val testTable = dataSourceName + ".default.flint_sql_test" + + // use a thread-local variable to store and manage the future in beforeEach and afterEach + val threadLocalFuture = new ThreadLocal[Future[Unit]]() + + override def beforeAll(): Unit = { + super.beforeAll() + + flintClient = new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava)); + osClient = new OSClient(new FlintOptions(openSearchOptions.asJava)) + updater = new OpenSearchUpdater( + requestIndex, + new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava))) + + } + + override def afterEach(): Unit = { + flintClient.deleteIndex(requestIndex) + super.afterEach() + } + + def createSession(jobId: String, excludeJobId: String): Unit = { + val docs = Seq(s"""{ + | "state": "running", + | "lastUpdateTime": 1698796582978, + | "applicationId": "00fd777k3k3ls20p", + | "error": "", + | "sessionId": ${sessionId}, + | "jobId": \"${jobId}\", + | "type": "session", + | "excludeJobIds": [\"${excludeJobId}\"] + |}""".stripMargin) + index(requestIndex, oneNodeSetting, requestIndexMapping, docs) + } + + def startREPL(): Future[Unit] = { + val prefix = "flint-repl-test" + val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor(prefix, 1) + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + + val futureResult = Future { + // SparkConf's constructor creates a SparkConf that loads defaults from system properties and the classpath. + // Read SparkConf.getSystemProperties + System.setProperty(DATA_SOURCE_NAME.key, "my_glue1") + System.setProperty(JOB_TYPE.key, "interactive") + System.setProperty(SESSION_ID.key, sessionId) + System.setProperty(REQUEST_INDEX.key, requestIndex) + System.setProperty(EXCLUDE_JOB_IDS.key, "00fer5qo32fa080q") + System.setProperty(REPL_INACTIVITY_TIMEOUT_MILLIS.key, "5000") + System.setProperty( + s"spark.sql.catalog.my_glue1", + "org.opensearch.sql.FlintDelegatingSessionCatalog") + System.setProperty("spark.master", "local") + System.setProperty(HOST_ENDPOINT.key, openSearchHost) + System.setProperty(HOST_PORT.key, String.valueOf(openSearchPort)) + System.setProperty(REFRESH_POLICY.key, "true") + + FlintREPL.envinromentProvider = new MockEnvironment( + Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) + FlintREPL.enableHiveSupport = false + FlintREPL.terminateJVM = false + FlintREPL.main(Array("select 1", resultIndex)) + } + futureResult.onComplete { + case Success(result) => logInfo(s"Success result: $result") + case Failure(ex) => + ex.printStackTrace() + assert(false, s"An error has occurred: ${ex.getMessage}") + } + futureResult + } + + def waitREPLStop(future: Future[Unit]): Unit = { + try { + ThreadUtils.awaitResult(future, Duration(1, MINUTES)) + } catch { + case e: Exception => + e.printStackTrace() + assert(false, "failure waiting for REPL to finish") + } + } + + def submitQuery(query: String, queryId: String): String = { + submitQuery(query, queryId, System.currentTimeMillis()) + } + + def submitQuery(query: String, queryId: String, submitTime: Long): String = { + val statementId = UUID.randomUUID().toString + + updater.upsert( + statementId, + s"""{ + | "sessionId": "${sessionId}", + | "query": "${query}", + | "applicationId": "00fd775baqpu4g0p", + | "state": "waiting", + | "submitTime": $submitTime, + | "type": "statement", + | "statementId": "${statementId}", + | "queryId": "${queryId}", + | "dataSourceName": "${dataSourceName}" + |}""".stripMargin) + statementId + } + + test("sanity") { + try { + createSession(jobRunId, "") + threadLocalFuture.set(startREPL()) + + val createStatement = + s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\\t' + | ) + |""".stripMargin + submitQuery(s"${makeJsonCompliant(createStatement)}", "99") + + val insertStatement = + s""" + | INSERT INTO $testTable + | VALUES ('Hello', 30) + | """.stripMargin + submitQuery(s"${makeJsonCompliant(insertStatement)}", "100") + + val selectQueryId = "101" + val selectQueryStartTime = System.currentTimeMillis() + val selectQuery = s"SELECT name, age FROM $testTable".stripMargin + val selectStatementId = submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId) + + val describeStatement = s"DESC $testTable".stripMargin + val descQueryId = "102" + val descStartTime = System.currentTimeMillis() + val descStatementId = submitQuery(s"${makeJsonCompliant(describeStatement)}", descQueryId) + + val showTableStatement = + s"SHOW TABLES IN " + dataSourceName + ".default LIKE 'flint_sql_test'" + val showQueryId = "103" + val showStartTime = System.currentTimeMillis() + val showTableStatementId = + submitQuery(s"${makeJsonCompliant(showTableStatement)}", showQueryId) + + val wrongSelectQueryId = "104" + val wrongSelectQueryStartTime = System.currentTimeMillis() + val wrongSelectQuery = s"SELECT name, age FROM testTable".stripMargin + val wrongSelectStatementId = + submitQuery(s"${makeJsonCompliant(wrongSelectQuery)}", wrongSelectQueryId) + + val lateSelectQueryId = "105" + val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin + // submitted from last year. We won't pick it up + val lateSelectStatementId = + submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId, 1672101970000L) + + // clean up + val dropStatement = + s"""DROP TABLE $testTable""".stripMargin + submitQuery(s"${makeJsonCompliant(dropStatement)}", "999") + + val selectQueryValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 1, + s"expected result size is 1, but got ${result.results.size}") + val expectedResult = "{'name':'Hello','age':30}" + assert( + result.results(0).equals(expectedResult), + s"expected result is $expectedResult, but got ${result.results(0)}") + assert( + result.schemas.size == 2, + s"expected schema size is 2, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'age','data_type':'integer'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + commonValidation(result, selectQueryId, selectQuery, selectQueryStartTime) + successValidation(result) + true + } + pollForResultAndAssert(selectQueryValidation, selectQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + selectStatementId), + s"Fail to verify for $selectStatementId.") + + val descValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 2, + s"expected result size is 2, but got ${result.results.size}") + val expectedResult0 = "{'col_name':'name','data_type':'string'}" + assert( + result.results(0).equals(expectedResult0), + s"expected result is $expectedResult0, but got ${result.results(0)}") + val expectedResult1 = "{'col_name':'age','data_type':'int'}" + assert( + result.results(1).equals(expectedResult1), + s"expected result is $expectedResult1, but got ${result.results(1)}") + assert( + result.schemas.size == 3, + s"expected schema size is 3, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'col_name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'data_type','data_type':'string'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + val expectedSecondSchema = "{'column_name':'comment','data_type':'string'}" + assert( + result.schemas(2).equals(expectedSecondSchema), + s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") + commonValidation(result, descQueryId, describeStatement, descStartTime) + successValidation(result) + true + } + pollForResultAndAssert(descValidation, descQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + descStatementId), + s"Fail to verify for $descStatementId.") + + val showValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 1, + s"expected result size is 1, but got ${result.results.size}") + val expectedResult = + "{'namespace':'default','tableName':'flint_sql_test','isTemporary':false}" + assert( + result.results(0).equals(expectedResult), + s"expected result is $expectedResult, but got ${result.results(0)}") + assert( + result.schemas.size == 3, + s"expected schema size is 3, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'namespace','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'tableName','data_type':'string'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + val expectedSecondSchema = "{'column_name':'isTemporary','data_type':'boolean'}" + assert( + result.schemas(2).equals(expectedSecondSchema), + s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") + commonValidation(result, showQueryId, showTableStatement, showStartTime) + successValidation(result) + true + } + pollForResultAndAssert(showValidation, showQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + showTableStatementId), + s"Fail to verify for $showTableStatementId.") + + val wrongSelectQueryValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 0, + s"expected result size is 0, but got ${result.results.size}") + assert( + result.schemas.size == 0, + s"expected schema size is 0, but got ${result.schemas.size}") + commonValidation(result, wrongSelectQueryId, wrongSelectQuery, wrongSelectQueryStartTime) + failureValidation(result) + true + } + pollForResultAndAssert(wrongSelectQueryValidation, wrongSelectQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "failed" + }, + wrongSelectStatementId), + s"Fail to verify for $wrongSelectStatementId.") + + // expect time out as this statement should not be picked up + assert( + awaitConditionForStatementOrTimeout( + statement => { + statement.state != "waiting" + }, + lateSelectStatementId), + s"Fail to verify for $lateSelectStatementId.") + } catch { + case e: Exception => + logError("Unexpected exception", e) + assert(false, "Unexpected exception") + } finally { + waitREPLStop(threadLocalFuture.get()) + threadLocalFuture.remove() + + // shutdown hook is called after all tests have finished. We cannot verify if session has correctly been set in IT. + } + } + + /** + * JSON does not support raw newlines (\n) in string values. All newlines must be escaped or + * removed when inside a JSON string. The same goes for tab characters, which should be + * represented as \\t. + * + * Here, I replace the newlines with spaces and escape tab characters that is being included in + * the JSON. + * + * @param sqlQuery + * @return + */ + def makeJsonCompliant(sqlQuery: String): String = { + sqlQuery.replaceAll("\n", " ").replaceAll("\t", "\\\\t") + } + + def commonValidation( + result: REPLResult, + expectedQueryId: String, + expectedStatement: String, + queryStartTime: Long): Unit = { + assert( + result.jobRunId.equals(jobRunId), + s"expected job id is $jobRunId, but got ${result.jobRunId}") + assert( + result.applicationId.equals(appId), + s"expected app id is $appId, but got ${result.applicationId}") + assert( + result.dataSourceName.equals(dataSourceName), + s"expected data source is $dataSourceName, but got ${result.dataSourceName}") + assert( + result.queryId.equals(expectedQueryId), + s"expected query id is $expectedQueryId, but got ${result.queryId}") + assert( + result.queryText.equals(expectedStatement), + s"expected query is $expectedStatement, but got ${result.queryText}") + assert( + result.sessionId.equals(sessionId), + s"expected session id is $sessionId, but got ${result.sessionId}") + assert( + result.updateTime > queryStartTime, + s"expect that update time is ${result.updateTime} later than query start time $queryStartTime, but it is not") + assert( + result.queryRunTime > 0, + s"expected query run time is positive, but got ${result.queryRunTime}") + assert( + result.queryRunTime < System.currentTimeMillis() - queryStartTime, + s"expected query run time ${result.queryRunTime} should be less than ${System + .currentTimeMillis() - queryStartTime}, but it is not") + } + + def successValidation(result: REPLResult): Unit = { + assert( + result.status.equals("SUCCESS"), + s"expected status is SUCCESS, but got ${result.status}") + assert(result.error.isEmpty, s"we don't expect error, but got ${result.error}") + } + + def failureValidation(result: REPLResult): Unit = { + assert(result.status.equals("FAILED"), s"expected status is FAILED, but got ${result.status}") + assert(!result.error.isEmpty, s"we expect error, but got nothing") + } + + def pollForResultAndAssert(expected: REPLResult => Boolean, queryId: String): Unit = { + pollForResultAndAssert(osClient, expected, "queryId", queryId, 60000, resultIndex) + } + + /** + * Repeatedly polls a resource until a specified condition is met or a timeout occurs. + * + * This method continuously checks a resource for a specific condition. If the condition is met + * within the timeout period, the polling stops. If the timeout period is exceeded without the + * condition being met, an assertion error is thrown. + * + * @param osClient + * The OSClient used to poll the resource. + * @param condition + * A function that takes an instance of type T and returns a Boolean. This function defines + * the condition to be met. + * @param id + * The unique identifier of the resource to be polled. + * @param timeoutMillis + * The maximum amount of time (in milliseconds) to wait for the condition to be met. + * @param index + * The index in which the resource resides. + * @param deserialize + * A function that deserializes a String into an instance of type T. + * @param logType + * A descriptive string for logging purposes, indicating the type of resource being polled. + * @return + * whether timeout happened + * @throws OpenSearchStatusException + * if there's an issue fetching the resource. + */ + def awaitConditionOrTimeout[T]( + osClient: OSClient, + expected: T => Boolean, + id: String, + timeoutMillis: Long, + index: String, + deserialize: String => T, + logType: String): Boolean = { + val getResponse = osClient.getDoc(index, id) + val startTime = System.currentTimeMillis() + breakable { + while (System.currentTimeMillis() - startTime < timeoutMillis) { + logInfo(s"Check $logType for $id") + try { + if (getResponse.isExists()) { + val instance = deserialize(getResponse.getSourceAsString) + logInfo(s"$logType $id: $instance") + if (expected(instance)) { + break + } + } + } catch { + case e: OpenSearchStatusException => logError(s"Exception while fetching $logType", e) + } + Thread.sleep(2000) // 2 seconds + } + } + System.currentTimeMillis() - startTime >= timeoutMillis + } + + def awaitConditionForStatementOrTimeout( + expected: FlintCommand => Boolean, + statementId: String): Boolean = { + awaitConditionOrTimeout[FlintCommand]( + osClient, + expected, + statementId, + 10000, + requestIndex, + FlintCommand.deserialize, + "statement") + } + + def awaitConditionForSessionOrTimeout( + expected: FlintInstance => Boolean, + sessionId: String): Boolean = { + awaitConditionOrTimeout[FlintInstance]( + osClient, + expected, + sessionId, + 10000, + requestIndex, + FlintInstance.deserialize, + "session") + } +} diff --git a/integ-test/src/test/scala/org/apache/spark/sql/JobTest.scala b/integ-test/src/test/scala/org/apache/spark/sql/JobTest.scala new file mode 100644 index 000000000..563997b7f --- /dev/null +++ b/integ-test/src/test/scala/org/apache/spark/sql/JobTest.scala @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success} +import scala.util.control.Breaks._ + +import org.opensearch.OpenSearchStatusException +import org.opensearch.flint.OpenSearchSuite +import org.opensearch.flint.core.FlintOptions +import org.opensearch.search.sort.SortOrder + +import org.apache.spark.internal.Logging + +/** + * We use a self-type annotation (self: OpenSearchSuite =>) to specify that it must be mixed into + * a class that also mixes in OpenSearchSuite. This way, JobTest can still use the + * openSearchOptions field, + */ +trait JobTest extends Logging { self: OpenSearchSuite => + + def pollForResultAndAssert( + osClient: OSClient, + expected: REPLResult => Boolean, + idField: String, + idValue: String, + timeoutMillis: Long, + resultIndex: String): Unit = { + val query = + s"""{ + | "bool": { + | "must": [ + | { + | "term": { + | "$idField": "$idValue" + | } + | } + | ] + | } + |}""".stripMargin + val resultReader = osClient.createQueryReader(resultIndex, query, "updateTime", SortOrder.ASC) + + val startTime = System.currentTimeMillis() + breakable { + while (System.currentTimeMillis() - startTime < timeoutMillis) { + logInfo(s"Check result for $idValue") + try { + if (resultReader.hasNext()) { + REPLResult.deserialize(resultReader.next()) match { + case Success(replResult) => + logInfo(s"repl result: $replResult") + assert(expected(replResult), s"{$query} failed.") + case Failure(exception) => + assert(false, "Failed to deserialize: " + exception.getMessage) + } + break + } + } catch { + case e: OpenSearchStatusException => logError("Exception while querying for result", e) + } + + Thread.sleep(2000) // 2 seconds + } + if (System.currentTimeMillis() - startTime >= timeoutMillis) { + assert( + false, + s"Timeout occurred after $timeoutMillis milliseconds waiting for query result.") + } + } + } + + /** + * Used to preprocess multi-line queries before comparing them as serialized and deserialized + * queries might have different characters. + * @param s + * input + * @return + * normalized input by replacing all space, tab, ane newlines with single spaces. + */ + def normalizeString(s: String): String = { + // \\s+ is a regular expression that matches one or more whitespace characters, including spaces, tabs, and newlines. + s.replaceAll("\\s+", " ") + } // Replace all whitespace characters with empty string +} diff --git a/integ-test/src/test/scala/org/apache/spark/sql/REPLResult.scala b/integ-test/src/test/scala/org/apache/spark/sql/REPLResult.scala new file mode 100644 index 000000000..34dc2595c --- /dev/null +++ b/integ-test/src/test/scala/org/apache/spark/sql/REPLResult.scala @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import scala.util.Try + +import org.json4s.{DefaultFormats, Formats} +import org.json4s.native.JsonMethods.parse + +class REPLResult( + val results: Seq[String], + val schemas: Seq[String], + val jobRunId: String, + val applicationId: String, + val dataSourceName: String, + val status: String, + val error: String, + val queryId: String, + val queryText: String, + val sessionId: String, + val updateTime: Long, + val queryRunTime: Long) { + override def toString: String = { + s"REPLResult(results=$results, schemas=$schemas, jobRunId=$jobRunId, applicationId=$applicationId, " + + s"dataSourceName=$dataSourceName, status=$status, error=$error, queryId=$queryId, queryText=$queryText, " + + s"sessionId=$sessionId, updateTime=$updateTime, queryRunTime=$queryRunTime)" + } +} + +object REPLResult { + implicit val formats: Formats = DefaultFormats + + def deserialize(jsonString: String): Try[REPLResult] = Try { + val json = parse(jsonString) + + new REPLResult( + results = (json \ "result").extract[Seq[String]], + schemas = (json \ "schema").extract[Seq[String]], + jobRunId = (json \ "jobRunId").extract[String], + applicationId = (json \ "applicationId").extract[String], + dataSourceName = (json \ "dataSourceName").extract[String], + status = (json \ "status").extract[String], + error = (json \ "error").extract[String], + queryId = (json \ "queryId").extract[String], + queryText = (json \ "queryText").extract[String], + sessionId = (json \ "sessionId").extract[String], + updateTime = (json \ "updateTime").extract[Long], + queryRunTime = (json \ "queryRunTime").extract[Long]) + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 7af1c2639..4ab3a983b 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -51,6 +51,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit protected def deleteTestIndex(testIndexNames: String*): Unit = { testIndexNames.foreach(testIndex => { + /** * Todo, if state is not valid, will throw IllegalStateException. Should check flint * .isRefresh before cleanup resource. Current solution, (1) try to delete flint index, (2) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala index 61564546e..575f09362 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala @@ -142,7 +142,8 @@ class FlintSparkPPLCorrelationITSuite assert( thrown.getMessage === "Correlation command was called with `fields` attribute having different elements from the 'mapping' attributes ") } - test("create failing ppl correlation query with no scope - due to mismatch fields to mappings test") { + test( + "create failing ppl correlation query with no scope - due to mismatch fields to mappings test") { val thrown = intercept[IllegalStateException] { val frame = sql(s""" | source = $testTable1, $testTable2| correlate exact fields(name, country) mapping($testTable1.name = $testTable2.name) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala index 62ff50fb6..32c1baa0a 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala @@ -331,6 +331,7 @@ class FlintSparkPPLFiltersITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + /** * | age_span | country | average_age | * |:---------|:--------|:------------| diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 750e228ef..df0bf5c4e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -52,9 +52,13 @@ object FlintJob extends Logging with FlintJobExecutor { * Without this setup, Spark would not recognize names in the format `my_glue1.default`. */ conf.set("spark.sql.defaultCatalog", dataSource) - val jobOperator = - JobOperator(conf, query, dataSource, resultIndex, wait.equalsIgnoreCase("streaming")) + JobOperator( + createSparkSession(conf), + query, + dataSource, + resultIndex, + wait.equalsIgnoreCase("streaming")) jobOperator.start() } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index a44e70401..4a3c03d9b 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -18,11 +18,12 @@ import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql.FlintJob.{checkAndCreateIndex, createIndex, currentTimeProvider, executeQuery, getFailedData, getFormattedData, isSuperset, logError, logInfo, processQueryException, writeDataFrameToOpensearch} +import org.apache.spark.sql.FlintREPL.envinromentProvider import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType} -import org.apache.spark.sql.util.{DefaultThreadPoolFactory, RealTimeProvider, ThreadPoolFactory, TimeProvider} +import org.apache.spark.sql.util.{DefaultThreadPoolFactory, EnvironmentProvider, RealEnvironment, RealTimeProvider, ThreadPoolFactory, TimeProvider} import org.apache.spark.util.ThreadUtils trait FlintJobExecutor { @@ -30,6 +31,8 @@ trait FlintJobExecutor { var currentTimeProvider: TimeProvider = new RealTimeProvider() var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory() + var envinromentProvider: EnvironmentProvider = new RealEnvironment() + var enableHiveSupport: Boolean = true // The enabled setting, which can be applied only to the top-level mapping definition and to object fields, val resultIndexMapping = @@ -87,7 +90,11 @@ trait FlintJobExecutor { } def createSparkSession(conf: SparkConf): SparkSession = { - SparkSession.builder().config(conf).enableHiveSupport().getOrCreate() + val builder = SparkSession.builder().config(conf) + if (enableHiveSupport) { + builder.enableHiveSupport() + } + builder.getOrCreate() } private def writeData(resultData: DataFrame, resultIndex: String): Unit = { @@ -177,8 +184,8 @@ trait FlintJobExecutor { ( resultToSave, resultSchemaToSave, - sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown"), - sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), + envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), dataSource, "SUCCESS", "", @@ -226,8 +233,8 @@ trait FlintJobExecutor { ( null, null, - sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown"), - sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), + envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), dataSource, "FAILED", error, @@ -310,7 +317,8 @@ trait FlintJobExecutor { } } catch { case e: IllegalStateException - if e.getCause().getMessage().contains("index_not_found_exception") => + if e.getCause != null && + e.getCause.getMessage.contains("index_not_found_exception") => createIndex(osClient, resultIndex, resultIndexMapping) case e: InterruptedException => val error = s"Interrupted by the main thread: ${e.getMessage}" diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 6c3fd957d..2a63653e3 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -18,11 +18,15 @@ import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings import org.opensearch.flint.app.{FlintCommand, FlintInstance} import org.opensearch.flint.app.FlintInstance.formats +import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} +import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintJob.createSparkSession import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.sql.flint.config.FlintSparkConf.REPL_INACTIVITY_TIMEOUT_MILLIS import org.apache.spark.sql.util.{DefaultShutdownHookManager, ShutdownHookManagerTrait} import org.apache.spark.util.ThreadUtils @@ -42,13 +46,16 @@ import org.apache.spark.util.ThreadUtils object FlintREPL extends Logging with FlintJobExecutor { private val HEARTBEAT_INTERVAL_MILLIS = 60000L - private val DEFAULT_INACTIVITY_LIMIT_MILLIS = 10 * 60 * 1000 private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 val INITIAL_DELAY_MILLIS = 3000L val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L + @volatile var earlyExitFlag: Boolean = false + // termiante JVM in the presence non-deamon thread before exiting + var terminateJVM = true + def updateSessionIndex(flintCommand: FlintCommand, updater: OpenSearchUpdater): Unit = { updater.update(flintCommand.statementId, FlintCommand.serialize(flintCommand)) } @@ -61,7 +68,7 @@ object FlintREPL extends Logging with FlintJobExecutor { // init SparkContext val conf: SparkConf = createSparkConf() - val dataSource = conf.get("spark.flint.datasource.name", "unknown") + val dataSource = conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, "unknown") // https://github.com/opensearch-project/opensearch-spark/issues/138 /* * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, @@ -71,33 +78,36 @@ object FlintREPL extends Logging with FlintJobExecutor { * Without this setup, Spark would not recognize names in the format `my_glue1.default`. */ conf.set("spark.sql.defaultCatalog", dataSource) - val wait = conf.get("spark.flint.job.type", "continue") + val wait = conf.get(FlintSparkConf.JOB_TYPE.key, "continue") if (wait.equalsIgnoreCase("streaming")) { logInfo(s"""streaming query ${query}""") val jobOperator = - JobOperator(conf, query, dataSource, resultIndex, true) + JobOperator(createSparkSession(conf), query, dataSource, resultIndex, true) jobOperator.start() } else { // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. - val sessionIndex: Option[String] = Option(conf.get("spark.flint.job.requestIndex", null)) - val sessionId: Option[String] = Option(conf.get("spark.flint.job.sessionId", null)) + val sessionIndex: Option[String] = Option(conf.get(FlintSparkConf.REQUEST_INDEX.key, null)) + val sessionId: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null)) if (sessionIndex.isEmpty) { - throw new IllegalArgumentException("spark.flint.job.requestIndex is not set") + throw new IllegalArgumentException(FlintSparkConf.REQUEST_INDEX.key + " is not set") } if (sessionId.isEmpty) { - throw new IllegalArgumentException("spark.flint.job.sessionId is not set") + throw new IllegalArgumentException(FlintSparkConf.SESSION_ID.key + " is not set") } val spark = createSparkSession(conf) val osClient = new OSClient(FlintSparkConf().flintOptions()) - val jobId = sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown") - val applicationId = sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") + val jobId = envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") + val applicationId = + envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") // Read the values from the Spark configuration or fall back to the default values val inactivityLimitMillis: Long = - conf.getLong("spark.flint.job.inactivityLimitMillis", DEFAULT_INACTIVITY_LIMIT_MILLIS) + conf.getLong( + FlintSparkConf.REPL_INACTIVITY_TIMEOUT_MILLIS.key, + FlintOptions.DEFAULT_INACTIVITY_LIMIT_MILLIS) val queryExecutionTimeoutSecs: Duration = Duration( conf.getLong( "spark.flint.job.queryExecutionTimeoutSec", @@ -136,6 +146,7 @@ object FlintREPL extends Logging with FlintJobExecutor { applicationId, flintSessionIndexUpdater, jobStartTime)) { + earlyExitFlag = true return } @@ -151,7 +162,6 @@ object FlintREPL extends Logging with FlintJobExecutor { queryExecutionTimeoutSecs, inactivityLimitMillis, queryWaitTimeoutMillis) - exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { queryLoop(commandContext) } @@ -177,12 +187,12 @@ object FlintREPL extends Logging with FlintJobExecutor { // Check for non-daemon threads that may prevent the driver from shutting down. // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, // which may be due to unresolved bugs in dependencies or threads not being properly shut down. - if (threadPoolFactory.hasNonDaemonThreadsOtherThanMain) { + if (terminateJVM && threadPoolFactory.hasNonDaemonThreadsOtherThanMain) { logInfo("A non-daemon thread in the driver is seen.") // Exit the JVM to prevent resource leaks and potential emr-s job hung. // A zero status code is used for a graceful shutdown without indicating an error. // If exiting with non-zero status, emr-s job will fail. - // This is a part of the fault tolerance mechanism to handle such scenarios gracefully. + // This is a part of the fault tolerance mechanism to handle such scenarios gracefully System.exit(0) } } @@ -232,7 +242,7 @@ object FlintREPL extends Logging with FlintJobExecutor { applicationId: String, flintSessionIndexUpdater: OpenSearchUpdater, jobStartTime: Long): Boolean = { - val confExcludeJobsOpt = conf.getOption("spark.flint.deployment.excludeJobs") + val confExcludeJobsOpt = conf.getOption(FlintSparkConf.EXCLUDE_JOB_IDS.key) confExcludeJobsOpt match { case None => @@ -505,6 +515,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } if (!canPickNextStatementResult) { + earlyExitFlag = true canProceed = false } else if (!flintReader.hasNext) { canProceed = false @@ -559,9 +570,6 @@ object FlintREPL extends Logging with FlintJobExecutor { osClient: OSClient): Unit = { try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) - // todo. it is migration plan to handle https://github - // .com/opensearch-project/sql/issues/2436. Remove sleep after issue fixed in plugin. - Thread.sleep(2000) if (flintCommand.isRunning() || flintCommand.isWaiting()) { // we have set failed state in exception handling flintCommand.complete() @@ -814,7 +822,7 @@ object FlintREPL extends Logging with FlintJobExecutor { | } |}""".stripMargin - val flintReader = osClient.createReader(sessionIndex, dsl, "submitTime") + val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) flintReader } @@ -838,7 +846,15 @@ object FlintREPL extends Logging with FlintJobExecutor { } val state = Option(source.get("state")).map(_.asInstanceOf[String]) - if (state.isDefined && state.get != "dead" && state.get != "fail") { + // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, + // it indicates that the control plane has already initiated a new session to handle remaining requests for the + // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new + // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, + // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption + // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure + // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate + // processing. + if (!earlyExitFlag && state.isDefined && state.get != "dead" && state.get != "fail") { updateFlintInstanceBeforeShutdown( source, getResponse, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index a702d2c64..a2edbe98e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -21,14 +21,13 @@ import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.util.ThreadUtils case class JobOperator( - sparkConf: SparkConf, + spark: SparkSession, query: String, dataSource: String, resultIndex: String, streaming: Boolean) extends Logging with FlintJobExecutor { - private val spark = createSparkSession(sparkConf) // jvm shutdown hook sys.addShutdownHook(stop()) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala index e2e44bddd..cd784e704 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala @@ -10,6 +10,7 @@ import java.util.ArrayList import java.util.Locale import org.opensearch.action.get.{GetRequest, GetResponse} +import org.opensearch.action.search.{SearchRequest, SearchResponse} import org.opensearch.client.{RequestOptions, RestHighLevelClient} import org.opensearch.client.indices.{CreateIndexRequest, GetIndexRequest, GetIndexResponse} import org.opensearch.client.indices.CreateIndexRequest @@ -18,7 +19,7 @@ import org.opensearch.common.settings.Settings import org.opensearch.common.xcontent.{NamedXContentRegistry, XContentParser, XContentType} import org.opensearch.common.xcontent.DeprecationHandler.IGNORE_DEPRECATIONS import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions} -import org.opensearch.flint.core.storage.{FlintReader, OpenSearchScrollReader, OpenSearchUpdater} +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchQueryReader, OpenSearchScrollReader, OpenSearchUpdater} import org.opensearch.index.query.{AbstractQueryBuilder, MatchAllQueryBuilder, QueryBuilder} import org.opensearch.plugins.SearchPlugin import org.opensearch.search.SearchModule @@ -117,14 +118,14 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { String.format( Locale.ROOT, "Failed to retrieve doc %s from index %s", - osIndexName, - id), + id, + osIndexName), e) } } } - def createReader(indexName: String, query: String, sort: String): FlintReader = try { + def createScrollReader(indexName: String, query: String, sort: String): FlintReader = try { var queryBuilder: QueryBuilder = new MatchAllQueryBuilder if (!Strings.isNullOrEmpty(query)) { val parser = @@ -152,4 +153,24 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { } } } + + def createQueryReader( + indexName: String, + query: String, + sort: String, + sortOrder: SortOrder): FlintReader = try { + var queryBuilder: QueryBuilder = new MatchAllQueryBuilder + if (!Strings.isNullOrEmpty(query)) { + val parser = + XContentType.JSON.xContent.createParser(xContentRegistry, IGNORE_DEPRECATIONS, query) + queryBuilder = AbstractQueryBuilder.parseInnerQueryBuilder(parser) + } + new OpenSearchQueryReader( + flintClient.createClient(), + indexName, + new SearchSourceBuilder().query(queryBuilder).sort(sort, sortOrder)) + } catch { + case e: IOException => + throw new RuntimeException(e) + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/EnvironmentProvider.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/EnvironmentProvider.scala new file mode 100644 index 000000000..5b1c4e2df --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/EnvironmentProvider.scala @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +/** + * Trait defining an interface for fetching environment variables. + */ +trait EnvironmentProvider { + + /** + * Retrieves the value of an environment variable. + * + * @param name + * The name of the environment variable. + * @param default + * The default value to return if the environment variable is not set. + * @return + * The value of the environment variable if it exists, otherwise the default value. + */ + def getEnvVar(name: String, default: String): String +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/RealEnvironment.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/RealEnvironment.scala new file mode 100644 index 000000000..bf5eafce5 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/RealEnvironment.scala @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +/** + * An implementation of `EnvironmentProvider` that fetches actual environment variables from the + * system. + */ +class RealEnvironment extends EnvironmentProvider { + + /** + * Retrieves the value of an environment variable from the system or returns a default value if + * not present. + * + * @param name + * The name of the environment variable. + * @param default + * The default value to return if the environment variable is not set in the system. + * @return + * The value of the environment variable if it exists in the system, otherwise the default + * value. + */ + def getEnvVar(name: String, default: String): String = sys.env.getOrElse(name, default) +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index c3d027102..3e9d408e6 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -15,6 +15,7 @@ import scala.concurrent.duration._ import scala.concurrent.duration.{Duration, MINUTES} import scala.reflect.runtime.universe.TypeTag +import org.mockito.ArgumentMatchers.{eq => eqTo, _} import org.mockito.ArgumentMatchersSugar import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -22,12 +23,13 @@ import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse import org.opensearch.flint.app.FlintCommand import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} +import org.opensearch.search.sort.SortOrder import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.types.{ArrayType, LongType, NullType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{LongType, NullType, StringType, StructField, StructType} import org.apache.spark.sql.util.{DefaultThreadPoolFactory, MockThreadPoolFactory, MockTimeProvider, RealTimeProvider, ShutdownHookManagerTrait} import org.apache.spark.util.ThreadUtils @@ -411,7 +413,8 @@ class FlintREPLTest new ConnectException( "Timeout connecting to [search-foo-1-bar.eu-west-1.es.amazonaws.com:443]")) val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) when(mockReader.hasNext).thenThrow(exception) val maxRetries = 1 @@ -686,7 +689,8 @@ class FlintREPLTest test("queryLoop continue until inactivity limit is reached") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) when(mockReader.hasNext).thenReturn(false) val resultIndex = "testResultIndex" @@ -736,7 +740,8 @@ class FlintREPLTest test("queryLoop should stop when canPickUpNextStatement is false") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) when(mockReader.hasNext).thenReturn(true) val resultIndex = "testResultIndex" @@ -790,7 +795,8 @@ class FlintREPLTest test("queryLoop should properly shut down the thread pool after execution") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) when(mockReader.hasNext).thenReturn(false) val resultIndex = "testResultIndex" @@ -838,7 +844,8 @@ class FlintREPLTest test("queryLoop handle exceptions within the loop gracefully") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) // Simulate an exception thrown when hasNext is called when(mockReader.hasNext).thenThrow(new RuntimeException("Test exception")) @@ -889,7 +896,8 @@ class FlintREPLTest test("queryLoop should correctly update loop control variables") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) val getResponse = mock[GetResponse] when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(false) @@ -958,7 +966,8 @@ class FlintREPLTest test("queryLoop should execute loop without processing any commands") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) val getResponse = mock[GetResponse] when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(false) diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockEnvironment.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockEnvironment.scala new file mode 100644 index 000000000..b6f3e3c97 --- /dev/null +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockEnvironment.scala @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +/** + * A mock implementation of `EnvironmentProvider` for use in tests, where environment variables + * can be predefined. + * + * @param inputMap + * A map representing the environment variables (name -> value). + */ +class MockEnvironment(inputMap: Map[String, String]) extends EnvironmentProvider { + + /** + * Retrieves the value of an environment variable from the input map or returns a default value + * if not present. + * + * @param name + * The name of the environment variable. + * @param default + * The default value to return if the environment variable is not set in the input map. + * @return + * The value of the environment variable from the input map if it exists, otherwise the + * default value. + */ + def getEnvVar(name: String, default: String): String = inputMap.getOrElse(name, default) +}