Skip to content

Commit

Permalink
SPARK-1868: Users should be allowed to cogroup at least 4 RDDs
Browse files Browse the repository at this point in the history
Adds cogroup for 4 RDDs.

Author: Allan Douglas R. de Oliveira <[email protected]>

Closes apache#813 from douglaz/more_cogroups and squashes the following commits:

f8d6273 [Allan Douglas R. de Oliveira] Test python groupWith for one more case
0e9009c [Allan Douglas R. de Oliveira] Added scala tests
c3ffcdd [Allan Douglas R. de Oliveira] Added java tests
517a67f [Allan Douglas R. de Oliveira] Added tests for python groupWith
2f402d5 [Allan Douglas R. de Oliveira] Removed TODO
17474f4 [Allan Douglas R. de Oliveira] Use new cogroup function
7877a2a [Allan Douglas R. de Oliveira] Fixed code
ba02414 [Allan Douglas R. de Oliveira] Added varargs cogroup to pyspark
c4a8a51 [Allan Douglas R. de Oliveira] Added java cogroup 4
e94963c [Allan Douglas R. de Oliveira] Fixed spacing
f1ee57b [Allan Douglas R. de Oliveira] Fixed scala style issues
d7196f1 [Allan Douglas R. de Oliveira] Allow the cogroup of 4 RDDs
  • Loading branch information
douglaz authored and pwendell committed Jun 20, 2014
1 parent d484dde commit 6a224c3
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 17 deletions.
51 changes: 51 additions & 0 deletions core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
partitioner: Partitioner): JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner)))

/**
* For each key k in `this` or `other1` or `other2` or `other3`,
* return a resulting RDD that contains a tuple with the list of values
* for that key in `this`, `other1`, `other2` and `other3`.
*/
def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1],
other2: JavaPairRDD[K, W2],
other3: JavaPairRDD[K, W3],
partitioner: Partitioner)
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, partitioner)))

/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
Expand All @@ -558,6 +570,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2)))

/**
* For each key k in `this` or `other1` or `other2` or `other3`,
* return a resulting RDD that contains a tuple with the list of values
* for that key in `this`, `other1`, `other2` and `other3`.
*/
def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1],
other2: JavaPairRDD[K, W2],
other3: JavaPairRDD[K, W3])
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3)))

/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
Expand All @@ -574,6 +597,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions)))

/**
* For each key k in `this` or `other1` or `other2` or `other3`,
* return a resulting RDD that contains a tuple with the list of values
* for that key in `this`, `other1`, `other2` and `other3`.
*/
def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1],
other2: JavaPairRDD[K, W2],
other3: JavaPairRDD[K, W3],
numPartitions: Int)
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, numPartitions)))

/** Alias for cogroup. */
def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JIterable[V], JIterable[W])] =
fromRDD(cogroupResultToJava(rdd.groupWith(other)))
Expand All @@ -583,6 +618,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2)))

/** Alias for cogroup. */
def groupWith[W1, W2, W3](other1: JavaPairRDD[K, W1],
other2: JavaPairRDD[K, W2],
other3: JavaPairRDD[K, W3])
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
fromRDD(cogroupResult3ToJava(rdd.groupWith(other1, other2, other3)))

/**
* Return the list of values in the RDD for key `key`. This operation is done efficiently if the
* RDD has a known partitioner by only searching the partition that the key maps to.
Expand Down Expand Up @@ -786,6 +828,15 @@ object JavaPairRDD {
.mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3)))
}

private[spark]
def cogroupResult3ToJava[K: ClassTag, V, W1, W2, W3](
rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))])
: RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3]))] = {
rddToPairRDDFunctions(rdd)
.mapValues(x =>
(asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3), asJavaIterable(x._4)))
}

def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = {
new JavaPairRDD[K, V](rdd)
}
Expand Down
51 changes: 51 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,28 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
new FlatMappedValuesRDD(self, cleanF)
}

/**
* For each key k in `this` or `other1` or `other2` or `other3`,
* return a resulting RDD that contains a tuple with the list of values
* for that key in `this`, `other1`, `other2` and `other3`.
*/
def cogroup[W1, W2, W3](other1: RDD[(K, W1)],
other2: RDD[(K, W2)],
other3: RDD[(K, W3)],
partitioner: Partitioner)
: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) {
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](Seq(self, other1, other2, other3), partitioner)
cg.mapValues { case Seq(vs, w1s, w2s, w3s) =>
(vs.asInstanceOf[Seq[V]],
w1s.asInstanceOf[Seq[W1]],
w2s.asInstanceOf[Seq[W2]],
w3s.asInstanceOf[Seq[W3]])
}
}

