Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding sse_kms_key optional parameter and using in Redshift UNLOAD #458

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,19 @@ for other options).</p>
<p>Note that since these options are appended to the end of the <tt>COPY</tt> command, only options that make sense
at the end of the command can be used, but that should cover most possible use cases.</p>
</td>
</tr>
<tr>
<td><tt>sse_kms_key</tt></td>
<td>No</td>
<td>No default</td>
<td>
<p>The KMS key ID to use for server-side encryption in S3 during the Redshift <tt>UNLOAD</tt> operation rather than AWS's default
encryption. The Redshift IAM role must have access to the KMS key for writing with it, and the Spark IAM role must have access
to the key for read operations. Reading the encrypted data requires no changes (AWS handles this under-the-hood) so long as
Spark's IAM role has the proper access.</p>
<p>See the <a href="https://docs.aws.amazon.com/redshift/latest/dg/t_unloading_encrypted_files.html">Redshift docs</a>
for more information.</p>
</td>
</tr>
<tr>
<td><tt>tempformat</tt> (Experimental)</td>
Expand Down
6 changes: 6 additions & 0 deletions src/main/scala/com/databricks/spark/redshift/Parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ private[redshift] object Parameters {
// * sortkeyspec has no default, but is optional
// * distkey has no default, but is optional unless using diststyle KEY
// * jdbcdriver has no default, but is optional
// * sse_kms_key has no default, but is optional

"forward_spark_s3_credentials" -> "false",
"tempformat" -> "AVRO",
Expand Down Expand Up @@ -285,5 +286,10 @@ private[redshift] object Parameters {
new BasicSessionCredentials(accessKey, secretAccessKey, sessionToken))
}
}

/**
* The AWS SSE-KMS key to use for encryption during UNLOAD operations instead of AWS's default encryption
*/
def sseKmsKey: Option[String] = parameters.get("sse_kms_key")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ private[redshift] case class RedshiftRelation(
} else {
// Unload data from Redshift into a temporary directory in S3:
val tempDir = params.createPerQueryTempDir()
val unloadSql = buildUnloadStmt(requiredColumns, filters, tempDir, creds)
val unloadSql = buildUnloadStmt(requiredColumns, filters, tempDir, creds, params.sseKmsKey)
log.info(unloadSql)
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
try {
Expand Down Expand Up @@ -176,7 +176,8 @@ private[redshift] case class RedshiftRelation(
requiredColumns: Array[String],
filters: Array[Filter],
tempDir: String,
creds: AWSCredentialsProvider): String = {
creds: AWSCredentialsProvider,
sseKmsKey: Option[String]): String = {
assert(!requiredColumns.isEmpty)
// Always quote column names:
val columnList = requiredColumns.map(col => s""""$col"""").mkString(", ")
Expand All @@ -193,7 +194,9 @@ private[redshift] case class RedshiftRelation(
// the credentials passed via `credsString`.
val fixedUrl = Utils.fixS3Url(Utils.removeCredentialsFromURI(new URI(tempDir)).toString)

s"UNLOAD ('$query') TO '$fixedUrl' WITH CREDENTIALS '$credsString' ESCAPE MANIFEST"
val sseKmsClause = sseKmsKey.map(key => s"KMS_KEY_ID '$key' ENCRYPTED").getOrElse("")

s"UNLOAD ('$query') TO '$fixedUrl' WITH CREDENTIALS '$credsString' ESCAPE MANIFEST $sseKmsClause"
}

private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,50 @@ class RedshiftSourceSuite
mockRedshift.verifyThatExpectedQueriesWereIssued(Seq(expectedQuery))
}

test("DefaultSource adds SSE-KMS clause") {
// scalastyle:off
unloadedData =
"""
|1|t
|1|f
|0|
|0|f
||
""".stripMargin.trim
// scalastyle:on
val kmsKeyId = "abc123"
val expectedQuery = (
"UNLOAD \\('SELECT \"testbyte\", \"testbool\" FROM \"PUBLIC\".\"test_table\" '\\) " +
"TO '.*' " +
"WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " +
"ESCAPE MANIFEST " +
"KMS_KEY_ID 'abc123' ENCRYPTED").r
val mockRedshift =
new MockRedshift(defaultParams("url"), Map("test_table" -> TestUtils.testSchema))
// Construct the source with a custom schema
val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
val params = defaultParams + ("sse_kms_key" -> kmsKeyId)
val relation = source.createRelation(testSqlContext, params, TestUtils.testSchema)
val resultSchema =
StructType(Seq(StructField("testbyte", ByteType), StructField("testbool", BooleanType)))

val rdd = relation.asInstanceOf[PrunedFilteredScan]
.buildScan(Array("testbyte", "testbool"), Array.empty[Filter])
.mapPartitions { iter =>
val fromRow = RowEncoder(resultSchema).resolveAndBind().fromRow _
iter.asInstanceOf[Iterator[InternalRow]].map(fromRow)
}
val prunedExpectedValues = Array(
Row(1.toByte, true),
Row(1.toByte, false),
Row(0.toByte, null),
Row(0.toByte, false),
Row(null, null))
assert(rdd.collect() === prunedExpectedValues)
mockRedshift.verifyThatConnectionsWereClosed()
mockRedshift.verifyThatExpectedQueriesWereIssued(Seq(expectedQuery))
}

test("DefaultSource supports preactions options to run queries before running COPY command") {
val mockRedshift = new MockRedshift(
defaultParams("url"),
Expand Down