Skip to content

Commit

Permalink
Automatically use Redshift 4.0 or 4.1 JDBC driver, depending on which…
Browse files Browse the repository at this point in the history
… is installed

This patch refactors the driver-loading code so that it automatically uses either the Redshift JDBC 4.0 or 4.1 drivers, depending on which is installed. I also added support for automatically supplying an appropriate default JDBC driver class name when the `postgres://` subprotocol is used.

I tested all of the error-handling corner-cases manually.

Fixes #83.

Author: Josh Rosen <[email protected]>

Closes #90 from JoshRosen/auto-configure-jdbc-driver-class.
  • Loading branch information
JoshRosen committed Sep 15, 2015
1 parent de513a5 commit 43f9709
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 65 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ and use that as a temp location for this data.
<tr>
<td><tt>jdbcdriver</tt></td>
<td>No</td>
<td><tt>com.amazon.redshift.jdbc4.Driver</tt></td>
<td>The class name of the JDBC driver to load before JDBC operations. Must be on classpath.</td>
<td>Determined by the JDBC URL's subprotocol</td>
<td>The class name of the JDBC driver to load before JDBC operations. This class must be on the classpath. In most cases, it should not be necessary to specify this option, as the appropriate driver classname should automatically be determined by the JDBC URL's subprotocol.</td>
</tr>
<tr>
<td><tt>diststyle</tt></td>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ trait IntegrationSuiteBase
sc.hadoopConfiguration.setBoolean("fs.s3n.impl.disable.cache", true)
sc.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID)
sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY)
conn = DefaultJDBCWrapper.getConnector("com.amazon.redshift.jdbc4.Driver", jdbcUrl)
conn = DefaultJDBCWrapper.getConnector(None, jdbcUrl)
}

