Skip to content

Commit

Permalink
[SPARK-23935][SQL] Adding map_entries function
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR adds `map_entries` function that returns an unordered array of all entries in the given map.

## How was this patch tested?

New tests added into:
- `CollectionExpressionSuite`
- `DataFrameFunctionsSuite`

## CodeGen examples
### Primitive types
```
val df = Seq(Map(1 -> 5, 2 -> 6)).toDF("m")
df.filter('m.isNotNull).select(map_entries('m)).debugCodegen
```
Result:
```
/* 042 */         boolean project_isNull_0 = false;
/* 043 */
/* 044 */         ArrayData project_value_0 = null;
/* 045 */
/* 046 */         final int project_numElements_0 = inputadapter_value_0.numElements();
/* 047 */         final ArrayData project_keys_0 = inputadapter_value_0.keyArray();
/* 048 */         final ArrayData project_values_0 = inputadapter_value_0.valueArray();
/* 049 */
/* 050 */         final long project_size_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
/* 051 */           project_numElements_0,
/* 052 */           32);
/* 053 */         if (project_size_0 > 2147483632) {
/* 054 */           final Object[] project_internalRowArray_0 = new Object[project_numElements_0];
/* 055 */           for (int z = 0; z < project_numElements_0; z++) {
/* 056 */             project_internalRowArray_0[z] = new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{project_keys_0.getInt(z), project_values_0.getInt(z)});
/* 057 */           }
/* 058 */           project_value_0 = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_internalRowArray_0);
/* 059 */
/* 060 */         } else {
/* 061 */           final byte[] project_arrayBytes_0 = new byte[(int)project_size_0];
/* 062 */           UnsafeArrayData project_unsafeArrayData_0 = new UnsafeArrayData();
/* 063 */           Platform.putLong(project_arrayBytes_0, 16, project_numElements_0);
/* 064 */           project_unsafeArrayData_0.pointTo(project_arrayBytes_0, 16, (int)project_size_0);
/* 065 */
/* 066 */           final int project_structsOffset_0 = UnsafeArrayData.calculateHeaderPortionInBytes(project_numElements_0) + project_numElements_0 * 8;
/* 067 */           UnsafeRow project_unsafeRow_0 = new UnsafeRow(2);
/* 068 */           for (int z = 0; z < project_numElements_0; z++) {
/* 069 */             long offset = project_structsOffset_0 + z * 24L;
/* 070 */             project_unsafeArrayData_0.setLong(z, (offset << 32) + 24L);
/* 071 */             project_unsafeRow_0.pointTo(project_arrayBytes_0, 16 + offset, 24);
/* 072 */             project_unsafeRow_0.setInt(0, project_keys_0.getInt(z));
/* 073 */             project_unsafeRow_0.setInt(1, project_values_0.getInt(z));
/* 074 */           }
/* 075 */           project_value_0 = project_unsafeArrayData_0;
/* 076 */
/* 077 */         }
```
### Non-primitive types
```
val df = Seq(Map("a" -> "foo", "b" -> null)).toDF("m")
df.filter('m.isNotNull).select(map_entries('m)).debugCodegen
```
Result:
```
/* 042 */         boolean project_isNull_0 = false;
/* 043 */
/* 044 */         ArrayData project_value_0 = null;
/* 045 */
/* 046 */         final int project_numElements_0 = inputadapter_value_0.numElements();
/* 047 */         final ArrayData project_keys_0 = inputadapter_value_0.keyArray();
/* 048 */         final ArrayData project_values_0 = inputadapter_value_0.valueArray();
/* 049 */
/* 050 */         final Object[] project_internalRowArray_0 = new Object[project_numElements_0];
/* 051 */         for (int z = 0; z < project_numElements_0; z++) {
/* 052 */           project_internalRowArray_0[z] = new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{project_keys_0.getUTF8String(z), project_values_0.getUTF8String(z)});
/* 053 */         }
/* 054 */         project_value_0 = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_internalRowArray_0);
```

Author: Marek Novotny <[email protected]>

Closes #21236 from mn-mikke/feature/array-api-map_entries-to-master.
  • Loading branch information
mn-mikke authored and ueshin committed May 21, 2018
1 parent e480ecc commit a6e883f
Show file tree
Hide file tree
Showing 9 changed files with 287 additions and 0 deletions.
20 changes: 20 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2344,6 +2344,26 @@ def map_values(col):
return Column(sc._jvm.functions.map_values(_to_java_column(col)))


