Skip to content

Commit

Permalink
[SPARK-3418] Fixed style issues and added documentation for methods
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed Sep 6, 2014
1 parent 41b2da3 commit 56d7c85
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 115 deletions.
143 changes: 77 additions & 66 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ package org.apache.spark.mllib.linalg
import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS}
import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}

import org.apache.commons.lang.StringEscapeUtils.escapeJava

/**
* BLAS routines for MLlib's vectors and matrices.
*/
Expand Down Expand Up @@ -223,14 +221,13 @@ private[mllib] object BLAS extends Serializable {
* @param C the resulting matrix C. Size of m x n.
*/
def gemm(
transA: String,
transB: String,
alpha: Double,
A: Matrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix) {

transA: String,
transB: String,
alpha: Double,
A: Matrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix) {
var mA: Int = A.numRows
var nB: Int = B.numCols
var kA: Int = A.numCols
Expand All @@ -241,15 +238,13 @@ private[mllib] object BLAS extends Serializable {
kA = A.numRows
}
require(transA == "T" || transA == "t" || transA == "N" || transA == "n",
s"Invalid argument used for transA: $transA. " +
escapeJava("Must be \"N\", \"n\", \"T\", or \"t\""))
s"""Invalid argument used for transA: $transA. Must be \"N\", \"n\", \"T\", or \"t\"""")
if (transB == "T" || transB=="t"){
nB = B.numRows
kB = B.numCols
}
require(transB == "T" || transB == "t" || transB == "N" || transB == "n",
s"Invalid argument used for transB: $transB. " +
escapeJava("Must be \"N\", \"n\", \"T\", or \"t\""))
s"""Invalid argument used for transB: $transB. Must be \"N\", \"n\", \"T\", or \"t\"""")

