Fix channel number handling and add option to set channel-max.

* Correctly record and use remote vs local channel. This worked with a
  single channel because both sides would pick zero and there would be
  no mismatch. When multiple channels were used frames from the second
  session could be delivered to the first.
* Added `ConnMaxSessions()` to allow for setting the channel-max, which
  currently default to 65536.
* Does not currently support reusing sessions numbers after a session
  has been closed.
* Modified integration tests to allow testing multiple sessions.

Updates #6
This commit is contained in:
Kale Blankenship 2018-01-18 19:36:02 -08:00
Родитель a6e33075b4
Коммит 2210add7da
3 изменённых файлов: 164 добавлений и 96 удалений

Просмотреть файл

@ -78,13 +78,18 @@ func (c *Client) Close() error {
// NewSession opens a new AMQP session to the server.
func (c *Client) NewSession() (*Session, error) {
// get a session allocated by Client.mux
var s *Session
var sResp newSessionResp
select {
case <-c.conn.done:
return nil, c.conn.getErr()
case s = <-c.conn.newSession:
case sResp = <-c.conn.newSession:
}
if sResp.err != nil {
return nil, sResp.err
}
s := sResp.session
// send Begin to server
begin := &performBegin{
NextOutgoingID: 0,
@ -109,9 +114,6 @@ func (c *Client) NewSession() (*Session, error) {
return nil, errorErrorf("unexpected begin response: %+v", fr.body)
}
// TODO: record negotiated settings
s.remoteChannel = begin.RemoteChannel
// start Session multiplexor
go s.mux(begin)
@ -123,7 +125,7 @@ func (c *Client) NewSession() (*Session, error) {
// A session multiplexes Receivers.
type Session struct {
channel uint16 // session's local channel
remoteChannel uint16 // session's remote channel
remoteChannel uint16 // session's remote channel, owned by conn.mux
conn *conn // underlying conn
rx chan frame // frames destined for this session are sent on this chan by conn.mux
tx chan frameBody // non-transfer frames to be sent; session must track disposition
@ -167,7 +169,7 @@ func (s *Session) Close() error {
func (s *Session) txFrame(p frameBody, done chan struct{}) {
s.conn.wantWriteFrame(frame{
typ: frameTypeAMQP,
channel: s.remoteChannel,
channel: s.channel,
body: p,
done: done,
})

95
conn.go
Просмотреть файл

@ -13,10 +13,9 @@ import (
// Default connection options
const (
DefaultMaxFrameSize = 512
DefaultIdleTimeout = 1 * time.Minute
defaultChannelMax = 1
DefaultMaxFrameSize = 512
DefaultMaxSessions = 65536
)
// Errors
@ -108,6 +107,28 @@ func ConnConnectTimeout(d time.Duration) ConnOption {
return func(c *conn) error { c.connectTimeout = d; return nil }
}
// ConnMaxSessions sets the maximum number of channels.
//
// n must be in the range 1 to 65536.
//
// BUG: Currently this limits how many channels can ever
// be opened on this connection rather than how many
// channels can be open at the same time.
//
// Default: 65536.
func ConnMaxSessions(n int) ConnOption {
return func(c *conn) error {
if n < 1 {
return errorNew("max sessions cannot be less than 1")
}
if n > 65536 {
return errorNew("max sessions cannot be greater than 65536")
}
c.channelMax = uint16(n - 1)
return nil
}
}
// conn is an AMQP connection.
type conn struct {
net net.Conn // underlying connection
@ -142,9 +163,9 @@ type conn struct {
closeOnce sync.Once
// mux
newSession chan *Session // new Sessions are requested from mux by reading off this channel
delSession chan *Session // session completion is indicated to mux by sending the Session on this channel
connErr chan error // connReader/Writer notifications of an error
newSession chan newSessionResp // new Sessions are requested from mux by reading off this channel
delSession chan *Session // session completion is indicated to mux by sending the Session on this channel
connErr chan error // connReader/Writer notifications of an error
// connReader
rxProto chan protoHeader // protoHeaders received by connReader
@ -157,19 +178,24 @@ type conn struct {
txDone chan struct{}
}
type newSessionResp struct {
session *Session
err error
}
func newConn(netConn net.Conn, opts ...ConnOption) (*conn, error) {
c := &conn{
net: netConn,
maxFrameSize: DefaultMaxFrameSize,
peerMaxFrameSize: DefaultMaxFrameSize,
channelMax: defaultChannelMax,
channelMax: DefaultMaxSessions - 1, // -1 because channel-max starts at zero
idleTimeout: DefaultIdleTimeout,
done: make(chan struct{}),
connErr: make(chan error, 2), // buffered to ensure connReader/Writer won't leak
rxProto: make(chan protoHeader),
rxFrame: make(chan frame),
rxDone: make(chan struct{}),
newSession: make(chan *Session),
newSession: make(chan newSessionResp),
delSession: make(chan *Session),
txFrame: make(chan frame),
txDone: make(chan struct{}),
@ -247,11 +273,14 @@ func (c *conn) getErr() error {
// mux is start in it's own goroutine after initial connection establishment.
// It handles muxing of sessions, keepalives, and connection errors.
func (c *conn) mux() {
// create the next session to allocate
nextSession := newSession(c, 0)
var (
// create the next session to allocate
nextSession = newSessionResp{session: newSession(c, 0)}
// map channel to sessions
sessions := make(map[uint16]*Session)
// map channels to sessions
sessionsByChannel = make(map[uint16]*Session)
sessionsByRemoteChannel = make(map[uint16]*Session)
)
// hold the errMu lock until error or done
c.errMu.Lock()
@ -271,16 +300,27 @@ func (c *conn) mux() {
// new frame from connReader
case fr := <-c.rxFrame:
// lookup session and send to Session.mux
ch, ok := sessions[fr.channel]
session, ok := sessionsByRemoteChannel[fr.channel]
if !ok {
c.err = errorErrorf("unexpected frame: %#v", fr.body)
continue
// if this is a begin, RemoteChannel should be used
begin, ok := fr.body.(*performBegin)
if !ok {
c.err = errorErrorf("unexpected frame: %#v", fr.body)
continue
}
session, ok = sessionsByChannel[begin.RemoteChannel]
if !ok {
c.err = errorErrorf("unexpected frame: %#v", fr.body)
continue
}
session.remoteChannel = fr.channel
sessionsByRemoteChannel[fr.channel] = session
}
// TODO: handle session deletion while sending frame to
// session mux?
select {
case ch.rx <- fr:
case session.rx <- fr:
case <-c.done:
return
}
@ -293,14 +333,27 @@ func (c *conn) mux() {
// sessions are far less frequent than frames being sent to sessions,
// this avoids the lock/unlock for session lookup.
case c.newSession <- nextSession:
sessions[nextSession.channel] = nextSession
if nextSession.err != nil {
continue
}
ch := nextSession.session.channel
sessionsByChannel[ch] = nextSession.session
if ch >= c.channelMax {
nextSession.session = nil
nextSession.err = errorErrorf("reached connection channel max (%d)", c.channelMax)
continue
}
// create the next session to send
nextSession = newSession(c, nextSession.channel+1) // TODO: enforce max session/wrapping
nextSession.session = newSession(c, ch+1)
// session deletion
case s := <-c.delSession:
delete(sessions, s.channel)
// TODO: allow channel number reuse
delete(sessionsByChannel, s.channel)
delete(sessionsByRemoteChannel, s.remoteChannel)
// connection is complete
case <-c.done:

Просмотреть файл

@ -51,15 +51,18 @@ func TestIntegrationRoundTrip(t *testing.T) {
defer cleanup()
tests := []struct {
label string
data []string
label string
sessions int
data []string
}{
{
label: "1 roundtrip, small payload",
data: []string{"1Hello there!"},
label: "1 roundtrip, small payload",
sessions: 1,
data: []string{"1Hello there!"},
},
{
label: "3 roundtrip, small payload",
label: "3 roundtrip, small payload",
sessions: 1,
data: []string{
"2Hey there!",
"2Hi there!",
@ -67,13 +70,19 @@ func TestIntegrationRoundTrip(t *testing.T) {
},
},
{
label: "1000 roundtrip, small payload",
label: "1000 roundtrip, small payload",
sessions: 1,
data: repeatStrings(1000,
"3Hey there!",
"3Hi there!",
"3Ho there!",
),
},
{
label: "1 roundtrip, small payload, 10 sessions",
sessions: 10,
data: []string{"1Hello there!"},
},
}
for _, tt := range tests {
@ -81,84 +90,88 @@ func TestIntegrationRoundTrip(t *testing.T) {
checkLeaks := leaktest.CheckTimeout(t, 60*time.Second)
// Create client
client := newClient(t, tt.label)
client := newClient(t, tt.label,
amqp.ConnMaxSessions(tt.sessions),
)
defer client.Close()
// Open a session
session, err := client.NewSession()
if err != nil {
t.Fatal(err)
}
// Create a sender
sender, err := session.NewSender(
amqp.LinkAddress(queueName),
)
if err != nil {
t.Fatal(err)
}
// Perform test concurrently for speed and to catch races
var wg sync.WaitGroup
wg.Add(2)
var sendErr error
go func() {
defer wg.Done()
defer sender.Close()
for i, data := range tt.data {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
err = sender.Send(ctx, &amqp.Message{
Data: []byte(data),
})
cancel()
if err != nil {
sendErr = fmt.Errorf("Error after %d sends: %+v", i, err)
return
}
for i := 0; i < tt.sessions; i++ {
// Open a session
session, err := client.NewSession()
if err != nil {
t.Fatal(err)
}
}()
var receiveErr error
go func() {
defer wg.Done()
// Create a receiver
receiver, err := session.NewReceiver(
// Create a sender
sender, err := session.NewSender(
amqp.LinkAddress(queueName),
amqp.LinkCredit(10),
amqp.LinkBatching(false),
)
if err != nil {
receiveErr = err
return
t.Fatal(err)
}
defer receiver.Close()
for i, data := range tt.data {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
msg, err := receiver.Receive(ctx)
cancel()
// Perform test concurrently for speed and to catch races
var wg sync.WaitGroup
wg.Add(2)
var sendErr error
go func() {
defer wg.Done()
defer sender.Close()
for i, data := range tt.data {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
err = sender.Send(ctx, &amqp.Message{
Data: []byte(data),
})
cancel()
if err != nil {
sendErr = fmt.Errorf("Error after %d sends: %+v", i, err)
return
}
}
}()
var receiveErr error
go func() {
defer wg.Done()
// Create a receiver
receiver, err := session.NewReceiver(
amqp.LinkAddress(queueName),
amqp.LinkCredit(10),
amqp.LinkBatching(false),
)
if err != nil {
receiveErr = fmt.Errorf("Error after %d receives: %+v", i, err)
receiveErr = err
return
}
defer receiver.Close()
// Accept message
msg.Accept()
for i, data := range tt.data {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
msg, err := receiver.Receive(ctx)
cancel()
if err != nil {
receiveErr = fmt.Errorf("Error after %d receives: %+v", i, err)
return
}
if !bytes.Equal([]byte(data), msg.Data) {
receiveErr = fmt.Errorf("Expected received message %d to be %v, but it was %v", i+1, string(data), string(msg.Data))
// Accept message
msg.Accept()
if !bytes.Equal([]byte(data), msg.Data) {
receiveErr = fmt.Errorf("Expected received message %d to be %v, but it was %v", i+1, string(data), string(msg.Data))
}
}
}()
wg.Wait()
if sendErr != nil || receiveErr != nil {
t.Error("Send error:", sendErr)
t.Fatal("Receive error:", receiveErr)
}
}()
wg.Wait()
if sendErr != nil || receiveErr != nil {
t.Error("Send error:", sendErr)
t.Fatal("Receive error:", receiveErr)
}
client.Close() // close before leak check