зеркало из https://github.com/Azure/go-amqp.git
Remove conn.mux (#173)
* Remove conn.mux conn.connReader() dispatches frames directly to sessions now. Added conn.NextSession() and conn.DeleteSession() for deterministic session management. Channel numbers are now recycled immediately which prompted a fix for TestSessionClose. Fixed various tests to handle close frame (the error was being swallowed before). Tests that utilize testconn were silently failing due to a bug in Conn.Read which has been fixed. * simplify reader/writer error handling * clean-up * fix testconn.SetDeadline * update changelog * refine error check when closing net.Conn fixed propagation of RemoteErr on close * always reset idle read timeout before reading * remove connReaderRun as it's no longer necessary * consolidate calls to closeOnce.Do * replace magic number with constant
This commit is contained in:
Родитель
1b6c612eb6
Коммит
5515441808
|
@ -1,5 +1,11 @@
|
|||
# Release History
|
||||
|
||||
## 0.18.1 (unreleased)
|
||||
|
||||
### Other Changes
|
||||
|
||||
* The connection mux goroutine has been removed, eliminating a potential source of deadlocks.
|
||||
|
||||
## 0.18.0 (2022-12-06)
|
||||
|
||||
### Features Added
|
||||
|
|
527
conn.go
527
conn.go
|
@ -149,14 +149,12 @@ type Conn struct {
|
|||
peerMaxFrameSize uint32 // maximum frame size peer will accept
|
||||
|
||||
// conn state
|
||||
doneErrMu sync.Mutex // mux holds doneErr from start until shutdown completes; operations are sequential before mux is started
|
||||
doneErr error // error to be returned to client
|
||||
done chan struct{} // indicates the connection is done
|
||||
done chan struct{} // indicates the connection has terminated
|
||||
doneErr error // contains the error state returned from Close(); DO NOT TOUCH outside of conn.go until Done has been closed!
|
||||
|
||||
// mux
|
||||
connErr chan error // connReader/Writer notifications of an error
|
||||
closeMux chan struct{} // indicates that the mux should stop
|
||||
closeMuxOnce sync.Once
|
||||
// connReader and connWriter management
|
||||
rxtxExit chan struct{} // signals connReader and connWriter to exit
|
||||
closeOnce sync.Once // ensures that close() is only called once
|
||||
|
||||
// session tracking
|
||||
channels *bitmap.Bitmap
|
||||
|
@ -164,15 +162,15 @@ type Conn struct {
|
|||
sessionsByChannelMu sync.RWMutex
|
||||
|
||||
// connReader
|
||||
rxProto chan protoHeader // protoHeaders received by connReader
|
||||
rxFrame chan frames.Frame // AMQP frames received by connReader
|
||||
rxDone chan struct{}
|
||||
connReaderRun chan func() // functions to be run by conn reader (set deadline on conn to run)
|
||||
rxBuf buffer.Buffer // incoming bytes buffer
|
||||
rxDone chan struct{} // closed when connReader exits
|
||||
rxErr error // contains last error reading from c.net; DO NOT TOUCH outside of connReader until rxDone has been closed!
|
||||
|
||||
// connWriter
|
||||
txFrame chan frames.Frame // AMQP frames to be sent by connWriter
|
||||
txBuf buffer.Buffer // buffer for marshaling frames before transmitting
|
||||
txDone chan struct{}
|
||||
txDone chan struct{} // closed when connWriter exits
|
||||
txErr error // contains last error writing to c.net; DO NOT TOUCH outside of connWriter until txDone has been closed!
|
||||
}
|
||||
|
||||
// used to abstract the underlying dialer for testing purposes
|
||||
|
@ -256,12 +254,8 @@ func newConn(netConn net.Conn, opts *ConnOptions) (*Conn, error) {
|
|||
idleTimeout: defaultIdleTimeout,
|
||||
containerID: shared.RandString(40),
|
||||
done: make(chan struct{}),
|
||||
connErr: make(chan error, 2), // buffered to ensure connReader/Writer won't leak
|
||||
closeMux: make(chan struct{}),
|
||||
rxProto: make(chan protoHeader),
|
||||
rxFrame: make(chan frames.Frame),
|
||||
rxtxExit: make(chan struct{}),
|
||||
rxDone: make(chan struct{}),
|
||||
connReaderRun: make(chan func(), 1), // buffered to allow queueing function before interrupt
|
||||
txFrame: make(chan frames.Frame),
|
||||
txDone: make(chan struct{}),
|
||||
sessionsByChannel: map[uint16]*Session{},
|
||||
|
@ -311,7 +305,6 @@ func newConn(netConn net.Conn, opts *ConnOptions) (*Conn, error) {
|
|||
if opts.dialer != nil {
|
||||
c.dialer = opts.dialer
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
|
@ -327,12 +320,9 @@ func (c *Conn) initTLSConfig() {
|
|||
}
|
||||
}
|
||||
|
||||
// Start establishes the connection and begins multiplexing network IO.
|
||||
// start establishes the connection and begins multiplexing network IO.
|
||||
// It is an error to call Start() on a connection that's been closed.
|
||||
func (c *Conn) start() error {
|
||||
// start reader
|
||||
go c.connReader()
|
||||
|
||||
// run connection establishment state machine
|
||||
for state := c.negotiateProto; state != nil; {
|
||||
var err error
|
||||
|
@ -340,6 +330,7 @@ func (c *Conn) start() error {
|
|||
// check if err occurred
|
||||
if err != nil {
|
||||
close(c.txDone) // close here since connWriter hasn't been started yet
|
||||
close(c.rxDone)
|
||||
_ = c.Close()
|
||||
return err
|
||||
}
|
||||
|
@ -349,72 +340,61 @@ func (c *Conn) start() error {
|
|||
// this is because our peer can tell us the max channels they support.
|
||||
c.channels = bitmap.New(uint32(c.channelMax))
|
||||
|
||||
// start multiplexor and writer
|
||||
go c.mux()
|
||||
go c.connWriter()
|
||||
go c.connReader()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c *Conn) Close() error {
|
||||
c.closeMuxOnce.Do(func() { close(c.closeMux) })
|
||||
err := c.err()
|
||||
c.close()
|
||||
var connErr *ConnError
|
||||
if errors.As(err, &connErr) && connErr.RemoteErr == nil && connErr.inner == nil {
|
||||
if errors.As(c.doneErr, &connErr) && connErr.RemoteErr == nil && connErr.inner == nil {
|
||||
// an empty ConnectionError means the connection was closed by the caller
|
||||
// or as requested by the peer and no error was provided in the close frame.
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
|
||||
// there was an error during shut-down or connReader/connWriter
|
||||
// experienced a terminal error
|
||||
return c.doneErr
|
||||
}
|
||||
|
||||
// close should only be called by conn.mux.
|
||||
// close is called once, either from Close() or when connReader/connWriter exits
|
||||
func (c *Conn) close() {
|
||||
close(c.done) // notify goroutines and blocked functions to exit
|
||||
c.closeOnce.Do(func() {
|
||||
defer close(c.done)
|
||||
|
||||
// wait for writing to stop, allows it to send the final close frame
|
||||
<-c.txDone
|
||||
close(c.rxtxExit)
|
||||
|
||||
// reading from connErr in mux can race with closeMux, causing
|
||||
// a pending conn read/write error to be lost. now that the
|
||||
// mux has exited, drain any pending error.
|
||||
select {
|
||||
case err := <-c.connErr:
|
||||
c.doneErr = err
|
||||
default:
|
||||
// no pending read/write error
|
||||
}
|
||||
// wait for writing to stop, allows it to send the final close frame
|
||||
<-c.txDone
|
||||
|
||||
err := c.net.Close()
|
||||
switch {
|
||||
// conn.err already set
|
||||
// TODO: err info is lost, log it?
|
||||
case c.doneErr != nil:
|
||||
closeErr := c.net.Close()
|
||||
|
||||
// conn.err not set and c.net.Close() returned a non-nil error
|
||||
case err != nil:
|
||||
c.doneErr = err
|
||||
// check rxDone after closing net, otherwise may block
|
||||
// for up to c.idleTimeout
|
||||
<-c.rxDone
|
||||
|
||||
// no errors
|
||||
default:
|
||||
}
|
||||
if errors.Is(c.rxErr, net.ErrClosed) {
|
||||
// this is the expected error when the connection is closed, swallow it
|
||||
c.rxErr = nil
|
||||
}
|
||||
|
||||
// check rxDone after closing net, otherwise may block
|
||||
// for up to c.idleTimeout
|
||||
<-c.rxDone
|
||||
}
|
||||
|
||||
// Err returns the connection's error state after it's been closed.
|
||||
// Calling this on an open connection will block until the connection is closed.
|
||||
func (c *Conn) err() error {
|
||||
c.doneErrMu.Lock()
|
||||
defer c.doneErrMu.Unlock()
|
||||
var amqpErr *Error
|
||||
if errors.As(c.doneErr, &amqpErr) {
|
||||
return &ConnError{RemoteErr: amqpErr}
|
||||
}
|
||||
return &ConnError{inner: c.doneErr}
|
||||
if c.txErr == nil && c.rxErr == nil && closeErr == nil {
|
||||
// if there are no errors, it means user initiated close() and we shut down cleanly
|
||||
c.doneErr = &ConnError{}
|
||||
} else if amqpErr, ok := c.rxErr.(*Error); ok {
|
||||
// we experienced a peer-initiated close that contained an Error. return it
|
||||
c.doneErr = &ConnError{RemoteErr: amqpErr}
|
||||
} else if c.txErr != nil {
|
||||
c.doneErr = &ConnError{inner: c.txErr}
|
||||
} else if c.rxErr != nil {
|
||||
c.doneErr = &ConnError{inner: c.rxErr}
|
||||
} else {
|
||||
c.doneErr = &ConnError{inner: closeErr}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conn) NewSession(ctx context.Context, opts *SessionOptions) (*Session, error) {
|
||||
|
@ -454,251 +434,181 @@ func (c *Conn) deleteSession(s *Session) {
|
|||
c.channels.Remove(uint32(s.channel))
|
||||
}
|
||||
|
||||
// mux is started in it's own goroutine after initial connection establishment.
|
||||
// It handles muxing of sessions, keepalives, and connection errors.
|
||||
func (c *Conn) mux() {
|
||||
var (
|
||||
// map channels to sessions
|
||||
sessionsByRemoteChannel = make(map[uint16]*Session)
|
||||
)
|
||||
|
||||
// hold the errMu lock until error or done
|
||||
c.doneErrMu.Lock()
|
||||
defer c.doneErrMu.Unlock()
|
||||
defer c.close() // defer order is important. c.errMu unlock indicates that connection is finally complete
|
||||
// connReader reads from the net.Conn, decodes frames, and either handles
|
||||
// them here as appropriate or sends them to the session.rx channel.
|
||||
func (c *Conn) connReader() {
|
||||
defer func() {
|
||||
close(c.rxDone)
|
||||
c.close()
|
||||
}()
|
||||
|
||||
var sessionsByRemoteChannel = make(map[uint16]*Session)
|
||||
var err error
|
||||
for {
|
||||
// check if last loop returned an error
|
||||
if c.doneErr != nil {
|
||||
if err != nil {
|
||||
debug.Log(1, "connReader terminal error: %v", err)
|
||||
c.rxErr = err
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
// error from connReader
|
||||
case c.doneErr = <-c.connErr:
|
||||
var fr frames.Frame
|
||||
fr, err = c.readFrame()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// new frame from connReader
|
||||
case fr := <-c.rxFrame:
|
||||
var (
|
||||
session *Session
|
||||
ok bool
|
||||
)
|
||||
var (
|
||||
session *Session
|
||||
ok bool
|
||||
)
|
||||
|
||||
switch body := fr.Body.(type) {
|
||||
// Server initiated close.
|
||||
case *frames.PerformClose:
|
||||
if body.Error != nil {
|
||||
c.doneErr = body.Error
|
||||
}
|
||||
switch body := fr.Body.(type) {
|
||||
// Server initiated close.
|
||||
case *frames.PerformClose:
|
||||
// connWriter will send the close performative ack on its way out.
|
||||
// it's a SHOULD though, not a MUST.
|
||||
debug.Log(3, "RX (connReader): %s", body)
|
||||
if body.Error == nil {
|
||||
return
|
||||
|
||||
// RemoteChannel should be used when frame is Begin
|
||||
case *frames.PerformBegin:
|
||||
if body.RemoteChannel == nil {
|
||||
// since we only support remotely-initiated sessions, this is an error
|
||||
// TODO: it would be ideal to not have this kill the connection
|
||||
c.doneErr = fmt.Errorf("%T: nil RemoteChannel", fr.Body)
|
||||
break
|
||||
}
|
||||
c.sessionsByChannelMu.RLock()
|
||||
session, ok = c.sessionsByChannel[*body.RemoteChannel]
|
||||
c.sessionsByChannelMu.RUnlock()
|
||||
if !ok {
|
||||
c.doneErr = fmt.Errorf("unexpected remote channel number %d", *body.RemoteChannel)
|
||||
break
|
||||
}
|
||||
|
||||
session.remoteChannel = fr.Channel
|
||||
sessionsByRemoteChannel[fr.Channel] = session
|
||||
|
||||
case *frames.PerformEnd:
|
||||
session, ok = sessionsByRemoteChannel[fr.Channel]
|
||||
if !ok {
|
||||
c.doneErr = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel (PerformEnd)", fr.Body, fr.Channel)
|
||||
break
|
||||
}
|
||||
// we MUST remove the remote channel from our map as soon as we receive
|
||||
// the ack (i.e. before passing it on to the session mux) on the session
|
||||
// ending since the numbers are recycled.
|
||||
delete(sessionsByRemoteChannel, fr.Channel)
|
||||
|
||||
default:
|
||||
// pass on performative to the correct session
|
||||
session, ok = sessionsByRemoteChannel[fr.Channel]
|
||||
if !ok {
|
||||
c.doneErr = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel", fr.Body, fr.Channel)
|
||||
}
|
||||
}
|
||||
err = body.Error
|
||||
continue
|
||||
|
||||
// RemoteChannel should be used when frame is Begin
|
||||
case *frames.PerformBegin:
|
||||
if body.RemoteChannel == nil {
|
||||
// since we only support remotely-initiated sessions, this is an error
|
||||
// TODO: it would be ideal to not have this kill the connection
|
||||
err = fmt.Errorf("%T: nil RemoteChannel", fr.Body)
|
||||
continue
|
||||
}
|
||||
c.sessionsByChannelMu.RLock()
|
||||
session, ok = c.sessionsByChannel[*body.RemoteChannel]
|
||||
c.sessionsByChannelMu.RUnlock()
|
||||
if !ok {
|
||||
err = fmt.Errorf("unexpected remote channel number %d", *body.RemoteChannel)
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case session.rx <- fr:
|
||||
case <-c.closeMux:
|
||||
return
|
||||
}
|
||||
session.remoteChannel = fr.Channel
|
||||
sessionsByRemoteChannel[fr.Channel] = session
|
||||
|
||||
// connection is complete
|
||||
case <-c.closeMux:
|
||||
case *frames.PerformEnd:
|
||||
session, ok = sessionsByRemoteChannel[fr.Channel]
|
||||
if !ok {
|
||||
err = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel (PerformEnd)", fr.Body, fr.Channel)
|
||||
continue
|
||||
}
|
||||
// we MUST remove the remote channel from our map as soon as we receive
|
||||
// the ack (i.e. before passing it on to the session mux) on the session
|
||||
// ending since the numbers are recycled.
|
||||
delete(sessionsByRemoteChannel, fr.Channel)
|
||||
|
||||
default:
|
||||
// pass on performative to the correct session
|
||||
session, ok = sessionsByRemoteChannel[fr.Channel]
|
||||
if !ok {
|
||||
err = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel", fr.Body, fr.Channel)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case session.rx <- fr:
|
||||
case <-c.rxtxExit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// connReader reads from the net.Conn, decodes frames, and passes them
|
||||
// up via the conn.rxFrame and conn.rxProto channels.
|
||||
func (c *Conn) connReader() {
|
||||
defer close(c.rxDone)
|
||||
// readFrame reads a complete frame from c.net.
|
||||
// it assumes that any read deadline has already been applied.
|
||||
// used externally by SASL only.
|
||||
func (c *Conn) readFrame() (frames.Frame, error) {
|
||||
switch {
|
||||
// Cheaply reuse free buffer space when fully read.
|
||||
case c.rxBuf.Len() == 0:
|
||||
c.rxBuf.Reset()
|
||||
|
||||
buf := &buffer.Buffer{}
|
||||
// Prevent excessive/unbounded growth by shifting data to beginning of buffer.
|
||||
case int64(c.rxBuf.Size()) > int64(c.maxFrameSize):
|
||||
c.rxBuf.Reclaim()
|
||||
}
|
||||
|
||||
var (
|
||||
negotiating = true // true during conn establishment, check for protoHeaders
|
||||
currentHeader frames.Header // keep track of the current header, for frames split across multiple TCP packets
|
||||
frameInProgress bool // true if in the middle of receiving data for currentHeader
|
||||
)
|
||||
|
||||
for {
|
||||
switch {
|
||||
// Cheaply reuse free buffer space when fully read.
|
||||
case buf.Len() == 0:
|
||||
buf.Reset()
|
||||
|
||||
// Prevent excessive/unbounded growth by shifting data to beginning of buffer.
|
||||
case int64(buf.Size()) > int64(c.maxFrameSize):
|
||||
buf.Reclaim()
|
||||
}
|
||||
|
||||
// need to read more if buf doesn't contain the complete frame
|
||||
// or there's not enough in buf to parse the header
|
||||
if frameInProgress || buf.Len() < frames.HeaderSize {
|
||||
if frameInProgress || c.rxBuf.Len() < frames.HeaderSize {
|
||||
// we MUST reset the idle timeout before each read from net.Conn
|
||||
if c.idleTimeout > 0 {
|
||||
_ = c.net.SetReadDeadline(time.Now().Add(c.idleTimeout))
|
||||
}
|
||||
err := buf.ReadFromOnce(c.net)
|
||||
err := c.rxBuf.ReadFromOnce(c.net)
|
||||
if err != nil {
|
||||
debug.Log(1, "connReader error: %v", err)
|
||||
select {
|
||||
// check if error was due to close in progress
|
||||
case <-c.done:
|
||||
return
|
||||
|
||||
// if there is a pending connReaderRun function, execute it
|
||||
case f := <-c.connReaderRun:
|
||||
f()
|
||||
continue
|
||||
|
||||
// send error to mux and return
|
||||
default:
|
||||
c.connErr <- err
|
||||
return
|
||||
}
|
||||
debug.Log(1, "readFrame error: %v", err)
|
||||
return frames.Frame{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// read more if buf doesn't contain enough to parse the header
|
||||
if buf.Len() < frames.HeaderSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// during negotiation, check for proto frames
|
||||
if negotiating && bytes.Equal(buf.Bytes()[:4], []byte{'A', 'M', 'Q', 'P'}) {
|
||||
const protoHeaderSize = 8
|
||||
buf, ok := buf.Next(protoHeaderSize)
|
||||
if !ok {
|
||||
c.connErr <- errors.New("invalid protoHeader")
|
||||
return
|
||||
}
|
||||
_ = buf[7]
|
||||
|
||||
if !bytes.Equal(buf[:4], []byte{'A', 'M', 'Q', 'P'}) {
|
||||
c.connErr <- fmt.Errorf("unexpected protocol %q", buf[:4])
|
||||
return
|
||||
}
|
||||
|
||||
p := protoHeader{
|
||||
ProtoID: protoID(buf[4]),
|
||||
Major: buf[5],
|
||||
Minor: buf[6],
|
||||
Revision: buf[7],
|
||||
}
|
||||
|
||||
if p.Major != 1 || p.Minor != 0 || p.Revision != 0 {
|
||||
c.connErr <- fmt.Errorf("unexpected protocol version %d.%d.%d", p.Major, p.Minor, p.Revision)
|
||||
return
|
||||
}
|
||||
|
||||
// negotiation is complete once an AMQP proto frame is received
|
||||
if p.ProtoID == protoAMQP {
|
||||
negotiating = false
|
||||
}
|
||||
|
||||
// send proto header
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
case c.rxProto <- p:
|
||||
}
|
||||
|
||||
if c.rxBuf.Len() < frames.HeaderSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// parse the header if a frame isn't in progress
|
||||
if !frameInProgress {
|
||||
var err error
|
||||
currentHeader, err = frames.ParseHeader(buf)
|
||||
currentHeader, err = frames.ParseHeader(&c.rxBuf)
|
||||
if err != nil {
|
||||
c.connErr <- err
|
||||
return
|
||||
return frames.Frame{}, err
|
||||
}
|
||||
frameInProgress = true
|
||||
}
|
||||
|
||||
// check size is reasonable
|
||||
if currentHeader.Size > math.MaxInt32 { // make max size configurable
|
||||
c.connErr <- errors.New("payload too large")
|
||||
return
|
||||
return frames.Frame{}, errors.New("payload too large")
|
||||
}
|
||||
|
||||
bodySize := int64(currentHeader.Size - frames.HeaderSize)
|
||||
|
||||
// the full frame has been received
|
||||
if int64(buf.Len()) < bodySize {
|
||||
// the full frame hasn't been received, keep reading
|
||||
if int64(c.rxBuf.Len()) < bodySize {
|
||||
continue
|
||||
}
|
||||
frameInProgress = false
|
||||
|
||||
// check if body is empty (keepalive)
|
||||
if bodySize == 0 {
|
||||
debug.Log(3, "received keep-alive frame")
|
||||
continue
|
||||
}
|
||||
|
||||
// parse the frame
|
||||
b, ok := buf.Next(bodySize)
|
||||
b, ok := c.rxBuf.Next(bodySize)
|
||||
if !ok {
|
||||
c.connErr <- fmt.Errorf("buffer EOF; requested bytes: %d, actual size: %d", bodySize, buf.Len())
|
||||
return
|
||||
return frames.Frame{}, fmt.Errorf("buffer EOF; requested bytes: %d, actual size: %d", bodySize, c.rxBuf.Len())
|
||||
}
|
||||
|
||||
parsedBody, err := frames.ParseBody(buffer.New(b))
|
||||
if err != nil {
|
||||
c.connErr <- err
|
||||
return
|
||||
return frames.Frame{}, err
|
||||
}
|
||||
|
||||
// send to mux
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
case c.rxFrame <- frames.Frame{Channel: currentHeader.Channel, Body: parsedBody}:
|
||||
}
|
||||
return frames.Frame{Channel: currentHeader.Channel, Body: parsedBody}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) connWriter() {
|
||||
defer close(c.txDone)
|
||||
defer func() {
|
||||
close(c.txDone)
|
||||
c.close()
|
||||
}()
|
||||
|
||||
// disable write timeout
|
||||
if c.connectTimeout != 0 {
|
||||
|
@ -724,8 +634,8 @@ func (c *Conn) connWriter() {
|
|||
var err error
|
||||
for {
|
||||
if err != nil {
|
||||
debug.Log(1, "connWriter error: %v", err)
|
||||
c.connErr <- err
|
||||
debug.Log(1, "connWriter terminal error: %v", err)
|
||||
c.txErr = err
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -750,11 +660,14 @@ func (c *Conn) connWriter() {
|
|||
// possibly drained, then reset.)
|
||||
|
||||
// connection complete
|
||||
case <-c.done:
|
||||
// send close
|
||||
case <-c.rxtxExit:
|
||||
// send close performative. note that the spec says we
|
||||
// SHOULD wait for the ack but we don't HAVE to, in order
|
||||
// to be resilient to bad actors etc. so we just send
|
||||
// the close performative and exit.
|
||||
cls := &frames.PerformClose{}
|
||||
debug.Log(1, "TX (connWriter): %s", cls)
|
||||
_ = c.writeFrame(frames.Frame{
|
||||
c.txErr = c.writeFrame(frames.Frame{
|
||||
Type: frames.TypeAMQP,
|
||||
Body: cls,
|
||||
})
|
||||
|
@ -810,7 +723,7 @@ func (c *Conn) sendFrame(fr frames.Frame) error {
|
|||
case c.txFrame <- fr:
|
||||
return nil
|
||||
case <-c.done:
|
||||
return c.err()
|
||||
return c.doneErr
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -876,55 +789,77 @@ func (c *Conn) exchangeProtoHeader(pID protoID) (stateFunc, error) {
|
|||
|
||||
// readProtoHeader reads a protocol header packet from c.rxProto.
|
||||
func (c *Conn) readProtoHeader() (protoHeader, error) {
|
||||
var deadline <-chan time.Time
|
||||
if c.connectTimeout != 0 {
|
||||
deadline = time.After(c.connectTimeout)
|
||||
const protoHeaderSize = 8
|
||||
|
||||
// only read from the network once our buffer has been exhausted.
|
||||
// TODO: this preserves existing behavior as some tests rely on this
|
||||
// implementation detail (it lets you replay a stream of bytes). we
|
||||
// might want to consider removing this and fixing the tests as the
|
||||
// protocol doesn't actually work this way.
|
||||
if c.rxBuf.Len() == 0 {
|
||||
for {
|
||||
if c.connectTimeout != 0 {
|
||||
_ = c.net.SetReadDeadline(time.Now().Add(c.connectTimeout))
|
||||
}
|
||||
|
||||
err := c.rxBuf.ReadFromOnce(c.net)
|
||||
if err != nil {
|
||||
return protoHeader{}, err
|
||||
}
|
||||
|
||||
// read more if buf doesn't contain enough to parse the header
|
||||
if c.rxBuf.Len() >= protoHeaderSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// reset outside the loop
|
||||
if c.connectTimeout != 0 {
|
||||
_ = c.net.SetReadDeadline(time.Time{})
|
||||
}
|
||||
}
|
||||
var p protoHeader
|
||||
select {
|
||||
case p = <-c.rxProto:
|
||||
return p, nil
|
||||
case err := <-c.connErr:
|
||||
return p, err
|
||||
case fr := <-c.rxFrame:
|
||||
return p, fmt.Errorf("readProtoHeader: unexpected frame %#v", fr)
|
||||
case <-deadline:
|
||||
return p, errors.New("amqp: timeout waiting for response")
|
||||
|
||||
buf, ok := c.rxBuf.Next(protoHeaderSize)
|
||||
if !ok {
|
||||
return protoHeader{}, errors.New("invalid protoHeader")
|
||||
}
|
||||
// bounds check hint to compiler; see golang.org/issue/14808
|
||||
_ = buf[protoHeaderSize-1]
|
||||
|
||||
if !bytes.Equal(buf[:4], []byte{'A', 'M', 'Q', 'P'}) {
|
||||
return protoHeader{}, fmt.Errorf("unexpected protocol %q", buf[:4])
|
||||
}
|
||||
|
||||
p := protoHeader{
|
||||
ProtoID: protoID(buf[4]),
|
||||
Major: buf[5],
|
||||
Minor: buf[6],
|
||||
Revision: buf[7],
|
||||
}
|
||||
|
||||
if p.Major != 1 || p.Minor != 0 || p.Revision != 0 {
|
||||
return protoHeader{}, fmt.Errorf("unexpected protocol version %d.%d.%d", p.Major, p.Minor, p.Revision)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// startTLS wraps the conn with TLS and returns to Client.negotiateProto
|
||||
func (c *Conn) startTLS() (stateFunc, error) {
|
||||
c.initTLSConfig()
|
||||
|
||||
// buffered so connReaderRun won't block
|
||||
done := make(chan error, 1)
|
||||
_ = c.net.SetReadDeadline(time.Time{}) // clear timeout
|
||||
|
||||
// this function will be executed by connReader
|
||||
c.connReaderRun <- func() {
|
||||
defer close(done)
|
||||
_ = c.net.SetReadDeadline(time.Time{}) // clear timeout
|
||||
|
||||
// wrap existing net.Conn and perform TLS handshake
|
||||
tlsConn := tls.Client(c.net, c.tlsConfig)
|
||||
if c.connectTimeout != 0 {
|
||||
_ = tlsConn.SetWriteDeadline(time.Now().Add(c.connectTimeout))
|
||||
}
|
||||
done <- tlsConn.Handshake()
|
||||
// TODO: return?
|
||||
|
||||
// swap net.Conn
|
||||
c.net = tlsConn
|
||||
c.tlsComplete = true
|
||||
}
|
||||
|
||||
// set deadline to interrupt connReader
|
||||
_ = c.net.SetReadDeadline(time.Time{}.Add(1))
|
||||
|
||||
if err := <-done; err != nil {
|
||||
// wrap existing net.Conn and perform TLS handshake
|
||||
tlsConn := tls.Client(c.net, c.tlsConfig)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// swap net.Conn
|
||||
c.net = tlsConn
|
||||
c.tlsComplete = true
|
||||
|
||||
// go to next protocol
|
||||
return c.negotiateProto, nil
|
||||
}
|
||||
|
@ -951,7 +886,7 @@ func (c *Conn) openAMQP() (stateFunc, error) {
|
|||
}
|
||||
|
||||
// get the response
|
||||
fr, err := c.readFrame()
|
||||
fr, err := c.readSingleFrame()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -981,7 +916,7 @@ func (c *Conn) openAMQP() (stateFunc, error) {
|
|||
// mechanism specified by the server
|
||||
func (c *Conn) negotiateSASL() (stateFunc, error) {
|
||||
// read mechanisms frame
|
||||
fr, err := c.readFrame()
|
||||
fr, err := c.readSingleFrame()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1010,7 +945,7 @@ func (c *Conn) negotiateSASL() (stateFunc, error) {
|
|||
// used externally by SASL only.
|
||||
func (c *Conn) saslOutcome() (stateFunc, error) {
|
||||
// read outcome frame
|
||||
fr, err := c.readFrame()
|
||||
fr, err := c.readSingleFrame()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1030,27 +965,21 @@ func (c *Conn) saslOutcome() (stateFunc, error) {
|
|||
return c.negotiateProto, nil
|
||||
}
|
||||
|
||||
// readFrame is used during connection establishment to read a single frame.
|
||||
// readSingleFrame is used during connection establishment to read a single frame.
|
||||
//
|
||||
// After setup, conn.mux handles incoming frames.
|
||||
// used externally by SASL only.
|
||||
func (c *Conn) readFrame() (frames.Frame, error) {
|
||||
var deadline <-chan time.Time
|
||||
// After setup, conn.connReader handles incoming frames.
|
||||
func (c *Conn) readSingleFrame() (frames.Frame, error) {
|
||||
if c.connectTimeout != 0 {
|
||||
deadline = time.After(c.connectTimeout)
|
||||
_ = c.net.SetDeadline(time.Now().Add(c.connectTimeout))
|
||||
defer func() { _ = c.net.SetDeadline(time.Time{}) }()
|
||||
}
|
||||
|
||||
var fr frames.Frame
|
||||
select {
|
||||
case fr = <-c.rxFrame:
|
||||
return fr, nil
|
||||
case err := <-c.connErr:
|
||||
return fr, err
|
||||
case p := <-c.rxProto:
|
||||
return fr, fmt.Errorf("unexpected protocol header %#v", p)
|
||||
case <-deadline:
|
||||
return fr, errors.New("amqp: timeout waiting for response")
|
||||
fr, err := c.readFrame()
|
||||
if err != nil {
|
||||
return frames.Frame{}, err
|
||||
}
|
||||
|
||||
return fr, nil
|
||||
}
|
||||
|
||||
type protoHeader struct {
|
||||
|
|
103
conn_test.go
103
conn_test.go
|
@ -316,30 +316,44 @@ func TestClose(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestServerSideClose(t *testing.T) {
|
||||
netConn := mocks.NewNetConn(senderFrameHandlerNoUnhandled(SenderSettleModeUnsettled))
|
||||
closeReceived := make(chan struct{})
|
||||
responder := func(req frames.FrameBody) ([]byte, error) {
|
||||
switch req.(type) {
|
||||
case *mocks.AMQPProto:
|
||||
return []byte{'A', 'M', 'Q', 'P', 0, 1, 0, 0}, nil
|
||||
case *frames.PerformOpen:
|
||||
return mocks.PerformOpen("container")
|
||||
case *frames.PerformClose:
|
||||
close(closeReceived)
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
}
|
||||
netConn := mocks.NewNetConn(responder)
|
||||
conn, err := newConn(netConn, nil)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.start())
|
||||
fr, err := mocks.PerformClose(nil)
|
||||
require.NoError(t, err)
|
||||
netConn.SendFrame(fr)
|
||||
<-closeReceived
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// with error
|
||||
netConn = mocks.NewNetConn(senderFrameHandlerNoUnhandled(SenderSettleModeUnsettled))
|
||||
closeReceived = make(chan struct{})
|
||||
netConn = mocks.NewNetConn(responder)
|
||||
conn, err = newConn(netConn, nil)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.start())
|
||||
fr, err = mocks.PerformClose(&Error{Condition: "Close", Description: "mock server error"})
|
||||
require.NoError(t, err)
|
||||
netConn.SendFrame(fr)
|
||||
// wait a bit for connReader to read from the mock
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
<-closeReceived
|
||||
err = conn.Close()
|
||||
var connErr *ConnError
|
||||
if !errors.As(err, &connErr) {
|
||||
t.Fatalf("unexpected error type %T", err)
|
||||
}
|
||||
require.ErrorAs(t, err, &connErr)
|
||||
require.Equal(t, "*Error{Condition: Close, Description: mock server error, Info: map[]}", connErr.Error())
|
||||
}
|
||||
|
||||
|
@ -358,6 +372,8 @@ func TestKeepAlives(t *testing.T) {
|
|||
case *mocks.KeepAlive:
|
||||
keepAlives <- struct{}{}
|
||||
return nil, nil
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -380,6 +396,48 @@ func TestKeepAlives(t *testing.T) {
|
|||
require.NoError(t, conn.Close())
|
||||
}
|
||||
|
||||
func TestKeepAlivesIdleTimeout(t *testing.T) {
|
||||
responder := func(req frames.FrameBody) ([]byte, error) {
|
||||
switch req.(type) {
|
||||
case *mocks.AMQPProto:
|
||||
return []byte{'A', 'M', 'Q', 'P', 0, 1, 0, 0}, nil
|
||||
case *frames.PerformOpen:
|
||||
return mocks.EncodeFrame(mocks.FrameAMQP, 0, &frames.PerformOpen{ContainerID: "container", IdleTimeout: time.Minute})
|
||||
case *mocks.KeepAlive:
|
||||
return nil, nil
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
}
|
||||
|
||||
const idleTimeout = 100 * time.Millisecond
|
||||
|
||||
netConn := mocks.NewNetConn(responder)
|
||||
conn, err := newConn(netConn, &ConnOptions{
|
||||
IdleTimeout: idleTimeout,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.start())
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(idleTimeout / 2):
|
||||
netConn.SendKeepAlive()
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(2 * idleTimeout)
|
||||
require.NoError(t, conn.Close())
|
||||
}
|
||||
|
||||
func TestConnReaderError(t *testing.T) {
|
||||
netConn := mocks.NewNetConn(senderFrameHandlerNoUnhandled(SenderSettleModeUnsettled))
|
||||
conn, err := newConn(netConn, nil)
|
||||
|
@ -415,6 +473,29 @@ func TestConnWriterError(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConnWithZeroByteReads(t *testing.T) {
|
||||
responder := func(req frames.FrameBody) ([]byte, error) {
|
||||
switch req.(type) {
|
||||
case *mocks.AMQPProto:
|
||||
return []byte{'A', 'M', 'Q', 'P', 0, 1, 0, 0}, nil
|
||||
case *frames.PerformOpen:
|
||||
return mocks.PerformOpen("container")
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
}
|
||||
|
||||
netConn := mocks.NewNetConn(responder)
|
||||
netConn.SendFrame([]byte{})
|
||||
|
||||
conn, err := newConn(netConn, nil)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.start())
|
||||
require.NoError(t, conn.Close())
|
||||
}
|
||||
|
||||
type mockDialer struct {
|
||||
resp func(frames.FrameBody) ([]byte, error)
|
||||
}
|
||||
|
@ -465,6 +546,8 @@ func TestClientClose(t *testing.T) {
|
|||
return []byte{'A', 'M', 'Q', 'P', 0, 1, 0, 0}, nil
|
||||
case *frames.PerformOpen:
|
||||
return mocks.PerformOpen("container")
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -551,6 +634,8 @@ func TestClientNewSession(t *testing.T) {
|
|||
return nil, fmt.Errorf("unexpected incoming window %d", tt.OutgoingWindow)
|
||||
}
|
||||
return mocks.PerformBegin(channelNum)
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -593,6 +678,8 @@ func TestClientMultipleSessions(t *testing.T) {
|
|||
b, err := mocks.PerformBegin(channelNum)
|
||||
channelNum++
|
||||
return b, err
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -636,6 +723,8 @@ func TestClientTooManySessions(t *testing.T) {
|
|||
b, err := mocks.PerformBegin(channelNum)
|
||||
channelNum++
|
||||
return b, err
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
|
|
@ -13,11 +13,13 @@ import (
|
|||
"github.com/Azure/go-amqp/internal/frames"
|
||||
"github.com/Azure/go-amqp/internal/testconn"
|
||||
"github.com/fortytw2/leaktest"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func fuzzConn(data []byte) int {
|
||||
// Receive
|
||||
client, err := NewConn(testconn.New(data), &ConnOptions{
|
||||
Timeout: 10 * time.Millisecond,
|
||||
IdleTimeout: 10 * time.Millisecond,
|
||||
SASLType: SASLTypePlain("listen", "3aCXZYFcuZA89xe6lZkfYJvOPnTGipA3ap7NvPruBhI="),
|
||||
})
|
||||
|
@ -464,8 +466,9 @@ func TestFuzzConnCrashers(t *testing.T) {
|
|||
|
||||
for i, tt := range tests {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
fuzzConn([]byte(tt))
|
||||
end := leaktest.Check(t)
|
||||
require.Zero(t, fuzzConn([]byte(tt)))
|
||||
end()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -759,15 +759,14 @@ func TestMultipleSessionsOpenClose(t *testing.T) {
|
|||
if localBrokerAddr == "" {
|
||||
t.Skip()
|
||||
}
|
||||
// TODO: connReader and connWriter goroutines will leak
|
||||
//checkLeaks := leaktest.Check(t)
|
||||
|
||||
checkLeaks := leaktest.Check(t)
|
||||
|
||||
// Create client
|
||||
client, err := amqp.Dial(localBrokerAddr, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
sessions := [10]*amqp.Session{}
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -790,22 +789,24 @@ func TestMultipleSessionsOpenClose(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
//checkLeaks()
|
||||
|
||||
client.Close()
|
||||
checkLeaks()
|
||||
}
|
||||
|
||||
func TestConcurrentSessionsOpenClose(t *testing.T) {
|
||||
if localBrokerAddr == "" {
|
||||
t.Skip()
|
||||
}
|
||||
// TODO: connReader and connWriter goroutines will leak
|
||||
//checkLeaks := leaktest.Check(t)
|
||||
|
||||
checkLeaks := leaktest.Check(t)
|
||||
|
||||
// Create client
|
||||
client, err := amqp.Dial(localBrokerAddr, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
|
@ -827,7 +828,9 @@ func TestConcurrentSessionsOpenClose(t *testing.T) {
|
|||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
//checkLeaks()
|
||||
|
||||
client.Close()
|
||||
checkLeaks()
|
||||
}
|
||||
|
||||
func repeatStrings(count int, strs ...string) []string {
|
||||
|
|
|
@ -27,6 +27,7 @@ func NewNetConn(resp func(frames.FrameBody) ([]byte, error)) *NetConn {
|
|||
// writes from blocking shutdown. the size was arbitrarily picked.
|
||||
readData: make(chan []byte, 10),
|
||||
readClose: make(chan struct{}),
|
||||
readDL: newNopTimer(), // default, no deadline
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -47,7 +48,7 @@ type NetConn struct {
|
|||
WriteErr chan error
|
||||
|
||||
resp func(frames.FrameBody) ([]byte, error)
|
||||
readDL *time.Timer
|
||||
readDL readTimer
|
||||
readData chan []byte
|
||||
readClose chan struct{}
|
||||
closed bool
|
||||
|
@ -90,15 +91,15 @@ func (n *NetConn) SendMultiFrameTransfer(remoteChannel uint16, linkHandle, deliv
|
|||
func (n *NetConn) Read(b []byte) (int, error) {
|
||||
select {
|
||||
case <-n.readClose:
|
||||
return 0, errors.New("mock connection was closed")
|
||||
return 0, net.ErrClosed
|
||||
default:
|
||||
// not closed yet
|
||||
}
|
||||
|
||||
select {
|
||||
case <-n.readClose:
|
||||
return 0, errors.New("mock connection was closed")
|
||||
case <-n.readDL.C:
|
||||
return 0, net.ErrClosed
|
||||
case <-n.readDL.C():
|
||||
return 0, errors.New("mock connection read deadline exceeded")
|
||||
case rd := <-n.readData:
|
||||
return copy(b, rd), nil
|
||||
|
@ -116,7 +117,7 @@ func (n *NetConn) Read(b []byte) (int, error) {
|
|||
func (n *NetConn) Write(b []byte) (int, error) {
|
||||
select {
|
||||
case <-n.readClose:
|
||||
return 0, errors.New("mock connection was closed")
|
||||
return 0, net.ErrClosed
|
||||
default:
|
||||
// not closed yet
|
||||
}
|
||||
|
@ -142,7 +143,7 @@ func (n *NetConn) Write(b []byte) (int, error) {
|
|||
return len(b), nil
|
||||
}
|
||||
|
||||
// Close is called by conn.close when conn.mux unwinds.
|
||||
// Close is called by conn.close.
|
||||
func (n *NetConn) Close() error {
|
||||
if n.closed {
|
||||
return errors.New("double close")
|
||||
|
@ -175,9 +176,9 @@ func (n *NetConn) SetReadDeadline(t time.Time) error {
|
|||
// called by conn.connReader before calling Read
|
||||
// stop the last timer if available
|
||||
if n.readDL != nil && !n.readDL.Stop() {
|
||||
<-n.readDL.C
|
||||
<-n.readDL.C()
|
||||
}
|
||||
n.readDL = time.NewTimer(time.Until(t))
|
||||
n.readDL = timer{t: time.NewTimer(time.Until(t))}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -437,3 +438,37 @@ func encodeMultiFrameTransfer(remoteChannel uint16, linkHandle, deliveryID uint3
|
|||
}
|
||||
return frameData, nil
|
||||
}
|
||||
|
||||
type readTimer interface {
|
||||
C() <-chan time.Time
|
||||
Stop() bool
|
||||
}
|
||||
|
||||
func newNopTimer() nopTimer {
|
||||
return nopTimer{t: make(chan time.Time)}
|
||||
}
|
||||
|
||||
type nopTimer struct {
|
||||
t chan time.Time
|
||||
}
|
||||
|
||||
func (n nopTimer) C() <-chan time.Time {
|
||||
return n.t
|
||||
}
|
||||
|
||||
func (n nopTimer) Stop() bool {
|
||||
close(n.t)
|
||||
return true
|
||||
}
|
||||
|
||||
type timer struct {
|
||||
t *time.Timer
|
||||
}
|
||||
|
||||
func (t timer) C() <-chan time.Time {
|
||||
return t.t.C
|
||||
}
|
||||
|
||||
func (t timer) Stop() bool {
|
||||
return t.t.Stop()
|
||||
}
|
||||
|
|
|
@ -35,7 +35,12 @@ func (c *Conn) Read(b []byte) (int, error) {
|
|||
}
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
n := copy(b, c.data[0])
|
||||
c.data = c.data[1:]
|
||||
// only move on to the next chunk if this one was entirely consumed
|
||||
if n == len(c.data[0]) {
|
||||
c.data = c.data[1:]
|
||||
} else {
|
||||
c.data[0] = c.data[0][n:]
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
|
@ -63,7 +68,7 @@ func (c *Conn) RemoteAddr() net.Addr {
|
|||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
return c.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
|
@ -84,5 +89,5 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
|
|||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
return errors.New("testconn.SetWriteDeadline NYI")
|
||||
}
|
||||
|
|
|
@ -52,6 +52,8 @@ func TestSenderMethodsNoSend(t *testing.T) {
|
|||
return mocks.SenderAttach(0, tt.Name, 0, SenderSettleModeUnsettled)
|
||||
case *frames.PerformDetach:
|
||||
return mocks.PerformDetach(0, 0, nil)
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -222,6 +224,8 @@ func TestSenderAttachError(t *testing.T) {
|
|||
// we don't need to respond to the ack
|
||||
detachAck <- true
|
||||
return nil, nil
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -466,6 +470,8 @@ func TestSenderSendRejectedNoDetach(t *testing.T) {
|
|||
return mocks.PerformDisposition(encoding.RoleReceiver, 0, *tt.DeliveryID, nil, &encoding.StateAccepted{})
|
||||
case *frames.PerformDetach:
|
||||
return mocks.PerformDetach(0, 0, nil)
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -602,6 +608,8 @@ func TestSenderSendMsgTooBig(t *testing.T) {
|
|||
return mocks.PerformDisposition(encoding.RoleReceiver, 0, *tt.DeliveryID, nil, &encoding.StateAccepted{})
|
||||
case *frames.PerformDetach:
|
||||
return mocks.PerformDetach(0, 0, nil)
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -710,6 +718,8 @@ func TestSenderSendMultiTransfer(t *testing.T) {
|
|||
return mocks.PerformDisposition(encoding.RoleReceiver, 0, deliveryID, nil, &encoding.StateAccepted{})
|
||||
case *frames.PerformDetach:
|
||||
return mocks.PerformDetach(0, 0, nil)
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -730,7 +740,7 @@ func TestSenderSendMultiTransfer(t *testing.T) {
|
|||
|
||||
sendInitialFlowFrame(t, netConn, 0, 100)
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 100000*time.Millisecond)
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
payload := make([]byte, maxReceiverFrameSize*4)
|
||||
for i := 0; i < maxReceiverFrameSize*4; i++ {
|
||||
payload[i] = byte(i % 256)
|
||||
|
|
12
session.go
12
session.go
|
@ -42,9 +42,9 @@ type SessionOptions struct {
|
|||
// A session multiplexes Receivers.
|
||||
type Session struct {
|
||||
channel uint16 // session's local channel
|
||||
remoteChannel uint16 // session's remote channel, owned by conn.mux
|
||||
remoteChannel uint16 // session's remote channel, owned by conn.connReader
|
||||
conn *Conn // underlying conn
|
||||
rx chan frames.Frame // frames destined for this session are sent on this chan by conn.mux
|
||||
rx chan frames.Frame // frames destined for this session are sent on this chan by conn.connReader
|
||||
tx chan frames.FrameBody // non-transfer frames to be sent; session must track disposition
|
||||
txTransfer chan *frames.PerformTransfer // transfer frames to be sent; session must track disposition
|
||||
|
||||
|
@ -139,7 +139,7 @@ func (s *Session) begin(ctx context.Context) error {
|
|||
}()
|
||||
return ctx.Err()
|
||||
case <-s.conn.done:
|
||||
return s.conn.err()
|
||||
return s.conn.doneErr
|
||||
case fr = <-s.rx:
|
||||
// received ack that session was created
|
||||
}
|
||||
|
@ -148,7 +148,7 @@ func (s *Session) begin(ctx context.Context) error {
|
|||
begin, ok := fr.Body.(*frames.PerformBegin)
|
||||
if !ok {
|
||||
// this codepath is hard to hit (impossible?). if the response isn't a PerformBegin and we've not
|
||||
// yet seen the remote channel number, the default clause in conn.mux will protect us from that.
|
||||
// yet seen the remote channel number, the default clause in conn.connReader will protect us from that.
|
||||
// if we have seen the remote channel number then it's likely the session.mux for that channel will
|
||||
// either swallow the frame or blow up in some other way, both causing this call to hang.
|
||||
// deallocate session on error. we can't call
|
||||
|
@ -279,7 +279,7 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) {
|
|||
select {
|
||||
// conn has completed, exit
|
||||
case <-s.conn.done:
|
||||
s.err = s.conn.err()
|
||||
s.err = s.conn.doneErr
|
||||
return
|
||||
|
||||
// session is being closed by user
|
||||
|
@ -298,7 +298,7 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) {
|
|||
break EndLoop
|
||||
}
|
||||
case <-s.conn.done:
|
||||
s.err = s.conn.err()
|
||||
s.err = s.conn.doneErr
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,6 +35,8 @@ func TestSessionClose(t *testing.T) {
|
|||
}
|
||||
channelNum--
|
||||
return b, nil
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -48,8 +50,9 @@ func TestSessionClose(t *testing.T) {
|
|||
session, err := client.NewSession(ctx, nil)
|
||||
cancel()
|
||||
require.NoErrorf(t, err, "iteration %d", i)
|
||||
require.Equalf(t, channelNum-1, session.channel, "iteration %d", i)
|
||||
ctx, cancel = context.WithTimeout(context.Background(), time.Second)
|
||||
require.Equalf(t, uint16(0), session.channel, "iteration %d", i)
|
||||
require.Equalf(t, channelNum-1, session.remoteChannel, "iteration %d", i)
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
|
||||
err = session.Close(ctx)
|
||||
cancel()
|
||||
require.NoErrorf(t, err, "iteration %d", i)
|
||||
|
@ -68,6 +71,8 @@ func TestSessionServerClose(t *testing.T) {
|
|||
return mocks.PerformBegin(0)
|
||||
case *frames.PerformEnd:
|
||||
return nil, nil // swallow
|
||||
case *frames.PerformClose:
|
||||
return nil, nil // swallow
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -112,6 +117,8 @@ func TestSessionCloseTimeout(t *testing.T) {
|
|||
// sleep to trigger session close timeout
|
||||
time.Sleep(1 * time.Second)
|
||||
return mocks.PerformEnd(0, nil)
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -199,6 +206,8 @@ func TestSessionNewReceiverBatchingOneCredit(t *testing.T) {
|
|||
return mocks.ReceiverAttach(0, tt.Name, 0, ReceiverSettleModeFirst, nil)
|
||||
case *frames.PerformFlow:
|
||||
return nil, nil
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -242,6 +251,8 @@ func TestSessionNewReceiverBatchingEnabled(t *testing.T) {
|
|||
return mocks.ReceiverAttach(0, tt.Name, 0, ReceiverSettleModeFirst, nil)
|
||||
case *frames.PerformFlow:
|
||||
return nil, nil
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -284,6 +295,8 @@ func TestSessionNewReceiverMismatchedLinkName(t *testing.T) {
|
|||
return mocks.PerformEnd(0, nil)
|
||||
case *frames.PerformAttach:
|
||||
return mocks.ReceiverAttach(0, "wrong_name", 0, ReceiverSettleModeFirst, nil)
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -351,6 +364,8 @@ func TestSessionNewSenderMismatchedLinkName(t *testing.T) {
|
|||
return mocks.PerformEnd(0, nil)
|
||||
case *frames.PerformAttach:
|
||||
return mocks.SenderAttach(0, "wrong_name", 0, SenderSettleModeUnsettled)
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -516,6 +531,8 @@ func TestSessionFlowFrameWithEcho(t *testing.T) {
|
|||
return nil, nil
|
||||
case *frames.PerformEnd:
|
||||
return mocks.PerformEnd(0, nil)
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
@ -562,6 +579,8 @@ func TestSessionInvalidAttachDeadlock(t *testing.T) {
|
|||
case *frames.PerformAttach:
|
||||
enqueueFrames(tt.Name)
|
||||
return nil, nil
|
||||
case *frames.PerformClose:
|
||||
return mocks.PerformClose(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled frame %T", req)
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче