Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2981,6 +2981,35 @@ abstract class Dataset[T] extends Serializable {
*/
def repartitionById(numPartitions: Int, partitionIdExpr: Column): Dataset[T]

/**
* Proactively optimizes the partition count of this Dataset based on its estimated size.
*
* == Best Practice: Use on Ingest ==
* This method is best used immediately after reading a dataset to ensure the initial
* parallelism matches the data size. This prevents "Small File" issues (too many partitions)
* or "Giant Partition" issues (too few partitions) before heavy transformations begin.
*
* {{{
* val raw = spark.read.parquet("...")
* val optimized = raw.optimizePartitions() // Perfect start for transformations
* optimized.filter(...).groupBy(...)
* }}}
*
* == Warning: Use on Write ==
* This method uses Round Robin partitioning (random shuffle) to balance sizes.
* If used immediately before writing to a partitioned table (e.g., `write.partitionBy("city")`),
* it may degrade performance by breaking data locality, causing the writer to create
* many small files across directories.
*
* @param targetMB The target partition size in Megabytes. Defaults to 128MB.
* @group typedrel
* @since 4.2.0
*/
def optimizePartitions(targetMB: Int = 128): Dataset[T] = {
throw new UnsupportedOperationException("This method is implemented in " +
"the concrete Dataset classes")
}

/**
* Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions
* are requested. If a larger number of partitions is requested, it will stay at the current
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OptimizePartitionsCommand, Repartition}
import org.apache.spark.sql.catalyst.rules.Rule

/**
* Proactively optimizes the partition count of a Dataset based on its estimated size.
* This rule transforms the custom OptimizePartitionsCommand into standard Spark operations.
*/
object OptimizePartitionsRule extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
case OptimizePartitionsCommand(child, targetMB, currentPartitions) =>

val targetBytes = targetMB.toLong * 1024L * 1024L

// Get the estimated size from Catalyst Statistics
val sizeInBytes = child.stats.sizeInBytes

// Calculate Optimal Partition Count (N)
val count = math.ceil(sizeInBytes.toDouble / targetBytes).toInt
val calculatedN: Int = if (count <= 1) 1 else count

// Smart Switch: Coalesce vs Repartition
if (calculatedN < currentPartitions) {
// DOWNSCALING: Use Coalesce (shuffle = false)
Repartition(calculatedN, shuffle = false, child)
} else if (calculatedN > currentPartitions) {
// UPSCALING: Use Repartition (shuffle = true)
Repartition(calculatedN, shuffle = true, child)
} else {
// OPTIMAL
child
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.Attribute

/**
* A logical command that hints to the optimizer that we want to
* automatically repartition the data based on statistics.
*/
case class OptimizePartitionsCommand(child: LogicalPlan,
targetMB: Int,
currentPartitions: Int) extends UnaryNode {

override def output: Seq[Attribute] = child.output

override protected def withNewChildInternal(newChild: LogicalPlan): OptimizePartitionsCommand =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,14 @@ class Dataset[T] private[sql](
}
}

override def optimizePartitions(targetMB: Int): Dataset[T] = {
val currentPartitions = rdd.getNumPartitions

withTypedPlan {
OptimizePartitionsCommand(logicalPlan, targetMB, currentPartitions)
}
}

/** @inheritdoc */
def coalesce(numPartitions: Int): Dataset[T] = withSameTypedPlan {
Repartition(numPartitions, shuffle = false, logicalPlan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class SparkOptimizer(
ConstantFolding,
EliminateLimits),
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*),
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition)))
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition),
Batch("Optimizer Partitions", Once, OptimizePartitionsRule)))

override def nonExcludableRules: Seq[String] = super.nonExcludableRules ++
Seq(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSparkSession

class OptimizePartitionsSuite extends SparkFunSuite with SharedSparkSession {

test("TEST 1: Small Data Compaction (Coalesce)") {
val initialDF = spark.range(10000).repartition(100)
val optimizedDF = initialDF.optimizePartitions()
assert(optimizedDF.rdd.getNumPartitions == 1,
s"Expected 1 partition, got ${optimizedDF.rdd.getNumPartitions}.")
}

test("TEST 2: Scaling Up (Large Data Repartition)") {
val initialDF = spark.range(500000).repartition(1)
// initialDF size is 4MB.
// Passing desired partition = 2MB to trigger increase in partition from 1 to 2.
val optimizedDF = initialDF.optimizePartitions(2)
// We expect number of partitions to increase to 2 so that each partition size is 2MB.
assert(optimizedDF.rdd.getNumPartitions == 2,
s"Expected scaling up, got ${optimizedDF.rdd.getNumPartitions}.")
}
}