Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51379][ML] Move treeAggregate's final aggregation from driver to executor #50142

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ private[ml] class WeightedLeastSquares(
instr.logWarning("regParam is zero, which might cause numerical instability and overfitting.")
}

val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_), depth)
val summary = instances.treeAggregate[Aggregator](
zeroValue = new Aggregator,
seqOp = (agg: Aggregator, x: Instance) => agg.add(x),
combOp = (agg1: Aggregator, agg2: Aggregator) => agg1.merge(agg2),
depth = depth,
finalAggregateOnExecutor = true)
summary.validate()
instr.logInfo(log"Number of instances: ${MDC(COUNT, summary.count)}.")
val k = if (fitIntercept) summary.k + 1 else summary.k
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ private[ml] class RDDLossFunction[
val thisAgg = getAggregator(bcCoefficients)
val seqOp = (agg: Agg, x: T) => agg.add(x)
val combOp = (agg1: Agg, agg2: Agg) => agg1.merge(agg2)
val newAgg = instances.treeAggregate(thisAgg)(seqOp, combOp, aggregationDepth)
val newAgg = instances.treeAggregate(thisAgg, seqOp, combOp, aggregationDepth, true)
val gradient = newAgg.gradient
val regLoss = regularization.map { regFun =>
val (regLoss, regGradient) = regFun.calculate(Vectors.fromBreeze(coefficients))
Expand Down
12 changes: 7 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,16 @@ object Summarizer extends Logging {
aggregationDepth: Int = 2,
requested: Seq[String] = Seq("mean", "std", "count")) = {
instances.treeAggregate(
(Summarizer.createSummarizerBuffer(requested: _*),
Summarizer.createSummarizerBuffer("mean", "std", "count")))(
zeroValue = (Summarizer.createSummarizerBuffer(requested: _*),
Summarizer.createSummarizerBuffer("mean", "std", "count")),
seqOp = (c: (SummarizerBuffer, SummarizerBuffer), instance: Instance) =>
(c._1.add(instance.features, instance.weight),
c._2.add(Vectors.dense(instance.label), instance.weight)),
combOp = (c1: (SummarizerBuffer, SummarizerBuffer),
c2: (SummarizerBuffer, SummarizerBuffer)) =>
(c1._1.merge(c2._1), c1._2.merge(c2._2)),
depth = aggregationDepth
depth = aggregationDepth,
finalAggregateOnExecutor = true
)
}

Expand All @@ -235,13 +236,14 @@ object Summarizer extends Logging {
aggregationDepth: Int = 2,
requested: Seq[String] = Seq("mean", "std", "count")) = {
instances.treeAggregate(
(Summarizer.createSummarizerBuffer(requested: _*), new MultiClassSummarizer))(
zeroValue = (Summarizer.createSummarizerBuffer(requested: _*), new MultiClassSummarizer),
seqOp = (c: (SummarizerBuffer, MultiClassSummarizer), instance: Instance) =>
(c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight)),
combOp = (c1: (SummarizerBuffer, MultiClassSummarizer),
c2: (SummarizerBuffer, MultiClassSummarizer)) =>
(c1._1.merge(c2._1), c1._2.merge(c2._2)),
depth = aggregationDepth
depth = aggregationDepth,
finalAggregateOnExecutor = true
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,12 @@ class GaussianMixture private (
val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_)

// aggregate the cluster contribution for all sample points
val sums = breezeData.treeAggregate(ExpectationSum.zero(k, d))(compute.value, _ += _)
val sums = breezeData.treeAggregate[ExpectationSum](
zeroValue = ExpectationSum.zero(k, d),
seqOp = (agg: ExpectationSum, v: BV[Double]) => compute.value(agg, v),
combOp = (agg1: ExpectationSum, agg2: ExpectationSum) => agg1 += agg2,
depth = 2,
finalAggregateOnExecutor = true)

// Create new distributions based on the partial assignments
// (often referred to as the "M" step in literature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging {
}

val (statsSum: BDM[Double], logphatOption: Option[BDV[Double]], nonEmptyDocsN: Long) = stats
.treeAggregate((null.asInstanceOf[BDM[Double]], logphatPartOptionBase(), 0L))(
elementWiseSum, elementWiseSum
.treeAggregate((null.asInstanceOf[BDM[Double]], logphatPartOptionBase(), 0L),
elementWiseSum, elementWiseSum, 2, true
)

expElogbetaBc.destroy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double]
* and labels on one pass.
*/
private val summary: MultilabelSummarizer = {
predictionAndLabels
.treeAggregate(new MultilabelSummarizer)(
(summary, sample) => summary.add(sample._1, sample._2),
(sum1, sum2) => sum1.merge(sum2)
)
predictionAndLabels.treeAggregate[MultilabelSummarizer](
zeroValue = new MultilabelSummarizer,
seqOp = (summary: MultilabelSummarizer,
sample: (Array[Double], Array[Double])) => summary.add(sample._1, sample._2),
combOp = (sum1: MultilabelSummarizer, sum2: MultilabelSummarizer) => sum1.merge(sum2),
depth = 2,
finalAggregateOnExecutor = true)
}


Expand Down
14 changes: 9 additions & 5 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,15 @@ class IDF @Since("1.2.0") (@Since("1.2.0") val minDocFreq: Int) {
*/
@Since("1.1.0")
def fit(dataset: RDD[Vector]): IDFModel = {
val (idf: Vector, docFreq: Array[Long], numDocs: Long) = dataset.treeAggregate(
new IDF.DocumentFrequencyAggregator(minDocFreq = minDocFreq))(
seqOp = (df, v) => df.add(v),
combOp = (df1, df2) => df1.merge(df2)
).idf()
val (idf: Vector, docFreq: Array[Long], numDocs: Long) = dataset
.treeAggregate[IDF.DocumentFrequencyAggregator](
zeroValue = new IDF.DocumentFrequencyAggregator(minDocFreq = minDocFreq),
seqOp = (df: IDF.DocumentFrequencyAggregator, v: Vector) => df.add(v),
combOp = (df1: IDF.DocumentFrequencyAggregator,
df2: IDF.DocumentFrequencyAggregator) => df1.merge(df2),
depth = 2,
finalAggregateOnExecutor = true
).idf()
new IDFModel(idf, docFreq, numDocs)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ class RowMatrix @Since("1.0.0") (
private[mllib] def multiplyGramianMatrixBy(v: BDV[Double]): BDV[Double] = {
val n = numCols().toInt
val vbr = rows.context.broadcast(v)
rows.treeAggregate(null.asInstanceOf[BDV[Double]])(
seqOp = (U, r) => {
rows.treeAggregate[BDV[Double]](
zeroValue = null.asInstanceOf[BDV[Double]],
seqOp = (U: BDV[Double], r: Vector) => {
val rBrz = r.asBreeze
val a = rBrz.dot(vbr.value)
val theU =
Expand All @@ -109,7 +110,8 @@ class RowMatrix @Since("1.0.0") (
s"Do not support vector operation from type ${rBrz.getClass.getName}.")
}
theU
}, combOp = (U1, U2) => {
},
combOp = (U1: BDV[Double], U2: BDV[Double]) => {
if (U1 == null) {
U2
} else if (U2 == null) {
Expand All @@ -118,7 +120,10 @@ class RowMatrix @Since("1.0.0") (
U1 += U2
U1
}
})
},
depth = 2,
finalAggregateOnExecutor = true
)
}

/**
Expand All @@ -136,8 +141,9 @@ class RowMatrix @Since("1.0.0") (
val gramianSizeInBytes = nt * 8L

// Compute the upper triangular part of the gram matrix.
val GU = rows.treeAggregate(null.asInstanceOf[BDV[Double]])(
seqOp = (maybeU, v) => {
val GU = rows.treeAggregate[BDV[Double]](
zeroValue = null.asInstanceOf[BDV[Double]],
seqOp = (maybeU: BDV[Double], v: Vector) => {
val U =
if (maybeU == null) {
new BDV[Double](nt)
Expand All @@ -146,15 +152,17 @@ class RowMatrix @Since("1.0.0") (
}
BLAS.spr(1.0, v, U.data)
U
}, combOp = (U1, U2) =>
},
combOp = (U1: BDV[Double], U2: BDV[Double]) =>
if (U1 == null) {
U2
} else if (U2 == null) {
U1
} else {
U1 += U2
},
depth = getTreeAggregateIdealDepth(gramianSizeInBytes)
depth = getTreeAggregateIdealDepth(gramianSizeInBytes),
finalAggregateOnExecutor = true
)

RowMatrix.triuToFull(n, GU.data)
Expand All @@ -168,8 +176,9 @@ class RowMatrix @Since("1.0.0") (
// This succeeds when n <= 65535, which is checked above
val nt = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2))

val MU = rows.treeAggregate(null.asInstanceOf[BDV[Double]])(
seqOp = (maybeU, v) => {
val MU = rows.treeAggregate[BDV[Double]](
zeroValue = null.asInstanceOf[BDV[Double]],
seqOp = (maybeU: BDV[Double], v: Vector) => {
val U =
if (maybeU == null) {
new BDV[Double](nt)
Expand All @@ -188,14 +197,17 @@ class RowMatrix @Since("1.0.0") (

BLAS.spr(1.0, new DenseVector(na), U.data)
U
}, combOp = (U1, U2) =>
},
combOp = (U1: BDV[Double], U2: BDV[Double]) =>
if (U1 == null) {
U2
} else if (U2 == null) {
U1
} else {
U1 += U2
}
},
depth = 2,
finalAggregateOnExecutor = true
)

bc.destroy()
Expand Down Expand Up @@ -533,9 +545,13 @@ class RowMatrix @Since("1.0.0") (
*/
@Since("1.0.0")
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
val summary = rows.treeAggregate[MultivariateOnlineSummarizer](
zeroValue = new MultivariateOnlineSummarizer,
seqOp = (aggregator: MultivariateOnlineSummarizer, data: Vector) => aggregator.add(data),
combOp = (aggregator1: MultivariateOnlineSummarizer,
aggregator2: MultivariateOnlineSummarizer) => aggregator1.merge(aggregator2),
depth = 2,
finalAggregateOnExecutor = true)
updateNumRows(summary.count)
summary
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,9 @@ object GradientDescent extends Logging {
// Sample a subset (fraction miniBatchFraction) of the total data
// compute and sum up the subgradients on this subset (this is one map-reduce)
val (gradientSum, lossSum, miniBatchSize) = data.sample(false, miniBatchFraction, 42 + i)
.treeAggregate((null.asInstanceOf[BDV[Double]], 0.0, 0L))(
seqOp = (c, v) => {
.treeAggregate[(BDV[Double], Double, Long)](
zeroValue = (null.asInstanceOf[BDV[Double]], 0.0, 0L),
seqOp = (c: (BDV[Double], Double, Long), v: (Double, Vector)) => {
// c: (grad, loss, count), v: (label, features)
val vec =
if (c._1 == null) {
Expand All @@ -256,7 +257,7 @@ object GradientDescent extends Logging {
val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(vec))
(vec, c._2 + l, c._3 + 1)
},
combOp = (c1, c2) => {
combOp = (c1: (BDV[Double], Double, Long), c2: (BDV[Double], Double, Long)) => {
// c: (grad, loss, count)
val vec =
if (c1._1 == null) {
Expand All @@ -268,7 +269,9 @@ object GradientDescent extends Logging {
c1._1
}
(vec, c1._2 + c2._2, c1._3 + c2._3)
})
},
depth = 2,
finalAggregateOnExecutor = true)
bcWeights.destroy()

if (miniBatchSize > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ object LBFGS extends Logging {
}

val zeroSparseVector = Vectors.sparse(n, Seq.empty)
val (gradientSum, lossSum) = data.treeAggregate((zeroSparseVector, 0.0))(seqOp, combOp)
val (gradientSum, lossSum) = data.treeAggregate(
(zeroSparseVector, 0.0), seqOp, combOp, 2, true)

// broadcasted model is not needed anymore
bcW.destroy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,13 @@ object Statistics {
* @return [[SummarizerBuffer]] object containing column-wise summary statistics.
*/
private[mllib] def colStats(X: RDD[(Vector, Double)], requested: Seq[String]) = {
X.treeAggregate(Summarizer.createSummarizerBuffer(requested: _*))(
seqOp = { case (c, (v, w)) => c.add(v.nonZeroIterator, v.size, w) },
combOp = { case (c1, c2) => c1.merge(c2) },
depth = 2
X.treeAggregate[SummarizerBuffer](
zeroValue = Summarizer.createSummarizerBuffer(requested: _*),
seqOp = (c: SummarizerBuffer,
vw: (Vector, Double)) => c.add(vw._1.nonZeroIterator, vw._1.size, vw._2),
combOp = (c1: SummarizerBuffer, c2: SummarizerBuffer) => c1.merge(c2),
depth = 2,
finalAggregateOnExecutor = true
)
}

Expand Down