require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB")
require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA")
Expand All @@ -258,9 +253,9 @@ private[mllib] object BLAS extends Serializable {

A match {
case sparse: SparseMatrix =>
gemm(transA, transB, alpha, sparse, B, beta, C, mA, kA, nB)
gemm(transA, transB, alpha, sparse, B, beta, C)
case dense: DenseMatrix =>
gemm(transA, transB, alpha, dense, B, beta, C, mA, kA, nB)
gemm(transA, transB, alpha, dense, B, beta, C)
case _ =>
throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.")
}
Expand All @@ -276,12 +271,11 @@ private[mllib] object BLAS extends Serializable {
* @param C the resulting matrix C. Size of m x n.
*/
def gemm(
alpha: Double,
A: Matrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix) {

alpha: Double,
A: Matrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix) {
gemm("N", "N", alpha, A, B, beta, C)
}

Expand All @@ -298,12 +292,12 @@ private[mllib] object BLAS extends Serializable {
*
* @return The resulting matrix C. Size of m x n.
*/
def gemm(transA: String,
transB: String,
alpha: Double,
A: Matrix,
B: DenseMatrix) : DenseMatrix = {

def gemm(
transA: String,
transB: String,
alpha: Double,
A: Matrix,
B: DenseMatrix) : DenseMatrix = {
var mA: Int = A.numRows
var nB: Int = B.numCols
var kA: Int = A.numCols
Expand All @@ -319,7 +313,6 @@ private[mllib] object BLAS extends Serializable {
}

val C: DenseMatrix = DenseMatrix.zeros(mA, nB)

gemm(transA, transB, alpha, A, B, 0.0, C)

C
Expand All @@ -338,7 +331,6 @@ private[mllib] object BLAS extends Serializable {
alpha: Double,
A: Matrix,
B: DenseMatrix) : DenseMatrix = {

gemm("N", "N", alpha, A, B)
}

Expand All @@ -353,10 +345,20 @@ private[mllib] object BLAS extends Serializable {
A: DenseMatrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix,
mA: Int,
kA: Int,
nB: Int) {
C: DenseMatrix) {
var mA: Int = A.numRows
var nB: Int = B.numCols
var kA: Int = A.numCols
var kB: Int = B.numRows

if (transA == "T" || transA=="t"){
mA = A.numCols
kA = A.numRows
}
if (transB == "T" || transB=="t"){
nB = B.numRows
kB = B.numCols
}

nativeBLAS.dgemm(transA,transB, mA, nB, kA, alpha, A.toArray, A.numRows, B.toArray, B.numRows,
beta, C.toArray, C.numRows)
Expand All @@ -373,17 +375,25 @@ private[mllib] object BLAS extends Serializable {
A: SparseMatrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix,
mA: Int,
kA: Int,
nB: Int) {

val transposeA = A.numCols == mA
val transposeB = B.numRows == nB
C: DenseMatrix) {
var transposeA = false
var transposeB = false
var mA: Int = A.numRows
var nB: Int = B.numCols
var kA: Int = A.numCols

if (transA == "T" || transA=="t"){
mA = A.numCols
kA = A.numRows
transposeA = true
}
if (transB == "T" || transB=="t"){
nB = B.numRows
transposeB = true
}
val Avals = A.toArray
val Arows = if (!transposeA) A.rowIndices else A.colIndices
val Acols = if (!transposeA) A.colIndices else A.rowIndices
val Arows = if (!transposeA) A.rowIndices else A.colPointers
val Acols = if (!transposeA) A.colPointers else A.rowIndices

// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
if (transposeA){
Expand Down Expand Up @@ -412,8 +422,8 @@ private[mllib] object BLAS extends Serializable {
// Scale matrix first if `beta` is not equal to 0.0
if (beta != 0.0){
var i = 0
val CLength = C.numCols * C.numRows
while ( i < CLength) {
val Clength = C.numCols * C.numRows
while ( i < Clength) {
C.values(i) *= beta
i += 1
}
Expand Down Expand Up @@ -478,16 +488,15 @@ private[mllib] object BLAS extends Serializable {
transposeA = true
}
require(trans == "T" || trans == "t" || trans == "N" || trans == "n",
s"Invalid argument used for trans: $trans. " +
escapeJava("Must be \"N\", \"n\", \"T\", or \"t\""))
s"""Invalid argument used for trans: $trans. Must be \"N\", \"n\", \"T\", or \"t\"""")

require(nA == nx, s"The columns of A don't match the number of elements of x. A: $nA, x: $nx")
require(mA == y.size,
s"The rows of A don't match the number of elements of y. A: $mA, y:${y.size}}")

A match {
case sparse: SparseMatrix =>
gemv(trans, alpha, sparse, x, beta, y, mA, nA, transposeA)
gemv(trans, alpha, sparse, x, beta, y)
case dense: DenseMatrix =>
gemv(trans, alpha, dense, x, beta, y)
case _ =>
Expand All @@ -505,12 +514,11 @@ private[mllib] object BLAS extends Serializable {
* @param y the resulting vector y. Size of m x 1.
*/
def gemv(
alpha: Double,
A: Matrix,
x: DenseVector,
beta: Double,
y: DenseVector) {

alpha: Double,
A: Matrix,
x: DenseVector,
beta: Double,
y: DenseVector) {
gemv("N", alpha, A, x, beta, y)
}

Expand All @@ -530,7 +538,6 @@ private[mllib] object BLAS extends Serializable {
alpha: Double,
A: Matrix,
x: DenseVector): DenseVector = {

val m = if(trans == "N" || trans == "n") A.numRows else A.numCols

val y: DenseVector = new DenseVector(Array.fill(m)(0.0))
Expand All @@ -549,10 +556,9 @@ private[mllib] object BLAS extends Serializable {
* @return `DenseVector` y, the result of the matrix-vector multiplication. Size of m x 1.
*/
def gemv(
alpha: Double,
A: Matrix,
x: DenseVector): DenseVector = {

alpha: Double,
A: Matrix,
x: DenseVector): DenseVector = {
gemv("N", alpha, A, x)
}

Expand All @@ -568,7 +574,6 @@ private[mllib] object BLAS extends Serializable {
x: DenseVector,
beta: Double,
y: DenseVector) {

nativeBLAS.dgemv(trans, A.numRows, A.numCols, alpha, A.toArray, A.numRows, x.toArray, 1, beta,
y.toArray, 1)
}
Expand All @@ -583,14 +588,20 @@ private[mllib] object BLAS extends Serializable {
A: SparseMatrix,
x: DenseVector,
beta: Double,
y: DenseVector,
mA: Int,
nA: Int,
transposeA: Boolean) {
y: DenseVector) {

var mA: Int = A.numRows
var nA: Int = A.numCols
var transposeA: Boolean = false
if (trans == "T" || trans=="t"){
mA = A.numCols
nA = A.numRows
transposeA = true
}

val Avals = A.toArray
val Arows = if (!transposeA) A.rowIndices else A.colIndices
val Acols = if (!transposeA) A.colIndices else A.rowIndices
val Arows = if (!transposeA) A.rowIndices else A.colPointers
val Acols = if (!transposeA) A.colPointers else A.rowIndices

// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
if (transposeA){
Expand Down
Loading

0 comments on commit 56d7c85

Please sign in to comment.