diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 8d17459eb4b8b..d779e602545cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -382,8 +382,10 @@ class GaussianMixture @Since("2.0.0") ( val spark = dataset.sparkSession import spark.implicits._ - val sc = spark.sparkContext - val numClusters = $(k) + val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol)) + require(numFeatures < GaussianMixture.MAX_NUM_FEATURES, s"GaussianMixture cannot handle more " + + s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" + + s" matrix is quadratic in the number of features.") val handlePersistence = dataset.storageLevel == StorageLevel.NONE @@ -401,12 +403,8 @@ class GaussianMixture @Since("2.0.0") ( instances.persist(StorageLevel.MEMORY_AND_DISK) } - // Extract the number of features. - val numFeatures = MetadataUtils.getNumFeatures(dataset.schema($(featuresCol))) - .getOrElse(instances.first()._1.size) - require(numFeatures < GaussianMixture.MAX_NUM_FEATURES, s"GaussianMixture cannot handle more " + - s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" + - s" matrix is quadratic in the number of features.") + val sc = spark.sparkContext + val numClusters = $(k) instr.logPipelineStage(this) instr.logDataset(dataset) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala index e20163977755c..6db0408e8d2b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -71,24 +71,6 @@ private[spark] object MetadataUtils { } } - /** - * Examine a schema to identify the number of features in a vector column. - * Returns None if the number of features is not specified. - */ - def getNumFeatures(vectorSchema: StructField): Option[Int] = { - if (vectorSchema.dataType == new VectorUDT) { - val group = AttributeGroup.fromStructField(vectorSchema) - val size = group.size - if (size >= 0) { - Some(size) - } else { - None - } - } else { - None - } - } - /** * Examine a schema to identify categorical (Binary and Nominal) features. *