diff --git a/common/tags/src/test/java/org/apache/spark/tags/ExtendedRedshiftTest.java b/common/tags/src/test/java/org/apache/spark/tags/ExtendedRedshiftTest.java
new file mode 100644
index 0000000000000..4d76aa3adbc81
--- /dev/null
+++ b/common/tags/src/test/java/org/apache/spark/tags/ExtendedRedshiftTest.java
@@ -0,0 +1,19 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package org.apache.spark.tags;
+
+import java.lang.annotation.*;
+
+import org.scalatest.TagAnnotation;
+
+@TagAnnotation
+@Retention(RetentionPolicy.RUNTIME)
+@Target({ElementType.METHOD, ElementType.TYPE})
+public @interface ExtendedRedshiftTest { }
diff --git a/dev/.rat-excludes b/dev/.rat-excludes
index 6be1c72bc6cfb..d9022e241b37e 100644
--- a/dev/.rat-excludes
+++ b/dev/.rat-excludes
@@ -103,3 +103,4 @@ org.apache.spark.scheduler.ExternalClusterManager
org.apache.spark.deploy.yarn.security.ServiceCredentialProvider
spark-warehouse
structured-streaming/*
+install-redshift-jdbc.sh
diff --git a/dev/install-redshift-jdbc.sh b/dev/install-redshift-jdbc.sh
new file mode 100755
index 0000000000000..8719d0f6c5632
--- /dev/null
+++ b/dev/install-redshift-jdbc.sh
@@ -0,0 +1,21 @@
+#!/usr/bin/env bash
+
+set -e
+
+SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
+SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")"
+
+cd /tmp
+
+VERSION='1.1.7.1007'
+FILENAME="RedshiftJDBC4-$VERSION.jar"
+
+wget "https://s3.amazonaws.com/redshift-downloads/drivers/$FILENAME"
+
+$SPARK_ROOT_DIR/build/mvn install:install-file \
+ -Dfile=$FILENAME \
+ -DgroupId=com.amazonaws \
+ -DartifactId=redshift.jdbc4 \
+ -Dversion=$VERSION \
+ -Dpackaging=jar
+
diff --git a/dev/run-tests.py b/dev/run-tests.py
index a9692ab0c1350..1be74fa25aac1 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -111,7 +111,8 @@ def determine_modules_to_test(changed_modules):
>>> x = [x.name for x in determine_modules_to_test([modules.sql])]
>>> x # doctest: +NORMALIZE_WHITESPACE
['sql', 'avro', 'hive', 'mllib', 'sql-kafka-0-10', 'sql-kafka-0-8', 'examples',
- 'hive-thriftserver', 'pyspark-sql', 'sparkr', 'pyspark-mllib', 'pyspark-ml']
+ 'hive-thriftserver', 'pyspark-sql', 'redshift', 'sparkr', 'pyspark-mllib',
+ 'redshift-integration-tests', 'pyspark-ml']
"""
modules_to_test = set()
for module in changed_modules:
@@ -512,6 +513,8 @@ def main():
test_env = "amplab_jenkins"
# add path for Python3 in Jenkins if we're calling from a Jenkins machine
os.environ["PATH"] = "/home/anaconda/envs/py3k/bin:" + os.environ.get("PATH")
+ # Install Redshift JDBC
+ run_cmd([os.path.join(SPARK_HOME, "dev", "install-redshift-jdbc.sh")])
else:
# else we're running locally and can use local settings
build_tool = "sbt"
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index b7fc30854f13b..1cb01128b2d91 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -168,6 +168,31 @@ def __hash__(self):
]
)
+redshift = Module(
+ name="redshift",
+ dependencies=[avro, sql],
+ source_file_regexes=[
+ "external/redshift",
+ ],
+ sbt_test_goals=[
+ "redshift/test",
+ ],
+ test_tags=[
+ "org.apache.spark.tags.ExtendedRedshiftTest"
+ ],
+)
+
+redshift_integration_tests = Module(
+ name="redshift-integration-tests",
+ dependencies=[redshift],
+ source_file_regexes=[
+ "external/redshift-integration-tests",
+ ],
+ sbt_test_goals=[
+ "redshift-integration-tests/test",
+ ],
+)
+
sql_kafka = Module(
name="sql-kafka-0-10",
dependencies=[sql],
diff --git a/external/redshift-integration-tests/pom.xml b/external/redshift-integration-tests/pom.xml
new file mode 100644
index 0000000000000..654b2354ee48b
--- /dev/null
+++ b/external/redshift-integration-tests/pom.xml
@@ -0,0 +1,144 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.11
+ 2.1.0
+ ../../pom.xml
+
+
+ com.databricks
+ spark-redshift-integration-tests_2.11
+
+ redshift-integration-tests
+
+ jar
+ Spark Redshift Integration Tests
+ http://spark.apache.org/
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ com.databricks
+ spark-avro_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ com.databricks
+ spark-redshift_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ com.databricks
+ spark-redshift_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+
+ com.amazonaws
+ aws-java-sdk-core
+ 1.9.40
+ provided
+
+
+ com.amazonaws
+ aws-java-sdk-s3
+ 1.9.40
+ provided
+
+
+ com.amazonaws
+ aws-java-sdk-sts
+ 1.9.40
+ provided
+
+
+ com.eclipsesource.minimal-json
+ minimal-json
+ 0.9.4
+ compile
+
+
+ org.apache.hadoop
+ hadoop-client
+ test
+
+
+ org.apache.hadoop
+ hadoop-common
+ ${hadoop.version}
+ test
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-hive_${scala.binary.version}
+ ${project.version}
+ test
+
+
+
+ com.amazonaws
+ redshift.jdbc4
+ 1.1.7.1007
+ jar
+ test
+
+
+ org.mockito
+ mockito-core
+ test
+
+
+ org.apache.spark
+ spark-tags_${scala.binary.version}
+ test-jar
+ test
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/AWSCredentialsInUriIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/AWSCredentialsInUriIntegrationSuite.scala
new file mode 100755
index 0000000000000..e36d80a2c4064
--- /dev/null
+++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/AWSCredentialsInUriIntegrationSuite.scala
@@ -0,0 +1,52 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.net.URI
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.tags.ExtendedRedshiftTest
+
+/**
+ * This suite performs basic integration tests where the AWS credentials have been
+ * encoded into the tempdir URI rather than being set in the Hadoop configuration.
+ */
+@ExtendedRedshiftTest
+class AWSCredentialsInUriIntegrationSuite extends IntegrationSuiteBase {
+
+ override protected val tempDir: String = {
+ val uri = new URI(AWS_S3_SCRATCH_SPACE + randomSuffix + "/")
+ new URI(
+ uri.getScheme,
+ s"$AWS_ACCESS_KEY_ID:$AWS_SECRET_ACCESS_KEY",
+ uri.getHost,
+ uri.getPort,
+ uri.getPath,
+ uri.getQuery,
+ uri.getFragment).toString
+ }
+
+
+ // Override this method so that we do not set the credentials in sc.hadoopConf.
+ override def beforeAll(): Unit = {
+ assert(tempDir.contains("AKIA"), "tempdir did not contain AWS credentials")
+ assert(!AWS_SECRET_ACCESS_KEY.contains("/"), "AWS secret key should not contain slash")
+ sc = new SparkContext("local", getClass.getSimpleName)
+ conn = DefaultJDBCWrapper.getConnector(None, jdbcUrl, None)
+ }
+
+ test("roundtrip save and load") {
+ val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1),
+ StructType(StructField("foo", IntegerType) :: Nil))
+ testRoundtripSaveAndLoad(s"roundtrip_save_and_load_$randomSuffix", df)
+ }
+}
diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/ColumnMetadataSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/ColumnMetadataSuite.scala
new file mode 100755
index 0000000000000..c27d75dcf0d98
--- /dev/null
+++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/ColumnMetadataSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.sql.SQLException
+
+import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.types._
+import org.apache.spark.tags.ExtendedRedshiftTest
+
+/**
+ * End-to-end tests of features which depend on per-column metadata (such as comments, maxlength).
+ */
+@ExtendedRedshiftTest
+class ColumnMetadataSuite extends IntegrationSuiteBase {
+
+ test("configuring maxlength on string columns") {
+ val tableName = s"configuring_maxlength_on_string_column_$randomSuffix"
+ try {
+ val metadata = new MetadataBuilder().putLong("maxlength", 512).build()
+ val schema = StructType(
+ StructField("x", StringType, metadata = metadata) :: Nil)
+ write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 512))), schema))
+ .option("dbtable", tableName)
+ .mode(SaveMode.ErrorIfExists)
+ .save()
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ checkAnswer(read.option("dbtable", tableName).load(), Seq(Row("a" * 512)))
+ // This append should fail due to the string being longer than the maxlength
+ intercept[SQLException] {
+ write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 513))), schema))
+ .option("dbtable", tableName)
+ .mode(SaveMode.Append)
+ .save()
+ }
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+
+ test("configuring compression on columns") {
+ val tableName = s"configuring_compression_on_columns_$randomSuffix"
+ try {
+ val metadata = new MetadataBuilder().putString("encoding", "LZO").build()
+ val schema = StructType(
+ StructField("x", StringType, metadata = metadata) :: Nil)
+ write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 128))), schema))
+ .option("dbtable", tableName)
+ .mode(SaveMode.ErrorIfExists)
+ .save()
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ checkAnswer(read.option("dbtable", tableName).load(), Seq(Row("a" * 128)))
+ val encodingDF = sqlContext.read
+ .format("jdbc")
+ .option("url", jdbcUrl)
+ .option("dbtable",
+ s"""(SELECT "column", lower(encoding) FROM pg_table_def WHERE tablename='$tableName')""")
+ .load()
+ checkAnswer(encodingDF, Seq(Row("x", "lzo")))
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+
+ test("configuring comments on columns") {
+ val tableName = s"configuring_comments_on_columns_$randomSuffix"
+ try {
+ val metadata = new MetadataBuilder().putString("description", "Hello Column").build()
+ val schema = StructType(
+ StructField("x", StringType, metadata = metadata) :: Nil)
+ write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 128))), schema))
+ .option("dbtable", tableName)
+ .option("description", "Hello Table")
+ .mode(SaveMode.ErrorIfExists)
+ .save()
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ checkAnswer(read.option("dbtable", tableName).load(), Seq(Row("a" * 128)))
+ val tableDF = sqlContext.read
+ .format("jdbc")
+ .option("url", jdbcUrl)
+ .option("dbtable", s"(SELECT pg_catalog.obj_description('$tableName'::regclass))")
+ .load()
+ checkAnswer(tableDF, Seq(Row("Hello Table")))
+ val commentQuery =
+ s"""
+ |(SELECT c.column_name, pgd.description
+ |FROM pg_catalog.pg_statio_all_tables st
+ |INNER JOIN pg_catalog.pg_description pgd
+ | ON (pgd.objoid=st.relid)
+ |INNER JOIN information_schema.columns c
+ | ON (pgd.objsubid=c.ordinal_position AND c.table_name=st.relname)
+ |WHERE c.table_name='$tableName')
+ """.stripMargin
+ val columnDF = sqlContext.read
+ .format("jdbc")
+ .option("url", jdbcUrl)
+ .option("dbtable", commentQuery)
+ .load()
+ checkAnswer(columnDF, Seq(Row("x", "Hello Column")))
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+}
diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/CrossRegionIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/CrossRegionIntegrationSuite.scala
new file mode 100755
index 0000000000000..b4fb34af86c62
--- /dev/null
+++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/CrossRegionIntegrationSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import com.amazonaws.auth.BasicAWSCredentials
+import com.amazonaws.services.s3.AmazonS3Client
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.tags.ExtendedRedshiftTest
+
+/**
+ * Integration tests where the Redshift cluster and the S3 bucket are in different AWS regions.
+ */
+@ExtendedRedshiftTest
+class CrossRegionIntegrationSuite extends IntegrationSuiteBase {
+
+ protected val AWS_S3_CROSS_REGION_SCRATCH_SPACE: String =
+ loadConfigFromEnv("AWS_S3_CROSS_REGION_SCRATCH_SPACE")
+ require(AWS_S3_CROSS_REGION_SCRATCH_SPACE.contains("s3n"), "must use s3n:// URL")
+
+ override protected val tempDir: String = AWS_S3_CROSS_REGION_SCRATCH_SPACE + randomSuffix + "/"
+
+ test("write") {
+ val bucketRegion = Utils.getRegionForS3Bucket(
+ tempDir,
+ new AmazonS3Client(new BasicAWSCredentials(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY))).get
+ val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1),
+ StructType(StructField("foo", IntegerType) :: Nil))
+ val tableName = s"roundtrip_save_and_load_$randomSuffix"
+ try {
+ write(df)
+ .option("dbtable", tableName)
+ .option("extracopyoptions", s"region '$bucketRegion'")
+ .save()
+ // Check that the table exists. It appears that creating a table in one connection then
+ // immediately querying for existence from another connection may result in spurious "table
+ // doesn't exist" errors; this caused the "save with all empty partitions" test to become
+ // flaky (see #146). To work around this, add a small sleep and check again:
+ if (!DefaultJDBCWrapper.tableExists(conn, tableName)) {
+ Thread.sleep(1000)
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ }
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+}
diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/DecimalIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/DecimalIntegrationSuite.scala
new file mode 100755
index 0000000000000..e3078b9457eaf
--- /dev/null
+++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/DecimalIntegrationSuite.scala
@@ -0,0 +1,93 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.DecimalType
+import org.apache.spark.tags.ExtendedRedshiftTest
+
+/**
+ * Integration tests for decimal support. For a reference on Redshift's DECIMAL type, see
+ * http://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html
+ */
+@ExtendedRedshiftTest
+class DecimalIntegrationSuite extends IntegrationSuiteBase {
+
+ private def testReadingDecimals(precision: Int, scale: Int, decimalStrings: Seq[String]): Unit = {
+ test(s"reading DECIMAL($precision, $scale)") {
+ val tableName = s"reading_decimal_${precision}_${scale}_$randomSuffix"
+ val expectedRows = decimalStrings.map { d =>
+ if (d == null) {
+ Row(null)
+ } else {
+ Row(Conversions.createRedshiftDecimalFormat().parse(d).asInstanceOf[java.math.BigDecimal])
+ }
+ }
+ try {
+ conn.createStatement().executeUpdate(
+ s"CREATE TABLE $tableName (x DECIMAL($precision, $scale))")
+ for (x <- decimalStrings) {
+ conn.createStatement().executeUpdate(s"INSERT INTO $tableName VALUES ($x)")
+ }
+ conn.commit()
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ val loadedDf = read.option("dbtable", tableName).load()
+ checkAnswer(loadedDf, expectedRows)
+ checkAnswer(loadedDf.selectExpr("x + 0"), expectedRows)
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+ }
+
+ testReadingDecimals(19, 0, Seq(
+ // Max and min values of DECIMAL(19, 0) column according to Redshift docs:
+ "9223372036854775807", // 2^63 - 1
+ "-9223372036854775807",
+ "0",
+ "12345678910",
+ null
+ ))
+
+ testReadingDecimals(19, 4, Seq(
+ "922337203685477.5807",
+ "-922337203685477.5807",
+ "0",
+ "1234567.8910",
+ null
+ ))
+
+ testReadingDecimals(38, 4, Seq(
+ "922337203685477.5808",
+ "9999999999999999999999999999999999.0000",
+ "-9999999999999999999999999999999999.0000",
+ "0",
+ "1234567.8910",
+ null
+ ))
+
+ test("Decimal precision is preserved when reading from query (regression test for issue #203)") {
+ withTempRedshiftTable("issue203") { tableName =>
+ try {
+ conn.createStatement().executeUpdate(s"CREATE TABLE $tableName (foo BIGINT)")
+ conn.createStatement().executeUpdate(s"INSERT INTO $tableName VALUES (91593373)")
+ conn.commit()
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ val df = read
+ .option("query", s"select foo / 1000000.0 from $tableName limit 1")
+ .load()
+ val res: Double = df.collect().toSeq.head.getDecimal(0).doubleValue()
+ assert(res === (91593373L / 1000000.0) +- 0.01)
+ assert(df.schema.fields.head.dataType === DecimalType(28, 8))
+ }
+ }
+ }
+}
diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/IAMIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/IAMIntegrationSuite.scala
new file mode 100755
index 0000000000000..c8c298166f5ef
--- /dev/null
+++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/IAMIntegrationSuite.scala
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.sql.SQLException
+
+import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.tags.ExtendedRedshiftTest
+
+/**
+ * Integration tests for configuring Redshift to access S3 using Amazon IAM roles.
+ */
+@ExtendedRedshiftTest
+class IAMIntegrationSuite extends IntegrationSuiteBase {
+
+ private val IAM_ROLE_ARN: String = loadConfigFromEnv("STS_ROLE_ARN")
+
+ test("roundtrip save and load") {
+ val tableName = s"iam_roundtrip_save_and_load$randomSuffix"
+ val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
+ StructType(StructField("a", IntegerType) :: Nil))
+ try {
+ write(df)
+ .option("dbtable", tableName)
+ .option("forward_spark_s3_credentials", "false")
+ .option("aws_iam_role", IAM_ROLE_ARN)
+ .mode(SaveMode.ErrorIfExists)
+ .save()
+
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ val loadedDf = read
+ .option("dbtable", tableName)
+ .option("forward_spark_s3_credentials", "false")
+ .option("aws_iam_role", IAM_ROLE_ARN)
+ .load()
+ assert(loadedDf.schema.length === 1)
+ assert(loadedDf.columns === Seq("a"))
+ checkAnswer(loadedDf, Seq(Row(1)))
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+
+ test("load fails if IAM role cannot be assumed") {
+ val tableName = s"iam_load_fails_if_role_cannot_be_assumed$randomSuffix"
+ try {
+ val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
+ StructType(StructField("a", IntegerType) :: Nil))
+ val err = intercept[SQLException] {
+ write(df)
+ .option("dbtable", tableName)
+ .option("forward_spark_s3_credentials", "false")
+ .option("aws_iam_role", IAM_ROLE_ARN + "-some-bogus-suffix")
+ .mode(SaveMode.ErrorIfExists)
+ .save()
+ }
+ assert(err.getCause.getMessage.contains("is not authorized to assume IAM Role"))
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+}
diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/IntegrationSuiteBase.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/IntegrationSuiteBase.scala
new file mode 100644
index 0000000000000..6c4026cab2935
--- /dev/null
+++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/IntegrationSuiteBase.scala
@@ -0,0 +1,217 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.net.URI
+import java.sql.Connection
+
+import scala.util.Random
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.s3native.NativeS3FileSystem
+import org.scalatest.{BeforeAndAfterEach, Matchers}
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql._
+import org.apache.spark.sql.hive.test.TestHiveContext
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Base class for writing integration tests which run against a real Redshift cluster.
+ */
+trait IntegrationSuiteBase
+ extends QueryTest
+ with Matchers
+ with BeforeAndAfterEach {
+
+ protected def loadConfigFromEnv(envVarName: String): String = {
+ Option(System.getenv(envVarName)).getOrElse {
+ fail(s"Must set $envVarName environment variable")
+ }
+ }
+
+ // The following configurations must be set in order to run these tests. In Travis, these
+ // environment variables are set using Travis's encrypted environment variables feature:
+ // http://docs.travis-ci.com/user/environment-variables/#Encrypted-Variables
+
+ // JDBC URL listed in the AWS console (should not contain username and password).
+ protected val AWS_REDSHIFT_JDBC_URL: String = loadConfigFromEnv("AWS_REDSHIFT_JDBC_URL")
+ protected val AWS_REDSHIFT_USER: String = loadConfigFromEnv("AWS_REDSHIFT_USER")
+ protected val AWS_REDSHIFT_PASSWORD: String = loadConfigFromEnv("AWS_REDSHIFT_PASSWORD")
+ protected val AWS_ACCESS_KEY_ID: String = loadConfigFromEnv("TEST_AWS_ACCESS_KEY_ID")
+ protected val AWS_SECRET_ACCESS_KEY: String = loadConfigFromEnv("TEST_AWS_SECRET_ACCESS_KEY")
+ // Path to a directory in S3 (e.g. 's3n://bucket-name/path/to/scratch/space').
+ protected val AWS_S3_SCRATCH_SPACE: String = loadConfigFromEnv("AWS_S3_SCRATCH_SPACE")
+ require(AWS_S3_SCRATCH_SPACE.contains("s3n"), "must use s3n:// URL")
+
+ protected def jdbcUrl: String = {
+ s"$AWS_REDSHIFT_JDBC_URL?user=$AWS_REDSHIFT_USER&password=$AWS_REDSHIFT_PASSWORD"
+ }
+
+ /**
+ * Random suffix appended appended to table and directory names in order to avoid collisions
+ * between separate Travis builds.
+ */
+ protected val randomSuffix: String = Math.abs(Random.nextLong()).toString
+
+ protected val tempDir: String = AWS_S3_SCRATCH_SPACE + randomSuffix + "/"
+
+ /**
+ * Spark Context with Hadoop file overridden to point at our local test data file for this suite,
+ * no-matter what temp directory was generated and requested.
+ */
+ protected var sc: SparkContext = _
+ protected var sqlContext: SQLContext = _
+ protected var conn: Connection = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sc = new SparkContext("local", "RedshiftSourceSuite")
+ // Bypass Hadoop's FileSystem caching mechanism so that we don't cache the credentials:
+ sc.hadoopConfiguration.setBoolean("fs.s3.impl.disable.cache", true)
+ sc.hadoopConfiguration.setBoolean("fs.s3n.impl.disable.cache", true)
+ sc.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID)
+ sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY)
+ conn = DefaultJDBCWrapper.getConnector(None, jdbcUrl, None)
+ }
+
+ override def afterAll(): Unit = {
+ try {
+ val conf = new Configuration(false)
+ conf.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID)
+ conf.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY)
+ // Bypass Hadoop's FileSystem caching mechanism so that we don't cache the credentials:
+ conf.setBoolean("fs.s3.impl.disable.cache", true)
+ conf.setBoolean("fs.s3n.impl.disable.cache", true)
+ conf.set("fs.s3.impl", classOf[NativeS3FileSystem].getCanonicalName)
+ conf.set("fs.s3n.impl", classOf[NativeS3FileSystem].getCanonicalName)
+ val fs = FileSystem.get(URI.create(tempDir), conf)
+ fs.delete(new Path(tempDir), true)
+ fs.close()
+ } finally {
+ try {
+ conn.close()
+ } finally {
+ try {
+ sc.stop()
+ } finally {
+ super.afterAll()
+ }
+ }
+ }
+ }
+
+ override protected def beforeEach(): Unit = {
+ super.beforeEach()
+ sqlContext = new TestHiveContext(sc, loadTestTables = false)
+ }
+
+ /**
+ * Create a new DataFrameReader using common options for reading from Redshift.
+ */
+ protected def read: DataFrameReader = {
+ sqlContext.read
+ .format("com.databricks.spark.redshift")
+ .option("url", jdbcUrl)
+ .option("tempdir", tempDir)
+ .option("forward_spark_s3_credentials", "true")
+ }
+ /**
+ * Create a new DataFrameWriter using common options for writing to Redshift.
+ */
+ protected def write(df: DataFrame): DataFrameWriter[Row] = {
+ df.write
+ .format("com.databricks.spark.redshift")
+ .option("url", jdbcUrl)
+ .option("tempdir", tempDir)
+ .option("forward_spark_s3_credentials", "true")
+ }
+
+ protected def createTestDataInRedshift(tableName: String): Unit = {
+ conn.createStatement().executeUpdate(
+ s"""
+ |create table $tableName (
+ |testbyte int2,
+ |testbool boolean,
+ |testdate date,
+ |testdouble float8,
+ |testfloat float4,
+ |testint int4,
+ |testlong int8,
+ |testshort int2,
+ |teststring varchar(256),
+ |testtimestamp timestamp
+ |)
+ """.stripMargin
+ )
+ // scalastyle:off
+ conn.createStatement().executeUpdate(
+ s"""
+ |insert into $tableName values
+ |(null, null, null, null, null, null, null, null, null, null),
+ |(0, null, '2015-07-03', 0.0, -1.0, 4141214, 1239012341823719, null, 'f', '2015-07-03 00:00:00.000'),
+ |(0, false, null, -1234152.12312498, 100000.0, null, 1239012341823719, 24, '___|_123', null),
+ |(1, false, '2015-07-02', 0.0, 0.0, 42, 1239012341823719, -13, 'asdf', '2015-07-02 00:00:00.000'),
+ |(1, true, '2015-07-01', 1234152.12312498, 1.0, 42, 1239012341823719, 23, 'Unicode''s樂趣', '2015-07-01 00:00:00.001')
+ """.stripMargin
+ )
+ // scalastyle:on
+ conn.commit()
+ }
+
+ protected def withTempRedshiftTable[T](namePrefix: String)(body: String => T): T = {
+ val tableName = s"$namePrefix$randomSuffix"
+ try {
+ body(tableName)
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+
+ /**
+ * Save the given DataFrame to Redshift, then load the results back into a DataFrame and check
+ * that the returned DataFrame matches the one that we saved.
+ *
+ * @param tableName the table name to use
+ * @param df the DataFrame to save
+ * @param expectedSchemaAfterLoad if specified, the expected schema after loading the data back
+ * from Redshift. This should be used in cases where you expect
+ * the schema to differ due to reasons like case-sensitivity.
+ * @param saveMode the [[SaveMode]] to use when writing data back to Redshift
+ */
+ def testRoundtripSaveAndLoad(
+ tableName: String,
+ df: DataFrame,
+ expectedSchemaAfterLoad: Option[StructType] = None,
+ saveMode: SaveMode = SaveMode.ErrorIfExists): Unit = {
+ try {
+ write(df)
+ .option("dbtable", tableName)
+ .mode(saveMode)
+ .save()
+ // Check that the table exists. It appears that creating a table in one connection then
+ // immediately querying for existence from another connection may result in spurious "table
+ // doesn't exist" errors; this caused the "save with all empty partitions" test to become
+ // flaky (see #146). To work around this, add a small sleep and check again:
+ if (!DefaultJDBCWrapper.tableExists(conn, tableName)) {
+ Thread.sleep(1000)
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ }
+ val loadedDf = read.option("dbtable", tableName).load()
+ assert(loadedDf.schema === expectedSchemaAfterLoad.getOrElse(df.schema))
+ checkAnswer(loadedDf, df.collect())
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+}
diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/PostgresDriverIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/PostgresDriverIntegrationSuite.scala
new file mode 100755
index 0000000000000..3513ec48d56f7
--- /dev/null
+++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/PostgresDriverIntegrationSuite.scala
@@ -0,0 +1,45 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.tags.ExtendedRedshiftTest
+
+/**
+ * Basic integration tests with the Postgres JDBC driver.
+ */
+@ExtendedRedshiftTest
+class PostgresDriverIntegrationSuite extends IntegrationSuiteBase {
+
+ override def jdbcUrl: String = {
+ super.jdbcUrl.replace("jdbc:redshift", "jdbc:postgresql")
+ }
+
+ test("postgresql driver takes precedence for jdbc:postgresql:// URIs") {
+ val conn = DefaultJDBCWrapper.getConnector(None, jdbcUrl, None)
+ try {
+ // TODO(josh): this is slightly different than what was done in open-source spark-redshift.
+ // This is due to conflicting PG driver being pulled in via transitive Spark test deps.
+ // We should consider removing the postgres driver support entirely in our Databricks internal
+ // version.
+ assert(conn.getClass.getName === "org.postgresql.jdbc.PgConnection")
+ } finally {
+ conn.close()
+ }
+ }
+
+ test("roundtrip save and load") {
+ conn.setAutoCommit(false) // TODO(josh): Hack needed due to different PG driver version
+ val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1),
+ StructType(StructField("foo", IntegerType) :: Nil))
+ testRoundtripSaveAndLoad(s"save_with_one_empty_partition_$randomSuffix", df)
+ }
+}
diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftCredentialsInConfIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftCredentialsInConfIntegrationSuite.scala
new file mode 100755
index 0000000000000..cde68da81d1a7
--- /dev/null
+++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftCredentialsInConfIntegrationSuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.tags.ExtendedRedshiftTest
+
+/**
+ * This suite performs basic integration tests where the Redshift credentials have been
+ * specified via `spark-redshift`'s configuration rather than as part of the JDBC URL.
+ */
+@ExtendedRedshiftTest
+class RedshiftCredentialsInConfIntegrationSuite extends IntegrationSuiteBase {
+
+ test("roundtrip save and load") {
+ val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1),
+ StructType(StructField("foo", IntegerType) :: Nil))
+ val tableName = s"roundtrip_save_and_load_$randomSuffix"
+ try {
+ write(df)
+ .option("url", AWS_REDSHIFT_JDBC_URL)
+ .option("user", AWS_REDSHIFT_USER)
+ .option("password", AWS_REDSHIFT_PASSWORD)
+ .option("dbtable", tableName)
+ .save()
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ val loadedDf = read
+ .option("url", AWS_REDSHIFT_JDBC_URL)
+ .option("user", AWS_REDSHIFT_USER)
+ .option("password", AWS_REDSHIFT_PASSWORD)
+ .option("dbtable", tableName)
+ .load()
+ assert(loadedDf.schema === df.schema)
+ checkAnswer(loadedDf, df.collect())
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+
+}
diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftReadSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftReadSuite.scala
new file mode 100755
index 0000000000000..dfd98281043f5
--- /dev/null
+++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftReadSuite.scala
@@ -0,0 +1,243 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import org.apache.spark.sql.{execution, Row}
+import org.apache.spark.sql.types.LongType
+import org.apache.spark.tags.ExtendedRedshiftTest
+
+/**
+ * End-to-end tests of functionality which only impacts the read path (e.g. filter pushdown).
+ */
+@ExtendedRedshiftTest
+class RedshiftReadSuite extends IntegrationSuiteBase {
+
+ private val test_table: String = s"read_suite_test_table_$randomSuffix"
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ conn.prepareStatement(s"drop table if exists $test_table").executeUpdate()
+ conn.commit()
+ createTestDataInRedshift(test_table)
+ }
+
+ override def afterAll(): Unit = {
+ try {
+ conn.prepareStatement(s"drop table if exists $test_table").executeUpdate()
+ conn.commit()
+ } finally {
+ super.afterAll()
+ }
+ }
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ read.option("dbtable", test_table).load().createOrReplaceTempView("test_table")
+ }
+
+ test("DefaultSource can load Redshift UNLOAD output to a DataFrame") {
+ checkAnswer(
+ sqlContext.sql("select * from test_table"),
+ TestUtils.expectedData)
+ }
+
+ test("count() on DataFrame created from a Redshift table") {
+ checkAnswer(
+ sqlContext.sql("select count(*) from test_table"),
+ Seq(Row(TestUtils.expectedData.length))
+ )
+ }
+
+ test("count() on DataFrame created from a Redshift query") {
+ val loadedDf =
+ // scalastyle:off
+ read.option("query", s"select * from $test_table where teststring = 'Unicode''s樂趣'").load()
+ // scalastyle:on
+ checkAnswer(
+ loadedDf.selectExpr("count(*)"),
+ Seq(Row(1))
+ )
+ }
+
+ test("backslashes in queries/subqueries are escaped (regression test for #215)") {
+ val loadedDf =
+ read.option("query", s"select replace(teststring, '\\\\', '') as col from $test_table").load()
+ checkAnswer(
+ loadedDf.filter("col = 'asdf'"),
+ Seq(Row("asdf"))
+ )
+ }
+
+ test("Can load output when 'dbtable' is a subquery wrapped in parentheses") {
+ // scalastyle:off
+ val query =
+ s"""
+ |(select testbyte, testbool
+ |from $test_table
+ |where testbool = true
+ | and teststring = 'Unicode''s樂趣'
+ | and testdouble = 1234152.12312498
+ | and testfloat = 1.0
+ | and testint = 42)
+ """.stripMargin
+ // scalastyle:on
+ checkAnswer(read.option("dbtable", query).load(), Seq(Row(1, true)))
+ }
+
+ test("Can load output when 'query' is specified instead of 'dbtable'") {
+ // scalastyle:off
+ val query =
+ s"""
+ |select testbyte, testbool
+ |from $test_table
+ |where testbool = true
+ | and teststring = 'Unicode''s樂趣'
+ | and testdouble = 1234152.12312498
+ | and testfloat = 1.0
+ | and testint = 42
+ """.stripMargin
+ // scalastyle:on
+ checkAnswer(read.option("query", query).load(), Seq(Row(1, true)))
+ }
+
+ test("Can load output of Redshift aggregation queries") {
+ checkAnswer(
+ read.option("query", s"select testbool, count(*) from $test_table group by testbool").load(),
+ Seq(Row(true, 1), Row(false, 2), Row(null, 2)))
+ }
+
+ test("multiple scans on same table") {
+ // .rdd() forces the first query to be unloaded from Redshift
+ val rdd1 = sqlContext.sql("select testint from test_table").rdd
+ // Similarly, this also forces an unload:
+ sqlContext.sql("select testdouble from test_table").rdd
+ // If the unloads were performed into the same directory then this call would fail: the
+ // second unload from rdd2 would have overwritten the integers with doubles, so we'd get
+ // a NumberFormatException.
+ rdd1.count()
+ }
+
+ test("DefaultSource supports simple column filtering") {
+ checkAnswer(
+ sqlContext.sql("select testbyte, testbool from test_table"),
+ Seq(
+ Row(null, null),
+ Row(0.toByte, null),
+ Row(0.toByte, false),
+ Row(1.toByte, false),
+ Row(1.toByte, true)))
+ }
+
+ test("query with pruned and filtered scans") {
+ // scalastyle:off
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |select testbyte, testbool
+ |from test_table
+ |where testbool = true
+ | and teststring = "Unicode's樂趣"
+ | and testdouble = 1234152.12312498
+ | and testfloat = 1.0
+ | and testint = 42
+ """.stripMargin),
+ Seq(Row(1, true)))
+ // scalastyle:on
+ }
+
+ test("RedshiftRelation implements Spark 1.6+'s unhandledFilters API") {
+ assume(org.apache.spark.SPARK_VERSION.take(3) >= "1.6")
+ val df = sqlContext.sql("select testbool from test_table where testbool = true")
+ val physicalPlan = df.queryExecution.sparkPlan
+ physicalPlan.collectFirst { case f: execution.FilterExec => f }.foreach { filter =>
+ fail(s"Filter should have been eliminated:\n${df.queryExecution}")
+ }
+ }
+
+ test("filtering based on date constants (regression test for #152)") {
+ val date = TestUtils.toDate(year = 2015, zeroBasedMonth = 6, date = 3)
+ val df = sqlContext.sql("select testdate from test_table")
+
+ checkAnswer(df.filter(df("testdate") === date), Seq(Row(date)))
+ // This query failed in Spark 1.6.0 but not in earlier versions. It looks like 1.6.0 performs
+ // constant-folding, whereas earlier Spark versions would preserve the cast which prevented
+ // filter pushdown.
+ checkAnswer(df.filter("testdate = to_date('2015-07-03')"), Seq(Row(date)))
+ }
+
+ test("filtering based on timestamp constants (regression test for #152)") {
+ val timestamp = TestUtils.toTimestamp(2015, zeroBasedMonth = 6, 1, 0, 0, 0, 1)
+ val df = sqlContext.sql("select testtimestamp from test_table")
+
+ checkAnswer(df.filter(df("testtimestamp") === timestamp), Seq(Row(timestamp)))
+ // This query failed in Spark 1.6.0 but not in earlier versions. It looks like 1.6.0 performs
+ // constant-folding, whereas earlier Spark versions would preserve the cast which prevented
+ // filter pushdown.
+ checkAnswer(df.filter("testtimestamp = '2015-07-01 00:00:00.001'"), Seq(Row(timestamp)))
+ }
+
+ test("read special float values (regression test for #261)") {
+ val tableName = s"roundtrip_special_float_values_$randomSuffix"
+ try {
+ conn.createStatement().executeUpdate(
+ s"CREATE TABLE $tableName (x real)")
+ conn.createStatement().executeUpdate(
+ s"INSERT INTO $tableName VALUES ('NaN'), ('Infinity'), ('-Infinity')")
+ conn.commit()
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ // Due to #98, we use Double here instead of float:
+ checkAnswer(
+ read.option("dbtable", tableName).load(),
+ Seq(Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity).map(x => Row.apply(x)))
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+
+ test("read special double values (regression test for #261)") {
+ val tableName = s"roundtrip_special_double_values_$randomSuffix"
+ try {
+ conn.createStatement().executeUpdate(
+ s"CREATE TABLE $tableName (x double precision)")
+ conn.createStatement().executeUpdate(
+ s"INSERT INTO $tableName VALUES ('NaN'), ('Infinity'), ('-Infinity')")
+ conn.commit()
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ checkAnswer(
+ read.option("dbtable", tableName).load(),
+ Seq(Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity).map(x => Row.apply(x)))
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+
+ test("read records containing escaped characters") {
+ withTempRedshiftTable("records_with_escaped_characters") { tableName =>
+ conn.createStatement().executeUpdate(
+ s"CREATE TABLE $tableName (x text)")
+ conn.createStatement().executeUpdate(
+ s"""INSERT INTO $tableName VALUES ('a\\nb'), ('\\\\'), ('"')""")
+ conn.commit()
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ checkAnswer(
+ read.option("dbtable", tableName).load(),
+ Seq("a\nb", "\\", "\"").map(x => Row.apply(x)))
+ }
+ }
+
+ test("read result of approximate count(distinct) query (#300)") {
+ val df = read
+ .option("query", s"select approximate count(distinct testbool) as c from $test_table")
+ .load()
+ assert(df.schema.fields(0).dataType === LongType)
+ }
+}
diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftWriteSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftWriteSuite.scala
new file mode 100755
index 0000000000000..a00fcf4d13823
--- /dev/null
+++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftWriteSuite.scala
@@ -0,0 +1,164 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.sql.SQLException
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.types._
+import org.apache.spark.tags.ExtendedRedshiftTest
+
+/**
+ * End-to-end tests of functionality which involves writing to Redshift via the connector.
+ */
+abstract class BaseRedshiftWriteSuite extends IntegrationSuiteBase {
+
+ protected val tempformat: String
+
+ override protected def write(df: DataFrame): DataFrameWriter[Row] =
+ super.write(df).option("tempformat", tempformat)
+
+ test("roundtrip save and load") {
+ // This test can be simplified once #98 is fixed.
+ val tableName = s"roundtrip_save_and_load_$randomSuffix"
+ try {
+ write(
+ sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), TestUtils.testSchema))
+ .option("dbtable", tableName)
+ .mode(SaveMode.ErrorIfExists)
+ .save()
+
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ checkAnswer(read.option("dbtable", tableName).load(), TestUtils.expectedData)
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+
+ test("roundtrip save and load with uppercase column names") {
+ testRoundtripSaveAndLoad(
+ s"roundtrip_write_and_read_with_uppercase_column_names_$randomSuffix",
+ sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
+ StructType(StructField("A", IntegerType) :: Nil)),
+ expectedSchemaAfterLoad = Some(StructType(StructField("a", IntegerType) :: Nil)))
+ }
+
+ test("save with column names that are reserved words") {
+ testRoundtripSaveAndLoad(
+ s"save_with_column_names_that_are_reserved_words_$randomSuffix",
+ sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
+ StructType(StructField("table", IntegerType) :: Nil)))
+ }
+
+ test("save with one empty partition (regression test for #96)") {
+ val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 2),
+ StructType(StructField("foo", IntegerType) :: Nil))
+ assert(df.rdd.glom.collect() === Array(Array.empty[Row], Array(Row(1))))
+ testRoundtripSaveAndLoad(s"save_with_one_empty_partition_$randomSuffix", df)
+ }
+
+ test("save with all empty partitions (regression test for #96)") {
+ val df = sqlContext.createDataFrame(sc.parallelize(Seq.empty[Row], 2),
+ StructType(StructField("foo", IntegerType) :: Nil))
+ assert(df.rdd.glom.collect() === Array(Array.empty[Row], Array.empty[Row]))
+ testRoundtripSaveAndLoad(s"save_with_all_empty_partitions_$randomSuffix", df)
+ // Now try overwriting that table. Although the new table is empty, it should still overwrite
+ // the existing table.
+ val df2 = df.withColumnRenamed("foo", "bar")
+ testRoundtripSaveAndLoad(
+ s"save_with_all_empty_partitions_$randomSuffix", df2, saveMode = SaveMode.Overwrite)
+ }
+
+ test("informative error message when saving a table with string that is longer than max length") {
+ val tableName = s"error_message_when_string_too_long_$randomSuffix"
+ try {
+ val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 512))),
+ StructType(StructField("A", StringType) :: Nil))
+ val e = intercept[SQLException] {
+ write(df)
+ .option("dbtable", tableName)
+ .mode(SaveMode.ErrorIfExists)
+ .save()
+ }
+ assert(e.getMessage.contains("while loading data into Redshift"))
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+
+ test("full timestamp precision is preserved in loads (regression test for #214)") {
+ val timestamps = Seq(
+ TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 1),
+ TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 10),
+ TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 100),
+ TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 1000))
+ testRoundtripSaveAndLoad(
+ s"full_timestamp_precision_is_preserved$randomSuffix",
+ sqlContext.createDataFrame(sc.parallelize(timestamps.map(Row(_))),
+ StructType(StructField("ts", TimestampType) :: Nil))
+ )
+ }
+}
+
+@ExtendedRedshiftTest
+class AvroRedshiftWriteSuite extends BaseRedshiftWriteSuite {
+ override protected val tempformat: String = "AVRO"
+
+ test("informative error message when saving with column names that contain spaces (#84)") {
+ intercept[IllegalArgumentException] {
+ testRoundtripSaveAndLoad(
+ s"error_when_saving_column_name_with_spaces_$randomSuffix",
+ sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
+ StructType(StructField("column name with spaces", IntegerType) :: Nil)))
+ }
+ }
+}
+
+@ExtendedRedshiftTest
+class CSVRedshiftWriteSuite extends BaseRedshiftWriteSuite {
+ override protected val tempformat: String = "CSV"
+
+ test("save with column names that contain spaces (#84)") {
+ testRoundtripSaveAndLoad(
+ s"save_with_column_names_that_contain_spaces_$randomSuffix",
+ sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
+ StructType(StructField("column name with spaces", IntegerType) :: Nil)))
+ }
+}
+
+@ExtendedRedshiftTest
+class CSVGZIPRedshiftWriteSuite extends IntegrationSuiteBase {
+ // Note: we purposely don't inherit from BaseRedshiftWriteSuite because we're only interested in
+ // testing basic functionality of the GZIP code; the rest of the write path should be unaffected
+ // by compression here.
+
+ override protected def write(df: DataFrame): DataFrameWriter[Row] =
+ super.write(df).option("tempformat", "CSV GZIP")
+
+ test("roundtrip save and load") {
+ // This test can be simplified once #98 is fixed.
+ val tableName = s"roundtrip_save_and_load_$randomSuffix"
+ try {
+ write(
+ sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), TestUtils.testSchema))
+ .option("dbtable", tableName)
+ .mode(SaveMode.ErrorIfExists)
+ .save()
+
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ checkAnswer(read.option("dbtable", tableName).load(), TestUtils.expectedData)
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+}
diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/STSIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/STSIntegrationSuite.scala
new file mode 100755
index 0000000000000..c86081799e7ca
--- /dev/null
+++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/STSIntegrationSuite.scala
@@ -0,0 +1,75 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import com.amazonaws.auth.BasicAWSCredentials
+import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClient
+import com.amazonaws.services.securitytoken.model.AssumeRoleRequest
+
+import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.tags.ExtendedRedshiftTest
+
+/**
+ * Integration tests for accessing S3 using Amazon Security Token Service (STS) credentials.
+ */
+@ExtendedRedshiftTest
+class STSIntegrationSuite extends IntegrationSuiteBase {
+
+ private val STS_ROLE_ARN: String = loadConfigFromEnv("STS_ROLE_ARN")
+ private var STS_ACCESS_KEY_ID: String = _
+ private var STS_SECRET_ACCESS_KEY: String = _
+ private var STS_SESSION_TOKEN: String = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ val awsCredentials = new BasicAWSCredentials(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY)
+ val stsClient = new AWSSecurityTokenServiceClient(awsCredentials)
+ val assumeRoleRequest = new AssumeRoleRequest()
+ assumeRoleRequest.setDurationSeconds(900) // this is the minimum supported duration
+ assumeRoleRequest.setRoleArn(STS_ROLE_ARN)
+ assumeRoleRequest.setRoleSessionName(s"spark-$randomSuffix")
+ val creds = stsClient.assumeRole(assumeRoleRequest).getCredentials
+ STS_ACCESS_KEY_ID = creds.getAccessKeyId
+ STS_SECRET_ACCESS_KEY = creds.getSecretAccessKey
+ STS_SESSION_TOKEN = creds.getSessionToken
+ }
+
+ test("roundtrip save and load") {
+ val tableName = s"roundtrip_save_and_load$randomSuffix"
+ val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
+ StructType(StructField("a", IntegerType) :: Nil))
+ try {
+ write(df)
+ .option("dbtable", tableName)
+ .option("forward_spark_s3_credentials", "false")
+ .option("temporary_aws_access_key_id", STS_ACCESS_KEY_ID)
+ .option("temporary_aws_secret_access_key", STS_SECRET_ACCESS_KEY)
+ .option("temporary_aws_session_token", STS_SESSION_TOKEN)
+ .mode(SaveMode.ErrorIfExists)
+ .save()
+
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ val loadedDf = read
+ .option("dbtable", tableName)
+ .option("forward_spark_s3_credentials", "false")
+ .option("temporary_aws_access_key_id", STS_ACCESS_KEY_ID)
+ .option("temporary_aws_secret_access_key", STS_SECRET_ACCESS_KEY)
+ .option("temporary_aws_session_token", STS_SESSION_TOKEN)
+ .load()
+ assert(loadedDf.schema.length === 1)
+ assert(loadedDf.columns === Seq("a"))
+ checkAnswer(loadedDf, Seq(Row(1)))
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+}
diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/SaveModeIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/SaveModeIntegrationSuite.scala
new file mode 100755
index 0000000000000..0cdc8ba893171
--- /dev/null
+++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/SaveModeIntegrationSuite.scala
@@ -0,0 +1,122 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.tags.ExtendedRedshiftTest
+
+/**
+ * End-to-end tests of [[SaveMode]] behavior.
+ */
+@ExtendedRedshiftTest
+class SaveModeIntegrationSuite extends IntegrationSuiteBase {
+ test("SaveMode.Overwrite with schema-qualified table name (#97)") {
+ withTempRedshiftTable("overwrite_schema_qualified_table_name") { tableName =>
+ val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
+ StructType(StructField("a", IntegerType) :: Nil))
+ // Ensure that the table exists:
+ write(df)
+ .option("dbtable", tableName)
+ .mode(SaveMode.ErrorIfExists)
+ .save()
+ assert(DefaultJDBCWrapper.tableExists(conn, s"PUBLIC.$tableName"))
+ // Try overwriting that table while using the schema-qualified table name:
+ write(df)
+ .option("dbtable", s"PUBLIC.$tableName")
+ .mode(SaveMode.Overwrite)
+ .save()
+ }
+ }
+
+ test("SaveMode.Overwrite with non-existent table") {
+ testRoundtripSaveAndLoad(
+ s"overwrite_non_existent_table$randomSuffix",
+ sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
+ StructType(StructField("a", IntegerType) :: Nil)),
+ saveMode = SaveMode.Overwrite)
+ }
+
+ test("SaveMode.Overwrite with existing table") {
+ withTempRedshiftTable("overwrite_existing_table") { tableName =>
+ // Create a table to overwrite
+ write(sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
+ StructType(StructField("a", IntegerType) :: Nil)))
+ .option("dbtable", tableName)
+ .mode(SaveMode.ErrorIfExists)
+ .save()
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+
+ val overwritingDf =
+ sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), TestUtils.testSchema)
+ write(overwritingDf)
+ .option("dbtable", tableName)
+ .mode(SaveMode.Overwrite)
+ .save()
+
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ checkAnswer(read.option("dbtable", tableName).load(), TestUtils.expectedData)
+ }
+ }
+
+ // TODO:test overwrite that fails.
+
+ test("Append SaveMode doesn't destroy existing data") {
+ withTempRedshiftTable("append_doesnt_destroy_existing_data") { tableName =>
+ createTestDataInRedshift(tableName)
+ val extraData = Seq(
+ Row(2.toByte, false, null, -1234152.12312498, 100000.0f, null, 1239012341823719L,
+ 24.toShort, "___|_123", null))
+
+ write(sqlContext.createDataFrame(sc.parallelize(extraData), TestUtils.testSchema))
+ .option("dbtable", tableName)
+ .mode(SaveMode.Append)
+ .saveAsTable(tableName)
+
+ checkAnswer(
+ sqlContext.sql(s"select * from $tableName"),
+ TestUtils.expectedData ++ extraData)
+ }
+ }
+
+ test("Respect SaveMode.ErrorIfExists when table exists") {
+ withTempRedshiftTable("respect_savemode_error_if_exists") { tableName =>
+ val rdd = sc.parallelize(TestUtils.expectedData)
+ val df = sqlContext.createDataFrame(rdd, TestUtils.testSchema)
+ createTestDataInRedshift(tableName) // to ensure that the table already exists
+
+ // Check that SaveMode.ErrorIfExists throws an exception
+ val e = intercept[Exception] {
+ write(df)
+ .option("dbtable", tableName)
+ .mode(SaveMode.ErrorIfExists)
+ .saveAsTable(tableName)
+ }
+ assert(e.getMessage.contains("exists"))
+ }
+ }
+
+ test("Do nothing when table exists if SaveMode = Ignore") {
+ withTempRedshiftTable("do_nothing_when_savemode_ignore") { tableName =>
+ val rdd = sc.parallelize(TestUtils.expectedData.drop(1))
+ val df = sqlContext.createDataFrame(rdd, TestUtils.testSchema)
+ createTestDataInRedshift(tableName) // to ensure that the table already exists
+ write(df)
+ .option("dbtable", tableName)
+ .mode(SaveMode.Ignore)
+ .saveAsTable(tableName)
+
+ // Check that SaveMode.Ignore does nothing
+ checkAnswer(
+ sqlContext.sql(s"select * from $tableName"),
+ TestUtils.expectedData)
+ }
+ }
+}
diff --git a/external/redshift/pom.xml b/external/redshift/pom.xml
new file mode 100644
index 0000000000000..6ea5e4318f5d3
--- /dev/null
+++ b/external/redshift/pom.xml
@@ -0,0 +1,145 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.11
+ 2.1.0
+ ../../pom.xml
+
+
+ com.databricks
+ spark-redshift_2.11
+
+ redshift
+
+ jar
+ Spark Redshift
+ http://spark.apache.org/
+
+
+
+ hadoop-2.7
+
+
+ org.apache.hadoop
+ hadoop-aws
+ ${hadoop.version}
+ test
+
+
+
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ com.databricks
+ spark-avro_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+
+ com.amazonaws
+ aws-java-sdk-core
+ 1.9.40
+ provided
+
+
+ com.amazonaws
+ aws-java-sdk-s3
+ 1.9.40
+ provided
+
+
+ com.amazonaws
+ aws-java-sdk-sts
+ 1.9.40
+ provided
+
+
+ com.eclipsesource.minimal-json
+ minimal-json
+ 0.9.4
+ compile
+
+
+ org.apache.hadoop
+ hadoop-client
+ provided
+
+
+ org.apache.hadoop
+ hadoop-common
+ ${hadoop.version}
+ provided
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-hive_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+
+ com.amazonaws
+ redshift.jdbc4
+ 1.1.7.1007
+ jar
+ test
+
+
+ org.mockito
+ mockito-core
+ test
+
+
+ org.apache.spark
+ spark-tags_${scala.binary.version}
+ test-jar
+ test
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/AWSCredentialsUtils.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/AWSCredentialsUtils.scala
new file mode 100755
index 0000000000000..0fb6a307ee644
--- /dev/null
+++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/AWSCredentialsUtils.scala
@@ -0,0 +1,101 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.net.URI
+
+import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, AWSSessionCredentials, BasicAWSCredentials, DefaultAWSCredentialsProviderChain}
+import com.databricks.spark.redshift.Parameters.MergedParameters
+import org.apache.hadoop.conf.Configuration
+
+private[redshift] object AWSCredentialsUtils {
+
+ /**
+ * Generates a credentials string for use in Redshift COPY and UNLOAD statements.
+ * Favors a configured `aws_iam_role` if available in the parameters.
+ */
+ def getRedshiftCredentialsString(
+ params: MergedParameters,
+ sparkAwsCredentials: AWSCredentials): String = {
+
+ def awsCredsToString(credentials: AWSCredentials): String = {
+ credentials match {
+ case creds: AWSSessionCredentials =>
+ s"aws_access_key_id=${creds.getAWSAccessKeyId};" +
+ s"aws_secret_access_key=${creds.getAWSSecretKey};token=${creds.getSessionToken}"
+ case creds =>
+ s"aws_access_key_id=${creds.getAWSAccessKeyId};" +
+ s"aws_secret_access_key=${creds.getAWSSecretKey}"
+ }
+ }
+ if (params.iamRole.isDefined) {
+ s"aws_iam_role=${params.iamRole.get}"
+ } else if (params.temporaryAWSCredentials.isDefined) {
+ awsCredsToString(params.temporaryAWSCredentials.get.getCredentials)
+ } else if (params.forwardSparkS3Credentials) {
+ awsCredsToString(sparkAwsCredentials)
+ } else {
+ throw new IllegalStateException("No Redshift S3 authentication mechanism was specified")
+ }
+ }
+
+ def staticCredentialsProvider(credentials: AWSCredentials): AWSCredentialsProvider = {
+ new AWSCredentialsProvider {
+ override def getCredentials: AWSCredentials = credentials
+ override def refresh(): Unit = {}
+ }
+ }
+
+ def load(params: MergedParameters, hadoopConfiguration: Configuration): AWSCredentialsProvider = {
+ params.temporaryAWSCredentials.getOrElse(loadFromURI(params.rootTempDir, hadoopConfiguration))
+ }
+
+ private def loadFromURI(
+ tempPath: String,
+ hadoopConfiguration: Configuration): AWSCredentialsProvider = {
+ // scalastyle:off
+ // A good reference on Hadoop's configuration loading / precedence is
+ // https://github.com/apache/hadoop/blob/trunk/hadoop-tools/hadoop-aws/src/site/markdown/tools/hadoop-aws/index.md
+ // scalastyle:on
+ val uri = new URI(tempPath)
+ val uriScheme = uri.getScheme
+
+ uriScheme match {
+ case "s3" | "s3n" | "s3a" =>
+ // This matches what S3A does, with one exception: we don't support anonymous credentials.
+ // First, try to parse from URI:
+ Option(uri.getUserInfo).flatMap { userInfo =>
+ if (userInfo.contains(":")) {
+ val Array(accessKey, secretKey) = userInfo.split(":")
+ Some(staticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey)))
+ } else {
+ None
+ }
+ }.orElse {
+ // Next, try to read from configuration
+ val accessKeyConfig = if (uriScheme == "s3a") "access.key" else "awsAccessKeyId"
+ val secretKeyConfig = if (uriScheme == "s3a") "secret.key" else "awsSecretAccessKey"
+
+ val accessKey = hadoopConfiguration.get(s"fs.$uriScheme.$accessKeyConfig", null)
+ val secretKey = hadoopConfiguration.get(s"fs.$uriScheme.$secretKeyConfig", null)
+ if (accessKey != null && secretKey != null) {
+ Some(staticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey)))
+ } else {
+ None
+ }
+ }.getOrElse {
+ // Finally, fall back on the instance profile provider
+ new DefaultAWSCredentialsProviderChain()
+ }
+ case other =>
+ throw new IllegalArgumentException(s"Unrecognized scheme $other; expected s3, s3n, or s3a")
+ }
+ }
+}
diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/Conversions.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/Conversions.scala
new file mode 100755
index 0000000000000..a83411db48c85
--- /dev/null
+++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/Conversions.scala
@@ -0,0 +1,118 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.sql.Timestamp
+import java.text.{DecimalFormat, DecimalFormatSymbols, SimpleDateFormat}
+import java.util.Locale
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.expressions.GenericRow
+import org.apache.spark.sql.types._
+
+/**
+ * Data type conversions for Redshift unloaded data
+ */
+private[redshift] object Conversions {
+
+ /**
+ * Parse a boolean using Redshift's UNLOAD bool syntax
+ */
+ private def parseBoolean(s: String): Boolean = {
+ if (s == "t") true
+ else if (s == "f") false
+ else throw new IllegalArgumentException(s"Expected 't' or 'f' but got '$s'")
+ }
+
+ /**
+ * Formatter for writing decimals unloaded from Redshift.
+ *
+ * Note that Java Formatters are NOT thread-safe, so you should not re-use instances of this
+ * DecimalFormat across threads.
+ */
+ def createRedshiftDecimalFormat(): DecimalFormat = {
+ val format = new DecimalFormat()
+ format.setParseBigDecimal(true)
+ format.setDecimalFormatSymbols(new DecimalFormatSymbols(Locale.US))
+ format
+ }
+
+ /**
+ * Formatter for parsing strings exported from Redshift DATE columns.
+ *
+ * Note that Java Formatters are NOT thread-safe, so you should not re-use instances of this
+ * SimpleDateFormat across threads.
+ */
+ def createRedshiftDateFormat(): SimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd")
+
+ /**
+ * Formatter for formatting timestamps for insertion into Redshift TIMESTAMP columns.
+ *
+ * This formatter should not be used to parse timestamps returned from Redshift UNLOAD commands;
+ * instead, use [[Timestamp.valueOf()]].
+ *
+ * Note that Java Formatters are NOT thread-safe, so you should not re-use instances of this
+ * SimpleDateFormat across threads.
+ */
+ def createRedshiftTimestampFormat(): SimpleDateFormat = {
+ new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS")
+ }
+
+ /**
+ * Return a function that will convert arrays of strings conforming to the given schema to Rows.
+ *
+ * Note that instances of this function are NOT thread-safe.
+ */
+ def createRowConverter(schema: StructType): Array[String] => InternalRow = {
+ val dateFormat = createRedshiftDateFormat()
+ val decimalFormat = createRedshiftDecimalFormat()
+ val conversionFunctions: Array[String => Any] = schema.fields.map { field =>
+ field.dataType match {
+ case ByteType => (data: String) => java.lang.Byte.parseByte(data)
+ case BooleanType => (data: String) => parseBoolean(data)
+ case DateType => (data: String) => new java.sql.Date(dateFormat.parse(data).getTime)
+ case DoubleType => (data: String) => data match {
+ case "nan" => Double.NaN
+ case "inf" => Double.PositiveInfinity
+ case "-inf" => Double.NegativeInfinity
+ case _ => java.lang.Double.parseDouble(data)
+ }
+ case FloatType => (data: String) => data match {
+ case "nan" => Float.NaN
+ case "inf" => Float.PositiveInfinity
+ case "-inf" => Float.NegativeInfinity
+ case _ => java.lang.Float.parseFloat(data)
+ }
+ case dt: DecimalType =>
+ (data: String) => decimalFormat.parse(data).asInstanceOf[java.math.BigDecimal]
+ case IntegerType => (data: String) => java.lang.Integer.parseInt(data)
+ case LongType => (data: String) => java.lang.Long.parseLong(data)
+ case ShortType => (data: String) => java.lang.Short.parseShort(data)
+ case StringType => (data: String) => data
+ case TimestampType => (data: String) => Timestamp.valueOf(data)
+ case _ => (data: String) => data
+ }
+ }
+ // As a performance optimization, re-use the same mutable row / array:
+ val converted: Array[Any] = Array.fill(schema.length)(null)
+ val externalRow = new GenericRow(converted)
+ val encoder = RowEncoder(schema)
+ (inputRow: Array[String]) => {
+ var i = 0
+ while (i < schema.length) {
+ val data = inputRow(i)
+ converted(i) = if (data == null || data.isEmpty) null else conversionFunctions(i)(data)
+ i += 1
+ }
+ encoder.toRow(externalRow)
+ }
+ }
+}
diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala
new file mode 100755
index 0000000000000..258255676a860
--- /dev/null
+++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala
@@ -0,0 +1,108 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import com.amazonaws.auth.AWSCredentialsProvider
+import com.amazonaws.services.s3.AmazonS3Client
+import org.slf4j.LoggerFactory
+
+import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext}
+import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Redshift Source implementation for Spark SQL
+ */
+class DefaultSource(
+ jdbcWrapper: JDBCWrapper,
+ s3ClientFactory: AWSCredentialsProvider => AmazonS3Client)
+ extends RelationProvider
+ with SchemaRelationProvider
+ with CreatableRelationProvider {
+
+ private val log = LoggerFactory.getLogger(getClass)
+
+ /**
+ * Default constructor required by Data Source API
+ */
+ def this() = this(DefaultJDBCWrapper, awsCredentials => new AmazonS3Client(awsCredentials))
+
+ /**
+ * Create a new RedshiftRelation instance using parameters from Spark SQL DDL. Resolves the schema
+ * using JDBC connection over provided URL, which must contain credentials.
+ */
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String]): BaseRelation = {
+ val params = Parameters.mergeParameters(parameters)
+ RedshiftRelation(jdbcWrapper, s3ClientFactory, params, None)(sqlContext)
+ }
+
+ /**
+ * Load a RedshiftRelation using user-provided schema, so no inference over JDBC will be used.
+ */
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String],
+ schema: StructType): BaseRelation = {
+ val params = Parameters.mergeParameters(parameters)
+ RedshiftRelation(jdbcWrapper, s3ClientFactory, params, Some(schema))(sqlContext)
+ }
+
+ /**
+ * Creates a Relation instance by first writing the contents of the given DataFrame to Redshift
+ */
+ override def createRelation(
+ sqlContext: SQLContext,
+ saveMode: SaveMode,
+ parameters: Map[String, String],
+ data: DataFrame): BaseRelation = {
+ val params = Parameters.mergeParameters(parameters)
+ val table = params.table.getOrElse {
+ throw new IllegalArgumentException(
+ "For save operations you must specify a Redshift table name with the 'dbtable' parameter")
+ }
+
+ def tableExists: Boolean = {
+ val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
+ try {
+ jdbcWrapper.tableExists(conn, table.toString)
+ } finally {
+ conn.close()
+ }
+ }
+
+ val (doSave, dropExisting) = saveMode match {
+ case SaveMode.Append => (true, false)
+ case SaveMode.Overwrite => (true, true)
+ case SaveMode.ErrorIfExists =>
+ if (tableExists) {
+ sys.error(s"Table $table already exists! (SaveMode is set to ErrorIfExists)")
+ } else {
+ (true, false)
+ }
+ case SaveMode.Ignore =>
+ if (tableExists) {
+ log.info(s"Table $table already exists -- ignoring save request.")
+ (false, false)
+ } else {
+ (true, false)
+ }
+ }
+
+ if (doSave) {
+ val updatedParams = parameters.updated("overwrite", dropExisting.toString)
+ new RedshiftWriter(jdbcWrapper, s3ClientFactory).saveToRedshift(
+ sqlContext, data, saveMode, Parameters.mergeParameters(updatedParams))
+ }
+
+ createRelation(sqlContext, parameters)
+ }
+}
diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/FilterPushdown.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/FilterPushdown.scala
new file mode 100755
index 0000000000000..786b3f6257f6e
--- /dev/null
+++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/FilterPushdown.scala
@@ -0,0 +1,77 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.sql.{Date, Timestamp}
+
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+
+/**
+ * Helper methods for pushing filters into Redshift queries.
+ */
+private[redshift] object FilterPushdown {
+ /**
+ * Build a SQL WHERE clause for the given filters. If a filter cannot be pushed down then no
+ * condition will be added to the WHERE clause. If none of the filters can be pushed down then
+ * an empty string will be returned.
+ *
+ * @param schema the schema of the table being queried
+ * @param filters an array of filters, the conjunction of which is the filter condition for the
+ * scan.
+ */
+ def buildWhereClause(schema: StructType, filters: Seq[Filter]): String = {
+ val filterExpressions = filters.flatMap(f => buildFilterExpression(schema, f)).mkString(" AND ")
+ if (filterExpressions.isEmpty) "" else "WHERE " + filterExpressions
+ }
+
+ /**
+ * Attempt to convert the given filter into a SQL expression. Returns None if the expression
+ * could not be converted.
+ */
+ def buildFilterExpression(schema: StructType, filter: Filter): Option[String] = {
+ def buildComparison(attr: String, value: Any, comparisonOp: String): Option[String] = {
+ getTypeForAttribute(schema, attr).map { dataType =>
+ val sqlEscapedValue: String = dataType match {
+ case StringType => s"\\'${value.toString.replace("'", "\\'\\'")}\\'"
+ case DateType => s"\\'${value.asInstanceOf[Date]}\\'"
+ case TimestampType => s"\\'${value.asInstanceOf[Timestamp]}\\'"
+ case _ => value.toString
+ }
+ s""""$attr" $comparisonOp $sqlEscapedValue"""
+ }
+ }
+
+ filter match {
+ case EqualTo(attr, value) => buildComparison(attr, value, "=")
+ case LessThan(attr, value) => buildComparison(attr, value, "<")
+ case GreaterThan(attr, value) => buildComparison(attr, value, ">")
+ case LessThanOrEqual(attr, value) => buildComparison(attr, value, "<=")
+ case GreaterThanOrEqual(attr, value) => buildComparison(attr, value, ">=")
+ case IsNotNull(attr) =>
+ getTypeForAttribute(schema, attr).map(dataType => s""""$attr" IS NOT NULL""")
+ case IsNull(attr) =>
+ getTypeForAttribute(schema, attr).map(dataType => s""""$attr" IS NULL""")
+ case _ => None
+ }
+ }
+
+ /**
+ * Use the given schema to look up the attribute's data type. Returns None if the attribute could
+ * not be resolved.
+ */
+ private def getTypeForAttribute(schema: StructType, attribute: String): Option[DataType] = {
+ if (schema.fieldNames.contains(attribute)) {
+ Some(schema(attribute).dataType)
+ } else {
+ None
+ }
+ }
+}
diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/Parameters.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/Parameters.scala
new file mode 100755
index 0000000000000..8789b5c8ae2d6
--- /dev/null
+++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/Parameters.scala
@@ -0,0 +1,282 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import com.amazonaws.auth.{AWSCredentialsProvider, BasicSessionCredentials}
+
+/**
+ * All user-specifiable parameters for spark-redshift, along with their validation rules and
+ * defaults.
+ */
+private[redshift] object Parameters {
+
+ val DEFAULT_PARAMETERS: Map[String, String] = Map(
+ // Notes:
+ // * tempdir, dbtable and url have no default and they *must* be provided
+ // * sortkeyspec has no default, but is optional
+ // * distkey has no default, but is optional unless using diststyle KEY
+ // * jdbcdriver has no default, but is optional
+
+ "forward_spark_s3_credentials" -> "false",
+ "tempformat" -> "AVRO",
+ "csvnullstring" -> "@NULL@",
+ "overwrite" -> "false",
+ "diststyle" -> "EVEN",
+ "usestagingtable" -> "true",
+ "preactions" -> ";",
+ "postactions" -> ";"
+ )
+
+ val VALID_TEMP_FORMATS = Set("AVRO", "CSV", "CSV GZIP")
+
+ /**
+ * Merge user parameters with the defaults, preferring user parameters if specified
+ */
+ def mergeParameters(userParameters: Map[String, String]): MergedParameters = {
+ if (!userParameters.contains("tempdir")) {
+ throw new IllegalArgumentException("'tempdir' is required for all Redshift loads and saves")
+ }
+ if (userParameters.contains("tempformat") &&
+ !VALID_TEMP_FORMATS.contains(userParameters("tempformat").toUpperCase)) {
+ throw new IllegalArgumentException(
+ s"""Invalid temp format: ${userParameters("tempformat")}; """ +
+ s"valid formats are: ${VALID_TEMP_FORMATS.mkString(", ")}")
+ }
+ if (!userParameters.contains("url")) {
+ throw new IllegalArgumentException("A JDBC URL must be provided with 'url' parameter")
+ }
+ if (!userParameters.contains("dbtable") && !userParameters.contains("query")) {
+ throw new IllegalArgumentException(
+ "You must specify a Redshift table name with the 'dbtable' parameter or a query with the " +
+ "'query' parameter.")
+ }
+ if (userParameters.contains("dbtable") && userParameters.contains("query")) {
+ throw new IllegalArgumentException(
+ "You cannot specify both the 'dbtable' and 'query' parameters at the same time.")
+ }
+ val credsInURL = userParameters.get("url")
+ .filter(url => url.contains("user=") || url.contains("password="))
+ if (userParameters.contains("user") || userParameters.contains("password")) {
+ if (credsInURL.isDefined) {
+ throw new IllegalArgumentException(
+ "You cannot specify credentials in both the URL and as user/password options")
+ }
+ } else if (credsInURL.isEmpty) {
+ throw new IllegalArgumentException(
+ "You must specify credentials in either the URL or as user/password options")
+ }
+
+ MergedParameters(DEFAULT_PARAMETERS ++ userParameters)
+ }
+
+ /**
+ * Adds validators and accessors to string map
+ */
+ case class MergedParameters(parameters: Map[String, String]) {
+
+ require(temporaryAWSCredentials.isDefined || iamRole.isDefined || forwardSparkS3Credentials,
+ "You must specify a method for authenticating Redshift's connection to S3 (aws_iam_role," +
+ " forward_spark_s3_credentials, or temporary_aws_*. For a discussion of the differences" +
+ " between these options, please see the README.")
+
+ require(Seq(
+ temporaryAWSCredentials.isDefined,
+ iamRole.isDefined,
+ forwardSparkS3Credentials).count(_ == true) == 1,
+ "The aws_iam_role, forward_spark_s3_credentials, and temporary_aws_*. options are " +
+ "mutually-exclusive; please specify only one.")
+
+ /**
+ * A root directory to be used for intermediate data exchange, expected to be on S3, or
+ * somewhere that can be written to and read from by Redshift. Make sure that AWS credentials
+ * are available for S3.
+ */
+ def rootTempDir: String = parameters("tempdir")
+
+ /**
+ * The format in which to save temporary files in S3. Defaults to "AVRO"; the other allowed
+ * values are "CSV" and "CSV GZIP" for CSV and gzipped CSV, respectively.
+ */
+ def tempFormat: String = parameters("tempformat").toUpperCase
+
+ /**
+ * The String value to write for nulls when using CSV.
+ * This should be a value which does not appear in your actual data.
+ */
+ def nullString: String = parameters("csvnullstring")
+
+ /**
+ * Creates a per-query subdirectory in the [[rootTempDir]], with a random UUID.
+ */
+ def createPerQueryTempDir(): String = Utils.makeTempPath(rootTempDir)
+
+ /**
+ * The Redshift table to be used as the target when loading or writing data.
+ */
+ def table: Option[TableName] = parameters.get("dbtable").map(_.trim).flatMap { dbtable =>
+ // We technically allow queries to be passed using `dbtable` as long as they are wrapped
+ // in parentheses. Valid SQL identifiers may contain parentheses but cannot begin with them,
+ // so there is no ambiguity in ignoring subqeries here and leaving their handling up to
+ // the `query` function defined below.
+ if (dbtable.startsWith("(") && dbtable.endsWith(")")) {
+ None
+ } else {
+ Some(TableName.parseFromEscaped(dbtable))
+ }
+ }
+
+ /**
+ * The Redshift query to be used as the target when loading data.
+ */
+ def query: Option[String] = parameters.get("query").orElse {
+ parameters.get("dbtable")
+ .map(_.trim)
+ .filter(t => t.startsWith("(") && t.endsWith(")"))
+ .map(t => t.drop(1).dropRight(1))
+ }
+
+ /**
+ * User and password to be used to authenticate to Redshift
+ */
+ def credentials: Option[(String, String)] = {
+ for (
+ user <- parameters.get("user");
+ password <- parameters.get("password")
+ ) yield (user, password)
+ }
+
+ /**
+ * A JDBC URL, of the format:
+ *
+ * jdbc:subprotocol://host:port/database?user=username&password=password
+ *
+ * Where:
+ * - subprotocol can be postgresql or redshift, depending on which JDBC driver you have loaded.
+ * Note however that one Redshift-compatible driver must be on the classpath and match this
+ * URL.
+ * - host and port should point to the Redshift master node, so security groups and/or VPC will
+ * need to be configured to allow access from the Spark driver
+ * - database identifies a Redshift database name
+ * - user and password are credentials to access the database, which must be embedded in this
+ * URL for JDBC
+ */
+ def jdbcUrl: String = parameters("url")
+
+ /**
+ * The JDBC driver class name. This is used to make sure the driver is registered before
+ * connecting over JDBC.
+ */
+ def jdbcDriver: Option[String] = parameters.get("jdbcdriver")
+
+ /**
+ * Set the Redshift table distribution style, which can be one of: EVEN, KEY or ALL. If you set
+ * it to KEY, you'll also need to use the distkey parameter to set the distribution key.
+ *
+ * Default is EVEN.
+ */
+ def distStyle: Option[String] = parameters.get("diststyle")
+
+ /**
+ * The name of a column in the table to use as the distribution key when using DISTSTYLE KEY.
+ * Not set by default, as default DISTSTYLE is EVEN.
+ */
+ def distKey: Option[String] = parameters.get("distkey")
+
+ /**
+ * A full Redshift SORTKEY specification. For full information, see latest Redshift docs:
+ * http://docs.aws.amazon.com/redshift/latest/dg/r_CREATE_TABLE_NEW.html
+ *
+ * Examples:
+ * SORTKEY (my_sort_column)
+ * COMPOUND SORTKEY (sort_col1, sort_col2)
+ * INTERLEAVED SORTKEY (sort_col1, sort_col2)
+ *
+ * Not set by default - table will be unsorted.
+ *
+ * Note: appending data to a table with a sort key only makes sense if you know that the data
+ * being added will be after the data already in the table according to the sort order. Redshift
+ * does not support random inserts according to sort order, so performance will degrade if you
+ * try this.
+ */
+ def sortKeySpec: Option[String] = parameters.get("sortkeyspec")
+
+ /**
+ * DEPRECATED: see PR #157.
+ *
+ * When true, data is always loaded into a new temporary table when performing an overwrite.
+ * This is to ensure that the whole load process succeeds before dropping any data from
+ * Redshift, which can be useful if, in the event of failures, stale data is better than no data
+ * for your systems.
+ *
+ * Defaults to true.
+ */
+ def useStagingTable: Boolean = parameters("usestagingtable").toBoolean
+
+ /**
+ * Extra options to append to the Redshift COPY command (e.g. "MAXERROR 100").
+ */
+ def extraCopyOptions: String = parameters.get("extracopyoptions").getOrElse("")
+
+ /**
+ * Description of the table, set using the SQL COMMENT command.
+ */
+ def description: Option[String] = parameters.get("description")
+
+ /**
+ * List of semi-colon separated SQL statements to run before write operations.
+ * This can be useful for running DELETE operations to clean up data
+ *
+ * If the action string contains %s, the table name will be substituted in, in case a staging
+ * table is being used.
+ *
+ * Defaults to empty.
+ */
+ def preActions: Array[String] = parameters("preactions").split(";")
+
+ /**
+ * List of semi-colon separated SQL statements to run after successful write operations.
+ * This can be useful for running GRANT operations to make your new tables readable to other
+ * users and groups.
+ *
+ * If the action string contains %s, the table name will be substituted in, in case a staging
+ * table is being used.
+ *
+ * Defaults to empty.
+ */
+ def postActions: Array[String] = parameters("postactions").split(";")
+
+ /**
+ * The IAM role that Redshift should assume for COPY/UNLOAD operations.
+ */
+ def iamRole: Option[String] = parameters.get("aws_iam_role")
+
+ /**
+ * If true then this library will automatically discover the credentials that Spark is
+ * using to connect to S3 and will forward those credentials to Redshift over JDBC.
+ */
+ def forwardSparkS3Credentials: Boolean = parameters("forward_spark_s3_credentials").toBoolean
+
+ /**
+ * Temporary AWS credentials which are passed to Redshift. These only need to be supplied by
+ * the user when Hadoop is configured to authenticate to S3 via IAM roles assigned to EC2
+ * instances.
+ */
+ def temporaryAWSCredentials: Option[AWSCredentialsProvider] = {
+ for (
+ accessKey <- parameters.get("temporary_aws_access_key_id");
+ secretAccessKey <- parameters.get("temporary_aws_secret_access_key");
+ sessionToken <- parameters.get("temporary_aws_session_token")
+ ) yield {
+ AWSCredentialsUtils.staticCredentialsProvider(
+ new BasicSessionCredentials(accessKey, secretAccessKey, sessionToken))
+ }
+ }
+ }
+}
diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/RecordReaderIterator.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/RecordReaderIterator.scala
new file mode 100755
index 0000000000000..acfb9140fde53
--- /dev/null
+++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/RecordReaderIterator.scala
@@ -0,0 +1,62 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.io.Closeable
+
+import org.apache.hadoop.mapreduce.RecordReader
+
+/**
+ * An adaptor from a Hadoop [[RecordReader]] to an [[Iterator]] over the values returned.
+ *
+ * This is copied from Apache Spark and is inlined here to avoid depending on Spark internals
+ * in this external library.
+ */
+private[redshift] class RecordReaderIterator[T](
+ private[this] var rowReader: RecordReader[_, T]) extends Iterator[T] with Closeable {
+ private[this] var havePair = false
+ private[this] var finished = false
+
+ override def hasNext: Boolean = {
+ if (!finished && !havePair) {
+ finished = !rowReader.nextKeyValue
+ if (finished) {
+ // Close and release the reader here; close() will also be called when the task
+ // completes, but for tasks that read from many files, it helps to release the
+ // resources early.
+ close()
+ }
+ havePair = !finished
+ }
+ !finished
+ }
+
+ override def next(): T = {
+ if (!hasNext) {
+ throw new java.util.NoSuchElementException("End of stream")
+ }
+ havePair = false
+ rowReader.getCurrentValue
+ }
+
+ override def close(): Unit = {
+ if (rowReader != null) {
+ try {
+ // Close the reader and release it. Note: it's very important that we don't close the
+ // reader more than once, since that exposes us to MAPREDUCE-5918 when running against
+ // older Hadoop 2.x releases. That bug can lead to non-deterministic corruption issues
+ // when reading compressed input.
+ rowReader.close()
+ } finally {
+ rowReader = null
+ }
+ }
+ }
+}
diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftFileFormat.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftFileFormat.scala
new file mode 100755
index 0000000000000..a45daadcf01f2
--- /dev/null
+++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftFileFormat.scala
@@ -0,0 +1,96 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.net.URI
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Internal data source used for reading Redshift UNLOAD files.
+ *
+ * This is not intended for public consumption / use outside of this package and therefore
+ * no API stability is guaranteed.
+ */
+private[redshift] class RedshiftFileFormat extends FileFormat {
+ override def inferSchema(
+ sparkSession: SparkSession,
+ options: Map[String, String],
+ files: Seq[FileStatus]): Option[StructType] = {
+ // Schema is provided by caller.
+ None
+ }
+
+ override def prepareWrite(
+ sparkSession: SparkSession,
+ job: Job,
+ options: Map[String, String],
+ dataSchema: StructType): OutputWriterFactory = {
+ throw new UnsupportedOperationException(s"prepareWrite is not supported for $this")
+ }
+
+ override def isSplitable(
+ sparkSession: SparkSession,
+ options: Map[String, String],
+ path: Path): Boolean = {
+ // Our custom InputFormat handles split records properly
+ true
+ }
+
+ override def buildReader(
+ sparkSession: SparkSession,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String],
+ hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
+
+ require(partitionSchema.isEmpty)
+ require(filters.isEmpty)
+ require(dataSchema == requiredSchema)
+
+ val broadcastedConf =
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
+
+ (file: PartitionedFile) => {
+ val conf = broadcastedConf.value.value
+
+ val fileSplit = new FileSplit(
+ new Path(new URI(file.filePath)),
+ file.start,
+ file.length,
+ // TODO: Implement Locality
+ Array.empty)
+ val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
+ val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+ val reader = new RedshiftRecordReader
+ reader.initialize(fileSplit, hadoopAttemptContext)
+ val iter = new RecordReaderIterator[Array[String]](reader)
+ // Ensure that the record reader is closed upon task completion. It will ordinarily
+ // be closed once it is completely iterated, but this is necessary to guard against
+ // resource leaks in case the task fails or is interrupted.
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
+ val converter = Conversions.createRowConverter(requiredSchema)
+ iter.map(converter)
+ }
+ }
+}
diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftInputFormat.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftInputFormat.scala
new file mode 100755
index 0000000000000..94829c028e520
--- /dev/null
+++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftInputFormat.scala
@@ -0,0 +1,252 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.io.{BufferedInputStream, IOException}
+import java.lang.{Long => JavaLong}
+import java.nio.charset.Charset
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.io.compress.CompressionCodecFactory
+import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit}
+
+/**
+ * Input format for text records saved with in-record delimiter and newline characters escaped.
+ *
+ * For example, a record containing two fields: `"a\n"` and `"|b\\"` saved with delimiter `|`
+ * should be the following:
+ * {{{
+ * a\\\n|\\|b\\\\\n
+ * }}},
+ * where the in-record `|`, `\r`, `\n`, and `\\` characters are escaped by `\\`.
+ * Users can configure the delimiter via [[RedshiftInputFormat$#KEY_DELIMITER]].
+ * Its default value [[RedshiftInputFormat$#DEFAULT_DELIMITER]] is set to match Redshift's UNLOAD
+ * with the ESCAPE option:
+ * {{{
+ * UNLOAD ('select_statement')
+ * TO 's3://object_path_prefix'
+ * ESCAPE
+ * }}}
+ *
+ * @see org.apache.spark.SparkContext#newAPIHadoopFile
+ */
+class RedshiftInputFormat extends FileInputFormat[JavaLong, Array[String]] {
+
+ override def createRecordReader(
+ split: InputSplit,
+ context: TaskAttemptContext): RecordReader[JavaLong, Array[String]] = {
+ new RedshiftRecordReader
+ }
+}
+
+object RedshiftInputFormat {
+
+ /** configuration key for delimiter */
+ val KEY_DELIMITER = "redshift.delimiter"
+ /** default delimiter */
+ val DEFAULT_DELIMITER = '|'
+
+ /** Gets the delimiter char from conf or the default. */
+ private[redshift] def getDelimiterOrDefault(conf: Configuration): Char = {
+ val c = conf.get(KEY_DELIMITER, DEFAULT_DELIMITER.toString)
+ if (c.length != 1) {
+ throw new IllegalArgumentException(s"Expect delimiter be a single character but got '$c'.")
+ } else {
+ c.charAt(0)
+ }
+ }
+}
+
+private[redshift] class RedshiftRecordReader extends RecordReader[JavaLong, Array[String]] {
+
+ private var reader: BufferedInputStream = _
+
+ private var key: JavaLong = _
+ private var value: Array[String] = _
+
+ private var start: Long = _
+ private var end: Long = _
+ private var cur: Long = _
+
+ private var eof: Boolean = false
+
+ private var delimiter: Byte = _
+ @inline private[this] final val escapeChar: Byte = '\\'
+ @inline private[this] final val lineFeed: Byte = '\n'
+ @inline private[this] final val carriageReturn: Byte = '\r'
+
+ @inline private[this] final val defaultBufferSize = 1024 * 1024
+
+ private[this] val chars = ArrayBuffer.empty[Byte]
+
+ override def initialize(inputSplit: InputSplit, context: TaskAttemptContext): Unit = {
+ val split = inputSplit.asInstanceOf[FileSplit]
+ val file = split.getPath
+ val conf: Configuration = context.getConfiguration
+ delimiter = RedshiftInputFormat.getDelimiterOrDefault(conf).asInstanceOf[Byte]
+ require(delimiter != escapeChar,
+ s"The delimiter and the escape char cannot be the same but found $delimiter.")
+ require(delimiter != lineFeed, "The delimiter cannot be the lineFeed character.")
+ require(delimiter != carriageReturn, "The delimiter cannot be the carriage return.")
+ val compressionCodecs = new CompressionCodecFactory(conf)
+ val codec = compressionCodecs.getCodec(file)
+ if (codec != null) {
+ throw new IOException(s"Do not support compressed files but found $file.")
+ }
+ val fs = file.getFileSystem(conf)
+ val size = fs.getFileStatus(file).getLen
+ start = findNext(fs, file, size, split.getStart)
+ end = findNext(fs, file, size, split.getStart + split.getLength)
+ cur = start
+ val in = fs.open(file)
+ if (cur > 0L) {
+ in.seek(cur - 1L)
+ in.read()
+ }
+ reader = new BufferedInputStream(in, defaultBufferSize)
+ }
+
+ override def getProgress: Float = {
+ if (start >= end) {
+ 1.0f
+ } else {
+ math.min((cur - start).toFloat / (end - start), 1.0f)
+ }
+ }
+
+ override def nextKeyValue(): Boolean = {
+ if (cur < end && !eof) {
+ key = cur
+ value = nextValue()
+ true
+ } else {
+ key = null
+ value = null
+ false
+ }
+ }
+
+ override def getCurrentValue: Array[String] = value
+
+ override def getCurrentKey: JavaLong = key
+
+ override def close(): Unit = {
+ if (reader != null) {
+ reader.close()
+ }
+ }
+
+ /**
+ * Finds the start of the next record.
+ * Because we don't know whether the first char is escaped or not, we need to first find a
+ * position that is not escaped.
+ *
+ * @param fs file system
+ * @param file file path
+ * @param size file size
+ * @param offset start offset
+ * @return the start position of the next record
+ */
+ private def findNext(fs: FileSystem, file: Path, size: Long, offset: Long): Long = {
+ if (offset == 0L) {
+ return 0L
+ } else if (offset >= size) {
+ return size
+ }
+ val in = fs.open(file)
+ var pos = offset
+ in.seek(pos)
+ val bis = new BufferedInputStream(in, defaultBufferSize)
+ // Find the first unescaped char.
+ var escaped = true
+ var thisEof = false
+ while (escaped && !thisEof) {
+ val v = bis.read()
+ if (v < 0) {
+ thisEof = true
+ } else {
+ pos += 1
+ if (v != escapeChar) {
+ escaped = false
+ }
+ }
+ }
+ // Find the next unescaped line feed.
+ var endOfRecord = false
+ while ((escaped || !endOfRecord) && !thisEof) {
+ val v = bis.read()
+ if (v < 0) {
+ thisEof = true
+ } else {
+ pos += 1
+ if (v == escapeChar) {
+ escaped = true
+ } else {
+ if (!escaped) {
+ endOfRecord = v == lineFeed
+ } else {
+ escaped = false
+ }
+ }
+ }
+ }
+ in.close()
+ pos
+ }
+
+ private def nextValue(): Array[String] = {
+ val fields = ArrayBuffer.empty[String]
+ var escaped = false
+ var endOfRecord = false
+ while (!endOfRecord && !eof) {
+ var endOfField = false
+ chars.clear()
+ while (!endOfField && !endOfRecord && !eof) {
+ val v = reader.read()
+ if (v < 0) {
+ eof = true
+ } else {
+ cur += 1L
+ val c = v.asInstanceOf[Byte]
+ if (escaped) {
+ if (c != escapeChar && c != delimiter && c != lineFeed && c != carriageReturn) {
+ throw new IllegalStateException(
+ s"Found `$c` (ASCII $v) after $escapeChar.")
+ }
+ chars.append(c)
+ escaped = false
+ } else {
+ if (c == escapeChar) {
+ escaped = true
+ } else if (c == delimiter) {
+ endOfField = true
+ } else if (c == lineFeed) {
+ endOfRecord = true
+ } else {
+ // also copy carriage return
+ chars.append(c)
+ }
+ }
+ }
+ }
+ // TODO: charset?
+ fields.append(new String(chars.toArray, Charset.forName("UTF-8")))
+ }
+ if (escaped) {
+ throw new IllegalStateException(s"Found hanging escape char.")
+ }
+ fields.toArray
+ }
+}
+
diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala
new file mode 100755
index 0000000000000..c110176859d9e
--- /dev/null
+++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala
@@ -0,0 +1,347 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException}
+import java.util.Properties
+import java.util.concurrent.{Executors, ThreadFactory}
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.JavaConverters._
+import scala.concurrent.{Await, ExecutionContext, Future}
+import scala.concurrent.duration.Duration
+import scala.util.Try
+import scala.util.control.NonFatal
+
+import org.slf4j.LoggerFactory
+
+import org.apache.spark.SPARK_VERSION
+import org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry
+import org.apache.spark.sql.types._
+
+/**
+ * Shim which exposes some JDBC helper functions. Most of this code is copied from Spark SQL, with
+ * minor modifications for Redshift-specific features and limitations.
+ */
+private[redshift] class JDBCWrapper {
+
+ private val log = LoggerFactory.getLogger(getClass)
+
+ private val ec: ExecutionContext = {
+ val threadFactory = new ThreadFactory {
+ private[this] val count = new AtomicInteger()
+ override def newThread(r: Runnable) = {
+ val thread = new Thread(r)
+ thread.setName(s"spark-redshift-JDBCWrapper-${count.incrementAndGet}")
+ thread.setDaemon(true)
+ thread
+ }
+ }
+ ExecutionContext.fromExecutorService(Executors.newCachedThreadPool(threadFactory))
+ }
+
+ /**
+ * Given a JDBC subprotocol, returns the name of the appropriate driver class to use.
+ *
+ * If the user has explicitly specified a driver class in their configuration then that class will
+ * be used. Otherwise, we will attempt to load the correct driver class based on
+ * the JDBC subprotocol.
+ *
+ * @param jdbcSubprotocol 'redshift' or 'postgresql'
+ * @param userProvidedDriverClass an optional user-provided explicit driver class name
+ * @return the driver class
+ */
+ private def getDriverClass(
+ jdbcSubprotocol: String,
+ userProvidedDriverClass: Option[String]): String = {
+ userProvidedDriverClass.getOrElse {
+ jdbcSubprotocol match {
+ case "redshift" =>
+ try {
+ Utils.classForName("com.amazon.redshift.jdbc42.Driver").getName
+ } catch {
+ case _: ClassNotFoundException =>
+ try {
+ Utils.classForName("com.amazon.redshift.jdbc41.Driver").getName
+ } catch {
+ case _: ClassNotFoundException =>
+ try {
+ Utils.classForName("com.amazon.redshift.jdbc4.Driver").getName
+ } catch {
+ case e: ClassNotFoundException =>
+ throw new ClassNotFoundException(
+ "Could not load an Amazon Redshift JDBC driver; see the README for " +
+ "instructions on downloading and configuring the official Amazon driver.",
+ e
+ )
+ }
+ }
+ }
+ case "postgresql" => "org.postgresql.Driver"
+ case other => throw new IllegalArgumentException(s"Unsupported JDBC protocol: '$other'")
+ }
+ }
+ }
+
+ /**
+ * Execute the given SQL statement while supporting interruption.
+ * If InterruptedException is caught, then the statement will be cancelled if it is running.
+ *
+ * @return true
if the first result is a ResultSet
+ * object; false
if the first result is an update
+ * count or there is no result
+ */
+ def executeInterruptibly(statement: PreparedStatement): Boolean = {
+ executeInterruptibly(statement, _.execute())
+ }
+
+ /**
+ * Execute the given SQL statement while supporting interruption.
+ * If InterruptedException is caught, then the statement will be cancelled if it is running.
+ *
+ * @return a ResultSet
object that contains the data produced by the
+ * query; never null
+ */
+ def executeQueryInterruptibly(statement: PreparedStatement): ResultSet = {
+ executeInterruptibly(statement, _.executeQuery())
+ }
+
+ private def executeInterruptibly[T](
+ statement: PreparedStatement,
+ op: PreparedStatement => T): T = {
+ try {
+ val future = Future[T](op(statement))(ec)
+ try {
+ // scalastyle:off awaitresult
+ Await.result(future, Duration.Inf)
+ // scalastyle:on awaitresult
+ } catch {
+ case e: SQLException =>
+ // Wrap and re-throw so that this thread's stacktrace appears to the user.
+ throw new SQLException("Exception thrown in awaitResult: ", e)
+ case NonFatal(t) =>
+ // Wrap and re-throw so that this thread's stacktrace appears to the user.
+ throw new Exception("Exception thrown in awaitResult: ", t)
+ }
+ } catch {
+ case e: InterruptedException =>
+ try {
+ statement.cancel()
+ throw e
+ } catch {
+ case s: SQLException =>
+ log.error("Exception occurred while cancelling query", s)
+ throw e
+ }
+ }
+ }
+
+ /**
+ * Takes a (schema, table) specification and returns the table's Catalyst
+ * schema.
+ *
+ * @param conn A JDBC connection to the database.
+ * @param table The table name of the desired table. This may also be a
+ * SQL query wrapped in parentheses.
+ *
+ * @return A StructType giving the table's Catalyst schema.
+ * @throws SQLException if the table specification is garbage.
+ * @throws SQLException if the table contains an unsupported type.
+ */
+ def resolveTable(conn: Connection, table: String): StructType = {
+ // It's important to leave the `LIMIT 1` clause in order to limit the work of the query in case
+ // the underlying JDBC driver implementation implements PreparedStatement.getMetaData() by
+ // executing the query. It looks like the standard Redshift and Postgres JDBC drivers don't do
+ // this but we leave the LIMIT condition here as a safety-net to guard against perf regressions.
+ val ps = conn.prepareStatement(s"SELECT * FROM $table LIMIT 1")
+ try {
+ val rsmd = executeInterruptibly(ps, _.getMetaData)
+ val ncols = rsmd.getColumnCount
+ val fields = new Array[StructField](ncols)
+ var i = 0
+ while (i < ncols) {
+ val columnName = rsmd.getColumnLabel(i + 1)
+ val dataType = rsmd.getColumnType(i + 1)
+ val fieldSize = rsmd.getPrecision(i + 1)
+ val fieldScale = rsmd.getScale(i + 1)
+ val isSigned = rsmd.isSigned(i + 1)
+ val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
+ val columnType = getCatalystType(dataType, fieldSize, fieldScale, isSigned)
+ fields(i) = StructField(columnName, columnType, nullable)
+ i = i + 1
+ }
+ new StructType(fields)
+ } finally {
+ ps.close()
+ }
+ }
+
+ /**
+ * Given a driver string and a JDBC url, load the specified driver and return a DB connection.
+ *
+ * @param userProvidedDriverClass the class name of the JDBC driver for the given url. If this
+ * is None then `spark-redshift` will attempt to automatically
+ * discover the appropriate driver class.
+ * @param url the JDBC url to connect to.
+ */
+ def getConnector(
+ userProvidedDriverClass: Option[String],
+ url: String,
+ credentials: Option[(String, String)]) : Connection = {
+ val subprotocol = url.stripPrefix("jdbc:").split(":")(0)
+ val driverClass: String = getDriverClass(subprotocol, userProvidedDriverClass)
+ DriverRegistry.register(driverClass)
+ val driverWrapperClass: Class[_] = if (SPARK_VERSION.startsWith("1.4")) {
+ Utils.classForName("org.apache.spark.sql.jdbc.package$DriverWrapper")
+ } else { // Spark 1.5.0+
+ Utils.classForName("org.apache.spark.sql.execution.datasources.jdbc.DriverWrapper")
+ }
+ def getWrapped(d: Driver): Driver = {
+ require(driverWrapperClass.isAssignableFrom(d.getClass))
+ driverWrapperClass.getDeclaredMethod("wrapped").invoke(d).asInstanceOf[Driver]
+ }
+ // Note that we purposely don't call DriverManager.getConnection() here: we want to ensure
+ // that an explicitly-specified user-provided driver class can take precedence over the default
+ // class, but DriverManager.getConnection() might return a according to a different precedence.
+ // At the same time, we don't want to create a driver-per-connection, so we use the
+ // DriverManager's driver instances to handle that singleton logic for us.
+ val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
+ case d if driverWrapperClass.isAssignableFrom(d.getClass)
+ && getWrapped(d).getClass.getCanonicalName == driverClass => d
+ case d if d.getClass.getCanonicalName == driverClass => d
+ }.getOrElse {
+ throw new IllegalArgumentException(s"Did not find registered driver with class $driverClass")
+ }
+ val properties = new Properties()
+ credentials.foreach { case(user, password) =>
+ properties.setProperty("user", user)
+ properties.setProperty("password", password)
+ }
+ driver.connect(url, properties)
+ }
+
+ /**
+ * Compute the SQL schema string for the given Spark SQL Schema.
+ */
+ def schemaString(schema: StructType): String = {
+ val sb = new StringBuilder()
+ schema.fields.foreach { field => {
+ val name = field.name
+ val typ: String = if (field.metadata.contains("redshift_type")) {
+ field.metadata.getString("redshift_type")
+ } else {
+ field.dataType match {
+ case IntegerType => "INTEGER"
+ case LongType => "BIGINT"
+ case DoubleType => "DOUBLE PRECISION"
+ case FloatType => "REAL"
+ case ShortType => "INTEGER"
+ case ByteType => "SMALLINT" // Redshift does not support the BYTE type.
+ case BooleanType => "BOOLEAN"
+ case StringType =>
+ if (field.metadata.contains("maxlength")) {
+ s"VARCHAR(${field.metadata.getLong("maxlength")})"
+ } else {
+ "TEXT"
+ }
+ case TimestampType => "TIMESTAMP"
+ case DateType => "DATE"
+ case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})"
+ case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
+ }
+ }
+
+ val nullable = if (field.nullable) "" else "NOT NULL"
+ val encoding = if (field.metadata.contains("encoding")) {
+ s"ENCODE ${field.metadata.getString("encoding")}"
+ } else {
+ ""
+ }
+ sb.append(s""", "${name.replace("\"", "\\\"")}" $typ $nullable $encoding""".trim)
+ }}
+ if (sb.length < 2) "" else sb.substring(2)
+ }
+
+ /**
+ * Returns true if the table already exists in the JDBC database.
+ */
+ def tableExists(conn: Connection, table: String): Boolean = {
+ // Somewhat hacky, but there isn't a good way to identify whether a table exists for all
+ // SQL database systems, considering "table" could also include the database name.
+ Try {
+ val stmt = conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1")
+ executeInterruptibly(stmt, _.getMetaData).getColumnCount
+ }.isSuccess
+ }
+
+ /**
+ * Maps a JDBC type to a Catalyst type.
+ *
+ * @param sqlType - A field of java.sql.Types
+ * @return The Catalyst type corresponding to sqlType.
+ */
+ private def getCatalystType(
+ sqlType: Int,
+ precision: Int,
+ scale: Int,
+ signed: Boolean): DataType = {
+ // TODO: cleanup types which are irrelevant for Redshift.
+ val answer = sqlType match {
+ // scalastyle:off
+ case java.sql.Types.ARRAY => null
+ case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) }
+ case java.sql.Types.BINARY => BinaryType
+ case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks
+ case java.sql.Types.BLOB => BinaryType
+ case java.sql.Types.BOOLEAN => BooleanType
+ case java.sql.Types.CHAR => StringType
+ case java.sql.Types.CLOB => StringType
+ case java.sql.Types.DATALINK => null
+ case java.sql.Types.DATE => DateType
+ case java.sql.Types.DECIMAL
+ if precision != 0 || scale != 0 => DecimalType(precision, scale)
+ case java.sql.Types.DECIMAL => DecimalType(38, 18) // Spark 1.5.0 default
+ case java.sql.Types.DISTINCT => null
+ case java.sql.Types.DOUBLE => DoubleType
+ case java.sql.Types.FLOAT => FloatType
+ case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType }
+ case java.sql.Types.JAVA_OBJECT => null
+ case java.sql.Types.LONGNVARCHAR => StringType
+ case java.sql.Types.LONGVARBINARY => BinaryType
+ case java.sql.Types.LONGVARCHAR => StringType
+ case java.sql.Types.NCHAR => StringType
+ case java.sql.Types.NCLOB => StringType
+ case java.sql.Types.NULL => null
+ case java.sql.Types.NUMERIC
+ if precision != 0 || scale != 0 => DecimalType(precision, scale)
+ case java.sql.Types.NUMERIC => DecimalType(38, 18) // Spark 1.5.0 default
+ case java.sql.Types.NVARCHAR => StringType
+ case java.sql.Types.OTHER => null
+ case java.sql.Types.REAL => DoubleType
+ case java.sql.Types.REF => StringType
+ case java.sql.Types.ROWID => LongType
+ case java.sql.Types.SMALLINT => IntegerType
+ case java.sql.Types.SQLXML => StringType
+ case java.sql.Types.STRUCT => StringType
+ case java.sql.Types.TIME => TimestampType
+ case java.sql.Types.TIMESTAMP => TimestampType
+ case java.sql.Types.TINYINT => IntegerType
+ case java.sql.Types.VARBINARY => BinaryType
+ case java.sql.Types.VARCHAR => StringType
+ case _ => null
+ // scalastyle:on
+ }
+
+ if (answer == null) throw new SQLException("Unsupported type " + sqlType)
+ answer
+ }
+}
+
+private[redshift] object DefaultJDBCWrapper extends JDBCWrapper
diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala
new file mode 100755
index 0000000000000..fd24546d452b5
--- /dev/null
+++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala
@@ -0,0 +1,195 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.io.InputStreamReader
+import java.net.URI
+
+import scala.collection.JavaConverters._
+
+import com.amazonaws.auth.AWSCredentialsProvider
+import com.amazonaws.services.s3.AmazonS3Client
+import com.databricks.spark.redshift.Parameters.MergedParameters
+import com.eclipsesource.json.Json
+import org.slf4j.LoggerFactory
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext}
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+
+/**
+ * Data Source API implementation for Amazon Redshift database tables
+ */
+private[redshift] case class RedshiftRelation(
+ jdbcWrapper: JDBCWrapper,
+ s3ClientFactory: AWSCredentialsProvider => AmazonS3Client,
+ params: MergedParameters,
+ userSchema: Option[StructType])
+ (@transient val sqlContext: SQLContext)
+ extends BaseRelation
+ with PrunedFilteredScan
+ with InsertableRelation {
+
+ private val log = LoggerFactory.getLogger(getClass)
+
+ if (sqlContext != null) {
+ Utils.assertThatFileSystemIsNotS3BlockFileSystem(
+ new URI(params.rootTempDir), sqlContext.sparkContext.hadoopConfiguration)
+ }
+
+ private val tableNameOrSubquery =
+ params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get
+
+ override lazy val schema: StructType = {
+ userSchema.getOrElse {
+ val tableNameOrSubquery =
+ params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get
+ val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
+ try {
+ jdbcWrapper.resolveTable(conn, tableNameOrSubquery)
+ } finally {
+ conn.close()
+ }
+ }
+ }
+
+ override def toString: String = s"RedshiftRelation($tableNameOrSubquery)"
+
+ override def insert(data: DataFrame, overwrite: Boolean): Unit = {
+ val saveMode = if (overwrite) {
+ SaveMode.Overwrite
+ } else {
+ SaveMode.Append
+ }
+ val writer = new RedshiftWriter(jdbcWrapper, s3ClientFactory)
+ writer.saveToRedshift(sqlContext, data, saveMode, params)
+ }
+
+ override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
+ filters.filterNot(filter => FilterPushdown.buildFilterExpression(schema, filter).isDefined)
+ }
+
+ override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
+ val creds = AWSCredentialsUtils.load(params, sqlContext.sparkContext.hadoopConfiguration)
+ for (
+ redshiftRegion <- Utils.getRegionForRedshiftCluster(params.jdbcUrl);
+ s3Region <- Utils.getRegionForS3Bucket(params.rootTempDir, s3ClientFactory(creds))
+ ) {
+ if (redshiftRegion != s3Region) {
+ // We don't currently support `extraunloadoptions`, so even if Amazon _did_ add a `region`
+ // option for this we wouldn't be able to pass in the new option. However, we choose to
+ // err on the side of caution and don't throw an exception because we don't want to break
+ // existing workloads in case the region detection logic is wrong.
+ log.error("The Redshift cluster and S3 bucket are in different regions " +
+ s"($redshiftRegion and $s3Region, respectively). Redshift's UNLOAD command requires " +
+ s"that the Redshift cluster and Amazon S3 bucket be located in the same region, so " +
+ s"this read will fail.")
+ }
+ }
+ Utils.checkThatBucketHasObjectLifecycleConfiguration(params.rootTempDir, s3ClientFactory(creds))
+ if (requiredColumns.isEmpty) {
+ // In the special case where no columns were requested, issue a `count(*)` against Redshift
+ // rather than unloading data.
+ val whereClause = FilterPushdown.buildWhereClause(schema, filters)
+ val countQuery = s"SELECT count(*) FROM $tableNameOrSubquery $whereClause"
+ log.info(countQuery)
+ val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
+ try {
+ val results = jdbcWrapper.executeQueryInterruptibly(conn.prepareStatement(countQuery))
+ if (results.next()) {
+ val numRows = results.getLong(1)
+ val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt
+ val emptyRow = RowEncoder(StructType(Seq.empty)).toRow(Row(Seq.empty))
+ sqlContext.sparkContext
+ .parallelize(1L to numRows, parallelism)
+ .map(_ => emptyRow)
+ .asInstanceOf[RDD[Row]]
+ } else {
+ throw new IllegalStateException("Could not read count from Redshift")
+ }
+ } finally {
+ conn.close()
+ }
+ } else {
+ // Unload data from Redshift into a temporary directory in S3:
+ val tempDir = params.createPerQueryTempDir()
+ val unloadSql = buildUnloadStmt(requiredColumns, filters, tempDir, creds)
+ log.info(unloadSql)
+ val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
+ try {
+ jdbcWrapper.executeInterruptibly(conn.prepareStatement(unloadSql))
+ } finally {
+ conn.close()
+ }
+ // Read the MANIFEST file to get the list of S3 part files that were written by Redshift.
+ // We need to use a manifest in order to guard against S3's eventually-consistent listings.
+ val filesToRead: Seq[String] = {
+ val cleanedTempDirUri =
+ Utils.fixS3Url(Utils.removeCredentialsFromURI(URI.create(tempDir)).toString)
+ val s3URI = Utils.createS3URI(cleanedTempDirUri)
+ val s3Client = s3ClientFactory(creds)
+ val is = s3Client.getObject(s3URI.getBucket, s3URI.getKey + "manifest").getObjectContent
+ val s3Files = try {
+ val entries = Json.parse(new InputStreamReader(is)).asObject().get("entries").asArray()
+ entries.iterator().asScala.map(_.asObject().get("url").asString()).toSeq
+ } finally {
+ is.close()
+ }
+ // The filenames in the manifest are of the form s3://bucket/key, without credentials.
+ // If the S3 credentials were originally specified in the tempdir's URI, then we need to
+ // reintroduce them here
+ s3Files.map { file =>
+ tempDir.stripSuffix("/") + '/' + file.stripPrefix(cleanedTempDirUri).stripPrefix("/")
+ }
+ }
+
+ val prunedSchema = pruneSchema(schema, requiredColumns)
+
+ sqlContext.read
+ .format(classOf[RedshiftFileFormat].getName)
+ .schema(prunedSchema)
+ .load(filesToRead: _*)
+ .queryExecution.executedPlan.execute().asInstanceOf[RDD[Row]]
+ }
+ }
+
+ override def needConversion: Boolean = false
+
+ private def buildUnloadStmt(
+ requiredColumns: Array[String],
+ filters: Array[Filter],
+ tempDir: String,
+ creds: AWSCredentialsProvider): String = {
+ assert(!requiredColumns.isEmpty)
+ // Always quote column names:
+ val columnList = requiredColumns.map(col => s""""$col"""").mkString(", ")
+ val whereClause = FilterPushdown.buildWhereClause(schema, filters)
+ val credsString: String =
+ AWSCredentialsUtils.getRedshiftCredentialsString(params, creds.getCredentials)
+ val query = {
+ // Since the query passed to UNLOAD will be enclosed in single quotes, we need to escape
+ // any backslashes and single quotes that appear in the query itself
+ val escapedTableNameOrSubqury = tableNameOrSubquery.replace("\\", "\\\\").replace("'", "\\'")
+ s"SELECT $columnList FROM $escapedTableNameOrSubqury $whereClause"
+ }
+ // We need to remove S3 credentials from the unload path URI because they will conflict with
+ // the credentials passed via `credsString`.
+ val fixedUrl = Utils.fixS3Url(Utils.removeCredentialsFromURI(new URI(tempDir)).toString)
+
+ s"UNLOAD ('$query') TO '$fixedUrl' WITH CREDENTIALS '$credsString' ESCAPE MANIFEST"
+ }
+
+ private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
+ val fieldMap = Map(schema.fields.map(x => x.name -> x): _*)
+ new StructType(columns.map(name => fieldMap(name)))
+ }
+}
diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala
new file mode 100755
index 0000000000000..2b1d691782bb2
--- /dev/null
+++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala
@@ -0,0 +1,425 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.net.URI
+import java.sql.{Connection, Date, SQLException, Timestamp}
+
+import scala.collection.mutable
+import scala.util.control.NonFatal
+
+import com.amazonaws.auth.AWSCredentialsProvider
+import com.amazonaws.services.s3.AmazonS3Client
+import com.databricks.spark.redshift.Parameters.MergedParameters
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.slf4j.LoggerFactory
+
+import org.apache.spark.TaskContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext}
+import org.apache.spark.sql.types._
+
+/**
+ * Functions to write data to Redshift.
+ *
+ * At a high level, writing data back to Redshift involves the following steps:
+ *
+ * - Use the spark-avro library to save the DataFrame to S3 using Avro serialization. Prior to
+ * saving the data, certain data type conversions are applied in order to work around
+ * limitations in Avro's data type support and Redshift's case-insensitive identifier handling.
+ *
+ * While writing the Avro files, we use accumulators to keep track of which partitions were
+ * non-empty. After the write operation completes, we use this to construct a list of non-empty
+ * Avro partition files.
+ *
+ * - If there is data to be written (i.e. not all partitions were empty), then use the list of
+ * non-empty Avro files to construct a JSON manifest file to tell Redshift to load those files.
+ * This manifest is written to S3 alongside the Avro files themselves. We need to use an
+ * explicit manifest, as opposed to simply passing the name of the directory containing the
+ * Avro files, in order to work around a bug related to parsing of empty Avro files (see #96).
+ *
+ * - Start a new JDBC transaction and disable auto-commit. Depending on the SaveMode, issue
+ * DELETE TABLE or CREATE TABLE commands, then use the COPY command to instruct Redshift to load
+ * the Avro data into the appropriate table.
+ */
+private[redshift] class RedshiftWriter(
+ jdbcWrapper: JDBCWrapper,
+ s3ClientFactory: AWSCredentialsProvider => AmazonS3Client) {
+
+ private val log = LoggerFactory.getLogger(getClass)
+
+ /**
+ * Generate CREATE TABLE statement for Redshift
+ */
+ // Visible for testing.
+ private[redshift] def createTableSql(data: DataFrame, params: MergedParameters): String = {
+ val schemaSql = jdbcWrapper.schemaString(data.schema)
+ val distStyleDef = params.distStyle match {
+ case Some(style) => s"DISTSTYLE $style"
+ case None => ""
+ }
+ val distKeyDef = params.distKey match {
+ case Some(key) => s"DISTKEY ($key)"
+ case None => ""
+ }
+ val sortKeyDef = params.sortKeySpec.getOrElse("")
+ val table = params.table.get
+
+ s"CREATE TABLE IF NOT EXISTS $table ($schemaSql) $distStyleDef $distKeyDef $sortKeyDef"
+ }
+
+ /**
+ * Generate the COPY SQL command
+ */
+ private def copySql(
+ sqlContext: SQLContext,
+ params: MergedParameters,
+ creds: AWSCredentialsProvider,
+ manifestUrl: String): String = {
+ val credsString: String =
+ AWSCredentialsUtils.getRedshiftCredentialsString(params, creds.getCredentials)
+ val fixedUrl = Utils.fixS3Url(manifestUrl)
+ val format = params.tempFormat match {
+ case "AVRO" => "AVRO 'auto'"
+ case csv if csv == "CSV" || csv == "CSV GZIP" => csv + s" NULL AS '${params.nullString}'"
+ }
+ s"COPY ${params.table.get} FROM '$fixedUrl' CREDENTIALS '$credsString' FORMAT AS " +
+ s"${format} manifest ${params.extraCopyOptions}"
+ }
+
+ /**
+ * Generate COMMENT SQL statements for the table and columns.
+ */
+ private[redshift] def commentActions(tableComment: Option[String], schema: StructType):
+ List[String] = {
+ tableComment.toList.map(desc => s"COMMENT ON TABLE %s IS '${desc.replace("'", "''")}'") ++
+ schema.fields
+ .withFilter(f => f.metadata.contains("description"))
+ .map(f => s"""COMMENT ON COLUMN %s."${f.name.replace("\"", "\\\"")}""""
+ + s" IS '${f.metadata.getString("description").replace("'", "''")}'")
+ }
+
+ /**
+ * Perform the Redshift load by issuing a COPY statement.
+ */
+ private def doRedshiftLoad(
+ conn: Connection,
+ data: DataFrame,
+ params: MergedParameters,
+ creds: AWSCredentialsProvider,
+ manifestUrl: Option[String]): Unit = {
+
+ // If the table doesn't exist, we need to create it first, using JDBC to infer column types
+ val createStatement = createTableSql(data, params)
+ log.info(createStatement)
+ jdbcWrapper.executeInterruptibly(conn.prepareStatement(createStatement))
+
+ val preActions = commentActions(params.description, data.schema) ++ params.preActions
+
+ // Execute preActions
+ preActions.foreach { action =>
+ val actionSql = if (action.contains("%s")) action.format(params.table.get) else action
+ log.info("Executing preAction: " + actionSql)
+ jdbcWrapper.executeInterruptibly(conn.prepareStatement(actionSql))
+ }
+
+ manifestUrl.foreach { manifestUrl =>
+ // Load the temporary data into the new file
+ val copyStatement = copySql(data.sqlContext, params, creds, manifestUrl)
+ log.info(copyStatement)
+ try {
+ jdbcWrapper.executeInterruptibly(conn.prepareStatement(copyStatement))
+ } catch {
+ case e: SQLException =>
+ log.error("SQLException thrown while running COPY query; will attempt to retrieve " +
+ "more information by querying the STL_LOAD_ERRORS table", e)
+ // Try to query Redshift's STL_LOAD_ERRORS table to figure out why the load failed.
+ // See http://docs.aws.amazon.com/redshift/latest/dg/r_STL_LOAD_ERRORS.html for details.
+ conn.rollback()
+ val errorLookupQuery =
+ """
+ | SELECT *
+ | FROM stl_load_errors
+ | WHERE query = pg_last_query_id()
+ """.stripMargin
+ val detailedException: Option[SQLException] = try {
+ val results =
+ jdbcWrapper.executeQueryInterruptibly(conn.prepareStatement(errorLookupQuery))
+ if (results.next()) {
+ val errCode = results.getInt("err_code")
+ val errReason = results.getString("err_reason").trim
+ val columnLength: String =
+ Option(results.getString("col_length"))
+ .map(_.trim)
+ .filter(_.nonEmpty)
+ .map(n => s"($n)")
+ .getOrElse("")
+ val exceptionMessage =
+ s"""
+ |Error (code $errCode) while loading data into Redshift: "$errReason"
+ |Table name: ${params.table.get}
+ |Column name: ${results.getString("colname").trim}
+ |Column type: ${results.getString("type").trim}$columnLength
+ |Raw line: ${results.getString("raw_line")}
+ |Raw field value: ${results.getString("raw_field_value")}
+ """.stripMargin
+ Some(new SQLException(exceptionMessage, e))
+ } else {
+ None
+ }
+ } catch {
+ case NonFatal(e2) =>
+ log.error("Error occurred while querying STL_LOAD_ERRORS", e2)
+ None
+ }
+ throw detailedException.getOrElse(e)
+ }
+ }
+
+ // Execute postActions
+ params.postActions.foreach { action =>
+ val actionSql = if (action.contains("%s")) action.format(params.table.get) else action
+ log.info("Executing postAction: " + actionSql)
+ jdbcWrapper.executeInterruptibly(conn.prepareStatement(actionSql))
+ }
+ }
+
+ /**
+ * Serialize temporary data to S3, ready for Redshift COPY, and create a manifest file which can
+ * be used to instruct Redshift to load the non-empty temporary data partitions.
+ *
+ * @return the URL of the manifest file in S3, in `s3://path/to/file/manifest.json` format, if
+ * at least one record was written, and None otherwise.
+ */
+ private def unloadData(
+ sqlContext: SQLContext,
+ data: DataFrame,
+ tempDir: String,
+ tempFormat: String,
+ nullString: String): Option[String] = {
+ // spark-avro does not support Date types. In addition, it converts Timestamps into longs
+ // (milliseconds since the Unix epoch). Redshift is capable of loading timestamps in
+ // 'epochmillisecs' format but there's no equivalent format for dates. To work around this, we
+ // choose to write out both dates and timestamps as strings.
+ // For additional background and discussion, see #39.
+
+ // Convert the rows so that timestamps and dates become formatted strings.
+ // Formatters are not thread-safe, and thus these functions are not thread-safe.
+ // However, each task gets its own deserialized copy, making this safe.
+ val conversionFunctions: Array[Any => Any] = data.schema.fields.map { field =>
+ field.dataType match {
+ case DateType =>
+ val dateFormat = Conversions.createRedshiftDateFormat()
+ (v: Any) => {
+ if (v == null) null else dateFormat.format(v.asInstanceOf[Date])
+ }
+ case TimestampType =>
+ val timestampFormat = Conversions.createRedshiftTimestampFormat()
+ (v: Any) => {
+ if (v == null) null else timestampFormat.format(v.asInstanceOf[Timestamp])
+ }
+ case _ => (v: Any) => v
+ }
+ }
+
+ // Use Spark accumulators to determine which partitions were non-empty.
+ val nonEmptyPartitions =
+ sqlContext.sparkContext.accumulableCollection(mutable.HashSet.empty[Int])
+
+ val convertedRows: RDD[Row] = data.rdd.mapPartitions { iter: Iterator[Row] =>
+ if (iter.hasNext) {
+ nonEmptyPartitions += TaskContext.get.partitionId()
+ }
+ iter.map { row =>
+ val convertedValues: Array[Any] = new Array(conversionFunctions.length)
+ var i = 0
+ while (i < conversionFunctions.length) {
+ convertedValues(i) = conversionFunctions(i)(row(i))
+ i += 1
+ }
+ Row.fromSeq(convertedValues)
+ }
+ }
+
+ // Convert all column names to lowercase, which is necessary for Redshift to be able to load
+ // those columns (see #51).
+ val schemaWithLowercaseColumnNames: StructType =
+ StructType(data.schema.map(f => f.copy(name = f.name.toLowerCase)))
+
+ if (schemaWithLowercaseColumnNames.map(_.name).toSet.size != data.schema.size) {
+ throw new IllegalArgumentException(
+ "Cannot save table to Redshift because two or more column names would be identical" +
+ " after conversion to lowercase: " + data.schema.map(_.name).mkString(", "))
+ }
+
+ // Update the schema so that Avro writes date and timestamp columns as formatted timestamp
+ // strings. This is necessary for Redshift to be able to load these columns (see #39).
+ val convertedSchema: StructType = StructType(
+ schemaWithLowercaseColumnNames.map {
+ case StructField(name, DateType, nullable, meta) =>
+ StructField(name, StringType, nullable, meta)
+ case StructField(name, TimestampType, nullable, meta) =>
+ StructField(name, StringType, nullable, meta)
+ case other => other
+ }
+ )
+
+ val writer = sqlContext.createDataFrame(convertedRows, convertedSchema).write
+ (tempFormat match {
+ case "AVRO" =>
+ writer.format("com.databricks.spark.avro")
+ case "CSV" =>
+ writer.format("csv")
+ .option("escape", "\"")
+ .option("nullValue", nullString)
+ case "CSV GZIP" =>
+ writer.format("csv")
+ .option("escape", "\"")
+ .option("nullValue", nullString)
+ .option("compression", "gzip")
+ }).save(tempDir)
+
+ if (nonEmptyPartitions.value.isEmpty) {
+ None
+ } else {
+ // See https://docs.aws.amazon.com/redshift/latest/dg/loading-data-files-using-manifest.html
+ // for a description of the manifest file format. The URLs in this manifest must be absolute
+ // and complete.
+
+ // The partition filenames are of the form part-r-XXXXX-UUID.fileExtension.
+ val fs = FileSystem.get(URI.create(tempDir), sqlContext.sparkContext.hadoopConfiguration)
+ val partitionIdRegex = "^part-(?:r-)?(\\d+)[^\\d+].*$".r
+ val filesToLoad: Seq[String] = {
+ val nonEmptyPartitionIds = nonEmptyPartitions.value.toSet
+ fs.listStatus(new Path(tempDir)).map(_.getPath.getName).collect {
+ case file @ partitionIdRegex(id) if nonEmptyPartitionIds.contains(id.toInt) => file
+ }
+ }
+ // It's possible that tempDir contains AWS access keys. We shouldn't save those credentials to
+ // S3, so let's first sanitize `tempdir` and make sure that it uses the s3:// scheme:
+ val sanitizedTempDir = Utils.fixS3Url(
+ Utils.removeCredentialsFromURI(URI.create(tempDir)).toString).stripSuffix("/")
+ val manifestEntries = filesToLoad.map { file =>
+ s"""{"url":"$sanitizedTempDir/$file", "mandatory":true}"""
+ }
+ val manifest = s"""{"entries": [${manifestEntries.mkString(",\n")}]}"""
+ val manifestPath = sanitizedTempDir + "/manifest.json"
+ val fsDataOut = fs.create(new Path(manifestPath))
+ try {
+ fsDataOut.write(manifest.getBytes("utf-8"))
+ } finally {
+ fsDataOut.close()
+ }
+ Some(manifestPath)
+ }
+ }
+
+ /**
+ * Write a DataFrame to a Redshift table, using S3 and Avro or CSV serialization
+ */
+ def saveToRedshift(
+ sqlContext: SQLContext,
+ data: DataFrame,
+ saveMode: SaveMode,
+ params: MergedParameters) : Unit = {
+ if (params.table.isEmpty) {
+ throw new IllegalArgumentException(
+ "For save operations you must specify a Redshift table name with the 'dbtable' parameter")
+ }
+
+ if (!params.useStagingTable) {
+ log.warn("Setting useStagingTable=false is deprecated; instead, we recommend that you " +
+ "drop the target table yourself. For more details on this deprecation, see" +
+ "https://github.com/databricks/spark-redshift/pull/157")
+ }
+
+ val creds: AWSCredentialsProvider =
+ AWSCredentialsUtils.load(params, sqlContext.sparkContext.hadoopConfiguration)
+
+ for (
+ redshiftRegion <- Utils.getRegionForRedshiftCluster(params.jdbcUrl);
+ s3Region <- Utils.getRegionForS3Bucket(params.rootTempDir, s3ClientFactory(creds))
+ ) {
+ val regionIsSetInExtraCopyOptions =
+ params.extraCopyOptions.contains(s3Region) && params.extraCopyOptions.contains("region")
+ if (redshiftRegion != s3Region && !regionIsSetInExtraCopyOptions) {
+ log.error("The Redshift cluster and S3 bucket are in different regions " +
+ s"($redshiftRegion and $s3Region, respectively). In order to perform this cross-region " +
+ s"""write, you must add "region '$s3Region'" to the extracopyoptions parameter. """ +
+ "For more details on cross-region usage, see the README.")
+ }
+ }
+
+ // When using the Avro tempformat, log an informative error message in case any column names
+ // are unsupported by Avro's schema validation:
+ if (params.tempFormat == "AVRO") {
+ for (fieldName <- data.schema.fieldNames) {
+ // The following logic is based on Avro's Schema.validateName() method:
+ val firstChar = fieldName.charAt(0)
+ val isValid = (firstChar.isLetter || firstChar == '_') && fieldName.tail.forall { c =>
+ c.isLetterOrDigit || c == '_'
+ }
+ if (!isValid) {
+ throw new IllegalArgumentException(
+ s"The field name '$fieldName' is not supported when using the Avro tempformat. " +
+ "Try using the CSV tempformat instead. For more details, see " +
+ "https://github.com/databricks/spark-redshift/issues/84")
+ }
+ }
+ }
+
+ Utils.assertThatFileSystemIsNotS3BlockFileSystem(
+ new URI(params.rootTempDir), sqlContext.sparkContext.hadoopConfiguration)
+
+ Utils.checkThatBucketHasObjectLifecycleConfiguration(params.rootTempDir, s3ClientFactory(creds))
+
+ // Save the table's rows to S3:
+ val manifestUrl = unloadData(
+ sqlContext,
+ data,
+ tempDir = params.createPerQueryTempDir(),
+ tempFormat = params.tempFormat,
+ nullString = params.nullString)
+ val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
+ conn.setAutoCommit(false)
+ try {
+ val table: TableName = params.table.get
+ if (saveMode == SaveMode.Overwrite) {
+ // Overwrites must drop the table in case there has been a schema update
+ jdbcWrapper.executeInterruptibly(conn.prepareStatement(s"DROP TABLE IF EXISTS $table;"))
+ if (!params.useStagingTable) {
+ // If we're not using a staging table, commit now so that Redshift doesn't have to
+ // maintain a snapshot of the old table during the COPY; this sacrifices atomicity for
+ // performance.
+ conn.commit()
+ }
+ }
+ log.info(s"Loading new Redshift data to: $table")
+ doRedshiftLoad(conn, data, params, creds, manifestUrl)
+ conn.commit()
+ } catch {
+ case NonFatal(e) =>
+ try {
+ log.error("Exception thrown during Redshift load; will roll back transaction", e)
+ conn.rollback()
+ } catch {
+ case NonFatal(e2) =>
+ log.error("Exception while rolling back transaction", e2)
+ }
+ throw e
+ } finally {
+ conn.close()
+ }
+ }
+}
+
+object DefaultRedshiftWriter extends RedshiftWriter(
+ DefaultJDBCWrapper,
+ awsCredentials => new AmazonS3Client(awsCredentials))
diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/SerializableConfiguration.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/SerializableConfiguration.scala
new file mode 100755
index 0000000000000..74e3d351e4e31
--- /dev/null
+++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/SerializableConfiguration.scala
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.io._
+
+import scala.util.control.NonFatal
+
+import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
+import com.esotericsoftware.kryo.io.{Input, Output}
+import org.apache.hadoop.conf.Configuration
+import org.slf4j.LoggerFactory
+
+class SerializableConfiguration(@transient var value: Configuration)
+ extends Serializable with KryoSerializable {
+ @transient private[redshift] lazy val log = LoggerFactory.getLogger(getClass)
+
+ private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException {
+ out.defaultWriteObject()
+ value.write(out)
+ }
+
+ private def readObject(in: ObjectInputStream): Unit = tryOrIOException {
+ value = new Configuration(false)
+ value.readFields(in)
+ }
+
+ private def tryOrIOException[T](block: => T): T = {
+ try {
+ block
+ } catch {
+ case e: IOException =>
+ log.error("Exception encountered", e)
+ throw e
+ case NonFatal(e) =>
+ log.error("Exception encountered", e)
+ throw new IOException(e)
+ }
+ }
+
+ def write(kryo: Kryo, out: Output): Unit = {
+ val dos = new DataOutputStream(out)
+ value.write(dos)
+ dos.flush()
+ }
+
+ def read(kryo: Kryo, in: Input): Unit = {
+ value = new Configuration(false)
+ value.readFields(new DataInputStream(in))
+ }
+}
diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/TableName.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/TableName.scala
new file mode 100755
index 0000000000000..77ef3b6602b22
--- /dev/null
+++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/TableName.scala
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * Wrapper class for representing the name of a Redshift table.
+ */
+private[redshift] case class TableName(unescapedSchemaName: String, unescapedTableName: String) {
+ private def quote(str: String) = '"' + str.replace("\"", "\"\"") + '"'
+ def escapedSchemaName: String = quote(unescapedSchemaName)
+ def escapedTableName: String = quote(unescapedTableName)
+ override def toString: String = s"$escapedSchemaName.$escapedTableName"
+}
+
+private[redshift] object TableName {
+ /**
+ * Parses a table name which is assumed to have been escaped according to Redshift's rules for
+ * delimited identifiers.
+ */
+ def parseFromEscaped(str: String): TableName = {
+ def dropOuterQuotes(s: String) =
+ if (s.startsWith("\"") && s.endsWith("\"")) s.drop(1).dropRight(1) else s
+ def unescapeQuotes(s: String) = s.replace("\"\"", "\"")
+ def unescape(s: String) = unescapeQuotes(dropOuterQuotes(s))
+ splitByDots(str) match {
+ case Seq(tableName) => TableName("PUBLIC", unescape(tableName))
+ case Seq(schemaName, tableName) => TableName(unescape(schemaName), unescape(tableName))
+ case other => throw new IllegalArgumentException(s"Could not parse table name from '$str'")
+ }
+ }
+
+ /**
+ * Split by dots (.) while obeying our identifier quoting rules in order to allow dots to appear
+ * inside of quoted identifiers.
+ */
+ private def splitByDots(str: String): Seq[String] = {
+ val parts: ArrayBuffer[String] = ArrayBuffer.empty
+ val sb = new StringBuilder
+ var inQuotes: Boolean = false
+ for (c <- str) c match {
+ case '"' =>
+ // Note that double quotes are escaped by pairs of double quotes (""), so we don't need
+ // any extra code to handle them; we'll be back in inQuotes=true after seeing the pair.
+ sb.append('"')
+ inQuotes = !inQuotes
+ case '.' =>
+ if (!inQuotes) {
+ parts.append(sb.toString())
+ sb.clear()
+ } else {
+ sb.append('.')
+ }
+ case other =>
+ sb.append(other)
+ }
+ if (sb.nonEmpty) {
+ parts.append(sb.toString())
+ }
+ parts
+ }
+}
diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/Utils.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/Utils.scala
new file mode 100755
index 0000000000000..ae5e4c9c9b78f
--- /dev/null
+++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/Utils.scala
@@ -0,0 +1,200 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.net.URI
+import java.util.UUID
+
+import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
+
+import com.amazonaws.services.s3.{AmazonS3Client, AmazonS3URI}
+import com.amazonaws.services.s3.model.BucketLifecycleConfiguration
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.FileSystem
+import org.slf4j.LoggerFactory
+
+/**
+ * Various arbitrary helper functions
+ */
+private[redshift] object Utils {
+
+ private val log = LoggerFactory.getLogger(getClass)
+
+ def classForName(className: String): Class[_] = {
+ val classLoader =
+ Option(Thread.currentThread().getContextClassLoader).getOrElse(this.getClass.getClassLoader)
+ // scalastyle:off
+ Class.forName(className, true, classLoader)
+ // scalastyle:on
+ }
+
+ /**
+ * Joins prefix URL a to path suffix b, and appends a trailing /, in order to create
+ * a temp directory path for S3.
+ */
+ def joinUrls(a: String, b: String): String = {
+ a.stripSuffix("/") + "/" + b.stripPrefix("/").stripSuffix("/") + "/"
+ }
+
+ /**
+ * Redshift COPY and UNLOAD commands don't support s3n or s3a, but users may wish to use them
+ * for data loads. This function converts the URL back to the s3:// format.
+ */
+ def fixS3Url(url: String): String = {
+ url.replaceAll("s3[an]://", "s3://")
+ }
+
+ /**
+ * Factory method to create new S3URI in order to handle various library incompatibilities with
+ * older AWS Java Libraries
+ */
+ def createS3URI(url: String): AmazonS3URI = {
+ try {
+ // try to instantiate AmazonS3URI with url
+ new AmazonS3URI(url)
+ } catch {
+ case e: IllegalArgumentException if e.getMessage.
+ startsWith("Invalid S3 URI: hostname does not appear to be a valid S3 endpoint") => {
+ new AmazonS3URI(addEndpointToUrl(url))
+ }
+ }
+ }
+
+ /**
+ * Since older AWS Java Libraries do not handle S3 urls that have just the bucket name
+ * as the host, add the endpoint to the host
+ */
+ def addEndpointToUrl(url: String, domain: String = "s3.amazonaws.com"): String = {
+ val uri = new URI(url)
+ val hostWithEndpoint = uri.getHost + "." + domain
+ new URI(uri.getScheme,
+ uri.getUserInfo,
+ hostWithEndpoint,
+ uri.getPort,
+ uri.getPath,
+ uri.getQuery,
+ uri.getFragment).toString
+ }
+
+ /**
+ * Returns a copy of the given URI with the user credentials removed.
+ */
+ def removeCredentialsFromURI(uri: URI): URI = {
+ new URI(
+ uri.getScheme,
+ null, // no user info
+ uri.getHost,
+ uri.getPort,
+ uri.getPath,
+ uri.getQuery,
+ uri.getFragment)
+ }
+
+ // Visible for testing
+ private[redshift] var lastTempPathGenerated: String = null
+
+ /**
+ * Creates a randomly named temp directory path for intermediate data
+ */
+ def makeTempPath(tempRoot: String): String = {
+ lastTempPathGenerated = Utils.joinUrls(tempRoot, UUID.randomUUID().toString)
+ lastTempPathGenerated
+ }
+
+ /**
+ * Checks whether the S3 bucket for the given UI has an object lifecycle configuration to
+ * ensure cleanup of temporary files. If no applicable configuration is found, this method logs
+ * a helpful warning for the user.
+ */
+ def checkThatBucketHasObjectLifecycleConfiguration(
+ tempDir: String,
+ s3Client: AmazonS3Client): Unit = {
+ try {
+ val s3URI = createS3URI(Utils.fixS3Url(tempDir))
+ val bucket = s3URI.getBucket
+ assert(bucket != null, "Could not get bucket from S3 URI")
+ val key = Option(s3URI.getKey).getOrElse("")
+ val hasMatchingBucketLifecycleRule: Boolean = {
+ val rules = Option(s3Client.getBucketLifecycleConfiguration(bucket))
+ .map(_.getRules.asScala)
+ .getOrElse(Seq.empty)
+ rules.exists { rule =>
+ // Note: this only checks that there is an active rule which matches the temp directory;
+ // it does not actually check that the rule will delete the files. This check is still
+ // better than nothing, though, and we can always improve it later.
+ rule.getStatus == BucketLifecycleConfiguration.ENABLED && key.startsWith(rule.getPrefix)
+ }
+ }
+ if (!hasMatchingBucketLifecycleRule) {
+ log.warn(s"The S3 bucket $bucket does not have an object lifecycle configuration to " +
+ "ensure cleanup of temporary files. Consider configuring `tempdir` to point to a " +
+ "bucket with an object lifecycle policy that automatically deletes files after an " +
+ "expiration period. For more information, see " +
+ "https://docs.aws.amazon.com/AmazonS3/latest/dev/object-lifecycle-mgmt.html")
+ }
+ } catch {
+ case NonFatal(e) =>
+ log.warn("An error occurred while trying to read the S3 bucket lifecycle configuration", e)
+ }
+ }
+
+ /**
+ * Given a URI, verify that the Hadoop FileSystem for that URI is not the S3 block FileSystem.
+ * `spark-redshift` cannot use this FileSystem because the files written to it will not be
+ * readable by Redshift (and vice versa).
+ */
+ def assertThatFileSystemIsNotS3BlockFileSystem(uri: URI, hadoopConfig: Configuration): Unit = {
+ val fs = FileSystem.get(uri, hadoopConfig)
+ // Note that we do not want to use isInstanceOf here, since we're only interested in detecting
+ // exact matches. We compare the class names as strings in order to avoid introducing a binary
+ // dependency on classes which belong to the `hadoop-aws` JAR, as that artifact is not present
+ // in some environments (such as EMR). See #92 for details.
+ if (fs.getClass.getCanonicalName == "org.apache.hadoop.fs.s3.S3FileSystem") {
+ throw new IllegalArgumentException(
+ "spark-redshift does not support the S3 Block FileSystem. Please reconfigure `tempdir` to" +
+ "use a s3n:// or s3a:// scheme.")
+ }
+ }
+
+ /**
+ * Attempts to retrieve the region of the S3 bucket.
+ */
+ def getRegionForS3Bucket(tempDir: String, s3Client: AmazonS3Client): Option[String] = {
+ try {
+ val s3URI = createS3URI(Utils.fixS3Url(tempDir))
+ val bucket = s3URI.getBucket
+ assert(bucket != null, "Could not get bucket from S3 URI")
+ val region = s3Client.getBucketLocation(bucket) match {
+ // Map "US Standard" to us-east-1
+ case null | "US" => "us-east-1"
+ case other => other
+ }
+ Some(region)
+ } catch {
+ case NonFatal(e) =>
+ log.warn("An error occurred while trying to determine the S3 bucket's region", e)
+ None
+ }
+ }
+
+ /**
+ * Attempts to determine the region of a Redshift cluster based on its URL. It may not be possible
+ * to determine the region in some cases, such as when the Redshift cluster is placed behind a
+ * proxy.
+ */
+ def getRegionForRedshiftCluster(url: String): Option[String] = {
+ val regionRegex = """.*\.([^.]+)\.redshift\.amazonaws\.com.*""".r
+ url match {
+ case regionRegex(region) => Some(region)
+ case _ => None
+ }
+ }
+}
diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/package.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/package.scala
new file mode 100755
index 0000000000000..233141efa23a7
--- /dev/null
+++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/package.scala
@@ -0,0 +1,48 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark
+
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.types.{StringType, StructField, StructType}
+
+package object redshift {
+
+ /**
+ * Wrapper of SQLContext that provide `redshiftFile` method.
+ */
+ implicit class RedshiftContext(sqlContext: SQLContext) {
+
+ /**
+ * Read a file unloaded from Redshift into a DataFrame.
+ * @param path input path
+ * @return a DataFrame with all string columns
+ */
+ def redshiftFile(path: String, columns: Seq[String]): DataFrame = {
+ val sc = sqlContext.sparkContext
+ val rdd = sc.newAPIHadoopFile(path, classOf[RedshiftInputFormat],
+ classOf[java.lang.Long], classOf[Array[String]], sc.hadoopConfiguration)
+ // TODO: allow setting NULL string.
+ val nullable = rdd.values.map(_.map(f => if (f.isEmpty) null else f)).map(x => Row(x: _*))
+ val schema = StructType(columns.map(c => StructField(c, StringType, nullable = true)))
+ sqlContext.createDataFrame(nullable, schema)
+ }
+
+ /**
+ * Reads a table unload from Redshift with its schema.
+ */
+ def redshiftFile(path: String, schema: StructType): DataFrame = {
+ val casts = schema.fields.map { field =>
+ col(field.name).cast(field.dataType).as(field.name)
+ }
+ redshiftFile(path, schema.fieldNames).select(casts: _*)
+ }
+ }
+}
diff --git a/external/redshift/src/test/java/com/databricks/spark/redshift/S3NInMemoryFileSystem.java b/external/redshift/src/test/java/com/databricks/spark/redshift/S3NInMemoryFileSystem.java
new file mode 100755
index 0000000000000..b1a6a08a0decb
--- /dev/null
+++ b/external/redshift/src/test/java/com/databricks/spark/redshift/S3NInMemoryFileSystem.java
@@ -0,0 +1,23 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package org.apache.hadoop.fs.s3native;
+
+import org.apache.hadoop.fs.s3native.NativeS3FileSystem;
+import org.apache.hadoop.fs.s3native.InMemoryNativeFileSystemStore;
+
+/**
+ * A helper implementation of {@link NativeS3FileSystem}
+ * without actually connecting to S3 for unit testing.
+ */
+public class S3NInMemoryFileSystem extends NativeS3FileSystem {
+ public S3NInMemoryFileSystem() {
+ super(new InMemoryNativeFileSystemStore());
+ }
+}
diff --git a/external/redshift/src/test/java/org/apache/hadoop/fs/s3native/InMemoryNativeFileSystemStore.java b/external/redshift/src/test/java/org/apache/hadoop/fs/s3native/InMemoryNativeFileSystemStore.java
new file mode 100644
index 0000000000000..ac572aad40361
--- /dev/null
+++ b/external/redshift/src/test/java/org/apache/hadoop/fs/s3native/InMemoryNativeFileSystemStore.java
@@ -0,0 +1,206 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.fs.s3native;
+
+import static org.apache.hadoop.fs.s3native.NativeS3FileSystem.PATH_DELIMITER;
+
+import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URI;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.SortedMap;
+import java.util.SortedSet;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import java.util.Map.Entry;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.util.Time;
+
+/**
+ *
+ * A stub implementation of {@link NativeFileSystemStore} for testing
+ * {@link NativeS3FileSystem} without actually connecting to S3.
+ *
+ */
+public class InMemoryNativeFileSystemStore implements NativeFileSystemStore {
+
+ private Configuration conf;
+
+ private SortedMap metadataMap =
+ new TreeMap();
+ private SortedMap dataMap = new TreeMap();
+
+ @Override
+ public void initialize(URI uri, Configuration conf) throws IOException {
+ this.conf = conf;
+ }
+
+ @Override
+ public void storeEmptyFile(String key) throws IOException {
+ metadataMap.put(key, new FileMetadata(key, 0, Time.now()));
+ dataMap.put(key, new byte[0]);
+ }
+
+ @Override
+ public void storeFile(String key, File file, byte[] md5Hash)
+ throws IOException {
+
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ byte[] buf = new byte[8192];
+ int numRead;
+ BufferedInputStream in = null;
+ try {
+ in = new BufferedInputStream(new FileInputStream(file));
+ while ((numRead = in.read(buf)) >= 0) {
+ out.write(buf, 0, numRead);
+ }
+ } finally {
+ if (in != null) {
+ in.close();
+ }
+ }
+ metadataMap.put(key,
+ new FileMetadata(key, file.length(), Time.now()));
+ dataMap.put(key, out.toByteArray());
+ }
+
+ @Override
+ public InputStream retrieve(String key) throws IOException {
+ return retrieve(key, 0);
+ }
+
+ @Override
+ public InputStream retrieve(String key, long byteRangeStart)
+ throws IOException {
+
+ byte[] data = dataMap.get(key);
+ File file = createTempFile();
+ BufferedOutputStream out = null;
+ try {
+ out = new BufferedOutputStream(new FileOutputStream(file));
+ out.write(data, (int) byteRangeStart,
+ data.length - (int) byteRangeStart);
+ } finally {
+ if (out != null) {
+ out.close();
+ }
+ }
+ return new FileInputStream(file);
+ }
+
+ private File createTempFile() throws IOException {
+ File dir = new File(conf.get("fs.s3.buffer.dir"));
+ if (!dir.exists() && !dir.mkdirs()) {
+ throw new IOException("Cannot create S3 buffer directory: " + dir);
+ }
+ File result = File.createTempFile("test-", ".tmp", dir);
+ result.deleteOnExit();
+ return result;
+ }
+
+ @Override
+ public FileMetadata retrieveMetadata(String key) throws IOException {
+ return metadataMap.get(key);
+ }
+
+ @Override
+ public PartialListing list(String prefix, int maxListingLength)
+ throws IOException {
+ return list(prefix, maxListingLength, null, false);
+ }
+
+ @Override
+ public PartialListing list(String prefix, int maxListingLength,
+ String priorLastKey, boolean recursive) throws IOException {
+
+ return list(prefix, recursive ? null : PATH_DELIMITER, maxListingLength, priorLastKey);
+ }
+
+ private PartialListing list(String prefix, String delimiter,
+ int maxListingLength, String priorLastKey) throws IOException {
+
+ if (prefix.length() > 0 && !prefix.endsWith(PATH_DELIMITER)) {
+ prefix += PATH_DELIMITER;
+ }
+
+ List metadata = new ArrayList();
+ SortedSet commonPrefixes = new TreeSet();
+ for (String key : dataMap.keySet()) {
+ if (key.startsWith(prefix)) {
+ if (delimiter == null) {
+ metadata.add(retrieveMetadata(key));
+ } else {
+ int delimIndex = key.indexOf(delimiter, prefix.length());
+ if (delimIndex == -1) {
+ metadata.add(retrieveMetadata(key));
+ } else {
+ String commonPrefix = key.substring(0, delimIndex);
+ commonPrefixes.add(commonPrefix);
+ }
+ }
+ }
+ if (metadata.size() + commonPrefixes.size() == maxListingLength) {
+ new PartialListing(key, metadata.toArray(new FileMetadata[0]),
+ commonPrefixes.toArray(new String[0]));
+ }
+ }
+ return new PartialListing(null, metadata.toArray(new FileMetadata[0]),
+ commonPrefixes.toArray(new String[0]));
+ }
+
+ @Override
+ public void delete(String key) throws IOException {
+ metadataMap.remove(key);
+ dataMap.remove(key);
+ }
+
+ @Override
+ public void copy(String srcKey, String dstKey) throws IOException {
+ metadataMap.put(dstKey, metadataMap.get(srcKey));
+ dataMap.put(dstKey, dataMap.get(srcKey));
+ }
+
+ @Override
+ public void purge(String prefix) throws IOException {
+ Iterator> i =
+ metadataMap.entrySet().iterator();
+ while (i.hasNext()) {
+ Entry entry = i.next();
+ if (entry.getKey().startsWith(prefix)) {
+ dataMap.remove(entry.getKey());
+ i.remove();
+ }
+ }
+ }
+
+ @Override
+ public void dump() throws IOException {
+ System.out.println(metadataMap.values());
+ System.out.println(dataMap.keySet());
+ }
+}
diff --git a/external/redshift/src/test/resources/hive-site.xml b/external/redshift/src/test/resources/hive-site.xml
new file mode 100755
index 0000000000000..1a06bec091434
--- /dev/null
+++ b/external/redshift/src/test/resources/hive-site.xml
@@ -0,0 +1,22 @@
+
+
+
+
+
+
+ fs.permissions.umask-mode
+ 022
+ Setting a value for fs.permissions.umask-mode to work around issue in HIVE-6962.
+ It has no impact in Hadoop 1.x line on HDFS operations.
+
+
+
diff --git a/external/redshift/src/test/resources/redshift_unload_data.txt b/external/redshift/src/test/resources/redshift_unload_data.txt
new file mode 100755
index 0000000000000..6b47543ba9b97
--- /dev/null
+++ b/external/redshift/src/test/resources/redshift_unload_data.txt
@@ -0,0 +1,5 @@
+1|t|2015-07-01|1234152.12312498|1.0|42|1239012341823719|23|Unicode's樂趣|2015-07-01 00:00:00.001
+1|f|2015-07-02|0|0.0|42|1239012341823719|-13|asdf|2015-07-02 00:00:00.0
+0||2015-07-03|0.0|-1.0|4141214|1239012341823719||f|2015-07-03 00:00:00
+0|f||-1234152.12312498|100000.0||1239012341823719|24|___\|_123|
+|||||||||
diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/AWSCredentialsUtilsSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/AWSCredentialsUtilsSuite.scala
new file mode 100755
index 0000000000000..91f8d78d8f840
--- /dev/null
+++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/AWSCredentialsUtilsSuite.scala
@@ -0,0 +1,133 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import scala.language.implicitConversions
+
+import com.amazonaws.auth.{AWSSessionCredentials, BasicAWSCredentials, BasicSessionCredentials}
+import com.databricks.spark.redshift.Parameters.MergedParameters
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.SparkFunSuite
+
+class AWSCredentialsUtilsSuite extends SparkFunSuite {
+
+ val baseParams = Map(
+ "tempdir" -> "s3://foo/bar",
+ "dbtable" -> "test_schema.test_table",
+ "url" -> "jdbc:redshift://foo/bar?user=user&password=password")
+
+ private implicit def string2Params(tempdir: String): MergedParameters = {
+ Parameters.mergeParameters(baseParams ++ Map(
+ "tempdir" -> tempdir,
+ "forward_spark_s3_credentials" -> "true"))
+ }
+
+ test("credentialsString with regular keys") {
+ val creds = new BasicAWSCredentials("ACCESSKEYID", "SECRET/KEY/WITH/SLASHES")
+ val params =
+ Parameters.mergeParameters(baseParams ++ Map("forward_spark_s3_credentials" -> "true"))
+ assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, creds) ===
+ "aws_access_key_id=ACCESSKEYID;aws_secret_access_key=SECRET/KEY/WITH/SLASHES")
+ }
+
+ test("credentialsString with STS temporary keys") {
+ val params = Parameters.mergeParameters(baseParams ++ Map(
+ "temporary_aws_access_key_id" -> "ACCESSKEYID",
+ "temporary_aws_secret_access_key" -> "SECRET/KEY",
+ "temporary_aws_session_token" -> "SESSION/Token"))
+ assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, null) ===
+ "aws_access_key_id=ACCESSKEYID;aws_secret_access_key=SECRET/KEY;token=SESSION/Token")
+ }
+
+ test("Configured IAM roles should take precedence") {
+ val creds = new BasicSessionCredentials("ACCESSKEYID", "SECRET/KEY", "SESSION/Token")
+ val iamRole = "arn:aws:iam::123456789000:role/redshift_iam_role"
+ val params = Parameters.mergeParameters(baseParams ++ Map("aws_iam_role" -> iamRole))
+ assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, null) ===
+ s"aws_iam_role=$iamRole")
+ }
+
+ test("AWSCredentials.load() STS temporary keys should take precedence") {
+ val conf = new Configuration(false)
+ conf.set("fs.s3.awsAccessKeyId", "CONFID")
+ conf.set("fs.s3.awsSecretAccessKey", "CONFKEY")
+
+ val params = Parameters.mergeParameters(baseParams ++ Map(
+ "tempdir" -> "s3://URIID:URIKEY@bucket/path",
+ "temporary_aws_access_key_id" -> "key_id",
+ "temporary_aws_secret_access_key" -> "secret",
+ "temporary_aws_session_token" -> "token"
+ ))
+
+ val creds = AWSCredentialsUtils.load(params, conf).getCredentials
+ assert(creds.isInstanceOf[AWSSessionCredentials])
+ assert(creds.getAWSAccessKeyId === "key_id")
+ assert(creds.getAWSSecretKey === "secret")
+ assert(creds.asInstanceOf[AWSSessionCredentials].getSessionToken === "token")
+ }
+
+ test("AWSCredentials.load() credentials precedence for s3:// URIs") {
+ val conf = new Configuration(false)
+ conf.set("fs.s3.awsAccessKeyId", "CONFID")
+ conf.set("fs.s3.awsSecretAccessKey", "CONFKEY")
+
+ {
+ val creds = AWSCredentialsUtils.load("s3://URIID:URIKEY@bucket/path", conf).getCredentials
+ assert(creds.getAWSAccessKeyId === "URIID")
+ assert(creds.getAWSSecretKey === "URIKEY")
+ }
+
+ {
+ val creds = AWSCredentialsUtils.load("s3://bucket/path", conf).getCredentials
+ assert(creds.getAWSAccessKeyId === "CONFID")
+ assert(creds.getAWSSecretKey === "CONFKEY")
+ }
+
+ }
+
+ test("AWSCredentials.load() credentials precedence for s3n:// URIs") {
+ val conf = new Configuration(false)
+ conf.set("fs.s3n.awsAccessKeyId", "CONFID")
+ conf.set("fs.s3n.awsSecretAccessKey", "CONFKEY")
+
+ {
+ val creds = AWSCredentialsUtils.load("s3n://URIID:URIKEY@bucket/path", conf).getCredentials
+ assert(creds.getAWSAccessKeyId === "URIID")
+ assert(creds.getAWSSecretKey === "URIKEY")
+ }
+
+ {
+ val creds = AWSCredentialsUtils.load("s3n://bucket/path", conf).getCredentials
+ assert(creds.getAWSAccessKeyId === "CONFID")
+ assert(creds.getAWSSecretKey === "CONFKEY")
+ }
+
+ }
+
+ test("AWSCredentials.load() credentials precedence for s3a:// URIs") {
+ val conf = new Configuration(false)
+ conf.set("fs.s3a.access.key", "CONFID")
+ conf.set("fs.s3a.secret.key", "CONFKEY")
+
+ {
+ val creds = AWSCredentialsUtils.load("s3a://URIID:URIKEY@bucket/path", conf).getCredentials
+ assert(creds.getAWSAccessKeyId === "URIID")
+ assert(creds.getAWSSecretKey === "URIKEY")
+ }
+
+ {
+ val creds = AWSCredentialsUtils.load("s3a://bucket/path", conf).getCredentials
+ assert(creds.getAWSAccessKeyId === "CONFID")
+ assert(creds.getAWSSecretKey === "CONFKEY")
+ }
+
+ }
+}
diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala
new file mode 100755
index 0000000000000..41817311676c9
--- /dev/null
+++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.sql.Timestamp
+import java.util.Locale
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+
+/**
+ * Unit test for data type conversions
+ */
+class ConversionsSuite extends SparkFunSuite {
+
+ private def createRowConverter(schema: StructType) = {
+ Conversions.createRowConverter(schema).andThen(RowEncoder(schema).resolveAndBind().fromRow)
+ }
+
+ test("Data should be correctly converted") {
+ val convertRow = createRowConverter(TestUtils.testSchema)
+ val doubleMin = Double.MinValue.toString
+ val longMax = Long.MaxValue.toString
+ // scalastyle:off
+ val unicodeString = "Unicode是樂趣"
+ // scalastyle:on
+
+ val timestampWithMillis = "2014-03-01 00:00:01.123"
+
+ val expectedDateMillis = TestUtils.toMillis(2015, 6, 1, 0, 0, 0)
+ val expectedTimestampMillis = TestUtils.toMillis(2014, 2, 1, 0, 0, 1, 123)
+
+ val convertedRow = convertRow(
+ Array("1", "t", "2015-07-01", doubleMin, "1.0", "42",
+ longMax, "23", unicodeString, timestampWithMillis))
+
+ val expectedRow = Row(1.asInstanceOf[Byte], true, new Timestamp(expectedDateMillis),
+ Double.MinValue, 1.0f, 42, Long.MaxValue, 23.toShort, unicodeString,
+ new Timestamp(expectedTimestampMillis))
+
+ assert(convertedRow == expectedRow)
+ }
+
+ test("Row conversion handles null values") {
+ val convertRow = createRowConverter(TestUtils.testSchema)
+ val emptyRow = List.fill(TestUtils.testSchema.length)(null).toArray[String]
+ assert(convertRow(emptyRow) === Row(emptyRow: _*))
+ }
+
+ test("Booleans are correctly converted") {
+ val convertRow = createRowConverter(StructType(Seq(StructField("a", BooleanType))))
+ assert(convertRow(Array("t")) === Row(true))
+ assert(convertRow(Array("f")) === Row(false))
+ assert(convertRow(Array(null)) === Row(null))
+ intercept[IllegalArgumentException] {
+ convertRow(Array("not-a-boolean"))
+ }
+ }
+
+ test("timestamp conversion handles millisecond-level precision (regression test for #214)") {
+ val schema = StructType(Seq(StructField("a", TimestampType)))
+ val convertRow = createRowConverter(schema)
+ Seq(
+ "2014-03-01 00:00:01" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 1000),
+ "2014-03-01 00:00:01.000" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 1000),
+ "2014-03-01 00:00:00.1" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 100),
+ "2014-03-01 00:00:00.10" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 100),
+ "2014-03-01 00:00:00.100" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 100),
+ "2014-03-01 00:00:00.01" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 10),
+ "2014-03-01 00:00:00.010" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 10),
+ "2014-03-01 00:00:00.001" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 1)
+ ).foreach { case (timestampString, expectedTime) =>
+ withClue(s"timestamp string is '$timestampString'") {
+ val convertedRow = convertRow(Array(timestampString))
+ val convertedTimestamp = convertedRow.get(0).asInstanceOf[Timestamp]
+ assert(convertedTimestamp === new Timestamp(expectedTime))
+ }
+ }
+ }
+
+ test("RedshiftDecimalFormat is locale-insensitive (regression test for #243)") {
+ for (locale <- Seq(Locale.US, Locale.GERMAN, Locale.UK)) {
+ withClue(s"locale = $locale") {
+ TestUtils.withDefaultLocale(locale) {
+ val decimalFormat = Conversions.createRedshiftDecimalFormat()
+ val parsed = decimalFormat.parse("151.20").asInstanceOf[java.math.BigDecimal]
+ assert(parsed.doubleValue() === 151.20)
+ }
+ }
+ }
+ }
+
+ test("Row conversion properly handles NaN and Inf float values (regression test for #261)") {
+ val convertRow = createRowConverter(StructType(Seq(StructField("a", FloatType))))
+ assert(java.lang.Float.isNaN(convertRow(Array("nan")).getFloat(0)))
+ assert(convertRow(Array("inf")) === Row(Float.PositiveInfinity))
+ assert(convertRow(Array("-inf")) === Row(Float.NegativeInfinity))
+ }
+
+ test("Row conversion properly handles NaN and Inf double values (regression test for #261)") {
+ val convertRow = createRowConverter(StructType(Seq(StructField("a", DoubleType))))
+ assert(java.lang.Double.isNaN(convertRow(Array("nan")).getDouble(0)))
+ assert(convertRow(Array("inf")) === Row(Double.PositiveInfinity))
+ assert(convertRow(Array("-inf")) === Row(Double.NegativeInfinity))
+ }
+}
diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/DirectMapredOutputCommitter.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/DirectMapredOutputCommitter.scala
new file mode 100755
index 0000000000000..a24dd001920ec
--- /dev/null
+++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/DirectMapredOutputCommitter.scala
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapred._
+
+class DirectMapredOutputCommitter extends OutputCommitter {
+ override def setupJob(jobContext: JobContext): Unit = { }
+
+ override def setupTask(taskContext: TaskAttemptContext): Unit = { }
+
+ override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = {
+ // We return true here to guard against implementations that do not handle false correctly.
+ // The meaning of returning false is not entirely clear, so it's possible to be interpreted
+ // as an error. Returning true just means that commitTask() will be called, which is a no-op.
+ true
+ }
+
+ override def commitTask(taskContext: TaskAttemptContext): Unit = { }
+
+ override def abortTask(taskContext: TaskAttemptContext): Unit = { }
+
+ /**
+ * Creates a _SUCCESS file to indicate the entire job was successful.
+ * This mimics the behavior of FileOutputCommitter, reusing the same file name and conf option.
+ */
+ override def commitJob(context: JobContext): Unit = {
+ val conf = context.getJobConf
+ if (shouldCreateSuccessFile(conf)) {
+ val outputPath = FileOutputFormat.getOutputPath(conf)
+ if (outputPath != null) {
+ val fileSys = outputPath.getFileSystem(conf)
+ val filePath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME)
+ fileSys.create(filePath).close()
+ }
+ }
+ }
+
+ /** By default, we do create the _SUCCESS file, but we allow it to be turned off. */
+ private def shouldCreateSuccessFile(conf: JobConf): Boolean = {
+ conf.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)
+ }
+}
diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/DirectMapreduceOutputCommitter.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/DirectMapreduceOutputCommitter.scala
new file mode 100755
index 0000000000000..922706d30ad3b
--- /dev/null
+++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/DirectMapreduceOutputCommitter.scala
@@ -0,0 +1,53 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat}
+
+class DirectMapreduceOutputCommitter extends OutputCommitter {
+ override def setupJob(jobContext: JobContext): Unit = { }
+
+ override def setupTask(taskContext: TaskAttemptContext): Unit = { }
+
+ override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = {
+ // We return true here to guard against implementations that do not handle false correctly.
+ // The meaning of returning false is not entirely clear, so it's possible to be interpreted
+ // as an error. Returning true just means that commitTask() will be called, which is a no-op.
+ true
+ }
+
+ override def commitTask(taskContext: TaskAttemptContext): Unit = { }
+
+ override def abortTask(taskContext: TaskAttemptContext): Unit = { }
+
+ /**
+ * Creates a _SUCCESS file to indicate the entire job was successful.
+ * This mimics the behavior of FileOutputCommitter, reusing the same file name and conf option.
+ */
+ override def commitJob(context: JobContext): Unit = {
+ val conf = context.getConfiguration
+ if (shouldCreateSuccessFile(conf)) {
+ val outputPath = FileOutputFormat.getOutputPath(context)
+ if (outputPath != null) {
+ val fileSys = outputPath.getFileSystem(conf)
+ val filePath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME)
+ fileSys.create(filePath).close()
+ }
+ }
+ }
+
+ /** By default, we do create the _SUCCESS file, but we allow it to be turned off. */
+ private def shouldCreateSuccessFile(conf: Configuration): Boolean = {
+ conf.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)
+ }
+}
diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/FilterPushdownSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/FilterPushdownSuite.scala
new file mode 100755
index 0000000000000..d3e09d5e8a733
--- /dev/null
+++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/FilterPushdownSuite.scala
@@ -0,0 +1,88 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import com.databricks.spark.redshift.FilterPushdown._
+
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+import org.apache.spark.SparkFunSuite
+
+class FilterPushdownSuite extends SparkFunSuite {
+ test("buildWhereClause with empty list of filters") {
+ assert(buildWhereClause(StructType(Nil), Seq.empty) === "")
+ }
+
+ test("buildWhereClause with no filters that can be pushed down") {
+ assert(buildWhereClause(StructType(Nil), Seq(NewFilter, NewFilter)) === "")
+ }
+
+ test("buildWhereClause with with some filters that cannot be pushed down") {
+ val whereClause = buildWhereClause(testSchema, Seq(EqualTo("test_int", 1), NewFilter))
+ assert(whereClause === """WHERE "test_int" = 1""")
+ }
+
+ test("buildWhereClause with string literals that contain Unicode characters") {
+ // scalastyle:off
+ val whereClause = buildWhereClause(testSchema, Seq(EqualTo("test_string", "Unicode's樂趣")))
+ // Here, the apostrophe in the string needs to be replaced with two single quotes, '', but we
+ // also need to escape those quotes with backslashes because this WHERE clause is going to
+ // eventually be embedded inside of a single-quoted string that's embedded inside of a larger
+ // Redshift query.
+ assert(whereClause === """WHERE "test_string" = \'Unicode\'\'s樂趣\'""")
+ // scalastyle:on
+ }
+
+ test("buildWhereClause with multiple filters") {
+ val filters = Seq(
+ EqualTo("test_bool", true),
+ // scalastyle:off
+ EqualTo("test_string", "Unicode是樂趣"),
+ // scalastyle:on
+ GreaterThan("test_double", 1000.0),
+ LessThan("test_double", Double.MaxValue),
+ GreaterThanOrEqual("test_float", 1.0f),
+ LessThanOrEqual("test_int", 43),
+ IsNotNull("test_int"),
+ IsNull("test_int"))
+ val whereClause = buildWhereClause(testSchema, filters)
+ // scalastyle:off
+ val expectedWhereClause =
+ """
+ |WHERE "test_bool" = true
+ |AND "test_string" = \'Unicode是樂趣\'
+ |AND "test_double" > 1000.0
+ |AND "test_double" < 1.7976931348623157E308
+ |AND "test_float" >= 1.0
+ |AND "test_int" <= 43
+ |AND "test_int" IS NOT NULL
+ |AND "test_int" IS NULL
+ """.stripMargin.lines.mkString(" ").trim
+ // scalastyle:on
+ assert(whereClause === expectedWhereClause)
+ }
+
+ private val testSchema: StructType = StructType(Seq(
+ StructField("test_byte", ByteType),
+ StructField("test_bool", BooleanType),
+ StructField("test_date", DateType),
+ StructField("test_double", DoubleType),
+ StructField("test_float", FloatType),
+ StructField("test_int", IntegerType),
+ StructField("test_long", LongType),
+ StructField("test_short", ShortType),
+ StructField("test_string", StringType),
+ StructField("test_timestamp", TimestampType)))
+
+ /** A new filter subclasss which our pushdown logic does not know how to handle */
+ private case object NewFilter extends Filter {
+ override def references: Array[String] = Array.empty
+ }
+}
diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/MockRedshift.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/MockRedshift.scala
new file mode 100755
index 0000000000000..d665839fae5fc
--- /dev/null
+++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/MockRedshift.scala
@@ -0,0 +1,117 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.sql.{Connection, PreparedStatement, ResultSet, SQLException}
+
+import scala.collection.mutable
+import scala.util.matching.Regex
+
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.Assertions._
+
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Helper class for mocking Redshift / JDBC in unit tests.
+ */
+class MockRedshift(
+ jdbcUrl: String,
+ existingTablesAndSchemas: Map[String, StructType],
+ jdbcQueriesThatShouldFail: Seq[Regex] = Seq.empty) {
+
+ private[this] val queriesIssued: mutable.Buffer[String] = mutable.Buffer.empty
+ def getQueriesIssuedAgainstRedshift: Seq[String] = queriesIssued.toSeq
+
+ private[this] val jdbcConnections: mutable.Buffer[Connection] = mutable.Buffer.empty
+
+ val jdbcWrapper: JDBCWrapper = spy(new JDBCWrapper)
+
+ private def createMockConnection(): Connection = {
+ val conn = mock(classOf[Connection], RETURNS_SMART_NULLS)
+ jdbcConnections.append(conn)
+ when(conn.prepareStatement(anyString())).thenAnswer(new Answer[PreparedStatement] {
+ override def answer(invocation: InvocationOnMock): PreparedStatement = {
+ val query = invocation.getArguments()(0).asInstanceOf[String]
+ queriesIssued.append(query)
+ val mockStatement = mock(classOf[PreparedStatement], RETURNS_SMART_NULLS)
+ if (jdbcQueriesThatShouldFail.forall(_.findFirstMatchIn(query).isEmpty)) {
+ when(mockStatement.execute()).thenReturn(true)
+ when(mockStatement.executeQuery()).thenReturn(
+ mock(classOf[ResultSet], RETURNS_SMART_NULLS))
+ } else {
+ when(mockStatement.execute()).thenThrow(new SQLException(s"Error executing $query"))
+ when(mockStatement.executeQuery()).thenThrow(new SQLException(s"Error executing $query"))
+ }
+ mockStatement
+ }
+ })
+ conn
+ }
+
+ doAnswer(new Answer[Connection] {
+ override def answer(invocation: InvocationOnMock): Connection = createMockConnection()
+ }).when(jdbcWrapper)
+ .getConnector(any[Option[String]](), same(jdbcUrl), any[Option[(String, String)]]())
+
+ doAnswer(new Answer[Boolean] {
+ override def answer(invocation: InvocationOnMock): Boolean = {
+ existingTablesAndSchemas.contains(invocation.getArguments()(1).asInstanceOf[String])
+ }
+ }).when(jdbcWrapper).tableExists(any[Connection], anyString())
+
+ doAnswer(new Answer[StructType] {
+ override def answer(invocation: InvocationOnMock): StructType = {
+ existingTablesAndSchemas(invocation.getArguments()(1).asInstanceOf[String])
+ }
+ }).when(jdbcWrapper).resolveTable(any[Connection], anyString())
+
+ def verifyThatConnectionsWereClosed(): Unit = {
+ jdbcConnections.foreach { conn =>
+ verify(conn).close()
+ }
+ }
+
+ def verifyThatRollbackWasCalled(): Unit = {
+ jdbcConnections.foreach { conn =>
+ verify(conn, atLeastOnce()).rollback()
+ }
+ }
+
+ def verifyThatCommitWasNotCalled(): Unit = {
+ jdbcConnections.foreach { conn =>
+ verify(conn, never()).commit()
+ }
+ }
+
+ def verifyThatExpectedQueriesWereIssued(expectedQueries: Seq[Regex]): Unit = {
+ expectedQueries.zip(queriesIssued).foreach { case (expected, actual) =>
+ if (expected.findFirstMatchIn(actual).isEmpty) {
+ fail(
+ s"""
+ |Actual and expected JDBC queries did not match:
+ |Expected: $expected
+ |Actual: $actual
+ """.stripMargin)
+ }
+ }
+ if (expectedQueries.length > queriesIssued.length) {
+ val missingQueries = expectedQueries.drop(queriesIssued.length)
+ fail(s"Missing ${missingQueries.length} expected JDBC queries:" +
+ s"\n${missingQueries.mkString("\n")}")
+ } else if (queriesIssued.length > expectedQueries.length) {
+ val extraQueries = queriesIssued.drop(expectedQueries.length)
+ fail(s"Got ${extraQueries.length} unexpected JDBC queries:\n${extraQueries.mkString("\n")}")
+ }
+ }
+}
diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala
new file mode 100755
index 0000000000000..dea37a5a31d82
--- /dev/null
+++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala
@@ -0,0 +1,145 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import org.scalatest.Matchers
+
+import org.apache.spark.SparkFunSuite
+
+/**
+ * Check validation of parameter config
+ */
+class ParametersSuite extends SparkFunSuite with Matchers {
+
+ test("Minimal valid parameter map is accepted") {
+ val params = Map(
+ "tempdir" -> "s3://foo/bar",
+ "dbtable" -> "test_schema.test_table",
+ "url" -> "jdbc:redshift://foo/bar?user=user&password=password",
+ "forward_spark_s3_credentials" -> "true")
+
+ val mergedParams = Parameters.mergeParameters(params)
+
+ mergedParams.rootTempDir should startWith (params("tempdir"))
+ mergedParams.createPerQueryTempDir() should startWith (params("tempdir"))
+ mergedParams.jdbcUrl shouldBe params("url")
+ mergedParams.table shouldBe Some(TableName("test_schema", "test_table"))
+ assert(mergedParams.forwardSparkS3Credentials)
+
+ // Check that the defaults have been added
+ (Parameters.DEFAULT_PARAMETERS - "forward_spark_s3_credentials").foreach {
+ case (key, value) => mergedParams.parameters(key) shouldBe value
+ }
+ }
+
+ test("createPerQueryTempDir() returns distinct temp paths") {
+ val params = Map(
+ "forward_spark_s3_credentials" -> "true",
+ "tempdir" -> "s3://foo/bar",
+ "dbtable" -> "test_table",
+ "url" -> "jdbc:redshift://foo/bar?user=user&password=password")
+
+ val mergedParams = Parameters.mergeParameters(params)
+
+ mergedParams.createPerQueryTempDir() should not equal mergedParams.createPerQueryTempDir()
+ }
+
+ test("Errors are thrown when mandatory parameters are not provided") {
+ def checkMerge(params: Map[String, String], err: String): Unit = {
+ val e = intercept[IllegalArgumentException] {
+ Parameters.mergeParameters(params)
+ }
+ assert(e.getMessage.contains(err))
+ }
+ val testURL = "jdbc:redshift://foo/bar?user=user&password=password"
+ checkMerge(Map("dbtable" -> "test_table", "url" -> testURL), "tempdir")
+ checkMerge(Map("tempdir" -> "s3://foo/bar", "url" -> testURL), "Redshift table name")
+ checkMerge(Map("dbtable" -> "test_table", "tempdir" -> "s3://foo/bar"), "JDBC URL")
+ checkMerge(Map("dbtable" -> "test_table", "tempdir" -> "s3://foo/bar", "url" -> testURL),
+ "method for authenticating")
+ }
+
+ test("Must specify either 'dbtable' or 'query' parameter, but not both") {
+ intercept[IllegalArgumentException] {
+ Parameters.mergeParameters(Map(
+ "forward_spark_s3_credentials" -> "true",
+ "tempdir" -> "s3://foo/bar",
+ "url" -> "jdbc:redshift://foo/bar?user=user&password=password"))
+ }.getMessage should (include ("dbtable") and include ("query"))
+
+ intercept[IllegalArgumentException] {
+ Parameters.mergeParameters(Map(
+ "forward_spark_s3_credentials" -> "true",
+ "tempdir" -> "s3://foo/bar",
+ "dbtable" -> "test_table",
+ "query" -> "select * from test_table",
+ "url" -> "jdbc:redshift://foo/bar?user=user&password=password"))
+ }.getMessage should (include ("dbtable") and include ("query") and include("both"))
+
+ Parameters.mergeParameters(Map(
+ "forward_spark_s3_credentials" -> "true",
+ "tempdir" -> "s3://foo/bar",
+ "query" -> "select * from test_table",
+ "url" -> "jdbc:redshift://foo/bar?user=user&password=password"))
+ }
+
+ test("Must specify credentials in either URL or 'user' and 'password' parameters, but not both") {
+ intercept[IllegalArgumentException] {
+ Parameters.mergeParameters(Map(
+ "forward_spark_s3_credentials" -> "true",
+ "tempdir" -> "s3://foo/bar",
+ "query" -> "select * from test_table",
+ "url" -> "jdbc:redshift://foo/bar"))
+ }.getMessage should (include ("credentials"))
+
+ intercept[IllegalArgumentException] {
+ Parameters.mergeParameters(Map(
+ "forward_spark_s3_credentials" -> "true",
+ "tempdir" -> "s3://foo/bar",
+ "query" -> "select * from test_table",
+ "user" -> "user",
+ "password" -> "password",
+ "url" -> "jdbc:redshift://foo/bar?user=user&password=password"))
+ }.getMessage should (include ("credentials") and include("both"))
+
+ Parameters.mergeParameters(Map(
+ "forward_spark_s3_credentials" -> "true",
+ "tempdir" -> "s3://foo/bar",
+ "query" -> "select * from test_table",
+ "url" -> "jdbc:redshift://foo/bar?user=user&password=password"))
+ }
+
+ test("tempformat option is case-insensitive") {
+ val params = Map(
+ "forward_spark_s3_credentials" -> "true",
+ "tempdir" -> "s3://foo/bar",
+ "dbtable" -> "test_schema.test_table",
+ "url" -> "jdbc:redshift://foo/bar?user=user&password=password")
+
+ Parameters.mergeParameters(params + ("tempformat" -> "csv"))
+ Parameters.mergeParameters(params + ("tempformat" -> "CSV"))
+
+ intercept[IllegalArgumentException] {
+ Parameters.mergeParameters(params + ("tempformat" -> "invalid-temp-format"))
+ }
+ }
+
+ test("can only specify one Redshift to S3 authentication mechanism") {
+ val e = intercept[IllegalArgumentException] {
+ Parameters.mergeParameters(Map(
+ "tempdir" -> "s3://foo/bar",
+ "dbtable" -> "test_schema.test_table",
+ "url" -> "jdbc:redshift://foo/bar?user=user&password=password",
+ "forward_spark_s3_credentials" -> "true",
+ "aws_iam_role" -> "role"))
+ }
+ assert(e.getMessage.contains("mutually-exclusive"))
+ }
+}
diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/QueryTest.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/QueryTest.scala
new file mode 100755
index 0000000000000..60533741e61ef
--- /dev/null
+++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/QueryTest.scala
@@ -0,0 +1,80 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.plans.logical
+
+/**
+ * Copy of Spark SQL's `QueryTest` trait.
+ */
+trait QueryTest extends SparkFunSuite {
+ /**
+ * Runs the plan and makes sure the answer matches the expected result.
+ * @param df the [[DataFrame]] to be executed
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ */
+ def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = {
+ val isSorted = df.queryExecution.logical.collect { case s: logical.Sort => s }.nonEmpty
+ def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
+ // Converts data to types that we can do equality comparison using Scala collections.
+ // For BigDecimal type, the Scala type has a better definition of equality test (similar to
+ // Java's java.math.BigDecimal.compareTo).
+ // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
+ // equality test.
+ val converted: Seq[Row] = answer.map { s =>
+ Row.fromSeq(s.toSeq.map {
+ case d: java.math.BigDecimal => BigDecimal(d)
+ case b: Array[Byte] => b.toSeq
+ case o => o
+ })
+ }
+ if (!isSorted) converted.sortBy(_.toString()) else converted
+ }
+ val sparkAnswer = try df.collect().toSeq catch {
+ case e: Exception =>
+ val errorMessage =
+ s"""
+ |Exception thrown while executing query:
+ |${df.queryExecution}
+ |== Exception ==
+ |$e
+ |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
+ """.stripMargin
+ fail(errorMessage)
+ }
+
+ if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
+ val errorMessage =
+ s"""
+ |Results do not match for query:
+ |${df.queryExecution}
+ |== Results ==
+ |${sideBySide(
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer).map(_.toString()),
+ s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
+ """.stripMargin
+ fail(errorMessage)
+ }
+ }
+
+ private def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = {
+ val maxLeftSize = left.map(_.length).max
+ val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("")
+ val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("")
+
+ leftPadded.zip(rightPadded).map {
+ case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r
+ }
+ }
+}
diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/RedshiftInputFormatSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/RedshiftInputFormatSuite.scala
new file mode 100755
index 0000000000000..aaf496a82cd11
--- /dev/null
+++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/RedshiftInputFormatSuite.scala
@@ -0,0 +1,147 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.io.{DataOutputStream, File, FileOutputStream}
+
+import scala.language.implicitConversions
+
+import com.databricks.spark.redshift.RedshiftInputFormat._
+import com.google.common.io.Files
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.{SparkContext, SparkFunSuite}
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.types._
+
+class RedshiftInputFormatSuite extends SparkFunSuite {
+
+ import RedshiftInputFormatSuite._
+
+ private var sc: SparkContext = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sc = new SparkContext("local", this.getClass.getName)
+ }
+
+ override def afterAll(): Unit = {
+ sc.stop()
+ super.afterAll()
+ }
+
+ private def writeToFile(contents: String, file: File): Unit = {
+ val bytes = contents.getBytes
+ val out = new DataOutputStream(new FileOutputStream(file))
+ out.write(bytes, 0, bytes.length)
+ out.close()
+ }
+
+ private def escape(records: Set[Seq[String]], delimiter: Char): String = {
+ require(delimiter != '\\' && delimiter != '\n')
+ records.map { r =>
+ r.map { f =>
+ f.replace("\\", "\\\\")
+ .replace("\n", "\\\n")
+ .replace(delimiter, "\\" + delimiter)
+ }.mkString(delimiter)
+ }.mkString("", "\n", "\n")
+ }
+
+ private final val KEY_BLOCK_SIZE = "fs.local.block.size"
+
+ private final val TAB = '\t'
+
+ private val records = Set(
+ Seq("a\n", DEFAULT_DELIMITER + "b\\"),
+ Seq("c", TAB + "d"),
+ Seq("\ne", "\\\\f"))
+
+ private def withTempDir(func: File => Unit): Unit = {
+ val dir = Files.createTempDir()
+ dir.deleteOnExit()
+ func(dir)
+ }
+
+ test("default delimiter") {
+ withTempDir { dir =>
+ val escaped = escape(records, DEFAULT_DELIMITER)
+ writeToFile(escaped, new File(dir, "part-00000"))
+
+ val conf = new Configuration
+ conf.setLong(KEY_BLOCK_SIZE, 4)
+
+ val rdd = sc.newAPIHadoopFile(dir.toString, classOf[RedshiftInputFormat],
+ classOf[java.lang.Long], classOf[Array[String]], conf)
+
+ // TODO: Check this assertion - fails on Travis only, no idea what, or what it's for
+ // assert(rdd.partitions.size > records.size) // so there exist at least one empty partition
+
+ val actual = rdd.values.map(_.toSeq).collect()
+ assert(actual.size === records.size)
+ assert(actual.toSet === records)
+ }
+ }
+
+ test("customized delimiter") {
+ withTempDir { dir =>
+ val escaped = escape(records, TAB)
+ writeToFile(escaped, new File(dir, "part-00000"))
+
+ val conf = new Configuration
+ conf.setLong(KEY_BLOCK_SIZE, 4)
+ conf.set(KEY_DELIMITER, TAB)
+
+ val rdd = sc.newAPIHadoopFile(dir.toString, classOf[RedshiftInputFormat],
+ classOf[java.lang.Long], classOf[Array[String]], conf)
+
+ // TODO: Check this assertion - fails on Travis only, no idea what, or what it's for
+ // assert(rdd.partitions.size > records.size) // so there exist at least one empty partitions
+
+ val actual = rdd.values.map(_.toSeq).collect()
+ assert(actual.size === records.size)
+ assert(actual.toSet === records)
+ }
+ }
+
+ test("schema parser") {
+ withTempDir { dir =>
+ val testRecords = Set(
+ Seq("a\n", "TX", 1, 1.0, 1000L, 200000000000L),
+ Seq("b", "CA", 2, 2.0, 2000L, 1231412314L))
+ val escaped = escape(testRecords.map(_.map(_.toString)), DEFAULT_DELIMITER)
+ writeToFile(escaped, new File(dir, "part-00000"))
+
+ val sqlContext = new SQLContext(sc)
+ val expectedSchema = StructType(Seq(
+ StructField("name", StringType, nullable = true),
+ StructField("state", StringType, nullable = true),
+ StructField("id", IntegerType, nullable = true),
+ StructField("score", DoubleType, nullable = true),
+ StructField("big_score", LongType, nullable = true),
+ StructField("some_long", LongType, nullable = true)))
+
+ val df = sqlContext.redshiftFile(dir.toString, expectedSchema)
+ assert(df.schema === expectedSchema)
+
+ val parsed = df.rdd.map {
+ case Row(
+ name: String, state: String, id: Int, score: Double, bigScore: Long, someLong: Long
+ ) => Seq(name, state, id, score, bigScore, someLong)
+ }.collect().toSet
+
+ assert(parsed === testRecords)
+ }
+ }
+}
+
+object RedshiftInputFormatSuite {
+ implicit def charToString(c: Char): String = c.toString
+}
diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala
new file mode 100755
index 0000000000000..e039ad29ac59d
--- /dev/null
+++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala
@@ -0,0 +1,576 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.io.{ByteArrayInputStream, OutputStreamWriter}
+import java.net.URI
+
+import com.amazonaws.services.s3.AmazonS3Client
+import com.amazonaws.services.s3.model.{BucketLifecycleConfiguration, S3Object, S3ObjectInputStream}
+import com.amazonaws.services.s3.model.BucketLifecycleConfiguration.Rule
+import com.databricks.spark.redshift.Parameters.MergedParameters
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.s3native.S3NInMemoryFileSystem
+import org.apache.http.client.methods.HttpRequestBase
+import org.mockito.Matchers._
+import org.mockito.Mockito
+import org.mockito.Mockito.when
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers}
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+
+/**
+ * Tests main DataFrame loading and writing functionality
+ */
+class RedshiftSourceSuite
+ extends QueryTest
+ with Matchers
+ with BeforeAndAfterAll
+ with BeforeAndAfterEach {
+
+ /**
+ * Spark Context with Hadoop file overridden to point at our local test data file for this suite,
+ * no matter what temp directory was generated and requested.
+ */
+ private var sc: SparkContext = _
+
+ private var testSqlContext: SQLContext = _
+
+ private var expectedDataDF: DataFrame = _
+
+ private var mockS3Client: AmazonS3Client = _
+
+ private var s3FileSystem: FileSystem = _
+
+ private val s3TempDir: String = "s3n://test-bucket/temp-dir/"
+
+ private var unloadedData: String = ""
+
+ // Parameters common to most tests. Some parameters are overridden in specific tests.
+ private def defaultParams: Map[String, String] = Map(
+ "url" -> "jdbc:redshift://foo/bar?user=user&password=password",
+ "tempdir" -> s3TempDir,
+ "dbtable" -> "test_table",
+ "forward_spark_s3_credentials" -> "true")
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sc = new SparkContext("local", "RedshiftSourceSuite")
+ sc.hadoopConfiguration.set("fs.s3n.impl", classOf[S3NInMemoryFileSystem].getName)
+ // We need to use a DirectOutputCommitter to work around an issue which occurs with renames
+ // while using the mocked S3 filesystem.
+ sc.hadoopConfiguration.set("spark.sql.sources.outputCommitterClass",
+ classOf[DirectMapreduceOutputCommitter].getName)
+ sc.hadoopConfiguration.set("mapred.output.committer.class",
+ classOf[DirectMapredOutputCommitter].getName)
+ sc.hadoopConfiguration.set("fs.s3.awsAccessKeyId", "test1")
+ sc.hadoopConfiguration.set("fs.s3.awsSecretAccessKey", "test2")
+ sc.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", "test1")
+ sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", "test2")
+ }
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ s3FileSystem = FileSystem.get(new URI(s3TempDir), sc.hadoopConfiguration)
+ testSqlContext = new SQLContext(sc)
+ expectedDataDF =
+ testSqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), TestUtils.testSchema)
+ // Configure a mock S3 client so that we don't hit errors when trying to access AWS in tests.
+ mockS3Client = Mockito.mock(classOf[AmazonS3Client], Mockito.RETURNS_SMART_NULLS)
+ when(mockS3Client.getBucketLifecycleConfiguration(anyString())).thenReturn(
+ new BucketLifecycleConfiguration().withRules(
+ new Rule().withPrefix("").withStatus(BucketLifecycleConfiguration.ENABLED)
+ ))
+ val mockManifest = Mockito.mock(classOf[S3Object], Mockito.RETURNS_SMART_NULLS)
+ when(mockManifest.getObjectContent).thenAnswer {
+ new Answer[S3ObjectInputStream] {
+ override def answer(invocationOnMock: InvocationOnMock): S3ObjectInputStream = {
+ val manifest =
+ s"""
+ | {
+ | "entries": [
+ | { "url": "${Utils.fixS3Url(Utils.lastTempPathGenerated)}/part-00000" }
+ | ]
+ | }
+ """.stripMargin
+ // Write the data to the output file specified in the manifest:
+ val out = s3FileSystem.create(new Path(s"${Utils.lastTempPathGenerated}/part-00000"))
+ val ow = new OutputStreamWriter(out.getWrappedStream)
+ ow.write(unloadedData)
+ ow.close()
+ out.close()
+ val is = new ByteArrayInputStream(manifest.getBytes("UTF-8"))
+ new S3ObjectInputStream(
+ is,
+ Mockito.mock(classOf[HttpRequestBase], Mockito.RETURNS_SMART_NULLS))
+ }
+ }
+ }
+ when(mockS3Client.getObject(anyString(), endsWith("manifest"))).thenReturn(mockManifest)
+ }
+
+ override def afterEach(): Unit = {
+ super.afterEach()
+ testSqlContext = null
+ expectedDataDF = null
+ mockS3Client = null
+ FileSystem.closeAll()
+ }
+
+ override def afterAll(): Unit = {
+ sc.stop()
+ super.afterAll()
+ }
+
+ test("DefaultSource can load Redshift UNLOAD output to a DataFrame") {
+ // scalastyle:off
+ unloadedData =
+ """
+ |1|t|2015-07-01|1234152.12312498|1.0|42|1239012341823719|23|Unicode's樂趣|2015-07-01 00:00:00.001
+ |1|f|2015-07-02|0|0.0|42|1239012341823719|-13|asdf|2015-07-02 00:00:00.0
+ |0||2015-07-03|0.0|-1.0|4141214|1239012341823719||f|2015-07-03 00:00:00
+ |0|f||-1234152.12312498|100000.0||1239012341823719|24|___\|_123|
+ ||||||||||
+ """.stripMargin.trim
+ // scalastyle:on
+ val expectedQuery = (
+ "UNLOAD \\('SELECT \"testbyte\", \"testbool\", \"testdate\", \"testdouble\"," +
+ " \"testfloat\", \"testint\", \"testlong\", \"testshort\", \"teststring\", " +
+ "\"testtimestamp\" " +
+ "FROM \"PUBLIC\".\"test_table\" '\\) " +
+ "TO '.*' " +
+ "WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " +
+ "ESCAPE").r
+ val mockRedshift = new MockRedshift(
+ defaultParams("url"),
+ Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema))
+
+ // Assert that we've loaded and converted all data in the test file
+ val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
+ val relation = source.createRelation(testSqlContext, defaultParams)
+ val df = testSqlContext.baseRelationToDataFrame(relation)
+ checkAnswer(df, TestUtils.expectedData)
+ mockRedshift.verifyThatConnectionsWereClosed()
+ mockRedshift.verifyThatExpectedQueriesWereIssued(Seq(expectedQuery))
+ }
+
+ test("Can load output of Redshift queries") {
+ // scalastyle:off
+ val expectedJDBCQuery =
+ """
+ |UNLOAD \('SELECT "testbyte", "testbool" FROM
+ | \(select testbyte, testbool
+ | from test_table
+ | where teststring = \\'\\\\\\\\Unicode\\'\\'s樂趣\\'\) '\)
+ """.stripMargin.lines.map(_.trim).mkString(" ").trim.r
+ val query =
+ """select testbyte, testbool from test_table where teststring = '\\Unicode''s樂趣'"""
+ unloadedData = "1|t"
+ // scalastyle:on
+ val querySchema =
+ StructType(Seq(StructField("testbyte", ByteType), StructField("testbool", BooleanType)))
+
+ val expectedValues = Array(Row(1.toByte, true))
+
+ // Test with dbtable parameter that wraps the query in parens:
+ {
+ val params = defaultParams + ("dbtable" -> s"($query)")
+ val mockRedshift =
+ new MockRedshift(defaultParams("url"), Map(params("dbtable") -> querySchema))
+ val relation = new DefaultSource(
+ mockRedshift.jdbcWrapper, _ => mockS3Client).createRelation(testSqlContext, params)
+ assert(testSqlContext.baseRelationToDataFrame(relation).collect() === expectedValues)
+ mockRedshift.verifyThatConnectionsWereClosed()
+ mockRedshift.verifyThatExpectedQueriesWereIssued(Seq(expectedJDBCQuery))
+ }
+
+ // Test with query parameter
+ {
+ val params = defaultParams - "dbtable" + ("query" -> query)
+ val mockRedshift = new MockRedshift(defaultParams("url"), Map(s"($query)" -> querySchema))
+ val relation = new DefaultSource(
+ mockRedshift.jdbcWrapper, _ => mockS3Client).createRelation(testSqlContext, params)
+ assert(testSqlContext.baseRelationToDataFrame(relation).collect() === expectedValues)
+ mockRedshift.verifyThatConnectionsWereClosed()
+ mockRedshift.verifyThatExpectedQueriesWereIssued(Seq(expectedJDBCQuery))
+ }
+ }
+
+ test("DefaultSource supports simple column filtering") {
+ // scalastyle:off
+ unloadedData =
+ """
+ |1|t
+ |1|f
+ |0|
+ |0|f
+ ||
+ """.stripMargin.trim
+ // scalastyle:on
+ val expectedQuery = (
+ "UNLOAD \\('SELECT \"testbyte\", \"testbool\" FROM \"PUBLIC\".\"test_table\" '\\) " +
+ "TO '.*' " +
+ "WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " +
+ "ESCAPE").r
+ val mockRedshift =
+ new MockRedshift(defaultParams("url"), Map("test_table" -> TestUtils.testSchema))
+ // Construct the source with a custom schema
+ val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
+ val relation = source.createRelation(testSqlContext, defaultParams, TestUtils.testSchema)
+ val resultSchema =
+ StructType(Seq(StructField("testbyte", ByteType), StructField("testbool", BooleanType)))
+
+ val rdd = relation.asInstanceOf[PrunedFilteredScan]
+ .buildScan(Array("testbyte", "testbool"), Array.empty[Filter])
+ .mapPartitions { iter =>
+ val fromRow = RowEncoder(resultSchema).resolveAndBind().fromRow _
+ iter.asInstanceOf[Iterator[InternalRow]].map(fromRow)
+ }
+ val prunedExpectedValues = Array(
+ Row(1.toByte, true),
+ Row(1.toByte, false),
+ Row(0.toByte, null),
+ Row(0.toByte, false),
+ Row(null, null))
+ assert(rdd.collect() === prunedExpectedValues)
+ mockRedshift.verifyThatConnectionsWereClosed()
+ mockRedshift.verifyThatExpectedQueriesWereIssued(Seq(expectedQuery))
+ }
+
+ test("DefaultSource supports user schema, pruned and filtered scans") {
+ // scalastyle:off
+ unloadedData = "1|t"
+ val expectedQuery = (
+ "UNLOAD \\('SELECT \"testbyte\", \"testbool\" " +
+ "FROM \"PUBLIC\".\"test_table\" " +
+ "WHERE \"testbool\" = true " +
+ "AND \"teststring\" = \\\\'Unicode\\\\'\\\\'s樂趣\\\\' " +
+ "AND \"testdouble\" > 1000.0 " +
+ "AND \"testdouble\" < 1.7976931348623157E308 " +
+ "AND \"testfloat\" >= 1.0 " +
+ "AND \"testint\" <= 43'\\) " +
+ "TO '.*' " +
+ "WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " +
+ "ESCAPE").r
+ // scalastyle:on
+ val mockRedshift = new MockRedshift(
+ defaultParams("url"),
+ Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema))
+
+ // Construct the source with a custom schema
+ val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
+ val relation = source.createRelation(testSqlContext, defaultParams, TestUtils.testSchema)
+ val resultSchema =
+ StructType(Seq(StructField("testbyte", ByteType), StructField("testbool", BooleanType)))
+
+ // Define a simple filter to only include a subset of rows
+ val filters: Array[Filter] = Array(
+ EqualTo("testbool", true),
+ // scalastyle:off
+ EqualTo("teststring", "Unicode's樂趣"),
+ // scalastyle:on
+ GreaterThan("testdouble", 1000.0),
+ LessThan("testdouble", Double.MaxValue),
+ GreaterThanOrEqual("testfloat", 1.0f),
+ LessThanOrEqual("testint", 43))
+ val rdd = relation.asInstanceOf[PrunedFilteredScan]
+ .buildScan(Array("testbyte", "testbool"), filters)
+ .mapPartitions { iter =>
+ val fromRow = RowEncoder(resultSchema).resolveAndBind().fromRow _
+ iter.asInstanceOf[Iterator[InternalRow]].map(fromRow)
+ }
+
+ assert(rdd.collect() === Array(Row(1, true)))
+ mockRedshift.verifyThatConnectionsWereClosed()
+ mockRedshift.verifyThatExpectedQueriesWereIssued(Seq(expectedQuery))
+ }
+
+ test("DefaultSource supports preactions options to run queries before running COPY command") {
+ val mockRedshift = new MockRedshift(
+ defaultParams("url"),
+ Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema))
+ val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
+ val params = defaultParams ++ Map(
+ "preactions" ->
+ """
+ | DELETE FROM %s WHERE id < 100;
+ | DELETE FROM %s WHERE id > 100;
+ | DELETE FROM %s WHERE id = -1;
+ """.stripMargin.trim,
+ "usestagingtable" -> "true")
+
+ val expectedCommands = Seq(
+ "DROP TABLE IF EXISTS \"PUBLIC\".\"test_table.*\"".r,
+ "CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table.*\"".r,
+ "DELETE FROM \"PUBLIC\".\"test_table.*\" WHERE id < 100".r,
+ "DELETE FROM \"PUBLIC\".\"test_table.*\" WHERE id > 100".r,
+ "DELETE FROM \"PUBLIC\".\"test_table.*\" WHERE id = -1".r,
+ "COPY \"PUBLIC\".\"test_table.*\"".r)
+
+ source.createRelation(testSqlContext, SaveMode.Overwrite, params, expectedDataDF)
+ mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands)
+ mockRedshift.verifyThatConnectionsWereClosed()
+ }
+
+ test("DefaultSource serializes data as Avro, then sends Redshift COPY command") {
+ val params = defaultParams ++ Map(
+ "postactions" -> "GRANT SELECT ON %s TO jeremy",
+ "diststyle" -> "KEY",
+ "distkey" -> "testint")
+
+ val expectedCommands = Seq(
+ "DROP TABLE IF EXISTS \"PUBLIC\"\\.\"test_table.*\"".r,
+ ("CREATE TABLE IF NOT EXISTS \"PUBLIC\"\\.\"test_table.*" +
+ " DISTSTYLE KEY DISTKEY \\(testint\\).*").r,
+ "COPY \"PUBLIC\"\\.\"test_table.*\"".r,
+ "GRANT SELECT ON \"PUBLIC\"\\.\"test_table\" TO jeremy".r)
+
+ val mockRedshift = new MockRedshift(
+ defaultParams("url"),
+ Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema))
+
+ val relation = RedshiftRelation(
+ mockRedshift.jdbcWrapper,
+ _ => mockS3Client,
+ Parameters.mergeParameters(params),
+ userSchema = None)(testSqlContext)
+ relation.asInstanceOf[InsertableRelation].insert(expectedDataDF, overwrite = true)
+
+ // Make sure we wrote the data out ready for Redshift load, in the expected formats.
+ // The data should have been written to a random subdirectory of `tempdir`. Since we clear
+ // `tempdir` between every unit test, there should only be one directory here.
+ assert(s3FileSystem.listStatus(new Path(s3TempDir)).length === 1)
+ val dirWithAvroFiles = s3FileSystem.listStatus(new Path(s3TempDir)).head.getPath.toUri.toString
+ val written = testSqlContext.read.format("com.databricks.spark.avro").load(dirWithAvroFiles)
+ checkAnswer(written, TestUtils.expectedDataWithConvertedTimesAndDates)
+ mockRedshift.verifyThatConnectionsWereClosed()
+ mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands)
+ }
+
+ test("Cannot write table with column names that become ambiguous under case insensitivity") {
+ val mockRedshift = new MockRedshift(
+ defaultParams("url"),
+ Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema))
+
+ val schema = StructType(Seq(StructField("a", IntegerType), StructField("A", IntegerType)))
+ val df = testSqlContext.createDataFrame(sc.emptyRDD[Row], schema)
+ val writer = new RedshiftWriter(mockRedshift.jdbcWrapper, _ => mockS3Client)
+
+ intercept[IllegalArgumentException] {
+ writer.saveToRedshift(
+ testSqlContext, df, SaveMode.Append, Parameters.mergeParameters(defaultParams))
+ }
+ mockRedshift.verifyThatConnectionsWereClosed()
+ mockRedshift.verifyThatCommitWasNotCalled()
+ mockRedshift.verifyThatRollbackWasCalled()
+ mockRedshift.verifyThatExpectedQueriesWereIssued(Seq.empty)
+ }
+
+ test("Failed copies are handled gracefully when using a staging table") {
+ val params = defaultParams ++ Map("usestagingtable" -> "true")
+
+ val mockRedshift = new MockRedshift(
+ defaultParams("url"),
+ Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema),
+ jdbcQueriesThatShouldFail = Seq("COPY \"PUBLIC\".\"test_table.*\"".r))
+
+ val expectedCommands = Seq(
+ "DROP TABLE IF EXISTS \"PUBLIC\".\"test_table.*\"".r,
+ "CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table.*\"".r,
+ "COPY \"PUBLIC\".\"test_table.*\"".r,
+ ".*FROM stl_load_errors.*".r
+ )
+
+ val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
+ intercept[Exception] {
+ source.createRelation(testSqlContext, SaveMode.Overwrite, params, expectedDataDF)
+ }
+ mockRedshift.verifyThatConnectionsWereClosed()
+ mockRedshift.verifyThatCommitWasNotCalled()
+ mockRedshift.verifyThatRollbackWasCalled()
+ mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands)
+ }
+
+ test("Append SaveMode doesn't destroy existing data") {
+ val expectedCommands =
+ Seq("CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table\" .*".r,
+ "COPY \"PUBLIC\".\"test_table\" .*".r)
+
+ val mockRedshift = new MockRedshift(
+ defaultParams("url"),
+ Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> null))
+
+ val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
+ source.createRelation(testSqlContext, SaveMode.Append, defaultParams, expectedDataDF)
+
+ // This test is "appending" to an empty table, so we expect all our test data to be
+ // the only content in the returned data frame.
+ // The data should have been written to a random subdirectory of `tempdir`. Since we clear
+ // `tempdir` between every unit test, there should only be one directory here.
+ assert(s3FileSystem.listStatus(new Path(s3TempDir)).length === 1)
+ val dirWithAvroFiles = s3FileSystem.listStatus(new Path(s3TempDir)).head.getPath.toUri.toString
+ val written = testSqlContext.read.format("com.databricks.spark.avro").load(dirWithAvroFiles)
+ checkAnswer(written, TestUtils.expectedDataWithConvertedTimesAndDates)
+ mockRedshift.verifyThatConnectionsWereClosed()
+ mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands)
+ }
+
+ test("configuring maxlength on string columns") {
+ val longStrMetadata = new MetadataBuilder().putLong("maxlength", 512).build()
+ val shortStrMetadata = new MetadataBuilder().putLong("maxlength", 10).build()
+ val schema = StructType(
+ StructField("long_str", StringType, metadata = longStrMetadata) ::
+ StructField("short_str", StringType, metadata = shortStrMetadata) ::
+ StructField("default_str", StringType) ::
+ Nil)
+ val df = testSqlContext.createDataFrame(sc.emptyRDD[Row], schema)
+ val createTableCommand =
+ DefaultRedshiftWriter.createTableSql(df, MergedParameters.apply(defaultParams)).trim
+ val expectedCreateTableCommand =
+ """CREATE TABLE IF NOT EXISTS "PUBLIC"."test_table" ("long_str" VARCHAR(512),""" +
+ """ "short_str" VARCHAR(10), "default_str" TEXT)"""
+ assert(createTableCommand === expectedCreateTableCommand)
+ }
+
+ test("configuring encoding on columns") {
+ val lzoMetadata = new MetadataBuilder().putString("encoding", "LZO").build()
+ val runlengthMetadata = new MetadataBuilder().putString("encoding", "RUNLENGTH").build()
+ val schema = StructType(
+ StructField("lzo_str", StringType, metadata = lzoMetadata) ::
+ StructField("runlength_str", StringType, metadata = runlengthMetadata) ::
+ StructField("default_str", StringType) ::
+ Nil)
+ val df = testSqlContext.createDataFrame(sc.emptyRDD[Row], schema)
+ val createTableCommand =
+ DefaultRedshiftWriter.createTableSql(df, MergedParameters.apply(defaultParams)).trim
+ val expectedCreateTableCommand =
+ """CREATE TABLE IF NOT EXISTS "PUBLIC"."test_table" ("lzo_str" TEXT ENCODE LZO,""" +
+ """ "runlength_str" TEXT ENCODE RUNLENGTH, "default_str" TEXT)"""
+ assert(createTableCommand === expectedCreateTableCommand)
+ }
+
+ test("configuring descriptions on columns") {
+ val descriptionMetadata1 = new MetadataBuilder().putString("description", "Test1").build()
+ val descriptionMetadata2 = new MetadataBuilder().putString("description", "Test'2").build()
+ val schema = StructType(
+ StructField("first_str", StringType, metadata = descriptionMetadata1) ::
+ StructField("second_str", StringType, metadata = descriptionMetadata2) ::
+ StructField("default_str", StringType) ::
+ Nil)
+ val df = testSqlContext.createDataFrame(sc.emptyRDD[Row], schema)
+ val commentCommands =
+ DefaultRedshiftWriter.commentActions(Some("Test"), schema)
+ val expectedCommentCommands = List(
+ "COMMENT ON TABLE %s IS 'Test'",
+ "COMMENT ON COLUMN %s.\"first_str\" IS 'Test1'",
+ "COMMENT ON COLUMN %s.\"second_str\" IS 'Test''2'")
+ assert(commentCommands === expectedCommentCommands)
+ }
+
+ test("configuring redshift_type on columns") {
+ val bpcharMetadata = new MetadataBuilder().putString("redshift_type", "BPCHAR(2)").build()
+ val nvarcharMetadata = new MetadataBuilder().putString("redshift_type", "NVARCHAR(123)").build()
+
+ val schema = StructType(
+ StructField("bpchar_str", StringType, metadata = bpcharMetadata) ::
+ StructField("bpchar_str", StringType, metadata = nvarcharMetadata) ::
+ StructField("default_str", StringType) ::
+ Nil)
+
+ val df = testSqlContext.createDataFrame(sc.emptyRDD[Row], schema)
+ val createTableCommand =
+ DefaultRedshiftWriter.createTableSql(df, MergedParameters.apply(defaultParams)).trim
+ val expectedCreateTableCommand =
+ """CREATE TABLE IF NOT EXISTS "PUBLIC"."test_table" ("bpchar_str" BPCHAR(2),""" +
+ """ "bpchar_str" NVARCHAR(123), "default_str" TEXT)"""
+ assert(createTableCommand === expectedCreateTableCommand)
+ }
+
+ test("Respect SaveMode.ErrorIfExists when table exists") {
+ val mockRedshift = new MockRedshift(
+ defaultParams("url"),
+ Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> null))
+ val errIfExistsSource = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
+ intercept[Exception] {
+ errIfExistsSource.createRelation(
+ testSqlContext, SaveMode.ErrorIfExists, defaultParams, expectedDataDF)
+ }
+ mockRedshift.verifyThatConnectionsWereClosed()
+ mockRedshift.verifyThatExpectedQueriesWereIssued(Seq.empty)
+ }
+
+ test("Do nothing when table exists if SaveMode = Ignore") {
+ val mockRedshift = new MockRedshift(
+ defaultParams("url"),
+ Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> null))
+ val ignoreSource = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
+ ignoreSource.createRelation(testSqlContext, SaveMode.Ignore, defaultParams, expectedDataDF)
+ mockRedshift.verifyThatConnectionsWereClosed()
+ mockRedshift.verifyThatExpectedQueriesWereIssued(Seq.empty)
+ }
+
+ test("Cannot save when 'query' parameter is specified instead of 'dbtable'") {
+ val invalidParams = Map(
+ "url" -> "jdbc:redshift://foo/bar?user=user&password=password",
+ "tempdir" -> s3TempDir,
+ "query" -> "select * from test_table",
+ "forward_spark_s3_credentials" -> "true")
+
+ val e1 = intercept[IllegalArgumentException] {
+ expectedDataDF.write.format("com.databricks.spark.redshift").options(invalidParams).save()
+ }
+ assert(e1.getMessage.contains("dbtable"))
+ }
+
+ test("Public Scala API rejects invalid parameter maps") {
+ val invalidParams = Map("dbtable" -> "foo") // missing tempdir and url
+
+ val e1 = intercept[IllegalArgumentException] {
+ expectedDataDF.write.format("com.databricks.spark.redshift").options(invalidParams).save()
+ }
+ assert(e1.getMessage.contains("tempdir"))
+
+ val e2 = intercept[IllegalArgumentException] {
+ expectedDataDF.write.format("com.databricks.spark.redshift").options(invalidParams).save()
+ }
+ assert(e2.getMessage.contains("tempdir"))
+ }
+
+ test("DefaultSource has default constructor, required by Data Source API") {
+ new DefaultSource()
+ }
+
+ test("Saves throw error message if S3 Block FileSystem would be used") {
+ val params = defaultParams + ("tempdir" -> defaultParams("tempdir").replace("s3n", "s3"))
+ val e = intercept[IllegalArgumentException] {
+ expectedDataDF.write
+ .format("com.databricks.spark.redshift")
+ .mode("append")
+ .options(params)
+ .save()
+ }
+ assert(e.getMessage.contains("Block FileSystem"))
+ }
+
+ test("Loads throw error message if S3 Block FileSystem would be used") {
+ val params = defaultParams + ("tempdir" -> defaultParams("tempdir").replace("s3n", "s3"))
+ val e = intercept[IllegalArgumentException] {
+ testSqlContext.read.format("com.databricks.spark.redshift").options(params).load()
+ }
+ assert(e.getMessage.contains("Block FileSystem"))
+ }
+}
diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/SerializableConfigurationSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/SerializableConfigurationSuite.scala
new file mode 100755
index 0000000000000..bef296d5e5de0
--- /dev/null
+++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/SerializableConfigurationSuite.scala
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance}
+
+class SerializableConfigurationSuite extends SparkFunSuite {
+
+ private def testSerialization(serializer: SerializerInstance): Unit = {
+ val conf = new SerializableConfiguration(new Configuration())
+
+ val serialized = serializer.serialize(conf)
+
+ serializer.deserialize[Any](serialized) match {
+ case c: SerializableConfiguration =>
+ assert(c.log != null, "log was null")
+ assert(c.value != null, "value was null")
+ case other => fail(
+ s"Expecting ${classOf[SerializableConfiguration]}, but got ${other.getClass}.")
+ }
+ }
+
+ test("serialization with JavaSerializer") {
+ testSerialization(new JavaSerializer(new SparkConf()).newInstance())
+ }
+
+ test("serialization with KryoSerializer") {
+ testSerialization(new KryoSerializer(new SparkConf()).newInstance())
+ }
+
+}
diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/TableNameSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/TableNameSuite.scala
new file mode 100755
index 0000000000000..6ffd0a1232624
--- /dev/null
+++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/TableNameSuite.scala
@@ -0,0 +1,30 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import org.apache.spark.SparkFunSuite
+
+class TableNameSuite extends SparkFunSuite {
+ test("TableName.parseFromEscaped") {
+ assert(TableName.parseFromEscaped("foo.bar") === TableName("foo", "bar"))
+ assert(TableName.parseFromEscaped("foo") === TableName("PUBLIC", "foo"))
+ assert(TableName.parseFromEscaped("\"foo\"") === TableName("PUBLIC", "foo"))
+ assert(TableName.parseFromEscaped("\"\"\"foo\"\"\".bar") === TableName("\"foo\"", "bar"))
+ // Dots (.) can also appear inside of valid identifiers.
+ assert(TableName.parseFromEscaped("\"foo.bar\".baz") === TableName("foo.bar", "baz"))
+ assert(TableName.parseFromEscaped("\"foo\"\".bar\".baz") === TableName("foo\".bar", "baz"))
+ }
+
+ test("TableName.toString") {
+ assert(TableName("foo", "bar").toString === """"foo"."bar"""")
+ assert(TableName("PUBLIC", "bar").toString === """"PUBLIC"."bar"""")
+ assert(TableName("\"foo\"", "bar").toString === "\"\"\"foo\"\"\".\"bar\"")
+ }
+}
diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/TestUtils.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/TestUtils.scala
new file mode 100755
index 0000000000000..53f837b5e0df8
--- /dev/null
+++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/TestUtils.scala
@@ -0,0 +1,117 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.sql.{Date, Timestamp}
+import java.util.{Calendar, Locale}
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+
+/**
+ * Helpers for Redshift tests that require common mocking
+ */
+object TestUtils {
+
+ /**
+ * Simple schema that includes all data types we support
+ */
+ val testSchema: StructType = {
+ // These column names need to be lowercase; see #51
+ StructType(Seq(
+ StructField("testbyte", ByteType),
+ StructField("testbool", BooleanType),
+ StructField("testdate", DateType),
+ StructField("testdouble", DoubleType),
+ StructField("testfloat", FloatType),
+ StructField("testint", IntegerType),
+ StructField("testlong", LongType),
+ StructField("testshort", ShortType),
+ StructField("teststring", StringType),
+ StructField("testtimestamp", TimestampType)))
+ }
+
+ // scalastyle:off
+ /**
+ * Expected parsed output corresponding to the output of testData.
+ */
+ val expectedData: Seq[Row] = Seq(
+ Row(1.toByte, true, TestUtils.toDate(2015, 6, 1), 1234152.12312498,
+ 1.0f, 42, 1239012341823719L, 23.toShort, "Unicode's樂趣",
+ TestUtils.toTimestamp(2015, 6, 1, 0, 0, 0, 1)),
+ Row(1.toByte, false, TestUtils.toDate(2015, 6, 2), 0.0, 0.0f, 42,
+ 1239012341823719L, -13.toShort, "asdf", TestUtils.toTimestamp(2015, 6, 2, 0, 0, 0, 0)),
+ Row(0.toByte, null, TestUtils.toDate(2015, 6, 3), 0.0, -1.0f, 4141214,
+ 1239012341823719L, null, "f", TestUtils.toTimestamp(2015, 6, 3, 0, 0, 0)),
+ Row(0.toByte, false, null, -1234152.12312498, 100000.0f, null, 1239012341823719L, 24.toShort,
+ "___|_123", null),
+ Row(List.fill(10)(null): _*))
+ // scalastyle:on
+
+ /**
+ * The same as `expectedData`, but with dates and timestamps converted into string format.
+ * See #39 for context.
+ */
+ val expectedDataWithConvertedTimesAndDates: Seq[Row] = expectedData.map { row =>
+ Row.fromSeq(row.toSeq.map {
+ case t: Timestamp => Conversions.createRedshiftTimestampFormat().format(t)
+ case d: Date => Conversions.createRedshiftDateFormat().format(d)
+ case other => other
+ })
+ }
+
+ /**
+ * Convert date components to a millisecond timestamp
+ */
+ def toMillis(
+ year: Int,
+ zeroBasedMonth: Int,
+ date: Int,
+ hour: Int,
+ minutes: Int,
+ seconds: Int,
+ millis: Int = 0): Long = {
+ val calendar = Calendar.getInstance()
+ calendar.set(year, zeroBasedMonth, date, hour, minutes, seconds)
+ calendar.set(Calendar.MILLISECOND, millis)
+ calendar.getTime.getTime
+ }
+
+ /**
+ * Convert date components to a SQL Timestamp
+ */
+ def toTimestamp(
+ year: Int,
+ zeroBasedMonth: Int,
+ date: Int,
+ hour: Int,
+ minutes: Int,
+ seconds: Int,
+ millis: Int = 0): Timestamp = {
+ new Timestamp(toMillis(year, zeroBasedMonth, date, hour, minutes, seconds, millis))
+ }
+
+ /**
+ * Convert date components to a SQL [[Date]].
+ */
+ def toDate(year: Int, zeroBasedMonth: Int, date: Int): Date = {
+ new Date(toTimestamp(year, zeroBasedMonth, date, 0, 0, 0).getTime)
+ }
+
+ def withDefaultLocale[T](newDefaultLocale: Locale)(block: => T): T = {
+ val originalDefaultLocale = Locale.getDefault
+ try {
+ Locale.setDefault(newDefaultLocale)
+ block
+ } finally {
+ Locale.setDefault(originalDefaultLocale)
+ }
+ }
+}
diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/UtilsSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/UtilsSuite.scala
new file mode 100755
index 0000000000000..20a9180fe3405
--- /dev/null
+++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/UtilsSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Copyright (C) 2016 Databricks, Inc.
+ *
+ * Portions of this software incorporate or are derived from software contained within Apache Spark,
+ * and this modified software differs from the Apache Spark software provided under the Apache
+ * License, Version 2.0, a copy of which you may obtain at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.databricks.spark.redshift
+
+import java.net.URI
+
+import org.scalatest.Matchers
+
+import org.apache.spark.SparkFunSuite
+
+/**
+ * Unit tests for helper functions
+ */
+class UtilsSuite extends SparkFunSuite with Matchers {
+
+ test("joinUrls preserves protocol information") {
+ Utils.joinUrls("s3n://foo/bar/", "/baz") shouldBe "s3n://foo/bar/baz/"
+ Utils.joinUrls("s3n://foo/bar/", "/baz/") shouldBe "s3n://foo/bar/baz/"
+ Utils.joinUrls("s3n://foo/bar/", "baz/") shouldBe "s3n://foo/bar/baz/"
+ Utils.joinUrls("s3n://foo/bar/", "baz") shouldBe "s3n://foo/bar/baz/"
+ Utils.joinUrls("s3n://foo/bar", "baz") shouldBe "s3n://foo/bar/baz/"
+ }
+
+ test("joinUrls preserves credentials") {
+ assert(
+ Utils.joinUrls("s3n://ACCESSKEY:SECRETKEY@bucket/tempdir", "subdir") ===
+ "s3n://ACCESSKEY:SECRETKEY@bucket/tempdir/subdir/")
+ }
+
+ test("fixUrl produces Redshift-compatible equivalents") {
+ Utils.fixS3Url("s3a://foo/bar/12345") shouldBe "s3://foo/bar/12345"
+ Utils.fixS3Url("s3n://foo/bar/baz") shouldBe "s3://foo/bar/baz"
+ }
+
+ test("addEndpointToUrl produces urls with endpoints added to host") {
+ Utils.addEndpointToUrl("s3a://foo/bar/12345") shouldBe "s3a://foo.s3.amazonaws.com/bar/12345"
+ Utils.addEndpointToUrl("s3n://foo/bar/baz") shouldBe "s3n://foo.s3.amazonaws.com/bar/baz"
+ }
+
+ test("temp paths are random subdirectories of root") {
+ val root = "s3n://temp/"
+ val firstTempPath = Utils.makeTempPath(root)
+
+ Utils.makeTempPath(root) should (startWith (root) and endWith ("/")
+ and not equal root and not equal firstTempPath)
+ }
+
+ test("removeCredentialsFromURI removes AWS access keys") {
+ def removeCreds(uri: String): String = {
+ Utils.removeCredentialsFromURI(URI.create(uri)).toString
+ }
+ assert(removeCreds("s3n://bucket/path/to/temp/dir") === "s3n://bucket/path/to/temp/dir")
+ assert(
+ removeCreds("s3n://ACCESSKEY:SECRETKEY@bucket/path/to/temp/dir") ===
+ "s3n://bucket/path/to/temp/dir")
+ }
+
+ test("getRegionForRedshiftCluster") {
+ val redshiftUrl =
+ "jdbc:redshift://example.secret.us-west-2.redshift.amazonaws.com:5439/database"
+ assert(Utils.getRegionForRedshiftCluster("mycluster.example.com") === None)
+ assert(Utils.getRegionForRedshiftCluster(redshiftUrl) === Some("us-west-2"))
+ }
+}
diff --git a/pom.xml b/pom.xml
index d269a4fcbeeb7..9e19ae588ef9c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -108,6 +108,8 @@
repl
launcher
external/avro
+ external/redshift
+ external/redshift-integration-tests
external/kafka-0-8
external/kafka-0-8-assembly
external/kafka-0-10
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 21e84021105a0..8dfba47145413 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -39,8 +39,9 @@ object BuildCommons {
private val buildLocation = file(".").getAbsoluteFile.getParentFile
- val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010, sqlKafka08, avro) = Seq(
- "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10", "sql-kafka-0-8", "avro"
+ val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010, sqlKafka08, avro, redshift, redshiftIntegrationTests) = Seq(
+ "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10", "sql-kafka-0-8", "avro",
+ "redshift", "redshift-integration-tests"
).map(ProjectRef(buildLocation, _))
val streamingProjects@Seq(
@@ -353,7 +354,7 @@ object SparkBuild extends PomBuild {
val mimaProjects = allProjects.filterNot { x =>
Seq(
spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn,
- unsafe, tags, sqlKafka010, sqlKafka08, avro
+ unsafe, tags, sqlKafka010, sqlKafka08, avro, redshift, redshiftIntegrationTests
).contains(x)
}
@@ -717,9 +718,9 @@ object Unidoc {
publish := {},
unidocProjectFilter in(ScalaUnidoc, unidoc) :=
- inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010, sqlKafka08, avro),
+ inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010, sqlKafka08, avro, redshift),
unidocProjectFilter in(JavaUnidoc, unidoc) :=
- inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010, sqlKafka08, avro),
+ inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010, sqlKafka08, avro, redshift),
unidocAllClasspaths in (ScalaUnidoc, unidoc) := {
ignoreClasspaths((unidocAllClasspaths in (ScalaUnidoc, unidoc)).value)