Skip to content

Commit

Permalink
[SNAP-1067] Optimizations seen in perf analysis related to SnappyData…
Browse files Browse the repository at this point in the history
… PR#381 (#11)

 - added hashCode/equals to UnsafeMapData and optimized hashing/equals for Decimal
   (assuming scale is same for both as in the calls from Spark layer)
 - optimizations to UTF8String: cached "isAscii" and "hash"
 - more efficient ByteArrayMethods.arrayEquals (~3ns vs ~9ns for 15 byte array)
 - reverting aggregate attribute changes (nullability optimization) from Spark layer and instead take care of it on the SnappyData layer; also reverted other changes in HashAggregateExec made earlier for AQP and nullability
- copy spark-version-info in generateSources target for IDEA
Conflicts:
	common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java

Conflicts:
	common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
	sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
  • Loading branch information
Sumedh Wale authored and ymahajan committed Mar 2, 2018
1 parent 0c52ebd commit 59d8076
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 90 deletions.
7 changes: 7 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,13 @@ subprojects {
task generateSources {
dependsOn subprojectBase + 'snappy-spark-catalyst_' + scalaBinaryVersion + ':generateGrammarSource'
dependsOn subprojectBase + 'snappy-spark-streaming-flume-sink_' + scalaBinaryVersion + ':generateAvroJava'
// copy extra-resources in normal resource path for IDEA
def coreProject = project(subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion)
copy {
from "${coreProject.buildDir}/extra-resources"
include 'spark-version-info.properties'
into "${coreProject.buildDir}/resources/main"
}
}

if (rootProject.name == 'snappy-spark') {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,22 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) {
public static int MAX_ROUNDED_ARRAY_LENGTH = Integer.MAX_VALUE - 15;

private static final boolean unaligned = Platform.unaligned();

/**
* Optimized byte array equality check for byte arrays.
* @return true if the arrays are equal, false otherwise
*/
public static boolean arrayEquals(
Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) {
int i = 0;

// check if stars align and we can get both offsets to be aligned
if ((leftOffset % 8) == (rightOffset % 8)) {
while ((leftOffset + i) % 8 != 0 && i < length) {
if (Platform.getByte(leftBase, leftOffset + i) !=
Platform.getByte(rightBase, rightOffset + i)) {
return false;
}
i += 1;
final Object leftBase, long leftOffset, final Object rightBase,
long rightOffset, final long length) {
long endOffset = leftOffset + length - 8;
while (leftOffset <= endOffset) {
if (Platform.getLong(leftBase, leftOffset) !=
Platform.getLong(rightBase, rightOffset)) {
return false;
}
leftOffset += 8;
rightOffset += 8;
}
// for architectures that support unaligned accesses, chew it up 8 bytes at a time
if (unaligned || (((leftOffset + i) % 8 == 0) && ((rightOffset + i) % 8 == 0))) {
Expand All @@ -75,15 +74,17 @@ public static boolean arrayEquals(
}
i += 8;
}
leftOffset += 4;
rightOffset += 4;
}
// this will finish off the unaligned comparisons, or do the entire aligned
// comparison whichever is needed.
while (i < length) {
if (Platform.getByte(leftBase, leftOffset + i) !=
Platform.getByte(rightBase, rightOffset + i)) {
return false;
endOffset += 4;
while (leftOffset < endOffset) {
if (Platform.getByte(leftBase, leftOffset) !=
Platform.getByte(rightBase, rightOffset)) {
return false;
}
i += 1;
leftOffset++;
rightOffset++;
}
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
private long offset;
private int numBytes;

private transient int hash;
private transient boolean isAscii;

public Object getBaseObject() { return base; }
public long getBaseOffset() { return offset; }

Expand Down Expand Up @@ -187,6 +190,7 @@ public void writeTo(OutputStream out) throws IOException {
* @param b The first byte of a code point
*/
private static int numBytesForFirstByte(final byte b) {
if (b >= 0) return 1;
final int offset = (b & 0xFF) - 192;
return (offset >= 0) ? bytesOfCodePointInUTF8[offset] : 1;
}
Expand All @@ -202,10 +206,14 @@ public int numBytes() {
* Returns the number of code points in it.
*/
public int numChars() {
if (isAscii) return numBytes;
final long endOffset = offset + numBytes;
int len = 0;
for (int i = 0; i < numBytes; i += numBytesForFirstByte(getByte(i))) {
len += 1;
for (long offset = this.offset; offset < endOffset;
offset += numBytesForFirstByte(Platform.getByte(base, offset))) {
len++;
}
if (len == numBytes) isAscii = true;
return len;
}

Expand Down Expand Up @@ -332,7 +340,7 @@ public boolean contains(final UTF8String substring) {
/**
* Returns the byte at position `i`.
*/
private byte getByte(int i) {
public byte getByte(int i) {
return Platform.getByte(base, offset + i);
}

Expand Down Expand Up @@ -1256,6 +1264,12 @@ public boolean equals(final Object other) {
}
}

public boolean equals(final UTF8String o) {
final int numBytes = this.numBytes;
return o != null && numBytes == o.numBytes && ByteArrayMethods.arrayEquals(
base, offset, o.base, o.offset, numBytes);
}

/**
* Levenshtein distance is a metric for measuring the distance of two strings. The distance is
* defined by the minimum number of single-character edits (i.e. insertions, deletions or
Expand Down Expand Up @@ -1322,7 +1336,10 @@ public int levenshteinDistance(UTF8String other) {

@Override
public int hashCode() {
return Murmur3_x86_32.hashUnsafeBytes(base, offset, numBytes, 42);
final int h = this.hash;
if (h != 0) return h;
return (this.hash = Murmur3_x86_32.hashUnsafeBytes(
base, offset, numBytes, 42));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit

override lazy val aggBufferAttributes = sum :: count :: Nil

override lazy val aggBufferAttributesForGroup: Seq[AttributeReference] = {
if (child.nullable) aggBufferAttributes
else sum.copy(nullable = false)(sum.exprId, sum.qualifier,
sum.isGenerated) :: count :: Nil
}

override lazy val initialValues = Seq(
/* sum = */ Cast(Literal(0), sumDataType),
/* count = */ Literal(0L)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,10 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast

override lazy val aggBufferAttributes = sum :: Nil

override lazy val aggBufferAttributesForGroup: Seq[AttributeReference] = {
if (child.nullable) aggBufferAttributes
else sum.copy(nullable = false)(sum.exprId, sum.qualifier,
sum.isGenerated) :: Nil
}

override lazy val initialValues: Seq[Expression] = Seq(
/* sum = */ Literal.create(null, sumDataType)
)

override lazy val initialValuesForGroup: Seq[Expression] = Seq(
/* sum = */
if (child.nullable) Literal.create(null, sumDataType)
else Cast(Literal(0), sumDataType)
)

override lazy val updateExpressions: Seq[Expression] = {
if (child.nullable) {
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,6 @@ abstract class AggregateFunction extends Expression {
/** Attributes of fields in aggBufferSchema. */
def aggBufferAttributes: Seq[AttributeReference]

/** Attributes of fields in aggBufferSchema used for group by. */
def aggBufferAttributesForGroup: Seq[AttributeReference] = aggBufferAttributes

/**
* Attributes of fields in input aggregation buffers (immutable aggregation buffers that are
* merged with mutable aggregation buffers in the merge() function or merge expressions).
Expand Down Expand Up @@ -378,11 +375,6 @@ abstract class DeclarativeAggregate
*/
val initialValues: Seq[Expression]

/**
* Expressions for initializing empty aggregation buffers for group by.
*/
def initialValuesForGroup: Seq[Expression] = initialValues

/**
* Expressions for updating the mutable aggregation buffer based on an input row.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,14 @@ object DecimalAggregates extends Rule[LogicalPlan] {
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4), Option(SQLConf.get.sessionLocalTimeZone))

case Max(e @ DecimalType.Expression(prec, scale)) if prec <= MAX_LONG_DIGITS =>
MakeDecimal(we.copy(windowFunction = ae.copy(
aggregateFunction = Max(UnscaledValue(e)))), prec, scale)

case Min(e @ DecimalType.Expression(prec, scale)) if prec <= MAX_LONG_DIGITS =>
MakeDecimal(we.copy(windowFunction = ae.copy(
aggregateFunction = Min(UnscaledValue(e)))), prec, scale)

case _ => we
}
case ae @ AggregateExpression(af, _, _, _) => af match {
Expand All @@ -1170,6 +1178,12 @@ object DecimalAggregates extends Rule[LogicalPlan] {
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4), Option(SQLConf.get.sessionLocalTimeZone))

case Max(e @ DecimalType.Expression(prec, scale)) if prec <= MAX_LONG_DIGITS =>
MakeDecimal(ae.copy(aggregateFunction = Max(UnscaledValue(e))), prec, scale)

case Min(e @ DecimalType.Expression(prec, scale)) if prec <= MAX_LONG_DIGITS =>
MakeDecimal(ae.copy(aggregateFunction = Min(UnscaledValue(e))), prec, scale)

case _ => ae
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def toJavaBigInteger: java.math.BigInteger = java.math.BigInteger.valueOf(toLong)

def toUnscaledLong: Long = {
if (decimalVal.ne(null)) {
decimalVal.underlying().unscaledValue().longValueExact()
} else {
if (decimalVal eq null) {
longVal
} else {
decimalVal.underlying().unscaledValue().longValueExact()
}
}

Expand Down Expand Up @@ -339,14 +339,41 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}

override def equals(other: Any): Boolean = other match {
case d: Decimal =>
compare(d) == 0
case _ =>
false
case d: Decimal => equals(d)
case _ => false
}

override def hashCode(): Int = toBigDecimal.hashCode()

def equals(other: Decimal): Boolean = {
if (other != null) {
val decimalVal = this.decimalVal
val otherDecimalVal = other.decimalVal
if (decimalVal eq null) {
if (otherDecimalVal eq null) {
if (_scale == other._scale) longVal == other.longVal
else toJavaBigDecimal.equals(other.toJavaBigDecimal)
} else {
toJavaBigDecimal.equals(otherDecimalVal.bigDecimal)
}
} else if (otherDecimalVal ne null) {
decimalVal.bigDecimal.equals(otherDecimalVal.bigDecimal)
} else {
decimalVal.bigDecimal.equals(other.toJavaBigDecimal)
}
} else false
}

def fastHashCode(): Int = {
val decimalVal = this.decimalVal
if (decimalVal != null) {
decimalVal.bigDecimal.hashCode()
} else {
val longVal = this.longVal
(longVal ^ (longVal >>> 32)).toInt
}
}

def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0

def + (that: Decimal): Decimal = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ object AggUtils {
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
__resultExpressions = resultExpressions,
resultExpressions = resultExpressions,
child = child)
} else {
val objectHashEnabled = child.sqlContext.conf.useObjectHashAggregation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Changes for SnappyData data platform.
*
* Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You
* may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License. See accompanying
* LICENSE file.
*/

package org.apache.spark.sql.execution.aggregate

Expand Down Expand Up @@ -60,20 +42,14 @@ case class HashAggregateExec(
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
__resultExpressions: Seq[NamedExpression],
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode with CodegenSupport {

@transient lazy val resultExpressions = __resultExpressions

@transient lazy private[this] val aggregateBufferAttributes = {
private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}

@transient lazy private[this] val aggregateBufferAttributesForGroup = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributesForGroup)
}

require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))

override lazy val allAttributes: AttributeSeq =
Expand Down Expand Up @@ -320,7 +296,7 @@ case class HashAggregateExec(
private val declFunctions = aggregateExpressions.map(_.aggregateFunction)
.filter(_.isInstanceOf[DeclarativeAggregate])
.map(_.asInstanceOf[DeclarativeAggregate])
private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributesForGroup)
private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes)

// The name for Fast HashMap
b private var fastHashMapTerm: String = _
Expand All @@ -340,7 +316,7 @@ b private var fastHashMapTerm: String = _
*/
def createHashMap(): UnsafeFixedWidthAggregationMap = {
// create initialized aggregate buffer
val initExpr = declFunctions.flatMap(_.initialValuesForGroup)
val initExpr = declFunctions.flatMap(f => f.initialValues)
val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow)

// create hashMap
Expand Down Expand Up @@ -409,7 +385,7 @@ b private var fastHashMapTerm: String = _
val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
val mergeProjection = newMutableProjection(
mergeExpr,
aggregateBufferAttributesForGroup ++ declFunctions.flatMap(_.inputAggBufferAttributes),
aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
subexpressionEliminationEnabled)
val joinedRow = new JoinedRow()

Expand Down Expand Up @@ -477,14 +453,14 @@ b private var fastHashMapTerm: String = _
}
val evaluateKeyVars = evaluateVariables(keyVars)
ctx.INPUT_ROW = bufferTerm
val bufferVars = aggregateBufferAttributesForGroup.zipWithIndex.map { case (e, i) =>
val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).genCode(ctx)
}
val evaluateBufferVars = evaluateVariables(bufferVars)
// evaluate the aggregation result
ctx.currentVars = bufferVars
val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
BindReferences.bindReference(e, aggregateBufferAttributesForGroup).genCode(ctx)
BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
}
val evaluateAggResults = evaluateVariables(aggResults)
// generate the final result
Expand Down Expand Up @@ -779,8 +755,8 @@ b private var fastHashMapTerm: String = _
val hashExpr = Murmur3Hash(groupingExpressions, 42)
val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx)

val inputAttr = aggregateBufferAttributesForGroup ++ child.output
ctx.currentVars = new Array[ExprCode](aggregateBufferAttributesForGroup.length) ++ input
val inputAttr = aggregateBufferAttributes ++ child.output
ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input

val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter,
incCounter) = if (testFallbackStartsAt.isDefined) {
Expand Down
Loading

0 comments on commit 59d8076

Please sign in to comment.