2017-04-01 23:00:36 +03:00
|
|
|
package amqp
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
2017-04-23 04:32:50 +03:00
|
|
|
"crypto/tls"
|
2017-04-01 23:00:36 +03:00
|
|
|
"fmt"
|
|
|
|
"net"
|
|
|
|
"net/url"
|
2017-04-27 06:35:29 +03:00
|
|
|
"sync"
|
|
|
|
"sync/atomic"
|
2017-04-17 06:39:31 +03:00
|
|
|
"time"
|
2017-04-16 21:37:51 +03:00
|
|
|
|
|
|
|
"github.com/pkg/errors"
|
2017-04-01 23:00:36 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
// connection defaults
|
|
|
|
const (
|
|
|
|
initialMaxFrameSize = 512
|
|
|
|
initialChannelMax = 1
|
|
|
|
)
|
|
|
|
|
2017-04-23 19:42:48 +03:00
|
|
|
type ConnOpt func(*Conn) error
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-04-23 19:42:48 +03:00
|
|
|
func ConnHostname(hostname string) ConnOpt {
|
2017-04-16 21:37:51 +03:00
|
|
|
return func(c *Conn) error {
|
|
|
|
c.hostname = hostname
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-04-23 23:16:51 +03:00
|
|
|
func ConnTLS(enable bool) ConnOpt {
|
2017-04-23 04:32:50 +03:00
|
|
|
return func(c *Conn) error {
|
2017-04-23 19:42:48 +03:00
|
|
|
c.tlsNegotiation = enable
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func ConnTLSConfig(tc *tls.Config) ConnOpt {
|
|
|
|
return func(c *Conn) error {
|
|
|
|
c.tlsConfig = tc
|
|
|
|
c.tlsNegotiation = true
|
2017-04-23 04:32:50 +03:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
func ConnIdleTimeout(d time.Duration) ConnOpt {
|
|
|
|
return func(c *Conn) error {
|
|
|
|
c.idleTimeout = d
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func ConnMaxFrameSize(n uint32) ConnOpt {
|
|
|
|
return func(c *Conn) error {
|
|
|
|
// TODO: error if 0
|
|
|
|
c.maxFrameSize = n
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-04-01 23:00:36 +03:00
|
|
|
type stateFunc func() stateFunc
|
|
|
|
|
|
|
|
type Conn struct {
|
2017-04-23 19:42:48 +03:00
|
|
|
net net.Conn
|
|
|
|
|
|
|
|
// TLS
|
|
|
|
tlsNegotiation bool
|
|
|
|
tlsComplete bool
|
|
|
|
tlsConfig *tls.Config
|
2017-04-01 23:00:36 +03:00
|
|
|
|
|
|
|
maxFrameSize uint32
|
|
|
|
channelMax uint16
|
2017-04-16 21:37:51 +03:00
|
|
|
hostname string
|
2017-04-17 06:39:31 +03:00
|
|
|
idleTimeout time.Duration
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
peerMaxFrameSize uint32
|
|
|
|
|
|
|
|
// startMux holds errMu from start until shutdown complete
|
|
|
|
// operations are sequential before startMux is started and
|
|
|
|
// holding the mutex is not necessary
|
|
|
|
errMu sync.Mutex
|
|
|
|
err error
|
|
|
|
doneClosed int32
|
|
|
|
done chan struct{}
|
2017-04-01 23:00:36 +03:00
|
|
|
|
|
|
|
// SASL
|
|
|
|
saslHandlers map[Symbol]stateFunc
|
|
|
|
saslComplete bool
|
|
|
|
|
|
|
|
// mux
|
|
|
|
readErr chan error
|
2017-04-24 06:24:12 +03:00
|
|
|
rxProto chan proto
|
2017-04-23 04:31:07 +03:00
|
|
|
rxFrame chan frame
|
2017-04-01 23:00:36 +03:00
|
|
|
newSession chan *Session
|
|
|
|
delSession chan *Session
|
|
|
|
}
|
|
|
|
|
2017-04-23 19:42:48 +03:00
|
|
|
func Dial(addr string, opts ...ConnOpt) (*Conn, error) {
|
2017-04-01 23:00:36 +03:00
|
|
|
u, err := url.Parse(addr)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2017-04-23 04:32:50 +03:00
|
|
|
host, port, err := net.SplitHostPort(u.Host)
|
|
|
|
if err != nil {
|
|
|
|
host = u.Host
|
|
|
|
port = "5672"
|
|
|
|
}
|
|
|
|
|
2017-04-01 23:00:36 +03:00
|
|
|
switch u.Scheme {
|
2017-04-23 04:32:50 +03:00
|
|
|
case "amqp", "amqps", "":
|
2017-04-01 23:00:36 +03:00
|
|
|
default:
|
|
|
|
return nil, fmt.Errorf("unsupported scheme %q", u.Scheme)
|
|
|
|
}
|
|
|
|
|
2017-04-23 04:32:50 +03:00
|
|
|
conn, err := net.Dial("tcp", host+":"+port)
|
2017-04-01 23:00:36 +03:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2017-04-23 19:42:48 +03:00
|
|
|
opts = append([]ConnOpt{
|
|
|
|
ConnHostname(host),
|
2017-04-23 23:16:51 +03:00
|
|
|
ConnTLS(u.Scheme == "amqps"),
|
2017-04-23 04:32:50 +03:00
|
|
|
}, opts...)
|
|
|
|
|
|
|
|
c, err := New(conn, opts...)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return c, err
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
2017-04-23 19:42:48 +03:00
|
|
|
func New(conn net.Conn, opts ...ConnOpt) (*Conn, error) {
|
2017-04-01 23:00:36 +03:00
|
|
|
c := &Conn{
|
2017-04-27 06:35:29 +03:00
|
|
|
net: conn,
|
|
|
|
maxFrameSize: initialMaxFrameSize,
|
|
|
|
peerMaxFrameSize: initialMaxFrameSize,
|
|
|
|
channelMax: initialChannelMax,
|
|
|
|
idleTimeout: 1 * time.Minute,
|
|
|
|
done: make(chan struct{}),
|
|
|
|
readErr: make(chan error, 1),
|
|
|
|
rxProto: make(chan proto),
|
|
|
|
rxFrame: make(chan frame),
|
|
|
|
newSession: make(chan *Session),
|
|
|
|
delSession: make(chan *Session),
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
for _, opt := range opts {
|
|
|
|
if err := opt(c); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-04-24 06:24:12 +03:00
|
|
|
go c.connReader()
|
|
|
|
|
2017-04-01 23:00:36 +03:00
|
|
|
for state := c.negotiateProto; state != nil; {
|
|
|
|
state = state()
|
|
|
|
}
|
|
|
|
|
2017-04-24 06:24:12 +03:00
|
|
|
if c.err != nil {
|
2017-04-27 06:35:29 +03:00
|
|
|
c.Close()
|
2017-04-24 06:24:12 +03:00
|
|
|
return nil, c.err
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
2017-04-24 06:24:12 +03:00
|
|
|
go c.startMux()
|
|
|
|
|
|
|
|
return c, nil
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) Close() error {
|
|
|
|
// TODO: shutdown AMQP
|
2017-04-27 06:35:29 +03:00
|
|
|
c.closeDone()
|
|
|
|
|
|
|
|
c.errMu.Lock()
|
|
|
|
defer c.errMu.Unlock()
|
2017-04-24 06:24:12 +03:00
|
|
|
err := c.net.Close()
|
|
|
|
if c.err == nil {
|
|
|
|
c.err = err
|
|
|
|
}
|
|
|
|
return c.err
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
func (c *Conn) closeDone() {
|
|
|
|
if atomic.CompareAndSwapInt32(&c.doneClosed, 0, 1) {
|
|
|
|
close(c.done)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-04-01 23:00:36 +03:00
|
|
|
func (c *Conn) MaxFrameSize() int {
|
|
|
|
return int(c.maxFrameSize)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) ChannelMax() int {
|
|
|
|
return int(c.channelMax)
|
|
|
|
}
|
|
|
|
|
2017-04-23 20:31:45 +03:00
|
|
|
func (c *Conn) NewSession() (*Session, error) {
|
2017-04-24 06:24:12 +03:00
|
|
|
var s *Session
|
|
|
|
select {
|
|
|
|
case <-c.done:
|
|
|
|
return nil, c.err
|
|
|
|
case s = <-c.newSession:
|
2017-04-22 22:56:08 +03:00
|
|
|
}
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
err := s.txFrame(&performBegin{
|
2017-04-01 23:00:36 +03:00
|
|
|
NextOutgoingID: 0,
|
|
|
|
IncomingWindow: 1,
|
|
|
|
})
|
2017-04-24 06:24:12 +03:00
|
|
|
if err != nil {
|
|
|
|
s.Close()
|
|
|
|
return nil, err
|
|
|
|
}
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-04-24 06:24:12 +03:00
|
|
|
var fr frame
|
|
|
|
select {
|
|
|
|
case <-c.done:
|
|
|
|
return nil, c.err
|
|
|
|
case fr = <-s.rx:
|
|
|
|
}
|
2017-04-27 06:35:29 +03:00
|
|
|
begin, ok := fr.preformative.(*performBegin)
|
2017-04-22 22:56:08 +03:00
|
|
|
if !ok {
|
|
|
|
s.Close()
|
|
|
|
return nil, fmt.Errorf("unexpected begin response: %+v", fr)
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
2017-04-22 22:56:08 +03:00
|
|
|
fmt.Printf("Begin Resp: %+v", begin)
|
2017-04-01 23:00:36 +03:00
|
|
|
// TODO: record negotiated settings
|
2017-04-23 23:16:51 +03:00
|
|
|
s.remoteChannel = begin.RemoteChannel
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-04-22 22:56:08 +03:00
|
|
|
go s.startMux()
|
|
|
|
|
|
|
|
return s, nil
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
2017-04-23 23:16:51 +03:00
|
|
|
var keepaliveFrame = []byte{0x00, 0x00, 0x00, 0x08, 0x02, 0x00, 0x00, 0x00}
|
|
|
|
|
2017-04-01 23:00:36 +03:00
|
|
|
func (c *Conn) startMux() {
|
2017-04-23 23:16:51 +03:00
|
|
|
nextSession := newSession(c, 0)
|
2017-04-01 23:00:36 +03:00
|
|
|
|
|
|
|
// map channel to session
|
|
|
|
sessions := make(map[uint16]*Session)
|
|
|
|
|
2017-04-17 06:39:31 +03:00
|
|
|
keepalive := time.NewTicker(c.idleTimeout / 2)
|
|
|
|
|
2017-04-01 23:00:36 +03:00
|
|
|
fmt.Println("Starting mux")
|
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
c.errMu.Lock()
|
|
|
|
defer c.errMu.Unlock()
|
|
|
|
|
2017-04-22 19:48:39 +03:00
|
|
|
outer:
|
2017-04-01 23:00:36 +03:00
|
|
|
for {
|
|
|
|
if c.err != nil {
|
2017-04-27 06:35:29 +03:00
|
|
|
c.closeDone()
|
2017-04-24 06:24:12 +03:00
|
|
|
return
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
select {
|
2017-04-24 06:24:12 +03:00
|
|
|
case c.err = <-c.readErr:
|
2017-04-01 23:00:36 +03:00
|
|
|
fmt.Println("Got read error")
|
|
|
|
|
2017-04-23 04:31:07 +03:00
|
|
|
case fr := <-c.rxFrame:
|
|
|
|
ch, ok := sessions[fr.channel]
|
|
|
|
if !ok {
|
2017-04-27 06:35:29 +03:00
|
|
|
c.err = errors.Errorf("unexpected frame: %#v", fr.preformative)
|
2017-04-23 04:31:07 +03:00
|
|
|
continue outer
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
2017-04-23 04:31:07 +03:00
|
|
|
ch.rx <- fr
|
2017-04-01 23:00:36 +03:00
|
|
|
|
|
|
|
case c.newSession <- nextSession:
|
|
|
|
fmt.Println("Got new session request")
|
|
|
|
sessions[nextSession.channel] = nextSession
|
|
|
|
// TODO: handle max session/wrapping
|
2017-04-23 23:16:51 +03:00
|
|
|
nextSession = newSession(c, nextSession.channel+1)
|
2017-04-01 23:00:36 +03:00
|
|
|
|
|
|
|
case s := <-c.delSession:
|
|
|
|
fmt.Println("Got delete session request")
|
|
|
|
delete(sessions, s.channel)
|
|
|
|
|
2017-04-17 06:39:31 +03:00
|
|
|
case <-keepalive.C:
|
|
|
|
fmt.Printf("Writing: %# 02x\n", keepaliveFrame)
|
|
|
|
_, c.err = c.net.Write(keepaliveFrame)
|
2017-04-27 06:35:29 +03:00
|
|
|
case <-c.done:
|
|
|
|
return
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) connReader() {
|
2017-04-23 04:31:07 +03:00
|
|
|
buf := bufPool.Get().(*bytes.Buffer)
|
|
|
|
buf.Reset()
|
2017-04-24 06:24:12 +03:00
|
|
|
rxBuf := make([]byte, c.maxFrameSize)
|
|
|
|
|
|
|
|
negotiating := true
|
2017-04-23 04:31:07 +03:00
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
idleTimeout := c.idleTimeout
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
var currentHeader frameHeader
|
|
|
|
frameInProgress := false
|
|
|
|
var err error
|
2017-04-23 04:31:07 +03:00
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
for {
|
|
|
|
fmt.Println(frameInProgress, buf.Len())
|
|
|
|
if frameInProgress || buf.Len() < 8 { // 8 = min size for header
|
|
|
|
c.net.SetReadDeadline(time.Now().Add(idleTimeout))
|
|
|
|
n, err := c.net.Read(rxBuf[:c.maxFrameSize]) // TODO: send error on frame too large
|
|
|
|
if err != nil {
|
|
|
|
c.readErr <- err
|
|
|
|
return
|
2017-04-24 06:24:12 +03:00
|
|
|
}
|
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
_, err = buf.Write(rxBuf[:n])
|
2017-04-23 04:31:07 +03:00
|
|
|
if err != nil {
|
2017-04-24 06:24:12 +03:00
|
|
|
c.readErr <- err
|
|
|
|
return
|
2017-04-23 04:31:07 +03:00
|
|
|
}
|
2017-04-27 06:35:29 +03:00
|
|
|
}
|
2017-04-23 04:31:07 +03:00
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
if buf.Len() < 8 {
|
|
|
|
continue
|
|
|
|
}
|
2017-04-23 04:31:07 +03:00
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
if negotiating && bytes.Equal(buf.Bytes()[:4], []byte{'A', 'M', 'Q', 'P'}) {
|
|
|
|
p, err := parseProto(buf)
|
2017-04-23 04:31:07 +03:00
|
|
|
if err != nil {
|
2017-04-24 06:24:12 +03:00
|
|
|
c.readErr <- err
|
|
|
|
return
|
2017-04-23 04:31:07 +03:00
|
|
|
}
|
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
if p.ProtoID == protoAMQP {
|
|
|
|
negotiating = false
|
2017-04-24 06:24:12 +03:00
|
|
|
}
|
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
// fmt.Printf("GOT: %#v\n", p)
|
|
|
|
|
2017-04-24 06:24:12 +03:00
|
|
|
select {
|
|
|
|
case <-c.done:
|
|
|
|
return
|
2017-04-27 06:35:29 +03:00
|
|
|
case c.rxProto <- p:
|
|
|
|
}
|
|
|
|
// fmt.Printf("Buf len: %d\n", buf.Len())
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
if !frameInProgress {
|
|
|
|
currentHeader, err = parseFrameHeader(buf)
|
|
|
|
if err != nil {
|
|
|
|
c.readErr <- err
|
|
|
|
return
|
|
|
|
}
|
|
|
|
frameInProgress = true
|
|
|
|
}
|
|
|
|
// fmt.Printf("GOT: %#v\n", currentHeader)
|
|
|
|
|
|
|
|
if uint64(buf.Len()) < uint64(currentHeader.Size-8) {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
frameInProgress = false
|
|
|
|
|
|
|
|
frameBody := buf.Next(int(currentHeader.Size - 8))
|
|
|
|
|
|
|
|
p, err := parseFrame(bytes.NewBuffer(frameBody))
|
|
|
|
if err != nil {
|
|
|
|
c.readErr <- err
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
fmt.Printf("GOT: %#v\n", p)
|
|
|
|
|
|
|
|
if o, ok := p.(*performOpen); ok && o.MaxFrameSize < c.maxFrameSize {
|
|
|
|
if o.IdleTimeout > 0 && o.IdleTimeout < idleTimeout {
|
|
|
|
idleTimeout = o.IdleTimeout
|
2017-04-24 06:24:12 +03:00
|
|
|
}
|
2017-04-23 04:31:07 +03:00
|
|
|
}
|
2017-04-27 06:35:29 +03:00
|
|
|
|
|
|
|
select {
|
|
|
|
case <-c.done:
|
|
|
|
return
|
|
|
|
case c.rxFrame <- frame{channel: currentHeader.Channel, preformative: p}:
|
|
|
|
}
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
|
|
On connection open, we'll need to handle 4 possible scenarios:
|
|
|
|
1. Straight into AMQP.
|
|
|
|
2. SASL -> AMQP.
|
|
|
|
3. TLS -> AMQP.
|
|
|
|
4. TLS -> SASL -> AMQP
|
|
|
|
*/
|
|
|
|
func (c *Conn) negotiateProto() stateFunc {
|
|
|
|
switch {
|
2017-04-23 19:42:48 +03:00
|
|
|
case c.tlsNegotiation && !c.tlsComplete:
|
2017-04-23 21:01:44 +03:00
|
|
|
return c.exchangeProtoHeader(protoTLS)
|
2017-04-01 23:00:36 +03:00
|
|
|
case c.saslHandlers != nil && !c.saslComplete:
|
2017-04-23 21:01:44 +03:00
|
|
|
return c.exchangeProtoHeader(protoSASL)
|
2017-04-01 23:00:36 +03:00
|
|
|
default:
|
2017-04-23 21:01:44 +03:00
|
|
|
return c.exchangeProtoHeader(protoAMQP)
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// ProtoIDs
|
|
|
|
const (
|
2017-04-23 21:01:44 +03:00
|
|
|
protoAMQP = 0x0
|
|
|
|
protoTLS = 0x2
|
|
|
|
protoSASL = 0x3
|
2017-04-01 23:00:36 +03:00
|
|
|
)
|
|
|
|
|
2017-04-24 06:24:12 +03:00
|
|
|
func (c *Conn) exchangeProtoHeader(protoID uint8) stateFunc {
|
|
|
|
_, c.err = c.net.Write([]byte{'A', 'M', 'Q', 'P', protoID, 1, 0, 0})
|
2017-04-01 23:00:36 +03:00
|
|
|
if c.err != nil {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2017-04-24 06:24:12 +03:00
|
|
|
var p proto
|
|
|
|
select {
|
|
|
|
case p = <-c.rxProto:
|
|
|
|
case c.err = <-c.readErr:
|
2017-04-01 23:00:36 +03:00
|
|
|
return nil
|
2017-04-24 07:03:50 +03:00
|
|
|
case fr := <-c.rxFrame:
|
|
|
|
c.err = errors.Errorf("unexpected frame %#v", fr)
|
2017-04-27 06:35:29 +03:00
|
|
|
return nil
|
2017-04-24 06:24:12 +03:00
|
|
|
case <-time.After(1 * time.Second):
|
|
|
|
c.err = ErrTimeout
|
2017-04-01 23:00:36 +03:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
fmt.Printf("Proto: %s; ProtoID: %d; Version: %d.%d.%d\n",
|
2017-04-24 06:24:12 +03:00
|
|
|
p.Proto,
|
|
|
|
p.ProtoID,
|
|
|
|
p.Major,
|
|
|
|
p.Minor,
|
|
|
|
p.Revision,
|
2017-04-01 23:00:36 +03:00
|
|
|
)
|
|
|
|
|
2017-04-24 06:24:12 +03:00
|
|
|
if protoID != p.ProtoID {
|
|
|
|
c.err = fmt.Errorf("unexpected protocol header %#00x, expected %#00x", p.ProtoID, protoID)
|
2017-04-01 23:00:36 +03:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2017-04-24 06:24:12 +03:00
|
|
|
switch protoID {
|
2017-04-23 21:01:44 +03:00
|
|
|
case protoAMQP:
|
2017-04-01 23:00:36 +03:00
|
|
|
return c.txOpen
|
2017-04-23 21:01:44 +03:00
|
|
|
case protoTLS:
|
2017-04-23 04:32:50 +03:00
|
|
|
return c.protoTLS
|
2017-04-23 21:01:44 +03:00
|
|
|
case protoSASL:
|
2017-04-01 23:00:36 +03:00
|
|
|
return c.protoSASL
|
|
|
|
default:
|
2017-04-24 06:24:12 +03:00
|
|
|
c.err = fmt.Errorf("unknown protocol ID %#02x", p.ProtoID)
|
2017-04-01 23:00:36 +03:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-04-23 04:32:50 +03:00
|
|
|
func (c *Conn) protoTLS() stateFunc {
|
2017-04-23 19:42:48 +03:00
|
|
|
if c.tlsConfig == nil {
|
2017-04-24 00:58:59 +03:00
|
|
|
c.tlsConfig = new(tls.Config)
|
2017-04-23 23:16:51 +03:00
|
|
|
}
|
|
|
|
if c.tlsConfig.ServerName == "" && !c.tlsConfig.InsecureSkipVerify {
|
|
|
|
c.tlsConfig.ServerName = c.hostname
|
2017-04-23 19:42:48 +03:00
|
|
|
}
|
|
|
|
c.net = tls.Client(c.net, c.tlsConfig)
|
|
|
|
c.tlsComplete = true
|
2017-04-23 04:32:50 +03:00
|
|
|
return c.negotiateProto
|
|
|
|
}
|
|
|
|
|
2017-04-23 04:31:07 +03:00
|
|
|
func (c *Conn) txPreformative(fr frame) error {
|
2017-04-23 20:31:45 +03:00
|
|
|
data, err := marshal(fr.preformative)
|
2017-04-01 23:00:36 +03:00
|
|
|
if err != nil {
|
2017-04-22 22:56:08 +03:00
|
|
|
return err
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
wr := bufPool.New().(*bytes.Buffer)
|
|
|
|
defer bufPool.Put(wr)
|
2017-04-22 22:56:08 +03:00
|
|
|
wr.Reset()
|
2017-04-01 23:00:36 +03:00
|
|
|
|
2017-04-23 21:01:44 +03:00
|
|
|
err = writeFrame(wr, frameTypeAMQP, fr.channel, data)
|
2017-04-22 22:56:08 +03:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2017-04-01 23:00:36 +03:00
|
|
|
|
|
|
|
_, err = c.net.Write(wr.Bytes())
|
2017-04-22 22:56:08 +03:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) txOpen() stateFunc {
|
2017-04-23 04:31:07 +03:00
|
|
|
c.err = c.txPreformative(frame{
|
2017-04-27 06:35:29 +03:00
|
|
|
preformative: &performOpen{
|
2017-04-23 23:16:51 +03:00
|
|
|
ContainerID: randString(),
|
2017-04-23 04:31:07 +03:00
|
|
|
Hostname: c.hostname,
|
|
|
|
MaxFrameSize: c.maxFrameSize,
|
|
|
|
ChannelMax: c.channelMax,
|
|
|
|
},
|
|
|
|
channel: 0,
|
2017-04-22 22:56:08 +03:00
|
|
|
})
|
|
|
|
if c.err != nil {
|
2017-04-01 23:00:36 +03:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
return c.rxOpen
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) rxOpen() stateFunc {
|
2017-04-24 06:24:12 +03:00
|
|
|
fr, err := c.readFrame()
|
2017-04-01 23:00:36 +03:00
|
|
|
if err != nil {
|
2017-04-24 06:24:12 +03:00
|
|
|
c.err = err
|
2017-04-01 23:00:36 +03:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
o, ok := fr.preformative.(*performOpen)
|
2017-04-24 06:24:12 +03:00
|
|
|
if !ok {
|
|
|
|
c.err = fmt.Errorf("unexpected frame type %T", fr.preformative)
|
2017-04-27 06:35:29 +03:00
|
|
|
return nil
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
fmt.Printf("Rx Open: %#v\n", o)
|
|
|
|
|
2017-04-27 06:35:29 +03:00
|
|
|
if o.MaxFrameSize > 0 {
|
|
|
|
c.peerMaxFrameSize = o.MaxFrameSize // TODO: make writer adhere
|
|
|
|
}
|
|
|
|
|
2017-04-23 21:01:44 +03:00
|
|
|
if o.IdleTimeout > 0 {
|
|
|
|
c.idleTimeout = o.IdleTimeout
|
2017-04-17 06:39:31 +03:00
|
|
|
}
|
|
|
|
|
2017-04-01 23:00:36 +03:00
|
|
|
if o.ChannelMax < c.channelMax {
|
|
|
|
c.channelMax = o.ChannelMax
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) protoSASL() stateFunc {
|
|
|
|
if c.saslHandlers == nil {
|
|
|
|
// we don't support SASL
|
|
|
|
c.err = fmt.Errorf("server request SASL, but not configured")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2017-04-24 06:24:12 +03:00
|
|
|
fr, err := c.readFrame()
|
2017-04-01 23:00:36 +03:00
|
|
|
if err != nil {
|
|
|
|
c.err = err
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2017-04-24 06:24:12 +03:00
|
|
|
sm, ok := fr.preformative.(*saslMechanisms)
|
|
|
|
if !ok {
|
|
|
|
c.err = fmt.Errorf("unexpected frame type %T", fr.preformative)
|
2017-04-01 23:00:36 +03:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, mech := range sm.Mechanisms {
|
|
|
|
if state, ok := c.saslHandlers[mech]; ok {
|
|
|
|
return state
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: send some sort of "auth not supported" frame?
|
|
|
|
c.err = fmt.Errorf("no supported auth mechanism (%v)", sm.Mechanisms)
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) saslOutcome() stateFunc {
|
2017-04-24 06:24:12 +03:00
|
|
|
fr, err := c.readFrame()
|
2017-04-01 23:00:36 +03:00
|
|
|
if err != nil {
|
|
|
|
c.err = err
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2017-04-24 06:24:12 +03:00
|
|
|
so, ok := fr.preformative.(*saslOutcome)
|
|
|
|
if !ok {
|
|
|
|
c.err = fmt.Errorf("unexpected frame type %T", fr.preformative)
|
2017-04-27 06:35:29 +03:00
|
|
|
return nil
|
2017-04-01 23:00:36 +03:00
|
|
|
}
|
|
|
|
|
2017-04-23 21:01:44 +03:00
|
|
|
if so.Code != codeSASLOK {
|
2017-04-01 23:00:36 +03:00
|
|
|
c.err = fmt.Errorf("SASL PLAIN auth failed with code %#00x: %s", so.Code, so.AdditionalData)
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
c.saslComplete = true
|
|
|
|
|
|
|
|
return c.negotiateProto
|
|
|
|
}
|
2017-04-24 06:24:12 +03:00
|
|
|
|
|
|
|
var ErrTimeout = errors.New("timeout waiting for response")
|
|
|
|
|
|
|
|
func (c *Conn) readFrame() (frame, error) {
|
|
|
|
var fr frame
|
|
|
|
select {
|
|
|
|
case fr = <-c.rxFrame:
|
|
|
|
return fr, nil
|
|
|
|
case err := <-c.readErr:
|
|
|
|
return fr, err
|
2017-04-24 07:03:50 +03:00
|
|
|
case p := <-c.rxProto:
|
|
|
|
return fr, errors.Errorf("unexpected protocol header %#v", p)
|
2017-04-24 06:24:12 +03:00
|
|
|
case <-time.After(1 * time.Second):
|
|
|
|
return fr, ErrTimeout
|
|
|
|
}
|
|
|
|
}
|