@since(2.4)
def map_entries(col):
"""
Collection function: Returns an unordered array of all entries in the given map.
:param col: name of column or expression
>>> from pyspark.sql.functions import map_entries
>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
>>> df.select(map_entries("data").alias("entries")).show()
+----------------+
| entries|
+----------------+
|[[1, a], [2, b]]|
+----------------+
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.map_entries(_to_java_column(col)))


@ignore_unicode_prefix
@since(2.4)
def array_repeat(col, count):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
*/
public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable {

public static final int WORD_SIZE = 8;

//////////////////////////////////////////////////////////////////////////////
// Static methods
//////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ object FunctionRegistry {
expression[ElementAt]("element_at"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[MapEntries]("map_entries"),
expression[Size]("size"),
expression[Slice]("slice"),
expression[Size]("cardinality"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,40 @@ class CodegenContext {
""".stripMargin
}

/**
* Generates code creating a [[UnsafeArrayData]]. The generated code executes
* a provided fallback when the size of backing array would exceed the array size limit.
* @param arrayName a name of the array to create
* @param numElements a piece of code representing the number of elements the array should contain
* @param elementSize a size of an element in bytes
* @param bodyCode a function generating code that fills up the [[UnsafeArrayData]]
* and getting the backing array as a parameter
* @param fallbackCode a piece of code executed when the array size limit is exceeded
*/
def createUnsafeArrayWithFallback(
arrayName: String,
numElements: String,
elementSize: Int,
bodyCode: String => String,
fallbackCode: String): String = {
val arraySize = freshName("size")
val arrayBytes = freshName("arrayBytes")
s"""
|final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
| $numElements,
| $elementSize);
|if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| $fallbackCode
|} else {
| final byte[] $arrayBytes = new byte[(int)$arraySize];
| UnsafeArrayData $arrayName = new UnsafeArrayData();
| Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements);
| $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize);
| ${bodyCode(arrayBytes)}
|}
""".stripMargin
}

/**
* Generates code to do null safe execution, i.e. only execute the code when the input is not
* null by adding null check if necessary.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}

Expand Down Expand Up @@ -154,6 +155,158 @@ case class MapValues(child: Expression)
override def prettyName: String = "map_values"
}

/**
* Returns an unordered array of all entries in the given map.
*/
@ExpressionDescription(
usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.",
examples = """
Examples:
> SELECT _FUNC_(map(1, 'a', 2, 'b'));
[(1,"a"),(2,"b")]
""",
since = "2.4.0")
case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(MapType)

lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]

override def dataType: DataType = {
ArrayType(
StructType(
StructField("key", childDataType.keyType, false) ::
StructField("value", childDataType.valueType, childDataType.valueContainsNull) ::
Nil),
false)
}

override protected def nullSafeEval(input: Any): Any = {
val childMap = input.asInstanceOf[MapData]
val keys = childMap.keyArray()
val values = childMap.valueArray()
val length = childMap.numElements()
val resultData = new Array[AnyRef](length)
var i = 0;
while (i < length) {
val key = keys.get(i, childDataType.keyType)
val value = values.get(i, childDataType.valueType)
val row = new GenericInternalRow(Array[Any](key, value))
resultData.update(i, row)
i += 1
}
new GenericArrayData(resultData)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val numElements = ctx.freshName("numElements")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType)
val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
val code = if (isKeyPrimitive && isValuePrimitive) {
genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements)
} else {
genCodeForAnyElements(ctx, keys, values, ev.value, numElements)
}
s"""
|final int $numElements = $c.numElements();
|final ArrayData $keys = $c.keyArray();
|final ArrayData $values = $c.valueArray();
|$code
""".stripMargin
})
}

private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z")

private def getValue(varName: String) = {
CodeGenerator.getValue(varName, childDataType.valueType, "z")
}

private def genCodeForPrimitiveElements(
ctx: CodegenContext,
keys: String,
values: String,
arrayData: String,
numElements: String): String = {
val unsafeRow = ctx.freshName("unsafeRow")
val unsafeArrayData = ctx.freshName("unsafeArrayData")
val structsOffset = ctx.freshName("structsOffset")
val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"

val baseOffset = Platform.BYTE_ARRAY_OFFSET
val wordSize = UnsafeRow.WORD_SIZE
val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2
val structSizeAsLong = structSize + "L"
val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)

val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});"
val valueAssignmentChecked = if (childDataType.valueContainsNull) {
s"""
|if ($values.isNullAt(z)) {
| $unsafeRow.setNullAt(1);
|} else {
| $valueAssignment
|}
""".stripMargin
} else {
valueAssignment
}

