Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.Partitioner
import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{catalyst, Encoder, Row}
Expand Down Expand Up @@ -398,6 +399,42 @@ case class AppendColumnsWithObject(
copy(child = newChild)
}

/** Factory for constructing new `RepartitionByPartitioner` nodes. */
object RepartitionByPartitioner {
def apply[T: Encoder, K: Encoder](
keyFunc: T => K,
partitioner: Partitioner,
child: LogicalPlan): RepartitionByPartitioner = {
new RepartitionByPartitioner(
keyFunc.asInstanceOf[Any => Any],
implicitly[Encoder[T]].clsTag.runtimeClass,
implicitly[Encoder[T]].schema,
UnresolvedDeserializer(encoderFor[T].deserializer),
partitioner,
child)
}
}

/**
* Repartitions the input using a custom [[Partitioner]] by applying a key extraction function
* to each row.
*/
case class RepartitionByPartitioner(
keyFunc: Any => Any,
argumentClass: Class[_],
argumentSchema: StructType,
deserializer: Expression,
partitioner: Partitioner,
child: LogicalPlan) extends UnaryNode {

override def output: Seq[Attribute] = child.output

override def maxRows: Option[Long] = child.maxRows

override protected def withNewChildInternal(newChild: LogicalPlan): RepartitionByPartitioner =
copy(child = newChild)
}

