Skip to content

Commit

Permalink
move mean and variance to model
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Aug 7, 2014
1 parent 48a0fff commit 40d863b
Showing 1 changed file with 28 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

package org.apache.spark.mllib.feature

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.rdd.RDDFunctions._
Expand All @@ -35,11 +36,13 @@ import org.apache.spark.rdd.RDD
* @param withStd True by default. Scales the data to unit standard deviation.
*/
@Experimental
class StandardScaler(withMean: Boolean, withStd: Boolean) {
class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging {

def this() = this(false, true)

require(withMean || withStd, s"withMean and withStd both equal to false. Doing nothing.")
if (!(withMean || withStd)) {
logWarning("Both withMean and withStd are false. The model does nothing.")
}

/**
* Computes the mean and variance and stores as a model to be used for later scaling.
Expand All @@ -48,34 +51,41 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) {
* @return a StandardScalarModel
*/
def fit(data: RDD[Vector]): StandardScalerModel = {
// TODO: skip computation if both withMean and withStd are false
val summary = data.treeAggregate(new MultivariateOnlineSummarizer)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))

val mean = summary.mean.toBreeze
val factor = summary.variance.toBreeze
require(mean.size == factor.size)

var i = 0
while (i < factor.size) {
factor(i) = if (factor(i) != 0.0) 1.0 / math.sqrt(factor(i)) else 0.0
i += 1
}

new StandardScalerModel(withMean, withStd, mean, factor)
new StandardScalerModel(withMean, withStd, summary.mean, summary.variance)
}
}

/**
* :: Experimental ::
* Represents a StandardScaler model that can transform vectors.
*
* @param withMean whether to center the data before scaling
* @param withStd whether to scale the data to have unit standard deviation
* @param mean column mean values
* @param variance column variance values
*/
@Experimental
class StandardScalerModel private[mllib] (
val withMean: Boolean,
val withStd: Boolean,
val mean: BV[Double],
val factor: BV[Double]) extends VectorTransformer {
val mean: Vector,
val variance: Vector) extends VectorTransformer {

require(mean.size == variance.size)

private lazy val factor: BDV[Double] = {
val f = BDV.zeros[Double](variance.size)
var i = 0
while (i < f.size) {
f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0
i += 1
}
f
}

/**
* Applies standardization transformation on a vector.
Expand All @@ -85,13 +95,7 @@ class StandardScalerModel private[mllib] (
* for the column with zero variance.
*/
override def transform(vector: Vector): Vector = {
if (mean == null || factor == null) {
throw new IllegalStateException(
"Haven't learned column summary statistics yet. Call fit first.")
}

require(vector.size == mean.size)

require(mean.size == vector.size)
if (withMean) {
vector.toBreeze match {
case dv: BDV[Double] =>
Expand Down

0 comments on commit 40d863b

Please sign in to comment.