-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-29967][ML][PYTHON] KMeans support instance weighting #26739
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.Since | |
import org.apache.spark.broadcast.Broadcast | ||
import org.apache.spark.internal.Logging | ||
import org.apache.spark.ml.util.Instrumentation | ||
import org.apache.spark.mllib.linalg.{Vector, Vectors} | ||
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} | ||
import org.apache.spark.mllib.linalg.BLAS.axpy | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.storage.StorageLevel | ||
|
@@ -209,11 +209,14 @@ class KMeans private ( | |
*/ | ||
@Since("0.8.0") | ||
def run(data: RDD[Vector]): KMeansModel = { | ||
run(data, None) | ||
val instances: RDD[(Vector, Double)] = data.map { | ||
case (point) => (point, 1.0) | ||
} | ||
runWithWeight(instances, None) | ||
} | ||
|
||
private[spark] def run( | ||
data: RDD[Vector], | ||
private[spark] def runWithWeight( | ||
data: RDD[(Vector, Double)], | ||
instr: Option[Instrumentation]): KMeansModel = { | ||
|
||
if (data.getStorageLevel == StorageLevel.NONE) { | ||
|
@@ -222,12 +225,15 @@ class KMeans private ( | |
} | ||
|
||
// Compute squared norms and cache them. | ||
val norms = data.map(Vectors.norm(_, 2.0)) | ||
val zippedData = data.zip(norms).map { case (v, norm) => | ||
new VectorWithNorm(v, norm) | ||
val norms = data.map { case (v, _) => | ||
Vectors.norm(v, 2.0) | ||
} | ||
|
||
val zippedData = data.zip(norms).map { case ((v, w), norm) => | ||
(new VectorWithNorm(v, norm), w) | ||
} | ||
zippedData.persist() | ||
val model = runAlgorithm(zippedData, instr) | ||
val model = runAlgorithmWithWeight(zippedData, instr) | ||
zippedData.unpersist() | ||
|
||
// Warn at the end of the run as well, for increased visibility. | ||
|
@@ -241,8 +247,8 @@ class KMeans private ( | |
/** | ||
* Implementation of K-Means algorithm. | ||
*/ | ||
private def runAlgorithm( | ||
data: RDD[VectorWithNorm], | ||
private def runAlgorithmWithWeight( | ||
data: RDD[(VectorWithNorm, Double)], | ||
instr: Option[Instrumentation]): KMeansModel = { | ||
|
||
val sc = data.sparkContext | ||
|
@@ -251,14 +257,17 @@ class KMeans private ( | |
|
||
val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure) | ||
|
||
val dataVectorWithNorm = data.map(d => d._1) | ||
val weights = data.map(d => d._2) | ||
|
||
val centers = initialModel match { | ||
case Some(kMeansCenters) => | ||
kMeansCenters.clusterCenters.map(new VectorWithNorm(_)) | ||
case None => | ||
if (initializationMode == KMeans.RANDOM) { | ||
initRandom(data) | ||
initRandom(dataVectorWithNorm) | ||
} else { | ||
initKMeansParallel(data, distanceMeasureInstance) | ||
initKMeansParallel(dataVectorWithNorm, distanceMeasureInstance) | ||
} | ||
} | ||
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 | ||
|
@@ -278,32 +287,38 @@ class KMeans private ( | |
val bcCenters = sc.broadcast(centers) | ||
|
||
// Find the new centers | ||
val collected = data.mapPartitions { points => | ||
val collected = data.mapPartitions { pointsAndWeights => | ||
val thisCenters = bcCenters.value | ||
val dims = thisCenters.head.vector.size | ||
|
||
val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims)) | ||
val counts = Array.fill(thisCenters.length)(0L) | ||
|
||
points.foreach { point => | ||
// clusterWeightSum is needed to calculate cluster center | ||
// cluster center = | ||
// sample1 * weight1/clusterWeightSum + sample2 * weight2/clusterWeightSum + ... | ||
val clusterWeightSum = Array.ofDim[Double](thisCenters.length) | ||
|
||
pointsAndWeights.foreach { case (point, weight) => | ||
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point) | ||
costAccum.add(cost) | ||
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter)) | ||
counts(bestCenter) += 1 | ||
costAccum.add(cost * weight) | ||
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter), weight) | ||
clusterWeightSum(bestCenter) += weight | ||
} | ||
|
||
counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator | ||
}.reduceByKey { case ((sum1, count1), (sum2, count2)) => | ||
clusterWeightSum.indices.filter(clusterWeightSum(_) > 0) | ||
.map(j => (j, (sums(j), clusterWeightSum(j)))).iterator | ||
}.reduceByKey { case ((sum1, clusterWeightSum1), (sum2, clusterWeightSum2)) => | ||
axpy(1.0, sum2, sum1) | ||
(sum1, count1 + count2) | ||
(sum1, clusterWeightSum1 + clusterWeightSum2) | ||
}.collectAsMap() | ||
|
||
if (iteration == 0) { | ||
instr.foreach(_.logNumExamples(collected.values.map(_._2).sum)) | ||
instr.foreach(_.logNumExamples(data.count())) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit, what about using a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. Thanks! |
||
instr.foreach(_.logSumOfWeights(collected.values.map(_._2).sum)) | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you just log the sum of weights? it keeps the same info in the unweighted case and it's still sort of meaningful as 'number of examples' in the weighted case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1, I guess we need to add a new var There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess maybe leave the code this way for now and open a separate PR later on to add method There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am OK to add new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am OK to add new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. Thanks! |
||
val newCenters = collected.mapValues { case (sum, count) => | ||
distanceMeasureInstance.centroid(sum, count) | ||
val newCenters = collected.mapValues { case (sum, weightSum) => | ||
distanceMeasureInstance.centroid(sum, weightSum) | ||
} | ||
|
||
bcCenters.destroy() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the above
def centroid(sum: Vector, count: Long): VectorWithNorm
still needed?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. It is still used by
BisecttingKMeans