go-amqp/conn.go

1148 строки
32 KiB
Go

package amqp
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"math"
"net"
"net/url"
"sync"
"time"
"github.com/Azure/go-amqp/internal/bitmap"
"github.com/Azure/go-amqp/internal/buffer"
"github.com/Azure/go-amqp/internal/debug"
"github.com/Azure/go-amqp/internal/encoding"
"github.com/Azure/go-amqp/internal/frames"
"github.com/Azure/go-amqp/internal/shared"
)
// Default connection options
const (
defaultIdleTimeout = 1 * time.Minute
defaultMaxFrameSize = 65536
defaultMaxSessions = 65536
defaultWriteTimeout = 30 * time.Second
)
// ConnOptions contains the optional settings for configuring an AMQP connection.
type ConnOptions struct {
// ContainerID sets the container-id to use when opening the connection.
//
// A container ID will be randomly generated if this option is not used.
ContainerID string
// HostName sets the hostname sent in the AMQP
// Open frame and TLS ServerName (if not otherwise set).
HostName string
// IdleTimeout specifies the maximum period between
// receiving frames from the peer.
//
// Specify a value less than zero to disable idle timeout.
//
// Default: 1 minute (60000000000).
IdleTimeout time.Duration
// MaxFrameSize sets the maximum frame size that
// the connection will accept.
//
// Must be 512 or greater.
//
// Default: 65536.
MaxFrameSize uint32
// MaxSessions sets the maximum number of channels.
// The value must be greater than zero.
//
// Default: 65536.
MaxSessions uint16
// Properties sets an entry in the connection properties map sent to the server.
Properties map[string]any
// SASLType contains the specified SASL authentication mechanism.
SASLType SASLType
// TLSConfig sets the tls.Config to be used during
// TLS negotiation.
//
// This option is for advanced usage, in most scenarios
// providing a URL scheme of "amqps://" is sufficient.
TLSConfig *tls.Config
// WriteTimeout controls the write deadline when writing AMQP frames to the
// underlying net.Conn and no caller provided context.Context is available or
// the context contains no deadline (e.g. context.Background()).
// The timeout is set per write.
//
// Setting to a value less than zero means no timeout is set, so writes
// defer to the underlying behavior of net.Conn with no write deadline.
//
// Default: 30s
WriteTimeout time.Duration
// test hook
dialer dialer
}
// Dial connects to an AMQP broker.
//
// If the addr includes a scheme, it must be "amqp", "amqps", or "amqp+ssl".
// If no port is provided, 5672 will be used for "amqp" and 5671 for "amqps" or "amqp+ssl".
//
// If username and password information is not empty it's used as SASL PLAIN
// credentials, equal to passing ConnSASLPlain option.
//
// opts: pass nil to accept the default values.
func Dial(ctx context.Context, addr string, opts *ConnOptions) (*Conn, error) {
c, err := dialConn(ctx, addr, opts)
if err != nil {
return nil, err
}
err = c.start(ctx)
if err != nil {
return nil, err
}
return c, nil
}
// NewConn establishes a new AMQP client connection over conn.
// NOTE: [Conn] takes ownership of the provided [net.Conn] and will close it as required.
// opts: pass nil to accept the default values.
func NewConn(ctx context.Context, conn net.Conn, opts *ConnOptions) (*Conn, error) {
c, err := newConn(conn, opts)
if err != nil {
return nil, err
}
err = c.start(ctx)
if err != nil {
return nil, err
}
return c, nil
}
// Conn is an AMQP connection.
type Conn struct {
net net.Conn // underlying connection
dialer dialer // used for testing purposes, it allows faking dialing TCP/TLS endpoints
writeTimeout time.Duration // controls write deadline in absense of a context
// TLS
tlsNegotiation bool // negotiate TLS
tlsComplete bool // TLS negotiation complete
tlsConfig *tls.Config // TLS config, default used if nil (ServerName set to Client.hostname)
// SASL
saslHandlers map[encoding.Symbol]stateFunc // map of supported handlers keyed by SASL mechanism, SASL not negotiated if nil
saslComplete bool // SASL negotiation complete; internal *except* for SASL auth methods
// local settings
maxFrameSize uint32 // max frame size to accept
channelMax uint16 // maximum number of channels to allow
hostname string // hostname of remote server (set explicitly or parsed from URL)
idleTimeout time.Duration // maximum period between receiving frames
properties map[encoding.Symbol]any // additional properties sent upon connection open
containerID string // set explicitly or randomly generated
// peer settings
peerIdleTimeout time.Duration // maximum period between sending frames
peerMaxFrameSize uint32 // maximum frame size peer will accept
// conn state
done chan struct{} // indicates the connection has terminated
doneErr error // contains the error state returned from Close(); DO NOT TOUCH outside of conn.go until done has been closed!
// connReader and connWriter management
rxtxExit chan struct{} // signals connReader and connWriter to exit
closeOnce sync.Once // ensures that close() is only called once
// session tracking
channels *bitmap.Bitmap
sessionsByChannel map[uint16]*Session
sessionsByChannelMu sync.RWMutex
abandonedSessionsMu sync.Mutex
abandonedSessions []*Session
// connReader
rxBuf buffer.Buffer // incoming bytes buffer
rxDone chan struct{} // closed when connReader exits
rxErr error // contains last error reading from c.net; DO NOT TOUCH outside of connReader until rxDone has been closed!
// connWriter
txFrame chan frameEnvelope // AMQP frames to be sent by connWriter
txBuf buffer.Buffer // buffer for marshaling frames before transmitting
txDone chan struct{} // closed when connWriter exits
txErr error // contains last error writing to c.net; DO NOT TOUCH outside of connWriter until txDone has been closed!
}
// used to abstract the underlying dialer for testing purposes
type dialer interface {
NetDialerDial(ctx context.Context, c *Conn, host, port string) error
TLSDialWithDialer(ctx context.Context, c *Conn, host, port string) error
}
// implements the dialer interface
type defaultDialer struct{}
func (defaultDialer) NetDialerDial(ctx context.Context, c *Conn, host, port string) (err error) {
dialer := &net.Dialer{}
c.net, err = dialer.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
return
}
func (defaultDialer) TLSDialWithDialer(ctx context.Context, c *Conn, host, port string) (err error) {
dialer := &tls.Dialer{Config: c.tlsConfig}
c.net, err = dialer.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
return
}
func dialConn(ctx context.Context, addr string, opts *ConnOptions) (*Conn, error) {
u, err := url.Parse(addr)
if err != nil {
return nil, err
}
host, port := u.Hostname(), u.Port()
if port == "" {
port = "5672"
if u.Scheme == "amqps" || u.Scheme == "amqp+ssl" {
port = "5671"
}
}
var cp ConnOptions
if opts != nil {
cp = *opts
}
// prepend SASL credentials when the user/pass segment is not empty
if u.User != nil {
pass, _ := u.User.Password()
cp.SASLType = SASLTypePlain(u.User.Username(), pass)
}
if cp.HostName == "" {
cp.HostName = host
}
c, err := newConn(nil, &cp)
if err != nil {
return nil, err
}
switch u.Scheme {
case "amqp", "":
err = c.dialer.NetDialerDial(ctx, c, host, port)
case "amqps", "amqp+ssl":
c.initTLSConfig()
c.tlsNegotiation = false
err = c.dialer.TLSDialWithDialer(ctx, c, host, port)
default:
err = fmt.Errorf("unsupported scheme %q", u.Scheme)
}
if err != nil {
return nil, err
}
return c, nil
}
func newConn(netConn net.Conn, opts *ConnOptions) (*Conn, error) {
c := &Conn{
dialer: defaultDialer{},
net: netConn,
maxFrameSize: defaultMaxFrameSize,
peerMaxFrameSize: defaultMaxFrameSize,
channelMax: defaultMaxSessions - 1, // -1 because channel-max starts at zero
idleTimeout: defaultIdleTimeout,
containerID: shared.RandString(40),
done: make(chan struct{}),
rxtxExit: make(chan struct{}),
rxDone: make(chan struct{}),
txFrame: make(chan frameEnvelope),
txDone: make(chan struct{}),
sessionsByChannel: map[uint16]*Session{},
writeTimeout: defaultWriteTimeout,
}
// apply options
if opts == nil {
opts = &ConnOptions{}
}
if opts.WriteTimeout > 0 {
c.writeTimeout = opts.WriteTimeout
} else if opts.WriteTimeout < 0 {
c.writeTimeout = 0
}
if opts.ContainerID != "" {
c.containerID = opts.ContainerID
}
if opts.HostName != "" {
c.hostname = opts.HostName
}
if opts.IdleTimeout > 0 {
c.idleTimeout = opts.IdleTimeout
} else if opts.IdleTimeout < 0 {
c.idleTimeout = 0
}
if opts.MaxFrameSize > 0 && opts.MaxFrameSize < 512 {
return nil, fmt.Errorf("invalid MaxFrameSize value %d", opts.MaxFrameSize)
} else if opts.MaxFrameSize > 512 {
c.maxFrameSize = opts.MaxFrameSize
}
if opts.MaxSessions > 0 {
c.channelMax = opts.MaxSessions
}
if opts.SASLType != nil {
if err := opts.SASLType(c); err != nil {
return nil, err
}
}
if opts.Properties != nil {
c.properties = make(map[encoding.Symbol]any)
for key, val := range opts.Properties {
c.properties[encoding.Symbol(key)] = val
}
}
if opts.TLSConfig != nil {
c.tlsConfig = opts.TLSConfig.Clone()
}
if opts.dialer != nil {
c.dialer = opts.dialer
}
return c, nil
}
func (c *Conn) initTLSConfig() {
// create a new config if not already set
if c.tlsConfig == nil {
c.tlsConfig = new(tls.Config)
}
// TLS config must have ServerName or InsecureSkipVerify set
if c.tlsConfig.ServerName == "" && !c.tlsConfig.InsecureSkipVerify {
c.tlsConfig.ServerName = c.hostname
}
}
// start establishes the connection and begins multiplexing network IO.
// It is an error to call Start() on a connection that's been closed.
func (c *Conn) start(ctx context.Context) (err error) {
// if the context has a deadline or is cancellable, start the interruptor goroutine.
// this will close the underlying net.Conn in response to the context.
if ctx.Done() != nil {
done := make(chan struct{})
interruptRes := make(chan error, 1)
defer func() {
close(done)
if ctxErr := <-interruptRes; ctxErr != nil {
// return context error to caller
err = ctxErr
}
}()
go func() {
select {
case <-ctx.Done():
c.closeDuringStart()
interruptRes <- ctx.Err()
case <-done:
interruptRes <- nil
}
}()
}
if err = c.startImpl(ctx); err != nil {
return err
}
// we can't create the channel bitmap until the connection has been established.
// this is because our peer can tell us the max channels they support.
c.channels = bitmap.New(uint32(c.channelMax))
go c.connWriter()
go c.connReader()
return
}
func (c *Conn) startImpl(ctx context.Context) error {
// set connection establishment deadline as required
if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
_ = c.net.SetDeadline(deadline)
// remove connection establishment deadline
defer func() {
_ = c.net.SetDeadline(time.Time{})
}()
}
// run connection establishment state machine
for state := c.negotiateProto; state != nil; {
var err error
state, err = state(ctx)
// check if err occurred
if err != nil {
c.closeDuringStart()
return err
}
}
return nil
}
// Close closes the connection.
func (c *Conn) Close() error {
c.close()
// wait until the reader/writer goroutines have exited before proceeding.
// this is to prevent a race between calling Close() and a reader/writer
// goroutine calling close() due to a terminal error.
<-c.txDone
<-c.rxDone
var connErr *ConnError
if errors.As(c.doneErr, &connErr) && connErr.RemoteErr == nil && connErr.inner == nil {
// an empty ConnectionError means the connection was closed by the caller
return nil
}
// there was an error during shut-down or connReader/connWriter
// experienced a terminal error
return c.doneErr
}
// close is called once, either from Close() or when connReader/connWriter exits
func (c *Conn) close() {
c.closeOnce.Do(func() {
defer close(c.done)
close(c.rxtxExit)
// wait for writing to stop, allows it to send the final close frame
<-c.txDone
closeErr := c.net.Close()
// check rxDone after closing net, otherwise may block
// for up to c.idleTimeout
<-c.rxDone
if errors.Is(c.rxErr, net.ErrClosed) {
// this is the expected error when the connection is closed, swallow it
c.rxErr = nil
}
if c.txErr == nil && c.rxErr == nil && closeErr == nil {
// if there are no errors, it means user initiated close() and we shut down cleanly
c.doneErr = &ConnError{}
} else if amqpErr, ok := c.rxErr.(*Error); ok {
// we experienced a peer-initiated close that contained an Error. return it
c.doneErr = &ConnError{RemoteErr: amqpErr}
} else if c.txErr != nil {
// c.txErr is already wrapped in a ConnError
c.doneErr = c.txErr
} else if c.rxErr != nil {
c.doneErr = &ConnError{inner: c.rxErr}
} else {
c.doneErr = &ConnError{inner: closeErr}
}
})
}
// closeDuringStart is a special close to be used only during startup (i.e. c.start() and any of its children)
func (c *Conn) closeDuringStart() {
c.closeOnce.Do(func() {
c.net.Close()
})
}
// NewSession starts a new session on the connection.
// - ctx controls waiting for the peer to acknowledge the session
// - opts contains optional values, pass nil to accept the defaults
//
// If the context's deadline expires or is cancelled before the operation
// completes, an error is returned. If the Session was successfully
// created, it will be cleaned up in future calls to NewSession.
func (c *Conn) NewSession(ctx context.Context, opts *SessionOptions) (*Session, error) {
// clean up any abandoned sessions first
if err := c.freeAbandonedSessions(ctx); err != nil {
return nil, err
}
session, err := c.newSession(opts)
if err != nil {
return nil, err
}
if err := session.begin(ctx); err != nil {
c.abandonSession(session)
return nil, err
}
return session, nil
}
func (c *Conn) freeAbandonedSessions(ctx context.Context) error {
c.abandonedSessionsMu.Lock()
defer c.abandonedSessionsMu.Unlock()
debug.Log(3, "TX (Conn %p): cleaning up %d abandoned sessions", c, len(c.abandonedSessions))
for _, s := range c.abandonedSessions {
fr := frames.PerformEnd{}
if err := s.txFrameAndWait(ctx, &fr); err != nil {
return err
}
}
c.abandonedSessions = nil
return nil
}
func (c *Conn) newSession(opts *SessionOptions) (*Session, error) {
c.sessionsByChannelMu.Lock()
defer c.sessionsByChannelMu.Unlock()
// create the next session to allocate
// note that channel always start at 0
channel, ok := c.channels.Next()
if !ok {
if err := c.Close(); err != nil {
return nil, err
}
return nil, &ConnError{inner: fmt.Errorf("reached connection channel max (%d)", c.channelMax)}
}
session := newSession(c, uint16(channel), opts)
c.sessionsByChannel[session.channel] = session
return session, nil
}
func (c *Conn) deleteSession(s *Session) {
c.sessionsByChannelMu.Lock()
defer c.sessionsByChannelMu.Unlock()
delete(c.sessionsByChannel, s.channel)
c.channels.Remove(uint32(s.channel))
}
func (c *Conn) abandonSession(s *Session) {
c.abandonedSessionsMu.Lock()
defer c.abandonedSessionsMu.Unlock()
c.abandonedSessions = append(c.abandonedSessions, s)
}
// connReader reads from the net.Conn, decodes frames, and either handles
// them here as appropriate or sends them to the session.rx channel.
func (c *Conn) connReader() {
defer func() {
close(c.rxDone)
c.close()
}()
var sessionsByRemoteChannel = make(map[uint16]*Session)
var err error
for {
if err != nil {
debug.Log(0, "RX (connReader %p): terminal error: %v", c, err)
c.rxErr = err
return
}
var fr frames.Frame
fr, err = c.readFrame()
if err != nil {
continue
}
debug.Log(0, "RX (connReader %p): %s", c, fr)
var (
session *Session
ok bool
)
switch body := fr.Body.(type) {
// Server initiated close.
case *frames.PerformClose:
// connWriter will send the close performative ack on its way out.
// it's a SHOULD though, not a MUST.
if body.Error == nil {
return
}
err = body.Error
continue
// RemoteChannel should be used when frame is Begin
case *frames.PerformBegin:
if body.RemoteChannel == nil {
// since we only support remotely-initiated sessions, this is an error
// TODO: it would be ideal to not have this kill the connection
err = fmt.Errorf("%T: nil RemoteChannel", fr.Body)
continue
}
c.sessionsByChannelMu.RLock()
session, ok = c.sessionsByChannel[*body.RemoteChannel]
c.sessionsByChannelMu.RUnlock()
if !ok {
// this can happen if NewSession() exits due to the context expiring/cancelled
// before the begin ack is received.
err = fmt.Errorf("unexpected remote channel number %d", *body.RemoteChannel)
continue
}
session.remoteChannel = fr.Channel
sessionsByRemoteChannel[fr.Channel] = session
case *frames.PerformEnd:
session, ok = sessionsByRemoteChannel[fr.Channel]
if !ok {
err = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel (PerformEnd)", fr.Body, fr.Channel)
continue
}
// we MUST remove the remote channel from our map as soon as we receive
// the ack (i.e. before passing it on to the session mux) on the session
// ending since the numbers are recycled.
delete(sessionsByRemoteChannel, fr.Channel)
c.deleteSession(session)
default:
// pass on performative to the correct session
session, ok = sessionsByRemoteChannel[fr.Channel]
if !ok {
err = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel", fr.Body, fr.Channel)
continue
}
}
q := session.rxQ.Acquire()
q.Enqueue(fr.Body)
session.rxQ.Release(q)
debug.Log(2, "RX (connReader %p): mux frame to Session (%p): %s", c, session, fr)
}
}
// readFrame reads a complete frame from c.net.
// it assumes that any read deadline has already been applied.
// used externally by SASL only.
func (c *Conn) readFrame() (frames.Frame, error) {
switch {
// Cheaply reuse free buffer space when fully read.
case c.rxBuf.Len() == 0:
c.rxBuf.Reset()
// Prevent excessive/unbounded growth by shifting data to beginning of buffer.
case int64(c.rxBuf.Size()) > int64(c.maxFrameSize):
c.rxBuf.Reclaim()
}
var (
currentHeader frames.Header // keep track of the current header, for frames split across multiple TCP packets
frameInProgress bool // true if in the middle of receiving data for currentHeader
)
for {
// need to read more if buf doesn't contain the complete frame
// or there's not enough in buf to parse the header
if frameInProgress || c.rxBuf.Len() < frames.HeaderSize {
// we MUST reset the idle timeout before each read from net.Conn
if c.idleTimeout > 0 {
_ = c.net.SetReadDeadline(time.Now().Add(c.idleTimeout))
}
err := c.rxBuf.ReadFromOnce(c.net)
if err != nil {
return frames.Frame{}, err
}
}
// parse the header if a frame isn't in progress
if !frameInProgress {
// read more if buf doesn't contain enough to parse the header
// NOTE: we MUST do this ONLY if a frame isn't in progress else we can
// end up stalling when reading frames with bodies smaller than HeaderSize
if c.rxBuf.Len() < frames.HeaderSize {
continue
}
var err error
currentHeader, err = frames.ParseHeader(&c.rxBuf)
if err != nil {
return frames.Frame{}, err
}
frameInProgress = true
}
// check size is reasonable
if currentHeader.Size > math.MaxInt32 { // make max size configurable
return frames.Frame{}, errors.New("payload too large")
}
bodySize := int64(currentHeader.Size - frames.HeaderSize)
// the full frame hasn't been received, keep reading
if int64(c.rxBuf.Len()) < bodySize {
continue
}
frameInProgress = false
// check if body is empty (keepalive)
if bodySize == 0 {
debug.Log(3, "RX (connReader %p): received keep-alive frame", c)
continue
}
// parse the frame
b, ok := c.rxBuf.Next(bodySize)
if !ok {
return frames.Frame{}, fmt.Errorf("buffer EOF; requested bytes: %d, actual size: %d", bodySize, c.rxBuf.Len())
}
parsedBody, err := frames.ParseBody(buffer.New(b))
if err != nil {
return frames.Frame{}, err
}
return frames.Frame{Channel: currentHeader.Channel, Body: parsedBody}, nil
}
}
// frameContext is an extended context.Context used to track writes to the network.
// this is required in order to remove ambiguities that can arise when simply waiting
// on context.Context.Done() to be signaled.
type frameContext struct {
// Ctx contains the caller's context and is used to set the write deadline.
Ctx context.Context
// Done is closed when the frame was successfully written to net.Conn or Ctx was cancelled/timed out.
// Can be nil, but shouldn't be for callers that care about confirmation of sending.
Done chan struct{}
// Err contains the context error. MUST be set before closing Done and ONLY read if Done is closed.
// ONLY Conn.connWriter may write to this field.
Err error
}
// frameEnvelope is used when sending a frame to connWriter to be written to net.Conn
type frameEnvelope struct {
FrameCtx *frameContext
Frame frames.Frame
}
func (c *Conn) connWriter() {
defer func() {
close(c.txDone)
c.close()
}()
var (
// keepalives are sent at a rate of 1/2 idle timeout
keepaliveInterval = c.peerIdleTimeout / 2
// 0 disables keepalives
keepalivesEnabled = keepaliveInterval > 0
// set if enable, nil if not; nil channels block forever
keepalive <-chan time.Time
)
if keepalivesEnabled {
ticker := time.NewTicker(keepaliveInterval)
defer ticker.Stop()
keepalive = ticker.C
}
var err error
for {
if err != nil {
debug.Log(0, "TX (connWriter %p): terminal error: %v", c, err)
c.txErr = err
return
}
select {
// frame write request
case env := <-c.txFrame:
timeout, ctxErr := c.getWriteTimeout(env.FrameCtx.Ctx)
if ctxErr != nil {
debug.Log(1, "TX (connWriter %p) getWriteTimeout: %s: %s", c, ctxErr.Error(), env.Frame)
if env.FrameCtx.Done != nil {
// the error MUST be set before closing the channel
env.FrameCtx.Err = ctxErr
close(env.FrameCtx.Done)
}
continue
}
debug.Log(0, "TX (connWriter %p) timeout %s: %s", c, timeout, env.Frame)
err = c.writeFrame(timeout, env.Frame)
if err == nil && env.FrameCtx.Done != nil {
close(env.FrameCtx.Done)
}
// in the event of write failure, Conn will close and a
// *ConnError will be propagated to all of the sessions/link.
// keepalive timer
case <-keepalive:
debug.Log(3, "TX (connWriter %p): sending keep-alive frame", c)
_ = c.net.SetWriteDeadline(time.Now().Add(c.writeTimeout))
if _, err = c.net.Write(keepaliveFrame); err != nil {
err = &ConnError{inner: err}
}
// It would be slightly more efficient in terms of network
// resources to reset the timer each time a frame is sent.
// However, keepalives are small (8 bytes) and the interval
// is usually on the order of minutes. It does not seem
// worth it to add extra operations in the write path to
// avoid. (To properly reset a timer it needs to be stopped,
// possibly drained, then reset.)
// connection complete
case <-c.rxtxExit:
// send close performative. note that the spec says we
// SHOULD wait for the ack but we don't HAVE to, in order
// to be resilient to bad actors etc. so we just send
// the close performative and exit.
fr := frames.Frame{
Type: frames.TypeAMQP,
Body: &frames.PerformClose{},
}
debug.Log(1, "TX (connWriter %p): %s", c, fr)
c.txErr = c.writeFrame(c.writeTimeout, fr)
return
}
}
}
// writeFrame writes a frame to the network.
// used externally by SASL only.
// - timeout - the write deadline to set. zero means no deadline
//
// errors are wrapped in a ConnError as they can be returned to outside callers.
func (c *Conn) writeFrame(timeout time.Duration, fr frames.Frame) error {
// writeFrame into txBuf
c.txBuf.Reset()
err := frames.Write(&c.txBuf, fr)
if err != nil {
return &ConnError{inner: err}
}
// validate the frame isn't exceeding peer's max frame size
requiredFrameSize := c.txBuf.Len()
if uint64(requiredFrameSize) > uint64(c.peerMaxFrameSize) {
return &ConnError{inner: fmt.Errorf("%T frame size %d larger than peer's max frame size %d", fr, requiredFrameSize, c.peerMaxFrameSize)}
}
if timeout == 0 {
_ = c.net.SetWriteDeadline(time.Time{})
} else if timeout > 0 {
_ = c.net.SetWriteDeadline(time.Now().Add(timeout))
}
// write to network
n, err := c.net.Write(c.txBuf.Bytes())
if l := c.txBuf.Len(); n > 0 && n < l && err != nil {
debug.Log(1, "TX (writeFrame %p): wrote %d bytes less than len %d: %v", c, n, l, err)
}
if err != nil {
err = &ConnError{inner: err}
}
return err
}
// writeProtoHeader writes an AMQP protocol header to the
// network
func (c *Conn) writeProtoHeader(pID protoID) error {
_, err := c.net.Write([]byte{'A', 'M', 'Q', 'P', byte(pID), 1, 0, 0})
return err
}
// keepaliveFrame is an AMQP frame with no body, used for keepalives
var keepaliveFrame = []byte{0x00, 0x00, 0x00, 0x08, 0x02, 0x00, 0x00, 0x00}
// SendFrame is used by sessions and links to send frames across the network.
func (c *Conn) sendFrame(frameEnv frameEnvelope) {
select {
case c.txFrame <- frameEnv:
debug.Log(2, "TX (Conn %p): mux frame to connWriter: %s", c, frameEnv.Frame)
case <-c.done:
// Conn has closed
}
}
// stateFunc is a state in a state machine.
//
// The state is advanced by returning the next state.
// The state machine concludes when nil is returned.
type stateFunc func(context.Context) (stateFunc, error)
// negotiateProto determines which proto to negotiate next.
// used externally by SASL only.
func (c *Conn) negotiateProto(ctx context.Context) (stateFunc, error) {
// in the order each must be negotiated
switch {
case c.tlsNegotiation && !c.tlsComplete:
return c.exchangeProtoHeader(protoTLS)
case c.saslHandlers != nil && !c.saslComplete:
return c.exchangeProtoHeader(protoSASL)
default:
return c.exchangeProtoHeader(protoAMQP)
}
}
type protoID uint8
// protocol IDs received in protoHeaders
const (
protoAMQP protoID = 0x0
protoTLS protoID = 0x2
protoSASL protoID = 0x3
)
// exchangeProtoHeader performs the round trip exchange of protocol
// headers, validation, and returns the protoID specific next state.
func (c *Conn) exchangeProtoHeader(pID protoID) (stateFunc, error) {
// write the proto header
if err := c.writeProtoHeader(pID); err != nil {
return nil, err
}
// read response header
p, err := c.readProtoHeader()
if err != nil {
return nil, err
}
if pID != p.ProtoID {
return nil, fmt.Errorf("unexpected protocol header %#00x, expected %#00x", p.ProtoID, pID)
}
// go to the proto specific state
switch pID {
case protoAMQP:
return c.openAMQP, nil
case protoTLS:
return c.startTLS, nil
case protoSASL:
return c.negotiateSASL, nil
default:
return nil, fmt.Errorf("unknown protocol ID %#02x", p.ProtoID)
}
}
// readProtoHeader reads a protocol header packet from c.rxProto.
func (c *Conn) readProtoHeader() (protoHeader, error) {
const protoHeaderSize = 8
// only read from the network once our buffer has been exhausted.
// TODO: this preserves existing behavior as some tests rely on this
// implementation detail (it lets you replay a stream of bytes). we
// might want to consider removing this and fixing the tests as the
// protocol doesn't actually work this way.
if c.rxBuf.Len() == 0 {
for {
err := c.rxBuf.ReadFromOnce(c.net)
if err != nil {
return protoHeader{}, err
}
// read more if buf doesn't contain enough to parse the header
if c.rxBuf.Len() >= protoHeaderSize {
break
}
}
}
buf, ok := c.rxBuf.Next(protoHeaderSize)
if !ok {
return protoHeader{}, errors.New("invalid protoHeader")
}
// bounds check hint to compiler; see golang.org/issue/14808
_ = buf[protoHeaderSize-1]
if !bytes.Equal(buf[:4], []byte{'A', 'M', 'Q', 'P'}) {
return protoHeader{}, fmt.Errorf("unexpected protocol %q", buf[:4])
}
p := protoHeader{
ProtoID: protoID(buf[4]),
Major: buf[5],
Minor: buf[6],
Revision: buf[7],
}
if p.Major != 1 || p.Minor != 0 || p.Revision != 0 {
return protoHeader{}, fmt.Errorf("unexpected protocol version %d.%d.%d", p.Major, p.Minor, p.Revision)
}
return p, nil
}
// startTLS wraps the conn with TLS and returns to Client.negotiateProto
func (c *Conn) startTLS(ctx context.Context) (stateFunc, error) {
c.initTLSConfig()
_ = c.net.SetReadDeadline(time.Time{}) // clear timeout
// wrap existing net.Conn and perform TLS handshake
tlsConn := tls.Client(c.net, c.tlsConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
return nil, err
}
// swap net.Conn
c.net = tlsConn
c.tlsComplete = true
// go to next protocol
return c.negotiateProto, nil
}
// openAMQP round trips the AMQP open performative
func (c *Conn) openAMQP(ctx context.Context) (stateFunc, error) {
// send open frame
open := &frames.PerformOpen{
ContainerID: c.containerID,
Hostname: c.hostname,
MaxFrameSize: c.maxFrameSize,
ChannelMax: c.channelMax,
IdleTimeout: c.idleTimeout / 2, // per spec, advertise half our idle timeout
Properties: c.properties,
}
fr := frames.Frame{
Type: frames.TypeAMQP,
Body: open,
Channel: 0,
}
debug.Log(1, "TX (openAMQP %p): %s", c, fr)
timeout, err := c.getWriteTimeout(ctx)
if err != nil {
return nil, err
}
if err = c.writeFrame(timeout, fr); err != nil {
return nil, err
}
// get the response
fr, err = c.readSingleFrame()
if err != nil {
return nil, err
}
debug.Log(1, "RX (openAMQP %p): %s", c, fr)
o, ok := fr.Body.(*frames.PerformOpen)
if !ok {
return nil, fmt.Errorf("openAMQP: unexpected frame type %T", fr.Body)
}
// update peer settings
if o.MaxFrameSize > 0 {
c.peerMaxFrameSize = o.MaxFrameSize
}
if o.IdleTimeout > 0 {
// TODO: reject very small idle timeouts
c.peerIdleTimeout = o.IdleTimeout
}
if o.ChannelMax < c.channelMax {
c.channelMax = o.ChannelMax
}
// connection established, exit state machine
return nil, nil
}
// negotiateSASL returns the SASL handler for the first matched
// mechanism specified by the server
func (c *Conn) negotiateSASL(context.Context) (stateFunc, error) {
// read mechanisms frame
fr, err := c.readSingleFrame()
if err != nil {
return nil, err
}
debug.Log(1, "RX (negotiateSASL %p): %s", c, fr)
sm, ok := fr.Body.(*frames.SASLMechanisms)
if !ok {
return nil, fmt.Errorf("negotiateSASL: unexpected frame type %T", fr.Body)
}
// return first match in c.saslHandlers based on order received
for _, mech := range sm.Mechanisms {
if state, ok := c.saslHandlers[mech]; ok {
return state, nil
}
}
// no match
return nil, fmt.Errorf("no supported auth mechanism (%v)", sm.Mechanisms) // TODO: send "auth not supported" frame?
}
// saslOutcome processes the SASL outcome frame and return Client.negotiateProto
// on success.
//
// SASL handlers return this stateFunc when the mechanism specific negotiation
// has completed.
// used externally by SASL only.
func (c *Conn) saslOutcome(context.Context) (stateFunc, error) {
// read outcome frame
fr, err := c.readSingleFrame()
if err != nil {
return nil, err
}
debug.Log(1, "RX (saslOutcome %p): %s", c, fr)
so, ok := fr.Body.(*frames.SASLOutcome)
if !ok {
return nil, fmt.Errorf("saslOutcome: unexpected frame type %T", fr.Body)
}
// check if auth succeeded
if so.Code != encoding.CodeSASLOK {
return nil, fmt.Errorf("SASL PLAIN auth failed with code %#00x: %s", so.Code, so.AdditionalData) // implement Stringer for so.Code
}
// return to c.negotiateProto
c.saslComplete = true
return c.negotiateProto, nil
}
// readSingleFrame is used during connection establishment to read a single frame.
//
// After setup, conn.connReader handles incoming frames.
func (c *Conn) readSingleFrame() (frames.Frame, error) {
fr, err := c.readFrame()
if err != nil {
return frames.Frame{}, err
}
return fr, nil
}
// getWriteTimeout returns the timeout as calculated from the context's deadline
// or the default write timeout if the context has no deadline.
// if the context has timed out or was cancelled, an error is returned.
func (c *Conn) getWriteTimeout(ctx context.Context) (time.Duration, error) {
if ctx.Err() != nil {
// if the context is already cancelled we can just bail.
return 0, ctx.Err()
}
if deadline, ok := ctx.Deadline(); ok {
until := time.Until(deadline)
if until <= 0 {
return 0, context.DeadlineExceeded
}
return until, nil
}
return c.writeTimeout, nil
}
type protoHeader struct {
ProtoID protoID
Major uint8
Minor uint8
Revision uint8
}