From 16a4ca45373cd4b75032f88668610d9b693fb4b3 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Thu, 14 Mar 2013 13:58:37 -0700 Subject: [PATCH] restrict V type of foldByKey in order to retain ClassManifest; added foldByKey to Java API and test --- .../main/scala/spark/PairRDDFunctions.scala | 4 ++-- .../scala/spark/api/java/JavaPairRDD.scala | 6 +++++ core/src/test/scala/spark/JavaAPISuite.java | 24 ++++++++++++++++++- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index a6e00c3a84..0fde902261 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -91,8 +91,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( /** * Merge the values for each key using an associative function and a neutral "zero value". */ - def foldByKey[V1 >: V](zeroValue: V1)(op: (V1, V1) => V1): RDD[(K, V1)] = { - groupByKey.mapValues(seq => seq.fold(zeroValue)(op)) + def foldByKey(zeroValue: V)(op: (V, V) => V): RDD[(K, V)] = { + groupByKey.mapValues(seq => seq.fold[V](zeroValue)(op)) } /** diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala index c1bd13c49a..1e1c910202 100644 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala @@ -160,6 +160,12 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif : PartialResult[java.util.Map[K, BoundedDouble]] = rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) + /** + * Merge the values for each key using an associative function and a neutral "zero value". + */ + def foldByKey(zeroValue: V, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = + fromRDD(rdd.foldByKey(zeroValue)(func)) + /** * Merge the values for each key using an associative reduce function. This will also perform * the merging locally on each mapper before sending results to a reducer, similarly to a diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 26e3ab72c0..b83076b929 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -196,7 +196,29 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(33, sum); } - @Test + @Test + public void foldByKey() { + List> pairs = Arrays.asList( + new Tuple2(2, 1), + new Tuple2(2, 1), + new Tuple2(1, 1), + new Tuple2(3, 2), + new Tuple2(3, 1) + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + JavaPairRDD sums = rdd.foldByKey(0, + new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }); + Assert.assertEquals(1, sums.lookup(1).get(0).intValue()); + Assert.assertEquals(2, sums.lookup(2).get(0).intValue()); + Assert.assertEquals(3, sums.lookup(3).get(0).intValue()); + } + + @Test public void reduceByKey() { List> pairs = Arrays.asList( new Tuple2(2, 1),