From ef1021a1d2efea0839efe0409b01af3770db6204 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 11 Nov 2019 21:37:18 -0800 Subject: [PATCH 1/3] [SPARK-29808][ML][PYTHON] StopWordsRemover should support multi-cols --- .../spark/ml/feature/StopWordsRemover.scala | 66 +++++++-- .../ml/feature/StopWordsRemoverSuite.scala | 133 +++++++++++++++++- python/pyspark/ml/feature.py | 39 ++++- 3 files changed, 221 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index f95e03ae6c822..fc1a2d0540c3c 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -22,15 +22,20 @@ import java.util.Locale import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.{ArrayType, StringType, StructType} +import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} /** * A feature transformer that filters out stop words from input. * + * Since 3.0.0, + * `StopWordsRemover` can filters out multiple columns at once by setting the `inputCols` + * parameter. Note that when both the `inputCol` and `inputCols` parameters are set, an Exception + * will be thrown. + * * @note null values from input array are preserved unless adding null to stopWords * explicitly. * @@ -38,7 +43,8 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructType} */ @Since("1.5.0") class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { + extends Transformer with HasInputCol with HasOutputCol with HasInputCols with HasOutputCols + with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("stopWords")) @@ -51,6 +57,14 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String @Since("1.5.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("3.0.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("3.0.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + /** * The words to be filtered out. * Default: English stop words @@ -121,6 +135,15 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String } } + /** Returns the input and output column names corresponding in pair. */ + private[feature] def getInOutCols(): (Array[String], Array[String]) = { + if (isSet(inputCol)) { + (Array($(inputCol)), Array($(outputCol))) + } else { + ($(inputCols), $(outputCols)) + } + } + setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), caseSensitive -> false, locale -> getDefaultOrUS.toString) @@ -142,16 +165,41 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String terms.filter(s => !lowerStopWords.contains(toLower(s))) } } - val metadata = outputSchema($(outputCol)).metadata - dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) + + val (inputColNames, outputColNames) = getInOutCols() + val ouputCols = inputColNames.map { inputColName => + t(col(inputColName)) + } + val ouputMetadata = outputColNames.map(outputSchema(_).metadata) + dataset.withColumns(outputColNames, ouputCols, ouputMetadata) } @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.sameType(ArrayType(StringType)), "Input type must be " + - s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.") - SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), + Seq(outputCols)) + + if (isSet(inputCols)) { + require(getInputCols.length == getOutputCols.length, + s"StopWordsRemover $this has mismatched Params " + + s"for multi-column transform. Params (inputCols, outputCols) should have " + + s"equal lengths, but they have different lengths: " + + s"(${getInputCols.length}, ${getOutputCols.length}).") + } + + val (inputColNames, outputColNames) = getInOutCols() + + var outputFields = schema.fields + inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) => + require(!schema.fieldNames.contains(outputColName), + s"Output Column $outputColName already exists.") + val inputType = schema(inputColName).dataType + require(inputType.sameType(ArrayType(StringType)), "Input type must be " + + s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.") + val outputField = StructField(outputColName, inputType, schema(inputColName).nullable) + outputFields :+= outputField + } + StructType(outputFields) } @Since("1.5.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 6d0b83e85733e..3e4ccbee404cb 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import java.util.Locale +import org.apache.spark.ml.Pipeline import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.{DataFrame, Row} @@ -181,12 +182,19 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { } test("read/write") { - val t = new StopWordsRemover() + val t1 = new StopWordsRemover() .setInputCol("myInputCol") .setOutputCol("myOutputCol") .setStopWords(Array("the", "a")) .setCaseSensitive(true) - testDefaultReadWrite(t) + testDefaultReadWrite(t1) + + val t2 = new StopWordsRemover() + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("result1", "result2", "result3")) + .setStopWords(Array("the", "a")) + .setCaseSensitive(true) + testDefaultReadWrite(t2) } test("StopWordsRemover output column already exists") { @@ -199,7 +207,7 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { testTransformerByInterceptingException[(Array[String], Array[String])]( dataSet, remover, - s"requirement failed: Column $outputCol already exists.", + s"requirement failed: Output Column $outputCol already exists.", "expected") } @@ -217,4 +225,123 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { Locale.setDefault(oldDefault) } } + + test("Multiple Columns: StopWordsRemover default") { + val remover = new StopWordsRemover() + .setInputCols(Array("raw1", "raw2")) + .setOutputCols(Array("filtered1", "filtered2")) + val df = Seq( + (Seq("test", "test"), Seq("test1", "test2"), Seq("test", "test"), Seq("test1", "test2")), + (Seq("a", "b", "c", "d"), Seq("a", "b"), Seq("b", "c", "d"), Seq("b")), + (Seq("a", "the", "an"), Seq("the", "an"), Seq(), Seq()), + (Seq("A", "The", "AN"), Seq("A", "The"), Seq(), Seq()), + (Seq(null), Seq(null), Seq(null), Seq(null)), + (Seq(), Seq(), Seq(), Seq()) + ).toDF("raw1", "raw2", "expected1", "expected2") + + remover.transform(df) + .select("filtered1", "expected1", "filtered2", "expected2") + .collect().foreach { + case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) => + assert(r1 === e1, + s"The result value is not correct after bucketing. Expected $e1 but found $r1") + assert(r2 === e2, + s"The result value is not correct after bucketing. Expected $e2 but found $r2") + } + } + + test("Multiple Columns: StopWordsRemover with particular stop words list") { + val stopWords = Array("test", "a", "an", "the") + val remover = new StopWordsRemover() + .setInputCols(Array("raw1", "raw2")) + .setOutputCols(Array("filtered1", "filtered2")) + .setStopWords(stopWords) + val df = Seq( + (Seq("test", "test"), Seq("test1", "test2"), Seq(), Seq("test1", "test2")), + (Seq("a", "b", "c", "d"), Seq("a", "b"), Seq("b", "c", "d"), Seq("b")), + (Seq("a", "the", "an"), Seq("a", "the", "test1"), Seq(), Seq("test1")), + (Seq("A", "The", "AN"), Seq("A", "The", "AN"), Seq(), Seq()), + (Seq(null), Seq(null), Seq(null), Seq(null)), + (Seq(), Seq(), Seq(), Seq()) + ).toDF("raw1", "raw2", "expected1", "expected2") + + remover.transform(df) + .select("filtered1", "expected1", "filtered2", "expected2") + .collect().foreach { + case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) => + assert(r1 === e1, + s"The result value is not correct after bucketing. Expected $e1 but found $r1") + assert(r2 === e2, + s"The result value is not correct after bucketing. Expected $e2 but found $r2") + } + } + + test("Compare single/multiple column(s) StopWordsRemover in pipeline") { + val df = Seq( + (Seq("test", "test"), Seq("test1", "test2")), + (Seq("a", "b", "c", "d"), Seq("a", "b")), + (Seq("a", "the", "an"), Seq("a", "the", "test1")), + (Seq("A", "The", "AN"), Seq("A", "The", "AN")), + (Seq(null), Seq(null)), + (Seq(), Seq()) + ).toDF("input1", "input2") + + val multiColsRemover = new StopWordsRemover() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("output1", "output2")) + + val plForMultiCols = new Pipeline() + .setStages(Array(multiColsRemover)) + .fit(df) + + val removerForCol1 = new StopWordsRemover() + .setInputCol("input1") + .setOutputCol("output1") + val removerForCol2 = new StopWordsRemover() + .setInputCol("input2") + .setOutputCol("output2") + + val plForSingleCol = new Pipeline() + .setStages(Array(removerForCol1, removerForCol2)) + .fit(df) + + val resultForSingleCol = plForSingleCol.transform(df) + .select("output1", "output2") + .collect() + val resultForMultiCols = plForMultiCols.transform(df) + .select("output1", "output2") + .collect() + + resultForSingleCol.zip(resultForMultiCols).foreach { + case (rowForSingle, rowForMultiCols) => + assert(rowForSingle === rowForMultiCols) + } + } + + test("Multiple Columns: Mismatched sizes of inputCols/outputCols") { + val remover = new StopWordsRemover() + .setInputCols(Array("input1")) + .setOutputCols(Array("result1", "result2")) + val df = Seq( + (Seq("A"), Seq("A")), + (Seq("The", "the"), Seq("The")) + ).toDF("input1", "input2") + intercept[IllegalArgumentException] { + remover.transform(df).count() + } + } + + test("Multiple Columns: Set both of inputCol/inputCols") { + val remover = new StopWordsRemover() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("result1", "result2")) + .setInputCol("input1") + val df = Seq( + (Seq("A"), Seq("A")), + (Seq("The", "the"), Seq("The")) + ).toDF("input1", "input2") + intercept[IllegalArgumentException] { + remover.transform(df).count() + } + } } diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 9513b0caecb9a..f6e531302317b 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -3774,9 +3774,13 @@ def setOutputCol(self, value): return self._set(outputCol=value) -class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): +class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols, + JavaMLReadable, JavaMLWritable): """ A feature transformer that filters out stop words from input. + Since 3.0.0, :py:class:`StopWordsRemover` can filter out multiple columns at once by setting + the :py:attr:`inputCols` parameter. Note that when both the :py:attr:`inputCol` and + :py:attr:`inputCols` parameters are set, an Exception will be thrown. .. note:: null values from input array are preserved unless adding null to stopWords explicitly. @@ -3795,6 +3799,17 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl True >>> loadedRemover.getCaseSensitive() == remover.getCaseSensitive() True + >>> df2 = spark.createDataFrame([(["a", "b", "c"], ["a", "b"])], ["text1", "text2"]) + >>> remover2 = StopWordsRemover(stopWords=["b"]) + >>> remover2.setInputCols(["text1", "text2"]).setOutputCols(["words1", "words2"]) + StopWordsRemover... + >>> remover2.transform(df2).show() + +---------+------+------+------+ + | text1| text2|words1|words2| + +---------+------+------+------+ + |[a, b, c]|[a, b]|[a, c]| [a]| + +---------+------+------+------+ + ... .. versionadded:: 1.6.0 """ @@ -3808,10 +3823,10 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl @keyword_only def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False, - locale=None): + locale=None, inputCols=None, outputCols=None): """ __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \ - locale=None) + locale=None, inputCols=None, outputCols=None) """ super(StopWordsRemover, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", @@ -3824,10 +3839,10 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive= @keyword_only @since("1.6.0") def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False, - locale=None): + locale=None, inputCols=None, outputCols=None): """ setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \ - locale=None) + locale=None, inputCols=None, outputCols=None) Sets params for this StopWordRemover. """ kwargs = self._input_kwargs @@ -3887,6 +3902,20 @@ def setOutputCol(self, value): """ return self._set(outputCol=value) + @since("3.0.0") + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + @since("3.0.0") + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + @staticmethod @since("2.0.0") def loadDefaultStopWords(language): From 0d2f624a5ce7a7cbc9ec9bc35cc08f8d6fcd98b4 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 11 Nov 2019 21:53:36 -0800 Subject: [PATCH 2/3] minor fix --- .../scala/org/apache/spark/ml/feature/StopWordsRemover.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index fc1a2d0540c3c..1e1ae5eedc78e 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructTyp * A feature transformer that filters out stop words from input. * * Since 3.0.0, - * `StopWordsRemover` can filters out multiple columns at once by setting the `inputCols` + * `StopWordsRemover` can filter out multiple columns at once by setting the `inputCols` * parameter. Note that when both the `inputCol` and `inputCols` parameters are set, an Exception * will be thrown. * @@ -188,7 +188,6 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String } val (inputColNames, outputColNames) = getInOutCols() - var outputFields = schema.fields inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) => require(!schema.fieldNames.contains(outputColName), From fb082d77600956f86dcd860e90eac4e12176229c Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 12 Nov 2019 09:54:09 -0800 Subject: [PATCH 3/3] address comments --- .../spark/ml/feature/StopWordsRemover.scala | 21 ++++++++----------- .../ml/feature/StopWordsRemoverSuite.scala | 20 +++++++++--------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 1e1ae5eedc78e..5377eed9b18b2 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -31,10 +31,9 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructTyp /** * A feature transformer that filters out stop words from input. * - * Since 3.0.0, - * `StopWordsRemover` can filter out multiple columns at once by setting the `inputCols` - * parameter. Note that when both the `inputCol` and `inputCols` parameters are set, an Exception - * will be thrown. + * Since 3.0.0, `StopWordsRemover` can filter out multiple columns at once by setting the + * `inputCols` parameter. Note that when both the `inputCol` and `inputCols` parameters are set, + * an Exception will be thrown. * * @note null values from input array are preserved unless adding null to stopWords * explicitly. @@ -182,23 +181,21 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String if (isSet(inputCols)) { require(getInputCols.length == getOutputCols.length, s"StopWordsRemover $this has mismatched Params " + - s"for multi-column transform. Params (inputCols, outputCols) should have " + - s"equal lengths, but they have different lengths: " + + s"for multi-column transform. Params ($inputCols, $outputCols) should have " + + "equal lengths, but they have different lengths: " + s"(${getInputCols.length}, ${getOutputCols.length}).") } val (inputColNames, outputColNames) = getInOutCols() - var outputFields = schema.fields - inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) => - require(!schema.fieldNames.contains(outputColName), + val newCols = inputColNames.zip(outputColNames).map { case (inputColName, outputColName) => + require(!schema.fieldNames.contains(outputColName), s"Output Column $outputColName already exists.") val inputType = schema(inputColName).dataType require(inputType.sameType(ArrayType(StringType)), "Input type must be " + s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.") - val outputField = StructField(outputColName, inputType, schema(inputColName).nullable) - outputFields :+= outputField + StructField(outputColName, inputType, schema(inputColName).nullable) } - StructType(outputFields) + StructType(schema.fields ++ newCols) } @Since("1.5.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 3e4ccbee404cb..c142f83e05956 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -242,11 +242,11 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { remover.transform(df) .select("filtered1", "expected1", "filtered2", "expected2") .collect().foreach { - case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) => - assert(r1 === e1, - s"The result value is not correct after bucketing. Expected $e1 but found $r1") - assert(r2 === e2, - s"The result value is not correct after bucketing. Expected $e2 but found $r2") + case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) => + assert(r1 === e1, + s"The result value is not correct after bucketing. Expected $e1 but found $r1") + assert(r2 === e2, + s"The result value is not correct after bucketing. Expected $e2 but found $r2") } } @@ -268,11 +268,11 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { remover.transform(df) .select("filtered1", "expected1", "filtered2", "expected2") .collect().foreach { - case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) => - assert(r1 === e1, - s"The result value is not correct after bucketing. Expected $e1 but found $r1") - assert(r2 === e2, - s"The result value is not correct after bucketing. Expected $e2 but found $r2") + case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) => + assert(r1 === e1, + s"The result value is not correct after bucketing. Expected $e1 but found $r1") + assert(r2 === e2, + s"The result value is not correct after bucketing. Expected $e2 but found $r2") } }