go-amqp/conn.go

728 строки
18 KiB
Go
Исходник Обычный вид История

2017-04-01 23:00:36 +03:00
package amqp
import (
"bytes"
2017-04-23 04:32:50 +03:00
"crypto/tls"
"io"
2017-04-30 02:38:15 +03:00
"math"
2017-04-01 23:00:36 +03:00
"net"
"net/url"
2017-04-27 06:35:29 +03:00
"sync"
"sync/atomic"
2017-04-17 06:39:31 +03:00
"time"
2017-04-01 23:00:36 +03:00
)
// connection defaults
const (
2017-04-30 02:38:15 +03:00
defaultMaxFrameSize = 512
defaultChannelMax = 1
defaultIdleTimeout = 1 * time.Minute
)
// Errors
var (
ErrTimeout = errorNew("timeout waiting for response")
2017-04-01 23:00:36 +03:00
)
2017-04-30 02:38:15 +03:00
// ConnOption is an function for configuring an AMQP connection.
type ConnOption func(*Conn) error
2017-04-01 23:00:36 +03:00
2017-04-30 02:38:15 +03:00
// ConnHostname sets the hostname of the server sent in the AMQP
// Open frame and TLS ServerName (if not otherwise set).
//
// This is useful when the AMQP connection will be established
// via a pre-established TLS connection as the server may not
// know what hostname the client is attempting to connect to.
func ConnHostname(hostname string) ConnOption {
return func(c *Conn) error {
c.hostname = hostname
return nil
}
}
2017-04-30 02:38:15 +03:00
// ConnTLS toggles TLS negotiation.
func ConnTLS(enable bool) ConnOption {
2017-04-23 04:32:50 +03:00
return func(c *Conn) error {
c.tlsNegotiation = enable
return nil
}
}
2017-04-30 02:38:15 +03:00
// ConnTLSConfig sets the tls.Config to be used during
// TLS negotiation.
//
// This option is for advanced usage, in most scenarios
// providing a URL scheme of "amqps://" or ConnTLS(true)
// is sufficient.
func ConnTLSConfig(tc *tls.Config) ConnOption {
return func(c *Conn) error {
c.tlsConfig = tc
c.tlsNegotiation = true
2017-04-23 04:32:50 +03:00
return nil
}
}
2017-04-30 02:38:15 +03:00
// ConnIdleTimeout specifies the maximum period between receiving
// frames from the peer.
//
// Resolution is milliseconds. A value of zero indicates no timeout.
// This setting is in addition to TCP keepalives.
func ConnIdleTimeout(d time.Duration) ConnOption {
2017-04-27 06:35:29 +03:00
return func(c *Conn) error {
2017-04-30 02:38:15 +03:00
if d < 0 {
return errorNew("idle timeout cannot be negative")
}
2017-04-27 06:35:29 +03:00
c.idleTimeout = d
return nil
}
}
2017-04-30 02:38:15 +03:00
// ConnMaxFrameSize sets the maximum frame size that
// the connection will send our receive.
//
// Must be 512 or greater.
//
// Default: 512
func ConnMaxFrameSize(n uint32) ConnOption {
2017-04-27 06:35:29 +03:00
return func(c *Conn) error {
2017-04-30 02:38:15 +03:00
if n < 512 {
return errorNew("max frame size must be 512 or greater")
}
2017-04-27 06:35:29 +03:00
c.maxFrameSize = n
return nil
}
}
2017-04-01 23:00:36 +03:00
type stateFunc func() stateFunc
2017-04-30 02:38:15 +03:00
// Conn is an AMQP connection.
2017-04-01 23:00:36 +03:00
type Conn struct {
net net.Conn // underlying connection
pauseRead int32 // atomically set to indicate connReader should pause reading from network
resumeRead chan struct{} // connReader reads from channel while paused, until channel is closed
// TLS
2017-04-30 02:38:15 +03:00
tlsNegotiation bool // negotiate TLS
tlsComplete bool // TLS negotiation complete
tlsConfig *tls.Config // TLS config, default used if nil (ServerName set to Conn.hostname)
2017-04-01 23:00:36 +03:00
2017-04-30 02:38:15 +03:00
// SASL
saslHandlers map[Symbol]stateFunc // map of supported handlers keyed by SASL mechanism, SASL not negotiated if nil
saslComplete bool // SASL negotiation complete
2017-04-01 23:00:36 +03:00
2017-04-30 02:38:15 +03:00
// local settings
maxFrameSize uint32 // max frame size we accept
channelMax uint16 // maximum number of channels we'll create
hostname string // hostname of remote server (set explicitly or parsed from URL)
idleTimeout time.Duration // maximum period between receiving frames
2017-04-27 06:35:29 +03:00
2017-04-30 02:38:15 +03:00
// peer settings
peerIdleTimeout time.Duration // maximum period between sending frames
peerMaxFrameSize uint32 // maximum frame size peer will accept
2017-04-01 23:00:36 +03:00
2017-04-30 02:38:15 +03:00
// conn state
errMu sync.Mutex // mux holds errMu from start until shutdown completes; operations are sequential before mux is started
err error // error to be returned to client
doneClosed int32 // atomically read/set; used to prevent double close
done chan struct{} // indicates the connection is done
2017-04-01 23:00:36 +03:00
// mux
2017-04-30 02:38:15 +03:00
readErr chan error // connReader notifications of an error
rxProto chan protoHeader // protoHeaders received by connReader
rxFrame chan frame // AMQP frames received by connReader
newSession chan *Session // new Sessions are requested from mux by reading off this channel
delSession chan *Session // session completion is indicated to mux by sending the Session on this channel
2017-04-01 23:00:36 +03:00
}
2017-04-30 02:38:15 +03:00
// Dial connects to an AMQP server.
//
// If the addr includes a scheme, it must be "amqp" or "amqps".
// TLS will be negotiated when the scheme is "amqps".
//
// If no port is provided, 5672 will be used.
func Dial(addr string, opts ...ConnOption) (*Conn, error) {
2017-04-01 23:00:36 +03:00
u, err := url.Parse(addr)
if err != nil {
return nil, err
}
2017-04-23 04:32:50 +03:00
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
host = u.Host
port = "5672" // use default AMQP if parse fails
2017-04-23 04:32:50 +03:00
}
2017-04-01 23:00:36 +03:00
switch u.Scheme {
2017-04-23 04:32:50 +03:00
case "amqp", "amqps", "":
2017-04-01 23:00:36 +03:00
default:
2017-04-30 02:38:15 +03:00
return nil, errorErrorf("unsupported scheme %q", u.Scheme)
2017-04-01 23:00:36 +03:00
}
2017-04-23 04:32:50 +03:00
conn, err := net.Dial("tcp", host+":"+port)
2017-04-01 23:00:36 +03:00
if err != nil {
return nil, err
}
// append default options so user specified can overwrite
2017-04-30 02:38:15 +03:00
opts = append([]ConnOption{
ConnHostname(host),
ConnTLS(u.Scheme == "amqps"),
2017-04-23 04:32:50 +03:00
}, opts...)
c, err := New(conn, opts...)
if err != nil {
return nil, err
}
return c, err
2017-04-01 23:00:36 +03:00
}
2017-04-30 02:38:15 +03:00
// New establishes an AMQP connection on pre-established
// net.Conn.
func New(conn net.Conn, opts ...ConnOption) (*Conn, error) {
2017-04-01 23:00:36 +03:00
c := &Conn{
2017-04-27 06:35:29 +03:00
net: conn,
2017-04-30 02:38:15 +03:00
maxFrameSize: defaultMaxFrameSize,
peerMaxFrameSize: defaultMaxFrameSize,
channelMax: defaultChannelMax,
idleTimeout: defaultIdleTimeout,
2017-04-27 06:35:29 +03:00
done: make(chan struct{}),
2017-04-30 02:38:15 +03:00
readErr: make(chan error, 1), // buffered to ensure connReader doesn't leak
rxProto: make(chan protoHeader),
2017-04-27 06:35:29 +03:00
rxFrame: make(chan frame),
newSession: make(chan *Session),
delSession: make(chan *Session),
2017-04-01 23:00:36 +03:00
}
// apply options
2017-04-01 23:00:36 +03:00
for _, opt := range opts {
if err := opt(c); err != nil {
return nil, err
}
}
// start connReader
2017-04-24 06:24:12 +03:00
go c.connReader()
// run connection establishment state machine
2017-04-01 23:00:36 +03:00
for state := c.negotiateProto; state != nil; {
state = state()
}
// check if err occurred
2017-04-24 06:24:12 +03:00
if c.err != nil {
2017-04-27 06:35:29 +03:00
c.Close()
2017-04-24 06:24:12 +03:00
return nil, c.err
2017-04-01 23:00:36 +03:00
}
// start multiplexor
2017-04-30 02:38:15 +03:00
go c.mux()
2017-04-24 06:24:12 +03:00
return c, nil
2017-04-01 23:00:36 +03:00
}
2017-04-30 02:38:15 +03:00
// Close disconnects the connection.
2017-04-01 23:00:36 +03:00
func (c *Conn) Close() error {
// TODO: shutdown AMQP
c.closeDone() // notify goroutines and blocked functions to exit
2017-04-27 06:35:29 +03:00
// Conn.mux holds err lock until shutdown, we block until
// shutdown completes and we can return the error (if any)
2017-04-27 06:35:29 +03:00
c.errMu.Lock()
defer c.errMu.Unlock()
2017-04-24 06:24:12 +03:00
err := c.net.Close()
if c.err == nil {
c.err = err
}
return c.err
2017-04-01 23:00:36 +03:00
}
2017-04-30 02:38:15 +03:00
// closeDone closes Conn.done if it has not already been closed
2017-04-27 06:35:29 +03:00
func (c *Conn) closeDone() {
if atomic.CompareAndSwapInt32(&c.doneClosed, 0, 1) {
close(c.done)
}
}
2017-04-30 02:38:15 +03:00
// NewSession opens a new AMQP session to the server.
2017-04-23 20:31:45 +03:00
func (c *Conn) NewSession() (*Session, error) {
// get a session allocated by Conn.mux
2017-04-24 06:24:12 +03:00
var s *Session
select {
case <-c.done:
return nil, c.err
case s = <-c.newSession:
}
2017-04-01 23:00:36 +03:00
// send Begin to server
2017-04-27 06:35:29 +03:00
err := s.txFrame(&performBegin{
2017-04-01 23:00:36 +03:00
NextOutgoingID: 0,
IncomingWindow: 1,
})
2017-04-24 06:24:12 +03:00
if err != nil {
s.Close()
return nil, err
}
2017-04-01 23:00:36 +03:00
// wait for response
2017-04-24 06:24:12 +03:00
var fr frame
select {
case <-c.done:
return nil, c.err
case fr = <-s.rx:
}
2017-04-30 02:38:15 +03:00
begin, ok := fr.body.(*performBegin)
if !ok {
s.Close() // deallocate session on error
2017-04-30 02:38:15 +03:00
return nil, errorErrorf("unexpected begin response: %+v", fr)
2017-04-01 23:00:36 +03:00
}
// TODO: record negotiated settings
s.remoteChannel = begin.RemoteChannel
2017-04-01 23:00:36 +03:00
// start Session multiplexor
go s.mux()
return s, nil
2017-04-01 23:00:36 +03:00
}
// keepaliveFrame is an AMQP frame with no body, used for keepalives
var keepaliveFrame = []byte{0x00, 0x00, 0x00, 0x08, 0x02, 0x00, 0x00, 0x00}
// mux is start in it's own goroutine after initial connection establishment.
// It handles muxing of sessions, keepalives, and connection errors.
2017-04-30 02:38:15 +03:00
func (c *Conn) mux() {
// create the next session to allocate
nextSession := newSession(c, 0)
2017-04-01 23:00:36 +03:00
// map channel to sessions
2017-04-01 23:00:36 +03:00
sessions := make(map[uint16]*Session)
// if Conn.peerIdleTimeout is 0, keepalive will be nil and
// no keepalives will be sent
2017-04-30 02:38:15 +03:00
var keepalive <-chan time.Time
// per spec, keepalives should be sent every 0.5 * idle timeout
2017-04-30 02:38:15 +03:00
if kaInterval := c.peerIdleTimeout / 2; kaInterval > 0 {
ticker := time.NewTicker(kaInterval)
defer ticker.Stop()
keepalive = ticker.C
}
2017-04-01 23:00:36 +03:00
// we hold the errMu lock until error or done
2017-04-27 06:35:29 +03:00
c.errMu.Lock()
defer c.errMu.Unlock()
2017-04-01 23:00:36 +03:00
for {
// check if last loop returned an error
2017-04-01 23:00:36 +03:00
if c.err != nil {
2017-04-27 06:35:29 +03:00
c.closeDone()
2017-04-24 06:24:12 +03:00
return
2017-04-01 23:00:36 +03:00
}
select {
// error from connReader
2017-04-24 06:24:12 +03:00
case c.err = <-c.readErr:
2017-04-01 23:00:36 +03:00
// new frame from connReader
case fr := <-c.rxFrame:
// lookup session and send to Session.mux
ch, ok := sessions[fr.channel]
if !ok {
2017-04-30 02:38:15 +03:00
c.err = errorErrorf("unexpected frame: %#v", fr.body)
continue
2017-04-01 23:00:36 +03:00
}
ch.rx <- fr
2017-04-01 23:00:36 +03:00
// new session request
//
// Continually try to send the next session on the channel,
// then add it to the sessions map. This allows us to control ID
// allocation and prevents the need to have shared map. Since new
// sessions are far less frequent than frames being sent to sessions,
// we can avoid the lock/unlock for session lookup.
case c.newSession <- nextSession:
2017-04-01 23:00:36 +03:00
sessions[nextSession.channel] = nextSession
// create the next session to send
nextSession = newSession(c, nextSession.channel+1) // TODO: enforce max session/wrapping
// session deletion
2017-04-01 23:00:36 +03:00
case s := <-c.delSession:
delete(sessions, s.channel)
// keepalive timer
2017-04-30 02:38:15 +03:00
case <-keepalive:
// TODO: reset timer on non-keepalive transmit
2017-04-17 06:39:31 +03:00
_, c.err = c.net.Write(keepaliveFrame)
// connection is complete
2017-04-27 06:35:29 +03:00
case <-c.done:
return
2017-04-01 23:00:36 +03:00
}
}
}
// frameReader reads one frame at a time, up to n bytes
type frameReader struct {
r io.Reader // underlying reader
n int64 // max bytes per Read call
}
func (f *frameReader) Read(p []byte) (int, error) {
if f.n < int64(len(p)) {
p = p[:f.n]
}
n, err := f.r.Read(p)
if err != nil {
return n, err
}
return n, io.EOF
}
2017-04-30 02:38:15 +03:00
// connReader reads from the net.Conn, decodes frames, and passes them
// up via the Conn.rxFrame and Conn.rxProto channels.
2017-04-01 23:00:36 +03:00
func (c *Conn) connReader() {
buf := bufPool.Get().(*bytes.Buffer)
2017-04-30 18:56:16 +03:00
defer bufPool.Put(buf)
buf.Reset()
2017-04-24 06:24:12 +03:00
2017-04-30 02:38:15 +03:00
var (
negotiating = true // true during conn establishment, we should check for protoHeaders
currentHeader frameHeader // keep track of the current header, for frames split across multiple TCP packets
frameInProgress bool // true if we're in the middle of receiving data for currentHeader
2017-04-30 02:38:15 +03:00
)
// frameReader facilitates reading directly into buf
fr := &frameReader{r: c.net, n: int64(c.maxFrameSize)}
2017-04-27 06:35:29 +03:00
for {
// we need to read more if buf doesn't contain the complete frame
// or there's not enough in buf to parse the header
if frameInProgress || buf.Len() < frameHeaderSize {
2017-04-30 18:56:16 +03:00
c.net.SetReadDeadline(time.Now().Add(c.idleTimeout))
_, err := buf.ReadFrom(fr) // TODO: send error on frame too large
2017-04-27 06:35:29 +03:00
if err != nil {
2017-04-30 05:33:03 +03:00
if atomic.LoadInt32(&c.pauseRead) == 1 {
// need to stop reading during TLS negotiation,
// see Conn.startTLS()
c.pauseRead = 0
for range c.resumeRead {
// reads indicate paused, resume on close
}
fr.r = c.net // conn wrapped with TLS
2017-04-30 05:33:03 +03:00
continue
}
2017-04-27 06:35:29 +03:00
c.readErr <- err
return
2017-04-24 06:24:12 +03:00
}
2017-04-27 06:35:29 +03:00
}
// read more if we didn't get enough to parse header
if buf.Len() < frameHeaderSize {
2017-04-27 06:35:29 +03:00
continue
}
// during negotiation, check for proto frames
2017-04-27 06:35:29 +03:00
if negotiating && bytes.Equal(buf.Bytes()[:4], []byte{'A', 'M', 'Q', 'P'}) {
2017-04-30 02:38:15 +03:00
p, err := parseProtoHeader(buf)
if err != nil {
2017-04-24 06:24:12 +03:00
c.readErr <- err
return
}
// we know negotiation is complete once an AMQP proto frame
// is received
2017-04-27 06:35:29 +03:00
if p.ProtoID == protoAMQP {
negotiating = false
2017-04-24 06:24:12 +03:00
}
// send proto header
2017-04-24 06:24:12 +03:00
select {
case <-c.done:
return
2017-04-27 06:35:29 +03:00
case c.rxProto <- p:
}
2017-04-30 02:38:15 +03:00
2017-04-27 06:35:29 +03:00
continue
}
// parse the header if we're not completeing an already
// parsed frame
2017-04-27 06:35:29 +03:00
if !frameInProgress {
var err error
2017-04-27 06:35:29 +03:00
currentHeader, err = parseFrameHeader(buf)
if err != nil {
c.readErr <- err
return
}
frameInProgress = true
}
// check size is reasonable
2017-04-30 02:38:15 +03:00
if currentHeader.Size > math.MaxInt32 { // make max size configurable
c.readErr <- errorNew("payload too large")
return
}
bodySize := int(currentHeader.Size - frameHeaderSize)
2017-04-30 02:38:15 +03:00
// check if we have the full frame
2017-04-30 02:38:15 +03:00
if buf.Len() < bodySize {
2017-04-27 06:35:29 +03:00
continue
}
frameInProgress = false
// check if body is empty (keepalive)
2017-04-30 02:38:15 +03:00
if bodySize == 0 {
continue
2017-04-30 02:38:15 +03:00
}
// parse the frame
2017-04-30 18:56:16 +03:00
payload := bytes.NewBuffer(buf.Next(bodySize))
parsedBody, err := parseFrame(payload)
2017-04-27 06:35:29 +03:00
if err != nil {
c.readErr <- err
return
}
// send to mux
2017-04-27 06:35:29 +03:00
select {
case <-c.done:
return
2017-04-30 18:56:16 +03:00
case c.rxFrame <- frame{channel: currentHeader.Channel, body: parsedBody}:
2017-04-27 06:35:29 +03:00
}
2017-04-01 23:00:36 +03:00
}
}
2017-04-30 02:38:15 +03:00
// negotiateProto determines which proto to negotiate next
2017-04-01 23:00:36 +03:00
func (c *Conn) negotiateProto() stateFunc {
// in the order each must be negotiated
2017-04-01 23:00:36 +03:00
switch {
case c.tlsNegotiation && !c.tlsComplete:
2017-04-23 21:01:44 +03:00
return c.exchangeProtoHeader(protoTLS)
2017-04-01 23:00:36 +03:00
case c.saslHandlers != nil && !c.saslComplete:
2017-04-23 21:01:44 +03:00
return c.exchangeProtoHeader(protoSASL)
2017-04-01 23:00:36 +03:00
default:
2017-04-23 21:01:44 +03:00
return c.exchangeProtoHeader(protoAMQP)
2017-04-01 23:00:36 +03:00
}
}
// protocol IDs received in protoHeaders
2017-04-01 23:00:36 +03:00
const (
2017-04-23 21:01:44 +03:00
protoAMQP = 0x0
protoTLS = 0x2
protoSASL = 0x3
2017-04-01 23:00:36 +03:00
)
2017-04-30 02:38:15 +03:00
// exchangeProtoHeader performs the round trip exchange of protocol
// headers, validation, and returns the protoID specific next state.
2017-04-24 06:24:12 +03:00
func (c *Conn) exchangeProtoHeader(protoID uint8) stateFunc {
// write the proto header
c.net.SetWriteDeadline(time.Now().Add(1 * time.Second)) // TODO: make configurable
2017-04-24 06:24:12 +03:00
_, c.err = c.net.Write([]byte{'A', 'M', 'Q', 'P', protoID, 1, 0, 0})
2017-04-01 23:00:36 +03:00
if c.err != nil {
2017-04-30 05:33:03 +03:00
c.err = errorWrapf(c.err, "writing to network")
2017-04-01 23:00:36 +03:00
return nil
}
c.net.SetWriteDeadline(time.Time{})
2017-04-01 23:00:36 +03:00
// read response header
2017-04-30 02:38:15 +03:00
var p protoHeader
2017-04-24 06:24:12 +03:00
select {
case p = <-c.rxProto:
case c.err = <-c.readErr:
2017-04-01 23:00:36 +03:00
return nil
case fr := <-c.rxFrame:
2017-04-30 02:38:15 +03:00
c.err = errorErrorf("unexpected frame %#v", fr)
2017-04-27 06:35:29 +03:00
return nil
2017-04-24 06:24:12 +03:00
case <-time.After(1 * time.Second):
c.err = errorWrapf(ErrTimeout, "timeout")
2017-04-01 23:00:36 +03:00
return nil
}
2017-04-24 06:24:12 +03:00
if protoID != p.ProtoID {
2017-04-30 02:38:15 +03:00
c.err = errorErrorf("unexpected protocol header %#00x, expected %#00x", p.ProtoID, protoID)
2017-04-01 23:00:36 +03:00
return nil
}
// go to the proto specific state
2017-04-24 06:24:12 +03:00
switch protoID {
2017-04-23 21:01:44 +03:00
case protoAMQP:
2017-04-30 02:38:15 +03:00
return c.openAMQP
2017-04-23 21:01:44 +03:00
case protoTLS:
2017-04-30 02:38:15 +03:00
return c.startTLS
2017-04-23 21:01:44 +03:00
case protoSASL:
2017-04-30 02:38:15 +03:00
return c.negotiateSASL
2017-04-01 23:00:36 +03:00
default:
2017-04-30 02:38:15 +03:00
c.err = errorErrorf("unknown protocol ID %#02x", p.ProtoID)
2017-04-01 23:00:36 +03:00
return nil
}
}
2017-04-30 02:38:15 +03:00
// startTLS wraps the conn with TLS and returns to Conn.negotiateProto
func (c *Conn) startTLS() stateFunc {
// create a new config if not already set
if c.tlsConfig == nil {
2017-04-24 00:58:59 +03:00
c.tlsConfig = new(tls.Config)
}
// TLS config must have ServerName or InsecureSkipVerify set
if c.tlsConfig.ServerName == "" && !c.tlsConfig.InsecureSkipVerify {
c.tlsConfig.ServerName = c.hostname
}
2017-04-30 05:33:03 +03:00
// convoluted method to pause connReader, explorer simpler alternatives
c.resumeRead = make(chan struct{}) // 1. create channel
atomic.StoreInt32(&c.pauseRead, 1) // 2. indicate should pause
c.net.SetReadDeadline(time.Time{}.Add(1)) // 3. set deadline to interrupt connReader
c.resumeRead <- struct{}{} // 4. wait for connReader to read from chan, indicating paused
defer close(c.resumeRead) // 5. defer connReader resume by closing channel
c.net.SetReadDeadline(time.Time{}) // 6. clear deadline
// wrap existing net.Conn and perform TLS handshake
2017-04-30 05:33:03 +03:00
conn := tls.Client(c.net, c.tlsConfig)
c.err = conn.Handshake()
if c.err != nil {
return nil
}
// swap net.Conn
2017-04-30 05:33:03 +03:00
c.net = conn
c.tlsComplete = true
// go to next protocol
2017-04-23 04:32:50 +03:00
return c.negotiateProto
}
2017-04-30 02:38:15 +03:00
// txFrame encodes and transmits a frame on the connection
func (c *Conn) txFrame(fr frame) error {
// BUG: This should respect c.peerMaxFrameSize. Should not affect current functionality;
2017-04-30 05:33:03 +03:00
// only transfer frames should be larger than min-max frame size (512).
2017-04-30 02:38:15 +03:00
return writeFrame(c.net, fr) // TODO: buffer?
}
2017-04-30 02:38:15 +03:00
// openAMQP round trips the AMQP open performative
func (c *Conn) openAMQP() stateFunc {
// send open frame
2017-04-30 02:38:15 +03:00
c.err = c.txFrame(frame{
typ: frameTypeAMQP,
body: &performOpen{
ContainerID: randString(),
Hostname: c.hostname,
MaxFrameSize: c.maxFrameSize,
ChannelMax: c.channelMax,
2017-04-30 02:38:15 +03:00
IdleTimeout: c.idleTimeout,
},
channel: 0,
})
if c.err != nil {
2017-04-01 23:00:36 +03:00
return nil
}
// get the response
2017-04-24 06:24:12 +03:00
fr, err := c.readFrame()
2017-04-01 23:00:36 +03:00
if err != nil {
2017-04-24 06:24:12 +03:00
c.err = err
2017-04-01 23:00:36 +03:00
return nil
}
2017-04-30 02:38:15 +03:00
o, ok := fr.body.(*performOpen)
2017-04-24 06:24:12 +03:00
if !ok {
2017-04-30 02:38:15 +03:00
c.err = errorErrorf("unexpected frame type %T", fr.body)
2017-04-27 06:35:29 +03:00
return nil
2017-04-01 23:00:36 +03:00
}
// update peer settings
2017-04-27 06:35:29 +03:00
if o.MaxFrameSize > 0 {
2017-04-30 05:33:03 +03:00
c.peerMaxFrameSize = o.MaxFrameSize
2017-04-27 06:35:29 +03:00
}
2017-04-23 21:01:44 +03:00
if o.IdleTimeout > 0 {
2017-04-30 02:38:15 +03:00
c.peerIdleTimeout = o.IdleTimeout
2017-04-17 06:39:31 +03:00
}
2017-04-01 23:00:36 +03:00
if o.ChannelMax < c.channelMax {
c.channelMax = o.ChannelMax
}
// connection established, exit state machine
2017-04-01 23:00:36 +03:00
return nil
}
2017-04-30 02:38:15 +03:00
// negotiateSASL returns the SASL handler for the first matched
// mechanism specified by the server
func (c *Conn) negotiateSASL() stateFunc {
2017-04-01 23:00:36 +03:00
if c.saslHandlers == nil {
// we don't support SASL
2017-04-30 02:38:15 +03:00
c.err = errorErrorf("server request SASL, but not configured")
2017-04-01 23:00:36 +03:00
return nil
}
2017-04-24 06:24:12 +03:00
fr, err := c.readFrame()
2017-04-01 23:00:36 +03:00
if err != nil {
c.err = err
return nil
}
2017-04-30 02:38:15 +03:00
sm, ok := fr.body.(*saslMechanisms)
2017-04-24 06:24:12 +03:00
if !ok {
2017-04-30 02:38:15 +03:00
c.err = errorErrorf("unexpected frame type %T", fr.body)
2017-04-01 23:00:36 +03:00
return nil
}
for _, mech := range sm.Mechanisms {
if state, ok := c.saslHandlers[mech]; ok {
return state
}
}
// TODO: send some sort of "auth not supported" frame?
2017-04-30 02:38:15 +03:00
c.err = errorErrorf("no supported auth mechanism (%v)", sm.Mechanisms)
2017-04-01 23:00:36 +03:00
return nil
}
2017-04-30 02:38:15 +03:00
// saslOutcome processes the SASL outcome frame and return Conn.negotiateProto
// on success.
//
// SASL handlers return this stateFunc when the mechanism specific negotiation
// has completed.
2017-04-01 23:00:36 +03:00
func (c *Conn) saslOutcome() stateFunc {
2017-04-24 06:24:12 +03:00
fr, err := c.readFrame()
2017-04-01 23:00:36 +03:00
if err != nil {
c.err = err
return nil
}
2017-04-30 02:38:15 +03:00
so, ok := fr.body.(*saslOutcome)
2017-04-24 06:24:12 +03:00
if !ok {
2017-04-30 02:38:15 +03:00
c.err = errorErrorf("unexpected frame type %T", fr.body)
2017-04-27 06:35:29 +03:00
return nil
2017-04-01 23:00:36 +03:00
}
2017-04-23 21:01:44 +03:00
if so.Code != codeSASLOK {
2017-04-30 02:38:15 +03:00
c.err = errorErrorf("SASL PLAIN auth failed with code %#00x: %s", so.Code, so.AdditionalData)
2017-04-01 23:00:36 +03:00
return nil
}
c.saslComplete = true
return c.negotiateProto
}
2017-04-24 06:24:12 +03:00
2017-04-30 02:38:15 +03:00
// readFrame is used during connection establishment to read a single frame.
//
// After setup, Conn.mux handles incoming frames.
2017-04-24 06:24:12 +03:00
func (c *Conn) readFrame() (frame, error) {
var fr frame
select {
case fr = <-c.rxFrame:
return fr, nil
case err := <-c.readErr:
return fr, err
case p := <-c.rxProto:
2017-04-30 02:38:15 +03:00
return fr, errorErrorf("unexpected protocol header %#v", p)
2017-04-24 06:24:12 +03:00
case <-time.After(1 * time.Second):
return fr, errorWrapf(ErrTimeout, "timeout")
2017-04-24 06:24:12 +03:00
}
}