Added a unit test for local-cluster mode and simplified some of the code involved in that

This commit is contained in:
Matei Zaharia 2012-09-07 17:08:36 -07:00 коммит произвёл Denny
Родитель f2ac55840c
Коммит a13780670d
8 изменённых файлов: 120 добавлений и 48 удалений

Просмотреть файл

@ -67,7 +67,7 @@ class SparkContext(
System.setProperty("spark.master.port", "0") System.setProperty("spark.master.port", "0")
} }
private val isLocal = (master == "local" || master.startsWith("local[")) private val isLocal = (master == "local" || master.startsWith("local\\["))
// Create the Spark execution environment (cache, map output tracker, etc) // Create the Spark execution environment (cache, map output tracker, etc)
val env = SparkEnv.createFromSystemProperties( val env = SparkEnv.createFromSystemProperties(
@ -84,7 +84,7 @@ class SparkContext(
// Regular expression for local[N, maxRetries], used in tests with failing tasks // Regular expression for local[N, maxRetries], used in tests with failing tasks
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+),([0-9]+)\]""".r val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+),([0-9]+)\]""".r
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally // Regular expression for simulating a Spark cluster of [N, cores, memory] locally
val SPARK_LOCALCLUSTER_REGEX = """local-cluster\[([0-9]+)\,([0-9]+),([0-9]+)]""".r val LOCAL_CLUSTER_REGEX = """local-cluster\[([0-9]+),([0-9]+),([0-9]+)]""".r
// Regular expression for connecting to Spark deploy clusters // Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """(spark://.*)""".r val SPARK_REGEX = """(spark://.*)""".r
@ -104,13 +104,13 @@ class SparkContext(
scheduler.initialize(backend) scheduler.initialize(backend)
scheduler scheduler
case SPARK_LOCALCLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerlave) => case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerlave) =>
val scheduler = new ClusterScheduler(this) val scheduler = new ClusterScheduler(this)
val localCluster = new LocalSparkCluster(numSlaves.toInt, coresPerSlave.toInt, memoryPerlave.toInt) val localCluster = new LocalSparkCluster(numSlaves.toInt, coresPerSlave.toInt, memoryPerlave.toInt)
val sparkUrl = localCluster.start() val sparkUrl = localCluster.start()
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, frameworkName) val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, frameworkName)
scheduler.initialize(backend) scheduler.initialize(backend)
backend.shutdownHook = (backend: SparkDeploySchedulerBackend) => { backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop() localCluster.stop()
} }
scheduler scheduler

Просмотреть файл

@ -76,9 +76,12 @@ private object HttpBroadcast extends Logging {
} }
def stop() { def stop() {
if (server != null) { synchronized {
server.stop() if (server != null) {
server = null server.stop()
server = null
initialized = false
}
} }
} }

Просмотреть файл

