Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor scala reflection #7

Merged
merged 3 commits into from
Jul 24, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ object SparkBuild extends Build {
parallelExecution in Test := false,
libraryDependencies ++= Seq(
"com.typesafe" %% "scalalogging-slf4j" % "1.0.1",
"ch.epfl.lamp" %% "scala-records" % "0.2-SNAPSHOT"
"ch.epfl.lamp" %% "scala-records" % "0.1"
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,16 @@ import org.apache.spark.sql.catalyst.types._
/**
* Provides experimental support for generating catalyst schemas for scala objects.
*/
object ScalaReflection {
import scala.reflect.runtime.universe._
object ScalaReflection extends ScalaReflection {
val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe
}

trait ScalaReflection {

/** The universe we work in (runtime or macro) */
val universe: scala.reflect.api.Universe

import universe._

case class Schema(dataType: DataType, nullable: Boolean)

Expand Down
197 changes: 87 additions & 110 deletions sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.types._

import scala.language.experimental.macros
import scala.language.existentials

import records._
import Macros.RecordMacros
Expand All @@ -17,143 +18,119 @@ import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
object SQLMacros {
import scala.reflect.macros._

def sqlImpl(c: Context)(args: c.Expr[Any]*) =
new Macros[c.type](c).sql(args)

case class Schema(dataType: DataType, nullable: Boolean)

def sqlImpl(c: Context)(args: c.Expr[Any]*) = {
class Macros[C <: Context](val c: C) extends ScalaReflection {
val universe: c.universe.type = c.universe

import c.universe._

// TODO: Don't copy this function from ScalaReflection.
def schemaFor(tpe: `Type`): Schema = tpe match {
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
Schema(ArrayType(schemaFor(elementType).dataType), nullable = true)
case t if t <:< typeOf[Map[_,_]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
Schema(MapType(schemaFor(keyType).dataType, schemaFor(valueType).dataType), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
}
val rowTpe = tq"_root_.org.apache.spark.sql.catalyst.expressions.Row"

val q"""
org.apache.spark.sql.test.TestSQLContext.SqlInterpolator(
scala.StringContext.apply(..$rawParts))""" = c.prefix.tree
val rMacros = new RecordMacros[c.type](c)

val parts = rawParts.map(_.toString.stripPrefix("\"").stripSuffix("\""))
val query = parts(0) + (0 until args.size).map { i =>
s"table$i" + parts(i + 1)
}.mkString("")
case class RecSchema(name: String, index: Int,
cType: DataType, tpe: Type)

val parser = new SqlParser()
val logicalPlan = parser(query)
val catalog = new SimpleCatalog
val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, false)
def sql(args: Seq[c.Expr[Any]]) = {

val tables = args.zipWithIndex.map { case (arg, i) =>
val TypeRef(_, _, Seq(schemaType)) = arg.actualType
val q"""
org.apache.spark.sql.test.TestSQLContext.SqlInterpolator(
scala.StringContext.apply(..$rawParts))""" = c.prefix.tree

val inputSchema = schemaFor(schemaType).dataType.asInstanceOf[StructType].toAttributes
(s"table$i", LocalRelation(inputSchema:_*))
}
val parts = rawParts.map(_.toString.stripPrefix("\"").stripSuffix("\""))
val query = parts(0) + args.indices.map { i => s"table$i" + parts(i + 1) }.mkString("")

tables.foreach(t => catalog.registerTable(None, t._1, t._2))

val analyzedPlan = analyzer(logicalPlan)

// TODO: This shouldn't probably be here but somewhere generic
// which defines the catalyst <-> Scala type mapping
def toScalaType(dt: DataType) = dt match {
case IntegerType => definitions.IntTpe
case LongType => definitions.LongTpe
case ShortType => definitions.ShortTpe
case ByteType => definitions.ByteTpe
case DoubleType => definitions.DoubleTpe
case FloatType => definitions.FloatTpe
case BooleanType => definitions.BooleanTpe
case StringType => definitions.StringClass.toType
}
val analyzedPlan = analyzeQuery(query, args.map(_.actualType))

// TODO: Move this to a macro implementation class (we need it
// locally for `Type` which is on c.universe)
case class RecSchema(name: String, index: Int,
cType: DataType, tpe: Type)
val fields = analyzedPlan.output.map(attr => (attr.name, attr.dataType))
val record = genRecord(q"row", fields)

val tree = q"""
..${args.zipWithIndex.map{ case (r,i) => q"""$r.registerAsTable(${s"table$i"})""" }}
val result = sql($query)
result.map(row => $record)
"""

val fullSchema = analyzedPlan.output.zipWithIndex.map { case (attr, i) =>
RecSchema(attr.name, i, attr.dataType, toScalaType(attr.dataType))
println(tree)

c.Expr(tree)
}

val schema = fullSchema.map(s => (s.name, s.tpe))
// TODO: Handle nullable fields
def genRecord(row: Tree, fields: Seq[(String, DataType)]) = {
case class ImplSchema(name: String, tpe: Type, impl: Tree)

val rMacros = new RecordMacros[c.type](c)
val implSchemas = for {
((name, dataType),i) <- fields.zipWithIndex
} yield {
val tpe = c.typeCheck(genGetField(q"null: $rowTpe", i, dataType)).tpe
val tree = genGetField(row, i, dataType)

val (spFlds, objFields) = fullSchema.partition(s =>
rMacros.specializedTypes.contains(s.tpe))
ImplSchema(name, tpe, tree)
}

val spFldsByType = {
val grouped = spFlds.groupBy(_.tpe)
grouped.mapValues { _.map(s => s.name -> s).toMap }
}
val schema = implSchemas.map(f => (f.name, f.tpe))

val (spFlds, objFields) = implSchemas.partition(s =>
rMacros.specializedTypes.contains(s.tpe))

val spImplsByTpe = {
val grouped = spFlds.groupBy(_.tpe)
grouped.mapValues { _.map(s => s.name -> s.impl).toMap }
}

def methodName(t: DataType) = newTermName("get" + primitiveForType(t))
val dataObjImpl = {
val impls = objFields.map(s => s.name -> s.impl).toMap
val lookupTree = rMacros.genLookup(q"fieldName", impls, mayCache = false)
q"($lookupTree).asInstanceOf[T]"
}

val dataObjImpl = {
val fldTrees = objFields.map(s =>
s.name -> q"row.${methodName(s.cType)}(${s.index})"
).toMap
val lookupTree = rMacros.genLookup(q"fieldName", fldTrees, mayCache = false)
q"($lookupTree).asInstanceOf[T]"
rMacros.specializedRecord(schema)(tq"Serializable")()(dataObjImpl) {
case tpe if spImplsByTpe.contains(tpe) =>
rMacros.genLookup(q"fieldName", spImplsByTpe(tpe), mayCache = false)
}
}

val record = rMacros.specializedRecord(schema)(tq"Serializable")()(dataObjImpl) {
case tpe if spFldsByType.contains(tpe) =>
val fldTrees = spFldsByType(tpe).mapValues(s =>
q"row.${methodName(s.cType)}(${s.index})")
rMacros.genLookup(q"fieldName", fldTrees, mayCache = false)
/** Generate a tree that retrieves a given field for a given type.
* Constructs a nested record if necessary
*/
def genGetField(row: Tree, index: Int, t: DataType): Tree = t match {
case t: PrimitiveType =>
val methodName = newTermName("get" + primitiveForType(t))
q"$row.$methodName($index)"
case StructType(structFields) =>
val fields = structFields.map(f => (f.name, f.dataType))
genRecord(q"$row($index).asInstanceOf[$rowTpe]", fields)
case _ =>
c.abort(NoPosition, s"Query returns currently unhandled field type: $t")
}

val tree = q"""
..${args.zipWithIndex.map{ case (r,i) => q"""$r.registerAsTable(${s"table$i"})""" }}
val result = sql($query)
result.map(row => $record)
"""
def analyzeQuery(query: String, tableTypes: Seq[Type]) = {
val parser = new SqlParser()
val logicalPlan = parser(query)
val catalog = new SimpleCatalog
val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, false)

println(tree)
val tables = tableTypes.zipWithIndex.map { case (tblTpe, i) =>
val TypeRef(_, _, Seq(schemaType)) = tblTpe

val inputSchema = schemaFor(schemaType).dataType.asInstanceOf[StructType].toAttributes
(s"table$i", LocalRelation(inputSchema:_*))
}

tables.foreach(t => catalog.registerTable(None, t._1, t._2))

analyzer(logicalPlan)
}

c.Expr(tree)
}

// TODO: Duplicated from codegen PR...
protected def primitiveForType(dt: DataType) = dt match {
protected def primitiveForType(dt: PrimitiveType) = dt match {
case IntegerType => "Int"
case LongType => "Long"
case ShortType => "Short"
Expand All @@ -173,4 +150,4 @@ trait TypedSQL {
// TODO: Handle functions...
def sql(args: Any*): Any = macro SQLMacros.sqlImpl
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ class TypedSqlSuite extends FunSuite {
assert(results.first().age == 30)
}

ignore("nested results") {
test("nested results") {
val results = sql"SELECT * FROM $cars"
assert(results.first().owner.name === "Michael")
assert(results.first().owner.name == "Michael")
}

test("join query") {
val results = sql"""SELECT a.name FROM $people a JOIN $people b ON a.age = b.age"""

assert(results.first().name == "Michael")
}
}
}