From 389fb4cc54f0c433f3afd56913229aa9fb4bf2fd Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 31 Aug 2012 17:47:43 -0700 Subject: [PATCH] End runJob() with a SparkException when a task fails too many times in one of the cluster schedulers. --- .../scala/spark/scheduler/DAGScheduler.scala | 70 +++++++++++++++---- .../spark/scheduler/DAGSchedulerEvent.scala | 2 + .../scheduler/TaskSchedulerListener.scala | 3 + .../main/scala/spark/scheduler/TaskSet.scala | 2 + .../scheduler/cluster/TaskSetManager.scala | 1 + 5 files changed, 66 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index f7472971b5..2e2dc295b6 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -38,6 +38,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with eventQueue.put(HostLost(host)) } + // Called by TaskScheduler to cancel an entier TaskSet due to repeated failures. + override def taskSetFailed(taskSet: TaskSet, reason: String) { + eventQueue.put(TaskSetFailed(taskSet, reason)) + } + // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one // as more failure events come in @@ -258,6 +263,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with case completion: CompletionEvent => handleTaskCompletion(completion) + case TaskSetFailed(taskSet, reason) => + abortStage(idToStage(taskSet.stageId), reason) + case StopDAGScheduler => // Cancel any active jobs for (job <- activeJobs) { @@ -475,18 +483,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with handleHostLost(bmAddress.ip) } - case _ => - // Non-fetch failure -- probably a bug in the job, so bail out - // TODO: Cancel all tasks that are still running - resultStageToJob.get(stage) match { - case Some(job) => - val error = new SparkException("Task failed: " + task + ", reason: " + event.reason) - job.listener.jobFailed(error) - activeJobs -= job - resultStageToJob -= stage - case None => - logInfo("Ignoring result from " + task + " because its job has finished") - } + case other => + // Non-fetch failure -- probably a bug in user code; abort all jobs depending on this stage + abortStage(idToStage(task.stageId), task + " failed: " + other) } } @@ -509,6 +508,53 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with updateCacheLocs() } } + + /** + * Aborts all jobs depending on a particular Stage. This is called in response to a task set + * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. + */ + def abortStage(failedStage: Stage, reason: String) { + val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq + for (resultStage <- dependentStages) { + val job = resultStageToJob(resultStage) + job.listener.jobFailed(new SparkException("Job failed: " + reason)) + activeJobs -= job + resultStageToJob -= resultStage + } + if (dependentStages.isEmpty) { + logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") + } + } + + /** + * Return true if one of stage's ancestors is target. + */ + def stageDependsOn(stage: Stage, target: Stage): Boolean = { + if (stage == target) { + return true + } + val visitedRdds = new HashSet[RDD[_]] + val visitedStages = new HashSet[Stage] + def visit(rdd: RDD[_]) { + if (!visitedRdds(rdd)) { + visitedRdds += rdd + for (dep <- rdd.dependencies) { + dep match { + case shufDep: ShuffleDependency[_,_,_] => + val mapStage = getShuffleMapStage(shufDep, stage.priority) + if (!mapStage.isAvailable) { + visitedStages += mapStage + visit(mapStage.rdd) + } // Otherwise there's no need to follow the dependency back + case narrowDep: NarrowDependency[_] => + visit(narrowDep.rdd) + } + } + } + } + visit(stage.rdd) + visitedRdds.contains(target.rdd) + } def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = { // If the partition is cached, return the cache locations diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala index 0fc73059c3..1322aae3a3 100644 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -29,4 +29,6 @@ case class CompletionEvent( case class HostLost(host: String) extends DAGSchedulerEvent +case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent + case object StopDAGScheduler extends DAGSchedulerEvent diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala index a647eec9e4..f838272fb4 100644 --- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala +++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala @@ -13,4 +13,7 @@ trait TaskSchedulerListener { // A node was lost from the cluster. def hostLost(host: String): Unit + + // The TaskScheduler wants to abort an entire task set. + def taskSetFailed(taskSet: TaskSet, reason: String): Unit } diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala index 6f29dd2e9d..3f4a464902 100644 --- a/core/src/main/scala/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/spark/scheduler/TaskSet.scala @@ -6,4 +6,6 @@ package spark.scheduler */ class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) { val id: String = stageId + "." + attempt + + override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 5a7df6040c..5f98a396b4 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -342,6 +342,7 @@ class TaskSetManager( failed = true causeOfFailure = message // TODO: Kill running tasks if we were not terminated due to a Mesos error + sched.listener.taskSetFailed(taskSet, message) sched.taskSetFinished(this) }