go-amqp/conn_test.go

1110 строки
33 KiB
Go

package amqp
import (
"context"
"crypto/tls"
"errors"
"fmt"
"math"
"testing"
"time"
"github.com/Azure/go-amqp/internal/encoding"
"github.com/Azure/go-amqp/internal/fake"
"github.com/Azure/go-amqp/internal/frames"
"github.com/Azure/go-amqp/internal/test"
"github.com/stretchr/testify/require"
)
func TestConnOptions(t *testing.T) {
tests := []struct {
label string
opts ConnOptions
verify func(t *testing.T, c *Conn)
fails bool
}{
{
label: "no options",
verify: func(t *testing.T, c *Conn) {},
},
{
label: "multiple properties",
opts: ConnOptions{
Properties: map[string]any{
"x-opt-test1": "test3",
"x-opt-test2": "test2",
},
},
verify: func(t *testing.T, c *Conn) {
wantProperties := map[encoding.Symbol]any{
"x-opt-test1": "test3",
"x-opt-test2": "test2",
}
if !test.Equal(c.properties, wantProperties) {
require.Equal(t, wantProperties, c.properties)
}
},
},
{
label: "ConnServerHostname",
opts: ConnOptions{
HostName: "testhost",
},
verify: func(t *testing.T, c *Conn) {
if c.hostname != "testhost" {
t.Errorf("unexpected host name %s", c.hostname)
}
},
},
{
label: "ConnTLSConfig",
opts: ConnOptions{
TLSConfig: &tls.Config{MinVersion: tls.VersionTLS13},
},
verify: func(t *testing.T, c *Conn) {
if c.tlsConfig.MinVersion != tls.VersionTLS13 {
t.Errorf("unexpected TLS min version %d", c.tlsConfig.MinVersion)
}
},
},
{
label: "ConnIdleTimeout_Valid",
opts: ConnOptions{
IdleTimeout: 15 * time.Minute,
},
verify: func(t *testing.T, c *Conn) {
if c.idleTimeout != 15*time.Minute {
t.Errorf("unexpected idle timeout %s", c.idleTimeout)
}
},
},
{
label: "ConnIdleTimeout_Invalid",
fails: true,
opts: ConnOptions{
IdleTimeout: -15 * time.Minute,
},
},
{
label: "ConnMaxFrameSize_Valid",
opts: ConnOptions{
MaxFrameSize: 1024,
},
verify: func(t *testing.T, c *Conn) {
if c.maxFrameSize != 1024 {
t.Errorf("unexpected max frame size %d", c.maxFrameSize)
}
},
},
{
label: "ConnMaxFrameSize_Invalid",
fails: true,
opts: ConnOptions{
MaxFrameSize: 128,
},
},
{
label: "ConnMaxSessions_Success",
opts: ConnOptions{
MaxSessions: 32768,
},
verify: func(t *testing.T, c *Conn) {
if c.channelMax != 32768 {
t.Errorf("unexpected session count %d", c.channelMax)
}
},
},
{
label: "ConnMaxSessions_TooSmall",
fails: true,
opts: ConnOptions{
MaxSessions: 0,
},
},
{
label: "ConnContainerID",
opts: ConnOptions{
ContainerID: "myid",
},
verify: func(t *testing.T, c *Conn) {
if c.containerID != "myid" {
t.Errorf("unexpected container ID %s", c.containerID)
}
},
},
}
for _, tt := range tests {
t.Run(tt.label, func(t *testing.T) {
got, err := newConn(nil, &tt.opts)
if err != nil && !tt.fails {
t.Fatal(err)
}
if !tt.fails {
tt.verify(t, got)
}
})
}
}
type fakeDialer struct {
fail bool
}
func (f fakeDialer) NetDialerDial(ctx context.Context, c *Conn, host, port string) (err error) {
err = f.error()
return
}
func (f fakeDialer) TLSDialWithDialer(ctx context.Context, c *Conn, host, port string) (err error) {
err = f.error()
return
}
func (f fakeDialer) error() error {
if f.fail {
return errors.New("failed")
}
return nil
}
func TestDialConn(t *testing.T) {
c, err := dialConn(context.Background(), ":bad url/ value", &ConnOptions{dialer: fakeDialer{}})
require.Error(t, err)
require.Nil(t, c)
c, err = dialConn(context.Background(), "http://localhost", &ConnOptions{dialer: fakeDialer{}})
require.Error(t, err)
require.Nil(t, c)
c, err = dialConn(context.Background(), "amqp://localhost", &ConnOptions{dialer: fakeDialer{}})
require.NoError(t, err)
require.NotNil(t, c)
require.Nil(t, c.tlsConfig)
c, err = dialConn(context.Background(), "amqps://localhost", &ConnOptions{dialer: fakeDialer{}})
require.NoError(t, err)
require.NotNil(t, c)
require.NotNil(t, c.tlsConfig)
c, err = dialConn(context.Background(), "amqp://localhost:12345", &ConnOptions{dialer: fakeDialer{}})
require.NoError(t, err)
require.NotNil(t, c)
c, err = dialConn(context.Background(), "amqp://username:password@localhost", &ConnOptions{dialer: fakeDialer{}})
require.NoError(t, err)
require.NotNil(t, c)
if _, ok := c.saslHandlers[saslMechanismPLAIN]; !ok {
t.Fatal("missing SASL plain handler")
}
c, err = dialConn(context.Background(), "amqp://localhost", &ConnOptions{dialer: fakeDialer{fail: true}})
require.Error(t, err)
require.Nil(t, c)
}
func TestStart(t *testing.T) {
tests := []struct {
label string
fails bool
responder func(uint16, frames.FrameBody) (fake.Response, error)
}{
{
label: "bad header",
fails: true,
responder: func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
},
},
{
label: "incorrect version",
fails: true,
responder: func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
},
},
{
label: "failed PerformOpen",
fails: true,
responder: func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return fake.Response{}, errors.New("mock write failure")
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
},
},
{
label: "unexpected PerformOpen response",
fails: true,
responder: func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return newResponse(fake.PerformBegin(0, 1))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
},
},
{
label: "success",
responder: func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return newResponse(fake.PerformOpen("container"))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
},
},
}
for _, tt := range tests {
t.Run(tt.label, func(t *testing.T) {
netConn := fake.NewNetConn(tt.responder, fake.NetConnOptions{})
conn, err := newConn(netConn, nil)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
err = conn.start(ctx)
cancel()
if tt.fails {
require.Error(t, err)
// verify that the conn was closed
err := netConn.Close()
require.ErrorIs(t, err, fake.ErrAlreadyClosed)
} else {
require.NoError(t, err)
// verify that the conn wasn't closed
err := netConn.Close()
require.NoError(t, err)
}
})
}
}
func TestClose(t *testing.T) {
netConn := fake.NewNetConn(senderFrameHandlerNoUnhandled(0, SenderSettleModeUnsettled), fake.NetConnOptions{})
conn, err := newConn(netConn, nil)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
require.NoError(t, conn.start(ctx))
cancel()
require.Nil(t, conn.Properties())
require.NoError(t, conn.Close())
// with Close error
netConn = fake.NewNetConn(senderFrameHandlerNoUnhandled(0, SenderSettleModeUnsettled), fake.NetConnOptions{})
conn, err = newConn(netConn, nil)
require.NoError(t, err)
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
require.NoError(t, conn.start(ctx))
cancel()
netConn.OnClose = func() error {
return errors.New("mock close failed")
}
// wait a bit for connReader to read from the mock
time.Sleep(100 * time.Millisecond)
require.Error(t, conn.Close())
}
func TestServerSideClose(t *testing.T) {
closeReceived := make(chan struct{})
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return newResponse(fake.PerformOpen("container"))
case *frames.PerformClose:
close(closeReceived)
return newResponse(fake.PerformClose(nil))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
conn, err := newConn(netConn, nil)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
require.NoError(t, conn.start(ctx))
cancel()
fr, err := fake.PerformClose(nil)
require.NoError(t, err)
netConn.SendFrame(fr)
<-closeReceived
err = conn.Close()
require.NoError(t, err)
// with error
closeReceived = make(chan struct{})
netConn = fake.NewNetConn(responder, fake.NetConnOptions{})
conn, err = newConn(netConn, nil)
require.NoError(t, err)
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
require.NoError(t, conn.start(ctx))
cancel()
fr, err = fake.PerformClose(&Error{Condition: "Close", Description: "mock server error"})
require.NoError(t, err)
netConn.SendFrame(fr)
<-closeReceived
err = conn.Close()
var connErr *ConnError
require.ErrorAs(t, err, &connErr)
require.Equal(t, "*Error{Condition: Close, Description: mock server error, Info: map[]}", connErr.Error())
}
func TestKeepAlives(t *testing.T) {
// closing conn can race with keep-alive ticks, so sometimes we get
// two in this test. the test needs to receive at least one keep-alive,
// so use a buffered channel to absorb any extras.
keepAlives := make(chan struct{}, 3)
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
// specify small idle timeout so we receive a lot of keep-alives
return newResponse(fake.EncodeFrame(frames.TypeAMQP, 0, &frames.PerformOpen{ContainerID: "container", IdleTimeout: 100 * time.Millisecond}))
case *fake.KeepAlive:
keepAlives <- struct{}{}
return fake.Response{}, nil
case *frames.PerformClose:
return newResponse(fake.PerformClose(nil))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
conn, err := newConn(netConn, nil)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
require.NoError(t, conn.start(ctx))
cancel()
// send keepalive
netConn.SendKeepAlive()
ctx, cancel = context.WithTimeout(context.Background(), time.Second)
defer cancel()
select {
case <-keepAlives:
// got keep-alive
case <-ctx.Done():
t.Fatal("didn't receive any keepalive frames")
}
require.NoError(t, conn.Close())
}
func TestKeepAlivesIdleTimeout(t *testing.T) {
start := make(chan struct{})
done := make(chan struct{})
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
close(start)
return newResponse(fake.EncodeFrame(frames.TypeAMQP, 0, &frames.PerformOpen{ContainerID: "container", IdleTimeout: time.Minute}))
case *fake.KeepAlive:
return fake.Response{}, nil
case *frames.PerformClose:
close(done)
return newResponse(fake.PerformClose(nil))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
const idleTimeout = 100 * time.Millisecond
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
conn, err := newConn(netConn, &ConnOptions{
IdleTimeout: idleTimeout,
})
require.NoError(t, err)
go func() {
<-start
for {
select {
case <-time.After(idleTimeout / 2):
netConn.SendKeepAlive()
case <-done:
return
}
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
require.NoError(t, conn.start(ctx))
cancel()
time.Sleep(2 * idleTimeout)
require.NoError(t, conn.Close())
}
func TestConnReaderError(t *testing.T) {
netConn := fake.NewNetConn(senderFrameHandlerNoUnhandled(0, SenderSettleModeUnsettled), fake.NetConnOptions{})
conn, err := newConn(netConn, nil)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
require.NoError(t, conn.start(ctx))
cancel()
// trigger some kind of error
netConn.ReadErr <- errors.New("failed")
// wait a bit for the connReader goroutine to read from the mock
time.Sleep(100 * time.Millisecond)
err = conn.Close()
var connErr *ConnError
if !errors.As(err, &connErr) {
t.Fatalf("unexpected error type %T", err)
}
}
func TestConnWriterError(t *testing.T) {
netConn := fake.NewNetConn(senderFrameHandlerNoUnhandled(0, SenderSettleModeUnsettled), fake.NetConnOptions{})
conn, err := newConn(netConn, nil)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
require.NoError(t, conn.start(ctx))
cancel()
netConn.WriteErr <- errors.New("boom")
// wait a bit for connReader to read from the mock
time.Sleep(100 * time.Millisecond)
err = conn.Close()
var connErr *ConnError
if !errors.As(err, &connErr) {
t.Fatalf("unexpected error type %T", err)
}
}
func TestConnWithZeroByteReads(t *testing.T) {
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return newResponse(fake.PerformOpen("container"))
case *frames.PerformClose:
return newResponse(fake.PerformClose(nil))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
netConn.SendFrame([]byte{})
conn, err := newConn(netConn, nil)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
require.NoError(t, conn.start(ctx))
cancel()
require.NoError(t, conn.Close())
}
func TestConnNegotiationTimeout(t *testing.T) {
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
return fake.Response{}, nil
}
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
conn, err := newConn(netConn, nil)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
require.ErrorIs(t, conn.start(ctx), context.DeadlineExceeded)
cancel()
}
type mockDialer struct {
resp func(uint16, frames.FrameBody) (fake.Response, error)
}
func (m mockDialer) NetDialerDial(ctx context.Context, c *Conn, host, port string) error {
c.net = fake.NewNetConn(m.resp, fake.NetConnOptions{})
return nil
}
func (mockDialer) TLSDialWithDialer(ctx context.Context, c *Conn, host, port string) error {
panic("nyi")
}
func TestClientDial(t *testing.T) {
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return newResponse(fake.PerformOpen("container"))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
client, err := Dial(ctx, "amqp://localhost", &ConnOptions{dialer: mockDialer{resp: responder}})
cancel()
require.NoError(t, err)
require.NotNil(t, client)
// error case
responder = func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return fake.Response{}, errors.New("mock read failed")
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
client, err = Dial(ctx, "amqp://localhost", &ConnOptions{dialer: mockDialer{resp: responder}})
cancel()
require.Error(t, err)
require.Nil(t, client)
}
func TestClientClose(t *testing.T) {
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return newResponse(fake.PerformOpen("container"))
case *frames.PerformClose:
return newResponse(fake.PerformClose(nil))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
client, err := Dial(ctx, "amqp://localhost", &ConnOptions{dialer: mockDialer{resp: responder}})
cancel()
require.NoError(t, err)
require.NotNil(t, client)
require.NoError(t, client.Close())
require.NoError(t, client.Close())
}
func TestSessionOptions(t *testing.T) {
tests := []struct {
label string
opt SessionOptions
verify func(t *testing.T, s *Session)
}{
{
label: "SessionMaxLinks",
opt: SessionOptions{
MaxLinks: 4096,
},
verify: func(t *testing.T, s *Session) {
if s.handleMax != 4096-1 {
t.Errorf("unexpected max links %d", s.handleMax)
}
},
},
}
for _, tt := range tests {
t.Run(tt.label, func(t *testing.T) {
session := newSession(nil, 0, &tt.opt)
tt.verify(t, session)
})
}
}
func TestClientNewSession(t *testing.T) {
const channelNum = 0
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch tt := req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return newResponse(fake.PerformOpen("container"))
case *frames.PerformBegin:
if tt.RemoteChannel != nil {
return fake.Response{}, errors.New("expected nil remote channel")
}
return newResponse(fake.PerformBegin(channelNum, remoteChannel))
case *frames.PerformClose:
return newResponse(fake.PerformClose(nil))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
client, err := NewConn(ctx, netConn, nil)
cancel()
require.NoError(t, err)
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
session, err := client.NewSession(ctx, nil)
cancel()
require.NoError(t, err)
require.NotNil(t, session)
require.Equal(t, uint16(channelNum), session.channel)
require.NoError(t, client.Close())
// creating a session after the connection has been closed returns nothing
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
session, err = client.NewSession(ctx, nil)
cancel()
var connErr *ConnError
if !errors.As(err, &connErr) {
t.Fatalf("unexpected error type %T", err)
}
require.Equal(t, "amqp: connection closed", connErr.Error())
require.Nil(t, session)
}
func TestClientMultipleSessions(t *testing.T) {
channelNum := uint16(0)
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return newResponse(fake.PerformOpen("container"))
case *frames.PerformBegin:
b, err := fake.PerformBegin(channelNum, remoteChannel)
if err != nil {
return fake.Response{}, err
}
channelNum++
return fake.Response{Payload: b}, nil
case *frames.PerformClose:
return newResponse(fake.PerformClose(nil))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
client, err := NewConn(ctx, netConn, nil)
cancel()
require.NoError(t, err)
// first session
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
session1, err := client.NewSession(ctx, nil)
cancel()
require.NoError(t, err)
require.NotNil(t, session1)
require.Equal(t, channelNum-1, session1.channel)
// second session
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
session2, err := client.NewSession(ctx, nil)
cancel()
require.NoError(t, err)
require.NotNil(t, session2)
require.Equal(t, channelNum-1, session2.channel)
require.NoError(t, client.Close())
}
func TestClientTooManySessions(t *testing.T) {
channelNum := uint16(0)
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
// return small number of max channels
return newResponse(fake.EncodeFrame(frames.TypeAMQP, 0, &frames.PerformOpen{
ChannelMax: 1,
ContainerID: "test",
IdleTimeout: time.Minute,
MaxFrameSize: 4294967295,
}))
case *frames.PerformBegin:
b, err := fake.PerformBegin(channelNum, remoteChannel)
if err != nil {
return fake.Response{}, err
}
channelNum++
return fake.Response{Payload: b}, nil
case *frames.PerformClose:
return newResponse(fake.PerformClose(nil))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
client, err := NewConn(ctx, netConn, nil)
cancel()
require.NoError(t, err)
for i := uint16(0); i < 3; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
session, err := client.NewSession(ctx, nil)
cancel()
if i < 2 {
require.NoError(t, err)
require.NotNil(t, session)
} else {
// third channel should fail
require.Error(t, err)
require.Nil(t, session)
}
}
require.NoError(t, client.Close())
}
func TestClientNewSessionMissingRemoteChannel(t *testing.T) {
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return newResponse(fake.PerformOpen("container"))
case *frames.PerformBegin:
// return begin with nil RemoteChannel
return newResponse(fake.EncodeFrame(frames.TypeAMQP, 0, &frames.PerformBegin{
NextOutgoingID: 1,
IncomingWindow: 5000,
OutgoingWindow: 1000,
HandleMax: math.MaxInt16,
}))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
client, err := NewConn(ctx, netConn, nil)
cancel()
require.NoError(t, err)
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
session, err := client.NewSession(ctx, &SessionOptions{
MaxLinks: 1,
})
cancel()
require.Error(t, err)
require.Nil(t, session)
require.Error(t, client.Close())
}
func TestClientNewSessionInvalidInitialResponse(t *testing.T) {
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return newResponse(fake.PerformOpen("container"))
case *frames.PerformBegin:
// respond with the wrong frame type
return newResponse(fake.PerformOpen("bad"))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
client, err := NewConn(ctx, netConn, nil)
cancel()
require.NoError(t, err)
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
session, err := client.NewSession(ctx, nil)
cancel()
require.Error(t, err)
require.Nil(t, session)
}
func TestClientNewSessionInvalidSecondResponseSameChannel(t *testing.T) {
firstChan := true
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return newResponse(fake.PerformOpen("container"))
case *frames.PerformClose:
return newResponse(fake.PerformClose(nil))
case *frames.PerformBegin:
if firstChan {
firstChan = false
return newResponse(fake.PerformBegin(0, remoteChannel))
}
// respond with the wrong frame type
return newResponse(fake.PerformOpen("bad"))
case *frames.PerformEnd:
return newResponse(fake.PerformEnd(0, nil))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
client, err := NewConn(ctx, netConn, nil)
cancel()
require.NoError(t, err)
// fisrt session succeeds
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
session, err := client.NewSession(ctx, nil)
cancel()
require.NoError(t, err)
require.NotNil(t, session)
// second session fails - times out as the ack is never received
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
session, err = client.NewSession(ctx, nil)
cancel()
require.Error(t, err)
require.Nil(t, session)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.NoError(t, client.Close())
}
func TestClientNewSessionInvalidSecondResponseDifferentChannel(t *testing.T) {
firstChan := true
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return newResponse(fake.PerformOpen("container"))
case *frames.PerformBegin:
if firstChan {
firstChan = false
return newResponse(fake.PerformBegin(0, remoteChannel))
}
// respond with the wrong frame type
// note that it has to be for the next channel
return newResponse(fake.PerformDisposition(encoding.RoleSender, 1, 0, nil, nil))
case *frames.PerformEnd:
return newResponse(fake.PerformEnd(0, nil))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
client, err := NewConn(ctx, netConn, nil)
cancel()
require.NoError(t, err)
// fisrt session succeeds
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
session, err := client.NewSession(ctx, nil)
cancel()
require.NoError(t, err)
require.NotNil(t, session)
// second session fails
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
session, err = client.NewSession(ctx, nil)
cancel()
require.Error(t, err)
require.Nil(t, session)
require.Error(t, client.Close())
}
func TestNewSessionTimedOut(t *testing.T) {
var sessionCount uint32
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return newResponse(fake.PerformOpen("container"))
case *frames.PerformClose:
return newResponse(fake.PerformClose(nil))
case *frames.PerformBegin:
if sessionCount == 0 {
sessionCount++
fr, err := fake.PerformBegin(0, remoteChannel)
if err != nil {
return fake.Response{}, err
}
// include a write delay so NewSession times out
return fake.Response{Payload: fr, WriteDelay: 100 * time.Millisecond}, nil
}
return newResponse(fake.PerformBegin(1, remoteChannel))
case *frames.PerformEnd:
return newResponse(fake.PerformEnd(0, nil))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
client, err := NewConn(ctx, netConn, nil)
cancel()
require.NoError(t, err)
// fisrt session fails due to deadline exceeded
ctx, cancel = context.WithTimeout(context.Background(), 20*time.Millisecond)
session, err := client.NewSession(ctx, nil)
cancel()
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, session)
// should have one session to clean up
require.Len(t, client.abandonedSessions, 1)
require.Len(t, client.sessionsByChannel, 1)
// creating a new session cleans up the old one
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
session, err = client.NewSession(ctx, nil)
cancel()
require.NoError(t, err)
require.NotNil(t, session)
require.Empty(t, client.abandonedSessions)
require.Len(t, client.sessionsByChannel, 1)
}
func TestNewSessionWriteError(t *testing.T) {
endAck := make(chan struct{})
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return newResponse(fake.PerformOpen("container"))
case *frames.PerformBegin:
return fake.Response{}, errors.New("write error")
case *frames.PerformEnd:
close(endAck)
return newResponse(fake.PerformEnd(0, nil))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
client, err := NewConn(ctx, netConn, nil)
cancel()
require.NoError(t, err)
// fisrt session succeeds
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
session, err := client.NewSession(ctx, nil)
cancel()
var connErr *ConnError
require.ErrorAs(t, err, &connErr)
require.Equal(t, "write error", connErr.Error())
require.Nil(t, session)
select {
case <-time.After(time.Second):
// expected
case <-endAck:
t.Fatal("unexpected ack")
}
}
func TestGetWriteTimeout(t *testing.T) {
conn, err := newConn(nil, nil)
require.NoError(t, err)
duration, err := conn.getWriteTimeout(context.Background())
require.NoError(t, err)
require.EqualValues(t, defaultWriteTimeout, duration)
ctx, cancel := context.WithCancel(context.Background())
duration, err = conn.getWriteTimeout(ctx)
require.NoError(t, err)
require.EqualValues(t, defaultWriteTimeout, duration)
cancel()
duration, err = conn.getWriteTimeout(ctx)
require.ErrorIs(t, err, context.Canceled)
require.Zero(t, duration)
const timeout = 10 * time.Millisecond
ctx, cancel = context.WithTimeout(context.Background(), timeout)
duration, err = conn.getWriteTimeout(ctx)
require.NoError(t, err)
require.InDelta(t, timeout, duration, float64(time.Millisecond))
// sleep until after the timeout expires
time.Sleep(2 * timeout)
duration, err = conn.getWriteTimeout(ctx)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Zero(t, duration)
cancel()
}
func TestConnSmallFrames(t *testing.T) {
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
return newResponse(fake.PerformOpen("container"))
case *frames.PerformClose:
return newResponse(fake.PerformClose(nil))
case *frames.PerformBegin:
return newResponse(fake.PerformBegin(0, 0))
case *frames.PerformEnd:
body, err := fake.PerformEnd(0, nil)
if err != nil {
return fake.Response{}, err
}
return fake.Response{
Payload: body,
ChunkSize: 8, // must be >= HeaderSize
}, nil
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
conn, err := newConn(netConn, nil)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
require.NoError(t, conn.start(ctx))
cancel()
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
session, err := conn.NewSession(ctx, nil)
require.NoError(t, err)
cancel()
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
require.NoError(t, session.Close(ctx))
cancel()
require.NoError(t, conn.Close())
}
func TestConnProperties(t *testing.T) {
responder := func(remoteChannel uint16, req frames.FrameBody) (fake.Response, error) {
switch req.(type) {
case *fake.AMQPProto:
return newResponse(fake.ProtoHeader(fake.ProtoAMQP))
case *frames.PerformOpen:
b, err := fake.EncodeFrame(frames.TypeAMQP, 0, &frames.PerformOpen{
ChannelMax: 65535,
ContainerID: "container",
IdleTimeout: time.Minute,
MaxFrameSize: 4294967295,
Properties: map[encoding.Symbol]any{
"ConnProperty1": "foo",
"ConnProperty2": 123,
},
})
return newResponse(b, err)
case *frames.PerformClose:
return newResponse(fake.PerformClose(nil))
default:
return fake.Response{}, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := fake.NewNetConn(responder, fake.NetConnOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
client, err := NewConn(ctx, netConn, nil)
cancel()
require.NoError(t, err)
require.Equal(t, map[string]any{
"ConnProperty1": "foo",
"ConnProperty2": int64(123),
}, client.Properties())
require.NoError(t, client.Close())
}