Skip to content
This repository has been archived by the owner on Nov 19, 2020. It is now read-only.

Commit

Permalink
Merge pull request #238 from YaronK/development
Browse files Browse the repository at this point in the history
Merging pull request for ErrorBasedPruning issues: 

Fixes GH-234, GH-235, GH-236. GH-237
  • Loading branch information
cesarsouza committed Jun 1, 2016
2 parents 347dd92 + 38e11e7 commit 254ca3a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,6 @@ public double Run()
return computeError();
}



private double computeError()
{
return new ZeroOneLoss(outputs) { Mean = true }.Loss(tree.Decide(inputs));
Expand All @@ -185,7 +183,7 @@ private double computeError()
private bool compute(DecisionNode node)
{
int[] indices = subsets[node].ToArray();
int[] subset = outputs.Submatrix(indices);
int[] outputSubset = outputs.Submatrix(indices);

if (indices.Length == 0)
{
Expand All @@ -195,41 +193,27 @@ private bool compute(DecisionNode node)
node.Branches = null;
node.Output = null;

foreach (var child in node)
subsets[child].Clear();

for (int i = 0; i < inputs.Length; i++)
trackDecisions(node, inputs[i], i);

return true;
}

int size = indices.Length;
int mostCommon = subset.Mode();
DecisionNode maxChild = getMaxChild(node);

double replace = Double.PositiveInfinity;
if (maxChild != null)
{
replace = computeErrorReplacingSubtrees(node, maxChild);
replace = upperBound(replace, size);
}

double baselineError = computeError();
baselineError = upperBound(baselineError, size);

double baseline = computeErrorSubtree(indices);
double prune = computeErrorWithoutSubtree(node, mostCommon);


baseline = upperBound(baseline, size);
prune = upperBound(prune, size);
int mostCommon = outputSubset.Mode();
double pruneError = computeErrorWithoutSubtree(node, mostCommon);
pruneError = upperBound(pruneError, size);

DecisionNode maxChild = getMaxChild(node);
double replaceError = computeErrorReplacingSubtrees(node, maxChild);
replaceError = upperBound(replaceError, size);

bool changed = false;

if (Math.Abs(prune - baseline) < limit ||
Math.Abs(replace - baseline) < limit)
if (Math.Abs(pruneError - baselineError) < limit ||
Math.Abs(replaceError - baselineError) < limit)
{
if (replace < prune)
if (replaceError < pruneError)
{
// We should replace the subtree with its maximum child
node.Branches = maxChild.Branches;
Expand All @@ -249,8 +233,9 @@ private bool compute(DecisionNode node)
foreach (var child in node)
subsets[child].Clear();

for (int i = 0; i < inputs.Length; i++)
trackDecisions(node, inputs[i], i);
double[][] inputSubset = inputs.Submatrix(indices);
for (int i = 0; i < inputSubset.Length; i++)
trackDecisions(node, inputSubset[i], i);
}

return changed;
Expand Down Expand Up @@ -294,33 +279,18 @@ private double computeErrorReplacingSubtrees(DecisionNode tree, DecisionNode chi
return error;
}

private double computeErrorSubtree(int[] indices)
{
int error = 0;
foreach (int i in indices)
if (outputs[i] != actual[i]) error++;

return error / (double)indices.Length;
}

private DecisionNode getMaxChild(DecisionNode tree)
{
DecisionNode max = null;
int maxCount = 0;

foreach (var child in tree.Branches)
{
if (child.Branches != null)
var list = subsets[child];
if (list.Count > maxCount)
{
foreach (var node in child.Branches)
{
var list = subsets[node];
if (list.Count > maxCount)
{
max = node;
maxCount = list.Count;
}
}
max = child;
maxCount = list.Count;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ namespace Accord.Tests.MachineLearning
using Accord.Statistics.Filters;
using Accord.Math;
using Accord.Tests.MachineLearning.Properties;
using Accord.MachineLearning.DecisionTrees.Learning;
using Accord.MachineLearning.DecisionTrees.Learning;


[TestFixture]
public class ErrorBasedpruningTest
{
Expand Down Expand Up @@ -84,9 +84,9 @@ public void RunTest()
foreach (var node in tree)
nodeCount2++;

Assert.AreEqual(0.25459770114942532, error);
Assert.AreEqual(0.28922413793103446, error);
Assert.AreEqual(447, nodeCount);
Assert.AreEqual(193, nodeCount2);
Assert.AreEqual(424, nodeCount2);
}


Expand All @@ -109,14 +109,14 @@ public void RunTest3()
actual[i] = nodeCount2;
}

double[] expected = { 447, 193, 145, 140, 124, 117, 109, 103, 95, 87 };
double[] expected = { 447, 424, 410, 402, 376, 362, 354, 348, 336, 322 };

for (int i = 0; i < actual.Length; i++)
Assert.AreEqual(expected[i], actual[i]);
}

private static void repeat(double[][] inputs, int[] outputs,
DecisionTree tree, int training, double threshold,
private static void repeat(double[][] inputs, int[] outputs,
DecisionTree tree, int training, double threshold,
out int nodeCount2)
{
int nodeCount = 0;
Expand Down

0 comments on commit 254ca3a

Please sign in to comment.