This commit is contained in:
Mosharaf Chowdhury 2012-10-02 22:25:39 -07:00
Родитель 31ffe8d528 ff813e4380
Коммит 119e50c7b9
217 изменённых файлов: 14272 добавлений и 1538 удалений

3
.gitignore поставляемый
Просмотреть файл

@ -13,6 +13,8 @@ third_party/libmesos.dylib
conf/java-opts conf/java-opts
conf/spark-env.sh conf/spark-env.sh
conf/log4j.properties conf/log4j.properties
docs/_site
docs/api
target/ target/
reports/ reports/
.project .project
@ -28,3 +30,4 @@ project/plugins/lib_managed/
project/plugins/src_managed/ project/plugins/src_managed/
logs/ logs/
log/ log/
spark-tests.log

Просмотреть файл

@ -6,16 +6,14 @@ Lightning-Fast Cluster Computing - <http://www.spark-project.org/>
## Online Documentation ## Online Documentation
You can find the latest Spark documentation, including a programming You can find the latest Spark documentation, including a programming
guide, on the project wiki at <http://github.com/mesos/spark/wiki>. This guide, on the project webpage at <http://spark-project.org/documentation.html>.
file only contains basic setup instructions. This README file only contains basic setup instructions.
## Building ## Building
Spark requires Scala 2.9.1. This version has been tested with 2.9.1.final. Spark requires Scala 2.9.2. The project is built using Simple Build Tool (SBT),
which is packaged with it. To build Spark and its example programs, run:
The project is built using Simple Build Tool (SBT), which is packaged with it.
To build Spark and its example programs, run:
sbt/sbt compile sbt/sbt compile

Просмотреть файл

