зеркало из https://github.com/Azure/go-amqp.git
Fix channel number handling and add option to set channel-max.
* Correctly record and use remote vs local channel. This worked with a single channel because both sides would pick zero and there would be no mismatch. When multiple channels were used frames from the second session could be delivered to the first. * Added `ConnMaxSessions()` to allow for setting the channel-max, which currently default to 65536. * Does not currently support reusing sessions numbers after a session has been closed. * Modified integration tests to allow testing multiple sessions. Updates #6
This commit is contained in:
Родитель
a6e33075b4
Коммит
2210add7da
16
client.go
16
client.go
|
@ -78,13 +78,18 @@ func (c *Client) Close() error {
|
|||
// NewSession opens a new AMQP session to the server.
|
||||
func (c *Client) NewSession() (*Session, error) {
|
||||
// get a session allocated by Client.mux
|
||||
var s *Session
|
||||
var sResp newSessionResp
|
||||
select {
|
||||
case <-c.conn.done:
|
||||
return nil, c.conn.getErr()
|
||||
case s = <-c.conn.newSession:
|
||||
case sResp = <-c.conn.newSession:
|
||||
}
|
||||
|
||||
if sResp.err != nil {
|
||||
return nil, sResp.err
|
||||
}
|
||||
s := sResp.session
|
||||
|
||||
// send Begin to server
|
||||
begin := &performBegin{
|
||||
NextOutgoingID: 0,
|
||||
|
@ -109,9 +114,6 @@ func (c *Client) NewSession() (*Session, error) {
|
|||
return nil, errorErrorf("unexpected begin response: %+v", fr.body)
|
||||
}
|
||||
|
||||
// TODO: record negotiated settings
|
||||
s.remoteChannel = begin.RemoteChannel
|
||||
|
||||
// start Session multiplexor
|
||||
go s.mux(begin)
|
||||
|
||||
|
@ -123,7 +125,7 @@ func (c *Client) NewSession() (*Session, error) {
|
|||
// A session multiplexes Receivers.
|
||||
type Session struct {
|
||||
channel uint16 // session's local channel
|
||||
remoteChannel uint16 // session's remote channel
|
||||
remoteChannel uint16 // session's remote channel, owned by conn.mux
|
||||
conn *conn // underlying conn
|
||||
rx chan frame // frames destined for this session are sent on this chan by conn.mux
|
||||
tx chan frameBody // non-transfer frames to be sent; session must track disposition
|
||||
|
@ -167,7 +169,7 @@ func (s *Session) Close() error {
|
|||
func (s *Session) txFrame(p frameBody, done chan struct{}) {
|
||||
s.conn.wantWriteFrame(frame{
|
||||
typ: frameTypeAMQP,
|
||||
channel: s.remoteChannel,
|
||||
channel: s.channel,
|
||||
body: p,
|
||||
done: done,
|
||||
})
|
||||
|
|
95
conn.go
95
conn.go
|
@ -13,10 +13,9 @@ import (
|
|||
|
||||
// Default connection options
|
||||
const (
|
||||
DefaultMaxFrameSize = 512
|
||||
DefaultIdleTimeout = 1 * time.Minute
|
||||
|
||||
defaultChannelMax = 1
|
||||
DefaultMaxFrameSize = 512
|
||||
DefaultMaxSessions = 65536
|
||||
)
|
||||
|
||||
// Errors
|
||||
|
@ -108,6 +107,28 @@ func ConnConnectTimeout(d time.Duration) ConnOption {
|
|||
return func(c *conn) error { c.connectTimeout = d; return nil }
|
||||
}
|
||||
|
||||
// ConnMaxSessions sets the maximum number of channels.
|
||||
//
|
||||
// n must be in the range 1 to 65536.
|
||||
//
|
||||
// BUG: Currently this limits how many channels can ever
|
||||
// be opened on this connection rather than how many
|
||||
// channels can be open at the same time.
|
||||
//
|
||||
// Default: 65536.
|
||||
func ConnMaxSessions(n int) ConnOption {
|
||||
return func(c *conn) error {
|
||||
if n < 1 {
|
||||
return errorNew("max sessions cannot be less than 1")
|
||||
}
|
||||
if n > 65536 {
|
||||
return errorNew("max sessions cannot be greater than 65536")
|
||||
}
|
||||
c.channelMax = uint16(n - 1)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// conn is an AMQP connection.
|
||||
type conn struct {
|
||||
net net.Conn // underlying connection
|
||||
|
@ -142,9 +163,9 @@ type conn struct {
|
|||
closeOnce sync.Once
|
||||
|
||||
// mux
|
||||
newSession chan *Session // new Sessions are requested from mux by reading off this channel
|
||||
delSession chan *Session // session completion is indicated to mux by sending the Session on this channel
|
||||
connErr chan error // connReader/Writer notifications of an error
|
||||
newSession chan newSessionResp // new Sessions are requested from mux by reading off this channel
|
||||
delSession chan *Session // session completion is indicated to mux by sending the Session on this channel
|
||||
connErr chan error // connReader/Writer notifications of an error
|
||||
|
||||
// connReader
|
||||
rxProto chan protoHeader // protoHeaders received by connReader
|
||||
|
@ -157,19 +178,24 @@ type conn struct {
|
|||
txDone chan struct{}
|
||||
}
|
||||
|
||||
type newSessionResp struct {
|
||||
session *Session
|
||||
err error
|
||||
}
|
||||
|
||||
func newConn(netConn net.Conn, opts ...ConnOption) (*conn, error) {
|
||||
c := &conn{
|
||||
net: netConn,
|
||||
maxFrameSize: DefaultMaxFrameSize,
|
||||
peerMaxFrameSize: DefaultMaxFrameSize,
|
||||
channelMax: defaultChannelMax,
|
||||
channelMax: DefaultMaxSessions - 1, // -1 because channel-max starts at zero
|
||||
idleTimeout: DefaultIdleTimeout,
|
||||
done: make(chan struct{}),
|
||||
connErr: make(chan error, 2), // buffered to ensure connReader/Writer won't leak
|
||||
rxProto: make(chan protoHeader),
|
||||
rxFrame: make(chan frame),
|
||||
rxDone: make(chan struct{}),
|
||||
newSession: make(chan *Session),
|
||||
newSession: make(chan newSessionResp),
|
||||
delSession: make(chan *Session),
|
||||
txFrame: make(chan frame),
|
||||
txDone: make(chan struct{}),
|
||||
|
@ -247,11 +273,14 @@ func (c *conn) getErr() error {
|
|||
// mux is start in it's own goroutine after initial connection establishment.
|
||||
// It handles muxing of sessions, keepalives, and connection errors.
|
||||
func (c *conn) mux() {
|
||||
// create the next session to allocate
|
||||
nextSession := newSession(c, 0)
|
||||
var (
|
||||
// create the next session to allocate
|
||||
nextSession = newSessionResp{session: newSession(c, 0)}
|
||||
|
||||
// map channel to sessions
|
||||
sessions := make(map[uint16]*Session)
|
||||
// map channels to sessions
|
||||
sessionsByChannel = make(map[uint16]*Session)
|
||||
sessionsByRemoteChannel = make(map[uint16]*Session)
|
||||
)
|
||||
|
||||
// hold the errMu lock until error or done
|
||||
c.errMu.Lock()
|
||||
|
@ -271,16 +300,27 @@ func (c *conn) mux() {
|
|||
// new frame from connReader
|
||||
case fr := <-c.rxFrame:
|
||||
// lookup session and send to Session.mux
|
||||
ch, ok := sessions[fr.channel]
|
||||
session, ok := sessionsByRemoteChannel[fr.channel]
|
||||
if !ok {
|
||||
c.err = errorErrorf("unexpected frame: %#v", fr.body)
|
||||
continue
|
||||
// if this is a begin, RemoteChannel should be used
|
||||
begin, ok := fr.body.(*performBegin)
|
||||
if !ok {
|
||||
c.err = errorErrorf("unexpected frame: %#v", fr.body)
|
||||
continue
|
||||
}
|
||||
|
||||
session, ok = sessionsByChannel[begin.RemoteChannel]
|
||||
if !ok {
|
||||
c.err = errorErrorf("unexpected frame: %#v", fr.body)
|
||||
continue
|
||||
}
|
||||
|
||||
session.remoteChannel = fr.channel
|
||||
sessionsByRemoteChannel[fr.channel] = session
|
||||
}
|
||||
|
||||
// TODO: handle session deletion while sending frame to
|
||||
// session mux?
|
||||
select {
|
||||
case ch.rx <- fr:
|
||||
case session.rx <- fr:
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
|
@ -293,14 +333,27 @@ func (c *conn) mux() {
|
|||
// sessions are far less frequent than frames being sent to sessions,
|
||||
// this avoids the lock/unlock for session lookup.
|
||||
case c.newSession <- nextSession:
|
||||
sessions[nextSession.channel] = nextSession
|
||||
if nextSession.err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ch := nextSession.session.channel
|
||||
sessionsByChannel[ch] = nextSession.session
|
||||
|
||||
if ch >= c.channelMax {
|
||||
nextSession.session = nil
|
||||
nextSession.err = errorErrorf("reached connection channel max (%d)", c.channelMax)
|
||||
continue
|
||||
}
|
||||
|
||||
// create the next session to send
|
||||
nextSession = newSession(c, nextSession.channel+1) // TODO: enforce max session/wrapping
|
||||
nextSession.session = newSession(c, ch+1)
|
||||
|
||||
// session deletion
|
||||
case s := <-c.delSession:
|
||||
delete(sessions, s.channel)
|
||||
// TODO: allow channel number reuse
|
||||
delete(sessionsByChannel, s.channel)
|
||||
delete(sessionsByRemoteChannel, s.remoteChannel)
|
||||
|
||||
// connection is complete
|
||||
case <-c.done:
|
||||
|
|
|
@ -51,15 +51,18 @@ func TestIntegrationRoundTrip(t *testing.T) {
|
|||
defer cleanup()
|
||||
|
||||
tests := []struct {
|
||||
label string
|
||||
data []string
|
||||
label string
|
||||
sessions int
|
||||
data []string
|
||||
}{
|
||||
{
|
||||
label: "1 roundtrip, small payload",
|
||||
data: []string{"1Hello there!"},
|
||||
label: "1 roundtrip, small payload",
|
||||
sessions: 1,
|
||||
data: []string{"1Hello there!"},
|
||||
},
|
||||
{
|
||||
label: "3 roundtrip, small payload",
|
||||
label: "3 roundtrip, small payload",
|
||||
sessions: 1,
|
||||
data: []string{
|
||||
"2Hey there!",
|
||||
"2Hi there!",
|
||||
|
@ -67,13 +70,19 @@ func TestIntegrationRoundTrip(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
label: "1000 roundtrip, small payload",
|
||||
label: "1000 roundtrip, small payload",
|
||||
sessions: 1,
|
||||
data: repeatStrings(1000,
|
||||
"3Hey there!",
|
||||
"3Hi there!",
|
||||
"3Ho there!",
|
||||
),
|
||||
},
|
||||
{
|
||||
label: "1 roundtrip, small payload, 10 sessions",
|
||||
sessions: 10,
|
||||
data: []string{"1Hello there!"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -81,84 +90,88 @@ func TestIntegrationRoundTrip(t *testing.T) {
|
|||
checkLeaks := leaktest.CheckTimeout(t, 60*time.Second)
|
||||
|
||||
// Create client
|
||||
client := newClient(t, tt.label)
|
||||
client := newClient(t, tt.label,
|
||||
amqp.ConnMaxSessions(tt.sessions),
|
||||
)
|
||||
defer client.Close()
|
||||
|
||||
// Open a session
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a sender
|
||||
sender, err := session.NewSender(
|
||||
amqp.LinkAddress(queueName),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Perform test concurrently for speed and to catch races
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var sendErr error
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer sender.Close()
|
||||
|
||||
for i, data := range tt.data {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
err = sender.Send(ctx, &amqp.Message{
|
||||
Data: []byte(data),
|
||||
})
|
||||
cancel()
|
||||
if err != nil {
|
||||
sendErr = fmt.Errorf("Error after %d sends: %+v", i, err)
|
||||
return
|
||||
}
|
||||
for i := 0; i < tt.sessions; i++ {
|
||||
// Open a session
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
var receiveErr error
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// Create a receiver
|
||||
receiver, err := session.NewReceiver(
|
||||
// Create a sender
|
||||
sender, err := session.NewSender(
|
||||
amqp.LinkAddress(queueName),
|
||||
amqp.LinkCredit(10),
|
||||
amqp.LinkBatching(false),
|
||||
)
|
||||
if err != nil {
|
||||
receiveErr = err
|
||||
return
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer receiver.Close()
|
||||
|
||||
for i, data := range tt.data {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
msg, err := receiver.Receive(ctx)
|
||||
cancel()
|
||||
// Perform test concurrently for speed and to catch races
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var sendErr error
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer sender.Close()
|
||||
|
||||
for i, data := range tt.data {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
err = sender.Send(ctx, &amqp.Message{
|
||||
Data: []byte(data),
|
||||
})
|
||||
cancel()
|
||||
if err != nil {
|
||||
sendErr = fmt.Errorf("Error after %d sends: %+v", i, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var receiveErr error
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// Create a receiver
|
||||
receiver, err := session.NewReceiver(
|
||||
amqp.LinkAddress(queueName),
|
||||
amqp.LinkCredit(10),
|
||||
amqp.LinkBatching(false),
|
||||
)
|
||||
if err != nil {
|
||||
receiveErr = fmt.Errorf("Error after %d receives: %+v", i, err)
|
||||
receiveErr = err
|
||||
return
|
||||
}
|
||||
defer receiver.Close()
|
||||
|
||||
// Accept message
|
||||
msg.Accept()
|
||||
for i, data := range tt.data {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
msg, err := receiver.Receive(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
receiveErr = fmt.Errorf("Error after %d receives: %+v", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal([]byte(data), msg.Data) {
|
||||
receiveErr = fmt.Errorf("Expected received message %d to be %v, but it was %v", i+1, string(data), string(msg.Data))
|
||||
// Accept message
|
||||
msg.Accept()
|
||||
|
||||
if !bytes.Equal([]byte(data), msg.Data) {
|
||||
receiveErr = fmt.Errorf("Expected received message %d to be %v, but it was %v", i+1, string(data), string(msg.Data))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if sendErr != nil || receiveErr != nil {
|
||||
t.Error("Send error:", sendErr)
|
||||
t.Fatal("Receive error:", receiveErr)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if sendErr != nil || receiveErr != nil {
|
||||
t.Error("Send error:", sendErr)
|
||||
t.Fatal("Receive error:", receiveErr)
|
||||
}
|
||||
|
||||
client.Close() // close before leak check
|
||||
|
|
Загрузка…
Ссылка в новой задаче