/** Factory for constructing new `MapGroups` nodes. */
object MapGroups {
def apply[K : Encoder, T : Encoder, U : Encoder](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.physical
import scala.annotation.tailrec
import scala.collection.mutable

import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.{Partitioner, SparkException, SparkUnsupportedOperationException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
Expand Down Expand Up @@ -676,6 +676,29 @@ case class ShufflePartitionIdPassThrough(
copy(expr = newChildren.head.asInstanceOf[DirectShufflePartitionID])
}

/**
* Represents a partitioning where rows are distributed using a custom [[Partitioner]].
*
* The key extraction function is applied to deserialize each row and extract a key,
* which is then passed to the partitioner to determine the target partition.
*/
case class CustomFunctionPartitioning(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't #52153 enough to cover the custom partition case?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for flagging, I wasn’t aware of #52153 when I put this together. Just read through it.

It looks like repartitionById covers cases where the partition logic can be expressed as a column expression, which handles a lot of use cases cleanly.

The gap I was thinking about is reusing existing Partitioner implementations from RDD codebases, or cases where the logic is complex enough that encapsulating it in a testable class is preferable to inline expressions. But I can see an argument that those are niche enough that repartitionById is sufficient.

Curious whether there’s appetite for supporting both patterns or if the consensus is that this isn’t needed. Happy to close if so.

keyFunc: Any => Any,
deserializer: Expression,
partitioner: Partitioner,
outputAttributes: Seq[Attribute]) extends Partitioning {

override val numPartitions: Int = partitioner.numPartitions

// Cannot satisfy ClusteredDistribution because we don't know the semantics of the
// user-provided partitioner (e.g., it may not guarantee co-location of same keys).
override def satisfies0(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
case AllTuples => numPartitions == 1
case _ => false
}
}

trait ShuffleSpec {
/**
* Returns the number of partitions of this shuffle spec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import scala.util.control.NonFatal

import org.apache.commons.text.StringEscapeUtils

import org.apache.spark.{sql, SparkException, TaskContext}
import org.apache.spark.{sql, Partitioner, SparkException, TaskContext}
import org.apache.spark.annotation.{DeveloperApi, Stable, Unstable}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
Expand Down Expand Up @@ -1549,6 +1549,27 @@ class Dataset[T] private[sql](
repartitionByExpression(Some(numPartitions), Seq(directShufflePartitionIdCol))
}

/**
* Returns a new Dataset partitioned using the specified partitioner.
*
* This is similar to RDD's `partitionBy` method. The key extraction function is applied to each
* element to extract the key, which is then passed to the partitioner to determine the target
* partition.
*
* {{{
* // Repartition using a custom partitioner
* ds.repartition[String](_.userId, new HashPartitioner(100))
* }}}
*
* @group typedrel
* @since 4.1.0
*/
def repartition[K: Encoder](
keyFunc: T => K,
partitioner: Partitioner): Dataset[T] = withSameTypedPlan {
RepartitionByPartitioner(keyFunc, partitioner, logicalPlan)
}

protected def repartitionByRange(
numPartitions: Option[Int],
partitionExprs: Seq[Column]): Dataset[T] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.CustomFunctionPartitioning
import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
Expand Down Expand Up @@ -1027,6 +1028,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
} else {
execution.CoalesceExec(numPartitions, planLater(child)) :: Nil
}
case r: logical.RepartitionByPartitioner =>
val partitioning = CustomFunctionPartitioning(
r.keyFunc,
r.deserializer,
r.partitioner,
r.child.output)
ShuffleExchangeExec(partitioning, planLater(r.child), REPARTITION_BY_NUM) :: Nil
case logical.Sort(sortExprs, global, child, _) =>
execution.SortExec(sortExprs, global, planLater(child)) :: Nil
case logical.Project(projectList, child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ object ShuffleExchangeExec {
case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index)
}.toMap
new KeyGroupedPartitioner(mutable.Map(valueMap.toSeq: _*), n)
case c: CustomFunctionPartitioning => c.partitioner
case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning")
// TODO: Handle BroadcastPartitioning.
}
Expand Down Expand Up @@ -408,6 +409,9 @@ object ShuffleExchangeExec {
// If the value is null, `InternalRow#getInt` returns 0.
val projection = UnsafeProjection.create(s.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
case c: CustomFunctionPartitioning =>
val getObject = ObjectOperator.deserializeRowToObject(c.deserializer, c.outputAttributes)
row => c.keyFunc(getObject(row))
case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution

import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.{HashPartitioner, Partitioner, SparkUnsupportedOperationException}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, DataFrame, Row}
import org.apache.spark.sql.AnalysisException
Expand Down Expand Up @@ -1584,6 +1584,96 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
// The groupBy reuse the output partitioning after DirectShufflePartitionID.
checkShuffleCount(grouped, 3)
}

test("SPARK-27853: repartition with custom Partitioner using HashPartitioner") {
val ds = spark.range(100).as[Long]
val numPartitions = 10
val partitioner = new HashPartitioner(numPartitions)

val repartitioned = ds.repartition[Long](identity, partitioner)

assert(repartitioned.rdd.getNumPartitions == numPartitions)
assert(repartitioned.count() == 100)

val result = repartitioned.withColumn("partition_id", spark_partition_id()).collect()
result.foreach { row =>
val value = row.getAs[Long]("id")
val actualPartition = row.getAs[Int]("partition_id")
val expectedPartition = (value.hashCode() % numPartitions + numPartitions) % numPartitions
assert(actualPartition == expectedPartition,
s"Value $value should be in partition $expectedPartition but was in $actualPartition")
}
}

test("SPARK-27853: repartition with custom Partitioner and key extraction function") {
val ds = Seq(
("Alice", 25),
("Bob", 30),
("Charlie", 25),
("David", 30)
).toDF("name", "age").as[(String, Int)]

val partitioner = new HashPartitioner(5)
val repartitioned = ds.repartition[Int](_._2, partitioner)

assert(repartitioned.rdd.getNumPartitions == 5)
assert(repartitioned.count() == 4)

val result = repartitioned.withColumn("partition_id", spark_partition_id()).collect()
val alicePartition = result.find(_.getAs[String]("name") == "Alice")
.get.getAs[Int]("partition_id")
val charliePartition = result.find(_.getAs[String]("name") == "Charlie")
.get.getAs[Int]("partition_id")
assert(alicePartition == charliePartition)

val bobPartition = result.find(_.getAs[String]("name") == "Bob")
.get.getAs[Int]("partition_id")
val davidPartition = result.find(_.getAs[String]("name") == "David")
.get.getAs[Int]("partition_id")
assert(bobPartition == davidPartition)
}

test("SPARK-27853: repartition with user-defined Partitioner") {
class EvenOddPartitioner extends Partitioner {
override def numPartitions: Int = 2
override def getPartition(key: Any): Int = key match {
case l: Long => if (l % 2 == 0) 0 else 1
case _ => 0
}
}

val ds = spark.range(20).as[Long]
val repartitioned = ds.repartition[Long](identity, new EvenOddPartitioner())

assert(repartitioned.rdd.getNumPartitions == 2)

val result = repartitioned.withColumn("partition_id", spark_partition_id()).collect()
result.foreach { row =>
val value = row.getAs[Long]("id")
val partition = row.getAs[Int]("partition_id")
val expected = if (value % 2 == 0) 0 else 1
assert(partition == expected,
s"Value $value should be in partition $expected but was in $partition")
}
}

test("SPARK-27853: repartition with custom Partitioner introduces shuffle") {
val ds = spark.range(100).as[Long]
val repartitioned = ds.repartition[Long](identity, new HashPartitioner(10))

val shuffles = collect(repartitioned.queryExecution.executedPlan) {
case s: ShuffleExchangeLike => s
}
assert(shuffles.nonEmpty)
}

test("SPARK-27853: repartition with custom Partitioner preserves data") {
val original = (1 to 1000).map(i => (i, s"value_$i")).toDF("id", "value").as[(Int, String)]

val repartitioned = original.repartition[Int](_._1, new HashPartitioner(20))

assert(original.collect().toSet == repartitioned.collect().toSet)
}
}

// Used for unit-testing EnsureRequirements
Expand Down