@ -142,7 +142,7 @@ class WPRSerializerInstance extends SerializerInstance {
class WPRSerializationStream(os: OutputStream) extends SerializationStream { class WPRSerializationStream(os: OutputStream) extends SerializationStream {
val dos = new DataOutputStream(os) val dos = new DataOutputStream(os)
def writeObject[T](t: T): Unit = t match { def writeObject[T](t: T): SerializationStream = t match {
case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match {
case links: Array[String] => { case links: Array[String] => {
dos.writeInt(0) // links dos.writeInt(0) // links
@ -151,17 +151,20 @@ class WPRSerializationStream(os: OutputStream) extends SerializationStream {
for (link <- links) { for (link <- links) {
dos.writeUTF(link) dos.writeUTF(link)
} }
this
} }
case rank: Double => { case rank: Double => {
dos.writeInt(1) // rank dos.writeInt(1) // rank
dos.writeUTF(id) dos.writeUTF(id)
dos.writeDouble(rank) dos.writeDouble(rank)
this
} }
} }
case (id: String, rank: Double) => { case (id: String, rank: Double) => {
dos.writeInt(2) // rank without wrapper dos.writeInt(2) // rank without wrapper
dos.writeUTF(id) dos.writeUTF(id)
dos.writeDouble(rank) dos.writeDouble(rank)
this
} }
} }

Просмотреть файл

@ -1,8 +1,10 @@
# Set everything to be logged to the console # Set everything to be logged to the console
log4j.rootCategory=WARN, console log4j.rootCategory=INFO, file
log4j.appender.console=org.apache.log4j.ConsoleAppender log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.file.append=false
log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n log4j.appender.file.file=spark-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose # Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN log4j.logger.org.eclipse.jetty=WARN

Просмотреть файл

@ -1,5 +1,8 @@
#!/usr/bin/env bash #!/usr/bin/env bash
# This Spark deploy script is a modified version of the Apache Hadoop deploy
# script, available under the Apache 2 license:
#
# Licensed to the Apache Software Foundation (ASF) under one or more # Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with # contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership. # this work for additional information regarding copyright ownership.

Просмотреть файл

@ -1,5 +1,8 @@
#!/usr/bin/env bash #!/usr/bin/env bash
# This Spark deploy script is a modified version of the Apache Hadoop deploy
# script, available under the Apache 2 license:
#
# Licensed to the Apache Software Foundation (ASF) under one or more # Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with # contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership. # this work for additional information regarding copyright ownership.

Просмотреть файл

@ -3,6 +3,7 @@ package spark
import java.io._ import java.io._
import scala.collection.mutable.Map import scala.collection.mutable.Map
import scala.collection.generic.Growable
/** /**
* A datatype that can be accumulated, i.e. has an commutative and associative +. * A datatype that can be accumulated, i.e. has an commutative and associative +.
@ -92,6 +93,29 @@ trait AccumulableParam[R, T] extends Serializable {
def zero(initialValue: R): R def zero(initialValue: R): R
} }
class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
extends AccumulableParam[R,T] {
def addAccumulator(growable: R, elem: T) : R = {
growable += elem
growable
}
def addInPlace(t1: R, t2: R) : R = {
t1 ++= t2
t1
}
def zero(initialValue: R): R = {
// We need to clone initialValue, but it's hard to specify that R should also be Cloneable.
// Instead we'll serialize it to a buffer and load it back.
val ser = (new spark.JavaSerializer).newInstance()
val copy = ser.deserialize[R](ser.serialize(initialValue))
copy.clear() // In case it contained stuff
copy
}
}
/** /**
* A simpler value of [[spark.Accumulable]] where the result type being accumulated is the same * A simpler value of [[spark.Accumulable]] where the result type being accumulated is the same
* as the types of elements being merged. * as the types of elements being merged.

Просмотреть файл

@ -9,9 +9,9 @@ package spark
* known as map-side aggregations. When set to false, * known as map-side aggregations. When set to false,
* mergeCombiners function is not used. * mergeCombiners function is not used.
*/ */
class Aggregator[K, V, C] ( case class Aggregator[K, V, C] (
val createCombiner: V => C, val createCombiner: V => C,
val mergeValue: (C, V) => C, val mergeValue: (C, V) => C,
val mergeCombiners: (C, C) => C, val mergeCombiners: (C, C) => C,
val mapSideCombine: Boolean = true) val mapSideCombine: Boolean = true)
extends Serializable

Просмотреть файл

@ -2,12 +2,13 @@ package spark
import scala.collection.mutable.HashMap import scala.collection.mutable.HashMap
class BlockRDDSplit(val blockId: String, idx: Int) extends Split { private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
val index = idx val index = idx
} }
class BlockRDD[T: ClassManifest](sc: SparkContext, blockIds: Array[String]) extends RDD[T](sc) { class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
extends RDD[T](sc) {
@transient @transient
val splits_ = (0 until blockIds.size).map(i => { val splits_ = (0 until blockIds.size).map(i => {

Просмотреть файл

@ -11,8 +11,7 @@ import spark.storage.BlockManagerId
import it.unimi.dsi.fastutil.io.FastBufferedInputStream import it.unimi.dsi.fastutil.io.FastBufferedInputStream
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) { def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager val blockManager = SparkEnv.get.blockManager
@ -29,39 +28,32 @@ class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
val blocksByAddress: Seq[(BlockManagerId, Seq[String])] = splitsByAddress.toSeq.map { val blocksByAddress: Seq[(BlockManagerId, Seq[String])] = splitsByAddress.toSeq.map {
case (address, splits) => case (address, splits) =>
(address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId))) (address, splits.map(i => "shuffle_%d_%d_%d".format(shuffleId, i, reduceId)))
} }
try { for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) {
for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) { blockOption match {
blockOption match { case Some(block) => {
case Some(block) => { val values = block
val values = block for(value <- values) {
for(value <- values) { val v = value.asInstanceOf[(K, V)]
val v = value.asInstanceOf[(K, V)] func(v._1, v._2)
func(v._1, v._2)
}
}
case None => {
throw new BlockException(blockId, "Did not get block " + blockId)
} }
} }
} case None => {
} catch { val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
// TODO: this is really ugly -- let's find a better way of throwing a FetchFailedException blockId match {
case be: BlockException => { case regex(shufId, mapId, reduceId) =>
val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]]*)".r val addr = addresses(mapId.toInt)
be.blockId match { throw new FetchFailedException(addr, shufId.toInt, mapId.toInt, reduceId.toInt, null)
case regex(sId, mId, rId) => { case _ =>
val address = addresses(mId.toInt) throw new SparkException(
throw new FetchFailedException(address, sId.toInt, mId.toInt, rId.toInt, be) "Failed to get block " + blockId + ", which is not a shuffle block")
}
case _ => {
throw be
} }
} }
} }
} }
logDebug("Fetching and merging outputs of shuffle %d, reduce %d took %d ms".format( logDebug("Fetching and merging outputs of shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime)) shuffleId, reduceId, System.currentTimeMillis - startTime))
} }

Просмотреть файл

@ -9,7 +9,7 @@ import java.util.LinkedHashMap
* some cache entries have pointers to a shared object. Nonetheless, this Cache should work well * some cache entries have pointers to a shared object. Nonetheless, this Cache should work well
* when most of the space is used by arrays of primitives or of simple classes. * when most of the space is used by arrays of primitives or of simple classes.
*/ */
class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging { private[spark] class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
logInfo("BoundedMemoryCache.maxBytes = " + maxBytes) logInfo("BoundedMemoryCache.maxBytes = " + maxBytes)
def this() { def this() {
@ -104,9 +104,9 @@ class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
} }
// An entry in our map; stores a cached object and its size in bytes // An entry in our map; stores a cached object and its size in bytes
case class Entry(value: Any, size: Long) private[spark] case class Entry(value: Any, size: Long)
object BoundedMemoryCache { private[spark] object BoundedMemoryCache {
/** /**
* Get maximum cache capacity from system configuration * Get maximum cache capacity from system configuration
*/ */

Просмотреть файл

@ -2,9 +2,9 @@ package spark
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
sealed trait CachePutResponse private[spark] sealed trait CachePutResponse
case class CachePutSuccess(size: Long) extends CachePutResponse private[spark] case class CachePutSuccess(size: Long) extends CachePutResponse
case class CachePutFailure() extends CachePutResponse private[spark] case class CachePutFailure() extends CachePutResponse
/** /**
* An interface for caches in Spark, to allow for multiple implementations. Caches are used to store * An interface for caches in Spark, to allow for multiple implementations. Caches are used to store
@ -22,7 +22,7 @@ case class CachePutFailure() extends CachePutResponse
* This abstract class handles the creation of key spaces, so that subclasses need only deal with * This abstract class handles the creation of key spaces, so that subclasses need only deal with
* keys that are unique across modules. * keys that are unique across modules.
*/ */
abstract class Cache { private[spark] abstract class Cache {
private val nextKeySpaceId = new AtomicInteger(0) private val nextKeySpaceId = new AtomicInteger(0)
private def newKeySpaceId() = nextKeySpaceId.getAndIncrement() private def newKeySpaceId() = nextKeySpaceId.getAndIncrement()
@ -52,7 +52,7 @@ abstract class Cache {
/** /**
* A key namespace in a Cache. * A key namespace in a Cache.
*/ */
class KeySpace(cache: Cache, val keySpaceId: Int) { private[spark] class KeySpace(cache: Cache, val keySpaceId: Int) {
def get(datasetId: Any, partition: Int): Any = def get(datasetId: Any, partition: Int): Any =
cache.get((keySpaceId, datasetId), partition) cache.get((keySpaceId, datasetId), partition)

Просмотреть файл

@ -15,19 +15,20 @@ import scala.collection.mutable.HashSet
import spark.storage.BlockManager import spark.storage.BlockManager
import spark.storage.StorageLevel import spark.storage.StorageLevel
sealed trait CacheTrackerMessage private[spark] sealed trait CacheTrackerMessage
case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
case class MemoryCacheLost(host: String) extends CacheTrackerMessage
case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
case object GetCacheStatus extends CacheTrackerMessage
case object GetCacheLocations extends CacheTrackerMessage
case object StopCacheTracker extends CacheTrackerMessage
class CacheTrackerActor extends Actor with Logging { private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage
private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
private[spark] case object GetCacheStatus extends CacheTrackerMessage
private[spark] case object GetCacheLocations extends CacheTrackerMessage
private[spark] case object StopCacheTracker extends CacheTrackerMessage
private[spark] class CacheTrackerActor extends Actor with Logging {
// TODO: Should probably store (String, CacheType) tuples // TODO: Should probably store (String, CacheType) tuples
private val locs = new HashMap[Int, Array[List[String]]] private val locs = new HashMap[Int, Array[List[String]]]
@ -43,8 +44,6 @@ class CacheTrackerActor extends Actor with Logging {
def receive = { def receive = {
case SlaveCacheStarted(host: String, size: Long) => case SlaveCacheStarted(host: String, size: Long) =>
logInfo("Started slave cache (size %s) on %s".format(
Utils.memoryBytesToString(size), host))
slaveCapacity.put(host, size) slaveCapacity.put(host, size)
slaveUsage.put(host, 0) slaveUsage.put(host, 0)
sender ! true sender ! true
@ -56,22 +55,12 @@ class CacheTrackerActor extends Actor with Logging {
case AddedToCache(rddId, partition, host, size) => case AddedToCache(rddId, partition, host, size) =>
slaveUsage.put(host, getCacheUsage(host) + size) slaveUsage.put(host, getCacheUsage(host) + size)
logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
rddId, partition, host, Utils.memoryBytesToString(size),
Utils.memoryBytesToString(getCacheAvailable(host))))
locs(rddId)(partition) = host :: locs(rddId)(partition) locs(rddId)(partition) = host :: locs(rddId)(partition)
sender ! true sender ! true
case DroppedFromCache(rddId, partition, host, size) => case DroppedFromCache(rddId, partition, host, size) =>
logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
rddId, partition, host, Utils.memoryBytesToString(size),
Utils.memoryBytesToString(getCacheAvailable(host))))
slaveUsage.put(host, getCacheUsage(host) - size) slaveUsage.put(host, getCacheUsage(host) - size)
// Do a sanity check to make sure usage is greater than 0. // Do a sanity check to make sure usage is greater than 0.
val usage = getCacheUsage(host)
if (usage < 0) {
logError("Cache usage on %s is negative (%d)".format(host, usage))
}
locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host) locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
sender ! true sender ! true
@ -101,7 +90,7 @@ class CacheTrackerActor extends Actor with Logging {
} }
} }
class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager) private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
extends Logging { extends Logging {
// Tracker actor on the master, or remote reference to it on workers // Tracker actor on the master, or remote reference to it on workers
@ -151,7 +140,6 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
logInfo("Registering RDD ID " + rddId + " with cache") logInfo("Registering RDD ID " + rddId + " with cache")
registeredRddIds += rddId registeredRddIds += rddId
communicate(RegisterRDD(rddId, numPartitions)) communicate(RegisterRDD(rddId, numPartitions))
logInfo(RegisterRDD(rddId, numPartitions) + " successful")
} }
} }
} }
@ -171,7 +159,6 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
// For BlockManager.scala only // For BlockManager.scala only
def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) { def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) {
communicate(t) communicate(t)
logInfo("notifyTheCacheTrackerFromBlockManager successful")
} }
// Get a snapshot of the currently known locations // Get a snapshot of the currently known locations
@ -181,7 +168,7 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
// Gets or computes an RDD split // Gets or computes an RDD split
def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = { def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = {
val key = "rdd:%d:%d".format(rdd.id, split.index) val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key) logInfo("Cache key is " + key)
blockManager.get(key) match { blockManager.get(key) match {
case Some(cachedValues) => case Some(cachedValues) =>
@ -223,7 +210,7 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
logInfo("Computing partition " + split) logInfo("Computing partition " + split)
try { try {
// BlockManager will iterate over results from compute to create RDD // BlockManager will iterate over results from compute to create RDD
blockManager.put(key, rdd.compute(split), storageLevel, false) blockManager.put(key, rdd.compute(split), storageLevel, true)
//future.apply() // Wait for the reply from the cache tracker //future.apply() // Wait for the reply from the cache tracker
blockManager.get(key) match { blockManager.get(key) match {
case Some(values) => case Some(values) =>

Просмотреть файл

@ -1,5 +1,6 @@
package spark package spark
private[spark]
class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable { class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable {
override val index: Int = idx override val index: Int = idx
} }

Просмотреть файл

@ -9,7 +9,7 @@ import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
import org.objectweb.asm.commons.EmptyVisitor import org.objectweb.asm.commons.EmptyVisitor
import org.objectweb.asm.Opcodes._ import org.objectweb.asm.Opcodes._
object ClosureCleaner extends Logging { private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it // Get an ASM class reader for a given class from the JAR that loaded it
private def getClassReader(cls: Class[_]): ClassReader = { private def getClassReader(cls: Class[_]): ClassReader = {
new ClassReader(cls.getResourceAsStream( new ClassReader(cls.getResourceAsStream(
@ -154,7 +154,7 @@ object ClosureCleaner extends Logging {
} }
} }
class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor { private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
override def visitMethod(access: Int, name: String, desc: String, override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = { sig: String, exceptions: Array[String]): MethodVisitor = {
return new EmptyVisitor { return new EmptyVisitor {
@ -180,7 +180,7 @@ class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor
} }
} }
class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor { private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
var myName: String = null var myName: String = null
override def visit(version: Int, access: Int, name: String, sig: String, override def visit(version: Int, access: Int, name: String, sig: String,

Просмотреть файл

@ -6,16 +6,17 @@ import java.io.ObjectInputStream
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap import scala.collection.mutable.HashMap
sealed trait CoGroupSplitDep extends Serializable private[spark] sealed trait CoGroupSplitDep extends Serializable
case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
private[spark]
class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable { class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable {
override val index: Int = idx override val index: Int = idx
override def hashCode(): Int = idx override def hashCode(): Int = idx
} }
class CoGroupAggregator private[spark] class CoGroupAggregator
extends Aggregator[Any, Any, ArrayBuffer[Any]]( extends Aggregator[Any, Any, ArrayBuffer[Any]](
{ x => ArrayBuffer(x) }, { x => ArrayBuffer(x) },
{ (b, x) => b += x }, { (b, x) => b += x },

Просмотреть файл

@ -0,0 +1,43 @@
package spark
private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split
/**
* Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of
* this RDD computes one or more of the parent ones. Will produce exactly `maxPartitions` if the
* parent had more than this many partitions, or fewer if the parent had fewer.
*
* This transformation is useful when an RDD with many partitions gets filtered into a smaller one,
* or to avoid having a large number of small tasks when processing a directory with many files.
*/
class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int)
extends RDD[T](prev.context) {
@transient val splits_ : Array[Split] = {
val prevSplits = prev.splits
if (prevSplits.length < maxPartitions) {
prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) }
} else {
(0 until maxPartitions).map { i =>
val rangeStart = (i * prevSplits.length) / maxPartitions
val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions
new CoalescedRDDSplit(i, prevSplits.slice(rangeStart, rangeEnd))
}.toArray
}
}
override def splits = splits_
override def compute(split: Split): Iterator[T] = {
split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap {
parentSplit => prev.iterator(parentSplit)
}
}
val dependencies = List(
new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] =
splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index)
}
)
}

Просмотреть файл

@ -6,9 +6,13 @@ import java.util.concurrent.ThreadFactory
* A ThreadFactory that creates daemon threads * A ThreadFactory that creates daemon threads
*/ */
private object DaemonThreadFactory extends ThreadFactory { private object DaemonThreadFactory extends ThreadFactory {
override def newThread(r: Runnable): Thread = { override def newThread(r: Runnable): Thread = new DaemonThread(r)
val t = new Thread(r) }
t.setDaemon(true)
return t private class DaemonThread(r: Runnable = null) extends Thread {
override def run() {
if (r != null) {
r.run()
}
} }
} }

Просмотреть файл

@ -2,7 +2,7 @@ package spark
import spark.storage.BlockManagerId import spark.storage.BlockManagerId
class FetchFailedException( private[spark] class FetchFailedException(
val bmAddress: BlockManagerId, val bmAddress: BlockManagerId,
val shuffleId: Int, val shuffleId: Int,
val mapId: Int, val mapId: Int,

Просмотреть файл

@ -18,7 +18,7 @@ import org.apache.hadoop.util.ReflectionUtils
/** /**
* A Spark split class that wraps around a Hadoop InputSplit. * A Spark split class that wraps around a Hadoop InputSplit.
*/ */
class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit) private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
extends Split extends Split
with Serializable { with Serializable {
@ -42,7 +42,8 @@ class HadoopRDD[K, V](
minSplits: Int) minSplits: Int)
extends RDD[(K, V)](sc) { extends RDD[(K, V)](sc) {
val serializableConf = new SerializableWritable(conf) // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it
val confBroadcast = sc.broadcast(new SerializableWritable(conf))
@transient @transient
val splits_ : Array[Split] = { val splits_ : Array[Split] = {
@ -66,7 +67,7 @@ class HadoopRDD[K, V](
val split = theSplit.asInstanceOf[HadoopSplit] val split = theSplit.asInstanceOf[HadoopSplit]
var reader: RecordReader[K, V] = null var reader: RecordReader[K, V] = null
val conf = serializableConf.value val conf = confBroadcast.value.value
val fmt = createInputFormat(conf) val fmt = createInputFormat(conf)
reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)

Просмотреть файл

@ -0,0 +1,47 @@
package spark
import java.io.{File, PrintWriter}
import java.net.URL
import scala.collection.mutable.HashMap
import org.apache.hadoop.fs.FileUtil
private[spark] class HttpFileServer extends Logging {
var baseDir : File = null
var fileDir : File = null
var jarDir : File = null
var httpServer : HttpServer = null
var serverUri : String = null
def initialize() {
baseDir = Utils.createTempDir()
fileDir = new File(baseDir, "files")
jarDir = new File(baseDir, "jars")
fileDir.mkdir()
jarDir.mkdir()
logInfo("HTTP File server directory is " + baseDir)
httpServer = new HttpServer(baseDir)
httpServer.start()
serverUri = httpServer.uri
}
def stop() {
httpServer.stop()
}
def addFile(file: File) : String = {
addFileToDir(file, fileDir)
return serverUri + "/files/" + file.getName
}
def addJar(file: File) : String = {
addFileToDir(file, jarDir)
return serverUri + "/jars/" + file.getName
}
def addFileToDir(file: File, dir: File) : String = {
Utils.copyFile(file, new File(dir, file.getName))
return dir + "/" + file.getName
}
}

Просмотреть файл

@ -12,14 +12,14 @@ import org.eclipse.jetty.util.thread.QueuedThreadPool
/** /**
* Exception type thrown by HttpServer when it is in the wrong state for an operation. * Exception type thrown by HttpServer when it is in the wrong state for an operation.
*/ */
class ServerStateException(message: String) extends Exception(message) private[spark] class ServerStateException(message: String) extends Exception(message)
/** /**
* An HTTP server for static content used to allow worker nodes to access JARs added to SparkContext * An HTTP server for static content used to allow worker nodes to access JARs added to SparkContext
* as well as classes created by the interpreter when the user types in code. This is just a wrapper * as well as classes created by the interpreter when the user types in code. This is just a wrapper
* around a Jetty server. * around a Jetty server.
*/ */
class HttpServer(resourceBase: File) extends Logging { private[spark] class HttpServer(resourceBase: File) extends Logging {
private var server: Server = null private var server: Server = null
private var port: Int = -1 private var port: Int = -1

Просмотреть файл

@ -5,14 +5,14 @@ import java.nio.ByteBuffer
import spark.util.ByteBufferInputStream import spark.util.ByteBufferInputStream
class JavaSerializationStream(out: OutputStream) extends SerializationStream { private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream {
val objOut = new ObjectOutputStream(out) val objOut = new ObjectOutputStream(out)
def writeObject[T](t: T) { objOut.writeObject(t) } def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this }
def flush() { objOut.flush() } def flush() { objOut.flush() }
def close() { objOut.close() } def close() { objOut.close() }
} }
class JavaDeserializationStream(in: InputStream, loader: ClassLoader) private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
extends DeserializationStream { extends DeserializationStream {
val objIn = new ObjectInputStream(in) { val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass) = override def resolveClass(desc: ObjectStreamClass) =
@ -23,7 +23,7 @@ extends DeserializationStream {
def close() { objIn.close() } def close() { objIn.close() }
} }
class JavaSerializerInstance extends SerializerInstance { private[spark] class JavaSerializerInstance extends SerializerInstance {
def serialize[T](t: T): ByteBuffer = { def serialize[T](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream() val bos = new ByteArrayOutputStream()
val out = serializeStream(bos) val out = serializeStream(bos)
@ -57,6 +57,6 @@ class JavaSerializerInstance extends SerializerInstance {
} }
} }
class JavaSerializer extends Serializer { private[spark] class JavaSerializer extends Serializer {
def newInstance(): SerializerInstance = new JavaSerializerInstance def newInstance(): SerializerInstance = new JavaSerializerInstance
} }

Просмотреть файл

@ -10,15 +10,17 @@ import scala.collection.mutable
import com.esotericsoftware.kryo._ import com.esotericsoftware.kryo._
import com.esotericsoftware.kryo.{Serializer => KSerializer} import com.esotericsoftware.kryo.{Serializer => KSerializer}
import com.esotericsoftware.kryo.serialize.ClassSerializer import com.esotericsoftware.kryo.serialize.ClassSerializer
import com.esotericsoftware.kryo.serialize.SerializableSerializer
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
import spark.broadcast._
import spark.storage._ import spark.storage._
/** /**
* Zig-zag encoder used to write object sizes to serialization streams. * Zig-zag encoder used to write object sizes to serialization streams.
* Based on Kryo's integer encoder. * Based on Kryo's integer encoder.
*/ */
object ZigZag { private[spark] object ZigZag {
def writeInt(n: Int, out: OutputStream) { def writeInt(n: Int, out: OutputStream) {
var value = n var value = n
if ((value & ~0x7F) == 0) { if ((value & ~0x7F) == 0) {
@ -66,22 +68,25 @@ object ZigZag {
} }
} }
private[spark]
class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream) class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream)
extends SerializationStream { extends SerializationStream {
val channel = Channels.newChannel(out) val channel = Channels.newChannel(out)
def writeObject[T](t: T) { def writeObject[T](t: T): SerializationStream = {
kryo.writeClassAndObject(threadBuffer, t) kryo.writeClassAndObject(threadBuffer, t)
ZigZag.writeInt(threadBuffer.position(), out) ZigZag.writeInt(threadBuffer.position(), out)
threadBuffer.flip() threadBuffer.flip()
channel.write(threadBuffer) channel.write(threadBuffer)
threadBuffer.clear() threadBuffer.clear()
this
} }
def flush() { out.flush() } def flush() { out.flush() }
def close() { out.close() } def close() { out.close() }
} }
private[spark]
class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream) class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream)
extends DeserializationStream { extends DeserializationStream {
def readObject[T](): T = { def readObject[T](): T = {
@ -92,7 +97,7 @@ extends DeserializationStream {
def close() { in.close() } def close() { in.close() }
} }
class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
val kryo = ks.kryo val kryo = ks.kryo
val threadBuffer = ks.threadBuffer.get() val threadBuffer = ks.threadBuffer.get()
val objectBuffer = ks.objectBuffer.get() val objectBuffer = ks.objectBuffer.get()
@ -159,7 +164,9 @@ trait KryoRegistrator {
} }
class KryoSerializer extends Serializer with Logging { class KryoSerializer extends Serializer with Logging {
val kryo = createKryo() // Make this lazy so that it only gets called once we receive our first task on each executor,
// so we can pull out any custom Kryo registrator from the user's JARs.
lazy val kryo = createKryo()
val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024 val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024
@ -190,8 +197,8 @@ class KryoSerializer extends Serializer with Logging {
(1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1), (1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1),
None, None,
ByteBuffer.allocate(1), ByteBuffer.allocate(1),
StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY,
PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY_DESER), PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
GotBlock("1", ByteBuffer.allocate(1)), GotBlock("1", ByteBuffer.allocate(1)),
GetBlock("1") GetBlock("1")
) )
@ -203,6 +210,10 @@ class KryoSerializer extends Serializer with Logging {
kryo.register(classOf[Class[_]], new ClassSerializer(kryo)) kryo.register(classOf[Class[_]], new ClassSerializer(kryo))
kryo.setRegistrationOptional(true) kryo.setRegistrationOptional(true)
// Allow sending SerializableWritable
kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer())
kryo.register(classOf[HttpBroadcast[_]], new SerializableSerializer())
// Register some commonly used Scala singleton objects. Because these // Register some commonly used Scala singleton objects. Because these
// are singletons, we must return the exact same local object when we // are singletons, we must return the exact same local object when we
// deserialize rather than returning a clone as FieldSerializer would. // deserialize rather than returning a clone as FieldSerializer would.
@ -250,7 +261,8 @@ class KryoSerializer extends Serializer with Logging {
val regCls = System.getProperty("spark.kryo.registrator") val regCls = System.getProperty("spark.kryo.registrator")
if (regCls != null) { if (regCls != null) {
logInfo("Running user registrator: " + regCls) logInfo("Running user registrator: " + regCls)
val reg = Class.forName(regCls).newInstance().asInstanceOf[KryoRegistrator] val classLoader = Thread.currentThread.getContextClassLoader
val reg = Class.forName(regCls, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]
reg.registerClasses(kryo) reg.registerClasses(kryo)
} }
kryo kryo

Просмотреть файл

@ -1,5 +1,6 @@
package spark package spark
import java.io.{DataInputStream, DataOutputStream, ByteArrayOutputStream, ByteArrayInputStream}
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import akka.actor._ import akka.actor._
@ -10,20 +11,20 @@ import akka.util.Duration
import akka.util.Timeout import akka.util.Timeout
import akka.util.duration._ import akka.util.duration._
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet import scala.collection.mutable.HashSet
import spark.storage.BlockManagerId import spark.storage.BlockManagerId
sealed trait MapOutputTrackerMessage private[spark] sealed trait MapOutputTrackerMessage
case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage
case object StopMapOutputTracker extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
class MapOutputTrackerActor(bmAddresses: ConcurrentHashMap[Int, Array[BlockManagerId]]) private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
extends Actor with Logging {
def receive = { def receive = {
case GetMapOutputLocations(shuffleId: Int) => case GetMapOutputLocations(shuffleId: Int) =>
logInfo("Asked to get map output locations for shuffle " + shuffleId) logInfo("Asked to get map output locations for shuffle " + shuffleId)
sender ! bmAddresses.get(shuffleId) sender ! tracker.getSerializedLocations(shuffleId)
case StopMapOutputTracker => case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!") logInfo("MapOutputTrackerActor stopped!")
@ -32,22 +33,26 @@ extends Actor with Logging {
} }
} }
class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging { private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging {
val ip: String = System.getProperty("spark.master.host", "localhost") val ip: String = System.getProperty("spark.master.host", "localhost")
val port: Int = System.getProperty("spark.master.port", "7077").toInt val port: Int = System.getProperty("spark.master.port", "7077").toInt
val actorName: String = "MapOutputTracker" val actorName: String = "MapOutputTracker"
val timeout = 10.seconds val timeout = 10.seconds
private var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]] var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
// Incremented every time a fetch fails so that client nodes know to clear // Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens. // their cache of map output locations if this happens.
private var generation: Long = 0 private var generation: Long = 0
private var generationLock = new java.lang.Object private var generationLock = new java.lang.Object
// Cache a serialized version of the output locations for each shuffle to send them out faster
var cacheGeneration = generation
val cachedSerializedLocs = new HashMap[Int, Array[Byte]]
var trackerActor: ActorRef = if (isMaster) { var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(bmAddresses)), name = actorName) val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
logInfo("Registered MapOutputTrackerActor actor") logInfo("Registered MapOutputTrackerActor actor")
actor actor
} else { } else {
@ -134,15 +139,16 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg
} }
// We won the race to fetch the output locs; do so // We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor) logInfo("Doing the fetch; tracker actor = " + trackerActor)
val fetched = askTracker(GetMapOutputLocations(shuffleId)).asInstanceOf[Array[BlockManagerId]] val fetchedBytes = askTracker(GetMapOutputLocations(shuffleId)).asInstanceOf[Array[Byte]]
val fetchedLocs = deserializeLocations(fetchedBytes)
logInfo("Got the output locations") logInfo("Got the output locations")
bmAddresses.put(shuffleId, fetched) bmAddresses.put(shuffleId, fetchedLocs)
fetching.synchronized { fetching.synchronized {
fetching -= shuffleId fetching -= shuffleId
fetching.notifyAll() fetching.notifyAll()
} }
return fetched return fetchedLocs
} else { } else {
return locs return locs
} }
@ -181,4 +187,70 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg
} }
} }
} }
def getSerializedLocations(shuffleId: Int): Array[Byte] = {
var locs: Array[BlockManagerId] = null
var generationGotten: Long = -1
generationLock.synchronized {
if (generation > cacheGeneration) {
cachedSerializedLocs.clear()
cacheGeneration = generation
}
cachedSerializedLocs.get(shuffleId) match {
case Some(bytes) =>
return bytes
case None =>
locs = bmAddresses.get(shuffleId)
generationGotten = generation
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
val bytes = serializeLocations(locs)
// Add them into the table only if the generation hasn't changed while we were working
generationLock.synchronized {
if (generation == generationGotten) {
cachedSerializedLocs(shuffleId) = bytes
}
}
return bytes
}
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by grouping together the locations by block manager ID.
def serializeLocations(locs: Array[BlockManagerId]): Array[Byte] = {
val out = new ByteArrayOutputStream
val dataOut = new DataOutputStream(out)
dataOut.writeInt(locs.length)
val grouped = locs.zipWithIndex.groupBy(_._1)
dataOut.writeInt(grouped.size)
for ((id, pairs) <- grouped if id != null) {
dataOut.writeUTF(id.ip)
dataOut.writeInt(id.port)
dataOut.writeInt(pairs.length)
for ((_, blockIndex) <- pairs) {
dataOut.writeInt(blockIndex)
}
}
dataOut.close()
out.toByteArray
}
// Opposite of serializeLocations.
def deserializeLocations(bytes: Array[Byte]): Array[BlockManagerId] = {
val dataIn = new DataInputStream(new ByteArrayInputStream(bytes))
val length = dataIn.readInt()
val array = new Array[BlockManagerId](length)
val numGroups = dataIn.readInt()
for (i <- 0 until numGroups) {
val ip = dataIn.readUTF()
val port = dataIn.readInt()
val id = new BlockManagerId(ip, port)
val numBlocks = dataIn.readInt()
for (j <- 0 until numBlocks) {
array(dataIn.readInt()) = id
}
}
array
}
} }

Просмотреть файл

@ -13,6 +13,7 @@ import org.apache.hadoop.mapreduce.TaskAttemptID
import java.util.Date import java.util.Date
import java.text.SimpleDateFormat import java.text.SimpleDateFormat
private[spark]
class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable) class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
extends Split { extends Split {
@ -28,7 +29,9 @@ class NewHadoopRDD[K, V](
@transient conf: Configuration) @transient conf: Configuration)
extends RDD[(K, V)](sc) { extends RDD[(K, V)](sc) {
private val serializableConf = new SerializableWritable(conf) // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
val confBroadcast = sc.broadcast(new SerializableWritable(conf))
// private val serializableConf = new SerializableWritable(conf)
private val jobtrackerId: String = { private val jobtrackerId: String = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm") val formatter = new SimpleDateFormat("yyyyMMddHHmm")
@ -41,7 +44,7 @@ class NewHadoopRDD[K, V](
@transient @transient
private val splits_ : Array[Split] = { private val splits_ : Array[Split] = {
val inputFormat = inputFormatClass.newInstance val inputFormat = inputFormatClass.newInstance
val jobContext = new JobContext(serializableConf.value, jobId) val jobContext = new JobContext(conf, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray val rawSplits = inputFormat.getSplits(jobContext).toArray
val result = new Array[Split](rawSplits.size) val result = new Array[Split](rawSplits.size)
for (i <- 0 until rawSplits.size) { for (i <- 0 until rawSplits.size) {
@ -54,9 +57,9 @@ class NewHadoopRDD[K, V](
override def compute(theSplit: Split) = new Iterator[(K, V)] { override def compute(theSplit: Split) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopSplit] val split = theSplit.asInstanceOf[NewHadoopSplit]
val conf = serializableConf.value val conf = confBroadcast.value.value
val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
val context = new TaskAttemptContext(serializableConf.value, attemptId) val context = new TaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance val format = inputFormatClass.newInstance
val reader = format.createRecordReader(split.serializableHadoopSplit.value, context) val reader = format.createRecordReader(split.serializableHadoopSplit.value, context)
reader.initialize(split.serializableHadoopSplit.value, context) reader.initialize(split.serializableHadoopSplit.value, context)

Просмотреть файл

@ -1,11 +1,10 @@
package spark package spark
import java.io.EOFException import java.io.EOFException
import java.net.URL
import java.io.ObjectInputStream import java.io.ObjectInputStream
import java.net.URL
import java.util.{Date, HashMap => JHashMap}
import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.atomic.AtomicLong
import java.util.{HashMap => JHashMap}
import java.util.Date
import java.text.SimpleDateFormat import java.text.SimpleDateFormat
import scala.collection.Map import scala.collection.Map
@ -50,9 +49,18 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
def combineByKey[C](createCombiner: V => C, def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C, mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C, mergeCombiners: (C, C) => C,
partitioner: Partitioner): RDD[(K, C)] = { partitioner: Partitioner,
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) mapSideCombine: Boolean = true): RDD[(K, C)] = {
new ShuffledRDD(self, aggregator, partitioner) val aggregator =
if (mapSideCombine) {
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
new Aggregator[K, V, C](createCombiner, mergeValue, null, false)
}
new ShuffledAggregatedRDD(self, aggregator, partitioner)
} }
def combineByKey[C](createCombiner: V => C, def combineByKey[C](createCombiner: V => C,
@ -116,13 +124,24 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
groupByKey(new HashPartitioner(numSplits)) groupByKey(new HashPartitioner(numSplits))
} }
def partitionBy(partitioner: Partitioner): RDD[(K, V)] = { /**
def createCombiner(v: V) = ArrayBuffer(v) * Repartition the RDD using the specified partitioner. If mapSideCombine is
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v * true, Spark will group values of the same key together on the map side
def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2 * before the repartitioning. If a large number of duplicated keys are
val bufs = combineByKey[ArrayBuffer[V]]( * expected, and the size of the keys are large, mapSideCombine should be set
createCombiner _, mergeValue _, mergeCombiners _, partitioner) * to true.
bufs.flatMapValues(buf => buf) */
def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = {
if (mapSideCombine) {
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
val bufs = combineByKey[ArrayBuffer[V]](
createCombiner _, mergeValue _, mergeCombiners _, partitioner)
bufs.flatMapValues(buf => buf)
} else {
new RepartitionShuffledRDD(self, partitioner)
}
} }
def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = {
@ -416,22 +435,8 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
extends Logging extends Logging
with Serializable { with Serializable {
def sortByKey(ascending: Boolean = true): RDD[(K,V)] = { def sortByKey(ascending: Boolean = true, numSplits: Int = self.splits.size): RDD[(K,V)] = {
val rangePartitionedRDD = self.partitionBy(new RangePartitioner(self.splits.size, self, ascending)) new ShuffledSortedRDD(self, ascending, numSplits)
new SortedRDD(rangePartitionedRDD, ascending)
}
}
class SortedRDD[K <% Ordered[K], V](prev: RDD[(K, V)], ascending: Boolean)
extends RDD[(K, V)](prev.context) {
override def splits = prev.splits
override val partitioner = prev.partitioner
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = {
prev.iterator(split).toArray
.sortWith((x, y) => if (ascending) x._1 < y._1 else x._1 > y._1).iterator
} }
} }
@ -454,6 +459,6 @@ class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U]
} }
} }
object Manifests { private[spark] object Manifests {
val seqSeqManifest = classManifest[Seq[Seq[_]]] val seqSeqManifest = classManifest[Seq[Seq[_]]]
} }

Просмотреть файл

@ -3,7 +3,7 @@ package spark
import scala.collection.immutable.NumericRange import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
class ParallelCollectionSplit[T: ClassManifest]( private[spark] class ParallelCollectionSplit[T: ClassManifest](
val rddId: Long, val rddId: Long,
val slice: Int, val slice: Int,
values: Seq[T]) values: Seq[T])

Просмотреть файл

@ -41,9 +41,9 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
Array() Array()
} else { } else {
val rddSize = rdd.count() val rddSize = rdd.count()
val maxSampleSize = partitions * 10.0 val maxSampleSize = partitions * 20.0
val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0) val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
val rddSample = rdd.sample(true, frac, 1).map(_._1).collect().sortWith(_ < _) val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sortWith(_ < _)
if (rddSample.length == 0) { if (rddSample.length == 0) {
Array() Array()
} else { } else {

Просмотреть файл

@ -61,6 +61,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def compute(split: Split): Iterator[T] def compute(split: Split): Iterator[T]
@transient val dependencies: List[Dependency[_]] @transient val dependencies: List[Dependency[_]]
// Record user function generating this RDD
val origin = Utils.getSparkCallSite
// Optionally overridden by subclasses to specify how they are partitioned // Optionally overridden by subclasses to specify how they are partitioned
val partitioner: Option[Partitioner] = None val partitioner: Option[Partitioner] = None
@ -69,6 +72,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def context = sc def context = sc
def elementClassManifest: ClassManifest[T] = classManifest[T]
// Get a unique ID for this RDD // Get a unique ID for this RDD
val id = sc.newRddId() val id = sc.newRddId()
@ -87,21 +92,21 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
} }
// Turn on the default caching level for this RDD // Turn on the default caching level for this RDD
def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY_DESER) def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY)
// Turn on the default caching level for this RDD // Turn on the default caching level for this RDD
def cache(): RDD[T] = persist() def cache(): RDD[T] = persist()
def getStorageLevel = storageLevel def getStorageLevel = storageLevel
def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER): RDD[T] = { private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
if (!level.useDisk && level.replication < 2) { if (!level.useDisk && level.replication < 2) {
throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
} }
// This is a hack. Ideally this should re-use the code used by the CacheTracker // This is a hack. Ideally this should re-use the code used by the CacheTracker
// to generate the key. // to generate the key.
def getSplitKey(split: Split) = "rdd:%d:%d".format(this.id, split.index) def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index)
persist(level) persist(level)
sc.runJob(this, (iter: Iterator[T]) => {} ) sc.runJob(this, (iter: Iterator[T]) => {} )
@ -131,7 +136,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f)) def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f))
def distinct(): RDD[T] = map(x => (x, "")).reduceByKey((x, y) => x).map(_._1) def distinct(numSplits: Int = splits.size): RDD[T] =
map(x => (x, null)).reduceByKey((x, y) => x, numSplits).map(_._1)
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] =
new SampledRDD(this, withReplacement, fraction, seed) new SampledRDD(this, withReplacement, fraction, seed)
@ -143,8 +149,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
var initialCount = count() var initialCount = count()
var maxSelected = 0 var maxSelected = 0
if (initialCount > Integer.MAX_VALUE) { if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE maxSelected = Integer.MAX_VALUE - 1
} else { } else {
maxSelected = initialCount.toInt maxSelected = initialCount.toInt
} }
@ -159,15 +165,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
total = num total = num
} }
var samples = this.sample(withReplacement, fraction, seed).collect() val rand = new Random(seed)
var samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
while (samples.length < total) { while (samples.length < total) {
samples = this.sample(withReplacement, fraction, seed).collect() samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
} }
val arr = samples.take(total) Utils.randomizeInPlace(samples, rand).take(total)
return arr
} }
def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other)) def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
@ -195,6 +200,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] = def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] =
new MapPartitionsRDD(this, sc.clean(f)) new MapPartitionsRDD(this, sc.clean(f))
def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
new MapPartitionsWithSplitRDD(this, sc.clean(f))
// Actions (launch a job to return a value to the user program) // Actions (launch a job to return a value to the user program)
def foreach(f: T => Unit) { def foreach(f: T => Unit) {
@ -416,3 +424,18 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
override val dependencies = List(new OneToOneDependency(prev)) override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = f(prev.iterator(split)) override def compute(split: Split) = f(prev.iterator(split))
} }
/**
* A variant of the MapPartitionsRDD that passes the split index into the
* closure. This can be used to generate or collect partition specific
* information such as the number of tuples in a partition.
*/
class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: (Int, Iterator[T]) => Iterator[U])
extends RDD[U](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = f(split.index, prev.iterator(split))
}

