Skip to content

Commit

Permalink
[SPARK-6496] [MLLIB] GeneralizedLinearAlgorithm.run(input, initialWei…
Browse files Browse the repository at this point in the history
…ghts) should initialize numFeatures

In GeneralizedLinearAlgorithm ```numFeatures``` is default to -1, we need to update it to correct value when we call run() to train a model.
```LogisticRegressionWithLBFGS.run(input)``` works well, but when we call ```LogisticRegressionWithLBFGS.run(input, initialWeights)``` to train multiclass classification model, it will throw exception due to the numFeatures is not updated.
In this PR, we just update numFeatures at the beginning of GeneralizedLinearAlgorithm.run(input, initialWeights) and add test case.

Author: Yanbo Liang <[email protected]>

Closes apache#5167 from yanboliang/spark-6496 and squashes the following commits:

8131c48 [Yanbo Liang] LogisticRegressionWithLBFGS.run(input, initialWeights) should initialize numFeatures

(cherry picked from commit 10c7860)
Signed-off-by: Sean Owen <[email protected]>
  • Loading branch information
yanboliang authored and srowen committed Mar 25, 2015
1 parent 8e4e2e3 commit 2be4255
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/
def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {

if (numFeatures < 0) {
numFeatures = input.map(_.features.size).first()
}

if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M

val model = lr.run(testRDD)

val numFeatures = testRDD.map(_.features.size).first()
val initialWeights = Vectors.dense(new Array[Double]((numFeatures + 1) * 2))
val model2 = lr.run(testRDD, initialWeights)

LogisticRegressionSuite.checkModelsEqual(model, model2)

/**
* The following is the instruction to reproduce the model using R's glmnet package.
*
Expand Down

0 comments on commit 2be4255

Please sign in to comment.