-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-7241] Pearson correlation for DataFrames #5858
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,32 @@ import org.apache.spark.sql.execution.stat._ | |
@Experimental | ||
final class DataFrameStatFunctions private[sql](df: DataFrame) { | ||
|
||
/** | ||
* Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson | ||
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in | ||
* MLlib's Statistics. | ||
* | ||
* @param col1 the name of the column | ||
* @param col2 the name of the column to calculate the correlation against | ||
* @return The Pearson Correlation Coefficient as a Double. | ||
*/ | ||
def corr(col1: String, col2: String, method: String): Double = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto. Accept |
||
require(method == "pearson", "Currently only the calculation of the Pearson Correlation " + | ||
"coefficient is supported.") | ||
StatFunctions.pearsonCorrelation(df, Seq(col1, col2)) | ||
} | ||
|
||
/** | ||
* Calculates the Pearson Correlation Coefficient of two columns of a DataFrame. | ||
* | ||
* @param col1 the name of the column | ||
* @param col2 the name of the column to calculate the correlation against | ||
* @return The Pearson Correlation Coefficient as a Double. | ||
*/ | ||
def corr(col1: String, col2: String): Double = { | ||
corr(col1, col2, "pearson") | ||
} | ||
|
||
/** | ||
* Finding frequent items for columns, possibly with false positives. Using the | ||
* frequent element count algorithm described in | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,43 +23,51 @@ import org.apache.spark.sql.types.{DoubleType, NumericType} | |
|
||
private[sql] object StatFunctions { | ||
|
||
/** Calculate the Pearson Correlation Coefficient for the given columns */ | ||
private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { | ||
val counts = collectStatisticalData(df, cols) | ||
counts.Ck / math.sqrt(counts.MkX * counts.MkY) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be the sample correlation as well. In the unit tests, please provide R commands that compute the correlation and the result, and verify that we output the same value. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is, isn't it? The n - 1's cancel. I tested with sciPy Pearsonr method.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, they canceled each other. Could you add a non-trivial test to Scala? Now it only has |
||
} | ||
|
||
/** Helper class to simplify tracking and merging counts. */ | ||
private class CovarianceCounter extends Serializable { | ||
var xAvg = 0.0 | ||
var yAvg = 0.0 | ||
var Ck = 0.0 | ||
var count = 0L | ||
var xAvg = 0.0 // the mean of all examples seen so far in col1 | ||
var yAvg = 0.0 // the mean of all examples seen so far in col2 | ||
var Ck = 0.0 // the co-moment after k examples | ||
var MkX = 0.0 // sum of squares of differences from the (current) mean for col1 | ||
var MkY = 0.0 // sum of squares of differences from the (current) mean for col1 | ||
var count = 0L // count of observed examples | ||
// add an example to the calculation | ||
def add(x: Double, y: Double): this.type = { | ||
val oldX = xAvg | ||
val deltaX = x - xAvg | ||
val deltaY = y - yAvg | ||
count += 1 | ||
xAvg += (x - xAvg) / count | ||
yAvg += (y - yAvg) / count | ||
Ck += (y - yAvg) * (x - oldX) | ||
xAvg += deltaX / count | ||
yAvg += deltaY / count | ||
Ck += deltaX * (y - yAvg) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Umm, we need to use the updated There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I didn't see |
||
MkX += deltaX * (x - xAvg) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
MkY += deltaY * (y - yAvg) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here |
||
this | ||
} | ||
// merge counters from other partitions. Formula can be found at: | ||
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Covariance | ||
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance | ||
def merge(other: CovarianceCounter): this.type = { | ||
val totalCount = count + other.count | ||
Ck += other.Ck + | ||
(xAvg - other.xAvg) * (yAvg - other.yAvg) * count / totalCount * other.count | ||
val deltaX = xAvg - other.xAvg | ||
val deltaY = yAvg - other.yAvg | ||
Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count | ||
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount | ||
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount | ||
MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count | ||
MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count | ||
count = totalCount | ||
this | ||
} | ||
// return the sample covariance for the observed examples | ||
def cov: Double = Ck / (count - 1) | ||
} | ||
|
||
/** | ||
* Calculate the covariance of two numerical columns of a DataFrame. | ||
* @param df The DataFrame | ||
* @param cols the column names | ||
* @return the covariance of the two columns. | ||
*/ | ||
private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { | ||
private def collectStatisticalData(df: DataFrame, cols: Seq[String]): CovarianceCounter = { | ||
require(cols.length == 2, "Currently cov supports calculating the covariance " + | ||
"between two columns.") | ||
cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => | ||
|
@@ -68,13 +76,23 @@ private[sql] object StatFunctions { | |
s"with dataType ${data.get.dataType} not supported.") | ||
} | ||
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) | ||
val counts = df.select(columns:_*).rdd.aggregate(new CovarianceCounter)( | ||
df.select(columns:_*).rdd.aggregate(new CovarianceCounter)( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. space after |
||
seqOp = (counter, row) => { | ||
counter.add(row.getDouble(0), row.getDouble(1)) | ||
}, | ||
combOp = (baseCounter, other) => { | ||
baseCounter.merge(other) | ||
}) | ||
}) | ||
} | ||
|
||
/** | ||
* Calculate the covariance of two numerical columns of a DataFrame. | ||
* @param df The DataFrame | ||
* @param cols the column names | ||
* @return the covariance of the two columns. | ||
*/ | ||
private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { | ||
val counts = collectStatisticalData(df, cols) | ||
counts.cov | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -187,6 +187,13 @@ public void testFrequentItems() { | |
Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); | ||
} | ||
|
||
@Test | ||
public void testCorrelation() { | ||
DataFrame df = context.table("testData2"); | ||
Double pearsonCorr = df.stat().corr("a", "b", "pearson"); | ||
Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The numerical error should be close to machine precision. So let's change |
||
} | ||
|
||
@Test | ||
public void testCovariance() { | ||
DataFrame df = context.table("testData2"); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,10 +30,10 @@ class DataFrameStatSuite extends FunSuite { | |
def toLetter(i: Int): String = (i + 97).toChar.toString | ||
|
||
test("Frequent Items") { | ||
val rows = Array.tabulate(1000) { i => | ||
val rows = Seq.tabulate(1000) { i => | ||
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) | ||
} | ||
val df = sqlCtx.sparkContext.parallelize(rows).toDF("numbers", "letters", "negDoubles") | ||
val df = rows.toDF("numbers", "letters", "negDoubles") | ||
|
||
val results = df.stat.freqItems(Array("numbers", "letters"), 0.1) | ||
val items = results.collect().head | ||
|
@@ -43,12 +43,18 @@ class DataFrameStatSuite extends FunSuite { | |
val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1) | ||
val items2 = singleColResults.collect().head | ||
items2.getSeq[Double](0) should contain (-1.0) | ||
} | ||
|
||
test("pearson correlation") { | ||
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") | ||
val corr1 = df.stat.corr("a", "b", "pearson") | ||
assert(math.abs(corr1 - 1.0) < 1e-6) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, |
||
val corr2 = df.stat.corr("a", "c", "pearson") | ||
assert(math.abs(corr2 + 1.0) < 1e-6) | ||
} | ||
|
||
test("covariance") { | ||
val rows = Array.tabulate(10)(i => (i, 2.0 * i, toLetter(i))) | ||
val df = sqlCtx.sparkContext.parallelize(rows).toDF("singles", "doubles", "letters") | ||
val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters") | ||
|
||
val results = df.stat.cov("singles", "doubles") | ||
assert(math.abs(results - 55.0 / 3) < 1e-6) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accept
Column
as well?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If @rxin is okay with it, I can add those in a follow up PR for all the methods that we added.