[#590] ConnectionFactory: support CancellationToken

This commit is contained in:
xinchen 2024-09-04 17:09:53 -07:00
Родитель 1362ac56b2
Коммит d5a0fcb443
12 изменённых файлов: 154 добавлений и 151 удалений

Просмотреть файл

@ -36,6 +36,9 @@
<StartupObject />
</PropertyGroup>
<ItemGroup>
<Compile Include="..\..\..\test\Common\Extensions.cs">
<Link>Extensions.cs</Link>
</Compile>
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="..\..\..\test\Common\TestAmqpBroker.cs">
<Link>TestAmqpBroker.cs</Link>

Просмотреть файл

@ -42,6 +42,9 @@
<Reference Include="System.Xml" />
</ItemGroup>
<ItemGroup>
<Compile Include="..\..\..\test\Common\Extensions.cs">
<Link>Extensions.cs</Link>
</Compile>
<Compile Include="Program.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
</ItemGroup>

Просмотреть файл

@ -37,7 +37,7 @@ namespace PeerToPeer.Certificate
Console.WriteLine("Starting server...");
ContainerHost host = new ContainerHost(address);
var listener = host.Listeners[0];
listener.SSL.Certificate = GetCertificate("localhost");
listener.SSL.Certificate = Test.Common.Extensions.GetCertificate("localhost");
listener.SSL.ClientCertificateRequired = true;
listener.SSL.RemoteCertificateValidationCallback = ValidateServerCertificate;
listener.SASL.EnableExternalMechanism = true;
@ -50,7 +50,7 @@ namespace PeerToPeer.Certificate
Console.WriteLine("Starting client...");
ConnectionFactory factory = new ConnectionFactory();
factory.SSL.ClientCertificates.Add(GetCertificate("localhost"));
factory.SSL.ClientCertificates.Add(Test.Common.Extensions.GetCertificate("localhost"));
factory.SSL.RemoteCertificateValidationCallback = ValidateServerCertificate;
factory.SASL.Profile = SaslProfile.External;
Console.WriteLine("Sending message...");
@ -73,38 +73,6 @@ namespace PeerToPeer.Certificate
return true;
}
static X509Certificate2 GetCertificate(string certFindValue)
{
StoreLocation[] locations = new StoreLocation[] { StoreLocation.LocalMachine, StoreLocation.CurrentUser };
foreach (StoreLocation location in locations)
{
X509Store store = new X509Store(StoreName.My, location);
store.Open(OpenFlags.OpenExistingOnly);
X509Certificate2Collection collection = store.Certificates.Find(
X509FindType.FindBySubjectName,
certFindValue,
false);
if (collection.Count == 0)
{
collection = store.Certificates.Find(
X509FindType.FindByThumbprint,
certFindValue,
false);
}
store.Close();
if (collection.Count > 0)
{
return collection[0];
}
}
throw new ArgumentException("No certificate can be found using the find value " + certFindValue);
}
class MessageProcessor : IMessageProcessor
{
int IMessageProcessor.Credit

Просмотреть файл

@ -22,6 +22,7 @@ namespace Amqp
using System.Net.Security;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Amqp.Framing;
using Amqp.Handler;
@ -103,7 +104,7 @@ namespace Amqp
/// <returns>A task for the connection creation operation. On success, the result is an AMQP <see cref="Connection"/></returns>
public Task<Connection> CreateAsync(Address address, IHandler handler)
{
return this.CreateAsync(address, null, null, handler);
return this.CreateAsync(address, null, null, handler, CancellationToken.None);
}
/// <summary>
@ -116,7 +117,19 @@ namespace Amqp
/// <remarks>The Open object, when provided, is used as is, and not augmented by the AMQP settings.</remarks>
public Task<Connection> CreateAsync(Address address, Open open = null, OnOpened onOpened = null)
{
return this.CreateAsync(address, open, onOpened, null);
return this.CreateAsync(address, open, onOpened, null, CancellationToken.None);
}
/// <summary>
/// Creates a new connection with an optional protocol handler.
/// </summary>
/// <param name="address">The address of remote endpoint to connect to.</param>
/// <param name="cancellationToken">The cancellation token associated with the async operation.</param>
/// <param name="handler">The protocol handler.</param>
/// <returns>A task for the connection creation operation. On success, the result is an AMQP <see cref="Connection"/></returns>
public Task<Connection> CreateAsync(Address address, CancellationToken cancellationToken, IHandler handler = null)
{
return this.CreateAsync(address, null, null, handler, cancellationToken);
}
internal async Task ConnectAsync(Address address, SaslProfile saslProfile, Open open, Connection connection)
@ -133,14 +146,14 @@ namespace Amqp
}
}
IAsyncTransport transport = await this.CreateTransportAsync(address, saslProfile, connection.Handler).ConfigureAwait(false);
IAsyncTransport transport = await this.CreateTransportAsync(address, saslProfile, connection.Handler, CancellationToken.None).ConfigureAwait(false);
connection.Init(this.BufferManager, this.AMQP, transport, open);
AsyncPump pump = new AsyncPump(this.BufferManager, transport);
pump.Start(connection);
}
async Task<IAsyncTransport> CreateTransportAsync(Address address, SaslProfile saslProfile, IHandler handler)
async Task<IAsyncTransport> CreateTransportAsync(Address address, SaslProfile saslProfile, IHandler handler, CancellationToken cancellationToken)
{
IAsyncTransport transport;
TransportProvider provider;
@ -151,7 +164,7 @@ namespace Amqp
else if (TcpTransport.MatchScheme(address.Scheme))
{
TcpTransport tcpTransport = new TcpTransport(this.BufferManager);
await tcpTransport.ConnectAsync(address, this, handler).ConfigureAwait(false);
await tcpTransport.ConnectAsync(address, this, handler, cancellationToken).ConfigureAwait(false);
transport = tcpTransport;
}
#if NETFX
@ -183,7 +196,7 @@ namespace Amqp
return transport;
}
async Task<Connection> CreateAsync(Address address, Open open, OnOpened onOpened, IHandler handler)
async Task<Connection> CreateAsync(Address address, Open open, OnOpened onOpened, IHandler handler, CancellationToken cancellationToken)
{
SaslProfile saslProfile = null;
if (address.User != null)
@ -195,9 +208,9 @@ namespace Amqp
saslProfile = this.saslSettings.Profile;
}
IAsyncTransport transport = await this.CreateTransportAsync(address, saslProfile, handler).ConfigureAwait(false);
IAsyncTransport transport = await this.CreateTransportAsync(address, saslProfile, handler, cancellationToken).ConfigureAwait(false);
var tcs = new ConnectTaskCompletionSource(this, address, open, onOpened, handler, transport);
var tcs = new ConnectTaskCompletionSource(this, address, open, onOpened, handler, transport, cancellationToken);
return await tcs.Task.ConfigureAwait(false);
}
@ -283,12 +296,20 @@ namespace Amqp
{
readonly ConnectionFactory factory;
readonly OnOpened onOpened;
readonly IAsyncTransport transport;
readonly CancellationTokenRegistration ctr;
Connection connection;
public ConnectTaskCompletionSource(ConnectionFactory factory, Address address, Open open, OnOpened onOpened, IHandler handler, IAsyncTransport transport)
public ConnectTaskCompletionSource(ConnectionFactory factory, Address address, Open open,
OnOpened onOpened, IHandler handler, IAsyncTransport transport, CancellationToken cancellationToken)
{
this.factory = factory;
this.onOpened = onOpened;
this.transport = transport;
if (cancellationToken.CanBeCanceled)
{
this.ctr = cancellationToken.Register(o => ((ConnectTaskCompletionSource)o).OnCancel(), this);
}
this.connection = new Connection(this.factory.BufferManager, this.factory.AMQP, address, transport, open, this.OnOpen, handler);
AsyncPump pump = new AsyncPump(this.factory.BufferManager, transport);
@ -297,16 +318,30 @@ namespace Amqp
void OnOpen(IConnection connection, Open open)
{
this.ctr.Dispose();
if (this.onOpened != null)
{
this.onOpened(connection, open);
}
this.TrySetResult(this.connection);
if (!this.TrySetResult(this.connection))
{
this.transport.Close();
}
}
void OnCancel()
{
this.ctr.Dispose();
if (this.TrySetCanceled())
{
this.transport.Close();
}
}
void OnException(Exception exception)
{
this.ctr.Dispose();
this.TrySetException(exception);
}
}

Просмотреть файл

@ -21,6 +21,7 @@ namespace Amqp
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
static class SocketExtensions
@ -64,26 +65,28 @@ namespace Amqp
public static void Complete<T>(object sender, SocketAsyncEventArgs args, bool throwOnError, T result)
{
var tcs = (TaskCompletionSource<T>)args.UserToken;
args.UserToken = null;
if (tcs == null)
using (var tcs = (SocketTaskCompletionSource<T>)args.UserToken)
{
return;
}
args.UserToken = null;
if (tcs == null)
{
return;
}
if (args.SocketError != SocketError.Success && throwOnError)
{
tcs.TrySetException(new SocketException((int)args.SocketError));
}
else
{
tcs.TrySetResult(result);
if (args.SocketError != SocketError.Success && throwOnError)
{
tcs.TrySetException(new SocketException((int)args.SocketError));
}
else
{
tcs.TrySetResult(result);
}
}
}
public static Task ConnectAsync(this Socket socket, IPAddress addr, int port)
public static Task ConnectAsync(this Socket socket, IPAddress addr, int port, CancellationToken cancellationToken)
{
var tcs = new TaskCompletionSource<int>();
var tcs = new SocketTaskCompletionSource<int>(cancellationToken);
var args = new SocketAsyncEventArgs();
args.RemoteEndPoint = new IPEndPoint(addr, port);
args.UserToken = tcs;
@ -99,7 +102,7 @@ namespace Amqp
public static Task<int> ReceiveAsync(this Socket socket, SocketAsyncEventArgs args, byte[] buffer, int offset, int count)
{
var tcs = new TaskCompletionSource<int>();
var tcs = new SocketTaskCompletionSource<int>(CancellationToken.None);
args.SetBuffer(buffer, offset, count);
args.UserToken = tcs;
if (!socket.ReceiveAsync(args))
@ -112,7 +115,7 @@ namespace Amqp
public static Task<int> SendAsync(this Socket socket, SocketAsyncEventArgs args, IList<ArraySegment<byte>> buffers)
{
var tcs = new TaskCompletionSource<int>();
var tcs = new SocketTaskCompletionSource<int>(CancellationToken.None);
args.SetBuffer(null, 0, 0);
args.BufferList = buffers;
args.UserToken = tcs;
@ -126,7 +129,7 @@ namespace Amqp
public static Task<Socket> AcceptAsync(this Socket socket, SocketAsyncEventArgs args, SocketFlags flags)
{
var tcs = new TaskCompletionSource<Socket>();
var tcs = new SocketTaskCompletionSource<Socket>(CancellationToken.None);
args.UserToken = tcs;
if (!socket.AcceptAsync(args))
{
@ -135,5 +138,30 @@ namespace Amqp
return tcs.Task;
}
sealed class SocketTaskCompletionSource<T> : TaskCompletionSource<T>, IDisposable
{
readonly CancellationTokenRegistration ctr;
public SocketTaskCompletionSource(CancellationToken ct)
{
if (ct.CanBeCanceled)
{
this.ctr = ct.Register(o => OnCancel(o), this);
}
}
public void Dispose()
{
this.ctr.Dispose();
}
static void OnCancel(object state)
{
var thisPtr = (SocketTaskCompletionSource<T>)state;
thisPtr.ctr.Dispose();
thisPtr.TrySetCanceled();
}
}
}
}

Просмотреть файл

@ -22,6 +22,7 @@ namespace Amqp
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Amqp.Handler;
@ -57,10 +58,10 @@ namespace Amqp
factory.SSL.RemoteCertificateValidationCallback = noneCertValidator;
}
this.ConnectAsync(address, factory, connection.Handler).ConfigureAwait(false).GetAwaiter().GetResult();
this.ConnectAsync(address, factory, connection.Handler, CancellationToken.None).ConfigureAwait(false).GetAwaiter().GetResult();
}
public async Task ConnectAsync(Address address, ConnectionFactory factory, IHandler handler)
public async Task ConnectAsync(Address address, ConnectionFactory factory, IHandler handler, CancellationToken cancellationToken)
{
IPAddress[] ipAddresses;
IPAddress ip;
@ -88,7 +89,7 @@ namespace Amqp
socket = new Socket(ipAddresses[i].AddressFamily, SocketType.Stream, ProtocolType.Tcp);
try
{
await socket.ConnectAsync(ipAddresses[i], address.Port).ConfigureAwait(false);
await socket.ConnectAsync(ipAddresses[i], address.Port, cancellationToken).ConfigureAwait(false);
exception = null;
break;

Просмотреть файл

@ -1189,7 +1189,7 @@ namespace Test.Amqp
try
{
cert = GetCertificate(StoreLocation.LocalMachine, StoreName.My, "localhost");
cert = Test.Common.Extensions.GetCertificate("localhost");
}
catch (PlatformNotSupportedException)
{
@ -1456,27 +1456,6 @@ namespace Test.Amqp
Assert.AreEqual((message.BodySection as AmqpValue).Value, (copy.BodySection as AmqpValue).Value);
}
public static X509Certificate2 GetCertificate(StoreLocation storeLocation, StoreName storeName, string certFindValue)
{
X509Store store = new X509Store(storeName, storeLocation);
store.Open(OpenFlags.OpenExistingOnly);
X509Certificate2Collection collection = store.Certificates.Find(
X509FindType.FindBySubjectName,
certFindValue,
false);
if (collection.Count == 0)
{
throw new ArgumentException("No certificate can be found using the find value " + certFindValue);
}
#if DOTNET
store.Dispose();
#else
store.Close();
#endif
return collection[0];
}
}
class TestMessageProcessor : IMessageProcessor

Просмотреть файл

@ -55,41 +55,47 @@ namespace Test.Common
return null;
}
if (certFindValue == null)
return GetCertificate(certFindValue ?? host);
}
public static X509Certificate2 GetCertificate(string certFindValue)
{
if (TryGetCertificate(StoreLocation.CurrentUser, StoreName.My, certFindValue, out X509Certificate2 cert))
{
certFindValue = host;
return cert;
}
StoreLocation[] locations = new StoreLocation[] { StoreLocation.LocalMachine, StoreLocation.CurrentUser };
foreach (StoreLocation location in locations)
if (TryGetCertificate(StoreLocation.LocalMachine, StoreName.My, certFindValue, out cert))
{
X509Store store = new X509Store(StoreName.My, location);
store.Open(OpenFlags.OpenExistingOnly);
X509Certificate2Collection collection = store.Certificates.Find(
X509FindType.FindBySubjectName,
certFindValue,
false);
if (collection.Count == 0)
{
collection = store.Certificates.Find(
X509FindType.FindByThumbprint,
certFindValue,
false);
}
store.Close();
if (collection.Count > 0)
{
return collection[0];
}
return cert;
}
throw new ArgumentException("No certificate can be found using the find value " + certFindValue);
}
public static bool TryGetCertificate(StoreLocation storeLocation, StoreName storeName, string certFindValue, out X509Certificate2 cert)
{
X509Store store = new X509Store(storeName, storeLocation);
store.Open(OpenFlags.OpenExistingOnly);
X509Certificate2Collection collection = store.Certificates.Find(
X509FindType.FindBySubjectName,
certFindValue,
true);
if (collection.Count == 0)
{
cert = null;
return false;
}
#if DOTNET
store.Dispose();
#else
store.Close();
#endif
cert = collection[0];
return true;
}
static Dictionary<string, TraceLevel> GetTraceMapping()
{
return new Dictionary<string, TraceLevel>()

Просмотреть файл

@ -55,7 +55,7 @@ namespace Listener.IContainer
this.implicitQueue = true;
}
this.certificate = certValue == null ? null : GetCertificate(certValue);
this.certificate = certValue == null ? null : Test.Common.Extensions.GetCertificate(certValue);
string containerId = "AMQPNetLite-TestBroker-" + Guid.NewGuid().ToString().Substring(0, 8);
this.listeners = new ConnectionListener[endpoints.Count];
@ -124,41 +124,6 @@ namespace Listener.IContainer
}
}
static X509Certificate2 GetCertificate(string certFindValue)
{
StoreLocation[] locations = new StoreLocation[] { StoreLocation.LocalMachine, StoreLocation.CurrentUser };
foreach (StoreLocation location in locations)
{
X509Store store = new X509Store(StoreName.My, location);
store.Open(OpenFlags.OpenExistingOnly);
X509Certificate2Collection collection = store.Certificates.Find(
X509FindType.FindBySubjectName,
certFindValue,
false);
if (collection.Count == 0)
{
collection = store.Certificates.Find(
X509FindType.FindByThumbprint,
certFindValue,
false);
}
#if DOTNET
store.Dispose();
#else
store.Close();
#endif
if (collection.Count > 0)
{
return collection[0];
}
}
throw new ArgumentException("No certificate can be found using the find value " + certFindValue);
}
X509Certificate2 IContainer.ServiceCertificate
{
get { return this.certificate; }

Просмотреть файл

@ -85,7 +85,7 @@ namespace Test.Amqp
string testName = "WebSocketSslMutalAuthTest";
Address listenAddress = new Address("wss://localhost:18081/" + testName + "/");
X509Certificate2 cert = ContainerHostTests.GetCertificate(StoreLocation.LocalMachine, StoreName.My, "localhost");
X509Certificate2 cert = Test.Common.Extensions.GetCertificate("localhost");
string output;
int code = Exec("netsh.exe", string.Format("http show sslcert hostnameport={0}:{1}", listenAddress.Host, listenAddress.Port), out output);
@ -117,7 +117,7 @@ namespace Test.Amqp
var wssFactory = new WebSocketTransportFactory();
wssFactory.Options = o =>
{
o.ClientCertificates.Add(ContainerHostTests.GetCertificate(StoreLocation.LocalMachine, StoreName.My, listenAddress.Host));
o.ClientCertificates.Add(Test.Common.Extensions.GetCertificate(listenAddress.Host));
};
ConnectionFactory connectionFactory = new ConnectionFactory(new TransportProvider[] { wssFactory });

Просмотреть файл

@ -168,6 +168,21 @@ namespace Test.Amqp
await connection.CloseAsync();
}
[TestMethod]
public async Task ConnectionFactoryCancelAsync()
{
// Invalid address for test
var address = new Address("amqp://192.0.2.3:5672");
try
{
Connection connection = await Connection.Factory.CreateAsync(address, new CancellationTokenSource(300).Token);
Assert.IsTrue(false, "Connection creation should have been cancelled.");
}
catch (TaskCanceledException)
{
}
}
[TestMethod]
public async Task ReceiverSenderAsync()
{

Просмотреть файл

@ -6,7 +6,7 @@
<CheckEolTargetFramework>false</CheckEolTargetFramework>
</PropertyGroup>
<ItemGroup>
<Compile Include="../TestAmqpBroker/Program.cs;../Common/TestAmqpBroker.cs" />
<Compile Include="../TestAmqpBroker/Program.cs;../Common/Extensions.cs;../Common/TestAmqpBroker.cs" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Amqp.csproj" />