Skip to content

Commit

Permalink
[SPARK-29095][ML] add extractInstances
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
common methods support extract weights

### Why are the changes needed?
today more and more ML algs support weighting, add this method will make impls simple

### Does this PR introduce any user-facing change?
no

### How was this patch tested?
existing testsuites

Closes #25802 from zhengruifeng/add_extractInstances.

Authored-by: zhengruifeng <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
  • Loading branch information
zhengruifeng authored and srowen committed Sep 24, 2019
1 parent 7c02c14 commit fff2e84
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 66 deletions.
35 changes: 34 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.ml

import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
Expand Down Expand Up @@ -62,6 +62,39 @@ private[ml] trait PredictorParams extends Params
}
SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
}

/**
* Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
* and put it in an RDD with strong types.
*/
protected def extractInstances(dataset: Dataset[_]): RDD[Instance] = {
val w = this match {
case p: HasWeightCol =>
if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
col($(p.weightCol)).cast(DoubleType)
} else {
lit(1.0)
}
}

dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
}

/**
* Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
* and put it in an RDD with strong types.
* Validate the output instances with the given function.
*/
protected def extractInstances(dataset: Dataset[_],
validateInstance: Instance => Unit): RDD[Instance] = {
extractInstances(dataset).map { instance =>
validateInstance(instance)
instance
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils}
Expand All @@ -42,6 +42,22 @@ private[spark] trait ClassifierParams
val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT)
}

/**
* Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
* and put it in an RDD with strong types.
* Validates the label on the classifier is a valid integer in the range [0, numClasses).
*/
protected def extractInstances(dataset: Dataset[_],
numClasses: Int): RDD[Instance] = {
val validateInstance = (instance: Instance) => {
val label = instance.label
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")
}
extractInstances(dataset, validateInstance)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
Expand All @@ -34,9 +34,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit, udf}
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}

/**
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
Expand Down Expand Up @@ -116,23 +115,16 @@ class DecisionTreeClassifier @Since("1.4.0") (
dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr =>
instr.logPipelineStage(this)
instr.logDataset(dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses = getNumClasses(dataset)

if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
validateNumClasses(numClasses)
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
validateLabel(label, numClasses)
Instance(label, weight, features)
}
val instances = extractInstances(dataset, numClasses)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
instr.logNumClasses(numClasses)
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}

/** Params for linear SVM Classifier. */
private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
Expand Down Expand Up @@ -161,12 +159,7 @@ class LinearSVC @Since("2.2.0") (
override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra)

override protected def train(dataset: Dataset[_]): LinearSVCModel = instrumented { instr =>
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val instances = extractInstances(dataset)

instr.logPipelineStage(this)
instr.logDataset(dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, Multiclas
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils
Expand Down Expand Up @@ -492,12 +491,7 @@ class LogisticRegression @Since("1.2.0") (
protected[spark] def train(
dataset: Dataset[_],
handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr =>
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val instances = extractInstances(dataset)

if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
import org.apache.spark.ml.feature.OneHotEncoderModel
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.Since
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.functions.col

/**
* Params for Naive Bayes Classifiers.
Expand Down Expand Up @@ -137,35 +138,30 @@ class NaiveBayes @Since("1.5.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}

val modelTypeValue = $(modelType)
val requireValues: Vector => Unit = {
modelTypeValue match {
case Multinomial =>
requireNonnegativeValues
case Bernoulli =>
requireZeroOneBernoulliValues
case _ =>
// This should never happen.
throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
}
val validateInstance = $(modelType) match {
case Multinomial =>
(instance: Instance) => requireNonnegativeValues(instance.features)
case Bernoulli =>
(instance: Instance) => requireZeroOneBernoulliValues(instance.features)
case _ =>
// This should never happen.
throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
}

instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
probabilityCol, modelType, smoothing, thresholds)

val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size
instr.logNumFeatures(numFeatures)
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))

// Aggregates term frequencies per label.
// TODO: Calling aggregateByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
.map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2)))
}.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))(
val aggregated = extractInstances(dataset, validateInstance).map { instance =>
(instance.label, (instance.weight, instance.features))
}.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))(
seqOp = {
case ((weightSum, featureSum, count), (weight, features)) =>
requireValues(features)
BLAS.axpy(weight, features, featureSum)
(weightSum + weight, featureSum, count + 1)
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
Expand All @@ -34,9 +34,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType


/**
Expand Down Expand Up @@ -118,12 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val instances = extractInstances(dataset)
val strategy = getOldStrategy(categoricalFeatures)

instr.logPipelineStage(this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel}
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
Expand Down Expand Up @@ -320,13 +319,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
override protected def train(dataset: Dataset[_]): LinearRegressionModel = instrumented { instr =>
// Extract the number of features before deciding optimization solver.
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))

val instances: RDD[Instance] = dataset.select(
col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val instances = extractInstances(dataset)

instr.logPipelineStage(this)
instr.logDataset(dataset)
Expand Down

0 comments on commit fff2e84

Please sign in to comment.