Skip to content

Commit

Permalink
Revert "[SPARK-12006][ML][PYTHON] Fix GMM failure if initialModel is …
Browse files Browse the repository at this point in the history
…not None"

This reverts commit fcd013c.

Author: Yin Huai <[email protected]>

Closes #10632 from yhuai/pythonStyle.

(cherry picked from commit e5cde7a)
Signed-off-by: Yin Huai <[email protected]>
  • Loading branch information
yhuai committed Jan 7, 2016
1 parent bc39775 commit d491464
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 13 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia
if initialModel.k != k:
raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s"
% (initialModel.k, k))
initialModelWeights = list(initialModel.weights)
initialModelWeights = initialModel.weights
initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)]
initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)]
weight, mu, sigma = callMLlibFunc("trainGaussianMixture", rdd.map(_convert_to_vector), k,
Expand Down
12 changes: 0 additions & 12 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,18 +310,6 @@ def test_gmm_deterministic(self):
for c1, c2 in zip(clusters1.weights, clusters2.weights):
self.assertEquals(round(c1, 7), round(c2, 7))

def test_gmm_with_initial_model(self):
from pyspark.mllib.clustering import GaussianMixture
data = self.sc.parallelize([
(-10, -5), (-9, -4), (10, 5), (9, 4)
])

gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001,
maxIterations=10, seed=63)
gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001,
maxIterations=10, seed=63, initialModel=gmm1)
self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0)

def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\
Expand Down

0 comments on commit d491464

Please sign in to comment.