* Remove conn.mux

conn.connReader() dispatches frames directly to sessions now.
Added conn.NextSession() and conn.DeleteSession() for deterministic
session management.
Channel numbers are now recycled immediately which prompted a fix for
TestSessionClose.
Fixed various tests to handle close frame (the error was being swallowed
before).
Tests that utilize testconn were silently failing due to a bug in
Conn.Read which has been fixed.

* simplify reader/writer error handling

* clean-up

* fix testconn.SetDeadline

* update changelog

* refine error check when closing net.Conn

fixed propagation of RemoteErr on close

* always reset idle read timeout before reading

* remove connReaderRun as it's no longer necessary

* consolidate calls to closeOnce.Do

* replace magic number with constant
This commit is contained in:
Joel Hendrix 2022-12-13 14:00:03 -08:00 коммит произвёл GitHub
Родитель 1b6c612eb6
Коммит 5515441808
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 435 добавлений и 336 удалений

Просмотреть файл

@ -1,5 +1,11 @@
# Release History
## 0.18.1 (unreleased)
### Other Changes
* The connection mux goroutine has been removed, eliminating a potential source of deadlocks.
## 0.18.0 (2022-12-06)
### Features Added

527
conn.go
Просмотреть файл

@ -149,14 +149,12 @@ type Conn struct {
peerMaxFrameSize uint32 // maximum frame size peer will accept
// conn state
doneErrMu sync.Mutex // mux holds doneErr from start until shutdown completes; operations are sequential before mux is started
doneErr error // error to be returned to client
done chan struct{} // indicates the connection is done
done chan struct{} // indicates the connection has terminated
doneErr error // contains the error state returned from Close(); DO NOT TOUCH outside of conn.go until Done has been closed!
// mux
connErr chan error // connReader/Writer notifications of an error
closeMux chan struct{} // indicates that the mux should stop
closeMuxOnce sync.Once
// connReader and connWriter management
rxtxExit chan struct{} // signals connReader and connWriter to exit
closeOnce sync.Once // ensures that close() is only called once
// session tracking
channels *bitmap.Bitmap
@ -164,15 +162,15 @@ type Conn struct {
sessionsByChannelMu sync.RWMutex
// connReader
rxProto chan protoHeader // protoHeaders received by connReader
rxFrame chan frames.Frame // AMQP frames received by connReader
rxDone chan struct{}
connReaderRun chan func() // functions to be run by conn reader (set deadline on conn to run)
rxBuf buffer.Buffer // incoming bytes buffer
rxDone chan struct{} // closed when connReader exits
rxErr error // contains last error reading from c.net; DO NOT TOUCH outside of connReader until rxDone has been closed!
// connWriter
txFrame chan frames.Frame // AMQP frames to be sent by connWriter
txBuf buffer.Buffer // buffer for marshaling frames before transmitting
txDone chan struct{}
txDone chan struct{} // closed when connWriter exits
txErr error // contains last error writing to c.net; DO NOT TOUCH outside of connWriter until txDone has been closed!
}
// used to abstract the underlying dialer for testing purposes
@ -256,12 +254,8 @@ func newConn(netConn net.Conn, opts *ConnOptions) (*Conn, error) {
idleTimeout: defaultIdleTimeout,
containerID: shared.RandString(40),
done: make(chan struct{}),
connErr: make(chan error, 2), // buffered to ensure connReader/Writer won't leak
closeMux: make(chan struct{}),
rxProto: make(chan protoHeader),
rxFrame: make(chan frames.Frame),
rxtxExit: make(chan struct{}),
rxDone: make(chan struct{}),
connReaderRun: make(chan func(), 1), // buffered to allow queueing function before interrupt
txFrame: make(chan frames.Frame),
txDone: make(chan struct{}),
sessionsByChannel: map[uint16]*Session{},
@ -311,7 +305,6 @@ func newConn(netConn net.Conn, opts *ConnOptions) (*Conn, error) {
if opts.dialer != nil {
c.dialer = opts.dialer
}
return c, nil
}
@ -327,12 +320,9 @@ func (c *Conn) initTLSConfig() {
}
}
// Start establishes the connection and begins multiplexing network IO.
// start establishes the connection and begins multiplexing network IO.
// It is an error to call Start() on a connection that's been closed.
func (c *Conn) start() error {
// start reader
go c.connReader()
// run connection establishment state machine
for state := c.negotiateProto; state != nil; {
var err error
@ -340,6 +330,7 @@ func (c *Conn) start() error {
// check if err occurred
if err != nil {
close(c.txDone) // close here since connWriter hasn't been started yet
close(c.rxDone)
_ = c.Close()
return err
}
@ -349,72 +340,61 @@ func (c *Conn) start() error {
// this is because our peer can tell us the max channels they support.
c.channels = bitmap.New(uint32(c.channelMax))
// start multiplexor and writer
go c.mux()
go c.connWriter()
go c.connReader()
return nil
}
// Close closes the connection.
func (c *Conn) Close() error {
c.closeMuxOnce.Do(func() { close(c.closeMux) })
err := c.err()
c.close()
var connErr *ConnError
if errors.As(err, &connErr) && connErr.RemoteErr == nil && connErr.inner == nil {
if errors.As(c.doneErr, &connErr) && connErr.RemoteErr == nil && connErr.inner == nil {
// an empty ConnectionError means the connection was closed by the caller
// or as requested by the peer and no error was provided in the close frame.
return nil
}
return err
// there was an error during shut-down or connReader/connWriter
// experienced a terminal error
return c.doneErr
}
// close should only be called by conn.mux.
// close is called once, either from Close() or when connReader/connWriter exits
func (c *Conn) close() {
close(c.done) // notify goroutines and blocked functions to exit
c.closeOnce.Do(func() {
defer close(c.done)
// wait for writing to stop, allows it to send the final close frame
<-c.txDone
close(c.rxtxExit)
// reading from connErr in mux can race with closeMux, causing
// a pending conn read/write error to be lost. now that the
// mux has exited, drain any pending error.
select {
case err := <-c.connErr:
c.doneErr = err
default:
// no pending read/write error
}
// wait for writing to stop, allows it to send the final close frame
<-c.txDone
err := c.net.Close()
switch {
// conn.err already set
// TODO: err info is lost, log it?
case c.doneErr != nil:
closeErr := c.net.Close()
// conn.err not set and c.net.Close() returned a non-nil error
case err != nil:
c.doneErr = err
// check rxDone after closing net, otherwise may block
// for up to c.idleTimeout
<-c.rxDone
// no errors
default:
}
if errors.Is(c.rxErr, net.ErrClosed) {
// this is the expected error when the connection is closed, swallow it
c.rxErr = nil
}
// check rxDone after closing net, otherwise may block
// for up to c.idleTimeout
<-c.rxDone
}
// Err returns the connection's error state after it's been closed.
// Calling this on an open connection will block until the connection is closed.
func (c *Conn) err() error {
c.doneErrMu.Lock()
defer c.doneErrMu.Unlock()
var amqpErr *Error
if errors.As(c.doneErr, &amqpErr) {
return &ConnError{RemoteErr: amqpErr}
}
return &ConnError{inner: c.doneErr}
if c.txErr == nil && c.rxErr == nil && closeErr == nil {
// if there are no errors, it means user initiated close() and we shut down cleanly
c.doneErr = &ConnError{}
} else if amqpErr, ok := c.rxErr.(*Error); ok {
// we experienced a peer-initiated close that contained an Error. return it
c.doneErr = &ConnError{RemoteErr: amqpErr}
} else if c.txErr != nil {
c.doneErr = &ConnError{inner: c.txErr}
} else if c.rxErr != nil {
c.doneErr = &ConnError{inner: c.rxErr}
} else {
c.doneErr = &ConnError{inner: closeErr}
}
})
}
func (c *Conn) NewSession(ctx context.Context, opts *SessionOptions) (*Session, error) {
@ -454,251 +434,181 @@ func (c *Conn) deleteSession(s *Session) {
c.channels.Remove(uint32(s.channel))
}
// mux is started in it's own goroutine after initial connection establishment.
// It handles muxing of sessions, keepalives, and connection errors.
func (c *Conn) mux() {
var (
// map channels to sessions
sessionsByRemoteChannel = make(map[uint16]*Session)
)
// hold the errMu lock until error or done
c.doneErrMu.Lock()
defer c.doneErrMu.Unlock()
defer c.close() // defer order is important. c.errMu unlock indicates that connection is finally complete
// connReader reads from the net.Conn, decodes frames, and either handles
// them here as appropriate or sends them to the session.rx channel.
func (c *Conn) connReader() {
defer func() {
close(c.rxDone)
c.close()
}()
var sessionsByRemoteChannel = make(map[uint16]*Session)
var err error
for {
// check if last loop returned an error
if c.doneErr != nil {
if err != nil {
debug.Log(1, "connReader terminal error: %v", err)
c.rxErr = err
return
}
select {
// error from connReader
case c.doneErr = <-c.connErr:
var fr frames.Frame
fr, err = c.readFrame()
if err != nil {
continue
}
// new frame from connReader
case fr := <-c.rxFrame:
var (
session *Session
ok bool
)
var (
session *Session
ok bool
)
switch body := fr.Body.(type) {
// Server initiated close.
case *frames.PerformClose:
if body.Error != nil {
c.doneErr = body.Error
}
switch body := fr.Body.(type) {
// Server initiated close.
case *frames.PerformClose:
// connWriter will send the close performative ack on its way out.
// it's a SHOULD though, not a MUST.
debug.Log(3, "RX (connReader): %s", body)
if body.Error == nil {
return
// RemoteChannel should be used when frame is Begin
case *frames.PerformBegin:
if body.RemoteChannel == nil {
// since we only support remotely-initiated sessions, this is an error
// TODO: it would be ideal to not have this kill the connection
c.doneErr = fmt.Errorf("%T: nil RemoteChannel", fr.Body)
break
}
c.sessionsByChannelMu.RLock()
session, ok = c.sessionsByChannel[*body.RemoteChannel]
c.sessionsByChannelMu.RUnlock()
if !ok {
c.doneErr = fmt.Errorf("unexpected remote channel number %d", *body.RemoteChannel)
break
}
session.remoteChannel = fr.Channel
sessionsByRemoteChannel[fr.Channel] = session
case *frames.PerformEnd:
session, ok = sessionsByRemoteChannel[fr.Channel]
if !ok {
c.doneErr = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel (PerformEnd)", fr.Body, fr.Channel)
break
}
// we MUST remove the remote channel from our map as soon as we receive
// the ack (i.e. before passing it on to the session mux) on the session
// ending since the numbers are recycled.
delete(sessionsByRemoteChannel, fr.Channel)
default:
// pass on performative to the correct session
session, ok = sessionsByRemoteChannel[fr.Channel]
if !ok {
c.doneErr = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel", fr.Body, fr.Channel)
}
}
err = body.Error
continue
// RemoteChannel should be used when frame is Begin
case *frames.PerformBegin:
if body.RemoteChannel == nil {
// since we only support remotely-initiated sessions, this is an error
// TODO: it would be ideal to not have this kill the connection
err = fmt.Errorf("%T: nil RemoteChannel", fr.Body)
continue
}
c.sessionsByChannelMu.RLock()
session, ok = c.sessionsByChannel[*body.RemoteChannel]
c.sessionsByChannelMu.RUnlock()
if !ok {
err = fmt.Errorf("unexpected remote channel number %d", *body.RemoteChannel)
continue
}
select {
case session.rx <- fr:
case <-c.closeMux:
return
}
session.remoteChannel = fr.Channel
sessionsByRemoteChannel[fr.Channel] = session
// connection is complete
case <-c.closeMux:
case *frames.PerformEnd:
session, ok = sessionsByRemoteChannel[fr.Channel]
if !ok {
err = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel (PerformEnd)", fr.Body, fr.Channel)
continue
}
// we MUST remove the remote channel from our map as soon as we receive
// the ack (i.e. before passing it on to the session mux) on the session
// ending since the numbers are recycled.
delete(sessionsByRemoteChannel, fr.Channel)
default:
// pass on performative to the correct session
session, ok = sessionsByRemoteChannel[fr.Channel]
if !ok {
err = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel", fr.Body, fr.Channel)
continue
}
}
select {
case session.rx <- fr:
case <-c.rxtxExit:
return
}
}
}
// connReader reads from the net.Conn, decodes frames, and passes them
// up via the conn.rxFrame and conn.rxProto channels.
func (c *Conn) connReader() {
defer close(c.rxDone)
// readFrame reads a complete frame from c.net.
// it assumes that any read deadline has already been applied.
// used externally by SASL only.
func (c *Conn) readFrame() (frames.Frame, error) {
switch {
// Cheaply reuse free buffer space when fully read.
case c.rxBuf.Len() == 0:
c.rxBuf.Reset()
buf := &buffer.Buffer{}
// Prevent excessive/unbounded growth by shifting data to beginning of buffer.
case int64(c.rxBuf.Size()) > int64(c.maxFrameSize):
c.rxBuf.Reclaim()
}
var (
negotiating = true // true during conn establishment, check for protoHeaders
currentHeader frames.Header // keep track of the current header, for frames split across multiple TCP packets
frameInProgress bool // true if in the middle of receiving data for currentHeader
)
for {
switch {
// Cheaply reuse free buffer space when fully read.
case buf.Len() == 0:
buf.Reset()
// Prevent excessive/unbounded growth by shifting data to beginning of buffer.
case int64(buf.Size()) > int64(c.maxFrameSize):
buf.Reclaim()
}
// need to read more if buf doesn't contain the complete frame
// or there's not enough in buf to parse the header
if frameInProgress || buf.Len() < frames.HeaderSize {
if frameInProgress || c.rxBuf.Len() < frames.HeaderSize {
// we MUST reset the idle timeout before each read from net.Conn
if c.idleTimeout > 0 {
_ = c.net.SetReadDeadline(time.Now().Add(c.idleTimeout))
}
err := buf.ReadFromOnce(c.net)
err := c.rxBuf.ReadFromOnce(c.net)
if err != nil {
debug.Log(1, "connReader error: %v", err)
select {
// check if error was due to close in progress
case <-c.done:
return
// if there is a pending connReaderRun function, execute it
case f := <-c.connReaderRun:
f()
continue
// send error to mux and return
default:
c.connErr <- err
return
}
debug.Log(1, "readFrame error: %v", err)
return frames.Frame{}, err
}
}
// read more if buf doesn't contain enough to parse the header
if buf.Len() < frames.HeaderSize {
continue
}
// during negotiation, check for proto frames
if negotiating && bytes.Equal(buf.Bytes()[:4], []byte{'A', 'M', 'Q', 'P'}) {
const protoHeaderSize = 8
buf, ok := buf.Next(protoHeaderSize)
if !ok {
c.connErr <- errors.New("invalid protoHeader")
return
}
_ = buf[7]
if !bytes.Equal(buf[:4], []byte{'A', 'M', 'Q', 'P'}) {
c.connErr <- fmt.Errorf("unexpected protocol %q", buf[:4])
return
}
p := protoHeader{
ProtoID: protoID(buf[4]),
Major: buf[5],
Minor: buf[6],
Revision: buf[7],
}
if p.Major != 1 || p.Minor != 0 || p.Revision != 0 {
c.connErr <- fmt.Errorf("unexpected protocol version %d.%d.%d", p.Major, p.Minor, p.Revision)
return
}
// negotiation is complete once an AMQP proto frame is received
if p.ProtoID == protoAMQP {
negotiating = false
}
// send proto header
select {
case <-c.done:
return
case c.rxProto <- p:
}
if c.rxBuf.Len() < frames.HeaderSize {
continue
}
// parse the header if a frame isn't in progress
if !frameInProgress {
var err error
currentHeader, err = frames.ParseHeader(buf)
currentHeader, err = frames.ParseHeader(&c.rxBuf)
if err != nil {
c.connErr <- err
return
return frames.Frame{}, err
}
frameInProgress = true
}
// check size is reasonable
if currentHeader.Size > math.MaxInt32 { // make max size configurable
c.connErr <- errors.New("payload too large")
return
return frames.Frame{}, errors.New("payload too large")
}
bodySize := int64(currentHeader.Size - frames.HeaderSize)
// the full frame has been received
if int64(buf.Len()) < bodySize {
// the full frame hasn't been received, keep reading
if int64(c.rxBuf.Len()) < bodySize {
continue
}
frameInProgress = false
// check if body is empty (keepalive)
if bodySize == 0 {
debug.Log(3, "received keep-alive frame")
continue
}
// parse the frame
b, ok := buf.Next(bodySize)
b, ok := c.rxBuf.Next(bodySize)
if !ok {
c.connErr <- fmt.Errorf("buffer EOF; requested bytes: %d, actual size: %d", bodySize, buf.Len())
return
return frames.Frame{}, fmt.Errorf("buffer EOF; requested bytes: %d, actual size: %d", bodySize, c.rxBuf.Len())
}
parsedBody, err := frames.ParseBody(buffer.New(b))
if err != nil {
c.connErr <- err
return
return frames.Frame{}, err
}
// send to mux
select {
case <-c.done:
return
case c.rxFrame <- frames.Frame{Channel: currentHeader.Channel, Body: parsedBody}:
}
return frames.Frame{Channel: currentHeader.Channel, Body: parsedBody}, nil
}
}
func (c *Conn) connWriter() {
defer close(c.txDone)
defer func() {
close(c.txDone)
c.close()
}()
// disable write timeout
if c.connectTimeout != 0 {
@ -724,8 +634,8 @@ func (c *Conn) connWriter() {
var err error
for {
if err != nil {
debug.Log(1, "connWriter error: %v", err)
c.connErr <- err
debug.Log(1, "connWriter terminal error: %v", err)
c.txErr = err
return
}
@ -750,11 +660,14 @@ func (c *Conn) connWriter() {
// possibly drained, then reset.)
// connection complete
case <-c.done:
// send close
case <-c.rxtxExit:
// send close performative. note that the spec says we
// SHOULD wait for the ack but we don't HAVE to, in order
// to be resilient to bad actors etc. so we just send
// the close performative and exit.
cls := &frames.PerformClose{}
debug.Log(1, "TX (connWriter): %s", cls)
_ = c.writeFrame(frames.Frame{
c.txErr = c.writeFrame(frames.Frame{
Type: frames.TypeAMQP,
Body: cls,
})
@ -810,7 +723,7 @@ func (c *Conn) sendFrame(fr frames.Frame) error {
case c.txFrame <- fr:
return nil
case <-c.done:
return c.err()
return c.doneErr
}
}
@ -876,55 +789,77 @@ func (c *Conn) exchangeProtoHeader(pID protoID) (stateFunc, error) {
// readProtoHeader reads a protocol header packet from c.rxProto.
func (c *Conn) readProtoHeader() (protoHeader, error) {
var deadline <-chan time.Time
if c.connectTimeout != 0 {
deadline = time.After(c.connectTimeout)
const protoHeaderSize = 8
// only read from the network once our buffer has been exhausted.
// TODO: this preserves existing behavior as some tests rely on this
// implementation detail (it lets you replay a stream of bytes). we
// might want to consider removing this and fixing the tests as the
// protocol doesn't actually work this way.
if c.rxBuf.Len() == 0 {
for {
if c.connectTimeout != 0 {
_ = c.net.SetReadDeadline(time.Now().Add(c.connectTimeout))
}
err := c.rxBuf.ReadFromOnce(c.net)
if err != nil {
return protoHeader{}, err
}
// read more if buf doesn't contain enough to parse the header
if c.rxBuf.Len() >= protoHeaderSize {
break
}
}
// reset outside the loop
if c.connectTimeout != 0 {
_ = c.net.SetReadDeadline(time.Time{})
}
}
var p protoHeader
select {
case p = <-c.rxProto:
return p, nil
case err := <-c.connErr:
return p, err
case fr := <-c.rxFrame:
return p, fmt.Errorf("readProtoHeader: unexpected frame %#v", fr)
case <-deadline:
return p, errors.New("amqp: timeout waiting for response")
buf, ok := c.rxBuf.Next(protoHeaderSize)
if !ok {
return protoHeader{}, errors.New("invalid protoHeader")
}
// bounds check hint to compiler; see golang.org/issue/14808
_ = buf[protoHeaderSize-1]
if !bytes.Equal(buf[:4], []byte{'A', 'M', 'Q', 'P'}) {
return protoHeader{}, fmt.Errorf("unexpected protocol %q", buf[:4])
}
p := protoHeader{
ProtoID: protoID(buf[4]),
Major: buf[5],
Minor: buf[6],
Revision: buf[7],
}
if p.Major != 1 || p.Minor != 0 || p.Revision != 0 {
return protoHeader{}, fmt.Errorf("unexpected protocol version %d.%d.%d", p.Major, p.Minor, p.Revision)
}
return p, nil
}
// startTLS wraps the conn with TLS and returns to Client.negotiateProto
func (c *Conn) startTLS() (stateFunc, error) {
c.initTLSConfig()
// buffered so connReaderRun won't block
done := make(chan error, 1)
_ = c.net.SetReadDeadline(time.Time{}) // clear timeout
// this function will be executed by connReader
c.connReaderRun <- func() {
defer close(done)
_ = c.net.SetReadDeadline(time.Time{}) // clear timeout
// wrap existing net.Conn and perform TLS handshake
tlsConn := tls.Client(c.net, c.tlsConfig)
if c.connectTimeout != 0 {
_ = tlsConn.SetWriteDeadline(time.Now().Add(c.connectTimeout))
}
done <- tlsConn.Handshake()
// TODO: return?
// swap net.Conn
c.net = tlsConn
c.tlsComplete = true
}
// set deadline to interrupt connReader
_ = c.net.SetReadDeadline(time.Time{}.Add(1))
if err := <-done; err != nil {
// wrap existing net.Conn and perform TLS handshake
tlsConn := tls.Client(c.net, c.tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return nil, err
}
// swap net.Conn
c.net = tlsConn
c.tlsComplete = true
// go to next protocol
return c.negotiateProto, nil
}
@ -951,7 +886,7 @@ func (c *Conn) openAMQP() (stateFunc, error) {
}
// get the response
fr, err := c.readFrame()
fr, err := c.readSingleFrame()
if err != nil {
return nil, err
}
@ -981,7 +916,7 @@ func (c *Conn) openAMQP() (stateFunc, error) {
// mechanism specified by the server
func (c *Conn) negotiateSASL() (stateFunc, error) {
// read mechanisms frame
fr, err := c.readFrame()
fr, err := c.readSingleFrame()
if err != nil {
return nil, err
}
@ -1010,7 +945,7 @@ func (c *Conn) negotiateSASL() (stateFunc, error) {
// used externally by SASL only.
func (c *Conn) saslOutcome() (stateFunc, error) {
// read outcome frame
fr, err := c.readFrame()
fr, err := c.readSingleFrame()
if err != nil {
return nil, err
}
@ -1030,27 +965,21 @@ func (c *Conn) saslOutcome() (stateFunc, error) {
return c.negotiateProto, nil
}
// readFrame is used during connection establishment to read a single frame.
// readSingleFrame is used during connection establishment to read a single frame.
//
// After setup, conn.mux handles incoming frames.
// used externally by SASL only.
func (c *Conn) readFrame() (frames.Frame, error) {
var deadline <-chan time.Time
// After setup, conn.connReader handles incoming frames.
func (c *Conn) readSingleFrame() (frames.Frame, error) {
if c.connectTimeout != 0 {
deadline = time.After(c.connectTimeout)
_ = c.net.SetDeadline(time.Now().Add(c.connectTimeout))
defer func() { _ = c.net.SetDeadline(time.Time{}) }()
}
var fr frames.Frame
select {
case fr = <-c.rxFrame:
return fr, nil
case err := <-c.connErr:
return fr, err
case p := <-c.rxProto:
return fr, fmt.Errorf("unexpected protocol header %#v", p)
case <-deadline:
return fr, errors.New("amqp: timeout waiting for response")
fr, err := c.readFrame()
if err != nil {
return frames.Frame{}, err
}
return fr, nil
}
type protoHeader struct {

Просмотреть файл

@ -316,30 +316,44 @@ func TestClose(t *testing.T) {
}
func TestServerSideClose(t *testing.T) {
netConn := mocks.NewNetConn(senderFrameHandlerNoUnhandled(SenderSettleModeUnsettled))
closeReceived := make(chan struct{})
responder := func(req frames.FrameBody) ([]byte, error) {
switch req.(type) {
case *mocks.AMQPProto:
return []byte{'A', 'M', 'Q', 'P', 0, 1, 0, 0}, nil
case *frames.PerformOpen:
return mocks.PerformOpen("container")
case *frames.PerformClose:
close(closeReceived)
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := mocks.NewNetConn(responder)
conn, err := newConn(netConn, nil)
require.NoError(t, err)
require.NoError(t, conn.start())
fr, err := mocks.PerformClose(nil)
require.NoError(t, err)
netConn.SendFrame(fr)
<-closeReceived
err = conn.Close()
require.NoError(t, err)
// with error
netConn = mocks.NewNetConn(senderFrameHandlerNoUnhandled(SenderSettleModeUnsettled))
closeReceived = make(chan struct{})
netConn = mocks.NewNetConn(responder)
conn, err = newConn(netConn, nil)
require.NoError(t, err)
require.NoError(t, conn.start())
fr, err = mocks.PerformClose(&Error{Condition: "Close", Description: "mock server error"})
require.NoError(t, err)
netConn.SendFrame(fr)
// wait a bit for connReader to read from the mock
time.Sleep(100 * time.Millisecond)
<-closeReceived
err = conn.Close()
var connErr *ConnError
if !errors.As(err, &connErr) {
t.Fatalf("unexpected error type %T", err)
}
require.ErrorAs(t, err, &connErr)
require.Equal(t, "*Error{Condition: Close, Description: mock server error, Info: map[]}", connErr.Error())
}
@ -358,6 +372,8 @@ func TestKeepAlives(t *testing.T) {
case *mocks.KeepAlive:
keepAlives <- struct{}{}
return nil, nil
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -380,6 +396,48 @@ func TestKeepAlives(t *testing.T) {
require.NoError(t, conn.Close())
}
func TestKeepAlivesIdleTimeout(t *testing.T) {
responder := func(req frames.FrameBody) ([]byte, error) {
switch req.(type) {
case *mocks.AMQPProto:
return []byte{'A', 'M', 'Q', 'P', 0, 1, 0, 0}, nil
case *frames.PerformOpen:
return mocks.EncodeFrame(mocks.FrameAMQP, 0, &frames.PerformOpen{ContainerID: "container", IdleTimeout: time.Minute})
case *mocks.KeepAlive:
return nil, nil
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
}
const idleTimeout = 100 * time.Millisecond
netConn := mocks.NewNetConn(responder)
conn, err := newConn(netConn, &ConnOptions{
IdleTimeout: idleTimeout,
})
require.NoError(t, err)
require.NoError(t, conn.start())
done := make(chan struct{})
defer close(done)
go func() {
for {
select {
case <-time.After(idleTimeout / 2):
netConn.SendKeepAlive()
case <-done:
return
}
}
}()
time.Sleep(2 * idleTimeout)
require.NoError(t, conn.Close())
}
func TestConnReaderError(t *testing.T) {
netConn := mocks.NewNetConn(senderFrameHandlerNoUnhandled(SenderSettleModeUnsettled))
conn, err := newConn(netConn, nil)
@ -415,6 +473,29 @@ func TestConnWriterError(t *testing.T) {
}
}
func TestConnWithZeroByteReads(t *testing.T) {
responder := func(req frames.FrameBody) ([]byte, error) {
switch req.(type) {
case *mocks.AMQPProto:
return []byte{'A', 'M', 'Q', 'P', 0, 1, 0, 0}, nil
case *frames.PerformOpen:
return mocks.PerformOpen("container")
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
}
netConn := mocks.NewNetConn(responder)
netConn.SendFrame([]byte{})
conn, err := newConn(netConn, nil)
require.NoError(t, err)
require.NoError(t, conn.start())
require.NoError(t, conn.Close())
}
type mockDialer struct {
resp func(frames.FrameBody) ([]byte, error)
}
@ -465,6 +546,8 @@ func TestClientClose(t *testing.T) {
return []byte{'A', 'M', 'Q', 'P', 0, 1, 0, 0}, nil
case *frames.PerformOpen:
return mocks.PerformOpen("container")
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -551,6 +634,8 @@ func TestClientNewSession(t *testing.T) {
return nil, fmt.Errorf("unexpected incoming window %d", tt.OutgoingWindow)
}
return mocks.PerformBegin(channelNum)
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -593,6 +678,8 @@ func TestClientMultipleSessions(t *testing.T) {
b, err := mocks.PerformBegin(channelNum)
channelNum++
return b, err
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -636,6 +723,8 @@ func TestClientTooManySessions(t *testing.T) {
b, err := mocks.PerformBegin(channelNum)
channelNum++
return b, err
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}

Просмотреть файл

@ -13,11 +13,13 @@ import (
"github.com/Azure/go-amqp/internal/frames"
"github.com/Azure/go-amqp/internal/testconn"
"github.com/fortytw2/leaktest"
"github.com/stretchr/testify/require"
)
func fuzzConn(data []byte) int {
// Receive
client, err := NewConn(testconn.New(data), &ConnOptions{
Timeout: 10 * time.Millisecond,
IdleTimeout: 10 * time.Millisecond,
SASLType: SASLTypePlain("listen", "3aCXZYFcuZA89xe6lZkfYJvOPnTGipA3ap7NvPruBhI="),
})
@ -464,8 +466,9 @@ func TestFuzzConnCrashers(t *testing.T) {
for i, tt := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
defer leaktest.Check(t)()
fuzzConn([]byte(tt))
end := leaktest.Check(t)
require.Zero(t, fuzzConn([]byte(tt)))
end()
})
}
}

Просмотреть файл

@ -759,15 +759,14 @@ func TestMultipleSessionsOpenClose(t *testing.T) {
if localBrokerAddr == "" {
t.Skip()
}
// TODO: connReader and connWriter goroutines will leak
//checkLeaks := leaktest.Check(t)
checkLeaks := leaktest.Check(t)
// Create client
client, err := amqp.Dial(localBrokerAddr, nil)
if err != nil {
t.Fatal(err)
}
defer client.Close()
sessions := [10]*amqp.Session{}
for i := 0; i < 10; i++ {
@ -790,22 +789,24 @@ func TestMultipleSessionsOpenClose(t *testing.T) {
}
}
}
//checkLeaks()
client.Close()
checkLeaks()
}
func TestConcurrentSessionsOpenClose(t *testing.T) {
if localBrokerAddr == "" {
t.Skip()
}
// TODO: connReader and connWriter goroutines will leak
//checkLeaks := leaktest.Check(t)
checkLeaks := leaktest.Check(t)
// Create client
client, err := amqp.Dial(localBrokerAddr, nil)
if err != nil {
t.Fatal(err)
}
defer client.Close()
wg := sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
@ -827,7 +828,9 @@ func TestConcurrentSessionsOpenClose(t *testing.T) {
}()
}
wg.Wait()
//checkLeaks()
client.Close()
checkLeaks()
}
func repeatStrings(count int, strs ...string) []string {

Просмотреть файл

@ -27,6 +27,7 @@ func NewNetConn(resp func(frames.FrameBody) ([]byte, error)) *NetConn {
// writes from blocking shutdown. the size was arbitrarily picked.
readData: make(chan []byte, 10),
readClose: make(chan struct{}),
readDL: newNopTimer(), // default, no deadline
}
}
@ -47,7 +48,7 @@ type NetConn struct {
WriteErr chan error
resp func(frames.FrameBody) ([]byte, error)
readDL *time.Timer
readDL readTimer
readData chan []byte
readClose chan struct{}
closed bool
@ -90,15 +91,15 @@ func (n *NetConn) SendMultiFrameTransfer(remoteChannel uint16, linkHandle, deliv
func (n *NetConn) Read(b []byte) (int, error) {
select {
case <-n.readClose:
return 0, errors.New("mock connection was closed")
return 0, net.ErrClosed
default:
// not closed yet
}
select {
case <-n.readClose:
return 0, errors.New("mock connection was closed")
case <-n.readDL.C:
return 0, net.ErrClosed
case <-n.readDL.C():
return 0, errors.New("mock connection read deadline exceeded")
case rd := <-n.readData:
return copy(b, rd), nil
@ -116,7 +117,7 @@ func (n *NetConn) Read(b []byte) (int, error) {
func (n *NetConn) Write(b []byte) (int, error) {
select {
case <-n.readClose:
return 0, errors.New("mock connection was closed")
return 0, net.ErrClosed
default:
// not closed yet
}
@ -142,7 +143,7 @@ func (n *NetConn) Write(b []byte) (int, error) {
return len(b), nil
}
// Close is called by conn.close when conn.mux unwinds.
// Close is called by conn.close.
func (n *NetConn) Close() error {
if n.closed {
return errors.New("double close")
@ -175,9 +176,9 @@ func (n *NetConn) SetReadDeadline(t time.Time) error {
// called by conn.connReader before calling Read
// stop the last timer if available
if n.readDL != nil && !n.readDL.Stop() {
<-n.readDL.C
<-n.readDL.C()
}
n.readDL = time.NewTimer(time.Until(t))
n.readDL = timer{t: time.NewTimer(time.Until(t))}
return nil
}
@ -437,3 +438,37 @@ func encodeMultiFrameTransfer(remoteChannel uint16, linkHandle, deliveryID uint3
}
return frameData, nil
}
type readTimer interface {
C() <-chan time.Time
Stop() bool
}
func newNopTimer() nopTimer {
return nopTimer{t: make(chan time.Time)}
}
type nopTimer struct {
t chan time.Time
}
func (n nopTimer) C() <-chan time.Time {
return n.t
}
func (n nopTimer) Stop() bool {
close(n.t)
return true
}
type timer struct {
t *time.Timer
}
func (t timer) C() <-chan time.Time {
return t.t.C
}
func (t timer) Stop() bool {
return t.t.Stop()
}

Просмотреть файл

@ -35,7 +35,12 @@ func (c *Conn) Read(b []byte) (int, error) {
}
time.Sleep(1 * time.Millisecond)
n := copy(b, c.data[0])
c.data = c.data[1:]
// only move on to the next chunk if this one was entirely consumed
if n == len(c.data[0]) {
c.data = c.data[1:]
} else {
c.data[0] = c.data[0][n:]
}
return n, nil
}
@ -63,7 +68,7 @@ func (c *Conn) RemoteAddr() net.Addr {
}
func (c *Conn) SetDeadline(t time.Time) error {
return nil
return c.SetReadDeadline(t)
}
func (c *Conn) SetReadDeadline(t time.Time) error {
@ -84,5 +89,5 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return nil
return errors.New("testconn.SetWriteDeadline NYI")
}

Просмотреть файл

@ -52,6 +52,8 @@ func TestSenderMethodsNoSend(t *testing.T) {
return mocks.SenderAttach(0, tt.Name, 0, SenderSettleModeUnsettled)
case *frames.PerformDetach:
return mocks.PerformDetach(0, 0, nil)
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -222,6 +224,8 @@ func TestSenderAttachError(t *testing.T) {
// we don't need to respond to the ack
detachAck <- true
return nil, nil
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -466,6 +470,8 @@ func TestSenderSendRejectedNoDetach(t *testing.T) {
return mocks.PerformDisposition(encoding.RoleReceiver, 0, *tt.DeliveryID, nil, &encoding.StateAccepted{})
case *frames.PerformDetach:
return mocks.PerformDetach(0, 0, nil)
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -602,6 +608,8 @@ func TestSenderSendMsgTooBig(t *testing.T) {
return mocks.PerformDisposition(encoding.RoleReceiver, 0, *tt.DeliveryID, nil, &encoding.StateAccepted{})
case *frames.PerformDetach:
return mocks.PerformDetach(0, 0, nil)
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -710,6 +718,8 @@ func TestSenderSendMultiTransfer(t *testing.T) {
return mocks.PerformDisposition(encoding.RoleReceiver, 0, deliveryID, nil, &encoding.StateAccepted{})
case *frames.PerformDetach:
return mocks.PerformDetach(0, 0, nil)
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -730,7 +740,7 @@ func TestSenderSendMultiTransfer(t *testing.T) {
sendInitialFlowFrame(t, netConn, 0, 100)
ctx, cancel = context.WithTimeout(context.Background(), 100000*time.Millisecond)
ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
payload := make([]byte, maxReceiverFrameSize*4)
for i := 0; i < maxReceiverFrameSize*4; i++ {
payload[i] = byte(i % 256)

Просмотреть файл

@ -42,9 +42,9 @@ type SessionOptions struct {
// A session multiplexes Receivers.
type Session struct {
channel uint16 // session's local channel
remoteChannel uint16 // session's remote channel, owned by conn.mux
remoteChannel uint16 // session's remote channel, owned by conn.connReader
conn *Conn // underlying conn
rx chan frames.Frame // frames destined for this session are sent on this chan by conn.mux
rx chan frames.Frame // frames destined for this session are sent on this chan by conn.connReader
tx chan frames.FrameBody // non-transfer frames to be sent; session must track disposition
txTransfer chan *frames.PerformTransfer // transfer frames to be sent; session must track disposition
@ -139,7 +139,7 @@ func (s *Session) begin(ctx context.Context) error {
}()
return ctx.Err()
case <-s.conn.done:
return s.conn.err()
return s.conn.doneErr
case fr = <-s.rx:
// received ack that session was created
}
@ -148,7 +148,7 @@ func (s *Session) begin(ctx context.Context) error {
begin, ok := fr.Body.(*frames.PerformBegin)
if !ok {
// this codepath is hard to hit (impossible?). if the response isn't a PerformBegin and we've not
// yet seen the remote channel number, the default clause in conn.mux will protect us from that.
// yet seen the remote channel number, the default clause in conn.connReader will protect us from that.
// if we have seen the remote channel number then it's likely the session.mux for that channel will
// either swallow the frame or blow up in some other way, both causing this call to hang.
// deallocate session on error. we can't call
@ -279,7 +279,7 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) {
select {
// conn has completed, exit
case <-s.conn.done:
s.err = s.conn.err()
s.err = s.conn.doneErr
return
// session is being closed by user
@ -298,7 +298,7 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) {
break EndLoop
}
case <-s.conn.done:
s.err = s.conn.err()
s.err = s.conn.doneErr
return
}
}

Просмотреть файл

@ -35,6 +35,8 @@ func TestSessionClose(t *testing.T) {
}
channelNum--
return b, nil
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -48,8 +50,9 @@ func TestSessionClose(t *testing.T) {
session, err := client.NewSession(ctx, nil)
cancel()
require.NoErrorf(t, err, "iteration %d", i)
require.Equalf(t, channelNum-1, session.channel, "iteration %d", i)
ctx, cancel = context.WithTimeout(context.Background(), time.Second)
require.Equalf(t, uint16(0), session.channel, "iteration %d", i)
require.Equalf(t, channelNum-1, session.remoteChannel, "iteration %d", i)
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
err = session.Close(ctx)
cancel()
require.NoErrorf(t, err, "iteration %d", i)
@ -68,6 +71,8 @@ func TestSessionServerClose(t *testing.T) {
return mocks.PerformBegin(0)
case *frames.PerformEnd:
return nil, nil // swallow
case *frames.PerformClose:
return nil, nil // swallow
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -112,6 +117,8 @@ func TestSessionCloseTimeout(t *testing.T) {
// sleep to trigger session close timeout
time.Sleep(1 * time.Second)
return mocks.PerformEnd(0, nil)
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -199,6 +206,8 @@ func TestSessionNewReceiverBatchingOneCredit(t *testing.T) {
return mocks.ReceiverAttach(0, tt.Name, 0, ReceiverSettleModeFirst, nil)
case *frames.PerformFlow:
return nil, nil
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -242,6 +251,8 @@ func TestSessionNewReceiverBatchingEnabled(t *testing.T) {
return mocks.ReceiverAttach(0, tt.Name, 0, ReceiverSettleModeFirst, nil)
case *frames.PerformFlow:
return nil, nil
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -284,6 +295,8 @@ func TestSessionNewReceiverMismatchedLinkName(t *testing.T) {
return mocks.PerformEnd(0, nil)
case *frames.PerformAttach:
return mocks.ReceiverAttach(0, "wrong_name", 0, ReceiverSettleModeFirst, nil)
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -351,6 +364,8 @@ func TestSessionNewSenderMismatchedLinkName(t *testing.T) {
return mocks.PerformEnd(0, nil)
case *frames.PerformAttach:
return mocks.SenderAttach(0, "wrong_name", 0, SenderSettleModeUnsettled)
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -516,6 +531,8 @@ func TestSessionFlowFrameWithEcho(t *testing.T) {
return nil, nil
case *frames.PerformEnd:
return mocks.PerformEnd(0, nil)
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}
@ -562,6 +579,8 @@ func TestSessionInvalidAttachDeadlock(t *testing.T) {
case *frames.PerformAttach:
enqueueFrames(tt.Name)
return nil, nil
case *frames.PerformClose:
return mocks.PerformClose(nil)
default:
return nil, fmt.Errorf("unhandled frame %T", req)
}