Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengruifeng committed Dec 25, 2019
1 parent f22d22a commit 53efa70
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,10 @@ class GaussianMixture @Since("2.0.0") (
val spark = dataset.sparkSession
import spark.implicits._

val sc = spark.sparkContext
val numClusters = $(k)
val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
require(numFeatures < GaussianMixture.MAX_NUM_FEATURES, s"GaussianMixture cannot handle more " +
s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" +
s" matrix is quadratic in the number of features.")

val handlePersistence = dataset.storageLevel == StorageLevel.NONE

Expand All @@ -401,12 +403,8 @@ class GaussianMixture @Since("2.0.0") (
instances.persist(StorageLevel.MEMORY_AND_DISK)
}

// Extract the number of features.
val numFeatures = MetadataUtils.getNumFeatures(dataset.schema($(featuresCol)))
.getOrElse(instances.first()._1.size)
require(numFeatures < GaussianMixture.MAX_NUM_FEATURES, s"GaussianMixture cannot handle more " +
s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" +
s" matrix is quadratic in the number of features.")
val sc = spark.sparkContext
val numClusters = $(k)

instr.logPipelineStage(this)
instr.logDataset(dataset)
Expand Down
18 changes: 0 additions & 18 deletions mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,6 @@ private[spark] object MetadataUtils {
}
}

/**
* Examine a schema to identify the number of features in a vector column.
* Returns None if the number of features is not specified.
*/
def getNumFeatures(vectorSchema: StructField): Option[Int] = {
if (vectorSchema.dataType == new VectorUDT) {
val group = AttributeGroup.fromStructField(vectorSchema)
val size = group.size
if (size >= 0) {
Some(size)
} else {
None
}
} else {
None
}
}

/**
* Examine a schema to identify categorical (Binary and Nominal) features.
*
Expand Down

0 comments on commit 53efa70

Please sign in to comment.