diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 8df706fe62feb..2a7fdb08fc244 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -115,10 +115,7 @@ class KMeansModel private[ml] ( } override def transform(dataset: DataFrame): DataFrame = { - dataset.select( - dataset("*"), - callUDF(predict _, IntegerType, col($(featuresCol))).as($(predictionCol)) - ) + dataset.withColumn($(predictionCol), callUDF(predict _, IntegerType, col($(featuresCol)))) } override def transformSchema(schema: StructType): StructType = {