Skip to content

Commit

Permalink
init pr
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengruifeng committed Dec 2, 2019
1 parent 03ac1b7 commit 4667c45
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,18 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel


/**
* Common params for GaussianMixture and GaussianMixtureModel
*/
private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter with HasFeaturesCol
with HasSeed with HasPredictionCol with HasProbabilityCol with HasTol with HasAggregationDepth {
with HasSeed with HasPredictionCol with HasWeightCol with HasProbabilityCol with HasTol
with HasAggregationDepth {

/**
* Number of independent Gaussians in the mixture model. Must be greater than 1. Default: 2.
Expand Down Expand Up @@ -333,6 +334,10 @@ class GaussianMixture @Since("2.0.0") (
@Since("2.0.0")
def setProbabilityCol(value: String): this.type = set(probabilityCol, value)

/** @group setParam */
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

/** @group setParam */
@Since("2.0.0")
def setK(value: Int): this.type = set(k, value)
Expand Down Expand Up @@ -362,29 +367,39 @@ class GaussianMixture @Since("2.0.0") (
override def fit(dataset: Dataset[_]): GaussianMixtureModel = instrumented { instr =>
transformSchema(dataset.schema, logging = true)

val sc = dataset.sparkSession.sparkContext
val spark = dataset.sparkSession
import spark.implicits._

val sc = spark.sparkContext
val numClusters = $(k)

val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val instances = dataset
.select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map {
case Row(features: Vector) => features

val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
col($(weightCol)).cast(DoubleType)
} else {
lit(1.0)
}

val instances = dataset.select(DatasetUtils.columnToVector(dataset, $(featuresCol)), w)
.as[(Vector, Double)]
.rdd

if (handlePersistence) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
}

// Extract the number of features.
val numFeatures = instances.first().size
val numFeatures = MetadataUtils.getNumFeatures(dataset.schema($(featuresCol)))
.getOrElse(instances.first()._1.size)
require(numFeatures < GaussianMixture.MAX_NUM_FEATURES, s"GaussianMixture cannot handle more " +
s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" +
s" matrix is quadratic in the number of features.")

instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol,
aggregationDepth)
instr.logParams(this, featuresCol, predictionCol, probabilityCol, weightCol, k, maxIter,
seed, tol, aggregationDepth)
instr.logNumFeatures(numFeatures)

val shouldDistributeGaussians = GaussianMixture.shouldDistributeGaussians(
Expand All @@ -399,18 +414,14 @@ class GaussianMixture @Since("2.0.0") (
var iter = 0
while (iter < $(maxIter) && math.abs(logLikelihood - logLikelihoodPrev) > $(tol)) {

val bcWeights = instances.sparkContext.broadcast(weights)
val bcGaussians = instances.sparkContext.broadcast(gaussians)
val bcWeights = sc.broadcast(weights)
val bcGaussians = sc.broadcast(gaussians)

// aggregate the cluster contribution for all sample points
val sums = instances.treeAggregate(
new ExpectationAggregator(numFeatures, bcWeights, bcGaussians))(
seqOp = (c, v) => (c, v) match {
case (aggregator, instance) => aggregator.add(instance)
},
combOp = (c1, c2) => (c1, c2) match {
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
},
seqOp = (c: ExpectationAggregator, v: (Vector, Double)) => c.add(v._1, v._2),
combOp = (c1: ExpectationAggregator, c2: ExpectationAggregator) => c1.merge(c2),
depth = $(aggregationDepth))

bcWeights.destroy()
Expand All @@ -425,7 +436,7 @@ class GaussianMixture @Since("2.0.0") (
Create new distributions based on the partial assignments
(often referred to as the "M" step in literature)
*/
val sumWeights = sums.weights.sum
val sumWeights = sums.weightSum

if (shouldDistributeGaussians) {
val numPartitions = math.min(numClusters, 1024)
Expand Down Expand Up @@ -488,21 +499,30 @@ class GaussianMixture @Since("2.0.0") (
* we only save the upper triangular part as a dense vector (column major).
*/
private def initRandom(
instances: RDD[Vector],
instances: RDD[(Vector, Double)],
numClusters: Int,
numFeatures: Int): (Array[Double], Array[(DenseVector, DenseVector)]) = {
val samples = instances.takeSample(withReplacement = true, numClusters * numSamples, $(seed))
val weights: Array[Double] = Array.fill(numClusters)(1.0 / numClusters)
val gaussians: Array[(DenseVector, DenseVector)] = Array.tabulate(numClusters) { i =>
val slice = samples.view(i * numSamples, (i + 1) * numSamples)
val (samples, sampleWeights) = instances
.takeSample(withReplacement = true, numClusters * numSamples, $(seed))
.unzip

val weights = new Array[Double](numClusters)
val weightSum = sampleWeights.sum

val gaussians = Array.tabulate(numClusters) { i =>
val sampleSlice = samples.view(i * numSamples, (i + 1) * numSamples)
val weightSlice = sampleWeights.view(i * numSamples, (i + 1) * numSamples)
val localWeightSum = weightSlice.sum
weights(i) = localWeightSum / weightSum

val mean = {
val v = new DenseVector(new Array[Double](numFeatures))
var i = 0
while (i < numSamples) {
BLAS.axpy(1.0, slice(i), v)
i += 1
var j = 0
while (j < numSamples) {
BLAS.axpy(weightSlice(j), sampleSlice(j), v)
j += 1
}
BLAS.scal(1.0 / numSamples, v)
BLAS.scal(1.0 / localWeightSum, v)
v
}
/*
Expand All @@ -514,9 +534,13 @@ class GaussianMixture @Since("2.0.0") (
*/
val cov = {
val ss = new DenseVector(new Array[Double](numFeatures)).asBreeze
slice.foreach(xi => ss += (xi.asBreeze - mean.asBreeze) ^:^ 2.0)
var j = 0
while (j < numSamples) {
ss += ((sampleSlice(j).asBreeze - mean.asBreeze) ^:^ 2.0) * weightSlice(j)
j += 1
}
val diagVec = Vectors.fromBreeze(ss)
BLAS.scal(1.0 / numSamples, diagVec)
BLAS.scal(1.0 / localWeightSum, diagVec)
val covVec = new DenseVector(Array.fill[Double](
numFeatures * (numFeatures + 1) / 2)(0.0))
diagVec.toArray.zipWithIndex.foreach { case (v: Double, i: Int) =>
Expand Down Expand Up @@ -621,12 +645,13 @@ private class ExpectationAggregator(

private val k: Int = bcWeights.value.length
private var totalCnt: Long = 0L
private var totalWeightSum: Double = 0.0
private var newLogLikelihood: Double = 0.0
private lazy val newWeights: Array[Double] = new Array[Double](k)
private lazy val newMeans: Array[DenseVector] = Array.fill(k)(
new DenseVector(Array.fill[Double](numFeatures)(0.0)))
private lazy val newCovs: Array[DenseVector] = Array.fill(k)(
new DenseVector(Array.fill[Double](numFeatures * (numFeatures + 1) / 2)(0.0)))
new DenseVector(new Array[Double](numFeatures * (numFeatures + 1) / 2)))

@transient private lazy val oldGaussians = {
bcGaussians.value.map { case (mean, covVec) =>
Expand All @@ -637,6 +662,8 @@ private class ExpectationAggregator(

def count: Long = totalCnt

def weightSum: Double = totalWeightSum

def logLikelihood: Double = newLogLikelihood

def weights: Array[Double] = newWeights
Expand All @@ -650,9 +677,10 @@ private class ExpectationAggregator(
* means and covariances for each distributions, and update the log likelihood.
*
* @param instance The instance of data point to be added.
* @param weight The instance weight.
* @return This ExpectationAggregator object.
*/
def add(instance: Vector): this.type = {
def add(instance: Vector, weight: Double): this.type = {
val localWeights = bcWeights.value
val localOldGaussians = oldGaussians

Expand All @@ -666,20 +694,21 @@ private class ExpectationAggregator(
i += 1
}

newLogLikelihood += math.log(probSum)
newLogLikelihood += math.log(probSum) * weight
val localNewWeights = newWeights
val localNewMeans = newMeans
val localNewCovs = newCovs
i = 0
while (i < k) {
prob(i) /= probSum
localNewWeights(i) += prob(i)
BLAS.axpy(prob(i), instance, localNewMeans(i))
BLAS.spr(prob(i), instance, localNewCovs(i))
val w = prob(i) / probSum * weight
localNewWeights(i) += w
BLAS.axpy(w, instance, localNewMeans(i))
BLAS.spr(w, instance, localNewCovs(i))
i += 1
}

totalCnt += 1
totalWeightSum += weight
this
}

Expand All @@ -694,6 +723,7 @@ private class ExpectationAggregator(
def merge(other: ExpectationAggregator): this.type = {
if (other.count != 0) {
totalCnt += other.totalCnt
totalWeightSum += other.totalWeightSum

val localThisNewWeights = this.newWeights
val localOtherNewWeights = other.newWeights
Expand Down
18 changes: 18 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@ private[spark] object MetadataUtils {
}
}

/**
* Examine a schema to identify the number of features in a vector column.
* Returns None if the number of features is not specified.
*/
def getNumFeatures(vectorSchema: StructField): Option[Int] = {
if (vectorSchema.dataType == new VectorUDT) {
val group = AttributeGroup.fromStructField(vectorSchema)
val size = group.size
if (size >= 0) {
Some(size)
} else {
None
}
} else {
None
}
}

/**
* Examine a schema to identify categorical (Binary and Nominal) features.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.ml.stat.distribution.MultivariateGaussian
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._


class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest {
Expand Down Expand Up @@ -267,6 +268,18 @@ class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest {
assert(trueLikelihood ~== floatLikelihood absTol 1e-6)
}

test("GMM support instance weighting") {
val gm1 = new GaussianMixture().setK(k).setMaxIter(20).setSeed(seed)
val gm2 = new GaussianMixture().setK(k).setMaxIter(20).setSeed(seed).setWeightCol("weight")

Seq(1.0, 10.0, 100.0).foreach { w =>
val gmm1 = gm1.fit(dataset)
val ds2 = dataset.select(col("features"), lit(w).as("weight"))
val gmm2 = gm2.fit(ds2)
modelEquals(gmm1, gmm2)
}
}

test("prediction on single instance") {
val gmm = new GaussianMixture().setSeed(123L)
val model = gmm.fit(dataset)
Expand Down Expand Up @@ -319,10 +332,14 @@ object GaussianMixtureSuite extends SparkFunSuite {

def modelEquals(m1: GaussianMixtureModel, m2: GaussianMixtureModel): Unit = {
assert(m1.weights.length === m2.weights.length)
val s1 = m1.weights.zip(m1.gaussians).sortBy(_._1)
val s2 = m2.weights.zip(m2.gaussians).sortBy(_._1)
for (i <- m1.weights.indices) {
assert(m1.weights(i) ~== m2.weights(i) absTol 1E-3)
assert(m1.gaussians(i).mean ~== m2.gaussians(i).mean absTol 1E-3)
assert(m1.gaussians(i).cov ~== m2.gaussians(i).cov absTol 1E-3)
val (w1, g1) = s1(i)
val (w2, g2) = s2(i)
assert(w1 ~== w2 absTol 1E-3)
assert(g1.mean ~== g2.mean absTol 1E-3)
assert(g1.cov ~== g2.cov absTol 1E-3)
}
}
}

0 comments on commit 4667c45

Please sign in to comment.