зеркало из https://github.com/dotnet/spark.git
Improve Pickling performance (#111)
This commit is contained in:
Родитель
20177c69cc
Коммит
3c986e10bd
|
@ -86,7 +86,8 @@ namespace Microsoft.Spark.UnitTest
|
|||
var pickledBytes = pickler.dumps(new[] { row1, row2 });
|
||||
|
||||
// Note that the following will invoke RowConstructor.construct().
|
||||
var unpickledData = PythonSerDe.GetUnpickledObjects(new MemoryStream(pickledBytes));
|
||||
var unpickledData = PythonSerDe.GetUnpickledObjects(
|
||||
new MemoryStream(pickledBytes), pickledBytes.Length);
|
||||
|
||||
Assert.Equal(2, unpickledData.Length);
|
||||
Assert.Equal(row1, (unpickledData[0] as RowConstructor).GetRow());
|
||||
|
|
|
@ -11,7 +11,6 @@ using System.Linq;
|
|||
using Apache.Arrow;
|
||||
using Apache.Arrow.Ipc;
|
||||
using Microsoft.Spark.Interop.Ipc;
|
||||
using Microsoft.Spark.IO;
|
||||
using Microsoft.Spark.Sql;
|
||||
using Microsoft.Spark.Utils;
|
||||
using Razorvine.Pickle;
|
||||
|
@ -80,13 +79,12 @@ namespace Microsoft.Spark.Worker.Command
|
|||
/// </summary>
|
||||
internal class PicklingSqlCommandExecutor : SqlCommandExecutor
|
||||
{
|
||||
[ThreadStatic]
|
||||
private static MemoryStream s_writeOutputStream;
|
||||
[ThreadStatic]
|
||||
private static MaxLengthReadStream s_slicedReadStream;
|
||||
[ThreadStatic]
|
||||
private static Pickler s_pickler;
|
||||
|
||||
[ThreadStatic]
|
||||
private static byte[] s_outputBuffer;
|
||||
|
||||
protected override CommandExecutorStat ExecuteCore(
|
||||
Stream inputStream,
|
||||
Stream outputStream,
|
||||
|
@ -118,10 +116,6 @@ namespace Microsoft.Spark.Worker.Command
|
|||
$"Invalid message length: {messageLength}");
|
||||
}
|
||||
|
||||
MaxLengthReadStream readStream = s_slicedReadStream ??
|
||||
(s_slicedReadStream = new MaxLengthReadStream());
|
||||
readStream.Reset(inputStream, messageLength);
|
||||
|
||||
// Each row in inputRows is of type object[]. If a null is present in a row
|
||||
// then the corresponding index column of the row object[] will be set to null.
|
||||
// For example, (inputRows.Length == 2) and (inputRows[0][0] == null)
|
||||
|
@ -131,7 +125,7 @@ namespace Microsoft.Spark.Worker.Command
|
|||
// |null|
|
||||
// | 11|
|
||||
// +----+
|
||||
object[] inputRows = PythonSerDe.GetUnpickledObjects(readStream);
|
||||
object[] inputRows = PythonSerDe.GetUnpickledObjects(inputStream, messageLength);
|
||||
|
||||
for (int i = 0; i < inputRows.Length; ++i)
|
||||
{
|
||||
|
@ -139,7 +133,9 @@ namespace Microsoft.Spark.Worker.Command
|
|||
outputRows.Add(commandRunner.Run(0, inputRows[i]));
|
||||
}
|
||||
|
||||
WriteOutput(outputStream, outputRows);
|
||||
// The initial (estimated) buffer size for pickling rows is set to the size of input pickled rows
|
||||
// because the number of rows are the same for both input and output.
|
||||
WriteOutput(outputStream, outputRows, messageLength);
|
||||
stat.NumEntriesProcessed += inputRows.Length;
|
||||
outputRows.Clear();
|
||||
}
|
||||
|
@ -153,22 +149,25 @@ namespace Microsoft.Spark.Worker.Command
|
|||
/// </summary>
|
||||
/// <param name="stream">Stream to write to</param>
|
||||
/// <param name="rows">Rows to write to</param>
|
||||
private void WriteOutput(Stream stream, IEnumerable<object> rows)
|
||||
/// <param name="sizeHint">
|
||||
/// Estimated max size of the serialized output.
|
||||
/// If it's not big enough, pickler increases the buffer.
|
||||
/// </param>
|
||||
private void WriteOutput(Stream stream, IEnumerable<object> rows, int sizeHint)
|
||||
{
|
||||
MemoryStream writeOutputStream = s_writeOutputStream ??
|
||||
(s_writeOutputStream = new MemoryStream());
|
||||
writeOutputStream.Position = 0;
|
||||
if (s_outputBuffer == null)
|
||||
s_outputBuffer = new byte[sizeHint];
|
||||
|
||||
Pickler pickler = s_pickler ?? (s_pickler = new Pickler(false));
|
||||
pickler.dump(rows, writeOutputStream);
|
||||
pickler.dumps(rows, ref s_outputBuffer, out int bytesWritten);
|
||||
|
||||
if (writeOutputStream.Position == 0)
|
||||
if (bytesWritten <= 0)
|
||||
{
|
||||
throw new Exception("Message buffer cannot be null.");
|
||||
throw new Exception($"Serialized output size must be positive. Was {bytesWritten}.");
|
||||
}
|
||||
|
||||
SerDe.Write(stream, (int)writeOutputStream.Position);
|
||||
SerDe.Write(stream, writeOutputStream.GetBuffer(), (int)writeOutputStream.Position);
|
||||
SerDe.Write(stream, bytesWritten);
|
||||
SerDe.Write(stream, s_outputBuffer, bytesWritten);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -1,87 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.Spark.IO
|
||||
{
|
||||
/// <summary>
|
||||
/// Provides a stream wrapper that allows reading only up to the specified number of bytes.
|
||||
/// </summary>
|
||||
internal sealed class MaxLengthReadStream : Stream
|
||||
{
|
||||
private Stream _stream;
|
||||
private int _remainingAllowedLength;
|
||||
|
||||
public void Reset(Stream stream, int maxLength)
|
||||
{
|
||||
_stream = stream;
|
||||
_remainingAllowedLength = maxLength;
|
||||
}
|
||||
|
||||
public override int ReadByte()
|
||||
{
|
||||
int result = -1;
|
||||
if ((_remainingAllowedLength > 0) && (result = _stream.ReadByte()) != -1)
|
||||
{
|
||||
--_remainingAllowedLength;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public override int Read(byte[] buffer, int offset, int count)
|
||||
{
|
||||
if (count > _remainingAllowedLength)
|
||||
{
|
||||
count = _remainingAllowedLength;
|
||||
}
|
||||
|
||||
int read = _stream.Read(buffer, offset, count);
|
||||
_remainingAllowedLength -= read;
|
||||
return read;
|
||||
}
|
||||
|
||||
public override async Task<int> ReadAsync(
|
||||
byte[] buffer,
|
||||
int offset,
|
||||
int count,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
if (count > _remainingAllowedLength)
|
||||
{
|
||||
count = _remainingAllowedLength;
|
||||
}
|
||||
|
||||
int read = await _stream.ReadAsync(buffer, offset, count, cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
|
||||
_remainingAllowedLength -= read;
|
||||
return read;
|
||||
}
|
||||
|
||||
// TODO: On .NET Core 2.1+ / .NET Standard 2.1+, also override ReadAsync that
|
||||
// returns ValueTask<int>.
|
||||
|
||||
public override void Flush() => _stream.Flush();
|
||||
public override Task FlushAsync(CancellationToken cancellationToken) =>
|
||||
_stream.FlushAsync();
|
||||
public override bool CanRead => true;
|
||||
public override bool CanSeek => false;
|
||||
public override bool CanWrite => false;
|
||||
public override long Length => throw new NotSupportedException();
|
||||
public override long Position
|
||||
{
|
||||
get => throw new NotSupportedException();
|
||||
set => throw new NotSupportedException();
|
||||
}
|
||||
public override long Seek(long offset, SeekOrigin origin) =>
|
||||
throw new NotSupportedException();
|
||||
public override void SetLength(long value) => throw new NotSupportedException();
|
||||
public override void Write(byte[] buffer, int offset, int count) =>
|
||||
throw new NotSupportedException();
|
||||
}
|
||||
}
|
|
@ -151,8 +151,7 @@ namespace Microsoft.Spark.Interop.Ipc
|
|||
{
|
||||
if (length < 0)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(
|
||||
"length", length, "length can't be negative.");
|
||||
throw new ArgumentOutOfRangeException(nameof(length), length, "length can't be negative.");
|
||||
}
|
||||
|
||||
var buffer = new byte[length];
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>netstandard2.0</TargetFramework>
|
||||
|
@ -16,7 +16,7 @@
|
|||
<PackageReference Include="Apache.Arrow" Version="0.13.0" />
|
||||
<PackageReference Include="Microsoft.CSharp" Version="4.5.0" />
|
||||
<PackageReference Include="Newtonsoft.Json" Version="11.0.2" />
|
||||
<PackageReference Include="Razorvine.Pyrolite" Version="4.25.0" />
|
||||
<PackageReference Include="Razorvine.Pyrolite" Version="4.26.0" />
|
||||
<PackageReference Include="System.Memory" Version="4.5.2" />
|
||||
</ItemGroup>
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ using System.Collections.Generic;
|
|||
using System.IO;
|
||||
using System.Runtime.Serialization.Formatters.Binary;
|
||||
using Microsoft.Spark.Interop.Ipc;
|
||||
using Microsoft.Spark.IO;
|
||||
using static Microsoft.Spark.Utils.CommandSerDe;
|
||||
|
||||
namespace Microsoft.Spark.RDD
|
||||
|
@ -66,18 +65,10 @@ namespace Microsoft.Spark.RDD
|
|||
/// </summary>
|
||||
private sealed class BinaryDeserializer : IDeserializer
|
||||
{
|
||||
[ThreadStatic]
|
||||
private static MaxLengthReadStream s_slicedReadStream;
|
||||
|
||||
private readonly BinaryFormatter _formater = new BinaryFormatter();
|
||||
|
||||
public object Deserialize(Stream stream, int length)
|
||||
{
|
||||
MaxLengthReadStream readStream = s_slicedReadStream ??
|
||||
(s_slicedReadStream = new MaxLengthReadStream());
|
||||
|
||||
readStream.Reset(stream, length);
|
||||
|
||||
return _formater.Deserialize(stream);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,11 +2,9 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using Microsoft.Spark.Interop.Ipc;
|
||||
using Microsoft.Spark.IO;
|
||||
using Microsoft.Spark.Network;
|
||||
using Microsoft.Spark.Utils;
|
||||
|
||||
|
@ -17,9 +15,6 @@ namespace Microsoft.Spark.Sql
|
|||
/// </summary>
|
||||
internal sealed class RowCollector
|
||||
{
|
||||
[ThreadStatic]
|
||||
private static MaxLengthReadStream s_slicedReadStream;
|
||||
|
||||
/// <summary>
|
||||
/// Collects pickled row objects from the given socket.
|
||||
/// </summary>
|
||||
|
@ -30,14 +25,9 @@ namespace Microsoft.Spark.Sql
|
|||
Stream inputStream = socket.InputStream;
|
||||
|
||||
int? length;
|
||||
while (((length = SerDe.ReadBytesLength(inputStream)) != null) &&
|
||||
(length.GetValueOrDefault() > 0))
|
||||
while (((length = SerDe.ReadBytesLength(inputStream)) != null) && (length.GetValueOrDefault() > 0))
|
||||
{
|
||||
MaxLengthReadStream readStream = s_slicedReadStream ??
|
||||
(s_slicedReadStream = new MaxLengthReadStream());
|
||||
|
||||
readStream.Reset(inputStream, length.GetValueOrDefault());
|
||||
var unpickledObjects = PythonSerDe.GetUnpickledObjects(readStream);
|
||||
object[] unpickledObjects = PythonSerDe.GetUnpickledObjects(inputStream, length.GetValueOrDefault());
|
||||
|
||||
foreach (object unpickled in unpickledObjects)
|
||||
{
|
||||
|
|
|
@ -2,8 +2,11 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using Microsoft.Spark.Interop.Ipc;
|
||||
using Microsoft.Spark.Sql;
|
||||
using Razorvine.Pickle;
|
||||
using Razorvine.Pickle.Objects;
|
||||
|
@ -26,17 +29,34 @@ namespace Microsoft.Spark.Utils
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Unpickles objects from byte[].
|
||||
/// Unpickles objects from Stream.
|
||||
/// </summary>
|
||||
/// <param name="s">Pickled byte stream</param>
|
||||
/// <param name="stream">Pickled byte stream</param>
|
||||
/// <param name="messageLength">Size (in bytes) of the pickled input</param>
|
||||
/// <returns>Unpicked objects</returns>
|
||||
internal static object[] GetUnpickledObjects(Stream s)
|
||||
internal static object[] GetUnpickledObjects(Stream stream, int messageLength)
|
||||
{
|
||||
// Not making any assumptions about the implementation and hence not a class member.
|
||||
var unpickler = new Unpickler();
|
||||
var unpickledItems = unpickler.load(s);
|
||||
Debug.Assert(unpickledItems != null);
|
||||
return (unpickledItems as object[]);
|
||||
byte[] buffer = ArrayPool<byte>.Shared.Rent(messageLength);
|
||||
|
||||
try
|
||||
{
|
||||
if (!SerDe.TryReadBytes(stream, buffer, messageLength))
|
||||
{
|
||||
throw new ArgumentException("The stream is closed.");
|
||||
}
|
||||
|
||||
// Not making any assumptions about the implementation and hence not a class member.
|
||||
var unpickler = new Unpickler();
|
||||
object unpickledItems = unpickler.loads(
|
||||
new ReadOnlyMemory<byte>(buffer, 0, messageLength),
|
||||
stackCapacity: 102); // spark always sends batches of 100 rows, +2 is for markers
|
||||
Debug.Assert(unpickledItems != null);
|
||||
return (unpickledItems as object[]);
|
||||
}
|
||||
finally
|
||||
{
|
||||
ArrayPool<byte>.Shared.Return(buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче