diff --git a/source_code/nyc-spark/src/main/scala/RFGreen.scala b/source_code/nyc-spark/src/main/scala/RFGreen.scala index 8c35f40..1e7ab9d 100644 --- a/source_code/nyc-spark/src/main/scala/RFGreen.scala +++ b/source_code/nyc-spark/src/main/scala/RFGreen.scala @@ -3,7 +3,7 @@ import java.sql.Date import java.util.Calendar import org.apache.spark.ml.feature._ -import org.apache.spark.ml.regression.RandomForestRegressor +import org.apache.spark.ml.regression.GBTRegressor import org.apache.spark.ml.{Pipeline, PipelineStage} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ @@ -117,14 +117,14 @@ object RFGreen { // Create our random forest model println("Creating Random Forest...") - val randomForest = new RandomForestRegressor(uid = "random_forest_regression") + val gbt = new GBTRegressor(uid = "gbt_regression") .setFeaturesCol("features_rf") .setLabelCol("passengers") .setPredictionCol("passengers_prediction") - .setNumTrees(50) + .setMaxIter(50) .setMaxBins(366) - val sparkPipelineEstimatorRf = new Pipeline().setStages(Array(sparkFeaturePipelineModel, randomForest)) + val sparkPipelineEstimatorRf = new Pipeline().setStages(Array(sparkFeaturePipelineModel, gbt)) println("Training Random Forest...") val sparkPipelineRf = sparkPipelineEstimatorRf.fit(dataset) println("Completed training Random Forest")