From 6f093c97ea452f99e91bff4428098b9a1a82daeb Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 17 Mar 2025 15:49:15 -0500 Subject: [PATCH 01/12] arrow spark writer stub --- spark/common/pom.xml | 7 ++ .../datasources/arrow/ArrowBatchWrite.java | 50 +++++++++++ .../arrow/ArrowDataWriterFactory.java | 37 ++++++++ .../sql/datasources/arrow/ArrowTable.java | 70 +++++++++++++++ .../datasources/arrow/ArrowTableProvider.java | 41 +++++++++ .../datasources/arrow/ArrowWriteBuilder.java | 37 ++++++++ .../sql/datasources/arrow/ArrowWriter.java | 86 +++++++++++++++++++ 7 files changed, 328 insertions(+) create mode 100644 spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.java create mode 100644 spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.java create mode 100644 spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTable.java create mode 100644 spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTableProvider.java create mode 100644 spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.java create mode 100644 spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriter.java diff --git a/spark/common/pom.xml b/spark/common/pom.xml index c5962ad0c84..b78ca55d4f7 100644 --- a/spark/common/pom.xml +++ b/spark/common/pom.xml @@ -37,6 +37,13 @@ + + + org.apache.spark + spark-connect-common_2.12 + 3.5.5 + + org.apache.sedona sedona-common diff --git a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.java b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.java new file mode 100644 index 00000000000..25d66747885 --- /dev/null +++ b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow; + +import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.DataWriterFactory; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.PhysicalWriteInfo; +import org.apache.spark.sql.connector.write.WriterCommitMessage; + +class ArrowBatchWrite implements BatchWrite { + private final LogicalWriteInfo logicalWriteInfo; + + public ArrowBatchWrite(LogicalWriteInfo info) { + this.logicalWriteInfo = info; + } + + @Override + public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) { + return new ArrowDataWriterFactory(logicalWriteInfo.schema()); + } + + @Override + public void commit(WriterCommitMessage[] messages) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'commit'"); + } + + @Override + public void abort(WriterCommitMessage[] messages) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'abort'"); + } +} diff --git a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.java b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.java new file mode 100644 index 00000000000..31c99a1f63b --- /dev/null +++ b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.write.DataWriter; +import org.apache.spark.sql.connector.write.DataWriterFactory; +import org.apache.spark.sql.types.StructType; + +class ArrowDataWriterFactory implements DataWriterFactory { + private final StructType schema; + + public ArrowDataWriterFactory(StructType schema) { + this.schema = schema; + } + + @Override + public DataWriter createWriter(int partitionId, long taskId) { + return new ArrowWriter(partitionId, taskId, schema); + } +} diff --git a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTable.java b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTable.java new file mode 100644 index 00000000000..5c6463c05b0 --- /dev/null +++ b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTable.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow; + +import java.util.HashSet; +import java.util.Set; +import org.apache.spark.sql.connector.catalog.SupportsWrite; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.SupportsOverwrite; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.StructType; + +class ArrowTable implements SupportsWrite, SupportsOverwrite { + private Set capabilities; + private StructType schema; + + ArrowTable(StructType schema) { + this.schema = schema; + } + + @Override + public String name() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'name'"); + } + + @Override + public StructType schema() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'schema'"); + } + + @Override + public Set capabilities() { + if (capabilities == null) { + this.capabilities = new HashSet<>(); + capabilities.add(TableCapability.BATCH_WRITE); + } + return capabilities; + } + + @Override + public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { + return new ArrowWriteBuilder(info); + } + + @Override + public WriteBuilder overwrite(Filter[] filters) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'overwrite'"); + } +} diff --git a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTableProvider.java b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTableProvider.java new file mode 100644 index 00000000000..5f4d0bc1231 --- /dev/null +++ b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTableProvider.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow; + +import java.util.Map; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableProvider; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +class ArrowTableProvider implements TableProvider { + + @Override + public StructType inferSchema(CaseInsensitiveStringMap options) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'inferSchema'"); + } + + @Override + public Table getTable( + StructType schema, Transform[] partitioning, Map properties) { + return new ArrowTable(schema); + } +} diff --git a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.java b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.java new file mode 100644 index 00000000000..e77af25b7d0 --- /dev/null +++ b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow; + +import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.WriteBuilder; + +class ArrowWriteBuilder implements WriteBuilder { + + private final LogicalWriteInfo writeInfo; + + public ArrowWriteBuilder(LogicalWriteInfo writeInfo) { + this.writeInfo = writeInfo; + } + + @Override + public BatchWrite buildForBatch() { + throw new UnsupportedOperationException("Unimplemented method 'buildForBatch'"); + } +} diff --git a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriter.java b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriter.java new file mode 100644 index 00000000000..e2dabc6b394 --- /dev/null +++ b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriter.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow; + +import java.io.IOException; +import org.apache.arrow.memory.RootAllocator; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder; +import org.apache.spark.sql.catalyst.encoders.RowEncoder; +import org.apache.spark.sql.connect.client.arrow.ArrowSerializer; +import org.apache.spark.sql.connector.write.DataWriter; +import org.apache.spark.sql.connector.write.WriterCommitMessage; +import org.apache.spark.sql.types.StructType; + +class ArrowWriter implements DataWriter { + private final int partitionId; + private final long taskId; + private int rowCount; + private AgnosticEncoder encoder; + private Encoder rowEncoder; + // https://github.com/apache/spark/blob/9353e94e50f3f73565f5f0023effd7e265c177b9/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala#L50 + private ArrowSerializer serializer; + + public ArrowWriter(int partitionId, long taskId, StructType schema) { + this.partitionId = partitionId; + this.taskId = taskId; + this.rowCount = 0; + this.encoder = RowEncoder.encoderFor(schema); + this.serializer = new ArrowSerializer(encoder, new RootAllocator(), "UTC"); + + // Create file, write schema + // Problem: ArrowSerializer() does not expose internal to write just the schema + // bytes. + } + + @Override + public void close() throws IOException { + // Close file + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'close'"); + } + + @Override + public void write(InternalRow record) throws IOException { + // Problem: serializer needs a Row but we have an InternalRow + // serializer.append(encoder.fromRow(record)); + + rowCount++; + if (rowCount > 1024) { + // Problem: writeIpcStream() writes both the schema and the batch, but + // we only want the batch + // serializer.writeIpcStream(null); + rowCount = 0; + } + } + + @Override + public WriterCommitMessage commit() throws IOException { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'commit'"); + } + + @Override + public void abort() throws IOException { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'abort'"); + } +} From 87a695c12da8f3f69f2053e2ae5a01b863cb17d4 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 17 Mar 2025 15:49:24 -0500 Subject: [PATCH 02/12] one more --- .../apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.java b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.java index e77af25b7d0..19061f433c1 100644 --- a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.java +++ b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.java @@ -32,6 +32,6 @@ public ArrowWriteBuilder(LogicalWriteInfo writeInfo) { @Override public BatchWrite buildForBatch() { - throw new UnsupportedOperationException("Unimplemented method 'buildForBatch'"); + return new ArrowBatchWrite(writeInfo); } } From a2795f8a4329d0500d4512fdd1652d18b0eb3bf2 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 18 Mar 2025 14:39:50 -0500 Subject: [PATCH 03/12] maybe build on more than one spark/scala combo --- spark/common/pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/common/pom.xml b/spark/common/pom.xml index b78ca55d4f7..5d137dd3d06 100644 --- a/spark/common/pom.xml +++ b/spark/common/pom.xml @@ -40,8 +40,8 @@ org.apache.spark - spark-connect-common_2.12 - 3.5.5 + spark-connect-common_${scala.compat.version} + ${spark.version} From b8ef6510e2bd2d4e6b6d9a84afd7971037ae7f11 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 25 Mar 2025 11:08:27 -0500 Subject: [PATCH 04/12] maybe get building --- spark/common/pom.xml | 8 + .../datasources/arrow/ArrowEncoderUtils.scala | 51 ++ .../datasources/arrow/ArrowSerializer.scala | 572 ++++++++++++++++++ .../sql/datasources/arrow/ArrowUtils.scala | 234 +++++++ 4 files changed, 865 insertions(+) create mode 100644 spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala create mode 100644 spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala create mode 100644 spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala diff --git a/spark/common/pom.xml b/spark/common/pom.xml index 5d137dd3d06..30beeea16ba 100644 --- a/spark/common/pom.xml +++ b/spark/common/pom.xml @@ -38,6 +38,14 @@ + + org.apache.arrow + arrow-java-root + 18.2.0 + pom + + + org.apache.spark spark-connect-common_${scala.compat.version} diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala new file mode 100644 index 00000000000..aaedcab8f12 --- /dev/null +++ b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow + +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag + +import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot} +import org.apache.arrow.vector.complex.StructVector + +private[arrow] object ArrowEncoderUtils { + object Classes { + val WRAPPED_ARRAY: Class[_] = classOf[scala.collection.mutable.WrappedArray[_]] + val ITERABLE: Class[_] = classOf[scala.collection.Iterable[_]] + val MAP: Class[_] = classOf[scala.collection.Map[_, _]] + val JLIST: Class[_] = classOf[java.util.List[_]] + val JMAP: Class[_] = classOf[java.util.Map[_, _]] + } + + def isSubClass(cls: Class[_], tag: ClassTag[_]): Boolean = { + cls.isAssignableFrom(tag.runtimeClass) + } + + def unsupportedCollectionType(cls: Class[_]): Nothing = { + throw new RuntimeException(s"Unsupported collection type: $cls") + } +} + +private[arrow] object StructVectors { + def unapply(v: AnyRef): Option[(StructVector, Seq[FieldVector])] = v match { + case root: VectorSchemaRoot => Option((null, root.getFieldVectors.asScala.toSeq)) + case struct: StructVector => Option((struct, struct.getChildrenFromFields.asScala.toSeq)) + case _ => None + } +} diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala new file mode 100644 index 00000000000..1a0e6604b33 --- /dev/null +++ b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala @@ -0,0 +1,572 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow + +import java.io.{ByteArrayOutputStream, OutputStream} +import java.lang.invoke.{MethodHandles, MethodType} +import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInteger} +import java.nio.channels.Channels +import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} +import java.util.{Map => JMap, Objects} + +import scala.collection.JavaConverters._ + +import com.google.protobuf.ByteString +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, IntervalYearVector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, VarCharVector, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.arrow.vector.ipc.{ArrowStreamWriter, WriteChannel} +import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer} +import org.apache.arrow.vector.util.Text + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.DefinedByConstructorParams +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ +import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} +import org.apache.spark.sql.errors.ExecutionErrors +import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.util.ArrowUtils + +private[sql] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable { self => + def asJava: java.util.Iterator[E] = new java.util.Iterator[E] with AutoCloseable { + override def next() = self.next() + + override def hasNext() = self.hasNext + + override def close() = self.close() + } + + override def map[B](f: E => B): CloseableIterator[B] = { + new CloseableIterator[B] { + override def next(): B = f(self.next()) + + override def hasNext: Boolean = self.hasNext + + override def close(): Unit = self.close() + } + } +} + +/** + * Helper class for converting user objects into arrow batches. + */ +class ArrowSerializer[T]( + private[this] val enc: AgnosticEncoder[T], + private[this] val allocator: BufferAllocator, + private[this] val timeZoneId: String) { + private val (root, serializer) = ArrowSerializer.serializerFor(enc, allocator, timeZoneId) + private val vectors = root.getFieldVectors.asScala + private val unloader = new VectorUnloader(root) + private val schemaBytes = { + // Only serialize the schema once. + val bytes = new ByteArrayOutputStream() + MessageSerializer.serialize(newChannel(bytes), root.getSchema) + bytes.toByteArray + } + private var rowCount: Int = 0 + private var closed: Boolean = false + + private def newChannel(output: OutputStream): WriteChannel = { + new WriteChannel(Channels.newChannel(output)) + } + + /** + * The size of the current batch. + * + * The size computed consist of the size of the schema and the size of the arrow buffers. The + * actual batch will be larger than that because of alignment, written IPC tokens, and the + * written record batch metadata. The size of the record batch metadata is proportional to the + * complexity of the schema. + */ + def sizeInBytes: Long = { + // We need to set the row count for getBufferSize to return the actual value. + root.setRowCount(rowCount) + schemaBytes.length + vectors.map(_.getBufferSize).sum + } + + /** + * Append a record to the current batch. + */ + def append(record: T): Unit = { + serializer.write(rowCount, record) + rowCount += 1 + } + + /** + * Write the schema and the current batch in Arrow IPC stream format to the [[OutputStream]]. + */ + def writeIpcStream(output: OutputStream): Unit = { + val channel = newChannel(output) + root.setRowCount(rowCount) + val batch = unloader.getRecordBatch + try { + channel.write(schemaBytes) + MessageSerializer.serialize(channel, batch) + ArrowStreamWriter.writeEndOfStream(channel, IpcOption.DEFAULT) + } finally { + batch.close() + } + } + + /** + * Reset the serializer. + */ + def reset(): Unit = { + rowCount = 0 + vectors.foreach(_.reset()) + } + + /** + * Close the serializer. + */ + def close(): Unit = { + root.close() + closed = true + } + + /** + * Check if the serializer has been closed. + * + * It is illegal to used the serializer after it has been closed. It will lead to errors and + * sorts of undefined behavior. + */ + def isClosed: Boolean = closed +} + +object ArrowSerializer { + import ArrowEncoderUtils._ + + /** + * Create an [[Iterator]] that converts the input [[Iterator]] of type `T` into an [[Iterator]] + * of Arrow IPC Streams. + */ + def serialize[T]( + input: Iterator[T], + enc: AgnosticEncoder[T], + allocator: BufferAllocator, + maxRecordsPerBatch: Int, + maxBatchSize: Long, + timeZoneId: String, + batchSizeCheckInterval: Int = 128): CloseableIterator[Array[Byte]] = { + assert(maxRecordsPerBatch > 0) + assert(maxBatchSize > 0) + assert(batchSizeCheckInterval > 0) + new CloseableIterator[Array[Byte]] { + private val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId) + private val bytes = new ByteArrayOutputStream + private var hasWrittenFirstBatch = false + + /** + * Periodical check to make sure we don't go over the size threshold by too much. + */ + private def sizeOk(i: Int): Boolean = { + if (i > 0 && i % batchSizeCheckInterval == 0) { + return serializer.sizeInBytes < maxBatchSize + } + true + } + + override def hasNext: Boolean = { + (input.hasNext || !hasWrittenFirstBatch) && !serializer.isClosed + } + + override def next(): Array[Byte] = { + if (!hasNext) { + throw new NoSuchElementException() + } + serializer.reset() + bytes.reset() + var i = 0 + while (i < maxRecordsPerBatch && input.hasNext && sizeOk(i)) { + serializer.append(input.next()) + i += 1 + } + serializer.writeIpcStream(bytes) + hasWrittenFirstBatch = true + bytes.toByteArray + } + + override def close(): Unit = serializer.close() + } + } + + def serialize[T]( + input: Iterator[T], + enc: AgnosticEncoder[T], + allocator: BufferAllocator, + timeZoneId: String): ByteString = { + val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId) + try { + input.foreach(serializer.append) + val output = ByteString.newOutput() + serializer.writeIpcStream(output) + output.toByteString + } finally { + serializer.close() + } + } + + /** + * Create a (root) [[Serializer]] for [[AgnosticEncoder]] `encoder`. + * + * The serializer returned by this method is NOT thread-safe. + */ + def serializerFor[T]( + encoder: AgnosticEncoder[T], + allocator: BufferAllocator, + timeZoneId: String): (VectorSchemaRoot, Serializer) = { + val arrowSchema = + ArrowUtils.toArrowSchema(encoder.schema, timeZoneId, errorOnDuplicatedFieldNames = true) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val serializer = if (encoder.schema != encoder.dataType) { + assert(root.getSchema.getFields.size() == 1) + serializerFor(encoder, root.getVector(0)) + } else { + serializerFor(encoder, root) + } + root -> serializer + } + + // TODO throw better errors on class cast exceptions. + private[arrow] def serializerFor[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = { + (encoder, v) match { + case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: BitVector) => + new FieldSerializer[Boolean, BitVector](v) { + override def set(index: Int, value: Boolean): Unit = + vector.setSafe(index, if (value) 1 else 0) + } + case (PrimitiveByteEncoder | BoxedByteEncoder, v: TinyIntVector) => + new FieldSerializer[Byte, TinyIntVector](v) { + override def set(index: Int, value: Byte): Unit = vector.setSafe(index, value) + } + case (PrimitiveShortEncoder | BoxedShortEncoder, v: SmallIntVector) => + new FieldSerializer[Short, SmallIntVector](v) { + override def set(index: Int, value: Short): Unit = vector.setSafe(index, value) + } + case (PrimitiveIntEncoder | BoxedIntEncoder, v: IntVector) => + new FieldSerializer[Int, IntVector](v) { + override def set(index: Int, value: Int): Unit = vector.setSafe(index, value) + } + case (PrimitiveLongEncoder | BoxedLongEncoder, v: BigIntVector) => + new FieldSerializer[Long, BigIntVector](v) { + override def set(index: Int, value: Long): Unit = vector.setSafe(index, value) + } + case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: Float4Vector) => + new FieldSerializer[Float, Float4Vector](v) { + override def set(index: Int, value: Float): Unit = vector.setSafe(index, value) + } + case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: Float8Vector) => + new FieldSerializer[Double, Float8Vector](v) { + override def set(index: Int, value: Double): Unit = vector.setSafe(index, value) + } + case (NullEncoder, v: NullVector) => + new FieldSerializer[Unit, NullVector](v) { + override def set(index: Int, value: Unit): Unit = vector.setNull(index) + } + case (StringEncoder, v: VarCharVector) => + new FieldSerializer[String, VarCharVector](v) { + override def set(index: Int, value: String): Unit = setString(v, index, value) + } + case (JavaEnumEncoder(_), v: VarCharVector) => + new FieldSerializer[Enum[_], VarCharVector](v) { + override def set(index: Int, value: Enum[_]): Unit = setString(v, index, value.name()) + } + case (ScalaEnumEncoder(_, _), v: VarCharVector) => + new FieldSerializer[Enumeration#Value, VarCharVector](v) { + override def set(index: Int, value: Enumeration#Value): Unit = + setString(v, index, value.toString) + } + case (BinaryEncoder, v: VarBinaryVector) => + new FieldSerializer[Array[Byte], VarBinaryVector](v) { + override def set(index: Int, value: Array[Byte]): Unit = vector.setSafe(index, value) + } + case (SparkDecimalEncoder(_), v: DecimalVector) => + new FieldSerializer[Decimal, DecimalVector](v) { + override def set(index: Int, value: Decimal): Unit = + setDecimal(vector, index, value.toJavaBigDecimal) + } + case (ScalaDecimalEncoder(_), v: DecimalVector) => + new FieldSerializer[BigDecimal, DecimalVector](v) { + override def set(index: Int, value: BigDecimal): Unit = + setDecimal(vector, index, value.bigDecimal) + } + case (JavaDecimalEncoder(_, false), v: DecimalVector) => + new FieldSerializer[JBigDecimal, DecimalVector](v) { + override def set(index: Int, value: JBigDecimal): Unit = + setDecimal(vector, index, value) + } + case (JavaDecimalEncoder(_, true), v: DecimalVector) => + new FieldSerializer[Any, DecimalVector](v) { + override def set(index: Int, value: Any): Unit = { + val decimal = value match { + case j: JBigDecimal => j + case d: BigDecimal => d.bigDecimal + case k: BigInt => new JBigDecimal(k.bigInteger) + case l: JBigInteger => new JBigDecimal(l) + case d: Decimal => d.toJavaBigDecimal + } + setDecimal(vector, index, decimal) + } + } + case (ScalaBigIntEncoder, v: DecimalVector) => + new FieldSerializer[BigInt, DecimalVector](v) { + override def set(index: Int, value: BigInt): Unit = + setDecimal(vector, index, new JBigDecimal(value.bigInteger)) + } + case (JavaBigIntEncoder, v: DecimalVector) => + new FieldSerializer[JBigInteger, DecimalVector](v) { + override def set(index: Int, value: JBigInteger): Unit = + setDecimal(vector, index, new JBigDecimal(value)) + } + case (DayTimeIntervalEncoder, v: DurationVector) => + new FieldSerializer[Duration, DurationVector](v) { + override def set(index: Int, value: Duration): Unit = + vector.setSafe(index, SparkIntervalUtils.durationToMicros(value)) + } + case (YearMonthIntervalEncoder, v: IntervalYearVector) => + new FieldSerializer[Period, IntervalYearVector](v) { + override def set(index: Int, value: Period): Unit = + vector.setSafe(index, SparkIntervalUtils.periodToMonths(value)) + } + case (DateEncoder(true) | LocalDateEncoder(true), v: DateDayVector) => + new FieldSerializer[Any, DateDayVector](v) { + override def set(index: Int, value: Any): Unit = + vector.setSafe(index, SparkDateTimeUtils.anyToDays(value)) + } + case (DateEncoder(false), v: DateDayVector) => + new FieldSerializer[java.sql.Date, DateDayVector](v) { + override def set(index: Int, value: java.sql.Date): Unit = + vector.setSafe(index, SparkDateTimeUtils.fromJavaDate(value)) + } + case (LocalDateEncoder(false), v: DateDayVector) => + new FieldSerializer[LocalDate, DateDayVector](v) { + override def set(index: Int, value: LocalDate): Unit = + vector.setSafe(index, SparkDateTimeUtils.localDateToDays(value)) + } + case (TimestampEncoder(true) | InstantEncoder(true), v: TimeStampMicroTZVector) => + new FieldSerializer[Any, TimeStampMicroTZVector](v) { + override def set(index: Int, value: Any): Unit = + vector.setSafe(index, SparkDateTimeUtils.anyToMicros(value)) + } + case (TimestampEncoder(false), v: TimeStampMicroTZVector) => + new FieldSerializer[java.sql.Timestamp, TimeStampMicroTZVector](v) { + override def set(index: Int, value: java.sql.Timestamp): Unit = + vector.setSafe(index, SparkDateTimeUtils.fromJavaTimestamp(value)) + } + case (InstantEncoder(false), v: TimeStampMicroTZVector) => + new FieldSerializer[Instant, TimeStampMicroTZVector](v) { + override def set(index: Int, value: Instant): Unit = + vector.setSafe(index, SparkDateTimeUtils.instantToMicros(value)) + } + case (LocalDateTimeEncoder, v: TimeStampMicroVector) => + new FieldSerializer[LocalDateTime, TimeStampMicroVector](v) { + override def set(index: Int, value: LocalDateTime): Unit = + vector.setSafe(index, SparkDateTimeUtils.localDateTimeToMicros(value)) + } + + case (OptionEncoder(value), v) => + new Serializer { + private[this] val delegate: Serializer = serializerFor(value, v) + override def write(index: Int, value: Any): Unit = value match { + case Some(value) => delegate.write(index, value) + case _ => delegate.write(index, null) + } + } + + case (ArrayEncoder(element, _), v: ListVector) => + val elementSerializer = serializerFor(element, v.getDataVector) + val toIterator = { array: Any => + array.asInstanceOf[Array[_]].iterator + } + new ArraySerializer(v, toIterator, elementSerializer) + + case (IterableEncoder(tag, element, _, lenient), v: ListVector) => + val elementSerializer = serializerFor(element, v.getDataVector) + val toIterator: Any => Iterator[_] = if (lenient) { + { + case i: scala.collection.Iterable[_] => i.iterator + case l: java.util.List[_] => l.iterator().asScala + case a: Array[_] => a.iterator + case o => unsupportedCollectionType(o.getClass) + } + } else if (isSubClass(Classes.ITERABLE, tag)) { v => + v.asInstanceOf[scala.collection.Iterable[_]].iterator + } else if (isSubClass(Classes.JLIST, tag)) { v => + v.asInstanceOf[java.util.List[_]].iterator().asScala + } else { + unsupportedCollectionType(tag.runtimeClass) + } + new ArraySerializer(v, toIterator, elementSerializer) + + case (MapEncoder(tag, key, value, _), v: MapVector) => + val structVector = v.getDataVector.asInstanceOf[StructVector] + val extractor = if (isSubClass(classOf[scala.collection.Map[_, _]], tag)) { (v: Any) => + v.asInstanceOf[scala.collection.Map[_, _]].iterator + } else if (isSubClass(classOf[JMap[_, _]], tag)) { (v: Any) => + v.asInstanceOf[JMap[Any, Any]].asScala.iterator + } else { + unsupportedCollectionType(tag.runtimeClass) + } + val structSerializer = new StructSerializer( + structVector, + new StructFieldSerializer( + extractKey, + serializerFor(key, structVector.getChild(MapVector.KEY_NAME))) :: + new StructFieldSerializer( + extractValue, + serializerFor(value, structVector.getChild(MapVector.VALUE_NAME))) :: Nil) + new ArraySerializer(v, extractor, structSerializer) + + case (ProductEncoder(tag, fields, _), StructVectors(struct, vectors)) => + if (isSubClass(classOf[Product], tag)) { + structSerializerFor(fields, struct, vectors) { (_, i) => p => + p.asInstanceOf[Product].productElement(i) + } + } else if (isSubClass(classOf[DefinedByConstructorParams], tag)) { + structSerializerFor(fields, struct, vectors) { (field, _) => + val getter = methodLookup.findVirtual( + tag.runtimeClass, + field.name, + MethodType.methodType(field.enc.clsTag.runtimeClass)) + o => getter.invoke(o) + } + } else { + unsupportedCollectionType(tag.runtimeClass) + } + + case (RowEncoder(fields), StructVectors(struct, vectors)) => + structSerializerFor(fields, struct, vectors) { (_, i) => r => r.asInstanceOf[Row].get(i) } + + case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) => + structSerializerFor(fields, struct, vectors) { (field, _) => + val getter = methodLookup.findVirtual( + tag.runtimeClass, + field.readMethod.get, + MethodType.methodType(field.enc.clsTag.runtimeClass)) + o => getter.invoke(o) + } + + case (CalendarIntervalEncoder | _: UDTEncoder[_], _) => + // throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType) + throw new RuntimeException("Unsupported data type") + + case _ => + throw new RuntimeException(s"Unsupported Encoder($encoder)/Vector($v) combination.") + } + } + + private val methodLookup = MethodHandles.lookup() + + private def setString(vector: VarCharVector, index: Int, string: String): Unit = { + val bytes = Text.encode(string) + vector.setSafe(index, bytes, 0, bytes.limit()) + } + + private def setDecimal(vector: DecimalVector, index: Int, decimal: JBigDecimal): Unit = { + val scaledDecimal = if (vector.getScale != decimal.scale()) { + decimal.setScale(vector.getScale) + } else { + decimal + } + vector.setSafe(index, scaledDecimal) + } + + private def extractKey(v: Any): Any = { + val key = v.asInstanceOf[(Any, Any)]._1 + Objects.requireNonNull(key) + key + } + + private def extractValue(v: Any): Any = { + v.asInstanceOf[(Any, Any)]._2 + } + + private def structSerializerFor( + fields: Seq[EncoderField], + struct: StructVector, + vectors: Seq[FieldVector])( + createGetter: (EncoderField, Int) => Any => Any): StructSerializer = { + require(fields.size == vectors.size) + val serializers = fields.zip(vectors).zipWithIndex.map { case ((field, vector), i) => + val serializer = serializerFor(field.enc, vector) + new StructFieldSerializer(createGetter(field, i), serializer) + } + new StructSerializer(struct, serializers) + } + + abstract class Serializer { + def write(index: Int, value: Any): Unit + } + + private abstract class FieldSerializer[E, V <: FieldVector](val vector: V) extends Serializer { + def set(index: Int, value: E): Unit + + override def write(index: Int, raw: Any): Unit = { + val value = raw.asInstanceOf[E] + if (value != null) { + set(index, value) + } else { + vector.setNull(index) + } + } + } + + private class ArraySerializer( + v: ListVector, + toIterator: Any => Iterator[Any], + elementSerializer: Serializer) + extends FieldSerializer[Any, ListVector](v) { + override def set(index: Int, value: Any): Unit = { + val elementStartIndex = vector.startNewValue(index) + var elementIndex = elementStartIndex + val iterator = toIterator(value) + while (iterator.hasNext) { + elementSerializer.write(elementIndex, iterator.next()) + elementIndex += 1 + } + vector.endValue(index, elementIndex - elementStartIndex) + } + } + + private class StructFieldSerializer(val extractor: Any => Any, val serializer: Serializer) { + def write(index: Int, value: Any): Unit = serializer.write(index, extractor(value)) + def writeNull(index: Int): Unit = serializer.write(index, null) + } + + private class StructSerializer( + struct: StructVector, + fieldSerializers: Seq[StructFieldSerializer]) + extends Serializer { + + override def write(index: Int, value: Any): Unit = { + if (value == null) { + if (struct != null) { + struct.setNull(index) + } + fieldSerializers.foreach(_.writeNull(index)) + } else { + if (struct != null) { + struct.setIndexDefined(index) + } + fieldSerializers.foreach(_.write(index, value)) + } + } + } +} diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala new file mode 100644 index 00000000000..a545981312b --- /dev/null +++ b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow + +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.JavaConverters._ + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.complex.MapVector +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit} +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} + +import org.apache.spark.sql.errors.ExecutionErrors +import org.apache.spark.sql.types._ + +private[sql] object ArrowUtils { + + val rootAllocator = new RootAllocator(Long.MaxValue) + + // todo: support more types. + + /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ + def toArrowType(dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = + dt match { + case BooleanType => ArrowType.Bool.INSTANCE + case ByteType => new ArrowType.Int(8, true) + case ShortType => new ArrowType.Int(8 * 2, true) + case IntegerType => new ArrowType.Int(8 * 4, true) + case LongType => new ArrowType.Int(8 * 8, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE + case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE + case StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE + case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE + // TODO(paleolimbot): DecimalType.Fixed is marked private + // case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) + case DateType => new ArrowType.Date(DateUnit.DAY) + case TimestampType if timeZoneId == null => + throw new IllegalStateException("Missing timezoneId where it is mandatory.") + case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) + case TimestampNTZType => + new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) + case NullType => ArrowType.Null.INSTANCE + case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) + case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) + case _ => + // throw ExecutionErrors.unsupportedDataTypeError(dt) + throw new RuntimeException("Unsupported data type") + } + + def fromArrowType(dt: ArrowType): DataType = dt match { + case ArrowType.Bool.INSTANCE => BooleanType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType + case float: ArrowType.FloatingPoint + if float.getPrecision() == FloatingPointPrecision.SINGLE => + FloatType + case float: ArrowType.FloatingPoint + if float.getPrecision() == FloatingPointPrecision.DOUBLE => + DoubleType + case ArrowType.Utf8.INSTANCE => StringType + case ArrowType.Binary.INSTANCE => BinaryType + case ArrowType.LargeUtf8.INSTANCE => StringType + case ArrowType.LargeBinary.INSTANCE => BinaryType + case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) + case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType + case ts: ArrowType.Timestamp + if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => + TimestampNTZType + case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType + case ArrowType.Null.INSTANCE => NullType + case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => + YearMonthIntervalType() + case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType() + // case _ => throw ExecutionErrors.unsupportedArrowTypeError(dt) + case _ => throw new RuntimeException("Unsupported arrow data type") + } + + /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ + def toArrowField( + name: String, + dt: DataType, + nullable: Boolean, + timeZoneId: String, + largeVarTypes: Boolean = false): Field = { + dt match { + case ArrayType(elementType, containsNull) => + val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) + new Field( + name, + fieldType, + Seq( + toArrowField("element", elementType, containsNull, timeZoneId, largeVarTypes)).asJava) + case StructType(fields) => + val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) + new Field( + name, + fieldType, + fields + .map { field => + toArrowField(field.name, field.dataType, field.nullable, timeZoneId, largeVarTypes) + } + .toSeq + .asJava) + case MapType(keyType, valueType, valueContainsNull) => + val mapType = new FieldType(nullable, new ArrowType.Map(false), null) + // Note: Map Type struct can not be null, Struct Type key field can not be null + new Field( + name, + mapType, + Seq( + toArrowField( + MapVector.DATA_VECTOR_NAME, + new StructType() + .add(MapVector.KEY_NAME, keyType, nullable = false) + .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull), + nullable = false, + timeZoneId, + largeVarTypes)).asJava) + case udt: UserDefinedType[_] => + toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes) + case dataType => + val fieldType = + new FieldType(nullable, toArrowType(dataType, timeZoneId, largeVarTypes), null) + new Field(name, fieldType, Seq.empty[Field].asJava) + } + } + + def fromArrowField(field: Field): DataType = { + field.getType match { + case _: ArrowType.Map => + val elementField = field.getChildren.get(0) + val keyType = fromArrowField(elementField.getChildren.get(0)) + val valueType = fromArrowField(elementField.getChildren.get(1)) + MapType(keyType, valueType, elementField.getChildren.get(1).isNullable) + case ArrowType.List.INSTANCE => + val elementField = field.getChildren().get(0) + val elementType = fromArrowField(elementField) + ArrayType(elementType, containsNull = elementField.isNullable) + case ArrowType.Struct.INSTANCE => + val fields = field.getChildren().asScala.map { child => + val dt = fromArrowField(child) + StructField(child.getName, dt, child.isNullable) + } + StructType(fields.toArray) + case arrowType => fromArrowType(arrowType) + } + } + + /** + * Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType + */ + def toArrowSchema( + schema: StructType, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean = false): Schema = { + new Schema(schema.map { field => + toArrowField( + field.name, + deduplicateFieldNames(field.dataType, errorOnDuplicatedFieldNames), + field.nullable, + timeZoneId, + largeVarTypes) + }.asJava) + } + + def fromArrowSchema(schema: Schema): StructType = { + StructType(schema.getFields.asScala.map { field => + val dt = fromArrowField(field) + StructField(field.getName, dt, field.isNullable) + }.toArray) + } + + private def deduplicateFieldNames( + dt: DataType, + errorOnDuplicatedFieldNames: Boolean): DataType = dt match { + case udt: UserDefinedType[_] => + deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames) + case st @ StructType(fields) => + val newNames = if (st.names.toSet.size == st.names.length) { + st.names + } else { + if (errorOnDuplicatedFieldNames) { + // throw ExecutionErrors.duplicatedFieldNameInArrowStructError(st.names) + throw new RuntimeException("duplicated field name in arrow struct") + } + val genNawName = st.names.groupBy(identity).map { + case (name, names) if names.length > 1 => + val i = new AtomicInteger() + name -> { () => s"${name}_${i.getAndIncrement()}" } + case (name, _) => name -> { () => name } + } + st.names.map(genNawName(_)()) + } + val newFields = + fields.zip(newNames).map { case (StructField(_, dataType, nullable, metadata), name) => + StructField( + name, + deduplicateFieldNames(dataType, errorOnDuplicatedFieldNames), + nullable, + metadata) + } + StructType(newFields) + case ArrayType(elementType, containsNull) => + ArrayType(deduplicateFieldNames(elementType, errorOnDuplicatedFieldNames), containsNull) + case MapType(keyType, valueType, valueContainsNull) => + MapType( + deduplicateFieldNames(keyType, errorOnDuplicatedFieldNames), + deduplicateFieldNames(valueType, errorOnDuplicatedFieldNames), + valueContainsNull) + case _ => dt + } +} From e0663783bed4a2b5e79fe4a6786d53ee36cb6061 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 25 Mar 2025 11:11:54 -0500 Subject: [PATCH 05/12] fix warning --- .../apache/sedona/sql/datasources/arrow/ArrowSerializer.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala index 1a0e6604b33..cd080ec3528 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala @@ -42,7 +42,6 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types.Decimal -import org.apache.spark.sql.util.ArrowUtils private[sql] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable { self => def asJava: java.util.Iterator[E] = new java.util.Iterator[E] with AutoCloseable { From 212fd87f930b7ee0f2abad4af72e2b42da7cb0a1 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 25 Mar 2025 11:23:36 -0500 Subject: [PATCH 06/12] move arrow stuff to a different directory --- .../{sql/datasources => }/arrow/ArrowEncoderUtils.scala | 2 +- .../sedona/{sql/datasources => }/arrow/ArrowSerializer.scala | 4 ++-- .../sedona/{sql/datasources => }/arrow/ArrowUtils.scala | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) rename spark/common/src/main/scala/org/apache/sedona/{sql/datasources => }/arrow/ArrowEncoderUtils.scala (97%) rename spark/common/src/main/scala/org/apache/sedona/{sql/datasources => }/arrow/ArrowSerializer.scala (99%) rename spark/common/src/main/scala/org/apache/sedona/{sql/datasources => }/arrow/ArrowUtils.scala (99%) diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala b/spark/common/src/main/scala/org/apache/sedona/arrow/ArrowEncoderUtils.scala similarity index 97% rename from spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala rename to spark/common/src/main/scala/org/apache/sedona/arrow/ArrowEncoderUtils.scala index aaedcab8f12..e86aa2ffe7f 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala +++ b/spark/common/src/main/scala/org/apache/sedona/arrow/ArrowEncoderUtils.scala @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.sedona.sql.datasources.arrow +package org.apache.sedona.arrow import scala.collection.JavaConverters._ import scala.reflect.ClassTag diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala b/spark/common/src/main/scala/org/apache/sedona/arrow/ArrowSerializer.scala similarity index 99% rename from spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala rename to spark/common/src/main/scala/org/apache/sedona/arrow/ArrowSerializer.scala index cd080ec3528..70019e289ac 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala +++ b/spark/common/src/main/scala/org/apache/sedona/arrow/ArrowSerializer.scala @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.sedona.sql.datasources.arrow +package org.apache.sedona.arrow import java.io.{ByteArrayOutputStream, OutputStream} import java.lang.invoke.{MethodHandles, MethodType} @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtil import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types.Decimal -private[sql] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable { self => +private[arrow] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable { self => def asJava: java.util.Iterator[E] = new java.util.Iterator[E] with AutoCloseable { override def next() = self.next() diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala b/spark/common/src/main/scala/org/apache/sedona/arrow/ArrowUtils.scala similarity index 99% rename from spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala rename to spark/common/src/main/scala/org/apache/sedona/arrow/ArrowUtils.scala index a545981312b..36cc1e727a4 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala +++ b/spark/common/src/main/scala/org/apache/sedona/arrow/ArrowUtils.scala @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.sedona.sql.datasources.arrow +package org.apache.sedona.arrow import java.util.concurrent.atomic.AtomicInteger @@ -30,7 +30,7 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ -private[sql] object ArrowUtils { +private[arrow] object ArrowUtils { val rootAllocator = new RootAllocator(Long.MaxValue) From 53fcecebb6720d05044d1179927e5c2690fe69ec Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 25 Mar 2025 13:12:39 -0500 Subject: [PATCH 07/12] move ArrowSerializer to spark/spark-3.5 --- spark/common/pom.xml | 15 --------------- .../sedona/sql/datasources/arrow/ArrowWriter.java | 2 +- spark/spark-3.5/pom.xml | 6 ++++++ .../apache/sedona/arrow/ArrowEncoderUtils.scala | 0 .../org/apache/sedona/arrow/ArrowSerializer.scala | 0 .../org/apache/sedona/arrow/ArrowUtils.scala | 0 6 files changed, 7 insertions(+), 16 deletions(-) rename spark/{common => spark-3.5}/src/main/scala/org/apache/sedona/arrow/ArrowEncoderUtils.scala (100%) rename spark/{common => spark-3.5}/src/main/scala/org/apache/sedona/arrow/ArrowSerializer.scala (100%) rename spark/{common => spark-3.5}/src/main/scala/org/apache/sedona/arrow/ArrowUtils.scala (100%) diff --git a/spark/common/pom.xml b/spark/common/pom.xml index 30beeea16ba..c5962ad0c84 100644 --- a/spark/common/pom.xml +++ b/spark/common/pom.xml @@ -37,21 +37,6 @@ - - - org.apache.arrow - arrow-java-root - 18.2.0 - pom - - - - - org.apache.spark - spark-connect-common_${scala.compat.version} - ${spark.version} - - org.apache.sedona sedona-common diff --git a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriter.java b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriter.java index e2dabc6b394..5ad1151ae30 100644 --- a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriter.java +++ b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriter.java @@ -20,12 +20,12 @@ import java.io.IOException; import org.apache.arrow.memory.RootAllocator; +import org.apache.sedona.arrow.ArrowSerializer; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder; import org.apache.spark.sql.catalyst.encoders.RowEncoder; -import org.apache.spark.sql.connect.client.arrow.ArrowSerializer; import org.apache.spark.sql.connector.write.DataWriter; import org.apache.spark.sql.connector.write.WriterCommitMessage; import org.apache.spark.sql.types.StructType; diff --git a/spark/spark-3.5/pom.xml b/spark/spark-3.5/pom.xml index 59a8cf7c82d..d1a616944a1 100644 --- a/spark/spark-3.5/pom.xml +++ b/spark/spark-3.5/pom.xml @@ -48,6 +48,12 @@ + + org.apache.arrow + arrow-java-root + 18.2.0 + pom + org.apache.sedona sedona-spark-common-${spark.compat.version}_${scala.compat.version} diff --git a/spark/common/src/main/scala/org/apache/sedona/arrow/ArrowEncoderUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/arrow/ArrowEncoderUtils.scala similarity index 100% rename from spark/common/src/main/scala/org/apache/sedona/arrow/ArrowEncoderUtils.scala rename to spark/spark-3.5/src/main/scala/org/apache/sedona/arrow/ArrowEncoderUtils.scala diff --git a/spark/common/src/main/scala/org/apache/sedona/arrow/ArrowSerializer.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/arrow/ArrowSerializer.scala similarity index 100% rename from spark/common/src/main/scala/org/apache/sedona/arrow/ArrowSerializer.scala rename to spark/spark-3.5/src/main/scala/org/apache/sedona/arrow/ArrowSerializer.scala diff --git a/spark/common/src/main/scala/org/apache/sedona/arrow/ArrowUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/arrow/ArrowUtils.scala similarity index 100% rename from spark/common/src/main/scala/org/apache/sedona/arrow/ArrowUtils.scala rename to spark/spark-3.5/src/main/scala/org/apache/sedona/arrow/ArrowUtils.scala From 80eb0f572ee12e1a031d3b0d873482655a725600 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 25 Mar 2025 14:33:24 -0500 Subject: [PATCH 08/12] port stubs to Scala --- .../datasources/arrow/ArrowBatchWrite.java | 50 ----------- .../sql/datasources/arrow/ArrowTable.java | 70 --------------- .../sql/datasources/arrow/ArrowWriter.java | 86 ------------------- .../datasources/arrow/ArrowBatchWrite.scala} | 31 +++---- .../datasources/arrow/ArrowDataSource.scala | 43 ++++++++++ .../arrow/ArrowDataWriterFactory.scala} | 25 +++--- .../sql/datasources/arrow/ArrowTable.scala | 61 +++++++++++++ .../arrow/ArrowWriteBuilder.scala} | 21 ++--- .../sql/datasources/arrow/ArrowWriter.scala | 43 ++++++++++ 9 files changed, 179 insertions(+), 251 deletions(-) delete mode 100644 spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.java delete mode 100644 spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTable.java delete mode 100644 spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriter.java rename spark/{common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTableProvider.java => spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.scala} (51%) create mode 100644 spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataSource.scala rename spark/{common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.java => spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.scala} (56%) create mode 100644 spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowTable.scala rename spark/{common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.java => spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.scala} (61%) create mode 100644 spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala diff --git a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.java b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.java deleted file mode 100644 index 25d66747885..00000000000 --- a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.sedona.sql.datasources.arrow; - -import org.apache.spark.sql.connector.write.BatchWrite; -import org.apache.spark.sql.connector.write.DataWriterFactory; -import org.apache.spark.sql.connector.write.LogicalWriteInfo; -import org.apache.spark.sql.connector.write.PhysicalWriteInfo; -import org.apache.spark.sql.connector.write.WriterCommitMessage; - -class ArrowBatchWrite implements BatchWrite { - private final LogicalWriteInfo logicalWriteInfo; - - public ArrowBatchWrite(LogicalWriteInfo info) { - this.logicalWriteInfo = info; - } - - @Override - public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) { - return new ArrowDataWriterFactory(logicalWriteInfo.schema()); - } - - @Override - public void commit(WriterCommitMessage[] messages) { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'commit'"); - } - - @Override - public void abort(WriterCommitMessage[] messages) { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'abort'"); - } -} diff --git a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTable.java b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTable.java deleted file mode 100644 index 5c6463c05b0..00000000000 --- a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTable.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.sedona.sql.datasources.arrow; - -import java.util.HashSet; -import java.util.Set; -import org.apache.spark.sql.connector.catalog.SupportsWrite; -import org.apache.spark.sql.connector.catalog.TableCapability; -import org.apache.spark.sql.connector.write.LogicalWriteInfo; -import org.apache.spark.sql.connector.write.SupportsOverwrite; -import org.apache.spark.sql.connector.write.WriteBuilder; -import org.apache.spark.sql.sources.Filter; -import org.apache.spark.sql.types.StructType; - -class ArrowTable implements SupportsWrite, SupportsOverwrite { - private Set capabilities; - private StructType schema; - - ArrowTable(StructType schema) { - this.schema = schema; - } - - @Override - public String name() { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'name'"); - } - - @Override - public StructType schema() { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'schema'"); - } - - @Override - public Set capabilities() { - if (capabilities == null) { - this.capabilities = new HashSet<>(); - capabilities.add(TableCapability.BATCH_WRITE); - } - return capabilities; - } - - @Override - public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { - return new ArrowWriteBuilder(info); - } - - @Override - public WriteBuilder overwrite(Filter[] filters) { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'overwrite'"); - } -} diff --git a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriter.java b/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriter.java deleted file mode 100644 index 5ad1151ae30..00000000000 --- a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriter.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.sedona.sql.datasources.arrow; - -import java.io.IOException; -import org.apache.arrow.memory.RootAllocator; -import org.apache.sedona.arrow.ArrowSerializer; -import org.apache.spark.sql.Encoder; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder; -import org.apache.spark.sql.catalyst.encoders.RowEncoder; -import org.apache.spark.sql.connector.write.DataWriter; -import org.apache.spark.sql.connector.write.WriterCommitMessage; -import org.apache.spark.sql.types.StructType; - -class ArrowWriter implements DataWriter { - private final int partitionId; - private final long taskId; - private int rowCount; - private AgnosticEncoder encoder; - private Encoder rowEncoder; - // https://github.com/apache/spark/blob/9353e94e50f3f73565f5f0023effd7e265c177b9/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala#L50 - private ArrowSerializer serializer; - - public ArrowWriter(int partitionId, long taskId, StructType schema) { - this.partitionId = partitionId; - this.taskId = taskId; - this.rowCount = 0; - this.encoder = RowEncoder.encoderFor(schema); - this.serializer = new ArrowSerializer(encoder, new RootAllocator(), "UTC"); - - // Create file, write schema - // Problem: ArrowSerializer() does not expose internal to write just the schema - // bytes. - } - - @Override - public void close() throws IOException { - // Close file - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'close'"); - } - - @Override - public void write(InternalRow record) throws IOException { - // Problem: serializer needs a Row but we have an InternalRow - // serializer.append(encoder.fromRow(record)); - - rowCount++; - if (rowCount > 1024) { - // Problem: writeIpcStream() writes both the schema and the batch, but - // we only want the batch - // serializer.writeIpcStream(null); - rowCount = 0; - } - } - - @Override - public WriterCommitMessage commit() throws IOException { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'commit'"); - } - - @Override - public void abort() throws IOException { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'abort'"); - } -} diff --git a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTableProvider.java b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.scala similarity index 51% rename from spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTableProvider.java rename to spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.scala index 5f4d0bc1231..4a3fc1428d4 100644 --- a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowTableProvider.java +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.scala @@ -16,26 +16,23 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.sedona.sql.datasources.arrow; +package org.apache.sedona.sql.datasources.arrow -import java.util.Map; -import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; -import org.apache.spark.sql.connector.expressions.Transform; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.apache.spark.sql.connector.write.BatchWrite +import org.apache.spark.sql.connector.write.{DataWriterFactory, PhysicalWriteInfo} +import org.apache.spark.sql.connector.write.WriterCommitMessage +import org.apache.spark.sql.connector.write.LogicalWriteInfo +import org.apache.spark.sql.catalyst.InternalRow -class ArrowTableProvider implements TableProvider { - - @Override - public StructType inferSchema(CaseInsensitiveStringMap options) { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'inferSchema'"); +case class ArrowBatchWrite(logicalInfo: LogicalWriteInfo) extends BatchWrite { + def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = { + return new ArrowDataWriterFactory(logicalInfo, info) } - @Override - public Table getTable( - StructType schema, Transform[] partitioning, Map properties) { - return new ArrowTable(schema); + def commit(messages: Array[WriterCommitMessage]): Unit = { + } + + def abort(messages: Array[WriterCommitMessage]): Unit = {} + } diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataSource.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataSource.scala new file mode 100644 index 00000000000..20412eec615 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataSource.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow + +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util.Locale +import scala.jdk.CollectionConverters._ +import scala.util.Try +import org.apache.sedona.sql.datasources.geopackage.ArrowTable + +class ArrowDataSource extends FileDataSourceV2 with DataSourceRegister { + + override def fallbackFileFormat: Class[_ <: FileFormat] = { + null + } + + override protected def getTable(options: CaseInsensitiveStringMap): Table = { + ArrowTable("", sparkSession, options, getPaths(options), None) + } + + override def shortName(): String = "arrows" +} diff --git a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.java b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.scala similarity index 56% rename from spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.java rename to spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.scala index 31c99a1f63b..ad8c631709c 100644 --- a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.java +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowDataWriterFactory.scala @@ -16,22 +16,17 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.sedona.sql.datasources.arrow; +package org.apache.sedona.sql.datasources.arrow -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.connector.write.DataWriter; -import org.apache.spark.sql.connector.write.DataWriterFactory; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.connector.write.LogicalWriteInfo +import org.apache.spark.sql.connector.write.PhysicalWriteInfo +import org.apache.spark.sql.connector.write.DataWriterFactory +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.write.DataWriter -class ArrowDataWriterFactory implements DataWriterFactory { - private final StructType schema; - - public ArrowDataWriterFactory(StructType schema) { - this.schema = schema; - } - - @Override - public DataWriter createWriter(int partitionId, long taskId) { - return new ArrowWriter(partitionId, taskId, schema); +case class ArrowDataWriterFactory(logicalInfo: LogicalWriteInfo, physicalInfo: PhysicalWriteInfo) + extends DataWriterFactory { + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + return new ArrowWriter(logicalInfo, physicalInfo, partitionId, taskId) } } diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowTable.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowTable.scala new file mode 100644 index 00000000000..49421a5d313 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowTable.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.geopackage + +import org.apache.hadoop.fs.FileStatus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import scala.jdk.CollectionConverters._ +import org.apache.sedona.sql.datasources.arrow.ArrowWriteBuilder + +case class ArrowTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + null + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + new ArrowWriteBuilder(info) + } + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + None + } + + override def formatName: String = { + "Arrow Stream" + } + + override def fallbackFileFormat: Class[_ <: FileFormat] = { + null + } + +} diff --git a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.java b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.scala similarity index 61% rename from spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.java rename to spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.scala index 19061f433c1..5efee38920c 100644 --- a/spark/common/src/main/java/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.java +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriteBuilder.scala @@ -16,22 +16,17 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.sedona.sql.datasources.arrow; +package org.apache.sedona.sql.datasources.arrow -import org.apache.spark.sql.connector.write.BatchWrite; -import org.apache.spark.sql.connector.write.LogicalWriteInfo; -import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.connector.write.WriteBuilder +import org.apache.spark.sql.connector.write.Write +import org.apache.spark.sql.connector.write.BatchWrite +import org.apache.spark.sql.connector.write.LogicalWriteInfo -class ArrowWriteBuilder implements WriteBuilder { +case class ArrowWriteBuilder(info: LogicalWriteInfo) extends WriteBuilder { - private final LogicalWriteInfo writeInfo; - - public ArrowWriteBuilder(LogicalWriteInfo writeInfo) { - this.writeInfo = writeInfo; + override def buildForBatch(): BatchWrite = { + new ArrowBatchWrite(info) } - @Override - public BatchWrite buildForBatch() { - return new ArrowBatchWrite(writeInfo); - } } diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala new file mode 100644 index 00000000000..00e46898e60 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.arrow + +import org.apache.spark.sql.connector.write.LogicalWriteInfo +import org.apache.spark.sql.connector.write.PhysicalWriteInfo +import org.apache.spark.sql.connector.write.DataWriter +import org.apache.spark.sql.connector.write.WriterCommitMessage +import org.apache.spark.sql.catalyst.InternalRow + +case class ArrowWriter( + logicalInfo: LogicalWriteInfo, + physicalInfo: PhysicalWriteInfo, + partitionId: Int, + taskId: Long) + extends DataWriter[InternalRow] { + def write(record: InternalRow): Unit = {} + + def commit(): WriterCommitMessage = { + null + } + + def abort(): Unit = {} + + def close(): Unit = {} + +} From 88aff3209f5c8c8f0dd58a05948efb836f2830fa Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 25 Mar 2025 14:39:33 -0500 Subject: [PATCH 09/12] format --- .../apache/sedona/sql/datasources/arrow/ArrowBatchWrite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.scala index 4a3fc1428d4..750e99ee669 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowBatchWrite.scala @@ -29,9 +29,7 @@ case class ArrowBatchWrite(logicalInfo: LogicalWriteInfo) extends BatchWrite { return new ArrowDataWriterFactory(logicalInfo, info) } - def commit(messages: Array[WriterCommitMessage]): Unit = { - - } + def commit(messages: Array[WriterCommitMessage]): Unit = {} def abort(messages: Array[WriterCommitMessage]): Unit = {} From 8993eb489e8eb6d1b69fdd39f8f13265d012e9f4 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 25 Mar 2025 15:11:25 -0500 Subject: [PATCH 10/12] move everything to the same folder --- .../sedona/{ => sql/datasources}/arrow/ArrowEncoderUtils.scala | 2 +- .../sedona/{ => sql/datasources}/arrow/ArrowSerializer.scala | 2 +- .../apache/sedona/{ => sql/datasources}/arrow/ArrowUtils.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) rename spark/spark-3.5/src/main/scala/org/apache/sedona/{ => sql/datasources}/arrow/ArrowEncoderUtils.scala (97%) rename spark/spark-3.5/src/main/scala/org/apache/sedona/{ => sql/datasources}/arrow/ArrowSerializer.scala (99%) rename spark/spark-3.5/src/main/scala/org/apache/sedona/{ => sql/datasources}/arrow/ArrowUtils.scala (99%) diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/arrow/ArrowEncoderUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala similarity index 97% rename from spark/spark-3.5/src/main/scala/org/apache/sedona/arrow/ArrowEncoderUtils.scala rename to spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala index e86aa2ffe7f..aaedcab8f12 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/arrow/ArrowEncoderUtils.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowEncoderUtils.scala @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.sedona.arrow +package org.apache.sedona.sql.datasources.arrow import scala.collection.JavaConverters._ import scala.reflect.ClassTag diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/arrow/ArrowSerializer.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala similarity index 99% rename from spark/spark-3.5/src/main/scala/org/apache/sedona/arrow/ArrowSerializer.scala rename to spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala index 70019e289ac..1500e635bae 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/arrow/ArrowSerializer.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.sedona.arrow +package org.apache.sedona.sql.datasources.arrow import java.io.{ByteArrayOutputStream, OutputStream} import java.lang.invoke.{MethodHandles, MethodType} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/arrow/ArrowUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala similarity index 99% rename from spark/spark-3.5/src/main/scala/org/apache/sedona/arrow/ArrowUtils.scala rename to spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala index 36cc1e727a4..60006412873 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/arrow/ArrowUtils.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowUtils.scala @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.sedona.arrow +package org.apache.sedona.sql.datasources.arrow import java.util.concurrent.atomic.AtomicInteger From 30740a9f293c281c429b3ff92ddc98c62052077e Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 25 Mar 2025 15:36:08 -0500 Subject: [PATCH 11/12] construct a serializer --- .../datasources/arrow/ArrowSerializer.scala | 30 +++++++++++++++++++ .../sql/datasources/arrow/ArrowWriter.scala | 29 ++++++++++++++---- 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala index 1500e635bae..f97d3d27c17 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowSerializer.scala @@ -124,6 +124,36 @@ class ArrowSerializer[T]( } } + /** + * Write the schema in Arrow IPC stream format to the [[OutputStream]]. + */ + def writeSchema(output: OutputStream): Unit = { + val channel = newChannel(output) + channel.write(schemaBytes) + } + + /** + * Write current batch in Arrow IPC stream format to the [[OutputStream]]. + */ + def writeBatch(output: OutputStream): Unit = { + val channel = newChannel(output) + root.setRowCount(rowCount) + val batch = unloader.getRecordBatch + try { + MessageSerializer.serialize(channel, batch) + } finally { + batch.close() + } + } + + /** + * Write the end-of-stream bytes to [[OutputStream]]. + */ + def writeEndOfStream(output: OutputStream): Unit = { + val channel = newChannel(output) + ArrowStreamWriter.writeEndOfStream(channel, IpcOption.DEFAULT) + } + /** * Reset the serializer. */ diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala index 00e46898e60..edc47802077 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala @@ -21,15 +21,32 @@ package org.apache.sedona.sql.datasources.arrow import org.apache.spark.sql.connector.write.LogicalWriteInfo import org.apache.spark.sql.connector.write.PhysicalWriteInfo import org.apache.spark.sql.connector.write.DataWriter +import org.apache.arrow.memory.RootAllocator; import org.apache.spark.sql.connector.write.WriterCommitMessage import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.Row; +import java.io.ByteArrayOutputStream +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder + +case class ArrowWriter() extends DataWriter[InternalRow] { + + private var encoder: AgnosticEncoder[Row] = _ + private var serializer: ArrowSerializer[Row] = _ + private var dummyOutput: ByteArrayOutputStream = _ + + def this( + logicalInfo: LogicalWriteInfo, + physicalInfo: PhysicalWriteInfo, + partitionId: Int, + taskId: Long) { + this() + dummyOutput = new ByteArrayOutputStream() + encoder = RowEncoder.encoderFor(logicalInfo.schema()) + serializer = new ArrowSerializer[Row](encoder, new RootAllocator(), "UTC"); + } -case class ArrowWriter( - logicalInfo: LogicalWriteInfo, - physicalInfo: PhysicalWriteInfo, - partitionId: Int, - taskId: Long) - extends DataWriter[InternalRow] { def write(record: InternalRow): Unit = {} def commit(): WriterCommitMessage = { From a2968a87156cf72ffe80899dec90d395bf8a6c3f Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 25 Mar 2025 16:04:07 -0500 Subject: [PATCH 12/12] mock write --- .../sql/datasources/arrow/ArrowWriter.scala | 38 ++++++++++++++++++- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala index edc47802077..ee1d84af1e2 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/arrow/ArrowWriter.scala @@ -29,12 +29,16 @@ import java.io.ByteArrayOutputStream import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder case class ArrowWriter() extends DataWriter[InternalRow] { private var encoder: AgnosticEncoder[Row] = _ + private var rowDeserializer: ExpressionEncoder.Deserializer[Row] = _ private var serializer: ArrowSerializer[Row] = _ private var dummyOutput: ByteArrayOutputStream = _ + private var rowCount: Long = 0 def this( logicalInfo: LogicalWriteInfo, @@ -44,10 +48,35 @@ case class ArrowWriter() extends DataWriter[InternalRow] { this() dummyOutput = new ByteArrayOutputStream() encoder = RowEncoder.encoderFor(logicalInfo.schema()) + rowDeserializer = Encoders + .row(logicalInfo.schema()) + .asInstanceOf[ExpressionEncoder[Row]] + .resolveAndBind() + .createDeserializer() serializer = new ArrowSerializer[Row](encoder, new RootAllocator(), "UTC"); + + serializer.writeSchema(dummyOutput) + } + + def write(record: InternalRow): Unit = { + serializer.append(rowDeserializer.apply(record)) + rowCount = rowCount + 1 + if (shouldFlush()) { + flush() + } } - def write(record: InternalRow): Unit = {} + private def shouldFlush(): Boolean = { + // Can use serializer.sizeInBytes() to parameterize batch size in terms of bytes + // or just use rowCount (batches of ~1024 rows are common and 16 MB is also a common + // threshold (maybe also applying a minimum row count in case we're dealing with big + // geometries). Checking sizeInBytes() should be done sparingly (expensive)). + rowCount >= 1024 + } + + private def flush(): Unit = { + serializer.writeBatch(dummyOutput) + } def commit(): WriterCommitMessage = { null @@ -55,6 +84,11 @@ case class ArrowWriter() extends DataWriter[InternalRow] { def abort(): Unit = {} - def close(): Unit = {} + def close(): Unit = { + flush() + serializer.writeEndOfStream(dummyOutput) + serializer.close() + dummyOutput.close() + } }