Merge in engine improvements from the Spark Streaming project, developed

jointly with Tathagata Das and Haoyuan Li. This commit imports the changes
and ports them to Mesos 0.9, but does not yet pass unit tests due to
various classes not supporting a graceful stop() yet.
This commit is contained in:
Matei Zaharia 2012-06-07 00:25:47 -07:00
Родитель 7e1c97fc4b
Коммит 63051dd2bc
86 изменённых файлов: 6376 добавлений и 1411 удалений

Просмотреть файл

@ -11,6 +11,7 @@ import scala.xml.{XML,NodeSeq}
import scala.collection.mutable.ArrayBuffer
import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
object WikipediaPageRankStandalone {
def main(args: Array[String]) {
@ -118,23 +119,23 @@ class WPRSerializer extends spark.Serializer {
}
class WPRSerializerInstance extends SerializerInstance {
def serialize[T](t: T): Array[Byte] = {
def serialize[T](t: T): ByteBuffer = {
throw new UnsupportedOperationException()
}
def deserialize[T](bytes: Array[Byte]): T = {
def deserialize[T](bytes: ByteBuffer): T = {
throw new UnsupportedOperationException()
}
def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
throw new UnsupportedOperationException()
}
def outputStream(s: OutputStream): SerializationStream = {
def serializeStream(s: OutputStream): SerializationStream = {
new WPRSerializationStream(s)
}
def inputStream(s: InputStream): DeserializationStream = {
def deserializeStream(s: InputStream): DeserializationStream = {
new WPRDeserializationStream(s)
}
}

Просмотреть файл

@ -0,0 +1,70 @@
package spark
import java.io.EOFException
import java.net.URL
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import spark.storage.BlockException
import spark.storage.BlockManagerId
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val ser = SparkEnv.get.serializer.newInstance()
val blockManager = SparkEnv.get.blockManager
val startTime = System.currentTimeMillis
val addresses = SparkEnv.get.mapOutputTracker.getServerAddresses(shuffleId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[Int]]
for ((address, index) <- addresses.zipWithIndex) {
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += index
}
val blocksByAddress: Seq[(BlockManagerId, Seq[String])] = splitsByAddress.toSeq.map {
case (address, splits) =>
(address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId)))
}
try {
val blockOptions = blockManager.get(blocksByAddress)
logDebug("Fetching map output blocks for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
blockOptions.foreach(x => {
val (blockId, blockOption) = x
blockOption match {
case Some(block) => {
val values = block.asInstanceOf[Iterator[Any]]
for(value <- values) {
val v = value.asInstanceOf[(K, V)]
func(v._1, v._2)
}
}
case None => {
throw new BlockException(blockId, "Did not get block " + blockId)
}
}
})
} catch {
case be: BlockException => {
val regex = "shuffledid_([0-9]*)_([0-9]*)_([0-9]]*)".r
be.blockId match {
case regex(sId, mId, rId) => {
val address = addresses(mId.toInt)
throw new FetchFailedException(address, sId.toInt, mId.toInt, rId.toInt, be)
}
case _ => {
throw be
}
}
}
}
}
}

Просмотреть файл

@ -90,7 +90,8 @@ class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
SparkEnv.get.cacheTracker.dropEntry(datasetId, partition)
// TODO: remove BoundedMemoryCache
SparkEnv.get.cacheTracker.dropEntry(datasetId.asInstanceOf[(Int, Int)]._2, partition)
}
}

Просмотреть файл

@ -1,11 +1,17 @@
package spark
import scala.actors._
import scala.actors.Actor._
import scala.actors.remote._
import akka.actor._
import akka.actor.Actor
import akka.actor.Actor._
import akka.util.duration._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import spark.storage.BlockManager
import spark.storage.StorageLevel
sealed trait CacheTrackerMessage
case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
@ -18,8 +24,8 @@ case object GetCacheStatus extends CacheTrackerMessage
case object GetCacheLocations extends CacheTrackerMessage
case object StopCacheTracker extends CacheTrackerMessage
class CacheTrackerActor extends DaemonActor with Logging {
class CacheTrackerActor extends Actor with Logging {
// TODO: Should probably store (String, CacheType) tuples
private val locs = new HashMap[Int, Array[List[String]]]
/**
@ -28,109 +34,93 @@ class CacheTrackerActor extends DaemonActor with Logging {
private val slaveCapacity = new HashMap[String, Long]
private val slaveUsage = new HashMap[String, Long]
// TODO: Should probably store (String, CacheType) tuples
private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
def act() {
val port = System.getProperty("spark.master.port").toInt
RemoteActor.alive(port)
RemoteActor.register('CacheTracker, self)
logInfo("Registered actor on port " + port)
loop {
react {
case SlaveCacheStarted(host: String, size: Long) =>
logInfo("Started slave cache (size %s) on %s".format(
Utils.memoryBytesToString(size), host))
slaveCapacity.put(host, size)
slaveUsage.put(host, 0)
reply('OK)
def receive = {
case SlaveCacheStarted(host: String, size: Long) =>
logInfo("Started slave cache (size %s) on %s".format(
Utils.memoryBytesToString(size), host))
slaveCapacity.put(host, size)
slaveUsage.put(host, 0)
self.reply(true)
case RegisterRDD(rddId: Int, numPartitions: Int) =>
logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
reply('OK)
case AddedToCache(rddId, partition, host, size) =>
if (size > 0) {
slaveUsage.put(host, getCacheUsage(host) + size)
logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
rddId, partition, host, Utils.memoryBytesToString(size),
Utils.memoryBytesToString(getCacheAvailable(host))))
} else {
logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
}
locs(rddId)(partition) = host :: locs(rddId)(partition)
reply('OK)
case DroppedFromCache(rddId, partition, host, size) =>
if (size > 0) {
logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
rddId, partition, host, Utils.memoryBytesToString(size),
Utils.memoryBytesToString(getCacheAvailable(host))))
slaveUsage.put(host, getCacheUsage(host) - size)
case RegisterRDD(rddId: Int, numPartitions: Int) =>
logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
self.reply(true)
// Do a sanity check to make sure usage is greater than 0.
val usage = getCacheUsage(host)
if (usage < 0) {
logError("Cache usage on %s is negative (%d)".format(host, usage))
}
} else {
logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
}
locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
reply('OK)
case AddedToCache(rddId, partition, host, size) =>
slaveUsage.put(host, getCacheUsage(host) + size)
logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
rddId, partition, host, Utils.memoryBytesToString(size),
Utils.memoryBytesToString(getCacheAvailable(host))))
locs(rddId)(partition) = host :: locs(rddId)(partition)
self.reply(true)
case MemoryCacheLost(host) =>
logInfo("Memory cache lost on " + host)
// TODO: Drop host from the memory locations list of all RDDs
case GetCacheLocations =>
logInfo("Asked for current cache locations")
reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())})
case GetCacheStatus =>
val status = slaveCapacity.map { case (host,capacity) =>
(host, capacity, getCacheUsage(host))
}.toSeq
reply(status)
case StopCacheTracker =>
reply('OK)
exit()
case DroppedFromCache(rddId, partition, host, size) =>
logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
rddId, partition, host, Utils.memoryBytesToString(size),
Utils.memoryBytesToString(getCacheAvailable(host))))
slaveUsage.put(host, getCacheUsage(host) - size)
// Do a sanity check to make sure usage is greater than 0.
val usage = getCacheUsage(host)
if (usage < 0) {
logError("Cache usage on %s is negative (%d)".format(host, usage))
}
}
locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
self.reply(true)
case MemoryCacheLost(host) =>
logInfo("Memory cache lost on " + host)
for ((id, locations) <- locs) {
for (i <- 0 until locations.length) {
locations(i) = locations(i).filterNot(_ == host)
}
}
self.reply(true)
case GetCacheLocations =>
logInfo("Asked for current cache locations")
self.reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())})
case GetCacheStatus =>
val status = slaveCapacity.map { case (host, capacity) =>
(host, capacity, getCacheUsage(host))
}.toSeq
self.reply(status)
case StopCacheTracker =>
logInfo("CacheTrackerActor Server stopped!")
self.reply(true)
self.exit()
}
}
class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
class CacheTracker(isMaster: Boolean, blockManager: BlockManager) extends Logging {
// Tracker actor on the master, or remote reference to it on workers
var trackerActor: AbstractActor = null
val ip: String = System.getProperty("spark.master.host", "localhost")
val port: Int = System.getProperty("spark.master.port", "7077").toInt
val aName: String = "CacheTracker"
if (isMaster) {
}
var trackerActor: ActorRef = if (isMaster) {
val actor = actorOf(new CacheTrackerActor)
remote.register(aName, actor)
actor.start()
logInfo("Registered CacheTrackerActor actor @ " + ip + ":" + port)
actor
} else {
remote.actorFor(aName, ip, port)
}
val registeredRddIds = new HashSet[Int]
// Stores map results for various splits locally
val cache = theCache.newKeySpace()
if (isMaster) {
val tracker = new CacheTrackerActor
tracker.start()
trackerActor = tracker
} else {
val host = System.getProperty("spark.master.host")
val port = System.getProperty("spark.master.port").toInt
trackerActor = RemoteActor.select(Node(host, port), 'CacheTracker)
}
// Report the cache being started.
trackerActor !? SlaveCacheStarted(Utils.getHost, cache.getCapacity)
// Remembers which splits are currently being loaded (on worker nodes)
val loading = new HashSet[(Int, Int)]
val loading = new HashSet[String]
// Registers an RDD (on master only)
def registerRDD(rddId: Int, numPartitions: Int) {
@ -138,24 +128,33 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
if (!registeredRddIds.contains(rddId)) {
logInfo("Registering RDD ID " + rddId + " with cache")
registeredRddIds += rddId
trackerActor !? RegisterRDD(rddId, numPartitions)
(trackerActor ? RegisterRDD(rddId, numPartitions)).as[Any] match {
case Some(true) =>
logInfo("CacheTracker registerRDD " + RegisterRDD(rddId, numPartitions) + " successfully.")
case Some(oops) =>
logError("CacheTracker registerRDD" + RegisterRDD(rddId, numPartitions) + " failed: " + oops)
case None =>
logError("CacheTracker registerRDD None. " + RegisterRDD(rddId, numPartitions))
throw new SparkException("Internal error: CacheTracker registerRDD None.")
}
}
}
}
// Get a snapshot of the currently known locations
def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
(trackerActor !? GetCacheLocations) match {
case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]]
case _ => throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap")
// For BlockManager.scala only
def cacheLost(host: String) {
(trackerActor ? MemoryCacheLost(host)).as[Any] match {
case Some(true) =>
logInfo("CacheTracker successfully removed entries on " + host)
case _ =>
logError("CacheTracker did not reply to MemoryCacheLost")
}
}
// Get the usage status of slave caches. Each tuple in the returned sequence
// is in the form of (host name, capacity, usage).
def getCacheStatus(): Seq[(String, Long, Long)] = {
(trackerActor !? GetCacheStatus) match {
(trackerActor ? GetCacheStatus) match {
case h: Seq[(String, Long, Long)] => h.asInstanceOf[Seq[(String, Long, Long)]]
case _ =>
@ -164,75 +163,94 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
}
}
// For BlockManager.scala only
def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) {
(trackerActor ? t).as[Any] match {
case Some(true) =>
logInfo("CacheTracker notifyTheCacheTrackerFromBlockManager successfully.")
case Some(oops) =>
logError("CacheTracker notifyTheCacheTrackerFromBlockManager failed: " + oops)
case None =>
logError("CacheTracker notifyTheCacheTrackerFromBlockManager None.")
}
}
// Get a snapshot of the currently known locations
def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
(trackerActor ? GetCacheLocations).as[Any] match {
case Some(h: HashMap[_, _]) =>
h.asInstanceOf[HashMap[Int, Array[List[String]]]]
case _ =>
throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap")
}
}
// Gets or computes an RDD split
def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T]): Iterator[T] = {
logInfo("Looking for RDD partition %d:%d".format(rdd.id, split.index))
val cachedVal = cache.get(rdd.id, split.index)
if (cachedVal != null) {
// Split is in cache, so just return its values
logInfo("Found partition in cache!")
return cachedVal.asInstanceOf[Array[T]].iterator
} else {
// Mark the split as loading (unless someone else marks it first)
val key = (rdd.id, split.index)
loading.synchronized {
while (loading.contains(key)) {
// Someone else is loading it; let's wait for them
try { loading.wait() } catch { case _ => }
}
// See whether someone else has successfully loaded it. The main way this would fail
// is for the RDD-level cache eviction policy if someone else has loaded the same RDD
// partition but we didn't want to make space for it. However, that case is unlikely
// because it's unlikely that two threads would work on the same RDD partition. One
// downside of the current code is that threads wait serially if this does happen.
val cachedVal = cache.get(rdd.id, split.index)
if (cachedVal != null) {
return cachedVal.asInstanceOf[Array[T]].iterator
}
// Nobody's loading it and it's not in the cache; let's load it ourselves
loading.add(key)
}
// If we got here, we have to load the split
// Tell the master that we're doing so
def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = {
val key = "rdd:%d:%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
case Some(cachedValues) =>
// Split is in cache, so just return its values
logInfo("Found partition in cache!")
return cachedValues.asInstanceOf[Iterator[T]]
// TODO: fetch any remote copy of the split that may be available
logInfo("Computing partition " + split)
var array: Array[T] = null
var putResponse: CachePutResponse = null
try {
array = rdd.compute(split).toArray(m)
putResponse = cache.put(rdd.id, split.index, array)
} finally {
// Tell other threads that we've finished our attempt to load the key (whether or not
// we've actually succeeded to put it in the map)
case None =>
// Mark the split as loading (unless someone else marks it first)
loading.synchronized {
loading.remove(key)
loading.notifyAll()
if (loading.contains(key)) {
logInfo("Loading contains " + key + ", waiting...")
while (loading.contains(key)) {
try {loading.wait()} catch {case _ =>}
}
logInfo("Loading no longer contains " + key + ", so returning cached result")
// See whether someone else has successfully loaded it. The main way this would fail
// is for the RDD-level cache eviction policy if someone else has loaded the same RDD
// partition but we didn't want to make space for it. However, that case is unlikely
// because it's unlikely that two threads would work on the same RDD partition. One
// downside of the current code is that threads wait serially if this does happen.
blockManager.get(key) match {
case Some(values) =>
return values.asInstanceOf[Iterator[T]]
case None =>
logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
loading.add(key)
}
} else {
loading.add(key)
}
}
}
putResponse match {
case CachePutSuccess(size) => {
// Tell the master that we added the entry. Don't return until it
// replies so it can properly schedule future tasks that use this RDD.
trackerActor !? AddedToCache(rdd.id, split.index, Utils.getHost, size)
// If we got here, we have to load the split
// Tell the master that we're doing so
//val host = System.getProperty("spark.hostname", Utils.localHostName)
//val future = trackerActor !! AddedToCache(rdd.id, split.index, host)
// TODO: fetch any remote copy of the split that may be available
// TODO: also register a listener for when it unloads
logInfo("Computing partition " + split)
try {
val values = new ArrayBuffer[Any]
values ++= rdd.compute(split)
blockManager.put(key, values.iterator, storageLevel, false)
//future.apply() // Wait for the reply from the cache tracker
return values.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
loading.remove(key)
loading.notifyAll()
}
}
case _ => null
}
return array.iterator
}
}
// Called by the Cache to report that an entry has been dropped from it
def dropEntry(datasetId: Any, partition: Int) {
datasetId match {
//TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here.
case (cache.keySpaceId, rddId: Int) => trackerActor !! DroppedFromCache(rddId, partition, Utils.getHost)
}
def dropEntry(rddId: Int, partition: Int) {
//TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here.
trackerActor !! DroppedFromCache(rddId, partition, Utils.localHostName())
}
def stop() {
trackerActor !? StopCacheTracker
trackerActor !! StopCacheTracker
registeredRddIds.clear()
trackerActor = null
}

Просмотреть файл

@ -22,11 +22,12 @@ class CoGroupAggregator
{ (b1, b2) => b1 ++ b2 })
with Serializable
class CoGroupedRDD[K](rdds: Seq[RDD[(_, _)]], part: Partitioner)
class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging {
val aggr = new CoGroupAggregator
@transient
override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
for ((rdd, index) <- rdds.zipWithIndex) {
@ -67,9 +68,10 @@ class CoGroupedRDD[K](rdds: Seq[RDD[(_, _)]], part: Partitioner)
override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
val numRdds = split.deps.size
val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
map.getOrElseUpdate(k, Array.fill(rdds.size)(new ArrayBuffer[Any]))
map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any]))
}
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, itsSplit) => {

Просмотреть файл

@ -1,374 +0,0 @@
package spark
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map}
/**
* A task created by the DAG scheduler. Knows its stage ID and map ouput tracker generation.
*/
abstract class DAGTask[T](val runId: Int, val stageId: Int) extends Task[T] {
val gen = SparkEnv.get.mapOutputTracker.getGeneration
override def generation: Option[Long] = Some(gen)
}
/**
* A completion event passed by the underlying task scheduler to the DAG scheduler.
*/
case class CompletionEvent(
task: DAGTask[_],
reason: TaskEndReason,
result: Any,
accumUpdates: Map[Long, Any])
/**
* Various possible reasons why a DAG task ended. The underlying scheduler is supposed to retry
* tasks several times for "ephemeral" failures, and only report back failures that require some
* old stages to be resubmitted, such as shuffle map fetch failures.
*/
sealed trait TaskEndReason
case object Success extends TaskEndReason
case class FetchFailed(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
case class ExceptionFailure(exception: Throwable) extends TaskEndReason
case class OtherFailure(message: String) extends TaskEndReason
/**
* A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
* each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
* schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/
private trait DAGScheduler extends Scheduler with Logging {
// Must be implemented by subclasses to start running a set of tasks. The subclass should also
// attempt to run different sets of tasks in the order given by runId (lower values first).
def submitTasks(tasks: Seq[Task[_]], runId: Int): Unit
// Must be called by subclasses to report task completions or failures.
def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]) {
lock.synchronized {
val dagTask = task.asInstanceOf[DAGTask[_]]
eventQueues.get(dagTask.runId) match {
case Some(queue) =>
queue += CompletionEvent(dagTask, reason, result, accumUpdates)
lock.notifyAll()
case None =>
logInfo("Ignoring completion event for DAG job " + dagTask.runId + " because it's gone")
}
}
}
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
val RESUBMIT_TIMEOUT = 2000L
// The time, in millis, to wake up between polls of the completion queue in order to potentially
// resubmit failed stages
val POLL_TIMEOUT = 500L
private val lock = new Object // Used for access to the entire DAGScheduler
private val eventQueues = new HashMap[Int, Queue[CompletionEvent]] // Indexed by run ID
val nextRunId = new AtomicInteger(0)
val nextStageId = new AtomicInteger(0)
val idToStage = new HashMap[Int, Stage]
val shuffleToMapStage = new HashMap[Int, Stage]
var cacheLocs = new HashMap[Int, Array[List[String]]]
val env = SparkEnv.get
val cacheTracker = env.cacheTracker
val mapOutputTracker = env.mapOutputTracker
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
cacheLocs(rdd.id)
}
def updateCacheLocs() {
cacheLocs = cacheTracker.getLocationsSnapshot()
}
def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = {
shuffleToMapStage.get(shuf.shuffleId) match {
case Some(stage) => stage
case None =>
val stage = newStage(shuf.rdd, Some(shuf))
shuffleToMapStage(shuf.shuffleId) = stage
stage
}
}
def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of splits is unknown
cacheTracker.registerRDD(rdd.id, rdd.splits.size)
if (shuffleDep != None) {
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
}
val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd))
idToStage(id) = stage
stage
}
def getParentStages(rdd: RDD[_]): List[Stage] = {
val parents = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(r: RDD[_]) {
if (!visited(r)) {
visited += r
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
cacheTracker.registerRDD(r.id, r.splits.size)
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_,_] =>
parents += getShuffleMapStage(shufDep)
case _ =>
visit(dep.rdd)
}
}
}
}
visit(rdd)
parents.toList
}
def getMissingParentStages(stage: Stage): List[Stage] = {
val missing = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
val locs = getCacheLocs(rdd)
for (p <- 0 until rdd.splits.size) {
if (locs(p) == Nil) {
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_,_] =>
val stage = getShuffleMapStage(shufDep)
if (!stage.isAvailable) {
missing += stage
}
case narrowDep: NarrowDependency[_] =>
visit(narrowDep.rdd)
}
}
}
}
}
}
visit(stage.rdd)
missing.toList
}
override def runJob[T, U](
finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
allowLocal: Boolean)
(implicit m: ClassManifest[U]): Array[U] = {
lock.synchronized {
val runId = nextRunId.getAndIncrement()
val outputParts = partitions.toArray
val numOutputParts: Int = partitions.size
val finalStage = newStage(finalRdd, None)
val results = new Array[U](numOutputParts)
val finished = new Array[Boolean](numOutputParts)
var numFinished = 0
val waiting = new HashSet[Stage] // stages we need to run whose parents aren't done
val running = new HashSet[Stage] // stages we are running right now
val failed = new HashSet[Stage] // stages that must be resubmitted due to fetch failures
val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // missing tasks from each stage
var lastFetchFailureTime: Long = 0 // used to wait a bit to avoid repeated resubmits
SparkEnv.set(env)
updateCacheLocs()
logInfo("Final stage: " + finalStage)
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
// Optimization for short actions like first() and take() that can be computed locally
// without shipping tasks to the cluster.
if (allowLocal && finalStage.parents.size == 0 && numOutputParts == 1) {
logInfo("Computing the requested partition locally")
val split = finalRdd.splits(outputParts(0))
val taskContext = new TaskContext(finalStage.id, outputParts(0), 0)
return Array(func(taskContext, finalRdd.iterator(split)))
}
// Register the job ID so that we can get completion events for it
eventQueues(runId) = new Queue[CompletionEvent]
def submitStage(stage: Stage) {
if (!waiting(stage) && !running(stage)) {
val missing = getMissingParentStages(stage)
if (missing == Nil) {
logInfo("Submitting " + stage + ", which has no missing parents")
submitMissingTasks(stage)
running += stage
} else {
for (parent <- missing) {
submitStage(parent)
}
waiting += stage
}
}
}
def submitMissingTasks(stage: Stage) {
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
var tasks = ArrayBuffer[Task[_]]()
if (stage == finalStage) {
for (id <- 0 until numOutputParts if (!finished(id))) {
val part = outputParts(id)
val locs = getPreferredLocs(finalRdd, part)
tasks += new ResultTask(runId, finalStage.id, finalRdd, func, part, locs, id)
}
} else {
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
val locs = getPreferredLocs(stage.rdd, p)
tasks += new ShuffleMapTask(runId, stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
}
}
myPending ++= tasks
submitTasks(tasks, runId)
}
submitStage(finalStage)
while (numFinished != numOutputParts) {
val eventOption = waitForEvent(runId, POLL_TIMEOUT)
val time = System.currentTimeMillis // TODO: use a pluggable clock for testability
// If we got an event off the queue, mark the task done or react to a fetch failure
if (eventOption != None) {
val evt = eventOption.get
val stage = idToStage(evt.task.stageId)
pendingTasks(stage) -= evt.task
if (evt.reason == Success) {
// A task ended
logInfo("Completed " + evt.task)
Accumulators.add(evt.accumUpdates)
evt.task match {
case rt: ResultTask[_, _] =>
results(rt.outputId) = evt.result.asInstanceOf[U]
finished(rt.outputId) = true
numFinished += 1
case smt: ShuffleMapTask =>
val stage = idToStage(smt.stageId)
stage.addOutputLoc(smt.partition, evt.result.asInstanceOf[String])
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
logInfo(stage + " finished; looking for newly runnable stages")
running -= stage
if (stage.shuffleDep != None) {
mapOutputTracker.registerMapOutputs(
stage.shuffleDep.get.shuffleId,
stage.outputLocs.map(_.head).toArray)
}
updateCacheLocs()
val newlyRunnable = new ArrayBuffer[Stage]
for (stage <- waiting if getMissingParentStages(stage) == Nil) {
newlyRunnable += stage
}
waiting --= newlyRunnable
running ++= newlyRunnable
for (stage <- newlyRunnable) {
submitMissingTasks(stage)
}
}
}
} else {
evt.reason match {
case FetchFailed(serverUri, shuffleId, mapId, reduceId) =>
// Mark the stage that the reducer was in as unrunnable
val failedStage = idToStage(evt.task.stageId)
running -= failedStage
failed += failedStage
// TODO: Cancel running tasks in the stage
logInfo("Marking " + failedStage + " for resubmision due to a fetch failure")
// Mark the map whose fetch failed as broken in the map stage
val mapStage = shuffleToMapStage(shuffleId)
mapStage.removeOutputLoc(mapId, serverUri)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, serverUri)
logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission")
failed += mapStage
// Remember that a fetch failed now; this is used to resubmit the broken
// stages later, after a small wait (to give other tasks the chance to fail)
lastFetchFailureTime = time
// TODO: If there are a lot of fetch failures on the same node, maybe mark all
// outputs on the node as dead.
case _ =>
// Non-fetch failure -- probably a bug in the job, so bail out
throw new SparkException("Task failed: " + evt.task + ", reason: " + evt.reason)
// TODO: Cancel all tasks that are still running
}
}
} // end if (evt != null)
// If fetches have failed recently and we've waited for the right timeout,
// resubmit all the failed stages
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
logInfo("Resubmitting failed stages")
updateCacheLocs()
for (stage <- failed) {
submitStage(stage)
}
failed.clear()
}
}
eventQueues -= runId
return results
}
}
def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
// If the partition is cached, return the cache locations
val cached = getCacheLocs(rdd)(partition)
if (cached != Nil) {
return cached
}
// If the RDD has some placement preferences (as is the case for input RDDs), get those
val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList
if (rddPrefs != Nil) {
return rddPrefs
}
// If the RDD has narrow dependencies, pick the first partition of the first narrow dep
// that has any placement preferences. Ideally we would choose based on transfer sizes,
// but this will do for now.
rdd.dependencies.foreach(_ match {
case n: NarrowDependency[_] =>
for (inPart <- n.getParents(partition)) {
val locs = getPreferredLocs(n.rdd, inPart)
if (locs != Nil)
return locs;
}
case _ =>
})
return Nil
}
// Assumes that lock is held on entrance, but will release it to wait for the next event.
def waitForEvent(runId: Int, timeout: Long): Option[CompletionEvent] = {
val endTime = System.currentTimeMillis() + timeout // TODO: Use pluggable clock for testing
while (eventQueues(runId).isEmpty) {
val time = System.currentTimeMillis()
if (time >= endTime) {
return None
} else {
lock.wait(endTime - time)
}
}
return Some(eventQueues(runId).dequeue())
}
}

Просмотреть файл

@ -8,7 +8,7 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd, false) {
class ShuffleDependency[K, V, C](
val shuffleId: Int,
rdd: RDD[(K, V)],
@transient rdd: RDD[(K, V)],
val aggregator: Aggregator[K, V, C],
val partitioner: Partitioner)
extends Dependency(rdd, true)

Просмотреть файл

@ -1,75 +0,0 @@
package spark
import java.io.File
import java.io.{FileOutputStream,FileInputStream}
import java.io.IOException
import java.util.LinkedHashMap
import java.util.UUID
// TODO: cache into a separate directory using Utils.createTempDir
// TODO: clean up disk cache afterwards
class DiskSpillingCache extends BoundedMemoryCache {
private val diskMap = new LinkedHashMap[(Any, Int), File](32, 0.75f, true)
override def get(datasetId: Any, partition: Int): Any = {
synchronized {
val ser = SparkEnv.get.serializer.newInstance()
super.get(datasetId, partition) match {
case bytes: Any => // found in memory
ser.deserialize(bytes.asInstanceOf[Array[Byte]])
case _ => diskMap.get((datasetId, partition)) match {
case file: Any => // found on disk
try {
val startTime = System.currentTimeMillis
val bytes = new Array[Byte](file.length.toInt)
new FileInputStream(file).read(bytes)
val timeTaken = System.currentTimeMillis - startTime
logInfo("Reading key (%s, %d) of size %d bytes from disk took %d ms".format(
datasetId, partition, file.length, timeTaken))
super.put(datasetId, partition, bytes)
ser.deserialize(bytes.asInstanceOf[Array[Byte]])
} catch {
case e: IOException =>
logWarning("Failed to read key (%s, %d) from disk at %s: %s".format(
datasetId, partition, file.getPath(), e.getMessage()))
diskMap.remove((datasetId, partition)) // remove dead entry
null
}
case _ => // not found
null
}
}
}
}
override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
var ser = SparkEnv.get.serializer.newInstance()
super.put(datasetId, partition, ser.serialize(value))
}
/**
* Spill the given entry to disk. Assumes that a lock is held on the
* DiskSpillingCache. Assumes that entry.value is a byte array.
*/
override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
logInfo("Spilling key (%s, %d) of size %d to make space".format(
datasetId, partition, entry.size))
val cacheDir = System.getProperty(
"spark.diskSpillingCache.cacheDir",
System.getProperty("java.io.tmpdir"))
val file = new File(cacheDir, "spark-dsc-" + UUID.randomUUID.toString)
try {
val stream = new FileOutputStream(file)
stream.write(entry.value.asInstanceOf[Array[Byte]])
stream.close()
diskMap.put((datasetId, partition), file)
} catch {
case e: IOException =>
logWarning("Failed to spill key (%s, %d) to disk at %s: %s".format(
datasetId, partition, file.getPath(), e.getMessage()))
// Do nothing and let the entry be discarded
}
}
}

Просмотреть файл

@ -0,0 +1,39 @@
package spark
import spark.partial.BoundedDouble
import spark.partial.MeanEvaluator
import spark.partial.PartialResult
import spark.partial.SumEvaluator
import spark.util.StatCounter
/**
* Extra functions available on RDDs of Doubles through an implicit conversion.
*/
class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
def sum(): Double = {
self.reduce(_ + _)
}
def stats(): StatCounter = {
self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b))
}
def mean(): Double = stats().mean
def variance(): Double = stats().variance
def stdev(): Double = stats().stdev
def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
val evaluator = new MeanEvaluator(self.splits.size, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
}
def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
val evaluator = new SumEvaluator(self.splits.size, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
}
}

