зеркало из https://github.com/Azure/DotNetty.git
Exposes SSL Stream and adds more TLS settings (#132)
Motivation: Some important SSL Stream settings are hidden in the TlsHandler class Modifications: SSLStream is provided by user now via factory method; TLS settings extended Results: More advanced scenarios, like X509 client authentication, are possible to do now
This commit is contained in:
Родитель
8ab8c8bf0c
Коммит
fb18eaffd7
|
@ -6,6 +6,7 @@ namespace Echo.Client
|
|||
using System;
|
||||
using System.Diagnostics.Tracing;
|
||||
using System.Net;
|
||||
using System.Net.Security;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using System.Threading.Tasks;
|
||||
using DotNetty.Codecs;
|
||||
|
@ -26,6 +27,13 @@ namespace Echo.Client
|
|||
|
||||
var group = new MultithreadEventLoopGroup();
|
||||
|
||||
X509Certificate2 cert = null;
|
||||
string targetHost = null;
|
||||
if (EchoClientSettings.IsSsl)
|
||||
{
|
||||
cert = new X509Certificate2("dotnetty.com.pfx", "password");
|
||||
targetHost = cert.GetNameInfo(X509NameType.DnsName, false);
|
||||
}
|
||||
try
|
||||
{
|
||||
var bootstrap = new Bootstrap();
|
||||
|
@ -37,11 +45,9 @@ namespace Echo.Client
|
|||
{
|
||||
IChannelPipeline pipeline = channel.Pipeline;
|
||||
|
||||
if (EchoClientSettings.IsSsl)
|
||||
if (cert != null)
|
||||
{
|
||||
var cert = new X509Certificate2("dotnetty.com.pfx", "password");
|
||||
string targetHost = cert.GetNameInfo(X509NameType.DnsName, false);
|
||||
pipeline.AddLast(TlsHandler.Client(targetHost, null, (sender, certificate, chain, errors) => true));
|
||||
pipeline.AddLast(new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), new ClientTlsSettings(targetHost)));
|
||||
}
|
||||
pipeline.AddLast(new LengthFieldPrepender(2));
|
||||
pipeline.AddLast(new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2));
|
||||
|
|
|
@ -5,6 +5,7 @@ namespace Echo.Server
|
|||
{
|
||||
using System;
|
||||
using System.Diagnostics.Tracing;
|
||||
using System.Net.Security;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using System.Threading.Tasks;
|
||||
using DotNetty.Codecs;
|
||||
|
@ -26,6 +27,11 @@ namespace Echo.Server
|
|||
|
||||
var bossGroup = new MultithreadEventLoopGroup(1);
|
||||
var workerGroup = new MultithreadEventLoopGroup();
|
||||
X509Certificate2 tlsCertificate = null;
|
||||
if (EchoServerSettings.IsSsl)
|
||||
{
|
||||
tlsCertificate = new X509Certificate2("dotnetty.com.pfx", "password");
|
||||
}
|
||||
try
|
||||
{
|
||||
var bootstrap = new ServerBootstrap();
|
||||
|
@ -37,10 +43,9 @@ namespace Echo.Server
|
|||
.ChildHandler(new ActionChannelInitializer<ISocketChannel>(channel =>
|
||||
{
|
||||
IChannelPipeline pipeline = channel.Pipeline;
|
||||
|
||||
if (EchoServerSettings.IsSsl)
|
||||
if (tlsCertificate != null)
|
||||
{
|
||||
pipeline.AddLast(TlsHandler.Server(new X509Certificate2("dotnetty.com.pfx", "password")));
|
||||
pipeline.AddLast(TlsHandler.Server(tlsCertificate));
|
||||
}
|
||||
pipeline.AddLast(new LengthFieldPrepender(2));
|
||||
pipeline.AddLast(new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2));
|
||||
|
|
|
@ -47,7 +47,9 @@
|
|||
<Compile Include="Logging\LogLevel.cs" />
|
||||
<Compile Include="Logging\LogLevelExtensions.cs" />
|
||||
<Compile Include="Properties\AssemblyInfo.cs" />
|
||||
<Compile Include="Tls\ClientTlsSettings.cs" />
|
||||
<Compile Include="Tls\NotSslRecordException.cs" />
|
||||
<Compile Include="Tls\ServerTlsSettings.cs" />
|
||||
<Compile Include="Tls\TlsHandshakeCompletionEvent.cs" />
|
||||
<Compile Include="Tls\TlsHandler.cs" />
|
||||
<Compile Include="Timeout\IdleState.cs" />
|
||||
|
@ -58,6 +60,7 @@
|
|||
<Compile Include="Timeout\WriteTimeoutException.cs" />
|
||||
<Compile Include="Timeout\ReadTimeoutHandler.cs" />
|
||||
<Compile Include="Timeout\WriteTimeoutHandler.cs" />
|
||||
<Compile Include="Tls\TlsSettings.cs" />
|
||||
<Compile Include="Tls\TlsUtils.cs" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
|
||||
namespace DotNetty.Handlers.Tls
|
||||
{
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Security.Authentication;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
|
||||
public sealed class ClientTlsSettings : TlsSettings
|
||||
{
|
||||
IReadOnlyCollection<X509Certificate2> certificates;
|
||||
|
||||
public ClientTlsSettings(string targetHost)
|
||||
: this(targetHost, new List<X509Certificate>())
|
||||
{
|
||||
}
|
||||
|
||||
public ClientTlsSettings(string targetHost, List<X509Certificate> certificates)
|
||||
: this(false, certificates, targetHost)
|
||||
{
|
||||
}
|
||||
|
||||
public ClientTlsSettings(bool checkCertificateRevocation, List<X509Certificate> certificates, string targetHost)
|
||||
: this(SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, checkCertificateRevocation, certificates, targetHost)
|
||||
{
|
||||
}
|
||||
|
||||
public ClientTlsSettings(SslProtocols enabledProtocols, bool checkCertificateRevocation, List<X509Certificate> certificates, string targetHost)
|
||||
:base(enabledProtocols, checkCertificateRevocation)
|
||||
{
|
||||
this.X509CertificateCollection = new X509CertificateCollection(certificates.ToArray());
|
||||
this.TargetHost = targetHost;
|
||||
this.Certificates = certificates.AsReadOnly();
|
||||
}
|
||||
|
||||
internal X509CertificateCollection X509CertificateCollection { get; set; }
|
||||
|
||||
public IReadOnlyCollection<X509Certificate> Certificates { get; }
|
||||
|
||||
public string TargetHost { get; }
|
||||
}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
|
||||
namespace DotNetty.Handlers.Tls
|
||||
{
|
||||
using System.Security.Authentication;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
|
||||
public sealed class ServerTlsSettings : TlsSettings
|
||||
{
|
||||
public ServerTlsSettings(X509Certificate certificate)
|
||||
: this(false, certificate)
|
||||
{
|
||||
}
|
||||
|
||||
public ServerTlsSettings(bool checkCertificateRevocation, X509Certificate certificate)
|
||||
: this(SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, checkCertificateRevocation, certificate)
|
||||
{
|
||||
}
|
||||
|
||||
public ServerTlsSettings(SslProtocols enabledProtocols, bool checkCertificateRevocation, X509Certificate certificate)
|
||||
: base(enabledProtocols, checkCertificateRevocation)
|
||||
{
|
||||
this.Certificate = certificate;
|
||||
}
|
||||
|
||||
public X509Certificate Certificate { get; }
|
||||
}
|
||||
}
|
|
@ -9,7 +9,6 @@ namespace DotNetty.Handlers.Tls
|
|||
using System.IO;
|
||||
using System.Net.Security;
|
||||
using System.Runtime.ExceptionServices;
|
||||
using System.Security.Authentication;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
@ -21,6 +20,7 @@ namespace DotNetty.Handlers.Tls
|
|||
|
||||
public sealed class TlsHandler : ByteToMessageDecoder
|
||||
{
|
||||
readonly TlsSettings settings;
|
||||
const int FallbackReadBufferSize = 256;
|
||||
const int UnencryptedWriteBatchSize = 14 * 1024;
|
||||
|
||||
|
@ -28,41 +28,39 @@ namespace DotNetty.Handlers.Tls
|
|||
static readonly Action<Task, object> HandshakeCompletionCallback = new Action<Task, object>(HandleHandshakeCompleted);
|
||||
|
||||
readonly SslStream sslStream;
|
||||
readonly MediationStream mediationStream;
|
||||
readonly TaskCompletionSource closeFuture;
|
||||
|
||||
TlsHandlerState state;
|
||||
int packetLength;
|
||||
readonly MediationStream mediationStream;
|
||||
volatile IChannelHandlerContext capturedContext;
|
||||
BatchingPendingWriteQueue pendingUnencryptedWrites;
|
||||
Task lastContextWriteTask;
|
||||
readonly TaskCompletionSource closeFuture;
|
||||
readonly bool isServer;
|
||||
readonly X509Certificate2 certificate;
|
||||
readonly string targetHost;
|
||||
bool firedChannelRead;
|
||||
IByteBuffer pendingSslStreamReadBuffer;
|
||||
Task<int> pendingSslStreamReadFuture;
|
||||
|
||||
TlsHandler(bool isServer, X509Certificate2 certificate, string targetHost, RemoteCertificateValidationCallback certificateValidationCallback)
|
||||
public TlsHandler(TlsSettings settings)
|
||||
: this(stream => new SslStream(stream, true), settings)
|
||||
{
|
||||
Contract.Requires(!isServer || certificate != null);
|
||||
Contract.Requires(isServer || !string.IsNullOrEmpty(targetHost));
|
||||
|
||||
this.closeFuture = new TaskCompletionSource();
|
||||
|
||||
this.isServer = isServer;
|
||||
this.certificate = certificate;
|
||||
this.targetHost = targetHost;
|
||||
this.mediationStream = new MediationStream(this);
|
||||
this.sslStream = new SslStream(this.mediationStream, true, certificateValidationCallback);
|
||||
}
|
||||
|
||||
public static TlsHandler Client(string targetHost) => new TlsHandler(false, null, targetHost, null);
|
||||
public TlsHandler(Func<Stream, SslStream> sslStreamFactory, TlsSettings settings)
|
||||
{
|
||||
Contract.Requires(sslStreamFactory != null);
|
||||
Contract.Requires(settings != null);
|
||||
|
||||
public static TlsHandler Client(string targetHost, X509Certificate2 certificate) => new TlsHandler(false, certificate, targetHost, null);
|
||||
this.settings = settings;
|
||||
this.closeFuture = new TaskCompletionSource();
|
||||
this.mediationStream = new MediationStream(this);
|
||||
this.sslStream = sslStreamFactory(this.mediationStream);
|
||||
}
|
||||
|
||||
public static TlsHandler Client(string targetHost, X509Certificate2 certificate, RemoteCertificateValidationCallback certificateValidationCallback) => new TlsHandler(false, certificate, targetHost, certificateValidationCallback);
|
||||
public static TlsHandler Client(string targetHost) => new TlsHandler(new ClientTlsSettings(targetHost));
|
||||
|
||||
public static TlsHandler Server(X509Certificate2 certificate) => new TlsHandler(true, certificate, null, null);
|
||||
public static TlsHandler Client(string targetHost, X509Certificate clientCertificate) => new TlsHandler(new ClientTlsSettings(targetHost, new List<X509Certificate>{ clientCertificate }));
|
||||
|
||||
public static TlsHandler Server(X509Certificate certificate) => new TlsHandler(new ServerTlsSettings(certificate));
|
||||
|
||||
public X509Certificate LocalCertificate => this.sslStream.LocalCertificate;
|
||||
|
||||
|
@ -74,7 +72,7 @@ namespace DotNetty.Handlers.Tls
|
|||
{
|
||||
base.ChannelActive(context);
|
||||
|
||||
if (!this.isServer)
|
||||
if (this.settings is ServerTlsSettings)
|
||||
{
|
||||
this.EnsureAuthenticated();
|
||||
}
|
||||
|
@ -161,7 +159,7 @@ namespace DotNetty.Handlers.Tls
|
|||
base.HandlerAdded(context);
|
||||
this.capturedContext = context;
|
||||
this.pendingUnencryptedWrites = new BatchingPendingWriteQueue(context, UnencryptedWriteBatchSize);
|
||||
if (context.Channel.Active && !this.isServer)
|
||||
if (context.Channel.Active && this.settings is ClientTlsSettings)
|
||||
{
|
||||
// todo: support delayed initialization on an existing/active channel if in client mode
|
||||
this.EnsureAuthenticated();
|
||||
|
@ -217,23 +215,23 @@ namespace DotNetty.Handlers.Tls
|
|||
break;
|
||||
}
|
||||
|
||||
int packetLength = TlsUtils.GetEncryptedPacketLength(input, offset);
|
||||
if (packetLength == -1)
|
||||
int encryptedPacketLength = TlsUtils.GetEncryptedPacketLength(input, offset);
|
||||
if (encryptedPacketLength == -1)
|
||||
{
|
||||
nonSslRecord = true;
|
||||
break;
|
||||
}
|
||||
|
||||
Contract.Assert(packetLength > 0);
|
||||
Contract.Assert(encryptedPacketLength > 0);
|
||||
|
||||
if (packetLength > readableBytes)
|
||||
if (encryptedPacketLength > readableBytes)
|
||||
{
|
||||
// wait until the whole packet can be read
|
||||
this.packetLength = packetLength;
|
||||
this.packetLength = encryptedPacketLength;
|
||||
break;
|
||||
}
|
||||
|
||||
int newTotalLength = totalLength + packetLength;
|
||||
int newTotalLength = totalLength + encryptedPacketLength;
|
||||
if (newTotalLength > TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH)
|
||||
{
|
||||
// Don't read too much.
|
||||
|
@ -245,8 +243,8 @@ namespace DotNetty.Handlers.Tls
|
|||
|
||||
// We have a whole packet.
|
||||
// Increment the offset to handle the next packet.
|
||||
packetLengths.Add(packetLength);
|
||||
offset += packetLength;
|
||||
packetLengths.Add(encryptedPacketLength);
|
||||
offset += encryptedPacketLength;
|
||||
totalLength = newTotalLength;
|
||||
}
|
||||
|
||||
|
@ -482,19 +480,16 @@ namespace DotNetty.Handlers.Tls
|
|||
if (!oldState.HasAny(TlsHandlerState.AuthenticationStarted))
|
||||
{
|
||||
this.state = oldState | TlsHandlerState.Authenticating;
|
||||
if (this.isServer)
|
||||
var serverSettings = settings as ServerTlsSettings;
|
||||
if (serverSettings != null)
|
||||
{
|
||||
this.sslStream.AuthenticateAsServerAsync(this.certificate, false, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false) // todo: change to begin/end
|
||||
this.sslStream.AuthenticateAsServerAsync(serverSettings.Certificate, false, serverSettings.EnabledProtocols, serverSettings.CheckCertificateRevocation)
|
||||
.ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously);
|
||||
}
|
||||
else
|
||||
{
|
||||
var certificateCollection = new X509Certificate2Collection();
|
||||
if (this.certificate != null)
|
||||
{
|
||||
certificateCollection.Add(this.certificate);
|
||||
}
|
||||
this.sslStream.AuthenticateAsClientAsync(this.targetHost, certificateCollection, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false) // todo: change to begin/end
|
||||
var clientSettings = (ClientTlsSettings)settings;
|
||||
this.sslStream.AuthenticateAsClientAsync(clientSettings.TargetHost, clientSettings.X509CertificateCollection, clientSettings.EnabledProtocols, clientSettings.CheckCertificateRevocation)
|
||||
.ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously);
|
||||
}
|
||||
return false;
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
|
||||
namespace DotNetty.Handlers.Tls
|
||||
{
|
||||
using System.Security.Authentication;
|
||||
|
||||
public abstract class TlsSettings
|
||||
{
|
||||
protected TlsSettings(SslProtocols enabledProtocols, bool checkCertificateRevocation)
|
||||
{
|
||||
this.EnabledProtocols = enabledProtocols;
|
||||
this.CheckCertificateRevocation = checkCertificateRevocation;
|
||||
}
|
||||
|
||||
public SslProtocols EnabledProtocols { get; }
|
||||
|
||||
public bool CheckCertificateRevocation { get; }
|
||||
}
|
||||
}
|
|
@ -154,7 +154,9 @@ namespace DotNetty.Handlers.Tests
|
|||
{
|
||||
var tlsCertificate = new X509Certificate2("dotnetty.com.pfx", "password");
|
||||
string targetHost = tlsCertificate.GetNameInfo(X509NameType.DnsName, false);
|
||||
TlsHandler tlsHandler = isClient ? TlsHandler.Client(targetHost, null, (_1, _2, _3, _4) => true) : TlsHandler.Server(tlsCertificate);
|
||||
TlsHandler tlsHandler = isClient ?
|
||||
new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), new ClientTlsSettings(targetHost)) :
|
||||
TlsHandler.Server(tlsCertificate);
|
||||
//var ch = new EmbeddedChannel(new LoggingHandler("BEFORE"), tlsHandler, new LoggingHandler("AFTER"));
|
||||
var ch = new EmbeddedChannel(tlsHandler);
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ namespace DotNetty.Tests.End2End
|
|||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Net;
|
||||
using System.Net.Security;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
@ -53,7 +54,7 @@ namespace DotNetty.Tests.End2End
|
|||
{
|
||||
ch.Pipeline.AddLast("server logger", new LoggingHandler("SERVER"));
|
||||
ch.Pipeline.AddLast("server tls", TlsHandler.Server(tlsCertificate));
|
||||
ch.Pipeline.AddLast("server logger2", new LoggingHandler("SER***"));
|
||||
ch.Pipeline.AddLast("server logger2", new LoggingHandler("SER***"));
|
||||
ch.Pipeline.AddLast("server prepender", new LengthFieldPrepender(2));
|
||||
ch.Pipeline.AddLast("server decoder", new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2));
|
||||
ch.Pipeline.AddLast(new EchoChannelHandler());
|
||||
|
@ -67,8 +68,9 @@ namespace DotNetty.Tests.End2End
|
|||
.Handler(new ActionChannelInitializer<ISocketChannel>(ch =>
|
||||
{
|
||||
string targetHost = tlsCertificate.GetNameInfo(X509NameType.DnsName, false);
|
||||
var clientTlsSettings = new ClientTlsSettings(targetHost);
|
||||
ch.Pipeline.AddLast("client logger", new LoggingHandler("CLIENT"));
|
||||
ch.Pipeline.AddLast("client tls", TlsHandler.Client(targetHost, null, (sender, certificate, chain, errors) => true));
|
||||
ch.Pipeline.AddLast("client tls", new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), clientTlsSettings));
|
||||
ch.Pipeline.AddLast("client logger2", new LoggingHandler("CLI***"));
|
||||
ch.Pipeline.AddLast("client prepender", new LengthFieldPrepender(2));
|
||||
ch.Pipeline.AddLast("client decoder", new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2));
|
||||
|
@ -109,7 +111,7 @@ namespace DotNetty.Tests.End2End
|
|||
Func<Task> closeServerFunc = await this.StartServerAsync(true, ch =>
|
||||
{
|
||||
ch.Pipeline.AddLast("server logger", new LoggingHandler("SERVER"));
|
||||
ch.Pipeline.AddLast("client tls", TlsHandler.Server(tlsCertificate));
|
||||
ch.Pipeline.AddLast("server tls", TlsHandler.Server(tlsCertificate));
|
||||
ch.Pipeline.AddLast("server logger2", new LoggingHandler("SER***"));
|
||||
ch.Pipeline.AddLast(
|
||||
MqttEncoder.Instance,
|
||||
|
@ -124,9 +126,11 @@ namespace DotNetty.Tests.End2End
|
|||
.Option(ChannelOption.TcpNodelay, true)
|
||||
.Handler(new ActionChannelInitializer<ISocketChannel>(ch =>
|
||||
{
|
||||
ch.Pipeline.AddLast("client logger", new LoggingHandler("CLIENT"));
|
||||
string targetHost = tlsCertificate.GetNameInfo(X509NameType.DnsName, false);
|
||||
ch.Pipeline.AddLast("client tls", TlsHandler.Client(targetHost, null, (sender, certificate, chain, errors) => true));
|
||||
var clientTlsSettings = new ClientTlsSettings(targetHost);
|
||||
|
||||
ch.Pipeline.AddLast("client logger", new LoggingHandler("CLIENT"));
|
||||
ch.Pipeline.AddLast("client tls", new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), clientTlsSettings));
|
||||
ch.Pipeline.AddLast("client logger2", new LoggingHandler("CLI***"));
|
||||
ch.Pipeline.AddLast(
|
||||
MqttEncoder.Instance,
|
||||
|
|
Загрузка…
Ссылка в новой задаче