Full support for multithreaded applications (#641)

This commit is contained in:
Andrew Fogarty 2020-10-05 17:31:54 -07:00 коммит произвёл GitHub
Родитель 7bcd2a5060
Коммит a67ad5907e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
14 изменённых файлов: 618 добавлений и 43 удалений

Просмотреть файл

@ -0,0 +1,121 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Threading;
using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Services;
using Microsoft.Spark.Sql;
using Xunit;
namespace Microsoft.Spark.E2ETest.IpcTests
{
[Collection("Spark E2E Tests")]
public class JvmThreadPoolGCTests
{
private readonly ILoggerService _loggerService;
private readonly SparkSession _spark;
private readonly IJvmBridge _jvmBridge;
public JvmThreadPoolGCTests(SparkFixture fixture)
{
_loggerService = LoggerServiceFactory.GetLogger(typeof(JvmThreadPoolGCTests));
_spark = fixture.Spark;
_jvmBridge = ((IJvmObjectReferenceProvider)_spark).Reference.Jvm;
}
/// <summary>
/// Test that the active SparkSession is thread-specific.
/// </summary>
[Fact]
public void TestThreadLocalSessions()
{
SparkSession.ClearActiveSession();
void testChildThread(string appName)
{
var thread = new Thread(() =>
{
Assert.Null(SparkSession.GetActiveSession());
SparkSession.SetActiveSession(
SparkSession.Builder().AppName(appName).GetOrCreate());
// Since we are in the child thread, GetActiveSession() should return the child
// SparkSession.
SparkSession activeSession = SparkSession.GetActiveSession();
Assert.NotNull(activeSession);
Assert.Equal(appName, activeSession.Conf().Get("spark.app.name", null));
});
thread.Start();
thread.Join();
}
for (int i = 0; i < 5; ++i)
{
testChildThread(i.ToString());
}
Assert.Null(SparkSession.GetActiveSession());
}
/// <summary>
/// Monitor a thread via the JvmThreadPoolGC.
/// </summary>
[Fact]
public void TestTryAddThread()
{
using var threadPool = new JvmThreadPoolGC(
_loggerService, _jvmBridge, TimeSpan.FromMinutes(30));
var thread = new Thread(() => _spark.Sql("SELECT TRUE"));
thread.Start();
Assert.True(threadPool.TryAddThread(thread));
// Subsequent call should return false, because the thread has already been added.
Assert.False(threadPool.TryAddThread(thread));
thread.Join();
}
/// <summary>
/// Create a Spark worker thread in the JVM ThreadPool then remove it directly through
/// the JvmBridge.
/// </summary>
[Fact]
public void TestRmThread()
{
// Create a thread and ensure that it is initialized in the JVM ThreadPool.
var thread = new Thread(() => _spark.Sql("SELECT TRUE"));
thread.Start();
thread.Join();
// First call should return true. Second call should return false.
Assert.True((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId));
Assert.False((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId));
}
/// <summary>
/// Test that the GC interval configuration defaults to 5 minutes, and can be updated
/// correctly by setting the environment variable.
/// </summary>
[Fact]
public void TestIntervalConfiguration()
{
// Default value is 5 minutes.
Assert.Null(Environment.GetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL"));
Assert.Equal(
TimeSpan.FromMinutes(5),
SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
// Test a custom value.
Environment.SetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL", "1:30:00");
Assert.Equal(
TimeSpan.FromMinutes(90),
SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
}
}
}

Просмотреть файл

@ -35,6 +35,10 @@ namespace Microsoft.Spark.E2ETest.IpcTests
Assert.IsType<Builder>(SparkSession.Builder()); Assert.IsType<Builder>(SparkSession.Builder());
SparkSession.ClearActiveSession();
SparkSession.SetActiveSession(_spark);
Assert.IsType<SparkSession>(SparkSession.GetActiveSession());
SparkSession.ClearDefaultSession(); SparkSession.ClearDefaultSession();
SparkSession.SetDefaultSession(_spark); SparkSession.SetDefaultSession(_spark);
Assert.IsType<SparkSession>(SparkSession.GetDefaultSession()); Assert.IsType<SparkSession>(SparkSession.GetDefaultSession());
@ -76,7 +80,7 @@ namespace Microsoft.Spark.E2ETest.IpcTests
/// </summary> /// </summary>
[Fact] [Fact]
public void TestCreateDataFrame() public void TestCreateDataFrame()
{ {
// Calling CreateDataFrame with schema // Calling CreateDataFrame with schema
{ {
var data = new List<GenericRow> var data = new List<GenericRow>

Просмотреть файл

@ -8,6 +8,7 @@ using System.Collections.Generic;
using System.IO; using System.IO;
using System.Net; using System.Net;
using System.Text; using System.Text;
using System.Threading;
using Microsoft.Spark.Network; using Microsoft.Spark.Network;
using Microsoft.Spark.Services; using Microsoft.Spark.Services;
@ -35,6 +36,7 @@ namespace Microsoft.Spark.Interop.Ipc
private readonly ILoggerService _logger = private readonly ILoggerService _logger =
LoggerServiceFactory.GetLogger(typeof(JvmBridge)); LoggerServiceFactory.GetLogger(typeof(JvmBridge));
private readonly int _portNumber; private readonly int _portNumber;
private readonly JvmThreadPoolGC _jvmThreadPoolGC;
internal JvmBridge(int portNumber) internal JvmBridge(int portNumber)
{ {
@ -45,6 +47,9 @@ namespace Microsoft.Spark.Interop.Ipc
_portNumber = portNumber; _portNumber = portNumber;
_logger.LogInfo($"JvMBridge port is {portNumber}"); _logger.LogInfo($"JvMBridge port is {portNumber}");
_jvmThreadPoolGC = new JvmThreadPoolGC(
_logger, this, SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
} }
private ISocketWrapper GetConnection() private ISocketWrapper GetConnection()
@ -158,11 +163,13 @@ namespace Microsoft.Spark.Interop.Ipc
ISocketWrapper socket = null; ISocketWrapper socket = null;
try try
{ {
Thread thread = Thread.CurrentThread;
MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream(); MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream();
payloadMemoryStream.Position = 0; payloadMemoryStream.Position = 0;
PayloadHelper.BuildPayload( PayloadHelper.BuildPayload(
payloadMemoryStream, payloadMemoryStream,
isStatic, isStatic,
thread.ManagedThreadId,
classNameOrJvmObjectReference, classNameOrJvmObjectReference,
methodName, methodName,
args); args);
@ -176,6 +183,8 @@ namespace Microsoft.Spark.Interop.Ipc
(int)payloadMemoryStream.Position); (int)payloadMemoryStream.Position);
outputStream.Flush(); outputStream.Flush();
_jvmThreadPoolGC.TryAddThread(thread);
Stream inputStream = socket.InputStream; Stream inputStream = socket.InputStream;
int isMethodCallFailed = SerDe.ReadInt32(inputStream); int isMethodCallFailed = SerDe.ReadInt32(inputStream);
if (isMethodCallFailed != 0) if (isMethodCallFailed != 0)
@ -410,6 +419,7 @@ namespace Microsoft.Spark.Interop.Ipc
public void Dispose() public void Dispose()
{ {
_jvmThreadPoolGC.Dispose();
while (_sockets.TryDequeue(out ISocketWrapper socket)) while (_sockets.TryDequeue(out ISocketWrapper socket))
{ {
if (socket != null) if (socket != null)

Просмотреть файл

@ -0,0 +1,149 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using Microsoft.Spark.Services;
namespace Microsoft.Spark.Interop.Ipc
{
/// <summary>
/// In .NET for Apache Spark, we maintain a 1-to-1 mapping between .NET application threads
/// and corresponding JVM threads. When a .NET thread calls a Spark API, that call is executed
/// by its corresponding JVM thread. This functionality allows for multithreaded applications
/// with thread-local variables.
///
/// This class keeps track of the .NET application thread lifecycle. When a .NET application
/// thread is no longer alive, this class submits an rmThread command to the JVM backend to
/// dispose of its corresponding JVM thread. All methods are thread-safe.
/// </summary>
internal class JvmThreadPoolGC : IDisposable
{
private readonly ILoggerService _loggerService;
private readonly IJvmBridge _jvmBridge;
private readonly TimeSpan _threadGCInterval;
private readonly ConcurrentDictionary<int, Thread> _activeThreads;
private readonly object _activeThreadGCTimerLock;
private Timer _activeThreadGCTimer;
/// <summary>
/// Construct the JvmThreadPoolGC.
/// </summary>
/// <param name="loggerService">Logger service.</param>
/// <param name="jvmBridge">The JvmBridge used to call JVM methods.</param>
/// <param name="threadGCInterval">The interval to GC finished threads.</param>
public JvmThreadPoolGC(ILoggerService loggerService, IJvmBridge jvmBridge, TimeSpan threadGCInterval)
{
_loggerService = loggerService;
_jvmBridge = jvmBridge;
_threadGCInterval = threadGCInterval;
_activeThreads = new ConcurrentDictionary<int, Thread>();
_activeThreadGCTimerLock = new object();
_activeThreadGCTimer = null;
}
/// <summary>
/// Dispose of the GC timer and run a final round of thread GC.
/// </summary>
public void Dispose()
{
lock (_activeThreadGCTimerLock)
{
if (_activeThreadGCTimer != null)
{
_activeThreadGCTimer.Dispose();
_activeThreadGCTimer = null;
}
}
GCThreads();
}
/// <summary>
/// Try to start monitoring a thread.
/// </summary>
/// <param name="thread">The thread to add.</param>
/// <returns>True if success, false if already added.</returns>
public bool TryAddThread(Thread thread)
{
bool returnValue = _activeThreads.TryAdd(thread.ManagedThreadId, thread);
// Initialize the GC timer if necessary.
if (_activeThreadGCTimer == null)
{
lock (_activeThreadGCTimerLock)
{
if (_activeThreadGCTimer == null && _activeThreads.Count > 0)
{
_activeThreadGCTimer = new Timer(
(state) => GCThreads(),
null,
_threadGCInterval,
_threadGCInterval);
}
}
}
return returnValue;
}
/// <summary>
/// Try to remove a thread from the pool. If the removal is successful, then the
/// corresponding JVM thread will also be disposed.
/// </summary>
/// <param name="threadId">The ID of the thread to remove.</param>
/// <returns>True if success, false if the thread cannot be found.</returns>
private bool TryDisposeJvmThread(int threadId)
{
if (_activeThreads.TryRemove(threadId, out _))
{
// _activeThreads does not have ownership of the threads on the .NET side. This
// class does not need to call Join() on the .NET Thread. However, this class is
// responsible for sending the rmThread command to the JVM to trigger disposal
// of the corresponding JVM thread.
if ((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", threadId))
{
_loggerService.LogDebug($"GC'd JVM thread {threadId}.");
return true;
}
else
{
_loggerService.LogWarn(
$"rmThread returned false for JVM thread {threadId}. " +
$"Either thread does not exist or has already been GC'd.");
}
}
return false;
}
/// <summary>
/// Remove any threads that are no longer active.
/// </summary>
private void GCThreads()
{
foreach (KeyValuePair<int, Thread> kvp in _activeThreads)
{
if (!kvp.Value.IsAlive)
{
TryDisposeJvmThread(kvp.Key);
}
}
lock (_activeThreadGCTimerLock)
{
// Dispose of the timer if there are no threads to monitor.
if (_activeThreadGCTimer != null && _activeThreads.IsEmpty)
{
_activeThreadGCTimer.Dispose();
_activeThreadGCTimer = null;
}
}
}
}
}

Просмотреть файл

@ -27,7 +27,7 @@ namespace Microsoft.Spark.Interop.Ipc
private static readonly byte[] s_timestampTypeId = new[] { (byte)'t' }; private static readonly byte[] s_timestampTypeId = new[] { (byte)'t' };
private static readonly byte[] s_jvmObjectTypeId = new[] { (byte)'j' }; private static readonly byte[] s_jvmObjectTypeId = new[] { (byte)'j' };
private static readonly byte[] s_byteArrayTypeId = new[] { (byte)'r' }; private static readonly byte[] s_byteArrayTypeId = new[] { (byte)'r' };
private static readonly byte[] s_doubleArrayArrayTypeId = new[] { ( byte)'A' }; private static readonly byte[] s_doubleArrayArrayTypeId = new[] { (byte)'A' };
private static readonly byte[] s_arrayTypeId = new[] { (byte)'l' }; private static readonly byte[] s_arrayTypeId = new[] { (byte)'l' };
private static readonly byte[] s_dictionaryTypeId = new[] { (byte)'e' }; private static readonly byte[] s_dictionaryTypeId = new[] { (byte)'e' };
private static readonly byte[] s_rowArrTypeId = new[] { (byte)'R' }; private static readonly byte[] s_rowArrTypeId = new[] { (byte)'R' };
@ -39,6 +39,7 @@ namespace Microsoft.Spark.Interop.Ipc
internal static void BuildPayload( internal static void BuildPayload(
MemoryStream destination, MemoryStream destination,
bool isStaticMethod, bool isStaticMethod,
int threadId,
object classNameOrJvmObjectReference, object classNameOrJvmObjectReference,
string methodName, string methodName,
object[] args) object[] args)
@ -48,6 +49,7 @@ namespace Microsoft.Spark.Interop.Ipc
destination.Position += sizeof(int); destination.Position += sizeof(int);
SerDe.Write(destination, isStaticMethod); SerDe.Write(destination, isStaticMethod);
SerDe.Write(destination, threadId);
SerDe.Write(destination, classNameOrJvmObjectReference.ToString()); SerDe.Write(destination, classNameOrJvmObjectReference.ToString());
SerDe.Write(destination, methodName); SerDe.Write(destination, methodName);
SerDe.Write(destination, args.Length); SerDe.Write(destination, args.Length);
@ -140,7 +142,7 @@ namespace Microsoft.Spark.Interop.Ipc
SerDe.Write(destination, d); SerDe.Write(destination, d);
} }
break; break;
case double[][] argDoubleArrayArray: case double[][] argDoubleArrayArray:
SerDe.Write(destination, s_doubleArrayArrayTypeId); SerDe.Write(destination, s_doubleArrayArrayTypeId);
SerDe.Write(destination, argDoubleArrayArray.Length); SerDe.Write(destination, argDoubleArrayArray.Length);

Просмотреть файл

@ -33,6 +33,18 @@ namespace Microsoft.Spark.Services
private string _workerPath; private string _workerPath;
/// <summary>
/// How often to run GC on JVM ThreadPool threads. Defaults to 5 minutes.
/// </summary>
public TimeSpan JvmThreadGCInterval
{
get
{
string envVar = Environment.GetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL");
return string.IsNullOrEmpty(envVar) ? TimeSpan.FromMinutes(5) : TimeSpan.Parse(envVar);
}
}
/// <summary> /// <summary>
/// Returns the port number for socket communication between JVM and CLR. /// Returns the port number for socket communication between JVM and CLR.
/// </summary> /// </summary>

Просмотреть файл

@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license. // The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information. // See the LICENSE file in the project root for more information.
using System;
namespace Microsoft.Spark.Services namespace Microsoft.Spark.Services
{ {
/// <summary> /// <summary>
@ -9,6 +11,11 @@ namespace Microsoft.Spark.Services
/// </summary> /// </summary>
internal interface IConfigurationService internal interface IConfigurationService
{ {
/// <summary>
/// How often to run GC on JVM ThreadPool threads.
/// </summary>
TimeSpan JvmThreadGCInterval { get; }
/// <summary> /// <summary>
/// The port number used for communicating with the .NET backend process. /// The port number used for communicating with the .NET backend process.
/// </summary> /// </summary>

Просмотреть файл

@ -61,10 +61,40 @@ namespace Microsoft.Spark.Sql
/// <returns>Builder object</returns> /// <returns>Builder object</returns>
public static Builder Builder() => new Builder(); public static Builder Builder() => new Builder();
/// Note that *ActiveSession() APIs are not exposed because these APIs work with a /// <summary>
/// thread-local variable, which stores the session variable. Since the Netty server /// Changes the SparkSession that will be returned in this thread when
/// that handles the requests is multi-threaded, any thread can invoke these APIs, /// <see cref="Builder.GetOrCreate"/> is called. This can be used to ensure that a given
/// resulting in unexpected behaviors if different threads are used. /// thread receives a SparkSession with an isolated session, instead of the global
/// (first created) context.
/// </summary>
/// <param name="session">SparkSession object</param>
public static void SetActiveSession(SparkSession session) =>
session._jvmObject.Jvm.CallStaticJavaMethod(
s_sparkSessionClassName, "setActiveSession", session);
/// <summary>
/// Clears the active SparkSession for current thread. Subsequent calls to
/// <see cref="Builder.GetOrCreate"/> will return the first created context
/// instead of a thread-local override.
/// </summary>
public static void ClearActiveSession() =>
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_sparkSessionClassName, "clearActiveSession");
/// <summary>
/// Returns the active SparkSession for the current thread, returned by the builder.
/// </summary>
/// <returns>Return null, when calling this function on executors</returns>
public static SparkSession GetActiveSession()
{
var optionalSession = new Option(
(JvmObjectReference)SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_sparkSessionClassName, "getActiveSession"));
return optionalSession.IsDefined()
? new SparkSession((JvmObjectReference)optionalSession.Get())
: null;
}
/// <summary> /// <summary>
/// Sets the default SparkSession that is returned by the builder. /// Sets the default SparkSession that is returned by the builder.

Просмотреть файл

@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import scala.collection.mutable.HashMap
import scala.language.existentials
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.api.dotnet.SerDe._
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
import scala.collection.mutable.HashMap
import scala.language.existentials
/** /**
* Handler for DotnetBackend. * Handler for DotnetBackend.
* This implementation is similar to RBackendHandler. * This implementation is similar to RBackendHandler.
@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend)
// First bit is isStatic // First bit is isStatic
val isStatic = readBoolean(dis) val isStatic = readBoolean(dis)
val threadId = readInt(dis)
val objId = readString(dis) val objId = readString(dis)
val methodName = readString(dis) val methodName = readString(dis)
val numArgs = readInt(dis) val numArgs = readInt(dis)
@ -65,12 +66,24 @@ class DotnetBackendHandler(server: DotnetBackend)
logError(s"Removing $objId failed", e) logError(s"Removing $objId failed", e)
writeInt(dos, -1) writeInt(dos, -1)
} }
case "rmThread" =>
try {
assert(readObjectType(dis) == 'i')
val threadToDelete = readInt(dis)
val result = ThreadPool.tryDeleteThread(threadToDelete)
writeInt(dos, 0)
writeObject(dos, result.asInstanceOf[AnyRef])
} catch {
case e: Exception =>
logError(s"Removing thread $threadId failed", e)
writeInt(dos, -1)
}
case "connectCallback" => case "connectCallback" =>
assert(readObjectType(dis) == 'c') assert(readObjectType(dis) == 'c')
val address = readString(dis) val address = readString(dis)
assert(readObjectType(dis) == 'i') assert(readObjectType(dis) == 'i')
val port = readInt(dis) val port = readInt(dis)
DotnetBackend.setCallbackClient(address, port); DotnetBackend.setCallbackClient(address, port)
writeInt(dos, 0) writeInt(dos, 0)
writeType(dos, "void") writeType(dos, "void")
case "closeCallback" => case "closeCallback" =>
@ -82,7 +95,8 @@ class DotnetBackendHandler(server: DotnetBackend)
case _ => dos.writeInt(-1) case _ => dos.writeInt(-1)
} }
} else { } else {
handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) ThreadPool
.run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
} }
bos.toByteArray bos.toByteArray
@ -162,19 +176,21 @@ class DotnetBackendHandler(server: DotnetBackend)
"invalid method " + methodName + " for object " + objId) "invalid method " + methodName + " for object " + objId)
} }
} catch { } catch {
case e: Exception => case e: Throwable =>
val jvmObj = JVMObjectTracker.get(objId) val jvmObj = JVMObjectTracker.get(objId)
val jvmObjName = jvmObj match { val jvmObjName = jvmObj match {
case Some(jObj) => jObj.getClass.getName case Some(jObj) => jObj.getClass.getName
case None => "NullObject" case None => "NullObject"
} }
val argsStr = args.map(arg => { val argsStr = args
if (arg != null) { .map(arg => {
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" if (arg != null) {
} else { s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
"[Value: NULL]" } else {
} "[Value: NULL]"
}).mkString(", ") }
})
.mkString(", ")
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")

Просмотреть файл

@ -0,0 +1,64 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import java.util.concurrent.{ExecutorService, Executors}
import scala.collection.mutable
/**
* Pool of thread executors. There should be a 1-1 correspondence between C# threads
* and Java threads.
*/
object ThreadPool {
/**
* Map from threadId to corresponding executor.
*/
private val executors: mutable.HashMap[Int, ExecutorService] =
new mutable.HashMap[Int, ExecutorService]()
/**
* Run some code on a particular thread.
*
* @param threadId Integer id of the thread.
* @param task Function to run on the thread.
*/
def run(threadId: Int, task: () => Unit): Unit = {
val executor = getOrCreateExecutor(threadId)
val future = executor.submit(new Runnable {
override def run(): Unit = task()
})
future.get()
}
/**
* Try to delete a particular thread.
*
* @param threadId Integer id of the thread.
* @return True if successful, false if thread does not exist.
*/
def tryDeleteThread(threadId: Int): Boolean = synchronized {
executors.remove(threadId) match {
case Some(executorService) =>
executorService.shutdown()
true
case None => false
}
}
/**
* Get the executor if it exists, otherwise create a new one.
*
* @param id Integer id of the thread.
* @return The new or existing executor with the given id.
*/
private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
}
}

Просмотреть файл

@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import scala.collection.mutable.HashMap
import scala.language.existentials
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.api.dotnet.SerDe._
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
import scala.collection.mutable.HashMap
import scala.language.existentials
/** /**
* Handler for DotnetBackend. * Handler for DotnetBackend.
* This implementation is similar to RBackendHandler. * This implementation is similar to RBackendHandler.
@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend)
// First bit is isStatic // First bit is isStatic
val isStatic = readBoolean(dis) val isStatic = readBoolean(dis)
val threadId = readInt(dis)
val objId = readString(dis) val objId = readString(dis)
val methodName = readString(dis) val methodName = readString(dis)
val numArgs = readInt(dis) val numArgs = readInt(dis)
@ -65,12 +66,24 @@ class DotnetBackendHandler(server: DotnetBackend)
logError(s"Removing $objId failed", e) logError(s"Removing $objId failed", e)
writeInt(dos, -1) writeInt(dos, -1)
} }
case "rmThread" =>
try {
assert(readObjectType(dis) == 'i')
val threadToDelete = readInt(dis)
val result = ThreadPool.tryDeleteThread(threadToDelete)
writeInt(dos, 0)
writeObject(dos, result.asInstanceOf[AnyRef])
} catch {
case e: Exception =>
logError(s"Removing thread $threadId failed", e)
writeInt(dos, -1)
}
case "connectCallback" => case "connectCallback" =>
assert(readObjectType(dis) == 'c') assert(readObjectType(dis) == 'c')
val address = readString(dis) val address = readString(dis)
assert(readObjectType(dis) == 'i') assert(readObjectType(dis) == 'i')
val port = readInt(dis) val port = readInt(dis)
DotnetBackend.setCallbackClient(address, port); DotnetBackend.setCallbackClient(address, port)
writeInt(dos, 0) writeInt(dos, 0)
writeType(dos, "void") writeType(dos, "void")
case "closeCallback" => case "closeCallback" =>
@ -82,7 +95,8 @@ class DotnetBackendHandler(server: DotnetBackend)
case _ => dos.writeInt(-1) case _ => dos.writeInt(-1)
} }
} else { } else {
handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) ThreadPool
.run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
} }
bos.toByteArray bos.toByteArray
@ -162,19 +176,21 @@ class DotnetBackendHandler(server: DotnetBackend)
"invalid method " + methodName + " for object " + objId) "invalid method " + methodName + " for object " + objId)
} }
} catch { } catch {
case e: Exception => case e: Throwable =>
val jvmObj = JVMObjectTracker.get(objId) val jvmObj = JVMObjectTracker.get(objId)
val jvmObjName = jvmObj match { val jvmObjName = jvmObj match {
case Some(jObj) => jObj.getClass.getName case Some(jObj) => jObj.getClass.getName
case None => "NullObject" case None => "NullObject"
} }
val argsStr = args.map(arg => { val argsStr = args
if (arg != null) { .map(arg => {
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" if (arg != null) {
} else { s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
"[Value: NULL]" } else {
} "[Value: NULL]"
}).mkString(", ") }
})
.mkString(", ")
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")

Просмотреть файл

@ -0,0 +1,64 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import java.util.concurrent.{ExecutorService, Executors}
import scala.collection.mutable
/**
* Pool of thread executors. There should be a 1-1 correspondence between C# threads
* and Java threads.
*/
object ThreadPool {
/**
* Map from threadId to corresponding executor.
*/
private val executors: mutable.HashMap[Int, ExecutorService] =
new mutable.HashMap[Int, ExecutorService]()
/**
* Run some code on a particular thread.
*
* @param threadId Integer id of the thread.
* @param task Function to run on the thread.
*/
def run(threadId: Int, task: () => Unit): Unit = {
val executor = getOrCreateExecutor(threadId)
val future = executor.submit(new Runnable {
override def run(): Unit = task()
})
future.get()
}
/**
* Try to delete a particular thread.
*
* @param threadId Integer id of the thread.
* @return True if successful, false if thread does not exist.
*/
def tryDeleteThread(threadId: Int): Boolean = synchronized {
executors.remove(threadId) match {
case Some(executorService) =>
executorService.shutdown()
true
case None => false
}
}
/**
* Get the executor if it exists, otherwise create a new one.
*
* @param id Integer id of the thread.
* @return The new or existing executor with the given id.
*/
private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
}
}

Просмотреть файл

@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend)
// First bit is isStatic // First bit is isStatic
val isStatic = readBoolean(dis) val isStatic = readBoolean(dis)
val threadId = readInt(dis)
val objId = readString(dis) val objId = readString(dis)
val methodName = readString(dis) val methodName = readString(dis)
val numArgs = readInt(dis) val numArgs = readInt(dis)
@ -65,12 +66,24 @@ class DotnetBackendHandler(server: DotnetBackend)
logError(s"Removing $objId failed", e) logError(s"Removing $objId failed", e)
writeInt(dos, -1) writeInt(dos, -1)
} }
case "rmThread" =>
try {
assert(readObjectType(dis) == 'i')
val threadToDelete = readInt(dis)
val result = ThreadPool.tryDeleteThread(threadToDelete)
writeInt(dos, 0)
writeObject(dos, result.asInstanceOf[AnyRef])
} catch {
case e: Exception =>
logError(s"Removing thread $threadId failed", e)
writeInt(dos, -1)
}
case "connectCallback" => case "connectCallback" =>
assert(readObjectType(dis) == 'c') assert(readObjectType(dis) == 'c')
val address = readString(dis) val address = readString(dis)
assert(readObjectType(dis) == 'i') assert(readObjectType(dis) == 'i')
val port = readInt(dis) val port = readInt(dis)
DotnetBackend.setCallbackClient(address, port); DotnetBackend.setCallbackClient(address, port)
writeInt(dos, 0) writeInt(dos, 0)
writeType(dos, "void") writeType(dos, "void")
case "closeCallback" => case "closeCallback" =>
@ -82,7 +95,8 @@ class DotnetBackendHandler(server: DotnetBackend)
case _ => dos.writeInt(-1) case _ => dos.writeInt(-1)
} }
} else { } else {
handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) ThreadPool
.run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
} }
bos.toByteArray bos.toByteArray
@ -162,19 +176,21 @@ class DotnetBackendHandler(server: DotnetBackend)
"invalid method " + methodName + " for object " + objId) "invalid method " + methodName + " for object " + objId)
} }
} catch { } catch {
case e: Exception => case e: Throwable =>
val jvmObj = JVMObjectTracker.get(objId) val jvmObj = JVMObjectTracker.get(objId)
val jvmObjName = jvmObj match { val jvmObjName = jvmObj match {
case Some(jObj) => jObj.getClass.getName case Some(jObj) => jObj.getClass.getName
case None => "NullObject" case None => "NullObject"
} }
val argsStr = args.map(arg => { val argsStr = args
if (arg != null) { .map(arg => {
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" if (arg != null) {
} else { s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
"[Value: NULL]" } else {
} "[Value: NULL]"
}).mkString(", ") }
})
.mkString(", ")
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")

Просмотреть файл

@ -0,0 +1,64 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import java.util.concurrent.{ExecutorService, Executors}
import scala.collection.mutable
/**
* Pool of thread executors. There should be a 1-1 correspondence between C# threads
* and Java threads.
*/
object ThreadPool {
/**
* Map from threadId to corresponding executor.
*/
private val executors: mutable.HashMap[Int, ExecutorService] =
new mutable.HashMap[Int, ExecutorService]()
/**
* Run some code on a particular thread.
*
* @param threadId Integer id of the thread.
* @param task Function to run on the thread.
*/
def run(threadId: Int, task: () => Unit): Unit = {
val executor = getOrCreateExecutor(threadId)
val future = executor.submit(new Runnable {
override def run(): Unit = task()
})
future.get()
}
/**
* Try to delete a particular thread.
*
* @param threadId Integer id of the thread.
* @return True if successful, false if thread does not exist.
*/
def tryDeleteThread(threadId: Int): Boolean = synchronized {
executors.remove(threadId) match {
case Some(executorService) =>
executorService.shutdown()
true
case None => false
}
}
/**
* Get the executor if it exists, otherwise create a new one.
*
* @param id Integer id of the thread.
* @return The new or existing executor with the given id.
*/
private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
}
}