Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-29808][ML][PYTHON] StopWordsRemover should support multi-cols #26480

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,29 @@ import java.util.Locale
import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}

/**
* A feature transformer that filters out stop words from input.
*
* Since 3.0.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel strongly, but you could remove this.

Copy link
Contributor Author

@huaxingao huaxingao Nov 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I accidentally broke the line, but I prefer to have it. When other features added the multi columns support, since xxx was added to the doc. Just try to be consistent with others.

* `StopWordsRemover` can filter out multiple columns at once by setting the `inputCols`
* parameter. Note that when both the `inputCol` and `inputCols` parameters are set, an Exception
* will be thrown.
*
* @note null values from input array are preserved unless adding null to stopWords
* explicitly.
*
* @see <a href="http://en.wikipedia.org/wiki/Stop_words">Stop words (Wikipedia)</a>
*/
@Since("1.5.0")
class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String)
extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
extends Transformer with HasInputCol with HasOutputCol with HasInputCols with HasOutputCols
with DefaultParamsWritable {

@Since("1.5.0")
def this() = this(Identifiable.randomUID("stopWords"))
Expand All @@ -51,6 +57,14 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
@Since("1.5.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
@Since("3.0.0")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)

/** @group setParam */
@Since("3.0.0")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am debating if I should add stopWordsArray/caseSensitiveArray/localArray. Seems to me that users will use the same set of stopWords for all columns, so it's no need to add those.

/**
* The words to be filtered out.
* Default: English stop words
Expand Down Expand Up @@ -121,6 +135,15 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
}
}

/** Returns the input and output column names corresponding in pair. */
private[feature] def getInOutCols(): (Array[String], Array[String]) = {
if (isSet(inputCol)) {
(Array($(inputCol)), Array($(outputCol)))
} else {
($(inputCols), $(outputCols))
}
}

setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"),
caseSensitive -> false, locale -> getDefaultOrUS.toString)

Expand All @@ -142,16 +165,40 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
terms.filter(s => !lowerStopWords.contains(toLower(s)))
}
}
val metadata = outputSchema($(outputCol)).metadata
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))

val (inputColNames, outputColNames) = getInOutCols()
val ouputCols = inputColNames.map { inputColName =>
t(col(inputColName))
}
val ouputMetadata = outputColNames.map(outputSchema(_).metadata)
dataset.withColumns(outputColNames, ouputCols, ouputMetadata)
}

@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
val inputType = schema($(inputCol)).dataType
require(inputType.sameType(ArrayType(StringType)), "Input type must be " +
s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.")
SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable)
ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol),
Seq(outputCols))

if (isSet(inputCols)) {
require(getInputCols.length == getOutputCols.length,
s"StopWordsRemover $this has mismatched Params " +
s"for multi-column transform. Params (inputCols, outputCols) should have " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: you don't need interpolation on these two lines.

s"equal lengths, but they have different lengths: " +
s"(${getInputCols.length}, ${getOutputCols.length}).")
}

val (inputColNames, outputColNames) = getInOutCols()
var outputFields = schema.fields
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will hardly matter unless the number of cols is large, but is it as easy and a little faster to .map the .zip below to the new output fields, and then append them once to schema.fields?

inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) =>
require(!schema.fieldNames.contains(outputColName),
s"Output Column $outputColName already exists.")
val inputType = schema(inputColName).dataType
require(inputType.sameType(ArrayType(StringType)), "Input type must be " +
s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.")
val outputField = StructField(outputColName, inputType, schema(inputColName).nullable)
outputFields :+= outputField
}
StructType(outputFields)
}

@Since("1.5.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature

import java.util.Locale

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.{DataFrame, Row}

Expand Down Expand Up @@ -181,12 +182,19 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
}

