diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 7340448b7d731..a293525854e0b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -58,6 +58,29 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(kmeans.getEpsilon === 1e-4) } + test("set parameters") { + val kmeans = new KMeans() + .setK(9) + .setFeaturesCol("test_feature") + .setPredictionCol("test_prediction") + .setMaxIter(33) + .setRuns(7) + .setInitializationMode(MLlibKMeans.RANDOM) + .setInitializationSteps(3) + .setSeed(123) + .setEpsilon(1e-3) + + assert(kmeans.getK === 9) + assert(kmeans.getFeaturesCol === "test_feature") + assert(kmeans.getPredictionCol === "test_prediction") + assert(kmeans.getMaxIter === 33) + assert(kmeans.getRuns === 7) + assert(kmeans.getInitializationMode === MLlibKMeans.RANDOM) + assert(kmeans.getInitializationSteps === 3) + assert(kmeans.getSeed === 123) + assert(kmeans.getEpsilon === 1e-3) + } + test("parameters validation") { intercept[IllegalArgumentException] { new KMeans().setK(1)