Skip to content

Commit

Permalink
[SPARK-10064] [ML] Parallelize decision tree bin split calculations
Browse files Browse the repository at this point in the history
Reimplement `DecisionTree.findSplitsBins` via `RDD` to parallelize bin calculation.

With large feature spaces the current implementation is very slow. This change limits the features that are distributed (or collected) to just the continuous features, and performs the split calculations in parallel. It completes on a real multi terabyte dataset in less than a minute instead of multiple hours.

Author: Nathan Howell <[email protected]>

Closes #8246 from NathanHowell/SPARK-10064.
  • Loading branch information
Nathan Howell authored and jkbradley committed Oct 8, 2015
1 parent 075a0b6 commit 1bc435a
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 95 deletions.
164 changes: 86 additions & 78 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuilder

import org.apache.spark.Logging
import org.apache.spark.annotation.{Experimental, Since}
Expand Down Expand Up @@ -643,8 +642,8 @@ object DecisionTree extends Serializable with Logging {

val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
.map { case (nodeIndex, aggStats) =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
nodeToFeatures(nodeIndex)
}

// find best split for each node
Expand Down Expand Up @@ -976,8 +975,8 @@ object DecisionTree extends Serializable with Logging {
val numFeatures = metadata.numFeatures

// Sample the input only if there are continuous features.
val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
val sampledInput = if (hasContinuousFeatures) {
val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
val sampledInput = if (continuousFeatures.nonEmpty) {
// Calculate the number of samples for approximate quantile calculation.
val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
val fraction = if (requiredSamples < metadata.numExamples) {
Expand All @@ -986,88 +985,97 @@ object DecisionTree extends Serializable with Logging {
1.0
}
logDebug("fraction of data used for calculating quantiles = " + fraction)
input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt())
} else {
new Array[LabeledPoint](0)
input.sparkContext.emptyRDD[LabeledPoint]
}

metadata.quantileStrategy match {
case Sort =>
val splits = new Array[Array[Split]](numFeatures)
val bins = new Array[Array[Bin]](numFeatures)

// Find all splits.
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
if (metadata.isContinuous(featureIndex)) {
val featureSamples = sampledInput.map(lp => lp.features(featureIndex))
val featureSplits = findSplitsForContinuousFeature(featureSamples,
metadata, featureIndex)

val numSplits = featureSplits.length
val numBins = numSplits + 1
logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
splits(featureIndex) = new Array[Split](numSplits)
bins(featureIndex) = new Array[Bin](numBins)

var splitIndex = 0
while (splitIndex < numSplits) {
val threshold = featureSplits(splitIndex)
splits(featureIndex)(splitIndex) =
new Split(featureIndex, threshold, Continuous, List())
splitIndex += 1
}
bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
splits(featureIndex)(0), Continuous, Double.MinValue)

splitIndex = 1
while (splitIndex < numSplits) {
bins(featureIndex)(splitIndex) =
new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
Continuous, Double.MinValue)
splitIndex += 1
}
bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
} else {
val numSplits = metadata.numSplits(featureIndex)
val numBins = metadata.numBins(featureIndex)
// Categorical feature
val featureArity = metadata.featureArity(featureIndex)
if (metadata.isUnordered(featureIndex)) {
// Unordered features
// 2^(maxFeatureValue - 1) - 1 combinations
splits(featureIndex) = new Array[Split](numSplits)
var splitIndex = 0
while (splitIndex < numSplits) {
val categories: List[Double] =
extractMultiClassCategories(splitIndex + 1, featureArity)
splits(featureIndex)(splitIndex) =
new Split(featureIndex, Double.MinValue, Categorical, categories)
splitIndex += 1
}
} else {
// Ordered features
// Bins correspond to feature values, so we do not need to compute splits or bins
// beforehand. Splits are constructed as needed during training.
splits(featureIndex) = new Array[Split](0)
}
// For ordered features, bins correspond to feature values.
// For unordered categorical features, there is no need to construct the bins.
// since there is a one-to-one correspondence between the splits and the bins.
bins(featureIndex) = new Array[Bin](0)
}
featureIndex += 1
}
(splits, bins)
findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
case MinMax =>
throw new UnsupportedOperationException("minmax not supported yet.")
case ApproxHist =>
throw new UnsupportedOperationException("approximate histogram not supported yet.")
}
}

private def findSplitsBinsBySorting(
input: RDD[LabeledPoint],
metadata: DecisionTreeMetadata,
continuousFeatures: IndexedSeq[Int]): (Array[Array[Split]], Array[Array[Bin]]) = {
def findSplits(
featureIndex: Int,
featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = {
val splits = {
val featureSplits = findSplitsForContinuousFeature(
featureSamples.toArray,
metadata,
featureIndex)
logDebug(s"featureIndex = $featureIndex, numSplits = ${featureSplits.length}")

featureSplits.map(threshold => new Split(featureIndex, threshold, Continuous, Nil))
}

val bins = {
val lowSplit = new DummyLowSplit(featureIndex, Continuous)
val highSplit = new DummyHighSplit(featureIndex, Continuous)

// tack the dummy splits on either side of the computed splits
val allSplits = lowSplit +: splits.toSeq :+ highSplit

// slide across the split points pairwise to allocate the bins
allSplits.sliding(2).map {
case Seq(left, right) => new Bin(left, right, Continuous, Double.MinValue)
}.toArray
}

(featureIndex, (splits, bins))
}

val continuousSplits = {
// reduce the parallelism for split computations when there are less
// continuous features than input partitions. this prevents tasks from
// being spun up that will definitely do no work.
val numPartitions = math.min(continuousFeatures.length, input.partitions.length)

input
.flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx))))
.groupByKey(numPartitions)
.map { case (k, v) => findSplits(k, v) }
.collectAsMap()
}

