-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLlib] [SPARK-2885] DIMSUM: All-pairs similarity #1778
Changes from all commits
5b8cd7d
6bebabb
3726ca9
654c4fb
502ce52
05e59b8
75edb25
029aa9c
139c8e1
41e8ece
dbc55ba
f56a882
eb1dc20
0f12ade
75a0b51
613f261
e9c6791
3764983
fb296f6
25e9d0d
251bb9c
0e4eda4
3c4cf41
f2947e4
254ca08
2196ba5
9fe17c0
aea0247
3467cff
976ddd4
ee8bd65
4eb71c6
404c64c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,17 +19,21 @@ package org.apache.spark.mllib.linalg.distributed | |
|
||
import java.util.Arrays | ||
|
||
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV} | ||
import breeze.linalg.{svd => brzSvd, axpy => brzAxpy} | ||
import scala.collection.mutable.ListBuffer | ||
|
||
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, | ||
svd => brzSvd} | ||
import breeze.numerics.{sqrt => brzSqrt} | ||
import com.github.fommil.netlib.BLAS.{getInstance => blas} | ||
|
||
import org.apache.spark.Logging | ||
import org.apache.spark.SparkContext._ | ||
import org.apache.spark.annotation.Experimental | ||
import org.apache.spark.mllib.linalg._ | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.Logging | ||
import org.apache.spark.mllib.rdd.RDDFunctions._ | ||
import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.util.random.XORShiftRandom | ||
import org.apache.spark.storage.StorageLevel | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove extra empty line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed extra empty line |
||
/** | ||
|
@@ -411,6 +415,165 @@ class RowMatrix( | |
new RowMatrix(AB, nRows, B.numCols) | ||
} | ||
|
||
/** | ||
* Compute all cosine similarities between columns of this matrix using the brute-force | ||
* approach of computing normalized dot products. | ||
* | ||
* @return An n x n sparse upper-triangular matrix of cosine similarities between | ||
* columns of this matrix. | ||
*/ | ||
def columnSimilarities(): CoordinateMatrix = { | ||
columnSimilarities(0.0) | ||
} | ||
|
||
/** | ||
* Compute similarities between columns of this matrix using a sampling approach. | ||
* | ||
* The threshold parameter is a trade-off knob between estimate quality and computational cost. | ||
* | ||
* Setting a threshold of 0 guarantees deterministic correct results, but comes at exactly | ||
* the same cost as the brute-force approach. Setting the threshold to positive values | ||
* incurs strictly less computational cost than the brute-force approach, however the | ||
* similarities computed will be estimates. | ||
* | ||
* The sampling guarantees relative-error correctness for those pairs of columns that have | ||
* similarity greater than the given similarity threshold. | ||
* | ||
* To describe the guarantee, we set some notation: | ||
* Let A be the smallest in magnitude non-zero element of this matrix. | ||
* Let B be the largest in magnitude non-zero element of this matrix. | ||
* Let L be the maximum number of non-zeros per row. | ||
* | ||
* For example, for {0,1} matrices: A=B=1. | ||
* Another example, for the Netflix matrix: A=1, B=5 | ||
* | ||
* For those column pairs that are above the threshold, | ||
* the computed similarity is correct to within 20% relative error with probability | ||
* at least 1 - (0.981)^10/B^ | ||
* | ||
* The shuffle size is bounded by the *smaller* of the following two expressions: | ||
* | ||
* O(n log(n) L / (threshold * A)) | ||
* O(m L^2^) | ||
* | ||
* The latter is the cost of the brute-force approach, so for non-zero thresholds, | ||
* the cost is always cheaper than the brute-force approach. | ||
* | ||
* @param threshold Set to 0 for deterministic guaranteed correctness. | ||
* Similarities above this threshold are estimated | ||
* with the cost vs estimate quality trade-off described above. | ||
* @return An n x n sparse upper-triangular matrix of cosine similarities | ||
* between columns of this matrix. | ||
*/ | ||
def columnSimilarities(threshold: Double): CoordinateMatrix = { | ||
require(threshold >= 0, s"Threshold cannot be negative: $threshold") | ||
|
||
if (threshold > 1) { | ||
logWarning(s"Threshold is greater than 1: $threshold " + | ||
"Computation will be more efficient with promoted sparsity, " + | ||
" however there is no correctness guarantee.") | ||
} | ||
|
||
val gamma = if (threshold < 1e-6) { | ||
Double.PositiveInfinity | ||
} else { | ||
10 * math.log(numCols()) / threshold | ||
} | ||
|
||
columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma) | ||
} | ||
|
||
/** | ||
* Find all similar columns using the DIMSUM sampling algorithm, described in two papers | ||
* | ||
* http://arxiv.org/abs/1206.2082 | ||
* http://arxiv.org/abs/1304.1467 | ||
* | ||
* @param colMags A vector of column magnitudes | ||
* @param gamma The oversampling parameter. For provable results, set to 10 * log(n) / s, | ||
* where s is the smallest similarity score to be estimated, | ||
* and n is the number of columns | ||
* @return An n x n sparse upper-triangular matrix of cosine similarities | ||
* between columns of this matrix. | ||
*/ | ||
private[mllib] def columnSimilaritiesDIMSUM( | ||
colMags: Array[Double], | ||
gamma: Double): CoordinateMatrix = { | ||
require(gamma > 1.0, s"Oversampling should be greater than 1: $gamma") | ||
require(colMags.size == this.numCols(), "Number of magnitudes didn't match column dimension") | ||
val sg = math.sqrt(gamma) // sqrt(gamma) used many times | ||
|
||
// Don't divide by zero for those columns with zero magnitude | ||
val colMagsCorrected = colMags.map(x => if (x == 0) 1.0 else x) | ||
|
||
val sc = rows.context | ||
val pBV = sc.broadcast(colMagsCorrected.map(c => sg / c)) | ||
val qBV = sc.broadcast(colMagsCorrected.map(c => math.min(sg, c))) | ||
|
||
val sims = rows.mapPartitionsWithIndex { (indx, iter) => | ||
val p = pBV.value | ||
val q = qBV.value | ||
|
||
val rand = new XORShiftRandom(indx) | ||
val scaled = new Array[Double](p.size) | ||
iter.flatMap { row => | ||
val buf = new ListBuffer[((Int, Int), Double)]() | ||
row match { | ||
case sv: SparseVector => | ||
val nnz = sv.indices.size | ||
var k = 0 | ||
while (k < nnz) { | ||
scaled(k) = sv.values(k) / q(sv.indices(k)) | ||
k += 1 | ||
} | ||
k = 0 | ||
while (k < nnz) { | ||
val i = sv.indices(k) | ||
val iVal = scaled(k) | ||
if (iVal != 0 && rand.nextDouble() < p(i)) { | ||
var l = k + 1 | ||
while (l < nnz) { | ||
val j = sv.indices(l) | ||
val jVal = scaled(l) | ||
if (jVal != 0 && rand.nextDouble() < p(j)) { | ||
buf += (((i, j), iVal * jVal)) | ||
} | ||
l += 1 | ||
} | ||
} | ||
k += 1 | ||
} | ||
case dv: DenseVector => | ||
val n = dv.values.size | ||
var i = 0 | ||
while (i < n) { | ||
scaled(i) = dv.values(i) / q(i) | ||
i += 1 | ||
} | ||
i = 0 | ||
while (i < n) { | ||
val iVal = scaled(i) | ||
if (iVal != 0 && rand.nextDouble() < p(i)) { | ||
var j = i + 1 | ||
while (j < n) { | ||
val jVal = scaled(j) | ||
if (jVal != 0 && rand.nextDouble() < p(j)) { | ||
buf += (((i, j), iVal * jVal)) | ||
} | ||
j += 1 | ||
} | ||
} | ||
i += 1 | ||
} | ||
} | ||
buf | ||
} | ||
}.reduceByKey(_ + _).map { case ((i, j), sim) => | ||
MatrixEntry(i.toLong, j.toLong, sim) | ||
} | ||
new CoordinateMatrix(sims, numCols(), numCols()) | ||
} | ||
|
||
private[mllib] override def toBreeze(): BDM[Double] = { | ||
val m = numRows().toInt | ||
val n = numCols().toInt | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,4 +53,14 @@ trait MultivariateStatisticalSummary { | |
* Minimum value of each column. | ||
*/ | ||
def min: Vector | ||
|
||
/** | ||
* Euclidean magnitude of each column | ||
*/ | ||
def normL2: Vector | ||
|
||
/** | ||
* L1 norm of each column | ||
*/ | ||
def normL1: Vector | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For general vectors and matrices, L1 and L2 norms are widely accepted as summaries of a vector and are standard linear algebra: http://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,6 +95,40 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { | |
} | ||
} | ||
|
||
test("similar columns") { | ||
val colMags = Vectors.dense(Math.sqrt(126), Math.sqrt(66), Math.sqrt(94)) | ||
val expected = BDM( | ||
(0.0, 54.0, 72.0), | ||
(0.0, 0.0, 78.0), | ||
(0.0, 0.0, 0.0)) | ||
|
||
for (i <- 0 until n; j <- 0 until n) { | ||
expected(i, j) /= (colMags(i) * colMags(j)) | ||
} | ||
|
||
for (mat <- Seq(denseMat, sparseMat)) { | ||
val G = mat.columnSimilarities(0.11).toBreeze() | ||
for (i <- 0 until n; j <- 0 until n) { | ||
if (expected(i, j) > 0) { | ||
val actual = expected(i, j) | ||
val estimate = G(i, j) | ||
assert(math.abs(actual - estimate) / actual < 0.2, | ||
s"Similarities not close enough: $actual vs $estimate") | ||
} | ||
} | ||
} | ||
|
||
for (mat <- Seq(denseMat, sparseMat)) { | ||
val G = mat.columnSimilarities() | ||
assert(closeToZero(G.toBreeze() - expected)) | ||
} | ||
|
||
for (mat <- Seq(denseMat, sparseMat)) { | ||
val G = mat.columnSimilaritiesDIMSUM(colMags.toArray, 150.0) | ||
assert(closeToZero(G.toBreeze() - expected)) | ||
} | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no test for obtaining partial similar pairs. Do you mind adding one? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added test for partial similar pairs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the test output only a subset of column pairs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test is estimating some column pairs i.e. I looked at the output and checked that some pairs don't have their similarity perfectly computed. |
||
test("svd of a full-rank matrix") { | ||
for (mat <- Seq(denseMat, sparseMat)) { | ||
for (mode <- Seq("auto", "local-svd", "local-eigs", "dist-eigs")) { | ||
|
@@ -190,6 +224,9 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { | |
assert(summary.numNonzeros === Vectors.dense(3.0, 3.0, 4.0), "nnz mismatch") | ||
assert(summary.max === Vectors.dense(9.0, 7.0, 8.0), "max mismatch") | ||
assert(summary.min === Vectors.dense(0.0, 0.0, 1.0), "column mismatch.") | ||
assert(summary.normL2 === Vectors.dense(Math.sqrt(126), Math.sqrt(66), Math.sqrt(94)), | ||
"magnitude mismatch.") | ||
assert(summary.normL1 === Vectors.dense(18.0, 12.0, 16.0), "L1 norm mismatch") | ||
} | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
separate scala imports from java ones
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separated