-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Streaming KMeans [MLLIB][SPARK-3254] #2942
Changes from 9 commits
b93350f
9fd9c15
a0fd790
b5b5f8d
f33684b
5db7074
9facbe3
ea9877c
2086bdc
ea22ec8
1472ec5
a4a316b
2899623
c7050d5
44050a9
9cfc301
77dbd3f
ad9bdc2
374a706
9f7aea9
0411bf5
2e682c0
078617c
b2e5b4a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,7 @@ a given dataset, the algorithm returns the best clustering result). | |
* *initializationSteps* determines the number of steps in the k-means\|\| algorithm. | ||
* *epsilon* determines the distance threshold within which we consider k-means to have converged. | ||
|
||
## Examples | ||
### Examples | ||
|
||
<div class="codetabs"> | ||
<div data-lang="scala" markdown="1"> | ||
|
@@ -153,3 +153,75 @@ provided in the [Self-Contained Applications](quick-start.html#self-contained-ap | |
section of the Spark | ||
Quick Start guide. Be sure to also include *spark-mllib* to your build file as | ||
a dependency. | ||
|
||
## Streaming clustering | ||
|
||
When data arrive in a stream, we may want to estimate clusters dynamically, updating them as new data arrive. MLlib provides support for streaming KMeans clustering, with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm uses a generalization of the mini-batch KMeans update rule. For each batch of data, we assign all points to their nearest cluster, compute new cluster centers, then update each cluster using: | ||
|
||
`\begin{equation} | ||
c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t} | ||
\end{equation}` | ||
`\begin{equation} | ||
n_{t+1} = n_t + m_t | ||
\end{equation}` | ||
|
||
Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$` is the number of points added to the cluster in the current batch. The decay factor `$\alpha$` can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning; with `$\alpha$=0` only the most recent data will be used. This is analogous to an expontentially-weighted moving average. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
### Examples | ||
|
||
This example shows how to estimate clusters on streaming data. | ||
|
||
<div class="codetabs"> | ||
|
||
<div data-lang="scala" markdown="1"> | ||
|
||
First we import the neccessary classes. | ||
|
||
{% highlight scala %} | ||
|
||
import org.apache.spark.mllib.linalg.Vectors | ||
import org.apache.spark.mllib.clustering.StreamingKMeans | ||
|
||
{% endhighlight %} | ||
|
||
Then we make an input stream of vectors for training, as well as one for testing. We assume a StreamingContext `ssc` has been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. For this example, we use vector data. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line too wide |
||
|
||
{% highlight scala %} | ||
|
||
val trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse) | ||
val testData = ssc.textFileStream("/testing/data/dir").map(Vectors.parse) | ||
|
||
{% endhighlight %} | ||
|
||
We create a model with random clusters and specify the number of clusters to find | ||
|
||
{% highlight scala %} | ||
|
||
val numDimensions = 3 | ||
val numClusters = 2 | ||
val model = new StreamingKMeans() | ||
.setK(numClusters) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2-space indentation |
||
.setDecayFactor(1.0) | ||
.setRandomWeights(numDimensions) | ||
|
||
{% endhighlight %} | ||
|
||
Now register the streams for training and testing and start the job, printing the predicted cluster assignments on new data points as they arrive. | ||
|
||
{% highlight scala %} | ||
|
||
model.trainOn(trainingData) | ||
model.predictOn(testData).print() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
ssc.start() | ||
ssc.awaitTermination() | ||
|
||
{% endhighlight %} | ||
|
||
As you add new text files with data the cluster centers will update. Each data point should be formatted as `[x1, x2, x3]`. Anytime a text file is placed in `/training/data/dir` | ||
the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. With new data, the cluster centers will change. | ||
|
||
|
||
</div> | ||
|
||
</div> |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.examples.mllib | ||
|
||
import org.apache.spark.mllib.linalg.Vectors | ||
import org.apache.spark.mllib.clustering.StreamingKMeans | ||
import org.apache.spark.SparkConf | ||
import org.apache.spark.streaming.{Seconds, StreamingContext} | ||
|
||
/** | ||
* Estimate clusters on one stream of data and make predictions | ||
* on another stream, where the data streams arrive as text files | ||
* into two different directories. | ||
* | ||
* The rows of the text files must be vector data in the form | ||
* `[x1,x2,x3,...,xn]` | ||
* Where n is the number of dimensions. n must be the same for train and test. | ||
* | ||
* Usage: StreamingKmeans <trainingDir> <testDir> <batchDuration> <numClusters> <numDimensions> | ||
* | ||
* To run on your local machine using the two directories `trainingDir` and `testDir`, | ||
* with updates every 5 seconds, 2 dimensions per data point, and 3 clusters, call: | ||
* $ bin/run-example \ | ||
* org.apache.spark.examples.mllib.StreamingKMeans trainingDir testDir 5 3 2 | ||
* | ||
* As you add text files to `trainingDir` the clusters will continuously update. | ||
* Anytime you add text files to `testDir`, you'll see predicted labels using the current model. | ||
* | ||
*/ | ||
object StreamingKMeans { | ||
|
||
def main(args: Array[String]) { | ||
|
||
if (args.length != 5) { | ||
System.err.println( | ||
"Usage: StreamingKMeans " + | ||
"<trainingDir> <testDir> <batchDuration> <numClusters> <numDimensions>") | ||
System.exit(1) | ||
} | ||
|
||
val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression") | ||
val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) | ||
|
||
val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse) | ||
val testData = ssc.textFileStream(args(1)).map(Vectors.parse) | ||
|
||
val model = new StreamingKMeans() | ||
.setK(args(3).toInt) | ||
.setDecayFactor(1.0) | ||
.setRandomCenters(args(4).toInt) | ||
|
||
model.trainOn(trainingData) | ||
model.predictOn(testData).print() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: use |
||
|
||
ssc.start() | ||
ssc.awaitTermination() | ||
|
||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
KMeans
->k-means