diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index 3431ad2258..45a14c8290 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -48,8 +48,9 @@ class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { } } } catch { + // TODO: this is really ugly -- let's find a better way of throwing a FetchFailedException case be: BlockException => { - val regex = "shuffledid_([0-9]*)_([0-9]*)_([0-9]]*)".r + val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]]*)".r be.blockId match { case regex(sId, mId, rId) => { val address = addresses(mId.toInt) diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 0c97cd44a1..de23eb6f48 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -116,7 +116,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = { val locs = bmAddresses.get(shuffleId) if (locs == null) { - logInfo("Don't have map outputs for shuffe " + shuffleId + ", fetching them") + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") fetching.synchronized { if (fetching.contains(shuffleId)) { // Someone else is fetching it; wait for them to be done @@ -158,6 +158,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg def incrementGeneration() { generationLock.synchronized { generation += 1 + logDebug("Increasing generation to " + generation) } } diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 9e335c25f7..dba209ac27 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -63,6 +63,7 @@ class Executor extends Logging { Thread.currentThread.setContextClassLoader(classLoader) Accumulators.clear() val task = ser.deserialize[Task[Any]](serializedTask, classLoader) + logInfo("Its generation is " + task.generation) env.mapOutputTracker.updateGeneration(task.generation) val value = task.run(taskId.toInt) val accumUpdates = Accumulators.values diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 451faee66e..da8aff9dd5 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -111,7 +111,7 @@ extends Connection(SocketChannel.open, selector_) { messages.synchronized{ /*messages += message*/ messages.enqueue(message) - logInfo("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") + logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") } } @@ -136,7 +136,7 @@ extends Connection(SocketChannel.open, selector_) { return chunk } /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ - logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) + logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) } } None diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 1a22d06cc8..0e764fff81 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -14,7 +14,8 @@ import scala.collection.mutable.SynchronizedQueue import scala.collection.mutable.Queue import scala.collection.mutable.ArrayBuffer -import akka.dispatch.{Promise, ExecutionContext, Future} +import akka.dispatch.{Await, Promise, ExecutionContext, Future} +import akka.util.Duration case class ConnectionManagerId(host: String, port: Int) { def toSocketAddress() = new InetSocketAddress(host, port) @@ -247,7 +248,7 @@ class ConnectionManager(port: Int) extends Logging { } private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) { - logInfo("Handling [" + message + "] from [" + connectionManagerId + "]") + logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") message match { case bufferMessage: BufferMessage => { if (bufferMessage.hasAckId) { @@ -305,7 +306,7 @@ class ConnectionManager(port: Int) extends Logging { } val connection = connectionsById.getOrElse(connectionManagerId, startNewConnection()) message.senderAddress = id.toSocketAddress() - logInfo("Sending [" + message + "] to [" + connectionManagerId + "]") + logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") /*connection.send(message)*/ sendMessageRequests.synchronized { sendMessageRequests += ((message, connection)) @@ -325,7 +326,7 @@ class ConnectionManager(port: Int) extends Logging { } def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = { - sendMessageReliably(connectionManagerId, message)() + Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf) } def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index e0e050d7c9..618d7b9794 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -85,6 +85,7 @@ class ShuffleMapTask( out.writeInt(bytes.length) out.write(bytes) out.writeInt(partition) + out.writeLong(generation) out.writeObject(split) } @@ -97,6 +98,7 @@ class ShuffleMapTask( rdd = rdd_ dep = dep_ partition = in.readInt() + generation = in.readLong() split = in.readObject().asInstanceOf[Split] } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 013671c1c8..83e7c6e036 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -2,13 +2,14 @@ package spark.scheduler.cluster import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import akka.actor.{Props, Actor, ActorRef, ActorSystem} +import akka.actor._ import akka.util.duration._ import akka.pattern.ask import spark.{SparkException, Logging, TaskState} import akka.dispatch.Await import java.util.concurrent.atomic.AtomicInteger +import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} /** * A standalone scheduler backend, which waits for standalone executors to connect to it through @@ -23,8 +24,16 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor { val slaveActor = new HashMap[String, ActorRef] + val slaveAddress = new HashMap[String, Address] val slaveHost = new HashMap[String, String] val freeCores = new HashMap[String, Int] + val actorToSlaveId = new HashMap[ActorRef, String] + val addressToSlaveId = new HashMap[Address, String] + + override def preStart() { + // Listen for remote client disconnection events, since they don't go through Akka's watch() + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + } def receive = { case RegisterSlave(slaveId, host, cores) => @@ -33,9 +42,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } else { logInfo("Registered slave: " + sender + " with ID " + slaveId) sender ! RegisteredSlave(sparkProperties) + context.watch(sender) slaveActor(slaveId) = sender slaveHost(slaveId) = host freeCores(slaveId) = cores + slaveAddress(slaveId) = sender.path.address + actorToSlaveId(sender) = slaveId + addressToSlaveId(sender.path.address) = slaveId totalCoreCount.addAndGet(cores) makeOffers() } @@ -54,7 +67,14 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor sender ! true context.stop(self) - // TODO: Deal with nodes disconnecting too! (Including decreasing totalCoreCount) + case Terminated(actor) => + actorToSlaveId.get(actor).foreach(removeSlave) + + case RemoteClientDisconnected(transport, address) => + addressToSlaveId.get(address).foreach(removeSlave) + + case RemoteClientShutdown(transport, address) => + addressToSlaveId.get(address).foreach(removeSlave) } // Make fake resource offers on all slaves @@ -76,6 +96,20 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor slaveActor(task.slaveId) ! LaunchTask(task) } } + + // Remove a disconnected slave from the cluster + def removeSlave(slaveId: String) { + logInfo("Slave " + slaveId + " disconnected, so removing it") + val numCores = freeCores(slaveId) + actorToSlaveId -= slaveActor(slaveId) + addressToSlaveId -= slaveAddress(slaveId) + slaveActor -= slaveId + slaveHost -= slaveId + freeCores -= slaveId + slaveHost -= slaveId + totalCoreCount.addAndGet(-numCores) + scheduler.slaveLost(slaveId) + } } var masterActor: ActorRef = null diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala index 0fc1d8ed30..65e59841a9 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala @@ -20,6 +20,8 @@ class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host: def successful: Boolean = finished && !failed + def running: Boolean = !finished + def duration: Long = { if (!finished) { throw new UnsupportedOperationException("duration() called on unfinished tasks") diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index be24316e80..5a7df6040c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -88,6 +88,7 @@ class TaskSetManager( // Figure out the current map output tracker generation and set it on all tasks val generation = sched.mapOutputTracker.getGeneration + logDebug("Generation for " + taskSet.id + ": " + generation) for (t <- tasks) { t.generation = generation } @@ -264,6 +265,11 @@ class TaskSetManager( def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } val index = info.index info.markFailed() if (!finished(index)) { @@ -340,7 +346,7 @@ class TaskSetManager( } def hostLost(hostname: String) { - logInfo("Re-queueing tasks for " + hostname) + logInfo("Re-queueing tasks for " + hostname + " from TaskSet " + taskSet.id) // If some task has preferred locations only on hostname, put it in the no-prefs list // to avoid the wait from delay scheduling for (index <- getPendingTasksForHost(hostname)) { @@ -349,7 +355,7 @@ class TaskSetManager( pendingTasksWithNoPrefs += index } } - // Also re-enqueue any tasks that ran on the failed host if this is a shuffle map stage + // Re-enqueue any tasks that ran on the failed host if this is a shuffle map stage if (tasks(0).isInstanceOf[ShuffleMapTask]) { for ((tid, info) <- taskInfos if info.host == hostname) { val index = taskInfos(tid).index @@ -364,6 +370,10 @@ class TaskSetManager( } } } + // Also re-enqueue any tasks that were running on the node + for ((tid, info) <- taskInfos if info.running && info.host == hostname) { + taskLost(tid, TaskState.KILLED, null) + } } /** diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index ff9914ae25..45f99717bc 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -364,6 +364,12 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val startTimeMs = System.currentTimeMillis var bytes: ByteBuffer = null + + // If we need to replicate the data, we'll want access to the values, but because our + // put will read the whole iterator, there will be no values left. For the case where + // the put serializes data, we'll remember the bytes, above; but for the case where + // it doesn't, such as MEMORY_ONLY_DESER, let's rely on the put returning an Iterator. + var valuesAfterPut: Iterator[Any] = null locker.getLock(blockId).synchronized { logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) @@ -391,7 +397,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // If only save to memory memoryStore.putValues(blockId, values, level) match { case Right(newBytes) => bytes = newBytes - case _ => + case Left(newIterator) => valuesAfterPut = newIterator } } else { // If only save to disk @@ -408,8 +414,13 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // Replicate block if required if (level.replication > 1) { + // Serialize the block if not already done if (bytes == null) { - bytes = dataSerialize(values) // serialize the block if not already done + if (valuesAfterPut == null) { + throw new SparkException( + "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") + } + bytes = dataSerialize(valuesAfterPut) } replicate(blockId, bytes, level) }