go-amqp/conn.go

560 строки
10 KiB
Go
Исходник Обычный вид История

2017-04-01 23:00:36 +03:00
package amqp
import (
"bytes"
2017-04-23 04:32:50 +03:00
"crypto/tls"
2017-04-01 23:00:36 +03:00
"fmt"
"io"
2017-04-01 23:00:36 +03:00
"net"
"net/url"
2017-04-17 06:39:31 +03:00
"time"
"github.com/pkg/errors"
2017-04-01 23:00:36 +03:00
)
// connection defaults
const (
initialMaxFrameSize = 512
initialChannelMax = 1
)
type ConnOpt func(*Conn) error
2017-04-01 23:00:36 +03:00
func ConnHostname(hostname string) ConnOpt {
return func(c *Conn) error {
c.hostname = hostname
return nil
}
}
func ConnTLS(enable bool) ConnOpt {
2017-04-23 04:32:50 +03:00
return func(c *Conn) error {
c.tlsNegotiation = enable
return nil
}
}
func ConnTLSConfig(tc *tls.Config) ConnOpt {
return func(c *Conn) error {
c.tlsConfig = tc
c.tlsNegotiation = true
2017-04-23 04:32:50 +03:00
return nil
}
}
2017-04-01 23:00:36 +03:00
type stateFunc func() stateFunc
type Conn struct {
net net.Conn
// TLS
tlsNegotiation bool
tlsComplete bool
tlsConfig *tls.Config
2017-04-01 23:00:36 +03:00
maxFrameSize uint32
channelMax uint16
hostname string
2017-04-17 06:39:31 +03:00
idleTimeout time.Duration
2017-04-01 23:00:36 +03:00
rxBuf []byte
err error
// SASL
saslHandlers map[Symbol]stateFunc
saslComplete bool
// mux
readErr chan error
rxFrame chan frame
txFrame chan frame
2017-04-01 23:00:36 +03:00
newSession chan *Session
delSession chan *Session
}
func Dial(addr string, opts ...ConnOpt) (*Conn, error) {
2017-04-01 23:00:36 +03:00
u, err := url.Parse(addr)
if err != nil {
return nil, err
}
2017-04-23 04:32:50 +03:00
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
host = u.Host
port = "5672"
}
2017-04-01 23:00:36 +03:00
switch u.Scheme {
2017-04-23 04:32:50 +03:00
case "amqp", "amqps", "":
2017-04-01 23:00:36 +03:00
default:
return nil, fmt.Errorf("unsupported scheme %q", u.Scheme)
}
2017-04-23 04:32:50 +03:00
conn, err := net.Dial("tcp", host+":"+port)
2017-04-01 23:00:36 +03:00
if err != nil {
return nil, err
}
opts = append([]ConnOpt{
ConnHostname(host),
ConnTLS(u.Scheme == "amqps"),
2017-04-23 04:32:50 +03:00
}, opts...)
c, err := New(conn, opts...)
if err != nil {
return nil, err
}
return c, err
2017-04-01 23:00:36 +03:00
}
func New(conn net.Conn, opts ...ConnOpt) (*Conn, error) {
2017-04-01 23:00:36 +03:00
c := &Conn{
net: conn,
maxFrameSize: initialMaxFrameSize,
channelMax: initialChannelMax,
2017-04-17 06:39:31 +03:00
idleTimeout: 1 * time.Minute,
2017-04-01 23:00:36 +03:00
readErr: make(chan error),
rxFrame: make(chan frame),
txFrame: make(chan frame),
2017-04-01 23:00:36 +03:00
newSession: make(chan *Session),
delSession: make(chan *Session),
}
for _, opt := range opts {
if err := opt(c); err != nil {
return nil, err
}
}
for state := c.negotiateProto; state != nil; {
state = state()
}
if c.err != nil && c.net != nil {
c.net.Close()
}
return c, c.err
}
func (c *Conn) Close() error {
// TODO: shutdown AMQP
return c.net.Close()
}
func (c *Conn) MaxFrameSize() int {
return int(c.maxFrameSize)
}
func (c *Conn) ChannelMax() int {
return int(c.channelMax)
}
2017-04-23 20:31:45 +03:00
func (c *Conn) NewSession() (*Session, error) {
s := <-c.newSession
if s.err != nil {
return nil, s.err
}
2017-04-01 23:00:36 +03:00
2017-04-23 20:31:45 +03:00
s.txFrame(&performativeBegin{
2017-04-01 23:00:36 +03:00
NextOutgoingID: 0,
IncomingWindow: 1,
})
fr := <-s.rx
2017-04-23 20:31:45 +03:00
begin, ok := fr.preformative.(*performativeBegin)
if !ok {
s.Close()
return nil, fmt.Errorf("unexpected begin response: %+v", fr)
2017-04-01 23:00:36 +03:00
}
fmt.Printf("Begin Resp: %+v", begin)
2017-04-01 23:00:36 +03:00
// TODO: record negotiated settings
s.remoteChannel = begin.RemoteChannel
2017-04-01 23:00:36 +03:00
go s.startMux()
return s, nil
2017-04-01 23:00:36 +03:00
}
2017-04-23 21:01:44 +03:00
func parseFrame(payload []byte) (preformative, error) {
pType, err := preformativeType(payload)
2017-04-03 02:39:48 +03:00
if err != nil {
2017-04-01 23:00:36 +03:00
return nil, err
}
2017-04-23 21:01:44 +03:00
var t preformative
switch pType {
2017-04-23 21:01:44 +03:00
case preformativeOpen:
2017-04-24 00:58:59 +03:00
t = new(performativeOpen)
2017-04-23 21:01:44 +03:00
case preformativeBegin:
2017-04-24 00:58:59 +03:00
t = new(performativeBegin)
2017-04-23 21:01:44 +03:00
case preformativeAttach:
2017-04-24 00:58:59 +03:00
t = new(performativeAttach)
2017-04-23 21:01:44 +03:00
case preformativeFlow:
2017-04-24 00:58:59 +03:00
t = new(flow)
2017-04-23 21:01:44 +03:00
case preformativeTransfer:
2017-04-24 00:58:59 +03:00
t = new(performativeTransfer)
2017-04-23 21:01:44 +03:00
case preformativeDisposition:
2017-04-24 00:58:59 +03:00
t = new(performativeDisposition)
2017-04-23 21:01:44 +03:00
case preformativeDetach:
2017-04-24 00:58:59 +03:00
t = new(performativeDetach)
2017-04-23 21:01:44 +03:00
case preformativeEnd:
2017-04-24 00:58:59 +03:00
t = new(performativeEnd)
2017-04-23 21:01:44 +03:00
case preformativeClose:
2017-04-24 00:58:59 +03:00
t = new(performativeClose)
default:
return nil, errors.Errorf("unknown preformative type %0x", pType)
}
2017-04-23 20:31:45 +03:00
err = unmarshal(bytes.NewReader(payload), t)
return t, err
2017-04-01 23:00:36 +03:00
}
type frame struct {
channel uint16
2017-04-23 21:01:44 +03:00
preformative preformative
}
var keepaliveFrame = []byte{0x00, 0x00, 0x00, 0x08, 0x02, 0x00, 0x00, 0x00}
2017-04-01 23:00:36 +03:00
func (c *Conn) startMux() {
go c.connReader()
nextSession := newSession(c, 0)
2017-04-01 23:00:36 +03:00
// map channel to session
sessions := make(map[uint16]*Session)
2017-04-17 06:39:31 +03:00
keepalive := time.NewTicker(c.idleTimeout / 2)
2017-04-01 23:00:36 +03:00
fmt.Println("Starting mux")
outer:
2017-04-01 23:00:36 +03:00
for {
if c.err != nil {
panic(c.err) // TODO: graceful close
}
select {
case err := <-c.readErr:
fmt.Println("Got read error")
c.err = err
case fr := <-c.rxFrame:
ch, ok := sessions[fr.channel]
if !ok {
c.err = errors.Errorf("unexpected frame: %+v", fr)
continue outer
2017-04-01 23:00:36 +03:00
}
ch.rx <- fr
2017-04-01 23:00:36 +03:00
case c.newSession <- nextSession:
fmt.Println("Got new session request")
sessions[nextSession.channel] = nextSession
// TODO: handle max session/wrapping
nextSession = newSession(c, nextSession.channel+1)
2017-04-01 23:00:36 +03:00
case s := <-c.delSession:
fmt.Println("Got delete session request")
delete(sessions, s.channel)
case fr := <-c.txFrame:
fmt.Printf("Writing: %d; %+v\n", fr.channel, fr.preformative)
c.err = c.txPreformative(fr)
2017-04-17 06:39:31 +03:00
case <-keepalive.C:
fmt.Printf("Writing: %# 02x\n", keepaliveFrame)
_, c.err = c.net.Write(keepaliveFrame)
2017-04-01 23:00:36 +03:00
}
}
}
func (c *Conn) connReader() {
buf := bufPool.Get().(*bytes.Buffer)
buf.Reset()
outer:
2017-04-01 23:00:36 +03:00
for {
n, err := c.net.Read(c.rxBuf[:c.maxFrameSize]) // TODO: send error on frame too large
if err != nil {
c.readErr <- err
return
}
_, err = buf.Write(c.rxBuf[:n])
if err != nil {
c.readErr <- err
return
}
for buf.Len() > 8 { // 8 = min size for header
frameHeader, err := parseFrameHeader(buf.Bytes())
if err != nil {
c.err = err
continue outer
}
if buf.Len() < int(frameHeader.size) {
continue outer
}
payload := make([]byte, frameHeader.size)
_, err = io.ReadFull(buf, payload)
if err != nil {
c.err = err
continue outer
}
preformative, err := parseFrame(payload[8:])
if err != nil {
c.err = err
continue outer
}
c.rxFrame <- frame{channel: frameHeader.channel, preformative: preformative}
}
2017-04-01 23:00:36 +03:00
}
}
/*
On connection open, we'll need to handle 4 possible scenarios:
1. Straight into AMQP.
2. SASL -> AMQP.
3. TLS -> AMQP.
4. TLS -> SASL -> AMQP
*/
func (c *Conn) negotiateProto() stateFunc {
switch {
case c.tlsNegotiation && !c.tlsComplete:
2017-04-23 21:01:44 +03:00
return c.exchangeProtoHeader(protoTLS)
2017-04-01 23:00:36 +03:00
case c.saslHandlers != nil && !c.saslComplete:
2017-04-23 21:01:44 +03:00
return c.exchangeProtoHeader(protoSASL)
2017-04-01 23:00:36 +03:00
default:
2017-04-23 21:01:44 +03:00
return c.exchangeProtoHeader(protoAMQP)
2017-04-01 23:00:36 +03:00
}
}
// ProtoIDs
const (
2017-04-23 21:01:44 +03:00
protoAMQP = 0x0
protoTLS = 0x2
protoSASL = 0x3
2017-04-01 23:00:36 +03:00
)
func (c *Conn) exchangeProtoHeader(proto uint8) stateFunc {
_, c.err = c.net.Write([]byte{'A', 'M', 'Q', 'P', proto, 1, 0, 0})
if c.err != nil {
return nil
}
c.rxBuf = make([]byte, c.maxFrameSize)
n, err := c.net.Read(c.rxBuf)
if err != nil {
c.err = err
return nil
}
fmt.Printf("Read %d bytes.\n", n)
p, err := parseProto(c.rxBuf[:n])
if err != nil {
c.err = err
return nil
}
fmt.Printf("Proto: %s; ProtoID: %d; Version: %d.%d.%d\n",
p.proto,
p.protoID,
p.major,
p.minor,
p.revision,
)
if proto != p.protoID {
c.err = fmt.Errorf("unexpected protocol header %#00x, expected %#00x", p.protoID, proto)
return nil
}
switch proto {
2017-04-23 21:01:44 +03:00
case protoAMQP:
2017-04-01 23:00:36 +03:00
return c.txOpen
2017-04-23 21:01:44 +03:00
case protoTLS:
2017-04-23 04:32:50 +03:00
return c.protoTLS
2017-04-23 21:01:44 +03:00
case protoSASL:
2017-04-01 23:00:36 +03:00
return c.protoSASL
default:
c.err = fmt.Errorf("unknown protocol ID %#02x", p.protoID)
return nil
}
}
2017-04-23 04:32:50 +03:00
func (c *Conn) protoTLS() stateFunc {
if c.tlsConfig == nil {
2017-04-24 00:58:59 +03:00
c.tlsConfig = new(tls.Config)
}
if c.tlsConfig.ServerName == "" && !c.tlsConfig.InsecureSkipVerify {
c.tlsConfig.ServerName = c.hostname
}
c.net = tls.Client(c.net, c.tlsConfig)
c.tlsComplete = true
2017-04-23 04:32:50 +03:00
return c.negotiateProto
}
func (c *Conn) txPreformative(fr frame) error {
2017-04-23 20:31:45 +03:00
data, err := marshal(fr.preformative)
2017-04-01 23:00:36 +03:00
if err != nil {
return err
2017-04-01 23:00:36 +03:00
}
wr := bufPool.New().(*bytes.Buffer)
defer bufPool.Put(wr)
wr.Reset()
2017-04-01 23:00:36 +03:00
2017-04-23 21:01:44 +03:00
err = writeFrame(wr, frameTypeAMQP, fr.channel, data)
if err != nil {
return err
}
2017-04-01 23:00:36 +03:00
_, err = c.net.Write(wr.Bytes())
return err
}
func (c *Conn) txOpen() stateFunc {
c.err = c.txPreformative(frame{
2017-04-23 20:31:45 +03:00
preformative: &performativeOpen{
ContainerID: randString(),
Hostname: c.hostname,
MaxFrameSize: c.maxFrameSize,
ChannelMax: c.channelMax,
},
channel: 0,
})
if c.err != nil {
2017-04-01 23:00:36 +03:00
return nil
}
return c.rxOpen
}
func (c *Conn) rxOpen() stateFunc {
n, err := c.net.Read(c.rxBuf)
if err != nil {
c.err = errors.Wrapf(err, "reading")
2017-04-01 23:00:36 +03:00
return nil
}
fh, err := parseFrameHeader(c.rxBuf[:n])
if err != nil {
c.err = errors.Wrapf(err, "parsing frame header")
2017-04-01 23:00:36 +03:00
return nil
}
2017-04-23 21:01:44 +03:00
if fh.frameType != frameTypeAMQP {
2017-04-01 23:00:36 +03:00
c.err = fmt.Errorf("unexpected frame type %#02x", fh.frameType)
}
2017-04-23 20:31:45 +03:00
var o performativeOpen
err = unmarshal(bytes.NewBuffer(c.rxBuf[fh.dataOffsetBytes():n]), &o)
2017-04-01 23:00:36 +03:00
if err != nil {
c.err = errors.Wrapf(err, "unmarshaling")
2017-04-01 23:00:36 +03:00
return nil
}
fmt.Printf("Rx Open: %#v\n", o)
2017-04-23 21:01:44 +03:00
if o.IdleTimeout > 0 {
c.idleTimeout = o.IdleTimeout
2017-04-17 06:39:31 +03:00
}
2017-04-01 23:00:36 +03:00
if o.MaxFrameSize < c.maxFrameSize {
c.maxFrameSize = o.MaxFrameSize
}
if o.ChannelMax < c.channelMax {
c.channelMax = o.ChannelMax
}
if uint32(len(c.rxBuf)) < c.maxFrameSize {
c.rxBuf = make([]byte, c.maxFrameSize)
}
go c.startMux()
return nil
}
func (c *Conn) protoSASL() stateFunc {
if c.saslHandlers == nil {
// we don't support SASL
c.err = fmt.Errorf("server request SASL, but not configured")
return nil
}
n, err := c.net.Read(c.rxBuf)
if err != nil {
c.err = err
return nil
}
fh, err := parseFrameHeader(c.rxBuf[:n])
if err != nil {
c.err = err
return nil
}
2017-04-23 21:01:44 +03:00
if fh.frameType != frameTypeSASL {
2017-04-01 23:00:36 +03:00
c.err = fmt.Errorf("unexpected frame type %#02x", fh.frameType)
}
2017-04-23 20:31:45 +03:00
var sm saslMechanisms
err = unmarshal(bytes.NewBuffer(c.rxBuf[fh.dataOffsetBytes():n]), &sm)
2017-04-01 23:00:36 +03:00
if err != nil {
c.err = err
return nil
}
for _, mech := range sm.Mechanisms {
if state, ok := c.saslHandlers[mech]; ok {
return state
}
}
// TODO: send some sort of "auth not supported" frame?
c.err = fmt.Errorf("no supported auth mechanism (%v)", sm.Mechanisms)
return nil
}
func (c *Conn) saslOutcome() stateFunc {
n, err := c.net.Read(c.rxBuf)
if err != nil {
c.err = err
return nil
}
fh, err := parseFrameHeader(c.rxBuf[:n])
if err != nil {
c.err = err
return nil
}
2017-04-23 21:01:44 +03:00
if fh.frameType != frameTypeSASL {
2017-04-01 23:00:36 +03:00
c.err = fmt.Errorf("unexpected frame type %#02x", fh.frameType)
}
2017-04-23 20:31:45 +03:00
var so saslOutcome
c.err = unmarshal(bytes.NewBuffer(c.rxBuf[fh.dataOffsetBytes():n]), &so)
2017-04-01 23:00:36 +03:00
if c.err != nil {
return nil
}
2017-04-23 21:01:44 +03:00
if so.Code != codeSASLOK {
2017-04-01 23:00:36 +03:00
c.err = fmt.Errorf("SASL PLAIN auth failed with code %#00x: %s", so.Code, so.AdditionalData)
return nil
}
c.saslComplete = true
return c.negotiateProto
}