Skip to content

Commit

Permalink
Keep vectors sparse in Java when reading LabeledPoints
Browse files Browse the repository at this point in the history
  • Loading branch information
mateiz committed Apr 15, 2014
1 parent eaee759 commit 0e7a3d8
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ class PythonMLLibAPI extends Serializable {
dataBytesJRDD: JavaRDD[Array[Byte]],
initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(xBytes => {
val x = deserializeDoubleVector(xBytes).toArray // TODO: deal with sparse vectors here!
LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length)))
val x = deserializeDoubleVector(xBytes)
LabeledPoint(x(0), x.slice(1, x.size))
})
val initialWeights = deserializeDoubleVector(initialWeightsBA)
val model = trainFunc(data, initialWeights)
Expand Down Expand Up @@ -300,8 +300,8 @@ class PythonMLLibAPI extends Serializable {
dataBytesJRDD: JavaRDD[Array[Byte]],
lambda: Double): java.util.List[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(xBytes => {
val x = deserializeDoubleVector(xBytes).toArray // TODO: make this efficient for sparse vecs
LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length)))
val x = deserializeDoubleVector(xBytes)
LabeledPoint(x(0), x.slice(1, x.size))
})
val model = NaiveBayes.train(data, lambda)
val ret = new java.util.LinkedList[java.lang.Object]()
Expand Down
49 changes: 46 additions & 3 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ trait Vector extends Serializable {
* @param i index
*/
private[mllib] def apply(i: Int): Double = toBreeze(i)

private[mllib] def slice(start: Int, end: Int): Vector
}

/**
Expand Down Expand Up @@ -130,9 +132,11 @@ object Vectors {
private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = {
breezeVector match {
case v: BDV[Double] =>
require(v.offset == 0, s"Do not support non-zero offset ${v.offset}.")
require(v.stride == 1, s"Do not support stride other than 1, but got ${v.stride}.")
new DenseVector(v.data)
if (v.offset == 0 && v.stride == 1) {
new DenseVector(v.data)
} else {
new DenseVector(v.toArray) // Can't use underlying array directly, so make a new one
}
case v: BSV[Double] =>
new SparseVector(v.length, v.index, v.data)
case v: BV[_] =>
Expand All @@ -155,6 +159,10 @@ class DenseVector(val values: Array[Double]) extends Vector {
private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values)

override def apply(i: Int) = values(i)

private[mllib] override def slice(start: Int, end: Int): Vector = {
new DenseVector(values.slice(start, end))
}
}

/**
Expand Down Expand Up @@ -185,4 +193,39 @@ class SparseVector(
}

private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)

override def apply(pos: Int): Double = {
// A more efficient apply() than creating a new Breeze vector
var i = 0
while (i < indices.length) {
if (indices(i) == pos) {
return values(i)
} else if (indices(i) > pos) {
return 0.0
}
i += 1
}
0.0
}

private[mllib] override def slice(start: Int, end: Int): Vector = {
require(start <= end, s"invalid range: ${start} to ${end}")
require(start >= 0, s"invalid range: ${start} to ${end}")
require(end <= size, s"invalid range: ${start} to ${end}")
// Figure out the range of indices that fall within the given bounds
var i = 0
var indexRangeStart = 0
var indexRangeEnd = 0
while (i < indices.length && indices(i) < start) {
i += 1
}
indexRangeStart = i
while (i < indices.length && indices(i) < end) {
i += 1
}
indexRangeEnd = i
val newIndices = indices.slice(indexRangeStart, indexRangeEnd).map(_ - start)
val newValues = values.slice(indexRangeStart, indexRangeEnd)
new SparseVector(end - start, newIndices, newValues)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,42 @@ class VectorsSuite extends FunSuite {
assert(v.## != another.##)
}
}

test("indexing dense vectors") {
val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0)
assert(vec(0) === 1.0)
assert(vec(3) === 4.0)
}

test("indexing sparse vectors") {
val vec = Vectors.sparse(7, Array(0, 2, 4, 6), Array(1.0, 2.0, 3.0, 4.0))
assert(vec(0) === 1.0)
assert(vec(1) === 0.0)
assert(vec(2) === 2.0)
assert(vec(3) === 0.0)
assert(vec(6) === 4.0)
val vec2 = Vectors.sparse(8, Array(0, 2, 4, 6), Array(1.0, 2.0, 3.0, 4.0))
assert(vec2(6) === 4.0)
assert(vec2(7) === 0.0)
}

test("slicing dense vectors") {
val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0)
val slice = vec.slice(1, 3)
assert(slice === Vectors.dense(2.0, 3.0))
assert(slice.isInstanceOf[DenseVector], "slice was not DenseVector")
}

test("slicing sparse vectors") {
val vec = Vectors.sparse(7, Array(0, 2, 4, 6), Array(1.0, 2.0, 3.0, 4.0))
val slice = vec.slice(1, 5)
assert(slice === Vectors.sparse(4, Array(1,3), Array(2.0, 3.0)))
assert(slice.isInstanceOf[SparseVector], "slice was not SparseVector")
val slice2 = vec.slice(1, 2)
assert(slice2 === Vectors.sparse(1, Array(), Array()))
assert(slice2.isInstanceOf[SparseVector], "slice was not SparseVector")
val slice3 = vec.slice(6, 7)
assert(slice3 === Vectors.sparse(1, Array(0), Array(4.0)))
assert(slice3.isInstanceOf[SparseVector], "slice was not SparseVector")
}
}

0 comments on commit 0e7a3d8

Please sign in to comment.