зеркало из https://github.com/microsoft/spark.git
Pass cache locations correctly to DAGScheduler.
This commit is contained in:
Родитель
e1436f1eaa
Коммит
1df5a65a01
|
@ -33,20 +33,14 @@ private abstract class DAGScheduler extends Scheduler with Logging {
|
|||
|
||||
val shuffleToMapStage = new HashMap[ShuffleDependency[_,_,_], Stage]
|
||||
|
||||
val cacheLocs = new HashMap[RDD[_], Array[List[String]]]
|
||||
var cacheLocs = new HashMap[Int, Array[List[String]]]
|
||||
|
||||
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
|
||||
cacheLocs.getOrElseUpdate(rdd, Array.fill[List[String]](rdd.splits.size)(Nil))
|
||||
cacheLocs(rdd.id)
|
||||
}
|
||||
|
||||
def addCacheLoc(rdd: RDD[_], partition: Int, host: String) {
|
||||
val locs = getCacheLocs(rdd)
|
||||
locs(partition) = host :: locs(partition)
|
||||
}
|
||||
|
||||
def removeCacheLoc(rdd: RDD[_], partition: Int, host: String) {
|
||||
val locs = getCacheLocs(rdd)
|
||||
locs(partition) -= host
|
||||
|
||||
def updateCacheLocs() {
|
||||
cacheLocs = RDDCache.getLocationsSnapshot()
|
||||
}
|
||||
|
||||
def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = {
|
||||
|
@ -60,6 +54,9 @@ private abstract class DAGScheduler extends Scheduler with Logging {
|
|||
}
|
||||
|
||||
def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = {
|
||||
// Kind of ugly: need to register RDDs with the cache here since
|
||||
// we can't do it in its constructor because # of splits is unknown
|
||||
RDDCache.registerRDD(rdd.id, rdd.splits.size)
|
||||
val id = newStageId()
|
||||
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd))
|
||||
idToStage(id) = stage
|
||||
|
@ -113,10 +110,10 @@ private abstract class DAGScheduler extends Scheduler with Logging {
|
|||
missing.toList
|
||||
}
|
||||
|
||||
override def runJob[T, U](rdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U])
|
||||
override def runJob[T, U](finalRdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U])
|
||||
: Array[U] = {
|
||||
val numOutputParts: Int = rdd.splits.size
|
||||
val finalStage = newStage(rdd, None)
|
||||
val numOutputParts: Int = finalRdd.splits.size
|
||||
val finalStage = newStage(finalRdd, None)
|
||||
val results = new Array[U](numOutputParts)
|
||||
val finished = new Array[Boolean](numOutputParts)
|
||||
var numFinished = 0
|
||||
|
@ -125,6 +122,8 @@ private abstract class DAGScheduler extends Scheduler with Logging {
|
|||
val running = new HashSet[Stage]
|
||||
val pendingTasks = new HashMap[Stage, HashSet[Task[_]]]
|
||||
|
||||
updateCacheLocs()
|
||||
|
||||
logInfo("Final stage: " + finalStage)
|
||||
logInfo("Parents of final stage: " + finalStage.parents)
|
||||
logInfo("Missing parents: " + getMissingParentStages(finalStage))
|
||||
|
@ -145,12 +144,13 @@ private abstract class DAGScheduler extends Scheduler with Logging {
|
|||
}
|
||||
|
||||
def submitMissingTasks(stage: Stage) {
|
||||
// Get our pending tasks and remember them in our pendingTasks entry
|
||||
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
|
||||
var tasks = ArrayBuffer[Task[_]]()
|
||||
if (stage == finalStage) {
|
||||
for (p <- 0 until numOutputParts if (!finished(p))) {
|
||||
val locs = getPreferredLocs(rdd, p)
|
||||
tasks += new ResultTask(finalStage.id, rdd, func, p, locs)
|
||||
val locs = getPreferredLocs(finalRdd, p)
|
||||
tasks += new ResultTask(finalStage.id, finalRdd, func, p, locs)
|
||||
}
|
||||
} else {
|
||||
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
|
||||
|
@ -186,6 +186,7 @@ private abstract class DAGScheduler extends Scheduler with Logging {
|
|||
if (pending.isEmpty) {
|
||||
logInfo(stage + " finished; looking for newly runnable stages")
|
||||
running -= stage
|
||||
updateCacheLocs()
|
||||
val newlyRunnable = new ArrayBuffer[Stage]
|
||||
for (stage <- waiting if getMissingParentStages(stage) == Nil) {
|
||||
newlyRunnable += stage
|
||||
|
|
|
@ -11,7 +11,7 @@ class MapOutputTracker extends DaemonActor with Logging {
|
|||
val port = System.getProperty("spark.master.port", "50501").toInt
|
||||
RemoteActor.alive(port)
|
||||
RemoteActor.register('MapOutputTracker, self)
|
||||
logInfo("Started on port " + port)
|
||||
logInfo("Registered actor on port " + port)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,31 +3,57 @@ package spark
|
|||
import scala.actors._
|
||||
import scala.actors.Actor._
|
||||
import scala.actors.remote._
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.mutable.HashSet
|
||||
|
||||
sealed trait CacheMessage
|
||||
case class CacheEntryAdded(rddId: Int, partition: Int, host: String)
|
||||
case class CacheEntryRemoved(rddId: Int, partition: Int, host: String)
|
||||
case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheMessage
|
||||
case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheMessage
|
||||
case class MemoryCacheLost(host: String) extends CacheMessage
|
||||
case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheMessage
|
||||
case object GetCacheLocations extends CacheMessage
|
||||
|
||||
class RDDCacheTracker extends DaemonActor with Logging {
|
||||
val locs = new HashMap[Int, Array[List[String]]]
|
||||
// TODO: Should probably store (String, CacheType) tuples
|
||||
|
||||
def act() {
|
||||
val port = System.getProperty("spark.master.port", "50501").toInt
|
||||
RemoteActor.alive(port)
|
||||
RemoteActor.register('RDDCacheTracker, self)
|
||||
logInfo("Started on port " + port)
|
||||
logInfo("Registered actor on port " + port)
|
||||
|
||||
loop {
|
||||
react {
|
||||
case CacheEntryAdded(rddId, partition, host) =>
|
||||
logInfo("Cache entry added: %s, %s, %s".format(rddId, partition, host))
|
||||
case RegisterRDD(rddId: Int, numPartitions: Int) =>
|
||||
logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
|
||||
locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
|
||||
reply("")
|
||||
|
||||
case AddedToCache(rddId, partition, host) =>
|
||||
logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
|
||||
locs(rddId)(partition) = host :: locs(rddId)(partition)
|
||||
|
||||
case CacheEntryRemoved(rddId, partition, host) =>
|
||||
logInfo("Cache entry removed: %s, %s, %s".format(rddId, partition, host))
|
||||
case DroppedFromCache(rddId, partition, host) =>
|
||||
logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
|
||||
locs(rddId)(partition) -= host
|
||||
|
||||
case MemoryCacheLost(host) =>
|
||||
logInfo("Memory cache lost on " + host)
|
||||
// TODO: Drop host from the memory locations list of all RDDs
|
||||
|
||||
case GetCacheLocations =>
|
||||
logInfo("Asked for current cache locations")
|
||||
val locsCopy = new HashMap[Int, Array[List[String]]]
|
||||
for ((rddId, array) <- locs) {
|
||||
locsCopy(rddId) = array.clone()
|
||||
}
|
||||
reply(locsCopy)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
import scala.collection.mutable.HashSet
|
||||
private object RDDCache extends Logging {
|
||||
// Stores map results for various splits locally
|
||||
val cache = Cache.newKeySpace()
|
||||
|
@ -38,6 +64,8 @@ private object RDDCache extends Logging {
|
|||
// Tracker actor on the master, or remote reference to it on workers
|
||||
var trackerActor: AbstractActor = null
|
||||
|
||||
val registeredRddIds = new HashSet[Int]
|
||||
|
||||
def initialize(isMaster: Boolean) {
|
||||
if (isMaster) {
|
||||
val tracker = new RDDCacheTracker
|
||||
|
@ -50,16 +78,34 @@ private object RDDCache extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
// Registers an RDD (on master only)
|
||||
def registerRDD(rddId: Int, numPartitions: Int) {
|
||||
registeredRddIds.synchronized {
|
||||
if (!registeredRddIds.contains(rddId)) {
|
||||
registeredRddIds += rddId
|
||||
trackerActor !? RegisterRDD(rddId, numPartitions)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get a snapshot of the currently known locations
|
||||
def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
|
||||
(trackerActor !? GetCacheLocations) match {
|
||||
case h: HashMap[Int, Array[List[String]]] => h
|
||||
case _ => throw new SparkException(
|
||||
"Internal error: RDDCache did not reply with a HashMap")
|
||||
}
|
||||
}
|
||||
|
||||
// Gets or computes an RDD split
|
||||
def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T])
|
||||
: Iterator[T] = {
|
||||
val key = (rdd.id, split.index)
|
||||
logInfo("CachedRDD split key is " + key)
|
||||
val cache = RDDCache.cache
|
||||
val loading = RDDCache.loading
|
||||
logInfo("CachedRDD partition key is " + key)
|
||||
val cachedVal = cache.get(key)
|
||||
if (cachedVal != null) {
|
||||
// Split is in cache, so just return its values
|
||||
logInfo("Found partition in cache!")
|
||||
return Iterator.fromArray(cachedVal.asInstanceOf[Array[T]])
|
||||
} else {
|
||||
// Mark the split as loading (unless someone else marks it first)
|
||||
|
@ -73,13 +119,13 @@ private object RDDCache extends Logging {
|
|||
loading.add(key)
|
||||
}
|
||||
}
|
||||
val host = System.getProperty("spark.hostname", Utils.localHostName)
|
||||
trackerActor ! CacheEntryAdded(rdd.id, split.index, host)
|
||||
// If we got here, we have to load the split
|
||||
// Tell the master that we're doing so
|
||||
val host = System.getProperty("spark.hostname", Utils.localHostName)
|
||||
trackerActor ! AddedToCache(rdd.id, split.index, host)
|
||||
// TODO: fetch any remote copy of the split that may be available
|
||||
// TODO: also notify the master that we're loading it
|
||||
// TODO: also register a listener for when it unloads
|
||||
logInfo("Computing and caching " + split)
|
||||
logInfo("Computing partition " + split)
|
||||
val array = rdd.compute(split).toArray(m)
|
||||
cache.put(key, array)
|
||||
loading.synchronized {
|
||||
|
|
|
@ -175,6 +175,7 @@ extends Logging {
|
|||
|
||||
private var nextRddId = new AtomicInteger(0)
|
||||
|
||||
// Register a new RDD, returning its RDD ID
|
||||
private[spark] def newRddId(): Int = {
|
||||
nextRddId.getAndIncrement()
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче