Skip to content

Commit

Permalink
Addresses PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed Jan 25, 2016
1 parent e97d7f9 commit 12bbefb
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.util.sketch;

import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
Expand Down Expand Up @@ -136,28 +135,7 @@ public abstract CountMinSketch mergeInPlace(CountMinSketch other)
* Reads in a {@link CountMinSketch} from an input stream.
*/
public static CountMinSketch readFrom(InputStream in) throws IOException {
DataInputStream dis = new DataInputStream(in);

// Ignores version number
dis.readInt();

long totalCount = dis.readLong();
int depth = dis.readInt();
int width = dis.readInt();

long hashA[] = new long[depth];
for (int i = 0; i < depth; ++i) {
hashA[i] = dis.readLong();
}

long table[][] = new long[depth][width];
for (int i = 0; i < depth; ++i) {
for (int j = 0; j < width; ++j) {
table[i][j] = dis.readLong();
}
}

return new CountMinSketchImpl(depth, width, totalCount, hashA, table);
return CountMinSketchImpl.readFrom(in);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,30 @@

package org.apache.spark.util.sketch;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Random;

/*
* Binary format of a serialized CountMinSketchImpl, version 1 (all values written in big-endian
* order):
*
* - Version number, always 1 (32 bit)
* - Total count of added items (64 bit)
* - Depth (32 bit)
* - Width (32 bit)
* - Hash functions (depth * 64 bit)
* - Count table
* - Row 0 (width * 64 bit)
* - Row 1 (width * 64 bit)
* - ...
* - Row depth - 1 (width * 64 bit)
*/
class CountMinSketchImpl extends CountMinSketch {
public static final long PRIME_MODULUS = (1L << 31) - 1;

Expand All @@ -35,15 +52,15 @@ class CountMinSketchImpl extends CountMinSketch {
private double eps;
private double confidence;

public CountMinSketchImpl(int depth, int width, int seed) {
CountMinSketchImpl(int depth, int width, int seed) {
this.depth = depth;
this.width = width;
this.eps = 2.0 / width;
this.confidence = 1 - 1 / Math.pow(2, depth);
initTablesWith(depth, width, seed);
}

public CountMinSketchImpl(double eps, double confidence, int seed) {
CountMinSketchImpl(double eps, double confidence, int seed) {
// 2/w = eps ; w = 2/eps
// 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence)
this.eps = eps;
Expand All @@ -53,7 +70,7 @@ public CountMinSketchImpl(double eps, double confidence, int seed) {
initTablesWith(depth, width, seed);
}

public CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) {
CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) {
this.depth = depth;
this.width = width;
this.eps = 2.0 / width;
Expand All @@ -75,24 +92,24 @@ public boolean equals(Object other) {

CountMinSketchImpl that = (CountMinSketchImpl) other;

if (this.depth == that.depth &&
this.width == that.width &&
this.totalCount == that.totalCount) {
for (int i = 0; i < depth; ++i) {
if (this.hashA[i] != that.hashA[i]) {
return false;
}

for (int j = 0; j < width; ++j) {
if (this.table[i][j] != that.table[i][j]) {
return false;
}
}
}
return true;
} else {
return false;
}
return
this.depth == that.depth &&
this.width == that.width &&
this.totalCount == that.totalCount &&
Arrays.equals(this.hashA, that.hashA) &&
Arrays.deepEquals(this.table, that.table);
}

@Override
public int hashCode() {
int hash = depth;

hash = hash * 31 + width;
hash = hash * 31 + (int) (totalCount ^ (totalCount >>> 32));
hash = hash * 31 + Arrays.hashCode(hashA);
hash = hash * 31 + Arrays.deepHashCode(table);

return hash;
}

@Override
Expand Down Expand Up @@ -324,4 +341,29 @@ public void writeTo(OutputStream out) throws IOException {
}
}
}

public static CountMinSketchImpl readFrom(InputStream in) throws IOException {
DataInputStream dis = new DataInputStream(in);

// Ignores version number
dis.readInt();

long totalCount = dis.readLong();
int depth = dis.readInt();
int width = dis.readInt();

long hashA[] = new long[depth];
for (int i = 0; i < depth; ++i) {
hashA[i] = dis.readLong();
}

long table[][] = new long[depth][width];
for (int i = 0; i < depth; ++i) {
for (int j = 0; j < width; ++j) {
table[i][j] = dis.readLong();
}
}

return new CountMinSketchImpl(depth, width, totalCount, hashA, table);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite

private val seed = 42

// Serializes and deserializes a given `CountMinSketch`, then checks whether the deserialized
// version is equivalent to the original one.
private def checkSerDe(sketch: CountMinSketch): Unit = {
val out = new ByteArrayOutputStream()
sketch.writeTo(out)
Expand All @@ -43,7 +45,8 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite

def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
test(s"accuracy - $typeName") {
val r = new Random()
// Uses fixed seed to ensure reproducible test execution
val r = new Random(31)

val numAllItems = 1000000
val allItems = Array.fill(numAllItems)(itemGenerator(r))
Expand Down

0 comments on commit 12bbefb

Please sign in to comment.