SPARK-635: Pass a TaskContext object to compute() interface and use that

to close Hadoop input stream.
This commit is contained in:
Reynold Xin 2012-12-13 15:41:53 -08:00
Родитель 391e5a194a
Коммит eacb98e900
26 изменённых файлов: 207 добавлений и 205 удалений

Просмотреть файл

@ -1,5 +1,9 @@
package spark
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import akka.actor._
import akka.dispatch._
import akka.pattern.ask
@ -8,10 +12,6 @@ import akka.util.Duration
import akka.util.Timeout
import akka.util.duration._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import spark.storage.BlockManager
import spark.storage.StorageLevel
@ -41,7 +41,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging {
private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
def receive = {
case SlaveCacheStarted(host: String, size: Long) =>
slaveCapacity.put(host, size)
@ -92,14 +92,14 @@ private[spark] class CacheTrackerActor extends Actor with Logging {
private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
extends Logging {
// Tracker actor on the master, or remote reference to it on workers
val ip: String = System.getProperty("spark.master.host", "localhost")
val port: Int = System.getProperty("spark.master.port", "7077").toInt
val actorName: String = "CacheTracker"
val timeout = 10.seconds
var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName)
logInfo("Registered CacheTrackerActor actor")
@ -132,7 +132,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
throw new SparkException("Error reply received from CacheTracker")
}
}
// Registers an RDD (on master only)
def registerRDD(rddId: Int, numPartitions: Int) {
registeredRddIds.synchronized {
@ -143,7 +143,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
}
}
}
// For BlockManager.scala only
def cacheLost(host: String) {
communicate(MemoryCacheLost(host))
@ -155,19 +155,21 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
def getCacheStatus(): Seq[(String, Long, Long)] = {
askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]]
}
// For BlockManager.scala only
def notifyFromBlockManager(t: AddedToCache) {
communicate(t)
}
// Get a snapshot of the currently known locations
def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
}
// Gets or computes an RDD split
def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = {
def getOrCompute[T](
rdd: RDD[T], split: Split, taskContext: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
@ -209,7 +211,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
// TODO: also register a listener for when it unloads
logInfo("Computing partition " + split)
val elements = new ArrayBuffer[Any]
elements ++= rdd.compute(split)
elements ++= rdd.compute(split, taskContext)
try {
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)

Просмотреть файл

@ -35,11 +35,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
with Serializable {
/**
* Generic function to combine the elements for each key using a custom set of aggregation
* Generic function to combine the elements for each key using a custom set of aggregation
* functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C
* Note that V and C can be different -- for example, one might group an RDD of type
* (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions:
*
*
* - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
* - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
* - `mergeCombiners`, to combine two C's into a single one.
@ -118,7 +118,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/** Count the number of elements for each key, and return the result to the master as a Map. */
def countByKey(): Map[K, Long] = self.map(_._1).countByValue()
/**
/**
* (Experimental) Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
@ -224,7 +224,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
}
/**
/**
* Simplified version of combineByKey that hash-partitions the resulting RDD using the default
* parallelism level.
*/
@ -628,7 +628,8 @@ class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)]
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override val partitioner = prev.partitioner
override def compute(split: Split) = prev.iterator(split).map{case (k, v) => (k, f(v))}
override def compute(split: Split, taskContext: TaskContext) =
prev.iterator(split, taskContext).map{case (k, v) => (k, f(v))}
}
private[spark]
@ -639,8 +640,8 @@ class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U]
override val dependencies = List(new OneToOneDependency(prev))
override val partitioner = prev.partitioner
override def compute(split: Split) = {
prev.iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) }
override def compute(split: Split, taskContext: TaskContext) = {
prev.iterator(split, taskContext).flatMap { case (k, v) => f(v).map(x => (k, x)) }
}
}

Просмотреть файл

