diff --git a/Sources/Accord.MachineLearning/DecisionTrees/Pruning/ErrorBasedPruning.cs b/Sources/Accord.MachineLearning/DecisionTrees/Pruning/ErrorBasedPruning.cs index d4452cbd0..4caaab036 100644 --- a/Sources/Accord.MachineLearning/DecisionTrees/Pruning/ErrorBasedPruning.cs +++ b/Sources/Accord.MachineLearning/DecisionTrees/Pruning/ErrorBasedPruning.cs @@ -169,8 +169,6 @@ public double Run() return computeError(); } - - private double computeError() { return new ZeroOneLoss(outputs) { Mean = true }.Loss(tree.Decide(inputs)); @@ -185,7 +183,7 @@ private double computeError() private bool compute(DecisionNode node) { int[] indices = subsets[node].ToArray(); - int[] subset = outputs.Submatrix(indices); + int[] outputSubset = outputs.Submatrix(indices); if (indices.Length == 0) { @@ -195,41 +193,27 @@ private bool compute(DecisionNode node) node.Branches = null; node.Output = null; - foreach (var child in node) - subsets[child].Clear(); - - for (int i = 0; i < inputs.Length; i++) - trackDecisions(node, inputs[i], i); - return true; } int size = indices.Length; - int mostCommon = subset.Mode(); - DecisionNode maxChild = getMaxChild(node); - - double replace = Double.PositiveInfinity; - if (maxChild != null) - { - replace = computeErrorReplacingSubtrees(node, maxChild); - replace = upperBound(replace, size); - } + double baselineError = computeError(); + baselineError = upperBound(baselineError, size); - double baseline = computeErrorSubtree(indices); - double prune = computeErrorWithoutSubtree(node, mostCommon); - - - baseline = upperBound(baseline, size); - prune = upperBound(prune, size); + int mostCommon = outputSubset.Mode(); + double pruneError = computeErrorWithoutSubtree(node, mostCommon); + pruneError = upperBound(pruneError, size); + DecisionNode maxChild = getMaxChild(node); + double replaceError = computeErrorReplacingSubtrees(node, maxChild); + replaceError = upperBound(replaceError, size); bool changed = false; - - if (Math.Abs(prune - baseline) < limit || - Math.Abs(replace - baseline) < limit) + if (Math.Abs(pruneError - baselineError) < limit || + Math.Abs(replaceError - baselineError) < limit) { - if (replace < prune) + if (replaceError < pruneError) { // We should replace the subtree with its maximum child node.Branches = maxChild.Branches; @@ -249,8 +233,9 @@ private bool compute(DecisionNode node) foreach (var child in node) subsets[child].Clear(); - for (int i = 0; i < inputs.Length; i++) - trackDecisions(node, inputs[i], i); + double[][] inputSubset = inputs.Submatrix(indices); + for (int i = 0; i < inputSubset.Length; i++) + trackDecisions(node, inputSubset[i], i); } return changed; @@ -294,15 +279,6 @@ private double computeErrorReplacingSubtrees(DecisionNode tree, DecisionNode chi return error; } - private double computeErrorSubtree(int[] indices) - { - int error = 0; - foreach (int i in indices) - if (outputs[i] != actual[i]) error++; - - return error / (double)indices.Length; - } - private DecisionNode getMaxChild(DecisionNode tree) { DecisionNode max = null; @@ -310,17 +286,11 @@ private DecisionNode getMaxChild(DecisionNode tree) foreach (var child in tree.Branches) { - if (child.Branches != null) + var list = subsets[child]; + if (list.Count > maxCount) { - foreach (var node in child.Branches) - { - var list = subsets[node]; - if (list.Count > maxCount) - { - max = node; - maxCount = list.Count; - } - } + max = child; + maxCount = list.Count; } } diff --git a/Unit Tests/Accord.Tests.MachineLearning/DecisionTrees/ErrorBasedPruningTest.cs b/Unit Tests/Accord.Tests.MachineLearning/DecisionTrees/ErrorBasedPruningTest.cs index 28cf5be8e..fd2c88a10 100644 --- a/Unit Tests/Accord.Tests.MachineLearning/DecisionTrees/ErrorBasedPruningTest.cs +++ b/Unit Tests/Accord.Tests.MachineLearning/DecisionTrees/ErrorBasedPruningTest.cs @@ -30,9 +30,9 @@ namespace Accord.Tests.MachineLearning using Accord.Statistics.Filters; using Accord.Math; using Accord.Tests.MachineLearning.Properties; - using Accord.MachineLearning.DecisionTrees.Learning; - - + using Accord.MachineLearning.DecisionTrees.Learning; + + [TestFixture] public class ErrorBasedpruningTest { @@ -84,9 +84,9 @@ public void RunTest() foreach (var node in tree) nodeCount2++; - Assert.AreEqual(0.25459770114942532, error); + Assert.AreEqual(0.28922413793103446, error); Assert.AreEqual(447, nodeCount); - Assert.AreEqual(193, nodeCount2); + Assert.AreEqual(424, nodeCount2); } @@ -109,14 +109,14 @@ public void RunTest3() actual[i] = nodeCount2; } - double[] expected = { 447, 193, 145, 140, 124, 117, 109, 103, 95, 87 }; + double[] expected = { 447, 424, 410, 402, 376, 362, 354, 348, 336, 322 }; for (int i = 0; i < actual.Length; i++) Assert.AreEqual(expected[i], actual[i]); } - private static void repeat(double[][] inputs, int[] outputs, - DecisionTree tree, int training, double threshold, + private static void repeat(double[][] inputs, int[] outputs, + DecisionTree tree, int training, double threshold, out int nodeCount2) { int nodeCount = 0;