diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 07423b612c30..1d5ab0b94230 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.Partitioner import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.{catalyst, Encoder, Row} @@ -398,6 +399,42 @@ case class AppendColumnsWithObject( copy(child = newChild) } +/** Factory for constructing new `RepartitionByPartitioner` nodes. */ +object RepartitionByPartitioner { + def apply[T: Encoder, K: Encoder]( + keyFunc: T => K, + partitioner: Partitioner, + child: LogicalPlan): RepartitionByPartitioner = { + new RepartitionByPartitioner( + keyFunc.asInstanceOf[Any => Any], + implicitly[Encoder[T]].clsTag.runtimeClass, + implicitly[Encoder[T]].schema, + UnresolvedDeserializer(encoderFor[T].deserializer), + partitioner, + child) + } +} + +/** + * Repartitions the input using a custom [[Partitioner]] by applying a key extraction function + * to each row. + */ +case class RepartitionByPartitioner( + keyFunc: Any => Any, + argumentClass: Class[_], + argumentSchema: StructType, + deserializer: Expression, + partitioner: Partitioner, + child: LogicalPlan) extends UnaryNode { + + override def output: Seq[Attribute] = child.output + + override def maxRows: Option[Long] = child.maxRows + + override protected def withNewChildInternal(newChild: LogicalPlan): RepartitionByPartitioner = + copy(child = newChild) +} + /** Factory for constructing new `MapGroups` nodes. */ object MapGroups { def apply[K : Encoder, T : Encoder, U : Encoder]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index b0fa4f889cda..6658fa12a020 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.physical import scala.annotation.tailrec import scala.collection.mutable -import org.apache.spark.{SparkException, SparkUnsupportedOperationException} +import org.apache.spark.{Partitioner, SparkException, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper @@ -676,6 +676,29 @@ case class ShufflePartitionIdPassThrough( copy(expr = newChildren.head.asInstanceOf[DirectShufflePartitionID]) } +/** + * Represents a partitioning where rows are distributed using a custom [[Partitioner]]. + * + * The key extraction function is applied to deserialize each row and extract a key, + * which is then passed to the partitioner to determine the target partition. + */ +case class CustomFunctionPartitioning( + keyFunc: Any => Any, + deserializer: Expression, + partitioner: Partitioner, + outputAttributes: Seq[Attribute]) extends Partitioning { + + override val numPartitions: Int = partitioner.numPartitions + + // Cannot satisfy ClusteredDistribution because we don't know the semantics of the + // user-provided partitioner (e.g., it may not guarantee co-location of same keys). + override def satisfies0(required: Distribution): Boolean = required match { + case UnspecifiedDistribution => true + case AllTuples => numPartitions == 1 + case _ => false + } +} + trait ShuffleSpec { /** * Returns the number of partitions of this shuffle spec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index d02b63b49ca5..739c04d15cde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -28,7 +28,7 @@ import scala.util.control.NonFatal import org.apache.commons.text.StringEscapeUtils -import org.apache.spark.{sql, SparkException, TaskContext} +import org.apache.spark.{sql, Partitioner, SparkException, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ @@ -1549,6 +1549,27 @@ class Dataset[T] private[sql]( repartitionByExpression(Some(numPartitions), Seq(directShufflePartitionIdCol)) } + /** + * Returns a new Dataset partitioned using the specified partitioner. + * + * This is similar to RDD's `partitionBy` method. The key extraction function is applied to each + * element to extract the key, which is then passed to the partitioner to determine the target + * partition. + * + * {{{ + * // Repartition using a custom partitioner + * ds.repartition[String](_.userId, new HashPartitioner(100)) + * }}} + * + * @group typedrel + * @since 4.1.0 + */ + def repartition[K: Encoder]( + keyFunc: T => K, + partitioner: Partitioner): Dataset[T] = withSameTypedPlan { + RepartitionByPartitioner(keyFunc, partitioner, logicalPlan) + } + protected def repartitionByRange( numPartitions: Option[Int], partitionExprs: Seq[Column]): Dataset[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5efad83bcba7..99eab0de3eae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.CustomFunctionPartitioning import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -1027,6 +1028,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } else { execution.CoalesceExec(numPartitions, planLater(child)) :: Nil } + case r: logical.RepartitionByPartitioner => + val partitioning = CustomFunctionPartitioning( + r.keyFunc, + r.deserializer, + r.partitioner, + r.child.output) + ShuffleExchangeExec(partitioning, planLater(r.child), REPARTITION_BY_NUM) :: Nil case logical.Sort(sortExprs, global, child, _) => execution.SortExec(sortExprs, global, planLater(child)) :: Nil case logical.Project(projectList, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index f052bd906880..38a2eb1034dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -375,6 +375,7 @@ object ShuffleExchangeExec { case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index) }.toMap new KeyGroupedPartitioner(mutable.Map(valueMap.toSeq: _*), n) + case c: CustomFunctionPartitioning => c.partitioner case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } @@ -408,6 +409,9 @@ object ShuffleExchangeExec { // If the value is null, `InternalRow#getInt` returns 0. val projection = UnsafeProjection.create(s.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) + case c: CustomFunctionPartitioning => + val getObject = ObjectOperator.deserializeRowToObject(c.deserializer, c.outputAttributes) + row => c.keyFunc(getObject(row)) case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b926cc192bd6..5609c9561d9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.{HashPartitioner, Partitioner, SparkUnsupportedOperationException} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, DataFrame, Row} import org.apache.spark.sql.AnalysisException @@ -1584,6 +1584,96 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { // The groupBy reuse the output partitioning after DirectShufflePartitionID. checkShuffleCount(grouped, 3) } + + test("SPARK-27853: repartition with custom Partitioner using HashPartitioner") { + val ds = spark.range(100).as[Long] + val numPartitions = 10 + val partitioner = new HashPartitioner(numPartitions) + + val repartitioned = ds.repartition[Long](identity, partitioner) + + assert(repartitioned.rdd.getNumPartitions == numPartitions) + assert(repartitioned.count() == 100) + + val result = repartitioned.withColumn("partition_id", spark_partition_id()).collect() + result.foreach { row => + val value = row.getAs[Long]("id") + val actualPartition = row.getAs[Int]("partition_id") + val expectedPartition = (value.hashCode() % numPartitions + numPartitions) % numPartitions + assert(actualPartition == expectedPartition, + s"Value $value should be in partition $expectedPartition but was in $actualPartition") + } + } + + test("SPARK-27853: repartition with custom Partitioner and key extraction function") { + val ds = Seq( + ("Alice", 25), + ("Bob", 30), + ("Charlie", 25), + ("David", 30) + ).toDF("name", "age").as[(String, Int)] + + val partitioner = new HashPartitioner(5) + val repartitioned = ds.repartition[Int](_._2, partitioner) + + assert(repartitioned.rdd.getNumPartitions == 5) + assert(repartitioned.count() == 4) + + val result = repartitioned.withColumn("partition_id", spark_partition_id()).collect() + val alicePartition = result.find(_.getAs[String]("name") == "Alice") + .get.getAs[Int]("partition_id") + val charliePartition = result.find(_.getAs[String]("name") == "Charlie") + .get.getAs[Int]("partition_id") + assert(alicePartition == charliePartition) + + val bobPartition = result.find(_.getAs[String]("name") == "Bob") + .get.getAs[Int]("partition_id") + val davidPartition = result.find(_.getAs[String]("name") == "David") + .get.getAs[Int]("partition_id") + assert(bobPartition == davidPartition) + } + + test("SPARK-27853: repartition with user-defined Partitioner") { + class EvenOddPartitioner extends Partitioner { + override def numPartitions: Int = 2 + override def getPartition(key: Any): Int = key match { + case l: Long => if (l % 2 == 0) 0 else 1 + case _ => 0 + } + } + + val ds = spark.range(20).as[Long] + val repartitioned = ds.repartition[Long](identity, new EvenOddPartitioner()) + + assert(repartitioned.rdd.getNumPartitions == 2) + + val result = repartitioned.withColumn("partition_id", spark_partition_id()).collect() + result.foreach { row => + val value = row.getAs[Long]("id") + val partition = row.getAs[Int]("partition_id") + val expected = if (value % 2 == 0) 0 else 1 + assert(partition == expected, + s"Value $value should be in partition $expected but was in $partition") + } + } + + test("SPARK-27853: repartition with custom Partitioner introduces shuffle") { + val ds = spark.range(100).as[Long] + val repartitioned = ds.repartition[Long](identity, new HashPartitioner(10)) + + val shuffles = collect(repartitioned.queryExecution.executedPlan) { + case s: ShuffleExchangeLike => s + } + assert(shuffles.nonEmpty) + } + + test("SPARK-27853: repartition with custom Partitioner preserves data") { + val original = (1 to 1000).map(i => (i, s"value_$i")).toDF("id", "value").as[(Int, String)] + + val repartitioned = original.repartition[Int](_._1, new HashPartitioner(20)) + + assert(original.collect().toSet == repartitioned.collect().toSet) + } } // Used for unit-testing EnsureRequirements