/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
Expand Down Expand Up @@ -599,6 +621,16 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}
}

/**
* For each key k in `this` or `other1` or `other2` or `other3`,
* return a resulting RDD that contains a tuple with the list of values
* for that key in `this`, `other1`, `other2` and `other3`.
*/
def cogroup[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)])
: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3))
}

/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
Expand Down Expand Up @@ -633,6 +665,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
cogroup(other1, other2, new HashPartitioner(numPartitions))
}

/**
* For each key k in `this` or `other1` or `other2` or `other3`,
* return a resulting RDD that contains a tuple with the list of values
* for that key in `this`, `other1`, `other2` and `other3`.
*/
def cogroup[W1, W2, W3](other1: RDD[(K, W1)],
other2: RDD[(K, W2)],
other3: RDD[(K, W3)],
numPartitions: Int)
: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
cogroup(other1, other2, other3, new HashPartitioner(numPartitions))
}

/** Alias for cogroup. */
def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = {
cogroup(other, defaultPartitioner(self, other))
Expand All @@ -644,6 +689,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
cogroup(other1, other2, defaultPartitioner(self, other1, other2))
}

/** Alias for cogroup. */
def groupWith[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)])
: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3))
}

/**
* Return an RDD with the pairs from `this` whose keys are not in `other`.
*
Expand Down
63 changes: 63 additions & 0 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import java.util.*;

import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;


import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
Expand Down Expand Up @@ -304,6 +307,66 @@ public void cogroup() {
cogrouped.collect();
}

@SuppressWarnings("unchecked")
@Test
public void cogroup3() {
JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
new Tuple2<String, String>("Apples", "Fruit"),
new Tuple2<String, String>("Oranges", "Fruit"),
new Tuple2<String, String>("Oranges", "Citrus")
));
JavaPairRDD<String, Integer> prices = sc.parallelizePairs(Arrays.asList(
new Tuple2<String, Integer>("Oranges", 2),
new Tuple2<String, Integer>("Apples", 3)
));
JavaPairRDD<String, Integer> quantities = sc.parallelizePairs(Arrays.asList(
new Tuple2<String, Integer>("Oranges", 21),
new Tuple2<String, Integer>("Apples", 42)
));

JavaPairRDD<String, Tuple3<Iterable<String>, Iterable<Integer>, Iterable<Integer>>> cogrouped =
categories.cogroup(prices, quantities);
Assert.assertEquals("[Fruit, Citrus]",
Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3()));


cogrouped.collect();
}

@SuppressWarnings("unchecked")
@Test
public void cogroup4() {
JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
new Tuple2<String, String>("Apples", "Fruit"),
new Tuple2<String, String>("Oranges", "Fruit"),
new Tuple2<String, String>("Oranges", "Citrus")
));
JavaPairRDD<String, Integer> prices = sc.parallelizePairs(Arrays.asList(
new Tuple2<String, Integer>("Oranges", 2),
new Tuple2<String, Integer>("Apples", 3)
));
JavaPairRDD<String, Integer> quantities = sc.parallelizePairs(Arrays.asList(
new Tuple2<String, Integer>("Oranges", 21),
new Tuple2<String, Integer>("Apples", 42)
));
JavaPairRDD<String, String> countries = sc.parallelizePairs(Arrays.asList(
new Tuple2<String, String>("Oranges", "BR"),
new Tuple2<String, String>("Apples", "US")
));

JavaPairRDD<String, Tuple4<Iterable<String>, Iterable<Integer>, Iterable<Integer>, Iterable<String>>> cogrouped =
categories.cogroup(prices, quantities, countries);
Assert.assertEquals("[Fruit, Citrus]",
Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3()));
Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4()));

cogrouped.collect();
}

@SuppressWarnings("unchecked")
@Test
public void leftOuterJoin() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,39 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
))
}

test("groupWith3") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd')))
val joined = rdd1.groupWith(rdd2, rdd3).collect()
assert(joined.size === 4)
val joinedSet = joined.map(x => (x._1,
(x._2._1.toList, x._2._2.toList, x._2._3.toList))).toSet
assert(joinedSet === Set(
(1, (List(1, 2), List('x'), List('a'))),
(2, (List(1), List('y', 'z'), List())),
(3, (List(1), List(), List('b'))),
(4, (List(), List('w'), List('c', 'd')))
))
}

test("groupWith4") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd')))
val rdd4 = sc.parallelize(Array((2, '@')))
val joined = rdd1.groupWith(rdd2, rdd3, rdd4).collect()
assert(joined.size === 4)
val joinedSet = joined.map(x => (x._1,
(x._2._1.toList, x._2._2.toList, x._2._3.toList, x._2._4.toList))).toSet
assert(joinedSet === Set(
(1, (List(1, 2), List('x'), List('a'), List())),
(2, (List(1), List('y', 'z'), List(), List('@'))),
(3, (List(1), List(), List('b'), List())),
(4, (List(), List('w'), List('c', 'd'), List()))
))
}

test("zero-partition RDD") {
val emptyDir = Files.createTempDir()
emptyDir.deleteOnExit()
Expand Down
20 changes: 10 additions & 10 deletions python/pyspark/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ def dispatch(seq):
return _do_python_join(rdd, other, numPartitions, dispatch)


def python_cogroup(rdd, other, numPartitions):
vs = rdd.map(lambda (k, v): (k, (1, v)))
ws = other.map(lambda (k, v): (k, (2, v)))
def python_cogroup(rdds, numPartitions):
def make_mapper(i):
return lambda (k, v): (k, (i, v))
vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)]
union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds)
rdd_len = len(vrdds)
def dispatch(seq):
vbuf, wbuf = [], []
bufs = [[] for i in range(rdd_len)]
for (n, v) in seq:
if n == 1:
vbuf.append(v)
elif n == 2:
wbuf.append(v)
return (ResultIterable(vbuf), ResultIterable(wbuf))
return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch)
bufs[n].append(v)
return tuple(map(ResultIterable, bufs))
return union_vrdds.groupByKey(numPartitions).mapValues(dispatch)
22 changes: 15 additions & 7 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,7 @@ def _mergeCombiners(iterator):
combiners[k] = mergeCombiners(combiners[k], v)
return combiners.iteritems()
return shuffled.mapPartitions(_mergeCombiners)

def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
"""
Aggregate the values of each key, using given combine functions and a neutral "zero value".
Expand All @@ -1245,7 +1245,7 @@ def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
"""
def createZero():
return copy.deepcopy(zeroValue)

return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions)

def foldByKey(self, zeroValue, func, numPartitions=None):
Expand Down Expand Up @@ -1323,12 +1323,20 @@ def mapValues(self, f):
map_values_fn = lambda (k, v): (k, f(v))
return self.map(map_values_fn, preservesPartitioning=True)

# TODO: support varargs cogroup of several RDDs.
def groupWith(self, other):
def groupWith(self, other, *others):
"""
Alias for cogroup.
Alias for cogroup but with support for multiple RDDs.
>>> w = sc.parallelize([("a", 5), ("b", 6)])
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
>>> z = sc.parallelize([("b", 42)])
>>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), list(y[3])))), \
sorted(list(w.groupWith(x, y, z).collect())))
[('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))]
"""
return self.cogroup(other)
return python_cogroup((self, other) + others, numPartitions=None)

# TODO: add variant with custom parittioner
def cogroup(self, other, numPartitions=None):
Expand All @@ -1342,7 +1350,7 @@ def cogroup(self, other, numPartitions=None):
>>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect())))
[('a', ([1], [2])), ('b', ([4], []))]
"""
return python_cogroup(self, other, numPartitions)
return python_cogroup((self, other), numPartitions)

def subtractByKey(self, other, numPartitions=None):
"""
Expand Down

0 comments on commit 6a224c3

Please sign in to comment.