Просмотреть файл

@ -10,9 +10,10 @@ import scala.collection.mutable.ArrayBuffer
import com.google.protobuf.ByteString
import org.apache.mesos._
import org.apache.mesos.Protos._
import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
import spark.broadcast._
import spark.scheduler._
/**
* The Mesos executor for Spark.
@ -29,6 +30,9 @@ class Executor extends org.apache.mesos.Executor with Logging {
executorInfo: ExecutorInfo,
frameworkInfo: FrameworkInfo,
slaveInfo: SlaveInfo) {
// Make sure the local hostname we report matches Mesos's name for this host
Utils.setCustomHostname(slaveInfo.getHostname())
// Read spark.* system properties from executor arg
val props = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
for ((key, value) <- props) {
@ -39,7 +43,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
RemoteActor.classLoader = getClass.getClassLoader
// Initialize Spark environment (using system properties read above)
env = SparkEnv.createFromSystemProperties(false)
env = SparkEnv.createFromSystemProperties(false, false)
SparkEnv.set(env)
// Old stuff that isn't yet using env
Broadcast.initialize(false)
@ -57,11 +61,11 @@ class Executor extends org.apache.mesos.Executor with Logging {
override def reregistered(d: ExecutorDriver, s: SlaveInfo) {}
override def launchTask(d: ExecutorDriver, task: TaskInfo) {
override def launchTask(d: ExecutorDriver, task: MTaskInfo) {
threadPool.execute(new TaskRunner(task, d))
}
class TaskRunner(info: TaskInfo, d: ExecutorDriver)
class TaskRunner(info: MTaskInfo, d: ExecutorDriver)
extends Runnable {
override def run() = {
val tid = info.getTaskId.getValue
@ -74,11 +78,11 @@ class Executor extends org.apache.mesos.Executor with Logging {
.setState(TaskState.TASK_RUNNING)
.build())
try {
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear
val task = ser.deserialize[Task[Any]](info.getData.toByteArray, classLoader)
for (gen <- task.generation) {// Update generation if any is set
env.mapOutputTracker.updateGeneration(gen)
}
val task = ser.deserialize[Task[Any]](info.getData.asReadOnlyByteBuffer, classLoader)
env.mapOutputTracker.updateGeneration(task.generation)
val value = task.run(tid.toInt)
val accumUpdates = Accumulators.values
val result = new TaskResult(value, accumUpdates)
@ -105,9 +109,11 @@ class Executor extends org.apache.mesos.Executor with Logging {
.setData(ByteString.copyFrom(ser.serialize(reason)))
.build())
// TODO: Handle errors in tasks less dramatically
// TODO: Should we exit the whole executor here? On the one hand, the failed task may
// have left some weird state around depending on when the exception was thrown, but on
// the other hand, maybe we could detect that when future tasks fail and exit then.
logError("Exception in task ID " + tid, t)
System.exit(1)
//System.exit(1)
}
}
}

Просмотреть файл

@ -1,7 +1,9 @@
package spark
import spark.storage.BlockManagerId
class FetchFailedException(
val serverUri: String,
val bmAddress: BlockManagerId,
val shuffleId: Int,
val mapId: Int,
val reduceId: Int,
@ -9,10 +11,10 @@ class FetchFailedException(
extends Exception {
override def getMessage(): String =
"Fetch failed: %s %d %d %d".format(serverUri, shuffleId, mapId, reduceId)
"Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
override def getCause(): Throwable = cause
def toTaskEndReason: TaskEndReason =
FetchFailed(serverUri, shuffleId, mapId, reduceId)
FetchFailed(bmAddress, shuffleId, mapId, reduceId)
}

Просмотреть файл

@ -1,6 +1,7 @@
package spark
import java.io._
import java.nio.ByteBuffer
class JavaSerializationStream(out: OutputStream) extends SerializationStream {
val objOut = new ObjectOutputStream(out)
@ -9,10 +10,11 @@ class JavaSerializationStream(out: OutputStream) extends SerializationStream {
def close() { objOut.close() }
}
class JavaDeserializationStream(in: InputStream) extends DeserializationStream {
class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
extends DeserializationStream {
val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
Class.forName(desc.getName, false, loader)
}
def readObject[T](): T = objIn.readObject().asInstanceOf[T]
@ -20,35 +22,36 @@ class JavaDeserializationStream(in: InputStream) extends DeserializationStream {
}
class JavaSerializerInstance extends SerializerInstance {
def serialize[T](t: T): Array[Byte] = {
def serialize[T](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = outputStream(bos)
val out = serializeStream(bos)
out.writeObject(t)
out.close()
bos.toByteArray
ByteBuffer.wrap(bos.toByteArray)
}
def deserialize[T](bytes: Array[Byte]): T = {
val bis = new ByteArrayInputStream(bytes)
val in = inputStream(bis)
def deserialize[T](bytes: ByteBuffer): T = {
val bis = new ByteArrayInputStream(bytes.array())
val in = deserializeStream(bis)
in.readObject().asInstanceOf[T]
}
def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
val bis = new ByteArrayInputStream(bytes)
val ois = new ObjectInputStream(bis) {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, loader)
}
return ois.readObject.asInstanceOf[T]
def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
val bis = new ByteArrayInputStream(bytes.array())
val in = deserializeStream(bis, loader)
in.readObject().asInstanceOf[T]
}
def outputStream(s: OutputStream): SerializationStream = {
def serializeStream(s: OutputStream): SerializationStream = {
new JavaSerializationStream(s)
}
def inputStream(s: InputStream): DeserializationStream = {
new JavaDeserializationStream(s)
def deserializeStream(s: InputStream): DeserializationStream = {
new JavaDeserializationStream(s, currentThread.getContextClassLoader)
}
def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
new JavaDeserializationStream(s, loader)
}
}

Просмотреть файл

@ -1,16 +0,0 @@
package spark
import org.apache.mesos._
import org.apache.mesos.Protos._
/**
* Class representing a parallel job in MesosScheduler. Schedules the job by implementing various
* callbacks.
*/
abstract class Job(val runId: Int, val jobId: Int) {
def slaveOffer(s: Offer, availableCpus: Double): Option[TaskInfo]
def statusUpdate(t: TaskStatus): Unit
def error(message: String): Unit
}

Просмотреть файл

@ -12,6 +12,8 @@ import com.esotericsoftware.kryo.{Serializer => KSerializer}
import com.esotericsoftware.kryo.serialize.ClassSerializer
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
import spark.storage._
/**
* Zig-zag encoder used to write object sizes to serialization streams.
* Based on Kryo's integer encoder.
@ -64,57 +66,90 @@ object ZigZag {
}
}
class KryoSerializationStream(kryo: Kryo, buf: ByteBuffer, out: OutputStream)
class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream)
extends SerializationStream {
val channel = Channels.newChannel(out)
def writeObject[T](t: T) {
kryo.writeClassAndObject(buf, t)
ZigZag.writeInt(buf.position(), out)
buf.flip()
channel.write(buf)
buf.clear()
kryo.writeClassAndObject(threadBuffer, t)
ZigZag.writeInt(threadBuffer.position(), out)
threadBuffer.flip()
channel.write(threadBuffer)
threadBuffer.clear()
}
def flush() { out.flush() }
def close() { out.close() }
}
class KryoDeserializationStream(buf: ObjectBuffer, in: InputStream)
class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream)
extends DeserializationStream {
def readObject[T](): T = {
val len = ZigZag.readInt(in)
buf.readClassAndObject(in, len).asInstanceOf[T]
objectBuffer.readClassAndObject(in, len).asInstanceOf[T]
}
def close() { in.close() }
}
class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
val buf = ks.threadBuf.get()
val kryo = ks.kryo
val threadBuffer = ks.threadBuffer.get()
val objectBuffer = ks.objectBuffer.get()
def serialize[T](t: T): Array[Byte] = {
buf.writeClassAndObject(t)
def serialize[T](t: T): ByteBuffer = {
// Write it to our thread-local scratch buffer first to figure out the size, then return a new
// ByteBuffer of the appropriate size
threadBuffer.clear()
kryo.writeClassAndObject(threadBuffer, t)
val newBuf = ByteBuffer.allocate(threadBuffer.position)
threadBuffer.flip()
newBuf.put(threadBuffer)
newBuf.flip()
newBuf
}
def deserialize[T](bytes: Array[Byte]): T = {
buf.readClassAndObject(bytes).asInstanceOf[T]
def deserialize[T](bytes: ByteBuffer): T = {
kryo.readClassAndObject(bytes).asInstanceOf[T]
}
def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
val oldClassLoader = ks.kryo.getClassLoader
ks.kryo.setClassLoader(loader)
val obj = buf.readClassAndObject(bytes).asInstanceOf[T]
ks.kryo.setClassLoader(oldClassLoader)
def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
val oldClassLoader = kryo.getClassLoader
kryo.setClassLoader(loader)
val obj = kryo.readClassAndObject(bytes).asInstanceOf[T]
kryo.setClassLoader(oldClassLoader)
obj
}
def outputStream(s: OutputStream): SerializationStream = {
new KryoSerializationStream(ks.kryo, ks.threadByteBuf.get(), s)
def serializeStream(s: OutputStream): SerializationStream = {
threadBuffer.clear()
new KryoSerializationStream(kryo, threadBuffer, s)
}
def inputStream(s: InputStream): DeserializationStream = {
new KryoDeserializationStream(buf, s)
def deserializeStream(s: InputStream): DeserializationStream = {
new KryoDeserializationStream(objectBuffer, s)
}
override def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
threadBuffer.clear()
while (iterator.hasNext) {
val element = iterator.next()
// TODO: Do we also want to write the object's size? Doesn't seem necessary.
kryo.writeClassAndObject(threadBuffer, element)
}
val newBuf = ByteBuffer.allocate(threadBuffer.position)
threadBuffer.flip()
newBuf.put(threadBuffer)
newBuf.flip()
newBuf
}
override def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
buffer.rewind()
new Iterator[Any] {
override def hasNext: Boolean = buffer.remaining > 0
override def next(): Any = kryo.readClassAndObject(buffer)
}
}
}
@ -126,20 +161,17 @@ trait KryoRegistrator {
class KryoSerializer extends Serializer with Logging {
val kryo = createKryo()
val bufferSize =
System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024
val threadBuf = new ThreadLocal[ObjectBuffer] {
val objectBuffer = new ThreadLocal[ObjectBuffer] {
override def initialValue = new ObjectBuffer(kryo, bufferSize)
}
val threadByteBuf = new ThreadLocal[ByteBuffer] {
val threadBuffer = new ThreadLocal[ByteBuffer] {
override def initialValue = ByteBuffer.allocate(bufferSize)
}
def createKryo(): Kryo = {
// This is used so we can serialize/deserialize objects without a zero-arg
// constructor.
val kryo = new KryoReflectionFactorySupport()
// Register some commonly used classes
@ -148,14 +180,20 @@ class KryoSerializer extends Serializer with Logging {
Array(1), Array(1.0), Array(1.0f), Array(1L), Array(""), Array(("", "")),
Array(new java.lang.Object), Array(1.toByte), Array(true), Array('c'),
// Specialized Tuple2s
("", ""), (1, 1), (1.0, 1.0), (1L, 1L),
("", ""), ("", 1), (1, 1), (1.0, 1.0), (1L, 1L),
(1, 1.0), (1.0, 1), (1L, 1.0), (1.0, 1L), (1, 1L), (1L, 1),
// Scala collections
List(1), mutable.ArrayBuffer(1),
// Options and Either
Some(1), Left(1), Right(1),
// Higher-dimensional tuples
(1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1)
(1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1),
None,
ByteBuffer.allocate(1),
StorageLevel.MEMORY_ONLY_DESER,
PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY_DESER),
GotBlock("1", ByteBuffer.allocate(1)),
GetBlock("1")
)
for (obj <- toRegister) {
kryo.register(obj.getClass)

Просмотреть файл

@ -28,9 +28,11 @@ trait Logging {
}
// Log methods that take only a String
def logInfo(msg: => String) = if (log.isInfoEnabled) log.info(msg)
def logInfo(msg: => String) = if (log.isInfoEnabled /*&& msg.contains("job finished in")*/) log.info(msg)
def logDebug(msg: => String) = if (log.isDebugEnabled) log.debug(msg)
def logTrace(msg: => String) = if (log.isTraceEnabled) log.trace(msg)
def logWarning(msg: => String) = if (log.isWarnEnabled) log.warn(msg)
@ -43,6 +45,9 @@ trait Logging {
def logDebug(msg: => String, throwable: Throwable) =
if (log.isDebugEnabled) log.debug(msg)
def logTrace(msg: => String, throwable: Throwable) =
if (log.isTraceEnabled) log.trace(msg)
def logWarning(msg: => String, throwable: Throwable) =
if (log.isWarnEnabled) log.warn(msg, throwable)

Просмотреть файл

@ -2,80 +2,80 @@ package spark
import java.util.concurrent.ConcurrentHashMap
import scala.actors._
import scala.actors.Actor._
import scala.actors.remote._
import akka.actor._
import akka.actor.Actor
import akka.actor.Actor._
import akka.util.duration._
import scala.collection.mutable.HashSet
import spark.storage.BlockManagerId
sealed trait MapOutputTrackerMessage
case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage
case object StopMapOutputTracker extends MapOutputTrackerMessage
class MapOutputTrackerActor(serverUris: ConcurrentHashMap[Int, Array[String]])
extends DaemonActor with Logging {
def act() {
val port = System.getProperty("spark.master.port").toInt
RemoteActor.alive(port)
RemoteActor.register('MapOutputTracker, self)
logInfo("Registered actor on port " + port)
loop {
react {
case GetMapOutputLocations(shuffleId: Int) =>
logInfo("Asked to get map output locations for shuffle " + shuffleId)
reply(serverUris.get(shuffleId))
case StopMapOutputTracker =>
reply('OK)
exit()
}
}
class MapOutputTrackerActor(bmAddresses: ConcurrentHashMap[Int, Array[BlockManagerId]])
extends Actor with Logging {
def receive = {
case GetMapOutputLocations(shuffleId: Int) =>
logInfo("Asked to get map output locations for shuffle " + shuffleId)
self.reply(bmAddresses.get(shuffleId))
case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
self.reply(true)
self.exit()
}
}
class MapOutputTracker(isMaster: Boolean) extends Logging {
var trackerActor: AbstractActor = null
val ip: String = System.getProperty("spark.master.host", "localhost")
val port: Int = System.getProperty("spark.master.port", "7077").toInt
val aName: String = "MapOutputTracker"
private var serverUris = new ConcurrentHashMap[Int, Array[String]]
private var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
private var generation: Long = 0
private var generationLock = new java.lang.Object
if (isMaster) {
val tracker = new MapOutputTrackerActor(serverUris)
tracker.start()
trackerActor = tracker
var trackerActor: ActorRef = if (isMaster) {
val actor = actorOf(new MapOutputTrackerActor(bmAddresses))
remote.register(aName, actor)
logInfo("Registered MapOutputTrackerActor actor @ " + ip + ":" + port)
actor
} else {
val host = System.getProperty("spark.master.host")
val port = System.getProperty("spark.master.port").toInt
trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker)
remote.actorFor(aName, ip, port)
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (serverUris.get(shuffleId) != null) {
if (bmAddresses.get(shuffleId) != null) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
serverUris.put(shuffleId, new Array[String](numMaps))
bmAddresses.put(shuffleId, new Array[BlockManagerId](numMaps))
}
def registerMapOutput(shuffleId: Int, mapId: Int, serverUri: String) {
var array = serverUris.get(shuffleId)
def registerMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
var array = bmAddresses.get(shuffleId)
array.synchronized {
array(mapId) = serverUri
array(mapId) = bmAddress
}
}
def registerMapOutputs(shuffleId: Int, locs: Array[String]) {
serverUris.put(shuffleId, Array[String]() ++ locs)
def registerMapOutputs(shuffleId: Int, locs: Array[BlockManagerId], changeGeneration: Boolean = false) {
bmAddresses.put(shuffleId, Array[BlockManagerId]() ++ locs)
if (changeGeneration) {
incrementGeneration()
}
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, serverUri: String) {
var array = serverUris.get(shuffleId)
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
var array = bmAddresses.get(shuffleId)
if (array != null) {
array.synchronized {
if (array(mapId) == serverUri) {
if (array(mapId) == bmAddress) {
array(mapId) = null
}
}
@ -89,10 +89,10 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
val fetching = new HashSet[Int]
// Called on possibly remote nodes to get the server URIs for a given shuffle
def getServerUris(shuffleId: Int): Array[String] = {
val locs = serverUris.get(shuffleId)
def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = {
val locs = bmAddresses.get(shuffleId)
if (locs == null) {
logInfo("Don't have map outputs for " + shuffleId + ", fetching them")
logInfo("Don't have map outputs for shuffe " + shuffleId + ", fetching them")
fetching.synchronized {
if (fetching.contains(shuffleId)) {
// Someone else is fetching it; wait for them to be done
@ -103,15 +103,17 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
case _ =>
}
}
return serverUris.get(shuffleId)
return bmAddresses.get(shuffleId)
} else {
fetching += shuffleId
}
}
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
val fetched = (trackerActor !? GetMapOutputLocations(shuffleId)).asInstanceOf[Array[String]]
serverUris.put(shuffleId, fetched)
val fetched = (trackerActor ? GetMapOutputLocations(shuffleId)).as[Array[BlockManagerId]].get
logInfo("Got the output locations")
bmAddresses.put(shuffleId, fetched)
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
@ -121,14 +123,10 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
return locs
}
}
def getMapOutputUri(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int): String = {
"%s/shuffle/%s/%s/%s".format(serverUri, shuffleId, mapId, reduceId)
}
def stop() {
trackerActor !? StopMapOutputTracker
serverUris.clear()
trackerActor !! StopMapOutputTracker
bmAddresses.clear()
trackerActor = null
}
@ -153,7 +151,7 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
serverUris = new ConcurrentHashMap[Int, Array[String]]
bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
generation = newGen
}
}

Просмотреть файл

@ -4,14 +4,14 @@ import java.io.EOFException
import java.net.URL
import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong
import java.util.HashSet
import java.util.Random
import java.util.{HashMap => JHashMap}
import java.util.Date
import java.text.SimpleDateFormat
import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.Map
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.BytesWritable
@ -34,7 +34,9 @@ import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob}
import org.apache.hadoop.mapreduce.TaskAttemptID
import org.apache.hadoop.mapreduce.TaskAttemptContext
import SparkContext._
import spark.SparkContext._
import spark.partial.BoundedDouble
import spark.partial.PartialResult
/**
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@ -43,19 +45,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
self: RDD[(K, V)])
extends Logging
with Serializable {
def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = {
def mergeMaps(m1: HashMap[K, V], m2: HashMap[K, V]): HashMap[K, V] = {
for ((k, v) <- m2) {
m1.get(k) match {
case None => m1(k) = v
case Some(w) => m1(k) = func(w, v)
}
}
return m1
}
self.map(pair => HashMap(pair)).reduce(mergeMaps)
}
def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C,
@ -77,6 +66,39 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = {
combineByKey[V]((v: V) => v, func, func, numSplits)
}
def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {
def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
val map = new JHashMap[K, V]
for ((k, v) <- iter) {
val old = map.get(k)
map.put(k, if (old == null) v else func(old, v))
}
Iterator(map)
}
def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = {
for ((k, v) <- m2) {
val old = m1.get(k)
m1.put(k, if (old == null) v else func(old, v))
}
return m1
}
self.mapPartitions(reducePartition).reduce(mergeMaps)
}
// Alias for backwards compatibility
def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func)
// TODO: This should probably be a distributed version
def countByKey(): Map[K, Long] = self.map(_._1).countByValue()
// TODO: This should probably be a distributed version
def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
: PartialResult[Map[K, BoundedDouble]] = {
self.map(_._1).countByValueApprox(timeout, confidence)
}
def groupByKey(numSplits: Int): RDD[(K, Seq[V])] = {
def createCombiner(v: V) = ArrayBuffer(v)

Просмотреть файл

@ -1,119 +0,0 @@
package spark
import java.io.ByteArrayInputStream
import java.io.EOFException
import java.net.URL
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicReference
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
class ParallelShuffleFetcher extends ShuffleFetcher with Logging {
val parallelFetches = System.getProperty("spark.parallel.fetches", "3").toInt
def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
// Figure out a list of input IDs (mapper IDs) for each server
val ser = SparkEnv.get.serializer.newInstance()
val inputsByUri = new HashMap[String, ArrayBuffer[Int]]
val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
for ((serverUri, index) <- serverUris.zipWithIndex) {
inputsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index
}
// Randomize them and put them in a LinkedBlockingQueue
val serverQueue = new LinkedBlockingQueue[(String, ArrayBuffer[Int])]
for (pair <- Utils.randomize(inputsByUri)) {
serverQueue.put(pair)
}
// Create a queue to hold the fetched data
val resultQueue = new LinkedBlockingQueue[Array[Byte]]
// Atomic variables to communicate failures and # of fetches done
var failure = new AtomicReference[FetchFailedException](null)
// Start multiple threads to do the fetching (TODO: may be possible to do it asynchronously)
for (i <- 0 until parallelFetches) {
new Thread("Fetch thread " + i + " for reduce " + reduceId) {
override def run() {
while (true) {
val pair = serverQueue.poll()
if (pair == null)
return
val (serverUri, inputIds) = pair
//logInfo("Pulled out server URI " + serverUri)
for (i <- inputIds) {
if (failure.get != null)
return
val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
logInfo("Starting HTTP request for " + url)
try {
val conn = new URL(url).openConnection()
conn.connect()
val len = conn.getContentLength()
if (len == -1) {
throw new SparkException("Content length was not specified by server")
}
val buf = new Array[Byte](len)
val in = new FastBufferedInputStream(conn.getInputStream())
var pos = 0
while (pos < len) {
val n = in.read(buf, pos, len-pos)
if (n == -1) {
throw new SparkException("EOF before reading the expected " + len + " bytes")
} else {
pos += n
}
}
// Done reading everything
resultQueue.put(buf)
in.close()
} catch {
case e: Exception =>
logError("Fetch failed from " + url, e)
failure.set(new FetchFailedException(serverUri, shuffleId, i, reduceId, e))
return
}
}
//logInfo("Done with server URI " + serverUri)
}
}
}.start()
}
// Wait for results from the threads (either a failure or all servers done)
var resultsDone = 0
var totalResults = inputsByUri.map{case (uri, inputs) => inputs.size}.sum
while (failure.get == null && resultsDone < totalResults) {
try {
val result = resultQueue.poll(100, TimeUnit.MILLISECONDS)
if (result != null) {
//logInfo("Pulled out a result")
val in = ser.inputStream(new ByteArrayInputStream(result))
try {
while (true) {
val pair = in.readObject().asInstanceOf[(K, V)]
func(pair._1, pair._2)
}
} catch {
case e: EOFException => {} // TODO: cleaner way to detect EOF, such as a sentinel
}
resultsDone += 1
//logInfo("Results done = " + resultsDone)
}
} catch { case e: InterruptedException => {} }
}
if (failure.get != null) {
throw failure.get
}
}
}

Просмотреть файл

@ -70,4 +70,3 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
false
}
}

Просмотреть файл

@ -3,6 +3,7 @@ package spark
import java.io.PrintWriter
import java.util.StringTokenizer
import scala.collection.Map
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source

Просмотреть файл

@ -4,11 +4,14 @@ import java.io.EOFException
import java.net.URL
import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong
import java.util.HashSet
import java.util.Random
import java.util.Date
import java.util.{HashMap => JHashMap}
import scala.collection.mutable.ArrayBuffer
import scala.collection.Map
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions.mapAsScalaMap
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
@ -22,6 +25,14 @@ import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.TextOutputFormat
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
import spark.partial.GroupedCountEvaluator
import spark.partial.PartialResult
import spark.storage.StorageLevel
import SparkContext._
/**
@ -61,19 +72,32 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
// Get a unique ID for this RDD
val id = sc.newRddId()
// Variables relating to caching
private var shouldCache = false
// Variables relating to persistence
private var storageLevel: StorageLevel = StorageLevel.NONE
// Change this RDD's caching
def cache(): RDD[T] = {
shouldCache = true
// Change this RDD's storage level
def persist(newLevel: StorageLevel): RDD[T] = {
// TODO: Handle changes of StorageLevel
if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) {
throw new UnsupportedOperationException(
"Cannot change storage level of an RDD after it was already assigned a level")
}
storageLevel = newLevel
this
}
// Turn on the default caching level for this RDD
def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY_DESER)
// Turn on the default caching level for this RDD
def cache(): RDD[T] = persist()
def getStorageLevel = storageLevel
// Read this RDD; will read from cache if applicable, or otherwise compute
final def iterator(split: Split): Iterator[T] = {
if (shouldCache) {
SparkEnv.get.cacheTracker.getOrCompute[T](this, split)
if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel)
} else {
compute(split)
}
@ -162,6 +186,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
Array.concat(results: _*)
}
def toArray(): Array[T] = collect()
def reduce(f: (T, T) => T): T = {
val cleanF = sc.clean(f)
val reducePartition: Iterator[T] => Option[T] = iter => {
@ -222,7 +248,67 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
}).sum
}
def toArray(): Array[T] = collect()
/**
* Approximate version of count() that returns a potentially incomplete result after a timeout.
*/
def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) =>
var result = 0L
while (iter.hasNext) {
result += 1L
iter.next
}
result
}
val evaluator = new CountEvaluator(splits.size, confidence)
sc.runApproximateJob(this, countElements, evaluator, timeout)
}
/**
* Count elements equal to each value, returning a map of (value, count) pairs. The final combine
* step happens locally on the master, equivalent to running a single reduce task.
*
* TODO: This should perhaps be distributed by default.
*/
def countByValue(): Map[T, Long] = {
def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = {
val map = new OLMap[T]
while (iter.hasNext) {
val v = iter.next()
map.put(v, map.getLong(v) + 1L)
}
Iterator(map)
}
def mergeMaps(m1: OLMap[T], m2: OLMap[T]): OLMap[T] = {
val iter = m2.object2LongEntrySet.fastIterator()
while (iter.hasNext) {
val entry = iter.next()
m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue)
}
return m1
}
val myResult = mapPartitions(countPartition).reduce(mergeMaps)
myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map
}
/**
* Approximate version of countByValue().
*/
def countByValueApprox(
timeout: Long,
confidence: Double = 0.95
): PartialResult[Map[T, BoundedDouble]] = {
val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) =>
val map = new OLMap[T]
while (iter.hasNext) {
val v = iter.next()
map.put(v, map.getLong(v) + 1L)
}
map
}
val evaluator = new GroupedCountEvaluator[T](splits.size, confidence)
sc.runApproximateJob(this, countPartition, evaluator, timeout)
}
/**
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so

Просмотреть файл

@ -1,27 +0,0 @@
package spark
/**
* Scheduler trait, implemented by both MesosScheduler and LocalScheduler.
*/
private trait Scheduler {
def start()
// Wait for registration with Mesos.
def waitForRegister()
/**
* Run a function on some partitions of an RDD, returning an array of results. The allowLocal
* flag specifies whether the scheduler is allowed to run the job on the master machine rather
* than shipping it to the cluster, for actions that create short jobs such as first() and take().
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
allowLocal: Boolean): Array[U]
def stop()
// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
def defaultParallelism(): Int
}

Просмотреть файл

@ -44,7 +44,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
}
// TODO: use something like WritableConverter to avoid reflection
}
c.asInstanceOf[Class[ _ <: Writable]]
c.asInstanceOf[Class[_ <: Writable]]
}
def saveAsSequenceFile(path: String) {

Просмотреть файл

@ -1,6 +1,12 @@
package spark
import java.io.{InputStream, OutputStream}
import java.io.{EOFException, InputStream, OutputStream}
import java.nio.ByteBuffer
import java.nio.channels.Channels
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import spark.util.ByteBufferInputStream
/**
* A serializer. Because some serialization libraries are not thread safe, this class is used to
@ -14,11 +20,31 @@ trait Serializer {
* An instance of the serializer, for use by one thread at a time.
*/
trait SerializerInstance {
def serialize[T](t: T): Array[Byte]
def deserialize[T](bytes: Array[Byte]): T
def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T
def outputStream(s: OutputStream): SerializationStream
def inputStream(s: InputStream): DeserializationStream
def serialize[T](t: T): ByteBuffer
def deserialize[T](bytes: ByteBuffer): T
def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T
def serializeStream(s: OutputStream): SerializationStream
def deserializeStream(s: InputStream): DeserializationStream
def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
// Default implementation uses serializeStream
val stream = new FastByteArrayOutputStream()
serializeStream(stream).writeAll(iterator)
val buffer = ByteBuffer.allocate(stream.position.toInt)
buffer.put(stream.array, 0, stream.position.toInt)
buffer.flip()
buffer
}
def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
// Default implementation uses deserializeStream
buffer.rewind()
deserializeStream(new ByteBufferInputStream(buffer)).toIterator
}
}
/**
@ -28,6 +54,13 @@ trait SerializationStream {
def writeObject[T](t: T): Unit
def flush(): Unit
def close(): Unit
def writeAll[T](iter: Iterator[T]): SerializationStream = {
while (iter.hasNext) {
writeObject(iter.next())
}
this
}
}
/**
@ -36,4 +69,45 @@ trait SerializationStream {
trait DeserializationStream {
def readObject[T](): T
def close(): Unit
/**
* Read the elements of this stream through an iterator. This can only be called once, as
* reading each element will consume data from the input source.
*/
def toIterator: Iterator[Any] = new Iterator[Any] {
var gotNext = false
var finished = false
var nextValue: Any = null
private def getNext() {
try {
nextValue = readObject[Any]()
} catch {
case eof: EOFException =>
finished = true
}
gotNext = true
}
override def hasNext: Boolean = {
if (!gotNext) {
getNext()
}
if (finished) {
close()
}
!finished
}
override def next(): Any = {
if (!gotNext) {
getNext()
}
if (finished) {
throw new NoSuchElementException("End of stream")
}
gotNext = false
nextValue
}
}
}

