diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 23cf6bce6dcdc..03f9cbdb9d0a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -739,7 +739,7 @@ object DecisionTree extends Serializable with Logging { val rightCount = rightImpurityCalculator.count // If left child or right child doesn't satisfy minimum instances per node, - // then this split is invalid, return invalid information gain stats + // then this split is invalid, return invalid information gain stats. if ((leftCount < metadata.minInstancesPerNode) || (rightCount < metadata.minInstancesPerNode)) { return InformationGainStats.invalidInformationGainStats @@ -764,6 +764,9 @@ object DecisionTree extends Serializable with Logging { val rightWeight = rightCount / totalCount.toDouble val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + + // if information gain doesn't satisfy minimum information gain, + // then this split is invalid, return invalid information gain stats. if (gain < metadata.minInfoGain) { return InformationGainStats.invalidInformationGainStats } @@ -771,6 +774,13 @@ object DecisionTree extends Serializable with Logging { new InformationGainStats(gain, impurity, leftImpurity, rightImpurity) } + /** + * Calculate predict value for current node, given stats of any split. + * Note that this function is called only once for each node. + * @param leftImpurityCalculator left node aggregates for a split + * @param rightImpurityCalculator right node aggregates for a node + * @return predict value for current node + */ private def calculatePredict( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator): Predict = { @@ -799,6 +809,7 @@ object DecisionTree extends Serializable with Logging { logDebug("node impurity = " + nodeImpurity) + // calculate predict only once var predict: Option[Predict] = None // For each (feature, split), calculate the gain, and select the best (feature, split). diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 4a133e21f461a..f3e2619bd8ba0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -42,5 +42,10 @@ class InformationGainStats( private[tree] object InformationGainStats { + /** + * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to + * denote that current split doesn't satisfies minimum info gain or + * minimum number of instances per node. + */ val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index da1aefb01c6ef..91ce51a3dfdbc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -68,7 +68,11 @@ private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType) private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType) extends Split(feature, Double.MaxValue, featureType, List()) - private[tree] object Split { + /** + * A [[org.apache.spark.mllib.tree.model.Split]] object to denote that + * we can't find a valid split that satisfies minimum info gain + * or minimum number of instances per node. + */ val noSplit = new Split(-1, Double.MinValue, FeatureType.Continuous, List()) }