Skip to content

Commit

Permalink
Merge pull request amplab#68 from mosharaf/master
Browse files Browse the repository at this point in the history
Faster and stable/reliable broadcast

HttpBroadcast is noticeably slow, but the alternatives (TreeBroadcast or BitTorrentBroadcast) are notoriously unreliable. The main problem with them is they try to manage the memory for the pieces of a broadcast themselves. Right now, the BroadcastManager does not know which machines the tasks reading from a broadcast variable is running and when they have finished. Consequently, we try to guess and often guess wrong, which blows up the memory usage and kills/hangs jobs.

This very simple implementation solves the problem by not trying to manage the intermediate pieces; instead, it offloads that duty to the BlockManager which is quite good at juggling blocks. Otherwise, it is very similar to the BitTorrentBroadcast implementation (without fancy optimizations). And it runs much faster than HttpBroadcast we have right now.

I've been using this for another project for last couple of weeks, and just today did some benchmarking against the Http one. The following shows the improvements for increasing broadcast size for cold runs. Each line represent the number of receivers.
![fix-bc-first](https://f.cloud.github.com/assets/232966/1349342/ffa149e4-36e7-11e3-9fa6-c74555829356.png)

After the first broadcast is over, i.e., after JVM is wormed up and for HttpBroadcast the server is already running (I think), the following are the improvements for warm runs.
![fix-bc-succ](https://f.cloud.github.com/assets/232966/1349352/5a948bae-36e8-11e3-98ce-34f19ebd33e0.jpg)
The curves are not as nice as the cold runs, but the improvements are obvious, specially for larger broadcasts and more receivers.

Depending on how it goes, we should deprecate and/or remove old TreeBroadcast and BitTorrentBroadcast implementations, and hopefully, SPARK-889 will not be necessary any more.
  • Loading branch information
mateiz committed Oct 19, 2013
2 parents 8d528af + 08391db commit e5316d0
Show file tree
Hide file tree
Showing 7 changed files with 328 additions and 12 deletions.
247 changes: 247 additions & 0 deletions core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.broadcast

import java.io._

import scala.math
import scala.util.Random

import org.apache.spark._
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
import org.apache.spark.util.Utils


private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {

def value = value_

def broadcastId = BroadcastBlockId(id)

TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
}

@transient var arrayOfBlocks: Array[TorrentBlock] = null
@transient var totalBlocks = -1
@transient var totalBytes = -1
@transient var hasBlocks = 0

if (!isLocal) {
sendBroadcast()
}

def sendBroadcast() {
var tInfo = TorrentBroadcast.blockifyObject(value_)

totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
hasBlocks = tInfo.totalBlocks

// Store meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
}

// Store individual pieces
for (i <- 0 until totalBlocks) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
}
}
}

// Called by JVM when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(broadcastId) match {
case Some(x) =>
value_ = x.asInstanceOf[T]

case None =>
val start = System.nanoTime
logInfo("Started reading broadcast variable " + id)

// Initialize @transient variables that will receive garbage values from the master.
resetWorkerVariables()

if (receiveBroadcast(id)) {
value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)

// Store the merged copy in cache so that the next worker doesn't need to rebuild it.
// This creates a tradeoff between memory usage and latency.
// Storing copy doubles the memory footprint; not storing doubles deserialization cost.
SparkEnv.get.blockManager.putSingle(
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)

// Remove arrayOfBlocks from memory once value_ is on local cache
resetWorkerVariables()
} else {
logError("Reading broadcast variable " + id + " failed")
}

val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}

private def resetWorkerVariables() {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
}

def receiveBroadcast(variableID: Long): Boolean = {
// Receive meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
var attemptId = 10
while (attemptId > 0 && totalBlocks == -1) {
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(metaId) match {
case Some(x) =>
val tInfo = x.asInstanceOf[TorrentInfo]
totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
hasBlocks = 0

case None =>
Thread.sleep(500)
}
}
attemptId -= 1
}
if (totalBlocks == -1) {
return false
}

// Receive actual blocks
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
for (pid <- recvOrder) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(pieceId) match {
case Some(x) =>
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
hasBlocks += 1
SparkEnv.get.blockManager.putSingle(
pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)

case None =>
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
}
}
}

(hasBlocks == totalBlocks)
}

}