Просмотреть файл

@ -1,26 +0,0 @@
package spark
import java.io._
/**
* Wrapper around a BoundedMemoryCache that stores serialized objects as byte arrays in order to
* reduce storage cost and GC overhead
*/
class SerializingCache extends Cache with Logging {
val bmc = new BoundedMemoryCache
override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
val ser = SparkEnv.get.serializer.newInstance()
bmc.put(datasetId, partition, ser.serialize(value))
}
override def get(datasetId: Any, partition: Int): Any = {
val bytes = bmc.get(datasetId, partition)
if (bytes != null) {
val ser = SparkEnv.get.serializer.newInstance()
return ser.deserialize(bytes.asInstanceOf[Array[Byte]])
} else {
return null
}
}
}

Просмотреть файл

@ -1,56 +0,0 @@
package spark
import java.io.BufferedOutputStream
import java.io.FileOutputStream
import java.io.ObjectOutputStream
import java.util.{HashMap => JHashMap}
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
class ShuffleMapTask(
runId: Int,
stageId: Int,
rdd: RDD[_],
dep: ShuffleDependency[_,_,_],
val partition: Int,
locs: Seq[String])
extends DAGTask[String](runId, stageId)
with Logging {
val split = rdd.splits(partition)
override def run (attemptId: Int): String = {
val numOutputSplits = dep.partitioner.numPartitions
val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
val partitioner = dep.partitioner.asInstanceOf[Partitioner]
val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])
for (elem <- rdd.iterator(split)) {
val (k, v) = elem.asInstanceOf[(Any, Any)]
var bucketId = partitioner.getPartition(k)
val bucket = buckets(bucketId)
var existing = bucket.get(k)
if (existing == null) {
bucket.put(k, aggregator.createCombiner(v))
} else {
bucket.put(k, aggregator.mergeValue(existing, v))
}
}
val ser = SparkEnv.get.serializer.newInstance()
for (i <- 0 until numOutputSplits) {
val file = SparkEnv.get.shuffleManager.getOutputFile(dep.shuffleId, partition, i)
val out = ser.outputStream(new FastBufferedOutputStream(new FileOutputStream(file)))
val iter = buckets(i).entrySet().iterator()
while (iter.hasNext()) {
val entry = iter.next()
out.writeObject((entry.getKey, entry.getValue))
}
// TODO: have some kind of EOF marker
out.close()
}
return SparkEnv.get.shuffleManager.getServerUri
}
override def preferredLocations: Seq[String] = locs
override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
}

Просмотреть файл

@ -8,7 +8,7 @@ class ShuffledRDDSplit(val idx: Int) extends Split {
}
class ShuffledRDD[K, V, C](
parent: RDD[(K, V)],
@transient parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C],
part : Partitioner)
extends RDD[(K, C)](parent.context) {

Просмотреть файл

@ -1,46 +0,0 @@
package spark
import java.io.EOFException
import java.net.URL
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val ser = SparkEnv.get.serializer.newInstance()
val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
for ((serverUri, index) <- serverUris.zipWithIndex) {
splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index
}
for ((serverUri, inputIds) <- Utils.randomize(splitsByUri)) {
for (i <- inputIds) {
try {
val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
// TODO: multithreaded fetch
// TODO: would be nice to retry multiple times
val inputStream = ser.inputStream(
new FastBufferedInputStream(new URL(url).openStream()))
try {
while (true) {
val pair = inputStream.readObject().asInstanceOf[(K, V)]
func(pair._1, pair._2)
}
} finally {
inputStream.close()
}
} catch {
case e: EOFException => {} // We currently assume EOF means we read the whole thing
case other: Exception => {
logError("Fetch failed", other)
throw new FetchFailedException(serverUri, shuffleId, i, reduceId, other)
}
}
}
}
}
}

Просмотреть файл

@ -3,6 +3,9 @@ package spark
import java.io._
import java.util.concurrent.atomic.AtomicInteger
import akka.actor.Actor
import akka.actor.Actor._
import scala.actors.remote.RemoteActor
import scala.collection.mutable.ArrayBuffer
@ -32,6 +35,15 @@ import org.apache.mesos.MesosNativeLibrary
import spark.broadcast._
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
import spark.scheduler.DAGScheduler
import spark.scheduler.TaskScheduler
import spark.scheduler.local.LocalScheduler
import spark.scheduler.mesos.MesosScheduler
import spark.scheduler.mesos.CoarseMesosScheduler
class SparkContext(
master: String,
frameworkName: String,
@ -54,14 +66,19 @@ class SparkContext(
if (RemoteActor.classLoader == null) {
RemoteActor.classLoader = getClass.getClassLoader
}
remote.start(System.getProperty("spark.master.host"),
System.getProperty("spark.master.port").toInt)
private val isLocal = master.startsWith("local") // TODO: better check for local
// Create the Spark execution environment (cache, map output tracker, etc)
val env = SparkEnv.createFromSystemProperties(true)
val env = SparkEnv.createFromSystemProperties(true, isLocal)
SparkEnv.set(env)
Broadcast.initialize(true)
// Create and start the scheduler
private var scheduler: Scheduler = {
private var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
@ -74,13 +91,17 @@ class SparkContext(
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
new LocalScheduler(threads.toInt, maxFailures.toInt)
case _ =>
MesosNativeLibrary.load()
new MesosScheduler(this, master, frameworkName)
System.loadLibrary("mesos")
if (System.getProperty("spark.mesos.coarse", "false") == "true") {
new CoarseMesosScheduler(this, master, frameworkName)
} else {
new MesosScheduler(this, master, frameworkName)
}
}
}
scheduler.start()
taskScheduler.start()
private val isLocal = scheduler.isInstanceOf[LocalScheduler]
private var dagScheduler = new DAGScheduler(taskScheduler)
// Methods for creating RDDs
@ -237,19 +258,21 @@ class SparkContext(
// Stop the SparkContext
def stop() {
scheduler.stop()
scheduler = null
dagScheduler.stop()
dagScheduler = null
taskScheduler = null
// TODO: Broadcast.stop(), Cache.stop()?
env.mapOutputTracker.stop()
env.cacheTracker.stop()
env.shuffleFetcher.stop()
env.shuffleManager.stop()
env.connectionManager.stop()
SparkEnv.set(null)
}
// Wait for the scheduler to be registered
// Wait for the scheduler to be registered with the cluster manager
def waitForRegister() {
scheduler.waitForRegister()
taskScheduler.waitForRegister()
}
// Get Spark's home location from either a value set through the constructor,
@ -281,7 +304,7 @@ class SparkContext(
): Array[U] = {
logInfo("Starting job...")
val start = System.nanoTime
val result = scheduler.runJob(rdd, func, partitions, allowLocal)
val result = dagScheduler.runJob(rdd, func, partitions, allowLocal)
logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
result
}
@ -306,6 +329,22 @@ class SparkContext(
runJob(rdd, func, 0 until rdd.splits.size, false)
}
/**
* Run a job that can return approximate results.
*/
def runApproximateJob[T, U, R](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long
): PartialResult[R] = {
logInfo("Starting job...")
val start = System.nanoTime
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, timeout)
logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
result
}
// Clean a closure to make it ready to serialized and send to tasks
// (removes unreferenced variables in $outer's, updates REPL variables)
private[spark] def clean[F <: AnyRef](f: F): F = {
@ -314,7 +353,7 @@ class SparkContext(
}
// Default level of parallelism to use when not given by user (e.g. for reduce tasks)
def defaultParallelism: Int = scheduler.defaultParallelism
def defaultParallelism: Int = taskScheduler.defaultParallelism
// Default min number of splits for Hadoop RDDs when not given by user
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
@ -349,15 +388,23 @@ object SparkContext {
}
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](rdd: RDD[(K, V)]) =
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
rdd: RDD[(K, V)]) =
new OrderedRDDFunctions(rdd)
implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd)
implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) =
new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))
// Implicit conversions to common Writable types, for saveAsSequenceFile
implicit def intToIntWritable(i: Int) = new IntWritable(i)

Просмотреть файл

@ -1,14 +1,26 @@
package spark
import akka.actor.Actor
import spark.storage.BlockManager
import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
class SparkEnv (
val cache: Cache,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheTracker: CacheTracker,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val shuffleManager: ShuffleManager
)
val cache: Cache,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheTracker: CacheTracker,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val shuffleManager: ShuffleManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager
) {
/** No-parameter constructor for unit tests. */
def this() = this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null)
}
object SparkEnv {
private val env = new ThreadLocal[SparkEnv]
@ -21,36 +33,55 @@ object SparkEnv {
env.get()
}
def createFromSystemProperties(isMaster: Boolean): SparkEnv = {
val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache")
val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer")
def createFromSystemProperties(isMaster: Boolean, isLocal: Boolean): SparkEnv = {
val serializerClass = System.getProperty("spark.serializer", "spark.KryoSerializer")
val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
BlockManagerMaster.startBlockManagerMaster(isMaster, isLocal)
var blockManager = new BlockManager(serializer)
val connectionManager = blockManager.connectionManager
val shuffleManager = new ShuffleManager()
val closureSerializerClass =
System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
val closureSerializer =
Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer]
val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache")
val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
val cacheTracker = new CacheTracker(isMaster, cache)
val cacheTracker = new CacheTracker(isMaster, blockManager)
blockManager.cacheTracker = cacheTracker
val mapOutputTracker = new MapOutputTracker(isMaster)
val shuffleFetcherClass =
System.getProperty("spark.shuffle.fetcher", "spark.SimpleShuffleFetcher")
System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
val shuffleFetcher =
Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher]
val shuffleMgr = new ShuffleManager()
/*
if (System.getProperty("spark.stream.distributed", "false") == "true") {
val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]]
if (isLocal || !isMaster) {
(new Thread() {
override def run() {
println("Wait started")
Thread.sleep(60000)
println("Wait ended")
val receiverClass = Class.forName("spark.stream.TestStreamReceiver4")
val constructor = receiverClass.getConstructor(blockManagerClass)
val receiver = constructor.newInstance(blockManager)
receiver.asInstanceOf[Thread].start()
}
}).start()
}
}
*/
new SparkEnv(
cache,
serializer,
closureSerializer,
cacheTracker,
mapOutputTracker,
shuffleFetcher,
shuffleMgr)
new SparkEnv(cache, serializer, closureSerializer, cacheTracker, mapOutputTracker, shuffleFetcher,
shuffleManager, blockManager, connectionManager)
}
}

Просмотреть файл

@ -1,41 +0,0 @@
package spark
class Stage(
val id: Int,
val rdd: RDD[_],
val shuffleDep: Option[ShuffleDependency[_,_,_]],
val parents: List[Stage]) {
val isShuffleMap = shuffleDep != None
val numPartitions = rdd.splits.size
val outputLocs = Array.fill[List[String]](numPartitions)(Nil)
var numAvailableOutputs = 0
def isAvailable: Boolean = {
if (parents.size == 0 && !isShuffleMap) {
true
} else {
numAvailableOutputs == numPartitions
}
}
def addOutputLoc(partition: Int, host: String) {
val prevList = outputLocs(partition)
outputLocs(partition) = host :: prevList
if (prevList == Nil)
numAvailableOutputs += 1
}
def removeOutputLoc(partition: Int, host: String) {
val prevList = outputLocs(partition)
val newList = prevList.filterNot(_ == host)
outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) {
numAvailableOutputs -= 1
}
}
override def toString = "Stage " + id
override def hashCode(): Int = id
}

Просмотреть файл

@ -1,9 +0,0 @@
package spark
class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable
abstract class Task[T] extends Serializable {
def run(id: Int): T
def preferredLocations: Seq[String] = Nil
def generation: Option[Long] = None
}

Просмотреть файл

@ -0,0 +1,3 @@
package spark
class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable

Просмотреть файл

@ -0,0 +1,16 @@
package spark
import spark.storage.BlockManagerId
/**
* Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry
* tasks several times for "ephemeral" failures, and only report back failures that require some
* old stages to be resubmitted, such as shuffle map fetch failures.
*/
sealed trait TaskEndReason
case object Success extends TaskEndReason
case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
case class ExceptionFailure(exception: Throwable) extends TaskEndReason
case class OtherFailure(message: String) extends TaskEndReason

Просмотреть файл

@ -1,8 +0,0 @@
package spark
import scala.collection.mutable.Map
// Task result. Also contains updates to accumulator variables.
// TODO: Use of distributed cache to return result is a hack to get around
// what seems to be a bug with messages over 60KB in libprocess; fix it
private class TaskResult[T](val value: T, val accumUpdates: Map[Long, Any]) extends Serializable

Просмотреть файл

@ -33,7 +33,8 @@ class UnionRDD[T: ClassManifest](
override def splits = splits_
@transient override val dependencies = {
@transient
override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for ((rdd, index) <- rdds.zipWithIndex) {

Просмотреть файл

@ -118,6 +118,23 @@ object Utils {
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
*/
def localIpAddress(): String = InetAddress.getLocalHost.getHostAddress
private var customHostname: Option[String] = None
/**
* Allow setting a custom host name because when we run on Mesos we need to use the same
* hostname it reports to the master.
*/
def setCustomHostname(hostname: String) {
customHostname = Some(hostname)
}
/**
* Get the local machine's hostname
*/
def localHostName(): String = {
customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
}
/**
* Returns a standard ThreadFactory except all threads are daemons.
@ -142,6 +159,14 @@ object Utils {
return threadPool
}
/**
* Return the string to tell how long has passed in seconds. The passing parameter should be in
* millisecond.
*/
def getUsedTimeMs(startTimeMs: Long): String = {
return " " + (System.currentTimeMillis - startTimeMs) + " ms "
}
/**
* Wrapper over newFixedThreadPool.
@ -154,16 +179,6 @@ object Utils {
return threadPool
}
/**
* Get the local machine's hostname.
*/
def localHostName(): String = InetAddress.getLocalHost.getHostName
/**
* Get current host
*/
def getHost = System.getProperty("spark.hostname", localHostName())
/**
* Delete a file or directory and its contents recursively.
*/

Просмотреть файл

@ -0,0 +1,364 @@
package spark.network
import spark._
import scala.collection.mutable.{HashMap, Queue, ArrayBuffer}
import java.io._
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
import java.net._
abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging {
channel.configureBlocking(false)
channel.socket.setTcpNoDelay(true)
channel.socket.setReuseAddress(true)
channel.socket.setKeepAlive(true)
/*channel.socket.setReceiveBufferSize(32768) */
var onCloseCallback: Connection => Unit = null
var onExceptionCallback: (Connection, Exception) => Unit = null
var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
lazy val remoteAddress = getRemoteAddress()
lazy val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
def key() = channel.keyFor(selector)
def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
def read() {
throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString)
}
def write() {
throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString)
}
def close() {
key.cancel()
channel.close()
callOnCloseCallback()
}
def onClose(callback: Connection => Unit) {onCloseCallback = callback}
def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback}
def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback}
def callOnExceptionCallback(e: Exception) {
if (onExceptionCallback != null) {
onExceptionCallback(this, e)
} else {
logError("Error in connection to " + remoteConnectionManagerId +
" and OnExceptionCallback not registered", e)
}
}
def callOnCloseCallback() {
if (onCloseCallback != null) {
onCloseCallback(this)
} else {
logWarning("Connection to " + remoteConnectionManagerId +
" closed and OnExceptionCallback not registered")
}
}
def changeConnectionKeyInterest(ops: Int) {
if (onKeyInterestChangeCallback != null) {
onKeyInterestChangeCallback(this, ops)
} else {
throw new Exception("OnKeyInterestChangeCallback not registered")
}
}
def printRemainingBuffer(buffer: ByteBuffer) {
val bytes = new Array[Byte](buffer.remaining)
val curPosition = buffer.position
buffer.get(bytes)
bytes.foreach(x => print(x + " "))
buffer.position(curPosition)
print(" (" + bytes.size + ")")
}
def printBuffer(buffer: ByteBuffer, position: Int, length: Int) {
val bytes = new Array[Byte](length)
val curPosition = buffer.position
buffer.position(position)
buffer.get(bytes)
bytes.foreach(x => print(x + " "))
print(" (" + position + ", " + length + ")")
buffer.position(curPosition)
}
}
class SendingConnection(val address: InetSocketAddress, selector_ : Selector)
extends Connection(SocketChannel.open, selector_) {
class Outbox(fair: Int = 0) {
val messages = new Queue[Message]()
val defaultChunkSize = 65536 //32768 //16384
var nextMessageToBeUsed = 0
def addMessage(message: Message): Unit = {
messages.synchronized{
/*messages += message*/
messages.enqueue(message)
logInfo("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]")
}
}
def getChunk(): Option[MessageChunk] = {
fair match {
case 0 => getChunkFIFO()
case 1 => getChunkRR()
case _ => throw new Exception("Unexpected fairness policy in outbox")
}
}
private def getChunkFIFO(): Option[MessageChunk] = {
/*logInfo("Using FIFO")*/
messages.synchronized {
while (!messages.isEmpty) {
val message = messages(0)
val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
messages += message // this is probably incorrect, it wont work as fifo
if (!message.started) logDebug("Starting to send [" + message + "]")
message.started = true
return chunk
}
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
}
}
None
}
private def getChunkRR(): Option[MessageChunk] = {
messages.synchronized {
while (!messages.isEmpty) {
/*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
/*val message = messages(nextMessageToBeUsed)*/
val message = messages.dequeue
val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
messages.enqueue(message)
nextMessageToBeUsed = nextMessageToBeUsed + 1
if (!message.started) {
logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]")
message.started = true
message.startTime = System.currentTimeMillis
}
logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
return chunk
}
/*messages -= message*/
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
}
}
None
}
}
val outbox = new Outbox(1)
val currentBuffers = new ArrayBuffer[ByteBuffer]()
/*channel.socket.setSendBufferSize(256 * 1024)*/
override def getRemoteAddress() = address
def send(message: Message) {
outbox.synchronized {
outbox.addMessage(message)
if (channel.isConnected) {
changeConnectionKeyInterest(SelectionKey.OP_WRITE)
}
}
}
def connect() {
try{
channel.connect(address)
channel.register(selector, SelectionKey.OP_CONNECT)
logInfo("Initiating connection to [" + address + "]")
} catch {
case e: Exception => {
logError("Error connecting to " + address, e)
callOnExceptionCallback(e)
}
}
}
def finishConnect() {
try {
channel.finishConnect
changeConnectionKeyInterest(SelectionKey.OP_WRITE)
logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
} catch {
case e: Exception => {
logWarning("Error finishing connection to " + address, e)
callOnExceptionCallback(e)
}
}
}
override def write() {
try{
while(true) {
if (currentBuffers.size == 0) {
outbox.synchronized {
outbox.getChunk match {
case Some(chunk) => {
currentBuffers ++= chunk.buffers
}
case None => {
changeConnectionKeyInterest(0)
/*key.interestOps(0)*/
return
}
}
}
}
if (currentBuffers.size > 0) {
val buffer = currentBuffers(0)
val remainingBytes = buffer.remaining
val writtenBytes = channel.write(buffer)
if (buffer.remaining == 0) {
currentBuffers -= buffer
}
if (writtenBytes < remainingBytes) {
return
}
}
}
} catch {
case e: Exception => {
logWarning("Error writing in connection to " + remoteConnectionManagerId, e)
callOnExceptionCallback(e)
close()
}
}
}
}
class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
extends Connection(channel_, selector_) {
class Inbox() {
val messages = new HashMap[Int, BufferMessage]()
def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
def createNewMessage: BufferMessage = {
val newMessage = Message.create(header).asInstanceOf[BufferMessage]
newMessage.started = true
newMessage.startTime = System.currentTimeMillis
logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]")
messages += ((newMessage.id, newMessage))
newMessage
}
val message = messages.getOrElseUpdate(header.id, createNewMessage)
logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]")
message.getChunkForReceiving(header.chunkSize)
}
def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
messages.get(chunk.header.id)
}
def removeMessage(message: Message) {
messages -= message.id
}
}
val inbox = new Inbox()
val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
var onReceiveCallback: (Connection , Message) => Unit = null
var currentChunk: MessageChunk = null
channel.register(selector, SelectionKey.OP_READ)
override def read() {
try {
while (true) {
if (currentChunk == null) {
val headerBytesRead = channel.read(headerBuffer)
if (headerBytesRead == -1) {
close()
return
}
if (headerBuffer.remaining > 0) {
return
}
headerBuffer.flip
if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
}
val header = MessageChunkHeader.create(headerBuffer)
headerBuffer.clear()
header.typ match {
case Message.BUFFER_MESSAGE => {
if (header.totalSize == 0) {
if (onReceiveCallback != null) {
onReceiveCallback(this, Message.create(header))
}
currentChunk = null
return
} else {
currentChunk = inbox.getChunk(header).orNull
}
}
case _ => throw new Exception("Message of unknown type received")
}
}
if (currentChunk == null) throw new Exception("No message chunk to receive data")
val bytesRead = channel.read(currentChunk.buffer)
if (bytesRead == 0) {
return
} else if (bytesRead == -1) {
close()
return
}
/*logDebug("Read " + bytesRead + " bytes for the buffer")*/
if (currentChunk.buffer.remaining == 0) {
/*println("Filled buffer at " + System.currentTimeMillis)*/
val bufferMessage = inbox.getMessageForChunk(currentChunk).get
if (bufferMessage.isCompletelyReceived) {
bufferMessage.flip
bufferMessage.finishTime = System.currentTimeMillis
logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken)
if (onReceiveCallback != null) {
onReceiveCallback(this, bufferMessage)
}
inbox.removeMessage(bufferMessage)
}
currentChunk = null
}
}
} catch {
case e: Exception => {
logWarning("Error reading from connection to " + remoteConnectionManagerId, e)
callOnExceptionCallback(e)
close()
}
}
}
def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
}

Просмотреть файл

@ -0,0 +1,467 @@
package spark.network
import spark._
import scala.actors.Future
import scala.actors.Futures.future
import scala.collection.mutable.HashMap
import scala.collection.mutable.SynchronizedMap
import scala.collection.mutable.SynchronizedQueue
import scala.collection.mutable.Queue
import scala.collection.mutable.ArrayBuffer
import java.io._
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
import java.net._
import java.util.concurrent.Executors
case class ConnectionManagerId(val host: String, val port: Int) {
def toSocketAddress() = new InetSocketAddress(host, port)
}
object ConnectionManagerId {
def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
}
}
class ConnectionManager(port: Int) extends Logging {
case class MessageStatus(message: Message, connectionManagerId: ConnectionManagerId) {
var ackMessage: Option[Message] = None
var attempted = false
var acked = false
}
val selector = SelectorProvider.provider.openSelector()
/*val handleMessageExecutor = new ThreadPoolExecutor(4, 4, 600, TimeUnit.SECONDS, new LinkedBlockingQueue()) */
val handleMessageExecutor = Executors.newFixedThreadPool(4)
val serverChannel = ServerSocketChannel.open()
val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val messageStatuses = new HashMap[Int, MessageStatus]
val connectionRequests = new SynchronizedQueue[SendingConnection]
val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
val sendMessageRequests = new Queue[(Message, SendingConnection)]
var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
serverChannel.configureBlocking(false)
serverChannel.socket.setReuseAddress(true)
serverChannel.socket.setReceiveBufferSize(256 * 1024)
serverChannel.socket.bind(new InetSocketAddress(port))
serverChannel.register(selector, SelectionKey.OP_ACCEPT)
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
val thisInstance = this
var selectorThread = new Thread("connection-manager-thread") {
override def run() {
thisInstance.run()
}
}
selectorThread.setDaemon(true)
selectorThread.start()
def run() {
try {
var interrupted = false
while(!interrupted) {
while(!connectionRequests.isEmpty) {
val sendingConnection = connectionRequests.dequeue
sendingConnection.connect()
addConnection(sendingConnection)
}
sendMessageRequests.synchronized {
while(!sendMessageRequests.isEmpty) {
val (message, connection) = sendMessageRequests.dequeue
connection.send(message)
}
}
while(!keyInterestChangeRequests.isEmpty) {
val (key, ops) = keyInterestChangeRequests.dequeue
val connection = connectionsByKey(key)
val lastOps = key.interestOps()
key.interestOps(ops)
def intToOpStr(op: Int): String = {
val opStrs = ArrayBuffer[String]()
if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
}
logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId +
"] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
}
val selectedKeysCount = selector.select()
if (selectedKeysCount == 0) logInfo("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
interrupted = selectorThread.isInterrupted
val selectedKeys = selector.selectedKeys().iterator()
while (selectedKeys.hasNext()) {
val key = selectedKeys.next.asInstanceOf[SelectionKey]
selectedKeys.remove()
if (key.isValid) {
if (key.isAcceptable) {
acceptConnection(key)
} else
if (key.isConnectable) {
connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect()
} else
if (key.isReadable) {
connectionsByKey(key).read()
} else
if (key.isWritable) {
connectionsByKey(key).write()
}
}
}
}
} catch {
case e: Exception => logError("Error in select loop", e)
}
}
def acceptConnection(key: SelectionKey) {
val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
val newChannel = serverChannel.accept()
val newConnection = new ReceivingConnection(newChannel, selector)
newConnection.onReceive(receiveMessage)
newConnection.onClose(removeConnection)
addConnection(newConnection)
logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
}
def addConnection(connection: Connection) {
connectionsByKey += ((connection.key, connection))
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection))
}
connection.onKeyInterestChange(changeConnectionKeyInterest)
connection.onException(handleConnectionError)
connection.onClose(removeConnection)
}
def removeConnection(connection: Connection) {
/*logInfo("Removing connection")*/
connectionsByKey -= connection.key
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
connectionsById -= sendingConnectionManagerId
messageStatuses.synchronized {
messageStatuses
.values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
logInfo("Notifying " + status)
status.synchronized {
status.attempted = true
status.acked = false
status.notifyAll
}
})
messageStatuses.retain((i, status) => {
status.connectionManagerId != sendingConnectionManagerId
})
}
} else if (connection.isInstanceOf[ReceivingConnection]) {
val receivingConnection = connection.asInstanceOf[ReceivingConnection]
val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId
logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull
if (sendingConnectionManagerId == null) {
logError("Corresponding SendingConnectionManagerId not found")
return
}
logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId)
val sendingConnection = connectionsById(sendingConnectionManagerId)
sendingConnection.close()
connectionsById -= sendingConnectionManagerId
messageStatuses.synchronized {
messageStatuses
.values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
logInfo("Notifying " + status)
status.synchronized {
status.attempted = true
status.acked = false
status.notifyAll
}
})
messageStatuses.retain((i, status) => {
status.connectionManagerId != sendingConnectionManagerId
})
}
}
}
def handleConnectionError(connection: Connection, e: Exception) {
logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId)
removeConnection(connection)
}
def changeConnectionKeyInterest(connection: Connection, ops: Int) {
keyInterestChangeRequests += ((connection.key, ops))
}
def receiveMessage(connection: Connection, message: Message) {
val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
logInfo("Received [" + message + "] from [" + connectionManagerId + "]")
val runnable = new Runnable() {
val creationTime = System.currentTimeMillis
def run() {
logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
handleMessage(connectionManagerId, message)
logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
}
}
handleMessageExecutor.execute(runnable)
/*handleMessage(connection, message)*/
}
private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) {
logInfo("Handling [" + message + "] from [" + connectionManagerId + "]")
message match {
case bufferMessage: BufferMessage => {
if (bufferMessage.hasAckId) {
val sentMessageStatus = messageStatuses.synchronized {
messageStatuses.get(bufferMessage.ackId) match {
case Some(status) => {
messageStatuses -= bufferMessage.ackId
status
}
case None => {
throw new Exception("Could not find reference for received ack message " + message.id)
null
}
}
}
sentMessageStatus.synchronized {
sentMessageStatus.ackMessage = Some(message)
sentMessageStatus.attempted = true
sentMessageStatus.acked = true
sentMessageStatus.notifyAll
}
} else {
val ackMessage = if (onReceiveCallback != null) {
logDebug("Calling back")
onReceiveCallback(bufferMessage, connectionManagerId)
} else {
logWarning("Not calling back as callback is null")
None
}
if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logWarning("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
logWarning("Response to " + bufferMessage + " does not have ack id set")
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
}
}
sendMessage(connectionManagerId, ackMessage.getOrElse {
Message.createBufferMessage(bufferMessage.id)
})
}
}
case _ => throw new Exception("Unknown type message received")
}
}
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
val newConnection = new SendingConnection(inetSocketAddress, selector)
connectionRequests += newConnection
newConnection
}
val connection = connectionsById.getOrElse(connectionManagerId, startNewConnection)
message.senderAddress = id.toSocketAddress()
logInfo("Sending [" + message + "] to [" + connectionManagerId + "]")
/*connection.send(message)*/
sendMessageRequests.synchronized {
sendMessageRequests += ((message, connection))
}
selector.wakeup()
}
def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message): Future[Option[Message]] = {
val messageStatus = new MessageStatus(message, connectionManagerId)
messageStatuses.synchronized {
messageStatuses += ((message.id, messageStatus))
}
sendMessage(connectionManagerId, message)
future {
messageStatus.synchronized {
if (!messageStatus.attempted) {
logTrace("Waiting, " + messageStatuses.size + " statuses" )
messageStatus.wait()
logTrace("Done waiting")
}
}
messageStatus.ackMessage
}
}
def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = {
sendMessageReliably(connectionManagerId, message)()
}
def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
onReceiveCallback = callback
}
def stop() {
selectorThread.interrupt()
selectorThread.join()
selector.close()
val connections = connectionsByKey.values
connections.foreach(_.close())
if (connectionsByKey.size != 0) {
logWarning("All connections not cleaned up")
}
handleMessageExecutor.shutdown()
logInfo("ConnectionManager stopped")
}
}
object ConnectionManager {
def main(args: Array[String]) {
val manager = new ConnectionManager(9999)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
println("Received [" + msg + "] from [" + id + "]")
None
})
/*testSequentialSending(manager)*/
/*System.gc()*/
/*testParallelSending(manager)*/
/*System.gc()*/
/*testParallelDecreasingSending(manager)*/
/*System.gc()*/
testContinuousSending(manager)
System.gc()
}
def testSequentialSending(manager: ConnectionManager) {
println("--------------------------")
println("Sequential Sending")
println("--------------------------")
val size = 10 * 1024 * 1024
val count = 10
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
(0 until count).map(i => {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliablySync(manager.id, bufferMessage)
})
println("--------------------------")
println()
}
def testParallelSending(manager: ConnectionManager) {
println("--------------------------")
println("Parallel Sending")
println("--------------------------")
val size = 10 * 1024 * 1024
val count = 10
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
val startTime = System.currentTimeMillis
(0 until count).map(i => {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {if (!f().isDefined) println("Failed")})
val finishTime = System.currentTimeMillis
val mb = size * count / 1024.0 / 1024.0
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("--------------------------")
println("Started at " + startTime + ", finished at " + finishTime)
println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)")
println("--------------------------")
println()
}
def testParallelDecreasingSending(manager: ConnectionManager) {
println("--------------------------")
println("Parallel Decreasing Sending")
println("--------------------------")
val size = 10 * 1024 * 1024
val count = 10
val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte)))
buffers.foreach(_.flip)
val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0
val startTime = System.currentTimeMillis
(0 until count).map(i => {
val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {if (!f().isDefined) println("Failed")})
val finishTime = System.currentTimeMillis
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("--------------------------")
/*println("Started at " + startTime + ", finished at " + finishTime) */
println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
println("--------------------------")
println()
}
def testContinuousSending(manager: ConnectionManager) {
println("--------------------------")
println("Continuous Sending")
println("--------------------------")
val size = 10 * 1024 * 1024
val count = 10
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
val startTime = System.currentTimeMillis
while(true) {
(0 until count).map(i => {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {if (!f().isDefined) println("Failed")})
val finishTime = System.currentTimeMillis
Thread.sleep(1000)
val mb = size * count / 1024.0 / 1024.0
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("--------------------------")
println()
}
}
}

