Add read and write buffer for CSharpWorker (#521)

* Add read and write buffer for CSharpWorker

* Fix test coverage decrease issue

* Adjust read buffer

* Fix failed unit test

* use --conf option to pass buffer size

* Fix flaky unit test

* Address review comments

* Fix option name renaming issue
This commit is contained in:
Tao Wang 2016-08-12 14:21:31 +08:00 коммит произвёл GitHub
Родитель 3afa007e58
Коммит 4944319a55
7 изменённых файлов: 101 добавлений и 54 удалений

Просмотреть файл

@ -21,6 +21,8 @@ namespace Microsoft.Spark.CSharp.Configuration
public const string CSharpWorkerPathSettingKey = "CSharpWorkerPath";
public const string CSharpBackendPortNumberSettingKey = "CSharpBackendPortNumber";
public const string CSharpSocketTypeEnvName = "spark.mobius.CSharp.socketType";
public const string CSharpWorkerReadBufferSizeEnvName = "spark.mobius.CSharpWorker.readBufferSize";
public const string CSharpWorkerWriteBufferSizeEnvName = "spark.mobius.CSharpWorker.writeBufferSize";
public const string SPARKCLR_HOME = "SPARKCLR_HOME";
public const string SPARK_MASTER = "spark.master";
public const string CSHARPBACKEND_PORT = "CSHARPBACKEND_PORT";
@ -208,7 +210,7 @@ namespace Microsoft.Spark.CSharp.Configuration
logger.LogInfo("Worker path read from setting {0} in app config", CSharpWorkerPathSettingKey);
return workerPathConfig.Value;
}
var path = GetSparkCLRArtifactsPath("bin", ProcFileName);
logger.LogInfo("Worker path {0} constructed using {1} environment variable", path, SPARKCLR_HOME);

Просмотреть файл

