diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 0fe30a3e7040b..486a1bb16af93 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.mllib.point.WeightedLabeledPoint /** * :: Experimental :: @@ -47,13 +48,16 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { + // Converting from standard instance format to weighted input format for tree training + val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features)) + // Cache input RDD for speedup during multiple passes. - input.cache() + weightedInput.cache() logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(weightedInput, strategy) val numBins = bins(0).length logDebug("numBins = " + numBins) @@ -70,7 +74,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) // num features - val numFeatures = input.take(1)(0).features.size + val numFeatures = weightedInput.take(1)(0).features.size // Calculate level for single group construction @@ -109,8 +113,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, - level, filters, splits, bins, maxLevelForSingleGroup) + val splitsStatsForLevel = DecisionTree.findBestSplits(weightedInput, parentImpurities, + strategy, level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { // Extract info for nodes at the current level. @@ -291,7 +295,7 @@ object DecisionTree extends Serializable with Logging { * @return array of splits with best splits for all nodes at a given level. */ protected[tree] def findBestSplits( - input: RDD[LabeledPoint], + input: RDD[WeightedLabeledPoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, @@ -339,7 +343,7 @@ object DecisionTree extends Serializable with Logging { * @return array of splits with best splits for all nodes at a given level. */ private def findBestSplitsPerGroup( - input: RDD[LabeledPoint], + input: RDD[WeightedLabeledPoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, @@ -399,7 +403,7 @@ object DecisionTree extends Serializable with Logging { * Find whether the sample is valid input for the current node, i.e., whether it passes through * all the filters for the current node. */ - def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { + def isSampleValid(parentFilters: List[Filter], labeledPoint: WeightedLabeledPoint): Boolean = { // leaf if ((level > 0) & (parentFilters.length == 0)) { return false @@ -438,7 +442,7 @@ object DecisionTree extends Serializable with Logging { */ def findBin( featureIndex: Int, - labeledPoint: LabeledPoint, + labeledPoint: WeightedLabeledPoint, isFeatureContinuous: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -509,7 +513,7 @@ object DecisionTree extends Serializable with Logging { * where b_ij is an integer between 0 and numBins - 1. * Invalid sample is denoted by noting bin for feature 1 as -1. */ - def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { + def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = { // Calculate bin index and label per feature per node. val arr = new Array[Double](1 + (numFeatures * numNodes)) arr(0) = labeledPoint.label @@ -982,7 +986,7 @@ object DecisionTree extends Serializable with Logging { * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) */ protected[tree] def findSplitsBins( - input: RDD[LabeledPoint], + input: RDD[WeightedLabeledPoint], strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count()