Skip to content

Commit

Permalink
Name changes
Browse files Browse the repository at this point in the history
  • Loading branch information
freeman-lab committed Aug 1, 2014
1 parent c7d38a3 commit 14b801e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import org.apache.spark.streaming.dstream.DStream

/**
* :: DeveloperApi ::
* StreamingRegression implements methods for training
* a linear regression model on streaming data, and using it
* for prediction on streaming data.
* StreamingLinearAlgorithm implements methods for continuously
* training a generalized linear model model on streaming data,
* and using it for prediction on streaming data.
*
* This class takes as type parameters a GeneralizedLinearModel,
* and a GeneralizedLinearAlgorithm, making it easy to extend to construct
Expand All @@ -34,7 +34,7 @@ import org.apache.spark.streaming.dstream.DStream
*
*/
@DeveloperApi
abstract class StreamingRegression[
abstract class StreamingLinearAlgorithm[
M <: GeneralizedLinearModel,
A <: GeneralizedLinearAlgorithm[M]] extends Logging {

Expand All @@ -45,7 +45,7 @@ abstract class StreamingRegression[
val algorithm: A

/** Return the latest model. */
def latest(): M = {
def latestModel(): M = {
model
}

Expand All @@ -58,8 +58,7 @@ abstract class StreamingRegression[
* @param data DStream containing labeled data
*/
def trainOn(data: DStream[LabeledPoint]) {
data.foreachRDD{
rdd =>
data.foreachRDD { rdd =>
model = algorithm.run(rdd, model.weights)
logInfo("Model updated")
logInfo("Current model: weights, %s".format(model.weights.toString))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class StreamingLinearRegressionWithSGD private (
private var numIterations: Int,
private var miniBatchFraction: Double,
private var numFeatures: Int)
extends StreamingRegression[LinearRegressionModel, LinearRegressionWithSGD] with Serializable {
extends StreamingLinearAlgorithm[LinearRegressionModel, LinearRegressionWithSGD] with Serializable {

val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
Utils.deleteRecursively(testDir)

// check accuracy of final parameter estimates
assertEqual(model.latest().intercept, 0.0, 0.1)
assertEqual(model.latest().weights(0), 10.0, 0.1)
assertEqual(model.latest().weights(1), 10.0, 0.1)
assertEqual(model.latestModel().intercept, 0.0, 0.1)
assertEqual(model.latestModel().weights(0), 10.0, 0.1)
assertEqual(model.latestModel().weights(1), 10.0, 0.1)

// check accuracy of predictions
val validationData = LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 17)
validatePrediction(validationData.map(row => model.latest().predict(row.features)), validationData)
validatePrediction(validationData.map(row => model.latestModel().predict(row.features)), validationData)
}

// Test that parameter estimates improve when learning Y = 10*X1 on streaming data
Expand All @@ -107,7 +107,7 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
Thread.sleep(batchDuration.milliseconds)
// wait an extra few seconds to make sure the update finishes before new data arrive
Thread.sleep(4000)
history.append(math.abs(model.latest().weights(0) - 10.0))
history.append(math.abs(model.latestModel().weights(0) - 10.0))
}

ssc.stop(stopSparkContext=false)
Expand Down

0 comments on commit 14b801e

Please sign in to comment.