Просмотреть файл

@ -1,7 +1,10 @@
package spark package spark
import java.util.Random import java.util.Random
import cern.jet.random.Poisson
import cern.jet.random.engine.DRand
private[spark]
class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable { class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
override val index: Int = prev.index override val index: Int = prev.index
} }
@ -28,19 +31,21 @@ class SampledRDD[T: ClassManifest](
override def compute(splitIn: Split) = { override def compute(splitIn: Split) = {
val split = splitIn.asInstanceOf[SampledRDDSplit] val split = splitIn.asInstanceOf[SampledRDDSplit]
val rg = new Random(split.seed)
// Sampling with replacement (TODO: use reservoir sampling to make this more efficient?)
if (withReplacement) { if (withReplacement) {
val oldData = prev.iterator(split.prev).toArray // For large datasets, the expected number of occurrences of each element in a sample with
val sampleSize = (oldData.size * frac).ceil.toInt // replacement is Poisson(frac). We use that to get a count for each element.
val sampledData = { val poisson = new Poisson(frac, new DRand(split.seed))
// all of oldData's indices are candidates, even if sampleSize < oldData.size prev.iterator(split.prev).flatMap { element =>
for (i <- 1 to sampleSize) val count = poisson.nextInt()
yield oldData(rg.nextInt(oldData.size)) if (count == 0) {
Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
} else {
Iterator.fill(count)(element)
}
} }
sampledData.iterator
} else { // Sampling without replacement } else { // Sampling without replacement
prev.iterator(split.prev).filter(x => (rg.nextDouble <= frac)) val rand = new Random(split.seed)
prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac))
} }
} }
} }

Просмотреть файл

