diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 1682ca91bf832..0130b3e255f0d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -147,24 +147,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) logDebug(s"Train split $splitIndex with multiple sets of parameters.") // Fit models in a Future for training in parallel - val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => - Future[Model[_]] { + val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => + Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] - if (collectSubModelsParam) { subModels.get(splitIndex)(paramIndex) = model } - model - } (executionContext) - } - - // Unpersist training data only when all models have trained - Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext) - .onComplete { _ => trainingDataset.unpersist() } (executionContext) - - // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up - val foldMetricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) => - modelFuture.map { model => // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) logDebug(s"Got metric $metric for model trained with $paramMap.") @@ -174,6 +162,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) // Wait for metrics to be calculated before unpersisting validation dataset val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) + trainingDataset.unpersist() validationDataset.unpersist() foldMetrics }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits