diff --git a/build.sbt b/build.sbt index 2c192b15d..a3ee68ed1 100644 --- a/build.sbt +++ b/build.sbt @@ -123,6 +123,14 @@ git.versionProperty := { * scala 13 + spark 3.2.1: https://mvnrepository.com/artifact/org.apache.spark/spark-core_2.13/3.2.1 */ val VersionMatrix: Map[String, VersionDependency] = Map( + "spark-avro" -> VersionDependency( + Seq( + "org.apache.spark" %% "spark-avro", + ), + Some(spark2_4_0), + Some(spark3_1_1), + Some(spark3_2_1) + ), "spark-sql" -> VersionDependency( Seq( "org.apache.spark" %% "spark-sql", @@ -395,6 +403,7 @@ lazy val flink = (project in file("flink")) libraryDependencies ++= fromMatrix(scalaVersion.value, "avro", "spark-all/provided", + "spark-avro/provided", "scala-parallel-collections", "flink") ) diff --git a/flink/src/main/scala/ai/chronon/flink/AvroFlinkSource.scala b/flink/src/main/scala/ai/chronon/flink/AvroFlinkSource.scala new file mode 100644 index 000000000..4901f5e58 --- /dev/null +++ b/flink/src/main/scala/ai/chronon/flink/AvroFlinkSource.scala @@ -0,0 +1,37 @@ +package ai.chronon.flink + +import org.apache.flink.api.common.functions.RichMapFunction +import org.apache.flink.api.scala.createTypeInformation +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} +import org.apache.spark.sql.Row +import org.apache.spark.sql.avro.AvroBytesToSparkRow + + +class AvroFlinkSource(source: DataStream[Array[Byte]], + avroSchemaJson: String, + avroOptions: Map[String, String]) extends FlinkSource[Row] { + + val (_, encoder) = AvroBytesToSparkRow.mapperAndEncoder(avroSchemaJson, avroOptions) + + + override def getDataStream(topic: String, groupName: String)( + env: StreamExecutionEnvironment, + parallelism: Int): DataStream[Row] = { + val mapper = new AvroBytesToSparkRowFunction(avroSchemaJson, avroOptions) + source.map(mapper) + } +} + +class AvroBytesToSparkRowFunction(avroSchemaJson: String, avroOptions: Map[String, String]) extends RichMapFunction[Array[Byte], Row] { + + @transient private var mapper: Array[Byte] => Row = _ + override def open(configuration: Configuration): Unit = { + val (_mapper, _) = AvroBytesToSparkRow.mapperAndEncoder(avroSchemaJson, avroOptions) + mapper = _mapper + } + + override def map(value: Array[Byte]): Row = { + mapper(value) + } +} \ No newline at end of file diff --git a/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala b/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala index 25b7f0039..d25b54d7e 100644 --- a/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala +++ b/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala @@ -1,19 +1,13 @@ package ai.chronon.flink import ai.chronon.aggregator.windowing.ResolutionUtils -import ai.chronon.api.{DataType} +import ai.chronon.api.DataType import ai.chronon.api.Extensions.{GroupByOps, SourceOps} -import ai.chronon.flink.window.{ - AlwaysFireOnElementTrigger, - FlinkRowAggProcessFunction, - FlinkRowAggregationFunction, - KeySelector, - TimestampedTile -} +import ai.chronon.flink.window.{AlwaysFireOnElementTrigger, FlinkRowAggProcessFunction, FlinkRowAggregationFunction, KeySelector, TimestampedTile} import ai.chronon.online.{GroupByServingInfoParsed, SparkConversions} import ai.chronon.online.KVStore.PutRequest import org.apache.flink.streaming.api.scala.{DataStream, OutputTag, StreamExecutionEnvironment} -import org.apache.spark.sql.Encoder +import org.apache.spark.sql.{Encoder, Row} import org.apache.flink.api.scala._ import org.apache.flink.streaming.api.functions.async.RichAsyncFunction import org.apache.flink.streaming.api.windowing.assigners.{TumblingEventTimeWindows, WindowAssigner} @@ -194,3 +188,19 @@ class FlinkJob[T](eventSrc: FlinkSource[T], ) } } + +object FlinkJob { + + def fromAvro(avroSource: AvroFlinkSource, + sinkFn: RichAsyncFunction[PutRequest, WriteResponse], + groupByServingInfoParsed: GroupByServingInfoParsed, + parallelism: Int): FlinkJob[Row] = { + + new FlinkJob[Row]( + avroSource, + sinkFn: RichAsyncFunction[PutRequest, WriteResponse], + groupByServingInfoParsed: GroupByServingInfoParsed, + avroSource.encoder, + parallelism: Int) + } +} \ No newline at end of file diff --git a/flink/src/main/scala/org/apache/spark/sql/avro/AvroBytesToSparkRow.scala b/flink/src/main/scala/org/apache/spark/sql/avro/AvroBytesToSparkRow.scala new file mode 100644 index 000000000..3e6105881 --- /dev/null +++ b/flink/src/main/scala/org/apache/spark/sql/avro/AvroBytesToSparkRow.scala @@ -0,0 +1,23 @@ +package org.apache.spark.sql.avro + +import org.apache.spark.sql.{Encoder, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.types.StructType + +/** + * Thin wrapper to [[AvroDataToCatalyst]] that backs `from_avro` https://spark.apache.org/docs/3.5.1/sql-data-sources-avro.html#to_avro-and-from_avro + * SparkSQL doesn't have this registered, so instead we use the underlying functionality instead. + */ +object AvroBytesToSparkRow { + + def mapperAndEncoder(avroSchemaJson: String, options: Map[String, String] = Map()): (Array[Byte] => Row, Encoder[Row]) = { + val catalyst = AvroDataToCatalyst(null, avroSchemaJson, options) + val sparkSchema = catalyst.dataType.asInstanceOf[StructType] + val rowEncoder = RowEncoder(sparkSchema) + val sparkRowDeser = RowEncoder(sparkSchema).resolveAndBind().createDeserializer() + val mapper = (bytes: Array[Byte]) => + sparkRowDeser(catalyst.nullSafeEval(bytes).asInstanceOf[InternalRow]) + (mapper, rowEncoder) + } +} diff --git a/flink/src/test/scala/ai/chronon/flink/test/AvroFlinkSourceTestUtils.scala b/flink/src/test/scala/ai/chronon/flink/test/AvroFlinkSourceTestUtils.scala new file mode 100644 index 000000000..f72a2ed21 --- /dev/null +++ b/flink/src/test/scala/ai/chronon/flink/test/AvroFlinkSourceTestUtils.scala @@ -0,0 +1,54 @@ +package ai.chronon.flink.test + +import org.apache.avro.Schema + +import java.io.ByteArrayOutputStream +import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} +import org.apache.avro.io.{DatumWriter, EncoderFactory} + +object AvroFlinkSourceTestUtils { + + val e2EAvroSchema = """ +{ + "type": "record", + "name": "E2ETestEvent", + "fields": [ + { "name": "id", "type": "string" }, + { "name": "int_val", "type": "int" }, + { "name": "double_val", "type": "double" }, + { "name": "created", "type": "long" } + ] +} +""".stripMargin + private val e2ETestEventSchema: Schema = new Schema.Parser().parse(e2EAvroSchema) + + def toAvroBytes(event: E2ETestEvent): Array[Byte] = { + // Create a record with the parsed schema + val record = new GenericData.Record(e2ETestEventSchema) + + // Populate the record fields from the case class + record.put("id", event.id) + record.put("int_val", event.int_val) + record.put("double_val", event.double_val) + record.put("created", event.created) + + avroBytesFromGenericRecord(record) + } + private def avroBytesFromGenericRecord(record: GenericRecord): Array[Byte] = { + // Create a ByteArrayOutputStream to hold the serialized data + val byteStream = new ByteArrayOutputStream() + + // Create a binary encoder that writes to the byteStream + val encoder = EncoderFactory.get.binaryEncoder(byteStream, null) + + // Create a datum writer for the record's schema + val writer: DatumWriter[GenericRecord] = new GenericDatumWriter[GenericRecord](record.getSchema) + + // Write the record data to the encoder + writer.write(record, encoder) + encoder.flush() + + // Return the serialized Avro bytes + byteStream.toByteArray + } +} diff --git a/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala b/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala index 83f4bd55d..20e5e507e 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala @@ -1,7 +1,7 @@ package ai.chronon.flink.test import ai.chronon.flink.window.{TimestampedIR, TimestampedTile} -import ai.chronon.flink.{FlinkJob, SparkExpressionEvalFn} +import ai.chronon.flink.{AvroFlinkSource, FlinkJob, SparkExpressionEvalFn} import ai.chronon.online.{Api, GroupByServingInfoParsed} import ai.chronon.online.KVStore.PutRequest import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration @@ -12,6 +12,7 @@ import org.junit.Assert.assertEquals import org.junit.{After, Before, Test} import org.mockito.Mockito.withSettings import org.scalatestplus.mockito.MockitoSugar.mock +import org.apache.flink.api.scala.createTypeInformation import scala.jdk.CollectionConverters.asScalaBufferConverter @@ -158,4 +159,40 @@ class FlinkJobIntegrationTest { assertEquals(expectedFinalIRsPerKey, finalIRsPerKey) } + + @Test + def testFlinkJobFromAvroEndToEnd(): Unit = { + implicit val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment + + val elements = Seq( + E2ETestEvent("test1", 12, 1.5, 1699366993123L), + E2ETestEvent("test2", 13, 1.6, 1699366993124L), + E2ETestEvent("test3", 14, 1.7, 1699366993125L) + ).map(AvroFlinkSourceTestUtils.toAvroBytes) + + val avroByteStream = env.fromCollection(elements) + + val source = new AvroFlinkSource(avroByteStream, AvroFlinkSourceTestUtils.e2EAvroSchema) + val groupBy = FlinkTestUtils.makeGroupBy(Seq("id")) + val encoder = Encoders.product[E2ETestEvent] + + val outputSchema = new SparkExpressionEvalFn(encoder, groupBy).getOutputSchema + + val groupByServingInfoParsed = + FlinkTestUtils.makeTestGroupByServingInfoParsed(groupBy, encoder.schema, outputSchema) + val mockApi = mock[Api](withSettings().serializable()) + val writerFn = new MockAsyncKVStoreWriter(Seq(true), mockApi, "testFlinkJobFromAvroEndToEnd") + val job = FlinkJob.fromAvro(source, writerFn, groupByServingInfoParsed, 2) + + job.runGroupByJob(env).addSink(new CollectSink) + + env.execute("FlinkJobIntegrationTest") + + // capture the datastream of the 'created' timestamps of all the written out events + val writeEventCreatedDS = CollectSink.values.asScala + assert(writeEventCreatedDS.size == elements.size) + // check that all the writes were successful + assertEquals(writeEventCreatedDS.map(_.status), Seq(true, true, true)) + } + } diff --git a/foo.txt b/foo.txt new file mode 100644 index 000000000..e69de29bb diff --git a/publish_jar.Dockerfile b/publish_jar.Dockerfile new file mode 100644 index 000000000..761005a5d --- /dev/null +++ b/publish_jar.Dockerfile @@ -0,0 +1,14 @@ +FROM houpy0829/chronon-ci:base--f87f50dc520f7a73894ae024eb78bd305d5b08e2 + +COPY . /workspace + +WORKDIR /workspace + +# Initialize conda and activate environment in the same shell +SHELL ["/bin/bash", "-c"] +ENV SBT_OPTS="-Xmx4G -Xms2G" +RUN source /opt/conda/etc/profile.d/conda.sh && \ + conda init bash && \ + conda activate chronon_py + +ENTRYPOINT ["/bin/bash"]% \ No newline at end of file