From 3328ac718ca07ae077de0092a4a194772a1b73b8 Mon Sep 17 00:00:00 2001 From: Herbert Li Date: Sun, 9 Dec 2018 14:33:13 -0500 Subject: [PATCH] change to lr --- source_code/nyc-spark/src/main/scala/RFGreen.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/source_code/nyc-spark/src/main/scala/RFGreen.scala b/source_code/nyc-spark/src/main/scala/RFGreen.scala index 1e7ab9d..bdad20a 100644 --- a/source_code/nyc-spark/src/main/scala/RFGreen.scala +++ b/source_code/nyc-spark/src/main/scala/RFGreen.scala @@ -3,7 +3,7 @@ import java.sql.Date import java.util.Calendar import org.apache.spark.ml.feature._ -import org.apache.spark.ml.regression.GBTRegressor +import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.{Pipeline, PipelineStage} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ @@ -117,12 +117,13 @@ object RFGreen { // Create our random forest model println("Creating Random Forest...") - val gbt = new GBTRegressor(uid = "gbt_regression") + val gbt = new LinearRegression(uid = "linear_regression") .setFeaturesCol("features_rf") .setLabelCol("passengers") .setPredictionCol("passengers_prediction") - .setMaxIter(50) - .setMaxBins(366) + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) val sparkPipelineEstimatorRf = new Pipeline().setStages(Array(sparkFeaturePipelineModel, gbt)) println("Training Random Forest...")