Using (processId, threadId) as key to mantain threadpool executor instead of only threadId (#793)

This commit is contained in:
高阳阳 2021-02-04 03:10:44 +08:00 коммит произвёл GitHub
Родитель d3039e2f32
Коммит d608d8e6c5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 82 добавлений и 56 удалений

Просмотреть файл

@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.
using System;
using System.Diagnostics;
using System.Threading;
using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
@ -68,8 +69,9 @@ namespace Microsoft.Spark.E2ETest.IpcTests
[Fact]
public void TestTryAddThread()
{
int processId = Process.GetCurrentProcess().Id;
using var threadPool = new JvmThreadPoolGC(
_loggerService, _jvmBridge, TimeSpan.FromMinutes(30));
_loggerService, _jvmBridge, TimeSpan.FromMinutes(30), processId);
var thread = new Thread(() => _spark.Sql("SELECT TRUE"));
thread.Start();
@ -88,14 +90,15 @@ namespace Microsoft.Spark.E2ETest.IpcTests
[Fact]
public void TestRmThread()
{
int processId = Process.GetCurrentProcess().Id;
// Create a thread and ensure that it is initialized in the JVM ThreadPool.
var thread = new Thread(() => _spark.Sql("SELECT TRUE"));
thread.Start();
thread.Join();
// First call should return true. Second call should return false.
Assert.True((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId));
Assert.False((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId));
Assert.True((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", processId, thread.ManagedThreadId));
Assert.False((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", processId, thread.ManagedThreadId));
}
/// <summary>

Просмотреть файл

@ -5,6 +5,7 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Text;
@ -34,6 +35,7 @@ namespace Microsoft.Spark.Interop.Ipc
private const int SocketBufferThreshold = 3;
private const int ThreadIdForRepl = 1;
private readonly int _processId = Process.GetCurrentProcess().Id;
private readonly SemaphoreSlim _socketSemaphore;
private readonly ConcurrentQueue<ISocketWrapper> _sockets =
new ConcurrentQueue<ISocketWrapper>();
@ -54,7 +56,7 @@ namespace Microsoft.Spark.Interop.Ipc
_logger.LogInfo($"JvMBridge port is {portNumber}");
_jvmThreadPoolGC = new JvmThreadPoolGC(
_logger, this, SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
_logger, this, SparkEnvironment.ConfigurationService.JvmThreadGCInterval, _processId);
_isRunningRepl = SparkEnvironment.ConfigurationService.IsRunningRepl();
@ -203,12 +205,15 @@ namespace Microsoft.Spark.Interop.Ipc
// call will never run because DotnetHandler will assign the method call to
// run on the same thread that `AwaitTermination` is running on.
Thread thread = _isRunningRepl ? null : Thread.CurrentThread;
int threadId = thread == null ? ThreadIdForRepl : thread.ManagedThreadId;
MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream();
payloadMemoryStream.Position = 0;
PayloadHelper.BuildPayload(
payloadMemoryStream,
isStatic,
thread == null ? ThreadIdForRepl : thread.ManagedThreadId,
_processId,
threadId,
classNameOrJvmObjectReference,
methodName,
args);

Просмотреть файл

@ -5,6 +5,7 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using Microsoft.Spark.Services;
@ -25,6 +26,7 @@ namespace Microsoft.Spark.Interop.Ipc
private readonly ILoggerService _loggerService;
private readonly IJvmBridge _jvmBridge;
private readonly TimeSpan _threadGCInterval;
private readonly int _processId;
private readonly ConcurrentDictionary<int, Thread> _activeThreads;
private readonly object _activeThreadGCTimerLock;
@ -36,11 +38,13 @@ namespace Microsoft.Spark.Interop.Ipc
/// <param name="loggerService">Logger service.</param>
/// <param name="jvmBridge">The JvmBridge used to call JVM methods.</param>
/// <param name="threadGCInterval">The interval to GC finished threads.</param>
public JvmThreadPoolGC(ILoggerService loggerService, IJvmBridge jvmBridge, TimeSpan threadGCInterval)
/// <param name="processId"> The ID of the process.</param>
public JvmThreadPoolGC(ILoggerService loggerService, IJvmBridge jvmBridge, TimeSpan threadGCInterval, int processId)
{
_loggerService = loggerService;
_jvmBridge = jvmBridge;
_threadGCInterval = threadGCInterval;
_processId = processId;
_activeThreads = new ConcurrentDictionary<int, Thread>();
_activeThreadGCTimerLock = new object();
@ -106,7 +110,7 @@ namespace Microsoft.Spark.Interop.Ipc
// class does not need to call Join() on the .NET Thread. However, this class is
// responsible for sending the rmThread command to the JVM to trigger disposal
// of the corresponding JVM thread.
if ((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", threadId))
if ((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", _processId, threadId))
{
_loggerService.LogDebug($"GC'd JVM thread {threadId}.");
return true;

Просмотреть файл

@ -39,6 +39,7 @@ namespace Microsoft.Spark.Interop.Ipc
internal static void BuildPayload(
MemoryStream destination,
bool isStaticMethod,
int processId,
int threadId,
object classNameOrJvmObjectReference,
string methodName,
@ -49,6 +50,7 @@ namespace Microsoft.Spark.Interop.Ipc
destination.Position += sizeof(int);
SerDe.Write(destination, isStaticMethod);
SerDe.Write(destination, processId);
SerDe.Write(destination, threadId);
SerDe.Write(destination, classNameOrJvmObjectReference.ToString());
SerDe.Write(destination, methodName);

Просмотреть файл

@ -41,6 +41,7 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
// First bit is isStatic
val isStatic = serDe.readBoolean(dis)
val processId = serDe.readInt(dis)
val threadId = serDe.readInt(dis)
val objId = serDe.readString(dis)
val methodName = serDe.readString(dis)
@ -67,9 +68,11 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
}
case "rmThread" =>
try {
assert(serDe.readObjectType(dis) == 'i')
val processId = serDe.readInt(dis)
assert(serDe.readObjectType(dis) == 'i')
val threadToDelete = serDe.readInt(dis)
val result = ThreadPool.tryDeleteThread(threadToDelete)
val result = ThreadPool.tryDeleteThread(processId, threadToDelete)
serDe.writeInt(dos, 0)
serDe.writeObject(dos, result.asInstanceOf[AnyRef])
} catch {
@ -95,7 +98,7 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
}
} else {
ThreadPool
.run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
.run(processId, threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
}
bos.toByteArray

Просмотреть файл

@ -17,19 +17,19 @@ import scala.collection.mutable
object ThreadPool {
/**
* Map from threadId to corresponding executor.
* Map from (processId, threadId) to corresponding executor.
*/
private val executors: mutable.HashMap[Int, ExecutorService] =
new mutable.HashMap[Int, ExecutorService]()
private val executors: mutable.HashMap[(Int, Int), ExecutorService] =
new mutable.HashMap[(Int, Int), ExecutorService]()
/**
* Run some code on a particular thread.
*
* @param processId Integer id of the process.
* @param threadId Integer id of the thread.
* @param task Function to run on the thread.
*/
def run(threadId: Int, task: () => Unit): Unit = {
val executor = getOrCreateExecutor(threadId)
def run(processId: Int, threadId: Int, task: () => Unit): Unit = {
val executor = getOrCreateExecutor(processId, threadId)
val future = executor.submit(new Runnable {
override def run(): Unit = task()
})
@ -39,12 +39,12 @@ object ThreadPool {
/**
* Try to delete a particular thread.
*
* @param processId Integer id of the process.
* @param threadId Integer id of the thread.
* @return True if successful, false if thread does not exist.
*/
def tryDeleteThread(threadId: Int): Boolean = synchronized {
executors.remove(threadId) match {
def tryDeleteThread(processId: Int, threadId: Int): Boolean = synchronized {
executors.remove((processId, threadId)) match {
case Some(executorService) =>
executorService.shutdown()
true
@ -62,11 +62,11 @@ object ThreadPool {
/**
* Get the executor if it exists, otherwise create a new one.
*
* @param id Integer id of the thread.
* @param processId Integer id of the process.
* @param threadId Integer id of the thread.
* @return The new or existing executor with the given id.
*/
private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
private def getOrCreateExecutor(processId: Int, threadId: Int): ExecutorService = synchronized {
executors.getOrElseUpdate((processId, threadId), Executors.newSingleThreadExecutor)
}
}

Просмотреть файл

@ -34,7 +34,8 @@ class DotnetBackendHandlerTest {
val message = givenMessage(m => {
val serDe = new SerDe(null)
m.writeBoolean(true) // static method
m.writeInt(1) // threadId
serDe.writeInt(m, 1) // processId
serDe.writeInt(m, 1) // threadId
serDe.writeString(m, "DotnetHandler") // class name
serDe.writeString(m, "connectCallback") // command (method) name
m.writeInt(2) // number of arguments

Просмотреть файл

@ -41,6 +41,7 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
// First bit is isStatic
val isStatic = serDe.readBoolean(dis)
val processId = serDe.readInt(dis)
val threadId = serDe.readInt(dis)
val objId = serDe.readString(dis)
val methodName = serDe.readString(dis)
@ -67,9 +68,11 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
}
case "rmThread" =>
try {
assert(serDe.readObjectType(dis) == 'i')
val processId = serDe.readInt(dis)
assert(serDe.readObjectType(dis) == 'i')
val threadToDelete = serDe.readInt(dis)
val result = ThreadPool.tryDeleteThread(threadToDelete)
val result = ThreadPool.tryDeleteThread(processId, threadToDelete)
serDe.writeInt(dos, 0)
serDe.writeObject(dos, result.asInstanceOf[AnyRef])
} catch {
@ -99,7 +102,7 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
}
} else {
ThreadPool
.run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
.run(processId, threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
}
bos.toByteArray

Просмотреть файл

@ -17,19 +17,19 @@ import scala.collection.mutable
object ThreadPool {
/**
* Map from threadId to corresponding executor.
* Map from (processId, threadId) to corresponding executor.
*/
private val executors: mutable.HashMap[Int, ExecutorService] =
new mutable.HashMap[Int, ExecutorService]()
private val executors: mutable.HashMap[(Int, Int), ExecutorService] =
new mutable.HashMap[(Int, Int), ExecutorService]()
/**
* Run some code on a particular thread.
*
* @param processId Integer id of the process.
* @param threadId Integer id of the thread.
* @param task Function to run on the thread.
*/
def run(threadId: Int, task: () => Unit): Unit = {
val executor = getOrCreateExecutor(threadId)
def run(processId: Int, threadId: Int, task: () => Unit): Unit = {
val executor = getOrCreateExecutor(processId, threadId)
val future = executor.submit(new Runnable {
override def run(): Unit = task()
})
@ -39,12 +39,12 @@ object ThreadPool {
/**
* Try to delete a particular thread.
*
* @param processId Integer id of the process.
* @param threadId Integer id of the thread.
* @return True if successful, false if thread does not exist.
*/
def tryDeleteThread(threadId: Int): Boolean = synchronized {
executors.remove(threadId) match {
def tryDeleteThread(processId: Int, threadId: Int): Boolean = synchronized {
executors.remove((processId, threadId)) match {
case Some(executorService) =>
executorService.shutdown()
true
@ -62,11 +62,11 @@ object ThreadPool {
/**
* Get the executor if it exists, otherwise create a new one.
*
* @param id Integer id of the thread.
* @return The new or existing executor with the given id.
* @param processId Integer id of the process.
* @param threadId Integer id of the thread.
* @return The new or existing executor with the given (processId, threadId).
*/
private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
private def getOrCreateExecutor(processId: Int, threadId: Int): ExecutorService = synchronized {
executors.getOrElseUpdate((processId, threadId), Executors.newSingleThreadExecutor)
}
}

Просмотреть файл

@ -35,7 +35,8 @@ class DotnetBackendHandlerTest {
val message = givenMessage(m => {
val serDe = new SerDe(null)
m.writeBoolean(true) // static method
m.writeInt(1) // threadId
serDe.writeInt(m, 1) // processId
serDe.writeInt(m, 1) // threadId
serDe.writeString(m, "DotnetHandler") // class name
serDe.writeString(m, "connectCallback") // command (method) name
m.writeInt(2) // number of arguments

Просмотреть файл

@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
// First bit is isStatic
val isStatic = serDe.readBoolean(dis)
val processId = serDe.readInt(dis)
val threadId = serDe.readInt(dis)
val objId = serDe.readString(dis)
val methodName = serDe.readString(dis)
@ -68,9 +69,11 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
}
case "rmThread" =>
try {
assert(serDe.readObjectType(dis) == 'i')
val processId = serDe.readInt(dis)
assert(serDe.readObjectType(dis) == 'i')
val threadToDelete = serDe.readInt(dis)
val result = ThreadPool.tryDeleteThread(threadToDelete)
val result = ThreadPool.tryDeleteThread(processId, threadToDelete)
serDe.writeInt(dos, 0)
serDe.writeObject(dos, result.asInstanceOf[AnyRef])
} catch {
@ -99,7 +102,7 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack
}
} else {
ThreadPool
.run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
.run(processId, threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
}
bos.toByteArray

Просмотреть файл

@ -17,19 +17,19 @@ import scala.collection.mutable
object ThreadPool {
/**
* Map from threadId to corresponding executor.
* Map from (processId, threadId) to corresponding executor.
*/
private val executors: mutable.HashMap[Int, ExecutorService] =
new mutable.HashMap[Int, ExecutorService]()
private val executors: mutable.HashMap[(Int, Int), ExecutorService] =
new mutable.HashMap[(Int, Int), ExecutorService]()
/**
* Run some code on a particular thread.
*
* @param processId Integer id of the process.
* @param threadId Integer id of the thread.
* @param task Function to run on the thread.
*/
def run(threadId: Int, task: () => Unit): Unit = {
val executor = getOrCreateExecutor(threadId)
def run(processId: Int, threadId: Int, task: () => Unit): Unit = {
val executor = getOrCreateExecutor(processId, threadId)
val future = executor.submit(new Runnable {
override def run(): Unit = task()
})
@ -39,12 +39,12 @@ object ThreadPool {
/**
* Try to delete a particular thread.
*
* @param processId Integer id of the process.
* @param threadId Integer id of the thread.
* @return True if successful, false if thread does not exist.
*/
def tryDeleteThread(threadId: Int): Boolean = synchronized {
executors.remove(threadId) match {
def tryDeleteThread(processId: Int, threadId: Int): Boolean = synchronized {
executors.remove((processId, threadId)) match {
case Some(executorService) =>
executorService.shutdown()
true
@ -62,11 +62,11 @@ object ThreadPool {
/**
* Get the executor if it exists, otherwise create a new one.
*
* @param id Integer id of the thread.
* @param processId Integer id of the process.
* @param threadId Integer id of the thread.
* @return The new or existing executor with the given id.
*/
private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
private def getOrCreateExecutor(processId: Int, threadId: Int): ExecutorService = synchronized {
executors.getOrElseUpdate((processId, threadId), Executors.newSingleThreadExecutor)
}
}

Просмотреть файл

@ -36,7 +36,8 @@ class DotnetBackendHandlerTest {
val message = givenMessage(m => {
val serDe = new SerDe(null)
m.writeBoolean(true) // static method
m.writeInt(1) // threadId
serDe.writeInt(m, 1) // processId
serDe.writeInt(m, 1) // threadId
serDe.writeString(m, "DotnetHandler") // class name
serDe.writeString(m, "connectCallback") // command (method) name
m.writeInt(2) // number of arguments