Exposes SSL Stream and adds more TLS settings (#132)

Motivation:

Some important SSL Stream settings are hidden in the TlsHandler class

Modifications:
SSLStream is provided by user now via factory method;
TLS settings extended

Results:
More advanced scenarios, like X509 client authentication, are possible to do now
This commit is contained in:
Mikhail Tuhckov 2016-06-16 12:07:04 -07:00 коммит произвёл Max Gortman
Родитель 8ab8c8bf0c
Коммит fb18eaffd7
9 изменённых файлов: 160 добавлений и 52 удалений

Просмотреть файл

@ -6,6 +6,7 @@ namespace Echo.Client
using System;
using System.Diagnostics.Tracing;
using System.Net;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using DotNetty.Codecs;
@ -26,6 +27,13 @@ namespace Echo.Client
var group = new MultithreadEventLoopGroup();
X509Certificate2 cert = null;
string targetHost = null;
if (EchoClientSettings.IsSsl)
{
cert = new X509Certificate2("dotnetty.com.pfx", "password");
targetHost = cert.GetNameInfo(X509NameType.DnsName, false);
}
try
{
var bootstrap = new Bootstrap();
@ -37,11 +45,9 @@ namespace Echo.Client
{
IChannelPipeline pipeline = channel.Pipeline;
if (EchoClientSettings.IsSsl)
if (cert != null)
{
var cert = new X509Certificate2("dotnetty.com.pfx", "password");
string targetHost = cert.GetNameInfo(X509NameType.DnsName, false);
pipeline.AddLast(TlsHandler.Client(targetHost, null, (sender, certificate, chain, errors) => true));
pipeline.AddLast(new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), new ClientTlsSettings(targetHost)));
}
pipeline.AddLast(new LengthFieldPrepender(2));
pipeline.AddLast(new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2));

Просмотреть файл

