From a4be981c0476bb613c660b70a370f671c8b3ffee Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 29 May 2018 23:26:39 -0700 Subject: [PATCH] [SPARK-24331][SPARKR][SQL] Adding arrays_overlap, array_repeat, map_entries to SparkR ## What changes were proposed in this pull request? The PR adds functions `arrays_overlap`, `array_repeat`, `map_entries` to SparkR. ## How was this patch tested? Tests added into R/pkg/tests/fulltests/test_sparkSQL.R ## Examples ### arrays_overlap ``` df <- createDataFrame(list(list(list(1L, 2L), list(3L, 1L)), list(list(1L, 2L), list(3L, 4L)), list(list(1L, NA), list(3L, 4L)))) collect(select(df, arrays_overlap(df[[1]], df[[2]]))) ``` ``` arrays_overlap(_1, _2) 1 TRUE 2 FALSE 3 NA ``` ### array_repeat ``` df <- createDataFrame(list(list("a", 3L), list("b", 2L))) collect(select(df, array_repeat(df[[1]], df[[2]]))) ``` ``` array_repeat(_1, _2) 1 a, a, a 2 b, b ``` ``` collect(select(df, array_repeat(df[[1]], 2L))) ``` ``` array_repeat(_1, 2) 1 a, a 2 b, b ``` ### map_entries ``` df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) collect(select(df, map_entries(df$map))) ``` ``` map_entries(map) 1 x, 1, y, 2 ``` Author: Marek Novotny Closes #21434 from mn-mikke/SPARK-24331. --- R/pkg/NAMESPACE | 3 ++ R/pkg/R/DataFrame.R | 2 + R/pkg/R/functions.R | 58 ++++++++++++++++++++++++--- R/pkg/R/generics.R | 12 ++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 22 +++++++++- 5 files changed, 91 insertions(+), 6 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c575fe255f57a..73a33af4dd48b 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -204,7 +204,9 @@ exportMethods("%<=>%", "array_max", "array_min", "array_position", + "array_repeat", "array_sort", + "arrays_overlap", "asc", "ascii", "asin", @@ -302,6 +304,7 @@ exportMethods("%<=>%", "lower", "lpad", "ltrim", + "map_entries", "map_keys", "map_values", "max", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a1c9495b0795e..70eb7a874b75c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2297,6 +2297,8 @@ setMethod("rename", setClassUnion("characterOrColumn", c("character", "Column")) +setClassUnion("numericOrColumn", c("numeric", "Column")) + #' Arrange Rows by Variables #' #' Sort a SparkDataFrame by the specified column(s). diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index fcb3521f901ea..abc91aeeb4825 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -189,6 +189,7 @@ NULL #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. #' } +#' @param y Column to compute on. #' @param value A value to compute on. #' \itemize{ #' \item \code{array_contains}: a value to be checked if contained in the column. @@ -207,7 +208,7 @@ NULL #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) -#' head(select(tmp, array_position(tmp$v1, 21), array_sort(tmp$v1))) +#' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1))) #' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) @@ -216,11 +217,10 @@ NULL #' head(select(tmp, sort_array(tmp$v1))) #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) -#' head(select(tmp3, map_keys(tmp3$v3))) -#' head(select(tmp3, map_values(tmp3$v3))) +#' head(select(tmp3, map_entries(tmp3$v3), map_keys(tmp3$v3), map_values(tmp3$v3))) #' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) -#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$hp)) -#' head(select(tmp4, concat(tmp4$v4, tmp4$v5))) +#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) +#' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) #' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))} NULL @@ -3048,6 +3048,26 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{array_repeat}: Creates an array containing \code{x} repeated the number of times +#' given by \code{count}. +#' +#' @param count a Column or constant determining the number of repetitions. +#' @rdname column_collection_functions +#' @aliases array_repeat array_repeat,Column,numericOrColumn-method +#' @note array_repeat since 2.4.0 +setMethod("array_repeat", + signature(x = "Column", count = "numericOrColumn"), + function(x, count) { + if (class(count) == "Column") { + count <- count@jc + } else { + count <- as.integer(count) + } + jc <- callJStatic("org.apache.spark.sql.functions", "array_repeat", x@jc, count) + column(jc) + }) + #' @details #' \code{array_sort}: Sorts the input array in ascending order. The elements of the input array #' must be orderable. NA elements will be placed at the end of the returned array. @@ -3062,6 +3082,21 @@ setMethod("array_sort", column(jc) }) +#' @details +#' \code{arrays_overlap}: Returns true if the input arrays have at least one non-null element in +#' common. If not and both arrays are non-empty and any of them contains a null, it returns null. +#' It returns false otherwise. +#' +#' @rdname column_collection_functions +#' @aliases arrays_overlap arrays_overlap,Column-method +#' @note arrays_overlap since 2.4.0 +setMethod("arrays_overlap", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "arrays_overlap", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{flatten}: Creates a single array from an array of arrays. #' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. @@ -3076,6 +3111,19 @@ setMethod("flatten", column(jc) }) +#' @details +#' \code{map_entries}: Returns an unordered array of all entries in the given map. +#' +#' @rdname column_collection_functions +#' @aliases map_entries map_entries,Column-method +#' @note map_entries since 2.4.0 +setMethod("map_entries", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_entries", x@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 3ea181157b644..8894cb1c5b92f 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -769,10 +769,18 @@ setGeneric("array_min", function(x) { standardGeneric("array_min") }) #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -1034,6 +1042,10 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @name NULL setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_entries", function(x) { standardGeneric("map_entries") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 13b55ac6e6e3c..16c1fd5a065eb 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1503,6 +1503,21 @@ test_that("column functions", { result <- collect(select(df2, reverse(df2[[1]])))[[1]] expect_equal(result, "cba") + # Test array_repeat() + df <- createDataFrame(list(list("a", 3L), list("b", 2L))) + result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]] + expect_equal(result, list(list("a", "a", "a"), list("b", "b"))) + + result <- collect(select(df, array_repeat(df[[1]], 2L)))[[1]] + expect_equal(result, list(list("a", "a"), list("b", "b"))) + + # Test arrays_overlap() + df <- createDataFrame(list(list(list(1L, 2L), list(3L, 1L)), + list(list(1L, 2L), list(3L, 4L)), + list(list(1L, NA), list(3L, 4L)))) + result <- collect(select(df, arrays_overlap(df[[1]], df[[2]])))[[1]] + expect_equal(result, c(TRUE, FALSE, NA)) + # Test array_sort() and sort_array() df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) @@ -1531,8 +1546,13 @@ test_that("column functions", { result <- collect(select(df, flatten(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L))) - # Test map_keys(), map_values() and element_at() + # Test map_entries(), map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) + result <- collect(select(df, map_entries(df$map)))[[1]] + expected_entries <- list(listToStruct(list(key = "x", value = 1)), + listToStruct(list(key = "y", value = 2))) + expect_equal(result, list(expected_entries)) + result <- collect(select(df, map_keys(df$map)))[[1]] expect_equal(result, list(list("x", "y")))