Skip to content
This repository has been archived by the owner on Nov 19, 2020. It is now read-only.

Commit

Permalink
ErrorBasedPruning issues
Browse files Browse the repository at this point in the history
  • Loading branch information
YakanMS committed May 25, 2016
1 parent 347dd92 commit 4af4897
Showing 1 changed file with 19 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,6 @@ public double Run()
return computeError();
}



private double computeError()
{
return new ZeroOneLoss(outputs) { Mean = true }.Loss(tree.Decide(inputs));
Expand All @@ -185,7 +183,7 @@ private double computeError()
private bool compute(DecisionNode node)
{
int[] indices = subsets[node].ToArray();
int[] subset = outputs.Submatrix(indices);
int[] outputSubset = outputs.Submatrix(indices);

if (indices.Length == 0)
{
Expand All @@ -195,41 +193,27 @@ private bool compute(DecisionNode node)
node.Branches = null;
node.Output = null;

foreach (var child in node)
subsets[child].Clear();

for (int i = 0; i < inputs.Length; i++)
trackDecisions(node, inputs[i], i);

return true;
}

int size = indices.Length;
int mostCommon = subset.Mode();
DecisionNode maxChild = getMaxChild(node);

double replace = Double.PositiveInfinity;
if (maxChild != null)
{
replace = computeErrorReplacingSubtrees(node, maxChild);
replace = upperBound(replace, size);
}

double baselineError = computeError();
baselineError = upperBound(baselineError, size);

double baseline = computeErrorSubtree(indices);
double prune = computeErrorWithoutSubtree(node, mostCommon);


baseline = upperBound(baseline, size);
prune = upperBound(prune, size);
int mostCommon = outputSubset.Mode();
double pruneError = computeErrorWithoutSubtree(node, mostCommon);
pruneError = upperBound(pruneError, size);

DecisionNode maxChild = getMaxChild(node);
double replaceError = computeErrorReplacingSubtrees(node, maxChild);
replaceError = upperBound(replaceError, size);

bool changed = false;

if (Math.Abs(prune - baseline) < limit ||
Math.Abs(replace - baseline) < limit)
if (Math.Abs(pruneError - baselineError) < limit ||
Math.Abs(replaceError - baselineError) < limit)
{
if (replace < prune)
if (replaceError < pruneError)
{
// We should replace the subtree with its maximum child
node.Branches = maxChild.Branches;
Expand All @@ -249,8 +233,9 @@ private bool compute(DecisionNode node)
foreach (var child in node)
subsets[child].Clear();

for (int i = 0; i < inputs.Length; i++)
trackDecisions(node, inputs[i], i);
double[][] inputSubset = inputs.Submatrix(indices);
for (int i = 0; i < inputSubset.Length; i++)
trackDecisions(node, inputSubset[i], i);
}

return changed;
Expand Down Expand Up @@ -294,33 +279,18 @@ private double computeErrorReplacingSubtrees(DecisionNode tree, DecisionNode chi
return error;
}

private double computeErrorSubtree(int[] indices)
{
int error = 0;
foreach (int i in indices)
if (outputs[i] != actual[i]) error++;

return error / (double)indices.Length;
}

private DecisionNode getMaxChild(DecisionNode tree)
{
DecisionNode max = null;
int maxCount = 0;

foreach (var child in tree.Branches)
{
if (child.Branches != null)
var list = subsets[child];
if (list.Count > maxCount)
{
foreach (var node in child.Branches)
{
var list = subsets[node];
if (list.Count > maxCount)
{
max = node;
maxCount = list.Count;
}
}
max = child;
maxCount = list.Count;
}
}

Expand Down

0 comments on commit 4af4897

Please sign in to comment.