Skip to content

Commit 3d8bfe7

Browse files
WeichenXu123yanboliang
authored andcommitted
[SPARK-16934][ML][MLLIB] Update LogisticCostAggregator serialization code to make it consistent with LinearRegression
## What changes were proposed in this pull request? Update LogisticCostAggregator serialization code to make it consistent with apache#14109 ## How was this patch tested? MLlib 2.0: ![image](https://cloud.githubusercontent.com/assets/19235986/17649601/5e2a79ac-61ee-11e6-833c-3bd8b5250470.png) After this PR: ![image](https://cloud.githubusercontent.com/assets/19235986/17649599/52b002ae-61ee-11e6-9402-9feb3439880f.png) Author: WeichenXu <[email protected]> Closes apache#14520 from WeichenXu123/improve_logistic_regression_costfun.
1 parent ddf0d1e commit 3d8bfe7

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

+20-16
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path
2525

2626
import org.apache.spark.SparkException
2727
import org.apache.spark.annotation.{Experimental, Since}
28+
import org.apache.spark.broadcast.Broadcast
2829
import org.apache.spark.internal.Logging
2930
import org.apache.spark.ml.feature.Instance
3031
import org.apache.spark.ml.linalg._
@@ -346,8 +347,9 @@ class LogisticRegression @Since("1.2.0") (
346347
val regParamL1 = $(elasticNetParam) * $(regParam)
347348
val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
348349

350+
val bcFeaturesStd = instances.context.broadcast(featuresStd)
349351
val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
350-
$(standardization), featuresStd, featuresMean, regParamL2)
352+
$(standardization), bcFeaturesStd, regParamL2)
351353

352354
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
353355
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
@@ -442,6 +444,7 @@ class LogisticRegression @Since("1.2.0") (
442444
rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
443445
i += 1
444446
}
447+
bcFeaturesStd.destroy(blocking = false)
445448

446449
if ($(fitIntercept)) {
447450
(Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last,
@@ -938,11 +941,15 @@ class BinaryLogisticRegressionSummary private[classification] (
938941
* Two LogisticAggregator can be merged together to have a summary of loss and gradient of
939942
* the corresponding joint dataset.
940943
*
944+
* @param bcCoefficients The broadcast coefficients corresponding to the features.
945+
* @param bcFeaturesStd The broadcast standard deviation values of the features.
941946
* @param numClasses the number of possible outcomes for k classes classification problem in
942947
* Multinomial Logistic Regression.
943948
* @param fitIntercept Whether to fit an intercept term.
944949
*/
945950
private class LogisticAggregator(
951+
val bcCoefficients: Broadcast[Vector],
952+
val bcFeaturesStd: Broadcast[Array[Double]],
946953
private val numFeatures: Int,
947954
numClasses: Int,
948955
fitIntercept: Boolean) extends Serializable {
@@ -958,29 +965,26 @@ private class LogisticAggregator(
958965
* of the objective function.
959966
*
960967
* @param instance The instance of data point to be added.
961-
* @param coefficients The coefficients corresponding to the features.
962-
* @param featuresStd The standard deviation values of the features.
963968
* @return This LogisticAggregator object.
964969
*/
965-
def add(
966-
instance: Instance,
967-
coefficients: Vector,
968-
featuresStd: Array[Double]): this.type = {
970+
def add(instance: Instance): this.type = {
969971
instance match { case Instance(label, weight, features) =>
970972
require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." +
971973
s" Expecting $numFeatures but got ${features.size}.")
972974
require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
973975

974976
if (weight == 0.0) return this
975977

976-
val coefficientsArray = coefficients match {
978+
val coefficientsArray = bcCoefficients.value match {
977979
case dv: DenseVector => dv.values
978980
case _ =>
979981
throw new IllegalArgumentException(
980-
s"coefficients only supports dense vector but got type ${coefficients.getClass}.")
982+
"coefficients only supports dense vector" +
983+
s"but got type ${bcCoefficients.value.getClass}.")
981984
}
982985
val localGradientSumArray = gradientSumArray
983986

987+
val featuresStd = bcFeaturesStd.value
984988
numClasses match {
985989
case 2 =>
986990
// For Binary Logistic Regression.
@@ -1077,24 +1081,23 @@ private class LogisticCostFun(
10771081
numClasses: Int,
10781082
fitIntercept: Boolean,
10791083
standardization: Boolean,
1080-
featuresStd: Array[Double],
1081-
featuresMean: Array[Double],
1084+
bcFeaturesStd: Broadcast[Array[Double]],
10821085
regParamL2: Double) extends DiffFunction[BDV[Double]] {
10831086

1087+
val featuresStd = bcFeaturesStd.value
1088+
10841089
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
10851090
val numFeatures = featuresStd.length
10861091
val coeffs = Vectors.fromBreeze(coefficients)
1092+
val bcCoeffs = instances.context.broadcast(coeffs)
10871093
val n = coeffs.size
1088-
val localFeaturesStd = featuresStd
1089-
10901094

10911095
val logisticAggregator = {
1092-
val seqOp = (c: LogisticAggregator, instance: Instance) =>
1093-
c.add(instance, coeffs, localFeaturesStd)
1096+
val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance)
10941097
val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2)
10951098

10961099
instances.treeAggregate(
1097-
new LogisticAggregator(numFeatures, numClasses, fitIntercept)
1100+
new LogisticAggregator(bcCoeffs, bcFeaturesStd, numFeatures, numClasses, fitIntercept)
10981101
)(seqOp, combOp)
10991102
}
11001103

@@ -1134,6 +1137,7 @@ private class LogisticCostFun(
11341137
}
11351138
0.5 * regParamL2 * sum
11361139
}
1140+
bcCoeffs.destroy(blocking = false)
11371141

11381142
(logisticAggregator.loss + regVal, new BDV(totalGradientArray))
11391143
}

0 commit comments

Comments
 (0)