@ -8,8 +8,8 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest](
val slice: Int,
values: Seq[T])
extends Split with Serializable {
def iterator(): Iterator[T] = values.iterator
def iterator: Iterator[T] = values.iterator
override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt
@ -22,7 +22,7 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest](
}
private[spark] class ParallelCollection[T: ClassManifest](
sc: SparkContext,
sc: SparkContext,
@transient data: Seq[T],
numSlices: Int)
extends RDD[T](sc) {
@ -38,17 +38,18 @@ private[spark] class ParallelCollection[T: ClassManifest](
override def splits = splits_.asInstanceOf[Array[Split]]
override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator
override def compute(s: Split, taskContext: TaskContext) =
s.asInstanceOf[ParallelCollectionSplit[T]].iterator
override def preferredLocations(s: Split): Seq[String] = Nil
override val dependencies: List[Dependency[_]] = Nil
}
private object ParallelCollection {
/**
* Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
* collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
* collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
* it efficient to run Spark over RDDs representing large sets of numbers.
*/
def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
@ -58,7 +59,7 @@ private object ParallelCollection {
seq match {
case r: Range.Inclusive => {
val sign = if (r.step < 0) {
-1
-1
} else {
1
}

Просмотреть файл

@ -81,7 +81,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def splits: Array[Split]
/** Function for computing a given partition. */
def compute(split: Split): Iterator[T]
def compute(split: Split, taskContext: TaskContext): Iterator[T]
/** How this RDD depends on any parent RDDs. */
@transient val dependencies: List[Dependency[_]]
@ -155,11 +155,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
* This should ''not'' be called by users directly, but is available for implementors of custom
* subclasses of RDD.
*/
final def iterator(split: Split): Iterator[T] = {
final def iterator(split: Split, taskContext: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel)
SparkEnv.get.cacheTracker.getOrCompute[T](this, split, taskContext, storageLevel)
} else {
compute(split)
compute(split, taskContext)
}
}

Просмотреть файл

@ -1,3 +1,20 @@
package spark
class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable
import scala.collection.mutable.ArrayBuffer
class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
@transient
val onCompleteCallbacks = new ArrayBuffer[Unit => Unit]
// Add a callback function to be executed on task completion. An example use
// is for HadoopRDD to register a callback to close the input stream.
def registerOnCompleteCallback(f: Unit => Unit) {
onCompleteCallbacks += f
}
def executeOnCompleteCallbacks() {
onCompleteCallbacks.foreach{_()}
}
}

Просмотреть файл

@ -1,16 +1,15 @@
package spark.api.java
import spark.{SparkContext, Split, RDD}
import java.util.{List => JList}
import scala.Tuple2
import scala.collection.JavaConversions._
import spark.{SparkContext, Split, RDD, TaskContext}
import spark.api.java.JavaPairRDD._
import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
import spark.partial.{PartialResult, BoundedDouble}
import spark.storage.StorageLevel
import java.util.{List => JList}
import scala.collection.JavaConversions._
import java.{util, lang}
import scala.Tuple2
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def wrapRDD(rdd: RDD[T]): This
@ -24,7 +23,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
/** The [[spark.SparkContext]] that this RDD was created on. */
def context: SparkContext = rdd.context
/** A unique ID for this RDD (within its SparkContext). */
def id: Int = rdd.id
@ -36,7 +35,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* This should ''not'' be called by users directly, but is available for implementors of custom
* subclasses of RDD.
*/
def iterator(split: Split): java.util.Iterator[T] = asJavaIterator(rdd.iterator(split))
def iterator(split: Split, taskContext: TaskContext): java.util.Iterator[T] =
asJavaIterator(rdd.iterator(split, taskContext))
// Transformations (return a new RDD)
@ -99,7 +99,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType())
}
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
@ -183,7 +182,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
// Actions (launch a job to return a value to the user program)
/**
* Applies a function f to all elements of this RDD.
*/
@ -200,7 +199,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
val arr: java.util.Collection[T] = rdd.collect().toSeq
new java.util.ArrayList(arr)
}
/**
* Reduces the elements of this RDD using the specified associative binary operator.
*/
@ -208,7 +207,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using a
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
* modify t1 and return it as its result value to avoid object allocation; however, it should not
* modify t2.
*/
@ -251,7 +250,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): java.util.Map[T, java.lang.Long] =
mapAsJavaMap(rdd.countByValue().map((x => (x._1, new lang.Long(x._2)))))
mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
/**
* (Experimental) Approximate version of countByValue().

Просмотреть файл

@ -2,11 +2,8 @@ package spark.rdd
import scala.collection.mutable.HashMap
import spark.Dependency
import spark.RDD
import spark.SparkContext
import spark.SparkEnv
import spark.Split
import spark.{Dependency, RDD, SparkContext, SparkEnv, Split, TaskContext}
private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
val index = idx
@ -19,29 +16,29 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
@transient
val splits_ = (0 until blockIds.size).map(i => {
new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
}).toArray
@transient
}).toArray
@transient
lazy val locations_ = {
val blockManager = SparkEnv.get.blockManager
val blockManager = SparkEnv.get.blockManager
/*val locations = blockIds.map(id => blockManager.getLocations(id))*/
val locations = blockManager.getLocations(blockIds)
val locations = blockManager.getLocations(blockIds)
HashMap(blockIds.zip(locations):_*)
}
override def splits = splits_
override def compute(split: Split): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
override def compute(split: Split, taskContext: TaskContext): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
val blockId = split.asInstanceOf[BlockRDDSplit].blockId
blockManager.get(blockId) match {
case Some(block) => block.asInstanceOf[Iterator[T]]
case None =>
case None =>
throw new Exception("Could not compute split, block " + blockId + " not found")
}
}
override def preferredLocations(split: Split) =
override def preferredLocations(split: Split) =
locations_(split.asInstanceOf[BlockRDDSplit].blockId)
override val dependencies: List[Dependency[_]] = Nil

