diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 9d2611e0ed8e9..e4f90f3171b0b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg import java.util.{Arrays, Random} -import scala.collection.mutable.{ArrayBuffer, ArrayBuilder, Map} +import scala.collection.mutable.{ArrayBuffer, ArrayBuilder, Map => MutableMap} import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} @@ -408,8 +408,15 @@ object SparseMatrix { require(density >= 0.0 && density <= 1.0, "density must be a double in the range " + s"0.0 <= d <= 1.0. Currently, density: $density") val length = math.ceil(numRows * numCols * density).toInt - val entries = Map[(Int, Int), Double]() + val entries = MutableMap[(Int, Int), Double]() var i = 0 + if (density == 0.0) { + return new SparseMatrix(numRows, numCols, new Array[Int](numCols + 1), + Array[Int](), Array[Double]()) + } else if (density == 1.0) { + return new SparseMatrix(numRows, numCols, (0 to numRows * numCols by numRows).toArray, + (0 until numRows * numCols).toArray, Array.fill(numRows * numCols)(method(rng))) + } // Expected number of iterations is less than 1.5 * length if (density < 0.34) { while (i < length) { @@ -424,23 +431,18 @@ object SparseMatrix { } } else { // selection - rejection method var j = 0 - val triesPerCol = math.ceil(length * 1.0 / numCols).toInt val pool = numRows * numCols // loop over columns so that the sort in fromCOO requires less sorting while (i < length && j < numCols) { - var k = 0 - val leftFromPool = (numCols - j) * numRows - while (k < triesPerCol) { - if (rng.nextDouble() < 1.0 * (length - i) / (pool - leftFromPool)) { - var rowIndex = rng.nextInt(numRows) - val colIndex = j - while (entries.contains((rowIndex, colIndex))) { - rowIndex = rng.nextInt(numRows) - } - entries += (rowIndex, colIndex) -> method(rng) + var passedInPool = j * numRows + var r = 0 + while (i < length && r < numRows) { + if (rng.nextDouble() < 1.0 * (length - i) / (pool - passedInPool)) { + entries += (r, j) -> method(rng) i += 1 } - k += 1 + r += 1 + passedInPool += 1 } j += 1 }