From 0e7a3d8599d6eb677e734cd3fadc27d6942a40f9 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 6 Apr 2014 19:25:07 -0700 Subject: [PATCH] Keep vectors sparse in Java when reading LabeledPoints --- .../mllib/api/python/PythonMLLibAPI.scala | 8 +-- .../apache/spark/mllib/linalg/Vectors.scala | 49 +++++++++++++++++-- .../spark/mllib/linalg/VectorsSuite.scala | 38 ++++++++++++++ 3 files changed, 88 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index e8d11870fdc02..02633431f569c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -171,8 +171,8 @@ class PythonMLLibAPI extends Serializable { dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { val data = dataBytesJRDD.rdd.map(xBytes => { - val x = deserializeDoubleVector(xBytes).toArray // TODO: deal with sparse vectors here! - LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length))) + val x = deserializeDoubleVector(xBytes) + LabeledPoint(x(0), x.slice(1, x.size)) }) val initialWeights = deserializeDoubleVector(initialWeightsBA) val model = trainFunc(data, initialWeights) @@ -300,8 +300,8 @@ class PythonMLLibAPI extends Serializable { dataBytesJRDD: JavaRDD[Array[Byte]], lambda: Double): java.util.List[java.lang.Object] = { val data = dataBytesJRDD.rdd.map(xBytes => { - val x = deserializeDoubleVector(xBytes).toArray // TODO: make this efficient for sparse vecs - LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length))) + val x = deserializeDoubleVector(xBytes) + LabeledPoint(x(0), x.slice(1, x.size)) }) val model = NaiveBayes.train(data, lambda) val ret = new java.util.LinkedList[java.lang.Object]() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 99a849f1c66b1..6cbe599c1d0db 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -60,6 +60,8 @@ trait Vector extends Serializable { * @param i index */ private[mllib] def apply(i: Int): Double = toBreeze(i) + + private[mllib] def slice(start: Int, end: Int): Vector } /** @@ -130,9 +132,11 @@ object Vectors { private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = { breezeVector match { case v: BDV[Double] => - require(v.offset == 0, s"Do not support non-zero offset ${v.offset}.") - require(v.stride == 1, s"Do not support stride other than 1, but got ${v.stride}.") - new DenseVector(v.data) + if (v.offset == 0 && v.stride == 1) { + new DenseVector(v.data) + } else { + new DenseVector(v.toArray) // Can't use underlying array directly, so make a new one + } case v: BSV[Double] => new SparseVector(v.length, v.index, v.data) case v: BV[_] => @@ -155,6 +159,10 @@ class DenseVector(val values: Array[Double]) extends Vector { private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) override def apply(i: Int) = values(i) + + private[mllib] override def slice(start: Int, end: Int): Vector = { + new DenseVector(values.slice(start, end)) + } } /** @@ -185,4 +193,39 @@ class SparseVector( } private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) + + override def apply(pos: Int): Double = { + // A more efficient apply() than creating a new Breeze vector + var i = 0 + while (i < indices.length) { + if (indices(i) == pos) { + return values(i) + } else if (indices(i) > pos) { + return 0.0 + } + i += 1 + } + 0.0 + } + + private[mllib] override def slice(start: Int, end: Int): Vector = { + require(start <= end, s"invalid range: ${start} to ${end}") + require(start >= 0, s"invalid range: ${start} to ${end}") + require(end <= size, s"invalid range: ${start} to ${end}") + // Figure out the range of indices that fall within the given bounds + var i = 0 + var indexRangeStart = 0 + var indexRangeEnd = 0 + while (i < indices.length && indices(i) < start) { + i += 1 + } + indexRangeStart = i + while (i < indices.length && indices(i) < end) { + i += 1 + } + indexRangeEnd = i + val newIndices = indices.slice(indexRangeStart, indexRangeEnd).map(_ - start) + val newValues = values.slice(indexRangeStart, indexRangeEnd) + new SparseVector(end - start, newIndices, newValues) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 8a200310e0bb1..098a99e098dc3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -82,4 +82,42 @@ class VectorsSuite extends FunSuite { assert(v.## != another.##) } } + + test("indexing dense vectors") { + val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0) + assert(vec(0) === 1.0) + assert(vec(3) === 4.0) + } + + test("indexing sparse vectors") { + val vec = Vectors.sparse(7, Array(0, 2, 4, 6), Array(1.0, 2.0, 3.0, 4.0)) + assert(vec(0) === 1.0) + assert(vec(1) === 0.0) + assert(vec(2) === 2.0) + assert(vec(3) === 0.0) + assert(vec(6) === 4.0) + val vec2 = Vectors.sparse(8, Array(0, 2, 4, 6), Array(1.0, 2.0, 3.0, 4.0)) + assert(vec2(6) === 4.0) + assert(vec2(7) === 0.0) + } + + test("slicing dense vectors") { + val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0) + val slice = vec.slice(1, 3) + assert(slice === Vectors.dense(2.0, 3.0)) + assert(slice.isInstanceOf[DenseVector], "slice was not DenseVector") + } + + test("slicing sparse vectors") { + val vec = Vectors.sparse(7, Array(0, 2, 4, 6), Array(1.0, 2.0, 3.0, 4.0)) + val slice = vec.slice(1, 5) + assert(slice === Vectors.sparse(4, Array(1,3), Array(2.0, 3.0))) + assert(slice.isInstanceOf[SparseVector], "slice was not SparseVector") + val slice2 = vec.slice(1, 2) + assert(slice2 === Vectors.sparse(1, Array(), Array())) + assert(slice2.isInstanceOf[SparseVector], "slice was not SparseVector") + val slice3 = vec.slice(6, 7) + assert(slice3 === Vectors.sparse(1, Array(0), Array(4.0))) + assert(slice3.isInstanceOf[SparseVector], "slice was not SparseVector") + } }