Merge branch 'master' into mesos-0.9

Conflicts:
	core/src/main/scala/spark/Executor.scala
This commit is contained in:
Matei Zaharia 2012-06-03 17:44:04 -07:00
Родитель bd2ab635a7 1dd7d3dfff
Коммит dbc3c86ae3
31 изменённых файлов: 655 добавлений и 185 удалений

Просмотреть файл

@ -126,6 +126,10 @@ class WPRSerializerInstance extends SerializerInstance {
throw new UnsupportedOperationException()
}
def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
throw new UnsupportedOperationException()
}
def outputStream(s: OutputStream): SerializationStream = {
new WPRSerializationStream(s)
}

Просмотреть файл

@ -9,19 +9,19 @@ import java.util.LinkedHashMap
* some cache entries have pointers to a shared object. Nonetheless, this Cache should work well
* when most of the space is used by arrays of primitives or of simple classes.
*/
class BoundedMemoryCache extends Cache with Logging {
private val maxBytes: Long = getMaxBytes()
class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
logInfo("BoundedMemoryCache.maxBytes = " + maxBytes)
def this() {
this(BoundedMemoryCache.getMaxBytes)
}
private var currentBytes = 0L
private val map = new LinkedHashMap[Any, Entry](32, 0.75f, true)
private val map = new LinkedHashMap[(Any, Int), Entry](32, 0.75f, true)
// An entry in our map; stores a cached object and its size in bytes
class Entry(val value: Any, val size: Long) {}
override def get(key: Any): Any = {
override def get(datasetId: Any, partition: Int): Any = {
synchronized {
val entry = map.get(key)
val entry = map.get((datasetId, partition))
if (entry != null) {
entry.value
} else {
@ -30,46 +30,80 @@ class BoundedMemoryCache extends Cache with Logging {
}
}
override def put(key: Any, value: Any) {
override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
val key = (datasetId, partition)
logInfo("Asked to add key " + key)
val size = estimateValueSize(key, value)
synchronized {
if (size > getCapacity) {
return CachePutFailure()
} else if (ensureFreeSpace(datasetId, size)) {
logInfo("Adding key " + key)
map.put(key, new Entry(value, size))
currentBytes += size
logInfo("Number of entries is now " + map.size)
return CachePutSuccess(size)
} else {
logInfo("Didn't add key " + key + " because we would have evicted part of same dataset")
return CachePutFailure()
}
}
}
override def getCapacity: Long = maxBytes
/**
* Estimate sizeOf 'value'
*/
private def estimateValueSize(key: (Any, Int), value: Any) = {
val startTime = System.currentTimeMillis
val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef])
val timeTaken = System.currentTimeMillis - startTime
logInfo("Estimated size for key %s is %d".format(key, size))
logInfo("Size estimation for key %s took %d ms".format(key, timeTaken))
synchronized {
ensureFreeSpace(size)
logInfo("Adding key " + key)
map.put(key, new Entry(value, size))
currentBytes += size
logInfo("Number of entries is now " + map.size)
}
}
private def getMaxBytes(): Long = {
val memoryFractionToUse = System.getProperty(
"spark.boundedMemoryCache.memoryFraction", "0.66").toDouble
(Runtime.getRuntime.maxMemory * memoryFractionToUse).toLong
size
}
/**
* Remove least recently used entries from the map until at least space bytes are free. Assumes
* Remove least recently used entries from the map until at least space bytes are free, in order
* to make space for a partition from the given dataset ID. If this cannot be done without
* evicting other data from the same dataset, returns false; otherwise, returns true. Assumes
* that a lock is held on the BoundedMemoryCache.
*/
private def ensureFreeSpace(space: Long) {
logInfo("ensureFreeSpace(%d) called with curBytes=%d, maxBytes=%d".format(
space, currentBytes, maxBytes))
val iter = map.entrySet.iterator
private def ensureFreeSpace(datasetId: Any, space: Long): Boolean = {
logInfo("ensureFreeSpace(%s, %d) called with curBytes=%d, maxBytes=%d".format(
datasetId, space, currentBytes, maxBytes))
val iter = map.entrySet.iterator // Will give entries in LRU order
while (maxBytes - currentBytes < space && iter.hasNext) {
val mapEntry = iter.next()
dropEntry(mapEntry.getKey, mapEntry.getValue)
val (entryDatasetId, entryPartition) = mapEntry.getKey
if (entryDatasetId == datasetId) {
// Cannot make space without removing part of the same dataset, or a more recently used one
return false
}
reportEntryDropped(entryDatasetId, entryPartition, mapEntry.getValue)
currentBytes -= mapEntry.getValue.size
iter.remove()
}
return true
}
protected def dropEntry(key: Any, entry: Entry) {
logInfo("Dropping key %s of size %d to make space".format(key, entry.size))
SparkEnv.get.cacheTracker.dropEntry(key)
protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
SparkEnv.get.cacheTracker.dropEntry(datasetId, partition)
}
}
// An entry in our map; stores a cached object and its size in bytes
case class Entry(value: Any, size: Long)
object BoundedMemoryCache {
/**
* Get maximum cache capacity from system configuration
*/
def getMaxBytes: Long = {
val memoryFractionToUse = System.getProperty("spark.boundedMemoryCache.memoryFraction", "0.66").toDouble
(Runtime.getRuntime.maxMemory * memoryFractionToUse).toLong
}
}

Просмотреть файл

@ -1,10 +1,16 @@
package spark
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.atomic.AtomicInteger
sealed trait CachePutResponse
case class CachePutSuccess(size: Long) extends CachePutResponse
case class CachePutFailure() extends CachePutResponse
/**
* An interface for caches in Spark, to allow for multiple implementations. Caches are used to store
* both partitions of cached RDDs and broadcast variables on Spark executors.
* both partitions of cached RDDs and broadcast variables on Spark executors. Caches are also aware
* of which entries are part of the same dataset (for example, partitions in the same RDD). The key
* for each value in a cache is a (datasetID, partition) pair.
*
* A single Cache instance gets created on each machine and is shared by all caches (i.e. both the
* RDD split cache and the broadcast variable cache), to enable global replacement policies.
@ -17,19 +23,41 @@ import java.util.concurrent.atomic.AtomicLong
* keys that are unique across modules.
*/
abstract class Cache {
private val nextKeySpaceId = new AtomicLong(0)
private val nextKeySpaceId = new AtomicInteger(0)
private def newKeySpaceId() = nextKeySpaceId.getAndIncrement()
def newKeySpace() = new KeySpace(this, newKeySpaceId())
def get(key: Any): Any
def put(key: Any, value: Any): Unit
/**
* Get the value for a given (datasetId, partition), or null if it is not
* found.
*/
def get(datasetId: Any, partition: Int): Any
/**
* Attempt to put a value in the cache; returns CachePutFailure if this was
* not successful (e.g. because the cache replacement policy forbids it), and
* CachePutSuccess if successful. If size estimation is available, the cache
* implementation should set the size field in CachePutSuccess.
*/
def put(datasetId: Any, partition: Int, value: Any): CachePutResponse
/**
* Report the capacity of the cache partition. By default this just reports
* zero. Specific implementations can choose to provide the capacity number.
*/
def getCapacity: Long = 0L
}
/**
* A key namespace in a Cache.
*/
class KeySpace(cache: Cache, id: Long) {
def get(key: Any): Any = cache.get((id, key))
def put(key: Any, value: Any): Unit = cache.put((id, key), value)
class KeySpace(cache: Cache, val keySpaceId: Int) {
def get(datasetId: Any, partition: Int): Any =
cache.get((keySpaceId, datasetId), partition)
def put(datasetId: Any, partition: Int, value: Any): CachePutResponse =
cache.put((keySpaceId, datasetId), partition, value)
def getCapacity: Long = cache.getCapacity
}

Просмотреть файл

@ -7,16 +7,32 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
sealed trait CacheTrackerMessage
case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheTrackerMessage
case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheTrackerMessage
case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
case class MemoryCacheLost(host: String) extends CacheTrackerMessage
case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
case object GetCacheStatus extends CacheTrackerMessage
case object GetCacheLocations extends CacheTrackerMessage
case object StopCacheTracker extends CacheTrackerMessage
class CacheTrackerActor extends DaemonActor with Logging {
val locs = new HashMap[Int, Array[List[String]]]
private val locs = new HashMap[Int, Array[List[String]]]
/**
* A map from the slave's host name to its cache size.
*/
private val slaveCapacity = new HashMap[String, Long]
private val slaveUsage = new HashMap[String, Long]
// TODO: Should probably store (String, CacheType) tuples
private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
def act() {
val port = System.getProperty("spark.master.port").toInt
@ -26,31 +42,61 @@ class CacheTrackerActor extends DaemonActor with Logging {
loop {
react {
case SlaveCacheStarted(host: String, size: Long) =>
logInfo("Started slave cache (size %s) on %s".format(
Utils.memoryBytesToString(size), host))
slaveCapacity.put(host, size)
slaveUsage.put(host, 0)
reply('OK)
case RegisterRDD(rddId: Int, numPartitions: Int) =>
logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
reply('OK)
case AddedToCache(rddId, partition, host) =>
logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
case AddedToCache(rddId, partition, host, size) =>
if (size > 0) {
slaveUsage.put(host, getCacheUsage(host) + size)
logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
rddId, partition, host, Utils.memoryBytesToString(size),
Utils.memoryBytesToString(getCacheAvailable(host))))
} else {
logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
}
locs(rddId)(partition) = host :: locs(rddId)(partition)
reply('OK)
case DroppedFromCache(rddId, partition, host) =>
logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
case DroppedFromCache(rddId, partition, host, size) =>
if (size > 0) {
logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
rddId, partition, host, Utils.memoryBytesToString(size),
Utils.memoryBytesToString(getCacheAvailable(host))))
slaveUsage.put(host, getCacheUsage(host) - size)
// Do a sanity check to make sure usage is greater than 0.
val usage = getCacheUsage(host)
if (usage < 0) {
logError("Cache usage on %s is negative (%d)".format(host, usage))
}
} else {
logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
}
locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
reply('OK)
case MemoryCacheLost(host) =>
logInfo("Memory cache lost on " + host)
// TODO: Drop host from the memory locations list of all RDDs
case GetCacheLocations =>
logInfo("Asked for current cache locations")
val locsCopy = new HashMap[Int, Array[List[String]]]
for ((rddId, array) <- locs) {
locsCopy(rddId) = array.clone()
}
reply(locsCopy)
reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())})
case GetCacheStatus =>
val status = slaveCapacity.map { case (host,capacity) =>
(host, capacity, getCacheUsage(host))
}.toSeq
reply(status)
case StopCacheTracker =>
reply('OK)
@ -60,10 +106,16 @@ class CacheTrackerActor extends DaemonActor with Logging {
}
}
class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
// Tracker actor on the master, or remote reference to it on workers
var trackerActor: AbstractActor = null
val registeredRddIds = new HashSet[Int]
// Stores map results for various splits locally
val cache = theCache.newKeySpace()
if (isMaster) {
val tracker = new CacheTrackerActor
tracker.start()
@ -74,10 +126,8 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
trackerActor = RemoteActor.select(Node(host, port), 'CacheTracker)
}
val registeredRddIds = new HashSet[Int]
// Stores map results for various splits locally
val cache = theCache.newKeySpace()
// Report the cache being started.
trackerActor !? SlaveCacheStarted(Utils.getHost, cache.getCapacity)
// Remembers which splits are currently being loaded (on worker nodes)
val loading = new HashSet[(Int, Int)]
@ -92,65 +142,92 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
}
}
}
// Get a snapshot of the currently known locations
def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
(trackerActor !? GetCacheLocations) match {
case h: HashMap[_, _] =>
h.asInstanceOf[HashMap[Int, Array[List[String]]]]
case _ =>
throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap")
case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]]
case _ => throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap")
}
}
// Get the usage status of slave caches. Each tuple in the returned sequence
// is in the form of (host name, capacity, usage).
def getCacheStatus(): Seq[(String, Long, Long)] = {
(trackerActor !? GetCacheStatus) match {
case h: Seq[(String, Long, Long)] => h.asInstanceOf[Seq[(String, Long, Long)]]
case _ =>
throw new SparkException(
"Internal error: CacheTrackerActor did not reply with a Seq[Tuple3[String, Long, Long]")
}
}
// Gets or computes an RDD split
def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T]): Iterator[T] = {
val key = (rdd.id, split.index)
logInfo("CachedRDD partition key is " + key)
val cachedVal = cache.get(key)
logInfo("Looking for RDD partition %d:%d".format(rdd.id, split.index))
val cachedVal = cache.get(rdd.id, split.index)
if (cachedVal != null) {
// Split is in cache, so just return its values
logInfo("Found partition in cache!")
return cachedVal.asInstanceOf[Array[T]].iterator
} else {
// Mark the split as loading (unless someone else marks it first)
val key = (rdd.id, split.index)
loading.synchronized {
if (loading.contains(key)) {
while (loading.contains(key)) {
try {loading.wait()} catch {case _ =>}
}
return cache.get(key).asInstanceOf[Array[T]].iterator
} else {
loading.add(key)
while (loading.contains(key)) {
// Someone else is loading it; let's wait for them
try { loading.wait() } catch { case _ => }
}
// See whether someone else has successfully loaded it. The main way this would fail
// is for the RDD-level cache eviction policy if someone else has loaded the same RDD
// partition but we didn't want to make space for it. However, that case is unlikely
// because it's unlikely that two threads would work on the same RDD partition. One
// downside of the current code is that threads wait serially if this does happen.
val cachedVal = cache.get(rdd.id, split.index)
if (cachedVal != null) {
return cachedVal.asInstanceOf[Array[T]].iterator
}
// Nobody's loading it and it's not in the cache; let's load it ourselves
loading.add(key)
}
// If we got here, we have to load the split
// Tell the master that we're doing so
val host = System.getProperty("spark.hostname", Utils.localHostName)
val future = trackerActor !! AddedToCache(rdd.id, split.index, host)
// TODO: fetch any remote copy of the split that may be available
// TODO: also register a listener for when it unloads
logInfo("Computing partition " + split)
val array = rdd.compute(split).toArray(m)
cache.put(key, array)
loading.synchronized {
loading.remove(key)
loading.notifyAll()
var array: Array[T] = null
var putResponse: CachePutResponse = null
try {
array = rdd.compute(split).toArray(m)
putResponse = cache.put(rdd.id, split.index, array)
} finally {
// Tell other threads that we've finished our attempt to load the key (whether or not
// we've actually succeeded to put it in the map)
loading.synchronized {
loading.remove(key)
loading.notifyAll()
}
}
putResponse match {
case CachePutSuccess(size) => {
// Tell the master that we added the entry. Don't return until it
// replies so it can properly schedule future tasks that use this RDD.
trackerActor !? AddedToCache(rdd.id, split.index, Utils.getHost, size)
}
case _ => null
}
future.apply() // Wait for the reply from the cache tracker
return array.iterator
}
}
// Reports that an entry has been dropped from the cache
def dropEntry(key: Any) {
key match {
case (keySpaceId: Long, (rddId: Int, partition: Int)) =>
val host = System.getProperty("spark.hostname", Utils.localHostName)
trackerActor !! DroppedFromCache(rddId, partition, host)
case _ =>
logWarning("Unknown key format: %s".format(key))
// Called by the Cache to report that an entry has been dropped from it
def dropEntry(datasetId: Any, partition: Int) {
datasetId match {
//TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here.
case (cache.keySpaceId, rddId: Int) => trackerActor !! DroppedFromCache(rddId, partition, Utils.getHost)
}
}

Просмотреть файл

@ -9,31 +9,31 @@ import java.util.UUID
// TODO: cache into a separate directory using Utils.createTempDir
// TODO: clean up disk cache afterwards
class DiskSpillingCache extends BoundedMemoryCache {
private val diskMap = new LinkedHashMap[Any, File](32, 0.75f, true)
private val diskMap = new LinkedHashMap[(Any, Int), File](32, 0.75f, true)
override def get(key: Any): Any = {
override def get(datasetId: Any, partition: Int): Any = {
synchronized {
val ser = SparkEnv.get.serializer.newInstance()
super.get(key) match {
super.get(datasetId, partition) match {
case bytes: Any => // found in memory
ser.deserialize(bytes.asInstanceOf[Array[Byte]])
case _ => diskMap.get(key) match {
case _ => diskMap.get((datasetId, partition)) match {
case file: Any => // found on disk
try {
val startTime = System.currentTimeMillis
val bytes = new Array[Byte](file.length.toInt)
new FileInputStream(file).read(bytes)
val timeTaken = System.currentTimeMillis - startTime
logInfo("Reading key %s of size %d bytes from disk took %d ms".format(
key, file.length, timeTaken))
super.put(key, bytes)
logInfo("Reading key (%s, %d) of size %d bytes from disk took %d ms".format(
datasetId, partition, file.length, timeTaken))
super.put(datasetId, partition, bytes)
ser.deserialize(bytes.asInstanceOf[Array[Byte]])
} catch {
case e: IOException =>
logWarning("Failed to read key %s from disk at %s: %s".format(
key, file.getPath(), e.getMessage()))
diskMap.remove(key) // remove dead entry
logWarning("Failed to read key (%s, %d) from disk at %s: %s".format(
datasetId, partition, file.getPath(), e.getMessage()))
diskMap.remove((datasetId, partition)) // remove dead entry
null
}
@ -44,18 +44,18 @@ class DiskSpillingCache extends BoundedMemoryCache {
}
}
override def put(key: Any, value: Any) {
override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
var ser = SparkEnv.get.serializer.newInstance()
super.put(key, ser.serialize(value))
super.put(datasetId, partition, ser.serialize(value))
}
/**
* Spill the given entry to disk. Assumes that a lock is held on the
* DiskSpillingCache. Assumes that entry.value is a byte array.
*/
override protected def dropEntry(key: Any, entry: Entry) {
logInfo("Spilling key %s of size %d to make space".format(
key, entry.size))
override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
logInfo("Spilling key (%s, %d) of size %d to make space".format(
datasetId, partition, entry.size))
val cacheDir = System.getProperty(
"spark.diskSpillingCache.cacheDir",
System.getProperty("java.io.tmpdir"))
@ -64,11 +64,11 @@ class DiskSpillingCache extends BoundedMemoryCache {
val stream = new FileOutputStream(file)
stream.write(entry.value.asInstanceOf[Array[Byte]])
stream.close()
diskMap.put(key, file)
diskMap.put((datasetId, partition), file)
} catch {
case e: IOException =>
logWarning("Failed to spill key %s to disk at %s: %s".format(
key, file.getPath(), e.getMessage()))
logWarning("Failed to spill key (%s, %d) to disk at %s: %s".format(
datasetId, partition, file.getPath(), e.getMessage()))
// Do nothing and let the entry be discarded
}
}

Просмотреть файл

@ -65,16 +65,17 @@ class Executor extends org.apache.mesos.Executor with Logging {
extends Runnable {
override def run() = {
val tid = info.getTaskId.getValue
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(classLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + tid)
d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(info.getTaskId)
.setState(TaskState.TASK_RUNNING)
.build())
try {
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear
val task = Utils.deserialize[Task[Any]](info.getData.toByteArray, classLoader)
val task = ser.deserialize[Task[Any]](info.getData.toByteArray, classLoader)
for (gen <- task.generation) {// Update generation if any is set
env.mapOutputTracker.updateGeneration(gen)
}
@ -84,7 +85,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(info.getTaskId)
.setState(TaskState.TASK_FINISHED)
.setData(ByteString.copyFrom(Utils.serialize(result)))
.setData(ByteString.copyFrom(ser.serialize(result)))
.build())
logInfo("Finished task ID " + tid)
} catch {
@ -93,7 +94,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(info.getTaskId)
.setState(TaskState.TASK_FAILED)
.setData(ByteString.copyFrom(Utils.serialize(reason)))
.setData(ByteString.copyFrom(ser.serialize(reason)))
.build())
}
case t: Throwable => {
@ -101,7 +102,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(info.getTaskId)
.setState(TaskState.TASK_FAILED)
.setData(ByteString.copyFrom(Utils.serialize(reason)))
.setData(ByteString.copyFrom(ser.serialize(reason)))
.build())
// TODO: Handle errors in tasks less dramatically

Просмотреть файл

@ -34,6 +34,15 @@ class JavaSerializerInstance extends SerializerInstance {
in.readObject().asInstanceOf[T]
}
def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
val bis = new ByteArrayInputStream(bytes)
val ois = new ObjectInputStream(bis) {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, loader)
}
return ois.readObject.asInstanceOf[T]
}
def outputStream(s: OutputStream): SerializationStream = {
new JavaSerializationStream(s)
}

Просмотреть файл

@ -9,6 +9,7 @@ import scala.collection.mutable
import com.esotericsoftware.kryo._
import com.esotericsoftware.kryo.{Serializer => KSerializer}
import com.esotericsoftware.kryo.serialize.ClassSerializer
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
/**
@ -100,6 +101,14 @@ class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
buf.readClassAndObject(bytes).asInstanceOf[T]
}
def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
val oldClassLoader = ks.kryo.getClassLoader
ks.kryo.setClassLoader(loader)
val obj = buf.readClassAndObject(bytes).asInstanceOf[T]
ks.kryo.setClassLoader(oldClassLoader)
obj
}
def outputStream(s: OutputStream): SerializationStream = {
new KryoSerializationStream(ks.kryo, ks.threadByteBuf.get(), s)
}
@ -129,6 +138,8 @@ class KryoSerializer extends Serializer with Logging {
}
def createKryo(): Kryo = {
// This is used so we can serialize/deserialize objects without a zero-arg
// constructor.
val kryo = new KryoReflectionFactorySupport()
// Register some commonly used classes
@ -150,6 +161,10 @@ class KryoSerializer extends Serializer with Logging {
kryo.register(obj.getClass)
}
// Register the following classes for passing closures.
kryo.register(classOf[Class[_]], new ClassSerializer(kryo))
kryo.setRegistrationOptional(true)
// Register some commonly used Scala singleton objects. Because these
// are singletons, we must return the exact same local object when we
// deserialize rather than returning a clone as FieldSerializer would.

Просмотреть файл

@ -38,14 +38,23 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
Accumulators.clear
val bytes = Utils.serialize(task)
logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes")
val deserializedTask = Utils.deserialize[Task[_]](
bytes, Thread.currentThread.getContextClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
val startTime = System.currentTimeMillis
val bytes = ser.serialize(task)
val timeTaken = System.currentTimeMillis - startTime
logInfo("Size of task %d is %d bytes and took %d ms to serialize".format(
idInJob, bytes.size, timeTaken))
val deserializedTask = ser.deserialize[Task[_]](bytes, currentThread.getContextClassLoader)
val result: Any = deserializedTask.run(attemptId)
// Serialize and deserialize the result to emulate what the mesos
// executor does. This is useful to catch serialization errors early
// on in development (so when users move their local Spark programs
// to the cluster, they don't get surprised by serialization errors).
val resultToReturn = ser.deserialize[Any](ser.serialize(result))
val accumUpdates = Accumulators.values
logInfo("Finished task " + idInJob)
taskEnded(task, Success, result, accumUpdates)
taskEnded(task, Success, resultToReturn, accumUpdates)
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
@ -55,7 +64,7 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
submitTask(task, idInJob)
} else {
// TODO: Do something nicer here to return all the way to the user
System.exit(1)
taskEnded(task, new ExceptionFailure(t), null, null)
}
}
}

Просмотреть файл

@ -42,7 +42,7 @@ private class MesosScheduler(
// Memory used by each executor (in megabytes)
val EXECUTOR_MEMORY = {
if (System.getenv("SPARK_MEM") != null) {
memoryStringToMb(System.getenv("SPARK_MEM"))
MesosScheduler.memoryStringToMb(System.getenv("SPARK_MEM"))
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
} else {
512
@ -81,9 +81,7 @@ private class MesosScheduler(
// Sorts jobs in reverse order of run ID for use in our priority queue (so lower IDs run first)
private val jobOrdering = new Ordering[Job] {
override def compare(j1: Job, j2: Job): Int = {
return j2.runId - j1.runId
}
override def compare(j1: Job, j2: Job): Int = j2.runId - j1.runId
}
def newJobId(): Int = this.synchronized {
@ -162,7 +160,7 @@ private class MesosScheduler(
activeJobs(jobId) = myJob
activeJobsQueue += myJob
logInfo("Adding job with ID " + jobId)
jobTasks(jobId) = new HashSet()
jobTasks(jobId) = HashSet.empty[String]
}
driver.reviveOffers();
}
@ -390,23 +388,26 @@ private class MesosScheduler(
}
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
}
object MesosScheduler {
/**
* Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
* This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
* Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
* This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
* environment variable.
*/
def memoryStringToMb(str: String): Int = {
val lower = str.toLowerCase
if (lower.endsWith("k")) {
(lower.substring(0, lower.length-1).toLong / 1024).toInt
(lower.substring(0, lower.length - 1).toLong / 1024).toInt
} else if (lower.endsWith("m")) {
lower.substring(0, lower.length-1).toInt
lower.substring(0, lower.length - 1).toInt
} else if (lower.endsWith("g")) {
lower.substring(0, lower.length-1).toInt * 1024
lower.substring(0, lower.length - 1).toInt * 1024
} else if (lower.endsWith("t")) {
lower.substring(0, lower.length-1).toInt * 1024 * 1024
} else {// no suffix, so it's just a number in bytes
lower.substring(0, lower.length - 1).toInt * 1024 * 1024
} else {
// no suffix, so it's just a number in bytes
(lower.toLong / 1024 / 1024).toInt
}
}

Просмотреть файл

@ -3,6 +3,7 @@ package spark
import java.io.PrintWriter
import java.util.StringTokenizer
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
@ -10,8 +11,12 @@ import scala.io.Source
* An RDD that pipes the contents of each parent partition through an external command
* (printing them one per line) and returns the output as a collection of strings.
*/
class PipedRDD[T: ClassManifest](parent: RDD[T], command: Seq[String])
class PipedRDD[T: ClassManifest](
parent: RDD[T], command: Seq[String], envVars: Map[String, String])
extends RDD[String](parent.context) {
def this(parent: RDD[T], command: Seq[String]) = this(parent, command, Map())
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String) = this(parent, PipedRDD.tokenize(command))
@ -21,7 +26,12 @@ class PipedRDD[T: ClassManifest](parent: RDD[T], command: Seq[String])
override val dependencies = List(new OneToOneDependency(parent))
override def compute(split: Split): Iterator[String] = {
val proc = Runtime.getRuntime.exec(command.toArray)
val pb = new ProcessBuilder(command)
// Add the environmental variables to the process.
val currentEnvVars = pb.environment()
envVars.foreach { case(variable, value) => currentEnvVars.put(variable, value) }
val proc = pb.start()
val env = SparkEnv.get
// Start a thread to print the process's stderr to ours

Просмотреть файл

@ -9,8 +9,6 @@ import java.util.Random
import java.util.Date
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.Map
import scala.collection.mutable.HashMap
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
@ -50,7 +48,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
// Methods that must be implemented by subclasses
def splits: Array[Split]
def compute(split: Split): Iterator[T]
val dependencies: List[Dependency[_]]
@transient val dependencies: List[Dependency[_]]
// Optionally overridden by subclasses to specify how they are partitioned
val partitioner: Option[Partitioner] = None
@ -146,6 +144,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command)
def pipe(command: Seq[String], env: Map[String, String]): RDD[String] =
new PipedRDD(this, command, env)
def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] =
new MapPartitionsRDD(this, sc.clean(f))

Просмотреть файл

@ -16,6 +16,7 @@ trait Serializer {
trait SerializerInstance {
def serialize[T](t: T): Array[Byte]
def deserialize[T](bytes: Array[Byte]): T
def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T
def outputStream(s: OutputStream): SerializationStream
def inputStream(s: InputStream): DeserializationStream
}

Просмотреть файл

@ -9,13 +9,13 @@ import java.io._
class SerializingCache extends Cache with Logging {
val bmc = new BoundedMemoryCache
override def put(key: Any, value: Any) {
override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
val ser = SparkEnv.get.serializer.newInstance()
bmc.put(key, ser.serialize(value))
bmc.put(datasetId, partition, ser.serialize(value))
}
override def get(key: Any): Any = {
val bytes = bmc.get(key)
override def get(datasetId: Any, partition: Int): Any = {
val bytes = bmc.get(datasetId, partition)
if (bytes != null) {
val ser = SparkEnv.get.serializer.newInstance()
return ser.deserialize(bytes.asInstanceOf[Array[Byte]])

Просмотреть файл

@ -30,6 +30,9 @@ class SimpleJob(
// Maximum times a task is allowed to fail before failing the job
val MAX_TASK_FAILURES = 4
// Serializer for closures and tasks.
val ser = SparkEnv.get.closureSerializer.newInstance()
val callingThread = Thread.currentThread
val tasks = tasksSeq.toArray
val numTasks = tasks.length
@ -170,8 +173,14 @@ class SimpleJob(
.setType(Value.Type.SCALAR)
.setScalar(Value.Scalar.newBuilder().setValue(CPUS_PER_TASK).build())
.build()
val serializedTask = Utils.serialize(task)
logDebug("Serialized size: " + serializedTask.size)
val startTime = System.currentTimeMillis
val serializedTask = ser.serialize(task)
val timeTaken = System.currentTimeMillis - startTime
logInfo("Size of task %d:%d is %d bytes and took %d ms to serialize by %s"
.format(jobId, index, serializedTask.size, timeTaken, ser.getClass.getName))
val taskName = "task %d:%d".format(jobId, index)
return Some(TaskInfo.newBuilder()
.setTaskId(taskId)
@ -209,7 +218,8 @@ class SimpleJob(
tasksFinished += 1
logInfo("Finished TID %s (progress: %d/%d)".format(tid, tasksFinished, numTasks))
// Deserialize task result
val result = Utils.deserialize[TaskResult[_]](status.getData.toByteArray)
val result = ser.deserialize[TaskResult[_]](
status.getData.toByteArray, getClass.getClassLoader)
sched.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
// Mark finished and stop if we've finished all the tasks
finished(index) = true
@ -231,7 +241,8 @@ class SimpleJob(
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
if (status.getData != null && status.getData.size > 0) {
val reason = Utils.deserialize[TaskEndReason](status.getData.toByteArray)
val reason = ser.deserialize[TaskEndReason](
status.getData.toByteArray, getClass.getClassLoader)
reason match {
case fetchFailed: FetchFailed =>
logInfo("Loss was due to fetch failure from " + fetchFailed.serverUri)

Просмотреть файл

@ -8,6 +8,11 @@ import com.google.common.collect.MapMaker
class SoftReferenceCache extends Cache {
val map = new MapMaker().softValues().makeMap[Any, Any]()
override def get(key: Any): Any = map.get(key)
override def put(key: Any, value: Any) = map.put(key, value)
override def get(datasetId: Any, partition: Int): Any =
map.get((datasetId, partition))
override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
map.put((datasetId, partition), value)
return CachePutSuccess(0)
}
}

Просмотреть файл

@ -3,6 +3,7 @@ package spark
class SparkEnv (
val cache: Cache,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheTracker: CacheTracker,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
@ -27,6 +28,11 @@ object SparkEnv {
val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer")
val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
val closureSerializerClass =
System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
val closureSerializer =
Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer]
val cacheTracker = new CacheTracker(isMaster, cache)
val mapOutputTracker = new MapOutputTracker(isMaster)
@ -38,6 +44,13 @@ object SparkEnv {
val shuffleMgr = new ShuffleManager()
new SparkEnv(cache, serializer, cacheTracker, mapOutputTracker, shuffleFetcher, shuffleMgr)
new SparkEnv(
cache,
serializer,
closureSerializer,
cacheTracker,
mapOutputTracker,
shuffleFetcher,
shuffleMgr)
}
}

Просмотреть файл

@ -16,7 +16,7 @@ class UnionSplit[T: ClassManifest](
class UnionRDD[T: ClassManifest](
sc: SparkContext,
rdds: Seq[RDD[T]])
@transient rdds: Seq[RDD[T]])
extends RDD[T](sc)
with Serializable {
@ -33,7 +33,7 @@ class UnionRDD[T: ClassManifest](
override def splits = splits_
override val dependencies = {
@transient override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for ((rdd, index) <- rdds.zipWithIndex) {
@ -47,4 +47,4 @@ class UnionRDD[T: ClassManifest](
override def preferredLocations(s: Split): Seq[String] =
s.asInstanceOf[UnionSplit[T]].preferredLocations()
}
}

Просмотреть файл

@ -2,11 +2,11 @@ package spark
import java.io._
import java.net.InetAddress
import java.util.UUID
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import java.util.{Locale, UUID}
/**
* Various utility methods used by Spark.
@ -157,9 +157,12 @@ object Utils {
/**
* Get the local machine's hostname.
*/
def localHostName(): String = {
return InetAddress.getLocalHost().getHostName
}
def localHostName(): String = InetAddress.getLocalHost.getHostName
/**
* Get current host
*/
def getHost = System.getProperty("spark.hostname", localHostName())
/**
* Delete a file or directory and its contents recursively.
@ -174,4 +177,28 @@ object Utils {
throw new IOException("Failed to delete: " + file)
}
}
/**
* Use unit suffixes (Byte, Kilobyte, Megabyte, Gigabyte, Terabyte and
* Petabyte) in order to reduce the number of digits to four or less. For
* example, 4,000,000 is returned as 4MB.
*/
def memoryBytesToString(size: Long): String = {
val GB = 1L << 30
val MB = 1L << 20
val KB = 1L << 10
val (value, unit) = {
if (size >= 2*GB) {
(size.asInstanceOf[Double] / GB, "GB")
} else if (size >= 2*MB) {
(size.asInstanceOf[Double] / MB, "MB")
} else if (size >= 2*KB) {
(size.asInstanceOf[Double] / KB, "KB")
} else {
(size.asInstanceOf[Double], "B")
}
}
"%.1f%s".formatLocal(Locale.US, value, unit)
}
}

Просмотреть файл

@ -1,14 +0,0 @@
package spark
import com.google.common.collect.MapMaker
/**
* An implementation of Cache that uses weak references.
*/
class WeakReferenceCache extends Cache {
val map = new MapMaker().weakValues().makeMap[Any, Any]()
override def get(key: Any): Any = map.get(key)
override def put(key: Any, value: Any) = map.put(key, value)
}

Просмотреть файл

@ -16,7 +16,7 @@ extends Broadcast[T] with Logging with Serializable {
def value = value_
BitTorrentBroadcast.synchronized {
BitTorrentBroadcast.values.put(uuid, value_)
BitTorrentBroadcast.values.put(uuid, 0, value_)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@ -130,7 +130,7 @@ extends Broadcast[T] with Logging with Serializable {
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject
BitTorrentBroadcast.synchronized {
val cachedVal = BitTorrentBroadcast.values.get(uuid)
val cachedVal = BitTorrentBroadcast.values.get(uuid, 0)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
@ -152,12 +152,12 @@ extends Broadcast[T] with Logging with Serializable {
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
BitTorrentBroadcast.values.put(uuid, value_)
BitTorrentBroadcast.values.put(uuid, 0, value_)
} else {
// TODO: This part won't work, cause HDFS writing is turned OFF
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
BitTorrentBroadcast.values.put(uuid, value_)
BitTorrentBroadcast.values.put(uuid, 0, value_)
fileIn.close()
}

Просмотреть файл

@ -15,7 +15,7 @@ extends Broadcast[T] with Logging with Serializable {
def value = value_
ChainedBroadcast.synchronized {
ChainedBroadcast.values.put(uuid, value_)
ChainedBroadcast.values.put(uuid, 0, value_)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@ -101,7 +101,7 @@ extends Broadcast[T] with Logging with Serializable {
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject
ChainedBroadcast.synchronized {
val cachedVal = ChainedBroadcast.values.get(uuid)
val cachedVal = ChainedBroadcast.values.get(uuid, 0)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
@ -121,11 +121,11 @@ extends Broadcast[T] with Logging with Serializable {
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
ChainedBroadcast.values.put(uuid, value_)
ChainedBroadcast.values.put(uuid, 0, value_)
} else {
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
ChainedBroadcast.values.put(uuid, value_)
ChainedBroadcast.values.put(uuid, 0, value_)
fileIn.close()
}

Просмотреть файл

@ -17,7 +17,7 @@ extends Broadcast[T] with Logging with Serializable {
def value = value_
DfsBroadcast.synchronized {
DfsBroadcast.values.put(uuid, value_)
DfsBroadcast.values.put(uuid, 0, value_)
}
if (!isLocal) {
@ -34,7 +34,7 @@ extends Broadcast[T] with Logging with Serializable {
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject
DfsBroadcast.synchronized {
val cachedVal = DfsBroadcast.values.get(uuid)
val cachedVal = DfsBroadcast.values.get(uuid, 0)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
@ -43,7 +43,7 @@ extends Broadcast[T] with Logging with Serializable {
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
DfsBroadcast.values.put(uuid, value_)
DfsBroadcast.values.put(uuid, 0, value_)
fileIn.close
val time = (System.nanoTime - start) / 1e9

Просмотреть файл

@ -15,7 +15,7 @@ extends Broadcast[T] with Logging with Serializable {
def value = value_
TreeBroadcast.synchronized {
TreeBroadcast.values.put(uuid, value_)
TreeBroadcast.values.put(uuid, 0, value_)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@ -104,7 +104,7 @@ extends Broadcast[T] with Logging with Serializable {
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject
TreeBroadcast.synchronized {
val cachedVal = TreeBroadcast.values.get(uuid)
val cachedVal = TreeBroadcast.values.get(uuid, 0)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
@ -124,11 +124,11 @@ extends Broadcast[T] with Logging with Serializable {
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
TreeBroadcast.values.put(uuid, value_)
TreeBroadcast.values.put(uuid, 0, value_)
} else {
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
TreeBroadcast.values.put(uuid, value_)
TreeBroadcast.values.put(uuid, 0, value_)
fileIn.close()
}

Просмотреть файл

@ -0,0 +1,31 @@
package spark
import org.scalatest.FunSuite
class BoundedMemoryCacheTest extends FunSuite {
test("constructor test") {
val cache = new BoundedMemoryCache(40)
expect(40)(cache.getCapacity)
}
test("caching") {
val cache = new BoundedMemoryCache(40) {
//TODO sorry about this, but there is not better way how to skip 'cacheTracker.dropEntry'
override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
}
}
//should be OK
expect(CachePutSuccess(30))(cache.put("1", 0, "Meh"))
//we cannot add this to cache (there is not enough space in cache) & we cannot evict the only value from
//cache because it's from the same dataset
expect(CachePutFailure())(cache.put("1", 1, "Meh"))
//should be OK, dataset '1' can be evicted from cache
expect(CachePutSuccess(30))(cache.put("2", 0, "Meh"))
//should fail, cache should obey it's capacity
expect(CachePutFailure())(cache.put("3", 0, "Very_long_and_useless_string"))
}
}

Просмотреть файл

@ -0,0 +1,97 @@
package spark
import org.scalatest.FunSuite
import collection.mutable.HashMap
class CacheTrackerSuite extends FunSuite {
test("CacheTrackerActor slave initialization & cache status") {
System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
val tracker = new CacheTrackerActor
tracker.start()
tracker !? SlaveCacheStarted("host001", initialSize)
assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 0L)))
tracker !? StopCacheTracker
}
test("RegisterRDD") {
System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
val tracker = new CacheTrackerActor
tracker.start()
tracker !? SlaveCacheStarted("host001", initialSize)
tracker !? RegisterRDD(1, 3)
tracker !? RegisterRDD(2, 1)
assert(getCacheLocations(tracker) == Map(1 -> List(List(), List(), List()), 2 -> List(List())))
tracker !? StopCacheTracker
}
test("AddedToCache") {
System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
val tracker = new CacheTrackerActor
tracker.start()
tracker !? SlaveCacheStarted("host001", initialSize)
tracker !? RegisterRDD(1, 2)
tracker !? RegisterRDD(2, 1)
tracker !? AddedToCache(1, 0, "host001", 2L << 15)
tracker !? AddedToCache(1, 1, "host001", 2L << 11)
tracker !? AddedToCache(2, 0, "host001", 3L << 10)
assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
tracker !? StopCacheTracker
}
test("DroppedFromCache") {
System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
val tracker = new CacheTrackerActor
tracker.start()
tracker !? SlaveCacheStarted("host001", initialSize)
tracker !? RegisterRDD(1, 2)
tracker !? RegisterRDD(2, 1)
tracker !? AddedToCache(1, 0, "host001", 2L << 15)
tracker !? AddedToCache(1, 1, "host001", 2L << 11)
tracker !? AddedToCache(2, 0, "host001", 3L << 10)
assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
tracker !? DroppedFromCache(1, 1, "host001", 2L << 11)
assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 68608L)))
assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
tracker !? StopCacheTracker
}
/**
* Helper function to get cacheLocations from CacheTracker
*/
def getCacheLocations(tracker: CacheTrackerActor) = tracker !? GetCacheLocations match {
case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]].map {
case (i, arr) => (i -> arr.toList)
}
}
}

Просмотреть файл

@ -65,5 +65,21 @@ class FailureSuite extends FunSuite {
FailureSuiteState.clear()
}
test("failure because task results are not serializable") {
val sc = new SparkContext("local[1,1]", "test")
val results = sc.makeRDD(1 to 3).map(x => new NonSerializable)
val thrown = intercept[spark.SparkException] {
results.collect()
}
assert(thrown.getClass === classOf[spark.SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))
sc.stop()
FailureSuiteState.clear()
}
// TODO: Need to add tests with shuffle fetch failures.
}

Просмотреть файл

@ -0,0 +1,28 @@
package spark
import org.scalatest.FunSuite
class MesosSchedulerSuite extends FunSuite {
test("memoryStringToMb"){
assert(MesosScheduler.memoryStringToMb("1") == 0)
assert(MesosScheduler.memoryStringToMb("1048575") == 0)
assert(MesosScheduler.memoryStringToMb("3145728") == 3)
assert(MesosScheduler.memoryStringToMb("1024k") == 1)
assert(MesosScheduler.memoryStringToMb("5000k") == 4)
assert(MesosScheduler.memoryStringToMb("4024k") == MesosScheduler.memoryStringToMb("4024K"))
assert(MesosScheduler.memoryStringToMb("1024m") == 1024)
assert(MesosScheduler.memoryStringToMb("5000m") == 5000)
assert(MesosScheduler.memoryStringToMb("4024m") == MesosScheduler.memoryStringToMb("4024M"))
assert(MesosScheduler.memoryStringToMb("2g") == 2048)
assert(MesosScheduler.memoryStringToMb("3g") == MesosScheduler.memoryStringToMb("3G"))
assert(MesosScheduler.memoryStringToMb("2t") == 2097152)
assert(MesosScheduler.memoryStringToMb("3t") == MesosScheduler.memoryStringToMb("3T"))
}
}

Просмотреть файл

@ -0,0 +1,37 @@
package spark
import org.scalatest.FunSuite
import SparkContext._
class PipedRDDSuite extends FunSuite {
test("basic pipe") {
val sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("cat"))
val c = piped.collect()
println(c.toSeq)
assert(c.size === 4)
assert(c(0) === "1")
assert(c(1) === "2")
assert(c(2) === "3")
assert(c(3) === "4")
sc.stop()
}
test("pipe with env variable") {
val sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
val c = piped.collect()
assert(c.size === 2)
assert(c(0) === "LALALA")
assert(c(1) === "LALALA")
sc.stop()
}
}

Просмотреть файл

@ -0,0 +1,29 @@
package spark
import org.scalatest.FunSuite
import java.io.{ByteArrayOutputStream, ByteArrayInputStream}
import util.Random
class UtilsSuite extends FunSuite {
test("memoryBytesToString") {
assert(Utils.memoryBytesToString(10) === "10.0B")
assert(Utils.memoryBytesToString(1500) === "1500.0B")
assert(Utils.memoryBytesToString(2000000) === "1953.1KB")
assert(Utils.memoryBytesToString(2097152) === "2.0MB")
assert(Utils.memoryBytesToString(2306867) === "2.2MB")
assert(Utils.memoryBytesToString(5368709120L) === "5.0GB")
}
test("copyStream") {
//input array initialization
val bytes = Array.ofDim[Byte](9000)
Random.nextBytes(bytes)
val os = new ByteArrayOutputStream()
Utils.copyStream(new ByteArrayInputStream(bytes), os)
assert(os.toByteArray.toList.equals(bytes.toList))
}
}

Просмотреть файл

@ -7,7 +7,7 @@ object Main {
def interp = _interp
private[repl] def interp_=(i: SparkILoop) { _interp = i }
def interp_=(i: SparkILoop) { _interp = i }
def main(args: Array[String]) {
_interp = new SparkILoop