From 83dd0928db6d1109a9290dd14e7208c90ee75c60 Mon Sep 17 00:00:00 2001 From: Tobias Schlatter Date: Tue, 22 Jul 2014 14:19:59 +0200 Subject: [PATCH 1/3] Fix records version to 0.1 --- project/SparkBuild.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9ed0250eee8a9..2c82b73943570 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -499,7 +499,7 @@ object SparkBuild extends Build { parallelExecution in Test := false, libraryDependencies ++= Seq( "com.typesafe" %% "scalalogging-slf4j" % "1.0.1", - "ch.epfl.lamp" %% "scala-records" % "0.2-SNAPSHOT" + "ch.epfl.lamp" %% "scala-records" % "0.1" ) ) From ae5ecaf56fe2dab90327635dcc58e59ab236bb4d Mon Sep 17 00:00:00 2001 From: Tobias Schlatter Date: Tue, 22 Jul 2014 14:20:53 +0200 Subject: [PATCH 2/3] Handle nested fields --- .../scala/org/apache/spark/sql/TypedSql.scala | 194 ++++++++++-------- .../org/apache/spark/sql/TypedSqlSuite.scala | 6 +- 2 files changed, 109 insertions(+), 91 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala b/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala index 119a786e2e6fa..6cc2d882b7dac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala @@ -6,6 +6,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types._ import scala.language.experimental.macros +import scala.language.existentials import records._ import Macros.RecordMacros @@ -17,11 +18,113 @@ import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} object SQLMacros { import scala.reflect.macros._ + def sqlImpl(c: Context)(args: c.Expr[Any]*) = + new Macros[c.type](c).sql(args) + case class Schema(dataType: DataType, nullable: Boolean) - def sqlImpl(c: Context)(args: c.Expr[Any]*) = { + class Macros[C <: Context](val c: C) { import c.universe._ + val rowTpe = tq"_root_.org.apache.spark.sql.catalyst.expressions.Row" + + val rMacros = new RecordMacros[c.type](c) + + case class RecSchema(name: String, index: Int, + cType: DataType, tpe: Type) + + def sql(args: Seq[c.Expr[Any]]) = { + + val q""" + org.apache.spark.sql.test.TestSQLContext.SqlInterpolator( + scala.StringContext.apply(..$rawParts))""" = c.prefix.tree + + val parts = rawParts.map(_.toString.stripPrefix("\"").stripSuffix("\"")) + val query = parts(0) + args.indices.map { i => s"table$i" + parts(i + 1) }.mkString("") + + val analyzedPlan = analyzeQuery(query, args.map(_.actualType)) + + val fields = analyzedPlan.output.map(attr => (attr.name, attr.dataType)) + val record = genRecord(q"row", fields) + + val tree = q""" + ..${args.zipWithIndex.map{ case (r,i) => q"""$r.registerAsTable(${s"table$i"})""" }} + val result = sql($query) + result.map(row => $record) + """ + + println(tree) + + c.Expr(tree) + } + + // TODO: Handle nullable fields + def genRecord(row: Tree, fields: Seq[(String, DataType)]) = { + case class ImplSchema(name: String, tpe: Type, impl: Tree) + + val implSchemas = for { + ((name, dataType),i) <- fields.zipWithIndex + } yield { + val tpe = c.typeCheck(genGetField(q"null: $rowTpe", i, dataType)).tpe + val tree = genGetField(row, i, dataType) + + ImplSchema(name, tpe, tree) + } + + val schema = implSchemas.map(f => (f.name, f.tpe)) + + val (spFlds, objFields) = implSchemas.partition(s => + rMacros.specializedTypes.contains(s.tpe)) + + val spImplsByTpe = { + val grouped = spFlds.groupBy(_.tpe) + grouped.mapValues { _.map(s => s.name -> s.impl).toMap } + } + + val dataObjImpl = { + val impls = objFields.map(s => s.name -> s.impl).toMap + val lookupTree = rMacros.genLookup(q"fieldName", impls, mayCache = false) + q"($lookupTree).asInstanceOf[T]" + } + + rMacros.specializedRecord(schema)(tq"Serializable")()(dataObjImpl) { + case tpe if spImplsByTpe.contains(tpe) => + rMacros.genLookup(q"fieldName", spImplsByTpe(tpe), mayCache = false) + } + } + + /** Generate a tree that retrieves a given field for a given type. + * Constructs a nested record if necessary + */ + def genGetField(row: Tree, index: Int, t: DataType): Tree = t match { + case t: PrimitiveType => + val methodName = newTermName("get" + primitiveForType(t)) + q"$row.$methodName($index)" + case StructType(structFields) => + val fields = structFields.map(f => (f.name, f.dataType)) + genRecord(q"$row($index).asInstanceOf[$rowTpe]", fields) + case _ => + c.abort(NoPosition, s"Query returns currently unhandled field type: $t") + } + + def analyzeQuery(query: String, tableTypes: Seq[Type]) = { + val parser = new SqlParser() + val logicalPlan = parser(query) + val catalog = new SimpleCatalog + val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, false) + + val tables = tableTypes.zipWithIndex.map { case (tblTpe, i) => + val TypeRef(_, _, Seq(schemaType)) = tblTpe + + val inputSchema = schemaFor(schemaType).dataType.asInstanceOf[StructType].toAttributes + (s"table$i", LocalRelation(inputSchema:_*)) + } + + tables.foreach(t => catalog.registerTable(None, t._1, t._2)) + + analyzer(logicalPlan) + } + // TODO: Don't copy this function from ScalaReflection. def schemaFor(tpe: `Type`): Schema = tpe match { case t if t <:< typeOf[Option[_]] => @@ -65,95 +168,10 @@ object SQLMacros { case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) } - val q""" - org.apache.spark.sql.test.TestSQLContext.SqlInterpolator( - scala.StringContext.apply(..$rawParts))""" = c.prefix.tree - - val parts = rawParts.map(_.toString.stripPrefix("\"").stripSuffix("\"")) - val query = parts(0) + (0 until args.size).map { i => - s"table$i" + parts(i + 1) - }.mkString("") - - val parser = new SqlParser() - val logicalPlan = parser(query) - val catalog = new SimpleCatalog - val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, false) - - val tables = args.zipWithIndex.map { case (arg, i) => - val TypeRef(_, _, Seq(schemaType)) = arg.actualType - - val inputSchema = schemaFor(schemaType).dataType.asInstanceOf[StructType].toAttributes - (s"table$i", LocalRelation(inputSchema:_*)) - } - - tables.foreach(t => catalog.registerTable(None, t._1, t._2)) - - val analyzedPlan = analyzer(logicalPlan) - - // TODO: This shouldn't probably be here but somewhere generic - // which defines the catalyst <-> Scala type mapping - def toScalaType(dt: DataType) = dt match { - case IntegerType => definitions.IntTpe - case LongType => definitions.LongTpe - case ShortType => definitions.ShortTpe - case ByteType => definitions.ByteTpe - case DoubleType => definitions.DoubleTpe - case FloatType => definitions.FloatTpe - case BooleanType => definitions.BooleanTpe - case StringType => definitions.StringClass.toType - } - - // TODO: Move this to a macro implementation class (we need it - // locally for `Type` which is on c.universe) - case class RecSchema(name: String, index: Int, - cType: DataType, tpe: Type) - - val fullSchema = analyzedPlan.output.zipWithIndex.map { case (attr, i) => - RecSchema(attr.name, i, attr.dataType, toScalaType(attr.dataType)) - } - - val schema = fullSchema.map(s => (s.name, s.tpe)) - - val rMacros = new RecordMacros[c.type](c) - - val (spFlds, objFields) = fullSchema.partition(s => - rMacros.specializedTypes.contains(s.tpe)) - - val spFldsByType = { - val grouped = spFlds.groupBy(_.tpe) - grouped.mapValues { _.map(s => s.name -> s).toMap } - } - - def methodName(t: DataType) = newTermName("get" + primitiveForType(t)) - - val dataObjImpl = { - val fldTrees = objFields.map(s => - s.name -> q"row.${methodName(s.cType)}(${s.index})" - ).toMap - val lookupTree = rMacros.genLookup(q"fieldName", fldTrees, mayCache = false) - q"($lookupTree).asInstanceOf[T]" - } - - val record = rMacros.specializedRecord(schema)(tq"Serializable")()(dataObjImpl) { - case tpe if spFldsByType.contains(tpe) => - val fldTrees = spFldsByType(tpe).mapValues(s => - q"row.${methodName(s.cType)}(${s.index})") - rMacros.genLookup(q"fieldName", fldTrees, mayCache = false) - } - - val tree = q""" - ..${args.zipWithIndex.map{ case (r,i) => q"""$r.registerAsTable(${s"table$i"})""" }} - val result = sql($query) - result.map(row => $record) - """ - - println(tree) - - c.Expr(tree) } // TODO: Duplicated from codegen PR... - protected def primitiveForType(dt: DataType) = dt match { + protected def primitiveForType(dt: PrimitiveType) = dt match { case IntegerType => "Int" case LongType => "Long" case ShortType => "Short" @@ -173,4 +191,4 @@ trait TypedSQL { // TODO: Handle functions... def sql(args: Any*): Any = macro SQLMacros.sqlImpl } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala index 5ab67670a1d3e..57b0318450dad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala @@ -46,9 +46,9 @@ class TypedSqlSuite extends FunSuite { assert(results.first().age == 30) } - ignore("nested results") { + test("nested results") { val results = sql"SELECT * FROM $cars" - assert(results.first().owner.name === "Michael") + assert(results.first().owner.name == "Michael") } test("join query") { @@ -56,4 +56,4 @@ class TypedSqlSuite extends FunSuite { assert(results.first().name == "Michael") } -} \ No newline at end of file +} From b38fef3b7520d68668c235922ed229d2e0a5b20f Mon Sep 17 00:00:00 2001 From: Tobias Schlatter Date: Tue, 22 Jul 2014 14:31:08 +0200 Subject: [PATCH 3/3] Refactor ScalaReflection to support compile-time reflection --- .../spark/sql/catalyst/ScalaReflection.scala | 12 ++++- .../scala/org/apache/spark/sql/TypedSql.scala | 47 ++----------------- 2 files changed, 13 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 5a55be1e51558..28957e6b5ee19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -27,8 +27,16 @@ import org.apache.spark.sql.catalyst.types._ /** * Provides experimental support for generating catalyst schemas for scala objects. */ -object ScalaReflection { - import scala.reflect.runtime.universe._ +object ScalaReflection extends ScalaReflection { + val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe +} + +trait ScalaReflection { + + /** The universe we work in (runtime or macro) */ + val universe: scala.reflect.api.Universe + + import universe._ case class Schema(dataType: DataType, nullable: Boolean) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala b/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala index 6cc2d882b7dac..d23b623490fa0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala @@ -23,7 +23,9 @@ object SQLMacros { case class Schema(dataType: DataType, nullable: Boolean) - class Macros[C <: Context](val c: C) { + class Macros[C <: Context](val c: C) extends ScalaReflection { + val universe: c.universe.type = c.universe + import c.universe._ val rowTpe = tq"_root_.org.apache.spark.sql.catalyst.expressions.Row" @@ -125,49 +127,6 @@ object SQLMacros { analyzer(logicalPlan) } - // TODO: Don't copy this function from ScalaReflection. - def schemaFor(tpe: `Type`): Schema = tpe match { - case t if t <:< typeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - Schema(schemaFor(optType).dataType, nullable = true) - case t if t <:< typeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val params = t.member(nme.CONSTRUCTOR).asMethod.paramss - Schema(StructType( - params.head.map { p => - val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) - StructField(p.name.toString, dataType, nullable) - }), nullable = true) - // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) - case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") - case t if t <:< typeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - Schema(ArrayType(schemaFor(elementType).dataType), nullable = true) - case t if t <:< typeOf[Map[_,_]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - Schema(MapType(schemaFor(keyType).dataType, schemaFor(valueType).dataType), nullable = true) - case t if t <:< typeOf[String] => Schema(StringType, nullable = true) - case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true) - case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) - case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) - case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) - case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) - case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) - case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) - case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) - case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) - case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) - case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) - case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) - case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) - case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) - case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) - } - } // TODO: Duplicated from codegen PR...