Skip to content

Commit

Permalink
[SPARK-23381][CORE] Murmur3 hash generates a different value from oth…
Browse files Browse the repository at this point in the history
…er implementations

## What changes were proposed in this pull request?
Murmur3 hash generates a different value from the original and other implementations (like Scala standard library and Guava or so) when the length of a bytes array is not multiple of 4.

## How was this patch tested?
Added a unit test.

**Note: When we merge this PR, please give all the credits to Shintaro Murakami.**

Author: Shintaro Murakami <mrkm4ntrgmail.com>

Author: gatorsmile <[email protected]>
Author: Shintaro Murakami <[email protected]>

Closes #20630 from gatorsmile/pr-20568.
  • Loading branch information
mrkm4ntr authored and gatorsmile committed Feb 17, 2018
1 parent 0a73aa3 commit d5ed210
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i
}

public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
// This is not compatible with original and another implementations.
// But remain it for backward compatibility for the components existing before 2.3.
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
int lengthAligned = lengthInBytes - lengthInBytes % 4;
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
Expand All @@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i
return fmix(h1, lengthInBytes);
}

public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) {
// This is compatible with original and another implementations.
// Use this method for new components after Spark 2.3.
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
int lengthAligned = lengthInBytes - lengthInBytes % 4;
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
int k1 = 0;
for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) {
k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift;
}
h1 ^= mixK1(k1);
return fmix(h1, lengthInBytes);
}

private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
assert (lengthInBytes % 4 == 0);
int h1 = seed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i
}

public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
// This is not compatible with original and another implementations.
// But remain it for backward compatibility for the components existing before 2.3.
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
int lengthAligned = lengthInBytes - lengthInBytes % 4;
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
Expand All @@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i
return fmix(h1, lengthInBytes);
}

public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) {
// This is compatible with original and another implementations.
// Use this method for new components after Spark 2.3.
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
int lengthAligned = lengthInBytes - lengthInBytes % 4;
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
int k1 = 0;
for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) {
k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift;
}
h1 ^= mixK1(k1);
return fmix(h1, lengthInBytes);
}

private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
assert (lengthInBytes % 4 == 0);
int h1 = seed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.util.Random;
import java.util.Set;

import scala.util.hashing.MurmurHash3$;

import org.apache.spark.unsafe.Platform;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -51,6 +53,23 @@ public void testKnownLongInputs() {
Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE));
}

// SPARK-23381 Check whether the hash of the byte array is the same as another implementations
@Test
public void testKnownBytesInputs() {
byte[] test = "test".getBytes(StandardCharsets.UTF_8);
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test, 0),
Murmur3_x86_32.hashUnsafeBytes2(test, Platform.BYTE_ARRAY_OFFSET, test.length, 0));
byte[] test1 = "test1".getBytes(StandardCharsets.UTF_8);
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test1, 0),
Murmur3_x86_32.hashUnsafeBytes2(test1, Platform.BYTE_ARRAY_OFFSET, test1.length, 0));
byte[] te = "te".getBytes(StandardCharsets.UTF_8);
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(te, 0),
Murmur3_x86_32.hashUnsafeBytes2(te, Platform.BYTE_ARRAY_OFFSET, te.length, 0));
byte[] tes = "tes".getBytes(StandardCharsets.UTF_8);
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(tes, 0),
Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0));
}

@Test
public void randomizedStressTest() {
int size = 65536;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.feature

import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
Expand All @@ -28,6 +29,8 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.OpenHashMap

Expand Down Expand Up @@ -138,7 +141,7 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme

@Since("2.3.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val hashFunc: Any => Int = OldHashingTF.murmur3Hash
val hashFunc: Any => Int = FeatureHasher.murmur3Hash
val n = $(numFeatures)
val localInputCols = $(inputCols)
val catCols = if (isSet(categoricalCols)) {
Expand Down Expand Up @@ -218,4 +221,32 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] {

@Since("2.3.0")
override def load(path: String): FeatureHasher = super.load(path)

private val seed = OldHashingTF.seed

/**
* Calculate a hash code value for the term object using
* Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32).
* This is the default hash algorithm used from Spark 2.0 onwards.
* Use hashUnsafeBytes2 to match the original algorithm with the value.
* See SPARK-23381.
*/
@Since("2.3.0")
private[feature] def murmur3Hash(term: Any): Int = {
term match {
case null => seed
case b: Boolean => hashInt(if (b) 1 else 0, seed)
case b: Byte => hashInt(b, seed)
case s: Short => hashInt(s, seed)
case i: Int => hashInt(i, seed)
case l: Long => hashLong(l, seed)
case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed)
case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed)
case s: String =>
val utf8 = UTF8String.fromString(s)
hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed)
case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " +
s"support type ${term.getClass.getCanonicalName} of input data.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ object HashingTF {

private[HashingTF] val Murmur3: String = "murmur3"

private val seed = 42
private[spark] val seed = 42

/**
* Calculate a hash code value for the term object using the native Scala implementation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class FeatureHasherSuite extends SparkFunSuite
with MLlibTestSparkContext
with DefaultReadWriteTest {

import testImplicits._

import HashingTFSuite.murmur3FeatureIdx
import FeatureHasherSuite.murmur3FeatureIdx

implicit private val vectorEncoder = ExpressionEncoder[Vector]()

Expand Down Expand Up @@ -216,3 +217,11 @@ class FeatureHasherSuite extends SparkFunSuite
testDefaultReadWrite(t)
}
}

object FeatureHasherSuite {

private[feature] def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = {
Utils.nonNegativeMod(FeatureHasher.murmur3Hash(term), numFeatures)
}

}
4 changes: 2 additions & 2 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,9 +741,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures,
>>> df = spark.createDataFrame(data, cols)
>>> hasher = FeatureHasher(inputCols=cols, outputCol="features")
>>> hasher.transform(df).head().features
SparseVector(262144, {51871: 1.0, 63643: 1.0, 174475: 2.0, 253195: 1.0})
SparseVector(262144, {174475: 2.0, 247670: 1.0, 257907: 1.0, 262126: 1.0})
>>> hasher.setCategoricalCols(["real"]).transform(df).head().features
SparseVector(262144, {51871: 1.0, 63643: 1.0, 171257: 1.0, 253195: 1.0})
SparseVector(262144, {171257: 1.0, 247670: 1.0, 257907: 1.0, 262126: 1.0})
>>> hasherPath = temp_path + "/hasher"
>>> hasher.save(hasherPath)
>>> loadedHasher = FeatureHasher.load(hasherPath)
Expand Down

0 comments on commit d5ed210

Please sign in to comment.