Skip to content

Commit

Permalink
Implement the KMeans API for spark.ml Pipelines in Scala
Browse files Browse the repository at this point in the history
  • Loading branch information
yu-iskw committed Jul 1, 2015
1 parent 1ce6428 commit 63ad785
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 3 deletions.
197 changes: 197 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.clustering

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasMaxIter, HasPredictionCol, HasSeed}
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.mllib
import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.util.Utils


/**
* Common params for KMeans and KMeansModel
*/
private[clustering] trait KMeansParams
extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol {
/**
* Param for the column name for the number of clusters to create.
* @group param
*/
val k = new Param[Int](this, "k", "number of clusters to create")

/** @group getParam */
def getK: Int = $(k)

/**
* Param for the column name for the number of runs of the algorithm to execute in parallel.
* @group param
*/
val runs = new Param[Int](this, "runs", "number of runs of the algorithm to execute in parallel")

/** @group getParam */
def getRuns: Int = $(runs)

/**
* Param for the column name for the distance threshold
* within which we've consider centers to have converged.
* @group param
*/
val epsilon = new Param[Double](this, "epsilon", "distance threshold")

/** @group getParam */
def getEpsilon: Double = $(epsilon)

/**
* Param for the initialization algorithm.
* @group param
*/
val initializationMode = new Param[String](this, "initializationMode", "initialization algorithm")

/** @group getParam */
def getInitializationMode: String = $(initializationMode)

/**
* Param for the number of steps for k-means initialization mode.
* @group param
*/
val initializationSteps =
new Param[Int](this, "initializationSteps", "number of steps for k-means||")

/** @group getParam */
def getInitializationSteps: Int = $(initializationSteps)

/**
* Validates and transforms the input schema.
* @param schema input schema
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
}
}

/**
* :: Experimental ::
* Model fitted by KMeans.
*
* @param paramMap a parameter map for fitting.
* @param parentModel a model trained by spark.mllib.clustering.KMeans.
*/
@Experimental
class KMeansModel private[ml] (
override val uid: String,
val paramMap: ParamMap,
val parentModel: mllib.clustering.KMeansModel
) extends Model[KMeansModel] with KMeansParams {
/**
* Transforms the input dataset.
*/
override def transform(dataset: DataFrame): DataFrame = {
dataset.select(
dataset("*"),
callUDF(predict _, IntegerType, col($(featuresCol))).as($(predictionCol))
)
}

/**
* :: DeveloperApi ::
*
* Derives the output schema from the input schema.
*/
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

def predict(features: Vector): Int = parentModel.predict(features)

def clusterCenters: Array[Vector] = parentModel.clusterCenters
}

/**
* :: Experimental ::
* KMeans API for spark.ml Pipeline.
*/
@Experimental
class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams {
setK(2)
setMaxIter(20)
setRuns(1)
setInitializationMode(KMeans.K_MEANS_PARALLEL)
setInitializationSteps(5)
setEpsilon(1e-4)
setSeed(Utils.random.nextLong())

def this() = this(Identifiable.randomUID("kmeans"))

/** @group setParam */
def setFeaturesCol(value: String): this.type = set(featuresCol, value)

/** @group setParam */
def setK(value: Int): this.type = set(k, value)

/** @group setParam */
def setInitializationMode(value: String): this.type = {
mllib.clustering.KMeans.validateInitializationMode(value)
set(initializationMode, value)
}

/** @group setParam */
def setInitializationSteps(value: Int): this.type = {
require(value > 0, "Number of initialization steps must be positive")
set(initializationSteps, value)
}

/** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value)

/** @group setParam */
def setRuns(value: Int): this.type = set(runs, value)

/** @group setParam */
def setEpsilon(value: Double): this.type = set(epsilon, value)

/** @group setParam */
def setSeed(value: Long): this.type = set(seed, value)


override def fit(dataset: DataFrame): KMeansModel = {
val map = this.extractParamMap()
val rdd = dataset.select(col(map(featuresCol))).map { case Row(point: Vector) => point}

val algo = new mllib.clustering.KMeans()
.setK(map(k))
.setMaxIterations(map(maxIter))
.setSeed(map(seed))
val parentModel = algo.run(rdd)
new KMeansModel(uid, map, parentModel)
}

override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ class KMeans private (
* (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
*/
def setInitializationMode(initializationMode: String): this.type = {
if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) {
throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode)
}
KMeans.validateInitializationMode(initializationMode)
this.initializationMode = initializationMode
this
}
Expand Down Expand Up @@ -521,6 +519,13 @@ object KMeans {
v2: VectorWithNorm): Double = {
MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
}

private[spark] def validateInitializationMode(initializationMode: String): Boolean = {
if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) {
throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode)
}
true
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.clustering

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, SQLContext}

private[clustering] case class TestRow(features: Vector)

object KMeansSuite {
def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
val sc = sql.sparkContext
val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
.map(v => new TestRow(v))
sql.createDataFrame(rdd)
}
}

class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {

val k = 5
@transient var dataset: DataFrame = _

override def beforeAll(): Unit = {
super.beforeAll()

dataset = KMeansSuite.generateKMeansData(sqlContext, 1000, 3, k)
}

test("default parameters") {
val kmeans = new KMeans()

assert(kmeans.getK === 2)
assert(kmeans.getFeaturesCol === "features")
assert(kmeans.getMaxIter === 20)
assert(kmeans.getRuns === 1)
assert(kmeans.getInitializationMode === KMeans.K_MEANS_PARALLEL)
assert(kmeans.getInitializationSteps === 5)
assert(kmeans.getEpsilon === 1e-4)
}

test("fit & transform") {
val kmeans = new KMeans().setK(k)
val model = kmeans.fit(dataset)
assert(model.clusterCenters.length === k)

val transformed = model.transform(dataset)
assert(transformed.columns === Array("features", "prediction"))
val clusters = transformed.select("prediction")
.map(row => row.apply(0)).distinct().collect().toSet
assert(clusters === Set(0, 1, 2, 3, 4))
}
}

0 comments on commit 63ad785

Please sign in to comment.