From abf0504cc9c6306e5da429d0e9eaec34873d97ae Mon Sep 17 00:00:00 2001 From: xinchen Date: Thu, 20 Jun 2019 15:40:02 -0700 Subject: [PATCH] Fix IndexOutOfRangeException. --- src/Connection.cs | 24 ++++++++++---- src/Listener/ListenerConnection.cs | 7 ++-- src/Listener/ListenerSession.cs | 2 ++ src/Session.cs | 22 +++++++++---- test/Common/ProtocolTests.cs | 53 ++++++++++++++++++++++++++++++ 5 files changed, 90 insertions(+), 18 deletions(-) diff --git a/src/Connection.cs b/src/Connection.cs index 624ed18..308e25e 100644 --- a/src/Connection.cs +++ b/src/Connection.cs @@ -488,21 +488,22 @@ namespace Amqp internal virtual void OnBegin(ushort remoteChannel, Begin begin) { + this.ValidateChannel(remoteChannel); + lock (this.ThisLock) { - if (remoteChannel > this.channelMax) - { - throw new AmqpException(ErrorCode.NotAllowed, - Fx.Format(SRAmqp.AmqpHandleExceeded, this.channelMax + 1)); - } - Session session = this.GetSession(this.localSessions, begin.RemoteChannel); session.OnBegin(remoteChannel, begin); int count = this.remoteSessions.Length; if (count - 1 < remoteChannel) { - int size = Math.Min(count * 2, this.channelMax + 1); + int size = count * 2; + while (size - 1 < remoteChannel) + { + size *= 2; + } + Session[] expanded = new Session[size]; Array.Copy(this.remoteSessions, expanded, count); this.remoteSessions = expanded; @@ -519,6 +520,15 @@ namespace Amqp } } + internal void ValidateChannel(ushort channel) + { + if (channel > this.channelMax) + { + throw new AmqpException(ErrorCode.NotAllowed, + Fx.Format(SRAmqp.AmqpHandleExceeded, this.channelMax + 1)); + } + } + void OnEnd(ushort remoteChannel, End end) { Session session = this.GetSession(this.remoteSessions, remoteChannel); diff --git a/src/Listener/ListenerConnection.cs b/src/Listener/ListenerConnection.cs index 122e26f..28e72aa 100644 --- a/src/Listener/ListenerConnection.cs +++ b/src/Listener/ListenerConnection.cs @@ -59,6 +59,8 @@ namespace Amqp.Listener internal override void OnBegin(ushort remoteChannel, Begin begin) { + this.ValidateChannel(remoteChannel); + // this sends a begin to the remote peer Begin local = new Begin() { @@ -69,11 +71,6 @@ namespace Amqp.Listener HandleMax = (uint)(this.listener.AMQP.MaxLinksPerSession - 1) }; - if (begin.HandleMax < local.HandleMax) - { - local.HandleMax = begin.HandleMax; - } - var session = new ListenerSession(this, local); // this updates the local session state diff --git a/src/Listener/ListenerSession.cs b/src/Listener/ListenerSession.cs index e8c1f94..a09fe47 100644 --- a/src/Listener/ListenerSession.cs +++ b/src/Listener/ListenerSession.cs @@ -33,6 +33,8 @@ namespace Amqp.Listener internal override void OnAttach(Attach attach) { + this.ValidateHandle(attach.Handle); + var connection = (ListenerConnection)this.Connection; Link link = connection.Listener.Container.CreateLink(connection, this, attach); this.AddRemoteLink(attach.Handle, link); diff --git a/src/Session.cs b/src/Session.cs index 8d5b357..5c17f84 100644 --- a/src/Session.cs +++ b/src/Session.cs @@ -382,11 +382,7 @@ namespace Amqp internal virtual void OnAttach(Attach attach) { - if (attach.Handle > this.handleMax) - { - throw new AmqpException(ErrorCode.NotAllowed, - Fx.Format(SRAmqp.AmqpHandleExceeded, this.handleMax + 1)); - } + this.ValidateHandle(attach.Handle); Link link = null; lock (this.ThisLock) @@ -419,7 +415,12 @@ namespace Amqp int count = this.remoteLinks.Length; if (count - 1 < remoteHandle) { - int size = (int)Math.Min(count * 2 - 1, this.handleMax) + 1; + int size = count * 2; + while (size - 1 < remoteHandle) + { + size *= 2; + } + Link[] expanded = new Link[size]; Array.Copy(this.remoteLinks, expanded, count); this.remoteLinks = expanded; @@ -447,6 +448,15 @@ namespace Amqp }; } + internal void ValidateHandle(uint handle) + { + if (handle > this.handleMax) + { + throw new AmqpException(ErrorCode.NotAllowed, + Fx.Format(SRAmqp.AmqpHandleExceeded, this.handleMax + 1)); + } + } + internal Delivery RemoveDeliveries(Link link) { LinkedList list = null; diff --git a/test/Common/ProtocolTests.cs b/test/Common/ProtocolTests.cs index 2ad22de..aec3ed6 100644 --- a/test/Common/ProtocolTests.cs +++ b/test/Common/ProtocolTests.cs @@ -99,6 +99,59 @@ namespace Test.Amqp }).Unwrap().GetAwaiter().GetResult(); } + [TestMethod] + public void RemoteSessionChannelTest() + { + this.testListener.RegisterTarget(TestPoint.Begin, (stream, channel, fields) => + { + // send a large channel number to test if client can grow the table correctly + TestListener.FRM(stream, 0x11UL, 0, (ushort)(channel + 100), channel, 0u, 100u, 100u, 8u); + return TestOutcome.Stop; + }); + + string testName = "ConnectionChannelTest"; + + Open open = new Open() { ContainerId = testName, HostName = "localhost", MaxFrameSize = 2048 }; + Connection connection = new Connection(this.address, null, open, null); + for (int i = 0; i < 10; i++) + { + Session session = new Session(connection); + } + + connection.Close(); + Assert.IsTrue(connection.Error == null, "connection has error!" + connection.Error); + } + + [TestMethod] + public void RemoteLinkHandleTest() + { + this.testListener.RegisterTarget(TestPoint.Begin, (stream, channel, fields) => + { + TestListener.FRM(stream, 0x11UL, 0, channel, channel, 0u, 100u, 100u, 8000u); + return TestOutcome.Stop; + }); + + this.testListener.RegisterTarget(TestPoint.Attach, (stream, channel, fields) => + { + uint handle = (uint)fields[1]; + fields[1] = handle + 100u; + return TestOutcome.Continue; + }); + + string testName = "RemoteLinkHandleTest"; + + Open open = new Open() { ContainerId = testName, HostName = "localhost", MaxFrameSize = 2048 }; + Connection connection = new Connection(this.address, null, open, null); + Session session = new Session(connection); + for (int i = 0; i < 10; i++) + { + SenderLink sender = new SenderLink(session, "sender-" + i, "any"); + } + + connection.Close(); + Assert.IsTrue(connection.Error == null, "connection has error!" + connection.Error); + } + [TestMethod] public void ConnectionRemoteIdleTimeoutTest() {