Просмотреть файл

@ -0,0 +1,74 @@
package spark.network
import spark._
import spark.SparkContext._
import scala.io.Source
import java.nio.ByteBuffer
import java.net.InetAddress
object ConnectionManagerTest extends Logging{
def main(args: Array[String]) {
if (args.length < 2) {
println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>")
System.exit(1)
}
if (args(0).startsWith("local")) {
println("This runs only on a mesos cluster")
}
val sc = new SparkContext(args(0), "ConnectionManagerTest")
val slavesFile = Source.fromFile(args(1))
val slaves = slavesFile.mkString.split("\n")
slavesFile.close()
/*println("Slaves")*/
/*slaves.foreach(println)*/
val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map(
i => SparkEnv.get.connectionManager.id).collect()
println("\nSlave ConnectionManagerIds")
slaveConnManagerIds.foreach(println)
println
val count = 10
(0 until count).foreach(i => {
val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => {
val connManager = SparkEnv.get.connectionManager
val thisConnManagerId = connManager.id
connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
logInfo("Received [" + msg + "] from [" + id + "]")
None
})
val size = 100 * 1024 * 1024
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
val startTime = System.currentTimeMillis
val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
})
val results = futures.map(f => f())
val finishTime = System.currentTimeMillis
Thread.sleep(5000)
val mb = size * results.size / 1024.0 / 1024.0
val ms = finishTime - startTime
val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
logInfo(resultStr)
resultStr
}).collect()
println("---------------------")
println("Run " + i)
resultStrs.foreach(println)
println("---------------------")
})
}
}

Просмотреть файл

@ -0,0 +1,219 @@
package spark.network
import spark._
import scala.collection.mutable.ArrayBuffer
import java.nio.ByteBuffer
import java.net.InetAddress
import java.net.InetSocketAddress
class MessageChunkHeader(
val typ: Long,
val id: Int,
val totalSize: Int,
val chunkSize: Int,
val other: Int,
val address: InetSocketAddress) {
lazy val buffer = {
val ip = address.getAddress.getAddress()
val port = address.getPort()
ByteBuffer.
allocate(MessageChunkHeader.HEADER_SIZE).
putLong(typ).
putInt(id).
putInt(totalSize).
putInt(chunkSize).
putInt(other).
putInt(ip.size).
put(ip).
putInt(port).
position(MessageChunkHeader.HEADER_SIZE).
flip.asInstanceOf[ByteBuffer]
}
override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
" and sizes " + totalSize + " / " + chunkSize + " bytes"
}
class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
val size = if (buffer == null) 0 else buffer.remaining
lazy val buffers = {
val ab = new ArrayBuffer[ByteBuffer]()
ab += header.buffer
if (buffer != null) {
ab += buffer
}
ab
}
override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
}
abstract class Message(val typ: Long, val id: Int) {
var senderAddress: InetSocketAddress = null
var started = false
var startTime = -1L
var finishTime = -1L
def size: Int
def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
def timeTaken(): String = (finishTime - startTime).toString + " ms"
override def toString = "" + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
}
class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
extends Message(Message.BUFFER_MESSAGE, id_) {
val initialSize = currentSize()
var gotChunkForSendingOnce = false
def size = initialSize
def currentSize() = {
if (buffers == null || buffers.isEmpty) {
0
} else {
buffers.map(_.remaining).reduceLeft(_ + _)
}
}
def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
if (maxChunkSize <= 0) {
throw new Exception("Max chunk size is " + maxChunkSize)
}
if (size == 0 && gotChunkForSendingOnce == false) {
val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
gotChunkForSendingOnce = true
return Some(newChunk)
}
while(!buffers.isEmpty) {
val buffer = buffers(0)
if (buffer.remaining == 0) {
buffers -= buffer
} else {
val newBuffer = if (buffer.remaining <= maxChunkSize) {
buffer.duplicate
} else {
buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
}
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
gotChunkForSendingOnce = true
return Some(newChunk)
}
}
None
}
def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
// STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
if (buffers.size > 1) {
throw new Exception("Attempting to get chunk from message with multiple data buffers")
}
val buffer = buffers(0)
if (buffer.remaining > 0) {
if (buffer.remaining < chunkSize) {
throw new Exception("Not enough space in data buffer for receiving chunk")
}
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
return Some(newChunk)
}
None
}
def flip() {
buffers.foreach(_.flip)
}
def hasAckId() = (ackId != 0)
def isCompletelyReceived() = !buffers(0).hasRemaining
override def toString = {
if (hasAckId) {
"BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
} else {
"BufferMessage(id = " + id + ", size = " + size + ")"
}
}
}
object MessageChunkHeader {
val HEADER_SIZE = 40
def create(buffer: ByteBuffer): MessageChunkHeader = {
if (buffer.remaining != HEADER_SIZE) {
throw new IllegalArgumentException("Cannot convert buffer data to Message")
}
val typ = buffer.getLong()
val id = buffer.getInt()
val totalSize = buffer.getInt()
val chunkSize = buffer.getInt()
val other = buffer.getInt()
val ipSize = buffer.getInt()
val ipBytes = new Array[Byte](ipSize)
buffer.get(ipBytes)
val ip = InetAddress.getByAddress(ipBytes)
val port = buffer.getInt()
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
}
}
object Message {
val BUFFER_MESSAGE = 1111111111L
var lastId = 1
def getNewId() = synchronized {
lastId += 1
if (lastId == 0) lastId += 1
lastId
}
def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
if (dataBuffers == null) {
return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
}
if (dataBuffers.exists(_ == null)) {
throw new Exception("Attempting to create buffer message with null buffer")
}
return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
}
def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
createBufferMessage(dataBuffers, 0)
def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
if (dataBuffer == null) {
return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
} else {
return createBufferMessage(Array(dataBuffer), ackId)
}
}
def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
createBufferMessage(dataBuffer, 0)
def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId)
def create(header: MessageChunkHeader): Message = {
val newMessage: Message = header.typ match {
case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
}
newMessage.senderAddress = header.address
newMessage
}
}

Просмотреть файл

@ -0,0 +1,20 @@
package spark.network
import java.nio.ByteBuffer
import java.net.InetAddress
object ReceiverTest {
def main(args: Array[String]) {
val manager = new ConnectionManager(9999)
println("Started connection manager with id = " + manager.id)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
/*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/
val buffer = ByteBuffer.wrap("response".getBytes())
Some(Message.createBufferMessage(buffer, msg.id))
})
Thread.currentThread.join()
}
}

Просмотреть файл

@ -0,0 +1,53 @@
package spark.network
import java.nio.ByteBuffer
import java.net.InetAddress
object SenderTest {
def main(args: Array[String]) {
if (args.length < 2) {
println("Usage: SenderTest <target host> <target port>")
System.exit(1)
}
val targetHost = args(0)
val targetPort = args(1).toInt
val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort)
val manager = new ConnectionManager(0)
println("Started connection manager with id = " + manager.id)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
println("Received [" + msg + "] from [" + id + "]")
None
})
val size = 100 * 1024 * 1024
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
val targetServer = args(0)
val count = 100
(0 until count).foreach(i => {
val dataMessage = Message.createBufferMessage(buffer.duplicate)
val startTime = System.currentTimeMillis
/*println("Started timer at " + startTime)*/
val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) match {
case Some(response) =>
val buffer = response.asInstanceOf[BufferMessage].buffers(0)
new String(buffer.array)
case None => "none"
}
val finishTime = System.currentTimeMillis
val mb = size / 1024.0 / 1024.0
val ms = finishTime - startTime
/*val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"*/
val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" + (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr
println(resultStr)
})
}
}

Просмотреть файл

@ -0,0 +1,66 @@
package spark.partial
import spark._
import spark.scheduler.JobListener
/**
* A JobListener for an approximate single-result action, such as count() or non-parallel reduce().
* This listener waits up to timeout milliseconds and will return a partial answer even if the
* complete answer is not available by then.
*
* This class assumes that the action is performed on an entire RDD[T] via a function that computes
* a result of type U for each partition, and that the action returns a partial or complete result
* of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt).
*/
class ApproximateActionListener[T, U, R](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long)
extends JobListener {
val startTime = System.currentTimeMillis()
val totalTasks = rdd.splits.size
var finishedTasks = 0
var failure: Option[Exception] = None // Set if the job has failed (permanently)
var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult
override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
evaluator.merge(index, result.asInstanceOf[U])
finishedTasks += 1
if (finishedTasks == totalTasks) {
// If we had already returned a PartialResult, set its final value
resultObject.foreach(r => r.setFinalValue(evaluator.currentResult()))
// Notify any waiting thread that may have called getResult
this.notifyAll()
}
}
override def jobFailed(exception: Exception): Unit = synchronized {
failure = Some(exception)
this.notifyAll()
}
/**
* Waits for up to timeout milliseconds since the listener was created and then returns a
* PartialResult with the result so far. This may be complete if the whole job is done.
*/
def getResult(): PartialResult[R] = synchronized {
val finishTime = startTime + timeout
while (true) {
val time = System.currentTimeMillis()
if (failure != None) {
throw failure.get
} else if (finishedTasks == totalTasks) {
return new PartialResult(evaluator.currentResult(), true)
} else if (time >= finishTime) {
resultObject = Some(new PartialResult(evaluator.currentResult(), false))
return resultObject.get
} else {
this.wait(finishTime - time)
}
}
// Should never be reached, but required to keep the compiler happy
return null
}
}

Просмотреть файл

@ -0,0 +1,10 @@
package spark.partial
/**
* An object that computes a function incrementally by merging in results of type U from multiple
* tasks. Allows partial evaluation at any point by calling currentResult().
*/
trait ApproximateEvaluator[U, R] {
def merge(outputId: Int, taskResult: U): Unit
def currentResult(): R
}

Просмотреть файл

@ -0,0 +1,8 @@
package spark.partial
/**
* A Double with error bars on it.
*/
class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
override def toString(): String = "[%.3f, %.3f]".format(low, high)
}

Просмотреть файл

@ -0,0 +1,38 @@
package spark.partial
import cern.jet.stat.Probability
/**
* An ApproximateEvaluator for counts.
*
* TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might
* be best to make this a special case of GroupedCountEvaluator with one group.
*/
class CountEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[Long, BoundedDouble] {
var outputsMerged = 0
var sum: Long = 0
override def merge(outputId: Int, taskResult: Long) {
outputsMerged += 1
sum += taskResult
}
override def currentResult(): BoundedDouble = {
if (outputsMerged == totalOutputs) {
new BoundedDouble(sum, 1.0, sum, sum)
} else if (outputsMerged == 0) {
new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
} else {
val p = outputsMerged.toDouble / totalOutputs
val mean = (sum + 1 - p) / p
val variance = (sum + 1) * (1 - p) / (p * p)
val stdev = math.sqrt(variance)
val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
val low = mean - confFactor * stdev
val high = mean + confFactor * stdev
new BoundedDouble(mean, confidence, low, high)
}
}
}

Просмотреть файл

@ -0,0 +1,62 @@
package spark.partial
import java.util.{HashMap => JHashMap}
import java.util.{Map => JMap}
import scala.collection.Map
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions.mapAsScalaMap
import cern.jet.stat.Probability
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
/**
* An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval.
*/
class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] {
var outputsMerged = 0
var sums = new OLMap[T] // Sum of counts for each key
override def merge(outputId: Int, taskResult: OLMap[T]) {
outputsMerged += 1
val iter = taskResult.object2LongEntrySet.fastIterator()
while (iter.hasNext) {
val entry = iter.next()
sums.put(entry.getKey, sums.getLong(entry.getKey) + entry.getLongValue)
}
}
override def currentResult(): Map[T, BoundedDouble] = {
if (outputsMerged == totalOutputs) {
val result = new JHashMap[T, BoundedDouble](sums.size)
val iter = sums.object2LongEntrySet.fastIterator()
while (iter.hasNext) {
val entry = iter.next()
val sum = entry.getLongValue()
result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum)
}
result
} else if (outputsMerged == 0) {
new HashMap[T, BoundedDouble]
} else {
val p = outputsMerged.toDouble / totalOutputs
val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
val result = new JHashMap[T, BoundedDouble](sums.size)
val iter = sums.object2LongEntrySet.fastIterator()
while (iter.hasNext) {
val entry = iter.next()
val sum = entry.getLongValue
val mean = (sum + 1 - p) / p
val variance = (sum + 1) * (1 - p) / (p * p)
val stdev = math.sqrt(variance)
val low = mean - confFactor * stdev
val high = mean + confFactor * stdev
result(entry.getKey) = new BoundedDouble(mean, confidence, low, high)
}
result
}
}
}

Просмотреть файл

@ -0,0 +1,65 @@
package spark.partial
import java.util.{HashMap => JHashMap}
import java.util.{Map => JMap}
import scala.collection.mutable.HashMap
import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
import spark.util.StatCounter
/**
* An ApproximateEvaluator for means by key. Returns a map of key to confidence interval.
*/
class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
var outputsMerged = 0
var sums = new JHashMap[T, StatCounter] // Sum of counts for each key
override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) {
outputsMerged += 1
val iter = taskResult.entrySet.iterator()
while (iter.hasNext) {
val entry = iter.next()
val old = sums.get(entry.getKey)
if (old != null) {
old.merge(entry.getValue)
} else {
sums.put(entry.getKey, entry.getValue)
}
}
}
override def currentResult(): Map[T, BoundedDouble] = {
if (outputsMerged == totalOutputs) {
val result = new JHashMap[T, BoundedDouble](sums.size)
val iter = sums.entrySet.iterator()
while (iter.hasNext) {
val entry = iter.next()
val mean = entry.getValue.mean
result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean)
}
result
} else if (outputsMerged == 0) {
new HashMap[T, BoundedDouble]
} else {
val p = outputsMerged.toDouble / totalOutputs
val studentTCacher = new StudentTCacher(confidence)
val result = new JHashMap[T, BoundedDouble](sums.size)
val iter = sums.entrySet.iterator()
while (iter.hasNext) {
val entry = iter.next()
val counter = entry.getValue
val mean = counter.mean
val stdev = math.sqrt(counter.sampleVariance / counter.count)
val confFactor = studentTCacher.get(counter.count)
val low = mean - confFactor * stdev
val high = mean + confFactor * stdev
result(entry.getKey) = new BoundedDouble(mean, confidence, low, high)
}
result
}
}
}

Просмотреть файл

@ -0,0 +1,72 @@
package spark.partial
import java.util.{HashMap => JHashMap}
import java.util.{Map => JMap}
import scala.collection.mutable.HashMap
import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
import spark.util.StatCounter
/**
* An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval.
*/
class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
var outputsMerged = 0
var sums = new JHashMap[T, StatCounter] // Sum of counts for each key
override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) {
outputsMerged += 1
val iter = taskResult.entrySet.iterator()
while (iter.hasNext) {
val entry = iter.next()
val old = sums.get(entry.getKey)
if (old != null) {
old.merge(entry.getValue)
} else {
sums.put(entry.getKey, entry.getValue)
}
}
}
override def currentResult(): Map[T, BoundedDouble] = {
if (outputsMerged == totalOutputs) {
val result = new JHashMap[T, BoundedDouble](sums.size)
val iter = sums.entrySet.iterator()
while (iter.hasNext) {
val entry = iter.next()
val sum = entry.getValue.sum
result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum)
}
result
} else if (outputsMerged == 0) {
new HashMap[T, BoundedDouble]
} else {
val p = outputsMerged.toDouble / totalOutputs
val studentTCacher = new StudentTCacher(confidence)
val result = new JHashMap[T, BoundedDouble](sums.size)
val iter = sums.entrySet.iterator()
while (iter.hasNext) {
val entry = iter.next()
val counter = entry.getValue
val meanEstimate = counter.mean
val meanVar = counter.sampleVariance / counter.count
val countEstimate = (counter.count + 1 - p) / p
val countVar = (counter.count + 1) * (1 - p) / (p * p)
val sumEstimate = meanEstimate * countEstimate
val sumVar = (meanEstimate * meanEstimate * countVar) +
(countEstimate * countEstimate * meanVar) +
(meanVar * countVar)
val sumStdev = math.sqrt(sumVar)
val confFactor = studentTCacher.get(counter.count)
val low = sumEstimate - confFactor * sumStdev
val high = sumEstimate + confFactor * sumStdev
result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high)
}
result
}
}
}

Просмотреть файл

@ -0,0 +1,41 @@
package spark.partial
import cern.jet.stat.Probability
import spark.util.StatCounter
/**
* An ApproximateEvaluator for means.
*/
class MeanEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[StatCounter, BoundedDouble] {
var outputsMerged = 0
var counter = new StatCounter
override def merge(outputId: Int, taskResult: StatCounter) {
outputsMerged += 1
counter.merge(taskResult)
}
override def currentResult(): BoundedDouble = {
if (outputsMerged == totalOutputs) {
new BoundedDouble(counter.mean, 1.0, counter.mean, counter.mean)
} else if (outputsMerged == 0) {
new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
} else {
val mean = counter.mean
val stdev = math.sqrt(counter.sampleVariance / counter.count)
val confFactor = {
if (counter.count > 100) {
Probability.normalInverse(1 - (1 - confidence) / 2)
} else {
Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
}
}
val low = mean - confFactor * stdev
val high = mean + confFactor * stdev
new BoundedDouble(mean, confidence, low, high)
}
}
}

Просмотреть файл

@ -0,0 +1,86 @@
package spark.partial
class PartialResult[R](initialVal: R, isFinal: Boolean) {
private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None
private var failure: Option[Exception] = None
private var completionHandler: Option[R => Unit] = None
private var failureHandler: Option[Exception => Unit] = None
def initialValue: R = initialVal
def isInitialValueFinal: Boolean = isFinal
/**
* Blocking method to wait for and return the final value.
*/
def getFinalValue(): R = synchronized {
while (finalValue == None && failure == None) {
this.wait()
}
if (finalValue != None) {
return finalValue.get
} else {
throw failure.get
}
}
/**
* Set a handler to be called when this PartialResult completes. Only one completion handler
* is supported per PartialResult.
*/
def onComplete(handler: R => Unit): PartialResult[R] = synchronized {
if (completionHandler != None) {
throw new UnsupportedOperationException("onComplete cannot be called twice")
}
completionHandler = Some(handler)
if (finalValue != None) {
// We already have a final value, so let's call the handler
handler(finalValue.get)
}
return this
}
/**
* Set a handler to be called if this PartialResult's job fails. Only one failure handler
* is supported per PartialResult.
*/
def onFail(handler: Exception => Unit): Unit = synchronized {
if (failureHandler != None) {
throw new UnsupportedOperationException("onFail cannot be called twice")
}
failureHandler = Some(handler)
if (failure != None) {
// We already have a failure, so let's call the handler
handler(failure.get)
}
}
private[spark] def setFinalValue(value: R): Unit = synchronized {
if (finalValue != None) {
throw new UnsupportedOperationException("setFinalValue called twice on a PartialResult")
}
finalValue = Some(value)
// Call the completion handler if it was set
completionHandler.foreach(h => h(value))
// Notify any threads that may be calling getFinalValue()
this.notifyAll()
}
private[spark] def setFailure(exception: Exception): Unit = synchronized {
if (failure != None) {
throw new UnsupportedOperationException("setFailure called twice on a PartialResult")
}
failure = Some(exception)
// Call the failure handler if it was set
failureHandler.foreach(h => h(exception))
// Notify any threads that may be calling getFinalValue()
this.notifyAll()
}
override def toString: String = synchronized {
finalValue match {
case Some(value) => "(final: " + value + ")"
case None => "(partial: " + initialValue + ")"
}
}
}

Просмотреть файл

@ -0,0 +1,26 @@
package spark.partial
import cern.jet.stat.Probability
/**
* A utility class for caching Student's T distribution values for a given confidence level
* and various sample sizes. This is used by the MeanEvaluator to efficiently calculate
* confidence intervals for many keys.
*/
class StudentTCacher(confidence: Double) {
val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation
val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2)
val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0)
def get(sampleSize: Long): Double = {
if (sampleSize >= NORMAL_APPROX_SAMPLE_SIZE) {
normalApprox
} else {
val size = sampleSize.toInt
if (cache(size) < 0) {
cache(size) = Probability.studentTInverse(1 - confidence, size - 1)
}
cache(size)
}
}
}

Просмотреть файл

@ -0,0 +1,51 @@
package spark.partial
import cern.jet.stat.Probability
import spark.util.StatCounter
/**
* An ApproximateEvaluator for sums. It estimates the mean and the cont and multiplies them
* together, then uses the formula for the variance of two independent random variables to get
* a variance for the result and compute a confidence interval.
*/
class SumEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[StatCounter, BoundedDouble] {
var outputsMerged = 0
var counter = new StatCounter
override def merge(outputId: Int, taskResult: StatCounter) {
outputsMerged += 1
counter.merge(taskResult)
}
override def currentResult(): BoundedDouble = {
if (outputsMerged == totalOutputs) {
new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum)
} else if (outputsMerged == 0) {
new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
} else {
val p = outputsMerged.toDouble / totalOutputs
val meanEstimate = counter.mean
val meanVar = counter.sampleVariance / counter.count
val countEstimate = (counter.count + 1 - p) / p
val countVar = (counter.count + 1) * (1 - p) / (p * p)
val sumEstimate = meanEstimate * countEstimate
val sumVar = (meanEstimate * meanEstimate * countVar) +
(countEstimate * countEstimate * meanVar) +
(meanVar * countVar)
val sumStdev = math.sqrt(sumVar)
val confFactor = {
if (counter.count > 100) {
Probability.normalInverse(1 - (1 - confidence) / 2)
} else {
Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
}
}
val low = sumEstimate - confFactor * sumStdev
val high = sumEstimate + confFactor * sumStdev
new BoundedDouble(sumEstimate, confidence, low, high)
}
}
}

Просмотреть файл

@ -0,0 +1,18 @@
package spark.scheduler
import spark.TaskContext
/**
* Tracks information about an active job in the DAGScheduler.
*/
class ActiveJob(
val runId: Int,
val finalStage: Stage,
val func: (TaskContext, Iterator[_]) => _,
val partitions: Array[Int],
val listener: JobListener) {
val numPartitions = partitions.length
val finished = Array.fill[Boolean](numPartitions)(false)
var numFinished = 0
}

Просмотреть файл

@ -0,0 +1,532 @@
package spark.scheduler
import java.net.URI
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.Future
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map}
import spark._
import spark.partial.ApproximateActionListener
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
import spark.storage.BlockManagerMaster
import spark.storage.BlockManagerId
/**
* A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
* each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
* schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/
class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
taskSched.setListener(this)
// Called by TaskScheduler to report task completions or failures.
override def taskEnded(
task: Task[_],
reason: TaskEndReason,
result: Any,
accumUpdates: Map[Long, Any]) {
eventQueue.put(CompletionEvent(task, reason, result, accumUpdates))
}
// Called by TaskScheduler when a host fails.
override def hostLost(host: String) {
eventQueue.put(HostLost(host))
}
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
val RESUBMIT_TIMEOUT = 50L
// The time, in millis, to wake up between polls of the completion queue in order to potentially
// resubmit failed stages
val POLL_TIMEOUT = 10L
private val lock = new Object // Used for access to the entire DAGScheduler
private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
val nextRunId = new AtomicInteger(0)
val nextStageId = new AtomicInteger(0)
val idToStage = new HashMap[Int, Stage]
val shuffleToMapStage = new HashMap[Int, Stage]
var cacheLocs = new HashMap[Int, Array[List[String]]]
val env = SparkEnv.get
val cacheTracker = env.cacheTracker
val mapOutputTracker = env.mapOutputTracker
val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back;
// that's not going to be a realistic assumption in general
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
val running = new HashSet[Stage] // Stages we are running right now
val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage
var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits
val activeJobs = new HashSet[ActiveJob]
val resultStageToJob = new HashMap[Stage, ActiveJob]
// Start a thread to run the DAGScheduler event loop
new Thread("DAGScheduler") {
setDaemon(true)
override def run() {
DAGScheduler.this.run()
}
}.start()
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
cacheLocs(rdd.id)
}
def updateCacheLocs() {
cacheLocs = cacheTracker.getLocationsSnapshot()
}
/**
* Get or create a shuffle map stage for the given shuffle dependency's map side.
* The priority value passed in will be used if the stage doesn't already exist with
* a lower priority (we assume that priorities always increase across jobs for now).
*/
def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_,_], priority: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
val stage = newStage(shuffleDep.rdd, Some(shuffleDep), priority)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
}
/**
* Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or
* as a result stage for the final RDD used directly in an action. The stage will also be given
* the provided priority.
*/
def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]], priority: Int): Stage = {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of splits is unknown
logInfo("Registering RDD " + rdd.id + ": " + rdd)
cacheTracker.registerRDD(rdd.id, rdd.splits.size)
if (shuffleDep != None) {
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
}
val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority)
idToStage(id) = stage
stage
}
/**
* Get or create the list of parent stages for a given RDD. The stages will be assigned the
* provided priority if they haven't already been created with a lower priority.
*/
def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
val parents = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(r: RDD[_]) {
if (!visited(r)) {
visited += r
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
logInfo("Registering parent RDD " + r.id + ": " + r)
cacheTracker.registerRDD(r.id, r.splits.size)
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_,_] =>
parents += getShuffleMapStage(shufDep, priority)
case _ =>
visit(dep.rdd)
}
}
}
}
visit(rdd)
parents.toList
}
def getMissingParentStages(stage: Stage): List[Stage] = {
val missing = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
val locs = getCacheLocs(rdd)
for (p <- 0 until rdd.splits.size) {
if (locs(p) == Nil) {
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_,_] =>
val mapStage = getShuffleMapStage(shufDep, stage.priority)
if (!mapStage.isAvailable) {
missing += mapStage
}
case narrowDep: NarrowDependency[_] =>
visit(narrowDep.rdd)
}
}
}
}
}
}
visit(stage.rdd)
missing.toList
}
def runJob[T, U](
finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
allowLocal: Boolean)
(implicit m: ClassManifest[U]): Array[U] =
{
val waiter = new JobWaiter(partitions.size)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, waiter))
waiter.getResult() match {
case JobSucceeded(results: Seq[_]) =>
return results.asInstanceOf[Seq[U]].toArray
case JobFailed(exception: Exception) =>
throw exception
}
}
def runApproximateJob[T, U, R](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long
): PartialResult[R] =
{
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.splits.size).toArray
eventQueue.put(JobSubmitted(rdd, func2, partitions, false, listener))
return listener.getResult() // Will throw an exception if the job fails
}
/**
* The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
* events and responds by launching tasks. This runs in a dedicated thread and receives events
* via the eventQueue.
*/
def run() = {
SparkEnv.set(env)
while (true) {
val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
if (event != null) {
logDebug("Got event of type " + event.getClass.getName)
}
event match {
case JobSubmitted(finalRDD, func, partitions, allowLocal, listener) =>
val runId = nextRunId.getAndIncrement()
val finalStage = newStage(finalRDD, None, runId)
val job = new ActiveJob(runId, finalStage, func, partitions, listener)
updateCacheLocs()
logInfo("Got job " + job.runId + " with " + partitions.length + " output partitions")
logInfo("Final stage: " + finalStage)
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
// Compute very short actions like first() or take() with no parent stages locally.
runLocally(job)
} else {
activeJobs += job
resultStageToJob(finalStage) = job
submitStage(finalStage)
}
case HostLost(host) =>
handleHostLost(host)
case completion: CompletionEvent =>
handleTaskCompletion(completion)
case null =>
// queue.poll() timed out, ignore it
}
// Periodically resubmit failed stages if some map output fetches have failed and we have
// waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
// tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
// the same time, so we want to make sure we've identified all the reduce tasks that depend
// on the failed node.
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
logInfo("Resubmitting failed stages")
updateCacheLocs()
val failed2 = failed.toArray
failed.clear()
for (stage <- failed2.sortBy(_.priority)) {
submitStage(stage)
}
} else {
// TODO: We might want to run this less often, when we are sure that something has become
// runnable that wasn't before.
logDebug("Checking for newly runnable parent stages")
logDebug("running: " + running)
logDebug("waiting: " + waiting)
logDebug("failed: " + failed)
val waiting2 = waiting.toArray
waiting.clear()
for (stage <- waiting2.sortBy(_.priority)) {
submitStage(stage)
}
}
}
}
/**
* Run a job on an RDD locally, assuming it has only a single partition and no dependencies.
* We run the operation in a separate thread just in case it takes a bunch of time, so that we
* don't block the DAGScheduler event loop or other concurrent jobs.
*/
def runLocally(job: ActiveJob) {
logInfo("Computing the requested partition locally")
new Thread("Local computation of job " + job.runId) {
override def run() {
try {
SparkEnv.set(env)
val rdd = job.finalStage.rdd
val split = rdd.splits(job.partitions(0))
val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
val result = job.func(taskContext, rdd.iterator(split))
job.listener.taskSucceeded(0, result)
} catch {
case e: Exception =>
job.listener.jobFailed(e)
}
}
}.start()
}
def submitStage(stage: Stage) {
logDebug("submitStage(" + stage + ")")
if (!waiting(stage) && !running(stage) && !failed(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
if (missing == Nil) {
logInfo("Submitting " + stage + ", which has no missing parents")
submitMissingTasks(stage)
running += stage
} else {
for (parent <- missing) {
submitStage(parent)
}
waiting += stage
}
}
}
def submitMissingTasks(stage: Stage) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
myPending.clear()
var tasks = ArrayBuffer[Task[_]]()
if (stage.isShuffleMap) {
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
val locs = getPreferredLocs(stage.rdd, p)
tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
}
} else {
// This is a final stage; figure out its job's missing partitions
val job = resultStageToJob(stage)
for (id <- 0 until job.numPartitions if (!job.finished(id))) {
val partition = job.partitions(id)
val locs = getPreferredLocs(stage.rdd, partition)
tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id)
}
}
if (tasks.size > 0) {
logInfo("Submitting " + tasks.size + " missing tasks from " + stage)
myPending ++= tasks
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority))
} else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
running -= stage
}
}
/**
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
*/
def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
val stage = idToStage(task.stageId)
event.reason match {
case Success =>
logInfo("Completed " + task)
if (event.accumUpdates != null) {
Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
}
pendingTasks(stage) -= task
task match {
case rt: ResultTask[_, _] =>
resultStageToJob.get(stage) match {
case Some(job) =>
if (!job.finished(rt.outputId)) {
job.finished(rt.outputId) = true
job.numFinished += 1
job.listener.taskSucceeded(rt.outputId, event.result)
// If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) {
activeJobs -= job
resultStageToJob -= stage
running -= stage
}
}
case None =>
logInfo("Ignoring result from " + rt + " because its job has finished")
}
case smt: ShuffleMapTask =>
val stage = idToStage(smt.stageId)
val bmAddress = event.result.asInstanceOf[BlockManagerId]
val host = bmAddress.ip
logInfo("ShuffleMapTask finished with host " + host)
if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos
stage.addOutputLoc(smt.partition, bmAddress)
}
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
logInfo(stage + " finished; looking for newly runnable stages")
running -= stage
logInfo("running: " + running)
logInfo("waiting: " + waiting)
logInfo("failed: " + failed)
if (stage.shuffleDep != None) {
mapOutputTracker.registerMapOutputs(
stage.shuffleDep.get.shuffleId,
stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray)
}
updateCacheLocs()
if (stage.outputLocs.count(_ == Nil) != 0) {
// Some tasks had failed; let's resubmit this stage
// TODO: Lower-level scheduler should also deal with this
logInfo("Resubmitting " + stage + " because some of its tasks had failed: " +
stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", "))
submitStage(stage)
} else {
val newlyRunnable = new ArrayBuffer[Stage]
for (stage <- waiting) {
logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage))
}
for (stage <- waiting if getMissingParentStages(stage) == Nil) {
newlyRunnable += stage
}
waiting --= newlyRunnable
running ++= newlyRunnable
for (stage <- newlyRunnable.sortBy(_.id)) {
submitMissingTasks(stage)
}
}
}
}
case Resubmitted =>
logInfo("Resubmitted " + task + ", so marking it as still running")
pendingTasks(stage) += task
case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
// Mark the stage that the reducer was in as unrunnable
val failedStage = idToStage(task.stageId)
running -= failedStage
failed += failedStage
// TODO: Cancel running tasks in the stage
logInfo("Marking " + failedStage + " for resubmision due to a fetch failure")
// Mark the map whose fetch failed as broken in the map stage
val mapStage = shuffleToMapStage(shuffleId)
mapStage.removeOutputLoc(mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission")
failed += mapStage
// Remember that a fetch failed now; this is used to resubmit the broken
// stages later, after a small wait (to give other tasks the chance to fail)
lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock
// TODO: mark the host as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleHostLost(bmAddress.ip)
}
case _ =>
// Non-fetch failure -- probably a bug in the job, so bail out
// TODO: Cancel all tasks that are still running
resultStageToJob.get(stage) match {
case Some(job) =>
val error = new SparkException("Task failed: " + task + ", reason: " + event.reason)
job.listener.jobFailed(error)
activeJobs -= job
resultStageToJob -= stage
case None =>
logInfo("Ignoring result from " + task + " because its job has finished")
}
}
}
/**
* Responds to a host being lost. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use hostLost() to post a host lost event from outside.
*/
def handleHostLost(host: String) {
if (!deadHosts.contains(host)) {
logInfo("Host lost: " + host)
deadHosts += host
BlockManagerMaster.notifyADeadHost(host)
// TODO: This will be really slow if we keep accumulating shuffle map stages
for ((shuffleId, stage) <- shuffleToMapStage) {
stage.removeOutputsOnHost(host)
val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
mapOutputTracker.registerMapOutputs(shuffleId, locs, true)
}
cacheTracker.cacheLost(host)
updateCacheLocs()
}
}
def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
// If the partition is cached, return the cache locations
val cached = getCacheLocs(rdd)(partition)
if (cached != Nil) {
return cached
}
// If the RDD has some placement preferences (as is the case for input RDDs), get those
val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList
if (rddPrefs != Nil) {
return rddPrefs
}
// If the RDD has narrow dependencies, pick the first partition of the first narrow dep
// that has any placement preferences. Ideally we would choose based on transfer sizes,
// but this will do for now.
rdd.dependencies.foreach(_ match {
case n: NarrowDependency[_] =>
for (inPart <- n.getParents(partition)) {
val locs = getPreferredLocs(n.rdd, inPart)
if (locs != Nil)
return locs;
}
case _ =>
})
return Nil
}
def stop() {
// TODO: Put a stop event on our queue and break the event loop
taskSched.stop()
}
}

Просмотреть файл

@ -0,0 +1,30 @@
package spark.scheduler
import scala.collection.mutable.Map
import spark._
/**
* Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue
* architecture where any thread can post an event (e.g. a task finishing or a new job being
* submitted) but there is a single "logic" thread that reads these events and takes decisions.
* This greatly simplifies synchronization.
*/
sealed trait DAGSchedulerEvent
case class JobSubmitted(
finalRDD: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
allowLocal: Boolean,
listener: JobListener)
extends DAGSchedulerEvent
case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
result: Any,
accumUpdates: Map[Long, Any])
extends DAGSchedulerEvent
case class HostLost(host: String) extends DAGSchedulerEvent

Просмотреть файл

@ -0,0 +1,11 @@
package spark.scheduler
/**
* Interface used to listen for job completion or failure events after submitting a job to the
* DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole
* job fails (and no further taskSucceeded events will happen).
*/
trait JobListener {
def taskSucceeded(index: Int, result: Any)
def jobFailed(exception: Exception)
}

Просмотреть файл

@ -0,0 +1,9 @@
package spark.scheduler
/**
* A result of a job in the DAGScheduler.
*/
sealed trait JobResult
case class JobSucceeded(results: Seq[_]) extends JobResult
case class JobFailed(exception: Exception) extends JobResult

Просмотреть файл

@ -0,0 +1,43 @@
package spark.scheduler
import scala.collection.mutable.ArrayBuffer
/**
* An object that waits for a DAGScheduler job to complete.
*/
class JobWaiter(totalTasks: Int) extends JobListener {
private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null)
private var finishedTasks = 0
private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
private var jobResult: JobResult = null // If the job is finished, this will be its result
override def taskSucceeded(index: Int, result: Any) = synchronized {
if (jobFinished) {
throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
}
taskResults(index) = result
finishedTasks += 1
if (finishedTasks == totalTasks) {
jobFinished = true
jobResult = JobSucceeded(taskResults)
this.notifyAll()
}
}
override def jobFailed(exception: Exception) = synchronized {
if (jobFinished) {
throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter")
}
jobFinished = true
jobResult = JobFailed(exception)
this.notifyAll()
}
def getResult(): JobResult = synchronized {
while (!jobFinished) {
this.wait()
}
return jobResult
}
}

Просмотреть файл

@ -1,14 +1,15 @@
package spark
package spark.scheduler
import spark._
class ResultTask[T, U](
runId: Int,
stageId: Int,
rdd: RDD[T],
stageId: Int,
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
val partition: Int,
locs: Seq[String],
val partition: Int,
@transient locs: Seq[String],
val outputId: Int)
extends DAGTask[U](runId, stageId) {
extends Task[U](stageId) {
val split = rdd.splits(partition)

Просмотреть файл

@ -0,0 +1,135 @@
package spark.scheduler
import java.io._
import java.util.HashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.ArrayBuffer
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import com.ning.compress.lzf.LZFInputStream
import com.ning.compress.lzf.LZFOutputStream
import spark._
import spark.storage._
object ShuffleMapTask {
val serializedInfoCache = new HashMap[Int, Array[Byte]]
val deserializedInfoCache = new HashMap[Int, (RDD[_], ShuffleDependency[_,_,_])]
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = {
synchronized {
val old = serializedInfoCache.get(stageId)
if (old != null) {
return old
} else {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
objOut.writeObject(rdd)
objOut.writeObject(dep)
objOut.close()
val bytes = out.toByteArray
serializedInfoCache.put(stageId, bytes)
return bytes
}
}
}
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = {
synchronized {
val old = deserializedInfoCache.get(stageId)
if (old != null) {
return old
} else {
val loader = currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, loader)
}
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]]
val tuple = (rdd, dep)
deserializedInfoCache.put(stageId, tuple)
return tuple
}
}
}
}
class ShuffleMapTask(
stageId: Int,
var rdd: RDD[_],
var dep: ShuffleDependency[_,_,_],
var partition: Int,
@transient var locs: Seq[String])
extends Task[BlockManagerId](stageId)
with Externalizable
with Logging {
def this() = this(0, null, null, 0, null)
var split = if (rdd == null) {
null
} else {
rdd.splits(partition)
}
override def writeExternal(out: ObjectOutput) {
out.writeInt(stageId)
val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
out.writeInt(bytes.length)
out.write(bytes)
out.writeInt(partition)
out.writeObject(split)
}
override def readExternal(in: ObjectInput) {
val stageId = in.readInt()
val numBytes = in.readInt()
val bytes = new Array[Byte](numBytes)
in.readFully(bytes)
val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
rdd = rdd_
dep = dep_
partition = in.readInt()
split = in.readObject().asInstanceOf[Split]
}
override def run(attemptId: Int): BlockManagerId = {
val numOutputSplits = dep.partitioner.numPartitions
val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
val partitioner = dep.partitioner.asInstanceOf[Partitioner]
val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
for (elem <- rdd.iterator(split)) {
val (k, v) = elem.asInstanceOf[(Any, Any)]
var bucketId = partitioner.getPartition(k)
val bucket = buckets(bucketId)
var existing = bucket.get(k)
if (existing == null) {
bucket.put(k, aggregator.createCombiner(v))
} else {
bucket.put(k, aggregator.mergeValue(existing, v))
}
}
val ser = SparkEnv.get.serializer.newInstance()
val blockManager = SparkEnv.get.blockManager
for (i <- 0 until numOutputSplits) {
val blockId = "shuffleid_" + dep.shuffleId + "_" + partition + "_" + i
val arr = new ArrayBuffer[Any]
val iter = buckets(i).entrySet().iterator()
while (iter.hasNext()) {
val entry = iter.next()
arr += ((entry.getKey(), entry.getValue()))
}
// TODO: This should probably be DISK_ONLY
blockManager.put(blockId, arr.iterator, StorageLevel.MEMORY_ONLY, false)
}
return SparkEnv.get.blockManager.blockManagerId
}
override def preferredLocations: Seq[String] = locs
override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
}

Просмотреть файл

@ -0,0 +1,86 @@
package spark.scheduler
import java.net.URI
import spark._
import spark.storage.BlockManagerId
/**
* A stage is a set of independent tasks all computing the same function that need to run as part
* of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run
* by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the
* DAGScheduler runs these stages in topological order.
*
* Each Stage can either be a shuffle map stage, in which case its tasks' results are input for
* another stage, or a result stage, in which case its tasks directly compute the action that
* initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes
* that each output partition is on.
*
* Each Stage also has a priority, which is (by default) based on the job it was submitted in.
* This allows Stages from earlier jobs to be computed first or recovered faster on failure.
*/
class Stage(
val id: Int,
val rdd: RDD[_],
val shuffleDep: Option[ShuffleDependency[_,_,_]], // Output shuffle if stage is a map stage
val parents: List[Stage],
val priority: Int)
extends Logging {
val isShuffleMap = shuffleDep != None
val numPartitions = rdd.splits.size
val outputLocs = Array.fill[List[BlockManagerId]](numPartitions)(Nil)
var numAvailableOutputs = 0
private var nextAttemptId = 0
def isAvailable: Boolean = {
if (/*parents.size == 0 &&*/ !isShuffleMap) {
true
} else {
numAvailableOutputs == numPartitions
}
}
def addOutputLoc(partition: Int, bmAddress: BlockManagerId) {
val prevList = outputLocs(partition)
outputLocs(partition) = bmAddress :: prevList
if (prevList == Nil)
numAvailableOutputs += 1
}
def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) {
val prevList = outputLocs(partition)
val newList = prevList.filterNot(_ == bmAddress)
outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) {
numAvailableOutputs -= 1
}
}
def removeOutputsOnHost(host: String) {
var becameUnavailable = false
for (partition <- 0 until numPartitions) {
val prevList = outputLocs(partition)
val newList = prevList.filterNot(_.ip == host)
outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) {
becameUnavailable = true
numAvailableOutputs -= 1
}
}
if (becameUnavailable) {
logInfo("%s is now unavailable on %s (%d/%d, %s)".format(this, host, numAvailableOutputs, numPartitions, isAvailable))
}
}
def newAttemptId(): Int = {
val id = nextAttemptId
nextAttemptId += 1
return id
}
override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]"
override def hashCode(): Int = id
}

Просмотреть файл

@ -0,0 +1,11 @@
package spark.scheduler
/**
* A task to execute on a worker node.
*/
abstract class Task[T](val stageId: Int) extends Serializable {
def run(attemptId: Int): T
def preferredLocations: Seq[String] = Nil
var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler.
}

Просмотреть файл

@ -0,0 +1,34 @@
package spark.scheduler
import java.io._
import scala.collection.mutable.Map
// Task result. Also contains updates to accumulator variables.
// TODO: Use of distributed cache to return result is a hack to get around
// what seems to be a bug with messages over 60KB in libprocess; fix it
class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Externalizable {
def this() = this(null.asInstanceOf[T], null)
override def writeExternal(out: ObjectOutput) {
out.writeObject(value)
out.writeInt(accumUpdates.size)
for ((key, value) <- accumUpdates) {
out.writeLong(key)
out.writeObject(value)
}
}
override def readExternal(in: ObjectInput) {
value = in.readObject().asInstanceOf[T]
val numUpdates = in.readInt
if (numUpdates == 0) {
accumUpdates = null
} else {
accumUpdates = Map()
for (i <- 0 until numUpdates) {
accumUpdates(in.readLong()) = in.readObject()
}
}
}
}

Просмотреть файл

@ -0,0 +1,27 @@
package spark.scheduler
/**
* Low-level task scheduler interface, implemented by both MesosScheduler and LocalScheduler.
* These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
* and are responsible for sending the tasks to the cluster, running them, retrying if there
* are failures, and mitigating stragglers. They return events to the DAGScheduler through
* the TaskSchedulerListener interface.
*/
trait TaskScheduler {
def start(): Unit
// Wait for registration with Mesos.
def waitForRegister(): Unit
// Disconnect from the cluster.
def stop(): Unit
// Submit a sequence of tasks to run.
def submitTasks(taskSet: TaskSet): Unit
// Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
def setListener(listener: TaskSchedulerListener): Unit
// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
def defaultParallelism(): Int
}

Просмотреть файл

@ -0,0 +1,16 @@
package spark.scheduler
import scala.collection.mutable.Map
import spark.TaskEndReason
/**
* Interface for getting events back from the TaskScheduler.
*/
trait TaskSchedulerListener {
// A task has finished or failed.
def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit
// A node was lost from the cluster.
def hostLost(host: String): Unit
}

Просмотреть файл

@ -0,0 +1,9 @@
package spark.scheduler
/**
* A set of tasks submitted together to the low-level TaskScheduler, usually representing
* missing partitions of a particular stage.
*/
class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) {
val id: String = stageId + "." + attempt
}

Просмотреть файл

@ -1,16 +1,21 @@
package spark
package spark.scheduler.local
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger
import spark._
import spark.scheduler._
/**
* A simple Scheduler implementation that runs tasks locally in a thread pool. Optionally the
* scheduler also allows each task to fail up to maxFailures times, which is useful for testing
* fault recovery.
* A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
* testing fault recovery.
*/
private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGScheduler with Logging {
class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with Logging {
var attemptId = new AtomicInteger(0)
var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
val env = SparkEnv.get
var listener: TaskSchedulerListener = null
// TODO: Need to take into account stage priority in scheduling
@ -18,7 +23,12 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
override def waitForRegister() {}
override def submitTasks(tasks: Seq[Task[_]], runId: Int) {
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}
override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
val failCount = new Array[Int](tasks.size)
def submitTask(task: Task[_], idInJob: Int) {
@ -38,23 +48,14 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
Accumulators.clear
val ser = SparkEnv.get.closureSerializer.newInstance()
val startTime = System.currentTimeMillis
val bytes = ser.serialize(task)
val timeTaken = System.currentTimeMillis - startTime
logInfo("Size of task %d is %d bytes and took %d ms to serialize".format(
idInJob, bytes.size, timeTaken))
val deserializedTask = ser.deserialize[Task[_]](bytes, currentThread.getContextClassLoader)
val bytes = Utils.serialize(task)
logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes")
val deserializedTask = Utils.deserialize[Task[_]](
bytes, Thread.currentThread.getContextClassLoader)
val result: Any = deserializedTask.run(attemptId)
// Serialize and deserialize the result to emulate what the mesos
// executor does. This is useful to catch serialization errors early
// on in development (so when users move their local Spark programs
// to the cluster, they don't get surprised by serialization errors).
val resultToReturn = ser.deserialize[Any](ser.serialize(result))
val accumUpdates = Accumulators.values
logInfo("Finished task " + idInJob)
taskEnded(task, Success, resultToReturn, accumUpdates)
listener.taskEnded(task, Success, result, accumUpdates)
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
@ -64,7 +65,7 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
submitTask(task, idInJob)
} else {
// TODO: Do something nicer here to return all the way to the user
taskEnded(task, new ExceptionFailure(t), null, null)
listener.taskEnded(task, new ExceptionFailure(t), null, null)
}
}
}

Просмотреть файл

@ -0,0 +1,364 @@
package spark.scheduler.mesos
import java.io.{File, FileInputStream, FileOutputStream}
import java.util.{ArrayList => JArrayList}
import java.util.{List => JList}
import java.util.{HashMap => JHashMap}
import java.util.concurrent._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.collection.mutable.Map
import scala.collection.mutable.PriorityQueue
import scala.collection.JavaConversions._
import scala.math.Ordering
import akka.actor._
import akka.actor.Actor
import akka.actor.Actor._
import akka.actor.Channel
import akka.serialization.RemoteActorSerialization._
import com.google.protobuf.ByteString
import org.apache.mesos.{Scheduler => MScheduler}
import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
import spark._
import spark.scheduler._
sealed trait CoarseMesosSchedulerMessage
case class RegisterSlave(slaveId: String, host: String, port: Int) extends CoarseMesosSchedulerMessage
case class StatusUpdate(slaveId: String, status: TaskStatus) extends CoarseMesosSchedulerMessage
case class LaunchTask(slaveId: String, task: MTaskInfo) extends CoarseMesosSchedulerMessage
case class ReviveOffers() extends CoarseMesosSchedulerMessage
case class FakeOffer(slaveId: String, host: String, cores: Int)
/**
* Mesos scheduler that uses coarse-grained tasks and does its own fine-grained scheduling inside
* them using Akka actors for messaging. Clients should first call start(), then submit task sets
* through the runTasks method.
*
* TODO: This is a pretty big hack for now.
*/
class CoarseMesosScheduler(
sc: SparkContext,
master: String,
frameworkName: String)
extends MesosScheduler(sc, master, frameworkName) {
val CORES_PER_SLAVE = System.getProperty("spark.coarseMesosScheduler.coresPerSlave", "4").toInt
class MasterActor extends Actor {
val slaveActor = new HashMap[String, ActorRef]
val slaveHost = new HashMap[String, String]
val freeCores = new HashMap[String, Int]
def receive = {
case RegisterSlave(slaveId, host, port) =>
slaveActor(slaveId) = remote.actorFor("WorkerActor", host, port)
logInfo("Slave actor: " + slaveActor(slaveId))
slaveHost(slaveId) = host
freeCores(slaveId) = CORES_PER_SLAVE
makeFakeOffers()
case StatusUpdate(slaveId, status) =>
fakeStatusUpdate(status)
if (isFinished(status.getState)) {
freeCores(slaveId) += 1
makeFakeOffers(slaveId)
}
case LaunchTask(slaveId, task) =>
freeCores(slaveId) -= 1
slaveActor(slaveId) ! LaunchTask(slaveId, task)
case ReviveOffers() =>
logInfo("Reviving offers")
makeFakeOffers()
}
// Make fake resource offers for all slaves
def makeFakeOffers() {
fakeResourceOffers(slaveHost.toSeq.map{case (id, host) => FakeOffer(id, host, freeCores(id))})
}
// Make fake resource offers for all slaves
def makeFakeOffers(slaveId: String) {
fakeResourceOffers(Seq(FakeOffer(slaveId, slaveHost(slaveId), freeCores(slaveId))))
}
}
val masterActor: ActorRef = actorOf(new MasterActor)
remote.register("MasterActor", masterActor)
masterActor.start()
val taskIdsOnSlave = new HashMap[String, HashSet[String]]
/**
* Method called by Mesos to offer resources on slaves. We resond by asking our active task sets
* for tasks in order of priority. We fill each node with tasks in a round-robin manner so that
* tasks are balanced across the cluster.
*/
override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
synchronized {
val tasks = offers.map(o => new JArrayList[MTaskInfo])
for (i <- 0 until offers.size) {
val o = offers.get(i)
val slaveId = o.getSlaveId.getValue
if (!slaveIdToHost.contains(slaveId)) {
slaveIdToHost(slaveId) = o.getHostname
hostsAlive += o.getHostname
taskIdsOnSlave(slaveId) = new HashSet[String]
// Launch an infinite task on the node that will talk to the MasterActor to get fake tasks
val cpuRes = Resource.newBuilder()
.setName("cpus")
.setType(Value.Type.SCALAR)
.setScalar(Value.Scalar.newBuilder().setValue(1).build())
.build()
val task = new WorkerTask(slaveId, o.getHostname)
val serializedTask = Utils.serialize(task)
tasks(i).add(MTaskInfo.newBuilder()
.setTaskId(newTaskId())
.setSlaveId(o.getSlaveId)
.setExecutor(executorInfo)
.setName("worker task")
.addResources(cpuRes)
.setData(ByteString.copyFrom(serializedTask))
.build())
}
}
val filters = Filters.newBuilder().setRefuseSeconds(10).build()
for (i <- 0 until offers.size) {
d.launchTasks(offers(i).getId(), tasks(i), filters)
}
}
}
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
val tid = status.getTaskId.getValue
var taskSetToUpdate: Option[TaskSetManager] = None
var taskFailed = false
synchronized {
try {
taskIdToTaskSetId.get(tid) match {
case Some(taskSetId) =>
if (activeTaskSets.contains(taskSetId)) {
//activeTaskSets(taskSetId).statusUpdate(status)
taskSetToUpdate = Some(activeTaskSets(taskSetId))
}
if (isFinished(status.getState)) {
taskIdToTaskSetId.remove(tid)
if (taskSetTaskIds.contains(taskSetId)) {
taskSetTaskIds(taskSetId) -= tid
}
val slaveId = taskIdToSlaveId(tid)
taskIdToSlaveId -= tid
taskIdsOnSlave(slaveId) -= tid
}
if (status.getState == TaskState.TASK_FAILED) {
taskFailed = true
}
case None =>
logInfo("Ignoring update from TID " + tid + " because its task set is gone")
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
}
}
// Update the task set and DAGScheduler without holding a lock on this, because that can deadlock
if (taskSetToUpdate != None) {
taskSetToUpdate.get.statusUpdate(status)
}
if (taskFailed) {
// Revive offers if a task had failed for some reason other than host lost
reviveOffers()
}
}
override def slaveLost(d: SchedulerDriver, s: SlaveID) {
logInfo("Slave lost: " + s.getValue)
var failedHost: Option[String] = None
var lostTids: Option[HashSet[String]] = None
synchronized {
val slaveId = s.getValue
val host = slaveIdToHost(slaveId)
if (hostsAlive.contains(host)) {
slaveIdsWithExecutors -= slaveId
hostsAlive -= host
failedHost = Some(host)
lostTids = Some(taskIdsOnSlave(slaveId))
logInfo("failedHost: " + host)
logInfo("lostTids: " + lostTids)
taskIdsOnSlave -= slaveId
activeTaskSetsQueue.foreach(_.hostLost(host))
}
}
if (failedHost != None) {
// Report all the tasks on the failed host as lost, without holding a lock on this
for (tid <- lostTids.get; taskSetId <- taskIdToTaskSetId.get(tid)) {
// TODO: Maybe call our statusUpdate() instead to clean our internal data structures
activeTaskSets(taskSetId).statusUpdate(TaskStatus.newBuilder()
.setTaskId(TaskID.newBuilder().setValue(tid).build())
.setState(TaskState.TASK_LOST)
.build())
}
// Also report the loss to the DAGScheduler
listener.hostLost(failedHost.get)
reviveOffers();
}
}
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
// Check for speculatable tasks in all our active jobs.
override def checkSpeculatableTasks() {
var shouldRevive = false
synchronized {
for (ts <- activeTaskSetsQueue) {
shouldRevive |= ts.checkSpeculatableTasks()
}
}
if (shouldRevive) {
reviveOffers()
}
}
val lock2 = new Object
var firstWait = true
override def waitForRegister() {
lock2.synchronized {
if (firstWait) {
super.waitForRegister()
Thread.sleep(5000)
firstWait = false
}
}
}
def fakeStatusUpdate(status: TaskStatus) {
statusUpdate(driver, status)
}
def fakeResourceOffers(offers: Seq[FakeOffer]) {
logDebug("fakeResourceOffers: " + offers)
val availableCpus = offers.map(_.cores.toDouble).toArray
var launchedTask = false
for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
do {
launchedTask = false
for (i <- 0 until offers.size if hostsAlive.contains(offers(i).host)) {
manager.slaveOffer(offers(i).slaveId, offers(i).host, availableCpus(i)) match {
case Some(task) =>
val tid = task.getTaskId.getValue
val sid = offers(i).slaveId
taskIdToTaskSetId(tid) = manager.taskSet.id
taskSetTaskIds(manager.taskSet.id) += tid
taskIdToSlaveId(tid) = sid
taskIdsOnSlave(sid) += tid
slaveIdsWithExecutors += sid
availableCpus(i) -= getResource(task.getResourcesList(), "cpus")
launchedTask = true
masterActor ! LaunchTask(sid, task)
case None => {}
}
}
} while (launchedTask)
}
}
override def reviveOffers() {
masterActor ! ReviveOffers()
}
}
class WorkerTask(slaveId: String, host: String) extends Task[Unit](-1) {
generation = 0
def run(id: Int): Unit = {
val actor = actorOf(new WorkerActor(slaveId, host))
if (!remote.isRunning) {
remote.start(Utils.localIpAddress, 7078)
}
remote.register("WorkerActor", actor)
actor.start()
while (true) {
Thread.sleep(10000)
}
}
}
class WorkerActor(slaveId: String, host: String) extends Actor with Logging {
val env = SparkEnv.get
val classLoader = currentThread.getContextClassLoader
val threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
val masterIp: String = System.getProperty("spark.master.host", "localhost")
val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt
val masterActor = remote.actorFor("MasterActor", masterIp, masterPort)
class TaskRunner(desc: MTaskInfo)
extends Runnable {
override def run() = {
val tid = desc.getTaskId.getValue
logInfo("Running task ID " + tid)
try {
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear
val task = Utils.deserialize[Task[Any]](desc.getData.toByteArray, classLoader)
env.mapOutputTracker.updateGeneration(task.generation)
val value = task.run(tid.toInt)
val accumUpdates = Accumulators.values
val result = new TaskResult(value, accumUpdates)
masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
.setTaskId(desc.getTaskId)
.setState(TaskState.TASK_FINISHED)
.setData(ByteString.copyFrom(Utils.serialize(result)))
.build())
logInfo("Finished task ID " + tid)
} catch {
case ffe: FetchFailedException => {
val reason = ffe.toTaskEndReason
masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
.setTaskId(desc.getTaskId)
.setState(TaskState.TASK_FAILED)
.setData(ByteString.copyFrom(Utils.serialize(reason)))
.build())
}
case t: Throwable => {
val reason = ExceptionFailure(t)
masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
.setTaskId(desc.getTaskId)
.setState(TaskState.TASK_FAILED)
.setData(ByteString.copyFrom(Utils.serialize(reason)))
.build())
// TODO: Should we exit the whole executor here? On the one hand, the failed task may
// have left some weird state around depending on when the exception was thrown, but on
// the other hand, maybe we could detect that when future tasks fail and exit then.
logError("Exception in task ID " + tid, t)
//System.exit(1)
}
}
}
}
override def preStart {
val ref = toRemoteActorRefProtocol(self).toByteArray
logInfo("Registering with master")
masterActor ! RegisterSlave(slaveId, host, remote.address.getPort)
}
override def receive = {
case LaunchTask(slaveId, task) =>
threadPool.execute(new TaskRunner(task))
}
}

