Skip to content

Commit

Permalink
Add spark sql app to snapshot workflow (opensearch-project#19)
Browse files Browse the repository at this point in the history
Signed-off-by: Rupal Mahajan <[email protected]>
  • Loading branch information
rupal-bq authored Sep 8, 2023
1 parent 0d3fc0e commit 879a541
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 2 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/snapshot-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ jobs:
java-version: 11

- name: Publish to Local Maven
run: sbt standaloneCosmetic/publishM2
run: |
sbt standaloneCosmetic/publishM2
sbt sparkSqlApplicationCosmetic/publishM2
- uses: actions/checkout@v3
with:
Expand Down
19 changes: 18 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ lazy val commonSettings = Seq(
Test / test := ((Test / test) dependsOn testScalastyle).value)

lazy val root = (project in file("."))
.aggregate(flintCore, flintSparkIntegration)
.aggregate(flintCore, flintSparkIntegration, sparkSqlApplication)
.disablePlugins(AssemblyPlugin)
.settings(name := "flint", publish / skip := true)

Expand Down Expand Up @@ -125,6 +125,23 @@ lazy val standaloneCosmetic = project
exportJars := true,
Compile / packageBin := (flintSparkIntegration / assembly).value)

lazy val sparkSqlApplication = (project in file("spark-sql-application"))
.settings(
commonSettings,
name := "sql-job",
scalaVersion := scala212,
libraryDependencies ++= Seq(
"org.scalatest" %% "scalatest" % "3.2.15" % "test"),
libraryDependencies ++= deps(sparkVersion))

lazy val sparkSqlApplicationCosmetic = project
.settings(
name := "opensearch-spark-sql-application",
commonSettings,
releaseSettings,
exportJars := true,
Compile / packageBin := (sparkSqlApplication / assembly).value)

