Merge pull request #222 from rxin/dev

Added MapPartitionsWithSplitRDD.
This commit is contained in:
Matei Zaharia 2012-09-26 23:16:45 -07:00
Родитель ea05fc130b 1ad1331a34
Коммит 920fab23c3
2 изменённых файлов: 23 добавлений и 0 удалений

Просмотреть файл

@ -196,6 +196,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] = def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] =
new MapPartitionsRDD(this, sc.clean(f)) new MapPartitionsRDD(this, sc.clean(f))
def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
new MapPartitionsWithSplitRDD(this, sc.clean(f))
// Actions (launch a job to return a value to the user program) // Actions (launch a job to return a value to the user program)
def foreach(f: T => Unit) { def foreach(f: T => Unit) {
@ -417,3 +420,18 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
override val dependencies = List(new OneToOneDependency(prev)) override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = f(prev.iterator(split)) override def compute(split: Split) = f(prev.iterator(split))
} }
/**
* A variant of the MapPartitionsRDD that passes the split index into the
* closure. This can be used to generate or collect partition specific
* information such as the number of tuples in a partition.
*/
class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: (Int, Iterator[T]) => Iterator[U])
extends RDD[U](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = f(split.index, prev.iterator(split))
}

Просмотреть файл

@ -29,6 +29,11 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
assert(partitionSums.collect().toList === List(3, 7)) assert(partitionSums.collect().toList === List(3, 7))
val partitionSumsWithSplit = nums.mapPartitionsWithSplit {
case(split, iter) => Iterator((split, iter.reduceLeft(_ + _)))
}
assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7)))
} }
test("SparkContext.union") { test("SparkContext.union") {