diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index d78620896dc15..642612c4f0191 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -40,7 +40,7 @@ private[clustering] trait KMeansParams * Set the number of clusters to create (k). Default: 2. * @group param */ - val k = new Param[Int](this, "k", "number of clusters to create") + val k = new Param[Int](this, "k", "number of clusters to create", (x: Int) => x > 1) /** @group getParam */ def getK: Int = $(k) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 6f3becb82f26f..a1a5c10b1dadf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -58,6 +58,12 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(kmeans.getEpsilon === 1e-4) } + test("parameters validation") { + intercept[IllegalArgumentException] { + new KMeans().setK(1) + } + } + test("fit & transform") { val predictionColName = "kmeans_prediction" val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName)