Skip to content

Commit

Permalink
[SPARK-6124] Support jdbc connection properties in OPTIONS part of th…
Browse files Browse the repository at this point in the history
…e query

One more thing if this PR is considered to be OK - it might make sense to add extra .jdbc() API's that take Properties to SQLContext.

Author: Volodymyr Lyubinets <[email protected]>

Closes #4859 from vlyubin/jdbcProperties and squashes the following commits:

7a8cfda [Volodymyr Lyubinets] Support jdbc connection properties in OPTIONS part of the query
  • Loading branch information
vlyubin authored and marmbrus committed Mar 24, 2015
1 parent 6cd7058 commit bfd3ee9
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 29 deletions.
14 changes: 8 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.jdbc

import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException}
import java.util.Properties

import org.apache.commons.lang.StringEscapeUtils.escapeSql
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
Expand Down Expand Up @@ -90,9 +91,9 @@ private[sql] object JDBCRDD extends Logging {
* @throws SQLException if the table specification is garbage.
* @throws SQLException if the table contains an unsupported type.
*/
def resolveTable(url: String, table: String): StructType = {
def resolveTable(url: String, table: String, properties: Properties): StructType = {
val quirks = DriverQuirks.get(url)
val conn: Connection = DriverManager.getConnection(url)
val conn: Connection = DriverManager.getConnection(url, properties)
try {
val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery()
try {
Expand Down Expand Up @@ -147,7 +148,7 @@ private[sql] object JDBCRDD extends Logging {
*
* @return A function that loads the driver and connects to the url.
*/
def getConnector(driver: String, url: String): () => Connection = {
def getConnector(driver: String, url: String, properties: Properties): () => Connection = {
() => {
try {
if (driver != null) Class.forName(driver)
Expand All @@ -156,7 +157,7 @@ private[sql] object JDBCRDD extends Logging {
logWarning(s"Couldn't find class $driver", e);
}
}
DriverManager.getConnection(url)
DriverManager.getConnection(url, properties)
}
}
/**
Expand All @@ -179,6 +180,7 @@ private[sql] object JDBCRDD extends Logging {
schema: StructType,
driver: String,
url: String,
properties: Properties,
fqTable: String,
requiredColumns: Array[String],
filters: Array[Filter],
Expand All @@ -189,7 +191,7 @@ private[sql] object JDBCRDD extends Logging {
return new
JDBCRDD(
sc,
getConnector(driver, url),
getConnector(driver, url, properties),
prunedSchema,
fqTable,
requiredColumns,
Expand Down Expand Up @@ -361,7 +363,7 @@ private[sql] class JDBCRDD(
var ans = 0L
var j = 0
while (j < bytes.size) {
ans = 256*ans + (255 & bytes(j))
ans = 256 * ans + (255 & bytes(j))
j = j + 1;
}
mutableRow.setLong(i, ans)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@

package org.apache.spark.sql.jdbc

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.types.StructType
import java.sql.DriverManager
import java.util.Properties

import scala.collection.mutable.ArrayBuffer
import java.sql.DriverManager

import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType

/**
* Data corresponding to one partition of a JDBCRDD.
Expand Down Expand Up @@ -115,18 +116,21 @@ private[sql] class DefaultSource extends RelationProvider {
numPartitions.toInt)
}
val parts = JDBCRelation.columnPartition(partitionInfo)
JDBCRelation(url, table, parts)(sqlContext)
val properties = new Properties() // Additional properties that we will pass to getConnection
parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
JDBCRelation(url, table, parts, properties)(sqlContext)
}
}

private[sql] case class JDBCRelation(
url: String,
table: String,
parts: Array[Partition])(@transient val sqlContext: SQLContext)
parts: Array[Partition],
properties: Properties = new Properties())(@transient val sqlContext: SQLContext)
extends BaseRelation
with PrunedFilteredScan {

override val schema: StructType = JDBCRDD.resolveTable(url, table)
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName
Expand All @@ -135,6 +139,7 @@ private[sql] case class JDBCRelation(
schema,
driver,
url,
properties,
table,
requiredColumns,
filters,
Expand Down
55 changes: 39 additions & 16 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,31 @@ package org.apache.spark.sql.jdbc

import java.math.BigDecimal
import java.sql.DriverManager
import java.util.{Calendar, GregorianCalendar}
import java.util.{Calendar, GregorianCalendar, Properties}

import org.apache.spark.sql.test._
import org.h2.jdbc.JdbcSQLException
import org.scalatest.{FunSuite, BeforeAndAfter}
import TestSQLContext._
import TestSQLContext.implicits._

class JDBCSuite extends FunSuite with BeforeAndAfter {
val url = "jdbc:h2:mem:testdb0"
val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
var conn: java.sql.Connection = null

val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte)

before {
Class.forName("org.h2.Driver")
conn = DriverManager.getConnection(url)
// Extra properties that will be specified for our database. We need these to test
// usage of parameters from OPTIONS clause in queries.
val properties = new Properties()
properties.setProperty("user", "testUser")
properties.setProperty("password", "testPass")
properties.setProperty("rowId", "false")

conn = DriverManager.getConnection(url, properties)
conn.prepareStatement("create schema test").executeUpdate()
conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate()
Expand All @@ -46,15 +55,15 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
s"""
|CREATE TEMPORARY TABLE foobar
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE')
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))

sql(
s"""
|CREATE TEMPORARY TABLE parts
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE',
|partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3')
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass',
| partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3')
""".stripMargin.replaceAll("\n", " "))

conn.prepareStatement("create table test.inttypes (a INT, b BOOLEAN, c TINYINT, "
Expand All @@ -68,12 +77,12 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
s"""
|CREATE TEMPORARY TABLE inttypes
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.INTTYPES')
|OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))

conn.prepareStatement("create table test.strtypes (a BINARY(20), b VARCHAR(20), "
+ "c VARCHAR_IGNORECASE(20), d CHAR(20), e BLOB, f CLOB)").executeUpdate()
var stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)")
val stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)")
stmt.setBytes(1, testBytes)
stmt.setString(2, "Sensitive")
stmt.setString(3, "Insensitive")
Expand All @@ -85,7 +94,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
s"""
|CREATE TEMPORARY TABLE strtypes
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.STRTYPES')
|OPTIONS (url '$url', dbtable 'TEST.STRTYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))

conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP)"
Expand All @@ -97,7 +106,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
s"""
|CREATE TEMPORARY TABLE timetypes
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.TIMETYPES')
|OPTIONS (url '$url', dbtable 'TEST.TIMETYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))


Expand All @@ -112,7 +121,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
s"""
|CREATE TEMPORARY TABLE flttypes
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.FLTTYPES')
|OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))

// Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types.
Expand Down Expand Up @@ -174,16 +183,17 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
}

test("Basic API") {
assert(TestSQLContext.jdbc(url, "TEST.PEOPLE").collect.size == 3)
assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect.size == 3)
}

test("Partitioning via JDBCPartitioningInfo API") {
assert(TestSQLContext.jdbc(url, "TEST.PEOPLE", "THEID", 0, 4, 3).collect.size == 3)
assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3)
.collect.size == 3)
}

test("Partitioning via list-of-where-clauses API") {
val parts = Array[String]("THEID < 2", "THEID >= 2")
assert(TestSQLContext.jdbc(url, "TEST.PEOPLE", parts).collect.size == 3)
assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect.size == 3)
}

test("H2 integral types") {
Expand Down Expand Up @@ -216,7 +226,6 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
assert(rows(0).getString(5).equals("I am a clob!"))
}


test("H2 time types") {
val rows = sql("SELECT * FROM timetypes").collect()
val cal = new GregorianCalendar(java.util.Locale.ROOT)
Expand Down Expand Up @@ -246,17 +255,31 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
.equals(new BigDecimal("123456789012345.54321543215432100000")))
}


test("SQL query as table name") {
sql(
s"""
|CREATE TEMPORARY TABLE hack
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)')
|OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)',
| user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
val rows = sql("SELECT * FROM hack").collect()
assert(rows(0).getDouble(0) == 1.00000011920928955) // Yes, I meant ==.
// For some reason, H2 computes this square incorrectly...
assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12)
}

test("Pass extra properties via OPTIONS") {
// We set rowId to false during setup, which means that _ROWID_ column should be absent from
// all tables. If rowId is true (default), the query below doesn't throw an exception.
intercept[JdbcSQLException] {
sql(
s"""
|CREATE TEMPORARY TABLE abc
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable '(SELECT _ROWID_ FROM test.people)',
| user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
}
}
}

0 comments on commit bfd3ee9

Please sign in to comment.