Merge pull request #390 from qintao1976/dev

Make CSharpBackendHandler thread safe
This commit is contained in:
Tao Qin 2016-04-19 11:49:53 +08:00
Родитель 416245b2fe f6e20ebd19
Коммит 674944d093
2 изменённых файлов: 20 добавлений и 16 удалений

Просмотреть файл

@ -26,7 +26,7 @@ import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
*/
// Since SparkCLR is a package to Spark and not a part of spark-core it mirrors the implementation of
// selected parts from RBackend with SparkCLR customizations
class CSharpBackend {
class CSharpBackend { self => // for accessing the this reference in inner class(ChannelInitializer)
private[this] var channelFuture: ChannelFuture = null
private[this] var bootstrap: ServerBootstrap = null
private[this] var bossGroup: EventLoopGroup = null
@ -35,7 +35,6 @@ class CSharpBackend {
// need at least 3 threads, use 10 here for safety
bossGroup = new NioEventLoopGroup(10)
val workerGroup = bossGroup
val handler = new CSharpBackendHandler(this) //TODO - work with SparkR devs to make this configurable and reuse RBackend
bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
@ -54,7 +53,8 @@ class CSharpBackend {
//new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
.addLast("decoder", new ByteArrayDecoder())
.addLast("handler", handler)
//TODO - work with SparkR devs to make this configurable and reuse RBackend
.addLast("handler", new CSharpBackendHandler(self))
}
})

Просмотреть файл

@ -23,7 +23,6 @@ import scala.collection.mutable.HashMap
*/
// Since SparkCLR is a package to Spark and not a part of spark-core, it mirrors the implementation
// of selected parts from RBackend with SparkCLR customizations
@Sharable
class CSharpBackendHandler(server: CSharpBackend) extends SimpleChannelInboundHandler[Array[Byte]] {
override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = {
@ -269,31 +268,36 @@ class CSharpBackendHandler(server: CSharpBackend) extends SimpleChannelInboundHa
*/
private object JVMObjectTracker {
// TODO: This map should be thread-safe if we want to support multiple
// connections at the same time
// Muliple threads may access objMap and increase objCounter. Because get method return Option,
// it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap.
private[this] val objMap = new HashMap[String, Object]
// TODO: We support only one connection now, so an integer is fine.
// Investigate using use atomic integer in the future.
private[this] var objCounter: Int = 1
def getObject(id: String): Object = {
objMap(id)
synchronized {
objMap(id)
}
}
def get(id: String): Option[Object] = {
objMap.get(id)
synchronized {
objMap.get(id)
}
}
def put(obj: Object): String = {
val objId = objCounter.toString
objCounter = objCounter + 1
objMap.put(objId, obj)
objId
synchronized {
val objId = objCounter.toString
objCounter = objCounter + 1
objMap.put(objId, obj)
objId
}
}
def remove(id: String): Option[Object] = {
objMap.remove(id)
synchronized {
objMap.remove(id)
}
}
}