Просмотреть файл

@ -1,4 +1,4 @@
package spark
package spark.scheduler.mesos
import java.io.{File, FileInputStream, FileOutputStream}
import java.util.{ArrayList => JArrayList}
@ -17,20 +17,23 @@ import com.google.protobuf.ByteString
import org.apache.mesos.{Scheduler => MScheduler}
import org.apache.mesos._
import org.apache.mesos.Protos._
import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
import spark._
import spark.scheduler._
/**
* The main Scheduler implementation, which runs jobs on Mesos. Clients should first call start(),
* then submit tasks through the runTasks method.
* The main TaskScheduler implementation, which runs tasks on Mesos. Clients should first call
* start(), then submit task sets through the runTasks method.
*/
private class MesosScheduler(
class MesosScheduler(
sc: SparkContext,
master: String,
frameworkName: String)
extends MScheduler
with DAGScheduler
extends TaskScheduler
with MScheduler
with Logging {
// Environment variables to pass to our executors
val ENV_VARS_TO_SEND_TO_EXECUTORS = Array(
"SPARK_MEM",
@ -49,55 +52,60 @@ private class MesosScheduler(
}
}
// How often to check for speculative tasks
val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
// Lock used to wait for scheduler to be registered
private var isRegistered = false
private val registeredLock = new Object()
var isRegistered = false
val registeredLock = new Object()
private val activeJobs = new HashMap[Int, Job]
private var activeJobsQueue = new ArrayBuffer[Job]
val activeTaskSets = new HashMap[String, TaskSetManager]
var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
private val taskIdToJobId = new HashMap[String, Int]
private val taskIdToSlaveId = new HashMap[String, String]
private val jobTasks = new HashMap[Int, HashSet[String]]
val taskIdToTaskSetId = new HashMap[String, String]
val taskIdToSlaveId = new HashMap[String, String]
val taskSetTaskIds = new HashMap[String, HashSet[String]]
// Incrementing job and task IDs
private var nextJobId = 0
private var nextTaskId = 0
// Incrementing Mesos task IDs
var nextTaskId = 0
// Driver for talking to Mesos
var driver: SchedulerDriver = null
// Which nodes we have executors on
private val slavesWithExecutors = new HashSet[String]
// Which hosts in the cluster are alive (contains hostnames)
val hostsAlive = new HashSet[String]
// Which slave IDs we have executors on
val slaveIdsWithExecutors = new HashSet[String]
val slaveIdToHost = new HashMap[String, String]
// JAR server, if any JARs were added by the user to the SparkContext
var jarServer: HttpServer = null
// URIs of JARs to pass to executor
var jarUris: String = ""
// Create an ExecutorInfo for our tasks
val executorInfo = createExecutorInfo()
// Sorts jobs in reverse order of run ID for use in our priority queue (so lower IDs run first)
private val jobOrdering = new Ordering[Job] {
override def compare(j1: Job, j2: Job): Int = j2.runId - j1.runId
}
def newJobId(): Int = this.synchronized {
val id = nextJobId
nextJobId += 1
return id
// Listener object to pass upcalls into
var listener: TaskSchedulerListener = null
val mapOutputTracker = SparkEnv.get.mapOutputTracker
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}
def newTaskId(): TaskID = {
val id = "" + nextTaskId;
nextTaskId += 1;
return TaskID.newBuilder().setValue(id).build()
val id = TaskID.newBuilder().setValue("" + nextTaskId).build()
nextTaskId += 1
return id
}
override def start() {
new Thread("Spark scheduler") {
new Thread("MesosScheduler driver") {
setDaemon(true)
override def run {
val sched = MesosScheduler.this
@ -110,12 +118,27 @@ private class MesosScheduler(
case e: Exception => logError("driver.run() failed", e)
}
}
}.start
}.start()
if (System.getProperty("spark.speculation", "false") == "true") {
new Thread("MesosScheduler speculation check") {
setDaemon(true)
override def run {
waitForRegister()
while (true) {
try {
Thread.sleep(SPECULATION_INTERVAL)
} catch { case e: InterruptedException => {} }
checkSpeculatableTasks()
}
}
}.start()
}
}
def createExecutorInfo(): ExecutorInfo = {
val sparkHome = sc.getSparkHome match {
case Some(path) => path
case Some(path) =>
path
case None =>
throw new SparkException("Spark home is not set; set it through the spark.home system " +
"property, the SPARK_HOME environment variable or the SparkContext constructor")
@ -151,27 +174,26 @@ private class MesosScheduler(
.build()
}
def submitTasks(tasks: Seq[Task[_]], runId: Int) {
logInfo("Got a job with " + tasks.size + " tasks")
def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.size + " tasks")
waitForRegister()
this.synchronized {
val jobId = newJobId()
val myJob = new SimpleJob(this, tasks, runId, jobId)
activeJobs(jobId) = myJob
activeJobsQueue += myJob
logInfo("Adding job with ID " + jobId)
jobTasks(jobId) = HashSet.empty[String]
val manager = new TaskSetManager(this, taskSet)
activeTaskSets(taskSet.id) = manager
activeTaskSetsQueue += manager
taskSetTaskIds(taskSet.id) = new HashSet()
}
driver.reviveOffers();
reviveOffers();
}
def jobFinished(job: Job) {
def taskSetFinished(manager: TaskSetManager) {
this.synchronized {
activeJobs -= job.jobId
activeJobsQueue -= job
taskIdToJobId --= jobTasks(job.jobId)
taskIdToSlaveId --= jobTasks(job.jobId)
jobTasks.remove(job.jobId)
activeTaskSets -= manager.taskSet.id
activeTaskSetsQueue -= manager
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id)
taskSetTaskIds.remove(manager.taskSet.id)
}
}
@ -196,33 +218,40 @@ private class MesosScheduler(
override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
/**
* Method called by Mesos to offer resources on slaves. We resond by asking our active jobs for
* tasks in FIFO order. We fill each node with tasks in a round-robin manner so that tasks are
* balanced across the cluster.
* Method called by Mesos to offer resources on slaves. We resond by asking our active task sets
* for tasks in order of priority. We fill each node with tasks in a round-robin manner so that
* tasks are balanced across the cluster.
*/
override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
synchronized {
val tasks = offers.map(o => new JArrayList[TaskInfo])
// Mark each slave as alive and remember its hostname
for (o <- offers) {
slaveIdToHost(o.getSlaveId.getValue) = o.getHostname
hostsAlive += o.getHostname
}
// Build a list of tasks to assign to each slave
val tasks = offers.map(o => new JArrayList[MTaskInfo])
val availableCpus = offers.map(o => getResource(o.getResourcesList(), "cpus"))
val enoughMem = offers.map(o => {
val mem = getResource(o.getResourcesList(), "mem")
val slaveId = o.getSlaveId.getValue
mem >= EXECUTOR_MEMORY || slavesWithExecutors.contains(slaveId)
mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId)
})
var launchedTask = false
for (job <- activeJobsQueue.sorted(jobOrdering)) {
for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
do {
launchedTask = false
for (i <- 0 until offers.size if enoughMem(i)) {
job.slaveOffer(offers(i), availableCpus(i)) match {
val sid = offers(i).getSlaveId.getValue
val host = offers(i).getHostname
manager.slaveOffer(sid, host, availableCpus(i)) match {
case Some(task) =>
tasks(i).add(task)
val tid = task.getTaskId.getValue
val sid = offers(i).getSlaveId.getValue
taskIdToJobId(tid) = job.jobId
jobTasks(job.jobId) += tid
taskIdToTaskSetId(tid) = manager.taskSet.id
taskSetTaskIds(manager.taskSet.id) += tid
taskIdToSlaveId(tid) = sid
slavesWithExecutors += sid
slaveIdsWithExecutors += sid
availableCpus(i) -= getResource(task.getResourcesList(), "cpus")
launchedTask = true
@ -256,53 +285,74 @@ private class MesosScheduler(
}
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
var jobToUpdate: Option[Job] = None
val tid = status.getTaskId.getValue
var taskSetToUpdate: Option[TaskSetManager] = None
var failedHost: Option[String] = None
var taskFailed = false
synchronized {
try {
val tid = status.getTaskId.getValue
if (status.getState == TaskState.TASK_LOST
&& taskIdToSlaveId.contains(tid)) {
if (status.getState == TaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) {
// We lost the executor on this slave, so remember that it's gone
slavesWithExecutors -= taskIdToSlaveId(tid)
val slaveId = taskIdToSlaveId(tid)
val host = slaveIdToHost(slaveId)
if (hostsAlive.contains(host)) {
slaveIdsWithExecutors -= slaveId
hostsAlive -= host
activeTaskSetsQueue.foreach(_.hostLost(host))
failedHost = Some(host)
}
}
taskIdToJobId.get(tid) match {
case Some(jobId) =>
if (activeJobs.contains(jobId)) {
jobToUpdate = Some(activeJobs(jobId))
taskIdToTaskSetId.get(tid) match {
case Some(taskSetId) =>
if (activeTaskSets.contains(taskSetId)) {
//activeTaskSets(taskSetId).statusUpdate(status)
taskSetToUpdate = Some(activeTaskSets(taskSetId))
}
if (isFinished(status.getState)) {
taskIdToJobId.remove(tid)
if (jobTasks.contains(jobId)) {
jobTasks(jobId) -= tid
taskIdToTaskSetId.remove(tid)
if (taskSetTaskIds.contains(taskSetId)) {
taskSetTaskIds(taskSetId) -= tid
}
taskIdToSlaveId.remove(tid)
}
if (status.getState == TaskState.TASK_FAILED) {
taskFailed = true
}
case None =>
logInfo("Ignoring update from TID " + tid + " because its job is gone")
logInfo("Ignoring update from TID " + tid + " because its task set is gone")
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
}
}
for (j <- jobToUpdate) {
j.statusUpdate(status)
// Update the task set and DAGScheduler without holding a lock on this, because that can deadlock
if (taskSetToUpdate != None) {
taskSetToUpdate.get.statusUpdate(status)
}
if (failedHost != None) {
listener.hostLost(failedHost.get)
reviveOffers();
}
if (taskFailed) {
// Also revive offers if a task had failed for some reason other than host lost
reviveOffers()
}
}
override def error(d: SchedulerDriver, message: String) {
logError("Mesos error: " + message)
synchronized {
if (activeJobs.size > 0) {
// Have each job throw a SparkException with the error
for ((jobId, activeJob) <- activeJobs) {
if (activeTaskSets.size > 0) {
// Have each task set throw a SparkException with the error
for ((taskSetId, manager) <- activeTaskSets) {
try {
activeJob.error(message)
manager.error(message)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
}
} else {
// No jobs are active but we still got an error. Just exit since this
// No task sets are active but we still got an error. Just exit since this
// must mean the error is during registration.
// It might be good to do something smarter here in the future.
System.exit(1)
@ -373,41 +423,68 @@ private class MesosScheduler(
return Utils.serialize(props.toArray)
}
override def frameworkMessage(
d: SchedulerDriver,
e: ExecutorID,
s: SlaveID,
b: Array[Byte]) {}
override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
override def slaveLost(d: SchedulerDriver, s: SlaveID) {
slavesWithExecutors.remove(s.getValue)
var failedHost: Option[String] = None
synchronized {
val slaveId = s.getValue
val host = slaveIdToHost(slaveId)
if (hostsAlive.contains(host)) {
slaveIdsWithExecutors -= slaveId
hostsAlive -= host
activeTaskSetsQueue.foreach(_.hostLost(host))
failedHost = Some(host)
}
}
if (failedHost != None) {
listener.hostLost(failedHost.get)
reviveOffers();
}
}
override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) {
slavesWithExecutors.remove(s.getValue)
logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue))
slaveLost(d, s)
}
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
synchronized {
for (ts <- activeTaskSetsQueue) {
shouldRevive |= ts.checkSpeculatableTasks()
}
}
if (shouldRevive) {
reviveOffers()
}
}
def reviveOffers() {
driver.reviveOffers()
}
}
object MesosScheduler {
/**
* Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
* This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
* Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
* This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
* environment variable.
*/
def memoryStringToMb(str: String): Int = {
val lower = str.toLowerCase
if (lower.endsWith("k")) {
(lower.substring(0, lower.length - 1).toLong / 1024).toInt
(lower.substring(0, lower.length-1).toLong / 1024).toInt
} else if (lower.endsWith("m")) {
lower.substring(0, lower.length - 1).toInt
lower.substring(0, lower.length-1).toInt
} else if (lower.endsWith("g")) {
lower.substring(0, lower.length - 1).toInt * 1024
lower.substring(0, lower.length-1).toInt * 1024
} else if (lower.endsWith("t")) {
lower.substring(0, lower.length - 1).toInt * 1024 * 1024
} else {
// no suffix, so it's just a number in bytes
lower.substring(0, lower.length-1).toInt * 1024 * 1024
} else {// no suffix, so it's just a number in bytes
(lower.toLong / 1024 / 1024).toInt
}
}

Просмотреть файл

@ -0,0 +1,32 @@
package spark.scheduler.mesos
/**
* Information about a running task attempt.
*/
class TaskInfo(val taskId: String, val index: Int, val launchTime: Long, val host: String) {
var finishTime: Long = 0
var failed = false
def markSuccessful(time: Long = System.currentTimeMillis) {
finishTime = time
}
def markFailed(time: Long = System.currentTimeMillis) {
finishTime = time
failed = true
}
def finished: Boolean = finishTime != 0
def successful: Boolean = finished && !failed
def duration: Long = {
if (!finished) {
throw new UnsupportedOperationException("duration() called on unfinished tasks")
} else {
finishTime - launchTime
}
}
def timeRunning(currentTime: Long): Long = currentTime - launchTime
}

Просмотреть файл

@ -1,28 +1,32 @@
package spark
package spark.scheduler.mesos
import java.util.Arrays
import java.util.{HashMap => JHashMap}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
import com.google.protobuf.ByteString
import org.apache.mesos._
import org.apache.mesos.Protos._
import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
import spark._
import spark.scheduler._
/**
* A Job that runs a set of tasks with no interdependencies.
* Schedules the tasks within a single TaskSet in the MesosScheduler.
*/
class SimpleJob(
class TaskSetManager(
sched: MesosScheduler,
tasksSeq: Seq[Task[_]],
runId: Int,
jobId: Int)
extends Job(runId, jobId)
with Logging {
val taskSet: TaskSet)
extends Logging {
// Maximum time to wait to run a task in a preferred location (in ms)
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "5000").toLong
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
// CPUs to request per task
val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
@ -30,18 +34,20 @@ class SimpleJob(
// Maximum times a task is allowed to fail before failing the job
val MAX_TASK_FAILURES = 4
// Quantile of tasks at which to start speculation
val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
// Serializer for closures and tasks.
val ser = SparkEnv.get.closureSerializer.newInstance()
val callingThread = Thread.currentThread
val tasks = tasksSeq.toArray
val priority = taskSet.priority
val tasks = taskSet.tasks
val numTasks = tasks.length
val launched = new Array[Boolean](numTasks)
val copiesRunning = new Array[Int](numTasks)
val finished = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
val tidToIndex = HashMap[String, Int]()
var tasksLaunched = 0
val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
var tasksFinished = 0
// Last time when we launched a preferred task (for delay scheduling)
@ -62,6 +68,13 @@ class SimpleJob(
// List containing all pending tasks (also used as a stack, as above)
val allPendingTasks = new ArrayBuffer[Int]
// Tasks that can be specualted. Since these will be a small fraction of total
// tasks, we'll just hold them in a HaskSet.
val speculatableTasks = new HashSet[Int]
// Task index, start and finish time for each task attempt (indexed by task ID)
val taskInfos = new HashMap[String, TaskInfo]
// Did the job fail?
var failed = false
var causeOfFailure = ""
@ -76,6 +89,12 @@ class SimpleJob(
// exceptions automatically.
val recentExceptions = HashMap[String, (Int, Long)]()
// Figure out the current map output tracker generation and set it on all tasks
val generation = sched.mapOutputTracker.getGeneration
for (t <- tasks) {
t.generation = generation
}
// Add all our tasks to the pending lists. We do this in reverse order
// of task index so that tasks with low indices get launched first.
for (i <- (0 until numTasks).reverse) {
@ -84,7 +103,7 @@ class SimpleJob(
// Add a task to all the pending-task lists that it should be on.
def addPendingTask(index: Int) {
val locations = tasks(index).preferredLocations
val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
if (locations.size == 0) {
pendingTasksWithNoPrefs += index
} else {
@ -110,13 +129,37 @@ class SimpleJob(
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
if (!launched(index) && !finished(index)) {
if (copiesRunning(index) == 0 && !finished(index)) {
return Some(index)
}
}
return None
}
// Return a speculative task for a given host if any are available. The task should not have an
// attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
// task must have a preference for this host (or no preferred locations at all).
def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
val localTask = speculatableTasks.find { index =>
val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
val attemptLocs = taskAttempts(index).map(_.host)
(locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
}
if (localTask != None) {
speculatableTasks -= localTask.get
return localTask
}
if (!localOnly && speculatableTasks.size > 0) {
val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host))
if (nonLocalTask != None) {
speculatableTasks -= nonLocalTask.get
return nonLocalTask
}
}
return None
}
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
def findTask(host: String, localOnly: Boolean): Option[Int] = {
@ -129,10 +172,13 @@ class SimpleJob(
return noPrefTask
}
if (!localOnly) {
return findTaskFromList(allPendingTasks) // Look for non-local task
} else {
return None
val nonLocalTask = findTaskFromList(allPendingTasks)
if (nonLocalTask != None) {
return nonLocalTask
}
}
// Finally, if all else has failed, find a speculative task
return findSpeculativeTask(host, localOnly)
}
// Does a host count as a preferred location for a task? This is true if
@ -144,11 +190,11 @@ class SimpleJob(
}
// Respond to an offer of a single slave from the scheduler by finding a task
def slaveOffer(offer: Offer, availableCpus: Double): Option[TaskInfo] = {
if (tasksLaunched < numTasks && availableCpus >= CPUS_PER_TASK) {
def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[MTaskInfo] = {
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
val time = System.currentTimeMillis
val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
val host = offer.getHostname
var localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
findTask(host, localOnly) match {
case Some(index) => {
// Found a task; do some bookkeeping and return a Mesos task for it
@ -156,17 +202,17 @@ class SimpleJob(
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
val preferred = isPreferredLocation(task, host)
val prefStr = if(preferred) "preferred" else "non-preferred"
val message =
"Starting task %d:%d as TID %s on slave %s: %s (%s)".format(
jobId, index, taskId.getValue, offer.getSlaveId.getValue, host, prefStr)
logInfo(message)
val prefStr = if (preferred) "preferred" else "non-preferred"
logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
taskSet.id, index, taskId.getValue, slaveId, host, prefStr))
// Do various bookkeeping
tidToIndex(taskId.getValue) = index
launched(index) = true
tasksLaunched += 1
if (preferred)
copiesRunning(index) += 1
val info = new TaskInfo(taskId.getValue, index, time, host)
taskInfos(taskId.getValue) = info
taskAttempts(index) = info :: taskAttempts(index)
if (preferred) {
lastPreferredLaunchTime = time
}
// Create and return the Mesos task object
val cpuRes = Resource.newBuilder()
.setName("cpus")
@ -178,13 +224,13 @@ class SimpleJob(
val serializedTask = ser.serialize(task)
val timeTaken = System.currentTimeMillis - startTime
logInfo("Size of task %d:%d is %d bytes and took %d ms to serialize by %s"
.format(jobId, index, serializedTask.size, timeTaken, ser.getClass.getName))
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %d:%d".format(jobId, index)
return Some(TaskInfo.newBuilder()
val taskName = "task %s:%d".format(taskSet.id, index)
return Some(MTaskInfo.newBuilder()
.setTaskId(taskId)
.setSlaveId(offer.getSlaveId)
.setSlaveId(SlaveID.newBuilder().setValue(slaveId))
.setExecutor(sched.executorInfo)
.setName(taskName)
.addResources(cpuRes)
@ -213,18 +259,21 @@ class SimpleJob(
def taskFinished(status: TaskStatus) {
val tid = status.getTaskId.getValue
val index = tidToIndex(tid)
val info = taskInfos(tid)
val index = info.index
info.markSuccessful()
if (!finished(index)) {
tasksFinished += 1
logInfo("Finished TID %s (progress: %d/%d)".format(tid, tasksFinished, numTasks))
// Deserialize task result
val result = ser.deserialize[TaskResult[_]](
status.getData.toByteArray, getClass.getClassLoader)
sched.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
tid, info.duration, tasksFinished, numTasks))
// Deserialize task result and pass it to the scheduler
val result = ser.deserialize[TaskResult[_]](status.getData.asReadOnlyByteBuffer)
sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
// Mark finished and stop if we've finished all the tasks
finished(index) = true
if (tasksFinished == numTasks)
sched.jobFinished(this)
if (tasksFinished == numTasks) {
sched.taskSetFinished(this)
}
} else {
logInfo("Ignoring task-finished event for TID " + tid +
" because task " + index + " is already finished")
@ -233,30 +282,29 @@ class SimpleJob(
def taskLost(status: TaskStatus) {
val tid = status.getTaskId.getValue
val index = tidToIndex(tid)
val info = taskInfos(tid)
val index = info.index
info.markFailed()
if (!finished(index)) {
logInfo("Lost TID %s (task %d:%d)".format(tid, jobId, index))
launched(index) = false
tasksLaunched -= 1
logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
copiesRunning(index) -= 1
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
if (status.getData != null && status.getData.size > 0) {
val reason = ser.deserialize[TaskEndReason](
status.getData.toByteArray, getClass.getClassLoader)
val reason = ser.deserialize[TaskEndReason](status.getData.asReadOnlyByteBuffer)
reason match {
case fetchFailed: FetchFailed =>
logInfo("Loss was due to fetch failure from " + fetchFailed.serverUri)
sched.taskEnded(tasks(index), fetchFailed, null, null)
logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
sched.listener.taskEnded(tasks(index), fetchFailed, null, null)
finished(index) = true
tasksFinished += 1
if (tasksFinished == numTasks) {
sched.jobFinished(this)
}
sched.taskSetFinished(this)
return
case ef: ExceptionFailure =>
val key = ef.exception.toString
val now = System.currentTimeMillis
val (printFull, dupCount) =
val (printFull, dupCount) = {
if (recentExceptions.contains(key)) {
val (dupCount, printTime) = recentExceptions(key)
if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
@ -267,32 +315,28 @@ class SimpleJob(
(false, dupCount + 1)
}
} else {
recentExceptions += Tuple(key, (0, now))
recentExceptions(key) = (0, now)
(true, 0)
}
if (printFull) {
val stackTrace =
for (elem <- ef.exception.getStackTrace)
yield "\tat %s".format(elem.toString)
logInfo("Loss was due to %s\n%s".format(
ef.exception.toString, stackTrace.mkString("\n")))
} else {
logInfo("Loss was due to %s [duplicate %d]".format(
ef.exception.toString, dupCount))
}
if (printFull) {
val locs = ef.exception.getStackTrace.map(loc => "\tat %s".format(loc.toString))
logInfo("Loss was due to %s\n%s".format(ef.exception.toString, locs.mkString("\n")))
} else {
logInfo("Loss was due to %s [duplicate %d]".format(ef.exception.toString, dupCount))
}
case _ => {}
}
}
// On other failures, re-enqueue the task as pending for a max number of retries
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
// Count attempts only on FAILED and LOST state (not on KILLED)
if (status.getState == TaskState.TASK_FAILED ||
status.getState == TaskState.TASK_LOST) {
// Count failed attempts only on FAILED and LOST state (not on KILLED)
if (status.getState == TaskState.TASK_FAILED || status.getState == TaskState.TASK_LOST) {
numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %d:%d failed more than %d times; aborting job".format(
jobId, index, MAX_TASK_FAILURES))
logError("Task %s:%d failed more than %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES))
abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES))
}
}
@ -311,6 +355,71 @@ class SimpleJob(
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
sched.jobFinished(this)
sched.taskSetFinished(this)
}
def hostLost(hostname: String) {
logInfo("Re-queueing tasks for " + hostname)
// If some task has preferred locations only on hostname, put it in the no-prefs list
// to avoid the wait from delay scheduling
for (index <- getPendingTasksForHost(hostname)) {
val newLocs = tasks(index).preferredLocations.toSet & sched.hostsAlive
if (newLocs.isEmpty) {
pendingTasksWithNoPrefs += index
}
}
// Also re-enqueue any tasks that ran on the failed host if this is a shuffle map stage
if (tasks(0).isInstanceOf[ShuffleMapTask]) {
for ((tid, info) <- taskInfos if info.host == hostname) {
val index = taskInfos(tid).index
if (finished(index)) {
finished(index) = false
copiesRunning(index) -= 1
tasksFinished -= 1
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
sched.listener.taskEnded(tasks(index), Resubmitted, null, null)
}
}
}
}
/**
* Check for tasks to be speculated and return true if there are any. This is called periodically
* by the MesosScheduler.
*
* TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
* we don't scan the whole task set. It might also help to make this sorted by launch time.
*/
def checkSpeculatableTasks(): Boolean = {
// Can't speculate if we only have one task, or if all tasks have finished.
if (numTasks == 1 || tasksFinished == numTasks) {
return false
}
var foundTasks = false
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
if (tasksFinished >= minFinishedForSpeculation) {
val time = System.currentTimeMillis()
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
Arrays.sort(durations)
val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
// TODO: Threshold should also look at standard deviation of task durations and have a lower
// bound based on that.
logDebug("Task length threshold for speculation: " + threshold)
for ((tid, info) <- taskInfos) {
val index = info.index
if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
!speculatableTasks.contains(index)) {
logInfo("Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
taskSet.id, index, info.host, threshold))
speculatableTasks += index
foundTasks = true
}
}
}
return foundTasks
}
}

Просмотреть файл

@ -0,0 +1,507 @@
package spark.storage
import java.io._
import java.nio._
import java.nio.channels.FileChannel.MapMode
import java.util.{HashMap => JHashMap}
import java.util.LinkedHashMap
import java.util.UUID
import java.util.Collections
import scala.actors._
import scala.actors.Actor._
import scala.actors.Future
import scala.actors.Futures.future
import scala.actors.remote._
import scala.actors.remote.RemoteActor._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
import it.unimi.dsi.fastutil.io._
import spark.CacheTracker
import spark.Logging
import spark.Serializer
import spark.SizeEstimator
import spark.SparkEnv
import spark.SparkException
import spark.Utils
import spark.network._
class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
def this() = this(null, 0)
override def writeExternal(out: ObjectOutput) {
out.writeUTF(ip)
out.writeInt(port)
}
override def readExternal(in: ObjectInput) {
ip = in.readUTF()
port = in.readInt()
}
override def toString = "BlockManagerId(" + ip + ", " + port + ")"
override def hashCode = ip.hashCode * 41 + port
override def equals(that: Any) = that match {
case id: BlockManagerId => port == id.port && ip == id.ip
case _ => false
}
}
case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message)
class BlockLocker(numLockers: Int) {
private val hashLocker = Array.fill(numLockers)(new Object())
def getLock(blockId: String): Object = {
return hashLocker(Math.abs(blockId.hashCode % numLockers))
}
}
/**
* A start towards a block manager class. This will eventually be used for both RDD persistence
* and shuffle outputs.
*
* TODO: Should make the communication with Master or Peers code more robust and log friendly.
*/
class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging {
private val NUM_LOCKS = 337
private val locker = new BlockLocker(NUM_LOCKS)
private val storageLevels = Collections.synchronizedMap(new JHashMap[String, StorageLevel])
private val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
private val diskStore: BlockStore = new DiskStore(this,
System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
val connectionManager = new ConnectionManager(0)
val connectionManagerId = connectionManager.id
val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port)
// TODO(Haoyuan): This will be removed after cacheTracker is removed from the code base.
var cacheTracker: CacheTracker = null
initLogging()
initialize()
/**
* Construct a BlockManager with a memory limit set based on system properties.
*/
def this(serializer: Serializer) =
this(BlockManager.getMaxMemoryFromSystemProperties(), serializer)
/**
* Initialize the BlockManager. Register to the BlockManagerMaster, and start the
* BlockManagerWorker actor.
*/
def initialize() {
BlockManagerMaster.mustRegisterBlockManager(
RegisterBlockManager(blockManagerId, maxMemory, maxMemory))
BlockManagerWorker.startBlockManagerWorker(this)
}
/**
* Get locations of the block.
*/
def getLocations(blockId: String): Seq[String] = {
val startTimeMs = System.currentTimeMillis
var managers: Array[BlockManagerId] = BlockManagerMaster.mustGetLocations(GetLocations(blockId))
val locations = managers.map((manager: BlockManagerId) => { manager.ip }).toSeq
logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs))
return locations
}
/**
* Get locations of an array of blocks
*/
def getLocationsMultipleBlockIds(blockIds: Array[String]): Array[Seq[String]] = {
val startTimeMs = System.currentTimeMillis
val locations = BlockManagerMaster.mustGetLocationsMultipleBlockIds(
GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray
logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
return locations
}
def getLocal(blockId: String): Option[Iterator[Any]] = {
logDebug("Getting block " + blockId)
locker.getLock(blockId).synchronized {
// Check storage level of block
val level = storageLevels.get(blockId)
if (level != null) {
logDebug("Level for block " + blockId + " is " + level + " on local machine")
// Look for the block in memory
if (level.useMemory) {
logDebug("Getting block " + blockId + " from memory")
memoryStore.getValues(blockId) match {
case Some(iterator) => {
logDebug("Block " + blockId + " found in memory")
return Some(iterator)
}
case None => {
logDebug("Block " + blockId + " not found in memory")
}
}
} else {
logDebug("Not getting block " + blockId + " from memory")
}
// Look for block in disk
if (level.useDisk) {
logDebug("Getting block " + blockId + " from disk")
diskStore.getValues(blockId) match {
case Some(iterator) => {
logDebug("Block " + blockId + " found in disk")
return Some(iterator)
}
case None => {
throw new Exception("Block " + blockId + " not found in disk")
return None
}
}
} else {
logDebug("Not getting block " + blockId + " from disk")
}
} else {
logDebug("Level for block " + blockId + " not found")
}
}
return None
}
def getRemote(blockId: String): Option[Iterator[Any]] = {
// Get locations of block
val locations = BlockManagerMaster.mustGetLocations(GetLocations(blockId))
// Get block from remote locations
for (loc <- locations) {
val data = BlockManagerWorker.syncGetBlock(
GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port))
if (data != null) {
logDebug("Data is not null: " + data)
return Some(dataDeserialize(data))
}
logDebug("Data is null")
}
logDebug("Data not found")
return None
}
/**
* Read a block from the block manager.
*/
def get(blockId: String): Option[Iterator[Any]] = {
getLocal(blockId).orElse(getRemote(blockId))
}
/**
* Read many blocks from block manager using their BlockManagerIds.
*/
def get(blocksByAddress: Seq[(BlockManagerId, Seq[String])]): HashMap[String, Option[Iterator[Any]]] = {
logDebug("Getting " + blocksByAddress.map(_._2.size).sum + " blocks")
var startTime = System.currentTimeMillis
val blocks = new HashMap[String,Option[Iterator[Any]]]()
val localBlockIds = new ArrayBuffer[String]()
val remoteBlockIds = new ArrayBuffer[String]()
val remoteBlockIdsPerLocation = new HashMap[BlockManagerId, Seq[String]]()
// Split local and remote blocks
for ((address, blockIds) <- blocksByAddress) {
if (address == blockManagerId) {
localBlockIds ++= blockIds
} else {
remoteBlockIds ++= blockIds
remoteBlockIdsPerLocation(address) = blockIds
}
}
// Start getting remote blocks
val remoteBlockFutures = remoteBlockIdsPerLocation.toSeq.map { case (bmId, bIds) =>
val cmId = ConnectionManagerId(bmId.ip, bmId.port)
val blockMessages = bIds.map(bId => BlockMessage.fromGetBlock(GetBlock(bId)))
val blockMessageArray = new BlockMessageArray(blockMessages)
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
(cmId, future)
}
logDebug("Started remote gets for " + remoteBlockIds.size + " blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
// Get the local blocks while remote blocks are being fetched
startTime = System.currentTimeMillis
localBlockIds.foreach(id => {
get(id) match {
case Some(block) => {
blocks.update(id, Some(block))
logDebug("Got local block " + id)
}
case None => {
throw new BlockException(id, "Could not get block " + id + " from local machine")
}
}
})
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
// wait for and gather all the remote blocks
for ((cmId, future) <- remoteBlockFutures) {
var count = 0
val oneBlockId = remoteBlockIdsPerLocation(new BlockManagerId(cmId.host, cmId.port)).first
future() match {
case Some(message) => {
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
blockMessageArray.foreach(blockMessage => {
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
throw new BlockException(oneBlockId, "Unexpected message received from " + cmId)
}
val buffer = blockMessage.getData()
val blockId = blockMessage.getId()
val block = dataDeserialize(buffer)
blocks.update(blockId, Some(block))
logDebug("Got remote block " + blockId + " in " + Utils.getUsedTimeMs(startTime))
count += 1
})
}
case None => {
throw new BlockException(oneBlockId, "Could not get blocks from " + cmId)
}
}
logDebug("Got remote " + count + " blocks from " + cmId.host + " in " + Utils.getUsedTimeMs(startTime) + " ms")
}
logDebug("Got all blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
return blocks
}
/**
* Write a new block to the block manager.
*/
def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean = true) {
if (!level.useDisk && !level.useMemory) {
throw new IllegalArgumentException("Storage level has neither useMemory nor useDisk set")
}
val startTimeMs = System.currentTimeMillis
var bytes: ByteBuffer = null
locker.getLock(blockId).synchronized {
logDebug("Put for block " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
// Check and warn if block with same id already exists
if (storageLevels.get(blockId) != null) {
logWarning("Block " + blockId + " already exists in local machine")
return
}
// Store the storage level
storageLevels.put(blockId, level)
if (level.useMemory && level.useDisk) {
// If saving to both memory and disk, then serialize only once
memoryStore.putValues(blockId, values, level) match {
case Left(newValues) =>
diskStore.putValues(blockId, newValues, level) match {
case Right(newBytes) => bytes = newBytes
case _ => throw new Exception("Unexpected return value")
}
case Right(newBytes) =>
bytes = newBytes
diskStore.putBytes(blockId, newBytes, level)
}
} else if (level.useMemory) {
// If only save to memory
memoryStore.putValues(blockId, values, level) match {
case Right(newBytes) => bytes = newBytes
case _ =>
}
} else {
// If only save to disk
diskStore.putValues(blockId, values, level) match {
case Right(newBytes) => bytes = newBytes
case _ => throw new Exception("Unexpected return value")
}
}
if (tellMaster) {
notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0))
logDebug("Put block " + blockId + " after notifying the master " + Utils.getUsedTimeMs(startTimeMs))
}
}
// Replicate block if required
if (level.replication > 1) {
if (bytes == null) {
bytes = dataSerialize(values) // serialize the block if not already done
}
replicate(blockId, bytes, level)
}
// TODO(Haoyuan): This code will be removed when CacheTracker is gone.
if (blockId.startsWith("rdd")) {
notifyTheCacheTracker(blockId)
}
logDebug("Put block " + blockId + " after notifying the CacheTracker " + Utils.getUsedTimeMs(startTimeMs))
}
def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
val startTime = System.currentTimeMillis
if (!level.useDisk && !level.useMemory) {
throw new IllegalArgumentException("Storage level has neither useMemory nor useDisk set")
} else if (level.deserialized) {
throw new IllegalArgumentException("Storage level cannot have deserialized when putBytes is used")
}
val replicationFuture = if (level.replication > 1) {
future {
replicate(blockId, bytes, level)
}
} else {
null
}
locker.getLock(blockId).synchronized {
logDebug("PutBytes for block " + blockId + " used " + Utils.getUsedTimeMs(startTime)
+ " to get into synchronized block")
if (storageLevels.get(blockId) != null) {
logWarning("Block " + blockId + " already exists")
return
}
storageLevels.put(blockId, level)
if (level.useMemory) {
memoryStore.putBytes(blockId, bytes, level)
}
if (level.useDisk) {
diskStore.putBytes(blockId, bytes, level)
}
if (tellMaster) {
notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0))
}
}
if (blockId.startsWith("rdd")) {
notifyTheCacheTracker(blockId)
}
if (level.replication > 1) {
if (replicationFuture == null) {
throw new Exception("Unexpected")
}
replicationFuture()
}
val finishTime = System.currentTimeMillis
if (level.replication > 1) {
logDebug("PutBytes with replication took " + (finishTime - startTime) + " ms")
} else {
logDebug("PutBytes without replication took " + (finishTime - startTime) + " ms")
}
}
private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
val tLevel: StorageLevel =
new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
var peers: Array[BlockManagerId] = BlockManagerMaster.mustGetPeers(
GetPeers(blockManagerId, level.replication - 1))
for (peer: BlockManagerId <- peers) {
val start = System.nanoTime
logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is "
+ data.array().length + " Bytes. To node: " + peer)
if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel),
new ConnectionManagerId(peer.ip, peer.port))) {
logError("Failed to call syncPutBlock to " + peer)
}
logDebug("Replicated BlockId " + blockId + " once used " +
(System.nanoTime - start) / 1e6 + " s; The size of the data is " +
data.array().length + " bytes.")
}
}
// TODO(Haoyuan): This code will be removed when CacheTracker is gone.
def notifyTheCacheTracker(key: String) {
val rddInfo = key.split(":")
val rddId: Int = rddInfo(1).toInt
val splitIndex: Int = rddInfo(2).toInt
val host = System.getProperty("spark.hostname", Utils.localHostName)
cacheTracker.notifyTheCacheTrackerFromBlockManager(spark.AddedToCache(rddId, splitIndex, host))
}
/**
* Read a block consisting of a single object.
*/
def getSingle(blockId: String): Option[Any] = {
get(blockId).map(_.next)
}
/**
* Write a block consisting of a single object.
*/
def putSingle(blockId: String, value: Any, level: StorageLevel) {
put(blockId, Iterator(value), level)
}
/**
* Drop block from memory (called when memory store has reached it limit)
*/
def dropFromMemory(blockId: String) {
locker.getLock(blockId).synchronized {
val level = storageLevels.get(blockId)
if (level == null) {
logWarning("Block " + blockId + " cannot be removed from memory as it does not exist")
return
}
if (!level.useMemory) {
logWarning("Block " + blockId + " cannot be removed from memory as it is not in memory")
return
}
memoryStore.remove(blockId)
if (!level.useDisk) {
storageLevels.remove(blockId)
} else {
val newLevel = level.clone
newLevel.useMemory = false
storageLevels.remove(blockId)
storageLevels.put(blockId, newLevel)
}
}
}
def dataSerialize(values: Iterator[Any]): ByteBuffer = {
/*serializer.newInstance().serializeMany(values)*/
val byteStream = new FastByteArrayOutputStream(4096)
serializer.newInstance().serializeStream(byteStream).writeAll(values).close()
byteStream.trim()
ByteBuffer.wrap(byteStream.array)
}
def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = {
/*serializer.newInstance().deserializeMany(bytes)*/
val ser = serializer.newInstance()
return ser.deserializeStream(new FastByteArrayInputStream(bytes.array())).toIterator
}
private def notifyMaster(heartBeat: HeartBeat) {
BlockManagerMaster.mustHeartBeat(heartBeat)
}
}
object BlockManager extends Logging {
def getMaxMemoryFromSystemProperties(): Long = {
val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble
val bytes = (Runtime.getRuntime.totalMemory * memoryFraction).toLong
logInfo("Maximum memory to use: " + bytes)
bytes
}
}

Просмотреть файл

@ -0,0 +1,516 @@
package spark.storage
import java.io._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.util.Random
import akka.actor._
import akka.actor.Actor
import akka.actor.Actor._
import akka.util.duration._
import spark.Logging
import spark.Utils
sealed trait ToBlockManagerMaster
case class RegisterBlockManager(
blockManagerId: BlockManagerId,
maxMemSize: Long,
maxDiskSize: Long)
extends ToBlockManagerMaster
class HeartBeat(
var blockManagerId: BlockManagerId,
var blockId: String,
var storageLevel: StorageLevel,
var deserializedSize: Long,
var size: Long)
extends ToBlockManagerMaster
with Externalizable {
def this() = this(null, null, null, 0, 0) // For deserialization only
override def writeExternal(out: ObjectOutput) {
blockManagerId.writeExternal(out)
out.writeUTF(blockId)
storageLevel.writeExternal(out)
out.writeInt(deserializedSize.toInt)
out.writeInt(size.toInt)
}
override def readExternal(in: ObjectInput) {
blockManagerId = new BlockManagerId()
blockManagerId.readExternal(in)
blockId = in.readUTF()
storageLevel = new StorageLevel()
storageLevel.readExternal(in)
deserializedSize = in.readInt()
size = in.readInt()
}
}
object HeartBeat {
def apply(blockManagerId: BlockManagerId,
blockId: String,
storageLevel: StorageLevel,
deserializedSize: Long,
size: Long): HeartBeat = {
new HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size)
}
// For pattern-matching
def unapply(h: HeartBeat): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = {
Some((h.blockManagerId, h.blockId, h.storageLevel, h.deserializedSize, h.size))
}
}
case class GetLocations(
blockId: String)
extends ToBlockManagerMaster
case class GetLocationsMultipleBlockIds(
blockIds: Array[String])
extends ToBlockManagerMaster
case class GetPeers(
blockManagerId: BlockManagerId,
size: Int)
extends ToBlockManagerMaster
case class RemoveHost(
host: String)
extends ToBlockManagerMaster
class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging {
class BlockManagerInfo(
timeMs: Long,
maxMem: Long,
maxDisk: Long) {
private var lastSeenMs = timeMs
private var remainedMem = maxMem
private var remainedDisk = maxDisk
private val blocks = new HashMap[String, StorageLevel]
def updateLastSeenMs() {
lastSeenMs = System.currentTimeMillis() / 1000
}
def addBlock(blockId: String, storageLevel: StorageLevel, deserializedSize: Long, size: Long) =
synchronized {
updateLastSeenMs()
if (blocks.contains(blockId)) {
val oriLevel: StorageLevel = blocks(blockId)
if (oriLevel.deserialized) {
remainedMem += deserializedSize
}
if (oriLevel.useMemory) {
remainedMem += size
}
if (oriLevel.useDisk) {
remainedDisk += size
}
}
blocks += (blockId -> storageLevel)
if (storageLevel.deserialized) {
remainedMem -= deserializedSize
}
if (storageLevel.useMemory) {
remainedMem -= size
}
if (storageLevel.useDisk) {
remainedDisk -= size
}
if (!(storageLevel.deserialized || storageLevel.useMemory || storageLevel.useDisk)) {
blocks.remove(blockId)
}
}
def getLastSeenMs(): Long = {
return lastSeenMs
}
def getRemainedMem(): Long = {
return remainedMem
}
def getRemainedDisk(): Long = {
return remainedDisk
}
override def toString(): String = {
return "BlockManagerInfo " + timeMs + " " + remainedMem + " " + remainedDisk
}
}
private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerInfo]
private val blockIdMap = new HashMap[String, Pair[Int, HashSet[BlockManagerId]]]
initLogging()
def removeHost(host: String) {
logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.")
logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq)
val ip = host.split(":")(0)
val port = host.split(":")(1)
blockManagerInfo.remove(new BlockManagerId(ip, port.toInt))
logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq)
self.reply(true)
}
def receive = {
case RegisterBlockManager(blockManagerId, maxMemSize, maxDiskSize) =>
register(blockManagerId, maxMemSize, maxDiskSize)
case HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
heartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size)
case GetLocations(blockId) =>
getLocations(blockId)
case GetLocationsMultipleBlockIds(blockIds) =>
getLocationsMultipleBlockIds(blockIds)
case GetPeers(blockManagerId, size) =>
getPeers_Deterministic(blockManagerId, size)
/*getPeers(blockManagerId, size)*/
case RemoveHost(host) =>
removeHost(host)
case msg =>
logInfo("Got unknown msg: " + msg)
}
private def register(blockManagerId: BlockManagerId, maxMemSize: Long, maxDiskSize: Long) {
val startTimeMs = System.currentTimeMillis()
val tmp = " " + blockManagerId + " "
logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
logInfo("Got Register Msg from " + blockManagerId)
if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
logInfo("Got Register Msg from master node, don't register it")
} else {
blockManagerInfo += (blockManagerId -> new BlockManagerInfo(
System.currentTimeMillis() / 1000, maxMemSize, maxDiskSize))
}
logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs))
self.reply(true)
}
private def heartBeat(
blockManagerId: BlockManagerId,
blockId: String,
storageLevel: StorageLevel,
deserializedSize: Long,
size: Long) {
val startTimeMs = System.currentTimeMillis()
val tmp = " " + blockManagerId + " " + blockId + " "
logDebug("Got in heartBeat 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
if (blockId == null) {
blockManagerInfo(blockManagerId).updateLastSeenMs()
logDebug("Got in heartBeat 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
self.reply(true)
}
blockManagerInfo(blockManagerId).addBlock(blockId, storageLevel, deserializedSize, size)
logDebug("Got in heartBeat 2" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
var locations: HashSet[BlockManagerId] = null
if (blockIdMap.contains(blockId)) {
locations = blockIdMap(blockId)._2
} else {
locations = new HashSet[BlockManagerId]
blockIdMap += (blockId -> (storageLevel.replication, locations))
}
logDebug("Got in heartBeat 3" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
if (storageLevel.deserialized || storageLevel.useDisk || storageLevel.useMemory) {
locations += blockManagerId
} else {
locations.remove(blockManagerId)
}
logDebug("Got in heartBeat 4" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
if (locations.size == 0) {
blockIdMap.remove(blockId)
}
logDebug("Got in heartBeat 5" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
self.reply(true)
}
private def getLocations(blockId: String) {
val startTimeMs = System.currentTimeMillis()
val tmp = " " + blockId + " "
logDebug("Got in getLocations 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
if (blockIdMap.contains(blockId)) {
var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
res.appendAll(blockIdMap(blockId)._2)
logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at "
+ Utils.getUsedTimeMs(startTimeMs))
self.reply(res.toSeq)
} else {
logDebug("Got in getLocations 2" + tmp + Utils.getUsedTimeMs(startTimeMs))
var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
self.reply(res)
}
}
private def getLocationsMultipleBlockIds(blockIds: Array[String]) {
def getLocations(blockId: String): Seq[BlockManagerId] = {
val tmp = blockId
logDebug("Got in getLocationsMultipleBlockIds Sub 0 " + tmp)
if (blockIdMap.contains(blockId)) {
var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
res.appendAll(blockIdMap(blockId)._2)
logDebug("Got in getLocationsMultipleBlockIds Sub 1 " + tmp + " " + res.toSeq)
return res.toSeq
} else {
logDebug("Got in getLocationsMultipleBlockIds Sub 2 " + tmp)
var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
return res.toSeq
}
}
logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq)
var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]]
for (blockId <- blockIds) {
res.append(getLocations(blockId))
}
logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq + " : " + res.toSeq)
self.reply(res.toSeq)
}
private def getPeers(blockManagerId: BlockManagerId, size: Int) {
val startTimeMs = System.currentTimeMillis()
val tmp = " " + blockManagerId + " "
logDebug("Got in getPeers 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
res.appendAll(peers)
res -= blockManagerId
val rand = new Random(System.currentTimeMillis())
logDebug("Got in getPeers 1" + tmp + Utils.getUsedTimeMs(startTimeMs))
while (res.length > size) {
res.remove(rand.nextInt(res.length))
}
logDebug("Got in getPeers 2" + tmp + Utils.getUsedTimeMs(startTimeMs))
self.reply(res.toSeq)
}
private def getPeers_Deterministic(blockManagerId: BlockManagerId, size: Int) {
val startTimeMs = System.currentTimeMillis()
var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
val peersWithIndices = peers.zipWithIndex
val selfIndex = peersWithIndices.find(_._1 == blockManagerId).map(_._2).getOrElse(-1)
if (selfIndex == -1) {
throw new Exception("Self index for " + blockManagerId + " not found")
}
var index = selfIndex
while (res.size < size) {
index += 1
if (index == selfIndex) {
throw new Exception("More peer expected than available")
}
res += peers(index % peers.size)
}
val resStr = res.map(_.toString).reduceLeft(_ + ", " + _)
logDebug("Got peers for " + blockManagerId + " as [" + resStr + "]")
self.reply(res.toSeq)
}
}
object BlockManagerMaster extends Logging {
initLogging()
val AKKA_ACTOR_NAME: String = "BlockMasterManager"
val REQUEST_RETRY_INTERVAL_MS = 100
val DEFAULT_MASTER_IP: String = System.getProperty("spark.master.host", "localhost")
val DEFAULT_MASTER_PORT: Int = System.getProperty("spark.master.port", "7077").toInt
val DEFAULT_MANAGER_IP: String = Utils.localHostName()
val DEFAULT_MANAGER_PORT: String = "10902"
implicit val TIME_OUT_SEC = Actor.Timeout(3000 millis)
var masterActor: ActorRef = null
def startBlockManagerMaster(isMaster: Boolean, isLocal: Boolean) {
if (isMaster) {
masterActor = actorOf(new BlockManagerMaster(isLocal))
remote.register(AKKA_ACTOR_NAME, masterActor)
logInfo("Registered BlockManagerMaster Actor: " + DEFAULT_MASTER_IP + ":" + DEFAULT_MASTER_PORT)
masterActor.start()
} else {
masterActor = remote.actorFor(AKKA_ACTOR_NAME, DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT)
}
}
def notifyADeadHost(host: String) {
(masterActor ? RemoveHost(host + ":" + DEFAULT_MANAGER_PORT)).as[Any] match {
case Some(true) =>
logInfo("Removed " + host + " successfully. @ notifyADeadHost")
case Some(oops) =>
logError("Failed @ notifyADeadHost: " + oops)
case None =>
logError("None @ notifyADeadHost.")
}
}
def mustRegisterBlockManager(msg: RegisterBlockManager) {
while (! syncRegisterBlockManager(msg)) {
logWarning("Failed to register " + msg)
Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
}
}
def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = {
//val masterActor = RemoteActor.select(node, name)
val startTimeMs = System.currentTimeMillis()
val tmp = " msg " + msg + " "
logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
(masterActor ? msg).as[Any] match {
case Some(true) =>
logInfo("BlockManager registered successfully @ syncRegisterBlockManager.")
logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
return true
case Some(oops) =>
logError("Failed @ syncRegisterBlockManager: " + oops)
return false
case None =>
logError("None @ syncRegisterBlockManager.")
return false
}
}
def mustHeartBeat(msg: HeartBeat) {
while (! syncHeartBeat(msg)) {
logWarning("Failed to send heartbeat" + msg)
Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
}
}
def syncHeartBeat(msg: HeartBeat): Boolean = {
val startTimeMs = System.currentTimeMillis()
val tmp = " msg " + msg + " "
logDebug("Got in syncHeartBeat " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs))
(masterActor ? msg).as[Any] match {
case Some(true) =>
logInfo("Heartbeat sent successfully.")
logDebug("Got in syncHeartBeat " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs))
return true
case Some(oops) =>
logError("Failed: " + oops)
return false
case None =>
logError("None.")
return false
}
}
def mustGetLocations(msg: GetLocations): Array[BlockManagerId] = {
var res: Array[BlockManagerId] = syncGetLocations(msg)
while (res == null) {
logInfo("Failed to get locations " + msg)
Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
res = syncGetLocations(msg)
}
return res
}
def syncGetLocations(msg: GetLocations): Array[BlockManagerId] = {
val startTimeMs = System.currentTimeMillis()
val tmp = " msg " + msg + " "
logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
(masterActor ? msg).as[Seq[BlockManagerId]] match {
case Some(arr) =>
logDebug("GetLocations successfully.")
logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
for (ele <- arr) {
res += ele
}
logDebug("Got in syncGetLocations 2 " + tmp + Utils.getUsedTimeMs(startTimeMs))
return res.toArray
case None =>
logError("GetLocations call returned None.")
return null
}
}
def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds):
Seq[Seq[BlockManagerId]] = {
var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg)
while (res == null) {
logWarning("Failed to GetLocationsMultipleBlockIds " + msg)
Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
res = syncGetLocationsMultipleBlockIds(msg)
}
return res
}
def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds):
Seq[Seq[BlockManagerId]] = {
val startTimeMs = System.currentTimeMillis
val tmp = " msg " + msg + " "
logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
(masterActor ? msg).as[Any] match {
case Some(arr: Seq[Seq[BlockManagerId]]) =>
logDebug("GetLocationsMultipleBlockIds successfully: " + arr)
logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
return arr
case Some(oops) =>
logError("Failed: " + oops)
return null
case None =>
logInfo("None.")
return null
}
}
def mustGetPeers(msg: GetPeers): Array[BlockManagerId] = {
var res: Array[BlockManagerId] = syncGetPeers(msg)
while ((res == null) || (res.length != msg.size)) {
logInfo("Failed to get peers " + msg)
Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
res = syncGetPeers(msg)
}
return res
}
def syncGetPeers(msg: GetPeers): Array[BlockManagerId] = {
val startTimeMs = System.currentTimeMillis
val tmp = " msg " + msg + " "
logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
(masterActor ? msg).as[Seq[BlockManagerId]] match {
case Some(arr) =>
logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
logInfo("GetPeers successfully: " + arr.length)
res.appendAll(arr)
logDebug("Got in syncGetPeers 2 " + tmp + Utils.getUsedTimeMs(startTimeMs))
return res.toArray
case None =>
logError("GetPeers call returned None.")
return null
}
}
}

Просмотреть файл

@ -0,0 +1,142 @@
package spark.storage
import java.nio._
import scala.actors._
import scala.actors.Actor._
import scala.actors.remote._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.util.Random
import spark.Logging
import spark.Utils
import spark.SparkEnv
import spark.network._
/**
* This should be changed to use event model late.
*/
class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
initLogging()
blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive)
def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = {
logDebug("Handling message " + msg)
msg match {
case bufferMessage: BufferMessage => {
try {
logDebug("Handling as a buffer message " + bufferMessage)
val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage)
logDebug("Parsed as a block message array")
val responseMessages = blockMessages.map(processBlockMessage _).filter(_ != None).map(_.get)
/*logDebug("Processed block messages")*/
return Some(new BlockMessageArray(responseMessages).toBufferMessage)
} catch {
case e: Exception => logError("Exception handling buffer message: " + e.getMessage)
return None
}
}
case otherMessage: Any => {
logError("Unknown type message received: " + otherMessage)
return None
}
}
}
def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = {
blockMessage.getType() match {
case BlockMessage.TYPE_PUT_BLOCK => {
val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel())
logInfo("Received [" + pB + "]")
putBlock(pB.id, pB.data, pB.level)
return None
}
case BlockMessage.TYPE_GET_BLOCK => {
val gB = new GetBlock(blockMessage.getId())
logInfo("Received [" + gB + "]")
val buffer = getBlock(gB.id)
if (buffer == null) {
return None
}
return Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer)))
}
case _ => return None
}
}
private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) {
val startTimeMs = System.currentTimeMillis()
logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes)
blockManager.putBytes(id, bytes, level)
logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " with data size: " + bytes.array().length)
}
private def getBlock(id: String): ByteBuffer = {
val startTimeMs = System.currentTimeMillis()
logDebug("Getblock " + id + " started from " + startTimeMs)
val block = blockManager.get(id)
val buffer = block match {
case Some(tValues) => {
val values = tValues.asInstanceOf[Iterator[Any]]
val buffer = blockManager.dataSerialize(values)
buffer
}
case None => {
null
}
}
logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " and got buffer " + buffer)
return buffer
}
}
object BlockManagerWorker extends Logging {
private var blockManagerWorker: BlockManagerWorker = null
private val DATA_TRANSFER_TIME_OUT_MS: Long = 500
private val REQUEST_RETRY_INTERVAL_MS: Long = 1000
initLogging()
def startBlockManagerWorker(manager: BlockManager) {
blockManagerWorker = new BlockManagerWorker(manager)
}
def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = {
val blockManager = blockManagerWorker.blockManager
val connectionManager = blockManager.connectionManager
val serializer = blockManager.serializer
val blockMessage = BlockMessage.fromPutBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
val resultMessage = connectionManager.sendMessageReliablySync(
toConnManagerId, blockMessageArray.toBufferMessage())
return (resultMessage != None)
}
def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
val blockManager = blockManagerWorker.blockManager
val connectionManager = blockManager.connectionManager
val serializer = blockManager.serializer
val blockMessage = BlockMessage.fromGetBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
val responseMessage = connectionManager.sendMessageReliablySync(
toConnManagerId, blockMessageArray.toBufferMessage())
responseMessage match {
case Some(message) => {
val bufferMessage = message.asInstanceOf[BufferMessage]
logDebug("Response message received " + bufferMessage)
BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => {
logDebug("Found " + blockMessage)
return blockMessage.getData
})
}
case None => logDebug("No response message received"); return null
}
return null
}
}

