Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Oct 21, 2014
1 parent 0871576 commit 1eac767
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
4 changes: 2 additions & 2 deletions python/pyspark/mllib/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def corr(x, y=None, method=None):
ser = PickleSerializer()
return ser.loads(str(bytes)).toArray()
else:
jx = _to_java_object_rdd(x)
jy = _to_java_object_rdd(y)
jx = _to_java_object_rdd(x.map(float))
jy = _to_java_object_rdd(y.map(float))
return sc._jvm.PythonMLLibAPI().corr(jx, jy, method)


Expand Down
13 changes: 10 additions & 3 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,19 @@ def test_regression(self):

class StatTests(PySparkTestCase):
# SPARK-4023
def test_col_with_random_rdd(self):
def test_col_with_different_rdds(self):
# numpy
data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10)
summary = Statistics.colStats(data)
self.assertEqual(1000, summary.count())
mean = summary.mean()
self.assertTrue(all(abs(v) < 0.1 for v in mean))
# array
data = self.sc.parallelize([range(10)] * 10)
summary = Statistics.colStats(data)
self.assertEqual(10, summary.count())
# array
data = self.sc.parallelize([pyarray.array("d", range(10))] * 10)
summary = Statistics.colStats(data)
self.assertEqual(10, summary.count())


@unittest.skipIf(not _have_scipy, "SciPy not installed")
Expand Down

0 comments on commit 1eac767

Please sign in to comment.