Просмотреть файл

@ -1,9 +1,7 @@
package spark.rdd
import spark.NarrowDependency
import spark.RDD
import spark.SparkContext
import spark.Split
import spark.{NarrowDependency, RDD, SparkContext, Split, TaskContext}
private[spark]
class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable {
@ -17,9 +15,9 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
rdd2: RDD[U])
extends RDD[Pair[T, U]](sc)
with Serializable {
val numSplitsInRdd2 = rdd2.splits.size
@transient
val splits_ = {
// create the cross product split
@ -38,11 +36,12 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
}
override def compute(split: Split) = {
override def compute(split: Split, taskContext: TaskContext) = {
val currSplit = split.asInstanceOf[CartesianSplit]
for (x <- rdd1.iterator(currSplit.s1); y <- rdd2.iterator(currSplit.s2)) yield (x, y)
for (x <- rdd1.iterator(currSplit.s1, taskContext);
y <- rdd2.iterator(currSplit.s2, taskContext)) yield (x, y)
}
override val dependencies = List(
new NarrowDependency(rdd1) {
def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)

Просмотреть файл

@ -3,21 +3,15 @@ package spark.rdd
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import spark.Aggregator
import spark.Dependency
import spark.Logging
import spark.OneToOneDependency
import spark.Partitioner
import spark.RDD
import spark.ShuffleDependency
import spark.SparkEnv
import spark.Split
import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
private[spark] sealed trait CoGroupSplitDep extends Serializable
private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
private[spark]
private[spark]
class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable {
override val index: Int = idx
override def hashCode(): Int = idx
@ -32,9 +26,9 @@ private[spark] class CoGroupAggregator
class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging {
val aggr = new CoGroupAggregator
@transient
override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
@ -50,7 +44,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
deps.toList
}
@transient
val splits_ : Array[Split] = {
val firstRdd = rdds.head
@ -69,12 +63,12 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
override def splits = splits_
override val partitioner = Some(part)
override def preferredLocations(s: Split) = Nil
override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = {
override def compute(s: Split, taskContext: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
val numRdds = split.deps.size
val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
@ -84,7 +78,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, itsSplit) => {
// Read them from the parent
for ((k, v) <- rdd.iterator(itsSplit)) {
for ((k, v) <- rdd.iterator(itsSplit, taskContext)) {
getSeq(k.asInstanceOf[K])(depNum) += v
}
}

Просмотреть файл

@ -1,8 +1,7 @@
package spark.rdd
import spark.NarrowDependency
import spark.RDD
import spark.Split
import spark.{NarrowDependency, RDD, Split, TaskContext}
private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split
@ -32,9 +31,9 @@ class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int)
override def splits = splits_
override def compute(split: Split): Iterator[T] = {
override def compute(split: Split, taskContext: TaskContext): Iterator[T] = {
split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap {
parentSplit => prev.iterator(parentSplit)
parentSplit => prev.iterator(parentSplit, taskContext)
}
}

Просмотреть файл

@ -1,12 +1,12 @@
package spark.rdd
import spark.OneToOneDependency
import spark.RDD
import spark.Split
import spark.{OneToOneDependency, RDD, Split, TaskContext}
private[spark]
class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = prev.iterator(split).filter(f)
override def compute(split: Split, taskContext: TaskContext) =
prev.iterator(split, taskContext).filter(f)
}

Просмотреть файл

@ -1,16 +1,16 @@
package spark.rdd
import spark.OneToOneDependency
import spark.RDD
import spark.Split
import spark.{OneToOneDependency, RDD, Split, TaskContext}
private[spark]
class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: T => TraversableOnce[U])
extends RDD[U](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = prev.iterator(split).flatMap(f)
override def compute(split: Split, taskContext: TaskContext) =
prev.iterator(split, taskContext).flatMap(f)
}

Просмотреть файл

@ -1,12 +1,12 @@
package spark.rdd
import spark.OneToOneDependency
import spark.RDD
import spark.Split
import spark.{OneToOneDependency, RDD, Split, TaskContext}
private[spark]
class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator
override def compute(split: Split, taskContext: TaskContext) =
Array(prev.iterator(split, taskContext).toArray).iterator
}

Просмотреть файл

@ -15,19 +15,16 @@ import org.apache.hadoop.mapred.RecordReader
import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
import spark.Dependency
import spark.RDD
import spark.SerializableWritable
import spark.SparkContext
import spark.Split
import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext}
/**
/**
* A Spark split class that wraps around a Hadoop InputSplit.
*/
private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
extends Split
with Serializable {
val inputSplit = new SerializableWritable[InputSplit](s)
override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
@ -47,10 +44,10 @@ class HadoopRDD[K, V](
valueClass: Class[V],
minSplits: Int)
extends RDD[(K, V)](sc) {
// A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it
val confBroadcast = sc.broadcast(new SerializableWritable(conf))
@transient
val splits_ : Array[Split] = {
val inputFormat = createInputFormat(conf)
@ -69,7 +66,7 @@ class HadoopRDD[K, V](
override def splits = splits_
override def compute(theSplit: Split) = new Iterator[(K, V)] {
override def compute(theSplit: Split, taskContext: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[HadoopSplit]
var reader: RecordReader[K, V] = null
@ -77,6 +74,9 @@ class HadoopRDD[K, V](
val fmt = createInputFormat(conf)
reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
taskContext.registerOnCompleteCallback(Unit => reader.close())
val key: K = reader.createKey()
val value: V = reader.createValue()
var gotNext = false
@ -115,6 +115,6 @@ class HadoopRDD[K, V](
val hadoopSplit = split.asInstanceOf[HadoopSplit]
hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
}
override val dependencies: List[Dependency[_]] = Nil
}

Просмотреть файл

@ -1,8 +1,7 @@
package spark.rdd
import spark.OneToOneDependency
import spark.RDD
import spark.Split
import spark.{OneToOneDependency, RDD, Split, TaskContext}
private[spark]
class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
@ -12,8 +11,9 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
extends RDD[U](prev.context) {
override val partitioner = if (preservesPartitioning) prev.partitioner else None
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = f(prev.iterator(split))
override def compute(split: Split, taskContext: TaskContext) =
f(prev.iterator(split, taskContext))
}

Просмотреть файл

@ -1,8 +1,6 @@
package spark.rdd
import spark.OneToOneDependency
import spark.RDD
import spark.Split
import spark.{OneToOneDependency, RDD, Split, TaskContext}
/**
* A variant of the MapPartitionsRDD that passes the split index into the
@ -19,5 +17,6 @@ class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
override val partitioner = if (preservesPartitioning) prev.partitioner else None
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = f(split.index, prev.iterator(split))
override def compute(split: Split, taskContext: TaskContext) =
f(split.index, prev.iterator(split, taskContext))
}

Просмотреть файл

@ -1,16 +1,15 @@
package spark.rdd
import spark.OneToOneDependency
import spark.RDD
import spark.Split
import spark.{OneToOneDependency, RDD, Split, TaskContext}
private[spark]
class MappedRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: T => U)
extends RDD[U](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = prev.iterator(split).map(f)
override def compute(split: Split, taskContext: TaskContext) =
prev.iterator(split, taskContext).map(f)
}

Просмотреть файл

@ -1,22 +1,19 @@
package spark.rdd
import java.text.SimpleDateFormat
import java.util.Date
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
import java.util.Date
import java.text.SimpleDateFormat
import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext}
import spark.Dependency
import spark.RDD
import spark.SerializableWritable
import spark.SparkContext
import spark.Split
private[spark]
private[spark]
class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
extends Split {
val serializableHadoopSplit = new SerializableWritable(rawSplit)
override def hashCode(): Int = (41 * (41 + rddId) + index)
@ -29,7 +26,7 @@ class NewHadoopRDD[K, V](
@transient conf: Configuration)
extends RDD[(K, V)](sc)
with HadoopMapReduceUtil {
// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
val confBroadcast = sc.broadcast(new SerializableWritable(conf))
// private val serializableConf = new SerializableWritable(conf)
@ -56,7 +53,7 @@ class NewHadoopRDD[K, V](
override def splits = splits_
override def compute(theSplit: Split) = new Iterator[(K, V)] {
override def compute(theSplit: Split, taskContext: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopSplit]
val conf = confBroadcast.value.value
val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
@ -64,7 +61,10 @@ class NewHadoopRDD[K, V](
val format = inputFormatClass.newInstance
val reader = format.createRecordReader(split.serializableHadoopSplit.value, context)
reader.initialize(split.serializableHadoopSplit.value, context)
// Register an on-task-completion callback to close the input stream.
taskContext.registerOnCompleteCallback(Unit => reader.close())
var havePair = false
var finished = false
@ -72,9 +72,6 @@ class NewHadoopRDD[K, V](
if (!finished && !havePair) {
finished = !reader.nextKeyValue
havePair = !finished
if (finished) {
reader.close()
}
}
!finished
}

Просмотреть файл

@ -8,10 +8,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import spark.OneToOneDependency
import spark.RDD
import spark.SparkEnv
import spark.Split
import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext}
/**
@ -32,12 +29,12 @@ class PipedRDD[T: ClassManifest](
override val dependencies = List(new OneToOneDependency(parent))
override def compute(split: Split): Iterator[String] = {
override def compute(split: Split, taskContext: TaskContext): Iterator[String] = {
val pb = new ProcessBuilder(command)
// Add the environmental variables to the process.
val currentEnvVars = pb.environment()
envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) }
val proc = pb.start()
val env = SparkEnv.get
@ -55,7 +52,7 @@ class PipedRDD[T: ClassManifest](
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
for (elem <- parent.iterator(split)) {
for (elem <- parent.iterator(split, taskContext)) {
out.println(elem)
}
out.close()

Просмотреть файл

@ -4,9 +4,8 @@ import java.util.Random
import cern.jet.random.Poisson
import cern.jet.random.engine.DRand
import spark.RDD
import spark.OneToOneDependency
import spark.Split
import spark.{OneToOneDependency, RDD, Split, TaskContext}
private[spark]
class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
@ -15,7 +14,7 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali
class SampledRDD[T: ClassManifest](
prev: RDD[T],
withReplacement: Boolean,
withReplacement: Boolean,
frac: Double,
seed: Int)
extends RDD[T](prev.context) {
@ -29,17 +28,17 @@ class SampledRDD[T: ClassManifest](
override def splits = splits_.asInstanceOf[Array[Split]]
override val dependencies = List(new OneToOneDependency(prev))
override def preferredLocations(split: Split) =
prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
override def compute(splitIn: Split) = {
override def compute(splitIn: Split, taskContext: TaskContext) = {
val split = splitIn.asInstanceOf[SampledRDDSplit]
if (withReplacement) {
// For large datasets, the expected number of occurrences of each element in a sample with
// replacement is Poisson(frac). We use that to get a count for each element.
val poisson = new Poisson(frac, new DRand(split.seed))
prev.iterator(split.prev).flatMap { element =>
prev.iterator(split.prev, taskContext).flatMap { element =>
val count = poisson.nextInt()
if (count == 0) {
Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
@ -49,7 +48,7 @@ class SampledRDD[T: ClassManifest](
}
} else { // Sampling without replacement
val rand = new Random(split.seed)
prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac))
prev.iterator(split.prev, taskContext).filter(x => (rand.nextDouble <= frac))
}
}
}

Просмотреть файл

@ -1,10 +1,7 @@
package spark.rdd
import spark.Partitioner
import spark.RDD
import spark.ShuffleDependency
import spark.SparkEnv
import spark.Split
import spark.{OneToOneDependency, Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext}
private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
override val index = idx
@ -34,7 +31,7 @@ class ShuffledRDD[K, V](
val dep = new ShuffleDependency(parent, part)
override val dependencies = List(dep)
override def compute(split: Split): Iterator[(K, V)] = {
override def compute(split: Split, taskContext: TaskContext): Iterator[(K, V)] = {
SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index)
}
}

Просмотреть файл

@ -2,20 +2,17 @@ package spark.rdd
import scala.collection.mutable.ArrayBuffer
import spark.Dependency
import spark.RangeDependency
import spark.RDD
import spark.SparkContext
import spark.Split
import spark.{Dependency, RangeDependency, RDD, SparkContext, Split, TaskContext}
private[spark] class UnionSplit[T: ClassManifest](
idx: Int,
idx: Int,
rdd: RDD[T],
split: Split)
extends Split
with Serializable {
def iterator() = rdd.iterator(split)
def iterator(taskContext: TaskContext) = rdd.iterator(split, taskContext)
def preferredLocations() = rdd.preferredLocations(split)
override val index: Int = idx
}
@ -25,7 +22,7 @@ class UnionRDD[T: ClassManifest](
@transient rdds: Seq[RDD[T]])
extends RDD[T](sc)
with Serializable {
@transient
val splits_ : Array[Split] = {
val array = new Array[Split](rdds.map(_.splits.size).sum)
@ -44,13 +41,14 @@ class UnionRDD[T: ClassManifest](
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for (rdd <- rdds) {
deps += new RangeDependency(rdd, 0, pos, rdd.splits.size)
deps += new RangeDependency(rdd, 0, pos, rdd.splits.size)
pos += rdd.splits.size
}
deps.toList
}
override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator()
override def compute(s: Split, taskContext: TaskContext): Iterator[T] =
s.asInstanceOf[UnionSplit[T]].iterator(taskContext)
override def preferredLocations(s: Split): Seq[String] =
s.asInstanceOf[UnionSplit[T]].preferredLocations()

Просмотреть файл

@ -1,21 +1,19 @@
package spark.rdd
import spark.Dependency
import spark.OneToOneDependency
import spark.RDD
import spark.SparkContext
import spark.Split
import spark.{OneToOneDependency, RDD, SparkContext, Split, TaskContext}
private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
idx: Int,
idx: Int,
rdd1: RDD[T],
rdd2: RDD[U],
split1: Split,
split2: Split)
extends Split
with Serializable {
def iterator(): Iterator[(T, U)] = rdd1.iterator(split1).zip(rdd2.iterator(split2))
def iterator(taskContext: TaskContext): Iterator[(T, U)] =
rdd1.iterator(split1, taskContext).zip(rdd2.iterator(split2, taskContext))
def preferredLocations(): Seq[String] =
rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
@ -46,8 +44,9 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
@transient
override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))
override def compute(s: Split): Iterator[(T, U)] = s.asInstanceOf[ZippedSplit[T, U]].iterator()
override def compute(s: Split, taskContext: TaskContext): Iterator[(T, U)] =
s.asInstanceOf[ZippedSplit[T, U]].iterator(taskContext)
override def preferredLocations(s: Split): Seq[String] =
s.asInstanceOf[ZippedSplit[T, U]].preferredLocations()

Просмотреть файл

@ -16,8 +16,8 @@ import spark.storage.BlockManagerMaster
import spark.storage.BlockManagerId
/**
* A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
* each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
* A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
* each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
* schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/
@ -73,7 +73,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back;
// that's not going to be a realistic assumption in general
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
val running = new HashSet[Stage] // Stages we are running right now
val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
@ -94,7 +94,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
cacheLocs(rdd.id)
}
def updateCacheLocs() {
cacheLocs = cacheTracker.getLocationsSnapshot()
}
@ -326,7 +326,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val rdd = job.finalStage.rdd
val split = rdd.splits(job.partitions(0))
val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
val result = job.func(taskContext, rdd.iterator(split))
val result = job.func(taskContext, rdd.iterator(split, taskContext))
taskContext.executeOnCompleteCallbacks()
job.listener.taskSucceeded(0, result)
} catch {
case e: Exception =>
@ -353,7 +354,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
}
def submitMissingTasks(stage: Stage) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
@ -395,7 +396,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val task = event.task
val stage = idToStage(task.stageId)
event.reason match {
case Success =>
case Success =>
logInfo("Completed " + task)
if (event.accumUpdates != null) {
Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
@ -519,7 +520,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
updateCacheLocs()
}
}
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.

Просмотреть файл

@ -10,12 +10,14 @@ private[spark] class ResultTask[T, U](
@transient locs: Seq[String],
val outputId: Int)
extends Task[U](stageId) {
val split = rdd.splits(partition)
override def run(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId)
func(context, rdd.iterator(split))
val result = func(context, rdd.iterator(split, context))
context.executeOnCompleteCallbacks()
result
}
override def preferredLocations: Seq[String] = locs

Просмотреть файл

@ -70,19 +70,19 @@ private[spark] object ShuffleMapTask {
private[spark] class ShuffleMapTask(
stageId: Int,
var rdd: RDD[_],
var rdd: RDD[_],
var dep: ShuffleDependency[_,_],
var partition: Int,
var partition: Int,
@transient var locs: Seq[String])
extends Task[MapStatus](stageId)
with Externalizable
with Logging {
def this() = this(0, null, null, 0, null)
var split = if (rdd == null) {
null
} else {
null
} else {
rdd.splits(partition)
}
@ -113,9 +113,11 @@ private[spark] class ShuffleMapTask(
val numOutputSplits = dep.partitioner.numPartitions
val partitioner = dep.partitioner
val taskContext = new TaskContext(stageId, partition, attemptId)
// Partition the map output.
val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
for (elem <- rdd.iterator(split)) {
for (elem <- rdd.iterator(split, taskContext)) {
val pair = elem.asInstanceOf[(Any, Any)]
val bucketId = partitioner.getPartition(pair._1)
buckets(bucketId) += pair
@ -133,6 +135,9 @@ private[spark] class ShuffleMapTask(
compressedSizes(i) = MapOutputTracker.compressSize(size)
}
// Execute the callbacks on task completion.
taskContext.executeOnCompleteCallbacks()
return new MapStatus(blockManager.blockManagerId, compressedSizes)
}