Просмотреть файл

@ -0,0 +1,219 @@
package spark.storage
import java.nio._
import scala.collection.mutable.StringBuilder
import scala.collection.mutable.ArrayBuffer
import spark._
import spark.network._
case class GetBlock(id: String)
case class GotBlock(id: String, data: ByteBuffer)
case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel)
class BlockMessage() extends Logging{
// Un-initialized: typ = 0
// GetBlock: typ = 1
// GotBlock: typ = 2
// PutBlock: typ = 3
private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED
private var id: String = null
private var data: ByteBuffer = null
private var level: StorageLevel = null
initLogging()
def set(getBlock: GetBlock) {
typ = BlockMessage.TYPE_GET_BLOCK
id = getBlock.id
}
def set(gotBlock: GotBlock) {
typ = BlockMessage.TYPE_GOT_BLOCK
id = gotBlock.id
data = gotBlock.data
}
def set(putBlock: PutBlock) {
typ = BlockMessage.TYPE_PUT_BLOCK
id = putBlock.id
data = putBlock.data
level = putBlock.level
}
def set(buffer: ByteBuffer) {
val startTime = System.currentTimeMillis
/*
println()
println("BlockMessage: ")
while(buffer.remaining > 0) {
print(buffer.get())
}
buffer.rewind()
println()
println()
*/
typ = buffer.getInt()
val idLength = buffer.getInt()
val idBuilder = new StringBuilder(idLength)
for (i <- 1 to idLength) {
idBuilder += buffer.getChar()
}
id = idBuilder.toString()
logDebug("Set from buffer Result: " + typ + " " + id)
logDebug("Buffer position is " + buffer.position)
if (typ == BlockMessage.TYPE_PUT_BLOCK) {
val booleanInt = buffer.getInt()
val replication = buffer.getInt()
level = new StorageLevel(booleanInt, replication)
val dataLength = buffer.getInt()
data = ByteBuffer.allocate(dataLength)
if (dataLength != buffer.remaining) {
throw new Exception("Error parsing buffer")
}
data.put(buffer)
data.flip()
logDebug("Set from buffer Result 2: " + level + " " + data)
} else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
val dataLength = buffer.getInt()
logDebug("Data length is "+ dataLength)
logDebug("Buffer position is " + buffer.position)
data = ByteBuffer.allocate(dataLength)
if (dataLength != buffer.remaining) {
throw new Exception("Error parsing buffer")
}
data.put(buffer)
data.flip()
logDebug("Set from buffer Result 3: " + data)
}
val finishTime = System.currentTimeMillis
logDebug("Converted " + id + " from bytebuffer in " + (finishTime - startTime) / 1000.0 + " s")
}
def set(bufferMsg: BufferMessage) {
val buffer = bufferMsg.buffers.apply(0)
buffer.clear()
set(buffer)
}
def getType(): Int = {
return typ
}
def getId(): String = {
return id
}
def getData(): ByteBuffer = {
return data
}
def getLevel(): StorageLevel = {
return level
}
def toBufferMessage(): BufferMessage = {
val startTime = System.currentTimeMillis
val buffers = new ArrayBuffer[ByteBuffer]()
var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2)
buffer.putInt(typ).putInt(id.length())
id.foreach((x: Char) => buffer.putChar(x))
buffer.flip()
buffers += buffer
if (typ == BlockMessage.TYPE_PUT_BLOCK) {
buffer = ByteBuffer.allocate(8).putInt(level.toInt()).putInt(level.replication)
buffer.flip()
buffers += buffer
buffer = ByteBuffer.allocate(4).putInt(data.remaining)
buffer.flip()
buffers += buffer
buffers += data
} else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
buffer = ByteBuffer.allocate(4).putInt(data.remaining)
buffer.flip()
buffers += buffer
buffers += data
}
logDebug("Start to log buffers.")
buffers.foreach((x: ByteBuffer) => logDebug("" + x))
/*
println()
println("BlockMessage: ")
buffers.foreach(b => {
while(b.remaining > 0) {
print(b.get())
}
b.rewind()
})
println()
println()
*/
val finishTime = System.currentTimeMillis
logDebug("Converted " + id + " to buffer message in " + (finishTime - startTime) / 1000.0 + " s")
return Message.createBufferMessage(buffers)
}
override def toString(): String = {
"BlockMessage [type = " + typ + ", id = " + id + ", level = " + level +
", data = " + (if (data != null) data.remaining.toString else "null") + "]"
}
}
object BlockMessage {
val TYPE_NON_INITIALIZED: Int = 0
val TYPE_GET_BLOCK: Int = 1
val TYPE_GOT_BLOCK: Int = 2
val TYPE_PUT_BLOCK: Int = 3
def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = {
val newBlockMessage = new BlockMessage()
newBlockMessage.set(bufferMessage)
newBlockMessage
}
def fromByteBuffer(buffer: ByteBuffer): BlockMessage = {
val newBlockMessage = new BlockMessage()
newBlockMessage.set(buffer)
newBlockMessage
}
def fromGetBlock(getBlock: GetBlock): BlockMessage = {
val newBlockMessage = new BlockMessage()
newBlockMessage.set(getBlock)
newBlockMessage
}
def fromGotBlock(gotBlock: GotBlock): BlockMessage = {
val newBlockMessage = new BlockMessage()
newBlockMessage.set(gotBlock)
newBlockMessage
}
def fromPutBlock(putBlock: PutBlock): BlockMessage = {
val newBlockMessage = new BlockMessage()
newBlockMessage.set(putBlock)
newBlockMessage
}
def main(args: Array[String]) {
val B = new BlockMessage()
B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.DISK_AND_MEMORY_2))
val bMsg = B.toBufferMessage()
val C = new BlockMessage()
C.set(bMsg)
println(B.getId() + " " + B.getLevel())
println(C.getId() + " " + C.getLevel())
}
}

