Merge pull request #107 from qintao1976/current
add thread-safe feature to callback socket server through a socket co…
This commit is contained in:
Коммит
ea5c734cf6
|
@ -10,6 +10,7 @@ using System.Net.Sockets;
|
|||
using System.Runtime.Serialization;
|
||||
using System.Runtime.Serialization.Formatters.Binary;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using Microsoft.Spark.CSharp.Core;
|
||||
|
@ -136,6 +137,83 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc
|
|||
SparkCLRIpcProxy.JvmBridge.CallNonStaticJavaMethod(jvmStreamingContextReference, "awaitTermination", new object[] { timeout });
|
||||
}
|
||||
|
||||
private void ProcessCallbackRequest(object socket)
|
||||
{
|
||||
logger.LogInfo("new thread created to process callback request");
|
||||
|
||||
try
|
||||
{
|
||||
using (Socket sock = (Socket)socket)
|
||||
using (var s = new NetworkStream(sock))
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
try
|
||||
{
|
||||
string cmd = SerDe.ReadString(s);
|
||||
if (cmd == "close")
|
||||
{
|
||||
logger.LogInfo("receive close cmd from Scala side");
|
||||
break;
|
||||
}
|
||||
else if (cmd == "callback")
|
||||
{
|
||||
int numRDDs = SerDe.ReadInt(s);
|
||||
var jrdds = new List<JvmObjectReference>();
|
||||
for (int i = 0; i < numRDDs; i++)
|
||||
{
|
||||
jrdds.Add(new JvmObjectReference(SerDe.ReadObjectId(s)));
|
||||
}
|
||||
double time = SerDe.ReadDouble(s);
|
||||
|
||||
IFormatter formatter = new BinaryFormatter();
|
||||
object func = formatter.Deserialize(new MemoryStream(SerDe.ReadBytes(s)));
|
||||
|
||||
string deserializer = SerDe.ReadString(s);
|
||||
RDD<dynamic> rdd = null;
|
||||
if (jrdds[0].Id != null)
|
||||
rdd = new RDD<dynamic>(new RDDIpcProxy(jrdds[0]), sparkContext, (SerializedMode)Enum.Parse(typeof(SerializedMode), deserializer));
|
||||
|
||||
if (func is Func<double, RDD<dynamic>, RDD<dynamic>>)
|
||||
{
|
||||
JvmObjectReference jrdd = (((Func<double, RDD<dynamic>, RDD<dynamic>>)func)(time, rdd).RddProxy as RDDIpcProxy).JvmRddReference;
|
||||
SerDe.Write(s, (byte)'j');
|
||||
SerDe.Write(s, jrdd.Id);
|
||||
}
|
||||
else if (func is Func<double, RDD<dynamic>, RDD<dynamic>, RDD<dynamic>>)
|
||||
{
|
||||
string deserializer2 = SerDe.ReadString(s);
|
||||
RDD<dynamic> rdd2 = new RDD<dynamic>(new RDDIpcProxy(jrdds[1]), sparkContext, (SerializedMode)Enum.Parse(typeof(SerializedMode), deserializer2));
|
||||
JvmObjectReference jrdd = (((Func<double, RDD<dynamic>, RDD<dynamic>, RDD<dynamic>>)func)(time, rdd, rdd2).RddProxy as RDDIpcProxy).JvmRddReference;
|
||||
SerDe.Write(s, (byte)'j');
|
||||
SerDe.Write(s, jrdd.Id);
|
||||
}
|
||||
else
|
||||
{
|
||||
((Action<double, RDD<dynamic>>)func)(time, rdd);
|
||||
SerDe.Write(s, (byte)'n');
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
//log exception only when callback socket is not shutdown explicitly
|
||||
if (!callbackSocketShutdown)
|
||||
{
|
||||
logger.LogException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
logger.LogException(e);
|
||||
}
|
||||
|
||||
logger.LogInfo("thread to process callback request exit");
|
||||
}
|
||||
|
||||
public int StartCallback()
|
||||
{
|
||||
TcpListener callbackServer = new TcpListener(IPAddress.Parse("127.0.0.1"), 0);
|
||||
|
@ -145,72 +223,16 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc
|
|||
{
|
||||
try
|
||||
{
|
||||
using (Socket sock = callbackServer.AcceptSocket())
|
||||
using (var s = new NetworkStream(sock))
|
||||
ThreadPool.SetMaxThreads(10, 10);
|
||||
while (!callbackSocketShutdown)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
try
|
||||
{
|
||||
string cmd = SerDe.ReadString(s);
|
||||
if (cmd == "close")
|
||||
{
|
||||
logger.LogInfo("receive close cmd from Scala side");
|
||||
break;
|
||||
}
|
||||
else if (cmd == "callback")
|
||||
{
|
||||
int numRDDs = SerDe.ReadInt(s);
|
||||
var jrdds = new List<JvmObjectReference>();
|
||||
for (int i = 0; i < numRDDs; i++)
|
||||
{
|
||||
jrdds.Add(new JvmObjectReference(SerDe.ReadObjectId(s)));
|
||||
}
|
||||
double time = SerDe.ReadDouble(s);
|
||||
|
||||
IFormatter formatter = new BinaryFormatter();
|
||||
object func = formatter.Deserialize(new MemoryStream(SerDe.ReadBytes(s)));
|
||||
|
||||
string deserializer = SerDe.ReadString(s);
|
||||
RDD<dynamic> rdd = null;
|
||||
if (jrdds[0].Id != null)
|
||||
rdd = new RDD<dynamic>(new RDDIpcProxy(jrdds[0]), sparkContext, (SerializedMode)Enum.Parse(typeof(SerializedMode), deserializer));
|
||||
|
||||
if (func is Func<double, RDD<dynamic>, RDD<dynamic>>)
|
||||
{
|
||||
JvmObjectReference jrdd = (((Func<double, RDD<dynamic>, RDD<dynamic>>)func)(time, rdd).RddProxy as RDDIpcProxy).JvmRddReference;
|
||||
SerDe.Write(s, (byte)'j');
|
||||
SerDe.Write(s, jrdd.Id);
|
||||
}
|
||||
else if (func is Func<double, RDD<dynamic>, RDD<dynamic>, RDD<dynamic>>)
|
||||
{
|
||||
string deserializer2 = SerDe.ReadString(s);
|
||||
RDD<dynamic> rdd2 = new RDD<dynamic>(new RDDIpcProxy(jrdds[1]), sparkContext, (SerializedMode)Enum.Parse(typeof(SerializedMode), deserializer2));
|
||||
JvmObjectReference jrdd = (((Func<double, RDD<dynamic>, RDD<dynamic>, RDD<dynamic>>)func)(time, rdd, rdd2).RddProxy as RDDIpcProxy).JvmRddReference;
|
||||
SerDe.Write(s, (byte)'j');
|
||||
SerDe.Write(s, jrdd.Id);
|
||||
}
|
||||
else
|
||||
{
|
||||
((Action<double, RDD<dynamic>>)func)(time, rdd);
|
||||
SerDe.Write(s, (byte)'n');
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
//log exception only when callback socket is not shutdown explicitly
|
||||
if (!callbackSocketShutdown)
|
||||
{
|
||||
logger.LogInfo(e.ToString());
|
||||
}
|
||||
}
|
||||
}
|
||||
Socket sock = callbackServer.AcceptSocket();
|
||||
ThreadPool.QueueUserWorkItem(new WaitCallback(ProcessCallbackRequest), sock);
|
||||
}
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
logger.LogInfo(e.ToString());
|
||||
logger.LogException(e);
|
||||
throw;
|
||||
}
|
||||
finally
|
||||
|
|
|
@ -5,7 +5,7 @@ package org.apache.spark.api.csharp
|
|||
|
||||
import java.io.{DataOutputStream, File, FileOutputStream, IOException}
|
||||
import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket}
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.{LinkedBlockingQueue, BlockingQueue, TimeUnit}
|
||||
|
||||
import io.netty.bootstrap.ServerBootstrap
|
||||
import io.netty.channel.nio.NioEventLoopGroup
|
||||
|
@ -30,8 +30,8 @@ class CSharpBackend {
|
|||
private[this] var bossGroup: EventLoopGroup = null
|
||||
|
||||
def init(): Int = {
|
||||
|
||||
bossGroup = new NioEventLoopGroup(2)
|
||||
// need at least 3 threads, use 10 here for safety
|
||||
bossGroup = new NioEventLoopGroup(10)
|
||||
val workerGroup = bossGroup
|
||||
val handler = new CSharpBackendHandler(this) //TODO - work with SparkR devs to make this configurable and reuse RBackend
|
||||
|
||||
|
@ -78,23 +78,29 @@ class CSharpBackend {
|
|||
bootstrap.childGroup().shutdownGracefully()
|
||||
}
|
||||
bootstrap = null
|
||||
// Send close to CSharp callback server.
|
||||
if (CSharpBackend.callbackSocket != null &&
|
||||
!CSharpBackend.callbackSocket.isClosed()) {
|
||||
try {
|
||||
println("Requesting to close a call back server.")
|
||||
val dos = new DataOutputStream(CSharpBackend.callbackSocket.getOutputStream())
|
||||
|
||||
// Send close to CSharp callback server.
|
||||
println("Requesting to close all call back sockets.")
|
||||
var socket: Socket = null
|
||||
do {
|
||||
socket = CSharpBackend.callbackSockets.poll()
|
||||
if (socket != null) {
|
||||
val dos = new DataOutputStream(socket.getOutputStream)
|
||||
SerDe.writeString(dos, "close")
|
||||
CSharpBackend.callbackSocket.close()
|
||||
CSharpBackend.callbackSocketShutdown = true
|
||||
try {
|
||||
socket.close()
|
||||
socket = null
|
||||
}
|
||||
}
|
||||
}
|
||||
} while (socket != null)
|
||||
CSharpBackend.callbackSocketShutdown = true
|
||||
}
|
||||
}
|
||||
|
||||
object CSharpBackend {
|
||||
// Channel to callback server.
|
||||
private[spark] var callbackSocket: Socket = null
|
||||
// Channels to callback server.
|
||||
private[spark] val callbackSockets: BlockingQueue[Socket] = new LinkedBlockingQueue[Socket]()
|
||||
@volatile private[spark] var callbackPort: Int = 0
|
||||
|
||||
// flag to denote whether the callback socket is shutdown explicitly
|
||||
@volatile private[spark] var callbackSocketShutdown: Boolean = false
|
||||
|
|
|
@ -61,23 +61,29 @@ class CSharpBackendHandler(server: CSharpBackend) extends SimpleChannelInboundHa
|
|||
assert(t == 'i')
|
||||
val port = readInt(dis)
|
||||
println("Connecting to a callback server at port " + port)
|
||||
CSharpBackend.callbackSocket = new Socket("localhost", port)
|
||||
CSharpBackend.callbackPort = port
|
||||
writeInt(dos, 0)
|
||||
writeType(dos, "void")
|
||||
case "closeCallback" =>
|
||||
// Send close to CSharp callback server.
|
||||
if (CSharpBackend.callbackSocket != null &&
|
||||
CSharpBackend.callbackSocket.isConnected()) {
|
||||
try {
|
||||
println("Requesting to close a call back server.")
|
||||
val os = new DataOutputStream(CSharpBackend.callbackSocket.getOutputStream())
|
||||
writeString(os, "close")
|
||||
CSharpBackend.callbackSocket.close()
|
||||
println("Requesting to close all call back sockets.")
|
||||
var socket: Socket = null
|
||||
do {
|
||||
socket = CSharpBackend.callbackSockets.poll()
|
||||
if (socket != null) {
|
||||
val dataOutputStream = new DataOutputStream(socket.getOutputStream)
|
||||
SerDe.writeString(dataOutputStream, "close")
|
||||
try {
|
||||
socket.close()
|
||||
socket = null
|
||||
}
|
||||
}
|
||||
writeInt(dos, 0)
|
||||
writeType(dos, "void")
|
||||
}
|
||||
} while (socket != null)
|
||||
CSharpBackend.callbackSocketShutdown = true
|
||||
|
||||
writeInt(dos, 0)
|
||||
writeType(dos, "void")
|
||||
|
||||
case _ => dos.writeInt(-1)
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -51,10 +51,16 @@ object CSharpDStream {
|
|||
}
|
||||
|
||||
def callCSharpTransform(rdds: List[Option[RDD[_]]], time: Time, rfunc: Array[Byte],
|
||||
deserializers: List[String]): Option[RDD[Array[Byte]]] = synchronized {
|
||||
deserializers: List[String]): Option[RDD[Array[Byte]]] = {
|
||||
var socket: Socket = null
|
||||
try {
|
||||
val dos = new DataOutputStream(CSharpBackend.callbackSocket.getOutputStream())
|
||||
val dis = new DataInputStream(CSharpBackend.callbackSocket.getInputStream())
|
||||
socket = CSharpBackend.callbackSockets.poll()
|
||||
if (socket == null) {
|
||||
socket = new Socket("localhost", CSharpBackend.callbackPort)
|
||||
}
|
||||
|
||||
val dos = new DataOutputStream(socket.getOutputStream())
|
||||
val dis = new DataInputStream(socket.getInputStream())
|
||||
|
||||
writeString(dos, "callback")
|
||||
writeInt(dos, rdds.size)
|
||||
|
@ -64,7 +70,9 @@ object CSharpDStream {
|
|||
writeBytes(dos, rfunc)
|
||||
deserializers.foreach(x => writeString(dos, x))
|
||||
dos.flush()
|
||||
Option(readObject(dis).asInstanceOf[JavaRDD[Array[Byte]]]).map(_.rdd)
|
||||
val result = Option(readObject(dis).asInstanceOf[JavaRDD[Array[Byte]]]).map(_.rdd)
|
||||
CSharpBackend.callbackSockets.offer(socket)
|
||||
result
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
// log exception only when callback socket is not shutdown explicitly
|
||||
|
@ -74,6 +82,13 @@ object CSharpDStream {
|
|||
e.printStackTrace()
|
||||
}
|
||||
|
||||
// close this socket if error happen
|
||||
if (socket != null) {
|
||||
try {
|
||||
socket.close()
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче