From e2667d40508637f05ead8cd65d4ec3d3546daf0e Mon Sep 17 00:00:00 2001 From: Reza Zadeh Date: Wed, 19 Mar 2014 15:26:59 -0700 Subject: [PATCH] assertMatrixApproximatelyEquals --- .../apache/spark/mllib/linalg/PCASuite.scala | 18 +++---- .../apache/spark/mllib/linalg/SVDSuite.scala | 47 ++++++++++--------- 2 files changed, 33 insertions(+), 32 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/PCASuite.scala index 664702b1852b4..5e5086b1bf73e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/PCASuite.scala @@ -52,15 +52,15 @@ class PCASuite extends FunSuite with BeforeAndAfterAll { ret } - def assertMatrixEquals(a: DoubleMatrix, b: DoubleMatrix) { + def assertMatrixApproximatelyEquals(a: DoubleMatrix, b: DoubleMatrix) { assert(a.rows == b.rows && a.columns == b.columns, "dimension mismatch: $a.rows vs $b.rows and $a.columns vs $b.columns") - val diff = DoubleMatrix.zeros(a.rows, a.columns) - Array.tabulate(a.rows, a.columns) { (i, j) => - diff.put(i, j, - Math.min(Math.abs(a.get(i, j) - b.get(i, j)), Math.abs(a.get(i, j) + b.get(i, j)))) + for (i <- 0 until a.columns) { + val aCol = a.getColumn(i) + val bCol = b.getColumn(i) + val diff = Math.min(aCol.sub(bCol).norm1, aCol.add(bCol).norm1) + assert(diff < EPSILON, "matrix mismatch: " + diff) } - assert(diff.norm1 < EPSILON, "matrix mismatch: " + diff.norm1) } test("full rank matrix pca") { @@ -78,7 +78,7 @@ class PCASuite extends FunSuite with BeforeAndAfterAll { val coeffs = new DoubleMatrix(new PCA().setK(n).compute(a)) - assertMatrixEquals(getDenseMatrix(SparseMatrix(realPCA,n,n)), coeffs) + assertMatrixApproximatelyEquals(getDenseMatrix(SparseMatrix(realPCA,n,n)), coeffs) } test("sparse matrix full rank matrix pca") { @@ -97,7 +97,7 @@ class PCASuite extends FunSuite with BeforeAndAfterAll { val coeffs = new DoubleMatrix(new PCA().setK(n).compute(a)) - assertMatrixEquals(getDenseMatrix(SparseMatrix(realPCA,n,n)), coeffs) + assertMatrixApproximatelyEquals(getDenseMatrix(SparseMatrix(realPCA,n,n)), coeffs) } test("truncated matrix pca") { @@ -117,7 +117,7 @@ class PCASuite extends FunSuite with BeforeAndAfterAll { val k = 2 val coeffs = new DoubleMatrix(new PCA().setK(k).compute(a)) - assertMatrixEquals(getDenseMatrix(SparseMatrix(realPCA,n,k)), coeffs) + assertMatrixApproximatelyEquals(getDenseMatrix(SparseMatrix(realPCA,n,k)), coeffs) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala index cd6454caa0e24..20e2b0f84be06 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala @@ -56,14 +56,15 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll { ret } - def assertMatrixEquals(a: DoubleMatrix, b: DoubleMatrix) { - assert(a.rows == b.rows && a.columns == b.columns, "dimension mismatch") - val diff = DoubleMatrix.zeros(a.rows, a.columns) - Array.tabulate(a.rows, a.columns){(i, j) => - diff.put(i, j, - Math.min(Math.abs(a.get(i, j) - b.get(i, j)), - Math.abs(a.get(i, j) + b.get(i, j)))) } - assert(diff.norm1 < EPSILON, "matrix mismatch: " + diff.norm1) + def assertMatrixApproximatelyEquals(a: DoubleMatrix, b: DoubleMatrix) { + assert(a.rows == b.rows && a.columns == b.columns, + "dimension mismatch: $a.rows vs $b.rows and $a.columns vs $b.columns") + for (i <- 0 until a.columns) { + val aCol = a.getColumn(i) + val bCol = b.getColumn(i) + val diff = Math.min(aCol.sub(bCol).norm1, aCol.add(bCol).norm1) + assert(diff < EPSILON, "matrix mismatch: " + diff) + } } test("full rank matrix svd") { @@ -89,12 +90,12 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll { // check individual decomposition - assertMatrixEquals(retu, svd(0)) - assertMatrixEquals(rets, DoubleMatrix.diag(svd(1))) - assertMatrixEquals(retv, svd(2)) + assertMatrixApproximatelyEquals(retu, svd(0)) + assertMatrixApproximatelyEquals(rets, DoubleMatrix.diag(svd(1))) + assertMatrixApproximatelyEquals(retv, svd(2)) // check multiplication guarantee - assertMatrixEquals(retu.mmul(rets).mmul(retv.transpose), denseA) + assertMatrixApproximatelyEquals(retu.mmul(rets).mmul(retv.transpose), denseA) } test("dense full rank matrix svd") { @@ -120,12 +121,12 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll { // check individual decomposition - assertMatrixEquals(retu, svd(0)) - assertMatrixEquals(rets, DoubleMatrix.diag(svd(1))) - assertMatrixEquals(retv, svd(2)) + assertMatrixApproximatelyEquals(retu, svd(0)) + assertMatrixApproximatelyEquals(rets, DoubleMatrix.diag(svd(1))) + assertMatrixApproximatelyEquals(retv, svd(2)) // check multiplication guarantee - assertMatrixEquals(retu.mmul(rets).mmul(retv.transpose), denseA) + assertMatrixApproximatelyEquals(retu.mmul(rets).mmul(retv.transpose), denseA) } test("rank one matrix svd") { @@ -153,12 +154,12 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll { val retv = getDenseMatrix(v) // check individual decomposition - assertMatrixEquals(retu, svd(0).getColumn(0)) - assertMatrixEquals(rets, DoubleMatrix.diag(svd(1).getRow(0))) - assertMatrixEquals(retv, svd(2).getColumn(0)) + assertMatrixApproximatelyEquals(retu, svd(0).getColumn(0)) + assertMatrixApproximatelyEquals(rets, DoubleMatrix.diag(svd(1).getRow(0))) + assertMatrixApproximatelyEquals(retv, svd(2).getColumn(0)) // check multiplication guarantee - assertMatrixEquals(retu.mmul(rets).mmul(retv.transpose), denseA) + assertMatrixApproximatelyEquals(retu.mmul(rets).mmul(retv.transpose), denseA) } test("truncated with k") { @@ -186,8 +187,8 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll { assert(retrank == 1, "rank returned not one") // check individual decomposition - assertMatrixEquals(retu, svd(0).getColumn(0)) - assertMatrixEquals(rets, DoubleMatrix.diag(svd(1).getRow(0))) - assertMatrixEquals(retv, svd(2).getColumn(0)) + assertMatrixApproximatelyEquals(retu, svd(0).getColumn(0)) + assertMatrixApproximatelyEquals(rets, DoubleMatrix.diag(svd(1).getRow(0))) + assertMatrixApproximatelyEquals(retv, svd(2).getColumn(0)) } }