From 32ebffa9c0a0a8ed5f37f422751558957ed9a387 Mon Sep 17 00:00:00 2001 From: xinchen Date: Sun, 20 Nov 2016 20:17:09 -0800 Subject: [PATCH] SASL PLAIN should throw fail on invalid credentials --- src/Net/TaskExtensions.cs | 8 +++- src/Sasl/SaslProfile.cs | 10 ++++- test/Common/ContainerHostTests.cs | 62 +++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 2 deletions(-) diff --git a/src/Net/TaskExtensions.cs b/src/Net/TaskExtensions.cs index 31d505d..2a5af1c 100644 --- a/src/Net/TaskExtensions.cs +++ b/src/Net/TaskExtensions.cs @@ -278,6 +278,7 @@ namespace Amqp ProtocolHeader myHeader = saslProfile.Start(hostname, writer); AsyncPump pump = new AsyncPump(bufferManager, transport); + SaslCode code = SaslCode.Auth; await pump.PumpAsync( header => @@ -287,12 +288,17 @@ namespace Amqp }, buffer => { - SaslCode code; return saslProfile.OnFrame(writer, buffer, out code); }); await writer.FlushAsync(); + if (code != SaslCode.Ok) + { + throw new AmqpException(ErrorCode.UnauthorizedAccess, + Fx.Format(SRAmqp.SaslNegoFailed, code)); + } + return (IAsyncTransport)saslProfile.UpgradeTransportInternal(transport); } } diff --git a/src/Sasl/SaslProfile.cs b/src/Sasl/SaslProfile.cs index 2868953..5ced82e 100644 --- a/src/Sasl/SaslProfile.cs +++ b/src/Sasl/SaslProfile.cs @@ -128,7 +128,15 @@ namespace Amqp.Sasl if (response != null) { this.SendCommand(transport, response); - shouldContinue = response.Descriptor.Code != Codec.SaslOutcome.Code; + if (response.Descriptor.Code == Codec.SaslOutcome.Code) + { + code = ((SaslOutcome)response).Code; + shouldContinue = false; + } + else + { + shouldContinue = true; + } } } diff --git a/test/Common/ContainerHostTests.cs b/test/Common/ContainerHostTests.cs index 08b4c20..ffceb48 100644 --- a/test/Common/ContainerHostTests.cs +++ b/test/Common/ContainerHostTests.cs @@ -27,6 +27,8 @@ using Amqp.Listener; using Amqp.Sasl; using Amqp.Types; using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Net.Sockets; +using System.Text; namespace Test.Amqp { @@ -735,6 +737,66 @@ namespace Test.Amqp Assert.IsTrue(listenerConnection.Principal.Identity.AuthenticationType == "PLAIN", "wrong auth type"); } + [TestMethod] + public void ContainerHostSaslPlainNegativeTest() + { + string address = new UriBuilder(this.Uri) { Password = "invalid" }.Uri.AbsoluteUri; + Trace.WriteLine(TraceLevel.Information, "sync test"); + { + try + { + var connection = new Connection(new Address(address)); + Assert.IsTrue(false, "Exception not thrown"); + } + catch (AmqpException ae) + { + Assert.AreEqual(ErrorCode.UnauthorizedAccess, ae.Error.Condition.ToString()); + } + } + + Trace.WriteLine(TraceLevel.Information, "async test"); + Task.Factory.StartNew(async () => + { + try + { + Connection connection = await Connection.Factory.CreateAsync(new Address(address)); + Assert.IsTrue(false, "Exception not thrown"); + } + catch (AmqpException ae) + { + Assert.AreEqual(ErrorCode.UnauthorizedAccess, ae.Error.Condition.ToString()); + } + }).Unwrap().GetAwaiter().GetResult(); + } + + [TestMethod] + public void ContainerHostListenerSaslPlainNegativeTest() + { + var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + socket.Connect(this.Uri.Host, this.Uri.Port); + var stream = new NetworkStream(socket); + + stream.Write(new byte[] { (byte)'A', (byte)'M', (byte)'Q', (byte)'P', 3, 1, 0, 0 }, 0, 8); + TestListener.FRM(stream, 0x41, 3, 0, new Symbol("PLAIN"), Encoding.ASCII.GetBytes("guest\0invalid")); + + byte[] buffer = new byte[1024]; + int total = 0; + int readSize = 0; + for (int i = 0; i < 1000; i++) + { + readSize = socket.Receive(buffer, 0, buffer.Length, SocketFlags.None); + if (readSize == 0) + { + break; + } + + total += readSize; + } + + Assert.IsTrue(total > 0, "No response received from listener"); + Assert.AreEqual(0, readSize, "last read should be 0 as socket should be closed"); + } + [TestMethod] public void ContainerHostX509PrincipalTest() {