Skip to content

Commit

Permalink
[SPARK-5702][SQL] Allow short names for built-in data sources.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Feb 11, 2015
1 parent 6195e24 commit 74f42e3
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,6 @@ private[sql] object JDBCRelation {
* exactly once. The parameters minValue and maxValue are advisory in that
* incorrect values may cause the partitioning to be poor, but no data
* will fail to be represented.
*
* @param column - Column name. Must refer to a column of integral type.
* @param numPartitions - Number of partitions
* @param minValue - Smallest value of column. Advisory.
* @param maxValue - Largest value of column. Advisory.
*/
def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
if (partitioning == null) return Array[Partition](JDBCPartition(null, 0))
Expand All @@ -68,12 +63,17 @@ private[sql] object JDBCRelation {
var currentValue: Long = partitioning.lowerBound
var ans = new ArrayBuffer[Partition]()
while (i < numPartitions) {
val lowerBound = (if (i != 0) s"$column >= $currentValue" else null)
val lowerBound = if (i != 0) s"$column >= $currentValue" else null
currentValue += stride
val upperBound = (if (i != numPartitions - 1) s"$column < $currentValue" else null)
val whereClause = (if (upperBound == null) lowerBound
else if (lowerBound == null) upperBound
else s"$lowerBound AND $upperBound")
val upperBound = if (i != numPartitions - 1) s"$column < $currentValue" else null
val whereClause =
if (upperBound == null) {
lowerBound
} else if (lowerBound == null) {
upperBound
} else {
s"$lowerBound AND $upperBound"
}
ans += JDBCPartition(whereClause, i)
i = i + 1
}
Expand All @@ -96,8 +96,7 @@ private[sql] class DefaultSource extends RelationProvider {

if (driver != null) Class.forName(driver)

if (
partitionColumn != null
if (partitionColumn != null
&& (lowerBound == null || upperBound == null || numPartitions == null)) {
sys.error("Partitioning incompletely specified")
}
Expand All @@ -119,7 +118,8 @@ private[sql] class DefaultSource extends RelationProvider {
private[sql] case class JDBCRelation(
url: String,
table: String,
parts: Array[Partition])(@transient val sqlContext: SQLContext) extends PrunedFilteredScan {
parts: Array[Partition])(@transient val sqlContext: SQLContext)
extends PrunedFilteredScan {

override val schema = JDBCRDD.resolveTable(url, table)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.json
import java.io.IOException

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
Expand Down
77 changes: 42 additions & 35 deletions sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -234,65 +234,73 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
primitiveType
}

object ResolvedDataSource {
def apply(
sqlContext: SQLContext,
userSpecifiedSchema: Option[StructType],
provider: String,
options: Map[String, String]): ResolvedDataSource = {
private[sql] object ResolvedDataSource {

private val builtinSources = Map(
"jdbc" -> classOf[org.apache.spark.sql.jdbc.DefaultSource],
"json" -> classOf[org.apache.spark.sql.json.DefaultSource],
"parquet" -> classOf[org.apache.spark.sql.parquet.DefaultSource]
)

/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider: String): Class[_] = {
if (builtinSources.contains(provider)) {
return builtinSources(provider)
}

val loader = Utils.getContextOrSparkClassLoader
val clazz: Class[_] = try loader.loadClass(provider) catch {
try {
loader.loadClass(provider)
} catch {
case cnf: java.lang.ClassNotFoundException =>
try loader.loadClass(provider + ".DefaultSource") catch {
try {
loader.loadClass(provider + ".DefaultSource")
} catch {
case cnf: java.lang.ClassNotFoundException =>
sys.error(s"Failed to load class for data source: $provider")
}
}
}

/** Create a [[ResolvedDataSource]] for reading data in. */
def apply(
sqlContext: SQLContext,
userSpecifiedSchema: Option[StructType],
provider: String,
options: Map[String, String]): ResolvedDataSource = {
val clazz: Class[_] = lookupDataSource(provider)
val relation = userSpecifiedSchema match {
case Some(schema: StructType) => {
clazz.newInstance match {
case dataSource: SchemaRelationProvider =>
dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.")
}
case Some(schema: StructType) => clazz.newInstance() match {
case dataSource: SchemaRelationProvider =>
dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.")
}
case None => {
clazz.newInstance match {
case dataSource: RelationProvider =>
dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.")
}

case None => clazz.newInstance() match {
case dataSource: RelationProvider =>
dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.")
}
}

new ResolvedDataSource(clazz, relation)
}

/** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */
def apply(
sqlContext: SQLContext,
provider: String,
mode: SaveMode,
options: Map[String, String],
data: DataFrame): ResolvedDataSource = {
val loader = Utils.getContextOrSparkClassLoader
val clazz: Class[_] = try loader.loadClass(provider) catch {
case cnf: java.lang.ClassNotFoundException =>
try loader.loadClass(provider + ".DefaultSource") catch {
case cnf: java.lang.ClassNotFoundException =>
sys.error(s"Failed to load class for data source: $provider")
}
}

val relation = clazz.newInstance match {
val clazz: Class[_] = lookupDataSource(provider)
val relation = clazz.newInstance() match {
case dataSource: CreatableRelationProvider =>
dataSource.createRelation(sqlContext, mode, options, data)
case _ =>
sys.error(s"${clazz.getCanonicalName} does not allow create table as select.")
}

new ResolvedDataSource(clazz, relation)
}
}
Expand Down Expand Up @@ -405,6 +413,5 @@ protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[St

/**
* The exception thrown from the DDL parser.
* @param message
*/
protected[sql] class DDLException(message: String) extends Exception(message)
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.sources

import org.scalatest.FunSuite

class ResolvedDataSourceSuite extends FunSuite {

test("builtin sources") {
assert(ResolvedDataSource.lookupDataSource("jdbc") ===
classOf[org.apache.spark.sql.jdbc.DefaultSource])

assert(ResolvedDataSource.lookupDataSource("json") ===
classOf[org.apache.spark.sql.json.DefaultSource])

assert(ResolvedDataSource.lookupDataSource("parquet") ===
classOf[org.apache.spark.sql.parquet.DefaultSource])
}
}

0 comments on commit 74f42e3

Please sign in to comment.