зеркало из https://github.com/dotnet/spark.git
Full support for multithreaded applications (#641)
This commit is contained in:
Родитель
7bcd2a5060
Коммит
a67ad5907e
|
@ -0,0 +1,121 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Threading;
|
||||
using Microsoft.Spark.Interop;
|
||||
using Microsoft.Spark.Interop.Ipc;
|
||||
using Microsoft.Spark.Services;
|
||||
using Microsoft.Spark.Sql;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.Spark.E2ETest.IpcTests
|
||||
{
|
||||
[Collection("Spark E2E Tests")]
|
||||
public class JvmThreadPoolGCTests
|
||||
{
|
||||
private readonly ILoggerService _loggerService;
|
||||
private readonly SparkSession _spark;
|
||||
private readonly IJvmBridge _jvmBridge;
|
||||
|
||||
public JvmThreadPoolGCTests(SparkFixture fixture)
|
||||
{
|
||||
_loggerService = LoggerServiceFactory.GetLogger(typeof(JvmThreadPoolGCTests));
|
||||
_spark = fixture.Spark;
|
||||
_jvmBridge = ((IJvmObjectReferenceProvider)_spark).Reference.Jvm;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Test that the active SparkSession is thread-specific.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void TestThreadLocalSessions()
|
||||
{
|
||||
SparkSession.ClearActiveSession();
|
||||
|
||||
void testChildThread(string appName)
|
||||
{
|
||||
var thread = new Thread(() =>
|
||||
{
|
||||
Assert.Null(SparkSession.GetActiveSession());
|
||||
|
||||
SparkSession.SetActiveSession(
|
||||
SparkSession.Builder().AppName(appName).GetOrCreate());
|
||||
|
||||
// Since we are in the child thread, GetActiveSession() should return the child
|
||||
// SparkSession.
|
||||
SparkSession activeSession = SparkSession.GetActiveSession();
|
||||
Assert.NotNull(activeSession);
|
||||
Assert.Equal(appName, activeSession.Conf().Get("spark.app.name", null));
|
||||
});
|
||||
|
||||
thread.Start();
|
||||
thread.Join();
|
||||
}
|
||||
|
||||
for (int i = 0; i < 5; ++i)
|
||||
{
|
||||
testChildThread(i.ToString());
|
||||
}
|
||||
|
||||
Assert.Null(SparkSession.GetActiveSession());
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Monitor a thread via the JvmThreadPoolGC.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void TestTryAddThread()
|
||||
{
|
||||
using var threadPool = new JvmThreadPoolGC(
|
||||
_loggerService, _jvmBridge, TimeSpan.FromMinutes(30));
|
||||
|
||||
var thread = new Thread(() => _spark.Sql("SELECT TRUE"));
|
||||
thread.Start();
|
||||
|
||||
Assert.True(threadPool.TryAddThread(thread));
|
||||
// Subsequent call should return false, because the thread has already been added.
|
||||
Assert.False(threadPool.TryAddThread(thread));
|
||||
|
||||
thread.Join();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a Spark worker thread in the JVM ThreadPool then remove it directly through
|
||||
/// the JvmBridge.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void TestRmThread()
|
||||
{
|
||||
// Create a thread and ensure that it is initialized in the JVM ThreadPool.
|
||||
var thread = new Thread(() => _spark.Sql("SELECT TRUE"));
|
||||
thread.Start();
|
||||
thread.Join();
|
||||
|
||||
// First call should return true. Second call should return false.
|
||||
Assert.True((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId));
|
||||
Assert.False((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Test that the GC interval configuration defaults to 5 minutes, and can be updated
|
||||
/// correctly by setting the environment variable.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void TestIntervalConfiguration()
|
||||
{
|
||||
// Default value is 5 minutes.
|
||||
Assert.Null(Environment.GetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL"));
|
||||
Assert.Equal(
|
||||
TimeSpan.FromMinutes(5),
|
||||
SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
|
||||
|
||||
// Test a custom value.
|
||||
Environment.SetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL", "1:30:00");
|
||||
Assert.Equal(
|
||||
TimeSpan.FromMinutes(90),
|
||||
SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -35,6 +35,10 @@ namespace Microsoft.Spark.E2ETest.IpcTests
|
|||
|
||||
Assert.IsType<Builder>(SparkSession.Builder());
|
||||
|
||||
SparkSession.ClearActiveSession();
|
||||
SparkSession.SetActiveSession(_spark);
|
||||
Assert.IsType<SparkSession>(SparkSession.GetActiveSession());
|
||||
|
||||
SparkSession.ClearDefaultSession();
|
||||
SparkSession.SetDefaultSession(_spark);
|
||||
Assert.IsType<SparkSession>(SparkSession.GetDefaultSession());
|
||||
|
|
|
@ -8,6 +8,7 @@ using System.Collections.Generic;
|
|||
using System.IO;
|
||||
using System.Net;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using Microsoft.Spark.Network;
|
||||
using Microsoft.Spark.Services;
|
||||
|
||||
|
@ -35,6 +36,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
private readonly ILoggerService _logger =
|
||||
LoggerServiceFactory.GetLogger(typeof(JvmBridge));
|
||||
private readonly int _portNumber;
|
||||
private readonly JvmThreadPoolGC _jvmThreadPoolGC;
|
||||
|
||||
internal JvmBridge(int portNumber)
|
||||
{
|
||||
|
@ -45,6 +47,9 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
|
||||
_portNumber = portNumber;
|
||||
_logger.LogInfo($"JvMBridge port is {portNumber}");
|
||||
|
||||
_jvmThreadPoolGC = new JvmThreadPoolGC(
|
||||
_logger, this, SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
|
||||
}
|
||||
|
||||
private ISocketWrapper GetConnection()
|
||||
|
@ -158,11 +163,13 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
ISocketWrapper socket = null;
|
||||
try
|
||||
{
|
||||
Thread thread = Thread.CurrentThread;
|
||||
MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream();
|
||||
payloadMemoryStream.Position = 0;
|
||||
PayloadHelper.BuildPayload(
|
||||
payloadMemoryStream,
|
||||
isStatic,
|
||||
thread.ManagedThreadId,
|
||||
classNameOrJvmObjectReference,
|
||||
methodName,
|
||||
args);
|
||||
|
@ -176,6 +183,8 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
(int)payloadMemoryStream.Position);
|
||||
outputStream.Flush();
|
||||
|
||||
_jvmThreadPoolGC.TryAddThread(thread);
|
||||
|
||||
Stream inputStream = socket.InputStream;
|
||||
int isMethodCallFailed = SerDe.ReadInt32(inputStream);
|
||||
if (isMethodCallFailed != 0)
|
||||
|
@ -410,6 +419,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
|
||||
public void Dispose()
|
||||
{
|
||||
_jvmThreadPoolGC.Dispose();
|
||||
while (_sockets.TryDequeue(out ISocketWrapper socket))
|
||||
{
|
||||
if (socket != null)
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Collections.Generic;
|
||||
using System.Threading;
|
||||
using Microsoft.Spark.Services;
|
||||
|
||||
namespace Microsoft.Spark.Interop.Ipc
|
||||
{
|
||||
/// <summary>
|
||||
/// In .NET for Apache Spark, we maintain a 1-to-1 mapping between .NET application threads
|
||||
/// and corresponding JVM threads. When a .NET thread calls a Spark API, that call is executed
|
||||
/// by its corresponding JVM thread. This functionality allows for multithreaded applications
|
||||
/// with thread-local variables.
|
||||
///
|
||||
/// This class keeps track of the .NET application thread lifecycle. When a .NET application
|
||||
/// thread is no longer alive, this class submits an rmThread command to the JVM backend to
|
||||
/// dispose of its corresponding JVM thread. All methods are thread-safe.
|
||||
/// </summary>
|
||||
internal class JvmThreadPoolGC : IDisposable
|
||||
{
|
||||
private readonly ILoggerService _loggerService;
|
||||
private readonly IJvmBridge _jvmBridge;
|
||||
private readonly TimeSpan _threadGCInterval;
|
||||
private readonly ConcurrentDictionary<int, Thread> _activeThreads;
|
||||
|
||||
private readonly object _activeThreadGCTimerLock;
|
||||
private Timer _activeThreadGCTimer;
|
||||
|
||||
/// <summary>
|
||||
/// Construct the JvmThreadPoolGC.
|
||||
/// </summary>
|
||||
/// <param name="loggerService">Logger service.</param>
|
||||
/// <param name="jvmBridge">The JvmBridge used to call JVM methods.</param>
|
||||
/// <param name="threadGCInterval">The interval to GC finished threads.</param>
|
||||
public JvmThreadPoolGC(ILoggerService loggerService, IJvmBridge jvmBridge, TimeSpan threadGCInterval)
|
||||
{
|
||||
_loggerService = loggerService;
|
||||
_jvmBridge = jvmBridge;
|
||||
_threadGCInterval = threadGCInterval;
|
||||
_activeThreads = new ConcurrentDictionary<int, Thread>();
|
||||
|
||||
_activeThreadGCTimerLock = new object();
|
||||
_activeThreadGCTimer = null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Dispose of the GC timer and run a final round of thread GC.
|
||||
/// </summary>
|
||||
public void Dispose()
|
||||
{
|
||||
lock (_activeThreadGCTimerLock)
|
||||
{
|
||||
if (_activeThreadGCTimer != null)
|
||||
{
|
||||
_activeThreadGCTimer.Dispose();
|
||||
_activeThreadGCTimer = null;
|
||||
}
|
||||
}
|
||||
|
||||
GCThreads();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to start monitoring a thread.
|
||||
/// </summary>
|
||||
/// <param name="thread">The thread to add.</param>
|
||||
/// <returns>True if success, false if already added.</returns>
|
||||
public bool TryAddThread(Thread thread)
|
||||
{
|
||||
bool returnValue = _activeThreads.TryAdd(thread.ManagedThreadId, thread);
|
||||
|
||||
// Initialize the GC timer if necessary.
|
||||
if (_activeThreadGCTimer == null)
|
||||
{
|
||||
lock (_activeThreadGCTimerLock)
|
||||
{
|
||||
if (_activeThreadGCTimer == null && _activeThreads.Count > 0)
|
||||
{
|
||||
_activeThreadGCTimer = new Timer(
|
||||
(state) => GCThreads(),
|
||||
null,
|
||||
_threadGCInterval,
|
||||
_threadGCInterval);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return returnValue;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to remove a thread from the pool. If the removal is successful, then the
|
||||
/// corresponding JVM thread will also be disposed.
|
||||
/// </summary>
|
||||
/// <param name="threadId">The ID of the thread to remove.</param>
|
||||
/// <returns>True if success, false if the thread cannot be found.</returns>
|
||||
private bool TryDisposeJvmThread(int threadId)
|
||||
{
|
||||
if (_activeThreads.TryRemove(threadId, out _))
|
||||
{
|
||||
// _activeThreads does not have ownership of the threads on the .NET side. This
|
||||
// class does not need to call Join() on the .NET Thread. However, this class is
|
||||
// responsible for sending the rmThread command to the JVM to trigger disposal
|
||||
// of the corresponding JVM thread.
|
||||
if ((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", threadId))
|
||||
{
|
||||
_loggerService.LogDebug($"GC'd JVM thread {threadId}.");
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
_loggerService.LogWarn(
|
||||
$"rmThread returned false for JVM thread {threadId}. " +
|
||||
$"Either thread does not exist or has already been GC'd.");
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Remove any threads that are no longer active.
|
||||
/// </summary>
|
||||
private void GCThreads()
|
||||
{
|
||||
foreach (KeyValuePair<int, Thread> kvp in _activeThreads)
|
||||
{
|
||||
if (!kvp.Value.IsAlive)
|
||||
{
|
||||
TryDisposeJvmThread(kvp.Key);
|
||||
}
|
||||
}
|
||||
|
||||
lock (_activeThreadGCTimerLock)
|
||||
{
|
||||
// Dispose of the timer if there are no threads to monitor.
|
||||
if (_activeThreadGCTimer != null && _activeThreads.IsEmpty)
|
||||
{
|
||||
_activeThreadGCTimer.Dispose();
|
||||
_activeThreadGCTimer = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -27,7 +27,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
private static readonly byte[] s_timestampTypeId = new[] { (byte)'t' };
|
||||
private static readonly byte[] s_jvmObjectTypeId = new[] { (byte)'j' };
|
||||
private static readonly byte[] s_byteArrayTypeId = new[] { (byte)'r' };
|
||||
private static readonly byte[] s_doubleArrayArrayTypeId = new[] { ( byte)'A' };
|
||||
private static readonly byte[] s_doubleArrayArrayTypeId = new[] { (byte)'A' };
|
||||
private static readonly byte[] s_arrayTypeId = new[] { (byte)'l' };
|
||||
private static readonly byte[] s_dictionaryTypeId = new[] { (byte)'e' };
|
||||
private static readonly byte[] s_rowArrTypeId = new[] { (byte)'R' };
|
||||
|
@ -39,6 +39,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
internal static void BuildPayload(
|
||||
MemoryStream destination,
|
||||
bool isStaticMethod,
|
||||
int threadId,
|
||||
object classNameOrJvmObjectReference,
|
||||
string methodName,
|
||||
object[] args)
|
||||
|
@ -48,6 +49,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
destination.Position += sizeof(int);
|
||||
|
||||
SerDe.Write(destination, isStaticMethod);
|
||||
SerDe.Write(destination, threadId);
|
||||
SerDe.Write(destination, classNameOrJvmObjectReference.ToString());
|
||||
SerDe.Write(destination, methodName);
|
||||
SerDe.Write(destination, args.Length);
|
||||
|
|
|
@ -33,6 +33,18 @@ namespace Microsoft.Spark.Services
|
|||
|
||||
private string _workerPath;
|
||||
|
||||
/// <summary>
|
||||
/// How often to run GC on JVM ThreadPool threads. Defaults to 5 minutes.
|
||||
/// </summary>
|
||||
public TimeSpan JvmThreadGCInterval
|
||||
{
|
||||
get
|
||||
{
|
||||
string envVar = Environment.GetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL");
|
||||
return string.IsNullOrEmpty(envVar) ? TimeSpan.FromMinutes(5) : TimeSpan.Parse(envVar);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the port number for socket communication between JVM and CLR.
|
||||
/// </summary>
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
|
||||
namespace Microsoft.Spark.Services
|
||||
{
|
||||
/// <summary>
|
||||
|
@ -9,6 +11,11 @@ namespace Microsoft.Spark.Services
|
|||
/// </summary>
|
||||
internal interface IConfigurationService
|
||||
{
|
||||
/// <summary>
|
||||
/// How often to run GC on JVM ThreadPool threads.
|
||||
/// </summary>
|
||||
TimeSpan JvmThreadGCInterval { get; }
|
||||
|
||||
/// <summary>
|
||||
/// The port number used for communicating with the .NET backend process.
|
||||
/// </summary>
|
||||
|
|
|
@ -61,10 +61,40 @@ namespace Microsoft.Spark.Sql
|
|||
/// <returns>Builder object</returns>
|
||||
public static Builder Builder() => new Builder();
|
||||
|
||||
/// Note that *ActiveSession() APIs are not exposed because these APIs work with a
|
||||
/// thread-local variable, which stores the session variable. Since the Netty server
|
||||
/// that handles the requests is multi-threaded, any thread can invoke these APIs,
|
||||
/// resulting in unexpected behaviors if different threads are used.
|
||||
/// <summary>
|
||||
/// Changes the SparkSession that will be returned in this thread when
|
||||
/// <see cref="Builder.GetOrCreate"/> is called. This can be used to ensure that a given
|
||||
/// thread receives a SparkSession with an isolated session, instead of the global
|
||||
/// (first created) context.
|
||||
/// </summary>
|
||||
/// <param name="session">SparkSession object</param>
|
||||
public static void SetActiveSession(SparkSession session) =>
|
||||
session._jvmObject.Jvm.CallStaticJavaMethod(
|
||||
s_sparkSessionClassName, "setActiveSession", session);
|
||||
|
||||
/// <summary>
|
||||
/// Clears the active SparkSession for current thread. Subsequent calls to
|
||||
/// <see cref="Builder.GetOrCreate"/> will return the first created context
|
||||
/// instead of a thread-local override.
|
||||
/// </summary>
|
||||
public static void ClearActiveSession() =>
|
||||
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
|
||||
s_sparkSessionClassName, "clearActiveSession");
|
||||
|
||||
/// <summary>
|
||||
/// Returns the active SparkSession for the current thread, returned by the builder.
|
||||
/// </summary>
|
||||
/// <returns>Return null, when calling this function on executors</returns>
|
||||
public static SparkSession GetActiveSession()
|
||||
{
|
||||
var optionalSession = new Option(
|
||||
(JvmObjectReference)SparkEnvironment.JvmBridge.CallStaticJavaMethod(
|
||||
s_sparkSessionClassName, "getActiveSession"));
|
||||
|
||||
return optionalSession.IsDefined()
|
||||
? new SparkSession((JvmObjectReference)optionalSession.Get())
|
||||
: null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sets the default SparkSession that is returned by the builder.
|
||||
|
|
|
@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet
|
|||
|
||||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.language.existentials
|
||||
|
||||
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
|
||||
import org.apache.spark.api.dotnet.SerDe._
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.language.existentials
|
||||
|
||||
/**
|
||||
* Handler for DotnetBackend.
|
||||
* This implementation is similar to RBackendHandler.
|
||||
|
@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend)
|
|||
|
||||
// First bit is isStatic
|
||||
val isStatic = readBoolean(dis)
|
||||
val threadId = readInt(dis)
|
||||
val objId = readString(dis)
|
||||
val methodName = readString(dis)
|
||||
val numArgs = readInt(dis)
|
||||
|
@ -65,12 +66,24 @@ class DotnetBackendHandler(server: DotnetBackend)
|
|||
logError(s"Removing $objId failed", e)
|
||||
writeInt(dos, -1)
|
||||
}
|
||||
case "rmThread" =>
|
||||
try {
|
||||
assert(readObjectType(dis) == 'i')
|
||||
val threadToDelete = readInt(dis)
|
||||
val result = ThreadPool.tryDeleteThread(threadToDelete)
|
||||
writeInt(dos, 0)
|
||||
writeObject(dos, result.asInstanceOf[AnyRef])
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError(s"Removing thread $threadId failed", e)
|
||||
writeInt(dos, -1)
|
||||
}
|
||||
case "connectCallback" =>
|
||||
assert(readObjectType(dis) == 'c')
|
||||
val address = readString(dis)
|
||||
assert(readObjectType(dis) == 'i')
|
||||
val port = readInt(dis)
|
||||
DotnetBackend.setCallbackClient(address, port);
|
||||
DotnetBackend.setCallbackClient(address, port)
|
||||
writeInt(dos, 0)
|
||||
writeType(dos, "void")
|
||||
case "closeCallback" =>
|
||||
|
@ -82,7 +95,8 @@ class DotnetBackendHandler(server: DotnetBackend)
|
|||
case _ => dos.writeInt(-1)
|
||||
}
|
||||
} else {
|
||||
handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
|
||||
ThreadPool
|
||||
.run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
|
||||
}
|
||||
|
||||
bos.toByteArray
|
||||
|
@ -162,19 +176,21 @@ class DotnetBackendHandler(server: DotnetBackend)
|
|||
"invalid method " + methodName + " for object " + objId)
|
||||
}
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
case e: Throwable =>
|
||||
val jvmObj = JVMObjectTracker.get(objId)
|
||||
val jvmObjName = jvmObj match {
|
||||
case Some(jObj) => jObj.getClass.getName
|
||||
case None => "NullObject"
|
||||
}
|
||||
val argsStr = args.map(arg => {
|
||||
if (arg != null) {
|
||||
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
|
||||
} else {
|
||||
"[Value: NULL]"
|
||||
}
|
||||
}).mkString(", ")
|
||||
val argsStr = args
|
||||
.map(arg => {
|
||||
if (arg != null) {
|
||||
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
|
||||
} else {
|
||||
"[Value: NULL]"
|
||||
}
|
||||
})
|
||||
.mkString(", ")
|
||||
|
||||
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
|
||||
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import java.util.concurrent.{ExecutorService, Executors}
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
/**
|
||||
* Pool of thread executors. There should be a 1-1 correspondence between C# threads
|
||||
* and Java threads.
|
||||
*/
|
||||
object ThreadPool {
|
||||
|
||||
/**
|
||||
* Map from threadId to corresponding executor.
|
||||
*/
|
||||
private val executors: mutable.HashMap[Int, ExecutorService] =
|
||||
new mutable.HashMap[Int, ExecutorService]()
|
||||
|
||||
/**
|
||||
* Run some code on a particular thread.
|
||||
*
|
||||
* @param threadId Integer id of the thread.
|
||||
* @param task Function to run on the thread.
|
||||
*/
|
||||
def run(threadId: Int, task: () => Unit): Unit = {
|
||||
val executor = getOrCreateExecutor(threadId)
|
||||
val future = executor.submit(new Runnable {
|
||||
override def run(): Unit = task()
|
||||
})
|
||||
|
||||
future.get()
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to delete a particular thread.
|
||||
*
|
||||
* @param threadId Integer id of the thread.
|
||||
* @return True if successful, false if thread does not exist.
|
||||
*/
|
||||
def tryDeleteThread(threadId: Int): Boolean = synchronized {
|
||||
executors.remove(threadId) match {
|
||||
case Some(executorService) =>
|
||||
executorService.shutdown()
|
||||
true
|
||||
case None => false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the executor if it exists, otherwise create a new one.
|
||||
*
|
||||
* @param id Integer id of the thread.
|
||||
* @return The new or existing executor with the given id.
|
||||
*/
|
||||
private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
|
||||
executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
|
||||
}
|
||||
}
|
|
@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet
|
|||
|
||||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.language.existentials
|
||||
|
||||
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
|
||||
import org.apache.spark.api.dotnet.SerDe._
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.language.existentials
|
||||
|
||||
/**
|
||||
* Handler for DotnetBackend.
|
||||
* This implementation is similar to RBackendHandler.
|
||||
|
@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend)
|
|||
|
||||
// First bit is isStatic
|
||||
val isStatic = readBoolean(dis)
|
||||
val threadId = readInt(dis)
|
||||
val objId = readString(dis)
|
||||
val methodName = readString(dis)
|
||||
val numArgs = readInt(dis)
|
||||
|
@ -65,12 +66,24 @@ class DotnetBackendHandler(server: DotnetBackend)
|
|||
logError(s"Removing $objId failed", e)
|
||||
writeInt(dos, -1)
|
||||
}
|
||||
case "rmThread" =>
|
||||
try {
|
||||
assert(readObjectType(dis) == 'i')
|
||||
val threadToDelete = readInt(dis)
|
||||
val result = ThreadPool.tryDeleteThread(threadToDelete)
|
||||
writeInt(dos, 0)
|
||||
writeObject(dos, result.asInstanceOf[AnyRef])
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError(s"Removing thread $threadId failed", e)
|
||||
writeInt(dos, -1)
|
||||
}
|
||||
case "connectCallback" =>
|
||||
assert(readObjectType(dis) == 'c')
|
||||
val address = readString(dis)
|
||||
assert(readObjectType(dis) == 'i')
|
||||
val port = readInt(dis)
|
||||
DotnetBackend.setCallbackClient(address, port);
|
||||
DotnetBackend.setCallbackClient(address, port)
|
||||
writeInt(dos, 0)
|
||||
writeType(dos, "void")
|
||||
case "closeCallback" =>
|
||||
|
@ -82,7 +95,8 @@ class DotnetBackendHandler(server: DotnetBackend)
|
|||
case _ => dos.writeInt(-1)
|
||||
}
|
||||
} else {
|
||||
handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
|
||||
ThreadPool
|
||||
.run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
|
||||
}
|
||||
|
||||
bos.toByteArray
|
||||
|
@ -162,19 +176,21 @@ class DotnetBackendHandler(server: DotnetBackend)
|
|||
"invalid method " + methodName + " for object " + objId)
|
||||
}
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
case e: Throwable =>
|
||||
val jvmObj = JVMObjectTracker.get(objId)
|
||||
val jvmObjName = jvmObj match {
|
||||
case Some(jObj) => jObj.getClass.getName
|
||||
case None => "NullObject"
|
||||
}
|
||||
val argsStr = args.map(arg => {
|
||||
if (arg != null) {
|
||||
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
|
||||
} else {
|
||||
"[Value: NULL]"
|
||||
}
|
||||
}).mkString(", ")
|
||||
val argsStr = args
|
||||
.map(arg => {
|
||||
if (arg != null) {
|
||||
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
|
||||
} else {
|
||||
"[Value: NULL]"
|
||||
}
|
||||
})
|
||||
.mkString(", ")
|
||||
|
||||
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
|
||||
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import java.util.concurrent.{ExecutorService, Executors}
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
/**
|
||||
* Pool of thread executors. There should be a 1-1 correspondence between C# threads
|
||||
* and Java threads.
|
||||
*/
|
||||
object ThreadPool {
|
||||
|
||||
/**
|
||||
* Map from threadId to corresponding executor.
|
||||
*/
|
||||
private val executors: mutable.HashMap[Int, ExecutorService] =
|
||||
new mutable.HashMap[Int, ExecutorService]()
|
||||
|
||||
/**
|
||||
* Run some code on a particular thread.
|
||||
*
|
||||
* @param threadId Integer id of the thread.
|
||||
* @param task Function to run on the thread.
|
||||
*/
|
||||
def run(threadId: Int, task: () => Unit): Unit = {
|
||||
val executor = getOrCreateExecutor(threadId)
|
||||
val future = executor.submit(new Runnable {
|
||||
override def run(): Unit = task()
|
||||
})
|
||||
|
||||
future.get()
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to delete a particular thread.
|
||||
*
|
||||
* @param threadId Integer id of the thread.
|
||||
* @return True if successful, false if thread does not exist.
|
||||
*/
|
||||
def tryDeleteThread(threadId: Int): Boolean = synchronized {
|
||||
executors.remove(threadId) match {
|
||||
case Some(executorService) =>
|
||||
executorService.shutdown()
|
||||
true
|
||||
case None => false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the executor if it exists, otherwise create a new one.
|
||||
*
|
||||
* @param id Integer id of the thread.
|
||||
* @return The new or existing executor with the given id.
|
||||
*/
|
||||
private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
|
||||
executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
|
||||
}
|
||||
}
|
|
@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend)
|
|||
|
||||
// First bit is isStatic
|
||||
val isStatic = readBoolean(dis)
|
||||
val threadId = readInt(dis)
|
||||
val objId = readString(dis)
|
||||
val methodName = readString(dis)
|
||||
val numArgs = readInt(dis)
|
||||
|
@ -65,12 +66,24 @@ class DotnetBackendHandler(server: DotnetBackend)
|
|||
logError(s"Removing $objId failed", e)
|
||||
writeInt(dos, -1)
|
||||
}
|
||||
case "rmThread" =>
|
||||
try {
|
||||
assert(readObjectType(dis) == 'i')
|
||||
val threadToDelete = readInt(dis)
|
||||
val result = ThreadPool.tryDeleteThread(threadToDelete)
|
||||
writeInt(dos, 0)
|
||||
writeObject(dos, result.asInstanceOf[AnyRef])
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError(s"Removing thread $threadId failed", e)
|
||||
writeInt(dos, -1)
|
||||
}
|
||||
case "connectCallback" =>
|
||||
assert(readObjectType(dis) == 'c')
|
||||
val address = readString(dis)
|
||||
assert(readObjectType(dis) == 'i')
|
||||
val port = readInt(dis)
|
||||
DotnetBackend.setCallbackClient(address, port);
|
||||
DotnetBackend.setCallbackClient(address, port)
|
||||
writeInt(dos, 0)
|
||||
writeType(dos, "void")
|
||||
case "closeCallback" =>
|
||||
|
@ -82,7 +95,8 @@ class DotnetBackendHandler(server: DotnetBackend)
|
|||
case _ => dos.writeInt(-1)
|
||||
}
|
||||
} else {
|
||||
handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
|
||||
ThreadPool
|
||||
.run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
|
||||
}
|
||||
|
||||
bos.toByteArray
|
||||
|
@ -162,19 +176,21 @@ class DotnetBackendHandler(server: DotnetBackend)
|
|||
"invalid method " + methodName + " for object " + objId)
|
||||
}
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
case e: Throwable =>
|
||||
val jvmObj = JVMObjectTracker.get(objId)
|
||||
val jvmObjName = jvmObj match {
|
||||
case Some(jObj) => jObj.getClass.getName
|
||||
case None => "NullObject"
|
||||
}
|
||||
val argsStr = args.map(arg => {
|
||||
if (arg != null) {
|
||||
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
|
||||
} else {
|
||||
"[Value: NULL]"
|
||||
}
|
||||
}).mkString(", ")
|
||||
val argsStr = args
|
||||
.map(arg => {
|
||||
if (arg != null) {
|
||||
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
|
||||
} else {
|
||||
"[Value: NULL]"
|
||||
}
|
||||
})
|
||||
.mkString(", ")
|
||||
|
||||
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
|
||||
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import java.util.concurrent.{ExecutorService, Executors}
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
/**
|
||||
* Pool of thread executors. There should be a 1-1 correspondence between C# threads
|
||||
* and Java threads.
|
||||
*/
|
||||
object ThreadPool {
|
||||
|
||||
/**
|
||||
* Map from threadId to corresponding executor.
|
||||
*/
|
||||
private val executors: mutable.HashMap[Int, ExecutorService] =
|
||||
new mutable.HashMap[Int, ExecutorService]()
|
||||
|
||||
/**
|
||||
* Run some code on a particular thread.
|
||||
*
|
||||
* @param threadId Integer id of the thread.
|
||||
* @param task Function to run on the thread.
|
||||
*/
|
||||
def run(threadId: Int, task: () => Unit): Unit = {
|
||||
val executor = getOrCreateExecutor(threadId)
|
||||
val future = executor.submit(new Runnable {
|
||||
override def run(): Unit = task()
|
||||
})
|
||||
|
||||
future.get()
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to delete a particular thread.
|
||||
*
|
||||
* @param threadId Integer id of the thread.
|
||||
* @return True if successful, false if thread does not exist.
|
||||
*/
|
||||
def tryDeleteThread(threadId: Int): Boolean = synchronized {
|
||||
executors.remove(threadId) match {
|
||||
case Some(executorService) =>
|
||||
executorService.shutdown()
|
||||
true
|
||||
case None => false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the executor if it exists, otherwise create a new one.
|
||||
*
|
||||
* @param id Integer id of the thread.
|
||||
* @return The new or existing executor with the given id.
|
||||
*/
|
||||
private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
|
||||
executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
|
||||
}
|
||||
}
|
Загрузка…
Ссылка в новой задаче