Skip to content

Commit

Permalink
[SPARK-16387][SQL] JDBC Writer should use dialect to quote field names.
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Currently, JDBC Writer uses dialects to get datatypes, but doesn't to quote field names. This PR uses dialects to quote the field names, too.

**Reported Error Scenario (MySQL case)**
```scala
scala> val url="jdbc:mysql://localhost:3306/temp"
scala> val prop = new java.util.Properties
scala> prop.setProperty("user","root")
scala> spark.createDataset(Seq("a","b","c")).toDF("order")
scala> df.write.mode("overwrite").jdbc(url, "temptable", prop)
...MySQLSyntaxErrorException: ... near 'order TEXT )
```

## How was this patch tested?

Pass the Jenkins tests and manually do the above case.

Author: Dongjoon Hyun <[email protected]>

Closes apache#14107 from dongjoon-hyun/SPARK-16387.
  • Loading branch information
dongjoon-hyun authored and rxin committed Jul 8, 2016
1 parent 60ba436 commit 3b22291
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ object JdbcUtils extends Logging {
/**
* Returns a PreparedStatement that inserts a row into table via conn.
*/
def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = {
val columns = rddSchema.fields.map(_.name).mkString(",")
def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect)
: PreparedStatement = {
val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)"
conn.prepareStatement(sql)
Expand Down Expand Up @@ -177,7 +178,7 @@ object JdbcUtils extends Logging {
if (supportsTransactions) {
conn.setAutoCommit(false) // Everything in the same db transaction.
}
val stmt = insertStatement(conn, table, rddSchema)
val stmt = insertStatement(conn, table, rddSchema, dialect)
try {
var rowCount = 0
while (iterator.hasNext) {
Expand Down Expand Up @@ -260,7 +261,7 @@ object JdbcUtils extends Logging {
val sb = new StringBuilder()
val dialect = JdbcDialects.get(url)
df.schema.fields foreach { field =>
val name = field.name
val name = dialect.quoteIdentifier(field.name)
val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -764,4 +764,10 @@ class JDBCSuite extends SparkFunSuite
assertEmptyQuery(s"SELECT * FROM tempFrame where $FALSE2")
}
}

test("SPARK-16387: Reserved SQL words are not escaped by JDBC writer") {
val df = spark.createDataset(Seq("a", "b", "c")).toDF("order")
val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp")
assert(schema.contains("`order` TEXT"))
}
}

0 comments on commit 3b22291

Please sign in to comment.