lazy val releaseSettings = Seq(
publishMavenStyle := true,
publishArtifact := true,
Expand Down
109 changes: 109 additions & 0 deletions spark-sql-application/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Spark SQL Application

This application execute sql query and store the result in OpenSearch index in following format
```
"stepId":"<emr-step-id>",
"applicationId":"<spark-application-id>"
"schema": "json blob",
"result": "json blob"
```

## Prerequisites

+ Spark 3.3.1
+ Scala 2.12.15
+ flint-spark-integration

## Usage

To use this application, you can run Spark with Flint extension:

```
./bin/spark-submit \
--class org.opensearch.sql.SQLJob \
--jars <flint-spark-integration-jar> \
sql-job.jar \
<spark-sql-query> \
<opensearch-index> \
<opensearch-host> \
<opensearch-port> \
<opensearch-scheme> \
<opensearch-auth> \
<opensearch-region> \
```

## Result Specifications

Following example shows how the result is written to OpenSearch index after query execution.

Let's assume sql query result is
```
+------+------+
|Letter|Number|
+------+------+
|A |1 |
|B |2 |
|C |3 |
+------+------+
```
OpenSearch index document will look like
```json
{
"_index" : ".query_execution_result",
"_id" : "A2WOsYgBMUoqCqlDJHrn",
"_score" : 1.0,
"_source" : {
"result" : [
"{'Letter':'A','Number':1}",
"{'Letter':'B','Number':2}",
"{'Letter':'C','Number':3}"
],
"schema" : [
"{'column_name':'Letter','data_type':'string'}",
"{'column_name':'Number','data_type':'integer'}"
],
"stepId" : "s-JZSB1139WIVU",
"applicationId" : "application_1687726870985_0003"
}
}
```

## Build

To build and run this application with Spark, you can run:

```
sbt clean sparkSqlApplicationCosmetic/publishM2
```

## Test

To run tests, you can use:

```
sbt test
```

## Scalastyle

To check code with scalastyle, you can run:

```
sbt scalastyle
```

## Code of Conduct

This project has adopted an [Open Source Code of Conduct](../CODE_OF_CONDUCT.md).

## Security

If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public GitHub issue.

## License

See the [LICENSE](../LICENSE.txt) file for our project's licensing. We will ask you to confirm the licensing of your contribution.

## Copyright

Copyright OpenSearch Contributors. See [NOTICE](../NOTICE) for details.
112 changes: 112 additions & 0 deletions spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.types._

/**
* Spark SQL Application entrypoint
*
* @param args(0)
* sql query
* @param args(1)
* opensearch index name
* @param args(2-6)
* opensearch connection values required for flint-integration jar.
* host, port, scheme, auth, region respectively.
* @return
* write sql query result to given opensearch index
*/
object SQLJob {
def main(args: Array[String]) {
// Get the SQL query and Opensearch Config from the command line arguments
val query = args(0)
val index = args(1)
val host = args(2)
val port = args(3)
val scheme = args(4)
val auth = args(5)
val region = args(6)

val conf: SparkConf = new SparkConf()
.setAppName("SQLJob")
.set("spark.sql.extensions", "org.opensearch.flint.spark.FlintSparkExtensions")
.set("spark.datasource.flint.host", host)
.set("spark.datasource.flint.port", port)
.set("spark.datasource.flint.scheme", scheme)
.set("spark.datasource.flint.auth", auth)
.set("spark.datasource.flint.region", region)

// Create a SparkSession
val spark = SparkSession.builder().config(conf).enableHiveSupport().getOrCreate()

try {
// Execute SQL query
val result: DataFrame = spark.sql(query)

// Get Data
val data = getFormattedData(result, spark)

// Write data to OpenSearch index
val aos = Map(
"host" -> host,
"port" -> port,
"scheme" -> scheme,
"auth" -> auth,
"region" -> region)

data.write
.format("flint")
.options(aos)
.mode("append")
.save(index)

} finally {
// Stop SparkSession
spark.stop()
}
}

/**
* Create a new formatted dataframe with json result, json schema and EMR_STEP_ID.
*
* @param result
* sql query result dataframe
* @param spark
* spark session
* @return
* dataframe with result, schema and emr step id
*/
def getFormattedData(result: DataFrame, spark: SparkSession): DataFrame = {
// Create the schema dataframe
val schemaRows = result.schema.fields.map { field =>
Row(field.name, field.dataType.typeName)
}
val resultSchema = spark.createDataFrame(spark.sparkContext.parallelize(schemaRows),
StructType(Seq(
StructField("column_name", StringType, nullable = false),
StructField("data_type", StringType, nullable = false))))

// Define the data schema
val schema = StructType(Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("stepId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true)))

// Create the data rows
val rows = Seq((
result.toJSON.collect.toList.map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'")),
resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'")),
sys.env.getOrElse("EMR_STEP_ID", "unknown"),
spark.sparkContext.applicationId))

// Create the DataFrame for data
spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql

import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType}


class SQLJobTest extends SparkFunSuite with Matchers {

val spark = SparkSession.builder().appName("Test").master("local").getOrCreate()

// Define input dataframe
val inputSchema = StructType(Seq(
StructField("Letter", StringType, nullable = false),
StructField("Number", IntegerType, nullable = false)
))
val inputRows = Seq(
Row("A", 1),
Row("B", 2),
Row("C", 3)
)
val input: DataFrame = spark.createDataFrame(
spark.sparkContext.parallelize(inputRows), inputSchema)

test("Test getFormattedData method") {
// Define expected dataframe
val expectedSchema = StructType(Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("stepId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true)
))
val expectedRows = Seq(
Row(
Array("{'Letter':'A','Number':1}",
"{'Letter':'B','Number':2}",
"{'Letter':'C','Number':3}"),
Array("{'column_name':'Letter','data_type':'string'}",
"{'column_name':'Number','data_type':'integer'}"),
"unknown",
spark.sparkContext.applicationId
)
)
val expected: DataFrame = spark.createDataFrame(
spark.sparkContext.parallelize(expectedRows), expectedSchema)

// Compare the result
val result = SQLJob.getFormattedData(input, spark)
assertEqualDataframe(expected, result)
}

def assertEqualDataframe(expected: DataFrame, result: DataFrame): Unit = {
assert(expected.schema === result.schema)
assert(expected.collect() === result.collect())
}
}

0 comments on commit 879a541

Please sign in to comment.