Skip to content

Commit

Permalink
Added stats from cross validation as a val in the cross validation mo…
Browse files Browse the repository at this point in the history
…del to save them for user access
  • Loading branch information
leahmcguire committed Jun 2, 2015
1 parent ad06727 commit 3a995da
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
copyValues(new CrossValidatorModel(uid, bestModel).setParent(this))
copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
}

override def transformSchema(schema: StructType): StructType = {
Expand All @@ -158,7 +158,8 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
@Experimental
class CrossValidatorModel private[ml] (
override val uid: String,
val bestModel: Model[_])
val bestModel: Model[_],
val crossValidationMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams {

override def validateParams(): Unit = {
Expand All @@ -175,7 +176,7 @@ class CrossValidatorModel private[ml] (
}

override def copy(extra: ParamMap): CrossValidatorModel = {
val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]])
val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]], crossValidationMetrics.clone())
copyValues(copied, extra)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
assert(cvModel.crossValidationMetrics.length == 4)
}

test("validateParams should check estimatorParamMaps") {
Expand Down

0 comments on commit 3a995da

Please sign in to comment.