Skip to content

Commit

Permalink
Removed the miniBatch in LBFGS.
Browse files Browse the repository at this point in the history
  • Loading branch information
DB Tsai committed Apr 30, 2014
1 parent 1ba6a33 commit 9cc6cf9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
private var convergenceTol = 1E-4
private var maxNumIterations = 100
private var regParam = 0.0
private var miniBatchFraction = 1.0

/**
* Set the number of corrections used in the LBFGS update. Default 10.
Expand All @@ -57,14 +56,6 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
this
}

/**
* Set fraction of data to be used for each L-BFGS iteration. Default 1.0.
*/
def setMiniBatchFraction(fraction: Double): this.type = {
this.miniBatchFraction = fraction
this
}

/**
* Set the convergence tolerance of iterations for L-BFGS. Default 1E-4.
* Smaller value will lead to higher accuracy with the cost of more iterations.
Expand Down Expand Up @@ -110,15 +101,14 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
}

override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = {
val (weights, _) = LBFGS.runMiniBatchLBFGS(
val (weights, _) = LBFGS.runLBFGS(
data,
gradient,
updater,
numCorrections,
convergenceTol,
maxNumIterations,
regParam,
miniBatchFraction,
initialWeights)
weights
}
Expand All @@ -132,10 +122,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
@DeveloperApi
object LBFGS extends Logging {
/**
* Run Limited-memory BFGS (L-BFGS) in parallel using mini batches.
* In each iteration, we sample a subset (fraction miniBatchFraction) of the total data
* in order to compute a gradient estimate.
* Sampling, and averaging the subgradients over this subset is performed using one standard
* Run Limited-memory BFGS (L-BFGS) in parallel.
* Averaging the subgradients over different partitions is performed using one standard
* spark map-reduce in each iteration.
*
* @param data - Input data for L-BFGS. RDD of the set of data examples, each of
Expand All @@ -147,31 +135,27 @@ object LBFGS extends Logging {
* @param convergenceTol - The convergence tolerance of iterations for L-BFGS
* @param maxNumIterations - Maximal number of iterations that L-BFGS can be run.
* @param regParam - Regularization parameter
* @param miniBatchFraction - Fraction of the input data set that should be used for
* one iteration of L-BFGS. Default value 1.0.
*
* @return A tuple containing two elements. The first element is a column matrix containing
* weights for every feature, and the second element is an array containing the loss
* computed for every iteration.
*/
def runMiniBatchLBFGS(
def runLBFGS(
data: RDD[(Double, Vector)],
gradient: Gradient,
updater: Updater,
numCorrections: Int,
convergenceTol: Double,
maxNumIterations: Int,
regParam: Double,
miniBatchFraction: Double,
initialWeights: Vector): (Vector, Array[Double]) = {

val lossHistory = new ArrayBuffer[Double](maxNumIterations)

val numExamples = data.count()
val miniBatchSize = numExamples * miniBatchFraction

val costFun =
new CostFun(data, gradient, updater, regParam, miniBatchFraction, miniBatchSize)
new CostFun(data, gradient, updater, regParam, numExamples)

val lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol)

Expand All @@ -190,7 +174,7 @@ object LBFGS extends Logging {
lossHistory.append(state.value)
val weights = Vectors.fromBreeze(state.x)

logInfo("LBFGS.runMiniBatchLBFGS finished. Last 10 losses %s".format(
logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format(
lossHistory.takeRight(10).mkString(", ")))

(weights, lossHistory.toArray)
Expand All @@ -205,8 +189,7 @@ object LBFGS extends Logging {
gradient: Gradient,
updater: Updater,
regParam: Double,
miniBatchFraction: Double,
miniBatchSize: Double) extends DiffFunction[BDV[Double]] {
numExamples: Long) extends DiffFunction[BDV[Double]] {

private var i = 0

Expand All @@ -215,8 +198,7 @@ object LBFGS extends Logging {
val localData = data
val localGradient = gradient

val (gradientSum, lossSum) = localData.sample(false, miniBatchFraction, 42 + i)
.aggregate((BDV.zeros[Double](weights.size), 0.0))(
val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))(
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
val l = localGradient.compute(
features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad))
Expand All @@ -234,7 +216,7 @@ object LBFGS extends Logging {
Vectors.fromBreeze(weights),
Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2

val loss = lossSum / miniBatchSize + regVal
val loss = lossSum / numExamples + regVal
/**
* It will return the gradient part of regularization using updater.
*
Expand All @@ -256,8 +238,8 @@ object LBFGS extends Logging {
Vectors.fromBreeze(weights),
Vectors.dense(new Array[Double](weights.size)), 1, 1, regParam)._1.toBreeze

// gradientTotal = gradientSum / miniBatchSize + gradientTotal
axpy(1.0 / miniBatchSize, gradientSum, gradientTotal)
// gradientTotal = gradientSum / numExamples + gradientTotal
axpy(1.0 / numExamples, gradientSum, gradientTotal)

i += 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,14 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
val convergenceTol = 1e-12
val maxNumIterations = 10

val (_, loss) = LBFGS.runMiniBatchLBFGS(
val (_, loss) = LBFGS.runLBFGS(
dataRDD,
gradient,
simpleUpdater,
numCorrections,
convergenceTol,
maxNumIterations,
regParam,
miniBatchFrac,
initialWeightsWithIntercept)

// Since the cost function is convex, the loss is guaranteed to be monotonically decreasing
Expand Down Expand Up @@ -104,15 +103,14 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
val convergenceTol = 1e-12
val maxNumIterations = 10

val (weightLBFGS, lossLBFGS) = LBFGS.runMiniBatchLBFGS(
val (weightLBFGS, lossLBFGS) = LBFGS.runLBFGS(
dataRDD,
gradient,
squaredL2Updater,
numCorrections,
convergenceTol,
maxNumIterations,
regParam,
miniBatchFrac,
initialWeightsWithIntercept)

val numGDIterations = 50
Expand Down Expand Up @@ -150,47 +148,44 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
val maxNumIterations = 8
var convergenceTol = 0.0

val (_, lossLBFGS1) = LBFGS.runMiniBatchLBFGS(
val (_, lossLBFGS1) = LBFGS.runLBFGS(
dataRDD,
gradient,
squaredL2Updater,
numCorrections,
convergenceTol,
maxNumIterations,
regParam,
miniBatchFrac,
initialWeightsWithIntercept)

// Note that the first loss is computed with initial weights,
// so the total numbers of loss will be numbers of iterations + 1
assert(lossLBFGS1.length == 9)

convergenceTol = 0.1
val (_, lossLBFGS2) = LBFGS.runMiniBatchLBFGS(
val (_, lossLBFGS2) = LBFGS.runLBFGS(
dataRDD,
gradient,
squaredL2Updater,
numCorrections,
convergenceTol,
maxNumIterations,
regParam,
miniBatchFrac,
initialWeightsWithIntercept)

// Based on observation, lossLBFGS2 runs 3 iterations, no theoretically guaranteed.
assert(lossLBFGS2.length == 4)
assert((lossLBFGS2(2) - lossLBFGS2(3)) / lossLBFGS2(2) < convergenceTol)

convergenceTol = 0.01
val (_, lossLBFGS3) = LBFGS.runMiniBatchLBFGS(
val (_, lossLBFGS3) = LBFGS.runLBFGS(
dataRDD,
gradient,
squaredL2Updater,
numCorrections,
convergenceTol,
maxNumIterations,
regParam,
miniBatchFrac,
initialWeightsWithIntercept)

// With smaller convergenceTol, it takes more steps.
Expand Down

0 comments on commit 9cc6cf9

Please sign in to comment.