diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 00d0b18c27a8d..af1258f5e7d50 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -208,13 +208,11 @@ class RowMatrix( val nt: Int = n * (n + 1) / 2 // Compute the upper triangular part of the gram matrix. - val GU = rows.aggregate(new BDV[Double](new Array[Double](nt)))( + val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))( seqOp = (U, v) => { RowMatrix.dspr(1.0, v, U.data) U - }, - combOp = (U1, U2) => U1 += U2 - ) + }, combOp = (U1, U2) => U1 += U2, 2) RowMatrix.triuToFull(n, GU.data) } @@ -309,10 +307,11 @@ class RowMatrix( s"We need at least $mem bytes of memory.") } - val (m, mean) = rows.aggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))( + val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))( seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze), - combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2) - ) + combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => + (s1._1 + s2._1, s1._2 += s2._2), + 2) updateNumRows(m) @@ -371,10 +370,10 @@ class RowMatrix( */ def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = { val zeroValue = new ColumnStatisticsAggregator(numCols().toInt) - val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)( - (aggregator, data) => aggregator.add(data), - (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - ) + val summary = rows.map(_.toBreeze).treeAggregate[ColumnStatisticsAggregator](zeroValue)( + seqOp = (aggregator, data) => aggregator.add(data), + combOp = (aggregator1, aggregator2) => aggregator1.merge(aggregator2), + 2) updateNumRows(summary.count) summary } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 7030eeabe400a..e146044cd81a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -175,14 +175,14 @@ object GradientDescent extends Logging { // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i) - .aggregate((BDV.zeros[Double](weights.size), 0.0))( + .treeAggregate((BDV.zeros[Double](weights.size), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = gradient.compute(features, label, weights, Vectors.fromBreeze(grad)) (grad, loss + l) }, combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => (grad1 += grad2, loss1 + loss2) - }) + }, 2) /** * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 8f187c9df5102..937d83ad73b0e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -198,7 +198,7 @@ object LBFGS extends Logging { val localData = data val localGradient = gradient - val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))( + val (gradientSum, lossSum) = localData.treeAggregate((BDV.zeros[Double](weights.size), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = localGradient.compute( features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad)) @@ -206,7 +206,7 @@ object LBFGS extends Logging { }, combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => (grad1 += grad2, loss1 + loss2) - }) + }, 2) /** * regVal is sum of weight squares if it's L2 updater;