This commit is contained in:
Adam Sitnik 2019-05-21 13:02:07 -07:00 коммит произвёл Terry Kim
Родитель 20177c69cc
Коммит 3c986e10bd
8 изменённых файлов: 54 добавлений и 141 удалений

Просмотреть файл

@ -86,7 +86,8 @@ namespace Microsoft.Spark.UnitTest
var pickledBytes = pickler.dumps(new[] { row1, row2 });
// Note that the following will invoke RowConstructor.construct().
var unpickledData = PythonSerDe.GetUnpickledObjects(new MemoryStream(pickledBytes));
var unpickledData = PythonSerDe.GetUnpickledObjects(
new MemoryStream(pickledBytes), pickledBytes.Length);
Assert.Equal(2, unpickledData.Length);
Assert.Equal(row1, (unpickledData[0] as RowConstructor).GetRow());

Просмотреть файл

@ -11,7 +11,6 @@ using System.Linq;
using Apache.Arrow;
using Apache.Arrow.Ipc;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.IO;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Utils;
using Razorvine.Pickle;
@ -80,13 +79,12 @@ namespace Microsoft.Spark.Worker.Command
/// </summary>
internal class PicklingSqlCommandExecutor : SqlCommandExecutor
{
[ThreadStatic]
private static MemoryStream s_writeOutputStream;
[ThreadStatic]
private static MaxLengthReadStream s_slicedReadStream;
[ThreadStatic]
private static Pickler s_pickler;
[ThreadStatic]
private static byte[] s_outputBuffer;
protected override CommandExecutorStat ExecuteCore(
Stream inputStream,
Stream outputStream,
@ -118,10 +116,6 @@ namespace Microsoft.Spark.Worker.Command
$"Invalid message length: {messageLength}");
}
MaxLengthReadStream readStream = s_slicedReadStream ??
(s_slicedReadStream = new MaxLengthReadStream());
readStream.Reset(inputStream, messageLength);
// Each row in inputRows is of type object[]. If a null is present in a row
// then the corresponding index column of the row object[] will be set to null.
// For example, (inputRows.Length == 2) and (inputRows[0][0] == null)
@ -131,7 +125,7 @@ namespace Microsoft.Spark.Worker.Command
// |null|
// | 11|
// +----+
object[] inputRows = PythonSerDe.GetUnpickledObjects(readStream);
object[] inputRows = PythonSerDe.GetUnpickledObjects(inputStream, messageLength);
for (int i = 0; i < inputRows.Length; ++i)
{
@ -139,7 +133,9 @@ namespace Microsoft.Spark.Worker.Command
outputRows.Add(commandRunner.Run(0, inputRows[i]));
}
WriteOutput(outputStream, outputRows);
// The initial (estimated) buffer size for pickling rows is set to the size of input pickled rows
// because the number of rows are the same for both input and output.
WriteOutput(outputStream, outputRows, messageLength);
stat.NumEntriesProcessed += inputRows.Length;
outputRows.Clear();
}
@ -153,22 +149,25 @@ namespace Microsoft.Spark.Worker.Command
/// </summary>
/// <param name="stream">Stream to write to</param>
/// <param name="rows">Rows to write to</param>
private void WriteOutput(Stream stream, IEnumerable<object> rows)
/// <param name="sizeHint">
/// Estimated max size of the serialized output.
/// If it's not big enough, pickler increases the buffer.
/// </param>
private void WriteOutput(Stream stream, IEnumerable<object> rows, int sizeHint)
{
MemoryStream writeOutputStream = s_writeOutputStream ??
(s_writeOutputStream = new MemoryStream());
writeOutputStream.Position = 0;
if (s_outputBuffer == null)
s_outputBuffer = new byte[sizeHint];
Pickler pickler = s_pickler ?? (s_pickler = new Pickler(false));
pickler.dump(rows, writeOutputStream);
pickler.dumps(rows, ref s_outputBuffer, out int bytesWritten);
if (writeOutputStream.Position == 0)
if (bytesWritten <= 0)
{
throw new Exception("Message buffer cannot be null.");
throw new Exception($"Serialized output size must be positive. Was {bytesWritten}.");
}
SerDe.Write(stream, (int)writeOutputStream.Position);
SerDe.Write(stream, writeOutputStream.GetBuffer(), (int)writeOutputStream.Position);
SerDe.Write(stream, bytesWritten);
SerDe.Write(stream, s_outputBuffer, bytesWritten);
}
/// <summary>

Просмотреть файл

@ -1,87 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.Spark.IO
{
/// <summary>
/// Provides a stream wrapper that allows reading only up to the specified number of bytes.
/// </summary>
internal sealed class MaxLengthReadStream : Stream
{
private Stream _stream;
private int _remainingAllowedLength;
public void Reset(Stream stream, int maxLength)
{
_stream = stream;
_remainingAllowedLength = maxLength;
}
public override int ReadByte()
{
int result = -1;
if ((_remainingAllowedLength > 0) && (result = _stream.ReadByte()) != -1)
{
--_remainingAllowedLength;
}
return result;
}
public override int Read(byte[] buffer, int offset, int count)
{
if (count > _remainingAllowedLength)
{
count = _remainingAllowedLength;
}
int read = _stream.Read(buffer, offset, count);
_remainingAllowedLength -= read;
return read;
}
public override async Task<int> ReadAsync(
byte[] buffer,
int offset,
int count,
CancellationToken cancellationToken)
{
if (count > _remainingAllowedLength)
{
count = _remainingAllowedLength;
}
int read = await _stream.ReadAsync(buffer, offset, count, cancellationToken)
.ConfigureAwait(false);
_remainingAllowedLength -= read;
return read;
}
// TODO: On .NET Core 2.1+ / .NET Standard 2.1+, also override ReadAsync that
// returns ValueTask<int>.
public override void Flush() => _stream.Flush();
public override Task FlushAsync(CancellationToken cancellationToken) =>
_stream.FlushAsync();
public override bool CanRead => true;
public override bool CanSeek => false;
public override bool CanWrite => false;
public override long Length => throw new NotSupportedException();
public override long Position
{
get => throw new NotSupportedException();
set => throw new NotSupportedException();
}
public override long Seek(long offset, SeekOrigin origin) =>
throw new NotSupportedException();
public override void SetLength(long value) => throw new NotSupportedException();
public override void Write(byte[] buffer, int offset, int count) =>
throw new NotSupportedException();
}
}

