Skip to content

Commit

Permalink
Addressing comments in the PR. Now uses views to slice sub-arrays. Tw…
Browse files Browse the repository at this point in the history
…o pre-allocated buffers instead of one.
  • Loading branch information
fabuzaid21 committed Jan 12, 2016
1 parent 045dab2 commit 402b80b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 44 deletions.
74 changes: 35 additions & 39 deletions mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -800,8 +800,8 @@ private[ml] object AltDT extends Logging {
toOffset: Int,
split: Split,
numRows: Int): RoaringBitmap = {
val nodeRowIndices = col.indices.slice(fromOffset, toOffset)
val nodeRowValues = col.values.slice(fromOffset, toOffset)
val nodeRowIndices = col.indices.view.slice(fromOffset, toOffset)
val nodeRowValues = col.values.view.slice(fromOffset, toOffset)
val bitv = new RoaringBitmap()
var i = 0
while (i < nodeRowValues.length) {
Expand Down Expand Up @@ -841,9 +841,10 @@ private[ml] object AltDT extends Logging {
activeNodes: BitSet)
extends Serializable {

// pre-allocated temporary buffer that we use to sort
// pre-allocated temporary buffers that we use to sort
// instances in left and right children during update
val tempValsIndices: Array[(Double, Int)] = new Array[(Double, Int)](columns(0).values.length)
val tempVals: Array[Double] = new Array[Double](columns(0).values.length)
val tempIndices: Array[Int] = new Array[Int](columns(0).values.length)

/** For debugging */
override def toString: String = {
Expand Down Expand Up @@ -880,54 +881,53 @@ private[ml] object AltDT extends Logging {
val newNodeOffsets = nodeOffsets.map(Array(_))

val newColumns = columns.map { col =>
val iter = activeNodes.iterator
while (iter.hasNext) {
val nodeIdx = iter.next()
activeNodes.iterator.foreach { nodeIdx =>
val from = nodeOffsets(nodeIdx)
val to = nodeOffsets(nodeIdx + 1)
val rangeIndices = col.indices.slice(from, to)
val rangeValues = col.values.slice(from, to)
val rangeIndices = col.indices.view.slice(from, to)
val rangeValues = col.values.view.slice(from, to)

// if this is the very first time we split
// we don't have to use the indices to figure
// out which bits are turned on
// If this is the very first time we split,
// we don't use rangeIndices to count the number of bits set;
// the entire bit vector will be used, so getCardinality
// will give us the same result more cheaply.
val numBitsSet = if (nodeOffsets.length == 2) instanceBitVector.getCardinality
else rangeIndices.count(instanceBitVector.contains)
val numBitsNotSet = to - from - numBitsSet

val numBitsNotSet = to - from - numBitsSet // number of instances splitting left
val oldOffset = newNodeOffsets(nodeIdx).head
// numBitsNotSet == number of instances going to the left
// which is how big the offset should be
// if numBitsNotSet == 0, then this node was not split,

// If numBitsNotSet or numBitsSet equals 0, then this node was not split,
// so we do not need to update its part of the column. Otherwise, we update it.
if (numBitsNotSet == 0 || numBitsSet == 0) {
newNodeOffsets(nodeIdx) = Array(oldOffset)
} else {
if (numBitsNotSet != 0 && numBitsSet != 0) {
newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numBitsNotSet)
// Sort range [from, to) based on split, then value. This is required to match
// the bit vector across all workers. See [[bitVectorFromSplit]] for details.
// Within [from, to), we will have all "left child" instances (those that are false),
// then all "right child" instances. Then, within each child, we sort by value, so
// we can compute the best split for the next iteration. The corresponding index for
// an instance is used to look up the split value ("left" or "right") in the
// instanceBitVector, which is ordered by index.

// BEGIN SORTING
// between [from, numBitsNotSet) and [numBitsNotSet, to)
// the columns need to be sorted by value. Since @rangeValues
// has already been sorted by value, we iterate from beginning to end
// We sort the [from, to) slice of col based on instance bit, then
// instance value. This is required to match the bit vector across all
// workers. All instances going "left" in the split (which are false)
// should be ordered before the instances going "right". The instanceBitVector
// gives us the bit value for each instance based on the instance's index.
// Then both [from, numBitsNotSet) and [numBitsNotSet, to) need to be sorted
// by value.
// Since the column is already sorted by value, we can compute
// this sort in a single pass over the data. We iterate from start to finish
// (which preserves the sorted order), and then copy the values
// into a temporary buffer either 1) in the [from, numBitsNotSet) range
// or 2) in the [numBitsNotSet, to) range.
// into @tempVals and @tempIndices either:
// 1) in the [from, numBitsNotSet) range if the bit is false, or
// 2) in the [numBitsNotSet, to) range if the bit is true.
var (leftInstanceIdx, rightInstanceIdx) = (from, from + numBitsNotSet)
var idx = 0
while (idx < rangeValues.length) {
val indexForVal = rangeIndices(idx)
val bit = instanceBitVector.contains(indexForVal)
if (bit) {
tempValsIndices(rightInstanceIdx) = (rangeValues(idx), indexForVal)
tempVals(rightInstanceIdx) = rangeValues(idx)
tempIndices(rightInstanceIdx) = indexForVal
rightInstanceIdx += 1
} else {
tempValsIndices(leftInstanceIdx) = (rangeValues(idx), indexForVal)
tempVals(leftInstanceIdx) = rangeValues(idx)
tempIndices(leftInstanceIdx) = indexForVal
leftInstanceIdx += 1
}
idx += 1
Expand All @@ -936,12 +936,8 @@ private[ml] object AltDT extends Logging {

// update the column values and indices
// with the corresponding indices
var i = 0
while (i < rangeValues.length) {
col.values(from + i) = tempValsIndices(from + i)._1
col.indices(from + i) = tempValsIndices(from + i)._2
i += 1
}
Array.copy(tempVals, from, col.values, from, rangeValues.length)
Array.copy(tempIndices, from, col.indices, from, rangeValues.length)
}
}
col
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,11 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
val col = new FeatureVector(0, 0, Array(0.1, 0.2, 0.4, 0.6, 0.7),
Array(4, 2, 0, 1, 3))
def checkSplit(fromOffset: Int, toOffset: Int, threshold: Double,
expectedRight: Set[Int]): Unit = {
val split = new ContinuousSplit(0, threshold)
val numRows = col.values.length
val bitv = AltDT.bitVectorFromSplit(col, fromOffset, toOffset, split, numRows)
assert(bitv.toArray.toSet === expectedRight)
expectedRight: Set[Int]): Unit = {
val split = new ContinuousSplit(0, threshold)
val numRows = col.values.length
val bitv = AltDT.bitVectorFromSplit(col, fromOffset, toOffset, split, numRows)
assert(bitv.toArray.toSet === expectedRight)
}
// Left child node
checkSplit(0, 3, 0.05, Set(0, 2, 4))
Expand Down

0 comments on commit 402b80b

Please sign in to comment.