diff --git a/NuGet.config b/NuGet.config index 69bab2da..ce00c056 100644 --- a/NuGet.config +++ b/NuGet.config @@ -4,6 +4,8 @@ + + diff --git a/ToProjectReferences.ps1 b/ToProjectReferences.ps1 new file mode 100644 index 00000000..4273aff9 --- /dev/null +++ b/ToProjectReferences.ps1 @@ -0,0 +1,45 @@ +param($references) +$ErrorActionPreference = "Stop"; + +function ToProjectName($file) +{ + return $file.Directory.Name; +} + +$projectreferences = ls (Join-Path $references *.csproj) -rec; + +$localprojects = ls -rec *.csproj; + +foreach ($project in $localprojects) +{ + Write-Host "Processing $project"; + + [Reflection.Assembly]::LoadWithPartialName("System.Xml.Linq") | Out-Null; + + $changed = $false + $xDoc = [System.Xml.Linq.XDocument]::Load($project, [System.Xml.Linq.LoadOptions]::PreserveWhitespace); + $endpoints = $xDoc.Descendants("PackageReference") | %{ + $packageName = $_.Attribute("Include").Value; + $replacementProject = $projectreferences | ? { + return (ToProjectName($_)) -eq $packageName + }; + + if ($replacementProject) + { + $changed = $true + Write-Host " Replacing $packageName with $($project.FullName)"; + $_.Name = "ProjectReference"; + $_.Attribute("Include").Value = $replacementProject.FullName; + } + }; + if ($changed) + { + $settings = New-Object System.Xml.XmlWriterSettings + $settings.OmitXmlDeclaration = $true; + $writer = [System.Xml.XmlWriter]::Create($project, $settings) + + $xDoc.Save($writer); + $writer.Dispose(); + } + +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Adapter/Internal/AdaptedPipeline.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Adapter/Internal/AdaptedPipeline.cs index 28c8fd66..f3ca1812 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Adapter/Internal/AdaptedPipeline.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Adapter/Internal/AdaptedPipeline.cs @@ -3,37 +3,40 @@ using System; using System.IO; +using System.IO.Pipelines; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; +using MemoryPool = Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure.MemoryPool; namespace Microsoft.AspNetCore.Server.Kestrel.Adapter.Internal { public class AdaptedPipeline : IDisposable { + private const int MinAllocBufferSize = 2048; + private readonly Stream _filteredStream; public AdaptedPipeline( string connectionId, Stream filteredStream, + IPipe pipe, MemoryPool memory, - IKestrelTrace logger, - IThreadPool threadPool, - IBufferSizeControl bufferSizeControl) + IKestrelTrace logger) { - SocketInput = new SocketInput(memory, threadPool, bufferSizeControl); - SocketOutput = new StreamSocketOutput(connectionId, filteredStream, memory, logger); + Input = pipe; + Output = new StreamSocketOutput(connectionId, filteredStream, memory, logger); _filteredStream = filteredStream; } - public SocketInput SocketInput { get; } + public IPipe Input { get; } - public ISocketOutput SocketOutput { get; } + public ISocketOutput Output { get; } public void Dispose() { - SocketInput.Dispose(); + Input.Writer.Complete(); } public async Task ReadInputAsync() @@ -42,21 +45,29 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Adapter.Internal do { - var block = SocketInput.IncomingStart(); + var block = Input.Writer.Alloc(MinAllocBufferSize); try { - var count = block.Data.Offset + block.Data.Count - block.End; - bytesRead = await _filteredStream.ReadAsync(block.Array, block.End, count); + var array = block.Memory.GetArray(); + try + { + bytesRead = await _filteredStream.ReadAsync(array.Array, array.Offset, array.Count); + block.Advance(bytesRead); + } + finally + { + await block.FlushAsync(); + } } catch (Exception ex) { - SocketInput.IncomingComplete(0, ex); + Input.Writer.Complete(ex); throw; } - - SocketInput.IncomingComplete(bytesRead, error: null); } while (bytesRead != 0); + + Input.Writer.Complete(); } } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Adapter/Internal/RawStream.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Adapter/Internal/RawStream.cs index 0824eeb5..3ec258b1 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Adapter/Internal/RawStream.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Adapter/Internal/RawStream.cs @@ -3,6 +3,7 @@ using System; using System.IO; +using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Internal.Http; @@ -12,12 +13,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Adapter.Internal { public class RawStream : Stream { - private readonly SocketInput _input; + private readonly IPipeReader _input; private readonly ISocketOutput _output; - private Task _cachedTask = TaskCache.DefaultCompletedTask; - - public RawStream(SocketInput input, ISocketOutput output) + public RawStream(IPipeReader input, ISocketOutput output) { _input = input; _output = output; @@ -68,23 +67,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Adapter.Internal public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - var task = ReadAsync(new ArraySegment(buffer, offset, count)); - - if (task.IsCompletedSuccessfully) - { - if (_cachedTask.Result != task.Result) - { - // Needs .AsTask to match Stream's Async method return types - _cachedTask = task.AsTask(); - } - } - else - { - // Needs .AsTask to match Stream's Async method return types - _cachedTask = task.AsTask(); - } - - return _cachedTask; + return ReadAsync(new ArraySegment(buffer, offset, count)); } public override void Write(byte[] buffer, int offset, int count) @@ -125,10 +108,31 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Adapter.Internal return _output.FlushAsync(cancellationToken); } - - private ValueTask ReadAsync(ArraySegment buffer) + private async Task ReadAsync(ArraySegment buffer) { - return _input.ReadAsync(buffer.Array, buffer.Offset, buffer.Count); + while (true) + { + var result = await _input.ReadAsync(); + var readableBuffer = result.Buffer; + try + { + if (!readableBuffer.IsEmpty) + { + var count = Math.Min(readableBuffer.Length, buffer.Count); + readableBuffer = readableBuffer.Slice(0, count); + readableBuffer.CopyTo(buffer); + return count; + } + else if (result.IsCompleted || result.IsCancelled) + { + return 0; + } + } + finally + { + _input.Advance(readableBuffer.End, readableBuffer.End); + } + } } #if NET451 diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/BufferSizeControl.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/BufferSizeControl.cs deleted file mode 100644 index 364a1a2c..00000000 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/BufferSizeControl.cs +++ /dev/null @@ -1,77 +0,0 @@ -using System.Diagnostics; - -namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http -{ - public class BufferSizeControl : IBufferSizeControl - { - private readonly long _maxSize; - private readonly IConnectionControl _connectionControl; - - private readonly object _lock = new object(); - - private long _size; - private bool _connectionPaused; - - public BufferSizeControl(long maxSize, IConnectionControl connectionControl) - { - _maxSize = maxSize; - _connectionControl = connectionControl; - } - - private long Size - { - get - { - return _size; - } - set - { - // Caller should ensure that bytes are never consumed before the producer has called Add() - Debug.Assert(value >= 0); - _size = value; - } - } - - public void Add(int count) - { - Debug.Assert(count >= 0); - - if (count == 0) - { - // No-op and avoid taking lock to reduce contention - return; - } - - lock (_lock) - { - Size += count; - if (!_connectionPaused && Size >= _maxSize) - { - _connectionPaused = true; - _connectionControl.Pause(); - } - } - } - - public void Subtract(int count) - { - Debug.Assert(count >= 0); - - if (count == 0) - { - // No-op and avoid taking lock to reduce contention - return; - } - - lock (_lock) - { - Size -= count; - if (_connectionPaused && Size < _maxSize) - { - _connectionPaused = false; - _connectionControl.Resume(); - } - } - } - } -} diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs index 12b97868..f2211cdc 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; +using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Adapter; @@ -18,6 +19,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { public class Connection : ConnectionContext, IConnectionControl { + private const int MinAllocBufferSize = 2048; + // Base32 encoding - in ascii sort order for easy text based sorting private static readonly string _encode32Chars = "0123456789ABCDEFGHIJKLMNOPQRSTUV"; @@ -40,11 +43,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http private Task _readInputTask; private TaskCompletionSource _socketClosedTcs = new TaskCompletionSource(); - private BufferSizeControl _bufferSizeControl; private long _lastTimestamp; private long _timeoutTimestamp = long.MaxValue; private TimeoutAction _timeoutAction; + private WritableBuffer? _currentWritableBuffer; public Connection(ListenerContext context, UvStreamHandle socket) : base(context) { @@ -55,12 +58,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http ConnectionId = GenerateConnectionId(Interlocked.Increment(ref _lastConnectionId)); - if (ServerOptions.Limits.MaxRequestBufferSize.HasValue) - { - _bufferSizeControl = new BufferSizeControl(ServerOptions.Limits.MaxRequestBufferSize.Value, this); - } - - Input = new SocketInput(Thread.Memory, ThreadPool, _bufferSizeControl); + Input = Thread.PipelineFactory.Create(ListenerContext.LibuvPipeOptions); Output = new SocketOutput(Thread, _socket, this, ConnectionId, Log, ThreadPool); var tcpHandle = _socket as UvTcpHandle; @@ -92,6 +90,21 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http // Start socket prior to applying the ConnectionAdapter _socket.ReadStart(_allocCallback, _readCallback, this); + // Dispatch to a thread pool so if the first read completes synchronously + // we won't be on IO thread + try + { + ThreadPool.UnsafeRun(state => ((Connection)state).StartFrame(), this); + } + catch (Exception e) + { + Log.LogError(0, e, "Connection.StartFrame"); + throw; + } + } + + private void StartFrame() + { if (_connectionAdapters.Count == 0) { _frame.Start(); @@ -107,7 +120,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http public Task StopAsync() { _frame.StopAsync(); - _frame.Input.CompleteAwaiting(); + _frame.Input.Reader.CancelPendingRead(); return _socketClosedTcs.Task; } @@ -138,11 +151,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http var connection2 = (Connection)state2; connection2._filteredStream.Dispose(); connection2._adaptedPipeline.Dispose(); + Input.Reader.Complete(); }, connection); } }, this); - Input.Dispose(); + Input.Writer.Complete(new TaskCanceledException("The request was aborted")); _socketClosedTcs.TrySetResult(null); } @@ -168,7 +182,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { try { - var rawStream = new RawStream(Input, Output); + var rawStream = new RawStream(Input.Reader, Output); var adapterContext = new ConnectionAdapterContext(rawStream); var adaptedConnections = new IAdaptedConnection[_connectionAdapters.Count]; @@ -182,11 +196,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http if (adapterContext.ConnectionStream != rawStream) { _filteredStream = adapterContext.ConnectionStream; - _adaptedPipeline = new AdaptedPipeline(ConnectionId, adapterContext.ConnectionStream, - Thread.Memory, Log, ThreadPool, _bufferSizeControl); + _adaptedPipeline = new AdaptedPipeline( + ConnectionId, + adapterContext.ConnectionStream, + Thread.PipelineFactory.Create(ListenerContext.AdaptedPipeOptions), + Thread.Memory, + Log); - _frame.Input = _adaptedPipeline.SocketInput; - _frame.Output = _adaptedPipeline.SocketOutput; + _frame.Input = _adaptedPipeline.Input; + _frame.Output = _adaptedPipeline.Output; // Don't attempt to read input if connection has already closed. // This can happen if a client opens a connection and immediately closes it. @@ -201,6 +219,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http catch (Exception ex) { Log.LogError(0, ex, $"Uncaught exception from the {nameof(IConnectionAdapter.OnConnectionAsync)} method of an {nameof(IConnectionAdapter)}."); + Input.Reader.Complete(); ConnectionControl.End(ProduceEndType.SocketDisconnect); } } @@ -210,13 +229,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http return ((Connection)state).OnAlloc(handle, suggestedSize); } - private Libuv.uv_buf_t OnAlloc(UvStreamHandle handle, int suggestedSize) + private unsafe Libuv.uv_buf_t OnAlloc(UvStreamHandle handle, int suggestedSize) { - var result = Input.IncomingStart(); + Debug.Assert(_currentWritableBuffer == null); + var currentWritableBuffer = Input.Writer.Alloc(MinAllocBufferSize); + _currentWritableBuffer = currentWritableBuffer; + void* dataPtr; + var tryGetPointer = currentWritableBuffer.Memory.TryGetPointer(out dataPtr); + Debug.Assert(tryGetPointer); return handle.Libuv.buf_init( - result.DataArrayPtr + result.End, - result.Data.Offset + result.Data.Count - result.End); + (IntPtr)dataPtr, + currentWritableBuffer.Memory.Length); } private static void ReadCallback(UvStreamHandle handle, int status, object state) @@ -224,19 +248,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http ((Connection)state).OnRead(handle, status); } - private void OnRead(UvStreamHandle handle, int status) + private async void OnRead(UvStreamHandle handle, int status) { - if (status == 0) - { - // A zero status does not indicate an error or connection end. It indicates - // there is no data to be read right now. - // See the note at http://docs.libuv.org/en/v1.x/stream.html#c.uv_read_cb. - // We need to clean up whatever was allocated by OnAlloc. - Input.IncomingDeferred(); - return; - } - - var normalRead = status > 0; + var normalRead = status >= 0; var normalDone = status == Constants.EOF; var errorDone = !(normalDone || normalRead); var readCount = normalRead ? status : 0; @@ -256,6 +270,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } IOException error = null; + WritableBufferAwaitable? flushTask = null; if (errorDone) { Exception uvError; @@ -272,13 +287,31 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } error = new IOException(uvError.Message, uvError); + _currentWritableBuffer?.Commit(); + } + else + { + Debug.Assert(_currentWritableBuffer != null); + + var currentWritableBuffer = _currentWritableBuffer.Value; + currentWritableBuffer.Advance(readCount); + flushTask = currentWritableBuffer.FlushAsync(); } - Input.IncomingComplete(readCount, error); + _currentWritableBuffer = null; + if (flushTask?.IsCompleted == false) + { + OnPausePosted(); + if (await flushTask.Value) + { + OnResumePosted(); + } + } if (!normalRead) { - AbortAsync(error); + Input.Writer.Complete(error); + var ignore = AbortAsync(error); } } @@ -289,7 +322,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http // Even though this method is called on the event loop already, // post anyway so the ReadStop() call doesn't get reordered // relative to the ReadStart() call made in Resume(). - Thread.Post(state => ((Connection)state).OnPausePosted(), this); + Thread.Post(state => state.OnPausePosted(), this); } void IConnectionControl.Resume() @@ -297,7 +330,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http Log.ConnectionResume(ConnectionId); // This is called from the consuming thread. - Thread.Post(state => ((Connection)state).OnResumePosted(), this); + Thread.Post(state => state.OnResumePosted(), this); } private void OnPausePosted() @@ -316,14 +349,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { try { - _socket.ReadStart(_allocCallback, _readCallback, this); + _socket.ReadStart(_allocCallback, _readCallback, this); } catch (UvException) { // ReadStart() can throw a UvException in some cases (e.g. socket is no longer connected). // This should be treated the same as OnRead() seeing a "normalDone" condition. Log.ConnectionReadFin(ConnectionId); - Input.IncomingComplete(0, null); + Input.Writer.Complete(); } } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ConnectionContext.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ConnectionContext.cs index decc5dbf..267f34dd 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ConnectionContext.cs @@ -3,6 +3,7 @@ using System; using System.Net; +using System.IO.Pipelines; using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http @@ -20,7 +21,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http public ListenerContext ListenerContext { get; set; } - public SocketInput Input { get; set; } + public IPipe Input { get; set; } public ISocketOutput Output { get; set; } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ConnectionManager.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ConnectionManager.cs index 884324f0..dbea0eef 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ConnectionManager.cs @@ -34,7 +34,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { var tcs = new TaskCompletionSource(); - _thread.Post(state => action((ConnectionManager)state, tcs), this); + _thread.Post(state => action(state, tcs), this); return await Task.WhenAny(tcs.Task, Task.Delay(timeout)).ConfigureAwait(false) == tcs.Task; } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs index e72549c7..7a3d04e8 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs @@ -5,10 +5,14 @@ using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; +using System.IO.Pipelines; +using System.IO.Pipelines.Text.Primitives; using System.Linq; using System.Net; using System.Runtime.CompilerServices; using System.Text; +using System.Text.Encodings.Web.Utf8; +using System.Text.Utf8; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; @@ -95,7 +99,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } public ConnectionContext ConnectionContext { get; } - public SocketInput Input { get; set; } + public IPipe Input { get; set; } public ISocketOutput Output { get; set; } public IEnumerable AdaptedConnections { get; set; } @@ -386,13 +390,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http public void Start() { Reset(); - _requestProcessingTask = - Task.Factory.StartNew( - (o) => ((Frame)o).RequestProcessingAsync(), - this, - default(CancellationToken), - TaskCreationOptions.DenyChildAttach, - TaskScheduler.Default).Unwrap(); + _requestProcessingTask = RequestProcessingAsync(); _frameStartedTcs.SetResult(null); } @@ -986,216 +984,204 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http Output.ProducingComplete(end); } - public RequestLineStatus TakeStartLine(SocketInput input) + public bool TakeStartLine(ReadableBuffer buffer, out ReadCursor consumed, out ReadCursor examined) { - const int MaxInvalidRequestLineChars = 32; + var start = buffer.Start; + var end = buffer.Start; - var scan = input.ConsumingStart(); - var start = scan; - var consumed = scan; - var end = scan; + examined = buffer.End; + consumed = buffer.Start; - try + if (_requestProcessingStatus == RequestProcessingStatus.RequestPending) { - // We may hit this when the client has stopped sending data but - // the connection hasn't closed yet, and therefore Frame.Stop() - // hasn't been called yet. - if (scan.Peek() == -1) + ConnectionControl.ResetTimeout(_requestHeadersTimeoutMilliseconds, TimeoutAction.SendTimeoutResponse); + } + + _requestProcessingStatus = RequestProcessingStatus.RequestStarted; + + var limitedBuffer = buffer; + if (buffer.Length >= ServerOptions.Limits.MaxRequestLineSize) + { + limitedBuffer = buffer.Slice(0, ServerOptions.Limits.MaxRequestLineSize); + } + if (ReadCursorOperations.Seek(limitedBuffer.Start, limitedBuffer.End, out end, ByteLF) == -1) + { + if (limitedBuffer.Length == ServerOptions.Limits.MaxRequestLineSize) { - return RequestLineStatus.Empty; - } - - if (_requestProcessingStatus == RequestProcessingStatus.RequestPending) - { - ConnectionControl.ResetTimeout(_requestHeadersTimeoutMilliseconds, TimeoutAction.SendTimeoutResponse); - } - - _requestProcessingStatus = RequestProcessingStatus.RequestStarted; - - int bytesScanned; - if (end.Seek(ByteLF, out bytesScanned, ServerOptions.Limits.MaxRequestLineSize) == -1) - { - if (bytesScanned >= ServerOptions.Limits.MaxRequestLineSize) - { - RejectRequest(RequestRejectionReason.RequestLineTooLong); - } - else - { - return RequestLineStatus.Incomplete; - } - } - end.Take(); - - string method; - var begin = scan; - if (!begin.GetKnownMethod(out method)) - { - if (scan.Seek(ByteSpace, ref end) == -1) - { - RejectRequest(RequestRejectionReason.InvalidRequestLine, - Log.IsEnabled(LogLevel.Information) ? start.GetAsciiStringEscaped(end, MaxInvalidRequestLineChars) : string.Empty); - } - - method = begin.GetAsciiString(ref scan); - - if (method == null) - { - RejectRequest(RequestRejectionReason.InvalidRequestLine, - Log.IsEnabled(LogLevel.Information) ? start.GetAsciiStringEscaped(end, MaxInvalidRequestLineChars) : string.Empty); - } - - // Note: We're not in the fast path any more (GetKnownMethod should have handled any HTTP Method we're aware of) - // So we can be a tiny bit slower and more careful here. - for (int i = 0; i < method.Length; i++) - { - if (!IsValidTokenChar(method[i])) - { - RejectRequest(RequestRejectionReason.InvalidRequestLine, - Log.IsEnabled(LogLevel.Information) ? start.GetAsciiStringEscaped(end, MaxInvalidRequestLineChars) : string.Empty); - } - } + RejectRequest(RequestRejectionReason.RequestLineTooLong); } else { - scan.Skip(method.Length); + return false; + } + } + + end = buffer.Move(end, 1); + ReadCursor methodEnd; + string method; + if (!buffer.GetKnownMethod(out method)) + { + if (ReadCursorOperations.Seek(buffer.Start, end, out methodEnd, ByteSpace) == -1) + { + RejectRequestLine(start, end); } - scan.Take(); - begin = scan; - var needDecode = false; - var chFound = scan.Seek(ByteSpace, ByteQuestionMark, BytePercentage, ref end); + method = buffer.Slice(buffer.Start, methodEnd).GetAsciiString(); + + if (method == null) + { + RejectRequestLine(start, end); + } + + // Note: We're not in the fast path any more (GetKnownMethod should have handled any HTTP Method we're aware of) + // So we can be a tiny bit slower and more careful here. + for (int i = 0; i < method.Length; i++) + { + if (!IsValidTokenChar(method[i])) + { + RejectRequestLine(start, end); + } + } + } + else + { + methodEnd = buffer.Slice(method.Length).Start; + } + + var needDecode = false; + ReadCursor pathEnd; + + var pathBegin = buffer.Move(methodEnd, 1); + + var chFound = ReadCursorOperations.Seek(pathBegin, end, out pathEnd, ByteSpace, ByteQuestionMark, BytePercentage); + if (chFound == -1) + { + RejectRequestLine(start, end); + } + else if (chFound == BytePercentage) + { + needDecode = true; + chFound = ReadCursorOperations.Seek(pathBegin, end, out pathEnd, ByteSpace, ByteQuestionMark); if (chFound == -1) { - RejectRequest(RequestRejectionReason.InvalidRequestLine, - Log.IsEnabled(LogLevel.Information) ? start.GetAsciiStringEscaped(end, MaxInvalidRequestLineChars) : string.Empty); - } - else if (chFound == BytePercentage) - { - needDecode = true; - chFound = scan.Seek(ByteSpace, ByteQuestionMark, ref end); - if (chFound == -1) - { - RejectRequest(RequestRejectionReason.InvalidRequestLine, - Log.IsEnabled(LogLevel.Information) ? start.GetAsciiStringEscaped(end, MaxInvalidRequestLineChars) : string.Empty); - } + RejectRequestLine(start, end); } + }; - var pathBegin = begin; - var pathEnd = scan; - - var queryString = ""; - if (chFound == ByteQuestionMark) - { - begin = scan; - if (scan.Seek(ByteSpace, ref end) == -1) - { - RejectRequest(RequestRejectionReason.InvalidRequestLine, - Log.IsEnabled(LogLevel.Information) ? start.GetAsciiStringEscaped(end, MaxInvalidRequestLineChars) : string.Empty); - } - queryString = begin.GetAsciiString(ref scan); - } - - var queryEnd = scan; - - if (pathBegin.Peek() == ByteSpace) - { - RejectRequest(RequestRejectionReason.InvalidRequestLine, - Log.IsEnabled(LogLevel.Information) ? start.GetAsciiStringEscaped(end, MaxInvalidRequestLineChars) : string.Empty); - } - - scan.Take(); - begin = scan; - if (scan.Seek(ByteCR, ref end) == -1) - { - RejectRequest(RequestRejectionReason.InvalidRequestLine, - Log.IsEnabled(LogLevel.Information) ? start.GetAsciiStringEscaped(end, MaxInvalidRequestLineChars) : string.Empty); - } - - string httpVersion; - if (!begin.GetKnownVersion(out httpVersion)) - { - httpVersion = begin.GetAsciiStringEscaped(scan, 9); - - if (httpVersion == string.Empty) - { - RejectRequest(RequestRejectionReason.InvalidRequestLine, - Log.IsEnabled(LogLevel.Information) ? start.GetAsciiStringEscaped(end, MaxInvalidRequestLineChars) : string.Empty); - } - else - { - RejectRequest(RequestRejectionReason.UnrecognizedHTTPVersion, httpVersion); - } - } - - scan.Take(); // consume CR - if (scan.Take() != ByteLF) - { - RejectRequest(RequestRejectionReason.InvalidRequestLine, - Log.IsEnabled(LogLevel.Information) ? start.GetAsciiStringEscaped(end, MaxInvalidRequestLineChars) : string.Empty); - } - - // URIs are always encoded/escaped to ASCII https://tools.ietf.org/html/rfc3986#page-11 - // Multibyte Internationalized Resource Identifiers (IRIs) are first converted to utf8; - // then encoded/escaped to ASCII https://www.ietf.org/rfc/rfc3987.txt "Mapping of IRIs to URIs" - string requestUrlPath; - string rawTarget; - if (needDecode) - { - // Read raw target before mutating memory. - rawTarget = pathBegin.GetAsciiString(ref queryEnd); - - // URI was encoded, unescape and then parse as utf8 - pathEnd = UrlPathDecoder.Unescape(pathBegin, pathEnd); - requestUrlPath = pathBegin.GetUtf8String(ref pathEnd); - } - else - { - // URI wasn't encoded, parse as ASCII - requestUrlPath = pathBegin.GetAsciiString(ref pathEnd); - - if (queryString.Length == 0) - { - // No need to allocate an extra string if the path didn't need - // decoding and there's no query string following it. - rawTarget = requestUrlPath; - } - else - { - rawTarget = pathBegin.GetAsciiString(ref queryEnd); - } - } - - var normalizedTarget = PathNormalizer.RemoveDotSegments(requestUrlPath); - - consumed = scan; - Method = method; - QueryString = queryString; - RawTarget = rawTarget; - HttpVersion = httpVersion; - - bool caseMatches; - if (RequestUrlStartsWithPathBase(normalizedTarget, out caseMatches)) - { - PathBase = caseMatches ? _pathBase : normalizedTarget.Substring(0, _pathBase.Length); - Path = normalizedTarget.Substring(_pathBase.Length); - } - else if (rawTarget[0] == '/') // check rawTarget since normalizedTarget can be "" or "/" after dot segment removal - { - Path = normalizedTarget; - } - else - { - Path = string.Empty; - PathBase = string.Empty; - QueryString = string.Empty; - } - - return RequestLineStatus.Done; - } - finally + var queryString = ""; + ReadCursor queryEnd = pathEnd; + if (chFound == ByteQuestionMark) { - input.ConsumingComplete(consumed, end); + if (ReadCursorOperations.Seek(pathEnd, end, out queryEnd, ByteSpace) == -1) + { + RejectRequestLine(start, end); + } + queryString = buffer.Slice(pathEnd, queryEnd).GetAsciiString(); } + + // No path + if (pathBegin == pathEnd) + { + RejectRequestLine(start, end); + } + + ReadCursor versionEnd; + if (ReadCursorOperations.Seek(queryEnd, end, out versionEnd, ByteCR) == -1) + { + RejectRequestLine(start, end); + } + + string httpVersion; + var versionBuffer = buffer.Slice(queryEnd, end).Slice(1); + if (!versionBuffer.GetKnownVersion(out httpVersion)) + { + httpVersion = versionBuffer.Start.GetAsciiStringEscaped(versionEnd, 9); + + if (httpVersion == string.Empty) + { + RejectRequestLine(start, end); + } + else + { + RejectRequest(RequestRejectionReason.UnrecognizedHTTPVersion, httpVersion); + } + } + + var lineEnd = buffer.Slice(versionEnd, 2).ToSpan(); + if (lineEnd[1] != ByteLF) + { + RejectRequestLine(start, end); + } + + var pathBuffer = buffer.Slice(pathBegin, pathEnd); + var targetBuffer = buffer.Slice(pathBegin, queryEnd); + + // URIs are always encoded/escaped to ASCII https://tools.ietf.org/html/rfc3986#page-11 + // Multibyte Internationalized Resource Identifiers (IRIs) are first converted to utf8; + // then encoded/escaped to ASCII https://www.ietf.org/rfc/rfc3987.txt "Mapping of IRIs to URIs" + string requestUrlPath; + string rawTarget; + if (needDecode) + { + // Read raw target before mutating memory. + rawTarget = targetBuffer.GetAsciiString() ?? string.Empty; + + // URI was encoded, unescape and then parse as utf8 + var pathSpan = pathBuffer.ToSpan(); + int pathLength = UrlEncoder.Decode(pathSpan, pathSpan); + requestUrlPath = new Utf8String(pathSpan.Slice(0, pathLength)).ToString(); + } + else + { + // URI wasn't encoded, parse as ASCII + requestUrlPath = pathBuffer.GetAsciiString() ?? string.Empty; + + if (queryString.Length == 0) + { + // No need to allocate an extra string if the path didn't need + // decoding and there's no query string following it. + rawTarget = requestUrlPath; + } + else + { + rawTarget = targetBuffer.GetAsciiString() ?? string.Empty; + } + } + + var normalizedTarget = PathNormalizer.RemoveDotSegments(requestUrlPath); + + consumed = end; + examined = end; + Method = method; + QueryString = queryString; + RawTarget = rawTarget; + HttpVersion = httpVersion; + + bool caseMatches; + if (RequestUrlStartsWithPathBase(normalizedTarget, out caseMatches)) + { + PathBase = caseMatches ? _pathBase : normalizedTarget.Substring(0, _pathBase.Length); + Path = normalizedTarget.Substring(_pathBase.Length); + } + else if (rawTarget[0] == '/') // check rawTarget since normalizedTarget can be "" or "/" after dot segment removal + { + Path = normalizedTarget; + } + else + { + Path = string.Empty; + PathBase = string.Empty; + QueryString = string.Empty; + } + + return true; + } + + private void RejectRequestLine(ReadCursor start, ReadCursor end) + { + const int MaxRequestLineError = 32; + RejectRequest(RequestRejectionReason.InvalidRequestLine, + Log.IsEnabled(LogLevel.Information) ? start.GetAsciiStringEscaped(end, MaxRequestLineError) : string.Empty); } private static bool IsValidTokenChar(char c) @@ -1255,34 +1241,35 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http return true; } - public bool TakeMessageHeaders(SocketInput input, FrameRequestHeaders requestHeaders) + public bool TakeMessageHeaders(ReadableBuffer buffer, FrameRequestHeaders requestHeaders, out ReadCursor consumed, out ReadCursor examined) { - var scan = input.ConsumingStart(); - var consumed = scan; - var end = scan; - try + consumed = buffer.Start; + examined = buffer.End; + + while (true) { - while (!end.IsEnd) + var headersEnd = buffer.Slice(0, Math.Min(buffer.Length, 2)); + var headersEndSpan = headersEnd.ToSpan(); + + if (headersEndSpan.Length == 0) { - var ch = end.Peek(); - if (ch == -1) - { - return false; - } - else if (ch == ByteCR) + return false; + } + else + { + var ch = headersEndSpan[0]; + if (ch == ByteCR) { // Check for final CRLF. - end.Take(); - ch = end.Take(); - - if (ch == -1) + if (headersEndSpan.Length < 2) { return false; } - else if (ch == ByteLF) + else if (headersEndSpan[1] == ByteLF) { + consumed = headersEnd.End; + examined = consumed; ConnectionControl.CancelTimeout(); - consumed = end; return true; } @@ -1293,129 +1280,113 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { RejectRequest(RequestRejectionReason.HeaderLineMustNotStartWithWhitespace); } + } - // If we've parsed the max allowed numbers of headers and we're starting a new - // one, we've gone over the limit. - if (_requestHeadersParsed == ServerOptions.Limits.MaxRequestHeaderCount) + // If we've parsed the max allowed numbers of headers and we're starting a new + // one, we've gone over the limit. + if (_requestHeadersParsed == ServerOptions.Limits.MaxRequestHeaderCount) + { + RejectRequest(RequestRejectionReason.TooManyHeaders); + } + + ReadCursor lineEnd; + var limitedBuffer = buffer; + if (buffer.Length >= _remainingRequestHeadersBytesAllowed) + { + limitedBuffer = buffer.Slice(0, _remainingRequestHeadersBytesAllowed); + } + if (ReadCursorOperations.Seek(limitedBuffer.Start, limitedBuffer.End, out lineEnd, ByteLF) == -1) + { + if (limitedBuffer.Length == _remainingRequestHeadersBytesAllowed) { - RejectRequest(RequestRejectionReason.TooManyHeaders); + RejectRequest(RequestRejectionReason.HeadersExceedMaxTotalSize); } - - int bytesScanned; - if (end.Seek(ByteLF, out bytesScanned, _remainingRequestHeadersBytesAllowed) == -1) - { - if (bytesScanned >= _remainingRequestHeadersBytesAllowed) - { - RejectRequest(RequestRejectionReason.HeadersExceedMaxTotalSize); - } - else - { - return false; - } - } - - var beginName = scan; - if (scan.Seek(ByteColon, ref end) == -1) - { - RejectRequest(RequestRejectionReason.NoColonCharacterFoundInHeaderLine); - } - var endName = scan; - - scan.Take(); - - var validateName = beginName; - if (validateName.Seek(ByteSpace, ByteTab, ref endName) != -1) - { - RejectRequest(RequestRejectionReason.WhitespaceIsNotAllowedInHeaderName); - } - - var beginValue = scan; - ch = scan.Take(); - - while (ch == ByteSpace || ch == ByteTab) - { - beginValue = scan; - ch = scan.Take(); - } - - scan = beginValue; - if (scan.Seek(ByteCR, ref end) == -1) - { - RejectRequest(RequestRejectionReason.MissingCRInHeaderLine); - } - - scan.Take(); // we know this is '\r' - ch = scan.Take(); // expecting '\n' - end = scan; - - if (ch != ByteLF) - { - RejectRequest(RequestRejectionReason.HeaderValueMustNotContainCR); - } - - var next = scan.Peek(); - if (next == -1) + else { return false; } - else if (next == ByteSpace || next == ByteTab) - { - // From https://tools.ietf.org/html/rfc7230#section-3.2.4: - // - // Historically, HTTP header field values could be extended over - // multiple lines by preceding each extra line with at least one space - // or horizontal tab (obs-fold). This specification deprecates such - // line folding except within the message/http media type - // (Section 8.3.1). A sender MUST NOT generate a message that includes - // line folding (i.e., that has any field-value that contains a match to - // the obs-fold rule) unless the message is intended for packaging - // within the message/http media type. - // - // A server that receives an obs-fold in a request message that is not - // within a message/http container MUST either reject the message by - // sending a 400 (Bad Request), preferably with a representation - // explaining that obsolete line folding is unacceptable, or replace - // each received obs-fold with one or more SP octets prior to - // interpreting the field value or forwarding the message downstream. - RejectRequest(RequestRejectionReason.HeaderValueLineFoldingNotSupported); - } - - // Trim trailing whitespace from header value by repeatedly advancing to next - // whitespace or CR. - // - // - If CR is found, this is the end of the header value. - // - If whitespace is found, this is the _tentative_ end of the header value. - // If non-whitespace is found after it and it's not CR, seek again to the next - // whitespace or CR for a new (possibly tentative) end of value. - var ws = beginValue; - var endValue = scan; - do - { - ws.Seek(ByteSpace, ByteTab, ByteCR); - endValue = ws; - - ch = ws.Take(); - while (ch == ByteSpace || ch == ByteTab) - { - ch = ws.Take(); - } - } while (ch != ByteCR); - - var name = beginName.GetArraySegment(endName); - var value = beginValue.GetAsciiString(ref endValue); - - consumed = scan; - requestHeaders.Append(name.Array, name.Offset, name.Count, value); - - _remainingRequestHeadersBytesAllowed -= bytesScanned; - _requestHeadersParsed++; } - return false; - } - finally - { - input.ConsumingComplete(consumed, end); + var beginName = buffer.Start; + ReadCursor endName; + if (ReadCursorOperations.Seek(buffer.Start, lineEnd, out endName, ByteColon) == -1) + { + RejectRequest(RequestRejectionReason.NoColonCharacterFoundInHeaderLine); + } + + ReadCursor whitespace; + if (ReadCursorOperations.Seek(beginName, endName, out whitespace, ByteTab, ByteSpace) != -1) + { + RejectRequest(RequestRejectionReason.WhitespaceIsNotAllowedInHeaderName); + } + + ReadCursor endValue; + if (ReadCursorOperations.Seek(beginName, lineEnd, out endValue, ByteCR) == -1) + { + RejectRequest(RequestRejectionReason.MissingCRInHeaderLine); + } + + var lineSufix = buffer.Slice(endValue); + if (lineSufix.Length < 3) + { + return false; + } + lineSufix = lineSufix.Slice(0, 3); // \r\n\r + var lineSufixSpan = lineSufix.ToSpan(); + // This check and MissingCRInHeaderLine is a bit backwards, we should do it at once instead of having another seek + if (lineSufixSpan[1] != ByteLF) + { + RejectRequest(RequestRejectionReason.HeaderValueMustNotContainCR); + } + + var next = lineSufixSpan[2]; + if (next == ByteSpace || next == ByteTab) + { + // From https://tools.ietf.org/html/rfc7230#section-3.2.4: + // + // Historically, HTTP header field values could be extended over + // multiple lines by preceding each extra line with at least one space + // or horizontal tab (obs-fold). This specification deprecates such + // line folding except within the message/http media type + // (Section 8.3.1). A sender MUST NOT generate a message that includes + // line folding (i.e., that has any field-value that contains a match to + // the obs-fold rule) unless the message is intended for packaging + // within the message/http media type. + // + // A server that receives an obs-fold in a request message that is not + // within a message/http container MUST either reject the message by + // sending a 400 (Bad Request), preferably with a representation + // explaining that obsolete line folding is unacceptable, or replace + // each received obs-fold with one or more SP octets prior to + // interpreting the field value or forwarding the message downstream. + RejectRequest(RequestRejectionReason.HeaderValueLineFoldingNotSupported); + } + + // Trim trailing whitespace from header value by repeatedly advancing to next + // whitespace or CR. + // + // - If CR is found, this is the end of the header value. + // - If whitespace is found, this is the _tentative_ end of the header value. + // If non-whitespace is found after it and it's not CR, seek again to the next + // whitespace or CR for a new (possibly tentative) end of value. + + var nameBuffer = buffer.Slice(beginName, endName); + + // TODO: TrimStart and TrimEnd are pretty slow + var valueBuffer = buffer.Slice(endName, endValue).Slice(1).TrimStart().TrimEnd(); + + var name = nameBuffer.ToArraySegment(); + var value = valueBuffer.GetAsciiString(); + + lineEnd = limitedBuffer.Move(lineEnd, 1); + + // TODO: bad + _remainingRequestHeadersBytesAllowed -= buffer.Slice(0, lineEnd).Length; + _requestHeadersParsed++; + + requestHeaders.Append(name.Array, name.Offset, name.Count, value); + buffer = buffer.Slice(lineEnd); + consumed = buffer.Start; } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs index f6d9128f..40fc9e20 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs @@ -3,6 +3,7 @@ using System; using System.IO; +using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting.Server; @@ -31,66 +32,95 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http /// public override async Task RequestProcessingAsync() { - var requestLineStatus = RequestLineStatus.Empty; + var requestLineStatus = default(RequestLineStatus); try { while (!_requestProcessingStopping) { + // If writer completes with an error Input.ReadAsyncDispatched would throw and + // this would not be reset to empty. But it's required by ECONNRESET check lower in the method. + requestLineStatus = RequestLineStatus.Empty; + ConnectionControl.SetTimeout(_keepAliveMilliseconds, TimeoutAction.CloseConnection); while (!_requestProcessingStopping) { - requestLineStatus = TakeStartLine(Input); + var result = await Input.Reader.ReadAsync(); + var examined = result.Buffer.End; + var consumed = result.Buffer.End; + + try + { + if (!result.Buffer.IsEmpty) + { + requestLineStatus = TakeStartLine(result.Buffer, out consumed, out examined) + ? RequestLineStatus.Done : RequestLineStatus.Incomplete; + } + else + { + requestLineStatus = RequestLineStatus.Empty; + } + } + catch (InvalidOperationException) + { + throw BadHttpRequestException.GetException(RequestRejectionReason.InvalidRequestLine); + } + finally + { + Input.Reader.Advance(consumed, examined); + } if (requestLineStatus == RequestLineStatus.Done) { break; } - if (Input.CheckFinOrThrow()) + if (result.IsCompleted) { - // We need to attempt to consume start lines and headers even after - // SocketInput.RemoteIntakeFin is set to true to ensure we don't close a - // connection without giving the application a chance to respond to a request - // sent immediately before the a FIN from the client. - requestLineStatus = TakeStartLine(Input); - if (requestLineStatus == RequestLineStatus.Empty) { return; } - if (requestLineStatus != RequestLineStatus.Done) - { - RejectRequest(RequestRejectionReason.InvalidRequestLine, requestLineStatus.ToString()); - } - - break; + RejectRequest(RequestRejectionReason.InvalidRequestLine, requestLineStatus.ToString()); } - - await Input; } InitializeHeaders(); - while (!_requestProcessingStopping && !TakeMessageHeaders(Input, FrameRequestHeaders)) + while (!_requestProcessingStopping) { - if (Input.CheckFinOrThrow()) - { - // We need to attempt to consume start lines and headers even after - // SocketInput.RemoteIntakeFin is set to true to ensure we don't close a - // connection without giving the application a chance to respond to a request - // sent immediately before the a FIN from the client. - if (!TakeMessageHeaders(Input, FrameRequestHeaders)) - { - RejectRequest(RequestRejectionReason.MalformedRequestInvalidHeaders); - } + var result = await Input.Reader.ReadAsync(); + var examined = result.Buffer.End; + var consumed = result.Buffer.End; + + bool headersDone; + + try + { + headersDone = TakeMessageHeaders(result.Buffer, FrameRequestHeaders, out consumed, + out examined); + } + catch (InvalidOperationException) + { + throw BadHttpRequestException.GetException(RequestRejectionReason.MalformedRequestInvalidHeaders); + } + finally + { + Input.Reader.Advance(consumed, examined); + } + + if (headersDone) + { break; } - await Input; + if (result.IsCompleted) + { + RejectRequest(RequestRejectionReason.MalformedRequestInvalidHeaders); + } } if (!_requestProcessingStopping) @@ -216,6 +246,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { try { + Input.Reader.Complete(); // If _requestAborted is set, the connection has already been closed. if (Volatile.Read(ref _requestAborted) == 0) { diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Listener.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Listener.cs index 3bb269f6..a8624d17 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Listener.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Listener.cs @@ -36,7 +36,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http Thread.Post(state => { - var tcs2 = (TaskCompletionSource) state; + var tcs2 = state; try { var listener = ((Listener) tcs2.Task.AsyncState); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ListenerContext.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ListenerContext.cs index 083f1a21..e5a7ae80 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ListenerContext.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ListenerContext.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO.Pipelines; using Microsoft.AspNetCore.Server.Kestrel.Internal.Networking; namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http @@ -40,5 +41,21 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http throw new InvalidOperationException(); } } + + public PipeOptions LibuvPipeOptions => new PipeOptions + { + ReaderScheduler = TaskRunScheduler.Default, + WriterScheduler = Thread, + MaximumSizeHigh = ServiceContext.ServerOptions.Limits.MaxRequestBufferSize ?? 0, + MaximumSizeLow = ServiceContext.ServerOptions.Limits.MaxRequestBufferSize ?? 0 + }; + + public PipeOptions AdaptedPipeOptions => new PipeOptions + { + ReaderScheduler = InlineScheduler.Default, + WriterScheduler = InlineScheduler.Default, + MaximumSizeHigh = ServiceContext.ServerOptions.Limits.MaxRequestBufferSize ?? 0, + MaximumSizeLow = ServiceContext.ServerOptions.Limits.MaxRequestBufferSize ?? 0 + }; } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ListenerSecondary.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ListenerSecondary.cs index 9ce1478b..41de1df4 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ListenerSecondary.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ListenerSecondary.cs @@ -47,7 +47,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http DispatchPipe = new UvPipeHandle(Log); var tcs = new TaskCompletionSource(this); - Thread.Post(state => StartCallback((TaskCompletionSource)state), tcs); + Thread.Post(StartCallback, tcs); return tcs.Task; } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/MessageBody.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/MessageBody.cs index 99b10d1b..ce0edff4 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/MessageBody.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/MessageBody.cs @@ -3,10 +3,10 @@ using System; using System.IO; +using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; using Microsoft.Extensions.Internal; namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http @@ -215,9 +215,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http private void ConsumedBytes(int count) { - var scan = _context.Input.ConsumingStart(); - scan.Skip(count); - _context.Input.ConsumingComplete(scan, scan); + var scan = _context.Input.Reader.ReadAsync().GetResult().Buffer; + var consumed = scan.Move(scan.Start, count); + _context.Input.Reader.Advance(consumed, consumed); OnConsumedBytes(count); } @@ -304,7 +304,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http protected override ValueTask> PeekAsync(CancellationToken cancellationToken) { - return _context.Input.PeekAsync(); + return _context.Input.Reader.PeekAsync(); } } @@ -351,7 +351,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http return new ValueTask>(); } - var task = _context.Input.PeekAsync(); + var task = _context.Input.Reader.PeekAsync(); if (task.IsCompleted) { @@ -413,7 +413,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http // byte consts don't have a data type annotation so we pre-cast it private const byte ByteCR = (byte)'\r'; - private readonly SocketInput _input; + private readonly IPipeReader _input; private readonly FrameRequestHeaders _requestHeaders; private int _inputLength; @@ -423,7 +423,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http : base(context) { RequestKeepAlive = keepAlive; - _input = _context.Input; + _input = _context.Input.Reader; _requestHeaders = headers; } @@ -443,45 +443,71 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { while (_mode == Mode.Prefix) { - var fin = _input.CheckFinOrThrow(); + var result = await _input.ReadAsync(); + var buffer = result.Buffer; + var consumed = default(ReadCursor); + var examined = default(ReadCursor); - ParseChunkedPrefix(); + try + { + ParseChunkedPrefix(buffer, out consumed, out examined); + } + finally + { + _input.Advance(consumed, examined); + } if (_mode != Mode.Prefix) { break; } - else if (fin) + else if (result.IsCompleted) { _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); } - await _input; } while (_mode == Mode.Extension) { - var fin = _input.CheckFinOrThrow(); + var result = await _input.ReadAsync(); + var buffer = result.Buffer; + var consumed = default(ReadCursor); + var examined = default(ReadCursor); - ParseExtension(); + try + { + ParseExtension(buffer, out consumed, out examined); + } + finally + { + _input.Advance(consumed, examined); + } if (_mode != Mode.Extension) { break; } - else if (fin) + else if (result.IsCompleted) { _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); } - await _input; } while (_mode == Mode.Data) { - var fin = _input.CheckFinOrThrow(); - - var segment = PeekChunkedData(); + var result = await _input.ReadAsync(); + var buffer = result.Buffer; + ArraySegment segment; + try + { + segment = PeekChunkedData(buffer); + } + finally + { + _input.Advance(buffer.Start, buffer.Start); + } if (segment.Count != 0) { @@ -491,195 +517,214 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { break; } - else if (fin) + else if (result.IsCompleted) { _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); } - - await _input; } while (_mode == Mode.Suffix) { - var fin = _input.CheckFinOrThrow(); + var result = await _input.ReadAsync(); + var buffer = result.Buffer; + var consumed = default(ReadCursor); + var examined = default(ReadCursor); - ParseChunkedSuffix(); + try + { + ParseChunkedSuffix(buffer, out consumed, out examined); + } + finally + { + _input.Advance(consumed, examined); + } if (_mode != Mode.Suffix) { break; } - else if (fin) + else if (result.IsCompleted) { _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); } - - await _input; } } // Chunks finished, parse trailers while (_mode == Mode.Trailer) { - var fin = _input.CheckFinOrThrow(); + var result = await _input.ReadAsync(); + var buffer = result.Buffer; + var consumed = default(ReadCursor); + var examined = default(ReadCursor); - ParseChunkedTrailer(); + try + { + ParseChunkedTrailer(buffer, out consumed, out examined); + } + finally + { + _input.Advance(consumed, examined); + } if (_mode != Mode.Trailer) { break; } - else if (fin) + else if (result.IsCompleted) { _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); } - await _input; } if (_mode == Mode.TrailerHeaders) { - while (!_context.TakeMessageHeaders(_input, _requestHeaders)) + while (true) { - if (_input.CheckFinOrThrow()) + var result = await _input.ReadAsync(); + var buffer = result.Buffer; + + if (buffer.IsEmpty && result.IsCompleted) { - if (_context.TakeMessageHeaders(_input, _requestHeaders)) + _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); + } + + var consumed = default(ReadCursor); + var examined = default(ReadCursor); + + try + { + if (_context.TakeMessageHeaders(buffer, _requestHeaders, out consumed, out examined)) { break; } - else - { - _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); - } } - - await _input; + finally + { + _input.Advance(consumed, examined); + } } - _mode = Mode.Complete; } return default(ArraySegment); } - private void ParseChunkedPrefix() + private void ParseChunkedPrefix(ReadableBuffer buffer, out ReadCursor consumed, out ReadCursor examined) { - var scan = _input.ConsumingStart(); - var consumed = scan; - try + consumed = buffer.Start; + examined = buffer.Start; + var reader = new ReadableBufferReader(buffer); + var ch1 = reader.Take(); + var ch2 = reader.Take(); + + if (ch1 == -1 || ch2 == -1) { - var ch1 = scan.Take(); - var ch2 = scan.Take(); - if (ch1 == -1 || ch2 == -1) + examined = reader.Cursor; + return; + } + + var chunkSize = CalculateChunkSize(ch1, 0); + ch1 = ch2; + + do + { + if (ch1 == ';') { + consumed = reader.Cursor; + examined = reader.Cursor; + + _inputLength = chunkSize; + _mode = Mode.Extension; return; } - var chunkSize = CalculateChunkSize(ch1, 0); + ch2 = reader.Take(); + if (ch2 == -1) + { + examined = reader.Cursor; + return; + } + + if (ch1 == '\r' && ch2 == '\n') + { + consumed = reader.Cursor; + examined = reader.Cursor; + + _inputLength = chunkSize; + + if (chunkSize > 0) + { + _mode = Mode.Data; + } + else + { + _mode = Mode.Trailer; + } + + return; + } + + chunkSize = CalculateChunkSize(ch1, chunkSize); ch1 = ch2; - - do - { - if (ch1 == ';') - { - consumed = scan; - - _inputLength = chunkSize; - _mode = Mode.Extension; - return; - } - - ch2 = scan.Take(); - if (ch2 == -1) - { - return; - } - - if (ch1 == '\r' && ch2 == '\n') - { - consumed = scan; - _inputLength = chunkSize; - - if (chunkSize > 0) - { - _mode = Mode.Data; - } - else - { - _mode = Mode.Trailer; - } - - return; - } - - chunkSize = CalculateChunkSize(ch1, chunkSize); - ch1 = ch2; - } while (ch1 != -1); - } - finally - { - _input.ConsumingComplete(consumed, scan); - } + } while (ch1 != -1); } - private void ParseExtension() + private void ParseExtension(ReadableBuffer buffer, out ReadCursor consumed, out ReadCursor examined) { - var scan = _input.ConsumingStart(); - var consumed = scan; - try + // Chunk-extensions not currently parsed + // Just drain the data + consumed = buffer.Start; + examined = buffer.Start; + do { - // Chunk-extensions not currently parsed - // Just drain the data - do + ReadCursor extensionCursor; + if (ReadCursorOperations.Seek(buffer.Start, buffer.End, out extensionCursor, ByteCR) == -1) { - if (scan.Seek(ByteCR) == -1) - { - // End marker not found yet - consumed = scan; - return; - }; + // End marker not found yet + examined = buffer.End; + return; + }; - var ch1 = scan.Take(); - var ch2 = scan.Take(); + var sufixBuffer = buffer.Slice(extensionCursor); + if (sufixBuffer.Length < 2) + { + examined = buffer.End; + return; + } - if (ch2 == '\n') + sufixBuffer = sufixBuffer.Slice(0, 2); + var sufixSpan = sufixBuffer.ToSpan(); + + + if (sufixSpan[1] == '\n') + { + consumed = sufixBuffer.End; + examined = sufixBuffer.End; + if (_inputLength > 0) { - consumed = scan; - if (_inputLength > 0) - { - _mode = Mode.Data; - } - else - { - _mode = Mode.Trailer; - } + _mode = Mode.Data; } - else if (ch2 == -1) + else { - return; + _mode = Mode.Trailer; } - } while (_mode == Mode.Extension); - } - finally - { - _input.ConsumingComplete(consumed, scan); - } + } + } while (_mode == Mode.Extension); } - private ArraySegment PeekChunkedData() + private ArraySegment PeekChunkedData(ReadableBuffer buffer) { if (_inputLength == 0) { _mode = Mode.Suffix; return default(ArraySegment); } + var segment = buffer.First.GetArray(); - var scan = _input.ConsumingStart(); - var segment = scan.PeekArraySegment(); int actual = Math.Min(segment.Count, _inputLength); // Nothing is consumed yet. ConsumedBytes(int) will move the iterator. - _input.ConsumingComplete(scan, scan); - if (actual == segment.Count) { return segment; @@ -690,60 +735,54 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } } - private void ParseChunkedSuffix() + private void ParseChunkedSuffix(ReadableBuffer buffer, out ReadCursor consumed, out ReadCursor examined) { - var scan = _input.ConsumingStart(); - var consumed = scan; - try + consumed = buffer.Start; + examined = buffer.Start; + + if (buffer.Length < 2) { - var ch1 = scan.Take(); - var ch2 = scan.Take(); - if (ch1 == -1 || ch2 == -1) - { - return; - } - else if (ch1 == '\r' && ch2 == '\n') - { - consumed = scan; - _mode = Mode.Prefix; - } - else - { - _context.RejectRequest(RequestRejectionReason.BadChunkSuffix); - } + examined = buffer.End; + return; } - finally + + var sufixBuffer = buffer.Slice(0, 2); + var sufixSpan = sufixBuffer.ToSpan(); + if (sufixSpan[0] == '\r' && sufixSpan[1] == '\n') { - _input.ConsumingComplete(consumed, scan); + consumed = sufixBuffer.End; + examined = sufixBuffer.End; + _mode = Mode.Prefix; + } + else + { + _context.RejectRequest(RequestRejectionReason.BadChunkSuffix); } } - private void ParseChunkedTrailer() + private void ParseChunkedTrailer(ReadableBuffer buffer, out ReadCursor consumed, out ReadCursor examined) { - var scan = _input.ConsumingStart(); - var consumed = scan; - try - { - var ch1 = scan.Take(); - var ch2 = scan.Take(); + consumed = buffer.Start; + examined = buffer.Start; - if (ch1 == -1 || ch2 == -1) - { - return; - } - else if (ch1 == '\r' && ch2 == '\n') - { - consumed = scan; - _mode = Mode.Complete; - } - else - { - _mode = Mode.TrailerHeaders; - } - } - finally + if (buffer.Length < 2) { - _input.ConsumingComplete(consumed, scan); + examined = buffer.End; + return; + } + + var trailerBuffer = buffer.Slice(0, 2); + var trailerSpan = trailerBuffer.ToSpan(); + + if (trailerSpan[0] == '\r' && trailerSpan[1] == '\n') + { + consumed = trailerBuffer.End; + examined = trailerBuffer.End; + _mode = Mode.Complete; + } + else + { + _mode = Mode.TrailerHeaders; } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/PipelineExtensions.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/PipelineExtensions.cs new file mode 100644 index 00000000..4698eb51 --- /dev/null +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/PipelineExtensions.cs @@ -0,0 +1,108 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http +{ + public static class PipelineExtensions + { + public static ValueTask> PeekAsync(this IPipeReader pipelineReader) + { + var input = pipelineReader.ReadAsync(); + while (input.IsCompleted) + { + var result = input.GetResult(); + try + { + if (!result.Buffer.IsEmpty) + { + var segment = result.Buffer.First; + var data = segment.GetArray(); + + return new ValueTask>(data); + } + else if (result.IsCompleted || result.IsCancelled) + { + return default(ValueTask>); + } + } + finally + { + pipelineReader.Advance(result.Buffer.Start, result.Buffer.Start); + } + input = pipelineReader.ReadAsync(); + } + + return new ValueTask>(pipelineReader.PeekAsyncAwaited(input)); + } + + private static async Task> PeekAsyncAwaited(this IPipeReader pipelineReader, ReadableBufferAwaitable readingTask) + { + while (true) + { + var result = await readingTask; + + await AwaitableThreadPool.Yield(); + + try + { + if (!result.Buffer.IsEmpty) + { + var segment = result.Buffer.First; + return segment.GetArray(); + } + else if (result.IsCompleted || result.IsCancelled) + { + return default(ArraySegment); + } + } + finally + { + pipelineReader.Advance(result.Buffer.Start, result.Buffer.Start); + } + + readingTask = pipelineReader.ReadAsync(); + } + } + + private static async Task ReadAsyncDispatchedAwaited(ReadableBufferAwaitable awaitable) + { + var result = await awaitable; + await AwaitableThreadPool.Yield(); + return result; + } + + public static Span ToSpan(this ReadableBuffer buffer) + { + if (buffer.IsSingleSpan) + { + return buffer.First.Span; + } + return buffer.ToArray(); + } + + public static ArraySegment ToArraySegment(this ReadableBuffer buffer) + { + if (buffer.IsSingleSpan) + { + return buffer.First.GetArray(); + } + return new ArraySegment(buffer.ToArray()); + } + + public static ArraySegment GetArray(this Memory memory) + { + ArraySegment result; + if (!memory.TryGetArray(out result)) + { + throw new InvalidOperationException("Memory backed by array was expected"); + } + return result; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/RequestRejectionReason.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/RequestRejectionReason.cs index caa6475d..042a25d9 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/RequestRejectionReason.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/RequestRejectionReason.cs @@ -30,6 +30,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http RequestTimeout, FinalTransferCodingNotChunked, LengthRequired, - LengthRequiredHttp10, + LengthRequiredHttp10 } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketInput.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketInput.cs deleted file mode 100644 index a89fff7b..00000000 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketInput.cs +++ /dev/null @@ -1,351 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Diagnostics; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; - -namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http -{ - public class SocketInput : ICriticalNotifyCompletion, IDisposable - { - private static readonly Action _awaitableIsCompleted = () => { }; - private static readonly Action _awaitableIsNotCompleted = () => { }; - - private readonly MemoryPool _memory; - private readonly IThreadPool _threadPool; - private readonly IBufferSizeControl _bufferSizeControl; - private readonly ManualResetEventSlim _manualResetEvent = new ManualResetEventSlim(false, 0); - - private Action _awaitableState; - - private MemoryPoolBlock _head; - private MemoryPoolBlock _tail; - private MemoryPoolBlock _pinned; - - private object _sync = new object(); - - private bool _consuming; - private bool _disposed; - - private TaskCompletionSource _tcs = new TaskCompletionSource(); - - public SocketInput(MemoryPool memory, IThreadPool threadPool, IBufferSizeControl bufferSizeControl = null) - { - _memory = memory; - _threadPool = threadPool; - _bufferSizeControl = bufferSizeControl; - _awaitableState = _awaitableIsNotCompleted; - } - - public bool IsCompleted => ReferenceEquals(_awaitableState, _awaitableIsCompleted); - - private bool ReadingInput => _tcs.Task.Status == TaskStatus.WaitingForActivation; - - public bool CheckFinOrThrow() - { - CheckConnectionError(); - return _tcs.Task.Status == TaskStatus.RanToCompletion; - } - - public MemoryPoolBlock IncomingStart() - { - lock (_sync) - { - const int minimumSize = 2048; - - if (_tail != null && minimumSize <= _tail.Data.Offset + _tail.Data.Count - _tail.End) - { - _pinned = _tail; - } - else - { - _pinned = _memory.Lease(); - } - - return _pinned; - } - } - - public void IncomingComplete(int count, Exception error) - { - Action awaitableState; - - lock (_sync) - { - // Must call Add() before bytes are available to consumer, to ensure that Length is >= 0 - _bufferSizeControl?.Add(count); - - if (_pinned != null) - { - _pinned.End += count; - - if (_head == null) - { - _head = _tail = _pinned; - } - else if (_tail == _pinned) - { - // NO-OP: this was a read into unoccupied tail-space - } - else - { - Volatile.Write(ref _tail.Next, _pinned); - _tail = _pinned; - } - - _pinned = null; - } - - if (error != null) - { - SetConnectionError(error); - } - else if (count == 0) - { - FinReceived(); - } - - awaitableState = Interlocked.Exchange(ref _awaitableState, _awaitableIsCompleted); - } - - Complete(awaitableState); - } - - public void IncomingDeferred() - { - Debug.Assert(_pinned != null); - - lock (_sync) - { - if (_pinned != null) - { - if (_pinned != _tail) - { - _memory.Return(_pinned); - } - - _pinned = null; - } - } - } - - private void Complete(Action awaitableState) - { - _manualResetEvent.Set(); - - if (!ReferenceEquals(awaitableState, _awaitableIsCompleted) && - !ReferenceEquals(awaitableState, _awaitableIsNotCompleted)) - { - _threadPool.Run(awaitableState); - } - } - - public MemoryPoolIterator ConsumingStart() - { - MemoryPoolBlock head; - bool isAlreadyConsuming; - - lock (_sync) - { - isAlreadyConsuming = _consuming; - head = _head; - _consuming = true; - } - - if (isAlreadyConsuming) - { - throw new InvalidOperationException("Already consuming input."); - } - - return new MemoryPoolIterator(head); - } - - public void ConsumingComplete( - MemoryPoolIterator consumed, - MemoryPoolIterator examined) - { - bool isConsuming; - MemoryPoolBlock returnStart = null; - MemoryPoolBlock returnEnd = null; - - lock (_sync) - { - if (!_disposed) - { - if (!consumed.IsDefault) - { - // Compute lengthConsumed before modifying _head or consumed - var lengthConsumed = 0; - if (_bufferSizeControl != null) - { - lengthConsumed = new MemoryPoolIterator(_head).GetLength(consumed); - } - - returnStart = _head; - - var consumedAll = !consumed.IsDefault && consumed.IsEnd; - if (consumedAll && _pinned != _tail) - { - // Everything has been consumed and no data is being written to the - // _tail block, so return all blocks between _head and _tail inclusive. - _head = null; - _tail = null; - } - else - { - returnEnd = consumed.Block; - _head = consumed.Block; - _head.Start = consumed.Index; - } - - // Must call Subtract() after _head has been advanced, to avoid producer starting too early and growing - // buffer beyond max length. - _bufferSizeControl?.Subtract(lengthConsumed); - } - - // If _head is null, everything has been consumed and examined. - var examinedAll = (!examined.IsDefault && examined.IsEnd) || _head == null; - if (examinedAll && ReadingInput) - { - _manualResetEvent.Reset(); - - Interlocked.CompareExchange( - ref _awaitableState, - _awaitableIsNotCompleted, - _awaitableIsCompleted); - } - } - else - { - // Dispose won't have returned the blocks if we were consuming, so return them now - returnStart = _head; - _head = null; - _tail = null; - } - - isConsuming = _consuming; - _consuming = false; - } - - ReturnBlocks(returnStart, returnEnd); - - if (!isConsuming) - { - throw new InvalidOperationException("No ongoing consuming operation to complete."); - } - } - - public void CompleteAwaiting() - { - Complete(Interlocked.Exchange(ref _awaitableState, _awaitableIsCompleted)); - } - - public void AbortAwaiting() - { - SetConnectionError(new TaskCanceledException("The request was aborted")); - - CompleteAwaiting(); - } - - public SocketInput GetAwaiter() - { - return this; - } - - public void OnCompleted(Action continuation) - { - var awaitableState = Interlocked.CompareExchange( - ref _awaitableState, - continuation, - _awaitableIsNotCompleted); - - if (ReferenceEquals(awaitableState, _awaitableIsCompleted)) - { - _threadPool.Run(continuation); - } - else if (!ReferenceEquals(awaitableState, _awaitableIsNotCompleted)) - { - SetConnectionError(new InvalidOperationException("Concurrent reads are not supported.")); - - Interlocked.Exchange( - ref _awaitableState, - _awaitableIsCompleted); - - _manualResetEvent.Set(); - - _threadPool.Run(continuation); - _threadPool.Run(awaitableState); - } - } - - public void UnsafeOnCompleted(Action continuation) - { - OnCompleted(continuation); - } - - public void GetResult() - { - if (!IsCompleted) - { - _manualResetEvent.Wait(); - } - - CheckConnectionError(); - } - - public void Dispose() - { - AbortAwaiting(); - - MemoryPoolBlock block = null; - - lock (_sync) - { - if (!_consuming) - { - block = _head; - _head = null; - _tail = null; - } - - _disposed = true; - } - - ReturnBlocks(block, null); - } - - private static void ReturnBlocks(MemoryPoolBlock block, MemoryPoolBlock end) - { - while (block != end) - { - var returnBlock = block; - block = block.Next; - - returnBlock.Pool.Return(returnBlock); - } - } - - private void SetConnectionError(Exception error) - { - _tcs.TrySetException(error); - } - - private void FinReceived() - { - _tcs.TrySetResult(null); - } - - private void CheckConnectionError() - { - var error = _tcs.Task.Exception?.InnerException; - if (error != null) - { - throw error; - } - } - } -} diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketInputExtensions.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketInputExtensions.cs deleted file mode 100644 index 8dd26803..00000000 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketInputExtensions.cs +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; - -namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http -{ - public static class SocketInputExtensions - { - public static ValueTask ReadAsync(this SocketInput input, byte[] buffer, int offset, int count) - { - while (input.IsCompleted) - { - var fin = input.CheckFinOrThrow(); - - var begin = input.ConsumingStart(); - int actual; - var end = begin.CopyTo(buffer, offset, count, out actual); - input.ConsumingComplete(end, end); - - if (actual != 0 || fin) - { - return new ValueTask(actual); - } - } - - return new ValueTask(input.ReadAsyncAwaited(buffer, offset, count)); - } - - private static async Task ReadAsyncAwaited(this SocketInput input, byte[] buffer, int offset, int count) - { - while (true) - { - await input; - - var fin = input.CheckFinOrThrow(); - - var begin = input.ConsumingStart(); - int actual; - var end = begin.CopyTo(buffer, offset, count, out actual); - input.ConsumingComplete(end, end); - - if (actual != 0 || fin) - { - return actual; - } - } - } - - public static ValueTask> PeekAsync(this SocketInput input) - { - while (input.IsCompleted) - { - var fin = input.CheckFinOrThrow(); - - var begin = input.ConsumingStart(); - var segment = begin.PeekArraySegment(); - input.ConsumingComplete(begin, begin); - - if (segment.Count != 0 || fin) - { - return new ValueTask>(segment); - } - } - - return new ValueTask>(input.PeekAsyncAwaited()); - } - - private static async Task> PeekAsyncAwaited(this SocketInput input) - { - while (true) - { - await input; - - var fin = input.CheckFinOrThrow(); - - var begin = input.ConsumingStart(); - var segment = begin.PeekArraySegment(); - input.ConsumingComplete(begin, begin); - - if (segment.Count != 0 || fin) - { - return segment; - } - } - } - } -} diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketOutput.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketOutput.cs index c06bd4b2..52a3d5b1 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketOutput.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketOutput.cs @@ -327,7 +327,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http private void ScheduleWrite() { - _thread.Post(state => ((SocketOutput)state).WriteAllPending(), this); + _thread.Post(state => state.WriteAllPending(), this); } // This is called on the libuv event loop diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/AwaitableThreadPool.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/AwaitableThreadPool.cs new file mode 100644 index 00000000..70930c73 --- /dev/null +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/AwaitableThreadPool.cs @@ -0,0 +1,39 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure +{ + internal static class AwaitableThreadPool + { + internal static Awaitable Yield() + { + return new Awaitable(); + } + + internal struct Awaitable : ICriticalNotifyCompletion + { + public void GetResult() + { + + } + + public Awaitable GetAwaiter() => this; + + public bool IsCompleted => false; + + public void OnCompleted(Action continuation) + { + Task.Run(continuation); + } + + public void UnsafeOnCompleted(Action continuation) + { + OnCompleted(continuation); + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/KestrelThread.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/KestrelThread.cs index 383f804c..ea40ccf3 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/KestrelThread.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/KestrelThread.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.IO.Pipelines; using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; using System.Threading; @@ -12,18 +13,17 @@ using Microsoft.AspNetCore.Server.Kestrel.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; using Microsoft.AspNetCore.Server.Kestrel.Internal.Networking; using Microsoft.Extensions.Logging; +using MemoryPool = Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure.MemoryPool; namespace Microsoft.AspNetCore.Server.Kestrel.Internal { /// /// Summary description for KestrelThread /// - public class KestrelThread + public class KestrelThread: IScheduler { public const long HeartbeatMilliseconds = 1000; - private static readonly Action _postCallbackAdapter = (callback, state) => ((Action)callback).Invoke(state); - private static readonly Action _postAsyncCallbackAdapter = (callback, state) => ((Action)callback).Invoke(state); private static readonly Libuv.uv_walk_cb _heartbeatWalkCallback = (ptr, arg) => { var streamHandle = UvMemory.FromIntPtr(ptr) as UvStreamHandle; @@ -78,10 +78,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal QueueCloseHandle = PostCloseHandle; QueueCloseAsyncHandle = EnqueueCloseHandle; Memory = new MemoryPool(); + PipelineFactory = new PipeFactory(); WriteReqPool = new WriteReqPool(this, _log); ConnectionManager = new ConnectionManager(this, _threadPool); } - // For testing internal KestrelThread(KestrelEngine engine, int maxLoops) : this(engine) @@ -93,6 +93,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal public MemoryPool Memory { get; } + public PipeFactory PipelineFactory { get; } + public ConnectionManager ConnectionManager { get; } public WriteReqPool WriteReqPool { get; } @@ -180,7 +182,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal var result = await WaitAsync(PostAsync(state => { - var listener = (KestrelThread)state; + var listener = state; listener.WriteReqPool.Dispose(); }, this), _shutdownTimeout).ConfigureAwait(false); @@ -193,6 +195,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal finally { Memory.Dispose(); + PipelineFactory.Dispose(); } } @@ -224,13 +227,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal _loop.Stop(); } - public void Post(Action callback, object state) + public void Post(Action callback, T state) { lock (_workSync) { _workAdding.Enqueue(new Work { - CallbackAdapter = _postCallbackAdapter, + CallbackAdapter = CallbackAdapter.PostCallbackAdapter, Callback = callback, State = state }); @@ -240,17 +243,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal private void Post(Action callback) { - Post(thread => callback((KestrelThread)thread), this); + Post(callback, this); } - public Task PostAsync(Action callback, object state) + public Task PostAsync(Action callback, T state) { var tcs = new TaskCompletionSource(); lock (_workSync) { _workAdding.Enqueue(new Work { - CallbackAdapter = _postAsyncCallbackAdapter, + CallbackAdapter = CallbackAdapter.PostAsyncCallbackAdapter, Callback = callback, State = state, Completion = tcs @@ -439,6 +442,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal return await Task.WhenAny(task, Task.Delay(timeout)).ConfigureAwait(false) == task; } + public void Schedule(Action action) + { + Post(state => state(), action); + } + private struct Work { public Action CallbackAdapter; @@ -452,5 +460,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal public Action Callback; public IntPtr Handle; } + + private class CallbackAdapter + { + public static readonly Action PostCallbackAdapter = (callback, state) => ((Action)callback).Invoke((T)state); + public static readonly Action PostAsyncCallbackAdapter = (callback, state) => ((Action)callback).Invoke((T)state); + } + } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIteratorExtensions.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIteratorExtensions.cs index 1839d330..94cf66ef 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIteratorExtensions.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIteratorExtensions.cs @@ -3,6 +3,7 @@ using System; using System.Diagnostics; +using System.IO.Pipelines; using System.Runtime.CompilerServices; using System.Text; using Microsoft.AspNetCore.Http; @@ -69,19 +70,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure } } - public static string GetAsciiStringEscaped(this MemoryPoolIterator start, MemoryPoolIterator end, int maxChars) + public static string GetAsciiStringEscaped(this ReadCursor start, ReadCursor end, int maxChars) { var sb = new StringBuilder(); - var scan = start; + var reader = new ReadableBufferReader(start, end); - while (maxChars > 0 && (scan.Block != end.Block || scan.Index != end.Index)) + while (maxChars > 0 && !reader.End) { - var ch = scan.Take(); - sb.Append(ch < 0x20 || ch >= 0x7F ? $"<0x{ch.ToString("X2")}>" : ((char)ch).ToString()); + var ch = reader.Take(); + sb.Append(ch < 0x20 || ch >= 0x7F ? $"<0x{ch:X2}>" : ((char)ch).ToString()); maxChars--; } - if (scan.Block != end.Block || scan.Index != end.Index) + if (!reader.End) { sb.Append("..."); } @@ -130,16 +131,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure /// A reference to a pre-allocated known string, if the input matches any. /// true if the input matches a known string, false otherwise. [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool GetKnownMethod(this MemoryPoolIterator begin, out string knownMethod) + public static bool GetKnownMethod(this ReadableBuffer begin, out string knownMethod) { knownMethod = null; - - ulong value; - if (!begin.TryPeekLong(out value)) + if (begin.Length < sizeof(ulong)) { return false; } + ulong value = begin.ReadLittleEndian(); if ((value & _mask4Chars) == _httpGetMethodLong) { knownMethod = HttpMethods.Get; @@ -171,16 +171,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure /// A reference to a pre-allocated known string, if the input matches any. /// true if the input matches a known string, false otherwise. [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool GetKnownVersion(this MemoryPoolIterator begin, out string knownVersion) + public static bool GetKnownVersion(this ReadableBuffer begin, out string knownVersion) { knownVersion = null; - ulong value; - if (!begin.TryPeekLong(out value)) + if (begin.Length < sizeof(ulong)) { return false; } + var value = begin.ReadLittleEndian(); if (value == _http11VersionLong) { knownVersion = Http11Version; @@ -192,9 +192,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure if (knownVersion != null) { - begin.Skip(knownVersion.Length); - - if (begin.Peek() != '\r') + if (begin.Slice(sizeof(ulong)).Peek() != '\r') { knownVersion = null; } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Microsoft.AspNetCore.Server.Kestrel.csproj b/src/Microsoft.AspNetCore.Server.Kestrel/Microsoft.AspNetCore.Server.Kestrel.csproj index 02214cba..c62c37e2 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Microsoft.AspNetCore.Server.Kestrel.csproj +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Microsoft.AspNetCore.Server.Kestrel.csproj @@ -18,6 +18,10 @@ + + + + diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/MaxRequestBufferSizeTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/MaxRequestBufferSizeTests.cs index 2acd8f7f..cfbb6cc2 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/MaxRequestBufferSizeTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/MaxRequestBufferSizeTests.cs @@ -70,7 +70,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } [Theory] - [MemberData("LargeUploadData")] + [MemberData(nameof(LargeUploadData))] public async Task LargeUpload(long? maxRequestBufferSize, bool ssl, bool expectPause) { // Parameters diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests.csproj b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests.csproj index ac06d2b3..af33bd65 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests.csproj +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests.csproj @@ -3,7 +3,7 @@ - netcoreapp1.1;net452 + netcoreapp1.1 netcoreapp1.1 win7-x64 diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/PathBaseTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/PathBaseTests.cs index 0e6045f1..8c0edb6e 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/PathBaseTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/PathBaseTests.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Net.Http; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; @@ -87,7 +88,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests using (var client = new HttpClient()) { - var response = await client.GetAsync($"http://localhost:{host.GetPort()}{requestPath}"); + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}{requestPath}"); response.EnsureSuccessStatusCode(); var responseText = await response.Content.ReadAsStringAsync(); diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs index bc136202..77ae582b 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs @@ -60,7 +60,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests { for (var i = 0; i < received; i++) { - Assert.Equal((byte)((total + i) % 256), receivedBytes[i]); + // Do not use Assert.Equal here, it is to slow for this hot path + Assert.True((byte)((total + i) % 256) == receivedBytes[i], "Data received is incorrect"); } } @@ -143,7 +144,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests client.DefaultRequestHeaders.Connection.Clear(); client.DefaultRequestHeaders.Connection.Add("close"); - var response = await client.GetAsync($"http://localhost:{host.GetPort()}/"); + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); response.EnsureSuccessStatusCode(); } } diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs index edf5d397..e027de1e 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs @@ -57,7 +57,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests using (var client = new HttpClient()) { - var response = await client.GetAsync($"http://localhost:{host.GetPort()}/"); + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); response.EnsureSuccessStatusCode(); var responseBody = await response.Content.ReadAsStreamAsync(); @@ -100,7 +100,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests using (var client = new HttpClient()) { - var response = await client.GetAsync($"http://localhost:{host.GetPort()}/"); + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); response.EnsureSuccessStatusCode(); var headers = response.Headers; @@ -145,7 +145,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests using (var client = new HttpClient()) { - var response = await client.GetAsync($"http://localhost:{host.GetPort()}/"); + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); Assert.False(onStartingCalled); @@ -178,7 +178,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests using (var client = new HttpClient()) { - var response = await client.GetAsync($"http://localhost:{host.GetPort()}/"); + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); // Despite the error, the response had already started Assert.Equal(HttpStatusCode.OK, response.StatusCode); diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ThreadCountTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ThreadCountTests.cs index e4a723a9..cfd2e133 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ThreadCountTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ThreadCountTests.cs @@ -43,7 +43,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests var requestTasks = new List>(); for (int i = 0; i < 20; i++) { - var requestTask = client.GetStringAsync($"http://localhost:{host.GetPort()}/"); + var requestTask = client.GetStringAsync($"http://127.0.0.1:{host.GetPort()}/"); requestTasks.Add(requestTask); } diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Performance/Microsoft.AspNetCore.Server.Kestrel.Performance.csproj b/test/Microsoft.AspNetCore.Server.Kestrel.Performance/Microsoft.AspNetCore.Server.Kestrel.Performance.csproj index bae45dfa..4e4fabb7 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.Performance/Microsoft.AspNetCore.Server.Kestrel.Performance.csproj +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Performance/Microsoft.AspNetCore.Server.Kestrel.Performance.csproj @@ -15,7 +15,6 @@ - diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Performance/PipeThroughput.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Performance/PipeThroughput.cs new file mode 100644 index 00000000..dc991012 --- /dev/null +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Performance/PipeThroughput.cs @@ -0,0 +1,68 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Text; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + [Config(typeof(CoreConfig))] + public class PipeThroughput + { + private const int _writeLenght = 57; + private const int InnerLoopCount = 512; + + private IPipe _pipe; + private PipeFactory _pipelineFactory; + + [Setup] + public void Setup() + { + _pipelineFactory = new PipeFactory(); + _pipe = _pipelineFactory.Create(); + } + + [Benchmark(OperationsPerInvoke = InnerLoopCount)] + public void ParseLiveAspNetTwoTasks() + { + var writing = Task.Run(async () => + { + for (int i = 0; i < InnerLoopCount; i++) + { + var writableBuffer = _pipe.Writer.Alloc(_writeLenght); + writableBuffer.Advance(_writeLenght); + await writableBuffer.FlushAsync(); + } + }); + + var reading = Task.Run(async () => + { + int remaining = InnerLoopCount * _writeLenght; + while (remaining != 0) + { + var result = await _pipe.Reader.ReadAsync(); + remaining -= result.Buffer.Length; + _pipe.Reader.Advance(result.Buffer.End, result.Buffer.End); + } + }); + + Task.WaitAll(writing, reading); + } + + [Benchmark(OperationsPerInvoke = InnerLoopCount)] + public void ParseLiveAspNetInline() + { + for (int i = 0; i < InnerLoopCount; i++) + { + var writableBuffer = _pipe.Writer.Alloc(_writeLenght); + writableBuffer.Advance(_writeLenght); + writableBuffer.FlushAsync().GetAwaiter().GetResult(); + var result = _pipe.Reader.ReadAsync().GetAwaiter().GetResult(); + _pipe.Reader.Advance(result.Buffer.End, result.Buffer.End); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Performance/Program.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Performance/Program.cs index 3e2d9277..d687d6be 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.Performance/Program.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Performance/Program.cs @@ -36,6 +36,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Performance { BenchmarkRunner.Run(); } + if (type.HasFlag(BenchmarkType.Throughput)) + { + BenchmarkRunner.Run(); + } } } @@ -44,6 +48,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Performance { RequestParsing = 1, Writing = 2, + Throughput = 4, // add new ones in powers of two - e.g. 2,4,8,16... All = uint.MaxValue diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Performance/RequestParsing.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Performance/RequestParsing.cs index facb7bce..e42f1091 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.Performance/RequestParsing.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Performance/RequestParsing.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO.Pipelines; using System.Linq; using System.Text; using BenchmarkDotNet.Attributes; @@ -9,6 +10,7 @@ using Microsoft.AspNetCore.Server.Kestrel.Internal; using Microsoft.AspNetCore.Server.Kestrel.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; using Microsoft.AspNetCore.Testing; +using MemoryPool = Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure.MemoryPool; using RequestLineStatus = Microsoft.AspNetCore.Server.Kestrel.Internal.Http.Frame.RequestLineStatus; namespace Microsoft.AspNetCore.Server.Kestrel.Performance @@ -21,14 +23,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Performance private const string plaintextRequest = "GET /plaintext HTTP/1.1\r\nHost: www.example.com\r\n\r\n"; - private const string liveaspnetRequest = "GET https://live.asp.net/ HTTP/1.1\r\n" + - "Host: live.asp.net\r\n" + - "Connection: keep-alive\r\n" + - "Upgrade-Insecure-Requests: 1\r\n" + - "User-Agent: Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/54.0.2840.99 Safari/537.36\r\n" + - "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8\r\n" + - "DNT: 1\r\n" + - "Accept-Encoding: gzip, deflate, sdch, br\r\n" + + private const string liveaspnetRequest = "GET https://live.asp.net/ HTTP/1.1\r\n" + + "Host: live.asp.net\r\n" + + "Connection: keep-alive\r\n" + + "Upgrade-Insecure-Requests: 1\r\n" + + "User-Agent: Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/54.0.2840.99 Safari/537.36\r\n" + + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8\r\n" + + "DNT: 1\r\n" + + "Accept-Encoding: gzip, deflate, sdch, br\r\n" + "Accept-Language: en-US,en;q=0.8\r\n" + "Cookie: __unam=7a67379-1s65dc575c4-6d778abe-1; omniID=9519gfde_3347_4762_8762_df51458c8ec2\r\n\r\n"; @@ -48,7 +50,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Performance "Cookie: prov=20629ccd-8b0f-e8ef-2935-cd26609fc0bc; __qca=P0-1591065732-1479167353442; _ga=GA1.2.1298898376.1479167354; _gat=1; sgt=id=9519gfde_3347_4762_8762_df51458c8ec2; acct=t=why-is-%e0%a5%a7%e0%a5%a8%e0%a5%a9-numeric&s=why-is-%e0%a5%a7%e0%a5%a8%e0%a5%a9-numeric\r\n\r\n"; private static readonly byte[] _plaintextPipelinedRequests = Encoding.ASCII.GetBytes(string.Concat(Enumerable.Repeat(plaintextRequest, Pipelining))); - private static readonly byte[] _plaintextRequest = Encoding.ASCII.GetBytes(plaintextRequest); + private static readonly byte[] _plaintextRequest = Encoding.ASCII.GetBytes(plaintextRequest); private static readonly byte[] _liveaspnentPipelinedRequests = Encoding.ASCII.GetBytes(string.Concat(Enumerable.Repeat(liveaspnetRequest, Pipelining))); private static readonly byte[] _liveaspnentRequest = Encoding.ASCII.GetBytes(liveaspnetRequest); @@ -56,19 +58,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Performance private static readonly byte[] _unicodePipelinedRequests = Encoding.ASCII.GetBytes(string.Concat(Enumerable.Repeat(unicodeRequest, Pipelining))); private static readonly byte[] _unicodeRequest = Encoding.ASCII.GetBytes(unicodeRequest); - private KestrelTrace Trace; - private LoggingThreadPool ThreadPool; - private MemoryPool MemoryPool; - private SocketInput SocketInput; - private Frame Frame; - [Benchmark(Baseline = true, OperationsPerInvoke = InnerLoopCount)] public void ParsePlaintext() { for (var i = 0; i < InnerLoopCount; i++) { InsertData(_plaintextRequest); - ParseData(); } } @@ -79,7 +74,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Performance for (var i = 0; i < InnerLoopCount; i++) { InsertData(_plaintextPipelinedRequests); - ParseData(); } } @@ -90,7 +84,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Performance for (var i = 0; i < InnerLoopCount; i++) { InsertData(_liveaspnentRequest); - ParseData(); } } @@ -101,7 +94,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Performance for (var i = 0; i < InnerLoopCount; i++) { InsertData(_liveaspnentPipelinedRequests); - ParseData(); } } @@ -112,7 +104,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Performance for (var i = 0; i < InnerLoopCount; i++) { InsertData(_unicodeRequest); - ParseData(); } } @@ -123,34 +114,52 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Performance for (var i = 0; i < InnerLoopCount; i++) { InsertData(_unicodePipelinedRequests); - ParseData(); } } - private void InsertData(byte[] dataBytes) + private void InsertData(byte[] bytes) { - SocketInput.IncomingData(dataBytes, 0, dataBytes.Length); + // There should not be any backpressure and task completes immediately + Pipe.Writer.WriteAsync(bytes).GetAwaiter().GetResult(); } private void ParseData() { - while (SocketInput.GetAwaiter().IsCompleted) + do { + var awaitable = Pipe.Reader.ReadAsync(); + if (!awaitable.IsCompleted) + { + // No more data + return; + } + + var result = awaitable.GetAwaiter().GetResult(); + var readableBuffer = result.Buffer; + Frame.Reset(); - if (Frame.TakeStartLine(SocketInput) != RequestLineStatus.Done) + ReadCursor consumed; + ReadCursor examined; + if (!Frame.TakeStartLine(readableBuffer, out consumed, out examined)) { ThrowInvalidStartLine(); } + Pipe.Reader.Advance(consumed, examined); + + result = Pipe.Reader.ReadAsync().GetAwaiter().GetResult(); + readableBuffer = result.Buffer; Frame.InitializeHeaders(); - if (!Frame.TakeMessageHeaders(SocketInput, (FrameRequestHeaders) Frame.RequestHeaders)) + if (!Frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)Frame.RequestHeaders, out consumed, out examined)) { ThrowInvalidMessageHeaders(); } + Pipe.Reader.Advance(consumed, examined); } + while(true); } private void ThrowInvalidStartLine() @@ -166,23 +175,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Performance [Setup] public void Setup() { - Trace = new KestrelTrace(new TestKestrelTrace()); - ThreadPool = new LoggingThreadPool(Trace); - MemoryPool = new MemoryPool(); - SocketInput = new SocketInput(MemoryPool, ThreadPool); - var connectionContext = new MockConnection(new KestrelServerOptions()); - connectionContext.Input = SocketInput; - Frame = new Frame(application: null, context: connectionContext); + PipelineFactory = new PipeFactory(); + Pipe = PipelineFactory.Create(); } - [Cleanup] - public void Cleanup() - { - SocketInput.IncomingFin(); - SocketInput.Dispose(); - MemoryPool.Dispose(); - } + public IPipe Pipe { get; set; } + + public Frame Frame { get; set; } + + public PipeFactory PipelineFactory { get; set; } } } diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Performance/Writing.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Performance/Writing.cs index fdbc73d6..04c8e414 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.Performance/Writing.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Performance/Writing.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO.Pipelines; using System.Net; using System.Threading; using System.Threading.Tasks; @@ -88,9 +89,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Performance private TestFrame MakeFrame() { - var ltp = new LoggingThreadPool(Mock.Of()); - var pool = new MemoryPool(); - var socketInput = new SocketInput(pool, ltp); + var socketInput = new PipeFactory().Create(); var serviceContext = new ServiceContext { diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/ConnectionTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/ConnectionTests.cs index 518bb721..cbffa6aa 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/ConnectionTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/ConnectionTests.cs @@ -49,11 +49,15 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Libuv.uv_buf_t ignored; mockLibuv.AllocCallback(socket.InternalGetHandle(), 2048, out ignored); mockLibuv.ReadCallback(socket.InternalGetHandle(), 0, ref ignored); - Assert.False(connection.Input.CheckFinOrThrow()); - }, null); + + var readAwaitable = connection.Input.Reader.ReadAsync(); + + var result = readAwaitable.GetResult(); + Assert.False(result.IsCompleted); + }, (object)null); connection.ConnectionControl.End(ProduceEndType.SocketDisconnect); } } } -} +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs index abbed6df..bf6fc9c7 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs @@ -3,6 +3,7 @@ using System; using System.IO; +using System.IO.Pipelines; using System.Net; using System.Text; using System.Threading; @@ -13,7 +14,6 @@ using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel; using Microsoft.AspNetCore.Server.Kestrel.Internal; using Microsoft.AspNetCore.Server.Kestrel.Internal.Http; -using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.Internal; using Moq; @@ -23,11 +23,14 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { public class FrameTests : IDisposable { - private readonly SocketInput _socketInput; - private readonly MemoryPool _pool; + private readonly IPipe _socketInput; private readonly TestFrame _frame; private readonly ServiceContext _serviceContext; private readonly ConnectionContext _connectionContext; + private PipeFactory _pipelineFactory; + + ReadCursor consumed; + ReadCursor examined; private class TestFrame : Frame { @@ -45,9 +48,8 @@ namespace Microsoft.AspNetCore.Server.KestrelTests public FrameTests() { var trace = new KestrelTrace(new TestKestrelTrace()); - var ltp = new LoggingThreadPool(trace); - _pool = new MemoryPool(); - _socketInput = new SocketInput(_pool, ltp); + _pipelineFactory = new PipeFactory(); + _socketInput = _pipelineFactory.Create(); _serviceContext = new ServiceContext { @@ -73,27 +75,26 @@ namespace Microsoft.AspNetCore.Server.KestrelTests public void Dispose() { - _pool.Dispose(); - _socketInput.Dispose(); + _socketInput.Reader.Complete(); + _socketInput.Writer.Complete(); + _pipelineFactory.Dispose(); } [Fact] - public void CanReadHeaderValueWithoutLeadingWhitespace() + public async Task CanReadHeaderValueWithoutLeadingWhitespace() { _frame.InitializeHeaders(); - var headerArray = Encoding.ASCII.GetBytes("Header:value\r\n\r\n"); - _socketInput.IncomingData(headerArray, 0, headerArray.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes("Header:value\r\n\r\n")); - var success = _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + var success = _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders) _frame.RequestHeaders, out consumed, out examined); + _socketInput.Reader.Advance(consumed, examined); Assert.True(success); Assert.Equal(1, _frame.RequestHeaders.Count); Assert.Equal("value", _frame.RequestHeaders["Header"]); - - // Assert TakeMessageHeaders consumed all the input - var scan = _socketInput.ConsumingStart(); - Assert.True(scan.IsEnd); + Assert.Equal(readableBuffer.End, consumed); } [Theory] @@ -107,20 +108,18 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [InlineData("Header: \t\tvalue\r\n\r\n")] [InlineData("Header: \t\t value\r\n\r\n")] [InlineData("Header: \t \t value\r\n\r\n")] - public void LeadingWhitespaceIsNotIncludedInHeaderValue(string rawHeaders) + public async Task LeadingWhitespaceIsNotIncludedInHeaderValue(string rawHeaders) { - var headerArray = Encoding.ASCII.GetBytes(rawHeaders); - _socketInput.IncomingData(headerArray, 0, headerArray.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes(rawHeaders)); - var success = _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + var success = _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined); + _socketInput.Reader.Advance(consumed, examined); Assert.True(success); Assert.Equal(1, _frame.RequestHeaders.Count); Assert.Equal("value", _frame.RequestHeaders["Header"]); - - // Assert TakeMessageHeaders consumed all the input - var scan = _socketInput.ConsumingStart(); - Assert.True(scan.IsEnd); + Assert.Equal(readableBuffer.End, consumed); } [Theory] @@ -133,20 +132,18 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [InlineData("Header: value \t\t\r\n\r\n")] [InlineData("Header: value \t\t \r\n\r\n")] [InlineData("Header: value \t \t \r\n\r\n")] - public void TrailingWhitespaceIsNotIncludedInHeaderValue(string rawHeaders) + public async Task TrailingWhitespaceIsNotIncludedInHeaderValue(string rawHeaders) { - var headerArray = Encoding.ASCII.GetBytes(rawHeaders); - _socketInput.IncomingData(headerArray, 0, headerArray.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes(rawHeaders)); - var success = _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + var success = _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined); + _socketInput.Reader.Advance(consumed, examined); Assert.True(success); Assert.Equal(1, _frame.RequestHeaders.Count); Assert.Equal("value", _frame.RequestHeaders["Header"]); - - // Assert TakeMessageHeaders consumed all the input - var scan = _socketInput.ConsumingStart(); - Assert.True(scan.IsEnd); + Assert.Equal(readableBuffer.End, consumed); } [Theory] @@ -158,20 +155,18 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [InlineData("Header: one \ttwo \tthree\r\n\r\n", "one \ttwo \tthree")] [InlineData("Header: one\t two\t three\r\n\r\n", "one\t two\t three")] [InlineData("Header: one \ttwo\t three\r\n\r\n", "one \ttwo\t three")] - public void WhitespaceWithinHeaderValueIsPreserved(string rawHeaders, string expectedValue) + public async Task WhitespaceWithinHeaderValueIsPreserved(string rawHeaders, string expectedValue) { - var headerArray = Encoding.ASCII.GetBytes(rawHeaders); - _socketInput.IncomingData(headerArray, 0, headerArray.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes(rawHeaders)); - var success = _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + var success = _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined); + _socketInput.Reader.Advance(consumed, examined); Assert.True(success); Assert.Equal(1, _frame.RequestHeaders.Count); Assert.Equal(expectedValue, _frame.RequestHeaders["Header"]); - - // Assert TakeMessageHeaders consumed all the input - var scan = _socketInput.ConsumingStart(); - Assert.True(scan.IsEnd); + Assert.Equal(readableBuffer.End, consumed); } [Theory] @@ -183,27 +178,32 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [InlineData("Header: line1\r\n\t\tline2\r\n\r\n")] [InlineData("Header: line1\r\n \t\t line2\r\n\r\n")] [InlineData("Header: line1\r\n \t \t line2\r\n\r\n")] - public void TakeMessageHeadersThrowsOnHeaderValueWithLineFolding(string rawHeaders) + public async Task TakeMessageHeadersThrowsOnHeaderValueWithLineFolding(string rawHeaders) { - var headerArray = Encoding.ASCII.GetBytes(rawHeaders); - _socketInput.IncomingData(headerArray, 0, headerArray.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes(rawHeaders)); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + _socketInput.Reader.Advance(consumed, examined); - var exception = Assert.Throws(() => _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders)); + var exception = Assert.Throws(() => _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined)); Assert.Equal("Header value line folding not supported.", exception.Message); Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); } [Fact] - public void TakeMessageHeadersThrowsOnHeaderValueWithLineFolding_CharacterNotAvailableOnFirstAttempt() + public async Task TakeMessageHeadersThrowsOnHeaderValueWithLineFolding_CharacterNotAvailableOnFirstAttempt() { - var headerArray = Encoding.ASCII.GetBytes("Header-1: value1\r\n"); - _socketInput.IncomingData(headerArray, 0, headerArray.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes("Header-1: value1\r\n")); - Assert.False(_frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders)); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + Assert.False(_frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined)); + _socketInput.Reader.Advance(consumed, examined); - _socketInput.IncomingData(Encoding.ASCII.GetBytes(" "), 0, 1); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes(" ")); + + readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + var exception = Assert.Throws(() => _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined)); + _socketInput.Reader.Advance(consumed, examined); - var exception = Assert.Throws(() => _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders)); Assert.Equal("Header value line folding not supported.", exception.Message); Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); } @@ -214,13 +214,14 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [InlineData("Header-1: value1\rHeader-2: value2\r\n\r\n")] [InlineData("Header-1: value1\r\nHeader-2: value2\r\r\n")] [InlineData("Header-1: value1\r\nHeader-2: v\ralue2\r\n")] - public void TakeMessageHeadersThrowsOnHeaderValueContainingCR(string rawHeaders) + public async Task TakeMessageHeadersThrowsOnHeaderValueContainingCR(string rawHeaders) { + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes(rawHeaders)); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; - var headerArray = Encoding.ASCII.GetBytes(rawHeaders); - _socketInput.IncomingData(headerArray, 0, headerArray.Length); + var exception = Assert.Throws(() => _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined)); + _socketInput.Reader.Advance(consumed, examined); - var exception = Assert.Throws(() => _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders)); Assert.Equal("Header value must not contain CR characters.", exception.Message); Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); } @@ -229,12 +230,14 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [InlineData("Header-1 value1\r\n\r\n")] [InlineData("Header-1 value1\r\nHeader-2: value2\r\n\r\n")] [InlineData("Header-1: value1\r\nHeader-2 value2\r\n\r\n")] - public void TakeMessageHeadersThrowsOnHeaderLineMissingColon(string rawHeaders) + public async Task TakeMessageHeadersThrowsOnHeaderLineMissingColon(string rawHeaders) { - var headerArray = Encoding.ASCII.GetBytes(rawHeaders); - _socketInput.IncomingData(headerArray, 0, headerArray.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes(rawHeaders)); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined)); + _socketInput.Reader.Advance(consumed, examined); - var exception = Assert.Throws(() => _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders)); Assert.Equal("No ':' character found in header line.", exception.Message); Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); } @@ -244,12 +247,14 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [InlineData("\tHeader: value\r\n\r\n")] [InlineData(" Header-1: value1\r\nHeader-2: value2\r\n\r\n")] [InlineData("\tHeader-1: value1\r\nHeader-2: value2\r\n\r\n")] - public void TakeMessageHeadersThrowsOnHeaderLineStartingWithWhitespace(string rawHeaders) + public async Task TakeMessageHeadersThrowsOnHeaderLineStartingWithWhitespace(string rawHeaders) { - var headerArray = Encoding.ASCII.GetBytes(rawHeaders); - _socketInput.IncomingData(headerArray, 0, headerArray.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes(rawHeaders)); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined)); + _socketInput.Reader.Advance(consumed, examined); - var exception = Assert.Throws(() => _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders)); Assert.Equal("Header line must not start with whitespace.", exception.Message); Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); } @@ -263,12 +268,14 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [InlineData("Header-1: value1\r\nHeader 2: value2\r\n\r\n")] [InlineData("Header-1: value1\r\nHeader-2 : value2\r\n\r\n")] [InlineData("Header-1: value1\r\nHeader-2\t: value2\r\n\r\n")] - public void TakeMessageHeadersThrowsOnWhitespaceInHeaderName(string rawHeaders) + public async Task TakeMessageHeadersThrowsOnWhitespaceInHeaderName(string rawHeaders) { - var headerArray = Encoding.ASCII.GetBytes(rawHeaders); - _socketInput.IncomingData(headerArray, 0, headerArray.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes(rawHeaders)); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined)); + _socketInput.Reader.Advance(consumed, examined); - var exception = Assert.Throws(() => _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders)); Assert.Equal("Whitespace is not allowed in header name.", exception.Message); Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); } @@ -277,41 +284,47 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [InlineData("Header-1: value1\r\nHeader-2: value2\r\n\r\r")] [InlineData("Header-1: value1\r\nHeader-2: value2\r\n\r ")] [InlineData("Header-1: value1\r\nHeader-2: value2\r\n\r \n")] - public void TakeMessageHeadersThrowsOnHeadersNotEndingInCRLFLine(string rawHeaders) + public async Task TakeMessageHeadersThrowsOnHeadersNotEndingInCRLFLine(string rawHeaders) { - var headerArray = Encoding.ASCII.GetBytes(rawHeaders); - _socketInput.IncomingData(headerArray, 0, headerArray.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes(rawHeaders)); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined)); + _socketInput.Reader.Advance(consumed, examined); - var exception = Assert.Throws(() => _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders)); Assert.Equal("Headers corrupted, invalid header sequence.", exception.Message); Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); } [Fact] - public void TakeMessageHeadersThrowsWhenHeadersExceedTotalSizeLimit() + public async Task TakeMessageHeadersThrowsWhenHeadersExceedTotalSizeLimit() { const string headerLine = "Header: value\r\n"; _serviceContext.ServerOptions.Limits.MaxRequestHeadersTotalSize = headerLine.Length - 1; _frame.Reset(); - var headerArray = Encoding.ASCII.GetBytes($"{headerLine}\r\n"); - _socketInput.IncomingData(headerArray, 0, headerArray.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes($"{headerLine}\r\n")); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined)); + _socketInput.Reader.Advance(consumed, examined); - var exception = Assert.Throws(() => _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders)); Assert.Equal("Request headers too long.", exception.Message); Assert.Equal(StatusCodes.Status431RequestHeaderFieldsTooLarge, exception.StatusCode); } [Fact] - public void TakeMessageHeadersThrowsWhenHeadersExceedCountLimit() + public async Task TakeMessageHeadersThrowsWhenHeadersExceedCountLimit() { const string headerLines = "Header-1: value1\r\nHeader-2: value2\r\n"; _serviceContext.ServerOptions.Limits.MaxRequestHeaderCount = 1; - var headerArray = Encoding.ASCII.GetBytes($"{headerLines}\r\n"); - _socketInput.IncomingData(headerArray, 0, headerArray.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes($"{headerLines}\r\n")); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined)); + _socketInput.Reader.Advance(consumed, examined); - var exception = Assert.Throws(() => _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders)); Assert.Equal("Request contains too many headers.", exception.Message); Assert.Equal(StatusCodes.Status431RequestHeaderFieldsTooLarge, exception.StatusCode); } @@ -323,19 +336,17 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [InlineData("Cookie:\r\nConnection: close\r\n\r\n", 2)] [InlineData("Connection: close\r\nCookie: \r\n\r\n", 2)] [InlineData("Connection: close\r\nCookie:\r\n\r\n", 2)] - public void EmptyHeaderValuesCanBeParsed(string rawHeaders, int numHeaders) + public async Task EmptyHeaderValuesCanBeParsed(string rawHeaders, int numHeaders) { - var headerArray = Encoding.ASCII.GetBytes(rawHeaders); - _socketInput.IncomingData(headerArray, 0, headerArray.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes(rawHeaders)); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; - var success = _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders); + var success = _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined); + _socketInput.Reader.Advance(consumed, examined); Assert.True(success); Assert.Equal(numHeaders, _frame.RequestHeaders.Count); - - // Assert TakeMessageHeaders consumed all the input - var scan = _socketInput.ConsumingStart(); - Assert.True(scan.IsEnd); + Assert.Equal(readableBuffer.End, consumed); } [Fact] @@ -351,7 +362,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } [Fact] - public void ResetResetsHeaderLimits() + public async Task ResetResetsHeaderLimits() { const string headerLine1 = "Header-1: value1\r\n"; const string headerLine2 = "Header-2: value2\r\n"; @@ -361,19 +372,25 @@ namespace Microsoft.AspNetCore.Server.KestrelTests options.Limits.MaxRequestHeaderCount = 1; _serviceContext.ServerOptions = options; - var headerArray1 = Encoding.ASCII.GetBytes($"{headerLine1}\r\n"); - _socketInput.IncomingData(headerArray1, 0, headerArray1.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes($"{headerLine1}\r\n")); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; - Assert.True(_frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders)); + var takeMessageHeaders = _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined); + _socketInput.Reader.Advance(consumed, examined); + + Assert.True(takeMessageHeaders); Assert.Equal(1, _frame.RequestHeaders.Count); Assert.Equal("value1", _frame.RequestHeaders["Header-1"]); _frame.Reset(); - var headerArray2 = Encoding.ASCII.GetBytes($"{headerLine2}\r\n"); - _socketInput.IncomingData(headerArray2, 0, headerArray1.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes($"{headerLine2}\r\n")); + readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; - Assert.True(_frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders)); + takeMessageHeaders = _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined); + _socketInput.Reader.Advance(consumed, examined); + + Assert.True(takeMessageHeaders); Assert.Equal(1, _frame.RequestHeaders.Count); Assert.Equal("value2", _frame.RequestHeaders["Header-2"]); } @@ -462,78 +479,84 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } [Fact] - public void TakeStartLineCallsConsumingCompleteWithFurthestExamined() + public async Task TakeStartLineCallsConsumingCompleteWithFurthestExamined() { var requestLineBytes = Encoding.ASCII.GetBytes("GET / "); - _socketInput.IncomingData(requestLineBytes, 0, requestLineBytes.Length); - _frame.TakeStartLine(_socketInput); - Assert.False(_socketInput.IsCompleted); + await _socketInput.Writer.WriteAsync(requestLineBytes); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + + _frame.TakeStartLine(readableBuffer, out consumed, out examined); + _socketInput.Reader.Advance(consumed, examined); + + Assert.Equal(readableBuffer.Start, consumed); + Assert.Equal(readableBuffer.End, examined); requestLineBytes = Encoding.ASCII.GetBytes("HTTP/1.1\r\n"); - _socketInput.IncomingData(requestLineBytes, 0, requestLineBytes.Length); - _frame.TakeStartLine(_socketInput); - Assert.False(_socketInput.IsCompleted); + await _socketInput.Writer.WriteAsync(requestLineBytes); + readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + + _frame.TakeStartLine(readableBuffer, out consumed, out examined); + _socketInput.Reader.Advance(consumed, examined); + + Assert.Equal(readableBuffer.End, consumed); + Assert.Equal(readableBuffer.End, examined); } [Theory] - [InlineData("", Frame.RequestLineStatus.Empty)] - [InlineData("G", Frame.RequestLineStatus.Incomplete)] - [InlineData("GE", Frame.RequestLineStatus.Incomplete)] - [InlineData("GET", Frame.RequestLineStatus.Incomplete)] - [InlineData("GET ", Frame.RequestLineStatus.Incomplete)] - [InlineData("GET /", Frame.RequestLineStatus.Incomplete)] - [InlineData("GET / ", Frame.RequestLineStatus.Incomplete)] - [InlineData("GET / H", Frame.RequestLineStatus.Incomplete)] - [InlineData("GET / HT", Frame.RequestLineStatus.Incomplete)] - [InlineData("GET / HTT", Frame.RequestLineStatus.Incomplete)] - [InlineData("GET / HTTP", Frame.RequestLineStatus.Incomplete)] - [InlineData("GET / HTTP/", Frame.RequestLineStatus.Incomplete)] - [InlineData("GET / HTTP/1", Frame.RequestLineStatus.Incomplete)] - [InlineData("GET / HTTP/1.", Frame.RequestLineStatus.Incomplete)] - [InlineData("GET / HTTP/1.1", Frame.RequestLineStatus.Incomplete)] - [InlineData("GET / HTTP/1.1\r", Frame.RequestLineStatus.Incomplete)] - public void TakeStartLineReturnsWhenGivenIncompleteRequestLines(string requestLine, Frame.RequestLineStatus expectedReturnValue) + [InlineData("G")] + [InlineData("GE")] + [InlineData("GET")] + [InlineData("GET ")] + [InlineData("GET /")] + [InlineData("GET / ")] + [InlineData("GET / H")] + [InlineData("GET / HT")] + [InlineData("GET / HTT")] + [InlineData("GET / HTTP")] + [InlineData("GET / HTTP/")] + [InlineData("GET / HTTP/1")] + [InlineData("GET / HTTP/1.")] + [InlineData("GET / HTTP/1.1")] + [InlineData("GET / HTTP/1.1\r")] + public async Task TakeStartLineReturnsWhenGivenIncompleteRequestLines(string requestLine) { var requestLineBytes = Encoding.ASCII.GetBytes(requestLine); - _socketInput.IncomingData(requestLineBytes, 0, requestLineBytes.Length); + await _socketInput.Writer.WriteAsync(requestLineBytes); - var returnValue = _frame.TakeStartLine(_socketInput); - Assert.Equal(expectedReturnValue, returnValue); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + var returnValue = _frame.TakeStartLine(readableBuffer, out consumed, out examined); + _socketInput.Reader.Advance(consumed, examined); + + Assert.False(returnValue); } [Fact] - public void TakeStartLineStartsRequestHeadersTimeoutOnFirstByteAvailable() + public async Task TakeStartLineStartsRequestHeadersTimeoutOnFirstByteAvailable() { var connectionControl = new Mock(); _connectionContext.ConnectionControl = connectionControl.Object; - var requestLineBytes = Encoding.ASCII.GetBytes("G"); - _socketInput.IncomingData(requestLineBytes, 0, requestLineBytes.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes("G")); + + _frame.TakeStartLine((await _socketInput.Reader.ReadAsync()).Buffer, out consumed, out examined); + _socketInput.Reader.Advance(consumed, examined); - _frame.TakeStartLine(_socketInput); var expectedRequestHeadersTimeout = (long)_serviceContext.ServerOptions.Limits.RequestHeadersTimeout.TotalMilliseconds; connectionControl.Verify(cc => cc.ResetTimeout(expectedRequestHeadersTimeout, TimeoutAction.SendTimeoutResponse)); } [Fact] - public void TakeStartLineDoesNotStartRequestHeadersTimeoutIfNoDataAvailable() - { - var connectionControl = new Mock(); - _connectionContext.ConnectionControl = connectionControl.Object; - - _frame.TakeStartLine(_socketInput); - connectionControl.Verify(cc => cc.ResetTimeout(It.IsAny(), It.IsAny()), Times.Never); - } - - [Fact] - public void TakeStartLineThrowsWhenTooLong() + public async Task TakeStartLineThrowsWhenTooLong() { _serviceContext.ServerOptions.Limits.MaxRequestLineSize = "GET / HTTP/1.1\r\n".Length; var requestLineBytes = Encoding.ASCII.GetBytes("GET /a HTTP/1.1\r\n"); - _socketInput.IncomingData(requestLineBytes, 0, requestLineBytes.Length); + await _socketInput.Writer.WriteAsync(requestLineBytes); + + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + var exception = Assert.Throws(() => _frame.TakeStartLine(readableBuffer, out consumed, out examined)); + _socketInput.Reader.Advance(consumed, examined); - var exception = Assert.Throws(() => _frame.TakeStartLine(_socketInput)); Assert.Equal("Request line too long.", exception.Message); Assert.Equal(StatusCodes.Status414UriTooLong, exception.StatusCode); } @@ -550,55 +573,60 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [InlineData("GET / HTTP/1.1\n", "Invalid request line: GET / HTTP/1.1<0x0A>")] [InlineData("GET / \r\n", "Invalid request line: GET / <0x0D><0x0A>")] [InlineData("GET / HTTP/1.1\ra\n", "Invalid request line: GET / HTTP/1.1<0x0D>a<0x0A>")] - public void TakeStartLineThrowsWhenInvalid(string requestLine, string expectedExceptionMessage) + public async Task TakeStartLineThrowsWhenInvalid(string requestLine, string expectedExceptionMessage) { - var requestLineBytes = Encoding.ASCII.GetBytes(requestLine); - _socketInput.IncomingData(requestLineBytes, 0, requestLineBytes.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes(requestLine)); + + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => _frame.TakeStartLine(readableBuffer, out consumed, out examined)); + _socketInput.Reader.Advance(consumed, examined); - var exception = Assert.Throws(() => _frame.TakeStartLine(_socketInput)); Assert.Equal(expectedExceptionMessage, exception.Message); Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); } [Fact] - public void TakeStartLineThrowsOnUnsupportedHttpVersion() + public async Task TakeStartLineThrowsOnUnsupportedHttpVersion() { - var requestLineBytes = Encoding.ASCII.GetBytes("GET / HTTP/1.2\r\n"); - _socketInput.IncomingData(requestLineBytes, 0, requestLineBytes.Length); + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes("GET / HTTP/1.2\r\n")); + + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => _frame.TakeStartLine(readableBuffer, out consumed, out examined)); + _socketInput.Reader.Advance(consumed, examined); - var exception = Assert.Throws(() => _frame.TakeStartLine(_socketInput)); Assert.Equal("Unrecognized HTTP version: HTTP/1.2", exception.Message); Assert.Equal(StatusCodes.Status505HttpVersionNotsupported, exception.StatusCode); } [Fact] - public void TakeStartLineThrowsOnUnsupportedHttpVersionLongerThanEightCharacters() + public async Task TakeStartLineThrowsOnUnsupportedHttpVersionLongerThanEightCharacters() { var requestLineBytes = Encoding.ASCII.GetBytes("GET / HTTP/1.1ab\r\n"); - _socketInput.IncomingData(requestLineBytes, 0, requestLineBytes.Length); + await _socketInput.Writer.WriteAsync(requestLineBytes); + + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => _frame.TakeStartLine(readableBuffer, out consumed, out examined)); + _socketInput.Reader.Advance(consumed, examined); - var exception = Assert.Throws(() => _frame.TakeStartLine(_socketInput)); Assert.Equal("Unrecognized HTTP version: HTTP/1.1a...", exception.Message); Assert.Equal(StatusCodes.Status505HttpVersionNotsupported, exception.StatusCode); } [Fact] - public void TakeMessageHeadersCallsConsumingCompleteWithFurthestExamined() + public async Task TakeMessageHeadersCallsConsumingCompleteWithFurthestExamined() { - var headersBytes = Encoding.ASCII.GetBytes("Header: "); - _socketInput.IncomingData(headersBytes, 0, headersBytes.Length); - _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders); - Assert.False(_socketInput.IsCompleted); + foreach (var rawHeader in new [] { "Header: " , "value\r\n" , "\r\n"}) + { + await _socketInput.Writer.WriteAsync(Encoding.ASCII.GetBytes(rawHeader)); - headersBytes = Encoding.ASCII.GetBytes("value\r\n"); - _socketInput.IncomingData(headersBytes, 0, headersBytes.Length); - _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders); - Assert.False(_socketInput.IsCompleted); - - headersBytes = Encoding.ASCII.GetBytes("\r\n"); - _socketInput.IncomingData(headersBytes, 0, headersBytes.Length); - _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders); - Assert.False(_socketInput.IsCompleted); + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined); + _socketInput.Reader.Advance(consumed, examined); + Assert.Equal(readableBuffer.End, examined); + } } [Theory] @@ -619,12 +647,17 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [InlineData("Header: value\r")] [InlineData("Header: value\r\n")] [InlineData("Header: value\r\n\r")] - public void TakeMessageHeadersReturnsWhenGivenIncompleteHeaders(string headers) + public async Task TakeMessageHeadersReturnsWhenGivenIncompleteHeaders(string headers) { var headerBytes = Encoding.ASCII.GetBytes(headers); - _socketInput.IncomingData(headerBytes, 0, headerBytes.Length); + await _socketInput.Writer.WriteAsync(headerBytes); - Assert.Equal(false, _frame.TakeMessageHeaders(_socketInput, (FrameRequestHeaders)_frame.RequestHeaders)); + ReadCursor consumed; + ReadCursor examined; + var readableBuffer = (await _socketInput.Reader.ReadAsync()).Buffer; + + Assert.Equal(false, _frame.TakeMessageHeaders(readableBuffer, (FrameRequestHeaders)_frame.RequestHeaders, out consumed, out examined)); + _socketInput.Reader.Advance(consumed, examined); } [Fact] @@ -639,7 +672,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests connectionControl.Verify(cc => cc.SetTimeout(expectedKeepAliveTimeout, TimeoutAction.CloseConnection)); _frame.StopAsync(); - _socketInput.IncomingFin(); + _socketInput.Writer.Complete(); requestProcessingTask.Wait(); } @@ -721,13 +754,13 @@ namespace Microsoft.AspNetCore.Server.KestrelTests _frame.Start(); var data = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\n\r\n"); - _socketInput.IncomingData(data, 0, data.Length); + await _socketInput.Writer.WriteAsync(data); var requestProcessingTask = _frame.StopAsync(); Assert.IsNotType(typeof(Task), requestProcessingTask); await requestProcessingTask.TimeoutAfter(TimeSpan.FromSeconds(10)); - _socketInput.IncomingFin(); + _socketInput.Writer.Complete(); } [Fact] diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/ListenerPrimaryTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/ListenerPrimaryTests.cs index 43e7b78e..137eaee9 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/ListenerPrimaryTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/ListenerPrimaryTests.cs @@ -182,7 +182,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } }, null); - }, null); + }, (object)null); await connectTcs.Task; @@ -191,7 +191,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Assert.Equal("Primary", await HttpClientSlim.GetStringAsync(address)); Assert.Equal("Secondary", await HttpClientSlim.GetStringAsync(address)); - await kestrelThreadPrimary.PostAsync(_ => pipe.Dispose(), null); + await kestrelThreadPrimary.PostAsync(_ => pipe.Dispose(), (object)null); // Wait up to 10 seconds for error to be logged for (var i = 0; i < 10 && primaryTrace.Logger.TotalErrorsLogged == 0; i++) diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs index 1239cfa9..78b7e1fc 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs @@ -3,11 +3,14 @@ using System; using System.Collections.Generic; +using System.IO.Pipelines; using System.Linq; using System.Numerics; using System.Text; using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; using Xunit; +using MemoryPool = Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure.MemoryPool; +using MemoryPoolBlock = Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure.MemoryPoolBlock; namespace Microsoft.AspNetCore.Server.KestrelTests { @@ -68,7 +71,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests public void MemorySeek(string raw, string search, char expectResult, int expectIndex) { var block = _pool.Lease(); - var chars = raw.ToCharArray().Select(c => (byte) c).ToArray(); + var chars = raw.ToCharArray().Select(c => (byte)c).ToArray(); Buffer.BlockCopy(chars, 0, block.Array, block.Start, chars.Length); block.End += chars.Length; @@ -150,7 +153,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests head = blocks[0].GetIterator(); for (var i = 0; i < 64; ++i) { - Assert.True(head.Put((byte) i), $"Fail to put data at {i}."); + Assert.True(head.Put((byte)i), $"Fail to put data at {i}."); } // Can't put anything by the end @@ -167,7 +170,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { // Arrange var block = _pool.Lease(); - var bytes = new byte[] {0, 1, 2, 3, 4, 5, 6, 7}; + var bytes = new byte[] { 0, 1, 2, 3, 4, 5, 6, 7 }; Buffer.BlockCopy(bytes, 0, block.Array, block.Start, bytes.Length); block.End += bytes.Length; var scan = block.GetIterator(); @@ -177,7 +180,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests var result = scan.PeekArraySegment(); // Assert - Assert.Equal(new byte[] {0, 1, 2, 3, 4, 5, 6, 7}, result); + Assert.Equal(new byte[] { 0, 1, 2, 3, 4, 5, 6, 7 }, result); Assert.Equal(originalIndex, scan.Index); _pool.Return(block); @@ -195,7 +198,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { // Arrange var block = _pool.Lease(); - var bytes = new byte[] {0, 1, 2, 3, 4, 5, 6, 7}; + var bytes = new byte[] { 0, 1, 2, 3, 4, 5, 6, 7 }; Buffer.BlockCopy(bytes, 0, block.Array, block.Start, bytes.Length); block.End += bytes.Length; block.Start = block.End; @@ -559,50 +562,63 @@ namespace Microsoft.AspNetCore.Server.KestrelTests public void GetsKnownMethod(string input, bool expectedResult, string expectedKnownString) { // Arrange - var block = _pool.Lease(); - var chars = input.ToCharArray().Select(c => (byte)c).ToArray(); - Buffer.BlockCopy(chars, 0, block.Array, block.Start, chars.Length); - block.End += chars.Length; - var begin = block.GetIterator(); - string knownString; + var block = ReadableBuffer.Create(Encoding.ASCII.GetBytes(input)); // Act - var result = begin.GetKnownMethod(out knownString); - + string knownString; + var result = block.GetKnownMethod(out knownString); // Assert Assert.Equal(expectedResult, result); Assert.Equal(expectedKnownString, knownString); + } + + [Theory] + [InlineData("CONNECT / HTTP/1.1", true, "CONNECT")] + [InlineData("DELETE / HTTP/1.1", true, "DELETE")] + [InlineData("GET / HTTP/1.1", true, "GET")] + [InlineData("HEAD / HTTP/1.1", true, "HEAD")] + [InlineData("PATCH / HTTP/1.1", true, "PATCH")] + [InlineData("POST / HTTP/1.1", true, "POST")] + [InlineData("PUT / HTTP/1.1", true, "PUT")] + [InlineData("OPTIONS / HTTP/1.1", true, "OPTIONS")] + [InlineData("TRACE / HTTP/1.1", true, "TRACE")] + [InlineData("GET/ HTTP/1.1", false, null)] + [InlineData("get / HTTP/1.1", false, null)] + [InlineData("GOT / HTTP/1.1", false, null)] + [InlineData("ABC / HTTP/1.1", false, null)] + [InlineData("PO / HTTP/1.1", false, null)] + [InlineData("PO ST / HTTP/1.1", false, null)] + [InlineData("short ", false, null)] + public void GetsKnownMethodOnBoundary(string input, bool expectedResult, string expectedKnownString) + { // Test at boundary var maxSplit = Math.Min(input.Length, 8); - var nextBlock = _pool.Lease(); for (var split = 0; split <= maxSplit; split++) { - // Arrange - block.Reset(); - nextBlock.Reset(); + using (var pipelineFactory = new PipeFactory()) + { + // Arrange + var pipe = pipelineFactory.Create(); + var buffer = pipe.Writer.Alloc(); + var block1Input = input.Substring(0, split); + var block2Input = input.Substring(split); + buffer.Append(ReadableBuffer.Create(Encoding.ASCII.GetBytes(block1Input))); + buffer.Append(ReadableBuffer.Create(Encoding.ASCII.GetBytes(block2Input))); + buffer.FlushAsync().GetAwaiter().GetResult(); - Buffer.BlockCopy(chars, 0, block.Array, block.Start, split); - Buffer.BlockCopy(chars, split, nextBlock.Array, nextBlock.Start, chars.Length - split); + var readResult = pipe.Reader.ReadAsync().GetAwaiter().GetResult(); - block.End += split; - nextBlock.End += chars.Length - split; - block.Next = nextBlock; + // Act + string boundaryKnownString; + var boundaryResult = readResult.Buffer.GetKnownMethod(out boundaryKnownString); - var boundaryBegin = block.GetIterator(); - string boundaryKnownString; - - // Act - var boundaryResult = boundaryBegin.GetKnownMethod(out boundaryKnownString); - - // Assert - Assert.Equal(expectedResult, boundaryResult); - Assert.Equal(expectedKnownString, boundaryKnownString); + // Assert + Assert.Equal(expectedResult, boundaryResult); + Assert.Equal(expectedKnownString, boundaryKnownString); + } } - - _pool.Return(block); - _pool.Return(nextBlock); } [Theory] @@ -615,49 +631,52 @@ namespace Microsoft.AspNetCore.Server.KestrelTests public void GetsKnownVersion(string input, bool expectedResult, string expectedKnownString) { // Arrange - var block = _pool.Lease(); - var chars = input.ToCharArray().Select(c => (byte)c).ToArray(); - Buffer.BlockCopy(chars, 0, block.Array, block.Start, chars.Length); - block.End += chars.Length; - var begin = block.GetIterator(); - string knownString; + var block = ReadableBuffer.Create(Encoding.ASCII.GetBytes(input)); // Act - var result = begin.GetKnownVersion(out knownString); + string knownString; + var result = block.GetKnownVersion(out knownString); // Assert Assert.Equal(expectedResult, result); Assert.Equal(expectedKnownString, knownString); + } + [Theory] + [InlineData("HTTP/1.0\r", true, MemoryPoolIteratorExtensions.Http10Version)] + [InlineData("HTTP/1.1\r", true, MemoryPoolIteratorExtensions.Http11Version)] + [InlineData("HTTP/3.0\r", false, null)] + [InlineData("http/1.0\r", false, null)] + [InlineData("http/1.1\r", false, null)] + [InlineData("short ", false, null)] + public void GetsKnownVersionOnBoundary(string input, bool expectedResult, string expectedKnownString) + { // Test at boundary var maxSplit = Math.Min(input.Length, 9); - var nextBlock = _pool.Lease(); for (var split = 0; split <= maxSplit; split++) { - // Arrange - block.Reset(); - nextBlock.Reset(); + using (var pipelineFactory = new PipeFactory()) + { + // Arrange + var pipe = pipelineFactory.Create(); + var buffer = pipe.Writer.Alloc(); + var block1Input = input.Substring(0, split); + var block2Input = input.Substring(split); + buffer.Append(ReadableBuffer.Create(Encoding.ASCII.GetBytes(block1Input))); + buffer.Append(ReadableBuffer.Create(Encoding.ASCII.GetBytes(block2Input))); + buffer.FlushAsync().GetAwaiter().GetResult(); - Buffer.BlockCopy(chars, 0, block.Array, block.Start, split); - Buffer.BlockCopy(chars, split, nextBlock.Array, nextBlock.Start, chars.Length - split); + var readResult = pipe.Reader.ReadAsync().GetAwaiter().GetResult(); - block.End += split; - nextBlock.End += chars.Length - split; - block.Next = nextBlock; + // Act + string boundaryKnownString; + var boundaryResult = readResult.Buffer.GetKnownVersion(out boundaryKnownString); - var boundaryBegin = block.GetIterator(); - string boundaryKnownString; - - // Act - var boundaryResult = boundaryBegin.GetKnownVersion(out boundaryKnownString); - - // Assert - Assert.Equal(expectedResult, boundaryResult); - Assert.Equal(expectedKnownString, boundaryKnownString); + // Assert + Assert.Equal(expectedResult, boundaryResult); + Assert.Equal(expectedKnownString, boundaryKnownString); + } } - - _pool.Return(block); - _pool.Return(nextBlock); } [Theory] @@ -681,37 +700,24 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [InlineData("HTTP/1.1\r", "")] public void KnownVersionCanBeReadAtAnyBlockBoundary(string block1Input, string block2Input) { - MemoryPoolBlock block1 = null; - MemoryPoolBlock block2 = null; - - try + using (var pipelineFactory = new PipeFactory()) { // Arrange - var chars1 = block1Input.ToCharArray().Select(c => (byte)c).ToArray(); - var chars2 = block2Input.ToCharArray().Select(c => (byte)c).ToArray(); - block1 = _pool.Lease(); - block2 = _pool.Lease(); - Buffer.BlockCopy(chars1, 0, block1.Array, block1.Start, chars1.Length); - Buffer.BlockCopy(chars2, 0, block2.Array, block2.Start, chars2.Length); - block1.End += chars1.Length; - block2.End += chars2.Length; - block1.Next = block2; - var iterator = block1.GetIterator(); + var pipe = pipelineFactory.Create(); + var buffer = pipe.Writer.Alloc(); + buffer.Append(ReadableBuffer.Create(Encoding.ASCII.GetBytes(block1Input))); + buffer.Append(ReadableBuffer.Create(Encoding.ASCII.GetBytes(block2Input))); + buffer.FlushAsync().GetAwaiter().GetResult(); + var readResult = pipe.Reader.ReadAsync().GetAwaiter().GetResult(); // Act string knownVersion; - var result = iterator.GetKnownVersion(out knownVersion); + var result = readResult.Buffer.GetKnownVersion(out knownVersion); // Assert Assert.True(result); Assert.Equal("HTTP/1.1", knownVersion); } - finally - { - // Cleanup - if (block1 != null) _pool.Return(block1); - if (block2 != null) _pool.Return(block2); - } } [Theory] @@ -740,7 +746,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests // Arrange block = _pool.Lease(); - var chars = input.ToString().ToCharArray().Select(c => (byte)c).ToArray(); + var chars = input.ToString().ToCharArray().Select(c => (byte) c).ToArray(); Buffer.BlockCopy(chars, 0, block.Array, block.Start, chars.Length); block.End += chars.Length; var scan = block.GetIterator(); @@ -974,7 +980,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [Fact] public void EmptyIteratorBehaviourIsValid() { - const byte byteCr = (byte) '\n'; + const byte byteCr = (byte)'\n'; ulong longValue; var end = default(MemoryPoolIterator); @@ -1194,16 +1200,10 @@ namespace Microsoft.AspNetCore.Server.KestrelTests try { // Arrange - block = _pool.Lease(); - var chars = input.ToCharArray().Select(c => (byte)c).ToArray(); - Buffer.BlockCopy(chars, 0, block.Array, block.Start, chars.Length); - block.End += chars.Length; - var start = block.GetIterator(); - var end = start; - end.Skip(input.Length); + var buffer = ReadableBuffer.Create(Encoding.ASCII.GetBytes(input)); // Act - var result = start.GetAsciiStringEscaped(end, maxChars); + var result = buffer.Start.GetAsciiStringEscaped(buffer.End, maxChars); // Assert Assert.Equal(expected, result); @@ -1294,28 +1294,14 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } } - private delegate bool GetKnownString(MemoryPoolIterator iter, out string result); + private delegate bool GetKnownString(ReadableBuffer iter, out string result); private void TestKnownStringsInterning(string input, string expected, GetKnownString action) { - // Arrange - var chars = input.ToCharArray().Select(c => (byte)c).ToArray(); - var block1 = _pool.Lease(); - var block2 = _pool.Lease(); - Buffer.BlockCopy(chars, 0, block1.Array, block1.Start, chars.Length); - Buffer.BlockCopy(chars, 0, block2.Array, block2.Start, chars.Length); - block1.End += chars.Length; - block2.End += chars.Length; - var begin1 = block1.GetIterator(); - var begin2 = block2.GetIterator(); - // Act string knownString1, knownString2; - var result1 = action(begin1, out knownString1); - var result2 = action(begin2, out knownString2); - - _pool.Return(block1); - _pool.Return(block2); + var result1 = action(ReadableBuffer.Create(Encoding.ASCII.GetBytes(input)), out knownString1); + var result2 = action(ReadableBuffer.Create(Encoding.ASCII.GetBytes(input)), out knownString2); // Assert Assert.True(result1); diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/MessageBodyTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/MessageBodyTests.cs index 3e11350d..0cce8696 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/MessageBodyTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/MessageBodyTests.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.IO.Pipelines; using System.Linq; using System.Text; using System.Threading; @@ -282,24 +283,29 @@ namespace Microsoft.AspNetCore.Server.KestrelTests // so no need to bounds check in this test. var socketInput = input.FrameContext.Input; var bytes = Encoding.ASCII.GetBytes(data[0]); - var block = socketInput.IncomingStart(); - Buffer.BlockCopy(bytes, 0, block.Array, block.End, bytes.Length); - socketInput.IncomingComplete(bytes.Length, null); + var buffer = socketInput.Writer.Alloc(2048); + ArraySegment block; + Assert.True(buffer.Memory.TryGetArray(out block)); + Buffer.BlockCopy(bytes, 0, block.Array, block.Offset, bytes.Length); + buffer.Advance(bytes.Length); + await buffer.FlushAsync(); // Verify the block passed to WriteAsync is the same one incoming data was written into. Assert.Same(block.Array, await writeTcs.Task); writeTcs = new TaskCompletionSource(); bytes = Encoding.ASCII.GetBytes(data[1]); - block = socketInput.IncomingStart(); - Buffer.BlockCopy(bytes, 0, block.Array, block.End, bytes.Length); - socketInput.IncomingComplete(bytes.Length, null); + buffer = socketInput.Writer.Alloc(2048); + Assert.True(buffer.Memory.TryGetArray(out block)); + Buffer.BlockCopy(bytes, 0, block.Array, block.Offset, bytes.Length); + buffer.Advance(bytes.Length); + await buffer.FlushAsync(); Assert.Same(block.Array, await writeTcs.Task); if (headers.HeaderConnection == "close") { - socketInput.IncomingFin(); + socketInput.Writer.Complete(); } await copyToAsyncTask; diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/Microsoft.AspNetCore.Server.KestrelTests.csproj b/test/Microsoft.AspNetCore.Server.KestrelTests/Microsoft.AspNetCore.Server.KestrelTests.csproj index 322597f3..a2ceb96d 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/Microsoft.AspNetCore.Server.KestrelTests.csproj +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/Microsoft.AspNetCore.Server.KestrelTests.csproj @@ -3,7 +3,7 @@ - netcoreapp1.1;net452 + netcoreapp1.1 netcoreapp1.1 true diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketInputTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/SocketInputTests.cs deleted file mode 100644 index a751bbe3..00000000 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketInputTests.cs +++ /dev/null @@ -1,243 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Server.Kestrel.Internal; -using Microsoft.AspNetCore.Server.Kestrel.Internal.Http; -using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; -using Microsoft.AspNetCore.Server.KestrelTests.TestHelpers; -using Microsoft.AspNetCore.Testing; -using Moq; -using Xunit; - -namespace Microsoft.AspNetCore.Server.KestrelTests -{ - public class SocketInputTests - { - public static readonly TheoryData> MockBufferSizeControlData = - new TheoryData>() { new Mock(), null }; - - [Theory] - [MemberData(nameof(MockBufferSizeControlData))] - public void IncomingDataCallsBufferSizeControlAdd(Mock mockBufferSizeControl) - { - using (var memory = new MemoryPool()) - using (var socketInput = new SocketInput(memory, null, mockBufferSizeControl?.Object)) - { - socketInput.IncomingData(new byte[5], 0, 5); - mockBufferSizeControl?.Verify(b => b.Add(5)); - } - } - - [Theory] - [MemberData(nameof(MockBufferSizeControlData))] - public void IncomingCompleteCallsBufferSizeControlAdd(Mock mockBufferSizeControl) - { - using (var memory = new MemoryPool()) - using (var socketInput = new SocketInput(memory, null, mockBufferSizeControl?.Object)) - { - socketInput.IncomingComplete(5, null); - mockBufferSizeControl?.Verify(b => b.Add(5)); - } - } - - [Theory] - [MemberData(nameof(MockBufferSizeControlData))] - public void ConsumingCompleteCallsBufferSizeControlSubtract(Mock mockBufferSizeControl) - { - using (var kestrelEngine = new KestrelEngine(new MockLibuv(), new TestServiceContext())) - { - kestrelEngine.Start(1); - - using (var memory = new MemoryPool()) - using (var socketInput = new SocketInput(memory, null, mockBufferSizeControl?.Object)) - { - socketInput.IncomingData(new byte[20], 0, 20); - - var iterator = socketInput.ConsumingStart(); - iterator.Skip(5); - socketInput.ConsumingComplete(iterator, iterator); - mockBufferSizeControl?.Verify(b => b.Subtract(5)); - } - } - } - - [Fact] - public async Task ConcurrentReadsFailGracefully() - { - // Arrange - var trace = new KestrelTrace(new TestKestrelTrace()); - var ltp = new LoggingThreadPool(trace); - using (var memory2 = new MemoryPool()) - using (var socketInput = new SocketInput(memory2, ltp)) - { - var task0Threw = false; - var task1Threw = false; - var task2Threw = false; - - var task0 = AwaitAsTaskAsync(socketInput); - - Assert.False(task0.IsFaulted); - - var task = task0.ContinueWith( - (t) => - { - TestConcurrentFaultedTask(t); - task0Threw = true; - }, - TaskContinuationOptions.OnlyOnFaulted); - - Assert.False(task0.IsFaulted); - - // Awaiting/continuing two tasks faults both - - var task1 = AwaitAsTaskAsync(socketInput); - - await task1.ContinueWith( - (t) => - { - TestConcurrentFaultedTask(t); - task1Threw = true; - }, - TaskContinuationOptions.OnlyOnFaulted); - - await task; - - Assert.True(task0.IsFaulted); - Assert.True(task1.IsFaulted); - - Assert.True(task0Threw); - Assert.True(task1Threw); - - // socket stays faulted - - var task2 = AwaitAsTaskAsync(socketInput); - - await task2.ContinueWith( - (t) => - { - TestConcurrentFaultedTask(t); - task2Threw = true; - }, - TaskContinuationOptions.OnlyOnFaulted); - - Assert.True(task2.IsFaulted); - Assert.True(task2Threw); - } - } - - [Fact] - public void ConsumingOutOfOrderFailsGracefully() - { - var defultIter = new MemoryPoolIterator(); - - // Calling ConsumingComplete without a preceding calling to ConsumingStart fails - using (var socketInput = new SocketInput(null, null)) - { - Assert.Throws(() => socketInput.ConsumingComplete(defultIter, defultIter)); - } - - // Calling ConsumingComplete twice in a row fails - using (var socketInput = new SocketInput(null, null)) - { - socketInput.ConsumingStart(); - socketInput.ConsumingComplete(defultIter, defultIter); - Assert.Throws(() => socketInput.ConsumingComplete(defultIter, defultIter)); - } - - // Calling ConsumingStart twice in a row fails - using (var socketInput = new SocketInput(null, null)) - { - socketInput.ConsumingStart(); - Assert.Throws(() => socketInput.ConsumingStart()); - } - } - - [Fact] - public async Task PeekAsyncRereturnsTheSameData() - { - using (var memory = new MemoryPool()) - using (var socketInput = new SocketInput(memory, new SynchronousThreadPool())) - { - socketInput.IncomingData(new byte[5], 0, 5); - - Assert.True(socketInput.IsCompleted); - Assert.Equal(5, (await socketInput.PeekAsync()).Count); - - // The same 5 bytes will be returned again since it hasn't been consumed. - Assert.True(socketInput.IsCompleted); - Assert.Equal(5, (await socketInput.PeekAsync()).Count); - - var scan = socketInput.ConsumingStart(); - scan.Skip(3); - socketInput.ConsumingComplete(scan, scan); - - // The remaining 2 unconsumed bytes will be returned. - Assert.True(socketInput.IsCompleted); - Assert.Equal(2, (await socketInput.PeekAsync()).Count); - - scan = socketInput.ConsumingStart(); - scan.Skip(2); - socketInput.ConsumingComplete(scan, scan); - - // Everything has been consume so socketInput is no longer in the completed state - Assert.False(socketInput.IsCompleted); - } - } - - [Fact] - public async Task CompleteAwaitingDoesNotCauseZeroLengthRead() - { - using (var memory = new MemoryPool()) - using (var socketInput = new SocketInput(memory, new SynchronousThreadPool())) - { - var readBuffer = new byte[20]; - - socketInput.IncomingData(new byte[5], 0, 5); - Assert.Equal(5, await socketInput.ReadAsync(readBuffer, 0, 20)); - - var readTask = socketInput.ReadAsync(readBuffer, 0, 20); - socketInput.CompleteAwaiting(); - Assert.False(readTask.IsCompleted); - - socketInput.IncomingData(new byte[5], 0, 5); - Assert.Equal(5, await readTask); - } - } - - [Fact] - public async Task CompleteAwaitingDoesNotCauseZeroLengthPeek() - { - using (var memory = new MemoryPool()) - using (var socketInput = new SocketInput(memory, new SynchronousThreadPool())) - { - socketInput.IncomingData(new byte[5], 0, 5); - Assert.Equal(5, (await socketInput.PeekAsync()).Count); - - var scan = socketInput.ConsumingStart(); - scan.Skip(5); - socketInput.ConsumingComplete(scan, scan); - - var peekTask = socketInput.PeekAsync(); - socketInput.CompleteAwaiting(); - Assert.False(peekTask.IsCompleted); - - socketInput.IncomingData(new byte[5], 0, 5); - Assert.Equal(5, (await socketInput.PeekAsync()).Count); - } - } - - private static void TestConcurrentFaultedTask(Task t) - { - Assert.True(t.IsFaulted); - Assert.IsType(typeof(System.InvalidOperationException), t.Exception.InnerException); - Assert.Equal(t.Exception.InnerException.Message, "Concurrent reads are not supported."); - } - - private async Task AwaitAsTaskAsync(SocketInput socketInput) - { - await socketInput; - } - } -} diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/TestInput.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/TestInput.cs index 5a798e80..ff0d2d94 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/TestInput.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/TestInput.cs @@ -2,6 +2,8 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO.Pipelines; +using System.Text; using System.Net; using System.Threading; using System.Threading.Tasks; @@ -11,12 +13,14 @@ using Microsoft.AspNetCore.Server.Kestrel.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.Internal; +using MemoryPool = Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure.MemoryPool; namespace Microsoft.AspNetCore.Server.KestrelTests { class TestInput : IConnectionControl, IFrameControl, IDisposable { private MemoryPool _memoryPool; + private PipeFactory _pipelineFactory; public TestInput() { @@ -41,18 +45,19 @@ namespace Microsoft.AspNetCore.Server.KestrelTests FrameContext.ConnectionContext.ListenerContext.ServiceContext.Log = trace; _memoryPool = new MemoryPool(); - FrameContext.Input = new SocketInput(_memoryPool, ltp); + _pipelineFactory = new PipeFactory(); + FrameContext.Input = _pipelineFactory.Create();; } public Frame FrameContext { get; set; } public void Add(string text, bool fin = false) { - var data = System.Text.Encoding.ASCII.GetBytes(text); - FrameContext.Input.IncomingData(data, 0, data.Length); + var data = Encoding.ASCII.GetBytes(text); + FrameContext.Input.Writer.WriteAsync(data).Wait(); if (fin) { - FrameContext.Input.IncomingFin(); + FrameContext.Input.Writer.Complete(); } } @@ -116,7 +121,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests public void Dispose() { - FrameContext.Input.Dispose(); + _pipelineFactory.Dispose(); _memoryPool.Dispose(); } } diff --git a/test/shared/SocketInputExtensions.cs b/test/shared/SocketInputExtensions.cs deleted file mode 100644 index d6dbbb7e..00000000 --- a/test/shared/SocketInputExtensions.cs +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using Microsoft.AspNetCore.Server.Kestrel.Internal.Http; - -namespace Microsoft.AspNetCore.Testing -{ - public static class SocketInputExtensions - { - public static void IncomingData(this SocketInput input, byte[] buffer, int offset, int count) - { - var bufferIndex = offset; - var remaining = count; - - while (remaining > 0) - { - var block = input.IncomingStart(); - - var bytesLeftInBlock = block.Data.Offset + block.Data.Count - block.End; - var bytesToCopy = remaining < bytesLeftInBlock ? remaining : bytesLeftInBlock; - - Buffer.BlockCopy(buffer, bufferIndex, block.Array, block.End, bytesToCopy); - - bufferIndex += bytesToCopy; - remaining -= bytesToCopy; - - input.IncomingComplete(bytesToCopy, null); - } - } - - public static void IncomingFin(this SocketInput input) - { - input.IncomingComplete(0, null); - } - } -}