Skip to content

Latest commit

 

History

History
404 lines (347 loc) · 13.3 KB

examples.md

File metadata and controls

404 lines (347 loc) · 13.3 KB
layout title type navigation
global
Examples
page singular
weight show
4
true

Apache Spark examples

These examples give a quick overview of the Spark API. Spark is built on the concept of distributed datasets, which contain arbitrary Java or Python objects. You create a dataset from external data, then apply parallel operations to it. The building block of the Spark API is its RDD API. In the RDD API, there are two types of operations: transformations, which define a new dataset based on previous ones, and actions, which kick off a job to execute on a cluster. On top of Spark’s RDD API, high level APIs are provided, e.g. DataFrame API and Machine Learning API. These high level APIs provide a concise way to conduct certain data operations. In this page, we will show examples using RDD API as well as examples using high level APIs.

RDD API examples

Word count

In this example, we use a few transformations to build a dataset of (String, Int) pairs called counts and then save it to a file.

{% highlight python %} text_file = sc.textFile("hdfs://...") counts = text_file.flatMap(lambda line: line.split(" ")) \ .map(lambda word: (word, 1)) \ .reduceByKey(lambda a, b: a + b) counts.saveAsTextFile("hdfs://...") {% endhighlight %}
{% highlight scala %} val textFile = sc.textFile("hdfs://...") val counts = textFile.flatMap(line => line.split(" ")) .map(word => (word, 1)) .reduceByKey(_ + _) counts.saveAsTextFile("hdfs://...") {% endhighlight %}
{% highlight java %} JavaRDD textFile = sc.textFile("hdfs://..."); JavaPairRDD counts = textFile .flatMap(s -> Arrays.asList(s.split(" ")).iterator()) .mapToPair(word -> new Tuple2<>(word, 1)) .reduceByKey((a, b) -> a + b); counts.saveAsTextFile("hdfs://..."); {% endhighlight %}

Pi estimation

Spark can also be used for compute-intensive tasks. This code estimates π by "throwing darts" at a circle. We pick random points in the unit square ((0, 0) to (1,1)) and see how many fall in the unit circle. The fraction should be π / 4, so we use this to get our estimate.

{% highlight python %} def inside(p): x, y = random.random(), random.random() return x*x + y*y < 1

count = sc.parallelize(range(0, NUM_SAMPLES))
.filter(inside).count() print("Pi is roughly %f" % (4.0 * count / NUM_SAMPLES)) {% endhighlight %}

{% highlight scala %} val count = sc.parallelize(1 to NUM_SAMPLES).filter { _ => val x = math.random val y = math.random x*x + y*y < 1 }.count() println(s"Pi is roughly ${4.0 * count / NUM_SAMPLES}") {% endhighlight %}
{% highlight java %} List l = new ArrayList<>(NUM_SAMPLES); for (int i = 0; i < NUM_SAMPLES; i++) { l.add(i); }

long count = sc.parallelize(l).filter(i -> { double x = Math.random(); double y = Math.random(); return xx + yy < 1; }).count(); System.out.println("Pi is roughly " + 4.0 * count / NUM_SAMPLES); {% endhighlight %}

DataFrame API examples

In Spark, a DataFrame is a distributed collection of data organized into named columns. Users can use DataFrame API to perform various relational operations on both external data sources and Spark’s built-in distributed collections without providing specific procedures for processing data. Also, programs based on DataFrame API will be automatically optimized by Spark’s built-in optimizer, Catalyst.

Text search

In this example, we search through the error messages in a log file.

{% highlight python %} textFile = sc.textFile("hdfs://...")

Creates a DataFrame having a single column named "line"

df = textFile.map(lambda r: Row(r)).toDF(["line"]) errors = df.filter(col("line").like("%ERROR%"))

Counts all the errors

errors.count()

Counts errors mentioning MySQL

errors.filter(col("line").like("%MySQL%")).count()

Fetches the MySQL errors as an array of strings

errors.filter(col("line").like("%MySQL%")).collect() {% endhighlight %}

{% highlight scala %} val textFile = sc.textFile("hdfs://...")

// Creates a DataFrame having a single column named "line" val df = textFile.toDF("line") val errors = df.filter(col("line").like("%ERROR%")) // Counts all the errors errors.count() // Counts errors mentioning MySQL errors.filter(col("line").like("%MySQL%")).count() // Fetches the MySQL errors as an array of strings errors.filter(col("line").like("%MySQL%")).collect() {% endhighlight %}

