From 1eac767215ad5e0967b0f1a1986f0542ee71ec43 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 21 Oct 2014 00:19:11 -0700 Subject: [PATCH] address comments --- python/pyspark/mllib/stat.py | 4 ++-- python/pyspark/mllib/tests.py | 13 ++++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 5e4767c792f4f..84baf12b906df 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -170,8 +170,8 @@ def corr(x, y=None, method=None): ser = PickleSerializer() return ser.loads(str(bytes)).toArray() else: - jx = _to_java_object_rdd(x) - jy = _to_java_object_rdd(y) + jx = _to_java_object_rdd(x.map(float)) + jy = _to_java_object_rdd(y.map(float)) return sc._jvm.PythonMLLibAPI().corr(jx, jy, method) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 6b7fcac47fd0b..d6fb87b378b4a 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -206,12 +206,19 @@ def test_regression(self): class StatTests(PySparkTestCase): # SPARK-4023 - def test_col_with_random_rdd(self): + def test_col_with_different_rdds(self): + # numpy data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) summary = Statistics.colStats(data) self.assertEqual(1000, summary.count()) - mean = summary.mean() - self.assertTrue(all(abs(v) < 0.1 for v in mean)) + # array + data = self.sc.parallelize([range(10)] * 10) + summary = Statistics.colStats(data) + self.assertEqual(10, summary.count()) + # array + data = self.sc.parallelize([pyarray.array("d", range(10))] * 10) + summary = Statistics.colStats(data) + self.assertEqual(10, summary.count()) @unittest.skipIf(not _have_scipy, "SciPy not installed")