Skip to content

Commit

Permalink
[SPARK-24117][SQL] Unified the getSizePerRow
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This pr unified the `getSizePerRow` because `getSizePerRow` is used in many places. For example:

1. [LocalRelation.scala#L80](https://github.com/wangyum/spark/blob/f70f46d1e5bc503e9071707d837df618b7696d32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala#L80)
2. [SizeInBytesOnlyStatsPlanVisitor.scala#L36](https://github.com/apache/spark/blob/76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala#L36)

## How was this patch tested?
Exist tests

Author: Yuming Wang <[email protected]>

Closes #21189 from wangyum/SPARK-24117.
  • Loading branch information
wangyum authored and cloud-fan committed May 8, 2018
1 parent 2f6fe7d commit 487faf1
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.types.{StructField, StructType}

object LocalRelation {
Expand Down Expand Up @@ -77,7 +78,7 @@ case class LocalRelation(
}

override def computeStats(): Statistics =
Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)
Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * data.length)

def toSQL(inlineTableName: String): String = {
require(data.nonEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@

package org.apache.spark.sql.catalyst.plans.logical.statsEstimation

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.math.BigDecimal.RoundingMode

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.{DecimalType, _}


object EstimationUtils {

/** Check if each plan has rowCount in its statistics. */
Expand Down Expand Up @@ -73,13 +71,12 @@ object EstimationUtils {
AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _)))
}

def getOutputSize(
def getSizePerRow(
attributes: Seq[Attribute],
outputRowCount: BigInt,
attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
// We assign a generic overhead for a Row object, the actual overhead is different for different
// Row format.
val sizePerRow = 8 + attributes.map { attr =>
8 + attributes.map { attr =>
if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) {
attr.dataType match {
case StringType =>
Expand All @@ -92,10 +89,15 @@ object EstimationUtils {
attr.dataType.defaultSize
}
}.sum
}

def getOutputSize(
attributes: Seq[Attribute],
outputRowCount: BigInt,
attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
// Output size can't be zero, or sizeInBytes of BinaryNode will also be zero
// (simple computation of statistics returns product of children).
if (outputRowCount > 0) outputRowCount * sizePerRow else 1
if (outputRowCount > 0) outputRowCount * getSizePerRow(attributes, attrStats) else 1
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
private def visitUnaryNode(p: UnaryNode): Statistics = {
// There should be some overhead in Row object, the size should not be zero when there is
// no columns, this help to prevent divide-by-zero error.
val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8
val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8
val childRowSize = EstimationUtils.getSizePerRow(p.child.output)
val outputRowSize = EstimationUtils.getSizePerRow(p.output)
// Assume there will be the same number of rows as child has.
var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize
if (sizeInBytes == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,21 @@ import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
import org.apache.spark.sql.streaming.{OutputMode, Trigger}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils


object MemoryStream {
protected val currentBlockId = new AtomicInteger(0)
protected val memoryStreamId = new AtomicInteger(0)
Expand Down Expand Up @@ -307,7 +305,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink
case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode {
def this(sink: MemorySink) = this(sink, sink.schema.toAttributes)

private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum
private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes)

override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update}
import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport}
Expand Down Expand Up @@ -182,7 +183,7 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode)
* Used to query the data that has been written into a [[MemorySinkV2]].
*/
case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode {
private val sizePerRow = output.map(_.dataType.defaultSize).sum
private val sizePerRow = EstimationUtils.getSizePerRow(output)

override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
}

assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
assert(sizes.head === BigInt(96),
assert(sizes.head === BigInt(128),
s"expected exact size 96 for table 'test', got: ${sizes.head}")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {

sink.addBatch(0, 1 to 3)
plan.invalidateStatsCache()
assert(plan.stats.sizeInBytes === 12)
assert(plan.stats.sizeInBytes === 36)

sink.addBatch(1, 4 to 6)
plan.invalidateStatsCache()
assert(plan.stats.sizeInBytes === 24)
assert(plan.stats.sizeInBytes === 72)
}

ignore("stress test") {
Expand Down

0 comments on commit 487faf1

Please sign in to comment.