From 695986125a1521aeffce1b5a6f6fd05b480a4384 Mon Sep 17 00:00:00 2001 From: FlytxtRnD Date: Wed, 10 Jun 2015 12:06:25 +0530 Subject: [PATCH] Accept initial cluster centers in KMeans --- .../spark/mllib/clustering/KMeans.scala | 56 +++++++++++++++++-- .../spark/mllib/clustering/KMeansSuite.scala | 28 ++++++++++ 2 files changed, 79 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 0f8d6a399682d..b73c52b2b197d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -156,6 +156,26 @@ class KMeans private ( this } + // Initial cluster centers can be provided as a KMeansModel object rather than using the + // random or k-means|| initializationMode + private var initialModel: Option[KMeansModel] = None + + /** Set the initial starting point, bypassing the random initialization or k-means|| + * The condition (model.k == this.k) must be met; failure will result in an + * IllegalArgumentException. + */ + def setInitialModel(model: KMeansModel): this.type = { + if (model.k == k) { + initialModel = Some(model) + } else { + throw new IllegalArgumentException("mismatched cluster count (model.k != k)") + } + this + } + + /** Return the user supplied initial KMeansModel, if supplied */ + def getInitialModel: Option[KMeansModel] = initialModel + /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. @@ -193,12 +213,19 @@ class KMeans private ( val initStartTime = System.nanoTime() - val centers = if (initializationMode == KMeans.RANDOM) { - initRandom(data) - } else { - initKMeansParallel(data) + val centers = initialModel match { + case Some(kMeansCenters) => { + Array.tabulate(runs)(r => kMeansCenters.clusterCenters + .map(s => new VectorWithNorm(s, Vectors.norm(s, 2.0)))) + } + case None => { + if (initializationMode == KMeans.RANDOM) { + initRandom(data) + } else { + initKMeansParallel(data) + } + } } - val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) + " seconds.") @@ -478,6 +505,25 @@ object KMeans { train(data, k, maxIterations, runs, K_MEANS_PARALLEL) } + /** + * Trains a k-means model using the given set of parameters and initial cluster centers + * + * @param data training points stored as `RDD[Vector]` + * @param k number of clusters + * @param maxIterations max number of iterations + * @param initialModel an existing set of cluster centers. + */ + def train( + data: RDD[Vector], + k: Int, + maxIterations: Int, + initialModel: KMeansModel): KMeansModel = { + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .setInitialModel(initialModel) + .run(data) + } + /** * Returns the index of the closest center to the given point, as well as the squared distance. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 0dbbd7127444f..9678687915186 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -278,6 +278,34 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("Initialize using given cluster centers") { + val points = Seq( + Vectors.dense(0.0, 0.0), + Vectors.dense(0.0, 0.1), + Vectors.dense(0.1, 0.0), + Vectors.dense(9.0, 0.0), + Vectors.dense(9.0, 0.2), + Vectors.dense(9.2, 0.0) + ) + val rdd = sc.parallelize(points, 3) + val model = KMeans.train(rdd, k = 2, maxIterations = 2, runs = 1) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + model.save(sc, path) + val loadedModel = KMeansModel.load(sc, path) + + val newModel = KMeans.train(rdd, k = 2, maxIterations = 2, initialModel = loadedModel) + val predicts = newModel.predict(rdd).collect() + + assert(predicts(0) === predicts(1)) + assert(predicts(0) === predicts(2)) + assert(predicts(3) === predicts(4)) + assert(predicts(3) === predicts(5)) + assert(predicts(0) != predicts(3)) + } + } object KMeansSuite extends SparkFunSuite {