Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24117][SQL] Unified the getSizePerRow #21189

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.types.{StructField, StructType}

object LocalRelation {
Expand Down Expand Up @@ -77,7 +78,7 @@ case class LocalRelation(
}

override def computeStats(): Statistics =
Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)
Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * data.length)

def toSQL(inlineTableName: String): String = {
require(data.nonEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@

package org.apache.spark.sql.catalyst.plans.logical.statsEstimation

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.math.BigDecimal.RoundingMode

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.{DecimalType, _}


object EstimationUtils {

/** Check if each plan has rowCount in its statistics. */
Expand Down Expand Up @@ -73,13 +71,12 @@ object EstimationUtils {
AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _)))
}

def getOutputSize(
def getSizePerRow(
attributes: Seq[Attribute],
outputRowCount: BigInt,
attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
// We assign a generic overhead for a Row object, the actual overhead is different for different
// Row format.
val sizePerRow = 8 + attributes.map { attr =>
8 + attributes.map { attr =>
if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) {
attr.dataType match {
case StringType =>
Expand All @@ -92,10 +89,15 @@ object EstimationUtils {
attr.dataType.defaultSize
}
}.sum
}

def getOutputSize(
attributes: Seq[Attribute],
outputRowCount: BigInt,
attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
// Output size can't be zero, or sizeInBytes of BinaryNode will also be zero
// (simple computation of statistics returns product of children).
if (outputRowCount > 0) outputRowCount * sizePerRow else 1
if (outputRowCount > 0) outputRowCount * getSizePerRow(attributes, attrStats) else 1
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
private def visitUnaryNode(p: UnaryNode): Statistics = {
// There should be some overhead in Row object, the size should not be zero when there is
// no columns, this help to prevent divide-by-zero error.
val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8
val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8
val childRowSize = EstimationUtils.getSizePerRow(p.child.output)
val outputRowSize = EstimationUtils.getSizePerRow(p.output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Assume there will be the same number of rows as child has.
var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize
if (sizeInBytes == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,21 @@ import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
import org.apache.spark.sql.streaming.{OutputMode, Trigger}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils


object MemoryStream {
protected val currentBlockId = new AtomicInteger(0)
protected val memoryStreamId = new AtomicInteger(0)
Expand Down Expand Up @@ -307,7 +305,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink
case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode {
def this(sink: MemorySink) = this(sink, sink.schema.toAttributes)

private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum
private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes)

override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update}
import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport}
Expand Down Expand Up @@ -182,7 +183,7 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode)
* Used to query the data that has been written into a [[MemorySinkV2]].
*/
case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode {
private val sizePerRow = output.map(_.dataType.defaultSize).sum
Copy link
Member

@gatorsmile gatorsmile Apr 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't think it's possible.

private val sizePerRow = EstimationUtils.getSizePerRow(output)

override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
}

assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
assert(sizes.head === BigInt(96),
assert(sizes.head === BigInt(128),
s"expected exact size 96 for table 'test', got: ${sizes.head}")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {

sink.addBatch(0, 1 to 3)
plan.invalidateStatsCache()
assert(plan.stats.sizeInBytes === 12)
assert(plan.stats.sizeInBytes === 36)

sink.addBatch(1, 4 to 6)
plan.invalidateStatsCache()
assert(plan.stats.sizeInBytes === 24)
assert(plan.stats.sizeInBytes === 72)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MemorySinkV2 is mainly for testing. I think the stats changes will not impact anything, right? @tdas @jose-torres

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't impact anything, but abstractly it seems strange that this unification would cause the stats to change? What are we doing differently to cause this, and how confident are we this won't happen to production sinks?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we forgot to count the row object overhead (8 bytes) before in memory stream.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM then

}

ignore("stress test") {
Expand Down