зеркало из https://github.com/microsoft/spark.git
Merge branch 'dev'
Conflicts: src/scala/spark/HdfsFile.scala src/scala/spark/NexusScheduler.scala src/test/spark/repl/ReplSuite.scala
This commit is contained in:
Коммит
7d0eae17e3
2
run
2
run
|
@ -4,7 +4,7 @@
|
|||
FWDIR=`dirname $0`
|
||||
|
||||
# Set JAVA_OPTS to be able to load libnexus.so and set various other misc options
|
||||
JAVA_OPTS="-Djava.library.path=$FWDIR/third_party:$FWDIR/src/native -Xmx750m"
|
||||
export JAVA_OPTS="-Djava.library.path=$FWDIR/third_party:$FWDIR/src/native -Xms100m -Xmx750m"
|
||||
if [ -e $FWDIR/conf/java-opts ] ; then
|
||||
JAVA_OPTS+=" `cat $FWDIR/conf/java-opts`"
|
||||
fi
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
import spark.SparkContext
|
||||
|
||||
object BroadcastTest {
|
||||
def main(args: Array[String]) {
|
||||
if (args.length == 0) {
|
||||
System.err.println("Usage: BroadcastTest <host> [<slices>]")
|
||||
System.exit(1)
|
||||
}
|
||||
val spark = new SparkContext(args(0), "Broadcast Test")
|
||||
val slices = if (args.length > 1) args(1).toInt else 2
|
||||
val num = if (args.length > 2) args(2).toInt else 1000000
|
||||
|
||||
var arr = new Array[Int](num)
|
||||
for (i <- 0 until arr.length)
|
||||
arr(i) = i
|
||||
|
||||
val barr = spark.broadcast(arr)
|
||||
spark.parallelize(1 to 10, slices).foreach {
|
||||
println("in task: barr = " + barr)
|
||||
i => println(barr.value.size)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -120,18 +120,18 @@ object SparkALS {
|
|||
|
||||
// Iteratively update movies then users
|
||||
val Rc = spark.broadcast(R)
|
||||
var msb = spark.broadcast(ms)
|
||||
var usb = spark.broadcast(us)
|
||||
var msc = spark.broadcast(ms)
|
||||
var usc = spark.broadcast(us)
|
||||
for (iter <- 1 to ITERATIONS) {
|
||||
println("Iteration " + iter + ":")
|
||||
ms = spark.parallelize(0 until M, slices)
|
||||
.map(i => updateMovie(i, msb.value(i), usb.value, Rc.value))
|
||||
.map(i => updateMovie(i, msc.value(i), usc.value, Rc.value))
|
||||
.toArray
|
||||
msb = spark.broadcast(ms) // Re-broadcast ms because it was updated
|
||||
msc = spark.broadcast(ms) // Re-broadcast ms because it was updated
|
||||
us = spark.parallelize(0 until U, slices)
|
||||
.map(i => updateUser(i, usb.value(i), msb.value, Rc.value))
|
||||
.map(i => updateUser(i, usc.value(i), msc.value, Rc.value))
|
||||
.toArray
|
||||
usb = spark.broadcast(us) // Re-broadcast us because it was updated
|
||||
usc = spark.broadcast(us) // Re-broadcast us because it was updated
|
||||
println("RMSE = " + rmse(R, ms, us))
|
||||
println()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,798 @@
|
|||
package spark
|
||||
|
||||
import java.io._
|
||||
import java.net._
|
||||
import java.util.{UUID, PriorityQueue, Comparator}
|
||||
|
||||
import com.google.common.collect.MapMaker
|
||||
|
||||
import java.util.concurrent.{Executors, ExecutorService}
|
||||
|
||||
import scala.actors.Actor
|
||||
import scala.actors.Actor._
|
||||
|
||||
import scala.collection.mutable.Map
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
|
||||
|
||||
import spark.compress.lzf.{LZFInputStream, LZFOutputStream}
|
||||
|
||||
@serializable
|
||||
trait BroadcastRecipe {
|
||||
val uuid = UUID.randomUUID
|
||||
|
||||
// We cannot have an abstract readObject here due to some weird issues with
|
||||
// readObject having to be 'private' in sub-classes. Possibly a Scala bug!
|
||||
def sendBroadcast: Unit
|
||||
|
||||
override def toString = "spark.Broadcast(" + uuid + ")"
|
||||
}
|
||||
|
||||
// TODO: Should think about storing in HDFS in the future
|
||||
// TODO: Right, now no parallelization between multiple broadcasts
|
||||
@serializable
|
||||
class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean)
|
||||
extends BroadcastRecipe {
|
||||
|
||||
def value = value_
|
||||
|
||||
BroadcastCS.synchronized { BroadcastCS.values.put (uuid, value_) }
|
||||
|
||||
if (!local) { sendBroadcast }
|
||||
|
||||
def sendBroadcast () {
|
||||
// Create a variableInfo object and store it in valueInfos
|
||||
var variableInfo = blockifyObject (value_, BroadcastCS.blockSize)
|
||||
// TODO: Even though this part is not in use now, there is problem in the
|
||||
// following statement. Shouldn't use constant port and hostAddress anymore?
|
||||
// val masterSource =
|
||||
// new SourceInfo (BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort,
|
||||
// variableInfo.totalBlocks, variableInfo.totalBytes, 0)
|
||||
// variableInfo.pqOfSources.add (masterSource)
|
||||
|
||||
BroadcastCS.synchronized {
|
||||
// BroadcastCS.valueInfos.put (uuid, variableInfo)
|
||||
|
||||
// TODO: Not using variableInfo in current implementation. Manually
|
||||
// setting all the variables inside BroadcastCS object
|
||||
|
||||
BroadcastCS.initializeVariable (variableInfo)
|
||||
}
|
||||
|
||||
// Now store a persistent copy in HDFS, just in case
|
||||
val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid))
|
||||
out.writeObject (value_)
|
||||
out.close
|
||||
}
|
||||
|
||||
private def readObject (in: ObjectInputStream) {
|
||||
in.defaultReadObject
|
||||
BroadcastCS.synchronized {
|
||||
val cachedVal = BroadcastCS.values.get (uuid)
|
||||
if (cachedVal != null) {
|
||||
value_ = cachedVal.asInstanceOf[T]
|
||||
} else {
|
||||
// Only a single worker (the first one) in the same node can ever be
|
||||
// here. The rest will always get the value ready
|
||||
val start = System.nanoTime
|
||||
|
||||
val retByteArray = BroadcastCS.receiveBroadcast (uuid)
|
||||
// If does not succeed, then get from HDFS copy
|
||||
if (retByteArray != null) {
|
||||
value_ = byteArrayToObject[T] (retByteArray)
|
||||
BroadcastCS.values.put (uuid, value_)
|
||||
// val variableInfo = blockifyObject (value_, BroadcastCS.blockSize)
|
||||
// BroadcastCS.valueInfos.put (uuid, variableInfo)
|
||||
} else {
|
||||
val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid))
|
||||
value_ = fileIn.readObject.asInstanceOf[T]
|
||||
BroadcastCH.values.put(uuid, value_)
|
||||
fileIn.close
|
||||
}
|
||||
|
||||
val time = (System.nanoTime - start) / 1e9
|
||||
println( System.currentTimeMillis + ": " + "Reading Broadcasted variable " + uuid + " took " + time + " s")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def blockifyObject (obj: T, blockSize: Int): VariableInfo = {
|
||||
val baos = new ByteArrayOutputStream
|
||||
val oos = new ObjectOutputStream (baos)
|
||||
oos.writeObject (obj)
|
||||
oos.close
|
||||
baos.close
|
||||
val byteArray = baos.toByteArray
|
||||
val bais = new ByteArrayInputStream (byteArray)
|
||||
|
||||
var blockNum = (byteArray.length / blockSize)
|
||||
if (byteArray.length % blockSize != 0)
|
||||
blockNum += 1
|
||||
|
||||
var retVal = new Array[BroadcastBlock] (blockNum)
|
||||
var blockID = 0
|
||||
|
||||
// TODO: What happens in byteArray.length == 0 => blockNum == 0
|
||||
for (i <- 0 until (byteArray.length, blockSize)) {
|
||||
val thisBlockSize = Math.min (blockSize, byteArray.length - i)
|
||||
var tempByteArray = new Array[Byte] (thisBlockSize)
|
||||
val hasRead = bais.read (tempByteArray, 0, thisBlockSize)
|
||||
|
||||
retVal (blockID) = new BroadcastBlock (blockID, tempByteArray)
|
||||
blockID += 1
|
||||
}
|
||||
bais.close
|
||||
|
||||
var variableInfo = VariableInfo (retVal, blockNum, byteArray.length)
|
||||
variableInfo.hasBlocks = blockNum
|
||||
|
||||
return variableInfo
|
||||
}
|
||||
|
||||
private def byteArrayToObject[A] (bytes: Array[Byte]): A = {
|
||||
val in = new ObjectInputStream (new ByteArrayInputStream (bytes))
|
||||
val retVal = in.readObject.asInstanceOf[A]
|
||||
in.close
|
||||
return retVal
|
||||
}
|
||||
|
||||
private def getByteArrayOutputStream (obj: T): ByteArrayOutputStream = {
|
||||
val bOut = new ByteArrayOutputStream
|
||||
val out = new ObjectOutputStream (bOut)
|
||||
out.writeObject (obj)
|
||||
out.close
|
||||
bOut.close
|
||||
return bOut
|
||||
}
|
||||
}
|
||||
|
||||
@serializable
|
||||
class CentralizedHDFSBroadcast[T](@transient var value_ : T, local: Boolean)
|
||||
extends BroadcastRecipe {
|
||||
|
||||
def value = value_
|
||||
|
||||
BroadcastCH.synchronized { BroadcastCH.values.put(uuid, value_) }
|
||||
|
||||
if (!local) { sendBroadcast }
|
||||
|
||||
def sendBroadcast () {
|
||||
val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid))
|
||||
out.writeObject (value_)
|
||||
out.close
|
||||
}
|
||||
|
||||
// Called by Java when deserializing an object
|
||||
private def readObject(in: ObjectInputStream) {
|
||||
in.defaultReadObject
|
||||
BroadcastCH.synchronized {
|
||||
val cachedVal = BroadcastCH.values.get(uuid)
|
||||
if (cachedVal != null) {
|
||||
value_ = cachedVal.asInstanceOf[T]
|
||||
} else {
|
||||
val start = System.nanoTime
|
||||
|
||||
val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid))
|
||||
value_ = fileIn.readObject.asInstanceOf[T]
|
||||
BroadcastCH.values.put(uuid, value_)
|
||||
fileIn.close
|
||||
|
||||
val time = (System.nanoTime - start) / 1e9
|
||||
println( System.currentTimeMillis + ": " + "Reading Broadcasted variable " + uuid + " took " + time + " s")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@serializable
|
||||
case class SourceInfo (val hostAddress: String, val listenPort: Int,
|
||||
val totalBlocks: Int, val totalBytes: Int, val replicaID: Int)
|
||||
extends Comparable [SourceInfo]{
|
||||
|
||||
var currentLeechers = 0
|
||||
var receptionFailed = false
|
||||
|
||||
def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
|
||||
}
|
||||
|
||||
@serializable
|
||||
case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { }
|
||||
|
||||
@serializable
|
||||
case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock],
|
||||
val totalBlocks: Int, val totalBytes: Int) {
|
||||
|
||||
@transient var hasBlocks = 0
|
||||
|
||||
val listenPortLock = new AnyRef
|
||||
val totalBlocksLock = new AnyRef
|
||||
val hasBlocksLock = new AnyRef
|
||||
|
||||
@transient var pqOfSources = new PriorityQueue[SourceInfo]
|
||||
}
|
||||
|
||||
private object Broadcast {
|
||||
private var initialized = false
|
||||
|
||||
// Will be called by SparkContext or Executor before using Broadcast
|
||||
// Calls all other initializers here
|
||||
def initialize (isMaster: Boolean) {
|
||||
synchronized {
|
||||
if (!initialized) {
|
||||
// Initialization for CentralizedHDFSBroadcast
|
||||
BroadcastCH.initialize
|
||||
// Initialization for ChainedStreamingBroadcast
|
||||
//BroadcastCS.initialize (isMaster)
|
||||
|
||||
initialized = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private object BroadcastCS {
|
||||
val values = new MapMaker ().softValues ().makeMap[UUID, Any]
|
||||
// val valueInfos = new MapMaker ().softValues ().makeMap[UUID, Any]
|
||||
|
||||
// private var valueToPort = Map[UUID, Int] ()
|
||||
|
||||
private var initialized = false
|
||||
private var isMaster_ = false
|
||||
|
||||
private var masterHostAddress_ = "127.0.0.1"
|
||||
private var masterListenPort_ : Int = 11111
|
||||
private var blockSize_ : Int = 512 * 1024
|
||||
private var maxRetryCount_ : Int = 2
|
||||
private var serverSocketTimout_ : Int = 50000
|
||||
private var dualMode_ : Boolean = false
|
||||
|
||||
private val hostAddress = InetAddress.getLocalHost.getHostAddress
|
||||
private var listenPort = -1
|
||||
|
||||
var arrayOfBlocks: Array[BroadcastBlock] = null
|
||||
var totalBytes = -1
|
||||
var totalBlocks = -1
|
||||
var hasBlocks = 0
|
||||
|
||||
val listenPortLock = new Object
|
||||
val totalBlocksLock = new Object
|
||||
val hasBlocksLock = new Object
|
||||
|
||||
var pqOfSources = new PriorityQueue[SourceInfo]
|
||||
|
||||
private var serveMR: ServeMultipleRequests = null
|
||||
private var guideMR: GuideMultipleRequests = null
|
||||
|
||||
def initialize (isMaster__ : Boolean) {
|
||||
synchronized {
|
||||
if (!initialized) {
|
||||
masterHostAddress_ =
|
||||
System.getProperty ("spark.broadcast.masterHostAddress", "127.0.0.1")
|
||||
masterListenPort_ =
|
||||
System.getProperty ("spark.broadcast.masterListenPort", "11111").toInt
|
||||
blockSize_ =
|
||||
System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024
|
||||
maxRetryCount_ =
|
||||
System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt
|
||||
serverSocketTimout_ =
|
||||
System.getProperty ("spark.broadcast.serverSocketTimout", "50000").toInt
|
||||
dualMode_ =
|
||||
System.getProperty ("spark.broadcast.dualMode", "false").toBoolean
|
||||
|
||||
isMaster_ = isMaster__
|
||||
|
||||
if (isMaster) {
|
||||
guideMR = new GuideMultipleRequests
|
||||
guideMR.setDaemon (true)
|
||||
guideMR.start
|
||||
println (System.currentTimeMillis + ": " + "GuideMultipleRequests started")
|
||||
}
|
||||
serveMR = new ServeMultipleRequests
|
||||
serveMR.setDaemon (true)
|
||||
serveMR.start
|
||||
|
||||
println (System.currentTimeMillis + ": " + "ServeMultipleRequests started")
|
||||
|
||||
println (System.currentTimeMillis + ": " + "BroadcastCS object has been initialized")
|
||||
|
||||
initialized = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: This should change in future implementation.
|
||||
// Called from the Master constructor to setup states for this particular that
|
||||
// is being broadcasted
|
||||
def initializeVariable (variableInfo: VariableInfo) {
|
||||
arrayOfBlocks = variableInfo.arrayOfBlocks
|
||||
totalBytes = variableInfo.totalBytes
|
||||
totalBlocks = variableInfo.totalBlocks
|
||||
hasBlocks = variableInfo.totalBlocks
|
||||
|
||||
// listenPort should already be valid
|
||||
assert (listenPort != -1)
|
||||
|
||||
pqOfSources = new PriorityQueue[SourceInfo]
|
||||
val masterSource_0 =
|
||||
new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 0)
|
||||
BroadcastCS.pqOfSources.add (masterSource_0)
|
||||
// Add one more time to have two replicas of any seeds in the PQ
|
||||
if (BroadcastCS.dualMode) {
|
||||
val masterSource_1 =
|
||||
new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 1)
|
||||
BroadcastCS.pqOfSources.add (masterSource_1)
|
||||
}
|
||||
}
|
||||
|
||||
def masterHostAddress = masterHostAddress_
|
||||
def masterListenPort = masterListenPort_
|
||||
def blockSize = blockSize_
|
||||
def maxRetryCount = maxRetryCount_
|
||||
def serverSocketTimout = serverSocketTimout_
|
||||
def dualMode = dualMode_
|
||||
|
||||
def isMaster = isMaster_
|
||||
|
||||
def receiveBroadcast (variableUUID: UUID): Array[Byte] = {
|
||||
// Wait until hostAddress and listenPort are created by the
|
||||
// ServeMultipleRequests thread
|
||||
// NO need to wait; ServeMultipleRequests is created much further ahead
|
||||
while (listenPort == -1) {
|
||||
listenPortLock.synchronized {
|
||||
listenPortLock.wait
|
||||
}
|
||||
}
|
||||
|
||||
// Connect and receive broadcast from the specified source, retrying the
|
||||
// specified number of times in case of failures
|
||||
var retriesLeft = BroadcastCS.maxRetryCount
|
||||
var retByteArray: Array[Byte] = null
|
||||
do {
|
||||
// Connect to Master and send this worker's Information
|
||||
val clientSocketToMaster =
|
||||
new Socket(BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort)
|
||||
println (System.currentTimeMillis + ": " + "Connected to Master's guiding object")
|
||||
// TODO: Guiding object connection is reusable
|
||||
val oisMaster =
|
||||
new ObjectInputStream (clientSocketToMaster.getInputStream)
|
||||
val oosMaster =
|
||||
new ObjectOutputStream (clientSocketToMaster.getOutputStream)
|
||||
|
||||
oosMaster.writeObject(new SourceInfo (hostAddress, listenPort, -1, -1, 0))
|
||||
oosMaster.flush
|
||||
|
||||
// Receive source information from Master
|
||||
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
|
||||
totalBlocks = sourceInfo.totalBlocks
|
||||
arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks)
|
||||
totalBlocksLock.synchronized {
|
||||
totalBlocksLock.notifyAll
|
||||
}
|
||||
totalBytes = sourceInfo.totalBytes
|
||||
|
||||
println (System.currentTimeMillis + ": " + "Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
|
||||
|
||||
retByteArray = receiveSingleTransmission (sourceInfo)
|
||||
|
||||
println (System.currentTimeMillis + ": " + "I got this from receiveSingleTransmission: " + retByteArray)
|
||||
|
||||
// TODO: Update sourceInfo to add error notifactions for Master
|
||||
if (retByteArray == null) { sourceInfo.receptionFailed = true }
|
||||
|
||||
// TODO: Supposed to update values here, but we don't support advanced
|
||||
// statistics right now. Master can handle leecherCount by itself.
|
||||
|
||||
// Send back statistics to the Master
|
||||
oosMaster.writeObject (sourceInfo)
|
||||
|
||||
oisMaster.close
|
||||
oosMaster.close
|
||||
clientSocketToMaster.close
|
||||
|
||||
retriesLeft -= 1
|
||||
} while (retriesLeft > 0 && retByteArray == null)
|
||||
|
||||
return retByteArray
|
||||
}
|
||||
|
||||
// Tries to receive broadcast from the Master and returns Boolean status.
|
||||
// This might be called multiple times to retry a defined number of times.
|
||||
private def receiveSingleTransmission(sourceInfo: SourceInfo): Array[Byte] = {
|
||||
var clientSocketToSource: Socket = null
|
||||
var oisSource: ObjectInputStream = null
|
||||
var oosSource: ObjectOutputStream = null
|
||||
|
||||
var retByteArray:Array[Byte] = null
|
||||
|
||||
try {
|
||||
// Connect to the source to get the object itself
|
||||
clientSocketToSource =
|
||||
new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
|
||||
oosSource =
|
||||
new ObjectOutputStream (clientSocketToSource.getOutputStream)
|
||||
oisSource =
|
||||
new ObjectInputStream (clientSocketToSource.getInputStream)
|
||||
|
||||
println (System.currentTimeMillis + ": " + "Inside receiveSingleTransmission")
|
||||
println (System.currentTimeMillis + ": " + "totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
|
||||
retByteArray = new Array[Byte] (totalBytes)
|
||||
for (i <- 0 until totalBlocks) {
|
||||
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
|
||||
System.arraycopy (bcBlock.byteArray, 0, retByteArray,
|
||||
i * BroadcastCS.blockSize, bcBlock.byteArray.length)
|
||||
arrayOfBlocks(hasBlocks) = bcBlock
|
||||
hasBlocks += 1
|
||||
hasBlocksLock.synchronized {
|
||||
hasBlocksLock.notifyAll
|
||||
}
|
||||
println (System.currentTimeMillis + ": " + "Received block: " + i + " " + bcBlock)
|
||||
}
|
||||
assert (hasBlocks == totalBlocks)
|
||||
println (System.currentTimeMillis + ": " + "After the receive loop")
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
retByteArray = null
|
||||
println (System.currentTimeMillis + ": " + "receiveSingleTransmission had a " + e)
|
||||
}
|
||||
} finally {
|
||||
if (oisSource != null) { oisSource.close }
|
||||
if (oosSource != null) {
|
||||
oosSource.close
|
||||
}
|
||||
if (clientSocketToSource != null) { clientSocketToSource.close }
|
||||
}
|
||||
|
||||
return retByteArray
|
||||
}
|
||||
|
||||
class TrackMultipleValues extends Thread {
|
||||
override def run = {
|
||||
var threadPool = Executors.newCachedThreadPool
|
||||
var serverSocket: ServerSocket = null
|
||||
|
||||
serverSocket = new ServerSocket (BroadcastCS.masterListenPort)
|
||||
println (System.currentTimeMillis + ": " + "TrackMultipleVariables" + serverSocket + " " + listenPort)
|
||||
|
||||
var keepAccepting = true
|
||||
try {
|
||||
while (true) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
serverSocket.setSoTimeout (serverSocketTimout)
|
||||
clientSocket = serverSocket.accept
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
println ("TrackMultipleValues Timeout. Stopping listening...")
|
||||
keepAccepting = false
|
||||
}
|
||||
}
|
||||
println (System.currentTimeMillis + ": " + "TrackMultipleValues:Got new request:" + clientSocket)
|
||||
if (clientSocket != null) {
|
||||
try {
|
||||
threadPool.execute (new Runnable {
|
||||
def run = {
|
||||
val oos = new ObjectOutputStream (clientSocket.getOutputStream)
|
||||
val ois = new ObjectInputStream (clientSocket.getInputStream)
|
||||
try {
|
||||
val variableUUID = ois.readObject.asInstanceOf[UUID]
|
||||
var contactPort = 0
|
||||
// TODO: Add logic and data structures to find out UUID->port
|
||||
// mapping. 0 = missed the broadcast, read from HDFS; <0 =
|
||||
// Haven't started yet, wait & retry; >0 = Read from this port
|
||||
oos.writeObject (contactPort)
|
||||
} catch {
|
||||
case e: Exception => { }
|
||||
} finally {
|
||||
ois.close
|
||||
oos.close
|
||||
clientSocket.close
|
||||
}
|
||||
}
|
||||
})
|
||||
} catch {
|
||||
// In failure, close the socket here; else, the thread will close it
|
||||
case ioe: IOException => clientSocket.close
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
serverSocket.close
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TrackSingleValue {
|
||||
|
||||
}
|
||||
|
||||
class GuideMultipleRequests extends Thread {
|
||||
override def run = {
|
||||
var threadPool = Executors.newCachedThreadPool
|
||||
var serverSocket: ServerSocket = null
|
||||
|
||||
serverSocket = new ServerSocket (BroadcastCS.masterListenPort)
|
||||
// listenPort = BroadcastCS.masterListenPort
|
||||
println (System.currentTimeMillis + ": " + "GuideMultipleRequests" + serverSocket + " " + listenPort)
|
||||
|
||||
var keepAccepting = true
|
||||
try {
|
||||
while (keepAccepting) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
serverSocket.setSoTimeout (serverSocketTimout)
|
||||
clientSocket = serverSocket.accept
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
println ("GuideMultipleRequests Timeout. Stopping listening...")
|
||||
keepAccepting = false
|
||||
}
|
||||
}
|
||||
if (clientSocket != null) {
|
||||
println (System.currentTimeMillis + ": " + "Guide:Accepted new client connection:" + clientSocket)
|
||||
try {
|
||||
threadPool.execute (new GuideSingleRequest (clientSocket))
|
||||
} catch {
|
||||
// In failure, close the socket here; else, the thread will close it
|
||||
case ioe: IOException => clientSocket.close
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
serverSocket.close
|
||||
}
|
||||
}
|
||||
|
||||
class GuideSingleRequest (val clientSocket: Socket) extends Runnable {
|
||||
private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
|
||||
private val ois = new ObjectInputStream (clientSocket.getInputStream)
|
||||
|
||||
private var selectedSourceInfo: SourceInfo = null
|
||||
private var thisWorkerInfo:SourceInfo = null
|
||||
|
||||
def run = {
|
||||
try {
|
||||
println (System.currentTimeMillis + ": " + "new GuideSingleRequest is running")
|
||||
// Connecting worker is sending in its hostAddress and listenPort it will
|
||||
// be listening to. ReplicaID is 0 and other fields are invalid (-1)
|
||||
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
|
||||
|
||||
// Select a suitable source and send it back to the worker
|
||||
selectedSourceInfo = selectSuitableSource (sourceInfo)
|
||||
println (System.currentTimeMillis + ": " + "Sending selectedSourceInfo:" + selectedSourceInfo)
|
||||
oos.writeObject (selectedSourceInfo)
|
||||
oos.flush
|
||||
|
||||
// Add this new (if it can finish) source to the PQ of sources
|
||||
thisWorkerInfo = new SourceInfo(sourceInfo.hostAddress,
|
||||
sourceInfo.listenPort, totalBlocks, totalBytes, 0)
|
||||
println (System.currentTimeMillis + ": " + "Adding possible new source to pqOfSources: " + thisWorkerInfo)
|
||||
pqOfSources.synchronized {
|
||||
pqOfSources.add (thisWorkerInfo)
|
||||
}
|
||||
|
||||
// Wait till the whole transfer is done. Then receive and update source
|
||||
// statistics in pqOfSources
|
||||
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
|
||||
|
||||
pqOfSources.synchronized {
|
||||
// This should work since SourceInfo is a case class
|
||||
assert (pqOfSources.contains (selectedSourceInfo))
|
||||
|
||||
// Remove first
|
||||
pqOfSources.remove (selectedSourceInfo)
|
||||
// TODO: Removing a source based on just one failure notification!
|
||||
// Update leecher count and put it back in IF reception succeeded
|
||||
if (!sourceInfo.receptionFailed) {
|
||||
selectedSourceInfo.currentLeechers -= 1
|
||||
pqOfSources.add (selectedSourceInfo)
|
||||
|
||||
// No need to find and update thisWorkerInfo, but add its replica
|
||||
if (BroadcastCS.dualMode) {
|
||||
pqOfSources.add (new SourceInfo (thisWorkerInfo.hostAddress,
|
||||
thisWorkerInfo.listenPort, totalBlocks, totalBytes, 1))
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// If something went wrong, e.g., the worker at the other end died etc.
|
||||
// then close everything up
|
||||
case e: Exception => {
|
||||
// Assuming that exception caused due to receiver worker failure
|
||||
// Remove failed worker from pqOfSources and update leecherCount of
|
||||
// corresponding source worker
|
||||
pqOfSources.synchronized {
|
||||
if (selectedSourceInfo != null) {
|
||||
// Remove first
|
||||
pqOfSources.remove (selectedSourceInfo)
|
||||
// Update leecher count and put it back in
|
||||
selectedSourceInfo.currentLeechers -= 1
|
||||
pqOfSources.add (selectedSourceInfo)
|
||||
}
|
||||
|
||||
// Remove thisWorkerInfo
|
||||
if (pqOfSources != null) { pqOfSources.remove (thisWorkerInfo) }
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
ois.close
|
||||
oos.close
|
||||
clientSocket.close
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: If a worker fails to get the broadcasted variable from a source and
|
||||
// comes back to Master, this function might choose the worker itself as a
|
||||
// source tp create a dependency cycle (this worker was put into pqOfSources
|
||||
// as a streming source when it first arrived). The length of this cycle can
|
||||
// be arbitrarily long.
|
||||
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
|
||||
// Select one with the lowest number of leechers
|
||||
pqOfSources.synchronized {
|
||||
// take is a blocking call removing the element from PQ
|
||||
var selectedSource = pqOfSources.poll
|
||||
assert (selectedSource != null)
|
||||
// Update leecher count
|
||||
selectedSource.currentLeechers += 1
|
||||
// Add it back and then return
|
||||
pqOfSources.add (selectedSource)
|
||||
return selectedSource
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ServeMultipleRequests extends Thread {
|
||||
override def run = {
|
||||
var threadPool = Executors.newCachedThreadPool
|
||||
var serverSocket: ServerSocket = null
|
||||
|
||||
serverSocket = new ServerSocket (0)
|
||||
listenPort = serverSocket.getLocalPort
|
||||
println (System.currentTimeMillis + ": " + "ServeMultipleRequests" + serverSocket + " " + listenPort)
|
||||
|
||||
listenPortLock.synchronized {
|
||||
listenPortLock.notifyAll
|
||||
}
|
||||
|
||||
var keepAccepting = true
|
||||
try {
|
||||
while (keepAccepting) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
serverSocket.setSoTimeout (serverSocketTimout)
|
||||
clientSocket = serverSocket.accept
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
println ("ServeMultipleRequests Timeout. Stopping listening...")
|
||||
keepAccepting = false
|
||||
}
|
||||
}
|
||||
if (clientSocket != null) {
|
||||
println (System.currentTimeMillis + ": " + "Serve:Accepted new client connection:" + clientSocket)
|
||||
try {
|
||||
threadPool.execute (new ServeSingleRequest (clientSocket))
|
||||
} catch {
|
||||
// In failure, close socket here; else, the thread will close it
|
||||
case ioe: IOException => clientSocket.close
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
serverSocket.close
|
||||
}
|
||||
}
|
||||
|
||||
class ServeSingleRequest (val clientSocket: Socket) extends Runnable {
|
||||
private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
|
||||
private val ois = new ObjectInputStream (clientSocket.getInputStream)
|
||||
|
||||
def run = {
|
||||
try {
|
||||
println (System.currentTimeMillis + ": " + "new ServeSingleRequest is running")
|
||||
sendObject
|
||||
} catch {
|
||||
// TODO: Need to add better exception handling here
|
||||
// If something went wrong, e.g., the worker at the other end died etc.
|
||||
// then close everything up
|
||||
case e: Exception => {
|
||||
println (System.currentTimeMillis + ": " + "ServeSingleRequest had a " + e)
|
||||
}
|
||||
} finally {
|
||||
println (System.currentTimeMillis + ": " + "ServeSingleRequest is closing streams and sockets")
|
||||
ois.close
|
||||
oos.close
|
||||
clientSocket.close
|
||||
}
|
||||
}
|
||||
|
||||
private def sendObject = {
|
||||
// Wait till receiving the SourceInfo from Master
|
||||
while (totalBlocks == -1) {
|
||||
totalBlocksLock.synchronized {
|
||||
totalBlocksLock.wait
|
||||
}
|
||||
}
|
||||
|
||||
for (i <- 0 until totalBlocks) {
|
||||
while (i == hasBlocks) {
|
||||
hasBlocksLock.synchronized {
|
||||
hasBlocksLock.wait
|
||||
}
|
||||
}
|
||||
try {
|
||||
oos.writeObject (arrayOfBlocks(i))
|
||||
oos.flush
|
||||
} catch {
|
||||
case e: Exception => { }
|
||||
}
|
||||
println (System.currentTimeMillis + ": " + "Send block: " + i + " " + arrayOfBlocks(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
private object BroadcastCH {
|
||||
val values = new MapMaker ().softValues ().makeMap[UUID, Any]
|
||||
|
||||
private var initialized = false
|
||||
|
||||
private var fileSystem: FileSystem = null
|
||||
private var workDir: String = null
|
||||
private var compress: Boolean = false
|
||||
private var bufferSize: Int = 65536
|
||||
|
||||
def initialize () {
|
||||
synchronized {
|
||||
if (!initialized) {
|
||||
bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
|
||||
val dfs = System.getProperty("spark.dfs", "file:///")
|
||||
if (!dfs.startsWith("file://")) {
|
||||
val conf = new Configuration()
|
||||
conf.setInt("io.file.buffer.size", bufferSize)
|
||||
val rep = System.getProperty("spark.dfs.replication", "3").toInt
|
||||
conf.setInt("dfs.replication", rep)
|
||||
fileSystem = FileSystem.get(new URI(dfs), conf)
|
||||
}
|
||||
workDir = System.getProperty("spark.dfs.workdir", "/tmp")
|
||||
compress = System.getProperty("spark.compress", "false").toBoolean
|
||||
|
||||
initialized = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def getPath(uuid: UUID) = new Path(workDir + "/broadcast-" + uuid)
|
||||
|
||||
def openFileForReading(uuid: UUID): InputStream = {
|
||||
val fileStream = if (fileSystem != null) {
|
||||
fileSystem.open(getPath(uuid))
|
||||
} else {
|
||||
// Local filesystem
|
||||
new FileInputStream(getPath(uuid).toString)
|
||||
}
|
||||
if (compress)
|
||||
new LZFInputStream(fileStream) // LZF stream does its own buffering
|
||||
else if (fileSystem == null)
|
||||
new BufferedInputStream(fileStream, bufferSize)
|
||||
else
|
||||
fileStream // Hadoop streams do their own buffering
|
||||
}
|
||||
|
||||
def openFileForWriting(uuid: UUID): OutputStream = {
|
||||
val fileStream = if (fileSystem != null) {
|
||||
fileSystem.create(getPath(uuid))
|
||||
} else {
|
||||
// Local filesystem
|
||||
new FileOutputStream(getPath(uuid).toString)
|
||||
}
|
||||
if (compress)
|
||||
new LZFOutputStream(fileStream) // LZF stream does its own buffering
|
||||
else if (fileSystem == null)
|
||||
new BufferedOutputStream(fileStream, bufferSize)
|
||||
else
|
||||
fileStream // Hadoop streams do their own buffering
|
||||
}
|
||||
}
|
|
@ -1,110 +0,0 @@
|
|||
package spark
|
||||
|
||||
import java.io._
|
||||
import java.net.URI
|
||||
import java.util.UUID
|
||||
|
||||
import com.google.common.collect.MapMaker
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
|
||||
|
||||
import spark.compress.lzf.{LZFInputStream, LZFOutputStream}
|
||||
|
||||
@serializable class Cached[T](@transient var value_ : T, local: Boolean) {
|
||||
val uuid = UUID.randomUUID()
|
||||
def value = value_
|
||||
|
||||
Cache.synchronized { Cache.values.put(uuid, value_) }
|
||||
|
||||
if (!local) writeCacheFile()
|
||||
|
||||
private def writeCacheFile() {
|
||||
val out = new ObjectOutputStream(Cache.openFileForWriting(uuid))
|
||||
out.writeObject(value_)
|
||||
out.close()
|
||||
}
|
||||
|
||||
// Called by Java when deserializing an object
|
||||
private def readObject(in: ObjectInputStream) {
|
||||
in.defaultReadObject
|
||||
Cache.synchronized {
|
||||
val cachedVal = Cache.values.get(uuid)
|
||||
if (cachedVal != null) {
|
||||
value_ = cachedVal.asInstanceOf[T]
|
||||
} else {
|
||||
val start = System.nanoTime
|
||||
val fileIn = new ObjectInputStream(Cache.openFileForReading(uuid))
|
||||
value_ = fileIn.readObject().asInstanceOf[T]
|
||||
Cache.values.put(uuid, value_)
|
||||
fileIn.close()
|
||||
val time = (System.nanoTime - start) / 1e9
|
||||
println("Reading cached variable " + uuid + " took " + time + " s")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def toString = "spark.Cached(" + uuid + ")"
|
||||
}
|
||||
|
||||
private object Cache {
|
||||
val values = new MapMaker().softValues().makeMap[UUID, Any]()
|
||||
|
||||
private var initialized = false
|
||||
private var fileSystem: FileSystem = null
|
||||
private var workDir: String = null
|
||||
private var compress: Boolean = false
|
||||
private var bufferSize: Int = 65536
|
||||
|
||||
// Will be called by SparkContext or Executor before using cache
|
||||
def initialize() {
|
||||
synchronized {
|
||||
if (!initialized) {
|
||||
bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
|
||||
val dfs = System.getProperty("spark.dfs", "file:///")
|
||||
if (!dfs.startsWith("file://")) {
|
||||
val conf = new Configuration()
|
||||
conf.setInt("io.file.buffer.size", bufferSize)
|
||||
val rep = System.getProperty("spark.dfs.replication", "3").toInt
|
||||
conf.setInt("dfs.replication", rep)
|
||||
fileSystem = FileSystem.get(new URI(dfs), conf)
|
||||
}
|
||||
workDir = System.getProperty("spark.dfs.workdir", "/tmp")
|
||||
compress = System.getProperty("spark.compress", "false").toBoolean
|
||||
initialized = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def getPath(uuid: UUID) = new Path(workDir + "/cache-" + uuid)
|
||||
|
||||
def openFileForReading(uuid: UUID): InputStream = {
|
||||
val fileStream = if (fileSystem != null) {
|
||||
fileSystem.open(getPath(uuid))
|
||||
} else {
|
||||
// Local filesystem
|
||||
new FileInputStream(getPath(uuid).toString)
|
||||
}
|
||||
if (compress)
|
||||
new LZFInputStream(fileStream) // LZF stream does its own buffering
|
||||
else if (fileSystem == null)
|
||||
new BufferedInputStream(fileStream, bufferSize)
|
||||
else
|
||||
fileStream // Hadoop streams do their own buffering
|
||||
}
|
||||
|
||||
def openFileForWriting(uuid: UUID): OutputStream = {
|
||||
val fileStream = if (fileSystem != null) {
|
||||
fileSystem.create(getPath(uuid))
|
||||
} else {
|
||||
// Local filesystem
|
||||
new FileOutputStream(getPath(uuid).toString)
|
||||
}
|
||||
if (compress)
|
||||
new LZFOutputStream(fileStream) // LZF stream does its own buffering
|
||||
else if (fileSystem == null)
|
||||
new BufferedOutputStream(fileStream, bufferSize)
|
||||
else
|
||||
fileStream // Hadoop streams do their own buffering
|
||||
}
|
||||
}
|
|
@ -19,8 +19,8 @@ object Executor {
|
|||
for ((key, value) <- props)
|
||||
System.setProperty(key, value)
|
||||
|
||||
// Initialize cache (uses some properties read above)
|
||||
Cache.initialize()
|
||||
// Initialize broadcast system (uses some properties read above)
|
||||
Broadcast.initialize(false)
|
||||
|
||||
// If the REPL is in use, create a ClassLoader that will be able to
|
||||
// read new classes defined by the REPL as the user types code
|
||||
|
|
|
@ -60,8 +60,10 @@ extends RDD[String, HdfsSplit](sc) {
|
|||
}
|
||||
}
|
||||
|
||||
override def prefers(split: HdfsSplit, slot: SlaveOffer) =
|
||||
split.value.getLocations().contains(slot.getHost)
|
||||
override def preferredLocations(split: HdfsSplit) = {
|
||||
// TODO: Filtering out "localhost" in case of file:// URLs
|
||||
split.value.getLocations().filter(_ != "localhost")
|
||||
}
|
||||
}
|
||||
|
||||
object ConfigureLock {}
|
||||
|
|
|
@ -1,50 +1,45 @@
|
|||
package spark
|
||||
|
||||
import java.io.File
|
||||
import java.util.concurrent.Semaphore
|
||||
|
||||
import nexus.{ExecutorInfo, TaskDescription, TaskState, TaskStatus}
|
||||
import nexus.{SlaveOffer, SchedulerDriver, NexusSchedulerDriver}
|
||||
import nexus.{SlaveOfferVector, TaskDescriptionVector, StringMap}
|
||||
import scala.collection.mutable.Map
|
||||
|
||||
import nexus.{Scheduler => NScheduler}
|
||||
import nexus._
|
||||
|
||||
// The main Scheduler implementation, which talks to Nexus. Clients are expected
|
||||
// to first call start(), then submit tasks through the runTasks method.
|
||||
//
|
||||
// This implementation is currently a little quick and dirty. The following
|
||||
// improvements need to be made to it:
|
||||
// 1) Fault tolerance should be added - if a task fails, just re-run it anywhere.
|
||||
// 2) Right now, the scheduler uses a linear scan through the tasks to find a
|
||||
// 1) Right now, the scheduler uses a linear scan through the tasks to find a
|
||||
// local one for a given node. It would be faster to have a separate list of
|
||||
// pending tasks for each node.
|
||||
// 3) The Callbacks way of organizing things didn't work out too well, so the
|
||||
// way the scheduler keeps track of the currently active runTasks operation
|
||||
// can be made cleaner.
|
||||
// 2) Presenting a single slave in ParallelOperation.slaveOffer makes it
|
||||
// difficult to balance tasks across nodes. It would be better to pass
|
||||
// all the offers to the ParallelOperation and have it load-balance.
|
||||
private class NexusScheduler(
|
||||
master: String, frameworkName: String, execArg: Array[Byte])
|
||||
extends nexus.Scheduler with spark.Scheduler
|
||||
extends NScheduler with spark.Scheduler
|
||||
{
|
||||
// Semaphore used by runTasks to ensure only one thread can be in it
|
||||
val semaphore = new Semaphore(1)
|
||||
// Lock used by runTasks to ensure only one thread can be in it
|
||||
val runTasksMutex = new Object()
|
||||
|
||||
// Lock used to wait for scheduler to be registered
|
||||
var isRegistered = false
|
||||
val registeredLock = new Object()
|
||||
|
||||
// Trait representing a set of scheduler callbacks
|
||||
trait Callbacks {
|
||||
def slotOffer(s: SlaveOffer): Option[TaskDescription]
|
||||
def taskFinished(t: TaskStatus): Unit
|
||||
def error(code: Int, message: String): Unit
|
||||
}
|
||||
|
||||
// Current callback object (may be null)
|
||||
var callbacks: Callbacks = null
|
||||
var activeOp: ParallelOperation = null
|
||||
|
||||
// Incrementing task ID
|
||||
var nextTaskId = 0
|
||||
private var nextTaskId = 0
|
||||
|
||||
// Maximum time to wait to run a task in a preferred location (in ms)
|
||||
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "1000").toLong
|
||||
def newTaskId(): Int = {
|
||||
val id = nextTaskId;
|
||||
nextTaskId += 1;
|
||||
return id
|
||||
}
|
||||
|
||||
// Driver for talking to Nexus
|
||||
var driver: SchedulerDriver = null
|
||||
|
@ -65,126 +60,28 @@ extends nexus.Scheduler with spark.Scheduler
|
|||
override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo =
|
||||
new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg)
|
||||
|
||||
override def runTasks[T: ClassManifest](tasks: Array[Task[T]]) : Array[T] = {
|
||||
val results = new Array[T](tasks.length)
|
||||
if (tasks.length == 0)
|
||||
return results
|
||||
override def runTasks[T: ClassManifest](tasks: Array[Task[T]]): Array[T] = {
|
||||
runTasksMutex.synchronized {
|
||||
waitForRegister()
|
||||
val myOp = new SimpleParallelOperation(this, tasks)
|
||||
|
||||
val launched = new Array[Boolean](tasks.length)
|
||||
|
||||
val callingThread = currentThread
|
||||
|
||||
var errorHappened = false
|
||||
var errorCode = 0
|
||||
var errorMessage = ""
|
||||
|
||||
// Wait for scheduler to be registered with Nexus
|
||||
waitForRegister()
|
||||
|
||||
try {
|
||||
// Acquire the runTasks semaphore
|
||||
semaphore.acquire()
|
||||
|
||||
val myCallbacks = new Callbacks {
|
||||
val firstTaskId = nextTaskId
|
||||
var tasksLaunched = 0
|
||||
var tasksFinished = 0
|
||||
var lastPreferredLaunchTime = System.currentTimeMillis
|
||||
|
||||
def slotOffer(slot: SlaveOffer): Option[TaskDescription] = {
|
||||
try {
|
||||
if (tasksLaunched < tasks.length) {
|
||||
// TODO: Add a short wait if no task with location pref is found
|
||||
// TODO: Figure out why a function is needed around this to
|
||||
// avoid scala.runtime.NonLocalReturnException
|
||||
def findTask: Option[TaskDescription] = {
|
||||
var checkPrefVals: Array[Boolean] = Array(true)
|
||||
val time = System.currentTimeMillis
|
||||
if (time - lastPreferredLaunchTime > LOCALITY_WAIT)
|
||||
checkPrefVals = Array(true, false) // Allow non-preferred tasks
|
||||
// TODO: Make desiredCpus and desiredMem configurable
|
||||
val desiredCpus = 1
|
||||
val desiredMem = 750L * 1024L * 1024L
|
||||
if (slot.getParams.get("cpus").toInt < desiredCpus ||
|
||||
slot.getParams.get("mem").toLong < desiredMem)
|
||||
return None
|
||||
for (checkPref <- checkPrefVals;
|
||||
i <- 0 until tasks.length;
|
||||
if !launched(i) && (!checkPref || tasks(i).prefers(slot)))
|
||||
{
|
||||
val taskId = nextTaskId
|
||||
nextTaskId += 1
|
||||
printf("Starting task %d as TID %d on slave %d: %s (%s)\n",
|
||||
i, taskId, slot.getSlaveId, slot.getHost,
|
||||
if(checkPref) "preferred" else "non-preferred")
|
||||
tasks(i).markStarted(slot)
|
||||
launched(i) = true
|
||||
tasksLaunched += 1
|
||||
if (checkPref)
|
||||
lastPreferredLaunchTime = time
|
||||
val params = new StringMap
|
||||
params.set("cpus", "" + desiredCpus)
|
||||
params.set("mem", "" + desiredMem)
|
||||
val serializedTask = Utils.serialize(tasks(i))
|
||||
return Some(new TaskDescription(taskId, slot.getSlaveId,
|
||||
"task_" + taskId, params, serializedTask))
|
||||
}
|
||||
return None
|
||||
}
|
||||
return findTask
|
||||
} else {
|
||||
return None
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
e.printStackTrace
|
||||
System.exit(1)
|
||||
return None
|
||||
}
|
||||
}
|
||||
try {
|
||||
this.synchronized {
|
||||
this.activeOp = myOp
|
||||
}
|
||||
|
||||
def taskFinished(status: TaskStatus) {
|
||||
println("Finished TID " + status.getTaskId)
|
||||
// Deserialize task result
|
||||
val result = Utils.deserialize[TaskResult[T]](status.getData)
|
||||
results(status.getTaskId - firstTaskId) = result.value
|
||||
// Update accumulators
|
||||
Accumulators.add(callingThread, result.accumUpdates)
|
||||
// Stop if we've finished all the tasks
|
||||
tasksFinished += 1
|
||||
if (tasksFinished == tasks.length) {
|
||||
NexusScheduler.this.callbacks = null
|
||||
NexusScheduler.this.notifyAll()
|
||||
}
|
||||
}
|
||||
|
||||
def error(code: Int, message: String) {
|
||||
// Save the error message
|
||||
errorHappened = true
|
||||
errorCode = code
|
||||
errorMessage = message
|
||||
// Indicate to caller thread that we're done
|
||||
NexusScheduler.this.callbacks = null
|
||||
NexusScheduler.this.notifyAll()
|
||||
driver.reviveOffers();
|
||||
myOp.join();
|
||||
} finally {
|
||||
this.synchronized {
|
||||
this.activeOp = null
|
||||
}
|
||||
}
|
||||
|
||||
this.synchronized {
|
||||
this.callbacks = myCallbacks
|
||||
}
|
||||
driver.reviveOffers();
|
||||
this.synchronized {
|
||||
while (this.callbacks != null) this.wait()
|
||||
}
|
||||
} finally {
|
||||
semaphore.release()
|
||||
if (myOp.errorHappened)
|
||||
throw new SparkException(myOp.errorMessage, myOp.errorCode)
|
||||
else
|
||||
return myOp.results
|
||||
}
|
||||
|
||||
if (errorHappened)
|
||||
throw new SparkException(errorMessage, errorCode)
|
||||
else
|
||||
return results
|
||||
}
|
||||
|
||||
override def registered(d: SchedulerDriver, frameworkId: Int) {
|
||||
|
@ -197,18 +94,19 @@ extends nexus.Scheduler with spark.Scheduler
|
|||
|
||||
override def waitForRegister() {
|
||||
registeredLock.synchronized {
|
||||
while (!isRegistered) registeredLock.wait()
|
||||
while (!isRegistered)
|
||||
registeredLock.wait()
|
||||
}
|
||||
}
|
||||
|
||||
override def resourceOffer(
|
||||
d: SchedulerDriver, oid: Long, slots: SlaveOfferVector) {
|
||||
d: SchedulerDriver, oid: Long, offers: SlaveOfferVector) {
|
||||
synchronized {
|
||||
val tasks = new TaskDescriptionVector
|
||||
if (callbacks != null) {
|
||||
if (activeOp != null) {
|
||||
try {
|
||||
for (i <- 0 until slots.size.toInt) {
|
||||
callbacks.slotOffer(slots.get(i)) match {
|
||||
for (i <- 0 until offers.size.toInt) {
|
||||
activeOp.slaveOffer(offers.get(i)) match {
|
||||
case Some(task) => tasks.add(task)
|
||||
case None => {}
|
||||
}
|
||||
|
@ -225,21 +123,21 @@ extends nexus.Scheduler with spark.Scheduler
|
|||
|
||||
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
|
||||
synchronized {
|
||||
if (callbacks != null && status.getState == TaskState.TASK_FINISHED) {
|
||||
try {
|
||||
callbacks.taskFinished(status)
|
||||
} catch {
|
||||
case e: Exception => e.printStackTrace
|
||||
try {
|
||||
if (activeOp != null) {
|
||||
activeOp.statusUpdate(status)
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => e.printStackTrace
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def error(d: SchedulerDriver, code: Int, message: String) {
|
||||
synchronized {
|
||||
if (callbacks != null) {
|
||||
if (activeOp != null) {
|
||||
try {
|
||||
callbacks.error(code, message)
|
||||
activeOp.error(code, message)
|
||||
} catch {
|
||||
case e: Exception => e.printStackTrace
|
||||
}
|
||||
|
@ -256,3 +154,137 @@ extends nexus.Scheduler with spark.Scheduler
|
|||
driver.stop()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Trait representing an object that manages a parallel operation by
|
||||
// implementing various scheduler callbacks.
|
||||
trait ParallelOperation {
|
||||
def slaveOffer(s: SlaveOffer): Option[TaskDescription]
|
||||
def statusUpdate(t: TaskStatus): Unit
|
||||
def error(code: Int, message: String): Unit
|
||||
}
|
||||
|
||||
|
||||
class SimpleParallelOperation[T: ClassManifest](
|
||||
sched: NexusScheduler, tasks: Array[Task[T]])
|
||||
extends ParallelOperation
|
||||
{
|
||||
// Maximum time to wait to run a task in a preferred location (in ms)
|
||||
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "1000").toLong
|
||||
|
||||
val callingThread = currentThread
|
||||
val numTasks = tasks.length
|
||||
val results = new Array[T](numTasks)
|
||||
val launched = new Array[Boolean](numTasks)
|
||||
val finished = new Array[Boolean](numTasks)
|
||||
val tidToIndex = Map[Int, Int]()
|
||||
|
||||
var allFinished = false
|
||||
val joinLock = new Object()
|
||||
|
||||
var errorHappened = false
|
||||
var errorCode = 0
|
||||
var errorMessage = ""
|
||||
|
||||
var tasksLaunched = 0
|
||||
var tasksFinished = 0
|
||||
var lastPreferredLaunchTime = System.currentTimeMillis
|
||||
|
||||
def setAllFinished() {
|
||||
joinLock.synchronized {
|
||||
allFinished = true
|
||||
joinLock.notifyAll()
|
||||
}
|
||||
}
|
||||
|
||||
def join() {
|
||||
joinLock.synchronized {
|
||||
while (!allFinished)
|
||||
joinLock.wait()
|
||||
}
|
||||
}
|
||||
|
||||
def slaveOffer(offer: SlaveOffer): Option[TaskDescription] = {
|
||||
if (tasksLaunched < numTasks) {
|
||||
var checkPrefVals: Array[Boolean] = Array(true)
|
||||
val time = System.currentTimeMillis
|
||||
if (time - lastPreferredLaunchTime > LOCALITY_WAIT)
|
||||
checkPrefVals = Array(true, false) // Allow non-preferred tasks
|
||||
// TODO: Make desiredCpus and desiredMem configurable
|
||||
val desiredCpus = 1
|
||||
val desiredMem = 750L * 1024L * 1024L
|
||||
if (offer.getParams.get("cpus").toInt < desiredCpus ||
|
||||
offer.getParams.get("mem").toLong < desiredMem)
|
||||
return None
|
||||
for (checkPref <- checkPrefVals; i <- 0 until numTasks) {
|
||||
if (!launched(i) && (!checkPref ||
|
||||
tasks(i).preferredLocations.contains(offer.getHost) ||
|
||||
tasks(i).preferredLocations.isEmpty))
|
||||
{
|
||||
val taskId = sched.newTaskId()
|
||||
tidToIndex(taskId) = i
|
||||
printf("Starting task %d as TID %d on slave %d: %s (%s)\n",
|
||||
i, taskId, offer.getSlaveId, offer.getHost,
|
||||
if(checkPref) "preferred" else "non-preferred")
|
||||
tasks(i).markStarted(offer)
|
||||
launched(i) = true
|
||||
tasksLaunched += 1
|
||||
if (checkPref)
|
||||
lastPreferredLaunchTime = time
|
||||
val params = new StringMap
|
||||
params.set("cpus", "" + desiredCpus)
|
||||
params.set("mem", "" + desiredMem)
|
||||
val serializedTask = Utils.serialize(tasks(i))
|
||||
return Some(new TaskDescription(taskId, offer.getSlaveId,
|
||||
"task_" + taskId, params, serializedTask))
|
||||
}
|
||||
}
|
||||
}
|
||||
return None
|
||||
}
|
||||
|
||||
def statusUpdate(status: TaskStatus) {
|
||||
status.getState match {
|
||||
case TaskState.TASK_FINISHED =>
|
||||
taskFinished(status)
|
||||
case TaskState.TASK_LOST =>
|
||||
taskLost(status)
|
||||
case TaskState.TASK_FAILED =>
|
||||
taskLost(status)
|
||||
case TaskState.TASK_KILLED =>
|
||||
taskLost(status)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
|
||||
def taskFinished(status: TaskStatus) {
|
||||
val tid = status.getTaskId
|
||||
println("Finished TID " + tid)
|
||||
// Deserialize task result
|
||||
val result = Utils.deserialize[TaskResult[T]](status.getData)
|
||||
results(tidToIndex(tid)) = result.value
|
||||
// Update accumulators
|
||||
Accumulators.add(callingThread, result.accumUpdates)
|
||||
// Mark finished and stop if we've finished all the tasks
|
||||
finished(tidToIndex(tid)) = true
|
||||
tasksFinished += 1
|
||||
if (tasksFinished == numTasks)
|
||||
setAllFinished()
|
||||
}
|
||||
|
||||
def taskLost(status: TaskStatus) {
|
||||
val tid = status.getTaskId
|
||||
println("Lost TID " + tid)
|
||||
launched(tidToIndex(tid)) = false
|
||||
tasksLaunched -= 1
|
||||
}
|
||||
|
||||
def error(code: Int, message: String) {
|
||||
// Save the error message
|
||||
errorHappened = true
|
||||
errorCode = code
|
||||
errorMessage = message
|
||||
// Indicate to caller thread that we're done
|
||||
setAllFinished()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ extends RDD[T, ParallelArraySplit[T]](sc) {
|
|||
|
||||
override def iterator(s: ParallelArraySplit[T]) = s.iterator
|
||||
|
||||
override def prefers(s: ParallelArraySplit[T], offer: SlaveOffer) = true
|
||||
override def preferredLocations(s: ParallelArraySplit[T]): Seq[String] = Nil
|
||||
}
|
||||
|
||||
private object ParallelArray {
|
||||
|
|
|
@ -16,7 +16,7 @@ abstract class RDD[T: ClassManifest, Split](
|
|||
@transient sc: SparkContext) {
|
||||
def splits: Array[Split]
|
||||
def iterator(split: Split): Iterator[T]
|
||||
def prefers(split: Split, slot: SlaveOffer): Boolean
|
||||
def preferredLocations(split: Split): Seq[String]
|
||||
|
||||
def taskStarted(split: Split, slot: SlaveOffer) {}
|
||||
|
||||
|
@ -83,7 +83,7 @@ abstract class RDD[T: ClassManifest, Split](
|
|||
abstract class RDDTask[U: ClassManifest, T: ClassManifest, Split](
|
||||
val rdd: RDD[T, Split], val split: Split)
|
||||
extends Task[U] {
|
||||
override def prefers(slot: SlaveOffer) = rdd.prefers(split, slot)
|
||||
override def preferredLocations() = rdd.preferredLocations(split)
|
||||
override def markStarted(slot: SlaveOffer) { rdd.taskStarted(split, slot) }
|
||||
}
|
||||
|
||||
|
@ -122,7 +122,7 @@ class MappedRDD[U: ClassManifest, T: ClassManifest, Split](
|
|||
prev: RDD[T, Split], f: T => U)
|
||||
extends RDD[U, Split](prev.sparkContext) {
|
||||
override def splits = prev.splits
|
||||
override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot)
|
||||
override def preferredLocations(split: Split) = prev.preferredLocations(split)
|
||||
override def iterator(split: Split) = prev.iterator(split).map(f)
|
||||
override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
|
||||
}
|
||||
|
@ -131,7 +131,7 @@ class FilteredRDD[T: ClassManifest, Split](
|
|||
prev: RDD[T, Split], f: T => Boolean)
|
||||
extends RDD[T, Split](prev.sparkContext) {
|
||||
override def splits = prev.splits
|
||||
override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot)
|
||||
override def preferredLocations(split: Split) = prev.preferredLocations(split)
|
||||
override def iterator(split: Split) = prev.iterator(split).filter(f)
|
||||
override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
|
||||
}
|
||||
|
@ -140,15 +140,15 @@ class CachedRDD[T, Split](
|
|||
prev: RDD[T, Split])(implicit m: ClassManifest[T])
|
||||
extends RDD[T, Split](prev.sparkContext) {
|
||||
val id = CachedRDD.newId()
|
||||
@transient val cacheLocs = Map[Split, List[Int]]()
|
||||
@transient val cacheLocs = Map[Split, List[String]]()
|
||||
|
||||
override def splits = prev.splits
|
||||
|
||||
override def prefers(split: Split, slot: SlaveOffer): Boolean = {
|
||||
override def preferredLocations(split: Split) = {
|
||||
if (cacheLocs.contains(split))
|
||||
cacheLocs(split).contains(slot.getSlaveId)
|
||||
cacheLocs(split)
|
||||
else
|
||||
prev.prefers(split, slot)
|
||||
prev.preferredLocations(split)
|
||||
}
|
||||
|
||||
override def iterator(split: Split): Iterator[T] = {
|
||||
|
@ -185,9 +185,9 @@ extends RDD[T, Split](prev.sparkContext) {
|
|||
|
||||
override def taskStarted(split: Split, slot: SlaveOffer) {
|
||||
val oldList = cacheLocs.getOrElse(split, Nil)
|
||||
val slaveId = slot.getSlaveId
|
||||
if (!oldList.contains(slaveId))
|
||||
cacheLocs(split) = slaveId :: oldList
|
||||
val host = slot.getHost
|
||||
if (!oldList.contains(host))
|
||||
cacheLocs(split) = host :: oldList
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -205,7 +205,7 @@ private object CachedRDD {
|
|||
@serializable
|
||||
abstract class UnionSplit[T: ClassManifest] {
|
||||
def iterator(): Iterator[T]
|
||||
def prefers(offer: SlaveOffer): Boolean
|
||||
def preferredLocations(): Seq[String]
|
||||
}
|
||||
|
||||
@serializable
|
||||
|
@ -213,7 +213,7 @@ class UnionSplitImpl[T: ClassManifest, Split](
|
|||
rdd: RDD[T, Split], split: Split)
|
||||
extends UnionSplit[T] {
|
||||
override def iterator() = rdd.iterator(split)
|
||||
override def prefers(offer: SlaveOffer) = rdd.prefers(split, offer)
|
||||
override def preferredLocations() = rdd.preferredLocations(split)
|
||||
}
|
||||
|
||||
@serializable
|
||||
|
@ -231,5 +231,6 @@ extends RDD[T, UnionSplit[T]](sc) {
|
|||
|
||||
override def iterator(s: UnionSplit[T]): Iterator[T] = s.iterator()
|
||||
|
||||
override def prefers(s: UnionSplit[T], o: SlaveOffer) = s.prefers(o)
|
||||
override def preferredLocations(s: UnionSplit[T]): Seq[String] =
|
||||
s.preferredLocations()
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import java.util.UUID
|
|||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
class SparkContext(master: String, frameworkName: String) {
|
||||
Cache.initialize()
|
||||
Broadcast.initialize(true)
|
||||
|
||||
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int) =
|
||||
new ParallelArray[T](this, seq, numSlices)
|
||||
|
@ -18,7 +18,8 @@ class SparkContext(master: String, frameworkName: String) {
|
|||
new Accumulator(initialValue, param)
|
||||
|
||||
// TODO: Keep around a weak hash map of values to Cached versions?
|
||||
def broadcast[T](value: T) = new Cached(value, local)
|
||||
def broadcast[T](value: T) = new CentralizedHDFSBroadcast(value, local)
|
||||
//def broadcast[T](value: T) = new ChainedStreamingBroadcast(value, local)
|
||||
|
||||
def textFile(path: String) = new HdfsTextFile(this, path)
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@ import nexus._
|
|||
@serializable
|
||||
trait Task[T] {
|
||||
def run: T
|
||||
def prefers(slot: SlaveOffer): Boolean = true
|
||||
def markStarted(slot: SlaveOffer) {}
|
||||
def preferredLocations: Seq[String] = Nil
|
||||
def markStarted(offer: SlaveOffer) {}
|
||||
}
|
||||
|
||||
@serializable
|
||||
|
|
Загрузка…
Ссылка в новой задаче