test("read/write") {
val t = new StopWordsRemover()
val t1 = new StopWordsRemover()
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setStopWords(Array("the", "a"))
.setCaseSensitive(true)
testDefaultReadWrite(t)
testDefaultReadWrite(t1)

val t2 = new StopWordsRemover()
.setInputCols(Array("input1", "input2", "input3"))
.setOutputCols(Array("result1", "result2", "result3"))
.setStopWords(Array("the", "a"))
.setCaseSensitive(true)
testDefaultReadWrite(t2)
}

test("StopWordsRemover output column already exists") {
Expand All @@ -199,7 +207,7 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
testTransformerByInterceptingException[(Array[String], Array[String])](
dataSet,
remover,
s"requirement failed: Column $outputCol already exists.",
s"requirement failed: Output Column $outputCol already exists.",
"expected")
}

Expand All @@ -217,4 +225,123 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
Locale.setDefault(oldDefault)
}
}

test("Multiple Columns: StopWordsRemover default") {
val remover = new StopWordsRemover()
.setInputCols(Array("raw1", "raw2"))
.setOutputCols(Array("filtered1", "filtered2"))
val df = Seq(
(Seq("test", "test"), Seq("test1", "test2"), Seq("test", "test"), Seq("test1", "test2")),
(Seq("a", "b", "c", "d"), Seq("a", "b"), Seq("b", "c", "d"), Seq("b")),
(Seq("a", "the", "an"), Seq("the", "an"), Seq(), Seq()),
(Seq("A", "The", "AN"), Seq("A", "The"), Seq(), Seq()),
(Seq(null), Seq(null), Seq(null), Seq(null)),
(Seq(), Seq(), Seq(), Seq())
).toDF("raw1", "raw2", "expected1", "expected2")

remover.transform(df)
.select("filtered1", "expected1", "filtered2", "expected2")
.collect().foreach {
case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit: indent this more

assert(r1 === e1,
s"The result value is not correct after bucketing. Expected $e1 but found $r1")
assert(r2 === e2,
s"The result value is not correct after bucketing. Expected $e2 but found $r2")
}
}

test("Multiple Columns: StopWordsRemover with particular stop words list") {
val stopWords = Array("test", "a", "an", "the")
val remover = new StopWordsRemover()
.setInputCols(Array("raw1", "raw2"))
.setOutputCols(Array("filtered1", "filtered2"))
.setStopWords(stopWords)
val df = Seq(
(Seq("test", "test"), Seq("test1", "test2"), Seq(), Seq("test1", "test2")),
(Seq("a", "b", "c", "d"), Seq("a", "b"), Seq("b", "c", "d"), Seq("b")),
(Seq("a", "the", "an"), Seq("a", "the", "test1"), Seq(), Seq("test1")),
(Seq("A", "The", "AN"), Seq("A", "The", "AN"), Seq(), Seq()),
(Seq(null), Seq(null), Seq(null), Seq(null)),
(Seq(), Seq(), Seq(), Seq())
).toDF("raw1", "raw2", "expected1", "expected2")

remover.transform(df)
.select("filtered1", "expected1", "filtered2", "expected2")
.collect().foreach {
case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) =>
assert(r1 === e1,
s"The result value is not correct after bucketing. Expected $e1 but found $r1")
assert(r2 === e2,
s"The result value is not correct after bucketing. Expected $e2 but found $r2")
}
}

test("Compare single/multiple column(s) StopWordsRemover in pipeline") {
val df = Seq(
(Seq("test", "test"), Seq("test1", "test2")),
(Seq("a", "b", "c", "d"), Seq("a", "b")),
(Seq("a", "the", "an"), Seq("a", "the", "test1")),
(Seq("A", "The", "AN"), Seq("A", "The", "AN")),
(Seq(null), Seq(null)),
(Seq(), Seq())
).toDF("input1", "input2")

val multiColsRemover = new StopWordsRemover()
.setInputCols(Array("input1", "input2"))
.setOutputCols(Array("output1", "output2"))

val plForMultiCols = new Pipeline()
.setStages(Array(multiColsRemover))
.fit(df)

val removerForCol1 = new StopWordsRemover()
.setInputCol("input1")
.setOutputCol("output1")
val removerForCol2 = new StopWordsRemover()
.setInputCol("input2")
.setOutputCol("output2")

val plForSingleCol = new Pipeline()
.setStages(Array(removerForCol1, removerForCol2))
.fit(df)

val resultForSingleCol = plForSingleCol.transform(df)
.select("output1", "output2")
.collect()
val resultForMultiCols = plForMultiCols.transform(df)
.select("output1", "output2")
.collect()

resultForSingleCol.zip(resultForMultiCols).foreach {
case (rowForSingle, rowForMultiCols) =>
assert(rowForSingle === rowForMultiCols)
}
}

