diff --git a/CHANGELOG.md b/CHANGELOG.md index b0b6738..c245ffc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Release History +## 0.18.1 (unreleased) + +### Other Changes + +* The connection mux goroutine has been removed, eliminating a potential source of deadlocks. + ## 0.18.0 (2022-12-06) ### Features Added diff --git a/conn.go b/conn.go index 5859df0..107c224 100644 --- a/conn.go +++ b/conn.go @@ -149,14 +149,12 @@ type Conn struct { peerMaxFrameSize uint32 // maximum frame size peer will accept // conn state - doneErrMu sync.Mutex // mux holds doneErr from start until shutdown completes; operations are sequential before mux is started - doneErr error // error to be returned to client - done chan struct{} // indicates the connection is done + done chan struct{} // indicates the connection has terminated + doneErr error // contains the error state returned from Close(); DO NOT TOUCH outside of conn.go until Done has been closed! - // mux - connErr chan error // connReader/Writer notifications of an error - closeMux chan struct{} // indicates that the mux should stop - closeMuxOnce sync.Once + // connReader and connWriter management + rxtxExit chan struct{} // signals connReader and connWriter to exit + closeOnce sync.Once // ensures that close() is only called once // session tracking channels *bitmap.Bitmap @@ -164,15 +162,15 @@ type Conn struct { sessionsByChannelMu sync.RWMutex // connReader - rxProto chan protoHeader // protoHeaders received by connReader - rxFrame chan frames.Frame // AMQP frames received by connReader - rxDone chan struct{} - connReaderRun chan func() // functions to be run by conn reader (set deadline on conn to run) + rxBuf buffer.Buffer // incoming bytes buffer + rxDone chan struct{} // closed when connReader exits + rxErr error // contains last error reading from c.net; DO NOT TOUCH outside of connReader until rxDone has been closed! // connWriter txFrame chan frames.Frame // AMQP frames to be sent by connWriter txBuf buffer.Buffer // buffer for marshaling frames before transmitting - txDone chan struct{} + txDone chan struct{} // closed when connWriter exits + txErr error // contains last error writing to c.net; DO NOT TOUCH outside of connWriter until txDone has been closed! } // used to abstract the underlying dialer for testing purposes @@ -256,12 +254,8 @@ func newConn(netConn net.Conn, opts *ConnOptions) (*Conn, error) { idleTimeout: defaultIdleTimeout, containerID: shared.RandString(40), done: make(chan struct{}), - connErr: make(chan error, 2), // buffered to ensure connReader/Writer won't leak - closeMux: make(chan struct{}), - rxProto: make(chan protoHeader), - rxFrame: make(chan frames.Frame), + rxtxExit: make(chan struct{}), rxDone: make(chan struct{}), - connReaderRun: make(chan func(), 1), // buffered to allow queueing function before interrupt txFrame: make(chan frames.Frame), txDone: make(chan struct{}), sessionsByChannel: map[uint16]*Session{}, @@ -311,7 +305,6 @@ func newConn(netConn net.Conn, opts *ConnOptions) (*Conn, error) { if opts.dialer != nil { c.dialer = opts.dialer } - return c, nil } @@ -327,12 +320,9 @@ func (c *Conn) initTLSConfig() { } } -// Start establishes the connection and begins multiplexing network IO. +// start establishes the connection and begins multiplexing network IO. // It is an error to call Start() on a connection that's been closed. func (c *Conn) start() error { - // start reader - go c.connReader() - // run connection establishment state machine for state := c.negotiateProto; state != nil; { var err error @@ -340,6 +330,7 @@ func (c *Conn) start() error { // check if err occurred if err != nil { close(c.txDone) // close here since connWriter hasn't been started yet + close(c.rxDone) _ = c.Close() return err } @@ -349,72 +340,61 @@ func (c *Conn) start() error { // this is because our peer can tell us the max channels they support. c.channels = bitmap.New(uint32(c.channelMax)) - // start multiplexor and writer - go c.mux() go c.connWriter() + go c.connReader() return nil } // Close closes the connection. func (c *Conn) Close() error { - c.closeMuxOnce.Do(func() { close(c.closeMux) }) - err := c.err() + c.close() var connErr *ConnError - if errors.As(err, &connErr) && connErr.RemoteErr == nil && connErr.inner == nil { + if errors.As(c.doneErr, &connErr) && connErr.RemoteErr == nil && connErr.inner == nil { // an empty ConnectionError means the connection was closed by the caller - // or as requested by the peer and no error was provided in the close frame. return nil } - return err + + // there was an error during shut-down or connReader/connWriter + // experienced a terminal error + return c.doneErr } -// close should only be called by conn.mux. +// close is called once, either from Close() or when connReader/connWriter exits func (c *Conn) close() { - close(c.done) // notify goroutines and blocked functions to exit + c.closeOnce.Do(func() { + defer close(c.done) - // wait for writing to stop, allows it to send the final close frame - <-c.txDone + close(c.rxtxExit) - // reading from connErr in mux can race with closeMux, causing - // a pending conn read/write error to be lost. now that the - // mux has exited, drain any pending error. - select { - case err := <-c.connErr: - c.doneErr = err - default: - // no pending read/write error - } + // wait for writing to stop, allows it to send the final close frame + <-c.txDone - err := c.net.Close() - switch { - // conn.err already set - // TODO: err info is lost, log it? - case c.doneErr != nil: + closeErr := c.net.Close() - // conn.err not set and c.net.Close() returned a non-nil error - case err != nil: - c.doneErr = err + // check rxDone after closing net, otherwise may block + // for up to c.idleTimeout + <-c.rxDone - // no errors - default: - } + if errors.Is(c.rxErr, net.ErrClosed) { + // this is the expected error when the connection is closed, swallow it + c.rxErr = nil + } - // check rxDone after closing net, otherwise may block - // for up to c.idleTimeout - <-c.rxDone -} - -// Err returns the connection's error state after it's been closed. -// Calling this on an open connection will block until the connection is closed. -func (c *Conn) err() error { - c.doneErrMu.Lock() - defer c.doneErrMu.Unlock() - var amqpErr *Error - if errors.As(c.doneErr, &amqpErr) { - return &ConnError{RemoteErr: amqpErr} - } - return &ConnError{inner: c.doneErr} + if c.txErr == nil && c.rxErr == nil && closeErr == nil { + // if there are no errors, it means user initiated close() and we shut down cleanly + c.doneErr = &ConnError{} + } else if amqpErr, ok := c.rxErr.(*Error); ok { + // we experienced a peer-initiated close that contained an Error. return it + c.doneErr = &ConnError{RemoteErr: amqpErr} + } else if c.txErr != nil { + c.doneErr = &ConnError{inner: c.txErr} + } else if c.rxErr != nil { + c.doneErr = &ConnError{inner: c.rxErr} + } else { + c.doneErr = &ConnError{inner: closeErr} + } + }) } func (c *Conn) NewSession(ctx context.Context, opts *SessionOptions) (*Session, error) { @@ -454,251 +434,181 @@ func (c *Conn) deleteSession(s *Session) { c.channels.Remove(uint32(s.channel)) } -// mux is started in it's own goroutine after initial connection establishment. -// It handles muxing of sessions, keepalives, and connection errors. -func (c *Conn) mux() { - var ( - // map channels to sessions - sessionsByRemoteChannel = make(map[uint16]*Session) - ) - - // hold the errMu lock until error or done - c.doneErrMu.Lock() - defer c.doneErrMu.Unlock() - defer c.close() // defer order is important. c.errMu unlock indicates that connection is finally complete +// connReader reads from the net.Conn, decodes frames, and either handles +// them here as appropriate or sends them to the session.rx channel. +func (c *Conn) connReader() { + defer func() { + close(c.rxDone) + c.close() + }() + var sessionsByRemoteChannel = make(map[uint16]*Session) + var err error for { - // check if last loop returned an error - if c.doneErr != nil { + if err != nil { + debug.Log(1, "connReader terminal error: %v", err) + c.rxErr = err return } - select { - // error from connReader - case c.doneErr = <-c.connErr: + var fr frames.Frame + fr, err = c.readFrame() + if err != nil { + continue + } - // new frame from connReader - case fr := <-c.rxFrame: - var ( - session *Session - ok bool - ) + var ( + session *Session + ok bool + ) - switch body := fr.Body.(type) { - // Server initiated close. - case *frames.PerformClose: - if body.Error != nil { - c.doneErr = body.Error - } + switch body := fr.Body.(type) { + // Server initiated close. + case *frames.PerformClose: + // connWriter will send the close performative ack on its way out. + // it's a SHOULD though, not a MUST. + debug.Log(3, "RX (connReader): %s", body) + if body.Error == nil { return - - // RemoteChannel should be used when frame is Begin - case *frames.PerformBegin: - if body.RemoteChannel == nil { - // since we only support remotely-initiated sessions, this is an error - // TODO: it would be ideal to not have this kill the connection - c.doneErr = fmt.Errorf("%T: nil RemoteChannel", fr.Body) - break - } - c.sessionsByChannelMu.RLock() - session, ok = c.sessionsByChannel[*body.RemoteChannel] - c.sessionsByChannelMu.RUnlock() - if !ok { - c.doneErr = fmt.Errorf("unexpected remote channel number %d", *body.RemoteChannel) - break - } - - session.remoteChannel = fr.Channel - sessionsByRemoteChannel[fr.Channel] = session - - case *frames.PerformEnd: - session, ok = sessionsByRemoteChannel[fr.Channel] - if !ok { - c.doneErr = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel (PerformEnd)", fr.Body, fr.Channel) - break - } - // we MUST remove the remote channel from our map as soon as we receive - // the ack (i.e. before passing it on to the session mux) on the session - // ending since the numbers are recycled. - delete(sessionsByRemoteChannel, fr.Channel) - - default: - // pass on performative to the correct session - session, ok = sessionsByRemoteChannel[fr.Channel] - if !ok { - c.doneErr = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel", fr.Body, fr.Channel) - } } + err = body.Error + continue + // RemoteChannel should be used when frame is Begin + case *frames.PerformBegin: + if body.RemoteChannel == nil { + // since we only support remotely-initiated sessions, this is an error + // TODO: it would be ideal to not have this kill the connection + err = fmt.Errorf("%T: nil RemoteChannel", fr.Body) + continue + } + c.sessionsByChannelMu.RLock() + session, ok = c.sessionsByChannel[*body.RemoteChannel] + c.sessionsByChannelMu.RUnlock() if !ok { + err = fmt.Errorf("unexpected remote channel number %d", *body.RemoteChannel) continue } - select { - case session.rx <- fr: - case <-c.closeMux: - return - } + session.remoteChannel = fr.Channel + sessionsByRemoteChannel[fr.Channel] = session - // connection is complete - case <-c.closeMux: + case *frames.PerformEnd: + session, ok = sessionsByRemoteChannel[fr.Channel] + if !ok { + err = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel (PerformEnd)", fr.Body, fr.Channel) + continue + } + // we MUST remove the remote channel from our map as soon as we receive + // the ack (i.e. before passing it on to the session mux) on the session + // ending since the numbers are recycled. + delete(sessionsByRemoteChannel, fr.Channel) + + default: + // pass on performative to the correct session + session, ok = sessionsByRemoteChannel[fr.Channel] + if !ok { + err = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel", fr.Body, fr.Channel) + continue + } + } + + select { + case session.rx <- fr: + case <-c.rxtxExit: return } } } -// connReader reads from the net.Conn, decodes frames, and passes them -// up via the conn.rxFrame and conn.rxProto channels. -func (c *Conn) connReader() { - defer close(c.rxDone) +// readFrame reads a complete frame from c.net. +// it assumes that any read deadline has already been applied. +// used externally by SASL only. +func (c *Conn) readFrame() (frames.Frame, error) { + switch { + // Cheaply reuse free buffer space when fully read. + case c.rxBuf.Len() == 0: + c.rxBuf.Reset() - buf := &buffer.Buffer{} + // Prevent excessive/unbounded growth by shifting data to beginning of buffer. + case int64(c.rxBuf.Size()) > int64(c.maxFrameSize): + c.rxBuf.Reclaim() + } var ( - negotiating = true // true during conn establishment, check for protoHeaders currentHeader frames.Header // keep track of the current header, for frames split across multiple TCP packets frameInProgress bool // true if in the middle of receiving data for currentHeader ) for { - switch { - // Cheaply reuse free buffer space when fully read. - case buf.Len() == 0: - buf.Reset() - - // Prevent excessive/unbounded growth by shifting data to beginning of buffer. - case int64(buf.Size()) > int64(c.maxFrameSize): - buf.Reclaim() - } - // need to read more if buf doesn't contain the complete frame // or there's not enough in buf to parse the header - if frameInProgress || buf.Len() < frames.HeaderSize { + if frameInProgress || c.rxBuf.Len() < frames.HeaderSize { + // we MUST reset the idle timeout before each read from net.Conn if c.idleTimeout > 0 { _ = c.net.SetReadDeadline(time.Now().Add(c.idleTimeout)) } - err := buf.ReadFromOnce(c.net) + err := c.rxBuf.ReadFromOnce(c.net) if err != nil { - debug.Log(1, "connReader error: %v", err) - select { - // check if error was due to close in progress - case <-c.done: - return - - // if there is a pending connReaderRun function, execute it - case f := <-c.connReaderRun: - f() - continue - - // send error to mux and return - default: - c.connErr <- err - return - } + debug.Log(1, "readFrame error: %v", err) + return frames.Frame{}, err } } // read more if buf doesn't contain enough to parse the header - if buf.Len() < frames.HeaderSize { - continue - } - - // during negotiation, check for proto frames - if negotiating && bytes.Equal(buf.Bytes()[:4], []byte{'A', 'M', 'Q', 'P'}) { - const protoHeaderSize = 8 - buf, ok := buf.Next(protoHeaderSize) - if !ok { - c.connErr <- errors.New("invalid protoHeader") - return - } - _ = buf[7] - - if !bytes.Equal(buf[:4], []byte{'A', 'M', 'Q', 'P'}) { - c.connErr <- fmt.Errorf("unexpected protocol %q", buf[:4]) - return - } - - p := protoHeader{ - ProtoID: protoID(buf[4]), - Major: buf[5], - Minor: buf[6], - Revision: buf[7], - } - - if p.Major != 1 || p.Minor != 0 || p.Revision != 0 { - c.connErr <- fmt.Errorf("unexpected protocol version %d.%d.%d", p.Major, p.Minor, p.Revision) - return - } - - // negotiation is complete once an AMQP proto frame is received - if p.ProtoID == protoAMQP { - negotiating = false - } - - // send proto header - select { - case <-c.done: - return - case c.rxProto <- p: - } - + if c.rxBuf.Len() < frames.HeaderSize { continue } // parse the header if a frame isn't in progress if !frameInProgress { var err error - currentHeader, err = frames.ParseHeader(buf) + currentHeader, err = frames.ParseHeader(&c.rxBuf) if err != nil { - c.connErr <- err - return + return frames.Frame{}, err } frameInProgress = true } // check size is reasonable if currentHeader.Size > math.MaxInt32 { // make max size configurable - c.connErr <- errors.New("payload too large") - return + return frames.Frame{}, errors.New("payload too large") } bodySize := int64(currentHeader.Size - frames.HeaderSize) - // the full frame has been received - if int64(buf.Len()) < bodySize { + // the full frame hasn't been received, keep reading + if int64(c.rxBuf.Len()) < bodySize { continue } frameInProgress = false // check if body is empty (keepalive) if bodySize == 0 { + debug.Log(3, "received keep-alive frame") continue } // parse the frame - b, ok := buf.Next(bodySize) + b, ok := c.rxBuf.Next(bodySize) if !ok { - c.connErr <- fmt.Errorf("buffer EOF; requested bytes: %d, actual size: %d", bodySize, buf.Len()) - return + return frames.Frame{}, fmt.Errorf("buffer EOF; requested bytes: %d, actual size: %d", bodySize, c.rxBuf.Len()) } parsedBody, err := frames.ParseBody(buffer.New(b)) if err != nil { - c.connErr <- err - return + return frames.Frame{}, err } - // send to mux - select { - case <-c.done: - return - case c.rxFrame <- frames.Frame{Channel: currentHeader.Channel, Body: parsedBody}: - } + return frames.Frame{Channel: currentHeader.Channel, Body: parsedBody}, nil } } func (c *Conn) connWriter() { - defer close(c.txDone) + defer func() { + close(c.txDone) + c.close() + }() // disable write timeout if c.connectTimeout != 0 { @@ -724,8 +634,8 @@ func (c *Conn) connWriter() { var err error for { if err != nil { - debug.Log(1, "connWriter error: %v", err) - c.connErr <- err + debug.Log(1, "connWriter terminal error: %v", err) + c.txErr = err return } @@ -750,11 +660,14 @@ func (c *Conn) connWriter() { // possibly drained, then reset.) // connection complete - case <-c.done: - // send close + case <-c.rxtxExit: + // send close performative. note that the spec says we + // SHOULD wait for the ack but we don't HAVE to, in order + // to be resilient to bad actors etc. so we just send + // the close performative and exit. cls := &frames.PerformClose{} debug.Log(1, "TX (connWriter): %s", cls) - _ = c.writeFrame(frames.Frame{ + c.txErr = c.writeFrame(frames.Frame{ Type: frames.TypeAMQP, Body: cls, }) @@ -810,7 +723,7 @@ func (c *Conn) sendFrame(fr frames.Frame) error { case c.txFrame <- fr: return nil case <-c.done: - return c.err() + return c.doneErr } } @@ -876,55 +789,77 @@ func (c *Conn) exchangeProtoHeader(pID protoID) (stateFunc, error) { // readProtoHeader reads a protocol header packet from c.rxProto. func (c *Conn) readProtoHeader() (protoHeader, error) { - var deadline <-chan time.Time - if c.connectTimeout != 0 { - deadline = time.After(c.connectTimeout) + const protoHeaderSize = 8 + + // only read from the network once our buffer has been exhausted. + // TODO: this preserves existing behavior as some tests rely on this + // implementation detail (it lets you replay a stream of bytes). we + // might want to consider removing this and fixing the tests as the + // protocol doesn't actually work this way. + if c.rxBuf.Len() == 0 { + for { + if c.connectTimeout != 0 { + _ = c.net.SetReadDeadline(time.Now().Add(c.connectTimeout)) + } + + err := c.rxBuf.ReadFromOnce(c.net) + if err != nil { + return protoHeader{}, err + } + + // read more if buf doesn't contain enough to parse the header + if c.rxBuf.Len() >= protoHeaderSize { + break + } + } + + // reset outside the loop + if c.connectTimeout != 0 { + _ = c.net.SetReadDeadline(time.Time{}) + } } - var p protoHeader - select { - case p = <-c.rxProto: - return p, nil - case err := <-c.connErr: - return p, err - case fr := <-c.rxFrame: - return p, fmt.Errorf("readProtoHeader: unexpected frame %#v", fr) - case <-deadline: - return p, errors.New("amqp: timeout waiting for response") + + buf, ok := c.rxBuf.Next(protoHeaderSize) + if !ok { + return protoHeader{}, errors.New("invalid protoHeader") } + // bounds check hint to compiler; see golang.org/issue/14808 + _ = buf[protoHeaderSize-1] + + if !bytes.Equal(buf[:4], []byte{'A', 'M', 'Q', 'P'}) { + return protoHeader{}, fmt.Errorf("unexpected protocol %q", buf[:4]) + } + + p := protoHeader{ + ProtoID: protoID(buf[4]), + Major: buf[5], + Minor: buf[6], + Revision: buf[7], + } + + if p.Major != 1 || p.Minor != 0 || p.Revision != 0 { + return protoHeader{}, fmt.Errorf("unexpected protocol version %d.%d.%d", p.Major, p.Minor, p.Revision) + } + + return p, nil } // startTLS wraps the conn with TLS and returns to Client.negotiateProto func (c *Conn) startTLS() (stateFunc, error) { c.initTLSConfig() - // buffered so connReaderRun won't block - done := make(chan error, 1) + _ = c.net.SetReadDeadline(time.Time{}) // clear timeout - // this function will be executed by connReader - c.connReaderRun <- func() { - defer close(done) - _ = c.net.SetReadDeadline(time.Time{}) // clear timeout - - // wrap existing net.Conn and perform TLS handshake - tlsConn := tls.Client(c.net, c.tlsConfig) - if c.connectTimeout != 0 { - _ = tlsConn.SetWriteDeadline(time.Now().Add(c.connectTimeout)) - } - done <- tlsConn.Handshake() - // TODO: return? - - // swap net.Conn - c.net = tlsConn - c.tlsComplete = true - } - - // set deadline to interrupt connReader - _ = c.net.SetReadDeadline(time.Time{}.Add(1)) - - if err := <-done; err != nil { + // wrap existing net.Conn and perform TLS handshake + tlsConn := tls.Client(c.net, c.tlsConfig) + if err := tlsConn.Handshake(); err != nil { return nil, err } + // swap net.Conn + c.net = tlsConn + c.tlsComplete = true + // go to next protocol return c.negotiateProto, nil } @@ -951,7 +886,7 @@ func (c *Conn) openAMQP() (stateFunc, error) { } // get the response - fr, err := c.readFrame() + fr, err := c.readSingleFrame() if err != nil { return nil, err } @@ -981,7 +916,7 @@ func (c *Conn) openAMQP() (stateFunc, error) { // mechanism specified by the server func (c *Conn) negotiateSASL() (stateFunc, error) { // read mechanisms frame - fr, err := c.readFrame() + fr, err := c.readSingleFrame() if err != nil { return nil, err } @@ -1010,7 +945,7 @@ func (c *Conn) negotiateSASL() (stateFunc, error) { // used externally by SASL only. func (c *Conn) saslOutcome() (stateFunc, error) { // read outcome frame - fr, err := c.readFrame() + fr, err := c.readSingleFrame() if err != nil { return nil, err } @@ -1030,27 +965,21 @@ func (c *Conn) saslOutcome() (stateFunc, error) { return c.negotiateProto, nil } -// readFrame is used during connection establishment to read a single frame. +// readSingleFrame is used during connection establishment to read a single frame. // -// After setup, conn.mux handles incoming frames. -// used externally by SASL only. -func (c *Conn) readFrame() (frames.Frame, error) { - var deadline <-chan time.Time +// After setup, conn.connReader handles incoming frames. +func (c *Conn) readSingleFrame() (frames.Frame, error) { if c.connectTimeout != 0 { - deadline = time.After(c.connectTimeout) + _ = c.net.SetDeadline(time.Now().Add(c.connectTimeout)) + defer func() { _ = c.net.SetDeadline(time.Time{}) }() } - var fr frames.Frame - select { - case fr = <-c.rxFrame: - return fr, nil - case err := <-c.connErr: - return fr, err - case p := <-c.rxProto: - return fr, fmt.Errorf("unexpected protocol header %#v", p) - case <-deadline: - return fr, errors.New("amqp: timeout waiting for response") + fr, err := c.readFrame() + if err != nil { + return frames.Frame{}, err } + + return fr, nil } type protoHeader struct { diff --git a/conn_test.go b/conn_test.go index 64ad6a4..3d3296f 100644 --- a/conn_test.go +++ b/conn_test.go @@ -316,30 +316,44 @@ func TestClose(t *testing.T) { } func TestServerSideClose(t *testing.T) { - netConn := mocks.NewNetConn(senderFrameHandlerNoUnhandled(SenderSettleModeUnsettled)) + closeReceived := make(chan struct{}) + responder := func(req frames.FrameBody) ([]byte, error) { + switch req.(type) { + case *mocks.AMQPProto: + return []byte{'A', 'M', 'Q', 'P', 0, 1, 0, 0}, nil + case *frames.PerformOpen: + return mocks.PerformOpen("container") + case *frames.PerformClose: + close(closeReceived) + return mocks.PerformClose(nil) + default: + return nil, fmt.Errorf("unhandled frame %T", req) + } + } + netConn := mocks.NewNetConn(responder) conn, err := newConn(netConn, nil) require.NoError(t, err) require.NoError(t, conn.start()) fr, err := mocks.PerformClose(nil) require.NoError(t, err) netConn.SendFrame(fr) + <-closeReceived err = conn.Close() require.NoError(t, err) + // with error - netConn = mocks.NewNetConn(senderFrameHandlerNoUnhandled(SenderSettleModeUnsettled)) + closeReceived = make(chan struct{}) + netConn = mocks.NewNetConn(responder) conn, err = newConn(netConn, nil) require.NoError(t, err) require.NoError(t, conn.start()) fr, err = mocks.PerformClose(&Error{Condition: "Close", Description: "mock server error"}) require.NoError(t, err) netConn.SendFrame(fr) - // wait a bit for connReader to read from the mock - time.Sleep(100 * time.Millisecond) + <-closeReceived err = conn.Close() var connErr *ConnError - if !errors.As(err, &connErr) { - t.Fatalf("unexpected error type %T", err) - } + require.ErrorAs(t, err, &connErr) require.Equal(t, "*Error{Condition: Close, Description: mock server error, Info: map[]}", connErr.Error()) } @@ -358,6 +372,8 @@ func TestKeepAlives(t *testing.T) { case *mocks.KeepAlive: keepAlives <- struct{}{} return nil, nil + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -380,6 +396,48 @@ func TestKeepAlives(t *testing.T) { require.NoError(t, conn.Close()) } +func TestKeepAlivesIdleTimeout(t *testing.T) { + responder := func(req frames.FrameBody) ([]byte, error) { + switch req.(type) { + case *mocks.AMQPProto: + return []byte{'A', 'M', 'Q', 'P', 0, 1, 0, 0}, nil + case *frames.PerformOpen: + return mocks.EncodeFrame(mocks.FrameAMQP, 0, &frames.PerformOpen{ContainerID: "container", IdleTimeout: time.Minute}) + case *mocks.KeepAlive: + return nil, nil + case *frames.PerformClose: + return mocks.PerformClose(nil) + default: + return nil, fmt.Errorf("unhandled frame %T", req) + } + } + + const idleTimeout = 100 * time.Millisecond + + netConn := mocks.NewNetConn(responder) + conn, err := newConn(netConn, &ConnOptions{ + IdleTimeout: idleTimeout, + }) + require.NoError(t, err) + require.NoError(t, conn.start()) + + done := make(chan struct{}) + defer close(done) + go func() { + for { + select { + case <-time.After(idleTimeout / 2): + netConn.SendKeepAlive() + case <-done: + return + } + } + }() + + time.Sleep(2 * idleTimeout) + require.NoError(t, conn.Close()) +} + func TestConnReaderError(t *testing.T) { netConn := mocks.NewNetConn(senderFrameHandlerNoUnhandled(SenderSettleModeUnsettled)) conn, err := newConn(netConn, nil) @@ -415,6 +473,29 @@ func TestConnWriterError(t *testing.T) { } } +func TestConnWithZeroByteReads(t *testing.T) { + responder := func(req frames.FrameBody) ([]byte, error) { + switch req.(type) { + case *mocks.AMQPProto: + return []byte{'A', 'M', 'Q', 'P', 0, 1, 0, 0}, nil + case *frames.PerformOpen: + return mocks.PerformOpen("container") + case *frames.PerformClose: + return mocks.PerformClose(nil) + default: + return nil, fmt.Errorf("unhandled frame %T", req) + } + } + + netConn := mocks.NewNetConn(responder) + netConn.SendFrame([]byte{}) + + conn, err := newConn(netConn, nil) + require.NoError(t, err) + require.NoError(t, conn.start()) + require.NoError(t, conn.Close()) +} + type mockDialer struct { resp func(frames.FrameBody) ([]byte, error) } @@ -465,6 +546,8 @@ func TestClientClose(t *testing.T) { return []byte{'A', 'M', 'Q', 'P', 0, 1, 0, 0}, nil case *frames.PerformOpen: return mocks.PerformOpen("container") + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -551,6 +634,8 @@ func TestClientNewSession(t *testing.T) { return nil, fmt.Errorf("unexpected incoming window %d", tt.OutgoingWindow) } return mocks.PerformBegin(channelNum) + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -593,6 +678,8 @@ func TestClientMultipleSessions(t *testing.T) { b, err := mocks.PerformBegin(channelNum) channelNum++ return b, err + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -636,6 +723,8 @@ func TestClientTooManySessions(t *testing.T) { b, err := mocks.PerformBegin(channelNum) channelNum++ return b, err + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } diff --git a/fuzz_test.go b/fuzz_test.go index 78f1a14..62389fd 100644 --- a/fuzz_test.go +++ b/fuzz_test.go @@ -13,11 +13,13 @@ import ( "github.com/Azure/go-amqp/internal/frames" "github.com/Azure/go-amqp/internal/testconn" "github.com/fortytw2/leaktest" + "github.com/stretchr/testify/require" ) func fuzzConn(data []byte) int { // Receive client, err := NewConn(testconn.New(data), &ConnOptions{ + Timeout: 10 * time.Millisecond, IdleTimeout: 10 * time.Millisecond, SASLType: SASLTypePlain("listen", "3aCXZYFcuZA89xe6lZkfYJvOPnTGipA3ap7NvPruBhI="), }) @@ -464,8 +466,9 @@ func TestFuzzConnCrashers(t *testing.T) { for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { - defer leaktest.Check(t)() - fuzzConn([]byte(tt)) + end := leaktest.Check(t) + require.Zero(t, fuzzConn([]byte(tt))) + end() }) } } diff --git a/integration_test.go b/integration_test.go index 55e80e5..df5ed21 100644 --- a/integration_test.go +++ b/integration_test.go @@ -759,15 +759,14 @@ func TestMultipleSessionsOpenClose(t *testing.T) { if localBrokerAddr == "" { t.Skip() } - // TODO: connReader and connWriter goroutines will leak - //checkLeaks := leaktest.Check(t) + + checkLeaks := leaktest.Check(t) // Create client client, err := amqp.Dial(localBrokerAddr, nil) if err != nil { t.Fatal(err) } - defer client.Close() sessions := [10]*amqp.Session{} for i := 0; i < 10; i++ { @@ -790,22 +789,24 @@ func TestMultipleSessionsOpenClose(t *testing.T) { } } } - //checkLeaks() + + client.Close() + checkLeaks() } func TestConcurrentSessionsOpenClose(t *testing.T) { if localBrokerAddr == "" { t.Skip() } - // TODO: connReader and connWriter goroutines will leak - //checkLeaks := leaktest.Check(t) + + checkLeaks := leaktest.Check(t) // Create client client, err := amqp.Dial(localBrokerAddr, nil) if err != nil { t.Fatal(err) } - defer client.Close() + wg := sync.WaitGroup{} for i := 0; i < 100; i++ { wg.Add(1) @@ -827,7 +828,9 @@ func TestConcurrentSessionsOpenClose(t *testing.T) { }() } wg.Wait() - //checkLeaks() + + client.Close() + checkLeaks() } func repeatStrings(count int, strs ...string) []string { diff --git a/internal/mocks/net_conn.go b/internal/mocks/net_conn.go index 3f640b4..c306a47 100644 --- a/internal/mocks/net_conn.go +++ b/internal/mocks/net_conn.go @@ -27,6 +27,7 @@ func NewNetConn(resp func(frames.FrameBody) ([]byte, error)) *NetConn { // writes from blocking shutdown. the size was arbitrarily picked. readData: make(chan []byte, 10), readClose: make(chan struct{}), + readDL: newNopTimer(), // default, no deadline } } @@ -47,7 +48,7 @@ type NetConn struct { WriteErr chan error resp func(frames.FrameBody) ([]byte, error) - readDL *time.Timer + readDL readTimer readData chan []byte readClose chan struct{} closed bool @@ -90,15 +91,15 @@ func (n *NetConn) SendMultiFrameTransfer(remoteChannel uint16, linkHandle, deliv func (n *NetConn) Read(b []byte) (int, error) { select { case <-n.readClose: - return 0, errors.New("mock connection was closed") + return 0, net.ErrClosed default: // not closed yet } select { case <-n.readClose: - return 0, errors.New("mock connection was closed") - case <-n.readDL.C: + return 0, net.ErrClosed + case <-n.readDL.C(): return 0, errors.New("mock connection read deadline exceeded") case rd := <-n.readData: return copy(b, rd), nil @@ -116,7 +117,7 @@ func (n *NetConn) Read(b []byte) (int, error) { func (n *NetConn) Write(b []byte) (int, error) { select { case <-n.readClose: - return 0, errors.New("mock connection was closed") + return 0, net.ErrClosed default: // not closed yet } @@ -142,7 +143,7 @@ func (n *NetConn) Write(b []byte) (int, error) { return len(b), nil } -// Close is called by conn.close when conn.mux unwinds. +// Close is called by conn.close. func (n *NetConn) Close() error { if n.closed { return errors.New("double close") @@ -175,9 +176,9 @@ func (n *NetConn) SetReadDeadline(t time.Time) error { // called by conn.connReader before calling Read // stop the last timer if available if n.readDL != nil && !n.readDL.Stop() { - <-n.readDL.C + <-n.readDL.C() } - n.readDL = time.NewTimer(time.Until(t)) + n.readDL = timer{t: time.NewTimer(time.Until(t))} return nil } @@ -437,3 +438,37 @@ func encodeMultiFrameTransfer(remoteChannel uint16, linkHandle, deliveryID uint3 } return frameData, nil } + +type readTimer interface { + C() <-chan time.Time + Stop() bool +} + +func newNopTimer() nopTimer { + return nopTimer{t: make(chan time.Time)} +} + +type nopTimer struct { + t chan time.Time +} + +func (n nopTimer) C() <-chan time.Time { + return n.t +} + +func (n nopTimer) Stop() bool { + close(n.t) + return true +} + +type timer struct { + t *time.Timer +} + +func (t timer) C() <-chan time.Time { + return t.t.C +} + +func (t timer) Stop() bool { + return t.t.Stop() +} diff --git a/internal/testconn/testconn.go b/internal/testconn/testconn.go index 14e6f40..f048f39 100644 --- a/internal/testconn/testconn.go +++ b/internal/testconn/testconn.go @@ -35,7 +35,12 @@ func (c *Conn) Read(b []byte) (int, error) { } time.Sleep(1 * time.Millisecond) n := copy(b, c.data[0]) - c.data = c.data[1:] + // only move on to the next chunk if this one was entirely consumed + if n == len(c.data[0]) { + c.data = c.data[1:] + } else { + c.data[0] = c.data[0][n:] + } return n, nil } @@ -63,7 +68,7 @@ func (c *Conn) RemoteAddr() net.Addr { } func (c *Conn) SetDeadline(t time.Time) error { - return nil + return c.SetReadDeadline(t) } func (c *Conn) SetReadDeadline(t time.Time) error { @@ -84,5 +89,5 @@ func (c *Conn) SetReadDeadline(t time.Time) error { } func (c *Conn) SetWriteDeadline(t time.Time) error { - return nil + return errors.New("testconn.SetWriteDeadline NYI") } diff --git a/sender_test.go b/sender_test.go index e1946b8..d17d2cc 100644 --- a/sender_test.go +++ b/sender_test.go @@ -52,6 +52,8 @@ func TestSenderMethodsNoSend(t *testing.T) { return mocks.SenderAttach(0, tt.Name, 0, SenderSettleModeUnsettled) case *frames.PerformDetach: return mocks.PerformDetach(0, 0, nil) + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -222,6 +224,8 @@ func TestSenderAttachError(t *testing.T) { // we don't need to respond to the ack detachAck <- true return nil, nil + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -466,6 +470,8 @@ func TestSenderSendRejectedNoDetach(t *testing.T) { return mocks.PerformDisposition(encoding.RoleReceiver, 0, *tt.DeliveryID, nil, &encoding.StateAccepted{}) case *frames.PerformDetach: return mocks.PerformDetach(0, 0, nil) + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -602,6 +608,8 @@ func TestSenderSendMsgTooBig(t *testing.T) { return mocks.PerformDisposition(encoding.RoleReceiver, 0, *tt.DeliveryID, nil, &encoding.StateAccepted{}) case *frames.PerformDetach: return mocks.PerformDetach(0, 0, nil) + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -710,6 +718,8 @@ func TestSenderSendMultiTransfer(t *testing.T) { return mocks.PerformDisposition(encoding.RoleReceiver, 0, deliveryID, nil, &encoding.StateAccepted{}) case *frames.PerformDetach: return mocks.PerformDetach(0, 0, nil) + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -730,7 +740,7 @@ func TestSenderSendMultiTransfer(t *testing.T) { sendInitialFlowFrame(t, netConn, 0, 100) - ctx, cancel = context.WithTimeout(context.Background(), 100000*time.Millisecond) + ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) payload := make([]byte, maxReceiverFrameSize*4) for i := 0; i < maxReceiverFrameSize*4; i++ { payload[i] = byte(i % 256) diff --git a/session.go b/session.go index cf216e4..6094853 100644 --- a/session.go +++ b/session.go @@ -42,9 +42,9 @@ type SessionOptions struct { // A session multiplexes Receivers. type Session struct { channel uint16 // session's local channel - remoteChannel uint16 // session's remote channel, owned by conn.mux + remoteChannel uint16 // session's remote channel, owned by conn.connReader conn *Conn // underlying conn - rx chan frames.Frame // frames destined for this session are sent on this chan by conn.mux + rx chan frames.Frame // frames destined for this session are sent on this chan by conn.connReader tx chan frames.FrameBody // non-transfer frames to be sent; session must track disposition txTransfer chan *frames.PerformTransfer // transfer frames to be sent; session must track disposition @@ -139,7 +139,7 @@ func (s *Session) begin(ctx context.Context) error { }() return ctx.Err() case <-s.conn.done: - return s.conn.err() + return s.conn.doneErr case fr = <-s.rx: // received ack that session was created } @@ -148,7 +148,7 @@ func (s *Session) begin(ctx context.Context) error { begin, ok := fr.Body.(*frames.PerformBegin) if !ok { // this codepath is hard to hit (impossible?). if the response isn't a PerformBegin and we've not - // yet seen the remote channel number, the default clause in conn.mux will protect us from that. + // yet seen the remote channel number, the default clause in conn.connReader will protect us from that. // if we have seen the remote channel number then it's likely the session.mux for that channel will // either swallow the frame or blow up in some other way, both causing this call to hang. // deallocate session on error. we can't call @@ -279,7 +279,7 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) { select { // conn has completed, exit case <-s.conn.done: - s.err = s.conn.err() + s.err = s.conn.doneErr return // session is being closed by user @@ -298,7 +298,7 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) { break EndLoop } case <-s.conn.done: - s.err = s.conn.err() + s.err = s.conn.doneErr return } } diff --git a/session_test.go b/session_test.go index 4ac09d2..ee8b98f 100644 --- a/session_test.go +++ b/session_test.go @@ -35,6 +35,8 @@ func TestSessionClose(t *testing.T) { } channelNum-- return b, nil + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -48,8 +50,9 @@ func TestSessionClose(t *testing.T) { session, err := client.NewSession(ctx, nil) cancel() require.NoErrorf(t, err, "iteration %d", i) - require.Equalf(t, channelNum-1, session.channel, "iteration %d", i) - ctx, cancel = context.WithTimeout(context.Background(), time.Second) + require.Equalf(t, uint16(0), session.channel, "iteration %d", i) + require.Equalf(t, channelNum-1, session.remoteChannel, "iteration %d", i) + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) err = session.Close(ctx) cancel() require.NoErrorf(t, err, "iteration %d", i) @@ -68,6 +71,8 @@ func TestSessionServerClose(t *testing.T) { return mocks.PerformBegin(0) case *frames.PerformEnd: return nil, nil // swallow + case *frames.PerformClose: + return nil, nil // swallow default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -112,6 +117,8 @@ func TestSessionCloseTimeout(t *testing.T) { // sleep to trigger session close timeout time.Sleep(1 * time.Second) return mocks.PerformEnd(0, nil) + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -199,6 +206,8 @@ func TestSessionNewReceiverBatchingOneCredit(t *testing.T) { return mocks.ReceiverAttach(0, tt.Name, 0, ReceiverSettleModeFirst, nil) case *frames.PerformFlow: return nil, nil + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -242,6 +251,8 @@ func TestSessionNewReceiverBatchingEnabled(t *testing.T) { return mocks.ReceiverAttach(0, tt.Name, 0, ReceiverSettleModeFirst, nil) case *frames.PerformFlow: return nil, nil + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -284,6 +295,8 @@ func TestSessionNewReceiverMismatchedLinkName(t *testing.T) { return mocks.PerformEnd(0, nil) case *frames.PerformAttach: return mocks.ReceiverAttach(0, "wrong_name", 0, ReceiverSettleModeFirst, nil) + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -351,6 +364,8 @@ func TestSessionNewSenderMismatchedLinkName(t *testing.T) { return mocks.PerformEnd(0, nil) case *frames.PerformAttach: return mocks.SenderAttach(0, "wrong_name", 0, SenderSettleModeUnsettled) + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -516,6 +531,8 @@ func TestSessionFlowFrameWithEcho(t *testing.T) { return nil, nil case *frames.PerformEnd: return mocks.PerformEnd(0, nil) + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) } @@ -562,6 +579,8 @@ func TestSessionInvalidAttachDeadlock(t *testing.T) { case *frames.PerformAttach: enqueueFrames(tt.Name) return nil, nil + case *frames.PerformClose: + return mocks.PerformClose(nil) default: return nil, fmt.Errorf("unhandled frame %T", req) }