Skip to content

Commit

Permalink
Added several Java-friendly APIs + unit tests: NaiveBayes, GaussianMi…
Browse files Browse the repository at this point in the history
…xture, LDA, StreamingKMeans, Statistics.corr, params
  • Loading branch information
jkbradley committed Jun 1, 2015
1 parent 3c01568 commit fe6dcfe
Show file tree
Hide file tree
Showing 14 changed files with 276 additions and 16 deletions.
20 changes: 9 additions & 11 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,10 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
}
}

/**
* Creates a param pair with the given value (for Java).
*/
/** Creates a param pair with the given value (for Java). */
def w(value: T): ParamPair[T] = this -> value

/**
* Creates a param pair with the given value (for Scala).
*/
/** Creates a param pair with the given value (for Scala). */
def ->(value: T): ParamPair[T] = ParamPair(this, value)

override final def toString: String = s"${parent}__$name"
Expand Down Expand Up @@ -190,6 +186,7 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double =>

def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

/** Creates a param pair with the given value (for Java). */
override def w(value: Double): ParamPair[Double] = super.w(value)
}

Expand All @@ -209,6 +206,7 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea

def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

/** Creates a param pair with the given value (for Java). */
override def w(value: Int): ParamPair[Int] = super.w(value)
}

Expand All @@ -228,6 +226,7 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo

def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

/** Creates a param pair with the given value (for Java). */
override def w(value: Float): ParamPair[Float] = super.w(value)
}

Expand All @@ -247,6 +246,7 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool

def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

/** Creates a param pair with the given value (for Java). */
override def w(value: Long): ParamPair[Long] = super.w(value)
}

Expand All @@ -260,6 +260,7 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV

def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

/** Creates a param pair with the given value (for Java). */
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}

Expand All @@ -274,8 +275,6 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array
def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)

override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)

/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
}
Expand All @@ -291,10 +290,9 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)

override def w(value: Array[Double]): ParamPair[Array[Double]] = super.w(value)

/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[Double]): ParamPair[Array[Double]] = w(value.asScala.toArray)
def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] =
w(value.asScala.map(_.asInstanceOf[Double]).toArray)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.{Logging, SparkContext, SparkException}
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable.IndexedSeq
import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV}

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
Expand Down Expand Up @@ -188,7 +189,10 @@ class GaussianMixture private (
new GaussianMixtureModel(weights, gaussians)
}

/** Average of dense breeze vectors */
/** Java-friendly version of [[run()]] */
def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd)

/** Average of dense breeze vectors */
private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
val v = BDV.zeros[Double](x(0).length)
x.foreach(xi => v += xi)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable}
Expand All @@ -46,7 +47,7 @@ import org.apache.spark.sql.{SQLContext, Row}
@Experimental
class GaussianMixtureModel(
val weights: Array[Double],
val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{
val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable {

require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")

Expand All @@ -65,6 +66,10 @@ class GaussianMixtureModel(
responsibilityMatrix.map(r => r.indexOf(r.max))
}

/** Java-friendly version of [[predict()]] */
def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]

/**
* Given the input vectors, return the membership value of each vector
* to all mixture components.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -345,6 +346,12 @@ class DistributedLDAModel private (
}
}

/** Java-friendly version of [[topicDistributions]] */
def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = {
new JavaPairRDD[java.lang.Long, Vector](
topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
}

// TODO:
// override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import scala.reflect.ClassTag

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext._
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream}
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
Expand Down Expand Up @@ -234,6 +236,9 @@ class StreamingKMeans(
}
}

/** Java-friendly version of `trainOn`. */
def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream)

/**
* Use the clustering model to make predictions on batches of data from a DStream.
*
Expand All @@ -245,6 +250,11 @@ class StreamingKMeans(
data.map(model.predict)
}

/** Java-friendly version of `predictOn`. */
def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = {
JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]])
}

/**
* Use the model to make predictions on the values of a DStream and carry over its keys.
*
Expand All @@ -257,6 +267,14 @@ class StreamingKMeans(
data.mapValues(model.predict)
}

/** Java-friendly version of `predictOnValues`. */
def predictOnValues[K](
data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = {
implicit val tag = fakeClassTag[K]
JavaPairDStream.fromPairDStream(
predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Integer)]])
}

/** Check whether cluster centers have been initialized. */
private[this] def assertInitialized(): Unit = {
if (model.clusterCenters == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.mllib.stat

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.linalg.{Matrix, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down Expand Up @@ -80,6 +81,10 @@ object Statistics {
*/
def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y)

/** Java-friendly version of [[corr()]] */
def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double =
corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]])

/**
* Compute the correlation for the input RDDs using the specified method.
* Methods currently supported: `pearson` (default), `spearman`.
Expand All @@ -96,6 +101,9 @@ object Statistics {
*/
def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method)

def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double =
corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method)

/**
* Conduct Pearson's chi-squared goodness of fit test of the observed data against the
* expected distribution.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public void testParams() {
testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
Assert.assertEquals(testParams.getMyStringParam(), "a");
Assert.assertEquals(testParams.getMyDoubleArrayParam(), new double[]{1.0, 2.0});
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,31 @@ public JavaTestParams setMyStringParam(String value) {
set(myStringParam_, value); return this;
}

private DoubleArrayParam myDoubleArrayParam_;
public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; }

public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); }

public JavaTestParams setMyDoubleArrayParam(double[] value) {
set(myDoubleArrayParam_, value); return this;
}

private void init() {
myIntParam_ = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0));
myIntParam_ = new IntParam(this, "myIntParam", "this is an int param",
ParamValidators.gt(0));
myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param",
ParamValidators.inRange(0.0, 1.0));
List<String> validStrings = Lists.newArrayList("a", "b");
myStringParam_ = new Param<String>(this, "myStringParam", "this is a string param",
ParamValidators.inArray(validStrings));
myDoubleArrayParam_ =
new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");

setDefault(myIntParam_, 1);
setDefault(myIntParam_.w(1));
setDefault(myDoubleParam_, 0.5);
setDefault(myIntParam().w(1), myDoubleParam().w(0.5));
setDefault(myDoubleArrayParam_, new double[]{1.0, 2.0});
setDefault(myDoubleArrayParam_.w(new double[]{1.0, 2.0}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.ml.classification;
package org.apache.spark.mllib.classification;

import java.io.Serializable;
import java.util.List;
Expand All @@ -28,7 +28,6 @@
import org.junit.Test;

import org.apache.spark.SparkConf;
import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.clustering;

import java.io.Serializable;
import java.util.List;

import com.google.common.collect.Lists;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import static org.junit.Assert.assertEquals;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;

public class JavaGaussianMixtureSuite implements Serializable {
private transient JavaSparkContext sc;

@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaGaussianMixture");
}

@After
public void tearDown() {
sc.stop();
sc = null;
}

@Test
public void runGaussianMixture() {
List<Vector> points = Lists.newArrayList(
Vectors.dense(1.0, 2.0, 6.0),
Vectors.dense(1.0, 3.0, 0.0),
Vectors.dense(1.0, 4.0, 6.0)
);

JavaRDD<Vector> data = sc.parallelize(points, 2);
GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234)
.run(data);
assertEquals(model.gaussians().length, 2);
JavaRDD<Integer> predictions = model.predict(data);
predictions.first();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ public void distributedLDAModel() {
// Check: log probabilities
assert(model.logLikelihood() < 0.0);
assert(model.logPrior() < 0.0);

// Check: topic distributions
JavaPairRDD<Long, Vector> topicDistributions = model.javaTopicDistributions();
assertEquals(topicDistributions.count(), corpus.count());
}

@Test
Expand Down
Loading

0 comments on commit fe6dcfe

Please sign in to comment.