test("Multiple Columns: Mismatched sizes of inputCols/outputCols") {
val remover = new StopWordsRemover()
.setInputCols(Array("input1"))
.setOutputCols(Array("result1", "result2"))
val df = Seq(
(Seq("A"), Seq("A")),
(Seq("The", "the"), Seq("The"))
).toDF("input1", "input2")
intercept[IllegalArgumentException] {
remover.transform(df).count()
}
}

test("Multiple Columns: Set both of inputCol/inputCols") {
val remover = new StopWordsRemover()
.setInputCols(Array("input1", "input2"))
.setOutputCols(Array("result1", "result2"))
.setInputCol("input1")
val df = Seq(
(Seq("A"), Seq("A")),
(Seq("The", "the"), Seq("The"))
).toDF("input1", "input2")
intercept[IllegalArgumentException] {
remover.transform(df).count()
}
}
}
39 changes: 34 additions & 5 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -3774,9 +3774,13 @@ def setOutputCol(self, value):
return self._set(outputCol=value)


class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
JavaMLReadable, JavaMLWritable):
"""
A feature transformer that filters out stop words from input.
Since 3.0.0, :py:class:`StopWordsRemover` can filter out multiple columns at once by setting
the :py:attr:`inputCols` parameter. Note that when both the :py:attr:`inputCol` and
:py:attr:`inputCols` parameters are set, an Exception will be thrown.

.. note:: null values from input array are preserved unless adding null to stopWords explicitly.

Expand All @@ -3795,6 +3799,17 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
True
>>> loadedRemover.getCaseSensitive() == remover.getCaseSensitive()
True
>>> df2 = spark.createDataFrame([(["a", "b", "c"], ["a", "b"])], ["text1", "text2"])
>>> remover2 = StopWordsRemover(stopWords=["b"])
>>> remover2.setInputCols(["text1", "text2"]).setOutputCols(["words1", "words2"])
StopWordsRemover...
>>> remover2.transform(df2).show()
+---------+------+------+------+
| text1| text2|words1|words2|
+---------+------+------+------+
|[a, b, c]|[a, b]|[a, c]| [a]|
+---------+------+------+------+
...

.. versionadded:: 1.6.0
"""
Expand All @@ -3808,10 +3823,10 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl

@keyword_only
def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
locale=None):
locale=None, inputCols=None, outputCols=None):
"""
__init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
locale=None)
locale=None, inputCols=None, outputCols=None)
"""
super(StopWordsRemover, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
Expand All @@ -3824,10 +3839,10 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=
@keyword_only
@since("1.6.0")
def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
locale=None):
locale=None, inputCols=None, outputCols=None):
"""
setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
locale=None)
locale=None, inputCols=None, outputCols=None)
Sets params for this StopWordRemover.
"""
kwargs = self._input_kwargs
Expand Down Expand Up @@ -3887,6 +3902,20 @@ def setOutputCol(self, value):
"""
return self._set(outputCol=value)

@since("3.0.0")
def setInputCols(self, value):
"""
Sets the value of :py:attr:`inputCols`.
"""
return self._set(inputCols=value)

@since("3.0.0")
def setOutputCols(self, value):
"""
Sets the value of :py:attr:`outputCols`.
"""
return self._set(outputCols=value)

@staticmethod
@since("2.0.0")
def loadDefaultStopWords(language):
Expand Down