-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-7752][MLLIB] Use lowercase letters for NaiveBayes.modelType #6277
Changes from 4 commits
17bba53
3c456a8
264a814
40ae53e
711d1c6
ae5c66a
f38b662
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,21 +25,20 @@ import org.json4s.JsonDSL._ | |
import org.json4s.jackson.JsonMethods._ | ||
|
||
import org.apache.spark.{Logging, SparkContext, SparkException} | ||
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector, Vectors} | ||
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector} | ||
import org.apache.spark.mllib.regression.LabeledPoint | ||
import org.apache.spark.mllib.util.{Loader, Saveable} | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.sql.{DataFrame, SQLContext} | ||
|
||
|
||
/** | ||
* Model for Naive Bayes Classifiers. | ||
* | ||
* @param labels list of labels | ||
* @param pi log of class priors, whose dimension is C, number of labels | ||
* @param theta log of class conditional probabilities, whose dimension is C-by-D, | ||
* where D is number of features | ||
* @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli" | ||
* @param modelType The type of NB model to fit can be "multinomial" or "bernoulli" | ||
*/ | ||
class NaiveBayesModel private[mllib] ( | ||
val labels: Array[Double], | ||
|
@@ -48,11 +47,13 @@ class NaiveBayesModel private[mllib] ( | |
val modelType: String) | ||
extends ClassificationModel with Serializable with Saveable { | ||
|
||
import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes} | ||
|
||
private val piVector = new DenseVector(pi) | ||
private val thetaMatrix = new DenseMatrix(labels.size, theta(0).size, theta.flatten, true) | ||
private val thetaMatrix = new DenseMatrix(labels.length, theta(0).length, theta.flatten, true) | ||
|
||
private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = | ||
this(labels, pi, theta, "Multinomial") | ||
this(labels, pi, theta, NaiveBayes.Multinomial) | ||
|
||
/** A Java-friendly constructor that takes three Iterable parameters. */ | ||
private[mllib] def this( | ||
|
@@ -61,12 +62,15 @@ class NaiveBayesModel private[mllib] ( | |
theta: JIterable[JIterable[Double]]) = | ||
this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray)) | ||
|
||
require(supportedModelTypes.contains(modelType), | ||
s"Invalid model type $modelType. Supported model types are $supportedModelTypes.") | ||
|
||
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. | ||
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra | ||
// application of this condition (in predict function). | ||
private val (thetaMinusNegTheta, negThetaSum) = modelType match { | ||
case "Multinomial" => (None, None) | ||
case "Bernoulli" => | ||
case Multinomial => (None, None) | ||
case Bernoulli => | ||
val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value))) | ||
val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0}) | ||
val thetaMinusNegTheta = thetaMatrix.map { value => | ||
|
@@ -75,7 +79,7 @@ class NaiveBayesModel private[mllib] ( | |
(Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) | ||
case _ => | ||
// This should never happen. | ||
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") | ||
throw new UnknownError(s"Invalid model type: $modelType.") | ||
} | ||
|
||
override def predict(testData: RDD[Vector]): RDD[Double] = { | ||
|
@@ -88,15 +92,15 @@ class NaiveBayesModel private[mllib] ( | |
|
||
override def predict(testData: Vector): Double = { | ||
modelType match { | ||
case "Multinomial" => | ||
case Multinomial => | ||
val prob = thetaMatrix.multiply(testData) | ||
BLAS.axpy(1.0, piVector, prob) | ||
labels(prob.argmax) | ||
case "Bernoulli" => | ||
case Bernoulli => | ||
testData.foreachActive { (index, value) => | ||
if (value != 0.0 && value != 1.0) { | ||
throw new SparkException( | ||
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.") | ||
s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.") | ||
} | ||
} | ||
val prob = thetaMinusNegTheta.get.multiply(testData) | ||
|
@@ -105,7 +109,7 @@ class NaiveBayesModel private[mllib] ( | |
labels(prob.argmax) | ||
case _ => | ||
// This should never happen. | ||
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") | ||
throw new UnknownError(s"Invalid model type: $modelType.") | ||
} | ||
} | ||
|
||
|
@@ -230,16 +234,16 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { | |
s"($loadedClassName, $version). Supported:\n" + | ||
s" ($classNameV1_0, 1.0)") | ||
} | ||
assert(model.pi.size == numClasses, | ||
assert(model.pi.length == numClasses, | ||
s"NaiveBayesModel.load expected $numClasses classes," + | ||
s" but class priors vector pi had ${model.pi.size} elements") | ||
assert(model.theta.size == numClasses, | ||
s" but class priors vector pi had ${model.pi.length} elements") | ||
assert(model.theta.length == numClasses, | ||
s"NaiveBayesModel.load expected $numClasses classes," + | ||
s" but class conditionals array theta had ${model.theta.size} elements") | ||
assert(model.theta.forall(_.size == numFeatures), | ||
s" but class conditionals array theta had ${model.theta.length} elements") | ||
assert(model.theta.forall(_.length == numFeatures), | ||
s"NaiveBayesModel.load expected $numFeatures features," + | ||
s" but class conditionals array theta had elements of size:" + | ||
s" ${model.theta.map(_.size).mkString(",")}") | ||
s" ${model.theta.map(_.length).mkString(",")}") | ||
model | ||
} | ||
} | ||
|
@@ -257,9 +261,11 @@ class NaiveBayes private ( | |
private var lambda: Double, | ||
private var modelType: String) extends Serializable with Logging { | ||
|
||
def this(lambda: Double) = this(lambda, "Multinomial") | ||
import NaiveBayes.{Bernoulli, Multinomial} | ||
|
||
def this() = this(1.0, "Multinomial") | ||
def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) | ||
|
||
def this() = this(1.0, NaiveBayes.Multinomial) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
|
||
/** Set the smoothing parameter. Default: 1.0. */ | ||
def setLambda(lambda: Double): NaiveBayes = { | ||
|
@@ -272,12 +278,11 @@ class NaiveBayes private ( | |
|
||
/** | ||
* Set the model type using a string (case-sensitive). | ||
* Supported options: "Multinomial" and "Bernoulli". | ||
* (default: Multinomial) | ||
* Supported options: "multinomial" (default) and "bernoulli". | ||
*/ | ||
def setModelType(modelType:String): NaiveBayes = { | ||
def setModelType(modelType: String): NaiveBayes = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As long as you're at it, can you change this to be case-insensitive so it's a little more robust? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried that option in the first version. Making it case-insensitive is easy but it may cause unexpected errors in the user code. nb.setModelType("Bernoulli")
if (nb.getModelType() == "Bernoulli") { // It becomes lowercase here.
...
} I can keep the input untouched but we need to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough |
||
require(NaiveBayes.supportedModelTypes.contains(modelType), | ||
s"NaiveBayes was created with an unknown ModelType: $modelType") | ||
s"NaiveBayes was created with an unknown model type: $modelType.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "model type" --> "modelType" to match parameter name |
||
this.modelType = modelType | ||
this | ||
} | ||
|
@@ -308,7 +313,7 @@ class NaiveBayes private ( | |
} | ||
if (!values.forall(v => v == 0.0 || v == 1.0)) { | ||
throw new SparkException( | ||
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.") | ||
s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") | ||
} | ||
} | ||
|
||
|
@@ -317,7 +322,7 @@ class NaiveBayes private ( | |
// TODO: similar to reduceByKeyLocally to save one stage. | ||
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)]( | ||
createCombiner = (v: Vector) => { | ||
if (modelType == "Bernoulli") { | ||
if (modelType == Bernoulli) { | ||
requireZeroOneBernoulliValues(v) | ||
} else { | ||
requireNonnegativeValues(v) | ||
|
@@ -352,11 +357,11 @@ class NaiveBayes private ( | |
labels(i) = label | ||
pi(i) = math.log(n + lambda) - piLogDenom | ||
val thetaLogDenom = modelType match { | ||
case "Multinomial" => math.log(sumTermFreqs.values.sum + numFeatures * lambda) | ||
case "Bernoulli" => math.log(n + 2.0 * lambda) | ||
case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda) | ||
case Bernoulli => math.log(n + 2.0 * lambda) | ||
case _ => | ||
// This should never happen. | ||
throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType") | ||
throw new UnknownError(s"Invalid model type: $modelType.") | ||
} | ||
var j = 0 | ||
while (j < numFeatures) { | ||
|
@@ -375,8 +380,14 @@ class NaiveBayes private ( | |
*/ | ||
object NaiveBayes { | ||
|
||
/** String name for multinomial model type. */ | ||
private[classification] val Multinomial: String = "multinomial" | ||
|
||
/** String name for Bernoulli model type. */ | ||
private[classification] val Bernoulli: String = "bernoulli" | ||
|
||
/* Set of modelTypes that NaiveBayes supports */ | ||
private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli") | ||
private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli) | ||
|
||
/** | ||
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs. | ||
|
@@ -406,7 +417,7 @@ object NaiveBayes { | |
* @param lambda The smoothing parameter | ||
*/ | ||
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { | ||
new NaiveBayes(lambda, "Multinomial").run(input) | ||
new NaiveBayes(lambda, Multinomial).run(input) | ||
} | ||
|
||
/** | ||
|
@@ -429,7 +440,7 @@ object NaiveBayes { | |
*/ | ||
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { | ||
require(supportedModelTypes.contains(modelType), | ||
s"NaiveBayes was created with an unknown ModelType: $modelType") | ||
s"NaiveBayes was created with an unknown model type: $modelType.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "model type" --> "modelType" |
||
new NaiveBayes(lambda, modelType).run(input) | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"NaiveBayes.Multinomial" --> "Multinomial"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't compile, because
this(...)
is not in the scope ofNaiveBayes
.