зеркало из https://github.com/microsoft/spark.git
Conflict fixed
This commit is contained in:
@ -13,6 +13,8 @@ third_party/libmesos.dylib
@ -28,3 +30,4 @@ project/plugins/lib_managed/
@ -6,16 +6,14 @@ Lightning-Fast Cluster Computing - <http://www.spark-project.org/>
## Online Documentation
You can find the latest Spark documentation, including a programming
guide, on the project wiki at <http://github.com/mesos/spark/wiki>. This
file only contains basic setup instructions.
guide, on the project webpage at <http://spark-project.org/documentation.html>.
This README file only contains basic setup instructions.
## Building
Spark requires Scala 2.9.1. This version has been tested with 2.9.1.final.
The project is built using Simple Build Tool (SBT), which is packaged with it.
To build Spark and its example programs, run:
Spark requires Scala 2.9.2. The project is built using Simple Build Tool (SBT),
which is packaged with it. To build Spark and its example programs, run:
sbt/sbt compile
@ -142,7 +142,7 @@ class WPRSerializerInstance extends SerializerInstance {
class WPRSerializationStream(os: OutputStream) extends SerializationStream {
val dos = new DataOutputStream(os)
def writeObject[T](t: T): Unit = t match {
def writeObject[T](t: T): SerializationStream = t match {
case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match {
case links: Array[String] => {
dos.writeInt(0) // links
@ -151,17 +151,20 @@ class WPRSerializationStream(os: OutputStream) extends SerializationStream {
for (link <- links) {
case rank: Double => {
dos.writeInt(1) // rank
case (id: String, rank: Double) => {
dos.writeInt(2) // rank without wrapper
@ -1,8 +1,10 @@
# Set everything to be logged to the console
log4j.rootCategory=WARN, console
log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
log4j.rootCategory=INFO, file
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
@ -1,5 +1,8 @@
#!/usr/bin/env bash
# This Spark deploy script is a modified version of the Apache Hadoop deploy
# script, available under the Apache 2 license:
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
@ -1,5 +1,8 @@
#!/usr/bin/env bash
# This Spark deploy script is a modified version of the Apache Hadoop deploy
# script, available under the Apache 2 license:
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
@ -3,6 +3,7 @@ package spark
import java.io._
import scala.collection.mutable.Map
import scala.collection.generic.Growable
* A datatype that can be accumulated, i.e. has an commutative and associative +.
@ -92,6 +93,29 @@ trait AccumulableParam[R, T] extends Serializable {
def zero(initialValue: R): R
class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
extends AccumulableParam[R,T] {
def addAccumulator(growable: R, elem: T) : R = {
growable += elem
def addInPlace(t1: R, t2: R) : R = {
t1 ++= t2
def zero(initialValue: R): R = {
// We need to clone initialValue, but it's hard to specify that R should also be Cloneable.
// Instead we'll serialize it to a buffer and load it back.
val ser = (new spark.JavaSerializer).newInstance()
val copy = ser.deserialize[R](ser.serialize(initialValue))
copy.clear() // In case it contained stuff
* A simpler value of [[spark.Accumulable]] where the result type being accumulated is the same
* as the types of elements being merged.
@ -9,9 +9,9 @@ package spark
* known as map-side aggregations. When set to false,
* mergeCombiners function is not used.
class Aggregator[K, V, C] (
case class Aggregator[K, V, C] (
val createCombiner: V => C,
val mergeValue: (C, V) => C,
val mergeCombiners: (C, C) => C,
val mapSideCombine: Boolean = true)
extends Serializable
@ -2,12 +2,13 @@ package spark
import scala.collection.mutable.HashMap
class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
val index = idx
class BlockRDD[T: ClassManifest](sc: SparkContext, blockIds: Array[String]) extends RDD[T](sc) {
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
extends RDD[T](sc) {
val splits_ = (0 until blockIds.size).map(i => {
@ -11,8 +11,7 @@ import spark.storage.BlockManagerId
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
@ -29,39 +28,32 @@ class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
val blocksByAddress: Seq[(BlockManagerId, Seq[String])] = splitsByAddress.toSeq.map {
case (address, splits) =>
(address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId)))
(address, splits.map(i => "shuffle_%d_%d_%d".format(shuffleId, i, reduceId)))
try {
for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) {
blockOption match {
case Some(block) => {
val values = block
for(value <- values) {
val v = value.asInstanceOf[(K, V)]
func(v._1, v._2)
case None => {
throw new BlockException(blockId, "Did not get block " + blockId)
for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) {
blockOption match {
case Some(block) => {
val values = block
for(value <- values) {
val v = value.asInstanceOf[(K, V)]
func(v._1, v._2)
} catch {
// TODO: this is really ugly -- let's find a better way of throwing a FetchFailedException
case be: BlockException => {
val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]]*)".r
be.blockId match {
case regex(sId, mId, rId) => {
val address = addresses(mId.toInt)
throw new FetchFailedException(address, sId.toInt, mId.toInt, rId.toInt, be)
case _ => {
throw be
case None => {
val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
blockId match {
case regex(shufId, mapId, reduceId) =>
val addr = addresses(mapId.toInt)
throw new FetchFailedException(addr, shufId.toInt, mapId.toInt, reduceId.toInt, null)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block")
logDebug("Fetching and merging outputs of shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
@ -9,7 +9,7 @@ import java.util.LinkedHashMap
* some cache entries have pointers to a shared object. Nonetheless, this Cache should work well
* when most of the space is used by arrays of primitives or of simple classes.
class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
private[spark] class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
logInfo("BoundedMemoryCache.maxBytes = " + maxBytes)
def this() {
@ -104,9 +104,9 @@ class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
// An entry in our map; stores a cached object and its size in bytes
case class Entry(value: Any, size: Long)
private[spark] case class Entry(value: Any, size: Long)
object BoundedMemoryCache {
private[spark] object BoundedMemoryCache {
* Get maximum cache capacity from system configuration
@ -2,9 +2,9 @@ package spark
import java.util.concurrent.atomic.AtomicInteger
sealed trait CachePutResponse
case class CachePutSuccess(size: Long) extends CachePutResponse
case class CachePutFailure() extends CachePutResponse
private[spark] sealed trait CachePutResponse
private[spark] case class CachePutSuccess(size: Long) extends CachePutResponse
private[spark] case class CachePutFailure() extends CachePutResponse
* An interface for caches in Spark, to allow for multiple implementations. Caches are used to store
@ -22,7 +22,7 @@ case class CachePutFailure() extends CachePutResponse
* This abstract class handles the creation of key spaces, so that subclasses need only deal with
* keys that are unique across modules.
abstract class Cache {
private[spark] abstract class Cache {
private val nextKeySpaceId = new AtomicInteger(0)
private def newKeySpaceId() = nextKeySpaceId.getAndIncrement()
@ -52,7 +52,7 @@ abstract class Cache {
* A key namespace in a Cache.
class KeySpace(cache: Cache, val keySpaceId: Int) {
private[spark] class KeySpace(cache: Cache, val keySpaceId: Int) {
def get(datasetId: Any, partition: Int): Any =
cache.get((keySpaceId, datasetId), partition)
@ -15,19 +15,20 @@ import scala.collection.mutable.HashSet
import spark.storage.BlockManager
import spark.storage.StorageLevel
sealed trait CacheTrackerMessage
case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
case class MemoryCacheLost(host: String) extends CacheTrackerMessage
case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
case object GetCacheStatus extends CacheTrackerMessage
case object GetCacheLocations extends CacheTrackerMessage
case object StopCacheTracker extends CacheTrackerMessage
private[spark] sealed trait CacheTrackerMessage
class CacheTrackerActor extends Actor with Logging {
private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage
private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
private[spark] case object GetCacheStatus extends CacheTrackerMessage
private[spark] case object GetCacheLocations extends CacheTrackerMessage
private[spark] case object StopCacheTracker extends CacheTrackerMessage
private[spark] class CacheTrackerActor extends Actor with Logging {
// TODO: Should probably store (String, CacheType) tuples
private val locs = new HashMap[Int, Array[List[String]]]
@ -43,8 +44,6 @@ class CacheTrackerActor extends Actor with Logging {
def receive = {
case SlaveCacheStarted(host: String, size: Long) =>
logInfo("Started slave cache (size %s) on %s".format(
Utils.memoryBytesToString(size), host))
slaveCapacity.put(host, size)
slaveUsage.put(host, 0)
sender ! true
@ -56,22 +55,12 @@ class CacheTrackerActor extends Actor with Logging {
case AddedToCache(rddId, partition, host, size) =>
slaveUsage.put(host, getCacheUsage(host) + size)
logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
rddId, partition, host, Utils.memoryBytesToString(size),
locs(rddId)(partition) = host :: locs(rddId)(partition)
sender ! true
case DroppedFromCache(rddId, partition, host, size) =>
logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
rddId, partition, host, Utils.memoryBytesToString(size),
slaveUsage.put(host, getCacheUsage(host) - size)
// Do a sanity check to make sure usage is greater than 0.
val usage = getCacheUsage(host)
if (usage < 0) {
logError("Cache usage on %s is negative (%d)".format(host, usage))
locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
sender ! true
@ -101,7 +90,7 @@ class CacheTrackerActor extends Actor with Logging {
class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
extends Logging {
// Tracker actor on the master, or remote reference to it on workers
@ -151,7 +140,6 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
logInfo("Registering RDD ID " + rddId + " with cache")
registeredRddIds += rddId
communicate(RegisterRDD(rddId, numPartitions))
logInfo(RegisterRDD(rddId, numPartitions) + " successful")
@ -171,7 +159,6 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
// For BlockManager.scala only
def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) {
logInfo("notifyTheCacheTrackerFromBlockManager successful")
// Get a snapshot of the currently known locations
@ -181,7 +168,7 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
// Gets or computes an RDD split
def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = {
val key = "rdd:%d:%d".format(rdd.id, split.index)
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
case Some(cachedValues) =>
@ -223,7 +210,7 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
logInfo("Computing partition " + split)
try {
// BlockManager will iterate over results from compute to create RDD
blockManager.put(key, rdd.compute(split), storageLevel, false)
blockManager.put(key, rdd.compute(split), storageLevel, true)
//future.apply() // Wait for the reply from the cache tracker
blockManager.get(key) match {
case Some(values) =>
@ -1,5 +1,6 @@
package spark
class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable {
override val index: Int = idx
@ -9,7 +9,7 @@ import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
import org.objectweb.asm.commons.EmptyVisitor
import org.objectweb.asm.Opcodes._
object ClosureCleaner extends Logging {
private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
private def getClassReader(cls: Class[_]): ClassReader = {
new ClassReader(cls.getResourceAsStream(
@ -154,7 +154,7 @@ object ClosureCleaner extends Logging {
class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
return new EmptyVisitor {
@ -180,7 +180,7 @@ class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor
class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
var myName: String = null
override def visit(version: Int, access: Int, name: String, sig: String,
@ -6,16 +6,17 @@ import java.io.ObjectInputStream
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
sealed trait CoGroupSplitDep extends Serializable
case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
private[spark] sealed trait CoGroupSplitDep extends Serializable
private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable {
override val index: Int = idx
override def hashCode(): Int = idx
class CoGroupAggregator
private[spark] class CoGroupAggregator
extends Aggregator[Any, Any, ArrayBuffer[Any]](
{ x => ArrayBuffer(x) },
{ (b, x) => b += x },
@ -0,0 +1,43 @@
package spark
private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split
* Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of
* this RDD computes one or more of the parent ones. Will produce exactly `maxPartitions` if the
* parent had more than this many partitions, or fewer if the parent had fewer.
* This transformation is useful when an RDD with many partitions gets filtered into a smaller one,
* or to avoid having a large number of small tasks when processing a directory with many files.
class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int)
extends RDD[T](prev.context) {
@transient val splits_ : Array[Split] = {
val prevSplits = prev.splits
if (prevSplits.length < maxPartitions) {
prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) }
} else {
(0 until maxPartitions).map { i =>
val rangeStart = (i * prevSplits.length) / maxPartitions
val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions
new CoalescedRDDSplit(i, prevSplits.slice(rangeStart, rangeEnd))
override def splits = splits_
override def compute(split: Split): Iterator[T] = {
split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap {
parentSplit => prev.iterator(parentSplit)
val dependencies = List(
new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] =
@ -6,9 +6,13 @@ import java.util.concurrent.ThreadFactory
* A ThreadFactory that creates daemon threads
private object DaemonThreadFactory extends ThreadFactory {
override def newThread(r: Runnable): Thread = {
val t = new Thread(r)
return t
override def newThread(r: Runnable): Thread = new DaemonThread(r)
private class DaemonThread(r: Runnable = null) extends Thread {
override def run() {
if (r != null) {
@ -2,7 +2,7 @@ package spark
import spark.storage.BlockManagerId
class FetchFailedException(
private[spark] class FetchFailedException(
val bmAddress: BlockManagerId,
val shuffleId: Int,
val mapId: Int,
@ -18,7 +18,7 @@ import org.apache.hadoop.util.ReflectionUtils
* A Spark split class that wraps around a Hadoop InputSplit.
class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
extends Split
with Serializable {
@ -42,7 +42,8 @@ class HadoopRDD[K, V](
minSplits: Int)
extends RDD[(K, V)](sc) {
val serializableConf = new SerializableWritable(conf)
// A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it
val confBroadcast = sc.broadcast(new SerializableWritable(conf))
val splits_ : Array[Split] = {
@ -66,7 +67,7 @@ class HadoopRDD[K, V](
val split = theSplit.asInstanceOf[HadoopSplit]
var reader: RecordReader[K, V] = null
val conf = serializableConf.value
val conf = confBroadcast.value.value
val fmt = createInputFormat(conf)
reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
@ -0,0 +1,47 @@
package spark
import java.io.{File, PrintWriter}
import java.net.URL
import scala.collection.mutable.HashMap
import org.apache.hadoop.fs.FileUtil
private[spark] class HttpFileServer extends Logging {
var baseDir : File = null
var fileDir : File = null
var jarDir : File = null
var httpServer : HttpServer = null
var serverUri : String = null
def initialize() {
baseDir = Utils.createTempDir()
fileDir = new File(baseDir, "files")
jarDir = new File(baseDir, "jars")
logInfo("HTTP File server directory is " + baseDir)
httpServer = new HttpServer(baseDir)
serverUri = httpServer.uri
def stop() {
def addFile(file: File) : String = {
addFileToDir(file, fileDir)
return serverUri + "/files/" + file.getName
def addJar(file: File) : String = {
addFileToDir(file, jarDir)
return serverUri + "/jars/" + file.getName
def addFileToDir(file: File, dir: File) : String = {
Utils.copyFile(file, new File(dir, file.getName))
return dir + "/" + file.getName
@ -12,14 +12,14 @@ import org.eclipse.jetty.util.thread.QueuedThreadPool
* Exception type thrown by HttpServer when it is in the wrong state for an operation.
class ServerStateException(message: String) extends Exception(message)
private[spark] class ServerStateException(message: String) extends Exception(message)
* An HTTP server for static content used to allow worker nodes to access JARs added to SparkContext
* as well as classes created by the interpreter when the user types in code. This is just a wrapper
* around a Jetty server.
class HttpServer(resourceBase: File) extends Logging {
private[spark] class HttpServer(resourceBase: File) extends Logging {
private var server: Server = null
private var port: Int = -1
@ -5,14 +5,14 @@ import java.nio.ByteBuffer
import spark.util.ByteBufferInputStream
class JavaSerializationStream(out: OutputStream) extends SerializationStream {
private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream {
val objOut = new ObjectOutputStream(out)
def writeObject[T](t: T) { objOut.writeObject(t) }
def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this }
def flush() { objOut.flush() }
def close() { objOut.close() }
class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
extends DeserializationStream {
val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass) =
@ -23,7 +23,7 @@ extends DeserializationStream {
def close() { objIn.close() }
class JavaSerializerInstance extends SerializerInstance {
private[spark] class JavaSerializerInstance extends SerializerInstance {
def serialize[T](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
@ -57,6 +57,6 @@ class JavaSerializerInstance extends SerializerInstance {
class JavaSerializer extends Serializer {
private[spark] class JavaSerializer extends Serializer {
def newInstance(): SerializerInstance = new JavaSerializerInstance
@ -10,15 +10,17 @@ import scala.collection.mutable
import com.esotericsoftware.kryo._
import com.esotericsoftware.kryo.{Serializer => KSerializer}
import com.esotericsoftware.kryo.serialize.ClassSerializer
import com.esotericsoftware.kryo.serialize.SerializableSerializer
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
import spark.broadcast._
import spark.storage._
* Zig-zag encoder used to write object sizes to serialization streams.
* Based on Kryo's integer encoder.
object ZigZag {
private[spark] object ZigZag {
def writeInt(n: Int, out: OutputStream) {
var value = n
if ((value & ~0x7F) == 0) {
@ -66,22 +68,25 @@ object ZigZag {
class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream)
extends SerializationStream {
val channel = Channels.newChannel(out)
def writeObject[T](t: T) {
def writeObject[T](t: T): SerializationStream = {
kryo.writeClassAndObject(threadBuffer, t)
ZigZag.writeInt(threadBuffer.position(), out)
def flush() { out.flush() }
def close() { out.close() }
class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream)
extends DeserializationStream {
def readObject[T](): T = {
@ -92,7 +97,7 @@ extends DeserializationStream {
def close() { in.close() }
class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
val kryo = ks.kryo
val threadBuffer = ks.threadBuffer.get()
val objectBuffer = ks.objectBuffer.get()
@ -159,7 +164,9 @@ trait KryoRegistrator {
class KryoSerializer extends Serializer with Logging {
val kryo = createKryo()
// Make this lazy so that it only gets called once we receive our first task on each executor,
// so we can pull out any custom Kryo registrator from the user's JARs.
lazy val kryo = createKryo()
val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024
@ -190,8 +197,8 @@ class KryoSerializer extends Serializer with Logging {
(1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1),
PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY_DESER),
PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
GotBlock("1", ByteBuffer.allocate(1)),
@ -203,6 +210,10 @@ class KryoSerializer extends Serializer with Logging {
kryo.register(classOf[Class[_]], new ClassSerializer(kryo))
// Allow sending SerializableWritable
kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer())
kryo.register(classOf[HttpBroadcast[_]], new SerializableSerializer())
// Register some commonly used Scala singleton objects. Because these
// are singletons, we must return the exact same local object when we
// deserialize rather than returning a clone as FieldSerializer would.
@ -250,7 +261,8 @@ class KryoSerializer extends Serializer with Logging {
val regCls = System.getProperty("spark.kryo.registrator")
if (regCls != null) {
logInfo("Running user registrator: " + regCls)
val reg = Class.forName(regCls).newInstance().asInstanceOf[KryoRegistrator]
val classLoader = Thread.currentThread.getContextClassLoader
val reg = Class.forName(regCls, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]
@ -1,5 +1,6 @@
package spark
import java.io.{DataInputStream, DataOutputStream, ByteArrayOutputStream, ByteArrayInputStream}
import java.util.concurrent.ConcurrentHashMap
import akka.actor._
@ -10,20 +11,20 @@ import akka.util.Duration
import akka.util.Timeout
import akka.util.duration._
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import spark.storage.BlockManagerId
sealed trait MapOutputTrackerMessage
case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage
case object StopMapOutputTracker extends MapOutputTrackerMessage
private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
class MapOutputTrackerActor(bmAddresses: ConcurrentHashMap[Int, Array[BlockManagerId]])
extends Actor with Logging {
private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
def receive = {
case GetMapOutputLocations(shuffleId: Int) =>
logInfo("Asked to get map output locations for shuffle " + shuffleId)
sender ! bmAddresses.get(shuffleId)
sender ! tracker.getSerializedLocations(shuffleId)
case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
@ -32,22 +33,26 @@ extends Actor with Logging {
class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging {
private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging {
val ip: String = System.getProperty("spark.master.host", "localhost")
val port: Int = System.getProperty("spark.master.port", "7077").toInt
val actorName: String = "MapOutputTracker"
val timeout = 10.seconds
private var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
private var generation: Long = 0
private var generationLock = new java.lang.Object
// Cache a serialized version of the output locations for each shuffle to send them out faster
var cacheGeneration = generation
val cachedSerializedLocs = new HashMap[Int, Array[Byte]]
var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(bmAddresses)), name = actorName)
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
logInfo("Registered MapOutputTrackerActor actor")
} else {
@ -134,15 +139,16 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
val fetched = askTracker(GetMapOutputLocations(shuffleId)).asInstanceOf[Array[BlockManagerId]]
val fetchedBytes = askTracker(GetMapOutputLocations(shuffleId)).asInstanceOf[Array[Byte]]
val fetchedLocs = deserializeLocations(fetchedBytes)
logInfo("Got the output locations")
bmAddresses.put(shuffleId, fetched)
bmAddresses.put(shuffleId, fetchedLocs)
fetching.synchronized {
fetching -= shuffleId
return fetched
return fetchedLocs
} else {
return locs
@ -181,4 +187,70 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg
def getSerializedLocations(shuffleId: Int): Array[Byte] = {
var locs: Array[BlockManagerId] = null
var generationGotten: Long = -1
generationLock.synchronized {
if (generation > cacheGeneration) {
cacheGeneration = generation
cachedSerializedLocs.get(shuffleId) match {
case Some(bytes) =>
return bytes
case None =>
locs = bmAddresses.get(shuffleId)
generationGotten = generation
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
val bytes = serializeLocations(locs)
// Add them into the table only if the generation hasn't changed while we were working
generationLock.synchronized {
if (generation == generationGotten) {
cachedSerializedLocs(shuffleId) = bytes
return bytes
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by grouping together the locations by block manager ID.
def serializeLocations(locs: Array[BlockManagerId]): Array[Byte] = {
val out = new ByteArrayOutputStream
val dataOut = new DataOutputStream(out)
val grouped = locs.zipWithIndex.groupBy(_._1)
for ((id, pairs) <- grouped if id != null) {
for ((_, blockIndex) <- pairs) {
// Opposite of serializeLocations.
def deserializeLocations(bytes: Array[Byte]): Array[BlockManagerId] = {
val dataIn = new DataInputStream(new ByteArrayInputStream(bytes))
val length = dataIn.readInt()
val array = new Array[BlockManagerId](length)
val numGroups = dataIn.readInt()
for (i <- 0 until numGroups) {
val ip = dataIn.readUTF()
val port = dataIn.readInt()
val id = new BlockManagerId(ip, port)
val numBlocks = dataIn.readInt()
for (j <- 0 until numBlocks) {
array(dataIn.readInt()) = id
@ -13,6 +13,7 @@ import org.apache.hadoop.mapreduce.TaskAttemptID
import java.util.Date
import java.text.SimpleDateFormat
class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
extends Split {
@ -28,7 +29,9 @@ class NewHadoopRDD[K, V](
@transient conf: Configuration)
extends RDD[(K, V)](sc) {
private val serializableConf = new SerializableWritable(conf)
// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
val confBroadcast = sc.broadcast(new SerializableWritable(conf))
// private val serializableConf = new SerializableWritable(conf)
private val jobtrackerId: String = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
@ -41,7 +44,7 @@ class NewHadoopRDD[K, V](
private val splits_ : Array[Split] = {
val inputFormat = inputFormatClass.newInstance
val jobContext = new JobContext(serializableConf.value, jobId)
val jobContext = new JobContext(conf, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray
val result = new Array[Split](rawSplits.size)
for (i <- 0 until rawSplits.size) {
@ -54,9 +57,9 @@ class NewHadoopRDD[K, V](
override def compute(theSplit: Split) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopSplit]
val conf = serializableConf.value
val conf = confBroadcast.value.value
val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
val context = new TaskAttemptContext(serializableConf.value, attemptId)
val context = new TaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
val reader = format.createRecordReader(split.serializableHadoopSplit.value, context)
reader.initialize(split.serializableHadoopSplit.value, context)
@ -1,11 +1,10 @@
package spark
import java.io.EOFException
import java.net.URL
import java.io.ObjectInputStream
import java.net.URL
import java.util.{Date, HashMap => JHashMap}
import java.util.concurrent.atomic.AtomicLong
import java.util.{HashMap => JHashMap}
import java.util.Date
import java.text.SimpleDateFormat
import scala.collection.Map
@ -50,9 +49,18 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
partitioner: Partitioner): RDD[(K, C)] = {
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
new ShuffledRDD(self, aggregator, partitioner)
partitioner: Partitioner,
mapSideCombine: Boolean = true): RDD[(K, C)] = {
val aggregator =
if (mapSideCombine) {
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
new Aggregator[K, V, C](createCombiner, mergeValue, null, false)
new ShuffledAggregatedRDD(self, aggregator, partitioner)
def combineByKey[C](createCombiner: V => C,
@ -65,7 +73,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = {
combineByKey[V]((v: V) => v, func, func, partitioner)
def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {
def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
val map = new JHashMap[K, V]
@ -116,13 +124,24 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
groupByKey(new HashPartitioner(numSplits))
def partitionBy(partitioner: Partitioner): RDD[(K, V)] = {
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
val bufs = combineByKey[ArrayBuffer[V]](
createCombiner _, mergeValue _, mergeCombiners _, partitioner)
bufs.flatMapValues(buf => buf)
* Repartition the RDD using the specified partitioner. If mapSideCombine is
* true, Spark will group values of the same key together on the map side
* before the repartitioning. If a large number of duplicated keys are
* expected, and the size of the keys are large, mapSideCombine should be set
* to true.
def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = {
if (mapSideCombine) {
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
val bufs = combineByKey[ArrayBuffer[V]](
createCombiner _, mergeValue _, mergeCombiners _, partitioner)
bufs.flatMapValues(buf => buf)
} else {
new RepartitionShuffledRDD(self, partitioner)
def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = {
@ -194,17 +213,17 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
def collectAsMap(): Map[K, V] = HashMap(self.collect(): _*)
def mapValues[U](f: V => U): RDD[(K, U)] = {
val cleanF = self.context.clean(f)
new MappedValuesRDD(self, cleanF)
def flatMapValues[U](f: V => TraversableOnce[U]): RDD[(K, U)] = {
val cleanF = self.context.clean(f)
new FlatMappedValuesRDD(self, cleanF)
def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = {
val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]),
@ -215,12 +234,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
(vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner)
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]],
other1.asInstanceOf[RDD[(_, _)]],
other1.asInstanceOf[RDD[(_, _)]],
other2.asInstanceOf[RDD[(_, _)]]),
val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest)
@ -289,7 +308,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) {
saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) {
saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
@ -363,7 +382,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf))
def saveAsHadoopDataset(conf: JobConf) {
val outputFormatClass = conf.getOutputFormat
val keyClass = conf.getOutputKeyClass
@ -377,7 +396,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
if (valueClass == null) {
throw new SparkException("Output value class not set")
logInfo("Saving as hadoop file of type (" + keyClass.getSimpleName+ ", " + valueClass.getSimpleName+ ")")
val writer = new HadoopWriter(conf)
@ -390,14 +409,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
writer.setup(context.stageId, context.splitId, attemptNumber)
var count = 0
while(iter.hasNext) {
val record = iter.next
count += 1
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
@ -413,28 +432,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
self: RDD[(K, V)])
extends Logging
extends Logging
with Serializable {
def sortByKey(ascending: Boolean = true): RDD[(K,V)] = {
val rangePartitionedRDD = self.partitionBy(new RangePartitioner(self.splits.size, self, ascending))
new SortedRDD(rangePartitionedRDD, ascending)
def sortByKey(ascending: Boolean = true, numSplits: Int = self.splits.size): RDD[(K,V)] = {
new ShuffledSortedRDD(self, ascending, numSplits)
class SortedRDD[K <% Ordered[K], V](prev: RDD[(K, V)], ascending: Boolean)
extends RDD[(K, V)](prev.context) {
override def splits = prev.splits
override val partitioner = prev.partitioner
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = {
.sortWith((x, y) => if (ascending) x._1 < y._1 else x._1 > y._1).iterator
class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
@ -444,7 +449,7 @@ class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)]
class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U])
extends RDD[(K, U)](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override val partitioner = prev.partitioner
@ -454,6 +459,6 @@ class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U]
object Manifests {
private[spark] object Manifests {
val seqSeqManifest = classManifest[Seq[Seq[_]]]
@ -3,7 +3,7 @@ package spark
import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer
class ParallelCollectionSplit[T: ClassManifest](
private[spark] class ParallelCollectionSplit[T: ClassManifest](
val rddId: Long,
val slice: Int,
values: Seq[T])
@ -41,9 +41,9 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
} else {
val rddSize = rdd.count()
val maxSampleSize = partitions * 10.0
val maxSampleSize = partitions * 20.0
val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
val rddSample = rdd.sample(true, frac, 1).map(_._1).collect().sortWith(_ < _)
val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sortWith(_ < _)
if (rddSample.length == 0) {
} else {
@ -61,6 +61,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def compute(split: Split): Iterator[T]
@transient val dependencies: List[Dependency[_]]
// Record user function generating this RDD
val origin = Utils.getSparkCallSite
// Optionally overridden by subclasses to specify how they are partitioned
val partitioner: Option[Partitioner] = None
@ -68,6 +71,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def preferredLocations(split: Split): Seq[String] = Nil
def context = sc
def elementClassManifest: ClassManifest[T] = classManifest[T]
// Get a unique ID for this RDD
val id = sc.newRddId()
@ -87,21 +92,21 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
// Turn on the default caching level for this RDD
def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY_DESER)
def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY)
// Turn on the default caching level for this RDD
def cache(): RDD[T] = persist()
def getStorageLevel = storageLevel
def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER): RDD[T] = {
private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
if (!level.useDisk && level.replication < 2) {
throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
// This is a hack. Ideally this should re-use the code used by the CacheTracker
// to generate the key.
def getSplitKey(split: Split) = "rdd:%d:%d".format(this.id, split.index)
def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index)
sc.runJob(this, (iter: Iterator[T]) => {} )
@ -131,7 +136,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f))
def distinct(): RDD[T] = map(x => (x, "")).reduceByKey((x, y) => x).map(_._1)
def distinct(numSplits: Int = splits.size): RDD[T] =
map(x => (x, null)).reduceByKey((x, y) => x, numSplits).map(_._1)
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] =
new SampledRDD(this, withReplacement, fraction, seed)
@ -143,8 +149,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
var initialCount = count()
var maxSelected = 0
if (initialCount > Integer.MAX_VALUE) {
maxSelected = Integer.MAX_VALUE
if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE - 1
} else {
maxSelected = initialCount.toInt
@ -159,15 +165,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
total = num
var samples = this.sample(withReplacement, fraction, seed).collect()
val rand = new Random(seed)
var samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
while (samples.length < total) {
samples = this.sample(withReplacement, fraction, seed).collect()
samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
val arr = samples.take(total)
return arr
Utils.randomizeInPlace(samples, rand).take(total)
def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
@ -195,6 +200,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] =
new MapPartitionsRDD(this, sc.clean(f))
def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
new MapPartitionsWithSplitRDD(this, sc.clean(f))
// Actions (launch a job to return a value to the user program)
def foreach(f: T => Unit) {
@ -416,3 +424,18 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = f(prev.iterator(split))
* A variant of the MapPartitionsRDD that passes the split index into the
* closure. This can be used to generate or collect partition specific
* information such as the number of tuples in a partition.
class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: (Int, Iterator[T]) => Iterator[U])
extends RDD[U](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = f(split.index, prev.iterator(split))
@ -1,7 +1,10 @@
package spark
import java.util.Random
import cern.jet.random.Poisson
import cern.jet.random.engine.DRand
class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
override val index: Int = prev.index
@ -28,19 +31,21 @@ class SampledRDD[T: ClassManifest](
override def compute(splitIn: Split) = {
val split = splitIn.asInstanceOf[SampledRDDSplit]
val rg = new Random(split.seed)
// Sampling with replacement (TODO: use reservoir sampling to make this more efficient?)
if (withReplacement) {
val oldData = prev.iterator(split.prev).toArray
val sampleSize = (oldData.size * frac).ceil.toInt
val sampledData = {
// all of oldData's indices are candidates, even if sampleSize < oldData.size
for (i <- 1 to sampleSize)
yield oldData(rg.nextInt(oldData.size))
// For large datasets, the expected number of occurrences of each element in a sample with
// replacement is Poisson(frac). We use that to get a count for each element.
val poisson = new Poisson(frac, new DRand(split.seed))
prev.iterator(split.prev).flatMap { element =>
val count = poisson.nextInt()
if (count == 0) {
Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
} else {
} else { // Sampling without replacement
prev.iterator(split.prev).filter(x => (rg.nextDouble <= frac))
val rand = new Random(split.seed)
prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac))
@ -12,14 +12,14 @@ import spark.util.ByteBufferInputStream
* A serializer. Because some serialization libraries are not thread safe, this class is used to
* create SerializerInstances that do the actual serialization.
trait Serializer {
private[spark] trait Serializer {
def newInstance(): SerializerInstance
* An instance of the serializer, for use by one thread at a time.
trait SerializerInstance {
private[spark] trait SerializerInstance {
def serialize[T](t: T): ByteBuffer
def deserialize[T](bytes: ByteBuffer): T
@ -43,15 +43,15 @@ trait SerializerInstance {
def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
// Default implementation uses deserializeStream
deserializeStream(new ByteBufferInputStream(buffer)).toIterator
deserializeStream(new ByteBufferInputStream(buffer)).asIterator
* A stream for writing serialized objects.
trait SerializationStream {
def writeObject[T](t: T): Unit
private[spark] trait SerializationStream {
def writeObject[T](t: T): SerializationStream
def flush(): Unit
def close(): Unit
@ -66,7 +66,7 @@ trait SerializationStream {
* A stream for reading serialized objects.
trait DeserializationStream {
private[spark] trait DeserializationStream {
def readObject[T](): T
def close(): Unit
@ -74,7 +74,7 @@ trait DeserializationStream {
* Read the elements of this stream through an iterator. This can only be called once, as
* reading each element will consume data from the input source.
def toIterator: Iterator[Any] = new Iterator[Any] {
def asIterator: Iterator[Any] = new Iterator[Any] {
var gotNext = false
var finished = false
var nextValue: Any = null
@ -1,6 +1,6 @@
package spark
abstract class ShuffleFetcher {
private[spark] abstract class ShuffleFetcher {
// Fetch the shuffle outputs for a given ShuffleDependency, calling func exactly
// once on each key-value pair obtained.
def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit)
@ -1,98 +0,0 @@
package spark
import java.io._
import java.net.URL
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.{ArrayBuffer, HashMap}
import spark._
class ShuffleManager extends Logging {
private var nextShuffleId = new AtomicLong(0)
private var shuffleDir: File = null
private var server: HttpServer = null
private var serverUri: String = null
private def initialize() {
// TODO: localDir should be created by some mechanism common to Spark
// so that it can be shared among shuffle, broadcast, etc
val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
var tries = 0
var foundLocalDir = false
var localDir: File = null
var localDirUuid: UUID = null
while (!foundLocalDir && tries < 10) {
tries += 1
try {
localDirUuid = UUID.randomUUID
localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
if (!localDir.exists) {
foundLocalDir = true
} catch {
case e: Exception =>
logWarning("Attempt " + tries + " to create local dir failed", e)
if (!foundLocalDir) {
logError("Failed 10 attempts to create local dir in " + localDirRoot)
shuffleDir = new File(localDir, "shuffle")
logInfo("Shuffle dir: " + shuffleDir)
// Add a shutdown hook to delete the local dir
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dir") {
override def run() {
val extServerPort = System.getProperty(
"spark.localFileShuffle.external.server.port", "-1").toInt
if (extServerPort != -1) {
// We're using an external HTTP server; set URI relative to its root
var extServerPath = System.getProperty(
"spark.localFileShuffle.external.server.path", "")
if (extServerPath != "" && !extServerPath.endsWith("/")) {
extServerPath += "/"
serverUri = "http://%s:%d/%s/spark-local-%s".format(
Utils.localIpAddress, extServerPort, extServerPath, localDirUuid)
} else {
// Create our own server
server = new HttpServer(localDir)
serverUri = server.uri
logInfo("Local URI: " + serverUri)
def stop() {
if (server != null) {
def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
val dir = new File(shuffleDir, shuffleId + "/" + inputId)
val file = new File(dir, "" + outputId)
return file
def getServerUri(): String = {
def newShuffleId(): Long = {
@ -1,29 +1,94 @@
package spark
import scala.collection.mutable.ArrayBuffer
import java.util.{HashMap => JHashMap}
class ShuffledRDDSplit(val idx: Int) extends Split {
private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
override val index = idx
override def hashCode(): Int = idx
class ShuffledRDD[K, V, C](
* The resulting RDD from a shuffle (e.g. repartitioning of data).
abstract class ShuffledRDD[K, V, C](
@transient parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C],
part : Partitioner)
part: Partitioner)
extends RDD[(K, C)](parent.context) {
//override val partitioner = Some(part)
override val partitioner = Some(part)
val splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
override def splits = splits_
override def preferredLocations(split: Split) = Nil
val dep = new ShuffleDependency(context.newShuffleId, parent, aggregator, part)
override val dependencies = List(dep)
* Repartition a key-value pair RDD.
class RepartitionShuffledRDD[K, V](
@transient parent: RDD[(K, V)],
part: Partitioner)
extends ShuffledRDD[K, V, V](
Aggregator[K, V, V](null, null, null, false),
part) {
override def compute(split: Split): Iterator[(K, V)] = {
val buf = new ArrayBuffer[(K, V)]
val fetcher = SparkEnv.get.shuffleFetcher
def addTupleToBuffer(k: K, v: V) = { buf += Tuple(k, v) }
fetcher.fetch[K, V](dep.shuffleId, split.index, addTupleToBuffer)
* A sort-based shuffle (that doesn't apply aggregation). It does so by first
* repartitioning the RDD by range, and then sort within each range.
class ShuffledSortedRDD[K <% Ordered[K]: ClassManifest, V](
@transient parent: RDD[(K, V)],
ascending: Boolean,
numSplits: Int)
extends RepartitionShuffledRDD[K, V](
new RangePartitioner(numSplits, parent, ascending)) {
override def compute(split: Split): Iterator[(K, V)] = {
// By separating this from RepartitionShuffledRDD, we avoided a
// buf.iterator.toArray call, thus avoiding building up the buffer twice.
val buf = new ArrayBuffer[(K, V)]
def addTupleToBuffer(k: K, v: V) { buf += ((k, v)) }
SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index, addTupleToBuffer)
if (ascending) {
buf.sortWith((x, y) => x._1 < y._1).iterator
} else {
buf.sortWith((x, y) => x._1 > y._1).iterator
* The resulting RDD from shuffle and running (hash-based) aggregation.
class ShuffledAggregatedRDD[K, V, C](
@transient parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C],
part : Partitioner)
extends ShuffledRDD[K, V, C](parent, aggregator, part) {
override def compute(split: Split): Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
@ -22,7 +22,7 @@ import it.unimi.dsi.fastutil.ints.IntOpenHashSet
* Based on the following JavaWorld article:
* http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html
object SizeEstimator extends Logging {
private[spark] object SizeEstimator extends Logging {
// Sizes of primitive types
private val BYTE_SIZE = 1
@ -77,22 +77,18 @@ object SizeEstimator extends Logging {
return System.getProperty("spark.test.useCompressedOops").toBoolean
try {
val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic";
val server = ManagementFactory.getPlatformMBeanServer();
val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic"
val server = ManagementFactory.getPlatformMBeanServer()
val bean = ManagementFactory.newPlatformMXBeanProxy(server,
hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean]);
hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean])
return bean.getVMOption("UseCompressedOops").getValue.toBoolean
} catch {
case e: IllegalArgumentException => {
logWarning("Exception while trying to check if compressed oops is enabled", e)
// Fall back to checking if maxMemory < 32GB
return Runtime.getRuntime.maxMemory < (32L*1024*1024*1024)
case e: SecurityException => {
logWarning("No permission to create MBeanServer", e)
// Fall back to checking if maxMemory < 32GB
return Runtime.getRuntime.maxMemory < (32L*1024*1024*1024)
case e: Exception => {
// Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB
val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024)
val guessInWords = if (guess) "yes" else "not"
logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords)
return guess
@ -146,6 +142,10 @@ object SizeEstimator extends Logging {
val cls = obj.getClass
if (cls.isArray) {
visitArray(obj, cls, state)
} else if (obj.isInstanceOf[ClassLoader] || obj.isInstanceOf[Class[_]]) {
// Hadoop JobConfs created in the interpreter have a ClassLoader, which greatly confuses
// the size estimator since it references the whole REPL. Do nothing in this case. In
// general all ClassLoaders and Classes will be shared between objects anyway.
} else {
val classInfo = getClassInfo(cls)
state.size += classInfo.shellSize
@ -5,7 +5,7 @@ import com.google.common.collect.MapMaker
* An implementation of Cache that uses soft references.
class SoftReferenceCache extends Cache {
private[spark] class SoftReferenceCache extends Cache {
val map = new MapMaker().softValues().makeMap[Any, Any]()
override def get(datasetId: Any, partition: Int): Any =
@ -2,13 +2,15 @@ package spark
import java.io._
import java.util.concurrent.atomic.AtomicInteger
import java.net.{URI, URLClassLoader}
import akka.actor.Actor
import akka.actor.Actor._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.generic.Growable
import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat
@ -34,6 +36,8 @@ import org.apache.mesos.{Scheduler, MesosNativeLibrary}
import spark.broadcast._
import spark.deploy.LocalSparkCluster
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
@ -51,7 +55,7 @@ class SparkContext(
val sparkHome: String,
val jars: Seq[String])
extends Logging {
def this(master: String, frameworkName: String) = this(master, frameworkName, null, Nil)
// Ensure logging is initialized before we spawn any threads
@ -75,24 +79,33 @@ class SparkContext(
// Used to store a URL for each static file/jar together with the file's local timestamp
val addedFiles = HashMap[String, Long]()
val addedJars = HashMap[String, Long]()
// Add each JAR given through the constructor
jars.foreach { addJar(_) }
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+),([0-9]+)\]""".r
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """(spark://.*)""".r
master match {
case "local" =>
new LocalScheduler(1, 0)
case "local" =>
new LocalScheduler(1, 0, this)
case LOCAL_N_REGEX(threads) =>
new LocalScheduler(threads.toInt, 0)
case LOCAL_N_REGEX(threads) =>
new LocalScheduler(threads.toInt, 0, this)
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
new LocalScheduler(threads.toInt, maxFailures.toInt)
new LocalScheduler(threads.toInt, maxFailures.toInt, this)
case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this)
@ -100,6 +113,28 @@ class SparkContext(
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
// Check to make sure SPARK_MEM <= memoryPerSlave. Otherwise Spark will just hang.
val memoryPerSlaveInt = memoryPerSlave.toInt
val sparkMemEnv = System.getenv("SPARK_MEM")
val sparkMemEnvInt = if (sparkMemEnv != null) Utils.memoryStringToMb(sparkMemEnv) else 512
if (sparkMemEnvInt > memoryPerSlaveInt) {
throw new SparkException(
"Slave memory (%d MB) cannot be smaller than SPARK_MEM (%d MB)".format(
memoryPerSlaveInt, sparkMemEnvInt))
val scheduler = new ClusterScheduler(this)
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
val sparkUrl = localCluster.start()
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, frameworkName)
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
case _ =>
val scheduler = new ClusterScheduler(this)
@ -122,7 +157,7 @@ class SparkContext(
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = {
new ParallelCollection[T](this, seq, numSlices)
def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = {
parallelize(seq, numSlices)
@ -163,14 +198,14 @@ class SparkContext(
* Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
* Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
* values and the InputFormat so that users don't need to pass them directly.
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F])
: RDD[(K, V)] = {
@ -191,7 +226,7 @@ class SparkContext(
new Configuration)
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
@ -207,7 +242,7 @@ class SparkContext(
new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf)
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
@ -233,14 +268,14 @@ class SparkContext(
sequenceFile(path, keyClass, valueClass, defaultMinSplits)
* Version of sequenceFile() for types implicitly convertible to Writables through a
* Version of sequenceFile() for types implicitly convertible to Writables through a
* WritableConverter.
* WritableConverters are provided in a somewhat strange way (by an implicit function) to support
* both subclasses of Writable and types for which we define a converter (e.g. Int to
* both subclasses of Writable and types for which we define a converter (e.g. Int to
* IntWritable). The most natural thing would've been to have implicit objects for the
* converters, but then we couldn't have an object for every subclass of Writable (you can't
* have a parameterized singleton object). We use functions instead to create a new converter
* have a parameterized singleton object). We use functions instead to create a new converter
* for the appropriate type. In addition, we pass the converter a ClassManifest of its type to
* allow it to figure out the Writable class to use in the subclass case.
@ -265,7 +300,7 @@ class SparkContext(
* that there's very little effort required to save arbitrary objects.
def objectFile[T: ClassManifest](
path: String,
path: String,
minSplits: Int = defaultMinSplits
): RDD[T] = {
sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minSplits)
@ -292,10 +327,57 @@ class SparkContext(
def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
new Accumulable(initialValue, param)
* Create an accumulator from a "mutable collection" type.
* Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
* standard mutable collections. So you can use this with mutable Map, Set, etc.
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = {
val param = new GrowableAccumulableParam[R,T]
new Accumulable(initialValue, param)
// Keep around a weak hash map of values to Cached versions?
def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
// Adds a file dependency to all Tasks executed in the future.
def addFile(path: String) {
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
case _ => path
addedFiles(key) = System.currentTimeMillis
// Fetch the file locally in case the task is executed locally
val filename = new File(path.split("/").last)
Utils.fetchFile(path, new File("."))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
def clearFiles() {
addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
// Adds a jar dependency to all Tasks executed in the future.
def addJar(path: String) {
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addJar(new File(uri.getPath))
case _ => path
addedJars(key) = System.currentTimeMillis
logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
def clearJars() {
addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
// Stop the SparkContext
def stop() {
@ -303,6 +385,9 @@ class SparkContext(
taskScheduler = null
// TODO: Cache.stop()?
// Clean up locally linked files
logInfo("Successfully stopped SparkContext")
@ -326,7 +411,7 @@ class SparkContext(
* Run a function on a given set of partitions in an RDD and return the results. This is the main
* entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies
* whether the scheduler can run the computation on the master rather than shipping it out to the
* whether the scheduler can run the computation on the master rather than shipping it out to the
* cluster, for short actions like first().
def runJob[T, U: ClassManifest](
@ -335,22 +420,23 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
logInfo("Starting job...")
val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, allowLocal)
logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
def runJob[T, U: ClassManifest](
rdd: RDD[T],
func: Iterator[T] => U,
func: Iterator[T] => U,
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal)
* Run a job on all partitions in an RDD and return the results in an array.
@ -371,10 +457,11 @@ class SparkContext(
evaluator: ApproximateEvaluator[U, R],
timeout: Long
): PartialResult[R] = {
logInfo("Starting job...")
val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, timeout)
logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
@ -396,7 +483,7 @@ class SparkContext(
private[spark] def newShuffleId(): Int = {
private var nextRddId = new AtomicInteger(0)
// Register a new RDD, returning its RDD ID
@ -424,7 +511,7 @@ object SparkContext {
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
@ -445,7 +532,7 @@ object SparkContext {
implicit def longToLongWritable(l: Long) = new LongWritable(l)
implicit def floatToFloatWritable(f: Float) = new FloatWritable(f)
implicit def doubleToDoubleWritable(d: Double) = new DoubleWritable(d)
implicit def boolToBoolWritable (b: Boolean) = new BooleanWritable(b)
@ -456,7 +543,7 @@ object SparkContext {
private implicit def arrayToArrayWritable[T <% Writable: ClassManifest](arr: Traversable[T]): ArrayWritable = {
def anyToWritable[U <% Writable](u: U): Writable = u
new ArrayWritable(classManifest[T].erasure.asInstanceOf[Class[Writable]],
arr.map(x => anyToWritable(x)).toArray)
@ -500,7 +587,7 @@ object SparkContext {
// Find the JAR that contains the class of a particular object
def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass)
@ -513,7 +600,7 @@ object SparkContext {
* that doesn't know the type of T when it is created. This sounds strange but is necessary to
* support converting subclasses of Writable to themselves (writableWritableConverter).
class WritableConverter[T](
private[spark] class WritableConverter[T](
val writableClass: ClassManifest[T] => Class[_ <: Writable],
val convert: Writable => T)
extends Serializable
@ -1,6 +1,8 @@
package spark
import akka.actor.ActorSystem
import akka.actor.ActorSystemImpl
import akka.remote.RemoteActorRefProvider
import spark.broadcast.BroadcastManager
import spark.storage.BlockManager
@ -8,35 +10,45 @@ import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
import spark.util.AkkaUtils
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
* including the serializer, Akka actor system, block manager, map output tracker, etc. Currently
* Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these
* objects needs to have the right SparkEnv set. You can get the current environment with
* SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
class SparkEnv (
val actorSystem: ActorSystem,
val cache: Cache,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheTracker: CacheTracker,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager
val connectionManager: ConnectionManager,
val httpFileServer: HttpFileServer
) {
/** No-parameter constructor for unit tests. */
def this() = {
this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
def stop() {
// Akka's awaitTermination doesn't actually wait until the port is unbound, so sleep a bit
// Akka's awaitTermination doesn't actually wait until the port is unbound, so sleep a bit
@ -66,66 +78,49 @@ object SparkEnv {
System.setProperty("spark.master.port", boundPort.toString)
val serializerClass = System.getProperty("spark.serializer", "spark.KryoSerializer")
val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
val classLoader = Thread.currentThread.getContextClassLoader
// Create an instance of the class named by the given Java system property, or by
// defaultClassName if the property is not set, and return it as a T
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
val name = System.getProperty(propertyName, defaultClassName)
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
val blockManagerMaster = new BlockManagerMaster(actorSystem, isMaster, isLocal)
val blockManager = new BlockManager(blockManagerMaster, serializer)
val connectionManager = blockManager.connectionManager
val shuffleManager = new ShuffleManager()
val connectionManager = blockManager.connectionManager
val broadcastManager = new BroadcastManager(isMaster)
val closureSerializerClass =
System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
val closureSerializer =
val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache")
val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
val closureSerializer = instantiateClass[Serializer](
"spark.closure.serializer", "spark.JavaSerializer")
val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager)
blockManager.cacheTracker = cacheTracker
val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster)
val shuffleFetcherClass =
System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
val shuffleFetcher =
if (System.getProperty("spark.stream.distributed", "false") == "true") {
val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]]
if (isLocal || !isMaster) {
(new Thread() {
override def run() {
println("Wait started")
println("Wait ended")
val receiverClass = Class.forName("spark.stream.TestStreamReceiver4")
val constructor = receiverClass.getConstructor(blockManagerClass)
val receiver = constructor.newInstance(blockManager)
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
val httpFileServer = new HttpFileServer()
System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
new SparkEnv(
@ -7,10 +7,16 @@ import spark.storage.BlockManagerId
* tasks several times for "ephemeral" failures, and only report back failures that require some
* old stages to be resubmitted, such as shuffle map fetch failures.
sealed trait TaskEndReason
private[spark] sealed trait TaskEndReason
case object Success extends TaskEndReason
private[spark] case object Success extends TaskEndReason
case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
case class ExceptionFailure(exception: Throwable) extends TaskEndReason
case class OtherFailure(message: String) extends TaskEndReason
private[spark] case class ExceptionFailure(exception: Throwable) extends TaskEndReason
private[spark] case class OtherFailure(message: String) extends TaskEndReason
@ -2,7 +2,7 @@ package spark
import org.apache.mesos.Protos.{TaskState => MesosTaskState}
object TaskState
private[spark] object TaskState
extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") {
@ -2,7 +2,7 @@ package spark
import scala.collection.mutable.ArrayBuffer
class UnionSplit[T: ClassManifest](
private[spark] class UnionSplit[T: ClassManifest](
idx: Int,
rdd: RDD[T],
split: Split)
@ -1,18 +1,18 @@
package spark
import java.io._
import java.net.InetAddress
import java.net.{InetAddress, URL, URI}
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import java.util.{Locale, UUID}
import scala.io.Source
* Various utility methods used by Spark.
object Utils {
private object Utils extends Logging {
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@ -116,22 +116,75 @@ object Utils {
copyStream(in, out, true)
/** Download a file from a given URL to the local filesystem */
def downloadFile(url: URL, localPath: String) {
val in = url.openStream()
val out = new FileOutputStream(localPath)
Utils.copyStream(in, out, true)
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
val targetFile = new File(targetDir, filename)
val uri = new URI(url)
uri.getScheme match {
case "http" | "https" | "ftp" =>
logInfo("Fetching " + url + " to " + targetFile)
val in = new URL(url).openStream()
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
case "file" | null =>
// Remove the file if it already exists
// Symlink the file locally
logInfo("Symlinking " + url + " to " + targetFile)
FileUtil.symLink(url, targetFile.toString)
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val uri = new URI(url)
val conf = new Configuration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
// Decompress the file if it's a .tar or .tar.gz
if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xzf", filename), targetDir)
} else if (filename.endsWith(".tar")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xf", filename), targetDir)
// Make the file executable - That's necessary for scripts
FileUtil.chmod(filename, "a+x")
* Shuffle the elements of a collection into a random order, returning the
* result in a new collection. Unlike scala.util.Random.shuffle, this method
* uses a local random number generator, avoiding inter-thread contention.
def randomize[T](seq: TraversableOnce[T]): Seq[T] = {
val buf = new ArrayBuffer[T]()
buf ++= seq
val rand = new Random()
for (i <- (buf.size - 1) to 1 by -1) {
def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = {
* Shuffle the elements of an array into a random order, modifying the
* original array. Returns the original array.
def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = {
for (i <- (arr.length - 1) to 1 by -1) {
val j = rand.nextInt(i)
val tmp = buf(j)
buf(j) = buf(i)
buf(i) = tmp
val tmp = arr(j)
arr(j) = arr(i)
arr(i) = tmp
@ -294,4 +347,43 @@ object Utils {
def execute(command: Seq[String]) {
execute(command, new File("."))
* When called inside a class in the spark package, returns the name of the user code class
* (outside the spark package) that called into Spark, as well as which Spark method they called.
* This is used, for example, to tell users where in their code each RDD got created.
def getSparkCallSite: String = {
val trace = Thread.currentThread.getStackTrace().filter( el =>
// Keep crawling up the stack trace until we find the first function not inside of the spark
// package. We track the last (shallowest) contiguous Spark method. This might be an RDD
// transformation, a SparkContext function (such as parallelize), or anything else that leads
// to instantiation of an RDD. We also track the first (deepest) user method, file, and line.
var lastSparkMethod = "<unknown>"
var firstUserFile = "<unknown>"
var firstUserLine = 0
var finished = false
for (el <- trace) {
if (!finished) {
if (el.getClassName.startsWith("spark.") && !el.getClassName.startsWith("spark.examples.")) {
lastSparkMethod = if (el.getMethodName == "<init>") {
// Spark method is a constructor; get its class name
el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1)
} else {
else {
firstUserLine = el.getLineNumber
firstUserFile = el.getFileName
finished = true
"%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
@ -33,6 +33,8 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
def distinct(): JavaDoubleRDD = fromRDD(srdd.distinct())
def distinct(numSplits: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numSplits))
def filter(f: JFunction[Double, java.lang.Boolean]): JavaDoubleRDD =
fromRDD(srdd.filter(x => f(x).booleanValue()))
@ -40,6 +40,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
def distinct(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct())
def distinct(numSplits: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numSplits))
def filter(f: Function[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue()))
@ -19,6 +19,8 @@ JavaRDDLike[T, JavaRDD[T]] {
def distinct(): JavaRDD[T] = wrapRDD(rdd.distinct())
def distinct(numSplits: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numSplits))
def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] =
wrapRDD(rdd.filter((x => f(x).booleanValue())))
@ -7,7 +7,7 @@ import scala.runtime.AbstractFunction1
* apply() method as call() and declare that it can throw Exception (since AbstractFunction1.apply
* isn't marked to allow that).
abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] {
private[spark] abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] {
def call(t: T): R
@ -7,7 +7,7 @@ import scala.runtime.AbstractFunction2
* apply() method as call() and declare that it can throw Exception (since AbstractFunction2.apply
* isn't marked to allow that).
abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] {
private[spark] abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] {
def call(t1: T1, t2: T2): R
@ -11,14 +11,17 @@ import scala.math
import spark._
import spark.storage.StorageLevel
class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id)
with Logging
with Serializable {
def value = value_
def blockId: String = "broadcast_" + id
MultiTracker.synchronized {
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@ -45,7 +48,7 @@ extends Broadcast[T] with Logging with Serializable {
// Used only in Workers
@transient var ttGuide: TalkToGuide = null
@transient var hostAddress = Utils.localIpAddress
@transient var hostAddress = Utils.localIpAddress()
@transient var listenPort = -1
@transient var guidePort = -1
@ -53,7 +56,7 @@ extends Broadcast[T] with Logging with Serializable {
// Must call this after all the variables have been created/initialized
if (!isLocal) {
def sendBroadcast() {
@ -106,20 +109,22 @@ extends Broadcast[T] with Logging with Serializable {
listOfSources += masterSource
// Register with the Tracker
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
private def readObject(in: ObjectInputStream) {
MultiTracker.synchronized {
SparkEnv.get.blockManager.getSingle(uuid.toString) match {
case Some(x) => x.asInstanceOf[T]
case None => {
logInfo("Started reading broadcast variable " + uuid)
SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) =>
value_ = x.asInstanceOf[T]
case None =>
logInfo("Started reading broadcast variable " + id)
// Initializing everything because Master will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
logInfo("Local host address: " + hostAddress)
@ -131,18 +136,17 @@ extends Broadcast[T] with Logging with Serializable {
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(uuid)
val receptionSucceeded = receiveBroadcast(id)
if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
} else {
logError("Reading Broadcasted variable " + uuid + " failed")
logError("Reading broadcast variable " + id + " failed")
val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
logInfo("Reading broadcast variable " + id + " took " + time + " s")
@ -254,8 +258,8 @@ extends Broadcast[T] with Logging with Serializable {
def receiveBroadcast(variableUUID: UUID): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableUUID)
def receiveBroadcast(variableID: Long): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableID)
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false
@ -764,7 +768,7 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("Sending stopBroadcast notifications...")
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
@ -1025,9 +1029,12 @@ extends Broadcast[T] with Logging with Serializable {
class BitTorrentBroadcastFactory
private[spark] class BitTorrentBroadcastFactory
extends BroadcastFactory {
def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) = new BitTorrentBroadcast[T](value_, isLocal)
def stop() = MultiTracker.stop
def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new BitTorrentBroadcast[T](value_, isLocal, id)
def stop() { MultiTracker.stop() }
@ -1,25 +1,20 @@
package spark.broadcast
import java.io._
import java.net._
import java.util.{BitSet, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import scala.collection.mutable.Map
import java.util.concurrent.atomic.AtomicLong
import spark._
trait Broadcast[T] extends Serializable {
val uuid = UUID.randomUUID
abstract class Broadcast[T](id: Long) extends Serializable {
def value: T
// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes.
override def toString = "spark.Broadcast(" + uuid + ")"
override def toString = "spark.Broadcast(" + id + ")"
class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable {
private var initialized = false
@ -49,14 +44,10 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl
private def getBroadcastFactory: BroadcastFactory = {
if (broadcastFactory == null) {
throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
private val nextBroadcastId = new AtomicLong(0)
def newBroadcast[T](value_ : T, isLocal: Boolean) = broadcastFactory.newBroadcast[T](value_, isLocal)
def newBroadcast[T](value_ : T, isLocal: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
def isMaster = isMaster_
@ -6,8 +6,8 @@ package spark.broadcast
* BroadcastFactory implementation to instantiate a particular broadcast for the
* entire Spark job.
trait BroadcastFactory {
private[spark] trait BroadcastFactory {
def initialize(isMaster: Boolean): Unit
def newBroadcast[T](value_ : T, isLocal: Boolean): Broadcast[T]
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
@ -12,44 +12,47 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark._
import spark.storage.StorageLevel
class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def blockId: String = "broadcast_" + id
HttpBroadcast.synchronized {
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
if (!isLocal) {
HttpBroadcast.write(uuid, value_)
HttpBroadcast.write(id, value_)
// Called by JVM when deserializing an object
private def readObject(in: ObjectInputStream) {
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(uuid.toString) match {
SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) => value_ = x.asInstanceOf[T]
case None => {
logInfo("Started reading broadcast variable " + uuid)
logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](uuid)
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
value_ = HttpBroadcast.read[T](id)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + uuid + " took " + time + " s")
logInfo("Reading broadcast variable " + id + " took " + time + " s")
class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isMaster: Boolean) = HttpBroadcast.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) = new HttpBroadcast[T](value_, isLocal)
def stop() = HttpBroadcast.stop()
private[spark] class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isMaster: Boolean) { HttpBroadcast.initialize(isMaster) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
def stop() { HttpBroadcast.stop() }
private object HttpBroadcast extends Logging {
@ -65,7 +68,7 @@ private object HttpBroadcast extends Logging {
synchronized {
if (!initialized) {
bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
compress = System.getProperty("spark.compress", "false").toBoolean
compress = System.getProperty("spark.broadcast.compress", "true").toBoolean
if (isMaster) {
@ -76,9 +79,12 @@ private object HttpBroadcast extends Logging {
def stop() {
if (server != null) {
server = null
synchronized {
if (server != null) {
server = null
initialized = false
@ -91,8 +97,8 @@ private object HttpBroadcast extends Logging {
logInfo("Broadcast server started at " + serverUri)
def write(uuid: UUID, value: Any) {
val file = new File(broadcastDir, "broadcast-" + uuid)
def write(id: Long, value: Any) {
val file = new File(broadcastDir, "broadcast-" + id)
val out: OutputStream = if (compress) {
new LZFOutputStream(new FileOutputStream(file)) // Does its own buffering
} else {
@ -104,8 +110,8 @@ private object HttpBroadcast extends Logging {
def read[T](uuid: UUID): T = {
val url = serverUri + "/broadcast-" + uuid
def read[T](id: Long): T = {
val url = serverUri + "/broadcast-" + id
var in = if (compress) {
new LZFInputStream(new URL(url).openStream()) // Does its own buffering
} else {
@ -2,8 +2,7 @@ package spark.broadcast
import java.io._
import java.net._
import java.util.{UUID, Random}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import java.util.Random
import scala.collection.mutable.Map
@ -18,7 +17,7 @@ extends Logging {
// Map to keep track of guides of ongoing broadcasts
var valueToGuideMap = Map[UUID, SourceInfo]()
var valueToGuideMap = Map[Long, SourceInfo]()
// Random number generator
var ranGen = new Random
@ -154,44 +153,44 @@ extends Logging {
val messageType = ois.readObject.asInstanceOf[Int]
// Receive UUID
val uuid = ois.readObject.asInstanceOf[UUID]
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
// Receive hostAddress and listenPort
val gInfo = ois.readObject.asInstanceOf[SourceInfo]
// Add to the map
valueToGuideMap.synchronized {
valueToGuideMap += (uuid -> gInfo)
valueToGuideMap += (id -> gInfo)
logInfo ("New broadcast " + uuid + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
} else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
// Receive UUID
val uuid = ois.readObject.asInstanceOf[UUID]
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
// Remove from the map
valueToGuideMap.synchronized {
valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToDefault)
valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
logInfo ("Broadcast " + uuid + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
} else if (messageType == FIND_BROADCAST_TRACKER) {
// Receive UUID
val uuid = ois.readObject.asInstanceOf[UUID]
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
var gInfo =
if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid)
if (valueToGuideMap.contains(id)) valueToGuideMap(id)
else SourceInfo("", SourceInfo.TxNotStartedRetry)
logDebug("Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort)
logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)
// Send reply back
@ -224,7 +223,7 @@ extends Logging {
def getGuideInfo(variableUUID: UUID): SourceInfo = {
def getGuideInfo(variableLong: Long): SourceInfo = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
@ -247,8 +246,8 @@ extends Logging {
// Send UUID and receive GuideInfo
// Send Long and receive GuideInfo
gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
} catch {
@ -276,7 +275,7 @@ extends Logging {
return gInfo
def registerBroadcast(uuid: UUID, gInfo: SourceInfo) {
def registerBroadcast(id: Long, gInfo: SourceInfo) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
@ -286,8 +285,8 @@ extends Logging {
// Send UUID of this broadcast
// Send Long of this broadcast
// Send this tracker's information
@ -303,7 +302,7 @@ extends Logging {
def unregisterBroadcast(uuid: UUID) {
def unregisterBroadcast(id: Long) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
@ -313,8 +312,8 @@ extends Logging {
// Send UUID of this broadcast
// Send Long of this broadcast
// Receive ACK and throw it away
@ -383,10 +382,10 @@ extends Logging {
case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
extends Serializable
case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {
@ -7,7 +7,7 @@ import spark._
* Used to keep and pass around information of peers involved in a broadcast
case class SourceInfo (hostAddress: String,
private[spark] case class SourceInfo (hostAddress: String,
listenPort: Int,
totalBlocks: Int = SourceInfo.UnusedParam,
totalBytes: Int = SourceInfo.UnusedParam)
@ -26,7 +26,7 @@ extends Comparable[SourceInfo] with Logging {
* Helper Object of SourceInfo for its constants
object SourceInfo {
private[spark] object SourceInfo {
// Broadcast has not started yet! Should never happen.
val TxNotStartedRetry = -1
// Broadcast has already finished. Try default mechanism.
@ -10,14 +10,15 @@ import scala.math
import spark._
import spark.storage.StorageLevel
class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def blockId = "broadcast_" + id
MultiTracker.synchronized {
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@ -35,7 +36,7 @@ extends Broadcast[T] with Logging with Serializable {
@transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null
@transient var hostAddress = Utils.localIpAddress
@transient var hostAddress = Utils.localIpAddress()
@transient var listenPort = -1
@transient var guidePort = -1
@ -43,7 +44,7 @@ extends Broadcast[T] with Logging with Serializable {
// Must call this after all the variables have been created/initialized
if (!isLocal) {
def sendBroadcast() {
@ -84,20 +85,22 @@ extends Broadcast[T] with Logging with Serializable {
listOfSources += masterSource
// Register with the Tracker
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
private def readObject(in: ObjectInputStream) {
MultiTracker.synchronized {
SparkEnv.get.blockManager.getSingle(uuid.toString) match {
case Some(x) => x.asInstanceOf[T]
case None => {
logInfo("Started reading broadcast variable " + uuid)
SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) =>
value_ = x.asInstanceOf[T]
case None =>
logInfo("Started reading broadcast variable " + id)
// Initializing everything because Master will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
logInfo("Local host address: " + hostAddress)
@ -108,18 +111,17 @@ extends Broadcast[T] with Logging with Serializable {
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(uuid)
val receptionSucceeded = receiveBroadcast(id)
if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
} else {
logError("Reading Broadcasted variable " + uuid + " failed")
logError("Reading broadcast variable " + id + " failed")
val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
logInfo("Reading broadcast variable " + id + " took " + time + " s")
@ -136,14 +138,14 @@ extends Broadcast[T] with Logging with Serializable {
serveMR = null
hostAddress = Utils.localIpAddress
hostAddress = Utils.localIpAddress()
listenPort = -1
stopBroadcast = false
def receiveBroadcast(variableUUID: UUID): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableUUID)
def receiveBroadcast(variableID: Long): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableID)
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false
@ -318,7 +320,7 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("Sending stopBroadcast notifications...")
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
@ -572,9 +574,12 @@ extends Broadcast[T] with Logging with Serializable {
class TreeBroadcastFactory
private[spark] class TreeBroadcastFactory
extends BroadcastFactory {
def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) = new TreeBroadcast[T](value_, isLocal)
def stop() = MultiTracker.stop
def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TreeBroadcast[T](value_, isLocal, id)
def stop() { MultiTracker.stop() }
@ -2,7 +2,7 @@ package spark.deploy
import scala.collection.Map
case class Command(
private[spark] case class Command(
mainClass: String,
arguments: Seq[String],
environment: Map[String, String]) {
@ -7,13 +7,15 @@ import scala.collection.immutable.List
import scala.collection.mutable.HashMap
sealed trait DeployMessage extends Serializable
private[spark] sealed trait DeployMessage extends Serializable
// Worker to Master
case class RegisterWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int)
extends DeployMessage
case class ExecutorStateChanged(
jobId: String,
execId: Int,
@ -23,11 +25,11 @@ case class ExecutorStateChanged(
// Master to Worker
case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
case class RegisterWorkerFailed(message: String) extends DeployMessage
case class KillExecutor(jobId: String, execId: Int) extends DeployMessage
private[spark] case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
private[spark] case class RegisterWorkerFailed(message: String) extends DeployMessage
private[spark] case class KillExecutor(jobId: String, execId: Int) extends DeployMessage
case class LaunchExecutor(
private[spark] case class LaunchExecutor(
jobId: String,
execId: Int,
jobDesc: JobDescription,
@ -38,33 +40,42 @@ case class LaunchExecutor(
// Client to Master
case class RegisterJob(jobDescription: JobDescription) extends DeployMessage
private[spark] case class RegisterJob(jobDescription: JobDescription) extends DeployMessage
// Master to Client
case class RegisteredJob(jobId: String) extends DeployMessage
case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int)
case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String])
case class JobKilled(message: String)
// Internal message in Client
case object StopClient
private[spark] case object StopClient
// MasterWebUI To Master
case object RequestMasterState
private[spark] case object RequestMasterState
// Master to MasterWebUI
case class MasterState(uri : String, workers: List[WorkerInfo], activeJobs: List[JobInfo],
completedJobs: List[JobInfo])
// WorkerWebUI to Worker
case object RequestWorkerState
private[spark] case object RequestWorkerState
// Worker to WorkerWebUI
case class WorkerState(uri: String, workerId: String, executors: List[ExecutorRunner],
finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int,
coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String)
@ -1,6 +1,6 @@
package spark.deploy
object ExecutorState
private[spark] object ExecutorState
extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") {
@ -1,6 +1,6 @@
package spark.deploy
class JobDescription(
private[spark] class JobDescription(
val name: String,
val cores: Int,
val memoryPerSlave: Int,
@ -0,0 +1,58 @@
package spark.deploy
import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
import spark.deploy.worker.Worker
import spark.deploy.master.Master
import spark.util.AkkaUtils
import spark.{Logging, Utils}
import scala.collection.mutable.ArrayBuffer
class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging {
val localIpAddress = Utils.localIpAddress
var masterActor : ActorRef = _
var masterActorSystem : ActorSystem = _
var masterPort : Int = _
var masterUrl : String = _
val slaveActorSystems = ArrayBuffer[ActorSystem]()
val slaveActors = ArrayBuffer[ActorRef]()
def start() : String = {
logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.")
/* Start the Master */
val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0)
masterActorSystem = actorSystem
masterUrl = "spark://" + localIpAddress + ":" + masterPort
val actor = masterActorSystem.actorOf(
Props(new Master(localIpAddress, masterPort, 0)), name = "Master")
masterActor = actor
/* Start the Slaves */
for (slaveNum <- 1 to numSlaves) {
val (actorSystem, boundPort) =
AkkaUtils.createActorSystem("sparkWorker" + slaveNum, localIpAddress, 0)
slaveActorSystems += actorSystem
val actor = actorSystem.actorOf(
Props(new Worker(localIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)),
name = "Worker")
slaveActors += actor
return masterUrl
def stop() {
logInfo("Shutting down local Spark cluster.")
// Stop the slaves before the master so they don't get upset that it disconnected
@ -16,7 +16,7 @@ import akka.dispatch.Await
* The main class used to talk to a Spark deploy cluster. Takes a master URL, a job description,
* and a listener for job events, and calls back the listener when various events occur.
class Client(
private[spark] class Client(
actorSystem: ActorSystem,
masterUrl: String,
jobDescription: JobDescription,
@ -42,7 +42,6 @@ class Client(
val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
try {
master = context.actorFor(akkaUrl)
//master ! RegisterWorker(ip, port, cores, memory)
master ! RegisterJob(jobDescription)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
@ -7,7 +7,7 @@ package spark.deploy.client
* Users of this API should *not* block inside the callback methods.
trait ClientListener {
private[spark] trait ClientListener {
def connected(jobId: String): Unit
def disconnected(): Unit
@ -4,7 +4,7 @@ import spark.util.AkkaUtils
import spark.{Logging, Utils}
import spark.deploy.{Command, JobDescription}
object TestClient {
private[spark] object TestClient {
class TestListener extends ClientListener with Logging {
def connected(id: String) {
@ -1,6 +1,6 @@
package spark.deploy.client
object TestExecutor {
private[spark] object TestExecutor {
def main(args: Array[String]) {
println("Hello world!")
while (true) {
@ -2,7 +2,7 @@ package spark.deploy.master
import spark.deploy.ExecutorState
class ExecutorInfo(
private[spark] class ExecutorInfo(
val id: Int,
val job: JobInfo,
val worker: WorkerInfo,
@ -5,6 +5,7 @@ import java.util.Date
import akka.actor.ActorRef
import scala.collection.mutable
class JobInfo(val id: String, val desc: JobDescription, val submitDate: Date, val actor: ActorRef) {
var state = JobState.WAITING
var executors = new mutable.HashMap[Int, ExecutorInfo]
@ -31,4 +32,13 @@ class JobInfo(val id: String, val desc: JobDescription, val submitDate: Date, va
def coresLeft: Int = desc.cores - coresGranted
private var _retryCount = 0
def retryCount = _retryCount
def incrementRetryCount = {
_retryCount += 1
@ -1,7 +1,9 @@
package spark.deploy.master
object JobState extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") {
private[spark] object JobState extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") {
type JobState = Value
val MAX_NUM_RETRY = 10
@ -1,21 +1,20 @@
package spark.deploy.master
import akka.actor._
import akka.actor.Terminated
import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown}
import java.text.SimpleDateFormat
import java.util.Date
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import akka.actor._
import spark.{Logging, Utils}
import spark.util.AkkaUtils
import java.text.SimpleDateFormat
import java.util.Date
import akka.remote.RemoteClientLifeCycleEvent
import spark.deploy._
import akka.remote.RemoteClientShutdown
import akka.remote.RemoteClientDisconnected
import spark.deploy.RegisterWorker
import spark.deploy.RegisterWorkerFailed
import akka.actor.Terminated
import spark.{Logging, SparkException, Utils}
import spark.util.AkkaUtils
class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For job IDs
var nextJobNumber = 0
@ -81,12 +80,22 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
exec.state = state
exec.job.actor ! ExecutorUpdated(execId, state, message)
if (ExecutorState.isFinished(state)) {
val jobInfo = idToJob(jobId)
// Remove this executor from the worker and job
logInfo("Removing executor " + exec.fullId + " because it is " + state)
// TODO: the worker would probably want to restart the executor a few times
// Only retry certain number of times so we don't go into an infinite loop.
if (jobInfo.incrementRetryCount <= JobState.MAX_NUM_RETRY) {
} else {
val e = new SparkException("Job %s wth ID %s failed %d times.".format(
jobInfo.desc.name, jobInfo.id, jobInfo.retryCount))
logError(e.getMessage, e)
throw e
case None =>
@ -112,7 +121,7 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
case RequestMasterState => {
sender ! MasterState(ip + ":" + port, workers.toList, jobs.toList, completedJobs.toList)
@ -203,7 +212,7 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
object Master {
private[spark] object Master {
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
@ -6,7 +6,7 @@ import spark.Utils
* Command-line parser for the master.
class MasterArguments(args: Array[String]) {
private[spark] class MasterArguments(args: Array[String]) {
var ip = Utils.localIpAddress()
var port = 7077
var webUiPort = 8080
@ -51,7 +51,7 @@ class MasterArguments(args: Array[String]) {
def printUsageAndExit(exitCode: Int) {
"Usage: spark-master [options]\n" +
"Usage: Master [options]\n" +
"\n" +
"Options:\n" +
" -i IP, --ip IP IP address or DNS name to listen on\n" +
@ -10,6 +10,7 @@ import cc.spray.directives._
import cc.spray.typeconversion.TwirlSupport._
import spark.deploy._
class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/master/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
@ -22,7 +23,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
completeWith {
val future = master ? RequestMasterState
future.map {
masterState => masterui.html.index.render(masterState.asInstanceOf[MasterState])
masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState])
} ~
@ -36,7 +37,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
// A bit ugly an inefficient, but we won't have a number of jobs
// so large that it will make a significant difference.
(masterState.activeJobs ::: masterState.completedJobs).find(_.id == jobId) match {
case Some(job) => masterui.html.job_details.render(job)
case Some(job) => spark.deploy.master.html.job_details.render(job)
case _ => null
@ -3,7 +3,7 @@ package spark.deploy.master
import akka.actor.ActorRef
import scala.collection.mutable
class WorkerInfo(
private[spark] class WorkerInfo(
val id: String,
val host: String,
val port: Int,
@ -13,7 +13,7 @@ import spark.deploy.ExecutorStateChanged
* Manages the execution of one executor process.
class ExecutorRunner(
private[spark] class ExecutorRunner(
val jobId: String,
val execId: Int,
val jobDesc: JobDescription,
@ -29,12 +29,25 @@ class ExecutorRunner(
val fullId = jobId + "/" + execId
var workerThread: Thread = null
var process: Process = null
var shutdownHook: Thread = null
def start() {
workerThread = new Thread("ExecutorRunner for " + fullId) {
override def run() { fetchAndRunExecutor() }
// Shutdown hook that kills actors on shutdown.
shutdownHook = new Thread() {
override def run() {
if (process != null) {
logInfo("Shutdown hook killing child process.")
/** Stop this executor runner, including killing the process it launched */
@ -45,40 +58,10 @@ class ExecutorRunner(
if (process != null) {
logInfo("Killing process!")
worker ! ExecutorStateChanged(jobId, execId, ExecutorState.KILLED, None)
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
val targetFile = new File(targetDir, filename)
if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) {
// Use the java.net library to fetch it
logInfo("Fetching " + url + " to " + targetFile)
val in = new URL(url).openStream()
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
} else {
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val uri = new URI(url)
val conf = new Configuration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
// Decompress the file if it's a .tar or .tar.gz
if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xzf", filename), targetDir)
} else if (filename.endsWith(".tar")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xf", filename), targetDir)
@ -92,7 +75,8 @@ class ExecutorRunner(
def buildCommandSeq(): Seq[String] = {
val command = jobDesc.command
val runScript = new File(sparkHome, "run").getCanonicalPath
val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run";
val runScript = new File(sparkHome, script).getCanonicalPath
Seq(runScript, command.mainClass) ++ command.arguments.map(substituteVariables)
@ -101,7 +85,12 @@ class ExecutorRunner(
val out = new FileOutputStream(file)
new Thread("redirect output to " + file) {
override def run() {
Utils.copyStream(in, out, true)
try {
Utils.copyStream(in, out, true)
} catch {
case e: IOException =>
logInfo("Redirection to " + file + " closed: " + e.getMessage)
@ -131,6 +120,9 @@ class ExecutorRunner(
env.put("SPARK_CORES", cores.toString)
env.put("SPARK_MEMORY", memory.toString)
// In case we are running this from within the Spark Shell
// so we are not creating a parent process.
process = builder.start()
// Redirect its stdout and stderr to files
@ -16,7 +16,14 @@ import spark.deploy.RegisterWorkerFailed
import akka.actor.Terminated
import java.io.File
class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, masterUrl: String)
private[spark] class Worker(
ip: String,
port: Int,
webUiPort: Int,
cores: Int,
memory: Int,
masterUrl: String,
workDirPath: String = null)
extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
@ -37,7 +44,11 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, mas
def memoryFree: Int = memory - memoryUsed
def createWorkDir() {
workDir = new File(sparkHome, "work")
workDir = if (workDirPath != null) {
new File(workDirPath)
} else {
new File(sparkHome, "work")
try {
if (!workDir.exists() && !workDir.mkdirs()) {
logError("Failed to create work directory " + workDir)
@ -153,14 +164,19 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, mas
def generateWorkerId(): String = {
"worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port)
override def postStop() {
object Worker {
private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
val actor = actorSystem.actorOf(
Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory, args.master)),
Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory,
args.master, args.workDir)),
name = "Worker")
@ -8,13 +8,14 @@ import java.lang.management.ManagementFactory
* Command-line parser for the master.
class WorkerArguments(args: Array[String]) {
private[spark] class WorkerArguments(args: Array[String]) {
var ip = Utils.localIpAddress()
var port = 0
var webUiPort = 8081
var cores = inferDefaultCores()
var memory = inferDefaultMemory()
var master: String = null
var workDir: String = null
// Check for settings in environment variables
if (System.getenv("SPARK_WORKER_PORT") != null) {
@ -29,6 +30,9 @@ class WorkerArguments(args: Array[String]) {
if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) {
webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt
if (System.getenv("SPARK_WORKER_DIR") != null) {
workDir = System.getenv("SPARK_WORKER_DIR")
@ -49,6 +53,10 @@ class WorkerArguments(args: Array[String]) {
memory = value
case ("--work-dir" | "-d") :: value :: tail =>
workDir = value
case "--webui-port" :: IntParam(value) :: tail =>
webUiPort = value
@ -77,13 +85,14 @@ class WorkerArguments(args: Array[String]) {
def printUsageAndExit(exitCode: Int) {
"Usage: spark-worker [options] <master>\n" +
"Usage: Worker [options] <master>\n" +
"\n" +
"Master must be a URL of the form spark://hostname:port\n" +
"\n" +
"Options:\n" +
" -c CORES, --cores CORES Number of cores to use\n" +
" -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" +
" -d DIR, --work-dir DIR Directory to run jobs in (default: SPARK_HOME/work)\n" +
" -i IP, --ip IP IP address or DNS name to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: random)\n" +
" --webui-port PORT Port for web UI (default: 8081)")
@ -9,6 +9,7 @@ import cc.spray.Directives
import cc.spray.typeconversion.TwirlSupport._
import spark.deploy.{WorkerState, RequestWorkerState}
class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/worker/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
@ -21,7 +22,7 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct
val future = worker ? RequestWorkerState
future.map { workerState =>
} ~
@ -1,10 +1,12 @@
package spark.executor
import java.io.{File, FileOutputStream}
import java.net.{URL, URLClassLoader}
import java.net.{URI, URL, URLClassLoader}
import java.util.concurrent._
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.fs.FileUtil
import scala.collection.mutable.{ArrayBuffer, Map, HashMap}
import spark.broadcast._
import spark.scheduler._
@ -14,11 +16,16 @@ import java.nio.ByteBuffer
* The Mesos executor for Spark.
class Executor extends Logging {
var classLoader: ClassLoader = null
private[spark] class Executor extends Logging {
var urlClassLoader : ExecutorURLClassLoader = null
var threadPool: ExecutorService = null
var env: SparkEnv = null
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
@ -32,14 +39,14 @@ class Executor extends Logging {
System.setProperty(key, value)
// Create our ClassLoader and set it on this thread
urlClassLoader = createClassLoader()
// Initialize Spark environment (using system properties read above)
env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false)
// Create our ClassLoader (using spark properties) and set it on this thread
classLoader = createClassLoader()
// Start worker thread pool
threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
@ -54,15 +61,16 @@ class Executor extends Logging {
override def run() {
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
try {
val task = ser.deserialize[Task[Any]](serializedTask, classLoader)
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
logInfo("Its generation is " + task.generation)
val value = task.run(taskId.toInt)
@ -96,25 +104,15 @@ class Executor extends Logging {
* Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
* created by the interpreter to the search path
private def createClassLoader(): ClassLoader = {
private def createClassLoader(): ExecutorURLClassLoader = {
var loader = this.getClass.getClassLoader
// If any JAR URIs are given through spark.jar.uris, fetch them to the
// current directory and put them all on the classpath. We assume that
// each URL has a unique file name so that no local filenames will clash
// in this process. This is guaranteed by ClusterScheduler.
val uris = System.getProperty("spark.jar.uris", "")
val localFiles = ArrayBuffer[String]()
for (uri <- uris.split(",").filter(_.size > 0)) {
val url = new URL(uri)
val filename = url.getPath.split("/").last
downloadFile(url, filename)
localFiles += filename
if (localFiles.size > 0) {
val urls = localFiles.map(f => new File(f).toURI.toURL).toArray
loader = new URLClassLoader(urls, loader)
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
val urls = currentJars.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
loader = new URLClassLoader(urls, loader)
// If the REPL is in use, add another ClassLoader that will read
// new classes defined by the REPL as the user types code
@ -133,13 +131,31 @@ class Executor extends Logging {
return loader
return new ExecutorURLClassLoader(Array(), loader)
// Download a file from a given URL to the local filesystem
private def downloadFile(url: URL, localPath: String) {
val in = url.openStream()
val out = new FileOutputStream(localPath)
Utils.copyStream(in, out, true)
* Download any missing dependencies if we receive a new set of files and JARs from the
* SparkContext. Also adds any new JARs we fetched to the class loader.
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name)
Utils.fetchFile(name, new File("."))
currentFiles(name) = timestamp
for ((name, timestamp) <- newJars if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name)
Utils.fetchFile(name, new File("."))
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
val url = new File(".", localName).toURI.toURL
if (!urlClassLoader.getURLs.contains(url)) {
logInfo("Adding " + url + " to class loader")
@ -6,6 +6,6 @@ import spark.TaskState.TaskState
* A pluggable interface used by the Executor to send updates to the cluster scheduler.
trait ExecutorBackend {
private[spark] trait ExecutorBackend {
def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer)
@ -0,0 +1,14 @@
package spark.executor
import java.net.{URLClassLoader, URL}
* The addURL method in URLClassLoader is protected. We subclass it to make this accessible.
private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader)
extends URLClassLoader(urls, parent) {
override def addURL(url: URL) {
@ -8,7 +8,7 @@ import com.google.protobuf.ByteString
import spark.{Utils, Logging}
import spark.TaskState
class MesosExecutorBackend(executor: Executor)
private[spark] class MesosExecutorBackend(executor: Executor)
extends MesosExecutor
with ExecutorBackend
with Logging {
@ -59,7 +59,7 @@ class MesosExecutorBackend(executor: Executor)
* Entry point for Mesos executor.
object MesosExecutorBackend {
private[spark] object MesosExecutorBackend {
def main(args: Array[String]) {
// Create a new Executor and start it running
@ -14,7 +14,7 @@ import spark.scheduler.cluster.RegisterSlaveFailed
import spark.scheduler.cluster.RegisterSlave
class StandaloneExecutorBackend(
private[spark] class StandaloneExecutorBackend(
executor: Executor,
masterUrl: String,
slaveId: String,
@ -62,7 +62,7 @@ class StandaloneExecutorBackend(
object StandaloneExecutorBackend {
private[spark] object StandaloneExecutorBackend {
def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) {
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
@ -11,6 +11,7 @@ import java.nio.channels.spi._
import java.net._
abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging {
@ -23,8 +24,8 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
var onExceptionCallback: (Connection, Exception) => Unit = null
var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
lazy val remoteAddress = getRemoteAddress()
lazy val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
val remoteAddress = getRemoteAddress()
val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
def key() = channel.keyFor(selector)
@ -39,7 +40,10 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
def close() {
val k = key()
if (k != null) {
@ -99,7 +103,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
class SendingConnection(val address: InetSocketAddress, selector_ : Selector)
private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector)
extends Connection(SocketChannel.open, selector_) {
class Outbox(fair: Int = 0) {
@ -134,9 +138,12 @@ extends Connection(SocketChannel.open, selector_) {
if (!message.started) logDebug("Starting to send [" + message + "]")
message.started = true
return chunk
} else {
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
"] in " + message.timeTaken )
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
@ -159,10 +166,11 @@ extends Connection(SocketChannel.open, selector_) {
logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
return chunk
/*messages -= message*/
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
} else {
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
"] in " + message.timeTaken )
@ -216,7 +224,7 @@ extends Connection(SocketChannel.open, selector_) {
while(true) {
if (currentBuffers.size == 0) {
outbox.synchronized {
outbox.getChunk match {
outbox.getChunk() match {
case Some(chunk) => {
currentBuffers ++= chunk.buffers
@ -252,7 +260,7 @@ extends Connection(SocketChannel.open, selector_) {
class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
extends Connection(channel_, selector_) {
class Inbox() {
@ -16,18 +16,19 @@ import scala.collection.mutable.ArrayBuffer
import akka.dispatch.{Await, Promise, ExecutionContext, Future}
import akka.util.Duration
import akka.util.duration._
case class ConnectionManagerId(host: String, port: Int) {
private[spark] case class ConnectionManagerId(host: String, port: Int) {
def toSocketAddress() = new InetSocketAddress(host, port)
object ConnectionManagerId {
private[spark] object ConnectionManagerId {
def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
class ConnectionManager(port: Int) extends Logging {
private[spark] class ConnectionManager(port: Int) extends Logging {
class MessageStatus(
val message: Message,
@ -348,7 +349,7 @@ class ConnectionManager(port: Int) extends Logging {
object ConnectionManager {
private[spark] object ConnectionManager {
def main(args: Array[String]) {
@ -403,7 +404,10 @@ object ConnectionManager {
(0 until count).map(i => {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {if (!f().isDefined) println("Failed")})
}).foreach(f => {
val g = Await.result(f, 1 second)
if (!g.isDefined) println("Failed")
val finishTime = System.currentTimeMillis
val mb = size * count / 1024.0 / 1024.0
@ -430,7 +434,10 @@ object ConnectionManager {
(0 until count).map(i => {
val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {if (!f().isDefined) println("Failed")})
}).foreach(f => {
val g = Await.result(f, 1 second)
if (!g.isDefined) println("Failed")
val finishTime = System.currentTimeMillis
val ms = finishTime - startTime
@ -457,7 +464,10 @@ object ConnectionManager {
(0 until count).map(i => {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {if (!f().isDefined) println("Failed")})
}).foreach(f => {
val g = Await.result(f, 1 second)
if (!g.isDefined) println("Failed")
val finishTime = System.currentTimeMillis
val mb = size * count / 1024.0 / 1024.0
@ -8,7 +8,10 @@ import scala.io.Source
import java.nio.ByteBuffer
import java.net.InetAddress
object ConnectionManagerTest extends Logging{
import akka.dispatch.Await
import akka.util.duration._
private[spark] object ConnectionManagerTest extends Logging{
def main(args: Array[String]) {
if (args.length < 2) {
println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>")
@ -53,7 +56,7 @@ object ConnectionManagerTest extends Logging{
logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
val results = futures.map(f => f())
val results = futures.map(f => Await.result(f, 1.second))
val finishTime = System.currentTimeMillis
@ -7,8 +7,9 @@ import scala.collection.mutable.ArrayBuffer
import java.nio.ByteBuffer
import java.net.InetAddress
import java.net.InetSocketAddress
import storage.BlockManager
class MessageChunkHeader(
private[spark] class MessageChunkHeader(
val typ: Long,
val id: Int,
val totalSize: Int,
@ -36,7 +37,7 @@ class MessageChunkHeader(
" and sizes " + totalSize + " / " + chunkSize + " bytes"
class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
private[spark] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
val size = if (buffer == null) 0 else buffer.remaining
lazy val buffers = {
val ab = new ArrayBuffer[ByteBuffer]()
@ -50,7 +51,7 @@ class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
abstract class Message(val typ: Long, val id: Int) {
private[spark] abstract class Message(val typ: Long, val id: Int) {
var senderAddress: InetSocketAddress = null
var started = false
var startTime = -1L
@ -64,10 +65,10 @@ abstract class Message(val typ: Long, val id: Int) {
def timeTaken(): String = (finishTime - startTime).toString + " ms"
override def toString = "" + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
extends Message(Message.BUFFER_MESSAGE, id_) {
val initialSize = currentSize()
@ -97,10 +98,11 @@ extends Message(Message.BUFFER_MESSAGE, id_) {
while(!buffers.isEmpty) {
val buffer = buffers(0)
if (buffer.remaining == 0) {
buffers -= buffer
} else {
val newBuffer = if (buffer.remaining <= maxChunkSize) {
} else {
@ -147,11 +149,10 @@ extends Message(Message.BUFFER_MESSAGE, id_) {
} else {
"BufferMessage(id = " + id + ", size = " + size + ")"
object MessageChunkHeader {
private[spark] object MessageChunkHeader {
val HEADER_SIZE = 40
def create(buffer: ByteBuffer): MessageChunkHeader = {
@ -172,7 +173,7 @@ object MessageChunkHeader {
object Message {
private[spark] object Message {
val BUFFER_MESSAGE = 1111111111L
var lastId = 1
@ -3,7 +3,7 @@ package spark.network
import java.nio.ByteBuffer
import java.net.InetAddress
object ReceiverTest {
private[spark] object ReceiverTest {
def main(args: Array[String]) {
val manager = new ConnectionManager(9999)
@ -3,7 +3,7 @@ package spark.network
import java.nio.ByteBuffer
import java.net.InetAddress
object SenderTest {
private[spark] object SenderTest {
def main(args: Array[String]) {
@ -12,7 +12,7 @@ import spark.scheduler.JobListener
* a result of type U for each partition, and that the action returns a partial or complete result
* of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt).
class ApproximateActionListener[T, U, R](
private[spark] class ApproximateActionListener[T, U, R](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
@ -4,7 +4,7 @@ package spark.partial
* An object that computes a function incrementally by merging in results of type U from multiple
* tasks. Allows partial evaluation at any point by calling currentResult().
trait ApproximateEvaluator[U, R] {
private[spark] trait ApproximateEvaluator[U, R] {
def merge(outputId: Int, taskResult: U): Unit
def currentResult(): R
@ -3,6 +3,7 @@ package spark.partial
* A Double with error bars on it.
class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
override def toString(): String = "[%.3f, %.3f]".format(low, high)
@ -8,7 +8,7 @@ import cern.jet.stat.Probability
* TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might
* be best to make this a special case of GroupedCountEvaluator with one group.
class CountEvaluator(totalOutputs: Int, confidence: Double)
private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[Long, BoundedDouble] {
var outputsMerged = 0
@ -14,7 +14,7 @@ import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
* An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval.
class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double)
private[spark] class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] {
var outputsMerged = 0
@ -12,7 +12,7 @@ import spark.util.StatCounter
* An ApproximateEvaluator for means by key. Returns a map of key to confidence interval.
class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double)
private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
var outputsMerged = 0
@ -12,7 +12,7 @@ import spark.util.StatCounter
* An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval.
class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double)
private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
var outputsMerged = 0
@ -7,7 +7,7 @@ import spark.util.StatCounter
* An ApproximateEvaluator for means.
class MeanEvaluator(totalOutputs: Int, confidence: Double)
private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[StatCounter, BoundedDouble] {
var outputsMerged = 0
@ -1,6 +1,6 @@
package spark.partial
class PartialResult[R](initialVal: R, isFinal: Boolean) {
private[spark] class PartialResult[R](initialVal: R, isFinal: Boolean) {
private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None
private var failure: Option[Exception] = None
private var completionHandler: Option[R => Unit] = None
@ -7,7 +7,7 @@ import cern.jet.stat.Probability
* and various sample sizes. This is used by the MeanEvaluator to efficiently calculate
* confidence intervals for many keys.
class StudentTCacher(confidence: Double) {
private[spark] class StudentTCacher(confidence: Double) {
val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation
val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2)
val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0)
@ -9,7 +9,7 @@ import spark.util.StatCounter
* together, then uses the formula for the variance of two independent random variables to get
* a variance for the result and compute a confidence interval.
class SumEvaluator(totalOutputs: Int, confidence: Double)
private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[StatCounter, BoundedDouble] {
var outputsMerged = 0
@ -5,11 +5,12 @@ import spark.TaskContext
* Tracks information about an active job in the DAGScheduler.
class ActiveJob(
private[spark] class ActiveJob(
val runId: Int,
val finalStage: Stage,
val func: (TaskContext, Iterator[_]) => _,
val partitions: Array[Int],
val callSite: String,
val listener: JobListener) {
val numPartitions = partitions.length
@ -21,6 +21,7 @@ import spark.storage.BlockManagerId
* schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
@ -38,6 +39,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// Called by TaskScheduler to cancel an entier TaskSet due to repeated failures.
override def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
@ -116,7 +122,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]], priority: Int): Stage = {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of splits is unknown
logInfo("Registering RDD " + rdd.id + ": " + rdd)
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
cacheTracker.registerRDD(rdd.id, rdd.splits.size)
if (shuffleDep != None) {
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
@ -139,7 +145,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visited += r
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
logInfo("Registering parent RDD " + r.id + ": " + r)
logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")")
cacheTracker.registerRDD(r.id, r.splits.size)
for (dep <- r.dependencies) {
dep match {
@ -183,23 +189,25 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
def runJob[T, U](
def runJob[T, U: ClassManifest](
finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean)
(implicit m: ClassManifest[U]): Array[U] =
: Array[U] =
if (partitions.size == 0) {
return new Array[U](0)
val waiter = new JobWaiter(partitions.size)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, waiter))
eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter))
waiter.getResult() match {
case JobSucceeded(results: Seq[_]) =>
return results.asInstanceOf[Seq[U]].toArray
case JobFailed(exception: Exception) =>
logInfo("Failed to run " + callSite)
throw exception
@ -208,13 +216,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long
): PartialResult[R] =
callSite: String,
timeout: Long)
: PartialResult[R] =
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.splits.size).toArray
eventQueue.put(JobSubmitted(rdd, func2, partitions, false, listener))
eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener))
return listener.getResult() // Will throw an exception if the job fails
@ -234,13 +243,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
event match {
case JobSubmitted(finalRDD, func, partitions, allowLocal, listener) =>
case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
val runId = nextRunId.getAndIncrement()
val finalStage = newStage(finalRDD, None, runId)
val job = new ActiveJob(runId, finalStage, func, partitions, listener)
val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
logInfo("Got job " + job.runId + " with " + partitions.length + " output partitions")
logInfo("Final stage: " + finalStage)
logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
" output partitions")
logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
@ -258,6 +268,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
case completion: CompletionEvent =>
case TaskSetFailed(taskSet, reason) =>
abortStage(idToStage(taskSet.stageId), reason)
case StopDAGScheduler =>
// Cancel any active jobs
for (job <- activeJobs) {
@ -329,7 +342,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
if (missing == Nil) {
logInfo("Submitting " + stage + ", which has no missing parents")
logInfo("Submitting " + stage + " (" + stage.origin + "), which has no missing parents")
running += stage
} else {
@ -416,7 +429,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
stage.addOutputLoc(smt.partition, bmAddress)
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
logInfo(stage + " finished; looking for newly runnable stages")
logInfo(stage + " (" + stage.origin + ") finished; looking for newly runnable stages")
running -= stage
logInfo("running: " + running)
logInfo("waiting: " + waiting)
@ -430,7 +443,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
if (stage.outputLocs.count(_ == Nil) != 0) {
// Some tasks had failed; let's resubmit this stage
// TODO: Lower-level scheduler should also deal with this
logInfo("Resubmitting " + stage + " because some of its tasks had failed: " +
logInfo("Resubmitting " + stage + " (" + stage.origin +
") because some of its tasks had failed: " +
stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", "))
} else {
@ -444,6 +458,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
waiting --= newlyRunnable
running ++= newlyRunnable
for (stage <- newlyRunnable.sortBy(_.id)) {
logInfo("Submitting " + stage + " (" + stage.origin + "), which is now runnable")
@ -460,12 +475,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
running -= failedStage
failed += failedStage
// TODO: Cancel running tasks in the stage
logInfo("Marking " + failedStage + " for resubmision due to a fetch failure")
logInfo("Marking " + failedStage + " (" + failedStage.origin +
") for resubmision due to a fetch failure")
// Mark the map whose fetch failed as broken in the map stage
val mapStage = shuffleToMapStage(shuffleId)
mapStage.removeOutputLoc(mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission")
logInfo("The failed fetch was from " + mapStage + " (" + mapStage.origin +
"); marking it for resubmission")
failed += mapStage
// Remember that a fetch failed now; this is used to resubmit the broken
// stages later, after a small wait (to give other tasks the chance to fail)
@ -475,18 +492,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
case _ =>
// Non-fetch failure -- probably a bug in the job, so bail out
// TODO: Cancel all tasks that are still running
resultStageToJob.get(stage) match {
case Some(job) =>
val error = new SparkException("Task failed: " + task + ", reason: " + event.reason)
activeJobs -= job
resultStageToJob -= stage
case None =>
logInfo("Ignoring result from " + task + " because its job has finished")
case other =>
// Non-fetch failure -- probably a bug in user code; abort all jobs depending on this stage
abortStage(idToStage(task.stageId), task + " failed: " + other)
@ -509,6 +517,53 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
job.listener.jobFailed(new SparkException("Job failed: " + reason))
activeJobs -= job
resultStageToJob -= resultStage
if (dependentStages.isEmpty) {
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
* Return true if one of stage's ancestors is target.
def stageDependsOn(stage: Stage, target: Stage): Boolean = {
if (stage == target) {
return true
val visitedRdds = new HashSet[RDD[_]]
val visitedStages = new HashSet[Stage]
def visit(rdd: RDD[_]) {
if (!visitedRdds(rdd)) {
visitedRdds += rdd
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_,_] =>
val mapStage = getShuffleMapStage(shufDep, stage.priority)
if (!mapStage.isAvailable) {
visitedStages += mapStage
} // Otherwise there's no need to follow the dependency back
case narrowDep: NarrowDependency[_] =>
def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
// If the partition is cached, return the cache locations
@ -10,23 +10,26 @@ import spark._
* submitted) but there is a single "logic" thread that reads these events and takes decisions.
* This greatly simplifies synchronization.
sealed trait DAGSchedulerEvent
private[spark] sealed trait DAGSchedulerEvent
case class JobSubmitted(
private[spark] case class JobSubmitted(
finalRDD: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
allowLocal: Boolean,
callSite: String,
listener: JobListener)
extends DAGSchedulerEvent
case class CompletionEvent(
private[spark] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
result: Any,
accumUpdates: Map[Long, Any])
extends DAGSchedulerEvent
case class HostLost(host: String) extends DAGSchedulerEvent
private[spark] case class HostLost(host: String) extends DAGSchedulerEvent
case object StopDAGScheduler extends DAGSchedulerEvent
private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
private[spark] case object StopDAGScheduler extends DAGSchedulerEvent
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Ссылка в новой задаче