Skip to content

Commit

Permalink
[SPARK-3418] New code review comments addressed
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed Sep 18, 2014
1 parent f35a161 commit 421045f
Showing 1 changed file with 13 additions and 21 deletions.
34 changes: 13 additions & 21 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ private[mllib] object BLAS extends Serializable {
A: Matrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix) {
C: DenseMatrix): Unit = {
gemm(false, false, alpha, A, B, beta, C)
}

Expand All @@ -267,7 +267,7 @@ private[mllib] object BLAS extends Serializable {
A: DenseMatrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix) {
C: DenseMatrix): Unit = {
val mA: Int = if (!transA) A.numRows else A.numCols
val nB: Int = if (!transB) B.numCols else B.numRows
val kA: Int = if (!transA) A.numCols else A.numRows
Expand Down Expand Up @@ -317,16 +317,17 @@ private[mllib] object BLAS extends Serializable {
while (colCounterForB < nB) {
var rowCounterForA = 0
val Cstart = colCounterForB * mA
val Bstart = colCounterForB * kA
while (rowCounterForA < mA) {
var i = Arows(rowCounterForA)
val indEnd = Arows(rowCounterForA + 1)
val Bstart = colCounterForB * kA
var sum = 0.0
while (i < indEnd) {
sum += Avals(i) * B(Bstart + Acols(i))
sum += Avals(i) * B.values(Bstart + Acols(i))
i += 1
}
C.values(rowCounterForA + Cstart) = beta * C.values(rowCounterForA + Cstart) + sum
val Cindex = Cstart + rowCounterForA
C.values(Cindex) = beta * C.values(Cindex) + sum
rowCounterForA += 1
}
colCounterForB += 1
Expand All @@ -343,7 +344,8 @@ private[mllib] object BLAS extends Serializable {
sum += Avals(i) * B(colCounterForB, Acols(i))
i += 1
}
C.values(rowCounter + Cstart) = beta * C.values(rowCounter + Cstart) + sum
val Cindex = Cstart + rowCounter
C.values(Cindex) = beta * C.values(Cindex) + sum
rowCounter += 1
}
colCounterForB += 1
Expand All @@ -352,12 +354,7 @@ private[mllib] object BLAS extends Serializable {
} else {
// Scale matrix first if `beta` is not equal to 0.0
if (beta != 0.0){
var i = 0
val Clength = C.numCols * C.numRows
while ( i < Clength) {
C.values(i) *= beta
i += 1
}
nativeBLAS.dscal(C.values.length, beta, C.values, 1)
}
// Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of
// B, and added to C.
Expand All @@ -371,7 +368,7 @@ private[mllib] object BLAS extends Serializable {
val Bval = B(colCounterForA, colCounterForB)
val Cstart = colCounterForB * mA
while (i < indEnd){
C.values(Arows(i) + Cstart) += Avals(i) * Bval
C.values(Cstart + Arows(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
Expand Down Expand Up @@ -503,7 +500,7 @@ private[mllib] object BLAS extends Serializable {
A: DenseMatrix,
x: DenseVector,
beta: Double,
y: DenseVector) {
y: DenseVector): Unit = {
val tStrA = if (!trans) "N" else "T"
nativeBLAS.dgemv(tStrA, A.numRows, A.numCols, alpha, A.values, A.numRows, x.values, 1, beta,
y.values, 1)
Expand All @@ -519,7 +516,7 @@ private[mllib] object BLAS extends Serializable {
A: SparseMatrix,
x: DenseVector,
beta: Double,
y: DenseVector) {
y: DenseVector): Unit = {

val mA: Int = if(!trans) A.numRows else A.numCols
val nA: Int = if(!trans) A.numCols else A.numRows
Expand All @@ -544,12 +541,7 @@ private[mllib] object BLAS extends Serializable {
} else {
// Scale vector first if `beta` is not equal to 0.0
if (beta != 0.0){
var i = 0
val yLength = y.size
while (i < yLength) {
y.values(i) *= beta
i += 1
}
scal(beta, y)
}
// Perform matrix-vector multiplication and add to y
var colCounterForA = 0
Expand Down

0 comments on commit 421045f

Please sign in to comment.