зеркало из https://github.com/microsoft/spark.git
General FileServer
A general fileserver for both JARs and regular files.
This commit is contained in:
@ -0,0 +1,31 @@
package spark
import java.io.{File, PrintWriter}
import java.net.URL
import scala.collection.mutable.HashMap
import org.apache.hadoop.fs.FileUtil
class HttpFileServer extends Logging {
var fileDir : File = null
var httpServer : HttpServer = null
var serverUri : String = null
def initialize() {
fileDir = Utils.createTempDir()
logInfo("HTTP File server directory is " + fileDir)
httpServer = new HttpServer(fileDir)
serverUri = httpServer.uri
def addFile(file: File) : String = {
Utils.copyFile(file, new File(fileDir, file.getName))
return serverUri + "/" + file.getName
def stop() {
@ -2,11 +2,12 @@ package spark
import java.io._
import java.util.concurrent.atomic.AtomicInteger
import java.net.URI
import akka.actor.Actor
import akka.actor.Actor._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ArrayBuffer, HashMap}
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
@ -76,7 +77,10 @@ class SparkContext(
// Used to store a URL for each static file together with the file's local timestamp
val files = HashMap[String, Long]()
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
@ -90,13 +94,13 @@ class SparkContext(
master match {
case "local" =>
new LocalScheduler(1, 0)
new LocalScheduler(1, 0, this)
case LOCAL_N_REGEX(threads) =>
new LocalScheduler(threads.toInt, 0)
new LocalScheduler(threads.toInt, 0, this)
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
new LocalScheduler(threads.toInt, maxFailures.toInt)
new LocalScheduler(threads.toInt, maxFailures.toInt, this)
case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this)
@ -131,7 +135,7 @@ class SparkContext(
private var dagScheduler = new DAGScheduler(taskScheduler)
// Methods for creating RDDs
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = {
@ -310,7 +314,24 @@ class SparkContext(
// Keep around a weak hash map of values to Cached versions?
def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
// Adds a file dependency to all Tasks executed in the future.
def addFile(path: String) : String = {
val uri = new URI(path)
uri.getScheme match {
// A local file
case null | "file" =>
val file = new File(uri.getPath)
val url = env.httpFileServer.addFile(file)
files(url) = System.currentTimeMillis
logInfo("Added file " + path + " at " + url + " with timestamp " + files(url))
return url
case _ =>
files(path) = System.currentTimeMillis
return path
// Stop the SparkContext
def stop() {
@ -19,15 +19,17 @@ class SparkEnv (
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager
val connectionManager: ConnectionManager,
val httpFileServer: HttpFileServer
) {
/** No-parameter constructor for unit tests. */
def this() = {
this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null, null)
def stop() {
@ -95,7 +97,11 @@ object SparkEnv {
System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
val shuffleFetcher =
val httpFileServer = new HttpFileServer()
System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
if (System.getProperty("spark.stream.distributed", "false") == "true") {
val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]]
@ -126,6 +132,7 @@ object SparkEnv {
@ -1,18 +1,19 @@
package spark
import java.io._
import java.net.InetAddress
import java.net.{InetAddress, URL, URI}
import java.util.{Locale, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import java.util.{Locale, UUID}
import scala.io.Source
* Various utility methods used by Spark.
object Utils {
object Utils extends Logging {
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@ -115,6 +116,47 @@ object Utils {
val out = new FileOutputStream(dest)
copyStream(in, out, true)
/* Download a file from a given URL to the local filesystem */
def downloadFile(url: URL, localPath: String) {
val in = url.openStream()
val out = new FileOutputStream(localPath)
Utils.copyStream(in, out, true)
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
val targetFile = new File(targetDir, filename)
if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) {
// Use the java.net library to fetch it
logInfo("Fetching " + url + " to " + targetFile)
val in = new URL(url).openStream()
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
} else {
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val uri = new URI(url)
val conf = new Configuration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
// Decompress the file if it's a .tar or .tar.gz
if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xzf", filename), targetDir)
} else if (filename.endsWith(".tar")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xf", filename), targetDir)
* Shuffle the elements of a collection into a random order, returning the
@ -65,38 +65,6 @@ class ExecutorRunner(
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
val targetFile = new File(targetDir, filename)
if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) {
// Use the java.net library to fetch it
logInfo("Fetching " + url + " to " + targetFile)
val in = new URL(url).openStream()
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
} else {
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val uri = new URI(url)
val conf = new Configuration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
// Decompress the file if it's a .tar or .tar.gz
if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xzf", filename), targetDir)
} else if (filename.endsWith(".tar")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xf", filename), targetDir)
/** Replace variables such as {{SLAVEID}} and {{CORES}} in a command argument passed to us */
def substituteVariables(argument: String): String = argument match {
case "{{SLAVEID}}" => workerId
@ -4,7 +4,9 @@ import java.io.{File, FileOutputStream}
import java.net.{URL, URLClassLoader}
import java.util.concurrent._
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.fs.FileUtil
import scala.collection.mutable.{ArrayBuffer, HashMap}
import spark.broadcast._
import spark.scheduler._
@ -18,6 +20,8 @@ class Executor extends Logging {
var classLoader: ClassLoader = null
var threadPool: ExecutorService = null
var env: SparkEnv = null
val fileSet: HashMap[String, Long] = new HashMap[String, Long]()
val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
@ -63,6 +67,7 @@ class Executor extends Logging {
val task = ser.deserialize[Task[Any]](serializedTask, classLoader)
logInfo("Its generation is " + task.generation)
val value = task.run(taskId.toInt)
@ -108,7 +113,7 @@ class Executor extends Logging {
for (uri <- uris.split(",").filter(_.size > 0)) {
val url = new URL(uri)
val filename = url.getPath.split("/").last
downloadFile(url, filename)
Utils.downloadFile(url, filename)
localFiles += filename
if (localFiles.size > 0) {
@ -136,10 +141,4 @@ class Executor extends Logging {
return loader
// Download a file from a given URL to the local filesystem
private def downloadFile(url: URL, localPath: String) {
val in = url.openStream()
val out = new FileOutputStream(localPath)
Utils.copyStream(in, out, true)
@ -1,5 +1,10 @@
package spark.scheduler
import scala.collection.mutable.HashMap
import spark.HttpFileServer
import spark.Utils
import java.io.File
* A task to execute on a worker node.
@ -8,4 +13,21 @@ abstract class Task[T](val stageId: Int) extends Serializable {
def preferredLocations: Seq[String] = Nil
var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler.
// Stores file dependencies for this task.
var fileSet : HashMap[String, Long] = new HashMap[String, Long]()
// Downloads all file dependencies from the Master file server
def downloadFileDependencies(currentFileSet : HashMap[String, Long]) {
// Find files that either don't exist or have an earlier timestamp
val missingFiles = fileSet.filter { case(k,v) =>
!currentFileSet.isDefinedAt(k) || currentFileSet(k) <= v
// Fetch each missing file
missingFiles.foreach { case (k,v) =>
Utils.fetchFile(k, new File(System.getProperty("user.dir")))
currentFileSet(k) = v
@ -88,6 +88,7 @@ class ClusterScheduler(sc: SparkContext)
def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
tasks.foreach { task => task.fileSet ++= sc.files }
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = new TaskSetManager(this, taskSet)
@ -235,30 +236,24 @@ class ClusterScheduler(sc: SparkContext)
override def defaultParallelism() = backend.defaultParallelism()
// Create a server for all the JARs added by the user to SparkContext.
// We first copy the JARs to a temp directory for easier server setup.
// Copies all the JARs added by the user to the SparkContext
// to the fileserver directory.
private def createJarServer() {
val jarDir = Utils.createTempDir()
logInfo("Temp directory for JARs: " + jarDir)
val fileServerDir = SparkEnv.get.httpFileServer.fileDir
val fileServerUri = SparkEnv.get.httpFileServer.serverUri
val filenames = ArrayBuffer[String]()
// Copy each JAR to a unique filename in the jarDir
for ((path, index) <- sc.jars.zipWithIndex) {
val file = new File(path)
if (file.exists) {
val filename = index + "_" + file.getName
Utils.copyFile(file, new File(jarDir, filename))
Utils.copyFile(file, new File(fileServerDir, filename))
filenames += filename
// Create the server
jarServer = new HttpServer(jarDir)
// Build up the jar URI list
val serverUri = jarServer.uri
jarUris = filenames.map(f => serverUri + "/" + f).mkString(",")
jarUris = filenames.map(f => fileServerUri + "/" + f).mkString(",")
System.setProperty("spark.jar.uris", jarUris)
logInfo("JAR server started at " + serverUri)
logInfo("JARs available at " + jarUris)
// Check for speculatable tasks in all our active jobs.
@ -2,6 +2,7 @@ package spark.scheduler.local
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.HashMap
import spark._
import spark.scheduler._
@ -11,12 +12,13 @@ import spark.scheduler._
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
* testing fault recovery.
class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with Logging {
class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends TaskScheduler with Logging {
var attemptId = new AtomicInteger(0)
var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
val env = SparkEnv.get
var listener: TaskSchedulerListener = null
val fileSet: HashMap[String, Long] = new HashMap[String, Long]()
// TODO: Need to take into account stage priority in scheduling
override def start() {}
@ -30,6 +32,7 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
val failCount = new Array[Int](tasks.size)
def submitTask(task: Task[_], idInJob: Int) {
task.fileSet ++= sc.files
val myAttemptId = attemptId.getAndIncrement()
threadPool.submit(new Runnable {
def run() {
@ -42,6 +45,7 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
logInfo("Running task " + idInJob)
// Set the Spark execution environment for the worker thread
try {
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
@ -81,6 +85,7 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
override def stop() {
@ -0,0 +1,43 @@
package spark
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import java.io.{File, PrintWriter}
class FileServerSuite extends FunSuite with BeforeAndAfter {
var sc: SparkContext = _
before {
// Create a sample text file
val pw = new PrintWriter(System.getProperty("java.io.tmpdir") + "FileServerSuite.txt")
after {
if (sc != null) {
sc = null
// Clean up downloaded file
val tmpFile = new File("FileServerSuite.txt")
if (tmpFile.exists) {
test("Distributing files") {
sc = new SparkContext("local[4]", "test")
sc.addFile(System.getProperty("java.io.tmpdir") + "FileServerSuite.txt")
val testRdd = sc.parallelize(List(1,2,3,4))
val result = testRdd.map { x =>
val in = new java.io.BufferedReader(new java.io.FileReader("FileServerSuite.txt"))
val fileVal = in.readLine().toInt
}.reduce(_ + _)
assert(result == 400)
Ссылка в новой задаче