diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 073b7542da4c2..5331e0e855157 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -77,4 +77,39 @@ class MatricesSuite extends FunSuite { assert(!sparseMat.toArray.eq(sparseCopy.toArray)) } + + test("matrix indexing and updating") { + val m = 3 + val n = 2 + val allValues = Array(0.0, 1.0, 2.0, 3.0, 4.0, 0.0) + + val denseMat = new DenseMatrix(m, n, allValues) + + assert(denseMat(0, 1) == 3.0) + assert(denseMat(0, 1) == denseMat.values(3)) + assert(denseMat(0, 1) == denseMat(3)) + assert(denseMat(0, 0) == 0.0) + + denseMat.update(0, 0, 10.0) + assert(denseMat(0, 0) == 10.0) + assert(denseMat.values(0) == 10.0) + + val sparseValues = Array(1.0, 2.0, 3.0, 4.0) + val colIndices = Array(0, 2, 4) + val rowIndices = Array(1, 2, 0, 1) + val sparseMat = new SparseMatrix(m, n, colIndices, rowIndices, sparseValues) + + assert(sparseMat(0, 1) == 3.0) + assert(sparseMat(0, 1) == denseMat.values(2)) + assert(sparseMat(0, 1) == denseMat(2)) + assert(sparseMat(0, 0) == 0.0) + + intercept[IllegalArgumentException] { + sparseMat.update(0, 0, 10.0) + } + + sparseMat.update(0, 1, 10.0) + assert(sparseMat(0, 1) == 10.0) + assert(sparseMat.values(2) == 10.0) + } }