Merge pull request #107 from qintao1976/current

add thread-safe feature to callback socket server through a socket co…
This commit is contained in:
Tao Qin 2015-12-02 05:59:30 +08:00
Родитель b480379b77 e726b082c1
Коммит ea5c734cf6
4 изменённых файлов: 139 добавлений и 90 удалений

Просмотреть файл

@ -10,6 +10,7 @@ using System.Net.Sockets;
using System.Runtime.Serialization;
using System.Runtime.Serialization.Formatters.Binary;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Spark.CSharp.Core;
@ -136,6 +137,83 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc
SparkCLRIpcProxy.JvmBridge.CallNonStaticJavaMethod(jvmStreamingContextReference, "awaitTermination", new object[] { timeout });
}
private void ProcessCallbackRequest(object socket)
{
logger.LogInfo("new thread created to process callback request");
try
{
using (Socket sock = (Socket)socket)
using (var s = new NetworkStream(sock))
{
while (true)
{
try
{
string cmd = SerDe.ReadString(s);
if (cmd == "close")
{
logger.LogInfo("receive close cmd from Scala side");
break;
}
else if (cmd == "callback")
{
int numRDDs = SerDe.ReadInt(s);
var jrdds = new List<JvmObjectReference>();
for (int i = 0; i < numRDDs; i++)
{
jrdds.Add(new JvmObjectReference(SerDe.ReadObjectId(s)));
}
double time = SerDe.ReadDouble(s);
IFormatter formatter = new BinaryFormatter();
object func = formatter.Deserialize(new MemoryStream(SerDe.ReadBytes(s)));
string deserializer = SerDe.ReadString(s);
RDD<dynamic> rdd = null;
if (jrdds[0].Id != null)
rdd = new RDD<dynamic>(new RDDIpcProxy(jrdds[0]), sparkContext, (SerializedMode)Enum.Parse(typeof(SerializedMode), deserializer));
if (func is Func<double, RDD<dynamic>, RDD<dynamic>>)
{
JvmObjectReference jrdd = (((Func<double, RDD<dynamic>, RDD<dynamic>>)func)(time, rdd).RddProxy as RDDIpcProxy).JvmRddReference;
SerDe.Write(s, (byte)'j');
SerDe.Write(s, jrdd.Id);
}
else if (func is Func<double, RDD<dynamic>, RDD<dynamic>, RDD<dynamic>>)
{
string deserializer2 = SerDe.ReadString(s);
RDD<dynamic> rdd2 = new RDD<dynamic>(new RDDIpcProxy(jrdds[1]), sparkContext, (SerializedMode)Enum.Parse(typeof(SerializedMode), deserializer2));
JvmObjectReference jrdd = (((Func<double, RDD<dynamic>, RDD<dynamic>, RDD<dynamic>>)func)(time, rdd, rdd2).RddProxy as RDDIpcProxy).JvmRddReference;
SerDe.Write(s, (byte)'j');
SerDe.Write(s, jrdd.Id);
}
else
{
((Action<double, RDD<dynamic>>)func)(time, rdd);
SerDe.Write(s, (byte)'n');
}
}
}
catch (Exception e)
{
//log exception only when callback socket is not shutdown explicitly
if (!callbackSocketShutdown)
{
logger.LogException(e);
}
}
}
}
}
catch (Exception e)
{
logger.LogException(e);
}
logger.LogInfo("thread to process callback request exit");
}
public int StartCallback()
{
TcpListener callbackServer = new TcpListener(IPAddress.Parse("127.0.0.1"), 0);
@ -145,72 +223,16 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc
{
try
{
using (Socket sock = callbackServer.AcceptSocket())
using (var s = new NetworkStream(sock))
ThreadPool.SetMaxThreads(10, 10);
while (!callbackSocketShutdown)
{
while (true)
{
try
{
string cmd = SerDe.ReadString(s);
if (cmd == "close")
{
logger.LogInfo("receive close cmd from Scala side");
break;
}
else if (cmd == "callback")
{
int numRDDs = SerDe.ReadInt(s);
var jrdds = new List<JvmObjectReference>();
for (int i = 0; i < numRDDs; i++)
{
jrdds.Add(new JvmObjectReference(SerDe.ReadObjectId(s)));
}
double time = SerDe.ReadDouble(s);
IFormatter formatter = new BinaryFormatter();
object func = formatter.Deserialize(new MemoryStream(SerDe.ReadBytes(s)));
string deserializer = SerDe.ReadString(s);
RDD<dynamic> rdd = null;
if (jrdds[0].Id != null)
rdd = new RDD<dynamic>(new RDDIpcProxy(jrdds[0]), sparkContext, (SerializedMode)Enum.Parse(typeof(SerializedMode), deserializer));
if (func is Func<double, RDD<dynamic>, RDD<dynamic>>)
{
JvmObjectReference jrdd = (((Func<double, RDD<dynamic>, RDD<dynamic>>)func)(time, rdd).RddProxy as RDDIpcProxy).JvmRddReference;
SerDe.Write(s, (byte)'j');
SerDe.Write(s, jrdd.Id);
}
else if (func is Func<double, RDD<dynamic>, RDD<dynamic>, RDD<dynamic>>)
{
string deserializer2 = SerDe.ReadString(s);
RDD<dynamic> rdd2 = new RDD<dynamic>(new RDDIpcProxy(jrdds[1]), sparkContext, (SerializedMode)Enum.Parse(typeof(SerializedMode), deserializer2));
JvmObjectReference jrdd = (((Func<double, RDD<dynamic>, RDD<dynamic>, RDD<dynamic>>)func)(time, rdd, rdd2).RddProxy as RDDIpcProxy).JvmRddReference;
SerDe.Write(s, (byte)'j');
SerDe.Write(s, jrdd.Id);
}
else
{
((Action<double, RDD<dynamic>>)func)(time, rdd);
SerDe.Write(s, (byte)'n');
}
}
}
catch (Exception e)
{
//log exception only when callback socket is not shutdown explicitly
if (!callbackSocketShutdown)
{
logger.LogInfo(e.ToString());
}
}
}
Socket sock = callbackServer.AcceptSocket();
ThreadPool.QueueUserWorkItem(new WaitCallback(ProcessCallbackRequest), sock);
}
}
catch (Exception e)
{
logger.LogInfo(e.ToString());
logger.LogException(e);
throw;
}
finally

Просмотреть файл

@ -5,7 +5,7 @@ package org.apache.spark.api.csharp
import java.io.{DataOutputStream, File, FileOutputStream, IOException}
import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket}
import java.util.concurrent.TimeUnit
import java.util.concurrent.{LinkedBlockingQueue, BlockingQueue, TimeUnit}
import io.netty.bootstrap.ServerBootstrap
import io.netty.channel.nio.NioEventLoopGroup
@ -30,8 +30,8 @@ class CSharpBackend {
private[this] var bossGroup: EventLoopGroup = null
def init(): Int = {
bossGroup = new NioEventLoopGroup(2)
// need at least 3 threads, use 10 here for safety
bossGroup = new NioEventLoopGroup(10)
val workerGroup = bossGroup
val handler = new CSharpBackendHandler(this) //TODO - work with SparkR devs to make this configurable and reuse RBackend
@ -78,23 +78,29 @@ class CSharpBackend {
bootstrap.childGroup().shutdownGracefully()
}
bootstrap = null
// Send close to CSharp callback server.
if (CSharpBackend.callbackSocket != null &&
!CSharpBackend.callbackSocket.isClosed()) {
try {
println("Requesting to close a call back server.")
val dos = new DataOutputStream(CSharpBackend.callbackSocket.getOutputStream())
// Send close to CSharp callback server.
println("Requesting to close all call back sockets.")
var socket: Socket = null
do {
socket = CSharpBackend.callbackSockets.poll()
if (socket != null) {
val dos = new DataOutputStream(socket.getOutputStream)
SerDe.writeString(dos, "close")
CSharpBackend.callbackSocket.close()
CSharpBackend.callbackSocketShutdown = true
try {
socket.close()
socket = null
}
}
}
} while (socket != null)
CSharpBackend.callbackSocketShutdown = true
}
}
object CSharpBackend {
// Channel to callback server.
private[spark] var callbackSocket: Socket = null
// Channels to callback server.
private[spark] val callbackSockets: BlockingQueue[Socket] = new LinkedBlockingQueue[Socket]()
@volatile private[spark] var callbackPort: Int = 0
// flag to denote whether the callback socket is shutdown explicitly
@volatile private[spark] var callbackSocketShutdown: Boolean = false

Просмотреть файл

@ -61,23 +61,29 @@ class CSharpBackendHandler(server: CSharpBackend) extends SimpleChannelInboundHa
assert(t == 'i')
val port = readInt(dis)
println("Connecting to a callback server at port " + port)
CSharpBackend.callbackSocket = new Socket("localhost", port)
CSharpBackend.callbackPort = port
writeInt(dos, 0)
writeType(dos, "void")
case "closeCallback" =>
// Send close to CSharp callback server.
if (CSharpBackend.callbackSocket != null &&
CSharpBackend.callbackSocket.isConnected()) {
try {
println("Requesting to close a call back server.")
val os = new DataOutputStream(CSharpBackend.callbackSocket.getOutputStream())
writeString(os, "close")
CSharpBackend.callbackSocket.close()
println("Requesting to close all call back sockets.")
var socket: Socket = null
do {
socket = CSharpBackend.callbackSockets.poll()
if (socket != null) {
val dataOutputStream = new DataOutputStream(socket.getOutputStream)
SerDe.writeString(dataOutputStream, "close")
try {
socket.close()
socket = null
}
}
writeInt(dos, 0)
writeType(dos, "void")
}
} while (socket != null)
CSharpBackend.callbackSocketShutdown = true
writeInt(dos, 0)
writeType(dos, "void")
case _ => dos.writeInt(-1)
}
} else {

Просмотреть файл

@ -51,10 +51,16 @@ object CSharpDStream {
}
def callCSharpTransform(rdds: List[Option[RDD[_]]], time: Time, rfunc: Array[Byte],
deserializers: List[String]): Option[RDD[Array[Byte]]] = synchronized {
deserializers: List[String]): Option[RDD[Array[Byte]]] = {
var socket: Socket = null
try {
val dos = new DataOutputStream(CSharpBackend.callbackSocket.getOutputStream())
val dis = new DataInputStream(CSharpBackend.callbackSocket.getInputStream())
socket = CSharpBackend.callbackSockets.poll()
if (socket == null) {
socket = new Socket("localhost", CSharpBackend.callbackPort)
}
val dos = new DataOutputStream(socket.getOutputStream())
val dis = new DataInputStream(socket.getInputStream())
writeString(dos, "callback")
writeInt(dos, rdds.size)
@ -64,7 +70,9 @@ object CSharpDStream {
writeBytes(dos, rfunc)
deserializers.foreach(x => writeString(dos, x))
dos.flush()
Option(readObject(dis).asInstanceOf[JavaRDD[Array[Byte]]]).map(_.rdd)
val result = Option(readObject(dis).asInstanceOf[JavaRDD[Array[Byte]]]).map(_.rdd)
CSharpBackend.callbackSockets.offer(socket)
result
} catch {
case e: Exception =>
// log exception only when callback socket is not shutdown explicitly
@ -74,6 +82,13 @@ object CSharpDStream {
e.printStackTrace()
}
// close this socket if error happen
if (socket != null) {
try {
socket.close()
}
}
None
}
}