diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala index 0db867ab04a70..d01dad09834e8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala @@ -29,6 +29,7 @@ import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.collection.BitSet +import org.roaringbitmap.RoaringBitmap /** @@ -135,7 +136,7 @@ private[ml] object AltDT extends Logging { } val labelsBc = input.sparkContext.broadcast(labels) // NOTE: Labels are not sorted with features since that would require 1 copy per feature, - // rather than 1 copy per worker. This means a lot of random accesses. + // rather than 1 copy per worker. This means a lot of random accesses. // We could improve this by applying first-level sorting (by node) to labels. // Sort each column by feature values. @@ -196,23 +197,19 @@ private[ml] object AltDT extends Logging { doneLearning = currentLevel + 1 >= strategy.maxDepth || estimatedRemainingActive == 0 if (!doneLearning) { - // Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right. - val aggBitVectors: Array[BitSubvector] = - collectBitVectors(partitionInfos, bestSplitsAndGains.map(_._1)) + val splits: Array[Option[Split]] = bestSplitsAndGains.map(_._1) - // Broadcast aggregated bit vectors. On each partition, update instance--node map. - val aggBitVectorsBc = input.sparkContext.broadcast(aggBitVectors) + // Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right + val aggBitVector: RoaringBitmap = aggregateBitVector(partitionInfos, splits, numRows) val newPartitionInfos = partitionInfos.map { partitionInfo => - partitionInfo.update(aggBitVectorsBc.value, numNodeOffsets) + partitionInfo.update(aggBitVector, numNodeOffsets) } // TODO: remove. For some reason, this is needed to make things work. // Probably messing up somewhere above... newPartitionInfos.cache().count() partitionInfos = newPartitionInfos - // TODO: unpersist aggBitVectorsBc after action. } - currentLevel += 1 } @@ -333,42 +330,52 @@ private[ml] object AltDT extends Logging { * @param bestSplits Split for each active node, or None if that node will not be split * @return Array of bit vectors, ordered by offset ranges */ - private[impl] def collectBitVectors( + private[impl] def aggregateBitVector( partitionInfos: RDD[PartitionInfo], - bestSplits: Array[Option[Split]]): Array[BitSubvector] = { + bestSplits: Array[Option[Split]], + numRows: Int): RoaringBitmap = { val bestSplitsBc: Broadcast[Array[Option[Split]]] = partitionInfos.sparkContext.broadcast(bestSplits) - val workerBitSubvectors: RDD[Array[BitSubvector]] = partitionInfos.map { + val workerBitSubvectors: RDD[RoaringBitmap] = partitionInfos.map { case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], activeNodes: BitSet) => val localBestSplits: Array[Option[Split]] = bestSplitsBc.value // localFeatureIndex[feature index] = index into PartitionInfo.columns val localFeatureIndex: Map[Int, Int] = columns.map(_.featureIndex).zipWithIndex.toMap - activeNodes.iterator.zip(localBestSplits.iterator).flatMap { + val bitSetForNodes: Iterator[RoaringBitmap] = activeNodes.iterator + .zip(localBestSplits.iterator).flatMap { case (nodeIndexInLevel: Int, Some(split: Split)) => if (localFeatureIndex.contains(split.featureIndex)) { // This partition has the column (feature) used for this split. val fromOffset = nodeOffsets(nodeIndexInLevel) val toOffset = nodeOffsets(nodeIndexInLevel + 1) val colIndex: Int = localFeatureIndex(split.featureIndex) - Iterator(bitSubvectorFromSplit(columns(colIndex), fromOffset, toOffset, split)) + Iterator(bitVectorFromSplit(columns(colIndex), fromOffset, toOffset, split, numRows)) } else { Iterator() } case (nodeIndexInLevel: Int, None) => - // Do not create a BitSubvector when there is no split. - // This requires PartitionInfo.update to handle missing BitSubvectors. + // Do not create a bitVector when there is no split. + // PartitionInfo.update will detect that there is no + // split, by how many instances go left/right. Iterator() - }.toArray + } + if (bitSetForNodes.isEmpty) { + new RoaringBitmap() + } else { + bitSetForNodes.reduce[RoaringBitmap] { (acc, bitv) => acc.or(bitv); acc } + } + } + val aggBitVector: RoaringBitmap = workerBitSubvectors.reduce { (acc, bitv) => + acc.or(bitv) + acc } - val aggBitVectors: Array[BitSubvector] = workerBitSubvectors.reduce(BitSubvector.merge) bestSplitsBc.unpersist() - aggBitVectors + aggBitVector } /** * Choose the best split for a feature at a node. - * * TODO: Return null or None when the split is invalid, such as putting all instances on one * child node. * @@ -787,20 +794,21 @@ private[ml] object AltDT extends Logging { * second by sorted row indices within the node's rows. * bit[index in sorted array of row indices] = false for left, true for right */ - private[impl] def bitSubvectorFromSplit( + private[impl] def bitVectorFromSplit( col: FeatureVector, fromOffset: Int, toOffset: Int, - split: Split): BitSubvector = { - val nodeRowIndices = col.indices.slice(fromOffset, toOffset) - val nodeRowValues = col.values.slice(fromOffset, toOffset) - val nodeRowValuesSortedByIndices = nodeRowIndices.zip(nodeRowValues).sortBy(_._1).map(_._2) - val bitv = new BitSubvector(fromOffset, toOffset) + split: Split, + numRows: Int): RoaringBitmap = { + val nodeRowIndices = col.indices.view.slice(fromOffset, toOffset) + val nodeRowValues = col.values.view.slice(fromOffset, toOffset) + val bitv = new RoaringBitmap() var i = 0 - while (i < nodeRowValuesSortedByIndices.length) { - val value = nodeRowValuesSortedByIndices(i) + while (i < nodeRowValues.length) { + val value = nodeRowValues(i) + val idx = nodeRowIndices(i) if (!split.shouldGoLeft(value)) { - bitv.set(fromOffset + i) + bitv.add(idx) } i += 1 } @@ -833,6 +841,11 @@ private[ml] object AltDT extends Logging { activeNodes: BitSet) extends Serializable { + // pre-allocated temporary buffers that we use to sort + // instances in left and right children during update + val tempVals: Array[Double] = new Array[Double](columns(0).values.length) + val tempIndices: Array[Int] = new Array[Int](columns(0).values.length) + /** For debugging */ override def toString: String = { "PartitionInfo(" + @@ -854,82 +867,82 @@ private[ml] object AltDT extends Logging { * Update nodeOffsets, activeNodes: * Split offsets for nodes which split (which can be identified using the bit vector). * - * @param bitVectors Bit vectors encoding splits for the next level of the tree. + * @param instanceBitVector Bit vector encoding splits for the next level of the tree. * These must follow a 2-level ordering, where the first level is by node * and the second level is by row index. * bitVector(i) = false iff instance i goes to the left child. * For instances at inactive (leaf) nodes, the value can be arbitrary. - * When an active node is not split (e.g., because no good split was found), - * then the corresponding BitSubvector can be missing. * @return Updated partition info */ - def update(bitVectors: Array[BitSubvector], newNumNodeOffsets: Int): PartitionInfo = { - val newColumns = columns.map { oldCol => - val col = oldCol.deepCopy() - var curBitVecIdx = 0 + def update(instanceBitVector: RoaringBitmap, newNumNodeOffsets: Int): + PartitionInfo = { + // Create a 2-level representation of the new nodeOffsets (to be flattened). + // These 2 levels correspond to original nodes and their children (if split). + val newNodeOffsets = nodeOffsets.map(Array(_)) + + val newColumns = columns.map { col => activeNodes.iterator.foreach { nodeIdx => val from = nodeOffsets(nodeIdx) val to = nodeOffsets(nodeIdx + 1) - if (curBitVecIdx + 1 < bitVectors.length && bitVectors(curBitVecIdx).to <= from) { - // If there are no more BitVectors, curBitVecIdx stays at the last bitVector, - // which is acceptable (since it will not cover further nodes which were not split). - curBitVecIdx += 1 - } - val curBitVector = bitVectors(curBitVecIdx) - // If the current BitVector does not cover this node, then this node was not split, - // so we do not need to update its part of the column. Otherwise, we update it. - if (curBitVector.from <= from && to <= curBitVector.to) { - // Sort range [from, to) based on indices. This is required to match the bit vector - // across all workers. See [[bitSubvectorFromSplit]] for details. - val rangeIndices = col.indices.view.slice(from, to).toArray - val rangeValues = col.values.view.slice(from, to).toArray - val sortedRange = rangeIndices.zip(rangeValues).sortBy(_._1) - // Sort range [from, to) based on bit vector. - sortedRange.zipWithIndex.map { case ((idx, value), i) => - val bit = curBitVector.get(from + i) - // TODO: In-place merge, rather than general sort. - // TODO: We don't actually need to sort the categorical features using our approach. - (bit, value, idx) - }.sorted.zipWithIndex.foreach { case ((bit, value, idx), i) => - col.values(from + i) = value - col.indices(from + i) = idx + val rangeIndices = col.indices.view.slice(from, to) + val rangeValues = col.values.view.slice(from, to) + + // If this is the very first time we split, + // we don't use rangeIndices to count the number of bits set; + // the entire bit vector will be used, so getCardinality + // will give us the same result more cheaply. + val numBitsSet = if (nodeOffsets.length == 2) instanceBitVector.getCardinality + else rangeIndices.count(instanceBitVector.contains) + + val numBitsNotSet = to - from - numBitsSet // number of instances splitting left + val oldOffset = newNodeOffsets(nodeIdx).head + + // If numBitsNotSet or numBitsSet equals 0, then this node was not split, + // so we do not need to update its part of the column. Otherwise, we update it. + if (numBitsNotSet != 0 && numBitsSet != 0) { + newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numBitsNotSet) + + // BEGIN SORTING + // We sort the [from, to) slice of col based on instance bit, then + // instance value. This is required to match the bit vector across all + // workers. All instances going "left" in the split (which are false) + // should be ordered before the instances going "right". The instanceBitVector + // gives us the bit value for each instance based on the instance's index. + // Then both [from, numBitsNotSet) and [numBitsNotSet, to) need to be sorted + // by value. + // Since the column is already sorted by value, we can compute + // this sort in a single pass over the data. We iterate from start to finish + // (which preserves the sorted order), and then copy the values + // into @tempVals and @tempIndices either: + // 1) in the [from, numBitsNotSet) range if the bit is false, or + // 2) in the [numBitsNotSet, to) range if the bit is true. + var (leftInstanceIdx, rightInstanceIdx) = (from, from + numBitsNotSet) + var idx = 0 + while (idx < rangeValues.length) { + val indexForVal = rangeIndices(idx) + val bit = instanceBitVector.contains(indexForVal) + if (bit) { + tempVals(rightInstanceIdx) = rangeValues(idx) + tempIndices(rightInstanceIdx) = indexForVal + rightInstanceIdx += 1 + } else { + tempVals(leftInstanceIdx) = rangeValues(idx) + tempIndices(leftInstanceIdx) = indexForVal + leftInstanceIdx += 1 + } + idx += 1 } - } - } - col - } + // END SORTING - // Create a 2-level representation of the new nodeOffsets (to be flattened). - // These 2 levels correspond to original nodes and their children (if split). - val newNodeOffsets = nodeOffsets.map(Array(_)) - var curBitVecIdx = 0 - activeNodes.iterator.foreach { nodeIdx => - val from = nodeOffsets(nodeIdx) - val to = nodeOffsets(nodeIdx + 1) - if (curBitVecIdx + 1 < bitVectors.length && bitVectors(curBitVecIdx).to <= from) { - // If there are no more BitVectors, curBitVecIdx stays at the last bitVector, - // which is acceptable (since it will not cover further nodes which were not split). - curBitVecIdx += 1 - } - val curBitVector = bitVectors(curBitVecIdx) - // If the current BitVector does not cover this node, then this node was not split, - // so we do not need to create a new node offset. Otherwise, we create an offset. - if (curBitVector.from <= from && to <= curBitVector.to) { - // Count number of values splitting to left vs. right - val numRight = Range(from, to).count(curBitVector.get) - val numLeft = to - from - numRight - if (numLeft != 0 && numRight != 0) { - // node is split - val oldOffset = newNodeOffsets(nodeIdx).head - newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numLeft) + // update the column values and indices + // with the corresponding indices + Array.copy(tempVals, from, col.values, from, rangeValues.length) + Array.copy(tempIndices, from, col.indices, from, rangeValues.length) } } + col } - assert(newNodeOffsets.map(_.length).sum == newNumNodeOffsets, - s"(W) newNodeOffsets total size: ${newNodeOffsets.map(_.length).sum}," + - s" newNumNodeOffsets: $newNumNodeOffsets") - // Identify the new activeNodes based on the 2-level representation of the new nodeOffsets. val newActiveNodes = new BitSet(newNumNodeOffsets - 1) var newNodeOffsetsIdx = 0 diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala index c4da28d402c58..03e1bfc570895 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala @@ -21,20 +21,20 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.regression.DecisionTreeRegressor import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.AltDT.{AltDTMetadata, FeatureVector, PartitionInfo} -import org.apache.spark.ml.tree.impl.TreeUtil._ -import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.collection.BitSet +import org.roaringbitmap.RoaringBitmap /** * Test suite for [[AltDT]]. */ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { - //////////////////////////////// Integration tests ////////////////////////////////// + /* * * * * * * * * * * Integration tests * * * * * * * * * * */ test("run deep example") { val data = Range(0, 3).map(x => LabeledPoint(math.pow(x, 3), Vectors.dense(x))) @@ -92,7 +92,7 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.numNodes === 5) } - //////////////////////////////// Helper classes ////////////////////////////////// + /* * * * * * * * * * * Helper classes * * * * * * * * * * */ test("FeatureVector") { val v = new FeatureVector(1, 0, Array(0.1, 0.3, 0.7), Array(1, 2, 0)) @@ -122,11 +122,12 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { // Create bitVector for splitting the 4 rows: L, R, L, R // New groups are {0, 2}, {1, 3} - val bitVector = new BitSubvector(0, numRows) - bitVector.set(1) - bitVector.set(3) + val bitVector = new RoaringBitmap() + bitVector.add(1) + bitVector.add(3) - val newInfo = info.update(Array(bitVector), newNumNodeOffsets = 3) + // for these tests, use the activeNodes for nodeSplitBitVector + val newInfo = info.update(bitVector, newNumNodeOffsets = 3) assert(newInfo.columns.length === 2) val expectedCol1a = @@ -139,12 +140,11 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { assert(newInfo.activeNodes.iterator.toSet === Set(0, 1)) // Create 2 bitVectors for splitting into: 0, 2, 1, 3 - val bv2a = new BitSubvector(0, 2) - bv2a.set(1) - val bv2b = new BitSubvector(2, 4) - bv2b.set(3) + val bitVector2 = new RoaringBitmap() + bitVector2.add(2) // 2 goes to the right + bitVector2.add(3) // 3 goes to the right - val newInfo2 = newInfo.update(Array(bv2a, bv2b), newNumNodeOffsets = 5) + val newInfo2 = newInfo.update(bitVector2, newNumNodeOffsets = 5) assert(newInfo2.columns.length === 2) val expectedCol2a = @@ -157,7 +157,7 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { assert(newInfo2.activeNodes.iterator.toSet === Set(0, 1, 2, 3)) } - //////////////////////////////// Misc ////////////////////////////////// + /* * * * * * * * * * * Misc * * * * * * * * * * */ test("numUnorderedBins") { // Note: We have duplicate bins (the inverse) for unordered features. This should be fixed! @@ -165,7 +165,7 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { assert(AltDT.numUnorderedBins(3) === 6) // 3 categories => 6 bins } - //////////////////////////////// Choosing splits ////////////////////////////////// + /* * * * * * * * * * * Choosing Splits * * * * * * * * * * */ test("computeBestSplits") { // TODO @@ -260,7 +260,7 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { val labels = Seq(0.0, 0.0, 1.0, 1.0, 1.0) val impurity = Entropy val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity) - val (split, stats) = AltDT.chooseUnorderedCategoricalSplit( + val (split, _) = AltDT.chooseUnorderedCategoricalSplit( featureIndex, values, labels, metadata, featureArity) split match { case Some(s: CategoricalSplit) => @@ -337,38 +337,38 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { assert(stats.impurityCalculator.stats === fullImpurityStatsArray) } - //////////////////////////////// Bit subvectors ////////////////////////////////// + /* * * * * * * * * * * Bit subvectors * * * * * * * * * * */ test("bitSubvectorFromSplit: 1 node") { val col = FeatureVector.fromOriginal(0, 0, Vectors.dense(0.1, 0.2, 0.4, 0.6, 0.7)) val fromOffset = 0 val toOffset = col.values.length + val numRows = toOffset val split = new ContinuousSplit(0, threshold = 0.5) - val bitv = AltDT.bitSubvectorFromSplit(col, fromOffset, toOffset, split) - assert(bitv.from === fromOffset) - assert(bitv.to === toOffset) - assert(bitv.iterator.toSet === Set(3, 4)) + val bitv = AltDT.bitVectorFromSplit(col, fromOffset, toOffset, split, numRows) + assert(bitv.toArray.toSet === Set(3, 4)) } test("bitSubvectorFromSplit: 2 nodes") { // Initially, 1 split: (0, 2, 4) | (1, 3) val col = new FeatureVector(0, 0, Array(0.1, 0.2, 0.4, 0.6, 0.7), Array(4, 2, 0, 1, 3)) - def checkSplit(fromOffset: Int, toOffset: Int, threshold: Double, expectedRight: Set[Int]): Unit = { - val split = new ContinuousSplit(0, threshold) - val bitv = AltDT.bitSubvectorFromSplit(col, fromOffset, toOffset, split) - assert(bitv.from === fromOffset) - assert(bitv.to === toOffset) - assert(bitv.iterator.toSet === expectedRight) + def checkSplit(fromOffset: Int, toOffset: Int, threshold: Double, + expectedRight: Set[Int]): Unit = { + val split = new ContinuousSplit(0, threshold) + val numRows = col.values.length + val bitv = AltDT.bitVectorFromSplit(col, fromOffset, toOffset, split, numRows) + assert(bitv.toArray.toSet === expectedRight) } // Left child node - checkSplit(0, 3, 0.15, Set(0, 1)) + checkSplit(0, 3, 0.05, Set(0, 2, 4)) + checkSplit(0, 3, 0.15, Set(0, 2)) checkSplit(0, 3, 0.2, Set(0)) checkSplit(0, 3, 0.5, Set()) // Right child node - checkSplit(3, 5, 0.1, Set(3, 4)) - checkSplit(3, 5, 0.65, Set(4)) + checkSplit(3, 5, 0.1, Set(1, 3)) + checkSplit(3, 5, 0.65, Set(3)) checkSplit(3, 5, 0.8, Set()) } @@ -381,30 +381,25 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes) val partitionInfos = sc.parallelize(Seq(info)) val bestSplit = new ContinuousSplit(0, threshold = 0.5) - val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(Some(bestSplit))) - assert(bitVectors.length === 1) - val bitv = bitVectors.head - assert(bitv.numBits === numRows) - assert(bitv.iterator.toArray === Array(3, 4)) + val bitVector = AltDT.aggregateBitVector(partitionInfos, Array(Some(bestSplit)), numRows) + assert(bitVector.toArray.toSet === Set(3, 4)) } test("collectBitVectors with 1 vector, with tied threshold") { val col = new FeatureVector(0, 0, - Array(-4.0,-4.0,-2.0,-2.0,-1.0,-1.0,1.0,1.0), Array(3,7,2,6,1,5,0,4)) + Array(-4.0, -4.0, -2.0, -2.0, -1.0, -1.0, 1.0, 1.0), + Array(3, 7, 2, 6, 1, 5, 0, 4)) val numRows = col.values.length val activeNodes = new BitSet(1) activeNodes.set(0) val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes) val partitionInfos = sc.parallelize(Seq(info)) val bestSplit = new ContinuousSplit(0, threshold = -2.0) - val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(Some(bestSplit))) - assert(bitVectors.length === 1) - val bitv = bitVectors.head - assert(bitv.numBits === numRows) - assert(bitv.iterator.toArray === Array(0, 1, 4, 5)) + val bitVector = AltDT.aggregateBitVector(partitionInfos, Array(Some(bestSplit)), numRows) + assert(bitVector.toArray.toSet === Set(0, 1, 4, 5)) } - //////////////////////////////// Active nodes ////////////////////////////////// + /* * * * * * * * * * * Active nodes * * * * * * * * * * */ test("computeActiveNodePeriphery") { // old periphery: 2 nodes