From 7991e15d32b4fd8a162816b273b15e691915214a Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Tue, 30 Jun 2015 09:09:48 +0900 Subject: [PATCH] Add a validation for `k` --- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 2 +- .../scala/org/apache/spark/ml/clustering/KMeansSuite.scala | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index d78620896dc15..642612c4f0191 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -40,7 +40,7 @@ private[clustering] trait KMeansParams * Set the number of clusters to create (k). Default: 2. * @group param */ - val k = new Param[Int](this, "k", "number of clusters to create") + val k = new Param[Int](this, "k", "number of clusters to create", (x: Int) => x > 1) /** @group getParam */ def getK: Int = $(k) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 6f3becb82f26f..a1a5c10b1dadf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -58,6 +58,12 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(kmeans.getEpsilon === 1e-4) } + test("parameters validation") { + intercept[IllegalArgumentException] { + new KMeans().setK(1) + } + } + test("fit & transform") { val predictionColName = "kmeans_prediction" val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName)