diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 74be91fd9e1b9..16af62da6f1e4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -683,8 +683,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, minInstancesPerNode = 4) + val strategy = new Strategy(algo = Classification, impurity = Gini, + maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2) val model = DecisionTree.train(input, strategy) assert(model.topNode.isLeaf) @@ -701,11 +701,37 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) + assert(bestSplits.length == 1) val bestInfoStats = bestSplits(0)._2 assert(bestInfoStats == InformationGainStats.invalidInformationGainStats) } + test("don't chose split that doesn't satify min instance per node requirements") { + // if a split doesn't satisfy min instances per node requirements, + // this split is invalid, even though the information gain of split is large. + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) + arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, + maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2), + numClassesForClassification = 2, minInstancesPerNode = 2) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, + new Array[Node](0), splits, bins, 10) + + assert(bestSplits.length == 1) + val bestSplit = bestSplits(0)._1 + val bestSplitStats = bestSplits(0)._1 + assert(bestSplit.feature == 1) + assert(bestSplitStats != InformationGainStats.invalidInformationGainStats) + } + test("split must satisfy min info gain requirements") { val arr = new Array[LabeledPoint](3) arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) @@ -731,7 +757,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) + assert(bestSplits.length == 1) val bestInfoStats = bestSplits(0)._2 assert(bestInfoStats == InformationGainStats.invalidInformationGainStats) }