Skip to content

Commit

Permalink
Accept initial cluster centers in KMeans
Browse files Browse the repository at this point in the history
  • Loading branch information
FlytxtRnD committed Jun 10, 2015
1 parent 778f3ca commit 6959861
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,26 @@ class KMeans private (
this
}

// Initial cluster centers can be provided as a KMeansModel object rather than using the
// random or k-means|| initializationMode
private var initialModel: Option[KMeansModel] = None

/** Set the initial starting point, bypassing the random initialization or k-means||
* The condition (model.k == this.k) must be met; failure will result in an
* IllegalArgumentException.
*/
def setInitialModel(model: KMeansModel): this.type = {
if (model.k == k) {
initialModel = Some(model)
} else {
throw new IllegalArgumentException("mismatched cluster count (model.k != k)")
}
this
}

/** Return the user supplied initial KMeansModel, if supplied */
def getInitialModel: Option[KMeansModel] = initialModel

/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
Expand Down Expand Up @@ -193,12 +213,19 @@ class KMeans private (

val initStartTime = System.nanoTime()

val centers = if (initializationMode == KMeans.RANDOM) {
initRandom(data)
} else {
initKMeansParallel(data)
val centers = initialModel match {
case Some(kMeansCenters) => {
Array.tabulate(runs)(r => kMeansCenters.clusterCenters
.map(s => new VectorWithNorm(s, Vectors.norm(s, 2.0))))
}
case None => {
if (initializationMode == KMeans.RANDOM) {
initRandom(data)
} else {
initKMeansParallel(data)
}
}
}

val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
" seconds.")
Expand Down Expand Up @@ -478,6 +505,25 @@ object KMeans {
train(data, k, maxIterations, runs, K_MEANS_PARALLEL)
}

/**
* Trains a k-means model using the given set of parameters and initial cluster centers
*
* @param data training points stored as `RDD[Vector]`
* @param k number of clusters
* @param maxIterations max number of iterations
* @param initialModel an existing set of cluster centers.
*/
def train(
data: RDD[Vector],
k: Int,
maxIterations: Int,
initialModel: KMeansModel): KMeansModel = {
new KMeans().setK(k)
.setMaxIterations(maxIterations)
.setInitialModel(initialModel)
.run(data)
}

/**
* Returns the index of the closest center to the given point, as well as the squared distance.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,34 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
}

test("Initialize using given cluster centers") {
val points = Seq(
Vectors.dense(0.0, 0.0),
Vectors.dense(0.0, 0.1),
Vectors.dense(0.1, 0.0),
Vectors.dense(9.0, 0.0),
Vectors.dense(9.0, 0.2),
Vectors.dense(9.2, 0.0)
)
val rdd = sc.parallelize(points, 3)
val model = KMeans.train(rdd, k = 2, maxIterations = 2, runs = 1)

val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
model.save(sc, path)
val loadedModel = KMeansModel.load(sc, path)

val newModel = KMeans.train(rdd, k = 2, maxIterations = 2, initialModel = loadedModel)
val predicts = newModel.predict(rdd).collect()

assert(predicts(0) === predicts(1))
assert(predicts(0) === predicts(2))
assert(predicts(3) === predicts(4))
assert(predicts(3) === predicts(5))
assert(predicts(0) != predicts(3))
}

}

object KMeansSuite extends SparkFunSuite {
Expand Down

0 comments on commit 6959861

Please sign in to comment.