diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 282aec09235fd..15b78b221394d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -18,49 +18,92 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.Transformer +import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} +/** + * A one-hot encoder that maps a column of label indices to a column of binary vectors, with + * at most a single one-value. By default, the binary vector has an element for each category, so + * with 5 categories, an input value of 2.0 would map to an output vector of + * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the + * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value + * of 0.0 would map to a vector of all zeros. Omitting the first category enables the vector + * columns to be independent. + */ @AlphaComponent -class OneHotEncoder(labelNames: Seq[String], includeFirst: Boolean = true) extends Transformer - with HasInputCol { +class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder] + with HasInputCol with HasOutputCol { - /** @group setParam */ - def setInputCol(value: String): this.type = set(inputCol, value) + /** + * Whether to include a component in the encoded vectors for the first category, defaults to true. + * @group param + */ + final val includeFirst: Param[Boolean] = + new Param[Boolean](this, "includeFirst", "include first category") + setDefault(includeFirst -> true) - private def outputColName(index: Int): String = { - s"${get(inputCol)}_${labelNames(index)}" - } + /** + * The names of the categories. Used to identify them in the attributes of the output column. + * This is a required parameter. + * @group param + */ + final val labelNames: Param[Array[String]] = + new Param[Array[String]](this, "labelNames", "categorical label names") - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - val map = this.paramMap ++ paramMap + /** @group setParam */ + def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value) - val startIndex = if (includeFirst) 0 else 1 - val cols = (startIndex until labelNames.length).map { index => - val colEncoder = udf { label: Double => if (index == label) 1.0 else 0.0 } - colEncoder(dataset(map(inputCol))).as(outputColName(index)) - } + /** @group setParam */ + def setLabelNames(value: Array[String]): this.type = set(labelNames, value) - dataset.select(Array(col("*")) ++ cols: _*) - } + /** @group setParam */ + override def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + override def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap - checkInputColumn(schema, map(inputCol), StringType) + val map = extractParamMap(paramMap) + SchemaUtils.checkColumnType(schema, map(inputCol), DoubleType) val inputFields = schema.fields - val startIndex = if (includeFirst) 0 else 1 - val fields = (startIndex until labelNames.length).map { index => - val colName = outputColName(index) - require(inputFields.forall(_.name != colName), - s"Output column $colName already exists.") - NominalAttribute.defaultAttr.withName(colName).toStructField() - } - - val outputFields = inputFields ++ fields + val outputColName = map(outputCol) + require(inputFields.forall(_.name != outputColName), + s"Output column $outputColName already exists.") + require(map.contains(labelNames), "OneHotEncoder missing category names") + val categories = map(labelNames) + val attrValues = (if (map(includeFirst)) categories else categories.drop(1)).toArray + val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues) + val outputFields = inputFields :+ attr.toStructField() StructType(outputFields) } + + protected def createTransformFunc(paramMap: ParamMap): (Double) => Vector = { + val map = extractParamMap(paramMap) + val first = map(includeFirst) + val vecLen = if (first) map(labelNames).length else map(labelNames).length - 1 + val oneValue = Array(1.0) + val emptyValues = Array[Double]() + val emptyIndices = Array[Int]() + label: Double => { + val values = if (first || label != 0.0) oneValue else emptyValues + val indices = if (first) { + Array(label.toInt) + } else if (label != 0.0) { + Array(label.toInt - 1) + } else { + emptyIndices + } + Vectors.sparse(vecLen, indices, values) + } + } + + /** + * Returns the data type of the output column. + */ + protected def outputDataType: DataType = new VectorUDT } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 3dd3f14dddf6e..6b76843a44612 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.feature +import org.apache.spark.ml.attribute.{NominalAttribute, Attribute} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, SQLContext} + import org.scalatest.FunSuite -import org.apache.spark.sql.SQLContext -import org.apache.spark.ml.attribute.{NominalAttribute, Attribute} class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { private var sqlContext: SQLContext = _ @@ -31,7 +33,7 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { sqlContext = new SQLContext(sc) } - test("OneHotEncoder") { + def stringIndexed(): (DataFrame, NominalAttribute) = { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) val df = sqlContext.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() @@ -41,14 +43,20 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { val transformed = indexer.transform(df) val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("a", "c", "b")) + (transformed, attr) + } - val encoder = new OneHotEncoder(attr.values.get) + test("OneHotEncoder includeFirst = true") { + val (transformed, attr) = stringIndexed() + val encoder = new OneHotEncoder() + .setLabelNames(attr.values.get) .setInputCol("labelIndex") + .setOutputCol("labelVec") val encoded = encoder.transform(transformed) - val output = encoded.select("id", "labelIndex_a", "labelIndex_c", "labelIndex_b").map { r => - (r.getInt(0), r.getDouble(1), r.getDouble(2), r.getDouble(3)) + val output = encoded.select("id", "labelVec").map { r => + val vec = r.get(1).asInstanceOf[Vector] + (r.getInt(0), vec(0), vec(1), vec(2)) }.collect().toSet // a -> 0, b -> 2, c -> 1 val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), @@ -56,4 +64,23 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { assert(output === expected) } + test("OneHotEncoder includeFirst = false") { + val (transformed, attr) = stringIndexed() + val encoder = new OneHotEncoder() + .setIncludeFirst(false) + .setLabelNames(attr.values.get) + .setInputCol("labelIndex") + .setOutputCol("labelVec") + val encoded = encoder.transform(transformed) + + val output = encoded.select("id", "labelVec").map { r => + val vec = r.get(1).asInstanceOf[Vector] + (r.getInt(0), vec(0), vec(1)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0), + (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0)) + assert(output === expected) + } + }