From f0b44ac7b51ffd53051aeef3846c1cdaf07e9bc6 Mon Sep 17 00:00:00 2001 From: Pavel Krymets Date: Thu, 27 Oct 2016 14:12:59 -0700 Subject: [PATCH] Check Accept-Encoding headers before creating compression provider (#154) --- .../BodyWrapperStream.cs | 115 +++++++--- .../IResponseCompressionProvider.cs | 7 + .../Properties/AssemblyInfo.cs | 3 + .../ResponseCompressionMiddleware.cs | 20 +- .../ResponseCompressionProvider.cs | 18 ++ .../BodyWrapperStreamTests.cs | 211 ++++++++++++++++++ .../project.json | 2 + 7 files changed, 325 insertions(+), 51 deletions(-) create mode 100644 test/Microsoft.AspNetCore.ResponseCompression.Tests/BodyWrapperStreamTests.cs diff --git a/src/Microsoft.AspNetCore.ResponseCompression/BodyWrapperStream.cs b/src/Microsoft.AspNetCore.ResponseCompression/BodyWrapperStream.cs index 1b82521..93197fa 100644 --- a/src/Microsoft.AspNetCore.ResponseCompression/BodyWrapperStream.cs +++ b/src/Microsoft.AspNetCore.ResponseCompression/BodyWrapperStream.cs @@ -17,23 +17,24 @@ namespace Microsoft.AspNetCore.ResponseCompression /// internal class BodyWrapperStream : Stream, IHttpBufferingFeature, IHttpSendFileFeature { - private readonly HttpResponse _response; + private readonly HttpContext _context; private readonly Stream _bodyOriginalStream; private readonly IResponseCompressionProvider _provider; - private readonly ICompressionProvider _compressionProvider; private readonly IHttpBufferingFeature _innerBufferFeature; private readonly IHttpSendFileFeature _innerSendFileFeature; + private ICompressionProvider _compressionProvider = null; private bool _compressionChecked = false; private Stream _compressionStream = null; + private bool _providerCreated = false; + private bool _autoFlush = false; - internal BodyWrapperStream(HttpResponse response, Stream bodyOriginalStream, IResponseCompressionProvider provider, ICompressionProvider compressionProvider, + internal BodyWrapperStream(HttpContext context, Stream bodyOriginalStream, IResponseCompressionProvider provider, IHttpBufferingFeature innerBufferFeature, IHttpSendFileFeature innerSendFileFeature) { - _response = response; + _context = context; _bodyOriginalStream = bodyOriginalStream; _provider = provider; - _compressionProvider = compressionProvider; _innerBufferFeature = innerBufferFeature; _innerSendFileFeature = innerSendFileFeature; } @@ -125,6 +126,10 @@ namespace Microsoft.AspNetCore.ResponseCompression if (_compressionStream != null) { _compressionStream.Write(buffer, offset, count); + if (_autoFlush) + { + _compressionStream.Flush(); + } } else { @@ -133,44 +138,70 @@ namespace Microsoft.AspNetCore.ResponseCompression } #if NET451 - public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, Object state) { - OnWrite(); + var tcs = new TaskCompletionSource(state); + InternalWriteAsync(buffer, offset, count, callback, tcs); + return tcs.Task; + } - if (_compressionStream != null) + private async void InternalWriteAsync(byte[] buffer, int offset, int count, AsyncCallback callback, TaskCompletionSource tcs) + { + try { - return _compressionStream.BeginWrite(buffer, offset, count, callback, state); + await WriteAsync(buffer, offset, count); + tcs.TrySetResult(null); + } + catch (Exception ex) + { + tcs.TrySetException(ex); + } + + if (callback != null) + { + // Offload callbacks to avoid stack dives on sync completions. + var ignored = Task.Run(() => + { + try + { + callback(tcs.Task); + } + catch (Exception) + { + // Suppress exceptions on background threads. + } + }); } - return _bodyOriginalStream.BeginWrite(buffer, offset, count, callback, state); } public override void EndWrite(IAsyncResult asyncResult) { - if (!_compressionChecked) + if (asyncResult == null) { - throw new InvalidOperationException("BeginWrite was not called before EndWrite"); + throw new ArgumentNullException(nameof(asyncResult)); } - if (_compressionStream != null) - { - _compressionStream.EndWrite(asyncResult); - } - else - { - _bodyOriginalStream.EndWrite(asyncResult); - } + var task = (Task)asyncResult; + task.GetAwaiter().GetResult(); } #endif - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { OnWrite(); if (_compressionStream != null) { - return _compressionStream.WriteAsync(buffer, offset, count, cancellationToken); + await _compressionStream.WriteAsync(buffer, offset, count, cancellationToken); + if (_autoFlush) + { + await _compressionStream.FlushAsync(cancellationToken); + } + } + else + { + await _bodyOriginalStream.WriteAsync(buffer, offset, count, cancellationToken); } - return _bodyOriginalStream.WriteAsync(buffer, offset, count, cancellationToken); } private void OnWrite() @@ -178,22 +209,30 @@ namespace Microsoft.AspNetCore.ResponseCompression if (!_compressionChecked) { _compressionChecked = true; - - if (IsCompressable()) + if (_provider.ShouldCompressResponse(_context)) { - _response.Headers.Append(HeaderNames.ContentEncoding, _compressionProvider.EncodingName); - _response.Headers.Remove(HeaderNames.ContentMD5); // Reset the MD5 because the content changed. - _response.Headers.Remove(HeaderNames.ContentLength); + var compressionProvider = ResolveCompressionProvider(); + if (compressionProvider != null) + { + _context.Response.Headers.Append(HeaderNames.ContentEncoding, compressionProvider.EncodingName); + _context.Response.Headers.Remove(HeaderNames.ContentMD5); // Reset the MD5 because the content changed. + _context.Response.Headers.Remove(HeaderNames.ContentLength); - _compressionStream = _compressionProvider.CreateStream(_bodyOriginalStream); + _compressionStream = compressionProvider.CreateStream(_bodyOriginalStream); + } } } } - private bool IsCompressable() + private ICompressionProvider ResolveCompressionProvider() { - return !_response.Headers.ContainsKey(HeaderNames.ContentRange) && // The response is not partial - _provider.ShouldCompressResponse(_response.HttpContext); + if (!_providerCreated) + { + _providerCreated = true; + _compressionProvider = _provider.GetCompressionProvider(_context); + } + + return _compressionProvider; } public void DisableRequestBuffering() @@ -205,13 +244,16 @@ namespace Microsoft.AspNetCore.ResponseCompression // For this to be effective it needs to be called before the first write. public void DisableResponseBuffering() { - if (!_compressionProvider.SupportsFlush) + if (ResolveCompressionProvider()?.SupportsFlush == false) { // Don't compress, some of the providers don't implement Flush (e.g. .NET 4.5.1 GZip/Deflate stream) // which would block real-time responses like SignalR. _compressionChecked = true; } - + else + { + _autoFlush = true; + } _innerBufferFeature?.DisableResponseBuffering(); } @@ -257,6 +299,11 @@ namespace Microsoft.AspNetCore.ResponseCompression { fileStream.Seek(offset, SeekOrigin.Begin); await StreamCopyOperation.CopyToAsync(fileStream, _compressionStream, count, cancellation); + + if (_autoFlush) + { + await _compressionStream.FlushAsync(cancellation); + } } } } diff --git a/src/Microsoft.AspNetCore.ResponseCompression/IResponseCompressionProvider.cs b/src/Microsoft.AspNetCore.ResponseCompression/IResponseCompressionProvider.cs index 8b118a2..c206acb 100644 --- a/src/Microsoft.AspNetCore.ResponseCompression/IResponseCompressionProvider.cs +++ b/src/Microsoft.AspNetCore.ResponseCompression/IResponseCompressionProvider.cs @@ -23,5 +23,12 @@ namespace Microsoft.AspNetCore.ResponseCompression /// /// bool ShouldCompressResponse(HttpContext context); + + /// + /// Examines the request to see if compression should be used for response. + /// + /// + /// + bool CheckRequestAcceptsCompression(HttpContext context); } } diff --git a/src/Microsoft.AspNetCore.ResponseCompression/Properties/AssemblyInfo.cs b/src/Microsoft.AspNetCore.ResponseCompression/Properties/AssemblyInfo.cs index 2dc4003..58848c0 100644 --- a/src/Microsoft.AspNetCore.ResponseCompression/Properties/AssemblyInfo.cs +++ b/src/Microsoft.AspNetCore.ResponseCompression/Properties/AssemblyInfo.cs @@ -3,6 +3,9 @@ using System.Reflection; using System.Resources; +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Microsoft.AspNetCore.ResponseCompression.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] [assembly: AssemblyMetadata("Serviceable", "True")] [assembly: NeutralResourcesLanguage("en-us")] diff --git a/src/Microsoft.AspNetCore.ResponseCompression/ResponseCompressionMiddleware.cs b/src/Microsoft.AspNetCore.ResponseCompression/ResponseCompressionMiddleware.cs index b684145..7fac933 100644 --- a/src/Microsoft.AspNetCore.ResponseCompression/ResponseCompressionMiddleware.cs +++ b/src/Microsoft.AspNetCore.ResponseCompression/ResponseCompressionMiddleware.cs @@ -18,15 +18,13 @@ namespace Microsoft.AspNetCore.ResponseCompression private readonly IResponseCompressionProvider _provider; - private readonly bool _enableForHttps; /// /// Initialize the Response Compression middleware. /// /// /// - /// - public ResponseCompressionMiddleware(RequestDelegate next, IResponseCompressionProvider provider, IOptions options) + public ResponseCompressionMiddleware(RequestDelegate next, IResponseCompressionProvider provider) { if (next == null) { @@ -36,14 +34,9 @@ namespace Microsoft.AspNetCore.ResponseCompression { throw new ArgumentNullException(nameof(provider)); } - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } _next = next; _provider = provider; - _enableForHttps = options.Value.EnableForHttps; } /// @@ -53,14 +46,7 @@ namespace Microsoft.AspNetCore.ResponseCompression /// public async Task Invoke(HttpContext context) { - ICompressionProvider compressionProvider = null; - - if (!context.Request.IsHttps || _enableForHttps) - { - compressionProvider = _provider.GetCompressionProvider(context); - } - - if (compressionProvider == null) + if (!_provider.CheckRequestAcceptsCompression(context)) { await _next(context); return; @@ -70,7 +56,7 @@ namespace Microsoft.AspNetCore.ResponseCompression var originalBufferFeature = context.Features.Get(); var originalSendFileFeature = context.Features.Get(); - var bodyWrapperStream = new BodyWrapperStream(context.Response, bodyStream, _provider, compressionProvider, + var bodyWrapperStream = new BodyWrapperStream(context, bodyStream, _provider, originalBufferFeature, originalSendFileFeature); context.Response.Body = bodyWrapperStream; context.Features.Set(bodyWrapperStream); diff --git a/src/Microsoft.AspNetCore.ResponseCompression/ResponseCompressionProvider.cs b/src/Microsoft.AspNetCore.ResponseCompression/ResponseCompressionProvider.cs index b546a60..e16e6a1 100644 --- a/src/Microsoft.AspNetCore.ResponseCompression/ResponseCompressionProvider.cs +++ b/src/Microsoft.AspNetCore.ResponseCompression/ResponseCompressionProvider.cs @@ -16,6 +16,7 @@ namespace Microsoft.AspNetCore.ResponseCompression { private readonly ICompressionProvider[] _providers; private readonly HashSet _mimeTypes; + private readonly bool _enableForHttps; /// /// If no compression providers are specified then GZip is used by default. @@ -54,6 +55,8 @@ namespace Microsoft.AspNetCore.ResponseCompression mimeTypes = ResponseCompressionDefaults.MimeTypes; } _mimeTypes = new HashSet(mimeTypes, StringComparer.OrdinalIgnoreCase); + + _enableForHttps = options.Value.EnableForHttps; } /// @@ -103,6 +106,11 @@ namespace Microsoft.AspNetCore.ResponseCompression /// public virtual bool ShouldCompressResponse(HttpContext context) { + if (context.Response.Headers.ContainsKey(HeaderNames.ContentRange)) + { + return false; + } + var mimeType = context.Response.ContentType; if (string.IsNullOrEmpty(mimeType)) @@ -121,5 +129,15 @@ namespace Microsoft.AspNetCore.ResponseCompression // TODO PERF: StringSegments? return _mimeTypes.Contains(mimeType); } + + /// + public bool CheckRequestAcceptsCompression(HttpContext context) + { + if (context.Request.IsHttps && !_enableForHttps) + { + return false; + } + return !string.IsNullOrEmpty(context.Request.Headers[HeaderNames.AcceptEncoding]); + } } } diff --git a/test/Microsoft.AspNetCore.ResponseCompression.Tests/BodyWrapperStreamTests.cs b/test/Microsoft.AspNetCore.ResponseCompression.Tests/BodyWrapperStreamTests.cs new file mode 100644 index 0000000..75ca3d7 --- /dev/null +++ b/test/Microsoft.AspNetCore.ResponseCompression.Tests/BodyWrapperStreamTests.cs @@ -0,0 +1,211 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.ResponseCompression.Tests +{ + public class BodyWrapperStreamTests + { + [Theory] + [InlineData(true)] + [InlineData(false)] + public void Write_IsPassedToUnderlyingStream_WhenDisableResponseBuffering(bool flushable) + { + + var buffer = new byte[] { 1 }; + byte[] written = null; + + var mock = new Mock(); + mock.SetupGet(s => s.CanWrite).Returns(true); + mock.Setup(s => s.Write(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((b, o, c) => + { + written = new ArraySegment(b, 0, c).ToArray(); + }); + + var stream = new BodyWrapperStream(new DefaultHttpContext(), mock.Object, new MockResponseCompressionProvider(flushable), null, null); + + stream.DisableResponseBuffering(); + stream.Write(buffer, 0, buffer.Length); + + Assert.Equal(buffer, written); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task WriteAsync_IsPassedToUnderlyingStream_WhenDisableResponseBuffering(bool flushable) + { + var buffer = new byte[] { 1 }; + byte[] written = null; + + var mock = new Mock(); + mock.SetupGet(s => s.CanWrite).Returns(true); + mock.Setup(s => s.WriteAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((b, o, c, t) => + { + written = new ArraySegment(b, 0, c).ToArray(); + }) + .Returns(Task.FromResult(0)); + + var stream = new BodyWrapperStream(new DefaultHttpContext(), mock.Object, new MockResponseCompressionProvider(flushable), null, null); + + stream.DisableResponseBuffering(); + await stream.WriteAsync(buffer, 0, buffer.Length); + + Assert.Equal(buffer, written); + } + + [Fact] + public async Task SendFileAsync_IsPassedToUnderlyingStream_WhenDisableResponseBuffering() + { + byte[] written = null; + + var mock = new Mock(); + mock.SetupGet(s => s.CanWrite).Returns(true); + mock.Setup(s => s.WriteAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((b, o, c, t) => + { + written = new ArraySegment(b, 0, c).ToArray(); + }) + .Returns(Task.FromResult(0)); + + var stream = new BodyWrapperStream(new DefaultHttpContext(), mock.Object, new MockResponseCompressionProvider(true), null, null); + + stream.DisableResponseBuffering(); + + var path = "testfile1kb.txt"; + await stream.SendFileAsync(path, 0, null, CancellationToken.None); + + Assert.Equal(File.ReadAllBytes(path), written); + } + +#if NET451 + [Theory] + [InlineData(true)] + [InlineData(false)] + public void BeginWrite_IsPassedToUnderlyingStream_WhenDisableResponseBuffering(bool flushable) + { + var buffer = new byte[] { 1 }; + byte[] written = null; + + var mock = new Mock(); + mock.SetupGet(s => s.CanWrite).Returns(true); + mock.Setup(s => s.WriteAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((b, o, c, t) => + { + written = new ArraySegment(b, 0, c).ToArray(); + }) + .Returns(Task.FromResult(0)); + + var stream = new BodyWrapperStream(new DefaultHttpContext(), mock.Object, new MockResponseCompressionProvider(flushable), null, null); + + stream.DisableResponseBuffering(); + stream.BeginWrite(buffer, 0, buffer.Length, (o) => {}, null); + + Assert.Equal(buffer, written); + } +#endif + + private class MockResponseCompressionProvider: IResponseCompressionProvider + { + private readonly bool _flushable; + + public MockResponseCompressionProvider(bool flushable) + { + _flushable = flushable; + } + + public ICompressionProvider GetCompressionProvider(HttpContext context) + { + return new MockCompressionProvider(_flushable); + } + + public bool ShouldCompressResponse(HttpContext context) + { + return true; + } + + public bool CheckRequestAcceptsCompression(HttpContext context) + { + return true; + } + } + + + private class MockCompressionProvider : ICompressionProvider + { + public MockCompressionProvider(bool flushable) + { + SupportsFlush = flushable; + } + + public string EncodingName { get; } + + public bool SupportsFlush { get; } + + public Stream CreateStream(Stream outputStream) + { + if (SupportsFlush) + { + return new BufferedStream(outputStream); + } + else + { + return new NoFlushBufferedStream(outputStream); + } + + } + } + + private class NoFlushBufferedStream : Stream + { + private readonly BufferedStream _bufferedStream; + + public NoFlushBufferedStream(Stream outputStream) + { + _bufferedStream = new BufferedStream(outputStream); + } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) => _bufferedStream.Read(buffer, offset, count); + + public override long Seek(long offset, SeekOrigin origin) => _bufferedStream.Seek(offset, origin); + + public override void SetLength(long value) => _bufferedStream.SetLength(value); + + public override void Write(byte[] buffer, int offset, int count) => _bufferedStream.Write(buffer, offset, count); + + public override bool CanRead => _bufferedStream.CanRead; + + public override bool CanSeek => _bufferedStream.CanSeek; + + public override bool CanWrite => _bufferedStream.CanWrite; + + public override long Length => _bufferedStream.Length; + + public override long Position + { + get { return _bufferedStream.Position; } + set { _bufferedStream.Position = value; } + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + _bufferedStream.Flush(); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.ResponseCompression.Tests/project.json b/test/Microsoft.AspNetCore.ResponseCompression.Tests/project.json index 27f2b14..18ac322 100644 --- a/test/Microsoft.AspNetCore.ResponseCompression.Tests/project.json +++ b/test/Microsoft.AspNetCore.ResponseCompression.Tests/project.json @@ -1,5 +1,6 @@ { "buildOptions": { + "keyFile": "../../tools/Key.snk", "copyToOutput": [ "testfile1kb.txt" ], @@ -11,6 +12,7 @@ "Microsoft.AspNetCore.ResponseCompression": "1.0.0-*", "Microsoft.AspNetCore.TestHost": "1.1.0-*", "Microsoft.Net.Http.Headers": "1.1.0-*", + "Moq": "4.6.36-*", "xunit": "2.2.0-*" }, "frameworks": {