From 870b270d6d31efa200af34870873f85130c2ae1d Mon Sep 17 00:00:00 2001 From: Heather Miller Date: Thu, 11 Sep 2014 22:54:58 -0700 Subject: [PATCH] Recovering lost changes --- .../scala/org/apache/spark/sql/TypedSql.scala | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala b/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala index e4d9a8e7f2385..1379c5d30fc7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala @@ -18,9 +18,11 @@ import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} /** * A collection of Scala macros for working with SQL in a type-safe way. */ -private[sql] object SQLMacros { +object SQLMacros { import scala.reflect.macros._ + var currentContext: SQLContext = _ + def sqlImpl(c: Context)(args: c.Expr[Any]*) = new Macros[c.type](c).sql(args) @@ -68,10 +70,29 @@ private[sql] object SQLMacros { case class RecSchema(name: String, index: Int, cType: DataType, tpe: Type) + def getSchema(sqlQuery: String, interpolatedArguments: Seq[InterpolatedItem]) = { + if (currentContext == null) { + val parser = new SqlParser() + val logicalPlan = parser(sqlQuery) + val catalog = new SimpleCatalog(true) + val functionRegistry = new SimpleFunctionRegistry + val analyzer = new Analyzer(catalog, functionRegistry, true) + + interpolatedArguments.foreach(_.localRegister(catalog, functionRegistry)) + val analyzedPlan = analyzer(logicalPlan) + + analyzedPlan.output.map(attr => (attr.name, attr.dataType)) + } else { + interpolatedArguments.foreach( + _.localRegister(currentContext.catalog, currentContext.functionRegistry)) + currentContext.sql(sqlQuery).schema.fields.map(attr => (attr.name, attr.dataType)) + } + } + def sql(args: Seq[c.Expr[Any]]) = { val q""" - org.apache.spark.sql.test.TestSQLContext.SqlInterpolator( + $path.SQLInterpolation( scala.StringContext.apply(..$rawParts))""" = c.prefix.tree //rawParts.map(_.toString).foreach(println) @@ -96,16 +117,7 @@ private[sql] object SQLMacros { interpolatedArguments(i).placeholderName + parts(i + 1) }.mkString("") - val parser = new SqlParser() - val logicalPlan = parser(query) - val catalog = new SimpleCatalog(true) - val functionRegistry = new SimpleFunctionRegistry - val analyzer = new Analyzer(catalog, functionRegistry, true) - - interpolatedArguments.foreach(_.localRegister(catalog, functionRegistry)) - val analyzedPlan = analyzer(logicalPlan) - - val fields = analyzedPlan.output.map(attr => (attr.name, attr.dataType)) + val fields = getSchema(query, interpolatedArguments) val record = genRecord(q"row", fields) val tree = q"""