Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-36559][SQL][PYTHON] Create plans dedicated to distributed-sequence index for optimization #33807

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions python/pyspark/pandas/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5160,26 +5160,25 @@ def test_print_schema(self):
sys.stdout = prev

def test_explain_hint(self):
with ps.option_context("compute.default_index_type", "sequence"):
psdf1 = ps.DataFrame(
{"lkey": ["foo", "bar", "baz", "foo"], "value": [1, 2, 3, 5]},
columns=["lkey", "value"],
)
psdf2 = ps.DataFrame(
{"rkey": ["foo", "bar", "baz", "foo"], "value": [5, 6, 7, 8]},
columns=["rkey", "value"],
)
merged = psdf1.merge(psdf2.spark.hint("broadcast"), left_on="lkey", right_on="rkey")
prev = sys.stdout
try:
out = StringIO()
sys.stdout = out
merged.spark.explain()
actual = out.getvalue().strip()

self.assertTrue("Broadcast" in actual, actual)
finally:
sys.stdout = prev
psdf1 = ps.DataFrame(
{"lkey": ["foo", "bar", "baz", "foo"], "value": [1, 2, 3, 5]},
columns=["lkey", "value"],
)
psdf2 = ps.DataFrame(
{"rkey": ["foo", "bar", "baz", "foo"], "value": [5, 6, 7, 8]},
columns=["rkey", "value"],
)
merged = psdf1.merge(psdf2.spark.hint("broadcast"), left_on="lkey", right_on="rkey")
prev = sys.stdout
try:
out = StringIO()
sys.stdout = out
merged.spark.explain()
actual = out.getvalue().strip()

self.assertTrue("Broadcast" in actual, actual)
finally:
sys.stdout = prev

def test_mad(self):
pdf = pd.DataFrame(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))

case oldVersion @ AttachDistributedSequence(sequenceAttr, _)
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(sequenceAttr = sequenceAttr.newInstance())))

case oldVersion: Generate
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,11 @@ object ColumnPruning extends Rule[LogicalPlan] {
}
a.copy(child = Expand(newProjects, newOutput, grandChild))

// Prune and drop AttachDistributedSequence if the produced attribute is not referred.
case p @ Project(_, a @ AttachDistributedSequence(_, grandChild))
if !p.references.contains(a.sequenceAttr) =>
p.copy(child = prunedChild(grandChild, p.references))

// Prunes the unused columns from child of `DeserializeToObject`
case d @ DeserializeToObject(_, _, child) if !child.outputSet.subsetOf(d.references) =>
d.copy(child = prunedChild(child, d.references))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,20 @@ case class ArrowEvalPython(
override protected def withNewChildInternal(newChild: LogicalPlan): ArrowEvalPython =
copy(child = newChild)
}

/**
* A logical plan that adds a new long column with the name `name` that
* increases one by one. This is for 'distributed-sequence' default index
* in pandas API on Spark.
*/
case class AttachDistributedSequence(
sequenceAttr: Attribute,
child: LogicalPlan) extends UnaryNode {

override val producedAttributes: AttributeSet = AttributeSet(sequenceAttr)

override val output: Seq[Attribute] = sequenceAttr +: child.output

override protected def withNewChildInternal(newChild: LogicalPlan): AttachDistributedSequence =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -452,5 +452,11 @@ class ColumnPruningSuite extends PlanTest {
val expected = input.where(rand(0L) > 0.5).where('key < 10).select('key).analyze
comparePlans(optimized, expected)
}
// todo: add more tests for column pruning

test("SPARK-36559 Prune and drop distributed-sequence if the produced column is not referred") {
val input = LocalRelation('a.int, 'b.int, 'c.int)
val plan1 = AttachDistributedSequence('d.int, input).select('a)
val correctAnswer1 = Project(Seq('a), input).analyze
comparePlans(Optimize.execute(plan1.analyze), correctAnswer1)
}
}
23 changes: 5 additions & 18 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3541,24 +3541,11 @@ class Dataset[T] private[sql](
* This is for 'distributed-sequence' default index in pandas API on Spark.
*/
private[sql] def withSequenceColumn(name: String) = {
val rdd: RDD[InternalRow] =
// Checkpoint the DataFrame to fix the partition ID.
localCheckpoint(false)
.queryExecution.toRdd.zipWithIndex().mapPartitions { iter =>
val joinedRow = new JoinedRow
val unsafeRowWriter =
new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1)

iter.map { case (row, id) =>
// Writes to an UnsafeRow directly
unsafeRowWriter.reset()
unsafeRowWriter.write(0, id)
joinedRow(unsafeRowWriter.getRow, row)
}
}

sparkSession.internalCreateDataFrame(
rdd, StructType(StructField(name, LongType, nullable = false) +: schema), isStreaming)
Dataset.ofRows(
sparkSession,
AttachDistributedSequence(
AttributeReference(name, LongType, nullable = false)(),
logicalPlan))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
func, output, planLater(left), planLater(right)) :: Nil
case logical.MapInPandas(func, output, child) =>
execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil
case logical.AttachDistributedSequence(attr, child) =>
execution.python.AttachDistributedSequenceExec(attr, planLater(child)) :: Nil
case logical.MapElements(f, _, _, objAttr, child) =>
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
case logical.AppendColumns(f, _, _, in, out, child) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.python

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}

/**
* A physical plan that adds a new long column with `sequenceAttr` that
* increases one by one. This is for 'distributed-sequence' default index
* in pandas API on Spark.
*/
case class AttachDistributedSequenceExec(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could think about implementing this with an expression (like Python UDF or Window) .. but just decided to do this with plans to avoid making it too much complicated.

sequenceAttr: Attribute,
child: SparkPlan)
extends UnaryExecNode {

override def producedAttributes: AttributeSet = AttributeSet(sequenceAttr)

override val output: Seq[Attribute] = sequenceAttr +: child.output

override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def doExecute(): RDD[InternalRow] = {
child.execute().map(_.copy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to copy the unsafe rows before calling localCheckpoint?

Copy link
Member Author

@HyukjinKwon HyukjinKwon Aug 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, i forgot to describe it. localCheckpoint caches (persists) the data, and it stores the rows so it needs to copy. This is actually being done at Dataset.checkpoint: https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L679

.localCheckpoint() // to avoid execute multiple jobs. zipWithIndex launches a Spark job.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still not sure if we need to localCheckPoint in the middle here ... but let me keep it as is for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g.) if the child RDD has a shuffle, the shuffle will be triggered twice, and this checkpoint is to avoid that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shuffle will be reused. I think localCheckpoint is useful to save computation. e.g. df.sort(...).withSequenceColumn, if we don't do localCheckpoint, the shuffle is still done only once, but the local sort after shuffle will be done twice.

.zipWithIndex().mapPartitions { iter =>
val unsafeProj = UnsafeProjection.create(output, output)
val joinedRow = new JoinedRow
val unsafeRowWriter =
new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1)

iter.map { case (row, id) =>
// Writes to an UnsafeRow directly
unsafeRowWriter.reset()
unsafeRowWriter.write(0, id)
joinedRow(unsafeRowWriter.getRow, row)
}.map(unsafeProj)
}
}

override protected def withNewChildInternal(newChild: SparkPlan): AttachDistributedSequenceExec =
copy(child = newChild)
}