@ -12,14 +12,14 @@ import spark.util.ByteBufferInputStream
* A serializer. Because some serialization libraries are not thread safe, this class is used to * A serializer. Because some serialization libraries are not thread safe, this class is used to
* create SerializerInstances that do the actual serialization. * create SerializerInstances that do the actual serialization.
*/ */
trait Serializer { private[spark] trait Serializer {
def newInstance(): SerializerInstance def newInstance(): SerializerInstance
} }
/** /**
* An instance of the serializer, for use by one thread at a time. * An instance of the serializer, for use by one thread at a time.
*/ */
trait SerializerInstance { private[spark] trait SerializerInstance {
def serialize[T](t: T): ByteBuffer def serialize[T](t: T): ByteBuffer
def deserialize[T](bytes: ByteBuffer): T def deserialize[T](bytes: ByteBuffer): T
@ -43,15 +43,15 @@ trait SerializerInstance {
def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
// Default implementation uses deserializeStream // Default implementation uses deserializeStream
buffer.rewind() buffer.rewind()
deserializeStream(new ByteBufferInputStream(buffer)).toIterator deserializeStream(new ByteBufferInputStream(buffer)).asIterator
} }
} }
/** /**
* A stream for writing serialized objects. * A stream for writing serialized objects.
*/ */
trait SerializationStream { private[spark] trait SerializationStream {
def writeObject[T](t: T): Unit def writeObject[T](t: T): SerializationStream
def flush(): Unit def flush(): Unit
def close(): Unit def close(): Unit
@ -66,7 +66,7 @@ trait SerializationStream {
/** /**
* A stream for reading serialized objects. * A stream for reading serialized objects.
*/ */
trait DeserializationStream { private[spark] trait DeserializationStream {
def readObject[T](): T def readObject[T](): T
def close(): Unit def close(): Unit
@ -74,7 +74,7 @@ trait DeserializationStream {
* Read the elements of this stream through an iterator. This can only be called once, as * Read the elements of this stream through an iterator. This can only be called once, as
* reading each element will consume data from the input source. * reading each element will consume data from the input source.
*/ */
def toIterator: Iterator[Any] = new Iterator[Any] { def asIterator: Iterator[Any] = new Iterator[Any] {
var gotNext = false var gotNext = false
var finished = false var finished = false
var nextValue: Any = null var nextValue: Any = null

Просмотреть файл

@ -1,6 +1,6 @@
package spark package spark
abstract class ShuffleFetcher { private[spark] abstract class ShuffleFetcher {
// Fetch the shuffle outputs for a given ShuffleDependency, calling func exactly // Fetch the shuffle outputs for a given ShuffleDependency, calling func exactly
// once on each key-value pair obtained. // once on each key-value pair obtained.
def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit)

Просмотреть файл

@ -1,98 +0,0 @@
package spark
import java.io._
import java.net.URL
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.{ArrayBuffer, HashMap}
import spark._
class ShuffleManager extends Logging {
private var nextShuffleId = new AtomicLong(0)
private var shuffleDir: File = null
private var server: HttpServer = null
private var serverUri: String = null
initialize()
private def initialize() {
// TODO: localDir should be created by some mechanism common to Spark
// so that it can be shared among shuffle, broadcast, etc
val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
var tries = 0
var foundLocalDir = false
var localDir: File = null
var localDirUuid: UUID = null
while (!foundLocalDir && tries < 10) {
tries += 1
try {
localDirUuid = UUID.randomUUID
localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
if (!localDir.exists) {
localDir.mkdirs()
foundLocalDir = true
}
} catch {
case e: Exception =>
logWarning("Attempt " + tries + " to create local dir failed", e)
}
}
if (!foundLocalDir) {
logError("Failed 10 attempts to create local dir in " + localDirRoot)
System.exit(1)
}
shuffleDir = new File(localDir, "shuffle")
shuffleDir.mkdirs()
logInfo("Shuffle dir: " + shuffleDir)
// Add a shutdown hook to delete the local dir
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dir") {
override def run() {
Utils.deleteRecursively(localDir)
}
})
val extServerPort = System.getProperty(
"spark.localFileShuffle.external.server.port", "-1").toInt
if (extServerPort != -1) {
// We're using an external HTTP server; set URI relative to its root
var extServerPath = System.getProperty(
"spark.localFileShuffle.external.server.path", "")
if (extServerPath != "" && !extServerPath.endsWith("/")) {
extServerPath += "/"
}
serverUri = "http://%s:%d/%s/spark-local-%s".format(
Utils.localIpAddress, extServerPort, extServerPath, localDirUuid)
} else {
// Create our own server
server = new HttpServer(localDir)
server.start()
serverUri = server.uri
}
logInfo("Local URI: " + serverUri)
}
def stop() {
if (server != null) {
server.stop()
}
}
def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
val dir = new File(shuffleDir, shuffleId + "/" + inputId)
dir.mkdirs()
val file = new File(dir, "" + outputId)
return file
}
def getServerUri(): String = {
serverUri
}
def newShuffleId(): Long = {
nextShuffleId.getAndIncrement()
}
}

Просмотреть файл

@ -1,18 +1,24 @@
package spark package spark
import scala.collection.mutable.ArrayBuffer
import java.util.{HashMap => JHashMap} import java.util.{HashMap => JHashMap}
class ShuffledRDDSplit(val idx: Int) extends Split {
private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
override val index = idx override val index = idx
override def hashCode(): Int = idx override def hashCode(): Int = idx
} }
class ShuffledRDD[K, V, C](
/**
* The resulting RDD from a shuffle (e.g. repartitioning of data).
*/
abstract class ShuffledRDD[K, V, C](
@transient parent: RDD[(K, V)], @transient parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C], aggregator: Aggregator[K, V, C],
part : Partitioner) part: Partitioner)
extends RDD[(K, C)](parent.context) { extends RDD[(K, C)](parent.context) {
//override val partitioner = Some(part)
override val partitioner = Some(part) override val partitioner = Some(part)
@transient @transient
@ -24,6 +30,65 @@ class ShuffledRDD[K, V, C](
val dep = new ShuffleDependency(context.newShuffleId, parent, aggregator, part) val dep = new ShuffleDependency(context.newShuffleId, parent, aggregator, part)
override val dependencies = List(dep) override val dependencies = List(dep)
}
/**
* Repartition a key-value pair RDD.
*/
class RepartitionShuffledRDD[K, V](
@transient parent: RDD[(K, V)],
part: Partitioner)
extends ShuffledRDD[K, V, V](
parent,
Aggregator[K, V, V](null, null, null, false),
part) {
override def compute(split: Split): Iterator[(K, V)] = {
val buf = new ArrayBuffer[(K, V)]
val fetcher = SparkEnv.get.shuffleFetcher
def addTupleToBuffer(k: K, v: V) = { buf += Tuple(k, v) }
fetcher.fetch[K, V](dep.shuffleId, split.index, addTupleToBuffer)
buf.iterator
}
}
/**
* A sort-based shuffle (that doesn't apply aggregation). It does so by first
* repartitioning the RDD by range, and then sort within each range.
*/
class ShuffledSortedRDD[K <% Ordered[K]: ClassManifest, V](
@transient parent: RDD[(K, V)],
ascending: Boolean,
numSplits: Int)
extends RepartitionShuffledRDD[K, V](
parent,
new RangePartitioner(numSplits, parent, ascending)) {
override def compute(split: Split): Iterator[(K, V)] = {
// By separating this from RepartitionShuffledRDD, we avoided a
// buf.iterator.toArray call, thus avoiding building up the buffer twice.
val buf = new ArrayBuffer[(K, V)]
def addTupleToBuffer(k: K, v: V) { buf += ((k, v)) }
SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index, addTupleToBuffer)
if (ascending) {
buf.sortWith((x, y) => x._1 < y._1).iterator
} else {
buf.sortWith((x, y) => x._1 > y._1).iterator
}
}
}
/**
* The resulting RDD from shuffle and running (hash-based) aggregation.
*/
class ShuffledAggregatedRDD[K, V, C](
@transient parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C],
part : Partitioner)
extends ShuffledRDD[K, V, C](parent, aggregator, part) {
override def compute(split: Split): Iterator[(K, C)] = { override def compute(split: Split): Iterator[(K, C)] = {
val combiners = new JHashMap[K, C] val combiners = new JHashMap[K, C]

Просмотреть файл

@ -22,7 +22,7 @@ import it.unimi.dsi.fastutil.ints.IntOpenHashSet
* Based on the following JavaWorld article: * Based on the following JavaWorld article:
* http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html
*/ */
object SizeEstimator extends Logging { private[spark] object SizeEstimator extends Logging {
// Sizes of primitive types // Sizes of primitive types
private val BYTE_SIZE = 1 private val BYTE_SIZE = 1
@ -77,22 +77,18 @@ object SizeEstimator extends Logging {
return System.getProperty("spark.test.useCompressedOops").toBoolean return System.getProperty("spark.test.useCompressedOops").toBoolean
} }
try { try {
val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic"; val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic"
val server = ManagementFactory.getPlatformMBeanServer(); val server = ManagementFactory.getPlatformMBeanServer()
val bean = ManagementFactory.newPlatformMXBeanProxy(server, val bean = ManagementFactory.newPlatformMXBeanProxy(server,
hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean]); hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean])
return bean.getVMOption("UseCompressedOops").getValue.toBoolean return bean.getVMOption("UseCompressedOops").getValue.toBoolean
} catch { } catch {
case e: IllegalArgumentException => { case e: Exception => {
logWarning("Exception while trying to check if compressed oops is enabled", e) // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB
// Fall back to checking if maxMemory < 32GB val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024)
return Runtime.getRuntime.maxMemory < (32L*1024*1024*1024) val guessInWords = if (guess) "yes" else "not"
} logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords)
return guess
case e: SecurityException => {
logWarning("No permission to create MBeanServer", e)
// Fall back to checking if maxMemory < 32GB
return Runtime.getRuntime.maxMemory < (32L*1024*1024*1024)
} }
} }
} }
@ -146,6 +142,10 @@ object SizeEstimator extends Logging {
val cls = obj.getClass val cls = obj.getClass
if (cls.isArray) { if (cls.isArray) {
visitArray(obj, cls, state) visitArray(obj, cls, state)
} else if (obj.isInstanceOf[ClassLoader] || obj.isInstanceOf[Class[_]]) {
// Hadoop JobConfs created in the interpreter have a ClassLoader, which greatly confuses
// the size estimator since it references the whole REPL. Do nothing in this case. In
// general all ClassLoaders and Classes will be shared between objects anyway.
} else { } else {
val classInfo = getClassInfo(cls) val classInfo = getClassInfo(cls)
state.size += classInfo.shellSize state.size += classInfo.shellSize

Просмотреть файл

@ -5,7 +5,7 @@ import com.google.common.collect.MapMaker
/** /**
* An implementation of Cache that uses soft references. * An implementation of Cache that uses soft references.
*/ */
class SoftReferenceCache extends Cache { private[spark] class SoftReferenceCache extends Cache {
val map = new MapMaker().softValues().makeMap[Any, Any]() val map = new MapMaker().softValues().makeMap[Any, Any]()
override def get(datasetId: Any, partition: Int): Any = override def get(datasetId: Any, partition: Int): Any =

Просмотреть файл

@ -2,13 +2,15 @@ package spark
import java.io._ import java.io._
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
import java.net.{URI, URLClassLoader}
import akka.actor.Actor import akka.actor.Actor
import akka.actor.Actor._ import akka.actor.Actor._
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.generic.Growable
import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.conf.Configuration import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat import org.apache.hadoop.mapred.SequenceFileInputFormat
@ -34,6 +36,8 @@ import org.apache.mesos.{Scheduler, MesosNativeLibrary}
import spark.broadcast._ import spark.broadcast._
import spark.deploy.LocalSparkCluster
import spark.partial.ApproximateEvaluator import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult import spark.partial.PartialResult
@ -75,24 +79,33 @@ class SparkContext(
isLocal) isLocal)
SparkEnv.set(env) SparkEnv.set(env)
// Used to store a URL for each static file/jar together with the file's local timestamp
val addedFiles = HashMap[String, Long]()
val addedJars = HashMap[String, Long]()
// Add each JAR given through the constructor
jars.foreach { addJar(_) }
// Create and start the scheduler // Create and start the scheduler
private var taskScheduler: TaskScheduler = { private var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format // Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks // Regular expression for local[N, maxRetries], used in tests with failing tasks
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+),([0-9]+)\]""".r val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters // Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """(spark://.*)""".r val SPARK_REGEX = """(spark://.*)""".r
master match { master match {
case "local" => case "local" =>
new LocalScheduler(1, 0) new LocalScheduler(1, 0, this)
case LOCAL_N_REGEX(threads) => case LOCAL_N_REGEX(threads) =>
new LocalScheduler(threads.toInt, 0) new LocalScheduler(threads.toInt, 0, this)
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
new LocalScheduler(threads.toInt, maxFailures.toInt) new LocalScheduler(threads.toInt, maxFailures.toInt, this)
case SPARK_REGEX(sparkUrl) => case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this) val scheduler = new ClusterScheduler(this)
@ -100,6 +113,28 @@ class SparkContext(
scheduler.initialize(backend) scheduler.initialize(backend)
scheduler scheduler
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
// Check to make sure SPARK_MEM <= memoryPerSlave. Otherwise Spark will just hang.
val memoryPerSlaveInt = memoryPerSlave.toInt
val sparkMemEnv = System.getenv("SPARK_MEM")
val sparkMemEnvInt = if (sparkMemEnv != null) Utils.memoryStringToMb(sparkMemEnv) else 512
if (sparkMemEnvInt > memoryPerSlaveInt) {
throw new SparkException(
"Slave memory (%d MB) cannot be smaller than SPARK_MEM (%d MB)".format(
memoryPerSlaveInt, sparkMemEnvInt))
}
val scheduler = new ClusterScheduler(this)
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
val sparkUrl = localCluster.start()
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, frameworkName)
scheduler.initialize(backend)
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop()
}
scheduler
case _ => case _ =>
MesosNativeLibrary.load() MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this) val scheduler = new ClusterScheduler(this)
@ -292,10 +327,57 @@ class SparkContext(
def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) = def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
new Accumulable(initialValue, param) new Accumulable(initialValue, param)
/**
* Create an accumulator from a "mutable collection" type.
*
* Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
* standard mutable collections. So you can use this with mutable Map, Set, etc.
*/
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = {
val param = new GrowableAccumulableParam[R,T]
new Accumulable(initialValue, param)
}
// Keep around a weak hash map of values to Cached versions? // Keep around a weak hash map of values to Cached versions?
def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal) def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
// Adds a file dependency to all Tasks executed in the future.
def addFile(path: String) {
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
case _ => path
}
addedFiles(key) = System.currentTimeMillis
// Fetch the file locally in case the task is executed locally
val filename = new File(path.split("/").last)
Utils.fetchFile(path, new File("."))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
def clearFiles() {
addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedFiles.clear()
}
// Adds a jar dependency to all Tasks executed in the future.
def addJar(path: String) {
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addJar(new File(uri.getPath))
case _ => path
}
addedJars(key) = System.currentTimeMillis
logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
}
def clearJars() {
addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedJars.clear()
}
// Stop the SparkContext // Stop the SparkContext
def stop() { def stop() {
dagScheduler.stop() dagScheduler.stop()
@ -303,6 +385,9 @@ class SparkContext(
taskScheduler = null taskScheduler = null
// TODO: Cache.stop()? // TODO: Cache.stop()?
env.stop() env.stop()
// Clean up locally linked files
clearFiles()
clearJars()
SparkEnv.set(null) SparkEnv.set(null)
ShuffleMapTask.clearCache() ShuffleMapTask.clearCache()
logInfo("Successfully stopped SparkContext") logInfo("Successfully stopped SparkContext")
@ -335,10 +420,11 @@ class SparkContext(
partitions: Seq[Int], partitions: Seq[Int],
allowLocal: Boolean allowLocal: Boolean
): Array[U] = { ): Array[U] = {
logInfo("Starting job...") val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, allowLocal) val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal)
logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s") logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
result result
} }
@ -371,10 +457,11 @@ class SparkContext(
evaluator: ApproximateEvaluator[U, R], evaluator: ApproximateEvaluator[U, R],
timeout: Long timeout: Long
): PartialResult[R] = { ): PartialResult[R] = {
logInfo("Starting job...") val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime val start = System.nanoTime
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, timeout) val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout)
logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s") logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
result result
} }
@ -513,7 +600,7 @@ object SparkContext {
* that doesn't know the type of T when it is created. This sounds strange but is necessary to * that doesn't know the type of T when it is created. This sounds strange but is necessary to
* support converting subclasses of Writable to themselves (writableWritableConverter). * support converting subclasses of Writable to themselves (writableWritableConverter).
*/ */
class WritableConverter[T]( private[spark] class WritableConverter[T](
val writableClass: ClassManifest[T] => Class[_ <: Writable], val writableClass: ClassManifest[T] => Class[_ <: Writable],
val convert: Writable => T) val convert: Writable => T)
extends Serializable extends Serializable

Просмотреть файл

@ -1,6 +1,8 @@
package spark package spark
import akka.actor.ActorSystem import akka.actor.ActorSystem
import akka.actor.ActorSystemImpl
import akka.remote.RemoteActorRefProvider
import spark.broadcast.BroadcastManager import spark.broadcast.BroadcastManager
import spark.storage.BlockManager import spark.storage.BlockManager
@ -8,35 +10,45 @@ import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager import spark.network.ConnectionManager
import spark.util.AkkaUtils import spark.util.AkkaUtils
/**
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
* including the serializer, Akka actor system, block manager, map output tracker, etc. Currently
* Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these
* objects needs to have the right SparkEnv set. You can get the current environment with
* SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
*/
class SparkEnv ( class SparkEnv (
val actorSystem: ActorSystem, val actorSystem: ActorSystem,
val cache: Cache,
val serializer: Serializer, val serializer: Serializer,
val closureSerializer: Serializer, val closureSerializer: Serializer,
val cacheTracker: CacheTracker, val cacheTracker: CacheTracker,
val mapOutputTracker: MapOutputTracker, val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher, val shuffleFetcher: ShuffleFetcher,
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager, val broadcastManager: BroadcastManager,
val blockManager: BlockManager, val blockManager: BlockManager,
val connectionManager: ConnectionManager val connectionManager: ConnectionManager,
val httpFileServer: HttpFileServer
) { ) {
/** No-parameter constructor for unit tests. */ /** No-parameter constructor for unit tests. */
def this() = { def this() = {
this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null) this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
} }
def stop() { def stop() {
httpFileServer.stop()
mapOutputTracker.stop() mapOutputTracker.stop()
cacheTracker.stop() cacheTracker.stop()
shuffleFetcher.stop() shuffleFetcher.stop()
shuffleManager.stop()
broadcastManager.stop() broadcastManager.stop()
blockManager.stop() blockManager.stop()
blockManager.master.stop() blockManager.master.stop()
actorSystem.shutdown() actorSystem.shutdown()
// Akka's awaitTermination doesn't actually wait until the port is unbound, so sleep a bit
Thread.sleep(100)
actorSystem.awaitTermination() actorSystem.awaitTermination()
// Akka's awaitTermination doesn't actually wait until the port is unbound, so sleep a bit
Thread.sleep(100)
} }
} }
@ -66,66 +78,49 @@ object SparkEnv {
System.setProperty("spark.master.port", boundPort.toString) System.setProperty("spark.master.port", boundPort.toString)
} }
val serializerClass = System.getProperty("spark.serializer", "spark.KryoSerializer") val classLoader = Thread.currentThread.getContextClassLoader
val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
// Create an instance of the class named by the given Java system property, or by
// defaultClassName if the property is not set, and return it as a T
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
val name = System.getProperty(propertyName, defaultClassName)
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
}
val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
val blockManagerMaster = new BlockManagerMaster(actorSystem, isMaster, isLocal) val blockManagerMaster = new BlockManagerMaster(actorSystem, isMaster, isLocal)
val blockManager = new BlockManager(blockManagerMaster, serializer) val blockManager = new BlockManager(blockManagerMaster, serializer)
val connectionManager = blockManager.connectionManager val connectionManager = blockManager.connectionManager
val shuffleManager = new ShuffleManager()
val broadcastManager = new BroadcastManager(isMaster) val broadcastManager = new BroadcastManager(isMaster)
val closureSerializerClass = val closureSerializer = instantiateClass[Serializer](
System.getProperty("spark.closure.serializer", "spark.JavaSerializer") "spark.closure.serializer", "spark.JavaSerializer")
val closureSerializer =
Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer]
val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache")
val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager) val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager)
blockManager.cacheTracker = cacheTracker blockManager.cacheTracker = cacheTracker
val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster) val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster)
val shuffleFetcherClass = val shuffleFetcher = instantiateClass[ShuffleFetcher](
System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
val shuffleFetcher =
Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher]
/* val httpFileServer = new HttpFileServer()
if (System.getProperty("spark.stream.distributed", "false") == "true") { httpFileServer.initialize()
val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]] System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
if (isLocal || !isMaster) {
(new Thread() {
override def run() {
println("Wait started")
Thread.sleep(60000)
println("Wait ended")
val receiverClass = Class.forName("spark.stream.TestStreamReceiver4")
val constructor = receiverClass.getConstructor(blockManagerClass)
val receiver = constructor.newInstance(blockManager)
receiver.asInstanceOf[Thread].start()
}
}).start()
}
}
*/
new SparkEnv( new SparkEnv(
actorSystem, actorSystem,
cache,
serializer, serializer,
closureSerializer, closureSerializer,
cacheTracker, cacheTracker,
mapOutputTracker, mapOutputTracker,
shuffleFetcher, shuffleFetcher,
shuffleManager,
broadcastManager, broadcastManager,
blockManager, blockManager,
connectionManager) connectionManager,
httpFileServer)
} }
} }

Просмотреть файл

@ -7,10 +7,16 @@ import spark.storage.BlockManagerId
* tasks several times for "ephemeral" failures, and only report back failures that require some * tasks several times for "ephemeral" failures, and only report back failures that require some
* old stages to be resubmitted, such as shuffle map fetch failures. * old stages to be resubmitted, such as shuffle map fetch failures.
*/ */
sealed trait TaskEndReason private[spark] sealed trait TaskEndReason
case object Success extends TaskEndReason private[spark] case object Success extends TaskEndReason
private[spark]
case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
private[spark]
case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
case class ExceptionFailure(exception: Throwable) extends TaskEndReason
case class OtherFailure(message: String) extends TaskEndReason private[spark] case class ExceptionFailure(exception: Throwable) extends TaskEndReason
private[spark] case class OtherFailure(message: String) extends TaskEndReason

Просмотреть файл

@ -2,7 +2,7 @@ package spark
import org.apache.mesos.Protos.{TaskState => MesosTaskState} import org.apache.mesos.Protos.{TaskState => MesosTaskState}
object TaskState private[spark] object TaskState
extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") { extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") {
val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value

Просмотреть файл

@ -2,7 +2,7 @@ package spark
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
class UnionSplit[T: ClassManifest]( private[spark] class UnionSplit[T: ClassManifest](
idx: Int, idx: Int,
rdd: RDD[T], rdd: RDD[T],
split: Split) split: Split)

Просмотреть файл

@ -1,18 +1,18 @@
package spark package spark
import java.io._ import java.io._
import java.net.InetAddress import java.net.{InetAddress, URL, URI}
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import java.util.{Locale, UUID}
import scala.io.Source import scala.io.Source
/** /**
* Various utility methods used by Spark. * Various utility methods used by Spark.
*/ */
object Utils { private object Utils extends Logging {
/** Serialize an object using Java serialization */ /** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = { def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream() val bos = new ByteArrayOutputStream()
@ -116,22 +116,75 @@ object Utils {
copyStream(in, out, true) copyStream(in, out, true)
} }
/** Download a file from a given URL to the local filesystem */
def downloadFile(url: URL, localPath: String) {
val in = url.openStream()
val out = new FileOutputStream(localPath)
Utils.copyStream(in, out, true)
}
/**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
*/
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
val targetFile = new File(targetDir, filename)
val uri = new URI(url)
uri.getScheme match {
case "http" | "https" | "ftp" =>
logInfo("Fetching " + url + " to " + targetFile)
val in = new URL(url).openStream()
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
case "file" | null =>
// Remove the file if it already exists
targetFile.delete()
// Symlink the file locally
logInfo("Symlinking " + url + " to " + targetFile)
FileUtil.symLink(url, targetFile.toString)
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val uri = new URI(url)
val conf = new Configuration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
}
// Decompress the file if it's a .tar or .tar.gz
if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xzf", filename), targetDir)
} else if (filename.endsWith(".tar")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xf", filename), targetDir)
}
// Make the file executable - That's necessary for scripts
FileUtil.chmod(filename, "a+x")
}
/** /**
* Shuffle the elements of a collection into a random order, returning the * Shuffle the elements of a collection into a random order, returning the
* result in a new collection. Unlike scala.util.Random.shuffle, this method * result in a new collection. Unlike scala.util.Random.shuffle, this method
* uses a local random number generator, avoiding inter-thread contention. * uses a local random number generator, avoiding inter-thread contention.
*/ */
def randomize[T](seq: TraversableOnce[T]): Seq[T] = { def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = {
val buf = new ArrayBuffer[T]() randomizeInPlace(seq.toArray)
buf ++= seq }
val rand = new Random()
for (i <- (buf.size - 1) to 1 by -1) { /**
* Shuffle the elements of an array into a random order, modifying the
* original array. Returns the original array.
*/
def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = {
for (i <- (arr.length - 1) to 1 by -1) {
val j = rand.nextInt(i) val j = rand.nextInt(i)
val tmp = buf(j) val tmp = arr(j)
buf(j) = buf(i) arr(j) = arr(i)
buf(i) = tmp arr(i) = tmp
} }
buf arr
} }
/** /**
@ -294,4 +347,43 @@ object Utils {
def execute(command: Seq[String]) { def execute(command: Seq[String]) {
execute(command, new File(".")) execute(command, new File("."))
} }
/**
* When called inside a class in the spark package, returns the name of the user code class
* (outside the spark package) that called into Spark, as well as which Spark method they called.
* This is used, for example, to tell users where in their code each RDD got created.
*/
def getSparkCallSite: String = {
val trace = Thread.currentThread.getStackTrace().filter( el =>
(!el.getMethodName.contains("getStackTrace")))
// Keep crawling up the stack trace until we find the first function not inside of the spark
// package. We track the last (shallowest) contiguous Spark method. This might be an RDD
// transformation, a SparkContext function (such as parallelize), or anything else that leads
// to instantiation of an RDD. We also track the first (deepest) user method, file, and line.
var lastSparkMethod = "<unknown>"
var firstUserFile = "<unknown>"
var firstUserLine = 0
var finished = false
for (el <- trace) {
if (!finished) {
if (el.getClassName.startsWith("spark.") && !el.getClassName.startsWith("spark.examples.")) {
lastSparkMethod = if (el.getMethodName == "<init>") {
// Spark method is a constructor; get its class name
el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1)
} else {
el.getMethodName
}
}
else {
firstUserLine = el.getLineNumber
firstUserFile = el.getFileName
finished = true
}
}
}
"%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
}
} }

Просмотреть файл

@ -33,6 +33,8 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
def distinct(): JavaDoubleRDD = fromRDD(srdd.distinct()) def distinct(): JavaDoubleRDD = fromRDD(srdd.distinct())
def distinct(numSplits: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numSplits))
def filter(f: JFunction[Double, java.lang.Boolean]): JavaDoubleRDD = def filter(f: JFunction[Double, java.lang.Boolean]): JavaDoubleRDD =
fromRDD(srdd.filter(x => f(x).booleanValue())) fromRDD(srdd.filter(x => f(x).booleanValue()))

Просмотреть файл

@ -40,6 +40,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
def distinct(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct()) def distinct(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct())
def distinct(numSplits: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numSplits))
def filter(f: Function[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] = def filter(f: Function[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue())) new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue()))

Просмотреть файл

@ -19,6 +19,8 @@ JavaRDDLike[T, JavaRDD[T]] {
def distinct(): JavaRDD[T] = wrapRDD(rdd.distinct()) def distinct(): JavaRDD[T] = wrapRDD(rdd.distinct())
def distinct(numSplits: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numSplits))
def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] = def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] =
wrapRDD(rdd.filter((x => f(x).booleanValue()))) wrapRDD(rdd.filter((x => f(x).booleanValue())))

Просмотреть файл