@ -2,11 +2,16 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Configuration;
using System.IO;
using System.Runtime.CompilerServices;
using System.Threading;
using Microsoft.Spark.CSharp.Configuration;
using Microsoft.Spark.CSharp.Interop.Ipc;
using Microsoft.Spark.CSharp.Network;
using Microsoft.Spark.CSharp.Services;
[assembly: InternalsVisibleTo("WorkerTest")]
namespace Microsoft.Spark.CSharp
{
/// <summary>
@ -43,21 +48,32 @@ namespace Microsoft.Spark.CSharp
this.socketReuse = socketReuse;
}
private Stream GetStream(Stream stream, int bufferSize)
{
return bufferSize > 0 ? new BufferedStream(stream, bufferSize) : stream;
}
public void Run()
{
Logger.LogInfo(string.Format("TaskRunner [{0}] is running ...", trId));
int readBufferSize = int.Parse(Environment.GetEnvironmentVariable(ConfigurationService.CSharpWorkerReadBufferSizeEnvName) ?? "8192");
int writeBufferSize = int.Parse(Environment.GetEnvironmentVariable(ConfigurationService.CSharpWorkerWriteBufferSizeEnvName) ?? "8192");
Logger.LogInfo(string.Format("TaskRunner [{0}] is running ..., read buffer size: {1}, write buffer size: {2}", trId, readBufferSize, writeBufferSize));
try
{
while (!stop)
{
using (var networkStream = socket.GetStream())
using (var inputStream = GetStream(networkStream, readBufferSize))
using (var outputStream = GetStream(networkStream, writeBufferSize))
{
byte[] bytes = SerDe.ReadBytes(networkStream, sizeof(int));
byte[] bytes = SerDe.ReadBytes(inputStream, sizeof(int));
if (bytes != null)
{
int splitIndex = SerDe.ToInt(bytes);
bool readComplete = Worker.ProcessStream(networkStream, splitIndex);
bool readComplete = Worker.ProcessStream(inputStream, outputStream, splitIndex);
outputStream.Flush();
if (!readComplete) // if the socket is not read through completely, then it can't be reused
{
stop = true;
@ -72,7 +88,7 @@ namespace Microsoft.Spark.CSharp
// Use SerDe.ReadBytes() to detect java side has closed socket properly
// ReadBytes() will block until the socket is closed
Logger.LogInfo("waiting JVM side to close socket...");
SerDe.ReadBytes(networkStream);
SerDe.ReadBytes(inputStream);
Logger.LogInfo("JVM side has closed socket");
}
}

Просмотреть файл

@ -155,7 +155,7 @@ namespace Microsoft.Spark.CSharp
return socket;
}
public static bool ProcessStream(Stream networkStream, int splitIndex)
public static bool ProcessStream(Stream inputStream, Stream outputStream, int splitIndex)
{
logger.LogInfo(string.Format("Start of stream processing, splitIndex: {0}", splitIndex));
bool readComplete = true; // Whether all input data from the socket is read though completely
@ -164,7 +164,7 @@ namespace Microsoft.Spark.CSharp
{
DateTime bootTime = DateTime.UtcNow;
string ver = SerDe.ReadString(networkStream);
string ver = SerDe.ReadString(inputStream);
logger.LogDebug("version: " + ver);
//// initialize global state
@ -172,30 +172,30 @@ namespace Microsoft.Spark.CSharp
//shuffle.DiskBytesSpilled = 0
// fetch name of workdir
string sparkFilesDir = SerDe.ReadString(networkStream);
string sparkFilesDir = SerDe.ReadString(inputStream);
logger.LogDebug("spark_files_dir: " + sparkFilesDir);
//SparkFiles._root_directory = sparkFilesDir
//SparkFiles._is_running_on_worker = True
ProcessIncludesItems(networkStream);
ProcessIncludesItems(inputStream);
ProcessBroadcastVariables(networkStream);
ProcessBroadcastVariables(inputStream);
Accumulator.threadLocalAccumulatorRegistry = new Dictionary<int, Accumulator>();
var formatter = ProcessCommand(networkStream, splitIndex, bootTime);
var formatter = ProcessCommand(inputStream, outputStream, splitIndex, bootTime);
// Mark the beginning of the accumulators section of the output
SerDe.Write(networkStream, (int)SpecialLengths.END_OF_DATA_SECTION);
SerDe.Write(outputStream, (int)SpecialLengths.END_OF_DATA_SECTION);
WriteAccumulatorValues(networkStream, formatter);
WriteAccumulatorValues(outputStream, formatter);
int end = SerDe.ReadInt(networkStream);
int end = SerDe.ReadInt(inputStream);
// check end of stream
if (end == (int)SpecialLengths.END_OF_STREAM)
{
SerDe.Write(networkStream, (int)SpecialLengths.END_OF_STREAM);
SerDe.Write(outputStream, (int)SpecialLengths.END_OF_STREAM);
logger.LogDebug("END_OF_STREAM: " + (int)SpecialLengths.END_OF_STREAM);
}
else
@ -203,11 +203,11 @@ namespace Microsoft.Spark.CSharp
// This may happen when the input data is not read completely, e.g., when take() operation is performed
logger.LogWarn(string.Format("**** unexpected read: {0}, not all data is read", end));
// write a different value to tell JVM to not reuse this worker
SerDe.Write(networkStream, (int)SpecialLengths.END_OF_DATA_SECTION);
SerDe.Write(outputStream, (int)SpecialLengths.END_OF_DATA_SECTION);
readComplete = false;
}
networkStream.Flush();
outputStream.Flush();
// log bytes read and write
logger.LogDebug(string.Format("total read bytes: {0}", SerDe.totalReadNum));
@ -222,7 +222,7 @@ namespace Microsoft.Spark.CSharp
try
{
logger.LogError("Trying to write error to stream");
SerDe.Write(networkStream, e.ToString());
SerDe.Write(outputStream, e.ToString());
}
catch (IOException)
{
@ -281,9 +281,9 @@ namespace Microsoft.Spark.CSharp
}
}
private static IFormatter ProcessCommand(Stream networkStream, int splitIndex, DateTime bootTime)
private static IFormatter ProcessCommand(Stream inputStream, Stream outputStream, int splitIndex, DateTime bootTime)
{
int lengthOfCommandByteArray = SerDe.ReadInt(networkStream);
int lengthOfCommandByteArray = SerDe.ReadInt(inputStream);
logger.LogDebug("command length: " + lengthOfCommandByteArray);
IFormatter formatter = new BinaryFormatter();
@ -293,18 +293,18 @@ namespace Microsoft.Spark.CSharp
var commandProcessWatch = new Stopwatch();
commandProcessWatch.Start();
int stageId = ReadDiagnosticsInfo(networkStream);
int stageId = ReadDiagnosticsInfo(inputStream);
string deserializerMode = SerDe.ReadString(networkStream);
string deserializerMode = SerDe.ReadString(inputStream);
logger.LogDebug("Deserializer mode: " + deserializerMode);
string serializerMode = SerDe.ReadString(networkStream);
string serializerMode = SerDe.ReadString(inputStream);
logger.LogDebug("Serializer mode: " + serializerMode);
string runMode = SerDe.ReadString(networkStream);
string runMode = SerDe.ReadString(inputStream);
if ("R".Equals(runMode, StringComparison.InvariantCultureIgnoreCase))
{
var compilationDumpDir = SerDe.ReadString(networkStream);
var compilationDumpDir = SerDe.ReadString(inputStream);
if (Directory.Exists(compilationDumpDir))
{
assemblyHandler.LoadAssemblies(Directory.GetFiles(compilationDumpDir, "ReplCompilation.*",
@ -316,7 +316,7 @@ namespace Microsoft.Spark.CSharp
}
}
byte[] command = SerDe.ReadBytes(networkStream);
byte[] command = SerDe.ReadBytes(inputStream);
logger.LogDebug("command bytes read: " + command.Length);
var stream = new MemoryStream(command);
@ -333,7 +333,7 @@ namespace Microsoft.Spark.CSharp
int count = 0;
int nullMessageCount = 0;
var funcProcessWatch = Stopwatch.StartNew();
foreach (var message in func(splitIndex, GetIterator(networkStream, deserializerMode)))
foreach (var message in func(splitIndex, GetIterator(inputStream, deserializerMode)))
{
funcProcessWatch.Stop();
@ -343,7 +343,7 @@ namespace Microsoft.Spark.CSharp
continue;
}
WriteOutput(networkStream, serializerMode, message, formatter);
WriteOutput(outputStream, serializerMode, message, formatter);
count++;
funcProcessWatch.Start();
}
@ -356,7 +356,7 @@ namespace Microsoft.Spark.CSharp
//else:
// process()
WriteDiagnosticsInfo(networkStream, bootTime, initTime);
WriteDiagnosticsInfo(outputStream, bootTime, initTime);
commandProcessWatch.Stop();

Просмотреть файл

@ -41,6 +41,7 @@
<HintPath>..\..\packages\Razorvine.Serpent.1.12.0.0\lib\net40\Razorvine.Serpent.dll</HintPath>
</Reference>
<Reference Include="System" />
<Reference Include="System.Configuration" />
<Reference Include="System.Core" />
<Reference Include="Microsoft.CSharp" />
</ItemGroup>

Просмотреть файл

@ -11,6 +11,8 @@ using System.Net;
using System.Reflection;
using System.Runtime.Serialization.Formatters.Binary;
using System.Text;
using System.Threading;
using Microsoft.Spark.CSharp;
using Microsoft.Spark.CSharp.Configuration;
using Microsoft.Spark.CSharp.Core;
using Microsoft.Spark.CSharp.Sql;
@ -338,21 +340,34 @@ namespace WorkerTest
[Test]
public void TestWorkerIncompleteBytes()
{
Process worker;
var CSharpRDD_SocketServer = CreateServer(out worker);
var originalReadBufferSize = Environment.GetEnvironmentVariable(ConfigurationService.CSharpWorkerReadBufferSizeEnvName);
using (var serverSocket = CSharpRDD_SocketServer.Accept())
using (var s = serverSocket.GetStream())
try
{
WritePayloadHeaderToWorker(s);
if (SocketFactory.SocketWrapperType.Equals(SocketWrapperType.Rio))
{
Environment.SetEnvironmentVariable(ConfigurationService.CSharpWorkerReadBufferSizeEnvName, "0");
}
SerDe.Write(s, command.Length);
s.Write(command, 0, command.Length / 2);
Process worker;
var CSharpRDD_SocketServer = CreateServer(out worker);
using (var serverSocket = CSharpRDD_SocketServer.Accept())
using (var s = serverSocket.GetStream())
{
WritePayloadHeaderToWorker(s);
SerDe.Write(s, command.Length);
s.Write(command, 0, command.Length/2);
}
AssertWorker(worker, 0, "System.ArgumentException: Incomplete bytes read: ");
CSharpRDD_SocketServer.Close();
}
finally
{
Environment.SetEnvironmentVariable(ConfigurationService.CSharpWorkerReadBufferSizeEnvName, originalReadBufferSize);
}
AssertWorker(worker, 0, "System.ArgumentException: Incomplete bytes read: ");
CSharpRDD_SocketServer.Close();
}
/// <summary>
@ -374,14 +389,11 @@ namespace WorkerTest
for (int i = 0; i < 100; i++)
SerDe.Write(s, i.ToString());
s.Flush();
int count = 0;
foreach (var bytes in ReadWorker(s, 100))
{
Assert.AreEqual(count++.ToString(), Encoding.UTF8.GetString(bytes));
}
Assert.AreEqual(100, count);
// Note: as send buffer is enabled by default, and CSharpWorker only flushes output after receives all data (receive END_OF_DATA_SECTION flag),
// so in current test we can't ensure expected number of result will be received at this point, validation for returned data is not enabled to avoid flaky test.
}
AssertWorker(worker, 0, "System.NullReferenceException: Object reference not set to an instance of an object.");

Просмотреть файл

@ -9,5 +9,7 @@
|Streaming (Kafka) |spark.mobius.streaming.kafka.numReceivers |Set the number of threads used to materialize the RDD created by applying the user read function to the original KafkaRDD. |
|Streaming (UpdateStateByKey) |spark.mobius.streaming.parallelJobs |Sets 0-based max number of parallel jobs for UpdateStateByKey so that next N batches can start its tasks on time even if previous batch not completed yet. default: 0, recommended: 1. It's a special version of spark.streaming.concurrentJobs which does not observe UpdateStateByKey's state ordering properly |
|Worker |spark.mobius.CSharp.socketType |Sets the socket type that will be used in IPC when transferring data between JVM and CLR. Valid values for this setting are: <ul><li>**Normal**: default .Net Socket implementation will be used. This is the default socket type in Mobius.</li><li>**Rio**: Windows RIO socket will be used. This option can be used **only in Windows OS**</li><li>**Saea**: .Net Socket implementation with SocketAsyncEventArgs class will be used</li></ul> Riosocket and SaeaSocket has better performance when dealing with larger data transmission than traditional .Net Socket. Significant performance improvement has been observed by using RIO/SAEA socket types when the average size of each row in the data processed in Mobius is over 4KB. You can profile your application for different socket types and decide which one offers best performance for your data. Depending on the OS, either Rio (Windows-only) or Saea (Windows/Linux) socket types can be used for data with larger row sizes|
|Worker |spark.mobius.CSharpWorker.readBufferSize |Sets the buffer size in bytes for data read operation from JVM to CSharpWorker. By default the value is 8KB if not explicitly specified. A typical scenario which can benefits a lot from this option is that CSharpWorker reads large amount of small records from JVM process. Please adjust the number based on your scenario. |
|Worker |spark.mobius.CSharpWorker.writeBufferSize |Sets the buffer size in bytes for data write operation from CSharpWorker to JVM. The default value is 8KB. Usually better performance can be gained if specify this option with a proper value when CSharpWorker needs to sends lots of small records (multiple bytes size) back to JVM process. Please adjust the buffer size based on your scenario. |

Просмотреть файл

@ -8,19 +8,17 @@ package org.apache.spark.api.csharp
import java.io._
import java.nio.ByteBuffer
import java.nio.channels.{FileChannel, FileLock, OverlappingFileLockException}
import java.nio.file.Files
import java.nio.file.Paths
import java.nio.file.attribute.PosixFilePermission._
import java.util.{List => JList, Map => JMap}
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.spark._
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.{PythonBroadcast, PythonRDD}
import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, PythonRunner}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.util.csharp.{Utils => CSharpUtils}
import org.apache.spark.api.python.PythonRunner
import scala.collection.JavaConverters._
/**
* RDD used for forking an external C# process and pipe in & out the data
@ -74,9 +72,21 @@ class CSharpRDD(
if (!CSharpRDD.csharpWorkerSocketType.isEmpty) {
envVars.put("spark.mobius.CSharp.socketType", CSharpRDD.csharpWorkerSocketType)
logInfo(s"CSharpWorker socket type: $CSharpRDD.csharpWorkerSocketType")
logInfo(s"CSharpWorker socket type: ${CSharpRDD.csharpWorkerSocketType}")
}
if (CSharpRDD.csharpWorkerReadBufferSize >= 0) {
envVars.put("spark.mobius.CSharpWorker.readBufferSize",
CSharpRDD.csharpWorkerReadBufferSize.toString)
}
if (CSharpRDD.csharpWorkerWriteBufferSize >= 0) {
envVars.put("spark.mobius.CSharpWorker.writeBufferSize",
CSharpRDD.csharpWorkerWriteBufferSize.toString)
}
logInfo("Env vars: " + envVars.asScala.mkString(", "))
val runner = new PythonRunner(
command, envVars, cSharpIncludes, cSharpWorker.getAbsolutePath, unUsedVersionIdentifier,
broadcastVars, accumulator, bufferSize, reuse_worker)
@ -207,6 +217,10 @@ object CSharpRDD {
var maxCSharpWorkerProcessCount: Int = SparkEnv.get.conf.getInt("spark.mobius.CSharpWorker.maxProcessCount", -1)
// socket type for CSharpWorker
var csharpWorkerSocketType: String = SparkEnv.get.conf.get("spark.mobius.CSharp.socketType", "")
// Buffer size in bytes for operation of reading data from JVM process
var csharpWorkerReadBufferSize: Int = SparkEnv.get.conf.getInt("spark.mobius.CSharpWorker.readBufferSize", -1)
// Buffer size in bytes for operation of writing data to JVM process
var csharpWorkerWriteBufferSize: Int = SparkEnv.get.conf.getInt("spark.mobius.CSharpWorker.writeBufferSize", -1)
def createRDDFromArray(
sc: SparkContext,
@ -214,7 +228,7 @@ object CSharpRDD {
numSlices: Int): JavaRDD[Array[Byte]] = {
JavaRDD.fromRDD(sc.parallelize(arr, numSlices))
}
// this method is called when saveAsTextFile is called on RDD<string>
// calling saveAsTextFile() on CSharpRDDs result in bytes written to text file
// - this method converts bytes to string before writing to file