Skip to content

Commit

Permalink
make StringIndexerModel silent if input column does not exist
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jun 2, 2015
1 parent ad06727 commit e112394
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ class StringIndexerModel private[ml] (
def setOutputCol(value: String): this.type = set(outputCol, value)

override def transform(dataset: DataFrame): DataFrame = {
if (!dataset.schema.fieldNames.contains($(inputCol))) {
logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
"Skip StringIndexerModel.")
return dataset
}

val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
labelToIndex(label)
Expand All @@ -128,6 +134,11 @@ class StringIndexerModel private[ml] (
}

override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
if (schema.fieldNames.contains($(inputCol))) {
validateAndTransformSchema(schema)
} else {
// If the input column does not exist during transformation, we skip StringIndexerModel.
schema
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.scalatest.FunSuite

import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.mllib.util.MLlibTestSparkContext

class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {

test("StringIndexer") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
Expand Down Expand Up @@ -60,4 +61,12 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
}

test("StringIndexerModel should keep silent if the input column does not exist.") {
val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
.setInputCol("label")
.setOutputCol("labelIndex")
val df = sqlContext.range(0L, 10L)
assert(indexerModel.transform(df).eq(df))
}
}

0 comments on commit e112394

Please sign in to comment.