diff --git a/spark/spark-3.5/pom.xml b/spark/spark-3.5/pom.xml
index 59a8cf7c82d..d1a616944a1 100644
--- a/spark/spark-3.5/pom.xml
+++ b/spark/spark-3.5/pom.xml
@@ -48,6 +48,12 @@
+
+ org.apache.arrow
+ arrow-java-root
+ 18.2.0
+ pom
+
org.apache.sedona
sedona-spark-common-${spark.compat.version}_${scala.compat.version}
diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.scala
new file mode 100644
index 00000000000..750e99ee669
--- /dev/null
+++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.datasources.arrow
+
+import org.apache.spark.sql.connector.write.BatchWrite
+import org.apache.spark.sql.connector.write.{DataWriterFactory, PhysicalWriteInfo}
+import org.apache.spark.sql.connector.write.WriterCommitMessage
+import org.apache.spark.sql.connector.write.LogicalWriteInfo
+import org.apache.spark.sql.catalyst.InternalRow
+
+case class ArrowBatchWrite(logicalInfo: LogicalWriteInfo) extends BatchWrite {
+ def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
+ return new ArrowDataWriterFactory(logicalInfo, info)
+ }
+
+ def commit(messages: Array[WriterCommitMessage]): Unit = {}
+
+ def abort(messages: Array[WriterCommitMessage]): Unit = {}
+
+}
diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataSource.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataSource.scala
new file mode 100644
index 00000000000..20412eec615
--- /dev/null
+++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataSource.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.datasources.arrow
+
+import org.apache.spark.sql.connector.catalog.Table
+import org.apache.spark.sql.execution.datasources.FileFormat
+import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+import java.util.Locale
+import scala.jdk.CollectionConverters._
+import scala.util.Try
+import org.apache.sedona.sql.datasources.geopackage.ArrowTable
+
+class ArrowDataSource extends FileDataSourceV2 with DataSourceRegister {
+
+ override def fallbackFileFormat: Class[_ <: FileFormat] = {
+ null
+ }
+
+ override protected def getTable(options: CaseInsensitiveStringMap): Table = {
+ ArrowTable("", sparkSession, options, getPaths(options), None)
+ }
+
+ override def shortName(): String = "arrows"
+}
diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.scala
new file mode 100644
index 00000000000..ad8c631709c
--- /dev/null
+++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.datasources.arrow
+
+import org.apache.spark.sql.connector.write.LogicalWriteInfo
+import org.apache.spark.sql.connector.write.PhysicalWriteInfo
+import org.apache.spark.sql.connector.write.DataWriterFactory
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.write.DataWriter
+
+case class ArrowDataWriterFactory(logicalInfo: LogicalWriteInfo, physicalInfo: PhysicalWriteInfo)
+ extends DataWriterFactory {
+ override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
+ return new ArrowWriter(logicalInfo, physicalInfo, partitionId, taskId)
+ }
+}
diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala
new file mode 100644
index 00000000000..aaedcab8f12
--- /dev/null
+++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.datasources.arrow
+
+import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
+
+import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
+import org.apache.arrow.vector.complex.StructVector
+
+private[arrow] object ArrowEncoderUtils {
+ object Classes {
+ val WRAPPED_ARRAY: Class[_] = classOf[scala.collection.mutable.WrappedArray[_]]
+ val ITERABLE: Class[_] = classOf[scala.collection.Iterable[_]]
+ val MAP: Class[_] = classOf[scala.collection.Map[_, _]]
+ val JLIST: Class[_] = classOf[java.util.List[_]]
+ val JMAP: Class[_] = classOf[java.util.Map[_, _]]
+ }
+
+ def isSubClass(cls: Class[_], tag: ClassTag[_]): Boolean = {
+ cls.isAssignableFrom(tag.runtimeClass)
+ }
+
+ def unsupportedCollectionType(cls: Class[_]): Nothing = {
+ throw new RuntimeException(s"Unsupported collection type: $cls")
+ }
+}
+
+private[arrow] object StructVectors {
+ def unapply(v: AnyRef): Option[(StructVector, Seq[FieldVector])] = v match {
+ case root: VectorSchemaRoot => Option((null, root.getFieldVectors.asScala.toSeq))
+ case struct: StructVector => Option((struct, struct.getChildrenFromFields.asScala.toSeq))
+ case _ => None
+ }
+}
diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala
new file mode 100644
index 00000000000..f97d3d27c17
--- /dev/null
+++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala
@@ -0,0 +1,601 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.datasources.arrow
+
+import java.io.{ByteArrayOutputStream, OutputStream}
+import java.lang.invoke.{MethodHandles, MethodType}
+import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInteger}
+import java.nio.channels.Channels
+import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
+import java.util.{Map => JMap, Objects}
+
+import scala.collection.JavaConverters._
+
+import com.google.protobuf.ByteString
+import org.apache.arrow.memory.BufferAllocator
+import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, IntervalYearVector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, VarCharVector, VectorSchemaRoot, VectorUnloader}
+import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
+import org.apache.arrow.vector.ipc.{ArrowStreamWriter, WriteChannel}
+import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer}
+import org.apache.arrow.vector.util.Text
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.DefinedByConstructorParams
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
+import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils}
+import org.apache.spark.sql.errors.ExecutionErrors
+import org.apache.spark.sql.types.Decimal
+
+private[arrow] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable { self =>
+ def asJava: java.util.Iterator[E] = new java.util.Iterator[E] with AutoCloseable {
+ override def next() = self.next()
+
+ override def hasNext() = self.hasNext
+
+ override def close() = self.close()
+ }
+
+ override def map[B](f: E => B): CloseableIterator[B] = {
+ new CloseableIterator[B] {
+ override def next(): B = f(self.next())
+
+ override def hasNext: Boolean = self.hasNext
+
+ override def close(): Unit = self.close()
+ }
+ }
+}
+
+/**
+ * Helper class for converting user objects into arrow batches.
+ */
+class ArrowSerializer[T](
+ private[this] val enc: AgnosticEncoder[T],
+ private[this] val allocator: BufferAllocator,
+ private[this] val timeZoneId: String) {
+ private val (root, serializer) = ArrowSerializer.serializerFor(enc, allocator, timeZoneId)
+ private val vectors = root.getFieldVectors.asScala
+ private val unloader = new VectorUnloader(root)
+ private val schemaBytes = {
+ // Only serialize the schema once.
+ val bytes = new ByteArrayOutputStream()
+ MessageSerializer.serialize(newChannel(bytes), root.getSchema)
+ bytes.toByteArray
+ }
+ private var rowCount: Int = 0
+ private var closed: Boolean = false
+
+ private def newChannel(output: OutputStream): WriteChannel = {
+ new WriteChannel(Channels.newChannel(output))
+ }
+
+ /**
+ * The size of the current batch.
+ *
+ * The size computed consist of the size of the schema and the size of the arrow buffers. The
+ * actual batch will be larger than that because of alignment, written IPC tokens, and the
+ * written record batch metadata. The size of the record batch metadata is proportional to the
+ * complexity of the schema.
+ */
+ def sizeInBytes: Long = {
+ // We need to set the row count for getBufferSize to return the actual value.
+ root.setRowCount(rowCount)
+ schemaBytes.length + vectors.map(_.getBufferSize).sum
+ }
+
+ /**
+ * Append a record to the current batch.
+ */
+ def append(record: T): Unit = {
+ serializer.write(rowCount, record)
+ rowCount += 1
+ }
+
+ /**
+ * Write the schema and the current batch in Arrow IPC stream format to the [[OutputStream]].
+ */
+ def writeIpcStream(output: OutputStream): Unit = {
+ val channel = newChannel(output)
+ root.setRowCount(rowCount)
+ val batch = unloader.getRecordBatch
+ try {
+ channel.write(schemaBytes)
+ MessageSerializer.serialize(channel, batch)
+ ArrowStreamWriter.writeEndOfStream(channel, IpcOption.DEFAULT)
+ } finally {
+ batch.close()
+ }
+ }
+
+ /**
+ * Write the schema in Arrow IPC stream format to the [[OutputStream]].
+ */
+ def writeSchema(output: OutputStream): Unit = {
+ val channel = newChannel(output)
+ channel.write(schemaBytes)
+ }
+
+ /**
+ * Write current batch in Arrow IPC stream format to the [[OutputStream]].
+ */
+ def writeBatch(output: OutputStream): Unit = {
+ val channel = newChannel(output)
+ root.setRowCount(rowCount)
+ val batch = unloader.getRecordBatch
+ try {
+ MessageSerializer.serialize(channel, batch)
+ } finally {
+ batch.close()
+ }
+ }
+
+ /**
+ * Write the end-of-stream bytes to [[OutputStream]].
+ */
+ def writeEndOfStream(output: OutputStream): Unit = {
+ val channel = newChannel(output)
+ ArrowStreamWriter.writeEndOfStream(channel, IpcOption.DEFAULT)
+ }
+
+ /**
+ * Reset the serializer.
+ */
+ def reset(): Unit = {
+ rowCount = 0
+ vectors.foreach(_.reset())
+ }
+
+ /**
+ * Close the serializer.
+ */
+ def close(): Unit = {
+ root.close()
+ closed = true
+ }
+
+ /**
+ * Check if the serializer has been closed.
+ *
+ * It is illegal to used the serializer after it has been closed. It will lead to errors and
+ * sorts of undefined behavior.
+ */
+ def isClosed: Boolean = closed
+}
+
+object ArrowSerializer {
+ import ArrowEncoderUtils._
+
+ /**
+ * Create an [[Iterator]] that converts the input [[Iterator]] of type `T` into an [[Iterator]]
+ * of Arrow IPC Streams.
+ */
+ def serialize[T](
+ input: Iterator[T],
+ enc: AgnosticEncoder[T],
+ allocator: BufferAllocator,
+ maxRecordsPerBatch: Int,
+ maxBatchSize: Long,
+ timeZoneId: String,
+ batchSizeCheckInterval: Int = 128): CloseableIterator[Array[Byte]] = {
+ assert(maxRecordsPerBatch > 0)
+ assert(maxBatchSize > 0)
+ assert(batchSizeCheckInterval > 0)
+ new CloseableIterator[Array[Byte]] {
+ private val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId)
+ private val bytes = new ByteArrayOutputStream
+ private var hasWrittenFirstBatch = false
+
+ /**
+ * Periodical check to make sure we don't go over the size threshold by too much.
+ */
+ private def sizeOk(i: Int): Boolean = {
+ if (i > 0 && i % batchSizeCheckInterval == 0) {
+ return serializer.sizeInBytes < maxBatchSize
+ }
+ true
+ }
+
+ override def hasNext: Boolean = {
+ (input.hasNext || !hasWrittenFirstBatch) && !serializer.isClosed
+ }
+
+ override def next(): Array[Byte] = {
+ if (!hasNext) {
+ throw new NoSuchElementException()
+ }
+ serializer.reset()
+ bytes.reset()
+ var i = 0
+ while (i < maxRecordsPerBatch && input.hasNext && sizeOk(i)) {
+ serializer.append(input.next())
+ i += 1
+ }
+ serializer.writeIpcStream(bytes)
+ hasWrittenFirstBatch = true
+ bytes.toByteArray
+ }
+
+ override def close(): Unit = serializer.close()
+ }
+ }
+
+ def serialize[T](
+ input: Iterator[T],
+ enc: AgnosticEncoder[T],
+ allocator: BufferAllocator,
+ timeZoneId: String): ByteString = {
+ val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId)
+ try {
+ input.foreach(serializer.append)
+ val output = ByteString.newOutput()
+ serializer.writeIpcStream(output)
+ output.toByteString
+ } finally {
+ serializer.close()
+ }
+ }
+
+ /**
+ * Create a (root) [[Serializer]] for [[AgnosticEncoder]] `encoder`.
+ *
+ * The serializer returned by this method is NOT thread-safe.
+ */
+ def serializerFor[T](
+ encoder: AgnosticEncoder[T],
+ allocator: BufferAllocator,
+ timeZoneId: String): (VectorSchemaRoot, Serializer) = {
+ val arrowSchema =
+ ArrowUtils.toArrowSchema(encoder.schema, timeZoneId, errorOnDuplicatedFieldNames = true)
+ val root = VectorSchemaRoot.create(arrowSchema, allocator)
+ val serializer = if (encoder.schema != encoder.dataType) {
+ assert(root.getSchema.getFields.size() == 1)
+ serializerFor(encoder, root.getVector(0))
+ } else {
+ serializerFor(encoder, root)
+ }
+ root -> serializer
+ }
+
+ // TODO throw better errors on class cast exceptions.
+ private[arrow] def serializerFor[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = {
+ (encoder, v) match {
+ case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: BitVector) =>
+ new FieldSerializer[Boolean, BitVector](v) {
+ override def set(index: Int, value: Boolean): Unit =
+ vector.setSafe(index, if (value) 1 else 0)
+ }
+ case (PrimitiveByteEncoder | BoxedByteEncoder, v: TinyIntVector) =>
+ new FieldSerializer[Byte, TinyIntVector](v) {
+ override def set(index: Int, value: Byte): Unit = vector.setSafe(index, value)
+ }
+ case (PrimitiveShortEncoder | BoxedShortEncoder, v: SmallIntVector) =>
+ new FieldSerializer[Short, SmallIntVector](v) {
+ override def set(index: Int, value: Short): Unit = vector.setSafe(index, value)
+ }
+ case (PrimitiveIntEncoder | BoxedIntEncoder, v: IntVector) =>
+ new FieldSerializer[Int, IntVector](v) {
+ override def set(index: Int, value: Int): Unit = vector.setSafe(index, value)
+ }
+ case (PrimitiveLongEncoder | BoxedLongEncoder, v: BigIntVector) =>
+ new FieldSerializer[Long, BigIntVector](v) {
+ override def set(index: Int, value: Long): Unit = vector.setSafe(index, value)
+ }
+ case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: Float4Vector) =>
+ new FieldSerializer[Float, Float4Vector](v) {
+ override def set(index: Int, value: Float): Unit = vector.setSafe(index, value)
+ }
+ case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: Float8Vector) =>
+ new FieldSerializer[Double, Float8Vector](v) {
+ override def set(index: Int, value: Double): Unit = vector.setSafe(index, value)
+ }
+ case (NullEncoder, v: NullVector) =>
+ new FieldSerializer[Unit, NullVector](v) {
+ override def set(index: Int, value: Unit): Unit = vector.setNull(index)
+ }
+ case (StringEncoder, v: VarCharVector) =>
+ new FieldSerializer[String, VarCharVector](v) {
+ override def set(index: Int, value: String): Unit = setString(v, index, value)
+ }
+ case (JavaEnumEncoder(_), v: VarCharVector) =>
+ new FieldSerializer[Enum[_], VarCharVector](v) {
+ override def set(index: Int, value: Enum[_]): Unit = setString(v, index, value.name())
+ }
+ case (ScalaEnumEncoder(_, _), v: VarCharVector) =>
+ new FieldSerializer[Enumeration#Value, VarCharVector](v) {
+ override def set(index: Int, value: Enumeration#Value): Unit =
+ setString(v, index, value.toString)
+ }
+ case (BinaryEncoder, v: VarBinaryVector) =>
+ new FieldSerializer[Array[Byte], VarBinaryVector](v) {
+ override def set(index: Int, value: Array[Byte]): Unit = vector.setSafe(index, value)
+ }
+ case (SparkDecimalEncoder(_), v: DecimalVector) =>
+ new FieldSerializer[Decimal, DecimalVector](v) {
+ override def set(index: Int, value: Decimal): Unit =
+ setDecimal(vector, index, value.toJavaBigDecimal)
+ }
+ case (ScalaDecimalEncoder(_), v: DecimalVector) =>
+ new FieldSerializer[BigDecimal, DecimalVector](v) {
+ override def set(index: Int, value: BigDecimal): Unit =
+ setDecimal(vector, index, value.bigDecimal)
+ }
+ case (JavaDecimalEncoder(_, false), v: DecimalVector) =>
+ new FieldSerializer[JBigDecimal, DecimalVector](v) {
+ override def set(index: Int, value: JBigDecimal): Unit =
+ setDecimal(vector, index, value)
+ }
+ case (JavaDecimalEncoder(_, true), v: DecimalVector) =>
+ new FieldSerializer[Any, DecimalVector](v) {
+ override def set(index: Int, value: Any): Unit = {
+ val decimal = value match {
+ case j: JBigDecimal => j
+ case d: BigDecimal => d.bigDecimal
+ case k: BigInt => new JBigDecimal(k.bigInteger)
+ case l: JBigInteger => new JBigDecimal(l)
+ case d: Decimal => d.toJavaBigDecimal
+ }
+ setDecimal(vector, index, decimal)
+ }
+ }
+ case (ScalaBigIntEncoder, v: DecimalVector) =>
+ new FieldSerializer[BigInt, DecimalVector](v) {
+ override def set(index: Int, value: BigInt): Unit =
+ setDecimal(vector, index, new JBigDecimal(value.bigInteger))
+ }
+ case (JavaBigIntEncoder, v: DecimalVector) =>
+ new FieldSerializer[JBigInteger, DecimalVector](v) {
+ override def set(index: Int, value: JBigInteger): Unit =
+ setDecimal(vector, index, new JBigDecimal(value))
+ }
+ case (DayTimeIntervalEncoder, v: DurationVector) =>
+ new FieldSerializer[Duration, DurationVector](v) {
+ override def set(index: Int, value: Duration): Unit =
+ vector.setSafe(index, SparkIntervalUtils.durationToMicros(value))
+ }
+ case (YearMonthIntervalEncoder, v: IntervalYearVector) =>
+ new FieldSerializer[Period, IntervalYearVector](v) {
+ override def set(index: Int, value: Period): Unit =
+ vector.setSafe(index, SparkIntervalUtils.periodToMonths(value))
+ }
+ case (DateEncoder(true) | LocalDateEncoder(true), v: DateDayVector) =>
+ new FieldSerializer[Any, DateDayVector](v) {
+ override def set(index: Int, value: Any): Unit =
+ vector.setSafe(index, SparkDateTimeUtils.anyToDays(value))
+ }
+ case (DateEncoder(false), v: DateDayVector) =>
+ new FieldSerializer[java.sql.Date, DateDayVector](v) {
+ override def set(index: Int, value: java.sql.Date): Unit =
+ vector.setSafe(index, SparkDateTimeUtils.fromJavaDate(value))
+ }
+ case (LocalDateEncoder(false), v: DateDayVector) =>
+ new FieldSerializer[LocalDate, DateDayVector](v) {
+ override def set(index: Int, value: LocalDate): Unit =
+ vector.setSafe(index, SparkDateTimeUtils.localDateToDays(value))
+ }
+ case (TimestampEncoder(true) | InstantEncoder(true), v: TimeStampMicroTZVector) =>
+ new FieldSerializer[Any, TimeStampMicroTZVector](v) {
+ override def set(index: Int, value: Any): Unit =
+ vector.setSafe(index, SparkDateTimeUtils.anyToMicros(value))
+ }
+ case (TimestampEncoder(false), v: TimeStampMicroTZVector) =>
+ new FieldSerializer[java.sql.Timestamp, TimeStampMicroTZVector](v) {
+ override def set(index: Int, value: java.sql.Timestamp): Unit =
+ vector.setSafe(index, SparkDateTimeUtils.fromJavaTimestamp(value))
+ }
+ case (InstantEncoder(false), v: TimeStampMicroTZVector) =>
+ new FieldSerializer[Instant, TimeStampMicroTZVector](v) {
+ override def set(index: Int, value: Instant): Unit =
+ vector.setSafe(index, SparkDateTimeUtils.instantToMicros(value))
+ }
+ case (LocalDateTimeEncoder, v: TimeStampMicroVector) =>
+ new FieldSerializer[LocalDateTime, TimeStampMicroVector](v) {
+ override def set(index: Int, value: LocalDateTime): Unit =
+ vector.setSafe(index, SparkDateTimeUtils.localDateTimeToMicros(value))
+ }
+
+ case (OptionEncoder(value), v) =>
+ new Serializer {
+ private[this] val delegate: Serializer = serializerFor(value, v)
+ override def write(index: Int, value: Any): Unit = value match {
+ case Some(value) => delegate.write(index, value)
+ case _ => delegate.write(index, null)
+ }
+ }
+
+ case (ArrayEncoder(element, _), v: ListVector) =>
+ val elementSerializer = serializerFor(element, v.getDataVector)
+ val toIterator = { array: Any =>
+ array.asInstanceOf[Array[_]].iterator
+ }
+ new ArraySerializer(v, toIterator, elementSerializer)
+
+ case (IterableEncoder(tag, element, _, lenient), v: ListVector) =>
+ val elementSerializer = serializerFor(element, v.getDataVector)
+ val toIterator: Any => Iterator[_] = if (lenient) {
+ {
+ case i: scala.collection.Iterable[_] => i.iterator
+ case l: java.util.List[_] => l.iterator().asScala
+ case a: Array[_] => a.iterator
+ case o => unsupportedCollectionType(o.getClass)
+ }
+ } else if (isSubClass(Classes.ITERABLE, tag)) { v =>
+ v.asInstanceOf[scala.collection.Iterable[_]].iterator
+ } else if (isSubClass(Classes.JLIST, tag)) { v =>
+ v.asInstanceOf[java.util.List[_]].iterator().asScala
+ } else {
+ unsupportedCollectionType(tag.runtimeClass)
+ }
+ new ArraySerializer(v, toIterator, elementSerializer)
+
+ case (MapEncoder(tag, key, value, _), v: MapVector) =>
+ val structVector = v.getDataVector.asInstanceOf[StructVector]
+ val extractor = if (isSubClass(classOf[scala.collection.Map[_, _]], tag)) { (v: Any) =>
+ v.asInstanceOf[scala.collection.Map[_, _]].iterator
+ } else if (isSubClass(classOf[JMap[_, _]], tag)) { (v: Any) =>
+ v.asInstanceOf[JMap[Any, Any]].asScala.iterator
+ } else {
+ unsupportedCollectionType(tag.runtimeClass)
+ }
+ val structSerializer = new StructSerializer(
+ structVector,
+ new StructFieldSerializer(
+ extractKey,
+ serializerFor(key, structVector.getChild(MapVector.KEY_NAME))) ::
+ new StructFieldSerializer(
+ extractValue,
+ serializerFor(value, structVector.getChild(MapVector.VALUE_NAME))) :: Nil)
+ new ArraySerializer(v, extractor, structSerializer)
+
+ case (ProductEncoder(tag, fields, _), StructVectors(struct, vectors)) =>
+ if (isSubClass(classOf[Product], tag)) {
+ structSerializerFor(fields, struct, vectors) { (_, i) => p =>
+ p.asInstanceOf[Product].productElement(i)
+ }
+ } else if (isSubClass(classOf[DefinedByConstructorParams], tag)) {
+ structSerializerFor(fields, struct, vectors) { (field, _) =>
+ val getter = methodLookup.findVirtual(
+ tag.runtimeClass,
+ field.name,
+ MethodType.methodType(field.enc.clsTag.runtimeClass))
+ o => getter.invoke(o)
+ }
+ } else {
+ unsupportedCollectionType(tag.runtimeClass)
+ }
+
+ case (RowEncoder(fields), StructVectors(struct, vectors)) =>
+ structSerializerFor(fields, struct, vectors) { (_, i) => r => r.asInstanceOf[Row].get(i) }
+
+ case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) =>
+ structSerializerFor(fields, struct, vectors) { (field, _) =>
+ val getter = methodLookup.findVirtual(
+ tag.runtimeClass,
+ field.readMethod.get,
+ MethodType.methodType(field.enc.clsTag.runtimeClass))
+ o => getter.invoke(o)
+ }
+
+ case (CalendarIntervalEncoder | _: UDTEncoder[_], _) =>
+ // throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType)
+ throw new RuntimeException("Unsupported data type")
+
+ case _ =>
+ throw new RuntimeException(s"Unsupported Encoder($encoder)/Vector($v) combination.")
+ }
+ }
+
+ private val methodLookup = MethodHandles.lookup()
+
+ private def setString(vector: VarCharVector, index: Int, string: String): Unit = {
+ val bytes = Text.encode(string)
+ vector.setSafe(index, bytes, 0, bytes.limit())
+ }
+
+ private def setDecimal(vector: DecimalVector, index: Int, decimal: JBigDecimal): Unit = {
+ val scaledDecimal = if (vector.getScale != decimal.scale()) {
+ decimal.setScale(vector.getScale)
+ } else {
+ decimal
+ }
+ vector.setSafe(index, scaledDecimal)
+ }
+
+ private def extractKey(v: Any): Any = {
+ val key = v.asInstanceOf[(Any, Any)]._1
+ Objects.requireNonNull(key)
+ key
+ }
+
+ private def extractValue(v: Any): Any = {
+ v.asInstanceOf[(Any, Any)]._2
+ }
+
+ private def structSerializerFor(
+ fields: Seq[EncoderField],
+ struct: StructVector,
+ vectors: Seq[FieldVector])(
+ createGetter: (EncoderField, Int) => Any => Any): StructSerializer = {
+ require(fields.size == vectors.size)
+ val serializers = fields.zip(vectors).zipWithIndex.map { case ((field, vector), i) =>
+ val serializer = serializerFor(field.enc, vector)
+ new StructFieldSerializer(createGetter(field, i), serializer)
+ }
+ new StructSerializer(struct, serializers)
+ }
+
+ abstract class Serializer {
+ def write(index: Int, value: Any): Unit
+ }
+
+ private abstract class FieldSerializer[E, V <: FieldVector](val vector: V) extends Serializer {
+ def set(index: Int, value: E): Unit
+
+ override def write(index: Int, raw: Any): Unit = {
+ val value = raw.asInstanceOf[E]
+ if (value != null) {
+ set(index, value)
+ } else {
+ vector.setNull(index)
+ }
+ }
+ }
+
+ private class ArraySerializer(
+ v: ListVector,
+ toIterator: Any => Iterator[Any],
+ elementSerializer: Serializer)
+ extends FieldSerializer[Any, ListVector](v) {
+ override def set(index: Int, value: Any): Unit = {
+ val elementStartIndex = vector.startNewValue(index)
+ var elementIndex = elementStartIndex
+ val iterator = toIterator(value)
+ while (iterator.hasNext) {
+ elementSerializer.write(elementIndex, iterator.next())
+ elementIndex += 1
+ }
+ vector.endValue(index, elementIndex - elementStartIndex)
+ }
+ }
+
+ private class StructFieldSerializer(val extractor: Any => Any, val serializer: Serializer) {
+ def write(index: Int, value: Any): Unit = serializer.write(index, extractor(value))
+ def writeNull(index: Int): Unit = serializer.write(index, null)
+ }
+
+ private class StructSerializer(
+ struct: StructVector,
+ fieldSerializers: Seq[StructFieldSerializer])
+ extends Serializer {
+
+ override def write(index: Int, value: Any): Unit = {
+ if (value == null) {
+ if (struct != null) {
+ struct.setNull(index)
+ }
+ fieldSerializers.foreach(_.writeNull(index))
+ } else {
+ if (struct != null) {
+ struct.setIndexDefined(index)
+ }
+ fieldSerializers.foreach(_.write(index, value))
+ }
+ }
+ }
+}
diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowTable.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowTable.scala
new file mode 100644
index 00000000000..49421a5d313
--- /dev/null
+++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowTable.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.datasources.geopackage
+
+import org.apache.hadoop.fs.FileStatus
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connector.read.ScanBuilder
+import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
+import org.apache.spark.sql.execution.datasources.FileFormat
+import org.apache.spark.sql.execution.datasources.v2.FileTable
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+import scala.jdk.CollectionConverters._
+import org.apache.sedona.sql.datasources.arrow.ArrowWriteBuilder
+
+case class ArrowTable(
+ name: String,
+ sparkSession: SparkSession,
+ options: CaseInsensitiveStringMap,
+ paths: Seq[String],
+ userSpecifiedSchema: Option[StructType])
+ extends FileTable(sparkSession, options, paths, userSpecifiedSchema) {
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
+ null
+ }
+
+ override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+ new ArrowWriteBuilder(info)
+ }
+
+ override def inferSchema(files: Seq[FileStatus]): Option[StructType] = {
+ None
+ }
+
+ override def formatName: String = {
+ "Arrow Stream"
+ }
+
+ override def fallbackFileFormat: Class[_ <: FileFormat] = {
+ null
+ }
+
+}
diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala
new file mode 100644
index 00000000000..60006412873
--- /dev/null
+++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.datasources.arrow
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.memory.RootAllocator
+import org.apache.arrow.vector.complex.MapVector
+import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit}
+import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
+
+import org.apache.spark.sql.errors.ExecutionErrors
+import org.apache.spark.sql.types._
+
+private[arrow] object ArrowUtils {
+
+ val rootAllocator = new RootAllocator(Long.MaxValue)
+
+ // todo: support more types.
+
+ /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */
+ def toArrowType(dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType =
+ dt match {
+ case BooleanType => ArrowType.Bool.INSTANCE
+ case ByteType => new ArrowType.Int(8, true)
+ case ShortType => new ArrowType.Int(8 * 2, true)
+ case IntegerType => new ArrowType.Int(8 * 4, true)
+ case LongType => new ArrowType.Int(8 * 8, true)
+ case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
+ case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
+ case StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE
+ case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE
+ case StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE
+ case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE
+ // TODO(paleolimbot): DecimalType.Fixed is marked private
+ // case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale)
+ case DateType => new ArrowType.Date(DateUnit.DAY)
+ case TimestampType if timeZoneId == null =>
+ throw new IllegalStateException("Missing timezoneId where it is mandatory.")
+ case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId)
+ case TimestampNTZType =>
+ new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
+ case NullType => ArrowType.Null.INSTANCE
+ case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH)
+ case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND)
+ case _ =>
+ // throw ExecutionErrors.unsupportedDataTypeError(dt)
+ throw new RuntimeException("Unsupported data type")
+ }
+
+ def fromArrowType(dt: ArrowType): DataType = dt match {
+ case ArrowType.Bool.INSTANCE => BooleanType
+ case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType
+ case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType
+ case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType
+ case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType
+ case float: ArrowType.FloatingPoint
+ if float.getPrecision() == FloatingPointPrecision.SINGLE =>
+ FloatType
+ case float: ArrowType.FloatingPoint
+ if float.getPrecision() == FloatingPointPrecision.DOUBLE =>
+ DoubleType
+ case ArrowType.Utf8.INSTANCE => StringType
+ case ArrowType.Binary.INSTANCE => BinaryType
+ case ArrowType.LargeUtf8.INSTANCE => StringType
+ case ArrowType.LargeBinary.INSTANCE => BinaryType
+ case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
+ case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
+ case ts: ArrowType.Timestamp
+ if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null =>
+ TimestampNTZType
+ case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType
+ case ArrowType.Null.INSTANCE => NullType
+ case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH =>
+ YearMonthIntervalType()
+ case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType()
+ // case _ => throw ExecutionErrors.unsupportedArrowTypeError(dt)
+ case _ => throw new RuntimeException("Unsupported arrow data type")
+ }
+
+ /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */
+ def toArrowField(
+ name: String,
+ dt: DataType,
+ nullable: Boolean,
+ timeZoneId: String,
+ largeVarTypes: Boolean = false): Field = {
+ dt match {
+ case ArrayType(elementType, containsNull) =>
+ val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null)
+ new Field(
+ name,
+ fieldType,
+ Seq(
+ toArrowField("element", elementType, containsNull, timeZoneId, largeVarTypes)).asJava)
+ case StructType(fields) =>
+ val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
+ new Field(
+ name,
+ fieldType,
+ fields
+ .map { field =>
+ toArrowField(field.name, field.dataType, field.nullable, timeZoneId, largeVarTypes)
+ }
+ .toSeq
+ .asJava)
+ case MapType(keyType, valueType, valueContainsNull) =>
+ val mapType = new FieldType(nullable, new ArrowType.Map(false), null)
+ // Note: Map Type struct can not be null, Struct Type key field can not be null
+ new Field(
+ name,
+ mapType,
+ Seq(
+ toArrowField(
+ MapVector.DATA_VECTOR_NAME,
+ new StructType()
+ .add(MapVector.KEY_NAME, keyType, nullable = false)
+ .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull),
+ nullable = false,
+ timeZoneId,
+ largeVarTypes)).asJava)
+ case udt: UserDefinedType[_] =>
+ toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes)
+ case dataType =>
+ val fieldType =
+ new FieldType(nullable, toArrowType(dataType, timeZoneId, largeVarTypes), null)
+ new Field(name, fieldType, Seq.empty[Field].asJava)
+ }
+ }
+
+ def fromArrowField(field: Field): DataType = {
+ field.getType match {
+ case _: ArrowType.Map =>
+ val elementField = field.getChildren.get(0)
+ val keyType = fromArrowField(elementField.getChildren.get(0))
+ val valueType = fromArrowField(elementField.getChildren.get(1))
+ MapType(keyType, valueType, elementField.getChildren.get(1).isNullable)
+ case ArrowType.List.INSTANCE =>
+ val elementField = field.getChildren().get(0)
+ val elementType = fromArrowField(elementField)
+ ArrayType(elementType, containsNull = elementField.isNullable)
+ case ArrowType.Struct.INSTANCE =>
+ val fields = field.getChildren().asScala.map { child =>
+ val dt = fromArrowField(child)
+ StructField(child.getName, dt, child.isNullable)
+ }
+ StructType(fields.toArray)
+ case arrowType => fromArrowType(arrowType)
+ }
+ }
+
+ /**
+ * Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType
+ */
+ def toArrowSchema(
+ schema: StructType,
+ timeZoneId: String,
+ errorOnDuplicatedFieldNames: Boolean,
+ largeVarTypes: Boolean = false): Schema = {
+ new Schema(schema.map { field =>
+ toArrowField(
+ field.name,
+ deduplicateFieldNames(field.dataType, errorOnDuplicatedFieldNames),
+ field.nullable,
+ timeZoneId,
+ largeVarTypes)
+ }.asJava)
+ }
+
+ def fromArrowSchema(schema: Schema): StructType = {
+ StructType(schema.getFields.asScala.map { field =>
+ val dt = fromArrowField(field)
+ StructField(field.getName, dt, field.isNullable)
+ }.toArray)
+ }
+
+ private def deduplicateFieldNames(
+ dt: DataType,
+ errorOnDuplicatedFieldNames: Boolean): DataType = dt match {
+ case udt: UserDefinedType[_] =>
+ deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames)
+ case st @ StructType(fields) =>
+ val newNames = if (st.names.toSet.size == st.names.length) {
+ st.names
+ } else {
+ if (errorOnDuplicatedFieldNames) {
+ // throw ExecutionErrors.duplicatedFieldNameInArrowStructError(st.names)
+ throw new RuntimeException("duplicated field name in arrow struct")
+ }
+ val genNawName = st.names.groupBy(identity).map {
+ case (name, names) if names.length > 1 =>
+ val i = new AtomicInteger()
+ name -> { () => s"${name}_${i.getAndIncrement()}" }
+ case (name, _) => name -> { () => name }
+ }
+ st.names.map(genNawName(_)())
+ }
+ val newFields =
+ fields.zip(newNames).map { case (StructField(_, dataType, nullable, metadata), name) =>
+ StructField(
+ name,
+ deduplicateFieldNames(dataType, errorOnDuplicatedFieldNames),
+ nullable,
+ metadata)
+ }
+ StructType(newFields)
+ case ArrayType(elementType, containsNull) =>
+ ArrayType(deduplicateFieldNames(elementType, errorOnDuplicatedFieldNames), containsNull)
+ case MapType(keyType, valueType, valueContainsNull) =>
+ MapType(
+ deduplicateFieldNames(keyType, errorOnDuplicatedFieldNames),
+ deduplicateFieldNames(valueType, errorOnDuplicatedFieldNames),
+ valueContainsNull)
+ case _ => dt
+ }
+}
diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.scala
new file mode 100644
index 00000000000..5efee38920c
--- /dev/null
+++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.datasources.arrow
+
+import org.apache.spark.sql.connector.write.WriteBuilder
+import org.apache.spark.sql.connector.write.Write
+import org.apache.spark.sql.connector.write.BatchWrite
+import org.apache.spark.sql.connector.write.LogicalWriteInfo
+
+case class ArrowWriteBuilder(info: LogicalWriteInfo) extends WriteBuilder {
+
+ override def buildForBatch(): BatchWrite = {
+ new ArrowBatchWrite(info)
+ }
+
+}
diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala
new file mode 100644
index 00000000000..ee1d84af1e2
--- /dev/null
+++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.datasources.arrow
+
+import org.apache.spark.sql.connector.write.LogicalWriteInfo
+import org.apache.spark.sql.connector.write.PhysicalWriteInfo
+import org.apache.spark.sql.connector.write.DataWriter
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.spark.sql.connector.write.WriterCommitMessage
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.Row;
+import java.io.ByteArrayOutputStream
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+
+case class ArrowWriter() extends DataWriter[InternalRow] {
+
+ private var encoder: AgnosticEncoder[Row] = _
+ private var rowDeserializer: ExpressionEncoder.Deserializer[Row] = _
+ private var serializer: ArrowSerializer[Row] = _
+ private var dummyOutput: ByteArrayOutputStream = _
+ private var rowCount: Long = 0
+
+ def this(
+ logicalInfo: LogicalWriteInfo,
+ physicalInfo: PhysicalWriteInfo,
+ partitionId: Int,
+ taskId: Long) {
+ this()
+ dummyOutput = new ByteArrayOutputStream()
+ encoder = RowEncoder.encoderFor(logicalInfo.schema())
+ rowDeserializer = Encoders
+ .row(logicalInfo.schema())
+ .asInstanceOf[ExpressionEncoder[Row]]
+ .resolveAndBind()
+ .createDeserializer()
+ serializer = new ArrowSerializer[Row](encoder, new RootAllocator(), "UTC");
+
+ serializer.writeSchema(dummyOutput)
+ }
+
+ def write(record: InternalRow): Unit = {
+ serializer.append(rowDeserializer.apply(record))
+ rowCount = rowCount + 1
+ if (shouldFlush()) {
+ flush()
+ }
+ }
+
+ private def shouldFlush(): Boolean = {
+ // Can use serializer.sizeInBytes() to parameterize batch size in terms of bytes
+ // or just use rowCount (batches of ~1024 rows are common and 16 MB is also a common
+ // threshold (maybe also applying a minimum row count in case we're dealing with big
+ // geometries). Checking sizeInBytes() should be done sparingly (expensive)).
+ rowCount >= 1024
+ }
+
+ private def flush(): Unit = {
+ serializer.writeBatch(dummyOutput)
+ }
+
+ def commit(): WriterCommitMessage = {
+ null
+ }
+
+ def abort(): Unit = {}
+
+ def close(): Unit = {
+ flush()
+ serializer.writeEndOfStream(dummyOutput)
+ serializer.close()
+ dummyOutput.close()
+ }
+
+}