{% highlight java %} // Creates a DataFrame having a single column named "line" JavaRDD textFile = sc.textFile("hdfs://..."); JavaRDD rowRDD = textFile.map(RowFactory::create); List fields = Arrays.asList( DataTypes.createStructField("line", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame df = sqlContext.createDataFrame(rowRDD, schema);

DataFrame errors = df.filter(col("line").like("%ERROR%")); // Counts all the errors errors.count(); // Counts errors mentioning MySQL errors.filter(col("line").like("%MySQL%")).count(); // Fetches the MySQL errors as an array of strings errors.filter(col("line").like("%MySQL%")).collect(); {% endhighlight %}

Simple data operations

In this example, we read a table stored in a database and calculate the number of people for every age. Finally, we save the calculated result to S3 in the format of JSON. A simple MySQL table "people" is used in the example and this table has two columns, "name" and "age".

{% highlight python %} # Creates a DataFrame based on a table named "people" # stored in a MySQL database. url = \ "jdbc:mysql://yourIP:yourPort/test?user=yourUsername;password=yourPassword" df = sqlContext \ .read \ .format("jdbc") \ .option("url", url) \ .option("dbtable", "people") \ .load()

Looks the schema of this DataFrame.

df.printSchema()

Counts people by age

countsByAge = df.groupBy("age").count() countsByAge.show()

Saves countsByAge to S3 in the JSON format.

countsByAge.write.format("json").save("s3a://...") {% endhighlight %}

{% highlight scala %} // Creates a DataFrame based on a table named "people" // stored in a MySQL database. val url = "jdbc:mysql://yourIP:yourPort/test?user=yourUsername;password=yourPassword" val df = sqlContext .read .format("jdbc") .option("url", url) .option("dbtable", "people") .load()

// Looks the schema of this DataFrame. df.printSchema()

// Counts people by age val countsByAge = df.groupBy("age").count() countsByAge.show()

// Saves countsByAge to S3 in the JSON format. countsByAge.write.format("json").save("s3a://...") {% endhighlight %}

{% highlight java %} // Creates a DataFrame based on a table named "people" // stored in a MySQL database. String url = "jdbc:mysql://yourIP:yourPort/test?user=yourUsername;password=yourPassword"; DataFrame df = sqlContext .read() .format("jdbc") .option("url", url) .option("dbtable", "people") .load();

// Looks the schema of this DataFrame. df.printSchema();

// Counts people by age DataFrame countsByAge = df.groupBy("age").count(); countsByAge.show();

// Saves countsByAge to S3 in the JSON format. countsByAge.write().format("json").save("s3a://..."); {% endhighlight %}

Machine learning example

MLlib, Spark’s Machine Learning (ML) library, provides many distributed ML algorithms. These algorithms cover tasks such as feature extraction, classification, regression, clustering, recommendation, and more. MLlib also provides tools such as ML Pipelines for building workflows, CrossValidator for tuning parameters, and model persistence for saving and loading models.

Prediction with logistic regression

In this example, we take a dataset of labels and feature vectors. We learn to predict the labels from feature vectors using the Logistic Regression algorithm.

{% highlight python %} # Every record of this DataFrame contains the label and # features represented by a vector. df = sqlContext.createDataFrame(data, ["label", "features"])

Set parameters for the algorithm.

Here, we limit the number of iterations to 10.

lr = LogisticRegression(maxIter=10)

Fit the model to the data.

model = lr.fit(df)

Given a dataset, predict each point's label, and show the results.

model.transform(df).show() {% endhighlight %}

{% highlight scala %} // Every record of this DataFrame contains the label and // features represented by a vector. val df = sqlContext.createDataFrame(data).toDF("label", "features")

// Set parameters for the algorithm. // Here, we limit the number of iterations to 10. val lr = new LogisticRegression().setMaxIter(10)

// Fit the model to the data. val model = lr.fit(df)

// Inspect the model: get the feature weights. val weights = model.weights

// Given a dataset, predict each point's label, and show the results. model.transform(df).show() {% endhighlight %}

{% highlight java %} // Every record of this DataFrame contains the label and // features represented by a vector. StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("features", new VectorUDT(), false, Metadata.empty()), }); DataFrame df = jsql.createDataFrame(data, schema);

// Set parameters for the algorithm. // Here, we limit the number of iterations to 10. LogisticRegression lr = new LogisticRegression().setMaxIter(10);

// Fit the model to the data. LogisticRegressionModel model = lr.fit(df);

// Inspect the model: get the feature weights. Vector weights = model.weights();

// Given a dataset, predict each point's label, and show the results. model.transform(df).show(); {% endhighlight %}

Additional examples

Many additional examples are distributed with Spark: