Skip to content

Commit

Permalink
[SPARK-24215][PYSPARK][FOLLOW UP] Implement eager evaluation for Data…
Browse files Browse the repository at this point in the history
…Frame APIs in PySpark

## What changes were proposed in this pull request?

Address comments in #21370 and add more test.

## How was this patch tested?

Enhance test in pyspark/sql/test.py and DataFrameSuite

Author: Yuanjian Li <[email protected]>

Closes #21553 from xuanyuanking/SPARK-24215-follow.
  • Loading branch information
xuanyuanking authored and gatorsmile committed Jun 27, 2018
1 parent a1a64e3 commit 6a0b77a
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 38 deletions.
27 changes: 0 additions & 27 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -456,33 +456,6 @@ Apart from these, the following properties are also available, and may be useful
from JVM to Python worker for every task.
</td>
</tr>
<tr>
<td><code>spark.sql.repl.eagerEval.enabled</code></td>
<td>false</td>
<td>
Enable eager evaluation or not. If true and the REPL you are using supports eager evaluation,
Dataset will be ran automatically. The HTML table which generated by <code>_repl_html_</code>
called by notebooks like Jupyter will feedback the queries user have defined. For plain Python
REPL, the output will be shown like <code>dataframe.show()</code>
(see <a href="https://issues.apache.org/jira/browse/SPARK-24215">SPARK-24215</a> for more details).
</td>
</tr>
<tr>
<td><code>spark.sql.repl.eagerEval.maxNumRows</code></td>
<td>20</td>
<td>
Default number of rows in eager evaluation output HTML table generated by <code>_repr_html_</code> or plain text,
this only take effect when <code>spark.sql.repl.eagerEval.enabled</code> is set to true.
</td>
</tr>
<tr>
<td><code>spark.sql.repl.eagerEval.truncate</code></td>
<td>20</td>
<td>
Default number of truncate in eager evaluation output HTML table generated by <code>_repr_html_</code> or
plain text, this only take effect when <code>spark.sql.repl.eagerEval.enabled</code> set to true.
</td>
</tr>
<tr>
<td><code>spark.files</code></td>
<td></td>
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,8 @@ def _repr_html_(self):
self._support_repr_html = True
if self._eager_eval:
max_num_rows = max(self._max_num_rows, 0)
vertical = False
sock_info = self._jdf.getRowsToPython(
max_num_rows, self._truncate, vertical)
max_num_rows, self._truncate)
rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))
head = rows[0]
row_data = rows[1:]
Expand Down
46 changes: 44 additions & 2 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3351,11 +3351,41 @@ def test_checking_csv_header(self):
finally:
shutil.rmtree(path)

def test_repr_html(self):
def test_repr_behaviors(self):
import re
pattern = re.compile(r'^ *\|', re.MULTILINE)
df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value"))
self.assertEquals(None, df._repr_html_())

# test when eager evaluation is enabled and _repr_html_ will not be called
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
expected1 = """+-----+-----+
|| key|value|
|+-----+-----+
|| 1| 1|
||22222|22222|
|+-----+-----+
|"""
self.assertEquals(re.sub(pattern, '', expected1), df.__repr__())
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
expected2 = """+---+-----+
||key|value|
|+---+-----+
|| 1| 1|
||222| 222|
|+---+-----+
|"""
self.assertEquals(re.sub(pattern, '', expected2), df.__repr__())
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
expected3 = """+---+-----+
||key|value|
|+---+-----+
|| 1| 1|
|+---+-----+
|only showing top 1 row
|"""
self.assertEquals(re.sub(pattern, '', expected3), df.__repr__())

# test when eager evaluation is enabled and _repr_html_ will be called
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
expected1 = """<table border='1'>
|<tr><th>key</th><th>value</th></tr>
Expand All @@ -3381,6 +3411,18 @@ def test_repr_html(self):
|"""
self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_())

# test when eager evaluation is disabled and _repr_html_ will be called
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}):
expected = "DataFrame[key: bigint, value: string]"
self.assertEquals(None, df._repr_html_())
self.assertEquals(expected, df.__repr__())
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
self.assertEquals(None, df._repr_html_())
self.assertEquals(expected, df.__repr__())
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
self.assertEquals(None, df._repr_html_())
self.assertEquals(expected, df.__repr__())


class HiveSparkSubmitTests(SparkSubmitTests):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,29 @@ object SQLConf {
"The size function returns null for null input if the flag is disabled.")
.booleanConf
.createWithDefault(true)

val REPL_EAGER_EVAL_ENABLED = buildConf("spark.sql.repl.eagerEval.enabled")
.doc("Enables eager evaluation or not. When true, the top K rows of Dataset will be " +
"displayed if and only if the REPL supports the eager evaluation. Currently, the " +
"eager evaluation is only supported in PySpark. For the notebooks like Jupyter, " +
"the HTML table (generated by _repr_html_) will be returned. For plain Python REPL, " +
"the returned outputs are formatted like dataframe.show().")
.booleanConf
.createWithDefault(false)

val REPL_EAGER_EVAL_MAX_NUM_ROWS = buildConf("spark.sql.repl.eagerEval.maxNumRows")
.doc("The max number of rows that are returned by eager evaluation. This only takes " +
"effect when spark.sql.repl.eagerEval.enabled is set to true. The valid range of this " +
"config is from 0 to (Int.MaxValue - 1), so the invalid config like negative and " +
"greater than (Int.MaxValue - 1) will be normalized to 0 and (Int.MaxValue - 1).")
.intConf
.createWithDefault(20)

val REPL_EAGER_EVAL_TRUNCATE = buildConf("spark.sql.repl.eagerEval.truncate")
.doc("The max number of characters for each cell that is returned by eager evaluation. " +
"This only takes effect when spark.sql.repl.eagerEval.enabled is set to true.")
.intConf
.createWithDefault(20)
}

/**
Expand Down
11 changes: 4 additions & 7 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,10 @@ class Dataset[T] private[sql](
* @param numRows Number of rows to return
* @param truncate If set to more than 0, truncates strings to `truncate` characters and
* all cells will be aligned right.
* @param vertical If set to true, the rows to return do not need truncate.
*/
private[sql] def getRows(
numRows: Int,
truncate: Int,
vertical: Boolean): Seq[Seq[String]] = {
truncate: Int): Seq[Seq[String]] = {
val newDf = toDF()
val castCols = newDf.logicalPlan.output.map { col =>
// Since binary types in top-level schema fields have a specific format to print,
Expand Down Expand Up @@ -289,7 +287,7 @@ class Dataset[T] private[sql](
vertical: Boolean = false): String = {
val numRows = _numRows.max(0).min(Int.MaxValue - 1)
// Get rows represented by Seq[Seq[String]], we may get one more line if it has more data.
val tmpRows = getRows(numRows, truncate, vertical)
val tmpRows = getRows(numRows, truncate)

val hasMoreData = tmpRows.length - 1 > numRows
val rows = tmpRows.take(numRows + 1)
Expand Down Expand Up @@ -3226,11 +3224,10 @@ class Dataset[T] private[sql](

private[sql] def getRowsToPython(
_numRows: Int,
truncate: Int,
vertical: Boolean): Array[Any] = {
truncate: Int): Array[Any] = {
EvaluatePython.registerPicklers()
val numRows = _numRows.max(0).min(Int.MaxValue - 1)
val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray
val rows = getRows(numRows, truncate).map(_.toArray).toArray
val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType)))
val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
rows.iterator.map(toJava))
Expand Down
59 changes: 59 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,65 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
testData.select($"*").show(1000)
}

test("getRows: truncate = [0, 20]") {
val longString = Array.fill(21)("1").mkString
val df = sparkContext.parallelize(Seq("1", longString)).toDF()
val expectedAnswerForFalse = Seq(
Seq("value"),
Seq("1"),
Seq("111111111111111111111"))
assert(df.getRows(10, 0) === expectedAnswerForFalse)
val expectedAnswerForTrue = Seq(
Seq("value"),
Seq("1"),
Seq("11111111111111111..."))
assert(df.getRows(10, 20) === expectedAnswerForTrue)
}

test("getRows: truncate = [3, 17]") {
val longString = Array.fill(21)("1").mkString
val df = sparkContext.parallelize(Seq("1", longString)).toDF()
val expectedAnswerForFalse = Seq(
Seq("value"),
Seq("1"),
Seq("111"))
assert(df.getRows(10, 3) === expectedAnswerForFalse)
val expectedAnswerForTrue = Seq(
Seq("value"),
Seq("1"),
Seq("11111111111111..."))
assert(df.getRows(10, 17) === expectedAnswerForTrue)
}

test("getRows: numRows = 0") {
val expectedAnswer = Seq(Seq("key", "value"), Seq("1", "1"))
assert(testData.select($"*").getRows(0, 20) === expectedAnswer)
}

test("getRows: array") {
val df = Seq(
(Array(1, 2, 3), Array(1, 2, 3)),
(Array(2, 3, 4), Array(2, 3, 4))
).toDF()
val expectedAnswer = Seq(
Seq("_1", "_2"),
Seq("[1, 2, 3]", "[1, 2, 3]"),
Seq("[2, 3, 4]", "[2, 3, 4]"))
assert(df.getRows(10, 20) === expectedAnswer)
}

test("getRows: binary") {
val df = Seq(
("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)),
("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8))
).toDF()
val expectedAnswer = Seq(
Seq("_1", "_2"),
Seq("[31 32]", "[41 42 43 2E]"),
Seq("[33 34]", "[31 32 33 34 36]"))
assert(df.getRows(10, 20) === expectedAnswer)
}

test("showString: truncate = [0, 20]") {
val longString = Array.fill(21)("1").mkString
val df = sparkContext.parallelize(Seq("1", longString)).toDF()
Expand Down

0 comments on commit 6a0b77a

Please sign in to comment.