Skip to content

Commit

Permalink
better initial intercept and more test
Browse files Browse the repository at this point in the history
  • Loading branch information
DB Tsai committed May 8, 2015
1 parent 5c31824 commit 0806002
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,23 @@ class LogisticRegression
val initialWeightsWithIntercept =
Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)

// TODO: Compute the initial intercept based on the histogram.
if ($(fitIntercept)) initialWeightsWithIntercept.toArray(numFeatures) = 1.0
if ($(fitIntercept)) {
/**
* For binary logistic regression, when we initialize the weights as zeros,
* it will converge faster if we initialize the intercept such that
* it follows the distribution of the labels.
*
* {{{
* P(0) = 1 / (1 + \exp(b)), and
* P(1) = \exp(b) / (1 + \exp(b))
* }}}, hence
* {{{
* b = \log{P(1) / P(0)} = \log{count_1 / count_0}
* }}}
*/
initialWeightsWithIntercept.toArray(numFeatures)
= Math.log(histogram(1).toDouble / histogram(0).toDouble)
}

val states = optimizer.iterations(new CachedDiffFunction(costFun),
initialWeightsWithIntercept.toBreeze.toDenseVector)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(model.intercept ~== interceptR relTol 1E-2)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
assert(model.weights(3) ~== weightsR(3) relTol 1E-2)
assert(model.weights(2) ~== weightsR(2) relTol 1E-2)
assert(model.weights(3) ~== weightsR(3) relTol 2E-2)
}

test("binary logistic regression without intercept with L1 regularization") {
Expand Down Expand Up @@ -423,10 +423,10 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
val interceptR = 0.57734851
val weightsR = Array(-0.05310287, 0.0, -0.08849250, -0.15458796)

assert(model.intercept ~== interceptR relTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
assert(model.intercept ~== interceptR relTol 1E-2)
assert(model.weights(0) ~== weightsR(0) relTol 1E-2)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
assert(model.weights(2) ~== weightsR(2) relTol 1E-2)
assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
}

Expand Down Expand Up @@ -462,4 +462,66 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
assert(model.weights(3) ~== weightsR(3) relTol 1E-2)
}

test("binary logistic regression with intercept with strong L1 regularization") {
val trainer = (new LogisticRegression).setFitIntercept(true)
.setElasticNetParam(1.0).setRegParam(6.0)
val model = trainer.fit(binaryDataset)

val histogram = binaryDataset.map { case Row(label: Double, features: Vector) => label }
.treeAggregate(new MultiClassSummarizer)(
seqOp = (c, v) => (c, v) match {
case (classSummarizer: MultiClassSummarizer, label: Double) => classSummarizer.add(label)
},
combOp = (c1, c2) => (c1, c2) match {
case (classSummarizer1: MultiClassSummarizer, classSummarizer2: MultiClassSummarizer) =>
classSummarizer1.merge(classSummarizer2)
}).histogram

/**
* For binary logistic regression with strong L1 regularization, all the weights will be zeros.
* As a result,
* {{{
* P(0) = 1 / (1 + \exp(b)), and
* P(1) = \exp(b) / (1 + \exp(b))
* }}}, hence
* {{{
* b = \log{P(1) / P(0)} = \log{count_1 / count_0}
* }}}
*/
val interceptTheory = Math.log(histogram(1).toDouble / histogram(0).toDouble)
val weightsTheory = Array(0.0, 0.0, 0.0, 0.0)

assert(model.intercept ~== interceptTheory relTol 1E-3)
assert(model.weights(0) ~== weightsTheory(0) absTol 1E-6)
assert(model.weights(1) ~== weightsTheory(1) absTol 1E-6)
assert(model.weights(2) ~== weightsTheory(2) absTol 1E-6)
assert(model.weights(3) ~== weightsTheory(3) absTol 1E-6)

/**
* Using the following R code to load the data and train the model using glmnet package.
*
* > library("glmnet")
* > data <- read.csv("path", header=FALSE)
* > label = factor(data$V1)
* > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
* > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0))
* > weights
* 5 x 1 sparse Matrix of class "dgCMatrix"
* s0
* (Intercept) -0.2480643
* data.V2 0.0000000
* data.V3 .
* data.V4 .
* data.V5 .
*/
val interceptR = -0.248065
val weightsR = Array(0.0, 0.0, 0.0, 0.0)

assert(model.intercept ~== interceptR relTol 1E-3)
assert(model.weights(0) ~== weightsR(0) absTol 1E-6)
assert(model.weights(1) ~== weightsR(1) absTol 1E-6)
assert(model.weights(2) ~== weightsR(2) absTol 1E-6)
assert(model.weights(3) ~== weightsR(3) absTol 1E-6)
}
}

0 comments on commit 0806002

Please sign in to comment.