override def afterAll(): Unit = {
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/com/databricks/spark/redshift/Parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ private[redshift] object Parameters {
// * tempdir, dbtable and url have no default and they *must* be provided
// * sortkeyspec has no default, but is optional
// * distkey has no default, but is optional unless using diststyle KEY
// * jdbcdriver has no default, but is optional

"jdbcdriver" -> "com.amazon.redshift.jdbc4.Driver",
"overwrite" -> "false",
"diststyle" -> "EVEN",
"usestagingtable" -> "true",
Expand Down Expand Up @@ -106,9 +106,9 @@ private[redshift] object Parameters {

/**
* The JDBC driver class name. This is used to make sure the driver is registered before
* connecting over JDBC. Default is "com.amazon.redshift.jdbc4.Driver"
* connecting over JDBC.
*/
def jdbcDriver: String = parameters("jdbcdriver")
def jdbcDriver: Option[String] = parameters.get("jdbcdriver")

/**
* If true, when writing, replace any existing data. When false, append to the table instead.
Expand Down
116 changes: 68 additions & 48 deletions src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package com.databricks.spark.redshift

import java.sql.{Connection, DriverManager, ResultSetMetaData, SQLException}
import java.net.URI
import java.sql.{Connection, Driver, DriverManager, ResultSetMetaData, SQLException}
import java.util.Properties

import scala.util.Try
Expand All @@ -34,27 +35,54 @@ private[redshift] class JDBCWrapper {

private val log = LoggerFactory.getLogger(getClass)

def registerDriver(driverClass: String): Unit = {
/**
* Given a JDBC subprotocol, returns the appropriate driver class so that it can be registered
* with Spark. If the user has explicitly specified a driver class in their configuration then
* that class will be used. Otherwise, we will attempt to load the correct driver class based on
* the JDBC subprotocol.
*
* @param jdbcSubprotocol 'redshift' or 'postgres'
* @param userProvidedDriverClass an optional user-provided explicit driver class name
* @return the driver class
*/
private def getDriverClass(
jdbcSubprotocol: String,
userProvidedDriverClass: Option[String]): Class[Driver] = {
userProvidedDriverClass.map(Utils.classForName).getOrElse {
jdbcSubprotocol match {
case "redshift" =>
try {
Utils.classForName("com.amazon.redshift.jdbc41.Driver")
} catch {
case _: ClassNotFoundException =>
try {
Utils.classForName("com.amazon.redshift.jdbc4.Driver")
} catch {
case e: ClassNotFoundException =>
throw new ClassNotFoundException(
"Could not load an Amazon Redshift JDBC driver; see the README for " +
"instructions on downloading and configuring the official Amazon driver.", e)
}
}
case "postgres" => Utils.classForName("org.postgresql.Driver")
case other => throw new IllegalArgumentException(s"Unsupported JDBC protocol: '$other'")
}
}.asInstanceOf[Class[Driver]]
}

private def registerDriver(driverClass: String): Unit = {
// DriverRegistry.register() is one of the few pieces of private Spark functionality which
// we need to rely on. This class was relocated in Spark 1.5.0, so we need to use reflection
// in order to support both Spark 1.4.x and 1.5.x.
// TODO: once 1.5.0 snapshots are on Maven, update this to switch the class name based on
// SPARK_VERSION.
val classLoader =
Option(Thread.currentThread().getContextClassLoader).getOrElse(this.getClass.getClassLoader)
if (SPARK_VERSION.startsWith("1.4")) {
val className = "org.apache.spark.sql.jdbc.package$DriverRegistry$"
// scalastyle:off
val driverRegistryClass = Class.forName(className, true, classLoader)
// scalastyle:on
val driverRegistryClass = Utils.classForName(className)
val registerMethod = driverRegistryClass.getDeclaredMethod("register", classOf[String])
val companionObject = driverRegistryClass.getDeclaredField("MODULE$").get(null)
registerMethod.invoke(companionObject, driverClass)
} else { // Spark 1.5.0+
val className = "org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry"
// scalastyle:off
val driverRegistryClass = Class.forName(className, true, classLoader)
// scalastyle:on
val driverRegistryClass = Utils.classForName(className)
val registerMethod = driverRegistryClass.getDeclaredMethod("register", classOf[String])
registerMethod.invoke(null, driverClass)
}
Expand All @@ -64,59 +92,51 @@ private[redshift] class JDBCWrapper {
* Takes a (schema, table) specification and returns the table's Catalyst
* schema.
*
* @param url - The JDBC url to fetch information from.
* @param table - The table name of the desired table. This may also be a
* @param conn A JDBC connection to the database.
* @param table The table name of the desired table. This may also be a
* SQL query wrapped in parentheses.
*
* @return A StructType giving the table's Catalyst schema.
* @throws SQLException if the table specification is garbage.
* @throws SQLException if the table contains an unsupported type.
*/
def resolveTable(url: String, table: String): StructType = {
val conn: Connection = DriverManager.getConnection(url, new Properties())
def resolveTable(conn: Connection, table: String): StructType = {
val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery()
try {
val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery()
try {
val rsmd = rs.getMetaData
val ncols = rsmd.getColumnCount
val fields = new Array[StructField](ncols)
var i = 0
while (i < ncols) {
val columnName = rsmd.getColumnLabel(i + 1)
val dataType = rsmd.getColumnType(i + 1)
val typeName = rsmd.getColumnTypeName(i + 1)
val fieldSize = rsmd.getPrecision(i + 1)
val fieldScale = rsmd.getScale(i + 1)
val isSigned = rsmd.isSigned(i + 1)
val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
val columnType = getCatalystType(dataType, fieldSize, fieldScale, isSigned)
fields(i) = StructField(columnName, columnType, nullable)
i = i + 1
}
return new StructType(fields)
} finally {
rs.close()
val rsmd = rs.getMetaData
val ncols = rsmd.getColumnCount
val fields = new Array[StructField](ncols)
var i = 0
while (i < ncols) {
val columnName = rsmd.getColumnLabel(i + 1)
val dataType = rsmd.getColumnType(i + 1)
val typeName = rsmd.getColumnTypeName(i + 1)
val fieldSize = rsmd.getPrecision(i + 1)
val fieldScale = rsmd.getScale(i + 1)
val isSigned = rsmd.isSigned(i + 1)
val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
val columnType = getCatalystType(dataType, fieldSize, fieldScale, isSigned)
fields(i) = StructField(columnName, columnType, nullable)
i = i + 1
}
new StructType(fields)
} finally {
conn.close()
rs.close()
}

throw new RuntimeException("This line is unreachable.")
}

/**
* Given a driver string and a JDBC url, load the specified driver and return a DB connection.
*
* @param driver the class name of the JDBC driver for the given url.
* @param userProvidedDriverClass the class name of the JDBC driver for the given url. If this
* is None then `spark-redshift` will attempt to automatically
* discover the appropriate driver class.
* @param url the JDBC url to connect to.
*/
def getConnector(driver: String, url: String): Connection = {
try {
if (driver != null) registerDriver(driver)
} catch {
case e: ClassNotFoundException =>
log.warn(s"Couldn't find class $driver", e)
}
def getConnector(userProvidedDriverClass: Option[String], url: String): Connection = {
val subprotocol = new URI(url.stripPrefix("jdbc:")).getScheme
val driverClass: Class[Driver] = getDriverClass(subprotocol, userProvidedDriverClass)
registerDriver(driverClass.getCanonicalName)
DriverManager.getConnection(url, new Properties())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ private[redshift] case class RedshiftRelation(
new URI(params.rootTempDir), sqlContext.sparkContext.hadoopConfiguration)
}

override def schema: StructType = {
userSchema match {
case Some(schema) => schema
case None => {
jdbcWrapper.registerDriver(params.jdbcDriver)
val tableNameOrSubquery = params.query.map(q => s"($q)").orElse(params.table).get
jdbcWrapper.resolveTable(params.jdbcUrl, tableNameOrSubquery)
override lazy val schema: StructType = {
userSchema.getOrElse {
val tableNameOrSubquery = params.query.map(q => s"($q)").orElse(params.table).get
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl)
try {
jdbcWrapper.resolveTable(conn, tableNameOrSubquery)
} finally {
conn.close()
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions src/main/scala/com/databricks/spark/redshift/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ private[redshift] object Utils {

private val log = LoggerFactory.getLogger(getClass)

def classForName(className: String): Class[_] = {
val classLoader =
Option(Thread.currentThread().getContextClassLoader).getOrElse(this.getClass.getClassLoader)
// scalastyle:off
Class.forName(className, true, classLoader)
// scalastyle:on
}

/**
* Joins prefix URL a to path suffix b, and appends a trailing /, in order to create
* a temp directory path for S3.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,18 @@ class MockRedshift(
conn
}

when(jdbcWrapper.getConnector(anyString(), same(jdbcUrl))).thenAnswer(new Answer[Connection] {
override def answer(invocation: InvocationOnMock): Connection = createMockConnection()
})
when(jdbcWrapper.getConnector(any[Option[String]](), same(jdbcUrl))).thenAnswer(
new Answer[Connection] {
override def answer(invocation: InvocationOnMock): Connection = createMockConnection()
})

when(jdbcWrapper.tableExists(any[Connection], anyString())).thenAnswer(new Answer[Boolean] {
override def answer(invocation: InvocationOnMock): Boolean = {
existingTablesAndSchemas.contains(invocation.getArguments()(1).asInstanceOf[String])
}
})

when(jdbcWrapper.resolveTable(same(jdbcUrl), anyString())).thenAnswer(new Answer[StructType] {
when(jdbcWrapper.resolveTable(any[Connection], anyString())).thenAnswer(new Answer[StructType] {
override def answer(invocation: InvocationOnMock): StructType = {
existingTablesAndSchemas(invocation.getArguments()(1).asInstanceOf[String])
}
Expand Down

0 comments on commit 43f9709

Please sign in to comment.