Просмотреть файл

@ -151,8 +151,7 @@ namespace Microsoft.Spark.Interop.Ipc
{
if (length < 0)
{
throw new ArgumentOutOfRangeException(
"length", length, "length can't be negative.");
throw new ArgumentOutOfRangeException(nameof(length), length, "length can't be negative.");
}
var buffer = new byte[length];

Просмотреть файл

@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
@ -16,7 +16,7 @@
<PackageReference Include="Apache.Arrow" Version="0.13.0" />
<PackageReference Include="Microsoft.CSharp" Version="4.5.0" />
<PackageReference Include="Newtonsoft.Json" Version="11.0.2" />
<PackageReference Include="Razorvine.Pyrolite" Version="4.25.0" />
<PackageReference Include="Razorvine.Pyrolite" Version="4.26.0" />
<PackageReference Include="System.Memory" Version="4.5.2" />
</ItemGroup>

Просмотреть файл

@ -7,7 +7,6 @@ using System.Collections.Generic;
using System.IO;
using System.Runtime.Serialization.Formatters.Binary;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.IO;
using static Microsoft.Spark.Utils.CommandSerDe;
namespace Microsoft.Spark.RDD
@ -66,18 +65,10 @@ namespace Microsoft.Spark.RDD
/// </summary>
private sealed class BinaryDeserializer : IDeserializer
{
[ThreadStatic]
private static MaxLengthReadStream s_slicedReadStream;
private readonly BinaryFormatter _formater = new BinaryFormatter();
public object Deserialize(Stream stream, int length)
{
MaxLengthReadStream readStream = s_slicedReadStream ??
(s_slicedReadStream = new MaxLengthReadStream());
readStream.Reset(stream, length);
return _formater.Deserialize(stream);
}
}

Просмотреть файл

@ -2,11 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.IO;
using Microsoft.Spark.Network;
using Microsoft.Spark.Utils;
@ -17,9 +15,6 @@ namespace Microsoft.Spark.Sql
/// </summary>
internal sealed class RowCollector
{
[ThreadStatic]
private static MaxLengthReadStream s_slicedReadStream;
/// <summary>
/// Collects pickled row objects from the given socket.
/// </summary>
@ -30,14 +25,9 @@ namespace Microsoft.Spark.Sql
Stream inputStream = socket.InputStream;
int? length;
while (((length = SerDe.ReadBytesLength(inputStream)) != null) &&
(length.GetValueOrDefault() > 0))
while (((length = SerDe.ReadBytesLength(inputStream)) != null) && (length.GetValueOrDefault() > 0))
{
MaxLengthReadStream readStream = s_slicedReadStream ??
(s_slicedReadStream = new MaxLengthReadStream());
readStream.Reset(inputStream, length.GetValueOrDefault());
var unpickledObjects = PythonSerDe.GetUnpickledObjects(readStream);
object[] unpickledObjects = PythonSerDe.GetUnpickledObjects(inputStream, length.GetValueOrDefault());
foreach (object unpickled in unpickledObjects)
{

Просмотреть файл

@ -2,8 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Buffers;
using System.Diagnostics;
using System.IO;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Sql;
using Razorvine.Pickle;
using Razorvine.Pickle.Objects;
@ -26,17 +29,34 @@ namespace Microsoft.Spark.Utils
}
/// <summary>
/// Unpickles objects from byte[].
/// Unpickles objects from Stream.
/// </summary>
/// <param name="s">Pickled byte stream</param>
/// <param name="stream">Pickled byte stream</param>
/// <param name="messageLength">Size (in bytes) of the pickled input</param>
/// <returns>Unpicked objects</returns>
internal static object[] GetUnpickledObjects(Stream s)
internal static object[] GetUnpickledObjects(Stream stream, int messageLength)
{
// Not making any assumptions about the implementation and hence not a class member.
var unpickler = new Unpickler();
var unpickledItems = unpickler.load(s);
Debug.Assert(unpickledItems != null);
return (unpickledItems as object[]);
byte[] buffer = ArrayPool<byte>.Shared.Rent(messageLength);
try
{
if (!SerDe.TryReadBytes(stream, buffer, messageLength))
{
throw new ArgumentException("The stream is closed.");
}
// Not making any assumptions about the implementation and hence not a class member.
var unpickler = new Unpickler();
object unpickledItems = unpickler.loads(
new ReadOnlyMemory<byte>(buffer, 0, messageLength),
stackCapacity: 102); // spark always sends batches of 100 rows, +2 is for markers
Debug.Assert(unpickledItems != null);
return (unpickledItems as object[]);
}
finally
{
ArrayPool<byte>.Shared.Return(buffer);
}
}
}
}