Skip to content

Commit

Permalink
add case _ back
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed May 20, 2015
1 parent 3c456a8 commit 264a814
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class NaiveBayesModel private[mllib] (
theta: JIterable[JIterable[Double]]) =
this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))

require(supportedModelTypes.contains(modelType), s"Invalid model type $modelType.")
require(supportedModelTypes.contains(modelType),
s"Invalid model type $modelType. Supported model types are $supportedModelTypes.")

// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
Expand All @@ -76,6 +77,9 @@ class NaiveBayesModel private[mllib] (
value - math.log(1.0 - math.exp(value))
}
(Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
case _ =>
// This should never happen.
throw new UnknownError(s"Invalid model type: $modelType.")
}

override def predict(testData: RDD[Vector]): RDD[Double] = {
Expand Down Expand Up @@ -103,6 +107,9 @@ class NaiveBayesModel private[mllib] (
BLAS.axpy(1.0, piVector, prob)
BLAS.axpy(1.0, negThetaSum.get, prob)
labels(prob.argmax)
case _ =>
// This should never happen.
throw new UnknownError(s"Invalid model type: $modelType.")
}
}

Expand Down Expand Up @@ -275,7 +282,7 @@ class NaiveBayes private (
*/
def setModelType(modelType: String): NaiveBayes = {
require(NaiveBayes.supportedModelTypes.contains(modelType),
s"NaiveBayes was created with an unknown ModelType: $modelType.")
s"NaiveBayes was created with an unknown model type: $modelType.")
this.modelType = modelType
this
}
Expand Down Expand Up @@ -352,6 +359,9 @@ class NaiveBayes private (
val thetaLogDenom = modelType match {
case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
case Bernoulli => math.log(n + 2.0 * lambda)
case _ =>
// This should never happen.
throw new UnknownError(s"Invalid model type: $modelType.")
}
var j = 0
while (j < numFeatures) {
Expand Down Expand Up @@ -430,7 +440,7 @@ object NaiveBayes {
*/
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
require(supportedModelTypes.contains(modelType),
s"NaiveBayes was created with an unknown ModelType: $modelType")
s"NaiveBayes was created with an unknown model type: $modelType.")
new NaiveBayes(lambda, modelType).run(input)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
validateModelFit(pi, theta, model)

val validationData = NaiveBayesSuite.generateNaiveBayesInput(
pi, theta, nPoints, 17, Bernoulli)
pi, theta, nPoints, 17, Multinomial)
val validationRDD = sc.parallelize(validationData, 2)

// Test prediction on RDD.
Expand Down

0 comments on commit 264a814

Please sign in to comment.