private object TorrentBroadcast
extends Logging {

private var initialized = false

def initialize(_isDriver: Boolean) {
synchronized {
if (!initialized) {
initialized = true
}
}
}

def stop() {
initialized = false
}

val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024

def blockifyObject[T](obj: T): TorrentInfo = {
val byteArray = Utils.serialize[T](obj)
val bais = new ByteArrayInputStream(byteArray)

var blockNum = (byteArray.length / BLOCK_SIZE)
if (byteArray.length % BLOCK_SIZE != 0)
blockNum += 1

var retVal = new Array[TorrentBlock](blockNum)
var blockID = 0

for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
var tempByteArray = new Array[Byte](thisBlockSize)
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)

retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
blockID += 1
}
bais.close()

var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
tInfo.hasBlocks = blockNum

return tInfo
}

def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
totalBytes: Int,
totalBlocks: Int): T = {
var retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
}
Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
}

}

private[spark] case class TorrentBlock(
blockID: Int,
byteArray: Array[Byte])
extends Serializable

private[spark] case class TorrentInfo(
@transient arrayOfBlocks : Array[TorrentBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {

@transient var hasBlocks = 0
}

private[spark] class TorrentBroadcastFactory
extends BroadcastFactory {

def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) }

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)

def stop() { TorrentBroadcast.stop() }
}
9 changes: 8 additions & 1 deletion core/src/main/scala/org/apache/spark/storage/BlockId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ private[spark] sealed abstract class BlockId {
def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
def isRDD = isInstanceOf[RDDBlockId]
def isShuffle = isInstanceOf[ShuffleBlockId]
def isBroadcast = isInstanceOf[BroadcastBlockId]
def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId]

override def toString = name
override def hashCode = name.hashCode
Expand All @@ -55,6 +55,10 @@ private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
def name = "broadcast_" + broadcastId
}

private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
def name = broadcastId.name + "_" + hType
}

private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
def name = "taskresult_" + taskId
}
Expand All @@ -72,6 +76,7 @@ private[spark] object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
val BROADCAST = "broadcast_([0-9]+)".r
val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
val TEST = "test_(.*)".r
Expand All @@ -84,6 +89,8 @@ private[spark] object BlockId {
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
case BROADCAST(broadcastId) =>
BroadcastBlockId(broadcastId.toLong)
case BROADCAST_HELPER(broadcastId, hType) =>
BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType)
case TASKRESULT(taskId) =>
TaskResultBlockId(taskId.toLong)
case STREAM(streamId, uniqueId) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.{InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}

import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
import scala.util.Random

import akka.actor.{ActorSystem, Cancellable, Props}
import akka.dispatch.{Await, Future}
Expand Down Expand Up @@ -269,7 +270,7 @@ private[spark] class BlockManager(
}

/**
* Actually send a UpdateBlockInfo message. Returns the mater's response,
* Actually send a UpdateBlockInfo message. Returns the master's response,
* which will be true if the block was successfully recorded and false if
* the slave needs to re-register.
*/
Expand Down Expand Up @@ -478,7 +479,7 @@ private[spark] class BlockManager(
}
logDebug("Getting remote block " + blockId)
// Get locations of block
val locations = master.getLocations(blockId)
val locations = Random.shuffle(master.getLocations(blockId))

// Get block from remote locations
for (loc <- locations) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}

private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
if (id.executorId == "<driver>" && !isLocal) {
// Got a register message from the master node; don't register it
} else if (!blockManagerInfo.contains(id)) {
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
case Some(manager) =>
// A block manager of the same executor already exists.
Expand Down
52 changes: 49 additions & 3 deletions core/src/test/scala/org/apache/spark/BroadcastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,66 @@ package org.apache.spark
import org.scalatest.FunSuite

class BroadcastSuite extends FunSuite with LocalSparkContext {

test("basic broadcast") {

override def afterEach() {
super.afterEach()
System.clearProperty("spark.broadcast.factory")
}

test("Using HttpBroadcast locally") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
sc = new SparkContext("local", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === Set((1, 10), (2, 10)))
}

test("Accessing HttpBroadcast variables from multiple threads") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
sc = new SparkContext("local[10]", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
}

test("Accessing HttpBroadcast variables in a local cluster") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
val numSlaves = 4
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
}

test("Using TorrentBroadcast locally") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === Set((1, 10), (2, 10)))
}

test("broadcast variables accessed in multiple threads") {
test("Accessing TorrentBroadcast variables from multiple threads") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local[10]", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
}

test("Accessing TorrentBroadcast variables in a local cluster") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
val numSlaves = 4
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
}

}
Loading

0 comments on commit e5316d0

Please sign in to comment.