-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-2852][MLLIB] Separate model from IDF/StandardScaler algorithms #1814
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,8 +17,9 @@ | |
|
||
package org.apache.spark.mllib.feature | ||
|
||
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} | ||
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} | ||
|
||
import org.apache.spark.Logging | ||
import org.apache.spark.annotation.Experimental | ||
import org.apache.spark.mllib.linalg.{Vector, Vectors} | ||
import org.apache.spark.mllib.rdd.RDDFunctions._ | ||
|
@@ -35,37 +36,55 @@ import org.apache.spark.rdd.RDD | |
* @param withStd True by default. Scales the data to unit standard deviation. | ||
*/ | ||
@Experimental | ||
class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer { | ||
class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This class is only used for keeping the state of withMean, and withStd, is it possible to move those states to fit function by overloading, and make it as object? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current API is more consistent with others like |
||
def this() = this(false, true) | ||
|
||
require(withMean || withStd, s"withMean and withStd both equal to false. Doing nothing.") | ||
|
||
private var mean: BV[Double] = _ | ||
private var factor: BV[Double] = _ | ||
if (!(withMean || withStd)) { | ||
logWarning("Both withMean and withStd are false. The model does nothing.") | ||
} | ||
|
||
/** | ||
* Computes the mean and variance and stores as a model to be used for later scaling. | ||
* | ||
* @param data The data used to compute the mean and variance to build the transformation model. | ||
* @return This StandardScalar object. | ||
* @return a StandardScalarModel | ||
*/ | ||
def fit(data: RDD[Vector]): this.type = { | ||
def fit(data: RDD[Vector]): StandardScalerModel = { | ||
// TODO: skip computation if both withMean and withStd are false | ||
val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( | ||
(aggregator, data) => aggregator.add(data), | ||
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)) | ||
new StandardScalerModel(withMean, withStd, summary.mean, summary.variance) | ||
} | ||
} | ||
|
||
mean = summary.mean.toBreeze | ||
factor = summary.variance.toBreeze | ||
require(mean.length == factor.length) | ||
/** | ||
* :: Experimental :: | ||
* Represents a StandardScaler model that can transform vectors. | ||
* | ||
* @param withMean whether to center the data before scaling | ||
* @param withStd whether to scale the data to have unit standard deviation | ||
* @param mean column mean values | ||
* @param variance column variance values | ||
*/ | ||
@Experimental | ||
class StandardScalerModel private[mllib] ( | ||
val withMean: Boolean, | ||
val withStd: Boolean, | ||
val mean: Vector, | ||
val variance: Vector) extends VectorTransformer { | ||
|
||
require(mean.size == variance.size) | ||
|
||
private lazy val factor: BDV[Double] = { | ||
val f = BDV.zeros[Double](variance.size) | ||
var i = 0 | ||
while (i < factor.length) { | ||
factor(i) = if (factor(i) != 0.0) 1.0 / math.sqrt(factor(i)) else 0.0 | ||
while (i < f.size) { | ||
f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0 | ||
i += 1 | ||
} | ||
|
||
this | ||
f | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since users may want to know the variance of the training set, should we have constructor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
/** | ||
|
@@ -76,13 +95,7 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor | |
* for the column with zero variance. | ||
*/ | ||
override def transform(vector: Vector): Vector = { | ||
if (mean == null || factor == null) { | ||
throw new IllegalStateException( | ||
"Haven't learned column summary statistics yet. Call fit first.") | ||
} | ||
|
||
require(vector.size == mean.length) | ||
|
||
require(mean.size == vector.size) | ||
if (withMean) { | ||
vector.toBreeze match { | ||
case dv: BDV[Double] => | ||
|
@@ -115,5 +128,4 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor | |
vector | ||
} | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following exception is used for unsupported vector in appendBias and StandardScaler, maybe we could have a global definition of this in util.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to use different error messages. In that case, having a util function doesn't save us much.