diff --git a/src/csharp/Microsoft.Spark.UnitTest/Sql/RowTests.cs b/src/csharp/Microsoft.Spark.UnitTest/Sql/RowTests.cs index c86bf9d9..e057761b 100644 --- a/src/csharp/Microsoft.Spark.UnitTest/Sql/RowTests.cs +++ b/src/csharp/Microsoft.Spark.UnitTest/Sql/RowTests.cs @@ -86,7 +86,8 @@ namespace Microsoft.Spark.UnitTest var pickledBytes = pickler.dumps(new[] { row1, row2 }); // Note that the following will invoke RowConstructor.construct(). - var unpickledData = PythonSerDe.GetUnpickledObjects(new MemoryStream(pickledBytes)); + var unpickledData = PythonSerDe.GetUnpickledObjects( + new MemoryStream(pickledBytes), pickledBytes.Length); Assert.Equal(2, unpickledData.Length); Assert.Equal(row1, (unpickledData[0] as RowConstructor).GetRow()); diff --git a/src/csharp/Microsoft.Spark.Worker/Command/SqlCommandExecutor.cs b/src/csharp/Microsoft.Spark.Worker/Command/SqlCommandExecutor.cs index 1526c9ef..b232b83a 100644 --- a/src/csharp/Microsoft.Spark.Worker/Command/SqlCommandExecutor.cs +++ b/src/csharp/Microsoft.Spark.Worker/Command/SqlCommandExecutor.cs @@ -11,7 +11,6 @@ using System.Linq; using Apache.Arrow; using Apache.Arrow.Ipc; using Microsoft.Spark.Interop.Ipc; -using Microsoft.Spark.IO; using Microsoft.Spark.Sql; using Microsoft.Spark.Utils; using Razorvine.Pickle; @@ -80,13 +79,12 @@ namespace Microsoft.Spark.Worker.Command /// internal class PicklingSqlCommandExecutor : SqlCommandExecutor { - [ThreadStatic] - private static MemoryStream s_writeOutputStream; - [ThreadStatic] - private static MaxLengthReadStream s_slicedReadStream; [ThreadStatic] private static Pickler s_pickler; + [ThreadStatic] + private static byte[] s_outputBuffer; + protected override CommandExecutorStat ExecuteCore( Stream inputStream, Stream outputStream, @@ -118,10 +116,6 @@ namespace Microsoft.Spark.Worker.Command $"Invalid message length: {messageLength}"); } - MaxLengthReadStream readStream = s_slicedReadStream ?? - (s_slicedReadStream = new MaxLengthReadStream()); - readStream.Reset(inputStream, messageLength); - // Each row in inputRows is of type object[]. If a null is present in a row // then the corresponding index column of the row object[] will be set to null. // For example, (inputRows.Length == 2) and (inputRows[0][0] == null) @@ -131,7 +125,7 @@ namespace Microsoft.Spark.Worker.Command // |null| // | 11| // +----+ - object[] inputRows = PythonSerDe.GetUnpickledObjects(readStream); + object[] inputRows = PythonSerDe.GetUnpickledObjects(inputStream, messageLength); for (int i = 0; i < inputRows.Length; ++i) { @@ -139,7 +133,9 @@ namespace Microsoft.Spark.Worker.Command outputRows.Add(commandRunner.Run(0, inputRows[i])); } - WriteOutput(outputStream, outputRows); + // The initial (estimated) buffer size for pickling rows is set to the size of input pickled rows + // because the number of rows are the same for both input and output. + WriteOutput(outputStream, outputRows, messageLength); stat.NumEntriesProcessed += inputRows.Length; outputRows.Clear(); } @@ -153,22 +149,25 @@ namespace Microsoft.Spark.Worker.Command /// /// Stream to write to /// Rows to write to - private void WriteOutput(Stream stream, IEnumerable rows) + /// + /// Estimated max size of the serialized output. + /// If it's not big enough, pickler increases the buffer. + /// + private void WriteOutput(Stream stream, IEnumerable rows, int sizeHint) { - MemoryStream writeOutputStream = s_writeOutputStream ?? - (s_writeOutputStream = new MemoryStream()); - writeOutputStream.Position = 0; + if (s_outputBuffer == null) + s_outputBuffer = new byte[sizeHint]; Pickler pickler = s_pickler ?? (s_pickler = new Pickler(false)); - pickler.dump(rows, writeOutputStream); + pickler.dumps(rows, ref s_outputBuffer, out int bytesWritten); - if (writeOutputStream.Position == 0) + if (bytesWritten <= 0) { - throw new Exception("Message buffer cannot be null."); + throw new Exception($"Serialized output size must be positive. Was {bytesWritten}."); } - SerDe.Write(stream, (int)writeOutputStream.Position); - SerDe.Write(stream, writeOutputStream.GetBuffer(), (int)writeOutputStream.Position); + SerDe.Write(stream, bytesWritten); + SerDe.Write(stream, s_outputBuffer, bytesWritten); } /// diff --git a/src/csharp/Microsoft.Spark/IO/MaxLengthReadStream.cs b/src/csharp/Microsoft.Spark/IO/MaxLengthReadStream.cs deleted file mode 100644 index a0de064a..00000000 --- a/src/csharp/Microsoft.Spark/IO/MaxLengthReadStream.cs +++ /dev/null @@ -1,87 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace Microsoft.Spark.IO -{ - /// - /// Provides a stream wrapper that allows reading only up to the specified number of bytes. - /// - internal sealed class MaxLengthReadStream : Stream - { - private Stream _stream; - private int _remainingAllowedLength; - - public void Reset(Stream stream, int maxLength) - { - _stream = stream; - _remainingAllowedLength = maxLength; - } - - public override int ReadByte() - { - int result = -1; - if ((_remainingAllowedLength > 0) && (result = _stream.ReadByte()) != -1) - { - --_remainingAllowedLength; - } - return result; - } - - public override int Read(byte[] buffer, int offset, int count) - { - if (count > _remainingAllowedLength) - { - count = _remainingAllowedLength; - } - - int read = _stream.Read(buffer, offset, count); - _remainingAllowedLength -= read; - return read; - } - - public override async Task ReadAsync( - byte[] buffer, - int offset, - int count, - CancellationToken cancellationToken) - { - if (count > _remainingAllowedLength) - { - count = _remainingAllowedLength; - } - - int read = await _stream.ReadAsync(buffer, offset, count, cancellationToken) - .ConfigureAwait(false); - - _remainingAllowedLength -= read; - return read; - } - - // TODO: On .NET Core 2.1+ / .NET Standard 2.1+, also override ReadAsync that - // returns ValueTask. - - public override void Flush() => _stream.Flush(); - public override Task FlushAsync(CancellationToken cancellationToken) => - _stream.FlushAsync(); - public override bool CanRead => true; - public override bool CanSeek => false; - public override bool CanWrite => false; - public override long Length => throw new NotSupportedException(); - public override long Position - { - get => throw new NotSupportedException(); - set => throw new NotSupportedException(); - } - public override long Seek(long offset, SeekOrigin origin) => - throw new NotSupportedException(); - public override void SetLength(long value) => throw new NotSupportedException(); - public override void Write(byte[] buffer, int offset, int count) => - throw new NotSupportedException(); - } -} diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/SerDe.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/SerDe.cs index 268d83c8..d0f6937a 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/SerDe.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/SerDe.cs @@ -151,8 +151,7 @@ namespace Microsoft.Spark.Interop.Ipc { if (length < 0) { - throw new ArgumentOutOfRangeException( - "length", length, "length can't be negative."); + throw new ArgumentOutOfRangeException(nameof(length), length, "length can't be negative."); } var buffer = new byte[length]; diff --git a/src/csharp/Microsoft.Spark/Microsoft.Spark.csproj b/src/csharp/Microsoft.Spark/Microsoft.Spark.csproj index c4774729..77f8e1b4 100644 --- a/src/csharp/Microsoft.Spark/Microsoft.Spark.csproj +++ b/src/csharp/Microsoft.Spark/Microsoft.Spark.csproj @@ -1,4 +1,4 @@ - + netstandard2.0 @@ -16,7 +16,7 @@ - + diff --git a/src/csharp/Microsoft.Spark/RDD/Collector.cs b/src/csharp/Microsoft.Spark/RDD/Collector.cs index c695c729..489fe3b7 100644 --- a/src/csharp/Microsoft.Spark/RDD/Collector.cs +++ b/src/csharp/Microsoft.Spark/RDD/Collector.cs @@ -7,7 +7,6 @@ using System.Collections.Generic; using System.IO; using System.Runtime.Serialization.Formatters.Binary; using Microsoft.Spark.Interop.Ipc; -using Microsoft.Spark.IO; using static Microsoft.Spark.Utils.CommandSerDe; namespace Microsoft.Spark.RDD @@ -66,18 +65,10 @@ namespace Microsoft.Spark.RDD /// private sealed class BinaryDeserializer : IDeserializer { - [ThreadStatic] - private static MaxLengthReadStream s_slicedReadStream; - private readonly BinaryFormatter _formater = new BinaryFormatter(); public object Deserialize(Stream stream, int length) { - MaxLengthReadStream readStream = s_slicedReadStream ?? - (s_slicedReadStream = new MaxLengthReadStream()); - - readStream.Reset(stream, length); - return _formater.Deserialize(stream); } } diff --git a/src/csharp/Microsoft.Spark/Sql/RowCollector.cs b/src/csharp/Microsoft.Spark/Sql/RowCollector.cs index c4d3f404..d265d37f 100644 --- a/src/csharp/Microsoft.Spark/Sql/RowCollector.cs +++ b/src/csharp/Microsoft.Spark/Sql/RowCollector.cs @@ -2,11 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Collections.Generic; using System.IO; using Microsoft.Spark.Interop.Ipc; -using Microsoft.Spark.IO; using Microsoft.Spark.Network; using Microsoft.Spark.Utils; @@ -17,9 +15,6 @@ namespace Microsoft.Spark.Sql /// internal sealed class RowCollector { - [ThreadStatic] - private static MaxLengthReadStream s_slicedReadStream; - /// /// Collects pickled row objects from the given socket. /// @@ -30,14 +25,9 @@ namespace Microsoft.Spark.Sql Stream inputStream = socket.InputStream; int? length; - while (((length = SerDe.ReadBytesLength(inputStream)) != null) && - (length.GetValueOrDefault() > 0)) + while (((length = SerDe.ReadBytesLength(inputStream)) != null) && (length.GetValueOrDefault() > 0)) { - MaxLengthReadStream readStream = s_slicedReadStream ?? - (s_slicedReadStream = new MaxLengthReadStream()); - - readStream.Reset(inputStream, length.GetValueOrDefault()); - var unpickledObjects = PythonSerDe.GetUnpickledObjects(readStream); + object[] unpickledObjects = PythonSerDe.GetUnpickledObjects(inputStream, length.GetValueOrDefault()); foreach (object unpickled in unpickledObjects) { diff --git a/src/csharp/Microsoft.Spark/Utils/PythonSerDe.cs b/src/csharp/Microsoft.Spark/Utils/PythonSerDe.cs index f546a820..c5e62e7f 100644 --- a/src/csharp/Microsoft.Spark/Utils/PythonSerDe.cs +++ b/src/csharp/Microsoft.Spark/Utils/PythonSerDe.cs @@ -2,8 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Buffers; using System.Diagnostics; using System.IO; +using Microsoft.Spark.Interop.Ipc; using Microsoft.Spark.Sql; using Razorvine.Pickle; using Razorvine.Pickle.Objects; @@ -26,17 +29,34 @@ namespace Microsoft.Spark.Utils } /// - /// Unpickles objects from byte[]. + /// Unpickles objects from Stream. /// - /// Pickled byte stream + /// Pickled byte stream + /// Size (in bytes) of the pickled input /// Unpicked objects - internal static object[] GetUnpickledObjects(Stream s) + internal static object[] GetUnpickledObjects(Stream stream, int messageLength) { - // Not making any assumptions about the implementation and hence not a class member. - var unpickler = new Unpickler(); - var unpickledItems = unpickler.load(s); - Debug.Assert(unpickledItems != null); - return (unpickledItems as object[]); + byte[] buffer = ArrayPool.Shared.Rent(messageLength); + + try + { + if (!SerDe.TryReadBytes(stream, buffer, messageLength)) + { + throw new ArgumentException("The stream is closed."); + } + + // Not making any assumptions about the implementation and hence not a class member. + var unpickler = new Unpickler(); + object unpickledItems = unpickler.loads( + new ReadOnlyMemory(buffer, 0, messageLength), + stackCapacity: 102); // spark always sends batches of 100 rows, +2 is for markers + Debug.Assert(unpickledItems != null); + return (unpickledItems as object[]); + } + finally + { + ArrayPool.Shared.Return(buffer); + } } } }