Skip to content

Commit

Permalink
Some changes and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Jun 18, 2015
1 parent d8b066a commit 4b1481f
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 5 deletions.
43 changes: 38 additions & 5 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
from pyspark import RDD
from pyspark import SparkContext
from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector
from pyspark.mllib.stat.distribution import MultivariateGaussian
from pyspark.mllib.util import Saveable, Loader, inherit_doc
from pyspark.streaming import DStream

__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture']

Expand Down Expand Up @@ -269,14 +270,46 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia
class StreamingKMeansModel(KMeansModel):
"""
.. note:: Experimental
>>> initCenters, initWeights = [[0.0, 0.0], [1.0, 1.0]], [1.0, 1.0]
>>> stkm = StreamingKMeansModel(initCenters, initWeights)
>>> data = sc.parallelize([[-0.1, -0.1], [0.1, 0.1],
... [0.9, 0.9], [1.1, 1.1]])
>>> stkm = stkm.update(data, 1.0, "batches")
>>> stkm.centers
array([[ 0., 0.],
[ 1., 1.]])
>>> stkm.predict([-0.1, -0.1]) == stkm.predict([0.1, 0.1]) == 0
True
>>> stkm.predict([0.9, 0.9]) == stkm.predict([1.1, 1.1]) == 1
True
>>> stkm.getClusterWeights
[3.0, 3.0]
>>> decayFactor = 0.0
>>> data = sc.parallelize([DenseVector([1.5, 1.5]), DenseVector([0.2, 0.2])])
>>> stkm = stkm.update(data, 0.0, "batches")
>>> stkm.centers
array([[ 0.2, 0.2],
[ 1.5, 1.5]])
>>> stkm.getClusterWeights
[1.0, 1.0]
>>> stkm.predict([0.2, 0.2])
0
>>> stkm.predict([1.5, 1.5])
1
"""
def __init__(self, clusterCenters, clusterWeights):
super(StreamingKMeansModel, self).__init__(centers=clusterCenters)
self._clusterWeights = list(clusterWeights)

@property
def getClusterWeights(self):
return self._clusterWeights

def update(self, data, decayFactor, timeUnit):
if not isinstance(data, RDD):
raise TypeError("data should be of a RDD, got %s." % type(data))
data = data.map(_convert_to_vector)
decayFactor = float(decayFactor)
if timeUnit not in ["batches", "points"]:
raise ValueError(
Expand Down Expand Up @@ -306,7 +339,7 @@ def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"):
def _validate(self, dstream):
if self.model is None:
raise ValueError(
"Initial centers should be set either by setInitialCenters ")
"Initial centers should be set either by setInitialCenters "
"or setRandomCenters.")
if not isinstance(dstream, DStream):
raise TypeError(
Expand Down Expand Up @@ -342,18 +375,18 @@ def trainOn(self, dstream):

def update(_, rdd):
if rdd:
self.model = self.model.update(rdd)
self.model = self.model.update(rdd, self._decayFactor, self._timeUnit)

dstream.foreachRDD(update)
return self

def predictOn(self, dstream):
self._validate(dstream)
dstream.map(model.predict)
dstream.map(self.model.predict)

def predictOnValues(self, dstream):
self._validate(dstream)
dstream.mapValues(model.predict)
dstream.mapValues(self.model.predict)


def _test():
Expand Down
40 changes: 40 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

from pyspark import SparkContext
from pyspark.mllib.common import _to_java_object_rdd
from pyspark.mllib.clustering import StreamingKMeans
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
from pyspark.mllib.regression import LabeledPoint
Expand All @@ -48,6 +49,7 @@
from pyspark.mllib.feature import StandardScaler
from pyspark.mllib.feature import ElementwiseProduct
from pyspark.serializers import PickleSerializer
from pyspark.streaming import StreamingContext
from pyspark.sql import SQLContext

_have_scipy = False
Expand Down Expand Up @@ -863,6 +865,44 @@ def test_model_transform(self):
eprod.transform(sparsevec), SparseVector(3, [0], [3]))


class StreamingKMeansTest(MLlibTestCase):
def test_model_params(self):
stkm = StreamingKMeans()
stkm.setK(5).setDecayFactor(0.0)
self.assertEquals(stkm._k, 5)
self.assertEquals(stkm._decayFactor, 0.0)

# Model not set yet.
self.assertIsNone(stkm.model)
self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0])

stkm.setInitialCenters([[0.0, 0.0], [1.0, 1.0]], [1.0, 1.0])
self.assertEqual(stkm.model.centers, [[0.0, 0.0], [1.0, 1.0]])
self.assertEqual(stkm.model.getClusterWeights, [1.0, 1.0])

def test_model(self):
stkm = StreamingKMeans()
initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]]
weights = [1.0, 1.0, 1.0, 1.0]
stkm.setInitialCenters(initCenters, weights)

offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]]
batches = []

for offset in offsets:
batches.append([[offset[0] + center[0], offset[1] + center[1]]
for center in initCenters])

batches = [self.sc.parallelize(batch, 1) for batch in batches]
ssc = StreamingContext(self.sc, 2.0)
input_stream = ssc.queueStream(batches)
stkm.trainOn(input_stream)
ssc.start()
finalModel = stkm.model
self.assertEqual(finalModel.centers, initCenters)
# self.assertEqual(finalModel.getClusterWeights, [5.0, 5.0, 5.0, 5.0])


if __name__ == "__main__":
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")
Expand Down

0 comments on commit 4b1481f

Please sign in to comment.