From 1ef4f0fbd27e54803f14fed1df541fb341daced8 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 26 Sep 2012 19:18:47 -0700 Subject: [PATCH] Allow controlling number of splits in sortByKey. --- .../main/scala/spark/PairRDDFunctions.scala | 4 +- core/src/main/scala/spark/ShuffledRDD.scala | 9 ++-- .../scala/spark/deploy/client/Client.scala | 1 - core/src/test/scala/spark/SortingSuite.scala | 48 +++++++++++++++++-- 4 files changed, 50 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index aa1d00c63c..4752bf8d9f 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -435,8 +435,8 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( extends Logging with Serializable { - def sortByKey(ascending: Boolean = true): RDD[(K,V)] = { - new ShuffledSortedRDD(self, ascending) + def sortByKey(ascending: Boolean = true, numSplits: Int = self.splits.size): RDD[(K,V)] = { + new ShuffledSortedRDD(self, ascending, numSplits) } } diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala index be75890a40..7c11925f86 100644 --- a/core/src/main/scala/spark/ShuffledRDD.scala +++ b/core/src/main/scala/spark/ShuffledRDD.scala @@ -16,7 +16,7 @@ class ShuffledRDDSplit(val idx: Int) extends Split { abstract class ShuffledRDD[K, V, C]( @transient parent: RDD[(K, V)], aggregator: Aggregator[K, V, C], - part : Partitioner) + part: Partitioner) extends RDD[(K, C)](parent.context) { override val partitioner = Some(part) @@ -38,7 +38,7 @@ abstract class ShuffledRDD[K, V, C]( */ class RepartitionShuffledRDD[K, V]( @transient parent: RDD[(K, V)], - part : Partitioner) + part: Partitioner) extends ShuffledRDD[K, V, V]( parent, Aggregator[K, V, V](null, null, null, false), @@ -60,10 +60,11 @@ class RepartitionShuffledRDD[K, V]( */ class ShuffledSortedRDD[K <% Ordered[K]: ClassManifest, V]( @transient parent: RDD[(K, V)], - ascending: Boolean) + ascending: Boolean, + numSplits: Int) extends RepartitionShuffledRDD[K, V]( parent, - new RangePartitioner(parent.splits.size, parent, ascending)) { + new RangePartitioner(numSplits, parent, ascending)) { override def compute(split: Split): Iterator[(K, V)] = { // By separating this from RepartitionShuffledRDD, we avoided a diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala index c7fa8a3874..a2f88fc5e5 100644 --- a/core/src/main/scala/spark/deploy/client/Client.scala +++ b/core/src/main/scala/spark/deploy/client/Client.scala @@ -42,7 +42,6 @@ class Client( val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort) try { master = context.actorFor(akkaUrl) - //master ! RegisterWorker(ip, port, cores, memory) master ! RegisterJob(jobDescription) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(master) // Doesn't work with remote actors, but useful for testing diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala index 188a9b564e..c87595ecb3 100644 --- a/core/src/test/scala/spark/SortingSuite.scala +++ b/core/src/test/scala/spark/SortingSuite.scala @@ -17,7 +17,7 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with test("sortByKey") { sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0))) + val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2) assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) } @@ -25,18 +25,56 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr) - assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) + val pairs = sc.parallelize(pairArr, 2) + val sorted = pairs.sortByKey() + assert(sorted.splits.size === 2) + assert(sorted.collect() === pairArr.sortBy(_._1)) } + test("large array with one split") { + sc = new SparkContext("local", "test") + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + val sorted = pairs.sortByKey(true, 1) + assert(sorted.splits.size === 1) + assert(sorted.collect() === pairArr.sortBy(_._1)) + } + + test("large array with many splits") { + sc = new SparkContext("local", "test") + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + val sorted = pairs.sortByKey(true, 20) + assert(sorted.splits.size === 20) + assert(sorted.collect() === pairArr.sortBy(_._1)) + } + test("sort descending") { sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr) + val pairs = sc.parallelize(pairArr, 2) assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) } + test("sort descending with one split") { + sc = new SparkContext("local", "test") + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 1) + assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) + } + + test("sort descending with many splits") { + sc = new SparkContext("local", "test") + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + assert(pairs.sortByKey(false, 20).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) + } + test("more partitions than elements") { sc = new SparkContext("local", "test") val rand = new scala.util.Random() @@ -48,7 +86,7 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with test("empty RDD") { sc = new SparkContext("local", "test") val pairArr = new Array[(Int, Int)](0) - val pairs = sc.parallelize(pairArr) + val pairs = sc.parallelize(pairArr, 2) assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) }