diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index a76253ea14..9e1bde3fbe 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -67,8 +67,10 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes") val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) updateDependencies(taskFiles, taskJars) // Download any files added with addFile + val deserStart = System.currentTimeMillis() val deserializedTask = ser.deserialize[Task[_]]( taskBytes, Thread.currentThread.getContextClassLoader) + val deserTime = System.currentTimeMillis() - deserStart // Run it val result: Any = deserializedTask.run(attemptId) @@ -77,15 +79,19 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon // executor does. This is useful to catch serialization errors early // on in development (so when users move their local Spark programs // to the cluster, they don't get surprised by serialization errors). - val resultToReturn = ser.deserialize[Any](ser.serialize(result)) + val serResult = ser.serialize(result) + deserializedTask.metrics.get.resultSize = serResult.limit() + val resultToReturn = ser.deserialize[Any](serResult) val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( ser.serialize(Accumulators.values)) logInfo("Finished " + task) info.markSuccessful() + deserializedTask.metrics.get.executorRunTime = info.duration.toInt //close enough + deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt // If the threadpool has not already been shutdown, notify DAGScheduler if (!Thread.currentThread().isInterrupted) - listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, null) + listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null)) } catch { case t: Throwable => { logError("Exception in task " + idInJob, t) diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala new file mode 100644 index 0000000000..dd9f2d7e91 --- /dev/null +++ b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala @@ -0,0 +1,80 @@ +package spark.scheduler + +import org.scalatest.FunSuite +import spark.{SparkContext, LocalSparkContext} +import scala.collection.mutable +import org.scalatest.matchers.ShouldMatchers +import spark.SparkContext._ + +/** + * + */ + +class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { + + test("local metrics") { + sc = new SparkContext("local[4]", "test") + val listener = new SaveStageInfo + sc.addSparkListener(listener) + sc.addSparkListener(new StatsReportListener) + + val d = sc.parallelize(1 to 1e4.toInt, 64) + d.count + listener.stageInfos.size should be (1) + + val d2 = d.map{i => i -> i * 2}.setName("shuffle input 1") + + val d3 = d.map{i => i -> (0 to (i % 5))}.setName("shuffle input 2") + + val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => k -> (v1.size, v2.size)} + d4.setName("A Cogroup") + + d4.collectAsMap + + listener.stageInfos.size should be (4) + listener.stageInfos.foreach {stageInfo => + //small test, so some tasks might take less than 1 millisecond, but average should be greater than 1 ms + checkNonZeroAvg(stageInfo.taskInfos.map{_._1.duration}, stageInfo + " duration") + checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, stageInfo + " executorRunTime") + checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, stageInfo + " executorDeserializeTime") + if (stageInfo.stage.rdd.name == d4.name) { + checkNonZeroAvg(stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, stageInfo + " fetchWaitTime") + } + + stageInfo.taskInfos.foreach{case (taskInfo, taskMetrics) => + taskMetrics.resultSize should be > (0l) + if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) { + taskMetrics.shuffleWriteMetrics should be ('defined) + taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l) + } + if (stageInfo.stage.rdd.name == d4.name) { + taskMetrics.shuffleReadMetrics should be ('defined) + val sm = taskMetrics.shuffleReadMetrics.get + sm.totalBlocksFetched should be > (0) + sm.shuffleReadMillis should be > (0l) + sm.localBlocksFetched should be > (0) + sm.remoteBlocksFetched should be (0) + sm.remoteBytesRead should be (0l) + sm.remoteFetchTime should be (0l) + } + } + } + } + + def checkNonZeroAvg(m: Traversable[Long], msg: String) { + assert(m.sum / m.size.toDouble > 0.0, msg) + } + + def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = { + val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name} + !names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty + } + + class SaveStageInfo extends SparkListener { + val stageInfos = mutable.Buffer[StageInfo]() + def onStageCompleted(stage: StageCompleted) { + stageInfos += stage.stageInfo + } + } + +}