Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-6996][SQL] Support map types in java beans #5578

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst

import java.lang.{Iterable => JavaIterable}
import java.util.{Map => JavaMap}

import scala.collection.mutable.HashMap
Expand Down Expand Up @@ -49,6 +50,16 @@ object CatalystTypeConverters {
case (s: Seq[_], arrayType: ArrayType) =>
s.map(convertToCatalyst(_, arrayType.elementType))

case (jit: JavaIterable[_], arrayType: ArrayType) => {
val iter = jit.iterator
var listOfItems: List[Any] = List()
while (iter.hasNext) {
val item = iter.next()
listOfItems :+= convertToCatalyst(item, arrayType.elementType)
}
listOfItems
}

case (s: Array[_], arrayType: ArrayType) =>
s.toSeq.map(convertToCatalyst(_, arrayType.elementType))

Expand Down Expand Up @@ -124,6 +135,15 @@ object CatalystTypeConverters {
extractOption(item) match {
case a: Array[_] => a.toSeq.map(elementConverter)
case s: Seq[_] => s.map(elementConverter)
case i: JavaIterable[_] => {
val iter = i.iterator
var convertedIterable: List[Any] = List()
while (iter.hasNext) {
val item = iter.next()
convertedIterable :+= elementConverter(item)
}
convertedIterable
}
case null => null
}
}
Expand Down
110 changes: 110 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.beans.Introspector
import java.lang.{Iterable => JIterable}
import java.util.{Iterator => JIterator, Map => JMap}

import com.google.common.reflect.TypeToken

import org.apache.spark.sql.types._

import scala.language.existentials

/**
* Type-inference utilities for POJOs and Java collections.
*/
private [sql] object JavaTypeInference {

private val iterableType = TypeToken.of(classOf[JIterable[_]])
private val mapType = TypeToken.of(classOf[JMap[_, _]])
private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType
private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType
private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType
private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType

/**
* Infers the corresponding SQL data type of a Java type.
* @param typeToken Java type
* @return (SQL data type, nullable)
*/
private [sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
typeToken.getRawType match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)

case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)

case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)

case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)

case _ if typeToken.isArray =>
val (dataType, nullable) = inferDataType(typeToken.getComponentType)
(ArrayType(dataType, nullable), true)

case _ if iterableType.isAssignableFrom(typeToken) =>
val (dataType, nullable) = inferDataType(elementType(typeToken))
(ArrayType(dataType, nullable), true)

case _ if mapType.isAssignableFrom(typeToken) =>
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]])
val keyType = elementType(mapSupertype.resolveType(keySetReturnType))
val valueType = elementType(mapSupertype.resolveType(valuesReturnType))
val (keyDataType, _) = inferDataType(keyType)
val (valueDataType, nullable) = inferDataType(valueType)
(MapType(keyDataType, valueDataType, nullable), true)

case _ =>
val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
val fields = properties.map { property =>
val returnType = typeToken.method(property.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(returnType)
new StructField(property.getName, dataType, nullable)
}
(new StructType(fields), true)
}
}

private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]])
val iteratorType = iterableSupertype.resolveType(iteratorReturnType)
val itemType = iteratorType.resolveType(nextReturnType)
itemType
}
}
52 changes: 5 additions & 47 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import scala.collection.immutable
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag

import com.google.common.reflect.TypeToken

import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -1222,56 +1224,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Returns a Catalyst Schema for the given java bean class.
*/
protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
val (dataType, _) = inferDataType(beanClass)
val (dataType, _) = JavaTypeInference.inferDataType(TypeToken.of(beanClass))
dataType.asInstanceOf[StructType].fields.map { f =>
AttributeReference(f.name, f.dataType, f.nullable)()
}
}

/**
* Infers the corresponding SQL data type of a Java class.
* @param clazz Java class
* @return (SQL data type, nullable)
*/
private def inferDataType(clazz: Class[_]): (DataType, Boolean) = {
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
clazz match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)

case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)

case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)

case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)

case c: Class[_] if c.isArray =>
val (dataType, nullable) = inferDataType(c.getComponentType)
(ArrayType(dataType, nullable), true)

case _ =>
val beanInfo = Introspector.getBeanInfo(clazz)
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
val fields = properties.map { property =>
val (dataType, nullable) = inferDataType(property.getPropertyType)
new StructField(property.getName, dataType, nullable)
}
(new StructType(fields), true)
}
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,28 @@

package test.org.apache.spark.sql;

import java.io.Serializable;
import java.util.Arrays;

import scala.collection.Seq;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.TestData$;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.*;
import org.junit.*;

import scala.collection.JavaConversions;
import scala.collection.Seq;
import scala.collection.mutable.Buffer;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import static org.apache.spark.sql.functions.*;

Expand Down Expand Up @@ -106,6 +111,8 @@ public void testShow() {
public static class Bean implements Serializable {
private double a = 0.0;
private Integer[] b = new Integer[]{0, 1};
private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });
private List<String> d = Arrays.asList("floppy", "disk");

public double getA() {
return a;
Expand All @@ -114,6 +121,14 @@ public double getA() {
public Integer[] getB() {
return b;
}

public Map<String, int[]> getC() {
return c;
}

public List<String> getD() {
return d;
}
}

@Test
Expand All @@ -127,7 +142,15 @@ public void testCreateDataFrameFromJavaBeans() {
Assert.assertEquals(
new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()),
schema.apply("b"));
Row first = df.select("a", "b").first();
ArrayType valueType = new ArrayType(DataTypes.IntegerType, false);
MapType mapType = new MapType(DataTypes.StringType, valueType, true);
Assert.assertEquals(
new StructField("c", mapType, true, Metadata.empty()),
schema.apply("c"));
Assert.assertEquals(
new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()),
schema.apply("d"));
Row first = df.select("a", "b", "c", "d").first();
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
// Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below,
// verify that it has the expected length, and contains expected elements.
Expand All @@ -136,5 +159,15 @@ public void testCreateDataFrameFromJavaBeans() {
for (int i = 0; i < result.length(); i++) {
Assert.assertEquals(bean.getB()[i], result.apply(i));
}
Buffer<Integer> outputBuffer = (Buffer<Integer>) first.getJavaMap(2).get("hello");
Assert.assertArrayEquals(
bean.getC().get("hello"),
Ints.toArray(JavaConversions.asJavaList(outputBuffer)));
Seq<String> d = first.getAs(3);
Assert.assertEquals(bean.getD().size(), d.length());
for (int i = 0; i < d.length(); i++) {
Assert.assertEquals(bean.getD().get(i), d.apply(i));
}
}

}