Skip to content

Commit

Permalink
[SPARK-26228][MLLIB] OOM issue encountered when computing Gramian matrix
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Avoid memory problems in closure cleaning when handling large Gramians (>= 16K rows/cols) by using null as zeroValue

## How was this patch tested?

Existing tests.
Note that it's hard to test the case that triggers this issue as it would require a large amount of memory and run a while. I confirmed locally that a 16K x 16K Gramian failed with tons of driver memory before, and didn't fail upfront after this change.

Closes apache#23600 from srowen/SPARK-26228.

Authored-by: Sean Owen <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
  • Loading branch information
srowen authored and jackylee-ch committed Feb 18, 2019
1 parent 00f1ace commit ef9a4d3
Showing 1 changed file with 34 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,25 @@ class RowMatrix @Since("1.0.0") (
val nt = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2))

// Compute the upper triangular part of the gram matrix.
val GU = rows.treeAggregate(new BDV[Double](nt))(
seqOp = (U, v) => {
val GU = rows.treeAggregate(null.asInstanceOf[BDV[Double]])(
seqOp = (maybeU, v) => {
val U =
if (maybeU == null) {
new BDV[Double](nt)
} else {
maybeU
}
BLAS.spr(1.0, v, U.data)
U
}, combOp = (U1, U2) => U1 += U2)
}, combOp = (U1, U2) =>
if (U1 == null) {
U2
} else if (U2 == null) {
U1
} else {
U1 += U2
}
)

RowMatrix.triuToFull(n, GU.data)
}
Expand All @@ -136,8 +150,14 @@ class RowMatrix @Since("1.0.0") (
// This succeeds when n <= 65535, which is checked above
val nt = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2))

val MU = rows.treeAggregate(new BDV[Double](nt))(
seqOp = (U, v) => {
val MU = rows.treeAggregate(null.asInstanceOf[BDV[Double]])(
seqOp = (maybeU, v) => {
val U =
if (maybeU == null) {
new BDV[Double](nt)
} else {
maybeU
}

val n = v.size
val na = Array.ofDim[Double](n)
Expand All @@ -150,7 +170,15 @@ class RowMatrix @Since("1.0.0") (

BLAS.spr(1.0, new DenseVector(na), U.data)
U
}, combOp = (U1, U2) => U1 += U2)
}, combOp = (U1, U2) =>
if (U1 == null) {
U2
} else if (U2 == null) {
U1
} else {
U1 += U2
}
)

bc.destroy()

Expand Down

0 comments on commit ef9a4d3

Please sign in to comment.