val numFeatures = metadata.numFeatures
val (splits, bins) = Range(0, numFeatures).unzip {
case i if metadata.isContinuous(i) =>
val (split, bin) = continuousSplits(i)
metadata.setNumSplits(i, split.length)
(split, bin)

case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
// Unordered features
// 2^(maxFeatureValue - 1) - 1 combinations
val featureArity = metadata.featureArity(i)
val split = Range(0, metadata.numSplits(i)).map { splitIndex =>
val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
new Split(i, Double.MinValue, Categorical, categories)
}

// For unordered categorical features, there is no need to construct the bins.
// since there is a one-to-one correspondence between the splits and the bins.
(split.toArray, Array.empty[Bin])

case i if metadata.isCategorical(i) =>
// Ordered features
// Bins correspond to feature values, so we do not need to compute splits or bins
// beforehand. Splits are constructed as needed during training.
(Array.empty[Split], Array.empty[Bin])
}

(splits.toArray, bins.toArray)
}

/**
* Nested method to extract list of eligible categories given an index. It extracts the
* position of ones in a binary representation of the input. If binary
Expand Down Expand Up @@ -1131,7 +1139,7 @@ object DecisionTree extends Serializable with Logging {
logDebug("stride = " + stride)

// iterate `valueCount` to find splits
val splitsBuilder = ArrayBuilder.make[Double]
val splitsBuilder = Array.newBuilder[Double]
var index = 1
// currentCount: sum of counts of values that have been visited
var currentCount = valueCounts(0)._2
Expand Down Expand Up @@ -1163,8 +1171,8 @@ object DecisionTree extends Serializable with Logging {
assert(splits.length > 0,
s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
" Please remove this feature and then try again.")
// set number of splits accordingly
metadata.setNumSplits(featureIndex, splits.length)

// the split metadata must be updated on the driver

splits
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,21 @@ private[spark] class NodeIdCache(

prevNodeIdsForInstances = nodeIdsForInstances
nodeIdsForInstances = data.zip(nodeIdsForInstances).map {
dataPoint => {
case (point, node) => {
var treeId = 0
while (treeId < nodeIdUpdaters.length) {
val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(dataPoint._2(treeId), null)
val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(node(treeId), null)
if (nodeIdUpdater != null) {
val newNodeIndex = nodeIdUpdater.updateNodeIndex(
binnedFeatures = dataPoint._1.datum.binnedFeatures,
binnedFeatures = point.datum.binnedFeatures,
bins = bins)
dataPoint._2(treeId) = newNodeIndex
node(treeId) = newNodeIndex
}

treeId += 1
}

dataPoint._2
node
}
}

Expand All @@ -138,7 +138,7 @@ private[spark] class NodeIdCache(
while (checkpointQueue.size > 1 && canDelete) {
// We can delete the oldest checkpoint iff
// the next checkpoint actually exists in the file system.
if (checkpointQueue.get(1).get.getCheckpointFile != None) {
if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) {
val old = checkpointQueue.dequeue()

// Since the old checkpoint is not deleted by Spark,
Expand All @@ -159,11 +159,11 @@ private[spark] class NodeIdCache(
* Call this after training is finished to delete any remaining checkpoints.
*/
def deleteAllCheckpoints(): Unit = {
while (checkpointQueue.size > 0) {
while (checkpointQueue.nonEmpty) {
val old = checkpointQueue.dequeue()
if (old.getCheckpointFile != None) {
for (checkpointFile <- old.getCheckpointFile) {
val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
fs.delete(new Path(old.getCheckpointFile.get), true)
fs.delete(new Path(checkpointFile), true)
}
}
if (prevNodeIdsForInstances != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 3)
assert(fakeMetadata.numSplits(0) === 3)
assert(fakeMetadata.numBins(0) === 4)
// check returned splits are distinct
assert(splits.distinct.length === splits.length)
}
Expand All @@ -151,8 +149,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 2)
assert(fakeMetadata.numSplits(0) === 2)
assert(fakeMetadata.numBins(0) === 3)
assert(splits(0) === 2.0)
assert(splits(1) === 3.0)
}
Expand All @@ -167,8 +163,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 1)
assert(fakeMetadata.numSplits(0) === 1)
assert(fakeMetadata.numBins(0) === 2)
assert(splits(0) === 1.0)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ object EnsembleTestHelper {
required: Double,
metricName: String = "mse") {
val predictions = input.map(x => model.predict(x.features))
val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) =>
label - prediction
val errors = predictions.zip(input).map { case (prediction, point) =>
point.label - prediction
}
val metric = metricName match {
case "mse" =>
Expand Down

0 comments on commit 1bc435a

Please sign in to comment.