From 9714fa547325ed7b6a8066a88957537936b233dd Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 24 Aug 2018 15:03:00 -0700 Subject: [PATCH] [SPARK-25234][SPARKR] avoid integer overflow in parallelize ## What changes were proposed in this pull request? `parallelize` uses integer multiplication to determine the split indices. It might cause integer overflow. ## How was this patch tested? unit test Closes #22225 from mengxr/SPARK-25234. Authored-by: Xiangrui Meng Signed-off-by: Xiangrui Meng --- R/pkg/R/context.R | 9 ++++----- R/pkg/tests/fulltests/test_context.R | 7 +++++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 7e77ea4e002d9..f168ca76b6007 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -138,11 +138,10 @@ parallelize <- function(sc, coll, numSlices = 1) { sizeLimit <- getMaxAllocationLimit(sc) objectSize <- object.size(coll) + len <- length(coll) # For large objects we make sure the size of each slice is also smaller than sizeLimit - numSerializedSlices <- max(numSlices, ceiling(objectSize / sizeLimit)) - if (numSerializedSlices > length(coll)) - numSerializedSlices <- length(coll) + numSerializedSlices <- min(len, max(numSlices, ceiling(objectSize / sizeLimit))) # Generate the slice ids to put each row # For instance, for numSerializedSlices of 22, length of 50 @@ -153,8 +152,8 @@ parallelize <- function(sc, coll, numSlices = 1) { splits <- if (numSerializedSlices > 0) { unlist(lapply(0: (numSerializedSlices - 1), function(x) { # nolint start - start <- trunc((x * length(coll)) / numSerializedSlices) - end <- trunc(((x + 1) * length(coll)) / numSerializedSlices) + start <- trunc((as.numeric(x) * len) / numSerializedSlices) + end <- trunc(((as.numeric(x) + 1) * len) / numSerializedSlices) # nolint end rep(start, end - start) })) diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index f0d0a5114f89f..288a2714a554e 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -240,3 +240,10 @@ test_that("add and get file to be downloaded with Spark job on every node", { unlink(path, recursive = TRUE) sparkR.session.stop() }) + +test_that("SPARK-25234: parallelize should not have integer overflow", { + sc <- sparkR.sparkContext(master = sparkRTestMaster) + # 47000 * 47000 exceeds integer range + parallelize(sc, 1:47000, 47000) + sparkR.session.stop() +})