@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path
25
25
26
26
import org .apache .spark .SparkException
27
27
import org .apache .spark .annotation .{Experimental , Since }
28
+ import org .apache .spark .broadcast .Broadcast
28
29
import org .apache .spark .internal .Logging
29
30
import org .apache .spark .ml .feature .Instance
30
31
import org .apache .spark .ml .linalg ._
@@ -346,8 +347,9 @@ class LogisticRegression @Since("1.2.0") (
346
347
val regParamL1 = $(elasticNetParam) * $(regParam)
347
348
val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
348
349
350
+ val bcFeaturesStd = instances.context.broadcast(featuresStd)
349
351
val costFun = new LogisticCostFun (instances, numClasses, $(fitIntercept),
350
- $(standardization), featuresStd, featuresMean , regParamL2)
352
+ $(standardization), bcFeaturesStd , regParamL2)
351
353
352
354
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0 ) {
353
355
new BreezeLBFGS [BDV [Double ]]($(maxIter), 10 , $(tol))
@@ -442,6 +444,7 @@ class LogisticRegression @Since("1.2.0") (
442
444
rawCoefficients(i) *= { if (featuresStd(i) != 0.0 ) 1.0 / featuresStd(i) else 0.0 }
443
445
i += 1
444
446
}
447
+ bcFeaturesStd.destroy(blocking = false )
445
448
446
449
if ($(fitIntercept)) {
447
450
(Vectors .dense(rawCoefficients.dropRight(1 )).compressed, rawCoefficients.last,
@@ -938,11 +941,15 @@ class BinaryLogisticRegressionSummary private[classification] (
938
941
* Two LogisticAggregator can be merged together to have a summary of loss and gradient of
939
942
* the corresponding joint dataset.
940
943
*
944
+ * @param bcCoefficients The broadcast coefficients corresponding to the features.
945
+ * @param bcFeaturesStd The broadcast standard deviation values of the features.
941
946
* @param numClasses the number of possible outcomes for k classes classification problem in
942
947
* Multinomial Logistic Regression.
943
948
* @param fitIntercept Whether to fit an intercept term.
944
949
*/
945
950
private class LogisticAggregator (
951
+ val bcCoefficients : Broadcast [Vector ],
952
+ val bcFeaturesStd : Broadcast [Array [Double ]],
946
953
private val numFeatures : Int ,
947
954
numClasses : Int ,
948
955
fitIntercept : Boolean ) extends Serializable {
@@ -958,29 +965,26 @@ private class LogisticAggregator(
958
965
* of the objective function.
959
966
*
960
967
* @param instance The instance of data point to be added.
961
- * @param coefficients The coefficients corresponding to the features.
962
- * @param featuresStd The standard deviation values of the features.
963
968
* @return This LogisticAggregator object.
964
969
*/
965
- def add (
966
- instance : Instance ,
967
- coefficients : Vector ,
968
- featuresStd : Array [Double ]): this .type = {
970
+ def add (instance : Instance ): this .type = {
969
971
instance match { case Instance (label, weight, features) =>
970
972
require(numFeatures == features.size, s " Dimensions mismatch when adding new instance. " +
971
973
s " Expecting $numFeatures but got ${features.size}. " )
972
974
require(weight >= 0.0 , s " instance weight, $weight has to be >= 0.0 " )
973
975
974
976
if (weight == 0.0 ) return this
975
977
976
- val coefficientsArray = coefficients match {
978
+ val coefficientsArray = bcCoefficients.value match {
977
979
case dv : DenseVector => dv.values
978
980
case _ =>
979
981
throw new IllegalArgumentException (
980
- s " coefficients only supports dense vector but got type ${coefficients.getClass}. " )
982
+ " coefficients only supports dense vector" +
983
+ s " but got type ${bcCoefficients.value.getClass}. " )
981
984
}
982
985
val localGradientSumArray = gradientSumArray
983
986
987
+ val featuresStd = bcFeaturesStd.value
984
988
numClasses match {
985
989
case 2 =>
986
990
// For Binary Logistic Regression.
@@ -1077,24 +1081,23 @@ private class LogisticCostFun(
1077
1081
numClasses : Int ,
1078
1082
fitIntercept : Boolean ,
1079
1083
standardization : Boolean ,
1080
- featuresStd : Array [Double ],
1081
- featuresMean : Array [Double ],
1084
+ bcFeaturesStd : Broadcast [Array [Double ]],
1082
1085
regParamL2 : Double ) extends DiffFunction [BDV [Double ]] {
1083
1086
1087
+ val featuresStd = bcFeaturesStd.value
1088
+
1084
1089
override def calculate (coefficients : BDV [Double ]): (Double , BDV [Double ]) = {
1085
1090
val numFeatures = featuresStd.length
1086
1091
val coeffs = Vectors .fromBreeze(coefficients)
1092
+ val bcCoeffs = instances.context.broadcast(coeffs)
1087
1093
val n = coeffs.size
1088
- val localFeaturesStd = featuresStd
1089
-
1090
1094
1091
1095
val logisticAggregator = {
1092
- val seqOp = (c : LogisticAggregator , instance : Instance ) =>
1093
- c.add(instance, coeffs, localFeaturesStd)
1096
+ val seqOp = (c : LogisticAggregator , instance : Instance ) => c.add(instance)
1094
1097
val combOp = (c1 : LogisticAggregator , c2 : LogisticAggregator ) => c1.merge(c2)
1095
1098
1096
1099
instances.treeAggregate(
1097
- new LogisticAggregator (numFeatures, numClasses, fitIntercept)
1100
+ new LogisticAggregator (bcCoeffs, bcFeaturesStd, numFeatures, numClasses, fitIntercept)
1098
1101
)(seqOp, combOp)
1099
1102
}
1100
1103
@@ -1134,6 +1137,7 @@ private class LogisticCostFun(
1134
1137
}
1135
1138
0.5 * regParamL2 * sum
1136
1139
}
1140
+ bcCoeffs.destroy(blocking = false )
1137
1141
1138
1142
(logisticAggregator.loss + regVal, new BDV (totalGradientArray))
1139
1143
}
0 commit comments