go-amqp/conn.go

515 строки
9.4 KiB
Go
Исходник Обычный вид История

2017-04-01 23:00:36 +03:00
package amqp
import (
"bytes"
"fmt"
"io"
2017-04-01 23:00:36 +03:00
"net"
"net/url"
2017-04-17 06:39:31 +03:00
"time"
"github.com/pkg/errors"
2017-04-01 23:00:36 +03:00
)
// connection defaults
const (
initialMaxFrameSize = 512
initialChannelMax = 1
)
type Opt func(*Conn) error
func OptHostname(hostname string) Opt {
return func(c *Conn) error {
c.hostname = hostname
return nil
}
}
2017-04-01 23:00:36 +03:00
type stateFunc func() stateFunc
type Conn struct {
net net.Conn
maxFrameSize uint32
channelMax uint16
hostname string
2017-04-17 06:39:31 +03:00
idleTimeout time.Duration
2017-04-01 23:00:36 +03:00
rxBuf []byte
err error
// SASL
saslHandlers map[Symbol]stateFunc
saslComplete bool
// mux
readErr chan error
rxFrame chan frame
txFrame chan frame
2017-04-01 23:00:36 +03:00
newSession chan *Session
delSession chan *Session
}
func Dial(addr string, opts ...Opt) (*Conn, error) {
u, err := url.Parse(addr)
if err != nil {
return nil, err
}
var conn net.Conn
switch u.Scheme {
case "amqp", "":
conn, err = net.Dial("tcp", u.Host)
default:
return nil, fmt.Errorf("unsupported scheme %q", u.Scheme)
}
if err != nil {
return nil, err
}
return New(conn, opts...)
}
func New(conn net.Conn, opts ...Opt) (*Conn, error) {
c := &Conn{
net: conn,
maxFrameSize: initialMaxFrameSize,
channelMax: initialChannelMax,
2017-04-17 06:39:31 +03:00
idleTimeout: 1 * time.Minute,
2017-04-01 23:00:36 +03:00
readErr: make(chan error),
rxFrame: make(chan frame),
txFrame: make(chan frame),
2017-04-01 23:00:36 +03:00
newSession: make(chan *Session),
delSession: make(chan *Session),
}
for _, opt := range opts {
if err := opt(c); err != nil {
return nil, err
}
}
for state := c.negotiateProto; state != nil; {
state = state()
}
if c.err != nil && c.net != nil {
c.net.Close()
}
return c, c.err
}
func (c *Conn) Close() error {
// TODO: shutdown AMQP
return c.net.Close()
}
func (c *Conn) MaxFrameSize() int {
return int(c.maxFrameSize)
}
func (c *Conn) ChannelMax() int {
return int(c.channelMax)
}
func (c *Conn) Session() (*Session, error) {
s := <-c.newSession
if s.err != nil {
return nil, s.err
}
2017-04-01 23:00:36 +03:00
s.txFrame(&Begin{
2017-04-01 23:00:36 +03:00
NextOutgoingID: 0,
IncomingWindow: 1,
})
fr := <-s.rx
begin, ok := fr.preformative.(*Begin)
if !ok {
s.Close()
return nil, fmt.Errorf("unexpected begin response: %+v", fr)
2017-04-01 23:00:36 +03:00
}
fmt.Printf("Begin Resp: %+v", begin)
2017-04-01 23:00:36 +03:00
// TODO: record negotiated settings
s.remoteChannel = fr.channel
s.newLink = make(chan *link)
s.delLink = make(chan *link)
2017-04-01 23:00:36 +03:00
go s.startMux()
return s, nil
2017-04-01 23:00:36 +03:00
}
func parseFrame(payload []byte) (Preformative, error) {
pType, err := preformativeType(payload)
2017-04-03 02:39:48 +03:00
if err != nil {
2017-04-01 23:00:36 +03:00
return nil, err
}
var t Preformative
switch pType {
case PreformativeOpen:
t = &Open{}
case PreformativeBegin:
t = &Begin{}
case PreformativeAttach:
t = &Attach{}
case PreformativeFlow:
t = &Flow{}
case PreformativeTransfer:
t = &Transfer{}
case PreformativeDisposition:
t = &Disposition{}
case PreformativeDetach:
t = &Detach{}
case PreformativeEnd:
t = &End{}
case PreformativeClose:
t = &Close{}
default:
return nil, errors.Errorf("unknown preformative type %0x", pType)
}
err = Unmarshal(bytes.NewReader(payload), t)
return t, err
2017-04-01 23:00:36 +03:00
}
type frame struct {
channel uint16
preformative Preformative
}
2017-04-01 23:00:36 +03:00
func (c *Conn) startMux() {
go c.connReader()
nextSession := &Session{conn: c, rx: make(chan frame)}
2017-04-01 23:00:36 +03:00
// map channel to session
sessions := make(map[uint16]*Session)
2017-04-17 06:39:31 +03:00
keepalive := time.NewTicker(c.idleTimeout / 2)
var buf bytes.Buffer
writeFrame(&buf, FrameTypeAMQP, 0, nil)
keepaliveFrame := buf.Bytes()
buf.Reset()
2017-04-17 06:39:31 +03:00
2017-04-01 23:00:36 +03:00
fmt.Println("Starting mux")
outer:
2017-04-01 23:00:36 +03:00
for {
if c.err != nil {
panic(c.err) // TODO: graceful close
}
select {
case err := <-c.readErr:
fmt.Println("Got read error")
c.err = err
case fr := <-c.rxFrame:
ch, ok := sessions[fr.channel]
if !ok {
c.err = errors.Errorf("unexpected frame: %+v", fr)
continue outer
2017-04-01 23:00:36 +03:00
}
ch.rx <- fr
2017-04-01 23:00:36 +03:00
case c.newSession <- nextSession:
fmt.Println("Got new session request")
sessions[nextSession.channel] = nextSession
// TODO: handle max session/wrapping
nextSession = &Session{conn: c, channel: nextSession.channel + 1, rx: make(chan frame)}
2017-04-01 23:00:36 +03:00
case s := <-c.delSession:
fmt.Println("Got delete session request")
delete(sessions, s.channel)
case fr := <-c.txFrame:
fmt.Printf("Writing: %d; %+v\n", fr.channel, fr.preformative)
c.err = c.txPreformative(fr)
2017-04-17 06:39:31 +03:00
case <-keepalive.C:
fmt.Printf("Writing: %# 02x\n", keepaliveFrame)
_, c.err = c.net.Write(keepaliveFrame)
2017-04-01 23:00:36 +03:00
}
}
}
func (c *Conn) connReader() {
buf := bufPool.Get().(*bytes.Buffer)
buf.Reset()
outer:
2017-04-01 23:00:36 +03:00
for {
n, err := c.net.Read(c.rxBuf[:c.maxFrameSize]) // TODO: send error on frame too large
if err != nil {
c.readErr <- err
return
}
_, err = buf.Write(c.rxBuf[:n])
if err != nil {
c.readErr <- err
return
}
for buf.Len() > 8 { // 8 = min size for header
frameHeader, err := parseFrameHeader(buf.Bytes())
if err != nil {
c.err = err
continue outer
}
if buf.Len() < int(frameHeader.size) {
continue outer
}
payload := make([]byte, frameHeader.size)
_, err = io.ReadFull(buf, payload)
if err != nil {
c.err = err
continue outer
}
preformative, err := parseFrame(payload[8:])
if err != nil {
c.err = err
continue outer
}
c.rxFrame <- frame{channel: frameHeader.channel, preformative: preformative}
}
2017-04-01 23:00:36 +03:00
}
}
/*
On connection open, we'll need to handle 4 possible scenarios:
1. Straight into AMQP.
2. SASL -> AMQP.
3. TLS -> AMQP.
4. TLS -> SASL -> AMQP
*/
func (c *Conn) negotiateProto() stateFunc {
switch {
case c.saslHandlers != nil && !c.saslComplete:
return c.exchangeProtoHeader(ProtoSASL)
default:
return c.exchangeProtoHeader(ProtoAMQP)
}
}
// ProtoIDs
const (
ProtoAMQP = 0x0
ProtoTLS = 0x2
ProtoSASL = 0x3
)
func (c *Conn) exchangeProtoHeader(proto uint8) stateFunc {
_, c.err = c.net.Write([]byte{'A', 'M', 'Q', 'P', proto, 1, 0, 0})
if c.err != nil {
return nil
}
c.rxBuf = make([]byte, c.maxFrameSize)
n, err := c.net.Read(c.rxBuf)
if err != nil {
c.err = err
return nil
}
fmt.Printf("Read %d bytes.\n", n)
p, err := parseProto(c.rxBuf[:n])
if err != nil {
c.err = err
return nil
}
fmt.Printf("Proto: %s; ProtoID: %d; Version: %d.%d.%d\n",
p.proto,
p.protoID,
p.major,
p.minor,
p.revision,
)
if proto != p.protoID {
c.err = fmt.Errorf("unexpected protocol header %#00x, expected %#00x", p.protoID, proto)
return nil
}
switch proto {
case ProtoAMQP:
return c.txOpen
case ProtoTLS:
// TODO
return nil
case ProtoSASL:
return c.protoSASL
default:
c.err = fmt.Errorf("unknown protocol ID %#02x", p.protoID)
return nil
}
}
func (c *Conn) txPreformative(fr frame) error {
data, err := Marshal(fr.preformative)
2017-04-01 23:00:36 +03:00
if err != nil {
return err
2017-04-01 23:00:36 +03:00
}
wr := bufPool.New().(*bytes.Buffer)
defer bufPool.Put(wr)
wr.Reset()
2017-04-01 23:00:36 +03:00
err = writeFrame(wr, FrameTypeAMQP, fr.channel, data)
if err != nil {
return err
}
2017-04-01 23:00:36 +03:00
_, err = c.net.Write(wr.Bytes())
return err
}
func (c *Conn) txOpen() stateFunc {
c.err = c.txPreformative(frame{
preformative: &Open{
ContainerID: "gopher",
Hostname: c.hostname,
MaxFrameSize: c.maxFrameSize,
ChannelMax: c.channelMax,
},
channel: 0,
})
if c.err != nil {
2017-04-01 23:00:36 +03:00
return nil
}
return c.rxOpen
}
func (c *Conn) rxOpen() stateFunc {
n, err := c.net.Read(c.rxBuf)
if err != nil {
c.err = errors.Wrapf(err, "reading")
2017-04-01 23:00:36 +03:00
return nil
}
fh, err := parseFrameHeader(c.rxBuf[:n])
if err != nil {
c.err = errors.Wrapf(err, "parsing frame header")
2017-04-01 23:00:36 +03:00
return nil
}
2017-04-03 02:39:48 +03:00
if fh.frameType != FrameTypeAMQP {
2017-04-01 23:00:36 +03:00
c.err = fmt.Errorf("unexpected frame type %#02x", fh.frameType)
}
var o Open
err = Unmarshal(bytes.NewBuffer(c.rxBuf[fh.dataOffsetBytes():n]), &o)
if err != nil {
c.err = errors.Wrapf(err, "unmarshaling")
2017-04-01 23:00:36 +03:00
return nil
}
fmt.Printf("Rx Open: %#v\n", o)
2017-04-17 06:39:31 +03:00
if o.IdleTimeout.Duration > 0 {
c.idleTimeout = o.IdleTimeout.Duration
}
2017-04-01 23:00:36 +03:00
if o.MaxFrameSize < c.maxFrameSize {
c.maxFrameSize = o.MaxFrameSize
}
if o.ChannelMax < c.channelMax {
c.channelMax = o.ChannelMax
}
if uint32(len(c.rxBuf)) < c.maxFrameSize {
c.rxBuf = make([]byte, c.maxFrameSize)
}
go c.startMux()
return nil
}
func (c *Conn) protoSASL() stateFunc {
if c.saslHandlers == nil {
// we don't support SASL
c.err = fmt.Errorf("server request SASL, but not configured")
return nil
}
n, err := c.net.Read(c.rxBuf)
if err != nil {
c.err = err
return nil
}
fh, err := parseFrameHeader(c.rxBuf[:n])
if err != nil {
c.err = err
return nil
}
2017-04-03 02:39:48 +03:00
if fh.frameType != FrameTypeSASL {
2017-04-01 23:00:36 +03:00
c.err = fmt.Errorf("unexpected frame type %#02x", fh.frameType)
}
var sm SASLMechanisms
err = Unmarshal(bytes.NewBuffer(c.rxBuf[fh.dataOffsetBytes():n]), &sm)
if err != nil {
c.err = err
return nil
}
for _, mech := range sm.Mechanisms {
if state, ok := c.saslHandlers[mech]; ok {
return state
}
}
// TODO: send some sort of "auth not supported" frame?
c.err = fmt.Errorf("no supported auth mechanism (%v)", sm.Mechanisms)
return nil
}
func (c *Conn) saslOutcome() stateFunc {
n, err := c.net.Read(c.rxBuf)
if err != nil {
c.err = err
return nil
}
fh, err := parseFrameHeader(c.rxBuf[:n])
if err != nil {
c.err = err
return nil
}
2017-04-03 02:39:48 +03:00
if fh.frameType != FrameTypeSASL {
2017-04-01 23:00:36 +03:00
c.err = fmt.Errorf("unexpected frame type %#02x", fh.frameType)
}
var so SASLOutcome
c.err = Unmarshal(bytes.NewBuffer(c.rxBuf[fh.dataOffsetBytes():n]), &so)
if c.err != nil {
return nil
}
if so.Code != CodeSASLOK {
c.err = fmt.Errorf("SASL PLAIN auth failed with code %#00x: %s", so.Code, so.AdditionalData)
return nil
}
c.saslComplete = true
return c.negotiateProto
}