val assignmentLoop = (byteArray: String) =>
s"""
|final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize;
|UnsafeRow $unsafeRow = new UnsafeRow(2);
|for (int z = 0; z < $numElements; z++) {
| long offset = $structsOffset + z * $structSizeAsLong;
| $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong);
| $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize);
| $unsafeRow.set$keyTypeName(0, ${getKey(keys)});
| $valueAssignmentChecked
|}
|$arrayData = $unsafeArrayData;
""".stripMargin

ctx.createUnsafeArrayWithFallback(
unsafeArrayData,
numElements,
structSize + wordSize,
assignmentLoop,
genCodeForAnyElements(ctx, keys, values, arrayData, numElements))
}

private def genCodeForAnyElements(
ctx: CodegenContext,
keys: String,
values: String,
arrayData: String,
numElements: String): String = {
val genericArrayClass = classOf[GenericArrayData].getName
val rowClass = classOf[GenericInternalRow].getName
val data = ctx.freshName("internalRowArray")

val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) {
s"$values.isNullAt(z) ? null : (Object)${getValue(values)}"
} else {
getValue(values)
}

s"""
|final Object[] $data = new Object[$numElements];
|for (int z = 0; z < $numElements; z++) {
| $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck});
|}
|$arrayData = new $genericArrayClass($data);
""".stripMargin
}

override def prettyName: String = "map_entries"
}

/**
* Common base class for [[SortArray]] and [[ArraySort]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._

class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -56,6 +57,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MapValues(m2), null)
}

test("MapEntries") {
def r(values: Any*): InternalRow = create_row(values: _*)

// Primitive-type keys/values
val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType))
val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType))
val mi2 = Literal.create(null, MapType(IntegerType, IntegerType))

checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2)))
checkEvaluation(MapEntries(mi1), Seq.empty)
checkEvaluation(MapEntries(mi2), null)

// Non-primitive-type keys/values
val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType))
val ms1 = Literal.create(Map[Int, Int](), MapType(StringType, StringType))
val ms2 = Literal.create(null, MapType(StringType, StringType))

checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null)))
checkEvaluation(MapEntries(ms1), Seq.empty)
checkEvaluation(MapEntries(ms2), null)
}

test("Sort Array") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
if (expected.isNaN) result.isNaN else expected == result
case (result: Float, expected: Float) =>
if (expected.isNaN) result.isNaN else expected == result
case (result: UnsafeRow, expected: GenericInternalRow) =>
val structType = exprDataType.asInstanceOf[StructType]
result.toSeq(structType) == expected.toSeq(structType)
case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema)
case _ =>
result == expected
Expand Down
7 changes: 7 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3492,6 +3492,13 @@ object functions {
*/
def map_values(e: Column): Column = withExpr { MapValues(e.expr) }

/**
* Returns an unordered array of all entries in the given map.
* @group collection_funcs
* @since 2.4.0
*/
def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) }

// scalastyle:off line.size.limit
// scalastyle:off parameter.number

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}

test("map_entries") {
val dummyFilter = (c: Column) => c.isNotNull || c.isNull

// Primitive-type elements
val idf = Seq(
Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300),
Map[Int, Int](),
null
).toDF("m")
val iExpected = Seq(
Row(Seq(Row(1, 100), Row(2, 200), Row(3, 300))),
Row(Seq.empty),
Row(null)
)

checkAnswer(idf.select(map_entries('m)), iExpected)
checkAnswer(idf.selectExpr("map_entries(m)"), iExpected)
checkAnswer(idf.filter(dummyFilter('m)).select(map_entries('m)), iExpected)
checkAnswer(
spark.range(1).selectExpr("map_entries(map(1, null, 2, null))"),
Seq(Row(Seq(Row(1, null), Row(2, null)))))
checkAnswer(
spark.range(1).filter(dummyFilter('id)).selectExpr("map_entries(map(1, null, 2, null))"),
Seq(Row(Seq(Row(1, null), Row(2, null)))))

// Non-primitive-type elements
val sdf = Seq(
Map[String, String]("a" -> "f", "b" -> "o", "c" -> "o"),
Map[String, String]("a" -> null, "b" -> null),
Map[String, String](),
null
).toDF("m")
val sExpected = Seq(
Row(Seq(Row("a", "f"), Row("b", "o"), Row("c", "o"))),
Row(Seq(Row("a", null), Row("b", null))),
Row(Seq.empty),
Row(null)
)

checkAnswer(sdf.select(map_entries('m)), sExpected)
checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected)
checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected)
}

test("array contains function") {
val df = Seq(
(Seq[Int](1, 2), "x"),
Expand Down

0 comments on commit a6e883f

Please sign in to comment.