зеркало из https://github.com/dotnet/spark.git
Using (processId, threadId) as key to mantain threadpool executor instead of only threadId (#793)
This commit is contained in:
Родитель
d3039e2f32
Коммит
d608d8e6c5
|
@ -3,6 +3,7 @@
|
|||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Diagnostics;
|
||||
using System.Threading;
|
||||
using Microsoft.Spark.Interop;
|
||||
using Microsoft.Spark.Interop.Ipc;
|
||||
|
@ -68,8 +69,9 @@ namespace Microsoft.Spark.E2ETest.IpcTests
|
|||
[Fact]
|
||||
public void TestTryAddThread()
|
||||
{
|
||||
int processId = Process.GetCurrentProcess().Id;
|
||||
using var threadPool = new JvmThreadPoolGC(
|
||||
_loggerService, _jvmBridge, TimeSpan.FromMinutes(30));
|
||||
_loggerService, _jvmBridge, TimeSpan.FromMinutes(30), processId);
|
||||
|
||||
var thread = new Thread(() => _spark.Sql("SELECT TRUE"));
|
||||
thread.Start();
|
||||
|
@ -88,14 +90,15 @@ namespace Microsoft.Spark.E2ETest.IpcTests
|
|||
[Fact]
|
||||
public void TestRmThread()
|
||||
{
|
||||
int processId = Process.GetCurrentProcess().Id;
|
||||
// Create a thread and ensure that it is initialized in the JVM ThreadPool.
|
||||
var thread = new Thread(() => _spark.Sql("SELECT TRUE"));
|
||||
thread.Start();
|
||||
thread.Join();
|
||||
|
||||
// First call should return true. Second call should return false.
|
||||
Assert.True((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId));
|
||||
Assert.False((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId));
|
||||
Assert.True((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", processId, thread.ManagedThreadId));
|
||||
Assert.False((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", processId, thread.ManagedThreadId));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
using System;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.Net;
|
||||
using System.Text;
|
||||
|
@ -34,6 +35,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
private const int SocketBufferThreshold = 3;
|
||||
private const int ThreadIdForRepl = 1;
|
||||
|
||||
private readonly int _processId = Process.GetCurrentProcess().Id;
|
||||
private readonly SemaphoreSlim _socketSemaphore;
|
||||
private readonly ConcurrentQueue<ISocketWrapper> _sockets =
|
||||
new ConcurrentQueue<ISocketWrapper>();
|
||||
|
@ -54,7 +56,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
_logger.LogInfo($"JvMBridge port is {portNumber}");
|
||||
|
||||
_jvmThreadPoolGC = new JvmThreadPoolGC(
|
||||
_logger, this, SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
|
||||
_logger, this, SparkEnvironment.ConfigurationService.JvmThreadGCInterval, _processId);
|
||||
|
||||
_isRunningRepl = SparkEnvironment.ConfigurationService.IsRunningRepl();
|
||||
|
||||
|
@ -203,12 +205,15 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
// call will never run because DotnetHandler will assign the method call to
|
||||
// run on the same thread that `AwaitTermination` is running on.
|
||||
Thread thread = _isRunningRepl ? null : Thread.CurrentThread;
|
||||
int threadId = thread == null ? ThreadIdForRepl : thread.ManagedThreadId;
|
||||
MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream();
|
||||
payloadMemoryStream.Position = 0;
|
||||
|
||||
PayloadHelper.BuildPayload(
|
||||
payloadMemoryStream,
|
||||
isStatic,
|
||||
thread == null ? ThreadIdForRepl : thread.ManagedThreadId,
|
||||
_processId,
|
||||
threadId,
|
||||
classNameOrJvmObjectReference,
|
||||
methodName,
|
||||
args);
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
using System;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Threading;
|
||||
using Microsoft.Spark.Services;
|
||||
|
||||
|
@ -25,6 +26,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
private readonly ILoggerService _loggerService;
|
||||
private readonly IJvmBridge _jvmBridge;
|
||||
private readonly TimeSpan _threadGCInterval;
|
||||
private readonly int _processId;
|
||||
private readonly ConcurrentDictionary<int, Thread> _activeThreads;
|
||||
|
||||
private readonly object _activeThreadGCTimerLock;
|
||||
|
@ -36,11 +38,13 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
/// <param name="loggerService">Logger service.</param>
|
||||
/// <param name="jvmBridge">The JvmBridge used to call JVM methods.</param>
|
||||
/// <param name="threadGCInterval">The interval to GC finished threads.</param>
|
||||
public JvmThreadPoolGC(ILoggerService loggerService, IJvmBridge jvmBridge, TimeSpan threadGCInterval)
|
||||
/// <param name="processId"> The ID of the process.</param>
|
||||
public JvmThreadPoolGC(ILoggerService loggerService, IJvmBridge jvmBridge, TimeSpan threadGCInterval, int processId)
|
||||
{
|
||||
_loggerService = loggerService;
|
||||
_jvmBridge = jvmBridge;
|
||||
_threadGCInterval = threadGCInterval;
|
||||
_processId = processId;
|
||||
_activeThreads = new ConcurrentDictionary<int, Thread>();
|
||||
|
||||
_activeThreadGCTimerLock = new object();
|
||||
|
@ -106,7 +110,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
// class does not need to call Join() on the .NET Thread. However, this class is
|
||||
// responsible for sending the rmThread command to the JVM to trigger disposal
|
||||
// of the corresponding JVM thread.
|
||||
if ((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", threadId))
|
||||
if ((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", _processId, threadId))
|
||||
{
|
||||
_loggerService.LogDebug($"GC'd JVM thread {threadId}.");
|
||||
return true;
|
||||
|
|
|
@ -39,6 +39,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
internal static void BuildPayload(
|
||||
MemoryStream destination,
|
||||
bool isStaticMethod,
|
||||
int processId,
|
||||
int threadId,
|
||||
object classNameOrJvmObjectReference,
|
||||
string methodName,
|
||||
|
@ -49,6 +50,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
destination.Position += sizeof(int);
|
||||
|
||||
SerDe.Write(destination, isStaticMethod);
|
||||
SerDe.Write(destination, processId);
|
||||
SerDe.Write(destination, threadId);
|
||||
SerDe.Write(destination, classNameOrJvmObjectReference.ToString());
|
||||
SerDe.Write(destination, methodName);
|
||||
|
|
|
@ -41,6 +41,7 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
|
|||
|
||||
// First bit is isStatic
|
||||
val isStatic = serDe.readBoolean(dis)
|
||||
val processId = serDe.readInt(dis)
|
||||
val threadId = serDe.readInt(dis)
|
||||
val objId = serDe.readString(dis)
|
||||
val methodName = serDe.readString(dis)
|
||||
|
@ -67,9 +68,11 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
|
|||
}
|
||||
case "rmThread" =>
|
||||
try {
|
||||
assert(serDe.readObjectType(dis) == 'i')
|
||||
val processId = serDe.readInt(dis)
|
||||
assert(serDe.readObjectType(dis) == 'i')
|
||||
val threadToDelete = serDe.readInt(dis)
|
||||
val result = ThreadPool.tryDeleteThread(threadToDelete)
|
||||
val result = ThreadPool.tryDeleteThread(processId, threadToDelete)
|
||||
serDe.writeInt(dos, 0)
|
||||
serDe.writeObject(dos, result.asInstanceOf[AnyRef])
|
||||
} catch {
|
||||
|
@ -95,7 +98,7 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
|
|||
}
|
||||
} else {
|
||||
ThreadPool
|
||||
.run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
|
||||
.run(processId, threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
|
||||
}
|
||||
|
||||
bos.toByteArray
|
||||
|
|
|
@ -17,19 +17,19 @@ import scala.collection.mutable
|
|||
object ThreadPool {
|
||||
|
||||
/**
|
||||
* Map from threadId to corresponding executor.
|
||||
* Map from (processId, threadId) to corresponding executor.
|
||||
*/
|
||||
private val executors: mutable.HashMap[Int, ExecutorService] =
|
||||
new mutable.HashMap[Int, ExecutorService]()
|
||||
private val executors: mutable.HashMap[(Int, Int), ExecutorService] =
|
||||
new mutable.HashMap[(Int, Int), ExecutorService]()
|
||||
|
||||
/**
|
||||
* Run some code on a particular thread.
|
||||
*
|
||||
* @param processId Integer id of the process.
|
||||
* @param threadId Integer id of the thread.
|
||||
* @param task Function to run on the thread.
|
||||
*/
|
||||
def run(threadId: Int, task: () => Unit): Unit = {
|
||||
val executor = getOrCreateExecutor(threadId)
|
||||
def run(processId: Int, threadId: Int, task: () => Unit): Unit = {
|
||||
val executor = getOrCreateExecutor(processId, threadId)
|
||||
val future = executor.submit(new Runnable {
|
||||
override def run(): Unit = task()
|
||||
})
|
||||
|
@ -39,12 +39,12 @@ object ThreadPool {
|
|||
|
||||
/**
|
||||
* Try to delete a particular thread.
|
||||
*
|
||||
* @param processId Integer id of the process.
|
||||
* @param threadId Integer id of the thread.
|
||||
* @return True if successful, false if thread does not exist.
|
||||
*/
|
||||
def tryDeleteThread(threadId: Int): Boolean = synchronized {
|
||||
executors.remove(threadId) match {
|
||||
def tryDeleteThread(processId: Int, threadId: Int): Boolean = synchronized {
|
||||
executors.remove((processId, threadId)) match {
|
||||
case Some(executorService) =>
|
||||
executorService.shutdown()
|
||||
true
|
||||
|
@ -62,11 +62,11 @@ object ThreadPool {
|
|||
|
||||
/**
|
||||
* Get the executor if it exists, otherwise create a new one.
|
||||
*
|
||||
* @param id Integer id of the thread.
|
||||
* @param processId Integer id of the process.
|
||||
* @param threadId Integer id of the thread.
|
||||
* @return The new or existing executor with the given id.
|
||||
*/
|
||||
private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
|
||||
executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
|
||||
private def getOrCreateExecutor(processId: Int, threadId: Int): ExecutorService = synchronized {
|
||||
executors.getOrElseUpdate((processId, threadId), Executors.newSingleThreadExecutor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,7 +34,8 @@ class DotnetBackendHandlerTest {
|
|||
val message = givenMessage(m => {
|
||||
val serDe = new SerDe(null)
|
||||
m.writeBoolean(true) // static method
|
||||
m.writeInt(1) // threadId
|
||||
serDe.writeInt(m, 1) // processId
|
||||
serDe.writeInt(m, 1) // threadId
|
||||
serDe.writeString(m, "DotnetHandler") // class name
|
||||
serDe.writeString(m, "connectCallback") // command (method) name
|
||||
m.writeInt(2) // number of arguments
|
||||
|
|
|
@ -41,6 +41,7 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
|
|||
|
||||
// First bit is isStatic
|
||||
val isStatic = serDe.readBoolean(dis)
|
||||
val processId = serDe.readInt(dis)
|
||||
val threadId = serDe.readInt(dis)
|
||||
val objId = serDe.readString(dis)
|
||||
val methodName = serDe.readString(dis)
|
||||
|
@ -67,9 +68,11 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
|
|||
}
|
||||
case "rmThread" =>
|
||||
try {
|
||||
assert(serDe.readObjectType(dis) == 'i')
|
||||
val processId = serDe.readInt(dis)
|
||||
assert(serDe.readObjectType(dis) == 'i')
|
||||
val threadToDelete = serDe.readInt(dis)
|
||||
val result = ThreadPool.tryDeleteThread(threadToDelete)
|
||||
val result = ThreadPool.tryDeleteThread(processId, threadToDelete)
|
||||
serDe.writeInt(dos, 0)
|
||||
serDe.writeObject(dos, result.asInstanceOf[AnyRef])
|
||||
} catch {
|
||||
|
@ -99,7 +102,7 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
|
|||
}
|
||||
} else {
|
||||
ThreadPool
|
||||
.run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
|
||||
.run(processId, threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
|
||||
}
|
||||
|
||||
bos.toByteArray
|
||||
|
|
|
@ -17,19 +17,19 @@ import scala.collection.mutable
|
|||
object ThreadPool {
|
||||
|
||||
/**
|
||||
* Map from threadId to corresponding executor.
|
||||
* Map from (processId, threadId) to corresponding executor.
|
||||
*/
|
||||
private val executors: mutable.HashMap[Int, ExecutorService] =
|
||||
new mutable.HashMap[Int, ExecutorService]()
|
||||
private val executors: mutable.HashMap[(Int, Int), ExecutorService] =
|
||||
new mutable.HashMap[(Int, Int), ExecutorService]()
|
||||
|
||||
/**
|
||||
* Run some code on a particular thread.
|
||||
*
|
||||
* @param processId Integer id of the process.
|
||||
* @param threadId Integer id of the thread.
|
||||
* @param task Function to run on the thread.
|
||||
*/
|
||||
def run(threadId: Int, task: () => Unit): Unit = {
|
||||
val executor = getOrCreateExecutor(threadId)
|
||||
def run(processId: Int, threadId: Int, task: () => Unit): Unit = {
|
||||
val executor = getOrCreateExecutor(processId, threadId)
|
||||
val future = executor.submit(new Runnable {
|
||||
override def run(): Unit = task()
|
||||
})
|
||||
|
@ -39,12 +39,12 @@ object ThreadPool {
|
|||
|
||||
/**
|
||||
* Try to delete a particular thread.
|
||||
*
|
||||
* @param processId Integer id of the process.
|
||||
* @param threadId Integer id of the thread.
|
||||
* @return True if successful, false if thread does not exist.
|
||||
*/
|
||||
def tryDeleteThread(threadId: Int): Boolean = synchronized {
|
||||
executors.remove(threadId) match {
|
||||
def tryDeleteThread(processId: Int, threadId: Int): Boolean = synchronized {
|
||||
executors.remove((processId, threadId)) match {
|
||||
case Some(executorService) =>
|
||||
executorService.shutdown()
|
||||
true
|
||||
|
@ -62,11 +62,11 @@ object ThreadPool {
|
|||
|
||||
/**
|
||||
* Get the executor if it exists, otherwise create a new one.
|
||||
*
|
||||
* @param id Integer id of the thread.
|
||||
* @return The new or existing executor with the given id.
|
||||
* @param processId Integer id of the process.
|
||||
* @param threadId Integer id of the thread.
|
||||
* @return The new or existing executor with the given (processId, threadId).
|
||||
*/
|
||||
private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
|
||||
executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
|
||||
private def getOrCreateExecutor(processId: Int, threadId: Int): ExecutorService = synchronized {
|
||||
executors.getOrElseUpdate((processId, threadId), Executors.newSingleThreadExecutor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,7 +35,8 @@ class DotnetBackendHandlerTest {
|
|||
val message = givenMessage(m => {
|
||||
val serDe = new SerDe(null)
|
||||
m.writeBoolean(true) // static method
|
||||
m.writeInt(1) // threadId
|
||||
serDe.writeInt(m, 1) // processId
|
||||
serDe.writeInt(m, 1) // threadId
|
||||
serDe.writeString(m, "DotnetHandler") // class name
|
||||
serDe.writeString(m, "connectCallback") // command (method) name
|
||||
m.writeInt(2) // number of arguments
|
||||
|
|
|
@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
|
|||
|
||||
// First bit is isStatic
|
||||
val isStatic = serDe.readBoolean(dis)
|
||||
val processId = serDe.readInt(dis)
|
||||
val threadId = serDe.readInt(dis)
|
||||
val objId = serDe.readString(dis)
|
||||
val methodName = serDe.readString(dis)
|
||||
|
@ -68,9 +69,11 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
|
|||
}
|
||||
case "rmThread" =>
|
||||
try {
|
||||
assert(serDe.readObjectType(dis) == 'i')
|
||||
val processId = serDe.readInt(dis)
|
||||
assert(serDe.readObjectType(dis) == 'i')
|
||||
val threadToDelete = serDe.readInt(dis)
|
||||
val result = ThreadPool.tryDeleteThread(threadToDelete)
|
||||
val result = ThreadPool.tryDeleteThread(processId, threadToDelete)
|
||||
serDe.writeInt(dos, 0)
|
||||
serDe.writeObject(dos, result.asInstanceOf[AnyRef])
|
||||
} catch {
|
||||
|
@ -99,7 +102,7 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
|
|||
}
|
||||
} else {
|
||||
ThreadPool
|
||||
.run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
|
||||
.run(processId, threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
|
||||
}
|
||||
|
||||
bos.toByteArray
|
||||
|
|
|
@ -17,19 +17,19 @@ import scala.collection.mutable
|
|||
object ThreadPool {
|
||||
|
||||
/**
|
||||
* Map from threadId to corresponding executor.
|
||||
* Map from (processId, threadId) to corresponding executor.
|
||||
*/
|
||||
private val executors: mutable.HashMap[Int, ExecutorService] =
|
||||
new mutable.HashMap[Int, ExecutorService]()
|
||||
private val executors: mutable.HashMap[(Int, Int), ExecutorService] =
|
||||
new mutable.HashMap[(Int, Int), ExecutorService]()
|
||||
|
||||
/**
|
||||
* Run some code on a particular thread.
|
||||
*
|
||||
* @param processId Integer id of the process.
|
||||
* @param threadId Integer id of the thread.
|
||||
* @param task Function to run on the thread.
|
||||
*/
|
||||
def run(threadId: Int, task: () => Unit): Unit = {
|
||||
val executor = getOrCreateExecutor(threadId)
|
||||
def run(processId: Int, threadId: Int, task: () => Unit): Unit = {
|
||||
val executor = getOrCreateExecutor(processId, threadId)
|
||||
val future = executor.submit(new Runnable {
|
||||
override def run(): Unit = task()
|
||||
})
|
||||
|
@ -39,12 +39,12 @@ object ThreadPool {
|
|||
|
||||
/**
|
||||
* Try to delete a particular thread.
|
||||
*
|
||||
* @param processId Integer id of the process.
|
||||
* @param threadId Integer id of the thread.
|
||||
* @return True if successful, false if thread does not exist.
|
||||
*/
|
||||
def tryDeleteThread(threadId: Int): Boolean = synchronized {
|
||||
executors.remove(threadId) match {
|
||||
def tryDeleteThread(processId: Int, threadId: Int): Boolean = synchronized {
|
||||
executors.remove((processId, threadId)) match {
|
||||
case Some(executorService) =>
|
||||
executorService.shutdown()
|
||||
true
|
||||
|
@ -62,11 +62,11 @@ object ThreadPool {
|
|||
|
||||
/**
|
||||
* Get the executor if it exists, otherwise create a new one.
|
||||
*
|
||||
* @param id Integer id of the thread.
|
||||
* @param processId Integer id of the process.
|
||||
* @param threadId Integer id of the thread.
|
||||
* @return The new or existing executor with the given id.
|
||||
*/
|
||||
private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
|
||||
executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
|
||||
private def getOrCreateExecutor(processId: Int, threadId: Int): ExecutorService = synchronized {
|
||||
executors.getOrElseUpdate((processId, threadId), Executors.newSingleThreadExecutor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,7 +36,8 @@ class DotnetBackendHandlerTest {
|
|||
val message = givenMessage(m => {
|
||||
val serDe = new SerDe(null)
|
||||
m.writeBoolean(true) // static method
|
||||
m.writeInt(1) // threadId
|
||||
serDe.writeInt(m, 1) // processId
|
||||
serDe.writeInt(m, 1) // threadId
|
||||
serDe.writeString(m, "DotnetHandler") // class name
|
||||
serDe.writeString(m, "connectCallback") // command (method) name
|
||||
m.writeInt(2) // number of arguments
|
||||
|
|
Загрузка…
Ссылка в новой задаче