go-amqp/conn.go

600 строки
11 KiB
Go
Исходник Обычный вид История

2017-04-01 23:00:36 +03:00
package amqp
import (
"bytes"
2017-04-23 04:32:50 +03:00
"crypto/tls"
2017-04-01 23:00:36 +03:00
"fmt"
"net"
"net/url"
2017-04-27 06:35:29 +03:00
"sync"
"sync/atomic"
2017-04-17 06:39:31 +03:00
"time"
"github.com/pkg/errors"
2017-04-01 23:00:36 +03:00
)
// connection defaults
const (
initialMaxFrameSize = 512
initialChannelMax = 1
)
type ConnOpt func(*Conn) error
2017-04-01 23:00:36 +03:00
func ConnHostname(hostname string) ConnOpt {
return func(c *Conn) error {
c.hostname = hostname
return nil
}
}
func ConnTLS(enable bool) ConnOpt {
2017-04-23 04:32:50 +03:00
return func(c *Conn) error {
c.tlsNegotiation = enable
return nil
}
}
func ConnTLSConfig(tc *tls.Config) ConnOpt {
return func(c *Conn) error {
c.tlsConfig = tc
c.tlsNegotiation = true
2017-04-23 04:32:50 +03:00
return nil
}
}
2017-04-27 06:35:29 +03:00
func ConnIdleTimeout(d time.Duration) ConnOpt {
return func(c *Conn) error {
c.idleTimeout = d
return nil
}
}
func ConnMaxFrameSize(n uint32) ConnOpt {
return func(c *Conn) error {
// TODO: error if 0
c.maxFrameSize = n
return nil
}
}
2017-04-01 23:00:36 +03:00
type stateFunc func() stateFunc
type Conn struct {
net net.Conn
// TLS
tlsNegotiation bool
tlsComplete bool
tlsConfig *tls.Config
2017-04-01 23:00:36 +03:00
maxFrameSize uint32
channelMax uint16
hostname string
2017-04-17 06:39:31 +03:00
idleTimeout time.Duration
2017-04-01 23:00:36 +03:00
2017-04-27 06:35:29 +03:00
peerMaxFrameSize uint32
// startMux holds errMu from start until shutdown complete
// operations are sequential before startMux is started and
// holding the mutex is not necessary
errMu sync.Mutex
err error
doneClosed int32
done chan struct{}
2017-04-01 23:00:36 +03:00
// SASL
saslHandlers map[Symbol]stateFunc
saslComplete bool
// mux
readErr chan error
2017-04-24 06:24:12 +03:00
rxProto chan proto
rxFrame chan frame
2017-04-01 23:00:36 +03:00
newSession chan *Session
delSession chan *Session
}
func Dial(addr string, opts ...ConnOpt) (*Conn, error) {
2017-04-01 23:00:36 +03:00
u, err := url.Parse(addr)
if err != nil {
return nil, err
}
2017-04-23 04:32:50 +03:00
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
host = u.Host
port = "5672"
}
2017-04-01 23:00:36 +03:00
switch u.Scheme {
2017-04-23 04:32:50 +03:00
case "amqp", "amqps", "":
2017-04-01 23:00:36 +03:00
default:
return nil, fmt.Errorf("unsupported scheme %q", u.Scheme)
}
2017-04-23 04:32:50 +03:00
conn, err := net.Dial("tcp", host+":"+port)
2017-04-01 23:00:36 +03:00
if err != nil {
return nil, err
}
opts = append([]ConnOpt{
ConnHostname(host),
ConnTLS(u.Scheme == "amqps"),
2017-04-23 04:32:50 +03:00
}, opts...)
c, err := New(conn, opts...)
if err != nil {
return nil, err
}
return c, err
2017-04-01 23:00:36 +03:00
}
func New(conn net.Conn, opts ...ConnOpt) (*Conn, error) {
2017-04-01 23:00:36 +03:00
c := &Conn{
2017-04-27 06:35:29 +03:00
net: conn,
maxFrameSize: initialMaxFrameSize,
peerMaxFrameSize: initialMaxFrameSize,
channelMax: initialChannelMax,
idleTimeout: 1 * time.Minute,
done: make(chan struct{}),
readErr: make(chan error, 1),
rxProto: make(chan proto),
rxFrame: make(chan frame),
newSession: make(chan *Session),
delSession: make(chan *Session),
2017-04-01 23:00:36 +03:00
}
for _, opt := range opts {
if err := opt(c); err != nil {
return nil, err
}
}
2017-04-24 06:24:12 +03:00
go c.connReader()
2017-04-01 23:00:36 +03:00
for state := c.negotiateProto; state != nil; {
state = state()
}
2017-04-24 06:24:12 +03:00
if c.err != nil {
2017-04-27 06:35:29 +03:00
c.Close()
2017-04-24 06:24:12 +03:00
return nil, c.err
2017-04-01 23:00:36 +03:00
}
2017-04-24 06:24:12 +03:00
go c.startMux()
return c, nil
2017-04-01 23:00:36 +03:00
}
func (c *Conn) Close() error {
// TODO: shutdown AMQP
2017-04-27 06:35:29 +03:00
c.closeDone()
c.errMu.Lock()
defer c.errMu.Unlock()
2017-04-24 06:24:12 +03:00
err := c.net.Close()
if c.err == nil {
c.err = err
}
return c.err
2017-04-01 23:00:36 +03:00
}
2017-04-27 06:35:29 +03:00
func (c *Conn) closeDone() {
if atomic.CompareAndSwapInt32(&c.doneClosed, 0, 1) {
close(c.done)
}
}
2017-04-01 23:00:36 +03:00
func (c *Conn) MaxFrameSize() int {
return int(c.maxFrameSize)
}
func (c *Conn) ChannelMax() int {
return int(c.channelMax)
}
2017-04-23 20:31:45 +03:00
func (c *Conn) NewSession() (*Session, error) {
2017-04-24 06:24:12 +03:00
var s *Session
select {
case <-c.done:
return nil, c.err
case s = <-c.newSession:
}
2017-04-01 23:00:36 +03:00
2017-04-27 06:35:29 +03:00
err := s.txFrame(&performBegin{
2017-04-01 23:00:36 +03:00
NextOutgoingID: 0,
IncomingWindow: 1,
})
2017-04-24 06:24:12 +03:00
if err != nil {
s.Close()
return nil, err
}
2017-04-01 23:00:36 +03:00
2017-04-24 06:24:12 +03:00
var fr frame
select {
case <-c.done:
return nil, c.err
case fr = <-s.rx:
}
2017-04-27 06:35:29 +03:00
begin, ok := fr.preformative.(*performBegin)
if !ok {
s.Close()
return nil, fmt.Errorf("unexpected begin response: %+v", fr)
2017-04-01 23:00:36 +03:00
}
fmt.Printf("Begin Resp: %+v", begin)
2017-04-01 23:00:36 +03:00
// TODO: record negotiated settings
s.remoteChannel = begin.RemoteChannel
2017-04-01 23:00:36 +03:00
go s.startMux()
return s, nil
2017-04-01 23:00:36 +03:00
}
var keepaliveFrame = []byte{0x00, 0x00, 0x00, 0x08, 0x02, 0x00, 0x00, 0x00}
2017-04-01 23:00:36 +03:00
func (c *Conn) startMux() {
nextSession := newSession(c, 0)
2017-04-01 23:00:36 +03:00
// map channel to session
sessions := make(map[uint16]*Session)
2017-04-17 06:39:31 +03:00
keepalive := time.NewTicker(c.idleTimeout / 2)
2017-04-01 23:00:36 +03:00
fmt.Println("Starting mux")
2017-04-27 06:35:29 +03:00
c.errMu.Lock()
defer c.errMu.Unlock()
outer:
2017-04-01 23:00:36 +03:00
for {
if c.err != nil {
2017-04-27 06:35:29 +03:00
c.closeDone()
2017-04-24 06:24:12 +03:00
return
2017-04-01 23:00:36 +03:00
}
select {
2017-04-24 06:24:12 +03:00
case c.err = <-c.readErr:
2017-04-01 23:00:36 +03:00
fmt.Println("Got read error")
case fr := <-c.rxFrame:
ch, ok := sessions[fr.channel]
if !ok {
2017-04-27 06:35:29 +03:00
c.err = errors.Errorf("unexpected frame: %#v", fr.preformative)
continue outer
2017-04-01 23:00:36 +03:00
}
ch.rx <- fr
2017-04-01 23:00:36 +03:00
case c.newSession <- nextSession:
fmt.Println("Got new session request")
sessions[nextSession.channel] = nextSession
// TODO: handle max session/wrapping
nextSession = newSession(c, nextSession.channel+1)
2017-04-01 23:00:36 +03:00
case s := <-c.delSession:
fmt.Println("Got delete session request")
delete(sessions, s.channel)
2017-04-17 06:39:31 +03:00
case <-keepalive.C:
fmt.Printf("Writing: %# 02x\n", keepaliveFrame)
_, c.err = c.net.Write(keepaliveFrame)
2017-04-27 06:35:29 +03:00
case <-c.done:
return
2017-04-01 23:00:36 +03:00
}
}
}
func (c *Conn) connReader() {
buf := bufPool.Get().(*bytes.Buffer)
buf.Reset()
2017-04-24 06:24:12 +03:00
rxBuf := make([]byte, c.maxFrameSize)
negotiating := true
2017-04-27 06:35:29 +03:00
idleTimeout := c.idleTimeout
2017-04-01 23:00:36 +03:00
2017-04-27 06:35:29 +03:00
var currentHeader frameHeader
frameInProgress := false
var err error
2017-04-27 06:35:29 +03:00
for {
fmt.Println(frameInProgress, buf.Len())
if frameInProgress || buf.Len() < 8 { // 8 = min size for header
c.net.SetReadDeadline(time.Now().Add(idleTimeout))
n, err := c.net.Read(rxBuf[:c.maxFrameSize]) // TODO: send error on frame too large
if err != nil {
c.readErr <- err
return
2017-04-24 06:24:12 +03:00
}
2017-04-27 06:35:29 +03:00
_, err = buf.Write(rxBuf[:n])
if err != nil {
2017-04-24 06:24:12 +03:00
c.readErr <- err
return
}
2017-04-27 06:35:29 +03:00
}
2017-04-27 06:35:29 +03:00
if buf.Len() < 8 {
continue
}
2017-04-27 06:35:29 +03:00
if negotiating && bytes.Equal(buf.Bytes()[:4], []byte{'A', 'M', 'Q', 'P'}) {
p, err := parseProto(buf)
if err != nil {
2017-04-24 06:24:12 +03:00
c.readErr <- err
return
}
2017-04-27 06:35:29 +03:00
if p.ProtoID == protoAMQP {
negotiating = false
2017-04-24 06:24:12 +03:00
}
2017-04-27 06:35:29 +03:00
// fmt.Printf("GOT: %#v\n", p)
2017-04-24 06:24:12 +03:00
select {
case <-c.done:
return
2017-04-27 06:35:29 +03:00
case c.rxProto <- p:
}
// fmt.Printf("Buf len: %d\n", buf.Len())
continue
}
if !frameInProgress {
currentHeader, err = parseFrameHeader(buf)
if err != nil {
c.readErr <- err
return
}
frameInProgress = true
}
// fmt.Printf("GOT: %#v\n", currentHeader)
if uint64(buf.Len()) < uint64(currentHeader.Size-8) {
continue
}
frameInProgress = false
frameBody := buf.Next(int(currentHeader.Size - 8))
p, err := parseFrame(bytes.NewBuffer(frameBody))
if err != nil {
c.readErr <- err
return
}
fmt.Printf("GOT: %#v\n", p)
if o, ok := p.(*performOpen); ok && o.MaxFrameSize < c.maxFrameSize {
if o.IdleTimeout > 0 && o.IdleTimeout < idleTimeout {
idleTimeout = o.IdleTimeout
2017-04-24 06:24:12 +03:00
}
}
2017-04-27 06:35:29 +03:00
select {
case <-c.done:
return
case c.rxFrame <- frame{channel: currentHeader.Channel, preformative: p}:
}
2017-04-01 23:00:36 +03:00
}
}
/*
On connection open, we'll need to handle 4 possible scenarios:
1. Straight into AMQP.
2. SASL -> AMQP.
3. TLS -> AMQP.
4. TLS -> SASL -> AMQP
*/
func (c *Conn) negotiateProto() stateFunc {
switch {
case c.tlsNegotiation && !c.tlsComplete:
2017-04-23 21:01:44 +03:00
return c.exchangeProtoHeader(protoTLS)
2017-04-01 23:00:36 +03:00
case c.saslHandlers != nil && !c.saslComplete:
2017-04-23 21:01:44 +03:00
return c.exchangeProtoHeader(protoSASL)
2017-04-01 23:00:36 +03:00
default:
2017-04-23 21:01:44 +03:00
return c.exchangeProtoHeader(protoAMQP)
2017-04-01 23:00:36 +03:00
}
}
// ProtoIDs
const (
2017-04-23 21:01:44 +03:00
protoAMQP = 0x0
protoTLS = 0x2
protoSASL = 0x3
2017-04-01 23:00:36 +03:00
)
2017-04-24 06:24:12 +03:00
func (c *Conn) exchangeProtoHeader(protoID uint8) stateFunc {
_, c.err = c.net.Write([]byte{'A', 'M', 'Q', 'P', protoID, 1, 0, 0})
2017-04-01 23:00:36 +03:00
if c.err != nil {
return nil
}
2017-04-24 06:24:12 +03:00
var p proto
select {
case p = <-c.rxProto:
case c.err = <-c.readErr:
2017-04-01 23:00:36 +03:00
return nil
case fr := <-c.rxFrame:
c.err = errors.Errorf("unexpected frame %#v", fr)
2017-04-27 06:35:29 +03:00
return nil
2017-04-24 06:24:12 +03:00
case <-time.After(1 * time.Second):
c.err = ErrTimeout
2017-04-01 23:00:36 +03:00
return nil
}
fmt.Printf("Proto: %s; ProtoID: %d; Version: %d.%d.%d\n",
2017-04-24 06:24:12 +03:00
p.Proto,
p.ProtoID,
p.Major,
p.Minor,
p.Revision,
2017-04-01 23:00:36 +03:00
)
2017-04-24 06:24:12 +03:00
if protoID != p.ProtoID {
c.err = fmt.Errorf("unexpected protocol header %#00x, expected %#00x", p.ProtoID, protoID)
2017-04-01 23:00:36 +03:00
return nil
}
2017-04-24 06:24:12 +03:00
switch protoID {
2017-04-23 21:01:44 +03:00
case protoAMQP:
2017-04-01 23:00:36 +03:00
return c.txOpen
2017-04-23 21:01:44 +03:00
case protoTLS:
2017-04-23 04:32:50 +03:00
return c.protoTLS
2017-04-23 21:01:44 +03:00
case protoSASL:
2017-04-01 23:00:36 +03:00
return c.protoSASL
default:
2017-04-24 06:24:12 +03:00
c.err = fmt.Errorf("unknown protocol ID %#02x", p.ProtoID)
2017-04-01 23:00:36 +03:00
return nil
}
}
2017-04-23 04:32:50 +03:00
func (c *Conn) protoTLS() stateFunc {
if c.tlsConfig == nil {
2017-04-24 00:58:59 +03:00
c.tlsConfig = new(tls.Config)
}
if c.tlsConfig.ServerName == "" && !c.tlsConfig.InsecureSkipVerify {
c.tlsConfig.ServerName = c.hostname
}
c.net = tls.Client(c.net, c.tlsConfig)
c.tlsComplete = true
2017-04-23 04:32:50 +03:00
return c.negotiateProto
}
func (c *Conn) txPreformative(fr frame) error {
2017-04-23 20:31:45 +03:00
data, err := marshal(fr.preformative)
2017-04-01 23:00:36 +03:00
if err != nil {
return err
2017-04-01 23:00:36 +03:00
}
wr := bufPool.New().(*bytes.Buffer)
defer bufPool.Put(wr)
wr.Reset()
2017-04-01 23:00:36 +03:00
2017-04-23 21:01:44 +03:00
err = writeFrame(wr, frameTypeAMQP, fr.channel, data)
if err != nil {
return err
}
2017-04-01 23:00:36 +03:00
_, err = c.net.Write(wr.Bytes())
return err
}
func (c *Conn) txOpen() stateFunc {
c.err = c.txPreformative(frame{
2017-04-27 06:35:29 +03:00
preformative: &performOpen{
ContainerID: randString(),
Hostname: c.hostname,
MaxFrameSize: c.maxFrameSize,
ChannelMax: c.channelMax,
},
channel: 0,
})
if c.err != nil {
2017-04-01 23:00:36 +03:00
return nil
}
return c.rxOpen
}
func (c *Conn) rxOpen() stateFunc {
2017-04-24 06:24:12 +03:00
fr, err := c.readFrame()
2017-04-01 23:00:36 +03:00
if err != nil {
2017-04-24 06:24:12 +03:00
c.err = err
2017-04-01 23:00:36 +03:00
return nil
}
2017-04-27 06:35:29 +03:00
o, ok := fr.preformative.(*performOpen)
2017-04-24 06:24:12 +03:00
if !ok {
c.err = fmt.Errorf("unexpected frame type %T", fr.preformative)
2017-04-27 06:35:29 +03:00
return nil
2017-04-01 23:00:36 +03:00
}
fmt.Printf("Rx Open: %#v\n", o)
2017-04-27 06:35:29 +03:00
if o.MaxFrameSize > 0 {
c.peerMaxFrameSize = o.MaxFrameSize // TODO: make writer adhere
}
2017-04-23 21:01:44 +03:00
if o.IdleTimeout > 0 {
c.idleTimeout = o.IdleTimeout
2017-04-17 06:39:31 +03:00
}
2017-04-01 23:00:36 +03:00
if o.ChannelMax < c.channelMax {
c.channelMax = o.ChannelMax
}
return nil
}
func (c *Conn) protoSASL() stateFunc {
if c.saslHandlers == nil {
// we don't support SASL
c.err = fmt.Errorf("server request SASL, but not configured")
return nil
}
2017-04-24 06:24:12 +03:00
fr, err := c.readFrame()
2017-04-01 23:00:36 +03:00
if err != nil {
c.err = err
return nil
}
2017-04-24 06:24:12 +03:00
sm, ok := fr.preformative.(*saslMechanisms)
if !ok {
c.err = fmt.Errorf("unexpected frame type %T", fr.preformative)
2017-04-01 23:00:36 +03:00
return nil
}
for _, mech := range sm.Mechanisms {
if state, ok := c.saslHandlers[mech]; ok {
return state
}
}
// TODO: send some sort of "auth not supported" frame?
c.err = fmt.Errorf("no supported auth mechanism (%v)", sm.Mechanisms)
return nil
}
func (c *Conn) saslOutcome() stateFunc {
2017-04-24 06:24:12 +03:00
fr, err := c.readFrame()
2017-04-01 23:00:36 +03:00
if err != nil {
c.err = err
return nil
}
2017-04-24 06:24:12 +03:00
so, ok := fr.preformative.(*saslOutcome)
if !ok {
c.err = fmt.Errorf("unexpected frame type %T", fr.preformative)
2017-04-27 06:35:29 +03:00
return nil
2017-04-01 23:00:36 +03:00
}
2017-04-23 21:01:44 +03:00
if so.Code != codeSASLOK {
2017-04-01 23:00:36 +03:00
c.err = fmt.Errorf("SASL PLAIN auth failed with code %#00x: %s", so.Code, so.AdditionalData)
return nil
}
c.saslComplete = true
return c.negotiateProto
}
2017-04-24 06:24:12 +03:00
var ErrTimeout = errors.New("timeout waiting for response")
func (c *Conn) readFrame() (frame, error) {
var fr frame
select {
case fr = <-c.rxFrame:
return fr, nil
case err := <-c.readErr:
return fr, err
case p := <-c.rxProto:
return fr, errors.Errorf("unexpected protocol header %#v", p)
2017-04-24 06:24:12 +03:00
case <-time.After(1 * time.Second):
return fr, ErrTimeout
}
}