Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23935][SQL] Adding map_entries function #21236

Closed
20 changes: 20 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2304,6 +2304,26 @@ def map_values(col):
return Column(sc._jvm.functions.map_values(_to_java_column(col)))


@since(2.4)
def map_entries(col):
"""
Collection function: Returns an unordered array of all entries in the given map.

:param col: name of column or expression

>>> from pyspark.sql.functions import map_entries
>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
>>> df.select(map_entries("data").alias("entries")).show()
+----------------+
| entries|
+----------------+
|[[1, a], [2, b]]|
+----------------+
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.map_entries(_to_java_column(col)))


# ---------------------------- User Defined Function ----------------------------------

class PandasUDFType(object):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ object FunctionRegistry {
expression[ElementAt]("element_at"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[MapEntries]("map_entries"),
expression[Size]("size"),
expression[Slice]("slice"),
expression[Size]("cardinality"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}

Expand Down Expand Up @@ -118,6 +119,162 @@ case class MapValues(child: Expression)
override def prettyName: String = "map_values"
}

/**
* Returns an unordered array of all entries in the given map.
*/
@ExpressionDescription(
usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.",
examples = """
Examples:
> SELECT _FUNC_(map(1, 'a', 2, 'b'));
[(1,"a"),(2,"b")]
""",
since = "2.4.0")
case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(MapType)

lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]

override def dataType: DataType = {
ArrayType(
StructType(
StructField("key", childDataType.keyType, false) ::
StructField("value", childDataType.valueType, childDataType.valueContainsNull) ::
Nil),
false)
}

override protected def nullSafeEval(input: Any): Any = {
val childMap = input.asInstanceOf[MapData]
val keys = childMap.keyArray()
val values = childMap.valueArray()
val length = childMap.numElements()
val resultData = new Array[AnyRef](length)
var i = 0;
while (i < length) {
val key = keys.get(i, childDataType.keyType)
val value = values.get(i, childDataType.valueType)
val row = new GenericInternalRow(Array[Any](key, value))
resultData.update(i, row)
i += 1
}
new GenericArrayData(resultData)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val numElements = ctx.freshName("numElements")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType)
val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
val code = if (isKeyPrimitive && isValuePrimitive) {
genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements)
} else {
genCodeForAnyElements(ctx, keys, values, ev.value, numElements)
}
s"""
|final int $numElements = $c.numElements();
|final ArrayData $keys = $c.keyArray();
|final ArrayData $values = $c.valueArray();
|$code
""".stripMargin
})
}

private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z")

private def getValue(varName: String) = {
CodeGenerator.getValue(varName, childDataType.valueType, "z")
}

private def genCodeForPrimitiveElements(
ctx: CodegenContext,
keys: String,
values: String,
arrayData: String,
numElements: String): String = {
val byteArraySize = ctx.freshName("byteArraySize")
val data = ctx.freshName("byteArray")
val unsafeRow = ctx.freshName("unsafeRow")
val unsafeArrayData = ctx.freshName("unsafeArrayData")
val structsOffset = ctx.freshName("structsOffset")
val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray"
val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"

val baseOffset = Platform.BYTE_ARRAY_OFFSET
val longSize = LongType.defaultSize
val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering it is right to use longSize here?
I know the value is 8 and is same as the word size, but feel like the meaning is different?
cc @gatorsmile @cloud-fan

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ueshin Really good question. I'm eager to learn about the true purpose of the DataType.defaultSize function. Currently, it's used in this meaning at more places (e.g.GenArrayData.genCodeToCreateArrayData and CodeGenerator.createUnsafeArray.)

What about using Long.BYTES from Java 8 instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, 8 is the better choice since it is not related to an element size of long.
To my best guess, it would be best to define a new constant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kiszk Thanks for your suggestion, but it seems to me that LongType.defaultSize could be used in this case. It seems that the purpose of defaultSize is not only the calculation of estimated data size in statistics. GenerateUnsafeProjection.writeArrayToBuffer, InterpretedUnsafeProjection.getElementSize and other parts utilize defaultSize in the same way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not for the element size of arrays. I agree with @kiszk to use 8.
Maybe we need to add a constant to represent the word size in UnsafeRow or somewhere in the future pr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh OK, I misunderstood the comments. Thanks guys!

val structSizeAsLong = structSize + "L"
val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)

val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});"
val valueAssignmentChecked = if (childDataType.valueContainsNull) {
s"""
|if ($values.isNullAt(z)) {
| $unsafeRow.setNullAt(1);
|} else {
| $valueAssignment
|}
""".stripMargin
} else {
valueAssignment
}

s"""
|final long $byteArraySize = $calculateArraySize($numElements, ${longSize + structSize});
|if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| ${genCodeForAnyElements(ctx, keys, values, arrayData, numElements)}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, should we use this idiom for other array functions? WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I separated the logic that I can leverage for map_from_entries function. Moreover, I think it should be possible to replace UnsafeArrayData.createUnsafeArray with that logic, but will do it in a different PR.

|} else {
| final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize;
| final byte[] $data = new byte[(int)$byteArraySize];
| UnsafeArrayData $unsafeArrayData = new UnsafeArrayData();
| Platform.putLong($data, $baseOffset, $numElements);
| $unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize);
| UnsafeRow $unsafeRow = new UnsafeRow(2);
| for (int z = 0; z < $numElements; z++) {
| long offset = $structsOffset + z * $structSizeAsLong;
| $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong);
| $unsafeRow.pointTo($data, $baseOffset + offset, $structSize);
| $unsafeRow.set$keyTypeName(0, ${getKey(keys)});
| $valueAssignmentChecked
| }
| $arrayData = $unsafeArrayData;
|}
""".stripMargin
}

private def genCodeForAnyElements(
ctx: CodegenContext,
keys: String,
values: String,
arrayData: String,
numElements: String): String = {
val genericArrayClass = classOf[GenericArrayData].getName
val rowClass = classOf[GenericInternalRow].getName
val data = ctx.freshName("internalRowArray")

val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) {
s"$values.isNullAt(z) ? null : (Object)${getValue(values)}"
} else {
getValue(values)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent


s"""
|final Object[] $data = new Object[$numElements];
|for (int z = 0; z < $numElements; z++) {
| $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck});
|}
|$arrayData = new $genericArrayClass($data);
""".stripMargin
}

override def prettyName: String = "map_entries"
}

/**
* Common base class for [[SortArray]] and [[ArraySort]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._

class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -56,6 +57,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MapValues(m2), null)
}

test("MapEntries") {
def r(values: Any*): InternalRow = create_row(values: _*)

// Primitive-type keys/values
val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType))
val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType))
val mi2 = Literal.create(null, MapType(IntegerType, IntegerType))

checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2)))
checkEvaluation(MapEntries(mi1), Seq.empty)
checkEvaluation(MapEntries(mi2), null)

// Non-primitive-type keys/values
val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType))
val ms1 = Literal.create(Map[Int, Int](), MapType(StringType, StringType))
val ms2 = Literal.create(null, MapType(StringType, StringType))

checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null)))
checkEvaluation(MapEntries(ms1), Seq.empty)
checkEvaluation(MapEntries(ms2), null)
}

test("Sort Array") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
if (expected.isNaN) result.isNaN else expected == result
case (result: Float, expected: Float) =>
if (expected.isNaN) result.isNaN else expected == result
case (result: UnsafeRow, expected: GenericInternalRow) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mn-mikke I was just looking over compiler warnings, and noticed it claims this case is never triggered. I think it's because it would always first match the (InternalRow, InternalRow) case above. Should it go before that then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @srowen,
(InternalRow, InternalRow) case was introduced later in 21838 and covers the logic of the case with UnsafeRow. So we can just remove the unreachable piece of code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Roger that, looks like Wenchen just did so. Thanks!

val structType = exprDataType.asInstanceOf[StructType]
result.toSeq(structType) == expected.toSeq(structType)
case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema)
case _ =>
result == expected
Expand Down
7 changes: 7 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3414,6 +3414,13 @@ object functions {
*/
def map_values(e: Column): Column = withExpr { MapValues(e.expr) }

/**
* Returns an unordered array of all entries in the given map.
* @group collection_funcs
* @since 2.4.0
*/
def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) }

// scalastyle:off line.size.limit
// scalastyle:off parameter.number

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}

test("map_entries") {
val dummyFilter = (c: Column) => c.isNotNull || c.isNull

// Primitive-type elements
val idf = Seq(
Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300),
Map[Int, Int](),
null
).toDF("m")
val iExpected = Seq(
Row(Seq(Row(1, 100), Row(2, 200), Row(3, 300))),
Row(Seq.empty),
Row(null)
)

checkAnswer(idf.select(map_entries('m)), iExpected)
checkAnswer(idf.selectExpr("map_entries(m)"), iExpected)
checkAnswer(idf.filter(dummyFilter('m)).select(map_entries('m)), iExpected)
checkAnswer(
spark.range(1).selectExpr("map_entries(map(1, null, 2, null))"),
Seq(Row(Seq(Row(1, null), Row(2, null)))))
checkAnswer(
spark.range(1).filter(dummyFilter('id)).selectExpr("map_entries(map(1, null, 2, null))"),
Seq(Row(Seq(Row(1, null), Row(2, null)))))

// Non-primitive-type elements
val sdf = Seq(
Map[String, String]("a" -> "f", "b" -> "o", "c" -> "o"),
Map[String, String]("a" -> null, "b" -> null),
Map[String, String](),
null
).toDF("m")
val sExpected = Seq(
Row(Seq(Row("a", "f"), Row("b", "o"), Row("c", "o"))),
Row(Seq(Row("a", null), Row("b", null))),
Row(Seq.empty),
Row(null)
)

checkAnswer(sdf.select(map_entries('m)), sExpected)
checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected)
checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected)
}

test("array contains function") {
val df = Seq(
(Seq[Int](1, 2), "x"),
Expand Down