diff --git a/spark/spark-3.5/pom.xml b/spark/spark-3.5/pom.xml index 59a8cf7c82d..d1a616944a1 100644 --- a/spark/spark-3.5/pom.xml +++ b/spark/spark-3.5/pom.xml @@ -48,6 +48,12 @@ + + org.apache.arrow + arrow-java-root + 18.2.0 + pom + org.apache.sedona sedona-spark-common-${spark.compat.version}_${scala.compat.version} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.scala new file mode 100644 index 00000000000..750e99ee669 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow + +import org.apache.spark.sql.connector.write.BatchWrite +import org.apache.spark.sql.connector.write.{DataWriterFactory, PhysicalWriteInfo} +import org.apache.spark.sql.connector.write.WriterCommitMessage +import org.apache.spark.sql.connector.write.LogicalWriteInfo +import org.apache.spark.sql.catalyst.InternalRow + +case class ArrowBatchWrite(logicalInfo: LogicalWriteInfo) extends BatchWrite { + def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = { + return new ArrowDataWriterFactory(logicalInfo, info) + } + + def commit(messages: Array[WriterCommitMessage]): Unit = {} + + def abort(messages: Array[WriterCommitMessage]): Unit = {} + +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataSource.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataSource.scala new file mode 100644 index 00000000000..20412eec615 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataSource.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow + +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util.Locale +import scala.jdk.CollectionConverters._ +import scala.util.Try +import org.apache.sedona.sql.datasources.geopackage.ArrowTable + +class ArrowDataSource extends FileDataSourceV2 with DataSourceRegister { + + override def fallbackFileFormat: Class[_ <: FileFormat] = { + null + } + + override protected def getTable(options: CaseInsensitiveStringMap): Table = { + ArrowTable("", sparkSession, options, getPaths(options), None) + } + + override def shortName(): String = "arrows" +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.scala new file mode 100644 index 00000000000..ad8c631709c --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow + +import org.apache.spark.sql.connector.write.LogicalWriteInfo +import org.apache.spark.sql.connector.write.PhysicalWriteInfo +import org.apache.spark.sql.connector.write.DataWriterFactory +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.write.DataWriter + +case class ArrowDataWriterFactory(logicalInfo: LogicalWriteInfo, physicalInfo: PhysicalWriteInfo) + extends DataWriterFactory { + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + return new ArrowWriter(logicalInfo, physicalInfo, partitionId, taskId) + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala new file mode 100644 index 00000000000..aaedcab8f12 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow + +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag + +import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot} +import org.apache.arrow.vector.complex.StructVector + +private[arrow] object ArrowEncoderUtils { + object Classes { + val WRAPPED_ARRAY: Class[_] = classOf[scala.collection.mutable.WrappedArray[_]] + val ITERABLE: Class[_] = classOf[scala.collection.Iterable[_]] + val MAP: Class[_] = classOf[scala.collection.Map[_, _]] + val JLIST: Class[_] = classOf[java.util.List[_]] + val JMAP: Class[_] = classOf[java.util.Map[_, _]] + } + + def isSubClass(cls: Class[_], tag: ClassTag[_]): Boolean = { + cls.isAssignableFrom(tag.runtimeClass) + } + + def unsupportedCollectionType(cls: Class[_]): Nothing = { + throw new RuntimeException(s"Unsupported collection type: $cls") + } +} + +private[arrow] object StructVectors { + def unapply(v: AnyRef): Option[(StructVector, Seq[FieldVector])] = v match { + case root: VectorSchemaRoot => Option((null, root.getFieldVectors.asScala.toSeq)) + case struct: StructVector => Option((struct, struct.getChildrenFromFields.asScala.toSeq)) + case _ => None + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala new file mode 100644 index 00000000000..f97d3d27c17 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala @@ -0,0 +1,601 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow + +import java.io.{ByteArrayOutputStream, OutputStream} +import java.lang.invoke.{MethodHandles, MethodType} +import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInteger} +import java.nio.channels.Channels +import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} +import java.util.{Map => JMap, Objects} + +import scala.collection.JavaConverters._ + +import com.google.protobuf.ByteString +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, IntervalYearVector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, VarCharVector, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.arrow.vector.ipc.{ArrowStreamWriter, WriteChannel} +import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer} +import org.apache.arrow.vector.util.Text + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.DefinedByConstructorParams +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ +import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} +import org.apache.spark.sql.errors.ExecutionErrors +import org.apache.spark.sql.types.Decimal + +private[arrow] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable { self => + def asJava: java.util.Iterator[E] = new java.util.Iterator[E] with AutoCloseable { + override def next() = self.next() + + override def hasNext() = self.hasNext + + override def close() = self.close() + } + + override def map[B](f: E => B): CloseableIterator[B] = { + new CloseableIterator[B] { + override def next(): B = f(self.next()) + + override def hasNext: Boolean = self.hasNext + + override def close(): Unit = self.close() + } + } +} + +/** + * Helper class for converting user objects into arrow batches. + */ +class ArrowSerializer[T]( + private[this] val enc: AgnosticEncoder[T], + private[this] val allocator: BufferAllocator, + private[this] val timeZoneId: String) { + private val (root, serializer) = ArrowSerializer.serializerFor(enc, allocator, timeZoneId) + private val vectors = root.getFieldVectors.asScala + private val unloader = new VectorUnloader(root) + private val schemaBytes = { + // Only serialize the schema once. + val bytes = new ByteArrayOutputStream() + MessageSerializer.serialize(newChannel(bytes), root.getSchema) + bytes.toByteArray + } + private var rowCount: Int = 0 + private var closed: Boolean = false + + private def newChannel(output: OutputStream): WriteChannel = { + new WriteChannel(Channels.newChannel(output)) + } + + /** + * The size of the current batch. + * + * The size computed consist of the size of the schema and the size of the arrow buffers. The + * actual batch will be larger than that because of alignment, written IPC tokens, and the + * written record batch metadata. The size of the record batch metadata is proportional to the + * complexity of the schema. + */ + def sizeInBytes: Long = { + // We need to set the row count for getBufferSize to return the actual value. + root.setRowCount(rowCount) + schemaBytes.length + vectors.map(_.getBufferSize).sum + } + + /** + * Append a record to the current batch. + */ + def append(record: T): Unit = { + serializer.write(rowCount, record) + rowCount += 1 + } + + /** + * Write the schema and the current batch in Arrow IPC stream format to the [[OutputStream]]. + */ + def writeIpcStream(output: OutputStream): Unit = { + val channel = newChannel(output) + root.setRowCount(rowCount) + val batch = unloader.getRecordBatch + try { + channel.write(schemaBytes) + MessageSerializer.serialize(channel, batch) + ArrowStreamWriter.writeEndOfStream(channel, IpcOption.DEFAULT) + } finally { + batch.close() + } + } + + /** + * Write the schema in Arrow IPC stream format to the [[OutputStream]]. + */ + def writeSchema(output: OutputStream): Unit = { + val channel = newChannel(output) + channel.write(schemaBytes) + } + + /** + * Write current batch in Arrow IPC stream format to the [[OutputStream]]. + */ + def writeBatch(output: OutputStream): Unit = { + val channel = newChannel(output) + root.setRowCount(rowCount) + val batch = unloader.getRecordBatch + try { + MessageSerializer.serialize(channel, batch) + } finally { + batch.close() + } + } + + /** + * Write the end-of-stream bytes to [[OutputStream]]. + */ + def writeEndOfStream(output: OutputStream): Unit = { + val channel = newChannel(output) + ArrowStreamWriter.writeEndOfStream(channel, IpcOption.DEFAULT) + } + + /** + * Reset the serializer. + */ + def reset(): Unit = { + rowCount = 0 + vectors.foreach(_.reset()) + } + + /** + * Close the serializer. + */ + def close(): Unit = { + root.close() + closed = true + } + + /** + * Check if the serializer has been closed. + * + * It is illegal to used the serializer after it has been closed. It will lead to errors and + * sorts of undefined behavior. + */ + def isClosed: Boolean = closed +} + +object ArrowSerializer { + import ArrowEncoderUtils._ + + /** + * Create an [[Iterator]] that converts the input [[Iterator]] of type `T` into an [[Iterator]] + * of Arrow IPC Streams. + */ + def serialize[T]( + input: Iterator[T], + enc: AgnosticEncoder[T], + allocator: BufferAllocator, + maxRecordsPerBatch: Int, + maxBatchSize: Long, + timeZoneId: String, + batchSizeCheckInterval: Int = 128): CloseableIterator[Array[Byte]] = { + assert(maxRecordsPerBatch > 0) + assert(maxBatchSize > 0) + assert(batchSizeCheckInterval > 0) + new CloseableIterator[Array[Byte]] { + private val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId) + private val bytes = new ByteArrayOutputStream + private var hasWrittenFirstBatch = false + + /** + * Periodical check to make sure we don't go over the size threshold by too much. + */ + private def sizeOk(i: Int): Boolean = { + if (i > 0 && i % batchSizeCheckInterval == 0) { + return serializer.sizeInBytes < maxBatchSize + } + true + } + + override def hasNext: Boolean = { + (input.hasNext || !hasWrittenFirstBatch) && !serializer.isClosed + } + + override def next(): Array[Byte] = { + if (!hasNext) { + throw new NoSuchElementException() + } + serializer.reset() + bytes.reset() + var i = 0 + while (i < maxRecordsPerBatch && input.hasNext && sizeOk(i)) { + serializer.append(input.next()) + i += 1 + } + serializer.writeIpcStream(bytes) + hasWrittenFirstBatch = true + bytes.toByteArray + } + + override def close(): Unit = serializer.close() + } + } + + def serialize[T]( + input: Iterator[T], + enc: AgnosticEncoder[T], + allocator: BufferAllocator, + timeZoneId: String): ByteString = { + val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId) + try { + input.foreach(serializer.append) + val output = ByteString.newOutput() + serializer.writeIpcStream(output) + output.toByteString + } finally { + serializer.close() + } + } + + /** + * Create a (root) [[Serializer]] for [[AgnosticEncoder]] `encoder`. + * + * The serializer returned by this method is NOT thread-safe. + */ + def serializerFor[T]( + encoder: AgnosticEncoder[T], + allocator: BufferAllocator, + timeZoneId: String): (VectorSchemaRoot, Serializer) = { + val arrowSchema = + ArrowUtils.toArrowSchema(encoder.schema, timeZoneId, errorOnDuplicatedFieldNames = true) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val serializer = if (encoder.schema != encoder.dataType) { + assert(root.getSchema.getFields.size() == 1) + serializerFor(encoder, root.getVector(0)) + } else { + serializerFor(encoder, root) + } + root -> serializer + } + + // TODO throw better errors on class cast exceptions. + private[arrow] def serializerFor[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = { + (encoder, v) match { + case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: BitVector) => + new FieldSerializer[Boolean, BitVector](v) { + override def set(index: Int, value: Boolean): Unit = + vector.setSafe(index, if (value) 1 else 0) + } + case (PrimitiveByteEncoder | BoxedByteEncoder, v: TinyIntVector) => + new FieldSerializer[Byte, TinyIntVector](v) { + override def set(index: Int, value: Byte): Unit = vector.setSafe(index, value) + } + case (PrimitiveShortEncoder | BoxedShortEncoder, v: SmallIntVector) => + new FieldSerializer[Short, SmallIntVector](v) { + override def set(index: Int, value: Short): Unit = vector.setSafe(index, value) + } + case (PrimitiveIntEncoder | BoxedIntEncoder, v: IntVector) => + new FieldSerializer[Int, IntVector](v) { + override def set(index: Int, value: Int): Unit = vector.setSafe(index, value) + } + case (PrimitiveLongEncoder | BoxedLongEncoder, v: BigIntVector) => + new FieldSerializer[Long, BigIntVector](v) { + override def set(index: Int, value: Long): Unit = vector.setSafe(index, value) + } + case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: Float4Vector) => + new FieldSerializer[Float, Float4Vector](v) { + override def set(index: Int, value: Float): Unit = vector.setSafe(index, value) + } + case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: Float8Vector) => + new FieldSerializer[Double, Float8Vector](v) { + override def set(index: Int, value: Double): Unit = vector.setSafe(index, value) + } + case (NullEncoder, v: NullVector) => + new FieldSerializer[Unit, NullVector](v) { + override def set(index: Int, value: Unit): Unit = vector.setNull(index) + } + case (StringEncoder, v: VarCharVector) => + new FieldSerializer[String, VarCharVector](v) { + override def set(index: Int, value: String): Unit = setString(v, index, value) + } + case (JavaEnumEncoder(_), v: VarCharVector) => + new FieldSerializer[Enum[_], VarCharVector](v) { + override def set(index: Int, value: Enum[_]): Unit = setString(v, index, value.name()) + } + case (ScalaEnumEncoder(_, _), v: VarCharVector) => + new FieldSerializer[Enumeration#Value, VarCharVector](v) { + override def set(index: Int, value: Enumeration#Value): Unit = + setString(v, index, value.toString) + } + case (BinaryEncoder, v: VarBinaryVector) => + new FieldSerializer[Array[Byte], VarBinaryVector](v) { + override def set(index: Int, value: Array[Byte]): Unit = vector.setSafe(index, value) + } + case (SparkDecimalEncoder(_), v: DecimalVector) => + new FieldSerializer[Decimal, DecimalVector](v) { + override def set(index: Int, value: Decimal): Unit = + setDecimal(vector, index, value.toJavaBigDecimal) + } + case (ScalaDecimalEncoder(_), v: DecimalVector) => + new FieldSerializer[BigDecimal, DecimalVector](v) { + override def set(index: Int, value: BigDecimal): Unit = + setDecimal(vector, index, value.bigDecimal) + } + case (JavaDecimalEncoder(_, false), v: DecimalVector) => + new FieldSerializer[JBigDecimal, DecimalVector](v) { + override def set(index: Int, value: JBigDecimal): Unit = + setDecimal(vector, index, value) + } + case (JavaDecimalEncoder(_, true), v: DecimalVector) => + new FieldSerializer[Any, DecimalVector](v) { + override def set(index: Int, value: Any): Unit = { + val decimal = value match { + case j: JBigDecimal => j + case d: BigDecimal => d.bigDecimal + case k: BigInt => new JBigDecimal(k.bigInteger) + case l: JBigInteger => new JBigDecimal(l) + case d: Decimal => d.toJavaBigDecimal + } + setDecimal(vector, index, decimal) + } + } + case (ScalaBigIntEncoder, v: DecimalVector) => + new FieldSerializer[BigInt, DecimalVector](v) { + override def set(index: Int, value: BigInt): Unit = + setDecimal(vector, index, new JBigDecimal(value.bigInteger)) + } + case (JavaBigIntEncoder, v: DecimalVector) => + new FieldSerializer[JBigInteger, DecimalVector](v) { + override def set(index: Int, value: JBigInteger): Unit = + setDecimal(vector, index, new JBigDecimal(value)) + } + case (DayTimeIntervalEncoder, v: DurationVector) => + new FieldSerializer[Duration, DurationVector](v) { + override def set(index: Int, value: Duration): Unit = + vector.setSafe(index, SparkIntervalUtils.durationToMicros(value)) + } + case (YearMonthIntervalEncoder, v: IntervalYearVector) => + new FieldSerializer[Period, IntervalYearVector](v) { + override def set(index: Int, value: Period): Unit = + vector.setSafe(index, SparkIntervalUtils.periodToMonths(value)) + } + case (DateEncoder(true) | LocalDateEncoder(true), v: DateDayVector) => + new FieldSerializer[Any, DateDayVector](v) { + override def set(index: Int, value: Any): Unit = + vector.setSafe(index, SparkDateTimeUtils.anyToDays(value)) + } + case (DateEncoder(false), v: DateDayVector) => + new FieldSerializer[java.sql.Date, DateDayVector](v) { + override def set(index: Int, value: java.sql.Date): Unit = + vector.setSafe(index, SparkDateTimeUtils.fromJavaDate(value)) + } + case (LocalDateEncoder(false), v: DateDayVector) => + new FieldSerializer[LocalDate, DateDayVector](v) { + override def set(index: Int, value: LocalDate): Unit = + vector.setSafe(index, SparkDateTimeUtils.localDateToDays(value)) + } + case (TimestampEncoder(true) | InstantEncoder(true), v: TimeStampMicroTZVector) => + new FieldSerializer[Any, TimeStampMicroTZVector](v) { + override def set(index: Int, value: Any): Unit = + vector.setSafe(index, SparkDateTimeUtils.anyToMicros(value)) + } + case (TimestampEncoder(false), v: TimeStampMicroTZVector) => + new FieldSerializer[java.sql.Timestamp, TimeStampMicroTZVector](v) { + override def set(index: Int, value: java.sql.Timestamp): Unit = + vector.setSafe(index, SparkDateTimeUtils.fromJavaTimestamp(value)) + } + case (InstantEncoder(false), v: TimeStampMicroTZVector) => + new FieldSerializer[Instant, TimeStampMicroTZVector](v) { + override def set(index: Int, value: Instant): Unit = + vector.setSafe(index, SparkDateTimeUtils.instantToMicros(value)) + } + case (LocalDateTimeEncoder, v: TimeStampMicroVector) => + new FieldSerializer[LocalDateTime, TimeStampMicroVector](v) { + override def set(index: Int, value: LocalDateTime): Unit = + vector.setSafe(index, SparkDateTimeUtils.localDateTimeToMicros(value)) + } + + case (OptionEncoder(value), v) => + new Serializer { + private[this] val delegate: Serializer = serializerFor(value, v) + override def write(index: Int, value: Any): Unit = value match { + case Some(value) => delegate.write(index, value) + case _ => delegate.write(index, null) + } + } + + case (ArrayEncoder(element, _), v: ListVector) => + val elementSerializer = serializerFor(element, v.getDataVector) + val toIterator = { array: Any => + array.asInstanceOf[Array[_]].iterator + } + new ArraySerializer(v, toIterator, elementSerializer) + + case (IterableEncoder(tag, element, _, lenient), v: ListVector) => + val elementSerializer = serializerFor(element, v.getDataVector) + val toIterator: Any => Iterator[_] = if (lenient) { + { + case i: scala.collection.Iterable[_] => i.iterator + case l: java.util.List[_] => l.iterator().asScala + case a: Array[_] => a.iterator + case o => unsupportedCollectionType(o.getClass) + } + } else if (isSubClass(Classes.ITERABLE, tag)) { v => + v.asInstanceOf[scala.collection.Iterable[_]].iterator + } else if (isSubClass(Classes.JLIST, tag)) { v => + v.asInstanceOf[java.util.List[_]].iterator().asScala + } else { + unsupportedCollectionType(tag.runtimeClass) + } + new ArraySerializer(v, toIterator, elementSerializer) + + case (MapEncoder(tag, key, value, _), v: MapVector) => + val structVector = v.getDataVector.asInstanceOf[StructVector] + val extractor = if (isSubClass(classOf[scala.collection.Map[_, _]], tag)) { (v: Any) => + v.asInstanceOf[scala.collection.Map[_, _]].iterator + } else if (isSubClass(classOf[JMap[_, _]], tag)) { (v: Any) => + v.asInstanceOf[JMap[Any, Any]].asScala.iterator + } else { + unsupportedCollectionType(tag.runtimeClass) + } + val structSerializer = new StructSerializer( + structVector, + new StructFieldSerializer( + extractKey, + serializerFor(key, structVector.getChild(MapVector.KEY_NAME))) :: + new StructFieldSerializer( + extractValue, + serializerFor(value, structVector.getChild(MapVector.VALUE_NAME))) :: Nil) + new ArraySerializer(v, extractor, structSerializer) + + case (ProductEncoder(tag, fields, _), StructVectors(struct, vectors)) => + if (isSubClass(classOf[Product], tag)) { + structSerializerFor(fields, struct, vectors) { (_, i) => p => + p.asInstanceOf[Product].productElement(i) + } + } else if (isSubClass(classOf[DefinedByConstructorParams], tag)) { + structSerializerFor(fields, struct, vectors) { (field, _) => + val getter = methodLookup.findVirtual( + tag.runtimeClass, + field.name, + MethodType.methodType(field.enc.clsTag.runtimeClass)) + o => getter.invoke(o) + } + } else { + unsupportedCollectionType(tag.runtimeClass) + } + + case (RowEncoder(fields), StructVectors(struct, vectors)) => + structSerializerFor(fields, struct, vectors) { (_, i) => r => r.asInstanceOf[Row].get(i) } + + case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) => + structSerializerFor(fields, struct, vectors) { (field, _) => + val getter = methodLookup.findVirtual( + tag.runtimeClass, + field.readMethod.get, + MethodType.methodType(field.enc.clsTag.runtimeClass)) + o => getter.invoke(o) + } + + case (CalendarIntervalEncoder | _: UDTEncoder[_], _) => + // throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType) + throw new RuntimeException("Unsupported data type") + + case _ => + throw new RuntimeException(s"Unsupported Encoder($encoder)/Vector($v) combination.") + } + } + + private val methodLookup = MethodHandles.lookup() + + private def setString(vector: VarCharVector, index: Int, string: String): Unit = { + val bytes = Text.encode(string) + vector.setSafe(index, bytes, 0, bytes.limit()) + } + + private def setDecimal(vector: DecimalVector, index: Int, decimal: JBigDecimal): Unit = { + val scaledDecimal = if (vector.getScale != decimal.scale()) { + decimal.setScale(vector.getScale) + } else { + decimal + } + vector.setSafe(index, scaledDecimal) + } + + private def extractKey(v: Any): Any = { + val key = v.asInstanceOf[(Any, Any)]._1 + Objects.requireNonNull(key) + key + } + + private def extractValue(v: Any): Any = { + v.asInstanceOf[(Any, Any)]._2 + } + + private def structSerializerFor( + fields: Seq[EncoderField], + struct: StructVector, + vectors: Seq[FieldVector])( + createGetter: (EncoderField, Int) => Any => Any): StructSerializer = { + require(fields.size == vectors.size) + val serializers = fields.zip(vectors).zipWithIndex.map { case ((field, vector), i) => + val serializer = serializerFor(field.enc, vector) + new StructFieldSerializer(createGetter(field, i), serializer) + } + new StructSerializer(struct, serializers) + } + + abstract class Serializer { + def write(index: Int, value: Any): Unit + } + + private abstract class FieldSerializer[E, V <: FieldVector](val vector: V) extends Serializer { + def set(index: Int, value: E): Unit + + override def write(index: Int, raw: Any): Unit = { + val value = raw.asInstanceOf[E] + if (value != null) { + set(index, value) + } else { + vector.setNull(index) + } + } + } + + private class ArraySerializer( + v: ListVector, + toIterator: Any => Iterator[Any], + elementSerializer: Serializer) + extends FieldSerializer[Any, ListVector](v) { + override def set(index: Int, value: Any): Unit = { + val elementStartIndex = vector.startNewValue(index) + var elementIndex = elementStartIndex + val iterator = toIterator(value) + while (iterator.hasNext) { + elementSerializer.write(elementIndex, iterator.next()) + elementIndex += 1 + } + vector.endValue(index, elementIndex - elementStartIndex) + } + } + + private class StructFieldSerializer(val extractor: Any => Any, val serializer: Serializer) { + def write(index: Int, value: Any): Unit = serializer.write(index, extractor(value)) + def writeNull(index: Int): Unit = serializer.write(index, null) + } + + private class StructSerializer( + struct: StructVector, + fieldSerializers: Seq[StructFieldSerializer]) + extends Serializer { + + override def write(index: Int, value: Any): Unit = { + if (value == null) { + if (struct != null) { + struct.setNull(index) + } + fieldSerializers.foreach(_.writeNull(index)) + } else { + if (struct != null) { + struct.setIndexDefined(index) + } + fieldSerializers.foreach(_.write(index, value)) + } + } + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowTable.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowTable.scala new file mode 100644 index 00000000000..49421a5d313 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowTable.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.geopackage + +import org.apache.hadoop.fs.FileStatus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import scala.jdk.CollectionConverters._ +import org.apache.sedona.sql.datasources.arrow.ArrowWriteBuilder + +case class ArrowTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + null + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + new ArrowWriteBuilder(info) + } + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + None + } + + override def formatName: String = { + "Arrow Stream" + } + + override def fallbackFileFormat: Class[_ <: FileFormat] = { + null + } + +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala new file mode 100644 index 00000000000..60006412873 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow + +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.JavaConverters._ + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.complex.MapVector +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit} +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} + +import org.apache.spark.sql.errors.ExecutionErrors +import org.apache.spark.sql.types._ + +private[arrow] object ArrowUtils { + + val rootAllocator = new RootAllocator(Long.MaxValue) + + // todo: support more types. + + /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ + def toArrowType(dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = + dt match { + case BooleanType => ArrowType.Bool.INSTANCE + case ByteType => new ArrowType.Int(8, true) + case ShortType => new ArrowType.Int(8 * 2, true) + case IntegerType => new ArrowType.Int(8 * 4, true) + case LongType => new ArrowType.Int(8 * 8, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE + case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE + case StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE + case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE + // TODO(paleolimbot): DecimalType.Fixed is marked private + // case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) + case DateType => new ArrowType.Date(DateUnit.DAY) + case TimestampType if timeZoneId == null => + throw new IllegalStateException("Missing timezoneId where it is mandatory.") + case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) + case TimestampNTZType => + new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) + case NullType => ArrowType.Null.INSTANCE + case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) + case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) + case _ => + // throw ExecutionErrors.unsupportedDataTypeError(dt) + throw new RuntimeException("Unsupported data type") + } + + def fromArrowType(dt: ArrowType): DataType = dt match { + case ArrowType.Bool.INSTANCE => BooleanType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType + case float: ArrowType.FloatingPoint + if float.getPrecision() == FloatingPointPrecision.SINGLE => + FloatType + case float: ArrowType.FloatingPoint + if float.getPrecision() == FloatingPointPrecision.DOUBLE => + DoubleType + case ArrowType.Utf8.INSTANCE => StringType + case ArrowType.Binary.INSTANCE => BinaryType + case ArrowType.LargeUtf8.INSTANCE => StringType + case ArrowType.LargeBinary.INSTANCE => BinaryType + case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) + case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType + case ts: ArrowType.Timestamp + if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => + TimestampNTZType + case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType + case ArrowType.Null.INSTANCE => NullType + case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => + YearMonthIntervalType() + case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType() + // case _ => throw ExecutionErrors.unsupportedArrowTypeError(dt) + case _ => throw new RuntimeException("Unsupported arrow data type") + } + + /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ + def toArrowField( + name: String, + dt: DataType, + nullable: Boolean, + timeZoneId: String, + largeVarTypes: Boolean = false): Field = { + dt match { + case ArrayType(elementType, containsNull) => + val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) + new Field( + name, + fieldType, + Seq( + toArrowField("element", elementType, containsNull, timeZoneId, largeVarTypes)).asJava) + case StructType(fields) => + val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) + new Field( + name, + fieldType, + fields + .map { field => + toArrowField(field.name, field.dataType, field.nullable, timeZoneId, largeVarTypes) + } + .toSeq + .asJava) + case MapType(keyType, valueType, valueContainsNull) => + val mapType = new FieldType(nullable, new ArrowType.Map(false), null) + // Note: Map Type struct can not be null, Struct Type key field can not be null + new Field( + name, + mapType, + Seq( + toArrowField( + MapVector.DATA_VECTOR_NAME, + new StructType() + .add(MapVector.KEY_NAME, keyType, nullable = false) + .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull), + nullable = false, + timeZoneId, + largeVarTypes)).asJava) + case udt: UserDefinedType[_] => + toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes) + case dataType => + val fieldType = + new FieldType(nullable, toArrowType(dataType, timeZoneId, largeVarTypes), null) + new Field(name, fieldType, Seq.empty[Field].asJava) + } + } + + def fromArrowField(field: Field): DataType = { + field.getType match { + case _: ArrowType.Map => + val elementField = field.getChildren.get(0) + val keyType = fromArrowField(elementField.getChildren.get(0)) + val valueType = fromArrowField(elementField.getChildren.get(1)) + MapType(keyType, valueType, elementField.getChildren.get(1).isNullable) + case ArrowType.List.INSTANCE => + val elementField = field.getChildren().get(0) + val elementType = fromArrowField(elementField) + ArrayType(elementType, containsNull = elementField.isNullable) + case ArrowType.Struct.INSTANCE => + val fields = field.getChildren().asScala.map { child => + val dt = fromArrowField(child) + StructField(child.getName, dt, child.isNullable) + } + StructType(fields.toArray) + case arrowType => fromArrowType(arrowType) + } + } + + /** + * Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType + */ + def toArrowSchema( + schema: StructType, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean = false): Schema = { + new Schema(schema.map { field => + toArrowField( + field.name, + deduplicateFieldNames(field.dataType, errorOnDuplicatedFieldNames), + field.nullable, + timeZoneId, + largeVarTypes) + }.asJava) + } + + def fromArrowSchema(schema: Schema): StructType = { + StructType(schema.getFields.asScala.map { field => + val dt = fromArrowField(field) + StructField(field.getName, dt, field.isNullable) + }.toArray) + } + + private def deduplicateFieldNames( + dt: DataType, + errorOnDuplicatedFieldNames: Boolean): DataType = dt match { + case udt: UserDefinedType[_] => + deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames) + case st @ StructType(fields) => + val newNames = if (st.names.toSet.size == st.names.length) { + st.names + } else { + if (errorOnDuplicatedFieldNames) { + // throw ExecutionErrors.duplicatedFieldNameInArrowStructError(st.names) + throw new RuntimeException("duplicated field name in arrow struct") + } + val genNawName = st.names.groupBy(identity).map { + case (name, names) if names.length > 1 => + val i = new AtomicInteger() + name -> { () => s"${name}_${i.getAndIncrement()}" } + case (name, _) => name -> { () => name } + } + st.names.map(genNawName(_)()) + } + val newFields = + fields.zip(newNames).map { case (StructField(_, dataType, nullable, metadata), name) => + StructField( + name, + deduplicateFieldNames(dataType, errorOnDuplicatedFieldNames), + nullable, + metadata) + } + StructType(newFields) + case ArrayType(elementType, containsNull) => + ArrayType(deduplicateFieldNames(elementType, errorOnDuplicatedFieldNames), containsNull) + case MapType(keyType, valueType, valueContainsNull) => + MapType( + deduplicateFieldNames(keyType, errorOnDuplicatedFieldNames), + deduplicateFieldNames(valueType, errorOnDuplicatedFieldNames), + valueContainsNull) + case _ => dt + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.scala new file mode 100644 index 00000000000..5efee38920c --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow + +import org.apache.spark.sql.connector.write.WriteBuilder +import org.apache.spark.sql.connector.write.Write +import org.apache.spark.sql.connector.write.BatchWrite +import org.apache.spark.sql.connector.write.LogicalWriteInfo + +case class ArrowWriteBuilder(info: LogicalWriteInfo) extends WriteBuilder { + + override def buildForBatch(): BatchWrite = { + new ArrowBatchWrite(info) + } + +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala new file mode 100644 index 00000000000..ee1d84af1e2 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow + +import org.apache.spark.sql.connector.write.LogicalWriteInfo +import org.apache.spark.sql.connector.write.PhysicalWriteInfo +import org.apache.spark.sql.connector.write.DataWriter +import org.apache.arrow.memory.RootAllocator; +import org.apache.spark.sql.connector.write.WriterCommitMessage +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.Row; +import java.io.ByteArrayOutputStream +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder + +case class ArrowWriter() extends DataWriter[InternalRow] { + + private var encoder: AgnosticEncoder[Row] = _ + private var rowDeserializer: ExpressionEncoder.Deserializer[Row] = _ + private var serializer: ArrowSerializer[Row] = _ + private var dummyOutput: ByteArrayOutputStream = _ + private var rowCount: Long = 0 + + def this( + logicalInfo: LogicalWriteInfo, + physicalInfo: PhysicalWriteInfo, + partitionId: Int, + taskId: Long) { + this() + dummyOutput = new ByteArrayOutputStream() + encoder = RowEncoder.encoderFor(logicalInfo.schema()) + rowDeserializer = Encoders + .row(logicalInfo.schema()) + .asInstanceOf[ExpressionEncoder[Row]] + .resolveAndBind() + .createDeserializer() + serializer = new ArrowSerializer[Row](encoder, new RootAllocator(), "UTC"); + + serializer.writeSchema(dummyOutput) + } + + def write(record: InternalRow): Unit = { + serializer.append(rowDeserializer.apply(record)) + rowCount = rowCount + 1 + if (shouldFlush()) { + flush() + } + } + + private def shouldFlush(): Boolean = { + // Can use serializer.sizeInBytes() to parameterize batch size in terms of bytes + // or just use rowCount (batches of ~1024 rows are common and 16 MB is also a common + // threshold (maybe also applying a minimum row count in case we're dealing with big + // geometries). Checking sizeInBytes() should be done sparingly (expensive)). + rowCount >= 1024 + } + + private def flush(): Unit = { + serializer.writeBatch(dummyOutput) + } + + def commit(): WriterCommitMessage = { + null + } + + def abort(): Unit = {} + + def close(): Unit = { + flush() + serializer.writeEndOfStream(dummyOutput) + serializer.close() + dummyOutput.close() + } + +}