Merge pull request #208 from rxin/dev

Separated ShuffledRDD into multiple classes.
This commit is contained in:
Matei Zaharia 2012-09-24 12:32:01 -07:00
Родитель 107a5ca879 397d3816e1
Коммит f855e4fad2
5 изменённых файлов: 128 добавлений и 62 удалений

Просмотреть файл

@ -9,9 +9,9 @@ package spark
* known as map-side aggregations. When set to false,
* mergeCombiners function is not used.
*/
class Aggregator[K, V, C] (
case class Aggregator[K, V, C] (
val createCombiner: V => C,
val mergeValue: (C, V) => C,
val mergeCombiners: (C, C) => C,
val mapSideCombine: Boolean = true)
extends Serializable

Просмотреть файл

@ -1,11 +1,10 @@
package spark
import java.io.EOFException
import java.net.URL
import java.io.ObjectInputStream
import java.net.URL
import java.util.{Date, HashMap => JHashMap}
import java.util.concurrent.atomic.AtomicLong
import java.util.{HashMap => JHashMap}
import java.util.Date
import java.text.SimpleDateFormat
import scala.collection.Map
@ -50,9 +49,18 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
partitioner: Partitioner): RDD[(K, C)] = {
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
new ShuffledRDD(self, aggregator, partitioner)
partitioner: Partitioner,
mapSideCombine: Boolean = true): RDD[(K, C)] = {
val aggregator =
if (mapSideCombine) {
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
new Aggregator[K, V, C](createCombiner, mergeValue, null, false)
}
new ShuffledAggregatedRDD(self, aggregator, partitioner)
}
def combineByKey[C](createCombiner: V => C,
@ -116,13 +124,24 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
groupByKey(new HashPartitioner(numSplits))
}
def partitionBy(partitioner: Partitioner): RDD[(K, V)] = {
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
val bufs = combineByKey[ArrayBuffer[V]](
createCombiner _, mergeValue _, mergeCombiners _, partitioner)
bufs.flatMapValues(buf => buf)
/**
* Repartition the RDD using the specified partitioner. If mapSideCombine is
* true, Spark will group values of the same key together on the map side
* before the repartitioning. If a large number of duplicated keys are
* expected, and the size of the keys are large, mapSideCombine should be set
* to true.
*/
def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = {
if (mapSideCombine) {
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
val bufs = combineByKey[ArrayBuffer[V]](
createCombiner _, mergeValue _, mergeCombiners _, partitioner)
bufs.flatMapValues(buf => buf)
} else {
new RepartitionShuffledRDD(self, partitioner)
}
}
def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = {
@ -417,21 +436,7 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
with Serializable {
def sortByKey(ascending: Boolean = true): RDD[(K,V)] = {
val rangePartitionedRDD = self.partitionBy(new RangePartitioner(self.splits.size, self, ascending))
new SortedRDD(rangePartitionedRDD, ascending)
}
}
class SortedRDD[K <% Ordered[K], V](prev: RDD[(K, V)], ascending: Boolean)
extends RDD[(K, V)](prev.context) {
override def splits = prev.splits
override val partitioner = prev.partitioner
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = {
prev.iterator(split).toArray
.sortWith((x, y) => if (ascending) x._1 < y._1 else x._1 > y._1).iterator
new ShuffledSortedRDD(self, ascending)
}
}

Просмотреть файл

@ -1,18 +1,24 @@
package spark
import scala.collection.mutable.ArrayBuffer
import java.util.{HashMap => JHashMap}
class ShuffledRDDSplit(val idx: Int) extends Split {
override val index = idx
override def hashCode(): Int = idx
}
class ShuffledRDD[K, V, C](
/**
* The resulting RDD from a shuffle (e.g. repartitioning of data).
*/
abstract class ShuffledRDD[K, V, C](
@transient parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C],
part : Partitioner)
extends RDD[(K, C)](parent.context) {
//override val partitioner = Some(part)
override val partitioner = Some(part)
@transient
@ -24,6 +30,60 @@ class ShuffledRDD[K, V, C](
val dep = new ShuffleDependency(context.newShuffleId, parent, aggregator, part)
override val dependencies = List(dep)
}
/**
* Repartition a key-value pair RDD.
*/
class RepartitionShuffledRDD[K, V](
@transient parent: RDD[(K, V)],
part : Partitioner)
extends ShuffledRDD[K, V, V](
parent,
Aggregator[K, V, V](null, null, null, false),
part) {
override def compute(split: Split): Iterator[(K, V)] = {
val buf = new ArrayBuffer[(K, V)]
val fetcher = SparkEnv.get.shuffleFetcher
def addTupleToBuffer(k: K, v: V) = { buf += Tuple(k, v) }
fetcher.fetch[K, V](dep.shuffleId, split.index, addTupleToBuffer)
buf.iterator
}
}
/**
* A sort-based shuffle (that doesn't apply aggregation). It does so by first
* repartitioning the RDD by range, and then sort within each range.
*/
class ShuffledSortedRDD[K <% Ordered[K]: ClassManifest, V](
@transient parent: RDD[(K, V)],
ascending: Boolean)
extends RepartitionShuffledRDD[K, V](
parent,
new RangePartitioner(parent.splits.size, parent, ascending)) {
override def compute(split: Split): Iterator[(K, V)] = {
// By separating this from RepartitionShuffledRDD, we avoided a
// buf.iterator.toArray call, thus avoiding building up the buffer twice.
val buf = new ArrayBuffer[(K, V)]
def addTupleToBuffer(k: K, v: V) = { buf += Tuple(k, v) }
SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index, addTupleToBuffer)
buf.sortWith((x, y) => if (ascending) x._1 < y._1 else x._1 > y._1).iterator
}
}
/**
* The resulting RDD from shuffle and running (hash-based) aggregation.
*/
class ShuffledAggregatedRDD[K, V, C](
@transient parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C],
part : Partitioner)
extends ShuffledRDD[K, V, C](parent, aggregator, part) {
override def compute(split: Split): Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]

Просмотреть файл

@ -44,7 +44,8 @@ object ShuffleMapTask {
}
// Since both the JarSet and FileSet have the same format this is used for both.
def serializeFileSet(set : HashMap[String, Long], stageId: Int, cache : JHashMap[Int, Array[Byte]]) : Array[Byte] = {
def serializeFileSet(
set : HashMap[String, Long], stageId: Int, cache : JHashMap[Int, Array[Byte]]) : Array[Byte] = {
val old = cache.get(stageId)
if (old != null) {
return old
@ -59,7 +60,6 @@ object ShuffleMapTask {
}
}
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = {
synchronized {
val loader = Thread.currentThread.getContextClassLoader
@ -113,7 +113,8 @@ class ShuffleMapTask(
out.writeInt(bytes.length)
out.write(bytes)
val fileSetBytes = ShuffleMapTask.serializeFileSet(fileSet, stageId, ShuffleMapTask.fileSetCache)
val fileSetBytes = ShuffleMapTask.serializeFileSet(
fileSet, stageId, ShuffleMapTask.fileSetCache)
out.writeInt(fileSetBytes.length)
out.write(fileSetBytes)
val jarSetBytes = ShuffleMapTask.serializeFileSet(jarSet, stageId, ShuffleMapTask.jarSetCache)
@ -172,7 +173,7 @@ class ShuffleMapTask(
buckets.map(_.iterator)
} else {
// No combiners (no map-side aggregation). Simply partition the map output.
val buckets = Array.tabulate(numOutputSplits)(_ => new ArrayBuffer[(Any, Any)])
val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
for (elem <- rdd.iterator(split)) {
val pair = elem.asInstanceOf[(Any, Any)]
val bucketId = partitioner.getPartition(pair._1)

Просмотреть файл

@ -212,7 +212,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
_+_,
_+_,
false)
val shuffledRdd = new ShuffledRDD(
val shuffledRdd = new ShuffledAggregatedRDD(
pairs, aggregator, new HashPartitioner(2))
assert(shuffledRdd.collect().toSet === Set((1, 8), (2, 1)))
@ -220,7 +220,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
// not see an exception because mergeCombine should not have been called.
val aggregatorWithException = new Aggregator[Int, Int, Int](
(v: Int) => v, _+_, ShuffleSuite.mergeCombineException, false)
val shuffledRdd1 = new ShuffledRDD(
val shuffledRdd1 = new ShuffledAggregatedRDD(
pairs, aggregatorWithException, new HashPartitioner(2))
assert(shuffledRdd1.collect().toSet === Set((1, 8), (2, 1)))
@ -228,7 +228,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
// expect to see an exception thrown.
val aggregatorWithException1 = new Aggregator[Int, Int, Int](
(v: Int) => v, _+_, ShuffleSuite.mergeCombineException)
val shuffledRdd2 = new ShuffledRDD(
val shuffledRdd2 = new ShuffledAggregatedRDD(
pairs, aggregatorWithException1, new HashPartitioner(2))
evaluating { shuffledRdd2.collect() } should produce [SparkException]
}