diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index ff02e5dd3c253..56caeac05c0c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -119,11 +119,25 @@ class RowMatrix @Since("1.0.0") ( val nt = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2)) // Compute the upper triangular part of the gram matrix. - val GU = rows.treeAggregate(new BDV[Double](nt))( - seqOp = (U, v) => { + val GU = rows.treeAggregate(null.asInstanceOf[BDV[Double]])( + seqOp = (maybeU, v) => { + val U = + if (maybeU == null) { + new BDV[Double](nt) + } else { + maybeU + } BLAS.spr(1.0, v, U.data) U - }, combOp = (U1, U2) => U1 += U2) + }, combOp = (U1, U2) => + if (U1 == null) { + U2 + } else if (U2 == null) { + U1 + } else { + U1 += U2 + } + ) RowMatrix.triuToFull(n, GU.data) } @@ -136,8 +150,14 @@ class RowMatrix @Since("1.0.0") ( // This succeeds when n <= 65535, which is checked above val nt = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2)) - val MU = rows.treeAggregate(new BDV[Double](nt))( - seqOp = (U, v) => { + val MU = rows.treeAggregate(null.asInstanceOf[BDV[Double]])( + seqOp = (maybeU, v) => { + val U = + if (maybeU == null) { + new BDV[Double](nt) + } else { + maybeU + } val n = v.size val na = Array.ofDim[Double](n) @@ -150,7 +170,15 @@ class RowMatrix @Since("1.0.0") ( BLAS.spr(1.0, new DenseVector(na), U.data) U - }, combOp = (U1, U2) => U1 += U2) + }, combOp = (U1, U2) => + if (U1 == null) { + U2 + } else if (U2 == null) { + U1 + } else { + U1 += U2 + } + ) bc.destroy()