зеркало из https://github.com/Azure/DotNetty.git
SNI Support (ported netty SniHandler) (#219)
* First cut commit *unfinished * Replace SNI handler with TlsHandler with certificate selected based on host name found in clientHello * first cut tests for SniHandler * test update * Test update * Supress further read when handler is replaced * made the snitest more effective and IDN in hostname lower case as per netty impl * More asserts to check whether snihandler gets replaces with tlshaldler in the pipeline * assert server name is always found in clienthello as per the test setup * Provided option to select default host name in case of error or client hello does not contail SNI extension, otherwise handshake fails in those cases * More elaborate tests * verbosity in test * Fixed Read continues to get called after handler removed and removed the workaround in SniHandler * relaced goto statement with flag for breaking outer for loop from within switch * Update SniHandler.cs * trigger CI build * addressed review comments * Fixed task continuation option * addresses further review comments * triggere build again with some more assert in test #221 * suppress read logic is still needed due to async "void" * changing the map to (string -> Task<ServerTlsSettings) * one more constructor overload * extensive tls read/write test is not needed since that's already done in tlshandler test * more readable target host validation in test to force retrigger confusing CI build * retrigger * addressed review comment "this generates 30 random data frames. pls replace with new [] { 1 }"
This commit is contained in:
Родитель
8357d8ed40
Коммит
1d3eda9a74
Двоичный файл не отображается.
|
@ -151,6 +151,9 @@ namespace DotNetty.Codecs
|
|||
public override void HandlerRemoved(IChannelHandlerContext context)
|
||||
{
|
||||
IByteBuffer buf = this.InternalBuffer;
|
||||
|
||||
// Directly set this to null so we are sure we not access it in any other method here anymore.
|
||||
this.cumulation = null;
|
||||
int readable = buf.ReadableBytes;
|
||||
if (readable > 0)
|
||||
{
|
||||
|
@ -162,7 +165,7 @@ namespace DotNetty.Codecs
|
|||
{
|
||||
buf.Release();
|
||||
}
|
||||
this.cumulation = null;
|
||||
|
||||
context.FireChannelReadComplete();
|
||||
this.HandlerRemovedInternal(context);
|
||||
}
|
||||
|
|
|
@ -37,4 +37,9 @@
|
|||
<Reference Include="System" />
|
||||
<Reference Include="Microsoft.CSharp" />
|
||||
</ItemGroup>
|
||||
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard1.3'">
|
||||
<PackageReference Include="System.Globalization.Extensions">
|
||||
<Version>4.3.0</Version>
|
||||
</PackageReference>
|
||||
</ItemGroup>
|
||||
</Project>
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
|
||||
namespace DotNetty.Handlers.Tls
|
||||
{
|
||||
using System;
|
||||
using System.Diagnostics.Contracts;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
public sealed class ServerTlsSniSettings
|
||||
{
|
||||
public ServerTlsSniSettings(Func<string, Task<ServerTlsSettings>> serverTlsSettingMap, string defaultServerHostName = null)
|
||||
{
|
||||
Contract.Requires(serverTlsSettingMap != null);
|
||||
this.ServerTlsSettingMap = serverTlsSettingMap;
|
||||
this.DefaultServerHostName = defaultServerHostName;
|
||||
}
|
||||
|
||||
public Func<string, Task<ServerTlsSettings>> ServerTlsSettingMap { get; }
|
||||
|
||||
public string DefaultServerHostName { get; }
|
||||
}
|
||||
}
|
|
@ -0,0 +1,316 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
|
||||
namespace DotNetty.Handlers.Tls
|
||||
{
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics.Contracts;
|
||||
using System.Globalization;
|
||||
using System.IO;
|
||||
using System.Net.Security;
|
||||
using System.Text;
|
||||
using DotNetty.Buffers;
|
||||
using DotNetty.Codecs;
|
||||
using DotNetty.Common.Internal.Logging;
|
||||
using DotNetty.Transport.Channels;
|
||||
|
||||
public sealed class SniHandler : ByteToMessageDecoder
|
||||
{
|
||||
// Maximal number of ssl records to inspect before fallback to the default (aligned with netty)
|
||||
const int MAX_SSL_RECORDS = 4;
|
||||
static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(typeof(SniHandler));
|
||||
readonly Func<Stream, SslStream> sslStreamFactory;
|
||||
readonly ServerTlsSniSettings serverTlsSniSettings;
|
||||
|
||||
bool handshakeFailed;
|
||||
bool suppressRead;
|
||||
bool readPending;
|
||||
|
||||
public SniHandler(ServerTlsSniSettings settings)
|
||||
: this(stream => new SslStream(stream, true), settings)
|
||||
{
|
||||
}
|
||||
|
||||
public SniHandler(Func<Stream, SslStream> sslStreamFactory, ServerTlsSniSettings settings)
|
||||
{
|
||||
Contract.Requires(settings != null);
|
||||
Contract.Requires(sslStreamFactory != null);
|
||||
this.sslStreamFactory = sslStreamFactory;
|
||||
this.serverTlsSniSettings = settings;
|
||||
}
|
||||
|
||||
protected override void Decode(IChannelHandlerContext context, IByteBuffer input, List<object> output)
|
||||
{
|
||||
if (!this.suppressRead && !this.handshakeFailed)
|
||||
{
|
||||
int writerIndex = input.WriterIndex;
|
||||
Exception error = null;
|
||||
try
|
||||
{
|
||||
bool continueLoop = true;
|
||||
for (int i = 0; i < MAX_SSL_RECORDS && continueLoop; i++)
|
||||
{
|
||||
int readerIndex = input.ReaderIndex;
|
||||
int readableBytes = writerIndex - readerIndex;
|
||||
if (readableBytes < TlsUtils.SSL_RECORD_HEADER_LENGTH)
|
||||
{
|
||||
// Not enough data to determine the record type and length.
|
||||
return;
|
||||
}
|
||||
|
||||
int command = input.GetByte(readerIndex);
|
||||
// tls, but not handshake command
|
||||
switch (command)
|
||||
{
|
||||
case TlsUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
|
||||
case TlsUtils.SSL_CONTENT_TYPE_ALERT:
|
||||
int len = TlsUtils.GetEncryptedPacketLength(input, readerIndex);
|
||||
|
||||
// Not an SSL/TLS packet
|
||||
if (len == TlsUtils.NOT_ENCRYPTED)
|
||||
{
|
||||
this.handshakeFailed = true;
|
||||
var e = new NotSslRecordException(
|
||||
"not an SSL/TLS record: " + ByteBufferUtil.HexDump(input));
|
||||
input.SkipBytes(input.ReadableBytes);
|
||||
|
||||
TlsUtils.NotifyHandshakeFailure(context, e);
|
||||
throw e;
|
||||
}
|
||||
if (len == TlsUtils.NOT_ENOUGH_DATA ||
|
||||
writerIndex - readerIndex - TlsUtils.SSL_RECORD_HEADER_LENGTH < len)
|
||||
{
|
||||
// Not enough data
|
||||
return;
|
||||
}
|
||||
|
||||
// increase readerIndex and try again.
|
||||
input.SkipBytes(len);
|
||||
continue;
|
||||
|
||||
case TlsUtils.SSL_CONTENT_TYPE_HANDSHAKE:
|
||||
int majorVersion = input.GetByte(readerIndex + 1);
|
||||
|
||||
// SSLv3 or TLS
|
||||
if (majorVersion == 3)
|
||||
{
|
||||
int packetLength = input.GetUnsignedShort(readerIndex + 3) + TlsUtils.SSL_RECORD_HEADER_LENGTH;
|
||||
|
||||
if (readableBytes < packetLength)
|
||||
{
|
||||
// client hello incomplete; try again to decode once more data is ready.
|
||||
return;
|
||||
}
|
||||
|
||||
// See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
|
||||
//
|
||||
// Decode the ssl client hello packet.
|
||||
// We have to skip bytes until SessionID (which sum to 43 bytes).
|
||||
//
|
||||
// struct {
|
||||
// ProtocolVersion client_version;
|
||||
// Random random;
|
||||
// SessionID session_id;
|
||||
// CipherSuite cipher_suites<2..2^16-2>;
|
||||
// CompressionMethod compression_methods<1..2^8-1>;
|
||||
// select (extensions_present) {
|
||||
// case false:
|
||||
// struct {};
|
||||
// case true:
|
||||
// Extension extensions<0..2^16-1>;
|
||||
// };
|
||||
// } ClientHello;
|
||||
//
|
||||
|
||||
int endOffset = readerIndex + packetLength;
|
||||
int offset = readerIndex + 43;
|
||||
|
||||
if (endOffset - offset < 6)
|
||||
{
|
||||
continueLoop = false;
|
||||
break;
|
||||
}
|
||||
|
||||
int sessionIdLength = input.GetByte(offset);
|
||||
offset += sessionIdLength + 1;
|
||||
|
||||
int cipherSuitesLength = input.GetUnsignedShort(offset);
|
||||
offset += cipherSuitesLength + 2;
|
||||
|
||||
int compressionMethodLength = input.GetByte(offset);
|
||||
offset += compressionMethodLength + 1;
|
||||
|
||||
int extensionsLength = input.GetUnsignedShort(offset);
|
||||
offset += 2;
|
||||
int extensionsLimit = offset + extensionsLength;
|
||||
|
||||
if (extensionsLimit > endOffset)
|
||||
{
|
||||
// Extensions should never exceed the record boundary.
|
||||
continueLoop = false;
|
||||
break;
|
||||
}
|
||||
|
||||
for (;;)
|
||||
{
|
||||
if (extensionsLimit - offset < 4)
|
||||
{
|
||||
continueLoop = false;
|
||||
break;
|
||||
}
|
||||
|
||||
int extensionType = input.GetUnsignedShort(offset);
|
||||
offset += 2;
|
||||
|
||||
int extensionLength = input.GetUnsignedShort(offset);
|
||||
offset += 2;
|
||||
|
||||
if (extensionsLimit - offset < extensionLength)
|
||||
{
|
||||
continueLoop = false;
|
||||
break;
|
||||
}
|
||||
|
||||
// SNI
|
||||
// See https://tools.ietf.org/html/rfc6066#page-6
|
||||
if (extensionType == 0)
|
||||
{
|
||||
offset += 2;
|
||||
if (extensionsLimit - offset < 3)
|
||||
{
|
||||
continueLoop = false;
|
||||
break;
|
||||
}
|
||||
|
||||
int serverNameType = input.GetByte(offset);
|
||||
offset++;
|
||||
|
||||
if (serverNameType == 0)
|
||||
{
|
||||
int serverNameLength = input.GetUnsignedShort(offset);
|
||||
offset += 2;
|
||||
|
||||
if (serverNameLength <= 0 || extensionsLimit - offset < serverNameLength)
|
||||
{
|
||||
continueLoop = false;
|
||||
break;
|
||||
}
|
||||
|
||||
string hostname = input.ToString(offset, serverNameLength, Encoding.UTF8);
|
||||
//try
|
||||
//{
|
||||
// select(ctx, IDN.toASCII(hostname,
|
||||
// IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US));
|
||||
//}
|
||||
//catch (Throwable t)
|
||||
//{
|
||||
// PlatformDependent.throwException(t);
|
||||
//}
|
||||
|
||||
var idn = new IdnMapping()
|
||||
{
|
||||
AllowUnassigned = true
|
||||
};
|
||||
|
||||
hostname = idn.GetAscii(hostname);
|
||||
#if NETSTANDARD1_3
|
||||
// TODO: netcore does not have culture sensitive tolower()
|
||||
hostname = hostname.ToLowerInvariant();
|
||||
#else
|
||||
hostname = hostname.ToLower(new CultureInfo("en-US"));
|
||||
#endif
|
||||
this.Select(context, hostname);
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
// invalid enum value
|
||||
continueLoop = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
offset += extensionLength;
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
// Fall-through
|
||||
default:
|
||||
//not tls, ssl or application data, do not try sni
|
||||
continueLoop = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
error = e;
|
||||
|
||||
// unexpected encoding, ignore sni and use default
|
||||
if (Logger.DebugEnabled)
|
||||
{
|
||||
Logger.Warn($"Unexpected client hello packet: {ByteBufferUtil.HexDump(input)}", e);
|
||||
}
|
||||
}
|
||||
|
||||
if (this.serverTlsSniSettings.DefaultServerHostName != null)
|
||||
{
|
||||
// Just select the default certifcate
|
||||
this.Select(context, this.serverTlsSniSettings.DefaultServerHostName);
|
||||
}
|
||||
else
|
||||
{
|
||||
this.handshakeFailed = true;
|
||||
var e = new DecoderException($"failed to get the Tls Certificate {error}");
|
||||
TlsUtils.NotifyHandshakeFailure(context, e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async void Select(IChannelHandlerContext context, string hostName)
|
||||
{
|
||||
Contract.Requires(hostName != null);
|
||||
this.suppressRead = true;
|
||||
try
|
||||
{
|
||||
var serverTlsSetting = await this.serverTlsSniSettings.ServerTlsSettingMap(hostName);
|
||||
this.ReplaceHandler(context, serverTlsSetting);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
this.ExceptionCaught(context, new DecoderException($"failed to get the Tls Certificate for {hostName}, {ex}"));
|
||||
}
|
||||
finally
|
||||
{
|
||||
this.suppressRead = false;
|
||||
if (this.readPending)
|
||||
{
|
||||
this.readPending = false;
|
||||
context.Read();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ReplaceHandler(IChannelHandlerContext context, ServerTlsSettings serverTlsSetting)
|
||||
{
|
||||
Contract.Requires(serverTlsSetting != null);
|
||||
var tlsHandler = new TlsHandler(this.sslStreamFactory, serverTlsSetting);
|
||||
context.Channel.Pipeline.Replace(this, nameof(TlsHandler), tlsHandler);
|
||||
}
|
||||
|
||||
public override void Read(IChannelHandlerContext context)
|
||||
{
|
||||
if (this.suppressRead)
|
||||
{
|
||||
this.readPending = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
base.Read(context);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -33,6 +33,12 @@ namespace DotNetty.Handlers.Tls
|
|||
/// the length of the ssl record header (in bytes)
|
||||
public const int SSL_RECORD_HEADER_LENGTH = 5;
|
||||
|
||||
// Not enough data in buffer to parse the record length
|
||||
public const int NOT_ENOUGH_DATA = -1;
|
||||
|
||||
// data is not encrypted
|
||||
public const int NOT_ENCRYPTED = -2;
|
||||
|
||||
/// <summary>
|
||||
/// Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase
|
||||
/// the readerIndex of the given <see cref="IByteBuffer"/>.
|
||||
|
|
|
@ -0,0 +1,280 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
|
||||
namespace DotNetty.Handlers.Tests
|
||||
{
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Linq;
|
||||
using System.Net.Security;
|
||||
using System.Security.Authentication;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using System.Threading.Tasks;
|
||||
using DotNetty.Buffers;
|
||||
using DotNetty.Common.Concurrency;
|
||||
using DotNetty.Handlers.Tls;
|
||||
using DotNetty.Tests.Common;
|
||||
using DotNetty.Transport.Channels;
|
||||
using DotNetty.Transport.Channels.Embedded;
|
||||
using Xunit;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
public class SniHandlerTest : TestBase
|
||||
{
|
||||
static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(10);
|
||||
static readonly Dictionary<string, ServerTlsSettings> SettingMap = new Dictionary<string, ServerTlsSettings>();
|
||||
|
||||
static SniHandlerTest()
|
||||
{
|
||||
X509Certificate2 tlsCertificate = TestResourceHelper.GetTestCertificate();
|
||||
X509Certificate2 tlsCertificate2 = TestResourceHelper.GetTestCertificate2();
|
||||
|
||||
SettingMap[tlsCertificate.GetNameInfo(X509NameType.DnsName, false)] = new ServerTlsSettings(tlsCertificate, false, false, SslProtocols.Tls12);
|
||||
SettingMap[tlsCertificate2.GetNameInfo(X509NameType.DnsName, false)] = new ServerTlsSettings(tlsCertificate2, false, false, SslProtocols.Tls12);
|
||||
}
|
||||
|
||||
public SniHandlerTest(ITestOutputHelper output)
|
||||
: base(output)
|
||||
{
|
||||
}
|
||||
|
||||
public static IEnumerable<object[]> GetTlsReadTestData()
|
||||
{
|
||||
var lengthVariations =
|
||||
new[]
|
||||
{
|
||||
new[] { 1 }
|
||||
};
|
||||
var boolToggle = new[] { false, true };
|
||||
var protocols = new[] { SslProtocols.Tls12 };
|
||||
var writeStrategyFactories = new Func<IWriteStrategy>[]
|
||||
{
|
||||
() => new AsIsWriteStrategy()
|
||||
};
|
||||
|
||||
return
|
||||
from frameLengths in lengthVariations
|
||||
from isClient in boolToggle
|
||||
from writeStrategyFactory in writeStrategyFactories
|
||||
from protocol in protocols
|
||||
from targetHost in SettingMap.Keys
|
||||
select new object[] { frameLengths, isClient, writeStrategyFactory(), protocol, targetHost };
|
||||
}
|
||||
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(GetTlsReadTestData))]
|
||||
public async Task TlsRead(int[] frameLengths, bool isClient, IWriteStrategy writeStrategy, SslProtocols protocol, string targetHost)
|
||||
{
|
||||
this.Output.WriteLine($"frameLengths: {string.Join(", ", frameLengths)}");
|
||||
this.Output.WriteLine($"writeStrategy: {writeStrategy}");
|
||||
this.Output.WriteLine($"protocol: {protocol}");
|
||||
this.Output.WriteLine($"targetHost: {targetHost}");
|
||||
|
||||
var executor = new SingleThreadEventExecutor("test executor", TimeSpan.FromMilliseconds(10));
|
||||
|
||||
try
|
||||
{
|
||||
var writeTasks = new List<Task>();
|
||||
var pair = await SetupStreamAndChannelAsync(isClient, executor, writeStrategy, protocol, writeTasks, targetHost).WithTimeout(TimeSpan.FromSeconds(10));
|
||||
EmbeddedChannel ch = pair.Item1;
|
||||
SslStream driverStream = pair.Item2;
|
||||
|
||||
int randomSeed = Environment.TickCount;
|
||||
var random = new Random(randomSeed);
|
||||
IByteBuffer expectedBuffer = Unpooled.Buffer(16 * 1024);
|
||||
foreach (int len in frameLengths)
|
||||
{
|
||||
var data = new byte[len];
|
||||
random.NextBytes(data);
|
||||
expectedBuffer.WriteBytes(data);
|
||||
await driverStream.WriteAsync(data, 0, data.Length).WithTimeout(TimeSpan.FromSeconds(5));
|
||||
}
|
||||
await Task.WhenAll(writeTasks).WithTimeout(TimeSpan.FromSeconds(5));
|
||||
IByteBuffer finalReadBuffer = Unpooled.Buffer(16 * 1024);
|
||||
await ReadOutboundAsync(async () => ch.ReadInbound<IByteBuffer>(), expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
|
||||
Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
|
||||
|
||||
if (!isClient)
|
||||
{
|
||||
// check if snihandler got replaced with tls handler
|
||||
Assert.Null(ch.Pipeline.Get<SniHandler>());
|
||||
Assert.NotNull(ch.Pipeline.Get<TlsHandler>());
|
||||
}
|
||||
|
||||
driverStream.Dispose();
|
||||
}
|
||||
finally
|
||||
{
|
||||
await executor.ShutdownGracefullyAsync(TimeSpan.FromMilliseconds(100), TimeSpan.FromMilliseconds(300));
|
||||
}
|
||||
}
|
||||
|
||||
public static IEnumerable<object[]> GetTlsWriteTestData()
|
||||
{
|
||||
var lengthVariations =
|
||||
new[]
|
||||
{
|
||||
new[] { 1 }
|
||||
};
|
||||
var boolToggle = new[] { false, true };
|
||||
var protocols = new[] { SslProtocols.Tls12 };
|
||||
|
||||
return
|
||||
from frameLengths in lengthVariations
|
||||
from isClient in boolToggle
|
||||
from protocol in protocols
|
||||
from targetHost in SettingMap.Keys
|
||||
select new object[] { frameLengths, isClient, protocol, targetHost };
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(GetTlsWriteTestData))]
|
||||
public async Task TlsWrite(int[] frameLengths, bool isClient, SslProtocols protocol, string targetHost)
|
||||
{
|
||||
this.Output.WriteLine("frameLengths: " + string.Join(", ", frameLengths));
|
||||
this.Output.WriteLine($"protocol: {protocol}");
|
||||
this.Output.WriteLine($"targetHost: {targetHost}");
|
||||
|
||||
var writeStrategy = new AsIsWriteStrategy();
|
||||
var executor = new SingleThreadEventExecutor("test executor", TimeSpan.FromMilliseconds(10));
|
||||
|
||||
try
|
||||
{
|
||||
var writeTasks = new List<Task>();
|
||||
var pair = await SetupStreamAndChannelAsync(isClient, executor, writeStrategy, protocol, writeTasks, targetHost);
|
||||
EmbeddedChannel ch = pair.Item1;
|
||||
SslStream driverStream = pair.Item2;
|
||||
|
||||
int randomSeed = Environment.TickCount;
|
||||
var random = new Random(randomSeed);
|
||||
IByteBuffer expectedBuffer = Unpooled.Buffer(16 * 1024);
|
||||
foreach (IEnumerable<int> lengths in frameLengths.Split(x => x < 0))
|
||||
{
|
||||
ch.WriteOutbound(lengths.Select(len =>
|
||||
{
|
||||
var data = new byte[len];
|
||||
random.NextBytes(data);
|
||||
expectedBuffer.WriteBytes(data);
|
||||
return (object)Unpooled.WrappedBuffer(data);
|
||||
}).ToArray());
|
||||
}
|
||||
|
||||
IByteBuffer finalReadBuffer = Unpooled.Buffer(16 * 1024);
|
||||
var readBuffer = new byte[16 * 1024 * 10];
|
||||
await ReadOutboundAsync(
|
||||
async () =>
|
||||
{
|
||||
int read = await driverStream.ReadAsync(readBuffer, 0, readBuffer.Length);
|
||||
return Unpooled.WrappedBuffer(readBuffer, 0, read);
|
||||
},
|
||||
expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
|
||||
Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
|
||||
|
||||
if (!isClient)
|
||||
{
|
||||
// check if snihandler got replaced with tls handler
|
||||
Assert.Null(ch.Pipeline.Get<SniHandler>());
|
||||
Assert.NotNull(ch.Pipeline.Get<TlsHandler>());
|
||||
}
|
||||
|
||||
driverStream.Dispose();
|
||||
}
|
||||
finally
|
||||
{
|
||||
await executor.ShutdownGracefullyAsync(TimeSpan.FromMilliseconds(100), TimeSpan.FromMilliseconds(300));
|
||||
}
|
||||
}
|
||||
|
||||
static async Task<Tuple<EmbeddedChannel, SslStream>> SetupStreamAndChannelAsync(bool isClient, IEventExecutor executor, IWriteStrategy writeStrategy, SslProtocols protocol, List<Task> writeTasks, string targetHost)
|
||||
{
|
||||
IChannelHandler tlsHandler = isClient ?
|
||||
(IChannelHandler)new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) =>
|
||||
{
|
||||
Assert.Equal(targetHost, certificate.Issuer.Replace("CN=", string.Empty));
|
||||
return true;
|
||||
}), new ClientTlsSettings(SslProtocols.Tls12, false, new List<X509Certificate>(), targetHost)) :
|
||||
new SniHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), new ServerTlsSniSettings(CertificateSelector));
|
||||
//var ch = new EmbeddedChannel(new LoggingHandler("BEFORE"), tlsHandler, new LoggingHandler("AFTER"));
|
||||
var ch = new EmbeddedChannel(tlsHandler);
|
||||
|
||||
if (!isClient)
|
||||
{
|
||||
// check if in the beginning snihandler exists in the pipeline, but not tls handler
|
||||
Assert.NotNull(ch.Pipeline.Get<SniHandler>());
|
||||
Assert.Null(ch.Pipeline.Get<TlsHandler>());
|
||||
}
|
||||
|
||||
IByteBuffer readResultBuffer = Unpooled.Buffer(4 * 1024);
|
||||
Func<ArraySegment<byte>, Task<int>> readDataFunc = async output =>
|
||||
{
|
||||
if (writeTasks.Count > 0)
|
||||
{
|
||||
await Task.WhenAll(writeTasks).WithTimeout(TestTimeout);
|
||||
writeTasks.Clear();
|
||||
}
|
||||
|
||||
if (readResultBuffer.ReadableBytes < output.Count)
|
||||
{
|
||||
await ReadOutboundAsync(async () => ch.ReadOutbound<IByteBuffer>(), output.Count - readResultBuffer.ReadableBytes, readResultBuffer, TestTimeout);
|
||||
}
|
||||
Assert.NotEqual(0, readResultBuffer.ReadableBytes);
|
||||
int read = Math.Min(output.Count, readResultBuffer.ReadableBytes);
|
||||
readResultBuffer.ReadBytes(output.Array, output.Offset, read);
|
||||
return read;
|
||||
};
|
||||
var mediationStream = new MediationStream(readDataFunc, input =>
|
||||
{
|
||||
Task task = executor.SubmitAsync(() => writeStrategy.WriteToChannelAsync(ch, input)).Unwrap();
|
||||
writeTasks.Add(task);
|
||||
return task;
|
||||
});
|
||||
|
||||
var driverStream = new SslStream(mediationStream, true, (_1, _2, _3, _4) => true);
|
||||
if (isClient)
|
||||
{
|
||||
await Task.Run(() => driverStream.AuthenticateAsServerAsync(CertificateSelector(targetHost).Result.Certificate).WithTimeout(TimeSpan.FromSeconds(5)));
|
||||
}
|
||||
else
|
||||
{
|
||||
await Task.Run(() => driverStream.AuthenticateAsClientAsync(targetHost, null, protocol, false)).WithTimeout(TimeSpan.FromSeconds(5));
|
||||
}
|
||||
writeTasks.Clear();
|
||||
|
||||
return Tuple.Create(ch, driverStream);
|
||||
}
|
||||
|
||||
static Task<ServerTlsSettings> CertificateSelector(string hostName)
|
||||
{
|
||||
Assert.NotNull(hostName);
|
||||
Assert.Contains(hostName, SettingMap.Keys);
|
||||
return Task.FromResult(SettingMap[hostName]);
|
||||
}
|
||||
|
||||
static Task ReadOutboundAsync(Func<Task<IByteBuffer>> readFunc, int expectedBytes, IByteBuffer result, TimeSpan timeout)
|
||||
{
|
||||
Stopwatch stopwatch = Stopwatch.StartNew();
|
||||
int remaining = expectedBytes;
|
||||
return AssertEx.EventuallyAsync(
|
||||
async () =>
|
||||
{
|
||||
TimeSpan readTimeout = timeout - stopwatch.Elapsed;
|
||||
if (readTimeout <= TimeSpan.Zero)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
IByteBuffer output = await readFunc().WithTimeout(readTimeout);//inbound ? ch.ReadInbound<IByteBuffer>() : ch.ReadOutbound<IByteBuffer>();
|
||||
if (output != null)
|
||||
{
|
||||
remaining -= output.ReadableBytes;
|
||||
result.WriteBytes(output);
|
||||
}
|
||||
return remaining <= 0;
|
||||
},
|
||||
TimeSpan.FromMilliseconds(10),
|
||||
timeout);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -93,6 +93,7 @@ namespace DotNetty.Handlers.Tests
|
|||
IByteBuffer finalReadBuffer = Unpooled.Buffer(16 * 1024);
|
||||
await ReadOutboundAsync(async () => ch.ReadInbound<IByteBuffer>(), expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
|
||||
Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
|
||||
driverStream.Dispose();
|
||||
}
|
||||
finally
|
||||
{
|
||||
|
@ -166,6 +167,7 @@ namespace DotNetty.Handlers.Tests
|
|||
},
|
||||
expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
|
||||
Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
|
||||
driverStream.Dispose();
|
||||
}
|
||||
finally
|
||||
{
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
</PropertyGroup>
|
||||
<ItemGroup>
|
||||
<EmbeddedResource Include="..\..\shared\dotnetty.com.pfx" />
|
||||
<EmbeddedResource Include="..\..\shared\contoso.com.pfx" />
|
||||
</ItemGroup>
|
||||
<PropertyGroup Condition=" '$(TargetFramework)' == 'net45' ">
|
||||
<RuntimeIdentifier>win-x64</RuntimeIdentifier>
|
||||
|
|
|
@ -21,5 +21,18 @@ namespace DotNetty.Tests.Common
|
|||
|
||||
return new X509Certificate2(certData, "password");
|
||||
}
|
||||
|
||||
public static X509Certificate2 GetTestCertificate2()
|
||||
{
|
||||
byte[] certData;
|
||||
using (Stream resStream = typeof(TestResourceHelper).GetTypeInfo().Assembly.GetManifestResourceStream(typeof(TestResourceHelper).Namespace + "." + "contoso.com.pfx"))
|
||||
using (var memStream = new MemoryStream())
|
||||
{
|
||||
resStream.CopyTo(memStream);
|
||||
certData = memStream.ToArray();
|
||||
}
|
||||
|
||||
return new X509Certificate2(certData, "password");
|
||||
}
|
||||
}
|
||||
}
|
Загрузка…
Ссылка в новой задаче