@ -9,10 +9,8 @@ import spark.{Logging, Utils}
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
class LocalSparkCluster(numSlaves : Int, coresPerSlave : Int, class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging {
memoryPerSlave : Int) extends Logging {
val threadPool = Utils.newDaemonFixedThreadPool(numSlaves + 1)
val localIpAddress = Utils.localIpAddress val localIpAddress = Utils.localIpAddress
var masterActor : ActorRef = _ var masterActor : ActorRef = _
@ -24,35 +22,25 @@ class LocalSparkCluster(numSlaves : Int, coresPerSlave : Int,
val slaveActors = ArrayBuffer[ActorRef]() val slaveActors = ArrayBuffer[ActorRef]()
def start() : String = { def start() : String = {
logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.") logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.")
/* Start the Master */ /* Start the Master */
val (masterActorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0) val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0)
masterActorSystem = actorSystem
masterUrl = "spark://" + localIpAddress + ":" + masterPort masterUrl = "spark://" + localIpAddress + ":" + masterPort
threadPool.execute(new Runnable { val actor = masterActorSystem.actorOf(
def run() { Props(new Master(localIpAddress, masterPort, 0)), name = "Master")
val actor = masterActorSystem.actorOf( masterActor = actor
Props(new Master(localIpAddress, masterPort, 8080)), name = "Master")
masterActor = actor
masterActorSystem.awaitTermination()
}
})
/* Start the Slaves */ /* Start the Slaves */
(1 to numSlaves).foreach { slaveNum => for (slaveNum <- 1 to numSlaves) {
val (actorSystem, boundPort) = val (actorSystem, boundPort) =
AkkaUtils.createActorSystem("sparkWorker" + slaveNum, localIpAddress, 0) AkkaUtils.createActorSystem("sparkWorker" + slaveNum, localIpAddress, 0)
slaveActorSystems += actorSystem slaveActorSystems += actorSystem
threadPool.execute(new Runnable { val actor = actorSystem.actorOf(
def run() { Props(new Worker(localIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)),
val actor = actorSystem.actorOf( name = "Worker")
Props(new Worker(localIpAddress, boundPort, 8080 + slaveNum, coresPerSlave, memoryPerSlave, masterUrl)), slaveActors += actor
name = "Worker")
slaveActors += actor
actorSystem.awaitTermination()
}
})
} }
return masterUrl return masterUrl
@ -60,9 +48,10 @@ class LocalSparkCluster(numSlaves : Int, coresPerSlave : Int,
def stop() { def stop() {
logInfo("Shutting down local Spark cluster.") logInfo("Shutting down local Spark cluster.")
masterActorSystem.shutdown() // Stop the slaves before the master so they don't get upset that it disconnected
slaveActorSystems.foreach(_.shutdown()) slaveActorSystems.foreach(_.shutdown())
slaveActorSystems.foreach(_.awaitTermination())
masterActorSystem.shutdown()
masterActorSystem.awaitTermination()
} }
}
}

Просмотреть файл

@ -29,6 +29,7 @@ class ExecutorRunner(
val fullId = jobId + "/" + execId val fullId = jobId + "/" + execId
var workerThread: Thread = null var workerThread: Thread = null
var process: Process = null var process: Process = null
var shutdownHook: Thread = null
def start() { def start() {
workerThread = new Thread("ExecutorRunner for " + fullId) { workerThread = new Thread("ExecutorRunner for " + fullId) {
@ -37,17 +38,16 @@ class ExecutorRunner(
workerThread.start() workerThread.start()
// Shutdown hook that kills actors on shutdown. // Shutdown hook that kills actors on shutdown.
Runtime.getRuntime.addShutdownHook( shutdownHook = new Thread() {
new Thread() { override def run() {
override def run() { if (process != null) {
if(process != null) { logInfo("Shutdown hook killing child process.")
logInfo("Shutdown Hook killing process.") process.destroy()
process.destroy() process.waitFor()
process.waitFor()
}
} }
}) }
}
Runtime.getRuntime.addShutdownHook(shutdownHook)
} }
/** Stop this executor runner, including killing the process it launched */ /** Stop this executor runner, including killing the process it launched */
@ -58,8 +58,10 @@ class ExecutorRunner(
if (process != null) { if (process != null) {
logInfo("Killing process!") logInfo("Killing process!")
process.destroy() process.destroy()
process.waitFor()
} }
worker ! ExecutorStateChanged(jobId, execId, ExecutorState.KILLED, None) worker ! ExecutorStateChanged(jobId, execId, ExecutorState.KILLED, None)
Runtime.getRuntime.removeShutdownHook(shutdownHook)
} }
} }
@ -114,7 +116,12 @@ class ExecutorRunner(
val out = new FileOutputStream(file) val out = new FileOutputStream(file)
new Thread("redirect output to " + file) { new Thread("redirect output to " + file) {
override def run() { override def run() {
Utils.copyStream(in, out, true) try {
Utils.copyStream(in, out, true)
} catch {
case e: IOException =>
logInfo("Redirection to " + file + " closed: " + e.getMessage)
}
} }
}.start() }.start()
} }

Просмотреть файл

@ -153,6 +153,10 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, mas
def generateWorkerId(): String = { def generateWorkerId(): String = {
"worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port) "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port)
} }
override def postStop() {
executors.values.foreach(_.kill())
}
} }
object Worker { object Worker {

Просмотреть файл

@ -16,7 +16,7 @@ class SparkDeploySchedulerBackend(
var client: Client = null var client: Client = null
var stopping = false var stopping = false
var shutdownHook : (SparkDeploySchedulerBackend) => Unit = _ var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _
val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt
@ -62,8 +62,8 @@ class SparkDeploySchedulerBackend(
stopping = true; stopping = true;
super.stop() super.stop()
client.stop() client.stop()
if (shutdownHook != null) { if (shutdownCallback != null) {
shutdownHook(this) shutdownCallback(this)
} }
} }

Просмотреть файл

@ -0,0 +1,68 @@
package spark
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
import org.scalacheck.Prop._
import com.google.common.io.Files
import scala.collection.mutable.ArrayBuffer
import SparkContext._
class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
val clusterUrl = "local-cluster[2,1,512]"
var sc: SparkContext = _
after {
if (sc != null) {
sc.stop()
sc = null
}
}
test("simple groupByKey") {
sc = new SparkContext(clusterUrl, "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 5)
val groups = pairs.groupByKey(5).collect()
assert(groups.size === 2)
val valuesFor1 = groups.find(_._1 == 1).get._2
assert(valuesFor1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
}
test("accumulators") {
sc = new SparkContext(clusterUrl, "test")
val accum = sc.accumulator(0)
sc.parallelize(1 to 10, 10).foreach(x => accum += x)
assert(accum.value === 55)
}
test("broadcast variables") {
sc = new SparkContext(clusterUrl, "test")
val array = new Array[Int](100)
val bv = sc.broadcast(array)
array(2) = 3 // Change the array -- this should not be seen on workers
val rdd = sc.parallelize(1 to 10, 10)
val sum = rdd.map(x => bv.value.sum).reduce(_ + _)
assert(sum === 0)
}
test("repeatedly failing task") {
sc = new SparkContext(clusterUrl, "test")
val accum = sc.accumulator(0)
val thrown = intercept[SparkException] {
sc.parallelize(1 to 10, 10).foreach(x => println(x / 0))
}
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.contains("more than 4 times"))
}
}

1
run
Просмотреть файл

@ -52,6 +52,7 @@ CLASSPATH="$SPARK_CLASSPATH"
CLASSPATH+=":$MESOS_CLASSPATH" CLASSPATH+=":$MESOS_CLASSPATH"
CLASSPATH+=":$FWDIR/conf" CLASSPATH+=":$FWDIR/conf"
CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/test-classes"
CLASSPATH+=":$CORE_DIR/src/main/resources" CLASSPATH+=":$CORE_DIR/src/main/resources"
CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes"