Skip to content

Commit

Permalink
separate Model from StandardScaler algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Aug 6, 2014
1 parent 89f3486 commit 48a0fff
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,38 +35,47 @@ import org.apache.spark.rdd.RDD
* @param withStd True by default. Scales the data to unit standard deviation.
*/
@Experimental
class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer {
class StandardScaler(withMean: Boolean, withStd: Boolean) {

def this() = this(false, true)

require(withMean || withStd, s"withMean and withStd both equal to false. Doing nothing.")

private var mean: BV[Double] = _
private var factor: BV[Double] = _

/**
* Computes the mean and variance and stores as a model to be used for later scaling.
*
* @param data The data used to compute the mean and variance to build the transformation model.
* @return This StandardScalar object.
* @return a StandardScalarModel
*/
def fit(data: RDD[Vector]): this.type = {
def fit(data: RDD[Vector]): StandardScalerModel = {
val summary = data.treeAggregate(new MultivariateOnlineSummarizer)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))

mean = summary.mean.toBreeze
factor = summary.variance.toBreeze
require(mean.length == factor.length)
val mean = summary.mean.toBreeze
val factor = summary.variance.toBreeze
require(mean.size == factor.size)

var i = 0
while (i < factor.length) {
while (i < factor.size) {
factor(i) = if (factor(i) != 0.0) 1.0 / math.sqrt(factor(i)) else 0.0
i += 1
}

this
new StandardScalerModel(withMean, withStd, mean, factor)
}
}

/**
* :: Experimental ::
* Represents a StandardScaler model that can transform vectors.
*/
@Experimental
class StandardScalerModel private[mllib] (
val withMean: Boolean,
val withStd: Boolean,
val mean: BV[Double],
val factor: BV[Double]) extends VectorTransformer {

/**
* Applies standardization transformation on a vector.
Expand All @@ -81,7 +90,7 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor
"Haven't learned column summary statistics yet. Call fit first.")
}

require(vector.size == mean.length)
require(vector.size == mean.size)

if (withMean) {
vector.toBreeze match {
Expand Down Expand Up @@ -115,5 +124,4 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor
vector
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,17 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext {
val standardizer2 = new StandardScaler()
val standardizer3 = new StandardScaler(withMean = true, withStd = false)

withClue("Using a standardizer before fitting the model should throw exception.") {
intercept[IllegalStateException] {
data.map(standardizer1.transform)
}
}

standardizer1.fit(dataRDD)
standardizer2.fit(dataRDD)
standardizer3.fit(dataRDD)
val model1 = standardizer1.fit(dataRDD)
val model2 = standardizer2.fit(dataRDD)
val model3 = standardizer3.fit(dataRDD)

val data1 = data.map(standardizer1.transform)
val data2 = data.map(standardizer2.transform)
val data3 = data.map(standardizer3.transform)
val data1 = data.map(model1.transform)
val data2 = data.map(model2.transform)
val data3 = data.map(model3.transform)

val data1RDD = standardizer1.transform(dataRDD)
val data2RDD = standardizer2.transform(dataRDD)
val data3RDD = standardizer3.transform(dataRDD)
val data1RDD = model1.transform(dataRDD)
val data2RDD = model2.transform(dataRDD)
val data3RDD = model3.transform(dataRDD)

val summary = computeSummary(dataRDD)
val summary1 = computeSummary(data1RDD)
Expand Down Expand Up @@ -129,25 +123,25 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext {
val standardizer2 = new StandardScaler()
val standardizer3 = new StandardScaler(withMean = true, withStd = false)

standardizer1.fit(dataRDD)
standardizer2.fit(dataRDD)
standardizer3.fit(dataRDD)
val model1 = standardizer1.fit(dataRDD)
val model2 = standardizer2.fit(dataRDD)
val model3 = standardizer3.fit(dataRDD)

val data2 = data.map(standardizer2.transform)
val data2 = data.map(model2.transform)

withClue("Standardization with mean can not be applied on sparse input.") {
intercept[IllegalArgumentException] {
data.map(standardizer1.transform)
data.map(model1.transform)
}
}

withClue("Standardization with mean can not be applied on sparse input.") {
intercept[IllegalArgumentException] {
data.map(standardizer3.transform)
data.map(model3.transform)
}
}

val data2RDD = standardizer2.transform(dataRDD)
val data2RDD = model2.transform(dataRDD)

val summary2 = computeSummary(data2RDD)

Expand Down Expand Up @@ -181,13 +175,13 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext {
val standardizer2 = new StandardScaler(withMean = true, withStd = false)
val standardizer3 = new StandardScaler(withMean = false, withStd = true)

standardizer1.fit(dataRDD)
standardizer2.fit(dataRDD)
standardizer3.fit(dataRDD)
val model1 = standardizer1.fit(dataRDD)
val model2 = standardizer2.fit(dataRDD)
val model3 = standardizer3.fit(dataRDD)

val data1 = data.map(standardizer1.transform)
val data2 = data.map(standardizer2.transform)
val data3 = data.map(standardizer3.transform)
val data1 = data.map(model1.transform)
val data2 = data.map(model2.transform)
val data3 = data.map(model3.transform)

assert(data1.forall(_.toArray.forall(_ == 0.0)),
"The variance is zero, so the transformed result should be 0.0")
Expand Down

0 comments on commit 48a0fff

Please sign in to comment.