From f227f95966c6380aee8d07c6f786873e974e3f8a Mon Sep 17 00:00:00 2001 From: Johnny Z Date: Thu, 12 Apr 2018 02:31:39 +1000 Subject: [PATCH] HTTP 1.1 codec #52 (#256) --- .gitattributes | 3 +- DotNetty.sln | 27 +- examples/HttpServer/HelloServerHandler.cs | 104 + examples/HttpServer/HttpServer.csproj | 30 + examples/HttpServer/MessageBody.cs | 17 + examples/HttpServer/Program.cs | 116 + .../HttpServer/Properties/AssemblyInfo.cs | 21 + examples/HttpServer/appsettings.json | 4 + src/DotNetty.Buffers/AbstractByteBuffer.cs | 86 + .../AbstractPooledDerivedByteBuffer.cs | 1 + src/DotNetty.Buffers/ByteBufferUtil.cs | 192 +- src/DotNetty.Buffers/EmptyByteBuffer.cs | 16 + src/DotNetty.Buffers/IByteBuffer.cs | 9 + src/DotNetty.Buffers/IByteBufferAllocator.cs | 2 +- src/DotNetty.Buffers/PoolArena.cs | 6 +- src/DotNetty.Buffers/PooledByteBuffer.cs | 6 +- .../PooledDuplicatedByteBuffer.cs | 4 +- .../PooledSlicedByteBuffer.cs | 4 +- src/DotNetty.Buffers/ThrowHelper.cs | 22 + .../UnpooledUnsafeDirectByteBuffer.cs | 2 +- ...BufferUtil .cs => UnsafeByteBufferUtil.cs} | 20 +- src/DotNetty.Buffers/WrappedByteBuffer.cs | 8 + .../WrappedCompositeByteBuffer.cs | 8 + .../CombinedHttpHeaders.cs | 221 ++ .../ComposedLastHttpContent.cs | 63 + .../Cookies/ClientCookieDecoder.cs | 266 +++ .../Cookies/ClientCookieEncoder.cs | 147 ++ .../Cookies/CookieDecoder.cs | 74 + .../Cookies/CookieEncoder.cs | 46 + .../Cookies/CookieHeaderNames.cs | 22 + .../Cookies/CookieUtil.cs | 203 ++ .../Cookies/DefaultCookie.cs | 227 ++ src/DotNetty.Codecs.Http/Cookies/ICookie.cs | 37 + .../Cookies/ServerCookieDecoder.cs | 146 ++ .../Cookies/ServerCookieEncoder.cs | 173 ++ src/DotNetty.Codecs.Http/Cors/CorsConfig.cs | 190 ++ .../Cors/CorsConfigBuilder.cs | 191 ++ src/DotNetty.Codecs.Http/Cors/CorsHandler.cs | 211 ++ .../DefaultFullHttpRequest.cs | 143 ++ .../DefaultFullHttpResponse.cs | 156 ++ .../DefaultHttpContent.cs | 66 + .../DefaultHttpHeaders.cs | 372 ++++ .../DefaultHttpMessage.cs | 73 + src/DotNetty.Codecs.Http/DefaultHttpObject.cs | 44 + .../DefaultHttpRequest.cs | 88 + .../DefaultHttpResponse.cs | 41 + .../DefaultLastHttpContent.cs | 84 + .../DotNetty.Codecs.Http.csproj | 48 + src/DotNetty.Codecs.Http/EmptyHttpHeaders.cs | 85 + .../EmptyLastHttpContent.cs | 53 + src/DotNetty.Codecs.Http/HttpChunkedInput.cs | 54 + src/DotNetty.Codecs.Http/HttpClientCodec.cs | 259 +++ .../HttpClientUpgradeHandler.cs | 197 ++ src/DotNetty.Codecs.Http/HttpConstants.cs | 55 + .../HttpContentCompressor.cs | 138 ++ .../HttpContentDecoder.cs | 243 +++ .../HttpContentDecompressor.cs | 50 + .../HttpContentEncoder.cs | 357 +++ .../HttpExpectationFailedEvent.cs | 16 + src/DotNetty.Codecs.Http/HttpHeaderNames.cs | 170 ++ src/DotNetty.Codecs.Http/HttpHeaderValues.cs | 100 + src/DotNetty.Codecs.Http/HttpHeaders.cs | 267 +++ .../HttpHeadersEncoder.cs | 46 + src/DotNetty.Codecs.Http/HttpMessageUtil.cs | 88 + src/DotNetty.Codecs.Http/HttpMethod.cs | 234 ++ .../HttpObjectAggregator.cs | 354 +++ src/DotNetty.Codecs.Http/HttpObjectDecoder.cs | 898 ++++++++ src/DotNetty.Codecs.Http/HttpObjectEncoder.cs | 252 +++ .../HttpRequestDecoder.cs | 42 + .../HttpRequestEncoder.cs | 79 + .../HttpResponseDecoder.cs | 40 + .../HttpResponseEncoder.cs | 61 + .../HttpResponseStatus.cs | 560 +++++ src/DotNetty.Codecs.Http/HttpScheme.cs | 39 + src/DotNetty.Codecs.Http/HttpServerCodec.cs | 109 + .../HttpServerExpectContinueHandler.cs | 65 + .../HttpServerKeepAliveHandler.cs | 98 + .../HttpServerUpgradeHandler.cs | 333 +++ src/DotNetty.Codecs.Http/HttpStatusClass.cs | 101 + src/DotNetty.Codecs.Http/HttpUtil.cs | 287 +++ src/DotNetty.Codecs.Http/HttpVersion.cs | 239 ++ src/DotNetty.Codecs.Http/IFullHttpMessage.cs | 9 + src/DotNetty.Codecs.Http/IFullHttpRequest.cs | 9 + src/DotNetty.Codecs.Http/IFullHttpResponse.cs | 9 + src/DotNetty.Codecs.Http/IHttpContent.cs | 11 + src/DotNetty.Codecs.Http/IHttpMessage.cs | 14 + src/DotNetty.Codecs.Http/IHttpObject.cs | 9 + src/DotNetty.Codecs.Http/IHttpRequest.cs | 16 + src/DotNetty.Codecs.Http/IHttpResponse.cs | 12 + src/DotNetty.Codecs.Http/ILastHttpContent.cs | 10 + .../Multipart/AbstractDiskHttpData.cs | 341 +++ .../Multipart/AbstractHttpData.cs | 137 ++ .../Multipart/AbstractMemoryHttpData.cs | 207 ++ .../Multipart/CaseIgnoringComparator.cs | 104 + .../Multipart/DefaultHttpDataFactory.cs | 283 +++ .../Multipart/DiskAttribute.cs | 180 ++ .../Multipart/DiskFileUpload.cs | 165 ++ .../Multipart/EndOfDataDecoderException.cs | 20 + .../Multipart/ErrorDataDecoderException.cs | 25 + .../Multipart/ErrorDataEncoderException.cs | 24 + .../Multipart/FileUploadUtil.cs | 18 + .../Multipart/HttpPostBodyUtil.cs | 116 + .../HttpPostMultipartRequestDecoder.cs | 1564 +++++++++++++ .../Multipart/HttpPostRequestDecoder.cs | 178 ++ .../Multipart/HttpPostRequestEncoder.cs | 1095 ++++++++++ .../HttpPostStandardRequestDecoder.cs | 586 +++++ .../Multipart/IAttribute.cs | 10 + .../Multipart/IFileUpload.cs | 14 + .../Multipart/IHttpData .cs | 48 + .../Multipart/IHttpDataFactory.cs | 30 + .../Multipart/IInterfaceHttpData.cs | 23 + .../IInterfaceHttpPostRequestDecoder.cs | 35 + .../Multipart/InternalAttribute.cs | 133 ++ .../Multipart/MemoryAttribute.cs | 160 ++ .../Multipart/MemoryFileUpload.cs | 146 ++ .../Multipart/MixedAttribute.cs | 249 +++ .../Multipart/MixedFileUpload.cs | 229 ++ .../Multipart/MultiPartStatus.cs | 23 + .../NotEnoughDataDecoderException.cs | 18 + .../Properties/AssemblyInfo.cs | 8 + .../Properties/Friends.cs | 7 + .../QueryStringDecoder.cs | 252 +++ .../QueryStringEncoder.cs | 83 + src/DotNetty.Codecs.Http/ThrowHelper.cs | 167 ++ src/DotNetty.Codecs.Http/UrlEncoder.cs | 52 + .../CharSequenceValueConverter.cs | 111 + src/DotNetty.Codecs/Compression/Adler32.cs | 127 ++ src/DotNetty.Codecs/Compression/CRC32.cs | 196 ++ .../Compression/CompressionException.cs | 20 + .../Compression/DecompressionException.cs | 20 + src/DotNetty.Codecs/Compression/Deflate.cs | 1943 +++++++++++++++++ src/DotNetty.Codecs/Compression/Deflater.cs | 190 ++ .../Compression/GZIPException.cs | 52 + src/DotNetty.Codecs/Compression/GZIPHeader.cs | 223 ++ src/DotNetty.Codecs/Compression/IChecksum.cs | 18 + src/DotNetty.Codecs/Compression/InfBlocks.cs | 695 ++++++ src/DotNetty.Codecs/Compression/InfCodes.cs | 695 ++++++ src/DotNetty.Codecs/Compression/InfTree.cs | 607 +++++ src/DotNetty.Codecs/Compression/Inflate.cs | 962 ++++++++ src/DotNetty.Codecs/Compression/Inflater.cs | 168 ++ src/DotNetty.Codecs/Compression/JZlib.cs | 97 + .../Compression/JZlibDecoder.cs | 160 ++ .../Compression/JZlibEncoder.cs | 249 +++ src/DotNetty.Codecs/Compression/StaticTree.cs | 158 ++ src/DotNetty.Codecs/Compression/Tree.cs | 414 ++++ src/DotNetty.Codecs/Compression/ZStream.cs | 374 ++++ .../Compression/ZlibCodecFactory.cs | 30 + .../Compression/ZlibDecoder.cs | 10 + .../Compression/ZlibEncoder.cs | 21 + src/DotNetty.Codecs/Compression/ZlibUtil.cs | 62 + .../Compression/ZlibWrapper.cs | 30 + src/DotNetty.Codecs/DateFormatter.cs | 484 ++++ src/DotNetty.Codecs/DecoderException.cs | 7 +- src/DotNetty.Codecs/DecoderResult.cs | 62 + src/DotNetty.Codecs/DefaultHeaders.cs | 1138 ++++++++++ src/DotNetty.Codecs/DotNetty.Codecs.csproj | 3 + src/DotNetty.Codecs/EncoderException.cs | 7 +- src/DotNetty.Codecs/HeadersUtils.cs | 88 + src/DotNetty.Codecs/IDecoderResultProvider.cs | 10 + src/DotNetty.Codecs/IHeaders.cs | 187 ++ src/DotNetty.Codecs/INameValidator.cs | 10 + src/DotNetty.Codecs/IValueConverter.cs | 46 + .../MessageAggregationException.cs | 20 + src/DotNetty.Codecs/MessageAggregator.cs | 367 ++++ src/DotNetty.Codecs/MessageToMessageCodec.cs | 64 + src/DotNetty.Codecs/NullNameValidator.cs | 18 + .../PrematureChannelClosureException.cs | 25 + src/DotNetty.Codecs/Properties/Friends.cs | 6 + .../Internal/AppendableCharSequence.cs | 227 ++ .../Internal/ConcurrentCircularArrayQueue.cs | 5 +- src/DotNetty.Common/Internal/EmptyArrays.cs | 22 + src/DotNetty.Common/Internal/IAppendable.cs | 16 + .../Internal/PlatformDependent.cs | 178 +- .../Internal/PlatformDependent0.cs | 75 +- src/DotNetty.Common/InternalThreadLocalMap.cs | 40 +- .../Utilities/AbstractReferenceCounted.cs | 2 +- .../Utilities/ArrayExtensions.cs | 1 - src/DotNetty.Common/Utilities/AsciiString.cs | 1584 ++++++++++++++ .../Utilities/ByteProcessor.cs | 30 +- .../Utilities/ByteProcessorUtils.cs | 13 + .../Utilities/CharSequenceEnumerator.cs | 86 + src/DotNetty.Common/Utilities/CharUtil.cs | 603 +++++ .../Utilities/ICharSequence.cs | 32 + .../Utilities/IHashingStrategy.cs | 21 + src/DotNetty.Common/Utilities/Signal.cs | 77 + .../Utilities/StringBuilderCharSequence.cs | 184 ++ .../Utilities/StringCharSequence.cs | 158 ++ src/DotNetty.Common/Utilities/StringUtil.cs | 598 +++-- .../Streams/ChunkedStream.cs | 81 + .../Streams/ChunkedWriteHandler.cs | 368 ++++ .../Streams/IChunkedInput.cs | 20 + .../Channels/ChannelHandlerAdapter.cs | 9 + .../Channels/CombinedChannelDuplexHandler.cs | 542 +++++ .../Channels/DefaultFileRegion.cs | 90 + .../Channels/Embedded/EmbeddedChannel.cs | 36 +- .../Channels/IFileRegion.cs | 19 + .../AbstractByteBufferTests.cs | 88 + ...AbstractReferenceCountedByteBufferTests.cs | 4 +- .../SimpleLeakAwareByteBufferTests.cs | 4 +- ...SimpleLeakAwareCompositeByteBufferTests.cs | 2 +- .../SlicedByteBufferTest.cs | 9 + .../CombinedHttpHeadersTest.cs | 368 ++++ .../Cookies/ClientCookieDecoderTest.cs | 275 +++ .../Cookies/ClientCookieEncoderTest.cs | 50 + .../Cookies/ServerCookieDecoderTest.cs | 190 ++ .../Cookies/ServerCookieEncoderTest.cs | 149 ++ .../Cors/CorsConfigTest.cs | 140 ++ .../Cors/CorsHandlerTest.cs | 432 ++++ .../DefaultHttpHeadersTest.cs | 151 ++ .../DefaultHttpRequestTest.cs | 39 + .../DotNetty.Codecs.Http.Tests.csproj | 38 + .../HttpChunkedInputTest.cs | 105 + .../HttpClientCodecTest.cs | 358 +++ .../HttpClientUpgradeHandlerTest.cs | 149 ++ .../HttpContentCompressorTest.cs | 441 ++++ .../HttpContentDecoderTest.cs | 657 ++++++ .../HttpContentEncoderTest.cs | 457 ++++ .../HttpHeadersTest.cs | 64 + .../HttpHeadersTestUtils.cs | 138 ++ .../HttpInvalidMessageTest.cs | 114 + .../HttpObjectAggregatorTest.cs | 500 +++++ .../HttpRequestDecoderTest.cs | 338 +++ .../HttpRequestEncoderTest.cs | 114 + .../HttpResponseDecoderTest.cs | 719 ++++++ .../HttpResponseEncoderTest.cs | 153 ++ .../HttpResponseStatusTest.cs | 90 + .../HttpServerCodecTest.cs | 175 ++ .../HttpServerExpectContinueHandlerTest.cs | 70 + .../HttpServerKeepAliveHandlerTest.cs | 180 ++ .../HttpUtilTest.cs | 256 +++ .../Multipart/AbstractMemoryHttpDataTest.cs | 82 + .../Multipart/DefaultHttpDataFactoryTest.cs | 139 ++ .../Multipart/DiskFileUploadTest.cs | 18 + .../Multipart/HttpPostRequestDecoderTest.cs | 607 +++++ .../Multipart/HttpPostRequestEncoderTest.cs | 442 ++++ .../Multipart/MemoryFileUploadTest.cs | 18 + .../Multipart/file-01.txt | 1 + .../Multipart/file-02.txt | 1 + .../Properties/AssemblyInfo.cs | 42 + .../QueryStringDecoderTest.cs | 366 ++++ .../QueryStringEncoderTest.cs | 65 + .../RoundTripTests.cs | 1 - .../DateFormatterTest.cs | 108 + .../DefaultHeadersTest.cs | 652 ++++++ .../Utilities/AsciiStringCharacterTest.cs | 313 +++ .../DotNetty.Handlers.Tests/TlsHandlerTest.cs | 4 +- .../PooledHeapByteBufferAllocatorBenchmark.cs | 15 + .../UnpooledByteBufferAllocatorBenchmark.cs | 4 +- .../Buffers/PooledByteBufferBenchmark.cs | 2 +- .../Buffers/UnpooledByteBufferBenchmark.cs | 2 +- .../Codecs/DateFormatterBenchmark.cs | 24 + .../Common/AsciiStringBenchmark.cs | 50 + .../DotNetty.Microbench.csproj | 1 + .../Headers/ExampleHeaders.cs | 149 ++ .../Headers/HeadersBenchmark.cs | 111 + .../Http/ClientCookieDecoderBenchmark.cs | 28 + .../Http/HttpRequestDecoderBenchmark.cs | 108 + .../Http/HttpRequestEncoderInsertBenchmark.cs | 43 + .../WriteBytesVsShortOrMediumBenchmark.cs | 58 + .../Internal/PlatformDependentBenchmark.cs | 35 + test/DotNetty.Microbench/Program.cs | 27 +- .../SocketDatagramChannelUnicastTest.cs | 10 +- 262 files changed, 42432 insertions(+), 240 deletions(-) create mode 100644 examples/HttpServer/HelloServerHandler.cs create mode 100644 examples/HttpServer/HttpServer.csproj create mode 100644 examples/HttpServer/MessageBody.cs create mode 100644 examples/HttpServer/Program.cs create mode 100644 examples/HttpServer/Properties/AssemblyInfo.cs create mode 100644 examples/HttpServer/appsettings.json rename src/DotNetty.Buffers/{UnsafeByteBufferUtil .cs => UnsafeByteBufferUtil.cs} (97%) create mode 100644 src/DotNetty.Codecs.Http/CombinedHttpHeaders.cs create mode 100644 src/DotNetty.Codecs.Http/ComposedLastHttpContent.cs create mode 100644 src/DotNetty.Codecs.Http/Cookies/ClientCookieDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/Cookies/ClientCookieEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/Cookies/CookieDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/Cookies/CookieEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/Cookies/CookieHeaderNames.cs create mode 100644 src/DotNetty.Codecs.Http/Cookies/CookieUtil.cs create mode 100644 src/DotNetty.Codecs.Http/Cookies/DefaultCookie.cs create mode 100644 src/DotNetty.Codecs.Http/Cookies/ICookie.cs create mode 100644 src/DotNetty.Codecs.Http/Cookies/ServerCookieDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/Cookies/ServerCookieEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/Cors/CorsConfig.cs create mode 100644 src/DotNetty.Codecs.Http/Cors/CorsConfigBuilder.cs create mode 100644 src/DotNetty.Codecs.Http/Cors/CorsHandler.cs create mode 100644 src/DotNetty.Codecs.Http/DefaultFullHttpRequest.cs create mode 100644 src/DotNetty.Codecs.Http/DefaultFullHttpResponse.cs create mode 100644 src/DotNetty.Codecs.Http/DefaultHttpContent.cs create mode 100644 src/DotNetty.Codecs.Http/DefaultHttpHeaders.cs create mode 100644 src/DotNetty.Codecs.Http/DefaultHttpMessage.cs create mode 100644 src/DotNetty.Codecs.Http/DefaultHttpObject.cs create mode 100644 src/DotNetty.Codecs.Http/DefaultHttpRequest.cs create mode 100644 src/DotNetty.Codecs.Http/DefaultHttpResponse.cs create mode 100644 src/DotNetty.Codecs.Http/DefaultLastHttpContent.cs create mode 100644 src/DotNetty.Codecs.Http/DotNetty.Codecs.Http.csproj create mode 100644 src/DotNetty.Codecs.Http/EmptyHttpHeaders.cs create mode 100644 src/DotNetty.Codecs.Http/EmptyLastHttpContent.cs create mode 100644 src/DotNetty.Codecs.Http/HttpChunkedInput.cs create mode 100644 src/DotNetty.Codecs.Http/HttpClientCodec.cs create mode 100644 src/DotNetty.Codecs.Http/HttpClientUpgradeHandler.cs create mode 100644 src/DotNetty.Codecs.Http/HttpConstants.cs create mode 100644 src/DotNetty.Codecs.Http/HttpContentCompressor.cs create mode 100644 src/DotNetty.Codecs.Http/HttpContentDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/HttpContentDecompressor.cs create mode 100644 src/DotNetty.Codecs.Http/HttpContentEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/HttpExpectationFailedEvent.cs create mode 100644 src/DotNetty.Codecs.Http/HttpHeaderNames.cs create mode 100644 src/DotNetty.Codecs.Http/HttpHeaderValues.cs create mode 100644 src/DotNetty.Codecs.Http/HttpHeaders.cs create mode 100644 src/DotNetty.Codecs.Http/HttpHeadersEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/HttpMessageUtil.cs create mode 100644 src/DotNetty.Codecs.Http/HttpMethod.cs create mode 100644 src/DotNetty.Codecs.Http/HttpObjectAggregator.cs create mode 100644 src/DotNetty.Codecs.Http/HttpObjectDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/HttpObjectEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/HttpRequestDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/HttpRequestEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/HttpResponseDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/HttpResponseEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/HttpResponseStatus.cs create mode 100644 src/DotNetty.Codecs.Http/HttpScheme.cs create mode 100644 src/DotNetty.Codecs.Http/HttpServerCodec.cs create mode 100644 src/DotNetty.Codecs.Http/HttpServerExpectContinueHandler.cs create mode 100644 src/DotNetty.Codecs.Http/HttpServerKeepAliveHandler.cs create mode 100644 src/DotNetty.Codecs.Http/HttpServerUpgradeHandler.cs create mode 100644 src/DotNetty.Codecs.Http/HttpStatusClass.cs create mode 100644 src/DotNetty.Codecs.Http/HttpUtil.cs create mode 100644 src/DotNetty.Codecs.Http/HttpVersion.cs create mode 100644 src/DotNetty.Codecs.Http/IFullHttpMessage.cs create mode 100644 src/DotNetty.Codecs.Http/IFullHttpRequest.cs create mode 100644 src/DotNetty.Codecs.Http/IFullHttpResponse.cs create mode 100644 src/DotNetty.Codecs.Http/IHttpContent.cs create mode 100644 src/DotNetty.Codecs.Http/IHttpMessage.cs create mode 100644 src/DotNetty.Codecs.Http/IHttpObject.cs create mode 100644 src/DotNetty.Codecs.Http/IHttpRequest.cs create mode 100644 src/DotNetty.Codecs.Http/IHttpResponse.cs create mode 100644 src/DotNetty.Codecs.Http/ILastHttpContent.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/AbstractDiskHttpData.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/AbstractHttpData.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/AbstractMemoryHttpData.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/CaseIgnoringComparator.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/DefaultHttpDataFactory.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/DiskAttribute.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/DiskFileUpload.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/EndOfDataDecoderException.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/ErrorDataDecoderException.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/ErrorDataEncoderException.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/FileUploadUtil.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/HttpPostBodyUtil.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/HttpPostMultipartRequestDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/HttpPostRequestDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/HttpPostRequestEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/HttpPostStandardRequestDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/IAttribute.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/IFileUpload.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/IHttpData .cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/IHttpDataFactory.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/IInterfaceHttpData.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/IInterfaceHttpPostRequestDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/InternalAttribute.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/MemoryAttribute.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/MemoryFileUpload.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/MixedAttribute.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/MixedFileUpload.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/MultiPartStatus.cs create mode 100644 src/DotNetty.Codecs.Http/Multipart/NotEnoughDataDecoderException.cs create mode 100644 src/DotNetty.Codecs.Http/Properties/AssemblyInfo.cs create mode 100644 src/DotNetty.Codecs.Http/Properties/Friends.cs create mode 100644 src/DotNetty.Codecs.Http/QueryStringDecoder.cs create mode 100644 src/DotNetty.Codecs.Http/QueryStringEncoder.cs create mode 100644 src/DotNetty.Codecs.Http/ThrowHelper.cs create mode 100644 src/DotNetty.Codecs.Http/UrlEncoder.cs create mode 100644 src/DotNetty.Codecs/CharSequenceValueConverter.cs create mode 100644 src/DotNetty.Codecs/Compression/Adler32.cs create mode 100644 src/DotNetty.Codecs/Compression/CRC32.cs create mode 100644 src/DotNetty.Codecs/Compression/CompressionException.cs create mode 100644 src/DotNetty.Codecs/Compression/DecompressionException.cs create mode 100644 src/DotNetty.Codecs/Compression/Deflate.cs create mode 100644 src/DotNetty.Codecs/Compression/Deflater.cs create mode 100644 src/DotNetty.Codecs/Compression/GZIPException.cs create mode 100644 src/DotNetty.Codecs/Compression/GZIPHeader.cs create mode 100644 src/DotNetty.Codecs/Compression/IChecksum.cs create mode 100644 src/DotNetty.Codecs/Compression/InfBlocks.cs create mode 100644 src/DotNetty.Codecs/Compression/InfCodes.cs create mode 100644 src/DotNetty.Codecs/Compression/InfTree.cs create mode 100644 src/DotNetty.Codecs/Compression/Inflate.cs create mode 100644 src/DotNetty.Codecs/Compression/Inflater.cs create mode 100644 src/DotNetty.Codecs/Compression/JZlib.cs create mode 100644 src/DotNetty.Codecs/Compression/JZlibDecoder.cs create mode 100644 src/DotNetty.Codecs/Compression/JZlibEncoder.cs create mode 100644 src/DotNetty.Codecs/Compression/StaticTree.cs create mode 100644 src/DotNetty.Codecs/Compression/Tree.cs create mode 100644 src/DotNetty.Codecs/Compression/ZStream.cs create mode 100644 src/DotNetty.Codecs/Compression/ZlibCodecFactory.cs create mode 100644 src/DotNetty.Codecs/Compression/ZlibDecoder.cs create mode 100644 src/DotNetty.Codecs/Compression/ZlibEncoder.cs create mode 100644 src/DotNetty.Codecs/Compression/ZlibUtil.cs create mode 100644 src/DotNetty.Codecs/Compression/ZlibWrapper.cs create mode 100644 src/DotNetty.Codecs/DateFormatter.cs create mode 100644 src/DotNetty.Codecs/DecoderResult.cs create mode 100644 src/DotNetty.Codecs/DefaultHeaders.cs create mode 100644 src/DotNetty.Codecs/HeadersUtils.cs create mode 100644 src/DotNetty.Codecs/IDecoderResultProvider.cs create mode 100644 src/DotNetty.Codecs/IHeaders.cs create mode 100644 src/DotNetty.Codecs/INameValidator.cs create mode 100644 src/DotNetty.Codecs/IValueConverter.cs create mode 100644 src/DotNetty.Codecs/MessageAggregationException.cs create mode 100644 src/DotNetty.Codecs/MessageAggregator.cs create mode 100644 src/DotNetty.Codecs/MessageToMessageCodec.cs create mode 100644 src/DotNetty.Codecs/NullNameValidator.cs create mode 100644 src/DotNetty.Codecs/PrematureChannelClosureException.cs create mode 100644 src/DotNetty.Codecs/Properties/Friends.cs create mode 100644 src/DotNetty.Common/Internal/AppendableCharSequence.cs create mode 100644 src/DotNetty.Common/Internal/EmptyArrays.cs create mode 100644 src/DotNetty.Common/Internal/IAppendable.cs create mode 100644 src/DotNetty.Common/Utilities/AsciiString.cs create mode 100644 src/DotNetty.Common/Utilities/ByteProcessorUtils.cs create mode 100644 src/DotNetty.Common/Utilities/CharSequenceEnumerator.cs create mode 100644 src/DotNetty.Common/Utilities/ICharSequence.cs create mode 100644 src/DotNetty.Common/Utilities/IHashingStrategy.cs create mode 100644 src/DotNetty.Common/Utilities/Signal.cs create mode 100644 src/DotNetty.Common/Utilities/StringBuilderCharSequence.cs create mode 100644 src/DotNetty.Common/Utilities/StringCharSequence.cs create mode 100644 src/DotNetty.Handlers/Streams/ChunkedStream.cs create mode 100644 src/DotNetty.Handlers/Streams/ChunkedWriteHandler.cs create mode 100644 src/DotNetty.Handlers/Streams/IChunkedInput.cs create mode 100644 src/DotNetty.Transport/Channels/CombinedChannelDuplexHandler.cs create mode 100644 src/DotNetty.Transport/Channels/DefaultFileRegion.cs create mode 100644 src/DotNetty.Transport/Channels/IFileRegion.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/CombinedHttpHeadersTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/Cookies/ClientCookieDecoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/Cookies/ClientCookieEncoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/Cookies/ServerCookieDecoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/Cookies/ServerCookieEncoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/Cors/CorsConfigTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/Cors/CorsHandlerTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/DefaultHttpHeadersTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/DefaultHttpRequestTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/DotNetty.Codecs.Http.Tests.csproj create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpChunkedInputTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpClientCodecTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpClientUpgradeHandlerTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpContentCompressorTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpContentDecoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpContentEncoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpHeadersTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpHeadersTestUtils.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpInvalidMessageTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpObjectAggregatorTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpRequestDecoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpRequestEncoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpResponseDecoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpResponseEncoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpResponseStatusTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpServerCodecTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpServerExpectContinueHandlerTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpServerKeepAliveHandlerTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/HttpUtilTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/Multipart/AbstractMemoryHttpDataTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/Multipart/DefaultHttpDataFactoryTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/Multipart/DiskFileUploadTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/Multipart/HttpPostRequestDecoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/Multipart/HttpPostRequestEncoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/Multipart/MemoryFileUploadTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/Multipart/file-01.txt create mode 100644 test/DotNetty.Codecs.Http.Tests/Multipart/file-02.txt create mode 100644 test/DotNetty.Codecs.Http.Tests/Properties/AssemblyInfo.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/QueryStringDecoderTest.cs create mode 100644 test/DotNetty.Codecs.Http.Tests/QueryStringEncoderTest.cs create mode 100644 test/DotNetty.Codecs.Tests/DateFormatterTest.cs create mode 100644 test/DotNetty.Codecs.Tests/DefaultHeadersTest.cs create mode 100644 test/DotNetty.Common.Tests/Utilities/AsciiStringCharacterTest.cs create mode 100644 test/DotNetty.Microbench/Allocators/PooledHeapByteBufferAllocatorBenchmark.cs create mode 100644 test/DotNetty.Microbench/Codecs/DateFormatterBenchmark.cs create mode 100644 test/DotNetty.Microbench/Common/AsciiStringBenchmark.cs create mode 100644 test/DotNetty.Microbench/Headers/ExampleHeaders.cs create mode 100644 test/DotNetty.Microbench/Headers/HeadersBenchmark.cs create mode 100644 test/DotNetty.Microbench/Http/ClientCookieDecoderBenchmark.cs create mode 100644 test/DotNetty.Microbench/Http/HttpRequestDecoderBenchmark.cs create mode 100644 test/DotNetty.Microbench/Http/HttpRequestEncoderInsertBenchmark.cs create mode 100644 test/DotNetty.Microbench/Http/WriteBytesVsShortOrMediumBenchmark.cs create mode 100644 test/DotNetty.Microbench/Internal/PlatformDependentBenchmark.cs diff --git a/.gitattributes b/.gitattributes index 4f1f650..cfa5d53 100644 --- a/.gitattributes +++ b/.gitattributes @@ -42,9 +42,10 @@ *.fs text=auto *.fsx text=auto *.hs text=auto +*.txt eol=crlf *.csproj text=auto *.vbproj text=auto *.fsproj text=auto *.dbproj text=auto -*.sln text=auto eol=crlf \ No newline at end of file +*.sln text=auto eol=crlf diff --git a/DotNetty.sln b/DotNetty.sln index f53522a..d9f9ccb 100644 --- a/DotNetty.sln +++ b/DotNetty.sln @@ -1,7 +1,7 @@  Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio 15 -VisualStudioVersion = 15.0.27004.2010 +VisualStudioVersion = 15.0.27130.2024 MinimumVisualStudioVersion = 10.0.40219.1 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "build", "build", "{F5B1CA65-5852-41C6-9D6F-184A3889237B}" ProjectSection(SolutionItems) = preProject @@ -95,7 +95,13 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DotNetty.Microbench", "test EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DotNetty.Transport.Libuv", "src\DotNetty.Transport.Libuv\DotNetty.Transport.Libuv.csproj", "{9FE6A783-C20D-4097-9988-4178E2C4CE75}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DotNetty.Transport.Libuv.Tests", "test\DotNetty.Transport.Libuv.Tests\DotNetty.Transport.Libuv.Tests.csproj", "{1012C962-7F6D-4EC5-A0EC-0741A95BAD6B}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DotNetty.Transport.Libuv.Tests", "test\DotNetty.Transport.Libuv.Tests\DotNetty.Transport.Libuv.Tests.csproj", "{1012C962-7F6D-4EC5-A0EC-0741A95BAD6B}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HttpServer", "examples\HttpServer\HttpServer.csproj", "{A7CACAE7-66E7-43DA-948B-28EB0DDDB582}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DotNetty.Codecs.Http", "src\DotNetty.Codecs.Http\DotNetty.Codecs.Http.csproj", "{5F68A5B1-7907-4B16-8AFE-326E9DD7D65B}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DotNetty.Codecs.Http.Tests", "test\DotNetty.Codecs.Http.Tests\DotNetty.Codecs.Http.Tests.csproj", "{16C89E7C-1575-4685-8DFA-8E7E2C6101BF}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -255,6 +261,18 @@ Global {1012C962-7F6D-4EC5-A0EC-0741A95BAD6B}.Debug|Any CPU.Build.0 = Debug|Any CPU {1012C962-7F6D-4EC5-A0EC-0741A95BAD6B}.Release|Any CPU.ActiveCfg = Release|Any CPU {1012C962-7F6D-4EC5-A0EC-0741A95BAD6B}.Release|Any CPU.Build.0 = Release|Any CPU + {A7CACAE7-66E7-43DA-948B-28EB0DDDB582}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A7CACAE7-66E7-43DA-948B-28EB0DDDB582}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A7CACAE7-66E7-43DA-948B-28EB0DDDB582}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A7CACAE7-66E7-43DA-948B-28EB0DDDB582}.Release|Any CPU.Build.0 = Release|Any CPU + {5F68A5B1-7907-4B16-8AFE-326E9DD7D65B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {5F68A5B1-7907-4B16-8AFE-326E9DD7D65B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {5F68A5B1-7907-4B16-8AFE-326E9DD7D65B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {5F68A5B1-7907-4B16-8AFE-326E9DD7D65B}.Release|Any CPU.Build.0 = Release|Any CPU + {16C89E7C-1575-4685-8DFA-8E7E2C6101BF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {16C89E7C-1575-4685-8DFA-8E7E2C6101BF}.Debug|Any CPU.Build.0 = Debug|Any CPU + {16C89E7C-1575-4685-8DFA-8E7E2C6101BF}.Release|Any CPU.ActiveCfg = Release|Any CPU + {16C89E7C-1575-4685-8DFA-8E7E2C6101BF}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -298,9 +316,12 @@ Global {7155D1E6-00CE-4081-B922-E6C5524EE600} = {541093F6-616E-43D9-B671-FCD1F9C0A181} {9FE6A783-C20D-4097-9988-4178E2C4CE75} = {126EA539-4B28-4B07-8B5D-D1D7F794D189} {1012C962-7F6D-4EC5-A0EC-0741A95BAD6B} = {541093F6-616E-43D9-B671-FCD1F9C0A181} + {A7CACAE7-66E7-43DA-948B-28EB0DDDB582} = {F716F1EF-81EF-4020-914A-5422A13A9E13} + {5F68A5B1-7907-4B16-8AFE-326E9DD7D65B} = {126EA539-4B28-4B07-8B5D-D1D7F794D189} + {16C89E7C-1575-4685-8DFA-8E7E2C6101BF} = {541093F6-616E-43D9-B671-FCD1F9C0A181} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution - SolutionGuid = {1FBD8DF1-D90A-4F21-8EB6-DA17B9431FE3} {9FE6A783-C20D-4097-9988-4178E2C4CE75} = {126EA539-4B28-4B07-8B5D-D1D7F794D189} + SolutionGuid = {1FBD8DF1-D90A-4F21-8EB6-DA17B9431FE3} EndGlobalSection EndGlobal diff --git a/examples/HttpServer/HelloServerHandler.cs b/examples/HttpServer/HelloServerHandler.cs new file mode 100644 index 0000000..22c1195 --- /dev/null +++ b/examples/HttpServer/HelloServerHandler.cs @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace HttpServer +{ + using System.Text; + using DotNetty.Buffers; + using DotNetty.Codecs.Http; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + using System; + using DotNetty.Common; + + sealed class HelloServerHandler : ChannelHandlerAdapter + { + static readonly ThreadLocalCache Cache = new ThreadLocalCache(); + + sealed class ThreadLocalCache : FastThreadLocal + { + protected override AsciiString GetInitialValue() + { + DateTime dateTime = DateTime.UtcNow; + return AsciiString.Cached($"{dateTime.DayOfWeek}, {dateTime:dd MMM yyyy HH:mm:ss z}"); + } + } + + static readonly byte[] StaticPlaintext = Encoding.UTF8.GetBytes("Hello, World!"); + static readonly int StaticPlaintextLen = StaticPlaintext.Length; + static readonly IByteBuffer PlaintextContentBuffer = Unpooled.UnreleasableBuffer(Unpooled.DirectBuffer().WriteBytes(StaticPlaintext)); + static readonly AsciiString PlaintextClheaderValue = AsciiString.Cached($"{StaticPlaintextLen}"); + static readonly AsciiString JsonClheaderValue = AsciiString.Cached($"{JsonLen()}"); + + static readonly AsciiString TypePlain = AsciiString.Cached("text/plain"); + static readonly AsciiString TypeJson = AsciiString.Cached("application/json"); + static readonly AsciiString ServerName = AsciiString.Cached("Netty"); + static readonly AsciiString ContentTypeEntity = HttpHeaderNames.ContentType; + static readonly AsciiString DateEntity = HttpHeaderNames.Date; + static readonly AsciiString ContentLengthEntity = HttpHeaderNames.ContentLength; + static readonly AsciiString ServerEntity = HttpHeaderNames.Server; + + volatile ICharSequence date = Cache.Value; + + static int JsonLen() => Encoding.UTF8.GetBytes(NewMessage().ToJsonFormat()).Length; + + static MessageBody NewMessage() => new MessageBody("Hello, World!"); + + public override void ChannelRead(IChannelHandlerContext ctx, object message) + { + if (message is IHttpRequest request) + { + try + { + this.Process(ctx, request); + } + finally + { + ReferenceCountUtil.Release(message); + } + } + else + { + ctx.FireChannelRead(message); + } + } + + void Process(IChannelHandlerContext ctx, IHttpRequest request) + { + string uri = request.Uri; + switch (uri) + { + case "/plaintext": + this.WriteResponse(ctx, PlaintextContentBuffer.Duplicate(), TypePlain, PlaintextClheaderValue); + break; + case "/json": + byte[] json = Encoding.UTF8.GetBytes(NewMessage().ToJsonFormat()); + this.WriteResponse(ctx, Unpooled.WrappedBuffer(json), TypeJson, JsonClheaderValue); + break; + default: + var response = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.NotFound, Unpooled.Empty, false); + ctx.WriteAndFlushAsync(response); + ctx.CloseAsync(); + break; + } + } + + void WriteResponse(IChannelHandlerContext ctx, IByteBuffer buf, ICharSequence contentType, ICharSequence contentLength) + { + // Build the response object. + var response = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK, buf, false); + HttpHeaders headers = response.Headers; + headers.Set(ContentTypeEntity, contentType); + headers.Set(ServerEntity, ServerName); + headers.Set(DateEntity, this.date); + headers.Set(ContentLengthEntity, contentLength); + + // Close the non-keep-alive connection after the write operation is done. + ctx.WriteAsync(response); + } + + public override void ExceptionCaught(IChannelHandlerContext context, Exception exception) => context.CloseAsync(); + + public override void ChannelReadComplete(IChannelHandlerContext context) => context.Flush(); + } +} diff --git a/examples/HttpServer/HttpServer.csproj b/examples/HttpServer/HttpServer.csproj new file mode 100644 index 0000000..c154729 --- /dev/null +++ b/examples/HttpServer/HttpServer.csproj @@ -0,0 +1,30 @@ + + + Exe + netcoreapp1.1;net451 + 1.6.1 + false + true + + + win-x64 + + + + PreserveNewest + + + PreserveNewest + + + + + + + + + + + + + \ No newline at end of file diff --git a/examples/HttpServer/MessageBody.cs b/examples/HttpServer/MessageBody.cs new file mode 100644 index 0000000..7ce89d5 --- /dev/null +++ b/examples/HttpServer/MessageBody.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace HttpServer +{ + sealed class MessageBody + { + public MessageBody(string message) + { + this.Message = message; + } + + public string Message { get; } + + public string ToJsonFormat() => "{" + $"\"{nameof(MessageBody)}\" :" + "{" + $"\"{nameof(this.Message)}\"" + " :\"" + this.Message + "\"}" +"}"; + } +} diff --git a/examples/HttpServer/Program.cs b/examples/HttpServer/Program.cs new file mode 100644 index 0000000..7dd99fa --- /dev/null +++ b/examples/HttpServer/Program.cs @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace HttpServer +{ + using System; + using System.IO; + using System.Net; + using System.Runtime; + using System.Runtime.InteropServices; + using System.Security.Cryptography.X509Certificates; + using System.Threading.Tasks; + using DotNetty.Codecs.Http; + using DotNetty.Common; + using DotNetty.Handlers.Tls; + using DotNetty.Transport.Bootstrapping; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Sockets; + using DotNetty.Transport.Libuv; + using Examples.Common; + + class Program + { + static Program() + { + ResourceLeakDetector.Level = ResourceLeakDetector.DetectionLevel.Disabled; + } + + static async Task RunServerAsync() + { + Console.WriteLine( + $"\n{RuntimeInformation.OSArchitecture} {RuntimeInformation.OSDescription}" + + $"\n{RuntimeInformation.ProcessArchitecture} {RuntimeInformation.FrameworkDescription}" + + $"\nProcessor Count : {Environment.ProcessorCount}\n"); + + bool useLibuv = ServerSettings.UseLibuv; + Console.WriteLine("Transport type : " + (useLibuv ? "Libuv" : "Socket")); + + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + GCSettings.LatencyMode = GCLatencyMode.SustainedLowLatency; + } + + Console.WriteLine($"Server garbage collection: {GCSettings.IsServerGC}"); + Console.WriteLine($"Current latency mode for garbage collection: {GCSettings.LatencyMode}"); + + IEventLoopGroup group; + IEventLoopGroup workGroup; + if (useLibuv) + { + var dispatcher = new DispatcherEventLoopGroup(); + group = dispatcher; + workGroup = new WorkerEventLoopGroup(dispatcher); + } + else + { + group = new MultithreadEventLoopGroup(1); + workGroup = new MultithreadEventLoopGroup(); + } + + X509Certificate2 tlsCertificate = null; + if (ServerSettings.IsSsl) + { + tlsCertificate = new X509Certificate2(Path.Combine(ExampleHelper.ProcessDirectory, "dotnetty.com.pfx"), "password"); + } + try + { + var bootstrap = new ServerBootstrap(); + bootstrap.Group(group, workGroup); + + if (useLibuv) + { + bootstrap.Channel(); + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux) + || RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + bootstrap + .Option(ChannelOption.SoReuseport, true) + .ChildOption(ChannelOption.SoReuseaddr, true); + } + } + else + { + bootstrap.Channel(); + } + + bootstrap + .Option(ChannelOption.SoBacklog, 8192) + .ChildHandler(new ActionChannelInitializer(channel => + { + IChannelPipeline pipeline = channel.Pipeline; + if (tlsCertificate != null) + { + pipeline.AddLast(TlsHandler.Server(tlsCertificate)); + } + pipeline.AddLast("encoder", new HttpResponseEncoder()); + pipeline.AddLast("decoder", new HttpRequestDecoder(4096, 8192, 8192, false)); + pipeline.AddLast("handler", new HelloServerHandler()); + })); + + IChannel bootstrapChannel = await bootstrap.BindAsync(IPAddress.IPv6Any, ServerSettings.Port); + + Console.WriteLine($"Httpd started. Listening on {bootstrapChannel.LocalAddress}"); + Console.ReadLine(); + + await bootstrapChannel.CloseAsync(); + } + finally + { + group.ShutdownGracefullyAsync().Wait(); + } + } + + static void Main() => RunServerAsync().Wait(); + } +} \ No newline at end of file diff --git a/examples/HttpServer/Properties/AssemblyInfo.cs b/examples/HttpServer/Properties/AssemblyInfo.cs new file mode 100644 index 0000000..fe194e9 --- /dev/null +++ b/examples/HttpServer/Properties/AssemblyInfo.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Reflection; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("HttpWebServer")] +[assembly: AssemblyTrademark("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("d4efc310-c3a7-42a2-bc9c-aa9cce3d1c63")] diff --git a/examples/HttpServer/appsettings.json b/examples/HttpServer/appsettings.json new file mode 100644 index 0000000..2942538 --- /dev/null +++ b/examples/HttpServer/appsettings.json @@ -0,0 +1,4 @@ +{ + "port": "7686", + "libuv": "true" +} \ No newline at end of file diff --git a/src/DotNetty.Buffers/AbstractByteBuffer.cs b/src/DotNetty.Buffers/AbstractByteBuffer.cs index 01b6289..96e85bd 100644 --- a/src/DotNetty.Buffers/AbstractByteBuffer.cs +++ b/src/DotNetty.Buffers/AbstractByteBuffer.cs @@ -491,6 +491,42 @@ namespace DotNetty.Buffers return value; } + public virtual unsafe ICharSequence GetCharSequence(int index, int length, Encoding encoding) + { + this.CheckIndex0(index, length); + if (length == 0) + { + return StringCharSequence.Empty; + } + + if (this.HasMemoryAddress) + { + IntPtr ptr = this.AddressOfPinnedMemory(); + if (ptr != IntPtr.Zero) + { + return new StringCharSequence(UnsafeByteBufferUtil.GetString((byte*)(ptr + index), length, encoding)); + } + else + { + fixed (byte* p = &this.GetPinnableMemoryAddress()) + return new StringCharSequence(UnsafeByteBufferUtil.GetString(p + index, length, encoding)); + } + } + if (this.HasArray) + { + return new StringCharSequence(encoding.GetString(this.Array, this.ArrayOffset + index, length)); + } + + return new StringCharSequence(this.ToString(index, length, encoding)); + } + + public virtual ICharSequence ReadCharSequence(int length, Encoding encoding) + { + ICharSequence sequence = this.GetCharSequence(this.readerIndex, length, encoding); + this.readerIndex += length; + return sequence; + } + public virtual IByteBuffer SetByte(int index, int value) { this.CheckIndex(index); @@ -633,6 +669,7 @@ namespace DotNetty.Buffers this.SetBytes(index, src, 0, src.Length); return this; } + public abstract IByteBuffer SetBytes(int index, byte[] src, int srcIndex, int length); public virtual IByteBuffer SetBytes(int index, IByteBuffer src) @@ -744,6 +781,48 @@ namespace DotNetty.Buffers return bytes.Length; } + public virtual int SetCharSequence(int index, ICharSequence sequence, Encoding encoding) => this.SetCharSequence0(index, sequence, encoding, false); + + int SetCharSequence0(int index, ICharSequence sequence, Encoding encoding, bool expand) + { + if (ReferenceEquals(encoding, Encoding.UTF8)) + { + int length = ByteBufferUtil.Utf8MaxBytes(sequence); + if (expand) + { + this.EnsureWritable0(length); + this.CheckIndex0(index, length); + } + else + { + this.CheckIndex(index, length); + } + return ByteBufferUtil.WriteUtf8(this, index, sequence, sequence.Count); + } + if (ReferenceEquals(encoding, Encoding.ASCII)) + { + int length = sequence.Count; + if (expand) + { + this.EnsureWritable0(length); + this.CheckIndex0(index, length); + } + else + { + this.CheckIndex(index, length); + } + return ByteBufferUtil.WriteAscii(this, index, sequence, length); + } + byte[] bytes = encoding.GetBytes(sequence.ToString()); + if (expand) + { + this.EnsureWritable0(bytes.Length); + // setBytes(...) will take care of checking the indices. + } + this.SetBytes(index, bytes); + return bytes.Length; + } + public virtual byte ReadByte() { this.CheckReadableBytes0(1); @@ -1184,6 +1263,13 @@ namespace DotNetty.Buffers return this; } + public virtual int WriteCharSequence(ICharSequence sequence, Encoding encoding) + { + int written = this.SetCharSequence0(this.writerIndex, sequence, encoding, true); + this.writerIndex += written; + return written; + } + public virtual int WriteString(string value, Encoding encoding) { int written = this.SetString0(this.writerIndex, value, encoding, true); diff --git a/src/DotNetty.Buffers/AbstractPooledDerivedByteBuffer.cs b/src/DotNetty.Buffers/AbstractPooledDerivedByteBuffer.cs index 9f18248..2544bb8 100644 --- a/src/DotNetty.Buffers/AbstractPooledDerivedByteBuffer.cs +++ b/src/DotNetty.Buffers/AbstractPooledDerivedByteBuffer.cs @@ -3,6 +3,7 @@ namespace DotNetty.Buffers { + using System; using System.Diagnostics; using DotNetty.Common; diff --git a/src/DotNetty.Buffers/ByteBufferUtil.cs b/src/DotNetty.Buffers/ByteBufferUtil.cs index 3ee6501..21b2604 100644 --- a/src/DotNetty.Buffers/ByteBufferUtil.cs +++ b/src/DotNetty.Buffers/ByteBufferUtil.cs @@ -5,6 +5,7 @@ namespace DotNetty.Buffers { using System; using System.Diagnostics.Contracts; + using System.Runtime.CompilerServices; using System.Text; using DotNetty.Common.Internal; using DotNetty.Common.Internal.Logging; @@ -68,8 +69,8 @@ namespace DotNetty.Buffers /// public static string HexDump(byte[] array, int fromIndex, int length) => HexUtil.DoHexDump(array, fromIndex, length); - public static bool EnsureWritableSuccess(int ensureWritableResult) => ensureWritableResult == 0 || ensureWritableResult == 2; - + public static bool EnsureWritableSuccess(int ensureWritableResult) => ensureWritableResult == 0 || ensureWritableResult == 2; + /// /// Calculates the hash code of the specified buffer. This method is /// useful when implementing a new buffer type. @@ -298,6 +299,104 @@ namespace DotNetty.Buffers return buffer.ForEachByteDesc(toIndex, fromIndex - toIndex, new IndexOfProcessor(value)); } + public static IByteBuffer WriteUtf8(IByteBufferAllocator alloc, ICharSequence seq) + { + // UTF-8 uses max. 3 bytes per char, so calculate the worst case. + IByteBuffer buf = alloc.Buffer(Utf8MaxBytes(seq)); + WriteUtf8(buf, seq); + return buf; + } + + public static int WriteUtf8(IByteBuffer buf, ICharSequence seq) => ReserveAndWriteUtf8(buf, seq, Utf8MaxBytes(seq)); + + public static int ReserveAndWriteUtf8(IByteBuffer buf, ICharSequence seq, int reserveBytes) + { + for (;;) + { + if (buf is AbstractByteBuffer byteBuf) + { + byteBuf.EnsureWritable0(reserveBytes); + int written = WriteUtf8(byteBuf, byteBuf.WriterIndex, seq, seq.Count); + byteBuf.SetWriterIndex(byteBuf.WriterIndex + written); + return written; + } + else if (buf is WrappedByteBuffer) + { + // Unwrap as the wrapped buffer may be an AbstractByteBuf and so we can use fast-path. + buf = buf.Unwrap(); + } + else + { + byte[] bytes = Encoding.UTF8.GetBytes(seq.ToString()); + buf.WriteBytes(bytes); + return bytes.Length; + } + } + } + + // Fast-Path implementation + internal static int WriteUtf8(AbstractByteBuffer buffer, int writerIndex, ICharSequence value, int len) + { + int oldWriterIndex = writerIndex; + + // We can use the _set methods as these not need to do any index checks and reference checks. + // This is possible as we called ensureWritable(...) before. + for (int i = 0; i < len; i++) + { + char c = value[i]; + if (c < 0x80) + { + buffer._SetByte(writerIndex++, (byte)c); + } + else if (c < 0x800) + { + buffer._SetByte(writerIndex++, (byte)(0xc0 | (c >> 6))); + buffer._SetByte(writerIndex++, (byte)(0x80 | (c & 0x3f))); + } + else if (char.IsSurrogate(c)) + { + if (!char.IsHighSurrogate(c)) + { + buffer._SetByte(writerIndex++, WriteUtfUnknown); + continue; + } + char c2; + try + { + // Surrogate Pair consumes 2 characters. Optimistically try to get the next character to avoid + // duplicate bounds checking with charAt. If an IndexOutOfBoundsException is thrown we will + // re-throw a more informative exception describing the problem. + c2 = value[++i]; + } + catch (IndexOutOfRangeException) + { + buffer._SetByte(writerIndex++, WriteUtfUnknown); + break; + } + if (!char.IsLowSurrogate(c2)) + { + buffer._SetByte(writerIndex++, WriteUtfUnknown); + buffer._SetByte(writerIndex++, char.IsHighSurrogate(c2) ? WriteUtfUnknown : c2); + continue; + } + int codePoint = CharUtil.ToCodePoint(c, c2); + // See http://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + buffer._SetByte(writerIndex++, (byte)(0xf0 | (codePoint >> 18))); + buffer._SetByte(writerIndex++, (byte)(0x80 | ((codePoint >> 12) & 0x3f))); + buffer._SetByte(writerIndex++, (byte)(0x80 | ((codePoint >> 6) & 0x3f))); + buffer._SetByte(writerIndex++, (byte)(0x80 | (codePoint & 0x3f))); + } + else + { + buffer._SetByte(writerIndex++, (byte)(0xe0 | (c >> 12))); + buffer._SetByte(writerIndex++, (byte)(0x80 | ((c >> 6) & 0x3f))); + buffer._SetByte(writerIndex++, (byte)(0x80 | (c & 0x3f))); + } + } + + return writerIndex - oldWriterIndex; + } + public static IByteBuffer WriteUtf8(IByteBufferAllocator alloc, string value) { // UTF-8 uses max. 3 bytes per char, so calculate the worst case. @@ -403,6 +502,8 @@ namespace DotNetty.Buffers return writerIndex - oldWriterIndex; } + internal static int Utf8MaxBytes(ICharSequence seq) => Utf8MaxBytes(seq.Count); + internal static int Utf8MaxBytes(string seq) => Utf8MaxBytes(seq.Length); internal static int Utf8MaxBytes(int seqLength) => seqLength * MaxBytesPerCharUtf8; @@ -470,6 +571,61 @@ namespace DotNetty.Buffers return encodedLength; } + public static IByteBuffer WriteAscii(IByteBufferAllocator alloc, ICharSequence seq) + { + // ASCII uses 1 byte per char + IByteBuffer buf = alloc.Buffer(seq.Count); + WriteAscii(buf, seq); + return buf; + } + + public static int WriteAscii(IByteBuffer buf, ICharSequence seq) + { + // ASCII uses 1 byte per char + int len = seq.Count; + if (seq is AsciiString asciiString) + { + buf.WriteBytes(asciiString.Array, asciiString.Offset, len); + } + else + { + for (;;) + { + if (buf is AbstractByteBuffer byteBuf) + { + byteBuf.EnsureWritable0(len); + int written = WriteAscii(byteBuf, byteBuf.WriterIndex, seq, len); + byteBuf.SetWriterIndex(byteBuf.WriterIndex + written); + return written; + } + else if (buf is WrappedByteBuffer) + { + // Unwrap as the wrapped buffer may be an AbstractByteBuf and so we can use fast-path. + buf = buf.Unwrap(); + } + else + { + byte[] bytes = Encoding.ASCII.GetBytes(seq.ToString()); + buf.WriteBytes(bytes); + return bytes.Length; + } + } + } + return len; + } + + // Fast-Path implementation + internal static int WriteAscii(AbstractByteBuffer buffer, int writerIndex, ICharSequence seq, int len) + { + // We can use the _set methods as these not need to do any index checks and reference checks. + // This is possible as we called ensureWritable(...) before. + for (int i = 0; i < len; i++) + { + buffer._SetByte(writerIndex++, AsciiString.CharToByte(seq[i])); + } + return len; + } + public static IByteBuffer WriteAscii(IByteBufferAllocator alloc, string value) { // ASCII uses 1 byte per char @@ -592,6 +748,38 @@ namespace DotNetty.Buffers } } + public static void Copy(AsciiString src, IByteBuffer dst) => Copy(src, 0, dst, src.Count); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Copy(AsciiString src, int srcIdx, IByteBuffer dst, int dstIdx, int length) + { + if (MathUtil.IsOutOfBounds(srcIdx, length, src.Count)) + { + ThrowHelper.ThrowIndexOutOfRangeException_Src(srcIdx, length, src.Count); + } + if (dst == null) + { + ThrowHelper.ThrowArgumentNullException_Dst(); + } + // ReSharper disable once PossibleNullReferenceException + dst.SetBytes(dstIdx, src.Array, srcIdx + src.Offset, length); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Copy(AsciiString src, int srcIdx, IByteBuffer dst, int length) + { + if (MathUtil.IsOutOfBounds(srcIdx, length, src.Count)) + { + ThrowHelper.ThrowIndexOutOfRangeException_Src(srcIdx, length, src.Count); + } + if (dst == null) + { + ThrowHelper.ThrowArgumentNullException_Dst(); + } + // ReSharper disable once PossibleNullReferenceException + dst.WriteBytes(src.Array, srcIdx + src.Offset, length); + } + /// /// Returns a multi-line hexadecimal dump of the specified {@link ByteBuf} that is easy to read by humans. /// diff --git a/src/DotNetty.Buffers/EmptyByteBuffer.cs b/src/DotNetty.Buffers/EmptyByteBuffer.cs index f65941c..c4e2316 100644 --- a/src/DotNetty.Buffers/EmptyByteBuffer.cs +++ b/src/DotNetty.Buffers/EmptyByteBuffer.cs @@ -156,6 +156,12 @@ namespace DotNetty.Buffers public IByteBuffer GetBytes(int index, Stream destination, int length) => this.CheckIndex(index, length); + public ICharSequence GetCharSequence(int index, int length, Encoding encoding) + { + this.CheckIndex(index, length); + return null; + } + public string GetString(int index, int length, Encoding encoding) { this.CheckIndex(index, length); @@ -218,6 +224,8 @@ namespace DotNetty.Buffers public IByteBuffer SetZero(int index, int length) => this.CheckIndex(index, length); + public int SetCharSequence(int index, ICharSequence sequence, Encoding encoding) => throw new IndexOutOfRangeException(); + public int SetString(int index, string value, Encoding encoding) => throw new IndexOutOfRangeException(); public bool ReadBoolean() => throw new IndexOutOfRangeException(); @@ -276,6 +284,12 @@ namespace DotNetty.Buffers public IByteBuffer ReadBytes(Stream destination, int length) => this.CheckLength(length); + public ICharSequence ReadCharSequence(int length, Encoding encoding) + { + this.CheckLength(length); + return null; + } + public string ReadString(int length, Encoding encoding) { this.CheckLength(length); @@ -338,6 +352,8 @@ namespace DotNetty.Buffers public IByteBuffer WriteZero(int length) => this.CheckLength(length); + public int WriteCharSequence(ICharSequence sequence, Encoding encoding) => throw new IndexOutOfRangeException(); + public int WriteString(string value, Encoding encoding) => throw new IndexOutOfRangeException(); public int IndexOf(int fromIndex, int toIndex, byte value) diff --git a/src/DotNetty.Buffers/IByteBuffer.cs b/src/DotNetty.Buffers/IByteBuffer.cs index de4b8d7..88bf344 100644 --- a/src/DotNetty.Buffers/IByteBuffer.cs +++ b/src/DotNetty.Buffers/IByteBuffer.cs @@ -483,6 +483,9 @@ namespace DotNetty.Buffers /// IByteBuffer GetBytes(int index, Stream destination, int length); + + ICharSequence GetCharSequence(int index, int length, Encoding encoding); + /// /// Gets a string with the given length at the given index. /// @@ -781,6 +784,8 @@ namespace DotNetty.Buffers /// IByteBuffer SetZero(int index, int length); + int SetCharSequence(int index, ICharSequence sequence, Encoding encoding); + /// /// Writes the specified string at the current writer index and increases /// the writer index by the written bytes. @@ -973,6 +978,8 @@ namespace DotNetty.Buffers IByteBuffer ReadBytes(Stream destination, int length); + ICharSequence ReadCharSequence(int length, Encoding encoding); + /// /// Gets a string with the given length at the current reader index /// and increases the reader index by the given length. @@ -1186,6 +1193,8 @@ namespace DotNetty.Buffers IByteBuffer WriteZero(int length); + int WriteCharSequence(ICharSequence sequence, Encoding encoding); + int WriteString(string value, Encoding encoding); int IndexOf(int fromIndex, int toIndex, byte value); diff --git a/src/DotNetty.Buffers/IByteBufferAllocator.cs b/src/DotNetty.Buffers/IByteBufferAllocator.cs index 548106b..484af2e 100644 --- a/src/DotNetty.Buffers/IByteBufferAllocator.cs +++ b/src/DotNetty.Buffers/IByteBufferAllocator.cs @@ -4,7 +4,7 @@ namespace DotNetty.Buffers { /// - /// Thread-safe interface for allocating instances for use inside Helios reactive I/O + /// Thread-safe interface for allocating /. /// public interface IByteBufferAllocator { diff --git a/src/DotNetty.Buffers/PoolArena.cs b/src/DotNetty.Buffers/PoolArena.cs index 72cc6d8..a5de2e0 100644 --- a/src/DotNetty.Buffers/PoolArena.cs +++ b/src/DotNetty.Buffers/PoolArena.cs @@ -648,7 +648,7 @@ namespace DotNetty.Buffers .Append(i) .Append(": "); PoolSubpage s = head.Next; - for (; ;) + for (; ; ) { buf.Append(s); s = s.Next; @@ -706,7 +706,7 @@ namespace DotNetty.Buffers // Rely on GC. } - protected override PooledByteBuffer NewByteBuf(int maxCapacity) => + protected override PooledByteBuffer NewByteBuf(int maxCapacity) => PooledHeapByteBuffer.NewInstance(maxCapacity); protected override void MemoryCopy(byte[] src, int srcOffset, byte[] dst, int dstOffset, int length) @@ -751,7 +751,7 @@ namespace DotNetty.Buffers return chunk; } - protected override PooledByteBuffer NewByteBuf(int maxCapacity) => + protected override PooledByteBuffer NewByteBuf(int maxCapacity) => PooledUnsafeDirectByteBuffer.NewInstance(maxCapacity); protected override unsafe void MemoryCopy(byte[] src, int srcOffset, byte[] dst, int dstOffset, int length) => diff --git a/src/DotNetty.Buffers/PooledByteBuffer.cs b/src/DotNetty.Buffers/PooledByteBuffer.cs index 9c8e191..8192b0c 100644 --- a/src/DotNetty.Buffers/PooledByteBuffer.cs +++ b/src/DotNetty.Buffers/PooledByteBuffer.cs @@ -59,7 +59,11 @@ namespace DotNetty.Buffers this.DiscardMarks(); } - public override int Capacity => this.Length; + public override int Capacity + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => this.Length; + } public sealed override IByteBuffer AdjustCapacity(int newCapacity) { diff --git a/src/DotNetty.Buffers/PooledDuplicatedByteBuffer.cs b/src/DotNetty.Buffers/PooledDuplicatedByteBuffer.cs index d3d51e4..c83eee4 100644 --- a/src/DotNetty.Buffers/PooledDuplicatedByteBuffer.cs +++ b/src/DotNetty.Buffers/PooledDuplicatedByteBuffer.cs @@ -36,7 +36,7 @@ namespace DotNetty.Buffers return this; } - public override int ArrayOffset => this.Unwrap().ArrayOffset; + public override int ArrayOffset => this.Unwrap().ArrayOffset; public override ref byte GetPinnableMemoryAddress() => ref this.Unwrap().GetPinnableMemoryAddress(); @@ -103,4 +103,4 @@ namespace DotNetty.Buffers protected internal override void _SetLongLE(int index, long value) => this.UnwrapCore()._SetLongLE(index, value); } -} +} \ No newline at end of file diff --git a/src/DotNetty.Buffers/PooledSlicedByteBuffer.cs b/src/DotNetty.Buffers/PooledSlicedByteBuffer.cs index 3a58227..f6277ec 100644 --- a/src/DotNetty.Buffers/PooledSlicedByteBuffer.cs +++ b/src/DotNetty.Buffers/PooledSlicedByteBuffer.cs @@ -41,7 +41,7 @@ namespace DotNetty.Buffers public override int Capacity => this.MaxCapacity; - public override IByteBuffer AdjustCapacity(int newCapacity) =>throw new NotSupportedException("sliced buffer"); + public override IByteBuffer AdjustCapacity(int newCapacity) => throw new NotSupportedException("sliced buffer"); public override int ArrayOffset => this.Idx(this.Unwrap().ArrayOffset); @@ -308,4 +308,4 @@ namespace DotNetty.Buffers int Idx(int index) => index + this.adjustment; } -} +} \ No newline at end of file diff --git a/src/DotNetty.Buffers/ThrowHelper.cs b/src/DotNetty.Buffers/ThrowHelper.cs index cbd68a7..fc54868 100644 --- a/src/DotNetty.Buffers/ThrowHelper.cs +++ b/src/DotNetty.Buffers/ThrowHelper.cs @@ -143,6 +143,17 @@ namespace DotNetty.Buffers } } + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowIndexOutOfRangeException_Src(int srcIndex, int length, int count) + { + throw GetIndexOutOfRangeException(); + + IndexOutOfRangeException GetIndexOutOfRangeException() + { + return new IndexOutOfRangeException(string.Format("expected: 0 <= srcIdx({0}) <= srcIdx + length({1}) <= srcLen({2})", srcIndex, length, count)); + } + } + [MethodImpl(MethodImplOptions.NoInlining)] internal static void ThrowIllegalReferenceCountException(int count) { @@ -230,5 +241,16 @@ namespace DotNetty.Buffers return new ArgumentOutOfRangeException("newCapacity", string.Format($"newCapacity: {0} (expected: 0-{1})", newCapacity, maxCapacity)); } } + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowArgumentNullException_Dst() + { + throw GetArgumentOutOfRangeException(); + + ArgumentNullException GetArgumentOutOfRangeException() + { + return new ArgumentNullException("dst"); + } + } } } diff --git a/src/DotNetty.Buffers/UnpooledUnsafeDirectByteBuffer.cs b/src/DotNetty.Buffers/UnpooledUnsafeDirectByteBuffer.cs index 27bb7c8..9da1340 100644 --- a/src/DotNetty.Buffers/UnpooledUnsafeDirectByteBuffer.cs +++ b/src/DotNetty.Buffers/UnpooledUnsafeDirectByteBuffer.cs @@ -356,4 +356,4 @@ namespace DotNetty.Buffers return this; } } -} +} \ No newline at end of file diff --git a/src/DotNetty.Buffers/UnsafeByteBufferUtil .cs b/src/DotNetty.Buffers/UnsafeByteBufferUtil.cs similarity index 97% rename from src/DotNetty.Buffers/UnsafeByteBufferUtil .cs rename to src/DotNetty.Buffers/UnsafeByteBufferUtil.cs index cea24dc..28fa677 100644 --- a/src/DotNetty.Buffers/UnsafeByteBufferUtil .cs +++ b/src/DotNetty.Buffers/UnsafeByteBufferUtil.cs @@ -24,28 +24,28 @@ namespace DotNetty.Buffers [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static int GetUnsignedMedium(byte* bytes) => - *bytes << 16 | - *(bytes + 1) << 8 | + *bytes << 16 | + *(bytes + 1) << 8 | *(bytes + 2); [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static int GetUnsignedMediumLE(byte* bytes) => - *bytes | - *(bytes + 1) << 8 | + *bytes | + *(bytes + 1) << 8 | *(bytes + 2) << 16; [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static int GetInt(byte* bytes) => - (*bytes << 24) | - (*(bytes + 1) << 16) | - (*(bytes + 2) << 8) | + (*bytes << 24) | + (*(bytes + 1) << 16) | + (*(bytes + 2) << 8) | (*(bytes + 3)); [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static int GetIntLE(byte* bytes) => - *bytes | + *bytes | (*(bytes + 1) << 8) | - (*(bytes + 2) << 16) | + (*(bytes + 2) << 16) | (*(bytes + 3) << 24); [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -365,4 +365,4 @@ namespace DotNetty.Buffers internal static UnpooledUnsafeDirectByteBuffer NewUnsafeDirectByteBuffer(IByteBufferAllocator alloc, int initialCapacity, int maxCapacity) => new UnpooledUnsafeDirectByteBuffer(alloc, initialCapacity, maxCapacity); } -} +} \ No newline at end of file diff --git a/src/DotNetty.Buffers/WrappedByteBuffer.cs b/src/DotNetty.Buffers/WrappedByteBuffer.cs index 4fd2007..4fa7068 100644 --- a/src/DotNetty.Buffers/WrappedByteBuffer.cs +++ b/src/DotNetty.Buffers/WrappedByteBuffer.cs @@ -210,6 +210,8 @@ namespace DotNetty.Buffers return this; } + public ICharSequence GetCharSequence(int index, int length, Encoding encoding) => this.Buf.GetCharSequence(index, length, encoding); + public string GetString(int index, int length, Encoding encoding) => this.Buf.GetString(index, length, encoding); public virtual IByteBuffer SetBoolean(int index, bool value) @@ -350,6 +352,8 @@ namespace DotNetty.Buffers return this; } + public int SetCharSequence(int index, ICharSequence sequence, Encoding encoding) => this.Buf.SetCharSequence(index, sequence, encoding); + public virtual bool ReadBoolean() => this.Buf.ReadBoolean(); public virtual byte ReadByte() => this.Buf.ReadByte(); @@ -436,6 +440,8 @@ namespace DotNetty.Buffers return this; } + public ICharSequence ReadCharSequence(int length, Encoding encoding) => this.Buf.ReadCharSequence(length, encoding); + public string ReadString(int length, Encoding encoding) => this.Buf.ReadString(length, encoding); public virtual IByteBuffer SkipBytes(int length) @@ -576,6 +582,8 @@ namespace DotNetty.Buffers return this; } + public int WriteCharSequence(ICharSequence sequence, Encoding encoding) => this.Buf.WriteCharSequence(sequence, encoding); + public int WriteString(string value, Encoding encoding) => this.Buf.WriteString(value, encoding); public virtual int IndexOf(int fromIndex, int toIndex, byte value) => this.Buf.IndexOf(fromIndex, toIndex, value); diff --git a/src/DotNetty.Buffers/WrappedCompositeByteBuffer.cs b/src/DotNetty.Buffers/WrappedCompositeByteBuffer.cs index 362ceb4..07dff4b 100644 --- a/src/DotNetty.Buffers/WrappedCompositeByteBuffer.cs +++ b/src/DotNetty.Buffers/WrappedCompositeByteBuffer.cs @@ -522,6 +522,12 @@ namespace DotNetty.Buffers return this; } + public override ICharSequence GetCharSequence(int index, int length, Encoding encoding) => this.wrapped.GetCharSequence(index, length, encoding); + + public override ICharSequence ReadCharSequence(int length, Encoding encoding) => this.wrapped.ReadCharSequence(length, encoding); + + public override int SetCharSequence(int index, ICharSequence sequence, Encoding encoding) => this.wrapped.SetCharSequence(index, sequence, encoding); + public override string GetString(int index, int length, Encoding encoding) => this.wrapped.GetString(index, length, encoding); public override string ReadString(int length, Encoding encoding) => this.wrapped.ReadString(length, encoding); @@ -530,6 +536,8 @@ namespace DotNetty.Buffers public override IByteBuffer ReadBytes(Stream destination, int length) => this.wrapped.ReadBytes(destination, length); + public override int WriteCharSequence(ICharSequence sequence, Encoding encoding) => this.wrapped.WriteCharSequence(sequence, encoding); + public override int WriteString(string value, Encoding encoding) => this.wrapped.WriteString(value, encoding); public override IByteBuffer SkipBytes(int length) diff --git a/src/DotNetty.Codecs.Http/CombinedHttpHeaders.cs b/src/DotNetty.Codecs.Http/CombinedHttpHeaders.cs new file mode 100644 index 0000000..af460a0 --- /dev/null +++ b/src/DotNetty.Codecs.Http/CombinedHttpHeaders.cs @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Linq; + using DotNetty.Common.Utilities; + + using static Common.Utilities.StringUtil; + + public sealed class CombinedHttpHeaders : DefaultHttpHeaders + { + public CombinedHttpHeaders(bool validate) + : base(new CombinedHttpHeadersImpl(AsciiString.CaseSensitiveHasher, ValueConverter(validate), NameValidator(validate))) + { + } + + public override bool ContainsValue(AsciiString name, ICharSequence value, bool ignoreCase) => + base.ContainsValue(name, TrimOws(value), ignoreCase); + + sealed class CombinedHttpHeadersImpl : DefaultHeaders + { + // An estimate of the size of a header value. + const int ValueLengthEstimate = 10; + + public CombinedHttpHeadersImpl(IHashingStrategy nameHashingStrategy, + IValueConverter valueConverter, INameValidator nameValidator) + : base(nameHashingStrategy, valueConverter, nameValidator) + { + } + + public override IEnumerable ValueIterator(AsciiString name) + { + ICharSequence value = null; + foreach (ICharSequence v in base.ValueIterator(name)) + { + if (value != null) + { + throw new InvalidOperationException($"{nameof(CombinedHttpHeaders)} should only have one value"); + } + value = v; + } + return value != null ? UnescapeCsvFields(value) : Enumerable.Empty(); + } + + public override IList GetAll(AsciiString name) + { + IList values = base.GetAll(name); + if (values.Count == 0) + { + return values; + } + if (values.Count != 1) + { + throw new InvalidOperationException($"{nameof(CombinedHttpHeaders)} should only have one value"); + } + + return UnescapeCsvFields(values[0]); + } + + public override IHeaders Add(IHeaders headers) + { + // Override the fast-copy mechanism used by DefaultHeaders + if (ReferenceEquals(headers, this)) + { + throw new ArgumentException("can't add to itself."); + } + + if (headers is CombinedHttpHeadersImpl) + { + if (this.IsEmpty) + { + // Can use the fast underlying copy + this.AddImpl(headers); + } + else + { + // Values are already escaped so don't escape again + foreach (HeaderEntry header in headers) + { + this.AddEscapedValue(header.Key, header.Value); + } + } + } + else + { + foreach (HeaderEntry header in headers) + { + this.Add(header.Key, header.Value); + } + } + + return this; + } + + public override IHeaders Set(IHeaders headers) + { + if (ReferenceEquals(headers, this)) + { + return this; + } + this.Clear(); + return this.Add(headers); + } + + public override IHeaders SetAll(IHeaders headers) + { + if (ReferenceEquals(headers, this)) + { + return this; + } + foreach (AsciiString key in headers.Names()) + { + this.Remove(key); + } + return this.Add(headers); + } + + public override IHeaders Add(AsciiString name, ICharSequence value) => + this.AddEscapedValue(name, EscapeCsv(value)); + + public override IHeaders Add(AsciiString name, IEnumerable values) => + this.AddEscapedValue(name, CommaSeparate(values)); + + public override IHeaders AddObject(AsciiString name, object value) => + this.AddEscapedValue(name, EscapeCsv(this.ValueConverter.ConvertObject(value))); + + public override IHeaders AddObject(AsciiString name, IEnumerable values) => + this.AddEscapedValue(name, this.CommaSeparate(values)); + + public override IHeaders AddObject(AsciiString name, params object[] values) => + this.AddEscapedValue(name, this.CommaSeparate(values)); + + public override IHeaders Set(AsciiString name, IEnumerable values) + { + base.Set(name, CommaSeparate(values)); + return this; + } + + public override IHeaders SetObject(AsciiString name, object value) + { + ICharSequence charSequence = EscapeCsv(this.ValueConverter.ConvertObject(value)); + base.Set(name, charSequence); + return this; + } + + public override IHeaders SetObject(AsciiString name, IEnumerable values) + { + base.Set(name, this.CommaSeparate(values)); + return this; + } + + CombinedHttpHeadersImpl AddEscapedValue(AsciiString name, ICharSequence escapedValue) + { + if (!this.TryGet(name, out ICharSequence currentValue)) + { + base.Add(name, escapedValue); + } + else + { + base.Set(name, CommaSeparateEscapedValues(currentValue, escapedValue)); + } + + return this; + } + + ICharSequence CommaSeparate(IEnumerable values) + { + StringBuilderCharSequence sb = values is ICollection collection + ? new StringBuilderCharSequence(collection.Count * ValueLengthEstimate) + : new StringBuilderCharSequence(); + + foreach (object value in values) + { + if (sb.Count > 0) + { + sb.Append(Comma); + } + + sb.Append(EscapeCsv(this.ValueConverter.ConvertObject(value))); + } + + return sb; + } + + static ICharSequence CommaSeparate(IEnumerable values) + { + StringBuilderCharSequence sb = values is ICollection collection + ? new StringBuilderCharSequence(collection.Count * ValueLengthEstimate) + : new StringBuilderCharSequence(); + + foreach (ICharSequence value in values) + { + if (sb.Count > 0) + { + sb.Append(Comma); + } + + sb.Append(EscapeCsv(value)); + } + + return sb; + } + + static ICharSequence CommaSeparateEscapedValues(ICharSequence currentValue, ICharSequence value) + { + var builder = new StringBuilderCharSequence(currentValue.Count + 1 + value.Count); + builder.Append(currentValue); + builder.Append(Comma); + builder.Append(value); + + return builder; + } + + static ICharSequence EscapeCsv(ICharSequence value) => StringUtil.EscapeCsv(value, true); + } + } +} diff --git a/src/DotNetty.Codecs.Http/ComposedLastHttpContent.cs b/src/DotNetty.Codecs.Http/ComposedLastHttpContent.cs new file mode 100644 index 0000000..adf0f17 --- /dev/null +++ b/src/DotNetty.Codecs.Http/ComposedLastHttpContent.cs @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http +{ + using DotNetty.Buffers; + using DotNetty.Common; + + public sealed class ComposedLastHttpContent : ILastHttpContent + { + readonly HttpHeaders trailingHeaders; + DecoderResult result; + + internal ComposedLastHttpContent(HttpHeaders trailingHeaders) + { + this.trailingHeaders = trailingHeaders; + } + + public HttpHeaders TrailingHeaders => this.trailingHeaders; + + public IByteBufferHolder Copy() + { + var content = new DefaultLastHttpContent(Unpooled.Empty); + content.TrailingHeaders.Set(this.trailingHeaders); + return content; + } + + public IByteBufferHolder Duplicate() => this.Copy(); + + public IByteBufferHolder RetainedDuplicate() => this.Copy(); + + public IByteBufferHolder Replace(IByteBuffer content) + { + var dup = new DefaultLastHttpContent(content); + dup.TrailingHeaders.SetAll(this.trailingHeaders); + return dup; + } + + public IReferenceCounted Retain() => this; + + public IReferenceCounted Retain(int increment) => this; + + public IReferenceCounted Touch() => this; + + public IReferenceCounted Touch(object hint) => this; + + public IByteBuffer Content => Unpooled.Empty; + + public DecoderResult Result + { + get => this.result; + set => this.result = value; + } + + public int ReferenceCount => 1; + + public bool Release() => false; + + public bool Release(int decrement) => false; + } +} diff --git a/src/DotNetty.Codecs.Http/Cookies/ClientCookieDecoder.cs b/src/DotNetty.Codecs.Http/Cookies/ClientCookieDecoder.cs new file mode 100644 index 0000000..d0630c0 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Cookies/ClientCookieDecoder.cs @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Cookies +{ + using System; + using System.Diagnostics; + using System.Diagnostics.Contracts; + using DotNetty.Common.Utilities; + + public sealed class ClientCookieDecoder : CookieDecoder + { + // Strict encoder that validates that name and value chars are in the valid scope + // defined in RFC6265 + public static readonly ClientCookieDecoder StrictDecoder = new ClientCookieDecoder(true); + + // Lax instance that doesn't validate name and value + public static readonly ClientCookieDecoder LaxDecoder = new ClientCookieDecoder(false); + + ClientCookieDecoder(bool strict) : base(strict) + { + } + + public ICookie Decode(string header) + { + Contract.Requires(header != null); + + int headerLen = header.Length; + if (headerLen == 0) + { + return null; + } + + CookieBuilder cookieBuilder = null; + //loop: + for (int i = 0;;) + { + + // Skip spaces and separators. + for (;;) + { + if (i == headerLen) + { + goto loop; + } + char c = header[i]; + if (c == ',') + { + // Having multiple cookies in a single Set-Cookie header is + // deprecated, modern browsers only parse the first one + goto loop; + + } + else if (c == '\t' || c == '\n' || c == 0x0b || c == '\f' + || c == '\r' || c == ' ' || c == ';') + { + i++; + continue; + } + break; + } + + int nameBegin = i; + int nameEnd; + int valueBegin; + int valueEnd; + + for (;;) + { + char curChar = header[i]; + if (curChar == ';') + { + // NAME; (no value till ';') + nameEnd = i; + valueBegin = valueEnd = -1; + break; + + } + else if (curChar == '=') + { + // NAME=VALUE + nameEnd = i; + i++; + if (i == headerLen) + { + // NAME= (empty value, i.e. nothing after '=') + valueBegin = valueEnd = 0; + break; + } + + valueBegin = i; + // NAME=VALUE; + int semiPos = header.IndexOf(';', i); + valueEnd = i = semiPos > 0 ? semiPos : headerLen; + break; + } + else + { + i++; + } + + if (i == headerLen) + { + // NAME (no value till the end of string) + nameEnd = headerLen; + valueBegin = valueEnd = -1; + break; + } + } + + if (valueEnd > 0 && header[valueEnd - 1] == ',') + { + // old multiple cookies separator, skipping it + valueEnd--; + } + + if (cookieBuilder == null) + { + // cookie name-value pair + DefaultCookie cookie = this.InitCookie(header, nameBegin, nameEnd, valueBegin, valueEnd); + + if (cookie == null) + { + return null; + } + + cookieBuilder = new CookieBuilder(cookie, header); + } + else + { + // cookie attribute + cookieBuilder.AppendAttribute(nameBegin, nameEnd, valueBegin, valueEnd); + } + } + + loop: + Debug.Assert(cookieBuilder != null); + return cookieBuilder.Cookie(); + } + + sealed class CookieBuilder + { + readonly string header; + readonly DefaultCookie cookie; + string domain; + string path; + long maxAge = long.MinValue; + int expiresStart; + int expiresEnd; + bool secure; + bool httpOnly; + + internal CookieBuilder(DefaultCookie cookie, string header) + { + this.cookie = cookie; + this.header = header; + } + + long MergeMaxAgeAndExpires() + { + // max age has precedence over expires + if (this.maxAge != long.MinValue) + { + return this.maxAge; + } + else if (IsValueDefined(this.expiresStart, this.expiresEnd)) + { + DateTime? expiresDate = DateFormatter.ParseHttpDate(this.header, this.expiresStart, this.expiresEnd); + if (expiresDate != null) + { + return (expiresDate.Value.Ticks - DateTime.UtcNow.Ticks) / TimeSpan.TicksPerSecond; + } + } + return long.MinValue; + } + + internal ICookie Cookie() + { + this.cookie.Domain = this.domain; + this.cookie.Path = this.path; + this.cookie.MaxAge = this.MergeMaxAgeAndExpires(); + this.cookie.IsSecure = this.secure; + this.cookie.IsHttpOnly = this.httpOnly; + + return this.cookie; + } + + public void AppendAttribute(int keyStart, int keyEnd, int valueStart, int valueEnd) + { + int length = keyEnd - keyStart; + + if (length == 4) + { + this.Parse4(keyStart, valueStart, valueEnd); + } + else if (length == 6) + { + this.Parse6(keyStart, valueStart, valueEnd); + } + else if (length == 7) + { + this.Parse7(keyStart, valueStart, valueEnd); + } + else if (length == 8) + { + this.Parse8(keyStart); + } + } + + void Parse4(int nameStart, int valueStart, int valueEnd) + { + if (CharUtil.RegionMatchesIgnoreCase(this.header, nameStart, CookieHeaderNames.Path, 0, 4)) + { + this.path = this.ComputeValue(valueStart, valueEnd); + } + } + + void Parse6(int nameStart, int valueStart, int valueEnd) + { + if (CharUtil.RegionMatchesIgnoreCase(this.header, nameStart, CookieHeaderNames.Domain, 0, 5)) + { + this.domain = this.ComputeValue(valueStart, valueEnd); + } + else if (CharUtil.RegionMatchesIgnoreCase(this.header, nameStart, CookieHeaderNames.Secure, 0, 5)) + { + this.secure = true; + } + } + + void SetMaxAge(string value) + { + if (long.TryParse(value, out long v)) + { + this.maxAge = Math.Max(v, 0); + } + } + + void Parse7(int nameStart, int valueStart, int valueEnd) + { + if (CharUtil.RegionMatchesIgnoreCase(this.header, nameStart, CookieHeaderNames.Expires, 0, 7)) + { + this.expiresStart = valueStart; + this.expiresEnd = valueEnd; + } + else if (CharUtil.RegionMatchesIgnoreCase(this.header, nameStart, CookieHeaderNames.MaxAge, 0, 7)) + { + this.SetMaxAge(this.ComputeValue(valueStart, valueEnd)); + } + } + + void Parse8(int nameStart) + { + if (CharUtil.RegionMatchesIgnoreCase(this.header, nameStart, CookieHeaderNames.HttpOnly, 0, 8)) + { + this.httpOnly = true; + } + } + + static bool IsValueDefined(int valueStart, int valueEnd) => valueStart != -1 && valueStart != valueEnd; + + string ComputeValue(int valueStart, int valueEnd) => IsValueDefined(valueStart, valueEnd) + ? this.header.Substring(valueStart, valueEnd - valueStart) + : null; + } + } +} diff --git a/src/DotNetty.Codecs.Http/Cookies/ClientCookieEncoder.cs b/src/DotNetty.Codecs.Http/Cookies/ClientCookieEncoder.cs new file mode 100644 index 0000000..a9f425b --- /dev/null +++ b/src/DotNetty.Codecs.Http/Cookies/ClientCookieEncoder.cs @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Cookies +{ + using System; + using System.Collections.Generic; + using System.Diagnostics; + using System.Diagnostics.Contracts; + using System.Text; + + using static CookieUtil; + + public sealed class ClientCookieEncoder : CookieEncoder + { + // Strict encoder that validates that name and value chars are in the valid scope and (for methods that accept + // multiple cookies) sorts cookies into order of decreasing path length, as specified in RFC6265. + public static readonly ClientCookieEncoder StrictEncoder = new ClientCookieEncoder(true); + + // Lax instance that doesn't validate name and value, and (for methods that accept multiple cookies) keeps + // cookies in the order in which they were given. + public static readonly ClientCookieEncoder LaxEncoder = new ClientCookieEncoder(false); + + static readonly CookieComparer Comparer = new CookieComparer(); + + ClientCookieEncoder(bool strict) : base(strict) + { + } + + public string Encode(string name, string value) => this.Encode(new DefaultCookie(name, value)); + + public string Encode(ICookie cookie) + { + Contract.Requires(cookie != null); + + StringBuilder buf = StringBuilder(); + this.Encode(buf, cookie); + return StripTrailingSeparator(buf); + } + + sealed class CookieComparer : IComparer + { + public int Compare(ICookie c1, ICookie c2) + { + Debug.Assert(c1 != null && c2 != null); + + string path1 = c1.Path; + string path2 = c2.Path; + // Cookies with unspecified path default to the path of the request. We don't + // know the request path here, but we assume that the length of an unspecified + // path is longer than any specified path (i.e. pathless cookies come first), + // because setting cookies with a path longer than the request path is of + // limited use. + int len1 = path1?.Length ?? int.MaxValue; + int len2 = path2?.Length ?? int.MaxValue; + int diff = len2 - len1; + if (diff != 0) + { + return diff; + } + // Rely on Java's sort stability to retain creation order in cases where + // cookies have same path length. + return -1; + } + } + + public string Encode(params ICookie[] cookies) + { + if (cookies == null || cookies.Length == 0) + { + return null; + } + + StringBuilder buf = StringBuilder(); + if (this.Strict) + { + if (cookies.Length == 1) + { + this.Encode(buf, cookies[0]); + } + else + { + var cookiesSorted = new ICookie[cookies.Length]; + Array.Copy(cookies, cookiesSorted, cookies.Length); + Array.Sort(cookiesSorted, Comparer); + foreach(ICookie c in cookiesSorted) + { + this.Encode(buf, c); + } + } + } + else + { + foreach (ICookie c in cookies) + { + this.Encode(buf, c); + } + } + return StripTrailingSeparatorOrNull(buf); + } + + public string Encode(IEnumerable cookies) + { + Contract.Requires(cookies != null); + + StringBuilder buf = StringBuilder(); + if (this.Strict) + { + var cookiesList = new List(); + foreach (ICookie cookie in cookies) + { + cookiesList.Add(cookie); + } + cookiesList.Sort(Comparer); + foreach (ICookie c in cookiesList) + { + this.Encode(buf, c); + } + } + else + { + foreach (ICookie cookie in cookies) + { + this.Encode(buf, cookie); + } + } + return StripTrailingSeparatorOrNull(buf); + } + + void Encode(StringBuilder buf, ICookie c) + { + string name = c.Name; + string value = c.Value?? ""; + + this.ValidateCookie(name, value); + + if (c.Wrap) + { + AddQuoted(buf, name, value); + } + else + { + Add(buf, name, value); + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/Cookies/CookieDecoder.cs b/src/DotNetty.Codecs.Http/Cookies/CookieDecoder.cs new file mode 100644 index 0000000..cac4328 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Cookies/CookieDecoder.cs @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Cookies +{ + using DotNetty.Common.Internal.Logging; + using DotNetty.Common.Utilities; + + using static CookieUtil; + + public abstract class CookieDecoder + { + static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(); + protected readonly bool Strict; + + protected CookieDecoder(bool strict) + { + this.Strict = strict; + } + + protected DefaultCookie InitCookie(string header, int nameBegin, int nameEnd, int valueBegin, int valueEnd) + { + if (nameBegin == -1 || nameBegin == nameEnd) + { + Logger.Debug("Skipping cookie with null name"); + return null; + } + + if (valueBegin == -1) + { + Logger.Debug("Skipping cookie with null value"); + return null; + } + + var sequence = new StringCharSequence(header, valueBegin, valueEnd - valueBegin); + ICharSequence unwrappedValue = UnwrapValue(sequence); + if (unwrappedValue == null) + { + Logger.Debug("Skipping cookie because starting quotes are not properly balanced in '{}'", sequence); + return null; + } + + string name = header.Substring(nameBegin, nameEnd - nameBegin); + + int invalidOctetPos; + if (this.Strict && (invalidOctetPos = FirstInvalidCookieNameOctet(name)) >= 0) + { + if (Logger.DebugEnabled) + { + Logger.Debug("Skipping cookie because name '{}' contains invalid char '{}'", + name, name[invalidOctetPos]); + } + return null; + } + + bool wrap = unwrappedValue.Count != valueEnd - valueBegin; + + if (this.Strict && (invalidOctetPos = FirstInvalidCookieValueOctet(unwrappedValue)) >= 0) + { + if (Logger.DebugEnabled) + { + Logger.Debug("Skipping cookie because value '{}' contains invalid char '{}'", + unwrappedValue, unwrappedValue[invalidOctetPos]); + } + + return null; + } + + var cookie = new DefaultCookie(name, unwrappedValue.ToString()); + cookie.Wrap = wrap; + return cookie; + } + } +} diff --git a/src/DotNetty.Codecs.Http/Cookies/CookieEncoder.cs b/src/DotNetty.Codecs.Http/Cookies/CookieEncoder.cs new file mode 100644 index 0000000..5c49c65 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Cookies/CookieEncoder.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Cookies +{ + using System; + using DotNetty.Common.Utilities; + + using static CookieUtil; + + public abstract class CookieEncoder + { + protected readonly bool Strict; + + protected CookieEncoder(bool strict) + { + this.Strict = strict; + } + + protected void ValidateCookie(string name, string value) + { + if (!this.Strict) + { + return; + } + + int pos; + if ((pos = FirstInvalidCookieNameOctet(name)) >= 0) + { + throw new ArgumentException($"Cookie name contains an invalid char: {name[pos]}"); + } + + var sequnce = new StringCharSequence(value); + ICharSequence unwrappedValue = UnwrapValue(sequnce); + if (unwrappedValue == null) + { + throw new ArgumentException($"Cookie value wrapping quotes are not balanced: {value}"); + } + + if ((pos = FirstInvalidCookieValueOctet(unwrappedValue)) >= 0) + { + throw new ArgumentException($"Cookie value contains an invalid char: {value[pos]}"); + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/Cookies/CookieHeaderNames.cs b/src/DotNetty.Codecs.Http/Cookies/CookieHeaderNames.cs new file mode 100644 index 0000000..fcbcb8a --- /dev/null +++ b/src/DotNetty.Codecs.Http/Cookies/CookieHeaderNames.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Cookies +{ + using DotNetty.Common.Utilities; + + public static class CookieHeaderNames + { + public static readonly AsciiString Path = AsciiString.Cached("Path"); + + public static readonly AsciiString Expires = AsciiString.Cached("Expires"); + + public static readonly AsciiString MaxAge = AsciiString.Cached("Max-Age"); + + public static readonly AsciiString Domain = AsciiString.Cached("Domain"); + + public static readonly AsciiString Secure = AsciiString.Cached("Secure"); + + public static readonly AsciiString HttpOnly = AsciiString.Cached("HTTPOnly"); + } +} diff --git a/src/DotNetty.Codecs.Http/Cookies/CookieUtil.cs b/src/DotNetty.Codecs.Http/Cookies/CookieUtil.cs new file mode 100644 index 0000000..c098889 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Cookies/CookieUtil.cs @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Cookies +{ + using System; + using System.Collections; + using System.Text; + using DotNetty.Common; + using DotNetty.Common.Utilities; + + static class CookieUtil + { + static readonly BitArray ValidCookieNameOctets = GetValidCookieNameOctets(); + static readonly BitArray ValidCookieValueOctects = GetValidCookieValueOctets(); + static readonly BitArray ValidCookieAttributeOctets = GetValidCookieAttributeValueOctets(); + + // token = 1* + // separators = + //"(" | ")" | "<" | ">" | "@" + // | "," | ";" | ":" | "\" | <"> + // | "/" | "[" | "]" | "?" | "=" + // | "{" | "}" | SP | HT + static BitArray GetValidCookieNameOctets() + { + var bitArray = new BitArray(128, false); + for (int i = 32; i < 127; i++) + { + bitArray[i] = true; + } + + var separators = new int[] + { '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t' }; + foreach (int separator in separators) + { + bitArray[separator] = false; + } + + return bitArray; + } + + // cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E + // US-ASCII characters excluding CTLs, whitespace, DQUOTE, comma, semicolon, and backslash + static BitArray GetValidCookieValueOctets() + { + var bitArray = new BitArray(128, false); + bitArray[0x21] = true; + for (int i = 0x23; i <= 0x2B; i++) + { + bitArray[i] = true; + } + for (int i = 0x2D; i <= 0x3A; i++) + { + bitArray[i] = true; + } + for (int i = 0x3C; i <= 0x5B; i++) + { + bitArray[i] = true; + } + for (int i = 0x5D; i <= 0x7E; i++) + { + bitArray[i] = true; + } + + return bitArray; + } + + // path-value = + static BitArray GetValidCookieAttributeValueOctets() + { + var bitArray = new BitArray(128, false); + for (int i = 32; i < 127; i++) + { + bitArray[i] = true; + } + bitArray[';'] = false; + + return bitArray; + } + + internal static StringBuilder StringBuilder() => InternalThreadLocalMap.Get().StringBuilder; + + internal static string StripTrailingSeparatorOrNull(StringBuilder buf) => + buf.Length == 0 ? null : StripTrailingSeparator(buf); + + internal static string StripTrailingSeparator(StringBuilder buf) + { + if (buf.Length > 0) + { + buf.Length = buf.Length - 2; + } + + return buf.ToString(); + } + + internal static void Add(StringBuilder sb, string name, long val) + { + sb.Append(name); + sb.Append((char)HttpConstants.EqualsSign); + sb.Append(val); + sb.Append((char)HttpConstants.Semicolon); + sb.Append((char)HttpConstants.HorizontalSpace); + } + + internal static void Add(StringBuilder sb, string name, string val) + { + sb.Append(name); + sb.Append((char)HttpConstants.EqualsSign); + sb.Append(val); + sb.Append((char)HttpConstants.Semicolon); + sb.Append((char)HttpConstants.HorizontalSpace); + } + + internal static void Add(StringBuilder sb, string name) + { + sb.Append(name); + sb.Append((char)HttpConstants.Semicolon); + sb.Append((char)HttpConstants.HorizontalSpace); + } + + internal static void AddQuoted(StringBuilder sb, string name, string val) + { + if (val == null) + { + val = ""; + } + + sb.Append(name); + sb.Append((char)HttpConstants.EqualsSign); + sb.Append((char)HttpConstants.DoubleQuote); + sb.Append(val); + sb.Append((char)HttpConstants.DoubleQuote); + sb.Append((char)HttpConstants.Semicolon); + sb.Append((char)HttpConstants.HorizontalSpace); + } + + internal static int FirstInvalidCookieNameOctet(string cs) => FirstInvalidOctet(cs, ValidCookieNameOctets); + + internal static int FirstInvalidCookieValueOctet(ICharSequence cs) => FirstInvalidOctet(cs, ValidCookieValueOctects); + + static int FirstInvalidOctet(string cs, BitArray bits) + { + for (int i = 0; i < cs.Length; i++) + { + char c = cs[i]; + if (!bits[c]) + { + return i; + } + } + return -1; + } + + static int FirstInvalidOctet(ICharSequence cs, BitArray bits) + { + for (int i = 0; i < cs.Count; i++) + { + char c = cs[i]; + if (!bits[c]) + { + return i; + } + } + return -1; + } + + internal static ICharSequence UnwrapValue(ICharSequence cs) + { + int len = cs.Count; + if (len > 0 && cs[0] == '"') + { + if (len >= 2 && cs[len - 1] == '"') + { + // properly balanced + return len == 2 ? StringCharSequence.Empty : cs.SubSequence(1, len - 1); + } + else + { + return null; + } + } + + return cs; + } + + internal static string ValidateAttributeValue(string name, string value) + { + value = value?.Trim(); + if (string.IsNullOrEmpty(value)) + { + return null; + } + + int i = FirstInvalidOctet(value, ValidCookieAttributeOctets); + if (i != -1) + { + throw new ArgumentException($"{name} contains the prohibited characters: ${value[i]}"); + } + + return value; + } + } +} diff --git a/src/DotNetty.Codecs.Http/Cookies/DefaultCookie.cs b/src/DotNetty.Codecs.Http/Cookies/DefaultCookie.cs new file mode 100644 index 0000000..53a6899 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Cookies/DefaultCookie.cs @@ -0,0 +1,227 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoProperty +// ReSharper disable ConvertToAutoPropertyWhenPossible +namespace DotNetty.Codecs.Http.Cookies +{ + using System; + using System.Diagnostics.Contracts; + using System.Text; + + using static CookieUtil; + + public sealed class DefaultCookie : ICookie + { + // Constant for undefined MaxAge attribute value. + const long UndefinedMaxAge = long.MinValue; + + readonly string name; + string value; + bool wrap; + string domain; + string path; + long maxAge = UndefinedMaxAge; + bool secure; + bool httpOnly; + + public DefaultCookie(string name, string value) + { + Contract.Requires(!string.IsNullOrEmpty(name?.Trim())); + Contract.Requires(value != null); + + this.name = name; + this.value = value; + } + + public string Name => this.name; + + public string Value + { + get => this.value; + set + { + Contract.Requires(value != null); + this.value = value; + } + } + + public bool Wrap + { + get => this.wrap; + set => this.wrap = value; + } + + public string Domain + { + get => this.domain; + set => this.domain = ValidateAttributeValue(nameof(this.domain), value); + } + + public string Path + { + get => this.path; + set => this.path = ValidateAttributeValue(nameof(this.path), value); + } + + public long MaxAge + { + get => this.maxAge; + set => this.maxAge = value; + } + + public bool IsSecure + { + get => this.secure; + set => this.secure = value; + } + + public bool IsHttpOnly + { + get => this.httpOnly; + set => this.httpOnly = value; + } + + public override int GetHashCode() => this.name.GetHashCode(); + + public override bool Equals(object obj) => obj is DefaultCookie cookie && this.Equals(cookie); + + public bool Equals(ICookie other) + { + if (ReferenceEquals(null, other)) + { + return false; + } + if (ReferenceEquals(this, other)) + { + return true; + } + + if (!this.name.Equals(other.Name)) + { + return false; + } + + if (this.path == null) + { + if (other.Path != null) + { + return false; + } + } + else if (other.Path == null) + { + return false; + } + else if (!this.path.Equals(other.Path)) + { + return false; + } + + if (this.domain == null) + { + if (other.Domain != null) + { + return false; + } + } + else + { + return this.domain.Equals(other.Domain, StringComparison.OrdinalIgnoreCase); + } + + return true; + } + + public int CompareTo(ICookie other) + { + int v = string.Compare(this.name, other.Name, StringComparison.Ordinal); + if (v != 0) + { + return v; + } + + if (this.path == null) + { + if (other.Path != null) + { + return -1; + } + } + else if (other.Path == null) + { + return 1; + } + else + { + v = string.Compare(this.path, other.Path, StringComparison.Ordinal); + if (v != 0) + { + return v; + } + } + + if (this.domain == null) + { + if (other.Domain != null) + { + return -1; + } + } + else if (other.Domain == null) + { + return 1; + } + else + { + v = string.Compare(this.domain, other.Domain, StringComparison.OrdinalIgnoreCase); + return v; + } + + return 0; + } + + public int CompareTo(object obj) + { + if (obj == null) + { + return 1; + } + + if (!(obj is ICookie cookie)) + { + throw new ArgumentException($"{nameof(obj)} must be of {nameof(ICookie)} type"); + } + + return this.CompareTo(cookie); + } + + public override string ToString() + { + StringBuilder buf = StringBuilder(); + buf.Append($"{this.name}={this.Value}"); + if (this.domain != null) + { + buf.Append($", domain={this.domain}"); + } + if (this.path != null) + { + buf.Append($", path={this.path}"); + } + if (this.maxAge >= 0) + { + buf.Append($", maxAge={this.maxAge}s"); + } + if (this.secure) + { + buf.Append(", secure"); + } + if (this.httpOnly) + { + buf.Append(", HTTPOnly"); + } + + return buf.ToString(); + } + } +} diff --git a/src/DotNetty.Codecs.Http/Cookies/ICookie.cs b/src/DotNetty.Codecs.Http/Cookies/ICookie.cs new file mode 100644 index 0000000..fe1a602 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Cookies/ICookie.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Cookies +{ + using System; + + // http://en.wikipedia.org/wiki/HTTP_cookie + public interface ICookie : IEquatable, IComparable, IComparable + { + string Name { get; } + + string Value { get; set; } + + /// + /// Returns true if the raw value of this {@link Cookie}, + /// was wrapped with double quotes in original Set-Cookie header. + /// + bool Wrap { get; set; } + + string Domain { get; set; } + + string Path { get; set; } + + long MaxAge { get; set; } + + bool IsSecure { get; set; } + + /// + /// Checks to see if this Cookie can only be accessed via HTTP. + /// If this returns true, the Cookie cannot be accessed through + /// client side script - But only if the browser supports it. + /// For more information, please look "http://www.owasp.org/index.php/HTTPOnly". + /// + bool IsHttpOnly { get; set; } + } +} diff --git a/src/DotNetty.Codecs.Http/Cookies/ServerCookieDecoder.cs b/src/DotNetty.Codecs.Http/Cookies/ServerCookieDecoder.cs new file mode 100644 index 0000000..59266d0 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Cookies/ServerCookieDecoder.cs @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Cookies +{ + using System.Collections.Generic; + using System.Collections.Immutable; + using System.Diagnostics.Contracts; + using DotNetty.Common.Utilities; + + // http://tools.ietf.org/html/rfc6265 + // compliant cookie decoder to be used server side. + // + // http://tools.ietf.org/html/rfc2965 + // cookies are still supported,old fields will simply be ignored. + public sealed class ServerCookieDecoder : CookieDecoder + { + static readonly AsciiString RFC2965Version = new AsciiString("$Version"); + static readonly AsciiString RFC2965Path = new AsciiString($"${CookieHeaderNames.Path}"); + static readonly AsciiString RFC2965Domain = new AsciiString($"${CookieHeaderNames.Domain}"); + static readonly AsciiString RFC2965Port = new AsciiString("$Port"); + + // + // Strict encoder that validates that name and value chars are in the valid scope + // defined in RFC6265 + // + public static readonly ServerCookieDecoder StrictDecoder = new ServerCookieDecoder(true); + + // + // Lax instance that doesn't validate name and value + // + public static readonly ServerCookieDecoder LaxDecoder = new ServerCookieDecoder(false); + + ServerCookieDecoder(bool strict) : base(strict) + { + } + + public ISet Decode(string header) + { + Contract.Requires(header != null); + + int headerLen = header.Length; + if (headerLen == 0) + { + return ImmutableHashSet.Empty; + } + + var cookies = new SortedSet(); + + int i = 0; + + bool rfc2965Style = false; + if (CharUtil.RegionMatchesIgnoreCase(header, 0, RFC2965Version, 0, RFC2965Version.Count)) + { + // RFC 2965 style cookie, move to after version value + i = header.IndexOf(';') + 1; + rfc2965Style = true; + } + + // loop + for (;;) + { + // Skip spaces and separators. + for (;;) + { + if (i == headerLen) + { + goto loop; + } + char c = header[i]; + if (c == '\t' || c == '\n' || c == 0x0b || c == '\f' + || c == '\r' || c == ' ' || c == ',' || c == ';') + { + i++; + continue; + } + break; + } + + int nameBegin = i; + int nameEnd; + int valueBegin; + int valueEnd; + + for (;;) + { + char curChar = header[i]; + if (curChar == ';') + { + // NAME; (no value till ';') + nameEnd = i; + valueBegin = valueEnd = -1; + break; + } + else if (curChar == '=') + { + // NAME=VALUE + nameEnd = i; + i++; + if (i == headerLen) + { + // NAME= (empty value, i.e. nothing after '=') + valueBegin = valueEnd = 0; + break; + } + + valueBegin = i; + // NAME=VALUE; + int semiPos = header.IndexOf(';', i); + valueEnd = i = semiPos > 0 ? semiPos : headerLen; + break; + } + else + { + i++; + } + + if (i == headerLen) + { + // NAME (no value till the end of string) + nameEnd = headerLen; + valueBegin = valueEnd = -1; + break; + } + } + + if (rfc2965Style && (CharUtil.RegionMatches(header, nameBegin, RFC2965Path, 0, RFC2965Path.Count) + || CharUtil.RegionMatches(header, nameBegin, RFC2965Domain, 0, RFC2965Domain.Count) + || CharUtil.RegionMatches(header, nameBegin, RFC2965Port, 0, RFC2965Port.Count))) + { + // skip obsolete RFC2965 fields + continue; + } + + DefaultCookie cookie = this.InitCookie(header, nameBegin, nameEnd, valueBegin, valueEnd); + if (cookie != null) + { + cookies.Add(cookie); + } + } + + loop: + return cookies; + } + } +} diff --git a/src/DotNetty.Codecs.Http/Cookies/ServerCookieEncoder.cs b/src/DotNetty.Codecs.Http/Cookies/ServerCookieEncoder.cs new file mode 100644 index 0000000..3be495e --- /dev/null +++ b/src/DotNetty.Codecs.Http/Cookies/ServerCookieEncoder.cs @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Cookies +{ + using System; + using System.Collections.Generic; + using System.Collections.Immutable; + using System.Diagnostics.Contracts; + using System.Text; + + using static CookieUtil; + + // http://tools.ietf.org/html/rfc6265 compliant cookie encoder to be used server side, + // so some fields are sent (Version is typically ignored). + // + // As Netty's Cookie merges Expires and MaxAge into one single field, only Max-Age field is sent. + // Note that multiple cookies must be sent as separate "Set-Cookie" headers. + public sealed class ServerCookieEncoder : CookieEncoder + { + // + // Strict encoder that validates that name and value chars are in the valid scope + // defined in RFC6265, and(for methods that accept multiple cookies) that only + // one cookie is encoded with any given name. (If multiple cookies have the same + // name, the last one is the one that is encoded.) + // + public static readonly ServerCookieEncoder StrictEncoder = new ServerCookieEncoder(true); + + // + // Lax instance that doesn't validate name and value, and that allows multiple + // cookies with the same name. + // + public static readonly ServerCookieEncoder LaxEncoder = new ServerCookieEncoder(false); + + ServerCookieEncoder(bool strict) : base(strict) + { + } + + public string Encode(string name, string value) => this.Encode(new DefaultCookie(name, value)); + + public string Encode(ICookie cookie) + { + Contract.Requires(cookie != null); + + string name = cookie.Name ?? nameof(cookie); + string value = cookie.Value ?? ""; + + this.ValidateCookie(name, value); + + StringBuilder buf = StringBuilder(); + + if (cookie.Wrap) + { + AddQuoted(buf, name, value); + } + else + { + Add(buf, name, value); + } + + if (cookie.MaxAge != long.MinValue) + { + Add(buf, (string)CookieHeaderNames.MaxAge, cookie.MaxAge); + DateTime expires = DateTime.UtcNow.AddMilliseconds(cookie.MaxAge * 1000); + buf.Append(CookieHeaderNames.Expires); + buf.Append((char)HttpConstants.EqualsSign); + DateFormatter.Append(expires, buf); + buf.Append((char)HttpConstants.Semicolon); + buf.Append((char)HttpConstants.HorizontalSpace); + } + + if (cookie.Path != null) + { + Add(buf, (string)CookieHeaderNames.Path, cookie.Path); + } + + if (cookie.Domain != null) + { + Add(buf, (string)CookieHeaderNames.Domain, cookie.Domain); + } + + if (cookie.IsSecure) + { + Add(buf, (string)CookieHeaderNames.Secure); + } + + if (cookie.IsHttpOnly) + { + Add(buf, (string)CookieHeaderNames.HttpOnly); + } + + return StripTrailingSeparator(buf); + } + + static List Dedup(IReadOnlyList encoded, IDictionary nameToLastIndex) + { + var isLastInstance = new bool[encoded.Count]; + foreach (int idx in nameToLastIndex.Values) + { + isLastInstance[idx] = true; + } + + var dedupd = new List(nameToLastIndex.Count); + for (int i = 0, n = encoded.Count; i < n; i++) + { + if (isLastInstance[i]) + { + dedupd.Add(encoded[i]); + } + } + return dedupd; + } + + public IList Encode(params ICookie[] cookies) + { + if (cookies == null || cookies.Length == 0) + { + return ImmutableList.Empty; + } + + var encoded = new List(cookies.Length); + Dictionary nameToIndex = this.Strict && cookies.Length > 1 ? new Dictionary() : null; + bool hasDupdName = false; + for (int i = 0; i < cookies.Length; i++) + { + ICookie c = cookies[i]; + encoded.Add(this.Encode(c)); + if (nameToIndex != null) + { + if (nameToIndex.ContainsKey(c.Name)) + { + nameToIndex[c.Name] = i; + hasDupdName = true; + } + else + { + nameToIndex.Add(c.Name, i); + } + } + } + return hasDupdName ? Dedup(encoded, nameToIndex) : encoded; + } + + public IList Encode(ICollection cookies) + { + Contract.Requires(cookies != null); + if (cookies.Count == 0) + { + return ImmutableList.Empty; + } + + var encoded = new List(); + var nameToIndex = new Dictionary(); + bool hasDupdName = false; + int i = 0; + foreach (ICookie c in cookies) + { + encoded.Add(this.Encode(c)); + if (nameToIndex.ContainsKey(c.Name)) + { + nameToIndex[c.Name] = i; + hasDupdName = true; + } + else + { + nameToIndex.Add(c.Name, i); + } + i++; + } + return hasDupdName ? Dedup(encoded, nameToIndex) : encoded; + } + } +} diff --git a/src/DotNetty.Codecs.Http/Cors/CorsConfig.cs b/src/DotNetty.Codecs.Http/Cors/CorsConfig.cs new file mode 100644 index 0000000..bc49318 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Cors/CorsConfig.cs @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoProperty +// ReSharper disable ConvertToAutoPropertyWhenPossible +namespace DotNetty.Codecs.Http.Cors +{ + using System; + using System.Collections.Generic; + using System.Collections.Immutable; + using System.Linq; + using System.Text; + using DotNetty.Common.Concurrency; + using DotNetty.Common.Utilities; + + // Configuration for Cross-Origin Resource Sharing (CORS). + public sealed class CorsConfig + { + readonly ISet origins; + readonly bool anyOrigin; + readonly bool enabled; + readonly ISet exposeHeaders; + readonly bool allowCredentials; + readonly long maxAge; + readonly ISet allowedRequestMethods; + readonly ISet allowedRequestHeaders; + readonly bool allowNullOrigin; + readonly IDictionary> preflightHeaders; + readonly bool shortCircuit; + + internal CorsConfig(CorsConfigBuilder builder) + { + this.origins = new HashSet(builder.origins, AsciiString.CaseSensitiveHasher); + this.anyOrigin = builder.anyOrigin; + this.enabled = builder.enabled; + this.exposeHeaders = builder.exposeHeaders; + this.allowCredentials = builder.allowCredentials; + this.maxAge = builder.maxAge; + this.allowedRequestMethods = builder.requestMethods; + this.allowedRequestHeaders = builder.requestHeaders; + this.allowNullOrigin = builder.allowNullOrigin; + this.preflightHeaders = builder.preflightHeaders; + this.shortCircuit = builder.shortCircuit; + } + + public bool IsCorsSupportEnabled => this.enabled; + + public bool IsAnyOriginSupported => this.anyOrigin; + + public ICharSequence Origin => this.origins.Count == 0 ? CorsHandler.AnyOrigin : this.origins.First(); + + public ISet Origins => this.origins; + + public bool IsNullOriginAllowed => this.allowNullOrigin; + + public ISet ExposedHeaders() => this.exposeHeaders.ToImmutableHashSet(); + + public bool IsCredentialsAllowed => this.allowCredentials; + + public long MaxAge => this.maxAge; + + public ISet AllowedRequestMethods() => this.allowedRequestMethods.ToImmutableHashSet(); + + public ISet AllowedRequestHeaders() => this.allowedRequestHeaders.ToImmutableHashSet(); + + public HttpHeaders PreflightResponseHeaders() + { + if (this.preflightHeaders.Count == 0) + { + return EmptyHttpHeaders.Default; + } + HttpHeaders headers = new DefaultHttpHeaders(); + foreach (KeyValuePair> entry in this.preflightHeaders) + { + object value = GetValue(entry.Value); + if (value is IEnumerable values) + { + headers.Add(entry.Key, values); + } + else + { + headers.Add(entry.Key, value); + } + } + return headers; + } + + public bool IsShortCircuit => this.shortCircuit; + + static object GetValue(ICallable callable) + { + try + { + return callable.Call(); + } + catch (Exception exception) + { + throw new InvalidOperationException($"Could not generate value for callable [{callable}]", exception); + } + } + + public override string ToString() + { + var builder = new StringBuilder(); + builder.Append($"{StringUtil.SimpleClassName(this)}") + .Append($"[enabled = {this.enabled}"); + + builder.Append(", origins="); + if (this.Origins.Count == 0) + { + builder.Append("*"); + } + else + { + builder.Append("("); + foreach (ICharSequence value in this.Origins) + { + builder.Append($"'{value}'"); + } + builder.Append(")"); + } + + builder.Append(", exposedHeaders="); + if (this.exposeHeaders.Count == 0) + { + builder.Append("*"); + } + else + { + builder.Append("("); + foreach (ICharSequence value in this.exposeHeaders) + { + builder.Append($"'{value}'"); + } + builder.Append(")"); + } + + builder.Append($", isCredentialsAllowed={this.allowCredentials}"); + builder.Append($", maxAge={this.maxAge}"); + + builder.Append(", allowedRequestMethods="); + if (this.allowedRequestMethods.Count == 0) + { + builder.Append("*"); + } + else + { + builder.Append("("); + foreach (HttpMethod value in this.allowedRequestMethods) + { + builder.Append($"'{value}'"); + } + builder.Append(")"); + } + + builder.Append(", allowedRequestHeaders="); + if (this.allowedRequestHeaders.Count == 0) + { + builder.Append("*"); + } + else + { + builder.Append("("); + foreach(AsciiString value in this.allowedRequestHeaders) + { + builder.Append($"'{value}'"); + } + builder.Append(")"); + } + + builder.Append(", preflightHeaders="); + if (this.preflightHeaders.Count == 0) + { + builder.Append("*"); + } + else + { + builder.Append("("); + foreach (AsciiString value in this.preflightHeaders.Keys) + { + builder.Append($"'{value}'"); + } + builder.Append(")"); + } + + builder.Append("]"); + return builder.ToString(); + } + } +} diff --git a/src/DotNetty.Codecs.Http/Cors/CorsConfigBuilder.cs b/src/DotNetty.Codecs.Http/Cors/CorsConfigBuilder.cs new file mode 100644 index 0000000..254c124 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Cors/CorsConfigBuilder.cs @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Http.Cors +{ + using System; + using System.Collections.Generic; + using System.Collections.Immutable; + using System.Diagnostics.Contracts; + using DotNetty.Common.Concurrency; + using DotNetty.Common.Utilities; + + public sealed class CorsConfigBuilder + { + public static CorsConfigBuilder ForAnyOrigin() => new CorsConfigBuilder(); + + public static CorsConfigBuilder ForOrigin(ICharSequence origin) + { + return CorsHandler.AnyOrigin.ContentEquals(origin) // * AnyOrigin + ? new CorsConfigBuilder() + : new CorsConfigBuilder(origin); + } + + public static CorsConfigBuilder ForOrigins(params ICharSequence[] origins) => new CorsConfigBuilder(origins); + + internal readonly ISet origins; + internal readonly bool anyOrigin; + internal bool allowNullOrigin; + internal bool enabled = true; + internal bool allowCredentials; + internal readonly HashSet exposeHeaders = new HashSet(AsciiString.CaseSensitiveHasher); + internal long maxAge; + internal readonly ISet requestMethods = new HashSet(); + internal readonly ISet requestHeaders = new HashSet(); + internal readonly Dictionary> preflightHeaders = new Dictionary>(); + internal bool noPreflightHeaders; + internal bool shortCircuit; + + CorsConfigBuilder(params ICharSequence[] origins) + { + this.origins = new HashSet(origins); + this.anyOrigin = false; + } + + CorsConfigBuilder() + { + this.anyOrigin = true; + this.origins = ImmutableHashSet.Empty; + } + + public CorsConfigBuilder AllowNullOrigin() + { + this.allowNullOrigin = true; + return this; + } + + public CorsConfigBuilder Disable() + { + this.enabled = false; + return this; + } + + public CorsConfigBuilder ExposeHeaders(params ICharSequence[] headers) + { + foreach (ICharSequence header in headers) + { + this.exposeHeaders.Add(header); + } + return this; + } + + public CorsConfigBuilder ExposeHeaders(params string[] headers) + { + foreach (string header in headers) + { + this.exposeHeaders.Add(new StringCharSequence(header)); + } + return this; + } + + public CorsConfigBuilder AllowCredentials() + { + this.allowCredentials = true; + return this; + } + + public CorsConfigBuilder MaxAge(long max) + { + this.maxAge = max; + return this; + } + + public CorsConfigBuilder AllowedRequestMethods(params HttpMethod[] methods) + { + this.requestMethods.UnionWith(methods); + return this; + } + + public CorsConfigBuilder AllowedRequestHeaders(params AsciiString[] headers) + { + this.requestHeaders.UnionWith(headers); + return this; + } + + public CorsConfigBuilder AllowedRequestHeaders(params ICharSequence[] headers) + { + foreach (ICharSequence header in headers) + { + this.requestHeaders.Add(new AsciiString(header)); + } + return this; + } + + public CorsConfigBuilder PreflightResponseHeader(AsciiString name, params object[] values) + { + Contract.Requires(values != null); + + if (values.Length == 1) + { + this.preflightHeaders.Add(name, new ConstantValueGenerator(values[0])); + } + else + { + this.PreflightResponseHeader(name, new List(values)); + } + return this; + } + + public CorsConfigBuilder PreflightResponseHeader(AsciiString name, ICollection value) + { + this.preflightHeaders.Add(name, new ConstantValueGenerator(value)); + return this; + } + + public CorsConfigBuilder PreflightResponseHeader(AsciiString name, ICallable valueGenerator) + { + this.preflightHeaders.Add(name, valueGenerator); + return this; + } + + public CorsConfigBuilder NoPreflightResponseHeaders() + { + this.noPreflightHeaders = true; + return this; + } + + public CorsConfigBuilder ShortCircuit() + { + this.shortCircuit = true; + return this; + } + + public CorsConfig Build() + { + if (this.preflightHeaders.Count == 0 && !this.noPreflightHeaders) + { + this.preflightHeaders.Add(HttpHeaderNames.Date, DateValueGenerator.Default); + this.preflightHeaders.Add(HttpHeaderNames.ContentLength, new ConstantValueGenerator(new AsciiString("0"))); + } + return new CorsConfig(this); + } + + // This class is used for preflight HTTP response values that do not need to be + // generated, but instead the value is "static" in that the same value will be returned + // for each call. + sealed class ConstantValueGenerator : ICallable + { + readonly object value; + + internal ConstantValueGenerator(object value) + { + Contract.Requires(value != null); + this.value = value; + } + + public object Call() => this.value; + } + + // This callable is used for the DATE preflight HTTP response HTTP header. + // It's value must be generated when the response is generated, hence will be + // different for every call. + sealed class DateValueGenerator : ICallable + { + internal static readonly DateValueGenerator Default = new DateValueGenerator(); + + public object Call() => new DateTime(); + } + } +} + diff --git a/src/DotNetty.Codecs.Http/Cors/CorsHandler.cs b/src/DotNetty.Codecs.Http/Cors/CorsHandler.cs new file mode 100644 index 0000000..c6e38d2 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Cors/CorsHandler.cs @@ -0,0 +1,211 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Cors +{ + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Threading.Tasks; + using DotNetty.Common.Internal.Logging; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + using static Common.Utilities.ReferenceCountUtil; + + public class CorsHandler : ChannelDuplexHandler + { + static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(); + + internal static readonly AsciiString AnyOrigin = new AsciiString("*"); + internal static readonly AsciiString NullOrigin = new AsciiString("null"); + + readonly CorsConfig config; + IHttpRequest request; + + public CorsHandler(CorsConfig config) + { + Contract.Requires(config != null); + + this.config = config; + } + + public override void ChannelRead(IChannelHandlerContext context, object message) + { + if (this.config.IsCorsSupportEnabled && message is IHttpRequest) + { + this.request = (IHttpRequest)message; + if (IsPreflightRequest(this.request)) + { + this.HandlePreflight(context, this.request); + return; + } + if (this.config.IsShortCircuit && !this.ValidateOrigin()) + { + Forbidden(context, this.request); + return; + } + } + context.FireChannelRead(message); + } + + void HandlePreflight(IChannelHandlerContext ctx, IHttpRequest req) + { + var response = new DefaultFullHttpResponse(req.ProtocolVersion, HttpResponseStatus.OK, true, true); + if (this.SetOrigin(response)) + { + this.SetAllowMethods(response); + this.SetAllowHeaders(response); + this.SetAllowCredentials(response); + this.SetMaxAge(response); + this.SetPreflightHeaders(response); + } + if (!response.Headers.Contains(HttpHeaderNames.ContentLength)) + { + response.Headers.Set(HttpHeaderNames.ContentLength, HttpHeaderValues.Zero); + } + + Release(req); + Respond(ctx, req, response); + } + + void SetPreflightHeaders(IHttpResponse response) => response.Headers.Add(this.config.PreflightResponseHeaders()); + + bool SetOrigin(IHttpResponse response) + { + if (!this.request.Headers.TryGet(HttpHeaderNames.Origin, out ICharSequence origin)) + { + return false; + } + if (NullOrigin.ContentEquals(origin) && this.config.IsNullOriginAllowed) + { + SetNullOrigin(response); + return true; + } + if (this.config.IsAnyOriginSupported) + { + if (this.config.IsCredentialsAllowed) + { + this.EchoRequestOrigin(response); + SetVaryHeader(response); + } + else + { + SetAnyOrigin(response); + } + return true; + } + if (this.config.Origins.Contains(origin)) + { + SetOrigin(response, origin); + SetVaryHeader(response); + return true; + } + Logger.Debug("Request origin [{}]] was not among the configured origins [{}]", origin, this.config.Origins); + + return false; + } + + bool ValidateOrigin() + { + if (this.config.IsAnyOriginSupported) + { + return true; + } + + if (!this.request.Headers.TryGet(HttpHeaderNames.Origin, out ICharSequence origin)) + { + // Not a CORS request so we cannot validate it. It may be a non CORS request. + return true; + } + + if (NullOrigin.ContentEquals(origin) && this.config.IsNullOriginAllowed) + { + return true; + } + return this.config.Origins.Contains(origin); + } + + void EchoRequestOrigin(IHttpResponse response) => SetOrigin(response, this.request.Headers.Get(HttpHeaderNames.Origin, null)); + + static void SetVaryHeader(IHttpResponse response) => response.Headers.Set(HttpHeaderNames.Vary, HttpHeaderNames.Origin); + + static void SetAnyOrigin(IHttpResponse response) => SetOrigin(response, AnyOrigin); + + static void SetNullOrigin(IHttpResponse response) => SetOrigin(response, NullOrigin); + + static void SetOrigin(IHttpResponse response, ICharSequence origin) => response.Headers.Set(HttpHeaderNames.AccessControlAllowOrigin, origin); + + void SetAllowCredentials(IHttpResponse response) + { + if (this.config.IsCredentialsAllowed + && !AsciiString.ContentEquals(response.Headers.Get(HttpHeaderNames.AccessControlAllowOrigin, null), AnyOrigin)) + { + response.Headers.Set(HttpHeaderNames.AccessControlAllowCredentials, new AsciiString("true")); + } + } + + static bool IsPreflightRequest(IHttpRequest request) + { + HttpHeaders headers = request.Headers; + return request.Method.Equals(HttpMethod.Options) + && headers.Contains(HttpHeaderNames.Origin) + && headers.Contains(HttpHeaderNames.AccessControlRequestMethod); + } + + void SetExposeHeaders(IHttpResponse response) + { + ISet headers = this.config.ExposedHeaders(); + if (headers.Count > 0) + { + response.Headers.Set(HttpHeaderNames.AccessControlExposeHeaders, headers); + } + } + + void SetAllowMethods(IHttpResponse response) => response.Headers.Set(HttpHeaderNames.AccessControlAllowMethods, this.config.AllowedRequestMethods()); + + void SetAllowHeaders(IHttpResponse response) => response.Headers.Set(HttpHeaderNames.AccessControlAllowHeaders, this.config.AllowedRequestHeaders()); + + void SetMaxAge(IHttpResponse response) => response.Headers.Set(HttpHeaderNames.AccessControlMaxAge, this.config.MaxAge); + + public override Task WriteAsync(IChannelHandlerContext context, object message) + { + if (this.config.IsCorsSupportEnabled && message is IHttpResponse response) + { + if (this.SetOrigin(response)) + { + this.SetAllowCredentials(response); + this.SetExposeHeaders(response); + } + } + return context.WriteAndFlushAsync(message); + } + + static void Forbidden(IChannelHandlerContext ctx, IHttpRequest request) + { + var response = new DefaultFullHttpResponse(request.ProtocolVersion, HttpResponseStatus.Forbidden); + response.Headers.Set(HttpHeaderNames.ContentLength, HttpHeaderValues.Zero); + Release(request); + Respond(ctx, request, response); + } + + static void Respond(IChannelHandlerContext ctx, IHttpRequest request, IHttpResponse response) + { + bool keepAlive = HttpUtil.IsKeepAlive(request); + + HttpUtil.SetKeepAlive(response, keepAlive); + + Task task = ctx.WriteAndFlushAsync(response); + if (!keepAlive) + { + task.ContinueWith(CloseOnComplete, ctx, + TaskContinuationOptions.ExecuteSynchronously); + } + } + + static void CloseOnComplete(Task task, object state) + { + var ctx = (IChannelHandlerContext)state; + ctx.CloseAsync(); + } + } +} diff --git a/src/DotNetty.Codecs.Http/DefaultFullHttpRequest.cs b/src/DotNetty.Codecs.Http/DefaultFullHttpRequest.cs new file mode 100644 index 0000000..5f693e5 --- /dev/null +++ b/src/DotNetty.Codecs.Http/DefaultFullHttpRequest.cs @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http +{ + using System.Diagnostics.Contracts; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common; + + public class DefaultFullHttpRequest : DefaultHttpRequest, IFullHttpRequest + { + readonly IByteBuffer content; + readonly HttpHeaders trailingHeader; + // Used to cache the value of the hash code and avoid {@link IllegalReferenceCountException}. + int hash; + + public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, string uri) + : this(httpVersion, method, uri, Unpooled.Buffer(0)) + { + } + + public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, string uri, IByteBuffer content) + : this(httpVersion, method, uri, content, true) + { + } + + public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, string uri, bool validateHeaders) + : this(httpVersion, method, uri, Unpooled.Buffer(0), validateHeaders) + { + } + + public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, string uri, + IByteBuffer content, bool validateHeaders) + : base(httpVersion, method, uri, validateHeaders) + { + Contract.Requires(content != null); + + this.content = content; + this.trailingHeader = new DefaultHttpHeaders(validateHeaders); + } + + public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, string uri, + IByteBuffer content, HttpHeaders headers, HttpHeaders trailingHeader) + : base(httpVersion, method, uri, headers) + { + Contract.Requires(content != null); + Contract.Requires(trailingHeader != null); + + this.content = content; + this.trailingHeader = trailingHeader; + } + + public HttpHeaders TrailingHeaders => this.trailingHeader; + + public IByteBuffer Content => this.content; + + public int ReferenceCount => this.content.ReferenceCount; + + public IReferenceCounted Retain() + { + this.content.Retain(); + return this; + } + + public IReferenceCounted Retain(int increment) + { + this.content.Retain(increment); + return this; + } + + public IReferenceCounted Touch() + { + this.content.Touch(); + return this; + } + + public IReferenceCounted Touch(object hint) + { + this.content.Touch(hint); + return this; + } + + public bool Release() => this.content.Release(); + + public bool Release(int decrement) => this.content.Release(decrement); + + public IByteBufferHolder Copy() => this.Replace(this.content.Copy()); + + public IByteBufferHolder Duplicate() => this.Replace(this.content.Duplicate()); + + public IByteBufferHolder RetainedDuplicate() => this.Replace(this.content.RetainedDuplicate()); + + public IByteBufferHolder Replace(IByteBuffer newContent) => + new DefaultFullHttpRequest(this.ProtocolVersion, this.Method, this.Uri, newContent, this.Headers, this.trailingHeader); + + public override int GetHashCode() + { + // ReSharper disable NonReadonlyMemberInGetHashCode + int hashCode = this.hash; + if (hashCode == 0) + { + if (this.content.ReferenceCount != 0) + { + try + { + hashCode = 31 + this.content.GetHashCode(); + } + catch (IllegalReferenceCountException) + { + // Handle race condition between checking refCnt() == 0 and using the object. + hashCode = 31; + } + } + else + { + hashCode = 31; + } + hashCode = 31 * hashCode + this.trailingHeader.GetHashCode(); + hashCode = 31 * hashCode + base.GetHashCode(); + + this.hash = hashCode; + } + // ReSharper restore NonReadonlyMemberInGetHashCode + return hashCode; + } + + public override bool Equals(object obj) + { + if (!(obj is DefaultFullHttpRequest other)) + { + return false; + } + return base.Equals(other) + && this.content.Equals(other.content) + && this.trailingHeader.Equals(other.trailingHeader); + } + + public override string ToString() => HttpMessageUtil.AppendFullRequest(new StringBuilder(256), this).ToString(); + } +} diff --git a/src/DotNetty.Codecs.Http/DefaultFullHttpResponse.cs b/src/DotNetty.Codecs.Http/DefaultFullHttpResponse.cs new file mode 100644 index 0000000..a86bd84 --- /dev/null +++ b/src/DotNetty.Codecs.Http/DefaultFullHttpResponse.cs @@ -0,0 +1,156 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http +{ + using System.Diagnostics.Contracts; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common; + + public class DefaultFullHttpResponse : DefaultHttpResponse, IFullHttpResponse + { + readonly IByteBuffer content; + readonly HttpHeaders trailingHeaders; + + // Used to cache the value of the hash code and avoid {@link IllegalReferenceCountException}. + int hash; + + public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status) + : this(version, status, Unpooled.Buffer(0)) + { + } + + public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status, IByteBuffer content) + : this(version, status, content, true) + { + } + + public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status, bool validateHeaders) + : this(version, status, Unpooled.Buffer(0), validateHeaders, false) + { + } + + public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status, bool validateHeaders, + bool singleFieldHeaders) + : this(version, status, Unpooled.Buffer(0), validateHeaders, singleFieldHeaders) + { + } + + public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status, + IByteBuffer content, bool validateHeaders) + : this(version, status, content, validateHeaders, false) + { + } + + public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status, + IByteBuffer content, bool validateHeaders, bool singleFieldHeaders) + : base(version, status, validateHeaders, singleFieldHeaders) + { + Contract.Requires(content != null); + + this.content = content; + this.trailingHeaders = singleFieldHeaders + ? new CombinedHttpHeaders(validateHeaders) + : new DefaultHttpHeaders(validateHeaders); + } + + public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status, IByteBuffer content, HttpHeaders headers, HttpHeaders trailingHeaders) + : base(version, status, headers) + { + Contract.Requires(content != null); + Contract.Requires(trailingHeaders != null); + + this.content = content; + this.trailingHeaders = trailingHeaders; + } + + public HttpHeaders TrailingHeaders => this.trailingHeaders; + + public IByteBuffer Content => this.content; + + public int ReferenceCount => this.content.ReferenceCount; + + public IReferenceCounted Retain() + { + this.content.Retain(); + return this; + } + + public IReferenceCounted Retain(int increment) + { + this.content.Retain(increment); + return this; + } + + public IReferenceCounted Touch() + { + this.content.Touch(); + return this; + } + + public IReferenceCounted Touch(object hint) + { + this.content.Touch(hint); + return this; + } + + public bool Release() => this.content.Release(); + + public bool Release(int decrement) => this.content.Release(decrement); + + public IByteBufferHolder Copy() => this.Replace(this.content.Copy()); + + public IByteBufferHolder Duplicate() => this.Replace(this.content.Duplicate()); + + public IByteBufferHolder RetainedDuplicate() => this.Replace(this.content.RetainedDuplicate()); + + public IByteBufferHolder Replace(IByteBuffer newContent) => + new DefaultFullHttpResponse(this.ProtocolVersion, this.Status, newContent, this.Headers, this.trailingHeaders); + + public override int GetHashCode() + { + // ReSharper disable NonReadonlyMemberInGetHashCode + int hashCode = this.hash; + if (hashCode == 0) + { + if (this.content.ReferenceCount != 0) + { + try + { + hashCode = 31 + this.content.GetHashCode(); + } + catch (IllegalReferenceCountException) + { + // Handle race condition between checking refCnt() == 0 and using the object. + hashCode = 31; + } + } + else + { + hashCode = 31; + } + hashCode = 31 * hashCode + this.trailingHeaders.GetHashCode(); + hashCode = 31 * hashCode + base.GetHashCode(); + this.hash = hashCode; + } + // ReSharper restore NonReadonlyMemberInGetHashCode + return hashCode; + } + + public override bool Equals(object obj) + { + if (!(obj is DefaultFullHttpResponse other)) + { + return false; + } + return base.Equals(other) + && this.content.Equals(other.content) + && this.trailingHeaders.Equals(other.trailingHeaders); + } + + public override string ToString() => HttpMessageUtil.AppendFullResponse(new StringBuilder(256), this).ToString(); + } +} diff --git a/src/DotNetty.Codecs.Http/DefaultHttpContent.cs b/src/DotNetty.Codecs.Http/DefaultHttpContent.cs new file mode 100644 index 0000000..bc9315c --- /dev/null +++ b/src/DotNetty.Codecs.Http/DefaultHttpContent.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http +{ + using System.Diagnostics.Contracts; + using DotNetty.Buffers; + using DotNetty.Common; + using DotNetty.Common.Utilities; + + public class DefaultHttpContent : DefaultHttpObject, IHttpContent + { + readonly IByteBuffer content; + + public DefaultHttpContent(IByteBuffer content) + { + Contract.Requires(content != null); + + this.content = content; + } + + public IByteBuffer Content => this.content; + + public IByteBufferHolder Copy() => this.Replace(this.content.Copy()); + + public IByteBufferHolder Duplicate() => this.Replace(this.content.Duplicate()); + + public IByteBufferHolder RetainedDuplicate() => this.Replace(this.content.RetainedDuplicate()); + + public virtual IByteBufferHolder Replace(IByteBuffer buffer) => new DefaultHttpContent(buffer); + + public int ReferenceCount => this.content.ReferenceCount; + + public IReferenceCounted Retain() + { + this.content.Retain(); + return this; + } + + public IReferenceCounted Retain(int increment) + { + this.content.Retain(increment); + return this; + } + + public IReferenceCounted Touch() + { + this.content.Touch(); + return this; + } + + public IReferenceCounted Touch(object hint) + { + this.content.Touch(hint); + return this; + } + + public bool Release() => this.content.Release(); + + public bool Release(int decrement) => this.content.Release(decrement); + + public override string ToString() => $"{StringUtil.SimpleClassName(this)} (data: {this.content}, decoderResult: {this.Result})"; + } +} diff --git a/src/DotNetty.Codecs.Http/DefaultHttpHeaders.cs b/src/DotNetty.Codecs.Http/DefaultHttpHeaders.cs new file mode 100644 index 0000000..45aad90 --- /dev/null +++ b/src/DotNetty.Codecs.Http/DefaultHttpHeaders.cs @@ -0,0 +1,372 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System; + using System.Collections.Generic; + using System.Collections.Immutable; + using System.Diagnostics; + using System.Runtime.CompilerServices; + using DotNetty.Codecs; + using DotNetty.Common.Utilities; + + public class DefaultHttpHeaders : HttpHeaders + { + const int HighestInvalidValueCharMask = ~15; + internal static readonly INameValidator HttpNameValidator = new HeaderNameValidator(); + internal static readonly INameValidator NotNullValidator = new NullNameValidator(); + + sealed class NameProcessor : IByteProcessor + { + public bool Process(byte value) + { + ValidateHeaderNameElement(value); + return true; + } + } + + sealed class HeaderNameValidator : INameValidator + { + static readonly NameProcessor ByteProcessor = new NameProcessor(); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void ValidateName(ICharSequence name) + { + if (name == null || name.Count == 0) + { + ThrowHelper.ThrowArgumentException_HeaderName(); + } + if (name is AsciiString asciiString) + { + asciiString.ForEachByte(ByteProcessor); + } + else + { + // Go through each character in the name + Debug.Assert(name != null); + // ReSharper disable once ForCanBeConvertedToForeach + // Avoid new enumerator instance + for (int index = 0; index < name.Count; ++index) + { + ValidateHeaderNameElement(name[index]); + } + } + } + } + + readonly DefaultHeaders headers; + + public DefaultHttpHeaders() : this(true) + { + } + + public DefaultHttpHeaders(bool validate) : this(validate, NameValidator(validate)) + { + } + + protected DefaultHttpHeaders(bool validate, INameValidator nameValidator) + : this(new DefaultHeaders(AsciiString.CaseInsensitiveHasher, + ValueConverter(validate), nameValidator)) + { + } + + protected DefaultHttpHeaders(DefaultHeaders headers) + { + this.headers = headers; + } + + public override HttpHeaders Add(HttpHeaders httpHeaders) + { + if (httpHeaders is DefaultHttpHeaders defaultHttpHeaders) + { + this.headers.Add(defaultHttpHeaders.headers); + return this; + } + return base.Add(httpHeaders); + } + + public override HttpHeaders Set(HttpHeaders httpHeaders) + { + if (httpHeaders is DefaultHttpHeaders defaultHttpHeaders) + { + this.headers.Set(defaultHttpHeaders.headers); + return this; + } + return base.Set(httpHeaders); + } + + public override HttpHeaders Add(AsciiString name, object value) + { + this.headers.AddObject(name, value); + return this; + } + + public override HttpHeaders AddInt(AsciiString name, int value) + { + this.headers.AddInt(name, value); + return this; + } + + public override HttpHeaders AddShort(AsciiString name, short value) + { + this.headers.AddShort(name, value); + return this; + } + + public override HttpHeaders Remove(AsciiString name) + { + this.headers.Remove(name); + return this; + } + + public override HttpHeaders Set(AsciiString name, object value) + { + this.headers.SetObject(name, value); + return this; + } + + public override HttpHeaders Set(AsciiString name, IEnumerable values) + { + this.headers.SetObject(name, values); + return this; + } + + public override HttpHeaders SetInt(AsciiString name, int value) + { + this.headers.SetInt(name, value); + return this; + } + + public override HttpHeaders SetShort(AsciiString name, short value) + { + this.headers.SetShort(name, value); + return this; + } + + public override HttpHeaders Clear() + { + this.headers.Clear(); + return this; + } + + public override bool TryGet(AsciiString name, out ICharSequence value) => this.headers.TryGet(name, out value); + + public override bool TryGetInt(AsciiString name, out int value) => this.headers.TryGetInt(name, out value); + + public override int GetInt(AsciiString name, int defaultValue) => this.headers.GetInt(name, defaultValue); + + public override bool TryGetShort(AsciiString name, out short value) => this.headers.TryGetShort(name, out value); + + public override short GetShort(AsciiString name, short defaultValue) => this.headers.GetShort(name, defaultValue); + + public override bool TryGetTimeMillis(AsciiString name, out long value) => this.headers.TryGetTimeMillis(name, out value); + + public override long GetTimeMillis(AsciiString name, long defaultValue) => this.headers.GetTimeMillis(name, defaultValue); + + public override IList GetAll(AsciiString name) => this.headers.GetAll(name); + + public override IEnumerable ValueCharSequenceIterator(AsciiString name) => this.headers.ValueIterator(name); + + public override IList> Entries() + { + if (this.IsEmpty) + { + return ImmutableList>.Empty; + } + var entriesConverted = new List>(this.headers.Size); + foreach(HeaderEntry entry in this) + { + entriesConverted.Add(entry); + } + return entriesConverted; + } + + public override IEnumerator> GetEnumerator() => this.headers.GetEnumerator(); + + public override bool Contains(AsciiString name) => this.headers.Contains(name); + + public override bool IsEmpty => this.headers.IsEmpty; + + public override int Size => this.headers.Size; + + public override bool Contains(AsciiString name, ICharSequence value, bool ignoreCase) => + this.headers.Contains(name, value, + ignoreCase ? AsciiString.CaseInsensitiveHasher : AsciiString.CaseSensitiveHasher); + + public override ISet Names() => this.headers.Names(); + + public override bool Equals(object obj) => obj is DefaultHttpHeaders other + && this.headers.Equals(other.headers, AsciiString.CaseSensitiveHasher); + + public override int GetHashCode() => this.headers.HashCode(AsciiString.CaseSensitiveHasher); + + public override HttpHeaders Copy() => new DefaultHttpHeaders(this.headers.Copy()); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void ValidateHeaderNameElement(byte value) + { + switch (value) + { + case 0x00: + case 0x09: //'\t': + case 0x0a: //'\n': + case 0x0b: + case 0x0c: //'\f': + case 0x0d: //'\r': + case 0x20: //' ': + case 0x2c: //',': + case 0x3a: //':': + case 0x3b: //';': + case 0x3d: //'=': + ThrowHelper.ThrowArgumentException_HeaderValue(value); + break; + default: + // Check to see if the character is not an ASCII character, or invalid + if (value > 127) + { + ThrowHelper.ThrowArgumentException_HeaderValueNonAscii(value); + } + break; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void ValidateHeaderNameElement(char value) + { + switch (value) + { + case '\x00': + case '\t': + case '\n': + case '\x0b': + case '\f': + case '\r': + case ' ': + case ',': + case ':': + case ';': + case '=': + ThrowHelper.ThrowArgumentException_HeaderValue(value); + break; + default: + // Check to see if the character is not an ASCII character, or invalid + if (value > 127) + { + ThrowHelper.ThrowArgumentException_HeaderValueNonAscii(value); + } + break; + } + } + + protected static IValueConverter ValueConverter(bool validate) => + validate ? DefaultHeaderValueConverterAndValidator : DefaultHeaderValueConverter; + + protected static INameValidator NameValidator(bool validate) => + validate ? HttpNameValidator : NotNullValidator; + + static readonly HeaderValueConverter DefaultHeaderValueConverter = new HeaderValueConverter(); + + class HeaderValueConverter : CharSequenceValueConverter + { + public override ICharSequence ConvertObject(object value) + { + if (value is ICharSequence seq) + { + return seq; + } + if (value is DateTime time) + { + return new StringCharSequence(DateFormatter.Format(time)); + } + return new StringCharSequence(value.ToString()); + } + } + + static readonly HeaderValueConverterAndValidator DefaultHeaderValueConverterAndValidator = new HeaderValueConverterAndValidator(); + + sealed class HeaderValueConverterAndValidator : HeaderValueConverter + { + public override ICharSequence ConvertObject(object value) + { + ICharSequence seq = base.ConvertObject(value); + int state = 0; + // Start looping through each of the character + // ReSharper disable once ForCanBeConvertedToForeach + // Avoid enumerator allocation + for (int index = 0; index < seq.Count; index++) + { + state = ValidateValueChar(state, seq[index]); + } + + if (state != 0) + { + ThrowHelper.ThrowArgumentException_HeaderValueEnd(seq); + } + return seq; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static int ValidateValueChar(int state, char character) + { + // State: + // 0: Previous character was neither CR nor LF + // 1: The previous character was CR + // 2: The previous character was LF + if ((character & HighestInvalidValueCharMask) == 0) + { + // Check the absolutely prohibited characters. + switch (character) + { + case '\x00': // NULL + ThrowHelper.ThrowArgumentException_HeaderValueNullChar(); + break; + case '\x0b': // Vertical tab + ThrowHelper.ThrowArgumentException_HeaderValueVerticalTabChar(); + break; + case '\f': + ThrowHelper.ThrowArgumentException_HeaderValueFormFeed(); + break; + } + } + + // Check the CRLF (HT | SP) pattern + switch (state) + { + case 0: + switch (character) + { + case '\r': + return 1; + case '\n': + return 2; + } + break; + case 1: + switch (character) + { + case '\n': + return 2; + default: + ThrowHelper.ThrowArgumentException_NewLineAfterLineFeed(); + break; + } + break; + case 2: + switch (character) + { + case '\t': + case ' ': + return 0; + default: + ThrowHelper.ThrowArgumentException_TabAndSpaceAfterLineFeed(); + break; + } + break; + } + + return state; + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/DefaultHttpMessage.cs b/src/DotNetty.Codecs.Http/DefaultHttpMessage.cs new file mode 100644 index 0000000..4ba8e30 --- /dev/null +++ b/src/DotNetty.Codecs.Http/DefaultHttpMessage.cs @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoProperty +// ReSharper disable ConvertToAutoPropertyWithPrivateSetter +namespace DotNetty.Codecs.Http +{ + using System.Diagnostics.Contracts; + + public abstract class DefaultHttpMessage : DefaultHttpObject, IHttpMessage + { + const int HashCodePrime = 31; + HttpVersion version; + readonly HttpHeaders headers; + + protected DefaultHttpMessage(HttpVersion version) : this(version, true, false) + { + } + + protected DefaultHttpMessage(HttpVersion version, bool validateHeaders, bool singleFieldHeaders) + { + Contract.Requires(version != null); + + this.version = version; + this.headers = singleFieldHeaders + ? new CombinedHttpHeaders(validateHeaders) + : new DefaultHttpHeaders(validateHeaders); + } + + protected DefaultHttpMessage(HttpVersion version, HttpHeaders headers) + { + Contract.Requires(version != null); + Contract.Requires(headers != null); + + this.version = version; + this.headers = headers; + } + + public HttpHeaders Headers => this.headers; + + public HttpVersion ProtocolVersion => this.version; + + public override int GetHashCode() + { + int result = 1; + result = HashCodePrime * result + this.headers.GetHashCode(); + // ReSharper disable once NonReadonlyMemberInGetHashCode + result = HashCodePrime * result + this.version.GetHashCode(); + result = HashCodePrime * result + base.GetHashCode(); + return result; + } + + public override bool Equals(object obj) + { + if (!(obj is DefaultHttpMessage other)) + { + return false; + } + + return this.headers.Equals(other.headers) + && this.version.Equals(other.version) + && base.Equals(obj); + } + + public IHttpMessage SetProtocolVersion(HttpVersion value) + { + Contract.Requires(value != null); + this.version = value; + return this; + } + } +} diff --git a/src/DotNetty.Codecs.Http/DefaultHttpObject.cs b/src/DotNetty.Codecs.Http/DefaultHttpObject.cs new file mode 100644 index 0000000..0cd7d71 --- /dev/null +++ b/src/DotNetty.Codecs.Http/DefaultHttpObject.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System.Diagnostics.Contracts; + + public class DefaultHttpObject : IHttpObject + { + const int HashCodePrime = 31; + DecoderResult decoderResult = DecoderResult.Success; + + protected DefaultHttpObject() + { + } + + public DecoderResult Result + { + get => this.decoderResult; + set + { + Contract.Requires(value != null); + this.decoderResult = value; + } + } + + public override int GetHashCode() + { + int result = 1; + // ReSharper disable once NonReadonlyMemberInGetHashCode + result = HashCodePrime * result + this.decoderResult.GetHashCode(); + return result; + } + + public override bool Equals(object obj) + { + if (!(obj is DefaultHttpObject other)) + { + return false; + } + return this.decoderResult.Equals(other.decoderResult); + } + } +} diff --git a/src/DotNetty.Codecs.Http/DefaultHttpRequest.cs b/src/DotNetty.Codecs.Http/DefaultHttpRequest.cs new file mode 100644 index 0000000..16d09b9 --- /dev/null +++ b/src/DotNetty.Codecs.Http/DefaultHttpRequest.cs @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoPropertyWithPrivateSetter +namespace DotNetty.Codecs.Http +{ + using System; + using System.Diagnostics.Contracts; + using System.Text; + + public class DefaultHttpRequest : DefaultHttpMessage, IHttpRequest + { + const int HashCodePrime = 31; + + HttpMethod method; + string uri; + + public DefaultHttpRequest(HttpVersion httpVersion, HttpMethod method, string uri) + : this(httpVersion, method, uri, true) + { + } + + public DefaultHttpRequest(HttpVersion version, HttpMethod method, string uri, bool validateHeaders) + : base(version, validateHeaders, false) + { + Contract.Requires(method != null); + Contract.Requires(uri != null); + + this.method = method; + this.uri = uri; + } + + public DefaultHttpRequest(HttpVersion version, HttpMethod method, string uri, HttpHeaders headers) + : base(version, headers) + { + Contract.Requires(method != null); + Contract.Requires(uri != null); + + this.method = method; + this.uri = uri; + } + + public HttpMethod Method => this.method; + + public string Uri => this.uri; + + public IHttpRequest SetMethod(HttpMethod value) + { + Contract.Requires(value != null); + this.method = value; + return this; + } + + public IHttpRequest SetUri(string value) + { + Contract.Requires(value != null); + this.uri = value; + return this; + } + + // ReSharper disable NonReadonlyMemberInGetHashCode + public override int GetHashCode() + { + int result = 1; + result = HashCodePrime * result + this.method.GetHashCode(); + result = HashCodePrime * result + this.uri.GetHashCode(); + result = HashCodePrime * result + base.GetHashCode(); + + return result; + } + // ReSharper restore NonReadonlyMemberInGetHashCode + + public override bool Equals(object obj) + { + if (!(obj is DefaultHttpRequest other)) + { + return false; + } + + return this.method.Equals(other.method) + && this.uri.Equals(other.uri, StringComparison.OrdinalIgnoreCase) + && base.Equals(obj); + } + + public override string ToString() => HttpMessageUtil.AppendRequest(new StringBuilder(256), this).ToString(); + } +} diff --git a/src/DotNetty.Codecs.Http/DefaultHttpResponse.cs b/src/DotNetty.Codecs.Http/DefaultHttpResponse.cs new file mode 100644 index 0000000..5bfb9b6 --- /dev/null +++ b/src/DotNetty.Codecs.Http/DefaultHttpResponse.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWithPrivateSetter +namespace DotNetty.Codecs.Http +{ + using System.Diagnostics.Contracts; + using System.Text; + + public class DefaultHttpResponse : DefaultHttpMessage, IHttpResponse + { + HttpResponseStatus status; + + public DefaultHttpResponse(HttpVersion version, HttpResponseStatus status, bool validateHeaders = true, bool singleFieldHeaders = false) + : base(version, validateHeaders, singleFieldHeaders) + { + Contract.Requires(status != null); + + this.status = status; + } + + public DefaultHttpResponse(HttpVersion version, HttpResponseStatus status, HttpHeaders headers) + : base(version, headers) + { + Contract.Requires(status != null); + + this.status = status; + } + + public HttpResponseStatus Status => this.status; + + public IHttpResponse SetStatus(HttpResponseStatus value) + { + Contract.Requires(value != null); + this.status = value; + return this; + } + + public override string ToString() => HttpMessageUtil.AppendResponse(new StringBuilder(256), this).ToString(); + } +} diff --git a/src/DotNetty.Codecs.Http/DefaultLastHttpContent.cs b/src/DotNetty.Codecs.Http/DefaultLastHttpContent.cs new file mode 100644 index 0000000..cfa2f69 --- /dev/null +++ b/src/DotNetty.Codecs.Http/DefaultLastHttpContent.cs @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http +{ + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + + public class DefaultLastHttpContent : DefaultHttpContent, ILastHttpContent + { + readonly HttpHeaders trailingHeaders; + readonly bool validateHeaders; + + public DefaultLastHttpContent() : this(Unpooled.Buffer(0), true) + { + } + + public DefaultLastHttpContent(IByteBuffer content) : this(content, true) + { + } + + public DefaultLastHttpContent(IByteBuffer content, bool validateHeaders) + : base(content) + { + this.trailingHeaders = new TrailingHttpHeaders(validateHeaders); + this.validateHeaders = validateHeaders; + } + + public HttpHeaders TrailingHeaders => this.trailingHeaders; + + public override IByteBufferHolder Replace(IByteBuffer buffer) + { + var dup = new DefaultLastHttpContent(this.Content, this.validateHeaders); + dup.TrailingHeaders.Set(this.trailingHeaders); + return dup; + } + + public override string ToString() + { + var buf = new StringBuilder(base.ToString()); + buf.Append(StringUtil.Newline); + this.AppendHeaders(buf); + + // Remove the last newline. + buf.Length = buf.Length - StringUtil.Newline.Length; + return buf.ToString(); + } + + void AppendHeaders(StringBuilder buf) + { + foreach (HeaderEntry e in this.trailingHeaders) + { + buf.Append($"{e.Key}: {e.Value}{StringUtil.Newline}"); + } + } + + sealed class TrailerNameValidator : INameValidator + { + public void ValidateName(ICharSequence name) + { + DefaultHttpHeaders.HttpNameValidator.ValidateName(name); + if (HttpHeaderNames.ContentLength.ContentEqualsIgnoreCase(name) + || HttpHeaderNames.TransferEncoding.ContentEqualsIgnoreCase(name) + || HttpHeaderNames.Trailer.ContentEqualsIgnoreCase(name)) + { + ThrowHelper.ThrowArgumentException_TrailingHeaderName(name); + } + } + } + + sealed class TrailingHttpHeaders : DefaultHttpHeaders + { + static readonly TrailerNameValidator TrailerNameValidator = new TrailerNameValidator(); + + public TrailingHttpHeaders(bool validate) + : base(validate, validate ? TrailerNameValidator : NotNullValidator) + { + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/DotNetty.Codecs.Http.csproj b/src/DotNetty.Codecs.Http/DotNetty.Codecs.Http.csproj new file mode 100644 index 0000000..b6e29b1 --- /dev/null +++ b/src/DotNetty.Codecs.Http/DotNetty.Codecs.Http.csproj @@ -0,0 +1,48 @@ + + + + netstandard1.3;net45 + true + DotNetty.Codecs.Http + Http codec for DotNetty + Copyright © Microsoft Corporation + DotNetty: Http codec + en-US + 0.4.7 + Microsoft Azure + $(NoWarn);CS1591 + true + false + true + DotNetty.Codecs.Http + ../../DotNetty.snk + true + true + socket;tcp;protocol;netty;dotnetty;network;http + https://github.com/Azure/DotNetty/ + https://github.com/Azure/DotNetty/blob/master/LICENSE.txt + git + https://github.com/Azure/DotNetty/ + 1.6.1 + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/DotNetty.Codecs.Http/EmptyHttpHeaders.cs b/src/DotNetty.Codecs.Http/EmptyHttpHeaders.cs new file mode 100644 index 0000000..15a53d6 --- /dev/null +++ b/src/DotNetty.Codecs.Http/EmptyHttpHeaders.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System; + using System.Collections.Generic; + using System.Collections.Immutable; + using System.Linq; + using DotNetty.Common.Utilities; + + public class EmptyHttpHeaders : HttpHeaders + { + static readonly IEnumerator> EmptryEnumerator = + Enumerable.Empty>().GetEnumerator(); + + public static readonly EmptyHttpHeaders Default = new EmptyHttpHeaders(); + + protected EmptyHttpHeaders() + { + } + + public override bool TryGet(AsciiString name, out ICharSequence value) + { + value = default(ICharSequence); + return false; + } + + public override bool TryGetInt(AsciiString name, out int value) + { + value = default(int); + return false; + } + + public override int GetInt(AsciiString name, int defaultValue) => defaultValue; + + public override bool TryGetShort(AsciiString name, out short value) + { + value = default(short); + return false; + } + + public override short GetShort(AsciiString name, short defaultValue) => defaultValue; + + public override bool TryGetTimeMillis(AsciiString name, out long value) + { + value = default(long); + return false; + } + + public override long GetTimeMillis(AsciiString name, long defaultValue) => defaultValue; + + public override IList GetAll(AsciiString name) => ImmutableList.Empty; + + public override IList> Entries() => ImmutableList>.Empty; + + public override bool Contains(AsciiString name) => false; + + public override bool IsEmpty => true; + + public override int Size => 0; + + public override ISet Names() => ImmutableHashSet.Empty; + + public override HttpHeaders AddInt(AsciiString name, int value) => throw new NotSupportedException("read only"); + + public override HttpHeaders AddShort(AsciiString name, short value) => throw new NotSupportedException("read only"); + + public override HttpHeaders Set(AsciiString name, object value) => throw new NotSupportedException("read only"); + + public override HttpHeaders Set(AsciiString name, IEnumerable values) => throw new NotSupportedException("read only"); + + public override HttpHeaders SetInt(AsciiString name, int value) => throw new NotSupportedException("read only"); + + public override HttpHeaders SetShort(AsciiString name, short value) => throw new NotSupportedException("read only"); + + public override HttpHeaders Remove(AsciiString name) => throw new NotSupportedException("read only"); + + public override HttpHeaders Clear() => throw new NotSupportedException("read only"); + + public override HttpHeaders Add(AsciiString name, object value) => throw new NotSupportedException("read only"); + + public override IEnumerator> GetEnumerator() => EmptryEnumerator; + } +} diff --git a/src/DotNetty.Codecs.Http/EmptyLastHttpContent.cs b/src/DotNetty.Codecs.Http/EmptyLastHttpContent.cs new file mode 100644 index 0000000..11007bb --- /dev/null +++ b/src/DotNetty.Codecs.Http/EmptyLastHttpContent.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System; + using DotNetty.Buffers; + using DotNetty.Common; + + public sealed class EmptyLastHttpContent : ILastHttpContent + { + public static readonly EmptyLastHttpContent Default = new EmptyLastHttpContent(); + + EmptyLastHttpContent() + { + this.Content = Unpooled.Empty; + } + + public DecoderResult Result + { + get => DecoderResult.Success; + set => throw new NotSupportedException("read only"); + } + + public int ReferenceCount => 1; + + public IReferenceCounted Retain() => this; + + public IReferenceCounted Retain(int increment) => this; + + public IReferenceCounted Touch() => this; + + public IReferenceCounted Touch(object hint) => this; + + public bool Release() => false; + + public bool Release(int decrement) => false; + + public IByteBuffer Content { get; } + + public IByteBufferHolder Copy() => this; + + public IByteBufferHolder Duplicate() => this; + + public IByteBufferHolder RetainedDuplicate() => this; + + public IByteBufferHolder Replace(IByteBuffer content) => new DefaultLastHttpContent(content); + + public HttpHeaders TrailingHeaders => EmptyHttpHeaders.Default; + + public override string ToString() => nameof(EmptyLastHttpContent); + } +} diff --git a/src/DotNetty.Codecs.Http/HttpChunkedInput.cs b/src/DotNetty.Codecs.Http/HttpChunkedInput.cs new file mode 100644 index 0000000..7633f52 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpChunkedInput.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using DotNetty.Buffers; + using DotNetty.Handlers.Streams; + + public class HttpChunkedInput : IChunkedInput + { + readonly IChunkedInput input; + readonly ILastHttpContent lastHttpContent; + bool sentLastChunk; + + public HttpChunkedInput(IChunkedInput input) + { + this.input = input; + this.lastHttpContent = EmptyLastHttpContent.Default; + } + + public HttpChunkedInput(IChunkedInput input, ILastHttpContent lastHttpContent) + { + this.input = input; + this.lastHttpContent = lastHttpContent; + } + + public bool IsEndOfInput => this.input.IsEndOfInput && this.sentLastChunk; + + public void Close() => this.input.Close(); + + public IHttpContent ReadChunk(IByteBufferAllocator allocator) + { + if (this.input.IsEndOfInput) + { + if (this.sentLastChunk) + { + return null; + } + // Send last chunk for this input + this.sentLastChunk = true; + return this.lastHttpContent; + } + else + { + IByteBuffer buf = this.input.ReadChunk(allocator); + return buf == null ? null : new DefaultHttpContent(buf); + } + } + + public long Length => this.input.Length; + + public long Progress => this.input.Progress; + } +} diff --git a/src/DotNetty.Codecs.Http/HttpClientCodec.cs b/src/DotNetty.Codecs.Http/HttpClientCodec.cs new file mode 100644 index 0000000..dd539d7 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpClientCodec.cs @@ -0,0 +1,259 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System.Collections.Generic; + using System.Threading; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public class HttpClientCodec : CombinedChannelDuplexHandler, + HttpClientUpgradeHandler.ISourceCodec + { + // A queue that is used for correlating a request and a response. + readonly Queue queue = new Queue(); + readonly bool parseHttpAfterConnectRequest; + + // If true, decoding stops (i.e. pass-through) + bool done; + + long requestResponseCounter; + readonly bool failOnMissingResponse; + + public HttpClientCodec() : this(4096, 8192, 8192, false) + { + } + + public HttpClientCodec(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize) + : this(maxInitialLineLength, maxHeaderSize, maxChunkSize, false) + { + } + + public HttpClientCodec( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool failOnMissingResponse) + : this(maxInitialLineLength, maxHeaderSize, maxChunkSize, failOnMissingResponse, true) + { + } + + public HttpClientCodec( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool failOnMissingResponse, + bool validateHeaders) + : this(maxInitialLineLength, maxHeaderSize, maxChunkSize, failOnMissingResponse, validateHeaders, false) + { + } + + public HttpClientCodec( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool failOnMissingResponse, + bool validateHeaders, bool parseHttpAfterConnectRequest) + { + this.Init(new Decoder(this, maxInitialLineLength, maxHeaderSize, maxChunkSize, validateHeaders), new Encoder(this)); + this.failOnMissingResponse = failOnMissingResponse; + this.parseHttpAfterConnectRequest = parseHttpAfterConnectRequest; + } + + public HttpClientCodec( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool failOnMissingResponse, + bool validateHeaders, int initialBufferSize) + : this(maxInitialLineLength, maxHeaderSize, maxChunkSize, failOnMissingResponse, validateHeaders, initialBufferSize, false) + { + } + + public HttpClientCodec( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool failOnMissingResponse, + bool validateHeaders, int initialBufferSize, bool parseHttpAfterConnectRequest) + { + this.Init(new Decoder(this, maxInitialLineLength, maxHeaderSize, maxChunkSize, validateHeaders, initialBufferSize), new Encoder(this)); + this.parseHttpAfterConnectRequest = parseHttpAfterConnectRequest; + this.failOnMissingResponse = failOnMissingResponse; + } + + public void PrepareUpgradeFrom(IChannelHandlerContext ctx) => ((Encoder)this.OutboundHandler).Upgraded = true; + + public void UpgradeFrom(IChannelHandlerContext ctx) + { + IChannelPipeline p = ctx.Channel.Pipeline; + p.Remove(this); + } + + public bool SingleDecode + { + get => this.InboundHandler.SingleDecode; + set => this.InboundHandler.SingleDecode = value; + } + + sealed class Encoder : HttpRequestEncoder + { + readonly HttpClientCodec clientCodec; + internal bool Upgraded; + + public Encoder(HttpClientCodec clientCodec) + { + this.clientCodec = clientCodec; + } + + protected override void Encode(IChannelHandlerContext context, object message, List output) + { + if (this.Upgraded) + { + output.Add(ReferenceCountUtil.Retain(message)); + return; + } + + if (message is IHttpRequest request && !this.clientCodec.done) + { + this.clientCodec.queue.Enqueue(request.Method); + } + + base.Encode(context, message, output); + + if (this.clientCodec.failOnMissingResponse && !this.clientCodec.done) + { + // check if the request is chunked if so do not increment + if (message is ILastHttpContent) + { + // increment as its the last chunk + Interlocked.Increment(ref this.clientCodec.requestResponseCounter); + } + } + } + } + + sealed class Decoder : HttpResponseDecoder + { + readonly HttpClientCodec clientCodec; + + internal Decoder(HttpClientCodec clientCodec, int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool validateHeaders) + : base(maxInitialLineLength, maxHeaderSize, maxChunkSize, validateHeaders) + { + this.clientCodec = clientCodec; + } + + internal Decoder(HttpClientCodec clientCodec, int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool validateHeaders, int initialBufferSize) + : base(maxInitialLineLength, maxHeaderSize, maxChunkSize, validateHeaders, initialBufferSize) + { + this.clientCodec = clientCodec; + } + + protected override void Decode(IChannelHandlerContext context, IByteBuffer buffer, List output) + { + if (this.clientCodec.done) + { + int readable = this.ActualReadableBytes; + if (readable == 0) + { + // if non is readable just return null + // https://github.com/netty/netty/issues/1159 + return; + } + output.Add(buffer.ReadBytes(readable)); + } + else + { + int oldSize = output.Count; + base.Decode(context, buffer, output); + if (this.clientCodec.failOnMissingResponse) + { + int size = output.Count; + for (int i = oldSize; i < size; i++) + { + this.Decrement(output[i]); + } + } + } + } + + void Decrement(object msg) + { + if (ReferenceEquals(null, msg)) + { + return; + } + + // check if it's an Header and its transfer encoding is not chunked. + if (msg is ILastHttpContent) + { + Interlocked.Decrement(ref this.clientCodec.requestResponseCounter); + } + } + + protected override bool IsContentAlwaysEmpty(IHttpMessage msg) + { + int statusCode = ((IHttpResponse)msg).Status.Code; + if (statusCode == 100 || statusCode == 101) + { + // 100-continue and 101 switching protocols response should be excluded from paired comparison. + // Just delegate to super method which has all the needed handling. + return base.IsContentAlwaysEmpty(msg); + } + + // Get the getMethod of the HTTP request that corresponds to the + // current response. + HttpMethod method = this.clientCodec.queue.Dequeue(); + + char firstChar = method.AsciiName[0]; + switch (firstChar) + { + case 'H': + // According to 4.3, RFC2616: + // All responses to the HEAD request getMethod MUST NOT include a + // message-body, even though the presence of entity-header fields + // might lead one to believe they do. + if (HttpMethod.Head.Equals(method)) + { + return true; + + // The following code was inserted to work around the servers + // that behave incorrectly. It has been commented out + // because it does not work with well behaving servers. + // Please note, even if the 'Transfer-Encoding: chunked' + // header exists in the HEAD response, the response should + // have absolutely no content. + // + // Interesting edge case: + // Some poorly implemented servers will send a zero-byte + // chunk if Transfer-Encoding of the response is 'chunked'. + // + // return !msg.isChunked(); + } + break; + case 'C': + // Successful CONNECT request results in a response with empty body. + if (statusCode == 200) + { + if (HttpMethod.Connect.Equals(method)) + { + // Proxy connection established - Parse HTTP only if configured by parseHttpAfterConnectRequest, + // else pass through. + if (!this.clientCodec.parseHttpAfterConnectRequest) + { + this.clientCodec.done = true; + this.clientCodec.queue.Clear(); + } + return true; + } + } + break; + } + + return base.IsContentAlwaysEmpty(msg); + } + + public override void ChannelInactive(IChannelHandlerContext ctx) + { + base.ChannelInactive(ctx); + + if (this.clientCodec.failOnMissingResponse) + { + long missingResponses = Interlocked.Read(ref this.clientCodec.requestResponseCounter); + if (missingResponses > 0) + { + ctx.FireExceptionCaught(new PrematureChannelClosureException( + $"channel gone inactive with {missingResponses} missing response(s)")); + } + } + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpClientUpgradeHandler.cs b/src/DotNetty.Codecs.Http/HttpClientUpgradeHandler.cs new file mode 100644 index 0000000..9f10eee --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpClientUpgradeHandler.cs @@ -0,0 +1,197 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System; + using System.Collections.Generic; + using System.Diagnostics; + using System.Diagnostics.Contracts; + using System.Text; + using System.Threading.Tasks; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + // Note HttpObjectAggregator already implements IChannelHandler + public class HttpClientUpgradeHandler : HttpObjectAggregator + { + // User events that are fired to notify about upgrade status. + public enum UpgradeEvent + { + // The Upgrade request was sent to the server. + UpgradeIssued, + + // The Upgrade to the new protocol was successful. + UpgradeSuccessful, + + // The Upgrade was unsuccessful due to the server not issuing + // with a 101 Switching Protocols response. + UpgradeRejected + } + + public interface ISourceCodec + { + // Removes or disables the encoder of this codec so that the {@link UpgradeCodec} can send an initial greeting + // (if any). + void PrepareUpgradeFrom(IChannelHandlerContext ctx); + + // Removes this codec (i.e. all associated handlers) from the pipeline. + void UpgradeFrom(IChannelHandlerContext ctx); + } + + public interface IUpgradeCodec + { + // Returns the name of the protocol supported by this codec, as indicated by the {@code 'UPGRADE'} header. + ICharSequence Protocol { get; } + + // Sets any protocol-specific headers required to the upgrade request. Returns the names of + // all headers that were added. These headers will be used to populate the CONNECTION header. + ICollection SetUpgradeHeaders(IChannelHandlerContext ctx, IHttpRequest upgradeRequest); + + /// + // Performs an HTTP protocol upgrade from the source codec. This method is responsible for + // adding all handlers required for the new protocol. + // + // ctx the context for the current handler. + // upgradeResponse the 101 Switching Protocols response that indicates that the server + // has switched to this protocol. + void UpgradeTo(IChannelHandlerContext ctx, IFullHttpResponse upgradeResponse); + } + + readonly ISourceCodec sourceCodec; + readonly IUpgradeCodec upgradeCodec; + bool upgradeRequested; + + public HttpClientUpgradeHandler(ISourceCodec sourceCodec, IUpgradeCodec upgradeCodec, int maxContentLength) + : base(maxContentLength) + { + Contract.Requires(sourceCodec != null); + Contract.Requires(upgradeCodec != null); + + this.sourceCodec = sourceCodec; + this.upgradeCodec = upgradeCodec; + } + + public override Task WriteAsync(IChannelHandlerContext context, object message) + { + if (!(message is IHttpRequest)) + { + return context.WriteAsync(message); + } + + if (this.upgradeRequested) + { + return TaskEx.FromException(new InvalidOperationException("Attempting to write HTTP request with upgrade in progress")); + } + + this.upgradeRequested = true; + this.SetUpgradeRequestHeaders(context, (IHttpRequest)message); + + // Continue writing the request. + Task task = context.WriteAsync(message); + + // Notify that the upgrade request was issued. + context.FireUserEventTriggered(UpgradeEvent.UpgradeIssued); + // Now we wait for the next HTTP response to see if we switch protocols. + return task; + } + + protected override void Decode(IChannelHandlerContext context, IHttpObject message, List output) + { + IFullHttpResponse response = null; + try + { + if (!this.upgradeRequested) + { + throw new InvalidOperationException("Read HTTP response without requesting protocol switch"); + } + + if (message is IHttpResponse rep) + { + if (!HttpResponseStatus.SwitchingProtocols.Equals(rep.Status)) + { + // The server does not support the requested protocol, just remove this handler + // and continue processing HTTP. + // NOTE: not releasing the response since we're letting it propagate to the + // next handler. + context.FireUserEventTriggered(UpgradeEvent.UpgradeRejected); + RemoveThisHandler(context); + context.FireChannelRead(rep); + return; + } + } + + if (message is IFullHttpResponse fullRep) + { + response = fullRep; + // Need to retain since the base class will release after returning from this method. + response.Retain(); + output.Add(response); + } + else + { + // Call the base class to handle the aggregation of the full request. + base.Decode(context, message, output); + if (output.Count == 0) + { + // The full request hasn't been created yet, still awaiting more data. + return; + } + + Debug.Assert(output.Count == 1); + response = (IFullHttpResponse)output[0]; + } + + if (response.Headers.TryGet(HttpHeaderNames.Upgrade, out ICharSequence upgradeHeader) && !AsciiString.ContentEqualsIgnoreCase(this.upgradeCodec.Protocol, upgradeHeader)) + { + throw new InvalidOperationException($"Switching Protocols response with unexpected UPGRADE protocol: {upgradeHeader}"); + } + + // Upgrade to the new protocol. + this.sourceCodec.PrepareUpgradeFrom(context); + this.upgradeCodec.UpgradeTo(context, response); + + // Notify that the upgrade to the new protocol completed successfully. + context.FireUserEventTriggered(UpgradeEvent.UpgradeSuccessful); + + // We guarantee UPGRADE_SUCCESSFUL event will be arrived at the next handler + // before http2 setting frame and http response. + this.sourceCodec.UpgradeFrom(context); + + // We switched protocols, so we're done with the upgrade response. + // Release it and clear it from the output. + response.Release(); + output.Clear(); + RemoveThisHandler(context); + } + catch (Exception exception) + { + ReferenceCountUtil.Release(response); + context.FireExceptionCaught(exception); + RemoveThisHandler(context); + } + } + + static void RemoveThisHandler(IChannelHandlerContext ctx) => ctx.Channel.Pipeline.Remove(ctx.Name); + + void SetUpgradeRequestHeaders(IChannelHandlerContext ctx, IHttpRequest request) + { + // Set the UPGRADE header on the request. + request.Headers.Set(HttpHeaderNames.Upgrade, this.upgradeCodec.Protocol); + + // Add all protocol-specific headers to the request. + var connectionParts = new List(2); + connectionParts.AddRange(this.upgradeCodec.SetUpgradeHeaders(ctx, request)); + + // Set the CONNECTION header from the set of all protocol-specific headers that were added. + var builder = new StringBuilder(); + foreach (ICharSequence part in connectionParts) + { + builder.Append(part); + builder.Append(','); + } + builder.Append(HttpHeaderValues.Upgrade); + request.Headers.Set(HttpHeaderNames.Connection, new StringCharSequence(builder.ToString())); + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpConstants.cs b/src/DotNetty.Codecs.Http/HttpConstants.cs new file mode 100644 index 0000000..d8689f6 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpConstants.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System.Text; + using DotNetty.Buffers; + + public static class HttpConstants + { + // Horizontal space + public const byte HorizontalSpace = 32; + + // Horizontal tab + public const byte HorizontalTab = 9; + + // Carriage return + public const byte CarriageReturn = 13; + + // Equals '=' + public const byte EqualsSign = 61; + + // Line feed character + public const byte LineFeed = 10; + + // Colon ':' + public const byte Colon = 58; + + // Semicolon ';' + public const byte Semicolon = 59; + + // Comma ',' + public const byte Comma = 44; + + // Double quote '"' + public const byte DoubleQuote = (byte)'"'; + + // Default character set (UTF-8) + public static readonly Encoding DefaultEncoding = Encoding.UTF8; + + // Horizontal space in char + public static readonly char HorizontalSpaceChar = (char)HorizontalSpace; + + // For HttpObjectEncoder + internal static readonly int CrlfShort = (CarriageReturn << 8) | LineFeed; + + internal static readonly int ZeroCrlfMedium = ('0' << 16) | CrlfShort; + + internal static readonly byte[] ZeroCrlfCrlf = { (byte)'0', CarriageReturn, LineFeed, CarriageReturn, LineFeed }; + + internal static readonly IByteBuffer CrlfBuf = Unpooled.UnreleasableBuffer(Unpooled.WrappedBuffer(new[] { CarriageReturn, LineFeed })); + + internal static readonly IByteBuffer ZeroCrlfCrlfBuf = Unpooled.UnreleasableBuffer(Unpooled.WrappedBuffer(ZeroCrlfCrlf)); + } +} diff --git a/src/DotNetty.Codecs.Http/HttpContentCompressor.cs b/src/DotNetty.Codecs.Http/HttpContentCompressor.cs new file mode 100644 index 0000000..fae24a4 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpContentCompressor.cs @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System; + using System.Diagnostics.Contracts; + using DotNetty.Codecs.Compression; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + + public class HttpContentCompressor : HttpContentEncoder + { + static readonly AsciiString GZipString = AsciiString.Cached("gzip"); + static readonly AsciiString DeflateString = AsciiString.Cached("deflate"); + + readonly int compressionLevel; + readonly int windowBits; + readonly int memLevel; + + IChannelHandlerContext handlerContext; + + public HttpContentCompressor() : this(6) + { + } + + public HttpContentCompressor(int compressionLevel) : this(compressionLevel, 15, 8) + { + } + + public HttpContentCompressor(int compressionLevel, int windowBits, int memLevel) + { + Contract.Requires(compressionLevel >= 0 && compressionLevel <= 9); + Contract.Requires(windowBits >= 9 && windowBits <= 15); + Contract.Requires(memLevel >= 1 && memLevel <= 9); + + this.compressionLevel = compressionLevel; + this.windowBits = windowBits; + this.memLevel = memLevel; + } + + public override void HandlerAdded(IChannelHandlerContext context) => this.handlerContext = context; + + protected override Result BeginEncode(IHttpResponse headers, ICharSequence acceptEncoding) + { + if (headers.Headers.Contains(HttpHeaderNames.ContentEncoding)) + { + // Content-Encoding was set, either as something specific or as the IDENTITY encoding + // Therefore, we should NOT encode here + return null; + } + + ZlibWrapper? wrapper = this.DetermineWrapper(acceptEncoding); + if (wrapper == null) + { + return null; + } + + ICharSequence targetContentEncoding; + switch (wrapper.Value) + { + case ZlibWrapper.Gzip: + targetContentEncoding = GZipString; + break; + case ZlibWrapper.Zlib: + targetContentEncoding = DeflateString; + break; + default: + throw new CodecException($"{wrapper.Value} not supported, only Gzip and Zlib are allowed."); + } + + return new Result(targetContentEncoding, + new EmbeddedChannel( + this.handlerContext.Channel.Id, + this.handlerContext.Channel.Metadata.HasDisconnect, + this.handlerContext.Channel.Configuration, + ZlibCodecFactory.NewZlibEncoder( + wrapper.Value, this.compressionLevel, this.windowBits, this.memLevel))); + } + + protected internal ZlibWrapper? DetermineWrapper(ICharSequence acceptEncoding) + { + float starQ = -1.0f; + float gzipQ = -1.0f; + float deflateQ = -1.0f; + ICharSequence[] parts = CharUtil.Split(acceptEncoding, ','); + foreach (ICharSequence encoding in parts) + { + float q = 1.0f; + int equalsPos = encoding.IndexOf('='); + if (equalsPos != -1) + { + try + { + q = float.Parse(encoding.ToString(equalsPos + 1)); + } + catch (FormatException) + { + // Ignore encoding + q = 0.0f; + } + } + + if (CharUtil.Contains(encoding, '*')) + { + starQ = q; + } + else if (AsciiString.Contains(encoding, GZipString) && q > gzipQ) + { + gzipQ = q; + } + else if (AsciiString.Contains(encoding, DeflateString) && q > deflateQ) + { + deflateQ = q; + } + } + if (gzipQ > 0.0f || deflateQ > 0.0f) + { + return gzipQ >= deflateQ ? ZlibWrapper.Gzip : ZlibWrapper.Zlib; + } + if (starQ > 0.0f) + { + // ReSharper disable CompareOfFloatsByEqualityOperator + if (gzipQ == -1.0f) + { + return ZlibWrapper.Gzip; + } + if (deflateQ == -1.0f) + { + return ZlibWrapper.Zlib; + } + // ReSharper restore CompareOfFloatsByEqualityOperator + } + return null; + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpContentDecoder.cs b/src/DotNetty.Codecs.Http/HttpContentDecoder.cs new file mode 100644 index 0000000..8932289 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpContentDecoder.cs @@ -0,0 +1,243 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System; + using System.Collections.Generic; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + + public abstract class HttpContentDecoder : MessageToMessageDecoder + { + internal static readonly AsciiString Identity = HttpHeaderValues.Identity; + + protected IChannelHandlerContext HandlerContext; + EmbeddedChannel decoder; + bool continueResponse; + + protected override void Decode(IChannelHandlerContext context, IHttpObject message, List output) + { + if (message is IHttpResponse response && response.Status.Code == 100) + { + if (!(response is ILastHttpContent)) + { + this.continueResponse = true; + } + // 100-continue response must be passed through. + output.Add(ReferenceCountUtil.Retain(message)); + return; + } + + if (this.continueResponse) + { + if (message is ILastHttpContent) + { + this.continueResponse = false; + } + // 100-continue response must be passed through. + output.Add(ReferenceCountUtil.Retain(message)); + return; + } + + if (message is IHttpMessage httpMessage) + { + this.Cleanup(); + HttpHeaders headers = httpMessage.Headers; + + // Determine the content encoding. + if (headers.TryGet(HttpHeaderNames.ContentEncoding, out ICharSequence contentEncoding)) + { + contentEncoding = AsciiString.Trim(contentEncoding); + } + else + { + contentEncoding = Identity; + } + this.decoder = this.NewContentDecoder(contentEncoding); + + if (this.decoder == null) + { + if (httpMessage is IHttpContent httpContent) + { + httpContent.Retain(); + } + output.Add(httpMessage); + return; + } + + // Remove content-length header: + // the correct value can be set only after all chunks are processed/decoded. + // If buffering is not an issue, add HttpObjectAggregator down the chain, it will set the header. + // Otherwise, rely on LastHttpContent message. + if (headers.Contains(HttpHeaderNames.ContentLength)) + { + headers.Remove(HttpHeaderNames.ContentLength); + headers.Set(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + } + // Either it is already chunked or EOF terminated. + // See https://github.com/netty/netty/issues/5892 + + // set new content encoding, + ICharSequence targetContentEncoding = this.GetTargetContentEncoding(contentEncoding); + if (HttpHeaderValues.Identity.ContentEquals(targetContentEncoding)) + { + // Do NOT set the 'Content-Encoding' header if the target encoding is 'identity' + // as per: http://tools.ietf.org/html/rfc2616#section-14.11 + headers.Remove(HttpHeaderNames.ContentEncoding); + } + else + { + headers.Set(HttpHeaderNames.ContentEncoding, targetContentEncoding); + } + + if (httpMessage is IHttpContent) + { + // If message is a full request or response object (headers + data), don't copy data part into out. + // Output headers only; data part will be decoded below. + // Note: "copy" object must not be an instance of LastHttpContent class, + // as this would (erroneously) indicate the end of the HttpMessage to other handlers. + IHttpMessage copy; + if (httpMessage is IHttpRequest req) + { + // HttpRequest or FullHttpRequest + copy = new DefaultHttpRequest(req.ProtocolVersion, req.Method, req.Uri); + } + else if (httpMessage is IHttpResponse res) + { + // HttpResponse or FullHttpResponse + copy = new DefaultHttpResponse(res.ProtocolVersion, res.Status); + } + else + { + throw new CodecException($"Object of class {StringUtil.SimpleClassName(httpMessage.GetType())} is not a HttpRequest or HttpResponse"); + } + copy.Headers.Set(httpMessage.Headers); + copy.Result = httpMessage.Result; + output.Add(copy); + } + else + { + output.Add(httpMessage); + } + } + + if (message is IHttpContent c) + { + if (this.decoder == null) + { + output.Add(c.Retain()); + } + else + { + this.DecodeContent(c, output); + } + } + } + + void DecodeContent(IHttpContent c, IList output) + { + IByteBuffer content = c.Content; + + this.Decode(content, output); + + if (c is ILastHttpContent last) + { + this.FinishDecode(output); + + // Generate an additional chunk if the decoder produced + // the last product on closure, + HttpHeaders headers = last.TrailingHeaders; + if (headers.IsEmpty) + { + output.Add(EmptyLastHttpContent.Default); + } + else + { + output.Add(new ComposedLastHttpContent(headers)); + } + } + } + + protected abstract EmbeddedChannel NewContentDecoder(ICharSequence contentEncoding); + + protected ICharSequence GetTargetContentEncoding(ICharSequence contentEncoding) => Identity; + + public override void HandlerRemoved(IChannelHandlerContext context) + { + this.CleanupSafely(context); + base.HandlerRemoved(context); + } + + public override void ChannelInactive(IChannelHandlerContext context) + { + this.CleanupSafely(context); + base.ChannelInactive(context); + } + + public override void HandlerAdded(IChannelHandlerContext context) + { + this.HandlerContext = context; + base.HandlerAdded(context); + } + + void Cleanup() + { + if (this.decoder != null) + { + this.decoder.FinishAndReleaseAll(); + this.decoder = null; + } + } + + void CleanupSafely(IChannelHandlerContext context) + { + try + { + this.Cleanup(); + } + catch (Exception cause) + { + // If cleanup throws any error we need to propagate it through the pipeline + // so we don't fail to propagate pipeline events. + context.FireExceptionCaught(cause); + } + } + + void Decode(IByteBuffer buf, IList output) + { + // call retain here as it will call release after its written to the channel + this.decoder.WriteInbound(buf.Retain()); + this.FetchDecoderOutput(output); + } + + void FinishDecode(ICollection output) + { + if (this.decoder.Finish()) + { + this.FetchDecoderOutput(output); + } + this.decoder = null; + } + + void FetchDecoderOutput(ICollection output) + { + for (;;) + { + var buf = this.decoder.ReadInbound(); + if (buf == null) + { + break; + } + if (!buf.IsReadable()) + { + buf.Release(); + continue; + } + output.Add(new DefaultHttpContent(buf)); + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpContentDecompressor.cs b/src/DotNetty.Codecs.Http/HttpContentDecompressor.cs new file mode 100644 index 0000000..35eed08 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpContentDecompressor.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using DotNetty.Codecs.Compression; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + + public class HttpContentDecompressor : HttpContentDecoder + { + readonly bool strict; + + public HttpContentDecompressor() : this(false) + { + } + + public HttpContentDecompressor(bool strict) + { + this.strict = strict; + } + + protected override EmbeddedChannel NewContentDecoder(ICharSequence contentEncoding) + { + if (HttpHeaderValues.Gzip.ContentEqualsIgnoreCase(contentEncoding) + || HttpHeaderValues.XGzip.ContentEqualsIgnoreCase(contentEncoding)) + { + return new EmbeddedChannel( + this.HandlerContext.Channel.Id, + this.HandlerContext.Channel.Metadata.HasDisconnect, + this.HandlerContext.Channel.Configuration, + ZlibCodecFactory.NewZlibDecoder(ZlibWrapper.Gzip)); + } + + if (HttpHeaderValues.Deflate.ContentEqualsIgnoreCase(contentEncoding) + || HttpHeaderValues.XDeflate.ContentEqualsIgnoreCase(contentEncoding)) + { + ZlibWrapper wrapper = this.strict ? ZlibWrapper.Zlib : ZlibWrapper.ZlibOrNone; + return new EmbeddedChannel( + this.HandlerContext.Channel.Id, + this.HandlerContext.Channel.Metadata.HasDisconnect, + this.HandlerContext.Channel.Configuration, + ZlibCodecFactory.NewZlibDecoder(wrapper)); + } + + // 'identity' or unsupported + return null; + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpContentEncoder.cs b/src/DotNetty.Codecs.Http/HttpContentEncoder.cs new file mode 100644 index 0000000..eb7454e --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpContentEncoder.cs @@ -0,0 +1,357 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System; + using System.Collections.Generic; + using System.Diagnostics; + using System.Diagnostics.Contracts; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + + public abstract class HttpContentEncoder : MessageToMessageCodec + { + enum State + { + PassThrough, + AwaitHeaders, + AwaitContent + } + + static readonly AsciiString ZeroLengthHead = AsciiString.Cached("HEAD"); + static readonly AsciiString ZeroLengthConnect = AsciiString.Cached("CONNECT"); + static readonly int ContinueCode = HttpResponseStatus.Continue.Code; + + readonly Queue acceptEncodingQueue = new Queue(); + EmbeddedChannel encoder; + State state = State.AwaitHeaders; + + public override bool AcceptOutboundMessage(object msg) => msg is IHttpContent || msg is IHttpResponse; + + protected override void Decode(IChannelHandlerContext ctx, IHttpRequest msg, List output) + { + ICharSequence acceptedEncoding = msg.Headers.Get(HttpHeaderNames.AcceptEncoding, HttpContentDecoder.Identity); + + HttpMethod meth = msg.Method; + if (ReferenceEquals(meth, HttpMethod.Head)) + { + acceptedEncoding = ZeroLengthHead; + } + else if (ReferenceEquals(meth, HttpMethod.Connect)) + { + acceptedEncoding = ZeroLengthConnect; + } + + this.acceptEncodingQueue.Enqueue(acceptedEncoding); + output.Add(ReferenceCountUtil.Retain(msg)); + } + + protected override void Encode(IChannelHandlerContext ctx, IHttpObject msg, List output) + { + bool isFull = msg is IHttpResponse && msg is ILastHttpContent; + switch (this.state) + { + case State.AwaitHeaders: + { + EnsureHeaders(msg); + Debug.Assert(this.encoder == null); + + var res = (IHttpResponse)msg; + int code = res.Status.Code; + ICharSequence acceptEncoding; + if (code == ContinueCode) + { + // We need to not poll the encoding when response with CONTINUE as another response will follow + // for the issued request. See https://github.com/netty/netty/issues/4079 + acceptEncoding = null; + } + else + { + // Get the list of encodings accepted by the peer. + acceptEncoding = this.acceptEncodingQueue.Count > 0 ? this.acceptEncodingQueue.Dequeue() : null; + if (acceptEncoding == null) + { + throw new InvalidOperationException("cannot send more responses than requests"); + } + } + + // + // per rfc2616 4.3 Message Body + // All 1xx (informational), 204 (no content), and 304 (not modified) responses MUST NOT include a + // message-body. All other responses do include a message-body, although it MAY be of zero length. + // + // 9.4 HEAD + // The HEAD method is identical to GET except that the server MUST NOT return a message-body + // in the response. + // + // Also we should pass through HTTP/1.0 as transfer-encoding: chunked is not supported. + // + // See https://github.com/netty/netty/issues/5382 + // + if (IsPassthru(res.ProtocolVersion, code, acceptEncoding)) + { + if (isFull) + { + output.Add(ReferenceCountUtil.Retain(res)); + } + else + { + output.Add(res); + // Pass through all following contents. + this.state = State.PassThrough; + } + break; + } + + if (isFull) + { + // Pass through the full response with empty content and continue waiting for the the next resp. + if (!((IByteBufferHolder)res).Content.IsReadable()) + { + output.Add(ReferenceCountUtil.Retain(res)); + break; + } + } + + // Prepare to encode the content. + Result result = this.BeginEncode(res, acceptEncoding); + + // If unable to encode, pass through. + if (result == null) + { + if (isFull) + { + output.Add(ReferenceCountUtil.Retain(res)); + } + else + { + output.Add(res); + // Pass through all following contents. + this.state = State.PassThrough; + } + break; + } + + this.encoder = result.ContentEncoder; + + // Encode the content and remove or replace the existing headers + // so that the message looks like a decoded message. + res.Headers.Set(HttpHeaderNames.ContentEncoding, result.TargetContentEncoding); + + // Output the rewritten response. + if (isFull) + { + // Convert full message into unfull one. + var newRes = new DefaultHttpResponse(res.ProtocolVersion, res.Status); + newRes.Headers.Set(res.Headers); + output.Add(newRes); + + EnsureContent(res); + this.EncodeFullResponse(newRes, (IHttpContent)res, output); + break; + } + else + { + // Make the response chunked to simplify content transformation. + res.Headers.Remove(HttpHeaderNames.ContentLength); + res.Headers.Set(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + + output.Add(res); + this.state = State.AwaitContent; + + if (!(msg is IHttpContent)) + { + // only break out the switch statement if we have not content to process + // See https://github.com/netty/netty/issues/2006 + break; + } + // Fall through to encode the content + goto case State.AwaitContent; + } + } + case State.AwaitContent: + { + EnsureContent(msg); + if (this.EncodeContent((IHttpContent)msg, output)) + { + this.state = State.AwaitHeaders; + } + break; + } + case State.PassThrough: + { + EnsureContent(msg); + output.Add(ReferenceCountUtil.Retain(msg)); + // Passed through all following contents of the current response. + if (msg is ILastHttpContent) + { + this.state = State.AwaitHeaders; + } + break; + } + } + } + + void EncodeFullResponse(IHttpResponse newRes, IHttpContent content, IList output) + { + int existingMessages = output.Count; + this.EncodeContent(content, output); + + if (HttpUtil.IsContentLengthSet(newRes)) + { + // adjust the content-length header + int messageSize = 0; + for (int i = existingMessages; i < output.Count; i++) + { + if (output[i] is IHttpContent httpContent) + { + messageSize += httpContent.Content.ReadableBytes; + } + } + HttpUtil.SetContentLength(newRes, messageSize); + } + else + { + newRes.Headers.Set(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + } + } + + static bool IsPassthru(HttpVersion version, int code, ICharSequence httpMethod) => + code < 200 || code == 204 || code == 304 + || (ReferenceEquals(httpMethod, ZeroLengthHead) || ReferenceEquals(httpMethod, ZeroLengthConnect) && code == 200) + || ReferenceEquals(version, HttpVersion.Http10); + + static void EnsureHeaders(IHttpObject msg) + { + if (!(msg is IHttpResponse)) + { + throw new CodecException($"unexpected message type: {msg.GetType().Name} (expected: {StringUtil.SimpleClassName()})"); + } + } + + static void EnsureContent(IHttpObject msg) + { + if (!(msg is IHttpContent)) + { + throw new CodecException($"unexpected message type: {msg.GetType().Name} (expected: {StringUtil.SimpleClassName()})"); + } + } + + bool EncodeContent(IHttpContent c, IList output) + { + IByteBuffer content = c.Content; + + this.Encode(content, output); + + if (c is ILastHttpContent last) + { + this.FinishEncode(output); + + // Generate an additional chunk if the decoder produced + // the last product on closure, + HttpHeaders headers = last.TrailingHeaders; + if (headers.IsEmpty) + { + output.Add(EmptyLastHttpContent.Default); + } + else + { + output.Add(new ComposedLastHttpContent(headers)); + } + return true; + } + return false; + } + + protected abstract Result BeginEncode(IHttpResponse headers, ICharSequence acceptEncoding); + + public override void HandlerRemoved(IChannelHandlerContext context) + { + this.CleanupSafely(context); + base.HandlerRemoved(context); + } + + public override void ChannelInactive(IChannelHandlerContext context) + { + this.CleanupSafely(context); + base.ChannelInactive(context); + } + + void Cleanup() + { + if (this.encoder != null) + { + // Clean-up the previous encoder if not cleaned up correctly. + this.encoder.FinishAndReleaseAll(); + this.encoder = null; + } + } + + void CleanupSafely(IChannelHandlerContext ctx) + { + try + { + this.Cleanup(); + } + catch (Exception cause) + { + // If cleanup throws any error we need to propagate it through the pipeline + // so we don't fail to propagate pipeline events. + ctx.FireExceptionCaught(cause); + } + } + + void Encode(IByteBuffer buf, IList output) + { + // call retain here as it will call release after its written to the channel + this.encoder.WriteOutbound(buf.Retain()); + this.FetchEncoderOutput(output); + } + + void FinishEncode(IList output) + { + if (this.encoder.Finish()) + { + this.FetchEncoderOutput(output); + } + this.encoder = null; + } + + void FetchEncoderOutput(ICollection output) + { + for (;;) + { + var buf = this.encoder.ReadOutbound(); + if (buf == null) + { + break; + } + if (!buf.IsReadable()) + { + buf.Release(); + continue; + } + output.Add(new DefaultHttpContent(buf)); + } + } + + public sealed class Result + { + public Result(ICharSequence targetContentEncoding, EmbeddedChannel contentEncoder) + { + Contract.Requires(targetContentEncoding != null); + Contract.Requires(contentEncoder != null); + + this.TargetContentEncoding = targetContentEncoding; + this.ContentEncoder = contentEncoder; + } + + public ICharSequence TargetContentEncoding { get; } + + public EmbeddedChannel ContentEncoder { get; } + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpExpectationFailedEvent.cs b/src/DotNetty.Codecs.Http/HttpExpectationFailedEvent.cs new file mode 100644 index 0000000..984bb3e --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpExpectationFailedEvent.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + // A user event designed to communicate that a expectation has failed and there should be no expectation that a + // body will follow. + public sealed class HttpExpectationFailedEvent + { + public static readonly HttpExpectationFailedEvent Default = new HttpExpectationFailedEvent(); + + HttpExpectationFailedEvent() + { + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpHeaderNames.cs b/src/DotNetty.Codecs.Http/HttpHeaderNames.cs new file mode 100644 index 0000000..711a7bd --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpHeaderNames.cs @@ -0,0 +1,170 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using DotNetty.Common.Utilities; + + /// + /// Standard HTTP header names. + /// + /// These are all defined as lowercase to support HTTP/2 requirements while also not + /// violating HTTP/1.x requirements.New header names should always be lowercase. + /// + public static class HttpHeaderNames + { + public static readonly AsciiString Accept = AsciiString.Cached("accept"); + + public static readonly AsciiString AcceptCharset = AsciiString.Cached("accept-charset"); + + public static readonly AsciiString AcceptEncoding = AsciiString.Cached("accept-encoding"); + + public static readonly AsciiString AcceptLanguage = AsciiString.Cached("accept-language"); + + public static readonly AsciiString AcceptRanges = AsciiString.Cached("accept-ranges"); + + public static readonly AsciiString AcceptPatch = AsciiString.Cached("accept-patch"); + + public static readonly AsciiString AccessControlAllowCredentials = AsciiString.Cached("access-control-allow-credentials"); + + public static readonly AsciiString AccessControlAllowHeaders = AsciiString.Cached("access-control-allow-headers"); + + public static readonly AsciiString AccessControlAllowMethods = AsciiString.Cached("access-control-allow-methods"); + + public static readonly AsciiString AccessControlAllowOrigin = AsciiString.Cached("access-control-allow-origin"); + + public static readonly AsciiString AccessControlExposeHeaders = AsciiString.Cached("access-control-expose-headers"); + + public static readonly AsciiString AccessControlMaxAge = AsciiString.Cached("access-control-max-age"); + + public static readonly AsciiString AccessControlRequestHeaders = AsciiString.Cached("access-control-request-headers"); + + public static readonly AsciiString AccessControlRequestMethod = AsciiString.Cached("access-control-request-method"); + + public static readonly AsciiString Age = AsciiString.Cached("age"); + + public static readonly AsciiString Allow = AsciiString.Cached("allow"); + + public static readonly AsciiString Authorization = AsciiString.Cached("authorization"); + + public static readonly AsciiString CacheControl = AsciiString.Cached("cache-control"); + + public static readonly AsciiString Connection = AsciiString.Cached("connection"); + + public static readonly AsciiString ContentBase = AsciiString.Cached("content-base"); + + public static readonly AsciiString ContentEncoding = AsciiString.Cached("content-encoding"); + + public static readonly AsciiString ContentLanguage = AsciiString.Cached("content-language"); + + public static readonly AsciiString ContentLength = AsciiString.Cached("content-length"); + + public static readonly AsciiString ContentLocation = AsciiString.Cached("content-location"); + + public static readonly AsciiString ContentTransferEncoding = AsciiString.Cached("content-transfer-encoding"); + + public static readonly AsciiString ContentDisposition = AsciiString.Cached("content-disposition"); + + public static readonly AsciiString ContentMD5 = AsciiString.Cached("content-md5"); + + public static readonly AsciiString ContentRange = AsciiString.Cached("content-range"); + + public static readonly AsciiString ContentSecurityPolicy = AsciiString.Cached("content-security-policy"); + + public static readonly AsciiString ContentType = AsciiString.Cached("content-type"); + + public static readonly AsciiString Cookie = AsciiString.Cached("cookie"); + + public static readonly AsciiString Date = AsciiString.Cached("date"); + + public static readonly AsciiString Etag = AsciiString.Cached("etag"); + + public static readonly AsciiString Expect = AsciiString.Cached("expect"); + + public static readonly AsciiString Expires = AsciiString.Cached("expires"); + + public static readonly AsciiString From = AsciiString.Cached("from"); + + public static readonly AsciiString Host = AsciiString.Cached("host"); + + public static readonly AsciiString IfMatch = AsciiString.Cached("if-match"); + + public static readonly AsciiString IfModifiedSince = AsciiString.Cached("if-modified-since"); + + public static readonly AsciiString IfNoneMatch = AsciiString.Cached("if-none-match"); + + public static readonly AsciiString IfRange = AsciiString.Cached("if-range"); + + public static readonly AsciiString IfUnmodifiedSince = AsciiString.Cached("if-unmodified-since"); + + public static readonly AsciiString LastModified = AsciiString.Cached("last-modified"); + + public static readonly AsciiString Location = AsciiString.Cached("location"); + + public static readonly AsciiString MaxForwards = AsciiString.Cached("max-forwards"); + + public static readonly AsciiString Origin = AsciiString.Cached("origin"); + + public static readonly AsciiString Pragma = AsciiString.Cached("pragma"); + + public static readonly AsciiString ProxyAuthenticate = AsciiString.Cached("proxy-authenticate"); + + public static readonly AsciiString ProxyAuthorization = AsciiString.Cached("proxy-authorization"); + + public static readonly AsciiString Range = AsciiString.Cached("range"); + + public static readonly AsciiString Referer = AsciiString.Cached("referer"); + + public static readonly AsciiString RetryAfter = AsciiString.Cached("retry-after"); + + public static readonly AsciiString SecWebsocketKey1 = AsciiString.Cached("sec-websocket-key1"); + + public static readonly AsciiString SecWebsocketKey2 = AsciiString.Cached("sec-websocket-key2"); + + public static readonly AsciiString SecWebsocketLocation = AsciiString.Cached("sec-websocket-location"); + + public static readonly AsciiString SecWebsocketOrigin = AsciiString.Cached("sec-websocket-origin"); + + public static readonly AsciiString SecWebsocketProtocol = AsciiString.Cached("sec-websocket-protocol"); + + public static readonly AsciiString SecWebsocketVersion = AsciiString.Cached("sec-websocket-version"); + + public static readonly AsciiString SecWebsocketKey = AsciiString.Cached("sec-websocket-key"); + + public static readonly AsciiString SecWebsocketAccept = AsciiString.Cached("sec-websocket-accept"); + + public static readonly AsciiString SecWebsocketExtensions = AsciiString.Cached("sec-websocket-extensions"); + + public static readonly AsciiString Server = AsciiString.Cached("server"); + + public static readonly AsciiString SetCookie = AsciiString.Cached("set-cookie"); + + public static readonly AsciiString SetCookie2 = AsciiString.Cached("set-cookie2"); + + public static readonly AsciiString Te = AsciiString.Cached("te"); + + public static readonly AsciiString Trailer = AsciiString.Cached("trailer"); + + public static readonly AsciiString TransferEncoding = AsciiString.Cached("transfer-encoding"); + + public static readonly AsciiString Upgrade = AsciiString.Cached("upgrade"); + + public static readonly AsciiString UserAgent = AsciiString.Cached("user-agent"); + + public static readonly AsciiString Vary = AsciiString.Cached("vary"); + + public static readonly AsciiString Via = AsciiString.Cached("via"); + + public static readonly AsciiString Warning = AsciiString.Cached("warning"); + + public static readonly AsciiString WebsocketLocation = AsciiString.Cached("websocket-location"); + + public static readonly AsciiString WebsocketOrigin = AsciiString.Cached("websocket-origin"); + + public static readonly AsciiString WebsocketProtocol = AsciiString.Cached("websocket-protocol"); + + public static readonly AsciiString WwwAuthenticate = AsciiString.Cached("www-authenticate"); + + public static readonly AsciiString XFrameOptions = AsciiString.Cached("x-frame-options"); + } +} diff --git a/src/DotNetty.Codecs.Http/HttpHeaderValues.cs b/src/DotNetty.Codecs.Http/HttpHeaderValues.cs new file mode 100644 index 0000000..b3ad16c --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpHeaderValues.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using DotNetty.Common.Utilities; + + public static class HttpHeaderValues + { + public static readonly AsciiString ApplicationJson = AsciiString.Cached("application/json"); + + public static readonly AsciiString ApplicationXWwwFormUrlencoded = AsciiString.Cached("application/x-www-form-urlencoded"); + + public static readonly AsciiString ApplicationOctetStream = AsciiString.Cached("application/octet-stream"); + + public static readonly AsciiString Attachment = AsciiString.Cached("attachment"); + + public static readonly AsciiString Base64 = AsciiString.Cached("base64"); + + public static readonly AsciiString Binary = AsciiString.Cached("binary"); + + public static readonly AsciiString Boundary = AsciiString.Cached("boundary"); + + public static readonly AsciiString Bytes = AsciiString.Cached("bytes"); + + public static readonly AsciiString Charset = AsciiString.Cached("charset"); + + public static readonly AsciiString Chunked = AsciiString.Cached("chunked"); + + public static readonly AsciiString Close = AsciiString.Cached("close"); + + public static readonly AsciiString Compress = AsciiString.Cached("compress"); + + public static readonly AsciiString Continue = AsciiString.Cached("100-continue"); + + public static readonly AsciiString Deflate = AsciiString.Cached("deflate"); + + public static readonly AsciiString XDeflate = AsciiString.Cached("x-deflate"); + + public static readonly AsciiString File = AsciiString.Cached("file"); + + public static readonly AsciiString FileName = AsciiString.Cached("filename"); + + public static readonly AsciiString FormData = AsciiString.Cached("form-data"); + + public static readonly AsciiString Gzip = AsciiString.Cached("gzip"); + + public static readonly AsciiString GzipDeflate = AsciiString.Cached("gzip,deflate"); + + public static readonly AsciiString XGzip = AsciiString.Cached("x-gzip"); + + public static readonly AsciiString Identity = AsciiString.Cached("identity"); + + public static readonly AsciiString KeepAlive = AsciiString.Cached("keep-alive"); + + public static readonly AsciiString MaxAge = AsciiString.Cached("max-age"); + + public static readonly AsciiString MaxStale = AsciiString.Cached("max-stale"); + + public static readonly AsciiString MinFresh = AsciiString.Cached("min-fresh"); + + public static readonly AsciiString MultipartFormData = AsciiString.Cached("multipart/form-data"); + + public static readonly AsciiString MultipartMixed = AsciiString.Cached("multipart/mixed"); + + public static readonly AsciiString MustRevalidate = AsciiString.Cached("must-revalidate"); + + public static readonly AsciiString Name = AsciiString.Cached("name"); + + public static readonly AsciiString NoCache = AsciiString.Cached("no-cache"); + + public static readonly AsciiString NoStore = AsciiString.Cached("no-store"); + + public static readonly AsciiString NoTransform = AsciiString.Cached("no-transform"); + + public static readonly AsciiString None = AsciiString.Cached("none"); + + public static readonly AsciiString Zero = AsciiString.Cached("0"); + + public static readonly AsciiString OnlyIfCached = AsciiString.Cached("only-if-cached"); + + public static readonly AsciiString Private = AsciiString.Cached("private"); + + public static readonly AsciiString ProxyRevalidate = AsciiString.Cached("proxy-revalidate"); + + public static readonly AsciiString Public = AsciiString.Cached("public"); + + public static readonly AsciiString QuotedPrintable = AsciiString.Cached("quoted-printable"); + + public static readonly AsciiString SMaxage = AsciiString.Cached("s-maxage"); + + public static readonly AsciiString TextPlain = AsciiString.Cached("text/plain"); + + public static readonly AsciiString Trailers = AsciiString.Cached("trailers"); + + public static readonly AsciiString Upgrade = AsciiString.Cached("upgrade"); + + public static readonly AsciiString Websocket = AsciiString.Cached("websocket"); + } +} diff --git a/src/DotNetty.Codecs.Http/HttpHeaders.cs b/src/DotNetty.Codecs.Http/HttpHeaders.cs new file mode 100644 index 0000000..025a870 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpHeaders.cs @@ -0,0 +1,267 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ForCanBeConvertedToForeach +namespace DotNetty.Codecs.Http +{ + using System.Collections; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using DotNetty.Common.Utilities; + + using static Common.Utilities.AsciiString; + + public abstract class HttpHeaders : IEnumerable> + { + public abstract bool TryGet(AsciiString name, out ICharSequence value); + + public ICharSequence Get(AsciiString name, ICharSequence defaultValue) => this.TryGet(name, out ICharSequence value) ? value : defaultValue; + + public abstract bool TryGetInt(AsciiString name, out int value); + + public abstract int GetInt(AsciiString name, int defaultValue); + + public abstract bool TryGetShort(AsciiString name, out short value); + + public abstract short GetShort(AsciiString name, short defaultValue); + + public abstract bool TryGetTimeMillis(AsciiString name, out long value); + + public abstract long GetTimeMillis(AsciiString name, long defaultValue); + + public abstract IList GetAll(AsciiString name); + + public abstract IList> Entries(); + + public virtual IEnumerable ValueCharSequenceIterator(AsciiString name) => this.GetAll(name); + + public abstract bool Contains(AsciiString name); + + public abstract bool IsEmpty { get; } + + public abstract int Size { get; } + + public abstract ISet Names(); + + public abstract HttpHeaders Add(AsciiString name, object value); + + public HttpHeaders Add(AsciiString name, IEnumerable values) + { + foreach (object value in values) + { + this.Add(name, value); + } + return this; + } + + public virtual HttpHeaders Add(HttpHeaders headers) + { + Contract.Requires(headers != null); + + foreach (HeaderEntry pair in headers) + { + this.Add(pair.Key, pair.Value); + } + return this; + } + + public abstract HttpHeaders AddInt(AsciiString name, int value); + + public abstract HttpHeaders AddShort(AsciiString name, short value); + + public abstract HttpHeaders Set(AsciiString name, object value); + + public abstract HttpHeaders Set(AsciiString name, IEnumerable values); + + public virtual HttpHeaders Set(HttpHeaders headers) + { + Contract.Requires(headers != null); + + this.Clear(); + + if (headers.IsEmpty) + { + return this; + } + + foreach(HeaderEntry pair in headers) + { + this.Add(pair.Key, pair.Value); + } + return this; + } + + public HttpHeaders SetAll(HttpHeaders headers) + { + Contract.Requires(headers != null); + + if (headers.IsEmpty) + { + return this; + } + + foreach (HeaderEntry pair in headers) + { + this.Add(pair.Key, pair.Value); + } + + return this; + } + + public abstract HttpHeaders SetInt(AsciiString name, int value); + + public abstract HttpHeaders SetShort(AsciiString name, short value); + + public abstract HttpHeaders Remove(AsciiString name); + + public abstract HttpHeaders Clear(); + + public virtual bool Contains(AsciiString name, ICharSequence value, bool ignoreCase) + { + IEnumerable values = this.ValueCharSequenceIterator(name); + if (ignoreCase) + { + foreach (ICharSequence v in values) + { + if (v.ContentEqualsIgnoreCase(value)) + { + return true; + } + } + } + else + { + foreach (ICharSequence v in this.ValueCharSequenceIterator(name)) + { + if (v.ContentEquals(value)) + { + return true; + } + } + } + return false; + } + + public virtual bool ContainsValue(AsciiString name, ICharSequence value, bool ignoreCase) + { + foreach (ICharSequence v in this.ValueCharSequenceIterator(name)) + { + if (ContainsCommaSeparatedTrimmed(v, value, ignoreCase)) + { + return true; + } + } + return false; + } + + static bool ContainsCommaSeparatedTrimmed(ICharSequence rawNext, ICharSequence expected, bool ignoreCase) + { + int begin = 0; + int end; + if (ignoreCase) + { + if ((end = IndexOf(rawNext, ',', begin)) == -1) + { + if (ContentEqualsIgnoreCase(Trim(rawNext), expected)) + { + return true; + } + } + else + { + do + { + if (ContentEqualsIgnoreCase(Trim(rawNext.SubSequence(begin, end)), expected)) + { + return true; + } + begin = end + 1; + } + while ((end = IndexOf(rawNext, ',', begin)) != -1); + + if (begin < rawNext.Count) + { + if (ContentEqualsIgnoreCase(Trim(rawNext.SubSequence(begin, rawNext.Count)), expected)) + { + return true; + } + } + } + } + else + { + if ((end = IndexOf(rawNext, ',', begin)) == -1) + { + if (ContentEquals(Trim(rawNext), expected)) + { + return true; + } + } + else + { + do + { + if (ContentEquals(Trim(rawNext.SubSequence(begin, end)), expected)) + { + return true; + } + begin = end + 1; + } + while ((end = IndexOf(rawNext, ',', begin)) != -1); + + if (begin < rawNext.Count) + { + if (ContentEquals(Trim(rawNext.SubSequence(begin, rawNext.Count)), expected)) + { + return true; + } + } + } + } + return false; + } + + public bool TryGetAsString(AsciiString name, out string value) + { + if (this.TryGet(name, out ICharSequence v)) + { + value = v.ToString(); + return true; + } + else + { + value = default(string); + return false; + } + } + + public IList GetAllAsString(AsciiString name) + { + var values = new List(); + IList list = this.GetAll(name); + foreach (ICharSequence value in list) + { + values.Add(value.ToString()); + } + + return values; + } + + public abstract IEnumerator> GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator(); + + public override string ToString() => HeadersUtils.ToString(this, this.Size); + + /// + /// Deep copy of the headers. + /// + /// A deap copy of this. + public virtual HttpHeaders Copy() + { + var copy = new DefaultHttpHeaders(); + copy.Set(this); + return copy; + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpHeadersEncoder.cs b/src/DotNetty.Codecs.Http/HttpHeadersEncoder.cs new file mode 100644 index 0000000..55fedb1 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpHeadersEncoder.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + + using static HttpConstants; + + static class HttpHeadersEncoder + { + const int ColonAndSpaceShort = (Colon << 8) | HorizontalSpace; + + public static void EncoderHeader(AsciiString name, ICharSequence value, IByteBuffer buf) + { + int nameLen = name.Count; + int valueLen = value.Count; + int entryLen = nameLen + valueLen + 4; + buf.EnsureWritable(entryLen); + int offset = buf.WriterIndex; + WriteAscii(buf, offset, name); + offset += nameLen; + buf.SetShort(offset, ColonAndSpaceShort); + offset += 2; + WriteAscii(buf, offset, value); + offset += valueLen; + buf.SetShort(offset, CrlfShort); + offset += 2; + buf.SetWriterIndex(offset); + } + + static void WriteAscii(IByteBuffer buf, int offset, ICharSequence value) + { + if (value is AsciiString asciiString) + { + ByteBufferUtil.Copy(asciiString, 0, buf, offset, value.Count); + } + else + { + buf.SetCharSequence(offset, value, Encoding.ASCII); + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpMessageUtil.cs b/src/DotNetty.Codecs.Http/HttpMessageUtil.cs new file mode 100644 index 0000000..44dfdcf --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpMessageUtil.cs @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System.Text; + using DotNetty.Common.Utilities; + + static class HttpMessageUtil + { + internal static StringBuilder AppendRequest(StringBuilder buf, IHttpRequest req) + { + AppendCommon(buf, req); + AppendInitialLine(buf, req); + AppendHeaders(buf, req.Headers); + RemoveLastNewLine(buf); + return buf; + } + + internal static StringBuilder AppendResponse(StringBuilder buf, IHttpResponse res) + { + AppendCommon(buf, res); + AppendInitialLine(buf, res); + AppendHeaders(buf, res.Headers); + RemoveLastNewLine(buf); + return buf; + } + + static void AppendCommon(StringBuilder buf, IHttpMessage msg) + { + buf.Append($"{StringUtil.SimpleClassName(msg)}"); + buf.Append("(decodeResult: "); + buf.Append(msg.Result); + buf.Append(", version: "); + buf.Append(msg.ProtocolVersion); + buf.Append($"){StringUtil.Newline}"); + } + + internal static StringBuilder AppendFullRequest(StringBuilder buf, IFullHttpRequest req) + { + AppendFullCommon(buf, req); + AppendInitialLine(buf, req); + AppendHeaders(buf, req.Headers); + AppendHeaders(buf, req.TrailingHeaders); + RemoveLastNewLine(buf); + return buf; + } + + internal static StringBuilder AppendFullResponse(StringBuilder buf, IFullHttpResponse res) + { + AppendFullCommon(buf, res); + AppendInitialLine(buf, res); + AppendHeaders(buf, res.Headers); + AppendHeaders(buf, res.TrailingHeaders); + RemoveLastNewLine(buf); + return buf; + } + + static void AppendFullCommon(StringBuilder buf, IFullHttpMessage msg) + { + buf.Append(StringUtil.SimpleClassName(msg)); + buf.Append("(decodeResult: "); + buf.Append(msg.Result); + buf.Append(", version: "); + buf.Append(msg.ProtocolVersion); + buf.Append(", content: "); + buf.Append(msg.Content); + buf.Append(')'); + buf.Append(StringUtil.Newline); + } + + static void AppendInitialLine(StringBuilder buf, IHttpRequest req) => + buf.Append($"{req.Method} {req.Uri} {req.ProtocolVersion}{StringUtil.Newline}"); + + static void AppendInitialLine(StringBuilder buf, IHttpResponse res) => + buf.Append($"{res.ProtocolVersion} {res.Status}{StringUtil.Newline}"); + + static void AppendHeaders(StringBuilder buf, HttpHeaders headers) + { + foreach(HeaderEntry e in headers) + { + buf.Append($"{e.Key}:{e.Value}{StringUtil.Newline}"); + } + } + + static void RemoveLastNewLine(StringBuilder buf) => buf.Length = buf.Length - StringUtil.Newline.Length; + } +} diff --git a/src/DotNetty.Codecs.Http/HttpMethod.cs b/src/DotNetty.Codecs.Http/HttpMethod.cs new file mode 100644 index 0000000..f530bec --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpMethod.cs @@ -0,0 +1,234 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ForCanBeConvertedToForeach +// ReSharper disable ConvertToAutoPropertyWhenPossible +namespace DotNetty.Codecs.Http +{ + using System; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Runtime.CompilerServices; + using DotNetty.Common.Utilities; + + public sealed class HttpMethod : IComparable, IComparable + { + /** + * The OPTIONS method represents a request for information about the communication options + * available on the request/response chain identified by the Request-URI. This method allows + * the client to determine the options and/or requirements associated with a resource, or the + * capabilities of a server, without implying a resource action or initiating a resource + * retrieval. + */ + public static readonly HttpMethod Options = new HttpMethod("OPTIONS"); + + /** + * The GET method means retrieve whatever information (in the form of an entity) is identified + * by the Request-URI. If the Request-URI refers to a data-producing process, it is the + * produced data which shall be returned as the entity in the response and not the source text + * of the process, unless that text happens to be the output of the process. + */ + public static readonly HttpMethod Get = new HttpMethod("GET"); + + /** + * The HEAD method is identical to GET except that the server MUST NOT return a message-body + * in the response. + */ + public static readonly HttpMethod Head = new HttpMethod("HEAD"); + + /** + * The POST method is used to request that the origin server accept the entity enclosed in the + * request as a new subordinate of the resource identified by the Request-URI in the + * Request-Line. + */ + public static readonly HttpMethod Post = new HttpMethod("POST"); + + /** + * The PUT method requests that the enclosed entity be stored under the supplied Request-URI. + */ + public static readonly HttpMethod Put = new HttpMethod("PUT"); + + /** + * The PATCH method requests that a set of changes described in the + * request entity be applied to the resource identified by the Request-URI. + */ + public static readonly HttpMethod Patch = new HttpMethod("PATCH"); + + /** + * The DELETE method requests that the origin server delete the resource identified by the + * Request-URI. + */ + public static readonly HttpMethod Delete = new HttpMethod("DELETE"); + + /** + * The TRACE method is used to invoke a remote, application-layer loop- back of the request + * message. + */ + public static readonly HttpMethod Trace = new HttpMethod("TRACE"); + + /** + * This specification reserves the method name CONNECT for use with a proxy that can dynamically + * switch to being a tunnel + */ + public static readonly HttpMethod Connect = new HttpMethod("CONNECT"); + + // HashMap + static readonly Dictionary MethodMap; + + static HttpMethod() + { + MethodMap = new Dictionary + { + { Options.ToString(), Options }, + { Get.ToString(), Get }, + { Head.ToString(), Head }, + { Post.ToString(), Post }, + { Put.ToString(), Put }, + { Patch.ToString(), Patch }, + { Delete.ToString(), Delete }, + { Trace.ToString(), Trace }, + { Connect.ToString(), Connect }, + }; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static HttpMethod ValueOf(AsciiString name) + { + if (name != null) + { + HttpMethod result = ValueOfInline(name.Array); + if (result != null) + { + return result; + } + + // Fall back to slow path + if (MethodMap.TryGetValue(name.ToString(), out result)) + { + return result; + } + } + // Really slow path and error handling + return new HttpMethod(name?.ToString()); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static HttpMethod ValueOfInline(byte[] bytes) + { + if (bytes.Length <= 2) + { + return null; + } + + HttpMethod match = null; + int i = 0; + switch (bytes[i++]) + { + case (byte)'C': + match = Connect; + break; + case (byte)'D': + match = Delete; + break; + case (byte)'G': + match = Get; + break; + case (byte)'H': + match = Head; + break; + case (byte)'O': + match = Options; + break; + case (byte)'P': + switch (bytes[i++]) + { + case (byte)'O': + match = Post; + break; + case (byte)'U': + match = Put; + break; + case (byte)'A': + match = Patch; + break; + } + break; + case (byte)'T': + match = Trace; + break; + } + if (match != null) + { + byte[] array = match.name.Array; + if (bytes.Length == array.Length) + { + for (; i < bytes.Length; i++) + { + if (bytes[i] != array[i]) + { + match = null; + break; + } + } + } + else + { + match = null; + } + } + return match; + } + + readonly AsciiString name; + + // Creates a new HTTP method with the specified name. You will not need to + // create a new method unless you are implementing a protocol derived from + // HTTP, such as + // http://en.wikipedia.org/wiki/Real_Time_Streaming_Protocol and + // http://en.wikipedia.org/wiki/Internet_Content_Adaptation_Protocol + // + public HttpMethod(string name) + { + Contract.Requires(name != null); + + name = name.Trim(); + if (string.IsNullOrEmpty(name)) + { + throw new ArgumentException(nameof(name)); + } + + for (int i=0; i this.name.ToString(); + + public AsciiString AsciiName => this.name; + + public override int GetHashCode() => this.name.GetHashCode(); + + public override bool Equals(object obj) + { + if (!(obj is HttpMethod method)) + { + return false; + } + + return this.name.Equals(method.name); + } + + public override string ToString() => this.name.ToString(); + + public int CompareTo(object obj) => this.CompareTo(obj as HttpMethod); + + public int CompareTo(HttpMethod other) => this.name.CompareTo(other.name); + } +} diff --git a/src/DotNetty.Codecs.Http/HttpObjectAggregator.cs b/src/DotNetty.Codecs.Http/HttpObjectAggregator.cs new file mode 100644 index 0000000..6bf8de2 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpObjectAggregator.cs @@ -0,0 +1,354 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http +{ + using System; + using System.Diagnostics; + using System.Text; + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Common; + using DotNetty.Common.Internal.Logging; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public class HttpObjectAggregator : MessageAggregator + { + static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(); + static readonly IFullHttpResponse Continue = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.Continue, Unpooled.Empty); + static readonly IFullHttpResponse ExpectationFailed = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.ExpectationFailed, Unpooled.Empty); + static readonly IFullHttpResponse TooLargeClose = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.RequestEntityTooLarge, Unpooled.Empty); + static readonly IFullHttpResponse TooLarge = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.RequestEntityTooLarge, Unpooled.Empty); + + static HttpObjectAggregator() + { + ExpectationFailed.Headers.Set(HttpHeaderNames.ContentLength, HttpHeaderValues.Zero); + TooLarge.Headers.Set(HttpHeaderNames.ContentLength, HttpHeaderValues.Zero); + + TooLargeClose.Headers.Set(HttpHeaderNames.ContentLength, HttpHeaderValues.Zero); + TooLargeClose.Headers.Set(HttpHeaderNames.Connection, HttpHeaderValues.Close); + } + + readonly bool closeOnExpectationFailed; + + public HttpObjectAggregator(int maxContentLength) + : this(maxContentLength, false) + { + } + + public HttpObjectAggregator(int maxContentLength, bool closeOnExpectationFailed) + : base(maxContentLength) + { + this.closeOnExpectationFailed = closeOnExpectationFailed; + } + + protected override bool IsStartMessage(IHttpObject msg) => msg is IHttpMessage; + + protected override bool IsContentMessage(IHttpObject msg) => msg is IHttpContent; + + protected override bool IsLastContentMessage(IHttpContent msg) => msg is ILastHttpContent; + + protected override bool IsAggregated(IHttpObject msg) => msg is IFullHttpMessage; + + protected override bool IsContentLengthInvalid(IHttpMessage start, int maxContentLength) + { + try + { + return HttpUtil.GetContentLength(start, -1) > maxContentLength; + } + catch (FormatException) + { + return false; + } + } + + static object ContinueResponse(IHttpMessage start, int maxContentLength, IChannelPipeline pipeline) + { + if (HttpUtil.IsUnsupportedExpectation(start)) + { + // if the request contains an unsupported expectation, we return 417 + pipeline.FireUserEventTriggered(HttpExpectationFailedEvent.Default); + return ExpectationFailed.RetainedDuplicate(); + } + else if (HttpUtil.Is100ContinueExpected(start)) + { + // if the request contains 100-continue but the content-length is too large, we return 413 + if (HttpUtil.GetContentLength(start, -1L) <= maxContentLength) + { + return Continue.RetainedDuplicate(); + } + pipeline.FireUserEventTriggered(HttpExpectationFailedEvent.Default); + return TooLarge.RetainedDuplicate(); + } + + return null; + } + + protected override object NewContinueResponse(IHttpMessage start, int maxContentLength, IChannelPipeline pipeline) + { + object response = ContinueResponse(start, maxContentLength, pipeline); + // we're going to respond based on the request expectation so there's no + // need to propagate the expectation further. + if (response != null) + { + start.Headers.Remove(HttpHeaderNames.Expect); + } + return response; + } + + protected override bool CloseAfterContinueResponse(object msg) => + this.closeOnExpectationFailed && this.IgnoreContentAfterContinueResponse(msg); + + protected override bool IgnoreContentAfterContinueResponse(object msg) => + msg is IHttpResponse response && response.Status.CodeClass.Equals(HttpStatusClass.ClientError); + + protected override IFullHttpMessage BeginAggregation(IHttpMessage start, IByteBuffer content) + { + Debug.Assert(!(start is IFullHttpMessage)); + + HttpUtil.SetTransferEncodingChunked(start, false); + + if (start is IHttpRequest request) + { + return new AggregatedFullHttpRequest(request, content, null); + } + else if (start is IHttpResponse response) + { + return new AggregatedFullHttpResponse(response, content, null); + } + + throw new CodecException($"Invalid type {StringUtil.SimpleClassName(start)} expecting {nameof(IHttpRequest)} or {nameof(IHttpResponse)}"); + } + + protected override void Aggregate(IFullHttpMessage aggregated, IHttpContent content) + { + if (content is ILastHttpContent httpContent) + { + // Merge trailing headers into the message. + ((AggregatedFullHttpMessage)aggregated).TrailingHeaders = httpContent.TrailingHeaders; + } + } + + protected override void FinishAggregation(IFullHttpMessage aggregated) + { + // Set the 'Content-Length' header. If one isn't already set. + // This is important as HEAD responses will use a 'Content-Length' header which + // does not match the actual body, but the number of bytes that would be + // transmitted if a GET would have been used. + // + // See rfc2616 14.13 Content-Length + if (!HttpUtil.IsContentLengthSet(aggregated)) + { + aggregated.Headers.Set( + HttpHeaderNames.ContentLength, + new AsciiString(aggregated.Content.ReadableBytes.ToString())); + } + } + + protected override void HandleOversizedMessage(IChannelHandlerContext ctx, IHttpMessage oversized) + { + if (oversized is IHttpRequest) + { + // send back a 413 and close the connection + + // If the client started to send data already, close because it's impossible to recover. + // If keep-alive is off and 'Expect: 100-continue' is missing, no need to leave the connection open. + if (oversized is IFullHttpMessage || + !HttpUtil.Is100ContinueExpected(oversized) && !HttpUtil.IsKeepAlive(oversized)) + { + ctx.WriteAndFlushAsync(TooLargeClose.RetainedDuplicate()).ContinueWith((t, s) => + { + if (t.IsFaulted) + { + Logger.Debug("Failed to send a 413 Request Entity Too Large.", t.Exception); + } + ((IChannelHandlerContext)s).CloseAsync(); + }, + ctx, + TaskContinuationOptions.ExecuteSynchronously); + } + else + { + ctx.WriteAndFlushAsync(TooLarge.RetainedDuplicate()).ContinueWith((t, s) => + { + if (t.IsFaulted) + { + Logger.Debug("Failed to send a 413 Request Entity Too Large.", t.Exception); + ((IChannelHandlerContext)s).CloseAsync(); + } + }, + ctx, + TaskContinuationOptions.ExecuteSynchronously); + } + // If an oversized request was handled properly and the connection is still alive + // (i.e. rejected 100-continue). the decoder should prepare to handle a new message. + var decoder = ctx.Channel.Pipeline.Get(); + decoder?.Reset(); + } + else if (oversized is IHttpResponse) + { + ctx.CloseAsync(); + throw new TooLongFrameException($"Response entity too large: {oversized}"); + } + else + { + throw new InvalidOperationException($"Invalid type {StringUtil.SimpleClassName(oversized)}, expecting {nameof(IHttpRequest)} or {nameof(IHttpResponse)}"); + } + } + + abstract class AggregatedFullHttpMessage : IFullHttpMessage + { + protected readonly IHttpMessage Message; + readonly IByteBuffer content; + HttpHeaders trailingHeaders; + + protected AggregatedFullHttpMessage(IHttpMessage message, IByteBuffer content, HttpHeaders trailingHeaders) + { + this.Message = message; + this.content = content; + this.trailingHeaders = trailingHeaders; + } + + public HttpHeaders TrailingHeaders + { + get + { + HttpHeaders headers = this.trailingHeaders; + return headers ?? EmptyHttpHeaders.Default; + } + internal set => this.trailingHeaders = value; + } + + public HttpVersion ProtocolVersion => this.Message.ProtocolVersion; + + public IHttpMessage SetProtocolVersion(HttpVersion version) + { + this.Message.SetProtocolVersion(version); + return this; + } + + public HttpHeaders Headers => this.Message.Headers; + + public DecoderResult Result + { + get => this.Message.Result; + set => this.Message.Result = value; + } + + public IByteBuffer Content => this.content; + + public int ReferenceCount => this.content.ReferenceCount; + + public IReferenceCounted Retain() + { + this.content.Retain(); + return this; + } + + public IReferenceCounted Retain(int increment) + { + this.content.Retain(increment); + return this; + } + + public IReferenceCounted Touch() + { + this.content.Touch(); + return this; + } + + public IReferenceCounted Touch(object hint) + { + this.content.Touch(hint); + return this; + } + + public bool Release() => this.content.Release(); + + public bool Release(int decrement) => this.content.Release(decrement); + + public abstract IByteBufferHolder Copy(); + + public abstract IByteBufferHolder Duplicate(); + + public abstract IByteBufferHolder RetainedDuplicate(); + + public abstract IByteBufferHolder Replace(IByteBuffer content); + } + + sealed class AggregatedFullHttpRequest : AggregatedFullHttpMessage, IFullHttpRequest + { + internal AggregatedFullHttpRequest(IHttpRequest message, IByteBuffer content, HttpHeaders trailingHeaders) + : base(message, content, trailingHeaders) + { + } + + public override IByteBufferHolder Copy() => this.Replace(this.Content.Copy()); + + public override IByteBufferHolder Duplicate() => this.Replace(this.Content.Duplicate()); + + public override IByteBufferHolder RetainedDuplicate() => this.Replace(this.Content.RetainedDuplicate()); + + public override IByteBufferHolder Replace(IByteBuffer content) + { + var dup = new DefaultFullHttpRequest(this.ProtocolVersion, this.Method, this.Uri, content, + this.Headers.Copy(), this.TrailingHeaders.Copy()); + dup.Result = this.Result; + return dup; + } + + public HttpMethod Method => ((IHttpRequest)this.Message).Method; + + public IHttpRequest SetMethod(HttpMethod method) + { + ((IHttpRequest)this.Message).SetMethod(method); + return this; + } + + public string Uri => ((IHttpRequest)this.Message).Uri; + + public IHttpRequest SetUri(string uri) + { + ((IHttpRequest)this.Message).SetUri(uri); + return this; + } + + public override string ToString() => HttpMessageUtil.AppendFullRequest(new StringBuilder(256), this).ToString(); + } + + sealed class AggregatedFullHttpResponse : AggregatedFullHttpMessage, IFullHttpResponse + { + public AggregatedFullHttpResponse(IHttpResponse message, IByteBuffer content, HttpHeaders trailingHeaders) + : base(message, content, trailingHeaders) + { + } + + public override IByteBufferHolder Copy() => this.Replace(this.Content.Copy()); + + public override IByteBufferHolder Duplicate() => this.Replace(this.Content.Duplicate()); + + public override IByteBufferHolder RetainedDuplicate() => this.Replace(this.Content.RetainedDuplicate()); + + public override IByteBufferHolder Replace(IByteBuffer content) + { + var dup = new DefaultFullHttpResponse(this.ProtocolVersion, this.Status, content, + this.Headers.Copy(), this.TrailingHeaders.Copy()); + dup.Result = this.Result; + return dup; + } + + public HttpResponseStatus Status => ((IHttpResponse)this.Message).Status; + + public IHttpResponse SetStatus(HttpResponseStatus status) + { + ((IHttpResponse)this.Message).SetStatus(status); + return this; + } + + public override string ToString() => HttpMessageUtil.AppendFullResponse(new StringBuilder(256), this).ToString(); + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpObjectDecoder.cs b/src/DotNetty.Codecs.Http/HttpObjectDecoder.cs new file mode 100644 index 0000000..2521b42 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpObjectDecoder.cs @@ -0,0 +1,898 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System; + using System.Collections.Generic; + using System.Diagnostics; + using System.Diagnostics.Contracts; + using System.Runtime.CompilerServices; + using DotNetty.Buffers; + using DotNetty.Common.Internal; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public abstract class HttpObjectDecoder : ByteToMessageDecoder + { + readonly int maxChunkSize; + readonly bool chunkedSupported; + protected readonly bool ValidateHeaders; + readonly HeaderParser headerParser; + readonly LineParser lineParser; + + IHttpMessage message; + long chunkSize; + long contentLength = long.MinValue; + volatile bool resetRequested; + + // These will be updated by splitHeader(...) + AsciiString name; + AsciiString value; + + ILastHttpContent trailer; + + enum State + { + SkipControlChars, + ReadInitial, + ReadHeader, + ReadVariableLengthContent, + ReadFixedLengthContent, + ReadChunkSize, + ReadChunkedContent, + ReadChunkDelimiter, + ReadChunkFooter, + BadMessage, + Upgraded + } + + State currentState = State.SkipControlChars; + + protected HttpObjectDecoder() : this(4096, 8192, 8192, true) + { + } + + protected HttpObjectDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool chunkedSupported) + : this(maxInitialLineLength, maxHeaderSize, maxChunkSize, chunkedSupported, true) + { + } + + protected HttpObjectDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, + bool chunkedSupported, bool validateHeaders) + : this(maxInitialLineLength, maxHeaderSize, maxChunkSize, chunkedSupported, validateHeaders, 128) + { + } + + protected HttpObjectDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, + bool chunkedSupported, bool validateHeaders, int initialBufferSize) + { + Contract.Requires(maxInitialLineLength > 0); + Contract.Requires(maxHeaderSize > 0); + Contract.Requires(maxChunkSize > 0); + + var seq = new AppendableCharSequence(initialBufferSize); + this.lineParser = new LineParser(seq, maxInitialLineLength); + this.headerParser = new HeaderParser(seq, maxHeaderSize); + this.maxChunkSize = maxChunkSize; + this.chunkedSupported = chunkedSupported; + this.ValidateHeaders = validateHeaders; + } + + protected override void Decode(IChannelHandlerContext context, IByteBuffer buffer, List output) + { + if (this.resetRequested) + { + this.ResetNow(); + } + + switch (this.currentState) + { + case State.SkipControlChars: + { + if (!SkipControlCharacters(buffer)) + { + return; + } + this.currentState = State.ReadInitial; + goto case State.ReadInitial; // Fall through + } + case State.ReadInitial: + { + try + { + AppendableCharSequence line = this.lineParser.Parse(buffer); + if (line == null) + { + return; + } + AsciiString[] initialLine = SplitInitialLine(line); + if (initialLine.Length < 3) + { + // Invalid initial line - ignore. + this.currentState = State.SkipControlChars; + return; + } + + this.message = this.CreateMessage(initialLine); + this.currentState = State.ReadHeader; + goto case State.ReadHeader; // Fall through + } + catch (Exception e) + { + output.Add(this.InvalidMessage(buffer, e)); + return; + } + } + case State.ReadHeader: + { + try + { + State? nextState = this.ReadHeaders(buffer); + if (nextState == null) + { + return; + } + this.currentState = nextState.Value; + switch (nextState.Value) + { + case State.SkipControlChars: + { + // fast-path + // No content is expected. + output.Add(this.message); + output.Add(EmptyLastHttpContent.Default); + this.ResetNow(); + return; + } + case State.ReadChunkSize: + { + if (!this.chunkedSupported) + { + throw new ArgumentException("Chunked messages not supported"); + } + // Chunked encoding - generate HttpMessage first. HttpChunks will follow. + output.Add(this.message); + return; + } + default: + { + // RFC 7230, 3.3.3 states that if a + // request does not have either a transfer-encoding or a content-length header then the message body + // length is 0. However for a response the body length is the number of octets received prior to the + // server closing the connection. So we treat this as variable length chunked encoding. + long length = this.ContentLength(); + if (length == 0 || length == -1 && this.IsDecodingRequest()) + { + output.Add(this.message); + output.Add(EmptyLastHttpContent.Default); + this.ResetNow(); + return; + } + + Debug.Assert(nextState.Value == State.ReadFixedLengthContent + || nextState.Value == State.ReadVariableLengthContent); + + output.Add(this.message); + + if (nextState == State.ReadFixedLengthContent) + { + // chunkSize will be decreased as the READ_FIXED_LENGTH_CONTENT state reads data chunk by chunk. + this.chunkSize = length; + } + + // We return here, this forces decode to be called again where we will decode the content + return; + } + } + } + catch (Exception exception) + { + output.Add(this.InvalidMessage(buffer, exception)); + return; + } + } + case State.ReadVariableLengthContent: + { + // Keep reading data as a chunk until the end of connection is reached. + int toRead = Math.Min(buffer.ReadableBytes, this.maxChunkSize); + if (toRead > 0) + { + IByteBuffer content = buffer.ReadRetainedSlice(toRead); + output.Add(new DefaultHttpContent(content)); + } + return; + } + case State.ReadFixedLengthContent: + { + int readLimit = buffer.ReadableBytes; + + // Check if the buffer is readable first as we use the readable byte count + // to create the HttpChunk. This is needed as otherwise we may end up with + // create a HttpChunk instance that contains an empty buffer and so is + // handled like it is the last HttpChunk. + // + // See https://github.com/netty/netty/issues/433 + if (readLimit == 0) + { + return; + } + + int toRead = Math.Min(readLimit, this.maxChunkSize); + if (toRead > this.chunkSize) + { + toRead = (int)this.chunkSize; + } + IByteBuffer content = buffer.ReadRetainedSlice(toRead); + this.chunkSize -= toRead; + + if (this.chunkSize == 0) + { + // Read all content. + output.Add(new DefaultLastHttpContent(content, this.ValidateHeaders)); + this.ResetNow(); + } + else + { + output.Add(new DefaultHttpContent(content)); + } + return; + } + // everything else after this point takes care of reading chunked content. basically, read chunk size, + // read chunk, read and ignore the CRLF and repeat until 0 + case State.ReadChunkSize: + { + try + { + AppendableCharSequence line = this.lineParser.Parse(buffer); + if (line == null) + { + return; + } + int size = GetChunkSize(line.ToAsciiString()); + this.chunkSize = size; + if (size == 0) + { + this.currentState = State.ReadChunkFooter; + return; + } + this.currentState = State.ReadChunkedContent; + goto case State.ReadChunkedContent; // fall-through + } + catch (Exception e) + { + output.Add(this.InvalidChunk(buffer, e)); + return; + } + } + case State.ReadChunkedContent: + { + Debug.Assert(this.chunkSize <= int.MaxValue); + + int toRead = Math.Min((int)this.chunkSize, this.maxChunkSize); + toRead = Math.Min(toRead, buffer.ReadableBytes); + if (toRead == 0) + { + return; + } + IHttpContent chunk = new DefaultHttpContent(buffer.ReadRetainedSlice(toRead)); + this.chunkSize -= toRead; + + output.Add(chunk); + + if (this.chunkSize != 0) + { + return; + } + this.currentState = State.ReadChunkDelimiter; + goto case State.ReadChunkDelimiter; // fall-through + } + case State.ReadChunkDelimiter: + { + int wIdx = buffer.WriterIndex; + int rIdx = buffer.ReaderIndex; + while (wIdx > rIdx) + { + byte next = buffer.GetByte(rIdx++); + if (next == HttpConstants.LineFeed) + { + this.currentState = State.ReadChunkSize; + break; + } + } + buffer.SetReaderIndex(rIdx); + return; + } + case State.ReadChunkFooter: + { + try + { + ILastHttpContent lastTrialer = this.ReadTrailingHeaders(buffer); + if (lastTrialer == null) + { + return; + } + output.Add(lastTrialer); + this.ResetNow(); + return; + } + catch (Exception exception) + { + output.Add(this.InvalidChunk(buffer, exception)); + return; + } + } + case State.BadMessage: + { + // Keep discarding until disconnection. + buffer.SkipBytes(buffer.ReadableBytes); + break; + } + case State.Upgraded: + { + int readableBytes = buffer.ReadableBytes; + if (readableBytes > 0) + { + // Keep on consuming as otherwise we may trigger an DecoderException, + // other handler will replace this codec with the upgraded protocol codec to + // take the traffic over at some point then. + // See https://github.com/netty/netty/issues/2173 + output.Add(buffer.ReadBytes(readableBytes)); + } + break; + } + } + } + + protected override void DecodeLast(IChannelHandlerContext context, IByteBuffer input, List output) + { + base.DecodeLast(context, input, output); + + if (this.resetRequested) + { + // If a reset was requested by decodeLast() we need to do it now otherwise we may produce a + // LastHttpContent while there was already one. + this.ResetNow(); + } + + // Handle the last unfinished message. + if (this.message != null) + { + bool chunked = HttpUtil.IsTransferEncodingChunked(this.message); + if (this.currentState == State.ReadVariableLengthContent + && !input.IsReadable() && !chunked) + { + // End of connection. + output.Add(EmptyLastHttpContent.Default); + this.ResetNow(); + return; + } + + if (this.currentState == State.ReadHeader) + { + // If we are still in the state of reading headers we need to create a new invalid message that + // signals that the connection was closed before we received the headers. + output.Add(this.InvalidMessage(Unpooled.Empty, + new PrematureChannelClosureException("Connection closed before received headers"))); + this.ResetNow(); + return; + } + + // Check if the closure of the connection signifies the end of the content. + bool prematureClosure; + if (this.IsDecodingRequest() || chunked) + { + // The last request did not wait for a response. + prematureClosure = true; + } + else + { + // Compare the length of the received content and the 'Content-Length' header. + // If the 'Content-Length' header is absent, the length of the content is determined by the end of the + // connection, so it is perfectly fine. + prematureClosure = this.ContentLength() > 0; + } + + if (!prematureClosure) + { + output.Add(EmptyLastHttpContent.Default); + } + this.ResetNow(); + } + } + + public override void UserEventTriggered(IChannelHandlerContext context, object evt) + { + if (evt is HttpExpectationFailedEvent) + { + switch (this.currentState) + { + case State.ReadFixedLengthContent: + case State.ReadVariableLengthContent: + case State.ReadChunkSize: + this.Reset(); + break; + } + } + base.UserEventTriggered(context, evt); + } + + protected virtual bool IsContentAlwaysEmpty(IHttpMessage msg) + { + if (msg is IHttpResponse res) + { + int code = res.Status.Code; + + // Correctly handle return codes of 1xx. + // + // See: + // - http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html Section 4.4 + // - https://github.com/netty/netty/issues/222 + if (code >= 100 && code < 200) + { + // One exception: Hixie 76 websocket handshake response + return !(code == 101 && !res.Headers.Contains(HttpHeaderNames.SecWebsocketAccept) + && res.Headers.Contains(HttpHeaderNames.Upgrade, HttpHeaderValues.Websocket, true)); + } + switch (code) + { + case 204: + case 304: + return true; + } + } + return false; + } + + protected bool IsSwitchingToNonHttp1Protocol(IHttpResponse msg) + { + if (msg.Status.Code != HttpResponseStatus.SwitchingProtocols.Code) + { + return false; + } + + return !msg.Headers.TryGet(HttpHeaderNames.Upgrade, out ICharSequence newProtocol) + || !AsciiString.Contains(newProtocol, HttpVersion.Http10String) + && !AsciiString.Contains(newProtocol, HttpVersion.Http11String); + } + + // Resets the state of the decoder so that it is ready to decode a new message. + // This method is useful for handling a rejected request with {@code Expect: 100-continue} header. + public void Reset() => this.resetRequested = true; + + void ResetNow() + { + IHttpMessage msg = this.message; + this.message = null; + this.name = null; + this.value = null; + this.contentLength = long.MinValue; + this.lineParser.Reset(); + this.headerParser.Reset(); + this.trailer = null; + if (!this.IsDecodingRequest()) + { + if (msg is IHttpResponse res && this.IsSwitchingToNonHttp1Protocol(res)) + { + this.currentState = State.Upgraded; + return; + } + } + + this.resetRequested = false; + this.currentState = State.SkipControlChars; + } + + IHttpMessage InvalidMessage(IByteBuffer buf, Exception cause) + { + this.currentState = State.BadMessage; + + // Advance the readerIndex so that ByteToMessageDecoder does not complain + // when we produced an invalid message without consuming anything. + buf.SkipBytes(buf.ReadableBytes); + + if (this.message != null) + { + this.message.Result = DecoderResult.Failure(cause); + } + else + { + this.message = this.CreateInvalidMessage(); + this.message.Result = DecoderResult.Failure(cause); + } + + IHttpMessage ret = this.message; + this.message = null; + return ret; + } + + IHttpContent InvalidChunk(IByteBuffer buf, Exception cause) + { + this.currentState = State.BadMessage; + + // Advance the readerIndex so that ByteToMessageDecoder does not complain + // when we produced an invalid message without consuming anything. + buf.SkipBytes(buf.ReadableBytes); + + IHttpContent chunk = new DefaultLastHttpContent(Unpooled.Empty); + chunk.Result = DecoderResult.Failure(cause); + this.message = null; + this.trailer = null; + return chunk; + } + + static bool SkipControlCharacters(IByteBuffer buffer) + { + bool skiped = false; + int wIdx = buffer.WriterIndex; + int rIdx = buffer.ReaderIndex; + while (wIdx > rIdx) + { + byte c = buffer.GetByte(rIdx++); + if (!CharUtil.IsISOControl(c) && !IsWhiteSpace(c)) + { + rIdx--; + skiped = true; + break; + } + } + buffer.SetReaderIndex(rIdx); + return skiped; + } + + State? ReadHeaders(IByteBuffer buffer) + { + IHttpMessage httpMessage = this.message; + HttpHeaders headers = httpMessage.Headers; + + AppendableCharSequence line = this.headerParser.Parse(buffer); + if (line == null) + { + return null; + } + // ReSharper disable once ConvertIfDoToWhile + if (line.Count > 0) + { + do + { + byte firstChar = line.Bytes[0]; + if (this.name != null && (firstChar == ' ' || firstChar == '\t')) + { + ICharSequence trimmedLine = CharUtil.Trim(line); + this.value = new AsciiString($"{this.value} {trimmedLine}"); + } + else + { + if (this.name != null) + { + headers.Add(this.name, this.value); + } + this.SplitHeader(line); + } + + line = this.headerParser.Parse(buffer); + if (line == null) + { + return null; + } + } while (line.Count > 0); + } + + // Add the last header. + if (this.name != null) + { + headers.Add(this.name, this.value); + } + // reset name and value fields + this.name = null; + this.value = null; + + State nextState; + + if (this.IsContentAlwaysEmpty(httpMessage)) + { + HttpUtil.SetTransferEncodingChunked(httpMessage, false); + nextState = State.SkipControlChars; + } + else if (HttpUtil.IsTransferEncodingChunked(httpMessage)) + { + nextState = State.ReadChunkSize; + } + else if (this.ContentLength() >= 0) + { + nextState = State.ReadFixedLengthContent; + } + else + { + nextState = State.ReadVariableLengthContent; + } + return nextState; + } + + long ContentLength() + { + if (this.contentLength == long.MinValue) + { + this.contentLength = HttpUtil.GetContentLength(this.message, -1L); + } + return this.contentLength; + } + + ILastHttpContent ReadTrailingHeaders(IByteBuffer buffer) + { + AppendableCharSequence line = this.headerParser.Parse(buffer); + if (line == null) + { + return null; + } + AsciiString lastHeader = null; + if (line.Count > 0) + { + ILastHttpContent trailingHeaders = this.trailer; + if (trailingHeaders == null) + { + trailingHeaders = new DefaultLastHttpContent(Unpooled.Empty, this.ValidateHeaders); + this.trailer = trailingHeaders; + } + do + { + byte firstChar = line.Bytes[0]; + if (lastHeader != null && (firstChar == ' ' || firstChar == '\t')) + { + IList current = trailingHeaders.TrailingHeaders.GetAll(lastHeader); + if (current.Count > 0) + { + int lastPos = current.Count - 1; + ICharSequence lineTrimmed = CharUtil.Trim(line); + current[lastPos] = new AsciiString($"{current[lastPos]} {lineTrimmed}"); + } + } + else + { + this.SplitHeader(line); + AsciiString headerName = this.name; + if (!HttpHeaderNames.ContentLength.ContentEqualsIgnoreCase(headerName) + && !HttpHeaderNames.TransferEncoding.ContentEqualsIgnoreCase(headerName) + && !HttpHeaderNames.Trailer.ContentEqualsIgnoreCase(headerName)) + { + trailingHeaders.TrailingHeaders.Add(headerName, this.value); + } + lastHeader = this.name; + // reset name and value fields + this.name = null; + this.value = null; + } + + line = this.headerParser.Parse(buffer); + if (line == null) + { + return null; + } + } while (line.Count > 0); + + this.trailer = null; + return trailingHeaders; + } + + return EmptyLastHttpContent.Default; + } + + protected abstract bool IsDecodingRequest(); + + protected abstract IHttpMessage CreateMessage(AsciiString[] initialLine); + + protected abstract IHttpMessage CreateInvalidMessage(); + + static int GetChunkSize(AsciiString hex) + { + hex = hex.Trim(); + for (int i = hex.Offset; i < hex.Count; i++) + { + byte c = hex.Array[i]; + if (c == ';' || IsWhiteSpace(c) || CharUtil.IsISOControl(c)) + { + hex = (AsciiString)hex.SubSequence(0, i); + break; + } + } + + return hex.ParseInt(16); + } + + static AsciiString[] SplitInitialLine(AppendableCharSequence sb) + { + byte[] chars = sb.Bytes; + int length = sb.Count; + + int aStart = FindNonWhitespace(chars, 0, length); + int aEnd = FindWhitespace(chars, aStart, length); + + int bStart = FindNonWhitespace(chars, aEnd, length); + int bEnd = FindWhitespace(chars, bStart, length); + + int cStart = FindNonWhitespace(chars, bEnd, length); + int cEnd = FindEndOfString(chars, length); + + return new[] + { + sb.SubStringUnsafe(aStart, aEnd), + sb.SubStringUnsafe(bStart, bEnd), + cStart < cEnd ? sb.SubStringUnsafe(cStart, cEnd) : AsciiString.Empty + }; + } + + void SplitHeader(AppendableCharSequence sb) + { + byte[] chars = sb.Bytes; + int length = sb.Count; + int nameEnd; + int colonEnd; + + int nameStart = FindNonWhitespace(chars, 0, length); + for (nameEnd = nameStart; nameEnd < length; nameEnd++) + { + byte ch = chars[nameEnd]; + if (ch == ':' || IsWhiteSpace(ch)) + { + break; + } + } + + for (colonEnd = nameEnd; colonEnd < length; colonEnd++) + { + if (chars[colonEnd] == ':') + { + colonEnd++; + break; + } + } + + this.name = sb.SubStringUnsafe(nameStart, nameEnd); + int valueStart = FindNonWhitespace(chars, colonEnd, length); + if (valueStart == length) + { + this.value = AsciiString.Empty; + } + else + { + int valueEnd = FindEndOfString(chars, length); + this.value = sb.SubStringUnsafe(valueStart, valueEnd); + } + } + + static int FindNonWhitespace(byte[] sb, int offset, int length) + { + for (int result = offset; result < length; ++result) + { + if (!IsWhiteSpace(sb[result])) + { + return result; + } + } + return length; + } + + static int FindWhitespace(byte[] sb, int offset, int length) + { + for (int result = offset; result < length; ++result) + { + if (IsWhiteSpace(sb[result])) + { + return result; + } + } + return length; + } + + static int FindEndOfString(byte[] sb, int length) + { + for (int result = length - 1; result > 0; --result) + { + if (!IsWhiteSpace(sb[result])) + { + return result + 1; + } + } + return 0; + } + + class HeaderParser : IByteProcessor + { + readonly AppendableCharSequence seq; + readonly int maxLength; + int size; + + internal HeaderParser(AppendableCharSequence seq, int maxLength) + { + this.seq = seq; + this.maxLength = maxLength; + } + + public virtual AppendableCharSequence Parse(IByteBuffer buffer) + { + int oldSize = this.size; + this.seq.Reset(); + int i = buffer.ForEachByte(this); + if (i == -1) + { + this.size = oldSize; + return null; + } + buffer.SetReaderIndex(i + 1); + return this.seq; + } + + public void Reset() => this.size = 0; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool Process(byte value) + { + if (value == HttpConstants.CarriageReturn) + { + return true; + } + if (value == HttpConstants.LineFeed) + { + return false; + } + + if (++this.size > this.maxLength) + { + // TODO: Respond with Bad Request and discard the traffic + // or close the connection. + // No need to notify the upstream handlers - just log. + // If decoding a response, just throw an exception. + ThrowTooLongFrameException(this, this.maxLength); + } + + this.seq.Append(value); + return true; + } + + static void ThrowTooLongFrameException(HeaderParser parser, int length) + { + throw GetTooLongFrameException(); + + TooLongFrameException GetTooLongFrameException() + { + return new TooLongFrameException(parser.NewExceptionMessage(length)); + } + } + + protected virtual string NewExceptionMessage(int length) => $"HTTP header is larger than {length} bytes."; + } + + sealed class LineParser : HeaderParser + { + internal LineParser(AppendableCharSequence seq, int maxLength) + : base(seq, maxLength) + { + } + + public override AppendableCharSequence Parse(IByteBuffer buffer) + { + this.Reset(); + return base.Parse(buffer); + } + + protected override string NewExceptionMessage(int maxLength) => $"An HTTP line is larger than {maxLength} bytes."; + } + + // Similar to char.IsWhiteSpace for ascii + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static bool IsWhiteSpace(byte c) + { + switch (c) + { + case HttpConstants.HorizontalSpace: + case HttpConstants.HorizontalTab: + case HttpConstants.CarriageReturn: + return true; + } + return false; + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpObjectEncoder.cs b/src/DotNetty.Codecs.Http/HttpObjectEncoder.cs new file mode 100644 index 0000000..1dfd0b0 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpObjectEncoder.cs @@ -0,0 +1,252 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System; + using System.Collections.Generic; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public abstract class HttpObjectEncoder : MessageToMessageEncoder where T : IHttpMessage + { + const float HeadersWeightNew = 1 / 5f; + const float HeadersWeightHistorical = 1 - HeadersWeightNew; + const float TrailersWeightNew = HeadersWeightNew; + const float TrailersWeightHistorical = HeadersWeightHistorical; + + const int StInit = 0; + const int StContentNonChunk = 1; + const int StContentChunk = 2; + const int StContentAlwaysEmpty = 3; + + int state = StInit; + + // Used to calculate an exponential moving average of the encoded size of the initial line and the headers for + // a guess for future buffer allocations. + float headersEncodedSizeAccumulator = 256; + + // Used to calculate an exponential moving average of the encoded size of the trailers for + // a guess for future buffer allocations. + float trailersEncodedSizeAccumulator = 256; + + protected override void Encode(IChannelHandlerContext context, object message, List output) + { + IByteBuffer buf = null; + if (message is IHttpMessage) + { + if (this.state != StInit) + { + throw new InvalidOperationException($"unexpected message type: {StringUtil.SimpleClassName(message)}"); + } + + var m = (T)message; + + buf = context.Allocator.Buffer((int)this.headersEncodedSizeAccumulator); + // Encode the message. + this.EncodeInitialLine(buf, m); + this.state = this.IsContentAlwaysEmpty(m) ? StContentAlwaysEmpty + : HttpUtil.IsTransferEncodingChunked(m) ? StContentChunk : StContentNonChunk; + + this.SanitizeHeadersBeforeEncode(m, this.state == StContentAlwaysEmpty); + + this.EncodeHeaders(m.Headers, buf); + buf.WriteShort(HttpConstants.CrlfShort); + + this.headersEncodedSizeAccumulator = HeadersWeightNew * PadSizeForAccumulation(buf.ReadableBytes) + + HeadersWeightHistorical * this.headersEncodedSizeAccumulator; + } + + // Bypass the encoder in case of an empty buffer, so that the following idiom works: + // + // ch.write(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE); + // + // See https://github.com/netty/netty/issues/2983 for more information. + if (message is IByteBuffer potentialEmptyBuf) + { + if (!potentialEmptyBuf.IsReadable()) + { + output.Add(potentialEmptyBuf.Retain()); + return; + } + } + + if (message is IHttpContent || message is IByteBuffer || message is IFileRegion) + { + switch (this.state) + { + case StInit: + throw new InvalidOperationException($"unexpected message type: {StringUtil.SimpleClassName(message)}"); + case StContentNonChunk: + long contentLength = ContentLength(message); + if (contentLength > 0) + { + if (buf != null && buf.WritableBytes >= contentLength && message is IHttpContent) + { + // merge into other buffer for performance reasons + buf.WriteBytes(((IHttpContent)message).Content); + output.Add(buf); + } + else + { + if (buf != null) + { + output.Add(buf); + } + output.Add(EncodeAndRetain(message)); + } + + if (message is ILastHttpContent) + { + this.state = StInit; + } + break; + } + + goto case StContentAlwaysEmpty; // fall-through! + case StContentAlwaysEmpty: + // ReSharper disable once ConvertIfStatementToNullCoalescingExpression + if (buf != null) + { + // We allocated a buffer so add it now. + output.Add(buf); + } + else + { + // Need to produce some output otherwise an + // IllegalStateException will be thrown as we did not write anything + // Its ok to just write an EMPTY_BUFFER as if there are reference count issues these will be + // propagated as the caller of the encode(...) method will release the original + // buffer. + // Writing an empty buffer will not actually write anything on the wire, so if there is a user + // error with msg it will not be visible externally + output.Add(Unpooled.Empty); + } + + break; + case StContentChunk: + if (buf != null) + { + // We allocated a buffer so add it now. + output.Add(buf); + } + this.EncodeChunkedContent(context, message, ContentLength(message), output); + + break; + default: + throw new EncoderException($"unexpected state {this.state}: {StringUtil.SimpleClassName(message)}"); + } + + if (message is ILastHttpContent) + { + this.state = StInit; + } + } + else if (buf != null) + { + output.Add(buf); + } + } + + protected void EncodeHeaders(HttpHeaders headers, IByteBuffer buf) + { + foreach (HeaderEntry header in headers) + { + HttpHeadersEncoder.EncoderHeader(header.Key, header.Value, buf); + } + } + + void EncodeChunkedContent(IChannelHandlerContext context, object message, long contentLength, ICollection output) + { + if (contentLength > 0) + { + var lengthHex = new AsciiString(Convert.ToString(contentLength, 16), Encoding.ASCII); + IByteBuffer buf = context.Allocator.Buffer(lengthHex.Count + 2); + buf.WriteCharSequence(lengthHex, Encoding.ASCII); + buf.WriteShort(HttpConstants.CrlfShort); + output.Add(buf); + output.Add(EncodeAndRetain(message)); + output.Add(HttpConstants.CrlfBuf.Duplicate()); + } + + if (message is ILastHttpContent content) + { + HttpHeaders headers = content.TrailingHeaders; + if (headers.IsEmpty) + { + output.Add(HttpConstants.ZeroCrlfCrlfBuf.Duplicate()); + } + else + { + IByteBuffer buf = context.Allocator.Buffer((int)this.trailersEncodedSizeAccumulator); + buf.WriteMedium(HttpConstants.ZeroCrlfMedium); + this.EncodeHeaders(headers, buf); + buf.WriteShort(HttpConstants.CrlfShort); + this.trailersEncodedSizeAccumulator = TrailersWeightNew * PadSizeForAccumulation(buf.ReadableBytes) + + TrailersWeightHistorical * this.trailersEncodedSizeAccumulator; + output.Add(buf); + } + } + else if (contentLength == 0) + { + // Need to produce some output otherwise an + // IllegalstateException will be thrown + output.Add(ReferenceCountUtil.Retain(message)); + } + } + + // Allows to sanitize headers of the message before encoding these. + protected virtual void SanitizeHeadersBeforeEncode(T msg, bool isAlwaysEmpty) + { + // noop + } + + protected virtual bool IsContentAlwaysEmpty(T msg) => false; + + public override bool AcceptOutboundMessage(object msg) => msg is IHttpObject || msg is IByteBuffer || msg is IFileRegion; + + static object EncodeAndRetain(object message) + { + if (message is IByteBuffer buffer) + { + return buffer.Retain(); + } + if (message is IHttpContent content) + { + return content.Content.Retain(); + } + if (message is IFileRegion region) + { + return region.Retain(); + } + throw new InvalidOperationException($"unexpected message type: {StringUtil.SimpleClassName(message)}"); + } + + static long ContentLength(object message) + { + if (message is IHttpContent content) + { + return content.Content.ReadableBytes; + } + if (message is IByteBuffer buffer) + { + return buffer.ReadableBytes; + } + if (message is IFileRegion region) + { + return region.Count; + } + throw new InvalidOperationException($"unexpected message type: {StringUtil.SimpleClassName(message)}"); + } + + // Add some additional overhead to the buffer. The rational is that it is better to slightly over allocate and waste + // some memory, rather than under allocate and require a resize/copy. + // @param readableBytes The readable bytes in the buffer. + // @return The {@code readableBytes} with some additional padding. + static int PadSizeForAccumulation(int readableBytes) => (readableBytes << 2) / 3; + + protected internal abstract void EncodeInitialLine(IByteBuffer buf, T message); + } +} diff --git a/src/DotNetty.Codecs.Http/HttpRequestDecoder.cs b/src/DotNetty.Codecs.Http/HttpRequestDecoder.cs new file mode 100644 index 0000000..c17bc58 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpRequestDecoder.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using DotNetty.Common.Utilities; + + public class HttpRequestDecoder : HttpObjectDecoder + { + public HttpRequestDecoder() + { + } + + public HttpRequestDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize) + : base(maxInitialLineLength, maxHeaderSize, maxChunkSize, true) + { + } + + public HttpRequestDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool validateHeaders) + : base(maxInitialLineLength, maxHeaderSize, maxChunkSize, true, validateHeaders) + { + } + + public HttpRequestDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool validateHeaders, + int initialBufferSize) + : base(maxInitialLineLength, maxHeaderSize, maxChunkSize, true, validateHeaders, initialBufferSize) + { + } + + protected sealed override IHttpMessage CreateMessage(AsciiString[] initialLine) => + new DefaultHttpRequest( + HttpVersion.ValueOf(initialLine[2]), + HttpMethod.ValueOf(initialLine[0]), initialLine[1].ToString(), this.ValidateHeaders); + + protected override IHttpMessage CreateInvalidMessage() => new DefaultFullHttpRequest(HttpVersion.Http10, HttpMethod.Get, "/bad-request", this.ValidateHeaders); + + protected override bool IsDecodingRequest() => true; + } +} diff --git a/src/DotNetty.Codecs.Http/HttpRequestEncoder.cs b/src/DotNetty.Codecs.Http/HttpRequestEncoder.cs new file mode 100644 index 0000000..ed0a596 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpRequestEncoder.cs @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + + using static HttpConstants; + + public class HttpRequestEncoder : HttpObjectEncoder + { + const char Slash = '/'; + const char QuestionMark = '?'; + const int SlashAndSpaceShort = (Slash << 8) | HorizontalSpace; + const int SpaceSlashAndSpaceMedium = (HorizontalSpace << 16) | SlashAndSpaceShort; + + public override bool AcceptOutboundMessage(object msg) => base.AcceptOutboundMessage(msg) && !(msg is IHttpResponse); + + protected internal override void EncodeInitialLine(IByteBuffer buf, IHttpRequest request) + { + ByteBufferUtil.Copy(request.Method.AsciiName, buf); + + string uri = request.Uri; + + if (string.IsNullOrEmpty(uri)) + { + // Add / as absolute path if no is present. + // See http://tools.ietf.org/html/rfc2616#section-5.1.2 + buf.WriteMedium(SpaceSlashAndSpaceMedium); + } + else + { + var uriCharSequence = new StringBuilderCharSequence(); + uriCharSequence.Append(uri); + + bool needSlash = false; + int start = uri.IndexOf("://", StringComparison.Ordinal); + if (start != -1 && uri[0] != Slash) + { + start += 3; + // Correctly handle query params. + // See https://github.com/netty/netty/issues/2732 + int index = uri.IndexOf(QuestionMark, start); + if (index == -1) + { + if (uri.LastIndexOf(Slash) < start) + { + needSlash = true; + } + } + else + { + if (uri.LastIndexOf(Slash, index) < start) + { + uriCharSequence.Insert(index, Slash); + } + } + } + + buf.WriteByte(HorizontalSpace).WriteCharSequence(uriCharSequence, Encoding.UTF8); + if (needSlash) + { + // write "/ " after uri + buf.WriteShort(SlashAndSpaceShort); + } + else + { + buf.WriteByte(HorizontalSpace); + } + } + + request.ProtocolVersion.Encode(buf); + buf.WriteShort(CrlfShort); + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpResponseDecoder.cs b/src/DotNetty.Codecs.Http/HttpResponseDecoder.cs new file mode 100644 index 0000000..e7ff8e7 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpResponseDecoder.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using DotNetty.Common.Utilities; + + public class HttpResponseDecoder : HttpObjectDecoder + { + static readonly HttpResponseStatus UnknownStatus = new HttpResponseStatus(999, new AsciiString("Unknown")); + + public HttpResponseDecoder() + { + } + + public HttpResponseDecoder(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize) + : base(maxInitialLineLength, maxHeaderSize, maxChunkSize, true) + { + } + + public HttpResponseDecoder(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool validateHeaders) + : base(maxInitialLineLength, maxHeaderSize, maxChunkSize, true, validateHeaders) + { + } + + public HttpResponseDecoder(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool validateHeaders, int initialBufferSize) + : base(maxInitialLineLength, maxHeaderSize, maxChunkSize, true, validateHeaders, initialBufferSize) + { + } + + protected sealed override IHttpMessage CreateMessage(AsciiString[] initialLine) => + new DefaultHttpResponse( + HttpVersion.ValueOf(initialLine[0]), + HttpResponseStatus.ValueOf(initialLine[1].ParseInt() , initialLine[2]), this.ValidateHeaders); + + protected override IHttpMessage CreateInvalidMessage() => new DefaultFullHttpResponse(HttpVersion.Http10, UnknownStatus, this.ValidateHeaders); + + protected override bool IsDecodingRequest() => false; + } +} diff --git a/src/DotNetty.Codecs.Http/HttpResponseEncoder.cs b/src/DotNetty.Codecs.Http/HttpResponseEncoder.cs new file mode 100644 index 0000000..19572de --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpResponseEncoder.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using DotNetty.Buffers; + + public class HttpResponseEncoder : HttpObjectEncoder + { + public override bool AcceptOutboundMessage(object msg) => base.AcceptOutboundMessage(msg) && !(msg is IHttpRequest); + + protected internal override void EncodeInitialLine(IByteBuffer buf, IHttpResponse response) + { + response.ProtocolVersion.Encode(buf); + buf.WriteByte(HttpConstants.HorizontalSpace); + response.Status.Encode(buf); + buf.WriteShort(HttpConstants.CrlfShort); + } + + protected override void SanitizeHeadersBeforeEncode(IHttpResponse msg, bool isAlwaysEmpty) + { + if (isAlwaysEmpty) + { + HttpResponseStatus status = msg.Status; + if (status.CodeClass == HttpStatusClass.Informational + || status.Code == HttpResponseStatus.NoContent.Code) + { + + // Stripping Content-Length: + // See https://tools.ietf.org/html/rfc7230#section-3.3.2 + msg.Headers.Remove(HttpHeaderNames.ContentLength); + + // Stripping Transfer-Encoding: + // See https://tools.ietf.org/html/rfc7230#section-3.3.1 + msg.Headers.Remove(HttpHeaderNames.TransferEncoding); + } + } + } + + protected override bool IsContentAlwaysEmpty(IHttpResponse msg) + { + // Correctly handle special cases as stated in: + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + HttpResponseStatus status = msg.Status; + + if (status.CodeClass == HttpStatusClass.Informational) + { + if (status.Code == HttpResponseStatus.SwitchingProtocols.Code) + { + // We need special handling for WebSockets version 00 as it will include an body. + // Fortunally this version should not really be used in the wild very often. + // See https://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-00#section-1.2 + return msg.Headers.Contains(HttpHeaderNames.SecWebsocketVersion); + } + return true; + } + return status.Code == HttpResponseStatus.NoContent.Code + || status.Code == HttpResponseStatus.NotModified.Code; + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpResponseStatus.cs b/src/DotNetty.Codecs.Http/HttpResponseStatus.cs new file mode 100644 index 0000000..45d0f1d --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpResponseStatus.cs @@ -0,0 +1,560 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http +{ + using System; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + + /** + * The response code and its description of HTTP or its derived protocols, such as + * RTSP and + * ICAP. + */ + public class HttpResponseStatus : IComparable + { + /** + * 100 Continue + */ + public static readonly HttpResponseStatus Continue = NewStatus(100, "Continue"); + + /** + * 101 Switching Protocols + */ + public static readonly HttpResponseStatus SwitchingProtocols = NewStatus(101, "Switching Protocols"); + + /** + * 102 Processing (WebDAV, RFC2518) + */ + public static readonly HttpResponseStatus Processing = NewStatus(102, "Processing"); + + /** + * 200 OK + */ + public static readonly HttpResponseStatus OK = NewStatus(200, "OK"); + + /** + * 201 Created + */ + public static readonly HttpResponseStatus Created = NewStatus(201, "Created"); + + /** + * 202 Accepted + */ + public static readonly HttpResponseStatus Accepted = NewStatus(202, "Accepted"); + + /** + * 203 Non-Authoritative Information (since HTTP/1.1) + */ + public static readonly HttpResponseStatus NonAuthoritativeInformation = NewStatus(203, "Non-Authoritative Information"); + + /** + * 204 No Content + */ + public static readonly HttpResponseStatus NoContent = NewStatus(204, "No Content"); + + /** + * 205 Reset Content + */ + public static readonly HttpResponseStatus ResetContent = NewStatus(205, "Reset Content"); + + /** + * 206 Partial Content + */ + public static readonly HttpResponseStatus PartialContent = NewStatus(206, "Partial Content"); + + /** + * 207 Multi-Status (WebDAV, RFC2518) + */ + public static readonly HttpResponseStatus MultiStatus = NewStatus(207, "Multi-Status"); + + /** + * 300 Multiple Choices + */ + public static readonly HttpResponseStatus MultipleChoices = NewStatus(300, "Multiple Choices"); + + /** + * 301 Moved Permanently + */ + public static readonly HttpResponseStatus MovedPermanently = NewStatus(301, "Moved Permanently"); + + /** + * 302 Found + */ + public static readonly HttpResponseStatus Found = NewStatus(302, "Found"); + + /** + * 303 See Other (since HTTP/1.1) + */ + public static readonly HttpResponseStatus SeeOther = NewStatus(303, "See Other"); + + /** + * 304 Not Modified + */ + public static readonly HttpResponseStatus NotModified = NewStatus(304, "Not Modified"); + + /** + * 305 Use Proxy (since HTTP/1.1) + */ + public static readonly HttpResponseStatus UseProxy = NewStatus(305, "Use Proxy"); + + /** + * 307 Temporary Redirect (since HTTP/1.1) + */ + public static readonly HttpResponseStatus TemporaryRedirect = NewStatus(307, "Temporary Redirect"); + + /** + * 308 Permanent Redirect (RFC7538) + */ + public static readonly HttpResponseStatus PermanentRedirect = NewStatus(308, "Permanent Redirect"); + + /** + * 400 Bad Request + */ + public static readonly HttpResponseStatus BadRequest = NewStatus(400, "Bad Request"); + + /** + * 401 Unauthorized + */ + public static readonly HttpResponseStatus Unauthorized = NewStatus(401, "Unauthorized"); + + /** + * 402 Payment Required + */ + public static readonly HttpResponseStatus PaymentRequired = NewStatus(402, "Payment Required"); + + /** + * 403 Forbidden + */ + public static readonly HttpResponseStatus Forbidden = NewStatus(403, "Forbidden"); + + /** + * 404 Not Found + */ + public static readonly HttpResponseStatus NotFound = NewStatus(404, "Not Found"); + + /** + * 405 Method Not Allowed + */ + public static readonly HttpResponseStatus MethodNotAllowed = NewStatus(405, "Method Not Allowed"); + + /** + * 406 Not Acceptable + */ + public static readonly HttpResponseStatus NotAcceptable = NewStatus(406, "Not Acceptable"); + + /** + * 407 Proxy Authentication Required + */ + public static readonly HttpResponseStatus ProxyAuthenticationRequired = NewStatus(407, "Proxy Authentication Required"); + + /** + * 408 Request Timeout + */ + public static readonly HttpResponseStatus RequestTimeout = NewStatus(408, "Request Timeout"); + + /** + * 409 Conflict + */ + public static readonly HttpResponseStatus Conflict = NewStatus(409, "Conflict"); + + /** + * 410 Gone + */ + public static readonly HttpResponseStatus Gone = NewStatus(410, "Gone"); + + /** + * 411 Length Required + */ + public static readonly HttpResponseStatus LengthRequired = NewStatus(411, "Length Required"); + + /** + * 412 Precondition Failed + */ + public static readonly HttpResponseStatus PreconditionFailed = NewStatus(412, "Precondition Failed"); + + /** + * 413 Request Entity Too Large + */ + public static readonly HttpResponseStatus RequestEntityTooLarge = NewStatus(413, "Request Entity Too Large"); + + /** + * 414 Request-URI Too Long + */ + public static readonly HttpResponseStatus RequestUriTooLong = NewStatus(414, "Request-URI Too Long"); + + /** + * 415 Unsupported Media Type + */ + public static readonly HttpResponseStatus UnsupportedMediaType = NewStatus(415, "Unsupported Media Type"); + + /** + * 416 Requested Range Not Satisfiable + */ + public static readonly HttpResponseStatus RequestedRangeNotSatisfiable = NewStatus(416, "Requested Range Not Satisfiable"); + + /** + * 417 Expectation Failed + */ + public static readonly HttpResponseStatus ExpectationFailed = NewStatus(417, "Expectation Failed"); + + /** + * 421 Misdirected Request + * + * 421 Status Code + */ + public static readonly HttpResponseStatus MisdirectedRequest = NewStatus(421, "Misdirected Request"); + + /** + * 422 Unprocessable Entity (WebDAV, RFC4918) + */ + public static readonly HttpResponseStatus UnprocessableEntity = NewStatus(422, "Unprocessable Entity"); + + /** + * 423 Locked (WebDAV, RFC4918) + */ + public static readonly HttpResponseStatus Locked = NewStatus(423, "Locked"); + + /** + * 424 Failed Dependency (WebDAV, RFC4918) + */ + public static readonly HttpResponseStatus FailedDependency = NewStatus(424, "Failed Dependency"); + + /** + * 425 Unordered Collection (WebDAV, RFC3648) + */ + public static readonly HttpResponseStatus UnorderedCollection = NewStatus(425, "Unordered Collection"); + + /** + * 426 Upgrade Required (RFC2817) + */ + public static readonly HttpResponseStatus UpgradeRequired = NewStatus(426, "Upgrade Required"); + + /** + * 428 Precondition Required (RFC6585) + */ + public static readonly HttpResponseStatus PreconditionRequired = NewStatus(428, "Precondition Required"); + + /** + * 429 Too Many Requests (RFC6585) + */ + public static readonly HttpResponseStatus TooManyRequests = NewStatus(429, "Too Many Requests"); + + /** + * 431 Request Header Fields Too Large (RFC6585) + */ + public static readonly HttpResponseStatus RequestHeaderFieldsTooLarge = NewStatus(431, "Request Header Fields Too Large"); + + /** + * 500 Internal Server Error + */ + public static readonly HttpResponseStatus InternalServerError = NewStatus(500, "Internal Server Error"); + + /** + * 501 Not Implemented + */ + public static readonly HttpResponseStatus NotImplemented = NewStatus(501, "Not Implemented"); + + /** + * 502 Bad Gateway + */ + public static readonly HttpResponseStatus BadGateway = NewStatus(502, "Bad Gateway"); + + /** + * 503 Service Unavailable + */ + public static readonly HttpResponseStatus ServiceUnavailable = NewStatus(503, "Service Unavailable"); + + /** + * 504 Gateway Timeout + */ + public static readonly HttpResponseStatus GatewayTimeout = NewStatus(504, "Gateway Timeout"); + + /** + * 505 HTTP Version Not Supported + */ + public static readonly HttpResponseStatus HttpVersionNotSupported = NewStatus(505, "HTTP Version Not Supported"); + + /** + * 506 Variant Also Negotiates (RFC2295) + */ + public static readonly HttpResponseStatus VariantAlsoNegotiates = NewStatus(506, "Variant Also Negotiates"); + + /** + * 507 Insufficient Storage (WebDAV, RFC4918) + */ + public static readonly HttpResponseStatus InsufficientStorage = NewStatus(507, "Insufficient Storage"); + + /** + * 510 Not Extended (RFC2774) + */ + public static readonly HttpResponseStatus NotExtended = NewStatus(510, "Not Extended"); + + /** + * 511 Network Authentication Required (RFC6585) + */ + public static readonly HttpResponseStatus NetworkAuthenticationRequired = NewStatus(511, "Network Authentication Required"); + + static HttpResponseStatus NewStatus(int statusCode, string reasonPhrase) => new HttpResponseStatus(statusCode, new AsciiString(reasonPhrase), true); + + // Returns the {@link HttpResponseStatus} represented by the specified code. + // If the specified code is a standard HTTP getStatus code, a cached instance + // will be returned. Otherwise, a new instance will be returned. + public static HttpResponseStatus ValueOf(int code) => ValueOf0(code) ?? new HttpResponseStatus(code); + + static HttpResponseStatus ValueOf0(int code) + { + switch (code) + { + case 100: + return Continue; + case 101: + return SwitchingProtocols; + case 102: + return Processing; + case 200: + return OK; + case 201: + return Created; + case 202: + return Accepted; + case 203: + return NonAuthoritativeInformation; + case 204: + return NoContent; + case 205: + return ResetContent; + case 206: + return PartialContent; + case 207: + return MultiStatus; + case 300: + return MultipleChoices; + case 301: + return MovedPermanently; + case 302: + return Found; + case 303: + return SeeOther; + case 304: + return NotModified; + case 305: + return UseProxy; + case 307: + return TemporaryRedirect; + case 308: + return PermanentRedirect; + case 400: + return BadRequest; + case 401: + return Unauthorized; + case 402: + return PaymentRequired; + case 403: + return Forbidden; + case 404: + return NotFound; + case 405: + return MethodNotAllowed; + case 406: + return NotAcceptable; + case 407: + return ProxyAuthenticationRequired; + case 408: + return RequestTimeout; + case 409: + return Conflict; + case 410: + return Gone; + case 411: + return LengthRequired; + case 412: + return PreconditionFailed; + case 413: + return RequestEntityTooLarge; + case 414: + return RequestUriTooLong; + case 415: + return UnsupportedMediaType; + case 416: + return RequestedRangeNotSatisfiable; + case 417: + return ExpectationFailed; + case 421: + return MisdirectedRequest; + case 422: + return UnprocessableEntity; + case 423: + return Locked; + case 424: + return FailedDependency; + case 425: + return UnorderedCollection; + case 426: + return UpgradeRequired; + case 428: + return PreconditionRequired; + case 429: + return TooManyRequests; + case 431: + return RequestHeaderFieldsTooLarge; + case 500: + return InternalServerError; + case 501: + return NotImplemented; + case 502: + return BadGateway; + case 503: + return ServiceUnavailable; + case 504: + return GatewayTimeout; + case 505: + return HttpVersionNotSupported; + case 506: + return VariantAlsoNegotiates; + case 507: + return InsufficientStorage; + case 510: + return NotExtended; + case 511: + return NetworkAuthenticationRequired; + } + return null; + } + + public static HttpResponseStatus ValueOf(int code, AsciiString reasonPhrase) + { + HttpResponseStatus responseStatus = ValueOf0(code); + return responseStatus != null && responseStatus.ReasonPhrase.ContentEquals(reasonPhrase) + ? responseStatus + : new HttpResponseStatus(code, reasonPhrase); + } + + public static HttpResponseStatus ParseLine(ICharSequence line) => line is AsciiString asciiString ? ParseLine(asciiString) : ParseLine(line.ToString()); + + public static HttpResponseStatus ParseLine(string line) + { + try + { + int space = line.IndexOf(' '); + return space == -1 + ? ValueOf(int.Parse(line)) + : ValueOf(int.Parse(line.Substring(0, space)), new AsciiString(line.Substring(space + 1))); + } + catch (Exception e) + { + throw new ArgumentException($"malformed status line: {line}", e); + } + } + + public static HttpResponseStatus ParseLine(AsciiString line) + { + try + { + int space = line.ForEachByte(ByteProcessor.FindAsciiSpace); + return space == -1 + ? ValueOf(line.ParseInt()) + : ValueOf(line.ParseInt(0, space), (AsciiString)line.SubSequence(space + 1)); + } + catch (Exception e) + { + throw new ArgumentException($"malformed status line: {line}", e); + } + } + + + readonly int code; + readonly AsciiString codeAsText; + readonly HttpStatusClass codeClass; + + readonly AsciiString reasonPhrase; + readonly byte[] bytes; + + HttpResponseStatus(int code) + : this(code, new AsciiString($"{HttpStatusClass.ValueOf(code).DefaultReasonPhrase} ({code})"), false) + { + } + + public HttpResponseStatus(int code, AsciiString reasonPhrase) + : this(code, reasonPhrase, false) + { + } + + public HttpResponseStatus(int code, AsciiString reasonPhrase, bool bytes) + { + if (code < 0) + { + throw new ArgumentException($"code: {code} (expected: 0+)"); + } + if (reasonPhrase == null) + { + throw new ArgumentException(nameof(reasonPhrase)); + } + + // ReSharper disable once ForCanBeConvertedToForeach + for (int i = 0; i < reasonPhrase.Count; i++) + { + char c = reasonPhrase[i]; + // Check prohibited characters. + switch (c) + { + case '\n': + case '\r': + throw new ArgumentException($"reasonPhrase contains one of the following prohibited characters: \\r\\n: {reasonPhrase}"); + } + } + + this.code = code; + this.codeAsText = new AsciiString(Convert.ToString(code)); + this.reasonPhrase = reasonPhrase; + this.bytes = bytes ? Encoding.ASCII.GetBytes($"{code} {reasonPhrase}") : null; + this.codeClass = HttpStatusClass.ValueOf(code); + } + + public int Code => this.code; + + public AsciiString CodeAsText => this.codeAsText; + + public AsciiString ReasonPhrase => this.reasonPhrase; + + public HttpStatusClass CodeClass => this.codeClass; + + public override int GetHashCode() => this.code; + + public override bool Equals(object obj) + { + if (!(obj is HttpResponseStatus other)) + { + return false; + } + return this.code == other.code; + } + + public int CompareTo(HttpResponseStatus other) => this.code - other.code; + + public override string ToString() => + new StringBuilder(this.ReasonPhrase.Count + 4) + .Append(this.Code) + .Append(' ') + .Append(this.ReasonPhrase) + .ToString(); + + internal void Encode(IByteBuffer buf) + { + if (this.bytes == null) + { + ByteBufferUtil.Copy(this.codeAsText, buf); + buf.WriteByte(HttpConstants.HorizontalSpace); + buf.WriteCharSequence(this.reasonPhrase, Encoding.ASCII); + } + else + { + buf.WriteBytes(this.bytes); + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpScheme.cs b/src/DotNetty.Codecs.Http/HttpScheme.cs new file mode 100644 index 0000000..8d8548a --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpScheme.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http +{ + using DotNetty.Common.Utilities; + + public sealed class HttpScheme + { + readonly int port; + readonly AsciiString name; + + HttpScheme(int port, string name) + { + this.port = port; + this.name = AsciiString.Cached(name); + } + + public AsciiString Name => this.name; + + public int Port => this.port; + + public override bool Equals(object obj) + { + if (!(obj is HttpScheme other)) + { + return false; + } + + return other.port == this.port && other.name.Equals(this.name); + } + + public override int GetHashCode() => this.port * 31 + this.name.GetHashCode(); + + public override string ToString() => this.name.ToString(); + } +} diff --git a/src/DotNetty.Codecs.Http/HttpServerCodec.cs b/src/DotNetty.Codecs.Http/HttpServerCodec.cs new file mode 100644 index 0000000..7cae4e1 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpServerCodec.cs @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System.Collections.Generic; + using DotNetty.Buffers; + using DotNetty.Transport.Channels; + + public class HttpServerCodec : CombinedChannelDuplexHandler, + HttpServerUpgradeHandler.ISourceCodec + { + /** A queue that is used for correlating a request and a response. */ + readonly Queue queue = new Queue(); + + public HttpServerCodec() : this(4096, 8192, 8192) + { + } + + public HttpServerCodec(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize) + { + this.Init(new HttpServerRequestDecoder(this, maxInitialLineLength, maxHeaderSize, maxChunkSize), + new HttpServerResponseEncoder(this)); + } + + public HttpServerCodec(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool validateHeaders) + { + this.Init(new HttpServerRequestDecoder(this, maxInitialLineLength, maxHeaderSize, maxChunkSize, validateHeaders), + new HttpServerResponseEncoder(this)); + } + + public HttpServerCodec(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool validateHeaders, int initialBufferSize) + { + this.Init(new HttpServerRequestDecoder(this, maxInitialLineLength, maxHeaderSize, maxChunkSize, validateHeaders, initialBufferSize), + new HttpServerResponseEncoder(this)); + } + + public void UpgradeFrom(IChannelHandlerContext ctx) => ctx.Channel.Pipeline.Remove(this); + + sealed class HttpServerRequestDecoder : HttpRequestDecoder + { + readonly HttpServerCodec serverCodec; + + public HttpServerRequestDecoder(HttpServerCodec serverCodec, int maxInitialLineLength, int maxHeaderSize, int maxChunkSize) + : base(maxInitialLineLength, maxHeaderSize, maxChunkSize) + { + this.serverCodec = serverCodec; + } + + public HttpServerRequestDecoder(HttpServerCodec serverCodec, int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool validateHeaders) + :base(maxInitialLineLength, maxHeaderSize, maxChunkSize, validateHeaders) + { + this.serverCodec = serverCodec; + } + + public HttpServerRequestDecoder(HttpServerCodec serverCodec, + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, bool validateHeaders, int initialBufferSize) + : base(maxInitialLineLength, maxHeaderSize, maxChunkSize, validateHeaders, initialBufferSize) + { + this.serverCodec = serverCodec; + } + + protected override void Decode(IChannelHandlerContext context, IByteBuffer buffer, List output) + { + int oldSize = output.Count; + base.Decode(context, buffer, output); + int size = output.Count; + for (int i = oldSize; i < size; i++) + { + if (output[i] is IHttpRequest request) + { + this.serverCodec.queue.Enqueue(request.Method); + } + } + } + } + + sealed class HttpServerResponseEncoder : HttpResponseEncoder + { + readonly HttpServerCodec serverCodec; + HttpMethod method; + + public HttpServerResponseEncoder(HttpServerCodec serverCodec) + { + this.serverCodec = serverCodec; + } + + protected override void SanitizeHeadersBeforeEncode(IHttpResponse msg, bool isAlwaysEmpty) + { + if (!isAlwaysEmpty && ReferenceEquals(this.method, HttpMethod.Connect) && msg.Status.CodeClass == HttpStatusClass.Success) + { + // Stripping Transfer-Encoding: + // See https://tools.ietf.org/html/rfc7230#section-3.3.1 + msg.Headers.Remove(HttpHeaderNames.TransferEncoding); + return; + } + + base.SanitizeHeadersBeforeEncode(msg, isAlwaysEmpty); + } + + + protected override bool IsContentAlwaysEmpty(IHttpResponse msg) + { + this.method = this.serverCodec.queue.Count > 0 ? this.serverCodec.queue.Dequeue() : null; + return HttpMethod.Head.Equals(this.method) || base.IsContentAlwaysEmpty(msg); + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpServerExpectContinueHandler.cs b/src/DotNetty.Codecs.Http/HttpServerExpectContinueHandler.cs new file mode 100644 index 0000000..de4e417 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpServerExpectContinueHandler.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public class HttpServerExpectContinueHandler : ChannelHandlerAdapter + { + static readonly IFullHttpResponse ExpectationFailed = new DefaultFullHttpResponse( + HttpVersion.Http11, HttpResponseStatus.ExpectationFailed, Unpooled.Empty); + + static readonly IFullHttpResponse Accept = new DefaultFullHttpResponse( + HttpVersion.Http11, HttpResponseStatus.Continue, Unpooled.Empty); + + static HttpServerExpectContinueHandler() + { + ExpectationFailed.Headers.Set(HttpHeaderNames.ContentLength, HttpHeaderValues.Zero); + Accept.Headers.Set(HttpHeaderNames.ContentLength, HttpHeaderValues.Zero); + } + + protected virtual IHttpResponse AcceptMessage(IHttpRequest request) => (IHttpResponse)Accept.RetainedDuplicate(); + + protected virtual IHttpResponse RejectResponse(IHttpRequest request) => (IHttpResponse)ExpectationFailed.RetainedDuplicate(); + + public override void ChannelRead(IChannelHandlerContext context, object message) + { + if (message is IHttpRequest req) + { + if (HttpUtil.Is100ContinueExpected(req)) + { + IHttpResponse accept = this.AcceptMessage(req); + + if (accept == null) + { + // the expectation failed so we refuse the request. + IHttpResponse rejection = this.RejectResponse(req); + ReferenceCountUtil.Release(message); + context.WriteAndFlushAsync(rejection) + .ContinueWith(CloseOnFailure, context, TaskContinuationOptions.ExecuteSynchronously); + return; + } + + context.WriteAndFlushAsync(accept) + .ContinueWith(CloseOnFailure, context, TaskContinuationOptions.ExecuteSynchronously); + req.Headers.Remove(HttpHeaderNames.Expect); + } + base.ChannelRead(context, message); + } + } + + static Task CloseOnFailure(Task task, object state) + { + if (task.IsFaulted) + { + var context = (IChannelHandlerContext)state; + return context.CloseAsync(); + } + return TaskEx.Completed; + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpServerKeepAliveHandler.cs b/src/DotNetty.Codecs.Http/HttpServerKeepAliveHandler.cs new file mode 100644 index 0000000..981ff20 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpServerKeepAliveHandler.cs @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System.Threading.Tasks; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + using static HttpUtil; + + public class HttpServerKeepAliveHandler : ChannelDuplexHandler + { + static readonly AsciiString MultipartPrefix = new AsciiString("multipart"); + + bool persistentConnection = true; + // Track pending responses to support client pipelining: https://tools.ietf.org/html/rfc7230#section-6.3.2 + int pendingResponses; + + public override void ChannelRead(IChannelHandlerContext context, object message) + { + // read message and track if it was keepAlive + if (message is IHttpRequest request) + { + if (this.persistentConnection) + { + this.pendingResponses += 1; + this.persistentConnection = IsKeepAlive(request); + } + } + base.ChannelRead(context, message); + } + + public override Task WriteAsync(IChannelHandlerContext context, object message) + { + // modify message on way out to add headers if needed + if (message is IHttpResponse response) + { + this.TrackResponse(response); + // Assume the response writer knows if they can persist or not and sets isKeepAlive on the response + if (!IsKeepAlive(response) || !IsSelfDefinedMessageLength(response)) + { + // No longer keep alive as the client can't tell when the message is done unless we close connection + this.pendingResponses = 0; + this.persistentConnection = false; + } + // Server might think it can keep connection alive, but we should fix response header if we know better + if (!this.ShouldKeepAlive()) + { + SetKeepAlive(response, false); + } + } + if (message is ILastHttpContent && !this.ShouldKeepAlive()) + { + return base.WriteAsync(context, message) + .ContinueWith(CloseOnComplete, context, TaskContinuationOptions.ExecuteSynchronously); + } + return base.WriteAsync(context, message); + } + + static Task CloseOnComplete(Task task, object state) + { + var context = (IChannelHandlerContext)state; + return context.CloseAsync(); + } + + void TrackResponse(IHttpResponse response) + { + if (!IsInformational(response)) + { + this.pendingResponses -= 1; + } + } + + bool ShouldKeepAlive() => this.pendingResponses != 0 || this.persistentConnection; + + /// + /// Keep-alive only works if the client can detect when the message has ended without relying on the connection being + /// closed. + /// https://tools.ietf.org/html/rfc7230#section-6.3 + /// https://tools.ietf.org/html/rfc7230#section-3.3.2 + /// https://tools.ietf.org/html/rfc7230#section-3.3.3 + /// + /// The HttpResponse to check + /// true if the response has a self defined message length. + static bool IsSelfDefinedMessageLength(IHttpResponse response) => + IsContentLengthSet(response) || IsTransferEncodingChunked(response) || IsMultipart(response) + || IsInformational(response) || response.Status.Code == HttpResponseStatus.NoContent.Code; + + static bool IsInformational(IHttpResponse response) => response.Status.CodeClass == HttpStatusClass.Informational; + + static bool IsMultipart(IHttpResponse response) + { + return response.Headers.TryGet(HttpHeaderNames.ContentType, out ICharSequence contentType) + && contentType.RegionMatchesIgnoreCase(0, MultipartPrefix, 0, MultipartPrefix.Count); + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpServerUpgradeHandler.cs b/src/DotNetty.Codecs.Http/HttpServerUpgradeHandler.cs new file mode 100644 index 0000000..e7e4d55 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpServerUpgradeHandler.cs @@ -0,0 +1,333 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http +{ + using System.Collections.Generic; + using System.Diagnostics; + using System.Diagnostics.Contracts; + using System.Text; + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Common; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public class HttpServerUpgradeHandler : HttpObjectAggregator + { + /// + /// The source codec that is used in the pipeline initially. + /// + public interface ISourceCodec + { + /// + /// Removes this codec (i.e. all associated handlers) from the pipeline. + /// + void UpgradeFrom(IChannelHandlerContext ctx); + } + + /// + /// A codec that the source can be upgraded to. + /// + public interface IUpgradeCodec + { + /// + /// Gets all protocol-specific headers required by this protocol for a successful upgrade. + /// Any supplied header will be required to appear in the {@link HttpHeaderNames#CONNECTION} header as well. + /// + ICollection RequiredUpgradeHeaders { get; } + + /// + /// Prepares the {@code upgradeHeaders} for a protocol update based upon the contents of {@code upgradeRequest}. + /// This method returns a boolean value to proceed or abort the upgrade in progress. If {@code false} is + /// returned, the upgrade is aborted and the {@code upgradeRequest} will be passed through the inbound pipeline + /// as if no upgrade was performed. If {@code true} is returned, the upgrade will proceed to the next + /// step which invokes {@link #upgradeTo}. When returning {@code true}, you can add headers to + /// the {@code upgradeHeaders} so that they are added to the 101 Switching protocols response. + /// + bool PrepareUpgradeResponse(IChannelHandlerContext ctx, IFullHttpRequest upgradeRequest, HttpHeaders upgradeHeaders); + + /// + /// Performs an HTTP protocol upgrade from the source codec. This method is responsible for + /// adding all handlers required for the new protocol. + /// + /// ctx the context for the current handler. + /// upgradeRequest the request that triggered the upgrade to this protocol. + /// + void UpgradeTo(IChannelHandlerContext ctx, IFullHttpRequest upgradeRequest); + } + + /// + /// Creates a new UpgradeCodec for the requested protocol name. + /// + public interface IUpgradeCodecFactory + { + /// + /// Invoked by {@link HttpServerUpgradeHandler} for all the requested protocol names in the order of + /// the client preference.The first non-{@code null} {@link UpgradeCodec} returned by this method + /// will be selected. + /// + IUpgradeCodec NewUpgradeCodec(ICharSequence protocol); + } + + public sealed class UpgradeEvent : IReferenceCounted + { + readonly ICharSequence protocol; + readonly IFullHttpRequest upgradeRequest; + + internal UpgradeEvent(ICharSequence protocol, IFullHttpRequest upgradeRequest) + { + this.protocol = protocol; + this.upgradeRequest = upgradeRequest; + } + + public ICharSequence Protocol => this.protocol; + + public IFullHttpRequest UpgradeRequest => this.upgradeRequest; + + public int ReferenceCount => this.upgradeRequest.ReferenceCount; + + public IReferenceCounted Retain() + { + this.upgradeRequest.Retain(); + return this; + } + + public IReferenceCounted Retain(int increment) + { + this.upgradeRequest.Retain(increment); + return this; + } + + public IReferenceCounted Touch() + { + this.upgradeRequest.Touch(); + return this; + } + + public IReferenceCounted Touch(object hint) + { + this.upgradeRequest.Touch(hint); + return this; + } + + public bool Release() => this.upgradeRequest.Release(); + + public bool Release(int decrement) => this.upgradeRequest.Release(decrement); + + public override string ToString() => $"UpgradeEvent [protocol={this.protocol}, upgradeRequest={this.upgradeRequest}]"; + } + + readonly ISourceCodec sourceCodec; + readonly IUpgradeCodecFactory upgradeCodecFactory; + bool handlingUpgrade; + + public HttpServerUpgradeHandler(ISourceCodec sourceCodec, IUpgradeCodecFactory upgradeCodecFactory) + : this(sourceCodec, upgradeCodecFactory, 0) + { + } + + public HttpServerUpgradeHandler(ISourceCodec sourceCodec, IUpgradeCodecFactory upgradeCodecFactory, int maxContentLength) + : base(maxContentLength) + { + Contract.Requires(sourceCodec != null); + Contract.Requires(upgradeCodecFactory != null); + + this.sourceCodec = sourceCodec; + this.upgradeCodecFactory = upgradeCodecFactory; + } + + protected override void Decode(IChannelHandlerContext context, IHttpObject message, List output) + { + // Determine if we're already handling an upgrade request or just starting a new one. + this.handlingUpgrade |= IsUpgradeRequest(message); + if (!this.handlingUpgrade) + { + // Not handling an upgrade request, just pass it to the next handler. + ReferenceCountUtil.Retain(message); + output.Add(message); + return; + } + + if (message is IFullHttpRequest fullRequest) + { + ReferenceCountUtil.Retain(fullRequest); + output.Add(fullRequest); + } + else + { + // Call the base class to handle the aggregation of the full request. + base.Decode(context, message, output); + if (output.Count == 0) + { + // The full request hasn't been created yet, still awaiting more data. + return; + } + + // Finished aggregating the full request, get it from the output list. + Debug.Assert(output.Count == 1); + this.handlingUpgrade = false; + fullRequest = (IFullHttpRequest)output[0]; + } + + if (this.Upgrade(context, fullRequest)) + { + // The upgrade was successful, remove the message from the output list + // so that it's not propagated to the next handler. This request will + // be propagated as a user event instead. + output.Clear(); + } + + // The upgrade did not succeed, just allow the full request to propagate to the + // next handler. + } + + static bool IsUpgradeRequest(IHttpObject msg) + { + if (!(msg is IHttpRequest request)) + { + return false; + } + return request.Headers.Contains(HttpHeaderNames.Upgrade); + } + + bool Upgrade(IChannelHandlerContext ctx, IFullHttpRequest request) + { + // Select the best protocol based on those requested in the UPGRADE header. + IList requestedProtocols = SplitHeader(request.Headers.Get(HttpHeaderNames.Upgrade, null)); + int numRequestedProtocols = requestedProtocols.Count; + IUpgradeCodec upgradeCodec = null; + ICharSequence upgradeProtocol = null; + for (int i = 0; i < numRequestedProtocols; i++) + { + ICharSequence p = requestedProtocols[i]; + IUpgradeCodec c = this.upgradeCodecFactory.NewUpgradeCodec(p); + if (c != null) + { + upgradeProtocol = p; + upgradeCodec = c; + break; + } + } + + if (upgradeCodec == null) + { + // None of the requested protocols are supported, don't upgrade. + return false; + } + + // Make sure the CONNECTION header is present. + ; + if (!request.Headers.TryGet(HttpHeaderNames.Connection, out ICharSequence connectionHeader)) + { + return false; + } + + // Make sure the CONNECTION header contains UPGRADE as well as all protocol-specific headers. + ICollection requiredHeaders = upgradeCodec.RequiredUpgradeHeaders; + IList values = SplitHeader(connectionHeader); + if (!AsciiString.ContainsContentEqualsIgnoreCase(values, HttpHeaderNames.Upgrade) + || !AsciiString.ContainsAllContentEqualsIgnoreCase(values, requiredHeaders)) + { + return false; + } + + // Ensure that all required protocol-specific headers are found in the request. + foreach (AsciiString requiredHeader in requiredHeaders) + { + if (!request.Headers.Contains(requiredHeader)) + { + return false; + } + } + + // Prepare and send the upgrade response. Wait for this write to complete before upgrading, + // since we need the old codec in-place to properly encode the response. + IFullHttpResponse upgradeResponse = CreateUpgradeResponse(upgradeProtocol); + if (!upgradeCodec.PrepareUpgradeResponse(ctx, request, upgradeResponse.Headers)) + { + return false; + } + + // Create the user event to be fired once the upgrade completes. + var upgradeEvent = new UpgradeEvent(upgradeProtocol, request); + + IUpgradeCodec finalUpgradeCodec = upgradeCodec; + ctx.WriteAndFlushAsync(upgradeResponse).ContinueWith(t => + { + try + { + if (t.Status == TaskStatus.RanToCompletion) + { + // Perform the upgrade to the new protocol. + this.sourceCodec.UpgradeFrom(ctx); + finalUpgradeCodec.UpgradeTo(ctx, request); + + // Notify that the upgrade has occurred. Retain the event to offset + // the release() in the finally block. + ctx.FireUserEventTriggered(upgradeEvent.Retain()); + + // Remove this handler from the pipeline. + ctx.Channel.Pipeline.Remove(this); + } + else + { + ctx.Channel.CloseAsync(); + } + } + finally + { + // Release the event if the upgrade event wasn't fired. + upgradeEvent.Release(); + } + }, TaskContinuationOptions.ExecuteSynchronously); + return true; + } + + static IFullHttpResponse CreateUpgradeResponse(ICharSequence upgradeProtocol) + { + var res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.SwitchingProtocols, + Unpooled.Empty, false); + res.Headers.Add(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade); + res.Headers.Add(HttpHeaderNames.Upgrade, upgradeProtocol); + return res; + } + + static IList SplitHeader(ICharSequence header) + { + var builder = new StringBuilder(header.Count); + var protocols = new List(4); + // ReSharper disable once ForCanBeConvertedToForeach + for (int i = 0; i < header.Count; ++i) + { + char c = header[i]; + if (char.IsWhiteSpace(c)) + { + // Don't include any whitespace. + continue; + } + if (c == ',') + { + // Add the string and reset the builder for the next protocol. + // Add the string and reset the builder for the next protocol. + protocols.Add(new AsciiString(builder.ToString())); + builder.Length = 0; + } + else + { + builder.Append(c); + } + } + + // Add the last protocol + if (builder.Length > 0) + { + protocols.Add(new AsciiString(builder.ToString())); + } + + return protocols; + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpStatusClass.cs b/src/DotNetty.Codecs.Http/HttpStatusClass.cs new file mode 100644 index 0000000..73f5dd0 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpStatusClass.cs @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http +{ + using System; + using DotNetty.Common.Utilities; + + public struct HttpStatusClass : IEquatable + { + public static readonly HttpStatusClass Informational = new HttpStatusClass(100, 200, "Informational"); + + public static readonly HttpStatusClass Success = new HttpStatusClass(200, 300, "Success"); + + public static readonly HttpStatusClass Redirection = new HttpStatusClass(300, 400, "Redirection"); + + public static readonly HttpStatusClass ClientError = new HttpStatusClass(400, 500, "Client Error"); + + public static readonly HttpStatusClass ServerError = new HttpStatusClass(500, 600, "Server Error"); + + public static readonly HttpStatusClass Unknown = new HttpStatusClass(0, 0, "Unknown Status"); + + public static HttpStatusClass ValueOf(int code) + { + if (Contains(Informational, code)) + { + return Informational; + } + if (Contains(Success, code)) + { + return Success; + } + if (Contains(Redirection, code)) + { + return Redirection; + } + if (Contains(ClientError, code)) + { + return ClientError; + } + if (Contains(ServerError, code)) + { + return ServerError; + } + return Unknown; + } + + public static HttpStatusClass ValueOf(ICharSequence code) + { + if (code != null && code.Count == 3) + { + char c0 = code[0]; + return IsDigit(c0) && IsDigit(code[1]) && IsDigit(code[2]) + ? ValueOf(Digit(c0) * 100) + : Unknown; + } + + return Unknown; + } + + static int Digit(char c) => c - '0'; + + static bool IsDigit(char c) => c >= '0' && c <= '9'; + + readonly int min; + readonly int max; + readonly AsciiString defaultReasonPhrase; + + HttpStatusClass(int min, int max, string defaultReasonPhrase) + { + this.min = min; + this.max = max; + this.defaultReasonPhrase = AsciiString.Cached(defaultReasonPhrase); + } + + public bool Contains(int code) => Contains(this, code); + + public static bool Contains(HttpStatusClass httpStatusClass, int code) + { + if ((httpStatusClass.min & httpStatusClass.max) == 0) + { + return code < 100 || code >= 600; + } + + return code >= httpStatusClass.min && code < httpStatusClass.max; + } + + public AsciiString DefaultReasonPhrase => this.defaultReasonPhrase; + + public bool Equals(HttpStatusClass other) => this.min == other.min && this.max == other.max; + + public override bool Equals(object obj) => obj is HttpStatusClass && this.Equals((HttpStatusClass)obj); + + public override int GetHashCode() => this.min.GetHashCode() ^ this.max.GetHashCode(); + + public static bool operator !=(HttpStatusClass left, HttpStatusClass right) => !(left == right); + + public static bool operator ==(HttpStatusClass left, HttpStatusClass right) => left.Equals(right); + } +} diff --git a/src/DotNetty.Codecs.Http/HttpUtil.cs b/src/DotNetty.Codecs.Http/HttpUtil.cs new file mode 100644 index 0000000..54bc86a --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpUtil.cs @@ -0,0 +1,287 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System; + using System.Collections.Generic; + using System.Text; + using DotNetty.Common.Utilities; + + public static class HttpUtil + { + static readonly AsciiString CharsetEquals = new AsciiString(HttpHeaderValues.Charset + "="); + static readonly AsciiString Semicolon = AsciiString.Cached(";"); + + public static bool IsKeepAlive(IHttpMessage message) + { + if (message.Headers.TryGet(HttpHeaderNames.Connection, out ICharSequence connection) + && HttpHeaderValues.Close.ContentEqualsIgnoreCase(connection)) + { + return false; + } + + if (message.ProtocolVersion.IsKeepAliveDefault) + { + return !HttpHeaderValues.Close.ContentEqualsIgnoreCase(connection); + } + else + { + return HttpHeaderValues.KeepAlive.ContentEqualsIgnoreCase(connection); + } + } + + public static void SetKeepAlive(IHttpMessage message, bool keepAlive) => SetKeepAlive(message.Headers, message.ProtocolVersion, keepAlive); + + public static void SetKeepAlive(HttpHeaders headers, HttpVersion httpVersion, bool keepAlive) + { + if (httpVersion.IsKeepAliveDefault) + { + if (keepAlive) + { + headers.Remove(HttpHeaderNames.Connection); + } + else + { + headers.Set(HttpHeaderNames.Connection, HttpHeaderValues.Close); + } + } + else + { + if (keepAlive) + { + headers.Set(HttpHeaderNames.Connection, HttpHeaderValues.KeepAlive); + } + else + { + headers.Remove(HttpHeaderNames.Connection); + } + } + } + + public static long GetContentLength(IHttpMessage message) + { + if (message.Headers.TryGet(HttpHeaderNames.ContentLength, out ICharSequence value)) + { + return CharUtil.ParseLong(value); + } + + // We know the content length if it's a Web Socket message even if + // Content-Length header is missing. + long webSocketContentLength = GetWebSocketContentLength(message); + if (webSocketContentLength >= 0) + { + return webSocketContentLength; + } + + // Otherwise we don't. + throw new FormatException($"header not found: {HttpHeaderNames.ContentLength}"); + } + + public static long GetContentLength(IHttpMessage message, long defaultValue) + { + if (message.Headers.TryGet(HttpHeaderNames.ContentLength, out ICharSequence value)) + { + return CharUtil.ParseLong(value); + } + + // We know the content length if it's a Web Socket message even if + // Content-Length header is missing. + long webSocketContentLength = GetWebSocketContentLength(message); + if (webSocketContentLength >= 0) + { + return webSocketContentLength; + } + + // Otherwise we don't. + return defaultValue; + } + + public static int GetContentLength(IHttpMessage message, int defaultValue) => + (int)Math.Min(int.MaxValue, GetContentLength(message, (long)defaultValue)); + + static int GetWebSocketContentLength(IHttpMessage message) + { + // WebSocket messages have constant content-lengths. + HttpHeaders h = message.Headers; + if (message is IHttpRequest req) + { + if (HttpMethod.Get.Equals(req.Method) + && h.Contains(HttpHeaderNames.SecWebsocketKey1) + && h.Contains(HttpHeaderNames.SecWebsocketKey2)) + { + return 8; + } + } + else if (message is IHttpResponse res) + { + if (res.Status.Code == 101 + && h.Contains(HttpHeaderNames.SecWebsocketOrigin) + && h.Contains(HttpHeaderNames.SecWebsocketLocation)) + { + return 16; + } + } + + // Not a web socket message + return -1; + } + + public static void SetContentLength(IHttpMessage message, long length) => message.Headers.Set(HttpHeaderNames.ContentLength, length); + + public static bool IsContentLengthSet(IHttpMessage message) => message.Headers.Contains(HttpHeaderNames.ContentLength); + + public static bool Is100ContinueExpected(IHttpMessage message) + { + if (!IsExpectHeaderValid(message)) + { + return false; + } + + ICharSequence expectValue = message.Headers.Get(HttpHeaderNames.Expect, null); + // unquoted tokens in the expect header are case-insensitive, thus 100-continue is case insensitive + return HttpHeaderValues.Continue.ContentEqualsIgnoreCase(expectValue); + } + + internal static bool IsUnsupportedExpectation(IHttpMessage message) + { + if (!IsExpectHeaderValid(message)) + { + return false; + } + + return message.Headers.TryGet(HttpHeaderNames.Expect, out ICharSequence expectValue) + && !HttpHeaderValues.Continue.ContentEqualsIgnoreCase(expectValue); + } + + // Expect: 100-continue is for requests only and it works only on HTTP/1.1 or later. Note further that RFC 7231 + // section 5.1.1 says "A server that receives a 100-continue expectation in an HTTP/1.0 request MUST ignore + // that expectation." + static bool IsExpectHeaderValid(IHttpMessage message) => message is IHttpRequest + && message.ProtocolVersion.CompareTo(HttpVersion.Http11) >= 0; + + public static void Set100ContinueExpected(IHttpMessage message, bool expected) + { + if (expected) + { + message.Headers.Set(HttpHeaderNames.Expect, HttpHeaderValues.Continue); + } + else + { + message.Headers.Remove(HttpHeaderNames.Expect); + } + } + + public static bool IsTransferEncodingChunked(IHttpMessage message) => message.Headers.Contains(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked, true); + + public static void SetTransferEncodingChunked(IHttpMessage m, bool chunked) + { + if (chunked) + { + m.Headers.Set(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + m.Headers.Remove(HttpHeaderNames.ContentLength); + } + else + { + IList encodings = m.Headers.GetAll(HttpHeaderNames.TransferEncoding); + if (encodings.Count == 0) + { + return; + } + var values = new List(encodings); + foreach (ICharSequence value in encodings) + { + if (HttpHeaderValues.Chunked.ContentEqualsIgnoreCase(value)) + { + values.Remove(value); + } + } + if (values.Count == 0) + { + m.Headers.Remove(HttpHeaderNames.TransferEncoding); + } + else + { + m.Headers.Set(HttpHeaderNames.TransferEncoding, values); + } + } + } + + public static Encoding GetCharset(IHttpMessage message) => GetCharset(message, Encoding.UTF8); + + public static Encoding GetCharset(ICharSequence contentTypeValue) => contentTypeValue != null ? GetCharset(contentTypeValue, Encoding.UTF8) : Encoding.UTF8; + + public static Encoding GetCharset(IHttpMessage message, Encoding defaultCharset) + { + return message.Headers.TryGet(HttpHeaderNames.ContentType, out ICharSequence contentTypeValue) + ? GetCharset(contentTypeValue, defaultCharset) + : defaultCharset; + } + + public static Encoding GetCharset(ICharSequence contentTypeValue, Encoding defaultCharset) + { + if (contentTypeValue != null) + { + ICharSequence charsetCharSequence = GetCharsetAsSequence(contentTypeValue); + if (charsetCharSequence != null) + { + try + { + return Encoding.GetEncoding(charsetCharSequence.ToString()); + } + catch (ArgumentException) + { + return defaultCharset; + } + } + else + { + return defaultCharset; + } + } + else + { + return defaultCharset; + } + } + + public static ICharSequence GetCharsetAsSequence(IHttpMessage message) + => message.Headers.TryGet(HttpHeaderNames.ContentType, out ICharSequence contentTypeValue) ? GetCharsetAsSequence(contentTypeValue) : null; + + public static ICharSequence GetCharsetAsSequence(ICharSequence contentTypeValue) + { + if (contentTypeValue == null) + { + throw new ArgumentException(nameof(contentTypeValue)); + } + int indexOfCharset = AsciiString.IndexOfIgnoreCaseAscii(contentTypeValue, CharsetEquals, 0); + if (indexOfCharset != AsciiString.IndexNotFound) + { + int indexOfEncoding = indexOfCharset + CharsetEquals.Count; + if (indexOfEncoding < contentTypeValue.Count) + { + return contentTypeValue.SubSequence(indexOfEncoding, contentTypeValue.Count); + } + } + return null; + } + + public static ICharSequence GetMimeType(IHttpMessage message) => + message.Headers.TryGet(HttpHeaderNames.ContentType, out ICharSequence contentTypeValue) ? GetMimeType(contentTypeValue) : null; + + public static ICharSequence GetMimeType(ICharSequence contentTypeValue) + { + if (contentTypeValue == null) + { + throw new ArgumentException(nameof(contentTypeValue)); + } + int indexOfSemicolon = AsciiString.IndexOfIgnoreCaseAscii(contentTypeValue, Semicolon, 0); + if (indexOfSemicolon != AsciiString.IndexNotFound) + { + return contentTypeValue.SubSequence(0, indexOfSemicolon); + } + + return contentTypeValue.Count > 0 ? contentTypeValue : null; + } + } +} diff --git a/src/DotNetty.Codecs.Http/HttpVersion.cs b/src/DotNetty.Codecs.Http/HttpVersion.cs new file mode 100644 index 0000000..e128fb3 --- /dev/null +++ b/src/DotNetty.Codecs.Http/HttpVersion.cs @@ -0,0 +1,239 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http +{ + using System; + using System.Diagnostics.Contracts; + using System.Runtime.CompilerServices; + using System.Text; + using System.Text.RegularExpressions; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + + public class HttpVersion : IComparable, IComparable + { + static readonly Regex VersionPattern = new Regex("^(\\S+)/(\\d+)\\.(\\d+)$", RegexOptions.Compiled); + + internal static readonly AsciiString Http10String = new AsciiString("HTTP/1.0"); + internal static readonly AsciiString Http11String = new AsciiString("HTTP/1.1"); + + public static readonly HttpVersion Http10 = new HttpVersion("HTTP", 1, 0, false, true); + public static readonly HttpVersion Http11 = new HttpVersion("HTTP", 1, 1, true, true); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static HttpVersion ValueOf(AsciiString text) + { + if (text == null) + { + ThrowHelper.ThrowArgumentException_NullText(); + } + + // ReSharper disable once PossibleNullReferenceException + HttpVersion version = ValueOfInline(text.Array); + if (version != null) + { + return version; + } + + // Fall back to slow path + text = text.Trim(); + + if (text.Count == 0) + { + ThrowHelper.ThrowArgumentException_EmptyText(); + } + + // Try to match without convert to uppercase first as this is what 99% of all clients + // will send anyway. Also there is a change to the RFC to make it clear that it is + // expected to be case-sensitive + // + // See: + // * http://trac.tools.ietf.org/wg/httpbis/trac/ticket/1 + // * http://trac.tools.ietf.org/wg/httpbis/trac/wiki + // + return Version0(text) ?? new HttpVersion(text.ToString(), true); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static HttpVersion ValueOfInline(byte[] bytes) + { + if (bytes.Length != 8) return null; + + if (bytes[0] != (byte)'H') return null; + if (bytes[1] != (byte)'T') return null; + if (bytes[2] != (byte)'T') return null; + if (bytes[3] != (byte)'P') return null; + if (bytes[4] != (byte)'/') return null; + if (bytes[5] != (byte)'1') return null; + if (bytes[6] != (byte)'.') return null; + switch (bytes[7]) + { + case (byte)'1': + return Http11; + case (byte)'0': + return Http10; + default: + return null; + } + } + + static HttpVersion Version0(AsciiString text) + { + if (Http11String.Equals(text)) + { + return Http11; + } + if (Http10String.Equals(text)) + { + return Http10; + } + + return null; + } + + readonly string protocolName; + readonly int majorVersion; + readonly int minorVersion; + readonly AsciiString text; + readonly bool keepAliveDefault; + readonly byte[] bytes; + + public HttpVersion(string text, bool keepAliveDefault) + { + Contract.Requires(text != null); + + text = text.Trim().ToUpper(); + if (string.IsNullOrEmpty(text)) + { + throw new ArgumentException("empty text"); + } + + Match match = VersionPattern.Match(text); + if (!match.Success) + { + throw new ArgumentException($"invalid version format: {text}"); + } + + this.protocolName = match.Groups[1].Value; + this.majorVersion = int.Parse(match.Groups[2].Value); + this.minorVersion = int.Parse(match.Groups[3].Value); + this.text = new AsciiString($"{this.ProtocolName}/{this.MajorVersion}.{this.MinorVersion}"); + this.keepAliveDefault = keepAliveDefault; + this.bytes = null; + } + + HttpVersion(string protocolName, int majorVersion, int minorVersion, bool keepAliveDefault, bool bytes) + { + if (protocolName == null) + { + throw new ArgumentException(nameof(protocolName)); + } + + protocolName = protocolName.Trim().ToUpper(); + if (string.IsNullOrEmpty(protocolName)) + { + throw new ArgumentException("empty protocolName"); + } + + // ReSharper disable once ForCanBeConvertedToForeach + for (int i = 0; i < protocolName.Length; i++) + { + char c = protocolName[i]; + if (CharUtil.IsISOControl(c) || char.IsWhiteSpace(c)) + { + throw new ArgumentException($"invalid character {c} in protocolName"); + } + } + + if (majorVersion < 0) + { + throw new ArgumentException("negative majorVersion"); + } + if (minorVersion < 0) + { + throw new ArgumentException("negative minorVersion"); + } + + this.protocolName = protocolName; + this.majorVersion = majorVersion; + this.minorVersion = minorVersion; + this.text = new AsciiString(protocolName + '/' + majorVersion + '.' + minorVersion); + this.keepAliveDefault = keepAliveDefault; + + this.bytes = bytes ? this.text.Array : null; + } + + public string ProtocolName => this.protocolName; + + public int MajorVersion => this.majorVersion; + + public int MinorVersion => this.minorVersion; + + public AsciiString Text => this.text; + + public bool IsKeepAliveDefault => this.keepAliveDefault; + + public override string ToString() => this.text.ToString(); + + public override int GetHashCode() => (this.protocolName.GetHashCode() * 31 + this.majorVersion) * 31 + this.minorVersion; + + public override bool Equals(object obj) + { + if (!(obj is HttpVersion that)) + { + return false; + } + + return this.minorVersion == that.minorVersion + && this.majorVersion == that.majorVersion + && this.protocolName.Equals(that.protocolName); + } + + public int CompareTo(HttpVersion other) + { + int v = string.CompareOrdinal(this.protocolName, other.protocolName); + if (v != 0) + { + return v; + } + + v = this.majorVersion - other.majorVersion; + if (v != 0) + { + return v; + } + + return this.minorVersion - other.minorVersion; + } + + public int CompareTo(object obj) + { + if (ReferenceEquals(this, obj)) + { + return 0; + } + + if (!(obj is HttpVersion)) + { + throw new ArgumentException($"{nameof(obj)} must be of {nameof(HttpVersion)} type"); + } + + return this.CompareTo((HttpVersion)obj); + } + + internal void Encode(IByteBuffer buf) + { + if (this.bytes == null) + { + buf.WriteCharSequence(this.text, Encoding.ASCII); + } + else + { + buf.WriteBytes(this.bytes); + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/IFullHttpMessage.cs b/src/DotNetty.Codecs.Http/IFullHttpMessage.cs new file mode 100644 index 0000000..71da446 --- /dev/null +++ b/src/DotNetty.Codecs.Http/IFullHttpMessage.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + public interface IFullHttpMessage : IHttpMessage, ILastHttpContent + { + } +} diff --git a/src/DotNetty.Codecs.Http/IFullHttpRequest.cs b/src/DotNetty.Codecs.Http/IFullHttpRequest.cs new file mode 100644 index 0000000..3883766 --- /dev/null +++ b/src/DotNetty.Codecs.Http/IFullHttpRequest.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + public interface IFullHttpRequest : IHttpRequest, IFullHttpMessage + { + } +} diff --git a/src/DotNetty.Codecs.Http/IFullHttpResponse.cs b/src/DotNetty.Codecs.Http/IFullHttpResponse.cs new file mode 100644 index 0000000..8c34509 --- /dev/null +++ b/src/DotNetty.Codecs.Http/IFullHttpResponse.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + public interface IFullHttpResponse : IHttpResponse, IFullHttpMessage + { + } +} diff --git a/src/DotNetty.Codecs.Http/IHttpContent.cs b/src/DotNetty.Codecs.Http/IHttpContent.cs new file mode 100644 index 0000000..45e4f32 --- /dev/null +++ b/src/DotNetty.Codecs.Http/IHttpContent.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using DotNetty.Buffers; + + public interface IHttpContent : IHttpObject, IByteBufferHolder + { + } +} diff --git a/src/DotNetty.Codecs.Http/IHttpMessage.cs b/src/DotNetty.Codecs.Http/IHttpMessage.cs new file mode 100644 index 0000000..0f2a450 --- /dev/null +++ b/src/DotNetty.Codecs.Http/IHttpMessage.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + public interface IHttpMessage : IHttpObject + { + HttpVersion ProtocolVersion { get; } + + IHttpMessage SetProtocolVersion(HttpVersion version); + + HttpHeaders Headers { get; } + } +} diff --git a/src/DotNetty.Codecs.Http/IHttpObject.cs b/src/DotNetty.Codecs.Http/IHttpObject.cs new file mode 100644 index 0000000..34cf1d5 --- /dev/null +++ b/src/DotNetty.Codecs.Http/IHttpObject.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + public interface IHttpObject : IDecoderResultProvider + { + } +} diff --git a/src/DotNetty.Codecs.Http/IHttpRequest.cs b/src/DotNetty.Codecs.Http/IHttpRequest.cs new file mode 100644 index 0000000..839a757 --- /dev/null +++ b/src/DotNetty.Codecs.Http/IHttpRequest.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + public interface IHttpRequest : IHttpMessage + { + HttpMethod Method { get; } + + IHttpRequest SetMethod(HttpMethod method); + + string Uri { get; } + + IHttpRequest SetUri(string uri); + } +} diff --git a/src/DotNetty.Codecs.Http/IHttpResponse.cs b/src/DotNetty.Codecs.Http/IHttpResponse.cs new file mode 100644 index 0000000..d1a5031 --- /dev/null +++ b/src/DotNetty.Codecs.Http/IHttpResponse.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + public interface IHttpResponse : IHttpMessage + { + HttpResponseStatus Status { get; } + + IHttpResponse SetStatus(HttpResponseStatus status); + } +} diff --git a/src/DotNetty.Codecs.Http/ILastHttpContent.cs b/src/DotNetty.Codecs.Http/ILastHttpContent.cs new file mode 100644 index 0000000..490aa74 --- /dev/null +++ b/src/DotNetty.Codecs.Http/ILastHttpContent.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + public interface ILastHttpContent : IHttpContent + { + HttpHeaders TrailingHeaders { get; } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/AbstractDiskHttpData.cs b/src/DotNetty.Codecs.Http/Multipart/AbstractDiskHttpData.cs new file mode 100644 index 0000000..f8130a5 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/AbstractDiskHttpData.cs @@ -0,0 +1,341 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using System.Diagnostics.Contracts; + using System.IO; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common; + using DotNetty.Common.Internal.Logging; + using DotNetty.Common.Utilities; + + public abstract class AbstractDiskHttpData : AbstractHttpData + { + static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(); + + FileStream fileStream; + + protected AbstractDiskHttpData(string name, Encoding charset, long size) : base(name, charset, size) + { + } + + protected abstract string DiskFilename { get; } + + protected abstract string Prefix { get; } + + protected abstract string BaseDirectory { get; } + + protected abstract string Postfix { get; } + + protected abstract bool DeleteOnExit { get; } + + FileStream TempFile() + { + string newpostfix; + string diskFilename = this.DiskFilename; + if (diskFilename != null) + { + newpostfix = '_' + diskFilename; + } + else + { + newpostfix = this.Postfix; + } + string directory = this.BaseDirectory == null + ? Path.GetTempPath() + : Path.Combine(Path.GetTempPath(), this.BaseDirectory); + // File.createTempFile + string fileName = Path.Combine(directory, $"{this.Prefix}{Path.GetRandomFileName()}{newpostfix}"); + FileStream tmpFile = File.Create(fileName, 4096, // DefaultBufferSize + this.DeleteOnExit ? FileOptions.DeleteOnClose : FileOptions.None); + return tmpFile; + } + + public override void SetContent(IByteBuffer buffer) + { + Contract.Requires(buffer != null); + try + { + this.Size = buffer.ReadableBytes; + this.CheckSize(this.Size); + if (this.DefinedSize > 0 && this.DefinedSize < this.Size) + { + throw new IOException($"Out of size: {this.Size} > {this.DefinedSize}"); + } + if (this.fileStream == null) + { + this.fileStream = this.TempFile(); + } + if (buffer.ReadableBytes == 0) + { + // empty file + return; + } + + buffer.GetBytes(buffer.ReaderIndex, this.fileStream, buffer.ReadableBytes); + buffer.SetReaderIndex(buffer.ReaderIndex + buffer.ReadableBytes); + this.fileStream.Flush(); + this.SetCompleted(); + } + finally + { + // Release the buffer as it was retained before and we not need a reference to it at all + // See https://github.com/netty/netty/issues/1516 + buffer.Release(); + } + } + + public override void AddContent(IByteBuffer buffer, bool last) + { + if (buffer != null) + { + try + { + int localsize = buffer.ReadableBytes; + this.CheckSize(this.Size + localsize); + if (this.DefinedSize > 0 && this.DefinedSize < this.Size + localsize) + { + throw new IOException($"Out of size: {this.Size} > {this.DefinedSize}"); + } + if (this.fileStream == null) + { + this.fileStream = this.TempFile(); + } + buffer.GetBytes(buffer.ReaderIndex, this.fileStream, buffer.ReadableBytes); + buffer.SetReaderIndex(buffer.ReaderIndex + localsize); + this.fileStream.Flush(); + + this.Size += buffer.ReadableBytes; + } + finally + { + // Release the buffer as it was retained before and we not need a reference to it at all + // See https://github.com/netty/netty/issues/1516 + buffer.Release(); + } + } + if (last) + { + if (this.fileStream == null) + { + this.fileStream = this.TempFile(); + } + this.SetCompleted(); + } + else + { + if (buffer == null) + { + throw new ArgumentNullException(nameof(buffer)); + } + } + } + + public override void SetContent(Stream source) + { + Contract.Requires(source != null); + + if (this.fileStream != null) + { + this.Delete(); + } + + this.fileStream = this.TempFile(); + int written = 0; + var bytes = new byte[4096 * 4]; + while (true) + { + int read = source.Read(bytes, 0, bytes.Length); + if (read <= 0) + { + break; + } + + written += read; + this.CheckSize(written); + this.fileStream.Write(bytes, 0, read); + } + this.fileStream.Flush(); + // Reset the position to start for reads + this.fileStream.Position -= written; + + this.Size = written; + if (this.DefinedSize > 0 && this.DefinedSize < this.Size) + { + try + { + Delete(this.fileStream); + } + catch (Exception error) + { + Logger.Warn("Failed to delete: {} {}", this.fileStream, error); + } + this.fileStream = null; + throw new IOException($"Out of size: {this.Size} > {this.DefinedSize}"); + } + //isRenamed = true; + this.SetCompleted(); + } + + public override void Delete() + { + if (this.fileStream != null) + { + try + { + Delete(this.fileStream); + } + catch (IOException error) + { + Logger.Warn("Failed to delete file.", error); + } + + this.fileStream = null; + } + } + + public override byte[] GetBytes() => this.fileStream == null + ? ArrayExtensions.ZeroBytes : ReadFrom(this.fileStream); + + public override IByteBuffer GetByteBuffer() + { + if (this.fileStream == null) + { + return Unpooled.Empty; + } + + byte[] array = ReadFrom(this.fileStream); + return Unpooled.WrappedBuffer(array); + } + + public override IByteBuffer GetChunk(int length) + { + if (this.fileStream == null || length == 0) + { + return Unpooled.Empty; + } + int read = 0; + var bytes = new byte[length]; + while (read < length) + { + int readnow = this.fileStream.Read(bytes, read, length - read); + if (readnow <= 0) + { + break; + } + + read += readnow; + } + if (read == 0) + { + return Unpooled.Empty; + } + IByteBuffer buffer = Unpooled.WrappedBuffer(bytes); + buffer.SetReaderIndex(0); + buffer.SetWriterIndex(read); + return buffer; + } + + public override string GetString() => this.GetString(HttpConstants.DefaultEncoding); + + public override string GetString(Encoding encoding) + { + if (this.fileStream == null) + { + return string.Empty; + } + byte[] array = ReadFrom(this.fileStream); + if (encoding == null) + { + encoding = HttpConstants.DefaultEncoding; + } + + return encoding.GetString(array); + } + + public override bool IsInMemory => false; + + public override bool RenameTo(FileStream destination) + { + Contract.Requires(destination != null); + if (this.fileStream == null) + { + throw new InvalidOperationException("No file defined so cannot be renamed"); + } + + // must copy + long chunkSize = 8196; + int position = 0; + while (position < this.Size) + { + if (chunkSize < this.Size - position) + { + chunkSize = this.Size - position; + } + + var buffer = new byte[chunkSize]; + int read = this.fileStream.Read(buffer, 0, (int)chunkSize); + if (read <= 0) + { + break; + } + + destination.Write(buffer, 0, read); + position += read; + } + + if (position == this.Size) + { + try + { + Delete(this.fileStream); + } + catch (IOException exception) + { + Logger.Warn("Failed to delete file.", exception); + } + this.fileStream = destination; + return true; + } + else + { + try + { + Delete(destination); + } + catch (IOException exception) + { + Logger.Warn("Failed to delete file.", exception); + } + return false; + } + } + + static void Delete(FileStream fileStream) + { + string fileName = fileStream.Name; + fileStream.Dispose(); + File.Delete(fileName); + } + + static byte[] ReadFrom(Stream fileStream) + { + long srcsize = fileStream.Length; + if (srcsize > int.MaxValue) + { + throw new ArgumentException("File too big to be loaded in memory"); + } + + var array = new byte[(int)srcsize]; + fileStream.Read(array, 0, array.Length); + return array; + } + + public override FileStream GetFile() => this.fileStream; + + public override IReferenceCounted Touch(object hint) => this; + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/AbstractHttpData.cs b/src/DotNetty.Codecs.Http/Multipart/AbstractHttpData.cs new file mode 100644 index 0000000..75f2f0a --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/AbstractHttpData.cs @@ -0,0 +1,137 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoProperty +// ReSharper disable ConvertToAutoPropertyWithPrivateSetter +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using System.Diagnostics.Contracts; + using System.IO; + using System.Text; + using System.Text.RegularExpressions; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public abstract class AbstractHttpData : AbstractReferenceCounted, IHttpData + { + static readonly Regex StripPattern = new Regex("(?:^\\s+|\\s+$|\\n)", RegexOptions.Compiled); + static readonly Regex ReplacePattern = new Regex("[\\r\\t]", RegexOptions.Compiled); + + readonly string name; + protected long DefinedSize; + protected long Size; + Encoding charset = HttpConstants.DefaultEncoding; + bool completed; + long maxSize = DefaultHttpDataFactory.MaxSize; + + protected AbstractHttpData(string name, Encoding charset, long size) + { + Contract.Requires(name != null); + + name = StripPattern.Replace(name, " "); + name = ReplacePattern.Replace(name, ""); + if (string.IsNullOrEmpty(name)) + { + throw new ArgumentException("empty name"); + } + + this.name = name; + if (charset != null) + { + this.charset = charset; + } + + this.DefinedSize = size; + } + + public long MaxSize + { + get => this.maxSize; + set => this.maxSize = value; + } + + public void CheckSize(long newSize) + { + if (this.MaxSize >= 0 && newSize > this.MaxSize) + { + throw new IOException("Size exceed allowed maximum capacity"); + } + } + + public string Name => this.name; + + public bool IsCompleted => this.completed; + + protected void SetCompleted() => this.completed = true; + + public Encoding Charset + { + get => this.charset; + set + { + Contract.Requires(value != null); + this.charset = value; + } + } + + public long Length => this.Size; + + public long DefinedLength => this.DefinedSize; + + public IByteBuffer Content + { + get + { + try + { + return this.GetByteBuffer(); + } + catch (IOException e) + { + throw new ChannelException(e); + } + } + } + + protected override void Deallocate() => this.Delete(); + + public abstract int CompareTo(IInterfaceHttpData other); + + public abstract HttpDataType DataType { get; } + + public abstract IByteBufferHolder Copy(); + + public abstract IByteBufferHolder Duplicate(); + + public abstract IByteBufferHolder RetainedDuplicate(); + + public abstract void SetContent(IByteBuffer buffer); + + public abstract void SetContent(Stream source); + + public abstract void AddContent(IByteBuffer buffer, bool last); + + public abstract void Delete(); + + public abstract byte[] GetBytes(); + + public abstract IByteBuffer GetByteBuffer(); + + public abstract IByteBuffer GetChunk(int length); + + public virtual string GetString() => this.GetString(this.charset); + + public abstract string GetString(Encoding encoding); + + public abstract bool RenameTo(FileStream destination); + + public abstract bool IsInMemory { get; } + + public abstract FileStream GetFile(); + + public abstract IByteBufferHolder Replace(IByteBuffer content); + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/AbstractMemoryHttpData.cs b/src/DotNetty.Codecs.Http/Multipart/AbstractMemoryHttpData.cs new file mode 100644 index 0000000..eb97433 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/AbstractMemoryHttpData.cs @@ -0,0 +1,207 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using System.Diagnostics.Contracts; + using System.IO; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common; + + public abstract class AbstractMemoryHttpData : AbstractHttpData + { + IByteBuffer byteBuf; + int chunkPosition; + + protected AbstractMemoryHttpData(string name, Encoding charset, long size) + : base(name, charset, size) + { + } + + public override void SetContent(IByteBuffer buffer) + { + Contract.Requires(buffer != null); + + long localsize = buffer.ReadableBytes; + this.CheckSize(localsize); + if (this.DefinedSize > 0 && this.DefinedSize < localsize) + { + throw new IOException($"Out of size: {localsize} > {this.DefinedSize}"); + } + this.byteBuf?.Release(); + + this.byteBuf = buffer; + this.Size = localsize; + this.SetCompleted(); + } + + public override void SetContent(Stream inputStream) + { + Contract.Requires(inputStream != null); + + if (!inputStream.CanRead) + { + throw new ArgumentException($"{nameof(inputStream)} is not readable"); + } + + IByteBuffer buffer = Unpooled.Buffer(); + var bytes = new byte[4096 * 4]; + int written = 0; + while (true) + { + int read = inputStream.Read(bytes, 0, bytes.Length); + if (read <= 0) + { + break; + } + + buffer.WriteBytes(bytes, 0, read); + written += read; + this.CheckSize(written); + } + this.Size = written; + if (this.DefinedSize > 0 && this.DefinedSize < this.Size) + { + throw new IOException($"Out of size: {this.Size} > {this.DefinedSize}"); + } + + this.byteBuf?.Release(); + this.byteBuf = buffer; + this.SetCompleted(); + } + + public override void AddContent(IByteBuffer buffer, bool last) + { + if (buffer != null) + { + long localsize = buffer.ReadableBytes; + this.CheckSize(this.Size + localsize); + if (this.DefinedSize > 0 && this.DefinedSize < this.Size + localsize) + { + throw new IOException($"Out of size: {(this.Size + localsize)} > {this.DefinedSize}"); + } + + this.Size += localsize; + if (this.byteBuf == null) + { + this.byteBuf = buffer; + } + else if (this.byteBuf is CompositeByteBuffer buf) + { + buf.AddComponent(true, buffer); + buf.SetWriterIndex((int)this.Size); + } + else + { + CompositeByteBuffer compositeBuffer = Unpooled.CompositeBuffer(int.MaxValue); + compositeBuffer.AddComponents(true, this.byteBuf, buffer); + compositeBuffer.SetWriterIndex((int)this.Size); + this.byteBuf = compositeBuffer; + } + } + if (last) + { + this.SetCompleted(); + } + else + { + if (buffer == null) + { + throw new ArgumentNullException(nameof(buffer)); + } + } + } + + public override void Delete() + { + if (this.byteBuf != null) + { + this.byteBuf.Release(); + this.byteBuf = null; + } + } + + public override byte[] GetBytes() + { + if (this.byteBuf == null) + { + return Unpooled.Empty.Array; + } + + var array = new byte[this.byteBuf.ReadableBytes]; + this.byteBuf.GetBytes(this.byteBuf.ReaderIndex, array); + return array; + } + + public override string GetString() => this.GetString(HttpConstants.DefaultEncoding); + + public override string GetString(Encoding encoding) + { + if (this.byteBuf == null) + { + return string.Empty; + } + if (encoding == null) + { + encoding = HttpConstants.DefaultEncoding; + } + return this.byteBuf.ToString(encoding); + } + + public override IByteBuffer GetByteBuffer() => this.byteBuf; + + public override IByteBuffer GetChunk(int length) + { + if (this.byteBuf == null || length == 0 || this.byteBuf.ReadableBytes == 0) + { + this.chunkPosition = 0; + return Unpooled.Empty; + } + int sizeLeft = this.byteBuf.ReadableBytes - this.chunkPosition; + if (sizeLeft == 0) + { + this.chunkPosition = 0; + return Unpooled.Empty; + } + int sliceLength = length; + if (sizeLeft < length) + { + sliceLength = sizeLeft; + } + + IByteBuffer chunk = this.byteBuf.RetainedSlice(this.chunkPosition, sliceLength); + this.chunkPosition += sliceLength; + return chunk; + } + + public override bool IsInMemory => true; + + public override bool RenameTo(FileStream destination) + { + Contract.Requires(destination != null); + + if (!destination.CanWrite) + { + throw new ArgumentException($"{nameof(destination)} is not writable"); + } + if (this.byteBuf == null) + { + return true; + } + + this.byteBuf.GetBytes(this.byteBuf.ReaderIndex, destination, this.byteBuf.ReadableBytes); + destination.Flush(); + return true; + } + + public override FileStream GetFile() => throw new IOException("Not represented by a stream"); + + public override IReferenceCounted Touch(object hint) + { + this.byteBuf?.Touch(hint); + return this; + } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/CaseIgnoringComparator.cs b/src/DotNetty.Codecs.Http/Multipart/CaseIgnoringComparator.cs new file mode 100644 index 0000000..f82756e --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/CaseIgnoringComparator.cs @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using System.Collections.Generic; + using DotNetty.Common.Utilities; + + sealed class CaseIgnoringComparator : IEqualityComparer, IComparer + { + public static readonly IEqualityComparer Default = new CaseIgnoringComparator(); + + CaseIgnoringComparator() + { + } + + public int Compare(ICharSequence x, ICharSequence y) + { + if (ReferenceEquals(x, y)) + { + return 0; + } + if (x == null) + { + return -1; + } + if (y == null) + { + return 1; + } + + int o1Length = x.Count; + int o2Length = y.Count; + int min = Math.Min(o1Length, o2Length); + for (int i = 0; i < min; i++) + { + char c1 = x[i]; + char c2 = y[i]; + if (c1 != c2) + { + c1 = char.ToUpper(c1); + c2 = char.ToUpper(c2); + if (c1 != c2) + { + c1 = char.ToLower(c1); + c2 = char.ToLower(c2); + if (c1 != c2) + { + return c1 - c2; + } + } + } + } + + return o1Length - o2Length; + } + + public bool Equals(ICharSequence x, ICharSequence y) + { + if (ReferenceEquals(x, y)) + { + return true; + } + + if (x == null || y == null) + { + return false; + } + + int o1Length = x.Count; + int o2Length = y.Count; + + if (o1Length != o2Length) + { + return false; + } + + for (int i = 0; i < o1Length; i++) + { + char c1 = x[i]; + char c2 = y[i]; + if (c1 != c2) + { + c1 = char.ToUpper(c1); + c2 = char.ToUpper(c2); + if (c1 != c2) + { + c1 = char.ToLower(c1); + c2 = char.ToLower(c2); + if (c1 != c2) + { + return false; + } + } + } + } + + return true; + } + + public int GetHashCode(ICharSequence obj) => obj.HashCode(true); + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/DefaultHttpDataFactory.cs b/src/DotNetty.Codecs.Http/Multipart/DefaultHttpDataFactory.cs new file mode 100644 index 0000000..20daafb --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/DefaultHttpDataFactory.cs @@ -0,0 +1,283 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using System.Collections.Concurrent; + using System.Collections.Generic; + using System.IO; + using System.Linq; + using System.Text; + + public class DefaultHttpDataFactory : IHttpDataFactory + { + // Proposed default MINSIZE as 16 KB. + public static readonly long MinSize = 0x4000; + + // Proposed default MAXSIZE = -1 as UNLIMITED + public static readonly long MaxSize = -1; + + readonly bool useDisk; + readonly bool checkSize; + readonly long minSize; + long maxSize = MaxSize; + readonly Encoding charset = HttpConstants.DefaultEncoding; + + // Keep all HttpDatas until cleanAllHttpData() is called. + readonly ConcurrentDictionary> requestFileDeleteMap = + new ConcurrentDictionary>(IdentityComparer.Default); + + // HttpData will be in memory if less than default size (16KB). + // The type will be Mixed. + public DefaultHttpDataFactory() + { + this.useDisk = false; + this.checkSize = true; + this.minSize = MinSize; + } + + public DefaultHttpDataFactory(Encoding charset) : this() + { + this.charset = charset; + } + + // HttpData will be always on Disk if useDisk is True, else always in Memory if False + public DefaultHttpDataFactory(bool useDisk) + { + this.useDisk = useDisk; + this.checkSize = false; + } + + public DefaultHttpDataFactory(bool useDisk, Encoding charset) : this(useDisk) + { + this.charset = charset; + } + + public DefaultHttpDataFactory(long minSize) + { + this.useDisk = false; + this.checkSize = true; + this.minSize = minSize; + } + + public DefaultHttpDataFactory(long minSize, Encoding charset) : this(minSize) + { + this.charset = charset; + } + + public void SetMaxLimit(long max) => this.maxSize = max; + + List GetList(IHttpRequest request) + { + List list = this.requestFileDeleteMap.GetOrAdd(request, _ => new List()); + return list; + } + + public IAttribute CreateAttribute(IHttpRequest request, string name) + { + if (this.useDisk) + { + var diskAttribute = new DiskAttribute(name, this.charset); + diskAttribute.MaxSize = this.maxSize; + List list = this.GetList(request); + list.Add(diskAttribute); + return diskAttribute; + } + if (this.checkSize) + { + var mixedAttribute = new MixedAttribute(name, this.minSize, this.charset); + mixedAttribute.MaxSize = this.maxSize; + List list = this.GetList(request); + list.Add(mixedAttribute); + return mixedAttribute; + } + var attribute = new MemoryAttribute(name); + attribute.MaxSize = this.maxSize; + return attribute; + } + + public IAttribute CreateAttribute(IHttpRequest request, string name, long definedSize) + { + if (this.useDisk) + { + var diskAttribute = new DiskAttribute(name, definedSize, this.charset); + diskAttribute.MaxSize = this.maxSize; + List list = this.GetList(request); + list.Add(diskAttribute); + return diskAttribute; + } + if (this.checkSize) + { + var mixedAttribute = new MixedAttribute(name, definedSize, this.minSize, this.charset); + mixedAttribute.MaxSize = this.maxSize; + List list = this.GetList(request); + list.Add(mixedAttribute); + return mixedAttribute; + } + var attribute = new MemoryAttribute(name, definedSize); + attribute.MaxSize = this.maxSize; + return attribute; + } + + static void CheckHttpDataSize(IHttpData data) + { + try + { + data.CheckSize(data.Length); + } + catch (IOException) + { + throw new ArgumentException("Attribute bigger than maxSize allowed"); + } + } + + public IAttribute CreateAttribute(IHttpRequest request, string name, string value) + { + if (this.useDisk) + { + IAttribute attribute; + try + { + attribute = new DiskAttribute(name, value, this.charset); + attribute.MaxSize = this.maxSize; + } + catch (IOException) + { + // revert to Mixed mode + attribute = new MixedAttribute(name, value, this.minSize, this.charset); + attribute.MaxSize = this.maxSize; + } + CheckHttpDataSize(attribute); + List list = this.GetList(request); + list.Add(attribute); + return attribute; + } + if (this.checkSize) + { + var mixedAttribute = new MixedAttribute(name, value, this.minSize, this.charset); + mixedAttribute.MaxSize = this.maxSize; + CheckHttpDataSize(mixedAttribute); + List list = this.GetList(request); + list.Add(mixedAttribute); + return mixedAttribute; + } + try + { + var attribute = new MemoryAttribute(name, value, this.charset); + attribute.MaxSize = this.maxSize; + CheckHttpDataSize(attribute); + return attribute; + } + catch (IOException e) + { + throw new ArgumentException($"({request}, {name}, {value})", e); + } + } + + public IFileUpload CreateFileUpload(IHttpRequest request, string name, string fileName, + string contentType, string contentTransferEncoding, Encoding encoding, + long size) + { + if (this.useDisk) + { + var fileUpload = new DiskFileUpload(name, fileName, contentType, + contentTransferEncoding, encoding, size); + fileUpload.MaxSize = this.maxSize; + CheckHttpDataSize(fileUpload); + List list = this.GetList(request); + list.Add(fileUpload); + return fileUpload; + } + if (this.checkSize) + { + var fileUpload = new MixedFileUpload(name, fileName, contentType, + contentTransferEncoding, encoding, size, this.minSize); + fileUpload.MaxSize = this.maxSize; + CheckHttpDataSize(fileUpload); + List list = this.GetList(request); + list.Add(fileUpload); + return fileUpload; + } + var memoryFileUpload = new MemoryFileUpload(name, fileName, contentType, + contentTransferEncoding, encoding, size); + memoryFileUpload.MaxSize = this.maxSize; + CheckHttpDataSize(memoryFileUpload); + return memoryFileUpload; + } + + public void RemoveHttpDataFromClean(IHttpRequest request, IInterfaceHttpData data) + { + if (!(data is IHttpData httpData)) + { + return; + } + + // Do not use getList because it adds empty list to requestFileDeleteMap + // if request is not found + if (!this.requestFileDeleteMap.TryGetValue(request, out List list)) + { + return; + } + + // Can't simply call list.remove(data), because different data items may be equal. + // Need to check identity. + int index = -1; + for (int i = 0; i < list.Count; i++) + { + if (ReferenceEquals(list[i], httpData)) + { + index = i; + break; + } + } + if (index != -1) + { + list.RemoveAt(index); + } + if (list.Count == 0) + { + this.requestFileDeleteMap.TryRemove(request, out _); + } + } + + public void CleanRequestHttpData(IHttpRequest request) + { + if (this.requestFileDeleteMap.TryRemove(request, out List list)) + { + foreach (IHttpData data in list) + { + data.Release(); + } + } + } + + public void CleanAllHttpData() + { + while (!this.requestFileDeleteMap.IsEmpty) + { + IHttpRequest[] keys = this.requestFileDeleteMap.Keys.ToArray(); + foreach (IHttpRequest key in keys) + { + if (this.requestFileDeleteMap.TryRemove(key, out List list)) + { + foreach (IHttpData data in list) + { + data.Release(); + } + } + } + } + } + + // Similar to IdentityHashMap in Java + sealed class IdentityComparer : IEqualityComparer + { + internal static readonly IdentityComparer Default = new IdentityComparer(); + + public bool Equals(IHttpRequest x, IHttpRequest y) => ReferenceEquals(x, y); + + public int GetHashCode(IHttpRequest obj) => obj.GetHashCode(); + } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/DiskAttribute.cs b/src/DotNetty.Codecs.Http/Multipart/DiskAttribute.cs new file mode 100644 index 0000000..52d5dd4 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/DiskAttribute.cs @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using System.Diagnostics.Contracts; + using System.IO; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Transport.Channels; + + public class DiskAttribute : AbstractDiskHttpData, IAttribute + { + public static string DiskBaseDirectory; + public static bool DeleteOnExitTemporaryFile = true; + public static readonly string FilePrefix = "Attr_"; + public static readonly string FilePostfix = ".att"; + + public DiskAttribute(string name) + : this(name, HttpConstants.DefaultEncoding) + { + } + + public DiskAttribute(string name, long definedSize) + : this(name, definedSize, HttpConstants.DefaultEncoding) + { + } + + public DiskAttribute(string name, Encoding charset) + : base(name, charset, 0) + { + } + + public DiskAttribute(string name, long definedSize, Encoding charset) + : base(name, charset, definedSize) + { + } + + public DiskAttribute(string name, string value) + : this(name, value, HttpConstants.DefaultEncoding) + { + } + + public DiskAttribute(string name, string value, Encoding charset) + : base(name, charset, 0) // Attribute have no default size + { + this.Value = value; + } + + public override HttpDataType DataType => HttpDataType.Attribute; + + public string Value + { + get + { + byte[] bytes = this.GetBytes(); + return this.Charset.GetString(bytes); + } + set + { + Contract.Requires(value != null); + + byte[] bytes = this.Charset.GetBytes(value); + this.CheckSize(bytes.Length); + IByteBuffer buffer = Unpooled.WrappedBuffer(bytes); + if (this.DefinedSize > 0) + { + this.DefinedSize = buffer.ReadableBytes; + } + this.SetContent(buffer); + } + } + + public override void AddContent(IByteBuffer buffer, bool last) + { + long newDefinedSize = this.Size + buffer.ReadableBytes; + this.CheckSize(newDefinedSize); + if (this.DefinedSize > 0 && this.DefinedSize < newDefinedSize) + { + this.DefinedSize = newDefinedSize; + } + base.AddContent(buffer, last); + } + + public override int GetHashCode() => this.Name.GetHashCode(); + + public override bool Equals(object obj) + { + if (obj is IAttribute attribute) + { + return this.Name.Equals(attribute.Name, StringComparison.OrdinalIgnoreCase); + } + return false; + } + + public override int CompareTo(IInterfaceHttpData other) + { + if (!(other is IAttribute)) + { + throw new ArgumentException($"Cannot compare {this.DataType} with {other.DataType}"); + } + + return this.CompareTo((IAttribute)other); + } + + public int CompareTo(IAttribute attribute) => string.Compare(this.Name, attribute.Name, StringComparison.OrdinalIgnoreCase); + + public override string ToString() + { + try + { + return $"{this.Name}={this.Value}"; + } + catch (IOException e) + { + return $"{this.Name}={e}"; + } + } + + protected override bool DeleteOnExit => DeleteOnExitTemporaryFile; + + protected override string BaseDirectory => DiskBaseDirectory; + + protected override string DiskFilename => $"{this.Name}{this.Postfix}"; + + protected override string Postfix => FilePostfix; + + protected override string Prefix => FilePrefix; + + public override IByteBufferHolder Copy() => this.Replace(this.Content?.Copy()); + + public override IByteBufferHolder Duplicate() => this.Replace(this.Content?.Duplicate()); + + public override IByteBufferHolder RetainedDuplicate() + { + IByteBuffer content = this.Content; + if (content != null) + { + content = content.RetainedDuplicate(); + bool success = false; + try + { + var duplicate = (IAttribute)this.Replace(content); + success = true; + return duplicate; + } + finally + { + if (!success) + { + content.Release(); + } + } + } + else + { + return this.Replace(null); + } + } + + public override IByteBufferHolder Replace(IByteBuffer content) + { + var attr = new DiskAttribute(this.Name); + attr.Charset = this.Charset; + if (content != null) + { + try + { + attr.SetContent(content); + } + catch (IOException e) + { + throw new ChannelException(e); + } + } + return attr; + } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/DiskFileUpload.cs b/src/DotNetty.Codecs.Http/Multipart/DiskFileUpload.cs new file mode 100644 index 0000000..cc7da95 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/DiskFileUpload.cs @@ -0,0 +1,165 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using System.Diagnostics.Contracts; + using System.IO; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Transport.Channels; + + public class DiskFileUpload : AbstractDiskHttpData, IFileUpload + { + public static string FileBaseDirectory; + public static bool DeleteOnExitTemporaryFile = true; + public static string FilePrefix = "FUp_"; + public static readonly string FilePostfix = ".tmp"; + + string filename; + string contentType; + string contentTransferEncoding; + + public DiskFileUpload(string name, string filename, string contentType, + string contentTransferEncoding, Encoding charset, long size) + : base(name, charset, size) + { + Contract.Requires(filename != null); + Contract.Requires(contentType != null); + + this.filename = filename; + this.contentType = contentType; + this.contentTransferEncoding = contentTransferEncoding; + } + + public override HttpDataType DataType => HttpDataType.FileUpload; + + public string FileName + { + get => this.filename; + set + { + Contract.Requires(value != null); + this.filename = value; + } + } + + public override int GetHashCode() => FileUploadUtil.HashCode(this); + + public override bool Equals(object obj) => obj is IFileUpload fileUpload && FileUploadUtil.Equals(this, fileUpload); + + public override int CompareTo(IInterfaceHttpData other) + { + if (!(other is IFileUpload)) + { + throw new ArgumentException($"Cannot compare {this.DataType} with {other.DataType}"); + } + + return this.CompareTo((IFileUpload)other); + } + + public int CompareTo(IFileUpload other) => FileUploadUtil.CompareTo(this, other); + + public string ContentType + { + get => this.contentType; + set + { + Contract.Requires(value != null); + this.contentType = value; + } + } + + public string ContentTransferEncoding + { + get => this.contentTransferEncoding; + set => this.contentTransferEncoding = value; + } + + public override string ToString() + { + FileStream fileStream = null; + try + { + fileStream = this.GetFile(); + } + catch (IOException) + { + // Should not occur. + } + + return HttpHeaderNames.ContentDisposition + ": " + + HttpHeaderValues.FormData + "; " + HttpHeaderValues.Name + "=\"" + this.Name + + "\"; " + HttpHeaderValues.FileName + "=\"" + this.filename + "\"\r\n" + + HttpHeaderNames.ContentType + ": " + this.contentType + + (this.Charset != null ? "; " + HttpHeaderValues.Charset + '=' + this.Charset.WebName + "\r\n" : "\r\n") + + HttpHeaderNames.ContentLength + ": " + this.Length + "\r\n" + + "Completed: " + this.IsCompleted + + "\r\nIsInMemory: " + this.IsInMemory + "\r\nRealFile: " + + (fileStream != null ? fileStream.Name : "null") + " DefaultDeleteAfter: " + + DeleteOnExitTemporaryFile; + } + + protected override bool DeleteOnExit => DeleteOnExitTemporaryFile; + + protected override string BaseDirectory => FileBaseDirectory; + + protected override string DiskFilename => "upload"; + + protected override string Postfix => FilePostfix; + + protected override string Prefix => FilePrefix; + + public override IByteBufferHolder Copy() => this.Replace(this.Content?.Copy()); + + public override IByteBufferHolder Duplicate() => this.Replace(this.Content?.Duplicate()); + + public override IByteBufferHolder RetainedDuplicate() + { + IByteBuffer content = this.Content; + if (content != null) + { + content = content.RetainedDuplicate(); + bool success = false; + try + { + var duplicate = (IFileUpload)this.Replace(content); + success = true; + return duplicate; + } + finally + { + if (!success) + { + content.Release(); + } + } + } + else + { + return this.Replace(null); + } + } + + public override IByteBufferHolder Replace(IByteBuffer content) + { + var upload = new DiskFileUpload( + this.Name, this.FileName, this.ContentType, this.ContentTransferEncoding, this.Charset, this.Size); + if (content != null) + { + try + { + upload.SetContent(content); + } + catch (IOException e) + { + throw new ChannelException(e); + } + } + + return upload; + } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/EndOfDataDecoderException.cs b/src/DotNetty.Codecs.Http/Multipart/EndOfDataDecoderException.cs new file mode 100644 index 0000000..4b80a1e --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/EndOfDataDecoderException.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + + public class EndOfDataDecoderException : DecoderException + { + public EndOfDataDecoderException(string message) + : base(message) + { + } + + public EndOfDataDecoderException(Exception innerException) + : base(innerException) + { + } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/ErrorDataDecoderException.cs b/src/DotNetty.Codecs.Http/Multipart/ErrorDataDecoderException.cs new file mode 100644 index 0000000..229c0be --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/ErrorDataDecoderException.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + + public class ErrorDataDecoderException : DecoderException + { + public ErrorDataDecoderException(string message) + : base(message) + { + } + + public ErrorDataDecoderException(Exception innerException) + : base(innerException) + { + } + + public ErrorDataDecoderException(string message, Exception innerException) + : base(message, innerException) + { + } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/ErrorDataEncoderException.cs b/src/DotNetty.Codecs.Http/Multipart/ErrorDataEncoderException.cs new file mode 100644 index 0000000..7995eb1 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/ErrorDataEncoderException.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + + public class ErrorDataEncoderException : Exception + { + public ErrorDataEncoderException(string message) + : base(message) + { + } + public ErrorDataEncoderException(Exception innerException) + : base(null, innerException) + { + } + + public ErrorDataEncoderException(string message, Exception innerException) + : base(message, innerException) + { + } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/FileUploadUtil.cs b/src/DotNetty.Codecs.Http/Multipart/FileUploadUtil.cs new file mode 100644 index 0000000..0832ff6 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/FileUploadUtil.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + + static class FileUploadUtil + { + public static int HashCode(IFileUpload upload) => upload.Name.GetHashCode(); + + public static bool Equals(IFileUpload upload1, IFileUpload upload2) => + upload1.Name.Equals(upload2.Name, StringComparison.OrdinalIgnoreCase); + + public static int CompareTo(IFileUpload upload1, IFileUpload upload2) => + string.Compare(upload1.Name, upload2.Name, StringComparison.OrdinalIgnoreCase); + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/HttpPostBodyUtil.cs b/src/DotNetty.Codecs.Http/Multipart/HttpPostBodyUtil.cs new file mode 100644 index 0000000..d666a65 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/HttpPostBodyUtil.cs @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + + static class HttpPostBodyUtil + { + public static readonly int ChunkSize = 8096; + + public static readonly string DefaultBinaryContentType = "application/octet-stream"; + + public static readonly string DefaultTextContentType = "text/plain"; + + public sealed class TransferEncodingMechanism + { + // Default encoding + public static readonly TransferEncodingMechanism Bit7 = new TransferEncodingMechanism("7bit"); + + // Short lines but not in ASCII - no encoding + public static readonly TransferEncodingMechanism Bit8 = new TransferEncodingMechanism("8bit"); + + // Could be long text not in ASCII - no encoding + public static readonly TransferEncodingMechanism Binary = new TransferEncodingMechanism("binary"); + + readonly string value; + + TransferEncodingMechanism(string value) + { + this.value = value; + } + + public string Value => this.value; + + public override string ToString() => this.value; + } + + internal class SeekAheadOptimize + { + internal byte[] Bytes; + internal int ReaderIndex; + internal int Pos; + internal int OrigPos; + internal int Limit; + internal IByteBuffer Buffer; + + internal SeekAheadOptimize(IByteBuffer buffer) + { + if (!buffer.HasArray) + { + throw new ArgumentException("buffer hasn't backing byte array"); + } + this.Buffer = buffer; + this.Bytes = buffer.Array; + this.ReaderIndex = buffer.ReaderIndex; + this.OrigPos = this.Pos = buffer.ArrayOffset + this.ReaderIndex; + this.Limit = buffer.ArrayOffset + buffer.WriterIndex; + } + + internal void SetReadPosition(int minus) + { + this.Pos -= minus; + this.ReaderIndex = this.GetReadPosition(this.Pos); + this.Buffer.SetReaderIndex(this.ReaderIndex); + } + + internal int GetReadPosition(int index) => index - this.OrigPos + this.ReaderIndex; + } + + internal static int FindNonWhitespace(ICharSequence sb, int offset) + { + int result; + for (result = offset; result < sb.Count; result++) + { + if (!char.IsWhiteSpace(sb[result])) + { + break; + } + } + + return result; + } + + internal static int FindWhitespace(ICharSequence sb, int offset) + { + int result; + for (result = offset; result < sb.Count; result++) + { + if (char.IsWhiteSpace(sb[result])) + { + break; + } + } + + return result; + } + + internal static int FindEndOfString(ICharSequence sb) + { + int result; + for (result = sb.Count; result > 0; result--) + { + if (!char.IsWhiteSpace(sb[result - 1])) + { + break; + } + } + + return result; + } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/HttpPostMultipartRequestDecoder.cs b/src/DotNetty.Codecs.Http/Multipart/HttpPostMultipartRequestDecoder.cs new file mode 100644 index 0000000..046a635 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/HttpPostMultipartRequestDecoder.cs @@ -0,0 +1,1564 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.IO; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common; + using DotNetty.Common.Utilities; + + public class HttpPostMultipartRequestDecoder : IInterfaceHttpPostRequestDecoder + { + // Factory used to create InterfaceHttpData + readonly IHttpDataFactory factory; + + // Request to decode + readonly IHttpRequest request; + + // Default charset to use + Encoding charset; + + // Does the last chunk already received + bool isLastChunk; + + // HttpDatas from Body + readonly List bodyListHttpData = new List(); + + // HttpDatas as Map from Body + readonly Dictionary> bodyMapHttpData = new Dictionary>(CaseIgnoringComparator.Default); + + // The current channelBuffer + IByteBuffer undecodedChunk; + + // Body HttpDatas current position + int bodyListHttpDataRank; + + // If multipart, this is the boundary for the global multipart + ICharSequence multipartDataBoundary; + + // If multipart, there could be internal multiparts (mixed) to the global + // multipart. Only one level is allowed. + ICharSequence multipartMixedBoundary; + + // Current getStatus + MultiPartStatus currentStatus = MultiPartStatus.Notstarted; + + // Used in Multipart + Dictionary currentFieldAttributes; + + // The current FileUpload that is currently in decode process + IFileUpload currentFileUpload; + + // The current Attribute that is currently in decode process + IAttribute currentAttribute; + + bool destroyed; + + int discardThreshold = HttpPostRequestDecoder.DefaultDiscardThreshold; + + public HttpPostMultipartRequestDecoder(IHttpRequest request) + : this(new DefaultHttpDataFactory(DefaultHttpDataFactory.MinSize), request, HttpConstants.DefaultEncoding) + { + } + + public HttpPostMultipartRequestDecoder(IHttpDataFactory factory, IHttpRequest request) + : this(factory, request, HttpConstants.DefaultEncoding) + { + } + + public HttpPostMultipartRequestDecoder(IHttpDataFactory factory, IHttpRequest request, Encoding charset) + { + Contract.Requires(request != null); + Contract.Requires(charset != null); + Contract.Requires(factory != null); + + this.factory = factory; + this.request = request; + this.charset = charset; + + // Fill default values + this.SetMultipart(this.request.Headers.Get(HttpHeaderNames.ContentType, null)); + if (request is IHttpContent content) + { + // Offer automatically if the given request is als type of HttpContent + // See #1089 + this.Offer(content); + } + else + { + this.undecodedChunk = Unpooled.Buffer(); + this.ParseBody(); + } + } + + void SetMultipart(ICharSequence contentType) + { + ICharSequence[] dataBoundary = HttpPostRequestDecoder.GetMultipartDataBoundary(contentType); + if (dataBoundary != null) + { + this.multipartDataBoundary = new AsciiString(dataBoundary[0]); + if (dataBoundary.Length > 1 && dataBoundary[1] != null) + { + this.charset = Encoding.GetEncoding(dataBoundary[1].ToString()); + } + } + else + { + this.multipartDataBoundary = null; + } + this.currentStatus = MultiPartStatus.HeaderDelimiter; + } + + void CheckDestroyed() + { + if (this.destroyed) + { + throw new InvalidOperationException($"{StringUtil.SimpleClassName()} was destroyed already"); + } + } + + public bool IsMultipart + { + get + { + this.CheckDestroyed(); + return true; + } + } + + public int DiscardThreshold + { + get => this.discardThreshold; + set + { + Contract.Requires(value >= 0); + this.discardThreshold = value; + } + } + + public List GetBodyHttpDatas() + { + this.CheckDestroyed(); + + if (!this.isLastChunk) + { + throw new NotEnoughDataDecoderException(nameof(HttpPostMultipartRequestDecoder)); + } + return this.bodyListHttpData; + } + + public List GetBodyHttpDatas(AsciiString name) + { + this.CheckDestroyed(); + + if (!this.isLastChunk) + { + throw new NotEnoughDataDecoderException(nameof(HttpPostMultipartRequestDecoder)); + } + return this.bodyMapHttpData[name]; + } + + public IInterfaceHttpData GetBodyHttpData(AsciiString name) + { + this.CheckDestroyed(); + + if (!this.isLastChunk) + { + throw new NotEnoughDataDecoderException(nameof(HttpPostMultipartRequestDecoder)); + } + if (this.bodyMapHttpData.TryGetValue(name, out List list)) + { + return list[0]; + } + return null; + } + + public IInterfaceHttpPostRequestDecoder Offer(IHttpContent content) + { + this.CheckDestroyed(); + + // Maybe we should better not copy here for performance reasons but this will need + // more care by the caller to release the content in a correct manner later + // So maybe something to optimize on a later stage + IByteBuffer buf = content.Content; + if (this.undecodedChunk == null) + { + this.undecodedChunk = buf.Copy(); + } + else + { + this.undecodedChunk.WriteBytes(buf); + } + if (content is ILastHttpContent) + { + this.isLastChunk = true; + } + this.ParseBody(); + if (this.undecodedChunk != null + && this.undecodedChunk.WriterIndex > this.discardThreshold) + { + this.undecodedChunk.DiscardReadBytes(); + } + return this; + } + + public bool HasNext + { + get + { + this.CheckDestroyed(); + + if (this.currentStatus == MultiPartStatus.Epilogue) + { + // OK except if end of list + if (this.bodyListHttpDataRank >= this.bodyListHttpData.Count) + { + throw new NotEnoughDataDecoderException(nameof(HttpPostMultipartRequestDecoder)); + } + } + return this.bodyListHttpData.Count > 0 && this.bodyListHttpDataRank < this.bodyListHttpData.Count; + } + } + + public IInterfaceHttpData Next() + { + this.CheckDestroyed(); + + return this.HasNext + ? this.bodyListHttpData[this.bodyListHttpDataRank++] + : null; + } + + public IInterfaceHttpData CurrentPartialHttpData + { + get + { + if (this.currentFileUpload != null) + { + return this.currentFileUpload; + } + else + { + return this.currentAttribute; + } + } + } + + void ParseBody() + { + if (this.currentStatus == MultiPartStatus.PreEpilogue + || this.currentStatus == MultiPartStatus.Epilogue) + { + if (this.isLastChunk) + { + this.currentStatus = MultiPartStatus.Epilogue; + } + return; + } + + this.ParseBodyMultipart(); + } + + protected void AddHttpData(IInterfaceHttpData data) + { + if (data == null) + { + return; + } + var name = new AsciiString(data.Name); + if (!this.bodyMapHttpData.TryGetValue(name, out List datas)) + { + datas = new List(1); + this.bodyMapHttpData.Add(name, datas); + } + datas.Add(data); + this.bodyListHttpData.Add(data); + } + + void ParseBodyMultipart() + { + if (this.undecodedChunk == null + || this.undecodedChunk.ReadableBytes == 0) + { + // nothing to decode + return; + } + + IInterfaceHttpData data = this.DecodeMultipart(this.currentStatus); + while (data != null) + { + this.AddHttpData(data); + if (this.currentStatus == MultiPartStatus.PreEpilogue + || this.currentStatus == MultiPartStatus.Epilogue) + { + break; + } + + data = this.DecodeMultipart(this.currentStatus); + } + } + + IInterfaceHttpData DecodeMultipart(MultiPartStatus state) + { + switch (state) + { + case MultiPartStatus.Notstarted: + throw new ErrorDataDecoderException("Should not be called with the current getStatus"); + case MultiPartStatus.Preamble: + // Content-type: multipart/form-data, boundary=AaB03x + throw new ErrorDataDecoderException("Should not be called with the current getStatus"); + case MultiPartStatus.HeaderDelimiter: + { + // --AaB03x or --AaB03x-- + return this.FindMultipartDelimiter(this.multipartDataBoundary, MultiPartStatus.Disposition, + MultiPartStatus.PreEpilogue); + } + case MultiPartStatus.Disposition: + { + // content-disposition: form-data; name="field1" + // content-disposition: form-data; name="pics"; filename="file1.txt" + // and other immediate values like + // Content-type: image/gif + // Content-Type: text/plain + // Content-Type: text/plain; charset=ISO-8859-1 + // Content-Transfer-Encoding: binary + // The following line implies a change of mode (mixed mode) + // Content-type: multipart/mixed, boundary=BbC04y + return this.FindMultipartDisposition(); + } + case MultiPartStatus.Field: + { + // Now get value according to Content-Type and Charset + Encoding localCharset = null; + if (this.currentFieldAttributes.TryGetValue(HttpHeaderValues.Charset, out IAttribute charsetAttribute)) + { + try + { + localCharset = Encoding.GetEncoding(charsetAttribute.Value); + } + catch (IOException e) + { + throw new ErrorDataDecoderException(e); + } + catch (ArgumentException e) + { + throw new ErrorDataDecoderException(e); + } + } + this.currentFieldAttributes.TryGetValue(HttpHeaderValues.Name, out IAttribute nameAttribute); + if (this.currentAttribute == null) + { + this.currentFieldAttributes.TryGetValue(HttpHeaderNames.ContentLength, out IAttribute lengthAttribute); + long size; + try + { + size = lengthAttribute != null ? long.Parse(lengthAttribute.Value) : 0L; + } + catch (IOException e) + { + throw new ErrorDataDecoderException(e); + } + catch (FormatException) + { + size = 0; + } + try + { + if (nameAttribute == null) + { + throw new ErrorDataDecoderException($"{HttpHeaderValues.Name} attribute cannot be null."); + } + if (size > 0) + { + this.currentAttribute = this.factory.CreateAttribute(this.request, + CleanString(nameAttribute.Value).ToString(), size); + } + else + { + this.currentAttribute = this.factory.CreateAttribute(this.request, + CleanString(nameAttribute.Value).ToString()); + } + } + catch (ArgumentException e) + { + throw new ErrorDataDecoderException(e); + } + catch (IOException e) + { + throw new ErrorDataDecoderException(e); + } + if (localCharset != null) + { + this.currentAttribute.Charset = localCharset; + } + } + // load data + if (!LoadDataMultipart(this.undecodedChunk, this.multipartDataBoundary, this.currentAttribute)) + { + // Delimiter is not found. Need more chunks. + return null; + } + IAttribute finalAttribute = this.currentAttribute; + this.currentAttribute = null; + this.currentFieldAttributes = null; + // ready to load the next one + this.currentStatus = MultiPartStatus.HeaderDelimiter; + return finalAttribute; + } + case MultiPartStatus.Fileupload: + { + // eventually restart from existing FileUpload + return this.GetFileUpload(this.multipartDataBoundary); + } + case MultiPartStatus.MixedDelimiter: + { + // --AaB03x or --AaB03x-- + // Note that currentFieldAttributes exists + return this.FindMultipartDelimiter(this.multipartMixedBoundary, MultiPartStatus.MixedDisposition, + MultiPartStatus.HeaderDelimiter); + } + case MultiPartStatus.MixedDisposition: + { + return this.FindMultipartDisposition(); + } + case MultiPartStatus.MixedFileUpload: + { + // eventually restart from existing FileUpload + return this.GetFileUpload(this.multipartMixedBoundary); + } + case MultiPartStatus.PreEpilogue: + case MultiPartStatus.Epilogue: + return null; + default: + throw new ErrorDataDecoderException("Shouldn't reach here."); + } + } + + static void SkipControlCharacters(IByteBuffer undecodedChunk) + { + if (!undecodedChunk.HasArray) + { + try + { + SkipControlCharactersStandard(undecodedChunk); + } + catch (IndexOutOfRangeException e) + { + throw new NotEnoughDataDecoderException(e); + } + return; + } + var sao = new HttpPostBodyUtil.SeekAheadOptimize(undecodedChunk); + while (sao.Pos < sao.Limit) + { + char c = (char)sao.Bytes[sao.Pos++]; + if (!CharUtil.IsISOControl(c) && !char.IsWhiteSpace(c)) + { + sao.SetReadPosition(1); + return; + } + } + throw new NotEnoughDataDecoderException("Access out of bounds"); + } + + static void SkipControlCharactersStandard(IByteBuffer undecodedChunk) + { + for (; ;) + { + char c = (char)undecodedChunk.ReadByte(); + if (!CharUtil.IsISOControl(c) && !char.IsWhiteSpace(c)) + { + undecodedChunk.SetReaderIndex(undecodedChunk.ReaderIndex - 1); + break; + } + } + } + + IInterfaceHttpData FindMultipartDelimiter(ICharSequence delimiter, MultiPartStatus dispositionStatus, + MultiPartStatus closeDelimiterStatus) + { + // --AaB03x or --AaB03x-- + int readerIndex = this.undecodedChunk.ReaderIndex; + try + { + SkipControlCharacters(this.undecodedChunk); + } + catch (NotEnoughDataDecoderException) + { + this.undecodedChunk.SetReaderIndex(readerIndex); + return null; + } + this.SkipOneLine(); + StringBuilderCharSequence newline; + try + { + newline = ReadDelimiter(this.undecodedChunk, delimiter); + } + catch (NotEnoughDataDecoderException) + { + this.undecodedChunk.SetReaderIndex(readerIndex); + return null; + } + if (newline.Equals(delimiter)) + { + this.currentStatus = dispositionStatus; + return this.DecodeMultipart(dispositionStatus); + } + if (AsciiString.ContentEquals(newline, new StringCharSequence(delimiter.ToString() + "--"))) + { + // CloseDelimiter or MIXED CloseDelimiter found + this.currentStatus = closeDelimiterStatus; + if (this.currentStatus == MultiPartStatus.HeaderDelimiter) + { + // MixedCloseDelimiter + // end of the Mixed part + this.currentFieldAttributes = null; + return this.DecodeMultipart(MultiPartStatus.HeaderDelimiter); + } + return null; + } + this.undecodedChunk.SetReaderIndex(readerIndex); + throw new ErrorDataDecoderException("No Multipart delimiter found"); + } + + IInterfaceHttpData FindMultipartDisposition() + { + int readerIndex = this.undecodedChunk.ReaderIndex; + if (this.currentStatus == MultiPartStatus.Disposition) + { + this.currentFieldAttributes = new Dictionary(CaseIgnoringComparator.Default); + } + // read many lines until empty line with newline found! Store all data + while (!this.SkipOneLine()) + { + StringCharSequence newline; + try + { + SkipControlCharacters(this.undecodedChunk); + newline = ReadLine(this.undecodedChunk, this.charset); + } + catch (NotEnoughDataDecoderException) + { + this.undecodedChunk.SetReaderIndex(readerIndex); + return null; + } + ICharSequence[] contents = SplitMultipartHeader(newline); + if (HttpHeaderNames.ContentDisposition.ContentEqualsIgnoreCase(contents[0])) + { + bool checkSecondArg; + if (this.currentStatus == MultiPartStatus.Disposition) + { + checkSecondArg = HttpHeaderValues.FormData.ContentEqualsIgnoreCase(contents[1]); + } + else + { + checkSecondArg = HttpHeaderValues.Attachment.ContentEqualsIgnoreCase(contents[1]) + || HttpHeaderValues.File.ContentEqualsIgnoreCase(contents[1]); + } + if (checkSecondArg) + { + // read next values and store them in the map as Attribute + for (int i = 2; i < contents.Length; i++) + { + ICharSequence[] values = CharUtil.Split(contents[i], '='); + IAttribute attribute; + try + { + attribute = this.GetContentDispositionAttribute(values); + } + catch (ArgumentNullException e) + { + throw new ErrorDataDecoderException(e); + } + catch (ArgumentException e) + { + throw new ErrorDataDecoderException(e); + } + this.currentFieldAttributes.Add(new AsciiString(attribute.Name), attribute); + } + } + } + else if (HttpHeaderNames.ContentTransferEncoding.ContentEqualsIgnoreCase(contents[0])) + { + IAttribute attribute; + try + { + attribute = this.factory.CreateAttribute(this.request, HttpHeaderNames.ContentTransferEncoding.ToString(), + CleanString(contents[1]).ToString()); + } + catch (ArgumentNullException e) + { + throw new ErrorDataDecoderException(e); + } + catch (ArgumentException e) + { + throw new ErrorDataDecoderException(e); + } + + this.currentFieldAttributes.Add(HttpHeaderNames.ContentTransferEncoding, attribute); + } + else if (HttpHeaderNames.ContentLength.ContentEqualsIgnoreCase(contents[0])) + { + IAttribute attribute; + try + { + attribute = this.factory.CreateAttribute(this.request, HttpHeaderNames.ContentLength.ToString(), + CleanString(contents[1]).ToString()); + } + catch (ArgumentNullException e) + { + throw new ErrorDataDecoderException(e); + } + catch (ArgumentException e) + { + throw new ErrorDataDecoderException(e); + } + + this.currentFieldAttributes.Add(HttpHeaderNames.ContentLength, attribute); + } + else if (HttpHeaderNames.ContentType.ContentEqualsIgnoreCase(contents[0])) + { + // Take care of possible "multipart/mixed" + if (HttpHeaderValues.MultipartMixed.ContentEqualsIgnoreCase(contents[1])) + { + if (this.currentStatus == MultiPartStatus.Disposition) + { + ICharSequence values = contents[2].SubstringAfter('='); + this.multipartMixedBoundary = new StringCharSequence("--" + values.ToString()); + this.currentStatus = MultiPartStatus.MixedDelimiter; + return this.DecodeMultipart(MultiPartStatus.MixedDelimiter); + } + else + { + throw new ErrorDataDecoderException("Mixed Multipart found in a previous Mixed Multipart"); + } + } + else + { + for (int i = 1; i < contents.Length; i++) + { + ICharSequence charsetHeader = HttpHeaderValues.Charset; + if (contents[i].RegionMatchesIgnoreCase(0, charsetHeader, 0, charsetHeader.Count)) + { + ICharSequence values = contents[i].SubstringAfter('='); + IAttribute attribute; + try + { + attribute = this.factory.CreateAttribute(this.request, charsetHeader.ToString(), CleanString(values).ToString()); + } + catch (ArgumentNullException e) + { + throw new ErrorDataDecoderException(e); + } + catch (ArgumentException e) + { + throw new ErrorDataDecoderException(e); + } + this.currentFieldAttributes.Add(HttpHeaderValues.Charset, attribute); + } + else + { + IAttribute attribute; + ICharSequence name; + try + { + name = CleanString(contents[0]); + attribute = this.factory.CreateAttribute(this.request, + name.ToString(), contents[i].ToString()); + } + catch (ArgumentNullException e) + { + throw new ErrorDataDecoderException(e); + } + catch (ArgumentException e) + { + throw new ErrorDataDecoderException(e); + } + this.currentFieldAttributes.Add(new AsciiString(name), attribute); + } + } + } + } + else + { + throw new ErrorDataDecoderException($"Unknown Params: {newline}"); + } + } + // Is it a FileUpload + this.currentFieldAttributes.TryGetValue(HttpHeaderValues.FileName, out IAttribute filenameAttribute); + if (this.currentStatus == MultiPartStatus.Disposition) + { + if (filenameAttribute != null) + { + // FileUpload + this.currentStatus = MultiPartStatus.Fileupload; + // do not change the buffer position + return this.DecodeMultipart(MultiPartStatus.Fileupload); + } + else + { + // Field + this.currentStatus = MultiPartStatus.Field; + // do not change the buffer position + return this.DecodeMultipart(MultiPartStatus.Field); + } + } + else + { + if (filenameAttribute != null) + { + // FileUpload + this.currentStatus = MultiPartStatus.MixedFileUpload; + // do not change the buffer position + return this.DecodeMultipart(MultiPartStatus.MixedFileUpload); + } + else + { + // Field is not supported in MIXED mode + throw new ErrorDataDecoderException("Filename not found"); + } + } + } + + static readonly AsciiString FilenameEncoded = AsciiString.Cached(HttpHeaderValues.FileName.ToString() + '*'); + + IAttribute GetContentDispositionAttribute(params ICharSequence[] values) + { + ICharSequence name = CleanString(values[0]); + ICharSequence value = values[1]; + + // Filename can be token, quoted or encoded. See https://tools.ietf.org/html/rfc5987 + if (HttpHeaderValues.FileName.ContentEquals(name)) + { + // Value is quoted or token. Strip if quoted: + int last = value.Count - 1; + if (last > 0 + && value[0] == HttpConstants.DoubleQuote + && value[last] == HttpConstants.DoubleQuote) + { + value = value.SubSequence(1, last); + } + } + else if (FilenameEncoded.ContentEquals(name)) + { + try + { + name = HttpHeaderValues.FileName; + string[] split = value.ToString().Split(new [] { '\'' }, 3); + value = new StringCharSequence( + QueryStringDecoder.DecodeComponent(split[2], Encoding.GetEncoding(split[0]))); + } + catch (IndexOutOfRangeException e) + { + throw new ErrorDataDecoderException(e); + } + catch (ArgumentException e) // Invalid encoding + { + throw new ErrorDataDecoderException(e); + } + } + else + { + // otherwise we need to clean the value + value = CleanString(value); + } + return this.factory.CreateAttribute(this.request, name.ToString(), value.ToString()); + } + + protected IInterfaceHttpData GetFileUpload(ICharSequence delimiter) + { + // eventually restart from existing FileUpload + // Now get value according to Content-Type and Charset + this.currentFieldAttributes.TryGetValue(HttpHeaderNames.ContentTransferEncoding, out IAttribute encodingAttribute); + Encoding localCharset = this.charset; + // Default + HttpPostBodyUtil.TransferEncodingMechanism mechanism = HttpPostBodyUtil.TransferEncodingMechanism.Bit7; + if (encodingAttribute != null) + { + string code; + try + { + code = encodingAttribute.Value.ToLower(); + } + catch (IOException e) + { + throw new ErrorDataDecoderException(e); + } + if (code.Equals(HttpPostBodyUtil.TransferEncodingMechanism.Bit7.Value)) + { + localCharset = Encoding.ASCII; + } + else if (code.Equals(HttpPostBodyUtil.TransferEncodingMechanism.Bit8.Value)) + { + localCharset = Encoding.UTF8; + mechanism = HttpPostBodyUtil.TransferEncodingMechanism.Bit8; + } + else if (code.Equals(HttpPostBodyUtil.TransferEncodingMechanism.Binary.Value)) + { + // no real charset, so let the default + mechanism = HttpPostBodyUtil.TransferEncodingMechanism.Binary; + } + else + { + throw new ErrorDataDecoderException("TransferEncoding Unknown: " + code); + } + } + if (this.currentFieldAttributes.TryGetValue(HttpHeaderValues.Charset, out IAttribute charsetAttribute)) + { + try + { + localCharset = Encoding.GetEncoding(charsetAttribute.Value); + } + catch (IOException e) + { + throw new ErrorDataDecoderException(e); + } + catch (ArgumentException e) + { + throw new ErrorDataDecoderException(e); + } + } + if (this.currentFileUpload == null) + { + this.currentFieldAttributes.TryGetValue(HttpHeaderValues.FileName, out IAttribute filenameAttribute); + this.currentFieldAttributes.TryGetValue(HttpHeaderValues.Name, out IAttribute nameAttribute); + this.currentFieldAttributes.TryGetValue(HttpHeaderNames.ContentType, out IAttribute contentTypeAttribute); + this.currentFieldAttributes.TryGetValue(HttpHeaderNames.ContentLength, out IAttribute lengthAttribute); + long size; + try + { + size = lengthAttribute != null ? long.Parse(lengthAttribute.Value) : 0L; + } + catch (IOException e) + { + throw new ErrorDataDecoderException(e); + } + catch (FormatException) + { + size = 0; + } + try + { + string contentType; + if (contentTypeAttribute != null) + { + contentType = contentTypeAttribute.Value; + } + else + { + contentType = HttpPostBodyUtil.DefaultBinaryContentType; + } + if (nameAttribute == null) + { + throw new ErrorDataDecoderException($"{HttpHeaderValues.Name} attribute cannot be null for file upload"); + } + if (filenameAttribute == null) + { + throw new ErrorDataDecoderException($"{HttpHeaderValues.FileName} attribute cannot be null for file upload"); + } + this.currentFileUpload = this.factory.CreateFileUpload(this.request, + CleanString(nameAttribute.Value).ToString(), CleanString(filenameAttribute.Value).ToString(), + contentType, mechanism.Value, localCharset, + size); + } + catch (ArgumentNullException e) + { + throw new ErrorDataDecoderException(e); + } + catch (ArgumentException e) + { + throw new ErrorDataDecoderException(e); + } + catch (IOException e) + { + throw new ErrorDataDecoderException(e); + } + } + // load data as much as possible + if (!LoadDataMultipart(this.undecodedChunk, delimiter, this.currentFileUpload)) + { + // Delimiter is not found. Need more chunks. + return null; + } + if (this.currentFileUpload.IsCompleted) + { + // ready to load the next one + if (this.currentStatus == MultiPartStatus.Fileupload) + { + this.currentStatus = MultiPartStatus.HeaderDelimiter; + this.currentFieldAttributes = null; + } + else + { + this.currentStatus = MultiPartStatus.MixedDelimiter; + this.CleanMixedAttributes(); + } + IFileUpload fileUpload = this.currentFileUpload; + this.currentFileUpload = null; + return fileUpload; + } + + // do not change the buffer position + // since some can be already saved into FileUpload + // So do not change the currentStatus + return null; + } + + public void Destroy() + { + this.CheckDestroyed(); + this.CleanFiles(); + this.destroyed = true; + + if (this.undecodedChunk != null && this.undecodedChunk.ReferenceCount > 0) + { + this.undecodedChunk.Release(); + this.undecodedChunk = null; + } + + // release all data which was not yet pulled + for (int i = this.bodyListHttpDataRank; i < this.bodyListHttpData.Count; i++) + { + this.bodyListHttpData[i].Release(); + } + } + + public void CleanFiles() + { + this.CheckDestroyed(); + this.factory.CleanRequestHttpData(this.request); + } + + public void RemoveHttpDataFromClean(IInterfaceHttpData data) + { + this.CheckDestroyed(); + + this.factory.RemoveHttpDataFromClean(this.request, data); + } + + + // Remove all Attributes that should be cleaned between two FileUpload in + // Mixed mode + void CleanMixedAttributes() + { + this.currentFieldAttributes.Remove(HttpHeaderValues.Charset); + this.currentFieldAttributes.Remove(HttpHeaderNames.ContentLength); + this.currentFieldAttributes.Remove(HttpHeaderNames.ContentTransferEncoding); + this.currentFieldAttributes.Remove(HttpHeaderNames.ContentType); + this.currentFieldAttributes.Remove(HttpHeaderValues.FileName); + } + + static StringCharSequence ReadLineStandard(IByteBuffer undecodedChunk, Encoding charset) + { + int readerIndex = undecodedChunk.ReaderIndex; + try + { + IByteBuffer line = Unpooled.Buffer(64); + + while (undecodedChunk.IsReadable()) + { + byte nextByte = undecodedChunk.ReadByte(); + if (nextByte == HttpConstants.CarriageReturn) + { + // check but do not changed readerIndex + nextByte = undecodedChunk.GetByte(undecodedChunk.ReaderIndex); + if (nextByte == HttpConstants.LineFeed) + { + // force read + undecodedChunk.ReadByte(); + return new StringCharSequence(line.ToString(charset)); + } + else + { + // Write CR (not followed by LF) + line.WriteByte(HttpConstants.CarriageReturn); + } + } + else if (nextByte == HttpConstants.LineFeed) + { + return new StringCharSequence(line.ToString(charset)); + } + else + { + line.WriteByte(nextByte); + } + } + } + catch (IndexOutOfRangeException e) + { + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(e); + } + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(nameof(ReadDelimiterStandard)); + } + + static StringCharSequence ReadLine(IByteBuffer undecodedChunk, Encoding charset) + { + if (!undecodedChunk.HasArray) + { + return ReadLineStandard(undecodedChunk, charset); + } + var sao = new HttpPostBodyUtil.SeekAheadOptimize(undecodedChunk); + int readerIndex = undecodedChunk.ReaderIndex; + try + { + IByteBuffer line = Unpooled.Buffer(64); + + while (sao.Pos < sao.Limit) + { + byte nextByte = sao.Bytes[sao.Pos++]; + if (nextByte == HttpConstants.CarriageReturn) + { + if (sao.Pos < sao.Limit) + { + nextByte = sao.Bytes[sao.Pos++]; + if (nextByte == HttpConstants.LineFeed) + { + sao.SetReadPosition(0); + return new StringCharSequence(line.ToString(charset)); + } + else + { + // Write CR (not followed by LF) + sao.Pos--; + line.WriteByte(HttpConstants.CarriageReturn); + } + } + else + { + line.WriteByte(nextByte); + } + } + else if (nextByte == HttpConstants.LineFeed) + { + sao.SetReadPosition(0); + return new StringCharSequence(line.ToString(charset)); + } + else + { + line.WriteByte(nextByte); + } + } + } + catch (IndexOutOfRangeException e) + { + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(e); + } + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(nameof(ReadLine)); + } + + static StringBuilderCharSequence ReadDelimiterStandard(IByteBuffer undecodedChunk, ICharSequence delimiter) + { + int readerIndex = undecodedChunk.ReaderIndex; + try + { + var sb = new StringBuilderCharSequence(64); + int delimiterPos = 0; + int len = delimiter.Count; + while (undecodedChunk.IsReadable() && delimiterPos < len) + { + byte nextByte = undecodedChunk.ReadByte(); + if (nextByte == delimiter[delimiterPos]) + { + delimiterPos++; + sb.Append((char)nextByte); + } + else + { + // delimiter not found so break here ! + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(nameof(ReadDelimiterStandard)); + } + } + // Now check if either opening delimiter or closing delimiter + if (undecodedChunk.IsReadable()) + { + byte nextByte = undecodedChunk.ReadByte(); + // first check for opening delimiter + if (nextByte == HttpConstants.CarriageReturn) + { + nextByte = undecodedChunk.ReadByte(); + if (nextByte == HttpConstants.LineFeed) + { + return sb; + } + else + { + // error since CR must be followed by LF + // delimiter not found so break here ! + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(nameof(ReadDelimiterStandard)); + } + } + else if (nextByte == HttpConstants.LineFeed) + { + return sb; + } + else if (nextByte == '-') + { + sb.Append('-'); + // second check for closing delimiter + nextByte = undecodedChunk.ReadByte(); + if (nextByte == '-') + { + sb.Append('-'); + // now try to find if CRLF or LF there + if (undecodedChunk.IsReadable()) + { + nextByte = undecodedChunk.ReadByte(); + if (nextByte == HttpConstants.CarriageReturn) + { + nextByte = undecodedChunk.ReadByte(); + if (nextByte == HttpConstants.LineFeed) + { + return sb; + } + else + { + // error CR without LF + // delimiter not found so break here ! + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(nameof(ReadDelimiterStandard)); + } + } + else if (nextByte == HttpConstants.LineFeed) + { + return sb; + } + else + { + // No CRLF but ok however (Adobe Flash uploader) + // minus 1 since we read one char ahead but + // should not + undecodedChunk.SetReaderIndex(undecodedChunk.ReaderIndex - 1); + return sb; + } + } + // FIXME what do we do here? + // either considering it is fine, either waiting for + // more data to come? + // lets try considering it is fine... + return sb; + } + // only one '-' => not enough + // whatever now => error since incomplete + } + } + } + catch (IndexOutOfRangeException e) + { + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(e); + } + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(nameof(ReadDelimiterStandard)); + } + + static StringBuilderCharSequence ReadDelimiter(IByteBuffer undecodedChunk, ICharSequence delimiter) + { + if (!undecodedChunk.HasArray) + { + return ReadDelimiterStandard(undecodedChunk, delimiter); + } + var sao = new HttpPostBodyUtil.SeekAheadOptimize(undecodedChunk); + int readerIndex = undecodedChunk.ReaderIndex; + int delimiterPos = 0; + int len = delimiter.Count; + try + { + var sb = new StringBuilderCharSequence(64); + // check conformity with delimiter + while (sao.Pos < sao.Limit && delimiterPos < len) + { + byte nextByte = sao.Bytes[sao.Pos++]; + if (nextByte == delimiter[delimiterPos]) + { + delimiterPos++; + sb.Append((char)nextByte); + } + else + { + // delimiter not found so break here ! + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(nameof(ReadDelimiter)); + } + } + // Now check if either opening delimiter or closing delimiter + if (sao.Pos < sao.Limit) + { + byte nextByte = sao.Bytes[sao.Pos++]; + if (nextByte == HttpConstants.CarriageReturn) + { + // first check for opening delimiter + if (sao.Pos < sao.Limit) + { + nextByte = sao.Bytes[sao.Pos++]; + if (nextByte == HttpConstants.LineFeed) + { + sao.SetReadPosition(0); + return sb; + } + else + { + // error CR without LF + // delimiter not found so break here ! + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(nameof(ReadDelimiter)); + } + } + else + { + // error since CR must be followed by LF + // delimiter not found so break here ! + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(nameof(ReadDelimiter)); + } + } + else if (nextByte == HttpConstants.LineFeed) + { + // same first check for opening delimiter where LF used with + // no CR + sao.SetReadPosition(0); + return sb; + } + else if (nextByte == '-') + { + sb.Append('-'); + // second check for closing delimiter + if (sao.Pos < sao.Limit) + { + nextByte = sao.Bytes[sao.Pos++]; + if (nextByte == '-') + { + sb.Append('-'); + // now try to find if CRLF or LF there + if (sao.Pos < sao.Limit) + { + nextByte = sao.Bytes[sao.Pos++]; + if (nextByte == HttpConstants.CarriageReturn) + { + if (sao.Pos < sao.Limit) + { + nextByte = sao.Bytes[sao.Pos++]; + if (nextByte == HttpConstants.LineFeed) + { + sao.SetReadPosition(0); + return sb; + } + else + { + // error CR without LF + // delimiter not found so break here ! + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(nameof(ReadDelimiter)); + } + } + else + { + // error CR without LF + // delimiter not found so break here ! + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(nameof(ReadDelimiter)); + } + } + else if (nextByte == HttpConstants.LineFeed) + { + sao.SetReadPosition(0); + return sb; + } + else + { + // No CRLF but ok however (Adobe Flash + // uploader) + // minus 1 since we read one char ahead but + // should not + sao.SetReadPosition(1); + return sb; + } + } + // FIXME what do we do here? + // either considering it is fine, either waiting for + // more data to come? + // lets try considering it is fine... + sao.SetReadPosition(0); + return sb; + } + // whatever now => error since incomplete + // only one '-' => not enough or whatever not enough + // element + } + } + } + } + catch (IndexOutOfRangeException e) + { + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(e); + } + undecodedChunk.SetReaderIndex(readerIndex); + throw new NotEnoughDataDecoderException(nameof(ReadDelimiter)); + } + + static bool LoadDataMultipartStandard(IByteBuffer undecodedChunk, ICharSequence delimiter, IHttpData httpData) + { + int startReaderIndex = undecodedChunk.ReaderIndex; + int delimeterLength = delimiter.Count; + int index = 0; + int lastPosition = startReaderIndex; + byte prevByte = HttpConstants.LineFeed; + bool delimiterFound = false; + while (undecodedChunk.IsReadable()) + { + byte nextByte = undecodedChunk.ReadByte(); + // Check the delimiter + if (prevByte == HttpConstants.LineFeed && nextByte == CharUtil.CodePointAt(delimiter, index)) + { + index++; + if (delimeterLength == index) + { + delimiterFound = true; + break; + } + continue; + } + lastPosition = undecodedChunk.ReaderIndex; + if (nextByte == HttpConstants.LineFeed) + { + index = 0; + lastPosition -= (prevByte == HttpConstants.CarriageReturn) ? 2 : 1; + } + prevByte = nextByte; + } + if (prevByte == HttpConstants.CarriageReturn) + { + lastPosition--; + } + IByteBuffer content = undecodedChunk.Copy(startReaderIndex, lastPosition - startReaderIndex); + try + { + httpData.AddContent(content, delimiterFound); + } + catch (IOException e) + { + throw new ErrorDataDecoderException(e); + } + undecodedChunk.SetReaderIndex(lastPosition); + return delimiterFound; + } + + static bool LoadDataMultipart(IByteBuffer undecodedChunk, ICharSequence delimiter, IHttpData httpData) + { + if (!undecodedChunk.HasArray) + { + return LoadDataMultipartStandard(undecodedChunk, delimiter, httpData); + } + var sao = new HttpPostBodyUtil.SeekAheadOptimize(undecodedChunk); + int startReaderIndex = undecodedChunk.ReaderIndex; + int delimeterLength = delimiter.Count; + int index = 0; + int lastRealPos = sao.Pos; + byte prevByte = HttpConstants.LineFeed; + bool delimiterFound = false; + while (sao.Pos < sao.Limit) + { + byte nextByte = sao.Bytes[sao.Pos++]; + // Check the delimiter + if (prevByte == HttpConstants.LineFeed && nextByte == CharUtil.CodePointAt(delimiter, index)) + { + index++; + if (delimeterLength == index) + { + delimiterFound = true; + break; + } + continue; + } + lastRealPos = sao.Pos; + if (nextByte == HttpConstants.LineFeed) + { + index = 0; + lastRealPos -= (prevByte == HttpConstants.CarriageReturn) ? 2 : 1; + } + prevByte = nextByte; + } + if (prevByte == HttpConstants.CarriageReturn) + { + lastRealPos--; + } + int lastPosition = sao.GetReadPosition(lastRealPos); + IByteBuffer content = undecodedChunk.Copy(startReaderIndex, lastPosition - startReaderIndex); + try + { + httpData.AddContent(content, delimiterFound); + } + catch (IOException e) + { + throw new ErrorDataDecoderException(e); + } + undecodedChunk.SetReaderIndex(lastPosition); + return delimiterFound; + } + + static ICharSequence CleanString(string field) => CleanString(new StringCharSequence(field)); + + static ICharSequence CleanString(ICharSequence field) + { + int size = field.Count; + var sb = new StringBuilderCharSequence(size); + for (int i = 0; i < size; i++) + { + char nextChar = field[i]; + switch (nextChar) + { + case ':': // Colon + case ',': // Comma + case '=': // EqualsSign + case ';': // Semicolon + case '\t': // HorizontalTab + sb.Append(HttpConstants.HorizontalSpaceChar); + break; + case '"': // DoubleQuote + // nothing added, just removes it + break; + default: + sb.Append(nextChar); + break; + } + } + return CharUtil.Trim(sb); + } + + bool SkipOneLine() + { + if (!this.undecodedChunk.IsReadable()) + { + return false; + } + byte nextByte = this.undecodedChunk.ReadByte(); + if (nextByte == HttpConstants.CarriageReturn) + { + if (!this.undecodedChunk.IsReadable()) + { + this.undecodedChunk.SetReaderIndex(this.undecodedChunk.ReaderIndex - 1); + return false; + } + + nextByte = this.undecodedChunk.ReadByte(); + if (nextByte == HttpConstants.LineFeed) + { + return true; + } + + this.undecodedChunk.SetReaderIndex(this.undecodedChunk.ReaderIndex - 2); + return false; + } + + if (nextByte == HttpConstants.LineFeed) + { + return true; + } + this.undecodedChunk.SetReaderIndex(this.undecodedChunk.ReaderIndex - 1); + return false; + } + + + static ICharSequence[] SplitMultipartHeader(ICharSequence sb) + { + var headers = new List(1); + int nameEnd; + int colonEnd; + int nameStart = HttpPostBodyUtil.FindNonWhitespace(sb, 0); + for (nameEnd = nameStart; nameEnd < sb.Count; nameEnd++) + { + char ch = sb[nameEnd]; + if (ch == ':' || char.IsWhiteSpace(ch)) + { + break; + } + } + for (colonEnd = nameEnd; colonEnd < sb.Count; colonEnd++) + { + if (sb[colonEnd] == ':') + { + colonEnd++; + break; + } + } + int valueStart = HttpPostBodyUtil.FindNonWhitespace(sb, colonEnd); + int valueEnd = HttpPostBodyUtil.FindEndOfString(sb); + headers.Add(sb.SubSequence(nameStart, nameEnd)); + ICharSequence svalue = (valueStart >= valueEnd) ? AsciiString.Empty : sb.SubSequence(valueStart, valueEnd); + ICharSequence[] values; + if (svalue.IndexOf(';') >= 0) + { + values = SplitMultipartHeaderValues(svalue); + } + else + { + values = CharUtil.Split(svalue, ','); + } + foreach(ICharSequence value in values) + { + headers.Add(CharUtil.Trim(value)); + } + var array = new ICharSequence[headers.Count]; + for (int i = 0; i < headers.Count; i++) + { + array[i] = headers[i]; + } + return array; + } + + static ICharSequence[] SplitMultipartHeaderValues(ICharSequence svalue) + { + List values = InternalThreadLocalMap.Get().CharSequenceList(1); + bool inQuote = false; + bool escapeNext = false; + int start = 0; + for (int i = 0; i < svalue.Count; i++) + { + char c = svalue[i]; + if (inQuote) + { + if (escapeNext) + { + escapeNext = false; + } + else + { + if (c == '\\') + { + escapeNext = true; + } + else if (c == '"') + { + inQuote = false; + } + } + } + else + { + if (c == '"') + { + inQuote = true; + } + else if (c == ';') + { + values.Add(svalue.SubSequence(start, i)); + start = i + 1; + } + } + } + values.Add(svalue.SubSequence(start)); + return values.ToArray(); + } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/HttpPostRequestDecoder.cs b/src/DotNetty.Codecs.Http/Multipart/HttpPostRequestDecoder.cs new file mode 100644 index 0000000..e7005c3 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/HttpPostRequestDecoder.cs @@ -0,0 +1,178 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Text; + using DotNetty.Common.Utilities; + + public class HttpPostRequestDecoder : IInterfaceHttpPostRequestDecoder + { + internal static readonly int DefaultDiscardThreshold = 10 * 1024 * 1024; + + readonly IInterfaceHttpPostRequestDecoder decoder; + + public HttpPostRequestDecoder(IHttpRequest request) + : this(new DefaultHttpDataFactory(DefaultHttpDataFactory.MinSize), request, HttpConstants.DefaultEncoding) + { + } + + public HttpPostRequestDecoder(IHttpDataFactory factory, IHttpRequest request) + : this(factory, request, HttpConstants.DefaultEncoding) + { + } + + public HttpPostRequestDecoder(IHttpDataFactory factory, IHttpRequest request, Encoding encoding) + { + Contract.Requires(factory != null); + Contract.Requires(request != null); + Contract.Requires(encoding != null); + + // Fill default values + if (IsMultipartRequest(request)) + { + this.decoder = new HttpPostMultipartRequestDecoder(factory, request, encoding); + } + else + { + this.decoder = new HttpPostStandardRequestDecoder(factory, request, encoding); + } + } + + public static bool IsMultipartRequest(IHttpRequest request) + { + if (request.Headers.TryGet(HttpHeaderNames.ContentType, out ICharSequence contentType)) + { + return GetMultipartDataBoundary(contentType) != null; + } + else + { + return false; + } + } + + // + // Check from the request ContentType if this request is a Multipart request. + // return an array of String if multipartDataBoundary exists with the multipartDataBoundary + // as first element, charset if any as second (missing if not set), else null + // + protected internal static ICharSequence[] GetMultipartDataBoundary(ICharSequence contentType) + { + // Check if Post using "multipart/form-data; boundary=--89421926422648 [; charset=xxx]" + ICharSequence[] headerContentType = SplitHeaderContentType(contentType); + AsciiString multiPartHeader = HttpHeaderValues.MultipartFormData; + if (headerContentType[0].RegionMatchesIgnoreCase(0, multiPartHeader, 0, multiPartHeader.Count)) + { + int mrank; + int crank; + AsciiString boundaryHeader = HttpHeaderValues.Boundary; + if (headerContentType[1].RegionMatchesIgnoreCase(0, boundaryHeader, 0, boundaryHeader.Count)) + { + mrank = 1; + crank = 2; + } + else if (headerContentType[2].RegionMatchesIgnoreCase(0, boundaryHeader, 0, boundaryHeader.Count)) + { + mrank = 2; + crank = 1; + } + else + { + return null; + } + ICharSequence boundary = headerContentType[mrank].SubstringAfter('='); + if (boundary == null) + { + throw new ErrorDataDecoderException("Needs a boundary value"); + } + if (boundary[0] == '"') + { + ICharSequence bound = CharUtil.Trim(boundary); + int index = bound.Count - 1; + if (bound[index] == '"') + { + boundary = bound.SubSequence(1, index); + } + } + AsciiString charsetHeader = HttpHeaderValues.Charset; + if (headerContentType[crank].RegionMatchesIgnoreCase(0, charsetHeader, 0, charsetHeader.Count)) + { + ICharSequence charset = headerContentType[crank].SubstringAfter('='); + if (charset != null) + { + return new [] + { + new StringCharSequence("--" + boundary.ToString()), + charset + }; + } + } + + return new ICharSequence[] + { + new StringCharSequence("--" + boundary.ToString()) + }; + } + + return null; + } + + public bool IsMultipart => this.decoder.IsMultipart; + + public int DiscardThreshold + { + get => this.decoder.DiscardThreshold; + set => this.decoder.DiscardThreshold = value; + } + + public List GetBodyHttpDatas() => this.decoder.GetBodyHttpDatas(); + + public List GetBodyHttpDatas(AsciiString name) => this.decoder.GetBodyHttpDatas(name); + + public IInterfaceHttpData GetBodyHttpData(AsciiString name) => this.decoder.GetBodyHttpData(name); + + public IInterfaceHttpPostRequestDecoder Offer(IHttpContent content) => this.decoder.Offer(content); + + public bool HasNext => this.decoder.HasNext; + + public IInterfaceHttpData Next() => this.decoder.Next(); + + public IInterfaceHttpData CurrentPartialHttpData => this.decoder.CurrentPartialHttpData; + + public void Destroy() => this.decoder.Destroy(); + + public void CleanFiles() => this.decoder.CleanFiles(); + + public void RemoveHttpDataFromClean(IInterfaceHttpData data) => this.decoder.RemoveHttpDataFromClean(data); + + static ICharSequence[] SplitHeaderContentType(ICharSequence sb) + { + int aStart = HttpPostBodyUtil.FindNonWhitespace(sb, 0); + int aEnd = sb.IndexOf(';'); + if (aEnd == -1) + { + return new [] { sb, StringCharSequence.Empty, StringCharSequence.Empty }; + } + int bStart = HttpPostBodyUtil.FindNonWhitespace(sb, aEnd + 1); + if (sb[aEnd - 1] == ' ') + { + aEnd--; + } + int bEnd = sb.IndexOf(';', bStart); + if (bEnd == -1) + { + bEnd = HttpPostBodyUtil.FindEndOfString(sb); + return new [] { sb.SubSequence(aStart, aEnd), sb.SubSequence(bStart, bEnd), StringCharSequence.Empty }; + } + int cStart = HttpPostBodyUtil.FindNonWhitespace(sb, bEnd + 1); + if (sb[bEnd - 1] == ' ') + { + bEnd--; + } + int cEnd = HttpPostBodyUtil.FindEndOfString(sb); + return new [] { sb.SubSequence(aStart, aEnd), sb.SubSequence(bStart, bEnd), sb.SubSequence(cStart, cEnd) }; + } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/HttpPostRequestEncoder.cs b/src/DotNetty.Codecs.Http/Multipart/HttpPostRequestEncoder.cs new file mode 100644 index 0000000..ad562f2 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/HttpPostRequestEncoder.cs @@ -0,0 +1,1095 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoProperty +// ReSharper disable ConvertToAutoPropertyWithPrivateSetter +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using System.Collections.Generic; + using System.Diagnostics; + using System.Diagnostics.Contracts; + using System.IO; + using System.Text; + using System.Text.RegularExpressions; + using DotNetty.Buffers; + using DotNetty.Common; + using DotNetty.Common.Internal; + using DotNetty.Common.Utilities; + using DotNetty.Handlers.Streams; + + public class HttpPostRequestEncoder : IChunkedInput + { + public enum EncoderMode + { + // Legacy mode which should work for most. It is known to not work with OAUTH. For OAUTH use + // {@link EncoderMode#RFC3986}. The W3C form recommendations this for submitting post form data. + RFC1738, + + // Mode which is more new and is used for OAUTH + RFC3986, + + // The HTML5 spec disallows mixed mode in multipart/form-data + // requests. More concretely this means that more files submitted + // under the same name will not be encoded using mixed mode, but + // will be treated as distinct fields. + // Reference: http://www.w3.org/TR/html5/forms.html#multipart-form-data + HTML5 + } + + static readonly KeyValuePair[] PercentEncodings; + + static HttpPostRequestEncoder() + { + PercentEncodings = new[] + { + new KeyValuePair(new Regex("\\*", RegexOptions.Compiled), "%2A"), + new KeyValuePair(new Regex("\\+", RegexOptions.Compiled), "%20"), + new KeyValuePair(new Regex("~", RegexOptions.Compiled), "%7E"), + }; + } + + // Factory used to create InterfaceHttpData + readonly IHttpDataFactory factory; + + // Request to encode + readonly IHttpRequest request; + + // Default charset to use + readonly Encoding charset; + + // Chunked false by default + bool isChunked; + + // InterfaceHttpData for Body (without encoding) + readonly List bodyListDatas; + + // The final Multipart List of InterfaceHttpData including encoding + internal readonly List MultipartHttpDatas; + + // Does this request is a Multipart request + readonly bool isMultipart; + + // If multipart, this is the boundary for the global multipart + internal string MultipartDataBoundary; + + // If multipart, there could be internal multiparts (mixed) to the global multipart. Only one level is allowed. + internal string MultipartMixedBoundary; + + // To check if the header has been finalized + bool headerFinalized; + + readonly EncoderMode encoderMode; + + public HttpPostRequestEncoder(IHttpRequest request, bool multipart) + : this(new DefaultHttpDataFactory(DefaultHttpDataFactory.MinSize), request, multipart, + HttpConstants.DefaultEncoding, EncoderMode.RFC1738) + { + } + + public HttpPostRequestEncoder(IHttpDataFactory factory, IHttpRequest request, bool multipart) + : this(factory, request, multipart, HttpConstants.DefaultEncoding, EncoderMode.RFC1738) + { + } + + public HttpPostRequestEncoder( + IHttpDataFactory factory, IHttpRequest request, bool multipart, Encoding charset, + EncoderMode encoderMode) + { + Contract.Requires(request != null); + Contract.Requires(factory != null); + Contract.Requires(charset != null); + + this.request = request; + this.charset = charset; + this.factory = factory; + HttpMethod method = request.Method; + if (method.Equals(HttpMethod.Trace)) + { + throw new ErrorDataEncoderException("Cannot create a Encoder if request is a TRACE"); + } + // Fill default values + this.bodyListDatas = new List(); + // default mode + this.isLastChunk = false; + this.isLastChunkSent = false; + this.isMultipart = multipart; + this.MultipartHttpDatas = new List(); + this.encoderMode = encoderMode; + if (this.isMultipart) + { + this.InitDataMultipart(); + } + } + // Clean all HttpDatas (on Disk) for the current request. + public void CleanFiles() => this.factory.CleanRequestHttpData(this.request); + + // Does the last non empty chunk already encoded so that next chunk will be empty (last chunk) + bool isLastChunk; + + // Last chunk already sent + bool isLastChunkSent; + + // The current FileUpload that is currently in encode process + IFileUpload currentFileUpload; + + // While adding a FileUpload, is the multipart currently in Mixed Mode + bool duringMixedMode; + + // Global Body size + long globalBodySize; + + // Global Transfer progress + long globalProgress; + + public bool IsMultipart => this.isMultipart; + + void InitDataMultipart() => this.MultipartDataBoundary = GetNewMultipartDelimiter(); + + void InitMixedMultipart() => this.MultipartMixedBoundary = GetNewMultipartDelimiter(); + + // construct a generated delimiter + static string GetNewMultipartDelimiter() => Convert.ToString(PlatformDependent.GetThreadLocalRandom().NextLong(), 16).ToLower(); + + public List GetBodyListAttributes() => this.bodyListDatas; + + public void SetBodyHttpDatas(List list) + { + Contract.Requires(list != null); + + this.globalBodySize = 0; + this.bodyListDatas.Clear(); + this.currentFileUpload = null; + this.duringMixedMode = false; + this.MultipartHttpDatas.Clear(); + foreach (IInterfaceHttpData data in list) + { + this.AddBodyHttpData(data); + } + } + + public void AddBodyAttribute(string name, string value) + { + Contract.Requires(name != null); + IAttribute data = this.factory.CreateAttribute(this.request, name, value ?? StringUtil.EmptyString); + this.AddBodyHttpData(data); + } + + public void AddBodyFileUpload(string name, FileStream fileStream, string contentType, bool isText) + { + string fileName = Path.GetFileName(fileStream.Name); + this.AddBodyFileUpload(name, fileName, fileStream, contentType, isText); + } + + public void AddBodyFileUpload(string name, string fileName, FileStream fileStream, string contentType, bool isText) + { + Contract.Requires(name != null); + Contract.Requires(fileStream != null); + + if (fileName == null) + { + fileName = StringUtil.EmptyString; + } + string scontentType = contentType; + string contentTransferEncoding = null; + if (contentType == null) + { + scontentType = isText + ? HttpPostBodyUtil.DefaultTextContentType + : HttpPostBodyUtil.DefaultBinaryContentType; + } + if (!isText) + { + contentTransferEncoding = HttpPostBodyUtil.TransferEncodingMechanism.Binary.Value; + } + + IFileUpload fileUpload = this.factory.CreateFileUpload(this.request, name, fileName, scontentType, + contentTransferEncoding, null, fileStream.Length); + try + { + fileUpload.SetContent(fileStream); + } + catch (IOException e) + { + throw new ErrorDataEncoderException(e); + } + + this.AddBodyHttpData(fileUpload); + } + + public void AddBodyFileUploads(string name, FileStream[] file, string[] contentType, bool[] isText) + { + if (file.Length != contentType.Length && file.Length != isText.Length) + { + throw new ArgumentException("Different array length"); + } + for (int i = 0; i < file.Length; i++) + { + this.AddBodyFileUpload(name, file[i], contentType[i], isText[i]); + } + } + + public void AddBodyHttpData(IInterfaceHttpData data) + { + Contract.Requires(data != null); + if (this.headerFinalized) + { + throw new ErrorDataEncoderException("Cannot add value once finalized"); + } + this.bodyListDatas.Add(data); + if (!this.isMultipart) + { + if (data is IAttribute dataAttribute) + { + try + { + // name=value& with encoded name and attribute + string key = this.EncodeAttribute(dataAttribute.Name, this.charset); + string value = this.EncodeAttribute(dataAttribute.Value, this.charset); + IAttribute newattribute = this.factory.CreateAttribute(this.request, key, value); + this.MultipartHttpDatas.Add(newattribute); + this.globalBodySize += newattribute.Name.Length + 1 + newattribute.Length + 1; + } + catch (IOException e) + { + throw new ErrorDataEncoderException(e); + } + } + else if (data is IFileUpload fileUpload) + { + // since not Multipart, only name=filename => Attribute + // name=filename& with encoded name and filename + string key = this.EncodeAttribute(fileUpload.Name, this.charset); + string value = this.EncodeAttribute(fileUpload.FileName, this.charset); + IAttribute newattribute = this.factory.CreateAttribute(this.request, key, value); + this.MultipartHttpDatas.Add(newattribute); + this.globalBodySize += newattribute.Name.Length + 1 + newattribute.Length + 1; + } + return; + } + // Logic: + // if not Attribute: + // add Data to body list + // if (duringMixedMode) + // add endmixedmultipart delimiter + // currentFileUpload = null + // duringMixedMode = false; + // add multipart delimiter, multipart body header and Data to multipart list + // reset currentFileUpload, duringMixedMode + // if FileUpload: take care of multiple file for one field => mixed mode + // if (duringMixedMode) + // if (currentFileUpload.name == data.name) + // add mixedmultipart delimiter, mixedmultipart body header and Data to multipart list + // else + // add endmixedmultipart delimiter, multipart body header and Data to multipart list + // currentFileUpload = data + // duringMixedMode = false; + // else + // if (currentFileUpload.name == data.name) + // change multipart body header of previous file into multipart list to + // mixedmultipart start, mixedmultipart body header + // add mixedmultipart delimiter, mixedmultipart body header and Data to multipart list + // duringMixedMode = true + // else + // add multipart delimiter, multipart body header and Data to multipart list + // currentFileUpload = data + // duringMixedMode = false; + // Do not add last delimiter! Could be: + // if duringmixedmode: endmixedmultipart + endmultipart + // else only endmultipart + // + if (data is IAttribute attribute) + { + InternalAttribute internalAttribute; + if (this.duringMixedMode) + { + internalAttribute = new InternalAttribute(this.charset); + internalAttribute.AddValue($"\r\n--{this.MultipartMixedBoundary}--"); + this.MultipartHttpDatas.Add(internalAttribute); + this.MultipartMixedBoundary = null; + this.currentFileUpload = null; + this.duringMixedMode = false; + } + internalAttribute = new InternalAttribute(this.charset); + if (this.MultipartHttpDatas.Count > 0) + { + // previously a data field so CRLF + internalAttribute.AddValue("\r\n"); + } + internalAttribute.AddValue($"--{this.MultipartDataBoundary}\r\n"); + // content-disposition: form-data; name="field1" + internalAttribute.AddValue($"{HttpHeaderNames.ContentDisposition}: {HttpHeaderValues.FormData}; {HttpHeaderValues.Name}=\"{attribute.Name}\"\r\n"); + // Add Content-Length: xxx + internalAttribute.AddValue($"{HttpHeaderNames.ContentLength}: {attribute.Length}\r\n"); + Encoding localcharset = attribute.Charset; + if (localcharset != null) + { + // Content-Type: text/plain; charset=charset + internalAttribute.AddValue($"{HttpHeaderNames.ContentType}: {HttpPostBodyUtil.DefaultTextContentType}; {HttpHeaderValues.Charset}={localcharset.WebName}\r\n"); + } + // CRLF between body header and data + internalAttribute.AddValue("\r\n"); + this.MultipartHttpDatas.Add(internalAttribute); + this.MultipartHttpDatas.Add(data); + this.globalBodySize += attribute.Length + internalAttribute.Size; + } + else if (data is IFileUpload fileUpload) + { + var internalAttribute = new InternalAttribute(this.charset); + if (this.MultipartHttpDatas.Count > 0) + { + // previously a data field so CRLF + internalAttribute.AddValue("\r\n"); + } + bool localMixed; + if (this.duringMixedMode) + { + if (this.currentFileUpload != null && this.currentFileUpload.Name.Equals(fileUpload.Name)) + { + // continue a mixed mode + + localMixed = true; + } + else + { + // end a mixed mode + + // add endmixedmultipart delimiter, multipart body header + // and + // Data to multipart list + internalAttribute.AddValue($"--{this.MultipartMixedBoundary}--"); + this.MultipartHttpDatas.Add(internalAttribute); + this.MultipartMixedBoundary = null; + // start a new one (could be replaced if mixed start again + // from here + internalAttribute = new InternalAttribute(this.charset); + internalAttribute.AddValue("\r\n"); + localMixed = false; + // new currentFileUpload and no more in Mixed mode + this.currentFileUpload = fileUpload; + this.duringMixedMode = false; + } + } + else + { + if (this.encoderMode != EncoderMode.HTML5 && this.currentFileUpload != null + && this.currentFileUpload.Name.Equals(fileUpload.Name)) + { + // create a new mixed mode (from previous file) + + // change multipart body header of previous file into + // multipart list to + // mixedmultipart start, mixedmultipart body header + + // change Internal (size()-2 position in multipartHttpDatas) + // from (line starting with *) + // --AaB03x + // * Content-Disposition: form-data; name="files"; + // filename="file1.txt" + // Content-Type: text/plain + // to (lines starting with *) + // --AaB03x + // * Content-Disposition: form-data; name="files" + // * Content-Type: multipart/mixed; boundary=BbC04y + // * + // * --BbC04y + // * Content-Disposition: attachment; filename="file1.txt" + // Content-Type: text/plain + + this.InitMixedMultipart(); + var pastAttribute = (InternalAttribute)this.MultipartHttpDatas[this.MultipartHttpDatas.Count - 2]; + // remove past size + this.globalBodySize -= pastAttribute.Size; + StringBuilder replacement = new StringBuilder( + 139 + this.MultipartDataBoundary.Length + this.MultipartMixedBoundary.Length * 2 + + fileUpload.FileName.Length + fileUpload.Name.Length) + .Append("--") + .Append(this.MultipartDataBoundary) + .Append("\r\n") + + .Append(HttpHeaderNames.ContentDisposition) + .Append(": ") + .Append(HttpHeaderValues.FormData) + .Append("; ") + .Append(HttpHeaderValues.Name) + .Append("=\"") + .Append(fileUpload.Name) + .Append("\"\r\n") + + .Append(HttpHeaderNames.ContentType) + .Append(": ") + .Append(HttpHeaderValues.MultipartMixed) + .Append("; ") + .Append(HttpHeaderValues.Boundary) + .Append('=') + .Append(this.MultipartMixedBoundary) + .Append("\r\n\r\n") + + .Append("--") + .Append(this.MultipartMixedBoundary) + .Append("\r\n") + + .Append(HttpHeaderNames.ContentDisposition) + .Append(": ") + .Append(HttpHeaderValues.Attachment); + + if (fileUpload.FileName.Length > 0) + { + replacement.Append("; ") + .Append(HttpHeaderValues.FileName) + .Append("=\"") + .Append(fileUpload.FileName) + .Append('"'); + } + + replacement.Append("\r\n"); + + pastAttribute.SetValue(replacement.ToString(), 1); + pastAttribute.SetValue("", 2); + + // update past size + this.globalBodySize += pastAttribute.Size; + + // now continue + // add mixedmultipart delimiter, mixedmultipart body header + // and + // Data to multipart list + localMixed = true; + this.duringMixedMode = true; + } + else + { + // a simple new multipart + // add multipart delimiter, multipart body header and Data + // to multipart list + localMixed = false; + this.currentFileUpload = fileUpload; + this.duringMixedMode = false; + } + } + + if (localMixed) + { + // add mixedmultipart delimiter, mixedmultipart body header and + // Data to multipart list + internalAttribute.AddValue($"--{this.MultipartMixedBoundary}\r\n"); + + if (fileUpload.FileName.Length == 0) + { + // Content-Disposition: attachment + internalAttribute.AddValue($"{HttpHeaderNames.ContentDisposition}: {HttpHeaderValues.Attachment}\r\n"); + } + else + { + // Content-Disposition: attachment; filename="file1.txt" + internalAttribute.AddValue($"{HttpHeaderNames.ContentDisposition}: {HttpHeaderValues.Attachment}; {HttpHeaderValues.FileName}=\"{fileUpload.FileName}\"\r\n"); + } + } + else + { + internalAttribute.AddValue($"--{this.MultipartDataBoundary}\r\n"); + + if (fileUpload.FileName.Length == 0) + { + // Content-Disposition: form-data; name="files"; + internalAttribute.AddValue($"{HttpHeaderNames.ContentDisposition}: {HttpHeaderValues.FormData}; {HttpHeaderValues.Name}=\"{fileUpload.Name}\"\r\n"); + } + else + { + // Content-Disposition: form-data; name="files"; + // filename="file1.txt" + internalAttribute.AddValue($"{HttpHeaderNames.ContentDisposition}: {HttpHeaderValues.FormData}; {HttpHeaderValues.Name}=\"{fileUpload.Name}\"; {HttpHeaderValues.FileName}=\"{fileUpload.FileName}\"\r\n"); + } + } + // Add Content-Length: xxx + internalAttribute.AddValue($"{HttpHeaderNames.ContentLength}: {fileUpload.Length}\r\n"); + // Content-Type: image/gif + // Content-Type: text/plain; charset=ISO-8859-1 + // Content-Transfer-Encoding: binary + internalAttribute.AddValue($"{HttpHeaderNames.ContentType}: {fileUpload.ContentType}"); + string contentTransferEncoding = fileUpload.ContentTransferEncoding; + if (contentTransferEncoding != null + && contentTransferEncoding.Equals(HttpPostBodyUtil.TransferEncodingMechanism.Binary.Value)) + { + internalAttribute.AddValue($"\r\n{HttpHeaderNames.ContentTransferEncoding}: {HttpPostBodyUtil.TransferEncodingMechanism.Binary.Value}\r\n\r\n"); + } + else if (fileUpload.Charset != null) + { + internalAttribute.AddValue($"; {HttpHeaderValues.Charset}={fileUpload.Charset.WebName}\r\n\r\n"); + } + else + { + internalAttribute.AddValue("\r\n\r\n"); + } + this.MultipartHttpDatas.Add(internalAttribute); + this.MultipartHttpDatas.Add(data); + this.globalBodySize += fileUpload.Length + internalAttribute.Size; + } + } + + ListIterator iterator; + + public IHttpRequest FinalizeRequest() + { + // Finalize the multipartHttpDatas + if (!this.headerFinalized) + { + if (this.isMultipart) + { + var attribute = new InternalAttribute(this.charset); + if (this.duringMixedMode) + { + attribute.AddValue($"\r\n--{this.MultipartMixedBoundary}--"); + } + + attribute.AddValue($"\r\n--{this.MultipartDataBoundary}--\r\n"); + this.MultipartHttpDatas.Add(attribute); + this.MultipartMixedBoundary = null; + this.currentFileUpload = null; + this.duringMixedMode = false; + this.globalBodySize += attribute.Size; + } + this.headerFinalized = true; + } + else + { + throw new ErrorDataEncoderException("Header already encoded"); + } + + HttpHeaders headers = this.request.Headers; + IList contentTypes = headers.GetAll(HttpHeaderNames.ContentType); + IList transferEncoding = headers.GetAll(HttpHeaderNames.TransferEncoding); + if (contentTypes != null) + { + headers.Remove(HttpHeaderNames.ContentType); + foreach (ICharSequence contentType in contentTypes) + { + // "multipart/form-data; boundary=--89421926422648" + string lowercased = contentType.ToString().ToLower(); + if (lowercased.StartsWith(HttpHeaderValues.MultipartFormData.ToString()) + || lowercased.StartsWith(HttpHeaderValues.ApplicationXWwwFormUrlencoded.ToString())) + { + // ignore + } + else + { + headers.Add(HttpHeaderNames.ContentType, contentType); + } + } + } + if (this.isMultipart) + { + string value = $"{HttpHeaderValues.MultipartFormData}; {HttpHeaderValues.Boundary}={this.MultipartDataBoundary}"; + headers.Add(HttpHeaderNames.ContentType, value); + } + else + { + // Not multipart + headers.Add(HttpHeaderNames.ContentType, HttpHeaderValues.ApplicationXWwwFormUrlencoded); + } + // Now consider size for chunk or not + long realSize = this.globalBodySize; + if (this.isMultipart) + { + this.iterator = new ListIterator(this.MultipartHttpDatas); + } + else + { + realSize -= 1; // last '&' removed + this.iterator = new ListIterator(this.MultipartHttpDatas); + } + headers.Set(HttpHeaderNames.ContentLength, Convert.ToString(realSize)); + if (realSize > HttpPostBodyUtil.ChunkSize || this.isMultipart) + { + this.isChunked = true; + if (transferEncoding != null) + { + headers.Remove(HttpHeaderNames.TransferEncoding); + foreach (ICharSequence v in transferEncoding) + { + if (HttpHeaderValues.Chunked.ContentEqualsIgnoreCase(v)) + { + // ignore + } + else + { + headers.Add(HttpHeaderNames.TransferEncoding, v); + } + } + } + HttpUtil.SetTransferEncodingChunked(this.request, true); + + // wrap to hide the possible content + return new WrappedHttpRequest(this.request); + } + else + { + // get the only one body and set it to the request + IHttpContent chunk = this.NextChunk(); + if (this.request is IFullHttpRequest fullRequest) + { + IByteBuffer chunkContent = chunk.Content; + if (!ReferenceEquals(fullRequest.Content, chunkContent)) + { + fullRequest.Content.Clear(); + fullRequest.Content.WriteBytes(chunkContent); + chunkContent.Release(); + } + return fullRequest; + } + else + { + return new WrappedFullHttpRequest(this.request, chunk); + } + } + } + + public bool IsChunked => this.isChunked; + + string EncodeAttribute(string value, Encoding stringEncoding) + { + if (value == null) + { + return string.Empty; + } + + string encoded = UrlEncoder.Encode(value, stringEncoding); + if (this.encoderMode == EncoderMode.RFC3986) + { + foreach (KeyValuePair entry in PercentEncodings) + { + string replacement = entry.Value; + encoded = entry.Key.Replace(encoded, replacement); + } + } + return encoded; + } + + // The ByteBuf currently used by the encoder + IByteBuffer currentBuffer; + + // The current InterfaceHttpData to encode (used if more chunks are available) + IInterfaceHttpData currentData; + + // If not multipart, does the currentBuffer stands for the Key or for the Value + bool isKey = true; + + IByteBuffer FillByteBuffer() + { + int length = this.currentBuffer.ReadableBytes; + if (length > HttpPostBodyUtil.ChunkSize) + { + return this.currentBuffer.ReadRetainedSlice(HttpPostBodyUtil.ChunkSize); + } + else + { + // to continue + IByteBuffer slice = this.currentBuffer; + this.currentBuffer = null; + return slice; + } + } + + // From the current context(currentBuffer and currentData), returns the next + // HttpChunk(if possible) trying to get sizeleft bytes more into the currentBuffer. + // This is the Multipart version. + IHttpContent EncodeNextChunkMultipart(int sizeleft) + { + if (this.currentData == null) + { + return null; + } + IByteBuffer buffer; + if (this.currentData is InternalAttribute internalAttribute) + { + buffer = internalAttribute.ToByteBuffer(); + this.currentData = null; + } + else + { + try + { + buffer = ((IHttpData)this.currentData).GetChunk(sizeleft); + } + catch (IOException e) + { + throw new ErrorDataEncoderException(e); + } + if (buffer.Capacity == 0) + { + // end for current InterfaceHttpData, need more data + this.currentData = null; + return null; + } + } + this.currentBuffer = this.currentBuffer == null + ? buffer + : Unpooled.WrappedBuffer(this.currentBuffer, buffer); + + if (this.currentBuffer.ReadableBytes < HttpPostBodyUtil.ChunkSize) + { + this.currentData = null; + return null; + } + + buffer = this.FillByteBuffer(); + return new DefaultHttpContent(buffer); + } + + // From the current context(currentBuffer and currentData), returns the next HttpChunk(if possible) + // trying to get* sizeleft bytes more into the currentBuffer.This is the UrlEncoded version. + IHttpContent EncodeNextChunkUrlEncoded(int sizeleft) + { + if (this.currentData == null) + { + return null; + } + int size = sizeleft; + IByteBuffer buffer; + + // Set name= + if (this.isKey) + { + string key = this.currentData.Name; + buffer = Unpooled.WrappedBuffer(Encoding.UTF8.GetBytes(key)); + this.isKey = false; + if (this.currentBuffer == null) + { + this.currentBuffer = Unpooled.WrappedBuffer(buffer, + Unpooled.WrappedBuffer(Encoding.UTF8.GetBytes("="))); + // continue + size -= buffer.ReadableBytes + 1; + } + else + { + this.currentBuffer = Unpooled.WrappedBuffer(this.currentBuffer, buffer, + Unpooled.WrappedBuffer(Encoding.UTF8.GetBytes("="))); + // continue + size -= buffer.ReadableBytes + 1; + } + if (this.currentBuffer.ReadableBytes >= HttpPostBodyUtil.ChunkSize) + { + buffer = this.FillByteBuffer(); + return new DefaultHttpContent(buffer); + } + } + + // Put value into buffer + try + { + buffer = ((IHttpData)this.currentData).GetChunk(size); + } + catch (IOException e) + { + throw new ErrorDataEncoderException(e); + } + + // Figure out delimiter + IByteBuffer delimiter = null; + if (buffer.ReadableBytes < size) + { + this.isKey = true; + delimiter = this.iterator.HasNext() + ? Unpooled.WrappedBuffer(Encoding.UTF8.GetBytes("&")) + : null; + } + + // End for current InterfaceHttpData, need potentially more data + if (buffer.Capacity == 0) + { + this.currentData = null; + if (this.currentBuffer == null) + { + this.currentBuffer = delimiter; + } + else + { + if (delimiter != null) + { + this.currentBuffer = Unpooled.WrappedBuffer(this.currentBuffer, delimiter); + } + } + Debug.Assert(this.currentBuffer != null); + if (this.currentBuffer.ReadableBytes >= HttpPostBodyUtil.ChunkSize) + { + buffer = this.FillByteBuffer(); + return new DefaultHttpContent(buffer); + } + return null; + } + + // Put it all together: name=value& + if (this.currentBuffer == null) + { + this.currentBuffer = delimiter != null + ? Unpooled.WrappedBuffer(buffer, delimiter) + : buffer; + } + else + { + this.currentBuffer = delimiter != null + ? Unpooled.WrappedBuffer(this.currentBuffer, buffer, delimiter) + : Unpooled.WrappedBuffer(this.currentBuffer, buffer); + } + + // end for current InterfaceHttpData, need more data + if (this.currentBuffer.ReadableBytes < HttpPostBodyUtil.ChunkSize) + { + this.currentData = null; + this.isKey = true; + return null; + } + + buffer = this.FillByteBuffer(); + return new DefaultHttpContent(buffer); + } + + public void Close() + { + // NO since the user can want to reuse (broadcast for instance) + // cleanFiles(); + } + public IHttpContent ReadChunk(IByteBufferAllocator allocator) + { + if (this.isLastChunkSent) + { + return null; + } + else + { + IHttpContent nextChunk = this.NextChunk(); + this.globalProgress += nextChunk.Content.ReadableBytes; + return nextChunk; + } + } + + IHttpContent NextChunk() + { + if (this.isLastChunk) + { + this.isLastChunkSent = true; + return EmptyLastHttpContent.Default; + } + // first test if previous buffer is not empty + int size = this.CalculateRemainingSize(); + if (size <= 0) + { + // NextChunk from buffer + IByteBuffer buffer = this.FillByteBuffer(); + return new DefaultHttpContent(buffer); + } + // size > 0 + if (this.currentData != null) + { + // continue to read data + IHttpContent chunk = this.isMultipart + ? this.EncodeNextChunkMultipart(size) + : this.EncodeNextChunkUrlEncoded(size); + if (chunk != null) + { + // NextChunk from data + return chunk; + } + size = this.CalculateRemainingSize(); + } + if (!this.iterator.HasNext()) + { + return this.LastChunk(); + } + while (size > 0 && this.iterator.HasNext()) + { + this.currentData = this.iterator.Next(); + IHttpContent chunk; + if (this.isMultipart) + { + chunk = this.EncodeNextChunkMultipart(size); + } + else + { + chunk = this.EncodeNextChunkUrlEncoded(size); + } + if (chunk == null) + { + // not enough + size = this.CalculateRemainingSize(); + continue; + } + // NextChunk from data + return chunk; + } + // end since no more data + return this.LastChunk(); + } + + int CalculateRemainingSize() + { + int size = HttpPostBodyUtil.ChunkSize; + if (this.currentBuffer != null) + { + size -= this.currentBuffer.ReadableBytes; + } + return size; + } + + IHttpContent LastChunk() + { + this.isLastChunk = true; + if (this.currentBuffer == null) + { + this.isLastChunkSent = true; + // LastChunk with no more data + return EmptyLastHttpContent.Default; + } + // NextChunk as last non empty from buffer + IByteBuffer buffer = this.currentBuffer; + this.currentBuffer = null; + return new DefaultHttpContent(buffer); + } + + public bool IsEndOfInput => this.isLastChunkSent; + + public long Length => this.isMultipart ? this.globalBodySize : this.globalBodySize - 1; + + // Global Transfer progress + public long Progress => this.globalProgress; + + class WrappedHttpRequest : IHttpRequest + { + readonly IHttpRequest request; + + internal WrappedHttpRequest(IHttpRequest request) + { + this.request = request; + } + + + public IHttpMessage SetProtocolVersion(HttpVersion version) + { + this.request.SetProtocolVersion(version); + return this; + } + + public IHttpRequest SetMethod(HttpMethod method) + { + this.request.SetMethod(method); + return this; + } + + public IHttpRequest SetUri(string uri) + { + this.request.SetUri(uri); + return this; + } + + public HttpVersion ProtocolVersion => this.request.ProtocolVersion; + + public HttpMethod Method => this.request.Method; + + public string Uri => this.request.Uri; + + public HttpHeaders Headers => this.request.Headers; + + public DecoderResult Result + { + get => this.request.Result; + set => this.request.Result = value; + } + } + + sealed class WrappedFullHttpRequest : WrappedHttpRequest, IFullHttpRequest + { + readonly IHttpContent content; + + public WrappedFullHttpRequest(IHttpRequest request, IHttpContent content) + : base(request) + { + this.content = content; + } + + public IByteBufferHolder Copy() => this.Replace(this.Content.Copy()); + + public IByteBufferHolder Duplicate() => this.Replace(this.Content.Duplicate()); + + public IByteBufferHolder RetainedDuplicate() => this.Replace(this.Content.RetainedDuplicate()); + + public IByteBufferHolder Replace(IByteBuffer newContent) + { + var duplicate = new DefaultFullHttpRequest(this.ProtocolVersion, this.Method, this.Uri, newContent); + duplicate.Headers.Set(this.Headers); + duplicate.TrailingHeaders.Set(this.TrailingHeaders); + return duplicate; + } + + public IReferenceCounted Retain(int increment) + { + this.content.Retain(increment); + return this; + } + + public IReferenceCounted Retain() + { + this.content.Retain(); + return this; + } + + public IReferenceCounted Touch() + { + this.content.Touch(); + return this; + } + + public IReferenceCounted Touch(object hint) + { + this.content.Touch(hint); + return this; + } + + public IByteBuffer Content => this.content.Content; + + public HttpHeaders TrailingHeaders + { + get + { + if (this.content is ILastHttpContent httpContent) + { + return httpContent.TrailingHeaders; + } + + return EmptyHttpHeaders.Default; + } + } + + public int ReferenceCount => this.content.ReferenceCount; + + public bool Release() => this.content.Release(); + + public bool Release(int decrement) => this.content.Release(decrement); + } + + sealed class ListIterator + { + readonly List list; + int index; + + public ListIterator(List list) + { + this.list = list; + this.index = 0; + } + + public bool HasNext() => this.index < this.list.Count; + + public IInterfaceHttpData Next() + { + if (!this.HasNext()) + { + throw new InvalidOperationException("No more element to iterate"); + } + + IInterfaceHttpData data = this.list[this.index++]; + return data; + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/HttpPostStandardRequestDecoder.cs b/src/DotNetty.Codecs.Http/Multipart/HttpPostStandardRequestDecoder.cs new file mode 100644 index 0000000..3711c97 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/HttpPostStandardRequestDecoder.cs @@ -0,0 +1,586 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable RedundantAssignment +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.IO; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + + public class HttpPostStandardRequestDecoder : IInterfaceHttpPostRequestDecoder + { + // Factory used to create InterfaceHttpData + readonly IHttpDataFactory factory; + + // Request to decode + readonly IHttpRequest request; + + // Default charset to use + readonly Encoding charset; + + // Does the last chunk already received + bool isLastChunk; + + // HttpDatas from Body + readonly List bodyListHttpData = new List(); + + // HttpDatas as Map from Body + readonly Dictionary> bodyMapHttpData = new Dictionary>(CaseIgnoringComparator.Default); + + // The current channelBuffer + IByteBuffer undecodedChunk; + + // Body HttpDatas current position + int bodyListHttpDataRank; + + // Current getStatus + MultiPartStatus currentStatus = MultiPartStatus.Notstarted; + + // The current Attribute that is currently in decode process + IAttribute currentAttribute; + + bool destroyed; + + int discardThreshold = HttpPostRequestDecoder.DefaultDiscardThreshold; + + public HttpPostStandardRequestDecoder(IHttpRequest request) + : this(new DefaultHttpDataFactory(DefaultHttpDataFactory.MinSize), request, HttpConstants.DefaultEncoding) + { + } + + public HttpPostStandardRequestDecoder(IHttpDataFactory factory, IHttpRequest request) + : this(factory, request, HttpConstants.DefaultEncoding) + { + } + + public HttpPostStandardRequestDecoder(IHttpDataFactory factory, IHttpRequest request, Encoding charset) + { + Contract.Requires(request != null); + Contract.Requires(charset != null); + Contract.Requires(factory != null); + + this.factory = factory; + this.request = request; + this.charset = charset; + if (request is IHttpContent content) + { + // Offer automatically if the given request is als type of HttpContent + // See #1089 + this.Offer(content); + } + else + { + this.undecodedChunk = Unpooled.Buffer(); + this.ParseBody(); + } + } + + void CheckDestroyed() + { + if (this.destroyed) + { + throw new InvalidOperationException($"{StringUtil.SimpleClassName()} was destroyed already"); + } + } + + public bool IsMultipart + { + get + { + this.CheckDestroyed(); + return false; + } + } + + public int DiscardThreshold + { + get => this.discardThreshold; + set + { + Contract.Requires(value >= 0); + this.discardThreshold = value; + } + } + + public List GetBodyHttpDatas() + { + this.CheckDestroyed(); + + if (!this.isLastChunk) + { + throw new NotEnoughDataDecoderException(nameof(HttpPostStandardRequestDecoder)); + } + return this.bodyListHttpData; + } + + public List GetBodyHttpDatas(AsciiString name) + { + this.CheckDestroyed(); + + if (!this.isLastChunk) + { + throw new NotEnoughDataDecoderException(nameof(HttpPostStandardRequestDecoder)); + } + return this.bodyMapHttpData[name]; + } + + public IInterfaceHttpData GetBodyHttpData(AsciiString name) + { + this.CheckDestroyed(); + + if (!this.isLastChunk) + { + throw new NotEnoughDataDecoderException(nameof(HttpPostStandardRequestDecoder)); + } + + if (this.bodyMapHttpData.TryGetValue(name, out List list)) + { + return list[0]; + } + return null; + } + + public IInterfaceHttpPostRequestDecoder Offer(IHttpContent content) + { + this.CheckDestroyed(); + + // Maybe we should better not copy here for performance reasons but this will need + // more care by the caller to release the content in a correct manner later + // So maybe something to optimize on a later stage + IByteBuffer buf = content.Content; + if (this.undecodedChunk == null) + { + this.undecodedChunk = buf.Copy(); + } + else + { + this.undecodedChunk.WriteBytes(buf); + } + + if (content is ILastHttpContent) { + this.isLastChunk = true; + } + + this.ParseBody(); + if (this.undecodedChunk != null && this.undecodedChunk.WriterIndex > this.discardThreshold) + { + this.undecodedChunk.DiscardReadBytes(); + } + + return this; + } + + public bool HasNext + { + get + { + this.CheckDestroyed(); + + if (this.currentStatus == MultiPartStatus.Epilogue) + { + // OK except if end of list + if (this.bodyListHttpDataRank >= this.bodyListHttpData.Count) + { + throw new EndOfDataDecoderException(nameof(HttpPostStandardRequestDecoder)); + } + } + + return this.bodyListHttpData.Count > 0 && this.bodyListHttpDataRank < this.bodyListHttpData.Count; + } + } + + public IInterfaceHttpData Next() + { + this.CheckDestroyed(); + + return this.HasNext + ? this.bodyListHttpData[this.bodyListHttpDataRank++] + : null; + } + + public IInterfaceHttpData CurrentPartialHttpData => this.currentAttribute; + + void ParseBody() + { + if (this.currentStatus == MultiPartStatus.PreEpilogue || this.currentStatus == MultiPartStatus.Epilogue) + { + if (this.isLastChunk) + { + this.currentStatus = MultiPartStatus.Epilogue; + } + + return; + } + this.ParseBodyAttributes(); + } + + protected void AddHttpData(IInterfaceHttpData data) + { + if (data == null) + { + return; + } + ICharSequence name = new StringCharSequence(data.Name); + if (!this.bodyMapHttpData.TryGetValue(name, out List datas)) + { + datas = new List(1); + this.bodyMapHttpData.Add(name, datas); + } + datas.Add(data); + this.bodyListHttpData.Add(data); + } + + void ParseBodyAttributesStandard() + { + int firstpos = this.undecodedChunk.ReaderIndex; + int currentpos = firstpos; + if (this.currentStatus == MultiPartStatus.Notstarted) + { + this.currentStatus = MultiPartStatus.Disposition; + } + bool contRead = true; + try + { + int ampersandpos; + while (this.undecodedChunk.IsReadable() && contRead) + { + char read = (char)this.undecodedChunk.ReadByte(); + currentpos++; + switch (this.currentStatus) + { + case MultiPartStatus.Disposition:// search '=' + if (read == '=') + { + this.currentStatus = MultiPartStatus.Field; + int equalpos = currentpos - 1; + string key = DecodeAttribute(this.undecodedChunk.ToString(firstpos, equalpos - firstpos, this.charset), + this.charset); + this.currentAttribute = this.factory.CreateAttribute(this.request, key); + firstpos = currentpos; + } + else if (read == '&') + { // special empty FIELD + this.currentStatus = MultiPartStatus.Disposition; + ampersandpos = currentpos - 1; + string key = DecodeAttribute( + this.undecodedChunk.ToString(firstpos, ampersandpos - firstpos, this.charset), this.charset); + this.currentAttribute = this.factory.CreateAttribute(this.request, key); + this.currentAttribute.Value = ""; // empty + this.AddHttpData(this.currentAttribute); + this.currentAttribute = null; + firstpos = currentpos; + contRead = true; + } + break; + case MultiPartStatus.Field:// search '&' or end of line + if (read == '&') + { + this.currentStatus = MultiPartStatus.Disposition; + ampersandpos = currentpos - 1; + this.SetFinalBuffer(this.undecodedChunk.Copy(firstpos, ampersandpos - firstpos)); + firstpos = currentpos; + contRead = true; + } + else if (read == HttpConstants.CarriageReturn) + { + if (this.undecodedChunk.IsReadable()) + { + read = (char)this.undecodedChunk.ReadByte(); + currentpos++; + if (read == HttpConstants.LineFeed) + { + this.currentStatus = MultiPartStatus.PreEpilogue; + ampersandpos = currentpos - 2; + this.SetFinalBuffer(this.undecodedChunk.Copy(firstpos, ampersandpos - firstpos)); + firstpos = currentpos; + contRead = false; + } + else + { + // Error + throw new ErrorDataDecoderException("Bad end of line"); + } + } + else + { + currentpos--; + } + } + else if (read == HttpConstants.LineFeed) + { + this.currentStatus = MultiPartStatus.PreEpilogue; + ampersandpos = currentpos - 1; + this.SetFinalBuffer(this.undecodedChunk.Copy(firstpos, ampersandpos - firstpos)); + firstpos = currentpos; + contRead = false; + } + break; + default: + // just stop + contRead = false; + break; + } + } + if (this.isLastChunk && this.currentAttribute != null) + { + // special case + ampersandpos = currentpos; + if (ampersandpos > firstpos) + { + this.SetFinalBuffer(this.undecodedChunk.Copy(firstpos, ampersandpos - firstpos)); + } + else if (!this.currentAttribute.IsCompleted) + { + this.SetFinalBuffer(Unpooled.Empty); + } + firstpos = currentpos; + this.currentStatus = MultiPartStatus.Epilogue; + this.undecodedChunk.SetReaderIndex(firstpos); + return; + } + if (contRead && this.currentAttribute != null) + { + // reset index except if to continue in case of FIELD getStatus + if (this.currentStatus == MultiPartStatus.Field) + { + this.currentAttribute.AddContent(this.undecodedChunk.Copy(firstpos, currentpos - firstpos), + false); + firstpos = currentpos; + } + this.undecodedChunk.SetReaderIndex(firstpos); + } + else + { + // end of line or end of block so keep index to last valid position + this.undecodedChunk.SetReaderIndex(firstpos); + } + } + catch (ErrorDataDecoderException) + { + // error while decoding + this.undecodedChunk.SetReaderIndex(firstpos); + throw; + } + catch (IOException e) + { + // error while decoding + this.undecodedChunk.SetReaderIndex(firstpos); + throw new ErrorDataDecoderException(e); + } + } + + void ParseBodyAttributes() + { + if (!this.undecodedChunk.HasArray) + { + this.ParseBodyAttributesStandard(); + return; + } + var sao = new HttpPostBodyUtil.SeekAheadOptimize(this.undecodedChunk); + int firstpos = this.undecodedChunk.ReaderIndex; + int currentpos = firstpos; + if (this.currentStatus == MultiPartStatus.Notstarted) + { + this.currentStatus = MultiPartStatus.Disposition; + } + bool contRead = true; + try + { + //loop: + int ampersandpos; + while (sao.Pos < sao.Limit) + { + char read = (char)(sao.Bytes[sao.Pos++]); + currentpos++; + switch (this.currentStatus) + { + case MultiPartStatus.Disposition:// search '=' + if (read == '=') + { + this.currentStatus = MultiPartStatus.Field; + int equalpos = currentpos - 1; + string key = DecodeAttribute(this.undecodedChunk.ToString(firstpos, equalpos - firstpos, this.charset), + this.charset); + this.currentAttribute = this.factory.CreateAttribute(this.request, key); + firstpos = currentpos; + } + else if (read == '&') + { // special empty FIELD + this.currentStatus = MultiPartStatus.Disposition; + ampersandpos = currentpos - 1; + string key = DecodeAttribute( + this.undecodedChunk.ToString(firstpos, ampersandpos - firstpos, this.charset), this.charset); + this.currentAttribute = this.factory.CreateAttribute(this.request, key); + this.currentAttribute.Value = ""; // empty + this.AddHttpData(this.currentAttribute); + this.currentAttribute = null; + firstpos = currentpos; + contRead = true; + } + break; + case MultiPartStatus.Field:// search '&' or end of line + if (read == '&') + { + this.currentStatus = MultiPartStatus.Disposition; + ampersandpos = currentpos - 1; + this.SetFinalBuffer(this.undecodedChunk.Copy(firstpos, ampersandpos - firstpos)); + firstpos = currentpos; + contRead = true; + } + else if (read == HttpConstants.CarriageReturn) + { + if (sao.Pos < sao.Limit) + { + read = (char)(sao.Bytes[sao.Pos++]); + currentpos++; + if (read == HttpConstants.LineFeed) + { + this.currentStatus = MultiPartStatus.PreEpilogue; + ampersandpos = currentpos - 2; + sao.SetReadPosition(0); + this.SetFinalBuffer(this.undecodedChunk.Copy(firstpos, ampersandpos - firstpos)); + firstpos = currentpos; + contRead = false; + goto loop; + } + else + { + // Error + sao.SetReadPosition(0); + throw new ErrorDataDecoderException("Bad end of line"); + } + } + else + { + if (sao.Limit > 0) + { + currentpos--; + } + } + } + else if (read == HttpConstants.LineFeed) + { + this.currentStatus = MultiPartStatus.PreEpilogue; + ampersandpos = currentpos - 1; + sao.SetReadPosition(0); + this.SetFinalBuffer(this.undecodedChunk.Copy(firstpos, ampersandpos - firstpos)); + firstpos = currentpos; + contRead = false; + goto loop; + } + break; + default: + // just stop + sao.SetReadPosition(0); + contRead = false; + goto loop; + } + } + loop: + if (this.isLastChunk && this.currentAttribute != null) + { + // special case + ampersandpos = currentpos; + if (ampersandpos > firstpos) + { + this.SetFinalBuffer(this.undecodedChunk.Copy(firstpos, ampersandpos - firstpos)); + } + else if (!this.currentAttribute.IsCompleted) + { + this.SetFinalBuffer(Unpooled.Empty); + } + firstpos = currentpos; + this.currentStatus = MultiPartStatus.Epilogue; + this.undecodedChunk.SetReaderIndex(firstpos); + return; + } + if (contRead && this.currentAttribute != null) + { + // reset index except if to continue in case of FIELD getStatus + if (this.currentStatus == MultiPartStatus.Field) + { + this.currentAttribute.AddContent(this.undecodedChunk.Copy(firstpos, currentpos - firstpos), + false); + firstpos = currentpos; + } + this.undecodedChunk.SetReaderIndex(firstpos); + } + else + { + // end of line or end of block so keep index to last valid position + this.undecodedChunk.SetReaderIndex(firstpos); + } + } + catch (ErrorDataDecoderException) + { + // error while decoding + this.undecodedChunk.SetReaderIndex(firstpos); + throw; + } + catch (IOException e) + { + // error while decoding + this.undecodedChunk.SetReaderIndex(firstpos); + throw new ErrorDataDecoderException(e); + } + catch (ArgumentException e) + { + // error while decoding + this.undecodedChunk.SetReaderIndex(firstpos); + throw new ErrorDataDecoderException(e); + } + } + + void SetFinalBuffer(IByteBuffer buffer) + { + this.currentAttribute.AddContent(buffer, true); + string value = DecodeAttribute(this.currentAttribute.GetByteBuffer().ToString(this.charset), this.charset); + this.currentAttribute.Value = value; + this.AddHttpData(this.currentAttribute); + this.currentAttribute = null; + } + + static string DecodeAttribute(string s, Encoding charset) + { + try + { + return QueryStringDecoder.DecodeComponent(s, charset); + } + catch (ArgumentException e) + { + throw new ErrorDataDecoderException($"Bad string: '{s}'", e); + } + } + + public void Destroy() + { + this.CleanFiles(); + this.destroyed = true; + + if (this.undecodedChunk != null && this.undecodedChunk.ReferenceCount > 0) + { + this.undecodedChunk.Release(); + this.undecodedChunk = null; + } + } + + public void CleanFiles() + { + this.CheckDestroyed(); + + this.factory.CleanRequestHttpData(this.request); + } + + public void RemoveHttpDataFromClean(IInterfaceHttpData data) + { + this.CheckDestroyed(); + + this.factory.RemoveHttpDataFromClean(this.request, data); + } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/IAttribute.cs b/src/DotNetty.Codecs.Http/Multipart/IAttribute.cs new file mode 100644 index 0000000..823a89c --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/IAttribute.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + public interface IAttribute : IHttpData + { + string Value { get; set; } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/IFileUpload.cs b/src/DotNetty.Codecs.Http/Multipart/IFileUpload.cs new file mode 100644 index 0000000..7f0c096 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/IFileUpload.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + public interface IFileUpload : IHttpData + { + string FileName { get; set; } + + string ContentType { get; set; } + + string ContentTransferEncoding { get; set; } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/IHttpData .cs b/src/DotNetty.Codecs.Http/Multipart/IHttpData .cs new file mode 100644 index 0000000..cab57ab --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/IHttpData .cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System.IO; + using System.Text; + using DotNetty.Buffers; + + public interface IHttpData : IInterfaceHttpData, IByteBufferHolder + { + long MaxSize { get; set; } + + void CheckSize(long newSize); + + void SetContent(IByteBuffer buffer); + + void SetContent(Stream source); + + void AddContent(IByteBuffer buffer, bool last); + + bool IsCompleted { get; } + + long Length { get; } + + long DefinedLength { get; } + + void Delete(); + + byte[] GetBytes(); + + IByteBuffer GetByteBuffer(); + + IByteBuffer GetChunk(int length); + + string GetString(); + + string GetString(Encoding charset); + + Encoding Charset { get; set; } + + bool RenameTo(FileStream destination); + + bool IsInMemory { get; } + + FileStream GetFile(); + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/IHttpDataFactory.cs b/src/DotNetty.Codecs.Http/Multipart/IHttpDataFactory.cs new file mode 100644 index 0000000..fc63738 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/IHttpDataFactory.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System.Text; + + /// + /// Interface to enable creation of IPostHttpData objects + /// + public interface IHttpDataFactory + { + void SetMaxLimit(long max); + + IAttribute CreateAttribute(IHttpRequest request, string name); + + IAttribute CreateAttribute(IHttpRequest request, string name, long definedSize); + + IAttribute CreateAttribute(IHttpRequest request, string name, string value); + + IFileUpload CreateFileUpload(IHttpRequest request, string name, string filename, + string contentType, string contentTransferEncoding, Encoding charset, long size); + + void RemoveHttpDataFromClean(IHttpRequest request, IInterfaceHttpData data); + + void CleanRequestHttpData(IHttpRequest request); + + void CleanAllHttpData(); + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/IInterfaceHttpData.cs b/src/DotNetty.Codecs.Http/Multipart/IInterfaceHttpData.cs new file mode 100644 index 0000000..70b3aa8 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/IInterfaceHttpData.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using DotNetty.Common; + + public enum HttpDataType + { + Attribute, + FileUpload, + InternalAttribute + } + + // Interface for all Objects that could be encoded/decoded using HttpPostRequestEncoder/Decoder + public interface IInterfaceHttpData : IComparable, IReferenceCounted + { + string Name { get; } + + HttpDataType DataType { get; } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/IInterfaceHttpPostRequestDecoder.cs b/src/DotNetty.Codecs.Http/Multipart/IInterfaceHttpPostRequestDecoder.cs new file mode 100644 index 0000000..0b790df --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/IInterfaceHttpPostRequestDecoder.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System.Collections.Generic; + using DotNetty.Common.Utilities; + + public interface IInterfaceHttpPostRequestDecoder + { + bool IsMultipart { get; } + + int DiscardThreshold { get; set; } + + List GetBodyHttpDatas(); + + List GetBodyHttpDatas(AsciiString name); + + IInterfaceHttpData GetBodyHttpData(AsciiString name); + + IInterfaceHttpPostRequestDecoder Offer(IHttpContent content); + + bool HasNext { get; } + + IInterfaceHttpData Next(); + + IInterfaceHttpData CurrentPartialHttpData { get; } + + void Destroy(); + + void CleanFiles(); + + void RemoveHttpDataFromClean(IInterfaceHttpData data); + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/InternalAttribute.cs b/src/DotNetty.Codecs.Http/Multipart/InternalAttribute.cs new file mode 100644 index 0000000..10bacef --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/InternalAttribute.cs @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common; + using DotNetty.Common.Utilities; + + sealed class InternalAttribute : AbstractReferenceCounted, IInterfaceHttpData + { + readonly List value = new List(); + readonly Encoding charset; + int size; + + internal InternalAttribute(Encoding charset) + { + this.charset = charset; + } + + public HttpDataType DataType => HttpDataType.InternalAttribute; + + public void AddValue(string stringValue) + { + Contract.Requires(stringValue != null); + + IByteBuffer buf = Unpooled.CopiedBuffer(this.charset.GetBytes(stringValue)); + this.value.Add(buf); + this.size += buf.ReadableBytes; + } + + public void AddValue(string stringValue, int rank) + { + Contract.Requires(stringValue != null); + + IByteBuffer buf = Unpooled.CopiedBuffer(this.charset.GetBytes(stringValue)); + this.value[rank] = buf; + this.size += buf.ReadableBytes; + } + + public void SetValue(string stringValue, int rank) + { + Contract.Requires(stringValue != null); + + IByteBuffer buf = Unpooled.CopiedBuffer(this.charset.GetBytes(stringValue)); + IByteBuffer old = this.value[rank]; + this.value[rank] = buf; + if (old != null) + { + this.size -= old.ReadableBytes; + old.Release(); + } + this.size += buf.ReadableBytes; + } + + public override int GetHashCode() => this.Name.GetHashCode(); + + public override bool Equals(object obj) + { + if (obj is InternalAttribute attribute) + { + return this.Name.Equals(attribute.Name, StringComparison.OrdinalIgnoreCase); + } + return false; + } + + public int CompareTo(IInterfaceHttpData other) + { + if (!(other is InternalAttribute)) + { + throw new ArgumentException($"Cannot compare {this.DataType} with {other.DataType}"); + } + + return this.CompareTo((InternalAttribute)other); + } + + public int CompareTo(InternalAttribute other) => string.Compare(this.Name, other.Name, StringComparison.OrdinalIgnoreCase); + + public override string ToString() + { + var result = new StringBuilder(); + foreach (IByteBuffer buf in this.value) + { + result.Append(buf.ToString(this.charset)); + } + + return result.ToString(); + } + + public int Size => this.size; + + public IByteBuffer ToByteBuffer() + { + CompositeByteBuffer compositeBuffer = Unpooled.CompositeBuffer(this.value.Count); + compositeBuffer.AddComponents(this.value); + compositeBuffer.SetWriterIndex(this.size); + compositeBuffer.SetReaderIndex(0); + + return compositeBuffer; + } + + public string Name => nameof(InternalAttribute); + + protected override void Deallocate() + { + // Do nothing + } + + protected override IReferenceCounted RetainCore(int increment) + { + foreach (IByteBuffer buf in this.value) + { + buf.Retain(increment); + } + return this; + } + + public override IReferenceCounted Touch(object hint) + { + foreach (IByteBuffer buf in this.value) + { + buf.Touch(hint); + } + return this; + } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/MemoryAttribute.cs b/src/DotNetty.Codecs.Http/Multipart/MemoryAttribute.cs new file mode 100644 index 0000000..a4ed93b --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/MemoryAttribute.cs @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using System.Diagnostics.Contracts; + using System.IO; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Transport.Channels; + + public class MemoryAttribute : AbstractMemoryHttpData, IAttribute + { + public MemoryAttribute(string name) + : this(name, HttpConstants.DefaultEncoding) + { + } + + public MemoryAttribute(string name, long definedSize) + : this(name, definedSize, HttpConstants.DefaultEncoding) + { + } + + public MemoryAttribute(string name, Encoding charset) + : base(name, charset, 0) + { + } + + public MemoryAttribute(string name, long definedSize, Encoding charset) + : base(name, charset, definedSize) + { + } + + public MemoryAttribute(string name, string value) + : this(name, value, HttpConstants.DefaultEncoding) + { + } + + public MemoryAttribute(string name, string value, Encoding contentEncoding) + : base(name, contentEncoding, 0) + { + this.Value = value; + } + + public override HttpDataType DataType => HttpDataType.Attribute; + + public string Value + { + get => this.GetByteBuffer().ToString(this.Charset); + set + { + Contract.Requires(value != null); + + byte[] bytes = this.Charset.GetBytes(value); + this.CheckSize(bytes.Length); + IByteBuffer buffer = Unpooled.WrappedBuffer(bytes); + if (this.DefinedSize > 0) + { + this.DefinedSize = buffer.ReadableBytes; + } + this.SetContent(buffer); + } + } + + public override void AddContent(IByteBuffer buffer, bool last) + { + int localsize = buffer.ReadableBytes; + this.CheckSize(this.Size + localsize); + if (this.DefinedSize > 0 && this.DefinedSize < this.Size + localsize) + { + this.DefinedSize = this.Size + localsize; + } + base.AddContent(buffer, last); + } + + public override int GetHashCode() => this.Name.GetHashCode(); + + public override bool Equals(object obj) + { + if (obj is IAttribute attribute) + { + return this.Name.Equals(attribute.Name, StringComparison.OrdinalIgnoreCase); + } + + return false; + } + + public override int CompareTo(IInterfaceHttpData other) + { + if (!(other is IAttribute)) + { + throw new ArgumentException($"Cannot compare {this.DataType} with {other.DataType}"); + } + + return this.CompareTo((IAttribute)other); + } + + public int CompareTo(IAttribute attribute) => string.Compare(this.Name, attribute.Name, StringComparison.OrdinalIgnoreCase); + + public override string ToString() => $"{this.Name} = {this.Value}"; + + public override IByteBufferHolder Copy() + { + IByteBuffer content = this.Content; + return this.Replace(content?.Copy()); + } + + public override IByteBufferHolder Duplicate() + { + IByteBuffer content = this.Content; + return this.Replace(content?.Duplicate()); + } + + public override IByteBufferHolder RetainedDuplicate() + { + IByteBuffer content = this.Content; + if (content != null) + { + content = content.RetainedDuplicate(); + bool success = false; + try + { + var duplicate = (IAttribute)this.Replace(content); + success = true; + return duplicate; + } + finally + { + if (!success) + { + content.Release(); + } + } + } + else + { + return this.Replace(null); + } + } + + public override IByteBufferHolder Replace(IByteBuffer content) + { + var attr = new MemoryAttribute(this.Name); + attr.Charset = this.Charset; + if (content != null) + { + try + { + attr.SetContent(content); + } + catch (IOException e) + { + throw new ChannelException(e); + } + } + return attr; + } + } +} \ No newline at end of file diff --git a/src/DotNetty.Codecs.Http/Multipart/MemoryFileUpload.cs b/src/DotNetty.Codecs.Http/Multipart/MemoryFileUpload.cs new file mode 100644 index 0000000..12bff90 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/MemoryFileUpload.cs @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoProperty +// ReSharper disable ConvertToAutoPropertyWhenPossible +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using System.Diagnostics.Contracts; + using System.IO; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Transport.Channels; + + public class MemoryFileUpload : AbstractMemoryHttpData, IFileUpload + { + string fileName; + string contentType; + string contentTransferEncoding; + + public MemoryFileUpload(string name, string fileName, string contentType, + string contentTransferEncoding, Encoding charset, long size) + : base(name, charset, size) + { + Contract.Requires(fileName != null); + Contract.Requires(contentType != null); + + this.fileName = fileName; + this.contentType = contentType; + this.contentTransferEncoding = contentTransferEncoding; + } + + public override HttpDataType DataType => HttpDataType.FileUpload; + + public string FileName + { + get => this.fileName; + set + { + Contract.Requires(value != null); + this.fileName = value; + } + } + + public override int GetHashCode() => FileUploadUtil.HashCode(this); + + public override bool Equals(object obj) + { + if (obj is IFileUpload fileUpload) + { + return FileUploadUtil.Equals(this, fileUpload); + } + return false; + } + + public override int CompareTo(IInterfaceHttpData other) + { + if (!(other is IFileUpload)) + { + throw new ArgumentException($"Cannot compare {this.DataType} with {other.DataType}"); + } + + return this.CompareTo((IFileUpload)other); + } + + public int CompareTo(IFileUpload other) => FileUploadUtil.CompareTo(this, other); + + public string ContentType + { + get => this.contentType; + set + { + Contract.Requires(value != null); + this.contentType = value; + } + } + + public string ContentTransferEncoding + { + get => this.contentTransferEncoding; + set => this.contentTransferEncoding = value; + } + + public override string ToString() + { + return HttpHeaderNames.ContentDisposition + ": " + + HttpHeaderValues.FormData + "; " + HttpHeaderValues.Name + "=\"" + this.Name + + "\"; " + HttpHeaderValues.FileName + "=\"" + this.FileName + "\"\r\n" + + HttpHeaderNames.ContentType + ": " + this.contentType + + (this.Charset != null ? "; " + HttpHeaderValues.Charset + '=' + this.Charset.WebName + "\r\n" : "\r\n") + + HttpHeaderNames.ContentLength + ": " + this.Length + "\r\n" + + "Completed: " + this.IsCompleted + + "\r\nIsInMemory: " + this.IsInMemory; + } + + public override IByteBufferHolder Copy() => this.Replace(this.Content?.Copy()); + + public override IByteBufferHolder Duplicate() => this.Replace(this.Content?.Duplicate()); + + public override IByteBufferHolder RetainedDuplicate() + { + IByteBuffer content = this.Content; + if (content != null) + { + content = content.RetainedDuplicate(); + bool success = false; + try + { + var duplicate = (IFileUpload)this.Replace(content); + success = true; + return duplicate; + } + finally + { + if (!success) + { + content.Release(); + } + } + } + else + { + return this.Replace(null); + } + } + + public override IByteBufferHolder Replace(IByteBuffer content) + { + var upload = new MemoryFileUpload( + this.Name, this.FileName, this.ContentType, this.contentTransferEncoding, this.Charset, this.Size); + if (content != null) + { + try + { + upload.SetContent(content); + return upload; + } + catch (IOException e) + { + throw new ChannelException(e); + } + } + return upload; + } + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/MixedAttribute.cs b/src/DotNetty.Codecs.Http/Multipart/MixedAttribute.cs new file mode 100644 index 0000000..ed91a74 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/MixedAttribute.cs @@ -0,0 +1,249 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + using System.IO; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common; + + public class MixedAttribute : IAttribute + { + IAttribute attribute; + + readonly long limitSize; + long maxSize = DefaultHttpDataFactory.MaxSize; + + public MixedAttribute(string name, long limitSize) + : this(name, limitSize, HttpConstants.DefaultEncoding) + { + } + + public MixedAttribute(string name, long definedSize, long limitSize) + : this(name, definedSize, limitSize, HttpConstants.DefaultEncoding) + { + } + + public MixedAttribute(string name, long limitSize, Encoding contentEncoding) + { + this.limitSize = limitSize; + this.attribute = new MemoryAttribute(name, contentEncoding); + } + + public MixedAttribute(string name, long definedSize, long limitSize, Encoding contentEncoding) + { + this.limitSize = limitSize; + this.attribute = new MemoryAttribute(name, definedSize, contentEncoding); + } + + public MixedAttribute(string name, string value, long limitSize) + : this(name, value, limitSize, HttpConstants.DefaultEncoding) + { + } + + public MixedAttribute(string name, string value, long limitSize, Encoding charset) + { + this.limitSize = limitSize; + if (value.Length > this.limitSize) + { + try + { + this.attribute = new DiskAttribute(name, value, charset); + } + catch (IOException e) + { + // revert to Memory mode + try + { + this.attribute = new MemoryAttribute(name, value, charset); + } + catch (IOException) + { + throw new ArgumentException($"{name}", e); + } + } + } + else + { + try + { + this.attribute = new MemoryAttribute(name, value, charset); + } + catch (IOException e) + { + throw new ArgumentException($"{name}", e); + } + } + } + + + public long MaxSize + { + get => this.maxSize; + set + { + this.maxSize = value; + this.attribute.MaxSize = this.maxSize; + } + } + + public void CheckSize(long newSize) + { + if (this.maxSize >= 0 && newSize > this.maxSize) + { + throw new IOException($"Size exceed allowed maximum capacity of {this.maxSize}"); + } + } + + public void AddContent(IByteBuffer buffer, bool last) + { + if (this.attribute is MemoryAttribute memoryAttribute) + { + this.CheckSize(this.attribute.Length + buffer.ReadableBytes); + if (this.attribute.Length + buffer.ReadableBytes > this.limitSize) + { + var diskAttribute = new DiskAttribute(this.attribute.Name, this.attribute.DefinedLength); + diskAttribute.MaxSize = this.maxSize; + if (memoryAttribute.GetByteBuffer() != null) + { + diskAttribute.AddContent(memoryAttribute.GetByteBuffer(), false); + } + this.attribute = diskAttribute; + } + } + this.attribute.AddContent(buffer, last); + } + + public void Delete() => this.attribute.Delete(); + + public byte[] GetBytes() => this.attribute.GetBytes(); + + public IByteBuffer GetByteBuffer() => this.attribute.GetByteBuffer(); + + public Encoding Charset + { + get => this.attribute.Charset; + set => this.attribute.Charset = value; + } + + public string GetString() => this.attribute.GetString(); + + public string GetString(Encoding charset) => this.attribute.GetString(charset); + + public bool IsCompleted => this.attribute.IsCompleted; + + public bool IsInMemory => this.attribute.IsInMemory; + + public long Length => this.attribute.Length; + + public long DefinedLength => this.attribute.DefinedLength; + + public bool RenameTo(FileStream destination) => this.attribute.RenameTo(destination); + + public void SetContent(IByteBuffer buffer) + { + this.CheckSize(buffer.ReadableBytes); + if (buffer.ReadableBytes > this.limitSize) + { + if (this.attribute is MemoryAttribute) + { + // change to Disk + this.attribute = new DiskAttribute(this.attribute.Name, this.attribute.DefinedLength); + this.attribute.MaxSize = this.maxSize; + } + } + this.attribute.SetContent(buffer); + } + + public void SetContent(Stream source) + { + this.CheckSize(source.Length); + if (source.Length > this.limitSize) + { + if (this.attribute is MemoryAttribute) + { + // change to Disk + this.attribute = new DiskAttribute(this.attribute.Name, this.attribute.DefinedLength); + this.attribute.MaxSize = this.maxSize; + } + } + this.attribute.SetContent(source); + } + + public HttpDataType DataType => this.attribute.DataType; + + public string Name => this.attribute.Name; + + // ReSharper disable once NonReadonlyMemberInGetHashCode + public override int GetHashCode() => this.attribute.GetHashCode(); + + public override bool Equals(object obj) => this.attribute.Equals(obj); + + public int CompareTo(IInterfaceHttpData other) => this.attribute.CompareTo(other); + + public override string ToString() => $"Mixed: {this.attribute}"; + + public string Value + { + get => this.attribute.Value; + set + { + if (value != null) + { + byte[] bytes = this.Charset != null + ? this.Charset.GetBytes(value) + : HttpConstants.DefaultEncoding.GetBytes(value); + this.CheckSize(bytes.Length); + } + + this.attribute.Value = value; + } + } + + public IByteBuffer GetChunk(int length) => this.attribute.GetChunk(length); + + public FileStream GetFile() => this.attribute.GetFile(); + + public IByteBufferHolder Copy() => this.attribute.Copy(); + + public IByteBufferHolder Duplicate() => this.attribute.Duplicate(); + + public IByteBufferHolder RetainedDuplicate() => this.attribute.RetainedDuplicate(); + + public IByteBufferHolder Replace(IByteBuffer content) => this.attribute.Replace(content); + + public IByteBuffer Content => this.attribute.Content; + + public int ReferenceCount => this.attribute.ReferenceCount; + + public IReferenceCounted Retain() + { + this.attribute.Retain(); + return this; + } + + public IReferenceCounted Retain(int increment) + { + this.attribute.Retain(increment); + return this; + } + + public IReferenceCounted Touch() + { + this.attribute.Touch(); + return this; + } + + public IReferenceCounted Touch(object hint) + { + this.attribute.Touch(hint); + return this; + } + + public bool Release() => this.attribute.Release(); + + public bool Release(int decrement) => this.attribute.Release(decrement); + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/MixedFileUpload.cs b/src/DotNetty.Codecs.Http/Multipart/MixedFileUpload.cs new file mode 100644 index 0000000..104d10f --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/MixedFileUpload.cs @@ -0,0 +1,229 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System.IO; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common; + + public class MixedFileUpload : IFileUpload + { + IFileUpload fileUpload; + + readonly long limitSize; + readonly long definedSize; + + long maxSize = DefaultHttpDataFactory.MaxSize; + + + public MixedFileUpload(string name, string fileName, string contentType, + string contentTransferEncoding, Encoding charset, long size, + long limitSize) + { + this.limitSize = limitSize; + if (size > this.limitSize) + { + this.fileUpload = new DiskFileUpload(name, fileName, contentType, + contentTransferEncoding, charset, size); + } + else + { + this.fileUpload = new MemoryFileUpload(name, fileName, contentType, + contentTransferEncoding, charset, size); + } + this.definedSize = size; + } + + public long MaxSize + { + get => this.maxSize; + set + { + this.maxSize = value; + this.fileUpload.MaxSize = value; + } + } + + public void CheckSize(long newSize) + { + if (this.maxSize >= 0 && newSize > this.maxSize) + { + throw new IOException($"{this.DataType} Size exceed allowed maximum capacity"); + } + } + + public void AddContent(IByteBuffer buffer, bool last) + { + if (this.fileUpload is MemoryFileUpload) { + this.CheckSize(this.fileUpload.Length + buffer.ReadableBytes); + if (this.fileUpload.Length + buffer.ReadableBytes > this.limitSize) + { + var diskFileUpload = new DiskFileUpload( + this.fileUpload.Name, this.fileUpload.FileName, + this.fileUpload.ContentType, + this.fileUpload.ContentTransferEncoding, this.fileUpload.Charset, + this.definedSize); + diskFileUpload.MaxSize = this.maxSize; + IByteBuffer data = this.fileUpload.GetByteBuffer(); + if (data != null && data.IsReadable()) + { + diskFileUpload.AddContent((IByteBuffer)data.Retain(), false); + } + // release old upload + this.fileUpload.Release(); + + this.fileUpload = diskFileUpload; + } + } + this.fileUpload.AddContent(buffer, last); + } + + public void Delete() => this.fileUpload.Delete(); + + public byte[] GetBytes() => this.fileUpload.GetBytes(); + + public IByteBuffer GetByteBuffer() => this.fileUpload.GetByteBuffer(); + + public Encoding Charset + { + get => this.fileUpload.Charset; + set => this.fileUpload.Charset = value; + } + + public string ContentType + { + get => this.fileUpload.ContentType; + set => this.fileUpload.ContentType = value; + } + + public string ContentTransferEncoding + { + get => this.fileUpload.ContentTransferEncoding; + set => this.fileUpload.ContentTransferEncoding = value; + } + + public string FileName + { + get => this.fileUpload.FileName; + set => this.fileUpload.FileName = value; + } + + public string GetString() => this.fileUpload.GetString(); + + public string GetString(Encoding encoding) => this.fileUpload.GetString(encoding); + + public bool IsCompleted => this.fileUpload.IsCompleted; + + public bool IsInMemory => this.fileUpload.IsInMemory; + + public long Length => this.fileUpload.Length; + + public long DefinedLength => this.fileUpload.DefinedLength; + + public bool RenameTo(FileStream destination) => this.fileUpload.RenameTo(destination); + + public void SetContent(IByteBuffer buffer) + { + this.CheckSize(buffer.ReadableBytes); + if (buffer.ReadableBytes > this.limitSize) + { + if (this.fileUpload is MemoryFileUpload memoryUpload) + { + // change to Disk + this.fileUpload = new DiskFileUpload( + memoryUpload.Name, + memoryUpload.FileName, + memoryUpload.ContentType, + memoryUpload.ContentTransferEncoding, + memoryUpload.Charset, + this.definedSize); + this.fileUpload.MaxSize = this.maxSize; + + // release old upload + memoryUpload.Release(); + } + } + this.fileUpload.SetContent(buffer); + } + + public void SetContent(Stream inputStream) + { + if (this.fileUpload is MemoryFileUpload) + { + IFileUpload memoryUpload = this.fileUpload; + // change to Disk + this.fileUpload = new DiskFileUpload( + this.fileUpload.Name, + this.fileUpload.FileName, + this.fileUpload.ContentType, + this.fileUpload.ContentTransferEncoding, + this.fileUpload.Charset, + this.definedSize); + this.fileUpload.MaxSize = this.maxSize; + + // release old upload + memoryUpload.Release(); + } + this.fileUpload.SetContent(inputStream); + } + + public HttpDataType DataType => this.fileUpload.DataType; + + public string Name => this.fileUpload.Name; + + // ReSharper disable once NonReadonlyMemberInGetHashCode + public override int GetHashCode() => this.fileUpload.GetHashCode(); + + public override bool Equals(object obj) => this.fileUpload.Equals(obj); + + public int CompareTo(IInterfaceHttpData other) => this.fileUpload.CompareTo(other); + + public override string ToString() => $"Mixed: {this.fileUpload}"; + + public IByteBuffer GetChunk(int length) => this.fileUpload.GetChunk(length); + + public FileStream GetFile() => this.fileUpload.GetFile(); + + public IByteBufferHolder Copy() => this.fileUpload.Copy(); + + public IByteBufferHolder Duplicate() => this.fileUpload.Duplicate(); + + public IByteBufferHolder RetainedDuplicate() => this.fileUpload.RetainedDuplicate(); + + public IByteBufferHolder Replace(IByteBuffer content) => this.fileUpload.Replace(content); + + public IByteBuffer Content => this.fileUpload.Content; + + public int ReferenceCount => this.fileUpload.ReferenceCount; + + public IReferenceCounted Retain() + { + this.fileUpload.Retain(); + return this; + } + + public IReferenceCounted Retain(int increment) + { + this.fileUpload.Retain(increment); + return this; + } + + public IReferenceCounted Touch() + { + this.fileUpload.Touch(); + return this; + } + + public IReferenceCounted Touch(object hint) + { + this.fileUpload.Touch(hint); + return this; + } + + public bool Release() => this.fileUpload.Release(); + + public bool Release(int decrement) => this.fileUpload.Release(decrement); + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/MultiPartStatus.cs b/src/DotNetty.Codecs.Http/Multipart/MultiPartStatus.cs new file mode 100644 index 0000000..5b77ea0 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/MultiPartStatus.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + enum MultiPartStatus + { + Notstarted, + Preamble, + HeaderDelimiter, + Disposition, + Field, + Fileupload, + MixedPreamble, + MixedDelimiter, + MixedDisposition, + MixedFileUpload, + MixedCloseDelimiter, + CloseDelimiter, + PreEpilogue, + Epilogue + } +} diff --git a/src/DotNetty.Codecs.Http/Multipart/NotEnoughDataDecoderException.cs b/src/DotNetty.Codecs.Http/Multipart/NotEnoughDataDecoderException.cs new file mode 100644 index 0000000..4a7ec16 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Multipart/NotEnoughDataDecoderException.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Multipart +{ + using System; + + public class NotEnoughDataDecoderException : DecoderException + { + public NotEnoughDataDecoderException(string message) : base(message) + { + } + + public NotEnoughDataDecoderException(Exception innerException) : base(innerException) + { + } + } +} diff --git a/src/DotNetty.Codecs.Http/Properties/AssemblyInfo.cs b/src/DotNetty.Codecs.Http/Properties/AssemblyInfo.cs new file mode 100644 index 0000000..bea1740 --- /dev/null +++ b/src/DotNetty.Codecs.Http/Properties/AssemblyInfo.cs @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Reflection; +using System.Resources; + +[assembly: NeutralResourcesLanguage("en-US")] +[assembly: AssemblyMetadata("Serviceable", "True")] \ No newline at end of file diff --git a/src/DotNetty.Codecs.Http/Properties/Friends.cs b/src/DotNetty.Codecs.Http/Properties/Friends.cs new file mode 100644 index 0000000..742a8ed --- /dev/null +++ b/src/DotNetty.Codecs.Http/Properties/Friends.cs @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("DotNetty.Codecs.Http.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100d9782d5a0b850f230f71e06de2e101d8441d83e15eef715837eee38fdbf5cb369b41ec36e6e7668c18cbb09e5419c179360461e740c1cce6ffbdcf81f245e1e705482797fe42aff2d31ecd72ea87362ded3c14066746fbab4a8e1896f8b982323c84e2c1b08407c0de18b7feef1535fb972a3b26181f5a304ebd181795a46d8f")] +[assembly: InternalsVisibleTo("DotNetty.Microbench, PublicKey=0024000004800000940000000602000000240000525341310004000001000100d9782d5a0b850f230f71e06de2e101d8441d83e15eef715837eee38fdbf5cb369b41ec36e6e7668c18cbb09e5419c179360461e740c1cce6ffbdcf81f245e1e705482797fe42aff2d31ecd72ea87362ded3c14066746fbab4a8e1896f8b982323c84e2c1b08407c0de18b7feef1535fb972a3b26181f5a304ebd181795a46d8f")] diff --git a/src/DotNetty.Codecs.Http/QueryStringDecoder.cs b/src/DotNetty.Codecs.Http/QueryStringDecoder.cs new file mode 100644 index 0000000..40c93cc --- /dev/null +++ b/src/DotNetty.Codecs.Http/QueryStringDecoder.cs @@ -0,0 +1,252 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable ConvertToAutoProperty +namespace DotNetty.Codecs.Http +{ + using System; + using System.Collections.Generic; + using System.Collections.Immutable; + using System.Diagnostics.Contracts; + using System.Text; + using DotNetty.Common.Utilities; + + public class QueryStringDecoder + { + const int DefaultMaxParams = 1024; + + readonly Encoding charset; + readonly string uri; + readonly int maxParams; + int pathEndIdx; + string path; + IDictionary> parameters; + + public QueryStringDecoder(string uri) : this(uri, HttpConstants.DefaultEncoding) + { + } + + public QueryStringDecoder(string uri, bool hasPath) : this(uri, HttpConstants.DefaultEncoding, hasPath) + { + } + + public QueryStringDecoder(string uri, Encoding charset) : this(uri, charset, true) + { + } + + public QueryStringDecoder(string uri, Encoding charset, bool hasPath) : this(uri, charset, hasPath, DefaultMaxParams) + { + } + + public QueryStringDecoder(string uri, Encoding charset, bool hasPath, int maxParams) + { + Contract.Requires(uri != null); + Contract.Requires(charset != null); + Contract.Requires(maxParams > 0); + + this.uri = uri; + this.charset = charset; + this.maxParams = maxParams; + + // -1 means that path end index will be initialized lazily + this.pathEndIdx = hasPath ? -1 : 0; + } + + public QueryStringDecoder(Uri uri) : this(uri, HttpConstants.DefaultEncoding) + { + } + + public QueryStringDecoder(Uri uri, Encoding charset) : this(uri, charset, DefaultMaxParams) + { + } + + public QueryStringDecoder(Uri uri, Encoding charset, int maxParams) + { + Contract.Requires(uri != null); + Contract.Requires(charset != null); + Contract.Requires(maxParams > 0); + + string rawPath = uri.AbsolutePath; + // Also take care of cut of things like "http://localhost" + this.uri = uri.PathAndQuery; + this.charset = charset; + this.maxParams = maxParams; + this.pathEndIdx = rawPath.Length; + } + + public override string ToString() => this.uri; + + public string Path => this.path ?? + (this.path = DecodeComponent(this.uri, 0, this.PathEndIdx(), this.charset, true)); + + public IDictionary> Parameters => this.parameters ?? + (this.parameters = DecodeParams(this.uri, this.PathEndIdx(), this.charset, this.maxParams)); + + public string RawPath() => this.uri.Substring(0, this.PathEndIdx()); + + public string RawQuery() + { + int start = this.pathEndIdx + 1; + return start < this.uri.Length ? this.uri.Substring(start) : StringUtil.EmptyString; + } + + int PathEndIdx() + { + if (this.pathEndIdx == -1) + { + this.pathEndIdx = FindPathEndIndex(this.uri); + } + return this.pathEndIdx; + } + + static IDictionary> DecodeParams(string s, int from, Encoding charset, int paramsLimit) + { + int len = s.Length; + if (from >= len) + { + return ImmutableDictionary>.Empty; + } + if (s[from] == '?') + { + from++; + } + var parameters = new Dictionary>(); + int nameStart = from; + int valueStart = -1; + int i; + //loop: + for (i = from; i < len; i++) + { + switch (s[i]) + { + case '=': + if (nameStart == i) + { + nameStart = i + 1; + } + else if (valueStart < nameStart) + { + valueStart = i + 1; + } + break; + case '&': + case ';': + if (AddParam(s, nameStart, valueStart, i, parameters, charset)) + { + paramsLimit--; + if (paramsLimit == 0) + { + return parameters; + } + } + nameStart = i + 1; + break; + case '#': + goto loop; + } + } + loop: + AddParam(s, nameStart, valueStart, i, parameters, charset); + return parameters; + } + + static bool AddParam(string s, int nameStart, int valueStart, int valueEnd, + Dictionary> parameters, Encoding charset) + { + if (nameStart >= valueEnd) + { + return false; + } + if (valueStart <= nameStart) + { + valueStart = valueEnd + 1; + } + string name = DecodeComponent(s, nameStart, valueStart - 1, charset, false); + string value = DecodeComponent(s, valueStart, valueEnd, charset, false); + if (!parameters.TryGetValue(name, out List values)) + { + values = new List(1); // Often there's only 1 value. + parameters.Add(name, values); + } + values.Add(value); + return true; + } + + public static string DecodeComponent(string s) => DecodeComponent(s, HttpConstants.DefaultEncoding); + + public static string DecodeComponent(string s, Encoding charset) => s == null + ? StringUtil.EmptyString : DecodeComponent(s, 0, s.Length, charset, false); + + static string DecodeComponent(string s, int from, int toExcluded, Encoding charset, bool isPath) + { + int len = toExcluded - from; + if (len <= 0) + { + return StringUtil.EmptyString; + } + int firstEscaped = -1; + for (int i = from; i < toExcluded; i++) + { + char c = s[i]; + if (c == '%' || c == '+' && !isPath) + { + firstEscaped = i; + break; + } + } + if (firstEscaped == -1) + { + return s.Substring(from, len); + } + + // Each encoded byte takes 3 characters (e.g. "%20") + int decodedCapacity = (toExcluded - firstEscaped) / 3; + var byteBuf = new byte[decodedCapacity]; + int idx; + var strBuf = new StringBuilder(len); + strBuf.Append(s, from, firstEscaped - from); + + for (int i = firstEscaped; i < toExcluded; i++) + { + char c = s[i]; + if (c != '%') + { + strBuf.Append(c != '+' || isPath ? c : StringUtil.Space); + continue; + } + + idx = 0; + do + { + if (i + 3 > toExcluded) + { + throw new ArgumentException($"unterminated escape sequence at index {i} of: {s}"); + } + byteBuf[idx++] = StringUtil.DecodeHexByte(s, i + 1); + i += 3; + } + while (i < toExcluded && s[i] == '%'); + i--; + + strBuf.Append(charset.GetString(byteBuf, 0, idx)); + } + + return strBuf.ToString(); + } + + static int FindPathEndIndex(string uri) + { + int len = uri.Length; + for (int i = 0; i < len; i++) + { + char c = uri[i]; + if (c == '?' || c == '#') + { + return i; + } + } + return len; + } + } +} diff --git a/src/DotNetty.Codecs.Http/QueryStringEncoder.cs b/src/DotNetty.Codecs.Http/QueryStringEncoder.cs new file mode 100644 index 0000000..6f35660 --- /dev/null +++ b/src/DotNetty.Codecs.Http/QueryStringEncoder.cs @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System.Diagnostics.Contracts; + using System.Text; + + /// + /// Creates an URL-encoded URI from a path string and key-value parameter pairs. + /// This encoder is for one time use only. Create a new instance for each URI. + /// + /// {@link QueryStringEncoder} encoder = new {@link QueryStringEncoder}("/hello"); + /// encoder.addParam("recipient", "world"); + /// assert encoder.toString().equals("/hello?recipient=world"); + /// + public class QueryStringEncoder + { + readonly Encoding encoding; + readonly StringBuilder uriBuilder; + bool hasParams; + + public QueryStringEncoder(string uri) : this(uri, HttpConstants.DefaultEncoding) + { + } + + public QueryStringEncoder(string uri, Encoding encoding) + { + this.uriBuilder = new StringBuilder(uri); + this.encoding = encoding; + } + + public void AddParam(string name, string value) + { + Contract.Requires(name != null); + if (this.hasParams) + { + this.uriBuilder.Append('&'); + } + else + { + this.uriBuilder.Append('?'); + this.hasParams = true; + } + + AppendComponent(name, this.encoding, this.uriBuilder); + if (value != null) + { + this.uriBuilder.Append('='); + AppendComponent(value, this.encoding, this.uriBuilder); + } + } + + public override string ToString() => this.uriBuilder.ToString(); + + static void AppendComponent(string s, Encoding encoding, StringBuilder sb) + { + s = UrlEncoder.Encode(s, encoding); + // replace all '+' with "%20" + int idx = s.IndexOf('+'); + if (idx == -1) + { + sb.Append(s); + return; + } + sb.Append(s, 0, idx).Append("%20"); + int size = s.Length; + idx++; + for (; idx < size; idx++) + { + char c = s[idx]; + if (c != '+') + { + sb.Append(c); + } + else + { + sb.Append("%20"); + } + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/ThrowHelper.cs b/src/DotNetty.Codecs.Http/ThrowHelper.cs new file mode 100644 index 0000000..55d2ff7 --- /dev/null +++ b/src/DotNetty.Codecs.Http/ThrowHelper.cs @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable NotResolvedInText +namespace DotNetty.Codecs.Http +{ + using System; + using System.Runtime.CompilerServices; + using DotNetty.Common.Utilities; + + static class ThrowHelper + { + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowArgumentException_NullText() + { + throw GetArgumentException(); + + ArgumentException GetArgumentException() + { + return new ArgumentException("text"); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowArgumentException_EmptyText() + { + throw GetArgumentException(); + + ArgumentException GetArgumentException() + { + return new ArgumentException("text is empty (possibly HTTP/0.9)"); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowArgumentException_HeaderName() + { + throw GetArgumentException(); + + ArgumentException GetArgumentException() + { + return new ArgumentException("empty headers are not allowed", "name"); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowArgumentException_TrailingHeaderName(ICharSequence name) + { + throw GetArgumentException(); + + ArgumentException GetArgumentException() + { + return new ArgumentException(string.Format("prohibited trailing header: {0}", name)); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowArgumentException_HeaderValue(byte value) + { + throw GetArgumentException(); + + ArgumentException GetArgumentException() + { + return new ArgumentException(string.Format("a header name cannot contain the following prohibited characters: =,;: \\t\\r\\n\\v\\f: {0}", value)); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowArgumentException_HeaderValue(char value) + { + throw GetArgumentException(); + + ArgumentException GetArgumentException() + { + return new ArgumentException(string.Format("a header name cannot contain the following prohibited characters: =,;: \\t\\r\\n\\v\\f: {0}", value)); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowArgumentException_HeaderValueNonAscii(byte value) + { + throw GetArgumentException(); + + ArgumentException GetArgumentException() + { + return new ArgumentException(string.Format("a header name cannot contain non-ASCII character: {0}", value)); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowArgumentException_HeaderValueNonAscii(char value) + { + throw GetArgumentException(); + + ArgumentException GetArgumentException() + { + return new ArgumentException(string.Format("a header name cannot contain non-ASCII character: {0}", value)); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowArgumentException_HeaderValueEnd(ICharSequence seq) + { + throw GetArgumentException(); + + ArgumentException GetArgumentException() + { + return new ArgumentException(string.Format("a header value must not end with '\\r' or '\\n':{0}", seq)); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowArgumentException_HeaderValueNullChar() + { + throw GetArgumentException(); + + ArgumentException GetArgumentException() + { + return new ArgumentException("a header value contains a prohibited character '\0'"); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowArgumentException_HeaderValueVerticalTabChar() + { + throw GetArgumentException(); + + ArgumentException GetArgumentException() + { + return new ArgumentException("a header value contains a prohibited character '\\v'"); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowArgumentException_HeaderValueFormFeed() + { + throw GetArgumentException(); + + ArgumentException GetArgumentException() + { + return new ArgumentException("a header value contains a prohibited character '\\f'"); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowArgumentException_NewLineAfterLineFeed() + { + throw GetArgumentException(); + + ArgumentException GetArgumentException() + { + return new ArgumentException("only '\\n' is allowed after '\\r'"); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowArgumentException_TabAndSpaceAfterLineFeed() + { + throw GetArgumentException(); + + ArgumentException GetArgumentException() + { + return new ArgumentException("only ' ' and '\\t' are allowed after '\\n'"); + } + } + } +} diff --git a/src/DotNetty.Codecs.Http/UrlEncoder.cs b/src/DotNetty.Codecs.Http/UrlEncoder.cs new file mode 100644 index 0000000..fc79cc7 --- /dev/null +++ b/src/DotNetty.Codecs.Http/UrlEncoder.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http +{ + using System.Text; + using DotNetty.Common.Utilities; + + // Similar to java URLEncoder + static class UrlEncoder + { + public static unsafe string Encode(string s, Encoding encoding) + { + if (string.IsNullOrEmpty(s)) + { + return s; + } + + int length = s.Length; + fixed (char* p = s) + { + int count = encoding.GetMaxByteCount(1); + int total = length * count * 3; // bytes per char encoded maximum + char* bytes = stackalloc char[total]; + byte* buf = stackalloc byte[count]; + + int index = 0; + for (int i = 0; i < length; i++) + { + char ch = *(p + i); + if ((ch >= 'a' && ch <= 'z') + || (ch >= 'A' && ch <= 'Z') + || (ch >= '0' && ch <= '9')) + { + bytes[index++] = ch; + } + else + { + total = encoding.GetBytes(p + i, 1, buf, count); + for (int j = 0; j < total; j++) + { + bytes[index++] = '%'; + bytes[index++] = CharUtil.Digits[(buf[j] & 0xf0) >> 4]; + bytes[index++] = CharUtil.Digits[buf[j] & 0xf]; + } + } + } + return new string(bytes, 0, index); + } + } + } +} diff --git a/src/DotNetty.Codecs/CharSequenceValueConverter.cs b/src/DotNetty.Codecs/CharSequenceValueConverter.cs new file mode 100644 index 0000000..0cfffcd --- /dev/null +++ b/src/DotNetty.Codecs/CharSequenceValueConverter.cs @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs +{ + using System; + using System.Globalization; + using DotNetty.Common.Utilities; + + public class CharSequenceValueConverter : IValueConverter + { + public static readonly CharSequenceValueConverter Default = new CharSequenceValueConverter(); + static readonly AsciiString TrueAscii = new AsciiString("true"); + + public virtual ICharSequence ConvertObject(object value) + { + if (value is ICharSequence sequence) + { + return sequence; + } + return new StringCharSequence(value.ToString()); + } + + public ICharSequence ConvertInt(int value) => new StringCharSequence(value.ToString()); + + public ICharSequence ConvertLong(long value) => new StringCharSequence(value.ToString()); + + public ICharSequence ConvertDouble(double value) => new StringCharSequence(value.ToString(CultureInfo.InvariantCulture)); + + public ICharSequence ConvertChar(char value) => new StringCharSequence(value.ToString()); + + public ICharSequence ConvertBoolean(bool value) => new StringCharSequence(value.ToString()); + + public ICharSequence ConvertFloat(float value) => new StringCharSequence(value.ToString(CultureInfo.InvariantCulture)); + + public bool ConvertToBoolean(ICharSequence value) => AsciiString.ContentEqualsIgnoreCase(value, TrueAscii); + + public ICharSequence ConvertByte(byte value) => new StringCharSequence(value.ToString()); + + public byte ConvertToByte(ICharSequence value) + { + if (value is AsciiString asciiString) + { + return asciiString.ByteAt(0); + } + return byte.Parse(value.ToString()); + } + + public char ConvertToChar(ICharSequence value) => value[0]; + + public ICharSequence ConvertShort(short value) => new StringCharSequence(value.ToString()); + + public short ConvertToShort(ICharSequence value) + { + if (value is AsciiString asciiString) + { + return asciiString.ParseShort(); + } + return short.Parse(value.ToString()); + } + + public int ConvertToInt(ICharSequence value) + { + if (value is AsciiString asciiString) + { + return asciiString.ParseInt(); + } + return int.Parse(value.ToString()); + } + + public long ConvertToLong(ICharSequence value) + { + if (value is AsciiString asciiString) + { + return asciiString.ParseLong(); + } + return long.Parse(value.ToString()); + } + + public ICharSequence ConvertTimeMillis(long value) => new StringCharSequence(value.ToString()); + + public long ConvertToTimeMillis(ICharSequence value) + { + DateTime? dateTime = DateFormatter.ParseHttpDate(value); + if (dateTime == null) + { + throw new FormatException($"header can't be parsed into a Date: {value}"); + } + return dateTime.Value.Ticks / TimeSpan.TicksPerMillisecond; + } + + public float ConvertToFloat(ICharSequence value) + { + if (value is AsciiString asciiString) + { + return asciiString.ParseFloat(); + } + return float.Parse(value.ToString(), NumberStyles.Any, CultureInfo.InvariantCulture); + } + + + public double ConvertToDouble(ICharSequence value) + { + if (value is AsciiString asciiString) + { + return asciiString.ParseDouble(); + } + return double.Parse(value.ToString(), NumberStyles.Any, CultureInfo.InvariantCulture); + } + } +} diff --git a/src/DotNetty.Codecs/Compression/Adler32.cs b/src/DotNetty.Codecs/Compression/Adler32.cs new file mode 100644 index 0000000..ed77905 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/Adler32.cs @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +//JZlib 0.0.* were released under the GNU LGPL license.Later, we have switched +//over to a BSD-style license. + +//------------------------------------------------------------------------------ +//Copyright (c) 2000-2011 ymnk, JCraft, Inc.All rights reserved. + +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in +// the documentation and/or other materials provided with the distribution. + +// 3. The names of the authors may not be used to endorse or promote products +// derived from this software without specific prior written permission. + +//THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, +//INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +//FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.IN NO EVENT SHALL JCRAFT, +//INC.OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +//LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +//OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +//LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING +//NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +//EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ReSharper disable ArrangeThisQualifier +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Compression +{ + /// + /// https://github.com/ymnk/jzlib/blob/master/src/main/java/com/jcraft/jzlib/Adler32.java + /// + sealed class Adler32 : IChecksum + { + // largest prime smaller than 65536 + const int BASE = 65521; + // NMAX is the largest n such that 255n(n+1)/2 + (n+1)(BASE-1) <= 2^32-1 + const int NMAX = 5552; + + long s1 = 1L; + long s2; + + public void Reset(long init) + { + s1 = init & 0xffff; + s2 = (init >> 16) & 0xffff; + } + + public void Reset() + { + s1 = 1L; + s2 = 0L; + } + + public long GetValue() => ((s2 << 16) | s1); + + public void Update(byte[] buf, int index, int len) + { + if (len == 1) + { + s1 += buf[index] & 0xff; s2 += s1; + s1 %= BASE; + s2 %= BASE; + return; + } + + int len1 = len / NMAX; + int len2 = len % NMAX; + while (len1-- > 0) + { + int k = NMAX; + len -= k; + while (k-- > 0) + { + s1 += buf[index++] & 0xff; s2 += s1; + } + s1 %= BASE; + s2 %= BASE; + } + + int k0 = len2; + while (k0-- > 0) + { + s1 += buf[index++] & 0xff; s2 += s1; + } + s1 %= BASE; + s2 %= BASE; + } + + public IChecksum Copy() + { + var foo = new Adler32(); + foo.s1 = this.s1; + foo.s2 = this.s2; + return foo; + } + + // The following logic has come from zlib.1.2. + internal static long Combine(long adler1, long adler2, long len2) + { + long BASEL = BASE; + long sum1; + long sum2; + long rem; // unsigned int + + rem = len2 % BASEL; + sum1 = adler1 & 0xffffL; + sum2 = rem * sum1; + sum2 %= BASEL; // MOD(sum2); + sum1 += (adler2 & 0xffffL) + BASEL - 1; + sum2 += ((adler1 >> 16) & 0xffffL) + ((adler2 >> 16) & 0xffffL) + BASEL - rem; + if (sum1 >= BASEL) sum1 -= BASEL; + if (sum1 >= BASEL) sum1 -= BASEL; + if (sum2 >= (BASEL << 1)) sum2 -= (BASEL << 1); + if (sum2 >= BASEL) sum2 -= BASEL; + return sum1 | (sum2 << 16); + } + } +} diff --git a/src/DotNetty.Codecs/Compression/CRC32.cs b/src/DotNetty.Codecs/Compression/CRC32.cs new file mode 100644 index 0000000..3ad39de --- /dev/null +++ b/src/DotNetty.Codecs/Compression/CRC32.cs @@ -0,0 +1,196 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +//JZlib 0.0.* were released under the GNU LGPL license.Later, we have switched +//over to a BSD-style license. + +//------------------------------------------------------------------------------ +//Copyright (c) 2000-2011 ymnk, JCraft, Inc.All rights reserved. + +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in +// the documentation and/or other materials provided with the distribution. + +// 3. The names of the authors may not be used to endorse or promote products +// derived from this software without specific prior written permission. + +//THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, +//INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +//FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.IN NO EVENT SHALL JCRAFT, +//INC.OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +//LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +//OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +//LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING +//NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +//EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ReSharper disable ArrangeThisQualifier +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Compression +{ + using System; + using DotNetty.Common.Utilities; + + /// + /// https://github.com/ymnk/jzlib/blob/master/src/main/java/com/jcraft/jzlib/CRC32.java + /// + sealed class CRC32 : IChecksum + { + /* + * The following logic has come from RFC1952. + */ + int v; + static readonly int[] crc_table; + + static CRC32() + { + crc_table = new int[256]; + for (int n = 0; n < 256; n++) + { + int c = n; + for (int k = 8; --k >= 0;) + { + if ((c & 1) != 0) + c = (int)(0xedb88320 ^ (c.RightUShift(1))); + else + c = c.RightUShift(1); + } + crc_table[n] = c; + } + } + + public void Update(byte[] buf, int index, int len) + { + int c = ~v; + while (--len >= 0) + c = crc_table[(c ^ buf[index++]) & 0xff] ^ (c.RightUShift(8)); + v = ~c; + } + + public void Reset() => v = 0; + + public void Reset(long vv) => v = (int)(vv & 0xffffffffL); + + public long GetValue() => v & 0xffffffffL; + + // The following logic has come from zlib.1.2. + static readonly int GF2_DIM = 32; + + internal static long Combine(long crc1, long crc2, long len2) + { + long row; + var even = new long[GF2_DIM]; + var odd = new long[GF2_DIM]; + + // degenerate case (also disallow negative lengths) + if (len2 <= 0) + return crc1; + + // put operator for one zero bit in odd + odd[0] = 0xedb88320L; // CRC-32 polynomial + row = 1; + for (int n = 1; n < GF2_DIM; n++) + { + odd[n] = row; + row <<= 1; + } + + // put operator for two zero bits in even + gf2_matrix_square(even, odd); + + // put operator for four zero bits in odd + gf2_matrix_square(odd, even); + + // apply len2 zeros to crc1 (first square will put the operator for one + // zero byte, eight zero bits, in even) + do + { + // apply zeros operator for this bit of len2 + gf2_matrix_square(even, odd); + if ((len2 & 1) != 0) + crc1 = gf2_matrix_times(even, crc1); + len2 >>= 1; + + // if no more bits set, then done + if (len2 == 0) + break; + + // another iteration of the loop with odd and even swapped + gf2_matrix_square(odd, even); + if ((len2 & 1) != 0) + crc1 = gf2_matrix_times(odd, crc1); + len2 >>= 1; + + // if no more bits set, then done + } + while (len2 != 0); + + /* return combined crc */ + crc1 ^= crc2; + return crc1; + } + + static long gf2_matrix_times(long[] mat, long vec) + { + long sum = 0; + int index = 0; + while (vec != 0) + { + if ((vec & 1) != 0) + sum ^= mat[index]; + vec >>= 1; + index++; + } + return sum; + } + + static void gf2_matrix_square(long[] square, long[] mat) + { + for (int n = 0; n < GF2_DIM; n++) + square[n] = gf2_matrix_times(mat, mat[n]); + } + + /* + private java.util.zip.CRC32 crc32 = new java.util.zip.CRC32(); + public void update(byte[] buf, int index, int len){ + if(buf==null) {crc32.reset();} + else{crc32.update(buf, index, len);} + } + public void reset(){ + crc32.reset(); + } + public void reset(long init){ + if(init==0L){ + crc32.reset(); + } + else{ + System.err.println("unsupported operation"); + } + } + public long getValue(){ + return crc32.getValue(); + } + */ + + public IChecksum Copy() + { + var foo = new CRC32(); + foo.v = this.v; + return foo; + } + + public static int[] getCRC32Table() + { + var tmp = new int[crc_table.Length]; + Array.Copy(crc_table, 0, tmp, 0, tmp.Length); + return tmp; + } + } +} diff --git a/src/DotNetty.Codecs/Compression/CompressionException.cs b/src/DotNetty.Codecs/Compression/CompressionException.cs new file mode 100644 index 0000000..7e7663f --- /dev/null +++ b/src/DotNetty.Codecs/Compression/CompressionException.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Compression +{ + using System; + + public class CompressionException : EncoderException + { + public CompressionException(string message) + : base(message) + { + } + + public CompressionException(string message, Exception exception) + : base(message, exception) + { + } + } +} diff --git a/src/DotNetty.Codecs/Compression/DecompressionException.cs b/src/DotNetty.Codecs/Compression/DecompressionException.cs new file mode 100644 index 0000000..bafd952 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/DecompressionException.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Compression +{ + using System; + + public class DecompressionException : DecoderException + { + public DecompressionException(string message) + : base(message) + { + } + + public DecompressionException(string message, Exception exception) + : base(message, exception) + { + } + } +} diff --git a/src/DotNetty.Codecs/Compression/Deflate.cs b/src/DotNetty.Codecs/Compression/Deflate.cs new file mode 100644 index 0000000..3bda831 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/Deflate.cs @@ -0,0 +1,1943 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +//JZlib 0.0.* were released under the GNU LGPL license.Later, we have switched +//over to a BSD-style license. + +//------------------------------------------------------------------------------ +//Copyright (c) 2000-2011 ymnk, JCraft, Inc.All rights reserved. + +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in +// the documentation and/or other materials provided with the distribution. + +// 3. The names of the authors may not be used to endorse or promote products +// derived from this software without specific prior written permission. + +//THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, +//INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +//FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.IN NO EVENT SHALL JCRAFT, +//INC.OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +//LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +//OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +//LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING +//NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +//EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ReSharper disable ArrangeThisQualifier +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Compression +{ + using System; + using DotNetty.Common.Utilities; + + /// + /// https://github.com/ymnk/jzlib/blob/master/src/main/java/com/jcraft/jzlib/Deflate.java + /// + sealed class Deflate + { + const int MAX_MEM_LEVEL = 9; + + const int Z_DEFAULT_COMPRESSION = -1; + + const int MAX_WBITS = 15; // 32K LZ77 window + const int DEF_MEM_LEVEL = 8; + + class Config + { + internal readonly int good_length; // reduce lazy search above this match length + internal readonly int max_lazy; // do not perform lazy search above this match length + internal readonly int nice_length; // quit search above this match length + internal readonly int max_chain; + internal readonly int func; + + internal Config( + int good_length, int max_lazy, + int nice_length, int max_chain, int func) + { + this.good_length = good_length; + this.max_lazy = max_lazy; + this.nice_length = nice_length; + this.max_chain = max_chain; + this.func = func; + } + } + + const int STORED = 0; + const int FAST = 1; + const int SLOW = 2; + + static readonly Config[] config_table; + static Deflate() + { + config_table = new Config[10]; + // good lazy nice chain + config_table[0] = new Config(0, 0, 0, 0, STORED); + config_table[1] = new Config(4, 4, 8, 4, FAST); + config_table[2] = new Config(4, 5, 16, 8, FAST); + config_table[3] = new Config(4, 6, 32, 32, FAST); + + config_table[4] = new Config(4, 4, 16, 16, SLOW); + config_table[5] = new Config(8, 16, 32, 32, SLOW); + config_table[6] = new Config(8, 16, 128, 128, SLOW); + config_table[7] = new Config(8, 32, 128, 256, SLOW); + config_table[8] = new Config(32, 128, 258, 1024, SLOW); + config_table[9] = new Config(32, 258, 258, 4096, SLOW); + } + + static readonly string[] z_errmsg = + { + "need dictionary", // Z_NEED_DICT 2 + "stream end", // Z_STREAM_END 1 + "", // Z_OK 0 + "file error", // Z_ERRNO (-1) + "stream error", // Z_STREAM_ERROR (-2) + "data error", // Z_DATA_ERROR (-3) + "insufficient memory", // Z_MEM_ERROR (-4) + "buffer error", // Z_BUF_ERROR (-5) + "incompatible version", // Z_VERSION_ERROR (-6) + "" + }; + + // block not completed, need more input or more output + const int NeedMore = 0; + + // block flush performed + const int BlockDone = 1; + + // finish started, need only more output at next deflate + const int FinishStarted = 2; + + // finish done, accept no more input or output + const int FinishDone = 3; + + // preset dictionary flag in zlib header + const int PRESET_DICT = 0x20; + + const int Z_FILTERED = 1; + const int Z_HUFFMAN_ONLY = 2; + const int Z_DEFAULT_STRATEGY = 0; + + const int Z_NO_FLUSH = 0; + const int Z_PARTIAL_FLUSH = 1; + //const int Z_SYNC_FLUSH = 2; + const int Z_FULL_FLUSH = 3; + const int Z_FINISH = 4; + + const int Z_OK = 0; + const int Z_STREAM_END = 1; + const int Z_NEED_DICT = 2; + //const int Z_ERRNO = -1; + const int Z_STREAM_ERROR = -2; + const int Z_DATA_ERROR = -3; + //const int Z_MEM_ERROR = -4; + const int Z_BUF_ERROR = -5; + //const int Z_VERSION_ERROR = -6; + + const int INIT_STATE = 42; + const int BUSY_STATE = 113; + const int FINISH_STATE = 666; + + // The deflate compression method + const int Z_DEFLATED = 8; + + const int STORED_BLOCK = 0; + const int STATIC_TREES = 1; + const int DYN_TREES = 2; + + // The three kinds of block type + const int Z_BINARY = 0; + const int Z_ASCII = 1; + const int Z_UNKNOWN = 2; + + const int Buf_size = 8 * 2; + + // repeat previous bit length 3-6 times (2 bits of repeat count) + const int REP_3_6 = 16; + + // repeat a zero length 3-10 times (3 bits of repeat count) + const int REPZ_3_10 = 17; + + // repeat a zero length 11-138 times (7 bits of repeat count) + const int REPZ_11_138 = 18; + + const int MIN_MATCH = 3; + const int MAX_MATCH = 258; + const int MIN_LOOKAHEAD = (MAX_MATCH + MIN_MATCH + 1); + + const int MAX_BITS = 15; + const int D_CODES = 30; + const int BL_CODES = 19; + const int LENGTH_CODES = 29; + const int LITERALS = 256; + const int L_CODES = (LITERALS + 1 + LENGTH_CODES); + const int HEAP_SIZE = (2 * L_CODES + 1); + + const int END_BLOCK = 256; + + ZStream strm; // pointer back to this zlib stream + int status; // as the name implies + internal byte[] pending_buf; // output still pending + int pending_buf_size; // size of pending_buf + internal int pending_out; // next pending byte to output to the stream + internal int pending; // nb of bytes in the pending buffer + internal int wrap = 1; + byte data_type; // UNKNOWN, BINARY or ASCII + byte method; // STORED (for zip only) or DEFLATED + int last_flush; // value of flush param for previous deflate call + + int w_size; // LZ77 window size (32K by default) + int w_bits; // log2(w_size) (8..16) + int w_mask; // w_size - 1 + + byte[] window; + // Sliding window. Input bytes are read into the second half of the window, + // and move to the first half later to keep a dictionary of at least wSize + // bytes. With this organization, matches are limited to a distance of + // wSize-MAX_MATCH bytes, but this ensures that IO is always + // performed with a length multiple of the block size. Also, it limits + // the window size to 64K, which is quite useful on MSDOS. + // To do: use the user input buffer as sliding window. + + int window_size; + // Actual size of window: 2*wSize, except when the user input buffer + // is directly used as sliding window. + + short[] prev; + // Link to older string with same hash index. To limit the size of this + // array to 64K, this link is maintained only for the last 32K strings. + // An index in this array is thus a window index modulo 32K. + + short[] head; // Heads of the hash chains or NIL. + + int ins_h; // hash index of string to be inserted + int hash_size; // number of elements in hash table + int hash_bits; // log2(hash_size) + int hash_mask; // hash_size-1 + + // Number of bits by which ins_h must be shifted at each input + // step. It must be such that after MIN_MATCH steps, the oldest + // byte no longer takes part in the hash key, that is: + // hash_shift * MIN_MATCH >= hash_bits + int hash_shift; + + // Window position at the beginning of the current output block. Gets + // negative when the window is moved backwards. + + int block_start; + + int match_length; // length of best match + int prev_match; // previous match + int match_available; // set if previous match exists + int strstart; // start of string to insert + int match_start; // start of matching string + int lookahead; // number of valid bytes ahead in window + + // Length of the best match at previous step. Matches not greater than this + // are discarded. This is used in the lazy match evaluation. + int prev_length; + + // To speed up deflation, hash chains are never searched beyond this + // length. A higher limit improves compression ratio but degrades the speed. + int max_chain_length; + + // Attempt to find a better match only when the current match is strictly + // smaller than this value. This mechanism is used only for compression + // levels >= 4. + int max_lazy_match; + + // Insert new strings in the hash table only if the match length is not + // greater than this length. This saves time but degrades compression. + // max_insert_length is used only for compression levels <= 3. + + internal int level; // compression level (1..9) + int strategy; // favor or force Huffman coding + + // Use a faster search when the previous match is longer than this + int good_match; + + // Stop searching when current match exceeds this + int nice_match; + + short[] dyn_ltree; // literal and length tree + short[] dyn_dtree; // distance tree + short[] bl_tree; // Huffman tree for bit lengths + + Tree l_desc = new Tree(); // desc for literal tree + Tree d_desc = new Tree(); // desc for distance tree + Tree bl_desc = new Tree(); // desc for bit length tree + + // number of codes at each bit length for an optimal tree + internal short[] bl_count = new short[MAX_BITS + 1]; + // working area to be used in Tree#gen_codes() + internal short[] next_code = new short[MAX_BITS + 1]; + + // heap used to build the Huffman trees + internal int[] heap = new int[2 * L_CODES + 1]; + + internal int heap_len; // number of elements in the heap + internal int heap_max; // element of largest frequency + // The sons of heap[n] are heap[2*n] and heap[2*n+1]. heap[0] is not used. + // The same heap array is used to build all trees. + + // Depth of each subtree used as tie breaker for trees of equal frequency + internal byte[] depth = new byte[2 * L_CODES + 1]; + + byte[] l_buf; // index for literals or lengths */ + + // Size of match buffer for literals/lengths. There are 4 reasons for + // limiting lit_bufsize to 64K: + // - frequencies can be kept in 16 bit counters + // - if compression is not successful for the first block, all input + // data is still in the window so we can still emit a stored block even + // when input comes from standard input. (This can also be done for + // all blocks if lit_bufsize is not greater than 32K.) + // - if compression is not successful for a file smaller than 64K, we can + // even emit a stored file instead of a stored block (saving 5 bytes). + // This is applicable only for zip (not gzip or zlib). + // - creating new Huffman trees less frequently may not provide fast + // adaptation to changes in the input data statistics. (Take for + // example a binary file with poorly compressible code followed by + // a highly compressible string table.) Smaller buffer sizes give + // fast adaptation but have of course the overhead of transmitting + // trees more frequently. + // - I can't count above 4 + int lit_bufsize; + + int last_lit; // running index in l_buf + + // Buffer for distances. To simplify the code, d_buf and l_buf have + // the same number of elements. To use different lengths, an extra flag + // array would be necessary. + + int d_buf; // index of pendig_buf + + internal int opt_len; // bit length of current block with optimal trees + internal int static_len; // bit length of current block with static trees + int matches; // number of string matches in current block + int last_eob_len; // bit length of EOB code for last block + + // Output buffer. bits are inserted starting at the bottom (least + // significant bits). + short bi_buf; + + // Number of valid bits in bi_buf. All bits above the last valid bit + // are always zero. + int bi_valid; + + GZIPHeader gheader = null; + + internal Deflate(ZStream strm) + { + this.strm = strm; + dyn_ltree = new short[HEAP_SIZE * 2]; + dyn_dtree = new short[(2 * D_CODES + 1) * 2]; // distance tree + bl_tree = new short[(2 * BL_CODES + 1) * 2]; // Huffman tree for bit lengths + } + + void Lm_init() + { + window_size = 2 * w_size; + + head[hash_size - 1] = 0; + for (int i = 0; i < hash_size - 1; i++) + { + head[i] = 0; + } + + // Set the default configuration parameters: + max_lazy_match = config_table[level].max_lazy; + good_match = config_table[level].good_length; + nice_match = config_table[level].nice_length; + max_chain_length = config_table[level].max_chain; + + strstart = 0; + block_start = 0; + lookahead = 0; + match_length = prev_length = MIN_MATCH - 1; + match_available = 0; + ins_h = 0; + } + + // Initialize the tree data structures for a new zlib stream. + void Tr_init() + { + + l_desc.dyn_tree = dyn_ltree; + l_desc.stat_desc = StaticTree.static_l_desc; + + d_desc.dyn_tree = dyn_dtree; + d_desc.stat_desc = StaticTree.static_d_desc; + + bl_desc.dyn_tree = bl_tree; + bl_desc.stat_desc = StaticTree.static_bl_desc; + + bi_buf = 0; + bi_valid = 0; + last_eob_len = 8; // enough lookahead for inflate + + // Initialize the first block of the first file: + Init_block(); + } + + void Init_block() + { + // Initialize the trees. + for (int i = 0; i < L_CODES; i++) dyn_ltree[i * 2] = 0; + for (int i = 0; i < D_CODES; i++) dyn_dtree[i * 2] = 0; + for (int i = 0; i < BL_CODES; i++) bl_tree[i * 2] = 0; + + dyn_ltree[END_BLOCK * 2] = 1; + opt_len = static_len = 0; + last_lit = matches = 0; + } + + // Restore the heap property by moving down the tree starting at node k, + // exchanging a node with the smallest of its two sons if necessary, stopping + // when the heap property is re-established (each father smaller than its + // two sons). + internal void Pqdownheap( + short[] tree, // the tree to restore + int k // node to move down + ) + { + int v = heap[k]; + int j = k << 1; // left son of k + while (j <= heap_len) + { + // Set j to the smallest of the two sons: + if (j < heap_len && + Smaller(tree, heap[j + 1], heap[j], depth)) + { + j++; + } + // Exit if v is smaller than both sons + if (Smaller(tree, v, heap[j], depth)) break; + + // Exchange v with the smallest son + heap[k] = heap[j]; k = j; + // And continue down the tree, setting j to the left son of k + j <<= 1; + } + heap[k] = v; + } + + static bool Smaller(short[] tree, int n, int m, byte[] depth) + { + short tn2 = tree[n * 2]; + short tm2 = tree[m * 2]; + return (tn2 < tm2 || + (tn2 == tm2 && depth[n] <= depth[m])); + } + + // Scan a literal or distance tree to determine the frequencies of the codes + // in the bit length tree. + void Scan_tree( + short[] tree,// the tree to be scanned + int max_code // and its largest code of non zero frequency + ) + { + int n; // iterates over all tree elements + int prevlen = -1; // last emitted length + int curlen; // length of current code + int nextlen = tree[0 * 2 + 1]; // length of next code + int count = 0; // repeat count of the current code + int max_count = 7; // max repeat count + int min_count = 4; // min repeat count + + if (nextlen == 0) { max_count = 138; min_count = 3; } + tree[(max_code + 1) * 2 + 1] = unchecked((short)0xffff); // guard + + for (n = 0; n <= max_code; n++) + { + curlen = nextlen; nextlen = tree[(n + 1) * 2 + 1]; + if (++count < max_count && curlen == nextlen) + { + continue; + } + else if (count < min_count) + { + bl_tree[curlen * 2] += (short)count; + } + else if (curlen != 0) + { + if (curlen != prevlen) bl_tree[curlen * 2]++; + bl_tree[REP_3_6 * 2]++; + } + else if (count <= 10) + { + bl_tree[REPZ_3_10 * 2]++; + } + else + { + bl_tree[REPZ_11_138 * 2]++; + } + count = 0; prevlen = curlen; + if (nextlen == 0) + { + max_count = 138; min_count = 3; + } + else if (curlen == nextlen) + { + max_count = 6; min_count = 3; + } + else + { + max_count = 7; min_count = 4; + } + } + } + // Construct the Huffman tree for the bit lengths and return the index in + // bl_order of the last bit length code to send. + int Build_bl_tree() + { + int max_blindex; // index of last bit length code of non zero freq + + // Determine the bit length frequencies for literal and distance trees + Scan_tree(dyn_ltree, l_desc.max_code); + Scan_tree(dyn_dtree, d_desc.max_code); + + // Build the bit length tree: + bl_desc.Build_tree(this); + // opt_len now includes the length of the tree representations, except + // the lengths of the bit lengths codes and the 5+5+4 bits for the counts. + + // Determine the number of bit length codes to send. The pkzip format + // requires that at least 4 bit length codes be sent. (appnote.txt says + // 3 but the actual value used is 4.) + for (max_blindex = BL_CODES - 1; max_blindex >= 3; max_blindex--) + { + if (bl_tree[Tree.bl_order[max_blindex] * 2 + 1] != 0) break; + } + // Update opt_len to include the bit length tree and counts + opt_len += 3 * (max_blindex + 1) + 5 + 5 + 4; + + return max_blindex; + } + + // Send the header for a block using dynamic Huffman trees: the counts, the + // lengths of the bit length codes, the literal tree and the distance tree. + // IN assertion: lcodes >= 257, dcodes >= 1, blcodes >= 4. + void Send_all_trees(int lcodes, int dcodes, int blcodes) + { + int rank; // index in bl_order + + Send_bits(lcodes - 257, 5); // not +255 as stated in appnote.txt + Send_bits(dcodes - 1, 5); + Send_bits(blcodes - 4, 4); // not -3 as stated in appnote.txt + for (rank = 0; rank < blcodes; rank++) + { + Send_bits(bl_tree[Tree.bl_order[rank] * 2 + 1], 3); + } + Send_tree(dyn_ltree, lcodes - 1); // literal tree + Send_tree(dyn_dtree, dcodes - 1); // distance tree + } + + // Send a literal or distance tree in compressed form, using the codes in + // bl_tree. + void Send_tree( + short[] tree,// the tree to be sent + int max_code // and its largest code of non zero frequency + ) + { + int n; // iterates over all tree elements + int prevlen = -1; // last emitted length + int curlen; // length of current code + int nextlen = tree[0 * 2 + 1]; // length of next code + int count = 0; // repeat count of the current code + int max_count = 7; // max repeat count + int min_count = 4; // min repeat count + + if (nextlen == 0) { max_count = 138; min_count = 3; } + + for (n = 0; n <= max_code; n++) + { + curlen = nextlen; nextlen = tree[(n + 1) * 2 + 1]; + if (++count < max_count && curlen == nextlen) + { + continue; + } + else if (count < min_count) + { + do { Send_code(curlen, bl_tree); } while (--count != 0); + } + else if (curlen != 0) + { + if (curlen != prevlen) + { + Send_code(curlen, bl_tree); count--; + } + Send_code(REP_3_6, bl_tree); + Send_bits(count - 3, 2); + } + else if (count <= 10) + { + Send_code(REPZ_3_10, bl_tree); + Send_bits(count - 3, 3); + } + else + { + Send_code(REPZ_11_138, bl_tree); + Send_bits(count - 11, 7); + } + count = 0; prevlen = curlen; + if (nextlen == 0) + { + max_count = 138; min_count = 3; + } + else if (curlen == nextlen) + { + max_count = 6; min_count = 3; + } + else + { + max_count = 7; min_count = 4; + } + } + } + + // Output a byte on the stream. + // IN assertion: there is enough room in pending_buf. + internal void Put_byte(byte[] p, int start, int len) + { + Array.Copy(p, start, pending_buf, pending, len); + pending += len; + } + + internal void Put_byte(byte c) => pending_buf[this.pending++] = c; + + internal void Put_short(int w) + { + Put_byte((byte)(w/*&0xff*/)); + Put_byte((byte)(w.RightUShift(8))); + } + + void PutShortMSB(int b) + { + Put_byte((byte)(b >> 8)); + Put_byte((byte)(b/*&0xff*/)); + } + + void Send_code(int c, short[] tree) + { + int c2 = c * 2; + Send_bits((tree[c2] & 0xffff), (tree[c2 + 1] & 0xffff)); + } + + void Send_bits(int value, int length) + { + int len = length; + if (bi_valid > (int)Buf_size - len) + { + int val = value; + // bi_buf |= (val << bi_valid); + bi_buf |= (short)((val << bi_valid) & 0xffff); + Put_short(bi_buf); + bi_buf = (short)(val.RightUShift(Buf_size - bi_valid)); + bi_valid += len - Buf_size; + } + else + { + // bi_buf |= (value) << bi_valid; + bi_buf |= (short)(((value) << bi_valid) & 0xffff); + bi_valid += len; + } + } + + // Send one empty static block to give enough lookahead for inflate. + // This takes 10 bits, of which 7 may remain in the bit buffer. + // The current inflate code requires 9 bits of lookahead. If the + // last two codes for the previous block (real code plus EOB) were coded + // on 5 bits or less, inflate may have only 5+3 bits of lookahead to decode + // the last real code. In this case we send two empty static blocks instead + // of one. (There are no problems if the previous block is stored or fixed.) + // To simplify the code, we assume the worst case of last real code encoded + // on one bit only. + void _tr_align() + { + Send_bits(STATIC_TREES << 1, 3); + Send_code(END_BLOCK, StaticTree.static_ltree); + + Bi_flush(); + + // Of the 10 bits for the empty block, we have already sent + // (10 - bi_valid) bits. The lookahead for the last real code (before + // the EOB of the previous block) was thus at least one plus the length + // of the EOB plus what we have just sent of the empty static block. + if (1 + last_eob_len + 10 - bi_valid < 9) + { + Send_bits(STATIC_TREES << 1, 3); + Send_code(END_BLOCK, StaticTree.static_ltree); + Bi_flush(); + } + last_eob_len = 7; + } + + // Save the match info and tally the frequency counts. Return true if + // the current block must be flushed. + bool _tr_tally(int dist, // distance of matched string + int lc // match length-MIN_MATCH or unmatched char (if dist==0) + ) + { + + pending_buf[d_buf + last_lit * 2] = (byte)(dist.RightUShift(8)); + pending_buf[d_buf + last_lit * 2 + 1] = (byte)dist; + + l_buf[last_lit] = (byte)lc; last_lit++; + + if (dist == 0) + { + // lc is the unmatched char + dyn_ltree[lc * 2]++; + } + else + { + matches++; + // Here, lc is the match length - MIN_MATCH + dist--; // dist = match distance - 1 + dyn_ltree[(Tree._length_code[lc] + LITERALS + 1) * 2]++; + dyn_dtree[Tree.D_code(dist) * 2]++; + } + + if ((last_lit & 0x1fff) == 0 && level > 2) + { + // Compute an upper bound for the compressed length + int out_length = last_lit * 8; + int in_length = strstart - block_start; + int dcode; + for (dcode = 0; dcode < D_CODES; dcode++) + { + out_length += (int)(dyn_dtree[dcode * 2] * + (5L + Tree.extra_dbits[dcode])); + } + out_length = out_length.RightUShift(3); + if ((matches < (last_lit / 2)) && out_length < in_length / 2) return true; + } + + return (last_lit == lit_bufsize - 1); + // We avoid equality with lit_bufsize because of wraparound at 64K + // on 16 bit machines and because stored blocks are restricted to + // 64K-1 bytes. + } + + // Send the block data compressed using the given Huffman trees + void Compress_block(short[] ltree, short[] dtree) + { + int dist; // distance of matched string + int lc; // match length or unmatched char (if dist == 0) + int lx = 0; // running index in l_buf + int code; // the code to send + int extra; // number of extra bits to send + + if (last_lit != 0) + { + do + { + dist = ((pending_buf[d_buf + lx * 2] << 8) & 0xff00) | + (pending_buf[d_buf + lx * 2 + 1] & 0xff); + lc = (l_buf[lx]) & 0xff; lx++; + + if (dist == 0) + { + Send_code(lc, ltree); // send a literal byte + } + else + { + // Here, lc is the match length - MIN_MATCH + code = Tree._length_code[lc]; + + Send_code(code + LITERALS + 1, ltree); // send the length code + extra = Tree.extra_lbits[code]; + if (extra != 0) + { + lc -= Tree.base_length[code]; + Send_bits(lc, extra); // send the extra length bits + } + dist--; // dist is now the match distance - 1 + code = Tree.D_code(dist); + + Send_code(code, dtree); // send the distance code + extra = Tree.extra_dbits[code]; + if (extra != 0) + { + dist -= Tree.base_dist[code]; + Send_bits(dist, extra); // send the extra distance bits + } + } // literal or match pair ? + + // Check that the overlay between pending_buf and d_buf+l_buf is ok: + } + while (lx < last_lit); + } + + Send_code(END_BLOCK, ltree); + last_eob_len = ltree[END_BLOCK * 2 + 1]; + } + + // Set the data type to ASCII or BINARY, using a crude approximation: + // binary if more than 20% of the bytes are <= 6 or >= 128, ascii otherwise. + // IN assertion: the fields freq of dyn_ltree are set and the total of all + // frequencies does not exceed 64K (to fit in an int on 16 bit machines). + void Set_data_type() + { + int n = 0; + int ascii_freq = 0; + int bin_freq = 0; + while (n < 7) { bin_freq += dyn_ltree[n * 2]; n++; } + while (n < 128) { ascii_freq += dyn_ltree[n * 2]; n++; } + while (n < LITERALS) { bin_freq += dyn_ltree[n * 2]; n++; } + data_type = (byte)(bin_freq > (ascii_freq.RightUShift(2)) ? Z_BINARY : Z_ASCII); + } + + // Flush the bit buffer, keeping at most 7 bits in it. + void Bi_flush() + { + if (bi_valid == 16) + { + Put_short(bi_buf); + bi_buf = 0; + bi_valid = 0; + } + else if (bi_valid >= 8) + { + Put_byte((byte)bi_buf); + bi_buf = (short)((int)this.bi_buf).RightUShift(8); + bi_valid -= 8; + } + } + + // Flush the bit buffer and align the output on a byte boundary + void Bi_windup() + { + if (bi_valid > 8) + { + Put_short(bi_buf); + } + else if (bi_valid > 0) + { + Put_byte((byte)bi_buf); + } + bi_buf = 0; + bi_valid = 0; + } + + // Copy a stored block, storing first the length and its + // one's complement if requested. + void Copy_block(int buf, // the input data + int len, // its length + bool header // true if block header must be written + ) + { + //int index = 0; + Bi_windup(); // align on byte boundary + last_eob_len = 8; // enough lookahead for inflate + + if (header) + { + Put_short((short)len); + Put_short((short)~len); + } + + // while(len--!=0) { + // put_byte(window[buf+index]); + // index++; + // } + Put_byte(window, buf, len); + } + + void Flush_block_only(bool eof) + { + _tr_flush_block(block_start >= 0 ? block_start : -1, + strstart - block_start, + eof); + block_start = strstart; + strm.Flush_pending(); + } + + // Copy without compression as much as possible from the input stream, return + // the current block state. + // This function does not insert new strings in the dictionary since + // uncompressible data is probably not useful. This function is used + // only for the level=0 compression option. + // NOTE: this function should be optimized to avoid extra copying from + // window to pending_buf. + int Deflate_stored(int flush) + { + // Stored blocks are limited to 0xffff bytes, pending_buf is limited + // to pending_buf_size, and each stored block has a 5 byte header: + + int max_block_size = 0xffff; + int max_start; + + if (max_block_size > pending_buf_size - 5) + { + max_block_size = pending_buf_size - 5; + } + + // Copy as much as possible from input to output: + while (true) + { + // Fill the window as much as possible: + if (lookahead <= 1) + { + Fill_window(); + if (lookahead == 0 && flush == Z_NO_FLUSH) return NeedMore; + if (lookahead == 0) break; // flush the current block + } + + strstart += lookahead; + lookahead = 0; + + // Emit a stored block if pending_buf will be full: + max_start = block_start + max_block_size; + if (strstart == 0 || strstart >= max_start) + { + // strstart == 0 is possible when wraparound on 16-bit machine + lookahead = (int)(strstart - max_start); + strstart = (int)max_start; + + Flush_block_only(false); + if (strm.avail_out == 0) return NeedMore; + + } + + // Flush if we may have to slide, otherwise block_start may become + // negative and the data will be gone: + if (strstart - block_start >= w_size - MIN_LOOKAHEAD) + { + Flush_block_only(false); + if (strm.avail_out == 0) return NeedMore; + } + } + + Flush_block_only(flush == Z_FINISH); + if (strm.avail_out == 0) + return (flush == Z_FINISH) ? FinishStarted : NeedMore; + + return flush == Z_FINISH ? FinishDone : BlockDone; + } + + // Send a stored block + void _tr_stored_block(int buf, // input block + int stored_len, // length of input block + bool eof // true if this is the last block for a file + ) + { + Send_bits((STORED_BLOCK << 1) + (eof ? 1 : 0), 3); // send block type + Copy_block(buf, stored_len, true); // with header + } + + // Determine the best encoding for the current block: dynamic trees, static + // trees or store, and output the encoded block to the zip file. + void _tr_flush_block(int buf, // input block, or NULL if too old + int stored_len, // length of input block + bool eof // true if this is the last block for a file + ) + { + int opt_lenb, static_lenb;// opt_len and static_len in bytes + int max_blindex = 0; // index of last bit length code of non zero freq + + // Build the Huffman trees unless a stored block is forced + if (level > 0) + { + // Check if the file is ascii or binary + if (data_type == Z_UNKNOWN) Set_data_type(); + + // Construct the literal and distance trees + l_desc.Build_tree(this); + + d_desc.Build_tree(this); + + // At this point, opt_len and static_len are the total bit lengths of + // the compressed block data, excluding the tree representations. + + // Build the bit length tree for the above two trees, and get the index + // in bl_order of the last bit length code to send. + max_blindex = Build_bl_tree(); + + // Determine the best encoding. Compute first the block length in bytes + opt_lenb = (opt_len + 3 + 7).RightUShift(3); + static_lenb = (static_len + 3 + 7).RightUShift(3); + + if (static_lenb <= opt_lenb) opt_lenb = static_lenb; + } + else + { + opt_lenb = static_lenb = stored_len + 5; // force a stored block + } + + if (stored_len + 4 <= opt_lenb && buf != -1) + { + // 4: two words for the lengths + // The test buf != NULL is only necessary if LIT_BUFSIZE > WSIZE. + // Otherwise we can't have processed more than WSIZE input bytes since + // the last block flush, because compression would have been + // successful. If LIT_BUFSIZE <= WSIZE, it is never too late to + // transform a block into a stored block. + _tr_stored_block(buf, stored_len, eof); + } + else if (static_lenb == opt_lenb) + { + Send_bits((STATIC_TREES << 1) + (eof ? 1 : 0), 3); + Compress_block(StaticTree.static_ltree, StaticTree.static_dtree); + } + else + { + Send_bits((DYN_TREES << 1) + (eof ? 1 : 0), 3); + Send_all_trees(l_desc.max_code + 1, d_desc.max_code + 1, max_blindex + 1); + Compress_block(dyn_ltree, dyn_dtree); + } + + // The above check is made mod 2^32, for files larger than 512 MB + // and uLong implemented on 32 bits. + + Init_block(); + + if (eof) + { + Bi_windup(); + } + } + + // Fill the window when the lookahead becomes insufficient. + // Updates strstart and lookahead. + // + // IN assertion: lookahead < MIN_LOOKAHEAD + // OUT assertions: strstart <= window_size-MIN_LOOKAHEAD + // At least one byte has been read, or avail_in == 0; reads are + // performed for at least two bytes (required for the zip translate_eol + // option -- not supported here). + void Fill_window() + { + int n, m; + int p; + int more; // Amount of free space at the end of the window. + + do + { + more = (window_size - lookahead - strstart); + + // Deal with !@#$% 64K limit: + if (more == 0 && strstart == 0 && lookahead == 0) + { + more = w_size; + } + else if (more == -1) + { + // Very unlikely, but possible on 16 bit machine if strstart == 0 + // and lookahead == 1 (input done one byte at time) + more--; + + // If the window is almost full and there is insufficient lookahead, + // move the upper half to the lower one to make room in the upper half. + } + else if (strstart >= w_size + w_size - MIN_LOOKAHEAD) + { + Array.Copy(window, w_size, window, 0, w_size); + match_start -= w_size; + strstart -= w_size; // we now have strstart >= MAX_DIST + block_start -= w_size; + + // Slide the hash table (could be avoided with 32 bit values + // at the expense of memory usage). We slide even when level == 0 + // to keep the hash table consistent if we switch back to level > 0 + // later. (Using level 0 permanently is not an optimal usage of + // zlib, so we don't care about this pathological case.) + + n = hash_size; + p = n; + do + { + m = (head[--p] & 0xffff); + head[p] = (short)(m >= w_size ? (short)(m - w_size) : 0); + } + while (--n != 0); + + n = w_size; + p = n; + do + { + m = (prev[--p] & 0xffff); + prev[p] = (short)(m >= w_size ? (short)(m - w_size) : 0); + // If n is not on any hash chain, prev[n] is garbage but + // its value will never be used. + } + while (--n != 0); + more += w_size; + } + + if (strm.avail_in == 0) return; + + // If there was no sliding: + // strstart <= WSIZE+MAX_DIST-1 && lookahead <= MIN_LOOKAHEAD - 1 && + // more == window_size - lookahead - strstart + // => more >= window_size - (MIN_LOOKAHEAD-1 + WSIZE + MAX_DIST-1) + // => more >= window_size - 2*WSIZE + 2 + // In the BIG_MEM or MMAP case (not yet supported), + // window_size == input_size + MIN_LOOKAHEAD && + // strstart + s->lookahead <= input_size => more >= MIN_LOOKAHEAD. + // Otherwise, window_size == 2*WSIZE so more >= 2. + // If there was sliding, more >= WSIZE. So in all cases, more >= 2. + + n = strm.Read_buf(window, strstart + lookahead, more); + lookahead += n; + + // Initialize the hash value now that we have some input: + if (lookahead >= MIN_MATCH) + { + ins_h = window[strstart] & 0xff; + ins_h = (((ins_h) << hash_shift) ^ (window[strstart + 1] & 0xff)) & hash_mask; + } + // If the whole input has less than MIN_MATCH bytes, ins_h is garbage, + // but this is not important since only literal bytes will be emitted. + } + while (lookahead < MIN_LOOKAHEAD && strm.avail_in != 0); + } + + // Compress as much as possible from the input stream, return the current + // block state. + // This function does not perform lazy evaluation of matches and inserts + // new strings in the dictionary only for unmatched strings or for short + // matches. It is used only for the fast compression options. + int Deflate_fast(int flush) + { + // short hash_head = 0; // head of the hash chain + int hash_head = 0; // head of the hash chain + bool bflush; // set if current block must be flushed + + while (true) + { + // Make sure that we always have enough lookahead, except + // at the end of the input file. We need MAX_MATCH bytes + // for the next match, plus MIN_MATCH bytes to insert the + // string following the next match. + if (lookahead < MIN_LOOKAHEAD) + { + Fill_window(); + if (lookahead < MIN_LOOKAHEAD && flush == Z_NO_FLUSH) + { + return NeedMore; + } + if (lookahead == 0) break; // flush the current block + } + + // Insert the string window[strstart .. strstart+2] in the + // dictionary, and set hash_head to the head of the hash chain: + if (lookahead >= MIN_MATCH) + { + ins_h = (((ins_h) << hash_shift) ^ (window[(strstart) + (MIN_MATCH - 1)] & 0xff)) & hash_mask; + + // prev[strstart&w_mask]=hash_head=head[ins_h]; + hash_head = (head[ins_h] & 0xffff); + prev[strstart & w_mask] = head[ins_h]; + head[ins_h] = (short)strstart; + } + + // Find the longest match, discarding those <= prev_length. + // At this point we have always match_length < MIN_MATCH + + if (hash_head != 0L && + ((strstart - hash_head) & 0xffff) <= w_size - MIN_LOOKAHEAD + ) + { + // To simplify the code, we prevent matches with the string + // of window index 0 (in particular we have to avoid a match + // of the string with itself at the start of the input file). + if (strategy != Z_HUFFMAN_ONLY) + { + match_length = Longest_match(hash_head); + } + // longest_match() sets match_start + } + if (match_length >= MIN_MATCH) + { + // check_match(strstart, match_start, match_length); + + bflush = _tr_tally(strstart - match_start, match_length - MIN_MATCH); + + lookahead -= match_length; + + // Insert new strings in the hash table only if the match length + // is not too large. This saves time but degrades compression. + if (match_length <= max_lazy_match && + lookahead >= MIN_MATCH) + { + match_length--; // string at strstart already in hash table + do + { + strstart++; + + ins_h = ((ins_h << hash_shift) ^ (window[(strstart) + (MIN_MATCH - 1)] & 0xff)) & hash_mask; + // prev[strstart&w_mask]=hash_head=head[ins_h]; + hash_head = (head[ins_h] & 0xffff); + prev[strstart & w_mask] = head[ins_h]; + head[ins_h] = (short)strstart; + + // strstart never exceeds WSIZE-MAX_MATCH, so there are + // always MIN_MATCH bytes ahead. + } + while (--match_length != 0); + strstart++; + } + else + { + strstart += match_length; + match_length = 0; + ins_h = window[strstart] & 0xff; + + ins_h = (((ins_h) << hash_shift) ^ (window[strstart + 1] & 0xff)) & hash_mask; + // If lookahead < MIN_MATCH, ins_h is garbage, but it does not + // matter since it will be recomputed at next deflate call. + } + } + else + { + // No match, output a literal byte + + bflush = _tr_tally(0, window[strstart] & 0xff); + lookahead--; + strstart++; + } + if (bflush) + { + + Flush_block_only(false); + if (strm.avail_out == 0) return NeedMore; + } + } + + Flush_block_only(flush == Z_FINISH); + if (strm.avail_out == 0) + { + if (flush == Z_FINISH) return FinishStarted; + else return NeedMore; + } + return flush == Z_FINISH ? FinishDone : BlockDone; + } + + // Same as above, but achieves better compression. We use a lazy + // evaluation for matches: a match is finally adopted only if there is + // no better match at the next window position. + int Deflate_slow(int flush) + { + // short hash_head = 0; // head of hash chain + int hash_head = 0; // head of hash chain + bool bflush; // set if current block must be flushed + + // Process the input block. + while (true) + { + // Make sure that we always have enough lookahead, except + // at the end of the input file. We need MAX_MATCH bytes + // for the next match, plus MIN_MATCH bytes to insert the + // string following the next match. + + if (lookahead < MIN_LOOKAHEAD) + { + Fill_window(); + if (lookahead < MIN_LOOKAHEAD && flush == Z_NO_FLUSH) + { + return NeedMore; + } + if (lookahead == 0) break; // flush the current block + } + + // Insert the string window[strstart .. strstart+2] in the + // dictionary, and set hash_head to the head of the hash chain: + + if (lookahead >= MIN_MATCH) + { + ins_h = (((ins_h) << hash_shift) ^ (window[(strstart) + (MIN_MATCH - 1)] & 0xff)) & hash_mask; + // prev[strstart&w_mask]=hash_head=head[ins_h]; + hash_head = (head[ins_h] & 0xffff); + prev[strstart & w_mask] = head[ins_h]; + head[ins_h] = (short)strstart; + } + + // Find the longest match, discarding those <= prev_length. + prev_length = match_length; prev_match = match_start; + match_length = MIN_MATCH - 1; + + if (hash_head != 0 && prev_length < max_lazy_match && + ((strstart - hash_head) & 0xffff) <= w_size - MIN_LOOKAHEAD + ) + { + // To simplify the code, we prevent matches with the string + // of window index 0 (in particular we have to avoid a match + // of the string with itself at the start of the input file). + + if (strategy != Z_HUFFMAN_ONLY) + { + match_length = Longest_match(hash_head); + } + // longest_match() sets match_start + + if (match_length <= 5 && (strategy == Z_FILTERED || + (match_length == MIN_MATCH && + strstart - match_start > 4096))) + { + + // If prev_match is also MIN_MATCH, match_start is garbage + // but we will ignore the current match anyway. + match_length = MIN_MATCH - 1; + } + } + + // If there was a match at the previous step and the current + // match is not better, output the previous match: + if (prev_length >= MIN_MATCH && match_length <= prev_length) + { + int max_insert = strstart + lookahead - MIN_MATCH; + // Do not insert strings in hash table beyond this. + + // check_match(strstart-1, prev_match, prev_length); + + bflush = _tr_tally(strstart - 1 - prev_match, prev_length - MIN_MATCH); + + // Insert in hash table all strings up to the end of the match. + // strstart-1 and strstart are already inserted. If there is not + // enough lookahead, the last two strings are not inserted in + // the hash table. + lookahead -= prev_length - 1; + prev_length -= 2; + do + { + if (++strstart <= max_insert) + { + ins_h = (((ins_h) << hash_shift) ^ (window[(strstart) + (MIN_MATCH - 1)] & 0xff)) & hash_mask; + //prev[strstart&w_mask]=hash_head=head[ins_h]; + hash_head = (head[ins_h] & 0xffff); + prev[strstart & w_mask] = head[ins_h]; + head[ins_h] = (short)strstart; + } + } + while (--prev_length != 0); + match_available = 0; + match_length = MIN_MATCH - 1; + strstart++; + + if (bflush) + { + Flush_block_only(false); + if (strm.avail_out == 0) return NeedMore; + } + } + else if (match_available != 0) + { + + // If there was no match at the previous position, output a + // single literal. If there was a match but the current match + // is longer, truncate the previous match to a single literal. + + bflush = _tr_tally(0, window[strstart - 1] & 0xff); + + if (bflush) + { + Flush_block_only(false); + } + strstart++; + lookahead--; + if (strm.avail_out == 0) return NeedMore; + } + else + { + // There is no previous match to compare with, wait for + // the next step to decide. + + match_available = 1; + strstart++; + lookahead--; + } + } + + if (match_available != 0) + { + bflush = _tr_tally(0, window[strstart - 1] & 0xff); + match_available = 0; + } + Flush_block_only(flush == Z_FINISH); + + if (strm.avail_out == 0) + { + if (flush == Z_FINISH) return FinishStarted; + else return NeedMore; + } + + return flush == Z_FINISH ? FinishDone : BlockDone; + } + + int Longest_match(int cur_match) + { + int chain_length = max_chain_length; // max hash chain length + int scan = strstart; // current string + int match; // matched string + int len; // length of current match + int best_len = prev_length; // best match length so far + int limit = strstart > (w_size - MIN_LOOKAHEAD) ? + strstart - (w_size - MIN_LOOKAHEAD) : 0; + int nice_match = this.nice_match; + + // Stop when cur_match becomes <= limit. To simplify the code, + // we prevent matches with the string of window index 0. + + int wmask = w_mask; + + int strend = strstart + MAX_MATCH; + byte scan_end1 = window[scan + best_len - 1]; + byte scan_end = window[scan + best_len]; + + // The code is optimized for HASH_BITS >= 8 and MAX_MATCH-2 multiple of 16. + // It is easy to get rid of this optimization if necessary. + + // Do not waste too much time if we already have a good match: + if (prev_length >= good_match) + { + chain_length >>= 2; + } + + // Do not look for matches beyond the end of the input. This is necessary + // to make deflate deterministic. + if (nice_match > lookahead) nice_match = lookahead; + + do + { + match = cur_match; + + // Skip to next match if the match length cannot increase + // or if the match length is less than 2: + if (window[match + best_len] != scan_end || + window[match + best_len - 1] != scan_end1 || + window[match] != window[scan] || + window[++match] != window[scan + 1]) continue; + + // The check at best_len-1 can be removed because it will be made + // again later. (This heuristic is not always a win.) + // It is not necessary to compare scan[2] and match[2] since they + // are always equal when the other bytes match, given that + // the hash keys are equal and that HASH_BITS >= 8. + scan += 2; match++; + + // We check for insufficient lookahead only every 8th comparison; + // the 256th check will be made at strstart+258. + do + { + } while (window[++scan] == window[++match] && + window[++scan] == window[++match] && + window[++scan] == window[++match] && + window[++scan] == window[++match] && + window[++scan] == window[++match] && + window[++scan] == window[++match] && + window[++scan] == window[++match] && + window[++scan] == window[++match] && + scan < strend); + + len = MAX_MATCH - (int)(strend - scan); + scan = strend - MAX_MATCH; + + if (len > best_len) + { + match_start = cur_match; + best_len = len; + if (len >= nice_match) break; + scan_end1 = window[scan + best_len - 1]; + scan_end = window[scan + best_len]; + } + + } while ((cur_match = (prev[cur_match & wmask] & 0xffff)) > limit + && --chain_length != 0); + + if (best_len <= lookahead) return best_len; + return lookahead; + } + + internal int DeflateInit(int level, int bits, int memlevel) => + DeflateInit(level, Z_DEFLATED, bits, memlevel, Z_DEFAULT_STRATEGY); + + internal int DeflateInit(int level, int bits) => + DeflateInit(level, Z_DEFLATED, bits, DEF_MEM_LEVEL, Z_DEFAULT_STRATEGY); + + internal int DeflateInit(int level) => DeflateInit(level, MAX_WBITS); + + int DeflateInit(int level, int method, int windowBits, int memLevel, int strategy) + { + int wrap = 1; + // byte[] my_version=ZLIB_VERSION; + + // + // if (version == null || version[0] != my_version[0] + // || stream_size != sizeof(z_stream)) { + // return Z_VERSION_ERROR; + // } + + strm.msg = null; + + if (level == Z_DEFAULT_COMPRESSION) level = 6; + + if (windowBits < 0) + { // undocumented feature: suppress zlib header + wrap = 0; + windowBits = -windowBits; + } + else if (windowBits > 15) + { + wrap = 2; + windowBits -= 16; + strm.adler = new CRC32(); + } + + if (memLevel < 1 || memLevel > MAX_MEM_LEVEL || + method != Z_DEFLATED || + windowBits < 9 || windowBits > 15 || level < 0 || level > 9 || + strategy < 0 || strategy > Z_HUFFMAN_ONLY) + { + return Z_STREAM_ERROR; + } + + strm.dstate = this; + + this.wrap = wrap; + w_bits = windowBits; + w_size = 1 << w_bits; + w_mask = w_size - 1; + + hash_bits = memLevel + 7; + hash_size = 1 << hash_bits; + hash_mask = hash_size - 1; + hash_shift = ((hash_bits + MIN_MATCH - 1) / MIN_MATCH); + + window = new byte[w_size * 2]; + prev = new short[w_size]; + head = new short[hash_size]; + + lit_bufsize = 1 << (memLevel + 6); // 16K elements by default + + // We overlay pending_buf and d_buf+l_buf. This works since the average + // output size for (length,distance) codes is <= 24 bits. + pending_buf = new byte[lit_bufsize * 3]; + pending_buf_size = lit_bufsize * 3; + + d_buf = lit_bufsize; + l_buf = new byte[lit_bufsize]; + + this.level = level; + + this.strategy = strategy; + this.method = (byte)method; + + return DeflateReset(); + } + + int DeflateReset() + { + strm.total_in = strm.total_out = 0; + strm.msg = null; // + strm.data_type = Z_UNKNOWN; + + pending = 0; + pending_out = 0; + + if (wrap < 0) + { + wrap = -wrap; + } + status = (wrap == 0) ? BUSY_STATE : INIT_STATE; + strm.adler.Reset(); + + last_flush = Z_NO_FLUSH; + + Tr_init(); + Lm_init(); + return Z_OK; + } + + internal int DeflateEnd() + { + if (status != INIT_STATE && status != BUSY_STATE && status != FINISH_STATE) + { + return Z_STREAM_ERROR; + } + // Deallocate in reverse order of allocations: + pending_buf = null; + l_buf = null; + head = null; + prev = null; + window = null; + // free + // dstate=null; + return status == BUSY_STATE ? Z_DATA_ERROR : Z_OK; + } + + internal int DeflateParams(int _level, int _strategy) + { + int err = Z_OK; + + if (_level == Z_DEFAULT_COMPRESSION) + { + _level = 6; + } + if (_level < 0 || _level > 9 || + _strategy < 0 || _strategy > Z_HUFFMAN_ONLY) + { + return Z_STREAM_ERROR; + } + + if (config_table[level].func != config_table[_level].func && + strm.total_in != 0) + { + // Flush the last buffer: + err = strm.Deflate_z(Z_PARTIAL_FLUSH); + } + + if (level != _level) + { + level = _level; + max_lazy_match = config_table[level].max_lazy; + good_match = config_table[level].good_length; + nice_match = config_table[level].nice_length; + max_chain_length = config_table[level].max_chain; + } + strategy = _strategy; + return err; + } + + internal int DeflateSetDictionary(byte[] dictionary, int dictLength) + { + int length = dictLength; + int index = 0; + + if (dictionary == null || status != INIT_STATE) + return Z_STREAM_ERROR; + + strm.adler.Update(dictionary, 0, dictLength); + + if (length < MIN_MATCH) return Z_OK; + if (length > w_size - MIN_LOOKAHEAD) + { + length = w_size - MIN_LOOKAHEAD; + index = dictLength - length; // use the tail of the dictionary + } + Array.Copy(dictionary, index, window, 0, length); + strstart = length; + block_start = length; + + // Insert all strings in the hash table (except for the last two bytes). + // s->lookahead stays null, so s->ins_h will be recomputed at the next + // call of fill_window. + + ins_h = window[0] & 0xff; + ins_h = (((ins_h) << hash_shift) ^ (window[1] & 0xff)) & hash_mask; + + for (int n = 0; n <= length - MIN_MATCH; n++) + { + ins_h = (((ins_h) << hash_shift) ^ (window[(n) + (MIN_MATCH - 1)] & 0xff)) & hash_mask; + prev[n & w_mask] = head[ins_h]; + head[ins_h] = (short)n; + } + return Z_OK; + } + + internal int Deflate_D(int flush) + { + int old_flush; + + if (flush > Z_FINISH || flush < 0) + { + return Z_STREAM_ERROR; + } + + if (strm.next_out == null || + (strm.next_in == null && strm.avail_in != 0) || + (status == FINISH_STATE && flush != Z_FINISH)) + { + strm.msg = z_errmsg[Z_NEED_DICT - (Z_STREAM_ERROR)]; + return Z_STREAM_ERROR; + } + if (strm.avail_out == 0) + { + strm.msg = z_errmsg[Z_NEED_DICT - (Z_BUF_ERROR)]; + return Z_BUF_ERROR; + } + + old_flush = last_flush; + last_flush = flush; + + // Write the zlib header + if (status == INIT_STATE) + { + if (wrap == 2) + { + GetGZIPHeader().Put(this); + status = BUSY_STATE; + strm.adler.Reset(); + } + else + { + int header = (Z_DEFLATED + ((w_bits - 8) << 4)) << 8; + int level_flags = ((level - 1) & 0xff) >> 1; + + if (level_flags > 3) level_flags = 3; + header |= (level_flags << 6); + if (strstart != 0) header |= PRESET_DICT; + header += 31 - (header % 31); + + status = BUSY_STATE; + PutShortMSB(header); + + + // Save the adler32 of the preset dictionary: + if (strstart != 0) + { + long adler = strm.adler.GetValue(); + PutShortMSB((int)(adler.RightUShift(16))); + PutShortMSB((int)(adler & 0xffff)); + } + strm.adler.Reset(); + } + } + + // Flush as much pending output as possible + if (pending != 0) + { + strm.Flush_pending(); + if (strm.avail_out == 0) + { + // Since avail_out is 0, deflate will be called again with + // more output space, but possibly with both pending and + // avail_in equal to zero. There won't be anything to do, + // but this is not an error situation so make sure we + // return OK instead of BUF_ERROR at next call of deflate: + last_flush = -1; + return Z_OK; + } + + // Make sure there is something to do and avoid duplicate consecutive + // flushes. For repeated and useless calls with Z_FINISH, we keep + // returning Z_STREAM_END instead of Z_BUFF_ERROR. + } + else if (strm.avail_in == 0 && flush <= old_flush && + flush != Z_FINISH) + { + strm.msg = z_errmsg[Z_NEED_DICT - (Z_BUF_ERROR)]; + return Z_BUF_ERROR; + } + + // User must not provide more input after the first FINISH: + if (status == FINISH_STATE && strm.avail_in != 0) + { + strm.msg = z_errmsg[Z_NEED_DICT - (Z_BUF_ERROR)]; + return Z_BUF_ERROR; + } + + // Start a new block or continue the current one. + if (strm.avail_in != 0 || lookahead != 0 || + (flush != Z_NO_FLUSH && status != FINISH_STATE)) + { + int bstate = -1; + switch (config_table[level].func) + { + case STORED: + bstate = Deflate_stored(flush); + break; + case FAST: + bstate = Deflate_fast(flush); + break; + case SLOW: + bstate = Deflate_slow(flush); + break; + } + + if (bstate == FinishStarted || bstate == FinishDone) + { + status = FINISH_STATE; + } + if (bstate == NeedMore || bstate == FinishStarted) + { + if (strm.avail_out == 0) + { + last_flush = -1; // avoid BUF_ERROR next call, see above + } + return Z_OK; + // If flush != Z_NO_FLUSH && avail_out == 0, the next call + // of deflate should use the same flush parameter to make sure + // that the flush is complete. So we don't have to output an + // empty block here, this will be done at next call. This also + // ensures that for a very small output buffer, we emit at most + // one empty block. + } + + if (bstate == BlockDone) + { + if (flush == Z_PARTIAL_FLUSH) + { + _tr_align(); + } + else + { // FULL_FLUSH or SYNC_FLUSH + _tr_stored_block(0, 0, false); + // For a full flush, this empty block will be recognized + // as a special marker by inflate_sync(). + if (flush == Z_FULL_FLUSH) + { + //state.head[s.hash_size-1]=0; + for (int i = 0; i < hash_size/*-1*/; i++) // forget history + head[i] = 0; + } + } + strm.Flush_pending(); + if (strm.avail_out == 0) + { + last_flush = -1; // avoid BUF_ERROR at next call, see above + return Z_OK; + } + } + } + + if (flush != Z_FINISH) return Z_OK; + if (wrap <= 0) return Z_STREAM_END; + + if (wrap == 2) + { + long adler = strm.adler.GetValue(); + Put_byte((byte)(adler & 0xff)); + Put_byte((byte)((adler >> 8) & 0xff)); + Put_byte((byte)((adler >> 16) & 0xff)); + Put_byte((byte)((adler >> 24) & 0xff)); + Put_byte((byte)(strm.total_in & 0xff)); + Put_byte((byte)((strm.total_in >> 8) & 0xff)); + Put_byte((byte)((strm.total_in >> 16) & 0xff)); + Put_byte((byte)((strm.total_in >> 24) & 0xff)); + + GetGZIPHeader().SetCRC(adler); + } + else + { + // Write the zlib trailer (adler32) + long adler = strm.adler.GetValue(); + PutShortMSB((int)(adler.RightUShift(16))); + PutShortMSB((int)(adler & 0xffff)); + } + + strm.Flush_pending(); + + // If avail_out is zero, the application will call deflate again + // to flush the rest. + + if (wrap > 0) wrap = -wrap; // write the trailer only once! + return pending != 0 ? Z_OK : Z_STREAM_END; + } + + internal static int DeflateCopy(ZStream dest, ZStream src) + { + if (src.dstate == null) + { + return Z_STREAM_ERROR; + } + + if (src.next_in != null) + { + dest.next_in = new byte[src.next_in.Length]; + Array.Copy(src.next_in, 0, dest.next_in, 0, src.next_in.Length); + } + dest.next_in_index = src.next_in_index; + dest.avail_in = src.avail_in; + dest.total_in = src.total_in; + + if (src.next_out != null) + { + dest.next_out = new byte[src.next_out.Length]; + Array.Copy(src.next_out, 0, dest.next_out, 0, src.next_out.Length); + } + + dest.next_out_index = src.next_out_index; + dest.avail_out = src.avail_out; + dest.total_out = src.total_out; + + dest.msg = src.msg; + dest.data_type = src.data_type; + dest.adler = src.adler.Copy(); + + dest.dstate = src.dstate.Clone(dest); + dest.dstate.strm = dest; + + return Z_OK; + } + + public Deflate Clone(ZStream z) + { + var dest = new Deflate(z); + + dest.pending_buf = Dup(dest.pending_buf); + //dest.d_buf = dest.d_buf; + dest.l_buf = Dup(dest.l_buf); + dest.window = Dup(dest.window); + + dest.prev = Dup(dest.prev); + dest.head = Dup(dest.head); + dest.dyn_ltree = Dup(dest.dyn_ltree); + dest.dyn_dtree = Dup(dest.dyn_dtree); + dest.bl_tree = Dup(dest.bl_tree); + + dest.bl_count = Dup(dest.bl_count); + dest.next_code = Dup(dest.next_code); + dest.heap = Dup(dest.heap); + dest.depth = Dup(dest.depth); + + dest.l_desc.dyn_tree = dest.dyn_ltree; + dest.d_desc.dyn_tree = dest.dyn_dtree; + dest.bl_desc.dyn_tree = dest.bl_tree; + + /* + dest.l_desc.stat_desc = StaticTree.static_l_desc; + dest.d_desc.stat_desc = StaticTree.static_d_desc; + dest.bl_desc.stat_desc = StaticTree.static_bl_desc; + */ + + if (dest.gheader != null) + { + dest.gheader = dest.gheader.Clone(); + } + + return dest; + } + + static byte[] Dup(byte[] buf) + { + var foo = new byte[buf.Length]; + Array.Copy(buf, 0, foo, 0, foo.Length); + return foo; + } + + static short[] Dup(short[] buf) + { + var foo = new short[buf.Length]; + Array.Copy(buf, 0, foo, 0, foo.Length); + return foo; + } + + static int[] Dup(int[] buf) + { + var foo = new int[buf.Length]; + Array.Copy(buf, 0, foo, 0, foo.Length); + return foo; + } + + GZIPHeader GetGZIPHeader() + { + lock (this) + { + if (gheader == null) + { + gheader = new GZIPHeader(); + } + return gheader; + } + } + } +} diff --git a/src/DotNetty.Codecs/Compression/Deflater.cs b/src/DotNetty.Codecs/Compression/Deflater.cs new file mode 100644 index 0000000..733ebfe --- /dev/null +++ b/src/DotNetty.Codecs/Compression/Deflater.cs @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +//JZlib 0.0.* were released under the GNU LGPL license.Later, we have switched +//over to a BSD-style license. + +//------------------------------------------------------------------------------ +//Copyright (c) 2000-2011 ymnk, JCraft, Inc.All rights reserved. + +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in +// the documentation and/or other materials provided with the distribution. + +// 3. The names of the authors may not be used to endorse or promote products +// derived from this software without specific prior written permission. + +//THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, +//INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +//FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.IN NO EVENT SHALL JCRAFT, +//INC.OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +//LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +//OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +//LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING +//NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +//EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ReSharper disable ArrangeThisQualifier +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Compression +{ + /// + /// https://github.com/ymnk/jzlib/blob/master/src/main/java/com/jcraft/jzlib/Deflater.java + /// + sealed class Deflater : ZStream + { + const int MAX_WBITS = 15; // 32K LZ77 window + //const int DEF_WBITS = MAX_WBITS; + + //const int Z_NO_FLUSH = 0; + //const int Z_PARTIAL_FLUSH = 1; + //const int Z_SYNC_FLUSH = 2; + //const int Z_FULL_FLUSH = 3; + //const int Z_FINISH = 4; + + //const int MAX_MEM_LEVEL = 9; + + const int Z_OK = 0; + const int Z_STREAM_END = 1; + //const int Z_NEED_DICT = 2; + //const int Z_ERRNO = -1; + const int Z_STREAM_ERROR = -2; + //const int Z_DATA_ERROR = -3; + //const int Z_MEM_ERROR = -4; + //const int Z_BUF_ERROR = -5; + //const int Z_VERSION_ERROR = -6; + + bool finished = false; + + public Deflater() + { + } + + public Deflater(int level) + : this(level, MAX_WBITS) + { + } + + public Deflater(int level, bool nowrap) + : this(level, MAX_WBITS, nowrap) + { + } + + public Deflater(int level, int bits) + : this(level, bits, false) + { + } + + public Deflater(int level, int bits, bool nowrap) + { + int ret = Init(level, bits, nowrap); + if (ret != Z_OK) throw new GZIPException(ret + ": " + msg); + } + + public Deflater(int level, int bits, int memlevel, JZlib.WrapperType wrapperType) + { + int ret = Init(level, bits, memlevel, wrapperType); + if (ret != Z_OK) throw new GZIPException(ret + ": " + msg); + } + + public Deflater(int level, int bits, int memlevel) + { + int ret = Init(level, bits, memlevel); + if (ret != Z_OK) throw new GZIPException(ret + ": " + msg); + } + + public int Init(int level) => Init(level, MAX_WBITS); + + public int Init(int level, bool nowrap) => Init(level, MAX_WBITS, nowrap); + + public int Init(int level, int bits) => Init(level, bits, false); + + public int Init(int level, int bits, int memlevel, JZlib.WrapperType wrapperType) + { + if (bits < 9 || bits > 15) + { + return Z_STREAM_ERROR; + } + if (wrapperType == JZlib.W_NONE) + { + bits *= -1; + } + else if (wrapperType == JZlib.W_GZIP) + { + bits += 16; + } + else if (wrapperType == JZlib.W_ANY) + { + return Z_STREAM_ERROR; + } + else if (wrapperType == JZlib.W_ZLIB) + { + } + return Init(level, bits, memlevel); + } + + public int Init(int level, int bits, int memlevel) + { + finished = false; + dstate = new Deflate(this); + return dstate.DeflateInit(level, bits, memlevel); + } + + public int Init(int level, int bits, bool nowrap) + { + finished = false; + dstate = new Deflate(this); + return dstate.DeflateInit(level, nowrap ? -bits : bits); + } + + public int Deflate(int flush) + { + if (dstate == null) + { + return Z_STREAM_ERROR; + } + int ret = dstate.Deflate_D(flush); + if (ret == Z_STREAM_END) + finished = true; + return ret; + } + + public override int End() + { + finished = true; + if (dstate == null) + return Z_STREAM_ERROR; + int ret = dstate.DeflateEnd(); + dstate = null; + Free(); + return ret; + } + + public int Params(int level, int strategy) + { + if (dstate == null) return Z_STREAM_ERROR; + return dstate.DeflateParams(level, strategy); + } + + public int SetDictionary(byte[] dictionary, int dictLength) + { + if (dstate == null) return Z_STREAM_ERROR; + return dstate.DeflateSetDictionary(dictionary, dictLength); + } + + public override bool Finished() => finished; + + public int Copy(Deflater src) + { + this.finished = src.finished; + return Compression.Deflate.DeflateCopy(this, src); + } + } +} diff --git a/src/DotNetty.Codecs/Compression/GZIPException.cs b/src/DotNetty.Codecs/Compression/GZIPException.cs new file mode 100644 index 0000000..793f619 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/GZIPException.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +//JZlib 0.0.* were released under the GNU LGPL license.Later, we have switched +//over to a BSD-style license. + +//------------------------------------------------------------------------------ +//Copyright (c) 2000-2011 ymnk, JCraft, Inc.All rights reserved. + +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in +// the documentation and/or other materials provided with the distribution. + +// 3. The names of the authors may not be used to endorse or promote products +// derived from this software without specific prior written permission. + +//THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, +//INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +//FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.IN NO EVENT SHALL JCRAFT, +//INC.OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +//LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +//OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +//LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING +//NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +//EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// + +namespace DotNetty.Codecs.Compression +{ + using System.IO; + + /// + /// https://github.com/ymnk/jzlib/blob/master/src/main/java/com/jcraft/jzlib/GZIPException.java + /// + public class GZIPException : IOException + { + public GZIPException() + { + } + + public GZIPException(string s) : base(s) + { + } + } +} diff --git a/src/DotNetty.Codecs/Compression/GZIPHeader.cs b/src/DotNetty.Codecs/Compression/GZIPHeader.cs new file mode 100644 index 0000000..1937f01 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/GZIPHeader.cs @@ -0,0 +1,223 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +//JZlib 0.0.* were released under the GNU LGPL license.Later, we have switched +//over to a BSD-style license. + +//------------------------------------------------------------------------------ +//Copyright (c) 2000-2011 ymnk, JCraft, Inc.All rights reserved. + +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in +// the documentation and/or other materials provided with the distribution. + +// 3. The names of the authors may not be used to endorse or promote products +// derived from this software without specific prior written permission. + +//THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, +//INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +//FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.IN NO EVENT SHALL JCRAFT, +//INC.OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +//LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +//OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +//LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING +//NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +//EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ReSharper disable ArrangeThisQualifier +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Compression +{ + using System; + using System.Runtime.InteropServices; + using System.Text; + + /// + /// https://github.com/ymnk/jzlib/blob/master/src/main/java/com/jcraft/jzlib/GZIPHeader.java + /// + /// http://www.ietf.org/rfc/rfc1952.txt + /// + class GZIPHeader + { + static readonly Encoding ISOEncoding = Encoding.GetEncoding("ISO-8859-1"); + static readonly byte Platform; + + public static readonly byte OS_MSDOS = (byte)0x00; + public static readonly byte OS_AMIGA = (byte)0x01; + public static readonly byte OS_VMS = (byte)0x02; + public static readonly byte OS_UNIX = (byte)0x03; + public static readonly byte OS_ATARI = (byte)0x05; + public static readonly byte OS_OS2 = (byte)0x06; + public static readonly byte OS_MACOS = (byte)0x07; + public static readonly byte OS_TOPS20 = (byte)0x0a; + public static readonly byte OS_WIN32 = (byte)0x0b; + public static readonly byte OS_VMCMS = (byte)0x04; + public static readonly byte OS_ZSYSTEM = (byte)0x08; + public static readonly byte OS_CPM = (byte)0x09; + public static readonly byte OS_QDOS = (byte)0x0c; + public static readonly byte OS_RISCOS = (byte)0x0d; + public static readonly byte OS_UNKNOWN = (byte)0xff; + + bool text = false; + bool fhcrc = false; + internal long time; + internal int xflags; + internal int os; + internal byte[] extra; + internal byte[] name; + internal byte[] comment; + internal int hcrc; + internal long crc; + //bool done = false; + long mtime = 0; + + static GZIPHeader() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Platform = OS_WIN32; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + Platform = OS_UNIX; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + Platform = OS_MACOS; + } + else + { + Platform = OS_UNKNOWN; + } + } + + internal GZIPHeader() + { + this.os = Platform; + } + + public void SetModifiedTime(long value) => this.mtime = value; + + public long GetModifiedTime() => this.mtime; + + public void SetOS(int value) + { + if ((0 <= value && value <= 13) || value == 255) + this.os = value; + else + throw new ArgumentException(nameof(value)); + } + + public int GetOS() => this.os; + + public void SetName(string value) => this.name = ISOEncoding.GetBytes(value); + + public string GetName() => this.name == null ? string.Empty : ISOEncoding.GetString(this.name); + + public void SetComment(string value) => this.comment = ISOEncoding.GetBytes(value); + + public string GetComment() => this.comment == null ? string.Empty : ISOEncoding.GetString(this.comment); + + public void SetCRC(long value) => this.crc = value; + + public long GetCRC() => this.crc; + + internal void Put(Deflate d) + { + int flag = 0; + if (text) + { + flag |= 1; // FTEXT + } + if (fhcrc) + { + flag |= 2; // FHCRC + } + if (extra != null) + { + flag |= 4; // FEXTRA + } + if (name != null) + { + flag |= 8; // FNAME + } + if (comment != null) + { + flag |= 16; // FCOMMENT + } + int xfl = 0; + if (d.level == JZlib.Z_BEST_SPEED) + { + xfl |= 4; + } + else if (d.level == JZlib.Z_BEST_COMPRESSION) + { + xfl |= 2; + } + + d.Put_short(unchecked((short)0x8b1f)); // ID1 ID2 + d.Put_byte((byte)8); // CM(Compression Method) + d.Put_byte((byte)flag); + d.Put_byte((byte)mtime); + d.Put_byte((byte)(mtime >> 8)); + d.Put_byte((byte)(mtime >> 16)); + d.Put_byte((byte)(mtime >> 24)); + d.Put_byte((byte)xfl); + d.Put_byte((byte)os); + + if (extra != null) + { + d.Put_byte((byte)extra.Length); + d.Put_byte((byte)(extra.Length >> 8)); + d.Put_byte(extra, 0, extra.Length); + } + + if (name != null) + { + d.Put_byte(name, 0, name.Length); + d.Put_byte((byte)0); + } + + if (comment != null) + { + d.Put_byte(comment, 0, comment.Length); + d.Put_byte((byte)0); + } + } + + public GZIPHeader Clone() + { + var gheader = new GZIPHeader(); + byte[] tmp; + if (gheader.extra != null) + { + tmp = new byte[gheader.extra.Length]; + Array.Copy(gheader.extra, 0, tmp, 0, tmp.Length); + gheader.extra = tmp; + } + + if (gheader.name != null) + { + tmp = new byte[gheader.name.Length]; + Array.Copy(gheader.name, 0, tmp, 0, tmp.Length); + gheader.name = tmp; + } + + if (gheader.comment != null) + { + tmp = new byte[gheader.comment.Length]; + Array.Copy(gheader.comment, 0, tmp, 0, tmp.Length); + gheader.comment = tmp; + } + + return gheader; + } + } +} diff --git a/src/DotNetty.Codecs/Compression/IChecksum.cs b/src/DotNetty.Codecs/Compression/IChecksum.cs new file mode 100644 index 0000000..f3a1136 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/IChecksum.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Compression +{ + interface IChecksum + { + void Update(byte[] buf, int index, int len); + + void Reset(); + + void Reset(long init); + + long GetValue(); + + IChecksum Copy(); + } +} diff --git a/src/DotNetty.Codecs/Compression/InfBlocks.cs b/src/DotNetty.Codecs/Compression/InfBlocks.cs new file mode 100644 index 0000000..6ba0fdc --- /dev/null +++ b/src/DotNetty.Codecs/Compression/InfBlocks.cs @@ -0,0 +1,695 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +//JZlib 0.0.* were released under the GNU LGPL license.Later, we have switched +//over to a BSD-style license. + +//------------------------------------------------------------------------------ +//Copyright (c) 2000-2011 ymnk, JCraft, Inc.All rights reserved. + +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in +// the documentation and/or other materials provided with the distribution. + +// 3. The names of the authors may not be used to endorse or promote products +// derived from this software without specific prior written permission. + +//THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, +//INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +//FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.IN NO EVENT SHALL JCRAFT, +//INC.OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +//LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +//OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +//LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING +//NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +//EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ReSharper disable ArrangeThisQualifier +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Compression +{ + using System; + using DotNetty.Common.Utilities; + + /// + /// https://github.com/ymnk/jzlib/blob/master/src/main/java/com/jcraft/jzlib/InfBlocks.java + /// + sealed class InfBlocks + { + const int MANY = 1440; + + // And'ing with mask[n] masks the lower n bits + static readonly int[] inflate_mask = + { + 0x00000000, 0x00000001, 0x00000003, 0x00000007, 0x0000000f, + 0x0000001f, 0x0000003f, 0x0000007f, 0x000000ff, 0x000001ff, + 0x000003ff, 0x000007ff, 0x00000fff, 0x00001fff, 0x00003fff, + 0x00007fff, 0x0000ffff + }; + + // Table for deflate from PKZIP's appnote.txt. + static readonly int[] border = + { + // Order of the bit length code lengths + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 + }; + + const int Z_OK = 0; + const int Z_STREAM_END = 1; + //const int Z_NEED_DICT = 2; + //const int Z_ERRNO = -1; + const int Z_STREAM_ERROR = -2; + const int Z_DATA_ERROR = -3; + //const int Z_MEM_ERROR = -4; + const int Z_BUF_ERROR = -5; + //const int Z_VERSION_ERROR = -6; + + const int TYPE = 0; // get type bits (3, including end bit) + const int LENS = 1; // get lengths for stored + const int STORED = 2;// processing stored block + const int TABLE = 3; // get table lengths + const int BTREE = 4; // get bit lengths tree for a dynamic block + const int DTREE = 5; // get length, distance trees for a dynamic block + const int CODES = 6; // processing fixed or dynamic block + const int DRY = 7; // output remaining window bytes + const int DONE = 8; // finished last block, done + const int BAD = 9; // ot a data error--stuck here + + int mode; // current inflate_block mode + + int left; // if STORED, bytes left to copy + + int table; // table lengths (14 bits) + int index; // index into blens (or border) + int[] blens; // bit lengths of codes + int[] bb = new int[1]; // bit length tree depth + int[] tb = new int[1]; // bit length decoding tree + + int[] bl = new int[1]; + int[] bd = new int[1]; + + int[][] tl = new int[1][]; + int[][] td = new int[1][]; + int[] tli = new int[1]; // tl_index + int[] tdi = new int[1]; // td_index + + readonly InfCodes codes; // if CODES, current state + + int last; // true if this block is the last block + + // mode independent information + internal int bitk; // bits in bit buffer + internal int bitb; // bit buffer + internal int[] hufts; // single malloc for tree space + internal byte[] window; // sliding window + internal int end; // one byte after sliding window + internal int read; // window read pointer + internal int write; // window write pointer + bool check; + + readonly InfTree inftree=new InfTree(); + + readonly ZStream z; + + internal InfBlocks(ZStream z, int w) + { + this.z = z; + this.codes = new InfCodes(this.z, this); + hufts = new int[MANY * 3]; + window = new byte[w]; + end = w; + this.check = (z.istate.wrap == 0) ? false : true; + mode = TYPE; + Reset(); + } + + internal void Reset() + { + if (mode == BTREE || mode == DTREE) + { + } + if (mode == CODES) + { + codes.Free(z); + } + mode = TYPE; + bitk = 0; + bitb = 0; + read = write = 0; + if (check) + { + z.adler.Reset(); + } + } + + internal int Proc(int r) + { + int t; // temporary storage + int b; // bit buffer + int k; // bits in bit buffer + int p; // input data pointer + int n; // bytes available there + int q; // output window write pointer + int m; // bytes to end of window or read pointer + + // copy input/output information to locals (UPDATE macro restores) + { p = z.next_in_index; n = z.avail_in; b = bitb; k = bitk; } + { q = write; m = (int)(q < read ? read - q - 1 : end - q); } + + // process input based on current state + while (true) + { + if (this.mode == TYPE) + { + while (k < (3)) + { + if (n != 0) + { + r = Z_OK; + } + else + { + bitb = b; bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + }; + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + t = (int)(b & 7); + last = t & 1; + + switch (t.RightUShift(1)) + { + case 0: // stored + { b = b.RightUShift(3); k -= (3); } + t = k & 7; // go to byte boundary + + { b = b.RightUShift(t); k -= (t); } + mode = LENS; // get length of stored block + break; + case 1: // fixed + InfTree.Inflate_trees_fixed(bl, bd, tl, td, z); + codes.Init(bl[0], bd[0], tl[0], 0, td[0], 0); + + { b = b.RightUShift(3); k -= (3); } + + mode = CODES; + break; + case 2: // dynamic + { b = b.RightUShift(3); k -= (3); } + + mode = TABLE; + break; + case 3: // illegal + { b = b.RightUShift(3); k -= (3); } + mode = BAD; + z.msg = "invalid block type"; + r = Z_DATA_ERROR; + + bitb = b; bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + } + continue; // break; + } // case TYPE + if (this.mode == LENS) + { + while (k < (32)) + { + if (n != 0) + { + r = Z_OK; + } + else + { + bitb = b; bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + }; + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + if ((((~b).RightUShift(16)) & 0xffff) != (b & 0xffff)) + { + mode = BAD; + z.msg = "invalid stored block lengths"; + r = Z_DATA_ERROR; + + bitb = b; bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + } + left = (b & 0xffff); + b = k = 0; // dump bits + mode = left != 0 ? STORED : (last != 0 ? DRY : TYPE); + continue; // break; + } // case LENS + if (this.mode == STORED) + { + if (n == 0) + { + bitb = b; bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + } + + if (m == 0) + { + if (q == end && read != 0) + { + q = 0; m = (int)(q < read ? read - q - 1 : end - q); + } + if (m == 0) + { + write = q; + r = Inflate_flush(r); + q = write; m = (int)(q < read ? read - q - 1 : end - q); + if (q == end && read != 0) + { + q = 0; m = (int)(q < read ? read - q - 1 : end - q); + } + if (m == 0) + { + bitb = b; bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + } + } + } + r = Z_OK; + + t = left; + if (t > n) t = n; + if (t > m) t = m; + Array.Copy(z.next_in, p, window, q, t); + p += t; n -= t; + q += t; m -= t; + if ((left -= t) != 0) + continue; // break; + mode = last != 0 ? DRY : TYPE; + continue; // break; + } // case STORED + if (this.mode == TABLE) + { + while (k < (14)) + { + if (n != 0) + { + r = Z_OK; + } + else + { + bitb = b; bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + }; + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + table = t = (b & 0x3fff); + if ((t & 0x1f) > 29 || ((t >> 5) & 0x1f) > 29) + { + mode = BAD; + z.msg = "too many length or distance symbols"; + r = Z_DATA_ERROR; + + bitb = b; bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + } + t = 258 + (t & 0x1f) + ((t >> 5) & 0x1f); + if (blens == null || blens.Length < t) + { + blens = new int[t]; + } + else + { + for (int i = 0; i < t; i++) { blens[i] = 0; } + } + + { b = b.RightUShift(14); k -= (14); } + + index = 0; + mode = BTREE; + } // case TABLE + if (this.mode == BTREE) + { + while (index < 4 + (table.RightUShift(10))) + { + while (k < (3)) + { + if (n != 0) + { + r = Z_OK; + } + else + { + bitb = b; bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + }; + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + blens[border[index++]] = b & 7; + + { b = b.RightUShift(3); k -= (3); } + } + + while (index < 19) + { + blens[border[index++]] = 0; + } + + bb[0] = 7; + t = inftree.Inflate_trees_bits(blens, bb, tb, hufts, z); + if (t != Z_OK) + { + r = t; + if (r == Z_DATA_ERROR) + { + blens = null; + mode = BAD; + } + + bitb = b; bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + } + + index = 0; + mode = DTREE; + } // case BTREE + if (this.mode == DTREE) + { + while (true) + { + t = table; + if (!(index < 258 + (t & 0x1f) + ((t >> 5) & 0x1f))) + { + break; + } + + //int[] h; + int i, j, c; + + t = bb[0]; + + while (k < (t)) + { + if (n != 0) + { + r = Z_OK; + } + else + { + bitb = b; bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + }; + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + if (tb[0] == -1) + { + //System.err.println("null..."); + } + + t = hufts[(tb[0] + (b & inflate_mask[t])) * 3 + 1]; + c = hufts[(tb[0] + (b & inflate_mask[t])) * 3 + 2]; + + if (c < 16) + { + b = b.RightUShift(t); k -= (t); + blens[index++] = c; + } + else + { // c == 16..18 + i = c == 18 ? 7 : c - 14; + j = c == 18 ? 11 : 3; + + while (k < (t + i)) + { + if (n != 0) + { + r = Z_OK; + } + else + { + bitb = b; bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + }; + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + b = b.RightUShift(t); k -= (t); + + j += (b & inflate_mask[i]); + + b = b.RightUShift(i); k -= (i); + + i = index; + t = table; + if (i + j > 258 + (t & 0x1f) + ((t >> 5) & 0x1f) || + (c == 16 && i < 1)) + { + blens = null; + mode = BAD; + z.msg = "invalid bit length repeat"; + r = Z_DATA_ERROR; + + bitb = b; bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + } + + c = c == 16 ? blens[i - 1] : 0; + do + { + blens[i++] = c; + } + while (--j != 0); + index = i; + } + } + + tb[0] = -1; + { + bl[0] = 9; // must be <= 9 for lookahead assumptions + bd[0] = 6; // must be <= 9 for lookahead assumptions + t = table; + t = inftree.Inflate_trees_dynamic(257 + (t & 0x1f), + 1 + ((t >> 5) & 0x1f), + blens, bl, bd, tli, tdi, hufts, z); + + if (t != Z_OK) + { + if (t == Z_DATA_ERROR) + { + blens = null; + mode = BAD; + } + r = t; + + bitb = b; bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + } + codes.Init(bl[0], bd[0], hufts, tli[0], hufts, tdi[0]); + } + mode = CODES; + } // case DTREE + if (this.mode == CODES) + { + bitb = b; bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + + if ((r = codes.Proc(r)) != Z_STREAM_END) + { + return Inflate_flush(r); + } + r = Z_OK; + codes.Free(z); + + p = z.next_in_index; n = z.avail_in; b = bitb; k = bitk; + q = write; m = (int)(q < read ? read - q - 1 : end - q); + + if (last == 0) + { + mode = TYPE; + continue; // break; + } + mode = DRY; + } // case CODES + if (this.mode == DRY) + { + write = q; + r = Inflate_flush(r); + q = write; m = (int)(q < read ? read - q - 1 : end - q); + if (read != write) + { + bitb = b; bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + } + mode = DONE; + } // case DRY + if (this.mode == DONE) + { + r = Z_STREAM_END; + + bitb = b; bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + } // case DONE + if (this.mode == BAD) + { + r = Z_DATA_ERROR; + + bitb = b; bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + } // case BAD + + break; // default + } + + r = Z_STREAM_ERROR; + + bitb = b; bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + write = q; + return Inflate_flush(r); + } + + internal void Free() + { + Reset(); + window = null; + hufts = null; + //ZFREE(z, s); + } + + internal void Set_dictionary(byte[] d, int start, int n) + { + Array.Copy(d, start, window, 0, n); + read = write = n; + } + + // Returns true if inflate is currently at the end of a block generated + // by Z_SYNC_FLUSH or Z_FULL_FLUSH. + internal int Sync_point() => mode == LENS ? 1 : 0; + + // copy as much as possible from the sliding window to the output area + internal int Inflate_flush(int r) + { + int n; + int p; + int q; + + // local copies of source and destination pointers + p = z.next_out_index; + q = read; + + // compute number of bytes to copy as far as end of window + n = (int)((q <= write ? write : end) - q); + if (n > z.avail_out) n = z.avail_out; + if (n != 0 && r == Z_BUF_ERROR) r = Z_OK; + + // update counters + z.avail_out -= n; + z.total_out += n; + + // update check information + if (check && n > 0) + { + z.adler.Update(window, q, n); + } + + // copy as far as end of window + Array.Copy(window, q, z.next_out, p, n); + p += n; + q += n; + + // see if more to copy at beginning of window + if (q == end) + { + // wrap pointers + q = 0; + if (write == end) + write = 0; + + // compute bytes to copy + n = write - q; + if (n > z.avail_out) n = z.avail_out; + if (n != 0 && r == Z_BUF_ERROR) r = Z_OK; + + // update counters + z.avail_out -= n; + z.total_out += n; + + // update check information + if (check && n > 0) + { + z.adler.Update(window, q, n); + } + + // copy + Array.Copy(window, q, z.next_out, p, n); + p += n; + q += n; + } + + // update pointers + z.next_out_index = p; + read = q; + + // done + return r; + } + } +} diff --git a/src/DotNetty.Codecs/Compression/InfCodes.cs b/src/DotNetty.Codecs/Compression/InfCodes.cs new file mode 100644 index 0000000..3a429dd --- /dev/null +++ b/src/DotNetty.Codecs/Compression/InfCodes.cs @@ -0,0 +1,695 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +//JZlib 0.0.* were released under the GNU LGPL license.Later, we have switched +//over to a BSD-style license. + +//------------------------------------------------------------------------------ +//Copyright (c) 2000-2011 ymnk, JCraft, Inc.All rights reserved. + +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in +// the documentation and/or other materials provided with the distribution. + +// 3. The names of the authors may not be used to endorse or promote products +// derived from this software without specific prior written permission. + +//THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, +//INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +//FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.IN NO EVENT SHALL JCRAFT, +//INC.OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +//LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +//OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +//LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING +//NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +//EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ReSharper disable ArrangeThisQualifier +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Compression +{ + using System; + using DotNetty.Common.Utilities; + + /// + /// https://github.com/ymnk/jzlib/blob/master/src/main/java/com/jcraft/jzlib/InfCodes.java + /// + sealed class InfCodes + { + static readonly int[] inflate_mask = + { + 0x00000000, 0x00000001, 0x00000003, 0x00000007, 0x0000000f, + 0x0000001f, 0x0000003f, 0x0000007f, 0x000000ff, 0x000001ff, + 0x000003ff, 0x000007ff, 0x00000fff, 0x00001fff, 0x00003fff, + 0x00007fff, 0x0000ffff + }; + + const int Z_OK = 0; + const int Z_STREAM_END = 1; + //const int Z_NEED_DICT = 2; + //const int Z_ERRNO = -1; + const int Z_STREAM_ERROR = -2; + const int Z_DATA_ERROR = -3; + //const int Z_MEM_ERROR = -4; + //const int Z_BUF_ERROR = -5; + //const int Z_VERSION_ERROR = -6; + + // waiting for "i:"=input, + // "o:"=output, + // "x:"=nothing + const int START = 0; // x: set up for LEN + const int LEN = 1; // i: get length/literal/eob next + const int LENEXT = 2; // i: getting length extra (have base) + const int DIST = 3; // i: get distance next + const int DISTEXT = 4;// i: getting distance extra + const int COPY = 5; // o: copying bytes in window, waiting for space + const int LIT = 6; // o: got literal, waiting for output space + const int WASH = 7; // o: got eob, possibly still output waiting + const int END = 8; // x: got eob and all data flushed + const int BADCODE = 9;// x: got error + + int mode; // current inflate_codes mode + + // mode dependent information + int len; + + int[] tree; // pointer into tree + int tree_index; + int need; // bits needed + + int lit; + + // if EXT or COPY, where and how much + int get; // bits to get for extra + int dist; // distance back to copy from + + byte lbits; // ltree bits decoded per branch + byte dbits; // dtree bits decoder per branch + int[] ltree; // literal/length/eob tree + int ltree_index; // literal/length/eob tree + int[] dtree; // distance tree + int dtree_index; // distance tree + + readonly ZStream z; + readonly InfBlocks s; + + internal InfCodes(ZStream z, InfBlocks s) + { + this.z = z; + this.s = s; + } + + internal void Init(int bl, int bd, + int[] tl, int tl_index, + int[] td, int td_index) + { + mode = START; + lbits = (byte)bl; + dbits = (byte)bd; + ltree = tl; + ltree_index = tl_index; + dtree = td; + dtree_index = td_index; + tree = null; + } + + internal int Proc(int r) + { + int j; // temporary storage + //int[] t; // temporary pointer + int tindex; // temporary pointer + int e; // extra bits or operation + int b = 0; // bit buffer + int k = 0; // bits in bit buffer + int p = 0; // input data pointer + int n; // bytes available there + int q; // output window write pointer + int m; // bytes to end of window or read pointer + int f; // pointer to copy strings from + + // copy input/output information to locals (UPDATE macro restores) + p = z.next_in_index; n = z.avail_in; b = s.bitb; k = s.bitk; + q = s.write; m = q < s.read ? s.read - q - 1 : s.end - q; + + // process input and output based on current state + while (true) + { + // waiting for "i:"=input, "o:"=output, "x:"=nothing + if (this.mode == START) // x: set up for LEN + { + if (m >= 258 && n >= 10) + { + + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + r = Inflate_fast(lbits, dbits, + ltree, ltree_index, + dtree, dtree_index, + s, z); + + p = z.next_in_index; n = z.avail_in; b = s.bitb; k = s.bitk; + q = s.write; m = q < s.read ? s.read - q - 1 : s.end - q; + + if (r != Z_OK) + { + mode = r == Z_STREAM_END ? WASH : BADCODE; + continue; // break; + } + } + need = lbits; + tree = ltree; + tree_index = ltree_index; + + mode = LEN; + } // case START + if (this.mode == LEN) // i: get length/literal/eob next + { + j = need; + + while (k < (j)) + { + if (n != 0) r = Z_OK; + else + { + + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + return s.Inflate_flush(r); + } + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + tindex = (tree_index + (b & inflate_mask[j])) * 3; + + b = b.RightUShift(tree[tindex + 1]); + k -= (tree[tindex + 1]); + + e = tree[tindex]; + + if (e == 0) // literal + { + lit = tree[tindex + 2]; + mode = LIT; + continue; // break; + } + if ((e & 16) != 0) // length + { + get = e & 15; + len = tree[tindex + 2]; + mode = LENEXT; + continue; // break; + } + if ((e & 64) == 0) // next table + { + need = e; + tree_index = tindex / 3 + tree[tindex + 2]; + continue; // break; + } + if ((e & 32) != 0) // end of block + { + mode = WASH; + continue; // break; + } + mode = BADCODE; // invalid code + z.msg = "invalid literal/length code"; + r = Z_DATA_ERROR; + + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + return s.Inflate_flush(r); + } // case LEN + if (this.mode == LENEXT) // i: getting length extra (have base) + { + j = get; + + while (k < (j)) + { + if (n != 0) r = Z_OK; + else + { + + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + return s.Inflate_flush(r); + } + n--; b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + len += (b & inflate_mask[j]); + + b >>= j; + k -= j; + + need = dbits; + tree = dtree; + tree_index = dtree_index; + mode = DIST; + } // case LENEXT + if (this.mode == DIST) // i: get distance next + { + j = need; + + while (k < (j)) + { + if (n != 0) r = Z_OK; + else + { + + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + return s.Inflate_flush(r); + } + n--; b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + tindex = (tree_index + (b & inflate_mask[j])) * 3; + + b >>= tree[tindex + 1]; + k -= tree[tindex + 1]; + + e = (tree[tindex]); + if ((e & 16) != 0) + { // distance + get = e & 15; + dist = tree[tindex + 2]; + mode = DISTEXT; + continue; // break; + } + if ((e & 64) == 0) + { // next table + need = e; + tree_index = tindex / 3 + tree[tindex + 2]; + continue; // break; + } + mode = BADCODE; // invalid code + z.msg = "invalid distance code"; + r = Z_DATA_ERROR; + + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + return s.Inflate_flush(r); + } // case DIST + if (this.mode == DISTEXT) // i: getting distance extra + { + j = get; + + while (k < (j)) + { + if (n != 0) r = Z_OK; + else + { + + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + return s.Inflate_flush(r); + } + n--; b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + dist += (b & inflate_mask[j]); + + b >>= j; + k -= j; + + mode = COPY; + } // case DISTEXT + if (this.mode == COPY) // o: copying bytes in window, waiting for space + { + f = q - dist; + while (f < 0) + { // modulo window size-"while" instead + f += s.end; // of "if" handles invalid distances + } + while (len != 0) + { + if (m == 0) + { + if (q == s.end && s.read != 0) { q = 0; m = q < s.read ? s.read - q - 1 : s.end - q; } + if (m == 0) + { + s.write = q; r = s.Inflate_flush(r); + q = s.write; m = q < s.read ? s.read - q - 1 : s.end - q; + + if (q == s.end && s.read != 0) { q = 0; m = q < s.read ? s.read - q - 1 : s.end - q; } + + if (m == 0) + { + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + return s.Inflate_flush(r); + } + } + } + + s.window[q++] = s.window[f++]; m--; + + if (f == s.end) + f = 0; + len--; + } + mode = START; + continue; + } // case COPY + if (this.mode == LIT) // o: got literal, waiting for output space + { + if (m == 0) + { + if (q == s.end && s.read != 0) { q = 0; m = q < s.read ? s.read - q - 1 : s.end - q; } + if (m == 0) + { + s.write = q; r = s.Inflate_flush(r); + q = s.write; m = q < s.read ? s.read - q - 1 : s.end - q; + + if (q == s.end && s.read != 0) { q = 0; m = q < s.read ? s.read - q - 1 : s.end - q; } + if (m == 0) + { + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + return s.Inflate_flush(r); + } + } + } + r = Z_OK; + + s.window[q++] = (byte)lit; m--; + + mode = START; + continue; + } // case LIT + if (this.mode == WASH) // o: got eob, possibly more output + { + if (k > 7) + { // return unused byte, if any + k -= 8; + n++; + p--; // can always return one + } + + s.write = q; r = s.Inflate_flush(r); + q = s.write; m = q < s.read ? s.read - q - 1 : s.end - q; + + if (s.read != s.write) + { + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + return s.Inflate_flush(r); + } + mode = END; + } // case WASH + if (this.mode == END) + { + r = Z_STREAM_END; + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + return s.Inflate_flush(r); + } // case END + if (this.mode == BADCODE) // x: got error + { + r = Z_DATA_ERROR; + + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + return s.Inflate_flush(r); + } // case BADCODE + + // default + break; + } // while + + // default + r = Z_STREAM_ERROR; + + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + + return s.Inflate_flush(r); + } + + internal void Free(ZStream z) + { + // ZFREE(z, c); + } + + // Called with number of bytes left to write in window at least 258 + // (the maximum string length) and number of input bytes available + // at least ten. The ten bytes are six bytes for the longest length/ + // distance pair plus four bytes for overloading the bit buffer. + + static int Inflate_fast( + int bl, int bd, + int[] tl, int tl_index, + int[] td, int td_index, + InfBlocks s, ZStream z) + { + int t; // temporary pointer + int[] tp; // temporary pointer + int tp_index; // temporary pointer + int e; // extra bits or operation + int b; // bit buffer + int k; // bits in bit buffer + int p; // input data pointer + int n; // bytes available there + int q; // output window write pointer + int m; // bytes to end of window or read pointer + int ml; // mask for literal/length tree + int md; // mask for distance tree + int c; // bytes to copy + int d; // distance back to copy from + int r; // copy source pointer + + int tp_index_t_3; // (tp_index+t)*3 + + // load input, output, bit values + p = z.next_in_index; n = z.avail_in; b = s.bitb; k = s.bitk; + q = s.write; m = q < s.read ? s.read - q - 1 : s.end - q; + + // initialize masks + ml = inflate_mask[bl]; + md = inflate_mask[bd]; + + // do until not enough input or output space for fast loop + do + { // assume called with m >= 258 && n >= 10 + // get literal/length code + while (k < (20)) + { // max bits for literal/length code + n--; + b |= (z.next_in[p++] & 0xff) << k; k += 8; + } + + t = b & ml; + tp = tl; + tp_index = tl_index; + tp_index_t_3 = (tp_index + t) * 3; + if ((e = tp[tp_index_t_3]) == 0) + { + b >>= (tp[tp_index_t_3 + 1]); k -= (tp[tp_index_t_3 + 1]); + + s.window[q++] = (byte)tp[tp_index_t_3 + 2]; + m--; + continue; + } + do + { + + b >>= (tp[tp_index_t_3 + 1]); k -= (tp[tp_index_t_3 + 1]); + + if ((e & 16) != 0) + { + e &= 15; + c = tp[tp_index_t_3 + 2] + ((int)b & inflate_mask[e]); + + b >>= e; k -= e; + + // decode distance base of block to copy + while (k < (15)) + { // max bits for distance code + n--; + b |= (z.next_in[p++] & 0xff) << k; k += 8; + } + + t = b & md; + tp = td; + tp_index = td_index; + tp_index_t_3 = (tp_index + t) * 3; + e = tp[tp_index_t_3]; + + do + { + + b >>= (tp[tp_index_t_3 + 1]); k -= (tp[tp_index_t_3 + 1]); + + if ((e & 16) != 0) + { + // get extra bits to add to distance base + e &= 15; + while (k < (e)) + { // get extra bits (up to 13) + n--; + b |= (z.next_in[p++] & 0xff) << k; k += 8; + } + + d = tp[tp_index_t_3 + 2] + (b & inflate_mask[e]); + + b >>= (e); k -= (e); + + // do the copy + m -= c; + if (q >= d) + { // offset before dest + // just copy + r = q - d; + if (q - r > 0 && 2 > (q - r)) + { + s.window[q++] = s.window[r++]; // minimum count is three, + s.window[q++] = s.window[r++]; // so unroll loop a little + c -= 2; + } + else + { + Array.Copy(s.window, r, s.window, q, 2); + q += 2; r += 2; c -= 2; + } + } + else + { // else offset after destination + r = q - d; + do + { + r += s.end; // force pointer in window + } while (r < 0); // covers invalid distances + e = s.end - r; + if (c > e) + { // if source crosses, + c -= e; // wrapped copy + if (q - r > 0 && e > (q - r)) + { + do { s.window[q++] = s.window[r++]; } + while (--e != 0); + } + else + { + Array.Copy(s.window, r, s.window, q, e); + q += e; r += e; e = 0; + } + r = 0; // copy rest from start of window + } + + } + + // copy all or what's left + if (q - r > 0 && c > (q - r)) + { + do { s.window[q++] = s.window[r++]; } + while (--c != 0); + } + else + { + Array.Copy(s.window, r, s.window, q, c); + q += c; r += c; c = 0; + } + break; + } + else if ((e & 64) == 0) + { + t += tp[tp_index_t_3 + 2]; + t += (b & inflate_mask[e]); + tp_index_t_3 = (tp_index + t) * 3; + e = tp[tp_index_t_3]; + } + else + { + z.msg = "invalid distance code"; + + c = z.avail_in - n; c = (k >> 3) < c ? k >> 3 : c; n += c; p -= c; k -= c << 3; + + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + + return Z_DATA_ERROR; + } + } + while (true); + break; + } + + if ((e & 64) == 0) + { + t += tp[tp_index_t_3 + 2]; + t += (b & inflate_mask[e]); + tp_index_t_3 = (tp_index + t) * 3; + if ((e = tp[tp_index_t_3]) == 0) + { + + b >>= (tp[tp_index_t_3 + 1]); k -= (tp[tp_index_t_3 + 1]); + + s.window[q++] = (byte)tp[tp_index_t_3 + 2]; + m--; + break; + } + } + else if ((e & 32) != 0) + { + + c = z.avail_in - n; c = (k >> 3) < c ? k >> 3 : c; n += c; p -= c; k -= c << 3; + + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + + return Z_STREAM_END; + } + else + { + z.msg = "invalid literal/length code"; + + c = z.avail_in - n; c = (k >> 3) < c ? k >> 3 : c; n += c; p -= c; k -= c << 3; + + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + + return Z_DATA_ERROR; + } + } + while (true); + } + while (m >= 258 && n >= 10); + + // not enough input or output--restore pointers and return + c = z.avail_in - n; c = (k >> 3) < c ? k >> 3 : c; n += c; p -= c; k -= c << 3; + + s.bitb = b; s.bitk = k; + z.avail_in = n; z.total_in += p - z.next_in_index; z.next_in_index = p; + s.write = q; + + return Z_OK; + } + } +} diff --git a/src/DotNetty.Codecs/Compression/InfTree.cs b/src/DotNetty.Codecs/Compression/InfTree.cs new file mode 100644 index 0000000..181c7cb --- /dev/null +++ b/src/DotNetty.Codecs/Compression/InfTree.cs @@ -0,0 +1,607 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +//JZlib 0.0.* were released under the GNU LGPL license.Later, we have switched +//over to a BSD-style license. + +//------------------------------------------------------------------------------ +//Copyright (c) 2000-2011 ymnk, JCraft, Inc.All rights reserved. + +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in +// the documentation and/or other materials provided with the distribution. + +// 3. The names of the authors may not be used to endorse or promote products +// derived from this software without specific prior written permission. + +//THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, +//INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +//FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.IN NO EVENT SHALL JCRAFT, +//INC.OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +//LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +//OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +//LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING +//NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +//EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ReSharper disable ArrangeThisQualifier +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Compression +{ + using System; + using DotNetty.Common.Utilities; + + /// + /// https://github.com/ymnk/jzlib/blob/master/src/main/java/com/jcraft/jzlib/InfTree.java + /// + sealed class InfTree + { + const int MANY = 1440; + + const int Z_OK = 0; + //const int Z_STREAM_END = 1; + //const int Z_NEED_DICT = 2; + //const int Z_ERRNO = -1; + //const int Z_STREAM_ERROR = -2; + const int Z_DATA_ERROR = -3; + const int Z_MEM_ERROR = -4; + const int Z_BUF_ERROR = -5; + //const int Z_VERSION_ERROR = -6; + + const int fixed_bl = 9; + const int fixed_bd = 5; + + static readonly int[] fixed_tl = + { + 96, 7, 256, 0, 8, 80, 0, 8, 16, 84, 8, 115, + 82, 7, 31, 0, 8, 112, 0, 8, 48, 0, 9, 192, + 80, 7, 10, 0, 8, 96, 0, 8, 32, 0, 9, 160, + 0, 8, 0, 0, 8, 128, 0, 8, 64, 0, 9, 224, + 80, 7, 6, 0, 8, 88, 0, 8, 24, 0, 9, 144, + 83, 7, 59, 0, 8, 120, 0, 8, 56, 0, 9, 208, + 81, 7, 17, 0, 8, 104, 0, 8, 40, 0, 9, 176, + 0, 8, 8, 0, 8, 136, 0, 8, 72, 0, 9, 240, + 80, 7, 4, 0, 8, 84, 0, 8, 20, 85, 8, 227, + 83, 7, 43, 0, 8, 116, 0, 8, 52, 0, 9, 200, + 81, 7, 13, 0, 8, 100, 0, 8, 36, 0, 9, 168, + 0, 8, 4, 0, 8, 132, 0, 8, 68, 0, 9, 232, + 80, 7, 8, 0, 8, 92, 0, 8, 28, 0, 9, 152, + 84, 7, 83, 0, 8, 124, 0, 8, 60, 0, 9, 216, + 82, 7, 23, 0, 8, 108, 0, 8, 44, 0, 9, 184, + 0, 8, 12, 0, 8, 140, 0, 8, 76, 0, 9, 248, + 80, 7, 3, 0, 8, 82, 0, 8, 18, 85, 8, 163, + 83, 7, 35, 0, 8, 114, 0, 8, 50, 0, 9, 196, + 81, 7, 11, 0, 8, 98, 0, 8, 34, 0, 9, 164, + 0, 8, 2, 0, 8, 130, 0, 8, 66, 0, 9, 228, + 80, 7, 7, 0, 8, 90, 0, 8, 26, 0, 9, 148, + 84, 7, 67, 0, 8, 122, 0, 8, 58, 0, 9, 212, + 82, 7, 19, 0, 8, 106, 0, 8, 42, 0, 9, 180, + 0, 8, 10, 0, 8, 138, 0, 8, 74, 0, 9, 244, + 80, 7, 5, 0, 8, 86, 0, 8, 22, 192, 8, 0, + 83, 7, 51, 0, 8, 118, 0, 8, 54, 0, 9, 204, + 81, 7, 15, 0, 8, 102, 0, 8, 38, 0, 9, 172, + 0, 8, 6, 0, 8, 134, 0, 8, 70, 0, 9, 236, + 80, 7, 9, 0, 8, 94, 0, 8, 30, 0, 9, 156, + 84, 7, 99, 0, 8, 126, 0, 8, 62, 0, 9, 220, + 82, 7, 27, 0, 8, 110, 0, 8, 46, 0, 9, 188, + 0, 8, 14, 0, 8, 142, 0, 8, 78, 0, 9, 252, + 96, 7, 256, 0, 8, 81, 0, 8, 17, 85, 8, 131, + 82, 7, 31, 0, 8, 113, 0, 8, 49, 0, 9, 194, + 80, 7, 10, 0, 8, 97, 0, 8, 33, 0, 9, 162, + 0, 8, 1, 0, 8, 129, 0, 8, 65, 0, 9, 226, + 80, 7, 6, 0, 8, 89, 0, 8, 25, 0, 9, 146, + 83, 7, 59, 0, 8, 121, 0, 8, 57, 0, 9, 210, + 81, 7, 17, 0, 8, 105, 0, 8, 41, 0, 9, 178, + 0, 8, 9, 0, 8, 137, 0, 8, 73, 0, 9, 242, + 80, 7, 4, 0, 8, 85, 0, 8, 21, 80, 8, 258, + 83, 7, 43, 0, 8, 117, 0, 8, 53, 0, 9, 202, + 81, 7, 13, 0, 8, 101, 0, 8, 37, 0, 9, 170, + 0, 8, 5, 0, 8, 133, 0, 8, 69, 0, 9, 234, + 80, 7, 8, 0, 8, 93, 0, 8, 29, 0, 9, 154, + 84, 7, 83, 0, 8, 125, 0, 8, 61, 0, 9, 218, + 82, 7, 23, 0, 8, 109, 0, 8, 45, 0, 9, 186, + 0, 8, 13, 0, 8, 141, 0, 8, 77, 0, 9, 250, + 80, 7, 3, 0, 8, 83, 0, 8, 19, 85, 8, 195, + 83, 7, 35, 0, 8, 115, 0, 8, 51, 0, 9, 198, + 81, 7, 11, 0, 8, 99, 0, 8, 35, 0, 9, 166, + 0, 8, 3, 0, 8, 131, 0, 8, 67, 0, 9, 230, + 80, 7, 7, 0, 8, 91, 0, 8, 27, 0, 9, 150, + 84, 7, 67, 0, 8, 123, 0, 8, 59, 0, 9, 214, + 82, 7, 19, 0, 8, 107, 0, 8, 43, 0, 9, 182, + 0, 8, 11, 0, 8, 139, 0, 8, 75, 0, 9, 246, + 80, 7, 5, 0, 8, 87, 0, 8, 23, 192, 8, 0, + 83, 7, 51, 0, 8, 119, 0, 8, 55, 0, 9, 206, + 81, 7, 15, 0, 8, 103, 0, 8, 39, 0, 9, 174, + 0, 8, 7, 0, 8, 135, 0, 8, 71, 0, 9, 238, + 80, 7, 9, 0, 8, 95, 0, 8, 31, 0, 9, 158, + 84, 7, 99, 0, 8, 127, 0, 8, 63, 0, 9, 222, + 82, 7, 27, 0, 8, 111, 0, 8, 47, 0, 9, 190, + 0, 8, 15, 0, 8, 143, 0, 8, 79, 0, 9, 254, + 96, 7, 256, 0, 8, 80, 0, 8, 16, 84, 8, 115, + 82, 7, 31, 0, 8, 112, 0, 8, 48, 0, 9, 193, + + 80, 7, 10, 0, 8, 96, 0, 8, 32, 0, 9, 161, + 0, 8, 0, 0, 8, 128, 0, 8, 64, 0, 9, 225, + 80, 7, 6, 0, 8, 88, 0, 8, 24, 0, 9, 145, + 83, 7, 59, 0, 8, 120, 0, 8, 56, 0, 9, 209, + 81, 7, 17, 0, 8, 104, 0, 8, 40, 0, 9, 177, + 0, 8, 8, 0, 8, 136, 0, 8, 72, 0, 9, 241, + 80, 7, 4, 0, 8, 84, 0, 8, 20, 85, 8, 227, + 83, 7, 43, 0, 8, 116, 0, 8, 52, 0, 9, 201, + 81, 7, 13, 0, 8, 100, 0, 8, 36, 0, 9, 169, + 0, 8, 4, 0, 8, 132, 0, 8, 68, 0, 9, 233, + 80, 7, 8, 0, 8, 92, 0, 8, 28, 0, 9, 153, + 84, 7, 83, 0, 8, 124, 0, 8, 60, 0, 9, 217, + 82, 7, 23, 0, 8, 108, 0, 8, 44, 0, 9, 185, + 0, 8, 12, 0, 8, 140, 0, 8, 76, 0, 9, 249, + 80, 7, 3, 0, 8, 82, 0, 8, 18, 85, 8, 163, + 83, 7, 35, 0, 8, 114, 0, 8, 50, 0, 9, 197, + 81, 7, 11, 0, 8, 98, 0, 8, 34, 0, 9, 165, + 0, 8, 2, 0, 8, 130, 0, 8, 66, 0, 9, 229, + 80, 7, 7, 0, 8, 90, 0, 8, 26, 0, 9, 149, + 84, 7, 67, 0, 8, 122, 0, 8, 58, 0, 9, 213, + 82, 7, 19, 0, 8, 106, 0, 8, 42, 0, 9, 181, + 0, 8, 10, 0, 8, 138, 0, 8, 74, 0, 9, 245, + 80, 7, 5, 0, 8, 86, 0, 8, 22, 192, 8, 0, + 83, 7, 51, 0, 8, 118, 0, 8, 54, 0, 9, 205, + 81, 7, 15, 0, 8, 102, 0, 8, 38, 0, 9, 173, + 0, 8, 6, 0, 8, 134, 0, 8, 70, 0, 9, 237, + 80, 7, 9, 0, 8, 94, 0, 8, 30, 0, 9, 157, + 84, 7, 99, 0, 8, 126, 0, 8, 62, 0, 9, 221, + 82, 7, 27, 0, 8, 110, 0, 8, 46, 0, 9, 189, + 0, 8, 14, 0, 8, 142, 0, 8, 78, 0, 9, 253, + 96, 7, 256, 0, 8, 81, 0, 8, 17, 85, 8, 131, + 82, 7, 31, 0, 8, 113, 0, 8, 49, 0, 9, 195, + 80, 7, 10, 0, 8, 97, 0, 8, 33, 0, 9, 163, + 0, 8, 1, 0, 8, 129, 0, 8, 65, 0, 9, 227, + 80, 7, 6, 0, 8, 89, 0, 8, 25, 0, 9, 147, + 83, 7, 59, 0, 8, 121, 0, 8, 57, 0, 9, 211, + 81, 7, 17, 0, 8, 105, 0, 8, 41, 0, 9, 179, + 0, 8, 9, 0, 8, 137, 0, 8, 73, 0, 9, 243, + 80, 7, 4, 0, 8, 85, 0, 8, 21, 80, 8, 258, + 83, 7, 43, 0, 8, 117, 0, 8, 53, 0, 9, 203, + 81, 7, 13, 0, 8, 101, 0, 8, 37, 0, 9, 171, + 0, 8, 5, 0, 8, 133, 0, 8, 69, 0, 9, 235, + 80, 7, 8, 0, 8, 93, 0, 8, 29, 0, 9, 155, + 84, 7, 83, 0, 8, 125, 0, 8, 61, 0, 9, 219, + 82, 7, 23, 0, 8, 109, 0, 8, 45, 0, 9, 187, + 0, 8, 13, 0, 8, 141, 0, 8, 77, 0, 9, 251, + 80, 7, 3, 0, 8, 83, 0, 8, 19, 85, 8, 195, + 83, 7, 35, 0, 8, 115, 0, 8, 51, 0, 9, 199, + 81, 7, 11, 0, 8, 99, 0, 8, 35, 0, 9, 167, + 0, 8, 3, 0, 8, 131, 0, 8, 67, 0, 9, 231, + 80, 7, 7, 0, 8, 91, 0, 8, 27, 0, 9, 151, + 84, 7, 67, 0, 8, 123, 0, 8, 59, 0, 9, 215, + 82, 7, 19, 0, 8, 107, 0, 8, 43, 0, 9, 183, + 0, 8, 11, 0, 8, 139, 0, 8, 75, 0, 9, 247, + 80, 7, 5, 0, 8, 87, 0, 8, 23, 192, 8, 0, + 83, 7, 51, 0, 8, 119, 0, 8, 55, 0, 9, 207, + 81, 7, 15, 0, 8, 103, 0, 8, 39, 0, 9, 175, + 0, 8, 7, 0, 8, 135, 0, 8, 71, 0, 9, 239, + 80, 7, 9, 0, 8, 95, 0, 8, 31, 0, 9, 159, + 84, 7, 99, 0, 8, 127, 0, 8, 63, 0, 9, 223, + 82, 7, 27, 0, 8, 111, 0, 8, 47, 0, 9, 191, + 0, 8, 15, 0, 8, 143, 0, 8, 79, 0, 9, 255 + }; + + static readonly int[] fixed_td = + { + 80, 5, 1, 87, 5, 257, 83, 5, 17, 91, 5, 4097, + 81, 5, 5, 89, 5, 1025, 85, 5, 65, 93, 5, 16385, + 80, 5, 3, 88, 5, 513, 84, 5, 33, 92, 5, 8193, + 82, 5, 9, 90, 5, 2049, 86, 5, 129, 192, 5, 24577, + 80, 5, 2, 87, 5, 385, 83, 5, 25, 91, 5, 6145, + 81, 5, 7, 89, 5, 1537, 85, 5, 97, 93, 5, 24577, + 80, 5, 4, 88, 5, 769, 84, 5, 49, 92, 5, 12289, + 82, 5, 13, 90, 5, 3073, 86, 5, 193, 192, 5, 24577 + }; + + // Tables for deflate from PKZIP's appnote.txt. + static readonly int[] cplens = + { + // Copy lengths for literal codes 257..285 + 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, + 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0 + }; + + // see note #13 above about 258 + static readonly int[] cplext = + { + // Extra bits for literal codes 257..285 + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, + 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0, 112, 112 // 112==invalid + }; + + static readonly int[] cpdist = + { + // Copy offsets for distance codes 0..29 + 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, + 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, + 8193, 12289, 16385, 24577 + }; + + static readonly int[] cpdext = + { + // Extra bits for distance codes + 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, + 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, + 12, 12, 13, 13 + }; + + // If BMAX needs to be larger than 16, then h and x[] should be uLong. + const int BMAX = 15; // maximum bit length of any code + + int[] hn = null; // hufts used in space + int[] v = null; // work area for huft_build + int[] c = null; // bit length count table + int[] r = null; // table entry for structure assignment + int[] u = null; // table stack + int[] x = null; // bit offsets, then code stack + + int Huft_build( + int[] b, // code lengths in bits (all assumed <= BMAX) + int bindex, + int n, // number of codes (assumed <= 288) + int s, // number of simple-valued codes (0..s-1) + int[] d, // list of base values for non-simple codes + int[] e, // list of extra bits for non-simple codes + int[] t, // result: starting table + int[] m, // maximum lookup bits, returns actual + int[] hp, // space for trees + int[] hn, // hufts used in space + int[] v // working area: values in order of bit length + ) + { + // Given a list of code lengths and a maximum table size, make a set of + // tables to decode that set of codes. Return Z_OK on success, Z_BUF_ERROR + // if the given code set is incomplete (the tables are still built in this + // case), Z_DATA_ERROR if the input is invalid (an over-subscribed set of + // lengths), or Z_MEM_ERROR if not enough memory. + + int a; // counter for codes of length k + int f; // i repeats in table every f entries + int g; // maximum code length + int h; // table level + int i; // counter, current code + int j; // counter + int k; // number of bits in current code + int l; // bits per table (returned in m) + int mask; // (1 << w) - 1, to avoid cc -O bug on HP + int p; // pointer into c[], b[], or v[] + int q; // points to current table + int w; // bits before this table == (l * h) + int xp; // pointer into x + int y; // number of dummy codes added + int z; // number of entries in current table + + // Generate counts for each bit length + + p = 0; + i = n; + do + { + c[b[bindex + p]]++; + p++; + i--; // assume all entries <= BMAX + } + while (i != 0); + + if (c[0] == n) + { + // null input--all zero length codes + t[0] = -1; + m[0] = 0; + return Z_OK; + } + + // Find minimum and maximum length, bound *m by those + l = m[0]; + for (j = 1; j <= BMAX; j++) + if (c[j] != 0) + break; + k = j; // minimum code length + if (l < j) + { + l = j; + } + for (i = BMAX; i != 0; i--) + { + if (c[i] != 0) + break; + } + g = i; // maximum code length + if (l > i) + { + l = i; + } + m[0] = l; + + // Adjust last length count to fill out codes, if needed + for (y = 1 << j; j < i; j++, y <<= 1) + { + if ((y -= c[j]) < 0) + { + return Z_DATA_ERROR; + } + } + if ((y -= c[i]) < 0) + { + return Z_DATA_ERROR; + } + c[i] += y; + + // Generate starting offsets into the value table for each length + x[1] = j = 0; + p = 1; + xp = 2; + while (--i != 0) + { + // note that i == g from above + x[xp] = (j += c[p]); + xp++; + p++; + } + + // Make a table of values in order of bit lengths + i = 0; + p = 0; + do + { + if ((j = b[bindex + p]) != 0) + { + v[x[j]++] = i; + } + p++; + } + while (++i < n); + n = x[g]; // set n to length of v + + // Generate the Huffman codes and for each, make the table entries + x[0] = i = 0; // first Huffman code is zero + p = 0; // grab values in bit order + h = -1; // no tables yet--level -1 + w = -l; // bits decoded == (l * h) + u[0] = 0; // just to keep compilers happy + q = 0; // ditto + z = 0; // ditto + + // go through the bit lengths (k already is bits in shortest code) + for (; k <= g; k++) + { + a = c[k]; + while (a-- != 0) + { + // here i is the Huffman code of length k bits for value *p + // make tables up to required level + while (k > w + l) + { + h++; + w += l; // previous table always l bits + // compute minimum size table less than or equal to l bits + z = g - w; + z = (z > l) ? l : z; // table size upper limit + if ((f = 1 << (j = k - w)) > a + 1) + { + // try a k-w bit table + // too few codes for k-w bit table + f -= a + 1; // deduct codes from patterns left + xp = k; + if (j < z) + { + while (++j < z) + { + // try smaller tables up to z bits + if ((f <<= 1) <= c[++xp]) + break; // enough codes to use up j bits + f -= c[xp]; // else deduct codes from patterns + } + } + } + z = 1 << j; // table entries for j-bit table + + // allocate new table + if (hn[0] + z > MANY) + { + // (note: doesn't matter for fixed) + return Z_DATA_ERROR; // overflow of MANY + } + u[h] = q = /*hp+*/ hn[0]; // DEBUG + hn[0] += z; + + // connect to last table, if there is one + if (h != 0) + { + x[h] = i; // save pattern for backing up + r[0] = (byte)j; // bits in this table + r[1] = (byte)l; // bits to dump before this table + j = i.RightUShift(w - l); + r[2] = q - this.u[h - 1] - j; // offset to this table + Array.Copy(r, 0, hp, (u[h - 1] + j) * 3, 3); // connect to last table + } + else + { + t[0] = q; // first table is returned result + } + } + + // set up table entry in r + r[1] = (byte)(k - w); + if (p >= n) + { + r[0] = 128 + 64; // out of values--invalid code + } + else if (v[p] < s) + { + r[0] = (byte)(v[p] < 256 ? 0 : 32 + 64); // 256 is end-of-block + r[2] = v[p++]; // simple code is just the value + } + else + { + r[0] = (byte)(e[v[p] - s] + 16 + 64); // non-simple--look up in lists + r[2] = d[v[p++] - s]; + } + + // fill code-like entries with r + f = 1 << (k - w); + for (j = i.RightUShift(w); j < z; j += f) + { + Array.Copy(r, 0, hp, (q + j) * 3, 3); + } + + // backwards increment the k-bit code i + for (j = 1 << (k - 1); (i & j) != 0; j = j.RightUShift(1)) + { + i ^= j; + } + i ^= j; + + // backup over finished tables + mask = (1 << w) - 1; // needed on HP, cc -O bug + while ((i & mask) != x[h]) + { + h--; // don't need to update q + w -= l; + mask = (1 << w) - 1; + } + } + } + // Return Z_BUF_ERROR if we were given an incomplete table + return y != 0 && g != 1 ? Z_BUF_ERROR : Z_OK; + } + + internal int Inflate_trees_bits(int[] c, // 19 code lengths + int[] bb, // bits tree desired/actual depth + int[] tb, // bits tree result + int[] hp, // space for trees + ZStream z // for messages + ) + { + int result; + InitWorkArea(19); + hn[0] = 0; + result = Huft_build(c, 0, 19, 19, null, null, tb, bb, hp, hn, v); + + if (result == Z_DATA_ERROR) + { + z.msg = "oversubscribed dynamic bit lengths tree"; + } + else if (result == Z_BUF_ERROR || bb[0] == 0) + { + z.msg = "incomplete dynamic bit lengths tree"; + result = Z_DATA_ERROR; + } + return result; + } + + internal int Inflate_trees_dynamic( + int nl, // number of literal/length codes + int nd, // number of distance codes + int[] c, // that many (total) code lengths + int[] bl, // literal desired/actual bit depth + int[] bd, // distance desired/actual bit depth + int[] tl, // literal/length tree result + int[] td, // distance tree result + int[] hp, // space for trees + ZStream z // for messages + ) + { + int result; + + // build literal/length tree + InitWorkArea(288); + hn[0] = 0; + result = Huft_build(c, 0, nl, 257, cplens, cplext, tl, bl, hp, hn, v); + if (result != Z_OK || bl[0] == 0) + { + if (result == Z_DATA_ERROR) + { + z.msg = "oversubscribed literal/length tree"; + } + else if (result != Z_MEM_ERROR) + { + z.msg = "incomplete literal/length tree"; + result = Z_DATA_ERROR; + } + return result; + } + + // build distance tree + InitWorkArea(288); + result = Huft_build(c, nl, nd, 0, cpdist, cpdext, td, bd, hp, hn, v); + + if (result != Z_OK || (bd[0] == 0 && nl > 257)) + { + if (result == Z_DATA_ERROR) + { + z.msg = "oversubscribed distance tree"; + } + else if (result == Z_BUF_ERROR) + { + z.msg = "incomplete distance tree"; + result = Z_DATA_ERROR; + } + else if (result != Z_MEM_ERROR) + { + z.msg = "empty distance tree with lengths"; + result = Z_DATA_ERROR; + } + return result; + } + + return Z_OK; + } + + internal static int Inflate_trees_fixed( + int[] bl, //literal desired/actual bit depth + int[] bd, //distance desired/actual bit depth + int[][] tl, //literal/length tree result + int[][] td, //distance tree result + ZStream z //for memory allocation + ) + { + bl[0] = fixed_bl; + bd[0] = fixed_bd; + tl[0] = fixed_tl; + td[0] = fixed_td; + return Z_OK; + } + + void InitWorkArea(int vsize) + { + if (hn == null) + { + hn = new int[1]; + v = new int[vsize]; + c = new int[BMAX + 1]; + r = new int[3]; + u = new int[BMAX]; + x = new int[BMAX + 1]; + } + if (v.Length < vsize) + { + v = new int[vsize]; + } + for (int i = 0; i < vsize; i++) + { + v[i] = 0; + } + for (int i = 0; i < BMAX + 1; i++) + { + c[i] = 0; + } + for (int i = 0; i < 3; i++) + { + r[i] = 0; + } + Array.Copy(c, 0, u, 0, BMAX); + Array.Copy(c, 0, x, 0, BMAX + 1); + } + } +} diff --git a/src/DotNetty.Codecs/Compression/Inflate.cs b/src/DotNetty.Codecs/Compression/Inflate.cs new file mode 100644 index 0000000..ee10443 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/Inflate.cs @@ -0,0 +1,962 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +//JZlib 0.0.* were released under the GNU LGPL license.Later, we have switched +//over to a BSD-style license. + +//------------------------------------------------------------------------------ +//Copyright (c) 2000-2011 ymnk, JCraft, Inc.All rights reserved. + +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in +// the documentation and/or other materials provided with the distribution. + +// 3. The names of the authors may not be used to endorse or promote products +// derived from this software without specific prior written permission. + +//THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, +//INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +//FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.IN NO EVENT SHALL JCRAFT, +//INC.OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +//LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +//OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +//LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING +//NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +//EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ReSharper disable ArrangeThisQualifier +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Compression +{ + using System; + using System.IO; + + /// + /// https://github.com/ymnk/jzlib/blob/master/src/main/java/com/jcraft/jzlib/Inflate.java + /// + sealed class Inflate + { + //const int MAX_WBITS = 15; // 32K LZ77 window + + // preset dictionary flag in zlib header + const int PRESET_DICT = 0x20; + + //const int Z_NO_FLUSH = 0; + //const int Z_PARTIAL_FLUSH = 1; + //const int Z_SYNC_FLUSH = 2; + //const int Z_FULL_FLUSH = 3; + const int Z_FINISH = 4; + + const int Z_DEFLATED = 8; + + const int Z_OK = 0; + const int Z_STREAM_END = 1; + const int Z_NEED_DICT = 2; + //const int Z_ERRNO = -1; + const int Z_STREAM_ERROR = -2; + const int Z_DATA_ERROR = -3; + //const int Z_MEM_ERROR = -4; + const int Z_BUF_ERROR = -5; + //const int Z_VERSION_ERROR = -6; + + //const int METHOD = 0; // waiting for method byte + //const int FLAG = 1; // waiting for flag byte + const int DICT4 = 2; // four dictionary check bytes to go + const int DICT3 = 3; // three dictionary check bytes to go + const int DICT2 = 4; // two dictionary check bytes to go + const int DICT1 = 5; // one dictionary check byte to go + const int DICT0 = 6; // waiting for inflateSetDictionary + const int BLOCKS = 7; // decompressing blocks + const int CHECK4 = 8; // four check bytes to go + const int CHECK3 = 9; // three check bytes to go + const int CHECK2 = 10; // two check bytes to go + const int CHECK1 = 11; // one check byte to go + const int DONE = 12; // finished check, done + const int BAD = 13; // got an error--stay here + + const int HEAD = 14; + const int LENGTH = 15; + const int TIME = 16; + const int OS = 17; + const int EXLEN = 18; + const int EXTRA = 19; + const int NAME = 20; + const int COMMENT = 21; + const int HCRC = 22; + const int FLAGS = 23; + + internal static readonly int INFLATE_ANY = 0x40000000; + + internal int mode; // current inflate mode + + // mode dependent information + int method; // if FLAGS, method byte + + // if CHECK, check values to compare + long was = -1; // computed check value + long need; // stream check value + + // if BAD, inflateSync's marker bytes count + int marker; + + // mode independent information + internal int wrap; // flag for no wrapper + // 0: no wrapper + // 1: zlib header + // 2: gzip header + // 4: auto detection + + int wbits; // log2(window size) (8..15, defaults to 15) + + InfBlocks blocks; // current inflate_blocks state + + readonly ZStream z; + + int flags; + + int need_bytes = -1; + byte[] crcbuf = new byte[4]; + + GZIPHeader gheader = null; + + internal int InflateReset() + { + if (z == null) + return Z_STREAM_ERROR; + + z.total_in = z.total_out = 0; + z.msg = null; + this.mode = HEAD; + this.need_bytes = -1; + this.blocks.Reset(); + return Z_OK; + } + + internal int InflateEnd() + { + if (blocks != null) + { + blocks.Free(); + } + return Z_OK; + } + + internal Inflate(ZStream z) + { + this.z = z; + } + + internal int InflateInit(int w) + { + z.msg = null; + blocks = null; + + // handle undocumented wrap option (no zlib header or check) + wrap = 0; + if (w < 0) + { + w = -w; + } + else if ((w & INFLATE_ANY) != 0) + { + wrap = 4; + w &= ~INFLATE_ANY; + if (w < 48) + w &= 15; + } + else if ((w & ~31) != 0) // for example, DEF_WBITS + 32 + { + wrap = 4; // zlib and gzip wrapped data should be accepted. + w &= 15; + } + else + { + wrap = (w >> 4) + 1; + if (w < 48) + w &= 15; + } + + if (w < 8 || w > 15) + { + InflateEnd(); + return Z_STREAM_ERROR; + } + if (blocks != null && wbits != w) + { + blocks.Free(); + blocks = null; + } + + // set window size + wbits = w; + + this.blocks = new InfBlocks(z, 1 << w); + + // reset state + InflateReset(); + + return Z_OK; + } + + internal int Inflate_I(int f) + { + //int hold = 0; + + int r; + int b; + + if (z == null || z.next_in == null) + { + if (f == Z_FINISH && this.mode == HEAD) + return Z_OK; + return Z_STREAM_ERROR; + } + + f = f == Z_FINISH ? Z_BUF_ERROR : Z_OK; + r = Z_BUF_ERROR; + while (true) + { + if (mode == HEAD) + { + if (wrap == 0) + { + this.mode = BLOCKS; + continue; // break; + } + + try { r = ReadBytes(2, r, f); } + catch (Return e) { return e.r; } + + if ((wrap == 4 || (wrap & 2) != 0) && + this.need == 0x8b1fL) // gzip header + { + if (wrap == 4) + { + wrap = 2; + } + z.adler = new CRC32(); + Checksum(2, this.need); + + if (gheader == null) + gheader = new GZIPHeader(); + + this.mode = FLAGS; + continue; // break; + } + + if ((wrap & 2) != 0) + { + this.mode = BAD; + z.msg = "incorrect header check"; + continue; // break; + } + + flags = 0; + + this.method = ((int)this.need) & 0xff; + b = ((int)(this.need >> 8)) & 0xff; + + if (((wrap & 1) == 0 || // check if zlib header allowed + (((this.method << 8) + b) % 31) != 0) && + (this.method & 0xf) != Z_DEFLATED) + { + if (wrap == 4) + { + z.next_in_index -= 2; + z.avail_in += 2; + z.total_in -= 2; + wrap = 0; + this.mode = BLOCKS; + continue; // break; + } + this.mode = BAD; + z.msg = "incorrect header check"; + // since zlib 1.2, it is allowted to inflateSync for this case. + /* + this.marker = 5; // can't try inflateSync + */ + continue; // break; + } + + if ((this.method & 0xf) != Z_DEFLATED) + { + this.mode = BAD; + z.msg = "unknown compression method"; + // since zlib 1.2, it is allowted to inflateSync for this case. + /* + this.marker = 5; // can't try inflateSync + */ + continue; // break; + } + + if (wrap == 4) + { + wrap = 1; + } + + if ((this.method >> 4) + 8 > this.wbits) + { + this.mode = BAD; + z.msg = "invalid window size"; + // since zlib 1.2, it is allowted to inflateSync for this case. + /* + this.marker = 5; // can't try inflateSync + */ + continue; // break; + } + + z.adler = new Adler32(); + + if ((b & PRESET_DICT) == 0) + { + this.mode = BLOCKS; + continue; // break; + } + this.mode = DICT4; + + } // case HEAD + if (this.mode == DICT4) + { + if (z.avail_in == 0) + return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need = ((z.next_in[z.next_in_index++] & 0xff) << 24) & 0xff000000L; + this.mode = DICT3; + } // case DICT4 + if (this.mode == DICT3) + { + if (z.avail_in == 0) + return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need += ((z.next_in[z.next_in_index++] & 0xff) << 16) & 0xff0000L; + this.mode = DICT2; + + } // case DICT3 + if (this.mode == DICT2) + { + if (z.avail_in == 0) + return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need += ((z.next_in[z.next_in_index++] & 0xff) << 8) & 0xff00L; + this.mode = DICT1; + + } // case DICT2 + if (this.mode == DICT1) + { + if (z.avail_in == 0) + return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need += (z.next_in[z.next_in_index++] & 0xffL); + z.adler.Reset(this.need); + this.mode = DICT0; + return Z_NEED_DICT; + } // case DICT1 + if (this.mode == DICT0) + { + this.mode = BAD; + z.msg = "need dictionary"; + this.marker = 0; // can try inflateSync + return Z_STREAM_ERROR; + } // case DICT0 + if (this.mode == BLOCKS) + { + r = this.blocks.Proc(r); + if (r == Z_DATA_ERROR) + { + this.mode = BAD; + this.marker = 0; // can try inflateSync + continue; // break; + } + if (r == Z_OK) + { + r = f; + } + if (r != Z_STREAM_END) + { + return r; + } + r = f; + this.was = z.adler.GetValue(); + this.blocks.Reset(); + if (this.wrap == 0) + { + this.mode = DONE; + continue; // break; + } + this.mode = CHECK4; + } // case BLOCKS + if (this.mode == CHECK4) + { + if (z.avail_in == 0) + return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need = ((z.next_in[z.next_in_index++] & 0xff) << 24) & 0xff000000L; + this.mode = CHECK3; + } // case CHECK4 + if (this.mode == CHECK3) + { + if (z.avail_in == 0) + return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need += ((z.next_in[z.next_in_index++] & 0xff) << 16) & 0xff0000L; + this.mode = CHECK2; + } // case CHECK3 + if (this.mode == CHECK2) + { + if (z.avail_in == 0) + return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need += ((z.next_in[z.next_in_index++] & 0xff) << 8) & 0xff00L; + this.mode = CHECK1; + } // case CHECK2 + if (this.mode == CHECK1) + { + if (z.avail_in == 0) + return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need += (z.next_in[z.next_in_index++] & 0xffL); + + if (flags != 0) + { + // gzip + this.need = ((this.need & 0xff000000) >> 24 | + (this.need & 0x00ff0000) >> 8 | + (this.need & 0x0000ff00) << 8 | + (this.need & 0x0000ffff) << 24) & 0xffffffffL; + } + + if (((int)(this.was)) != ((int)(this.need))) + { + z.msg = "incorrect data check"; + // chack is delayed + /* + this.mode = BAD; + this.marker = 5; // can't try inflateSync + break; + */ + } + else if (flags != 0 && gheader != null) + { + gheader.crc = this.need; + } + + this.mode = LENGTH; + } // case CHECK1 + if (this.mode == LENGTH) + { + if (wrap != 0 && flags != 0) + { + try { r = ReadBytes(4, r, f); } + catch (Return e) { return e.r; } + + if (z.msg != null && z.msg.Equals("incorrect data check")) + { + this.mode = BAD; + this.marker = 5; // can't try inflateSync + continue; // break; + } + + if (this.need != (z.total_out & 0xffffffffL)) + { + z.msg = "incorrect length check"; + this.mode = BAD; + continue; // break; + } + z.msg = null; + } + else + { + if (z.msg != null && z.msg.Equals("incorrect data check")) + { + this.mode = BAD; + this.marker = 5; // can't try inflateSync + continue; // break; + } + } + + this.mode = DONE; + } // case LENGTH + if (this.mode == DONE) + { + return Z_STREAM_END; + } + if (this.mode == BAD) + { + return Z_DATA_ERROR; + } + if (this.mode == FLAGS) + { + try { r = ReadBytes(2, r, f); } + catch (Return e) { return e.r; } + + flags = ((int)this.need) & 0xffff; + + if ((flags & 0xff) != Z_DEFLATED) + { + z.msg = "unknown compression method"; + this.mode = BAD; + continue; // break; + } + if ((flags & 0xe000) != 0) + { + z.msg = "unknown header flags set"; + this.mode = BAD; + continue; // break; + } + + if ((flags & 0x0200) != 0) + { + Checksum(2, this.need); + } + + this.mode = TIME; + } // case FLAGS + if (this.mode == TIME) + { + try + { + r = ReadBytes(4, r, f); + } + catch (Return e) + { + return e.r; + } + if (gheader != null) + gheader.time = this.need; + if ((flags & 0x0200) != 0) + { + Checksum(4, this.need); + } + this.mode = OS; + + } // case TIME + if (this.mode == OS) + { + try + { + r = ReadBytes(2, r, f); + } + catch (Return e) + { + return e.r; + } + if (gheader != null) + { + gheader.xflags = ((int)this.need) & 0xff; + gheader.os = (((int)this.need) >> 8) & 0xff; + } + if ((flags & 0x0200) != 0) + { + Checksum(2, this.need); + } + this.mode = EXLEN; + + } // case OS + if (this.mode == EXLEN) + { + if ((flags & 0x0400) != 0) + { + try + { + r = ReadBytes(2, r, f); + } + catch (Return e) + { + return e.r; + } + if (gheader != null) + { + gheader.extra = new byte[((int)this.need) & 0xffff]; + } + if ((flags & 0x0200) != 0) + { + Checksum(2, this.need); + } + } + else if (gheader != null) + { + gheader.extra = null; + } + this.mode = EXTRA; + + } // case EXLEN + if (this.mode == EXTRA) + { + if ((flags & 0x0400) != 0) + { + try + { + r = ReadBytes(r, f); + if (gheader != null) + { + byte[] foo = tmp_string.ToArray(); + tmp_string = null; + if (foo.Length == gheader.extra.Length) + { + Array.Copy(foo, 0, gheader.extra, 0, foo.Length); + } + else + { + z.msg = "bad extra field length"; + this.mode = BAD; + continue; // break; + } + } + } + catch (Return e) + { + return e.r; + } + } + else if (gheader != null) + { + gheader.extra = null; + } + this.mode = NAME; + } // case EXTRA + if (this.mode == NAME) + { + if ((flags & 0x0800) != 0) + { + try + { + r = ReadString(r, f); + if (gheader != null) + { + gheader.name = tmp_string.ToArray(); + } + tmp_string = null; + } + catch (Return e) + { + return e.r; + } + } + else if (gheader != null) + { + gheader.name = null; + } + this.mode = COMMENT; + + } // case NAME + if (this.mode == COMMENT) + { + if ((flags & 0x1000) != 0) + { + try + { + r = ReadString(r, f); + if (gheader != null) + { + gheader.comment = tmp_string.ToArray(); + } + tmp_string = null; + } + catch (Return e) + { + return e.r; + } + } + else if (gheader != null) + { + gheader.comment = null; + } + this.mode = HCRC; + } // case COMMENT + if (this.mode == HCRC) + { + if ((flags & 0x0200) != 0) + { + try { r = ReadBytes(2, r, f); } + catch (Return e) { return e.r; } + if (gheader != null) + { + gheader.hcrc = (int)(this.need & 0xffff); + } + if (this.need != (z.adler.GetValue() & 0xffffL)) + { + this.mode = BAD; + z.msg = "header crc mismatch"; + this.marker = 5; // can't try inflateSync + continue; // break; + } + } + z.adler = new CRC32(); + this.mode = BLOCKS; + continue; // break; + } // case HCRC + + // BAD + // Default + break; + } + + // Default + return Z_STREAM_ERROR; + } + + internal int InflateSetDictionary(byte[] dictionary, int dictLength) + { + if (z == null || (this.mode != DICT0 && this.wrap != 0)) + { + return Z_STREAM_ERROR; + } + + int index = 0; + int length = dictLength; + + if (this.mode == DICT0) + { + long adler_need = z.adler.GetValue(); + z.adler.Reset(); + z.adler.Update(dictionary, 0, dictLength); + if (z.adler.GetValue() != adler_need) + { + return Z_DATA_ERROR; + } + } + + z.adler.Reset(); + + if (length >= (1 << this.wbits)) + { + length = (1 << this.wbits) - 1; + index = dictLength - length; + } + this.blocks.Set_dictionary(dictionary, index, length); + this.mode = BLOCKS; + return Z_OK; + } + + static byte[] mark = { (byte)0, (byte)0, (byte)0xff, (byte)0xff }; + + internal int InflateSync() + { + int n; // number of bytes to look at + int p; // pointer to bytes + int m; // number of marker bytes found in a row + long r, w; // temporaries to save total_in and total_out + + // set up + if (z == null) + return Z_STREAM_ERROR; + if (this.mode != BAD) + { + this.mode = BAD; + this.marker = 0; + } + if ((n = z.avail_in) == 0) + return Z_BUF_ERROR; + + p = z.next_in_index; + m = this.marker; + // search + while (n != 0 && m < 4) + { + if (z.next_in[p] == mark[m]) + { + m++; + } + else if (z.next_in[p] != 0) + { + m = 0; + } + else + { + m = 4 - m; + } + p++; + n--; + } + + // restore + z.total_in += p - z.next_in_index; + z.next_in_index = p; + z.avail_in = n; + this.marker = m; + + // return no joy or set up to restart on a new block + if (m != 4) + { + return Z_DATA_ERROR; + } + r = z.total_in; + w = z.total_out; + InflateReset(); + z.total_in = r; + z.total_out = w; + this.mode = BLOCKS; + + return Z_OK; + } + + // Returns true if inflate is currently at the end of a block generated + // by Z_SYNC_FLUSH or Z_FULL_FLUSH. This function is used by one PPP + // implementation to provide an additional safety check. PPP uses Z_SYNC_FLUSH + // but removes the length bytes of the resulting empty stored block. When + // decompressing, PPP checks that at the end of input packet, inflate is + // waiting for these length bytes. + internal int InflateSyncPoint() + { + if (z == null || this.blocks == null) + return Z_STREAM_ERROR; + return this.blocks.Sync_point(); + } + + int ReadBytes(int n, int r, int f) + { + if (need_bytes == -1) + { + need_bytes = n; + this.need = 0; + } + while (need_bytes > 0) + { + if (z.avail_in == 0) { throw new Return(r); } + r = f; + z.avail_in--; z.total_in++; + this.need = this.need | + (long)((z.next_in[z.next_in_index++] & 0xff) << ((n - need_bytes) * 8)); + need_bytes--; + } + if (n == 2) + { + this.need &= 0xffffL; + } + else if (n == 4) + { + this.need &= 0xffffffffL; + } + need_bytes = -1; + return r; + } + + class Return : Exception + { + internal int r; + + internal Return(int r) { this.r = r; } + } + + MemoryStream tmp_string = null; + + int ReadString(int r, int f) + { + if (tmp_string == null) + { + tmp_string = new MemoryStream(); + } + int b = 0; + do + { + if (z.avail_in == 0) + { + throw new Return(r); + } + ; + r = f; + z.avail_in--; + z.total_in++; + b = z.next_in[z.next_in_index]; + if (b != 0) + tmp_string.Write(z.next_in, z.next_in_index, 1); + z.adler.Update(z.next_in, z.next_in_index, 1); + z.next_in_index++; + } + while (b != 0); + return r; + } + + int ReadBytes(int r, int f) + { + if (tmp_string == null) + { + tmp_string = new MemoryStream(); + } + int b = 0; + while (this.need > 0) + { + if (z.avail_in == 0) + { + throw new Return(r); + } + ; + r = f; + z.avail_in--; + z.total_in++; + b = z.next_in[z.next_in_index]; + tmp_string.Write(z.next_in, z.next_in_index, 1); + z.adler.Update(z.next_in, z.next_in_index, 1); + z.next_in_index++; + this.need--; + } + return r; + } + + void Checksum(int n, long v) + { + for (int i = 0; i < n; i++) + { + crcbuf[i] = (byte)(v & 0xff); + v >>= 8; + } + z.adler.Update(crcbuf, 0, n); + } + + public GZIPHeader getGZIPHeader() => this.gheader; + + internal bool InParsingHeader() + { + switch (mode) + { + case HEAD: + case DICT4: + case DICT3: + case DICT2: + case DICT1: + case FLAGS: + case TIME: + case OS: + case EXLEN: + case EXTRA: + case NAME: + case COMMENT: + case HCRC: + return true; + default: + return false; + } + } + } +} \ No newline at end of file diff --git a/src/DotNetty.Codecs/Compression/Inflater.cs b/src/DotNetty.Codecs/Compression/Inflater.cs new file mode 100644 index 0000000..547fd58 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/Inflater.cs @@ -0,0 +1,168 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +//JZlib 0.0.* were released under the GNU LGPL license.Later, we have switched +//over to a BSD-style license. + +//------------------------------------------------------------------------------ +//Copyright (c) 2000-2011 ymnk, JCraft, Inc.All rights reserved. + +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in +// the documentation and/or other materials provided with the distribution. + +// 3. The names of the authors may not be used to endorse or promote products +// derived from this software without specific prior written permission. + +//THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, +//INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +//FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.IN NO EVENT SHALL JCRAFT, +//INC.OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +//LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +//OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +//LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING +//NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +//EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ReSharper disable ArrangeThisQualifier +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Compression +{ + /// + /// https://github.com/ymnk/jzlib/blob/master/src/main/java/com/jcraft/jzlib/Inflater.java + /// + sealed class Inflater : ZStream + { + const int MAX_WBITS = 15; // 32K LZ77 window + const int DEF_WBITS = MAX_WBITS; + + //const int Z_NO_FLUSH = 0; + //const int Z_PARTIAL_FLUSH = 1; + //const int Z_SYNC_FLUSH = 2; + //const int Z_FULL_FLUSH = 3; + //const int Z_FINISH = 4; + + //const int MAX_MEM_LEVEL = 9; + + const int Z_OK = 0; + //const int Z_STREAM_END = 1; + //const int Z_NEED_DICT = 2; + //const int Z_ERRNO = -1; + const int Z_STREAM_ERROR = -2; + //const int Z_DATA_ERROR = -3; + //const int Z_MEM_ERROR = -4; + //const int Z_BUF_ERROR = -5; + //const int Z_VERSION_ERROR = -6; + + public Inflater() + { + Init(); + } + + public Inflater(JZlib.WrapperType wrapperType) : this(DEF_WBITS, wrapperType) + { + } + + public Inflater(int w, JZlib.WrapperType wrapperType) + { + int ret = Init(w, wrapperType); + if (ret != Z_OK) throw new GZIPException(ret + ": " + msg); + } + + public Inflater(int w) : this(w, false) + { + } + + public Inflater(bool nowrap) : this(DEF_WBITS, nowrap) + { + } + + public Inflater(int w, bool nowrap) + { + int ret = Init(w, nowrap); + if (ret != Z_OK) throw new GZIPException(ret + ": " + msg); + } + + //bool finished = false; + + public int Init() => Init(DEF_WBITS); + + public int Init(JZlib.WrapperType wrapperType) => Init(DEF_WBITS, wrapperType); + + public int Init(int w, JZlib.WrapperType wrapperType) + { + bool nowrap = false; + if (wrapperType == JZlib.W_NONE) + { + nowrap = true; + } + else if (wrapperType == JZlib.W_GZIP) + { + w += 16; + } + else if (wrapperType == JZlib.W_ANY) + { + w |= Compression.Inflate.INFLATE_ANY; + } + else if (wrapperType == JZlib.W_ZLIB) + { + } + return Init(w, nowrap); + } + + public int Init(bool nowrap) => Init(DEF_WBITS, nowrap); + + public int Init(int w) => Init(w, false); + + public int Init(int w, bool nowrap) + { + //finished = false; + istate = new Inflate(this); + return istate.InflateInit(nowrap ? -w : w); + } + + public int Inflate(int f) + { + if (istate == null) return Z_STREAM_ERROR; + int ret = istate.Inflate_I(f); + //if (ret == Z_STREAM_END) + // finished = true; + return ret; + } + + public override int End() + { + // finished = true; + if (istate == null) return Z_STREAM_ERROR; + int ret = istate.InflateEnd(); + // istate = null; + return ret; + } + + public int Sync() + { + if (istate == null) return Z_STREAM_ERROR; + return istate.InflateSync(); + } + public int SyncPoint() + { + if (istate == null) return Z_STREAM_ERROR; + return istate.InflateSyncPoint(); + } + + public int SetDictionary(byte[] dictionary, int dictLength) + { + if (istate == null) return Z_STREAM_ERROR; + return istate.InflateSetDictionary(dictionary, dictLength); + } + + public override bool Finished() => istate.mode == 12 /*DONE*/; + } +} diff --git a/src/DotNetty.Codecs/Compression/JZlib.cs b/src/DotNetty.Codecs/Compression/JZlib.cs new file mode 100644 index 0000000..3085fb0 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/JZlib.cs @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +//JZlib 0.0.* were released under the GNU LGPL license.Later, we have switched +//over to a BSD-style license. + +//------------------------------------------------------------------------------ +//Copyright (c) 2000-2011 ymnk, JCraft, Inc.All rights reserved. + +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in +// the documentation and/or other materials provided with the distribution. + +// 3. The names of the authors may not be used to endorse or promote products +// derived from this software without specific prior written permission. + +//THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, +//INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +//FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.IN NO EVENT SHALL JCRAFT, +//INC.OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +//LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +//OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +//LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING +//NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +//EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ReSharper disable ArrangeThisQualifier +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Compression +{ + /// + /// https://github.com/ymnk/jzlib/blob/master/src/main/java/com/jcraft/jzlib/JZlib.java + /// + public sealed class JZlib + { + const string VersionString = "1.1.0"; + public static string Version() => VersionString; + + public static readonly int MAX_WBITS = 15; // 32K LZ77 window + public static readonly int DEF_WBITS = MAX_WBITS; + + public enum WrapperType + { + NONE, + ZLIB, + GZIP, + ANY + } + + public static readonly WrapperType W_NONE = WrapperType.NONE; + public static readonly WrapperType W_ZLIB = WrapperType.ZLIB; + public static readonly WrapperType W_GZIP = WrapperType.GZIP; + public static readonly WrapperType W_ANY = WrapperType.ANY; + + // compression levels + public static readonly int Z_NO_COMPRESSION = 0; + public static readonly int Z_BEST_SPEED = 1; + public static readonly int Z_BEST_COMPRESSION = 9; + public static readonly int Z_DEFAULT_COMPRESSION = (-1); + + // compression strategy + public static readonly int Z_FILTERED = 1; + public static readonly int Z_HUFFMAN_ONLY = 2; + public static readonly int Z_DEFAULT_STRATEGY = 0; + + public static readonly int Z_NO_FLUSH = 0; + public static readonly int Z_PARTIAL_FLUSH = 1; + public static readonly int Z_SYNC_FLUSH = 2; + public static readonly int Z_FULL_FLUSH = 3; + public static readonly int Z_FINISH = 4; + + public static readonly int Z_OK = 0; + public static readonly int Z_STREAM_END = 1; + public static readonly int Z_NEED_DICT = 2; + public static readonly int Z_ERRNO = -1; + public static readonly int Z_STREAM_ERROR = -2; + public static readonly int Z_DATA_ERROR = -3; + public static readonly int Z_MEM_ERROR = -4; + public static readonly int Z_BUF_ERROR = -5; + public static readonly int Z_VERSION_ERROR = -6; + + // The three kinds of block type + public static readonly byte Z_BINARY = 0; + public static readonly byte Z_ASCII = 1; + public static readonly byte Z_UNKNOWN = 2; + + public static long Adler32_combine(long adler1, long adler2, long len2) => + Adler32.Combine(adler1, adler2, len2); + } +} diff --git a/src/DotNetty.Codecs/Compression/JZlibDecoder.cs b/src/DotNetty.Codecs/Compression/JZlibDecoder.cs new file mode 100644 index 0000000..e8b4bce --- /dev/null +++ b/src/DotNetty.Codecs/Compression/JZlibDecoder.cs @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Compression +{ + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using DotNetty.Buffers; + using DotNetty.Transport.Channels; + + public class JZlibDecoder : ZlibDecoder + { + readonly Inflater z = new Inflater(); + readonly byte[] dictionary; + volatile bool finished; + + public JZlibDecoder() : this(ZlibWrapper.ZlibOrNone) + { + } + + public JZlibDecoder(ZlibWrapper wrapper) + { + int resultCode = this.z.Init(ZlibUtil.ConvertWrapperType(wrapper)); + if (resultCode != JZlib.Z_OK) + { + ZlibUtil.Fail(this.z, "initialization failure", resultCode); + } + } + + public JZlibDecoder(byte[] dictionary) + { + Contract.Requires(dictionary != null); + this.dictionary = dictionary; + + int resultCode; + resultCode = this.z.InflateInit(JZlib.W_ZLIB); + if (resultCode != JZlib.Z_OK) + { + ZlibUtil.Fail(this.z, "initialization failure", resultCode); + } + } + + public override bool IsClosed => this.finished; + + protected internal override void Decode(IChannelHandlerContext context, IByteBuffer input, List output) + { + if (this.finished) + { + // Skip data received after finished. + input.SkipBytes(input.ReadableBytes); + return; + } + + int inputLength = input.ReadableBytes; + if (inputLength == 0) + { + return; + } + + try + { + // Configure input. + this.z.avail_in = inputLength; + if (input.HasArray) + { + this.z.next_in = input.Array; + this.z.next_in_index = input.ArrayOffset + input.ReaderIndex; + } + else + { + var array = new byte[inputLength]; + input.GetBytes(input.ReaderIndex, array); + this.z.next_in = array; + this.z.next_in_index = 0; + } + int oldNextInIndex = this.z.next_in_index; + + // Configure output. + int maxOutputLength = inputLength << 1; + IByteBuffer decompressed = context.Allocator.Buffer(maxOutputLength); + + try + { + while (true) + { + this.z.avail_out = maxOutputLength; + decompressed.EnsureWritable(maxOutputLength); + this.z.next_out = decompressed.Array; + this.z.next_out_index = decompressed.ArrayOffset + decompressed.WriterIndex; + int oldNextOutIndex = this.z.next_out_index; + + // Decompress 'in' into 'out' + int resultCode = this.z.Inflate(JZlib.Z_SYNC_FLUSH); + int outputLength = this.z.next_out_index - oldNextOutIndex; + if (outputLength > 0) + { + decompressed.SetWriterIndex(decompressed.WriterIndex + outputLength); + } + + if (resultCode == JZlib.Z_NEED_DICT) + { + if (this.dictionary == null) + { + ZlibUtil.Fail(this.z, "decompression failure", resultCode); + } + else + { + resultCode = this.z.InflateSetDictionary(this.dictionary, this.dictionary.Length); + if (resultCode != JZlib.Z_OK) + { + ZlibUtil.Fail(this.z, "failed to set the dictionary", resultCode); + } + } + continue; + } + if (resultCode == JZlib.Z_STREAM_END) + { + this.finished = true; // Do not decode anymore. + this.z.InflateEnd(); + break; + } + if (resultCode == JZlib.Z_OK) + { + continue; + } + if (resultCode == JZlib.Z_BUF_ERROR) + { + if (this.z.avail_in <= 0) + { + break; + } + + continue; + } + //default + ZlibUtil.Fail(this.z, "decompression failure", resultCode); + } + } + finally + { + input.SkipBytes(this.z.next_in_index - oldNextInIndex); + if (decompressed.IsReadable()) + { + output.Add(decompressed); + } + else + { + decompressed.Release(); + } + } + } + finally + { + this.z.next_in = null; + this.z.next_out = null; + } + + } + } +} diff --git a/src/DotNetty.Codecs/Compression/JZlibEncoder.cs b/src/DotNetty.Codecs/Compression/JZlibEncoder.cs new file mode 100644 index 0000000..c460017 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/JZlibEncoder.cs @@ -0,0 +1,249 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Compression +{ + using System; + using System.Diagnostics.Contracts; + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public class JZlibEncoder : ZlibEncoder + { + readonly int wrapperOverhead; + readonly Deflater z = new Deflater(); + + volatile bool finished; + volatile IChannelHandlerContext ctx; + + public JZlibEncoder() : this(6) + { + } + + public JZlibEncoder(int compressionLevel) : this(ZlibWrapper.Zlib, compressionLevel) + { + } + + public JZlibEncoder(ZlibWrapper wrapper) : this(wrapper, 6) + { + } + + public JZlibEncoder(ZlibWrapper wrapper, int compressionLevel) : this(wrapper, compressionLevel, 15, 8) + { + } + + /** + * Creates a new zlib encoder with the specified {@code compressionLevel}, + * the specified {@code windowBits}, the specified {@code memLevel}, and + * the specified wrapper. + * + * @param compressionLevel + * {@code 1} yields the fastest compression and {@code 9} yields the + * best compression. {@code 0} means no compression. The default + * compression level is {@code 6}. + * @param windowBits + * The base two logarithm of the size of the history buffer. The + * value should be in the range {@code 9} to {@code 15} inclusive. + * Larger values result in better compression at the expense of + * memory usage. The default value is {@code 15}. + * @param memLevel + * How much memory should be allocated for the internal compression + * state. {@code 1} uses minimum memory and {@code 9} uses maximum + * memory. Larger values result in better and faster compression + * at the expense of memory usage. The default value is {@code 8} + * + * @throws CompressionException if failed to initialize zlib + */ + public JZlibEncoder(ZlibWrapper wrapper, int compressionLevel, int windowBits, int memLevel) + { + Contract.Requires(compressionLevel >= 0 && compressionLevel <= 9); + Contract.Requires(windowBits >= 9 && windowBits <= 15); + Contract.Requires(memLevel >= 1 && memLevel <= 9); + + int resultCode = this.z.Init( + compressionLevel, windowBits, memLevel, + ZlibUtil.ConvertWrapperType(wrapper)); + if (resultCode != JZlib.Z_OK) + { + ZlibUtil.Fail(this.z, "initialization failure", resultCode); + } + + this.wrapperOverhead = ZlibUtil.WrapperOverhead(wrapper); + } + public JZlibEncoder(byte[] dictionary) : this(6, dictionary) + { + } + + public JZlibEncoder(int compressionLevel, byte[] dictionary) : this(compressionLevel, 15, 8, dictionary) + { + } + + public JZlibEncoder(int compressionLevel, int windowBits, int memLevel, byte[] dictionary) + { + Contract.Requires(compressionLevel >= 0 && compressionLevel <= 9); + Contract.Requires(windowBits >= 9 && windowBits <= 15); + Contract.Requires(memLevel >= 1 && memLevel <= 9); + + int resultCode = this.z.DeflateInit( + compressionLevel, windowBits, memLevel, + JZlib.W_ZLIB); // Default: ZLIB format + + if (resultCode != JZlib.Z_OK) + { + ZlibUtil.Fail(this.z, "initialization failure", resultCode); + } + else + { + resultCode = this.z.DeflateSetDictionary(dictionary, dictionary.Length); + if (resultCode != JZlib.Z_OK) + { + ZlibUtil.Fail(this.z, "failed to set the dictionary", resultCode); + } + } + + this.wrapperOverhead = ZlibUtil.WrapperOverhead(ZlibWrapper.Zlib); + } + + public override Task CloseAsync() => this.CloseAsync(this.CurrentContext()); + + public override Task CloseAsync(IChannelHandlerContext context) => this.FinishEncode(context); + + IChannelHandlerContext CurrentContext() + { + IChannelHandlerContext context = this.ctx; + if (context == null) + { + throw new InvalidOperationException("not added to a pipeline"); + } + + return context; + } + + public override bool IsClosed => this.finished; + + protected override void Encode(IChannelHandlerContext context, IByteBuffer message, IByteBuffer output) + { + if (this.finished) + { + output.WriteBytes(message); + return; + } + + int inputLength = message.ReadableBytes; + if (inputLength == 0) + { + return; + } + + try + { + // Configure input. + bool inHasArray = message.HasArray; + this.z.avail_in = inputLength; + if (inHasArray) + { + this.z.next_in = message.Array; + this.z.next_in_index = message.ArrayOffset + message.ReaderIndex; + } + else + { + var array = new byte[inputLength]; + message.GetBytes(message.ReaderIndex, array); + this.z.next_in = array; + this.z.next_in_index = 0; + } + int oldNextInIndex = this.z.next_in_index; + + // Configure output. + int maxOutputLength = (int)Math.Ceiling(inputLength * 1.001) + 12 + this.wrapperOverhead; + output.EnsureWritable(maxOutputLength); + this.z.avail_out = maxOutputLength; + this.z.next_out = output.Array; + this.z.next_out_index = output.ArrayOffset + output.WriterIndex; + int oldNextOutIndex = this.z.next_out_index; + + int resultCode; + try + { + resultCode = this.z.Deflate(JZlib.Z_SYNC_FLUSH); + } + finally + { + message.SkipBytes(this.z.next_in_index - oldNextInIndex); + } + + if (resultCode != JZlib.Z_OK) + { + ZlibUtil.Fail(this.z, "compression failure", resultCode); + } + + int outputLength = this.z.next_out_index - oldNextOutIndex; + if (outputLength > 0) + { + output.SetWriterIndex(output.WriterIndex + outputLength); + } + } + finally + { + this.z.next_in = null; + this.z.next_out = null; + } + } + + Task FinishEncode(IChannelHandlerContext context) + { + if (this.finished) + { + return TaskEx.Completed; + } + + this.finished = true; + + IByteBuffer footer; + try + { + // Configure input. + this.z.next_in = ArrayExtensions.ZeroBytes; + this.z.next_in_index = 0; + this.z.avail_in = 0; + + // Configure output. + var output = new byte[32]; // room for ADLER32 + ZLIB / CRC32 + GZIP header + this.z.next_out = output; + this.z.next_out_index = 0; + this.z.avail_out = output.Length; + + // Write the ADLER32 checksum(stream footer). + int resultCode = this.z.Deflate(JZlib.Z_FINISH); + if (resultCode != JZlib.Z_OK && resultCode != JZlib.Z_STREAM_END) + { + context.FireExceptionCaught( + new CompressionException($"Compression failure ({resultCode}) {this.z.msg}")); + return context.CloseAsync(); + } + else if (this.z.next_out_index != 0) + { + footer = Unpooled.WrappedBuffer(output, 0, this.z.next_out_index); + } + else + { + footer = Unpooled.Empty; + } + } + finally + { + this.z.DeflateEnd(); + + this.z.next_in = null; + this.z.next_out = null; + } + + return context.WriteAndFlushAsync(footer) + .ContinueWith(_ => context.CloseAsync()); + } + + public override void HandlerAdded(IChannelHandlerContext context) => this.ctx = context; + } +} diff --git a/src/DotNetty.Codecs/Compression/StaticTree.cs b/src/DotNetty.Codecs/Compression/StaticTree.cs new file mode 100644 index 0000000..5c0e8e3 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/StaticTree.cs @@ -0,0 +1,158 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +//JZlib 0.0.* were released under the GNU LGPL license.Later, we have switched +//over to a BSD-style license. + +//------------------------------------------------------------------------------ +//Copyright (c) 2000-2011 ymnk, JCraft, Inc.All rights reserved. + +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in +// the documentation and/or other materials provided with the distribution. + +// 3. The names of the authors may not be used to endorse or promote products +// derived from this software without specific prior written permission. + +//THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, +//INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +//FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.IN NO EVENT SHALL JCRAFT, +//INC.OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +//LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +//OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +//LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING +//NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +//EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ReSharper disable ArrangeThisQualifier +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Compression +{ + /// + /// https://github.com/ymnk/jzlib/blob/master/src/main/java/com/jcraft/jzlib/StaticTree.java + /// + sealed class StaticTree + { + const int MAX_BITS = 15; + + const int BL_CODES = 19; + const int D_CODES = 30; + const int LITERALS = 256; + const int LENGTH_CODES = 29; + const int L_CODES = LITERALS + 1 + LENGTH_CODES; + + // Bit length codes must not exceed MAX_BL_BITS bits + const int MAX_BL_BITS = 7; + + internal static readonly short[] static_ltree = + { + 12, 8, 140, 8, 76, 8, 204, 8, 44, 8, + 172, 8, 108, 8, 236, 8, 28, 8, 156, 8, + 92, 8, 220, 8, 60, 8, 188, 8, 124, 8, + 252, 8, 2, 8, 130, 8, 66, 8, 194, 8, + 34, 8, 162, 8, 98, 8, 226, 8, 18, 8, + 146, 8, 82, 8, 210, 8, 50, 8, 178, 8, + 114, 8, 242, 8, 10, 8, 138, 8, 74, 8, + 202, 8, 42, 8, 170, 8, 106, 8, 234, 8, + 26, 8, 154, 8, 90, 8, 218, 8, 58, 8, + 186, 8, 122, 8, 250, 8, 6, 8, 134, 8, + 70, 8, 198, 8, 38, 8, 166, 8, 102, 8, + 230, 8, 22, 8, 150, 8, 86, 8, 214, 8, + 54, 8, 182, 8, 118, 8, 246, 8, 14, 8, + 142, 8, 78, 8, 206, 8, 46, 8, 174, 8, + 110, 8, 238, 8, 30, 8, 158, 8, 94, 8, + 222, 8, 62, 8, 190, 8, 126, 8, 254, 8, + 1, 8, 129, 8, 65, 8, 193, 8, 33, 8, + 161, 8, 97, 8, 225, 8, 17, 8, 145, 8, + 81, 8, 209, 8, 49, 8, 177, 8, 113, 8, + 241, 8, 9, 8, 137, 8, 73, 8, 201, 8, + 41, 8, 169, 8, 105, 8, 233, 8, 25, 8, + 153, 8, 89, 8, 217, 8, 57, 8, 185, 8, + 121, 8, 249, 8, 5, 8, 133, 8, 69, 8, + 197, 8, 37, 8, 165, 8, 101, 8, 229, 8, + 21, 8, 149, 8, 85, 8, 213, 8, 53, 8, + 181, 8, 117, 8, 245, 8, 13, 8, 141, 8, + 77, 8, 205, 8, 45, 8, 173, 8, 109, 8, + 237, 8, 29, 8, 157, 8, 93, 8, 221, 8, + 61, 8, 189, 8, 125, 8, 253, 8, 19, 9, + 275, 9, 147, 9, 403, 9, 83, 9, 339, 9, + 211, 9, 467, 9, 51, 9, 307, 9, 179, 9, + 435, 9, 115, 9, 371, 9, 243, 9, 499, 9, + 11, 9, 267, 9, 139, 9, 395, 9, 75, 9, + 331, 9, 203, 9, 459, 9, 43, 9, 299, 9, + 171, 9, 427, 9, 107, 9, 363, 9, 235, 9, + 491, 9, 27, 9, 283, 9, 155, 9, 411, 9, + 91, 9, 347, 9, 219, 9, 475, 9, 59, 9, + 315, 9, 187, 9, 443, 9, 123, 9, 379, 9, + 251, 9, 507, 9, 7, 9, 263, 9, 135, 9, + 391, 9, 71, 9, 327, 9, 199, 9, 455, 9, + 39, 9, 295, 9, 167, 9, 423, 9, 103, 9, + 359, 9, 231, 9, 487, 9, 23, 9, 279, 9, + 151, 9, 407, 9, 87, 9, 343, 9, 215, 9, + 471, 9, 55, 9, 311, 9, 183, 9, 439, 9, + 119, 9, 375, 9, 247, 9, 503, 9, 15, 9, + 271, 9, 143, 9, 399, 9, 79, 9, 335, 9, + 207, 9, 463, 9, 47, 9, 303, 9, 175, 9, + 431, 9, 111, 9, 367, 9, 239, 9, 495, 9, + 31, 9, 287, 9, 159, 9, 415, 9, 95, 9, + 351, 9, 223, 9, 479, 9, 63, 9, 319, 9, + 191, 9, 447, 9, 127, 9, 383, 9, 255, 9, + 511, 9, 0, 7, 64, 7, 32, 7, 96, 7, + 16, 7, 80, 7, 48, 7, 112, 7, 8, 7, + 72, 7, 40, 7, 104, 7, 24, 7, 88, 7, + 56, 7, 120, 7, 4, 7, 68, 7, 36, 7, + 100, 7, 20, 7, 84, 7, 52, 7, 116, 7, + 3, 8, 131, 8, 67, 8, 195, 8, 35, 8, + 163, 8, 99, 8, 227, 8 + }; + + internal static readonly short[] static_dtree = + { + 0, 5, 16, 5, 8, 5, 24, 5, 4, 5, + 20, 5, 12, 5, 28, 5, 2, 5, 18, 5, + 10, 5, 26, 5, 6, 5, 22, 5, 14, 5, + 30, 5, 1, 5, 17, 5, 9, 5, 25, 5, + 5, 5, 21, 5, 13, 5, 29, 5, 3, 5, + 19, 5, 11, 5, 27, 5, 7, 5, 23, 5 + }; + + internal static StaticTree static_l_desc = + new StaticTree(static_ltree, Tree.extra_lbits, + LITERALS + 1, L_CODES, MAX_BITS); + + internal static StaticTree static_d_desc = + new StaticTree(static_dtree, Tree.extra_dbits, + 0, D_CODES, MAX_BITS); + + internal static StaticTree static_bl_desc = + new StaticTree(null, Tree.extra_blbits, + 0, BL_CODES, MAX_BL_BITS); + + internal short[] static_tree; // static tree or null + internal int[] extra_bits; // extra bits for each code or null + internal int extra_base; // base index for extra_bits + internal int elems; // max number of elements in the tree + internal int max_length; // max bit length for the codes + + StaticTree( + short[] static_tree, + int[] extra_bits, + int extra_base, + int elems, + int max_length) + { + this.static_tree = static_tree; + this.extra_bits = extra_bits; + this.extra_base = extra_base; + this.elems = elems; + this.max_length = max_length; + } + } +} diff --git a/src/DotNetty.Codecs/Compression/Tree.cs b/src/DotNetty.Codecs/Compression/Tree.cs new file mode 100644 index 0000000..07402e8 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/Tree.cs @@ -0,0 +1,414 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +//JZlib 0.0.* were released under the GNU LGPL license.Later, we have switched +//over to a BSD-style license. + +//------------------------------------------------------------------------------ +//Copyright (c) 2000-2011 ymnk, JCraft, Inc.All rights reserved. + +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in +// the documentation and/or other materials provided with the distribution. + +// 3. The names of the authors may not be used to endorse or promote products +// derived from this software without specific prior written permission. + +//THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, +//INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +//FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.IN NO EVENT SHALL JCRAFT, +//INC.OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +//LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +//OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +//LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING +//NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +//EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ReSharper disable ArrangeThisQualifier +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Compression +{ + using System; + using DotNetty.Common.Utilities; + + /// + /// https://github.com/ymnk/jzlib/blob/master/src/main/java/com/jcraft/jzlib/Tree.java + /// + sealed class Tree + { + const int MAX_BITS = 15; + //const int BL_CODES = 19; + //const int D_CODES = 30; + const int LITERALS = 256; + const int LENGTH_CODES = 29; + const int L_CODES = (LITERALS + 1 + LENGTH_CODES); + const int HEAP_SIZE = (2 * L_CODES + 1); + + // Bit length codes must not exceed MAX_BL_BITS bits + internal static readonly int MAX_BL_BITS = 7; + + // end of block literal code + internal static readonly int END_BLOCK = 256; + + // repeat previous bit length 3-6 times (2 bits of repeat count) + internal static readonly int REP_3_6 = 16; + + // repeat a zero length 3-10 times (3 bits of repeat count) + internal static readonly int REPZ_3_10 = 17; + + // repeat a zero length 11-138 times (7 bits of repeat count) + internal static readonly int REPZ_11_138 = 18; + + // extra bits for each length code + internal static readonly int[] extra_lbits = + { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 + }; + + // extra bits for each distance code + internal static readonly int[] extra_dbits = + { + 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13 + }; + + // extra bits for each bit length code + internal static readonly int[] extra_blbits = + { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 7 + }; + + internal static readonly byte[] bl_order = + { + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 + }; + + + // The lengths of the bit length codes are sent in order of decreasing + // probability, to avoid transmitting the lengths for unused bit + // length codes. + + internal static readonly int Buf_size = 8 * 2; + + // see definition of array dist_code below + internal static readonly int DIST_CODE_LEN = 512; + + static readonly byte[] _dist_code = + { + 0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, + 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, + 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, + 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, + 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, + 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, + 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, + 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, + 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 0, 0, 16, 17, + 18, 18, 19, 19, 20, 20, 20, 20, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, + 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, + 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, + 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, + 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, + 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, + 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, + 28, 28, 28, 28, 28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, + 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, + 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, + 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29 + }; + + internal static readonly byte[] _length_code = + { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 12, 12, + 13, 13, 13, 13, 14, 14, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, + 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, + 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, + 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23, + 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26, + 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, + 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, + 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 28 + }; + + internal static readonly int[] base_length = + { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32, 40, 48, 56, + 64, 80, 96, 112, 128, 160, 192, 224, 0 + }; + + internal static readonly int[] base_dist = + { + 0, 1, 2, 3, 4, 6, 8, 12, 16, 24, + 32, 48, 64, 96, 128, 192, 256, 384, 512, 768, + 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 24576 + }; + + // Mapping from a distance to a distance code. dist is the distance - 1 and + // must not have side effects. _dist_code[256] and _dist_code[257] are never + // used. + internal static int D_code(int dist) => + (dist < 256 ? _dist_code[dist] : _dist_code[256 + dist.RightUShift(7)]); + + internal short[] dyn_tree; // the dynamic tree + internal int max_code; // largest code with non zero frequency + internal StaticTree stat_desc; // the corresponding static tree + + // Compute the optimal bit lengths for a tree and update the total bit length + // for the current block. + // IN assertion: the fields freq and dad are set, heap[heap_max] and + // above are the tree nodes sorted by increasing frequency. + // OUT assertions: the field len is set to the optimal bit length, the + // array bl_count contains the frequencies for each bit length. + // The length opt_len is updated; static_len is also updated if stree is + // not null. + void Gen_bitlen(Deflate s) + { + short[] tree = dyn_tree; + short[] stree = stat_desc.static_tree; + int[] extra = stat_desc.extra_bits; + int _base = stat_desc.extra_base; + int max_length = stat_desc.max_length; + int h; // heap index + int n, m; // iterate over the tree elements + int bits; // bit length + int xbits; // extra bits + short f; // frequency + int overflow = 0; // number of elements with bit length too large + + for (bits = 0; bits <= MAX_BITS; bits++) + s.bl_count[bits] = 0; + + // In a first pass, compute the optimal bit lengths (which may + // overflow in the case of the bit length tree). + tree[s.heap[s.heap_max] * 2 + 1] = 0; // root of the heap + + for (h = s.heap_max + 1; h < HEAP_SIZE; h++) + { + n = s.heap[h]; + bits = tree[tree[n * 2 + 1] * 2 + 1] + 1; + if (bits > max_length) + { + bits = max_length; + overflow++; + } + tree[n * 2 + 1] = (short)bits; + // We overwrite tree[n*2+1] which is no longer needed + + if (n > max_code) + continue; // not a leaf node + + s.bl_count[bits]++; + xbits = 0; + if (n >= _base) + xbits = extra[n - _base]; + f = tree[n * 2]; + s.opt_len += f * (bits + xbits); + if (stree != null) + s.static_len += f * (stree[n * 2 + 1] + xbits); + } + if (overflow == 0) + return; + + // This happens for example on obj2 and pic of the Calgary corpus + // Find the first bit length which could increase: + do + { + bits = max_length - 1; + while (s.bl_count[bits] == 0) + bits--; + s.bl_count[bits]--; // move one leaf down the tree + s.bl_count[bits + 1] += 2; // move one overflow item as its brother + s.bl_count[max_length]--; + // The brother of the overflow item also moves one step up, + // but this does not affect bl_count[max_length] + overflow -= 2; + } + while (overflow > 0); + + for (bits = max_length; bits != 0; bits--) + { + n = s.bl_count[bits]; + while (n != 0) + { + m = s.heap[--h]; + if (m > max_code) + continue; + if (tree[m * 2 + 1] != bits) + { + s.opt_len += (int)(((long)bits - (long)tree[m * 2 + 1]) * (long)tree[m * 2]); + tree[m * 2 + 1] = (short)bits; + } + n--; + } + } + } + + // Construct one Huffman tree and assigns the code bit strings and lengths. + // Update the total bit length for the current block. + // IN assertion: the field freq is set for all tree elements. + // OUT assertions: the fields len and code are set to the optimal bit length + // and corresponding code. The length opt_len is updated; static_len is + // also updated if stree is not null. The field max_code is set. + internal void Build_tree(Deflate s) + { + short[] tree = dyn_tree; + short[] stree = stat_desc.static_tree; + int elems = stat_desc.elems; + int n, m; // iterate over heap elements + int max_code = -1; // largest code with non zero frequency + int node; // new node being created + + // Construct the initial heap, with least frequent element in + // heap[1]. The sons of heap[n] are heap[2*n] and heap[2*n+1]. + // heap[0] is not used. + s.heap_len = 0; + s.heap_max = HEAP_SIZE; + + for (n = 0; n < elems; n++) + { + if (tree[n * 2] != 0) + { + s.heap[++s.heap_len] = max_code = n; + s.depth[n] = 0; + } + else + { + tree[n * 2 + 1] = 0; + } + } + + // The pkzip format requires that at least one distance code exists, + // and that at least one bit should be sent even if there is only one + // possible code. So to avoid special checks later on we force at least + // two codes of non zero frequency. + while (s.heap_len < 2) + { + node = s.heap[++s.heap_len] = (max_code < 2 ? ++max_code : 0); + tree[node * 2] = 1; + s.depth[node] = 0; + s.opt_len--; + if (stree != null) + s.static_len -= stree[node * 2 + 1]; + // node is 0 or 1 so it does not have extra bits + } + this.max_code = max_code; + + // The elements heap[heap_len/2+1 .. heap_len] are leaves of the tree, + // establish sub-heaps of increasing lengths: + + for (n = s.heap_len / 2; n >= 1; n--) + s.Pqdownheap(tree, n); + + // Construct the Huffman tree by repeatedly combining the least two + // frequent nodes. + + node = elems; // next internal node of the tree + do + { + // n = node of least frequency + n = s.heap[1]; + s.heap[1] = s.heap[s.heap_len--]; + s.Pqdownheap(tree, 1); + m = s.heap[1]; // m = node of next least frequency + + s.heap[--s.heap_max] = n; // keep the nodes sorted by frequency + s.heap[--s.heap_max] = m; + + // Create a new node father of n and m + tree[node * 2] = (short)(tree[n * 2] + tree[m * 2]); + s.depth[node] = (byte)(Math.Max(s.depth[n], s.depth[m]) + 1); + tree[n * 2 + 1] = tree[m * 2 + 1] = (short)node; + + // and insert the new node in the heap + s.heap[1] = node++; + s.Pqdownheap(tree, 1); + } + while (s.heap_len >= 2); + + s.heap[--s.heap_max] = s.heap[1]; + + // At this point, the fields freq and dad are set. We can now + // generate the bit lengths. + + this.Gen_bitlen(s); + + // The field len is now set, we can generate the bit codes + Gen_codes(tree, max_code, s.bl_count, s.next_code); + } + + // Generate the codes for a given tree and bit counts (which need not be + // optimal). + // IN assertion: the array bl_count contains the bit length statistics for + // the given tree and the field len is set for all tree elements. + // OUT assertion: the field code is set for all tree elements of non + // zero code length. + static void Gen_codes( + short[] tree, // the tree to decorate + int max_code, // largest code with non zero frequency + short[] bl_count, // number of codes at each bit length + short[] next_code) + { + short code = 0; // running code value + int bits; // bit index + int n; // code index + + // The distribution counts are first used to generate the code values + // without bit reversal. + next_code[0] = 0; + for (bits = 1; bits <= MAX_BITS; bits++) + { + next_code[bits] = code = (short)((code + bl_count[bits - 1]) << 1); + } + + // Check that the bit counts in bl_count are consistent. The last code + // must be all ones. + //Assert (code + bl_count[MAX_BITS]-1 == (1< 0); + return res.RightUShift(1); + } + } +} diff --git a/src/DotNetty.Codecs/Compression/ZStream.cs b/src/DotNetty.Codecs/Compression/ZStream.cs new file mode 100644 index 0000000..fab2609 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/ZStream.cs @@ -0,0 +1,374 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +//JZlib 0.0.* were released under the GNU LGPL license.Later, we have switched +//over to a BSD-style license. + +//------------------------------------------------------------------------------ +//Copyright (c) 2000-2011 ymnk, JCraft, Inc.All rights reserved. + +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in +// the documentation and/or other materials provided with the distribution. + +// 3. The names of the authors may not be used to endorse or promote products +// derived from this software without specific prior written permission. + +//THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, +//INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +//FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.IN NO EVENT SHALL JCRAFT, +//INC.OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +//LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +//OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +//LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING +//NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +//EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ReSharper disable ArrangeThisQualifier +// ReSharper disable InconsistentNaming +namespace DotNetty.Codecs.Compression +{ + using System; + + /// + /// https://github.com/ymnk/jzlib/blob/master/src/main/java/com/jcraft/jzlib/ZStream.java + /// + + class ZStream + { + const int MAX_WBITS = 15; // 32K LZ77 window + const int DEF_WBITS = MAX_WBITS; + + //const int Z_NO_FLUSH = 0; + //const int Z_PARTIAL_FLUSH = 1; + //const int Z_SYNC_FLUSH = 2; + //const int Z_FULL_FLUSH = 3; + //const int Z_FINISH = 4; + + //const int MAX_MEM_LEVEL = 9; + + const int Z_OK = 0; + //const int Z_STREAM_END = 1; + //const int Z_NEED_DICT = 2; + //const int Z_ERRNO = -1; + const int Z_STREAM_ERROR = -2; + //const int Z_DATA_ERROR = -3; + //const int Z_MEM_ERROR = -4; + //const int Z_BUF_ERROR = -5; + //const int Z_VERSION_ERROR = -6; + + public byte[] next_in; // next input byte + public int next_in_index; + public int avail_in; // number of bytes available at next_in + public long total_in; // total nb of input bytes read so far + + public byte[] next_out; // next output byte should be put there + public int next_out_index; + public int avail_out; // remaining free space at next_out + public long total_out; // total nb of bytes output so far + + public string msg; + + internal Deflate dstate; + internal Inflate istate; + + internal int data_type; // best guess about the data type: ascii or binary + + internal IChecksum adler; + + protected internal ZStream() : this(new Adler32()) + { + } + + protected ZStream(IChecksum adler) + { + this.adler = adler; + } + + internal int InflateInit() => InflateInit(DEF_WBITS); + + internal int InflateInit(bool nowrap) => InflateInit(DEF_WBITS, nowrap); + + internal int InflateInit(int w) => InflateInit(w, false); + + internal int InflateInit(JZlib.WrapperType wrapperType) => InflateInit(DEF_WBITS, wrapperType); + + internal int InflateInit(int w, JZlib.WrapperType wrapperType) + { + bool nowrap = false; + if (wrapperType == JZlib.W_NONE) + { + nowrap = true; + } + else if (wrapperType == JZlib.W_GZIP) + { + w += 16; + } + else if (wrapperType == JZlib.W_ANY) + { + w |= Inflate.INFLATE_ANY; + } + else if (wrapperType == JZlib.W_ZLIB) + { + } + return InflateInit(w, nowrap); + } + + internal int InflateInit(int w, bool nowrap) + { + istate = new Inflate(this); + return istate.InflateInit(nowrap ? -w : w); + } + + internal int Inflate_z(int f) + { + if (istate == null) return Z_STREAM_ERROR; + return istate.Inflate_I(f); + } + + internal int InflateEnd() + { + if (istate == null) return Z_STREAM_ERROR; + int ret = istate.InflateEnd(); + // istate = null; + return ret; + } + + internal int InflateSync() + { + if (istate == null) + return Z_STREAM_ERROR; + return istate.InflateSync(); + } + + internal int InflateSyncPoint() + { + if (istate == null) return Z_STREAM_ERROR; + return istate.InflateSyncPoint(); + } + + internal int InflateSetDictionary(byte[] dictionary, int dictLength) + { + if (istate == null) return Z_STREAM_ERROR; + return istate.InflateSetDictionary(dictionary, dictLength); + } + + internal bool InflateFinished() => this.istate.mode == 12; + + internal int DeflateInit(int level) => DeflateInit(level, MAX_WBITS); + + internal int DeflateInit(int level, bool nowrap) => DeflateInit(level, MAX_WBITS, nowrap); + + internal int DeflateInit(int level, int bits) => DeflateInit(level, bits, false); + + internal int DeflateInit(int level, int bits, int memlevel, JZlib.WrapperType wrapperType) + { + if (bits < 9 || bits > 15) + { + return Z_STREAM_ERROR; + } + if (wrapperType == JZlib.W_NONE) + { + bits *= -1; + } + else if (wrapperType == JZlib.W_GZIP) + { + bits += 16; + } + else if (wrapperType == JZlib.W_ANY) + { + return Z_STREAM_ERROR; + } + else if (wrapperType == JZlib.W_ZLIB) + { + } + + return DeflateInit(level, bits, memlevel); + } + + internal int DeflateInit(int level, int bits, int memlevel) + { + dstate = new Deflate(this); + return dstate.DeflateInit(level, bits, memlevel); + } + + internal int DeflateInit(int level, int bits, bool nowrap) + { + dstate = new Deflate(this); + return dstate.DeflateInit(level, nowrap ? -bits : bits); + } + + internal int Deflate_z(int flush) + { + if (dstate == null) return Z_STREAM_ERROR; + return dstate.Deflate_D(flush); + } + + internal int DeflateEnd() + { + if (dstate == null) return Z_STREAM_ERROR; + int ret = dstate.DeflateEnd(); + dstate = null; + return ret; + } + + internal int DeflateParams(int level, int strategy) + { + if (dstate == null) return Z_STREAM_ERROR; + return dstate.DeflateParams(level, strategy); + } + + internal int DeflateSetDictionary(byte[] dictionary, int dictLength) + { + if (dstate == null) return Z_STREAM_ERROR; + return dstate.DeflateSetDictionary(dictionary, dictLength); + } + + // Flush as much pending output as possible. All deflate() output goes + // through this function so some applications may wish to modify it + // to avoid allocating a large strm->next_out buffer and copying into it. + // (See also read_buf()). + internal void Flush_pending() + { + int len = dstate.pending; + + if (len > avail_out) len = avail_out; + if (len == 0) return; + + if (dstate.pending_buf.Length <= dstate.pending_out || + next_out.Length <= next_out_index || + dstate.pending_buf.Length < (dstate.pending_out + len) || + next_out.Length < (next_out_index + len)) + { + //System.out.println(dstate.pending_buf.length+", "+dstate.pending_out+ + // ", "+next_out.length+", "+next_out_index+", "+len); + //System.out.println("avail_out="+avail_out); + } + + Array.Copy(dstate.pending_buf, dstate.pending_out, + next_out, next_out_index, len); + + next_out_index += len; + dstate.pending_out += len; + total_out += len; + avail_out -= len; + dstate.pending -= len; + if (dstate.pending == 0) + { + dstate.pending_out = 0; + } + } + + // Read a new buffer from the current input stream, update the adler32 + // and total number of bytes read. All deflate() input goes through + // this function so some applications may wish to modify it to avoid + // allocating a large strm->next_in buffer and copying from it. + // (See also flush_pending()). + internal int Read_buf(byte[] buf, int start, int size) + { + int len = avail_in; + + if (len > size) len = size; + if (len == 0) return 0; + + avail_in -= len; + + if (dstate.wrap != 0) + { + adler.Update(next_in, next_in_index, len); + } + Array.Copy(next_in, next_in_index, buf, start, len); + next_in_index += len; + total_in += len; + return len; + } + + internal long GetAdler() => adler.GetValue(); + + internal void Free() + { + next_in = null; + next_out = null; + msg = null; + } + + internal void SetOutput(byte[] buf) => SetOutput(buf, 0, buf.Length); + + internal void SetOutput(byte[] buf, int off, int len) + { + next_out = buf; + next_out_index = off; + avail_out = len; + } + + internal void SetInput(byte[] buf) => SetInput(buf, 0, buf.Length, false); + + internal void SetInput(byte[] buf, bool append) => SetInput(buf, 0, buf.Length, append); + + internal void SetInput(byte[] buf, int off, int len, bool append) + { + if (len <= 0 && append && next_in != null) return; + + if (avail_in > 0 && append) + { + var tmp = new byte[avail_in + len]; + Array.Copy(next_in, next_in_index, tmp, 0, avail_in); + Array.Copy(buf, off, tmp, avail_in, len); + next_in = tmp; + next_in_index = 0; + avail_in += len; + } + else + { + next_in = buf; + next_in_index = off; + avail_in = len; + } + } + + internal byte[] GetNextIn() => next_in; + + internal void SetNextIn(byte[] next_in_value) => next_in = next_in_value; + + internal int GetNextInIndex() => next_in_index; + + internal void SetNextInIndex(int next_in_index_value) => next_in_index = next_in_index_value; + + internal int GetAvailIn() => avail_in; + + internal void SetAvailIn(int avail_in_value) => avail_in = avail_in_value; + + internal byte[] GetNextOut() => next_out; + + internal void SetNextOut(byte[] next_out_value) => next_out = next_out_value; + + internal int GetNextOutIndex() => next_out_index; + + internal void SetNextOutIndex(int next_out_index_value) => next_out_index = next_out_index_value; + + internal int GetAvailOut() => avail_out; + + internal void SetAvailOut(int avail_out_value) => avail_out = avail_out_value; + + internal long GetTotalOut() => total_out; + + internal long GetTotalIn() => total_in; + + internal string GetMessage() => msg; + + /** + * Those methods are expected to be override by Inflater and Deflater. + * In the future, they will become abstract methods. + */ + public virtual int End() => Z_OK; + + public virtual bool Finished() => false; + } +} diff --git a/src/DotNetty.Codecs/Compression/ZlibCodecFactory.cs b/src/DotNetty.Codecs/Compression/ZlibCodecFactory.cs new file mode 100644 index 0000000..adb22ce --- /dev/null +++ b/src/DotNetty.Codecs/Compression/ZlibCodecFactory.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Compression +{ + public static class ZlibCodecFactory + { + public static ZlibEncoder NewZlibEncoder(int compressionLevel) => new JZlibEncoder(compressionLevel); + + public static ZlibEncoder NewZlibEncoder(ZlibWrapper wrapper) => new JZlibEncoder(wrapper); + + public static ZlibEncoder NewZlibEncoder(ZlibWrapper wrapper, int compressionLevel) => new JZlibEncoder(wrapper, compressionLevel); + + public static ZlibEncoder NewZlibEncoder(ZlibWrapper wrapper, int compressionLevel, int windowBits, int memLevel) => + new JZlibEncoder(wrapper, compressionLevel, windowBits, memLevel); + + public static ZlibEncoder NewZlibEncoder(byte[] dictionary) => new JZlibEncoder(dictionary); + + public static ZlibEncoder NewZlibEncoder(int compressionLevel, byte[] dictionary) => new JZlibEncoder(compressionLevel, dictionary); + + public static ZlibEncoder NewZlibEncoder(int compressionLevel, int windowBits, int memLevel, byte[] dictionary) => + new JZlibEncoder(compressionLevel, windowBits, memLevel, dictionary); + + public static ZlibDecoder NewZlibDecoder() => new JZlibDecoder(); + + public static ZlibDecoder NewZlibDecoder(ZlibWrapper wrapper) => new JZlibDecoder(wrapper); + + public static ZlibDecoder NewZlibDecoder(byte[] dictionary) => new JZlibDecoder(dictionary); + } +} diff --git a/src/DotNetty.Codecs/Compression/ZlibDecoder.cs b/src/DotNetty.Codecs/Compression/ZlibDecoder.cs new file mode 100644 index 0000000..64bf9d6 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/ZlibDecoder.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Compression +{ + public abstract class ZlibDecoder : ByteToMessageDecoder + { + public abstract bool IsClosed { get; } + } +} diff --git a/src/DotNetty.Codecs/Compression/ZlibEncoder.cs b/src/DotNetty.Codecs/Compression/ZlibEncoder.cs new file mode 100644 index 0000000..d16f237 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/ZlibEncoder.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Compression +{ + using System.Threading.Tasks; + using DotNetty.Buffers; + + public abstract class ZlibEncoder : MessageToByteEncoder + { + public abstract bool IsClosed { get; } + + /** + * Close this {@link ZlibEncoder} and so finish the encoding. + * + * The returned {@link ChannelFuture} will be notified once the + * operation completes. + */ + public abstract Task CloseAsync(); + } +} diff --git a/src/DotNetty.Codecs/Compression/ZlibUtil.cs b/src/DotNetty.Codecs/Compression/ZlibUtil.cs new file mode 100644 index 0000000..3b2d96f --- /dev/null +++ b/src/DotNetty.Codecs/Compression/ZlibUtil.cs @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Compression +{ + using System; + + static class ZlibUtil + { + public static void Fail(Inflater z, string message, int resultCode) + { + throw new DecompressionException($"{message} ({resultCode})" + (z.msg != null ? " : " + z.msg : "")); + } + + public static void Fail(Deflater z, string message, int resultCode) + { + throw new CompressionException($"{message} ({resultCode})" + (z.msg != null ? " : " + z.msg : "")); + } + + public static JZlib.WrapperType ConvertWrapperType(ZlibWrapper wrapper) + { + JZlib.WrapperType convertedWrapperType; + switch (wrapper) + { + case ZlibWrapper.None: + convertedWrapperType = JZlib.W_NONE; + break; + case ZlibWrapper.Zlib: + convertedWrapperType = JZlib.W_ZLIB; + break; + case ZlibWrapper.Gzip: + convertedWrapperType = JZlib.W_GZIP; + break; + case ZlibWrapper.ZlibOrNone: + convertedWrapperType = JZlib.W_ANY; + break; + default: + throw new ArgumentException($"Unknown type {wrapper}"); + } + + return convertedWrapperType; + } + + public static int WrapperOverhead(ZlibWrapper wrapper) + { + int overhead; + switch (wrapper) + { + case ZlibWrapper.Zlib: + overhead = 2; + break; + case ZlibWrapper.Gzip: + overhead = 10; + break; + default: + throw new NotSupportedException($"Unknown value {wrapper}"); + } + + return overhead; + } + } +} diff --git a/src/DotNetty.Codecs/Compression/ZlibWrapper.cs b/src/DotNetty.Codecs/Compression/ZlibWrapper.cs new file mode 100644 index 0000000..0a61d06 --- /dev/null +++ b/src/DotNetty.Codecs/Compression/ZlibWrapper.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Compression +{ + /** + * The container file formats that wrap the stream compressed by the DEFLATE + * algorithm. + */ + public enum ZlibWrapper + { + /** + * The ZLIB wrapper as specified in RFC 1950. + */ + Zlib, + /** + * The GZIP wrapper as specified in RFC 1952. + */ + Gzip, + /** + * Raw DEFLATE stream only (no header and no footer). + */ + None, + /** + * Try {@link #ZLIB} first and then {@link #NONE} if the first attempt fails. + * Please note that you can specify this wrapper type only when decompressing. + */ + ZlibOrNone + } +} diff --git a/src/DotNetty.Codecs/DateFormatter.cs b/src/DotNetty.Codecs/DateFormatter.cs new file mode 100644 index 0000000..0890a34 --- /dev/null +++ b/src/DotNetty.Codecs/DateFormatter.cs @@ -0,0 +1,484 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs +{ + using System; + using System.Collections; + using System.Diagnostics.Contracts; + using System.Text; + using DotNetty.Common; + using DotNetty.Common.Utilities; + + public sealed class DateFormatter + { + static readonly BitArray Delimiters = GetDelimiters(); + + static BitArray GetDelimiters() + { + var bitArray = new BitArray(128, false); + bitArray[0x09] = true; + for (int c = 0x20; c <= 0x2F; c++) + { + bitArray[c] = true; + } + + for (int c = 0x3B; c <= 0x40; c++) + { + bitArray[c] = true; + } + + for (int c = 0x5B; c <= 0x60; c++) + { + bitArray[c] = true; + } + + for (int c = 0x7B; c <= 0x7E; c++) + { + bitArray[c] = true; + } + + return bitArray; + } + + // The order is the same as dateTime.DayOfWeek + static readonly string[] DayOfWeekToShortName = + { "Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat" }; + + static readonly string[] CalendarMonthToShortName = + { "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec" }; + + static readonly ThreadLocalCache Cache = new ThreadLocalCache(); + + public static DateTime? ParseHttpDate(string txt) => ParseHttpDate(AsciiString.Cached(txt)); + + public static DateTime? ParseHttpDate(ICharSequence txt) => ParseHttpDate(txt, 0, txt.Count); + + public static DateTime? ParseHttpDate(string txt, int start, int end) => ParseHttpDate(AsciiString.Cached(txt), start, end); + + public static DateTime? ParseHttpDate(ICharSequence txt, int start, int end) + { + Contract.Requires(txt != null); + + int length = end - start; + if (length == 0) + { + return null; + } + else if (length < 0) + { + throw new ArgumentException("Can't have end < start"); + } + else if (length > 64) + { + throw new ArgumentException("Can't parse more than 64 chars, looks like a user error or a malformed header"); + } + + return Formatter().Parse0(txt, start, end); + } + + public static string Format(DateTime dateTime) => Formatter().Format0(dateTime); + + public static StringBuilder Append(DateTime dateTime, StringBuilder sb) => Append0(dateTime, sb); + + static DateFormatter Formatter() + { + DateFormatter formatter = Cache.Value; + formatter.Reset(); + return formatter; + } + + // delimiter = %x09 / %x20-2F / %x3B-40 / %x5B-60 / %x7B-7E + static bool IsDelim(char c) => Delimiters[c]; + + static bool IsDigit(char c) => c >= 48 && c <= 57; + + static int GetNumericalValue(char c) => c - 48; + + readonly StringBuilder sb = new StringBuilder(29); // Sun, 27 Nov 2016 19:37:15 GMT + bool timeFound; + int hours; + int minutes; + int seconds; + bool dayOfMonthFound; + int dayOfMonth; + bool monthFound; + int month; + bool yearFound; + int year; + + DateFormatter() + { + this.Reset(); + } + + public void Reset() + { + this.timeFound = false; + this.hours = -1; + this.minutes = -1; + this.seconds = -1; + this.dayOfMonthFound = false; + this.dayOfMonth = -1; + this.monthFound = false; + this.month = -1; + this.yearFound = false; + this.year = -1; + this.sb.Length = 0; + } + + bool TryParseTime(ICharSequence txt, int tokenStart, int tokenEnd) + { + int len = tokenEnd - tokenStart; + + // h:m:s to hh:mm:ss + if (len < 5 || len > 8) + { + return false; + } + + int localHours = -1; + int localMinutes = -1; + int localSeconds = -1; + int currentPartNumber = 0; + int currentPartValue = 0; + int numDigits = 0; + + for (int i = tokenStart; i < tokenEnd; i++) + { + char c = txt[i]; + if (IsDigit(c)) + { + currentPartValue = currentPartValue * 10 + GetNumericalValue(c); + if (++numDigits > 2) + { + return false; // too many digits in this part + } + } + else if (c == ':') + { + if (numDigits == 0) + { + // no digits between separators + return false; + } + switch (currentPartNumber) + { + case 0: + // flushing hours + localHours = currentPartValue; + break; + case 1: + // flushing minutes + localMinutes = currentPartValue; + break; + default: + // invalid, too many : + return false; + } + currentPartValue = 0; + currentPartNumber++; + numDigits = 0; + } + else + { + // invalid char + return false; + } + } + + if (numDigits > 0) + { + // pending seconds + localSeconds = currentPartValue; + } + + if (localHours >= 0 && localMinutes >= 0 && localSeconds >= 0) + { + this.hours = localHours; + this.minutes = localMinutes; + this.seconds = localSeconds; + return true; + } + + return false; + } + + bool TryParseDayOfMonth(ICharSequence txt, int tokenStart, int tokenEnd) + { + int len = tokenEnd - tokenStart; + + if (len == 1) + { + char c0 = txt[tokenStart]; + if (IsDigit(c0)) + { + this.dayOfMonth = GetNumericalValue(c0); + return true; + } + + } + else if (len == 2) + { + char c0 = txt[tokenStart]; + char c1 = txt[tokenStart + 1]; + if (IsDigit(c0) && IsDigit(c1)) + { + this.dayOfMonth = GetNumericalValue(c0) * 10 + GetNumericalValue(c1); + return true; + } + } + + return false; + } + + static bool MatchMonth(ICharSequence month, ICharSequence txt, int tokenStart) => + AsciiString.RegionMatchesAscii(month, true, 0, txt, tokenStart, 3); + + bool TryParseMonth(ICharSequence txt, int tokenStart, int tokenEnd) + { + int len = tokenEnd - tokenStart; + + if (len != 3) + { + return false; + } + + if (MatchMonth(Jan, txt, tokenStart)) + { + this.month = 1; + } + else if (MatchMonth(Feb, txt, tokenStart)) + { + this.month = 2; + } + else if (MatchMonth(Mar, txt, tokenStart)) + { + this.month = 3; + } + else if (MatchMonth(Apr, txt, tokenStart)) + { + this.month = 4; + } + else if (MatchMonth(May, txt, tokenStart)) + { + this.month = 5; + } + else if (MatchMonth(Jun, txt, tokenStart)) + { + this.month = 6; + } + else if (MatchMonth(Jul, txt, tokenStart)) + { + this.month = 7; + } + else if (MatchMonth(Aug, txt, tokenStart)) + { + this.month = 8; + } + else if (MatchMonth(Sep, txt, tokenStart)) + { + this.month = 9; + } + else if (MatchMonth(Oct, txt, tokenStart)) + { + this.month = 10; + } + else if (MatchMonth(Nov, txt, tokenStart)) + { + this.month = 11; + } + else if (MatchMonth(Dec, txt, tokenStart)) + { + this.month = 12; + } + else + { + return false; + } + + return true; + } + + static readonly AsciiString Jan = AsciiString.Cached("Jan"); + static readonly AsciiString Feb = AsciiString.Cached("Feb"); + static readonly AsciiString Mar = AsciiString.Cached("Mar"); + static readonly AsciiString Apr = AsciiString.Cached("Apr"); + static readonly AsciiString May = AsciiString.Cached("May"); + static readonly AsciiString Jun = AsciiString.Cached("Jun"); + static readonly AsciiString Jul = AsciiString.Cached("Jul"); + static readonly AsciiString Aug = AsciiString.Cached("Aug"); + static readonly AsciiString Sep = AsciiString.Cached("Sep"); + static readonly AsciiString Oct = AsciiString.Cached("Oct"); + static readonly AsciiString Nov = AsciiString.Cached("Nov"); + static readonly AsciiString Dec = AsciiString.Cached("Dec"); + + bool TryParseYear(ICharSequence txt, int tokenStart, int tokenEnd) + { + int len = tokenEnd - tokenStart; + + if (len == 2) + { + char c0 = txt[tokenStart]; + char c1 = txt[tokenStart + 1]; + if (IsDigit(c0) && IsDigit(c1)) + { + this.year = GetNumericalValue(c0) * 10 + GetNumericalValue(c1); + return true; + } + + } + else if (len == 4) + { + char c0 = txt[tokenStart]; + char c1 = txt[tokenStart + 1]; + char c2 = txt[tokenStart + 2]; + char c3 = txt[tokenStart + 3]; + if (IsDigit(c0) && IsDigit(c1) && IsDigit(c2) && IsDigit(c3)) + { + this.year = GetNumericalValue(c0) * 1000 + + GetNumericalValue(c1) * 100 + + GetNumericalValue(c2) * 10 + + GetNumericalValue(c3); + + return true; + } + } + + return false; + } + + bool ParseToken(ICharSequence txt, int tokenStart, int tokenEnd) + { + // return true if all parts are found + if (!this.timeFound) + { + this.timeFound = this.TryParseTime(txt, tokenStart, tokenEnd); + if (this.timeFound) + { + return this.dayOfMonthFound && this.monthFound && this.yearFound; + } + } + + if (!this.dayOfMonthFound) + { + this.dayOfMonthFound = this.TryParseDayOfMonth(txt, tokenStart, tokenEnd); + if (this.dayOfMonthFound) + { + return this.timeFound && this.monthFound && this.yearFound; + } + } + + if (!this.monthFound) + { + this.monthFound = this.TryParseMonth(txt, tokenStart, tokenEnd); + if (this.monthFound) + { + return this.timeFound && this.dayOfMonthFound && this.yearFound; + } + } + + if (!this.yearFound) + { + this.yearFound = this.TryParseYear(txt, tokenStart, tokenEnd); + } + + return this.timeFound && this.dayOfMonthFound && this.monthFound && this.yearFound; + } + + DateTime? Parse0(ICharSequence txt, int start, int end) + { + bool allPartsFound = this.Parse1(txt, start, end); + return allPartsFound && this.NormalizeAndValidate() ? this.ComputeDate() : default(DateTime?); + } + + bool Parse1(ICharSequence txt, int start, int end) + { + // return true if all parts are found + int tokenStart = -1; + + for (int i = start; i < end; i++) + { + char c = txt[i]; + + if (IsDelim(c)) + { + if (tokenStart != -1) + { + // terminate token + if (this.ParseToken(txt, tokenStart, i)) + { + return true; + } + tokenStart = -1; + } + } + else if (tokenStart == -1) + { + // start new token + tokenStart = i; + } + } + + // terminate trailing token + return tokenStart != -1 && this.ParseToken(txt, tokenStart, txt.Count); + } + + bool NormalizeAndValidate() + { + if (this.dayOfMonth < 1 + || this.dayOfMonth > 31 + || this.hours > 23 + || this.minutes > 59 + || this.seconds > 59) + { + return false; + } + + if (this.year >= 70 && this.year <= 99) + { + this.year += 1900; + } + else if (this.year >= 0 && this.year < 70) + { + this.year += 2000; + } + else if (this.year < 1601) + { + // invalid value + return false; + } + return true; + } + + DateTime ComputeDate() => new DateTime(this.year, this.month, this.dayOfMonth, this.hours, this.minutes, this.seconds, DateTimeKind.Utc); + + string Format0(DateTime dateTime) => Append0(dateTime, this.sb).ToString(); + + static StringBuilder Append0(DateTime dateTime, StringBuilder buffer) + { + buffer.Append(DayOfWeekToShortName[(int)dateTime.DayOfWeek]).Append(", "); + buffer.Append(dateTime.Day).Append(' '); + buffer.Append(CalendarMonthToShortName[dateTime.Month - 1]).Append(' '); + buffer.Append(dateTime.Year).Append(' '); + + AppendZeroLeftPadded(dateTime.Hour, buffer).Append(':'); + AppendZeroLeftPadded(dateTime.Minute, buffer).Append(':'); + return AppendZeroLeftPadded(dateTime.Second, buffer).Append(" GMT"); + } + + static StringBuilder AppendZeroLeftPadded(int value, StringBuilder sb) + { + if (value < 10) + { + sb.Append('0'); + } + return sb.Append(value); + } + + sealed class ThreadLocalCache : FastThreadLocal + { + protected override DateFormatter GetInitialValue() => new DateFormatter(); + } + } +} diff --git a/src/DotNetty.Codecs/DecoderException.cs b/src/DotNetty.Codecs/DecoderException.cs index 7217566..66c2450 100644 --- a/src/DotNetty.Codecs/DecoderException.cs +++ b/src/DotNetty.Codecs/DecoderException.cs @@ -5,7 +5,7 @@ namespace DotNetty.Codecs { using System; - public class DecoderException : Exception + public class DecoderException : CodecException { public DecoderException(string message) : base(message) @@ -16,5 +16,10 @@ namespace DotNetty.Codecs : base(null, cause) { } + + public DecoderException(string message, Exception cause) + : base(message, cause) + { + } } } \ No newline at end of file diff --git a/src/DotNetty.Codecs/DecoderResult.cs b/src/DotNetty.Codecs/DecoderResult.cs new file mode 100644 index 0000000..5da9208 --- /dev/null +++ b/src/DotNetty.Codecs/DecoderResult.cs @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs +{ + using System; + using System.Diagnostics.Contracts; + using System.Text; + using DotNetty.Common.Utilities; + + public class DecoderResult + { + protected static readonly Signal SignalUnfinished = Signal.ValueOf(typeof(DecoderResult), "UNFINISHED"); + protected static readonly Signal SignalSuccess = Signal.ValueOf(typeof(DecoderResult), "SUCCESS"); + + public static readonly DecoderResult Unfinished = new DecoderResult(SignalUnfinished); + public static readonly DecoderResult Success = new DecoderResult(SignalSuccess); + + public static DecoderResult Failure(Exception cause) + { + Contract.Requires(cause != null); + return new DecoderResult(cause); + } + + readonly Exception cause; + + protected DecoderResult(Exception cause) + { + Contract.Requires(cause != null); + this.cause = cause; + } + + public bool IsFinished => !ReferenceEquals(this.cause, SignalUnfinished); + + public bool IsSuccess => ReferenceEquals(this.cause, SignalSuccess); + + public bool IsFailure => !ReferenceEquals(this.cause, SignalSuccess) + && !ReferenceEquals(this.cause, SignalUnfinished); + + public Exception Cause => this.IsFailure ? this.cause : null; + + public override string ToString() + { + if (!this.IsFinished) + { + return "unfinished"; + } + + if (this.IsSuccess) + { + return "success"; + } + + string error = this.cause.ToString(); + return new StringBuilder(error.Length + 17) + .Append("failure(") + .Append(error) + .Append(')') + .ToString(); + } + } +} diff --git a/src/DotNetty.Codecs/DefaultHeaders.cs b/src/DotNetty.Codecs/DefaultHeaders.cs new file mode 100644 index 0000000..baefbb6 --- /dev/null +++ b/src/DotNetty.Codecs/DefaultHeaders.cs @@ -0,0 +1,1138 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoProperty +// ReSharper disable ConvertToAutoPropertyWithPrivateSetter +// ReSharper disable ForCanBeConvertedToForeach +// ReSharper disable PossibleUnintendedReferenceComparison +// ReSharper disable EmptyGeneralCatchClause +namespace DotNetty.Codecs +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Collections.Immutable; + using System.Runtime.CompilerServices; + using DotNetty.Common.Utilities; + + using static Common.Internal.MathUtil; + using static HeadersUtils; + + public class DefaultHeaders : IHeaders + where TKey : class + { + const int HashCodeSeed = unchecked((int)0xc2b2ae35); + + static readonly DefaultHashingStrategy DefaultValueHashingStrategy = new DefaultHashingStrategy(); + static readonly DefaultHashingStrategy DefaultKeyHashingStragety = new DefaultHashingStrategy(); + static readonly NullNameValidator DefaultKeyNameValidator = new NullNameValidator(); + + readonly HeaderEntry[] entries; + readonly HeaderEntry head; + + readonly byte hashMask; + protected readonly IValueConverter ValueConverter; + readonly INameValidator nameValidator; + readonly IHashingStrategy hashingStrategy; + int size; + + public DefaultHeaders(IValueConverter valueConverter) + : this(DefaultKeyHashingStragety, valueConverter, DefaultKeyNameValidator, 16) + { + } + + public DefaultHeaders(IValueConverter valueConverter, INameValidator nameValidator) + : this(DefaultKeyHashingStragety, valueConverter, nameValidator, 16) + { + } + + public DefaultHeaders(IHashingStrategy nameHashingStrategy, IValueConverter valueConverter, INameValidator nameValidator) + : this(nameHashingStrategy, valueConverter, nameValidator, 16) + { + } + + public DefaultHeaders(IHashingStrategy nameHashingStrategy, + IValueConverter valueConverter, INameValidator nameValidator, int arraySizeHint) + { + if (ReferenceEquals(nameHashingStrategy, null)) ThrowArgumentNullException(nameof(nameHashingStrategy)); + if (ReferenceEquals(valueConverter, null)) ThrowArgumentNullException(nameof(valueConverter)); + if (ReferenceEquals(nameValidator, null)) ThrowArgumentNullException(nameof(nameValidator)); + + this.hashingStrategy = nameHashingStrategy; + this.ValueConverter = valueConverter; + this.nameValidator = nameValidator; + + // Enforce a bound of [2, 128] because hashMask is a byte. The max possible value of hashMask is one less + // than the length of this array, and we want the mask to be > 0. + this.entries = new HeaderEntry[FindNextPositivePowerOfTwo(Math.Max(2, Math.Min(arraySizeHint, 128)))]; + this.hashMask = (byte)(this.entries.Length - 1); + this.head = new HeaderEntry(); + } + + public bool TryGet(TKey name, out TValue value) + { + if (name == null) ThrowArgumentNullException(nameof(name)); + + bool found = false; + int h = this.hashingStrategy.HashCode(name); + int i = this.Index(h); + HeaderEntry e = this.entries[i]; + value = default(TValue); + // loop until the first header was found + while (e != null) + { + if (e.Hash == h && this.hashingStrategy.Equals(name, e.key)) + { + value = e.value; + found = true; + } + + e = e.Next; + } + return found; + } + + public TValue Get(TKey name, TValue defaultValue) => this.TryGet(name, out TValue value) ? value : defaultValue; + + public bool TryGetAndRemove(TKey name, out TValue value) + { + if (name == null) ThrowArgumentNullException(nameof(name)); + + int h = this.hashingStrategy.HashCode(name); + return this.TryRemove0(h, this.Index(h), name, out value); + } + + public TValue GetAndRemove(TKey name, TValue defaultValue) => this.TryGetAndRemove(name, out TValue value) ? value : defaultValue; + + public virtual IList GetAll(TKey name) + { + if (name == null) ThrowArgumentNullException(nameof(name)); + + var values = new List(); + int h = this.hashingStrategy.HashCode(name); + int i = this.Index(h); + HeaderEntry e = this.entries[i]; + while (e != null) + { + if (e.Hash == h && this.hashingStrategy.Equals(name, e.key)) + { + values.Insert(0, e.value); + } + + e = e.Next; + } + return values; + } + + public virtual IEnumerable ValueIterator(TKey name) => new ValueEnumerator(this, name); + + public IList GetAllAndRemove(TKey name) + { + IList all = this.GetAll(name); + this.Remove(name); + return all; + } + + public bool Contains(TKey name) => this.TryGet(name, out _); + + public bool ContainsObject(TKey name, object value) + { + if (value == null) ThrowArgumentNullException(nameof(value)); + + return this.Contains(name, this.ValueConverter.ConvertObject(value)); + } + + public bool ContainsBoolean(TKey name, bool value) => this.Contains(name, this.ValueConverter.ConvertBoolean(value)); + + public bool ContainsByte(TKey name, byte value) => this.Contains(name, this.ValueConverter.ConvertByte(value)); + + public bool ContainsChar(TKey name, char value) => this.Contains(name, this.ValueConverter.ConvertChar(value)); + + public bool ContainsShort(TKey name, short value) => this.Contains(name, this.ValueConverter.ConvertShort(value)); + + public bool ContainsInt(TKey name, int value) => this.Contains(name, this.ValueConverter.ConvertInt(value)); + + public bool ContainsLong(TKey name, long value) => this.Contains(name, this.ValueConverter.ConvertLong(value)); + + public bool ContainsFloat(TKey name, float value) => this.Contains(name, this.ValueConverter.ConvertFloat(value)); + + public bool ContainsDouble(TKey name, double value) => this.Contains(name, this.ValueConverter.ConvertDouble(value)); + + public bool ContainsTimeMillis(TKey name, long value) => this.Contains(name, this.ValueConverter.ConvertTimeMillis(value)); + + public bool Contains(TKey name, TValue value) => this.Contains(name, value, DefaultValueHashingStrategy); + + public bool Contains(TKey name, TValue value, IHashingStrategy valueHashingStrategy) + { + if (name == null) ThrowArgumentNullException(nameof(name)); + + int h = this.hashingStrategy.HashCode(name); + int i = this.Index(h); + HeaderEntry e = this.entries[i]; + while (e != null) + { + if (e.Hash == h && this.hashingStrategy.Equals(name, e.key) + && valueHashingStrategy.Equals(value, e.value)) + { + return true; + } + e = e.Next; + } + return false; + } + + public int Size => this.size; + + public bool IsEmpty => this.head == this.head.After; + + public ISet Names() + { + if (this.IsEmpty) + { + return ImmutableHashSet.Empty; + } + + var names = new HashSet(this.hashingStrategy); + HeaderEntry e = this.head.After; + while (e != this.head) + { + names.Add(e.key); + e = e.After; + } + return names; + } + + public virtual IHeaders Add(TKey name, TValue value) + { + if (ReferenceEquals(value, null)) ThrowArgumentNullException(nameof(value)); + + this.nameValidator.ValidateName(name); + int h = this.hashingStrategy.HashCode(name); + int i = this.Index(h); + this.Add0(h, i, name, value); + return this; + } + + public virtual IHeaders Add(TKey name, IEnumerable values) + { + this.nameValidator.ValidateName(name); + int h = this.hashingStrategy.HashCode(name); + int i = this.Index(h); + foreach (TValue v in values) + { + this.Add0(h, i, name, v); + } + return this; + } + + public virtual IHeaders AddObject(TKey name, object value) + { + if (value == null) ThrowArgumentNullException(nameof(value)); + + return this.Add(name, this.ValueConverter.ConvertObject(value)); + } + + public virtual IHeaders AddObject(TKey name, IEnumerable values) + { + foreach (object value in values) + { + this.AddObject(name, value); + } + return this; + } + + public virtual IHeaders AddObject(TKey name, params object[] values) + { + // ReSharper disable once ForCanBeConvertedToForeach + // Avoid enumerator allocations + for (int i = 0; i < values.Length; i++) + { + this.AddObject(name, values[i]); + } + + return this; + } + + public IHeaders AddInt(TKey name, int value) => this.Add(name, this.ValueConverter.ConvertInt(value)); + + public IHeaders AddLong(TKey name, long value) => this.Add(name, this.ValueConverter.ConvertLong(value)); + + public IHeaders AddDouble(TKey name, double value) => this.Add(name, this.ValueConverter.ConvertDouble(value)); + + public IHeaders AddTimeMillis(TKey name, long value) => this.Add(name, this.ValueConverter.ConvertTimeMillis(value)); + + public IHeaders AddChar(TKey name, char value) => this.Add(name, this.ValueConverter.ConvertChar(value)); + + public IHeaders AddBoolean(TKey name, bool value) => this.Add(name, this.ValueConverter.ConvertBoolean(value)); + + public IHeaders AddFloat(TKey name, float value) => this.Add(name, this.ValueConverter.ConvertFloat(value)); + + public IHeaders AddByte(TKey name, byte value) => this.Add(name, this.ValueConverter.ConvertByte(value)); + + public IHeaders AddShort(TKey name, short value) => this.Add(name, this.ValueConverter.ConvertShort(value)); + + public virtual IHeaders Add(IHeaders headers) + { + if (ReferenceEquals(headers, this)) + { + ThrowArgumentException("can't add to itself."); + } + this.AddImpl(headers); + return this; + } + + protected void AddImpl(IHeaders headers) + { + if (headers is DefaultHeaders defaultHeaders) + { + HeaderEntry e = defaultHeaders.head.After; + + if (defaultHeaders.hashingStrategy == this.hashingStrategy + && defaultHeaders.nameValidator == this.nameValidator) + { + // Fastest copy + while (e != defaultHeaders.head) + { + this.Add0(e.Hash, this.Index(e.Hash), e.key, e.value); + e = e.After; + } + } + else + { + // Fast copy + while (e != defaultHeaders.head) + { + this.Add(e.key, e.value); + e = e.After; + } + } + } + else + { + // Slow copy + foreach (HeaderEntry header in headers) + { + this.Add(header.key, header.value); + } + } + } + + public IHeaders Set(TKey name, TValue value) + { + if (ReferenceEquals(value, null)) ThrowArgumentNullException(nameof(value)); + + this.nameValidator.ValidateName(name); + int h = this.hashingStrategy.HashCode(name); + int i = this.Index(h); + this.TryRemove0(h, i, name, out _); + this.Add0(h, i, name, value); + return this; + } + + public virtual IHeaders Set(TKey name, IEnumerable values) + { + if (ReferenceEquals(values, null)) ThrowArgumentNullException(nameof(values)); + + this.nameValidator.ValidateName(name); + int h = this.hashingStrategy.HashCode(name); + int i = this.Index(h); + + this.TryRemove0(h, i, name, out _); + // ReSharper disable once PossibleNullReferenceException + foreach (TValue v in values) + { + if (v == null) + { + break; + } + this.Add0(h, i, name, v); + } + + return this; + } + + public virtual IHeaders SetObject(TKey name, object value) + { + if (value == null) ThrowArgumentNullException(nameof(value)); + + TValue convertedValue = this.ValueConverter.ConvertObject(value); + return this.Set(name, convertedValue); + } + + public virtual IHeaders SetObject(TKey name, IEnumerable values) + { + if (ReferenceEquals(values, null)) ThrowArgumentNullException(nameof(values)); + + this.nameValidator.ValidateName(name); + int h = this.hashingStrategy.HashCode(name); + int i = this.Index(h); + + this.TryRemove0(h, i, name, out _); + // ReSharper disable once PossibleNullReferenceException + foreach (object v in values) + { + if (v == null) + { + break; + } + this.Add0(h, i, name, this.ValueConverter.ConvertObject(v)); + } + + return this; + } + + public IHeaders SetInt(TKey name, int value) => this.Set(name, this.ValueConverter.ConvertInt(value)); + + public IHeaders SetLong(TKey name, long value) => this.Set(name, this.ValueConverter.ConvertLong(value)); + + public IHeaders SetDouble(TKey name, double value) => this.Set(name, this.ValueConverter.ConvertDouble(value)); + + public IHeaders SetTimeMillis(TKey name, long value) => this.Set(name, this.ValueConverter.ConvertTimeMillis(value)); + + public IHeaders SetFloat(TKey name, float value) => this.Set(name, this.ValueConverter.ConvertFloat(value)); + + public IHeaders SetChar(TKey name, char value) => this.Set(name, this.ValueConverter.ConvertChar(value)); + + public IHeaders SetBoolean(TKey name, bool value) => this.Set(name, this.ValueConverter.ConvertBoolean(value)); + + public IHeaders SetByte(TKey name, byte value) => this.Set(name, this.ValueConverter.ConvertByte(value)); + + public IHeaders SetShort(TKey name, short value) => this.Set(name, this.ValueConverter.ConvertShort(value)); + + public virtual IHeaders Set(IHeaders headers) + { + if (!ReferenceEquals(headers, this)) + { + this.Clear(); + this.AddImpl(headers); + } + return this; + } + + public virtual IHeaders SetAll(IHeaders headers) + { + if (!ReferenceEquals(headers, this)) + { + foreach (TKey key in headers.Names()) + { + this.Remove(key); + } + this.AddImpl(headers); + } + return this; + } + + public bool Remove(TKey name) => this.TryGetAndRemove(name, out _); + + public IHeaders Clear() + { + Array.Clear(this.entries, 0, this.entries.Length); + this.head.Before = this.head.After = this.head; + this.size = 0; + return this; + } + + public IEnumerator> GetEnumerator() => new HeaderEnumerator(this); + + IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator(); + + public bool TryGetBoolean(TKey name, out bool value) + { + if (this.TryGet(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToBoolean(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(bool); + return false; + } + + public bool GetBoolean(TKey name, bool defaultValue) => this.TryGetBoolean(name, out bool value) ? value : defaultValue; + + public bool TryGetByte(TKey name, out byte value) + { + if (this.TryGet(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToByte(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(byte); + return false; + } + + public byte GetByte(TKey name, byte defaultValue) => this.TryGetByte(name, out byte value) ? value : defaultValue; + + public bool TryGetChar(TKey name, out char value) + { + if (this.TryGet(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToChar(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(char); + return false; + } + + public char GetChar(TKey name, char defaultValue) => this.TryGetChar(name, out char value) ? value : defaultValue; + + public bool TryGetShort(TKey name, out short value) + { + if (this.TryGet(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToShort(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(short); + return false; + } + + public short GetShort(TKey name, short defaultValue) => this.TryGetShort(name, out short value) ? value : defaultValue; + + public bool TryGetInt(TKey name, out int value) + { + if (this.TryGet(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToInt(v); + return true; + } + catch(Exception) + { + // Ignore + } + } + + value = default(int); + return false; + } + + public int GetInt(TKey name, int defaultValue) => this.TryGetInt(name, out int value) ? value : defaultValue; + + public bool TryGetLong(TKey name, out long value) + { + if (this.TryGet(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToLong(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(long); + return false; + } + + public long GetLong(TKey name, long defaultValue) => this.TryGetLong(name, out long value) ? value : defaultValue; + + public bool TryGetFloat(TKey name, out float value) + { + if (this.TryGet(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToFloat(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(float); + return false; + } + + public float GetFloat(TKey name, float defaultValue) => this.TryGetFloat(name, out float value) ? value : defaultValue; + + public bool TryGetDouble(TKey name, out double value) + { + if (this.TryGet(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToDouble(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(double); + return false; + } + + public double GetDouble(TKey name, double defaultValue) => this.TryGetDouble(name, out double value) ? value : defaultValue; + + public bool TryGetTimeMillis(TKey name, out long value) + { + if (this.TryGet(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToTimeMillis(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(long); + return false; + } + + public long GetTimeMillis(TKey name, long defaultValue) => this.TryGetTimeMillis(name, out long value) ? value : defaultValue; + + public bool TryGetBooleanAndRemove(TKey name, out bool value) + { + if (this.TryGetAndRemove(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToBoolean(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(bool); + return false; + } + + public bool GetBooleanAndRemove(TKey name, bool defaultValue) => this.TryGetBooleanAndRemove(name, out bool value) ? value : defaultValue; + + public bool TryGetByteAndRemove(TKey name, out byte value) + { + if (this.TryGetAndRemove(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToByte(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + value = default(byte); + return false; + } + + public byte GetByteAndRemove(TKey name, byte defaultValue) => this.TryGetByteAndRemove(name, out byte value) ? value : defaultValue; + + public bool TryGetCharAndRemove(TKey name, out char value) + { + if (this.TryGetAndRemove(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToChar(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(char); + return false; + } + + public char GetCharAndRemove(TKey name, char defaultValue) => this.TryGetCharAndRemove(name, out char value) ? value : defaultValue; + + public bool TryGetShortAndRemove(TKey name, out short value) + { + if (this.TryGetAndRemove(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToShort(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(short); + return false; + } + + public short GetShortAndRemove(TKey name, short defaultValue) => this.TryGetShortAndRemove(name, out short value) ? value : defaultValue; + + public bool TryGetIntAndRemove(TKey name, out int value) + { + if (this.TryGetAndRemove(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToInt(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(int); + return false; + } + + public int GetIntAndRemove(TKey name, int defaultValue) => this.TryGetIntAndRemove(name, out int value) ? value : defaultValue; + + public bool TryGetLongAndRemove(TKey name, out long value) + { + if (this.TryGetAndRemove(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToLong(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(long); + return false; + } + + public long GetLongAndRemove(TKey name, long defaultValue) => this.TryGetLongAndRemove(name, out long value) ? value : defaultValue; + + public bool TryGetFloatAndRemove(TKey name, out float value) + { + if (this.TryGetAndRemove(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToFloat(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(float); + return false; + } + + public float GetFloatAndRemove(TKey name, float defaultValue) => this.TryGetFloatAndRemove(name, out float value) ? value : defaultValue; + + public bool TryGetDoubleAndRemove(TKey name, out double value) + { + if (this.TryGetAndRemove(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToDouble(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(double); + return false; + } + + public double GetDoubleAndRemove(TKey name, double defaultValue) => this.TryGetDoubleAndRemove(name, out double value) ? value : defaultValue; + + public bool TryGetTimeMillisAndRemove(TKey name, out long value) + { + if (this.TryGetAndRemove(name, out TValue v)) + { + try + { + value = this.ValueConverter.ConvertToTimeMillis(v); + return true; + } + catch (Exception) + { + // Ignore + } + } + + value = default(long); + return false; + } + + public long GetTimeMillisAndRemove(TKey name, long defaultValue) => this.TryGetTimeMillisAndRemove(name, out long value) ? value : defaultValue; + + public override bool Equals(object obj) => obj is IHeaders headers && this.Equals(headers, DefaultValueHashingStrategy); + + public override int GetHashCode() => this.HashCode(DefaultValueHashingStrategy); + + public bool Equals(IHeaders h2, IHashingStrategy valueHashingStrategy) + { + if (h2.Size != this.size) + { + return false; + } + + if (ReferenceEquals(this, h2)) + { + return true; + } + + foreach (TKey name in this.Names()) + { + IList otherValues = h2.GetAll(name); + IList values = this.GetAll(name); + if (otherValues.Count != values.Count) + { + return false; + } + for (int i = 0; i < otherValues.Count; i++) + { + if (!valueHashingStrategy.Equals(otherValues[i], values[i])) + { + return false; + } + } + } + return true; + } + + public int HashCode(IHashingStrategy valueHashingStrategy) + { + int result = HashCodeSeed; + foreach (TKey name in this.Names()) + { + result = 31 * result + this.hashingStrategy.HashCode(name); + IList values = this.GetAll(name); + for (int i = 0; i < values.Count; ++i) + { + result = 31 * result + valueHashingStrategy.HashCode(values[i]); + } + } + return result; + } + + public override string ToString() => HeadersUtils.ToString(this, this.size); + + protected HeaderEntry NewHeaderEntry(int h, TKey name, TValue value, HeaderEntry next) => + new HeaderEntry(h, name, value, next, this.head); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + int Index(int hash) => hash & this.hashMask; + + void Add0(int h, int i, TKey name, TValue value) + { + // Update the hash table. + this.entries[i] = this.NewHeaderEntry(h, name, value, this.entries[i]); + ++this.size; + } + + bool TryRemove0(int h, int i, TKey name, out TValue value) + { + value = default(TValue); + + HeaderEntry e = this.entries[i]; + if (e == null) + { + return false; + } + + bool result = false; + + HeaderEntry next = e.Next; + while (next != null) + { + if (next.Hash == h && this.hashingStrategy.Equals(name, next.key)) + { + value = next.value; + e.Next = next.Next; + next.Remove(); + --this.size; + result = true; + } + else + { + e = next; + } + + next = e.Next; + } + + e = this.entries[i]; + if (e.Hash == h && this.hashingStrategy.Equals(name, e.key)) + { + if (!result) + { + value = e.value; + result = true; + } + this.entries[i] = e.Next; + e.Remove(); + --this.size; + } + + return result; + } + + public DefaultHeaders Copy() + { + var copy = new DefaultHeaders(this.hashingStrategy, this.ValueConverter, this.nameValidator, this.entries.Length); + copy.AddImpl(this); + return copy; + } + + struct ValueEnumerator : IEnumerator, IEnumerable + { + readonly IHashingStrategy hashingStrategy; + readonly int hash; + readonly TKey name; + readonly HeaderEntry head; + HeaderEntry node; + TValue current; + + public ValueEnumerator(DefaultHeaders headers, TKey name) + { + if (name == null) ThrowArgumentNullException(nameof(name)); + + this.hashingStrategy = headers.hashingStrategy; + this.hash = this.hashingStrategy.HashCode(name); + this.name = name; + this.node = this.head = headers.entries[headers.Index(this.hash)]; + this.current = default(TValue); + } + + bool IEnumerator.MoveNext() + { + if (this.node == null) + { + return false; + } + + this.current = this.node.value; + this.CalculateNext(this.node.Next); + return true; + } + + void CalculateNext(HeaderEntry entry) + { + while (entry != null) + { + if (entry.Hash == this.hash && this.hashingStrategy.Equals(this.name, entry.key)) + { + this.node = entry; + return; + } + entry = entry.Next; + } + this.node = null; + } + + TValue IEnumerator.Current => this.current; + + object IEnumerator.Current => this.current; + + void IEnumerator.Reset() + { + this.node = this.head; + this.current = default(TValue); + } + + void IDisposable.Dispose() + { + this.node = null; + this.current = default(TValue); + } + + public IEnumerator GetEnumerator() => this; + + IEnumerator IEnumerable.GetEnumerator() => this; + } + + struct HeaderEnumerator : IEnumerator> + { + readonly HeaderEntry head; + readonly int size; + + HeaderEntry node; + int index; + + public HeaderEnumerator(DefaultHeaders headers) + { + this.head = headers.head; + this.size = headers.size; + this.node = this.head; + this.index = 0; + } + + public HeaderEntry Current => this.node; + + object IEnumerator.Current + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get + { + if (this.index == 0 || this.index == this.size + 1) + { + ThrowInvalidOperationException("Enumerator not initialized or completed."); + } + return this.node; + } + } + + public bool MoveNext() + { + if (this.node == null) + { + this.index = this.size + 1; + return false; + } + + this.index++; + this.node = this.node.After; + if (this.node == this.head) + { + this.node = null; + return false; + } + return true; + } + + public void Reset() + { + this.node = this.head.After; + this.index = 0; + } + + public void Dispose() + { + this.node = null; + this.index = 0; + } + } + } + + public sealed class HeaderEntry + where TKey : class + { + internal readonly int Hash; + // ReSharper disable InconsistentNaming + internal readonly TKey key; + internal TValue value; + // ReSharper restore InconsistentNaming + + internal HeaderEntry Next; + internal HeaderEntry Before; + internal HeaderEntry After; + + public HeaderEntry(int hash, TKey key) + { + this.Hash = hash; + this.key = key; + } + + internal HeaderEntry() + { + this.Hash = -1; + this.key = default(TKey); + this.Before = this; + this.After = this; + } + + internal HeaderEntry(int hash, TKey key, TValue value, + HeaderEntry next, HeaderEntry head) + { + this.Hash = hash; + this.key = key; + this.value = value; + this.Next = next; + + this.After = head; + this.Before = head.Before; + // PointNeighborsToThis + this.Before.After = this; + this.After.Before = this; + } + + internal void Remove() + { + this.Before.After = this.After; + this.After.Before = this.Before; + } + + public TKey Key => this.key; + + public TValue Value => this.value; + + public TValue SetValue(TValue newValue) + { + if (ReferenceEquals(newValue, null)) ThrowArgumentNullException(nameof(newValue)); + + TValue oldValue = this.value; + this.value = newValue; + return oldValue; + } + + public override string ToString() => $"{this.key}={this.value}"; + + // ReSharper disable once MergeConditionalExpression + public override bool Equals(object obj) => obj is HeaderEntry other + && (this.key == null ? other.key == null : this.key.Equals(other.key)) + && (ReferenceEquals(this.value, null) ? ReferenceEquals(other.value, null) : this.value.Equals(other.value)); + + // ReSharper disable NonReadonlyMemberInGetHashCode + public override int GetHashCode() => (this.key == null ? 0 : this.key.GetHashCode()) + ^ (ReferenceEquals(this.value, null) ? 0 : this.value.GetHashCode()); + // ReSharper restore NonReadonlyMemberInGetHashCode + } +} diff --git a/src/DotNetty.Codecs/DotNetty.Codecs.csproj b/src/DotNetty.Codecs/DotNetty.Codecs.csproj index 1a81cc4..837c58a 100644 --- a/src/DotNetty.Codecs/DotNetty.Codecs.csproj +++ b/src/DotNetty.Codecs/DotNetty.Codecs.csproj @@ -28,6 +28,9 @@ + + + diff --git a/src/DotNetty.Codecs/EncoderException.cs b/src/DotNetty.Codecs/EncoderException.cs index 5966bbd..a146d21 100644 --- a/src/DotNetty.Codecs/EncoderException.cs +++ b/src/DotNetty.Codecs/EncoderException.cs @@ -5,7 +5,7 @@ namespace DotNetty.Codecs { using System; - public class EncoderException : Exception + public class EncoderException : CodecException { public EncoderException(string message) : base(message) @@ -16,5 +16,10 @@ namespace DotNetty.Codecs : base(null, innerException) { } + + public EncoderException(string message, Exception innerException) + : base(message, innerException) + { + } } } \ No newline at end of file diff --git a/src/DotNetty.Codecs/HeadersUtils.cs b/src/DotNetty.Codecs/HeadersUtils.cs new file mode 100644 index 0000000..e69d55d --- /dev/null +++ b/src/DotNetty.Codecs/HeadersUtils.cs @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs +{ + using System; + using System.Collections.Generic; + using System.Text; + using DotNetty.Common.Utilities; + + public static class HeadersUtils + { + public static List GetAllAsString(IHeaders headers, TKey name) + where TKey : class + { + IList allNames = headers.GetAll(name); + var values = new List(); + + // ReSharper disable once ForCanBeConvertedToForeach + // Avoid enumerator allocation + for (int i = 0; i < allNames.Count; i++) + { + TValue value = allNames[i]; + values.Add(value?.ToString()); + } + + return values; + } + + public static bool TryGetAsString(IHeaders headers, TKey name, out string value) + where TKey : class + { + if (headers.TryGet(name, out TValue orig)) + { + value = orig.ToString(); + return true; + } + else + { + value = default(string); + return false; + } + } + + public static string ToString(IEnumerable> headers, int size) + where TKey : class + { + string simpleName = StringUtil.SimpleClassName(headers); + if (size == 0) + { + return simpleName + "[]"; + } + else + { + // original capacity assumes 20 chars per headers + StringBuilder sb = new StringBuilder(simpleName.Length + 2 + size * 20) + .Append(simpleName) + .Append('['); + foreach (HeaderEntry header in headers) + { + sb.Append(header.Key).Append(": ").Append(header.Value).Append(", "); + } + sb.Length = sb.Length - 2; + return sb.Append(']').ToString(); + } + } + + public static IList NamesAsString(IHeaders headers) + { + ISet allNames = headers.Names(); + + var names = new List(); + + foreach (ICharSequence name in allNames) + { + names.Add(name.ToString()); + } + + return names; + } + + internal static void ThrowArgumentNullException(string name) => throw new ArgumentNullException(name); + + internal static void ThrowArgumentException(string message) => throw new ArgumentException(message); + + internal static void ThrowInvalidOperationException(string message) => throw new InvalidOperationException(message); + } +} diff --git a/src/DotNetty.Codecs/IDecoderResultProvider.cs b/src/DotNetty.Codecs/IDecoderResultProvider.cs new file mode 100644 index 0000000..5aa0765 --- /dev/null +++ b/src/DotNetty.Codecs/IDecoderResultProvider.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs +{ + public interface IDecoderResultProvider + { + DecoderResult Result { get; set; } + } +} diff --git a/src/DotNetty.Codecs/IHeaders.cs b/src/DotNetty.Codecs/IHeaders.cs new file mode 100644 index 0000000..c31f640 --- /dev/null +++ b/src/DotNetty.Codecs/IHeaders.cs @@ -0,0 +1,187 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs +{ + using System.Collections.Generic; + + public interface IHeaders : IEnumerable> + where TKey : class + { + bool TryGet(TKey name, out TValue value); + + TValue Get(TKey name, TValue defaultValue); + + bool TryGetAndRemove(TKey name, out TValue value); + + TValue GetAndRemove(TKey name, TValue defaultValue); + + IList GetAll(TKey name); + + IList GetAllAndRemove(TKey name); + + bool TryGetBoolean(TKey name, out bool value); + + bool GetBoolean(TKey name, bool defaultValue); + + bool TryGetByte(TKey name, out byte value); + + byte GetByte(TKey name, byte defaultValue); + + bool TryGetChar(TKey name, out char value); + + char GetChar(TKey name, char defaultValue); + + bool TryGetShort(TKey name, out short value); + + short GetShort(TKey name, short defaultValue); + + bool TryGetInt(TKey name, out int value); + + int GetInt(TKey name, int defaultValue); + + bool TryGetLong(TKey name, out long value); + + long GetLong(TKey name, long defaultValue); + + bool TryGetFloat(TKey name, out float value); + + float GetFloat(TKey name, float defaultValue); + + bool TryGetDouble(TKey name, out double value); + + double GetDouble(TKey name, double defaultValue); + + bool TryGetTimeMillis(TKey name, out long value); + + long GetTimeMillis(TKey name, long defaultValue); + + bool TryGetBooleanAndRemove(TKey name, out bool value); + + bool GetBooleanAndRemove(TKey name, bool defaultValue); + + bool TryGetByteAndRemove(TKey name, out byte value); + + byte GetByteAndRemove(TKey name, byte defaultValue); + + bool TryGetCharAndRemove(TKey name, out char value); + + char GetCharAndRemove(TKey name, char defaultValue); + + bool TryGetShortAndRemove(TKey name, out short value); + + short GetShortAndRemove(TKey name, short defaultValue); + + bool TryGetIntAndRemove(TKey name, out int value); + + int GetIntAndRemove(TKey name, int defaultValue); + + bool TryGetLongAndRemove(TKey name, out long value); + + long GetLongAndRemove(TKey name, long defaultValue); + + bool TryGetFloatAndRemove(TKey name, out float value); + + float GetFloatAndRemove(TKey name, float defaultValue); + + bool TryGetDoubleAndRemove(TKey name, out double value); + + double GetDoubleAndRemove(TKey name, double defaultValue); + + bool TryGetTimeMillisAndRemove(TKey name, out long value); + + long GetTimeMillisAndRemove(TKey name, long defaultValue); + + bool Contains(TKey name); + + bool Contains(TKey name, TValue value); + + bool ContainsObject(TKey name, object value); + + bool ContainsBoolean(TKey name, bool value); + + bool ContainsByte(TKey name, byte value); + + bool ContainsChar(TKey name, char value); + + bool ContainsShort(TKey name, short value); + + bool ContainsInt(TKey name, int value); + + bool ContainsLong(TKey name, long value); + + bool ContainsFloat(TKey name, float value); + + bool ContainsDouble(TKey name, double value); + + bool ContainsTimeMillis(TKey name, long value); + + int Size { get; } + + bool IsEmpty { get; } + + ISet Names(); + + IHeaders Add(TKey name, TValue value); + + IHeaders Add(TKey name, IEnumerable values); + + IHeaders AddObject(TKey name, object value); + + IHeaders AddObject(TKey name, IEnumerable values); + + IHeaders AddBoolean(TKey name, bool value); + + IHeaders AddByte(TKey name, byte value); + + IHeaders AddChar(TKey name, char value); + + IHeaders AddShort(TKey name, short value); + + IHeaders AddInt(TKey name, int value); + + IHeaders AddLong(TKey name, long value); + + IHeaders AddFloat(TKey name, float value); + + IHeaders AddDouble(TKey name, double value); + + IHeaders AddTimeMillis(TKey name, long value); + + IHeaders Add(IHeaders headers); + + IHeaders Set(TKey name, TValue value); + + IHeaders Set(TKey name, IEnumerable values); + + IHeaders SetObject(TKey name, object value); + + IHeaders SetObject(TKey name, IEnumerable values); + + IHeaders SetBoolean(TKey name, bool value); + + IHeaders SetByte(TKey name, byte value); + + IHeaders SetChar(TKey name, char value); + + IHeaders SetShort(TKey name, short value); + + IHeaders SetInt(TKey name, int value); + + IHeaders SetLong(TKey name, long value); + + IHeaders SetFloat(TKey name, float value); + + IHeaders SetDouble(TKey name, double value); + + IHeaders SetTimeMillis(TKey name, long value); + + IHeaders Set(IHeaders headers); + + IHeaders SetAll(IHeaders headers); + + bool Remove(TKey name); + + IHeaders Clear(); + } +} diff --git a/src/DotNetty.Codecs/INameValidator.cs b/src/DotNetty.Codecs/INameValidator.cs new file mode 100644 index 0000000..01c9d6b --- /dev/null +++ b/src/DotNetty.Codecs/INameValidator.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs +{ + public interface INameValidator + { + void ValidateName(T name); + } +} diff --git a/src/DotNetty.Codecs/IValueConverter.cs b/src/DotNetty.Codecs/IValueConverter.cs new file mode 100644 index 0000000..803f301 --- /dev/null +++ b/src/DotNetty.Codecs/IValueConverter.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs +{ + public interface IValueConverter + { + T ConvertObject(object value); + + T ConvertBoolean(bool value); + + bool ConvertToBoolean(T value); + + T ConvertByte(byte value); + + byte ConvertToByte(T value); + + T ConvertChar(char value); + + char ConvertToChar(T value); + + T ConvertShort(short value); + + short ConvertToShort(T value); + + T ConvertInt(int value); + + int ConvertToInt(T value); + + T ConvertLong(long value); + + long ConvertToLong(T value); + + T ConvertTimeMillis(long value); + + long ConvertToTimeMillis(T value); + + T ConvertFloat(float value); + + float ConvertToFloat(T value); + + T ConvertDouble(double value); + + double ConvertToDouble(T value); + } +} diff --git a/src/DotNetty.Codecs/MessageAggregationException.cs b/src/DotNetty.Codecs/MessageAggregationException.cs new file mode 100644 index 0000000..7d22eaa --- /dev/null +++ b/src/DotNetty.Codecs/MessageAggregationException.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs +{ + using System; + + public class MessageAggregationException : InvalidOperationException + { + public MessageAggregationException(string message) + : base(message) + { + } + + public MessageAggregationException(string message, Exception cause) + : base(message, cause) + { + } + } +} diff --git a/src/DotNetty.Codecs/MessageAggregator.cs b/src/DotNetty.Codecs/MessageAggregator.cs new file mode 100644 index 0000000..79f5e9c --- /dev/null +++ b/src/DotNetty.Codecs/MessageAggregator.cs @@ -0,0 +1,367 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs +{ + using System; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + /// + /// + /// An abstract that aggregates a series of message objects + /// into a single aggregated message. + /// 'A series of messages' is composed of the following: + /// a single start message which optionally contains the first part of the content, and + /// 1 or more content messages. The content of the aggregated message will be the merged + /// content of the start message and its following content messages. If this aggregator + /// encounters a content message where { @link #isLastContentMessage(ByteBufHolder)} + /// return true for, the aggregator will finish the aggregation and produce the aggregated + /// message and expect another start message. + /// + /// The type that covers both start message and content message + /// The type of the start message + /// The type of the content message + /// The type of the aggregated message + public abstract class MessageAggregator : MessageToMessageDecoder + where TContent : IByteBufferHolder + where TOutput : IByteBufferHolder + { + const int DefaultMaxCompositebufferComponents = 1024; + + int maxCumulationBufferComponents = DefaultMaxCompositebufferComponents; + + TOutput currentMessage; + bool handlingOversizedMessage; + + IChannelHandlerContext handlerContext; + + protected MessageAggregator(int maxContentLength) + { + ValidateMaxContentLength(maxContentLength); + this.MaxContentLength = maxContentLength; + } + + static void ValidateMaxContentLength(int maxContentLength) + { + if (maxContentLength < 0) + { + throw new ArgumentException($"maxContentLength: {maxContentLength}(expected: >= 0)", nameof(maxContentLength)); + } + } + + public override bool AcceptInboundMessage(object msg) + { + // No need to match last and full types because they are subset of first and middle types. + if (!base.AcceptInboundMessage(msg)) + { + return false; + } + + var message = (TMessage)msg; + return (this.IsContentMessage(message) || this.IsStartMessage(message)) + && !this.IsAggregated(message); + } + + protected abstract bool IsStartMessage(TMessage msg); + + protected abstract bool IsContentMessage(TMessage msg); + + protected abstract bool IsLastContentMessage(TContent msg); + + protected abstract bool IsAggregated(TMessage msg); + + public int MaxContentLength { get; } + + public int MaxCumulationBufferComponents + { + get => this.maxCumulationBufferComponents; + set + { + if (value < 2) + { + throw new ArgumentException($"maxCumulationBufferComponents: {value} (expected: >= 2)"); + } + if (this.handlerContext != null) + { + throw new InvalidOperationException("decoder properties cannot be changed once the decoder is added to a pipeline."); + } + + this.maxCumulationBufferComponents = value; + } + } + + protected IChannelHandlerContext HandlerContext() + { + if (this.handlerContext == null) + { + throw new InvalidOperationException("not added to a pipeline yet"); + } + + return this.handlerContext; + } + + protected internal override void Decode(IChannelHandlerContext context, TMessage message, List output) + { + if (this.IsStartMessage(message)) + { + this.handlingOversizedMessage = false; + if (this.currentMessage != null) + { + this.currentMessage.Release(); + this.currentMessage = default(TOutput); + + throw new MessageAggregationException("Start message should not have any current content."); + } + + var m = As(message); + Contract.Assert(m != null); + + // Send the continue response if necessary(e.g. 'Expect: 100-continue' header) + // Check before content length. Failing an expectation may result in a different response being sent. + object continueResponse = this.NewContinueResponse(m, this.MaxContentLength, context.Channel.Pipeline); + if (continueResponse != null) + { + // Make sure to call this before writing, otherwise reference counts may be invalid. + bool closeAfterWrite = this.CloseAfterContinueResponse(continueResponse); + this.handlingOversizedMessage = this.IgnoreContentAfterContinueResponse(continueResponse); + + Task task = context + .WriteAndFlushAsync(continueResponse) + .ContinueWith(ContinueResponseWriteAction, context, TaskContinuationOptions.ExecuteSynchronously); + + if (closeAfterWrite) + { + task.ContinueWith(CloseAfterWriteAction, context, TaskContinuationOptions.ExecuteSynchronously); + return; + } + + if (this.handlingOversizedMessage) + { + return; + } + } + else if (this.IsContentLengthInvalid(m, this.MaxContentLength)) + { + // if content length is set, preemptively close if it's too large + this.InvokeHandleOversizedMessage(context, m); + return; + } + + if (m is IDecoderResultProvider provider && !provider.Result.IsSuccess) + { + TOutput aggregated; + if (m is IByteBufferHolder holder) + { + aggregated = this.BeginAggregation(m, (IByteBuffer)holder.Content.Retain()); + } + else + { + aggregated = this.BeginAggregation(m, Unpooled.Empty); + } + this.FinishAggregation(aggregated); + output.Add(aggregated); + return; + } + + // A streamed message - initialize the cumulative buffer, and wait for incoming chunks. + CompositeByteBuffer content = context.Allocator.CompositeBuffer(this.maxCumulationBufferComponents); + if (m is IByteBufferHolder bufferHolder) + { + AppendPartialContent(content, bufferHolder.Content); + } + this.currentMessage = this.BeginAggregation(m, content); + } + else if (this.IsContentMessage(message)) + { + if (this.currentMessage == null) + { + // it is possible that a TooLongFrameException was already thrown but we can still discard data + // until the begging of the next request/response. + return; + } + + // Merge the received chunk into the content of the current message. + var content = (CompositeByteBuffer)this.currentMessage.Content; + + var m = As(message); + + // Handle oversized message. + if (content.ReadableBytes > this.MaxContentLength - m.Content.ReadableBytes) + { + // By convention, full message type extends first message type. + //@SuppressWarnings("unchecked") + var s = As(this.currentMessage); + Contract.Assert(s != null); + + this.InvokeHandleOversizedMessage(context, s); + return; + } + + // Append the content of the chunk. + AppendPartialContent(content, m.Content); + + // Give the subtypes a chance to merge additional information such as trailing headers. + this.Aggregate(this.currentMessage, m); + + bool last; + if (m is IDecoderResultProvider provider) + { + DecoderResult decoderResult = provider.Result; + if (!decoderResult.IsSuccess) + { + if (this.currentMessage is IDecoderResultProvider resultProvider) + { + resultProvider.Result = DecoderResult.Failure(decoderResult.Cause); + } + + last = true; + } + else + { + last = this.IsLastContentMessage(m); + } + } + else + { + last = this.IsLastContentMessage(m); + } + + if (last) + { + this.FinishAggregation(this.currentMessage); + + // All done + output.Add(this.currentMessage); + this.currentMessage = default(TOutput); + } + } + else + { + throw new MessageAggregationException("Unknown aggregation state."); + } + } + + static void CloseAfterWriteAction(Task task, object state) + { + var ctx = (IChannelHandlerContext)state; + ctx.Channel.CloseAsync(); + } + + static void ContinueResponseWriteAction(Task task, object state) + { + if (task.IsFaulted) + { + var ctx = (IChannelHandlerContext)state; + ctx.FireExceptionCaught(task.Exception); + } + } + + static T As(object obj) => (T)obj; + + static void AppendPartialContent(CompositeByteBuffer content, IByteBuffer partialContent) + { + if (!partialContent.IsReadable()) + { + return; + } + + content.AddComponent((IByteBuffer)partialContent.Retain()); + content.SetWriterIndex(content.WriterIndex + partialContent.ReadableBytes); + } + + protected abstract bool IsContentLengthInvalid(TStart start, int maxContentLength); + + protected abstract object NewContinueResponse(TStart start, int maxContentLength, IChannelPipeline pipeline); + + protected abstract bool CloseAfterContinueResponse(object msg); + + protected abstract bool IgnoreContentAfterContinueResponse(object msg); + + protected abstract TOutput BeginAggregation(TStart start, IByteBuffer content); + + protected virtual void Aggregate(TOutput aggregated, TContent content) + { + } + + protected virtual void FinishAggregation(TOutput aggregated) + { + } + + void InvokeHandleOversizedMessage(IChannelHandlerContext ctx, TStart oversized) + { + this.handlingOversizedMessage = true; + this.currentMessage = default(TOutput); + try + { + this.HandleOversizedMessage(ctx, oversized); + } + finally + { + // Release the message in case it is a full one. + ReferenceCountUtil.Release(oversized); + } + } + + protected virtual void HandleOversizedMessage(IChannelHandlerContext ctx, TStart oversized) => + ctx.FireExceptionCaught(new TooLongFrameException($"content length exceeded {this.MaxContentLength} bytes.")); + + public override void ChannelReadComplete(IChannelHandlerContext context) + { + // We might need keep reading the channel until the full message is aggregated. + // + // See https://github.com/netty/netty/issues/6583 + if (this.currentMessage != null && !this.handlerContext.Channel.Configuration.AutoRead) + { + context.Read(); + } + + context.FireChannelReadComplete(); + } + + public override void ChannelInactive(IChannelHandlerContext context) + { + try + { + // release current message if it is not null as it may be a left-over + base.ChannelInactive(context); + } + finally + { + this.ReleaseCurrentMessage(); + } + } + + public override void HandlerAdded(IChannelHandlerContext context) => this.handlerContext = context; + + public override void HandlerRemoved(IChannelHandlerContext context) + { + try + { + base.HandlerRemoved(context); + } + finally + { + // release current message if it is not null as it may be a left-over as there is not much more we can do in + // this case + this.ReleaseCurrentMessage(); + } + } + + void ReleaseCurrentMessage() + { + if (this.currentMessage == null) + { + return; + } + + this.currentMessage.Release(); + this.currentMessage = default(TOutput); + this.handlingOversizedMessage = false; + } + } +} diff --git a/src/DotNetty.Codecs/MessageToMessageCodec.cs b/src/DotNetty.Codecs/MessageToMessageCodec.cs new file mode 100644 index 0000000..a990193 --- /dev/null +++ b/src/DotNetty.Codecs/MessageToMessageCodec.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs +{ + using System.Collections.Generic; + using System.Threading.Tasks; + using DotNetty.Transport.Channels; + + public abstract class MessageToMessageCodec : ChannelDuplexHandler + { + readonly Encoder encoder; + readonly Decoder decoder; + + sealed class Encoder : MessageToMessageEncoder + { + readonly MessageToMessageCodec codec; + + public Encoder(MessageToMessageCodec codec) + { + this.codec = codec; + } + + public override bool AcceptOutboundMessage(object msg) => this.codec.AcceptOutboundMessage(msg); + + protected internal override void Encode(IChannelHandlerContext context, object message, List output) => this.codec.Encode(context, (TOutbound)message, output); + } + + sealed class Decoder : MessageToMessageDecoder + { + readonly MessageToMessageCodec codec; + + public Decoder(MessageToMessageCodec codec) + { + this.codec = codec; + } + + public override bool AcceptInboundMessage(object msg) => this.codec.AcceptInboundMessage(msg); + + protected internal override void Decode(IChannelHandlerContext context, object message, List output) => + this.codec.Decode(context, (TInbound)message, output); + } + + protected MessageToMessageCodec() + { + this.encoder = new Encoder(this); + this.decoder = new Decoder(this); + } + + public sealed override void ChannelRead(IChannelHandlerContext context, object message) => + this.decoder.ChannelRead(context, message); + + public sealed override Task WriteAsync(IChannelHandlerContext context, object message) => + this.encoder.WriteAsync(context, message); + + public virtual bool AcceptInboundMessage(object msg) => msg is TInbound; + + public virtual bool AcceptOutboundMessage(object msg) => msg is TOutbound; + + protected abstract void Encode(IChannelHandlerContext ctx, TOutbound msg, List output); + + protected abstract void Decode(IChannelHandlerContext ctx, TInbound msg, List output); + } +} diff --git a/src/DotNetty.Codecs/NullNameValidator.cs b/src/DotNetty.Codecs/NullNameValidator.cs new file mode 100644 index 0000000..4d52283 --- /dev/null +++ b/src/DotNetty.Codecs/NullNameValidator.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs +{ + using System; + + public sealed class NullNameValidator : INameValidator + { + public void ValidateName(T name) + { + if (name == null) + { + throw new ArgumentNullException(nameof(name)); + } + } + } +} diff --git a/src/DotNetty.Codecs/PrematureChannelClosureException.cs b/src/DotNetty.Codecs/PrematureChannelClosureException.cs new file mode 100644 index 0000000..bc57de6 --- /dev/null +++ b/src/DotNetty.Codecs/PrematureChannelClosureException.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs +{ + using System; + + public class PrematureChannelClosureException : CodecException + { + public PrematureChannelClosureException(string message) + : this(message, null) + { + } + + public PrematureChannelClosureException(Exception exception) + : this(null, exception) + { + } + + public PrematureChannelClosureException(string message, Exception exception) + : base(message, exception) + { + } + } +} diff --git a/src/DotNetty.Codecs/Properties/Friends.cs b/src/DotNetty.Codecs/Properties/Friends.cs new file mode 100644 index 0000000..900a37e --- /dev/null +++ b/src/DotNetty.Codecs/Properties/Friends.cs @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("DotNetty.Codecs.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100d9782d5a0b850f230f71e06de2e101d8441d83e15eef715837eee38fdbf5cb369b41ec36e6e7668c18cbb09e5419c179360461e740c1cce6ffbdcf81f245e1e705482797fe42aff2d31ecd72ea87362ded3c14066746fbab4a8e1896f8b982323c84e2c1b08407c0de18b7feef1535fb972a3b26181f5a304ebd181795a46d8f")] diff --git a/src/DotNetty.Common/Internal/AppendableCharSequence.cs b/src/DotNetty.Common/Internal/AppendableCharSequence.cs new file mode 100644 index 0000000..f9af4c7 --- /dev/null +++ b/src/DotNetty.Common/Internal/AppendableCharSequence.cs @@ -0,0 +1,227 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWithPrivateSetter +namespace DotNetty.Common.Internal +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Runtime.CompilerServices; + using System.Text; + using DotNetty.Common.Utilities; + + // + // This is used exclusively for http headers as a buffer + // + // In original Netty, this is backed by char array because http parsing + // converts each byte to char, then chars to string which implements ICharSequence + // in java and can be used in the same way as AsciiString. + // + // This approach performs poorly on .net because DotNetty only uses AsciiString + // for headers. DotNetty converts each byte to char, then chars back to bytes + // again when reading out to AsciiString. + // + // Each byte to char and each char to byte forwards and backwards! + // + // Therefore this is backed by bytes directly in DotNetty to avoid double conversions, + // and all chars are assumed to be ASCII! + // + public sealed class AppendableCharSequence : ICharSequence, IAppendable, IEquatable + { + byte[] chars; + int pos; + + public AppendableCharSequence(int length) + { + Contract.Requires(length > 0); + + this.chars = new byte[length]; + } + + public AppendableCharSequence(byte[] chars) + { + Contract.Requires(chars.Length > 0); + + this.chars = chars; + this.pos = chars.Length; + } + + public IEnumerator GetEnumerator() => new CharSequenceEnumerator(this); + + IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator(); + + public int Count => this.pos; + + public char this[int index] + { + get + { + Contract.Requires(index <= this.pos); + return AsciiString.ByteToChar(this.chars[index]); + } + } + + public ref byte[] Bytes => ref this.chars; + + public ICharSequence SubSequence(int start) => this.SubSequence(start, this.pos); + + public ICharSequence SubSequence(int start, int end) + { + int length = end - start; + var data = new byte[length]; + PlatformDependent.CopyMemory(this.chars, start, data, 0, length); + return new AppendableCharSequence(data); + } + + public int IndexOf(char ch, int start = 0) => CharUtil.IndexOf(this, ch, start); + + public bool RegionMatches(int thisStart, ICharSequence seq, int start, int length) => + CharUtil.RegionMatches(this, thisStart, seq, start, length); + + public bool RegionMatchesIgnoreCase(int thisStart, ICharSequence seq, int start, int length) => + CharUtil.RegionMatchesIgnoreCase(this, thisStart, seq, start, length); + + public bool ContentEquals(ICharSequence other) => CharUtil.ContentEquals(this, other); + + public bool ContentEqualsIgnoreCase(ICharSequence other) => CharUtil.ContentEqualsIgnoreCase(this, other); + + public bool Equals(AppendableCharSequence other) + { + if (other == null) + { + return false; + } + if (ReferenceEquals(this, other)) + { + return true; + } + + return this.pos == other.pos + && PlatformDependent.ByteArrayEquals(this.chars, 0, other.chars, 0, this.pos); + } + + public override bool Equals(object obj) + { + if (obj == null) + { + return false; + } + if (ReferenceEquals(this, obj)) + { + return true; + } + + if (obj is AppendableCharSequence other) + { + return this.Equals(other); + } + if (obj is ICharSequence seq) + { + return this.ContentEquals(seq); + } + + return false; + } + + public int HashCode(bool ignoreCase) => ignoreCase + ? StringComparer.OrdinalIgnoreCase.GetHashCode(this.ToString()) + : StringComparer.Ordinal.GetHashCode(this.ToString()); + + public override int GetHashCode() => this.HashCode(true); + + public IAppendable Append(char c) => this.Append((byte)c); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public IAppendable Append(byte c) + { + if (this.pos == this.chars.Length) + { + byte[] old = this.chars; + this.chars = new byte[old.Length << 1]; + PlatformDependent.CopyMemory(old, 0, this.chars, 0, old.Length); + } + this.chars[this.pos++] = c; + return this; + } + + public IAppendable Append(ICharSequence sequence) => this.Append(sequence, 0, sequence.Count); + + public IAppendable Append(ICharSequence sequence, int start, int end) + { + Contract.Requires(sequence.Count >= end); + + int length = end - start; + if (length > this.chars.Length - this.pos) + { + this.chars = Expand(this.chars, this.pos + length, this.pos); + } + + if (sequence is AppendableCharSequence seq) + { + // Optimize append operations via array copy + byte[] src = seq.chars; + PlatformDependent.CopyMemory(src, start, this.chars, this.pos, length); + this.pos += length; + + return this; + } + + for (int i = start; i < end; i++) + { + this.chars[this.pos++] = (byte)sequence[i]; + } + + return this; + } + + // Reset the {@link AppendableCharSequence}. Be aware this will only reset the current internal position and not + // shrink the internal char array. + public void Reset() => this.pos = 0; + + public string ToString(int start) + { + Contract.Requires(start >= 0 && start < this.pos); + return Encoding.ASCII.GetString(this.chars, start, this.pos); + } + + public override string ToString() => this.pos == 0 ? string.Empty : this.ToString(0); + + public AsciiString ToAsciiString() => this.pos == 0 ? AsciiString.Empty : new AsciiString(this.chars, 0, this.pos, true); + + // Create a new ascii string, this method assumes all chars has been sanitized + // to ascii chars when appending to the array + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public unsafe AsciiString SubStringUnsafe(int start, int end) + { + var bytes = new byte[end - start]; + fixed (byte* src = &this.chars[start]) + fixed (byte* dst = bytes) + { + PlatformDependent.CopyMemory(src, dst, bytes.Length); + } + return new AsciiString(bytes); + } + + static byte[] Expand(byte[] array, int neededSpace, int size) + { + int newCapacity = array.Length; + do + { + // double capacity until it is big enough + newCapacity <<= 1; + + if (newCapacity < 0) + { + throw new InvalidOperationException($"New capacity {newCapacity} must be positive"); + } + } + while (neededSpace > newCapacity); + + var newArray = new byte[newCapacity]; + PlatformDependent.CopyMemory(array, 0, newArray, 0, size); + return newArray; + } + } +} diff --git a/src/DotNetty.Common/Internal/ConcurrentCircularArrayQueue.cs b/src/DotNetty.Common/Internal/ConcurrentCircularArrayQueue.cs index d1e4d38..4ec0a12 100644 --- a/src/DotNetty.Common/Internal/ConcurrentCircularArrayQueue.cs +++ b/src/DotNetty.Common/Internal/ConcurrentCircularArrayQueue.cs @@ -20,8 +20,6 @@ namespace DotNetty.Common.Internal /// parameter are provided to allow the prevention of field reload after a /// LoadLoad barrier. ///

- /// @param - /// abstract class ConcurrentCircularArrayQueue : ConcurrentCircularArrayQueueL0Pad where T : class { @@ -62,8 +60,7 @@ namespace DotNetty.Common.Internal public override void Clear() { - T item; - while (this.TryDequeue(out item) || !this.IsEmpty) + while (this.TryDequeue(out T _) || !this.IsEmpty) { // looping } diff --git a/src/DotNetty.Common/Internal/EmptyArrays.cs b/src/DotNetty.Common/Internal/EmptyArrays.cs new file mode 100644 index 0000000..b56258b --- /dev/null +++ b/src/DotNetty.Common/Internal/EmptyArrays.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Common.Internal +{ + using DotNetty.Common.Utilities; + + public static class EmptyArrays + { + public static readonly int[] EmptyInts = { }; + + public static readonly byte[] EmptyBytes = { }; + + public static readonly char[] EmptyChars = { }; + + public static readonly object[] EmptyObjects = { }; + + public static readonly string[] EmptyStrings = { }; + + public static readonly AsciiString[] EmptyAsciiStrings = { }; + } +} diff --git a/src/DotNetty.Common/Internal/IAppendable.cs b/src/DotNetty.Common/Internal/IAppendable.cs new file mode 100644 index 0000000..7a60e9d --- /dev/null +++ b/src/DotNetty.Common/Internal/IAppendable.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Common.Internal +{ + using DotNetty.Common.Utilities; + + public interface IAppendable + { + IAppendable Append(char c); + + IAppendable Append(ICharSequence sequence); + + IAppendable Append(ICharSequence sequence, int start, int end); + } +} diff --git a/src/DotNetty.Common/Internal/PlatformDependent.cs b/src/DotNetty.Common/Internal/PlatformDependent.cs index f16258c..f1fb07e 100644 --- a/src/DotNetty.Common/Internal/PlatformDependent.cs +++ b/src/DotNetty.Common/Internal/PlatformDependent.cs @@ -11,6 +11,9 @@ namespace DotNetty.Common.Internal using System.Runtime.CompilerServices; using System.Threading; using DotNetty.Common.Internal.Logging; + using DotNetty.Common.Utilities; + + using static PlatformDependent0; public static class PlatformDependent { @@ -31,6 +34,7 @@ namespace DotNetty.Common.Internal static int seed = (int)(Stopwatch.GetTimestamp() & 0xFFFFFFFF); //used to safly cast long to int, because the timestamp returned is long and it doesn't fit into an int static readonly ThreadLocal ThreadLocalRandom = new ThreadLocal(() => new Random(Interlocked.Increment(ref seed))); //used to simulate java ThreadLocalRandom + static readonly bool IsLittleEndian = BitConverter.IsLittleEndian; public static IQueue NewFixedMpscQueue(int capacity) where T : class => new MpscArrayQueue(capacity); @@ -44,11 +48,179 @@ namespace DotNetty.Common.Internal public static unsafe bool ByteArrayEquals(byte[] bytes1, int startPos1, byte[] bytes2, int startPos2, int length) { - fixed (byte* array1 = bytes1) - fixed (byte* array2 = bytes2) - return PlatformDependent0.ByteArrayEquals(array1, startPos1, array2, startPos2, length); + if (length <= 0) + { + return true; + } + + fixed (byte* array1 = &bytes1[startPos1]) + fixed (byte* array2 = &bytes2[startPos2]) + return PlatformDependent0.ByteArrayEquals(array1, array2, length); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe int HashCodeAscii(byte[] bytes, int startPos, int length) + { + if (length == 0) + { + return HashCodeAsciiSeed; + } + + fixed (byte* array = &bytes[startPos]) + return PlatformDependent0.HashCodeAscii(array, length); + } + + public static int HashCodeAscii(ICharSequence bytes) + { + int hash = HashCodeAsciiSeed; + int remainingBytes = bytes.Count & 7; + + // Benchmarking shows that by just naively looping for inputs 8~31 bytes long we incur a relatively large + // performance penalty (only achieve about 60% performance of loop which iterates over each char). So because + // of this we take special provisions to unroll the looping for these conditions. + switch (bytes.Count) + { + case 31: + case 30: + case 29: + case 28: + case 27: + case 26: + case 25: + case 24: + hash = HashCodeAsciiCompute( + bytes, + bytes.Count - 24, + HashCodeAsciiCompute( + bytes, + bytes.Count - 16, + HashCodeAsciiCompute(bytes, bytes.Count - 8, hash))); + break; + case 23: + case 22: + case 21: + case 20: + case 19: + case 18: + case 17: + case 16: + hash = HashCodeAsciiCompute( + bytes, + bytes.Count - 16, + HashCodeAsciiCompute(bytes, bytes.Count - 8, hash)); + break; + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + case 8: + hash = HashCodeAsciiCompute(bytes, bytes.Count - 8, hash); + break; + case 7: + case 6: + case 5: + case 4: + case 3: + case 2: + case 1: + case 0: + break; + default: + for (int i = bytes.Count - 8; i >= remainingBytes; i -= 8) + { + hash = HashCodeAsciiCompute(bytes, i, hash); + } + break; + } + switch (remainingBytes) + { + case 7: + return ((hash + * HashCodeC1 + HashCodeAsciiSanitizsByte(bytes[0])) + * HashCodeC2 + HashCodeAsciiSanitizeShort(bytes, 1)) + * HashCodeC1 + HashCodeAsciiSanitizeInt(bytes, 3); + case 6: + return (hash + * HashCodeC1 + HashCodeAsciiSanitizeShort(bytes, 0)) + * HashCodeC2 + HashCodeAsciiSanitizeInt(bytes, 2); + case 5: + return (hash + * HashCodeC1 + HashCodeAsciiSanitizsByte(bytes[0])) + * HashCodeC2 + HashCodeAsciiSanitizeInt(bytes, 1); + case 4: + return hash + * HashCodeC1 + HashCodeAsciiSanitizeInt(bytes, 0); + case 3: + return (hash + * HashCodeC1 + HashCodeAsciiSanitizsByte(bytes[0])) + * HashCodeC2 + HashCodeAsciiSanitizeShort(bytes, 1); + case 2: + return hash + * HashCodeC1 + HashCodeAsciiSanitizeShort(bytes, 0); + case 1: + return hash + * HashCodeC1 + HashCodeAsciiSanitizsByte(bytes[0]); + default: + return hash; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static int HashCodeAsciiCompute(ICharSequence value, int offset, int hash) + { + if (!IsLittleEndian) + { + return hash * HashCodeC1 + + // Low order int + HashCodeAsciiSanitizeInt(value, offset + 4) * HashCodeC2 + + // High order int + HashCodeAsciiSanitizeInt(value, offset); + } + return hash * HashCodeC1 + + // Low order int + HashCodeAsciiSanitizeInt(value, offset) * HashCodeC2 + + // High order int + HashCodeAsciiSanitizeInt(value, offset + 4); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static int HashCodeAsciiSanitizeInt(ICharSequence value, int offset) + { + if (!IsLittleEndian) + { + // mimic a unsafe.getInt call on a big endian machine + return (value[offset + 3] & 0x1f) + | (value[offset + 2] & 0x1f) << 8 + | (value[offset + 1] & 0x1f) << 16 + | (value[offset] & 0x1f) << 24; + } + + return (value[offset + 3] & 0x1f) << 24 + | (value[offset + 2] & 0x1f) << 16 + | (value[offset + 1] & 0x1f) << 8 + | (value[offset] & 0x1f); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static int HashCodeAsciiSanitizeShort(ICharSequence value, int offset) + { + if (!IsLittleEndian) + { + // mimic a unsafe.getShort call on a big endian machine + return (value[offset + 1] & 0x1f) + | (value[offset] & 0x1f) << 8; + } + + return (value[offset + 1] & 0x1f) << 8 + | (value[offset] & 0x1f); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static int HashCodeAsciiSanitizsByte(char value) => value & 0x1f; + public static void CopyMemory(byte[] src, int srcIndex, byte[] dst, int dstIndex, int length) { if (length > 0) diff --git a/src/DotNetty.Common/Internal/PlatformDependent0.cs b/src/DotNetty.Common/Internal/PlatformDependent0.cs index a7b7378..ed35711 100644 --- a/src/DotNetty.Common/Internal/PlatformDependent0.cs +++ b/src/DotNetty.Common/Internal/PlatformDependent0.cs @@ -4,19 +4,23 @@ namespace DotNetty.Common.Internal { using System.Runtime.CompilerServices; + using DotNetty.Common.Utilities; static class PlatformDependent0 { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static unsafe bool ByteArrayEquals(byte* bytes1, int startPos1, byte* bytes2, int startPos2, int length) + internal static readonly int HashCodeAsciiSeed = unchecked((int)0xc2b2ae35); + internal static readonly int HashCodeC1 = unchecked((int)0xcc9e2d51); + internal static readonly int HashCodeC2 = 0x1b873593; + + internal static unsafe bool ByteArrayEquals(byte* bytes1, byte* bytes2, int length) { if (length <= 0) { return true; } - byte* baseOffset1 = bytes1 + startPos1; - byte* baseOffset2 = bytes2 + startPos2; + byte* baseOffset1 = bytes1; + byte* baseOffset2 = bytes2; int remainingBytes = length & 7; byte* end = baseOffset1 + remainingBytes; for (byte* i = baseOffset1 - 8 + length, j = baseOffset2 - 8 + length; i >= end; i -= 8, j -= 8) @@ -38,9 +42,70 @@ namespace DotNetty.Common.Internal if (remainingBytes >= 2) { return Unsafe.ReadUnaligned(baseOffset1) == Unsafe.ReadUnaligned(baseOffset2) - && (remainingBytes == 2 || *(bytes1 + startPos1 + 2) == *(bytes2 + startPos2 + 2)); + && (remainingBytes == 2 || *(bytes1 + 2) == *(bytes2 + 2)); } return *baseOffset1 == *baseOffset2; } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static unsafe int HashCodeAscii(byte* bytes, int length) + { + int hash = HashCodeAsciiSeed; + int remainingBytes = length & 7; + byte* end = bytes + remainingBytes; + for (byte* i = bytes - 8 + length; i >= end; i -= 8) + { + hash = HashCodeAsciiCompute(Unsafe.ReadUnaligned(i), hash); + } + + switch (remainingBytes) + { + case 7: + return ((hash * HashCodeC1 + HashCodeAsciiSanitize(*bytes)) + * HashCodeC2 + HashCodeAsciiSanitize(Unsafe.ReadUnaligned(bytes + 1))) + * HashCodeC1 + HashCodeAsciiSanitize(Unsafe.ReadUnaligned(bytes + 3)); + case 6: + return (hash * HashCodeC1 + HashCodeAsciiSanitize(Unsafe.ReadUnaligned(bytes))) + * HashCodeC2 + HashCodeAsciiSanitize(Unsafe.ReadUnaligned(bytes + 2)); + case 5: + return (hash * HashCodeC1 + HashCodeAsciiSanitize(*bytes)) + * HashCodeC2 + HashCodeAsciiSanitize(Unsafe.ReadUnaligned(bytes + 1)); + case 4: + return hash * HashCodeC1 + HashCodeAsciiSanitize(Unsafe.ReadUnaligned(bytes)); + case 3: + return (hash * HashCodeC1 + HashCodeAsciiSanitize(*bytes)) + * HashCodeC2 + HashCodeAsciiSanitize(Unsafe.ReadUnaligned(bytes + 1)); + case 2: + return hash * HashCodeC1 + HashCodeAsciiSanitize(Unsafe.ReadUnaligned(bytes)); + case 1: + return hash * HashCodeC1 + HashCodeAsciiSanitize(*bytes); + default: + return hash; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static int HashCodeAsciiCompute(long value, int hash) + { + // masking with 0x1f reduces the number of overall bits that impact the hash code but makes the hash + // code the same regardless of character case (upper case or lower case hash is the same). + unchecked + { + return hash * HashCodeC1 + + // Low order int + HashCodeAsciiSanitize((int)value) * HashCodeC2 + + // High order int + (int)(value & 0x1f1f1f1f00000000L).RightUShift(32); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static int HashCodeAsciiSanitize(int value) => value & 0x1f1f1f1f; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static int HashCodeAsciiSanitize(short value) => value & 0x1f1f; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static int HashCodeAsciiSanitize(byte value) => value & 0x1f; } } diff --git a/src/DotNetty.Common/InternalThreadLocalMap.cs b/src/DotNetty.Common/InternalThreadLocalMap.cs index 4c4abac..a75bbae 100644 --- a/src/DotNetty.Common/InternalThreadLocalMap.cs +++ b/src/DotNetty.Common/InternalThreadLocalMap.cs @@ -4,6 +4,7 @@ namespace DotNetty.Common { using System; + using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Text; using System.Threading; @@ -16,8 +17,9 @@ namespace DotNetty.Common /// public sealed class InternalThreadLocalMap { - public static readonly object Unset = new object(); + const int DefaultArrayListInitialCapacity = 8; + public static readonly object Unset = new object(); [ThreadStatic] static InternalThreadLocalMap slowThreadLocalMap; @@ -33,6 +35,10 @@ namespace DotNetty.Common // String-related thread-locals StringBuilder stringBuilder; + // ArrayList-related thread-locals + List charSequences; + List asciiStrings; + internal static int NextVariableIndex() { int index = Interlocked.Increment(ref nextIndex); @@ -65,7 +71,9 @@ namespace DotNetty.Common // Cache line padding (must be public) // With CompressedOops enabled, an instance of this class should occupy at least 128 bytes. + // ReSharper disable InconsistentNaming public long rp1, rp2, rp3, rp4, rp5, rp6, rp7, rp8, rp9; + // ReSharper restore InconsistentNaming InternalThreadLocalMap() { @@ -129,6 +137,36 @@ namespace DotNetty.Common } } + public List CharSequenceList(int minCapacity = DefaultArrayListInitialCapacity) + { + List localList = this.charSequences; + if (localList == null) + { + this.charSequences = new List(minCapacity); + return this.charSequences; + } + + localList.Clear(); + // ensureCapacity + localList.Capacity = minCapacity; + return localList; + } + + public List AsciiStringList(int minCapacity = DefaultArrayListInitialCapacity) + { + List localList = this.asciiStrings; + if (localList == null) + { + this.asciiStrings = new List(minCapacity); + return this.asciiStrings; + } + + localList.Clear(); + // ensureCapacity + localList.Capacity = minCapacity; + return localList; + } + public int FutureListenerStackDepth { get => this.futureListenerStackDepth; diff --git a/src/DotNetty.Common/Utilities/AbstractReferenceCounted.cs b/src/DotNetty.Common/Utilities/AbstractReferenceCounted.cs index 3f7cadc..2853c65 100644 --- a/src/DotNetty.Common/Utilities/AbstractReferenceCounted.cs +++ b/src/DotNetty.Common/Utilities/AbstractReferenceCounted.cs @@ -20,7 +20,7 @@ namespace DotNetty.Common.Utilities return this.RetainCore(increment); } - IReferenceCounted RetainCore(int increment) + protected virtual IReferenceCounted RetainCore(int increment) { while (true) { diff --git a/src/DotNetty.Common/Utilities/ArrayExtensions.cs b/src/DotNetty.Common/Utilities/ArrayExtensions.cs index cc5859a..b163118 100644 --- a/src/DotNetty.Common/Utilities/ArrayExtensions.cs +++ b/src/DotNetty.Common/Utilities/ArrayExtensions.cs @@ -4,7 +4,6 @@ namespace DotNetty.Common.Utilities { using System; - using System.Collections.Generic; using System.Diagnostics.Contracts; ///

diff --git a/src/DotNetty.Common/Utilities/AsciiString.cs b/src/DotNetty.Common/Utilities/AsciiString.cs new file mode 100644 index 0000000..15945ab --- /dev/null +++ b/src/DotNetty.Common/Utilities/AsciiString.cs @@ -0,0 +1,1584 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoPropertyWhenPossible +// ReSharper disable UseStringInterpolation +namespace DotNetty.Common.Utilities +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Runtime.CompilerServices; + using System.Runtime.InteropServices; + using System.Text; + using DotNetty.Common.Internal; + + public sealed class AsciiString : ICharSequence, IEquatable, IComparable, IComparable + { + public static readonly AsciiString Empty = Cached(string.Empty); + const int MaxCharValue = 255; + const byte Replacement = (byte)'?'; + public static readonly int IndexNotFound = -1; + + public static readonly IHashingStrategy CaseInsensitiveHasher = new CaseInsensitiveHashingStrategy(); + public static readonly IHashingStrategy CaseSensitiveHasher = new CaseSensitiveHashingStrategy(); + + static readonly ICharEqualityComparator DefaultCharComparator = new DefaultCharEqualityComparator(); + static readonly ICharEqualityComparator GeneralCaseInsensitiveComparator = new GeneralCaseInsensitiveCharEqualityComparator(); + static readonly ICharEqualityComparator AsciiCaseInsensitiveCharComparator = new AsciiCaseInsensitiveCharEqualityComparator(); + + sealed class CaseInsensitiveHashingStrategy : IHashingStrategy + { + public int HashCode(ICharSequence obj) => AsciiString.GetHashCode(obj); + + int IEqualityComparer.GetHashCode(ICharSequence obj) => this.HashCode(obj); + + public bool Equals(ICharSequence a, ICharSequence b) => ContentEqualsIgnoreCase(a, b); + } + + sealed class CaseSensitiveHashingStrategy : IHashingStrategy + { + public int HashCode(ICharSequence obj) => AsciiString.GetHashCode(obj); + + int IEqualityComparer.GetHashCode(ICharSequence obj) => this.HashCode(obj); + + public bool Equals(ICharSequence a, ICharSequence b) => ContentEquals(a, b); + } + + readonly byte[] value; + readonly int offset; + readonly int length; + + int hash; + + //Used to cache the ToString() value. + string stringValue; + + // Called by AppendableCharSequence for http headers + internal AsciiString(byte[] value) + { + this.value = value; + this.offset = 0; + this.length = value.Length; + } + + public AsciiString(byte[] value, bool copy) : this(value, 0, value.Length, copy) + { + } + + public AsciiString(byte[] value, int start, int length, bool copy) + { + if (copy) + { + this.value = new byte[length]; + PlatformDependent.CopyMemory(value, start, this.value, 0, length); + this.offset = 0; + } + else + { + if (MathUtil.IsOutOfBounds(start, length, value.Length)) + { + ThrowIndexOutOfRangeException_Start(start, length, value.Length); + } + + this.value = value; + this.offset = start; + } + + this.length = length; + } + + public AsciiString(char[] value) : this(value, 0, value.Length) + { + } + + public unsafe AsciiString(char[] value, int start, int length) + { + if (MathUtil.IsOutOfBounds(start, length, value.Length)) + { + ThrowIndexOutOfRangeException_Start(start, length, value.Length); + } + + this.value = new byte[length]; + fixed (char* chars = value) + fixed (byte* bytes = this.value) + GetBytes(chars + start, length, bytes); + + this.offset = 0; + this.length = length; + } + + public AsciiString(char[] value, Encoding encoding) : this(value, encoding, 0, value.Length) + { + } + + public AsciiString(char[] value, Encoding encoding, int start, int length) + { + this.value = encoding.GetBytes(value, start, length); + this.offset = 0; + this.length = this.value.Length; + } + + public AsciiString(ICharSequence value) : this(value, 0, value.Count) + { + } + + public AsciiString(ICharSequence value, int start, int length) + { + if (MathUtil.IsOutOfBounds(start, length, value.Count)) + { + ThrowIndexOutOfRangeException_Start(start, length, value.Count); + } + + this.value = new byte[length]; + for (int i = 0, j = start; i < length; i++, j++) + { + this.value[i] = CharToByte(value[j]); + } + + this.offset = 0; + this.length = length; + } + + public AsciiString(string value, Encoding encoding) : this(value, encoding, 0, value.Length) + { + } + + public AsciiString(string value, Encoding encoding, int start, int length) + { + int count = encoding.GetMaxByteCount(length); + var bytes = new byte[count]; + count = encoding.GetBytes(value, start, length, bytes, 0); + + this.value = new byte[count]; + PlatformDependent.CopyMemory(bytes, 0, this.value, 0, count); + + this.offset = 0; + this.length = this.value.Length; + } + + public AsciiString(string value) : this(value, 0, value.Length) + { + } + + public AsciiString(string value, int start, int length) + { + if (MathUtil.IsOutOfBounds(start, length, value.Length)) + { + ThrowIndexOutOfRangeException_Start(start, length, value.Length); + } + + this.value = new byte[value.Length]; + for (int i = 0; i < value.Length; i++) + { + this.value[i] = CharToByte(value[i]); + } + + this.offset = 0; + this.length = value.Length; + } + + public int ForEachByte(IByteProcessor visitor) => this.ForEachByte0(0, this.length, visitor); + + public int ForEachByte(int index, int count, IByteProcessor visitor) + { + if (MathUtil.IsOutOfBounds(index, count, this.length)) + { + ThrowIndexOutOfRangeException_Index(index, count, this.length); + } + return this.ForEachByte0(index, count, visitor); + } + + int ForEachByte0(int index, int count, IByteProcessor visitor) + { + int len = this.offset + index + count; + for (int i = this.offset + index; i < len; ++i) + { + if (!visitor.Process(this.value[i])) + { + return i - this.offset; + } + } + + return -1; + } + + public int ForEachByteDesc(IByteProcessor visitor) => this.ForEachByteDesc0(0, this.length, visitor); + + public int ForEachByteDesc(int index, int count, IByteProcessor visitor) + { + if (MathUtil.IsOutOfBounds(index, count, this.length)) + { + ThrowIndexOutOfRangeException_Index(index, count, this.length); + } + + return this.ForEachByteDesc0(index, count, visitor); + } + + int ForEachByteDesc0(int index, int count, IByteProcessor visitor) + { + int end = this.offset + index; + for (int i = this.offset + index + count - 1; i >= end; --i) + { + if (!visitor.Process(this.value[i])) + { + return i - this.offset; + } + } + + return -1; + } + + public byte ByteAt(int index) + { + // We must do a range check here to enforce the access does not go outside our sub region of the array. + // We rely on the array access itself to pick up the array out of bounds conditions + if (index < 0 || index >= this.length) + { + ThrowIndexOutOfRangeException_Index(index, this.length); + } + + return this.value[index + this.offset]; + } + + public bool IsEmpty => this.length == 0; + + public int Count => this.length; + + /// + /// During normal use cases the AsciiString should be immutable, but if the + /// underlying array is shared, and changes then this needs to be called. + /// + public void ArrayChanged() + { + this.stringValue = null; + this.hash = 0; + } + + public byte[] Array => this.value; + + public int Offset => this.offset; + + public bool IsEntireArrayUsed => this.offset == 0 && this.length == this.value.Length; + + public byte[] ToByteArray() => this.ToByteArray(0, this.length); + + public byte[] ToByteArray(int start, int end) + { + int count = end - start; + var bytes = new byte[count]; + PlatformDependent.CopyMemory(this.value, this.offset + start, bytes, 0, count); + + return bytes; + } + + public void Copy(int srcIdx, byte[] dst, int dstIdx, int count) + { + Contract.Requires(dst != null && dst.Length >= count); + + if (MathUtil.IsOutOfBounds(srcIdx, count, this.length)) + { + ThrowIndexOutOfRangeException_SrcIndex(srcIdx, count, this.length); + } + if (count == 0) + { + return; + } + + PlatformDependent.CopyMemory(this.value, srcIdx + this.offset, dst, dstIdx, count); + } + + public char this[int index] => ByteToChar(this.ByteAt(index)); + + public bool Contains(ICharSequence sequence) => this.IndexOf(sequence) >= 0; + + public int CompareTo(ICharSequence other) + { + if (ReferenceEquals(this, other)) + { + return 0; + } + + int length1 = this.length; + int length2 = other.Count; + int minLength = Math.Min(length1, length2); + for (int i = 0, j = this.offset; i < minLength; i++, j++) + { + int result = ByteToChar(this.value[j]) - other[i]; + if (result != 0) + { + return result; + } + } + + return length1 - length2; + } + + public AsciiString Concat(ICharSequence charSequence) + { + int thisLen = this.length; + int thatLen = charSequence.Count; + if (thatLen == 0) + { + return this; + } + + byte[] newValue; + if (charSequence is AsciiString that) + { + if (this.IsEmpty) + { + return that; + } + + newValue = new byte[thisLen + thatLen]; + PlatformDependent.CopyMemory(this.value, this.offset, newValue, 0, thisLen); + PlatformDependent.CopyMemory(that.value, that.offset, newValue, thisLen, thatLen); + + return new AsciiString(newValue, false); + } + + if (this.IsEmpty) + { + return new AsciiString(charSequence); + } + + newValue = new byte[thisLen + thatLen]; + PlatformDependent.CopyMemory(this.value, this.offset, newValue, 0, thisLen); + for (int i = thisLen, j = 0; i < newValue.Length; i++, j++) + { + newValue[i] = CharToByte(charSequence[j]); + } + + return new AsciiString(newValue, false); + } + + public bool EndsWith(ICharSequence suffix) + { + int suffixLen = suffix.Count; + return this.RegionMatches(this.length - suffixLen, suffix, 0, suffixLen); + } + + public bool ContentEqualsIgnoreCase(ICharSequence other) + { + if (other == null || other.Count != this.length) + { + return false; + } + + if (other is AsciiString rhs) + { + for (int i = this.offset, j = rhs.offset; i < this.length; ++i, ++j) + { + if (!EqualsIgnoreCase(this.value[i], rhs.value[j])) + { + return false; + } + } + return true; + } + + for (int i = this.offset, j = 0; i < this.length; ++i, ++j) + { + if (!EqualsIgnoreCase(ByteToChar(this.value[i]), other[j])) + { + return false; + } + } + + return true; + } + + public char[] ToCharArray() => this.ToCharArray(0, this.length); + + public char[] ToCharArray(int start, int end) + { + int count = end - start; + if (count == 0) + { + return EmptyArrays.EmptyChars; + } + + if (MathUtil.IsOutOfBounds(start, count, this.length)) + { + ThrowIndexOutOfRangeException_SrcIndex(start, count, this.length); + } + + var buffer = new char[count]; + for (int i = 0, j = start + this.offset; i < count; i++, j++) + { + buffer[i] = ByteToChar(this.value[j]); + } + + return buffer; + } + + public void Copy(int srcIdx, char[] dst, int dstIdx, int count) + { + Contract.Requires(dst != null); + + if (MathUtil.IsOutOfBounds(srcIdx, count, this.length)) + { + ThrowIndexOutOfRangeException_SrcIndex(srcIdx, count, this.length); + } + + int dstEnd = dstIdx + count; + for (int i = dstIdx, j = srcIdx + this.offset; i < dstEnd; i++, j++) + { + dst[i] = ByteToChar(this.value[j]); + } + } + + public ICharSequence SubSequence(int start) => (AsciiString)this.SubSequence(start, this.length); + + public ICharSequence SubSequence(int start, int end) => this.SubSequence(start, end, true); + + public AsciiString SubSequence(int start, int end, bool copy) + { + if (MathUtil.IsOutOfBounds(start, end - start, this.length)) + { + ThrowIndexOutOfRangeException_StartEnd(start, end, this.length); + } + + if (start == 0 && end == this.length) + { + return this; + } + + return end == start ? Empty : new AsciiString(this.value, start + this.offset, end - start, copy); + } + + public int IndexOf(ICharSequence sequence) => this.IndexOf(sequence, 0); + + public int IndexOf(ICharSequence subString, int start) + { + if (start < 0) + { + start = 0; + } + + int thisLen = this.length; + + int subCount = subString.Count; + if (subCount <= 0) + { + return start < thisLen ? start : thisLen; + } + if (subCount > thisLen - start) + { + return -1; + } + + char firstChar = subString[0]; + if (firstChar > MaxCharValue) + { + return -1; + } + + var indexOfVisitor = new IndexOfProcessor((byte)firstChar); + for (; ; ) + { + int i = this.ForEachByte(start, thisLen - start, indexOfVisitor); + if (i == -1 || subCount + i > thisLen) + { + return -1; // handles subCount > count || start >= count + } + int o1 = i, o2 = 0; + while (++o2 < subCount && ByteToChar(this.value[++o1 + this.offset]) == subString[o2]) + { + // Intentionally empty + } + if (o2 == subCount) + { + return i; + } + start = i + 1; + } + } + + public int IndexOf(char ch, int start) + { + if (start < 0) + { + start = 0; + } + + int thisLen = this.length; + if (ch > MaxCharValue) + { + return -1; + } + + return this.ForEachByte(start, thisLen - start, new IndexOfProcessor((byte)ch)); + } + + // Use count instead of count - 1 so lastIndexOf("") answers count + public int LastIndexOf(ICharSequence charSequence) => this.LastIndexOf(charSequence, this.length); + + public int LastIndexOf(ICharSequence subString, int start) + { + int thisLen = this.length; + int subCount = subString.Count; + + if (subCount > thisLen || start < 0) + { + return -1; + } + + if (subCount <= 0) + { + return start < thisLen ? start : thisLen; + } + + start = Math.Min(start, thisLen - subCount); + + // count and subCount are both >= 1 + char firstChar = subString[0]; + if (firstChar > MaxCharValue) + { + return -1; + } + var indexOfVisitor = new IndexOfProcessor((byte)firstChar); + for (; ;) + { + int i = this.ForEachByteDesc(start, thisLen - start, indexOfVisitor); + if (i == -1) + { + return -1; + } + int o1 = i, o2 = 0; + while (++o2 < subCount && ByteToChar(this.value[++o1 + this.offset]) == subString[o2]) + { + // Intentionally empty + } + if (o2 == subCount) + { + return i; + } + start = i - 1; + } + } + + public bool RegionMatches(int thisStart, ICharSequence seq, int start, int count) + { + Contract.Requires(seq != null); + + if (start < 0 || seq.Count - start < count) + { + return false; + } + + int thisLen = this.length; + if (thisStart < 0 || thisLen - thisStart < count) + { + return false; + } + + if (count <= 0) + { + return true; + } + + int thatEnd = start + count; + for (int i = start, j = thisStart + this.offset; i < thatEnd; i++, j++) + { + if (ByteToChar(this.value[j]) != seq[i]) + { + return false; + } + } + + return true; + } + + public bool RegionMatchesIgnoreCase(int thisStart, ICharSequence seq, int start, int count) + { + Contract.Requires(seq != null); + + int thisLen = this.length; + if (thisStart < 0 || count > thisLen - thisStart) + { + return false; + } + if (start < 0 || count > seq.Count - start) + { + return false; + } + + thisStart += this.offset; + int thisEnd = thisStart + count; + while (thisStart < thisEnd) + { + if (!EqualsIgnoreCase(ByteToChar(this.value[thisStart++]), seq[start++])) + { + return false; + } + } + + return true; + } + + public AsciiString Replace(char oldChar, char newChar) + { + if (oldChar > MaxCharValue) + { + return this; + } + + byte oldCharByte = CharToByte(oldChar); + int index = this.ForEachByte(new IndexOfProcessor(oldCharByte)); + if (index == -1) + { + return this; + } + + byte newCharByte = CharToByte(newChar); + var buffer = new byte[this.length]; + for (int i = 0, j = this.offset; i < buffer.Length; i++, j++) + { + byte b = this.value[j]; + if (b == oldCharByte) + { + b = newCharByte; + } + buffer[i] = b; + } + + return new AsciiString(buffer, false); + } + + public bool StartsWith(ICharSequence prefix) => this.StartsWith(prefix, 0); + + public bool StartsWith(ICharSequence prefix, int start) => this.RegionMatches(start, prefix, 0, prefix.Count); + + public AsciiString ToLowerCase() + { + bool lowercased = true; + int i, j; + int len = this.length + this.offset; + for (i = this.offset; i < len; ++i) + { + byte b = this.value[i]; + if (b >= 'A' && b <= 'Z') + { + lowercased = false; + break; + } + } + + // Check if this string does not contain any uppercase characters. + if (lowercased) + { + return this; + } + + var newValue = new byte[this.length]; + for (i = 0, j = this.offset; i < newValue.Length; ++i, ++j) + { + newValue[i] = ToLowerCase(this.value[j]); + } + + return new AsciiString(newValue, false); + } + + public AsciiString ToUpperCase() + { + bool uppercased = true; + int i, j; + int len = this.length + this.offset; + for (i = this.offset; i < len; ++i) + { + byte b = this.value[i]; + if (b >= 'a' && b <= 'z') + { + uppercased = false; + break; + } + } + + // Check if this string does not contain any lowercase characters. + if (uppercased) + { + return this; + } + + var newValue = new byte[this.length]; + for (i = 0, j = this.offset; i < newValue.Length; ++i, ++j) + { + newValue[i] = ToUpperCase(this.value[j]); + } + + return new AsciiString(newValue, false); + } + + public static ICharSequence Trim(ICharSequence c) + { + if (c is AsciiString asciiString) + { + return asciiString.Trim(); + } + int start = 0; + int last = c.Count - 1; + int end = last; + while (start <= end && c[start] <= ' ') + { + start++; + } + while (end >= start && c[end] <= ' ') + { + end--; + } + if (start == 0 && end == last) + { + return c; + } + return c.SubSequence(start, end + 1); + } + + public AsciiString Trim() + { + int start = this.offset; + int last = this.offset + this.length - 1; + int end = last; + while (start <= end && this.value[start] <= ' ') + { + start++; + } + while (end >= start && this.value[end] <= ' ') + { + end--; + } + if (start == 0 && end == last) + { + return this; + } + + return new AsciiString(this.value, start, end - start + 1, false); + } + + public bool ContentEquals(ICharSequence a) + { + if (a == null || a.Count != this.length) + { + return false; + } + + if (a is AsciiString asciiString) + { + return this.Equals(asciiString); + } + + for (int i = this.offset, j = 0; j < a.Count; ++i, ++j) + { + if (ByteToChar(this.value[i]) != a[j]) + { + return false; + } + } + + return true; + } + + public AsciiString[] Split(char delim) + { + List res = InternalThreadLocalMap.Get().AsciiStringList(); + + int start = 0; + int count = this.length; + for (int i = start; i < count; i++) + { + if (this[i] == delim) + { + if (start == i) + { + res.Add(Empty); + } + else + { + res.Add(new AsciiString(this.value, start + this.offset, i - start, false)); + } + start = i + 1; + } + } + + if (start == 0) + { + // If no delimiter was found in the value + res.Add(this); + } + else + { + if (start != count) + { + // Add the last element if it's not empty. + res.Add(new AsciiString(this.value, start + this.offset, count - start, false)); + } + else + { + // Truncate trailing empty elements. + while (res.Count > 0) + { + int i = res.Count - 1; + if (!res[i].IsEmpty) + { + res.RemoveAt(i); + } + else + { + break; + } + } + } + } + + var strings = new AsciiString[res.Count]; + res.CopyTo(strings); + return strings; + } + + // ReSharper disable NonReadonlyMemberInGetHashCode + public override int GetHashCode() + { + int h = this.hash; + if (h == 0) + { + h = PlatformDependent.HashCodeAscii(this.value, this.offset, this.length); + this.hash = h; + } + + return h; + } + // ReSharper restore NonReadonlyMemberInGetHashCode + + public bool Equals(AsciiString other) + { + if (other == null) + { + return false; + } + if (ReferenceEquals(this, other)) + { + return true; + } + + return this.length == other.length + && this.GetHashCode() == other.GetHashCode() + && PlatformDependent.ByteArrayEquals(this.value, this.offset, other.value, other.offset, this.length); + } + + public override bool Equals(object obj) + { + if (obj == null) + { + return false; + } + if (ReferenceEquals(this, obj)) + { + return true; + } + + if (obj is AsciiString ascii) + { + return this.Equals(ascii); + } + if (obj is ICharSequence seq) + { + return this.ContentEquals(seq); + } + + return false; + } + + public override string ToString() + { + if (this.stringValue != null) + { + return this.stringValue; + } + + this.stringValue = this.ToString(0); + return this.stringValue; + } + + public string ToString(int start) => this.ToString(start, this.length); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public unsafe string ToString(int start, int end) + { + int count = end - start; + if (MathUtil.IsOutOfBounds(start, count, this.length)) + { + ThrowIndexOutOfRangeException_SrcIndex(start, count, this.length); + } + if (count == 0) + { + return string.Empty; + } + + fixed (byte* p = &this.value[this.offset + start]) + { + return Marshal.PtrToStringAnsi((IntPtr)p, count); + } + } + + public bool ParseBoolean() => this.length >= 1 && this.value[this.offset] != 0; + + public char ParseChar() => this.ParseChar(0); + + public char ParseChar(int start) + { + if (start + 1 >= this.length) + { + throw new IndexOutOfRangeException($"2 bytes required to convert to character. index {start} would go out of bounds."); + } + + int startWithOffset = start + this.offset; + + return (char)((ByteToChar(this.value[startWithOffset]) << 8) + | ByteToChar(this.value[startWithOffset + 1])); + } + + public short ParseShort() => this.ParseShort(0, this.length, 10); + + public short ParseShort(int radix) => this.ParseShort(0, this.length, radix); + + public short ParseShort(int start, int end) => this.ParseShort(start, end, 10); + + public short ParseShort(int start, int end, int radix) + { + int intValue = this.ParseInt(start, end, radix); + short result = (short)intValue; + if (result != intValue) + { + throw new FormatException(this.SubSequence(start, end).ToString()); + } + + return result; + } + + public int ParseInt() => this.ParseInt(0, this.length, 10); + + public int ParseInt(int radix) => this.ParseInt(0, this.length, radix); + + public int ParseInt(int start, int end) => this.ParseInt(start, end, 10); + + public int ParseInt(int start, int end, int radix) + { + if (radix < CharUtil.MinRadix || radix > CharUtil.MaxRadix) + { + throw new FormatException($"Radix must be from {CharUtil.MinRadix} to {CharUtil.MaxRadix}"); + } + if (start == end) + { + throw new FormatException($"Content is empty because {start} and {end} are the same."); + } + + int i = start; + bool negative = this.ByteAt(i) == '-'; + if (negative && ++i == end) + { + throw new FormatException(this.SubSequence(start, end).ToString()); + } + + return this.ParseInt(i, end, radix, negative); + } + + int ParseInt(int start, int end, int radix, bool negative) + { + int max = int.MinValue / radix; + int result = 0; + int currOffset = start; + while (currOffset < end) + { + int digit = CharUtil.Digit((char)(this.value[currOffset++ + this.offset]), radix); + if (digit == -1) + { + throw new FormatException(this.SubSequence(start, end).ToString()); + } + if (max > result) + { + throw new FormatException(this.SubSequence(start, end).ToString()); + } + int next = result * radix - digit; + if (next > result) + { + throw new FormatException(this.SubSequence(start, end).ToString()); + } + result = next; + } + + if (!negative) + { + result = -result; + if (result < 0) + { + throw new FormatException(this.SubSequence(start, end).ToString()); + } + } + + return result; + } + + public long ParseLong() => this.ParseLong(0, this.length, 10); + + public long ParseLong(int radix) => this.ParseLong(0, this.length, radix); + + public long ParseLong(int start, int end) => this.ParseLong(start, end, 10); + + public long ParseLong(int start, int end, int radix) + { + if (radix < CharUtil.MinRadix || radix > CharUtil.MaxRadix) + { + throw new FormatException($"Radix must be from {CharUtil.MinRadix} to {CharUtil.MaxRadix}"); + } + + if (start == end) + { + throw new FormatException($"Content is empty because {start} and {end} are the same."); + } + + int i = start; + bool negative = this.ByteAt(i) == '-'; + if (negative && ++i == end) + { + throw new FormatException(this.SubSequence(start, end).ToString()); + } + + return this.ParseLong(i, end, radix, negative); + } + + long ParseLong(int start, int end, int radix, bool negative) + { + long max = long.MinValue / radix; + long result = 0; + int currOffset = start; + while (currOffset < end) + { + int digit = CharUtil.Digit((char)(this.value[currOffset++ + this.offset]), radix); + if (digit == -1) + { + throw new FormatException(this.SubSequence(start, end).ToString()); + } + if (max > result) + { + throw new FormatException(this.SubSequence(start, end).ToString()); + } + long next = result * radix - digit; + if (next > result) + { + throw new FormatException(this.SubSequence(start, end).ToString()); + } + result = next; + } + + if (!negative) + { + result = -result; + if (result < 0) + { + throw new FormatException(this.SubSequence(start, end).ToString()); + } + } + + return result; + } + + public float ParseFloat() => this.ParseFloat(0, this.length); + + public float ParseFloat(int start, int end) => Convert.ToSingle(this.ToString(start, end)); + + public double ParseDouble() => this.ParseDouble(0, this.length); + + public double ParseDouble(int start, int end) => Convert.ToDouble(this.ToString(start, end)); + + public static AsciiString Of(string value) => new AsciiString(value); + + public static AsciiString Of(ICharSequence charSequence) => charSequence is AsciiString s ? s : new AsciiString(charSequence); + + public static AsciiString Cached(string value) + { + var asciiString = new AsciiString(value); + asciiString.stringValue = value; + return asciiString; + } + + public static int GetHashCode(ICharSequence value) + { + if (value == null) + { + return 0; + } + if (value is AsciiString) + { + return value.GetHashCode(); + } + + return PlatformDependent.HashCodeAscii(value); + } + + public static bool Contains(ICharSequence a, ICharSequence b) => Contains(a, b, DefaultCharComparator); + + public static bool ContainsIgnoreCase(ICharSequence a, ICharSequence b) => Contains(a, b, AsciiCaseInsensitiveCharComparator); + + public static bool ContentEqualsIgnoreCase(ICharSequence a, ICharSequence b) + { + if (a == null || b == null) + { + return ReferenceEquals(a, b); + } + + if (a is AsciiString stringA) + { + return stringA.ContentEqualsIgnoreCase(b); + } + if (b is AsciiString stringB) + { + return stringB.ContentEqualsIgnoreCase(a); + } + + if (a.Count != b.Count) + { + return false; + } + for (int i = 0, j = 0; i < a.Count; ++i, ++j) + { + if (!EqualsIgnoreCase(a[i], b[j])) + { + return false; + } + } + + return true; + } + + public static bool ContainsContentEqualsIgnoreCase(ICollection collection, ICharSequence value) + { + foreach (ICharSequence v in collection) + { + if (ContentEqualsIgnoreCase(value, v)) + { + return true; + } + } + + return false; + } + + public static bool ContainsAllContentEqualsIgnoreCase(ICollection a, ICollection b) + { + foreach (AsciiString v in b) + { + if (!ContainsContentEqualsIgnoreCase(a, v)) + { + return false; + } + } + + return true; + } + + public static bool ContentEquals(ICharSequence a, ICharSequence b) + { + if (a == null || b == null) + { + return ReferenceEquals(a, b); + } + + if (a is AsciiString stringA) + { + return stringA.ContentEquals(b); + } + if (b is AsciiString stringB) + { + return stringB.ContentEquals(a); + } + + if (a.Count != b.Count) + { + return false; + } + + for (int i = 0; i < a.Count; ++i) + { + if (a[i] != b[i]) + { + return false; + } + } + + return true; + } + + static bool Contains(ICharSequence a, ICharSequence b, ICharEqualityComparator comparator) + { + if (a == null || b == null || a.Count < b.Count) + { + return false; + } + if (b.Count == 0) + { + return true; + } + + int bStart = 0; + for (int i = 0; i < a.Count; ++i) + { + if (comparator.CharEquals(b[bStart], a[i])) + { + // If b is consumed then true. + if (++bStart == b.Count) + { + return true; + } + } + else if (a.Count - i < b.Count) + { + // If there are not enough characters left in a for b to be contained, then false. + return false; + } + else + { + bStart = 0; + } + } + + return false; + } + + static bool RegionMatchesCharSequences(ICharSequence cs, int csStart, + ICharSequence seq, int start, int length, ICharEqualityComparator charEqualityComparator) + { + //general purpose implementation for CharSequences + if (csStart < 0 || length > cs.Count - csStart) + { + return false; + } + if (start < 0 || length > seq.Count - start) + { + return false; + } + + int csIndex = csStart; + int csEnd = csIndex + length; + int stringIndex = start; + + while (csIndex < csEnd) + { + char c1 = cs[csIndex++]; + char c2 = seq[stringIndex++]; + + if (!charEqualityComparator.CharEquals(c1, c2)) + { + return false; + } + } + + return true; + } + + public static bool RegionMatches(ICharSequence cs, bool ignoreCase, int csStart, ICharSequence seq, int start, int length) + { + if (cs == null || seq == null) + { + return false; + } + if (cs is StringCharSequence stringCharSequence && seq is StringCharSequence) + { + return ignoreCase + ? stringCharSequence.RegionMatchesIgnoreCase(csStart, seq, start, length) + : stringCharSequence.RegionMatches (csStart, seq, start, length); + } + if (cs is AsciiString asciiString) + { + return ignoreCase + ? asciiString.RegionMatchesIgnoreCase(csStart, seq, start, length) + : asciiString.RegionMatches(csStart, seq, start, length); + } + + return RegionMatchesCharSequences(cs, csStart, seq, start, length, + ignoreCase ? GeneralCaseInsensitiveComparator : DefaultCharComparator); + } + + public static bool RegionMatchesAscii(ICharSequence cs, bool ignoreCase, int csStart, ICharSequence seq, int start, int length) + { + if (cs == null || seq == null) + { + return false; + } + + if (!ignoreCase && cs is StringCharSequence && seq is StringCharSequence) + { + //we don't call regionMatches from String for ignoreCase==true. It's a general purpose method, + //which make complex comparison in case of ignoreCase==true, which is useless for ASCII-only strings. + //To avoid applying this complex ignore-case comparison, we will use regionMatchesCharSequences + return cs.RegionMatches(csStart, seq, start, length); + } + + if (cs is AsciiString asciiString) + { + return ignoreCase + ? asciiString.RegionMatchesIgnoreCase(csStart, seq, start, length) + : asciiString.RegionMatches(csStart, seq, start, length); + } + + return RegionMatchesCharSequences(cs, csStart, seq, start, length, + ignoreCase ? AsciiCaseInsensitiveCharComparator : DefaultCharComparator); + } + + public static int IndexOfIgnoreCase(ICharSequence str, ICharSequence searchStr, int startPos) + { + if (str == null || searchStr == null) + { + return IndexNotFound; + } + + if (startPos < 0) + { + startPos = 0; + } + int searchStrLen = searchStr.Count; + int endLimit = str.Count - searchStrLen + 1; + if (startPos > endLimit) + { + return IndexNotFound; + } + if (searchStrLen == 0) + { + return startPos; + } + for (int i = startPos; i < endLimit; i++) + { + if (RegionMatches(str, true, i, searchStr, 0, searchStrLen)) + { + return i; + } + } + + return IndexNotFound; + } + + public static int IndexOfIgnoreCaseAscii(ICharSequence str, ICharSequence searchStr, int startPos) + { + if (str == null || searchStr == null) + { + return IndexNotFound; + } + + if (startPos < 0) + { + startPos = 0; + } + int searchStrLen = searchStr.Count; + int endLimit = str.Count - searchStrLen + 1; + if (startPos > endLimit) + { + return IndexNotFound; + } + if (searchStrLen == 0) + { + return startPos; + } + for (int i = startPos; i < endLimit; i++) + { + if (RegionMatchesAscii(str, true, i, searchStr, 0, searchStrLen)) + { + return i; + } + } + + return IndexNotFound; + } + + public static int IndexOf(ICharSequence cs, char searchChar, int start) + { + if (cs is StringCharSequence stringCharSequence) + { + return stringCharSequence.IndexOf(searchChar, start); + } + else if (cs is AsciiString asciiString) + { + return asciiString.IndexOf(searchChar, start); + } + if (cs == null) + { + return IndexNotFound; + } + int sz = cs.Count; + if (start < 0) + { + start = 0; + } + for (int i = start; i < sz; i++) + { + if (cs[i] == searchChar) + { + return i; + } + } + + return IndexNotFound; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static bool EqualsIgnoreCase(byte a, byte b) => a == b || ToLowerCase(a) == ToLowerCase(b); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static bool EqualsIgnoreCase(char a, char b) => a == b || ToLowerCase(a) == ToLowerCase(b); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static byte ToLowerCase(byte b) => IsUpperCase(b) ? (byte)(b + 32) : b; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static char ToLowerCase(char c) => IsUpperCase(c) ? (char)(c + 32) : c; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static byte ToUpperCase(byte b) => IsLowerCase(b) ? (byte)(b - 32) : b; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static bool IsLowerCase(byte value) => value >= 'a' && value <= 'z'; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool IsUpperCase(byte value) => value >= 'A' && value <= 'Z'; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool IsUpperCase(char value) => value >= 'A' && value <= 'Z'; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static byte CharToByte(char c) => c > MaxCharValue ? Replacement : unchecked((byte)c); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static char ByteToChar(byte b) => (char)(b); + + public static explicit operator string(AsciiString value) => value?.ToString() ?? string.Empty; + + public static explicit operator AsciiString(string value) => value != null ? new AsciiString(value) : Empty; + + static unsafe void GetBytes(char* chars, int length, byte* bytes) + { + char* charEnd = chars + length; + while (chars < charEnd) + { + char ch = *(chars++); + // ByteToChar + if (ch > MaxCharValue) + { + *(bytes++) = Replacement; + } + else + { + *(bytes++) = unchecked((byte)ch); + } + } + } + + public int HashCode(bool ignoreCase) => !ignoreCase ? this.GetHashCode() : CaseInsensitiveHasher.GetHashCode(this); + + // + // Compares the specified string to this string using the ASCII values of the characters. Returns 0 if the strings + // contain the same characters in the same order. Returns a negative integer if the first non-equal character in + // this string has an ASCII value which is less than the ASCII value of the character at the same position in the + // specified string, or if this string is a prefix of the specified string. Returns a positive integer if the first + // non-equal character in this string has a ASCII value which is greater than the ASCII value of the character at + // the same position in the specified string, or if the specified string is a prefix of this string. + // + public int CompareTo(AsciiString other) + { + if (ReferenceEquals(this, other)) + { + return 0; + } + + int length1 = this.length; + int length2 = other.length; + int minLength = Math.Min(length1, length2); + for (int i = 0, j = this.offset; i < minLength; i++, j++) + { + int result = ByteToChar(this.value[j]) - other[i]; + if (result != 0) + { + return result; + } + } + + return length1 - length2; + } + + public int CompareTo(object obj) => this.CompareTo(obj as AsciiString); + + public IEnumerator GetEnumerator() => new CharSequenceEnumerator(this); + + IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator(); + + static void ThrowIndexOutOfRangeException_Start(int start, int length, int count) + { + throw GetIndexOutOfRangeException(); + + IndexOutOfRangeException GetIndexOutOfRangeException() + { + return new IndexOutOfRangeException(string.Format("expected: 0 <= start({0}) <= start + length({1}) <= value.length({2})", start, length, count)); + } + } + + static void ThrowIndexOutOfRangeException_StartEnd(int start, int end, int length) + { + throw GetIndexOutOfRangeException(); + + IndexOutOfRangeException GetIndexOutOfRangeException() + { + return new IndexOutOfRangeException(string.Format("expected: 0 <= start({0}) <= end ({1}) <= length({2})", start, end, length)); + } + } + + static void ThrowIndexOutOfRangeException_SrcIndex(int start, int count, int length) + { + throw GetIndexOutOfRangeException(); + + IndexOutOfRangeException GetIndexOutOfRangeException() + { + return new IndexOutOfRangeException(string.Format("expected: 0 <= start({0}) <= srcIdx + length({1}) <= srcLen({2})", start, count, length)); + } + } + + static void ThrowIndexOutOfRangeException_Index(int index, int length, int count) + { + throw GetIndexOutOfRangeException(); + + IndexOutOfRangeException GetIndexOutOfRangeException() + { + return new IndexOutOfRangeException(string.Format("expected: 0 <= index({0} <= start + length({1}) <= length({2})", index, length, count)); + } + } + + static void ThrowIndexOutOfRangeException_Index(int index, int length) + { + throw GetIndexOutOfRangeException(); + + IndexOutOfRangeException GetIndexOutOfRangeException() + { + return new IndexOutOfRangeException(string.Format("index: {0} must be in the range [0,{1})", index, length)); + } + } + + interface ICharEqualityComparator + { + bool CharEquals(char a, char b); + } + + sealed class DefaultCharEqualityComparator : ICharEqualityComparator + { + public bool CharEquals(char a, char b) => a == b; + } + + sealed class GeneralCaseInsensitiveCharEqualityComparator : ICharEqualityComparator + { + public bool CharEquals(char a, char b) => + char.ToUpper(a) == char.ToUpper(b) || char.ToLower(a) == char.ToLower(b); + } + + sealed class AsciiCaseInsensitiveCharEqualityComparator : ICharEqualityComparator + { + public bool CharEquals(char a, char b) => EqualsIgnoreCase(a, b); + } + } +} diff --git a/src/DotNetty.Common/Utilities/ByteProcessor.cs b/src/DotNetty.Common/Utilities/ByteProcessor.cs index b13abfe..6422222 100644 --- a/src/DotNetty.Common/Utilities/ByteProcessor.cs +++ b/src/DotNetty.Common/Utilities/ByteProcessor.cs @@ -6,6 +6,8 @@ namespace DotNetty.Common.Utilities using System; using System.Diagnostics.Contracts; + using static ByteProcessorUtils; + /// /// Provides a mechanism to iterate over a collection of bytes. /// @@ -64,46 +66,56 @@ namespace DotNetty.Common.Utilities /// /// Aborts on a {@code CR ('\r')}. /// - public static IByteProcessor FindCR = new IndexOfProcessor((byte)'\r'); + public static IByteProcessor FindCR = new IndexOfProcessor(CarriageReturn); /// /// Aborts on a non-{@code CR ('\r')}. /// - public static IByteProcessor FindNonCR = new IndexNotOfProcessor((byte)'\r'); + public static IByteProcessor FindNonCR = new IndexNotOfProcessor(CarriageReturn); /// /// Aborts on a {@code LF ('\n')}. /// - public static IByteProcessor FindLF = new IndexOfProcessor((byte)'\n'); + public static IByteProcessor FindLF = new IndexOfProcessor(LineFeed); /// /// Aborts on a non-{@code LF ('\n')}. /// - public static IByteProcessor FindNonLF = new IndexNotOfProcessor((byte)'\n'); + public static IByteProcessor FindNonLF = new IndexNotOfProcessor(LineFeed); /// /// Aborts on a {@code CR (';')}. /// - public static IByteProcessor FindSemiCOLON = new IndexOfProcessor((byte)';'); + public static IByteProcessor FindSemicolon = new IndexOfProcessor((byte)';'); + + /// + /// Aborts on a comma {@code (',')}. + /// + public static IByteProcessor FindComma = new IndexOfProcessor((byte)','); + + /// + /// Aborts on a ascii space character ({@code ' '}). + /// + public static IByteProcessor FindAsciiSpace = new IndexOfProcessor(Space); /// /// Aborts on a {@code CR ('\r')} or a {@code LF ('\n')}. /// - public static IByteProcessor FindCrlf = new ByteProcessor(new Func(value => value != '\r' && value != '\n')); + public static IByteProcessor FindCrlf = new ByteProcessor(new Func(value => value != CarriageReturn && value != LineFeed)); /// /// Aborts on a byte which is neither a {@code CR ('\r')} nor a {@code LF ('\n')}. /// - public static IByteProcessor FindNonCrlf = new ByteProcessor(new Func(value => value == '\r' || value == '\n')); + public static IByteProcessor FindNonCrlf = new ByteProcessor(new Func(value => value == CarriageReturn || value == LineFeed)); /// /// Aborts on a linear whitespace (a ({@code ' '} or a {@code '\t'}). /// - public static IByteProcessor FindLinearWhitespace = new ByteProcessor(new Func(value => value != ' ' && value != '\t')); + public static IByteProcessor FindLinearWhitespace = new ByteProcessor(new Func(value => value != Space && value != HTab)); /// /// Aborts on a byte which is not a linear whitespace (neither {@code ' '} nor {@code '\t'}). /// - public static IByteProcessor FindNonLinearWhitespace = new ByteProcessor(new Func(value => value == ' ' || value == '\t')); + public static IByteProcessor FindNonLinearWhitespace = new ByteProcessor(new Func(value => value == Space || value == HTab)); } } \ No newline at end of file diff --git a/src/DotNetty.Common/Utilities/ByteProcessorUtils.cs b/src/DotNetty.Common/Utilities/ByteProcessorUtils.cs new file mode 100644 index 0000000..347ece1 --- /dev/null +++ b/src/DotNetty.Common/Utilities/ByteProcessorUtils.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Common.Utilities +{ + static class ByteProcessorUtils + { + internal static readonly byte Space = (byte)' '; + internal static readonly byte HTab = (byte)'\t'; + internal static readonly byte CarriageReturn = (byte)'\r'; + internal static readonly byte LineFeed = (byte)'\n'; + } +} diff --git a/src/DotNetty.Common/Utilities/CharSequenceEnumerator.cs b/src/DotNetty.Common/Utilities/CharSequenceEnumerator.cs new file mode 100644 index 0000000..fa9862b --- /dev/null +++ b/src/DotNetty.Common/Utilities/CharSequenceEnumerator.cs @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Common.Utilities +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + + struct CharSequenceEnumerator : IEnumerator + { + ICharSequence charSequence; + int index; + char currentElement; + + internal CharSequenceEnumerator(ICharSequence charSequence) + { + Contract.Requires(charSequence != null); + + this.charSequence = charSequence; + this.index = -1; + this.currentElement = (char)0; + } + + public bool MoveNext() + { + if (this.index < this.charSequence.Count - 1) + { + this.index++; + this.currentElement = this.charSequence[this.index]; + return true; + } + + this.index = this.charSequence.Count; + return false; + } + + object IEnumerator.Current + { + get + { + if (this.index == -1) + { + throw new InvalidOperationException("Enumerator not initialized."); + } + if (this.index >= this.charSequence.Count) + { + throw new InvalidOperationException("Eumerator already completed."); + } + return this.currentElement; + } + } + + public char Current + { + get + { + if (this.index == -1) + { + throw new InvalidOperationException("Enumerator not initialized."); + } + if (this.index >= this.charSequence.Count) + { + throw new InvalidOperationException("Eumerator already completed."); + } + return this.currentElement; + } + } + + public void Reset() + { + this.index = -1; + this.currentElement = (char)0; + } + + public void Dispose() + { + if (this.charSequence != null) + { + this.index = this.charSequence.Count; + } + this.charSequence = null; + } + } +} diff --git a/src/DotNetty.Common/Utilities/CharUtil.cs b/src/DotNetty.Common/Utilities/CharUtil.cs index 2598c02..087575c 100644 --- a/src/DotNetty.Common/Utilities/CharUtil.cs +++ b/src/DotNetty.Common/Utilities/CharUtil.cs @@ -3,10 +3,589 @@ namespace DotNetty.Common.Utilities { + using System; + using System.Collections.Generic; + using System.Diagnostics.Contracts; using System.Runtime.CompilerServices; public static class CharUtil { + public static readonly string Digits = "0123456789ABCDEF"; + + public static readonly int MinRadix = 2; + public static readonly int MaxRadix = 36; + + const string DigitKeys = "0Aa\u0660\u06f0\u0966\u09e6\u0a66\u0ae6\u0b66\u0be7\u0c66\u0ce6\u0d66\u0e50\u0ed0\u0f20\u1040\u1369\u17e0\u1810\uff10\uff21\uff41"; + static readonly char[] DigitValues = "90Z7zW\u0669\u0660\u06f9\u06f0\u096f\u0966\u09ef\u09e6\u0a6f\u0a66\u0aef\u0ae6\u0b6f\u0b66\u0bef\u0be6\u0c6f\u0c66\u0cef\u0ce6\u0d6f\u0d66\u0e59\u0e50\u0ed9\u0ed0\u0f29\u0f20\u1049\u1040\u1371\u1368\u17e9\u17e0\u1819\u1810\uff19\uff10\uff3a\uff17\uff5a\uff37".ToCharArray(); + + public static int BinarySearchRange(string data, char c) + { + char value = '\u0000'; + int low = 0, mid = -1, high = data.Length - 1; + while (low <= high) + { + mid = (low + high) >> 1; + value = data[mid]; + if (c > value) + low = mid + 1; + else if (c == value) + return mid; + else + high = mid - 1; + } + + return mid - (c < value ? 1 : 0); + } + + public static int ParseInt(ICharSequence seq, int start, int end, int radix) + { + Contract.Requires(seq != null); + Contract.Requires(radix >= MinRadix && radix <= MaxRadix); + + if (start == end) + { + throw new FormatException(); + } + + int i = start; + bool negative = seq[i] == '-'; + if (negative && ++i == end) + { + throw new FormatException(seq.SubSequence(start, end).ToString()); + } + + return ParseInt(seq, i, end, radix, negative); + } + + public static int ParseInt(ICharSequence seq) => ParseInt(seq, 0, seq.Count, 10, false); + + public static int ParseInt(ICharSequence seq, int start, int end, int radix, bool negative) + { + Contract.Requires(seq != null); + Contract.Requires(radix >= MinRadix && radix <= MaxRadix); + + int max = int.MinValue / radix; + int result = 0; + int currOffset = start; + while (currOffset < end) + { + int digit = Digit((char)(seq[currOffset++] & 0xFF), radix); + if (digit == -1) + { + throw new FormatException(seq.SubSequence(start, end).ToString()); + } + if (max > result) + { + throw new FormatException(seq.SubSequence(start, end).ToString()); + } + int next = result * radix - digit; + if (next > result) + { + throw new FormatException(seq.SubSequence(start, end).ToString()); + } + result = next; + } + + if (!negative) + { + result = -result; + if (result < 0) + { + throw new FormatException(seq.SubSequence(start, end).ToString()); + } + } + + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long ParseLong(ICharSequence str, int radix = 10) + { + if (str is AsciiString asciiString) + { + return asciiString.ParseLong(radix); + } + + if (str == null + || radix < MinRadix + || radix > MaxRadix) + { + ThrowFormatException(str); + } + + // ReSharper disable once PossibleNullReferenceException + int length = str.Count; + int i = 0; + if (length == 0) + { + ThrowFormatException(str); + } + bool negative = str[i] == '-'; + if (negative && ++i == length) + { + ThrowFormatException(str); + } + + return ParseLong(str, i, radix, negative); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static long ParseLong(ICharSequence str, int offset, int radix, bool negative) + { + long max = long.MinValue / radix; + long result = 0, length = str.Count; + while (offset < length) + { + int digit = Digit(str[offset++], radix); + if (digit == -1) + { + ThrowFormatException(str); + } + if (max > result) + { + ThrowFormatException(str); + } + long next = result * radix - digit; + if (next > result) + { + ThrowFormatException(str); + } + result = next; + } + + if (!negative) + { + result = -result; + if (result < 0) + { + ThrowFormatException(str); + } + } + + return result; + } + + static void ThrowFormatException(ICharSequence str) => throw new FormatException(str.ToString()); + + public static bool IsNullOrEmpty(ICharSequence sequence) => sequence == null || sequence.Count == 0; + + public static ICharSequence[] Split(ICharSequence sequence, params char[] delimiters) => Split(sequence, 0, delimiters); + + public static ICharSequence[] Split(ICharSequence sequence, int startIndex, params char[] delimiters) + { + Contract.Requires(sequence != null); + Contract.Requires(delimiters != null); + Contract.Requires(startIndex >= 0 && startIndex < sequence.Count); + + List result = InternalThreadLocalMap.Get().CharSequenceList(); + + int i = startIndex; + int length = sequence.Count; + + while (i < length) + { + while (i < length && IndexOf(delimiters, sequence[i]) >= 0) + { + i++; + } + + int position = i; + if (i < length) + { + if (IndexOf(delimiters, sequence[position]) >= 0) + { + result.Add(sequence.SubSequence(position++, i + 1)); + } + else + { + ICharSequence seq = null; + for (position++; position < length; position++) + { + if (IndexOf(delimiters, sequence[position]) >= 0) + { + seq = sequence.SubSequence(i, position); + break; + } + } + result.Add(seq ?? sequence.SubSequence(i)); + } + i = position; + } + } + + return result.Count == 0 ? new[] { sequence } : result.ToArray(); + } + + internal static bool ContentEquals(ICharSequence left, ICharSequence right) + { + if (left == null || right == null) + { + return ReferenceEquals(left, right); + } + + if (ReferenceEquals(left, right)) + { + return true; + } + if (left.Count != right.Count) + { + return false; + } + + for (int i = 0; i < left.Count; i++) + { + char c1 = left[i]; + char c2 = right[i]; + if (c1 != c2 + && char.ToUpper(c1).CompareTo(char.ToUpper(c2)) != 0 + && char.ToLower(c1).CompareTo(char.ToLower(c2)) != 0) + { + return false; + } + } + + return true; + } + + internal static bool ContentEqualsIgnoreCase(ICharSequence left, ICharSequence right) + { + if (left == null || right == null) + { + return ReferenceEquals(left, right); + } + + if (ReferenceEquals(left, right)) + { + return true; + } + if (left.Count != right.Count) + { + return false; + } + + for (int i = 0; i < left.Count; i++) + { + char c1 = left[i]; + char c2 = right[i]; + if (char.ToLower(c1).CompareTo(char.ToLower(c2)) != 0) + { + return false; + } + } + + return true; + } + + public static bool RegionMatches(string value, int thisStart, ICharSequence other, int start, int length) + { + Contract.Requires(value != null && other != null); + + if (start < 0 + || other.Count - start < length) + { + return false; + } + + if (thisStart < 0 + || value.Length - thisStart < length) + { + return false; + } + + if (length <= 0) + { + return true; + } + + int o1 = thisStart; + int o2 = start; + for (int i = 0; i < length; ++i) + { + if (value[o1 + i] != other[o2 + i]) + { + return false; + } + } + + return true; + } + + public static bool RegionMatchesIgnoreCase(string value, int thisStart, ICharSequence other, int start, int length) + { + Contract.Requires(value != null && other != null); + + if (thisStart < 0 + || length > value.Length - thisStart) + { + return false; + } + + if (start < 0 || length > other.Count - start) + { + return false; + } + + int end = thisStart + length; + while (thisStart < end) + { + char c1 = value[thisStart++]; + char c2 = other[start++]; + if (c1 != c2 + && char.ToUpper(c1).CompareTo(char.ToUpper(c2)) != 0 + && char.ToLower(c1).CompareTo(char.ToLower(c2)) != 0) + { + return false; + } + } + + return true; + } + + public static bool RegionMatches(IReadOnlyList value, int thisStart, ICharSequence other, int start, int length) + { + Contract.Requires(value != null && other != null); + + if (start < 0 || other.Count - start < length) + { + return false; + } + + if (thisStart < 0 || value.Count - thisStart < length) + { + return false; + } + + if (length <= 0) + { + return true; + } + + int o1 = thisStart; + int o2 = start; + for (int i = 0; i < length; ++i) + { + if (value[o1 + i] != other[o2 + i]) + { + return false; + } + } + + return true; + } + + public static bool RegionMatchesIgnoreCase(IReadOnlyList value, int thisStart, ICharSequence other, int start, int length) + { + Contract.Requires(value != null && other != null); + + if (thisStart < 0 || length > value.Count - thisStart) + { + return false; + } + + if (start < 0 || length > other.Count - start) + { + return false; + } + + int end = thisStart + length; + while (thisStart < end) + { + char c1 = value[thisStart++]; + char c2 = other[start++]; + if (c1 != c2 + && char.ToUpper(c1).CompareTo(char.ToUpper(c2)) != 0 + && char.ToLower(c1).CompareTo(char.ToLower(c2)) != 0) + { + return false; + } + } + + return true; + } + + public static ICharSequence SubstringAfter(this ICharSequence value, char delim) + { + int pos = value.IndexOf(delim); + return pos >= 0 ? value.SubSequence(pos + 1, value.Count) : null; + } + + public static ICharSequence Trim(ICharSequence sequence) + { + Contract.Requires(sequence != null); + + int length = sequence.Count; + int start = IndexOfFirstNonWhiteSpace(sequence); + if (start == length) + { + return StringCharSequence.Empty; + } + + int last = IndexOfLastNonWhiteSpaceChar(sequence, start); + + length = last - start + 1; + return length == sequence.Count + ? sequence + : sequence.SubSequence(start, last + 1); + } + + static int IndexOfFirstNonWhiteSpace(IReadOnlyList value) + { + Contract.Requires(value != null); + + int i = 0; + while (i < value.Count && char.IsWhiteSpace(value[i])) + { + i++; + } + + return i; + } + + static int IndexOfLastNonWhiteSpaceChar(IReadOnlyList value, int start) + { + int i = value.Count - 1; + while (i > start && char.IsWhiteSpace(value[i])) + { + i--; + } + + return i; + } + + public static bool Contains(IReadOnlyList value, char c) + { + if (value != null) + { + int length = value.Count; + for (int i = 0; i < length; i++) + { + if (value[i] == c) + { + return true; + } + } + } + return false; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Digit(byte b) + { + const byte First = (byte)'0'; + const byte Last = (byte)'9'; + + if (b < First || b > Last) + { + return -1; + } + + return b - First; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Digit(char c, int radix) + { + if (radix >= MinRadix && radix <= MaxRadix) + { + if (c < 128) + { + int result = -1; + if ('0' <= c && c <= '9') + { + result = c - '0'; + } + else if ('a' <= c && c <= 'z') + { + result = c - ('a' - 10); + } + else if ('A' <= c && c <= 'Z') + { + result = c - ('A' - 10); + } + + return result < radix ? result : -1; + } + + int result1 = BinarySearchRange(DigitKeys, c); + if (result1 >= 0 && c <= DigitValues[result1 * 2]) + { + int value = (char)(c - DigitValues[result1 * 2 + 1]); + if (value >= radix) + { + return -1; + } + return value; + } + } + + return -1; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool IsISOControl(int c) => (c >= 0 && c <= 0x1f) || (c >= 0x7f && c <= 0x9f); + + public static int IndexOf(this ICharSequence cs, char searchChar, int start) + { + if (cs == null) + { + return AsciiString.IndexNotFound; + } + + if (cs is StringCharSequence sequence) + { + return sequence.IndexOf(searchChar, start); + } + + if (cs is AsciiString s) + { + return s.IndexOf(searchChar, start); + } + + int sz = cs.Count; + if (start < 0) + { + start = 0; + } + for (int i = start; i < sz; i++) + { + if (cs[i] == searchChar) + { + return i; + } + } + + return -1; + } + + static int IndexOf(char[] tokens, char value) + { + for (int i = 0; i < tokens.Length; i++) + { + if (tokens[i] == value) + { + return i; + } + } + + return -1; + } + + public static int CodePointAt(IReadOnlyList seq, int index) + { + Contract.Requires(seq != null); + Contract.Requires(index >= 0 && index < seq.Count); + + char high = seq[index++]; + if (index >= seq.Count) + { + return high; + } + + char low = seq[index]; + + return IsSurrogatePair(high, low) ? ToCodePoint(high, low) : high; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static int ToCodePoint(char high, char low) { @@ -16,5 +595,29 @@ namespace DotNetty.Common.Utilities int l = low & 0x3FF; return (h | l) + 0x10000; } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static bool IsSurrogatePair(char high, char low) => char.IsHighSurrogate(high) && char.IsLowSurrogate(low); + + internal static int IndexOf(IReadOnlyList value, char ch, int start) + { + char upper = char.ToUpper(ch); + char lower = char.ToLower(ch); + int i = start; + while (i < value.Count) + { + char c1 = value[i]; + if (c1 == ch + && char.ToUpper(c1).CompareTo(upper) != 0 + && char.ToLower(c1).CompareTo(lower) != 0) + { + return i; + } + + i++; + } + + return -1; + } } } diff --git a/src/DotNetty.Common/Utilities/ICharSequence.cs b/src/DotNetty.Common/Utilities/ICharSequence.cs new file mode 100644 index 0000000..40ced38 --- /dev/null +++ b/src/DotNetty.Common/Utilities/ICharSequence.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Common.Utilities +{ + using System.Collections.Generic; + + public interface ICharSequence : IReadOnlyList + { + /// Start is the inclusive start index to begin the subsequence. + /// End is the exclusive end index to end the subsequence. + ICharSequence SubSequence(int start, int end); + + ICharSequence SubSequence(int start); + + int IndexOf(char ch, int start = 0); + + bool RegionMatches(int thisStart, ICharSequence seq, int start, int length); + + bool RegionMatchesIgnoreCase(int thisStart, ICharSequence seq, int start, int length); + + bool ContentEquals(ICharSequence other); + + bool ContentEqualsIgnoreCase(ICharSequence other); + + int HashCode(bool ignoreCase); + + string ToString(int start); + + string ToString(); + } +} diff --git a/src/DotNetty.Common/Utilities/IHashingStrategy.cs b/src/DotNetty.Common/Utilities/IHashingStrategy.cs new file mode 100644 index 0000000..fc70295 --- /dev/null +++ b/src/DotNetty.Common/Utilities/IHashingStrategy.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Common.Utilities +{ + using System.Collections.Generic; + + public interface IHashingStrategy : IEqualityComparer + { + int HashCode(T obj); + } + + public sealed class DefaultHashingStrategy : IHashingStrategy + { + public int GetHashCode(T obj) => obj.GetHashCode(); + + public int HashCode(T obj) => obj != null ? this.GetHashCode(obj) : 0; + + public bool Equals(T a, T b) => ReferenceEquals(a, b) || (!ReferenceEquals(a, null) && a.Equals(b)); + } +} diff --git a/src/DotNetty.Common/Utilities/Signal.cs b/src/DotNetty.Common/Utilities/Signal.cs new file mode 100644 index 0000000..72124bf --- /dev/null +++ b/src/DotNetty.Common/Utilities/Signal.cs @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Common.Utilities +{ + using System; + + public sealed class Signal : Exception, IConstant, IComparable, IComparable + { + static readonly SignalConstantPool Pool = new SignalConstantPool(); + + sealed class SignalConstantPool : ConstantPool + { + protected override IConstant NewConstant(int id, string name) => new Signal(id, name); + }; + + public static Signal ValueOf(string name) => (Signal)Pool.ValueOf(name); + + public static Signal ValueOf(Type firstNameComponent, string secondNameComponent) => (Signal)Pool.ValueOf(firstNameComponent, secondNameComponent); + + readonly SignalConstant constant; + + Signal(int id, string name) + { + this.constant = new SignalConstant(id, name); + } + + public void Expect(Signal signal) + { + if (!ReferenceEquals(this, signal)) + { + throw new InvalidOperationException($"unexpected signal: {signal}"); + } + } + + public int Id => this.constant.Id; + + public string Name => this.constant.Name; + + public override bool Equals(object obj) => ReferenceEquals(this, obj); + + public override int GetHashCode() => this.Id; + + public int CompareTo(object obj) + { + if (ReferenceEquals(this, obj)) + { + return 0; + } + if (!ReferenceEquals(obj, null) && obj is Signal) + { + return this.CompareTo((Signal)obj); + } + + throw new Exception("failed to compare two different signal constants"); + } + + public int CompareTo(Signal other) + { + if (ReferenceEquals(this, other)) + { + return 0; + } + + return this.constant.CompareTo(other.constant); + } + + public override string ToString() => this.Name; + + sealed class SignalConstant : AbstractConstant + { + public SignalConstant(int id, string name) : base(id, name) + { + } + } + } +} diff --git a/src/DotNetty.Common/Utilities/StringBuilderCharSequence.cs b/src/DotNetty.Common/Utilities/StringBuilderCharSequence.cs new file mode 100644 index 0000000..54d880c --- /dev/null +++ b/src/DotNetty.Common/Utilities/StringBuilderCharSequence.cs @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Common.Utilities +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Text; + + public sealed class StringBuilderCharSequence : ICharSequence, IEquatable + { + readonly StringBuilder builder; + readonly int offset; + + public StringBuilderCharSequence(int capacity = 0) + { + Contract.Requires(capacity >= 0); + + this.builder = new StringBuilder(capacity); + this.offset = 0; + this.Count = 0; + } + + public StringBuilderCharSequence(StringBuilder builder) : this(builder, 0, builder.Length) + { + } + + public StringBuilderCharSequence(StringBuilder builder, int offset, int count) + { + Contract.Requires(builder != null); + Contract.Requires(offset >= 0 && count >= 0); + Contract.Requires(offset <= builder.Length - count); + + this.builder = builder; + this.offset = offset; + this.Count = count; + } + + public ICharSequence SubSequence(int start) => this.SubSequence(start, this.Count); + + public ICharSequence SubSequence(int start, int end) + { + Contract.Requires(start >= 0 && end >= start); + Contract.Requires(end <= this.Count); + + return end == start + ? new StringBuilderCharSequence() + : new StringBuilderCharSequence(this.builder, this.offset + start, end - start); + } + + public int Count { get; private set; } + + public char this[int index] + { + get + { + Contract.Requires(index >= 0 && index < this.Count); + return this.builder[this.offset + index]; + } + } + + public void Append(string value) + { + this.builder.Append(value); + this.Count += value.Length; + } + + public void Append(string value, int index, int count) + { + this.builder.Append(value, index, count); + this.Count += count; + } + + public void Append(ICharSequence value) + { + if (value == null || value.Count == 0) + { + return; + } + + this.builder.Append(value); + this.Count += value.Count; + } + + public void Append(ICharSequence value, int index, int count) + { + if (value == null || count == 0) + { + return; + } + + this.Append(value.SubSequence(index, index + count)); + } + + public void Append(char value) + { + this.builder.Append(value); + this.Count++; + } + + public void Insert(int start, char value) + { + Contract.Requires(start >= 0 && start < this.Count); + + this.builder.Insert(this.offset + start, value); + this.Count++; + } + + public bool RegionMatches(int thisStart, ICharSequence seq, int start, int length) => + CharUtil.RegionMatches(this, this.offset + thisStart, seq, start, length); + + public bool RegionMatchesIgnoreCase(int thisStart, ICharSequence seq, int start, int length) => + CharUtil.RegionMatchesIgnoreCase(this, this.offset + thisStart, seq, start, length); + + public int IndexOf(char ch, int start = 0) => CharUtil.IndexOf(this, ch, start); + + public string ToString(int start) + { + Contract.Requires(start >= 0 && start < this.Count); + + return this.builder.ToString(this.offset + start, this.Count); + } + + public override string ToString() => this.Count == 0 ? string.Empty : this.ToString(0); + + public bool Equals(StringBuilderCharSequence other) + { + if (other == null) + { + return false; + } + if (ReferenceEquals(this, other)) + { + return true; + } + if (this.Count != other.Count) + { + return false; + } + + return this.builder.ToString(this.offset, this.Count) + .Equals(other.builder.ToString(other.offset, this.Count)); + } + + public override bool Equals(object obj) + { + if (obj == null) + { + return false; + } + if (ReferenceEquals(this, obj)) + { + return true; + } + + if (obj is StringBuilderCharSequence other) + { + return this.Equals(other); + } + if (obj is ICharSequence seq) + { + return this.ContentEquals(seq); + } + + return false; + } + + public int HashCode(bool ignoreCase) => ignoreCase + ? StringComparer.OrdinalIgnoreCase.GetHashCode(this.ToString()) + : StringComparer.Ordinal.GetHashCode(this.ToString()); + + public override int GetHashCode() => this.HashCode(true); + + public bool ContentEquals(ICharSequence other) => CharUtil.ContentEquals(this, other); + + public bool ContentEqualsIgnoreCase(ICharSequence other) => CharUtil.ContentEqualsIgnoreCase(this, other); + + public IEnumerator GetEnumerator() => new CharSequenceEnumerator(this); + + IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator(); + } +} diff --git a/src/DotNetty.Common/Utilities/StringCharSequence.cs b/src/DotNetty.Common/Utilities/StringCharSequence.cs new file mode 100644 index 0000000..d822d1b --- /dev/null +++ b/src/DotNetty.Common/Utilities/StringCharSequence.cs @@ -0,0 +1,158 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +// ReSharper disable ConvertToAutoProperty +// ReSharper disable ConvertToAutoPropertyWhenPossible +namespace DotNetty.Common.Utilities +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + + public sealed class StringCharSequence : ICharSequence, IEquatable + { + public static readonly StringCharSequence Empty = new StringCharSequence(string.Empty); + + readonly string value; + readonly int offset; + readonly int count; + + public StringCharSequence(string value) + { + Contract.Requires(value != null); + + this.value = value; + this.offset = 0; + this.count = this.value.Length; + } + + public StringCharSequence(string value, int offset, int count) + { + Contract.Requires(value != null); + Contract.Requires(offset >= 0 && count >= 0); + Contract.Requires(offset <= value.Length - count); + + this.value = value; + this.offset = offset; + this.count = count; + } + + public int Count => this.count; + + public static explicit operator string(StringCharSequence charSequence) + { + Contract.Requires(charSequence != null); + return charSequence.ToString(); + } + + public static explicit operator StringCharSequence(string value) + { + Contract.Requires(value != null); + + return value.Length > 0 ? new StringCharSequence(value) : Empty; + } + + public ICharSequence SubSequence(int start) => this.SubSequence(start, this.count); + + public ICharSequence SubSequence(int start, int end) + { + Contract.Requires(start >= 0 && end >= start); + Contract.Requires(end <= this.count); + + return end == start + ? Empty + : new StringCharSequence(this.value, this.offset + start, end - start); + } + + public char this[int index] + { + get + { + Contract.Requires(index >= 0 && index < this.count); + return this.value[this.offset + index]; + } + } + + public bool RegionMatches(int thisStart, ICharSequence seq, int start, int length) => + CharUtil.RegionMatches(this, thisStart, seq, start, length); + + public bool RegionMatchesIgnoreCase(int thisStart, ICharSequence seq, int start, int length) => + CharUtil.RegionMatchesIgnoreCase(this, thisStart, seq, start, length); + + public int IndexOf(char ch, int start = 0) + { + Contract.Requires(start >= 0 && start < this.count); + + int index = this.value.IndexOf(ch, this.offset + start); + return index < 0 ? index : index - this.offset; + } + + public int IndexOf(string target, int start = 0) => this.value.IndexOf(target, StringComparison.Ordinal); + + public string ToString(int start) + { + Contract.Requires(start >= 0 && start < this.count); + + return this.value.Substring(this.offset + start, this.count); + } + + public override string ToString() => this.count == 0 ? string.Empty : this.ToString(0); + + public bool Equals(StringCharSequence other) + { + if (other == null) + { + return false; + } + if (ReferenceEquals(this, other)) + { + return true; + } + if (this.count != other.count) + { + return false; + } + + return string.Compare(this.value, this.offset, other.value, other.offset, this.count, + StringComparison.Ordinal) == 0; + } + + public override bool Equals(object obj) + { + if (ReferenceEquals(obj, null)) + { + return false; + } + if (ReferenceEquals(this, obj)) + { + return true; + } + + if (obj is StringCharSequence other) + { + return this.Equals(other); + } + if (obj is ICharSequence seq) + { + return this.ContentEquals(seq); + } + + return false; + } + + public int HashCode(bool ignoreCase) => ignoreCase + ? StringComparer.OrdinalIgnoreCase.GetHashCode(this.ToString()) + : StringComparer.Ordinal.GetHashCode(this.ToString()); + + public override int GetHashCode() => this.HashCode(false); + + public bool ContentEquals(ICharSequence other) => CharUtil.ContentEquals(this, other); + + public bool ContentEqualsIgnoreCase(ICharSequence other) => CharUtil.ContentEqualsIgnoreCase(this, other); + + public IEnumerator GetEnumerator() => new CharSequenceEnumerator(this); + + IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator(); + } +} diff --git a/src/DotNetty.Common/Utilities/StringUtil.cs b/src/DotNetty.Common/Utilities/StringUtil.cs index 3b6f1f5..bb182ca 100644 --- a/src/DotNetty.Common/Utilities/StringUtil.cs +++ b/src/DotNetty.Common/Utilities/StringUtil.cs @@ -5,35 +5,36 @@ namespace DotNetty.Common.Utilities { using System; using System.Collections.Generic; - using System.Diagnostics; using System.Diagnostics.Contracts; + using System.Runtime.CompilerServices; using System.Text; + using DotNetty.Common.Internal; /// /// String utility class. /// public static class StringUtil { + public static readonly string EmptyString = ""; + public static readonly string Newline = SystemPropertyUtil.Get("line.separator", Environment.NewLine); + public const char DoubleQuote = '\"'; public const char Comma = ','; public const char LineFeed = '\n'; public const char CarriageReturn = '\r'; public const char Tab = '\t'; + public static readonly char Space = '\x20'; public const byte UpperCaseToLowerCaseAsciiOffset = 'a' - 'A'; - public static readonly string Newline; static readonly string[] Byte2HexPad = new string[256]; static readonly string[] Byte2HexNopad = new string[256]; /** * 2 - Quote character at beginning and end. * 5 - Extra allowance for anticipated escape characters that may be added. */ - static readonly int CsvNumberEscapeCharacters = 2 + 5; - static readonly char PackageSeparatorChar = '.'; + const int CsvNumberEscapeCharacters = 2 + 5; static StringUtil() { - Newline = Environment.NewLine; - // Generate the lookup table that converts a byte into a 2-digit hexadecimal integer. int i; for (i = 0; i < 10; i++) @@ -63,102 +64,57 @@ namespace DotNetty.Common.Utilities } } - /// - /// Splits the specified {@link String} with the specified delimiter in maxParts maximum parts. - /// This operation is a simplified and optimized - /// version of {@link String#split(String, int)}. - /// - /// - /// - /// - /// - public static string[] Split(string value, char delim, int maxParts) - { - int end = value.Length; - var res = new List(); - - int start = 0; - int cpt = 1; - for (int i = 0; i < end && cpt < maxParts; i++) - { - if (value[i] == delim) - { - if (start == i) - { - res.Add(string.Empty); - } - else - { - res.Add(value.Substring(start, i)); - } - start = i + 1; - cpt++; - } - } - - if (start == 0) - { - // If no delimiter was found in the value - res.Add(value); - } - else - { - if (start != end) - { - // Add the last element if it's not empty. - res.Add(value.Substring(start, end)); - } - else - { - // Truncate trailing empty elements. - for (int i = res.Count - 1; i >= 0; i--) - { - if (res[i] == "") - { - res.Remove(res[i]); - } - else - { - break; - } - } - } - } - - return res.ToArray(); - } - - /// - /// Get the item after one char delim if the delim is found (else null). - /// This operation is a simplified and optimized - /// version of {@link String#split(String, int)}. - /// - /// - /// - /// - public static string SubstringAfter(this string value, char delim) + public static string SubstringAfter(string value, char delim) { int pos = value.IndexOf(delim); return pos >= 0 ? value.Substring(pos + 1) : null; } + public static bool CommonSuffixOfLength(string s, string p, int len) => s != null && p != null && len >= 0 && RegionMatches(s, s.Length - len, p, p.Length - len, len); + + static bool RegionMatches(string value, int thisStart, string other, int start, int length) + { + if (start < 0 || other.Length - start < length) + { + return false; + } + + if (thisStart < 0 || value.Length - thisStart < length) + { + return false; + } + + if (length <= 0) + { + return true; + } + + int o1 = thisStart; + int o2 = start; + for (int i = 0; i < length; ++i) + { + if (value[o1 + i] != other[o2 + i]) + { + return false; + } + } + + return true; + } + /// /// Converts the specified byte value into a 2-digit hexadecimal integer. /// public static string ByteToHexStringPadded(int value) => Byte2HexPad[value & 0xff]; - //todo: port - // /** - // * Converts the specified byte value into a 2-digit hexadecimal integer and appends it to the specified buffer. - // */ - //public static T byteToHexStringPadded(T buf, int value) { - // try { - // buf.append(byteToHexStringPadded(value)); - // } catch (IOException e) { - // PlatformDependent.throwException(e); - // } - // return buf; - //} + // + // Converts the specified byte value into a 2-digit hexadecimal integer and appends it to the specified buffer. + // + public static T ByteToHexStringPadded(T buf, int value) where T : IAppendable + { + buf.Append(new StringCharSequence(ByteToHexStringPadded(value))); + return buf; + } /// /// Converts the specified byte array into a hexadecimal value. @@ -179,15 +135,16 @@ namespace DotNetty.Common.Utilities return sb.ToString(); } - public static StringBuilder ToHexStringPadded(StringBuilder sb, byte[] src, int offset, int length) + public static T ToHexStringPadded(T dst, byte[] src) where T : IAppendable => ToHexStringPadded(dst, src, 0, src.Length); + + public static T ToHexStringPadded(T dst, byte[] src, int offset, int length) where T : IAppendable { - Contract.Requires((offset + length) <= src.Length); int end = offset + length; for (int i = offset; i < end; i++) { - sb.Append(ByteToHexStringPadded(src[i])); + ByteToHexStringPadded(dst, src[i]); } - return sb; + return dst; } /// @@ -195,27 +152,31 @@ namespace DotNetty.Common.Utilities /// public static string ByteToHexString(byte value) => Byte2HexNopad[value & 0xff]; - public static StringBuilder ByteToHexString(StringBuilder buf, byte value) => buf.Append(ByteToHexString(value)); + public static T ByteToHexString(T buf, byte value) where T : IAppendable + { + buf.Append(new StringCharSequence(ByteToHexString(value))); + return buf; + } public static string ToHexString(byte[] src) => ToHexString(src, 0, src.Length); - public static string ToHexString(byte[] src, int offset, int length) => ToHexString(new StringBuilder(length << 1), src, offset, length).ToString(); + public static string ToHexString(byte[] src, int offset, int length) => ToHexString(new AppendableCharSequence(length << 1), src, offset, length).ToString(); - public static StringBuilder ToHexString(StringBuilder dst, byte[] src) => ToHexString(dst, src, 0, src.Length); + public static T ToHexString(T dst, byte[] src) where T : IAppendable => ToHexString(dst, src, 0, src.Length); - /// - /// Converts the specified byte array into a hexadecimal value and appends it to the specified buffer. - /// - public static StringBuilder ToHexString(StringBuilder dst, byte[] src, int offset, int length) + public static T ToHexString(T dst, byte[] src, int offset, int length) where T : IAppendable { - Debug.Assert(length >= 0); + Contract.Requires(length >= 0); + if (length == 0) { return dst; } + int end = offset + length; int endMinusOne = end - 1; int i; + // Skip preceding zeroes. for (i = offset; i < endMinusOne; i++) { @@ -232,68 +193,59 @@ namespace DotNetty.Common.Utilities return dst; } - /// - /// Escapes the specified value, if necessary according to - /// RFC-4180. - /// - /// - /// The value which will be escaped according to - /// RFC-4180 - /// - /// the escaped value if necessary, or the value unchanged - public static string EscapeCsv(string value) + public static int DecodeHexNibble(char c) { - int length = value.Length; - if (length == 0) + // Character.digit() is not used here, as it addresses a larger + // set of characters (both ASCII and full-width latin letters). + if (c >= '0' && c <= '9') { - return value; + return c - '0'; } - int last = length - 1; - bool quoted = IsDoubleQuote(value[0]) && IsDoubleQuote(value[last]) && length != 1; - bool foundSpecialCharacter = false; - bool escapedDoubleQuote = false; - StringBuilder escaped = new StringBuilder(length + CsvNumberEscapeCharacters).Append(DoubleQuote); - for (int i = 0; i < length; i++) + if (c >= 'A' && c <= 'F') { - char current = value[i]; - switch (current) - { - case DoubleQuote: - if (i == 0 || i == last) - { - if (!quoted) - { - escaped.Append(DoubleQuote); - } - else - { - continue; - } - } - else - { - bool isNextCharDoubleQuote = IsDoubleQuote(value[i + 1]); - if (!IsDoubleQuote(value[i - 1]) && - (!isNextCharDoubleQuote || i + 1 == last)) - { - escaped.Append(DoubleQuote); - escapedDoubleQuote = true; - } - } - break; - case LineFeed: - case CarriageReturn: - case Comma: - foundSpecialCharacter = true; - break; - } - escaped.Append(current); + return c - 'A' + 0xA; } - return escapedDoubleQuote || foundSpecialCharacter && !quoted ? - escaped.Append(DoubleQuote).ToString() : value; + if (c >= 'a' && c <= 'f') + { + return c - 'a' + 0xA; + } + return -1; } - static bool IsDoubleQuote(char c) => c == DoubleQuote; + // Decode a 2-digit hex byte from within a string. + public static byte DecodeHexByte(string s, int pos) + { + int hi = DecodeHexNibble(s[pos]); + int lo = DecodeHexNibble(s[pos + 1]); + if (hi == -1 || lo == -1) + { + throw new ArgumentException($"invalid hex byte '{s.Substring(pos, 2)}' at index {pos} of '{s}'"); + } + + return (byte)((hi << 4) + lo); + } + + //Decodes part of a string with hex dump + public static byte[] DecodeHexDump(string hexDump, int fromIndex, int length) + { + if (length < 0 || (length & 1) != 0) + { + throw new ArgumentException($"length: {length}"); + } + if (length == 0) + { + return EmptyArrays.EmptyBytes; + } + var bytes = new byte[length.RightUShift(1)]; + for (int i = 0; i < length; i += 2) + { + bytes[i.RightUShift(1)] = DecodeHexByte(hexDump, fromIndex + i); + } + return bytes; + } + + // Decodes a hex dump + public static byte[] DecodeHexDump(string hexDump) => DecodeHexDump(hexDump, 0, hexDump.Length); /// /// The shortcut to SimpleClassName(o.GetType()). @@ -311,5 +263,333 @@ namespace DotNetty.Common.Utilities /// with anonymous classes. /// public static string SimpleClassName(Type type) => type.Name; + + /// + /// Escapes the specified value, if necessary according to + /// RFC-4180. + /// + /// + /// The value which will be escaped according to + /// RFC-4180 + /// + /// + /// The value will first be trimmed of its optional white-space characters, according to + /// RFC-7230 + /// + /// the escaped value if necessary, or the value unchanged + public static ICharSequence EscapeCsv(ICharSequence value, bool trimWhiteSpace = false) + { + Contract.Requires(value != null); + + int length = value.Count; + if (length == 0) + { + return value; + } + + int start; + int last; + if (trimWhiteSpace) + { + start = IndexOfFirstNonOwsChar(value, length); + last = IndexOfLastNonOwsChar(value, start, length); + } + else + { + start = 0; + last = length - 1; + } + if (start > last) + { + return StringCharSequence.Empty; + } + + int firstUnescapedSpecial = -1; + bool quoted = false; + if (IsDoubleQuote(value[start])) + { + quoted = IsDoubleQuote(value[last]) && last > start; + if (quoted) + { + start++; + last--; + } + else + { + firstUnescapedSpecial = start; + } + } + + if (firstUnescapedSpecial < 0) + { + if (quoted) + { + for (int i = start; i <= last; i++) + { + if (IsDoubleQuote(value[i])) + { + if (i == last || !IsDoubleQuote(value[i + 1])) + { + firstUnescapedSpecial = i; + break; + } + i++; + } + } + } + else + { + for (int i = start; i <= last; i++) + { + char c = value[i]; + if (c == LineFeed || c == CarriageReturn || c == Comma) + { + firstUnescapedSpecial = i; + break; + } + if (IsDoubleQuote(c)) + { + if (i == last || !IsDoubleQuote(value[i + 1])) + { + firstUnescapedSpecial = i; + break; + } + i++; + } + } + } + if (firstUnescapedSpecial < 0) + { + // Special characters is not found or all of them already escaped. + // In the most cases returns a same string. New string will be instantiated (via StringBuilder) + // only if it really needed. It's important to prevent GC extra load. + return quoted ? value.SubSequence(start - 1, last + 2) : value.SubSequence(start, last + 1); + } + } + + var result = new StringBuilderCharSequence(last - start + 1 + CsvNumberEscapeCharacters); + result.Append(DoubleQuote); + result.Append(value, start, firstUnescapedSpecial - start); + for (int i = firstUnescapedSpecial; i <= last; i++) + { + char c = value[i]; + if (IsDoubleQuote(c)) + { + result.Append(DoubleQuote); + if (i < last && IsDoubleQuote(value[i + 1])) + { + i++; + } + } + result.Append(c); + } + + result.Append(DoubleQuote); + return result; + } + + public static ICharSequence UnescapeCsv(ICharSequence value) + { + Contract.Requires(value != null); + int length = value.Count; + if (length == 0) + { + return value; + } + int last = length - 1; + bool quoted = IsDoubleQuote(value[0]) && IsDoubleQuote(value[last]) && length != 1; + if (!quoted) + { + ValidateCsvFormat(value); + return value; + } + StringBuilder unescaped = InternalThreadLocalMap.Get().StringBuilder; + for (int i = 1; i < last; i++) + { + char current = value[i]; + if (current == DoubleQuote) + { + if (IsDoubleQuote(value[i + 1]) && (i + 1) != last) + { + // Followed by a double-quote but not the last character + // Just skip the next double-quote + i++; + } + else + { + // Not followed by a double-quote or the following double-quote is the last character + throw NewInvalidEscapedCsvFieldException(value, i); + } + } + unescaped.Append(current); + } + + return new StringCharSequence(unescaped.ToString()); + } + + public static IList UnescapeCsvFields(ICharSequence value) + { + var unescaped = new List(2); + StringBuilder current = InternalThreadLocalMap.Get().StringBuilder; + bool quoted = false; + int last = value.Count - 1; + for (int i = 0; i <= last; i++) + { + char c = value[i]; + if (quoted) + { + switch (c) + { + case DoubleQuote: + if (i == last) + { + // Add the last field and return + unescaped.Add((StringCharSequence)current.ToString()); + return unescaped; + } + char next = value[++i]; + if (next == DoubleQuote) + { + // 2 double-quotes should be unescaped to one + current.Append(DoubleQuote); + } + else if (next == Comma) + { + // This is the end of a field. Let's start to parse the next field. + quoted = false; + unescaped.Add((StringCharSequence)current.ToString()); + current.Length = 0; + } + else + { + // double-quote followed by other character is invalid + throw new ArgumentException($"invalid escaped CSV field: {value} index: {i - 1}"); + } + break; + default: + current.Append(c); + break; + } + } + else + { + switch (c) + { + case Comma: + // Start to parse the next field + unescaped.Add((StringCharSequence)current.ToString()); + current.Length = 0; + break; + case DoubleQuote: + if (current.Length == 0) + { + quoted = true; + } + else + { + // double-quote appears without being enclosed with double-quotes + current.Append(c); + } + break; + case LineFeed: + case CarriageReturn: + // special characters appears without being enclosed with double-quotes + throw new ArgumentException($"invalid escaped CSV field: {value} index: {i}"); + default: + current.Append(c); + break; + } + } + } + if (quoted) + { + throw new ArgumentException($"invalid escaped CSV field: {value} index: {last}"); + } + + unescaped.Add((StringCharSequence)current.ToString()); + return unescaped; + } + + static void ValidateCsvFormat(ICharSequence value) + { + int length = value.Count; + for (int i = 0; i < length; i++) + { + switch (value[i]) + { + case DoubleQuote: + case LineFeed: + case CarriageReturn: + case Comma: + // If value contains any special character, it should be enclosed with double-quotes + throw NewInvalidEscapedCsvFieldException(value, i); + } + } + } + + static ArgumentException NewInvalidEscapedCsvFieldException(ICharSequence value, int index) => new ArgumentException($"invalid escaped CSV field: {value} index: {index}"); + + public static int Length(string s) => s?.Length ?? 0; + + public static int IndexOfNonWhiteSpace(IReadOnlyList seq, int offset) + { + for (; offset < seq.Count; ++offset) + { + if (!char.IsWhiteSpace(seq[offset])) + { + return offset; + } + } + + return -1; + } + + public static bool IsSurrogate(char c) => c >= '\uD800' && c <= '\uDFFF'; + + static bool IsDoubleQuote(char c) => c == DoubleQuote; + + public static bool EndsWith(IReadOnlyList s, char c) + { + int len = s.Count; + return len > 0 && s[len - 1] == c; + } + + public static ICharSequence TrimOws(ICharSequence value) + { + int length = value.Count; + if (length == 0) + { + return value; + } + + int start = IndexOfFirstNonOwsChar(value, length); + int end = IndexOfLastNonOwsChar(value, start, length); + return start == 0 && end == length - 1 ? value : value.SubSequence(start, end + 1); + } + + static int IndexOfFirstNonOwsChar(IReadOnlyList value, int length) + { + int i = 0; + while (i < length && IsOws(value[i])) + { + i++; + } + + return i; + } + + static int IndexOfLastNonOwsChar(IReadOnlyList value, int start, int length) + { + int i = length - 1; + while (i > start && IsOws(value[i])) + { + i--; + } + + return i; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static bool IsOws(char c) => c == Space || c == Tab; } } \ No newline at end of file diff --git a/src/DotNetty.Handlers/Streams/ChunkedStream.cs b/src/DotNetty.Handlers/Streams/ChunkedStream.cs new file mode 100644 index 0000000..1410045 --- /dev/null +++ b/src/DotNetty.Handlers/Streams/ChunkedStream.cs @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Handlers.Streams +{ + using System; + using System.Diagnostics.Contracts; + using System.IO; + using System.Threading; + using DotNetty.Buffers; + + public class ChunkedStream : IChunkedInput + { + public static readonly int DefaultChunkSize = 8192; + + readonly Stream input; + readonly int chunkSize; + bool closed; + + public ChunkedStream(Stream input) : this(input, DefaultChunkSize) + { + } + + public ChunkedStream(Stream input, int chunkSize) + { + Contract.Requires(input != null); + Contract.Requires(chunkSize > 0); + + this.input = input; + this.chunkSize = chunkSize; + } + + public long TransferredBytes { get; private set; } + + public bool IsEndOfInput => this.closed || (this.input.Position == this.input.Length); + + public void Close() + { + this.closed = true; + this.input.Dispose(); + } + + public IByteBuffer ReadChunk(IByteBufferAllocator allocator) + { + if (this.IsEndOfInput) + { + return null; + } + + long availableBytes = this.input.Length - this.input.Position; + int readChunkSize = availableBytes <= 0 + ? this.chunkSize + : (int)Math.Min(this.chunkSize, availableBytes); + + bool release = true; + IByteBuffer buffer = allocator.Buffer(readChunkSize); + try + { + // transfer to buffer + int count = buffer.SetBytesAsync(buffer.WriterIndex, this.input, readChunkSize, CancellationToken.None).Result; + buffer.SetWriterIndex(buffer.WriterIndex + count); + this.TransferredBytes += count; + + release = false; + } + finally + { + if (release) + { + buffer.Release(); + } + } + + return buffer; + } + + public long Length => -1; + + public long Progress => this.TransferredBytes; + } +} diff --git a/src/DotNetty.Handlers/Streams/ChunkedWriteHandler.cs b/src/DotNetty.Handlers/Streams/ChunkedWriteHandler.cs new file mode 100644 index 0000000..438856a --- /dev/null +++ b/src/DotNetty.Handlers/Streams/ChunkedWriteHandler.cs @@ -0,0 +1,368 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Handlers.Streams +{ + using System; + using System.Collections.Generic; + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Common.Concurrency; + using DotNetty.Common.Internal.Logging; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + + public class ChunkedWriteHandler : ChannelDuplexHandler + { + static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance>(); + + readonly Queue queue = new Queue(); + volatile IChannelHandlerContext ctx; + PendingWrite currentWrite; + + public override void HandlerAdded(IChannelHandlerContext context) => this.ctx = context; + + public void ResumeTransfer() + { + if (this.ctx == null) + { + return; + } + + if (this.ctx.Executor.InEventLoop) + { + this.InvokeDoFlush(this.ctx); + } + else + { + this.ctx.Executor.Execute(state => this.InvokeDoFlush((IChannelHandlerContext)state), this.ctx); + } + } + + public override Task WriteAsync(IChannelHandlerContext context, object message) + { + var pendingWrite = new PendingWrite(message); + this.queue.Enqueue(pendingWrite); + return pendingWrite.PendingTask; + } + + public override void Flush(IChannelHandlerContext context) => this.DoFlush(context); + + public override void ChannelInactive(IChannelHandlerContext context) + { + this.DoFlush(context); + context.FireChannelInactive(); + } + + public override void ChannelWritabilityChanged(IChannelHandlerContext context) + { + if (context.Channel.IsWritable) + { + // channel is writable again try to continue flushing + this.DoFlush(context); + } + + context.FireChannelWritabilityChanged(); + } + + void Discard(Exception cause = null) + { + for (;;) + { + PendingWrite current = this.currentWrite; + if (this.currentWrite == null) + { + current = this.queue.Count > 0 ? this.queue.Dequeue() : null; + } + else + { + this.currentWrite = null; + } + + if (current == null) + { + break; + } + + object message = current.Message; + var chunks = message as IChunkedInput; + if (chunks != null) + { + try + { + if (!chunks.IsEndOfInput) + { + if (cause == null) + { + cause = new ClosedChannelException(); + } + + current.Fail(cause); + } + else + { + current.Success(); + } + } + catch (Exception exception) + { + current.Fail(exception); + Logger.Warn($"{StringUtil.SimpleClassName(typeof(ChunkedWriteHandler))}.IsEndOfInput failed", exception); + } + finally + { + CloseInput(chunks); + } + } + else + { + if (cause == null) + { + cause = new ClosedChannelException(); + } + + current.Fail(cause); + } + } + } + + void InvokeDoFlush(IChannelHandlerContext context) + { + try + { + this.DoFlush(context); + } + catch (Exception exception) + { + if (Logger.WarnEnabled) + { + Logger.Warn("Unexpected exception while sending chunks.", exception); + } + } + } + + void DoFlush(IChannelHandlerContext context) + { + IChannel channel = context.Channel; + if (!channel.Active) + { + this.Discard(); + return; + } + + bool requiresFlush = true; + IByteBufferAllocator allocator = context.Allocator; + while (channel.IsWritable) + { + if (this.currentWrite == null) + { + this.currentWrite = this.queue.Count > 0 ? this.queue.Dequeue() : null; + } + + if (this.currentWrite == null) + { + break; + } + + PendingWrite current = this.currentWrite; + object pendingMessage = current.Message; + + var chunks = pendingMessage as IChunkedInput; + if (chunks != null) + { + bool endOfInput; + bool suspend; + object message = null; + + try + { + message = chunks.ReadChunk(allocator); + endOfInput = chunks.IsEndOfInput; + if (message == null) + { + // No need to suspend when reached at the end. + suspend = !endOfInput; + } + else + { + suspend = false; + } + } + catch (Exception exception) + { + this.currentWrite = null; + + if (message != null) + { + ReferenceCountUtil.Release(message); + } + + current.Fail(exception); + CloseInput(chunks); + + break; + } + + if (suspend) + { + // ChunkedInput.nextChunk() returned null and it has + // not reached at the end of input. Let's wait until + // more chunks arrive. Nothing to write or notify. + break; + } + + if (message == null) + { + // If message is null write an empty ByteBuf. + // See https://github.com/netty/netty/issues/1671 + message = Unpooled.Empty; + } + + Task future = context.WriteAsync(message); + if (endOfInput) + { + this.currentWrite = null; + + // Register a listener which will close the input once the write is complete. + // This is needed because the Chunk may have some resource bound that can not + // be closed before its not written. + // + // See https://github.com/netty/netty/issues/303 + future.ContinueWith((_, state) => + { + var pendingTask = (PendingWrite)state; + CloseInput((IChunkedInput)pendingTask.Message); + pendingTask.Success(); + }, + current, + TaskContinuationOptions.ExecuteSynchronously); + } + else if (channel.IsWritable) + { + future.ContinueWith((task, state) => + { + var pendingTask = (PendingWrite)state; + if (task.IsFaulted) + { + CloseInput((IChunkedInput)pendingTask.Message); + pendingTask.Fail(task.Exception); + } + else + { + pendingTask.Progress(chunks.Progress, chunks.Length); + } + }, + current, + TaskContinuationOptions.ExecuteSynchronously); + } + else + { + future.ContinueWith((task, state) => + { + var handler = (ChunkedWriteHandler) state; + if (task.IsFaulted) + { + CloseInput((IChunkedInput)handler.currentWrite.Message); + handler.currentWrite.Fail(task.Exception); + } + else + { + handler.currentWrite.Progress(chunks.Progress, chunks.Length); + if (channel.IsWritable) + { + handler.ResumeTransfer(); + } + } + }, + this, + TaskContinuationOptions.ExecuteSynchronously); + } + + // Flush each chunk to conserve memory + context.Flush(); + requiresFlush = false; + } + else + { + context.WriteAsync(pendingMessage) + .ContinueWith((task, state) => + { + var pendingTask = (PendingWrite)state; + if (task.IsFaulted) + { + pendingTask.Fail(task.Exception); + } + else + { + pendingTask.Success(); + } + }, + current, + TaskContinuationOptions.ExecuteSynchronously); + + this.currentWrite = null; + requiresFlush = true; + } + + if (!channel.Active) + { + this.Discard(new ClosedChannelException()); + break; + } + } + + if (requiresFlush) + { + context.Flush(); + } + } + + static void CloseInput(IChunkedInput chunks) + { + try + { + chunks.Close(); + } + catch (Exception exception) + { + if (Logger.WarnEnabled) + { + Logger.Warn("Failed to close a chunked input.", exception); + } + } + } + + sealed class PendingWrite + { + readonly TaskCompletionSource promise; + + public PendingWrite(object msg) + { + this.Message = msg; + this.promise = new TaskCompletionSource(); + } + + public object Message { get; } + + public void Success() => this.promise.TryComplete(); + + public void Fail(Exception error) + { + ReferenceCountUtil.Release(this.Message); + this.promise.TrySetException(error); + } + + public void Progress(long progress, long total) + { + if (progress < total) + { + return; + } + + this.Success(); + } + + public Task PendingTask => this.promise.Task; + } + } +} diff --git a/src/DotNetty.Handlers/Streams/IChunkedInput.cs b/src/DotNetty.Handlers/Streams/IChunkedInput.cs new file mode 100644 index 0000000..6d113be --- /dev/null +++ b/src/DotNetty.Handlers/Streams/IChunkedInput.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Handlers.Streams +{ + using DotNetty.Buffers; + + public interface IChunkedInput + { + bool IsEndOfInput { get; } + + void Close(); + + T ReadChunk(IByteBufferAllocator allocator); + + long Length { get; } + + long Progress { get; } + } +} diff --git a/src/DotNetty.Transport/Channels/ChannelHandlerAdapter.cs b/src/DotNetty.Transport/Channels/ChannelHandlerAdapter.cs index a6b4c84..243dfb6 100644 --- a/src/DotNetty.Transport/Channels/ChannelHandlerAdapter.cs +++ b/src/DotNetty.Transport/Channels/ChannelHandlerAdapter.cs @@ -6,6 +6,7 @@ namespace DotNetty.Transport.Channels using System; using System.Net; using System.Threading.Tasks; + using DotNetty.Common.Utilities; public class ChannelHandlerAdapter : IChannelHandler { @@ -73,5 +74,13 @@ namespace DotNetty.Transport.Channels public virtual void Read(IChannelHandlerContext context) => context.Read(); public virtual bool IsSharable => false; + + protected void EnsureNotSharable() + { + if (this.IsSharable) + { + throw new InvalidOperationException($"ChannelHandler {StringUtil.SimpleClassName(this)} is not allowed to be shared"); + } + } } } \ No newline at end of file diff --git a/src/DotNetty.Transport/Channels/CombinedChannelDuplexHandler.cs b/src/DotNetty.Transport/Channels/CombinedChannelDuplexHandler.cs new file mode 100644 index 0000000..5b854e7 --- /dev/null +++ b/src/DotNetty.Transport/Channels/CombinedChannelDuplexHandler.cs @@ -0,0 +1,542 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Transport.Channels +{ + using System; + using System.Diagnostics.Contracts; + using System.Net; + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Common.Concurrency; + using DotNetty.Common.Internal.Logging; + using DotNetty.Common.Utilities; + + public class CombinedChannelDuplexHandler : ChannelDuplexHandler + where TIn : IChannelHandler + where TOut : IChannelHandler + { + static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance>(); + + DelegatingChannelHandlerContext inboundCtx; + DelegatingChannelHandlerContext outboundCtx; + volatile bool handlerAdded; + + protected CombinedChannelDuplexHandler() + { + this.EnsureNotSharable(); + } + + public CombinedChannelDuplexHandler(TIn inboundHandler, TOut outboundHandler) + { + Contract.Requires(inboundHandler != null); + Contract.Requires(outboundHandler != null); + + this.EnsureNotSharable(); + this.Init(inboundHandler, outboundHandler); + } + + protected void Init(TIn inbound, TOut outbound) + { + this.Validate(inbound, outbound); + + this.InboundHandler = inbound; + this.OutboundHandler = outbound; + } + + protected TIn InboundHandler { get; private set; } + + protected TOut OutboundHandler { get; private set; } + + void Validate(TIn inbound, TOut outbound) + { + if (this.InboundHandler != null) + { + throw new InvalidOperationException($"init() can not be invoked if {StringUtil.SimpleClassName(this)} was constructed with non-default constructor."); + } + + if (inbound == null) + { + throw new ArgumentNullException(nameof(inbound)); + } + + if (outbound == null) + { + throw new ArgumentNullException(nameof(outbound)); + } + } + + + void CheckAdded() + { + if (!this.handlerAdded) + { + throw new InvalidOperationException("handler not added to pipeline yet"); + } + } + + public void RemoveInboundHandler() + { + this.CheckAdded(); + this.inboundCtx.Remove(); + } + + public void RemoveOutboundHandler() + { + this.CheckAdded(); + this.outboundCtx.Remove(); + } + + public override void HandlerAdded(IChannelHandlerContext context) + { + if (this.InboundHandler == null) + { + throw new InvalidOperationException($"Init() must be invoked before being added to a {nameof(IChannelPipeline)} if {StringUtil.SimpleClassName(this)} was constructed with the default constructor."); + } + + this.outboundCtx = new DelegatingChannelHandlerContext(context, this.OutboundHandler); + this.inboundCtx = new DelegatingChannelHandlerContext(context, this.InboundHandler, + cause => + { + try + { + this.OutboundHandler.ExceptionCaught(this.outboundCtx, cause); + } + catch (Exception error) + { + if (Logger.DebugEnabled) + { + Logger.Debug("An exception {}" + + "was thrown by a user handler's exceptionCaught() " + + "method while handling the following exception:", error, cause); + } + else if (Logger.WarnEnabled) + { + Logger.Warn("An exception '{}' [enable DEBUG level for full stacktrace] " + + "was thrown by a user handler's exceptionCaught() " + + "method while handling the following exception:", error, cause); + } + } + }); + + // The inboundCtx and outboundCtx were created and set now it's safe to call removeInboundHandler() and + // removeOutboundHandler(). + this.handlerAdded = true; + + try + { + this.InboundHandler.HandlerAdded(this.inboundCtx); + } + finally + { + this.OutboundHandler.HandlerAdded(this.outboundCtx); + } + } + + public override void HandlerRemoved(IChannelHandlerContext context) + { + try + { + this.inboundCtx.Remove(); + } + finally + { + this.outboundCtx.Remove(); + } + } + + public override void ChannelRegistered(IChannelHandlerContext context) + { + Contract.Assert(context == this.inboundCtx.InnerContext); + + if (!this.inboundCtx.Removed) + { + this.InboundHandler.ChannelRegistered(this.inboundCtx); + } + else + { + this.inboundCtx.FireChannelRegistered(); + } + } + + public override void ChannelUnregistered(IChannelHandlerContext context) + { + Contract.Assert(context == this.inboundCtx.InnerContext); + + if (!this.inboundCtx.Removed) + { + this.InboundHandler.ChannelUnregistered(this.inboundCtx); + } + else + { + this.inboundCtx.FireChannelUnregistered(); + } + } + + public override void ChannelActive(IChannelHandlerContext context) + { + Contract.Assert(context == this.inboundCtx.InnerContext); + + if (!this.inboundCtx.Removed) + { + this.InboundHandler.ChannelActive(this.inboundCtx); + } + else + { + this.inboundCtx.FireChannelActive(); + } + } + + public override void ChannelInactive(IChannelHandlerContext context) + { + Contract.Assert(context == this.inboundCtx.InnerContext); + + if (!this.inboundCtx.Removed) + { + this.InboundHandler.ChannelInactive(this.inboundCtx); + } + else + { + this.inboundCtx.FireChannelInactive(); + } + } + + public override void ExceptionCaught(IChannelHandlerContext context, Exception exception) + { + Contract.Assert(context == this.inboundCtx.InnerContext); + + if (!this.inboundCtx.Removed) + { + this.InboundHandler.ExceptionCaught(this.inboundCtx, exception); + } + else + { + this.inboundCtx.FireExceptionCaught(exception); + } + } + + public override void UserEventTriggered(IChannelHandlerContext context, object evt) + { + Contract.Assert(context == this.inboundCtx.InnerContext); + + if (!this.inboundCtx.Removed) + { + this.InboundHandler.UserEventTriggered(this.inboundCtx, evt); + } + else + { + this.inboundCtx.FireUserEventTriggered(evt); + } + } + + public override void ChannelRead(IChannelHandlerContext context, object message) + { + Contract.Assert(context == this.inboundCtx.InnerContext); + + if (!this.inboundCtx.Removed) + { + this.InboundHandler.ChannelRead(this.inboundCtx, message); + } + else + { + this.inboundCtx.FireChannelRead(message); + } + } + + public override void ChannelReadComplete(IChannelHandlerContext context) + { + Contract.Assert(context == this.inboundCtx.InnerContext); + + if (!this.inboundCtx.Removed) + { + this.InboundHandler.ChannelReadComplete(this.inboundCtx); + } + else + { + this.inboundCtx.FireChannelReadComplete(); + } + } + + public override void ChannelWritabilityChanged(IChannelHandlerContext context) + { + Contract.Assert(context == this.inboundCtx.InnerContext); + + if (!this.inboundCtx.Removed) + { + this.InboundHandler.ChannelWritabilityChanged(this.inboundCtx); + } + else + { + this.inboundCtx.FireChannelWritabilityChanged(); + } + } + + public override Task BindAsync(IChannelHandlerContext context, EndPoint localAddress) + { + Contract.Assert(context == this.outboundCtx.InnerContext); + + if (!this.outboundCtx.Removed) + { + return this.OutboundHandler.BindAsync(this.outboundCtx, localAddress); + } + else + { + return this.outboundCtx.BindAsync(localAddress); + } + } + + public override Task ConnectAsync(IChannelHandlerContext context, EndPoint remoteAddress, EndPoint localAddress) + { + Contract.Assert(context == this.outboundCtx.InnerContext); + + if (!this.outboundCtx.Removed) + { + return this.OutboundHandler.ConnectAsync(this.outboundCtx, remoteAddress, localAddress); + } + else + { + return this.outboundCtx.ConnectAsync(localAddress); + } + } + + public override Task DisconnectAsync(IChannelHandlerContext context) + { + Contract.Assert(context == this.outboundCtx.InnerContext); + + if (!this.outboundCtx.Removed) + { + return this.OutboundHandler.DisconnectAsync(this.outboundCtx); + } + else + { + return this.outboundCtx.DisconnectAsync(); + } + } + + public override Task CloseAsync(IChannelHandlerContext context) + { + Contract.Assert(context == this.outboundCtx.InnerContext); + + if (!this.outboundCtx.Removed) + { + return this.OutboundHandler.CloseAsync(this.outboundCtx); + } + else + { + return this.outboundCtx.CloseAsync(); + } + } + + public override Task DeregisterAsync(IChannelHandlerContext context) + { + Contract.Assert(context == this.outboundCtx.InnerContext); + + if (!this.outboundCtx.Removed) + { + return this.OutboundHandler.DeregisterAsync(this.outboundCtx); + } + else + { + return this.outboundCtx.DeregisterAsync(); + } + } + + public override void Read(IChannelHandlerContext context) + { + Contract.Assert(context == this.outboundCtx.InnerContext); + + if (!this.outboundCtx.Removed) + { + this.OutboundHandler.Read(this.outboundCtx); + } + else + { + this.outboundCtx.Read(); + } + } + + public override Task WriteAsync(IChannelHandlerContext context, object message) + { + Contract.Assert(context == this.outboundCtx.InnerContext); + + if (!this.outboundCtx.Removed) + { + return this.OutboundHandler.WriteAsync(this.outboundCtx, message); + } + else + { + return this.outboundCtx.WriteAsync(message); + } + } + + public override void Flush(IChannelHandlerContext context) + { + Contract.Assert(context == this.outboundCtx.InnerContext); + + if (!this.outboundCtx.Removed) + { + this.OutboundHandler.Flush(this.outboundCtx); + } + else + { + this.outboundCtx.Flush(); + } + } + + sealed class DelegatingChannelHandlerContext : IChannelHandlerContext + { + + readonly IChannelHandlerContext ctx; + readonly IChannelHandler handler; + readonly Action onError; + bool removed; + + public DelegatingChannelHandlerContext(IChannelHandlerContext ctx, IChannelHandler handler, Action onError = null) + { + this.ctx = ctx; + this.handler = handler; + this.onError = onError; + } + + public IChannelHandlerContext InnerContext => this.ctx; + + public IChannel Channel => this.ctx.Channel; + + public IByteBufferAllocator Allocator => this.ctx.Allocator; + + public IEventExecutor Executor => this.ctx.Executor; + + public string Name => this.ctx.Name; + + public IChannelHandler Handler => this.ctx.Handler; + + public bool Removed => this.removed || this.ctx.Removed; + + public IChannelHandlerContext FireChannelRegistered() + { + this.ctx.FireChannelRegistered(); + return this; + } + + public IChannelHandlerContext FireChannelUnregistered() + { + this.ctx.FireChannelUnregistered(); + return this; + } + + public IChannelHandlerContext FireChannelActive() + { + this.ctx.FireChannelActive(); + return this; + } + + public IChannelHandlerContext FireChannelInactive() + { + this.ctx.FireChannelInactive(); + return this; + } + + public IChannelHandlerContext FireExceptionCaught(Exception ex) + { + if (this.onError != null) + { + this.onError(ex); + } + else + { + this.ctx.FireExceptionCaught(ex); + } + + return this; + } + + public IChannelHandlerContext FireUserEventTriggered(object evt) + { + this.ctx.FireUserEventTriggered(evt); + return this; + } + + public IChannelHandlerContext FireChannelRead(object message) + { + this.ctx.FireChannelRead(message); + return this; + } + + public IChannelHandlerContext FireChannelReadComplete() + { + this.ctx.FireChannelReadComplete(); + return this; + } + + public IChannelHandlerContext FireChannelWritabilityChanged() + { + this.ctx.FireChannelWritabilityChanged(); + return this; + } + + public Task BindAsync(EndPoint localAddress) => this.ctx.BindAsync(localAddress); + + public Task ConnectAsync(EndPoint remoteAddress) => this.ctx.ConnectAsync(remoteAddress); + + public Task ConnectAsync(EndPoint remoteAddress, EndPoint localAddress) => this.ctx.ConnectAsync(remoteAddress, localAddress); + + public Task DisconnectAsync() => this.ctx.DisconnectAsync(); + + public Task CloseAsync() => this.ctx.CloseAsync(); + + public Task DeregisterAsync() => this.ctx.DeregisterAsync(); + + public IChannelHandlerContext Read() + { + this.ctx.Read(); + return this; + } + + public Task WriteAsync(object message) => this.ctx.WriteAsync(message); + + public IChannelHandlerContext Flush() + { + this.ctx.Flush(); + return this; + } + + public Task WriteAndFlushAsync(object message) => this.ctx.WriteAndFlushAsync(message); + + public IAttribute GetAttribute(AttributeKey key) where T : class => this.ctx.GetAttribute(key); + + public bool HasAttribute(AttributeKey key) where T : class => this.ctx.HasAttribute(key); + + internal void Remove() + { + IEventExecutor executor = this.Executor; + if (executor.InEventLoop) + { + this.Remove0(); + } + else + { + executor.Execute(() => this.Remove0()); + } + } + + void Remove0() + { + if (this.removed) + { + return; + } + + this.removed = true; + try + { + this.handler.HandlerRemoved(this); + } + catch (Exception cause) + { + this.FireExceptionCaught( + new ChannelPipelineException($"{StringUtil.SimpleClassName(this.handler)}.handlerRemoved() has thrown an exception.", cause)); + } + } + } + } +} diff --git a/src/DotNetty.Transport/Channels/DefaultFileRegion.cs b/src/DotNetty.Transport/Channels/DefaultFileRegion.cs new file mode 100644 index 0000000..a31ea36 --- /dev/null +++ b/src/DotNetty.Transport/Channels/DefaultFileRegion.cs @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Transport.Channels +{ + using System; + using System.Diagnostics.Contracts; + using System.IO; + using DotNetty.Buffers; + using DotNetty.Common; + using DotNetty.Common.Internal.Logging; + using DotNetty.Common.Utilities; + + public class DefaultFileRegion : AbstractReferenceCounted, IFileRegion + { + static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(); + + readonly FileStream file; + + public DefaultFileRegion(FileStream file, long position, long count) + { + Contract.Requires(file != null && file.CanRead); + Contract.Requires(position >= 0 && count >= 0); + + this.file = file; + this.Position = position; + this.Count = count; + } + + public override IReferenceCounted Touch(object hint) => this; + + public long Position { get; } + + public long Transferred { get; set; } + + public long Count { get; } + + public long TransferTo(Stream target, long pos) + { + Contract.Requires(target != null); + Contract.Requires(pos >= 0); + + long totalCount = this.Count - pos; + if (totalCount < 0) + { + throw new ArgumentOutOfRangeException($"position out of range: {pos} (expected: 0 - {this.Count - 1})"); + } + + if (totalCount == 0) + { + return 0L; + } + if (this.ReferenceCount == 0) + { + throw new IllegalReferenceCountException(0); + } + + var buffer = new byte[totalCount]; + int total = this.file.Read(buffer, (int)(this.Position + pos), (int)totalCount); + target.Write(buffer, 0, total); + if (total > 0) + { + this.Transferred += total; + } + + return total; + } + + protected override void Deallocate() + { + FileStream fileStream = this.file; + if (!fileStream.CanRead) + { + return; + } + + try + { + fileStream.Dispose(); + } + catch (Exception exception) + { + if (Logger.WarnEnabled) + { + Logger.Warn("Failed to close a file stream.", exception); + } + } + } + } +} diff --git a/src/DotNetty.Transport/Channels/Embedded/EmbeddedChannel.cs b/src/DotNetty.Transport/Channels/Embedded/EmbeddedChannel.cs index e1d9a27..d3bd4ca 100644 --- a/src/DotNetty.Transport/Channels/Embedded/EmbeddedChannel.cs +++ b/src/DotNetty.Transport/Channels/Embedded/EmbeddedChannel.cs @@ -6,6 +6,7 @@ namespace DotNetty.Transport.Channels.Embedded using System; using System.Collections.Generic; using System.Diagnostics; + using System.Diagnostics.Contracts; using System.Net; using System.Runtime.ExceptionServices; using System.Threading.Tasks; @@ -85,15 +86,30 @@ namespace DotNetty.Transport.Channels.Embedded : this(id, hasDisconnect, true, handlers) { } - public EmbeddedChannel(IChannelId id, bool hasDisconnect, bool start, params IChannelHandler[] handlers) + public EmbeddedChannel(IChannelId id, bool hasDisconnect, bool register, params IChannelHandler[] handlers) : base(null, id) { - this.Metadata = hasDisconnect ? METADATA_DISCONNECT : METADATA_NO_DISCONNECT; + this.Metadata = GetMetadata(hasDisconnect); this.Configuration = new DefaultChannelConfiguration(this); - if (handlers == null) - { - throw new ArgumentNullException(nameof(handlers)); - } + this.Setup(register, handlers); + } + + public EmbeddedChannel(IChannelId id, bool hasDisconnect, IChannelConfiguration config, + params IChannelHandler[] handlers) + : base(null, id) + { + Contract.Requires(config != null); + + this.Metadata = GetMetadata(hasDisconnect); + this.Configuration = config; + this.Setup(true, handlers); + } + + static ChannelMetadata GetMetadata(bool hasDisconnect) => hasDisconnect ? METADATA_DISCONNECT : METADATA_NO_DISCONNECT; + + void Setup(bool register, params IChannelHandler[] handlers) + { + Contract.Requires(handlers != null); IChannelPipeline p = this.Pipeline; p.AddLast(new ActionChannelInitializer(channel => @@ -104,18 +120,20 @@ namespace DotNetty.Transport.Channels.Embedded if (h == null) { break; + } pipeline.AddLast(h); } })); - if (start) + if (register) { - this.Start(); + Task future = this.loop.RegisterAsync(this); + Debug.Assert(future.IsCompleted); } } - public void Start() + public void Register() { Task future = this.loop.RegisterAsync(this); Debug.Assert(future.IsCompleted); diff --git a/src/DotNetty.Transport/Channels/IFileRegion.cs b/src/DotNetty.Transport/Channels/IFileRegion.cs new file mode 100644 index 0000000..b17da4a --- /dev/null +++ b/src/DotNetty.Transport/Channels/IFileRegion.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Transport.Channels +{ + using System.IO; + using DotNetty.Common; + + public interface IFileRegion : IReferenceCounted + { + long Position { get; } + + long Transferred { get; } + + long Count { get; } + + long TransferTo(Stream target, long position); + } +} diff --git a/test/DotNetty.Buffers.Tests/AbstractByteBufferTests.cs b/test/DotNetty.Buffers.Tests/AbstractByteBufferTests.cs index 2c39623..e941168 100644 --- a/test/DotNetty.Buffers.Tests/AbstractByteBufferTests.cs +++ b/test/DotNetty.Buffers.Tests/AbstractByteBufferTests.cs @@ -2262,6 +2262,17 @@ namespace DotNetty.Buffers.Tests [Fact] public void SetBytesAfterRelease3() => Assert.Throws(() => this.ReleasedBuffer().SetBytes(0, this.ReleaseLater(Unpooled.Buffer()), 0, 1)); + [Fact] + public void SetUsAsciiCharSequenceAfterRelease() => Assert.Throws(() => this.SetCharSequenceAfterRelease0(Encoding.ASCII)); + + [Fact] + public void SetUtf8CharSequenceAfterRelease() => Assert.Throws(() => this.SetCharSequenceAfterRelease0(Encoding.UTF8)); + + [Fact] + public void SetUtf16CharSequenceAfterRelease() => Assert.Throws(() => this.SetCharSequenceAfterRelease0(Encoding.Unicode)); + + void SetCharSequenceAfterRelease0(Encoding encoding) => this.ReleasedBuffer().SetCharSequence(0, new StringCharSequence("x"), encoding); + [Fact] public void SetUsAsciiStringAfterRelease() => Assert.Throws(() => this.SetStringAfterRelease0(Encoding.ASCII)); @@ -2414,6 +2425,17 @@ namespace DotNetty.Buffers.Tests [Fact] public void WriteZeroAfterRelease() => Assert.Throws(() => this.ReleasedBuffer().WriteZero(1)); + [Fact] + public void WriteUsAsciiCharSequenceAfterRelease() => Assert.Throws(() => this.WriteCharSequenceAfterRelease0(Encoding.ASCII)); + + [Fact] + public void WriteUtf8CharSequenceAfterRelease() => Assert.Throws(() => this.WriteCharSequenceAfterRelease0(Encoding.UTF8)); + + [Fact] + public void WriteUtf16CharSequenceAfterRelease() => Assert.Throws(() => this.WriteCharSequenceAfterRelease0(Encoding.Unicode)); + + void WriteCharSequenceAfterRelease0(Encoding encoding) => this.ReleasedBuffer().WriteCharSequence(new StringCharSequence("x"), encoding); + [Fact] public void WriteUsAsciiStringAfterRelease() => Assert.Throws(() => this.WriteStringAfterRelease0(Encoding.ASCII)); @@ -2514,6 +2536,54 @@ namespace DotNetty.Buffers.Tests } } + + [Fact] + public virtual void WriteUsAsciiCharSequenceExpand() => this.WriteCharSequenceExpand(Encoding.ASCII); + + [Fact] + public virtual void WriteUtf8CharSequenceExpand() => this.WriteCharSequenceExpand(Encoding.UTF8); + + [Fact] + public virtual void WriteUtf16CharSequenceExpand() => this.WriteCharSequenceExpand(Encoding.Unicode); + + void WriteCharSequenceExpand(Encoding encoding) + { + IByteBuffer buf = this.NewBuffer(1); + try + { + int writerIndex = buf.Capacity - 1; + buf.SetWriterIndex(writerIndex); + int written = buf.WriteCharSequence(new StringCharSequence("AB"), encoding); + Assert.Equal(writerIndex, buf.WriterIndex - written); + } + finally + { + buf.Release(); + } + } + + [Fact] + public void SetUsAsciiCharSequenceNoExpand() => Assert.Throws(() => this.SetCharSequenceNoExpand(Encoding.ASCII)); + + [Fact] + public void SetUtf8CharSequenceNoExpand() => Assert.Throws(() => this.SetCharSequenceNoExpand(Encoding.UTF8)); + + [Fact] + public void SetUtf16CharSequenceNoExpand() => Assert.Throws(() => this.SetCharSequenceNoExpand(Encoding.Unicode)); + + void SetCharSequenceNoExpand(Encoding encoding) + { + IByteBuffer buf = this.NewBuffer(1); + try + { + buf.SetCharSequence(0, new StringCharSequence("AB"), encoding); + } + finally + { + buf.Release(); + } + } + [Fact] public void SetUsAsciiStringNoExpand() => Assert.Throws(() => this.SetStringNoExpand(Encoding.ASCII)); @@ -2536,6 +2606,24 @@ namespace DotNetty.Buffers.Tests } } + [Fact] + public void SetUsAsciiCharSequence() => this.SetGetCharSequence(Encoding.ASCII); + + [Fact] + public void SetUtf8CharSequence() => this.SetGetCharSequence(Encoding.UTF8); + + [Fact] + public void SetUtf16CharSequence() => this.SetGetCharSequence(Encoding.Unicode); + + void SetGetCharSequence(Encoding encoding) + { + IByteBuffer buf = this.NewBuffer(16); + var sequence = new StringCharSequence("AB"); + int bytes = buf.SetCharSequence(1, sequence, encoding); + Assert.Equal(sequence, buf.GetCharSequence(1, bytes, encoding)); + buf.Release(); + } + [Fact] public void SetUsAsciiString() => this.SetGetString(Encoding.ASCII); diff --git a/test/DotNetty.Buffers.Tests/AbstractReferenceCountedByteBufferTests.cs b/test/DotNetty.Buffers.Tests/AbstractReferenceCountedByteBufferTests.cs index 2abddff..1c32959 100644 --- a/test/DotNetty.Buffers.Tests/AbstractReferenceCountedByteBufferTests.cs +++ b/test/DotNetty.Buffers.Tests/AbstractReferenceCountedByteBufferTests.cs @@ -128,10 +128,10 @@ namespace DotNetty.Buffers.Tests public override bool HasMemoryAddress => throw new NotSupportedException(); - public override ref byte GetPinnableMemoryAddress() => throw new NotSupportedException(); - public override IntPtr AddressOfPinnedMemory() => throw new NotSupportedException(); + public override ref byte GetPinnableMemoryAddress() => throw new NotSupportedException(); + public override IByteBuffer Unwrap() => throw new NotSupportedException(); public override bool IsDirect => throw new NotSupportedException(); diff --git a/test/DotNetty.Buffers.Tests/SimpleLeakAwareByteBufferTests.cs b/test/DotNetty.Buffers.Tests/SimpleLeakAwareByteBufferTests.cs index c87d55e..92ae259 100644 --- a/test/DotNetty.Buffers.Tests/SimpleLeakAwareByteBufferTests.cs +++ b/test/DotNetty.Buffers.Tests/SimpleLeakAwareByteBufferTests.cs @@ -30,7 +30,7 @@ namespace DotNetty.Buffers.Tests { base.Dispose(); - for (;;) + for (; ; ) { NoopResourceLeakTracker tracker = null; if (this.trackers.Count > 0) @@ -120,4 +120,4 @@ namespace DotNetty.Buffers.Tests } } } -} +} \ No newline at end of file diff --git a/test/DotNetty.Buffers.Tests/SimpleLeakAwareCompositeByteBufferTests.cs b/test/DotNetty.Buffers.Tests/SimpleLeakAwareCompositeByteBufferTests.cs index 37011af..36eb063 100644 --- a/test/DotNetty.Buffers.Tests/SimpleLeakAwareCompositeByteBufferTests.cs +++ b/test/DotNetty.Buffers.Tests/SimpleLeakAwareCompositeByteBufferTests.cs @@ -118,4 +118,4 @@ namespace DotNetty.Buffers.Tests } } } -} +} \ No newline at end of file diff --git a/test/DotNetty.Buffers.Tests/SlicedByteBufferTest.cs b/test/DotNetty.Buffers.Tests/SlicedByteBufferTest.cs index d34c64e..824b902 100644 --- a/test/DotNetty.Buffers.Tests/SlicedByteBufferTest.cs +++ b/test/DotNetty.Buffers.Tests/SlicedByteBufferTest.cs @@ -106,6 +106,15 @@ namespace DotNetty.Buffers.Tests Assert.Equal(0, slice2.ReferenceCount); } + [Fact] + public override void WriteUsAsciiCharSequenceExpand() => Assert.Throws(() => base.WriteUsAsciiCharSequenceExpand()); + + [Fact] + public override void WriteUtf8CharSequenceExpand() => Assert.Throws(() => base.WriteUtf8CharSequenceExpand()); + + [Fact] + public override void WriteUtf16CharSequenceExpand() => Assert.Throws(() => base.WriteUtf16CharSequenceExpand()); + [Fact] public void EnsureWritableWithEnoughSpaceShouldNotThrow() { diff --git a/test/DotNetty.Codecs.Http.Tests/CombinedHttpHeadersTest.cs b/test/DotNetty.Codecs.Http.Tests/CombinedHttpHeadersTest.cs new file mode 100644 index 0000000..7389c64 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/CombinedHttpHeadersTest.cs @@ -0,0 +1,368 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System.Collections.Generic; + using System.Linq; + using DotNetty.Common.Utilities; + using Xunit; + + using static Common.Utilities.AsciiString; + using static HttpHeadersTestUtils; + + public sealed class CombinedHttpHeadersTest + { + static readonly AsciiString HeaderName = new AsciiString("testHeader"); + + [Fact] + public void AddCharSequencesCsv() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Add(HeaderName, HeaderValue.Three.AsList()); + AssertCsvValues(headers, HeaderValue.Three); + } + + [Fact] + public void AddCharSequencesCsvWithExistingHeader() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Add(HeaderName, HeaderValue.Three.AsList()); + headers.Add(HeaderName, HeaderValue.Five.Subset(4)); + AssertCsvValues(headers, HeaderValue.Five); + } + + [Fact] + public void AddCombinedHeadersWhenEmpty() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + CombinedHttpHeaders otherHeaders = NewCombinedHttpHeaders(); + otherHeaders.Add(HeaderName, "a"); + otherHeaders.Add(HeaderName, "b"); + headers.Add(otherHeaders); + Assert.Equal("a,b", headers.Get(HeaderName, null)?.ToString()); + } + + [Fact] + public void AddCombinedHeadersWhenNotEmpty() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Add(HeaderName, "a"); + CombinedHttpHeaders otherHeaders = NewCombinedHttpHeaders(); + otherHeaders.Add(HeaderName, "b"); + otherHeaders.Add(HeaderName, "c"); + headers.Add(otherHeaders); + Assert.Equal("a,b,c", headers.Get(HeaderName, null)?.ToString()); + } + + [Fact] + public void SetCombinedHeadersWhenNotEmpty() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Add(HeaderName, "a"); + CombinedHttpHeaders otherHeaders = NewCombinedHttpHeaders(); + otherHeaders.Add(HeaderName, "b"); + otherHeaders.Add(HeaderName, "c"); + headers.Set(otherHeaders); + Assert.Equal("b,c", headers.Get(HeaderName, null)?.ToString()); + } + + [Fact] + public void AddUncombinedHeaders() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Add(HeaderName, "a"); + var otherHeaders = new DefaultHttpHeaders(); + otherHeaders.Add(HeaderName, "b"); + otherHeaders.Add(HeaderName, "c"); + headers.Add(otherHeaders); + Assert.Equal("a,b,c", headers.Get(HeaderName, null)?.ToString()); + } + + [Fact] + public void SetUncombinedHeaders() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Add(HeaderName, "a"); + var otherHeaders = new DefaultHttpHeaders(); + otherHeaders.Add(HeaderName, "b"); + otherHeaders.Add(HeaderName, "c"); + headers.Set(otherHeaders); + Assert.Equal("b,c", headers.Get(HeaderName, null)?.ToString()); + } + + [Fact] + public void AddCharSequencesCsvWithValueContainingComma() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Add(HeaderName, HeaderValue.SixQuoted.Subset(4)); + Assert.True(ContentEquals((StringCharSequence)HeaderValue.SixQuoted.SubsetAsCsvString(4), headers.Get(HeaderName, null))); + Assert.Equal(HeaderValue.SixQuoted.Subset(4), headers.GetAll(HeaderName)); + } + + [Fact] + public void AddCharSequencesCsvWithValueContainingCommas() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Add(HeaderName, HeaderValue.Eight.Subset(6)); + Assert.True(ContentEquals((StringCharSequence)HeaderValue.Eight.SubsetAsCsvString(6), headers.Get(HeaderName, null))); + Assert.Equal(HeaderValue.Eight.Subset(6), headers.GetAll(HeaderName)); + } + + [Fact] + public void AddCharSequencesCsvMultipleTimes() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + for (int i = 0; i < 5; ++i) + { + headers.Add(HeaderName, "value"); + } + Assert.True(ContentEquals((StringCharSequence)"value,value,value,value,value", headers.Get(HeaderName, null))); + } + + [Fact] + public void AddCharSequenceCsv() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + AddValues(headers, HeaderValue.One, HeaderValue.Two, HeaderValue.Three); + AssertCsvValues(headers, HeaderValue.Three); + } + + [Fact] + public void AddCharSequenceCsvSingleValue() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + AddValues(headers, HeaderValue.One); + AssertCsvValue(headers, HeaderValue.One); + } + + [Fact] + public void AddIterableCsv() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Add(HeaderName, HeaderValue.Three.AsList()); + AssertCsvValues(headers, HeaderValue.Three); + } + + [Fact] + public void AddIterableCsvWithExistingHeader() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Add(HeaderName, HeaderValue.Three.AsList()); + headers.Add(HeaderName, HeaderValue.Five.Subset(4)); + AssertCsvValues(headers, HeaderValue.Five); + } + + [Fact] + public void AddIterableCsvSingleValue() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Add(HeaderName, HeaderValue.One.AsList()); + AssertCsvValue(headers, HeaderValue.One); + } + + [Fact] + public void AddIterableCsvEmpty() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Add(HeaderName, new List()); + Assert.Equal(0, headers.GetAll(HeaderName).Count); + } + + [Fact] + public void AddObjectCsv() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + AddObjectValues(headers, HeaderValue.One, HeaderValue.Two, HeaderValue.Three); + AssertCsvValues(headers, HeaderValue.Three); + } + + [Fact] + public void AddObjectsCsv() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + List list = HeaderValue.Three.AsList(); + Assert.Equal(3, list.Count); + headers.Add(HeaderName, list); + AssertCsvValues(headers, HeaderValue.Three); + } + + [Fact] + public void AddObjectsIterableCsv() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Add(HeaderName, HeaderValue.Three.AsList()); + AssertCsvValues(headers, HeaderValue.Three); + } + + [Fact] + public void AddObjectsCsvWithExistingHeader() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Add(HeaderName, HeaderValue.Three.AsList()); + headers.Add(HeaderName, HeaderValue.Five.Subset(4)); + AssertCsvValues(headers, HeaderValue.Five); + } + + [Fact] + public void SetCharSequenceCsv() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Set(HeaderName, HeaderValue.Three.AsList()); + AssertCsvValues(headers, HeaderValue.Three); + } + + [Fact] + public void SetIterableCsv() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Set(HeaderName, HeaderValue.Three.AsList()); + AssertCsvValues(headers, HeaderValue.Three); + } + + [Fact] + public void SetObjectObjectsCsv() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Set(HeaderName, HeaderValue.Three.AsList()); + AssertCsvValues(headers, HeaderValue.Three); + } + + [Fact] + public void SetObjectIterableCsv() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Set(HeaderName, HeaderValue.Three.AsList()); + AssertCsvValues(headers, HeaderValue.Three); + } + + static CombinedHttpHeaders NewCombinedHttpHeaders() => new CombinedHttpHeaders(true); + + static void AssertCsvValues(CombinedHttpHeaders headers, HeaderValue headerValue) + { + Assert.True(ContentEquals(headerValue.AsCsv(), headers.Get(HeaderName, null))); + + List expected = headerValue.AsList(); + IList values = headers.GetAll(HeaderName); + + Assert.Equal(expected.Count, values.Count); + for (int i = 0; i < expected.Count; i++) + { + Assert.True(expected[i].ContentEquals(values[i])); + } + } + + static void AssertCsvValue(CombinedHttpHeaders headers, HeaderValue headerValue) + { + Assert.True(ContentEquals((StringCharSequence)headerValue.ToString(), headers.Get(HeaderName, null))); + Assert.True(ContentEquals((StringCharSequence)headerValue.ToString(), headers.GetAll(HeaderName)[0])); + } + + static void AddValues(CombinedHttpHeaders headers, params HeaderValue[] headerValues) + { + foreach (HeaderValue v in headerValues) + { + headers.Add(HeaderName, (StringCharSequence)v.ToString()); + } + } + + static void AddObjectValues(CombinedHttpHeaders headers, params HeaderValue[] headerValues) + { + foreach (HeaderValue v in headerValues) + { + headers.Add(HeaderName, v.ToString()); + } + } + + [Fact] + public void GetAll() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Set(HeaderName, new List { (StringCharSequence)"a", (StringCharSequence)"b", (StringCharSequence)"c" }); + var expected = new ICharSequence[] { (StringCharSequence)"a", (StringCharSequence)"b", (StringCharSequence)"c" }; + IList actual = headers.GetAll(HeaderName); + Assert.True(expected.SequenceEqual(actual)); + + headers.Set(HeaderName, new List { (StringCharSequence)"a,", (StringCharSequence)"b,", (StringCharSequence)"c," }); + expected = new ICharSequence[] { (StringCharSequence)"a,", (StringCharSequence)"b,", (StringCharSequence)"c," }; + actual = headers.GetAll(HeaderName); + Assert.True(expected.SequenceEqual(actual)); + + headers.Set(HeaderName, new List { (StringCharSequence)"a\"", (StringCharSequence)"b\"", (StringCharSequence)"c\"" }); + expected = new ICharSequence[] { (StringCharSequence)"a\"", (StringCharSequence)"b\"", (StringCharSequence)"c\"" }; + actual = headers.GetAll(HeaderName); + Assert.True(expected.SequenceEqual(actual)); + + headers.Set(HeaderName, new List { (StringCharSequence)"\"a\"", (StringCharSequence)"\"b\"", (StringCharSequence)"\"c\"" }); + expected = new ICharSequence[] { (StringCharSequence)"a", (StringCharSequence)"b", (StringCharSequence)"c" }; + actual = headers.GetAll(HeaderName); + Assert.True(expected.SequenceEqual(actual)); + + headers.Set(HeaderName, (StringCharSequence)"a,b,c"); + expected = new ICharSequence[] { (StringCharSequence)"a,b,c" }; + actual = headers.GetAll(HeaderName); + Assert.True(expected.SequenceEqual(actual)); + + headers.Set(HeaderName, (StringCharSequence)"\"a,b,c\""); + actual = headers.GetAll(HeaderName); + Assert.True(expected.SequenceEqual(actual)); + } + + [Fact] + public void OwsTrimming() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Set(HeaderName, new List { (StringCharSequence)"\ta", (StringCharSequence)" ", (StringCharSequence)" b ", (StringCharSequence)"\t \t"}); + headers.Add(HeaderName, new List { (StringCharSequence)" c, d \t" }); + + var expected = new List { (StringCharSequence)"a", (StringCharSequence)"", (StringCharSequence)"b", (StringCharSequence)"", (StringCharSequence)"c, d" }; + IList actual = headers.GetAll(HeaderName); + Assert.True(expected.SequenceEqual(actual)); + Assert.Equal("a,,b,,\"c, d\"", headers.Get(HeaderName, null)?.ToString()); + + Assert.True(headers.ContainsValue(HeaderName, (StringCharSequence)"a", true)); + Assert.True(headers.ContainsValue(HeaderName, (StringCharSequence)" a ", true)); + Assert.True(headers.ContainsValue(HeaderName, (StringCharSequence)"a", true)); + Assert.False(headers.ContainsValue(HeaderName, (StringCharSequence)"a,b", true)); + + Assert.False(headers.ContainsValue(HeaderName, (StringCharSequence)" c, d ", true)); + Assert.False(headers.ContainsValue(HeaderName, (StringCharSequence)"c, d", true)); + Assert.True(headers.ContainsValue(HeaderName, (StringCharSequence)" c ", true)); + Assert.True(headers.ContainsValue(HeaderName, (StringCharSequence)"d", true)); + + Assert.True(headers.ContainsValue(HeaderName, (StringCharSequence)"\t", true)); + Assert.True(headers.ContainsValue(HeaderName, (StringCharSequence)"", true)); + + Assert.False(headers.ContainsValue(HeaderName, (StringCharSequence)"e", true)); + + CombinedHttpHeaders copiedHeaders = NewCombinedHttpHeaders(); + copiedHeaders.Add(headers); + Assert.Equal(new List{ (StringCharSequence)"a", (StringCharSequence)"", (StringCharSequence)"b", (StringCharSequence)"", (StringCharSequence)"c, d" }, copiedHeaders.GetAll(HeaderName)); + } + + [Fact] + public void ValueIterator() + { + CombinedHttpHeaders headers = NewCombinedHttpHeaders(); + headers.Set(HeaderName, new List { (StringCharSequence)"\ta", (StringCharSequence)" ", (StringCharSequence)" b ", (StringCharSequence)"\t \t" }); + headers.Add(HeaderName, new List { (StringCharSequence)" c, d \t" }); + + var list = new List(headers.ValueCharSequenceIterator(new AsciiString("foo"))); + Assert.Empty(list); + AssertValueIterator(headers.ValueCharSequenceIterator(HeaderName)); + } + + static void AssertValueIterator(IEnumerable values) + { + var expected = new[] { "a", "", "b", "", "c, d" }; + int index = 0; + foreach (ICharSequence value in values) + { + Assert.True(index < expected.Length, "Wrong number of values"); + Assert.Equal(expected[index], value.ToString()); + index++; + } + Assert.Equal(expected.Length, index); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/Cookies/ClientCookieDecoderTest.cs b/test/DotNetty.Codecs.Http.Tests/Cookies/ClientCookieDecoderTest.cs new file mode 100644 index 0000000..4fdf0b8 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/Cookies/ClientCookieDecoderTest.cs @@ -0,0 +1,275 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.Cookies +{ + using System; + using System.Collections.Generic; + using DotNetty.Codecs.Http.Cookies; + using Xunit; + + public sealed class ClientCookieDecoderTest + { + [Fact] + public void DecodingSingleCookieV0() + { + string cookieString = "myCookie=myValue;expires=" + + DateFormatter.Format(DateTime.UtcNow.AddMilliseconds(50000)) + + ";path=/apathsomewhere;domain=.adomainsomewhere;secure;"; + ICookie cookie = ClientCookieDecoder.StrictDecoder.Decode(cookieString); + Assert.NotNull(cookie); + + Assert.Equal("myValue", cookie.Value); + Assert.Equal(".adomainsomewhere", cookie.Domain); + Assert.True(cookie.MaxAge >= 40 && cookie.MaxAge <= 60, "maxAge should be about 50ms when parsing cookie " + cookieString); + Assert.Equal("/apathsomewhere", cookie.Path); + Assert.True(cookie.IsSecure); + } + + [Fact] + public void DecodingSingleCookieV0ExtraParamsIgnored() + { + const string CookieString = "myCookie=myValue;max-age=50;path=/apathsomewhere;" + + "domain=.adomainsomewhere;secure;comment=this is a comment;version=0;" + + "commentURL=http://aurl.com;port=\"80,8080\";discard;"; + ICookie cookie = ClientCookieDecoder.StrictDecoder.Decode(CookieString); + Assert.NotNull(cookie); + + Assert.Equal("myValue", cookie.Value); + Assert.Equal(".adomainsomewhere", cookie.Domain); + Assert.Equal(50, cookie.MaxAge); + Assert.Equal("/apathsomewhere", cookie.Path); + Assert.True(cookie.IsSecure); + } + + [Fact] + public void DecodingSingleCookieV1() + { + const string CookieString = "myCookie=myValue;max-age=50;path=/apathsomewhere;domain=.adomainsomewhere" + + ";secure;comment=this is a comment;version=1;"; + ICookie cookie = ClientCookieDecoder.StrictDecoder.Decode(CookieString); + Assert.NotNull(cookie); + + Assert.Equal("myValue", cookie.Value); + Assert.Equal(".adomainsomewhere", cookie.Domain); + Assert.Equal(50, cookie.MaxAge); + Assert.Equal("/apathsomewhere", cookie.Path); + Assert.True(cookie.IsSecure); + } + + [Fact] + public void DecodingSingleCookieV1ExtraParamsIgnored() + { + const string CookieString = "myCookie=myValue;max-age=50;path=/apathsomewhere;" + + "domain=.adomainsomewhere;secure;comment=this is a comment;version=1;" + + "commentURL=http://aurl.com;port='80,8080';discard;"; + ICookie cookie = ClientCookieDecoder.StrictDecoder.Decode(CookieString); + Assert.NotNull(cookie); + + Assert.Equal("myValue", cookie.Value); + Assert.Equal(".adomainsomewhere", cookie.Domain); + Assert.Equal(50, cookie.MaxAge); + Assert.Equal("/apathsomewhere", cookie.Path); + Assert.True(cookie.IsSecure); + } + + [Fact] + public void DecodingSingleCookieV2() + { + const string CookieString = "myCookie=myValue;max-age=50;path=/apathsomewhere;" + + "domain=.adomainsomewhere;secure;comment=this is a comment;version=2;" + + "commentURL=http://aurl.com;port=\"80,8080\";discard;"; + ICookie cookie = ClientCookieDecoder.StrictDecoder.Decode(CookieString); + Assert.NotNull(cookie); + + Assert.Equal("myValue", cookie.Value); + Assert.Equal(".adomainsomewhere", cookie.Domain); + Assert.Equal(50, cookie.MaxAge); + Assert.Equal("/apathsomewhere", cookie.Path); + Assert.True(cookie.IsSecure); + } + + [Fact] + public void DecodingComplexCookie() + { + const string CookieString = "myCookie=myValue;max-age=50;path=/apathsomewhere;" + + "domain=.adomainsomewhere;secure;comment=this is a comment;version=2;" + + "commentURL=\"http://aurl.com\";port='80,8080';discard;"; + ICookie cookie = ClientCookieDecoder.StrictDecoder.Decode(CookieString); + Assert.NotNull(cookie); + + Assert.Equal("myValue", cookie.Value); + Assert.Equal(".adomainsomewhere", cookie.Domain); + Assert.Equal(50, cookie.MaxAge); + Assert.Equal("/apathsomewhere", cookie.Path); + Assert.True(cookie.IsSecure); + } + + [Fact] + public void DecodingQuotedCookie() + { + var sources = new List + { + "a=\"\",", + "b=\"1\"," + }; + + var cookies = new List(); + foreach (string source in sources) + { + cookies.Add(ClientCookieDecoder.StrictDecoder.Decode(source)); + } + Assert.Equal(2, cookies.Count); + + ICookie c = cookies[0]; + Assert.Equal("a", c.Name); + Assert.Equal("", c.Value); + + c = cookies[1]; + Assert.Equal("b", c.Name); + Assert.Equal("1", c.Value); + } + + [Fact] + public void DecodingGoogleAnalyticsCookie() + { + const string Source = "ARPT=LWUKQPSWRTUN04CKKJI; " + + "kw-2E343B92-B097-442c-BFA5-BE371E0325A2=unfinished furniture; " + + "__utma=48461872.1094088325.1258140131.1258140131.1258140131.1; " + + "__utmb=48461872.13.10.1258140131; __utmc=48461872; " + + "__utmz=48461872.1258140131.1.1.utmcsr=overstock.com|utmccn=(referral)|" + + "utmcmd=referral|utmcct=/Home-Garden/Furniture/Clearance,/clearance,/32/dept.html"; + ICookie cookie = ClientCookieDecoder.StrictDecoder.Decode(Source); + Assert.NotNull(cookie); + + Assert.Equal("ARPT", cookie.Name); + Assert.Equal("LWUKQPSWRTUN04CKKJI", cookie.Value); + } + + [Fact] + public void DecodingLongDates() + { + var cookieDate = new DateTime(9999, 12, 31, 23, 59, 59, DateTimeKind.Utc); + long expectedMaxAge = (cookieDate.Ticks - DateTime.UtcNow.Ticks) / TimeSpan.TicksPerSecond; + + const string Source = "Format=EU; expires=Fri, 31-Dec-9999 23:59:59 GMT; path=/"; + ICookie cookie = ClientCookieDecoder.StrictDecoder.Decode(Source); + Assert.NotNull(cookie); + + Assert.True(Math.Abs(expectedMaxAge - cookie.MaxAge) < 2); + } + + [Fact] + public void DecodingValueWithCommaFails() + { + const string Source = "UserCookie=timeZoneName=(GMT+04:00) Moscow, St. Petersburg, Volgograd&promocode=®ion=BE;" + + " expires=Sat, 01-Dec-2012 10:53:31 GMT; path=/"; + ICookie cookie = ClientCookieDecoder.StrictDecoder.Decode(Source); + Assert.Null(cookie); + } + + [Fact] + public void DecodingWeirdNames1() + { + const string Source = "path=; expires=Mon, 01-Jan-1990 00:00:00 GMT; path=/; domain=.www.google.com"; + ICookie cookie = ClientCookieDecoder.StrictDecoder.Decode(Source); + Assert.NotNull(cookie); + + Assert.Equal("path", cookie.Name); + Assert.Equal("", cookie.Value); + Assert.Equal("/", cookie.Path); + } + + [Fact] + public void DecodingWeirdNames2() + { + const string Source = "HTTPOnly="; + ICookie cookie = ClientCookieDecoder.StrictDecoder.Decode(Source); + Assert.NotNull(cookie); + + Assert.Equal("HTTPOnly", cookie.Name); + Assert.Equal("", cookie.Value); + } + + [Fact] + public void DecodingValuesWithCommasAndEqualsFails() + { + const string Source = "A=v=1&lg=en-US,it-IT,it&intl=it&np=1;T=z=E"; + ICookie cookie = ClientCookieDecoder.StrictDecoder.Decode(Source); + Assert.Null(cookie); + } + + [Fact] + public void DecodingLongValue() + { + const string LongValue = + "b___$Q__$ha______" + + "%=J^wI__3iD____$=HbQW__3iF____#=J^wI__3iH____%=J^wI__3iM____%=J^wI__3iS____" + + "#=J^wI__3iU____%=J^wI__3iZ____#=J^wI__3i]____%=J^wI__3ig____%=J^wI__3ij____" + + "%=J^wI__3ik____#=J^wI__3il____$=HbQW__3in____%=J^wI__3ip____$=HbQW__3iq____" + + "$=HbQW__3it____%=J^wI__3ix____#=J^wI__3j_____$=HbQW__3j%____$=HbQW__3j'____" + + "%=J^wI__3j(____%=J^wI__9mJ____'=KqtH__=SE__M____" + + "'=KqtH__s1X____$=MMyc__s1_____#=MN#O__ypn____'=KqtH__ypr____'=KqtH_#%h_____" + + "%=KqtH_#%o_____'=KqtH_#)H6______'=KqtH_#]9R____$=H/Lt_#]I6____#=KqtH_#]Z#____%=KqtH_#^*N____" + + "#=KqtH_#^:m____#=KqtH_#_*_____%=J^wI_#`-7____#=KqtH_#`T>____'=KqtH_#`T?____" + + "'=KqtH_#`TA____'=KqtH_#`TB____'=KqtH_#`TG____'=KqtH_#`TP____#=KqtH_#`U_____" + + "'=KqtH_#`U/____'=KqtH_#`U0____#=KqtH_#`U9____'=KqtH_#aEQ____%=KqtH_#b<)____" + + "'=KqtH_#c9-____%=KqtH_#dxC____%=KqtH_#dxE____%=KqtH_#ev$____'=KqtH_#fBi____" + + "#=KqtH_#fBj____'=KqtH_#fG)____'=KqtH_#fG+____'=KqtH_#g*B____'=KqtH_$>hD____+=J^x0_$?lW____'=KqtH_$?ll____'=KqtH_$?lm____" + + "%=KqtH_$?mi____'=KqtH_$?mx____'=KqtH_$D7]____#=J_#p_$D@T____#=J_#p_$V + Assert.Throws(() => ClientCookieEncoder.StrictEncoder.Encode(new DefaultCookie("myCookie", "foo;bar"))); + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/Cookies/ServerCookieDecoderTest.cs b/test/DotNetty.Codecs.Http.Tests/Cookies/ServerCookieDecoderTest.cs new file mode 100644 index 0000000..76679ae --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/Cookies/ServerCookieDecoderTest.cs @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.Cookies +{ + using System.Collections.Generic; + using System.Linq; + using DotNetty.Codecs.Http.Cookies; + using Xunit; + + public sealed class ServerCookieDecoderTest + { + [Fact] + public void DecodingSingleCookie() + { + const string CookieString = "myCookie=myValue"; + ISet cookies = ServerCookieDecoder.StrictDecoder.Decode(CookieString); + + Assert.NotNull(cookies); + Assert.Equal(1, cookies.Count); + ICookie cookie = cookies.First(); + Assert.Equal("myValue", cookie.Value); + } + + [Fact] + public void DecodingMultipleCookies() + { + const string C1 = "myCookie=myValue;"; + const string C2 = "myCookie2=myValue2;"; + const string C3 = "myCookie3=myValue3;"; + + ISet cookies = ServerCookieDecoder.StrictDecoder.Decode(C1 + C2 + C3); + Assert.NotNull(cookies); + Assert.Equal(3, cookies.Count); + + List list = cookies.ToList(); + Assert.Equal("myValue", list[0].Value); + Assert.Equal("myValue2", list[1].Value); + Assert.Equal("myValue3", list[2].Value); + } + + + [Fact] + public void DecodingGoogleAnalyticsCookie() + { + const string Source = "ARPT=LWUKQPSWRTUN04CKKJI; " + + "kw-2E343B92-B097-442c-BFA5-BE371E0325A2=unfinished_furniture; " + + "__utma=48461872.1094088325.1258140131.1258140131.1258140131.1; " + + "__utmb=48461872.13.10.1258140131; __utmc=48461872; " + + "__utmz=48461872.1258140131.1.1.utmcsr=overstock.com|utmccn=(referral)|" + + "utmcmd=referral|utmcct=/Home-Garden/Furniture/Clearance/clearance/32/dept.html"; + + ISet cookies = ServerCookieDecoder.StrictDecoder.Decode(Source); + Assert.NotNull(cookies); + Assert.Equal(6, cookies.Count); + + List list = cookies.ToList(); + + ICookie c = list[0]; + Assert.Equal("ARPT", c.Name); + Assert.Equal("LWUKQPSWRTUN04CKKJI", c.Value); + + c = list[1]; + Assert.Equal("__utma", c.Name); + Assert.Equal("48461872.1094088325.1258140131.1258140131.1258140131.1", c.Value); + + c = list[2]; + Assert.Equal("__utmb", c.Name); + Assert.Equal("48461872.13.10.1258140131", c.Value); + + c = list[3]; + Assert.Equal("__utmc", c.Name); + Assert.Equal("48461872", c.Value); + + c = list[4]; + Assert.Equal("__utmz", c.Name); + Assert.Equal("48461872.1258140131.1.1.utmcsr=overstock.com|" + + "utmccn=(referral)|utmcmd=referral|utmcct=/Home-Garden/Furniture/Clearance/clearance/32/dept.html", + c.Value); + + c = list[5]; + Assert.Equal("kw-2E343B92-B097-442c-BFA5-BE371E0325A2", c.Name); + Assert.Equal("unfinished_furniture", c.Value); + } + + [Fact] + public void DecodingLongValue() + { + const string LongValue = + "b___$Q__$ha______" + + "%=J^wI__3iD____$=HbQW__3iF____#=J^wI__3iH____%=J^wI__3iM____%=J^wI__3iS____" + + "#=J^wI__3iU____%=J^wI__3iZ____#=J^wI__3i]____%=J^wI__3ig____%=J^wI__3ij____" + + "%=J^wI__3ik____#=J^wI__3il____$=HbQW__3in____%=J^wI__3ip____$=HbQW__3iq____" + + "$=HbQW__3it____%=J^wI__3ix____#=J^wI__3j_____$=HbQW__3j%____$=HbQW__3j'____" + + "%=J^wI__3j(____%=J^wI__9mJ____'=KqtH__=SE__M____" + + "'=KqtH__s1X____$=MMyc__s1_____#=MN#O__ypn____'=KqtH__ypr____'=KqtH_#%h_____" + + "%=KqtH_#%o_____'=KqtH_#)H6______'=KqtH_#]9R____$=H/Lt_#]I6____#=KqtH_#]Z#____%=KqtH_#^*N____" + + "#=KqtH_#^:m____#=KqtH_#_*_____%=J^wI_#`-7____#=KqtH_#`T>____'=KqtH_#`T?____" + + "'=KqtH_#`TA____'=KqtH_#`TB____'=KqtH_#`TG____'=KqtH_#`TP____#=KqtH_#`U_____" + + "'=KqtH_#`U/____'=KqtH_#`U0____#=KqtH_#`U9____'=KqtH_#aEQ____%=KqtH_#b<)____" + + "'=KqtH_#c9-____%=KqtH_#dxC____%=KqtH_#dxE____%=KqtH_#ev$____'=KqtH_#fBi____" + + "#=KqtH_#fBj____'=KqtH_#fG)____'=KqtH_#fG+____'=KqtH_#g*B____'=KqtH_$>hD____+=J^x0_$?lW____'=KqtH_$?ll____'=KqtH_$?lm____" + + "%=KqtH_$?mi____'=KqtH_$?mx____'=KqtH_$D7]____#=J_#p_$D@T____#=J_#p_$V cookies = ServerCookieDecoder.StrictDecoder.Decode("bh=\"" + LongValue + "\";"); + Assert.NotNull(cookies); + Assert.Equal(1, cookies.Count); + + ICookie c = cookies.First(); + Assert.Equal("bh", c.Name); + Assert.Equal(LongValue, c.Value); + } + + [Fact] + public void DecodingOldRFC2965Cookies() + { + const string Source = "$Version=\"1\"; " + + "Part_Number1=\"Riding_Rocket_0023\"; $Path=\"/acme/ammo\"; " + + "Part_Number2=\"Rocket_Launcher_0001\"; $Path=\"/acme\""; + + ISet cookies = ServerCookieDecoder.StrictDecoder.Decode(Source); + Assert.NotNull(cookies); + Assert.Equal(2, cookies.Count); + + List list = cookies.ToList(); + ICookie c = list[0]; + Assert.Equal("Part_Number1", c.Name); + Assert.Equal("Riding_Rocket_0023", c.Value); + + c = list[1]; + Assert.Equal("Part_Number2", c.Name); + Assert.Equal("Rocket_Launcher_0001", c.Value); + } + + [Fact] + public void RejectCookieValueWithSemicolon() + { + ISet cookies = ServerCookieDecoder.StrictDecoder.Decode("name=\"foo;bar\";"); + Assert.NotNull(cookies); + Assert.Equal(0, cookies.Count); + } + + [Fact] + public void CaseSensitiveNames() + { + ISet cookies = ServerCookieDecoder.StrictDecoder.Decode("session_id=a; Session_id=b;"); + Assert.NotNull(cookies); + Assert.Equal(2, cookies.Count); + + List list = cookies.ToList(); + ICookie c = list[0]; + + Assert.Equal("Session_id", c.Name); + Assert.Equal("b", c.Value); + + c = list[1]; + Assert.Equal("session_id", c.Name); + Assert.Equal("a", c.Value); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/Cookies/ServerCookieEncoderTest.cs b/test/DotNetty.Codecs.Http.Tests/Cookies/ServerCookieEncoderTest.cs new file mode 100644 index 0000000..d0ba900 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/Cookies/ServerCookieEncoderTest.cs @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.Cookies +{ + using System; + using System.Collections.Generic; + using System.Text.RegularExpressions; + using DotNetty.Codecs.Http.Cookies; + using Xunit; + + public sealed class ServerCookieEncoderTest + { + [Fact] + public void EncodingSingleCookieV0() + { + const int MaxAge = 50; + const string Result = "myCookie=myValue; Max-Age=50; Expires=(.+?); Path=/apathsomewhere; Domain=.adomainsomewhere; Secure"; + + var cookie = new DefaultCookie("myCookie", "myValue") + { + Domain = ".adomainsomewhere", + MaxAge = MaxAge, + Path = "/apathsomewhere", + IsSecure = true + }; + string encodedCookie = ServerCookieEncoder.StrictEncoder.Encode(cookie); + + var regex = new Regex(Result, RegexOptions.Compiled); + MatchCollection matches = regex.Matches(encodedCookie); + Assert.Single(matches); + + Match match = matches[0]; + Assert.NotNull(match); + + DateTime? expiresDate = DateFormatter.ParseHttpDate(match.Groups[1].Value); + Assert.True(expiresDate.HasValue, $"Parse http date failed : {match.Groups[1].Value}"); + long diff = (expiresDate.Value.Ticks - DateTime.UtcNow.Ticks) / TimeSpan.TicksPerSecond; + // 2 secs should be fine + Assert.True(Math.Abs(diff - MaxAge) <= 2, $"Expecting difference of MaxAge < 2s, but was {diff}s (MaxAge={MaxAge})"); + } + + [Fact] + public void EncodingWithNoCookies() + { + string encodedCookie1 = ClientCookieEncoder.StrictEncoder.Encode(default(ICookie[])); + IList encodedCookie2 = ServerCookieEncoder.StrictEncoder.Encode(default(ICookie[])); + + Assert.Null(encodedCookie1); + Assert.NotNull(encodedCookie2); + Assert.Empty(encodedCookie2); + } + + [Fact] + public void EncodingMultipleCookiesStrict() + { + var result = new List + { + "cookie2=value2", + "cookie1=value3" + }; + ICookie cookie1 = new DefaultCookie("cookie1", "value1"); + ICookie cookie2 = new DefaultCookie("cookie2", "value2"); + ICookie cookie3 = new DefaultCookie("cookie1", "value3"); + + IList encodedCookies = ServerCookieEncoder.StrictEncoder.Encode(cookie1, cookie2, cookie3); + Assert.Equal(result, encodedCookies); + } + + [Fact] + public void IllegalCharInCookieNameMakesStrictEncoderThrowsException() + { + var illegalChars = new HashSet(); + + // CTLs + for (int i = 0x00; i <= 0x1F; i++) + { + illegalChars.Add((char)i); + } + illegalChars.Add((char)0x7F); + + var separaters = new [] + { + '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', + '?', '=', '{', '}', ' ', '\t' + }; + + // separators + foreach(char c in separaters) + { + illegalChars.Add(c); + } + + foreach (char c in illegalChars) + { + Assert.Throws(() => ServerCookieEncoder.StrictEncoder.Encode( + new DefaultCookie("foo" + c + "bar", "value"))); + } + } + + [Fact] + public void IllegalCharInCookieValueMakesStrictEncoderThrowsException() + { + var illegalChars = new HashSet(); + // CTLs + for (int i = 0x00; i <= 0x1F; i++) + { + illegalChars.Add((char)i); + } + illegalChars.Add((char)0x7F); + + + // whitespace, DQUOTE, comma, semicolon, and backslash + var separaters = new[] + { + ' ', '"', ',', ';', '\\' + }; + + foreach(char c in separaters) + { + illegalChars.Add(c); + } + + foreach (char c in illegalChars) + { + Assert.Throws(() => ServerCookieEncoder.StrictEncoder.Encode( + new DefaultCookie("name", "value" + c))); + } + } + + [Fact] + public void EncodingMultipleCookiesLax() + { + var result = new List + { + "cookie1=value1", + "cookie2=value2", + "cookie1=value3" + }; + + ICookie cookie1 = new DefaultCookie("cookie1", "value1"); + ICookie cookie2 = new DefaultCookie("cookie2", "value2"); + ICookie cookie3 = new DefaultCookie("cookie1", "value3"); + IList encodedCookies = ServerCookieEncoder.LaxEncoder.Encode(cookie1, cookie2, cookie3); + + Assert.Equal(result, encodedCookies); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/Cors/CorsConfigTest.cs b/test/DotNetty.Codecs.Http.Tests/Cors/CorsConfigTest.cs new file mode 100644 index 0000000..7edf70d --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/Cors/CorsConfigTest.cs @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.Cors +{ + using System.Collections.Generic; + using DotNetty.Codecs.Http.Cors; + using DotNetty.Common.Utilities; + using Xunit; + + public sealed class CorsConfigTest + { + [Fact] + public void Disabled() + { + CorsConfig cors = CorsConfigBuilder.ForAnyOrigin().Disable().Build(); + Assert.False(cors.IsCorsSupportEnabled); + } + + [Fact] + public void AnyOrigin() + { + CorsConfig cors = CorsConfigBuilder.ForAnyOrigin().Build(); + Assert.True(cors.IsAnyOriginSupported); + Assert.Equal("*", cors.Origin.ToString()); + Assert.Equal(0, cors.Origins.Count); + } + + [Fact] + public void WildcardOrigin() + { + CorsConfig cors = CorsConfigBuilder.ForOrigin(CorsHandler.AnyOrigin).Build(); + Assert.True(cors.IsAnyOriginSupported); + Assert.Equal("*", cors.Origin.ToString()); + Assert.Equal(0, cors.Origins.Count); + } + + [Fact] + public void Origin() + { + CorsConfig cors = CorsConfigBuilder.ForOrigin((StringCharSequence)"http://localhost:7888").Build(); + Assert.Equal("http://localhost:7888", cors.Origin.ToString()); + Assert.False(cors.IsAnyOriginSupported); + } + + [Fact] + public void Origins() + { + ICharSequence[] origins = { (StringCharSequence)"http://localhost:7888", (StringCharSequence)"https://localhost:7888"}; + CorsConfig cors = CorsConfigBuilder.ForOrigins(origins).Build(); + Assert.Equal(2, cors.Origins.Count); + Assert.True(cors.Origins.Contains(origins[0])); + Assert.True(cors.Origins.Contains(origins[1])); + Assert.False(cors.IsAnyOriginSupported); + } + + [Fact] + public void ExposeHeaders() + { + CorsConfig cors = CorsConfigBuilder.ForAnyOrigin() + .ExposeHeaders((StringCharSequence)"custom-header1", (StringCharSequence)"custom-header2").Build(); + Assert.True(cors.ExposedHeaders().Contains((StringCharSequence)"custom-header1")); + Assert.True(cors.ExposedHeaders().Contains((StringCharSequence)"custom-header2")); + } + + [Fact] + public void AllowCredentials() + { + CorsConfig cors = CorsConfigBuilder.ForAnyOrigin().AllowCredentials().Build(); + Assert.True(cors.IsCredentialsAllowed); + } + + [Fact] + public void MaxAge() + { + CorsConfig cors = CorsConfigBuilder.ForAnyOrigin().MaxAge(3000).Build(); + Assert.Equal(3000, cors.MaxAge); + } + + [Fact] + public void RequestMethods() + { + CorsConfig cors = CorsConfigBuilder.ForAnyOrigin() + .AllowedRequestMethods(HttpMethod.Post, HttpMethod.Get).Build(); + Assert.True(cors.AllowedRequestMethods().Contains(HttpMethod.Post)); + Assert.True(cors.AllowedRequestMethods().Contains(HttpMethod.Get)); + } + + [Fact] + public void RequestHeaders() + { + CorsConfig cors = CorsConfigBuilder.ForAnyOrigin() + .AllowedRequestHeaders((AsciiString)"preflight-header1", (AsciiString)"preflight-header2").Build(); + Assert.True(cors.AllowedRequestHeaders().Contains((AsciiString)"preflight-header1")); + Assert.True(cors.AllowedRequestHeaders().Contains((AsciiString)"preflight-header2")); + } + + [Fact] + public void PreflightResponseHeadersSingleValue() + { + CorsConfig cors = CorsConfigBuilder.ForAnyOrigin() + .PreflightResponseHeader((AsciiString)"SingleValue", (StringCharSequence)"value").Build(); + Assert.Equal((AsciiString)"value", cors.PreflightResponseHeaders().Get((AsciiString)"SingleValue", null)); + } + + [Fact] + public void PreflightResponseHeadersMultipleValues() + { + CorsConfig cors = CorsConfigBuilder.ForAnyOrigin() + .PreflightResponseHeader((AsciiString)"MultipleValues", (StringCharSequence)"value1", (StringCharSequence)"value2").Build(); + IList values = cors.PreflightResponseHeaders().GetAll((AsciiString)"MultipleValues"); + Assert.NotNull(values); + Assert.True(values.Contains((AsciiString)"value1")); + Assert.True(values.Contains((AsciiString)"value2")); + } + + [Fact] + public void DefaultPreflightResponseHeaders() + { + CorsConfig cors = CorsConfigBuilder.ForAnyOrigin().Build(); + Assert.NotNull(cors.PreflightResponseHeaders().Get(HttpHeaderNames.Date, null)); + Assert.Equal("0", cors.PreflightResponseHeaders().Get(HttpHeaderNames.ContentLength, null)); + } + + [Fact] + public void EmptyPreflightResponseHeaders() + { + CorsConfig cors = CorsConfigBuilder.ForAnyOrigin().NoPreflightResponseHeaders().Build(); + Assert.Same(EmptyHttpHeaders.Default, cors.PreflightResponseHeaders()); + } + + [Fact] + public void ShortCircuit() + { + CorsConfig cors = CorsConfigBuilder.ForOrigin((AsciiString)"http://localhost:8080") + .ShortCircuit().Build(); + Assert.True(cors.IsShortCircuit); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/Cors/CorsHandlerTest.cs b/test/DotNetty.Codecs.Http.Tests/Cors/CorsHandlerTest.cs new file mode 100644 index 0000000..180dc0b --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/Cors/CorsHandlerTest.cs @@ -0,0 +1,432 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.Cors +{ + using System.Collections.Generic; + using DotNetty.Codecs.Http.Cors; + using DotNetty.Common.Concurrency; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class CorsHandlerTest + { + [Fact] + public void NonCorsRequest() + { + IHttpResponse response = SimpleRequest(CorsConfigBuilder.ForAnyOrigin().Build(), null); + Assert.False(response.Headers.Contains(HttpHeaderNames.AccessControlAllowOrigin)); + } + + [Fact] + public void SimpleRequestWithAnyOrigin() + { + IHttpResponse response = SimpleRequest(CorsConfigBuilder.ForAnyOrigin().Build(), "http://localhost:7777"); + Assert.Equal("*", response.Headers.Get(HttpHeaderNames.AccessControlAllowOrigin, null).ToString()); + Assert.Null(response.Headers.Get(HttpHeaderNames.AccessControlAllowHeaders, null)); + } + + [Fact] + public void SimpleRequestWithNullOrigin() + { + IHttpResponse response = SimpleRequest(CorsConfigBuilder.ForOrigin((AsciiString)"http://test.com") + .AllowNullOrigin() + .AllowCredentials() + .Build(), "null"); + Assert.Equal("null", response.Headers.Get(HttpHeaderNames.AccessControlAllowOrigin, null).ToString()); + Assert.Equal("true", response.Headers.Get(HttpHeaderNames.AccessControlAllowCredentials, null).ToString()); + Assert.Null(response.Headers.Get(HttpHeaderNames.AccessControlAllowHeaders, null)); + } + + [Fact] + public void SimpleRequestWithOrigin() + { + var origin = new AsciiString("http://localhost:8888"); + IHttpResponse response = SimpleRequest(CorsConfigBuilder.ForOrigin(origin).Build(), origin.ToString()); + Assert.Equal(origin, response.Headers.Get(HttpHeaderNames.AccessControlAllowOrigin, null)); + Assert.Null(response.Headers.Get(HttpHeaderNames.AccessControlAllowHeaders, null)); + } + + [Fact] + public void SimpleRequestWithOrigins() + { + var origin1 = new AsciiString("http://localhost:8888"); + var origin2 = new AsciiString("https://localhost:8888"); + ICharSequence[] origins = { origin1, origin2}; + IHttpResponse response1 = SimpleRequest(CorsConfigBuilder.ForOrigins(origins).Build(), origin1.ToString()); + Assert.Equal(origin1, response1.Headers.Get(HttpHeaderNames.AccessControlAllowOrigin, null)); + Assert.Null(response1.Headers.Get(HttpHeaderNames.AccessControlAllowHeaders, null)); + IHttpResponse response2 = SimpleRequest(CorsConfigBuilder.ForOrigins(origins).Build(), origin2.ToString()); + Assert.Equal(origin2, response2.Headers.Get(HttpHeaderNames.AccessControlAllowOrigin, null)); + Assert.Null(response2.Headers.Get(HttpHeaderNames.AccessControlAllowHeaders, null)); + } + + [Fact] + public void SimpleRequestWithNoMatchingOrigin() + { + var origin = new AsciiString("http://localhost:8888"); + IHttpResponse response = SimpleRequest(CorsConfigBuilder.ForOrigins( + new AsciiString("https://localhost:8888")).Build(), origin.ToString()); + Assert.Null(response.Headers.Get(HttpHeaderNames.AccessControlAllowOrigin, null)); + Assert.Null(response.Headers.Get(HttpHeaderNames.AccessControlAllowHeaders, null)); + } + + [Fact] + public void PreflightDeleteRequestWithCustomHeaders() + { + CorsConfig config = CorsConfigBuilder.ForOrigin( + new AsciiString("http://localhost:8888")).AllowedRequestMethods(HttpMethod.Get, HttpMethod.Delete).Build(); + IHttpResponse response = PreflightRequest(config, "http://localhost:8888", "content-type, xheader1"); + Assert.Equal("http://localhost:8888", response.Headers.Get(HttpHeaderNames.AccessControlAllowOrigin, null)); + Assert.Contains("GET", response.Headers.Get(HttpHeaderNames.AccessControlAllowMethods, null).ToString()); + Assert.Contains("DELETE", response.Headers.Get(HttpHeaderNames.AccessControlAllowMethods, null).ToString()); + Assert.Equal(HttpHeaderNames.Origin.ToString(), response.Headers.Get(HttpHeaderNames.Vary, null)); + } + + [Fact] + public void PreflightRequestWithCustomHeaders() + { + const string HeaderName = "CustomHeader"; + const string Value1 = "value1"; + const string Value2 = "value2"; + CorsConfig config = CorsConfigBuilder.ForOrigin((AsciiString)"http://localhost:8888") + .PreflightResponseHeader((AsciiString)HeaderName, (AsciiString)Value1, (AsciiString)Value2).Build(); + IHttpResponse response = PreflightRequest(config, "http://localhost:8888", "content-type, xheader1"); + AssertValues(response, HeaderName, Value1, Value2); + Assert.Equal(HttpHeaderNames.Origin.ToString(), response.Headers.Get(HttpHeaderNames.Vary, null)); + Assert.Equal("0", response.Headers.Get(HttpHeaderNames.ContentLength, null).ToString()); + } + + [Fact] + public void PreflightRequestWithCustomHeadersIterable() + { + const string HeaderName = "CustomHeader"; + const string Value1 = "value1"; + const string Value2 = "value2"; + CorsConfig config = CorsConfigBuilder.ForOrigin((AsciiString)"http://localhost:8888") + .PreflightResponseHeader((AsciiString)HeaderName, new List { (AsciiString)Value1, (AsciiString)Value2 }) + .Build(); + IHttpResponse response = PreflightRequest(config, "http://localhost:8888", "content-type, xheader1"); + AssertValues(response, HeaderName, Value1, Value2); + Assert.Equal(HttpHeaderNames.Origin.ToString(), response.Headers.Get(HttpHeaderNames.Vary, null)); + } + + class ValueGenerator : ICallable + { + public object Call() => "generatedValue"; + } + + [Fact] + public void PreflightRequestWithValueGenerator() + { + CorsConfig config = CorsConfigBuilder.ForOrigin((AsciiString)"http://localhost:8888") + .PreflightResponseHeader((AsciiString)"GenHeader", new ValueGenerator()).Build(); + IHttpResponse response = PreflightRequest(config, "http://localhost:8888", "content-type, xheader1"); + Assert.Equal("generatedValue", response.Headers.Get((AsciiString)"GenHeader", null).ToString()); + Assert.Equal(HttpHeaderNames.Origin.ToString(), response.Headers.Get(HttpHeaderNames.Vary, null)); + } + + [Fact] + public void PreflightRequestWithNullOrigin() + { + var origin = new AsciiString("null"); + CorsConfig config = CorsConfigBuilder.ForOrigin(origin) + .AllowNullOrigin() + .AllowCredentials() + .Build(); + IHttpResponse response = PreflightRequest(config, origin.ToString(), "content-type, xheader1"); + Assert.Equal("null", response.Headers.Get(HttpHeaderNames.AccessControlAllowOrigin, null)); + Assert.Equal("true", response.Headers.Get(HttpHeaderNames.AccessControlAllowCredentials, null)); + } + + [Fact] + public void PreflightRequestAllowCredentials() + { + var origin = new AsciiString("null"); + CorsConfig config = CorsConfigBuilder.ForOrigin(origin).AllowCredentials().Build(); + IHttpResponse response = PreflightRequest(config, origin.ToString(), "content-type, xheader1"); + Assert.Equal("true", response.Headers.Get(HttpHeaderNames.AccessControlAllowCredentials, null)); + } + + [Fact] + public void PreflightRequestDoNotAllowCredentials() + { + CorsConfig config = CorsConfigBuilder.ForOrigin((AsciiString)"http://localhost:8888").Build(); + IHttpResponse response = PreflightRequest(config, "http://localhost:8888", ""); + // the only valid value for Access-Control-Allow-Credentials is true. + Assert.False(response.Headers.Contains(HttpHeaderNames.AccessControlAllowCredentials)); + } + + [Fact] + public void SimpleRequestCustomHeaders() + { + CorsConfig config = CorsConfigBuilder.ForAnyOrigin() + .ExposeHeaders((AsciiString)"custom1", (AsciiString)"custom2").Build(); + IHttpResponse response = SimpleRequest(config, "http://localhost:7777"); + Assert.Equal("*", response.Headers.Get(HttpHeaderNames.AccessControlAllowOrigin, null)); + Assert.Contains("custom1", response.Headers.Get(HttpHeaderNames.AccessControlExposeHeaders, null).ToString()); + Assert.Contains("custom2", response.Headers.Get(HttpHeaderNames.AccessControlExposeHeaders, null).ToString()); + } + + [Fact] + public void SimpleRequestAllowCredentials() + { + CorsConfig config = CorsConfigBuilder.ForAnyOrigin().AllowCredentials().Build(); + IHttpResponse response = SimpleRequest(config, "http://localhost:7777"); + Assert.Equal("true", response.Headers.Get(HttpHeaderNames.AccessControlAllowCredentials, null)); + } + + [Fact] + public void SimpleRequestDoNotAllowCredentials() + { + CorsConfig config = CorsConfigBuilder.ForAnyOrigin().Build(); + IHttpResponse response = SimpleRequest(config, "http://localhost:7777"); + Assert.False(response.Headers.Contains(HttpHeaderNames.AccessControlAllowCredentials)); + } + + [Fact] + public void AnyOriginAndAllowCredentialsShouldEchoRequestOrigin() + { + CorsConfig config = CorsConfigBuilder.ForAnyOrigin().AllowCredentials().Build(); + IHttpResponse response = SimpleRequest(config, "http://localhost:7777"); + Assert.Equal("true", response.Headers.Get(HttpHeaderNames.AccessControlAllowCredentials, null)); + Assert.Equal("http://localhost:7777", response.Headers.Get(HttpHeaderNames.AccessControlAllowOrigin, null).ToString()); + Assert.Equal(HttpHeaderNames.Origin.ToString(), response.Headers.Get(HttpHeaderNames.Vary, null)); + } + + [Fact] + public void SimpleRequestExposeHeaders() + { + CorsConfig config = CorsConfigBuilder.ForAnyOrigin() + .ExposeHeaders((AsciiString)"one", (AsciiString)"two").Build(); + IHttpResponse response = SimpleRequest(config, "http://localhost:7777"); + Assert.Contains("one", response.Headers.Get(HttpHeaderNames.AccessControlExposeHeaders, null).ToString()); + Assert.Contains("two", response.Headers.Get(HttpHeaderNames.AccessControlExposeHeaders, null).ToString()); + } + + [Fact] + public void SimpleRequestShortCircuit() + { + CorsConfig config = CorsConfigBuilder.ForOrigin((AsciiString)"http://localhost:8080") + .ShortCircuit().Build(); + IHttpResponse response = SimpleRequest(config, "http://localhost:7777"); + Assert.Equal(HttpResponseStatus.Forbidden, response.Status); + Assert.Equal("0", response.Headers.Get(HttpHeaderNames.ContentLength, null).ToString()); + } + + [Fact] + public void SimpleRequestNoShortCircuit() + { + CorsConfig config = CorsConfigBuilder.ForOrigin((AsciiString)"http://localhost:8080").Build(); + IHttpResponse response = SimpleRequest(config, "http://localhost:7777"); + Assert.Equal(HttpResponseStatus.OK, response.Status); + Assert.Null(response.Headers.Get(HttpHeaderNames.AccessControlAllowOrigin, null)); + } + + [Fact] + public void ShortCircuitNonCorsRequest() + { + CorsConfig config = CorsConfigBuilder.ForOrigin((AsciiString)"https://localhost") + .ShortCircuit().Build(); + IHttpResponse response = SimpleRequest(config, null); + Assert.Equal(HttpResponseStatus.OK, response.Status); + Assert.Null(response.Headers.Get(HttpHeaderNames.AccessControlAllowOrigin, null)); + } + + [Fact] + public void ShortCircuitWithConnectionKeepAliveShouldStayOpen() + { + CorsConfig config = CorsConfigBuilder.ForOrigin((AsciiString)"http://localhost:8080") + .ShortCircuit().Build(); + var channel = new EmbeddedChannel(new CorsHandler(config)); + IFullHttpRequest request = CreateHttpRequest(HttpMethod.Get); + request.Headers.Set(HttpHeaderNames.Origin, (AsciiString)"http://localhost:8888"); + request.Headers.Set(HttpHeaderNames.Connection, HttpHeaderValues.KeepAlive); + + Assert.False(channel.WriteInbound(request)); + var response = channel.ReadOutbound(); + Assert.True(HttpUtil.IsKeepAlive(response)); + + Assert.True(channel.Open); + Assert.Equal(HttpResponseStatus.Forbidden, response.Status); + Assert.True(ReferenceCountUtil.Release(response)); + Assert.False(channel.Finish()); + } + + [Fact] + public void ShortCircuitWithoutConnectionShouldStayOpen() + { + CorsConfig config = CorsConfigBuilder.ForOrigin((AsciiString)"http://localhost:8080") + .ShortCircuit().Build(); + var channel = new EmbeddedChannel(new CorsHandler(config)); + IFullHttpRequest request = CreateHttpRequest(HttpMethod.Get); + request.Headers.Set(HttpHeaderNames.Origin, (AsciiString)"http://localhost:8888"); + + Assert.False(channel.WriteInbound(request)); + var response = channel.ReadOutbound(); + Assert.True(HttpUtil.IsKeepAlive(response)); + + Assert.True(channel.Open); + Assert.Equal(HttpResponseStatus.Forbidden, response.Status); + Assert.True(ReferenceCountUtil.Release(response)); + Assert.False(channel.Finish()); + } + + [Fact] + public void ShortCircuitWithConnectionCloseShouldClose() + { + CorsConfig config = CorsConfigBuilder.ForOrigin((AsciiString)"http://localhost:8080") + .ShortCircuit().Build(); + var channel = new EmbeddedChannel(new CorsHandler(config)); + IFullHttpRequest request = CreateHttpRequest(HttpMethod.Get); + request.Headers.Set(HttpHeaderNames.Origin, "http://localhost:8888"); + request.Headers.Set(HttpHeaderNames.Connection, HttpHeaderValues.Close); + + Assert.False(channel.WriteInbound(request)); + var response = channel.ReadOutbound(); + Assert.False(HttpUtil.IsKeepAlive(response)); + + Assert.False(channel.Open); + Assert.Equal(HttpResponseStatus.Forbidden, response.Status); + Assert.True(ReferenceCountUtil.Release(response)); + Assert.False(channel.Finish()); + } + + [Fact] + public void PreflightRequestShouldReleaseRequest() + { + CorsConfig config = CorsConfigBuilder.ForOrigin((AsciiString)"http://localhost:8888") + .PreflightResponseHeader((AsciiString)"CustomHeader", new List{(AsciiString)"value1", (AsciiString)"value2"}) + .Build(); + var channel = new EmbeddedChannel(new CorsHandler(config)); + IFullHttpRequest request = OptionsRequest("http://localhost:8888", "content-type, xheader1", null); + Assert.False(channel.WriteInbound(request)); + Assert.Equal(0, request.ReferenceCount); + Assert.True(ReferenceCountUtil.Release(channel.ReadOutbound())); + Assert.False(channel.Finish()); + } + + [Fact] + public void PreflightRequestWithConnectionKeepAliveShouldStayOpen() + { + CorsConfig config = CorsConfigBuilder.ForOrigin((AsciiString)"http://localhost:8888").Build(); + var channel = new EmbeddedChannel(new CorsHandler(config)); + IFullHttpRequest request = OptionsRequest("http://localhost:8888", "", HttpHeaderValues.KeepAlive); + Assert.False(channel.WriteInbound(request)); + var response = channel.ReadOutbound(); + Assert.True(HttpUtil.IsKeepAlive(response)); + + Assert.True(channel.Open); + Assert.Equal(HttpResponseStatus.OK, response.Status); + Assert.True(ReferenceCountUtil.Release(response)); + Assert.False(channel.Finish()); + } + + [Fact] + public void PreflightRequestWithoutConnectionShouldStayOpen() + { + CorsConfig config = CorsConfigBuilder.ForOrigin((AsciiString)"http://localhost:8888").Build(); + var channel = new EmbeddedChannel(new CorsHandler(config)); + IFullHttpRequest request = OptionsRequest("http://localhost:8888", "", null); + Assert.False(channel.WriteInbound(request)); + var response = channel.ReadOutbound(); + Assert.True(HttpUtil.IsKeepAlive(response)); + + Assert.True(channel.Open); + Assert.Equal(HttpResponseStatus.OK, response.Status); + Assert.True(ReferenceCountUtil.Release(response)); + Assert.False(channel.Finish()); + } + + [Fact] + public void PreflightRequestWithConnectionCloseShouldClose() + { + CorsConfig config = CorsConfigBuilder.ForOrigin((AsciiString)"http://localhost:8888").Build(); + var channel = new EmbeddedChannel(new CorsHandler(config)); + IFullHttpRequest request = OptionsRequest("http://localhost:8888", "", HttpHeaderValues.Close); + Assert.False(channel.WriteInbound(request)); + var response = channel.ReadOutbound(); + Assert.False(HttpUtil.IsKeepAlive(response)); + + Assert.False(channel.Open); + Assert.Equal(HttpResponseStatus.OK, response.Status); + Assert.True(ReferenceCountUtil.Release(response)); + Assert.False(channel.Finish()); + } + + [Fact] + public void ForbiddenShouldReleaseRequest() + { + CorsConfig config = CorsConfigBuilder.ForOrigin((AsciiString)"https://localhost").ShortCircuit().Build(); + var channel = new EmbeddedChannel(new CorsHandler(config), new EchoHandler()); + IFullHttpRequest request = CreateHttpRequest(HttpMethod.Get); + request.Headers.Set(HttpHeaderNames.Origin, "http://localhost:8888"); + Assert.False(channel.WriteInbound(request)); + Assert.Equal(0, request.ReferenceCount); + Assert.True(ReferenceCountUtil.Release(channel.ReadOutbound())); + Assert.False(channel.Finish()); + } + + static IHttpResponse SimpleRequest(CorsConfig config, string origin, string requestHeaders = null) => + SimpleRequest(config, origin, requestHeaders, HttpMethod.Get); + + static IHttpResponse SimpleRequest(CorsConfig config, string origin, string requestHeaders, HttpMethod method) + { + var channel = new EmbeddedChannel(new CorsHandler(config), new EchoHandler()); + IFullHttpRequest httpRequest = CreateHttpRequest(method); + if (origin != null) + { + httpRequest.Headers.Set(HttpHeaderNames.Origin, new AsciiString(origin)); + } + if (requestHeaders != null) + { + httpRequest.Headers.Set(HttpHeaderNames.AccessControlRequestHeaders, new AsciiString(requestHeaders)); + } + + Assert.False(channel.WriteInbound(httpRequest)); + return channel.ReadOutbound(); + } + + static IHttpResponse PreflightRequest(CorsConfig config, string origin, string requestHeaders) + { + var channel = new EmbeddedChannel(new CorsHandler(config)); + Assert.False(channel.WriteInbound(OptionsRequest(origin, requestHeaders, null))); + var response = channel.ReadOutbound(); + Assert.False(channel.Finish()); + return response; + } + + static IFullHttpRequest OptionsRequest(string origin, string requestHeaders, AsciiString connection) + { + IFullHttpRequest httpRequest = CreateHttpRequest(HttpMethod.Options); + httpRequest.Headers.Set(HttpHeaderNames.Origin, new AsciiString(origin)); + httpRequest.Headers.Set(HttpHeaderNames.AccessControlRequestMethod, httpRequest.Method); + httpRequest.Headers.Set(HttpHeaderNames.AccessControlRequestHeaders, new AsciiString(requestHeaders)); + if (connection != null) + { + httpRequest.Headers.Set(HttpHeaderNames.Connection, connection); + } + + return httpRequest; + } + + static IFullHttpRequest CreateHttpRequest(HttpMethod method) => new DefaultFullHttpRequest(HttpVersion.Http11, method, "/info"); + + sealed class EchoHandler : SimpleChannelInboundHandler + { + protected override void ChannelRead0(IChannelHandlerContext ctx, object msg) => + ctx.WriteAndFlushAsync(new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK, true, true)); + } + + static void AssertValues(IHttpResponse response, string headerName, params string[] values) + { + ICharSequence header = response.Headers.Get(new AsciiString(headerName), null); + foreach (string value in values) + { + Assert.Contains(value, header.ToString()); + } + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/DefaultHttpHeadersTest.cs b/test/DotNetty.Codecs.Http.Tests/DefaultHttpHeadersTest.cs new file mode 100644 index 0000000..486604c --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/DefaultHttpHeadersTest.cs @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System.Collections.Generic; + using DotNetty.Common.Utilities; + using Xunit; + + public sealed class DefaultHttpHeadersTest + { + static readonly AsciiString HeaderName = new AsciiString("testHeader"); + + [Fact] + public void KeysShouldBeCaseInsensitive() + { + var headers = new DefaultHttpHeaders(); + headers.Add(HttpHeadersTestUtils.Of("Name"), HttpHeadersTestUtils.Of("value1")); + headers.Add(HttpHeadersTestUtils.Of("name"), HttpHeadersTestUtils.Of("value2")); + headers.Add(HttpHeadersTestUtils.Of("NAME"), HttpHeadersTestUtils.Of("value3")); + Assert.Equal(3, headers.Size); + + var values = new List + { + HttpHeadersTestUtils.Of("value1"), + HttpHeadersTestUtils.Of("value2"), + HttpHeadersTestUtils.Of("value3") + }; + + Assert.Equal(values, headers.GetAll(HttpHeadersTestUtils.Of("NAME"))); + Assert.Equal(values, headers.GetAll(HttpHeadersTestUtils.Of("name"))); + Assert.Equal(values, headers.GetAll(HttpHeadersTestUtils.Of("Name"))); + Assert.Equal(values, headers.GetAll(HttpHeadersTestUtils.Of("nAmE"))); + } + + [Fact] + public void KeysShouldBeCaseInsensitiveInHeadersEquals() + { + var headers1 = new DefaultHttpHeaders(); + headers1.Add(HttpHeadersTestUtils.Of("name1"), new[] { "value1", "value2", "value3" }); + headers1.Add(HttpHeadersTestUtils.Of("nAmE2"), HttpHeadersTestUtils.Of("value4")); + + var headers2 = new DefaultHttpHeaders(); + headers2.Add(HttpHeadersTestUtils.Of("naMe1"), new[] { "value1", "value2", "value3" }); + headers2.Add(HttpHeadersTestUtils.Of("NAME2"), HttpHeadersTestUtils.Of("value4")); + + Assert.True(Equals(headers1, headers2)); + Assert.True(Equals(headers2, headers1)); + Assert.Equal(headers1.GetHashCode(), headers2.GetHashCode()); + } + + [Fact] + public void StringKeyRetrievedAsAsciiString() + { + var headers = new DefaultHttpHeaders(false); + + // Test adding String key and retrieving it using a AsciiString key + const string Connection = "keep-alive"; + headers.Add(HttpHeadersTestUtils.Of("Connection"), Connection); + + // Passes + headers.TryGetAsString(HttpHeaderNames.Connection, out string value); + Assert.NotNull(value); + Assert.Equal(Connection, value); + + // Passes + ICharSequence value2 = headers.Get(HttpHeaderNames.Connection, null); + Assert.NotNull(value2); + Assert.Equal(Connection, value2); + } + + [Fact] + public void AsciiStringKeyRetrievedAsString() + { + var headers = new DefaultHttpHeaders(false); + + // Test adding AsciiString key and retrieving it using a String key + const string CacheControl = "no-cache"; + headers.Add(HttpHeaderNames.CacheControl, CacheControl); + + headers.TryGetAsString(HttpHeaderNames.CacheControl, out string value); + Assert.NotNull(value); + Assert.Equal(CacheControl, value); + + ICharSequence value2 = headers.Get(HttpHeaderNames.CacheControl, null); + Assert.NotNull(value2); + Assert.Equal(CacheControl, value2); + } + + [Fact] + public void GetOperations() + { + var headers = new DefaultHttpHeaders(); + headers.Add(HttpHeadersTestUtils.Of("Foo"), HttpHeadersTestUtils.Of("1")); + headers.Add(HttpHeadersTestUtils.Of("Foo"), HttpHeadersTestUtils.Of("2")); + + Assert.Equal("1", headers.Get(HttpHeadersTestUtils.Of("Foo"), null)); + + IList values = headers.GetAll(HttpHeadersTestUtils.Of("Foo")); + Assert.Equal(2, values.Count); + Assert.Equal("1", values[0]); + Assert.Equal("2", values[1]); + } + + [Fact] + public void EqualsIgnoreCase() + { + Assert.True(AsciiString.ContentEqualsIgnoreCase(null, null)); + Assert.False(AsciiString.ContentEqualsIgnoreCase(null, (StringCharSequence)"foo")); + Assert.False(AsciiString.ContentEqualsIgnoreCase((StringCharSequence)"bar", null)); + Assert.True(AsciiString.ContentEqualsIgnoreCase((StringCharSequence)"FoO", (StringCharSequence)"fOo")); + } + + [Fact] + public void AddCharSequences() + { + var headers = new DefaultHttpHeaders(); + headers.Add(HeaderName, HttpHeadersTestUtils.HeaderValue.Three.AsList()); + AssertDefaultValues(headers, HttpHeadersTestUtils.HeaderValue.Three); + } + + [Fact] + public void AddObjects() + { + var headers = new DefaultHttpHeaders(); + headers.Add(HeaderName, HttpHeadersTestUtils.HeaderValue.Three.AsList()); + AssertDefaultValues(headers, HttpHeadersTestUtils.HeaderValue.Three); + } + + [Fact] + public void SetCharSequences() + { + var headers = new DefaultHttpHeaders(); + headers.Set(HeaderName, HttpHeadersTestUtils.HeaderValue.Three.AsList()); + AssertDefaultValues(headers, HttpHeadersTestUtils.HeaderValue.Three); + } + + static void AssertDefaultValues(HttpHeaders headers, HttpHeadersTestUtils.HeaderValue headerValue) + { + Assert.True(AsciiString.ContentEquals(headerValue.AsList()[0], (StringCharSequence)headers.Get(HeaderName, null))); + List expected = headerValue.AsList(); + IList actual = headers.GetAll(HeaderName); + Assert.Equal(expected.Count, actual.Count); + + for (int i =0; i < expected.Count; i++) + { + Assert.True(AsciiString.ContentEquals(expected[i], (StringCharSequence)actual[i])); + } + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/DefaultHttpRequestTest.cs b/test/DotNetty.Codecs.Http.Tests/DefaultHttpRequestTest.cs new file mode 100644 index 0000000..a2158ee --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/DefaultHttpRequestTest.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using DotNetty.Common.Utilities; + using Xunit; + + public sealed class DefaultHttpRequestTest + { + [Fact] + public void HeaderRemoval() + { + var m = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/"); + HttpHeaders h = m.Headers; + + // Insert sample keys. + for (int i = 0; i < 1000; i++) + { + h.Set(HttpHeadersTestUtils.Of(i.ToString()), AsciiString.Empty); + } + + // Remove in reversed order. + for (int i = 999; i >= 0; i--) + { + h.Remove(HttpHeadersTestUtils.Of(i.ToString())); + } + + // Check if random access returns nothing. + for (int i = 0; i < 1000; i++) + { + Assert.False(h.TryGet(HttpHeadersTestUtils.Of(i.ToString()), out _)); + } + + // Check if sequential access returns nothing. + Assert.True(h.IsEmpty); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/DotNetty.Codecs.Http.Tests.csproj b/test/DotNetty.Codecs.Http.Tests/DotNetty.Codecs.Http.Tests.csproj new file mode 100644 index 0000000..860ece2 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/DotNetty.Codecs.Http.Tests.csproj @@ -0,0 +1,38 @@ + + + true + netcoreapp1.1;net452 + false + ../../DotNetty.snk + true + + + + + + + + + + + + + + PreserveNewest + + + PreserveNewest + + + + + + + + + + + + + + \ No newline at end of file diff --git a/test/DotNetty.Codecs.Http.Tests/HttpChunkedInputTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpChunkedInputTest.cs new file mode 100644 index 0000000..8c0493f --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpChunkedInputTest.cs @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System.IO; + using DotNetty.Buffers; + using DotNetty.Handlers.Streams; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class HttpChunkedInputTest + { + static readonly byte[] Bytes = new byte[1024 * 64]; + + static HttpChunkedInputTest() + { + for (int i = 0; i < Bytes.Length; i++) + { + Bytes[i] = (byte)i; + } + } + + [Fact] + public void ChunkedStream() + { + var stream = new ChunkedStream(new MemoryStream(Bytes)); + Check(new HttpChunkedInput(stream)); + } + + [Fact] + public void WrappedReturnNull() + { + var input = new EmptyChunkedInput(); + var httpInput = new HttpChunkedInput(input); + + IHttpContent result = httpInput.ReadChunk(PooledByteBufferAllocator.Default); + Assert.Null(result); + } + + sealed class EmptyChunkedInput : IChunkedInput + { + public bool IsEndOfInput => false; + + public void Close() + { + // NOOP + } + + public IByteBuffer ReadChunk(IByteBufferAllocator allocator) => null; + + public long Length => 0; + + public long Progress => 0; + } + + static void Check(params IChunkedInput[] inputs) + { + var ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + foreach (IChunkedInput input in inputs) + { + ch.WriteOutbound(input); + } + Assert.True(ch.Finish()); + + int i = 0; + int read = 0; + IHttpContent lastHttpContent = null; + for (;;) + { + var httpContent = ch.ReadOutbound(); + if (httpContent == null) + { + break; + } + + if (lastHttpContent != null) + { + Assert.True(lastHttpContent is DefaultHttpContent, "Chunk must be DefaultHttpContent"); + } + + IByteBuffer buffer = httpContent.Content; + while (buffer.IsReadable()) + { + Assert.Equal(Bytes[i++], buffer.ReadByte()); + read++; + if (i == Bytes.Length) + { + i = 0; + } + } + buffer.Release(); + + // Save last chunk + lastHttpContent = httpContent; + } + + Assert.Equal(Bytes.Length * inputs.Length, read); + + //Last chunk must be EmptyLastHttpContent.Default + Assert.Same(EmptyLastHttpContent.Default, lastHttpContent); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpClientCodecTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpClientCodecTest.cs new file mode 100644 index 0000000..09876ee --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpClientCodecTest.cs @@ -0,0 +1,358 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System; + using System.Net; + using System.Text; + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Common.Concurrency; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Bootstrapping; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + using DotNetty.Transport.Channels.Sockets; + using Xunit; + + using HttpVersion = DotNetty.Codecs.Http.HttpVersion; + + public sealed class HttpClientCodecTest + { + const string EmptyResponse = "HTTP/1.0 200 OK\r\nContent-Length: 0\r\n\r\n"; + + const string Response = "HTTP/1.0 200 OK\r\n" + + "Date: Fri, 31 Dec 1999 23:59:59 GMT\r\n" + + "Content-Type: text/html\r\n" + "Content-Length: 28\r\n" + "\r\n" + + "\r\n"; + + const string IncompleteChunkedResponse = "HTTP/1.1 200 OK\r\n" + + "Content-Type: text/plain\r\n" + + "Transfer-Encoding: chunked\r\n" + "\r\n" + + "5\r\n" + "first\r\n" + "6\r\n" + "second\r\n" + "0\r\n"; + + const string ChunkedResponse = IncompleteChunkedResponse + "\r\n"; + + [Fact] + public void ConnectWithResponseContent() + { + var codec = new HttpClientCodec(4096, 8192, 8192, true); + var ch = new EmbeddedChannel(codec); + + SendRequestAndReadResponse(ch, HttpMethod.Connect, Response); + ch.Finish(); + } + + [Fact] + public void FailsNotOnRequestResponseChunked() + { + var codec = new HttpClientCodec(4096, 8192, 8192, true); + var ch = new EmbeddedChannel(codec); + + SendRequestAndReadResponse(ch, HttpMethod.Get, ChunkedResponse); + ch.Finish(); + } + + [Fact] + public void FailsOnMissingResponse() + { + var codec = new HttpClientCodec(4096, 8192, 8192, true); + var ch = new EmbeddedChannel(codec); + + Assert.True(ch.WriteOutbound(new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "http://localhost/"))); + var buffer = ch.ReadOutbound(); + Assert.NotNull(buffer); + buffer.Release(); + + Assert.Throws(() => ch.Finish()); + } + + [Fact] + public void FailsOnIncompleteChunkedResponse() + { + var codec = new HttpClientCodec(4096, 8192, 8192, true); + var ch = new EmbeddedChannel(codec); + + ch.WriteOutbound(new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "http://localhost/")); + var buffer = ch.ReadOutbound(); + Assert.NotNull(buffer); + buffer.Release(); + Assert.Null(ch.ReadInbound()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes(IncompleteChunkedResponse))); + var response = ch.ReadInbound(); + Assert.NotNull(response); + var content = ch.ReadInbound(); + Assert.NotNull(content); // Chunk 'first' + content.Release(); + + content = ch.ReadInbound(); + Assert.NotNull(content); // Chunk 'second' + content.Release(); + + content = ch.ReadInbound(); + Assert.Null(content); + + Assert.Throws(() => ch.Finish()); + } + + [Fact] + public void ServerCloseSocketInputProvidesData() + { + var clientGroup = new MultithreadEventLoopGroup(1); + var serverGroup = new MultithreadEventLoopGroup(1); + try + { + var serverCompletion = new TaskCompletionSource(); + + var serverHandler = new ServerHandler(); + ServerBootstrap sb = new ServerBootstrap() + .Group(serverGroup) + .Channel() + .ChildHandler( + new ActionChannelInitializer( + ch => + { + // Don't use the HttpServerCodec, because we don't want to have content-length or anything added. + ch.Pipeline.AddLast(new HttpRequestDecoder(4096, 8192, 8192, true)); + ch.Pipeline.AddLast(new HttpObjectAggregator(4096)); + ch.Pipeline.AddLast(serverHandler); + serverCompletion.TryComplete(); + })); + + var clientHandler = new ClientHandler(); + Bootstrap cb = new Bootstrap() + .Group(clientGroup) + .Channel() + .Handler( + new ActionChannelInitializer( + ch => + { + ch.Pipeline.AddLast(new HttpClientCodec(4096, 8192, 8192, true)); + ch.Pipeline.AddLast(new HttpObjectAggregator(4096)); + ch.Pipeline.AddLast(clientHandler); + })); + + Task task = sb.BindAsync(IPAddress.Loopback, IPEndPoint.MinPort); + task.Wait(TimeSpan.FromSeconds(5)); + Assert.True(task.Status == TaskStatus.RanToCompletion); + IChannel serverChannel = task.Result; + int port = ((IPEndPoint)serverChannel.LocalAddress).Port; + + task = cb.ConnectAsync(IPAddress.Loopback, port); + task.Wait(TimeSpan.FromSeconds(5)); + Assert.True(task.Status == TaskStatus.RanToCompletion); + IChannel clientChannel = task.Result; + + serverCompletion.Task.Wait(TimeSpan.FromSeconds(5)); + clientChannel.WriteAndFlushAsync(new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/")).Wait(TimeSpan.FromSeconds(1)); + Assert.True(serverHandler.WaitForCompletion()); + Assert.True(clientHandler.WaitForCompletion()); + } + finally + { + Task.WaitAll( + clientGroup.ShutdownGracefullyAsync(TimeSpan.FromMilliseconds(100), TimeSpan.FromSeconds(1)), + serverGroup.ShutdownGracefullyAsync(TimeSpan.FromMilliseconds(100), TimeSpan.FromSeconds(1))); + } + } + + class ClientHandler : SimpleChannelInboundHandler + { + readonly TaskCompletionSource completion = new TaskCompletionSource(); + + public bool WaitForCompletion() + { + this.completion.Task.Wait(TimeSpan.FromSeconds(5)); + return this.completion.Task.Status == TaskStatus.RanToCompletion; + } + + protected override void ChannelRead0(IChannelHandlerContext ctx, IFullHttpResponse msg) => + this.completion.TryComplete(); + + public override void ExceptionCaught(IChannelHandlerContext context, Exception exception) => + this.completion.TrySetException(exception); + } + + class ServerHandler : SimpleChannelInboundHandler + { + readonly TaskCompletionSource completion = new TaskCompletionSource(); + + public bool WaitForCompletion() + { + this.completion.Task.Wait(TimeSpan.FromSeconds(5)); + return this.completion.Task.Status == TaskStatus.RanToCompletion; + } + + protected override void ChannelRead0(IChannelHandlerContext ctx, IFullHttpRequest msg) + { + // This is just a simple demo...don't block in IO + Assert.IsAssignableFrom(ctx.Channel); + + var sChannel = (ISocketChannel)ctx.Channel; + /** + * The point of this test is to not add any content-length or content-encoding headers + * and the client should still handle this. + * See RFC 7230, 3.3.3. + */ + + sChannel.WriteAndFlushAsync(Unpooled.WrappedBuffer(Encoding.UTF8.GetBytes("HTTP/1.0 200 OK\r\n" + "Date: Fri, 31 Dec 1999 23:59:59 GMT\r\n" + "Content-Type: text/html\r\n\r\n"))); + sChannel.WriteAndFlushAsync(Unpooled.WrappedBuffer(Encoding.UTF8.GetBytes("hello half closed!\r\n"))); + sChannel.CloseAsync(); + + sChannel.CloseCompletion.LinkOutcome(this.completion); + } + + public override void ExceptionCaught(IChannelHandlerContext context, Exception exception) => this.completion.TrySetException(exception); + } + + [Fact] + public void ContinueParsingAfterConnect() => AfterConnect(true); + + [Fact] + public void PassThroughAfterConnect() => AfterConnect(false); + + static void AfterConnect(bool parseAfterConnect) + { + var ch = new EmbeddedChannel(new HttpClientCodec(4096, 8192, 8192, true, true, parseAfterConnect)); + var connectResponseConsumer = new Consumer(); + SendRequestAndReadResponse(ch, HttpMethod.Connect, EmptyResponse, connectResponseConsumer); + + Assert.True(connectResponseConsumer.ReceivedCount > 0, "No connect response messages received."); + + void Handler(object m) + { + if (parseAfterConnect) + { + Assert.True(m is IHttpObject, "Unexpected response message type."); + } + else + { + Assert.False(m is IHttpObject, "Unexpected response message type."); + } + } + + var responseConsumer = new Consumer(Handler); + + SendRequestAndReadResponse(ch, HttpMethod.Get, Response, responseConsumer); + Assert.True(responseConsumer.ReceivedCount > 0, "No response messages received."); + Assert.False(ch.Finish(), "Channel finish failed."); + } + + static void SendRequestAndReadResponse(EmbeddedChannel ch, HttpMethod httpMethod, string response) => + SendRequestAndReadResponse(ch, httpMethod, response, new Consumer()); + + static void SendRequestAndReadResponse( + EmbeddedChannel ch, + HttpMethod httpMethod, + string response, + Consumer responseConsumer) + { + Assert.True( + ch.WriteOutbound(new DefaultFullHttpRequest(HttpVersion.Http11, httpMethod, "http://localhost/")), + "Channel outbound write failed."); + Assert.True( + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes(response))), + "Channel inbound write failed."); + + for (;;) + { + var msg = ch.ReadOutbound(); + if (msg == null) + { + break; + } + ReferenceCountUtil.Release(msg); + } + for (;;) + { + var msg = ch.ReadInbound(); + if (msg == null) + { + break; + } + responseConsumer.OnResponse(msg); + ReferenceCountUtil.Release(msg); + } + } + + sealed class Consumer + { + readonly Action handler; + + public Consumer(Action handler = null) + { + this.handler = handler; + } + + public void OnResponse(object response) + { + this.ReceivedCount++; + this.Accept(response); + } + + void Accept(object response) + { + this.handler?.Invoke(response); + } + + public int ReceivedCount { get; private set; } + } + + [Fact] + public void DecodesFinalResponseAfterSwitchingProtocols() + { + const string SwitchingProtocolsResponse = "HTTP/1.1 101 Switching Protocols\r\n" + + "Connection: Upgrade\r\n" + + "Upgrade: TLS/1.2, HTTP/1.1\r\n\r\n"; + + var codec = new HttpClientCodec(4096, 8192, 8192, true); + var ch = new EmbeddedChannel(codec, new HttpObjectAggregator(1024)); + + IHttpRequest request = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "http://localhost/"); + request.Headers.Set(HttpHeaderNames.Connection, HttpHeaderValues.Upgrade); + request.Headers.Set(HttpHeaderNames.Upgrade, "TLS/1.2"); + Assert.True(ch.WriteOutbound(request), "Channel outbound write failed."); + + Assert.True( + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes(SwitchingProtocolsResponse))), + "Channel inbound write failed."); + var switchingProtocolsResponse = ch.ReadInbound(); + Assert.NotNull(switchingProtocolsResponse); + switchingProtocolsResponse.Release(); + + Assert.True( + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes(Response))), + "Channel inbound write failed"); + var finalResponse = ch.ReadInbound(); + Assert.NotNull(finalResponse); + finalResponse.Release(); + Assert.True(ch.FinishAndReleaseAll(), "Channel finish failed"); + } + + [Fact] + public void WebSocket00Response() + { + byte[] data = Encoding.UTF8.GetBytes("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + + "Upgrade: WebSocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Origin: http://localhost:8080\r\n" + + "Sec-WebSocket-Location: ws://localhost/some/path\r\n" + + "\r\n" + + "1234567812345678"); + var ch = new EmbeddedChannel(new HttpClientCodec()); + Assert.True(ch.WriteInbound(Unpooled.WrappedBuffer(data))); + + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.SwitchingProtocols, res.Status); + var content = ch.ReadInbound(); + Assert.Equal(16, content.Content.ReadableBytes); + content.Release(); + + Assert.False(ch.Finish()); + var next = ch.ReadInbound(); + Assert.Null(next); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpClientUpgradeHandlerTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpClientUpgradeHandlerTest.cs new file mode 100644 index 0000000..29e13e5 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpClientUpgradeHandlerTest.cs @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System.Collections.Generic; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class HttpClientUpgradeHandlerTest + { + sealed class FakeSourceCodec : HttpClientUpgradeHandler.ISourceCodec + { + public void PrepareUpgradeFrom(IChannelHandlerContext ctx) + { + //NOP + } + + public void UpgradeFrom(IChannelHandlerContext ctx) + { + //NOP + } + } + + sealed class FakeUpgradeCodec : HttpClientUpgradeHandler.IUpgradeCodec + { + public ICharSequence Protocol => new AsciiString("fancyhttp"); + + public ICollection SetUpgradeHeaders(IChannelHandlerContext ctx, IHttpRequest upgradeRequest) => + new List(); + + public void UpgradeTo(IChannelHandlerContext ctx, IFullHttpResponse upgradeResponse) + { + //NOP + } + } + + sealed class UserEventCatcher : ChannelHandlerAdapter + { + public object UserEvent { get; private set; } + + public override void UserEventTriggered(IChannelHandlerContext context, object evt) => + this.UserEvent = evt; + } + + [Fact] + public void SuccessfulUpgrade() + { + HttpClientUpgradeHandler.ISourceCodec sourceCodec = new FakeSourceCodec(); + HttpClientUpgradeHandler.IUpgradeCodec upgradeCodec = new FakeUpgradeCodec(); + var handler = new HttpClientUpgradeHandler(sourceCodec, upgradeCodec, 1024); + var catcher = new UserEventCatcher(); + var channel = new EmbeddedChannel(catcher); + channel.Pipeline.AddFirst("upgrade", handler); + + Assert.True(channel.WriteOutbound(new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "netty.io"))); + var request = channel.ReadOutbound(); + + Assert.Equal(2, request.Headers.Size); + Assert.True(request.Headers.Contains(HttpHeaderNames.Upgrade, (AsciiString)"fancyhttp", false)); + Assert.True(request.Headers.Contains((AsciiString)"connection", (AsciiString)"upgrade", false)); + Assert.True(request.Release()); + Assert.Equal(HttpClientUpgradeHandler.UpgradeEvent.UpgradeIssued, catcher.UserEvent); + + var upgradeResponse = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.SwitchingProtocols); + upgradeResponse.Headers.Add(HttpHeaderNames.Upgrade, (AsciiString)"fancyhttp"); + Assert.False(channel.WriteInbound(upgradeResponse)); + Assert.False(channel.WriteInbound(EmptyLastHttpContent.Default)); + + Assert.Equal(HttpClientUpgradeHandler.UpgradeEvent.UpgradeSuccessful, catcher.UserEvent); + Assert.Null(channel.Pipeline.Get("upgrade")); + + Assert.True(channel.WriteInbound(new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK))); + var response = channel.ReadInbound(); + Assert.Equal(HttpResponseStatus.OK, response.Status); + Assert.True(response.Release()); + Assert.False(channel.Finish()); + } + + [Fact] + public void UpgradeRejected() + { + HttpClientUpgradeHandler.ISourceCodec sourceCodec = new FakeSourceCodec(); + HttpClientUpgradeHandler.IUpgradeCodec upgradeCodec = new FakeUpgradeCodec(); + var handler = new HttpClientUpgradeHandler(sourceCodec, upgradeCodec, 1024); + var catcher = new UserEventCatcher(); + var channel = new EmbeddedChannel(catcher); + channel.Pipeline.AddFirst("upgrade", handler); + + Assert.True(channel.WriteOutbound(new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "netty.io"))); + var request = channel.ReadOutbound(); + + Assert.Equal(2, request.Headers.Size); + Assert.True(request.Headers.Contains(HttpHeaderNames.Upgrade, (AsciiString)"fancyhttp", false)); + Assert.True(request.Headers.Contains((AsciiString)"connection", (AsciiString)"upgrade", false)); + Assert.True(request.Release()); + Assert.Equal(HttpClientUpgradeHandler.UpgradeEvent.UpgradeIssued, catcher.UserEvent); + + var upgradeResponse = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.SwitchingProtocols); + upgradeResponse.Headers.Add(HttpHeaderNames.Upgrade, (AsciiString)"fancyhttp"); + Assert.True(channel.WriteInbound(new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK))); + Assert.True(channel.WriteInbound(EmptyLastHttpContent.Default)); + + Assert.Equal(HttpClientUpgradeHandler.UpgradeEvent.UpgradeRejected, catcher.UserEvent); + Assert.Null(channel.Pipeline.Get("upgrade")); + + var response = channel.ReadInbound(); + Assert.Equal(HttpResponseStatus.OK, response.Status); + + var last = channel.ReadInbound(); + Assert.Equal(EmptyLastHttpContent.Default, last); + Assert.False(last.Release()); + Assert.False(channel.Finish()); + } + + [Fact] + public void EarlyBailout() + { + HttpClientUpgradeHandler.ISourceCodec sourceCodec = new FakeSourceCodec(); + HttpClientUpgradeHandler.IUpgradeCodec upgradeCodec = new FakeUpgradeCodec(); + var handler = new HttpClientUpgradeHandler(sourceCodec, upgradeCodec, 1024); + var catcher = new UserEventCatcher(); + var channel = new EmbeddedChannel(catcher); + channel.Pipeline.AddFirst("upgrade", handler); + + Assert.True(channel.WriteOutbound(new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "netty.io"))); + var request = channel.ReadOutbound(); + + Assert.Equal(2, request.Headers.Size); + Assert.True(request.Headers.Contains(HttpHeaderNames.Upgrade, (AsciiString)"fancyhttp", false)); + Assert.True(request.Headers.Contains((AsciiString)"connection", (AsciiString)"upgrade", false)); + Assert.True(request.Release()); + Assert.Equal(HttpClientUpgradeHandler.UpgradeEvent.UpgradeIssued, catcher.UserEvent); + + var upgradeResponse = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.SwitchingProtocols); + upgradeResponse.Headers.Add(HttpHeaderNames.Upgrade, (AsciiString)"fancyhttp"); + Assert.True(channel.WriteInbound(new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK))); + + Assert.Equal(HttpClientUpgradeHandler.UpgradeEvent.UpgradeRejected, catcher.UserEvent); + Assert.Null(channel.Pipeline.Get("upgrade")); + + var response = channel.ReadInbound(); + Assert.Equal(HttpResponseStatus.OK, response.Status); + Assert.False(channel.Finish()); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpContentCompressorTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpContentCompressorTest.cs new file mode 100644 index 0000000..ed0bcd2 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpContentCompressorTest.cs @@ -0,0 +1,441 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Codecs.Compression; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class HttpContentCompressorTest + { + [Fact] + public void GetTargetContentEncoding() + { + var compressor = new HttpContentCompressor(); + + string[] tests = + { + // Accept-Encoding -> Content-Encoding + "", null, + "*", "gzip", + "*;q=0.0", null, + "gzip", "gzip", + "compress, gzip;q=0.5", "gzip", + "gzip; q=0.5, identity", "gzip", + "gzip ; q=0.1", "gzip", + "gzip; q=0, deflate", "deflate", + " deflate ; q=0 , *;q=0.5", "gzip" + }; + for (int i = 0; i < tests.Length; i += 2) + { + var acceptEncoding = (AsciiString)tests[i]; + string contentEncoding = tests[i + 1]; + ZlibWrapper? targetWrapper = compressor.DetermineWrapper(acceptEncoding); + string targetEncoding = null; + if (targetWrapper != null) + { + switch (targetWrapper) + { + case ZlibWrapper.Gzip: + targetEncoding = "gzip"; + break; + case ZlibWrapper.Zlib: + targetEncoding = "deflate"; + break; + default: + Assert.True(false, $"Invalid type {targetWrapper}"); + break; + } + } + Assert.Equal(contentEncoding, targetEncoding); + } + } + + [Fact] + public void SplitContent() + { + var ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.WriteInbound(NewRequest()); + + ch.WriteOutbound(new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK)); + ch.WriteOutbound(new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("Hell")))); + ch.WriteOutbound(new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("o, w")))); + ch.WriteOutbound(new DefaultLastHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("orld")))); + + AssertEncodedResponse(ch); + + var chunk = ch.ReadOutbound(); + Assert.Equal("1f8b080000000000000bf248cdc901000000ffff", ByteBufferUtil.HexDump(chunk.Content)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("cad7512807000000ffff", ByteBufferUtil.HexDump(chunk.Content)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("ca2fca4901000000ffff", ByteBufferUtil.HexDump(chunk.Content)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("0300c2a99ae70c000000", ByteBufferUtil.HexDump(chunk.Content)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.False(chunk.Content.IsReadable()); + Assert.Equal(EmptyLastHttpContent.Default, chunk); + chunk.Release(); + + var last = ch.ReadOutbound(); + Assert.Null(last); + } + + [Fact] + public void ChunkedContent() + { + var ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.WriteInbound(NewRequest()); + + var res = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + res.Headers.Set(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + ch.WriteOutbound(res); + + AssertEncodedResponse(ch); + + ch.WriteOutbound(new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("Hell")))); + ch.WriteOutbound(new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("o, w")))); + ch.WriteOutbound(new DefaultLastHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("orld")))); + + var chunk = ch.ReadOutbound(); + Assert.Equal("1f8b080000000000000bf248cdc901000000ffff", ByteBufferUtil.HexDump(chunk.Content)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("cad7512807000000ffff", ByteBufferUtil.HexDump(chunk.Content)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("ca2fca4901000000ffff", ByteBufferUtil.HexDump(chunk.Content)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("0300c2a99ae70c000000", ByteBufferUtil.HexDump(chunk.Content)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.False(chunk.Content.IsReadable()); + Assert.Equal(EmptyLastHttpContent.Default, chunk); + chunk.Release(); + + var last = ch.ReadOutbound(); + Assert.Null(last); + } + + [Fact] + public void ChunkedContentWithTrailingHeader() + { + var ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.WriteInbound(NewRequest()); + + var res = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + res.Headers.Set(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + ch.WriteOutbound(res); + + AssertEncodedResponse(ch); + + ch.WriteOutbound(new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("Hell")))); + ch.WriteOutbound(new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("o, w")))); + var content = new DefaultLastHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("orld"))); + content.TrailingHeaders.Set((AsciiString)"X-Test", (AsciiString)"Netty"); + ch.WriteOutbound(content); + + var chunk = ch.ReadOutbound(); + Assert.Equal("1f8b080000000000000bf248cdc901000000ffff", ByteBufferUtil.HexDump(chunk.Content)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("cad7512807000000ffff", ByteBufferUtil.HexDump(chunk.Content)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("ca2fca4901000000ffff", ByteBufferUtil.HexDump(chunk.Content)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("0300c2a99ae70c000000", ByteBufferUtil.HexDump(chunk.Content)); + chunk.Release(); + + var lastChunk = ch.ReadOutbound(); + Assert.NotNull(lastChunk); + Assert.Equal("Netty", lastChunk.TrailingHeaders.Get((AsciiString)"X-Test", null).ToString()); + lastChunk.Release(); + + var last = ch.ReadOutbound(); + Assert.Null(last); + } + + [Fact] + public void FullContentWithContentLength() + { + var ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.WriteInbound(NewRequest()); + + var fullRes = new DefaultFullHttpResponse( + HttpVersion.Http11, + HttpResponseStatus.OK, + Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("Hello, World"))); + fullRes.Headers.Set(HttpHeaderNames.ContentLength, fullRes.Content.ReadableBytes); + ch.WriteOutbound(fullRes); + + var res = ch.ReadOutbound(); + Assert.NotNull(res); + Assert.False(res is IHttpContent, $"{res.GetType()}"); + + Assert.False(res.Headers.TryGet(HttpHeaderNames.TransferEncoding, out _)); + Assert.Equal("gzip", res.Headers.Get(HttpHeaderNames.ContentEncoding, null).ToString()); + + long contentLengthHeaderValue = HttpUtil.GetContentLength(res); + long observedLength = 0; + + var c = ch.ReadOutbound(); + observedLength += c.Content.ReadableBytes; + Assert.Equal("1f8b080000000000000bf248cdc9c9d75108cf2fca4901000000ffff", ByteBufferUtil.HexDump(c.Content)); + c.Release(); + + c = ch.ReadOutbound(); + observedLength += c.Content.ReadableBytes; + Assert.Equal("0300c6865b260c000000", ByteBufferUtil.HexDump(c.Content)); + c.Release(); + + var last = ch.ReadOutbound(); + Assert.Equal(0, last.Content.ReadableBytes); + last.Release(); + + var next = ch.ReadOutbound(); + Assert.Null(next); + Assert.Equal(contentLengthHeaderValue, observedLength); + } + + [Fact] + public void FullContent() + { + var ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.WriteInbound(NewRequest()); + + var res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK, + Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("Hello, World"))); + ch.WriteOutbound(res); + + AssertEncodedResponse(ch); + + var chunk = ch.ReadOutbound(); + Assert.Equal("1f8b080000000000000bf248cdc9c9d75108cf2fca4901000000ffff", ByteBufferUtil.HexDump(chunk.Content)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("0300c6865b260c000000", ByteBufferUtil.HexDump(chunk.Content)); + chunk.Release(); + + var lastChunk = ch.ReadOutbound(); + Assert.NotNull(lastChunk); + Assert.Equal(0, lastChunk.Content.ReadableBytes); + lastChunk.Release(); + + var last = ch.ReadOutbound(); + Assert.Null(last); + } + + [Fact] + public void EmptySplitContent() + { + var ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.WriteInbound(NewRequest()); + + ch.WriteOutbound(new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK)); + AssertEncodedResponse(ch); + + ch.WriteOutbound(EmptyLastHttpContent.Default); + var chunk = ch.ReadOutbound(); + Assert.Equal("1f8b080000000000000b03000000000000000000", ByteBufferUtil.HexDump(chunk.Content)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.False(chunk.Content.IsReadable()); + Assert.IsAssignableFrom(chunk); + + var last = ch.ReadOutbound(); + Assert.Null(last); + } + + [Fact] + public void EmptyFullContent() + { + var ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.WriteInbound(NewRequest()); + + IFullHttpResponse res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK, Unpooled.Empty); + ch.WriteOutbound(res); + + res = ch.ReadOutbound(); + Assert.NotNull(res); + + Assert.False(res.Headers.TryGet(HttpHeaderNames.TransferEncoding, out _)); + + // Content encoding shouldn't be modified. + Assert.False(res.Headers.TryGet(HttpHeaderNames.ContentEncoding, out _)); + Assert.Equal(0, res.Content.ReadableBytes); + Assert.Equal("", res.Content.ToString(Encoding.ASCII)); + res.Release(); + + var last = ch.ReadOutbound(); + Assert.Null(last); + } + + [Fact] + public void EmptyFullContentWithTrailer() + { + var ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.WriteInbound(NewRequest()); + + IFullHttpResponse res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK, Unpooled.Empty); + res.TrailingHeaders.Set((AsciiString)"X-Test", (AsciiString)"Netty"); + ch.WriteOutbound(res); + + res = ch.ReadOutbound(); + Assert.False(res.Headers.TryGet(HttpHeaderNames.TransferEncoding, out _)); + + // Content encoding shouldn't be modified. + Assert.False(res.Headers.TryGet(HttpHeaderNames.ContentEncoding, out _)); + Assert.Equal(0, res.Content.ReadableBytes); + Assert.Equal("", res.Content.ToString(Encoding.ASCII)); + Assert.Equal("Netty", res.TrailingHeaders.Get((AsciiString)"X-Test", null)); + + var last = ch.ReadOutbound(); + Assert.Null(last); + } + + [Fact] + public void Status100Continue() + { + IFullHttpRequest request = NewRequest(); + HttpUtil.Set100ContinueExpected(request, true); + + var ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.WriteInbound(request); + + var continueResponse = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.Continue, Unpooled.Empty); + ch.WriteOutbound(continueResponse); + + IFullHttpResponse res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK, Unpooled.Empty); + res.TrailingHeaders.Set((AsciiString)"X-Test", (AsciiString)"Netty"); + ch.WriteOutbound(res); + + res = ch.ReadOutbound(); + Assert.NotNull(res); + Assert.Same(continueResponse, res); + res.Release(); + + res = ch.ReadOutbound(); + Assert.NotNull(res); + Assert.False(res.Headers.TryGet(HttpHeaderNames.TransferEncoding, out _)); + + // Content encoding shouldn't be modified. + Assert.False(res.Headers.TryGet(HttpHeaderNames.ContentEncoding, out _)); + Assert.Equal(0, res.Content.ReadableBytes); + Assert.Equal("", res.Content.ToString(Encoding.ASCII)); + Assert.Equal("Netty", res.TrailingHeaders.Get((AsciiString)"X-Test", null)); + + var last = ch.ReadOutbound(); + Assert.Null(last); + } + + [Fact] + public void TooManyResponses() + { + var ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.WriteInbound(NewRequest()); + + ch.WriteOutbound(new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK, Unpooled.Empty)); + + try + { + ch.WriteOutbound(new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK, Unpooled.Empty)); + Assert.True(false, "Should not get here, expecting exception thrown"); + } + catch (AggregateException e) + { + Assert.Single(e.InnerExceptions); + Assert.IsType(e.InnerExceptions[0]); + Exception exception = e.InnerExceptions[0]; + Assert.IsType(exception.InnerException); + } + + Assert.True(ch.Finish()); + + for (;;) + { + var message = ch.ReadOutbound(); + if (message == null) + { + break; + } + ReferenceCountUtil.Release(message); + } + for (;;) + { + var message = ch.ReadInbound(); + if (message == null) + { + break; + } + ReferenceCountUtil.Release(message); + } + } + + [Fact] + public void Identity() + { + var ch = new EmbeddedChannel(new HttpContentCompressor()); + Assert.True(ch.WriteInbound(NewRequest())); + + var res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK, + Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("Hello, World"))); + int len = res.Content.ReadableBytes; + res.Headers.Set(HttpHeaderNames.ContentLength, len); + res.Headers.Set(HttpHeaderNames.ContentEncoding, HttpHeaderValues.Identity); + Assert.True(ch.WriteOutbound(res)); + + var response = ch.ReadOutbound(); + Assert.Equal(len.ToString(), response.Headers.Get(HttpHeaderNames.ContentLength, null).ToString()); + Assert.Equal(HttpHeaderValues.Identity.ToString(), response.Headers.Get(HttpHeaderNames.ContentEncoding, null).ToString()); + Assert.Equal("Hello, World", response.Content.ToString(Encoding.ASCII)); + response.Release(); + + Assert.True(ch.FinishAndReleaseAll()); + } + + static IFullHttpRequest NewRequest() + { + var req = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/"); + req.Headers.Set(HttpHeaderNames.AcceptEncoding, "gzip"); + return req; + } + + static void AssertEncodedResponse(EmbeddedChannel ch) + { + var res = ch.ReadOutbound(); + Assert.NotNull(res); + + var content = res as IHttpContent; + Assert.Null(content); + + Assert.Equal("chunked", res.Headers.Get(HttpHeaderNames.TransferEncoding, null)); + Assert.False(res.Headers.TryGet(HttpHeaderNames.ContentLength, out _)); + Assert.Equal("gzip", res.Headers.Get(HttpHeaderNames.ContentEncoding, null)); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpContentDecoderTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpContentDecoderTest.cs new file mode 100644 index 0000000..f088977 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpContentDecoderTest.cs @@ -0,0 +1,657 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System.Collections.Generic; + using System.Linq; + using System.Text; + using System.Threading; + using DotNetty.Buffers; + using DotNetty.Codecs.Compression; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class HttpContentDecoderTest + { + const string HelloWorld = "hello, world"; + static readonly byte[] GzHelloWorld = { + 31, (256 -117), 8, 8, 12, 3, (256 -74), 84, 0, 3, 50, 0, (256 -53), 72, (256 -51), (256 -55), (256 -55), + (256 -41), 81, 40, (256 -49), 47, (256 -54), 73, 1, 0, 58, 114, (256 -85), (256 -1), 12, 0, 0, 0 + }; + + [Fact] + public void BinaryDecompression() + { + // baseline test: zlib library and test helpers work correctly. + byte[] helloWorld = GzDecompress(GzHelloWorld); + byte[] expected = Encoding.ASCII.GetBytes(HelloWorld); + + Assert.True(expected.SequenceEqual(helloWorld)); + + const string FullCycleTest = "full cycle test"; + byte[] compressed = GzCompress(Encoding.ASCII.GetBytes(FullCycleTest)); + byte[] decompressed = GzDecompress(compressed); + + string result = Encoding.ASCII.GetString(decompressed); + Assert.Equal(FullCycleTest, result); + } + + [Fact] + public void RequestDecompression() + { + // baseline test: request decoder, content decompressor && request aggregator work as expected + var decoder = new HttpRequestDecoder(); + var decompressor = new HttpContentDecompressor(); + var aggregator = new HttpObjectAggregator(1024); + var channel = new EmbeddedChannel(decoder, decompressor, aggregator); + + string headers = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GzHelloWorld.Length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + IByteBuffer buf = Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(headers), GzHelloWorld); + Assert.True(channel.WriteInbound(buf)); + + var req = channel.ReadInbound(); + Assert.NotNull(req); + Assert.True(req.Headers.TryGetInt(HttpHeaderNames.ContentLength, out int length)); + Assert.Equal(HelloWorld.Length, length); + Assert.Equal(HelloWorld, req.Content.ToString(Encoding.ASCII)); + req.Release(); + + AssertHasInboundMessages(channel, false); + AssertHasOutboundMessages(channel, false); + Assert.False(channel.Finish()); // assert that no messages are left in channel + } + + [Fact] + public void ResponseDecompression() + { + // baseline test: response decoder, content decompressor && request aggregator work as expected + var decoder = new HttpResponseDecoder(); + var decompressor = new HttpContentDecompressor(); + var aggregator = new HttpObjectAggregator(1024); + var channel = new EmbeddedChannel(decoder, decompressor, aggregator); + + string headers = "HTTP/1.1 200 OK\r\n" + + "Content-Length: " + GzHelloWorld.Length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + IByteBuffer buf = Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(headers), GzHelloWorld); + Assert.True(channel.WriteInbound(buf)); + + var resp = channel.ReadInbound(); + Assert.NotNull(resp); + Assert.True(resp.Headers.TryGetInt(HttpHeaderNames.ContentLength, out int length)); + Assert.Equal(HelloWorld.Length, length); + Assert.Equal(HelloWorld, resp.Content.ToString(Encoding.ASCII)); + resp.Release(); + + AssertHasInboundMessages(channel, false); + AssertHasOutboundMessages(channel, false); + Assert.False(channel.Finish()); // assert that no messages are left in channel + } + + [Fact] + public void ExpectContinueResponse1() + { + // request with header "Expect: 100-continue" must be replied with one "100 Continue" response + // case 1: no ContentDecoder in chain at all (baseline test) + var decoder = new HttpRequestDecoder(); + var aggregator = new HttpObjectAggregator(1024); + var channel = new EmbeddedChannel(decoder, aggregator); + string req = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GzHelloWorld.Length + "\r\n" + + "Expect: 100-continue\r\n" + + "\r\n"; + // note: the following writeInbound() returns false as there is no message is inbound buffer + // until HttpObjectAggregator caches composes a complete message. + // however, http response "100 continue" must be sent as soon as headers are received + Assert.False(channel.WriteInbound(Unpooled.WrappedBuffer(Encoding.ASCII.GetBytes(req)))); + + var resp = channel.ReadOutbound(); + Assert.NotNull(resp); + Assert.Equal(100, resp.Status.Code); + Assert.True(channel.WriteInbound(Unpooled.WrappedBuffer(GzHelloWorld))); + resp.Release(); + + AssertHasInboundMessages(channel, true); + AssertHasOutboundMessages(channel, false); + Assert.False(channel.Finish()); + } + + [Fact] + public void ExpectContinueResponse2() + { + // request with header "Expect: 100-continue" must be replied with one "100 Continue" response + // case 2: contentDecoder is in chain, but the content is not encoded, should be no-op + var decoder = new HttpRequestDecoder(); + var decompressor = new HttpContentDecompressor(); + var aggregator = new HttpObjectAggregator(1024); + var channel = new EmbeddedChannel(decoder, decompressor, aggregator); + string req = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GzHelloWorld.Length + "\r\n" + + "Expect: 100-continue\r\n" + + "\r\n"; + Assert.False(channel.WriteInbound(Unpooled.WrappedBuffer(Encoding.ASCII.GetBytes(req)))); + + var resp = channel.ReadOutbound(); + Assert.NotNull(resp); + Assert.Equal(100, resp.Status.Code); + resp.Release(); + Assert.True(channel.WriteInbound(Unpooled.WrappedBuffer(GzHelloWorld))); + + AssertHasInboundMessages(channel, true); + AssertHasOutboundMessages(channel, false); + Assert.False(channel.Finish()); + } + + [Fact] + public void ExpectContinueResponse3() + { + // request with header "Expect: 100-continue" must be replied with one "100 Continue" response + // case 3: ContentDecoder is in chain and content is encoded + var decoder = new HttpRequestDecoder(); + var decompressor = new HttpContentDecompressor(); + var aggregator = new HttpObjectAggregator(1024); + var channel = new EmbeddedChannel(decoder, decompressor, aggregator); + string req = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GzHelloWorld.Length + "\r\n" + + "Expect: 100-continue\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + Assert.False(channel.WriteInbound(Unpooled.WrappedBuffer(Encoding.ASCII.GetBytes(req)))); + + var resp = channel.ReadOutbound(); + Assert.Equal(100, resp.Status.Code); + resp.Release(); + Assert.True(channel.WriteInbound(Unpooled.WrappedBuffer(GzHelloWorld))); + + AssertHasInboundMessages(channel, true); + AssertHasOutboundMessages(channel, false); + Assert.False(channel.Finish()); + } + + [Fact] + public void ExpectContinueResponse4() + { + // request with header "Expect: 100-continue" must be replied with one "100 Continue" response + // case 4: ObjectAggregator is up in chain + var decoder = new HttpRequestDecoder(); + var aggregator = new HttpObjectAggregator(1024); + var decompressor = new HttpContentDecompressor(); + var channel = new EmbeddedChannel(decoder, aggregator, decompressor); + string req = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GzHelloWorld.Length + "\r\n" + + "Expect: 100-continue\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + Assert.False(channel.WriteInbound(Unpooled.WrappedBuffer(Encoding.ASCII.GetBytes(req)))); + + var resp = channel.ReadOutbound(); + Assert.NotNull(resp); + Assert.Equal(100, resp.Status.Code); + resp.Release(); + Assert.True(channel.WriteInbound(Unpooled.WrappedBuffer(GzHelloWorld))); + + AssertHasInboundMessages(channel, true); + AssertHasOutboundMessages(channel, false); + Assert.False(channel.Finish()); + } + + sealed class TestHandler : ChannelHandlerAdapter + { + IFullHttpRequest request; + + public IFullHttpRequest Request => this.request; + + public override void ChannelRead(IChannelHandlerContext context, object message) + { + if (message is IFullHttpRequest value) + { + if (Interlocked.CompareExchange(ref this.request, value, null) != null) + { + value.Release(); + } + } + else + { + ReferenceCountUtil.Release(message); + } + } + } + + [Fact] + public void ExpectContinueResetHttpObjectDecoder() + { + // request with header "Expect: 100-continue" must be replied with one "100 Continue" response + // case 5: Test that HttpObjectDecoder correctly resets its internal state after a failed expectation. + var decoder = new HttpRequestDecoder(); + const int MaxBytes = 10; + var aggregator = new HttpObjectAggregator(MaxBytes); + + var testHandler = new TestHandler(); + var channel = new EmbeddedChannel(decoder, aggregator, testHandler); + string req1 = "POST /1 HTTP/1.1\r\n" + + "Content-Length: " + (MaxBytes + 1) + "\r\n" + + "Expect: 100-continue\r\n" + + "\r\n"; + Assert.False(channel.WriteInbound(Unpooled.WrappedBuffer(Encoding.ASCII.GetBytes(req1)))); + + var resp = channel.ReadOutbound(); + Assert.Equal(HttpStatusClass.ClientError, resp.Status.CodeClass); + resp.Release(); + + string req2 = "POST /2 HTTP/1.1\r\n" + + "Content-Length: " + MaxBytes + "\r\n" + + "Expect: 100-continue\r\n" + + "\r\n"; + Assert.False(channel.WriteInbound(Unpooled.WrappedBuffer(Encoding.ASCII.GetBytes(req2)))); + + resp = channel.ReadOutbound(); + Assert.Equal(100, resp.Status.Code); + resp.Release(); + + var content = new byte[MaxBytes]; + Assert.False(channel.WriteInbound(Unpooled.WrappedBuffer(content))); + + IFullHttpRequest req = testHandler.Request; + Assert.NotNull(req); + Assert.Equal("/2", req.Uri); + Assert.Equal(10, req.Content.ReadableBytes); + req.Release(); + + AssertHasInboundMessages(channel, false); + AssertHasOutboundMessages(channel, false); + Assert.False(channel.Finish()); + } + + [Fact] + public void RequestContentLength1() + { + // case 1: test that ContentDecompressor either sets the correct Content-Length header + // or removes it completely (handlers down the chain must rely on LastHttpContent object) + + // force content to be in more than one chunk (5 bytes/chunk) + var decoder = new HttpRequestDecoder(4096, 4096, 5); + var decompressor = new HttpContentDecompressor(); + var channel = new EmbeddedChannel(decoder, decompressor); + string headers = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GzHelloWorld.Length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + IByteBuffer buf = Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(headers), GzHelloWorld); + Assert.True(channel.WriteInbound(buf)); + + Queue req = channel.InboundMessages; + Assert.True(req.Count >= 1); + object o = req.Peek(); + Assert.IsAssignableFrom(o); + var request = (IHttpRequest)o; + if (request.Headers.TryGet(HttpHeaderNames.ContentLength, out ICharSequence v)) + { + Assert.Equal(HelloWorld.Length, long.Parse(v.ToString())); + } + + AssertHasInboundMessages(channel, true); + AssertHasOutboundMessages(channel, false); + Assert.False(channel.Finish()); + } + + [Fact] + public void RequestContentLength2() + { + // case 2: if HttpObjectAggregator is down the chain, then correct Content-Length header must be set + + // force content to be in more than one chunk (5 bytes/chunk) + var decoder = new HttpRequestDecoder(4096, 4096, 5); + var decompressor = new HttpContentDecompressor(); + var aggregator = new HttpObjectAggregator(1024); + var channel = new EmbeddedChannel(decoder, decompressor, aggregator); + string headers = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GzHelloWorld.Length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + IByteBuffer buf = Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(headers), GzHelloWorld); + Assert.True(channel.WriteInbound(buf)); + + var req = channel.ReadInbound(); + Assert.NotNull(req); + Assert.True(req.Headers.TryGet(HttpHeaderNames.ContentLength, out ICharSequence value)); + Assert.Equal(HelloWorld.Length, long.Parse(value.ToString())); + req.Release(); + + AssertHasInboundMessages(channel, false); + AssertHasOutboundMessages(channel, false); + Assert.False(channel.Finish()); + } + + [Fact] + public void ResponseContentLength1() + { + // case 1: test that ContentDecompressor either sets the correct Content-Length header + // or removes it completely (handlers down the chain must rely on LastHttpContent object) + + // force content to be in more than one chunk (5 bytes/chunk) + var decoder = new HttpResponseDecoder(4096, 4096, 5); + var decompressor = new HttpContentDecompressor(); + var channel = new EmbeddedChannel(decoder, decompressor); + string headers = "HTTP/1.1 200 OK\r\n" + + "Content-Length: " + GzHelloWorld.Length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + IByteBuffer buf = Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(headers), GzHelloWorld); + Assert.True(channel.WriteInbound(buf)); + + Queue resp = channel.InboundMessages; + Assert.True(resp.Count >= 1); + object o = resp.Peek(); + Assert.IsAssignableFrom(o); + var r = (IHttpResponse)o; + + Assert.False(r.Headers.Contains(HttpHeaderNames.ContentLength)); + Assert.True(r.Headers.TryGet(HttpHeaderNames.TransferEncoding, out ICharSequence transferEncoding)); + Assert.NotNull(transferEncoding); + Assert.Equal(HttpHeaderValues.Chunked, transferEncoding); + + AssertHasInboundMessages(channel, true); + AssertHasOutboundMessages(channel, false); + Assert.False(channel.Finish()); + } + + [Fact] + public void ResponseContentLength2() + { + // case 2: if HttpObjectAggregator is down the chain, then correct Content - Length header must be set + + // force content to be in more than one chunk (5 bytes/chunk) + var decoder = new HttpResponseDecoder(4096, 4096, 5); + var decompressor = new HttpContentDecompressor(); + var aggregator = new HttpObjectAggregator(1024); + var channel = new EmbeddedChannel(decoder, decompressor, aggregator); + string headers = "HTTP/1.1 200 OK\r\n" + + "Content-Length: " + GzHelloWorld.Length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + IByteBuffer buf = Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(headers), GzHelloWorld); + Assert.True(channel.WriteInbound(buf)); + + var res = channel.ReadInbound(); + Assert.NotNull(res); + Assert.True(res.Headers.TryGet(HttpHeaderNames.ContentLength, out ICharSequence value)); + Assert.Equal(HelloWorld.Length, long.Parse(value.ToString())); + res.Release(); + + AssertHasInboundMessages(channel, false); + AssertHasOutboundMessages(channel, false); + Assert.False(channel.Finish()); + } + + [Fact] + public void FullHttpRequest() + { + // test that ContentDecoder can be used after the ObjectAggregator + var decoder = new HttpRequestDecoder(4096, 4096, 5); + var aggregator = new HttpObjectAggregator(1024); + var decompressor = new HttpContentDecompressor(); + var channel = new EmbeddedChannel(decoder, aggregator, decompressor); + string headers = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GzHelloWorld.Length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + Assert.True(channel.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(headers), GzHelloWorld))); + + Queue req = channel.InboundMessages; + Assert.True(req.Count > 1); + int contentLength = 0; + contentLength = CalculateContentLength(req, contentLength); + byte[] receivedContent = ReadContent(req, contentLength, true); + Assert.Equal(HelloWorld, Encoding.ASCII.GetString(receivedContent)); + + AssertHasInboundMessages(channel, true); + AssertHasOutboundMessages(channel, false); + Assert.False(channel.Finish()); + } + + [Fact] + public void FullHttpResponse() + { + // test that ContentDecoder can be used after the ObjectAggregator + var decoder = new HttpResponseDecoder(4096, 4096, 5); + var aggregator = new HttpObjectAggregator(1024); + var decompressor = new HttpContentDecompressor(); + var channel = new EmbeddedChannel(decoder, aggregator, decompressor); + string headers = "HTTP/1.1 200 OK\r\n" + + "Content-Length: " + GzHelloWorld.Length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + Assert.True(channel.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(headers), GzHelloWorld))); + + Queue resp = channel.InboundMessages; + Assert.True(resp.Count > 1); + int contentLength = 0; + contentLength = CalculateContentLength(resp, contentLength); + byte[] receivedContent = ReadContent(resp, contentLength, true); + Assert.Equal(HelloWorld, Encoding.ASCII.GetString(receivedContent)); + + AssertHasInboundMessages(channel, true); + AssertHasOutboundMessages(channel, false); + Assert.False(channel.Finish()); + } + + // See https://github.com/netty/netty/issues/5892 + [Fact] + public void FullHttpResponseEOF() + { + // test that ContentDecoder can be used after the ObjectAggregator + var decoder = new HttpResponseDecoder(4096, 4096, 5); + var decompressor = new HttpContentDecompressor(); + var channel = new EmbeddedChannel(decoder, decompressor); + string headers = "HTTP/1.1 200 OK\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + Assert.True(channel.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(headers), GzHelloWorld))); + // This should terminate it. + Assert.True(channel.Finish()); + + Queue resp = channel.InboundMessages; + Assert.True(resp.Count > 1); + int contentLength = 0; + contentLength = CalculateContentLength(resp, contentLength); + byte[] receivedContent = ReadContent(resp, contentLength, false); + Assert.Equal(HelloWorld, Encoding.ASCII.GetString(receivedContent)); + + AssertHasInboundMessages(channel, true); + AssertHasOutboundMessages(channel, false); + Assert.False(channel.Finish()); + } + + // ReSharper disable once ParameterOnlyUsedForPreconditionCheck.Local + static byte[] ReadContent(IEnumerable req, int contentLength, bool hasTransferEncoding) + { + var receivedContent = new byte[contentLength]; + int readCount = 0; + foreach (object o in req) + { + if (o is IHttpContent content) + { + int readableBytes = content.Content.ReadableBytes; + content.Content.ReadBytes(receivedContent, readCount, readableBytes); + readCount += readableBytes; + } + + if (o is IHttpMessage message) + { + Assert.Equal(hasTransferEncoding, message.Headers.Contains(HttpHeaderNames.TransferEncoding)); + } + } + + return receivedContent; + } + + [Fact] + public void CleanupThrows() + { + var decoder = new CleanupDecoder(); + var inboundHandler = new InboundHandler(); + var channel = new EmbeddedChannel(decoder, inboundHandler); + + Assert.True(channel.WriteInbound(new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/"))); + var content = new DefaultHttpContent(Unpooled.Buffer().WriteZero(10)); + Assert.True(channel.WriteInbound(content)); + Assert.Equal(1, content.ReferenceCount); + + Assert.Throws(() => channel.FinishAndReleaseAll()); + Assert.Equal(1, inboundHandler.ChannelInactiveCalled); + Assert.Equal(0, content.ReferenceCount); + } + + sealed class CleanupDecoder : HttpContentDecoder + { + protected override EmbeddedChannel NewContentDecoder(ICharSequence contentEncoding) => new EmbeddedChannel(new Handler()); + + sealed class Handler : ChannelHandlerAdapter + { + public override void ChannelInactive(IChannelHandlerContext context) + { + context.FireExceptionCaught(new DecoderException("CleanupThrows")); + context.FireChannelInactive(); + } + } + } + + sealed class InboundHandler : ChannelHandlerAdapter + { + public int ChannelInactiveCalled; + + public override void ChannelInactive(IChannelHandlerContext context) + { + Interlocked.CompareExchange(ref this.ChannelInactiveCalled, 1, 0); + base.ChannelInactive(context); + } + } + + static int CalculateContentLength(IEnumerable req, int contentLength) + { + foreach (object o in req) + { + if (o is IHttpContent content) + { + Assert.True(content.ReferenceCount > 0); + contentLength += content.Content.ReadableBytes; + } + } + + return contentLength; + } + + static byte[] GzCompress(byte[] input) + { + ZlibEncoder encoder = ZlibCodecFactory.NewZlibEncoder(ZlibWrapper.Gzip); + var channel = new EmbeddedChannel(encoder); + Assert.True(channel.WriteOutbound(Unpooled.WrappedBuffer(input))); + Assert.True(channel.Finish()); // close the channel to indicate end-of-data + + int outputSize = 0; + IByteBuffer o; + var outbound = new List(); + while ((o = channel.ReadOutbound()) != null) + { + outbound.Add(o); + outputSize += o.ReadableBytes; + } + + var output = new byte[outputSize]; + int readCount = 0; + foreach (IByteBuffer b in outbound) + { + int readableBytes = b.ReadableBytes; + b.ReadBytes(output, readCount, readableBytes); + b.Release(); + readCount += readableBytes; + } + Assert.True(channel.InboundMessages.Count == 0 && channel.OutboundMessages.Count == 0); + + return output; + } + + static byte[] GzDecompress(byte[] input) + { + ZlibDecoder decoder = ZlibCodecFactory.NewZlibDecoder(ZlibWrapper.Gzip); + var channel = new EmbeddedChannel(decoder); + Assert.True(channel.WriteInbound(Unpooled.WrappedBuffer(input))); + Assert.True(channel.Finish()); // close the channel to indicate end-of-data + + int outputSize = 0; + IByteBuffer o; + var inbound = new List(); + while ((o = channel.ReadInbound()) != null) + { + inbound.Add(o); + outputSize += o.ReadableBytes; + } + + var output = new byte[outputSize]; + int readCount = 0; + foreach (IByteBuffer b in inbound) + { + int readableBytes = b.ReadableBytes; + b.ReadBytes(output, readCount, readableBytes); + b.Release(); + readCount += readableBytes; + } + Assert.True(channel.InboundMessages.Count == 0 && channel.OutboundMessages.Count == 0); + + return output; + } + + static void AssertHasInboundMessages(EmbeddedChannel channel, bool hasMessages) + { + object o; + if (hasMessages) + { + while (true) + { + o = channel.ReadInbound(); + Assert.NotNull(o); + ReferenceCountUtil.Release(o); + if (o is ILastHttpContent) + { + break; + } + } + } + else + { + o = channel.ReadInbound(); + Assert.Null(o); + } + } + + static void AssertHasOutboundMessages(EmbeddedChannel channel, bool hasMessages) + { + object o; + if (hasMessages) + { + while (true) + { + o = channel.ReadOutbound(); + Assert.NotNull(o); + ReferenceCountUtil.Release(o); + if (o is ILastHttpContent) + { + break; + } + } + } + else + { + o = channel.ReadOutbound(); + Assert.Null(o); + } + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpContentEncoderTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpContentEncoderTest.cs new file mode 100644 index 0000000..f0a71ae --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpContentEncoderTest.cs @@ -0,0 +1,457 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System; + using System.Text; + using System.Threading; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class HttpContentEncoderTest + { + sealed class TestEncoder : HttpContentEncoder + { + protected override Result BeginEncode(IHttpResponse headers, ICharSequence acceptEncoding) => + new Result(new StringCharSequence("test"), new EmbeddedChannel(new EmbeddedMessageEncoder())); + } + + sealed class EmbeddedMessageEncoder : MessageToByteEncoder + { + protected override void Encode(IChannelHandlerContext context, IByteBuffer message, IByteBuffer output) + { + output.WriteBytes(Encoding.ASCII.GetBytes(Convert.ToString(message.ReadableBytes))); + message.SkipBytes(message.ReadableBytes); + } + } + + [Fact] + public void SplitContent() + { + var ch = new EmbeddedChannel(new TestEncoder()); + ch.WriteInbound(new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/")); + + ch.WriteOutbound(new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK)); + ch.WriteOutbound(new DefaultHttpContent(Unpooled.WrappedBuffer(new byte[3]))); + ch.WriteOutbound(new DefaultHttpContent(Unpooled.WrappedBuffer(new byte[2]))); + ch.WriteOutbound(new DefaultLastHttpContent(Unpooled.WrappedBuffer(new byte[1]))); + + AssertEncodedResponse(ch); + + var chunk = ch.ReadOutbound(); + Assert.Equal("3", chunk.Content.ToString(Encoding.ASCII)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("2", chunk.Content.ToString(Encoding.ASCII)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("1", chunk.Content.ToString(Encoding.ASCII)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.False(chunk.Content.IsReadable()); + Assert.IsAssignableFrom(chunk); + chunk.Release(); + + var next = ch.ReadOutbound(); + Assert.Null(next); + } + + [Fact] + public void ChunkedContent() + { + var ch = new EmbeddedChannel(new TestEncoder()); + ch.WriteInbound(new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/")); + + var res = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + res.Headers.Set(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + ch.WriteOutbound(res); + + AssertEncodedResponse(ch); + + ch.WriteOutbound(new DefaultHttpContent(Unpooled.WrappedBuffer(new byte[3]))); + ch.WriteOutbound(new DefaultHttpContent(Unpooled.WrappedBuffer(new byte[2]))); + ch.WriteOutbound(new DefaultLastHttpContent(Unpooled.WrappedBuffer(new byte[1]))); + + var chunk = ch.ReadOutbound(); + Assert.Equal("3", chunk.Content.ToString(Encoding.ASCII)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("2", chunk.Content.ToString(Encoding.ASCII)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("1", chunk.Content.ToString(Encoding.ASCII)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.False(chunk.Content.IsReadable()); + Assert.IsAssignableFrom(chunk); + chunk.Release(); + + var next = ch.ReadOutbound(); + Assert.Null(next); + } + + [Fact] + public void ChunkedContentWithTrailingHeader() + { + var ch = new EmbeddedChannel(new TestEncoder()); + ch.WriteInbound(new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/")); + + var res = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + res.Headers.Set(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + ch.WriteOutbound(res); + + AssertEncodedResponse(ch); + + ch.WriteOutbound(new DefaultHttpContent(Unpooled.WrappedBuffer(new byte[3]))); + ch.WriteOutbound(new DefaultHttpContent(Unpooled.WrappedBuffer(new byte[2]))); + var content = new DefaultLastHttpContent(Unpooled.WrappedBuffer(new byte[1])); + content.TrailingHeaders.Set((AsciiString)"X-Test", (AsciiString)"Netty"); + ch.WriteOutbound(content); + + var chunk = ch.ReadOutbound(); + Assert.Equal("3", chunk.Content.ToString(Encoding.ASCII)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("2", chunk.Content.ToString(Encoding.ASCII)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.Equal("1", chunk.Content.ToString(Encoding.ASCII)); + chunk.Release(); + + var last = ch.ReadOutbound(); + Assert.NotNull(last); + Assert.False(last.Content.IsReadable()); + Assert.Equal("Netty", last.TrailingHeaders.Get((AsciiString)"X-Test", null).ToString()); + last.Release(); + + var next = ch.ReadOutbound(); + Assert.Null(next); + } + + [Fact] + public void FullContentWithContentLength() + { + var ch = new EmbeddedChannel(new TestEncoder()); + ch.WriteInbound(new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/")); + + var fullRes = new DefaultFullHttpResponse( + HttpVersion.Http11, HttpResponseStatus.OK, Unpooled.WrappedBuffer(new byte[42])); + fullRes.Headers.Set(HttpHeaderNames.ContentLength, 42); + ch.WriteOutbound(fullRes); + + var res = ch.ReadOutbound(); + Assert.NotNull(res); + Assert.False(res is IHttpContent, $"{res.GetType()}"); + Assert.False(res.Headers.TryGet(HttpHeaderNames.TransferEncoding, out _)); + Assert.Equal("2", res.Headers.Get(HttpHeaderNames.ContentLength, null).ToString()); + Assert.Equal("test", res.Headers.Get(HttpHeaderNames.ContentEncoding, null).ToString()); + + var c = ch.ReadOutbound(); + Assert.Equal(2, c.Content.ReadableBytes); + Assert.Equal("42", c.Content.ToString(Encoding.ASCII)); + c.Release(); + + var last = ch.ReadOutbound(); + Assert.Equal(0, last.Content.ReadableBytes); + last.Release(); + + var next = ch.ReadOutbound(); + Assert.Null(next); + } + + [Fact] + public void FullContent() + { + var ch = new EmbeddedChannel(new TestEncoder()); + ch.WriteInbound(new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/")); + + var res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK, + Unpooled.WrappedBuffer(new byte[42])); + ch.WriteOutbound(res); + + AssertEncodedResponse(ch); + var c = ch.ReadOutbound(); + Assert.NotNull(c); + Assert.Equal(2, c.Content.ReadableBytes); + Assert.Equal("42", c.Content.ToString(Encoding.ASCII)); + c.Release(); + + var last = ch.ReadOutbound(); + Assert.NotNull(last); + Assert.Equal(0, last.Content.ReadableBytes); + last.Release(); + + var next = ch.ReadOutbound(); + Assert.Null(next); + } + + // If the length of the content is unknown, {@link HttpContentEncoder} should not skip encoding the content + // even if the actual length is turned out to be 0. + [Fact] + public void EmptySplitContent() + { + var ch = new EmbeddedChannel(new TestEncoder()); + ch.WriteInbound(new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/")); + + ch.WriteOutbound(new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK)); + + AssertEncodedResponse(ch); + + ch.WriteOutbound(EmptyLastHttpContent.Default); + + var chunk = ch.ReadOutbound(); + Assert.NotNull(chunk); + Assert.Equal("0", chunk.Content.ToString(Encoding.ASCII)); + chunk.Release(); + + var last = ch.ReadOutbound(); + Assert.False(last.Content.IsReadable()); + last.Release(); + + var next = ch.ReadOutbound(); + Assert.Null(next); + } + + // If the length of the content is 0 for sure, {@link HttpContentEncoder} should skip encoding. + [Fact] + public void EmptyFullContent() + { + var ch = new EmbeddedChannel(new TestEncoder()); + ch.WriteInbound(new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/")); + + IFullHttpResponse res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK, Unpooled.Empty); + ch.WriteOutbound(res); + + res = ch.ReadOutbound(); + Assert.NotNull(res); + Assert.False(res.Headers.TryGet(HttpHeaderNames.TransferEncoding, out _)); + + // Content encoding shouldn't be modified. + Assert.False(res.Headers.TryGet(HttpHeaderNames.ContentEncoding, out _)); + Assert.Equal(0, res.Content.ReadableBytes); + Assert.Equal("", res.Content.ToString(Encoding.ASCII)); + res.Release(); + + var next = ch.ReadOutbound(); + Assert.Null(next); + } + + [Fact] + public void EmptyFullContentWithTrailer() + { + var ch = new EmbeddedChannel(new TestEncoder()); + ch.WriteInbound(new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/")); + + IFullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.Http11, HttpResponseStatus.OK, Unpooled.Empty); + res.TrailingHeaders.Set((AsciiString)"X-Test", (StringCharSequence)"Netty"); + ch.WriteOutbound(res); + + res = ch.ReadOutbound(); + Assert.NotNull(res); + Assert.False(res.Headers.TryGet(HttpHeaderNames.TransferEncoding, out _)); + + // Content encoding shouldn't be modified. + Assert.False(res.Headers.TryGet(HttpHeaderNames.ContentEncoding, out _)); + Assert.Equal(0, res.Content.ReadableBytes); + Assert.Equal("", res.Content.ToString(Encoding.ASCII)); + Assert.Equal("Netty", res.TrailingHeaders.Get((AsciiString)"X-Test", null)); + + var next = ch.ReadOutbound(); + Assert.Null(next); + } + + [Fact] + public void EmptyHeadResponse() + { + var ch = new EmbeddedChannel(new TestEncoder()); + var req = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Head, "/"); + ch.WriteInbound(req); + + var res = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + res.Headers.Set(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + ch.WriteOutbound(res); + ch.WriteOutbound(EmptyLastHttpContent.Default); + + AssertEmptyResponse(ch); + } + + [Fact] + public void Http304Response() + { + var ch = new EmbeddedChannel(new TestEncoder()); + var req = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Head, "/"); + req.Headers.Set(HttpHeaderNames.AcceptEncoding, HttpHeaderValues.Gzip); + ch.WriteInbound(req); + + var res = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.NotModified); + res.Headers.Set(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + ch.WriteOutbound(res); + ch.WriteOutbound(EmptyLastHttpContent.Default); + + AssertEmptyResponse(ch); + } + + [Fact] + public void Connect200Response() + { + var ch = new EmbeddedChannel(new TestEncoder()); + var req = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Connect, "google.com:80"); + ch.WriteInbound(req); + + var res = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + res.Headers.Set(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + ch.WriteOutbound(res); + ch.WriteOutbound(EmptyLastHttpContent.Default); + + AssertEmptyResponse(ch); + } + + [Fact] + public void ConnectFailureResponse() + { + const string Content = "Not allowed by configuration"; + + var ch = new EmbeddedChannel(new TestEncoder()); + var req = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Connect, "google.com:80"); + ch.WriteInbound(req); + + var res = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.MethodNotAllowed); + res.Headers.Set(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + ch.WriteOutbound(res); + ch.WriteOutbound(new DefaultHttpContent(Unpooled.WrappedBuffer(Encoding.ASCII.GetBytes(Content)))); + ch.WriteOutbound(EmptyLastHttpContent.Default); + + AssertEncodedResponse(ch); + + var chunk = ch.ReadOutbound(); + Assert.NotNull(chunk); + Assert.Equal("28", chunk.Content.ToString(Encoding.ASCII)); + chunk.Release(); + + chunk = ch.ReadOutbound(); + Assert.True(chunk.Content.IsReadable()); + Assert.Equal("0", chunk.Content.ToString(Encoding.ASCII)); + chunk.Release(); + + var last = ch.ReadOutbound(); + Assert.NotNull(last); + last.Release(); + + var next = ch.ReadOutbound(); + Assert.Null(next); + } + + [Fact] + public void Http10() + { + var ch = new EmbeddedChannel(new TestEncoder()); + var req = new DefaultFullHttpRequest(HttpVersion.Http10, HttpMethod.Get, "/"); + Assert.True(ch.WriteInbound(req)); + + var res = new DefaultHttpResponse(HttpVersion.Http10, HttpResponseStatus.OK); + res.Headers.Set(HttpHeaderNames.ContentLength, HttpHeaderValues.Zero); + Assert.True(ch.WriteOutbound(res)); + Assert.True(ch.WriteOutbound(EmptyLastHttpContent.Default)); + Assert.True(ch.Finish()); + + var request = ch.ReadInbound(); + Assert.True(request.Release()); + var next = ch.ReadInbound(); + Assert.Null(next); + + var response = ch.ReadOutbound(); + Assert.Same(res, response); + + var content = ch.ReadOutbound(); + Assert.Same(content, EmptyLastHttpContent.Default); + content.Release(); + + next = ch.ReadOutbound(); + Assert.Null(next); + } + + [Fact] + public void CleanupThrows() + { + var encoder = new CleanupEncoder(); + var inboundHandler = new InboundHandler(); + var channel = new EmbeddedChannel(encoder, inboundHandler); + + Assert.True(channel.WriteInbound(new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/"))); + Assert.True(channel.WriteOutbound(new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK))); + var content = new DefaultHttpContent(Unpooled.Buffer().WriteZero(10)); + Assert.True(channel.WriteOutbound(content)); + Assert.Equal(1, content.ReferenceCount); + + Assert.Throws(() => channel.FinishAndReleaseAll()); + Assert.Equal(1, inboundHandler.ChannelInactiveCalled); + Assert.Equal(0, content.ReferenceCount); + } + + sealed class CleanupEncoder : HttpContentEncoder + { + protected override Result BeginEncode(IHttpResponse headers, ICharSequence acceptEncoding) => + new Result(new StringCharSequence("myencoding"), new EmbeddedChannel(new Handler())); + + sealed class Handler : ChannelHandlerAdapter + { + public override void ChannelInactive(IChannelHandlerContext context) + { + context.FireExceptionCaught(new EncoderException("CleanupThrows")); + context.FireChannelInactive(); + } + } + } + + sealed class InboundHandler : ChannelHandlerAdapter + { + public int ChannelInactiveCalled; + + public override void ChannelInactive(IChannelHandlerContext context) + { + Interlocked.CompareExchange(ref this.ChannelInactiveCalled, 1, 0); + base.ChannelInactive(context); + } + } + + static void AssertEmptyResponse(EmbeddedChannel ch) + { + var res = ch.ReadOutbound(); + Assert.NotNull(res); + Assert.False(res is IHttpContent); + Assert.Equal("chunked", res.Headers.Get(HttpHeaderNames.TransferEncoding, null).ToString()); + Assert.False(res.Headers.TryGet(HttpHeaderNames.ContentLength, out _)); + + var chunk = ch.ReadOutbound(); + Assert.NotNull(chunk); + chunk.Release(); + + var next = ch.ReadOutbound(); + Assert.Null(next); + } + + static void AssertEncodedResponse(EmbeddedChannel ch) + { + var res = ch.ReadOutbound(); + Assert.NotNull(res); + + Assert.False(res is IHttpContent, $"{res.GetType()}"); + Assert.Equal("chunked", res.Headers.Get(HttpHeaderNames.TransferEncoding, null).ToString()); + Assert.False(res.Headers.TryGet(HttpHeaderNames.ContentLength, out _)); + Assert.Equal("test", res.Headers.Get(HttpHeaderNames.ContentEncoding, null).ToString()); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpHeadersTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpHeadersTest.cs new file mode 100644 index 0000000..958e571 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpHeadersTest.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System; + using System.Collections.Generic; + using DotNetty.Common.Utilities; + using Xunit; + + public sealed class HttpHeadersTest + { + [Fact] + public void RemoveTransferEncodingIgnoreCase() + { + var message = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + message.Headers.Set(HttpHeaderNames.TransferEncoding, "Chunked"); + Assert.False(message.Headers.IsEmpty); + HttpUtil.SetTransferEncodingChunked(message, false); + Assert.True(message.Headers.IsEmpty); + } + + // Test for https://github.com/netty/netty/issues/1690 + [Fact] + public void GetOperations() + { + HttpHeaders headers = new DefaultHttpHeaders(); + headers.Add(HttpHeadersTestUtils.Of("Foo"), HttpHeadersTestUtils.Of("1")); + headers.Add(HttpHeadersTestUtils.Of("Foo"), HttpHeadersTestUtils.Of("2")); + + Assert.Equal("1", headers.Get(HttpHeadersTestUtils.Of("Foo"), null)); + + IList values = headers.GetAll(HttpHeadersTestUtils.Of("Foo")); + Assert.Equal(2, values.Count); + Assert.Equal("1", values[0].ToString()); + Assert.Equal("2", values[1].ToString()); + } + + [Fact] + public void EqualsIgnoreCase() + { + Assert.True(AsciiString.ContentEqualsIgnoreCase(null, null)); + Assert.False(AsciiString.ContentEqualsIgnoreCase(null, (StringCharSequence)"foo")); + Assert.False(AsciiString.ContentEqualsIgnoreCase((StringCharSequence)"bar", null)); + Assert.True(AsciiString.ContentEqualsIgnoreCase((StringCharSequence)"FoO", (StringCharSequence)"fOo")); + } + + [Fact] + public void AddSelf() + { + HttpHeaders headers = new DefaultHttpHeaders(false); + Assert.Throws(() => headers.Add(headers)); + } + + [Fact] + public void SetSelfIsNoOp() + { + HttpHeaders headers = new DefaultHttpHeaders(false); + headers.Add((AsciiString)"name", (StringCharSequence)"value"); + headers.Set(headers); + Assert.Equal(1, headers.Size); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpHeadersTestUtils.cs b/test/DotNetty.Codecs.Http.Tests/HttpHeadersTestUtils.cs new file mode 100644 index 0000000..c30069e --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpHeadersTestUtils.cs @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System.Collections.Generic; + using System.Text; + using DotNetty.Common.Utilities; + using Xunit; + + static class HttpHeadersTestUtils + { + static readonly IReadOnlyDictionary ValueMap = new Dictionary + { + { 0, HeaderValue.Unknown }, + { 1, HeaderValue.One }, + { 2, HeaderValue.Two }, + { 3, HeaderValue.Three }, + { 4, HeaderValue.Four }, + { 5, HeaderValue.Five }, + { 6, HeaderValue.SixQuoted }, + { 7, HeaderValue.SevenQuoted }, + { 8, HeaderValue.Eight } + }; + + public class HeaderValue + { + public static readonly HeaderValue Unknown = new HeaderValue("Unknown", 0); + public static readonly HeaderValue One = new HeaderValue("One", 1); + public static readonly HeaderValue Two = new HeaderValue("Two", 2); + public static readonly HeaderValue Three = new HeaderValue("Three", 3); + public static readonly HeaderValue Four = new HeaderValue("Four", 4); + public static readonly HeaderValue Five = new HeaderValue("Five", 5); + public static readonly HeaderValue SixQuoted = new HeaderValue("Six,", 6); + public static readonly HeaderValue SevenQuoted = new HeaderValue("Seven; , GMT", 7); + public static readonly HeaderValue Eight = new HeaderValue("Eight", 8); + + readonly int nr; + readonly string value; + List array; + + HeaderValue(string value, int nr) + { + this.nr = nr; + this.value = value; + } + + public override string ToString() => this.value; + + public List Subset(int from) + { + Assert.True(from > 0); + --from; + int size = this.nr - from; + int end = from + size; + var list = new List(size); + List fullList = this.AsList(); + for (int i = from; i < end; ++i) + { + list.Add(fullList[i]); + } + + return list; + } + + public string SubsetAsCsvString(int from) + { + List subset = this.Subset(from); + return this.AsCsv(subset); + } + + public List AsList() + { + if (this.array == null) + { + var list = new List(this.nr); + for (int i = 1; i <= this.nr; i++) + { + list.Add(new StringCharSequence(Of(i).ToString())); + } + + this.array = list; + } + + return this.array; + } + + public string AsCsv(IList arr) + { + if (arr == null || arr.Count == 0) + { + return ""; + } + + var sb = new StringBuilder(arr.Count * 10); + int end = arr.Count - 1; + for (int i = 0; i < end; ++i) + { + Quoted(sb, arr[i]).Append(StringUtil.Comma); + } + + Quoted(sb, arr[end]); + return sb.ToString(); + } + + public ICharSequence AsCsv() => (StringCharSequence)this.AsCsv(this.AsList()); + + public static HeaderValue Of(int nr) => ValueMap.TryGetValue(nr, out HeaderValue v) ? v : Unknown; + } + + public static AsciiString Of(string s) => new AsciiString(s); + + static StringBuilder Quoted(StringBuilder sb, ICharSequence value) + { + if (Contains(value, StringUtil.Comma) && !Contains(value, StringUtil.DoubleQuote)) + { + return sb.Append(StringUtil.DoubleQuote) + .Append(value) + .Append(StringUtil.DoubleQuote); + } + + return sb.Append(value); + } + + static bool Contains(IEnumerable value, char c) + { + foreach (char t in value) + { + if (t == c) + { + return true; + } + } + + return false; + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpInvalidMessageTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpInvalidMessageTest.cs new file mode 100644 index 0000000..5f2267e --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpInvalidMessageTest.cs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class HttpInvalidMessageTest + { + readonly Random rnd = new Random(); + + [Fact] + public void RequestWithBadInitialLine() + { + var ch = new EmbeddedChannel(new HttpRequestDecoder()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("GET / HTTP/1.0 with extra\r\n"))); + var req = ch.ReadInbound(); + DecoderResult dr = req.Result; + Assert.NotNull(dr); + Assert.False(dr.IsSuccess); + Assert.True(dr.IsFailure); + this.EnsureInboundTrafficDiscarded(ch); + } + + [Fact] + public void RequestWithBadHeader() + { + var ch = new EmbeddedChannel(new HttpRequestDecoder()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("GET /maybe-something HTTP/1.0\r\n"))); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("Good_Name: Good Value\r\n"))); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("Bad=Name: Bad Value\r\n"))); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("\r\n"))); + var req = ch.ReadInbound(); + DecoderResult dr = req.Result; + Assert.NotNull(dr); + Assert.False(dr.IsSuccess); + Assert.True(dr.IsFailure); + Assert.Equal("Good Value", req.Headers.Get((AsciiString)"Good_Name", null).ToString()); + Assert.Equal("/maybe-something", req.Uri); + this.EnsureInboundTrafficDiscarded(ch); + } + + [Fact] + public void ResponseWithBadInitialLine() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("HTTP/1.0 BAD_CODE Bad Server\r\n"))); + var res = ch.ReadInbound(); + DecoderResult dr = res.Result; + Assert.NotNull(dr); + Assert.False(dr.IsSuccess); + Assert.True(dr.IsFailure); + this.EnsureInboundTrafficDiscarded(ch); + } + + [Fact] + public void ResponseWithBadHeader() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("HTTP/1.0 200 Maybe OK\r\n"))); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("Good_Name: Good Value\r\n"))); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("Bad=Name: Bad Value\r\n"))); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("\r\n"))); + var res = ch.ReadInbound(); + DecoderResult dr = res.Result; + Assert.NotNull(dr); + Assert.False(dr.IsSuccess); + Assert.True(dr.IsFailure); + Assert.Equal("Maybe OK", res.Status.ReasonPhrase); + Assert.Equal("Good Value", res.Headers.Get((AsciiString)"Good_Name", null).ToString()); + this.EnsureInboundTrafficDiscarded(ch); + } + + [Fact] + public void BadChunk() + { + var ch = new EmbeddedChannel(new HttpRequestDecoder()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("GET / HTTP/1.0\r\n"))); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("Transfer-Encoding: chunked\r\n\r\n"))); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("BAD_LENGTH\r\n"))); + var req = ch.ReadInbound(); + DecoderResult dr = req.Result; + Assert.NotNull(dr); + Assert.True(dr.IsSuccess); + var chunk = ch.ReadInbound(); + dr = chunk.Result; + Assert.False(dr.IsSuccess); + Assert.True(dr.IsFailure); + this.EnsureInboundTrafficDiscarded(ch); + } + + void EnsureInboundTrafficDiscarded(EmbeddedChannel ch) + { + // Generate a lot of random traffic to ensure that it's discarded silently. + var data = new byte[1048576]; + this.rnd.NextBytes(data); + + IByteBuffer buf = Unpooled.WrappedBuffer(data); + for (int i = 0; i < 4096; i++) + { + buf.SetIndex(0, data.Length); + ch.WriteInbound(buf.Retain()); + ch.CheckException(); + Assert.Null(ch.ReadInbound()); + } + buf.Release(); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpObjectAggregatorTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpObjectAggregatorTest.cs new file mode 100644 index 0000000..2acfc48 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpObjectAggregatorTest.cs @@ -0,0 +1,500 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System; + using System.Collections.Generic; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class HttpObjectAggregatorTest + { + [Fact] + public void Aggregate() + { + var aggregator = new HttpObjectAggregator(1024 * 1024); + var ch = new EmbeddedChannel(aggregator); + + var message = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Get, "http://localhost"); + message.Headers.Set((AsciiString)"X-Test", true); + IHttpContent chunk1 = new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("test"))); + IHttpContent chunk2 = new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("test2"))); + IHttpContent chunk3 = new DefaultLastHttpContent(Unpooled.Empty); + + Assert.False(ch.WriteInbound(message)); + Assert.False(ch.WriteInbound(chunk1)); + Assert.False(ch.WriteInbound(chunk2)); + + // this should trigger a channelRead event so return true + Assert.True(ch.WriteInbound(chunk3)); + Assert.True(ch.Finish()); + var aggregatedMessage = ch.ReadInbound(); + Assert.NotNull(aggregatedMessage); + + Assert.Equal(chunk1.Content.ReadableBytes + chunk2.Content.ReadableBytes, HttpUtil.GetContentLength(aggregatedMessage)); + Assert.Equal(bool.TrueString, aggregatedMessage.Headers.Get((AsciiString)"X-Test", null)?.ToString()); + CheckContentBuffer(aggregatedMessage); + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void AggregateWithTrailer() + { + var aggregator = new HttpObjectAggregator(1024 * 1024); + var ch = new EmbeddedChannel(aggregator); + var message = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Get, "http://localhost"); + message.Headers.Set((AsciiString)"X-Test", true); + HttpUtil.SetTransferEncodingChunked(message, true); + var chunk1 = new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("test"))); + var chunk2 = new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("test2"))); + var trailer = new DefaultLastHttpContent(); + trailer.TrailingHeaders.Set((AsciiString)"X-Trailer", true); + + Assert.False(ch.WriteInbound(message)); + Assert.False(ch.WriteInbound(chunk1)); + Assert.False(ch.WriteInbound(chunk2)); + + // this should trigger a channelRead event so return true + Assert.True(ch.WriteInbound(trailer)); + Assert.True(ch.Finish()); + var aggregatedMessage = ch.ReadInbound(); + Assert.NotNull(aggregatedMessage); + + Assert.Equal(chunk1.Content.ReadableBytes + chunk2.Content.ReadableBytes, HttpUtil.GetContentLength(aggregatedMessage)); + Assert.Equal(bool.TrueString, aggregatedMessage.Headers.Get((AsciiString)"X-Test", null)?.ToString()); + Assert.Equal(bool.TrueString, aggregatedMessage.TrailingHeaders.Get((AsciiString)"X-Trailer", null)?.ToString()); + CheckContentBuffer(aggregatedMessage); + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void OversizedRequest() + { + var aggregator = new HttpObjectAggregator(4); + var ch = new EmbeddedChannel(aggregator); + var message = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Put, "http://localhost"); + var chunk1 = new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("test"))); + var chunk2 = new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("test2"))); + EmptyLastHttpContent chunk3 = EmptyLastHttpContent.Default; + + Assert.False(ch.WriteInbound(message)); + Assert.False(ch.WriteInbound(chunk1)); + Assert.False(ch.WriteInbound(chunk2)); + + var response = ch.ReadOutbound(); + Assert.Equal(HttpResponseStatus.RequestEntityTooLarge, response.Status); + Assert.Equal("0", response.Headers.Get(HttpHeaderNames.ContentLength, null)); + Assert.False(ch.Open); + + try + { + Assert.False(ch.WriteInbound(chunk3)); + Assert.True(false, "Shoud not get here, expecting exception thrown."); + } + catch (Exception e) + { + Assert.True(e is ClosedChannelException); + } + + Assert.False(ch.Finish()); + } + + [Fact] + public void OversizedRequestWithoutKeepAlive() + { + // send a HTTP/1.0 request with no keep-alive header + var message = new DefaultHttpRequest(HttpVersion.Http10, HttpMethod.Put, "http://localhost"); + HttpUtil.SetContentLength(message, 5); + CheckOversizedRequest(message); + } + + [Fact] + public void OversizedRequestWithContentLength() + { + var message = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Put, "http://localhost"); + HttpUtil.SetContentLength(message, 5); + CheckOversizedRequest(message); + } + + [Fact] + public void OversizedResponse() + { + var aggregator = new HttpObjectAggregator(4); + var ch = new EmbeddedChannel(aggregator); + var message = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + var chunk1 = new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("test"))); + var chunk2 = new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("test2"))); + + Assert.False(ch.WriteInbound(message)); + Assert.False(ch.WriteInbound(chunk1)); + Assert.Throws(() => ch.WriteInbound(chunk2)); + + Assert.False(ch.Open); + Assert.False(ch.Finish()); + } + + [Fact] + public void InvalidConstructorUsage() + { + var error = Assert.Throws(() => new HttpObjectAggregator(-1)); + Assert.Equal("maxContentLength", error.ParamName); + } + + [Fact] + public void InvalidMaxCumulationBufferComponents() + { + var aggregator = new HttpObjectAggregator(int.MaxValue); + Assert.Throws(() => aggregator.MaxCumulationBufferComponents = 1); + } + + [Fact] + public void SetMaxCumulationBufferComponentsAfterInit() + { + var aggregator = new HttpObjectAggregator(int.MaxValue); + var ch = new EmbeddedChannel(aggregator); + Assert.Throws(() => aggregator.MaxCumulationBufferComponents = 10); + Assert.False(ch.Finish()); + } + + [Fact] + public void AggregateTransferEncodingChunked() + { + var aggregator = new HttpObjectAggregator(1024 * 1024); + var ch = new EmbeddedChannel(aggregator); + + var message = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Put, "http://localhost"); + message.Headers.Set((AsciiString)"X-Test", true); + message.Headers.Set((AsciiString)"Transfer-Encoding", (AsciiString)"Chunked"); + var chunk1 = new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("test"))); + var chunk2 = new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("test2"))); + EmptyLastHttpContent chunk3 = EmptyLastHttpContent.Default; + Assert.False(ch.WriteInbound(message)); + Assert.False(ch.WriteInbound(chunk1)); + Assert.False(ch.WriteInbound(chunk2)); + + // this should trigger a channelRead event so return true + Assert.True(ch.WriteInbound(chunk3)); + Assert.True(ch.Finish()); + var aggregatedMessage = ch.ReadInbound(); + Assert.NotNull(aggregatedMessage); + + Assert.Equal(chunk1.Content.ReadableBytes + chunk2.Content.ReadableBytes, HttpUtil.GetContentLength(aggregatedMessage)); + Assert.Equal(bool.TrueString, aggregatedMessage.Headers.Get((AsciiString)"X-Test", null)); + CheckContentBuffer(aggregatedMessage); + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void BadRequest() + { + var ch = new EmbeddedChannel(new HttpRequestDecoder(), new HttpObjectAggregator(1024 * 1024)); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("GET / HTTP/1.0 with extra\r\n"))); + var req = ch.ReadInbound(); + Assert.NotNull(req); + Assert.True(req.Result.IsFailure); + var last = ch.ReadInbound(); + Assert.Null(last); + ch.Finish(); + } + + [Fact] + public void BadResponse() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder(), new HttpObjectAggregator(1024 * 1024)); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("HTTP/1.0 BAD_CODE Bad Server\r\n"))); + var resp = ch.ReadInbound(); + Assert.NotNull(resp); + Assert.True(resp.Result.IsFailure); + var last = ch.ReadInbound(); + Assert.Null(last); + ch.Finish(); + } + + [Fact] + public void OversizedRequestWith100Continue() + { + var ch = new EmbeddedChannel(new HttpObjectAggregator(8)); + + // Send an oversized request with 100 continue. + var message = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Put, "http://localhost"); + HttpUtil.Set100ContinueExpected(message, true); + HttpUtil.SetContentLength(message, 16); + + var chunk1 = new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("some"))); + var chunk2 = new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("test"))); + EmptyLastHttpContent chunk3 = EmptyLastHttpContent.Default; + + // Send a request with 100-continue + large Content-Length header value. + Assert.False(ch.WriteInbound(message)); + + // The aggregator should respond with '413.' + var response = ch.ReadOutbound(); + Assert.Equal(HttpResponseStatus.RequestEntityTooLarge, response.Status); + Assert.Equal((AsciiString)"0", response.Headers.Get(HttpHeaderNames.ContentLength, null)); + + // An ill-behaving client could continue to send data without a respect, and such data should be discarded. + Assert.False(ch.WriteInbound(chunk1)); + + // The aggregator should not close the connection because keep-alive is on. + Assert.True(ch.Open); + + // Now send a valid request. + var message2 = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Put, "http://localhost"); + + Assert.False(ch.WriteInbound(message2)); + Assert.False(ch.WriteInbound(chunk2)); + Assert.True(ch.WriteInbound(chunk3)); + + var fullMsg = ch.ReadInbound(); + Assert.NotNull(fullMsg); + + Assert.Equal(chunk2.Content.ReadableBytes + chunk3.Content.ReadableBytes, HttpUtil.GetContentLength(fullMsg)); + Assert.Equal(HttpUtil.GetContentLength(fullMsg), fullMsg.Content.ReadableBytes); + + fullMsg.Release(); + Assert.False(ch.Finish()); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void UnsupportedExpectHeaderExpectation(bool close) + { + int maxContentLength = 4; + var aggregator = new HttpObjectAggregator(maxContentLength, close); + var ch = new EmbeddedChannel(new HttpRequestDecoder(), aggregator); + + Assert.False(ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes( + "GET / HTTP/1.1\r\n" + + "Expect: chocolate=yummy\r\n" + + "Content-Length: 100\r\n\r\n")))); + var next = ch.ReadInbound(); + Assert.Null(next); + + var response = ch.ReadOutbound(); + Assert.Equal(HttpResponseStatus.ExpectationFailed, response.Status); + Assert.Equal((AsciiString)"0", response.Headers.Get(HttpHeaderNames.ContentLength, null)); + response.Release(); + + if (close) + { + Assert.False(ch.Open); + } + else + { + // keep-alive is on by default in HTTP/1.1, so the connection should be still alive + Assert.True(ch.Open); + + // the decoder should be reset by the aggregator at this point and be able to decode the next request + Assert.True(ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\n\r\n")))); + + var request = ch.ReadInbound(); + Assert.NotNull(request); + Assert.Equal(HttpMethod.Get, request.Method); + Assert.Equal("/", request.Uri); + Assert.Equal(0, request.Content.ReadableBytes); + request.Release(); + } + + Assert.False(ch.Finish()); + } + + [Fact] + public void OversizedRequestWith100ContinueAndDecoder() + { + var ch = new EmbeddedChannel(new HttpRequestDecoder(), new HttpObjectAggregator(4)); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes( + "PUT /upload HTTP/1.1\r\n" + + "Expect: 100-continue\r\n" + + "Content-Length: 100\r\n\r\n"))); + + var next = ch.ReadInbound(); + Assert.Null(next); + + var response = ch.ReadOutbound(); + Assert.Equal(HttpResponseStatus.RequestEntityTooLarge, response.Status); + Assert.Equal((AsciiString)"0", response.Headers.Get(HttpHeaderNames.ContentLength, null)); + + // Keep-alive is on by default in HTTP/1.1, so the connection should be still alive. + Assert.True(ch.Open); + + // The decoder should be reset by the aggregator at this point and be able to decode the next request. + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("GET /max-upload-size HTTP/1.1\r\n\r\n"))); + + var request = ch.ReadInbound(); + Assert.Equal(HttpMethod.Get, request.Method); + Assert.Equal("/max-upload-size", request.Uri); + Assert.Equal(0, request.Content.ReadableBytes); + request.Release(); + + Assert.False(ch.Finish()); + } + + [Fact] + public void OversizedRequestWith100ContinueAndDecoderCloseConnection() + { + var ch = new EmbeddedChannel(new HttpRequestDecoder(), new HttpObjectAggregator(4, true)); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes( + "PUT /upload HTTP/1.1\r\n" + + "Expect: 100-continue\r\n" + + "Content-Length: 100\r\n\r\n"))); + + var next = ch.ReadInbound(); + Assert.Null(next); + + var response = ch.ReadOutbound(); + Assert.Equal(HttpResponseStatus.RequestEntityTooLarge, response.Status); + Assert.Equal((AsciiString)"0", response.Headers.Get(HttpHeaderNames.ContentLength, null)); + + // We are forcing the connection closed if an expectation is exceeded. + Assert.False(ch.Open); + Assert.False(ch.Finish()); + } + + [Fact] + public void RequestAfterOversized100ContinueAndDecoder() + { + var ch = new EmbeddedChannel(new HttpRequestDecoder(), new HttpObjectAggregator(15)); + + // Write first request with Expect: 100-continue. + var message = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Put, "http://localhost"); + HttpUtil.Set100ContinueExpected(message, true); + HttpUtil.SetContentLength(message, 16); + + var chunk1 = new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("some"))); + var chunk2 = new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("test"))); + EmptyLastHttpContent chunk3 = EmptyLastHttpContent.Default; + + // Send a request with 100-continue + large Content-Length header value. + Assert.False(ch.WriteInbound(message)); + + // The aggregator should respond with '413'. + var response = ch.ReadOutbound(); + Assert.Equal(HttpResponseStatus.RequestEntityTooLarge, response.Status); + Assert.Equal((AsciiString)"0", response.Headers.Get(HttpHeaderNames.ContentLength, null)); + + // An ill-behaving client could continue to send data without a respect, and such data should be discarded. + Assert.False(ch.WriteInbound(chunk1)); + + // The aggregator should not close the connection because keep-alive is on. + Assert.True(ch.Open); + + // Now send a valid request. + var message2 = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Put, "http://localhost"); + + Assert.False(ch.WriteInbound(message2)); + Assert.False(ch.WriteInbound(chunk2)); + Assert.True(ch.WriteInbound(chunk3)); + + var fullMsg = ch.ReadInbound(); + Assert.NotNull(fullMsg); + + Assert.Equal(chunk2.Content.ReadableBytes + chunk3.Content.ReadableBytes,HttpUtil.GetContentLength(fullMsg)); + Assert.Equal(HttpUtil.GetContentLength(fullMsg), fullMsg.Content.ReadableBytes); + + fullMsg.Release(); + Assert.False(ch.Finish()); + } + + [Fact] + public void ReplaceAggregatedRequest() + { + var ch = new EmbeddedChannel(new HttpObjectAggregator(1024 * 1024)); + + var boom = new Exception("boom"); + var req = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Get, "http://localhost"); + req.Result = DecoderResult.Failure(boom); + + Assert.True(ch.WriteInbound(req) && ch.Finish()); + + var aggregatedReq = ch.ReadInbound(); + var replacedReq = (IFullHttpRequest)aggregatedReq.Replace(Unpooled.Empty); + + Assert.Equal(replacedReq.Result, aggregatedReq.Result); + aggregatedReq.Release(); + replacedReq.Release(); + } + + [Fact] + public void ReplaceAggregatedResponse() + { + var ch = new EmbeddedChannel(new HttpObjectAggregator(1024 * 1024)); + + var boom = new Exception("boom"); + var rep = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + rep.Result = DecoderResult.Failure(boom); + + Assert.True(ch.WriteInbound(rep) && ch.Finish()); + + var aggregatedRep = ch.ReadInbound(); + var replacedRep = (IFullHttpResponse)aggregatedRep.Replace(Unpooled.Empty); + + Assert.Equal(replacedRep.Result, aggregatedRep.Result); + aggregatedRep.Release(); + replacedRep.Release(); + } + + static void CheckOversizedRequest(IHttpRequest message) + { + var ch = new EmbeddedChannel(new HttpObjectAggregator(4)); + + Assert.False(ch.WriteInbound(message)); + var response = ch.ReadOutbound(); + Assert.Equal(HttpResponseStatus.RequestEntityTooLarge, response.Status); + Assert.Equal("0", response.Headers.Get(HttpHeaderNames.ContentLength, null)); + + if (ServerShouldCloseConnection(message, response)) + { + Assert.False(ch.Open); + Assert.False(ch.Finish()); + } + else + { + Assert.True(ch.Open); + } + } + + static bool ServerShouldCloseConnection(IHttpRequest message, IHttpResponse response) + { + // If the response wasn't keep-alive, the server should close the connection. + if (!HttpUtil.IsKeepAlive(response)) + { + return true; + } + // The connection should only be kept open if Expect: 100-continue is set, + // or if keep-alive is on. + if (HttpUtil.Is100ContinueExpected(message)) + { + return false; + } + if (HttpUtil.IsKeepAlive(message)) + { + return false; + } + + return true; + } + + static void CheckContentBuffer(IFullHttpRequest aggregatedMessage) + { + var buffer = (CompositeByteBuffer)aggregatedMessage.Content; + Assert.Equal(2, buffer.NumComponents); + IList buffers = buffer.Decompose(0, buffer.Capacity); + Assert.Equal(2, buffers.Count); + foreach (IByteBuffer buf in buffers) + { + // This should be false as we decompose the buffer before to not have deep hierarchy + Assert.False(buf is CompositeByteBuffer); + } + aggregatedMessage.Release(); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpRequestDecoderTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpRequestDecoderTest.cs new file mode 100644 index 0000000..693f863 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpRequestDecoderTest.cs @@ -0,0 +1,338 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System.Collections.Generic; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class HttpRequestDecoderTest + { + const int ContentLength = 8; + + static readonly byte[] ContentCrlfDelimiters = CreateContent("\r\n"); + static readonly byte[] ContentLfDelimiters = CreateContent("\n"); + static readonly byte[] ContentMixedDelimiters = CreateContent("\r\n", "\n"); + + static byte[] CreateContent(params string[] lineDelimiters) + { + string lineDelimiter; + string lineDelimiter2; + if (lineDelimiters.Length == 2) + { + lineDelimiter = lineDelimiters[0]; + lineDelimiter2 = lineDelimiters[1]; + } + else + { + lineDelimiter = lineDelimiters[0]; + lineDelimiter2 = lineDelimiters[0]; + } + + string content = "GET /some/path?foo=bar&wibble=eek HTTP/1.1" + "\r\n" + + "Upgrade: WebSocket" + lineDelimiter2 + + "Connection: Upgrade" + lineDelimiter + + "Host: localhost" + lineDelimiter2 + + "Origin: http://localhost:8080" + lineDelimiter + + "Sec-WebSocket-Key1: 10 28 8V7 8 48 0" + lineDelimiter2 + + "Sec-WebSocket-Key2: 8 Xt754O3Q3QW 0 _60" + lineDelimiter + + "Content-Length: " + ContentLength + lineDelimiter2 + + "\r\n" + + "12345678"; + + return Encoding.ASCII.GetBytes(content); + } + + [Fact] + public void DecodeWholeRequestAtOnceCrlfDelimiters() => DecodeWholeRequestAtOnce(ContentCrlfDelimiters); + + [Fact] + public void DecodeWholeRequestAtOnceLfDelimiters() => DecodeWholeRequestAtOnce(ContentLfDelimiters); + + [Fact] + public void DecodeWholeRequestAtOnceMixedDelimiters() => DecodeWholeRequestAtOnce(ContentMixedDelimiters); + + static void DecodeWholeRequestAtOnce(byte[] content) + { + var channel = new EmbeddedChannel(new HttpRequestDecoder()); + Assert.True(channel.WriteInbound(Unpooled.WrappedBuffer(content))); + var req = channel.ReadInbound(); + Assert.NotNull(req); + CheckHeaders(req.Headers); + + var c = channel.ReadInbound(); + Assert.Equal(ContentLength, c.Content.ReadableBytes); + Assert.Equal( + Unpooled.WrappedBuffer(content, content.Length - ContentLength, ContentLength), + c.Content.ReadSlice(ContentLength)); + c.Release(); + + Assert.False(channel.Finish()); + Assert.Null(channel.ReadInbound()); + } + + static void CheckHeaders(HttpHeaders headers) + { + Assert.Equal(7, headers.Names().Count); + CheckHeader(headers, "Upgrade", "WebSocket"); + CheckHeader(headers, "Connection", "Upgrade"); + CheckHeader(headers, "Host", "localhost"); + CheckHeader(headers, "Origin", "http://localhost:8080"); + CheckHeader(headers, "Sec-WebSocket-Key1", "10 28 8V7 8 48 0"); + CheckHeader(headers, "Sec-WebSocket-Key2", "8 Xt754O3Q3QW 0 _60"); + CheckHeader(headers, "Content-Length", $"{ContentLength}"); + } + + static void CheckHeader(HttpHeaders headers, string name, string value) + { + var headerName = (AsciiString)name; + var headerValue = (StringCharSequence)value; + + IList header1 = headers.GetAll(headerName); + Assert.Equal(1, header1.Count); + Assert.Equal(headerValue, header1[0]); + } + + [Fact] + public void DecodeWholeRequestInMultipleStepsCrlfDelimiters() => DecodeWholeRequestInMultipleSteps(ContentCrlfDelimiters); + + [Fact] + public void DecodeWholeRequestInMultipleStepsLFDelimiters() => DecodeWholeRequestInMultipleSteps(ContentLfDelimiters); + + [Fact] + public void DecodeWholeRequestInMultipleStepsMixedDelimiters() => DecodeWholeRequestInMultipleSteps(ContentMixedDelimiters); + + static void DecodeWholeRequestInMultipleSteps(byte[] content) + { + for (int i = 1; i < content.Length; i++) + { + DecodeWholeRequestInMultipleSteps(content, i); + } + } + + static void DecodeWholeRequestInMultipleSteps(byte[] content, int fragmentSize) + { + var channel = new EmbeddedChannel(new HttpRequestDecoder()); + int headerLength = content.Length - ContentLength; + + // split up the header + for (int a = 0; a < headerLength;) + { + int amount = fragmentSize; + if (a + amount > headerLength) + { + amount = headerLength - a; + } + + // if header is done it should produce a HttpRequest + channel.WriteInbound(Unpooled.WrappedBuffer(content, a, amount)); + a += amount; + } + + for (int i = ContentLength; i > 0; i--) + { + // Should produce HttpContent + channel.WriteInbound(Unpooled.WrappedBuffer(content, content.Length - i, 1)); + } + + var req = channel.ReadInbound(); + Assert.NotNull(req); + CheckHeaders(req.Headers); + + for (int i = ContentLength; i > 1; i--) + { + var c = channel.ReadInbound(); + Assert.Equal(1, c.Content.ReadableBytes); + Assert.Equal(content[content.Length - i], c.Content.ReadByte()); + c.Release(); + } + + var last = channel.ReadInbound(); + Assert.Equal(1, last.Content.ReadableBytes); + Assert.Equal(content[content.Length - 1], last.Content.ReadByte()); + last.Release(); + + Assert.False(channel.Finish()); + Assert.Null(channel.ReadInbound()); + } + + [Fact] + public void MultiLineHeader() + { + var channel = new EmbeddedChannel(new HttpRequestDecoder()); + const string Crlf = "\r\n"; + const string Request = "GET /some/path HTTP/1.1" + Crlf + + "Host: localhost" + Crlf + + "MyTestHeader: part1" + Crlf + + " newLinePart2" + Crlf + + "MyTestHeader2: part21" + Crlf + + "\t newLinePart22" + + Crlf + Crlf; + Assert.True(channel.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(Request)))); + var req = channel.ReadInbound(); + Assert.Equal("part1 newLinePart2", req.Headers.Get(new AsciiString("MyTestHeader"), null).ToString()); + Assert.Equal("part21 newLinePart22", req.Headers.Get(new AsciiString("MyTestHeader2"), null).ToString()); + + var c = channel.ReadInbound(); + c.Release(); + + Assert.False(channel.Finish()); + var last = channel.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void EmptyHeaderValue() + { + var channel = new EmbeddedChannel(new HttpRequestDecoder()); + const string Crlf = "\r\n"; + const string Request = "GET /some/path HTTP/1.1" + Crlf + + "Host: localhost" + Crlf + + "EmptyHeader:" + Crlf + Crlf; + byte[] data = Encoding.ASCII.GetBytes(Request); + + channel.WriteInbound(Unpooled.WrappedBuffer(data)); + var req = channel.ReadInbound(); + Assert.Equal("", req.Headers.Get((AsciiString)"EmptyHeader", null).ToString()); + } + + [Fact] + public void Http100Continue() + { + var decoder = new HttpRequestDecoder(); + var channel = new EmbeddedChannel(decoder); + const string Oversized = "PUT /file HTTP/1.1\r\n" + + "Expect: 100-continue\r\n" + + "Content-Length: 1048576000\r\n\r\n"; + byte[] data = Encoding.ASCII.GetBytes(Oversized); + channel.WriteInbound(Unpooled.CopiedBuffer(data)); + var req = channel.ReadInbound(); + Assert.NotNull(req); + + // At this point, we assume that we sent '413 Entity Too Large' to the peer without closing the connection + // so that the client can try again. + decoder.Reset(); + + const string Query = "GET /max-file-size HTTP/1.1\r\n\r\n"; + data = Encoding.ASCII.GetBytes(Query); + channel.WriteInbound(Unpooled.CopiedBuffer(data)); + + req = channel.ReadInbound(); + Assert.NotNull(req); + + var last = channel.ReadInbound(); + Assert.NotNull(last); + Assert.IsType(last); + + Assert.False(channel.Finish()); + } + + [Fact] + public void Http100ContinueWithBadClient() + { + var decoder = new HttpRequestDecoder(); + var channel = new EmbeddedChannel(decoder); + const string Oversized = + "PUT /file HTTP/1.1\r\n" + + "Expect: 100-continue\r\n" + + "Content-Length: 1048576000\r\n\r\n" + + "WAY_TOO_LARGE_DATA_BEGINS"; + byte[] data = Encoding.ASCII.GetBytes(Oversized); + channel.WriteInbound(Unpooled.CopiedBuffer(data)); + var req = channel.ReadInbound(); + Assert.NotNull(req); + + var prematureData = channel.ReadInbound(); + prematureData.Release(); + + req = channel.ReadInbound(); + Assert.Null(req); + + // At this point, we assume that we sent '413 Entity Too Large' to the peer without closing the connection + // so that the client can try again. + decoder.Reset(); + + const string Query = "GET /max-file-size HTTP/1.1\r\n\r\n"; + data = Encoding.ASCII.GetBytes(Query); + channel.WriteInbound(Unpooled.CopiedBuffer(data)); + + req = channel.ReadInbound(); + Assert.NotNull(req); + + var last = channel.ReadInbound(); + Assert.NotNull(last); + Assert.IsType(last); + + Assert.False(channel.Finish()); + } + + [Fact] + public void MessagesSplitBetweenMultipleBuffers() + { + var channel = new EmbeddedChannel(new HttpRequestDecoder()); + const string Crlf = "\r\n"; + const string Str1 = "GET /some/path HTTP/1.1" + Crlf + + "Host: localhost1" + Crlf + Crlf + + "GET /some/other/path HTTP/1.0" + Crlf + + "Hos"; + const string Str2 = "t: localhost2" + Crlf + + "content-length: 0" + Crlf + Crlf; + + byte[] data = Encoding.ASCII.GetBytes(Str1); + channel.WriteInbound(Unpooled.CopiedBuffer(data)); + + var req = channel.ReadInbound(); + Assert.Equal(HttpVersion.Http11, req.ProtocolVersion); + Assert.Equal("/some/path", req.Uri); + Assert.Equal(1, req.Headers.Size); + Assert.True(AsciiString.ContentEqualsIgnoreCase((StringCharSequence)"localhost1", req.Headers.Get(HttpHeaderNames.Host, null))); + var cnt = channel.ReadInbound(); + cnt.Release(); + + data = Encoding.ASCII.GetBytes(Str2); + channel.WriteInbound(Unpooled.CopiedBuffer(data)); + req = channel.ReadInbound(); + Assert.Equal(HttpVersion.Http10, req.ProtocolVersion); + Assert.Equal("/some/other/path", req.Uri); + Assert.Equal(2, req.Headers.Size); + Assert.True(AsciiString.ContentEqualsIgnoreCase((StringCharSequence)"localhost2", req.Headers.Get(HttpHeaderNames.Host, null))); + Assert.True(AsciiString.ContentEqualsIgnoreCase((StringCharSequence)"0", req.Headers.Get(HttpHeaderNames.ContentLength, null))); + cnt = channel.ReadInbound(); + cnt.Release(); + + Assert.False(channel.FinishAndReleaseAll()); + } + + [Fact] + public void TooLargeInitialLine() + { + var channel = new EmbeddedChannel(new HttpRequestDecoder(10, 1024, 1024)); + const string RequestStr = "GET /some/path HTTP/1.1\r\n" + + "Host: localhost1\r\n\r\n"; + + Assert.True(channel.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(RequestStr)))); + var request = channel.ReadInbound(); + Assert.True(request.Result.IsFailure); + Assert.IsType(request.Result.Cause); + Assert.False(channel.Finish()); + } + + [Fact] + public void TooLargeHeaders() + { + var channel = new EmbeddedChannel(new HttpRequestDecoder(1024, 10, 1024)); + const string RequestStr = "GET /some/path HTTP/1.1\r\n" + + "Host: localhost1\r\n\r\n"; + + Assert.True(channel.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(RequestStr)))); + var request = channel.ReadInbound(); + Assert.True(request.Result.IsFailure); + Assert.IsType(request.Result.Cause); + Assert.False(channel.Finish()); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpRequestEncoderTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpRequestEncoderTest.cs new file mode 100644 index 0000000..ae0745c --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpRequestEncoderTest.cs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class HttpRequestEncoderTest + { + [Fact] + public void UriWithoutPath() + { + var encoder = new HttpRequestEncoder(); + IByteBuffer buffer = Unpooled.Buffer(64); + encoder.EncodeInitialLine( + buffer, + new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Get, "http://localhost")); + string req = buffer.ToString(Encoding.ASCII); + Assert.Equal("GET http://localhost/ HTTP/1.1\r\n", req); + } + + [Fact] + public void UriWithoutPath2() + { + var encoder = new HttpRequestEncoder(); + IByteBuffer buffer = Unpooled.Buffer(64); + encoder.EncodeInitialLine( + buffer, + new DefaultHttpRequest( + HttpVersion.Http11, + HttpMethod.Get, + "http://localhost:9999?p1=v1")); + string req = buffer.ToString(Encoding.ASCII); + Assert.Equal("GET http://localhost:9999/?p1=v1 HTTP/1.1\r\n", req); + } + + [Fact] + public void UriWithPath() + { + var encoder = new HttpRequestEncoder(); + IByteBuffer buffer = Unpooled.Buffer(64); + encoder.EncodeInitialLine( + buffer, + new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Get, "http://localhost/")); + string req = buffer.ToString(Encoding.ASCII); + Assert.Equal("GET http://localhost/ HTTP/1.1\r\n", req); + } + + [Fact] + public void AbsPath() + { + var encoder = new HttpRequestEncoder(); + IByteBuffer buffer = Unpooled.Buffer(64); + encoder.EncodeInitialLine( + buffer, + new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/")); + string req = buffer.ToString(Encoding.ASCII); + Assert.Equal("GET / HTTP/1.1\r\n", req); + } + + [Fact] + public void EmptyAbsPath() + { + var encoder = new HttpRequestEncoder(); + IByteBuffer buffer = Unpooled.Buffer(64); + encoder.EncodeInitialLine( + buffer, + new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Get, "")); + string req = buffer.ToString(Encoding.ASCII); + Assert.Equal("GET / HTTP/1.1\r\n", req); + } + + [Fact] + public void QueryStringPath() + { + var encoder = new HttpRequestEncoder(); + IByteBuffer buffer = Unpooled.Buffer(64); + encoder.EncodeInitialLine( + buffer, + new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/?url=http://example.com")); + string req = buffer.ToString(Encoding.ASCII); + Assert.Equal("GET /?url=http://example.com HTTP/1.1\r\n", req); + } + + [Fact] + public void EmptyReleasedBufferShouldNotWriteEmptyBufferToChannel() + { + var encoder = new HttpRequestEncoder(); + var channel = new EmbeddedChannel(encoder); + IByteBuffer buf = Unpooled.Buffer(); + buf.Release(); + var exception = Assert.Throws(() => channel.WriteAndFlushAsync(buf).Wait()); + Assert.Single(exception.InnerExceptions); + Assert.IsType(exception.InnerExceptions[0]); + Assert.IsType(exception.InnerExceptions[0].InnerException); + channel.FinishAndReleaseAll(); + } + + [Fact] + public void EmptydBufferShouldPassThrough() + { + var encoder = new HttpRequestEncoder(); + var channel = new EmbeddedChannel(encoder); + IByteBuffer buffer = Unpooled.Buffer(); + channel.WriteAndFlushAsync(buffer).Wait(); + channel.FinishAndReleaseAll(); + Assert.Equal(0, buffer.ReferenceCount); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpResponseDecoderTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpResponseDecoderTest.cs new file mode 100644 index 0000000..a7bf3c8 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpResponseDecoderTest.cs @@ -0,0 +1,719 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System; + using System.Collections.Generic; + using System.Linq; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class HttpResponseDecoderTest + { + // The size of headers should be calculated correctly even if a single header is split into multiple fragments. + // see #3445 + [Fact] + public void MaxHeaderSize1() + { + const int MaxHeaderSize = 8192; + + var ch = new EmbeddedChannel(new HttpResponseDecoder(4096, MaxHeaderSize, 8192)); + var bytes = new byte[MaxHeaderSize / 2 - 2]; + bytes.Fill((byte)'a'); + + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("HTTP/1.1 200 OK\r\n"))); + + // Write two 4096-byte headers (= 8192 bytes) + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("A:"))); + ch.WriteInbound(Unpooled.CopiedBuffer(bytes)); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("\r\n"))); + Assert.Null(ch.ReadInbound()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("B:"))); + ch.WriteInbound(Unpooled.CopiedBuffer(bytes)); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("\r\n"))); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("\r\n"))); + + var res = ch.ReadInbound(); + Assert.Null(res.Result.Cause); + Assert.True(res.Result.IsSuccess); + + Assert.Null(ch.ReadInbound()); + Assert.True(ch.Finish()); + + var last = ch.ReadInbound(); + Assert.NotNull(last); + } + + // Complementary test case of {@link #testMaxHeaderSize1()} When it actually exceeds the maximum, it should fail. + [Fact] + public void MaxHeaderSize2() + { + const int MaxHeaderSize = 8192; + + var ch = new EmbeddedChannel(new HttpResponseDecoder(4096, MaxHeaderSize, 8192)); + var bytes = new byte[MaxHeaderSize / 2 - 2]; + bytes.Fill((byte)'a'); + + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("HTTP/1.1 200 OK\r\n"))); + + // Write a 4096-byte header and a 4097-byte header to test an off-by-one case (= 8193 bytes) + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("A:"))); + ch.WriteInbound(Unpooled.CopiedBuffer(bytes)); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("\r\n"))); + Assert.Null(ch.ReadInbound()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("B: "))); // Note an extra space. + ch.WriteInbound(Unpooled.CopiedBuffer(bytes)); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("\r\n"))); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("\r\n"))); + + var res = ch.ReadInbound(); + Assert.True(res.Result.Cause is TooLongFrameException); + + Assert.False(ch.Finish()); + + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void ResponseChunked() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"))); + + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.OK, res.Status); + + var data = new byte[64]; + for (int i = 0; i < data.Length; i++) + { + data[i] = (byte)i; + } + + for (int i = 0; i < 10; i++) + { + Assert.False(ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes($"{Convert.ToString(data.Length, 16)}\r\n")))); + Assert.True(ch.WriteInbound(Unpooled.WrappedBuffer(data))); + var content = ch.ReadInbound(); + Assert.Equal(data.Length, content.Content.ReadableBytes); + + var decodedData = new byte[data.Length]; + content.Content.ReadBytes(decodedData); + Assert.True(data.SequenceEqual(decodedData)); + content.Release(); + + Assert.False(ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("\r\n")))); + } + + // Write the last chunk. + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("0\r\n\r\n"))); + + // Ensure the last chunk was decoded. + var lastContent = ch.ReadInbound(); + Assert.False(lastContent.Content.IsReadable()); + lastContent.Release(); + + ch.Finish(); + + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void ResponseChunkedExceedMaxChunkSize() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder(4096, 8192, 32)); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"))); + + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.OK, res.Status); + + var data = new byte[64]; + for (int i = 0; i < data.Length; i++) + { + data[i] = (byte)i; + } + + for (int i = 0; i < 10; i++) + { + Assert.False(ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes($"{Convert.ToString(data.Length, 16)}\r\n")))); + Assert.True(ch.WriteInbound(Unpooled.WrappedBuffer(data))); + + var decodedData = new byte[data.Length]; + var content = ch.ReadInbound(); + Assert.Equal(32, content.Content.ReadableBytes); + content.Content.ReadBytes(decodedData, 0, 32); + content.Release(); + + content = ch.ReadInbound(); + Assert.NotNull(content); + Assert.Equal(32, content.Content.ReadableBytes); + + content.Content.ReadBytes(decodedData, 32, 32); + + Assert.True(decodedData.SequenceEqual(data)); + content.Release(); + + Assert.False(ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("\r\n")))); + } + + // Write the last chunk. + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("0\r\n\r\n"))); + + // Ensure the last chunk was decoded. + var lastContent = ch.ReadInbound(); + Assert.False(lastContent.Content.IsReadable()); + lastContent.Release(); + + ch.Finish(); + + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void ClosureWithoutContentLength1() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("HTTP/1.1 200 OK\r\n\r\n"))); + + // Read the response headers. + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.OK, res.Status); + + res = ch.ReadInbound(); + Assert.Null(res); + + // Close the connection without sending anything. + Assert.True(ch.Finish()); + + // The decoder should still produce the last content. + var content = ch.ReadInbound(); + Assert.NotNull(content); + Assert.False(content.Content.IsReadable()); + content.Release(); + + // But nothing more. + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void ClosureWithoutContentLength2() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + + // Write the partial response. + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("HTTP/1.1 200 OK\r\n\r\n12345678"))); + + // Read the response headers. + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.OK, res.Status); + + // Read the partial content. + var content = ch.ReadInbound(); + Assert.Equal("12345678", content.Content.ToString(Encoding.ASCII)); + Assert.Null(content as ILastHttpContent); + content.Release(); + + res = ch.ReadInbound(); + Assert.Null(res); + + // Close the connection. + Assert.True(ch.Finish()); + + // The decoder should still produce the last content. + var lastContent = ch.ReadInbound(); + Assert.NotNull(lastContent); + Assert.False(lastContent.Content.IsReadable()); + lastContent.Release(); + + // But nothing more. + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void PrematureClosureWithChunkedEncoding1() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"))); + + // Read the response headers. + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.OK, res.Status); + Assert.Equal("chunked", res.Headers.Get(HttpHeaderNames.TransferEncoding, null).ToString()); + res = ch.ReadInbound(); + Assert.Null(res); + + // Close the connection without sending anything. + ch.Finish(); + + // The decoder should not generate the last chunk because it's closed prematurely. + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void PrematureClosureWithChunkedEncoding2() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + + // Write the partial response. + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n8\r\n12345678"))); + + // Read the response headers. + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.OK, res.Status); + Assert.Equal("chunked", res.Headers.Get(HttpHeaderNames.TransferEncoding, null).ToString()); + + // Read the partial content. + var content = ch.ReadInbound(); + Assert.Equal("12345678", content.Content.ToString(Encoding.ASCII)); + Assert.Null(content as ILastHttpContent); + content.Release(); + + content = ch.ReadInbound(); + Assert.Null(content); + + // Close the connection. + ch.Finish(); + + // The decoder should not generate the last chunk because it's closed prematurely. + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void LastResponseWithEmptyHeaderAndEmptyContent() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("HTTP/1.1 200 OK\r\n\r\n"))); + + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.OK, res.Status); + + var content = ch.ReadInbound(); + Assert.Null(content); + + Assert.True(ch.Finish()); + + var lastContent = ch.ReadInbound(); + Assert.False(lastContent.Content.IsReadable()); + lastContent.Release(); + + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void LastResponseWithoutContentLengthHeader() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("HTTP/1.1 200 OK\r\n\r\n"))); + + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.OK, res.Status); + var content = ch.ReadInbound(); + Assert.Null(content); + + ch.WriteInbound(Unpooled.WrappedBuffer(new byte[1024])); + content = ch.ReadInbound(); + Assert.Equal(1024, content.Content.ReadableBytes); + content.Release(); + + Assert.True(ch.Finish()); + + var lastContent = ch.ReadInbound(); + Assert.False(lastContent.Content.IsReadable()); + lastContent.Release(); + + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void LastResponseWithHeaderRemoveTrailingSpaces() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes("HTTP/1.1 200 OK\r\nX-Header: h2=h2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT \r\n\r\n"))); + + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.OK, res.Status); + Assert.Equal("h2=h2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT", res.Headers.Get((AsciiString)"X-Header", null).ToString()); + var content = ch.ReadInbound(); + Assert.Null(content); + + ch.WriteInbound(Unpooled.WrappedBuffer(new byte[1024])); + content = ch.ReadInbound(); + Assert.Equal(1024, content.Content.ReadableBytes); + content.Release(); + + Assert.True(ch.Finish()); + + var lastContent = ch.ReadInbound(); + Assert.False(lastContent.Content.IsReadable()); + lastContent.Release(); + + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void ResetContentResponseWithTransferEncoding() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + Assert.True(ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes( + "HTTP/1.1 205 Reset Content\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "0\r\n" + + "\r\n")))); + + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.ResetContent, res.Status); + + var lastContent = ch.ReadInbound(); + Assert.False(lastContent.Content.IsReadable()); + lastContent.Release(); + + Assert.False(ch.Finish()); + } + + [Fact] + public void LastResponseWithTrailingHeader() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes( + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "0\r\n" + + "Set-Cookie: t1=t1v1\r\n" + + "Set-Cookie: t2=t2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT\r\n" + + "\r\n"))); + + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.OK, res.Status); + + var lastContent = ch.ReadInbound(); + Assert.False(lastContent.Content.IsReadable()); + HttpHeaders headers = lastContent.TrailingHeaders; + Assert.Equal(1, headers.Names().Count); + IList values = headers.GetAll((AsciiString)"Set-Cookie"); + Assert.Equal(2, values.Count); + Assert.True(values.Contains((AsciiString)"t1=t1v1")); + Assert.True(values.Contains((AsciiString)"t2=t2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT")); + lastContent.Release(); + + Assert.False(ch.Finish()); + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void LastResponseWithTrailingHeaderFragmented() + { + byte[] data = Encoding.ASCII.GetBytes("HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "0\r\n" + + "Set-Cookie: t1=t1v1\r\n" + + "Set-Cookie: t2=t2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT\r\n" + + "\r\n"); + + for (int i = 1; i < data.Length; i++) + { + LastResponseWithTrailingHeaderFragmented0(data, i); + } + } + + static void LastResponseWithTrailingHeaderFragmented0(byte[] content, int fragmentSize) + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + const int HeaderLength = 47; + + // split up the header + for (int a = 0; a < HeaderLength;) + { + int amount = fragmentSize; + if (a + amount > HeaderLength) + { + amount = HeaderLength - a; + } + + // if header is done it should produce a HttpRequest + bool headerDone = a + amount == HeaderLength; + Assert.Equal(headerDone, ch.WriteInbound(Unpooled.WrappedBuffer(content, a, amount))); + a += amount; + } + + ch.WriteInbound(Unpooled.WrappedBuffer(content, HeaderLength, content.Length - HeaderLength)); + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.OK, res.Status); + + var lastContent = ch.ReadInbound(); + Assert.False(lastContent.Content.IsReadable()); + + HttpHeaders headers = lastContent.TrailingHeaders; + Assert.Equal(1, headers.Names().Count); + IList values = headers.GetAll((AsciiString)"Set-Cookie"); + Assert.Equal(2, values.Count); + Assert.True(values.Contains((AsciiString)"t1=t1v1")); + Assert.True(values.Contains((AsciiString)"t2=t2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT")); + lastContent.Release(); + + Assert.False(ch.Finish()); + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void ResponseWithContentLength() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes( + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 10\r\n" + + "\r\n"))); + + var data = new byte[10]; + for (int i = 0; i < data.Length; i++) + { + data[i] = (byte)i; + } + ch.WriteInbound(Unpooled.WrappedBuffer(data, 0, data.Length / 2)); + ch.WriteInbound(Unpooled.WrappedBuffer(data, 5, data.Length / 2)); + + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.OK, res.Status); + + var firstContent = ch.ReadInbound(); + Assert.Equal(5, firstContent.Content.ReadableBytes); + Assert.Equal(Unpooled.WrappedBuffer(data, 0, 5), firstContent.Content); + firstContent.Release(); + + var lastContent = ch.ReadInbound(); + Assert.Equal(5, lastContent.Content.ReadableBytes); + Assert.Equal(Unpooled.WrappedBuffer(data, 5, 5), lastContent.Content); + lastContent.Release(); + + Assert.False(ch.Finish()); + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void ResponseWithContentLengthFragmented() + { + byte[] data = Encoding.ASCII.GetBytes("HTTP/1.1 200 OK\r\n" + + "Content-Length: 10\r\n" + + "\r\n"); + + for (int i = 1; i < data.Length; i++) + { + ResponseWithContentLengthFragmented0(data, i); + } + } + + static void ResponseWithContentLengthFragmented0(byte[] header, int fragmentSize) + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + // split up the header + for (int a = 0; a < header.Length;) + { + int amount = fragmentSize; + if (a + amount > header.Length) + { + amount = header.Length - a; + } + + ch.WriteInbound(Unpooled.WrappedBuffer(header, a, amount)); + a += amount; + } + var data = new byte[10]; + for (int i = 0; i < data.Length; i++) + { + data[i] = (byte)i; + } + ch.WriteInbound(Unpooled.WrappedBuffer(data, 0, data.Length / 2)); + ch.WriteInbound(Unpooled.WrappedBuffer(data, 5, data.Length / 2)); + + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.OK, res.Status); + + var firstContent = ch.ReadInbound(); + Assert.Equal(5, firstContent.Content.ReadableBytes); + Assert.Equal(Unpooled.WrappedBuffer(data, 0, 5), firstContent.Content); + firstContent.Release(); + + var lastContent = ch.ReadInbound(); + Assert.Equal(5, lastContent.Content.ReadableBytes); + Assert.Equal(Unpooled.WrappedBuffer(data, 5, 5), lastContent.Content); + lastContent.Release(); + + Assert.False(ch.Finish()); + var last = ch.ReadInbound(); + Assert.Null(last); + } + + [Fact] + public void WebSocketResponse() + { + byte[] data = Encoding.ASCII.GetBytes("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + + "Upgrade: WebSocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Origin: http://localhost:8080\r\n" + + "Sec-WebSocket-Location: ws://localhost/some/path\r\n" + + "\r\n" + + "1234567812345678"); + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.WriteInbound(Unpooled.WrappedBuffer(data)); + + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.SwitchingProtocols, res.Status); + var content = ch.ReadInbound(); + Assert.Equal(16, content.Content.ReadableBytes); + content.Release(); + + Assert.False(ch.Finish()); + var last = ch.ReadInbound(); + Assert.Null(last); + } + + // See https://github.com/netty/netty/issues/2173 + [Fact] + public void WebSocketResponseWithDataFollowing() + { + byte[] data = Encoding.ASCII.GetBytes("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + + "Upgrade: WebSocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Origin: http://localhost:8080\r\n" + + "Sec-WebSocket-Location: ws://localhost/some/path\r\n" + + "\r\n" + + "1234567812345678"); + byte[] otherData = { 1, 2, 3, 4 }; + + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + IByteBuffer compositeBuffer = Unpooled.WrappedBuffer(data, otherData); + ch.WriteInbound(compositeBuffer); + + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http11, res.ProtocolVersion); + Assert.Equal(HttpResponseStatus.SwitchingProtocols, res.Status); + var content = ch.ReadInbound(); + Assert.Equal(16, content.Content.ReadableBytes); + content.Release(); + + Assert.True(ch.Finish()); + + IByteBuffer expected = Unpooled.WrappedBuffer(otherData); + var buffer = ch.ReadInbound(); + try + { + Assert.Equal(expected, buffer); + } + finally + { + expected.Release(); + buffer?.Release(); + } + } + + [Fact] + public void GarbageHeaders() + { + // A response without headers - from https://github.com/netty/netty/issues/2103 + byte[] data = Encoding.ASCII.GetBytes("\r\n" + + "400 Bad Request\r\n" + + "\r\n" + + "

400 Bad Request

\r\n" + + "
nginx/1.1.19
\r\n" + + "\r\n" + + "\r\n"); + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + + ch.WriteInbound(Unpooled.WrappedBuffer(data)); + // Garbage input should generate the 999 Unknown response. + var res = ch.ReadInbound(); + Assert.Same(HttpVersion.Http10, res.ProtocolVersion); + Assert.Equal(999, res.Status.Code); + Assert.True(res.Result.IsFailure); + Assert.True(res.Result.IsFinished); + + var next = ch.ReadInbound(); + Assert.Null(next); + + // More garbage should not generate anything (i.e. the decoder discards anything beyond this point.) + ch.WriteInbound(Unpooled.WrappedBuffer(data)); + next = ch.ReadInbound(); + Assert.Null(next); + + // Closing the connection should not generate anything since the protocol has been violated. + ch.Finish(); + next = ch.ReadInbound(); + Assert.Null(next); + } + + // Tests if the decoder produces one and only {@link LastHttpContent} when an invalid chunk is received and + // the connection is closed. + [Fact] + public void GarbageChunk() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + const string ResponseWithIllegalChunk = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "NOT_A_CHUNK_LENGTH\r\n"; + + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(ResponseWithIllegalChunk))); + var res = ch.ReadInbound(); + Assert.NotNull(res); + + // Ensure that the decoder generates the last chunk with correct decoder result. + var invalidChunk = ch.ReadInbound(); + Assert.True(invalidChunk.Result.IsFailure); + invalidChunk.Release(); + + // And no more messages should be produced by the decoder. + var next = ch.ReadInbound(); + Assert.Null(next); + + // .. even after the connection is closed. + Assert.False(ch.Finish()); + } + + [Fact] + public void ConnectionClosedBeforeHeadersReceived() + { + var ch = new EmbeddedChannel(new HttpResponseDecoder()); + const string ResponseInitialLine = "HTTP/1.1 200 OK\r\n"; + Assert.False(ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(ResponseInitialLine)))); + Assert.True(ch.Finish()); + var message = ch.ReadInbound(); + Assert.True(message.Result.IsFailure); + Assert.IsType(message.Result.Cause); + + var last = ch.ReadInbound(); + Assert.Null(last); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpResponseEncoderTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpResponseEncoderTest.cs new file mode 100644 index 0000000..4196bc3 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpResponseEncoderTest.cs @@ -0,0 +1,153 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System; + using System.IO; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Common; + using DotNetty.Transport.Channels; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class HttpResponseEncoderTest + { + const long IntegerOverflow = (long)int.MaxValue + 1; + static readonly IFileRegion FileRegion = new DummyLongFileRegion(); + + [Fact] + public void LargeFileRegionChunked() + { + var channel = new EmbeddedChannel(new HttpResponseEncoder()); + IHttpResponse response = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + response.Headers.Set(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + + Assert.True(channel.WriteOutbound(response)); + + var buffer = channel.ReadOutbound(); + Assert.Equal( + "HTTP/1.1 200 OK\r\n" + HttpHeaderNames.TransferEncoding + ": " + + HttpHeaderValues.Chunked + "\r\n\r\n", + buffer.ToString(Encoding.ASCII)); + buffer.Release(); + + Assert.True(channel.WriteOutbound(FileRegion)); + buffer = channel.ReadOutbound(); + Assert.Equal("80000000\r\n", buffer.ToString(Encoding.ASCII)); + buffer.Release(); + + var region = channel.ReadOutbound(); + Assert.Same(FileRegion, region); + region.Release(); + buffer = channel.ReadOutbound(); + Assert.Equal("\r\n", buffer.ToString(Encoding.ASCII)); + buffer.Release(); + + Assert.True(channel.WriteOutbound(EmptyLastHttpContent.Default)); + buffer = channel.ReadOutbound(); + Assert.Equal("0\r\n\r\n", buffer.ToString(Encoding.ASCII)); + buffer.Release(); + + Assert.False(channel.Finish()); + } + + class DummyLongFileRegion : IFileRegion + { + public int ReferenceCount => 1; + + public IReferenceCounted Retain() => this; + + public IReferenceCounted Retain(int increment) => this; + + public IReferenceCounted Touch() => this; + + public IReferenceCounted Touch(object hint) => this; + + public bool Release() => false; + + public bool Release(int decrement) => false; + + public long Position => 0; + + public long Transferred => 0; + + public long Count => IntegerOverflow; + + public long TransferTo(Stream target, long position) + { + throw new NotSupportedException(); + } + } + + [Fact] + public void EmptyBufferBypass() + { + var channel = new EmbeddedChannel(new HttpResponseEncoder()); + + // Test writing an empty buffer works when the encoder is at ST_INIT. + channel.WriteOutbound(Unpooled.Empty); + var buffer = channel.ReadOutbound(); + Assert.Same(buffer, Unpooled.Empty); + + // Leave the ST_INIT state. + IHttpResponse response = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + Assert.True(channel.WriteOutbound(response)); + buffer = channel.ReadOutbound(); + Assert.Equal("HTTP/1.1 200 OK\r\n\r\n", buffer.ToString(Encoding.ASCII)); + buffer.Release(); + + // Test writing an empty buffer works when the encoder is not at ST_INIT. + channel.WriteOutbound(Unpooled.Empty); + buffer = channel.ReadOutbound(); + Assert.Same(buffer, Unpooled.Empty); + + Assert.False(channel.Finish()); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void EmptyContent(bool chunked) + { + const string Content = "netty rocks"; + IByteBuffer contentBuffer = Unpooled.CopiedBuffer(Encoding.ASCII.GetBytes(Content)); + int length = contentBuffer.ReadableBytes; + + var channel = new EmbeddedChannel(new HttpResponseEncoder()); + IHttpResponse response = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + if (!chunked) + { + HttpUtil.SetContentLength(response, length); + } + Assert.True(channel.WriteOutbound(response)); + Assert.True(channel.WriteOutbound(new DefaultHttpContent(Unpooled.Empty))); + Assert.True(channel.WriteOutbound(new DefaultLastHttpContent(contentBuffer))); + + var buffer = channel.ReadOutbound(); + if (!chunked) + { + Assert.Equal( + "HTTP/1.1 200 OK\r\ncontent-length: " + length + "\r\n\r\n", + buffer.ToString(Encoding.ASCII)); + } + else + { + Assert.Equal("HTTP/1.1 200 OK\r\n\r\n", buffer.ToString(Encoding.ASCII)); + } + buffer.Release(); + + // Test writing an empty buffer works when the encoder is not at ST_INIT. + buffer = channel.ReadOutbound(); + Assert.Equal(0, buffer.ReadableBytes); + buffer.Release(); + + buffer = channel.ReadOutbound(); + Assert.Equal(length, buffer.ReadableBytes); + buffer.Release(); + + Assert.False(channel.Finish()); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpResponseStatusTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpResponseStatusTest.cs new file mode 100644 index 0000000..011d5e7 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpResponseStatusTest.cs @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System; + using DotNetty.Common.Utilities; + using Xunit; + + public sealed class HttpResponseStatusTest + { + [Fact] + public void ParseLineStringJustCode() + { + Assert.Same(HttpResponseStatus.OK, HttpResponseStatus.ParseLine("200")); + } + + [Fact] + public void ParseLineStringCodeAndPhrase() + { + Assert.Same(HttpResponseStatus.OK, HttpResponseStatus.ParseLine("200 OK")); + } + + [Fact] + public void ParseLineStringCustomCode() + { + HttpResponseStatus customStatus = HttpResponseStatus.ParseLine("612"); + Assert.Equal(612, customStatus.Code); + } + + [Fact] + public void ParseLineStringCustomCodeAndPhrase() + { + HttpResponseStatus customStatus = HttpResponseStatus.ParseLine("612 FOO"); + Assert.Equal(612, customStatus.Code); + Assert.Equal(new AsciiString("FOO"), customStatus.ReasonPhrase); + } + + [Fact] + public void ParseLineStringMalformedCode() + { + Assert.Throws(() => HttpResponseStatus.ParseLine("200a")); + } + + [Fact] + public void ParseLineStringMalformedCodeWithPhrase() + { + Assert.Throws(() => HttpResponseStatus.ParseLine("200a foo")); + } + + [Fact] + public void ParseLineAsciiStringJustCode() + { + Assert.Same(HttpResponseStatus.OK, HttpResponseStatus.ParseLine(new AsciiString("200"))); + } + + [Fact] + public void ParseLineAsciiStringCodeAndPhrase() + { + Assert.Same(HttpResponseStatus.OK, HttpResponseStatus.ParseLine(new AsciiString("200 OK"))); + } + + [Fact] + public void ParseLineAsciiStringCustomCode() + { + HttpResponseStatus customStatus = HttpResponseStatus.ParseLine(new AsciiString("612")); + Assert.Equal(612, customStatus.Code); + } + + [Fact] + public void ParseLineAsciiStringCustomCodeAndPhrase() + { + HttpResponseStatus customStatus = HttpResponseStatus.ParseLine(new AsciiString("612 FOO")); + Assert.Equal(612, customStatus.Code); + Assert.Equal("FOO", customStatus.ReasonPhrase); + } + + [Fact] + public void ParseLineAsciiStringMalformedCode() + { + Assert.Throws(() => HttpResponseStatus.ParseLine(new AsciiString("200a"))); + } + + [Fact] + public void ParseLineAsciiStringMalformedCodeWithPhrase() + { + Assert.Throws(() => HttpResponseStatus.ParseLine(new AsciiString("200a foo"))); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpServerCodecTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpServerCodecTest.cs new file mode 100644 index 0000000..22d0fa5 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpServerCodecTest.cs @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System.Text; + using DotNetty.Buffers; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class HttpServerCodecTest + { + // Testcase for https://github.com/netty/netty/issues/433 + [Fact] + public void UnfinishedChunkedHttpRequestIsLastFlag() + { + const int MaxChunkSize = 2000; + var httpServerCodec = new HttpServerCodec(1000, 1000, MaxChunkSize); + var ch = new EmbeddedChannel(httpServerCodec); + + int totalContentLength = MaxChunkSize * 5; + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes( + "PUT /test HTTP/1.1\r\n" + + "Content-Length: " + totalContentLength + "\r\n" + + "\r\n"))); + + int offeredContentLength = (int)(MaxChunkSize * 2.5); + ch.WriteInbound(PrepareDataChunk(offeredContentLength)); + ch.Finish(); + + var httpMessage = ch.ReadInbound(); + Assert.NotNull(httpMessage); + + bool empty = true; + int totalBytesPolled = 0; + for (;;) + { + var httpChunk = ch.ReadInbound(); + if (httpChunk == null) + { + break; + } + empty = false; + totalBytesPolled += httpChunk.Content.ReadableBytes; + Assert.False(httpChunk is ILastHttpContent); + httpChunk.Release(); + } + + Assert.False(empty); + Assert.Equal(offeredContentLength, totalBytesPolled); + } + + [Fact] + public void Code100Continue() + { + var ch = new EmbeddedChannel(new HttpServerCodec(), new HttpObjectAggregator(1024)); + + // Send the request headers. + ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes( + "PUT /upload-large HTTP/1.1\r\n" + + "Expect: 100-continue\r\n" + + "Content-Length: 1\r\n\r\n"))); + + // Ensure the aggregator generates nothing. + var next = ch.ReadInbound(); + Assert.Null(next); + + // Ensure the aggregator writes a 100 Continue response. + var continueResponse = ch.ReadOutbound(); + Assert.Equal("HTTP/1.1 100 Continue\r\n\r\n", continueResponse.ToString(Encoding.UTF8)); + continueResponse.Release(); + + // But nothing more. + next = ch.ReadInbound(); + Assert.Null(next); + + // Send the content of the request. + ch.WriteInbound(Unpooled.WrappedBuffer(new byte[] { 42 })); + + // Ensure the aggregator generates a full request. + var req = ch.ReadInbound(); + Assert.Equal("1", req.Headers.Get(HttpHeaderNames.ContentLength, null).ToString()); + Assert.Equal(1, req.Content.ReadableBytes); + Assert.Equal((byte)42, req.Content.ReadByte()); + req.Release(); + + // But nothing more. + next = ch.ReadInbound(); + Assert.Null(next); + + // Send the actual response. + var res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.Created); + res.Content.WriteBytes(Encoding.UTF8.GetBytes("OK")); + res.Headers.SetInt(HttpHeaderNames.ContentLength, 2); + ch.WriteOutbound(res); + + // Ensure the encoder handles the response after handling 100 Continue. + var encodedRes = ch.ReadOutbound(); + Assert.Equal("HTTP/1.1 201 Created\r\n" + HttpHeaderNames.ContentLength + ": 2\r\n\r\nOK", encodedRes.ToString(Encoding.UTF8)); + encodedRes.Release(); + + ch.Finish(); + } + + [Fact] + public void ChunkedHeadResponse() + { + var ch = new EmbeddedChannel(new HttpServerCodec()); + + // Send the request headers. + Assert.True(ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes( + "HEAD / HTTP/1.1\r\n\r\n")))); + + var request = ch.ReadInbound(); + Assert.Equal(HttpMethod.Head, request.Method); + var content = ch.ReadInbound(); + Assert.False(content.Content.IsReadable()); + content.Release(); + + var response = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + HttpUtil.SetTransferEncodingChunked(response, true); + Assert.True(ch.WriteOutbound(response)); + Assert.True(ch.WriteOutbound(EmptyLastHttpContent.Default)); + Assert.True(ch.Finish()); + + var buf = ch.ReadOutbound(); + Assert.Equal("HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\n\r\n", buf.ToString(Encoding.ASCII)); + buf.Release(); + + buf = ch.ReadOutbound(); + Assert.False(buf.IsReadable()); + buf.Release(); + + Assert.False(ch.FinishAndReleaseAll()); + } + + [Fact] + public void ChunkedHeadFullHttpResponse() + { + var ch = new EmbeddedChannel(new HttpServerCodec()); + + // Send the request headers. + Assert.True(ch.WriteInbound(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes( + "HEAD / HTTP/1.1\r\n\r\n")))); + + var request = ch.ReadInbound(); + Assert.Equal(HttpMethod.Head, request.Method); + var content = ch.ReadInbound(); + Assert.False(content.Content.IsReadable()); + content.Release(); + + var response = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + HttpUtil.SetTransferEncodingChunked(response, true); + Assert.True(ch.WriteOutbound(response)); + Assert.True(ch.Finish()); + + var buf = ch.ReadOutbound(); + Assert.Equal("HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\n\r\n", buf.ToString(Encoding.ASCII)); + buf.Release(); + + Assert.False(ch.FinishAndReleaseAll()); + } + + static IByteBuffer PrepareDataChunk(int size) + { + var sb = new StringBuilder(); + for (int i = 0; i < size; ++i) + { + sb.Append('a'); + } + + return Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes(sb.ToString())); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpServerExpectContinueHandlerTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpServerExpectContinueHandlerTest.cs new file mode 100644 index 0000000..896e029 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpServerExpectContinueHandlerTest.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + public sealed class HttpServerExpectContinueHandlerTest + { + sealed class ContinueHandler : HttpServerExpectContinueHandler + { + protected override IHttpResponse AcceptMessage(IHttpRequest request) + { + var response = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.Continue); + response.Headers.Set((AsciiString)"foo", (AsciiString)"bar"); + return response; + } + } + + [Fact] + public void ShouldRespondToExpectedHeader() + { + var channel = new EmbeddedChannel(new ContinueHandler()); + var request = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/"); + HttpUtil.Set100ContinueExpected(request, true); + + channel.WriteInbound(request); + var response = channel.ReadOutbound(); + + Assert.Equal(HttpResponseStatus.Continue, response.Status); + Assert.Equal((AsciiString)"bar", response.Headers.Get((AsciiString)"foo", null)); + ReferenceCountUtil.Release(response); + + var processedRequest = channel.ReadInbound(); + Assert.NotNull(processedRequest); + Assert.False(processedRequest.Headers.Contains(HttpHeaderNames.Expect)); + ReferenceCountUtil.Release(processedRequest); + Assert.False(channel.FinishAndReleaseAll()); + } + + sealed class CustomHandler : HttpServerExpectContinueHandler + { + protected override IHttpResponse AcceptMessage(IHttpRequest request) => null; + + protected override IHttpResponse RejectResponse(IHttpRequest request) => + new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.RequestEntityTooLarge); + } + + [Fact] + public void ShouldAllowCustomResponses() + { + var channel = new EmbeddedChannel(new CustomHandler()); + + var request = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Get, "/"); + HttpUtil.Set100ContinueExpected(request, true); + + channel.WriteInbound(request); + var response = channel.ReadOutbound(); + + Assert.Equal(HttpResponseStatus.RequestEntityTooLarge, response.Status); + ReferenceCountUtil.Release(response); + + // request was swallowed + Assert.Empty(channel.InboundMessages); + Assert.False(channel.FinishAndReleaseAll()); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpServerKeepAliveHandlerTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpServerKeepAliveHandlerTest.cs new file mode 100644 index 0000000..9dd004d --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpServerKeepAliveHandlerTest.cs @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System; + using System.Collections.Generic; + using DotNetty.Common.Utilities; + using DotNetty.Transport.Channels.Embedded; + using Xunit; + + using static HttpResponseStatus; + + public sealed class HttpServerKeepAliveHandlerTest + { + const string RequestKeepAlive = "REQUEST_KEEP_ALIVE"; + const int NotSelfDefinedMsgLength = 0; + const int SetResponseLength = 1; + const int SetMultipart = 2; + const int SetChunked = 4; + + public static IEnumerable GetKeepAliveCases() => new[] + { + new object[] { true, HttpVersion.Http10, OK, RequestKeepAlive, SetResponseLength, HttpHeaderValues.KeepAlive }, // 0 + new object[] { true, HttpVersion.Http10, OK, RequestKeepAlive, SetMultipart, HttpHeaderValues.KeepAlive }, // 1 + new object[] { false, HttpVersion.Http10, OK, null, SetResponseLength, null }, // 2 + new object[] { true, HttpVersion.Http11, OK, RequestKeepAlive, SetResponseLength, null }, // 3 + new object[] { false, HttpVersion.Http11, OK, RequestKeepAlive, SetResponseLength, HttpHeaderValues.Close }, // 4 + new object[] { true, HttpVersion.Http11, OK, RequestKeepAlive, SetMultipart, null }, // 5 + new object[] { true, HttpVersion.Http11, OK, RequestKeepAlive, SetChunked, null }, // 6 + new object[] { false, HttpVersion.Http11, OK, null, SetResponseLength, null }, // 7 + new object[] { false, HttpVersion.Http10, OK, RequestKeepAlive, NotSelfDefinedMsgLength, null }, // 8 + new object[] { false, HttpVersion.Http10, OK, null, NotSelfDefinedMsgLength, null }, // 9 + new object[] { false, HttpVersion.Http11, OK, RequestKeepAlive, NotSelfDefinedMsgLength, null }, // 10 + new object[] { false, HttpVersion.Http11, OK, null, NotSelfDefinedMsgLength, null }, // 11 + new object[] { false, HttpVersion.Http10, OK, RequestKeepAlive, SetResponseLength, null }, // 12 + new object[] { true, HttpVersion.Http11, NoContent, RequestKeepAlive, NotSelfDefinedMsgLength, null}, // 13 + new object[] { false, HttpVersion.Http10, NoContent, null, NotSelfDefinedMsgLength, null} // 14 + }; + + [Theory] + [MemberData(nameof(GetKeepAliveCases))] + public void KeepAlive(bool isKeepAliveResponseExpected, HttpVersion httpVersion, HttpResponseStatus responseStatus, string sendKeepAlive, int setSelfDefinedMessageLength, ICharSequence setResponseConnection) + { + var channel = new EmbeddedChannel(new HttpServerKeepAliveHandler()); + var request = new DefaultFullHttpRequest(httpVersion, HttpMethod.Get, "/v1/foo/bar"); + HttpUtil.SetKeepAlive(request, RequestKeepAlive.Equals(sendKeepAlive)); + var response = new DefaultFullHttpResponse(httpVersion, responseStatus); + if (!CharUtil.IsNullOrEmpty(setResponseConnection)) + { + response.Headers.Set(HttpHeaderNames.Connection, setResponseConnection); + } + SetupMessageLength(setSelfDefinedMessageLength, response); + + Assert.True(channel.WriteInbound(request)); + var requestForwarded = channel.ReadInbound(); + Assert.Equal(request, requestForwarded); + ReferenceCountUtil.Release(requestForwarded); + channel.WriteAndFlushAsync(response).Wait(TimeSpan.FromSeconds(1)); + var writtenResponse = channel.ReadOutbound(); + + Assert.Equal(isKeepAliveResponseExpected, channel.Open); + Assert.Equal(isKeepAliveResponseExpected, HttpUtil.IsKeepAlive(writtenResponse)); + ReferenceCountUtil.Release(writtenResponse); + Assert.False(channel.FinishAndReleaseAll()); + } + + [Theory] + [MemberData(nameof(GetKeepAliveCases))] +#pragma warning disable xUnit1026 // Theory methods should use all of their parameters + public void ConnectionCloseHeaderHandledCorrectly(bool isKeepAliveResponseExpected, HttpVersion httpVersion, HttpResponseStatus responseStatus, string sendKeepAlive, int setSelfDefinedMessageLength, ICharSequence setResponseConnection) +#pragma warning restore xUnit1026 // Theory methods should use all of their parameters + { + var channel = new EmbeddedChannel(new HttpServerKeepAliveHandler()); + var response = new DefaultFullHttpResponse(httpVersion, responseStatus); + response.Headers.Set(HttpHeaderNames.Connection, HttpHeaderValues.Close); + SetupMessageLength(setSelfDefinedMessageLength, response); + + channel.WriteAndFlushAsync(response).Wait(TimeSpan.FromSeconds(1)); + var writtenResponse = channel.ReadOutbound(); + + Assert.False(channel.Open); + ReferenceCountUtil.Release(writtenResponse); + Assert.False(channel.FinishAndReleaseAll()); + } + + [Theory] + [MemberData(nameof(GetKeepAliveCases))] + public void PipelineKeepAlive(bool isKeepAliveResponseExpected, HttpVersion httpVersion, HttpResponseStatus responseStatus, string sendKeepAlive, int setSelfDefinedMessageLength, ICharSequence setResponseConnection) + { + var channel = new EmbeddedChannel(new HttpServerKeepAliveHandler()); + var firstRequest = new DefaultFullHttpRequest(httpVersion, HttpMethod.Get, "/v1/foo/bar"); + HttpUtil.SetKeepAlive(firstRequest, true); + var secondRequest = new DefaultFullHttpRequest(httpVersion, HttpMethod.Get, "/v1/foo/bar"); + HttpUtil.SetKeepAlive(secondRequest, RequestKeepAlive.Equals(sendKeepAlive)); + var finalRequest = new DefaultFullHttpRequest(httpVersion, HttpMethod.Get, "/v1/foo/bar"); + HttpUtil.SetKeepAlive(finalRequest, false); + var response = new DefaultFullHttpResponse(httpVersion, responseStatus); + var informationalResp = new DefaultFullHttpResponse(httpVersion, Processing); + HttpUtil.SetKeepAlive(response, true); + HttpUtil.SetContentLength(response, 0); + HttpUtil.SetKeepAlive(informationalResp, true); + + Assert.True(channel.WriteInbound(firstRequest, secondRequest, finalRequest)); + + var requestForwarded = channel.ReadInbound(); + Assert.Equal(firstRequest, requestForwarded); + ReferenceCountUtil.Release(requestForwarded); + + channel.WriteAndFlushAsync(response.Duplicate().Retain()).Wait(TimeSpan.FromSeconds(1)); + var firstResponse = channel.ReadOutbound(); + Assert.True(channel.Open); + Assert.True(HttpUtil.IsKeepAlive(firstResponse)); + ReferenceCountUtil.Release(firstResponse); + + requestForwarded = channel.ReadInbound(); + Assert.Equal(secondRequest, requestForwarded); + ReferenceCountUtil.Release(requestForwarded); + + channel.WriteAndFlushAsync(informationalResp).Wait(TimeSpan.FromSeconds(1)); + var writtenInfoResp = channel.ReadOutbound(); + Assert.True(channel.Open); + Assert.True(HttpUtil.IsKeepAlive(writtenInfoResp)); + ReferenceCountUtil.Release(writtenInfoResp); + + if (!CharUtil.IsNullOrEmpty(setResponseConnection)) + { + response.Headers.Set(HttpHeaderNames.Connection, setResponseConnection); + } + else + { + response.Headers.Remove(HttpHeaderNames.Connection); + } + SetupMessageLength(setSelfDefinedMessageLength, response); + channel.WriteAndFlushAsync(response.Duplicate().Retain()).Wait(TimeSpan.FromSeconds(1)); + var secondResponse = channel.ReadOutbound(); + Assert.Equal(isKeepAliveResponseExpected, channel.Open); + Assert.Equal(isKeepAliveResponseExpected, HttpUtil.IsKeepAlive(secondResponse)); + ReferenceCountUtil.Release(secondResponse); + + requestForwarded = channel.ReadInbound(); + Assert.Equal(finalRequest, requestForwarded); + ReferenceCountUtil.Release(requestForwarded); + + if (isKeepAliveResponseExpected) + { + channel.WriteAndFlushAsync(response).Wait(TimeSpan.FromSeconds(1)); + var finalResponse = channel.ReadOutbound(); + Assert.False(channel.Open); + Assert.False(HttpUtil.IsKeepAlive(finalResponse)); + } + ReferenceCountUtil.Release(response); + Assert.False(channel.FinishAndReleaseAll()); + } + + static void SetupMessageLength(int setSelfDefinedMessageLength, IHttpResponse response) + { + switch (setSelfDefinedMessageLength) + { + case NotSelfDefinedMsgLength: + if (HttpUtil.IsContentLengthSet(response)) + { + response.Headers.Remove(HttpHeaderNames.ContentLength); + } + break; + case SetResponseLength: + HttpUtil.SetContentLength(response, 0); + break; + case SetChunked: + HttpUtil.SetTransferEncodingChunked(response, true); + break; + case SetMultipart: + response.Headers.Set(HttpHeaderNames.ContentType, HttpHeaderValues.MultipartMixed); + break; + default: + throw new ArgumentException($"Unknown selfDefinedMessageLength: {setSelfDefinedMessageLength}"); + } + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/HttpUtilTest.cs b/test/DotNetty.Codecs.Http.Tests/HttpUtilTest.cs new file mode 100644 index 0000000..b053b95 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/HttpUtilTest.cs @@ -0,0 +1,256 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System; + using System.Collections.Generic; + using System.Linq; + using System.Text; + using DotNetty.Common.Utilities; + using Xunit; + + public sealed class HttpUtilTest + { + [Fact] + public void RemoveTransferEncodingIgnoreCase() + { + var message = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + message.Headers.Set(HttpHeaderNames.TransferEncoding, "Chunked"); + Assert.False(message.Headers.IsEmpty); + HttpUtil.SetTransferEncodingChunked(message, false); + Assert.True(message.Headers.IsEmpty); + } + + // See https://github.com/netty/netty/issues/1690 + [Fact] + public void GetOperations() + { + HttpHeaders headers = new DefaultHttpHeaders(); + headers.Add(new AsciiString("Foo"), new AsciiString("1")); + headers.Add(new AsciiString("Foo"), new AsciiString("2")); + + Assert.True(headers.TryGet(new AsciiString("Foo"), out ICharSequence value)); + Assert.Equal("1", value.ToString()); + + IList values = headers.GetAll(new AsciiString("Foo")); + Assert.NotNull(values); + Assert.Equal(2, values.Count); + Assert.Equal("1", values[0].ToString()); + Assert.Equal("2", values[1].ToString()); + } + + [Fact] + public void GetCharsetAsRawCharSequence() + { + const string QuotesCharsetContentType = "text/html; charset=\"utf8\""; + const string SimpleContentType = "text/html"; + + var message = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + message.Headers.Set(HttpHeaderNames.ContentType, QuotesCharsetContentType); + Assert.Equal("\"utf8\"", HttpUtil.GetCharsetAsSequence(message).ToString()); + Assert.Equal("\"utf8\"", HttpUtil.GetCharsetAsSequence(new AsciiString(QuotesCharsetContentType))); + + message.Headers.Set(HttpHeaderNames.ContentType, "text/html"); + Assert.Null(HttpUtil.GetCharsetAsSequence(message)); + Assert.Null(HttpUtil.GetCharsetAsSequence(new AsciiString(SimpleContentType))); + } + + [Fact] + public void GetCharset() + { + const string NormalContentType = "text/html; charset=utf-8"; + const string UpperCaseNormalContentType = "TEXT/HTML; CHARSET=UTF-8"; + + var message = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + message.Headers.Set(HttpHeaderNames.ContentType, NormalContentType); + Assert.Equal(Encoding.UTF8, HttpUtil.GetCharset(message)); + Assert.Equal(Encoding.UTF8, HttpUtil.GetCharset(new AsciiString(NormalContentType))); + + message.Headers.Set(HttpHeaderNames.ContentType, UpperCaseNormalContentType); + Assert.Equal(Encoding.UTF8, HttpUtil.GetCharset(message)); + Assert.Equal(Encoding.UTF8, HttpUtil.GetCharset(new AsciiString(UpperCaseNormalContentType))); + } + + [Fact] + public void GetCharsetDefaultValue() + { + const string SimpleContentType = "text/html"; + const string ContentTypeWithIncorrectCharset = "text/html; charset=UTFFF"; + + var message = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + message.Headers.Set(HttpHeaderNames.ContentType, SimpleContentType); + Assert.Equal(Encoding.UTF8, HttpUtil.GetCharset(message)); + Assert.Equal(Encoding.UTF8, HttpUtil.GetCharset(new AsciiString(SimpleContentType))); + + message.Headers.Set(HttpHeaderNames.ContentType, SimpleContentType); + Assert.Equal(Encoding.UTF8, HttpUtil.GetCharset(message, Encoding.UTF8)); + Assert.Equal(Encoding.UTF8, HttpUtil.GetCharset(new AsciiString(SimpleContentType), Encoding.UTF8)); + + message.Headers.Set(HttpHeaderNames.ContentType, ContentTypeWithIncorrectCharset); + Assert.Equal(Encoding.UTF8, HttpUtil.GetCharset(message)); + Assert.Equal(Encoding.UTF8, HttpUtil.GetCharset(new AsciiString(ContentTypeWithIncorrectCharset))); + + message.Headers.Set(HttpHeaderNames.ContentType, ContentTypeWithIncorrectCharset); + Assert.Equal(Encoding.UTF8, HttpUtil.GetCharset(message, Encoding.UTF8)); + Assert.Equal(Encoding.UTF8, HttpUtil.GetCharset(new AsciiString(ContentTypeWithIncorrectCharset), Encoding.UTF8)); + } + + [Fact] + public void GetMimeType() + { + const string SimpleContentType = "text/html"; + const string NormalContentType = "text/html; charset=utf-8"; + + var message = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + Assert.Null(HttpUtil.GetMimeType(message)); + message.Headers.Set(HttpHeaderNames.ContentType, ""); + Assert.Null(HttpUtil.GetMimeType(message)); + Assert.Null(HttpUtil.GetMimeType(new AsciiString(""))); + message.Headers.Set(HttpHeaderNames.ContentType, SimpleContentType); + Assert.Equal("text/html", HttpUtil.GetMimeType(message)); + Assert.Equal("text/html", HttpUtil.GetMimeType(new AsciiString(SimpleContentType))); + + message.Headers.Set(HttpHeaderNames.ContentType, NormalContentType); + Assert.Equal("text/html", HttpUtil.GetMimeType(message)); + Assert.Equal("text/html", HttpUtil.GetMimeType(new AsciiString(NormalContentType))); + } + + [Fact] + public void GetContentLengthThrowsNumberFormatException() + { + var message = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + message.Headers.Set(HttpHeaderNames.ContentLength, "bar"); + Assert.Throws(() => HttpUtil.GetContentLength(message)); + } + + [Fact] + public void GetContentLengthIntDefaultValueThrowsNumberFormatException() + { + var message = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + message.Headers.Set(HttpHeaderNames.ContentLength, "bar"); + Assert.Throws(() => HttpUtil.GetContentLength(message, 1)); + } + + [Fact] + public void GetContentLengthLongDefaultValueThrowsNumberFormatException() + { + var message = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + message.Headers.Set(HttpHeaderNames.ContentLength, "bar"); + Assert.Throws(() => HttpUtil.GetContentLength(message, 1L)); + } + + [Fact] + public void DoubleChunkedHeader() + { + var message = new DefaultHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + message.Headers.Add(HttpHeaderNames.TransferEncoding, "chunked"); + HttpUtil.SetTransferEncodingChunked(message, true); + + IList list = message.Headers.GetAll(HttpHeaderNames.TransferEncoding); + Assert.NotNull(list); + + var expected = new List {"chunked"}; + Assert.True(expected.SequenceEqual(list.Select(x => x.ToString()))); + } + + static IEnumerable AllPossibleCasesOfContinue() + { + var cases = new List(); + string c = "continue"; + for (int i = 0; i < Math.Pow(2, c.Length); i++) + { + var sb = new StringBuilder(c.Length); + int j = i; + int k = 0; + while (j > 0) + { + if ((j & 1) == 1) + { + sb.Append(char.ToUpper(c[k++])); + } + else + { + sb.Append(c[k++]); + } + j >>= 1; + } + for (; k < c.Length; k++) + { + sb.Append(c[k]); + } + + cases.Add(sb.ToString()); + } + + return cases; + } + + [Fact] + public void Is100Continue() + { + // test all possible cases of 100-continue + foreach (string continueCase in AllPossibleCasesOfContinue()) + { + Run100ContinueTest(HttpVersion.Http11, "100-" + continueCase, true); + } + Run100ContinueTest(HttpVersion.Http11, null, false); + Run100ContinueTest(HttpVersion.Http11, "chocolate=yummy", false); + Run100ContinueTest(HttpVersion.Http10, "100-continue", false); + + var message = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + message.Headers.Set(HttpHeaderNames.Expect, "100-continue"); + Run100ContinueTest(message, false); + } + + static void Run100ContinueTest(HttpVersion version, string expectations, bool expect) + { + var message = new DefaultFullHttpRequest(version, HttpMethod.Get, "/"); + if (expectations != null) + { + message.Headers.Set(HttpHeaderNames.Expect, expectations); + } + + Run100ContinueTest(message, expect); + } + + static void Run100ContinueTest(IHttpMessage message, bool expected) + { + Assert.Equal(expected, HttpUtil.Is100ContinueExpected(message)); + ReferenceCountUtil.Release(message); + } + + [Fact] + public void ContainsUnsupportedExpectation() + { + // test all possible cases of 100-continue + foreach (string continueCase in AllPossibleCasesOfContinue()) + { + RunUnsupportedExpectationTest(HttpVersion.Http11, "100-" + continueCase, false); + } + RunUnsupportedExpectationTest(HttpVersion.Http11, null, false); + RunUnsupportedExpectationTest(HttpVersion.Http11, "chocolate=yummy", true); + RunUnsupportedExpectationTest(HttpVersion.Http10, "100-continue", false); + + var message = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + message.Headers.Set(new AsciiString("Expect"), "100-continue"); + RunUnsupportedExpectationTest(message, false); + } + + static void RunUnsupportedExpectationTest(HttpVersion version, string expectations, bool expect) + { + var message = new DefaultFullHttpRequest(version, HttpMethod.Get, "/"); + if (expectations != null) + { + message.Headers.Set(new AsciiString("Expect"), expectations); + } + RunUnsupportedExpectationTest(message, expect); + } + + static void RunUnsupportedExpectationTest(IHttpMessage message, bool expected) + { + Assert.Equal(expected, HttpUtil.IsUnsupportedExpectation(message)); + ReferenceCountUtil.Release(message); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/Multipart/AbstractMemoryHttpDataTest.cs b/test/DotNetty.Codecs.Http.Tests/Multipart/AbstractMemoryHttpDataTest.cs new file mode 100644 index 0000000..7480e12 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/Multipart/AbstractMemoryHttpDataTest.cs @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.Multipart +{ + using System; + using System.IO; + using System.Linq; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Codecs.Http.Multipart; + using Xunit; + + public sealed class AbstractMemoryHttpDataTest + { + [Fact] + public void SetContentFromStream() + { + var random = new Random(); + + for (int i = 0; i < 20; i++) + { + // Generate input data bytes. + int size = random.Next(short.MaxValue); + var bytes = new byte[size]; + + random.NextBytes(bytes); + + // Generate parsed HTTP data block. + var httpData = new TestHttpData("name", Encoding.UTF8, 0); + + httpData.SetContent(new MemoryStream(bytes)); + + // Validate stored data. + IByteBuffer buffer = httpData.GetByteBuffer(); + + Assert.Equal(0, buffer.ReaderIndex); + Assert.Equal(bytes.Length, buffer.WriterIndex); + + var data = new byte[bytes.Length]; + buffer.GetBytes(buffer.ReaderIndex, data); + + Assert.True(data.SequenceEqual(bytes)); + } + } + + sealed class TestHttpData : AbstractMemoryHttpData + { + public TestHttpData(string name, Encoding contentEncoding, long size) + : base(name, contentEncoding, size) + { + } + + public override int CompareTo(IInterfaceHttpData other) + { + throw new NotSupportedException("Should never be called."); + } + + public override HttpDataType DataType => throw new NotSupportedException("Should never be called."); + + public override IByteBufferHolder Copy() + { + throw new NotSupportedException("Should never be called."); + } + + public override IByteBufferHolder Duplicate() + { + throw new NotSupportedException("Should never be called."); + } + + public override IByteBufferHolder RetainedDuplicate() + { + throw new NotSupportedException("Should never be called."); + } + + public override IByteBufferHolder Replace(IByteBuffer content) + { + throw new NotSupportedException("Should never be called."); + } + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/Multipart/DefaultHttpDataFactoryTest.cs b/test/DotNetty.Codecs.Http.Tests/Multipart/DefaultHttpDataFactoryTest.cs new file mode 100644 index 0000000..4e868cb --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/Multipart/DefaultHttpDataFactoryTest.cs @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.Multipart +{ + using System; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Codecs.Http.Multipart; + using Xunit; + + public sealed class DefaultHttpDataFactoryTest : IDisposable + { + // req1 equals req2 + readonly IHttpRequest req1 = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Post, "/form"); + readonly IHttpRequest req2 = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Post, "/form"); + readonly DefaultHttpDataFactory factory; + + public DefaultHttpDataFactoryTest() + { + // Before doing anything, assert that the requests are equal + Assert.Equal(this.req1.GetHashCode(), this.req2.GetHashCode()); + Assert.True(this.req1.Equals(this.req2)); + + this.factory = new DefaultHttpDataFactory(); + } + + [Fact] + public void CleanRequestHttpDataShouldIdentifiesRequestsByTheirIdentities() + { + // Create some data belonging to req1 and req2 + IAttribute attribute1 = this.factory.CreateAttribute(this.req1, "attribute1", "value1"); + IAttribute attribute2 = this.factory.CreateAttribute(this.req2, "attribute2", "value2"); + IFileUpload file1 = this.factory.CreateFileUpload( + this.req1, + "file1", + "file1.txt", + HttpPostBodyUtil.DefaultTextContentType, + HttpHeaderValues.Identity.ToString(), + Encoding.UTF8, + 123); + + IFileUpload file2 = this.factory.CreateFileUpload( + this.req2, + "file2", + "file2.txt", + HttpPostBodyUtil.DefaultTextContentType, + HttpHeaderValues.Identity.ToString(), + Encoding.UTF8, + 123); + file1.SetContent(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("file1 content"))); + file2.SetContent(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("file2 content"))); + + // Assert that they are not deleted + Assert.NotNull(attribute1.GetByteBuffer()); + Assert.NotNull(attribute2.GetByteBuffer()); + Assert.NotNull(file1.GetByteBuffer()); + Assert.NotNull(file2.GetByteBuffer()); + Assert.Equal(1, attribute1.ReferenceCount); + Assert.Equal(1, attribute2.ReferenceCount); + Assert.Equal(1, file1.ReferenceCount); + Assert.Equal(1, file2.ReferenceCount); + + // Clean up by req1 + this.factory.CleanRequestHttpData(this.req1); + + // Assert that data belonging to req1 has been cleaned up + Assert.Null(attribute1.GetByteBuffer()); + Assert.Null(file1.GetByteBuffer()); + Assert.Equal(0, attribute1.ReferenceCount); + Assert.Equal(0, file1.ReferenceCount); + + // But not req2 + Assert.NotNull(attribute2.GetByteBuffer()); + Assert.NotNull(file2.GetByteBuffer()); + Assert.Equal(1, attribute2.ReferenceCount); + Assert.Equal(1, file2.ReferenceCount); + } + + [Fact] + public void RemoveHttpDataFromCleanShouldIdentifiesDataByTheirIdentities() + { + // Create some equal data items belonging to the same request + IAttribute attribute1 = this.factory.CreateAttribute(this.req1, "attribute", "value"); + IAttribute attribute2 = this.factory.CreateAttribute(this.req1, "attribute", "value"); + IFileUpload file1 = this.factory.CreateFileUpload( + this.req1, + "file", + "file.txt", + HttpPostBodyUtil.DefaultTextContentType, + HttpHeaderValues.Identity.ToString(), + Encoding.UTF8, + 123); + IFileUpload file2 = this.factory.CreateFileUpload( + this.req1, + "file", + "file.txt", + HttpPostBodyUtil.DefaultTextContentType, + HttpHeaderValues.Identity.ToString(), + Encoding.UTF8, + 123); + file1.SetContent(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("file content"))); + file2.SetContent(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes("file content"))); + + // Before doing anything, assert that the data items are equal + Assert.Equal(attribute1.GetHashCode(), attribute2.GetHashCode()); + Assert.True(attribute1.Equals(attribute2)); + Assert.Equal(file1.GetHashCode(), file2.GetHashCode()); + Assert.True(file1.Equals(file2)); + + // Remove attribute2 and file2 from being cleaned up by factory + this.factory.RemoveHttpDataFromClean(this.req1, attribute2); + this.factory.RemoveHttpDataFromClean(this.req1, file2); + + // Clean up by req1 + this.factory.CleanRequestHttpData(this.req1); + + // Assert that attribute1 and file1 have been cleaned up + Assert.Null(attribute1.GetByteBuffer()); + Assert.Null(file1.GetByteBuffer()); + Assert.Equal(0, attribute1.ReferenceCount); + Assert.Equal(0, file1.ReferenceCount); + + // But not attribute2 and file2 + Assert.NotNull(attribute2.GetByteBuffer()); + Assert.NotNull(file2.GetByteBuffer()); + Assert.Equal(1, attribute2.ReferenceCount); + Assert.Equal(1, file2.ReferenceCount); + + // Cleanup attribute2 and file2 manually to avoid memory leak, not via factory + attribute2.Release(); + file2.Release(); + Assert.Equal(0, attribute2.ReferenceCount); + Assert.Equal(0, file2.ReferenceCount); + } + + public void Dispose() => this.factory.CleanAllHttpData(); + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/Multipart/DiskFileUploadTest.cs b/test/DotNetty.Codecs.Http.Tests/Multipart/DiskFileUploadTest.cs new file mode 100644 index 0000000..49bed41 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/Multipart/DiskFileUploadTest.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.Multipart +{ + using DotNetty.Codecs.Http.Multipart; + using Xunit; + + public sealed class DiskFileUploadTest + { + [Fact] + public void DiskFileUploadEquals() + { + var f2 = new DiskFileUpload("d1", "d1", "application/json", null, null, 100); + Assert.Equal(f2, f2); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/Multipart/HttpPostRequestDecoderTest.cs b/test/DotNetty.Codecs.Http.Tests/Multipart/HttpPostRequestDecoderTest.cs new file mode 100644 index 0000000..94fa180 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/Multipart/HttpPostRequestDecoderTest.cs @@ -0,0 +1,607 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.Multipart +{ + using System.Collections.Generic; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Codecs.Http.Multipart; + using DotNetty.Common.Utilities; + using Xunit; + + public sealed class HttpPostRequestDecoderTest + { + // https://github.com/netty/netty/issues/1575 + [Theory] + [InlineData(true)] + [InlineData(false)] + public void BinaryStreamUpload(bool withSpace) + { + const string Boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + string contentTypeValue; + if (withSpace) + { + contentTypeValue = "multipart/form-data; boundary=" + Boundary; + } + else + { + contentTypeValue = "multipart/form-data;boundary=" + Boundary; + } + var req = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + req.Result = DecoderResult.Success; + req.Headers.Add(HttpHeaderNames.ContentType, contentTypeValue); + req.Headers.Add(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + + // Force to use memory-based data. + var inMemoryFactory = new DefaultHttpDataFactory(false); + + var values = new[] { "", "\r", "\r\r", "\r\r\r" }; + foreach (string data in values) + { + string body = + "--" + Boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"tmp-0.txt\"\r\n" + + "Content-Type: image/gif\r\n" + + "\r\n" + + data + "\r\n" + + "--" + Boundary + "--\r\n"; + + // Create decoder instance to test. + var decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + + decoder.Offer(new DefaultHttpContent(Unpooled.CopiedBuffer(Encoding.UTF8.GetBytes(body)))); + decoder.Offer(new DefaultHttpContent(Unpooled.Empty)); + + // Validate it's enough chunks to decode upload. + Assert.True(decoder.HasNext); + + // Decode binary upload. + IInterfaceHttpData next = decoder.Next(); + Assert.IsType(next); + var upload = (MemoryFileUpload)next; + + // Validate data has been parsed correctly as it was passed into request. + Assert.Equal(data, upload.GetString(Encoding.UTF8)); + upload.Release(); + decoder.Destroy(); + } + } + + // See https://github.com/netty/netty/issues/1089 + [Fact] + public void FullHttpRequestUpload() + { + const string Boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + + var req = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + req.Result = DecoderResult.Success; + req.Headers.Add(HttpHeaderNames.ContentType, "multipart/form-data; boundary=" + Boundary); + req.Headers.Add(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + + // Force to use memory-based data. + var inMemoryFactory = new DefaultHttpDataFactory(false); + + var values = new[] { "", "\r", "\r\r", "\r\r\r" }; + foreach (string data in values) + { + string body = + "--" + Boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"tmp-0.txt\"\r\n" + + "Content-Type: image/gif\r\n" + + "\r\n" + + data + "\r\n" + + "--" + Boundary + "--\r\n"; + + req.Content.WriteBytes(Encoding.UTF8.GetBytes(body)); + } + + // Create decoder instance to test. + var decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + List list = decoder.GetBodyHttpDatas(); + Assert.NotNull(list); + Assert.False(list.Count == 0); + decoder.Destroy(); + } + + // See https://github.com/netty/netty/issues/2544 + [Fact] + public void MultipartCodecWithCRasEndOfAttribute() + { + const string Boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + + // Force to use memory-based data. + var inMemoryFactory = new DefaultHttpDataFactory(false); + + const string Extradata = "aaaa"; + var strings = new string[5]; + for (int i = 0; i < 4; i++) + { + strings[i] = Extradata; + for (int j = 0; j < i; j++) + { + strings[i] += '\r'; + } + } + + for (int i = 0; i < 4; i++) + { + var req = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + req.Result = DecoderResult.Success; + req.Headers.Add(HttpHeaderNames.ContentType, "multipart/form-data; boundary=" + Boundary); + req.Headers.Add(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + string body = + "--" + Boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file" + i + "\"\r\n" + + "Content-Type: image/gif\r\n" + + "\r\n" + + strings[i] + "\r\n" + + "--" + Boundary + "--\r\n"; + + req.Content.WriteBytes(Encoding.UTF8.GetBytes(body)); + // Create decoder instance to test. + var decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + List list = decoder.GetBodyHttpDatas(); + Assert.NotNull(list); + Assert.False(list.Count == 0); + + // Check correctness: data size + IInterfaceHttpData httpData = decoder.GetBodyHttpData(new AsciiString($"file{i}")); + Assert.NotNull(httpData); + var attribute = httpData as IAttribute; + Assert.NotNull(attribute); + + byte[] data = attribute.GetBytes(); + Assert.NotNull(data); + Assert.Equal(Encoding.UTF8.GetBytes(strings[i]).Length, data.Length); + + decoder.Destroy(); + } + } + + // See https://github.com/netty/netty/issues/2542 + [Fact] + public void QuotedBoundary() + { + const string Boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + + var req = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + + req.Result = DecoderResult.Success; + req.Headers.Add(HttpHeaderNames.ContentType, "multipart/form-data; boundary=\"" + Boundary + '"'); + req.Headers.Add(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + + // Force to use memory-based data. + var inMemoryFactory = new DefaultHttpDataFactory(false); + + var values = new[] { "", "\r", "\r\r", "\r\r\r" }; + foreach (string data in values) + { + string body = + "--" + Boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"tmp-0.txt\"\r\n" + + "Content-Type: image/gif\r\n" + + "\r\n" + + data + "\r\n" + + "--" + Boundary + "--\r\n"; + + req.Content.WriteBytes(Encoding.UTF8.GetBytes(body)); + } + + // Create decoder instance to test. + var decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + List list = decoder.GetBodyHttpDatas(); + Assert.NotNull(list); + Assert.False(list.Count == 0); + decoder.Destroy(); + } + + [Fact] + public void NoZeroOut() + { + const string Boundary = "E832jQp_Rq2ErFmAduHSR8YlMSm0FCY"; + + var aMemFactory = new DefaultHttpDataFactory(false); + var aRequest = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + aRequest.Headers.Set(HttpHeaderNames.ContentType, "multipart/form-data; boundary=" + Boundary); + aRequest.Headers.Set(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + + var aDecoder = new HttpPostRequestDecoder(aMemFactory, aRequest); + + const string BodyData = "some data would be here. the data should be long enough that it " + + "will be longer than the original buffer length of 256 bytes in " + + "the HttpPostRequestDecoder in order to trigger the issue. Some more " + + "data just to be on the safe side."; + + const string Body = "--" + Boundary + "\r\n" + + "Content-Disposition: form-data; name=\"root\"\r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + BodyData + + "\r\n" + + "--" + Boundary + "--\r\n"; + + byte[] aBytes = Encoding.UTF8.GetBytes(Body); + const int Split = 125; + + UnpooledByteBufferAllocator aAlloc = UnpooledByteBufferAllocator.Default; + IByteBuffer aSmallBuf = aAlloc.Buffer(Split, Split); + IByteBuffer aLargeBuf = aAlloc.Buffer(aBytes.Length - Split, aBytes.Length - Split); + + aSmallBuf.WriteBytes(aBytes, 0, Split); + aLargeBuf.WriteBytes(aBytes, Split, aBytes.Length - Split); + + aDecoder.Offer(new DefaultHttpContent(aSmallBuf)); + aDecoder.Offer(new DefaultHttpContent(aLargeBuf)); + aDecoder.Offer(EmptyLastHttpContent.Default); + + Assert.True(aDecoder.HasNext); + IInterfaceHttpData aDecodedData = aDecoder.Next(); + Assert.Equal(HttpDataType.Attribute, aDecodedData.DataType); + + var aAttr = (IAttribute)aDecodedData; + Assert.Equal(BodyData, aAttr.Value); + + aDecodedData.Release(); + aDecoder.Destroy(); + } + + // See https://github.com/netty/netty/issues/2305 + [Fact] + public void ChunkCorrect() + { + const string Payload = "town=794649819&town=784444184&town=794649672&town=794657800&town=" + + "794655734&town=794649377&town=794652136&town=789936338&town=789948986&town=" + + "789949643&town=786358677&town=794655880&town=786398977&town=789901165&town=" + + "789913325&town=789903418&town=789903579&town=794645251&town=794694126&town=" + + "794694831&town=794655274&town=789913656&town=794653956&town=794665634&town=" + + "789936598&town=789904658&town=789899210&town=799696252&town=794657521&town=" + + "789904837&town=789961286&town=789958704&town=789948839&town=789933899&town=" + + "793060398&town=794659180&town=794659365&town=799724096&town=794696332&town=" + + "789953438&town=786398499&town=794693372&town=789935439&town=794658041&town=" + + "789917595&town=794655427&town=791930372&town=794652891&town=794656365&town=" + + "789960339&town=794645586&town=794657688&town=794697211&town=789937427&town=" + + "789902813&town=789941130&town=794696907&town=789904328&town=789955151&town=" + + "789911570&town=794655074&town=789939531&town=789935242&town=789903835&town=" + + "789953800&town=794649962&town=789939841&town=789934819&town=789959672&town=" + + "794659043&town=794657035&town=794658938&town=794651746&town=794653732&town=" + + "794653881&town=786397909&town=794695736&town=799724044&town=794695926&town=" + + "789912270&town=794649030&town=794657946&town=794655370&town=794659660&town=" + + "794694617&town=799149862&town=789953234&town=789900476&town=794654995&town=" + + "794671126&town=789908868&town=794652942&town=789955605&town=789901934&town=" + + "789950015&town=789937922&town=789962576&town=786360170&town=789954264&town=" + + "789911738&town=789955416&town=799724187&town=789911879&town=794657462&town=" + + "789912561&town=789913167&town=794655195&town=789938266&town=789952099&town=" + + "794657160&town=789949414&town=794691293&town=794698153&town=789935636&town=" + + "789956374&town=789934635&town=789935475&town=789935085&town=794651425&town=" + + "794654936&town=794655680&town=789908669&town=794652031&town=789951298&town=" + + "789938382&town=794651503&town=794653330&town=817675037&town=789951623&town=" + + "789958999&town=789961555&town=794694050&town=794650241&town=794656286&town=" + + "794692081&town=794660090&town=794665227&town=794665136&town=794669931"; + + var defaultHttpRequest = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Post, "/"); + + var decoder = new HttpPostRequestDecoder(defaultHttpRequest); + + const int FirstChunk = 10; + const int MiddleChunk = 1024; + + var part1 = new DefaultHttpContent(Unpooled.WrappedBuffer( + Encoding.UTF8.GetBytes(Payload.Substring(0, FirstChunk)))); + var part2 = new DefaultHttpContent(Unpooled.WrappedBuffer( + Encoding.UTF8.GetBytes(Payload.Substring(FirstChunk, MiddleChunk)))); + var part3 = new DefaultHttpContent(Unpooled.WrappedBuffer( + Encoding.UTF8.GetBytes(Payload.Substring(FirstChunk + MiddleChunk, MiddleChunk)))); + var part4 = new DefaultHttpContent(Unpooled.WrappedBuffer( + Encoding.UTF8.GetBytes(Payload.Substring(FirstChunk + MiddleChunk * 2)))); + + decoder.Offer(part1); + decoder.Offer(part2); + decoder.Offer(part3); + decoder.Offer(part4); + } + + // See https://github.com/netty/netty/issues/3326 + [Fact] + public void FilenameContainingSemicolon() + { + const string Boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + var req = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + req.Headers.Add(HttpHeaderNames.ContentType, "multipart/form-data; boundary=" + Boundary); + + // Force to use memory-based data. + var inMemoryFactory = new DefaultHttpDataFactory(false); + + const string Data = "asdf"; + const string Filename = "tmp;0.txt"; + const string Body = "--" + Boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"" + Filename + "\"\r\n" + + "Content-Type: image/gif\r\n" + + "\r\n" + + Data + "\r\n" + + "--" + Boundary + "--\r\n"; + + req.Content.WriteBytes(Encoding.UTF8.GetBytes(Body)); + var decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + List list = decoder.GetBodyHttpDatas(); + Assert.NotNull(list); + Assert.False(list.Count == 0); + decoder.Destroy(); + } + + [Fact] + public void FilenameContainingSemicolon2() + { + const string Boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + var req = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + req.Headers.Add(HttpHeaderNames.ContentType, "multipart/form-data; boundary=" + Boundary); + + // Force to use memory-based data. + var inMemoryFactory = new DefaultHttpDataFactory(false); + const string Data = "asdf"; + const string Filename = "tmp;0.txt"; + const string Body = "--" + Boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"" + Filename + "\"\r\n" + + "Content-Type: image/gif\r\n" + + "\r\n" + + Data + "\r\n" + + "--" + Boundary + "--\r\n"; + + req.Content.WriteBytes(Encoding.UTF8.GetBytes(Body)); + var decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + List list = decoder.GetBodyHttpDatas(); + Assert.NotNull(list); + Assert.False(list.Count == 0); + + IInterfaceHttpData part1 = list[0]; + Assert.IsAssignableFrom(part1); + var fileUpload = (IFileUpload)part1; + Assert.Equal("tmp 0.txt", fileUpload.FileName); + decoder.Destroy(); + } + + [Fact] + public void MultipartRequestWithoutContentTypeBody() + { + const string Boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + + var req = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + req.Result = DecoderResult.Success; + req.Headers.Add(HttpHeaderNames.ContentType, "multipart/form-data; boundary=" + Boundary); + req.Headers.Add(HttpHeaderNames.TransferEncoding, HttpHeaderValues.Chunked); + + // Force to use memory-based data. + var inMemoryFactory = new DefaultHttpDataFactory(false); + + var values = new[] { "", "\r", "\r\r", "\r\r\r" }; + foreach (string data in values) + { + string body = + "--" + Boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"tmp-0.txt\"\r\n" + + "\r\n" + + data + "\r\n" + + "--" + Boundary + "--\r\n"; + + req.Content.WriteBytes(Encoding.UTF8.GetBytes(body)); + } + + // Create decoder instance to test without any exception. + var decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + List list = decoder.GetBodyHttpDatas(); + Assert.NotNull(list); + Assert.False(list.Count == 0); + decoder.Destroy(); + } + + [Fact] + public void MultipartRequestWithFileInvalidCharset() + { + const string Boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + var req = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + req.Headers.Add(HttpHeaderNames.ContentType, "multipart/form-data; boundary=" + Boundary); + // Force to use memory-based data. + var inMemoryFactory = new DefaultHttpDataFactory(false); + + const string Data = "asdf"; + const string FileName = "tmp;0.txt"; + string body = + "--" + Boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"" + FileName + "\"\r\n" + + "Content-Type: image/gif; charset=ABCD\r\n" + + "\r\n" + + Data + "\r\n" + + "--" + Boundary + "--\r\n"; + + req.Content.WriteBytes(Encoding.UTF8.GetBytes(body)); + Assert.Throws(() => new HttpPostRequestDecoder(inMemoryFactory, req)); + } + + [Fact] + public void MultipartRequestWithFieldInvalidCharset() + { + const string Boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + var req = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + req.Headers.Add(HttpHeaderNames.ContentType, "multipart/form-data; boundary=" + Boundary); + // Force to use memory-based data. + var inMemoryFactory = new DefaultHttpDataFactory(false); + + const string BodyData = "some data would be here. the data should be long enough that it " + + "will be longer than the original buffer length of 256 bytes in " + + "the HttpPostRequestDecoder in order to trigger the issue. Some more " + + "data just to be on the safe side."; + + const string Body = "--" + Boundary + "\r\n" + + "Content-Disposition: form-data; name=\"root\"\r\n" + + "Content-Type: text/plain; charset=ABCD\r\n" + + "\r\n" + + BodyData + + "\r\n" + + "--" + Boundary + "--\r\n"; + + req.Content.WriteBytes(Encoding.UTF8.GetBytes(Body)); + Assert.Throws(() => new HttpPostRequestDecoder(inMemoryFactory, req)); + } + + [Fact] + public void FormEncodeIncorrect() + { + var content = new DefaultLastHttpContent(Unpooled.CopiedBuffer( + Encoding.ASCII.GetBytes("project=netty&&project=netty"))); + var req = new DefaultHttpRequest(HttpVersion.Http11, HttpMethod.Post, "/"); + var decoder = new HttpPostRequestDecoder(req); + Assert.Throws(() => decoder.Offer(content)); + decoder.Destroy(); + content.Release(); + } + + [Fact] + public void DecodeContentDispositionFieldParameters() + { + const string Boundary = "74e78d11b0214bdcbc2f86491eeb4902"; + const string Charset = "utf-8"; + const string Filename = "attached_файл.txt"; + string filenameEncoded = UrlEncoder.Encode(Filename, Encoding.UTF8); + + string body = "--" + Boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename*=" + Charset + "''" + filenameEncoded + "\r\n" + + "\r\n" + + "foo\r\n" + + "\r\n" + + "--" + Boundary + "--"; + + var req = new DefaultFullHttpRequest(HttpVersion.Http11, + HttpMethod.Post, + "http://localhost", + Unpooled.WrappedBuffer(Encoding.UTF8.GetBytes(body))); + + req.Headers.Add(HttpHeaderNames.ContentType, "multipart/form-data; boundary=" + Boundary); + var inMemoryFactory = new DefaultHttpDataFactory(false); + var decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + Assert.False(decoder.GetBodyHttpDatas().Count == 0); + IInterfaceHttpData part1 = decoder.GetBodyHttpDatas()[0]; + Assert.IsAssignableFrom(part1); + + var fileUpload = (IFileUpload)part1; + Assert.Equal(Filename, fileUpload.FileName); + decoder.Destroy(); + req.Release(); + } + + [Fact] + public void DecodeWithLanguageContentDispositionFieldParameters() + { + const string Boundary = "74e78d11b0214bdcbc2f86491eeb4902"; + const string Charset = "utf-8"; + const string Filename = "attached_файл.txt"; + const string Language = "anything"; + string filenameEncoded = UrlEncoder.Encode(Filename, Encoding.UTF8); + + string body = "--" + Boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename*=" + + Charset + "'" + Language + "'" + filenameEncoded + "\r\n" + + "\r\n" + + "foo\r\n" + + "\r\n" + + "--" + Boundary + "--"; + + var req = new DefaultFullHttpRequest( + HttpVersion.Http11, + HttpMethod.Post, + "http://localhost", + Unpooled.WrappedBuffer(Encoding.UTF8.GetBytes(body))); + + req.Headers.Add(HttpHeaderNames.ContentType, "multipart/form-data; boundary=" + Boundary); + var inMemoryFactory = new DefaultHttpDataFactory(false); + var decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + Assert.False(decoder.GetBodyHttpDatas().Count == 0); + IInterfaceHttpData part1 = decoder.GetBodyHttpDatas()[0]; + Assert.IsAssignableFrom(part1); + var fileUpload = (IFileUpload)part1; + Assert.Equal(Filename, fileUpload.FileName); + decoder.Destroy(); + req.Release(); + } + + [Fact] + public void DecodeMalformedNotEncodedContentDispositionFieldParameters() + { + const string Boundary = "74e78d11b0214bdcbc2f86491eeb4902"; + + const string Body = "--" + Boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename*=not-encoded\r\n" + + "\r\n" + + "foo\r\n" + + "\r\n" + + "--" + Boundary + "--"; + + var req = new DefaultFullHttpRequest( + HttpVersion.Http11, + HttpMethod.Post, + "http://localhost", + Unpooled.WrappedBuffer(Encoding.UTF8.GetBytes(Body))); + + req.Headers.Add(HttpHeaderNames.ContentType, "multipart/form-data; boundary=" + Boundary); + var inMemoryFactory = new DefaultHttpDataFactory(false); + Assert.Throws(() => new HttpPostRequestDecoder(inMemoryFactory, req)); + req.Release(); + } + + [Fact] + public void DecodeMalformedBadCharsetContentDispositionFieldParameters() + { + const string Boundary = "74e78d11b0214bdcbc2f86491eeb4902"; + + const string Body = "--" + Boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename*=not-a-charset''filename\r\n" + + "\r\n" + + "foo\r\n" + + "\r\n" + + "--" + Boundary + "--"; + + var req = new DefaultFullHttpRequest( + HttpVersion.Http11, + HttpMethod.Post, + "http://localhost", + Unpooled.WrappedBuffer(Encoding.UTF8.GetBytes(Body))); + + req.Headers.Add(HttpHeaderNames.ContentType, "multipart/form-data; boundary=" + Boundary); + + var inMemoryFactory = new DefaultHttpDataFactory(false); + Assert.Throws(() => new HttpPostRequestDecoder(inMemoryFactory, req)); + req.Release(); + } + + [Fact] + public void DecodeMalformedEmptyContentTypeFieldParameters() + { + const string Boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + var req = new DefaultFullHttpRequest( + HttpVersion.Http11, + HttpMethod.Post, + "http://localhost"); + + req.Headers.Add(HttpHeaderNames.ContentType, "multipart/form-data; boundary=" + Boundary); + // Force to use memory-based data. + var inMemoryFactory = new DefaultHttpDataFactory(false); + const string Data = "asdf"; + const string Filename = "tmp-0.txt"; + const string Body = "--" + Boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"" + Filename + "\"\r\n" + + "Content-Type: \r\n" + + "\r\n" + + Data + "\r\n" + + "--" + Boundary + "--\r\n"; + + req.Content.WriteBytes(Encoding.UTF8.GetBytes(Body)); + // Create decoder instance to test. + var decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + Assert.False(decoder.GetBodyHttpDatas().Count == 0); + IInterfaceHttpData part1 = decoder.GetBodyHttpDatas()[0]; + Assert.IsAssignableFrom(part1); + var fileUpload = (IFileUpload)part1; + Assert.Equal(Filename, fileUpload.FileName); + decoder.Destroy(); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/Multipart/HttpPostRequestEncoderTest.cs b/test/DotNetty.Codecs.Http.Tests/Multipart/HttpPostRequestEncoderTest.cs new file mode 100644 index 0000000..7c541fb --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/Multipart/HttpPostRequestEncoderTest.cs @@ -0,0 +1,442 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.Multipart +{ + using System; + using System.Collections.Generic; + using System.IO; + using System.Text; + using DotNetty.Buffers; + using DotNetty.Codecs.Http.Multipart; + using DotNetty.Common.Utilities; + using Xunit; + + public sealed class HttpPostRequestEncoderTest : IDisposable + { + readonly List files = new List(); + + [Fact] + public void AllowedMethods() + { + FileStream fileStream = File.Open("./Multipart/file-01.txt", FileMode.Open, FileAccess.Read); + this.files.Add(fileStream); + + ShouldThrowExceptionIfNotAllowed(HttpMethod.Connect, fileStream); + ShouldThrowExceptionIfNotAllowed(HttpMethod.Put, fileStream); + ShouldThrowExceptionIfNotAllowed(HttpMethod.Post, fileStream); + ShouldThrowExceptionIfNotAllowed(HttpMethod.Patch, fileStream); + ShouldThrowExceptionIfNotAllowed(HttpMethod.Delete, fileStream); + ShouldThrowExceptionIfNotAllowed(HttpMethod.Get, fileStream); + ShouldThrowExceptionIfNotAllowed(HttpMethod.Head, fileStream); + ShouldThrowExceptionIfNotAllowed(HttpMethod.Options, fileStream); + Assert.Throws(() => ShouldThrowExceptionIfNotAllowed(HttpMethod.Trace, fileStream)); + } + + static void ShouldThrowExceptionIfNotAllowed(HttpMethod method, FileStream fileStream) + { + fileStream.Position = 0; // Reset to the begining + var request = new DefaultFullHttpRequest(HttpVersion.Http11, method, "http://localhost"); + + var encoder = new HttpPostRequestEncoder(request, true); + encoder.AddBodyAttribute("foo", "bar"); + encoder.AddBodyFileUpload("quux", fileStream, "text/plain", false); + + string multipartDataBoundary = encoder.MultipartDataBoundary; + string content = GetRequestBody(encoder); + + string expected = "--" + multipartDataBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": form-data; name=\"foo\"" + "\r\n" + + HttpHeaderNames.ContentLength + ": 3" + "\r\n" + + HttpHeaderNames.ContentType + ": text/plain; charset=utf-8" + "\r\n" + + "\r\n" + + "bar" + + "\r\n" + + "--" + multipartDataBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": form-data; name=\"quux\"; filename=\"file-01.txt\"" + "\r\n" + + HttpHeaderNames.ContentLength + ": " + fileStream.Length + "\r\n" + + HttpHeaderNames.ContentType + ": text/plain" + "\r\n" + + HttpHeaderNames.ContentTransferEncoding + ": binary" + "\r\n" + + "\r\n" + + "File 01" + StringUtil.Newline + + "\r\n" + + "--" + multipartDataBoundary + "--" + "\r\n"; + + Assert.Equal(expected, content); + } + + [Fact] + public void SingleFileUploadNoName() + { + var request = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + var encoder = new HttpPostRequestEncoder(request, true); + + FileStream fileStream = File.Open("./Multipart/file-01.txt", FileMode.Open, FileAccess.Read); + this.files.Add(fileStream); + encoder.AddBodyAttribute("foo", "bar"); + encoder.AddBodyFileUpload("quux", "", fileStream, "text/plain", false); + + string multipartDataBoundary = encoder.MultipartDataBoundary; + string content = GetRequestBody(encoder); + + string expected = "--" + multipartDataBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": form-data; name=\"foo\"" + "\r\n" + + HttpHeaderNames.ContentLength + ": 3" + "\r\n" + + HttpHeaderNames.ContentType + ": text/plain; charset=utf-8" + "\r\n" + + "\r\n" + + "bar" + + "\r\n" + + "--" + multipartDataBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": form-data; name=\"quux\"\r\n" + + HttpHeaderNames.ContentLength + ": " + fileStream.Length + "\r\n" + + HttpHeaderNames.ContentType + ": text/plain" + "\r\n" + + HttpHeaderNames.ContentTransferEncoding + ": binary" + "\r\n" + + "\r\n" + + "File 01" + StringUtil.Newline + + "\r\n" + + "--" + multipartDataBoundary + "--" + "\r\n"; + + Assert.Equal(expected, content); + } + + [Fact] + public void MultiFileUploadInMixedMode() + { + var request = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + var encoder = new HttpPostRequestEncoder(request, true); + + FileStream fileStream1 = File.Open("./Multipart/file-01.txt", FileMode.Open, FileAccess.Read); + this.files.Add(fileStream1); + FileStream fileStream2 = File.Open("./Multipart/file-02.txt", FileMode.Open, FileAccess.Read); + this.files.Add(fileStream2); + + encoder.AddBodyAttribute("foo", "bar"); + encoder.AddBodyFileUpload("quux", fileStream1, "text/plain", false); + encoder.AddBodyFileUpload("quux", fileStream2, "text/plain", false); + + // We have to query the value of these two fields before finalizing + // the request, which unsets one of them. + string multipartDataBoundary = encoder.MultipartDataBoundary; + string multipartMixedBoundary = encoder.MultipartMixedBoundary; + string content = GetRequestBody(encoder); + + string expected = "--" + multipartDataBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": form-data; name=\"foo\"" + "\r\n" + + HttpHeaderNames.ContentLength + ": 3" + "\r\n" + + HttpHeaderNames.ContentType + ": text/plain; charset=utf-8" + "\r\n" + + "\r\n" + + "bar" + "\r\n" + + "--" + multipartDataBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": form-data; name=\"quux\"" + "\r\n" + + HttpHeaderNames.ContentType + ": multipart/mixed; boundary=" + multipartMixedBoundary + "\r\n" + + "\r\n" + + "--" + multipartMixedBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": attachment; filename=\"file-02.txt\"" + "\r\n" + + HttpHeaderNames.ContentLength + ": " + fileStream1.Length + "\r\n" + + HttpHeaderNames.ContentType + ": text/plain" + "\r\n" + + HttpHeaderNames.ContentTransferEncoding + ": binary" + "\r\n" + + "\r\n" + + "File 01" + StringUtil.Newline + + "\r\n" + + "--" + multipartMixedBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": attachment; filename=\"file-02.txt\"" + "\r\n" + + HttpHeaderNames.ContentLength + ": " + fileStream2.Length + "\r\n" + + HttpHeaderNames.ContentType + ": text/plain" + "\r\n" + + HttpHeaderNames.ContentTransferEncoding + ": binary" + "\r\n" + + "\r\n" + + "File 02" + StringUtil.Newline + + "\r\n" + + "--" + multipartMixedBoundary + "--" + "\r\n" + + "--" + multipartDataBoundary + "--" + "\r\n"; + + Assert.Equal(expected, content); + } + + [Fact] + public void MultiFileUploadInMixedModeNoName() + { + var request = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + var encoder = new HttpPostRequestEncoder(request, true); + + FileStream fileStream1 = File.Open("./Multipart/file-01.txt", FileMode.Open, FileAccess.Read); + this.files.Add(fileStream1); + FileStream fileStream2 = File.Open("./Multipart/file-02.txt", FileMode.Open, FileAccess.Read); + this.files.Add(fileStream2); + + encoder.AddBodyAttribute("foo", "bar"); + encoder.AddBodyFileUpload("quux", "", fileStream1, "text/plain", false); + encoder.AddBodyFileUpload("quux", "", fileStream2, "text/plain", false); + + // We have to query the value of these two fields before finalizing + // the request, which unsets one of them. + string multipartDataBoundary = encoder.MultipartDataBoundary; + string multipartMixedBoundary = encoder.MultipartMixedBoundary; + string content = GetRequestBody(encoder); + + string expected = "--" + multipartDataBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": form-data; name=\"foo\"" + "\r\n" + + HttpHeaderNames.ContentLength + ": 3" + "\r\n" + + HttpHeaderNames.ContentType + ": text/plain; charset=utf-8" + "\r\n" + + "\r\n" + + "bar" + "\r\n" + + "--" + multipartDataBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": form-data; name=\"quux\"" + "\r\n" + + HttpHeaderNames.ContentType + ": multipart/mixed; boundary=" + multipartMixedBoundary + "\r\n" + + "\r\n" + + "--" + multipartMixedBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": attachment\r\n" + + HttpHeaderNames.ContentLength + ": " + fileStream1.Length + "\r\n" + + HttpHeaderNames.ContentType + ": text/plain" + "\r\n" + + HttpHeaderNames.ContentTransferEncoding + ": binary" + "\r\n" + + "\r\n" + + "File 01" + StringUtil.Newline + + "\r\n" + + "--" + multipartMixedBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": attachment\r\n" + + HttpHeaderNames.ContentLength + ": " + fileStream2.Length + "\r\n" + + HttpHeaderNames.ContentType + ": text/plain" + "\r\n" + + HttpHeaderNames.ContentTransferEncoding + ": binary" + "\r\n" + + "\r\n" + + "File 02" + StringUtil.Newline + + "\r\n" + + "--" + multipartMixedBoundary + "--" + "\r\n" + + "--" + multipartDataBoundary + "--" + "\r\n"; + + Assert.Equal(expected, content); + } + + [Fact] + public void SingleFileUploadInHtml5Mode() + { + var request = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + var factory = new DefaultHttpDataFactory(DefaultHttpDataFactory.MinSize); + var encoder = new HttpPostRequestEncoder( + factory, + request, + true, + Encoding.UTF8, + HttpPostRequestEncoder.EncoderMode.HTML5); + + FileStream fileStream1 = File.Open("./Multipart/file-01.txt", FileMode.Open, FileAccess.Read); + this.files.Add(fileStream1); + FileStream fileStream2 = File.Open("./Multipart/file-02.txt", FileMode.Open, FileAccess.Read); + this.files.Add(fileStream2); + + encoder.AddBodyAttribute("foo", "bar"); + encoder.AddBodyFileUpload("quux", fileStream1, "text/plain", false); + encoder.AddBodyFileUpload("quux", fileStream2, "text/plain", false); + + string multipartDataBoundary = encoder.MultipartDataBoundary; + string content = GetRequestBody(encoder); + + string expected = "--" + multipartDataBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": form-data; name=\"foo\"" + "\r\n" + + HttpHeaderNames.ContentLength + ": 3" + "\r\n" + + HttpHeaderNames.ContentType + ": text/plain; charset=utf-8" + "\r\n" + + "\r\n" + + "bar" + "\r\n" + + "--" + multipartDataBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": form-data; name=\"quux\"; filename=\"file-01.txt\"" + "\r\n" + + HttpHeaderNames.ContentLength + ": " + fileStream1.Length + "\r\n" + + HttpHeaderNames.ContentType + ": text/plain" + "\r\n" + + HttpHeaderNames.ContentTransferEncoding + ": binary" + "\r\n" + + "\r\n" + + "File 01" + StringUtil.Newline + "\r\n" + + "--" + multipartDataBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": form-data; name=\"quux\"; filename=\"file-02.txt\"" + "\r\n" + + HttpHeaderNames.ContentLength + ": " + fileStream2.Length + "\r\n" + + HttpHeaderNames.ContentType + ": text/plain" + "\r\n" + + HttpHeaderNames.ContentTransferEncoding + ": binary" + "\r\n" + + "\r\n" + + "File 02" + StringUtil.Newline + + "\r\n" + + "--" + multipartDataBoundary + "--" + "\r\n"; + + Assert.Equal(expected, content); + } + + [Fact] + public void MultiFileUploadInHtml5Mode() + { + var request = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + var factory = new DefaultHttpDataFactory(DefaultHttpDataFactory.MinSize); + + var encoder = new HttpPostRequestEncoder( + factory, + request, + true, + Encoding.UTF8, + HttpPostRequestEncoder.EncoderMode.HTML5); + FileStream fileStream1 = File.Open("./Multipart/file-01.txt", FileMode.Open, FileAccess.Read); + this.files.Add(fileStream1); + + encoder.AddBodyAttribute("foo", "bar"); + encoder.AddBodyFileUpload("quux", fileStream1, "text/plain", false); + + string multipartDataBoundary = encoder.MultipartDataBoundary; + string content = GetRequestBody(encoder); + + string expected = "--" + multipartDataBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": form-data; name=\"foo\"" + "\r\n" + + HttpHeaderNames.ContentLength + ": 3" + "\r\n" + + HttpHeaderNames.ContentType + ": text/plain; charset=utf-8" + "\r\n" + + "\r\n" + + "bar" + + "\r\n" + + "--" + multipartDataBoundary + "\r\n" + + HttpHeaderNames.ContentDisposition + ": form-data; name=\"quux\"; filename=\"file-01.txt\"" + "\r\n" + + HttpHeaderNames.ContentLength + ": " + fileStream1.Length + "\r\n" + + HttpHeaderNames.ContentType + ": text/plain" + "\r\n" + + HttpHeaderNames.ContentTransferEncoding + ": binary" + "\r\n" + + "\r\n" + + "File 01" + StringUtil.Newline + + "\r\n" + + "--" + multipartDataBoundary + "--" + "\r\n"; + + Assert.Equal(expected, content); + } + + [Fact] + public void HttpPostRequestEncoderSlicedBuffer() + { + var request = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + + var encoder = new HttpPostRequestEncoder(request, true); + // add Form attribute + encoder.AddBodyAttribute("getform", "POST"); + encoder.AddBodyAttribute("info", "first value"); + encoder.AddBodyAttribute("secondinfo", "secondvalue a&"); + encoder.AddBodyAttribute("thirdinfo", "short text"); + + const int Length = 100000; + var array = new char[Length]; + array.Fill('a'); + string longText = new string(array); + encoder.AddBodyAttribute("fourthinfo", longText.Substring(0, 7470)); + + FileStream fileStream1 = File.Open("./Multipart/file-01.txt", FileMode.Open, FileAccess.Read); + this.files.Add(fileStream1); + encoder.AddBodyFileUpload("myfile", fileStream1, "application/x-zip-compressed", false); + encoder.FinalizeRequest(); + + while (!encoder.IsEndOfInput) + { + IHttpContent httpContent = encoder.ReadChunk(null); + IByteBuffer content = httpContent.Content; + int refCnt = content.ReferenceCount; + Assert.True( + (ReferenceEquals(content.Unwrap(), content) || content.Unwrap() == null) && refCnt == 1 + || !ReferenceEquals(content.Unwrap(), content) && refCnt == 2, + "content: " + content + " content.unwrap(): " + content.Unwrap() + " refCnt: " + refCnt); + httpContent.Release(); + } + + encoder.CleanFiles(); + encoder.Close(); + } + + static string GetRequestBody(HttpPostRequestEncoder encoder) + { + encoder.FinalizeRequest(); + + List chunks = encoder.MultipartHttpDatas; + var buffers = new IByteBuffer[chunks.Count]; + + for (int i = 0; i < buffers.Length; i++) + { + IInterfaceHttpData data = chunks[i]; + if (data is InternalAttribute attribute) + { + buffers[i] = attribute.ToByteBuffer(); + } + else if (data is IHttpData httpData) + { + buffers[i] = httpData.GetByteBuffer(); + } + } + + IByteBuffer content = Unpooled.WrappedBuffer(buffers); + string contentStr = content.ToString(Encoding.UTF8); + content.Release(); + return contentStr; + } + + [Fact] + public void DataIsMultipleOfChunkSize1() + { + var factory = new DefaultHttpDataFactory(DefaultHttpDataFactory.MinSize); + var request = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + var encoder = new HttpPostRequestEncoder(factory, request, true, + HttpConstants.DefaultEncoding, HttpPostRequestEncoder.EncoderMode.RFC1738); + + var first = new MemoryFileUpload("resources", "", "application/json", null, Encoding.UTF8, -1); + first.MaxSize = -1; + first.SetContent(new MemoryStream(new byte[7955])); + encoder.AddBodyHttpData(first); + + var second = new MemoryFileUpload("resources2", "", "application/json", null, Encoding.UTF8, -1); + second.MaxSize = -1; + second.SetContent(new MemoryStream(new byte[7928])); + encoder.AddBodyHttpData(second); + + Assert.NotNull(encoder.FinalizeRequest()); + + CheckNextChunkSize(encoder, 8080); + CheckNextChunkSize(encoder, 8055); + + IHttpContent httpContent = encoder.ReadChunk(default(IByteBufferAllocator)); + Assert.True(httpContent is ILastHttpContent, "Expected LastHttpContent is not received"); + httpContent.Release(); + + Assert.True(encoder.IsEndOfInput, "Expected end of input is not receive"); + } + + [Fact] + public void DataIsMultipleOfChunkSize2() + { + var request = new DefaultFullHttpRequest(HttpVersion.Http11, HttpMethod.Post, "http://localhost"); + var encoder = new HttpPostRequestEncoder(request, true); + const int Length = 7943; + var array = new char[Length]; + array.Fill('a'); + string longText = new string(array); + encoder.AddBodyAttribute("foo", longText); + + Assert.NotNull(encoder.FinalizeRequest()); + + // In Netty this is 8080 due to random long hex size difference + CheckNextChunkSize(encoder, 109 + Length + 8); + + IHttpContent httpContent = encoder.ReadChunk(default(IByteBufferAllocator)); + Assert.True(httpContent is ILastHttpContent, "Expected LastHttpContent is not received"); + httpContent.Release(); + + Assert.True(encoder.IsEndOfInput, "Expected end of input is not receive"); + } + + static void CheckNextChunkSize(HttpPostRequestEncoder encoder, int sizeWithoutDelimiter) + { + // 16 bytes as HttpPostRequestEncoder uses Long.toHexString(...) to generate a hex-string which will be between + // 2 and 16 bytes. + // See https://github.com/netty/netty/blob/4.1/codec-http/src/main/java/io/netty/handler/ + // codec/http/multipart/HttpPostRequestEncoder.java#L291 + int expectedSizeMin = sizeWithoutDelimiter + (2 + 2); // Two multipar boundary strings + int expectedSizeMax = sizeWithoutDelimiter + (16 + 16); // Two multipar boundary strings + + IHttpContent httpContent = encoder.ReadChunk(default(IByteBufferAllocator)); + + int readable = httpContent.Content.ReadableBytes; + bool expectedSize = readable >= expectedSizeMin && readable <= expectedSizeMax; + Assert.True(expectedSize, $"Chunk size is not in expected range ({expectedSizeMin} - {expectedSizeMax}), was: {readable}"); + httpContent.Release(); + } + + public void Dispose() + { + foreach (IDisposable file in this.files) + { + file.Dispose(); + } + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/Multipart/MemoryFileUploadTest.cs b/test/DotNetty.Codecs.Http.Tests/Multipart/MemoryFileUploadTest.cs new file mode 100644 index 0000000..615282a --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/Multipart/MemoryFileUploadTest.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests.Multipart +{ + using DotNetty.Codecs.Http.Multipart; + using Xunit; + + public sealed class MemoryFileUploadTest + { + [Fact] + public void MemoryFileUploadEquals() + { + var f1 = new MemoryFileUpload("m1", "m1", "application/json", null, null, 100); + Assert.Equal(f1, f1); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/Multipart/file-01.txt b/test/DotNetty.Codecs.Http.Tests/Multipart/file-01.txt new file mode 100644 index 0000000..a94c45f --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/Multipart/file-01.txt @@ -0,0 +1 @@ +File 01 diff --git a/test/DotNetty.Codecs.Http.Tests/Multipart/file-02.txt b/test/DotNetty.Codecs.Http.Tests/Multipart/file-02.txt new file mode 100644 index 0000000..e2e0c12 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/Multipart/file-02.txt @@ -0,0 +1 @@ +File 02 diff --git a/test/DotNetty.Codecs.Http.Tests/Properties/AssemblyInfo.cs b/test/DotNetty.Codecs.Http.Tests/Properties/AssemblyInfo.cs new file mode 100644 index 0000000..e3f1cab --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/Properties/AssemblyInfo.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Reflection; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. + +[assembly: AssemblyTitle("DotNetty.Codecs.Http.Tests")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("DotNetty.Codecs.Http.Tests")] +[assembly: AssemblyCopyright("Copyright © 2015")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. + +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM + +[assembly: Guid("4d4c9d88-34e3-41c3-955b-61da1ae07b7c")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("1.0.*")] + +[assembly: AssemblyVersion("1.0.0.0")] +[assembly: AssemblyFileVersion("1.0.0.0")] \ No newline at end of file diff --git a/test/DotNetty.Codecs.Http.Tests/QueryStringDecoderTest.cs b/test/DotNetty.Codecs.Http.Tests/QueryStringDecoderTest.cs new file mode 100644 index 0000000..ae32b13 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/QueryStringDecoderTest.cs @@ -0,0 +1,366 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System; + using System.Collections.Generic; + using System.Linq; + using System.Text; + using Xunit; + + public sealed class QueryStringDecoderTest + { + [Fact] + public void BasicUris() + { + var d = new QueryStringDecoder("http://localhost/path"); + Assert.Equal(0, d.Parameters.Count); + } + + [Fact] + public void Basic() + { + var d = new QueryStringDecoder("/foo"); + Assert.Equal("/foo", d.Path); + Assert.Equal(0, d.Parameters.Count); + + d = new QueryStringDecoder("/foo%20bar"); + Assert.Equal("/foo bar", d.Path); + Assert.Equal(0, d.Parameters.Count); + + d = new QueryStringDecoder("/foo?a=b=c"); + Assert.Equal("/foo", d.Path); + Assert.Equal(1, d.Parameters.Count); + Assert.Single(d.Parameters["a"]); + Assert.Equal("b=c", d.Parameters["a"][0]); + + d = new QueryStringDecoder("/foo?a=1&a=2"); + Assert.Equal("/foo", d.Path); + Assert.Equal(1, d.Parameters.Count); + Assert.Equal(2, d.Parameters["a"].Count); + Assert.Equal("1", d.Parameters["a"][0]); + Assert.Equal("2", d.Parameters["a"][1]); + + d = new QueryStringDecoder("/foo%20bar?a=1&a=2"); + Assert.Equal("/foo bar", d.Path); + Assert.Equal(1, d.Parameters.Count); + Assert.Equal(2, d.Parameters["a"].Count); + Assert.Equal("1", d.Parameters["a"][0]); + Assert.Equal("2", d.Parameters["a"][1]); + + d = new QueryStringDecoder("/foo?a=&a=2"); + Assert.Equal("/foo", d.Path); + Assert.Equal(1, d.Parameters.Count); + Assert.Equal(2, d.Parameters["a"].Count); + Assert.Equal("", d.Parameters["a"][0]); + Assert.Equal("2", d.Parameters["a"][1]); + + d = new QueryStringDecoder("/foo?a=1&a="); + Assert.Equal("/foo", d.Path); + Assert.Equal(1, d.Parameters.Count); + Assert.Equal(2, d.Parameters["a"].Count); + Assert.Equal("1", d.Parameters["a"][0]); + Assert.Equal("", d.Parameters["a"][1]); + + d = new QueryStringDecoder("/foo?a=1&a=&a="); + Assert.Equal("/foo", d.Path); + Assert.Equal(1, d.Parameters.Count); + Assert.Equal(3, d.Parameters["a"].Count); + Assert.Equal("1", d.Parameters["a"][0]); + Assert.Equal("", d.Parameters["a"][1]); + Assert.Equal("", d.Parameters["a"][2]); + + d = new QueryStringDecoder("/foo?a=1=&a==2"); + Assert.Equal("/foo", d.Path); + Assert.Equal(1, d.Parameters.Count); + Assert.Equal(2, d.Parameters["a"].Count); + Assert.Equal("1=", d.Parameters["a"][0]); + Assert.Equal("=2", d.Parameters["a"][1]); + + d = new QueryStringDecoder("/foo?abc=1%2023&abc=124%20"); + Assert.Equal("/foo", d.Path); + Assert.Equal(1, d.Parameters.Count); + Assert.Equal(2, d.Parameters["abc"].Count); + Assert.Equal("1 23", d.Parameters["abc"][0]); + Assert.Equal("124 ", d.Parameters["abc"][1]); + } + + [Fact] + public void Exotic() + { + AssertQueryString("", ""); + AssertQueryString("foo", "foo"); + AssertQueryString("foo", "foo?"); + AssertQueryString("/foo", "/foo?"); + AssertQueryString("/foo", "/foo"); + AssertQueryString("?a=", "?a"); + AssertQueryString("foo?a=", "foo?a"); + AssertQueryString("/foo?a=", "/foo?a"); + AssertQueryString("/foo?a=", "/foo?a&"); + AssertQueryString("/foo?a=", "/foo?&a"); + AssertQueryString("/foo?a=", "/foo?&a&"); + AssertQueryString("/foo?a=", "/foo?&=a"); + AssertQueryString("/foo?a=", "/foo?=a&"); + AssertQueryString("/foo?a=", "/foo?a=&"); + AssertQueryString("/foo?a=b&c=d", "/foo?a=b&&c=d"); + AssertQueryString("/foo?a=b&c=d", "/foo?a=b&=&c=d"); + AssertQueryString("/foo?a=b&c=d", "/foo?a=b&==&c=d"); + AssertQueryString("/foo?a=b&c=&x=y", "/foo?a=b&c&x=y"); + AssertQueryString("/foo?a=", "/foo?a="); + AssertQueryString("/foo?a=", "/foo?&a="); + AssertQueryString("/foo?a=b&c=d", "/foo?a=b&c=d"); + AssertQueryString("/foo?a=1&a=&a=", "/foo?a=1&a&a="); + } + + [Fact] + public void PathSpecific() + { + // decode escaped characters + Assert.Equal("/foo bar/", new QueryStringDecoder("/foo%20bar/?").Path); + Assert.Equal("/foo\r\n\\bar/", new QueryStringDecoder("/foo%0D%0A\\bar/?").Path); + + // a 'fragment' after '#' should be cuted (see RFC 3986) + Assert.Equal("", new QueryStringDecoder("#123").Path); + Assert.Equal("foo", new QueryStringDecoder("foo?bar#anchor").Path); + Assert.Equal("/foo-bar", new QueryStringDecoder("/foo-bar#anchor").Path); + Assert.Equal("/foo-bar", new QueryStringDecoder("/foo-bar#a#b?c=d").Path); + + // '+' is not escape ' ' for the path + Assert.Equal("+", new QueryStringDecoder("+").Path); + Assert.Equal("/foo+bar/", new QueryStringDecoder("/foo+bar/?").Path); + Assert.Equal("/foo++", new QueryStringDecoder("/foo++?index.php").Path); + Assert.Equal("/foo +", new QueryStringDecoder("/foo%20+?index.php").Path); + Assert.Equal("/foo+ ", new QueryStringDecoder("/foo+%20").Path); + } + + [Fact] + public void ExcludeFragment() + { + // a 'fragment' after '#' should be cuted (see RFC 3986) + Assert.Equal("a", new QueryStringDecoder("?a#anchor").Parameters.Keys.ElementAt(0)); + Assert.Equal("b", new QueryStringDecoder("?a=b#anchor").Parameters["a"][0]); + Assert.True(new QueryStringDecoder("?#").Parameters.Count == 0); + Assert.True(new QueryStringDecoder("?#anchor").Parameters.Count == 0); + Assert.True(new QueryStringDecoder("#?a=b#anchor").Parameters.Count == 0); + Assert.True(new QueryStringDecoder("?#a=b#anchor").Parameters.Count == 0); + } + + [Fact] + public void HashDos() + { + var buf = new StringBuilder(); + buf.Append('?'); + for (int i = 0; i < 65536; i++) + { + buf.Append('k'); + buf.Append(i); + buf.Append("=v"); + buf.Append(i); + buf.Append('&'); + } + + var d = new QueryStringDecoder(buf.ToString()); + IDictionary> parameters = d.Parameters; + Assert.Equal(1024, parameters.Count); + } + + [Fact] + public void HasPath() + { + var d = new QueryStringDecoder("1=2", false); + Assert.Equal("", d.Path); + IDictionary> parameters = d.Parameters; + Assert.Equal(1, parameters.Count); + Assert.True(parameters.ContainsKey("1")); + List param = parameters["1"]; + Assert.NotNull(param); + Assert.Single(param); + Assert.Equal("2", param[0]); + } + + [Fact] + public void UrlDecoding() + { + string caffe = new string( + // "Caffé" but instead of putting the literal E-acute in the + // source file, we directly use the UTF-8 encoding so as to + // not rely on the platform's default encoding (not portable). + new [] { 'C', 'a', 'f', 'f', '\u00E9' /* C3 A9 */ }); + + string[] tests = + { + // Encoded -> Decoded or error message substring + "", "", + "foo", "foo", + "f+o", "f o", + "f++", "f ", + "fo%", "unterminated escape sequence at index 2 of: fo%", + "%42", "B", + "%5f", "_", + "f%4", "unterminated escape sequence at index 1 of: f%4", + "%x2", "invalid hex byte 'x2' at index 1 of '%x2'", + "%4x", "invalid hex byte '4x' at index 1 of '%4x'", + "Caff%C3%A9", caffe, + "случайный праздник", "случайный праздник", + "случайный%20праздник", "случайный праздник", + "случайный%20праздник%20%E2%98%BA", "случайный праздник ☺", + }; + + for (int i = 0; i < tests.Length; i += 2) + { + string encoded = tests[i]; + string expected = tests[i + 1]; + try + { + string decoded = QueryStringDecoder.DecodeComponent(encoded); + Assert.Equal(expected, decoded); + } + catch (ArgumentException e) + { + Assert.Equal(expected, e.Message); + } + } + } + + static void AssertQueryString(string expected, string actual) + { + var ed = new QueryStringDecoder(expected); + var ad = new QueryStringDecoder(actual); + Assert.Equal(ed.Path, ad.Path); + + IDictionary> edParams = ed.Parameters; + IDictionary> adParams = ad.Parameters; + Assert.Equal(edParams.Count, adParams.Count); + + foreach (string name in edParams.Keys) + { + List expectedValues = edParams[name]; + + Assert.True(adParams.ContainsKey(name)); + List values = adParams[name]; + Assert.Equal(expectedValues.Count, values.Count); + + foreach (string value in expectedValues) + { + Assert.Contains(value, values); + } + } + } + + // See #189 + [Fact] + public void UrlString() + { + var uri = new Uri("http://localhost:8080/foo?param1=value1¶m2=value2¶m3=value3"); + var d = new QueryStringDecoder(uri); + Assert.Equal("/foo", d.Path); + IDictionary > parameters = d.Parameters; + Assert.Equal(3, parameters.Count); + + KeyValuePair> entry = parameters.ElementAt(0); + Assert.Equal("param1", entry.Key); + Assert.Single(entry.Value); + Assert.Equal("value1", entry.Value[0]); + + entry = parameters.ElementAt(1); + Assert.Equal("param2", entry.Key); + Assert.Single(entry.Value); + Assert.Equal("value2", entry.Value[0]); + + entry = parameters.ElementAt(2); + Assert.Equal("param3", entry.Key); + Assert.Single(entry.Value); + Assert.Equal("value3", entry.Value[0]); + } + + // See #189 + [Fact] + public void UriSlashPath() + { + var uri = new Uri("http://localhost:8080/?param1=value1¶m2=value2¶m3=value3"); + var d = new QueryStringDecoder(uri); + Assert.Equal("/", d.Path); + IDictionary> parameters = d.Parameters; + Assert.Equal(3, parameters.Count); + + KeyValuePair> entry = parameters.ElementAt(0); + Assert.Equal("param1", entry.Key); + Assert.Single(entry.Value); + Assert.Equal("value1", entry.Value[0]); + + entry = parameters.ElementAt(1); + Assert.Equal("param2", entry.Key); + Assert.Single(entry.Value); + Assert.Equal("value2", entry.Value[0]); + + entry = parameters.ElementAt(2); + Assert.Equal("param3", entry.Key); + Assert.Single(entry.Value); + Assert.Equal("value3", entry.Value[0]); + } + + // See #189 + [Fact] + public void UriNoPath() + { + var uri = new Uri("http://localhost:8080?param1=value1¶m2=value2¶m3=value3"); + var d = new QueryStringDecoder(uri); + // The path component cannot be empty string, + // if there are no path component, it shoudl be '/' as above UriSlashPath test + Assert.Equal("/", d.Path); + IDictionary> parameters = d.Parameters; + Assert.Equal(3, parameters.Count); + + KeyValuePair> entry = parameters.ElementAt(0); + Assert.Equal("param1", entry.Key); + Assert.Single(entry.Value); + Assert.Equal("value1", entry.Value[0]); + + entry = parameters.ElementAt(1); + Assert.Equal("param2", entry.Key); + Assert.Single(entry.Value); + Assert.Equal("value2", entry.Value[0]); + + entry = parameters.ElementAt(2); + Assert.Equal("param3", entry.Key); + Assert.Single(entry.Value); + Assert.Equal("value3", entry.Value[0]); + } + + // See https://github.com/netty/netty/issues/1833 + [Fact] + public void Uri2() + { + var uri = new Uri("http://foo.com/images;num=10?query=name;value=123"); + var d = new QueryStringDecoder(uri); + Assert.Equal("/images;num=10", d.Path); + IDictionary> parameters = d.Parameters; + Assert.Equal(2, parameters.Count); + + KeyValuePair> entry = parameters.ElementAt(0); + Assert.Equal("query", entry.Key); + Assert.Single(entry.Value); + Assert.Equal("name", entry.Value[0]); + + entry = parameters.ElementAt(1); + Assert.Equal("value", entry.Key); + Assert.Single(entry.Value); + Assert.Equal("123", entry.Value[0]); + } + + [Fact] + public void EmptyStrings() + { + var pathSlash = new QueryStringDecoder("path/"); + Assert.Equal("path/", pathSlash.RawPath()); + Assert.Equal("", pathSlash.RawQuery()); + var pathQuestion = new QueryStringDecoder("path?"); + Assert.Equal("path", pathQuestion.RawPath()); + Assert.Equal("", pathQuestion.RawQuery()); + var empty = new QueryStringDecoder(""); + Assert.Equal("", empty.RawPath()); + Assert.Equal("", empty.RawQuery()); + } + } +} diff --git a/test/DotNetty.Codecs.Http.Tests/QueryStringEncoderTest.cs b/test/DotNetty.Codecs.Http.Tests/QueryStringEncoderTest.cs new file mode 100644 index 0000000..0a8e765 --- /dev/null +++ b/test/DotNetty.Codecs.Http.Tests/QueryStringEncoderTest.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Http.Tests +{ + using System.Text; + using Xunit; + + public sealed class QueryStringEncoderTest + { + [Fact] + public void DefaultEncoding() + { + var e = new QueryStringEncoder("/foo"); + e.AddParam("a", "b=c"); + Assert.Equal("/foo?a=b%3Dc", e.ToString()); + + e = new QueryStringEncoder("/foo/\u00A5"); + e.AddParam("a", "\u00A5"); + Assert.Equal("/foo/\u00A5?a=%C2%A5", e.ToString()); + + e = new QueryStringEncoder("/foo"); + e.AddParam("a", "1"); + e.AddParam("b", "2"); + Assert.Equal("/foo?a=1&b=2", e.ToString()); + + e = new QueryStringEncoder("/foo"); + e.AddParam("a", "1"); + e.AddParam("b", ""); + e.AddParam("c", null); + e.AddParam("d", null); + Assert.Equal("/foo?a=1&b=&c&d", e.ToString()); + } + + [Fact] + public void NonDefaultEncoding() + { + var e = new QueryStringEncoder("/foo/\u00A5", Encoding.BigEndianUnicode); + e.AddParam("a", "\u00A5"); + + // + // Note that java emits endianess byte order mark results + // automatically, therefore the result is: + // + // %FE%FF%00%A5. + // + // .NET does not do this automatically by GetPreamble() method + // and manually write to results, therefore the result is: + // + // %00%A5 + // + // URL query strings do not need to encode this + + Assert.Equal("/foo/\u00A5?a=%00%A5", e.ToString()); + } + + [Fact] + public void WhitespaceEncoding() + { + var e = new QueryStringEncoder("/foo"); + e.AddParam("a", "b c"); + Assert.Equal("/foo?a=b%20c", e.ToString()); + } + } +} diff --git a/test/DotNetty.Codecs.Protobuf.Tests/RoundTripTests.cs b/test/DotNetty.Codecs.Protobuf.Tests/RoundTripTests.cs index 473b409..75c02b8 100644 --- a/test/DotNetty.Codecs.Protobuf.Tests/RoundTripTests.cs +++ b/test/DotNetty.Codecs.Protobuf.Tests/RoundTripTests.cs @@ -3,7 +3,6 @@ namespace DotNetty.Codecs.Protobuf.Tests { - using System; using System.Collections.Generic; using DotNetty.Buffers; using DotNetty.Transport.Channels.Embedded; diff --git a/test/DotNetty.Codecs.Tests/DateFormatterTest.cs b/test/DotNetty.Codecs.Tests/DateFormatterTest.cs new file mode 100644 index 0000000..d02f2d7 --- /dev/null +++ b/test/DotNetty.Codecs.Tests/DateFormatterTest.cs @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Tests +{ + using System; + using Xunit; + + public sealed class DateFormatterTest + { + // This date is set at "06 Nov 1994 08:49:37 GMT", from + // examples in RFC documentation + readonly DateTime expectedTime = new DateTime(1994, 11, 6, 8, 49, 37, DateTimeKind.Utc); + + [Fact] + public void ParseWithSingleDigitDay() + { + Assert.Equal(this.expectedTime, DateFormatter.ParseHttpDate("Sun, 6 Nov 1994 08:49:37 GMT")); + } + + [Fact] + public void ParseWithDoubleDigitDay() + { + Assert.Equal(this.expectedTime, DateFormatter.ParseHttpDate("Sun, 06 Nov 1994 08:49:37 GMT")); + } + + [Fact] + public void ParseWithDashSeparatorSingleDigitDay() + { + Assert.Equal(this.expectedTime, DateFormatter.ParseHttpDate("Sunday, 06-Nov-94 08:49:37 GMT")); + } + + [Fact] + public void ParseWithSingleDoubleDigitDay() + { + Assert.Equal(this.expectedTime, DateFormatter.ParseHttpDate("Sunday, 6-Nov-94 08:49:37 GMT")); + } + + [Fact] + public void ParseWithoutGmt() + { + Assert.Equal(this.expectedTime, DateFormatter.ParseHttpDate("Sun Nov 6 08:49:37 1994")); + } + + [Fact] + public void ParseWithFunkyTimezone() + { + Assert.Equal(this.expectedTime, DateFormatter.ParseHttpDate("Sun Nov 6 08:49:37 1994 -0000")); + } + + [Fact] + public void ParseWithSingleDigitHourMinutesAndSecond() + { + Assert.Equal(this.expectedTime, DateFormatter.ParseHttpDate("Sunday, 6-Nov-94 8:49:37 GMT")); + } + + [Fact] + public void ParseWithSingleDigitTime() + { + Assert.Equal(this.expectedTime, DateFormatter.ParseHttpDate("Sunday, 6 Nov 1994 8:49:37 GMT")); + + DateTime time080937 = this.expectedTime - TimeSpan.FromMilliseconds(40 * 60 * 1000); + Assert.Equal(time080937, DateFormatter.ParseHttpDate("Sunday, 6 Nov 1994 8:9:37 GMT")); + Assert.Equal(time080937, DateFormatter.ParseHttpDate("Sunday, 6 Nov 1994 8:09:37 GMT")); + + DateTime time080907 = this.expectedTime - TimeSpan.FromMilliseconds((40 * 60 + 30) * 1000); + Assert.Equal(time080907, DateFormatter.ParseHttpDate("Sunday, 6 Nov 1994 8:9:7 GMT")); + Assert.Equal(time080907, DateFormatter.ParseHttpDate("Sunday, 6 Nov 1994 8:9:07 GMT")); + } + + [Fact] + public void ParseMidnight() + { + Assert.Equal(new DateTime(1994, 11, 6, 0, 0, 0, DateTimeKind.Utc), DateFormatter.ParseHttpDate("Sunday, 6 Nov 1994 00:00:00 GMT")); + } + + [Fact] + public void ParseInvalidInput() + { + // missing field + Assert.Null(DateFormatter.ParseHttpDate("Sun, Nov 1994 08:49:37 GMT")); + Assert.Null(DateFormatter.ParseHttpDate("Sun, 6 1994 08:49:37 GMT")); + Assert.Null(DateFormatter.ParseHttpDate("Sun, 6 Nov 08:49:37 GMT")); + Assert.Null(DateFormatter.ParseHttpDate("Sun, 6 Nov 1994 :49:37 GMT")); + Assert.Null(DateFormatter.ParseHttpDate("Sun, 6 Nov 1994 49:37 GMT")); + Assert.Null(DateFormatter.ParseHttpDate("Sun, 6 Nov 1994 08::37 GMT")); + Assert.Null(DateFormatter.ParseHttpDate("Sun, 6 Nov 1994 08:37 GMT")); + Assert.Null(DateFormatter.ParseHttpDate("Sun, 6 Nov 1994 08:49: GMT")); + Assert.Null(DateFormatter.ParseHttpDate("Sun, 6 Nov 1994 08:49 GMT")); + //invalid value + Assert.Null(DateFormatter.ParseHttpDate("Sun, 6 FOO 1994 08:49:37 GMT")); + Assert.Null(DateFormatter.ParseHttpDate("Sun, 36 Nov 1994 08:49:37 GMT")); + Assert.Null(DateFormatter.ParseHttpDate("Sun, 6 Nov 1994 28:49:37 GMT")); + Assert.Null(DateFormatter.ParseHttpDate("Sun, 6 Nov 1994 08:69:37 GMT")); + Assert.Null(DateFormatter.ParseHttpDate("Sun, 6 Nov 1994 08:49:67 GMT")); + //wrong number of digits in timestamp + Assert.Null(DateFormatter.ParseHttpDate("Sunday, 6 Nov 1994 0:0:000 GMT")); + Assert.Null(DateFormatter.ParseHttpDate("Sunday, 6 Nov 1994 0:000:0 GMT")); + Assert.Null(DateFormatter.ParseHttpDate("Sunday, 6 Nov 1994 000:0:0 GMT")); + } + + [Fact] + public void Format() + { + Assert.Equal("Sun, 6 Nov 1994 08:49:37 GMT", DateFormatter.Format(this.expectedTime)); + } + } +} diff --git a/test/DotNetty.Codecs.Tests/DefaultHeadersTest.cs b/test/DotNetty.Codecs.Tests/DefaultHeadersTest.cs new file mode 100644 index 0000000..bea4b79 --- /dev/null +++ b/test/DotNetty.Codecs.Tests/DefaultHeadersTest.cs @@ -0,0 +1,652 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Codecs.Tests +{ + using System; + using System.Collections.Generic; + using System.Linq; + using System.Text; + using DotNetty.Common.Utilities; + using Xunit; + + using static Common.Utilities.AsciiString; + + public sealed class DefaultHeadersTest + { + sealed class TestDefaultHeaders : DefaultHeaders + { + public TestDefaultHeaders() : this(CharSequenceValueConverter.Default) + { + } + + public TestDefaultHeaders(IValueConverter converter) : base(converter) + { + } + } + + static TestDefaultHeaders NewInstance() => new TestDefaultHeaders(); + + [Fact] + public void AddShouldIncreaseAndRemoveShouldDecreaseTheSize() + { + TestDefaultHeaders headers = NewInstance(); + Assert.Equal(0, headers.Size); + headers.Add(Of("name1"), new[] { Of("value1"), Of("value2") }); + Assert.Equal(2, headers.Size); + headers.Add(Of("name2"), new[] { Of("value3"), Of("value4") }); + Assert.Equal(4, headers.Size); + headers.Add(Of("name3"), Of("value5")); + Assert.Equal(5, headers.Size); + + headers.Remove(Of("name3")); + Assert.Equal(4, headers.Size); + headers.Remove(Of("name1")); + Assert.Equal(2, headers.Size); + headers.Remove(Of("name2")); + Assert.Equal(0, headers.Size); + Assert.True(headers.IsEmpty); + } + + [Fact] + public void AfterClearHeadersShouldBeEmpty() + { + TestDefaultHeaders headers = NewInstance(); + headers.Add(Of("name1"), Of("value1")); + headers.Add(Of("name2"), Of("value2")); + Assert.Equal(2, headers.Size); + headers.Clear(); + Assert.Equal(0, headers.Size); + Assert.True(headers.IsEmpty); + Assert.False(headers.Contains(Of("name1"))); + Assert.False(headers.Contains(Of("name2"))); + } + + [Fact] + public void RemovingANameForASecondTimeShouldReturnFalse() + { + TestDefaultHeaders headers = NewInstance(); + headers.Add(Of("name1"), Of("value1")); + headers.Add(Of("name2"), Of("value2")); + Assert.True(headers.Remove(Of("name2"))); + Assert.False(headers.Remove(Of("name2"))); + } + + [Fact] + public void MultipleValuesPerNameShouldBeAllowed() + { + TestDefaultHeaders headers = NewInstance(); + headers.Add(Of("name"), Of("value1")); + headers.Add(Of("name"), Of("value2")); + headers.Add(Of("name"), Of("value3")); + Assert.Equal(3, headers.Size); + + IList values = headers.GetAll(Of("name")); + Assert.Equal(3, values.Count); + Assert.True(values.Contains(Of("value1"))); + Assert.True(values.Contains(Of("value2"))); + Assert.True(values.Contains(Of("value3"))); + } + + [Fact] + public void MultipleValuesPerNameIterator() + { + TestDefaultHeaders headers = NewInstance(); + headers.Add(Of("name"), Of("value1")); + headers.Add(Of("name"), Of("value2")); + headers.Add(Of("name"), Of("value3")); + Assert.Equal(3, headers.Size); + + var values = new List(); + foreach (ICharSequence value in headers.ValueIterator(Of("name"))) + { + values.Add(value); + } + Assert.Equal(3, values.Count); + Assert.Contains(Of("value1"), values); + Assert.Contains(Of("value2"), values); + Assert.Contains(Of("value3"), values); + } + + [Fact] + public void MultipleValuesPerNameIteratorEmpty() + { + TestDefaultHeaders headers = NewInstance(); + var values = new List(); + foreach (ICharSequence value in headers.ValueIterator(Of("name"))) + { + values.Add(value); + } + Assert.Empty(values); + } + + [Fact] + public void Contains() + { + TestDefaultHeaders headers = NewInstance(); + headers.AddBoolean(Of("boolean"), true); + Assert.True(headers.ContainsBoolean(Of("boolean"), true)); + Assert.False(headers.ContainsBoolean(Of("boolean"), false)); + + headers.AddLong(Of("long"), long.MaxValue); + Assert.True(headers.ContainsLong(Of("long"), long.MaxValue)); + Assert.False(headers.ContainsLong(Of("long"), long.MinValue)); + + headers.AddInt(Of("int"), int.MinValue); + Assert.True(headers.ContainsInt(Of("int"), int.MinValue)); + Assert.False(headers.ContainsInt(Of("int"), int.MaxValue)); + + headers.AddShort(Of("short"), short.MaxValue); + Assert.True(headers.ContainsShort(Of("short"), short.MaxValue)); + Assert.False(headers.ContainsShort(Of("short"), short.MinValue)); + + headers.AddChar(Of("char"), char.MaxValue); + Assert.True(headers.ContainsChar(Of("char"), char.MaxValue)); + Assert.False(headers.ContainsChar(Of("char"), char.MinValue)); + + headers.AddByte(Of("byte"), byte.MaxValue); + Assert.True(headers.ContainsByte(Of("byte"), byte.MaxValue)); + Assert.False(headers.ContainsByte(Of("byte"), byte.MinValue)); + + headers.AddDouble(Of("double"), double.MaxValue); + Assert.True(headers.ContainsDouble(Of("double"), double.MaxValue)); + Assert.False(headers.ContainsDouble(Of("double"), double.MinValue)); + + headers.AddFloat(Of("float"), float.MaxValue); + Assert.True(headers.ContainsFloat(Of("float"), float.MaxValue)); + Assert.False(headers.ContainsFloat(Of("float"), float.MinValue)); + + long millis = (long)Math.Floor(DateTime.UtcNow.Ticks / (double)TimeSpan.TicksPerMillisecond); + headers.AddTimeMillis(Of("millis"), millis); + Assert.True(headers.ContainsTimeMillis(Of("millis"), millis)); + // This test doesn't work on midnight, January 1, 1970 UTC + Assert.False(headers.ContainsTimeMillis(Of("millis"), 0)); + + headers.AddObject(Of("object"), "Hello World"); + Assert.True(headers.ContainsObject(Of("object"), "Hello World")); + Assert.False(headers.ContainsObject(Of("object"), "")); + + headers.Add(Of("name"), Of("value")); + Assert.True(headers.Contains(Of("name"), Of("value"))); + Assert.False(headers.Contains(Of("name"), Of("value1"))); + } + + [Fact] + public void Copy() + { + IHeaders headers = NewInstance(); + headers.AddBoolean(Of("boolean"), true); + headers.AddLong(Of("long"), long.MaxValue); + headers.AddInt(Of("int"), int.MinValue); + headers.AddShort(Of("short"), short.MaxValue); + headers.AddChar(Of("char"), char.MaxValue); + headers.AddByte(Of("byte"), byte.MaxValue); + headers.AddDouble(Of("double"), double.MaxValue); + headers.AddFloat(Of("float"), float.MaxValue); + long millis = (long)Math.Floor(DateTime.UtcNow.Ticks / (double)TimeSpan.TicksPerMillisecond); + headers.AddTimeMillis(Of("millis"), millis); + headers.AddObject(Of("object"), "Hello World"); + headers.Add(Of("name"), Of("value")); + + headers = NewInstance().Add(headers); + + Assert.True(headers.ContainsBoolean(Of("boolean"), true)); + Assert.False(headers.ContainsBoolean(Of("boolean"), false)); + + Assert.True(headers.ContainsLong(Of("long"), long.MaxValue)); + Assert.False(headers.ContainsLong(Of("long"), long.MinValue)); + + Assert.True(headers.ContainsInt(Of("int"), int.MinValue)); + Assert.False(headers.ContainsInt(Of("int"), int.MaxValue)); + + Assert.True(headers.ContainsShort(Of("short"), short.MaxValue)); + Assert.False(headers.ContainsShort(Of("short"), short.MinValue)); + + Assert.True(headers.ContainsChar(Of("char"), char.MaxValue)); + Assert.False(headers.ContainsChar(Of("char"), char.MinValue)); + + Assert.True(headers.ContainsByte(Of("byte"), byte.MaxValue)); + Assert.False(headers.ContainsLong(Of("byte"), byte.MinValue)); + + Assert.True(headers.ContainsDouble(Of("double"), double.MaxValue)); + Assert.False(headers.ContainsDouble(Of("double"), double.MinValue)); + + Assert.True(headers.ContainsFloat(Of("float"), float.MaxValue)); + Assert.False(headers.ContainsFloat(Of("float"), float.MinValue)); + + Assert.True(headers.ContainsTimeMillis(Of("millis"), millis)); + // This test doesn't work on midnight, January 1, 1970 UTC + Assert.False(headers.ContainsTimeMillis(Of("millis"), 0)); + + Assert.True(headers.ContainsObject(Of("object"), "Hello World")); + Assert.False(headers.ContainsObject(Of("object"), "")); + + Assert.True(headers.Contains(Of("name"), Of("value"))); + Assert.False(headers.Contains(Of("name"), Of("value1"))); + } + + [Fact] + public void CanMixConvertedAndNormalValues() + { + TestDefaultHeaders headers = NewInstance(); + headers.Add(Of("name"), Of("value")); + headers.AddInt(Of("name"), 100); + headers.AddBoolean(Of("name"), false); + + Assert.Equal(3, headers.Size); + Assert.True(headers.Contains(Of("name"))); + Assert.True(headers.Contains(Of("name"), Of("value"))); + Assert.True(headers.ContainsInt(Of("name"), 100)); + Assert.True(headers.ContainsBoolean(Of("name"), false)); + } + + [Fact] + public void GetAndRemove() + { + TestDefaultHeaders headers = NewInstance(); + headers.Add(Of("name1"), Of("value1")); + headers.Add(Of("name2"), new [] { Of("value2"), Of("value3")}); + headers.Add(Of("name3"), new [] { Of("value4"), Of("value5"), Of("value6") }); + + Assert.Equal(Of("value1"), headers.GetAndRemove(Of("name1"), Of("defaultvalue"))); + Assert.True(headers.TryGetAndRemove(Of("name2"), out ICharSequence value)); + Assert.Equal(Of("value2"), value); + Assert.False(headers.TryGetAndRemove(Of("name2"), out value)); + Assert.Null(value); + Assert.True(new [] { Of("value4"), Of("value5"), Of("value6") }.SequenceEqual(headers.GetAllAndRemove(Of("name3")))); + Assert.Equal(0, headers.Size); + Assert.False(headers.TryGetAndRemove(Of("noname"), out value)); + Assert.Null(value); + Assert.Equal(Of("defaultvalue"), headers.GetAndRemove(Of("noname"), Of("defaultvalue"))); + } + + [Fact] + public void WhenNameContainsMultipleValuesGetShouldReturnTheFirst() + { + TestDefaultHeaders headers = NewInstance(); + headers.Add(Of("name1"), new []{ Of("value1"), Of("value2")}); + Assert.True(headers.TryGet(Of("name1"), out ICharSequence value)); + Assert.Equal(Of("value1"), value); + } + + [Fact] + public void GetWithDefaultValue() + { + TestDefaultHeaders headers = NewInstance(); + headers.Add(Of("name1"), Of("value1")); + + Assert.Equal(Of("value1"), headers.Get(Of("name1"), Of("defaultvalue"))); + Assert.Equal(Of("defaultvalue"), headers.Get(Of("noname"), Of("defaultvalue"))); + } + + [Fact] + public void SetShouldOverWritePreviousValue() + { + TestDefaultHeaders headers = NewInstance(); + headers.Set(Of("name"), Of("value1")); + headers.Set(Of("name"), Of("value2")); + Assert.Equal(1, headers.Size); + Assert.Equal(1, headers.GetAll(Of("name")).Count); + Assert.Equal(Of("value2"), headers.GetAll(Of("name"))[0]); + Assert.True(headers.TryGet(Of("name"), out ICharSequence value)); + Assert.Equal(Of("value2"), value); + } + + [Fact] + public void SetAllShouldOverwriteSomeAndLeaveOthersUntouched() + { + TestDefaultHeaders h1 = NewInstance(); + + h1.Add(Of("name1"), Of("value1")); + h1.Add(Of("name2"), Of("value2")); + h1.Add(Of("name2"), Of("value3")); + h1.Add(Of("name3"), Of("value4")); + + TestDefaultHeaders h2 = NewInstance(); + h2.Add(Of("name1"), Of("value5")); + h2.Add(Of("name2"), Of("value6")); + h2.Add(Of("name1"), Of("value7")); + + TestDefaultHeaders expected = NewInstance(); + expected.Add(Of("name1"), Of("value5")); + expected.Add(Of("name2"), Of("value6")); + expected.Add(Of("name1"), Of("value7")); + expected.Add(Of("name3"), Of("value4")); + + h1.SetAll(h2); + + Assert.True(expected.Equals(h1)); + } + + [Fact] + public void HeadersWithSameNamesAndValuesShouldBeEquivalent() + { + TestDefaultHeaders headers1 = NewInstance(); + headers1.Add(Of("name1"), Of("value1")); + headers1.Add(Of("name2"), Of("value2")); + headers1.Add(Of("name2"), Of("value3")); + + TestDefaultHeaders headers2 = NewInstance(); + headers2.Add(Of("name1"), Of("value1")); + headers2.Add(Of("name2"), Of("value2")); + headers2.Add(Of("name2"), Of("value3")); + + Assert.True(headers1.Equals(headers2)); + Assert.True(headers2.Equals(headers1)); + Assert.Equal(headers1.GetHashCode(), headers2.GetHashCode()); + Assert.Equal(headers1.GetHashCode(), headers1.GetHashCode()); + Assert.Equal(headers2.GetHashCode(), headers2.GetHashCode()); + } + + [Fact] + public void EmptyHeadersShouldBeEqual() + { + TestDefaultHeaders headers1 = NewInstance(); + TestDefaultHeaders headers2 = NewInstance(); + Assert.NotSame(headers1, headers2); + Assert.True(headers1.Equals(headers2)); + Assert.Equal(headers1.GetHashCode(), headers2.GetHashCode()); + } + + [Fact] + public void HeadersWithSameNamesButDifferentValuesShouldNotBeEquivalent() + { + TestDefaultHeaders headers1 = NewInstance(); + headers1.Add(Of("name1"), Of("value1")); + TestDefaultHeaders headers2 = NewInstance(); + headers1.Add(Of("name1"), Of("value2")); + Assert.False(headers1.Equals(headers2)); + } + + [Fact] + public void SubsetOfHeadersShouldNotBeEquivalent() + { + TestDefaultHeaders headers1 = NewInstance(); + headers1.Add(Of("name1"), Of("value1")); + headers1.Add(Of("name2"), Of("value2")); + TestDefaultHeaders headers2 = NewInstance(); + headers1.Add(Of("name1"), Of("value1")); + Assert.False(headers1.Equals(headers2)); + } + + [Fact] + public void HeadersWithDifferentNamesAndValuesShouldNotBeEquivalent() + { + TestDefaultHeaders h1 = NewInstance(); + h1.Set(Of("name1"), Of("value1")); + TestDefaultHeaders h2 = NewInstance(); + h2.Set(Of("name2"), Of("value2")); + Assert.False(h1.Equals(h2)); + Assert.False(h2.Equals(h1)); + } + + [Fact] + public void IterateEmptyHeaders() + { + TestDefaultHeaders headers = NewInstance(); + var list = new List>(headers); + Assert.Empty(list); + } + + [Fact] + public void IteratorShouldReturnAllNameValuePairs() + { + TestDefaultHeaders headers1 = NewInstance(); + headers1.Add(Of("name1"), new[] { Of("value1"), Of("value2") }); + headers1.Add(Of("name2"), Of("value3")); + headers1.Add(Of("name3"), new[] { Of("value4"), Of("value5"), Of("value6") }); + headers1.Add(Of("name1"), new[] { Of("value7"), Of("value8") }); + Assert.Equal(8, headers1.Size); + + TestDefaultHeaders headers2 = NewInstance(); + foreach (HeaderEntry entry in headers1) + { + headers2.Add(entry.Key, entry.Value); + } + + Assert.True(headers1.Equals(headers2)); + } + + [Fact] + public void IteratorSetValueShouldChangeHeaderValue() + { + TestDefaultHeaders headers = NewInstance(); + headers.Add(Of("name1"), new[] { Of("value1"), Of("value2"), Of("value3")}); + headers.Add(Of("name2"), Of("value4")); + Assert.Equal(4, headers.Size); + + foreach(HeaderEntry header in headers) + { + if (Of("name1").Equals(header.Key) && Of("value2").Equals(header.Value)) + { + header.SetValue(Of("updatedvalue2")); + Assert.Equal(Of("updatedvalue2"), header.Value); + } + if (Of("name1").Equals(header.Key) && Of("value3").Equals(header.Value)) + { + header.SetValue(Of("updatedvalue3")); + Assert.Equal(Of("updatedvalue3"), header.Value); + } + } + + Assert.Equal(4, headers.Size); + Assert.True(headers.Contains(Of("name1"), Of("updatedvalue2"))); + Assert.False(headers.Contains(Of("name1"), Of("value2"))); + Assert.True(headers.Contains(Of("name1"), Of("updatedvalue3"))); + Assert.False(headers.Contains(Of("name1"), Of("value3"))); + } + + [Fact] + public void EntryEquals() + { + IHeaders same1 = NewInstance().Add(Of("name"), Of("value")); + IHeaders same2 = NewInstance().Add(Of("name"), Of("value")); + Assert.True(same1.Equals(same2)); + Assert.Equal(same1.GetHashCode(), same2.GetHashCode()); + + IHeaders nameDifferent1 = NewInstance().Add(Of("name1"), Of("value")); + IHeaders nameDifferent2 = NewInstance().Add(Of("name2"), Of("value")); + Assert.False(nameDifferent1.Equals(nameDifferent2)); + Assert.NotEqual(nameDifferent1.GetHashCode(), nameDifferent2.GetHashCode()); + + IHeaders valueDifferent1 = NewInstance().Add(Of("name"), Of("value1")); + IHeaders valueDifferent2 = NewInstance().Add(Of("name"), Of("value2")); + Assert.False(valueDifferent1.Equals(valueDifferent2)); + Assert.NotEqual(valueDifferent1.GetHashCode(), valueDifferent2.GetHashCode()); + } + + [Fact] + public void GetAllReturnsEmptyListForUnknownName() + { + TestDefaultHeaders headers = NewInstance(); + Assert.Equal(0, headers.GetAll(Of("noname")).Count); + } + + [Fact] + public void SetHeadersShouldClearAndOverwrite() + { + TestDefaultHeaders headers1 = NewInstance(); + headers1.Add(Of("name"), Of("value")); + + TestDefaultHeaders headers2 = NewInstance(); + headers2.Add(Of("name"), Of("newvalue")); + headers2.Add(Of("name1"), Of("value1")); + + headers1.Set(headers2); + Assert.True(headers1.Equals(headers2)); + } + + [Fact] + public void SetAllHeadersShouldOnlyOverwriteHeaders() + { + TestDefaultHeaders headers1 = NewInstance(); + headers1.Add(Of("name"), Of("value")); + headers1.Add(Of("name1"), Of("value1")); + + TestDefaultHeaders headers2 = NewInstance(); + headers2.Add(Of("name"), Of("newvalue")); + headers2.Add(Of("name2"), Of("value2")); + + TestDefaultHeaders expected = NewInstance(); + expected.Add(Of("name"), Of("newvalue")); + expected.Add(Of("name1"), Of("value1")); + expected.Add(Of("name2"), Of("value2")); + + headers1.SetAll(headers2); + Assert.True(headers1.Equals(expected)); + } + + [Fact] + public void AddSelf() + { + TestDefaultHeaders headers = NewInstance(); + Assert.Throws(() => headers.Add(headers)); + } + + [Fact] + public void SetSelfIsNoOp() + { + TestDefaultHeaders headers = NewInstance(); + headers.Add(Of("name"), Of("value")); + headers.Set(headers); + Assert.Equal(1, headers.Size); + } + + [Fact] + public void ConvertToString() + { + TestDefaultHeaders headers = NewInstance(); + headers.Add(Of("name1"), Of("value1")); + headers.Add(Of("name1"), Of("value2")); + headers.Add(Of("name2"), Of("value3")); + Assert.Equal("TestDefaultHeaders[name1: value1, name1: value2, name2: value3]", headers.ToString()); + + headers = NewInstance(); + headers.Add(Of("name1"), Of("value1")); + headers.Add(Of("name2"), Of("value2")); + headers.Add(Of("name3"), Of("value3")); + Assert.Equal("TestDefaultHeaders[name1: value1, name2: value2, name3: value3]", headers.ToString()); + + headers = NewInstance(); + headers.Add(Of("name1"), Of("value1")); + Assert.Equal("TestDefaultHeaders[name1: value1]", headers.ToString()); + + headers = NewInstance(); + Assert.Equal("TestDefaultHeaders[]", headers.ToString()); + } + + sealed class ThrowConverter : IValueConverter + { + public ICharSequence ConvertObject(object value) => throw new ArgumentException(); + + public ICharSequence ConvertBoolean(bool value) => throw new ArgumentException(); + + public bool ConvertToBoolean(ICharSequence value) => throw new ArgumentException(); + + public ICharSequence ConvertByte(byte value) => throw new ArgumentException(); + + public byte ConvertToByte(ICharSequence value) => throw new ArgumentException(); + + public ICharSequence ConvertChar(char value) => throw new ArgumentException(); + + public char ConvertToChar(ICharSequence value) => throw new ArgumentException(); + + public ICharSequence ConvertShort(short value) => throw new ArgumentException(); + + public short ConvertToShort(ICharSequence value) => throw new ArgumentException(); + + public ICharSequence ConvertInt(int value) => throw new ArgumentException(); + + public int ConvertToInt(ICharSequence value) => throw new ArgumentException(); + + public ICharSequence ConvertLong(long value) => throw new ArgumentException(); + + public long ConvertToLong(ICharSequence value) => throw new ArgumentException(); + + public ICharSequence ConvertTimeMillis(long value) => throw new ArgumentException(); + + public long ConvertToTimeMillis(ICharSequence value) => throw new ArgumentException(); + + public ICharSequence ConvertFloat(float value) => throw new ArgumentException(); + + public float ConvertToFloat(ICharSequence value) => throw new ArgumentException(); + + public ICharSequence ConvertDouble(double value) => throw new ArgumentException(); + + public double ConvertToDouble(ICharSequence value) => throw new ArgumentException(); + } + + [Fact] + public void NotThrowWhenConvertFails() + { + var headers = new TestDefaultHeaders(new ThrowConverter()); + + headers.Set(Of("name1"), Of("")); + Assert.False(headers.TryGetInt(Of("name1"), out int _)); + Assert.Equal(1, headers.GetInt(Of("name1"), 1)); + + Assert.False(headers.TryGetBoolean(Of(""), out bool _)); + Assert.False(headers.GetBoolean(Of("name1"), false)); + + Assert.False(headers.TryGetByte(Of("name1"), out byte _)); + Assert.Equal(1, headers.GetByte(Of("name1"), 1)); + + Assert.False(headers.TryGetChar(Of("name"), out char _)); + Assert.Equal('n', headers.GetChar(Of("name1"), 'n')); + + Assert.False(headers.TryGetDouble(Of("name"), out double _)); + Assert.Equal(1, headers.GetDouble(Of("name1"), 1), 0); + + Assert.False(headers.TryGetFloat(Of("name"), out float _)); + Assert.Equal(float.MaxValue, headers.GetFloat(Of("name1"), float.MaxValue), 0); + + Assert.False(headers.TryGetLong(Of("name"), out long _)); + Assert.Equal(long.MaxValue, headers.GetLong(Of("name1"), long.MaxValue)); + + Assert.False(headers.TryGetShort(Of("name"), out short _)); + Assert.Equal(short.MaxValue, headers.GetShort(Of("name1"), short.MaxValue)); + + Assert.False(headers.TryGetTimeMillis(Of("name"), out long _)); + Assert.Equal(long.MaxValue, headers.GetTimeMillis(Of("name1"), long.MaxValue)); + } + + [Fact] + public void GetBooleanInvalidValue() + { + TestDefaultHeaders headers = NewInstance(); + headers.Set(Of("name1"), new StringCharSequence("invalid")); + headers.Set(Of("name2"), new AsciiString("invalid")); + headers.Set(Of("name3"), new StringBuilderCharSequence(new StringBuilder("invalid"))); + + Assert.False(headers.GetBoolean(Of("name1"), false)); + Assert.False(headers.GetBoolean(Of("name2"), false)); + Assert.False(headers.GetBoolean(Of("name3"), false)); + } + + [Fact] + public void GetBooleanFalseValue() + { + TestDefaultHeaders headers = NewInstance(); + headers.Set(Of("name1"), new StringCharSequence("false")); + headers.Set(Of("name2"), new AsciiString("false")); + headers.Set(Of("name3"), new StringBuilderCharSequence(new StringBuilder("false"))); + + Assert.False(headers.GetBoolean(Of("name1"), true)); + Assert.False(headers.GetBoolean(Of("name2"), true)); + Assert.False(headers.GetBoolean(Of("name3"), true)); + } + + [Fact] + public void GetBooleanTrueValue() + { + TestDefaultHeaders headers = NewInstance(); + headers.Set(Of("name1"), new StringCharSequence("true")); + headers.Set(Of("name2"), new AsciiString("true")); + headers.Set(Of("name3"), new StringBuilderCharSequence(new StringBuilder("true"))); + + Assert.True(headers.GetBoolean(Of("name1"), false)); + Assert.True(headers.GetBoolean(Of("name2"), false)); + Assert.True(headers.GetBoolean(Of("name3"), false)); + } + } +} diff --git a/test/DotNetty.Common.Tests/Utilities/AsciiStringCharacterTest.cs b/test/DotNetty.Common.Tests/Utilities/AsciiStringCharacterTest.cs new file mode 100644 index 0000000..344fe5c --- /dev/null +++ b/test/DotNetty.Common.Tests/Utilities/AsciiStringCharacterTest.cs @@ -0,0 +1,313 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Common.Tests.Utilities +{ + using System; + using System.Linq; + using System.Text; + using DotNetty.Common.Utilities; + using Xunit; + + public sealed class AsciiStringCharacterTest + { + static readonly Random Rand = new Random(); + + [Fact] + public void GetBytes() + { + var b = new StringBuilder(); + for (int i = 0; i < 1 << 16; ++i) + { + b.Append("eéaà"); + } + string bString = b.ToString(); + var encodingList = new[] + { + Encoding.ASCII, + Encoding.BigEndianUnicode, + Encoding.UTF32, + Encoding.UTF7, + Encoding.UTF8, + Encoding.Unicode + }; + + foreach (Encoding encoding in encodingList) + { + byte[] expected = encoding.GetBytes(bString); + var value = new AsciiString(bString, encoding); + byte[] actual = value.ToByteArray(); + Assert.True(expected.SequenceEqual(actual)); + } + } + + [Fact] + public void ComparisonWithString() + { + const string Value = "shouldn't fail"; + var ascii = new AsciiString(Value.ToCharArray()); + Assert.Equal(Value, ascii.ToString()); + } + + [Fact] + public void SubSequence() + { + char[] initChars = { 't', 'h', 'i', 's', ' ', 'i', 's', ' ', 'a', ' ', 't', 'e', 's', 't' }; + byte[] init = initChars.Select(c => (byte)c).ToArray(); + var ascii = new AsciiString(init); + const int Start = 2; + int end = init.Length; + AsciiString sub1 = ascii.SubSequence(Start, end, false); + AsciiString sub2 = ascii.SubSequence(Start, end, true); + Assert.Equal(sub1.GetHashCode(), sub2.GetHashCode()); + Assert.Equal(sub1, sub2); + for (int i = Start; i < end; ++i) + { + Assert.Equal(init[i], sub1.ByteAt(i - Start)); + } + } + + [Fact] + public void Contains() + { + string[] falseLhs = { "a", "aa", "aaa" }; + string[] falseRhs = { "b", "ba", "baa" }; + foreach (string lhs in falseLhs) + { + foreach (string rhs in falseRhs) + { + AssertContains(lhs, rhs, false, false); + } + } + + AssertContains("", "", true, true); + AssertContains("AsfdsF", "", true, true); + AssertContains("", "b", false, false); + AssertContains("a", "a", true, true); + AssertContains("a", "b", false, false); + AssertContains("a", "A", false, true); + string b = "xyz"; + string a = b; + AssertContains(a, b, true, true); + + a = "a" + b; + AssertContains(a, b, true, true); + + a = b + "a"; + AssertContains(a, b, true, true); + + a = "a" + b + "a"; + AssertContains(a, b, true, true); + + b = "xYz"; + a = "xyz"; + AssertContains(a, b, false, true); + + b = "xYz"; + a = "xyzxxxXyZ" + b + "aaa"; + AssertContains(a, b, true, true); + + b = "foOo"; + a = "fooofoO"; + AssertContains(a, b, false, true); + + b = "Content-Equals: 10000"; + a = "content-equals: 1000"; + AssertContains(a, b, false, false); + a += "0"; + AssertContains(a, b, false, true); + } + + static void AssertContains(string a, string b, bool caseSensitiveEquals, bool caseInsenstaiveEquals) + { + var asciiA = new AsciiString(a); + var asciiB = new AsciiString(b); + Assert.Equal(caseSensitiveEquals, AsciiString.Contains(asciiA, asciiB)); + Assert.Equal(caseInsenstaiveEquals, AsciiString.ContainsIgnoreCase(asciiA, asciiB)); + } + + [Fact] + public void CaseSensitivity() + { + int i = 0; + for (; i < 32; i++) + { + DoCaseSensitivity(i); + } + int min = i; + const int Max = 4000; + int len = Rand.Next((Max - min) + 1) + min; + DoCaseSensitivity(len); + } + + static void DoCaseSensitivity(int len) + { + // Build an upper case and lower case string + const int UpperA = 'A'; + const int UpperZ = 'Z'; + const int UpperToLower = (int)'a' - UpperA; + + var lowerCaseBytes = new byte[len]; + var upperCaseBuilder = new StringBuilderCharSequence(len); + for (int i = 0; i < len; ++i) + { + char upper = (char)(Rand.Next((UpperZ - UpperA) + 1) + UpperA); + upperCaseBuilder.Append(upper); + lowerCaseBytes[i] = (byte)(upper + UpperToLower); + } + var upperCaseString = (StringCharSequence)upperCaseBuilder.ToString(); + var lowerCaseString = (StringCharSequence)new string(lowerCaseBytes.Select(x => (char)x).ToArray()); + var lowerCaseAscii = new AsciiString(lowerCaseBytes, false); + var upperCaseAscii = new AsciiString(upperCaseString); + + // Test upper case hash codes are equal + int upperCaseExpected = upperCaseAscii.GetHashCode(); + Assert.Equal(upperCaseExpected, AsciiString.GetHashCode(upperCaseBuilder)); + Assert.Equal(upperCaseExpected, AsciiString.GetHashCode(upperCaseString)); + Assert.Equal(upperCaseExpected, upperCaseAscii.GetHashCode()); + + // Test lower case hash codes are equal + int lowerCaseExpected = lowerCaseAscii.GetHashCode(); + Assert.Equal(lowerCaseExpected, AsciiString.GetHashCode(lowerCaseAscii)); + Assert.Equal(lowerCaseExpected, AsciiString.GetHashCode(lowerCaseString)); + Assert.Equal(lowerCaseExpected, lowerCaseAscii.GetHashCode()); + + // Test case insensitive hash codes are equal + int expectedCaseInsensitive = lowerCaseAscii.GetHashCode(); + Assert.Equal(expectedCaseInsensitive, AsciiString.GetHashCode(upperCaseBuilder)); + Assert.Equal(expectedCaseInsensitive, AsciiString.GetHashCode(upperCaseString)); + Assert.Equal(expectedCaseInsensitive, AsciiString.GetHashCode(lowerCaseString)); + Assert.Equal(expectedCaseInsensitive, AsciiString.GetHashCode(lowerCaseAscii)); + Assert.Equal(expectedCaseInsensitive, AsciiString.GetHashCode(upperCaseAscii)); + Assert.Equal(expectedCaseInsensitive, lowerCaseAscii.GetHashCode()); + Assert.Equal(expectedCaseInsensitive, upperCaseAscii.GetHashCode()); + + // Test that opposite cases are equal + Assert.Equal(lowerCaseAscii.GetHashCode(), AsciiString.GetHashCode(upperCaseString)); + Assert.Equal(upperCaseAscii.GetHashCode(), AsciiString.GetHashCode(lowerCaseString)); + } + + [Fact] + public void CaseInsensitiveHasherCharBuffer() + { + const string S1 = "TRANSFER-ENCODING"; + var array = new char[128]; + const int Offset = 100; + for (int i = 0; i < S1.Length; ++i) + { + array[Offset + i] = S1[i]; + } + + var s = new AsciiString(S1); + var b = new AsciiString(array, Offset, S1.Length); + Assert.Equal(AsciiString.GetHashCode(s), AsciiString.GetHashCode(b)); + } + + [Fact] + public void BooleanUtilityMethods() + { + Assert.True(new AsciiString(new byte[] { 1 }).ParseBoolean()); + Assert.False(AsciiString.Empty.ParseBoolean()); + Assert.False(new AsciiString(new byte[] { 0 }).ParseBoolean()); + Assert.True(new AsciiString(new byte[] { 5 }).ParseBoolean()); + Assert.True(new AsciiString(new byte[] { 2, 0 }).ParseBoolean()); + } + + [Fact] + public void EqualsIgnoreCase() + { + Assert.True(AsciiString.ContentEqualsIgnoreCase(null, null)); + Assert.False(AsciiString.ContentEqualsIgnoreCase(null, (StringCharSequence)"foo")); + Assert.False(AsciiString.ContentEqualsIgnoreCase((StringCharSequence)"bar", null)); + Assert.True(AsciiString.ContentEqualsIgnoreCase((StringCharSequence)"FoO", (StringCharSequence)"fOo")); + + // Test variations (Ascii + String, Ascii + Ascii, String + Ascii) + Assert.True(AsciiString.ContentEqualsIgnoreCase((AsciiString)"FoO", (StringCharSequence)"fOo")); + Assert.True(AsciiString.ContentEqualsIgnoreCase((AsciiString)"FoO", (AsciiString)"fOo")); + Assert.True(AsciiString.ContentEqualsIgnoreCase((StringCharSequence)"FoO", (AsciiString)"fOo")); + + // Test variations (Ascii + String, Ascii + Ascii, String + Ascii) + Assert.False(AsciiString.ContentEqualsIgnoreCase((AsciiString)"FoO", (StringCharSequence)"bAr")); + Assert.False(AsciiString.ContentEqualsIgnoreCase((AsciiString)"FoO", (AsciiString)"bAr")); + Assert.False(AsciiString.ContentEqualsIgnoreCase((StringCharSequence)"FoO", (AsciiString)"bAr")); + } + + [Fact] + public void IndexOfIgnoreCase() + { + Assert.Equal(-1, AsciiString.IndexOfIgnoreCase(null, (AsciiString)"abc", 1)); + Assert.Equal(-1, AsciiString.IndexOfIgnoreCase((AsciiString)"abc", null, 1)); + Assert.Equal(0, AsciiString.IndexOfIgnoreCase((AsciiString)"", (StringCharSequence)"", 0)); + Assert.Equal(0, AsciiString.IndexOfIgnoreCase((StringCharSequence)"aabaabaa", (AsciiString)"A", 0)); + Assert.Equal(2, AsciiString.IndexOfIgnoreCase((AsciiString)"aabaabaa", (StringCharSequence)"B", 0)); + Assert.Equal(1, AsciiString.IndexOfIgnoreCase((StringCharSequence)"aabaabaa", (AsciiString)"AB", 0)); + Assert.Equal(5, AsciiString.IndexOfIgnoreCase((AsciiString)"aabaabaa", (StringCharSequence)"B", 3)); + Assert.Equal(-1, AsciiString.IndexOfIgnoreCase((StringCharSequence)"aabaabaa", (AsciiString)"B", 9)); + Assert.Equal(2, AsciiString.IndexOfIgnoreCase((AsciiString)"aabaabaa", (StringCharSequence)"B", -1)); + Assert.Equal(2, AsciiString.IndexOfIgnoreCase((StringCharSequence)"aabaabaa", (StringCharSequence)"", 2)); + Assert.Equal(-1, AsciiString.IndexOfIgnoreCase((AsciiString)"abc", (AsciiString)"", 9)); + Assert.Equal(0, AsciiString.IndexOfIgnoreCase((StringCharSequence)"ãabaabaa", (AsciiString)"Ã", 0)); + } + + [Fact] + public void IndexOfIgnoreCaseAscii() + { + Assert.Equal(-1, AsciiString.IndexOfIgnoreCaseAscii(null, (AsciiString)"abc", 1)); + Assert.Equal(-1, AsciiString.IndexOfIgnoreCaseAscii((AsciiString)"abc", null, 1)); + Assert.Equal(0, AsciiString.IndexOfIgnoreCaseAscii((AsciiString)"", (StringCharSequence)"", 0)); + Assert.Equal(0, AsciiString.IndexOfIgnoreCaseAscii((StringCharSequence)"aabaabaa", (AsciiString)"A", 0)); + Assert.Equal(2, AsciiString.IndexOfIgnoreCaseAscii((AsciiString)"aabaabaa", (StringCharSequence)"B", 0)); + Assert.Equal(1, AsciiString.IndexOfIgnoreCaseAscii((StringCharSequence)"aabaabaa", (StringCharSequence)"AB", 0)); + Assert.Equal(5, AsciiString.IndexOfIgnoreCaseAscii((StringCharSequence)"aabaabaa", (AsciiString)"B", 3)); + Assert.Equal(-1, AsciiString.IndexOfIgnoreCaseAscii((AsciiString)"aabaabaa", (StringCharSequence)"B", 9)); + Assert.Equal(2, AsciiString.IndexOfIgnoreCaseAscii((AsciiString)"aabaabaa", new AsciiString("B"), -1)); + Assert.Equal(2, AsciiString.IndexOfIgnoreCaseAscii((StringCharSequence)"aabaabaa", (AsciiString)"", 2)); + Assert.Equal(-1, AsciiString.IndexOfIgnoreCaseAscii((AsciiString)"abc", (StringCharSequence)"", 9)); + } + + [Fact] + public void Trim() + { + Assert.Equal("", AsciiString.Empty.Trim().ToString()); + Assert.Equal("abc", new AsciiString(" abc").Trim().ToString()); + Assert.Equal("abc", new AsciiString("abc ").Trim().ToString()); + Assert.Equal("abc", new AsciiString(" abc ").Trim().ToString()); + } + + [Fact] + public void IndexOfChar() + { + Assert.Equal(-1, CharUtil.IndexOf(null, 'a', 0)); + Assert.Equal(-1, ((AsciiString)"").IndexOf('a', 0)); + Assert.Equal(-1, ((AsciiString)"abc").IndexOf('d', 0)); + Assert.Equal(-1, ((AsciiString)"aabaabaa").IndexOf('A', 0)); + Assert.Equal(0, ((AsciiString)"aabaabaa").IndexOf('a', 0)); + Assert.Equal(1, ((AsciiString)"aabaabaa").IndexOf('a', 1)); + Assert.Equal(3, ((AsciiString)"aabaabaa").IndexOf('a', 2)); + Assert.Equal(3, ((AsciiString)"aabdabaa").IndexOf('d', 1)); + } + + [Fact] + public void StaticIndexOfChar() + { + Assert.Equal(-1, CharUtil.IndexOf(null, 'a', 0)); + Assert.Equal(-1, CharUtil.IndexOf((AsciiString)"", 'a', 0)); + Assert.Equal(-1, CharUtil.IndexOf((AsciiString)"abc", 'd', 0)); + Assert.Equal(-1, CharUtil.IndexOf((AsciiString)"aabaabaa", 'A', 0)); + Assert.Equal(0, CharUtil.IndexOf((AsciiString)"aabaabaa", 'a', 0)); + Assert.Equal(1, CharUtil.IndexOf((AsciiString)"aabaabaa", 'a', 1)); + Assert.Equal(3, CharUtil.IndexOf((AsciiString)"aabaabaa", 'a', 2)); + Assert.Equal(3, CharUtil.IndexOf((AsciiString)"aabdabaa", 'd', 1)); + } + + [Fact] + public void SubStringHashCode() + { + var value1 = new AsciiString("123"); + var value2 = new AsciiString("a123".Substring(1)); + + //two "123"s + Assert.Equal(AsciiString.GetHashCode(value1), AsciiString.GetHashCode(value2)); + } + } +} diff --git a/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs b/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs index e998c82..b138b76 100644 --- a/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs +++ b/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs @@ -298,14 +298,14 @@ namespace DotNetty.Handlers.Tests public void NoAutoReadHandshakeProgresses(bool dropChannelActive) { var readHandler = new ReadRegisterHandler(); - EmbeddedChannel ch = new EmbeddedChannel(EmbeddedChannelId.Instance, false, false, + var ch = new EmbeddedChannel(EmbeddedChannelId.Instance, false, false, readHandler, TlsHandler.Client("dotnetty.com"), new ActivatingHandler(dropChannelActive) ); ch.Configuration.AutoRead = false; - ch.Start(); + ch.Register(); Assert.False(ch.Configuration.AutoRead); Assert.True(ch.WriteOutbound(Unpooled.Empty)); Assert.True(readHandler.ReadIssued); diff --git a/test/DotNetty.Microbench/Allocators/PooledHeapByteBufferAllocatorBenchmark.cs b/test/DotNetty.Microbench/Allocators/PooledHeapByteBufferAllocatorBenchmark.cs new file mode 100644 index 0000000..843fdf2 --- /dev/null +++ b/test/DotNetty.Microbench/Allocators/PooledHeapByteBufferAllocatorBenchmark.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Microbench.Allocators +{ + using DotNetty.Buffers; + + public class PooledHeapByteBufferAllocatorBenchmark : AbstractByteBufferAllocatorBenchmark + { + public PooledHeapByteBufferAllocatorBenchmark() + : base( new PooledByteBufferAllocator(true, 4, 4, 8192, 11, 0, 0, 0)) // Disable thread-local cache + { + } + } +} diff --git a/test/DotNetty.Microbench/Allocators/UnpooledByteBufferAllocatorBenchmark.cs b/test/DotNetty.Microbench/Allocators/UnpooledByteBufferAllocatorBenchmark.cs index f0a8ce2..b84b53e 100644 --- a/test/DotNetty.Microbench/Allocators/UnpooledByteBufferAllocatorBenchmark.cs +++ b/test/DotNetty.Microbench/Allocators/UnpooledByteBufferAllocatorBenchmark.cs @@ -5,9 +5,9 @@ namespace DotNetty.Microbench.Allocators { using DotNetty.Buffers; - public class UnpooledByteBufferAllocatorBenchmark : AbstractByteBufferAllocatorBenchmark + public class UnpooledHeapByteBufferAllocatorBenchmark : AbstractByteBufferAllocatorBenchmark { - public UnpooledByteBufferAllocatorBenchmark() : base(new UnpooledByteBufferAllocator(true)) + public UnpooledHeapByteBufferAllocatorBenchmark() : base(new UnpooledByteBufferAllocator(true)) { } } diff --git a/test/DotNetty.Microbench/Buffers/PooledByteBufferBenchmark.cs b/test/DotNetty.Microbench/Buffers/PooledByteBufferBenchmark.cs index 64cc9c5..5edecf7 100644 --- a/test/DotNetty.Microbench/Buffers/PooledByteBufferBenchmark.cs +++ b/test/DotNetty.Microbench/Buffers/PooledByteBufferBenchmark.cs @@ -79,4 +79,4 @@ namespace DotNetty.Microbench.Buffers [Benchmark] public long GetLong() => this.buffer.GetLong(0); } -} +} \ No newline at end of file diff --git a/test/DotNetty.Microbench/Buffers/UnpooledByteBufferBenchmark.cs b/test/DotNetty.Microbench/Buffers/UnpooledByteBufferBenchmark.cs index eba878e..68d0f52 100644 --- a/test/DotNetty.Microbench/Buffers/UnpooledByteBufferBenchmark.cs +++ b/test/DotNetty.Microbench/Buffers/UnpooledByteBufferBenchmark.cs @@ -73,4 +73,4 @@ namespace DotNetty.Microbench.Buffers [Benchmark] public long GetLong() => this.buffer.GetLong(0); } -} +} \ No newline at end of file diff --git a/test/DotNetty.Microbench/Codecs/DateFormatterBenchmark.cs b/test/DotNetty.Microbench/Codecs/DateFormatterBenchmark.cs new file mode 100644 index 0000000..615e73c --- /dev/null +++ b/test/DotNetty.Microbench/Codecs/DateFormatterBenchmark.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Microbench.Codecs +{ + using System; + using BenchmarkDotNet.Attributes; + using BenchmarkDotNet.Attributes.Jobs; + using DotNetty.Codecs; + + [CoreJob] + [BenchmarkCategory("Codecs")] + public class DateFormatterBenchmark + { + const string DateString = "Sun, 27 Nov 2016 19:18:46 GMT"; + readonly DateTime date = new DateTime(784111777000L); + + [Benchmark] + public DateTime? ParseHttpHeaderDateFormatter() => DateFormatter.ParseHttpDate(DateString); + + [Benchmark] + public string FormatHttpHeaderDateFormatter() => DateFormatter.Format(this.date); + } +} diff --git a/test/DotNetty.Microbench/Common/AsciiStringBenchmark.cs b/test/DotNetty.Microbench/Common/AsciiStringBenchmark.cs new file mode 100644 index 0000000..bad2dcf --- /dev/null +++ b/test/DotNetty.Microbench/Common/AsciiStringBenchmark.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Microbench.Common +{ + using System; + using System.Text; + using BenchmarkDotNet.Attributes; + using BenchmarkDotNet.Attributes.Jobs; + using DotNetty.Common.Internal; + using DotNetty.Common.Utilities; +#if NET46 + using BenchmarkDotNet.Diagnostics.Windows.Configs; +#endif + +#if !NET46 + [CoreJob] +#else + [ClrJob] + [InliningDiagnoser] +#endif + [BenchmarkCategory("Common")] + public class AsciiStringBenchmark + { + [Params(3, 5, 7, 8, 10, 20, 50, 100, 1000)] + public int Size { get; set; } + + AsciiString asciiString; + StringCharSequence stringValue; + static readonly Random RandomGenerator = new Random(); + + [GlobalSetup] + public void GlobalSetup() + { + var bytes = new byte[this.Size]; + RandomGenerator.NextBytes(bytes); + + this.asciiString = new AsciiString(bytes, false); + string value = Encoding.ASCII.GetString(bytes); + this.stringValue = new StringCharSequence(value); + } + + [Benchmark] + public int CharSequenceHashCode() => PlatformDependent.HashCodeAscii(this.stringValue); + + [Benchmark] + public int AsciiStringHashCode() => PlatformDependent.HashCodeAscii( + this.asciiString.Array, this.asciiString.Offset, this.asciiString.Count); + } +} diff --git a/test/DotNetty.Microbench/DotNetty.Microbench.csproj b/test/DotNetty.Microbench/DotNetty.Microbench.csproj index 0234caa..bf7c66e 100644 --- a/test/DotNetty.Microbench/DotNetty.Microbench.csproj +++ b/test/DotNetty.Microbench/DotNetty.Microbench.csproj @@ -18,6 +18,7 @@ + diff --git a/test/DotNetty.Microbench/Headers/ExampleHeaders.cs b/test/DotNetty.Microbench/Headers/ExampleHeaders.cs new file mode 100644 index 0000000..39118fd --- /dev/null +++ b/test/DotNetty.Microbench/Headers/ExampleHeaders.cs @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Microbench.Headers +{ + using System.Collections.Generic; + + public enum HeaderExample + { + Three = 3, + Five = 5, + Six = 6, + Eight = 8, + Eleven = 11, + Twentytwo = 22, + Thirty = 30 + } + + static class ExampleHeaders + { + public static Dictionary> GetExamples() + { + var examples = new Dictionary>(); + + var header = new Dictionary + { + { ":method", "GET" }, + { ":scheme", "https" }, + { ":path", "/index.html" } + }; + examples.Add(HeaderExample.Three, header); + + // Headers used by Norman's HTTP benchmarks with wrk + header = new Dictionary + { + { "Method", "GET" }, + { "Path", "/plaintext" }, + { "Host", "localhost" }, + { "Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8" }, + { "Connection", "keep-alive" } + }; + examples.Add(HeaderExample.Five, header); + + header = new Dictionary + { + { ":authority", "127.0.0.1:33333" }, + { ":method", "POST" }, + { ":path", "/grpc.testing.TestService/UnaryCall" }, + { ":scheme", "http" }, + { "content-type", "application/grpc" }, + { "te", "trailers" } + }; + examples.Add(HeaderExample.Six, header); + + header = new Dictionary + { + { ":method", "POST" }, + { ":scheme", "http" }, + { ":path", "/google.pubsub.v2.PublisherService/CreateTopic" }, + { ":authority", "pubsub.googleapis.com" }, + { "grpc-timeout", "1S" }, + { "content-type", "application/grpc+proto" }, + { "grpc-encoding", "gzip" }, + { "authorization", "Bearer y235.wef315yfh138vh31hv93hv8h3v" } + }; + examples.Add(HeaderExample.Eight, header); + + header = new Dictionary + { + { ":host", "twitter.com" }, + { ":method", "GET" }, + { ":path", "/" }, + { ":scheme", "https" }, + { ":version", "HTTP/1.1" }, + { "accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8" }, + { "accept-encoding", "gzip, deflate, sdch" }, + { "accept-language", "en-US,en;q=0.8" }, + { "cache-control", "max-age=0" }, + { "cookie", "noneofyourbusiness" }, + { "user-agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko)" } + }; + examples.Add(HeaderExample.Eleven, header); + + header = new Dictionary + { + { "cache-control", "no-cache, no-store, must-revalidate, pre-check=0, post-check=0" }, + { "content-encoding", "gzip" }, + { "content-security-policy", "default-src https:; connect-src https:;" }, + { "content-type", "text/html;charset=utf-8" }, + { "date", "Wed, 22 Apr 2015 00:40:28 GMT" }, + { "expires", "Tue, 31 Mar 1981 05:00:00 GMT" }, + { "last-modified", "Wed, 22 Apr 2015 00:40:28 GMT" }, + { "ms", "ms" }, + { "pragma", "no-cache" }, + { "server", "tsa_b" }, + { "set-cookie", "noneofyourbusiness" }, + { "status", "200 OK" }, + { "strict-transport-security", "max-age=631138519" }, + { "version", "HTTP/1.1" }, + { "x-connection-hash", "e176fe40accc1e2c613a34bc1941aa98" }, + { "x-content-type-options", "nosniff" }, + { "x-frame-options", "SAMEORIGIN" }, + { "x-response-time", "22" }, + { "x-transaction", "a54142ede693444d9" }, + { "x-twitter-response-tags", "BouncerCompliant" }, + { "x-ua-compatible", "IE=edge,chrome=1" }, + { "x-xss-protection", "1; mode=block" } + }; + examples.Add(HeaderExample.Twentytwo, header); + + header = new Dictionary + { + { "Cache-Control", "no-cache" }, + { "Content-Encoding", "gzip" }, + { "Content-Security-Policy", "default-src *; script-src assets-cdn.github.com ..." }, + { "Content-Type", "text/html; charset=utf-8" }, + { "Date", "Fri, 10 Apr 2015 02:15:38 GMT" }, + { "Server", "GitHub.com" }, + { "Set-Cookie", "_gh_sess=eyJzZXNza...; path=/; secure; HttpOnly" }, + { "Status", "200 OK" }, + { "Strict-Transport-Security", "max-age=31536000; includeSubdomains; preload" }, + { "Transfer-Encoding", "chunked" }, + { "Vary", "X-PJAX" }, + { "X-Content-Type-Options", "nosniff" }, + { "X-Frame-Options", "deny" }, + { "X-GitHub-Request-Id", "1" }, + { "X-GitHub-Session-Id", "1" }, + { "X-GitHub-User", "buchgr" }, + { "X-Request-Id", "28f245e02fc872dcf7f31149e52931dd" }, + { "X-Runtime", "0.082529" }, + { "X-Served-By", "b9c2a233f7f3119b174dbd8be2" }, + { "X-UA-Compatible", "IE=Edge,chrome=1" }, + { "X-XSS-Protection", "1; mode=block" }, + { "Via", "http/1.1 ir50.fp.bf1.yahoo.com (ApacheTrafficServer)" }, + { "Content-Language", "en" }, + { "Connection", "keep-alive" }, + { "Pragma", "no-cache" }, + { "Expires", "Sat, 01 Jan 2000 00:00:00 GMT" }, + { "X-Moose", "majestic" }, + { "x-ua-compatible", "IE=edge" }, + { "CF-Cache-Status", "HIT" }, + { "CF-RAY", "6a47f4f911e3-" } + }; + examples.Add(HeaderExample.Thirty, header); + + return examples; + } + } +} diff --git a/test/DotNetty.Microbench/Headers/HeadersBenchmark.cs b/test/DotNetty.Microbench/Headers/HeadersBenchmark.cs new file mode 100644 index 0000000..305f96b --- /dev/null +++ b/test/DotNetty.Microbench/Headers/HeadersBenchmark.cs @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Microbench.Headers +{ + using System.Collections.Generic; + using BenchmarkDotNet.Attributes; + using BenchmarkDotNet.Attributes.Jobs; + using BenchmarkDotNet.Engines; + using DotNetty.Codecs; + using DotNetty.Codecs.Http; + using DotNetty.Common; + using DotNetty.Common.Utilities; + + [SimpleJob(RunStrategy.Monitoring, 10, 5, 10)] + [BenchmarkCategory("Headers")] + public class HeadersBenchmark + { + [Params(3, 5, 6, 8, 11, 22, 30)] + public int HeaderSize { get; set; } + + AsciiString[] httpNames; + AsciiString[] httpValues; + + DefaultHttpHeaders httpHeaders; + DefaultHttpHeaders emptyHttpHeaders; + DefaultHttpHeaders emptyHttpHeadersNoValidate; + + static string ToHttpName(string name) => name.StartsWith(":") ? name.Substring(1) : name; + + [GlobalSetup] + public void GlobalSetup() + { + ResourceLeakDetector.Level = ResourceLeakDetector.DetectionLevel.Disabled; + Dictionary> headersSet = ExampleHeaders.GetExamples(); + Dictionary headers = headersSet[(HeaderExample)this.HeaderSize]; + this.httpNames = new AsciiString[headers.Count]; + this.httpValues = new AsciiString[headers.Count]; + this.httpHeaders = new DefaultHttpHeaders(false); + int idx = 0; + foreach (KeyValuePair header in headers) + { + string httpName = ToHttpName(header.Key); + string value = header.Value; + this.httpNames[idx] = new AsciiString(httpName); + this.httpValues[idx] = new AsciiString(value); + this.httpHeaders.Add(this.httpNames[idx], this.httpValues[idx]); + idx++; + } + this.emptyHttpHeaders = new DefaultHttpHeaders(); + this.emptyHttpHeadersNoValidate = new DefaultHttpHeaders(false); + } + + [Benchmark] + public DefaultHttpHeaders HttpRemove() + { + foreach(AsciiString name in this.httpNames) + { + this.httpHeaders.Remove(name); + } + + return this.httpHeaders; + } + + [Benchmark] + public DefaultHttpHeaders HttpGet() + { + foreach (AsciiString name in this.httpNames) + { + this.httpHeaders.TryGet(name, out _); + } + return this.httpHeaders; + } + + [Benchmark] + public DefaultHttpHeaders HttpPut() + { + var headers = new DefaultHttpHeaders(false); + for (int i = 0; i < this.httpNames.Length; i++) + { + headers.Add(this.httpNames[i], this.httpValues[i]); + } + return headers; + } + + [Benchmark] + public List> HttpIterate() + { + var list = new List>(); + foreach (HeaderEntry header in this.httpHeaders) + { + list.Add(header); + } + return list; + } + + [Benchmark] + public DefaultHttpHeaders HttpAddAllFastest() + { + this.emptyHttpHeadersNoValidate.Add(this.httpHeaders); + return this.emptyHttpHeadersNoValidate; + } + + [Benchmark] + public DefaultHttpHeaders HttpAddAllFast() + { + this.emptyHttpHeaders.Add(this.httpHeaders); + return this.emptyHttpHeaders; + } + } +} diff --git a/test/DotNetty.Microbench/Http/ClientCookieDecoderBenchmark.cs b/test/DotNetty.Microbench/Http/ClientCookieDecoderBenchmark.cs new file mode 100644 index 0000000..b7f1a9b --- /dev/null +++ b/test/DotNetty.Microbench/Http/ClientCookieDecoderBenchmark.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Microbench.Http +{ + using BenchmarkDotNet.Attributes; + using BenchmarkDotNet.Attributes.Jobs; + using DotNetty.Codecs.Http.Cookies; + using DotNetty.Common; + + [CoreJob] + [BenchmarkCategory("Http")] + public class ClientCookieDecoderBenchmark + { + const string CookieString = + "__Host-user_session_same_site=fgfMsM59vJTpZg88nxqKkIhgOt0ADF8LX8wjMMbtcb4IJMufWCnCcXORhbo9QMuyiybdtx; " + + "path=/; expires=Mon, 28 Nov 2016 13:56:01 GMT; secure; HttpOnly"; + + [GlobalSetup] + public void GlobalSetup() + { + ResourceLeakDetector.Level = ResourceLeakDetector.DetectionLevel.Disabled; + } + + [Benchmark] + public ICookie DecodeCookieWithRfc1123ExpiresField() => ClientCookieDecoder.StrictDecoder.Decode(CookieString); + } +} diff --git a/test/DotNetty.Microbench/Http/HttpRequestDecoderBenchmark.cs b/test/DotNetty.Microbench/Http/HttpRequestDecoderBenchmark.cs new file mode 100644 index 0000000..bc70966 --- /dev/null +++ b/test/DotNetty.Microbench/Http/HttpRequestDecoderBenchmark.cs @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Microbench.Http +{ + using System.Text; + using BenchmarkDotNet.Attributes; + using BenchmarkDotNet.Attributes.Jobs; + using BenchmarkDotNet.Engines; + using DotNetty.Buffers; + using DotNetty.Codecs.Http; + using DotNetty.Common; + using DotNetty.Transport.Channels.Embedded; + + [SimpleJob(RunStrategy.Monitoring, 10, 5, 20)] + [BenchmarkCategory("Http")] + public class HttpRequestDecoderBenchmark + { + const int ContentLength = 120; + + [Params(2, 4, 8, 16, 32)] + public int Step { get; set; } + + byte[] contentMixedDelimiters; + + [GlobalSetup] + public void GlobalSetup() + { + ResourceLeakDetector.Level = ResourceLeakDetector.DetectionLevel.Disabled; + this.contentMixedDelimiters = CreateContent("\r\n", "\n"); + } + + [Benchmark] + public void DecodeWholeRequestInMultipleStepsMixedDelimiters() => + DecodeWholeRequestInMultipleSteps(this.contentMixedDelimiters, this.Step); + + static void DecodeWholeRequestInMultipleSteps(byte[] content, int fragmentSize) + { + var channel = new EmbeddedChannel(new HttpRequestDecoder()); + + int headerLength = content.Length - ContentLength; + + // split up the header + for (int a = 0; a < headerLength;) + { + int amount = fragmentSize; + if (a + amount > headerLength) + { + amount = headerLength - a; + } + + // if header is done it should produce a HttpRequest + channel.WriteInbound(Unpooled.WrappedBuffer(content, a, amount)); + a += amount; + } + + for (int i = ContentLength; i > 0; i--) + { + // Should produce HttpContent + channel.WriteInbound(Unpooled.WrappedBuffer(content, content.Length - i, 1)); + } + } + + static byte[] CreateContent(params string[] lineDelimiters) + { + string lineDelimiter; + string lineDelimiter2; + if (lineDelimiters.Length == 2) + { + lineDelimiter = lineDelimiters[0]; + lineDelimiter2 = lineDelimiters[1]; + } + else + { + lineDelimiter = lineDelimiters[0]; + lineDelimiter2 = lineDelimiters[0]; + } + // This GET request is incorrect but it does not matter for HttpRequestDecoder. + // It used only to get a long request. + return Encoding.ASCII.GetBytes("GET /some/path?foo=bar&wibble=eek HTTP/1.1" + "\r\n" + + "Upgrade: WebSocket" + lineDelimiter2 + + "Connection: Upgrade" + lineDelimiter + + "Host: localhost" + lineDelimiter2 + + "Referer: http://www.site.ru/index.html" + lineDelimiter + + "User-Agent: Mozilla/5.0 (X11; U; Linux i686; ru; rv:1.9b5) Gecko/2008050509 Firefox/3.0b5" + + lineDelimiter2 + + "Accept: text/html" + lineDelimiter + + "Cookie: income=1" + lineDelimiter2 + + "Origin: http://localhost:8080" + lineDelimiter + + "Sec-WebSocket-Key1: 10 28 8V7 8 48 0" + lineDelimiter2 + + "Sec-WebSocket-Key2: 8 Xt754O3Q3QW 0 _60" + lineDelimiter + + "Content-Type: application/x-www-form-urlencoded" + lineDelimiter2 + + "Content-Length: " + ContentLength + lineDelimiter + + "\r\n" + + "1234567890\r\n" + + "1234567890\r\n" + + "1234567890\r\n" + + "1234567890\r\n" + + "1234567890\r\n" + + "1234567890\r\n" + + "1234567890\r\n" + + "1234567890\r\n" + + "1234567890\r\n" + + "1234567890\r\n" + ); + } + } +} diff --git a/test/DotNetty.Microbench/Http/HttpRequestEncoderInsertBenchmark.cs b/test/DotNetty.Microbench/Http/HttpRequestEncoderInsertBenchmark.cs new file mode 100644 index 0000000..f428c89 --- /dev/null +++ b/test/DotNetty.Microbench/Http/HttpRequestEncoderInsertBenchmark.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Microbench.Http +{ + using BenchmarkDotNet.Attributes; + using BenchmarkDotNet.Attributes.Jobs; + using DotNetty.Buffers; + using DotNetty.Codecs.Http; + using DotNetty.Common; + + [CoreJob] + [BenchmarkCategory("Http")] + public class HttpRequestEncoderInsertBenchmark + { + string uri; + HttpRequestEncoder encoder; + + [GlobalSetup] + public void GlobalSetup() + { + ResourceLeakDetector.Level = ResourceLeakDetector.DetectionLevel.Disabled; + this.uri = "http://localhost?eventType=CRITICAL&from=0&to=1497437160327&limit=10&offset=0"; + this. encoder = new HttpRequestEncoder(); + } + + [Benchmark] + public IByteBuffer EncodeInitialLine() + { + IByteBuffer buffer = Unpooled.Buffer(100); + try + { + this.encoder.EncodeInitialLine(buffer, new DefaultHttpRequest(HttpVersion.Http11, + HttpMethod.Get,this.uri)); + return buffer; + } + finally + { + buffer.Release(); + } + } + } +} diff --git a/test/DotNetty.Microbench/Http/WriteBytesVsShortOrMediumBenchmark.cs b/test/DotNetty.Microbench/Http/WriteBytesVsShortOrMediumBenchmark.cs new file mode 100644 index 0000000..ddb5a91 --- /dev/null +++ b/test/DotNetty.Microbench/Http/WriteBytesVsShortOrMediumBenchmark.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Microbench.Http +{ + using BenchmarkDotNet.Attributes; + using BenchmarkDotNet.Attributes.Jobs; + using DotNetty.Buffers; + using DotNetty.Codecs.Http; + using DotNetty.Common; + using DotNetty.Common.Utilities; + + [CoreJob] + [BenchmarkCategory("Http")] + public class WriteBytesVsShortOrMediumBenchmark + { + const int CrlfShort = (HttpConstants.CarriageReturn << 8) + HttpConstants.LineFeed; + const int ZeroCrlfMedium = ('0' << 16) + (HttpConstants.CarriageReturn << 8) + HttpConstants.LineFeed; + static readonly byte[] Crlf = { HttpConstants.CarriageReturn, HttpConstants.LineFeed }; + static readonly byte[] ZeroCrlf = { (byte)'0', HttpConstants.CarriageReturn, HttpConstants.LineFeed }; + + IByteBuffer buf; + + [GlobalSetup] + public void GlobalSetup() + { + ResourceLeakDetector.Level = ResourceLeakDetector.DetectionLevel.Disabled; + this.buf = Unpooled.Buffer(16); + } + + [Benchmark] + public IByteBuffer ShortInt() => this.buf.WriteShort(CrlfShort).ResetWriterIndex(); + + [Benchmark] + public IByteBuffer MediumInt() => this.buf.WriteMedium(ZeroCrlfMedium).ResetWriterIndex(); + + [Benchmark] + public IByteBuffer ByteArray2() => this.buf.WriteBytes(Crlf).ResetWriterIndex(); + + [Benchmark] + public IByteBuffer ByteArray3() => this.buf.WriteBytes(ZeroCrlf).ResetWriterIndex(); + + [Benchmark] + public IByteBuffer ChainedBytes2() => + this.buf.WriteByte(HttpConstants.CarriageReturn).WriteByte(HttpConstants.LineFeed).ResetWriterIndex(); + + [Benchmark] + public IByteBuffer ChainedBytes3() => + this.buf.WriteByte('0').WriteByte(HttpConstants.CarriageReturn).WriteByte(HttpConstants.LineFeed).ResetWriterIndex(); + + [GlobalCleanup] + public void GlobalCleanup() + { + this.buf?.SafeRelease(); + this.buf = null; + } + } +} diff --git a/test/DotNetty.Microbench/Internal/PlatformDependentBenchmark.cs b/test/DotNetty.Microbench/Internal/PlatformDependentBenchmark.cs new file mode 100644 index 0000000..51292af --- /dev/null +++ b/test/DotNetty.Microbench/Internal/PlatformDependentBenchmark.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Microbench.Internal +{ + using BenchmarkDotNet.Attributes; + using BenchmarkDotNet.Attributes.Jobs; + using DotNetty.Common.Internal; + + [CoreJob] + [BenchmarkCategory("Internal")] + public class PlatformDependentBenchmark + { + [Params(10, 50, 100, 1000, 10000, 100000)] + public int Size { get; set; } + + byte[] bytes1; + byte[] bytes2; + + [GlobalSetup] + public void GlobalSetup() + { + this.bytes1 = new byte[this.Size]; + this.bytes2 = new byte[this.Size]; + for (int i = 0; i < this.Size; i++) + { + this.bytes1[i] = this.bytes2[i] = (byte)i; + } + } + + [Benchmark] + public bool UnsafeBytesEqual() => + PlatformDependent.ByteArrayEquals(this.bytes1, 0, this.bytes2, 0, this.bytes1.Length); + } +} diff --git a/test/DotNetty.Microbench/Program.cs b/test/DotNetty.Microbench/Program.cs index 2c81355..935030e 100644 --- a/test/DotNetty.Microbench/Program.cs +++ b/test/DotNetty.Microbench/Program.cs @@ -7,20 +7,41 @@ namespace DotNetty.Microbench using BenchmarkDotNet.Running; using DotNetty.Microbench.Allocators; using DotNetty.Microbench.Buffers; + using DotNetty.Microbench.Codecs; + using DotNetty.Microbench.Common; using DotNetty.Microbench.Concurrency; + using DotNetty.Microbench.Headers; + using DotNetty.Microbench.Http; + using DotNetty.Microbench.Internal; class Program { static readonly Type[] BenchmarkTypes = { - typeof(PooledByteBufferAllocatorBenchmark), - typeof(UnpooledByteBufferAllocatorBenchmark), + typeof(PooledHeapByteBufferAllocatorBenchmark), + typeof(UnpooledHeapByteBufferAllocatorBenchmark), + typeof(ByteBufferBenchmark), + typeof(PooledByteBufferBenchmark), typeof(UnpooledByteBufferBenchmark), typeof(PooledByteBufferBenchmark), typeof(ByteBufUtilBenchmark), + + typeof(DateFormatterBenchmark), + + typeof(AsciiStringBenchmark), + typeof(FastThreadLocalBenchmark), - typeof(SingleThreadEventExecutorBenchmark) + typeof(SingleThreadEventExecutorBenchmark), + + typeof(HeadersBenchmark), + + typeof(ClientCookieDecoderBenchmark), + typeof(HttpRequestDecoderBenchmark), + typeof(HttpRequestEncoderInsertBenchmark), + typeof(WriteBytesVsShortOrMediumBenchmark), + + typeof(PlatformDependentBenchmark) }; static void Main(string[] args) diff --git a/test/DotNetty.Transport.Tests/Channel/Sockets/SocketDatagramChannelUnicastTest.cs b/test/DotNetty.Transport.Tests/Channel/Sockets/SocketDatagramChannelUnicastTest.cs index 719161d..908a03c 100644 --- a/test/DotNetty.Transport.Tests/Channel/Sockets/SocketDatagramChannelUnicastTest.cs +++ b/test/DotNetty.Transport.Tests/Channel/Sockets/SocketDatagramChannelUnicastTest.cs @@ -77,14 +77,14 @@ namespace DotNetty.Transport.Tests.Channel.Sockets result = this.sequenceEqual; } } - finally + finally { this.resetEvent.Reset(); this.sequenceEqual = false; } return result; - } + } } static readonly byte[] Data = { 0, 1, 2, 3 }; @@ -171,7 +171,7 @@ namespace DotNetty.Transport.Tests.Channel.Sockets this.Output.WriteLine($"Unicast server binding to:({addressFamily}){address}"); Task task = serverBootstrap.BindAsync(address, IPEndPoint.MinPort); - Assert.True(task.Wait(TimeSpan.FromMilliseconds(DefaultTimeOutInMilliseconds * 5)), + Assert.True(task.Wait(TimeSpan.FromMilliseconds(DefaultTimeOutInMilliseconds * 5)), $"Unicast server binding to:({addressFamily}){address} timed out!"); serverChannel = (SocketDatagramChannel)task.Result; @@ -190,7 +190,7 @@ namespace DotNetty.Transport.Tests.Channel.Sockets })); var clientEndPoint = new IPEndPoint( - addressFamily == AddressFamily.InterNetwork ? IPAddress.Any : IPAddress.IPv6Any, + addressFamily == AddressFamily.InterNetwork ? IPAddress.Any : IPAddress.IPv6Any, IPEndPoint.MinPort); clientBootstrap @@ -238,4 +238,4 @@ namespace DotNetty.Transport.Tests.Channel.Sockets } } } -} +} \ No newline at end of file