2017-04-01 23:00:36 +03:00
|
|
|
package amqp
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
2017-04-23 04:32:50 +03:00
|
|
|
"crypto/tls"
|
2017-05-01 08:02:53 +03:00
|
|
|
"io"
|
2017-04-30 02:38:15 +03:00
|
|
|
"math"
|
2017-04-01 23:00:36 +03:00
|
|
|
"net"
|
2017-04-27 06:35:29 +03:00
|
|
|
"sync"
|
2017-04-17 06:39:31 +03:00
|
|
|
"time"
|
2017-04-01 23:00:36 +03:00
|
|
|
)
|
|
|
|
|
2017-05-07 04:24:06 +03:00
|
|
|
// Default connection options
|
2017-04-01 23:00:36 +03:00
|
|
|
const (
|
2017-05-07 02:57:27 +03:00
|
|
|
DefaultIdleTimeout = 1 * time.Minute
|
2018-01-19 06:36:02 +03:00
|
|
|
DefaultMaxFrameSize = 512
|
|
|
|
DefaultMaxSessions = 65536
|
2017-04-30 02:38:15 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
// Errors
|
|
|
|
var (
|
|
|
|
ErrTimeout = errorNew("timeout waiting for response")
|
2017-04-01 23:00:36 +03:00
|
|
|
)
|
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
// ConnOption is an function for configuring an AMQP connection.
|
2017-05-04 05:30:30 +03:00
|
|
|
type ConnOption func(*conn) error
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-05-04 05:30:30 +03:00
|
|
|
// ConnServerHostname sets the hostname sent in the AMQP
|
2017-04-30 02:38:15 +03:00
|
|
|
// Open frame and TLS ServerName (if not otherwise set).
|
|
|
|
//
|
|
|
|
// This is useful when the AMQP connection will be established
|
|
|
|
// via a pre-established TLS connection as the server may not
|
2017-05-04 05:30:30 +03:00
|
|
|
// know which hostname the client is attempting to connect to.
|
|
|
|
func ConnServerHostname(hostname string) ConnOption {
|
|
|
|
return func(c *conn) error {
|
2017-04-16 21:37:51 +03:00
|
|
|
c.hostname = hostname
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
// ConnTLS toggles TLS negotiation.
|
2017-05-04 05:30:30 +03:00
|
|
|
//
|
|
|
|
// Default: false.
|
2017-04-30 02:38:15 +03:00
|
|
|
func ConnTLS(enable bool) ConnOption {
|
2017-05-04 05:30:30 +03:00
|
|
|
return func(c *conn) error {
|
2017-04-23 19:42:48 +03:00
|
|
|
c.tlsNegotiation = enable
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
// ConnTLSConfig sets the tls.Config to be used during
|
|
|
|
// TLS negotiation.
|
|
|
|
//
|
|
|
|
// This option is for advanced usage, in most scenarios
|
|
|
|
// providing a URL scheme of "amqps://" or ConnTLS(true)
|
|
|
|
// is sufficient.
|
|
|
|
func ConnTLSConfig(tc *tls.Config) ConnOption {
|
2017-05-04 05:30:30 +03:00
|
|
|
return func(c *conn) error {
|
2017-04-23 19:42:48 +03:00
|
|
|
c.tlsConfig = tc
|
|
|
|
c.tlsNegotiation = true
|
2017-04-23 04:32:50 +03:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
// ConnIdleTimeout specifies the maximum period between receiving
|
|
|
|
// frames from the peer.
|
|
|
|
//
|
|
|
|
// Resolution is milliseconds. A value of zero indicates no timeout.
|
|
|
|
// This setting is in addition to TCP keepalives.
|
2017-05-04 05:30:30 +03:00
|
|
|
//
|
|
|
|
// Default: 1 minute.
|
2017-04-30 02:38:15 +03:00
|
|
|
func ConnIdleTimeout(d time.Duration) ConnOption {
|
2017-05-04 05:30:30 +03:00
|
|
|
return func(c *conn) error {
|
2017-04-30 02:38:15 +03:00
|
|
|
if d < 0 {
|
|
|
|
return errorNew("idle timeout cannot be negative")
|
|
|
|
}
|
2017-04-27 06:35:29 +03:00
|
|
|
c.idleTimeout = d
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
// ConnMaxFrameSize sets the maximum frame size that
|
2017-05-07 01:26:17 +03:00
|
|
|
// the connection will accept.
|
2017-04-30 02:38:15 +03:00
|
|
|
//
|
|
|
|
// Must be 512 or greater.
|
|
|
|
//
|
2017-05-04 05:30:30 +03:00
|
|
|
// Default: 512.
|
2017-04-30 02:38:15 +03:00
|
|
|
func ConnMaxFrameSize(n uint32) ConnOption {
|
2017-05-04 05:30:30 +03:00
|
|
|
return func(c *conn) error {
|
2017-04-30 02:38:15 +03:00
|
|
|
if n < 512 {
|
|
|
|
return errorNew("max frame size must be 512 or greater")
|
|
|
|
}
|
2017-04-27 06:35:29 +03:00
|
|
|
c.maxFrameSize = n
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-05-04 05:30:30 +03:00
|
|
|
// ConnConnectTimeout configures how long to wait for the
|
|
|
|
// server during connection establishment.
|
|
|
|
//
|
|
|
|
// Once the connection has been established, ConnIdleTimeout
|
|
|
|
// applies. If duration is zero, no timeout will be applied.
|
|
|
|
//
|
|
|
|
// Default: 0.
|
|
|
|
func ConnConnectTimeout(d time.Duration) ConnOption {
|
|
|
|
return func(c *conn) error { c.connectTimeout = d; return nil }
|
|
|
|
}
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2018-01-19 06:36:02 +03:00
|
|
|
// ConnMaxSessions sets the maximum number of channels.
|
|
|
|
//
|
|
|
|
// n must be in the range 1 to 65536.
|
|
|
|
//
|
|
|
|
// BUG: Currently this limits how many channels can ever
|
|
|
|
// be opened on this connection rather than how many
|
|
|
|
// channels can be open at the same time.
|
|
|
|
//
|
|
|
|
// Default: 65536.
|
|
|
|
func ConnMaxSessions(n int) ConnOption {
|
|
|
|
return func(c *conn) error {
|
|
|
|
if n < 1 {
|
|
|
|
return errorNew("max sessions cannot be less than 1")
|
|
|
|
}
|
|
|
|
if n > 65536 {
|
|
|
|
return errorNew("max sessions cannot be greater than 65536")
|
|
|
|
}
|
|
|
|
c.channelMax = uint16(n - 1)
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-05-04 05:30:30 +03:00
|
|
|
// conn is an AMQP connection.
|
|
|
|
type conn struct {
|
|
|
|
net net.Conn // underlying connection
|
|
|
|
connectTimeout time.Duration // time to wait for reads/writes during conn establishment
|
2017-04-23 19:42:48 +03:00
|
|
|
|
|
|
|
// TLS
|
2017-04-30 02:38:15 +03:00
|
|
|
tlsNegotiation bool // negotiate TLS
|
|
|
|
tlsComplete bool // TLS negotiation complete
|
2017-05-04 05:30:30 +03:00
|
|
|
tlsConfig *tls.Config // TLS config, default used if nil (ServerName set to Client.hostname)
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
// SASL
|
2017-05-07 04:05:32 +03:00
|
|
|
saslHandlers map[symbol]stateFunc // map of supported handlers keyed by SASL mechanism, SASL not negotiated if nil
|
2017-04-30 02:38:15 +03:00
|
|
|
saslComplete bool // SASL negotiation complete
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
// local settings
|
2017-05-07 05:10:33 +03:00
|
|
|
maxFrameSize uint32 // max frame size to accept
|
|
|
|
channelMax uint16 // maximum number of channels to allow
|
2017-04-30 02:38:15 +03:00
|
|
|
hostname string // hostname of remote server (set explicitly or parsed from URL)
|
|
|
|
idleTimeout time.Duration // maximum period between receiving frames
|
2017-04-27 06:35:29 +03:00
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
// peer settings
|
|
|
|
peerIdleTimeout time.Duration // maximum period between sending frames
|
|
|
|
peerMaxFrameSize uint32 // maximum frame size peer will accept
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
// conn state
|
2017-11-13 07:51:38 +03:00
|
|
|
errMu sync.Mutex // mux holds errMu from start until shutdown completes; operations are sequential before mux is started
|
|
|
|
err error // error to be returned to client
|
|
|
|
doneOnce sync.Once // only close done once
|
|
|
|
done chan struct{} // indicates the connection is done
|
|
|
|
closeOnce sync.Once
|
2017-04-01 23:00:36 +03:00
|
|
|
|
|
|
|
// mux
|
2018-01-19 06:36:02 +03:00
|
|
|
newSession chan newSessionResp // new Sessions are requested from mux by reading off this channel
|
|
|
|
delSession chan *Session // session completion is indicated to mux by sending the Session on this channel
|
|
|
|
connErr chan error // connReader/Writer notifications of an error
|
2017-05-04 09:12:18 +03:00
|
|
|
|
|
|
|
// connReader
|
2018-02-12 02:26:24 +03:00
|
|
|
rxProto chan protoHeader // protoHeaders received by connReader
|
|
|
|
rxFrame chan frame // AMQP frames received by connReader
|
|
|
|
rxDone chan struct{}
|
|
|
|
connReaderRun chan func() // functions to be run by conn reader (set deadline on conn to run)
|
2017-05-04 09:12:18 +03:00
|
|
|
|
|
|
|
// connWriter
|
2018-02-03 22:54:49 +03:00
|
|
|
txFrame chan frame // AMQP frames to be sent by connWriter
|
|
|
|
txBuf buffer // buffer for marshaling frames before transmitting
|
2017-11-13 07:51:38 +03:00
|
|
|
txDone chan struct{}
|
2017-05-04 09:12:18 +03:00
|
|
|
}
|
|
|
|
|
2018-01-19 06:36:02 +03:00
|
|
|
type newSessionResp struct {
|
|
|
|
session *Session
|
|
|
|
err error
|
|
|
|
}
|
|
|
|
|
2017-05-04 09:12:18 +03:00
|
|
|
func newConn(netConn net.Conn, opts ...ConnOption) (*conn, error) {
|
|
|
|
c := &conn{
|
|
|
|
net: netConn,
|
2017-05-07 02:57:27 +03:00
|
|
|
maxFrameSize: DefaultMaxFrameSize,
|
|
|
|
peerMaxFrameSize: DefaultMaxFrameSize,
|
2018-01-19 06:36:02 +03:00
|
|
|
channelMax: DefaultMaxSessions - 1, // -1 because channel-max starts at zero
|
2017-05-07 02:57:27 +03:00
|
|
|
idleTimeout: DefaultIdleTimeout,
|
2017-05-04 09:12:18 +03:00
|
|
|
done: make(chan struct{}),
|
|
|
|
connErr: make(chan error, 2), // buffered to ensure connReader/Writer won't leak
|
|
|
|
rxProto: make(chan protoHeader),
|
|
|
|
rxFrame: make(chan frame),
|
2017-11-13 07:51:38 +03:00
|
|
|
rxDone: make(chan struct{}),
|
2018-02-12 02:26:24 +03:00
|
|
|
connReaderRun: make(chan func(), 1), // buffered to allow queueing function before interrupt
|
2018-01-19 06:36:02 +03:00
|
|
|
newSession: make(chan newSessionResp),
|
2017-05-04 09:12:18 +03:00
|
|
|
delSession: make(chan *Session),
|
|
|
|
txFrame: make(chan frame),
|
2017-11-13 07:51:38 +03:00
|
|
|
txDone: make(chan struct{}),
|
2017-05-04 09:12:18 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
// apply options
|
|
|
|
for _, opt := range opts {
|
|
|
|
if err := opt(c); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
}
|
2018-01-28 19:13:49 +03:00
|
|
|
return c, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *conn) initTLSConfig() {
|
|
|
|
// create a new config if not already set
|
|
|
|
if c.tlsConfig == nil {
|
|
|
|
c.tlsConfig = new(tls.Config)
|
|
|
|
}
|
2017-05-04 09:12:18 +03:00
|
|
|
|
2018-01-28 19:13:49 +03:00
|
|
|
// TLS config must have ServerName or InsecureSkipVerify set
|
|
|
|
if c.tlsConfig.ServerName == "" && !c.tlsConfig.InsecureSkipVerify {
|
|
|
|
c.tlsConfig.ServerName = c.hostname
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *conn) start() error {
|
2017-05-04 09:56:55 +03:00
|
|
|
// start reader
|
2017-05-04 09:12:18 +03:00
|
|
|
go c.connReader()
|
|
|
|
|
|
|
|
// run connection establishment state machine
|
|
|
|
for state := c.negotiateProto; state != nil; {
|
|
|
|
state = state()
|
|
|
|
}
|
|
|
|
|
|
|
|
// check if err occurred
|
|
|
|
if c.err != nil {
|
2017-11-13 07:51:38 +03:00
|
|
|
close(c.txDone) // close here since connWriter hasn't been started yet
|
|
|
|
c.Close()
|
2018-01-28 19:13:49 +03:00
|
|
|
return c.err
|
2017-05-04 09:12:18 +03:00
|
|
|
}
|
|
|
|
|
2017-05-04 09:56:55 +03:00
|
|
|
// start multiplexor and writer
|
2017-05-04 09:12:18 +03:00
|
|
|
go c.mux()
|
2017-05-04 09:56:55 +03:00
|
|
|
go c.connWriter()
|
2017-05-04 09:12:18 +03:00
|
|
|
|
2018-01-28 19:13:49 +03:00
|
|
|
return nil
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
2017-11-13 07:51:38 +03:00
|
|
|
func (c *conn) Close() error {
|
|
|
|
c.closeOnce.Do(func() { c.close() })
|
|
|
|
return c.err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *conn) close() {
|
2017-05-01 08:02:53 +03:00
|
|
|
c.closeDone() // notify goroutines and blocked functions to exit
|
2017-04-27 06:35:29 +03:00
|
|
|
|
2017-05-07 05:10:33 +03:00
|
|
|
// Client.mux holds err lock until shutdown, block until
|
|
|
|
// shutdown completes, then return the error (if any)
|
2017-04-27 06:35:29 +03:00
|
|
|
c.errMu.Lock()
|
|
|
|
defer c.errMu.Unlock()
|
2017-11-13 07:51:38 +03:00
|
|
|
|
|
|
|
// wait for writing to stop, allows it to send the final close frame
|
|
|
|
<-c.txDone
|
|
|
|
|
2017-04-24 06:24:12 +03:00
|
|
|
err := c.net.Close()
|
|
|
|
if c.err == nil {
|
|
|
|
c.err = err
|
|
|
|
}
|
2017-11-13 07:51:38 +03:00
|
|
|
|
|
|
|
// check rxDone after closing net, otherwise may block
|
|
|
|
// for up to c.idleTimeout
|
|
|
|
<-c.rxDone
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
2017-05-04 05:30:30 +03:00
|
|
|
// closeDone closes Client.done if it has not already been closed
|
|
|
|
func (c *conn) closeDone() {
|
|
|
|
c.doneOnce.Do(func() { close(c.done) })
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
2017-11-13 07:51:38 +03:00
|
|
|
// getErr returns conn.err.
|
|
|
|
//
|
|
|
|
// Must only be called after conn.done is closed.
|
|
|
|
func (c *conn) getErr() error {
|
|
|
|
c.errMu.Lock()
|
|
|
|
defer c.errMu.Unlock()
|
|
|
|
return c.err
|
|
|
|
}
|
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// mux is start in it's own goroutine after initial connection establishment.
|
|
|
|
// It handles muxing of sessions, keepalives, and connection errors.
|
2017-05-04 05:30:30 +03:00
|
|
|
func (c *conn) mux() {
|
2018-01-19 06:36:02 +03:00
|
|
|
var (
|
|
|
|
// create the next session to allocate
|
|
|
|
nextSession = newSessionResp{session: newSession(c, 0)}
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2018-01-19 06:36:02 +03:00
|
|
|
// map channels to sessions
|
|
|
|
sessionsByChannel = make(map[uint16]*Session)
|
|
|
|
sessionsByRemoteChannel = make(map[uint16]*Session)
|
|
|
|
)
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-05-07 05:10:33 +03:00
|
|
|
// hold the errMu lock until error or done
|
2017-04-27 06:35:29 +03:00
|
|
|
c.errMu.Lock()
|
|
|
|
defer c.errMu.Unlock()
|
|
|
|
|
2017-04-01 23:00:36 +03:00
|
|
|
for {
|
2017-05-01 08:02:53 +03:00
|
|
|
// check if last loop returned an error
|
2017-04-01 23:00:36 +03:00
|
|
|
if c.err != nil {
|
2017-04-27 06:35:29 +03:00
|
|
|
c.closeDone()
|
2017-04-24 06:24:12 +03:00
|
|
|
return
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
select {
|
2017-05-01 08:02:53 +03:00
|
|
|
// error from connReader
|
2017-05-04 09:12:18 +03:00
|
|
|
case c.err = <-c.connErr:
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// new frame from connReader
|
2017-04-23 04:31:07 +03:00
|
|
|
case fr := <-c.rxFrame:
|
2018-02-08 08:26:49 +03:00
|
|
|
var (
|
|
|
|
session *Session
|
|
|
|
ok bool
|
|
|
|
)
|
|
|
|
|
|
|
|
switch body := fr.body.(type) {
|
|
|
|
// RemoteChannel should be used when frame is Begin
|
|
|
|
case *performBegin:
|
|
|
|
session, ok = sessionsByChannel[body.RemoteChannel]
|
2018-01-19 06:36:02 +03:00
|
|
|
if !ok {
|
2018-02-08 08:26:49 +03:00
|
|
|
break
|
2018-01-19 06:36:02 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
session.remoteChannel = fr.channel
|
|
|
|
sessionsByRemoteChannel[fr.channel] = session
|
2018-02-08 08:26:49 +03:00
|
|
|
|
|
|
|
default:
|
|
|
|
session, ok = sessionsByRemoteChannel[fr.channel]
|
|
|
|
}
|
|
|
|
|
|
|
|
if !ok {
|
|
|
|
c.err = errorErrorf("unexpected frame: %#v", fr.body)
|
|
|
|
continue
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
2017-12-31 08:22:02 +03:00
|
|
|
|
|
|
|
select {
|
2018-01-19 06:36:02 +03:00
|
|
|
case session.rx <- fr:
|
2017-12-31 08:22:02 +03:00
|
|
|
case <-c.done:
|
|
|
|
return
|
|
|
|
}
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// new session request
|
|
|
|
//
|
|
|
|
// Continually try to send the next session on the channel,
|
|
|
|
// then add it to the sessions map. This allows us to control ID
|
|
|
|
// allocation and prevents the need to have shared map. Since new
|
|
|
|
// sessions are far less frequent than frames being sent to sessions,
|
2017-05-07 05:10:33 +03:00
|
|
|
// this avoids the lock/unlock for session lookup.
|
2017-05-01 08:02:53 +03:00
|
|
|
case c.newSession <- nextSession:
|
2018-01-19 06:36:02 +03:00
|
|
|
if nextSession.err != nil {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
ch := nextSession.session.channel
|
|
|
|
sessionsByChannel[ch] = nextSession.session
|
|
|
|
|
|
|
|
if ch >= c.channelMax {
|
|
|
|
nextSession.session = nil
|
|
|
|
nextSession.err = errorErrorf("reached connection channel max (%d)", c.channelMax)
|
|
|
|
continue
|
|
|
|
}
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// create the next session to send
|
2018-01-19 06:36:02 +03:00
|
|
|
nextSession.session = newSession(c, ch+1)
|
2017-05-01 08:02:53 +03:00
|
|
|
|
|
|
|
// session deletion
|
2017-04-01 23:00:36 +03:00
|
|
|
case s := <-c.delSession:
|
2018-01-19 06:36:02 +03:00
|
|
|
// TODO: allow channel number reuse
|
|
|
|
delete(sessionsByChannel, s.channel)
|
|
|
|
delete(sessionsByRemoteChannel, s.remoteChannel)
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// connection is complete
|
2017-04-27 06:35:29 +03:00
|
|
|
case <-c.done:
|
|
|
|
return
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
// connReader reads from the net.Conn, decodes frames, and passes them
|
2017-05-04 05:30:30 +03:00
|
|
|
// up via the conn.rxFrame and conn.rxProto channels.
|
|
|
|
func (c *conn) connReader() {
|
2017-11-13 07:51:38 +03:00
|
|
|
defer close(c.rxDone)
|
|
|
|
|
2018-02-03 22:54:49 +03:00
|
|
|
buf := new(buffer)
|
2017-04-24 06:24:12 +03:00
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
var (
|
2017-05-07 05:10:33 +03:00
|
|
|
negotiating = true // true during conn establishment, check for protoHeaders
|
2017-05-01 08:02:53 +03:00
|
|
|
currentHeader frameHeader // keep track of the current header, for frames split across multiple TCP packets
|
2017-05-07 05:10:33 +03:00
|
|
|
frameInProgress bool // true if in the middle of receiving data for currentHeader
|
2017-04-30 02:38:15 +03:00
|
|
|
)
|
2017-04-23 04:31:07 +03:00
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
for {
|
2018-02-03 22:54:49 +03:00
|
|
|
if buf.len() == 0 {
|
|
|
|
buf.reset()
|
|
|
|
}
|
|
|
|
|
2017-05-07 05:10:33 +03:00
|
|
|
// need to read more if buf doesn't contain the complete frame
|
2017-05-01 08:02:53 +03:00
|
|
|
// or there's not enough in buf to parse the header
|
2018-02-03 22:54:49 +03:00
|
|
|
if frameInProgress || buf.len() < frameHeaderSize {
|
2017-04-30 18:56:16 +03:00
|
|
|
c.net.SetReadDeadline(time.Now().Add(c.idleTimeout))
|
2018-02-03 22:54:49 +03:00
|
|
|
err := buf.readFromOnce(c.net)
|
2017-04-27 06:35:29 +03:00
|
|
|
if err != nil {
|
2017-11-13 07:51:38 +03:00
|
|
|
select {
|
2018-02-12 02:26:24 +03:00
|
|
|
// check if error was due to close in progress
|
2017-11-13 07:51:38 +03:00
|
|
|
case <-c.done:
|
|
|
|
return
|
2018-02-12 02:26:24 +03:00
|
|
|
|
|
|
|
// if there is a pending connReaderRun function, execute it
|
|
|
|
case f := <-c.connReaderRun:
|
|
|
|
f()
|
|
|
|
continue
|
|
|
|
|
|
|
|
// send error to mux and return
|
2017-11-13 07:51:38 +03:00
|
|
|
default:
|
2018-02-12 02:26:24 +03:00
|
|
|
c.connErr <- err
|
|
|
|
return
|
2017-11-13 07:51:38 +03:00
|
|
|
}
|
2017-04-24 06:24:12 +03:00
|
|
|
}
|
2017-04-27 06:35:29 +03:00
|
|
|
}
|
2017-04-23 04:31:07 +03:00
|
|
|
|
2017-05-07 05:10:33 +03:00
|
|
|
// read more if buf doesn't contain enough to parse the header
|
2018-02-03 22:54:49 +03:00
|
|
|
if buf.len() < frameHeaderSize {
|
2017-04-27 06:35:29 +03:00
|
|
|
continue
|
|
|
|
}
|
2017-04-23 04:31:07 +03:00
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// during negotiation, check for proto frames
|
2018-02-03 22:54:49 +03:00
|
|
|
if negotiating && bytes.Equal(buf.bytes()[:4], []byte{'A', 'M', 'Q', 'P'}) {
|
2017-04-30 02:38:15 +03:00
|
|
|
p, err := parseProtoHeader(buf)
|
2017-04-23 04:31:07 +03:00
|
|
|
if err != nil {
|
2017-05-04 09:12:18 +03:00
|
|
|
c.connErr <- err
|
2017-04-24 06:24:12 +03:00
|
|
|
return
|
2017-04-23 04:31:07 +03:00
|
|
|
}
|
|
|
|
|
2017-05-07 05:10:33 +03:00
|
|
|
// negotiation is complete once an AMQP proto frame is received
|
2017-04-27 06:35:29 +03:00
|
|
|
if p.ProtoID == protoAMQP {
|
|
|
|
negotiating = false
|
2017-04-24 06:24:12 +03:00
|
|
|
}
|
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// send proto header
|
2017-04-24 06:24:12 +03:00
|
|
|
select {
|
|
|
|
case <-c.done:
|
|
|
|
return
|
2017-04-27 06:35:29 +03:00
|
|
|
case c.rxProto <- p:
|
|
|
|
}
|
2017-04-30 02:38:15 +03:00
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
2017-05-07 05:10:33 +03:00
|
|
|
// parse the header if a frame isn't in progress
|
2017-04-27 06:35:29 +03:00
|
|
|
if !frameInProgress {
|
2017-05-01 08:02:53 +03:00
|
|
|
var err error
|
2017-04-27 06:35:29 +03:00
|
|
|
currentHeader, err = parseFrameHeader(buf)
|
|
|
|
if err != nil {
|
2017-05-04 09:12:18 +03:00
|
|
|
c.connErr <- err
|
2017-04-27 06:35:29 +03:00
|
|
|
return
|
|
|
|
}
|
|
|
|
frameInProgress = true
|
|
|
|
}
|
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// check size is reasonable
|
2017-04-30 02:38:15 +03:00
|
|
|
if currentHeader.Size > math.MaxInt32 { // make max size configurable
|
2017-05-04 09:12:18 +03:00
|
|
|
c.connErr <- errorNew("payload too large")
|
2017-04-30 02:38:15 +03:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
bodySize := int(currentHeader.Size - frameHeaderSize)
|
2017-04-30 02:38:15 +03:00
|
|
|
|
2017-05-07 05:10:33 +03:00
|
|
|
// the full frame has been received
|
2018-02-03 22:54:49 +03:00
|
|
|
if buf.len() < bodySize {
|
2017-04-27 06:35:29 +03:00
|
|
|
continue
|
|
|
|
}
|
|
|
|
frameInProgress = false
|
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// check if body is empty (keepalive)
|
2017-04-30 02:38:15 +03:00
|
|
|
if bodySize == 0 {
|
2017-05-01 08:02:53 +03:00
|
|
|
continue
|
2017-04-30 02:38:15 +03:00
|
|
|
}
|
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// parse the frame
|
2018-02-03 22:54:49 +03:00
|
|
|
b, ok := buf.next(bodySize)
|
|
|
|
if !ok {
|
|
|
|
c.connErr <- io.EOF
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
parsedBody, err := parseFrameBody(&buffer{b: b})
|
2017-04-27 06:35:29 +03:00
|
|
|
if err != nil {
|
2017-05-04 09:12:18 +03:00
|
|
|
c.connErr <- err
|
2017-04-27 06:35:29 +03:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// send to mux
|
2017-04-27 06:35:29 +03:00
|
|
|
select {
|
|
|
|
case <-c.done:
|
|
|
|
return
|
2017-04-30 18:56:16 +03:00
|
|
|
case c.rxFrame <- frame{channel: currentHeader.Channel, body: parsedBody}:
|
2017-04-27 06:35:29 +03:00
|
|
|
}
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-05-04 09:12:18 +03:00
|
|
|
func (c *conn) connWriter() {
|
2017-11-13 07:51:38 +03:00
|
|
|
defer close(c.txDone)
|
|
|
|
|
2017-05-07 01:26:17 +03:00
|
|
|
var (
|
|
|
|
// keepalives are sent at a rate of 1/2 idle timeout
|
|
|
|
keepaliveInterval = c.peerIdleTimeout / 2
|
|
|
|
// 0 disables keepalives
|
|
|
|
keepalivesEnabled = keepaliveInterval > 0
|
|
|
|
// set if enable, nil if not; nil channels block forever
|
|
|
|
keepalive <-chan time.Time
|
|
|
|
)
|
2017-05-04 09:12:18 +03:00
|
|
|
|
2017-05-07 01:26:17 +03:00
|
|
|
if keepalivesEnabled {
|
|
|
|
ticker := time.NewTicker(keepaliveInterval)
|
2017-05-04 09:12:18 +03:00
|
|
|
defer ticker.Stop()
|
|
|
|
keepalive = ticker.C
|
|
|
|
}
|
|
|
|
|
|
|
|
var err error
|
|
|
|
for {
|
|
|
|
if err != nil {
|
|
|
|
c.connErr <- err
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
select {
|
|
|
|
// frame write request
|
|
|
|
case fr := <-c.txFrame:
|
2017-05-04 09:56:55 +03:00
|
|
|
err = c.writeFrame(fr)
|
2017-12-31 08:22:02 +03:00
|
|
|
if err == nil && fr.done != nil {
|
|
|
|
close(fr.done)
|
|
|
|
}
|
2017-05-04 09:12:18 +03:00
|
|
|
|
|
|
|
// keepalive timer
|
2017-05-07 01:26:17 +03:00
|
|
|
case <-keepalive:
|
2017-05-04 09:12:18 +03:00
|
|
|
_, err = c.net.Write(keepaliveFrame)
|
2017-05-07 01:26:17 +03:00
|
|
|
// It would be slightly more efficient in terms of network
|
|
|
|
// resources to reset the timer each time a frame is sent.
|
|
|
|
// However, keepalives are small (8 bytes) and the interval
|
|
|
|
// is usually on the order of minutes. It does not seem
|
|
|
|
// worth it to add extra operations in the write path to
|
|
|
|
// avoid. (To properly reset a timer it needs to be stopped,
|
|
|
|
// possibly drained, then reset.)
|
|
|
|
|
|
|
|
// connection complete
|
2017-05-04 09:12:18 +03:00
|
|
|
case <-c.done:
|
2017-11-13 07:51:38 +03:00
|
|
|
// send close
|
|
|
|
c.writeFrame(frame{
|
2018-02-03 22:54:49 +03:00
|
|
|
type_: frameTypeAMQP,
|
|
|
|
body: &performClose{},
|
2017-11-13 07:51:38 +03:00
|
|
|
})
|
2017-05-04 09:56:55 +03:00
|
|
|
return
|
2017-05-04 09:12:18 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-05-04 09:56:55 +03:00
|
|
|
// writeFrame writes a frame to the network, may only be used
|
|
|
|
// by connWriter after initial negotiation.
|
|
|
|
func (c *conn) writeFrame(fr frame) error {
|
|
|
|
if c.connectTimeout != 0 {
|
|
|
|
c.net.SetWriteDeadline(time.Now().Add(c.connectTimeout))
|
|
|
|
}
|
2017-05-07 02:57:27 +03:00
|
|
|
|
|
|
|
// writeFrame into txBuf
|
2018-02-03 22:54:49 +03:00
|
|
|
c.txBuf.reset()
|
2017-05-04 09:56:55 +03:00
|
|
|
err := writeFrame(&c.txBuf, fr)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2017-05-07 05:10:33 +03:00
|
|
|
// validate the frame isn't exceeding peer's max frame size
|
2018-02-03 22:54:49 +03:00
|
|
|
if uint64(c.txBuf.len()) > uint64(c.peerMaxFrameSize) {
|
2017-11-06 05:46:56 +03:00
|
|
|
return errorErrorf("frame larger than peer's max frame size")
|
2017-05-04 09:56:55 +03:00
|
|
|
}
|
|
|
|
|
2017-05-07 02:57:27 +03:00
|
|
|
// write to network
|
2018-02-03 22:54:49 +03:00
|
|
|
_, err = c.net.Write(c.txBuf.bytes())
|
2017-05-04 09:56:55 +03:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2017-05-07 02:57:27 +03:00
|
|
|
// writeProtoHeader writes an AMQP protocol header to the
|
|
|
|
// network
|
2017-05-04 09:56:55 +03:00
|
|
|
func (c *conn) writeProtoHeader(pID protoID) error {
|
|
|
|
if c.connectTimeout != 0 {
|
|
|
|
c.net.SetWriteDeadline(time.Now().Add(c.connectTimeout))
|
|
|
|
}
|
|
|
|
_, err := c.net.Write([]byte{'A', 'M', 'Q', 'P', byte(pID), 1, 0, 0})
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2017-05-04 09:12:18 +03:00
|
|
|
// keepaliveFrame is an AMQP frame with no body, used for keepalives
|
|
|
|
var keepaliveFrame = []byte{0x00, 0x00, 0x00, 0x08, 0x02, 0x00, 0x00, 0x00}
|
|
|
|
|
2017-05-07 02:57:27 +03:00
|
|
|
// wantWriteFrame is used by sessions and links to send frame to
|
|
|
|
// connWriter.
|
2017-05-04 09:12:18 +03:00
|
|
|
func (c *conn) wantWriteFrame(fr frame) {
|
|
|
|
select {
|
|
|
|
case c.txFrame <- fr:
|
|
|
|
case <-c.done:
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-05-04 05:30:30 +03:00
|
|
|
// stateFunc is a state in a state machine.
|
|
|
|
//
|
|
|
|
// The state is advanced by returning the next state.
|
|
|
|
// The state machine concludes when nil is returned.
|
|
|
|
type stateFunc func() stateFunc
|
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
// negotiateProto determines which proto to negotiate next
|
2017-05-04 05:30:30 +03:00
|
|
|
func (c *conn) negotiateProto() stateFunc {
|
2017-05-01 08:02:53 +03:00
|
|
|
// in the order each must be negotiated
|
2017-04-01 23:00:36 +03:00
|
|
|
switch {
|
2017-04-23 19:42:48 +03:00
|
|
|
case c.tlsNegotiation && !c.tlsComplete:
|
2017-04-23 21:01:44 +03:00
|
|
|
return c.exchangeProtoHeader(protoTLS)
|
2017-04-01 23:00:36 +03:00
|
|
|
case c.saslHandlers != nil && !c.saslComplete:
|
2017-04-23 21:01:44 +03:00
|
|
|
return c.exchangeProtoHeader(protoSASL)
|
2017-04-01 23:00:36 +03:00
|
|
|
default:
|
2017-04-23 21:01:44 +03:00
|
|
|
return c.exchangeProtoHeader(protoAMQP)
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-05-04 09:12:18 +03:00
|
|
|
type protoID uint8
|
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// protocol IDs received in protoHeaders
|
2017-04-01 23:00:36 +03:00
|
|
|
const (
|
2017-05-04 09:12:18 +03:00
|
|
|
protoAMQP protoID = 0x0
|
|
|
|
protoTLS protoID = 0x2
|
|
|
|
protoSASL protoID = 0x3
|
2017-04-01 23:00:36 +03:00
|
|
|
)
|
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
// exchangeProtoHeader performs the round trip exchange of protocol
|
|
|
|
// headers, validation, and returns the protoID specific next state.
|
2017-05-04 09:12:18 +03:00
|
|
|
func (c *conn) exchangeProtoHeader(pID protoID) stateFunc {
|
2017-05-01 08:02:53 +03:00
|
|
|
// write the proto header
|
2017-05-04 09:56:55 +03:00
|
|
|
c.err = c.writeProtoHeader(pID)
|
|
|
|
if c.err != nil {
|
|
|
|
return nil
|
|
|
|
}
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// read response header
|
2017-05-07 01:26:17 +03:00
|
|
|
p, err := c.readProtoHeader()
|
|
|
|
if err != nil {
|
|
|
|
c.err = err
|
2017-04-01 23:00:36 +03:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2017-05-04 09:12:18 +03:00
|
|
|
if pID != p.ProtoID {
|
|
|
|
c.err = errorErrorf("unexpected protocol header %#00x, expected %#00x", p.ProtoID, pID)
|
2017-04-01 23:00:36 +03:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// go to the proto specific state
|
2017-05-04 09:12:18 +03:00
|
|
|
switch pID {
|
2017-04-23 21:01:44 +03:00
|
|
|
case protoAMQP:
|
2017-04-30 02:38:15 +03:00
|
|
|
return c.openAMQP
|
2017-04-23 21:01:44 +03:00
|
|
|
case protoTLS:
|
2017-04-30 02:38:15 +03:00
|
|
|
return c.startTLS
|
2017-04-23 21:01:44 +03:00
|
|
|
case protoSASL:
|
2017-04-30 02:38:15 +03:00
|
|
|
return c.negotiateSASL
|
2017-04-01 23:00:36 +03:00
|
|
|
default:
|
2017-04-30 02:38:15 +03:00
|
|
|
c.err = errorErrorf("unknown protocol ID %#02x", p.ProtoID)
|
2017-04-01 23:00:36 +03:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-05-07 02:57:27 +03:00
|
|
|
// readProtoHeader reads a protocol header packet from c.rxProto.
|
2017-05-07 01:26:17 +03:00
|
|
|
func (c *conn) readProtoHeader() (protoHeader, error) {
|
|
|
|
var deadline <-chan time.Time
|
|
|
|
if c.connectTimeout != 0 {
|
|
|
|
deadline = time.After(c.connectTimeout)
|
|
|
|
}
|
|
|
|
var p protoHeader
|
|
|
|
select {
|
|
|
|
case p = <-c.rxProto:
|
|
|
|
return p, nil
|
|
|
|
case err := <-c.connErr:
|
|
|
|
return p, err
|
|
|
|
case fr := <-c.rxFrame:
|
|
|
|
return p, errorErrorf("unexpected frame %#v", fr)
|
|
|
|
case <-deadline:
|
|
|
|
return p, ErrTimeout
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-05-04 05:30:30 +03:00
|
|
|
// startTLS wraps the conn with TLS and returns to Client.negotiateProto
|
|
|
|
func (c *conn) startTLS() stateFunc {
|
2018-01-28 19:13:49 +03:00
|
|
|
c.initTLSConfig()
|
2017-04-30 05:33:03 +03:00
|
|
|
|
2018-02-12 02:26:24 +03:00
|
|
|
done := make(chan struct{})
|
2017-04-30 05:33:03 +03:00
|
|
|
|
2018-02-12 02:26:24 +03:00
|
|
|
// this function will be executed by connReader
|
|
|
|
c.connReaderRun <- func() {
|
|
|
|
c.net.SetReadDeadline(time.Time{}) // clear timeout
|
|
|
|
|
|
|
|
// wrap existing net.Conn and perform TLS handshake
|
|
|
|
tlsConn := tls.Client(c.net, c.tlsConfig)
|
|
|
|
if c.connectTimeout != 0 {
|
|
|
|
tlsConn.SetWriteDeadline(time.Now().Add(c.connectTimeout))
|
|
|
|
}
|
|
|
|
c.err = tlsConn.Handshake()
|
|
|
|
|
|
|
|
// swap net.Conn
|
|
|
|
c.net = tlsConn
|
|
|
|
c.tlsComplete = true
|
|
|
|
|
|
|
|
close(done)
|
2017-05-04 05:30:30 +03:00
|
|
|
}
|
2018-02-12 02:26:24 +03:00
|
|
|
|
|
|
|
// set deadline to interrupt connReader
|
|
|
|
c.net.SetReadDeadline(time.Time{}.Add(1))
|
|
|
|
|
|
|
|
<-done
|
|
|
|
|
2017-04-30 05:33:03 +03:00
|
|
|
if c.err != nil {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// go to next protocol
|
2017-04-23 04:32:50 +03:00
|
|
|
return c.negotiateProto
|
|
|
|
}
|
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
// openAMQP round trips the AMQP open performative
|
2017-05-04 05:30:30 +03:00
|
|
|
func (c *conn) openAMQP() stateFunc {
|
2017-05-01 08:02:53 +03:00
|
|
|
// send open frame
|
2017-05-04 09:56:55 +03:00
|
|
|
c.err = c.writeFrame(frame{
|
2018-02-03 22:54:49 +03:00
|
|
|
type_: frameTypeAMQP,
|
2017-04-30 02:38:15 +03:00
|
|
|
body: &performOpen{
|
2017-11-13 07:51:38 +03:00
|
|
|
ContainerID: string(randBytes(40)),
|
2017-04-23 04:31:07 +03:00
|
|
|
Hostname: c.hostname,
|
|
|
|
MaxFrameSize: c.maxFrameSize,
|
|
|
|
ChannelMax: c.channelMax,
|
2017-04-30 02:38:15 +03:00
|
|
|
IdleTimeout: c.idleTimeout,
|
2017-04-23 04:31:07 +03:00
|
|
|
},
|
|
|
|
channel: 0,
|
2017-04-22 22:56:08 +03:00
|
|
|
})
|
2017-05-04 09:56:55 +03:00
|
|
|
if c.err != nil {
|
|
|
|
return nil
|
|
|
|
}
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// get the response
|
2017-04-24 06:24:12 +03:00
|
|
|
fr, err := c.readFrame()
|
2017-04-01 23:00:36 +03:00
|
|
|
if err != nil {
|
2017-04-24 06:24:12 +03:00
|
|
|
c.err = err
|
2017-04-01 23:00:36 +03:00
|
|
|
return nil
|
|
|
|
}
|
2017-04-30 02:38:15 +03:00
|
|
|
o, ok := fr.body.(*performOpen)
|
2017-04-24 06:24:12 +03:00
|
|
|
if !ok {
|
2017-04-30 02:38:15 +03:00
|
|
|
c.err = errorErrorf("unexpected frame type %T", fr.body)
|
2017-04-27 06:35:29 +03:00
|
|
|
return nil
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// update peer settings
|
2017-04-27 06:35:29 +03:00
|
|
|
if o.MaxFrameSize > 0 {
|
2017-04-30 05:33:03 +03:00
|
|
|
c.peerMaxFrameSize = o.MaxFrameSize
|
2017-04-27 06:35:29 +03:00
|
|
|
}
|
2017-04-23 21:01:44 +03:00
|
|
|
if o.IdleTimeout > 0 {
|
2017-05-07 01:26:17 +03:00
|
|
|
// TODO: reject very small idle timeouts
|
2017-04-30 02:38:15 +03:00
|
|
|
c.peerIdleTimeout = o.IdleTimeout
|
2017-04-17 06:39:31 +03:00
|
|
|
}
|
2017-04-01 23:00:36 +03:00
|
|
|
if o.ChannelMax < c.channelMax {
|
|
|
|
c.channelMax = o.ChannelMax
|
|
|
|
}
|
|
|
|
|
2017-05-01 08:02:53 +03:00
|
|
|
// connection established, exit state machine
|
2017-04-01 23:00:36 +03:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
// negotiateSASL returns the SASL handler for the first matched
|
|
|
|
// mechanism specified by the server
|
2017-05-04 05:30:30 +03:00
|
|
|
func (c *conn) negotiateSASL() stateFunc {
|
2017-05-03 04:49:31 +03:00
|
|
|
// read mechanisms frame
|
2017-04-24 06:24:12 +03:00
|
|
|
fr, err := c.readFrame()
|
2017-04-01 23:00:36 +03:00
|
|
|
if err != nil {
|
|
|
|
c.err = err
|
|
|
|
return nil
|
|
|
|
}
|
2017-04-30 02:38:15 +03:00
|
|
|
sm, ok := fr.body.(*saslMechanisms)
|
2017-04-24 06:24:12 +03:00
|
|
|
if !ok {
|
2017-04-30 02:38:15 +03:00
|
|
|
c.err = errorErrorf("unexpected frame type %T", fr.body)
|
2017-04-01 23:00:36 +03:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2017-05-03 04:49:31 +03:00
|
|
|
// return first match in c.saslHandlers based on order received
|
2017-04-01 23:00:36 +03:00
|
|
|
for _, mech := range sm.Mechanisms {
|
|
|
|
if state, ok := c.saslHandlers[mech]; ok {
|
|
|
|
return state
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-05-03 04:49:31 +03:00
|
|
|
// no match
|
2017-05-07 01:26:17 +03:00
|
|
|
c.err = errorErrorf("no supported auth mechanism (%v)", sm.Mechanisms) // TODO: send "auth not supported" frame?
|
2017-04-01 23:00:36 +03:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2017-05-04 05:30:30 +03:00
|
|
|
// saslOutcome processes the SASL outcome frame and return Client.negotiateProto
|
2017-04-30 02:38:15 +03:00
|
|
|
// on success.
|
|
|
|
//
|
|
|
|
// SASL handlers return this stateFunc when the mechanism specific negotiation
|
|
|
|
// has completed.
|
2017-05-04 05:30:30 +03:00
|
|
|
func (c *conn) saslOutcome() stateFunc {
|
2017-05-03 04:49:31 +03:00
|
|
|
// read outcome frame
|
2017-04-24 06:24:12 +03:00
|
|
|
fr, err := c.readFrame()
|
2017-04-01 23:00:36 +03:00
|
|
|
if err != nil {
|
|
|
|
c.err = err
|
|
|
|
return nil
|
|
|
|
}
|
2017-04-30 02:38:15 +03:00
|
|
|
so, ok := fr.body.(*saslOutcome)
|
2017-04-24 06:24:12 +03:00
|
|
|
if !ok {
|
2017-04-30 02:38:15 +03:00
|
|
|
c.err = errorErrorf("unexpected frame type %T", fr.body)
|
2017-04-27 06:35:29 +03:00
|
|
|
return nil
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
2017-05-03 04:49:31 +03:00
|
|
|
// check if auth succeeded
|
2017-04-23 21:01:44 +03:00
|
|
|
if so.Code != codeSASLOK {
|
2017-05-03 04:49:31 +03:00
|
|
|
c.err = errorErrorf("SASL PLAIN auth failed with code %#00x: %s", so.Code, so.AdditionalData) // implement Stringer for so.Code
|
2017-04-01 23:00:36 +03:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2017-05-03 04:49:31 +03:00
|
|
|
// return to c.negotiateProto
|
2017-04-01 23:00:36 +03:00
|
|
|
c.saslComplete = true
|
|
|
|
return c.negotiateProto
|
|
|
|
}
|
2017-04-24 06:24:12 +03:00
|
|
|
|
2017-04-30 02:38:15 +03:00
|
|
|
// readFrame is used during connection establishment to read a single frame.
|
|
|
|
//
|
2017-05-07 02:57:27 +03:00
|
|
|
// After setup, conn.mux handles incoming frames.
|
2017-05-04 05:30:30 +03:00
|
|
|
func (c *conn) readFrame() (frame, error) {
|
|
|
|
var deadline <-chan time.Time
|
|
|
|
if c.connectTimeout != 0 {
|
|
|
|
deadline = time.After(c.connectTimeout)
|
|
|
|
}
|
2017-05-04 09:12:18 +03:00
|
|
|
|
2017-04-24 06:24:12 +03:00
|
|
|
var fr frame
|
|
|
|
select {
|
|
|
|
case fr = <-c.rxFrame:
|
|
|
|
return fr, nil
|
2017-05-04 09:12:18 +03:00
|
|
|
case err := <-c.connErr:
|
2017-04-24 06:24:12 +03:00
|
|
|
return fr, err
|
2017-04-24 07:03:50 +03:00
|
|
|
case p := <-c.rxProto:
|
2017-04-30 02:38:15 +03:00
|
|
|
return fr, errorErrorf("unexpected protocol header %#v", p)
|
2017-05-04 05:30:30 +03:00
|
|
|
case <-deadline:
|
2018-02-12 01:55:49 +03:00
|
|
|
return fr, ErrTimeout
|
2017-04-24 06:24:12 +03:00
|
|
|
}
|
|
|
|
}
|