Skip to content

Commit

Permalink
9/17 comments addressed
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed Sep 18, 2014
1 parent 7af2f83 commit 4b7dbec
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 130 deletions.
82 changes: 21 additions & 61 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,8 @@ private[mllib] object BLAS extends Serializable with Logging {

/**
* C := alpha * A * B + beta * C
* @param transA specify whether to use matrix A, or the transpose of matrix A. Should be "N" or
* "n" to use A, and "T" or "t" to use the transpose of A.
* @param transB specify whether to use matrix B, or the transpose of matrix B. Should be "N" or
* "n" to use B, and "T" or "t" to use the transpose of B.
* @param transA whether to use the transpose of matrix A (true), or A itself (false).
* @param transB whether to use the transpose of matrix B (true), or B itself (false).
* @param alpha a scalar to scale the multiplication A * B.
* @param A the matrix A that will be left multiplied to B. Size of m x k.
* @param B the matrix B that will be left multiplied by A. Size of k x n.
Expand All @@ -231,7 +229,7 @@ private[mllib] object BLAS extends Serializable with Logging {
beta: Double,
C: DenseMatrix): Unit = {
if (alpha == 0.0) {
logWarning("gemm: alpha is equal to 0. Returning C.")
logDebug("gemm: alpha is equal to 0. Returning C.")
} else {
A match {
case sparse: SparseMatrix =>
Expand Down Expand Up @@ -319,7 +317,7 @@ private[mllib] object BLAS extends Serializable with Logging {
// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
if (transA){
var colCounterForB = 0
if (!transB){ // Expensive to put the check inside the loop
if (!transB) { // Expensive to put the check inside the loop
while (colCounterForB < nB) {
var rowCounterForA = 0
val Cstart = colCounterForB * mA
Expand Down Expand Up @@ -360,21 +358,22 @@ private[mllib] object BLAS extends Serializable with Logging {
} else {
// Scale matrix first if `beta` is not equal to 0.0
if (beta != 0.0){
nativeBLAS.dscal(C.values.length, beta, C.values, 1)
f2jBLAS.dscal(C.values.length, beta, C.values, 1)
}
// Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of
// B, and added to C.
var colCounterForB = 0 // the column to be updated in C
if (!transB) { // Expensive to put the check inside the loop
while (colCounterForB < nB) {
var colCounterForA = 0 // The column of A to multiply with the row of B
while (colCounterForA < kA){
val Bstart = colCounterForB * kB
val Cstart = colCounterForB * mA
while (colCounterForA < kA) {
var i = Acols(colCounterForA)
val indEnd = Acols(colCounterForA + 1)
val Bval = B(colCounterForA, colCounterForB)
val Cstart = colCounterForB * mA
val Bval = B.values(Bstart + colCounterForA) * alpha
while (i < indEnd){
C.values(Cstart + Arows(i)) += Avals(i) * Bval * alpha
C.values(Cstart + Arows(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
Expand All @@ -384,13 +383,13 @@ private[mllib] object BLAS extends Serializable with Logging {
} else {
while (colCounterForB < nB) {
var colCounterForA = 0 // The column of A to multiply with the row of B
val Cstart = colCounterForB * mA
while (colCounterForA < kA){
var i = Acols(colCounterForA)
val indEnd = Acols(colCounterForA + 1)
val Bval = B(colCounterForB, colCounterForA)
val Cstart = colCounterForB * mA
val Bval = B(colCounterForB, colCounterForA) * alpha
while (i < indEnd){
C.values(Cstart + Arows(i)) += Avals(i) * Bval * alpha
C.values(Cstart + Arows(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
Expand All @@ -403,8 +402,7 @@ private[mllib] object BLAS extends Serializable with Logging {

/**
* y := alpha * A * x + beta * y
* @param trans specify whether to use matrix A, or the transpose of matrix A. Should be "N" or
* "n" to use A, and "T" or "t" to use the transpose of A.
* @param trans whether to use the transpose of matrix A (true), or A itself (false).
* @param alpha a scalar to scale the multiplication A * x.
* @param A the matrix A that will be left multiplied to x. Size of m x n.
* @param x the vector x that will be left multiplied by A. Size of n x 1.
Expand All @@ -427,7 +425,7 @@ private[mllib] object BLAS extends Serializable with Logging {
require(mA == y.size,
s"The rows of A don't match the number of elements of y. A: $mA, y:${y.size}}")
if (alpha == 0.0) {
logWarning("gemv: alpha is equal to 0. Returning y.")
logDebug("gemv: alpha is equal to 0. Returning y.")
} else {
A match {
case sparse: SparseMatrix =>
Expand Down Expand Up @@ -458,47 +456,6 @@ private[mllib] object BLAS extends Serializable with Logging {
gemv(false, alpha, A, x, beta, y)
}

/**
* y := alpha * A * x
*
* @param trans specify whether to use matrix A, or the transpose of matrix A. Should be "N" or
* "n" to use A, and "T" or "t" to use the transpose of A.
* @param alpha a scalar to scale the multiplication A * x.
* @param A the matrix A that will be left multiplied to x. Size of m x n.
* @param x the vector x that will be left multiplied by A. Size of n x 1.
*
* @return `DenseVector` y, the result of the matrix-vector multiplication. Size of m x 1.
*/
def gemv(
trans: Boolean,
alpha: Double,
A: Matrix,
x: DenseVector): DenseVector = {
val m = if(!trans) A.numRows else A.numCols

val y: DenseVector = new DenseVector(Array.fill(m)(0.0))
gemv(trans, alpha, A, x, 0.0, y)

y
}

/**
* y := alpha * A * x
*
* @param alpha a scalar to scale the multiplication A * x.
* @param A the matrix A that will be left multiplied to x. Size of m x n.
* @param x the vector x that will be left multiplied by A. Size of n x 1.
*
* @return `DenseVector` y, the result of the matrix-vector multiplication. Size of m x 1.
*/
def gemv(
alpha: Double,
A: Matrix,
x: DenseVector): DenseVector = {
gemv(false, alpha, A, x)
}


/**
* y := alpha * A * x + beta * y
* For `DenseMatrix` A.
Expand Down Expand Up @@ -539,8 +496,9 @@ private[mllib] object BLAS extends Serializable with Logging {
var rowCounter = 0
while (rowCounter < mA){
var i = Arows(rowCounter)
val indEnd = Arows(rowCounter + 1)
var sum = 0.0
while(i < Arows(rowCounter + 1)){
while(i < indEnd){
sum += Avals(i) * x.values(Acols(i))
i += 1
}
Expand All @@ -556,9 +514,11 @@ private[mllib] object BLAS extends Serializable with Logging {
var colCounterForA = 0
while (colCounterForA < nA){
var i = Acols(colCounterForA)
while (i < Acols(colCounterForA + 1)){
val indEnd = Acols(colCounterForA + 1)
val xVal = x.values(colCounterForA) * alpha
while (i < indEnd){
val rowIndex = Arows(i)
y.values(rowIndex) += Avals(i) * x.values(colCounterForA) * alpha
y.values(rowIndex) += Avals(i) * xVal
i += 1
}
colCounterForA += 1
Expand Down
63 changes: 28 additions & 35 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
package org.apache.spark.mllib.linalg

import breeze.linalg.{Matrix => BM, DenseMatrix => BDM, CSCMatrix => BSM}

import org.apache.spark.util.random.XORShiftRandom

import java.util.Arrays

/**
* Trait for a local matrix.
*/
Expand All @@ -37,40 +40,45 @@ sealed trait Matrix extends Serializable {
/** Converts to a breeze matrix. */
private[mllib] def toBreeze: BM[Double]

/** Gets the i-th element in the array backing the matrix. */
private[mllib] def apply(i: Int): Double

/** Gets the (i, j)-th element. */
private[mllib] def apply(i: Int, j: Int): Double

/** Return the index for the (i, j)-th element in the backing array. */
private[mllib] def index(i: Int, j: Int): Int

/** Update element at (i, j) */
private[mllib] def update(i: Int, j: Int, v: Double)
private[mllib] def update(i: Int, j: Int, v: Double): Unit

/** Get a deep copy of the matrix. */
def copy: Matrix

/** Convenience method for `Matrix`-`DenseMatrix` multiplication. */
def times(y: DenseMatrix): DenseMatrix = {
def multiply(y: DenseMatrix): DenseMatrix = {
val C: DenseMatrix = Matrices.zeros(numRows, y.numCols).asInstanceOf[DenseMatrix]
BLAS.gemm(false, false, 1.0, this, y, 0.0, C)
C
}

/** Convenience method for `Matrix`-`DenseVector` multiplication. */
def times(y: DenseVector): DenseVector = BLAS.gemv(1.0, this, y)
def multiply(y: DenseVector): DenseVector = {
val output = new DenseVector(new Array[Double](numRows))
BLAS.gemv(1.0, this, y, 0.0, output)
output
}

/** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */
def transposeTimes(y: DenseMatrix): DenseMatrix = {
def transposeMultiply(y: DenseMatrix): DenseMatrix = {
val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix]
BLAS.gemm(true, false, 1.0, this, y, 0.0, C)
C
}

/** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */
def transposeTimes(y: DenseVector): DenseVector = BLAS.gemv(true, 1.0, this, y)
def transposeMultiply(y: DenseVector): DenseVector = {
val output = new DenseVector(new Array[Double](numCols))
BLAS.gemv(true, 1.0, this, y, 0.0, output)
output
}

/** A human readable representation of the matrix */
override def toString: String = toBreeze.toString()
Expand Down Expand Up @@ -106,7 +114,7 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double])

private[mllib] def index(i: Int, j: Int): Int = i + numRows * j

private[mllib] def update(i: Int, j: Int, v: Double){
private[mllib] def update(i: Int, j: Int, v: Double): Unit = {
values(index(i, j)) = v
}

Expand All @@ -128,7 +136,8 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double])
* @param numRows number of rows
* @param numCols number of columns
* @param colPtrs the index corresponding to the start of a new column
* @param rowIndices the row index of the entry
* @param rowIndices the row index of the entry. They must be in strictly increasing order for each
* column
* @param values non-zero matrix entries in column major
*/
class SparseMatrix(
Expand All @@ -145,7 +154,7 @@ class SparseMatrix(
s"numCols: $numCols")

override def toArray: Array[Double] = {
val arr = Array.fill(numRows * numCols)(0.0)
val arr = new Array[Double](numRows * numCols)
var j = 0
while (j < numCols) {
var i = colPtrs(j)
Expand All @@ -164,35 +173,19 @@ class SparseMatrix(
private[mllib] def toBreeze: BM[Double] =
new BSM[Double](values, numRows, numCols, colPtrs, rowIndices)

private[mllib] def apply(i: Int): Double = values(i)

private[mllib] def apply(i: Int, j: Int): Double = {
val ind = index(i, j)
if (ind == -1) 0.0 else values(ind)
if (ind < 0) 0.0 else values(ind)
}

private[mllib] def index(i: Int, j: Int): Int = {
var regionStart = colPtrs(j)
var regionEnd = colPtrs(j + 1)
while (regionStart <= regionEnd) {
val mid = (regionStart + regionEnd) / 2
if (rowIndices(mid) == i){
return mid
} else if (regionStart == regionEnd) {
return -1
} else if (rowIndices(mid) > i) {
regionEnd = mid
} else {
regionStart = mid
}
}
-1
Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i)
}

private[mllib] def update(i: Int, j: Int, v: Double){
private[mllib] def update(i: Int, j: Int, v: Double): Unit = {
val ind = index(i, j)
if (ind == -1){
throw new IllegalArgumentException("The given row and column indices correspond to a zero " +
throw new NoSuchElementException("The given row and column indices correspond to a zero " +
"value. Only non-zero elements in Sparse Matrices can be updated.")
} else {
values(index(i, j)) = v
Expand Down Expand Up @@ -223,17 +216,17 @@ object Matrices {
*
* @param numRows number of rows
* @param numCols number of columns
* @param colPointers the index corresponding to the start of a new column
* @param colPtrs the index corresponding to the start of a new column
* @param rowIndices the row index of the entry
* @param values non-zero matrix entries in column major
*/
def sparse(
numRows: Int,
numCols: Int,
colPointers: Array[Int],
colPtrs: Array[Int],
rowIndices: Array[Int],
values: Array[Double]): Matrix = {
new SparseMatrix(numRows, numCols, colPointers, rowIndices, values)
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values)
}

/**
Expand Down Expand Up @@ -262,7 +255,7 @@ object Matrices {
* @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros
*/
def zeros(numRows: Int, numCols: Int): Matrix =
new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(0.0))
new DenseMatrix(numRows, numCols, new Array[Double](numRows * numCols))

/**
* Generate a `DenseMatrix` consisting of ones.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ class BLASSuite extends FunSuite {
val B = new DenseMatrix(3, 2, Array(1.0, 0.0, 0.0, 0.0, 2.0, 1.0))
val expected = new DenseMatrix(4, 2, Array(0.0, 1.0, 0.0, 0.0, 4.0, 0.0, 2.0, 3.0))

assert(dA times B ~== expected absTol 1e-15)
assert(sA times B ~== expected absTol 1e-15)
assert(dA multiply B ~== expected absTol 1e-15)
assert(sA multiply B ~== expected absTol 1e-15)

val C1 = new DenseMatrix(4, 2, Array(1.0, 0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0))
val C2 = C1.copy
Expand Down Expand Up @@ -170,8 +170,8 @@ class BLASSuite extends FunSuite {
val sAT =
new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0))

assert(dAT transposeTimes B ~== expected absTol 1e-15)
assert(sAT transposeTimes B ~== expected absTol 1e-15)
assert(dAT transposeMultiply B ~== expected absTol 1e-15)
assert(sAT transposeMultiply B ~== expected absTol 1e-15)

gemm(true, false, 1.0, dAT, B, 2.0, C5)
gemm(true, false, 1.0, sAT, B, 2.0, C6)
Expand All @@ -181,7 +181,6 @@ class BLASSuite extends FunSuite {
assert(C6 ~== expected2 absTol 1e-15)
assert(C7 ~== expected3 absTol 1e-15)
assert(C8 ~== expected3 absTol 1e-15)

}

test("gemv") {
Expand All @@ -193,8 +192,8 @@ class BLASSuite extends FunSuite {
val x = new DenseVector(Array(1.0, 2.0, 3.0))
val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0))

assert(dA times x ~== expected absTol 1e-15)
assert(sA times x ~== expected absTol 1e-15)
assert(dA multiply x ~== expected absTol 1e-15)
assert(sA multiply x ~== expected absTol 1e-15)

val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
val y2 = y1.copy
Expand Down Expand Up @@ -226,8 +225,8 @@ class BLASSuite extends FunSuite {
val sAT =
new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0))

assert(dAT transposeTimes x ~== expected absTol 1e-15)
assert(sAT transposeTimes x ~== expected absTol 1e-15)
assert(dAT transposeMultiply x ~== expected absTol 1e-15)
assert(sAT transposeMultiply x ~== expected absTol 1e-15)

gemv(true, 1.0, dAT, x, 2.0, y5)
gemv(true, 1.0, sAT, x, 2.0, y6)
Expand Down
Loading

0 comments on commit 4b7dbec

Please sign in to comment.