上一篇文章我們分析了Shuffle的write部分蜈出,本文中我們來(lái)繼續(xù)分析Shuffle的read部分田弥。
我們來(lái)看ShuffledRDD中的compute方法:
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
可以看到首先調(diào)用的是ShuffleManager的getReader方法來(lái)獲得ShuffleReader,然后再調(diào)用ShuffleReader的read方法來(lái)讀取map階段輸出的中間數(shù)據(jù)铡原,而不管是HashShuffleManager還是SortShuffleManager偷厦,其getReader方法內(nèi)部都是實(shí)例化了BlockStoreShuffleReader,而B(niǎo)lockStoreShuffleReader正是實(shí)現(xiàn)了ShuffleReader接口:
override def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}
然后來(lái)看BlockStoreShuffleReader的read方法是具體如何工作的燕刻,即如何讀取Map階段輸出的中間結(jié)果的:
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
// 首先實(shí)例化ShuffleBlockFetcherIterator
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
// Wrap the streams for compression based on configuration
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
blockManager.wrapForCompression(blockId, inputStream)
}
val ser = Serializer.getSerializer(dep.serializer)
val serializerInstance = ser.newInstance()
// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { wrappedStream =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map(record => {
readMetrics.incRecordsRead(1)
record
}),
context.taskMetrics().updateShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
}
首先實(shí)例化ShuffleBlockFetcherIterator只泼,實(shí)例化的時(shí)候傳入了幾個(gè)參數(shù),這里介紹一下幾個(gè)重要的:
- blockManager.shuffleClient
- mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)
- SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024
shuffleClient就是用來(lái)讀取其他executors上的shuffle文件的卵洗,有可能是ExternalShuffleClient或者BlockTransferService:
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores)
new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(),
securityManager.isSaslEncryptionEnabled())
} else {
blockTransferService
}
而默認(rèn)使用的是BlockTransferService请唱,因?yàn)閑xternalShuffleServiceEnabled默認(rèn)為false:
private[spark]
val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
接下來(lái)的mapOutputTracker.getMapSizesByExecutorId就是獲得該reduce task的數(shù)據(jù)來(lái)源(數(shù)據(jù)的元數(shù)據(jù)信息),傳入的參數(shù)是shuffle的Id和partition的起始位置忌怎,返回的是Seq[(BlockManagerId, Seq[(BlockId, Long)])]籍滴,也就是說(shuō)數(shù)據(jù)是來(lái)自于哪個(gè)節(jié)點(diǎn)的哪些block的,并且block的數(shù)據(jù)大小是多少榴啸。
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
// 獲得Map階段輸出的中間計(jì)算結(jié)果的元數(shù)據(jù)信息
val statuses = getStatuses(shuffleId)
// Synchronize on the returned array because, on the driver, it gets mutated in place
// 將獲得的元數(shù)據(jù)信息轉(zhuǎn)化成形如Seq[(BlockManagerId, Seq[(BlockId, Long)])]格式的位置信息孽惰,用來(lái)讀取指定的Map階段產(chǎn)生的數(shù)據(jù)
statuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
}
}
最后的getSizeAsMb獲取的是一項(xiàng)配置參數(shù),代表一次從Map端獲取的最大的數(shù)據(jù)量鸥印。
獲取元數(shù)據(jù)信息
下面我們來(lái)著重分析一下是怎樣通過(guò)getStatuses來(lái)獲取元數(shù)據(jù)信息的:
private def getStatuses(shuffleId: Int): Array[MapStatus] = {
// 根據(jù)shuffleId獲得MapStatus組成的數(shù)組:Array[MapStatus]
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
// 如果沒(méi)有獲取到就進(jìn)行fetch操作
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
val startTime = System.currentTimeMillis
// 用來(lái)保存fetch來(lái)的MapStatus
var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized {
// 有可能有別的任務(wù)正在進(jìn)行fetch勋功,所以這里使用synchronized關(guān)鍵字保證同步
// Someone else is fetching it; wait for them to be done
while (fetching.contains(shuffleId)) {
try {
fetching.wait()
} catch {
case e: InterruptedException =>
}
}
// Either while we waited the fetch happened successfully, or
// someone fetched it in between the get and the fetching.synchronized.
// 等待過(guò)后繼續(xù)嘗試獲取
fetchedStatuses = mapStatuses.get(shuffleId).orNull
if (fetchedStatuses == null) {
// We have to do the fetch, get others to wait for us.
fetching += shuffleId
}
}
if (fetchedStatuses == null) {
// 如果得到了fetch的權(quán)利就進(jìn)行抓取
// We won the race to fetch the statuses; do so
logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
// This try-finally prevents hangs due to timeouts:
try {
// 調(diào)用askTracker方法發(fā)送消息坦报,消息的格式為GetMapOutputStatuses(shuffleId)
val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
// 將得到的序列化后的數(shù)據(jù)進(jìn)行反序列化
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
// 保存到本地的mapStatuses中
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
}
logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
s"${System.currentTimeMillis - startTime} ms")
if (fetchedStatuses != null) {
// 最后將抓取到的元數(shù)據(jù)信息返回
return fetchedStatuses
} else {
logError("Missing all output locations for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
}
} else {
// 如果獲取到了Array[MapStatus]就直接返回
return statuses
}
}
來(lái)看一下用來(lái)發(fā)送消息的askTracker方法,發(fā)送的消息是一個(gè)case class:GetMapOutputStatuses(shuffleId)
protected def askTracker[T: ClassTag](message: Any): T = {
try {
trackerEndpoint.askWithRetry[T](message)
} catch {
case e: Exception =>
logError("Error communicating with MapOutputTracker", e)
throw new SparkException("Error communicating with MapOutputTracker", e)
}
}
MapOutputTrackerMasterEndpoint在接收到該消息后的處理:
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = context.senderAddress.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
// 獲得Map階段的輸出數(shù)據(jù)的序列化后的元數(shù)據(jù)信息
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
// 序列化后的大小
val serializedSize = mapOutputStatuses.length
// 判斷是否超過(guò)maxAkkaFrameSize的限制
if (serializedSize > maxAkkaFrameSize) {
val msg = s"Map output statuses were $serializedSize bytes which " +
s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."
/* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender.
* A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */
val exception = new SparkException(msg)
logError(msg, exception)
context.sendFailure(exception)
} else {
// 如果沒(méi)有超過(guò)限制就將獲得的元數(shù)據(jù)信息返回
context.reply(mapOutputStatuses)
}
case StopMapOutputTracker =>
logInfo("MapOutputTrackerMasterEndpoint stopped!")
context.reply(true)
stop()
}
接著來(lái)看tracker(MapOutputTrackerMaster)的getSerializedMapOutputStatuses方法:
def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var epochGotten: Long = -1
epochLock.synchronized {
if (epoch > cacheEpoch) {
cachedSerializedStatuses.clear()
cacheEpoch = epoch
}
// 判斷是否已經(jīng)有緩存的數(shù)據(jù)
cachedSerializedStatuses.get(shuffleId) match {
case Some(bytes) =>
// 如果有的話就直接返回緩存數(shù)據(jù)
return bytes
case None =>
// 如果沒(méi)有的話就從mapStatuses中獲得
statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]())
epochGotten = epoch
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "statuses"; let's serialize and return that
// 序列化操作
val bytes = MapOutputTracker.serializeMapStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
epochLock.synchronized {
if (epoch == epochGotten) {
// 緩存操作
cachedSerializedStatuses(shuffleId) = bytes
}
}
// 返回序列化后的數(shù)據(jù)
bytes
}
獲得序列化后的數(shù)據(jù)后會(huì)到getStatuses方法狂鞋,將得到的序列化后的數(shù)據(jù)進(jìn)行反序列化片择,并將反序列化后的數(shù)據(jù)保存到該Executor(也就是本地)的mapStatuses中,下次再使用的時(shí)候就不必重復(fù)的進(jìn)行fetch操作骚揍,最后將獲得的元數(shù)據(jù)信息轉(zhuǎn)化成形如Seq[(BlockManagerId, Seq[(BlockId, Long)])]格式的位置信息字管,用來(lái)讀取指定的Map階段產(chǎn)生的數(shù)據(jù)。
根據(jù)得到的元數(shù)據(jù)信息抓取數(shù)據(jù)(分為遠(yuǎn)程和本地)
說(shuō)完了這些參數(shù)后我們回到ShuffleBlockFetcherIterator的實(shí)例化過(guò)程信不,ShuffleBlockFetcherIterator實(shí)例化的時(shí)候會(huì)執(zhí)行一個(gè)initialize()方法嘲叔,用來(lái)進(jìn)行一系列的初始化操作:
private[this] def initialize(): Unit = {
// Add a task completion callback (called in both success case and failure case) to cleanup.
// 不管最后task是success還是failure,都要進(jìn)行cleanup操作
context.addTaskCompletionListener(_ => cleanup())
// Split local and remote blocks.
// 將local和remote Blocks分離開(kāi)抽活,并將remote的返回給remoteRequests
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
// 這里的fetchRequests是一個(gè)隊(duì)列硫戈,我們將遠(yuǎn)程的請(qǐng)求以隨機(jī)的順序加入到該隊(duì)列,然后使用下面的
// fetchUpToMaxBytes方法取出隊(duì)列中的遠(yuǎn)程請(qǐng)求下硕,同時(shí)對(duì)大小進(jìn)行限制
fetchRequests ++= Utils.randomize(remoteRequests)
// Send out initial requests for blocks, up to our maxBytesInFlight
// 從fetchRequests取出遠(yuǎn)程請(qǐng)求丁逝,并使用sendRequest方法發(fā)送請(qǐng)求
fetchUpToMaxBytes()
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
// Get Local Blocks
// 獲取本地的Blocks
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}
首先來(lái)看一下splitLocalRemoteBlocks方法是如何將remote和local的blocks分離開(kāi)來(lái)的:
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
// 為了將大小控制在maxBytesInFlight以下,可以增加并行度梭姓,即從1個(gè)節(jié)點(diǎn)增加到5個(gè)
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest]
// Tracks total number of blocks (including zero sized blocks)
var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
totalBlocks += blockInfos.size
// 這里就是判斷所要獲取的是本地的block還是遠(yuǎn)程的block
if (address.executorId == blockManager.blockManagerId.executorId) {
// Filter out zero-sized blocks
// 過(guò)濾掉大小為0的blocks
localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
numBlocksToFetch += localBlocks.size
} else {
// 這里就是遠(yuǎn)程的部分霜幼,主要就是構(gòu)建remoteRequests
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(BlockId, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
// Skip empty blocks
if (size > 0) {
curBlocks += ((blockId, size))
remoteBlocks += blockId
numBlocksToFetch += 1
curRequestSize += size
} else if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
}
// 滿足大小的限制就構(gòu)建一個(gè)FetchRequest并加入到remoteRequests中
if (curRequestSize >= targetRequestSize) {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
curBlocks = new ArrayBuffer[(BlockId, Long)]
logDebug(s"Creating fetch request of $curRequestSize at $address")
curRequestSize = 0
}
}
// Add in the final request
// 最后將剩余的blocks構(gòu)成一個(gè)FetchRequest
if (curBlocks.nonEmpty) {
remoteRequests += new FetchRequest(address, curBlocks)
}
}
}
logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
// 最后返回remoteRequests
remoteRequests
}
這里的FetchRequest是一個(gè)數(shù)據(jù)結(jié)構(gòu),保存了要獲取的blocks的位置信息糊昙,而remoteRequests就是這些FetchRequest組成的ArrayBuffer:
case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) {
val size = blocks.map(_._2).sum
}
然后使用fetchUpToMaxBytes()方法來(lái)獲取遠(yuǎn)程的blocks信息:
private def fetchUpToMaxBytes(): Unit = {
// Send fetch requests up to maxBytesInFlight
while (fetchRequests.nonEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
}
可以看出內(nèi)部就是從上一步獲取的remoteRequests中取出一個(gè)FetchRequest并使用sendRequest發(fā)送該請(qǐng)求:
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
bytesInFlight += req.size
// 1辛掠、首先獲得要fetch的blocks的信息
// so we can look up the size of each blockID
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val blockIds = req.blocks.map(_._1.toString)
val address = req.address
// 2、然后通過(guò)shuffleClient的fetchBlocks方法來(lái)獲取對(duì)應(yīng)遠(yuǎn)程節(jié)點(diǎn)上的數(shù)據(jù)
// 默認(rèn)是通過(guò)NettyBlockTransferService的fetchBlocks方法實(shí)現(xiàn)的
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
new BlockFetchingListener {
// 3释牺、最后萝衩,不管成功還是失敗,都將結(jié)果保存在results中
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// Only add the buffer to results queue if the iterator is not zombie,
// i.e. cleanup() has not been called yet.
if (!isZombie) {
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf))
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
}
)
}
請(qǐng)求的結(jié)果最終保存在了results中没咙,成功的就是SuccessFetchResult猩谊,失敗的就是FailureFetchResult,具體是怎么fetchBlocks的就不在此說(shuō)明祭刚,本文最后會(huì)給出一張圖進(jìn)行簡(jiǎn)要的概述牌捷,有興趣的可以繼續(xù)進(jìn)行追蹤,其實(shí)底層是通過(guò)NettyBlockTransferService實(shí)現(xiàn)的涡驮,通過(guò)index文件查找到data文件暗甥。
接下來(lái)看一下使用fetchLocalBlocks()方法來(lái)獲取本地的blocks信息的過(guò)程:
private[this] def fetchLocalBlocks() {
val iter = localBlocks.iterator
while (iter.hasNext) {
val blockId = iter.next()
try {
val buf = blockManager.getBlockData(blockId)
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
return
}
}
}
這里就相對(duì)簡(jiǎn)單了,進(jìn)行迭代捉捅,如果獲取到就將SuccessFetchResult保存到results中撤防,如果沒(méi)有就將FailureFetchResult保存到results中,至此ShuffleBlockFetcherIterator的實(shí)例化及初始化過(guò)程結(jié)束棒口,接下來(lái)我們?cè)倩氐紹lockStoreShuffleReader的read方法中:
override def read(): Iterator[Product2[K, C]] = {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
// Wrap the streams for compression based on configuration
// 將上面獲取的信息進(jìn)行壓縮處理
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
blockManager.wrapForCompression(blockId, inputStream)
}
val ser = Serializer.getSerializer(dep.serializer)
val serializerInstance = ser.newInstance()
// Create a key/value iterator for each stream
// 為每個(gè)stream創(chuàng)建一個(gè)key/value的iterator
val recordIter = wrappedStreams.flatMap { wrappedStream =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// 統(tǒng)計(jì)系統(tǒng)相關(guān)的部分
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map(record => {
readMetrics.incRecordsRead(1)
record
}),
context.taskMetrics().updateShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
// 判斷是否需要進(jìn)行map端的聚合操作
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// 是否需要進(jìn)行排序
// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabl
// the ExternalSorter won't spill to disk.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.sto
case None =>
aggregatedIter
}
}
上面代碼中aggregator和keyOrdering的部分在分析Shuffle Write的時(shí)候已經(jīng)分析過(guò)了寄月,這里我們?cè)俸?jiǎn)單看一下相關(guān)的部分辜膝。
aggregator
首先來(lái)看aggregator部分:不管是否進(jìn)行聚合操作,即不管最后執(zhí)行的是combineCombinersByKey方法還是combineValuesByKey方法漾肮,最后都會(huì)執(zhí)行ExternalAppendOnlyMap的insertAll方法:
combineCombinersByKey方法的實(shí)現(xiàn):
def combineCombinersByKey(
iter: Iterator[_ <: Product2[K, C]],
context: TaskContext): Iterator[(K, C)] = {
val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
combiners.insertAll(iter)
updateMetrics(context, combiners)
combiners.iterator
}
combineValuesByKey方法的實(shí)現(xiàn):
def combineValuesByKey(
iter: Iterator[_ <: Product2[K, V]],
context: TaskContext): Iterator[(K, C)] = {
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
combiners.insertAll(iter)
updateMetrics(context, combiners)
combiners.iterator
}
而這個(gè)insertAll方法在Shuffle Write的部分已經(jīng)介紹過(guò)了:
def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
if (currentMap == null) {
throw new IllegalStateException(
"Cannot insert new elements into a map after calling iterator")
}
// An update function for the map that we reuse across entries to avoid allocating
// a new closure each time
var curEntry: Product2[K, V] = null
val update: (Boolean, C) => C = (hadVal, oldVal) => {
if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2)
}
while (entries.hasNext) {
curEntry = entries.next()
val estimatedSize = currentMap.estimateSize()
if (estimatedSize > _peakMemoryUsedBytes) {
_peakMemoryUsedBytes = estimatedSize
}
if (maybeSpill(currentMap, estimatedSize)) {
currentMap = new SizeTrackingAppendOnlyMap[K, C]
}
currentMap.changeValue(curEntry._1, update)
addElementsRead()
}
}
而這里的next方法會(huì)最終調(diào)用ShuffleBlockFetcherIterator的next方法:
override def next(): (BlockId, InputStream) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
val result = currentResult
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
result match {
case SuccessFetchResult(_, _, size, _) => bytesInFlight -= size
case _ =>
}
// Send fetch requests up to maxBytesInFlight
// 這里就是關(guān)鍵的代碼厂抖,即不斷的去抓去數(shù)據(jù),直到抓去到所有的數(shù)據(jù)
fetchUpToMaxBytes()
result match {
case FailureFetchResult(blockId, address, e) =>
throwFetchFailedException(blockId, address, e)
case SuccessFetchResult(blockId, address, _, buf) =>
try {
(result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
} catch {
case NonFatal(t) =>
throwFetchFailedException(blockId, address, t)
}
}
}
可以看到該方法中就是一直抓數(shù)據(jù)克懊,直到所有的數(shù)據(jù)都抓取到忱辅,然后就是執(zhí)行combiners.iterator:
override def iterator: Iterator[(K, C)] = {
if (currentMap == null) {
throw new IllegalStateException(
"ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
}
if (spilledMaps.isEmpty) {
CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap())
} else {
new ExternalIterator()
}
}
接下來(lái)就看一下ExternalIterator的實(shí)例化都做了什么工作:
// A queue that maintains a buffer for each stream we are currently merging
// This queue maintains the invariant that it only contains non-empty buffers
private val mergeHeap = new mutable.PriorityQueue[StreamBuffer]
// Input streams are derived both from the in-memory map and spilled maps on disk
// The in-memory map is sorted in place, while the spilled maps are already in sorted order
// 按照key的hashcode進(jìn)行排序
private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](
currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap())
// 將map中的數(shù)據(jù)和spillFile中的數(shù)據(jù)的iterator組合在一起
private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)
// 不斷迭代,直到將所有數(shù)據(jù)都讀出來(lái)谭溉,最后將所有的數(shù)據(jù)保存在mergeHeap中
inputStreams.foreach { it =>
val kcPairs = new ArrayBuffer[(K, C)]
readNextHashCode(it, kcPairs)
if (kcPairs.length > 0) {
mergeHeap.enqueue(new StreamBuffer(it, kcPairs))
}
}
最后將所有讀取的數(shù)據(jù)都保存在了mergeHeap中耕蝉,再來(lái)看一下有排序的情況。
keyOrdering
// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
可以看到使用了ExternalSorter的insertAll方法夜只,和Shuffle Write的時(shí)候操作是一樣的,這里我們就不進(jìn)行重復(fù)說(shuō)明了蒜魄,具體的內(nèi)容可以參考上一篇文章扔亥,最后還是用張圖來(lái)總結(jié)一下Shuffle Read的流程:
本文參照的是Spark 1.6.3版本的源碼,同時(shí)給出Spark 2.1.0版本的連接:
本文為原創(chuàng)谈为,歡迎轉(zhuǎn)載旅挤,轉(zhuǎn)載請(qǐng)注明出處、作者伞鲫,謝謝粘茄!