@ -5,6 +5,7 @@ namespace Echo.Server
{
using System;
using System.Diagnostics.Tracing;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using DotNetty.Codecs;
@ -26,6 +27,11 @@ namespace Echo.Server
var bossGroup = new MultithreadEventLoopGroup(1);
var workerGroup = new MultithreadEventLoopGroup();
X509Certificate2 tlsCertificate = null;
if (EchoServerSettings.IsSsl)
{
tlsCertificate = new X509Certificate2("dotnetty.com.pfx", "password");
}
try
{
var bootstrap = new ServerBootstrap();
@ -37,10 +43,9 @@ namespace Echo.Server
.ChildHandler(new ActionChannelInitializer<ISocketChannel>(channel =>
{
IChannelPipeline pipeline = channel.Pipeline;
if (EchoServerSettings.IsSsl)
if (tlsCertificate != null)
{
pipeline.AddLast(TlsHandler.Server(new X509Certificate2("dotnetty.com.pfx", "password")));
pipeline.AddLast(TlsHandler.Server(tlsCertificate));
}
pipeline.AddLast(new LengthFieldPrepender(2));
pipeline.AddLast(new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2));

Просмотреть файл

@ -47,7 +47,9 @@
<Compile Include="Logging\LogLevel.cs" />
<Compile Include="Logging\LogLevelExtensions.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="Tls\ClientTlsSettings.cs" />
<Compile Include="Tls\NotSslRecordException.cs" />
<Compile Include="Tls\ServerTlsSettings.cs" />
<Compile Include="Tls\TlsHandshakeCompletionEvent.cs" />
<Compile Include="Tls\TlsHandler.cs" />
<Compile Include="Timeout\IdleState.cs" />
@ -58,6 +60,7 @@
<Compile Include="Timeout\WriteTimeoutException.cs" />
<Compile Include="Timeout\ReadTimeoutHandler.cs" />
<Compile Include="Timeout\WriteTimeoutHandler.cs" />
<Compile Include="Tls\TlsSettings.cs" />
<Compile Include="Tls\TlsUtils.cs" />
</ItemGroup>
<ItemGroup>

Просмотреть файл

@ -0,0 +1,44 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
namespace DotNetty.Handlers.Tls
{
using System.Collections.Generic;
using System.Linq;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
public sealed class ClientTlsSettings : TlsSettings
{
IReadOnlyCollection<X509Certificate2> certificates;
public ClientTlsSettings(string targetHost)
: this(targetHost, new List<X509Certificate>())
{
}
public ClientTlsSettings(string targetHost, List<X509Certificate> certificates)
: this(false, certificates, targetHost)
{
}
public ClientTlsSettings(bool checkCertificateRevocation, List<X509Certificate> certificates, string targetHost)
: this(SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, checkCertificateRevocation, certificates, targetHost)
{
}
public ClientTlsSettings(SslProtocols enabledProtocols, bool checkCertificateRevocation, List<X509Certificate> certificates, string targetHost)
:base(enabledProtocols, checkCertificateRevocation)
{
this.X509CertificateCollection = new X509CertificateCollection(certificates.ToArray());
this.TargetHost = targetHost;
this.Certificates = certificates.AsReadOnly();
}
internal X509CertificateCollection X509CertificateCollection { get; set; }
public IReadOnlyCollection<X509Certificate> Certificates { get; }
public string TargetHost { get; }
}
}

Просмотреть файл

@ -0,0 +1,29 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
namespace DotNetty.Handlers.Tls
{
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
public sealed class ServerTlsSettings : TlsSettings
{
public ServerTlsSettings(X509Certificate certificate)
: this(false, certificate)
{
}
public ServerTlsSettings(bool checkCertificateRevocation, X509Certificate certificate)
: this(SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, checkCertificateRevocation, certificate)
{
}
public ServerTlsSettings(SslProtocols enabledProtocols, bool checkCertificateRevocation, X509Certificate certificate)
: base(enabledProtocols, checkCertificateRevocation)
{
this.Certificate = certificate;
}
public X509Certificate Certificate { get; }
}
}

Просмотреть файл

@ -9,7 +9,6 @@ namespace DotNetty.Handlers.Tls
using System.IO;
using System.Net.Security;
using System.Runtime.ExceptionServices;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
@ -21,6 +20,7 @@ namespace DotNetty.Handlers.Tls
public sealed class TlsHandler : ByteToMessageDecoder
{
readonly TlsSettings settings;
const int FallbackReadBufferSize = 256;
const int UnencryptedWriteBatchSize = 14 * 1024;
@ -28,41 +28,39 @@ namespace DotNetty.Handlers.Tls
static readonly Action<Task, object> HandshakeCompletionCallback = new Action<Task, object>(HandleHandshakeCompleted);
readonly SslStream sslStream;
readonly MediationStream mediationStream;
readonly TaskCompletionSource closeFuture;
TlsHandlerState state;
int packetLength;
readonly MediationStream mediationStream;
volatile IChannelHandlerContext capturedContext;
BatchingPendingWriteQueue pendingUnencryptedWrites;
Task lastContextWriteTask;
readonly TaskCompletionSource closeFuture;
readonly bool isServer;
readonly X509Certificate2 certificate;
readonly string targetHost;
bool firedChannelRead;
IByteBuffer pendingSslStreamReadBuffer;
Task<int> pendingSslStreamReadFuture;
TlsHandler(bool isServer, X509Certificate2 certificate, string targetHost, RemoteCertificateValidationCallback certificateValidationCallback)
public TlsHandler(TlsSettings settings)
: this(stream => new SslStream(stream, true), settings)
{
Contract.Requires(!isServer || certificate != null);
Contract.Requires(isServer || !string.IsNullOrEmpty(targetHost));
this.closeFuture = new TaskCompletionSource();
this.isServer = isServer;
this.certificate = certificate;
this.targetHost = targetHost;
this.mediationStream = new MediationStream(this);
this.sslStream = new SslStream(this.mediationStream, true, certificateValidationCallback);
}
public static TlsHandler Client(string targetHost) => new TlsHandler(false, null, targetHost, null);
public TlsHandler(Func<Stream, SslStream> sslStreamFactory, TlsSettings settings)
{
Contract.Requires(sslStreamFactory != null);
Contract.Requires(settings != null);
public static TlsHandler Client(string targetHost, X509Certificate2 certificate) => new TlsHandler(false, certificate, targetHost, null);
this.settings = settings;
this.closeFuture = new TaskCompletionSource();
this.mediationStream = new MediationStream(this);
this.sslStream = sslStreamFactory(this.mediationStream);
}
public static TlsHandler Client(string targetHost, X509Certificate2 certificate, RemoteCertificateValidationCallback certificateValidationCallback) => new TlsHandler(false, certificate, targetHost, certificateValidationCallback);
public static TlsHandler Client(string targetHost) => new TlsHandler(new ClientTlsSettings(targetHost));
public static TlsHandler Server(X509Certificate2 certificate) => new TlsHandler(true, certificate, null, null);
public static TlsHandler Client(string targetHost, X509Certificate clientCertificate) => new TlsHandler(new ClientTlsSettings(targetHost, new List<X509Certificate>{ clientCertificate }));
public static TlsHandler Server(X509Certificate certificate) => new TlsHandler(new ServerTlsSettings(certificate));
public X509Certificate LocalCertificate => this.sslStream.LocalCertificate;
@ -74,7 +72,7 @@ namespace DotNetty.Handlers.Tls
{
base.ChannelActive(context);
if (!this.isServer)
if (this.settings is ServerTlsSettings)
{
this.EnsureAuthenticated();
}
@ -161,7 +159,7 @@ namespace DotNetty.Handlers.Tls
base.HandlerAdded(context);
this.capturedContext = context;
this.pendingUnencryptedWrites = new BatchingPendingWriteQueue(context, UnencryptedWriteBatchSize);
if (context.Channel.Active && !this.isServer)
if (context.Channel.Active && this.settings is ClientTlsSettings)
{
// todo: support delayed initialization on an existing/active channel if in client mode
this.EnsureAuthenticated();
@ -217,23 +215,23 @@ namespace DotNetty.Handlers.Tls
break;
}
int packetLength = TlsUtils.GetEncryptedPacketLength(input, offset);
if (packetLength == -1)
int encryptedPacketLength = TlsUtils.GetEncryptedPacketLength(input, offset);
if (encryptedPacketLength == -1)
{
nonSslRecord = true;
break;
}
Contract.Assert(packetLength > 0);
Contract.Assert(encryptedPacketLength > 0);
if (packetLength > readableBytes)
if (encryptedPacketLength > readableBytes)
{
// wait until the whole packet can be read
this.packetLength = packetLength;
this.packetLength = encryptedPacketLength;
break;
}
int newTotalLength = totalLength + packetLength;
int newTotalLength = totalLength + encryptedPacketLength;
if (newTotalLength > TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH)
{
// Don't read too much.
@ -245,8 +243,8 @@ namespace DotNetty.Handlers.Tls
// We have a whole packet.
// Increment the offset to handle the next packet.
packetLengths.Add(packetLength);
offset += packetLength;
packetLengths.Add(encryptedPacketLength);
offset += encryptedPacketLength;
totalLength = newTotalLength;
}
@ -482,19 +480,16 @@ namespace DotNetty.Handlers.Tls
if (!oldState.HasAny(TlsHandlerState.AuthenticationStarted))
{
this.state = oldState | TlsHandlerState.Authenticating;
if (this.isServer)
var serverSettings = settings as ServerTlsSettings;
if (serverSettings != null)
{
this.sslStream.AuthenticateAsServerAsync(this.certificate, false, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false) // todo: change to begin/end
this.sslStream.AuthenticateAsServerAsync(serverSettings.Certificate, false, serverSettings.EnabledProtocols, serverSettings.CheckCertificateRevocation)
.ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously);
}
else
{
var certificateCollection = new X509Certificate2Collection();
if (this.certificate != null)
{
certificateCollection.Add(this.certificate);
}
this.sslStream.AuthenticateAsClientAsync(this.targetHost, certificateCollection, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false) // todo: change to begin/end
var clientSettings = (ClientTlsSettings)settings;
this.sslStream.AuthenticateAsClientAsync(clientSettings.TargetHost, clientSettings.X509CertificateCollection, clientSettings.EnabledProtocols, clientSettings.CheckCertificateRevocation)
.ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously);
}
return false;

Просмотреть файл

@ -0,0 +1,20 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
namespace DotNetty.Handlers.Tls
{
using System.Security.Authentication;
public abstract class TlsSettings
{
protected TlsSettings(SslProtocols enabledProtocols, bool checkCertificateRevocation)
{
this.EnabledProtocols = enabledProtocols;
this.CheckCertificateRevocation = checkCertificateRevocation;
}
public SslProtocols EnabledProtocols { get; }
public bool CheckCertificateRevocation { get; }
}
}

Просмотреть файл

@ -154,7 +154,9 @@ namespace DotNetty.Handlers.Tests
{
var tlsCertificate = new X509Certificate2("dotnetty.com.pfx", "password");
string targetHost = tlsCertificate.GetNameInfo(X509NameType.DnsName, false);
TlsHandler tlsHandler = isClient ? TlsHandler.Client(targetHost, null, (_1, _2, _3, _4) => true) : TlsHandler.Server(tlsCertificate);
TlsHandler tlsHandler = isClient ?
new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), new ClientTlsSettings(targetHost)) :
TlsHandler.Server(tlsCertificate);
//var ch = new EmbeddedChannel(new LoggingHandler("BEFORE"), tlsHandler, new LoggingHandler("AFTER"));
var ch = new EmbeddedChannel(tlsHandler);

Просмотреть файл

@ -7,6 +7,7 @@ namespace DotNetty.Tests.End2End
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading.Tasks;
@ -53,7 +54,7 @@ namespace DotNetty.Tests.End2End
{
ch.Pipeline.AddLast("server logger", new LoggingHandler("SERVER"));
ch.Pipeline.AddLast("server tls", TlsHandler.Server(tlsCertificate));
ch.Pipeline.AddLast("server logger2", new LoggingHandler("SER***"));
ch.Pipeline.AddLast("server logger2", new LoggingHandler("SER***"));
ch.Pipeline.AddLast("server prepender", new LengthFieldPrepender(2));
ch.Pipeline.AddLast("server decoder", new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2));
ch.Pipeline.AddLast(new EchoChannelHandler());
@ -67,8 +68,9 @@ namespace DotNetty.Tests.End2End
.Handler(new ActionChannelInitializer<ISocketChannel>(ch =>
{
string targetHost = tlsCertificate.GetNameInfo(X509NameType.DnsName, false);
var clientTlsSettings = new ClientTlsSettings(targetHost);
ch.Pipeline.AddLast("client logger", new LoggingHandler("CLIENT"));
ch.Pipeline.AddLast("client tls", TlsHandler.Client(targetHost, null, (sender, certificate, chain, errors) => true));
ch.Pipeline.AddLast("client tls", new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), clientTlsSettings));
ch.Pipeline.AddLast("client logger2", new LoggingHandler("CLI***"));
ch.Pipeline.AddLast("client prepender", new LengthFieldPrepender(2));
ch.Pipeline.AddLast("client decoder", new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2));
@ -109,7 +111,7 @@ namespace DotNetty.Tests.End2End
Func<Task> closeServerFunc = await this.StartServerAsync(true, ch =>
{
ch.Pipeline.AddLast("server logger", new LoggingHandler("SERVER"));
ch.Pipeline.AddLast("client tls", TlsHandler.Server(tlsCertificate));
ch.Pipeline.AddLast("server tls", TlsHandler.Server(tlsCertificate));
ch.Pipeline.AddLast("server logger2", new LoggingHandler("SER***"));
ch.Pipeline.AddLast(
MqttEncoder.Instance,
@ -124,9 +126,11 @@ namespace DotNetty.Tests.End2End
.Option(ChannelOption.TcpNodelay, true)
.Handler(new ActionChannelInitializer<ISocketChannel>(ch =>
{
ch.Pipeline.AddLast("client logger", new LoggingHandler("CLIENT"));
string targetHost = tlsCertificate.GetNameInfo(X509NameType.DnsName, false);
ch.Pipeline.AddLast("client tls", TlsHandler.Client(targetHost, null, (sender, certificate, chain, errors) => true));
var clientTlsSettings = new ClientTlsSettings(targetHost);
ch.Pipeline.AddLast("client logger", new LoggingHandler("CLIENT"));
ch.Pipeline.AddLast("client tls", new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), clientTlsSettings));
ch.Pipeline.AddLast("client logger2", new LoggingHandler("CLI***"));
ch.Pipeline.AddLast(
MqttEncoder.Instance,