Просмотреть файл

@ -0,0 +1,140 @@
package spark.storage
import java.nio._
import scala.collection.mutable.StringBuilder
import scala.collection.mutable.ArrayBuffer
import spark._
import spark.network._
class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging {
def this(bm: BlockMessage) = this(Array(bm))
def this() = this(null.asInstanceOf[Seq[BlockMessage]])
def apply(i: Int) = blockMessages(i)
def iterator = blockMessages.iterator
def length = blockMessages.length
initLogging()
def set(bufferMessage: BufferMessage) {
val startTime = System.currentTimeMillis
val newBlockMessages = new ArrayBuffer[BlockMessage]()
val buffer = bufferMessage.buffers(0)
buffer.clear()
/*
println()
println("BlockMessageArray: ")
while(buffer.remaining > 0) {
print(buffer.get())
}
buffer.rewind()
println()
println()
*/
while(buffer.remaining() > 0) {
val size = buffer.getInt()
logDebug("Creating block message of size " + size + " bytes")
val newBuffer = buffer.slice()
newBuffer.clear()
newBuffer.limit(size)
logDebug("Trying to convert buffer " + newBuffer + " to block message")
val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer)
logDebug("Created " + newBlockMessage)
newBlockMessages += newBlockMessage
buffer.position(buffer.position() + size)
}
val finishTime = System.currentTimeMillis
logDebug("Converted block message array from buffer message in " + (finishTime - startTime) / 1000.0 + " s")
this.blockMessages = newBlockMessages
}
def toBufferMessage(): BufferMessage = {
val buffers = new ArrayBuffer[ByteBuffer]()
blockMessages.foreach(blockMessage => {
val bufferMessage = blockMessage.toBufferMessage
logDebug("Adding " + blockMessage)
val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size)
sizeBuffer.flip
buffers += sizeBuffer
buffers ++= bufferMessage.buffers
logDebug("Added " + bufferMessage)
})
logDebug("Buffer list:")
buffers.foreach((x: ByteBuffer) => logDebug("" + x))
/*
println()
println("BlockMessageArray: ")
buffers.foreach(b => {
while(b.remaining > 0) {
print(b.get())
}
b.rewind()
})
println()
println()
*/
return Message.createBufferMessage(buffers)
}
}
object BlockMessageArray {
def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = {
val newBlockMessageArray = new BlockMessageArray()
newBlockMessageArray.set(bufferMessage)
newBlockMessageArray
}
def main(args: Array[String]) {
val blockMessages =
(0 until 10).map(i => {
if (i % 2 == 0) {
val buffer = ByteBuffer.allocate(100)
buffer.clear
BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY))
} else {
BlockMessage.fromGetBlock(GetBlock(i.toString))
}
})
val blockMessageArray = new BlockMessageArray(blockMessages)
println("Block message array created")
val bufferMessage = blockMessageArray.toBufferMessage
println("Converted to buffer message")
val totalSize = bufferMessage.size
val newBuffer = ByteBuffer.allocate(totalSize)
newBuffer.clear()
bufferMessage.buffers.foreach(buffer => {
newBuffer.put(buffer)
buffer.rewind()
})
newBuffer.flip
val newBufferMessage = Message.createBufferMessage(newBuffer)
println("Copied to new buffer message, size = " + newBufferMessage.size)
val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage)
println("Converted back to block message array")
newBlockMessageArray.foreach(blockMessage => {
blockMessage.getType() match {
case BlockMessage.TYPE_PUT_BLOCK => {
val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel())
println(pB)
}
case BlockMessage.TYPE_GET_BLOCK => {
val gB = new GetBlock(blockMessage.getId())
println(gB)
}
}
})
}
}

Просмотреть файл

@ -0,0 +1,282 @@
package spark.storage
import spark.{Utils, Logging, Serializer, SizeEstimator}
import scala.collection.mutable.ArrayBuffer
import java.io.{File, RandomAccessFile}
import java.nio.ByteBuffer
import java.nio.channels.FileChannel.MapMode
import java.util.{UUID, LinkedHashMap}
import java.util.concurrent.Executors
import it.unimi.dsi.fastutil.io._
/**
* Abstract class to store blocks
*/
abstract class BlockStore(blockManager: BlockManager) extends Logging {
initLogging()
def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel)
def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer]
def getBytes(blockId: String): Option[ByteBuffer]
def getValues(blockId: String): Option[Iterator[Any]]
def remove(blockId: String)
def dataSerialize(values: Iterator[Any]): ByteBuffer = blockManager.dataSerialize(values)
def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = blockManager.dataDeserialize(bytes)
}
/**
* Class to store blocks in memory
*/
class MemoryStore(blockManager: BlockManager, maxMemory: Long)
extends BlockStore(blockManager) {
class Entry(var value: Any, val size: Long, val deserialized: Boolean)
private val memoryStore = new LinkedHashMap[String, Entry](32, 0.75f, true)
private var currentMemory = 0L
private val blockDropper = Executors.newSingleThreadExecutor()
def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
if (level.deserialized) {
bytes.rewind()
val values = dataDeserialize(bytes)
val elements = new ArrayBuffer[Any]
elements ++= values
val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
ensureFreeSpace(sizeEstimate)
val entry = new Entry(elements, sizeEstimate, true)
memoryStore.synchronized { memoryStore.put(blockId, entry) }
currentMemory += sizeEstimate
logDebug("Block " + blockId + " stored as values to memory")
} else {
val entry = new Entry(bytes, bytes.array().length, false)
ensureFreeSpace(bytes.array.length)
memoryStore.synchronized { memoryStore.put(blockId, entry) }
currentMemory += bytes.array().length
logDebug("Block " + blockId + " stored as " + bytes.array().length + " bytes to memory")
}
}
def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = {
if (level.deserialized) {
val elements = new ArrayBuffer[Any]
elements ++= values
val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
ensureFreeSpace(sizeEstimate)
val entry = new Entry(elements, sizeEstimate, true)
memoryStore.synchronized { memoryStore.put(blockId, entry) }
currentMemory += sizeEstimate
logDebug("Block " + blockId + " stored as values to memory")
return Left(elements.iterator)
} else {
val bytes = dataSerialize(values)
ensureFreeSpace(bytes.array().length)
val entry = new Entry(bytes, bytes.array().length, false)
memoryStore.synchronized { memoryStore.put(blockId, entry) }
currentMemory += bytes.array().length
logDebug("Block " + blockId + " stored as " + bytes.array.length + " bytes to memory")
return Right(bytes)
}
}
def getBytes(blockId: String): Option[ByteBuffer] = {
throw new UnsupportedOperationException("Not implemented")
}
def getValues(blockId: String): Option[Iterator[Any]] = {
val entry = memoryStore.synchronized { memoryStore.get(blockId) }
if (entry == null) {
return None
}
if (entry.deserialized) {
return Some(entry.value.asInstanceOf[ArrayBuffer[Any]].toIterator)
} else {
return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer]))
}
}
def remove(blockId: String) {
memoryStore.synchronized {
val entry = memoryStore.get(blockId)
if (entry != null) {
memoryStore.remove(blockId)
currentMemory -= entry.size
logDebug("Block " + blockId + " of size " + entry.size + " dropped from memory")
} else {
logWarning("Block " + blockId + " could not be removed as it doesnt exist")
}
}
}
private def drop(blockId: String) {
blockDropper.submit(new Runnable() {
def run() {
blockManager.dropFromMemory(blockId)
}
})
}
private def ensureFreeSpace(space: Long) {
logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
space, currentMemory, maxMemory))
val droppedBlockIds = new ArrayBuffer[String]()
var droppedMemory = 0L
memoryStore.synchronized {
val iter = memoryStore.entrySet().iterator()
while (maxMemory - (currentMemory - droppedMemory) < space && iter.hasNext) {
val pair = iter.next()
val blockId = pair.getKey
droppedBlockIds += blockId
droppedMemory += pair.getValue.size
logDebug("Decided to drop " + blockId)
}
}
for (blockId <- droppedBlockIds) {
drop(blockId)
}
droppedBlockIds.clear
}
}
/**
* Class to store blocks in disk
*/
class DiskStore(blockManager: BlockManager, rootDirs: String)
extends BlockStore(blockManager) {
val MAX_DIR_CREATION_ATTEMPTS: Int = 10
val localDirs = createLocalDirs()
var lastLocalDirUsed = 0
addShutdownHook()
def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
val file = createFile(blockId)
if (file != null) {
val channel = new RandomAccessFile(file, "rw").getChannel()
val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.array.length)
buffer.put(bytes.array)
channel.close()
val finishTime = System.currentTimeMillis
logDebug("Block " + blockId + " stored to file of " + bytes.array.length + " bytes to disk in " + (finishTime - startTime) + " ms")
} else {
logError("File not created for block " + blockId)
}
}
def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = {
val bytes = dataSerialize(values)
logDebug("Converted block " + blockId + " to " + bytes.array.length + " bytes")
putBytes(blockId, bytes, level)
return Right(bytes)
}
def getBytes(blockId: String): Option[ByteBuffer] = {
val file = getFile(blockId)
val length = file.length().toInt
val channel = new RandomAccessFile(file, "r").getChannel()
val bytes = ByteBuffer.allocate(length)
bytes.put(channel.map(MapMode.READ_WRITE, 0, length))
return Some(bytes)
}
def getValues(blockId: String): Option[Iterator[Any]] = {
val file = getFile(blockId)
val length = file.length().toInt
val channel = new RandomAccessFile(file, "r").getChannel()
val bytes = channel.map(MapMode.READ_ONLY, 0, length)
val buffer = dataDeserialize(bytes)
channel.close()
return Some(buffer)
}
def remove(blockId: String) {
throw new UnsupportedOperationException("Not implemented")
}
private def createFile(blockId: String): File = {
val file = getFile(blockId)
if (file == null) {
lastLocalDirUsed = (lastLocalDirUsed + 1) % localDirs.size
val newFile = new File(localDirs(lastLocalDirUsed), blockId)
newFile.getParentFile.mkdirs()
return newFile
} else {
logError("File for block " + blockId + " already exists on disk, " + file)
return null
}
}
private def getFile(blockId: String): File = {
logDebug("Getting file for block " + blockId)
// Search for the file in all the local directories, only one of them should have the file
val files = localDirs.map(localDir => new File(localDir, blockId)).filter(_.exists)
if (files.size > 1) {
throw new Exception("Multiple files for same block " + blockId + " exists: " +
files.map(_.toString).reduceLeft(_ + ", " + _))
return null
} else if (files.size == 0) {
return null
} else {
logDebug("Got file " + files(0) + " of size " + files(0).length + " bytes")
return files(0)
}
}
private def createLocalDirs(): Seq[File] = {
logDebug("Creating local directories at root dirs '" + rootDirs + "'")
rootDirs.split("[;,:]").map(rootDir => {
var foundLocalDir: Boolean = false
var localDir: File = null
var localDirUuid: UUID = null
var tries = 0
while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
tries += 1
try {
localDirUuid = UUID.randomUUID()
localDir = new File(rootDir, "spark-local-" + localDirUuid)
if (!localDir.exists) {
localDir.mkdirs()
foundLocalDir = true
}
} catch {
case e: Exception =>
logWarning("Attempt " + tries + " to create local dir failed", e)
}
}
if (!foundLocalDir) {
logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
" attempts to create local dir in " + rootDir)
System.exit(1)
}
logDebug("Created local directory at " + localDir)
localDir
})
}
private def addShutdownHook() {
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run() {
logDebug("Shutdown hook called")
localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
}
})
}
}

Просмотреть файл

@ -0,0 +1,78 @@
package spark.storage
import java.io._
class StorageLevel(
var useDisk: Boolean,
var useMemory: Boolean,
var deserialized: Boolean,
var replication: Int = 1)
extends Externalizable {
// TODO: Also add fields for caching priority, dataset ID, and flushing.
def this(booleanInt: Int, replication: Int) {
this(((booleanInt & 4) != 0),
((booleanInt & 2) != 0),
((booleanInt & 1) != 0),
replication)
}
def this() = this(false, true, false) // For deserialization
override def clone(): StorageLevel = new StorageLevel(
this.useDisk, this.useMemory, this.deserialized, this.replication)
override def equals(other: Any): Boolean = other match {
case s: StorageLevel =>
s.useDisk == useDisk &&
s.useMemory == useMemory &&
s.deserialized == deserialized &&
s.replication == replication
case _ =>
false
}
def toInt(): Int = {
var ret = 0
if (useDisk) {
ret += 4
}
if (useMemory) {
ret += 2
}
if (deserialized) {
ret += 1
}
return ret
}
override def writeExternal(out: ObjectOutput) {
out.writeByte(toInt().toByte)
out.writeByte(replication.toByte)
}
override def readExternal(in: ObjectInput) {
val flags = in.readByte()
useDisk = (flags & 4) != 0
useMemory = (flags & 2) != 0
deserialized = (flags & 1) != 0
replication = in.readByte()
}
override def toString(): String =
"StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
}
object StorageLevel {
val NONE = new StorageLevel(false, false, false)
val DISK_ONLY = new StorageLevel(true, false, false)
val MEMORY_ONLY = new StorageLevel(false, true, false)
val MEMORY_ONLY_2 = new StorageLevel(false, true, false, 2)
val MEMORY_ONLY_DESER = new StorageLevel(false, true, true)
val MEMORY_ONLY_DESER_2 = new StorageLevel(false, true, true, 2)
val DISK_AND_MEMORY = new StorageLevel(true, true, false)
val DISK_AND_MEMORY_2 = new StorageLevel(true, true, false, 2)
val DISK_AND_MEMORY_DESER = new StorageLevel(true, true, true)
val DISK_AND_MEMORY_DESER_2 = new StorageLevel(true, true, true, 2)
}

Просмотреть файл

@ -0,0 +1,30 @@
package spark.util
import java.io.InputStream
import java.nio.ByteBuffer
class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream {
override def read(): Int = {
if (buffer.remaining() == 0) {
-1
} else {
buffer.get()
}
}
override def read(dest: Array[Byte]): Int = {
read(dest, 0, dest.length)
}
override def read(dest: Array[Byte], offset: Int, length: Int): Int = {
val amountToGet = math.min(buffer.remaining(), length)
buffer.get(dest, offset, amountToGet)
return amountToGet
}
override def skip(bytes: Long): Long = {
val amountToSkip = math.min(bytes, buffer.remaining).toInt
buffer.position(buffer.position + amountToSkip)
return amountToSkip
}
}

Просмотреть файл

@ -0,0 +1,89 @@
package spark.util
/**
* A class for tracking the statistics of a set of numbers (count, mean and variance) in a
* numerically robust way. Includes support for merging two StatCounters. Based on Welford and
* Chan's algorithms described at http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
*/
class StatCounter(values: TraversableOnce[Double]) {
private var n: Long = 0 // Running count of our values
private var mu: Double = 0 // Running mean of our values
private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2)
merge(values)
def this() = this(Nil)
def merge(value: Double): StatCounter = {
val delta = value - mu
n += 1
mu += delta / n
m2 += delta * (value - mu)
this
}
def merge(values: TraversableOnce[Double]): StatCounter = {
values.foreach(v => merge(v))
this
}
def merge(other: StatCounter): StatCounter = {
if (other == this) {
merge(other.copy()) // Avoid overwriting fields in a weird order
} else {
val delta = other.mu - mu
if (other.n * 10 < n) {
mu = mu + (delta * other.n) / (n + other.n)
} else if (n * 10 < other.n) {
mu = other.mu - (delta * n) / (n + other.n)
} else {
mu = (mu * n + other.mu * other.n) / (n + other.n)
}
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
n += other.n
this
}
}
def copy(): StatCounter = {
val other = new StatCounter
other.n = n
other.mu = mu
other.m2 = m2
other
}
def count: Long = n
def mean: Double = mu
def sum: Double = n * mu
def variance: Double = {
if (n == 0)
Double.NaN
else
m2 / n
}
def sampleVariance: Double = {
if (n <= 1)
Double.NaN
else
m2 / (n - 1)
}
def stdev: Double = math.sqrt(variance)
def sampleStdev: Double = math.sqrt(sampleVariance)
override def toString: String = {
"(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev)
}
}
object StatCounter {
def apply(values: TraversableOnce[Double]) = new StatCounter(values)
def apply(values: Double*) = new StatCounter(values)
}

Просмотреть файл

@ -1,95 +1,103 @@
package spark
import org.scalatest.FunSuite
import collection.mutable.HashMap
import scala.collection.mutable.HashMap
import akka.actor._
import akka.actor.Actor
import akka.actor.Actor._
class CacheTrackerSuite extends FunSuite {
test("CacheTrackerActor slave initialization & cache status") {
System.setProperty("spark.master.port", "1345")
//System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
val tracker = new CacheTrackerActor
val tracker = actorOf(new CacheTrackerActor)
tracker.start()
tracker !? SlaveCacheStarted("host001", initialSize)
tracker !! SlaveCacheStarted("host001", initialSize)
assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 0L)))
assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 0L)))
tracker !? StopCacheTracker
tracker !! StopCacheTracker
}
test("RegisterRDD") {
System.setProperty("spark.master.port", "1345")
//System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
val tracker = new CacheTrackerActor
val tracker = actorOf(new CacheTrackerActor)
tracker.start()
tracker !? SlaveCacheStarted("host001", initialSize)
tracker !! SlaveCacheStarted("host001", initialSize)
tracker !? RegisterRDD(1, 3)
tracker !? RegisterRDD(2, 1)
tracker !! RegisterRDD(1, 3)
tracker !! RegisterRDD(2, 1)
assert(getCacheLocations(tracker) == Map(1 -> List(List(), List(), List()), 2 -> List(List())))
assert(getCacheLocations(tracker) === Map(1 -> List(List(), List(), List()), 2 -> List(List())))
tracker !? StopCacheTracker
tracker !! StopCacheTracker
}
test("AddedToCache") {
System.setProperty("spark.master.port", "1345")
//System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
val tracker = new CacheTrackerActor
val tracker = actorOf(new CacheTrackerActor)
tracker.start()
tracker !? SlaveCacheStarted("host001", initialSize)
tracker !! SlaveCacheStarted("host001", initialSize)
tracker !? RegisterRDD(1, 2)
tracker !? RegisterRDD(2, 1)
tracker !! RegisterRDD(1, 2)
tracker !! RegisterRDD(2, 1)
tracker !? AddedToCache(1, 0, "host001", 2L << 15)
tracker !? AddedToCache(1, 1, "host001", 2L << 11)
tracker !? AddedToCache(2, 0, "host001", 3L << 10)
tracker !! AddedToCache(1, 0, "host001", 2L << 15)
tracker !! AddedToCache(1, 1, "host001", 2L << 11)
tracker !! AddedToCache(2, 0, "host001", 3L << 10)
assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L)))
assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
assert(getCacheLocations(tracker) ===
Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
tracker !? StopCacheTracker
tracker !! StopCacheTracker
}
test("DroppedFromCache") {
System.setProperty("spark.master.port", "1345")
//System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
val tracker = new CacheTrackerActor
val tracker = actorOf(new CacheTrackerActor)
tracker.start()
tracker !? SlaveCacheStarted("host001", initialSize)
tracker !! SlaveCacheStarted("host001", initialSize)
tracker !? RegisterRDD(1, 2)
tracker !? RegisterRDD(2, 1)
tracker !! RegisterRDD(1, 2)
tracker !! RegisterRDD(2, 1)
tracker !? AddedToCache(1, 0, "host001", 2L << 15)
tracker !? AddedToCache(1, 1, "host001", 2L << 11)
tracker !? AddedToCache(2, 0, "host001", 3L << 10)
tracker !! AddedToCache(1, 0, "host001", 2L << 15)
tracker !! AddedToCache(1, 1, "host001", 2L << 11)
tracker !! AddedToCache(2, 0, "host001", 3L << 10)
assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L)))
assert(getCacheLocations(tracker) ===
Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
tracker !? DroppedFromCache(1, 1, "host001", 2L << 11)
tracker !! DroppedFromCache(1, 1, "host001", 2L << 11)
assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 68608L)))
assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 68608L)))
assert(getCacheLocations(tracker) ===
Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
tracker !? StopCacheTracker
tracker !! StopCacheTracker
}
/**
* Helper function to get cacheLocations from CacheTracker
*/
def getCacheLocations(tracker: CacheTrackerActor) = tracker !? GetCacheLocations match {
def getCacheLocations(tracker: ActorRef) = (tracker ? GetCacheLocations).get match {
case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]].map {
case (i, arr) => (i -> arr.toList)
}

Просмотреть файл

@ -2,6 +2,8 @@ package spark
import org.scalatest.FunSuite
import spark.scheduler.mesos.MesosScheduler
class MesosSchedulerSuite extends FunSuite {
test("memoryStringToMb"){

Просмотреть файл

@ -2,7 +2,7 @@ package spark
import org.scalatest.FunSuite
import java.io.{ByteArrayOutputStream, ByteArrayInputStream}
import util.Random
import scala.util.Random
class UtilsSuite extends FunSuite {

Просмотреть файл

@ -33,6 +33,7 @@ object SparkBuild extends Build {
"org.scalatest" %% "scalatest" % "1.6.1" % "test",
"org.scala-tools.testing" %% "scalacheck" % "1.9" % "test"
),
parallelExecution in Test := false,
/* Workaround for issue #206 (fixed after SBT 0.11.0) */
watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task,
const(std.TaskExtra.constant(Nil)), aggregate = true, includeRoot = true) apply { _.join.map(_.flatten) }
@ -57,8 +58,12 @@ object SparkBuild extends Build {
"asm" % "asm-all" % "3.3.1",
"com.google.protobuf" % "protobuf-java" % "2.4.1",
"de.javakaffee" % "kryo-serializers" % "0.9",
"se.scalablesolutions.akka" % "akka-actor" % "1.3.1",
"se.scalablesolutions.akka" % "akka-remote" % "1.3.1",
"se.scalablesolutions.akka" % "akka-slf4j" % "1.3.1",
"org.jboss.netty" % "netty" % "3.2.6.Final",
"it.unimi.dsi" % "fastutil" % "6.4.2"
"it.unimi.dsi" % "fastutil" % "6.4.4",
"colt" % "colt" % "1.2.0"
)
) ++ assemblySettings ++ Seq(test in assembly := {})
@ -68,8 +73,7 @@ object SparkBuild extends Build {
) ++ assemblySettings ++ Seq(test in assembly := {})
def examplesSettings = sharedSettings ++ Seq(
name := "spark-examples",
libraryDependencies += "colt" % "colt" % "1.2.0"
name := "spark-examples"
)
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")