diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 34d14373b9027..00634c1a70c26 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -39,6 +39,7 @@ exportMethods("arrange",
"describe",
"dim",
"distinct",
+ "dropDuplicates",
"dropna",
"dtypes",
"except",
@@ -271,15 +272,15 @@ export("as.DataFrame",
"createExternalTable",
"dropTempTable",
"jsonFile",
- "read.json",
"loadDF",
"parquetFile",
"read.df",
+ "read.json",
"read.parquet",
"read.text",
"sql",
"str",
- "table",
+ "tableToDF",
"tableNames",
"tables",
"uncacheTable")
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 35695b9df1974..629c1ce2eddc1 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1645,6 +1645,36 @@ setMethod("where",
filter(x, condition)
})
+#' dropDuplicates
+#'
+#' Returns a new DataFrame with duplicate rows removed, considering only
+#' the subset of columns.
+#'
+#' @param x A DataFrame.
+#' @param colnames A character vector of column names.
+#' @return A DataFrame with duplicate rows removed.
+#' @family DataFrame functions
+#' @rdname dropduplicates
+#' @name dropDuplicates
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlContext <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- read.json(sqlContext, path)
+#' dropDuplicates(df)
+#' dropDuplicates(df, c("col1", "col2"))
+#' }
+setMethod("dropDuplicates",
+ signature(x = "DataFrame"),
+ function(x, colNames = columns(x)) {
+ stopifnot(class(colNames) == "character")
+
+ sdf <- callJMethod(x@sdf, "dropDuplicates", as.list(colNames))
+ dataFrame(sdf)
+ })
+
#' Join
#'
#' Join two DataFrames based on the given join expression.
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 99679b4a774d3..16a2578678cd3 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -352,6 +352,8 @@ sql <- function(sqlContext, sqlQuery) {
#' @param sqlContext SQLContext to use
#' @param tableName The SparkSQL Table to convert to a DataFrame.
#' @return DataFrame
+#' @rdname tableToDF
+#' @name tableToDF
#' @export
#' @examples
#'\dontrun{
@@ -360,15 +362,14 @@ sql <- function(sqlContext, sqlQuery) {
#' path <- "path/to/file.json"
#' df <- read.json(sqlContext, path)
#' registerTempTable(df, "table")
-#' new_df <- table(sqlContext, "table")
+#' new_df <- tableToDF(sqlContext, "table")
#' }
-table <- function(sqlContext, tableName) {
+tableToDF <- function(sqlContext, tableName) {
sdf <- callJMethod(sqlContext, "table", tableName)
dataFrame(sdf)
}
-
#' Tables
#'
#' Returns a DataFrame containing names of tables in the given database.
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 860329988f97c..d616266ead41b 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -428,6 +428,13 @@ setGeneric("corr", function(x, ...) {standardGeneric("corr") })
#' @export
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
+#' @rdname dropduplicates
+#' @export
+setGeneric("dropDuplicates",
+ function(x, colNames = columns(x)) {
+ standardGeneric("dropDuplicates")
+ })
+
#' @rdname nafunctions
#' @export
setGeneric("dropna",
diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R
index 1707e314beff5..3b14a497b487a 100644
--- a/R/pkg/inst/tests/testthat/test_context.R
+++ b/R/pkg/inst/tests/testthat/test_context.R
@@ -17,6 +17,30 @@
context("test functions in sparkR.R")
+test_that("Check masked functions", {
+ # Check that we are not masking any new function from base, stats, testthat unexpectedly
+ masked <- conflicts(detail = TRUE)$`package:SparkR`
+ expect_true("describe" %in% masked) # only when with testthat..
+ func <- lapply(masked, function(x) { capture.output(showMethods(x))[[1]] })
+ funcSparkROrEmpty <- grepl("\\(package SparkR\\)$|^$", func)
+ maskedBySparkR <- masked[funcSparkROrEmpty]
+ namesOfMasked <- c("describe", "cov", "filter", "lag", "na.omit", "predict", "sd", "var",
+ "colnames", "colnames<-", "intersect", "rank", "rbind", "sample", "subset",
+ "summary", "transform")
+ expect_equal(length(maskedBySparkR), length(namesOfMasked))
+ expect_equal(sort(maskedBySparkR), sort(namesOfMasked))
+ # above are those reported as masked when `library(SparkR)`
+ # note that many of these methods are still callable without base:: or stats:: prefix
+ # there should be a test for each of these, except followings, which are currently "broken"
+ funcHasAny <- unlist(lapply(masked, function(x) {
+ any(grepl("=\"ANY\"", capture.output(showMethods(x)[-1])))
+ }))
+ maskedCompletely <- masked[!funcHasAny]
+ namesOfMaskedCompletely <- c("cov", "filter", "sample")
+ expect_equal(length(maskedCompletely), length(namesOfMaskedCompletely))
+ expect_equal(sort(maskedCompletely), sort(namesOfMaskedCompletely))
+})
+
test_that("repeatedly starting and stopping SparkR", {
for (i in 1:4) {
sc <- sparkR.init()
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 67ecdbc522d23..14d40d5066e78 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -335,7 +335,6 @@ writeLines(mockLinesMapType, mapTypeJsonPath)
test_that("Collect DataFrame with complex types", {
# ArrayType
df <- read.json(sqlContext, complexTypeJsonPath)
-
ldf <- collect(df)
expect_equal(nrow(ldf), 3)
expect_equal(ncol(ldf), 3)
@@ -490,19 +489,15 @@ test_that("insertInto() on a registered table", {
unlink(parquetPath2)
})
-test_that("table() returns a new DataFrame", {
+test_that("tableToDF() returns a new DataFrame", {
df <- read.json(sqlContext, jsonPath)
registerTempTable(df, "table1")
- tabledf <- table(sqlContext, "table1")
+ tabledf <- tableToDF(sqlContext, "table1")
expect_is(tabledf, "DataFrame")
expect_equal(count(tabledf), 3)
+ tabledf2 <- tableToDF(sqlContext, "table1")
+ expect_equal(count(tabledf2), 3)
dropTempTable(sqlContext, "table1")
-
- # nolint start
- # Test base::table is working
- #a <- letters[1:3]
- #expect_equal(class(table(a, sample(a))), "table")
- # nolint end
})
test_that("toRDD() returns an RRDD", {
@@ -734,7 +729,7 @@ test_that("head() and first() return the correct data", {
expect_equal(ncol(testFirst), 2)
})
-test_that("distinct() and unique on DataFrames", {
+test_that("distinct(), unique() and dropDuplicates() on DataFrames", {
lines <- c("{\"name\":\"Michael\"}",
"{\"name\":\"Andy\", \"age\":30}",
"{\"name\":\"Justin\", \"age\":19}",
@@ -750,6 +745,42 @@ test_that("distinct() and unique on DataFrames", {
uniques2 <- unique(df)
expect_is(uniques2, "DataFrame")
expect_equal(count(uniques2), 3)
+
+ # Test dropDuplicates()
+ df <- createDataFrame(
+ sqlContext,
+ list(
+ list(2, 1, 2), list(1, 1, 1),
+ list(1, 2, 1), list(2, 1, 2),
+ list(2, 2, 2), list(2, 2, 1),
+ list(2, 1, 1), list(1, 1, 2),
+ list(1, 2, 2), list(1, 2, 1)),
+ schema = c("key", "value1", "value2"))
+ result <- collect(dropDuplicates(df))
+ expected <- rbind.data.frame(
+ c(1, 1, 1), c(1, 1, 2), c(1, 2, 1),
+ c(1, 2, 2), c(2, 1, 1), c(2, 1, 2),
+ c(2, 2, 1), c(2, 2, 2))
+ names(expected) <- c("key", "value1", "value2")
+ expect_equivalent(
+ result[order(result$key, result$value1, result$value2),],
+ expected)
+
+ result <- collect(dropDuplicates(df, c("key", "value1")))
+ expected <- rbind.data.frame(
+ c(1, 1, 1), c(1, 2, 1), c(2, 1, 2), c(2, 2, 2))
+ names(expected) <- c("key", "value1", "value2")
+ expect_equivalent(
+ result[order(result$key, result$value1, result$value2),],
+ expected)
+
+ result <- collect(dropDuplicates(df, "key"))
+ expected <- rbind.data.frame(
+ c(1, 1, 1), c(2, 1, 2))
+ names(expected) <- c("key", "value1", "value2")
+ expect_equivalent(
+ result[order(result$key, result$value1, result$value2),],
+ expected)
})
test_that("sample on a DataFrame", {
diff --git a/docs/sparkr.md b/docs/sparkr.md
index ea81532c611e2..73e38b8c70f01 100644
--- a/docs/sparkr.md
+++ b/docs/sparkr.md
@@ -375,13 +375,6 @@ The following functions are masked by the SparkR package:
sample in package:base |
base::sample(x, size, replace = FALSE, prob = NULL) |
-
- table in package:base |
- base::table(...,
- exclude = if (useNA == "no") c(NA, NaN),
- useNA = c("no", "ifany", "always"),
- dnn = list.names(...), deparse.level = 1)
|
-
Since part of SparkR is modeled on the `dplyr` package, certain functions in SparkR share the same names with those in `dplyr`. Depending on the load order of the two packages, some functions from the package loaded first are masked by those in the package loaded after. In such case, prefix such calls with the package name, for instance, `SparkR::cume_dist(x)` or `dplyr::cume_dist(x)`.
@@ -394,3 +387,7 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma
## Upgrading From SparkR 1.5.x to 1.6
- Before Spark 1.6, the default mode for writes was `append`. It was changed in Spark 1.6.0 to `error` to match the Scala API.
+
+## Upgrading From SparkR 1.6.x to 2.0
+
+ - The method `table` has been removed and replaced by `tableToDF`.
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 9ea639dc4f960..4eb17bfdcca90 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -394,7 +394,6 @@ def test_fit_maximize_metric(self):
if __name__ == "__main__":
- from pyspark.ml.tests import *
if xmlrunner:
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
else:
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index ea7d297cba2ae..32ed48e10388e 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -77,24 +77,21 @@
pass
ser = PickleSerializer()
+sc = SparkContext('local[4]', "MLlib tests")
class MLlibTestCase(unittest.TestCase):
def setUp(self):
- self.sc = SparkContext('local[4]', "MLlib tests")
-
- def tearDown(self):
- self.sc.stop()
+ self.sc = sc
class MLLibStreamingTestCase(unittest.TestCase):
def setUp(self):
- self.sc = SparkContext('local[4]', "MLlib tests")
+ self.sc = sc
self.ssc = StreamingContext(self.sc, 1.0)
def tearDown(self):
self.ssc.stop(False)
- self.sc.stop()
@staticmethod
def _eventually(condition, timeout=30.0, catch_assertions=False):
@@ -1169,7 +1166,7 @@ def test_predictOn_model(self):
clusterWeights=[1.0, 1.0, 1.0, 1.0])
predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]]
- predict_data = [self.sc.parallelize(batch, 1) for batch in predict_data]
+ predict_data = [sc.parallelize(batch, 1) for batch in predict_data]
predict_stream = self.ssc.queueStream(predict_data)
predict_val = stkm.predictOn(predict_stream)
@@ -1200,7 +1197,7 @@ def test_trainOn_predictOn(self):
# classification based in the initial model would have been 0
# proving that the model is updated.
batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]]
- batches = [self.sc.parallelize(batch) for batch in batches]
+ batches = [sc.parallelize(batch) for batch in batches]
input_stream = self.ssc.queueStream(batches)
predict_results = []
@@ -1233,7 +1230,7 @@ def test_dim(self):
self.assertEqual(len(point.features), 3)
linear_data = LinearDataGenerator.generateLinearRDD(
- sc=self.sc, nexamples=6, nfeatures=2, eps=0.1,
+ sc=sc, nexamples=6, nfeatures=2, eps=0.1,
nParts=2, intercept=0.0).collect()
self.assertEqual(len(linear_data), 6)
for point in linear_data:
@@ -1409,7 +1406,7 @@ def test_parameter_accuracy(self):
for i in range(10):
batch = LinearDataGenerator.generateLinearInput(
0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1)
- batches.append(self.sc.parallelize(batch))
+ batches.append(sc.parallelize(batch))
input_stream = self.ssc.queueStream(batches)
slr.trainOn(input_stream)
@@ -1433,7 +1430,7 @@ def test_parameter_convergence(self):
for i in range(10):
batch = LinearDataGenerator.generateLinearInput(
0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1)
- batches.append(self.sc.parallelize(batch))
+ batches.append(sc.parallelize(batch))
model_weights = []
input_stream = self.ssc.queueStream(batches)
@@ -1466,7 +1463,7 @@ def test_prediction(self):
0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0],
100, 42 + i, 0.1)
batches.append(
- self.sc.parallelize(batch).map(lambda lp: (lp.label, lp.features)))
+ sc.parallelize(batch).map(lambda lp: (lp.label, lp.features)))
input_stream = self.ssc.queueStream(batches)
output_stream = slr.predictOnValues(input_stream)
@@ -1497,7 +1494,7 @@ def test_train_prediction(self):
for i in range(10):
batch = LinearDataGenerator.generateLinearInput(
0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1)
- batches.append(self.sc.parallelize(batch))
+ batches.append(sc.parallelize(batch))
predict_batches = [
b.map(lambda lp: (lp.label, lp.features)) for b in batches]
@@ -1583,7 +1580,6 @@ def test_als_ratings_id_long_error(self):
if __name__ == "__main__":
- from pyspark.mllib.tests import *
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")
if xmlrunner:
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index ae8620274dd20..c03cb9338ae68 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1259,7 +1259,6 @@ def test_collect_functions(self):
if __name__ == "__main__":
- from pyspark.sql.tests import *
if xmlrunner:
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
else:
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 24b812615cbb4..86b05d9fd2424 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -1635,7 +1635,6 @@ def search_kinesis_asl_assembly_jar():
are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1'
if __name__ == "__main__":
- from pyspark.streaming.tests import *
kafka_assembly_jar = search_kafka_assembly_jar()
flume_assembly_jar = search_flume_assembly_jar()
mqtt_assembly_jar = search_mqtt_assembly_jar()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 23720502a82c8..5bd94476597ab 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -2008,7 +2008,6 @@ def test_statcounter_array(self):
if __name__ == "__main__":
- from pyspark.tests import *
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")
if not _have_numpy:
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 9257fba60e36c..d4b4bc88b3f2f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -297,7 +297,7 @@ class Analyzer(
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
object ResolveRelations extends Rule[LogicalPlan] {
- def getTable(u: UnresolvedRelation): LogicalPlan = {
+ private def getTable(u: UnresolvedRelation): LogicalPlan = {
try {
catalog.lookupRelation(u.tableIdentifier, u.alias)
} catch {
@@ -1165,7 +1165,7 @@ class Analyzer(
* scoping information for attributes and can be removed once analysis is complete.
*/
object EliminateSubQueries extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case Subquery(_, child) => child
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index cc3371c08fac4..04643f0274bd4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -35,11 +35,16 @@ import org.apache.spark.sql.types._
*/
abstract class Optimizer extends RuleExecutor[LogicalPlan] {
def batches: Seq[Batch] = {
- // SubQueries are only needed for analysis and can be removed before execution.
- Batch("Remove SubQueries", FixedPoint(100),
- EliminateSubQueries) ::
- Batch("Compute Current Time", Once,
+ // Technically some of the rules in Finish Analysis are not optimizer rules and belong more
+ // in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime).
+ // However, because we also use the analyzer to canonicalized queries (for view definition),
+ // we do not eliminate subqueries or compute current time in the analyzer.
+ Batch("Finish Analysis", Once,
+ EliminateSubQueries,
ComputeCurrentTime) ::
+ //////////////////////////////////////////////////////////////////////////////////////////
+ // Optimizer rules start here
+ //////////////////////////////////////////////////////////////////////////////////////////
Batch("Aggregate", FixedPoint(100),
ReplaceDistinctWithAggregate,
RemoveLiteralFromGroupExpressions) ::
@@ -57,7 +62,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
ProjectCollapsing,
CombineFilters,
CombineLimits,
- // Constant folding
+ // Constant folding and strength reduction
NullPropagation,
OptimizeIn,
ConstantFolding,
@@ -635,6 +640,24 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case q: LogicalPlan => q transformExpressionsUp {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
+
+ case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) =>
+ // If there are branches that are always false, remove them.
+ // If there are no more branches left, just use the else value.
+ // Note that these two are handled together here in a single case statement because
+ // otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
+ val newBranches = branches.filter(_._1 != FalseLiteral)
+ if (newBranches.isEmpty) {
+ elseValue.getOrElse(Literal.create(null, e.dataType))
+ } else {
+ e.copy(branches = newBranches)
+ }
+
+ case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) =>
+ // If the first branch is a true literal, remove the entire CaseWhen and use the value
+ // from that. Note that CaseWhen.branches should never be empty, and as a result the
+ // headOption (rather than head) added above is just a extra (and unnecessary) safeguard.
+ branches.head._2
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index d74f3ef2ffba6..57e1a3c9eb226 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -244,6 +244,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* When `rule` does not apply to a given node it is left unchanged.
* Users should not expect a specific directionality. If a specific directionality is needed,
* transformDown or transformUp should be used.
+ *
* @param rule the function use to transform this nodes children
*/
def transform(rule: PartialFunction[BaseType, BaseType]): BaseType = {
@@ -253,6 +254,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
/**
* Returns a copy of this node where `rule` has been recursively applied to it and all of its
* children (pre-order). When `rule` does not apply to a given node it is left unchanged.
+ *
* @param rule the function used to transform this nodes children
*/
def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = {
@@ -268,6 +270,26 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
}
+ /**
+ * Returns a copy of this node where `rule` has been recursively applied first to all of its
+ * children and then itself (post-order). When `rule` does not apply to a given node, it is left
+ * unchanged.
+ *
+ * @param rule the function use to transform this nodes children
+ */
+ def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
+ val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r))
+ if (this fastEquals afterRuleOnChildren) {
+ CurrentOrigin.withOrigin(origin) {
+ rule.applyOrElse(this, identity[BaseType])
+ }
+ } else {
+ CurrentOrigin.withOrigin(origin) {
+ rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
+ }
+ }
+ }
+
/**
* Returns a copy of this node where `rule` has been recursively applied to all the children of
* this node. When `rule` does not apply to a given node it is left unchanged.
@@ -332,25 +354,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
if (changed) makeCopy(newArgs) else this
}
- /**
- * Returns a copy of this node where `rule` has been recursively applied first to all of its
- * children and then itself (post-order). When `rule` does not apply to a given node, it is left
- * unchanged.
- * @param rule the function use to transform this nodes children
- */
- def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
- val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r))
- if (this fastEquals afterRuleOnChildren) {
- CurrentOrigin.withOrigin(origin) {
- rule.applyOrElse(this, identity[BaseType])
- }
- } else {
- CurrentOrigin.withOrigin(origin) {
- rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
- }
- }
- }
-
/**
* Args to the constructor that should be copied, but not transformed.
* These are appended to the transformed args automatically by makeCopy
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubQueriesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubQueriesSuite.scala
new file mode 100644
index 0000000000000..e0d430052fb55
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubQueriesSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis
+import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+
+class EliminateSubQueriesSuite extends PlanTest with PredicateHelper {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches = Batch("EliminateSubQueries", Once, EliminateSubQueries) :: Nil
+ }
+
+ private def assertEquivalent(e1: Expression, e2: Expression): Unit = {
+ val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
+ val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
+ comparePlans(actual, correctAnswer)
+ }
+
+ private def afterOptimization(plan: LogicalPlan): LogicalPlan = {
+ Optimize.execute(analysis.SimpleAnalyzer.execute(plan))
+ }
+
+ test("eliminate top level subquery") {
+ val input = LocalRelation('a.int, 'b.int)
+ val query = Subquery("a", input)
+ comparePlans(afterOptimization(query), input)
+ }
+
+ test("eliminate mid-tree subquery") {
+ val input = LocalRelation('a.int, 'b.int)
+ val query = Filter(TrueLiteral, Subquery("a", input))
+ comparePlans(
+ afterOptimization(query),
+ Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
+ }
+
+ test("eliminate multiple subqueries") {
+ val input = LocalRelation('a.int, 'b.int)
+ val query = Filter(TrueLiteral, Subquery("c", Subquery("b", Subquery("a", input))))
+ comparePlans(
+ afterOptimization(query),
+ Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
+ }
+
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
index 8e5d7ef3c9d49..d436b627f6bd2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types.IntegerType
class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
@@ -37,6 +38,10 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
comparePlans(actual, correctAnswer)
}
+ private val trueBranch = (TrueLiteral, Literal(5))
+ private val normalBranch = (NonFoldableLiteral(true), Literal(10))
+ private val unreachableBranch = (FalseLiteral, Literal(20))
+
test("simplify if") {
assertEquivalent(
If(TrueLiteral, Literal(10), Literal(20)),
@@ -47,4 +52,36 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
Literal(20))
}
+ test("remove unreachable branches") {
+ // i.e. removing branches whose conditions are always false
+ assertEquivalent(
+ CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None),
+ CaseWhen(normalBranch :: Nil, None))
+ }
+
+ test("remove entire CaseWhen if only the else branch is reachable") {
+ assertEquivalent(
+ CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))),
+ Literal(30))
+
+ assertEquivalent(
+ CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
+ Literal.create(null, IntegerType))
+ }
+
+ test("remove entire CaseWhen if the first branch is always true") {
+ assertEquivalent(
+ CaseWhen(trueBranch :: normalBranch :: Nil, None),
+ Literal(5))
+
+ // Test branch elimination and simplification in combination
+ assertEquivalent(
+ CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None),
+ Literal(5))
+
+ // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
+ assertEquivalent(
+ CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
+ CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala
index 127c9728da2d1..676a3d3bca9f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala
@@ -19,7 +19,10 @@ package org.apache.spark.sql.execution.datasources.csv
import java.nio.charset.Charset
+import org.apache.hadoop.io.compress._
+
import org.apache.spark.Logging
+import org.apache.spark.util.Utils
private[sql] case class CSVParameters(@transient parameters: Map[String, String]) extends Logging {
@@ -35,7 +38,7 @@ private[sql] case class CSVParameters(@transient parameters: Map[String, String]
private def getBool(paramName: String, default: Boolean = false): Boolean = {
val param = parameters.getOrElse(paramName, default.toString)
- if (param.toLowerCase() == "true") {
+ if (param.toLowerCase == "true") {
true
} else if (param.toLowerCase == "false") {
false
@@ -73,6 +76,11 @@ private[sql] case class CSVParameters(@transient parameters: Map[String, String]
val nullValue = parameters.getOrElse("nullValue", "")
+ val compressionCodec: Option[String] = {
+ val name = parameters.get("compression").orElse(parameters.get("codec"))
+ name.map(CSVCompressionCodecs.getCodecClassName)
+ }
+
val maxColumns = 20480
val maxCharsPerColumn = 100000
@@ -85,7 +93,6 @@ private[sql] case class CSVParameters(@transient parameters: Map[String, String]
}
private[csv] object ParseModes {
-
val PERMISSIVE_MODE = "PERMISSIVE"
val DROP_MALFORMED_MODE = "DROPMALFORMED"
val FAIL_FAST_MODE = "FAILFAST"
@@ -107,3 +114,28 @@ private[csv] object ParseModes {
true // We default to permissive is the mode string is not valid
}
}
+
+private[csv] object CSVCompressionCodecs {
+ private val shortCompressionCodecNames = Map(
+ "bzip2" -> classOf[BZip2Codec].getName,
+ "gzip" -> classOf[GzipCodec].getName,
+ "lz4" -> classOf[Lz4Codec].getName,
+ "snappy" -> classOf[SnappyCodec].getName)
+
+ /**
+ * Return the full version of the given codec class.
+ * If it is already a class name, just return it.
+ */
+ def getCodecClassName(name: String): String = {
+ val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase, name)
+ try {
+ // Validate the codec name
+ Utils.classForName(codecName)
+ codecName
+ } catch {
+ case e: ClassNotFoundException =>
+ throw new IllegalArgumentException(s"Codec [$codecName] " +
+ s"is not available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.")
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index 53818853ffb3b..1502501c3b89e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -24,6 +24,7 @@ import scala.util.control.NonFatal
import com.google.common.base.Objects
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
+import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.hadoop.mapreduce.RecordWriter
@@ -99,6 +100,15 @@ private[csv] class CSVRelation(
}
override def prepareJobForWrite(job: Job): OutputWriterFactory = {
+ val conf = job.getConfiguration
+ params.compressionCodec.foreach { codec =>
+ conf.set("mapreduce.output.fileoutputformat.compress", "true")
+ conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
+ conf.set("mapreduce.output.fileoutputformat.compress.codec", codec)
+ conf.set("mapreduce.map.output.compress", "true")
+ conf.set("mapreduce.map.output.compress.codec", codec)
+ }
+
new CSVOutputWriterFactory(params)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 071b5ef56d58b..a79566b1f3658 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -349,4 +349,30 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(results(0).toSeq === Array(2012, "Tesla", "S", "null", "null"))
assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null))
}
+
+ test("save csv with compression codec option") {
+ withTempDir { dir =>
+ val csvDir = new File(dir, "csv").getCanonicalPath
+ val cars = sqlContext.read
+ .format("csv")
+ .option("header", "true")
+ .load(testFile(carsFile))
+
+ cars.coalesce(1).write
+ .format("csv")
+ .option("header", "true")
+ .option("compression", "gZiP")
+ .save(csvDir)
+
+ val compressedFiles = new File(csvDir).listFiles()
+ assert(compressedFiles.exists(_.getName.endsWith(".gz")))
+
+ val carsCopy = sqlContext.read
+ .format("csv")
+ .option("header", "true")
+ .load(csvDir)
+
+ verifyCars(carsCopy, withHeader = true)
+ }
+ }
}