diff --git a/run b/run index 20b80ac170..e6723ccd7c 100755 --- a/run +++ b/run @@ -4,7 +4,7 @@ FWDIR=`dirname $0` # Set JAVA_OPTS to be able to load libnexus.so and set various other misc options -JAVA_OPTS="-Djava.library.path=$FWDIR/third_party:$FWDIR/src/native -Xmx750m" +export JAVA_OPTS="-Djava.library.path=$FWDIR/third_party:$FWDIR/src/native -Xms100m -Xmx750m" if [ -e $FWDIR/conf/java-opts ] ; then JAVA_OPTS+=" `cat $FWDIR/conf/java-opts`" fi diff --git a/src/examples/BroadcastTest.scala b/src/examples/BroadcastTest.scala new file mode 100644 index 0000000000..7764013413 --- /dev/null +++ b/src/examples/BroadcastTest.scala @@ -0,0 +1,24 @@ +import spark.SparkContext + +object BroadcastTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: BroadcastTest []") + System.exit(1) + } + val spark = new SparkContext(args(0), "Broadcast Test") + val slices = if (args.length > 1) args(1).toInt else 2 + val num = if (args.length > 2) args(2).toInt else 1000000 + + var arr = new Array[Int](num) + for (i <- 0 until arr.length) + arr(i) = i + + val barr = spark.broadcast(arr) + spark.parallelize(1 to 10, slices).foreach { + println("in task: barr = " + barr) + i => println(barr.value.size) + } + } +} + diff --git a/src/examples/SparkALS.scala b/src/examples/SparkALS.scala index a5d8559d7b..6fae3c0940 100644 --- a/src/examples/SparkALS.scala +++ b/src/examples/SparkALS.scala @@ -120,18 +120,18 @@ object SparkALS { // Iteratively update movies then users val Rc = spark.broadcast(R) - var msb = spark.broadcast(ms) - var usb = spark.broadcast(us) + var msc = spark.broadcast(ms) + var usc = spark.broadcast(us) for (iter <- 1 to ITERATIONS) { println("Iteration " + iter + ":") ms = spark.parallelize(0 until M, slices) - .map(i => updateMovie(i, msb.value(i), usb.value, Rc.value)) + .map(i => updateMovie(i, msc.value(i), usc.value, Rc.value)) .toArray - msb = spark.broadcast(ms) // Re-broadcast ms because it was updated + msc = spark.broadcast(ms) // Re-broadcast ms because it was updated us = spark.parallelize(0 until U, slices) - .map(i => updateUser(i, usb.value(i), msb.value, Rc.value)) + .map(i => updateUser(i, usc.value(i), msc.value, Rc.value)) .toArray - usb = spark.broadcast(us) // Re-broadcast us because it was updated + usc = spark.broadcast(us) // Re-broadcast us because it was updated println("RMSE = " + rmse(R, ms, us)) println() } diff --git a/src/scala/spark/Broadcast.scala b/src/scala/spark/Broadcast.scala new file mode 100644 index 0000000000..7fe84da47c --- /dev/null +++ b/src/scala/spark/Broadcast.scala @@ -0,0 +1,798 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{UUID, PriorityQueue, Comparator} + +import com.google.common.collect.MapMaker + +import java.util.concurrent.{Executors, ExecutorService} + +import scala.actors.Actor +import scala.actors.Actor._ + +import scala.collection.mutable.Map + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} + +import spark.compress.lzf.{LZFInputStream, LZFOutputStream} + +@serializable +trait BroadcastRecipe { + val uuid = UUID.randomUUID + + // We cannot have an abstract readObject here due to some weird issues with + // readObject having to be 'private' in sub-classes. Possibly a Scala bug! + def sendBroadcast: Unit + + override def toString = "spark.Broadcast(" + uuid + ")" +} + +// TODO: Should think about storing in HDFS in the future +// TODO: Right, now no parallelization between multiple broadcasts +@serializable +class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) + extends BroadcastRecipe { + + def value = value_ + + BroadcastCS.synchronized { BroadcastCS.values.put (uuid, value_) } + + if (!local) { sendBroadcast } + + def sendBroadcast () { + // Create a variableInfo object and store it in valueInfos + var variableInfo = blockifyObject (value_, BroadcastCS.blockSize) + // TODO: Even though this part is not in use now, there is problem in the + // following statement. Shouldn't use constant port and hostAddress anymore? + // val masterSource = + // new SourceInfo (BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort, + // variableInfo.totalBlocks, variableInfo.totalBytes, 0) + // variableInfo.pqOfSources.add (masterSource) + + BroadcastCS.synchronized { + // BroadcastCS.valueInfos.put (uuid, variableInfo) + + // TODO: Not using variableInfo in current implementation. Manually + // setting all the variables inside BroadcastCS object + + BroadcastCS.initializeVariable (variableInfo) + } + + // Now store a persistent copy in HDFS, just in case + val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid)) + out.writeObject (value_) + out.close + } + + private def readObject (in: ObjectInputStream) { + in.defaultReadObject + BroadcastCS.synchronized { + val cachedVal = BroadcastCS.values.get (uuid) + if (cachedVal != null) { + value_ = cachedVal.asInstanceOf[T] + } else { + // Only a single worker (the first one) in the same node can ever be + // here. The rest will always get the value ready + val start = System.nanoTime + + val retByteArray = BroadcastCS.receiveBroadcast (uuid) + // If does not succeed, then get from HDFS copy + if (retByteArray != null) { + value_ = byteArrayToObject[T] (retByteArray) + BroadcastCS.values.put (uuid, value_) + // val variableInfo = blockifyObject (value_, BroadcastCS.blockSize) + // BroadcastCS.valueInfos.put (uuid, variableInfo) + } else { + val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid)) + value_ = fileIn.readObject.asInstanceOf[T] + BroadcastCH.values.put(uuid, value_) + fileIn.close + } + + val time = (System.nanoTime - start) / 1e9 + println( System.currentTimeMillis + ": " + "Reading Broadcasted variable " + uuid + " took " + time + " s") + } + } + } + + private def blockifyObject (obj: T, blockSize: Int): VariableInfo = { + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream (baos) + oos.writeObject (obj) + oos.close + baos.close + val byteArray = baos.toByteArray + val bais = new ByteArrayInputStream (byteArray) + + var blockNum = (byteArray.length / blockSize) + if (byteArray.length % blockSize != 0) + blockNum += 1 + + var retVal = new Array[BroadcastBlock] (blockNum) + var blockID = 0 + + // TODO: What happens in byteArray.length == 0 => blockNum == 0 + for (i <- 0 until (byteArray.length, blockSize)) { + val thisBlockSize = Math.min (blockSize, byteArray.length - i) + var tempByteArray = new Array[Byte] (thisBlockSize) + val hasRead = bais.read (tempByteArray, 0, thisBlockSize) + + retVal (blockID) = new BroadcastBlock (blockID, tempByteArray) + blockID += 1 + } + bais.close + + var variableInfo = VariableInfo (retVal, blockNum, byteArray.length) + variableInfo.hasBlocks = blockNum + + return variableInfo + } + + private def byteArrayToObject[A] (bytes: Array[Byte]): A = { + val in = new ObjectInputStream (new ByteArrayInputStream (bytes)) + val retVal = in.readObject.asInstanceOf[A] + in.close + return retVal + } + + private def getByteArrayOutputStream (obj: T): ByteArrayOutputStream = { + val bOut = new ByteArrayOutputStream + val out = new ObjectOutputStream (bOut) + out.writeObject (obj) + out.close + bOut.close + return bOut + } +} + +@serializable +class CentralizedHDFSBroadcast[T](@transient var value_ : T, local: Boolean) + extends BroadcastRecipe { + + def value = value_ + + BroadcastCH.synchronized { BroadcastCH.values.put(uuid, value_) } + + if (!local) { sendBroadcast } + + def sendBroadcast () { + val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid)) + out.writeObject (value_) + out.close + } + + // Called by Java when deserializing an object + private def readObject(in: ObjectInputStream) { + in.defaultReadObject + BroadcastCH.synchronized { + val cachedVal = BroadcastCH.values.get(uuid) + if (cachedVal != null) { + value_ = cachedVal.asInstanceOf[T] + } else { + val start = System.nanoTime + + val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid)) + value_ = fileIn.readObject.asInstanceOf[T] + BroadcastCH.values.put(uuid, value_) + fileIn.close + + val time = (System.nanoTime - start) / 1e9 + println( System.currentTimeMillis + ": " + "Reading Broadcasted variable " + uuid + " took " + time + " s") + } + } + } +} + +@serializable +case class SourceInfo (val hostAddress: String, val listenPort: Int, + val totalBlocks: Int, val totalBytes: Int, val replicaID: Int) + extends Comparable [SourceInfo]{ + + var currentLeechers = 0 + var receptionFailed = false + + def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers) +} + +@serializable +case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { } + +@serializable +case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock], + val totalBlocks: Int, val totalBytes: Int) { + + @transient var hasBlocks = 0 + + val listenPortLock = new AnyRef + val totalBlocksLock = new AnyRef + val hasBlocksLock = new AnyRef + + @transient var pqOfSources = new PriorityQueue[SourceInfo] +} + +private object Broadcast { + private var initialized = false + + // Will be called by SparkContext or Executor before using Broadcast + // Calls all other initializers here + def initialize (isMaster: Boolean) { + synchronized { + if (!initialized) { + // Initialization for CentralizedHDFSBroadcast + BroadcastCH.initialize + // Initialization for ChainedStreamingBroadcast + //BroadcastCS.initialize (isMaster) + + initialized = true + } + } + } +} + +private object BroadcastCS { + val values = new MapMaker ().softValues ().makeMap[UUID, Any] + // val valueInfos = new MapMaker ().softValues ().makeMap[UUID, Any] + + // private var valueToPort = Map[UUID, Int] () + + private var initialized = false + private var isMaster_ = false + + private var masterHostAddress_ = "127.0.0.1" + private var masterListenPort_ : Int = 11111 + private var blockSize_ : Int = 512 * 1024 + private var maxRetryCount_ : Int = 2 + private var serverSocketTimout_ : Int = 50000 + private var dualMode_ : Boolean = false + + private val hostAddress = InetAddress.getLocalHost.getHostAddress + private var listenPort = -1 + + var arrayOfBlocks: Array[BroadcastBlock] = null + var totalBytes = -1 + var totalBlocks = -1 + var hasBlocks = 0 + + val listenPortLock = new Object + val totalBlocksLock = new Object + val hasBlocksLock = new Object + + var pqOfSources = new PriorityQueue[SourceInfo] + + private var serveMR: ServeMultipleRequests = null + private var guideMR: GuideMultipleRequests = null + + def initialize (isMaster__ : Boolean) { + synchronized { + if (!initialized) { + masterHostAddress_ = + System.getProperty ("spark.broadcast.masterHostAddress", "127.0.0.1") + masterListenPort_ = + System.getProperty ("spark.broadcast.masterListenPort", "11111").toInt + blockSize_ = + System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024 + maxRetryCount_ = + System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt + serverSocketTimout_ = + System.getProperty ("spark.broadcast.serverSocketTimout", "50000").toInt + dualMode_ = + System.getProperty ("spark.broadcast.dualMode", "false").toBoolean + + isMaster_ = isMaster__ + + if (isMaster) { + guideMR = new GuideMultipleRequests + guideMR.setDaemon (true) + guideMR.start + println (System.currentTimeMillis + ": " + "GuideMultipleRequests started") + } + serveMR = new ServeMultipleRequests + serveMR.setDaemon (true) + serveMR.start + + println (System.currentTimeMillis + ": " + "ServeMultipleRequests started") + + println (System.currentTimeMillis + ": " + "BroadcastCS object has been initialized") + + initialized = true + } + } + } + + // TODO: This should change in future implementation. + // Called from the Master constructor to setup states for this particular that + // is being broadcasted + def initializeVariable (variableInfo: VariableInfo) { + arrayOfBlocks = variableInfo.arrayOfBlocks + totalBytes = variableInfo.totalBytes + totalBlocks = variableInfo.totalBlocks + hasBlocks = variableInfo.totalBlocks + + // listenPort should already be valid + assert (listenPort != -1) + + pqOfSources = new PriorityQueue[SourceInfo] + val masterSource_0 = + new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 0) + BroadcastCS.pqOfSources.add (masterSource_0) + // Add one more time to have two replicas of any seeds in the PQ + if (BroadcastCS.dualMode) { + val masterSource_1 = + new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 1) + BroadcastCS.pqOfSources.add (masterSource_1) + } + } + + def masterHostAddress = masterHostAddress_ + def masterListenPort = masterListenPort_ + def blockSize = blockSize_ + def maxRetryCount = maxRetryCount_ + def serverSocketTimout = serverSocketTimout_ + def dualMode = dualMode_ + + def isMaster = isMaster_ + + def receiveBroadcast (variableUUID: UUID): Array[Byte] = { + // Wait until hostAddress and listenPort are created by the + // ServeMultipleRequests thread + // NO need to wait; ServeMultipleRequests is created much further ahead + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + // Connect and receive broadcast from the specified source, retrying the + // specified number of times in case of failures + var retriesLeft = BroadcastCS.maxRetryCount + var retByteArray: Array[Byte] = null + do { + // Connect to Master and send this worker's Information + val clientSocketToMaster = + new Socket(BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort) + println (System.currentTimeMillis + ": " + "Connected to Master's guiding object") + // TODO: Guiding object connection is reusable + val oisMaster = + new ObjectInputStream (clientSocketToMaster.getInputStream) + val oosMaster = + new ObjectOutputStream (clientSocketToMaster.getOutputStream) + + oosMaster.writeObject(new SourceInfo (hostAddress, listenPort, -1, -1, 0)) + oosMaster.flush + + // Receive source information from Master + var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] + totalBlocks = sourceInfo.totalBlocks + arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks) + totalBlocksLock.synchronized { + totalBlocksLock.notifyAll + } + totalBytes = sourceInfo.totalBytes + + println (System.currentTimeMillis + ": " + "Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) + + retByteArray = receiveSingleTransmission (sourceInfo) + + println (System.currentTimeMillis + ": " + "I got this from receiveSingleTransmission: " + retByteArray) + + // TODO: Update sourceInfo to add error notifactions for Master + if (retByteArray == null) { sourceInfo.receptionFailed = true } + + // TODO: Supposed to update values here, but we don't support advanced + // statistics right now. Master can handle leecherCount by itself. + + // Send back statistics to the Master + oosMaster.writeObject (sourceInfo) + + oisMaster.close + oosMaster.close + clientSocketToMaster.close + + retriesLeft -= 1 + } while (retriesLeft > 0 && retByteArray == null) + + return retByteArray + } + + // Tries to receive broadcast from the Master and returns Boolean status. + // This might be called multiple times to retry a defined number of times. + private def receiveSingleTransmission(sourceInfo: SourceInfo): Array[Byte] = { + var clientSocketToSource: Socket = null + var oisSource: ObjectInputStream = null + var oosSource: ObjectOutputStream = null + + var retByteArray:Array[Byte] = null + + try { + // Connect to the source to get the object itself + clientSocketToSource = + new Socket (sourceInfo.hostAddress, sourceInfo.listenPort) + oosSource = + new ObjectOutputStream (clientSocketToSource.getOutputStream) + oisSource = + new ObjectInputStream (clientSocketToSource.getInputStream) + + println (System.currentTimeMillis + ": " + "Inside receiveSingleTransmission") + println (System.currentTimeMillis + ": " + "totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) + retByteArray = new Array[Byte] (totalBytes) + for (i <- 0 until totalBlocks) { + val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] + System.arraycopy (bcBlock.byteArray, 0, retByteArray, + i * BroadcastCS.blockSize, bcBlock.byteArray.length) + arrayOfBlocks(hasBlocks) = bcBlock + hasBlocks += 1 + hasBlocksLock.synchronized { + hasBlocksLock.notifyAll + } + println (System.currentTimeMillis + ": " + "Received block: " + i + " " + bcBlock) + } + assert (hasBlocks == totalBlocks) + println (System.currentTimeMillis + ": " + "After the receive loop") + } catch { + case e: Exception => { + retByteArray = null + println (System.currentTimeMillis + ": " + "receiveSingleTransmission had a " + e) + } + } finally { + if (oisSource != null) { oisSource.close } + if (oosSource != null) { + oosSource.close + } + if (clientSocketToSource != null) { clientSocketToSource.close } + } + + return retByteArray + } + + class TrackMultipleValues extends Thread { + override def run = { + var threadPool = Executors.newCachedThreadPool + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket (BroadcastCS.masterListenPort) + println (System.currentTimeMillis + ": " + "TrackMultipleVariables" + serverSocket + " " + listenPort) + + var keepAccepting = true + try { + while (true) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout (serverSocketTimout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + println ("TrackMultipleValues Timeout. Stopping listening...") + keepAccepting = false + } + } + println (System.currentTimeMillis + ": " + "TrackMultipleValues:Got new request:" + clientSocket) + if (clientSocket != null) { + try { + threadPool.execute (new Runnable { + def run = { + val oos = new ObjectOutputStream (clientSocket.getOutputStream) + val ois = new ObjectInputStream (clientSocket.getInputStream) + try { + val variableUUID = ois.readObject.asInstanceOf[UUID] + var contactPort = 0 + // TODO: Add logic and data structures to find out UUID->port + // mapping. 0 = missed the broadcast, read from HDFS; <0 = + // Haven't started yet, wait & retry; >0 = Read from this port + oos.writeObject (contactPort) + } catch { + case e: Exception => { } + } finally { + ois.close + oos.close + clientSocket.close + } + } + }) + } catch { + // In failure, close the socket here; else, the thread will close it + case ioe: IOException => clientSocket.close + } + } + } + } finally { + serverSocket.close + } + } + } + + class TrackSingleValue { + + } + + class GuideMultipleRequests extends Thread { + override def run = { + var threadPool = Executors.newCachedThreadPool + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket (BroadcastCS.masterListenPort) + // listenPort = BroadcastCS.masterListenPort + println (System.currentTimeMillis + ": " + "GuideMultipleRequests" + serverSocket + " " + listenPort) + + var keepAccepting = true + try { + while (keepAccepting) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout (serverSocketTimout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + println ("GuideMultipleRequests Timeout. Stopping listening...") + keepAccepting = false + } + } + if (clientSocket != null) { + println (System.currentTimeMillis + ": " + "Guide:Accepted new client connection:" + clientSocket) + try { + threadPool.execute (new GuideSingleRequest (clientSocket)) + } catch { + // In failure, close the socket here; else, the thread will close it + case ioe: IOException => clientSocket.close + } + } + } + } finally { + serverSocket.close + } + } + + class GuideSingleRequest (val clientSocket: Socket) extends Runnable { + private val oos = new ObjectOutputStream (clientSocket.getOutputStream) + private val ois = new ObjectInputStream (clientSocket.getInputStream) + + private var selectedSourceInfo: SourceInfo = null + private var thisWorkerInfo:SourceInfo = null + + def run = { + try { + println (System.currentTimeMillis + ": " + "new GuideSingleRequest is running") + // Connecting worker is sending in its hostAddress and listenPort it will + // be listening to. ReplicaID is 0 and other fields are invalid (-1) + var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + // Select a suitable source and send it back to the worker + selectedSourceInfo = selectSuitableSource (sourceInfo) + println (System.currentTimeMillis + ": " + "Sending selectedSourceInfo:" + selectedSourceInfo) + oos.writeObject (selectedSourceInfo) + oos.flush + + // Add this new (if it can finish) source to the PQ of sources + thisWorkerInfo = new SourceInfo(sourceInfo.hostAddress, + sourceInfo.listenPort, totalBlocks, totalBytes, 0) + println (System.currentTimeMillis + ": " + "Adding possible new source to pqOfSources: " + thisWorkerInfo) + pqOfSources.synchronized { + pqOfSources.add (thisWorkerInfo) + } + + // Wait till the whole transfer is done. Then receive and update source + // statistics in pqOfSources + sourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + pqOfSources.synchronized { + // This should work since SourceInfo is a case class + assert (pqOfSources.contains (selectedSourceInfo)) + + // Remove first + pqOfSources.remove (selectedSourceInfo) + // TODO: Removing a source based on just one failure notification! + // Update leecher count and put it back in IF reception succeeded + if (!sourceInfo.receptionFailed) { + selectedSourceInfo.currentLeechers -= 1 + pqOfSources.add (selectedSourceInfo) + + // No need to find and update thisWorkerInfo, but add its replica + if (BroadcastCS.dualMode) { + pqOfSources.add (new SourceInfo (thisWorkerInfo.hostAddress, + thisWorkerInfo.listenPort, totalBlocks, totalBytes, 1)) + } + } + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc. + // then close everything up + case e: Exception => { + // Assuming that exception caused due to receiver worker failure + // Remove failed worker from pqOfSources and update leecherCount of + // corresponding source worker + pqOfSources.synchronized { + if (selectedSourceInfo != null) { + // Remove first + pqOfSources.remove (selectedSourceInfo) + // Update leecher count and put it back in + selectedSourceInfo.currentLeechers -= 1 + pqOfSources.add (selectedSourceInfo) + } + + // Remove thisWorkerInfo + if (pqOfSources != null) { pqOfSources.remove (thisWorkerInfo) } + } + } + } finally { + ois.close + oos.close + clientSocket.close + } + } + + // TODO: If a worker fails to get the broadcasted variable from a source and + // comes back to Master, this function might choose the worker itself as a + // source tp create a dependency cycle (this worker was put into pqOfSources + // as a streming source when it first arrived). The length of this cycle can + // be arbitrarily long. + private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { + // Select one with the lowest number of leechers + pqOfSources.synchronized { + // take is a blocking call removing the element from PQ + var selectedSource = pqOfSources.poll + assert (selectedSource != null) + // Update leecher count + selectedSource.currentLeechers += 1 + // Add it back and then return + pqOfSources.add (selectedSource) + return selectedSource + } + } + } + } + + class ServeMultipleRequests extends Thread { + override def run = { + var threadPool = Executors.newCachedThreadPool + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket (0) + listenPort = serverSocket.getLocalPort + println (System.currentTimeMillis + ": " + "ServeMultipleRequests" + serverSocket + " " + listenPort) + + listenPortLock.synchronized { + listenPortLock.notifyAll + } + + var keepAccepting = true + try { + while (keepAccepting) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout (serverSocketTimout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + println ("ServeMultipleRequests Timeout. Stopping listening...") + keepAccepting = false + } + } + if (clientSocket != null) { + println (System.currentTimeMillis + ": " + "Serve:Accepted new client connection:" + clientSocket) + try { + threadPool.execute (new ServeSingleRequest (clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => clientSocket.close + } + } + } + } finally { + serverSocket.close + } + } + + class ServeSingleRequest (val clientSocket: Socket) extends Runnable { + private val oos = new ObjectOutputStream (clientSocket.getOutputStream) + private val ois = new ObjectInputStream (clientSocket.getInputStream) + + def run = { + try { + println (System.currentTimeMillis + ": " + "new ServeSingleRequest is running") + sendObject + } catch { + // TODO: Need to add better exception handling here + // If something went wrong, e.g., the worker at the other end died etc. + // then close everything up + case e: Exception => { + println (System.currentTimeMillis + ": " + "ServeSingleRequest had a " + e) + } + } finally { + println (System.currentTimeMillis + ": " + "ServeSingleRequest is closing streams and sockets") + ois.close + oos.close + clientSocket.close + } + } + + private def sendObject = { + // Wait till receiving the SourceInfo from Master + while (totalBlocks == -1) { + totalBlocksLock.synchronized { + totalBlocksLock.wait + } + } + + for (i <- 0 until totalBlocks) { + while (i == hasBlocks) { + hasBlocksLock.synchronized { + hasBlocksLock.wait + } + } + try { + oos.writeObject (arrayOfBlocks(i)) + oos.flush + } catch { + case e: Exception => { } + } + println (System.currentTimeMillis + ": " + "Send block: " + i + " " + arrayOfBlocks(i)) + } + } + } + + } +} + +private object BroadcastCH { + val values = new MapMaker ().softValues ().makeMap[UUID, Any] + + private var initialized = false + + private var fileSystem: FileSystem = null + private var workDir: String = null + private var compress: Boolean = false + private var bufferSize: Int = 65536 + + def initialize () { + synchronized { + if (!initialized) { + bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val dfs = System.getProperty("spark.dfs", "file:///") + if (!dfs.startsWith("file://")) { + val conf = new Configuration() + conf.setInt("io.file.buffer.size", bufferSize) + val rep = System.getProperty("spark.dfs.replication", "3").toInt + conf.setInt("dfs.replication", rep) + fileSystem = FileSystem.get(new URI(dfs), conf) + } + workDir = System.getProperty("spark.dfs.workdir", "/tmp") + compress = System.getProperty("spark.compress", "false").toBoolean + + initialized = true + } + } + } + + private def getPath(uuid: UUID) = new Path(workDir + "/broadcast-" + uuid) + + def openFileForReading(uuid: UUID): InputStream = { + val fileStream = if (fileSystem != null) { + fileSystem.open(getPath(uuid)) + } else { + // Local filesystem + new FileInputStream(getPath(uuid).toString) + } + if (compress) + new LZFInputStream(fileStream) // LZF stream does its own buffering + else if (fileSystem == null) + new BufferedInputStream(fileStream, bufferSize) + else + fileStream // Hadoop streams do their own buffering + } + + def openFileForWriting(uuid: UUID): OutputStream = { + val fileStream = if (fileSystem != null) { + fileSystem.create(getPath(uuid)) + } else { + // Local filesystem + new FileOutputStream(getPath(uuid).toString) + } + if (compress) + new LZFOutputStream(fileStream) // LZF stream does its own buffering + else if (fileSystem == null) + new BufferedOutputStream(fileStream, bufferSize) + else + fileStream // Hadoop streams do their own buffering + } +} diff --git a/src/scala/spark/Cached.scala b/src/scala/spark/Cached.scala deleted file mode 100644 index 8113340e1f..0000000000 --- a/src/scala/spark/Cached.scala +++ /dev/null @@ -1,110 +0,0 @@ -package spark - -import java.io._ -import java.net.URI -import java.util.UUID - -import com.google.common.collect.MapMaker - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} - -import spark.compress.lzf.{LZFInputStream, LZFOutputStream} - -@serializable class Cached[T](@transient var value_ : T, local: Boolean) { - val uuid = UUID.randomUUID() - def value = value_ - - Cache.synchronized { Cache.values.put(uuid, value_) } - - if (!local) writeCacheFile() - - private def writeCacheFile() { - val out = new ObjectOutputStream(Cache.openFileForWriting(uuid)) - out.writeObject(value_) - out.close() - } - - // Called by Java when deserializing an object - private def readObject(in: ObjectInputStream) { - in.defaultReadObject - Cache.synchronized { - val cachedVal = Cache.values.get(uuid) - if (cachedVal != null) { - value_ = cachedVal.asInstanceOf[T] - } else { - val start = System.nanoTime - val fileIn = new ObjectInputStream(Cache.openFileForReading(uuid)) - value_ = fileIn.readObject().asInstanceOf[T] - Cache.values.put(uuid, value_) - fileIn.close() - val time = (System.nanoTime - start) / 1e9 - println("Reading cached variable " + uuid + " took " + time + " s") - } - } - } - - override def toString = "spark.Cached(" + uuid + ")" -} - -private object Cache { - val values = new MapMaker().softValues().makeMap[UUID, Any]() - - private var initialized = false - private var fileSystem: FileSystem = null - private var workDir: String = null - private var compress: Boolean = false - private var bufferSize: Int = 65536 - - // Will be called by SparkContext or Executor before using cache - def initialize() { - synchronized { - if (!initialized) { - bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - val dfs = System.getProperty("spark.dfs", "file:///") - if (!dfs.startsWith("file://")) { - val conf = new Configuration() - conf.setInt("io.file.buffer.size", bufferSize) - val rep = System.getProperty("spark.dfs.replication", "3").toInt - conf.setInt("dfs.replication", rep) - fileSystem = FileSystem.get(new URI(dfs), conf) - } - workDir = System.getProperty("spark.dfs.workdir", "/tmp") - compress = System.getProperty("spark.compress", "false").toBoolean - initialized = true - } - } - } - - private def getPath(uuid: UUID) = new Path(workDir + "/cache-" + uuid) - - def openFileForReading(uuid: UUID): InputStream = { - val fileStream = if (fileSystem != null) { - fileSystem.open(getPath(uuid)) - } else { - // Local filesystem - new FileInputStream(getPath(uuid).toString) - } - if (compress) - new LZFInputStream(fileStream) // LZF stream does its own buffering - else if (fileSystem == null) - new BufferedInputStream(fileStream, bufferSize) - else - fileStream // Hadoop streams do their own buffering - } - - def openFileForWriting(uuid: UUID): OutputStream = { - val fileStream = if (fileSystem != null) { - fileSystem.create(getPath(uuid)) - } else { - // Local filesystem - new FileOutputStream(getPath(uuid).toString) - } - if (compress) - new LZFOutputStream(fileStream) // LZF stream does its own buffering - else if (fileSystem == null) - new BufferedOutputStream(fileStream, bufferSize) - else - fileStream // Hadoop streams do their own buffering - } -} diff --git a/src/scala/spark/Executor.scala b/src/scala/spark/Executor.scala index 5c3d2523bf..679a61f3c0 100644 --- a/src/scala/spark/Executor.scala +++ b/src/scala/spark/Executor.scala @@ -19,8 +19,8 @@ object Executor { for ((key, value) <- props) System.setProperty(key, value) - // Initialize cache (uses some properties read above) - Cache.initialize() + // Initialize broadcast system (uses some properties read above) + Broadcast.initialize(false) // If the REPL is in use, create a ClassLoader that will be able to // read new classes defined by the REPL as the user types code diff --git a/src/scala/spark/HdfsFile.scala b/src/scala/spark/HdfsFile.scala index 8c702c9226..6aa0e22338 100644 --- a/src/scala/spark/HdfsFile.scala +++ b/src/scala/spark/HdfsFile.scala @@ -60,8 +60,10 @@ extends RDD[String, HdfsSplit](sc) { } } - override def prefers(split: HdfsSplit, slot: SlaveOffer) = - split.value.getLocations().contains(slot.getHost) + override def preferredLocations(split: HdfsSplit) = { + // TODO: Filtering out "localhost" in case of file:// URLs + split.value.getLocations().filter(_ != "localhost") + } } object ConfigureLock {} diff --git a/src/scala/spark/NexusScheduler.scala b/src/scala/spark/NexusScheduler.scala index 29c2011093..a5343039ef 100644 --- a/src/scala/spark/NexusScheduler.scala +++ b/src/scala/spark/NexusScheduler.scala @@ -1,50 +1,45 @@ package spark import java.io.File -import java.util.concurrent.Semaphore -import nexus.{ExecutorInfo, TaskDescription, TaskState, TaskStatus} -import nexus.{SlaveOffer, SchedulerDriver, NexusSchedulerDriver} -import nexus.{SlaveOfferVector, TaskDescriptionVector, StringMap} +import scala.collection.mutable.Map + +import nexus.{Scheduler => NScheduler} +import nexus._ // The main Scheduler implementation, which talks to Nexus. Clients are expected // to first call start(), then submit tasks through the runTasks method. // // This implementation is currently a little quick and dirty. The following // improvements need to be made to it: -// 1) Fault tolerance should be added - if a task fails, just re-run it anywhere. -// 2) Right now, the scheduler uses a linear scan through the tasks to find a +// 1) Right now, the scheduler uses a linear scan through the tasks to find a // local one for a given node. It would be faster to have a separate list of // pending tasks for each node. -// 3) The Callbacks way of organizing things didn't work out too well, so the -// way the scheduler keeps track of the currently active runTasks operation -// can be made cleaner. +// 2) Presenting a single slave in ParallelOperation.slaveOffer makes it +// difficult to balance tasks across nodes. It would be better to pass +// all the offers to the ParallelOperation and have it load-balance. private class NexusScheduler( master: String, frameworkName: String, execArg: Array[Byte]) -extends nexus.Scheduler with spark.Scheduler +extends NScheduler with spark.Scheduler { - // Semaphore used by runTasks to ensure only one thread can be in it - val semaphore = new Semaphore(1) + // Lock used by runTasks to ensure only one thread can be in it + val runTasksMutex = new Object() // Lock used to wait for scheduler to be registered var isRegistered = false val registeredLock = new Object() - // Trait representing a set of scheduler callbacks - trait Callbacks { - def slotOffer(s: SlaveOffer): Option[TaskDescription] - def taskFinished(t: TaskStatus): Unit - def error(code: Int, message: String): Unit - } - // Current callback object (may be null) - var callbacks: Callbacks = null + var activeOp: ParallelOperation = null // Incrementing task ID - var nextTaskId = 0 + private var nextTaskId = 0 - // Maximum time to wait to run a task in a preferred location (in ms) - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "1000").toLong + def newTaskId(): Int = { + val id = nextTaskId; + nextTaskId += 1; + return id + } // Driver for talking to Nexus var driver: SchedulerDriver = null @@ -65,126 +60,28 @@ extends nexus.Scheduler with spark.Scheduler override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo = new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg) - override def runTasks[T: ClassManifest](tasks: Array[Task[T]]) : Array[T] = { - val results = new Array[T](tasks.length) - if (tasks.length == 0) - return results + override def runTasks[T: ClassManifest](tasks: Array[Task[T]]): Array[T] = { + runTasksMutex.synchronized { + waitForRegister() + val myOp = new SimpleParallelOperation(this, tasks) - val launched = new Array[Boolean](tasks.length) - - val callingThread = currentThread - - var errorHappened = false - var errorCode = 0 - var errorMessage = "" - - // Wait for scheduler to be registered with Nexus - waitForRegister() - - try { - // Acquire the runTasks semaphore - semaphore.acquire() - - val myCallbacks = new Callbacks { - val firstTaskId = nextTaskId - var tasksLaunched = 0 - var tasksFinished = 0 - var lastPreferredLaunchTime = System.currentTimeMillis - - def slotOffer(slot: SlaveOffer): Option[TaskDescription] = { - try { - if (tasksLaunched < tasks.length) { - // TODO: Add a short wait if no task with location pref is found - // TODO: Figure out why a function is needed around this to - // avoid scala.runtime.NonLocalReturnException - def findTask: Option[TaskDescription] = { - var checkPrefVals: Array[Boolean] = Array(true) - val time = System.currentTimeMillis - if (time - lastPreferredLaunchTime > LOCALITY_WAIT) - checkPrefVals = Array(true, false) // Allow non-preferred tasks - // TODO: Make desiredCpus and desiredMem configurable - val desiredCpus = 1 - val desiredMem = 750L * 1024L * 1024L - if (slot.getParams.get("cpus").toInt < desiredCpus || - slot.getParams.get("mem").toLong < desiredMem) - return None - for (checkPref <- checkPrefVals; - i <- 0 until tasks.length; - if !launched(i) && (!checkPref || tasks(i).prefers(slot))) - { - val taskId = nextTaskId - nextTaskId += 1 - printf("Starting task %d as TID %d on slave %d: %s (%s)\n", - i, taskId, slot.getSlaveId, slot.getHost, - if(checkPref) "preferred" else "non-preferred") - tasks(i).markStarted(slot) - launched(i) = true - tasksLaunched += 1 - if (checkPref) - lastPreferredLaunchTime = time - val params = new StringMap - params.set("cpus", "" + desiredCpus) - params.set("mem", "" + desiredMem) - val serializedTask = Utils.serialize(tasks(i)) - return Some(new TaskDescription(taskId, slot.getSlaveId, - "task_" + taskId, params, serializedTask)) - } - return None - } - return findTask - } else { - return None - } - } catch { - case e: Exception => { - e.printStackTrace - System.exit(1) - return None - } - } + try { + this.synchronized { + this.activeOp = myOp } - - def taskFinished(status: TaskStatus) { - println("Finished TID " + status.getTaskId) - // Deserialize task result - val result = Utils.deserialize[TaskResult[T]](status.getData) - results(status.getTaskId - firstTaskId) = result.value - // Update accumulators - Accumulators.add(callingThread, result.accumUpdates) - // Stop if we've finished all the tasks - tasksFinished += 1 - if (tasksFinished == tasks.length) { - NexusScheduler.this.callbacks = null - NexusScheduler.this.notifyAll() - } - } - - def error(code: Int, message: String) { - // Save the error message - errorHappened = true - errorCode = code - errorMessage = message - // Indicate to caller thread that we're done - NexusScheduler.this.callbacks = null - NexusScheduler.this.notifyAll() + driver.reviveOffers(); + myOp.join(); + } finally { + this.synchronized { + this.activeOp = null } } - this.synchronized { - this.callbacks = myCallbacks - } - driver.reviveOffers(); - this.synchronized { - while (this.callbacks != null) this.wait() - } - } finally { - semaphore.release() + if (myOp.errorHappened) + throw new SparkException(myOp.errorMessage, myOp.errorCode) + else + return myOp.results } - - if (errorHappened) - throw new SparkException(errorMessage, errorCode) - else - return results } override def registered(d: SchedulerDriver, frameworkId: Int) { @@ -197,18 +94,19 @@ extends nexus.Scheduler with spark.Scheduler override def waitForRegister() { registeredLock.synchronized { - while (!isRegistered) registeredLock.wait() + while (!isRegistered) + registeredLock.wait() } } override def resourceOffer( - d: SchedulerDriver, oid: Long, slots: SlaveOfferVector) { + d: SchedulerDriver, oid: Long, offers: SlaveOfferVector) { synchronized { val tasks = new TaskDescriptionVector - if (callbacks != null) { + if (activeOp != null) { try { - for (i <- 0 until slots.size.toInt) { - callbacks.slotOffer(slots.get(i)) match { + for (i <- 0 until offers.size.toInt) { + activeOp.slaveOffer(offers.get(i)) match { case Some(task) => tasks.add(task) case None => {} } @@ -225,21 +123,21 @@ extends nexus.Scheduler with spark.Scheduler override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { synchronized { - if (callbacks != null && status.getState == TaskState.TASK_FINISHED) { - try { - callbacks.taskFinished(status) - } catch { - case e: Exception => e.printStackTrace + try { + if (activeOp != null) { + activeOp.statusUpdate(status) } + } catch { + case e: Exception => e.printStackTrace } } } override def error(d: SchedulerDriver, code: Int, message: String) { synchronized { - if (callbacks != null) { + if (activeOp != null) { try { - callbacks.error(code, message) + activeOp.error(code, message) } catch { case e: Exception => e.printStackTrace } @@ -256,3 +154,137 @@ extends nexus.Scheduler with spark.Scheduler driver.stop() } } + + +// Trait representing an object that manages a parallel operation by +// implementing various scheduler callbacks. +trait ParallelOperation { + def slaveOffer(s: SlaveOffer): Option[TaskDescription] + def statusUpdate(t: TaskStatus): Unit + def error(code: Int, message: String): Unit +} + + +class SimpleParallelOperation[T: ClassManifest]( + sched: NexusScheduler, tasks: Array[Task[T]]) +extends ParallelOperation +{ + // Maximum time to wait to run a task in a preferred location (in ms) + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "1000").toLong + + val callingThread = currentThread + val numTasks = tasks.length + val results = new Array[T](numTasks) + val launched = new Array[Boolean](numTasks) + val finished = new Array[Boolean](numTasks) + val tidToIndex = Map[Int, Int]() + + var allFinished = false + val joinLock = new Object() + + var errorHappened = false + var errorCode = 0 + var errorMessage = "" + + var tasksLaunched = 0 + var tasksFinished = 0 + var lastPreferredLaunchTime = System.currentTimeMillis + + def setAllFinished() { + joinLock.synchronized { + allFinished = true + joinLock.notifyAll() + } + } + + def join() { + joinLock.synchronized { + while (!allFinished) + joinLock.wait() + } + } + + def slaveOffer(offer: SlaveOffer): Option[TaskDescription] = { + if (tasksLaunched < numTasks) { + var checkPrefVals: Array[Boolean] = Array(true) + val time = System.currentTimeMillis + if (time - lastPreferredLaunchTime > LOCALITY_WAIT) + checkPrefVals = Array(true, false) // Allow non-preferred tasks + // TODO: Make desiredCpus and desiredMem configurable + val desiredCpus = 1 + val desiredMem = 750L * 1024L * 1024L + if (offer.getParams.get("cpus").toInt < desiredCpus || + offer.getParams.get("mem").toLong < desiredMem) + return None + for (checkPref <- checkPrefVals; i <- 0 until numTasks) { + if (!launched(i) && (!checkPref || + tasks(i).preferredLocations.contains(offer.getHost) || + tasks(i).preferredLocations.isEmpty)) + { + val taskId = sched.newTaskId() + tidToIndex(taskId) = i + printf("Starting task %d as TID %d on slave %d: %s (%s)\n", + i, taskId, offer.getSlaveId, offer.getHost, + if(checkPref) "preferred" else "non-preferred") + tasks(i).markStarted(offer) + launched(i) = true + tasksLaunched += 1 + if (checkPref) + lastPreferredLaunchTime = time + val params = new StringMap + params.set("cpus", "" + desiredCpus) + params.set("mem", "" + desiredMem) + val serializedTask = Utils.serialize(tasks(i)) + return Some(new TaskDescription(taskId, offer.getSlaveId, + "task_" + taskId, params, serializedTask)) + } + } + } + return None + } + + def statusUpdate(status: TaskStatus) { + status.getState match { + case TaskState.TASK_FINISHED => + taskFinished(status) + case TaskState.TASK_LOST => + taskLost(status) + case TaskState.TASK_FAILED => + taskLost(status) + case TaskState.TASK_KILLED => + taskLost(status) + case _ => + } + } + + def taskFinished(status: TaskStatus) { + val tid = status.getTaskId + println("Finished TID " + tid) + // Deserialize task result + val result = Utils.deserialize[TaskResult[T]](status.getData) + results(tidToIndex(tid)) = result.value + // Update accumulators + Accumulators.add(callingThread, result.accumUpdates) + // Mark finished and stop if we've finished all the tasks + finished(tidToIndex(tid)) = true + tasksFinished += 1 + if (tasksFinished == numTasks) + setAllFinished() + } + + def taskLost(status: TaskStatus) { + val tid = status.getTaskId + println("Lost TID " + tid) + launched(tidToIndex(tid)) = false + tasksLaunched -= 1 + } + + def error(code: Int, message: String) { + // Save the error message + errorHappened = true + errorCode = code + errorMessage = message + // Indicate to caller thread that we're done + setAllFinished() + } +} diff --git a/src/scala/spark/ParallelArray.scala b/src/scala/spark/ParallelArray.scala index 86bcff6e20..39ca867cb9 100644 --- a/src/scala/spark/ParallelArray.scala +++ b/src/scala/spark/ParallelArray.scala @@ -38,7 +38,7 @@ extends RDD[T, ParallelArraySplit[T]](sc) { override def iterator(s: ParallelArraySplit[T]) = s.iterator - override def prefers(s: ParallelArraySplit[T], offer: SlaveOffer) = true + override def preferredLocations(s: ParallelArraySplit[T]): Seq[String] = Nil } private object ParallelArray { diff --git a/src/scala/spark/RDD.scala b/src/scala/spark/RDD.scala index 6c30636ed9..f9a16ed782 100644 --- a/src/scala/spark/RDD.scala +++ b/src/scala/spark/RDD.scala @@ -16,7 +16,7 @@ abstract class RDD[T: ClassManifest, Split]( @transient sc: SparkContext) { def splits: Array[Split] def iterator(split: Split): Iterator[T] - def prefers(split: Split, slot: SlaveOffer): Boolean + def preferredLocations(split: Split): Seq[String] def taskStarted(split: Split, slot: SlaveOffer) {} @@ -83,7 +83,7 @@ abstract class RDD[T: ClassManifest, Split]( abstract class RDDTask[U: ClassManifest, T: ClassManifest, Split]( val rdd: RDD[T, Split], val split: Split) extends Task[U] { - override def prefers(slot: SlaveOffer) = rdd.prefers(split, slot) + override def preferredLocations() = rdd.preferredLocations(split) override def markStarted(slot: SlaveOffer) { rdd.taskStarted(split, slot) } } @@ -122,7 +122,7 @@ class MappedRDD[U: ClassManifest, T: ClassManifest, Split]( prev: RDD[T, Split], f: T => U) extends RDD[U, Split](prev.sparkContext) { override def splits = prev.splits - override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot) + override def preferredLocations(split: Split) = prev.preferredLocations(split) override def iterator(split: Split) = prev.iterator(split).map(f) override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot) } @@ -131,7 +131,7 @@ class FilteredRDD[T: ClassManifest, Split]( prev: RDD[T, Split], f: T => Boolean) extends RDD[T, Split](prev.sparkContext) { override def splits = prev.splits - override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot) + override def preferredLocations(split: Split) = prev.preferredLocations(split) override def iterator(split: Split) = prev.iterator(split).filter(f) override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot) } @@ -140,15 +140,15 @@ class CachedRDD[T, Split]( prev: RDD[T, Split])(implicit m: ClassManifest[T]) extends RDD[T, Split](prev.sparkContext) { val id = CachedRDD.newId() - @transient val cacheLocs = Map[Split, List[Int]]() + @transient val cacheLocs = Map[Split, List[String]]() override def splits = prev.splits - override def prefers(split: Split, slot: SlaveOffer): Boolean = { + override def preferredLocations(split: Split) = { if (cacheLocs.contains(split)) - cacheLocs(split).contains(slot.getSlaveId) + cacheLocs(split) else - prev.prefers(split, slot) + prev.preferredLocations(split) } override def iterator(split: Split): Iterator[T] = { @@ -185,9 +185,9 @@ extends RDD[T, Split](prev.sparkContext) { override def taskStarted(split: Split, slot: SlaveOffer) { val oldList = cacheLocs.getOrElse(split, Nil) - val slaveId = slot.getSlaveId - if (!oldList.contains(slaveId)) - cacheLocs(split) = slaveId :: oldList + val host = slot.getHost + if (!oldList.contains(host)) + cacheLocs(split) = host :: oldList } } @@ -205,7 +205,7 @@ private object CachedRDD { @serializable abstract class UnionSplit[T: ClassManifest] { def iterator(): Iterator[T] - def prefers(offer: SlaveOffer): Boolean + def preferredLocations(): Seq[String] } @serializable @@ -213,7 +213,7 @@ class UnionSplitImpl[T: ClassManifest, Split]( rdd: RDD[T, Split], split: Split) extends UnionSplit[T] { override def iterator() = rdd.iterator(split) - override def prefers(offer: SlaveOffer) = rdd.prefers(split, offer) + override def preferredLocations() = rdd.preferredLocations(split) } @serializable @@ -231,5 +231,6 @@ extends RDD[T, UnionSplit[T]](sc) { override def iterator(s: UnionSplit[T]): Iterator[T] = s.iterator() - override def prefers(s: UnionSplit[T], o: SlaveOffer) = s.prefers(o) + override def preferredLocations(s: UnionSplit[T]): Seq[String] = + s.preferredLocations() } diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala index 82456b43ab..6b5a07cff1 100644 --- a/src/scala/spark/SparkContext.scala +++ b/src/scala/spark/SparkContext.scala @@ -6,7 +6,7 @@ import java.util.UUID import scala.collection.mutable.ArrayBuffer class SparkContext(master: String, frameworkName: String) { - Cache.initialize() + Broadcast.initialize(true) def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int) = new ParallelArray[T](this, seq, numSlices) @@ -18,7 +18,8 @@ class SparkContext(master: String, frameworkName: String) { new Accumulator(initialValue, param) // TODO: Keep around a weak hash map of values to Cached versions? - def broadcast[T](value: T) = new Cached(value, local) + def broadcast[T](value: T) = new CentralizedHDFSBroadcast(value, local) + //def broadcast[T](value: T) = new ChainedStreamingBroadcast(value, local) def textFile(path: String) = new HdfsTextFile(this, path) diff --git a/src/scala/spark/Task.scala b/src/scala/spark/Task.scala index e559996a37..efb864472d 100644 --- a/src/scala/spark/Task.scala +++ b/src/scala/spark/Task.scala @@ -5,8 +5,8 @@ import nexus._ @serializable trait Task[T] { def run: T - def prefers(slot: SlaveOffer): Boolean = true - def markStarted(slot: SlaveOffer) {} + def preferredLocations: Seq[String] = Nil + def markStarted(offer: SlaveOffer) {} } @serializable