diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/StreamingContextIpcProxy.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/StreamingContextIpcProxy.cs index cf86fd8..fd8ea12 100644 --- a/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/StreamingContextIpcProxy.cs +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/StreamingContextIpcProxy.cs @@ -10,6 +10,7 @@ using System.Net.Sockets; using System.Runtime.Serialization; using System.Runtime.Serialization.Formatters.Binary; using System.Text; +using System.Threading; using System.Threading.Tasks; using Microsoft.Spark.CSharp.Core; @@ -136,6 +137,83 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc SparkCLRIpcProxy.JvmBridge.CallNonStaticJavaMethod(jvmStreamingContextReference, "awaitTermination", new object[] { timeout }); } + private void ProcessCallbackRequest(object socket) + { + logger.LogInfo("new thread created to process callback request"); + + try + { + using (Socket sock = (Socket)socket) + using (var s = new NetworkStream(sock)) + { + while (true) + { + try + { + string cmd = SerDe.ReadString(s); + if (cmd == "close") + { + logger.LogInfo("receive close cmd from Scala side"); + break; + } + else if (cmd == "callback") + { + int numRDDs = SerDe.ReadInt(s); + var jrdds = new List(); + for (int i = 0; i < numRDDs; i++) + { + jrdds.Add(new JvmObjectReference(SerDe.ReadObjectId(s))); + } + double time = SerDe.ReadDouble(s); + + IFormatter formatter = new BinaryFormatter(); + object func = formatter.Deserialize(new MemoryStream(SerDe.ReadBytes(s))); + + string deserializer = SerDe.ReadString(s); + RDD rdd = null; + if (jrdds[0].Id != null) + rdd = new RDD(new RDDIpcProxy(jrdds[0]), sparkContext, (SerializedMode)Enum.Parse(typeof(SerializedMode), deserializer)); + + if (func is Func, RDD>) + { + JvmObjectReference jrdd = (((Func, RDD>)func)(time, rdd).RddProxy as RDDIpcProxy).JvmRddReference; + SerDe.Write(s, (byte)'j'); + SerDe.Write(s, jrdd.Id); + } + else if (func is Func, RDD, RDD>) + { + string deserializer2 = SerDe.ReadString(s); + RDD rdd2 = new RDD(new RDDIpcProxy(jrdds[1]), sparkContext, (SerializedMode)Enum.Parse(typeof(SerializedMode), deserializer2)); + JvmObjectReference jrdd = (((Func, RDD, RDD>)func)(time, rdd, rdd2).RddProxy as RDDIpcProxy).JvmRddReference; + SerDe.Write(s, (byte)'j'); + SerDe.Write(s, jrdd.Id); + } + else + { + ((Action>)func)(time, rdd); + SerDe.Write(s, (byte)'n'); + } + } + } + catch (Exception e) + { + //log exception only when callback socket is not shutdown explicitly + if (!callbackSocketShutdown) + { + logger.LogException(e); + } + } + } + } + } + catch (Exception e) + { + logger.LogException(e); + } + + logger.LogInfo("thread to process callback request exit"); + } + public int StartCallback() { TcpListener callbackServer = new TcpListener(IPAddress.Parse("127.0.0.1"), 0); @@ -145,72 +223,16 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc { try { - using (Socket sock = callbackServer.AcceptSocket()) - using (var s = new NetworkStream(sock)) + ThreadPool.SetMaxThreads(10, 10); + while (!callbackSocketShutdown) { - while (true) - { - try - { - string cmd = SerDe.ReadString(s); - if (cmd == "close") - { - logger.LogInfo("receive close cmd from Scala side"); - break; - } - else if (cmd == "callback") - { - int numRDDs = SerDe.ReadInt(s); - var jrdds = new List(); - for (int i = 0; i < numRDDs; i++) - { - jrdds.Add(new JvmObjectReference(SerDe.ReadObjectId(s))); - } - double time = SerDe.ReadDouble(s); - - IFormatter formatter = new BinaryFormatter(); - object func = formatter.Deserialize(new MemoryStream(SerDe.ReadBytes(s))); - - string deserializer = SerDe.ReadString(s); - RDD rdd = null; - if (jrdds[0].Id != null) - rdd = new RDD(new RDDIpcProxy(jrdds[0]), sparkContext, (SerializedMode)Enum.Parse(typeof(SerializedMode), deserializer)); - - if (func is Func, RDD>) - { - JvmObjectReference jrdd = (((Func, RDD>)func)(time, rdd).RddProxy as RDDIpcProxy).JvmRddReference; - SerDe.Write(s, (byte)'j'); - SerDe.Write(s, jrdd.Id); - } - else if (func is Func, RDD, RDD>) - { - string deserializer2 = SerDe.ReadString(s); - RDD rdd2 = new RDD(new RDDIpcProxy(jrdds[1]), sparkContext, (SerializedMode)Enum.Parse(typeof(SerializedMode), deserializer2)); - JvmObjectReference jrdd = (((Func, RDD, RDD>)func)(time, rdd, rdd2).RddProxy as RDDIpcProxy).JvmRddReference; - SerDe.Write(s, (byte)'j'); - SerDe.Write(s, jrdd.Id); - } - else - { - ((Action>)func)(time, rdd); - SerDe.Write(s, (byte)'n'); - } - } - } - catch (Exception e) - { - //log exception only when callback socket is not shutdown explicitly - if (!callbackSocketShutdown) - { - logger.LogInfo(e.ToString()); - } - } - } + Socket sock = callbackServer.AcceptSocket(); + ThreadPool.QueueUserWorkItem(new WaitCallback(ProcessCallbackRequest), sock); } } catch (Exception e) { - logger.LogInfo(e.ToString()); + logger.LogException(e); throw; } finally diff --git a/scala/src/main/org/apache/spark/api/csharp/CSharpBackend.scala b/scala/src/main/org/apache/spark/api/csharp/CSharpBackend.scala index 538273b..4e1c896 100644 --- a/scala/src/main/org/apache/spark/api/csharp/CSharpBackend.scala +++ b/scala/src/main/org/apache/spark/api/csharp/CSharpBackend.scala @@ -5,7 +5,7 @@ package org.apache.spark.api.csharp import java.io.{DataOutputStream, File, FileOutputStream, IOException} import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket} -import java.util.concurrent.TimeUnit +import java.util.concurrent.{LinkedBlockingQueue, BlockingQueue, TimeUnit} import io.netty.bootstrap.ServerBootstrap import io.netty.channel.nio.NioEventLoopGroup @@ -30,8 +30,8 @@ class CSharpBackend { private[this] var bossGroup: EventLoopGroup = null def init(): Int = { - - bossGroup = new NioEventLoopGroup(2) + // need at least 3 threads, use 10 here for safety + bossGroup = new NioEventLoopGroup(10) val workerGroup = bossGroup val handler = new CSharpBackendHandler(this) //TODO - work with SparkR devs to make this configurable and reuse RBackend @@ -78,23 +78,29 @@ class CSharpBackend { bootstrap.childGroup().shutdownGracefully() } bootstrap = null - // Send close to CSharp callback server. - if (CSharpBackend.callbackSocket != null && - !CSharpBackend.callbackSocket.isClosed()) { - try { - println("Requesting to close a call back server.") - val dos = new DataOutputStream(CSharpBackend.callbackSocket.getOutputStream()) + + // Send close to CSharp callback server. + println("Requesting to close all call back sockets.") + var socket: Socket = null + do { + socket = CSharpBackend.callbackSockets.poll() + if (socket != null) { + val dos = new DataOutputStream(socket.getOutputStream) SerDe.writeString(dos, "close") - CSharpBackend.callbackSocket.close() - CSharpBackend.callbackSocketShutdown = true + try { + socket.close() + socket = null + } } - } + } while (socket != null) + CSharpBackend.callbackSocketShutdown = true } } object CSharpBackend { - // Channel to callback server. - private[spark] var callbackSocket: Socket = null + // Channels to callback server. + private[spark] val callbackSockets: BlockingQueue[Socket] = new LinkedBlockingQueue[Socket]() + @volatile private[spark] var callbackPort: Int = 0 // flag to denote whether the callback socket is shutdown explicitly @volatile private[spark] var callbackSocketShutdown: Boolean = false diff --git a/scala/src/main/org/apache/spark/api/csharp/CSharpBackendHandler.scala b/scala/src/main/org/apache/spark/api/csharp/CSharpBackendHandler.scala index 935c877..cc43aae 100644 --- a/scala/src/main/org/apache/spark/api/csharp/CSharpBackendHandler.scala +++ b/scala/src/main/org/apache/spark/api/csharp/CSharpBackendHandler.scala @@ -61,23 +61,29 @@ class CSharpBackendHandler(server: CSharpBackend) extends SimpleChannelInboundHa assert(t == 'i') val port = readInt(dis) println("Connecting to a callback server at port " + port) - CSharpBackend.callbackSocket = new Socket("localhost", port) + CSharpBackend.callbackPort = port writeInt(dos, 0) writeType(dos, "void") case "closeCallback" => // Send close to CSharp callback server. - if (CSharpBackend.callbackSocket != null && - CSharpBackend.callbackSocket.isConnected()) { - try { - println("Requesting to close a call back server.") - val os = new DataOutputStream(CSharpBackend.callbackSocket.getOutputStream()) - writeString(os, "close") - CSharpBackend.callbackSocket.close() + println("Requesting to close all call back sockets.") + var socket: Socket = null + do { + socket = CSharpBackend.callbackSockets.poll() + if (socket != null) { + val dataOutputStream = new DataOutputStream(socket.getOutputStream) + SerDe.writeString(dataOutputStream, "close") + try { + socket.close() + socket = null + } } - writeInt(dos, 0) - writeType(dos, "void") - } + } while (socket != null) CSharpBackend.callbackSocketShutdown = true + + writeInt(dos, 0) + writeType(dos, "void") + case _ => dos.writeInt(-1) } } else { diff --git a/scala/src/main/org/apache/spark/streaming/api/csharp/CSharpDStream.scala b/scala/src/main/org/apache/spark/streaming/api/csharp/CSharpDStream.scala index 6b916f2..c44c45b 100644 --- a/scala/src/main/org/apache/spark/streaming/api/csharp/CSharpDStream.scala +++ b/scala/src/main/org/apache/spark/streaming/api/csharp/CSharpDStream.scala @@ -51,10 +51,16 @@ object CSharpDStream { } def callCSharpTransform(rdds: List[Option[RDD[_]]], time: Time, rfunc: Array[Byte], - deserializers: List[String]): Option[RDD[Array[Byte]]] = synchronized { + deserializers: List[String]): Option[RDD[Array[Byte]]] = { + var socket: Socket = null try { - val dos = new DataOutputStream(CSharpBackend.callbackSocket.getOutputStream()) - val dis = new DataInputStream(CSharpBackend.callbackSocket.getInputStream()) + socket = CSharpBackend.callbackSockets.poll() + if (socket == null) { + socket = new Socket("localhost", CSharpBackend.callbackPort) + } + + val dos = new DataOutputStream(socket.getOutputStream()) + val dis = new DataInputStream(socket.getInputStream()) writeString(dos, "callback") writeInt(dos, rdds.size) @@ -64,7 +70,9 @@ object CSharpDStream { writeBytes(dos, rfunc) deserializers.foreach(x => writeString(dos, x)) dos.flush() - Option(readObject(dis).asInstanceOf[JavaRDD[Array[Byte]]]).map(_.rdd) + val result = Option(readObject(dis).asInstanceOf[JavaRDD[Array[Byte]]]).map(_.rdd) + CSharpBackend.callbackSockets.offer(socket) + result } catch { case e: Exception => // log exception only when callback socket is not shutdown explicitly @@ -74,6 +82,13 @@ object CSharpDStream { e.printStackTrace() } + // close this socket if error happen + if (socket != null) { + try { + socket.close() + } + } + None } }