go-amqp/conn.go

570 строки
10 KiB
Go

package amqp
import (
"bytes"
"fmt"
"log"
"net"
"net/url"
"github.com/pkg/errors"
)
// connection defaults
const (
initialMaxFrameSize = 512
initialChannelMax = 1
)
type Opt func(*Conn) error
func OptHostname(hostname string) Opt {
return func(c *Conn) error {
c.hostname = hostname
return nil
}
}
type stateFunc func() stateFunc
type Conn struct {
net net.Conn
maxFrameSize uint32
channelMax uint16
hostname string
rxBuf []byte
err error
// SASL
saslHandlers map[Symbol]stateFunc
saslComplete bool
// mux
readErr chan error
rxFrame chan []byte
txFrame chan *bytes.Buffer
newSession chan *Session
delSession chan *Session
}
func Dial(addr string, opts ...Opt) (*Conn, error) {
u, err := url.Parse(addr)
if err != nil {
return nil, err
}
var conn net.Conn
switch u.Scheme {
case "amqp", "":
conn, err = net.Dial("tcp", u.Host)
default:
return nil, fmt.Errorf("unsupported scheme %q", u.Scheme)
}
if err != nil {
return nil, err
}
return New(conn, opts...)
}
func New(conn net.Conn, opts ...Opt) (*Conn, error) {
c := &Conn{
net: conn,
maxFrameSize: initialMaxFrameSize,
channelMax: initialChannelMax,
readErr: make(chan error),
rxFrame: make(chan []byte),
txFrame: make(chan *bytes.Buffer),
newSession: make(chan *Session),
delSession: make(chan *Session),
}
for _, opt := range opts {
if err := opt(c); err != nil {
return nil, err
}
}
for state := c.negotiateProto; state != nil; {
state = state()
}
if c.err != nil && c.net != nil {
c.net.Close()
}
return c, c.err
}
func (c *Conn) Close() error {
// TODO: shutdown AMQP
return c.net.Close()
}
func (c *Conn) MaxFrameSize() int {
return int(c.maxFrameSize)
}
func (c *Conn) ChannelMax() int {
return int(c.channelMax)
}
type Session struct {
channel uint16
remoteChannel uint16
conn *Conn
err error
rx chan *frame
newLink chan *link
}
type frame struct {
header frameHeader
payload []byte
}
func (s *Session) Close() {
// TODO: send end preformative
s.conn.delSession <- s
}
func (s *Session) begin() error {
begin, err := Marshal(&Begin{
NextOutgoingID: 0,
IncomingWindow: 1,
OutgoingWindow: 1,
})
if err != nil {
return err
}
wr := bufPool.New().(*bytes.Buffer)
wr.Reset()
err = writeFrame(wr, FrameTypeAMQP, s.channel, begin)
if err != nil {
return err
}
s.conn.txFrame <- wr
fr := <-s.rx
pType, err := preformativeType(fr.payload)
if err != nil {
return err
}
if pType != PreformativeBegin {
return fmt.Errorf("unexpected begin response: %+v", fr)
}
var resp Begin
err = Unmarshal(bytes.NewReader(fr.payload), &resp)
if err != nil {
return err
}
fmt.Printf("Begin Resp: %+v", resp)
// TODO: record negotiated settings
s.remoteChannel = fr.header.channel
return nil
}
type Receiver struct {
link *link
}
func (s *Session) Receiver(source string) (*Receiver, error) {
link := <-s.newLink
attach, err := Marshal(&Attach{
Name: "ASHJDJKHJA-ASDHJ-ASDHGJH-ASDSAD78Y",
Handle: link.handle,
Role: true,
Source: &Source{
Address: source,
ExpiryPolicy: "link-attach",
},
Target: &Target{
Address: "",
ExpiryPolicy: "link-attach",
},
})
if err != nil {
return nil, err
}
wr := bufPool.New().(*bytes.Buffer)
wr.Reset()
err = writeFrame(wr, FrameTypeAMQP, s.channel, attach)
if err != nil {
return nil, err
}
s.conn.txFrame <- wr
fr := <-link.rx
resp, ok := fr.(*Attach)
if !ok {
return nil, fmt.Errorf("unexpected attach response: %+v", fr)
}
fmt.Printf("Attach Resp: %+v\n", resp)
fmt.Printf("Attach Source: %+v\n", resp.Source)
fmt.Printf("Attach Target: %+v\n", resp.Target)
r := &Receiver{link: link}
return r, nil
}
func (c *Conn) Session() (*Session, error) {
s := <-c.newSession
if s.err != nil {
return nil, s.err
}
err := s.begin()
if err != nil {
s.Close()
return nil, err
}
s.newLink = make(chan *link)
go s.startMux()
return s, nil
}
type link struct {
handle uint32
rx chan interface{}
}
func (s *Session) startMux() {
links := make(map[uint32]*link)
nextLink := &link{rx: make(chan interface{})}
for {
select {
case s.newLink <- nextLink:
fmt.Println("Got new link request")
links[nextLink.handle] = nextLink
// TODO: handle max session/wrapping
nextLink = &link{handle: nextLink.handle + 1, rx: make(chan interface{})}
case fr := <-s.rx:
go func() {
pType, err := preformativeType(fr.payload)
if err != nil {
log.Println("error:", err)
return
}
switch pType {
case PreformativeAttach:
var attach Attach
err = Unmarshal(bytes.NewReader(fr.payload), &attach)
if err != nil {
log.Println("error:", err)
return
}
link, ok := links[attach.Handle]
if ok {
link.rx <- &attach
}
// TODO: error
default:
// TODO: error
fmt.Printf("frame: %#v\n", fr)
}
}()
}
}
}
func (c *Conn) startMux() {
go c.connReader()
nextSession := &Session{conn: c, rx: make(chan *frame)}
// map channel to session
sessions := make(map[uint16]*Session)
fmt.Println("Starting mux")
for {
if c.err != nil {
panic(c.err) // TODO: graceful close
}
select {
case err := <-c.readErr:
fmt.Println("Got read error")
c.err = err
case rawFrame := <-c.rxFrame:
fmt.Println("Got rxFrame")
frameHeader, err := parseFrameHeader(rawFrame)
if err != nil {
c.err = err
continue
}
ch, ok := sessions[frameHeader.channel]
if !ok {
fmt.Printf("unexpected frame header: %+v", frameHeader)
continue
}
ch.rx <- &frame{header: frameHeader, payload: rawFrame[frameHeader.dataOffsetBytes():]}
case c.newSession <- nextSession:
fmt.Println("Got new session request")
sessions[nextSession.channel] = nextSession
// TODO: handle max session/wrapping
nextSession = &Session{conn: c, channel: nextSession.channel + 1, rx: make(chan *frame)}
case s := <-c.delSession:
fmt.Println("Got delete session request")
delete(sessions, s.channel)
case fr := <-c.txFrame:
fmt.Printf("Writing: %# 02x\n", fr)
_, c.err = c.net.Write(fr.Bytes())
bufPool.Put(fr)
}
}
}
func (c *Conn) connReader() {
for {
n, err := c.net.Read(c.rxBuf[:c.maxFrameSize]) // TODO: send error on frame too large
if err != nil {
c.readErr <- err
return
}
c.rxFrame <- append([]byte(nil), c.rxBuf[:n]...)
}
}
/*
On connection open, we'll need to handle 4 possible scenarios:
1. Straight into AMQP.
2. SASL -> AMQP.
3. TLS -> AMQP.
4. TLS -> SASL -> AMQP
*/
func (c *Conn) negotiateProto() stateFunc {
switch {
case c.saslHandlers != nil && !c.saslComplete:
return c.exchangeProtoHeader(ProtoSASL)
default:
return c.exchangeProtoHeader(ProtoAMQP)
}
}
// ProtoIDs
const (
ProtoAMQP = 0x0
ProtoTLS = 0x2
ProtoSASL = 0x3
)
func (c *Conn) exchangeProtoHeader(proto uint8) stateFunc {
_, c.err = c.net.Write([]byte{'A', 'M', 'Q', 'P', proto, 1, 0, 0})
if c.err != nil {
return nil
}
c.rxBuf = make([]byte, c.maxFrameSize)
n, err := c.net.Read(c.rxBuf)
if err != nil {
c.err = err
return nil
}
fmt.Printf("Read %d bytes.\n", n)
p, err := parseProto(c.rxBuf[:n])
if err != nil {
c.err = err
return nil
}
fmt.Printf("Proto: %s; ProtoID: %d; Version: %d.%d.%d\n",
p.proto,
p.protoID,
p.major,
p.minor,
p.revision,
)
if proto != p.protoID {
c.err = fmt.Errorf("unexpected protocol header %#00x, expected %#00x", p.protoID, proto)
return nil
}
switch proto {
case ProtoAMQP:
return c.txOpen
case ProtoTLS:
// TODO
return nil
case ProtoSASL:
return c.protoSASL
default:
c.err = fmt.Errorf("unknown protocol ID %#02x", p.protoID)
return nil
}
}
func (c *Conn) txOpen() stateFunc {
open, err := Marshal(&Open{
ContainerID: "gopher",
Hostname: c.hostname,
MaxFrameSize: c.maxFrameSize,
ChannelMax: c.channelMax,
})
if err != nil {
c.err = err
return nil
}
wr := bufPool.New().(*bytes.Buffer)
wr.Reset()
defer bufPool.Put(wr)
writeFrame(wr, FrameTypeAMQP, 0, open)
fmt.Printf("Writing: %# 02x\n", wr.Bytes())
_, err = c.net.Write(wr.Bytes())
if err != nil {
c.err = errors.Wrapf(err, "writing")
return nil
}
return c.rxOpen
}
func (c *Conn) rxOpen() stateFunc {
n, err := c.net.Read(c.rxBuf)
if err != nil {
c.err = errors.Wrapf(err, "reading")
return nil
}
fh, err := parseFrameHeader(c.rxBuf[:n])
if err != nil {
c.err = errors.Wrapf(err, "parsing frame header")
return nil
}
if fh.frameType != FrameTypeAMQP {
c.err = fmt.Errorf("unexpected frame type %#02x", fh.frameType)
}
var o Open
err = Unmarshal(bytes.NewBuffer(c.rxBuf[fh.dataOffsetBytes():n]), &o)
if err != nil {
c.err = errors.Wrapf(err, "unmarshaling")
return nil
}
fmt.Printf("Rx Open: %#v\n", o)
if o.MaxFrameSize < c.maxFrameSize {
c.maxFrameSize = o.MaxFrameSize
}
if o.ChannelMax < c.channelMax {
c.channelMax = o.ChannelMax
}
if uint32(len(c.rxBuf)) < c.maxFrameSize {
c.rxBuf = make([]byte, c.maxFrameSize)
}
go c.startMux()
return nil
}
func (c *Conn) protoSASL() stateFunc {
if c.saslHandlers == nil {
// we don't support SASL
c.err = fmt.Errorf("server request SASL, but not configured")
return nil
}
n, err := c.net.Read(c.rxBuf)
if err != nil {
c.err = err
return nil
}
fh, err := parseFrameHeader(c.rxBuf[:n])
if err != nil {
c.err = err
return nil
}
if fh.frameType != FrameTypeSASL {
c.err = fmt.Errorf("unexpected frame type %#02x", fh.frameType)
}
var sm SASLMechanisms
err = Unmarshal(bytes.NewBuffer(c.rxBuf[fh.dataOffsetBytes():n]), &sm)
if err != nil {
c.err = err
return nil
}
for _, mech := range sm.Mechanisms {
if state, ok := c.saslHandlers[mech]; ok {
return state
}
}
// TODO: send some sort of "auth not supported" frame?
c.err = fmt.Errorf("no supported auth mechanism (%v)", sm.Mechanisms)
return nil
}
func (c *Conn) saslOutcome() stateFunc {
n, err := c.net.Read(c.rxBuf)
if err != nil {
c.err = err
return nil
}
fh, err := parseFrameHeader(c.rxBuf[:n])
if err != nil {
c.err = err
return nil
}
if fh.frameType != FrameTypeSASL {
c.err = fmt.Errorf("unexpected frame type %#02x", fh.frameType)
}
var so SASLOutcome
c.err = Unmarshal(bytes.NewBuffer(c.rxBuf[fh.dataOffsetBytes():n]), &so)
if c.err != nil {
return nil
}
if so.Code != CodeSASLOK {
c.err = fmt.Errorf("SASL PLAIN auth failed with code %#00x: %s", so.Code, so.AdditionalData)
return nil
}
c.saslComplete = true
return c.negotiateProto
}