Skip to content

Commit

Permalink
Move check for empty data to GradientDescent
Browse files Browse the repository at this point in the history
  • Loading branch information
freeman-lab committed Aug 1, 2014
1 parent 4b0a5d3 commit c7d38a3
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,45 +159,57 @@ object GradientDescent extends Logging {
val stochasticLossHistory = new ArrayBuffer[Double](numIterations)

val numExamples = data.count()
val miniBatchSize = numExamples * miniBatchFraction

// Initialize weights as a column vector
var weights = Vectors.dense(initialWeights.toArray)

/**
* For the first iteration, the regVal will be initialized as sum of weight squares
* if it's L2 updater; for L1 updater, the same logic is followed.
*/
var regVal = updater.compute(
weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2

for (i <- 1 to numIterations) {
// Sample a subset (fraction miniBatchFraction) of the total data
// compute and sum up the subgradients on this subset (this is one map-reduce)
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
.aggregate((BDV.zeros[Double](weights.size), 0.0))(
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
val l = gradient.compute(features, label, weights, Vectors.fromBreeze(grad))
(grad, loss + l)
},
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
(grad1 += grad2, loss1 + loss2)
})

// if no data, return initial weights to avoid NaNs
if (numExamples == 0) {

logInfo("GradientDescent.runMiniBatchSGD returning initial weights, no data found")
(initialWeights, stochasticLossHistory.toArray)

} else {

val miniBatchSize = numExamples * miniBatchFraction

// Initialize weights as a column vector
var weights = Vectors.dense(initialWeights.toArray)

/**
* NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
* and regVal is the regularization value computed in the previous iteration as well.
* For the first iteration, the regVal will be initialized as sum of weight squares
* if it's L2 updater; for L1 updater, the same logic is followed.
*/
stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
val update = updater.compute(
weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam)
weights = update._1
regVal = update._2
var regVal = updater.compute(
weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2

for (i <- 1 to numIterations) {
// Sample a subset (fraction miniBatchFraction) of the total data
// compute and sum up the subgradients on this subset (this is one map-reduce)
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
.aggregate((BDV.zeros[Double](weights.size), 0.0))(
seqOp = (c, v) => (c, v) match {
case ((grad, loss), (label, features)) =>
val l = gradient.compute(features, label, weights, Vectors.fromBreeze(grad))
(grad, loss + l)
},
combOp = (c1, c2) => (c1, c2) match {
case ((grad1, loss1), (grad2, loss2)) =>
(grad1 += grad2, loss1 + loss2)
})

/**
* NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
* and regVal is the regularization value computed in the previous iteration as well.
*/
stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
val update = updater.compute(
weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam)
weights = update._1
regVal = update._2
}

logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(
stochasticLossHistory.takeRight(10).mkString(", ")))

(weights, stochasticLossHistory.toArray)
}

logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(
stochasticLossHistory.takeRight(10).mkString(", ")))

(weights, stochasticLossHistory.toArray)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,8 @@ abstract class StreamingRegression[
def trainOn(data: DStream[LabeledPoint]) {
data.foreachRDD{
rdd =>
if (rdd.count() > 0) {
model = algorithm.run(rdd, model.weights)
logInfo("Model updated")
}
model = algorithm.run(rdd, model.weights)
logInfo("Model updated")
logInfo("Current model: weights, %s".format(model.weights.toString))
logInfo("Current model: intercept, %s".format(model.intercept.toString))
}
Expand Down

0 comments on commit c7d38a3

Please sign in to comment.