diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 8e517252aafa6..e88caf3a66ea1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -235,12 +235,10 @@ private[mllib] object BLAS extends Serializable { var nB: Int = B.numCols var kA: Int = A.numCols var kB: Int = B.numRows - var transposeA: Boolean = false - var transposeB: Boolean = false + if (transA == "T" || transA=="t"){ mA = A.numCols kA = A.numRows - transposeA = true } require(transA == "T" || transA == "t" || transA == "N" || transA == "n", s"Invalid argument used for transA: $transA. " + @@ -248,7 +246,6 @@ private[mllib] object BLAS extends Serializable { if (transB == "T" || transB=="t"){ nB = B.numRows kB = B.numCols - transposeB = true } require(transB == "T" || transB == "t" || transB == "N" || transB == "n", s"Invalid argument used for transB: $transB. " + @@ -261,7 +258,7 @@ private[mllib] object BLAS extends Serializable { A match { case sparse: SparseMatrix => - gemm(transA, transB, alpha, sparse, B, beta, C, transposeA, transposeB, mA, kA, nB) + gemm(transA, transB, alpha, sparse, B, beta, C, mA, kA, nB) case dense: DenseMatrix => gemm(transA, transB, alpha, dense, B, beta, C, mA, kA, nB) case _ => @@ -370,19 +367,19 @@ private[mllib] object BLAS extends Serializable { * For `SparseMatrix` A. */ private def gemm( - transA: String, - transB: String, - alpha: Double, - A: SparseMatrix, - B: DenseMatrix, - beta: Double, - C: DenseMatrix, - transposeA: Boolean, - transposeB: Boolean, - mA: Int, - kA: Int, - nB: Int) { + transA: String, + transB: String, + alpha: Double, + A: SparseMatrix, + B: DenseMatrix, + beta: Double, + C: DenseMatrix, + mA: Int, + kA: Int, + nB: Int) { + val transposeA = A.numCols == mA + val transposeB = B.numRows == nB val Avals = A.toArray val Arows = if (!transposeA) A.rowIndices else A.colIndices @@ -446,7 +443,6 @@ private[mllib] object BLAS extends Serializable { colCounterForB += 1 } - elementCount += 1 } colCounterForA += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 14d9aea4768e9..b2d1320e940c0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -78,7 +78,7 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) private[mllib] override def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values) private[mllib] override def apply(i: Int): Double = values(i) - private[mllib] override def apply(r: Int, c: Int): Double = values(index(r,c)) + private[mllib] override def apply(r: Int, c: Int): Double = values(index(r, c)) private[mllib] def index(r: Int, c: Int): Int = r + numRows * c @@ -98,32 +98,32 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) */ object DenseMatrix { - def zeros(rows: Int, cols: Int) = new DenseMatrix(rows, cols, Array.fill(rows*cols)(0.0)) + def zeros(rows: Int, cols: Int) = new DenseMatrix(rows, cols, Array.fill(rows * cols)(0.0)) - def ones(rows: Int, cols: Int) = new DenseMatrix(rows, cols, Array.fill(rows*cols)(1.0)) + def ones(rows: Int, cols: Int) = new DenseMatrix(rows, cols, Array.fill(rows * cols)(1.0)) def eye(n: Int) = { val identity = DenseMatrix.zeros(n,n) for (i <- 0 until n){ - identity.update(i,i,1.0) + identity.update(i, i, 1.0) } identity } def rand(rows: Int, cols: Int) = { val rand = new scala.util.Random - new DenseMatrix(rows,cols, Array.fill(rows*cols)(rand.nextDouble())) + new DenseMatrix(rows,cols, Array.fill(rows * cols)(rand.nextDouble())) } def randn(rows: Int, cols: Int) = { val rand = new scala.util.Random - new DenseMatrix(rows,cols, Array.fill(rows*cols)(rand.nextGaussian())) + new DenseMatrix(rows,cols, Array.fill(rows * cols)(rand.nextGaussian())) } def diag(values: Array[Double]) = { val n = values.length val matrix = DenseMatrix.eye(n) - for (i <- 0 until n) matrix.update(i,i,values(i)) + for (i <- 0 until n) matrix.update(i, i, values(i)) matrix } } @@ -169,7 +169,7 @@ class SparseMatrix(val numRows: Int, private[mllib] def index(r: Int, c: Int): Int = { val regionStart = colIndices(c) - val regionEnd = colIndices(c+1) + val regionEnd = colIndices(c + 1) val region = rowIndices.slice(regionStart, regionEnd) if (region.contains(r)){ region.indexOf(r) + regionStart @@ -181,7 +181,7 @@ class SparseMatrix(val numRows: Int, // TODO(Burak): Maybe convert to Breeze to update zero entries? I can't think of any MLlib // TODO: algorithm that would use mutable Sparse Matrices private[mllib] override def update(r: Int, c: Int, v: Double){ - val ind = index(r,c) + val ind = index(r, c) if (ind == -1){ throw new IllegalArgumentException("The given row and column indices correspond to a zero " + "value. Sparse Matrices are currently immutable.")