diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index b4466ff40937f..0149938fa719c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -226,17 +226,10 @@ object GradientBoostedTrees extends Logging { logDebug("error of gbt = " + loss.computeError(partialModel, input)) if (validate) { - // Stop training early if - // 1. Reduction in error is less than the validationTol or - // 2. If the error increases, that is if the model is overfit. + // Record the best model if the reduction in error is more than the validationTol. // We want the model returned corresponding to the best validation error. val currentValidateError = loss.computeError(partialModel, validationInput) - if (bestValidateError - currentValidateError < validationTol) { - return new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, - baseLearners.slice(0, bestM), - baseLearnerWeights.slice(0, bestM)) - } else if (currentValidateError < bestValidateError) { + if (currentValidateError < bestValidateError - validationTol) { bestValidateError = currentValidateError bestM = m + 1 } @@ -251,9 +244,15 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") - - new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) + if (validate) { + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, + baseLearners.slice(0, bestM), + baseLearnerWeights.slice(0, bestM)) + } else { + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) + } } }