Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] FlinkJob.fromAvro operates on Spark row instead of compiled type #894

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ git.versionProperty := {
* scala 13 + spark 3.2.1: https://mvnrepository.com/artifact/org.apache.spark/spark-core_2.13/3.2.1
*/
val VersionMatrix: Map[String, VersionDependency] = Map(
"spark-avro" -> VersionDependency(
Seq(
"org.apache.spark" %% "spark-avro",
),
Some(spark2_4_0),
Some(spark3_1_1),
Some(spark3_2_1)
),
"spark-sql" -> VersionDependency(
Seq(
"org.apache.spark" %% "spark-sql",
Expand Down Expand Up @@ -395,6 +403,7 @@ lazy val flink = (project in file("flink"))
libraryDependencies ++= fromMatrix(scalaVersion.value,
"avro",
"spark-all/provided",
"spark-avro/provided",
"scala-parallel-collections",
"flink")
)
Expand Down
37 changes: 37 additions & 0 deletions flink/src/main/scala/ai/chronon/flink/AvroFlinkSource.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package ai.chronon.flink

import org.apache.flink.api.common.functions.RichMapFunction
import org.apache.flink.api.scala.createTypeInformation
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
import org.apache.spark.sql.Row
import org.apache.spark.sql.avro.AvroBytesToSparkRow


class AvroFlinkSource(source: DataStream[Array[Byte]],
avroSchemaJson: String,
avroOptions: Map[String, String]) extends FlinkSource[Row] {

val (_, encoder) = AvroBytesToSparkRow.mapperAndEncoder(avroSchemaJson, avroOptions)


override def getDataStream(topic: String, groupName: String)(
env: StreamExecutionEnvironment,
parallelism: Int): DataStream[Row] = {
val mapper = new AvroBytesToSparkRowFunction(avroSchemaJson, avroOptions)
source.map(mapper)
}
}

class AvroBytesToSparkRowFunction(avroSchemaJson: String, avroOptions: Map[String, String]) extends RichMapFunction[Array[Byte], Row] {

@transient private var mapper: Array[Byte] => Row = _
override def open(configuration: Configuration): Unit = {
val (_mapper, _) = AvroBytesToSparkRow.mapperAndEncoder(avroSchemaJson, avroOptions)
mapper = _mapper
}

override def map(value: Array[Byte]): Row = {
mapper(value)
}
}
28 changes: 19 additions & 9 deletions flink/src/main/scala/ai/chronon/flink/FlinkJob.scala
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
package ai.chronon.flink

import ai.chronon.aggregator.windowing.ResolutionUtils
import ai.chronon.api.{DataType}
import ai.chronon.api.DataType
import ai.chronon.api.Extensions.{GroupByOps, SourceOps}
import ai.chronon.flink.window.{
AlwaysFireOnElementTrigger,
FlinkRowAggProcessFunction,
FlinkRowAggregationFunction,
KeySelector,
TimestampedTile
}
import ai.chronon.flink.window.{AlwaysFireOnElementTrigger, FlinkRowAggProcessFunction, FlinkRowAggregationFunction, KeySelector, TimestampedTile}
import ai.chronon.online.{GroupByServingInfoParsed, SparkConversions}
import ai.chronon.online.KVStore.PutRequest
import org.apache.flink.streaming.api.scala.{DataStream, OutputTag, StreamExecutionEnvironment}
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.{Encoder, Row}
import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.functions.async.RichAsyncFunction
import org.apache.flink.streaming.api.windowing.assigners.{TumblingEventTimeWindows, WindowAssigner}
Expand Down Expand Up @@ -194,3 +188,19 @@ class FlinkJob[T](eventSrc: FlinkSource[T],
)
}
}

object FlinkJob {

def fromAvro(avroSource: AvroFlinkSource,
sinkFn: RichAsyncFunction[PutRequest, WriteResponse],
groupByServingInfoParsed: GroupByServingInfoParsed,
parallelism: Int): FlinkJob[Row] = {

new FlinkJob[Row](
avroSource,
sinkFn: RichAsyncFunction[PutRequest, WriteResponse],
groupByServingInfoParsed: GroupByServingInfoParsed,
avroSource.encoder,
parallelism: Int)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.apache.spark.sql.avro

import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.StructType

/**
* Thin wrapper to [[AvroDataToCatalyst]] that backs `from_avro` https://spark.apache.org/docs/3.5.1/sql-data-sources-avro.html#to_avro-and-from_avro
* SparkSQL doesn't have this registered, so instead we use the underlying functionality instead.
*/
object AvroBytesToSparkRow {

def mapperAndEncoder(avroSchemaJson: String, options: Map[String, String] = Map()): (Array[Byte] => Row, Encoder[Row]) = {
val catalyst = AvroDataToCatalyst(null, avroSchemaJson, options)
val sparkSchema = catalyst.dataType.asInstanceOf[StructType]
val rowEncoder = RowEncoder(sparkSchema)
val sparkRowDeser = RowEncoder(sparkSchema).resolveAndBind().createDeserializer()
val mapper = (bytes: Array[Byte]) =>
sparkRowDeser(catalyst.nullSafeEval(bytes).asInstanceOf[InternalRow])
(mapper, rowEncoder)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package ai.chronon.flink.test

import org.apache.avro.Schema

import java.io.ByteArrayOutputStream
import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord}
import org.apache.avro.io.{DatumWriter, EncoderFactory}

object AvroFlinkSourceTestUtils {

val e2EAvroSchema = """
{
"type": "record",
"name": "E2ETestEvent",
"fields": [
{ "name": "id", "type": "string" },
{ "name": "int_val", "type": "int" },
{ "name": "double_val", "type": "double" },
{ "name": "created", "type": "long" }
]
}
""".stripMargin
private val e2ETestEventSchema: Schema = new Schema.Parser().parse(e2EAvroSchema)

def toAvroBytes(event: E2ETestEvent): Array[Byte] = {
// Create a record with the parsed schema
val record = new GenericData.Record(e2ETestEventSchema)

// Populate the record fields from the case class
record.put("id", event.id)
record.put("int_val", event.int_val)
record.put("double_val", event.double_val)
record.put("created", event.created)

avroBytesFromGenericRecord(record)
}
private def avroBytesFromGenericRecord(record: GenericRecord): Array[Byte] = {
// Create a ByteArrayOutputStream to hold the serialized data
val byteStream = new ByteArrayOutputStream()

// Create a binary encoder that writes to the byteStream
val encoder = EncoderFactory.get.binaryEncoder(byteStream, null)

// Create a datum writer for the record's schema
val writer: DatumWriter[GenericRecord] = new GenericDatumWriter[GenericRecord](record.getSchema)

// Write the record data to the encoder
writer.write(record, encoder)
encoder.flush()

// Return the serialized Avro bytes
byteStream.toByteArray
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ai.chronon.flink.test

import ai.chronon.flink.window.{TimestampedIR, TimestampedTile}
import ai.chronon.flink.{FlinkJob, SparkExpressionEvalFn}
import ai.chronon.flink.{AvroFlinkSource, FlinkJob, SparkExpressionEvalFn}
import ai.chronon.online.{Api, GroupByServingInfoParsed}
import ai.chronon.online.KVStore.PutRequest
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration
Expand All @@ -12,6 +12,7 @@ import org.junit.Assert.assertEquals
import org.junit.{After, Before, Test}
import org.mockito.Mockito.withSettings
import org.scalatestplus.mockito.MockitoSugar.mock
import org.apache.flink.api.scala.createTypeInformation

import scala.jdk.CollectionConverters.asScalaBufferConverter

Expand Down Expand Up @@ -158,4 +159,40 @@ class FlinkJobIntegrationTest {

assertEquals(expectedFinalIRsPerKey, finalIRsPerKey)
}

@Test
def testFlinkJobFromAvroEndToEnd(): Unit = {
implicit val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment

val elements = Seq(
E2ETestEvent("test1", 12, 1.5, 1699366993123L),
E2ETestEvent("test2", 13, 1.6, 1699366993124L),
E2ETestEvent("test3", 14, 1.7, 1699366993125L)
).map(AvroFlinkSourceTestUtils.toAvroBytes)

val avroByteStream = env.fromCollection(elements)

val source = new AvroFlinkSource(avroByteStream, AvroFlinkSourceTestUtils.e2EAvroSchema)
val groupBy = FlinkTestUtils.makeGroupBy(Seq("id"))
val encoder = Encoders.product[E2ETestEvent]

val outputSchema = new SparkExpressionEvalFn(encoder, groupBy).getOutputSchema

val groupByServingInfoParsed =
FlinkTestUtils.makeTestGroupByServingInfoParsed(groupBy, encoder.schema, outputSchema)
val mockApi = mock[Api](withSettings().serializable())
val writerFn = new MockAsyncKVStoreWriter(Seq(true), mockApi, "testFlinkJobFromAvroEndToEnd")
val job = FlinkJob.fromAvro(source, writerFn, groupByServingInfoParsed, 2)

job.runGroupByJob(env).addSink(new CollectSink)

env.execute("FlinkJobIntegrationTest")

// capture the datastream of the 'created' timestamps of all the written out events
val writeEventCreatedDS = CollectSink.values.asScala
assert(writeEventCreatedDS.size == elements.size)
// check that all the writes were successful
assertEquals(writeEventCreatedDS.map(_.status), Seq(true, true, true))
}

}
Empty file added foo.txt
Empty file.
14 changes: 14 additions & 0 deletions publish_jar.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FROM houpy0829/chronon-ci:base--f87f50dc520f7a73894ae024eb78bd305d5b08e2

COPY . /workspace

WORKDIR /workspace

# Initialize conda and activate environment in the same shell
SHELL ["/bin/bash", "-c"]
ENV SBT_OPTS="-Xmx4G -Xms2G"
RUN source /opt/conda/etc/profile.d/conda.sh && \
conda init bash && \
conda activate chronon_py

ENTRYPOINT ["/bin/bash"]%