зеркало из https://github.com/Azure/DotNetty.git
SNI Support (ported netty SniHandler) (#219)
* First cut commit *unfinished * Replace SNI handler with TlsHandler with certificate selected based on host name found in clientHello * first cut tests for SniHandler * test update * Test update * Supress further read when handler is replaced * made the snitest more effective and IDN in hostname lower case as per netty impl * More asserts to check whether snihandler gets replaces with tlshaldler in the pipeline * assert server name is always found in clienthello as per the test setup * Provided option to select default host name in case of error or client hello does not contail SNI extension, otherwise handshake fails in those cases * More elaborate tests * verbosity in test * Fixed Read continues to get called after handler removed and removed the workaround in SniHandler * relaced goto statement with flag for breaking outer for loop from within switch * Update SniHandler.cs * trigger CI build * addressed review comments * Fixed task continuation option * addresses further review comments * triggere build again with some more assert in test #221 * suppress read logic is still needed due to async "void" * changing the map to (string -> Task<ServerTlsSettings) * one more constructor overload * extensive tls read/write test is not needed since that's already done in tlshandler test * more readable target host validation in test to force retrigger confusing CI build * retrigger * addressed review comment "this generates 30 random data frames. pls replace with new [] { 1 }"
This commit is contained in:
Родитель
8357d8ed40
Коммит
1d3eda9a74
Двоичный файл не отображается.
|
@ -151,6 +151,9 @@ namespace DotNetty.Codecs
|
||||||
public override void HandlerRemoved(IChannelHandlerContext context)
|
public override void HandlerRemoved(IChannelHandlerContext context)
|
||||||
{
|
{
|
||||||
IByteBuffer buf = this.InternalBuffer;
|
IByteBuffer buf = this.InternalBuffer;
|
||||||
|
|
||||||
|
// Directly set this to null so we are sure we not access it in any other method here anymore.
|
||||||
|
this.cumulation = null;
|
||||||
int readable = buf.ReadableBytes;
|
int readable = buf.ReadableBytes;
|
||||||
if (readable > 0)
|
if (readable > 0)
|
||||||
{
|
{
|
||||||
|
@ -162,7 +165,7 @@ namespace DotNetty.Codecs
|
||||||
{
|
{
|
||||||
buf.Release();
|
buf.Release();
|
||||||
}
|
}
|
||||||
this.cumulation = null;
|
|
||||||
context.FireChannelReadComplete();
|
context.FireChannelReadComplete();
|
||||||
this.HandlerRemovedInternal(context);
|
this.HandlerRemovedInternal(context);
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,4 +37,9 @@
|
||||||
<Reference Include="System" />
|
<Reference Include="System" />
|
||||||
<Reference Include="Microsoft.CSharp" />
|
<Reference Include="Microsoft.CSharp" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
|
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard1.3'">
|
||||||
|
<PackageReference Include="System.Globalization.Extensions">
|
||||||
|
<Version>4.3.0</Version>
|
||||||
|
</PackageReference>
|
||||||
|
</ItemGroup>
|
||||||
</Project>
|
</Project>
|
|
@ -0,0 +1,23 @@
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
|
||||||
|
namespace DotNetty.Handlers.Tls
|
||||||
|
{
|
||||||
|
using System;
|
||||||
|
using System.Diagnostics.Contracts;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
|
||||||
|
public sealed class ServerTlsSniSettings
|
||||||
|
{
|
||||||
|
public ServerTlsSniSettings(Func<string, Task<ServerTlsSettings>> serverTlsSettingMap, string defaultServerHostName = null)
|
||||||
|
{
|
||||||
|
Contract.Requires(serverTlsSettingMap != null);
|
||||||
|
this.ServerTlsSettingMap = serverTlsSettingMap;
|
||||||
|
this.DefaultServerHostName = defaultServerHostName;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Func<string, Task<ServerTlsSettings>> ServerTlsSettingMap { get; }
|
||||||
|
|
||||||
|
public string DefaultServerHostName { get; }
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,316 @@
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
|
||||||
|
namespace DotNetty.Handlers.Tls
|
||||||
|
{
|
||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Diagnostics.Contracts;
|
||||||
|
using System.Globalization;
|
||||||
|
using System.IO;
|
||||||
|
using System.Net.Security;
|
||||||
|
using System.Text;
|
||||||
|
using DotNetty.Buffers;
|
||||||
|
using DotNetty.Codecs;
|
||||||
|
using DotNetty.Common.Internal.Logging;
|
||||||
|
using DotNetty.Transport.Channels;
|
||||||
|
|
||||||
|
public sealed class SniHandler : ByteToMessageDecoder
|
||||||
|
{
|
||||||
|
// Maximal number of ssl records to inspect before fallback to the default (aligned with netty)
|
||||||
|
const int MAX_SSL_RECORDS = 4;
|
||||||
|
static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(typeof(SniHandler));
|
||||||
|
readonly Func<Stream, SslStream> sslStreamFactory;
|
||||||
|
readonly ServerTlsSniSettings serverTlsSniSettings;
|
||||||
|
|
||||||
|
bool handshakeFailed;
|
||||||
|
bool suppressRead;
|
||||||
|
bool readPending;
|
||||||
|
|
||||||
|
public SniHandler(ServerTlsSniSettings settings)
|
||||||
|
: this(stream => new SslStream(stream, true), settings)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
public SniHandler(Func<Stream, SslStream> sslStreamFactory, ServerTlsSniSettings settings)
|
||||||
|
{
|
||||||
|
Contract.Requires(settings != null);
|
||||||
|
Contract.Requires(sslStreamFactory != null);
|
||||||
|
this.sslStreamFactory = sslStreamFactory;
|
||||||
|
this.serverTlsSniSettings = settings;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override void Decode(IChannelHandlerContext context, IByteBuffer input, List<object> output)
|
||||||
|
{
|
||||||
|
if (!this.suppressRead && !this.handshakeFailed)
|
||||||
|
{
|
||||||
|
int writerIndex = input.WriterIndex;
|
||||||
|
Exception error = null;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
bool continueLoop = true;
|
||||||
|
for (int i = 0; i < MAX_SSL_RECORDS && continueLoop; i++)
|
||||||
|
{
|
||||||
|
int readerIndex = input.ReaderIndex;
|
||||||
|
int readableBytes = writerIndex - readerIndex;
|
||||||
|
if (readableBytes < TlsUtils.SSL_RECORD_HEADER_LENGTH)
|
||||||
|
{
|
||||||
|
// Not enough data to determine the record type and length.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int command = input.GetByte(readerIndex);
|
||||||
|
// tls, but not handshake command
|
||||||
|
switch (command)
|
||||||
|
{
|
||||||
|
case TlsUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
|
||||||
|
case TlsUtils.SSL_CONTENT_TYPE_ALERT:
|
||||||
|
int len = TlsUtils.GetEncryptedPacketLength(input, readerIndex);
|
||||||
|
|
||||||
|
// Not an SSL/TLS packet
|
||||||
|
if (len == TlsUtils.NOT_ENCRYPTED)
|
||||||
|
{
|
||||||
|
this.handshakeFailed = true;
|
||||||
|
var e = new NotSslRecordException(
|
||||||
|
"not an SSL/TLS record: " + ByteBufferUtil.HexDump(input));
|
||||||
|
input.SkipBytes(input.ReadableBytes);
|
||||||
|
|
||||||
|
TlsUtils.NotifyHandshakeFailure(context, e);
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
if (len == TlsUtils.NOT_ENOUGH_DATA ||
|
||||||
|
writerIndex - readerIndex - TlsUtils.SSL_RECORD_HEADER_LENGTH < len)
|
||||||
|
{
|
||||||
|
// Not enough data
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// increase readerIndex and try again.
|
||||||
|
input.SkipBytes(len);
|
||||||
|
continue;
|
||||||
|
|
||||||
|
case TlsUtils.SSL_CONTENT_TYPE_HANDSHAKE:
|
||||||
|
int majorVersion = input.GetByte(readerIndex + 1);
|
||||||
|
|
||||||
|
// SSLv3 or TLS
|
||||||
|
if (majorVersion == 3)
|
||||||
|
{
|
||||||
|
int packetLength = input.GetUnsignedShort(readerIndex + 3) + TlsUtils.SSL_RECORD_HEADER_LENGTH;
|
||||||
|
|
||||||
|
if (readableBytes < packetLength)
|
||||||
|
{
|
||||||
|
// client hello incomplete; try again to decode once more data is ready.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
|
||||||
|
//
|
||||||
|
// Decode the ssl client hello packet.
|
||||||
|
// We have to skip bytes until SessionID (which sum to 43 bytes).
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// ProtocolVersion client_version;
|
||||||
|
// Random random;
|
||||||
|
// SessionID session_id;
|
||||||
|
// CipherSuite cipher_suites<2..2^16-2>;
|
||||||
|
// CompressionMethod compression_methods<1..2^8-1>;
|
||||||
|
// select (extensions_present) {
|
||||||
|
// case false:
|
||||||
|
// struct {};
|
||||||
|
// case true:
|
||||||
|
// Extension extensions<0..2^16-1>;
|
||||||
|
// };
|
||||||
|
// } ClientHello;
|
||||||
|
//
|
||||||
|
|
||||||
|
int endOffset = readerIndex + packetLength;
|
||||||
|
int offset = readerIndex + 43;
|
||||||
|
|
||||||
|
if (endOffset - offset < 6)
|
||||||
|
{
|
||||||
|
continueLoop = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
int sessionIdLength = input.GetByte(offset);
|
||||||
|
offset += sessionIdLength + 1;
|
||||||
|
|
||||||
|
int cipherSuitesLength = input.GetUnsignedShort(offset);
|
||||||
|
offset += cipherSuitesLength + 2;
|
||||||
|
|
||||||
|
int compressionMethodLength = input.GetByte(offset);
|
||||||
|
offset += compressionMethodLength + 1;
|
||||||
|
|
||||||
|
int extensionsLength = input.GetUnsignedShort(offset);
|
||||||
|
offset += 2;
|
||||||
|
int extensionsLimit = offset + extensionsLength;
|
||||||
|
|
||||||
|
if (extensionsLimit > endOffset)
|
||||||
|
{
|
||||||
|
// Extensions should never exceed the record boundary.
|
||||||
|
continueLoop = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (;;)
|
||||||
|
{
|
||||||
|
if (extensionsLimit - offset < 4)
|
||||||
|
{
|
||||||
|
continueLoop = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
int extensionType = input.GetUnsignedShort(offset);
|
||||||
|
offset += 2;
|
||||||
|
|
||||||
|
int extensionLength = input.GetUnsignedShort(offset);
|
||||||
|
offset += 2;
|
||||||
|
|
||||||
|
if (extensionsLimit - offset < extensionLength)
|
||||||
|
{
|
||||||
|
continueLoop = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// SNI
|
||||||
|
// See https://tools.ietf.org/html/rfc6066#page-6
|
||||||
|
if (extensionType == 0)
|
||||||
|
{
|
||||||
|
offset += 2;
|
||||||
|
if (extensionsLimit - offset < 3)
|
||||||
|
{
|
||||||
|
continueLoop = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
int serverNameType = input.GetByte(offset);
|
||||||
|
offset++;
|
||||||
|
|
||||||
|
if (serverNameType == 0)
|
||||||
|
{
|
||||||
|
int serverNameLength = input.GetUnsignedShort(offset);
|
||||||
|
offset += 2;
|
||||||
|
|
||||||
|
if (serverNameLength <= 0 || extensionsLimit - offset < serverNameLength)
|
||||||
|
{
|
||||||
|
continueLoop = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
string hostname = input.ToString(offset, serverNameLength, Encoding.UTF8);
|
||||||
|
//try
|
||||||
|
//{
|
||||||
|
// select(ctx, IDN.toASCII(hostname,
|
||||||
|
// IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US));
|
||||||
|
//}
|
||||||
|
//catch (Throwable t)
|
||||||
|
//{
|
||||||
|
// PlatformDependent.throwException(t);
|
||||||
|
//}
|
||||||
|
|
||||||
|
var idn = new IdnMapping()
|
||||||
|
{
|
||||||
|
AllowUnassigned = true
|
||||||
|
};
|
||||||
|
|
||||||
|
hostname = idn.GetAscii(hostname);
|
||||||
|
#if NETSTANDARD1_3
|
||||||
|
// TODO: netcore does not have culture sensitive tolower()
|
||||||
|
hostname = hostname.ToLowerInvariant();
|
||||||
|
#else
|
||||||
|
hostname = hostname.ToLower(new CultureInfo("en-US"));
|
||||||
|
#endif
|
||||||
|
this.Select(context, hostname);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// invalid enum value
|
||||||
|
continueLoop = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
offset += extensionLength;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
break;
|
||||||
|
// Fall-through
|
||||||
|
default:
|
||||||
|
//not tls, ssl or application data, do not try sni
|
||||||
|
continueLoop = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (Exception e)
|
||||||
|
{
|
||||||
|
error = e;
|
||||||
|
|
||||||
|
// unexpected encoding, ignore sni and use default
|
||||||
|
if (Logger.DebugEnabled)
|
||||||
|
{
|
||||||
|
Logger.Warn($"Unexpected client hello packet: {ByteBufferUtil.HexDump(input)}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.serverTlsSniSettings.DefaultServerHostName != null)
|
||||||
|
{
|
||||||
|
// Just select the default certifcate
|
||||||
|
this.Select(context, this.serverTlsSniSettings.DefaultServerHostName);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
this.handshakeFailed = true;
|
||||||
|
var e = new DecoderException($"failed to get the Tls Certificate {error}");
|
||||||
|
TlsUtils.NotifyHandshakeFailure(context, e);
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async void Select(IChannelHandlerContext context, string hostName)
|
||||||
|
{
|
||||||
|
Contract.Requires(hostName != null);
|
||||||
|
this.suppressRead = true;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
var serverTlsSetting = await this.serverTlsSniSettings.ServerTlsSettingMap(hostName);
|
||||||
|
this.ReplaceHandler(context, serverTlsSetting);
|
||||||
|
}
|
||||||
|
catch (Exception ex)
|
||||||
|
{
|
||||||
|
this.ExceptionCaught(context, new DecoderException($"failed to get the Tls Certificate for {hostName}, {ex}"));
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
this.suppressRead = false;
|
||||||
|
if (this.readPending)
|
||||||
|
{
|
||||||
|
this.readPending = false;
|
||||||
|
context.Read();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReplaceHandler(IChannelHandlerContext context, ServerTlsSettings serverTlsSetting)
|
||||||
|
{
|
||||||
|
Contract.Requires(serverTlsSetting != null);
|
||||||
|
var tlsHandler = new TlsHandler(this.sslStreamFactory, serverTlsSetting);
|
||||||
|
context.Channel.Pipeline.Replace(this, nameof(TlsHandler), tlsHandler);
|
||||||
|
}
|
||||||
|
|
||||||
|
public override void Read(IChannelHandlerContext context)
|
||||||
|
{
|
||||||
|
if (this.suppressRead)
|
||||||
|
{
|
||||||
|
this.readPending = true;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
base.Read(context);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -33,6 +33,12 @@ namespace DotNetty.Handlers.Tls
|
||||||
/// the length of the ssl record header (in bytes)
|
/// the length of the ssl record header (in bytes)
|
||||||
public const int SSL_RECORD_HEADER_LENGTH = 5;
|
public const int SSL_RECORD_HEADER_LENGTH = 5;
|
||||||
|
|
||||||
|
// Not enough data in buffer to parse the record length
|
||||||
|
public const int NOT_ENOUGH_DATA = -1;
|
||||||
|
|
||||||
|
// data is not encrypted
|
||||||
|
public const int NOT_ENCRYPTED = -2;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase
|
/// Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase
|
||||||
/// the readerIndex of the given <see cref="IByteBuffer"/>.
|
/// the readerIndex of the given <see cref="IByteBuffer"/>.
|
||||||
|
|
|
@ -0,0 +1,280 @@
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
|
||||||
|
namespace DotNetty.Handlers.Tests
|
||||||
|
{
|
||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Diagnostics;
|
||||||
|
using System.Linq;
|
||||||
|
using System.Net.Security;
|
||||||
|
using System.Security.Authentication;
|
||||||
|
using System.Security.Cryptography.X509Certificates;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
using DotNetty.Buffers;
|
||||||
|
using DotNetty.Common.Concurrency;
|
||||||
|
using DotNetty.Handlers.Tls;
|
||||||
|
using DotNetty.Tests.Common;
|
||||||
|
using DotNetty.Transport.Channels;
|
||||||
|
using DotNetty.Transport.Channels.Embedded;
|
||||||
|
using Xunit;
|
||||||
|
using Xunit.Abstractions;
|
||||||
|
|
||||||
|
public class SniHandlerTest : TestBase
|
||||||
|
{
|
||||||
|
static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(10);
|
||||||
|
static readonly Dictionary<string, ServerTlsSettings> SettingMap = new Dictionary<string, ServerTlsSettings>();
|
||||||
|
|
||||||
|
static SniHandlerTest()
|
||||||
|
{
|
||||||
|
X509Certificate2 tlsCertificate = TestResourceHelper.GetTestCertificate();
|
||||||
|
X509Certificate2 tlsCertificate2 = TestResourceHelper.GetTestCertificate2();
|
||||||
|
|
||||||
|
SettingMap[tlsCertificate.GetNameInfo(X509NameType.DnsName, false)] = new ServerTlsSettings(tlsCertificate, false, false, SslProtocols.Tls12);
|
||||||
|
SettingMap[tlsCertificate2.GetNameInfo(X509NameType.DnsName, false)] = new ServerTlsSettings(tlsCertificate2, false, false, SslProtocols.Tls12);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SniHandlerTest(ITestOutputHelper output)
|
||||||
|
: base(output)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
public static IEnumerable<object[]> GetTlsReadTestData()
|
||||||
|
{
|
||||||
|
var lengthVariations =
|
||||||
|
new[]
|
||||||
|
{
|
||||||
|
new[] { 1 }
|
||||||
|
};
|
||||||
|
var boolToggle = new[] { false, true };
|
||||||
|
var protocols = new[] { SslProtocols.Tls12 };
|
||||||
|
var writeStrategyFactories = new Func<IWriteStrategy>[]
|
||||||
|
{
|
||||||
|
() => new AsIsWriteStrategy()
|
||||||
|
};
|
||||||
|
|
||||||
|
return
|
||||||
|
from frameLengths in lengthVariations
|
||||||
|
from isClient in boolToggle
|
||||||
|
from writeStrategyFactory in writeStrategyFactories
|
||||||
|
from protocol in protocols
|
||||||
|
from targetHost in SettingMap.Keys
|
||||||
|
select new object[] { frameLengths, isClient, writeStrategyFactory(), protocol, targetHost };
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
[Theory]
|
||||||
|
[MemberData(nameof(GetTlsReadTestData))]
|
||||||
|
public async Task TlsRead(int[] frameLengths, bool isClient, IWriteStrategy writeStrategy, SslProtocols protocol, string targetHost)
|
||||||
|
{
|
||||||
|
this.Output.WriteLine($"frameLengths: {string.Join(", ", frameLengths)}");
|
||||||
|
this.Output.WriteLine($"writeStrategy: {writeStrategy}");
|
||||||
|
this.Output.WriteLine($"protocol: {protocol}");
|
||||||
|
this.Output.WriteLine($"targetHost: {targetHost}");
|
||||||
|
|
||||||
|
var executor = new SingleThreadEventExecutor("test executor", TimeSpan.FromMilliseconds(10));
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
var writeTasks = new List<Task>();
|
||||||
|
var pair = await SetupStreamAndChannelAsync(isClient, executor, writeStrategy, protocol, writeTasks, targetHost).WithTimeout(TimeSpan.FromSeconds(10));
|
||||||
|
EmbeddedChannel ch = pair.Item1;
|
||||||
|
SslStream driverStream = pair.Item2;
|
||||||
|
|
||||||
|
int randomSeed = Environment.TickCount;
|
||||||
|
var random = new Random(randomSeed);
|
||||||
|
IByteBuffer expectedBuffer = Unpooled.Buffer(16 * 1024);
|
||||||
|
foreach (int len in frameLengths)
|
||||||
|
{
|
||||||
|
var data = new byte[len];
|
||||||
|
random.NextBytes(data);
|
||||||
|
expectedBuffer.WriteBytes(data);
|
||||||
|
await driverStream.WriteAsync(data, 0, data.Length).WithTimeout(TimeSpan.FromSeconds(5));
|
||||||
|
}
|
||||||
|
await Task.WhenAll(writeTasks).WithTimeout(TimeSpan.FromSeconds(5));
|
||||||
|
IByteBuffer finalReadBuffer = Unpooled.Buffer(16 * 1024);
|
||||||
|
await ReadOutboundAsync(async () => ch.ReadInbound<IByteBuffer>(), expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
|
||||||
|
Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
|
||||||
|
|
||||||
|
if (!isClient)
|
||||||
|
{
|
||||||
|
// check if snihandler got replaced with tls handler
|
||||||
|
Assert.Null(ch.Pipeline.Get<SniHandler>());
|
||||||
|
Assert.NotNull(ch.Pipeline.Get<TlsHandler>());
|
||||||
|
}
|
||||||
|
|
||||||
|
driverStream.Dispose();
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
await executor.ShutdownGracefullyAsync(TimeSpan.FromMilliseconds(100), TimeSpan.FromMilliseconds(300));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static IEnumerable<object[]> GetTlsWriteTestData()
|
||||||
|
{
|
||||||
|
var lengthVariations =
|
||||||
|
new[]
|
||||||
|
{
|
||||||
|
new[] { 1 }
|
||||||
|
};
|
||||||
|
var boolToggle = new[] { false, true };
|
||||||
|
var protocols = new[] { SslProtocols.Tls12 };
|
||||||
|
|
||||||
|
return
|
||||||
|
from frameLengths in lengthVariations
|
||||||
|
from isClient in boolToggle
|
||||||
|
from protocol in protocols
|
||||||
|
from targetHost in SettingMap.Keys
|
||||||
|
select new object[] { frameLengths, isClient, protocol, targetHost };
|
||||||
|
}
|
||||||
|
|
||||||
|
[Theory]
|
||||||
|
[MemberData(nameof(GetTlsWriteTestData))]
|
||||||
|
public async Task TlsWrite(int[] frameLengths, bool isClient, SslProtocols protocol, string targetHost)
|
||||||
|
{
|
||||||
|
this.Output.WriteLine("frameLengths: " + string.Join(", ", frameLengths));
|
||||||
|
this.Output.WriteLine($"protocol: {protocol}");
|
||||||
|
this.Output.WriteLine($"targetHost: {targetHost}");
|
||||||
|
|
||||||
|
var writeStrategy = new AsIsWriteStrategy();
|
||||||
|
var executor = new SingleThreadEventExecutor("test executor", TimeSpan.FromMilliseconds(10));
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
var writeTasks = new List<Task>();
|
||||||
|
var pair = await SetupStreamAndChannelAsync(isClient, executor, writeStrategy, protocol, writeTasks, targetHost);
|
||||||
|
EmbeddedChannel ch = pair.Item1;
|
||||||
|
SslStream driverStream = pair.Item2;
|
||||||
|
|
||||||
|
int randomSeed = Environment.TickCount;
|
||||||
|
var random = new Random(randomSeed);
|
||||||
|
IByteBuffer expectedBuffer = Unpooled.Buffer(16 * 1024);
|
||||||
|
foreach (IEnumerable<int> lengths in frameLengths.Split(x => x < 0))
|
||||||
|
{
|
||||||
|
ch.WriteOutbound(lengths.Select(len =>
|
||||||
|
{
|
||||||
|
var data = new byte[len];
|
||||||
|
random.NextBytes(data);
|
||||||
|
expectedBuffer.WriteBytes(data);
|
||||||
|
return (object)Unpooled.WrappedBuffer(data);
|
||||||
|
}).ToArray());
|
||||||
|
}
|
||||||
|
|
||||||
|
IByteBuffer finalReadBuffer = Unpooled.Buffer(16 * 1024);
|
||||||
|
var readBuffer = new byte[16 * 1024 * 10];
|
||||||
|
await ReadOutboundAsync(
|
||||||
|
async () =>
|
||||||
|
{
|
||||||
|
int read = await driverStream.ReadAsync(readBuffer, 0, readBuffer.Length);
|
||||||
|
return Unpooled.WrappedBuffer(readBuffer, 0, read);
|
||||||
|
},
|
||||||
|
expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
|
||||||
|
Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
|
||||||
|
|
||||||
|
if (!isClient)
|
||||||
|
{
|
||||||
|
// check if snihandler got replaced with tls handler
|
||||||
|
Assert.Null(ch.Pipeline.Get<SniHandler>());
|
||||||
|
Assert.NotNull(ch.Pipeline.Get<TlsHandler>());
|
||||||
|
}
|
||||||
|
|
||||||
|
driverStream.Dispose();
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
await executor.ShutdownGracefullyAsync(TimeSpan.FromMilliseconds(100), TimeSpan.FromMilliseconds(300));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static async Task<Tuple<EmbeddedChannel, SslStream>> SetupStreamAndChannelAsync(bool isClient, IEventExecutor executor, IWriteStrategy writeStrategy, SslProtocols protocol, List<Task> writeTasks, string targetHost)
|
||||||
|
{
|
||||||
|
IChannelHandler tlsHandler = isClient ?
|
||||||
|
(IChannelHandler)new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) =>
|
||||||
|
{
|
||||||
|
Assert.Equal(targetHost, certificate.Issuer.Replace("CN=", string.Empty));
|
||||||
|
return true;
|
||||||
|
}), new ClientTlsSettings(SslProtocols.Tls12, false, new List<X509Certificate>(), targetHost)) :
|
||||||
|
new SniHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), new ServerTlsSniSettings(CertificateSelector));
|
||||||
|
//var ch = new EmbeddedChannel(new LoggingHandler("BEFORE"), tlsHandler, new LoggingHandler("AFTER"));
|
||||||
|
var ch = new EmbeddedChannel(tlsHandler);
|
||||||
|
|
||||||
|
if (!isClient)
|
||||||
|
{
|
||||||
|
// check if in the beginning snihandler exists in the pipeline, but not tls handler
|
||||||
|
Assert.NotNull(ch.Pipeline.Get<SniHandler>());
|
||||||
|
Assert.Null(ch.Pipeline.Get<TlsHandler>());
|
||||||
|
}
|
||||||
|
|
||||||
|
IByteBuffer readResultBuffer = Unpooled.Buffer(4 * 1024);
|
||||||
|
Func<ArraySegment<byte>, Task<int>> readDataFunc = async output =>
|
||||||
|
{
|
||||||
|
if (writeTasks.Count > 0)
|
||||||
|
{
|
||||||
|
await Task.WhenAll(writeTasks).WithTimeout(TestTimeout);
|
||||||
|
writeTasks.Clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (readResultBuffer.ReadableBytes < output.Count)
|
||||||
|
{
|
||||||
|
await ReadOutboundAsync(async () => ch.ReadOutbound<IByteBuffer>(), output.Count - readResultBuffer.ReadableBytes, readResultBuffer, TestTimeout);
|
||||||
|
}
|
||||||
|
Assert.NotEqual(0, readResultBuffer.ReadableBytes);
|
||||||
|
int read = Math.Min(output.Count, readResultBuffer.ReadableBytes);
|
||||||
|
readResultBuffer.ReadBytes(output.Array, output.Offset, read);
|
||||||
|
return read;
|
||||||
|
};
|
||||||
|
var mediationStream = new MediationStream(readDataFunc, input =>
|
||||||
|
{
|
||||||
|
Task task = executor.SubmitAsync(() => writeStrategy.WriteToChannelAsync(ch, input)).Unwrap();
|
||||||
|
writeTasks.Add(task);
|
||||||
|
return task;
|
||||||
|
});
|
||||||
|
|
||||||
|
var driverStream = new SslStream(mediationStream, true, (_1, _2, _3, _4) => true);
|
||||||
|
if (isClient)
|
||||||
|
{
|
||||||
|
await Task.Run(() => driverStream.AuthenticateAsServerAsync(CertificateSelector(targetHost).Result.Certificate).WithTimeout(TimeSpan.FromSeconds(5)));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
await Task.Run(() => driverStream.AuthenticateAsClientAsync(targetHost, null, protocol, false)).WithTimeout(TimeSpan.FromSeconds(5));
|
||||||
|
}
|
||||||
|
writeTasks.Clear();
|
||||||
|
|
||||||
|
return Tuple.Create(ch, driverStream);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Task<ServerTlsSettings> CertificateSelector(string hostName)
|
||||||
|
{
|
||||||
|
Assert.NotNull(hostName);
|
||||||
|
Assert.Contains(hostName, SettingMap.Keys);
|
||||||
|
return Task.FromResult(SettingMap[hostName]);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Task ReadOutboundAsync(Func<Task<IByteBuffer>> readFunc, int expectedBytes, IByteBuffer result, TimeSpan timeout)
|
||||||
|
{
|
||||||
|
Stopwatch stopwatch = Stopwatch.StartNew();
|
||||||
|
int remaining = expectedBytes;
|
||||||
|
return AssertEx.EventuallyAsync(
|
||||||
|
async () =>
|
||||||
|
{
|
||||||
|
TimeSpan readTimeout = timeout - stopwatch.Elapsed;
|
||||||
|
if (readTimeout <= TimeSpan.Zero)
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
IByteBuffer output = await readFunc().WithTimeout(readTimeout);//inbound ? ch.ReadInbound<IByteBuffer>() : ch.ReadOutbound<IByteBuffer>();
|
||||||
|
if (output != null)
|
||||||
|
{
|
||||||
|
remaining -= output.ReadableBytes;
|
||||||
|
result.WriteBytes(output);
|
||||||
|
}
|
||||||
|
return remaining <= 0;
|
||||||
|
},
|
||||||
|
TimeSpan.FromMilliseconds(10),
|
||||||
|
timeout);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -93,6 +93,7 @@ namespace DotNetty.Handlers.Tests
|
||||||
IByteBuffer finalReadBuffer = Unpooled.Buffer(16 * 1024);
|
IByteBuffer finalReadBuffer = Unpooled.Buffer(16 * 1024);
|
||||||
await ReadOutboundAsync(async () => ch.ReadInbound<IByteBuffer>(), expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
|
await ReadOutboundAsync(async () => ch.ReadInbound<IByteBuffer>(), expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
|
||||||
Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
|
Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
|
||||||
|
driverStream.Dispose();
|
||||||
}
|
}
|
||||||
finally
|
finally
|
||||||
{
|
{
|
||||||
|
@ -166,6 +167,7 @@ namespace DotNetty.Handlers.Tests
|
||||||
},
|
},
|
||||||
expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
|
expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
|
||||||
Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
|
Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
|
||||||
|
driverStream.Dispose();
|
||||||
}
|
}
|
||||||
finally
|
finally
|
||||||
{
|
{
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
</PropertyGroup>
|
</PropertyGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<EmbeddedResource Include="..\..\shared\dotnetty.com.pfx" />
|
<EmbeddedResource Include="..\..\shared\dotnetty.com.pfx" />
|
||||||
|
<EmbeddedResource Include="..\..\shared\contoso.com.pfx" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<PropertyGroup Condition=" '$(TargetFramework)' == 'net45' ">
|
<PropertyGroup Condition=" '$(TargetFramework)' == 'net45' ">
|
||||||
<RuntimeIdentifier>win-x64</RuntimeIdentifier>
|
<RuntimeIdentifier>win-x64</RuntimeIdentifier>
|
||||||
|
|
|
@ -21,5 +21,18 @@ namespace DotNetty.Tests.Common
|
||||||
|
|
||||||
return new X509Certificate2(certData, "password");
|
return new X509Certificate2(certData, "password");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static X509Certificate2 GetTestCertificate2()
|
||||||
|
{
|
||||||
|
byte[] certData;
|
||||||
|
using (Stream resStream = typeof(TestResourceHelper).GetTypeInfo().Assembly.GetManifestResourceStream(typeof(TestResourceHelper).Namespace + "." + "contoso.com.pfx"))
|
||||||
|
using (var memStream = new MemoryStream())
|
||||||
|
{
|
||||||
|
resStream.CopyTo(memStream);
|
||||||
|
certData = memStream.ToArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
return new X509Certificate2(certData, "password");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
Загрузка…
Ссылка в новой задаче