Skip to content

Commit

Permalink
Recovering lost changes
Browse files Browse the repository at this point in the history
  • Loading branch information
heathermiller committed Sep 12, 2014
1 parent 6e1eaf3 commit 870b270
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
/**
* A collection of Scala macros for working with SQL in a type-safe way.
*/
private[sql] object SQLMacros {
object SQLMacros {
import scala.reflect.macros._

var currentContext: SQLContext = _

def sqlImpl(c: Context)(args: c.Expr[Any]*) =
new Macros[c.type](c).sql(args)

Expand Down Expand Up @@ -68,10 +70,29 @@ private[sql] object SQLMacros {

case class RecSchema(name: String, index: Int, cType: DataType, tpe: Type)

def getSchema(sqlQuery: String, interpolatedArguments: Seq[InterpolatedItem]) = {
if (currentContext == null) {
val parser = new SqlParser()
val logicalPlan = parser(sqlQuery)
val catalog = new SimpleCatalog(true)
val functionRegistry = new SimpleFunctionRegistry
val analyzer = new Analyzer(catalog, functionRegistry, true)

interpolatedArguments.foreach(_.localRegister(catalog, functionRegistry))
val analyzedPlan = analyzer(logicalPlan)

analyzedPlan.output.map(attr => (attr.name, attr.dataType))
} else {
interpolatedArguments.foreach(
_.localRegister(currentContext.catalog, currentContext.functionRegistry))
currentContext.sql(sqlQuery).schema.fields.map(attr => (attr.name, attr.dataType))
}
}

def sql(args: Seq[c.Expr[Any]]) = {

val q"""
org.apache.spark.sql.test.TestSQLContext.SqlInterpolator(
$path.SQLInterpolation(
scala.StringContext.apply(..$rawParts))""" = c.prefix.tree

//rawParts.map(_.toString).foreach(println)
Expand All @@ -96,16 +117,7 @@ private[sql] object SQLMacros {
interpolatedArguments(i).placeholderName + parts(i + 1)
}.mkString("")

val parser = new SqlParser()
val logicalPlan = parser(query)
val catalog = new SimpleCatalog(true)
val functionRegistry = new SimpleFunctionRegistry
val analyzer = new Analyzer(catalog, functionRegistry, true)

interpolatedArguments.foreach(_.localRegister(catalog, functionRegistry))
val analyzedPlan = analyzer(logicalPlan)

val fields = analyzedPlan.output.map(attr => (attr.name, attr.dataType))
val fields = getSchema(query, interpolatedArguments)
val record = genRecord(q"row", fields)

val tree = q"""
Expand Down

0 comments on commit 870b270

Please sign in to comment.