@ -7,7 +7,7 @@ import scala.runtime.AbstractFunction1
* apply() method as call() and declare that it can throw Exception (since AbstractFunction1.apply * apply() method as call() and declare that it can throw Exception (since AbstractFunction1.apply
* isn't marked to allow that). * isn't marked to allow that).
*/ */
abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] { private[spark] abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] {
@throws(classOf[Exception]) @throws(classOf[Exception])
def call(t: T): R def call(t: T): R

Просмотреть файл

@ -7,7 +7,7 @@ import scala.runtime.AbstractFunction2
* apply() method as call() and declare that it can throw Exception (since AbstractFunction2.apply * apply() method as call() and declare that it can throw Exception (since AbstractFunction2.apply
* isn't marked to allow that). * isn't marked to allow that).
*/ */
abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] { private[spark] abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] {
@throws(classOf[Exception]) @throws(classOf[Exception])
def call(t1: T1, t2: T2): R def call(t1: T1, t2: T2): R

Просмотреть файл

@ -11,14 +11,17 @@ import scala.math
import spark._ import spark._
import spark.storage.StorageLevel import spark.storage.StorageLevel
class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean) private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T] with Logging with Serializable { extends Broadcast[T](id)
with Logging
with Serializable {
def value = value_ def value = value_
def blockId: String = "broadcast_" + id
MultiTracker.synchronized { MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle( SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
} }
@transient var arrayOfBlocks: Array[BroadcastBlock] = null @transient var arrayOfBlocks: Array[BroadcastBlock] = null
@ -45,7 +48,7 @@ extends Broadcast[T] with Logging with Serializable {
// Used only in Workers // Used only in Workers
@transient var ttGuide: TalkToGuide = null @transient var ttGuide: TalkToGuide = null
@transient var hostAddress = Utils.localIpAddress @transient var hostAddress = Utils.localIpAddress()
@transient var listenPort = -1 @transient var listenPort = -1
@transient var guidePort = -1 @transient var guidePort = -1
@ -53,7 +56,7 @@ extends Broadcast[T] with Logging with Serializable {
// Must call this after all the variables have been created/initialized // Must call this after all the variables have been created/initialized
if (!isLocal) { if (!isLocal) {
sendBroadcast sendBroadcast()
} }
def sendBroadcast() { def sendBroadcast() {
@ -106,20 +109,22 @@ extends Broadcast[T] with Logging with Serializable {
listOfSources += masterSource listOfSources += masterSource
// Register with the Tracker // Register with the Tracker
MultiTracker.registerBroadcast(uuid, MultiTracker.registerBroadcast(id,
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
} }
private def readObject(in: ObjectInputStream) { private def readObject(in: ObjectInputStream) {
in.defaultReadObject() in.defaultReadObject()
MultiTracker.synchronized { MultiTracker.synchronized {
SparkEnv.get.blockManager.getSingle(uuid.toString) match { SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) => x.asInstanceOf[T] case Some(x) =>
case None => { value_ = x.asInstanceOf[T]
logInfo("Started reading broadcast variable " + uuid)
case None =>
logInfo("Started reading broadcast variable " + id)
// Initializing everything because Master will only send null/0 values // Initializing everything because Master will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache // Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables initializeWorkerVariables()
logInfo("Local host address: " + hostAddress) logInfo("Local host address: " + hostAddress)
@ -131,18 +136,17 @@ extends Broadcast[T] with Logging with Serializable {
val start = System.nanoTime val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(uuid) val receptionSucceeded = receiveBroadcast(id)
if (receptionSucceeded) { if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
SparkEnv.get.blockManager.putSingle( SparkEnv.get.blockManager.putSingle(
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false) blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
} else { } else {
logError("Reading Broadcasted variable " + uuid + " failed") logError("Reading broadcast variable " + id + " failed")
} }
val time = (System.nanoTime - start) / 1e9 val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
} }
} }
} }
@ -254,8 +258,8 @@ extends Broadcast[T] with Logging with Serializable {
} }
} }
def receiveBroadcast(variableUUID: UUID): Boolean = { def receiveBroadcast(variableID: Long): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableUUID) val gInfo = MultiTracker.getGuideInfo(variableID)
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false return false
@ -764,7 +768,7 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("Sending stopBroadcast notifications...") logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications sendStopBroadcastNotifications
MultiTracker.unregisterBroadcast(uuid) MultiTracker.unregisterBroadcast(id)
} finally { } finally {
if (serverSocket != null) { if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...") logInfo("GuideMultipleRequests now stopping...")
@ -1025,9 +1029,12 @@ extends Broadcast[T] with Logging with Serializable {
} }
} }
class BitTorrentBroadcastFactory private[spark] class BitTorrentBroadcastFactory
extends BroadcastFactory { extends BroadcastFactory {
def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster) def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
def newBroadcast[T](value_ : T, isLocal: Boolean) = new BitTorrentBroadcast[T](value_, isLocal)
def stop() = MultiTracker.stop def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new BitTorrentBroadcast[T](value_, isLocal, id)
def stop() { MultiTracker.stop() }
} }

Просмотреть файл

@ -1,25 +1,20 @@
package spark.broadcast package spark.broadcast
import java.io._ import java.io._
import java.net._ import java.util.concurrent.atomic.AtomicLong
import java.util.{BitSet, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import scala.collection.mutable.Map
import spark._ import spark._
trait Broadcast[T] extends Serializable { abstract class Broadcast[T](id: Long) extends Serializable {
val uuid = UUID.randomUUID
def value: T def value: T
// We cannot have an abstract readObject here due to some weird issues with // We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes. // readObject having to be 'private' in sub-classes.
override def toString = "spark.Broadcast(" + uuid + ")" override def toString = "spark.Broadcast(" + id + ")"
} }
private[spark]
class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable { class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable {
private var initialized = false private var initialized = false
@ -49,14 +44,10 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl
broadcastFactory.stop() broadcastFactory.stop()
} }
private def getBroadcastFactory: BroadcastFactory = { private val nextBroadcastId = new AtomicLong(0)
if (broadcastFactory == null) {
throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
}
broadcastFactory
}
def newBroadcast[T](value_ : T, isLocal: Boolean) = broadcastFactory.newBroadcast[T](value_, isLocal) def newBroadcast[T](value_ : T, isLocal: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
def isMaster = isMaster_ def isMaster = isMaster_
} }

Просмотреть файл

@ -6,8 +6,8 @@ package spark.broadcast
* BroadcastFactory implementation to instantiate a particular broadcast for the * BroadcastFactory implementation to instantiate a particular broadcast for the
* entire Spark job. * entire Spark job.
*/ */
trait BroadcastFactory { private[spark] trait BroadcastFactory {
def initialize(isMaster: Boolean): Unit def initialize(isMaster: Boolean): Unit
def newBroadcast[T](value_ : T, isLocal: Boolean): Broadcast[T] def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit def stop(): Unit
} }

Просмотреть файл

@ -12,44 +12,47 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark._ import spark._
import spark.storage.StorageLevel import spark.storage.StorageLevel
class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean) private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T] with Logging with Serializable { extends Broadcast[T](id) with Logging with Serializable {
def value = value_ def value = value_
def blockId: String = "broadcast_" + id
HttpBroadcast.synchronized { HttpBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle( SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
} }
if (!isLocal) { if (!isLocal) {
HttpBroadcast.write(uuid, value_) HttpBroadcast.write(id, value_)
} }
// Called by JVM when deserializing an object // Called by JVM when deserializing an object
private def readObject(in: ObjectInputStream) { private def readObject(in: ObjectInputStream) {
in.defaultReadObject() in.defaultReadObject()
HttpBroadcast.synchronized { HttpBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(uuid.toString) match { SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) => value_ = x.asInstanceOf[T] case Some(x) => value_ = x.asInstanceOf[T]
case None => { case None => {
logInfo("Started reading broadcast variable " + uuid) logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime val start = System.nanoTime
value_ = HttpBroadcast.read[T](uuid) value_ = HttpBroadcast.read[T](id)
SparkEnv.get.blockManager.putSingle( SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
val time = (System.nanoTime - start) / 1e9 val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + uuid + " took " + time + " s") logInfo("Reading broadcast variable " + id + " took " + time + " s")
} }
} }
} }
} }
} }
class HttpBroadcastFactory extends BroadcastFactory { private[spark] class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isMaster: Boolean) = HttpBroadcast.initialize(isMaster) def initialize(isMaster: Boolean) { HttpBroadcast.initialize(isMaster) }
def newBroadcast[T](value_ : T, isLocal: Boolean) = new HttpBroadcast[T](value_, isLocal)
def stop() = HttpBroadcast.stop() def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
def stop() { HttpBroadcast.stop() }
} }
private object HttpBroadcast extends Logging { private object HttpBroadcast extends Logging {
@ -65,7 +68,7 @@ private object HttpBroadcast extends Logging {
synchronized { synchronized {
if (!initialized) { if (!initialized) {
bufferSize = System.getProperty("spark.buffer.size", "65536").toInt bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
compress = System.getProperty("spark.compress", "false").toBoolean compress = System.getProperty("spark.broadcast.compress", "true").toBoolean
if (isMaster) { if (isMaster) {
createServer() createServer()
} }
@ -76,9 +79,12 @@ private object HttpBroadcast extends Logging {
} }
def stop() { def stop() {
if (server != null) { synchronized {
server.stop() if (server != null) {
server = null server.stop()
server = null
}
initialized = false
} }
} }
@ -91,8 +97,8 @@ private object HttpBroadcast extends Logging {
logInfo("Broadcast server started at " + serverUri) logInfo("Broadcast server started at " + serverUri)
} }
def write(uuid: UUID, value: Any) { def write(id: Long, value: Any) {
val file = new File(broadcastDir, "broadcast-" + uuid) val file = new File(broadcastDir, "broadcast-" + id)
val out: OutputStream = if (compress) { val out: OutputStream = if (compress) {
new LZFOutputStream(new FileOutputStream(file)) // Does its own buffering new LZFOutputStream(new FileOutputStream(file)) // Does its own buffering
} else { } else {
@ -104,8 +110,8 @@ private object HttpBroadcast extends Logging {
serOut.close() serOut.close()
} }
def read[T](uuid: UUID): T = { def read[T](id: Long): T = {
val url = serverUri + "/broadcast-" + uuid val url = serverUri + "/broadcast-" + id
var in = if (compress) { var in = if (compress) {
new LZFInputStream(new URL(url).openStream()) // Does its own buffering new LZFInputStream(new URL(url).openStream()) // Does its own buffering
} else { } else {

Просмотреть файл

@ -2,8 +2,7 @@ package spark.broadcast
import java.io._ import java.io._
import java.net._ import java.net._
import java.util.{UUID, Random} import java.util.Random
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import scala.collection.mutable.Map import scala.collection.mutable.Map
@ -18,7 +17,7 @@ extends Logging {
val FIND_BROADCAST_TRACKER = 2 val FIND_BROADCAST_TRACKER = 2
// Map to keep track of guides of ongoing broadcasts // Map to keep track of guides of ongoing broadcasts
var valueToGuideMap = Map[UUID, SourceInfo]() var valueToGuideMap = Map[Long, SourceInfo]()
// Random number generator // Random number generator
var ranGen = new Random var ranGen = new Random
@ -154,44 +153,44 @@ extends Logging {
val messageType = ois.readObject.asInstanceOf[Int] val messageType = ois.readObject.asInstanceOf[Int]
if (messageType == REGISTER_BROADCAST_TRACKER) { if (messageType == REGISTER_BROADCAST_TRACKER) {
// Receive UUID // Receive Long
val uuid = ois.readObject.asInstanceOf[UUID] val id = ois.readObject.asInstanceOf[Long]
// Receive hostAddress and listenPort // Receive hostAddress and listenPort
val gInfo = ois.readObject.asInstanceOf[SourceInfo] val gInfo = ois.readObject.asInstanceOf[SourceInfo]
// Add to the map // Add to the map
valueToGuideMap.synchronized { valueToGuideMap.synchronized {
valueToGuideMap += (uuid -> gInfo) valueToGuideMap += (id -> gInfo)
} }
logInfo ("New broadcast " + uuid + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap) logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK // Send dummy ACK
oos.writeObject(-1) oos.writeObject(-1)
oos.flush() oos.flush()
} else if (messageType == UNREGISTER_BROADCAST_TRACKER) { } else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
// Receive UUID // Receive Long
val uuid = ois.readObject.asInstanceOf[UUID] val id = ois.readObject.asInstanceOf[Long]
// Remove from the map // Remove from the map
valueToGuideMap.synchronized { valueToGuideMap.synchronized {
valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToDefault) valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
} }
logInfo ("Broadcast " + uuid + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap) logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK // Send dummy ACK
oos.writeObject(-1) oos.writeObject(-1)
oos.flush() oos.flush()
} else if (messageType == FIND_BROADCAST_TRACKER) { } else if (messageType == FIND_BROADCAST_TRACKER) {
// Receive UUID // Receive Long
val uuid = ois.readObject.asInstanceOf[UUID] val id = ois.readObject.asInstanceOf[Long]
var gInfo = var gInfo =
if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid) if (valueToGuideMap.contains(id)) valueToGuideMap(id)
else SourceInfo("", SourceInfo.TxNotStartedRetry) else SourceInfo("", SourceInfo.TxNotStartedRetry)
logDebug("Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort) logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)
// Send reply back // Send reply back
oos.writeObject(gInfo) oos.writeObject(gInfo)
@ -224,7 +223,7 @@ extends Logging {
} }
} }
def getGuideInfo(variableUUID: UUID): SourceInfo = { def getGuideInfo(variableLong: Long): SourceInfo = {
var clientSocketToTracker: Socket = null var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null var oisTracker: ObjectInputStream = null
@ -247,8 +246,8 @@ extends Logging {
oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER) oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
oosTracker.flush() oosTracker.flush()
// Send UUID and receive GuideInfo // Send Long and receive GuideInfo
oosTracker.writeObject(variableUUID) oosTracker.writeObject(variableLong)
oosTracker.flush() oosTracker.flush()
gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
} catch { } catch {
@ -276,7 +275,7 @@ extends Logging {
return gInfo return gInfo
} }
def registerBroadcast(uuid: UUID, gInfo: SourceInfo) { def registerBroadcast(id: Long, gInfo: SourceInfo) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort) val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream) val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush() oosST.flush()
@ -286,8 +285,8 @@ extends Logging {
oosST.writeObject(REGISTER_BROADCAST_TRACKER) oosST.writeObject(REGISTER_BROADCAST_TRACKER)
oosST.flush() oosST.flush()
// Send UUID of this broadcast // Send Long of this broadcast
oosST.writeObject(uuid) oosST.writeObject(id)
oosST.flush() oosST.flush()
// Send this tracker's information // Send this tracker's information
@ -303,7 +302,7 @@ extends Logging {
socket.close() socket.close()
} }
def unregisterBroadcast(uuid: UUID) { def unregisterBroadcast(id: Long) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort) val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream) val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush() oosST.flush()
@ -313,8 +312,8 @@ extends Logging {
oosST.writeObject(UNREGISTER_BROADCAST_TRACKER) oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
oosST.flush() oosST.flush()
// Send UUID of this broadcast // Send Long of this broadcast
oosST.writeObject(uuid) oosST.writeObject(id)
oosST.flush() oosST.flush()
// Receive ACK and throw it away // Receive ACK and throw it away
@ -383,10 +382,10 @@ extends Logging {
} }
} }
case class BroadcastBlock(blockID: Int, byteArray: Array[Byte]) private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
extends Serializable extends Serializable
case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock], private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
totalBlocks: Int, totalBlocks: Int,
totalBytes: Int) totalBytes: Int)
extends Serializable { extends Serializable {

Просмотреть файл

@ -7,7 +7,7 @@ import spark._
/** /**
* Used to keep and pass around information of peers involved in a broadcast * Used to keep and pass around information of peers involved in a broadcast
*/ */
case class SourceInfo (hostAddress: String, private[spark] case class SourceInfo (hostAddress: String,
listenPort: Int, listenPort: Int,
totalBlocks: Int = SourceInfo.UnusedParam, totalBlocks: Int = SourceInfo.UnusedParam,
totalBytes: Int = SourceInfo.UnusedParam) totalBytes: Int = SourceInfo.UnusedParam)
@ -26,7 +26,7 @@ extends Comparable[SourceInfo] with Logging {
/** /**
* Helper Object of SourceInfo for its constants * Helper Object of SourceInfo for its constants
*/ */
object SourceInfo { private[spark] object SourceInfo {
// Broadcast has not started yet! Should never happen. // Broadcast has not started yet! Should never happen.
val TxNotStartedRetry = -1 val TxNotStartedRetry = -1
// Broadcast has already finished. Try default mechanism. // Broadcast has already finished. Try default mechanism.

Просмотреть файл

@ -10,14 +10,15 @@ import scala.math
import spark._ import spark._
import spark.storage.StorageLevel import spark.storage.StorageLevel
class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean) private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T] with Logging with Serializable { extends Broadcast[T](id) with Logging with Serializable {
def value = value_ def value = value_
def blockId = "broadcast_" + id
MultiTracker.synchronized { MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle( SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
} }
@transient var arrayOfBlocks: Array[BroadcastBlock] = null @transient var arrayOfBlocks: Array[BroadcastBlock] = null
@ -35,7 +36,7 @@ extends Broadcast[T] with Logging with Serializable {
@transient var serveMR: ServeMultipleRequests = null @transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null @transient var guideMR: GuideMultipleRequests = null
@transient var hostAddress = Utils.localIpAddress @transient var hostAddress = Utils.localIpAddress()
@transient var listenPort = -1 @transient var listenPort = -1
@transient var guidePort = -1 @transient var guidePort = -1
@ -43,7 +44,7 @@ extends Broadcast[T] with Logging with Serializable {
// Must call this after all the variables have been created/initialized // Must call this after all the variables have been created/initialized
if (!isLocal) { if (!isLocal) {
sendBroadcast sendBroadcast()
} }
def sendBroadcast() { def sendBroadcast() {
@ -84,20 +85,22 @@ extends Broadcast[T] with Logging with Serializable {
listOfSources += masterSource listOfSources += masterSource
// Register with the Tracker // Register with the Tracker
MultiTracker.registerBroadcast(uuid, MultiTracker.registerBroadcast(id,
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
} }
private def readObject(in: ObjectInputStream) { private def readObject(in: ObjectInputStream) {
in.defaultReadObject() in.defaultReadObject()
MultiTracker.synchronized { MultiTracker.synchronized {
SparkEnv.get.blockManager.getSingle(uuid.toString) match { SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) => x.asInstanceOf[T] case Some(x) =>
case None => { value_ = x.asInstanceOf[T]
logInfo("Started reading broadcast variable " + uuid)
case None =>
logInfo("Started reading broadcast variable " + id)
// Initializing everything because Master will only send null/0 values // Initializing everything because Master will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache // Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables initializeWorkerVariables()
logInfo("Local host address: " + hostAddress) logInfo("Local host address: " + hostAddress)
@ -108,18 +111,17 @@ extends Broadcast[T] with Logging with Serializable {
val start = System.nanoTime val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(uuid) val receptionSucceeded = receiveBroadcast(id)
if (receptionSucceeded) { if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
SparkEnv.get.blockManager.putSingle( SparkEnv.get.blockManager.putSingle(
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false) blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
} else { } else {
logError("Reading Broadcasted variable " + uuid + " failed") logError("Reading broadcast variable " + id + " failed")
} }
val time = (System.nanoTime - start) / 1e9 val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
} }
} }
} }
@ -136,14 +138,14 @@ extends Broadcast[T] with Logging with Serializable {
serveMR = null serveMR = null
hostAddress = Utils.localIpAddress hostAddress = Utils.localIpAddress()
listenPort = -1 listenPort = -1
stopBroadcast = false stopBroadcast = false
} }
def receiveBroadcast(variableUUID: UUID): Boolean = { def receiveBroadcast(variableID: Long): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableUUID) val gInfo = MultiTracker.getGuideInfo(variableID)
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false return false
@ -318,7 +320,7 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("Sending stopBroadcast notifications...") logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications sendStopBroadcastNotifications
MultiTracker.unregisterBroadcast(uuid) MultiTracker.unregisterBroadcast(id)
} finally { } finally {
if (serverSocket != null) { if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...") logInfo("GuideMultipleRequests now stopping...")
@ -572,9 +574,12 @@ extends Broadcast[T] with Logging with Serializable {
} }
} }
class TreeBroadcastFactory private[spark] class TreeBroadcastFactory
extends BroadcastFactory { extends BroadcastFactory {
def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster) def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
def newBroadcast[T](value_ : T, isLocal: Boolean) = new TreeBroadcast[T](value_, isLocal)
def stop() = MultiTracker.stop def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TreeBroadcast[T](value_, isLocal, id)
def stop() { MultiTracker.stop() }
} }

Просмотреть файл

@ -2,7 +2,7 @@ package spark.deploy
import scala.collection.Map import scala.collection.Map
case class Command( private[spark] case class Command(
mainClass: String, mainClass: String,
arguments: Seq[String], arguments: Seq[String],
environment: Map[String, String]) { environment: Map[String, String]) {

Просмотреть файл

@ -7,13 +7,15 @@ import scala.collection.immutable.List
import scala.collection.mutable.HashMap import scala.collection.mutable.HashMap
sealed trait DeployMessage extends Serializable private[spark] sealed trait DeployMessage extends Serializable
// Worker to Master // Worker to Master
private[spark]
case class RegisterWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int) case class RegisterWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int)
extends DeployMessage extends DeployMessage
private[spark]
case class ExecutorStateChanged( case class ExecutorStateChanged(
jobId: String, jobId: String,
execId: Int, execId: Int,
@ -23,11 +25,11 @@ case class ExecutorStateChanged(
// Master to Worker // Master to Worker
case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage private[spark] case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
case class RegisterWorkerFailed(message: String) extends DeployMessage private[spark] case class RegisterWorkerFailed(message: String) extends DeployMessage
case class KillExecutor(jobId: String, execId: Int) extends DeployMessage private[spark] case class KillExecutor(jobId: String, execId: Int) extends DeployMessage
case class LaunchExecutor( private[spark] case class LaunchExecutor(
jobId: String, jobId: String,
execId: Int, execId: Int,
jobDesc: JobDescription, jobDesc: JobDescription,
@ -38,33 +40,42 @@ case class LaunchExecutor(
// Client to Master // Client to Master
case class RegisterJob(jobDescription: JobDescription) extends DeployMessage private[spark] case class RegisterJob(jobDescription: JobDescription) extends DeployMessage
// Master to Client // Master to Client
private[spark]
case class RegisteredJob(jobId: String) extends DeployMessage case class RegisteredJob(jobId: String) extends DeployMessage
private[spark]
case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int)
private[spark]
case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String]) case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String])
private[spark]
case class JobKilled(message: String) case class JobKilled(message: String)
// Internal message in Client // Internal message in Client
case object StopClient private[spark] case object StopClient
// MasterWebUI To Master // MasterWebUI To Master
case object RequestMasterState private[spark] case object RequestMasterState
// Master to MasterWebUI // Master to MasterWebUI
private[spark]
case class MasterState(uri : String, workers: List[WorkerInfo], activeJobs: List[JobInfo], case class MasterState(uri : String, workers: List[WorkerInfo], activeJobs: List[JobInfo],
completedJobs: List[JobInfo]) completedJobs: List[JobInfo])
// WorkerWebUI to Worker // WorkerWebUI to Worker
case object RequestWorkerState private[spark] case object RequestWorkerState
// Worker to WorkerWebUI // Worker to WorkerWebUI
private[spark]
case class WorkerState(uri: String, workerId: String, executors: List[ExecutorRunner], case class WorkerState(uri: String, workerId: String, executors: List[ExecutorRunner],
finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int, finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int,
coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String)

Просмотреть файл

@ -1,6 +1,6 @@
package spark.deploy package spark.deploy
object ExecutorState private[spark] object ExecutorState
extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") { extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") {
val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value

Просмотреть файл

@ -1,6 +1,6 @@
package spark.deploy package spark.deploy
class JobDescription( private[spark] class JobDescription(
val name: String, val name: String,
val cores: Int, val cores: Int,
val memoryPerSlave: Int, val memoryPerSlave: Int,

Просмотреть файл

@ -0,0 +1,58 @@
package spark.deploy
import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
import spark.deploy.worker.Worker
import spark.deploy.master.Master
import spark.util.AkkaUtils
import spark.{Logging, Utils}
import scala.collection.mutable.ArrayBuffer
private[spark]
class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging {
val localIpAddress = Utils.localIpAddress
var masterActor : ActorRef = _
var masterActorSystem : ActorSystem = _
var masterPort : Int = _
var masterUrl : String = _
val slaveActorSystems = ArrayBuffer[ActorSystem]()
val slaveActors = ArrayBuffer[ActorRef]()
def start() : String = {
logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.")
/* Start the Master */
val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0)
masterActorSystem = actorSystem
masterUrl = "spark://" + localIpAddress + ":" + masterPort
val actor = masterActorSystem.actorOf(
Props(new Master(localIpAddress, masterPort, 0)), name = "Master")
masterActor = actor
/* Start the Slaves */
for (slaveNum <- 1 to numSlaves) {
val (actorSystem, boundPort) =
AkkaUtils.createActorSystem("sparkWorker" + slaveNum, localIpAddress, 0)
slaveActorSystems += actorSystem
val actor = actorSystem.actorOf(
Props(new Worker(localIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)),
name = "Worker")
slaveActors += actor
}
return masterUrl
}
def stop() {
logInfo("Shutting down local Spark cluster.")
// Stop the slaves before the master so they don't get upset that it disconnected
slaveActorSystems.foreach(_.shutdown())
slaveActorSystems.foreach(_.awaitTermination())
masterActorSystem.shutdown()
masterActorSystem.awaitTermination()
}
}

Просмотреть файл

@ -16,7 +16,7 @@ import akka.dispatch.Await
* The main class used to talk to a Spark deploy cluster. Takes a master URL, a job description, * The main class used to talk to a Spark deploy cluster. Takes a master URL, a job description,
* and a listener for job events, and calls back the listener when various events occur. * and a listener for job events, and calls back the listener when various events occur.
*/ */
class Client( private[spark] class Client(
actorSystem: ActorSystem, actorSystem: ActorSystem,
masterUrl: String, masterUrl: String,
jobDescription: JobDescription, jobDescription: JobDescription,
@ -42,7 +42,6 @@ class Client(
val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort) val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
try { try {
master = context.actorFor(akkaUrl) master = context.actorFor(akkaUrl)
//master ! RegisterWorker(ip, port, cores, memory)
master ! RegisterJob(jobDescription) master ! RegisterJob(jobDescription)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing context.watch(master) // Doesn't work with remote actors, but useful for testing

Просмотреть файл

@ -7,7 +7,7 @@ package spark.deploy.client
* *
* Users of this API should *not* block inside the callback methods. * Users of this API should *not* block inside the callback methods.
*/ */
trait ClientListener { private[spark] trait ClientListener {
def connected(jobId: String): Unit def connected(jobId: String): Unit
def disconnected(): Unit def disconnected(): Unit

Просмотреть файл

@ -4,7 +4,7 @@ import spark.util.AkkaUtils
import spark.{Logging, Utils} import spark.{Logging, Utils}
import spark.deploy.{Command, JobDescription} import spark.deploy.{Command, JobDescription}
object TestClient { private[spark] object TestClient {
class TestListener extends ClientListener with Logging { class TestListener extends ClientListener with Logging {
def connected(id: String) { def connected(id: String) {

Просмотреть файл

@ -1,6 +1,6 @@
package spark.deploy.client package spark.deploy.client
object TestExecutor { private[spark] object TestExecutor {
def main(args: Array[String]) { def main(args: Array[String]) {
println("Hello world!") println("Hello world!")
while (true) { while (true) {

Просмотреть файл

@ -2,7 +2,7 @@ package spark.deploy.master
import spark.deploy.ExecutorState import spark.deploy.ExecutorState
class ExecutorInfo( private[spark] class ExecutorInfo(
val id: Int, val id: Int,
val job: JobInfo, val job: JobInfo,
val worker: WorkerInfo, val worker: WorkerInfo,

Просмотреть файл

@ -5,6 +5,7 @@ import java.util.Date
import akka.actor.ActorRef import akka.actor.ActorRef
import scala.collection.mutable import scala.collection.mutable
private[spark]
class JobInfo(val id: String, val desc: JobDescription, val submitDate: Date, val actor: ActorRef) { class JobInfo(val id: String, val desc: JobDescription, val submitDate: Date, val actor: ActorRef) {
var state = JobState.WAITING var state = JobState.WAITING
var executors = new mutable.HashMap[Int, ExecutorInfo] var executors = new mutable.HashMap[Int, ExecutorInfo]
@ -31,4 +32,13 @@ class JobInfo(val id: String, val desc: JobDescription, val submitDate: Date, va
} }
def coresLeft: Int = desc.cores - coresGranted def coresLeft: Int = desc.cores - coresGranted
private var _retryCount = 0
def retryCount = _retryCount
def incrementRetryCount = {
_retryCount += 1
_retryCount
}
} }

Просмотреть файл

@ -1,7 +1,9 @@
package spark.deploy.master package spark.deploy.master
object JobState extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") { private[spark] object JobState extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") {
type JobState = Value type JobState = Value
val WAITING, RUNNING, FINISHED, FAILED = Value val WAITING, RUNNING, FINISHED, FAILED = Value
val MAX_NUM_RETRY = 10
} }

Просмотреть файл

@ -1,21 +1,20 @@
package spark.deploy.master package spark.deploy.master
import akka.actor._
import akka.actor.Terminated
import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown}
import java.text.SimpleDateFormat
import java.util.Date
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import akka.actor._
import spark.{Logging, Utils}
import spark.util.AkkaUtils
import java.text.SimpleDateFormat
import java.util.Date
import akka.remote.RemoteClientLifeCycleEvent
import spark.deploy._ import spark.deploy._
import akka.remote.RemoteClientShutdown import spark.{Logging, SparkException, Utils}
import akka.remote.RemoteClientDisconnected import spark.util.AkkaUtils
import spark.deploy.RegisterWorker
import spark.deploy.RegisterWorkerFailed
import akka.actor.Terminated
class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For job IDs val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For job IDs
var nextJobNumber = 0 var nextJobNumber = 0
@ -81,12 +80,22 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
exec.state = state exec.state = state
exec.job.actor ! ExecutorUpdated(execId, state, message) exec.job.actor ! ExecutorUpdated(execId, state, message)
if (ExecutorState.isFinished(state)) { if (ExecutorState.isFinished(state)) {
val jobInfo = idToJob(jobId)
// Remove this executor from the worker and job // Remove this executor from the worker and job
logInfo("Removing executor " + exec.fullId + " because it is " + state) logInfo("Removing executor " + exec.fullId + " because it is " + state)
idToJob(jobId).removeExecutor(exec) jobInfo.removeExecutor(exec)
exec.worker.removeExecutor(exec) exec.worker.removeExecutor(exec)
// TODO: the worker would probably want to restart the executor a few times
schedule() // Only retry certain number of times so we don't go into an infinite loop.
if (jobInfo.incrementRetryCount <= JobState.MAX_NUM_RETRY) {
schedule()
} else {
val e = new SparkException("Job %s wth ID %s failed %d times.".format(
jobInfo.desc.name, jobInfo.id, jobInfo.retryCount))
logError(e.getMessage, e)
throw e
//System.exit(1)
}
} }
} }
case None => case None =>
@ -203,7 +212,7 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
} }
} }
object Master { private[spark] object Master {
def main(argStrings: Array[String]) { def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings) val args = new MasterArguments(argStrings)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port) val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)

Просмотреть файл

@ -6,7 +6,7 @@ import spark.Utils
/** /**
* Command-line parser for the master. * Command-line parser for the master.
*/ */
class MasterArguments(args: Array[String]) { private[spark] class MasterArguments(args: Array[String]) {
var ip = Utils.localIpAddress() var ip = Utils.localIpAddress()
var port = 7077 var port = 7077
var webUiPort = 8080 var webUiPort = 8080
@ -51,7 +51,7 @@ class MasterArguments(args: Array[String]) {
*/ */
def printUsageAndExit(exitCode: Int) { def printUsageAndExit(exitCode: Int) {
System.err.println( System.err.println(
"Usage: spark-master [options]\n" + "Usage: Master [options]\n" +
"\n" + "\n" +
"Options:\n" + "Options:\n" +
" -i IP, --ip IP IP address or DNS name to listen on\n" + " -i IP, --ip IP IP address or DNS name to listen on\n" +

Просмотреть файл

@ -10,6 +10,7 @@ import cc.spray.directives._
import cc.spray.typeconversion.TwirlSupport._ import cc.spray.typeconversion.TwirlSupport._
import spark.deploy._ import spark.deploy._
private[spark]
class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives { class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/master/webui" val RESOURCE_DIR = "spark/deploy/master/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static" val STATIC_RESOURCE_DIR = "spark/deploy/static"
@ -22,7 +23,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
completeWith { completeWith {
val future = master ? RequestMasterState val future = master ? RequestMasterState
future.map { future.map {
masterState => masterui.html.index.render(masterState.asInstanceOf[MasterState]) masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState])
} }
} }
} ~ } ~
@ -36,7 +37,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
// A bit ugly an inefficient, but we won't have a number of jobs // A bit ugly an inefficient, but we won't have a number of jobs
// so large that it will make a significant difference. // so large that it will make a significant difference.
(masterState.activeJobs ::: masterState.completedJobs).find(_.id == jobId) match { (masterState.activeJobs ::: masterState.completedJobs).find(_.id == jobId) match {
case Some(job) => masterui.html.job_details.render(job) case Some(job) => spark.deploy.master.html.job_details.render(job)
case _ => null case _ => null
} }
} }

Просмотреть файл

@ -3,7 +3,7 @@ package spark.deploy.master
import akka.actor.ActorRef import akka.actor.ActorRef
import scala.collection.mutable import scala.collection.mutable
class WorkerInfo( private[spark] class WorkerInfo(
val id: String, val id: String,
val host: String, val host: String,
val port: Int, val port: Int,

Просмотреть файл

@ -13,7 +13,7 @@ import spark.deploy.ExecutorStateChanged
/** /**
* Manages the execution of one executor process. * Manages the execution of one executor process.
*/ */
class ExecutorRunner( private[spark] class ExecutorRunner(
val jobId: String, val jobId: String,
val execId: Int, val execId: Int,
val jobDesc: JobDescription, val jobDesc: JobDescription,
@ -29,12 +29,25 @@ class ExecutorRunner(
val fullId = jobId + "/" + execId val fullId = jobId + "/" + execId
var workerThread: Thread = null var workerThread: Thread = null
var process: Process = null var process: Process = null
var shutdownHook: Thread = null
def start() { def start() {
workerThread = new Thread("ExecutorRunner for " + fullId) { workerThread = new Thread("ExecutorRunner for " + fullId) {
override def run() { fetchAndRunExecutor() } override def run() { fetchAndRunExecutor() }
} }
workerThread.start() workerThread.start()
// Shutdown hook that kills actors on shutdown.
shutdownHook = new Thread() {
override def run() {
if (process != null) {
logInfo("Shutdown hook killing child process.")
process.destroy()
process.waitFor()
}
}
}
Runtime.getRuntime.addShutdownHook(shutdownHook)
} }
/** Stop this executor runner, including killing the process it launched */ /** Stop this executor runner, including killing the process it launched */
@ -45,40 +58,10 @@ class ExecutorRunner(
if (process != null) { if (process != null) {
logInfo("Killing process!") logInfo("Killing process!")
process.destroy() process.destroy()
process.waitFor()
} }
worker ! ExecutorStateChanged(jobId, execId, ExecutorState.KILLED, None) worker ! ExecutorStateChanged(jobId, execId, ExecutorState.KILLED, None)
} Runtime.getRuntime.removeShutdownHook(shutdownHook)
}
/**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
*/
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
val targetFile = new File(targetDir, filename)
if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) {
// Use the java.net library to fetch it
logInfo("Fetching " + url + " to " + targetFile)
val in = new URL(url).openStream()
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
} else {
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val uri = new URI(url)
val conf = new Configuration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
}
// Decompress the file if it's a .tar or .tar.gz
if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xzf", filename), targetDir)
} else if (filename.endsWith(".tar")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xf", filename), targetDir)
} }
} }
@ -92,7 +75,8 @@ class ExecutorRunner(
def buildCommandSeq(): Seq[String] = { def buildCommandSeq(): Seq[String] = {
val command = jobDesc.command val command = jobDesc.command
val runScript = new File(sparkHome, "run").getCanonicalPath val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run";
val runScript = new File(sparkHome, script).getCanonicalPath
Seq(runScript, command.mainClass) ++ command.arguments.map(substituteVariables) Seq(runScript, command.mainClass) ++ command.arguments.map(substituteVariables)
} }
@ -101,7 +85,12 @@ class ExecutorRunner(
val out = new FileOutputStream(file) val out = new FileOutputStream(file)
new Thread("redirect output to " + file) { new Thread("redirect output to " + file) {
override def run() { override def run() {
Utils.copyStream(in, out, true) try {
Utils.copyStream(in, out, true)
} catch {
case e: IOException =>
logInfo("Redirection to " + file + " closed: " + e.getMessage)
}
} }
}.start() }.start()
} }
@ -131,6 +120,9 @@ class ExecutorRunner(
} }
env.put("SPARK_CORES", cores.toString) env.put("SPARK_CORES", cores.toString)
env.put("SPARK_MEMORY", memory.toString) env.put("SPARK_MEMORY", memory.toString)
// In case we are running this from within the Spark Shell
// so we are not creating a parent process.
env.put("SPARK_LAUNCH_WITH_SCALA", "0")
process = builder.start() process = builder.start()
// Redirect its stdout and stderr to files // Redirect its stdout and stderr to files

Просмотреть файл

@ -16,7 +16,14 @@ import spark.deploy.RegisterWorkerFailed
import akka.actor.Terminated import akka.actor.Terminated
import java.io.File import java.io.File
class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, masterUrl: String) private[spark] class Worker(
ip: String,
port: Int,
webUiPort: Int,
cores: Int,
memory: Int,
masterUrl: String,
workDirPath: String = null)
extends Actor with Logging { extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
@ -37,7 +44,11 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, mas
def memoryFree: Int = memory - memoryUsed def memoryFree: Int = memory - memoryUsed
def createWorkDir() { def createWorkDir() {
workDir = new File(sparkHome, "work") workDir = if (workDirPath != null) {
new File(workDirPath)
} else {
new File(sparkHome, "work")
}
try { try {
if (!workDir.exists() && !workDir.mkdirs()) { if (!workDir.exists() && !workDir.mkdirs()) {
logError("Failed to create work directory " + workDir) logError("Failed to create work directory " + workDir)
@ -153,14 +164,19 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, mas
def generateWorkerId(): String = { def generateWorkerId(): String = {
"worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port) "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port)
} }
override def postStop() {
executors.values.foreach(_.kill())
}
} }
object Worker { private[spark] object Worker {
def main(argStrings: Array[String]) { def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings) val args = new WorkerArguments(argStrings)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port) val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
val actor = actorSystem.actorOf( val actor = actorSystem.actorOf(
Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory, args.master)), Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory,
args.master, args.workDir)),
name = "Worker") name = "Worker")
actorSystem.awaitTermination() actorSystem.awaitTermination()
} }

Просмотреть файл

@ -8,13 +8,14 @@ import java.lang.management.ManagementFactory
/** /**
* Command-line parser for the master. * Command-line parser for the master.
*/ */
class WorkerArguments(args: Array[String]) { private[spark] class WorkerArguments(args: Array[String]) {
var ip = Utils.localIpAddress() var ip = Utils.localIpAddress()
var port = 0 var port = 0
var webUiPort = 8081 var webUiPort = 8081
var cores = inferDefaultCores() var cores = inferDefaultCores()
var memory = inferDefaultMemory() var memory = inferDefaultMemory()
var master: String = null var master: String = null
var workDir: String = null
// Check for settings in environment variables // Check for settings in environment variables
if (System.getenv("SPARK_WORKER_PORT") != null) { if (System.getenv("SPARK_WORKER_PORT") != null) {
@ -29,6 +30,9 @@ class WorkerArguments(args: Array[String]) {
if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) { if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) {
webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt
} }
if (System.getenv("SPARK_WORKER_DIR") != null) {
workDir = System.getenv("SPARK_WORKER_DIR")
}
parse(args.toList) parse(args.toList)
@ -49,6 +53,10 @@ class WorkerArguments(args: Array[String]) {
memory = value memory = value
parse(tail) parse(tail)
case ("--work-dir" | "-d") :: value :: tail =>
workDir = value
parse(tail)
case "--webui-port" :: IntParam(value) :: tail => case "--webui-port" :: IntParam(value) :: tail =>
webUiPort = value webUiPort = value
parse(tail) parse(tail)
@ -77,13 +85,14 @@ class WorkerArguments(args: Array[String]) {
*/ */
def printUsageAndExit(exitCode: Int) { def printUsageAndExit(exitCode: Int) {
System.err.println( System.err.println(
"Usage: spark-worker [options] <master>\n" + "Usage: Worker [options] <master>\n" +
"\n" + "\n" +
"Master must be a URL of the form spark://hostname:port\n" + "Master must be a URL of the form spark://hostname:port\n" +
"\n" + "\n" +
"Options:\n" + "Options:\n" +
" -c CORES, --cores CORES Number of cores to use\n" + " -c CORES, --cores CORES Number of cores to use\n" +
" -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" + " -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" +
" -d DIR, --work-dir DIR Directory to run jobs in (default: SPARK_HOME/work)\n" +
" -i IP, --ip IP IP address or DNS name to listen on\n" + " -i IP, --ip IP IP address or DNS name to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: random)\n" + " -p PORT, --port PORT Port to listen on (default: random)\n" +
" --webui-port PORT Port for web UI (default: 8081)") " --webui-port PORT Port for web UI (default: 8081)")

Просмотреть файл

@ -9,6 +9,7 @@ import cc.spray.Directives
import cc.spray.typeconversion.TwirlSupport._ import cc.spray.typeconversion.TwirlSupport._
import spark.deploy.{WorkerState, RequestWorkerState} import spark.deploy.{WorkerState, RequestWorkerState}
private[spark]
class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives { class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/worker/webui" val RESOURCE_DIR = "spark/deploy/worker/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static" val STATIC_RESOURCE_DIR = "spark/deploy/static"
@ -21,7 +22,7 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct
completeWith{ completeWith{
val future = worker ? RequestWorkerState val future = worker ? RequestWorkerState
future.map { workerState => future.map { workerState =>
workerui.html.index(workerState.asInstanceOf[WorkerState]) spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState])
} }
} }
} ~ } ~

Просмотреть файл

@ -1,10 +1,12 @@
package spark.executor package spark.executor
import java.io.{File, FileOutputStream} import java.io.{File, FileOutputStream}
import java.net.{URL, URLClassLoader} import java.net.{URI, URL, URLClassLoader}
import java.util.concurrent._ import java.util.concurrent._
import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.FileUtil
import scala.collection.mutable.{ArrayBuffer, Map, HashMap}
import spark.broadcast._ import spark.broadcast._
import spark.scheduler._ import spark.scheduler._
@ -14,11 +16,16 @@ import java.nio.ByteBuffer
/** /**
* The Mesos executor for Spark. * The Mesos executor for Spark.
*/ */
class Executor extends Logging { private[spark] class Executor extends Logging {
var classLoader: ClassLoader = null var urlClassLoader : ExecutorURLClassLoader = null
var threadPool: ExecutorService = null var threadPool: ExecutorService = null
var env: SparkEnv = null var env: SparkEnv = null
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
initLogging() initLogging()
@ -32,14 +39,14 @@ class Executor extends Logging {
System.setProperty(key, value) System.setProperty(key, value)
} }
// Create our ClassLoader and set it on this thread
urlClassLoader = createClassLoader()
Thread.currentThread.setContextClassLoader(urlClassLoader)
// Initialize Spark environment (using system properties read above) // Initialize Spark environment (using system properties read above)
env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false) env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false)
SparkEnv.set(env) SparkEnv.set(env)
// Create our ClassLoader (using spark properties) and set it on this thread
classLoader = createClassLoader()
Thread.currentThread.setContextClassLoader(classLoader)
// Start worker thread pool // Start worker thread pool
threadPool = new ThreadPoolExecutor( threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
@ -54,15 +61,16 @@ class Executor extends Logging {
override def run() { override def run() {
SparkEnv.set(env) SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(classLoader) Thread.currentThread.setContextClassLoader(urlClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance() val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId) logInfo("Running task ID " + taskId)
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
try { try {
SparkEnv.set(env) SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear() Accumulators.clear()
val task = ser.deserialize[Task[Any]](serializedTask, classLoader) val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
logInfo("Its generation is " + task.generation) logInfo("Its generation is " + task.generation)
env.mapOutputTracker.updateGeneration(task.generation) env.mapOutputTracker.updateGeneration(task.generation)
val value = task.run(taskId.toInt) val value = task.run(taskId.toInt)
@ -96,25 +104,15 @@ class Executor extends Logging {
* Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
* created by the interpreter to the search path * created by the interpreter to the search path
*/ */
private def createClassLoader(): ClassLoader = { private def createClassLoader(): ExecutorURLClassLoader = {
var loader = this.getClass.getClassLoader var loader = this.getClass.getClassLoader
// If any JAR URIs are given through spark.jar.uris, fetch them to the // For each of the jars in the jarSet, add them to the class loader.
// current directory and put them all on the classpath. We assume that // We assume each of the files has already been fetched.
// each URL has a unique file name so that no local filenames will clash val urls = currentJars.keySet.map { uri =>
// in this process. This is guaranteed by ClusterScheduler. new File(uri.split("/").last).toURI.toURL
val uris = System.getProperty("spark.jar.uris", "") }.toArray
val localFiles = ArrayBuffer[String]() loader = new URLClassLoader(urls, loader)
for (uri <- uris.split(",").filter(_.size > 0)) {
val url = new URL(uri)
val filename = url.getPath.split("/").last
downloadFile(url, filename)
localFiles += filename
}
if (localFiles.size > 0) {
val urls = localFiles.map(f => new File(f).toURI.toURL).toArray
loader = new URLClassLoader(urls, loader)
}
// If the REPL is in use, add another ClassLoader that will read // If the REPL is in use, add another ClassLoader that will read
// new classes defined by the REPL as the user types code // new classes defined by the REPL as the user types code
@ -133,13 +131,31 @@ class Executor extends Logging {
} }
} }
return loader return new ExecutorURLClassLoader(Array(), loader)
} }
// Download a file from a given URL to the local filesystem /**
private def downloadFile(url: URL, localPath: String) { * Download any missing dependencies if we receive a new set of files and JARs from the
val in = url.openStream() * SparkContext. Also adds any new JARs we fetched to the class loader.
val out = new FileOutputStream(localPath) */
Utils.copyStream(in, out, true) private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name)
Utils.fetchFile(name, new File("."))
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name)
Utils.fetchFile(name, new File("."))
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
val url = new File(".", localName).toURI.toURL
if (!urlClassLoader.getURLs.contains(url)) {
logInfo("Adding " + url + " to class loader")
urlClassLoader.addURL(url)
}
}
} }
} }

Просмотреть файл

@ -6,6 +6,6 @@ import spark.TaskState.TaskState
/** /**
* A pluggable interface used by the Executor to send updates to the cluster scheduler. * A pluggable interface used by the Executor to send updates to the cluster scheduler.
*/ */
trait ExecutorBackend { private[spark] trait ExecutorBackend {
def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer)
} }

Просмотреть файл

@ -0,0 +1,14 @@
package spark.executor
import java.net.{URLClassLoader, URL}
/**
* The addURL method in URLClassLoader is protected. We subclass it to make this accessible.
*/
private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader)
extends URLClassLoader(urls, parent) {
override def addURL(url: URL) {
super.addURL(url)
}
}

Просмотреть файл

@ -8,7 +8,7 @@ import com.google.protobuf.ByteString
import spark.{Utils, Logging} import spark.{Utils, Logging}
import spark.TaskState import spark.TaskState
class MesosExecutorBackend(executor: Executor) private[spark] class MesosExecutorBackend(executor: Executor)
extends MesosExecutor extends MesosExecutor
with ExecutorBackend with ExecutorBackend
with Logging { with Logging {
@ -59,7 +59,7 @@ class MesosExecutorBackend(executor: Executor)
/** /**
* Entry point for Mesos executor. * Entry point for Mesos executor.
*/ */
object MesosExecutorBackend { private[spark] object MesosExecutorBackend {
def main(args: Array[String]) { def main(args: Array[String]) {
MesosNativeLibrary.load() MesosNativeLibrary.load()
// Create a new Executor and start it running // Create a new Executor and start it running

Просмотреть файл

@ -14,7 +14,7 @@ import spark.scheduler.cluster.RegisterSlaveFailed
import spark.scheduler.cluster.RegisterSlave import spark.scheduler.cluster.RegisterSlave
class StandaloneExecutorBackend( private[spark] class StandaloneExecutorBackend(
executor: Executor, executor: Executor,
masterUrl: String, masterUrl: String,
slaveId: String, slaveId: String,
@ -62,7 +62,7 @@ class StandaloneExecutorBackend(
} }
} }
object StandaloneExecutorBackend { private[spark] object StandaloneExecutorBackend {
def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) { def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) {
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc // before getting started with all our system properties, etc

Просмотреть файл

@ -11,6 +11,7 @@ import java.nio.channels.spi._
import java.net._ import java.net._
private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging { abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging {
channel.configureBlocking(false) channel.configureBlocking(false)
@ -23,8 +24,8 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
var onExceptionCallback: (Connection, Exception) => Unit = null var onExceptionCallback: (Connection, Exception) => Unit = null
var onKeyInterestChangeCallback: (Connection, Int) => Unit = null var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
lazy val remoteAddress = getRemoteAddress() val remoteAddress = getRemoteAddress()
lazy val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress) val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
def key() = channel.keyFor(selector) def key() = channel.keyFor(selector)
@ -39,7 +40,10 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
} }
def close() { def close() {
key.cancel() val k = key()
if (k != null) {
k.cancel()
}
channel.close() channel.close()
callOnCloseCallback() callOnCloseCallback()
} }
@ -99,7 +103,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
} }
class SendingConnection(val address: InetSocketAddress, selector_ : Selector) private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector)
extends Connection(SocketChannel.open, selector_) { extends Connection(SocketChannel.open, selector_) {
class Outbox(fair: Int = 0) { class Outbox(fair: Int = 0) {
@ -134,9 +138,12 @@ extends Connection(SocketChannel.open, selector_) {
if (!message.started) logDebug("Starting to send [" + message + "]") if (!message.started) logDebug("Starting to send [" + message + "]")
message.started = true message.started = true
return chunk return chunk
} else {
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
"] in " + message.timeTaken )
} }
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
} }
} }
None None
@ -159,10 +166,11 @@ extends Connection(SocketChannel.open, selector_) {
} }
logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]") logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
return chunk return chunk
} else {
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
"] in " + message.timeTaken )
} }
/*messages -= message*/
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
} }
} }
None None
@ -216,7 +224,7 @@ extends Connection(SocketChannel.open, selector_) {
while(true) { while(true) {
if (currentBuffers.size == 0) { if (currentBuffers.size == 0) {
outbox.synchronized { outbox.synchronized {
outbox.getChunk match { outbox.getChunk() match {
case Some(chunk) => { case Some(chunk) => {
currentBuffers ++= chunk.buffers currentBuffers ++= chunk.buffers
} }
@ -252,7 +260,7 @@ extends Connection(SocketChannel.open, selector_) {
} }
class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
extends Connection(channel_, selector_) { extends Connection(channel_, selector_) {
class Inbox() { class Inbox() {

Просмотреть файл

@ -16,18 +16,19 @@ import scala.collection.mutable.ArrayBuffer
import akka.dispatch.{Await, Promise, ExecutionContext, Future} import akka.dispatch.{Await, Promise, ExecutionContext, Future}
import akka.util.Duration import akka.util.Duration
import akka.util.duration._
case class ConnectionManagerId(host: String, port: Int) { private[spark] case class ConnectionManagerId(host: String, port: Int) {
def toSocketAddress() = new InetSocketAddress(host, port) def toSocketAddress() = new InetSocketAddress(host, port)
} }
object ConnectionManagerId { private[spark] object ConnectionManagerId {
def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
} }
} }
class ConnectionManager(port: Int) extends Logging { private[spark] class ConnectionManager(port: Int) extends Logging {
class MessageStatus( class MessageStatus(
val message: Message, val message: Message,
@ -348,7 +349,7 @@ class ConnectionManager(port: Int) extends Logging {
} }
object ConnectionManager { private[spark] object ConnectionManager {
def main(args: Array[String]) { def main(args: Array[String]) {
@ -403,7 +404,10 @@ object ConnectionManager {
(0 until count).map(i => { (0 until count).map(i => {
val bufferMessage = Message.createBufferMessage(buffer.duplicate) val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliably(manager.id, bufferMessage) manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {if (!f().isDefined) println("Failed")}) }).foreach(f => {
val g = Await.result(f, 1 second)
if (!g.isDefined) println("Failed")
})
val finishTime = System.currentTimeMillis val finishTime = System.currentTimeMillis
val mb = size * count / 1024.0 / 1024.0 val mb = size * count / 1024.0 / 1024.0
@ -430,7 +434,10 @@ object ConnectionManager {
(0 until count).map(i => { (0 until count).map(i => {
val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
manager.sendMessageReliably(manager.id, bufferMessage) manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {if (!f().isDefined) println("Failed")}) }).foreach(f => {
val g = Await.result(f, 1 second)
if (!g.isDefined) println("Failed")
})
val finishTime = System.currentTimeMillis val finishTime = System.currentTimeMillis
val ms = finishTime - startTime val ms = finishTime - startTime
@ -457,7 +464,10 @@ object ConnectionManager {
(0 until count).map(i => { (0 until count).map(i => {
val bufferMessage = Message.createBufferMessage(buffer.duplicate) val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliably(manager.id, bufferMessage) manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {if (!f().isDefined) println("Failed")}) }).foreach(f => {
val g = Await.result(f, 1 second)
if (!g.isDefined) println("Failed")
})
val finishTime = System.currentTimeMillis val finishTime = System.currentTimeMillis
Thread.sleep(1000) Thread.sleep(1000)
val mb = size * count / 1024.0 / 1024.0 val mb = size * count / 1024.0 / 1024.0

Просмотреть файл

@ -8,7 +8,10 @@ import scala.io.Source
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.net.InetAddress import java.net.InetAddress
object ConnectionManagerTest extends Logging{ import akka.dispatch.Await
import akka.util.duration._
private[spark] object ConnectionManagerTest extends Logging{
def main(args: Array[String]) { def main(args: Array[String]) {
if (args.length < 2) { if (args.length < 2) {
println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>") println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>")
@ -53,7 +56,7 @@ object ConnectionManagerTest extends Logging{
logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
}) })
val results = futures.map(f => f()) val results = futures.map(f => Await.result(f, 1.second))
val finishTime = System.currentTimeMillis val finishTime = System.currentTimeMillis
Thread.sleep(5000) Thread.sleep(5000)

Просмотреть файл

@ -7,8 +7,9 @@ import scala.collection.mutable.ArrayBuffer
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.net.InetAddress import java.net.InetAddress
import java.net.InetSocketAddress import java.net.InetSocketAddress
import storage.BlockManager
class MessageChunkHeader( private[spark] class MessageChunkHeader(
val typ: Long, val typ: Long,
val id: Int, val id: Int,
val totalSize: Int, val totalSize: Int,
@ -36,7 +37,7 @@ class MessageChunkHeader(
" and sizes " + totalSize + " / " + chunkSize + " bytes" " and sizes " + totalSize + " / " + chunkSize + " bytes"
} }
class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { private[spark] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
val size = if (buffer == null) 0 else buffer.remaining val size = if (buffer == null) 0 else buffer.remaining
lazy val buffers = { lazy val buffers = {
val ab = new ArrayBuffer[ByteBuffer]() val ab = new ArrayBuffer[ByteBuffer]()
@ -50,7 +51,7 @@ class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
} }
abstract class Message(val typ: Long, val id: Int) { private[spark] abstract class Message(val typ: Long, val id: Int) {
var senderAddress: InetSocketAddress = null var senderAddress: InetSocketAddress = null
var started = false var started = false
var startTime = -1L var startTime = -1L
@ -64,10 +65,10 @@ abstract class Message(val typ: Long, val id: Int) {
def timeTaken(): String = (finishTime - startTime).toString + " ms" def timeTaken(): String = (finishTime - startTime).toString + " ms"
override def toString = "" + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
} }
class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
extends Message(Message.BUFFER_MESSAGE, id_) { extends Message(Message.BUFFER_MESSAGE, id_) {
val initialSize = currentSize() val initialSize = currentSize()
@ -97,10 +98,11 @@ extends Message(Message.BUFFER_MESSAGE, id_) {
while(!buffers.isEmpty) { while(!buffers.isEmpty) {
val buffer = buffers(0) val buffer = buffers(0)
if (buffer.remaining == 0) { if (buffer.remaining == 0) {
BlockManager.dispose(buffer)
buffers -= buffer buffers -= buffer
} else { } else {
val newBuffer = if (buffer.remaining <= maxChunkSize) { val newBuffer = if (buffer.remaining <= maxChunkSize) {
buffer.duplicate buffer.duplicate()
} else { } else {
buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
} }
@ -147,11 +149,10 @@ extends Message(Message.BUFFER_MESSAGE, id_) {
} else { } else {
"BufferMessage(id = " + id + ", size = " + size + ")" "BufferMessage(id = " + id + ", size = " + size + ")"
} }
} }
} }
object MessageChunkHeader { private[spark] object MessageChunkHeader {
val HEADER_SIZE = 40 val HEADER_SIZE = 40
def create(buffer: ByteBuffer): MessageChunkHeader = { def create(buffer: ByteBuffer): MessageChunkHeader = {
@ -172,7 +173,7 @@ object MessageChunkHeader {
} }
} }
object Message { private[spark] object Message {
val BUFFER_MESSAGE = 1111111111L val BUFFER_MESSAGE = 1111111111L
var lastId = 1 var lastId = 1

Просмотреть файл

@ -3,7 +3,7 @@ package spark.network
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.net.InetAddress import java.net.InetAddress
object ReceiverTest { private[spark] object ReceiverTest {
def main(args: Array[String]) { def main(args: Array[String]) {
val manager = new ConnectionManager(9999) val manager = new ConnectionManager(9999)

Просмотреть файл

@ -3,7 +3,7 @@ package spark.network
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.net.InetAddress import java.net.InetAddress
object SenderTest { private[spark] object SenderTest {
def main(args: Array[String]) { def main(args: Array[String]) {

Просмотреть файл

@ -12,7 +12,7 @@ import spark.scheduler.JobListener
* a result of type U for each partition, and that the action returns a partial or complete result * a result of type U for each partition, and that the action returns a partial or complete result
* of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt). * of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt).
*/ */
class ApproximateActionListener[T, U, R]( private[spark] class ApproximateActionListener[T, U, R](
rdd: RDD[T], rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U, func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R], evaluator: ApproximateEvaluator[U, R],

Просмотреть файл

@ -4,7 +4,7 @@ package spark.partial
* An object that computes a function incrementally by merging in results of type U from multiple * An object that computes a function incrementally by merging in results of type U from multiple
* tasks. Allows partial evaluation at any point by calling currentResult(). * tasks. Allows partial evaluation at any point by calling currentResult().
*/ */
trait ApproximateEvaluator[U, R] { private[spark] trait ApproximateEvaluator[U, R] {
def merge(outputId: Int, taskResult: U): Unit def merge(outputId: Int, taskResult: U): Unit
def currentResult(): R def currentResult(): R
} }

Просмотреть файл

@ -3,6 +3,7 @@ package spark.partial
/** /**
* A Double with error bars on it. * A Double with error bars on it.
*/ */
private[spark]
class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) { class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
override def toString(): String = "[%.3f, %.3f]".format(low, high) override def toString(): String = "[%.3f, %.3f]".format(low, high)
} }

Просмотреть файл

@ -8,7 +8,7 @@ import cern.jet.stat.Probability
* TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might
* be best to make this a special case of GroupedCountEvaluator with one group. * be best to make this a special case of GroupedCountEvaluator with one group.
*/ */
class CountEvaluator(totalOutputs: Int, confidence: Double) private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[Long, BoundedDouble] { extends ApproximateEvaluator[Long, BoundedDouble] {
var outputsMerged = 0 var outputsMerged = 0

Просмотреть файл

@ -14,7 +14,7 @@ import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
/** /**
* An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval. * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval.
*/ */
class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double) private[spark] class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] { extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] {
var outputsMerged = 0 var outputsMerged = 0

Просмотреть файл

@ -12,7 +12,7 @@ import spark.util.StatCounter
/** /**
* An ApproximateEvaluator for means by key. Returns a map of key to confidence interval. * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval.
*/ */
class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double) private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
var outputsMerged = 0 var outputsMerged = 0

Просмотреть файл

@ -12,7 +12,7 @@ import spark.util.StatCounter
/** /**
* An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval. * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval.
*/ */
class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double) private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
var outputsMerged = 0 var outputsMerged = 0

Просмотреть файл

@ -7,7 +7,7 @@ import spark.util.StatCounter
/** /**
* An ApproximateEvaluator for means. * An ApproximateEvaluator for means.
*/ */
class MeanEvaluator(totalOutputs: Int, confidence: Double) private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[StatCounter, BoundedDouble] { extends ApproximateEvaluator[StatCounter, BoundedDouble] {
var outputsMerged = 0 var outputsMerged = 0

Просмотреть файл

@ -1,6 +1,6 @@
package spark.partial package spark.partial
class PartialResult[R](initialVal: R, isFinal: Boolean) { private[spark] class PartialResult[R](initialVal: R, isFinal: Boolean) {
private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None
private var failure: Option[Exception] = None private var failure: Option[Exception] = None
private var completionHandler: Option[R => Unit] = None private var completionHandler: Option[R => Unit] = None

Просмотреть файл

@ -7,7 +7,7 @@ import cern.jet.stat.Probability
* and various sample sizes. This is used by the MeanEvaluator to efficiently calculate * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate
* confidence intervals for many keys. * confidence intervals for many keys.
*/ */
class StudentTCacher(confidence: Double) { private[spark] class StudentTCacher(confidence: Double) {
val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation
val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2) val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2)
val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0) val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0)

Просмотреть файл

@ -9,7 +9,7 @@ import spark.util.StatCounter
* together, then uses the formula for the variance of two independent random variables to get * together, then uses the formula for the variance of two independent random variables to get
* a variance for the result and compute a confidence interval. * a variance for the result and compute a confidence interval.
*/ */
class SumEvaluator(totalOutputs: Int, confidence: Double) private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[StatCounter, BoundedDouble] { extends ApproximateEvaluator[StatCounter, BoundedDouble] {
var outputsMerged = 0 var outputsMerged = 0

Просмотреть файл

@ -5,11 +5,12 @@ import spark.TaskContext
/** /**
* Tracks information about an active job in the DAGScheduler. * Tracks information about an active job in the DAGScheduler.
*/ */
class ActiveJob( private[spark] class ActiveJob(
val runId: Int, val runId: Int,
val finalStage: Stage, val finalStage: Stage,
val func: (TaskContext, Iterator[_]) => _, val func: (TaskContext, Iterator[_]) => _,
val partitions: Array[Int], val partitions: Array[Int],
val callSite: String,
val listener: JobListener) { val listener: JobListener) {
val numPartitions = partitions.length val numPartitions = partitions.length

Просмотреть файл

@ -21,6 +21,7 @@ import spark.storage.BlockManagerId
* schedule to run the job. Subclasses only need to implement the code to send a task to the cluster * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents). * and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/ */
private[spark]
class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging { class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
taskSched.setListener(this) taskSched.setListener(this)
@ -38,6 +39,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
eventQueue.put(HostLost(host)) eventQueue.put(HostLost(host))
} }
// Called by TaskScheduler to cancel an entier TaskSet due to repeated failures.
override def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
}
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected; // The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in // as more failure events come in
@ -116,7 +122,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]], priority: Int): Stage = { def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]], priority: Int): Stage = {
// Kind of ugly: need to register RDDs with the cache and map output tracker here // Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of splits is unknown // since we can't do it in the RDD constructor because # of splits is unknown
logInfo("Registering RDD " + rdd.id + ": " + rdd) logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
cacheTracker.registerRDD(rdd.id, rdd.splits.size) cacheTracker.registerRDD(rdd.id, rdd.splits.size)
if (shuffleDep != None) { if (shuffleDep != None) {
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
@ -139,7 +145,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visited += r visited += r
// Kind of ugly: need to register RDDs with the cache here since // Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown // we can't do it in its constructor because # of splits is unknown
logInfo("Registering parent RDD " + r.id + ": " + r) logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")")
cacheTracker.registerRDD(r.id, r.splits.size) cacheTracker.registerRDD(r.id, r.splits.size)
for (dep <- r.dependencies) { for (dep <- r.dependencies) {
dep match { dep match {
@ -183,23 +189,25 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
missing.toList missing.toList
} }
def runJob[T, U]( def runJob[T, U: ClassManifest](
finalRdd: RDD[T], finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U, func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int], partitions: Seq[Int],
callSite: String,
allowLocal: Boolean) allowLocal: Boolean)
(implicit m: ClassManifest[U]): Array[U] = : Array[U] =
{ {
if (partitions.size == 0) { if (partitions.size == 0) {
return new Array[U](0) return new Array[U](0)
} }
val waiter = new JobWaiter(partitions.size) val waiter = new JobWaiter(partitions.size)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, waiter)) eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter))
waiter.getResult() match { waiter.getResult() match {
case JobSucceeded(results: Seq[_]) => case JobSucceeded(results: Seq[_]) =>
return results.asInstanceOf[Seq[U]].toArray return results.asInstanceOf[Seq[U]].toArray
case JobFailed(exception: Exception) => case JobFailed(exception: Exception) =>
logInfo("Failed to run " + callSite)
throw exception throw exception
} }
} }
@ -208,13 +216,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
rdd: RDD[T], rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U, func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R], evaluator: ApproximateEvaluator[U, R],
timeout: Long callSite: String,
): PartialResult[R] = timeout: Long)
: PartialResult[R] =
{ {
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.splits.size).toArray val partitions = (0 until rdd.splits.size).toArray
eventQueue.put(JobSubmitted(rdd, func2, partitions, false, listener)) eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener))
return listener.getResult() // Will throw an exception if the job fails return listener.getResult() // Will throw an exception if the job fails
} }
@ -234,13 +243,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
} }
event match { event match {
case JobSubmitted(finalRDD, func, partitions, allowLocal, listener) => case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
val runId = nextRunId.getAndIncrement() val runId = nextRunId.getAndIncrement()
val finalStage = newStage(finalRDD, None, runId) val finalStage = newStage(finalRDD, None, runId)
val job = new ActiveJob(runId, finalStage, func, partitions, listener) val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
updateCacheLocs() updateCacheLocs()
logInfo("Got job " + job.runId + " with " + partitions.length + " output partitions") logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
logInfo("Final stage: " + finalStage) " output partitions")
logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
logInfo("Parents of final stage: " + finalStage.parents) logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage)) logInfo("Missing parents: " + getMissingParentStages(finalStage))
if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
@ -258,6 +268,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
case completion: CompletionEvent => case completion: CompletionEvent =>
handleTaskCompletion(completion) handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
abortStage(idToStage(taskSet.stageId), reason)
case StopDAGScheduler => case StopDAGScheduler =>
// Cancel any active jobs // Cancel any active jobs
for (job <- activeJobs) { for (job <- activeJobs) {
@ -329,7 +342,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val missing = getMissingParentStages(stage).sortBy(_.id) val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing) logDebug("missing: " + missing)
if (missing == Nil) { if (missing == Nil) {
logInfo("Submitting " + stage + ", which has no missing parents") logInfo("Submitting " + stage + " (" + stage.origin + "), which has no missing parents")
submitMissingTasks(stage) submitMissingTasks(stage)
running += stage running += stage
} else { } else {
@ -416,7 +429,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
stage.addOutputLoc(smt.partition, bmAddress) stage.addOutputLoc(smt.partition, bmAddress)
} }
if (running.contains(stage) && pendingTasks(stage).isEmpty) { if (running.contains(stage) && pendingTasks(stage).isEmpty) {
logInfo(stage + " finished; looking for newly runnable stages") logInfo(stage + " (" + stage.origin + ") finished; looking for newly runnable stages")
running -= stage running -= stage
logInfo("running: " + running) logInfo("running: " + running)
logInfo("waiting: " + waiting) logInfo("waiting: " + waiting)
@ -430,7 +443,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
if (stage.outputLocs.count(_ == Nil) != 0) { if (stage.outputLocs.count(_ == Nil) != 0) {
// Some tasks had failed; let's resubmit this stage // Some tasks had failed; let's resubmit this stage
// TODO: Lower-level scheduler should also deal with this // TODO: Lower-level scheduler should also deal with this
logInfo("Resubmitting " + stage + " because some of its tasks had failed: " + logInfo("Resubmitting " + stage + " (" + stage.origin +
") because some of its tasks had failed: " +
stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", ")) stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", "))
submitStage(stage) submitStage(stage)
} else { } else {
@ -444,6 +458,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
waiting --= newlyRunnable waiting --= newlyRunnable
running ++= newlyRunnable running ++= newlyRunnable
for (stage <- newlyRunnable.sortBy(_.id)) { for (stage <- newlyRunnable.sortBy(_.id)) {
logInfo("Submitting " + stage + " (" + stage.origin + "), which is now runnable")
submitMissingTasks(stage) submitMissingTasks(stage)
} }
} }
@ -460,12 +475,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
running -= failedStage running -= failedStage
failed += failedStage failed += failedStage
// TODO: Cancel running tasks in the stage // TODO: Cancel running tasks in the stage
logInfo("Marking " + failedStage + " for resubmision due to a fetch failure") logInfo("Marking " + failedStage + " (" + failedStage.origin +
") for resubmision due to a fetch failure")
// Mark the map whose fetch failed as broken in the map stage // Mark the map whose fetch failed as broken in the map stage
val mapStage = shuffleToMapStage(shuffleId) val mapStage = shuffleToMapStage(shuffleId)
mapStage.removeOutputLoc(mapId, bmAddress) mapStage.removeOutputLoc(mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission") logInfo("The failed fetch was from " + mapStage + " (" + mapStage.origin +
"); marking it for resubmission")
failed += mapStage failed += mapStage
// Remember that a fetch failed now; this is used to resubmit the broken // Remember that a fetch failed now; this is used to resubmit the broken
// stages later, after a small wait (to give other tasks the chance to fail) // stages later, after a small wait (to give other tasks the chance to fail)
@ -475,18 +492,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
handleHostLost(bmAddress.ip) handleHostLost(bmAddress.ip)
} }
case _ => case other =>
// Non-fetch failure -- probably a bug in the job, so bail out // Non-fetch failure -- probably a bug in user code; abort all jobs depending on this stage
// TODO: Cancel all tasks that are still running abortStage(idToStage(task.stageId), task + " failed: " + other)
resultStageToJob.get(stage) match {
case Some(job) =>
val error = new SparkException("Task failed: " + task + ", reason: " + event.reason)
job.listener.jobFailed(error)
activeJobs -= job
resultStageToJob -= stage
case None =>
logInfo("Ignoring result from " + task + " because its job has finished")
}
} }
} }
@ -510,6 +518,53 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
} }
} }
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
job.listener.jobFailed(new SparkException("Job failed: " + reason))
activeJobs -= job
resultStageToJob -= resultStage
}
if (dependentStages.isEmpty) {
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
}
}
/**
* Return true if one of stage's ancestors is target.
*/
def stageDependsOn(stage: Stage, target: Stage): Boolean = {
if (stage == target) {
return true
}
val visitedRdds = new HashSet[RDD[_]]
val visitedStages = new HashSet[Stage]
def visit(rdd: RDD[_]) {
if (!visitedRdds(rdd)) {
visitedRdds += rdd
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_,_] =>
val mapStage = getShuffleMapStage(shufDep, stage.priority)
if (!mapStage.isAvailable) {
visitedStages += mapStage
visit(mapStage.rdd)
} // Otherwise there's no need to follow the dependency back
case narrowDep: NarrowDependency[_] =>
visit(narrowDep.rdd)
}
}
}
}
visit(stage.rdd)
visitedRdds.contains(target.rdd)
}
def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = { def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
// If the partition is cached, return the cache locations // If the partition is cached, return the cache locations
val cached = getCacheLocs(rdd)(partition) val cached = getCacheLocs(rdd)(partition)

Просмотреть файл

@ -10,23 +10,26 @@ import spark._
* submitted) but there is a single "logic" thread that reads these events and takes decisions. * submitted) but there is a single "logic" thread that reads these events and takes decisions.
* This greatly simplifies synchronization. * This greatly simplifies synchronization.
*/ */
sealed trait DAGSchedulerEvent private[spark] sealed trait DAGSchedulerEvent
case class JobSubmitted( private[spark] case class JobSubmitted(
finalRDD: RDD[_], finalRDD: RDD[_],
func: (TaskContext, Iterator[_]) => _, func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int], partitions: Array[Int],
allowLocal: Boolean, allowLocal: Boolean,
callSite: String,
listener: JobListener) listener: JobListener)
extends DAGSchedulerEvent extends DAGSchedulerEvent
case class CompletionEvent( private[spark] case class CompletionEvent(
task: Task[_], task: Task[_],
reason: TaskEndReason, reason: TaskEndReason,
result: Any, result: Any,
accumUpdates: Map[Long, Any]) accumUpdates: Map[Long, Any])
extends DAGSchedulerEvent extends DAGSchedulerEvent
case class HostLost(host: String) extends DAGSchedulerEvent private[spark] case class HostLost(host: String) extends DAGSchedulerEvent
case object StopDAGScheduler extends DAGSchedulerEvent private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
private[spark] case object StopDAGScheduler extends DAGSchedulerEvent

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше