зеркало из https://github.com/microsoft/spark.git
Start fetching a remote block when a received remote block has been passed
to the reduce function
This commit is contained in:
Родитель
389fb4cc54
Коммит
3076b038f4
|
@ -2,16 +2,14 @@ package spark.storage
|
||||||
|
|
||||||
import java.io._
|
import java.io._
|
||||||
import java.nio._
|
import java.nio._
|
||||||
import java.nio.channels.FileChannel.MapMode
|
|
||||||
import java.util.{HashMap => JHashMap}
|
|
||||||
import java.util.LinkedHashMap
|
|
||||||
import java.util.concurrent.ConcurrentHashMap
|
import java.util.concurrent.ConcurrentHashMap
|
||||||
import java.util.concurrent.LinkedBlockingQueue
|
import java.util.concurrent.LinkedBlockingQueue
|
||||||
import java.util.Collections
|
import java.util.Collections
|
||||||
|
|
||||||
import akka.dispatch.{Await, Future}
|
import akka.dispatch.{Await, Future}
|
||||||
import scala.collection.mutable.ArrayBuffer
|
import scala.collection.mutable.ArrayBuffer
|
||||||
import scala.collection.mutable.HashMap
|
import scala.collection.mutable.{HashMap, HashSet}
|
||||||
|
import scala.collection.mutable.Queue
|
||||||
import scala.collection.JavaConversions._
|
import scala.collection.JavaConversions._
|
||||||
|
|
||||||
import it.unimi.dsi.fastutil.io._
|
import it.unimi.dsi.fastutil.io._
|
||||||
|
@ -273,28 +271,19 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
|
||||||
logDebug("Getting " + totalBlocks + " blocks")
|
logDebug("Getting " + totalBlocks + " blocks")
|
||||||
var startTime = System.currentTimeMillis
|
var startTime = System.currentTimeMillis
|
||||||
val localBlockIds = new ArrayBuffer[String]()
|
val localBlockIds = new ArrayBuffer[String]()
|
||||||
val remoteBlockIds = new ArrayBuffer[String]()
|
val remoteBlockIds = new HashSet[String]()
|
||||||
val remoteBlockIdsPerLocation = new HashMap[BlockManagerId, Seq[String]]()
|
|
||||||
|
|
||||||
// A queue to hold our results. Because we want all the deserializing the happen in the
|
// A queue to hold our results. Because we want all the deserializing the happen in the
|
||||||
// caller's thread, this will actually hold functions to produce the Iterator for each block.
|
// caller's thread, this will actually hold functions to produce the Iterator for each block.
|
||||||
// For local blocks we'll have an iterator already, while for remote ones we'll deserialize.
|
// For local blocks we'll have an iterator already, while for remote ones we'll deserialize.
|
||||||
val results = new LinkedBlockingQueue[(String, Option[() => Iterator[Any]])]
|
val results = new LinkedBlockingQueue[(String, Option[() => Iterator[Any]])]
|
||||||
|
|
||||||
// Split local and remote blocks
|
// Bound the number and memory usage of fetched remote blocks.
|
||||||
for ((address, blockIds) <- blocksByAddress) {
|
val parallelFetches = BlockManager.getNumParallelFetchesFromSystemProperties
|
||||||
if (address == blockManagerId) {
|
val blocksToRequest = new Queue[(BlockManagerId, BlockMessage)]
|
||||||
localBlockIds ++= blockIds
|
|
||||||
} else {
|
def sendRequest(bmId: BlockManagerId, blockMessages: Seq[BlockMessage]) {
|
||||||
remoteBlockIds ++= blockIds
|
val cmId = new ConnectionManagerId(bmId.ip, bmId.port)
|
||||||
remoteBlockIdsPerLocation(address) = blockIds
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start getting remote blocks
|
|
||||||
for ((bmId, bIds) <- remoteBlockIdsPerLocation) {
|
|
||||||
val cmId = ConnectionManagerId(bmId.ip, bmId.port)
|
|
||||||
val blockMessages = bIds.map(bId => BlockMessage.fromGetBlock(GetBlock(bId)))
|
|
||||||
val blockMessageArray = new BlockMessageArray(blockMessages)
|
val blockMessageArray = new BlockMessageArray(blockMessages)
|
||||||
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
|
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
|
||||||
future.onSuccess {
|
future.onSuccess {
|
||||||
|
@ -312,17 +301,43 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
case None => {
|
case None => {
|
||||||
logError("Could not get blocks from " + cmId)
|
logError("Could not get block(s) from " + cmId)
|
||||||
for (blockId <- bIds) {
|
for (blockMessage <- blockMessages) {
|
||||||
results.put((blockId, None))
|
results.put((blockMessage.getId, None))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logDebug("Started remote gets for " + remoteBlockIds.size + " blocks in " +
|
|
||||||
|
// Split local and remote blocks. Remote blocks are further split into ones that will
|
||||||
|
// be requested initially and ones that will be added to a queue of blocks to request.
|
||||||
|
val initialRequestBlocks = new HashMap[BlockManagerId, ArrayBuffer[BlockMessage]]()
|
||||||
|
var initialRequests = 0
|
||||||
|
for ((address, blockIds) <- blocksByAddress) {
|
||||||
|
if (address == blockManagerId) {
|
||||||
|
localBlockIds ++= blockIds
|
||||||
|
} else {
|
||||||
|
remoteBlockIds ++= blockIds
|
||||||
|
blockIds.foreach{blockId =>
|
||||||
|
val blockMessage = BlockMessage.fromGetBlock(GetBlock(blockId))
|
||||||
|
if (initialRequests < parallelFetches) {
|
||||||
|
initialRequestBlocks.getOrElseUpdate(address, new ArrayBuffer[BlockMessage])
|
||||||
|
.append(blockMessage)
|
||||||
|
initialRequests += 1
|
||||||
|
} else {
|
||||||
|
blocksToRequest.enqueue((address, blockMessage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send out initial request(s) for 'parallelFetches' blocks.
|
||||||
|
for ((bmId, blockMessages) <- initialRequestBlocks) { sendRequest(bmId, blockMessages) }
|
||||||
|
|
||||||
|
logDebug("Started remote gets for " + parallelFetches + " blocks in " +
|
||||||
Utils.getUsedTimeMs(startTime) + " ms")
|
Utils.getUsedTimeMs(startTime) + " ms")
|
||||||
|
|
||||||
// Get the local blocks while remote blocks are being fetched
|
// Get the local blocks while remote blocks are being fetched.
|
||||||
startTime = System.currentTimeMillis
|
startTime = System.currentTimeMillis
|
||||||
localBlockIds.foreach(id => {
|
localBlockIds.foreach(id => {
|
||||||
getLocal(id) match {
|
getLocal(id) match {
|
||||||
|
@ -337,7 +352,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
|
||||||
})
|
})
|
||||||
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
|
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
|
||||||
|
|
||||||
// Return an iterator that will read fetched blocks off the queue as they arrive
|
// Return an iterator that will read fetched blocks off the queue as they arrive.
|
||||||
return new Iterator[(String, Option[Iterator[Any]])] {
|
return new Iterator[(String, Option[Iterator[Any]])] {
|
||||||
var resultsGotten = 0
|
var resultsGotten = 0
|
||||||
|
|
||||||
|
@ -346,6 +361,10 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
|
||||||
def next(): (String, Option[Iterator[Any]]) = {
|
def next(): (String, Option[Iterator[Any]]) = {
|
||||||
resultsGotten += 1
|
resultsGotten += 1
|
||||||
val (blockId, functionOption) = results.take()
|
val (blockId, functionOption) = results.take()
|
||||||
|
if (remoteBlockIds.contains(blockId) && !blocksToRequest.isEmpty) {
|
||||||
|
val (bmId, blockMessage) = blocksToRequest.dequeue
|
||||||
|
sendRequest(bmId, Seq(blockMessage))
|
||||||
|
}
|
||||||
(blockId, functionOption.map(_.apply()))
|
(blockId, functionOption.map(_.apply()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -598,6 +617,11 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
|
||||||
}
|
}
|
||||||
|
|
||||||
object BlockManager {
|
object BlockManager {
|
||||||
|
|
||||||
|
def getNumParallelFetchesFromSystemProperties(): Int = {
|
||||||
|
System.getProperty("spark.blockManager.parallelFetches", "8").toInt
|
||||||
|
}
|
||||||
|
|
||||||
def getMaxMemoryFromSystemProperties(): Long = {
|
def getMaxMemoryFromSystemProperties(): Long = {
|
||||||
val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble
|
val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble
|
||||||
(Runtime.getRuntime.maxMemory * memoryFraction).toLong
|
(Runtime.getRuntime.maxMemory * memoryFraction).toLong
|
||||||
|
|
Загрузка…
Ссылка в новой задаче