зеркало из https://github.com/dotnet/spark.git
Full support for multithreaded applications (#641)
This commit is contained in:
Родитель
7bcd2a5060
Коммит
a67ad5907e
|
@ -0,0 +1,121 @@
|
||||||
|
// Licensed to the .NET Foundation under one or more agreements.
|
||||||
|
// The .NET Foundation licenses this file to you under the MIT license.
|
||||||
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
|
using System;
|
||||||
|
using System.Threading;
|
||||||
|
using Microsoft.Spark.Interop;
|
||||||
|
using Microsoft.Spark.Interop.Ipc;
|
||||||
|
using Microsoft.Spark.Services;
|
||||||
|
using Microsoft.Spark.Sql;
|
||||||
|
using Xunit;
|
||||||
|
|
||||||
|
namespace Microsoft.Spark.E2ETest.IpcTests
|
||||||
|
{
|
||||||
|
[Collection("Spark E2E Tests")]
|
||||||
|
public class JvmThreadPoolGCTests
|
||||||
|
{
|
||||||
|
private readonly ILoggerService _loggerService;
|
||||||
|
private readonly SparkSession _spark;
|
||||||
|
private readonly IJvmBridge _jvmBridge;
|
||||||
|
|
||||||
|
public JvmThreadPoolGCTests(SparkFixture fixture)
|
||||||
|
{
|
||||||
|
_loggerService = LoggerServiceFactory.GetLogger(typeof(JvmThreadPoolGCTests));
|
||||||
|
_spark = fixture.Spark;
|
||||||
|
_jvmBridge = ((IJvmObjectReferenceProvider)_spark).Reference.Jvm;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Test that the active SparkSession is thread-specific.
|
||||||
|
/// </summary>
|
||||||
|
[Fact]
|
||||||
|
public void TestThreadLocalSessions()
|
||||||
|
{
|
||||||
|
SparkSession.ClearActiveSession();
|
||||||
|
|
||||||
|
void testChildThread(string appName)
|
||||||
|
{
|
||||||
|
var thread = new Thread(() =>
|
||||||
|
{
|
||||||
|
Assert.Null(SparkSession.GetActiveSession());
|
||||||
|
|
||||||
|
SparkSession.SetActiveSession(
|
||||||
|
SparkSession.Builder().AppName(appName).GetOrCreate());
|
||||||
|
|
||||||
|
// Since we are in the child thread, GetActiveSession() should return the child
|
||||||
|
// SparkSession.
|
||||||
|
SparkSession activeSession = SparkSession.GetActiveSession();
|
||||||
|
Assert.NotNull(activeSession);
|
||||||
|
Assert.Equal(appName, activeSession.Conf().Get("spark.app.name", null));
|
||||||
|
});
|
||||||
|
|
||||||
|
thread.Start();
|
||||||
|
thread.Join();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 5; ++i)
|
||||||
|
{
|
||||||
|
testChildThread(i.ToString());
|
||||||
|
}
|
||||||
|
|
||||||
|
Assert.Null(SparkSession.GetActiveSession());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Monitor a thread via the JvmThreadPoolGC.
|
||||||
|
/// </summary>
|
||||||
|
[Fact]
|
||||||
|
public void TestTryAddThread()
|
||||||
|
{
|
||||||
|
using var threadPool = new JvmThreadPoolGC(
|
||||||
|
_loggerService, _jvmBridge, TimeSpan.FromMinutes(30));
|
||||||
|
|
||||||
|
var thread = new Thread(() => _spark.Sql("SELECT TRUE"));
|
||||||
|
thread.Start();
|
||||||
|
|
||||||
|
Assert.True(threadPool.TryAddThread(thread));
|
||||||
|
// Subsequent call should return false, because the thread has already been added.
|
||||||
|
Assert.False(threadPool.TryAddThread(thread));
|
||||||
|
|
||||||
|
thread.Join();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Create a Spark worker thread in the JVM ThreadPool then remove it directly through
|
||||||
|
/// the JvmBridge.
|
||||||
|
/// </summary>
|
||||||
|
[Fact]
|
||||||
|
public void TestRmThread()
|
||||||
|
{
|
||||||
|
// Create a thread and ensure that it is initialized in the JVM ThreadPool.
|
||||||
|
var thread = new Thread(() => _spark.Sql("SELECT TRUE"));
|
||||||
|
thread.Start();
|
||||||
|
thread.Join();
|
||||||
|
|
||||||
|
// First call should return true. Second call should return false.
|
||||||
|
Assert.True((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId));
|
||||||
|
Assert.False((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Test that the GC interval configuration defaults to 5 minutes, and can be updated
|
||||||
|
/// correctly by setting the environment variable.
|
||||||
|
/// </summary>
|
||||||
|
[Fact]
|
||||||
|
public void TestIntervalConfiguration()
|
||||||
|
{
|
||||||
|
// Default value is 5 minutes.
|
||||||
|
Assert.Null(Environment.GetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL"));
|
||||||
|
Assert.Equal(
|
||||||
|
TimeSpan.FromMinutes(5),
|
||||||
|
SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
|
||||||
|
|
||||||
|
// Test a custom value.
|
||||||
|
Environment.SetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL", "1:30:00");
|
||||||
|
Assert.Equal(
|
||||||
|
TimeSpan.FromMinutes(90),
|
||||||
|
SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -35,6 +35,10 @@ namespace Microsoft.Spark.E2ETest.IpcTests
|
||||||
|
|
||||||
Assert.IsType<Builder>(SparkSession.Builder());
|
Assert.IsType<Builder>(SparkSession.Builder());
|
||||||
|
|
||||||
|
SparkSession.ClearActiveSession();
|
||||||
|
SparkSession.SetActiveSession(_spark);
|
||||||
|
Assert.IsType<SparkSession>(SparkSession.GetActiveSession());
|
||||||
|
|
||||||
SparkSession.ClearDefaultSession();
|
SparkSession.ClearDefaultSession();
|
||||||
SparkSession.SetDefaultSession(_spark);
|
SparkSession.SetDefaultSession(_spark);
|
||||||
Assert.IsType<SparkSession>(SparkSession.GetDefaultSession());
|
Assert.IsType<SparkSession>(SparkSession.GetDefaultSession());
|
||||||
|
@ -76,7 +80,7 @@ namespace Microsoft.Spark.E2ETest.IpcTests
|
||||||
/// </summary>
|
/// </summary>
|
||||||
[Fact]
|
[Fact]
|
||||||
public void TestCreateDataFrame()
|
public void TestCreateDataFrame()
|
||||||
{
|
{
|
||||||
// Calling CreateDataFrame with schema
|
// Calling CreateDataFrame with schema
|
||||||
{
|
{
|
||||||
var data = new List<GenericRow>
|
var data = new List<GenericRow>
|
||||||
|
|
|
@ -8,6 +8,7 @@ using System.Collections.Generic;
|
||||||
using System.IO;
|
using System.IO;
|
||||||
using System.Net;
|
using System.Net;
|
||||||
using System.Text;
|
using System.Text;
|
||||||
|
using System.Threading;
|
||||||
using Microsoft.Spark.Network;
|
using Microsoft.Spark.Network;
|
||||||
using Microsoft.Spark.Services;
|
using Microsoft.Spark.Services;
|
||||||
|
|
||||||
|
@ -35,6 +36,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
||||||
private readonly ILoggerService _logger =
|
private readonly ILoggerService _logger =
|
||||||
LoggerServiceFactory.GetLogger(typeof(JvmBridge));
|
LoggerServiceFactory.GetLogger(typeof(JvmBridge));
|
||||||
private readonly int _portNumber;
|
private readonly int _portNumber;
|
||||||
|
private readonly JvmThreadPoolGC _jvmThreadPoolGC;
|
||||||
|
|
||||||
internal JvmBridge(int portNumber)
|
internal JvmBridge(int portNumber)
|
||||||
{
|
{
|
||||||
|
@ -45,6 +47,9 @@ namespace Microsoft.Spark.Interop.Ipc
|
||||||
|
|
||||||
_portNumber = portNumber;
|
_portNumber = portNumber;
|
||||||
_logger.LogInfo($"JvMBridge port is {portNumber}");
|
_logger.LogInfo($"JvMBridge port is {portNumber}");
|
||||||
|
|
||||||
|
_jvmThreadPoolGC = new JvmThreadPoolGC(
|
||||||
|
_logger, this, SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
|
||||||
}
|
}
|
||||||
|
|
||||||
private ISocketWrapper GetConnection()
|
private ISocketWrapper GetConnection()
|
||||||
|
@ -158,11 +163,13 @@ namespace Microsoft.Spark.Interop.Ipc
|
||||||
ISocketWrapper socket = null;
|
ISocketWrapper socket = null;
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
|
Thread thread = Thread.CurrentThread;
|
||||||
MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream();
|
MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream();
|
||||||
payloadMemoryStream.Position = 0;
|
payloadMemoryStream.Position = 0;
|
||||||
PayloadHelper.BuildPayload(
|
PayloadHelper.BuildPayload(
|
||||||
payloadMemoryStream,
|
payloadMemoryStream,
|
||||||
isStatic,
|
isStatic,
|
||||||
|
thread.ManagedThreadId,
|
||||||
classNameOrJvmObjectReference,
|
classNameOrJvmObjectReference,
|
||||||
methodName,
|
methodName,
|
||||||
args);
|
args);
|
||||||
|
@ -176,6 +183,8 @@ namespace Microsoft.Spark.Interop.Ipc
|
||||||
(int)payloadMemoryStream.Position);
|
(int)payloadMemoryStream.Position);
|
||||||
outputStream.Flush();
|
outputStream.Flush();
|
||||||
|
|
||||||
|
_jvmThreadPoolGC.TryAddThread(thread);
|
||||||
|
|
||||||
Stream inputStream = socket.InputStream;
|
Stream inputStream = socket.InputStream;
|
||||||
int isMethodCallFailed = SerDe.ReadInt32(inputStream);
|
int isMethodCallFailed = SerDe.ReadInt32(inputStream);
|
||||||
if (isMethodCallFailed != 0)
|
if (isMethodCallFailed != 0)
|
||||||
|
@ -410,6 +419,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
||||||
|
|
||||||
public void Dispose()
|
public void Dispose()
|
||||||
{
|
{
|
||||||
|
_jvmThreadPoolGC.Dispose();
|
||||||
while (_sockets.TryDequeue(out ISocketWrapper socket))
|
while (_sockets.TryDequeue(out ISocketWrapper socket))
|
||||||
{
|
{
|
||||||
if (socket != null)
|
if (socket != null)
|
||||||
|
|
|
@ -0,0 +1,149 @@
|
||||||
|
// Licensed to the .NET Foundation under one or more agreements.
|
||||||
|
// The .NET Foundation licenses this file to you under the MIT license.
|
||||||
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
|
using System;
|
||||||
|
using System.Collections.Concurrent;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Threading;
|
||||||
|
using Microsoft.Spark.Services;
|
||||||
|
|
||||||
|
namespace Microsoft.Spark.Interop.Ipc
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// In .NET for Apache Spark, we maintain a 1-to-1 mapping between .NET application threads
|
||||||
|
/// and corresponding JVM threads. When a .NET thread calls a Spark API, that call is executed
|
||||||
|
/// by its corresponding JVM thread. This functionality allows for multithreaded applications
|
||||||
|
/// with thread-local variables.
|
||||||
|
///
|
||||||
|
/// This class keeps track of the .NET application thread lifecycle. When a .NET application
|
||||||
|
/// thread is no longer alive, this class submits an rmThread command to the JVM backend to
|
||||||
|
/// dispose of its corresponding JVM thread. All methods are thread-safe.
|
||||||
|
/// </summary>
|
||||||
|
internal class JvmThreadPoolGC : IDisposable
|
||||||
|
{
|
||||||
|
private readonly ILoggerService _loggerService;
|
||||||
|
private readonly IJvmBridge _jvmBridge;
|
||||||
|
private readonly TimeSpan _threadGCInterval;
|
||||||
|
private readonly ConcurrentDictionary<int, Thread> _activeThreads;
|
||||||
|
|
||||||
|
private readonly object _activeThreadGCTimerLock;
|
||||||
|
private Timer _activeThreadGCTimer;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Construct the JvmThreadPoolGC.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="loggerService">Logger service.</param>
|
||||||
|
/// <param name="jvmBridge">The JvmBridge used to call JVM methods.</param>
|
||||||
|
/// <param name="threadGCInterval">The interval to GC finished threads.</param>
|
||||||
|
public JvmThreadPoolGC(ILoggerService loggerService, IJvmBridge jvmBridge, TimeSpan threadGCInterval)
|
||||||
|
{
|
||||||
|
_loggerService = loggerService;
|
||||||
|
_jvmBridge = jvmBridge;
|
||||||
|
_threadGCInterval = threadGCInterval;
|
||||||
|
_activeThreads = new ConcurrentDictionary<int, Thread>();
|
||||||
|
|
||||||
|
_activeThreadGCTimerLock = new object();
|
||||||
|
_activeThreadGCTimer = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Dispose of the GC timer and run a final round of thread GC.
|
||||||
|
/// </summary>
|
||||||
|
public void Dispose()
|
||||||
|
{
|
||||||
|
lock (_activeThreadGCTimerLock)
|
||||||
|
{
|
||||||
|
if (_activeThreadGCTimer != null)
|
||||||
|
{
|
||||||
|
_activeThreadGCTimer.Dispose();
|
||||||
|
_activeThreadGCTimer = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GCThreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Try to start monitoring a thread.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="thread">The thread to add.</param>
|
||||||
|
/// <returns>True if success, false if already added.</returns>
|
||||||
|
public bool TryAddThread(Thread thread)
|
||||||
|
{
|
||||||
|
bool returnValue = _activeThreads.TryAdd(thread.ManagedThreadId, thread);
|
||||||
|
|
||||||
|
// Initialize the GC timer if necessary.
|
||||||
|
if (_activeThreadGCTimer == null)
|
||||||
|
{
|
||||||
|
lock (_activeThreadGCTimerLock)
|
||||||
|
{
|
||||||
|
if (_activeThreadGCTimer == null && _activeThreads.Count > 0)
|
||||||
|
{
|
||||||
|
_activeThreadGCTimer = new Timer(
|
||||||
|
(state) => GCThreads(),
|
||||||
|
null,
|
||||||
|
_threadGCInterval,
|
||||||
|
_threadGCInterval);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return returnValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Try to remove a thread from the pool. If the removal is successful, then the
|
||||||
|
/// corresponding JVM thread will also be disposed.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="threadId">The ID of the thread to remove.</param>
|
||||||
|
/// <returns>True if success, false if the thread cannot be found.</returns>
|
||||||
|
private bool TryDisposeJvmThread(int threadId)
|
||||||
|
{
|
||||||
|
if (_activeThreads.TryRemove(threadId, out _))
|
||||||
|
{
|
||||||
|
// _activeThreads does not have ownership of the threads on the .NET side. This
|
||||||
|
// class does not need to call Join() on the .NET Thread. However, this class is
|
||||||
|
// responsible for sending the rmThread command to the JVM to trigger disposal
|
||||||
|
// of the corresponding JVM thread.
|
||||||
|
if ((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", threadId))
|
||||||
|
{
|
||||||
|
_loggerService.LogDebug($"GC'd JVM thread {threadId}.");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
_loggerService.LogWarn(
|
||||||
|
$"rmThread returned false for JVM thread {threadId}. " +
|
||||||
|
$"Either thread does not exist or has already been GC'd.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Remove any threads that are no longer active.
|
||||||
|
/// </summary>
|
||||||
|
private void GCThreads()
|
||||||
|
{
|
||||||
|
foreach (KeyValuePair<int, Thread> kvp in _activeThreads)
|
||||||
|
{
|
||||||
|
if (!kvp.Value.IsAlive)
|
||||||
|
{
|
||||||
|
TryDisposeJvmThread(kvp.Key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lock (_activeThreadGCTimerLock)
|
||||||
|
{
|
||||||
|
// Dispose of the timer if there are no threads to monitor.
|
||||||
|
if (_activeThreadGCTimer != null && _activeThreads.IsEmpty)
|
||||||
|
{
|
||||||
|
_activeThreadGCTimer.Dispose();
|
||||||
|
_activeThreadGCTimer = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -27,7 +27,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
||||||
private static readonly byte[] s_timestampTypeId = new[] { (byte)'t' };
|
private static readonly byte[] s_timestampTypeId = new[] { (byte)'t' };
|
||||||
private static readonly byte[] s_jvmObjectTypeId = new[] { (byte)'j' };
|
private static readonly byte[] s_jvmObjectTypeId = new[] { (byte)'j' };
|
||||||
private static readonly byte[] s_byteArrayTypeId = new[] { (byte)'r' };
|
private static readonly byte[] s_byteArrayTypeId = new[] { (byte)'r' };
|
||||||
private static readonly byte[] s_doubleArrayArrayTypeId = new[] { ( byte)'A' };
|
private static readonly byte[] s_doubleArrayArrayTypeId = new[] { (byte)'A' };
|
||||||
private static readonly byte[] s_arrayTypeId = new[] { (byte)'l' };
|
private static readonly byte[] s_arrayTypeId = new[] { (byte)'l' };
|
||||||
private static readonly byte[] s_dictionaryTypeId = new[] { (byte)'e' };
|
private static readonly byte[] s_dictionaryTypeId = new[] { (byte)'e' };
|
||||||
private static readonly byte[] s_rowArrTypeId = new[] { (byte)'R' };
|
private static readonly byte[] s_rowArrTypeId = new[] { (byte)'R' };
|
||||||
|
@ -39,6 +39,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
||||||
internal static void BuildPayload(
|
internal static void BuildPayload(
|
||||||
MemoryStream destination,
|
MemoryStream destination,
|
||||||
bool isStaticMethod,
|
bool isStaticMethod,
|
||||||
|
int threadId,
|
||||||
object classNameOrJvmObjectReference,
|
object classNameOrJvmObjectReference,
|
||||||
string methodName,
|
string methodName,
|
||||||
object[] args)
|
object[] args)
|
||||||
|
@ -48,6 +49,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
||||||
destination.Position += sizeof(int);
|
destination.Position += sizeof(int);
|
||||||
|
|
||||||
SerDe.Write(destination, isStaticMethod);
|
SerDe.Write(destination, isStaticMethod);
|
||||||
|
SerDe.Write(destination, threadId);
|
||||||
SerDe.Write(destination, classNameOrJvmObjectReference.ToString());
|
SerDe.Write(destination, classNameOrJvmObjectReference.ToString());
|
||||||
SerDe.Write(destination, methodName);
|
SerDe.Write(destination, methodName);
|
||||||
SerDe.Write(destination, args.Length);
|
SerDe.Write(destination, args.Length);
|
||||||
|
@ -140,7 +142,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
||||||
SerDe.Write(destination, d);
|
SerDe.Write(destination, d);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case double[][] argDoubleArrayArray:
|
case double[][] argDoubleArrayArray:
|
||||||
SerDe.Write(destination, s_doubleArrayArrayTypeId);
|
SerDe.Write(destination, s_doubleArrayArrayTypeId);
|
||||||
SerDe.Write(destination, argDoubleArrayArray.Length);
|
SerDe.Write(destination, argDoubleArrayArray.Length);
|
||||||
|
|
|
@ -33,6 +33,18 @@ namespace Microsoft.Spark.Services
|
||||||
|
|
||||||
private string _workerPath;
|
private string _workerPath;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// How often to run GC on JVM ThreadPool threads. Defaults to 5 minutes.
|
||||||
|
/// </summary>
|
||||||
|
public TimeSpan JvmThreadGCInterval
|
||||||
|
{
|
||||||
|
get
|
||||||
|
{
|
||||||
|
string envVar = Environment.GetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL");
|
||||||
|
return string.IsNullOrEmpty(envVar) ? TimeSpan.FromMinutes(5) : TimeSpan.Parse(envVar);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Returns the port number for socket communication between JVM and CLR.
|
/// Returns the port number for socket communication between JVM and CLR.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
// The .NET Foundation licenses this file to you under the MIT license.
|
// The .NET Foundation licenses this file to you under the MIT license.
|
||||||
// See the LICENSE file in the project root for more information.
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
|
using System;
|
||||||
|
|
||||||
namespace Microsoft.Spark.Services
|
namespace Microsoft.Spark.Services
|
||||||
{
|
{
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
@ -9,6 +11,11 @@ namespace Microsoft.Spark.Services
|
||||||
/// </summary>
|
/// </summary>
|
||||||
internal interface IConfigurationService
|
internal interface IConfigurationService
|
||||||
{
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// How often to run GC on JVM ThreadPool threads.
|
||||||
|
/// </summary>
|
||||||
|
TimeSpan JvmThreadGCInterval { get; }
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// The port number used for communicating with the .NET backend process.
|
/// The port number used for communicating with the .NET backend process.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
|
|
@ -61,10 +61,40 @@ namespace Microsoft.Spark.Sql
|
||||||
/// <returns>Builder object</returns>
|
/// <returns>Builder object</returns>
|
||||||
public static Builder Builder() => new Builder();
|
public static Builder Builder() => new Builder();
|
||||||
|
|
||||||
/// Note that *ActiveSession() APIs are not exposed because these APIs work with a
|
/// <summary>
|
||||||
/// thread-local variable, which stores the session variable. Since the Netty server
|
/// Changes the SparkSession that will be returned in this thread when
|
||||||
/// that handles the requests is multi-threaded, any thread can invoke these APIs,
|
/// <see cref="Builder.GetOrCreate"/> is called. This can be used to ensure that a given
|
||||||
/// resulting in unexpected behaviors if different threads are used.
|
/// thread receives a SparkSession with an isolated session, instead of the global
|
||||||
|
/// (first created) context.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="session">SparkSession object</param>
|
||||||
|
public static void SetActiveSession(SparkSession session) =>
|
||||||
|
session._jvmObject.Jvm.CallStaticJavaMethod(
|
||||||
|
s_sparkSessionClassName, "setActiveSession", session);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Clears the active SparkSession for current thread. Subsequent calls to
|
||||||
|
/// <see cref="Builder.GetOrCreate"/> will return the first created context
|
||||||
|
/// instead of a thread-local override.
|
||||||
|
/// </summary>
|
||||||
|
public static void ClearActiveSession() =>
|
||||||
|
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
|
||||||
|
s_sparkSessionClassName, "clearActiveSession");
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Returns the active SparkSession for the current thread, returned by the builder.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Return null, when calling this function on executors</returns>
|
||||||
|
public static SparkSession GetActiveSession()
|
||||||
|
{
|
||||||
|
var optionalSession = new Option(
|
||||||
|
(JvmObjectReference)SparkEnvironment.JvmBridge.CallStaticJavaMethod(
|
||||||
|
s_sparkSessionClassName, "getActiveSession"));
|
||||||
|
|
||||||
|
return optionalSession.IsDefined()
|
||||||
|
? new SparkSession((JvmObjectReference)optionalSession.Get())
|
||||||
|
: null;
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Sets the default SparkSession that is returned by the builder.
|
/// Sets the default SparkSession that is returned by the builder.
|
||||||
|
|
|
@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet
|
||||||
|
|
||||||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
|
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
|
||||||
|
|
||||||
|
import scala.collection.mutable.HashMap
|
||||||
|
import scala.language.existentials
|
||||||
|
|
||||||
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
|
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
|
||||||
import org.apache.spark.api.dotnet.SerDe._
|
import org.apache.spark.api.dotnet.SerDe._
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.util.Utils
|
import org.apache.spark.util.Utils
|
||||||
|
|
||||||
import scala.collection.mutable.HashMap
|
|
||||||
import scala.language.existentials
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Handler for DotnetBackend.
|
* Handler for DotnetBackend.
|
||||||
* This implementation is similar to RBackendHandler.
|
* This implementation is similar to RBackendHandler.
|
||||||
|
@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend)
|
||||||
|
|
||||||
// First bit is isStatic
|
// First bit is isStatic
|
||||||
val isStatic = readBoolean(dis)
|
val isStatic = readBoolean(dis)
|
||||||
|
val threadId = readInt(dis)
|
||||||
val objId = readString(dis)
|
val objId = readString(dis)
|
||||||
val methodName = readString(dis)
|
val methodName = readString(dis)
|
||||||
val numArgs = readInt(dis)
|
val numArgs = readInt(dis)
|
||||||
|
@ -65,12 +66,24 @@ class DotnetBackendHandler(server: DotnetBackend)
|
||||||
logError(s"Removing $objId failed", e)
|
logError(s"Removing $objId failed", e)
|
||||||
writeInt(dos, -1)
|
writeInt(dos, -1)
|
||||||
}
|
}
|
||||||
|
case "rmThread" =>
|
||||||
|
try {
|
||||||
|
assert(readObjectType(dis) == 'i')
|
||||||
|
val threadToDelete = readInt(dis)
|
||||||
|
val result = ThreadPool.tryDeleteThread(threadToDelete)
|
||||||
|
writeInt(dos, 0)
|
||||||
|
writeObject(dos, result.asInstanceOf[AnyRef])
|
||||||
|
} catch {
|
||||||
|
case e: Exception =>
|
||||||
|
logError(s"Removing thread $threadId failed", e)
|
||||||
|
writeInt(dos, -1)
|
||||||
|
}
|
||||||
case "connectCallback" =>
|
case "connectCallback" =>
|
||||||
assert(readObjectType(dis) == 'c')
|
assert(readObjectType(dis) == 'c')
|
||||||
val address = readString(dis)
|
val address = readString(dis)
|
||||||
assert(readObjectType(dis) == 'i')
|
assert(readObjectType(dis) == 'i')
|
||||||
val port = readInt(dis)
|
val port = readInt(dis)
|
||||||
DotnetBackend.setCallbackClient(address, port);
|
DotnetBackend.setCallbackClient(address, port)
|
||||||
writeInt(dos, 0)
|
writeInt(dos, 0)
|
||||||
writeType(dos, "void")
|
writeType(dos, "void")
|
||||||
case "closeCallback" =>
|
case "closeCallback" =>
|
||||||
|
@ -82,7 +95,8 @@ class DotnetBackendHandler(server: DotnetBackend)
|
||||||
case _ => dos.writeInt(-1)
|
case _ => dos.writeInt(-1)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
|
ThreadPool
|
||||||
|
.run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
|
||||||
}
|
}
|
||||||
|
|
||||||
bos.toByteArray
|
bos.toByteArray
|
||||||
|
@ -162,19 +176,21 @@ class DotnetBackendHandler(server: DotnetBackend)
|
||||||
"invalid method " + methodName + " for object " + objId)
|
"invalid method " + methodName + " for object " + objId)
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
case e: Exception =>
|
case e: Throwable =>
|
||||||
val jvmObj = JVMObjectTracker.get(objId)
|
val jvmObj = JVMObjectTracker.get(objId)
|
||||||
val jvmObjName = jvmObj match {
|
val jvmObjName = jvmObj match {
|
||||||
case Some(jObj) => jObj.getClass.getName
|
case Some(jObj) => jObj.getClass.getName
|
||||||
case None => "NullObject"
|
case None => "NullObject"
|
||||||
}
|
}
|
||||||
val argsStr = args.map(arg => {
|
val argsStr = args
|
||||||
if (arg != null) {
|
.map(arg => {
|
||||||
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
|
if (arg != null) {
|
||||||
} else {
|
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
|
||||||
"[Value: NULL]"
|
} else {
|
||||||
}
|
"[Value: NULL]"
|
||||||
}).mkString(", ")
|
}
|
||||||
|
})
|
||||||
|
.mkString(", ")
|
||||||
|
|
||||||
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
|
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the .NET Foundation under one or more agreements.
|
||||||
|
* The .NET Foundation licenses this file to you under the MIT license.
|
||||||
|
* See the LICENSE file in the project root for more information.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.api.dotnet
|
||||||
|
|
||||||
|
import java.util.concurrent.{ExecutorService, Executors}
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Pool of thread executors. There should be a 1-1 correspondence between C# threads
|
||||||
|
* and Java threads.
|
||||||
|
*/
|
||||||
|
object ThreadPool {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map from threadId to corresponding executor.
|
||||||
|
*/
|
||||||
|
private val executors: mutable.HashMap[Int, ExecutorService] =
|
||||||
|
new mutable.HashMap[Int, ExecutorService]()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Run some code on a particular thread.
|
||||||
|
*
|
||||||
|
* @param threadId Integer id of the thread.
|
||||||
|
* @param task Function to run on the thread.
|
||||||
|
*/
|
||||||
|
def run(threadId: Int, task: () => Unit): Unit = {
|
||||||
|
val executor = getOrCreateExecutor(threadId)
|
||||||
|
val future = executor.submit(new Runnable {
|
||||||
|
override def run(): Unit = task()
|
||||||
|
})
|
||||||
|
|
||||||
|
future.get()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Try to delete a particular thread.
|
||||||
|
*
|
||||||
|
* @param threadId Integer id of the thread.
|
||||||
|
* @return True if successful, false if thread does not exist.
|
||||||
|
*/
|
||||||
|
def tryDeleteThread(threadId: Int): Boolean = synchronized {
|
||||||
|
executors.remove(threadId) match {
|
||||||
|
case Some(executorService) =>
|
||||||
|
executorService.shutdown()
|
||||||
|
true
|
||||||
|
case None => false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the executor if it exists, otherwise create a new one.
|
||||||
|
*
|
||||||
|
* @param id Integer id of the thread.
|
||||||
|
* @return The new or existing executor with the given id.
|
||||||
|
*/
|
||||||
|
private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
|
||||||
|
executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet
|
||||||
|
|
||||||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
|
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
|
||||||
|
|
||||||
|
import scala.collection.mutable.HashMap
|
||||||
|
import scala.language.existentials
|
||||||
|
|
||||||
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
|
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
|
||||||
import org.apache.spark.api.dotnet.SerDe._
|
import org.apache.spark.api.dotnet.SerDe._
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.util.Utils
|
import org.apache.spark.util.Utils
|
||||||
|
|
||||||
import scala.collection.mutable.HashMap
|
|
||||||
import scala.language.existentials
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Handler for DotnetBackend.
|
* Handler for DotnetBackend.
|
||||||
* This implementation is similar to RBackendHandler.
|
* This implementation is similar to RBackendHandler.
|
||||||
|
@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend)
|
||||||
|
|
||||||
// First bit is isStatic
|
// First bit is isStatic
|
||||||
val isStatic = readBoolean(dis)
|
val isStatic = readBoolean(dis)
|
||||||
|
val threadId = readInt(dis)
|
||||||
val objId = readString(dis)
|
val objId = readString(dis)
|
||||||
val methodName = readString(dis)
|
val methodName = readString(dis)
|
||||||
val numArgs = readInt(dis)
|
val numArgs = readInt(dis)
|
||||||
|
@ -65,12 +66,24 @@ class DotnetBackendHandler(server: DotnetBackend)
|
||||||
logError(s"Removing $objId failed", e)
|
logError(s"Removing $objId failed", e)
|
||||||
writeInt(dos, -1)
|
writeInt(dos, -1)
|
||||||
}
|
}
|
||||||
|
case "rmThread" =>
|
||||||
|
try {
|
||||||
|
assert(readObjectType(dis) == 'i')
|
||||||
|
val threadToDelete = readInt(dis)
|
||||||
|
val result = ThreadPool.tryDeleteThread(threadToDelete)
|
||||||
|
writeInt(dos, 0)
|
||||||
|
writeObject(dos, result.asInstanceOf[AnyRef])
|
||||||
|
} catch {
|
||||||
|
case e: Exception =>
|
||||||
|
logError(s"Removing thread $threadId failed", e)
|
||||||
|
writeInt(dos, -1)
|
||||||
|
}
|
||||||
case "connectCallback" =>
|
case "connectCallback" =>
|
||||||
assert(readObjectType(dis) == 'c')
|
assert(readObjectType(dis) == 'c')
|
||||||
val address = readString(dis)
|
val address = readString(dis)
|
||||||
assert(readObjectType(dis) == 'i')
|
assert(readObjectType(dis) == 'i')
|
||||||
val port = readInt(dis)
|
val port = readInt(dis)
|
||||||
DotnetBackend.setCallbackClient(address, port);
|
DotnetBackend.setCallbackClient(address, port)
|
||||||
writeInt(dos, 0)
|
writeInt(dos, 0)
|
||||||
writeType(dos, "void")
|
writeType(dos, "void")
|
||||||
case "closeCallback" =>
|
case "closeCallback" =>
|
||||||
|
@ -82,7 +95,8 @@ class DotnetBackendHandler(server: DotnetBackend)
|
||||||
case _ => dos.writeInt(-1)
|
case _ => dos.writeInt(-1)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
|
ThreadPool
|
||||||
|
.run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
|
||||||
}
|
}
|
||||||
|
|
||||||
bos.toByteArray
|
bos.toByteArray
|
||||||
|
@ -162,19 +176,21 @@ class DotnetBackendHandler(server: DotnetBackend)
|
||||||
"invalid method " + methodName + " for object " + objId)
|
"invalid method " + methodName + " for object " + objId)
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
case e: Exception =>
|
case e: Throwable =>
|
||||||
val jvmObj = JVMObjectTracker.get(objId)
|
val jvmObj = JVMObjectTracker.get(objId)
|
||||||
val jvmObjName = jvmObj match {
|
val jvmObjName = jvmObj match {
|
||||||
case Some(jObj) => jObj.getClass.getName
|
case Some(jObj) => jObj.getClass.getName
|
||||||
case None => "NullObject"
|
case None => "NullObject"
|
||||||
}
|
}
|
||||||
val argsStr = args.map(arg => {
|
val argsStr = args
|
||||||
if (arg != null) {
|
.map(arg => {
|
||||||
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
|
if (arg != null) {
|
||||||
} else {
|
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
|
||||||
"[Value: NULL]"
|
} else {
|
||||||
}
|
"[Value: NULL]"
|
||||||
}).mkString(", ")
|
}
|
||||||
|
})
|
||||||
|
.mkString(", ")
|
||||||
|
|
||||||
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
|
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the .NET Foundation under one or more agreements.
|
||||||
|
* The .NET Foundation licenses this file to you under the MIT license.
|
||||||
|
* See the LICENSE file in the project root for more information.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.api.dotnet
|
||||||
|
|
||||||
|
import java.util.concurrent.{ExecutorService, Executors}
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Pool of thread executors. There should be a 1-1 correspondence between C# threads
|
||||||
|
* and Java threads.
|
||||||
|
*/
|
||||||
|
object ThreadPool {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map from threadId to corresponding executor.
|
||||||
|
*/
|
||||||
|
private val executors: mutable.HashMap[Int, ExecutorService] =
|
||||||
|
new mutable.HashMap[Int, ExecutorService]()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Run some code on a particular thread.
|
||||||
|
*
|
||||||
|
* @param threadId Integer id of the thread.
|
||||||
|
* @param task Function to run on the thread.
|
||||||
|
*/
|
||||||
|
def run(threadId: Int, task: () => Unit): Unit = {
|
||||||
|
val executor = getOrCreateExecutor(threadId)
|
||||||
|
val future = executor.submit(new Runnable {
|
||||||
|
override def run(): Unit = task()
|
||||||
|
})
|
||||||
|
|
||||||
|
future.get()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Try to delete a particular thread.
|
||||||
|
*
|
||||||
|
* @param threadId Integer id of the thread.
|
||||||
|
* @return True if successful, false if thread does not exist.
|
||||||
|
*/
|
||||||
|
def tryDeleteThread(threadId: Int): Boolean = synchronized {
|
||||||
|
executors.remove(threadId) match {
|
||||||
|
case Some(executorService) =>
|
||||||
|
executorService.shutdown()
|
||||||
|
true
|
||||||
|
case None => false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the executor if it exists, otherwise create a new one.
|
||||||
|
*
|
||||||
|
* @param id Integer id of the thread.
|
||||||
|
* @return The new or existing executor with the given id.
|
||||||
|
*/
|
||||||
|
private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
|
||||||
|
executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
|
||||||
|
}
|
||||||
|
}
|
|
@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend)
|
||||||
|
|
||||||
// First bit is isStatic
|
// First bit is isStatic
|
||||||
val isStatic = readBoolean(dis)
|
val isStatic = readBoolean(dis)
|
||||||
|
val threadId = readInt(dis)
|
||||||
val objId = readString(dis)
|
val objId = readString(dis)
|
||||||
val methodName = readString(dis)
|
val methodName = readString(dis)
|
||||||
val numArgs = readInt(dis)
|
val numArgs = readInt(dis)
|
||||||
|
@ -65,12 +66,24 @@ class DotnetBackendHandler(server: DotnetBackend)
|
||||||
logError(s"Removing $objId failed", e)
|
logError(s"Removing $objId failed", e)
|
||||||
writeInt(dos, -1)
|
writeInt(dos, -1)
|
||||||
}
|
}
|
||||||
|
case "rmThread" =>
|
||||||
|
try {
|
||||||
|
assert(readObjectType(dis) == 'i')
|
||||||
|
val threadToDelete = readInt(dis)
|
||||||
|
val result = ThreadPool.tryDeleteThread(threadToDelete)
|
||||||
|
writeInt(dos, 0)
|
||||||
|
writeObject(dos, result.asInstanceOf[AnyRef])
|
||||||
|
} catch {
|
||||||
|
case e: Exception =>
|
||||||
|
logError(s"Removing thread $threadId failed", e)
|
||||||
|
writeInt(dos, -1)
|
||||||
|
}
|
||||||
case "connectCallback" =>
|
case "connectCallback" =>
|
||||||
assert(readObjectType(dis) == 'c')
|
assert(readObjectType(dis) == 'c')
|
||||||
val address = readString(dis)
|
val address = readString(dis)
|
||||||
assert(readObjectType(dis) == 'i')
|
assert(readObjectType(dis) == 'i')
|
||||||
val port = readInt(dis)
|
val port = readInt(dis)
|
||||||
DotnetBackend.setCallbackClient(address, port);
|
DotnetBackend.setCallbackClient(address, port)
|
||||||
writeInt(dos, 0)
|
writeInt(dos, 0)
|
||||||
writeType(dos, "void")
|
writeType(dos, "void")
|
||||||
case "closeCallback" =>
|
case "closeCallback" =>
|
||||||
|
@ -82,7 +95,8 @@ class DotnetBackendHandler(server: DotnetBackend)
|
||||||
case _ => dos.writeInt(-1)
|
case _ => dos.writeInt(-1)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
|
ThreadPool
|
||||||
|
.run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
|
||||||
}
|
}
|
||||||
|
|
||||||
bos.toByteArray
|
bos.toByteArray
|
||||||
|
@ -162,19 +176,21 @@ class DotnetBackendHandler(server: DotnetBackend)
|
||||||
"invalid method " + methodName + " for object " + objId)
|
"invalid method " + methodName + " for object " + objId)
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
case e: Exception =>
|
case e: Throwable =>
|
||||||
val jvmObj = JVMObjectTracker.get(objId)
|
val jvmObj = JVMObjectTracker.get(objId)
|
||||||
val jvmObjName = jvmObj match {
|
val jvmObjName = jvmObj match {
|
||||||
case Some(jObj) => jObj.getClass.getName
|
case Some(jObj) => jObj.getClass.getName
|
||||||
case None => "NullObject"
|
case None => "NullObject"
|
||||||
}
|
}
|
||||||
val argsStr = args.map(arg => {
|
val argsStr = args
|
||||||
if (arg != null) {
|
.map(arg => {
|
||||||
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
|
if (arg != null) {
|
||||||
} else {
|
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
|
||||||
"[Value: NULL]"
|
} else {
|
||||||
}
|
"[Value: NULL]"
|
||||||
}).mkString(", ")
|
}
|
||||||
|
})
|
||||||
|
.mkString(", ")
|
||||||
|
|
||||||
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
|
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the .NET Foundation under one or more agreements.
|
||||||
|
* The .NET Foundation licenses this file to you under the MIT license.
|
||||||
|
* See the LICENSE file in the project root for more information.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.api.dotnet
|
||||||
|
|
||||||
|
import java.util.concurrent.{ExecutorService, Executors}
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Pool of thread executors. There should be a 1-1 correspondence between C# threads
|
||||||
|
* and Java threads.
|
||||||
|
*/
|
||||||
|
object ThreadPool {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map from threadId to corresponding executor.
|
||||||
|
*/
|
||||||
|
private val executors: mutable.HashMap[Int, ExecutorService] =
|
||||||
|
new mutable.HashMap[Int, ExecutorService]()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Run some code on a particular thread.
|
||||||
|
*
|
||||||
|
* @param threadId Integer id of the thread.
|
||||||
|
* @param task Function to run on the thread.
|
||||||
|
*/
|
||||||
|
def run(threadId: Int, task: () => Unit): Unit = {
|
||||||
|
val executor = getOrCreateExecutor(threadId)
|
||||||
|
val future = executor.submit(new Runnable {
|
||||||
|
override def run(): Unit = task()
|
||||||
|
})
|
||||||
|
|
||||||
|
future.get()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Try to delete a particular thread.
|
||||||
|
*
|
||||||
|
* @param threadId Integer id of the thread.
|
||||||
|
* @return True if successful, false if thread does not exist.
|
||||||
|
*/
|
||||||
|
def tryDeleteThread(threadId: Int): Boolean = synchronized {
|
||||||
|
executors.remove(threadId) match {
|
||||||
|
case Some(executorService) =>
|
||||||
|
executorService.shutdown()
|
||||||
|
true
|
||||||
|
case None => false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the executor if it exists, otherwise create a new one.
|
||||||
|
*
|
||||||
|
* @param id Integer id of the thread.
|
||||||
|
* @return The new or existing executor with the given id.
|
||||||
|
*/
|
||||||
|
private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
|
||||||
|
executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
|
||||||
|
}
|
||||||
|
}
|
Загрузка…
Ссылка в новой задаче