Skip to content

Commit

Permalink
assertMatrixApproximatelyEquals
Browse files Browse the repository at this point in the history
  • Loading branch information
rezazadeh committed Mar 20, 2014
1 parent 3787bb4 commit e2667d4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ class PCASuite extends FunSuite with BeforeAndAfterAll {
ret
}

def assertMatrixEquals(a: DoubleMatrix, b: DoubleMatrix) {
def assertMatrixApproximatelyEquals(a: DoubleMatrix, b: DoubleMatrix) {
assert(a.rows == b.rows && a.columns == b.columns,
"dimension mismatch: $a.rows vs $b.rows and $a.columns vs $b.columns")
val diff = DoubleMatrix.zeros(a.rows, a.columns)
Array.tabulate(a.rows, a.columns) { (i, j) =>
diff.put(i, j,
Math.min(Math.abs(a.get(i, j) - b.get(i, j)), Math.abs(a.get(i, j) + b.get(i, j))))
for (i <- 0 until a.columns) {
val aCol = a.getColumn(i)
val bCol = b.getColumn(i)
val diff = Math.min(aCol.sub(bCol).norm1, aCol.add(bCol).norm1)
assert(diff < EPSILON, "matrix mismatch: " + diff)
}
assert(diff.norm1 < EPSILON, "matrix mismatch: " + diff.norm1)
}

test("full rank matrix pca") {
Expand All @@ -78,7 +78,7 @@ class PCASuite extends FunSuite with BeforeAndAfterAll {

val coeffs = new DoubleMatrix(new PCA().setK(n).compute(a))

assertMatrixEquals(getDenseMatrix(SparseMatrix(realPCA,n,n)), coeffs)
assertMatrixApproximatelyEquals(getDenseMatrix(SparseMatrix(realPCA,n,n)), coeffs)
}

test("sparse matrix full rank matrix pca") {
Expand All @@ -97,7 +97,7 @@ class PCASuite extends FunSuite with BeforeAndAfterAll {

val coeffs = new DoubleMatrix(new PCA().setK(n).compute(a))

assertMatrixEquals(getDenseMatrix(SparseMatrix(realPCA,n,n)), coeffs)
assertMatrixApproximatelyEquals(getDenseMatrix(SparseMatrix(realPCA,n,n)), coeffs)
}

test("truncated matrix pca") {
Expand All @@ -117,7 +117,7 @@ class PCASuite extends FunSuite with BeforeAndAfterAll {
val k = 2
val coeffs = new DoubleMatrix(new PCA().setK(k).compute(a))

assertMatrixEquals(getDenseMatrix(SparseMatrix(realPCA,n,k)), coeffs)
assertMatrixApproximatelyEquals(getDenseMatrix(SparseMatrix(realPCA,n,k)), coeffs)
}
}

Expand Down
47 changes: 24 additions & 23 deletions mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,15 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll {
ret
}

def assertMatrixEquals(a: DoubleMatrix, b: DoubleMatrix) {
assert(a.rows == b.rows && a.columns == b.columns, "dimension mismatch")
val diff = DoubleMatrix.zeros(a.rows, a.columns)
Array.tabulate(a.rows, a.columns){(i, j) =>
diff.put(i, j,
Math.min(Math.abs(a.get(i, j) - b.get(i, j)),
Math.abs(a.get(i, j) + b.get(i, j)))) }
assert(diff.norm1 < EPSILON, "matrix mismatch: " + diff.norm1)
def assertMatrixApproximatelyEquals(a: DoubleMatrix, b: DoubleMatrix) {
assert(a.rows == b.rows && a.columns == b.columns,
"dimension mismatch: $a.rows vs $b.rows and $a.columns vs $b.columns")
for (i <- 0 until a.columns) {
val aCol = a.getColumn(i)
val bCol = b.getColumn(i)
val diff = Math.min(aCol.sub(bCol).norm1, aCol.add(bCol).norm1)
assert(diff < EPSILON, "matrix mismatch: " + diff)
}
}

test("full rank matrix svd") {
Expand All @@ -89,12 +90,12 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll {


// check individual decomposition
assertMatrixEquals(retu, svd(0))
assertMatrixEquals(rets, DoubleMatrix.diag(svd(1)))
assertMatrixEquals(retv, svd(2))
assertMatrixApproximatelyEquals(retu, svd(0))
assertMatrixApproximatelyEquals(rets, DoubleMatrix.diag(svd(1)))
assertMatrixApproximatelyEquals(retv, svd(2))

// check multiplication guarantee
assertMatrixEquals(retu.mmul(rets).mmul(retv.transpose), denseA)
assertMatrixApproximatelyEquals(retu.mmul(rets).mmul(retv.transpose), denseA)
}

test("dense full rank matrix svd") {
Expand All @@ -120,12 +121,12 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll {


// check individual decomposition
assertMatrixEquals(retu, svd(0))
assertMatrixEquals(rets, DoubleMatrix.diag(svd(1)))
assertMatrixEquals(retv, svd(2))
assertMatrixApproximatelyEquals(retu, svd(0))
assertMatrixApproximatelyEquals(rets, DoubleMatrix.diag(svd(1)))
assertMatrixApproximatelyEquals(retv, svd(2))

// check multiplication guarantee
assertMatrixEquals(retu.mmul(rets).mmul(retv.transpose), denseA)
assertMatrixApproximatelyEquals(retu.mmul(rets).mmul(retv.transpose), denseA)
}

test("rank one matrix svd") {
Expand Down Expand Up @@ -153,12 +154,12 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll {
val retv = getDenseMatrix(v)

// check individual decomposition
assertMatrixEquals(retu, svd(0).getColumn(0))
assertMatrixEquals(rets, DoubleMatrix.diag(svd(1).getRow(0)))
assertMatrixEquals(retv, svd(2).getColumn(0))
assertMatrixApproximatelyEquals(retu, svd(0).getColumn(0))
assertMatrixApproximatelyEquals(rets, DoubleMatrix.diag(svd(1).getRow(0)))
assertMatrixApproximatelyEquals(retv, svd(2).getColumn(0))

// check multiplication guarantee
assertMatrixEquals(retu.mmul(rets).mmul(retv.transpose), denseA)
assertMatrixApproximatelyEquals(retu.mmul(rets).mmul(retv.transpose), denseA)
}

test("truncated with k") {
Expand Down Expand Up @@ -186,8 +187,8 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll {
assert(retrank == 1, "rank returned not one")

// check individual decomposition
assertMatrixEquals(retu, svd(0).getColumn(0))
assertMatrixEquals(rets, DoubleMatrix.diag(svd(1).getRow(0)))
assertMatrixEquals(retv, svd(2).getColumn(0))
assertMatrixApproximatelyEquals(retu, svd(0).getColumn(0))
assertMatrixApproximatelyEquals(rets, DoubleMatrix.diag(svd(1).getRow(0)))
assertMatrixApproximatelyEquals(retv, svd(2).getColumn(0))
}
}

0 comments on commit e2667d4

Please sign in to comment.