When to work with Spark? Normally we just use pandas. However, pandas gets extremely slow when the datasets becomes larger.
Before jumping to Spark, be clear about your purpose. If you just want to process several files with ~2GB size, and that they can be processed line by line without further calculations, you may want to go for line by line approach, because pandas becomes slow when Input-Output flow is large.
If you need to perform more calculations or the size of tables gets to TB size or larger size, you may want to consider to use Spark!
Spark is a technology for parellel computing on clusters. I would like to think of it as pandas on clusters (very inaccurate analogy). Resilient Distributed Dataset is the basic building blocks in Spark. DataFrame is built upon RDD with built in optimizations when it comes to table operations. I would like to think of it as DataFame in pandas (again, just an analogy).
PySpark reference is here.
To use spark in Python, you need to instantiate a SparkContext object in python. You can think of it as a connection to your control server of your cluster. Then you need to create a SparkSession, of which you can think as an interface to this connection.
## Assume we have loaded a SparkSession called spark
## Here we are going to practice how to create a SparkSession
# Import SparkSession from pyspark.sql
from pyspark.sql import SparkSession
# Create my_spark
my_spark = SparkSession.builder.getOrCreate()
# Print
print(my_spark)
From now on let's assume spark stands for SparkSession NOT A SparkContext anymore. Just a naming stuff. Your SparkSession has an attribute called catalog with lists of all data inside the session on your cluster. There are several methods to get information. For example:
# Print the tables in the catalog
print(spark.catalog.listTables())
Amazingly, you can quey tables in spark sessions like datasets in sql databases.
# Assume a table of flights is shown in your catalog
query = "FROM flights SELECT * LIMIT 10"
# Get the first 10 rows of flights, flights10 will be a DataFrame
flights10 = spark.sql(query)
# Show the results
flights10.show()
You can convert spark DataFrames to pandas DataFrames and work on it locally:
# Example query
query = "SELECT origin, dest, COUNT(*) as N FROM flights GROUP BY origin, dest"
# Run the query
flight_counts = spark.sql(query)
# Convert the results to a pandas DataFrame
pd_counts = flight_counts.toPandas()
# Print the head of pd_counts
print(pd_counts.head())
Convert a pandas DataFrame to a spark DataFrame and (not automatically but after some action) work on clusters!
# Create pd_temp
pd_temp = pd.DataFrame(np.random.random(10))
# Create spark_temp from pd_temp
spark_temp = spark.createDataFrame(pd_temp)
# Examine the tables in the catalog
print(spark.catalog.listTables())
# Add spark_temp to the catalog, register it using name "temp", createOrReplaceTempView garantees your table names are not duplicates, it will update existing table if exists.
spark_temp.createOrReplaceTempView("temp")
# Examine the tables in the catalog again
print(spark.catalog.listTables())
Without working with pandas, let's load the data directly.
# Take an example
file_path = "/usr/local/share/datasets/airports.csv"
# Read in the airports data
airports = spark.read.csv(file_path,header=True)
# Show the data
airports.show()
From now on, we have to recognize the unique functionalities of Spark. Forget about pandas dataframes, because those intuitions are not helpful anymore.
Column is an Object type in Spark, and it can be created by df.colName
.
Unlike pandas DataFrame, Spark DataFrame is immutable, meaning you cannot change columns in place.
To add new column generated by some operation on old column to a df you do something like this: df=df.withColumn("newCol", df.oldCol + 1)
.
To replace an old column you do this:df=df.withColumn("oldCol", df.oldCol + 1)
.
# Create the DataFrame flights
flights = spark.table("flights")
# Show the head
flights.show()
# Add duration_hrs
flights = flights.withColumn("duration_hrs", flights.air_time/60.)
# Rename column A to B
flights = flights.withColumnRenamed("A","B")
Passing string of SQL code or Booleans are the same.
# Filter flights by passing a string
long_flights1 = flights.filter("distance > 1000")
# Filter flights by passing a column of boolean values
long_flights2 = flights.filter(flights.distance > 1000)
# Print the data to check they're equal
long_flights1.show()
long_flights2.show()
You may consider filter is like 'SELECT * FROM table, WHERE(your SQL filtering condition here)'
Recall the 'SELECT' statement in SQL. The 'select' in DataFrame is even more powerful. 'selectExpr' is basically equivalent to 'select' but taking string SQL code as argument.
# Select the first set of columns
selected1 = flights.select('tailnum','origin','dest')
# Select the second set of columns
temp = flights.select(flights.origin, flights.dest, flights.carrier)
# Define first filter
filterA = flights.origin == "SEA"
# Define second filter
filterB = flights.dest == "PDX"
# Filter the data, first by filterA then by filterB
selected2 = temp.filter(filterA).filter(filterB)
# Define avg_speed
avg_speed = (flights.distance/(flights.air_time/60)).alias("avg_speed")
# Select the correct columns
speed1 = flights.select("origin", "dest", "tailnum", avg_speed)
# Create the same table using a SQL expression
speed2 = flights.selectExpr("origin", "dest", "tailnum", "distance/(air_time/60) as avg_speed")
Aggregating means summerizing a group of data in some sense. like min(), max(), count().
The .groupBy() method of the DataFrame creates an object of type pyspark.sql.GroupedData. Passing arguments to .groupby() is similar to using groupby in SQL.
# Find the shortest flight from PDX in terms of distance
flights.filter(flights.origin=='PDX').groupBy().min('distance').show()
# Find the longest flight from SEA in terms of air time
flights.filter(flights.origin=='SEA').groupBy().max('air_time').show()
# Average duration of Delta flights
flights.filter(flights.carrier=="DL").filter(flights.origin=="SEA").groupBy().avg('air_time').show()
# Total hours in the air
flights.withColumn("duration_hrs", flights.air_time/60).groupBy().sum("duration_hrs").show()
# Group by tailnum
by_plane = flights.groupBy("tailnum")
# Number of flights each plane made
by_plane.count().show()
# Group by origin
by_origin = flights.groupBy("origin")
# Average duration of flights from PDX and SEA
by_origin.avg("air_time").show()
GroupData objects have another useful method .agg(), which allows you to pass an aggregat column expression that uses any of the aggregate functions from the pyspark.sql.functions submodule. This submodule has many useful functions.
Example:
# Import pyspark.sql.functions as F
import pyspark.sql.functions as F
# Group by month and dest
by_month_dest = flights.groupBy('month','dest')
# Average departure delay by month and destination
by_month_dest.avg("dep_delay").show()
# Standard deviation of departure delay
by_month_dest.agg(F.stddev("dep_delay")).show()
# Examine the data
print(airports.show())
# Rename the faa column
airports = airports.withColumnRenamed("faa","dest")
# Join the DataFrames
flights_with_airports = flights.join(airports,'dest',how='leftouter')
# Examine the new DataFrame
print(flights_with_airports.show())
Before we get to the pipeline, you should understand that there are two main classes for spark: Estimators and Transformers. Estimators have method .fit() and return models. Transformers have method .transform and return DataFrames. Spark modelling mainly relies on these two classes.
First, preprocessing data. Spark requires numerical data for modeling.
# Rename year column
planes = planes.withColumnRenamed("year", "plane_year")
# Join the DataFrames
model_data = flights.join(planes, on='tailnum', how="leftouter")
# Cast the columns to integers
model_data = model_data.withColumn("arr_delay", model_data.arr_delay.cast('integer'))
model_data = model_data.withColumn("air_time", model_data.air_time.cast('integer'))
model_data = model_data.withColumn("month", model_data.month.cast('integer'))
model_data = model_data.withColumn("plane_year", model_data.plane_year.cast('integer'))
# Create the column plane_age
model_data = model_data.withColumn("plane_age", model_data.year - model_data.plane_year)
#Create is_late
model_data = model_data.withColumn("is_late", model_data.arr_delay > 0)
# Convert to an integer
model_data = model_data.withColumn("label", model_data.is_late.cast('integer'))
# Remove missing values
model_data = model_data.filter("arr_delay is not NULL and dep_delay is not NULL and air_time is not NULL and plane_year is not NULL")
Next, process categorical data.
# Create a StringIndexer
carr_indexer = StringIndexer(inputCol="carrier", outputCol="carrier_index")
# Create a OneHotEncoder
carr_encoder = OneHotEncoder(inputCol="carrier_index", outputCol="carrier_fact")
# Create a StringIndexer
dest_indexer = StringIndexer(inputCol="dest", outputCol="dest_index")
# Create a OneHotEncoder
dest_encoder = OneHotEncoder(inputCol="dest_index", outputCol="dest_fact")
The next step is to combine all of the columns of features to a single column.
# Make a VectorAssembler
vec_assembler = VectorAssembler(inputCols=["month", "air_time", "carrier_fact", "dest_fact", "plane_age"], outputCol="features")
Next, use Pipeline to combine all Transformers and Estimators.
# Import Pipeline
from pyspark.ml import Pipeline
# Make the pipeline
flights_pipe = Pipeline(stages=[dest_indexer, dest_encoder, carr_indexer, carr_encoder, vec_assembler])
Then we are going to split the data, after these transformations. Operations like StringIndexer don't always give the same index even with the same list of strings.
# Fit and transform the data
piped_data = flights_pipe.fit(model_data).transform(model_data)
# Split the data into training and test sets
training, test = piped_data.randomSplit([.6, .4])
OK, finally the fun part!
# Import LogisticRegression
from pyspark.ml.classification import LogisticRegression
# Create a LogisticRegression Estimator
lr = LogisticRegression()
# Import the evaluation submodule
import pyspark.ml.evaluation as evals
# Create a BinaryClassificationEvaluator
evaluator = evals.BinaryClassificationEvaluator(metricName="areaUnderROC")
# Import the tuning submodule
import pyspark.ml.tuning as tune
# Create the parameter grid
grid = tune.ParamGridBuilder()
# Add the hyperparameter
grid = grid.addGrid(lr.regParam, np.arange(0, .1, .01)) ## model
grid = grid.addGrid(lr.elasticNetParam, [0 , 1]) ## regularization
# Build the grid
grid = grid.build()
# Create the CrossValidator
cv = tune.CrossValidator(estimator=lr,
estimatorParamMaps=grid,
evaluator=evaluator
)
# Fit cross validation models
models = cv.fit(training)
# Extract the best model
best_lr = models.bestModel
# Use the model to predict the test set
test_results = best_lr.transform(test)
# Evaluate the predictions
print(evaluator.evaluate(test_results))
# Import the PySpark module
from pyspark.sql import SparkSession
# Create SparkSession object
spark = SparkSession.builder \
.master('local[*]') \ ## To connect to a remote location, use: spark://<IP address | DNS name>:<port>
.appName('test') \
.getOrCreate()
# What version of Spark?
print(spark.version)
# Terminate the cluster
spark.stop()
DataFrame: Select Methods:
- count()
- show()
- printSchema()
Selected attributes:
- dtypes
Reading data from csv: cars = spark.read.csv("cars.csv",header=True)
Optional arguments:
- header
- sep
- schema - explicit column data types
- inferSchema - deduce column data types?
- nullValue -placeholder for missing data
This action can have problems. We may prefer cars = spark.read.csv("cars.csv", header=True, inferSchema=True, nullValue='NA')
(Always good to explicitly define missing values.)
check cars.dtypes
to show datatypes of each column. It turns out, columns with missing data will have 'NA' string, that column will be wrongly interpretated as String column.
In that case, we need to specify the schema by hand.
# Read data from CSV file
flights = spark.read.csv('flights.csv',
sep=',',
header=True,
inferSchema=True,
nullValue='NA')
# Get number of records
print("The data contain %d records." % flights.count())
# View the first five records
flights.show(5)
# Check column data types
flights.dtypes
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
# Specify column names and types
schema = StructType([
StructField("id", IntegerType()),
StructField("text", StringType()),
StructField("label", IntegerType())
])
# Load data from a delimited file
sms = spark.read.csv('sms.csv', sep=';', header=False, schema=schema)
# Print schema of DataFrame
sms.printSchema()
Data selection
# Either drop the columns you don't want
cars = cars.drop('maker','model')
# ... or select the columns you do want
cars = cars.select("origin", 'type', 'cyl')
# Filtering out missing vals
## count
cars.filter('cyl is NULL').count()
## drop records with missing values in the cylinders column
cars = cars.filter('cyl IS NOT NULL')
## drop records with any missing data
cars = cars.dropna()
Index categorical data.
from pyspark.ml.feature import StringIndexer
indexer = StringIndexer(inputCol='type',outputCol='type_idx')
# Assign index values to strings
indexer = StringIndexer(inputCol= 'type',
outputCol='type_idx')
# Assign index values to strings
indexer = indexer.fit(cars)
# Create column with index values
cars = indexer.transform(cars)
By defaults the most frequent string will get index 0, and the least frequent string will get the maximum index.
Use stringOrderType
to change order.
Suppose you loaded a dataset named flights
# Remove the 'flight' column
flights = flights.drop('flight')
# Number of records with missing 'delay' values
flights.filter('delay IS NULL').count()
# Remove records with missing 'delay' values
flights = flights.filter('delay IS NOT NULL')
# Remove records with missing values in any column and get the number of remaining rows
flights = flights.dropna()
print(flights.count())
# Import the required function
from pyspark.sql.functions import round
# Convert 'mile' to 'km' and drop 'mile' column
flights_km = flights.withColumn('km', round(flights.mile * 1.60934, 0)) \
.drop('mile')
# Create 'label' column indicating whether flight delayed (1) or not (0)
flights_km = flights_km.withColumn('label', (flights_km.delay>=15).cast('integer'))
# Check first five records
flights_km.show(5)
# Split into training and testing sets in a 80:20 ratio
flights_train, flights_test = flights.randomSplit([0.8,0.2], 17)
# Check that training set has around 80% of records
training_ratio = flights_train.count() / flights.count()
print(training_ratio)
# Import the Decision Tree Classifier class
from pyspark.ml.classification import DecisionTreeClassifier
# Create a classifier object and fit to the training data
tree = DecisionTreeClassifier()
tree_model = tree.fit(flights_train)
# Create predictions for the testing data and take a look at the predictions
prediction = tree_model.transform(flights_test)
prediction.select('label', 'prediction', 'probability').show(5, False)
# Create a confusion matrix
prediction.groupBy('label', 'prediction').count().show()
# Calculate the elements of the confusion matrix
TN = prediction.filter('prediction = 0 AND label = prediction').count()
TP = prediction.filter('prediction=1 AND label = prediction').count()
FN = prediction.filter('prediction=0 AND label!=prediction').count()
FP = prediction.filter('prediction=1 AND label!=prediction').count()
# Accuracy measures the proportion of correct predictions
accuracy = (TN+TP)/(FN+FP+TN+TP)
print(accuracy)
# Import the logistic regression class
from pyspark.ml.classification import LogisticRegression
# Create a classifier object and train on training data
logistic = LogisticRegression().fit(flights_train)
# Create predictions for the testing data and show confusion matrix
prediction = logistic.transform(flights_test)
prediction.groupBy('label', 'prediction').count().show()
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, BinaryClassificationEvaluator
# Calculate precision and recall
precision =TP/(TP+FP)
recall = (TP)/(FN+TP)
print('precision = {:.2f}\nrecall = {:.2f}'.format(precision, recall))
# Find weighted precision
multi_evaluator = MulticlassClassificationEvaluator()
weighted_precision = multi_evaluator.evaluate(prediction, {multi_evaluator.metricName: "weightedPrecision"})
# Find AUC
binary_evaluator = BinaryClassificationEvaluator()
auc = binary_evaluator.evaluate(prediction, {binary_evaluator.metricName:"areaUnderROC"})
# Import the necessary functions
from pyspark.sql.functions import regexp_replace
from pyspark.ml.feature import Tokenizer
# Remove punctuation (REGEX provided) and numbers
wrangled = sms.withColumn('text', regexp_replace(sms.text, '[_():;,.!?\\-]', ' '))
wrangled = wrangled.withColumn('text', regexp_replace(wrangled.text, '[0-9]', ' '))
# Merge multiple spaces
wrangled = wrangled.withColumn('text', regexp_replace(wrangled.text, ' +', ' '))
# Split the text into words
wrangled = Tokenizer(inputCol='text', outputCol='words').transform(wrangled)
wrangled.show(4, truncate=False)
from pyspark.ml.feature import StopWordsRemover, HashingTF, IDF
# Remove stop words.
wrangled = StopWordsRemover(inputCol='words', outputCol='terms')\
.transform(sms)
# Apply the hashing trick
wrangled = HashingTF(inputCol='terms', outputCol='hash', numFeatures=1024)\
.transform(wrangled)
# Convert hashed symbols to TF-IDF
tf_idf = IDF(inputCol='hash', outputCol='features')\
.fit(wrangled).transform(wrangled)
tf_idf.select('terms', 'features').show(4, truncate=False)
Now we are ready to include texts as features. rename tf_idf sms
# Split the data into training and testing sets
sms_train, sms_test = sms.randomSplit([0.8,0.2], 13)
# Fit a Logistic Regression model to the training data
logistic = LogisticRegression(regParam=0.2).fit(sms_train)
# Make predictions on the testing data
prediction = logistic.transform(sms_test)
# Create a confusion matrix, comparing predictions to known labels
prediction.groupBy('label', 'prediction').count().show()
However, it is not sensible to use index values for numerical calculations, that's why we need One-Hot Encoding.
# Import the one hot encoder class
from pyspark.ml.feature import OneHotEncoderEstimator
# Create an instance of the one hot encoder
onehot = OneHotEncoderEstimator(inputCols=['org_idx'], outputCols=['org_dummy'])
# Apply the one hot encoder to the flights data
onehot = onehot.fit(flights)
flights_onehot = onehot.transform(flights)
# Check the results
flights_onehot.select('org', 'org_idx', 'org_dummy').distinct().sort('org_idx').show()
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator
# Create a regression object and train on training data
regression = LinearRegression(labelCol='duration').fit(flights_train)
# Create predictions for the testing data and take a look at the predictions
predictions = regression.transform(flights_test)
predictions.select('duration', 'prediction').show(5, False)
# Calculate the RMSE
RegressionEvaluator(labelCol='duration').evaluate(predictions)
# Intercept (average minutes on ground)
inter = regression.intercept
print(inter)
# Coefficients
coefs = regression.coefficients
print(coefs)
# Average minutes per km
minutes_per_km = coefs[0]
print(minutes_per_km)
# Average speed in km per hour
avg_speed = 60/(minutes_per_km)
print(avg_speed)
Bucketing converges continous values into discrete categories.
from pyspark.ml.feature import Bucketizer, OneHotEncoderEstimator
# Create buckets at 3 hour intervals through the day
buckets = Bucketizer(splits=[3*i for i in range(0,9)], inputCol = 'depart', outputCol = 'depart_bucket')
# Bucket the departure times
bucketed = buckets.transform(flights)
bucketed.select('depart','depart_bucket').show(5)
# Create a one-hot encoder
onehot = OneHotEncoderEstimator(inputCols=['depart_bucket'],outputCols=['depart_dummy'])
# One-hot encode the bucketed departure times
flights_onehot = onehot.fit(bucketed).transform(bucketed)
flights_onehot.select('depart','depart_bucket','depart_dummy').show(5)
After these, we use discrete features as usual.
Lasso - absolute value of the coeffients Ridge - square of the coefficients
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator
# Fit Lasso model (α = 1) to training data, here alpha is param of elasticNetParam
regression = LinearRegression(labelCol='duration', regParam=1, elasticNetParam=1).fit(flights_train)
# Calculate the RMSE on testing data
rmse = RegressionEvaluator(labelCol='duration').evaluate(regression.transform(flights_test))
print("The test RMSE is", rmse)
# Look at the model coefficients
coeffs = regression.coefficients
print(coeffs)
# Number of zero coefficients
zero_coeff = sum([beta==0 for beta in regression.coefficients])
print("Number of ceofficients equal to 0:", zero_coeff)
# Convert categorical strings to index values
indexer = StringIndexer(inputCol='org',outputCol='org_idx')
# One-hot encode index values
onehot = OneHotEncoderEstimator(
inputCols = ['org_idx','dow'],
outputCols = ['org_dummy','dow_dummy']
)
# Assemble predictors into a single column
assembler = VectorAssembler(inputCols=['km','org_dummy','dow_dummy'], outputCol='features')
# A linear regression object
regression = LinearRegression(labelCol='duration')
# Import class for creating a pipeline
from pyspark.ml import Pipeline
# Construct a pipeline
pipeline = Pipeline(stages=[indexer,onehot,assembler,regression])
# Train the pipeline on the training data
pipeline = pipeline.fit(flights_train)
# Make predictions on the testing data
predictions = pipeline.transform(flights_test)
SMS spam pipeline:
from pyspark.ml.feature import Tokenizer, StopWordsRemover, HashingTF, IDF
# Break text into tokens at non-word characters
tokenizer = Tokenizer(inputCol='text', outputCol='words')
# Remove stop words
remover = StopWordsRemover(inputCol=tokenizer.getOutputCol(), outputCol='terms')
# Apply the hashing trick and transform to TF-IDF
hasher = HashingTF(inputCol=remover.getOutputCol(), outputCol="hash")
idf = IDF(inputCol=hasher.getOutputCol(), outputCol="features")
# Create a logistic regression object and add everything to a pipeline
logistic = LogisticRegression()
pipeline = Pipeline(stages=[tokenizer, remover, hasher, idf, logistic])
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
# Create an empty parameter grid
params = ParamGridBuilder().build()
# Create objects for building and evaluating a regression model
regression = LinearRegression(labelCol='duration')
evaluator = RegressionEvaluator(labelCol='duration')
# Create a cross validator
cv = CrossValidator(estimator=regression, estimatorParamMaps=params, evaluator=evaluator, numFolds=5 )
# Train and test model on multiple folds of the training data
cv = cv.fit(flights_train)
# Create an indexer for the org field
indexer = StringIndexer(inputCol='org', outputCol='org_idx')
# Create an one-hot encoder for the indexed org field
onehot = OneHotEncoderEstimator(inputCols=['org_idx'], outputCols=['org_dummy'])
# Assemble the km and one-hot encoded fields
assembler = VectorAssembler(inputCols=['km','org_dummy'], outputCol='features')
# Create a pipeline and cross-validator.
pipeline = Pipeline(stages=[indexer, onehot, assembler, regression])
cv = CrossValidator(estimator=pipeline,
estimatorParamMaps=params,
evaluator=evaluator)
# Create parameter grid
params = ParamGridBuilder()
# Add grids for two parameters
params = params.addGrid(regression.regParam,[0.01,0.1,1.0,10.0])\
.addGrid(regression.elasticNetParam, [0.0,0.5,1.0])
# Build the parameter grid
params = params.build()
print('Number of models to be tested: ', len(params))
# Create cross-validator
cv = CrossValidator(estimator=pipeline, estimatorParamMaps=params, evaluator=evaluator, numFolds=5)
# Get the best model from cross validation
best_model = cv.bestModel
# Look at the stages in the best model
print(best_model.stages)
# Get the parameters for the LinearRegression object in the best model
best_model.stages[3].extractParamMap()
# Generate predictions on testing data using the best model then calculate RMSE
predictions = best_model.transform(flights_test)
evaluator.evaluate(predictions)
SMS spam optimization
hasher = HashingTF()
logistic = LogisticRegression()
# Create parameter grid
params = ParamGridBuilder()
# Add grid for hashing trick parameters
params = params.addGrid(hasher.numFeatures, [1024,4096,16384]) \
.addGrid(hasher.binary, [True,False])
# Add grid for logistic regression parameters
params = params.addGrid(logistic.regParam, [0.01,0.1,1.0,10.0]) \
.addGrid(logistic.elasticNetParam, [0.0,0.5,1.0])
# Build parameter grid
params = params.build()
Combine Models
Random Forests train trees in parallel, Gradient-Boosted Trees in series.
GBT:
# Import the classes required
from pyspark.ml.classification import DecisionTreeClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
# Create model objects and train on training data
tree = DecisionTreeClassifier().fit(flights_train)
gbt = GBTClassifier().fit(flights_train)
# Compare AUC on testing data
evaluator = BinaryClassificationEvaluator()
evaluator.evaluate(tree.transform(flights_test))
evaluator.evaluate(gbt.transform(flights_test))
# Find the number of trees and the relative importance of features
print(len(gbt.trees))
print(gbt.featureImportances)
RF:
# Create a random forest classifier
forest = RandomForestClassifier()
# Create a parameter grid
params = ParamGridBuilder() \
.addGrid(forest.featureSubsetStrategy, ['all', 'onethird', 'sqrt', 'log2']) \
.addGrid(forest.maxDepth, [2, 5, 10]) \
.build()
# Create a binary classification evaluator
evaluator = BinaryClassificationEvaluator()
# Create a cross-validator
cv = CrossValidator(estimator=forest, estimatorParamMaps= params, evaluator=evaluator, numFolds=5)
# Average AUC for each parameter combination in grid
avg_auc = cv.avgMetrics
# Average AUC for the best model
best_model_auc = max(avg_auc)
# What's the optimal parameter value?
opt_max_depth = cv.bestModel.explainParam('maxDepth')
opt_feat_substrat = cv.bestModel.explainParam('featureSubsetStrategy')
# AUC for best model on testing data
best_auc = evaluator.evaluate(cv.transform(flights_test))