forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added lots of classes for new ML API:
Abstract classes for learning algorithms: * Classifier * Regressor * Predictor Traits for learning algorithms * HasDefaultEstimator * IterativeEstimator * IterativeSolver * ProbabilisticClassificationModel * WeakLearner Concrete classes: learning algorithms * AdaBoost (partly implemented) * NaiveBayes (rough implementation) * LinearRegression * LogisticRegression (updated to use new abstract classes) Concrete classes: evaluation * ClassificationEvaluator * RegressionEvaluator * PredictionEvaluator Concrete classes: other * LabeledPoint (adding weight to the old LabeledPoint)
- Loading branch information
Showing
19 changed files
with
1,001 additions
and
58 deletions.
There are no files selected for viewing
24 changes: 24 additions & 0 deletions
24
mllib/src/main/scala/org/apache/spark/ml/LabeledPoint.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
package org.apache.spark.ml | ||
|
||
import org.apache.spark.mllib.linalg.Vector | ||
|
||
/** | ||
* Class that represents an instance (data point) for prediction tasks. | ||
* | ||
* @param label Label to predict | ||
* @param features List of features describing this instance | ||
* @param weight Instance weight | ||
*/ | ||
case class LabeledPoint(label: Double, features: Vector, weight: Double) { | ||
|
||
/** Default constructor which sets instance weight to 1.0 */ | ||
def this(label: Double, features: Vector) = this(label, features, 1.0) | ||
|
||
override def toString: String = { | ||
"(%s,%s,%s)".format(label, features, weight) | ||
} | ||
} | ||
|
||
object LabeledPoint { | ||
def apply(label: Double, features: Vector) = new LabeledPoint(label, features) | ||
} |
208 changes: 208 additions & 0 deletions
208
mllib/src/main/scala/org/apache/spark/ml/classification/AdaBoost.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
package org.apache.spark.ml.classification | ||
|
||
import scala.collection.mutable.ArrayBuffer | ||
|
||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.sql._ | ||
import org.apache.spark.mllib.linalg.{Vectors, Vector} | ||
import org.apache.spark.ml.LabeledPoint | ||
import org.apache.spark.ml.evaluation.ClassificationEvaluator | ||
import org.apache.spark.ml.param.{HasWeightCol, Param, ParamMap, HasMaxIter} | ||
import org.apache.spark.ml.impl.estimator.{ProbabilisticClassificationModel, WeakLearner, | ||
IterativeEstimator, IterativeSolver} | ||
|
||
|
||
private[classification] trait AdaBoostParams extends ClassifierParams | ||
with HasMaxIter with HasWeightCol { | ||
|
||
/** param for weak learner type */ | ||
val weakLearner: Param[Classifier[_, _]] = | ||
new Param(this, "weakLearner", "weak learning algorithm") | ||
def getWeakLearner: Classifier[_, _] = get(weakLearner) | ||
|
||
/** param for weak learner param maps */ | ||
val weakLearnerParamMap: Param[ParamMap] = | ||
new Param(this, "weakLearnerParamMap", "param map for the weak learner") | ||
def getWeakLearnerParamMap: ParamMap = get(weakLearnerParamMap) | ||
|
||
override def validate(paramMap: ParamMap): Unit = { | ||
// TODO: Check maxIter, weakLearner, weakLearnerParamMap, weightCol | ||
// Check: If the weak learner does not extend WeakLearner, then featuresColName should be | ||
// castable to FeaturesType. | ||
} | ||
} | ||
|
||
|
||
/** | ||
* AdaBoost | ||
* | ||
* Developer notes: | ||
* - If the weak learner implements the [[WeakLearner]] | ||
*/ | ||
class AdaBoost extends Classifier[AdaBoost, AdaBoostModel] | ||
with AdaBoostParams | ||
with IterativeEstimator[AdaBoostModel] { | ||
|
||
def setMaxIter(value: Int): this.type = set(maxIter, value) | ||
def setWeightCol(value: String): this.type = set(weightCol, value) | ||
def setWeakLearner(value: Classifier[_, _]): this.type = set(weakLearner, value) | ||
def setWeakLearnerParamMap(value: ParamMap): this.type = set(weakLearnerParamMap, value) | ||
|
||
/** | ||
* Extract LabeledPoints, using the weak learner's native feature representation if possible. | ||
* @param paramMap Complete paramMap (after combining with the internal paramMap) | ||
*/ | ||
private def extractLabeledPoints(dataset: SchemaRDD, paramMap: ParamMap): RDD[LabeledPoint] = { | ||
import dataset.sqlContext._ | ||
val featuresColName = paramMap(featuresCol) | ||
val wl = paramMap(weakLearner) | ||
val featuresRDD: RDD[Vector] = wl match { | ||
case wlTagged: WeakLearner => | ||
val wlParamMap = paramMap(weakLearnerParamMap) | ||
val wlFeaturesColName = wlParamMap(wl.featuresCol) | ||
val origFeaturesRDD = dataset.select(featuresColName.attr).as(wlFeaturesColName.attr) | ||
wlTagged.getNativeFeatureRDD(origFeaturesRDD, wlParamMap) | ||
case _ => | ||
dataset.select(featuresColName.attr).map { case Row(features: Vector) => features } | ||
} | ||
|
||
val labelColName = paramMap(labelCol) | ||
if (paramMap.contains(weightCol)) { | ||
val weightColName = paramMap(weightCol) | ||
dataset.select(labelColName.attr, weightColName.attr) | ||
.zip(featuresRDD).map { case (Row(label: Double, weight: Double), features: Vector) => | ||
LabeledPoint(label, features, weight) | ||
} | ||
} else { | ||
dataset.select(labelColName.attr) | ||
.zip(featuresRDD).map { case (Row(label: Double), features: Vector) => | ||
LabeledPoint(label, features) | ||
} | ||
} | ||
} | ||
|
||
// From Classifier | ||
override def fit(dataset: SchemaRDD, paramMap: ParamMap): AdaBoostModel = { | ||
val map = this.paramMap ++ paramMap | ||
val labeledPoints: RDD[LabeledPoint] = extractLabeledPoints(dataset, map) | ||
train(labeledPoints, paramMap) | ||
} | ||
|
||
// From IterativeEstimator | ||
override private[ml] def createSolver(dataset: SchemaRDD, paramMap: ParamMap): AdaBoostSolver = { | ||
val map = this.paramMap ++ paramMap | ||
val labeledPoints: RDD[LabeledPoint] = extractLabeledPoints(dataset, map) | ||
new AdaBoostSolver(labeledPoints, this, map) | ||
} | ||
|
||
// From Predictor | ||
override def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): AdaBoostModel = { | ||
val map = this.paramMap ++ paramMap | ||
val solver = new AdaBoostSolver(dataset, this, map) | ||
while (solver.step()) { } | ||
solver.currentModel | ||
} | ||
} | ||
|
||
|
||
class AdaBoostModel private[ml] ( | ||
val weakHypotheses: Array[ClassificationModel[_]], | ||
val weakHypothesisWeights: Array[Double], | ||
override val parent: AdaBoost, | ||
override val fittingParamMap: ParamMap) | ||
extends ClassificationModel[AdaBoostModel] | ||
with ProbabilisticClassificationModel | ||
with AdaBoostParams { | ||
|
||
require(weakHypotheses.size != 0) | ||
require(weakHypotheses.size == weakHypothesisWeights.size) | ||
|
||
// From Classifier.Model: | ||
override val numClasses: Int = weakHypotheses(0).numClasses | ||
|
||
require(weakHypotheses.forall(_.numClasses == numClasses)) | ||
|
||
private val margin: Vector => Double = (features) => { | ||
weakHypotheses.zip(weakHypothesisWeights) | ||
.foldLeft(0.0) { case (total: Double, (wh: ClassificationModel[_], weight: Double)) => | ||
val pred = if (wh.predict(features) == 1.0) 1.0 else -1.0 | ||
total + weight * pred | ||
} | ||
} | ||
|
||
private val score: Vector => Double = (features) => { | ||
val m = margin(features) | ||
1.0 / (1.0 + math.exp(-2.0 * m)) | ||
} | ||
|
||
override def predictProbabilities(features: Vector): Vector = { | ||
val s = score(features) | ||
Vectors.dense(Array(1.0 - s, s)) | ||
} | ||
|
||
override def predictRaw(features: Vector): Vector = { | ||
val m = margin(features) | ||
Vectors.dense(Array(-m, m)) | ||
} | ||
} | ||
|
||
|
||
private[ml] class AdaBoostSolver( | ||
val origData: RDD[LabeledPoint], | ||
val parent: AdaBoost, | ||
val paramMap: ParamMap) extends IterativeSolver[AdaBoostModel] { | ||
|
||
private val weakHypotheses = new ArrayBuffer[ClassificationModel[_]] | ||
private val weakHypothesisWeights = new ArrayBuffer[Double] | ||
|
||
private val wl: Classifier[_, _] = paramMap(parent.weakLearner) | ||
private val wlParamMap = paramMap(parent.weakLearnerParamMap) | ||
override val maxIterations: Int = paramMap(parent.maxIter) | ||
|
||
// TODO: Decide if this alg should cache data, or if that should be left to the user. | ||
|
||
// TODO: check for weights = 0 | ||
// TODO: EDITING HERE NOW: switch to log weights | ||
private var logInstanceWeights: RDD[Double] = origData.map(lp => math.log(lp.weight)) | ||
|
||
override def stepImpl(): Boolean = ??? /*{ | ||
// Check if the weak learner takes instance weights. | ||
val wlDataset = wl match { | ||
case wlWeighted: HasWeightCol => | ||
origData.zip(logInstanceWeights).map { case (lp: LabeledPoint, logWeight: Double) => | ||
LabeledPoint(lp.label, lp.features, weight) | ||
} | ||
case _ => | ||
// Subsample data to simulate the current instance weight distribution. | ||
// TODO: This needs to be done before AdaBoost is committed. | ||
throw new NotImplementedError( | ||
"AdaBoost currently requires that the weak learning algorithm accept instance weights.") | ||
} | ||
// Train the weak learning algorithm. | ||
val weakHypothesis: ClassificationModel[_] = wl match { | ||
case wlTagged: WeakLearner[_] => | ||
// This lets the weak learner know that the features are in its native format. | ||
wlTagged.trainNative(wlDataset, wlParamMap).asInstanceOf[ClassificationModel[_]] | ||
case _ => | ||
wl.train(wlDataset, wlParamMap).asInstanceOf[ClassificationModel[_]] | ||
} | ||
// Add the weighted weak hypothesis to the ensemble. | ||
// TODO: Handle instance weights. | ||
val predictionsAndLabels = wlDataset.map(lp => weakHypothesis.predict(lp.features)) | ||
.zip(wlDataset.map(_.label)) | ||
val eps = ClassificationEvaluator.computeMetric(predictionsAndLabels, "accuracy") | ||
val alpha = 0.5 * (math.log(1.0 - eps) - math.log(eps)) // TODO: handle eps near 0 | ||
weakHypotheses += weakHypothesis | ||
weakHypothesisWeights += alpha | ||
// Update weights. | ||
val newInstanceWeights = instanceWeights.zip(predictionsAndLabels).map { | ||
case (weight: Double, (pred: Double, label: Double)) => | ||
??? | ||
} | ||
}*/ | ||
|
||
override def currentModel: AdaBoostModel = { | ||
new AdaBoostModel(weakHypotheses.toArray, weakHypothesisWeights.toArray, parent, paramMap) | ||
} | ||
} |
73 changes: 73 additions & 0 deletions
73
mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.classification | ||
|
||
import org.apache.spark.annotation.AlphaComponent | ||
import org.apache.spark.ml.evaluation.ClassificationEvaluator | ||
import org.apache.spark.mllib.linalg.Vector | ||
import org.apache.spark.ml._ | ||
import org.apache.spark.ml.impl.estimator.{HasDefaultEvaluator, PredictionModel, Predictor, | ||
PredictorParams} | ||
import org.apache.spark.rdd.RDD | ||
|
||
@AlphaComponent | ||
private[classification] trait ClassifierParams extends PredictorParams | ||
|
||
/** | ||
* Single-label binary or multiclass classification | ||
*/ | ||
abstract class Classifier[Learner <: Classifier[Learner, M], M <: ClassificationModel[M]] | ||
extends Predictor[Learner, M] | ||
with ClassifierParams | ||
with HasDefaultEvaluator { | ||
|
||
override def defaultEvaluator: Evaluator = new ClassificationEvaluator | ||
} | ||
|
||
|
||
private[ml] abstract class ClassificationModel[M <: ClassificationModel[M]] | ||
extends PredictionModel[M] with ClassifierParams { | ||
|
||
def numClasses: Int | ||
|
||
/** | ||
* Predict label for the given features. Labels are indexed {0, 1, ..., numClasses - 1}. | ||
* This default implementation for classification predicts the index of the maximum value | ||
* from [[predictRaw()]]. | ||
*/ | ||
override def predict(features: Vector): Double = { | ||
predictRaw(features).toArray.zipWithIndex.maxBy(_._1)._2 | ||
} | ||
|
||
/** | ||
* Raw prediction for each possible label | ||
* @return vector where element i is the raw score for label i | ||
*/ | ||
def predictRaw(features: Vector): Vector | ||
|
||
/** | ||
* Compute this model's accuracy on the given dataset. | ||
*/ | ||
def accuracy(dataset: RDD[LabeledPoint]): Double = { | ||
// TODO: Handle instance weights. | ||
val predictionsAndLabels = dataset.map(lp => predict(lp.features)) | ||
.zip(dataset.map(_.label)) | ||
ClassificationEvaluator.computeMetric(predictionsAndLabels, "accuracy") | ||
} | ||
|
||
} |
Oops, something went wrong.