Skip to content

Commit

Permalink
[SPARK-8563] [MLLIB] Fixed a bug so that IndexedRowMatrix.computeSVD(…
Browse files Browse the repository at this point in the history
…).U.numCols = k

I'm sorry that I made #6949 closed by mistake.
I pushed codes again.

And, I added a test code.

>
There is a bug that `U.numCols() = self.nCols` in `IndexedRowMatrix.computeSVD()`
It should have been `U.numCols() = k = svd.U.numCols()`

>
```
self = U * sigma * V.transpose
(m x n) = (m x n) * (k x k) * (k x n) //ASIS
-->
(m x n) = (m x k) * (k x k) * (k x n) //TOBE
```

Author: lee19 <[email protected]>

Closes #6953 from lee19/MLlibBugfix and squashes the following commits:

c1812a0 [lee19] [SPARK-8563] [MLlib] Used nRows instead of numRows() to reduce a burden.
4b9803b [lee19] [SPARK-8563] [MLlib] Fixed a build error.
c2ccd89 [lee19] Added a unit test that validates matrix sizes of svd for [SPARK-8563][MLlib]
8373424 [lee19] [SPARK-8563][MLlib] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols = k
  • Loading branch information
lee19 authored and mengxr committed Jun 30, 2015
1 parent 8c89896 commit e725262
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class IndexedRowMatrix(
val indexedRows = indices.zip(svd.U.rows).map { case (i, v) =>
IndexedRow(i, v)
}
new IndexedRowMatrix(indexedRows, nRows, nCols)
new IndexedRowMatrix(indexedRows, nRows, svd.U.numCols().toInt)
} else {
null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,17 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(closeToZero(U * brzDiag(s) * V.t - localA))
}

test("validate matrix sizes of svd") {
val k = 2
val A = new IndexedRowMatrix(indexedRows)
val svd = A.computeSVD(k, computeU = true)
assert(svd.U.numRows() === m)
assert(svd.U.numCols() === k)
assert(svd.s.size === k)
assert(svd.V.numRows === n)
assert(svd.V.numCols === k)
}

test("validate k in svd") {
val A = new IndexedRowMatrix(indexedRows)
intercept[IllegalArgumentException] {
Expand Down

0 comments on commit e725262

Please sign in to comment.