Pass cache locations correctly to DAGScheduler.

This commit is contained in:
Matei Zaharia 2011-03-06 12:16:38 -08:00
Родитель e1436f1eaa
Коммит 1df5a65a01
4 изменённых файлов: 80 добавлений и 32 удалений

Просмотреть файл

@ -33,20 +33,14 @@ private abstract class DAGScheduler extends Scheduler with Logging {
val shuffleToMapStage = new HashMap[ShuffleDependency[_,_,_], Stage]
val cacheLocs = new HashMap[RDD[_], Array[List[String]]]
var cacheLocs = new HashMap[Int, Array[List[String]]]
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
cacheLocs.getOrElseUpdate(rdd, Array.fill[List[String]](rdd.splits.size)(Nil))
cacheLocs(rdd.id)
}
def addCacheLoc(rdd: RDD[_], partition: Int, host: String) {
val locs = getCacheLocs(rdd)
locs(partition) = host :: locs(partition)
}
def removeCacheLoc(rdd: RDD[_], partition: Int, host: String) {
val locs = getCacheLocs(rdd)
locs(partition) -= host
def updateCacheLocs() {
cacheLocs = RDDCache.getLocationsSnapshot()
}
def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = {
@ -60,6 +54,9 @@ private abstract class DAGScheduler extends Scheduler with Logging {
}
def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = {
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
RDDCache.registerRDD(rdd.id, rdd.splits.size)
val id = newStageId()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd))
idToStage(id) = stage
@ -113,10 +110,10 @@ private abstract class DAGScheduler extends Scheduler with Logging {
missing.toList
}
override def runJob[T, U](rdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U])
override def runJob[T, U](finalRdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U])
: Array[U] = {
val numOutputParts: Int = rdd.splits.size
val finalStage = newStage(rdd, None)
val numOutputParts: Int = finalRdd.splits.size
val finalStage = newStage(finalRdd, None)
val results = new Array[U](numOutputParts)
val finished = new Array[Boolean](numOutputParts)
var numFinished = 0
@ -125,6 +122,8 @@ private abstract class DAGScheduler extends Scheduler with Logging {
val running = new HashSet[Stage]
val pendingTasks = new HashMap[Stage, HashSet[Task[_]]]
updateCacheLocs()
logInfo("Final stage: " + finalStage)
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
@ -145,12 +144,13 @@ private abstract class DAGScheduler extends Scheduler with Logging {
}
def submitMissingTasks(stage: Stage) {
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
var tasks = ArrayBuffer[Task[_]]()
if (stage == finalStage) {
for (p <- 0 until numOutputParts if (!finished(p))) {
val locs = getPreferredLocs(rdd, p)
tasks += new ResultTask(finalStage.id, rdd, func, p, locs)
val locs = getPreferredLocs(finalRdd, p)
tasks += new ResultTask(finalStage.id, finalRdd, func, p, locs)
}
} else {
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
@ -186,6 +186,7 @@ private abstract class DAGScheduler extends Scheduler with Logging {
if (pending.isEmpty) {
logInfo(stage + " finished; looking for newly runnable stages")
running -= stage
updateCacheLocs()
val newlyRunnable = new ArrayBuffer[Stage]
for (stage <- waiting if getMissingParentStages(stage) == Nil) {
newlyRunnable += stage

Просмотреть файл

@ -11,7 +11,7 @@ class MapOutputTracker extends DaemonActor with Logging {
val port = System.getProperty("spark.master.port", "50501").toInt
RemoteActor.alive(port)
RemoteActor.register('MapOutputTracker, self)
logInfo("Started on port " + port)
logInfo("Registered actor on port " + port)
}
}

Просмотреть файл

@ -3,31 +3,57 @@ package spark
import scala.actors._
import scala.actors.Actor._
import scala.actors.remote._
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
sealed trait CacheMessage
case class CacheEntryAdded(rddId: Int, partition: Int, host: String)
case class CacheEntryRemoved(rddId: Int, partition: Int, host: String)
case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheMessage
case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheMessage
case class MemoryCacheLost(host: String) extends CacheMessage
case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheMessage
case object GetCacheLocations extends CacheMessage
class RDDCacheTracker extends DaemonActor with Logging {
val locs = new HashMap[Int, Array[List[String]]]
// TODO: Should probably store (String, CacheType) tuples
def act() {
val port = System.getProperty("spark.master.port", "50501").toInt
RemoteActor.alive(port)
RemoteActor.register('RDDCacheTracker, self)
logInfo("Started on port " + port)
logInfo("Registered actor on port " + port)
loop {
react {
case CacheEntryAdded(rddId, partition, host) =>
logInfo("Cache entry added: %s, %s, %s".format(rddId, partition, host))
case RegisterRDD(rddId: Int, numPartitions: Int) =>
logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
reply("")
case AddedToCache(rddId, partition, host) =>
logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
locs(rddId)(partition) = host :: locs(rddId)(partition)
case CacheEntryRemoved(rddId, partition, host) =>
logInfo("Cache entry removed: %s, %s, %s".format(rddId, partition, host))
case DroppedFromCache(rddId, partition, host) =>
logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
locs(rddId)(partition) -= host
case MemoryCacheLost(host) =>
logInfo("Memory cache lost on " + host)
// TODO: Drop host from the memory locations list of all RDDs
case GetCacheLocations =>
logInfo("Asked for current cache locations")
val locsCopy = new HashMap[Int, Array[List[String]]]
for ((rddId, array) <- locs) {
locsCopy(rddId) = array.clone()
}
reply(locsCopy)
}
}
}
}
import scala.collection.mutable.HashSet
private object RDDCache extends Logging {
// Stores map results for various splits locally
val cache = Cache.newKeySpace()
@ -38,6 +64,8 @@ private object RDDCache extends Logging {
// Tracker actor on the master, or remote reference to it on workers
var trackerActor: AbstractActor = null
val registeredRddIds = new HashSet[Int]
def initialize(isMaster: Boolean) {
if (isMaster) {
val tracker = new RDDCacheTracker
@ -50,16 +78,34 @@ private object RDDCache extends Logging {
}
}
// Registers an RDD (on master only)
def registerRDD(rddId: Int, numPartitions: Int) {
registeredRddIds.synchronized {
if (!registeredRddIds.contains(rddId)) {
registeredRddIds += rddId
trackerActor !? RegisterRDD(rddId, numPartitions)
}
}
}
// Get a snapshot of the currently known locations
def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
(trackerActor !? GetCacheLocations) match {
case h: HashMap[Int, Array[List[String]]] => h
case _ => throw new SparkException(
"Internal error: RDDCache did not reply with a HashMap")
}
}
// Gets or computes an RDD split
def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T])
: Iterator[T] = {
val key = (rdd.id, split.index)
logInfo("CachedRDD split key is " + key)
val cache = RDDCache.cache
val loading = RDDCache.loading
logInfo("CachedRDD partition key is " + key)
val cachedVal = cache.get(key)
if (cachedVal != null) {
// Split is in cache, so just return its values
logInfo("Found partition in cache!")
return Iterator.fromArray(cachedVal.asInstanceOf[Array[T]])
} else {
// Mark the split as loading (unless someone else marks it first)
@ -73,13 +119,13 @@ private object RDDCache extends Logging {
loading.add(key)
}
}
val host = System.getProperty("spark.hostname", Utils.localHostName)
trackerActor ! CacheEntryAdded(rdd.id, split.index, host)
// If we got here, we have to load the split
// Tell the master that we're doing so
val host = System.getProperty("spark.hostname", Utils.localHostName)
trackerActor ! AddedToCache(rdd.id, split.index, host)
// TODO: fetch any remote copy of the split that may be available
// TODO: also notify the master that we're loading it
// TODO: also register a listener for when it unloads
logInfo("Computing and caching " + split)
logInfo("Computing partition " + split)
val array = rdd.compute(split).toArray(m)
cache.put(key, array)
loading.synchronized {

Просмотреть файл

@ -175,6 +175,7 @@ extends Logging {
private var nextRddId = new AtomicInteger(0)
// Register a new RDD, returning its RDD ID
private[spark] def newRddId(): Int = {
nextRddId.getAndIncrement()
}