зеркало из https://github.com/Azure/go-amqp.git
570 строки
10 KiB
Go
570 строки
10 KiB
Go
package amqp
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/url"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// connection defaults
|
|
const (
|
|
initialMaxFrameSize = 512
|
|
initialChannelMax = 1
|
|
)
|
|
|
|
type Opt func(*Conn) error
|
|
|
|
func OptHostname(hostname string) Opt {
|
|
return func(c *Conn) error {
|
|
c.hostname = hostname
|
|
return nil
|
|
}
|
|
}
|
|
|
|
type stateFunc func() stateFunc
|
|
|
|
type Conn struct {
|
|
net net.Conn
|
|
|
|
maxFrameSize uint32
|
|
channelMax uint16
|
|
hostname string
|
|
|
|
rxBuf []byte
|
|
err error
|
|
|
|
// SASL
|
|
saslHandlers map[Symbol]stateFunc
|
|
saslComplete bool
|
|
|
|
// mux
|
|
readErr chan error
|
|
rxFrame chan []byte
|
|
txFrame chan *bytes.Buffer
|
|
newSession chan *Session
|
|
delSession chan *Session
|
|
}
|
|
|
|
func Dial(addr string, opts ...Opt) (*Conn, error) {
|
|
u, err := url.Parse(addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var conn net.Conn
|
|
switch u.Scheme {
|
|
case "amqp", "":
|
|
conn, err = net.Dial("tcp", u.Host)
|
|
default:
|
|
return nil, fmt.Errorf("unsupported scheme %q", u.Scheme)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return New(conn, opts...)
|
|
}
|
|
|
|
func New(conn net.Conn, opts ...Opt) (*Conn, error) {
|
|
c := &Conn{
|
|
net: conn,
|
|
maxFrameSize: initialMaxFrameSize,
|
|
channelMax: initialChannelMax,
|
|
readErr: make(chan error),
|
|
rxFrame: make(chan []byte),
|
|
txFrame: make(chan *bytes.Buffer),
|
|
newSession: make(chan *Session),
|
|
delSession: make(chan *Session),
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
if err := opt(c); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
for state := c.negotiateProto; state != nil; {
|
|
state = state()
|
|
}
|
|
|
|
if c.err != nil && c.net != nil {
|
|
c.net.Close()
|
|
}
|
|
|
|
return c, c.err
|
|
}
|
|
|
|
func (c *Conn) Close() error {
|
|
// TODO: shutdown AMQP
|
|
return c.net.Close()
|
|
}
|
|
|
|
func (c *Conn) MaxFrameSize() int {
|
|
return int(c.maxFrameSize)
|
|
}
|
|
|
|
func (c *Conn) ChannelMax() int {
|
|
return int(c.channelMax)
|
|
}
|
|
|
|
type Session struct {
|
|
channel uint16
|
|
remoteChannel uint16
|
|
conn *Conn
|
|
err error
|
|
rx chan *frame
|
|
|
|
newLink chan *link
|
|
}
|
|
|
|
type frame struct {
|
|
header frameHeader
|
|
payload []byte
|
|
}
|
|
|
|
func (s *Session) Close() {
|
|
// TODO: send end preformative
|
|
s.conn.delSession <- s
|
|
}
|
|
|
|
func (s *Session) begin() error {
|
|
begin, err := Marshal(&Begin{
|
|
NextOutgoingID: 0,
|
|
IncomingWindow: 1,
|
|
OutgoingWindow: 1,
|
|
})
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
wr := bufPool.New().(*bytes.Buffer)
|
|
wr.Reset()
|
|
|
|
err = writeFrame(wr, FrameTypeAMQP, s.channel, begin)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.conn.txFrame <- wr
|
|
fr := <-s.rx
|
|
|
|
pType, err := preformativeType(fr.payload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if pType != PreformativeBegin {
|
|
return fmt.Errorf("unexpected begin response: %+v", fr)
|
|
}
|
|
|
|
var resp Begin
|
|
err = Unmarshal(bytes.NewReader(fr.payload), &resp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
fmt.Printf("Begin Resp: %+v", resp)
|
|
// TODO: record negotiated settings
|
|
s.remoteChannel = fr.header.channel
|
|
|
|
return nil
|
|
}
|
|
|
|
type Receiver struct {
|
|
link *link
|
|
}
|
|
|
|
func (s *Session) Receiver(source string) (*Receiver, error) {
|
|
link := <-s.newLink
|
|
attach, err := Marshal(&Attach{
|
|
Name: "ASHJDJKHJA-ASDHJ-ASDHGJH-ASDSAD78Y",
|
|
Handle: link.handle,
|
|
Role: true,
|
|
Source: &Source{
|
|
Address: source,
|
|
ExpiryPolicy: "link-attach",
|
|
},
|
|
Target: &Target{
|
|
Address: "",
|
|
ExpiryPolicy: "link-attach",
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
wr := bufPool.New().(*bytes.Buffer)
|
|
wr.Reset()
|
|
|
|
err = writeFrame(wr, FrameTypeAMQP, s.channel, attach)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s.conn.txFrame <- wr
|
|
fr := <-link.rx
|
|
|
|
resp, ok := fr.(*Attach)
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected attach response: %+v", fr)
|
|
}
|
|
|
|
fmt.Printf("Attach Resp: %+v\n", resp)
|
|
fmt.Printf("Attach Source: %+v\n", resp.Source)
|
|
fmt.Printf("Attach Target: %+v\n", resp.Target)
|
|
|
|
r := &Receiver{link: link}
|
|
return r, nil
|
|
}
|
|
|
|
func (c *Conn) Session() (*Session, error) {
|
|
s := <-c.newSession
|
|
if s.err != nil {
|
|
return nil, s.err
|
|
}
|
|
|
|
err := s.begin()
|
|
if err != nil {
|
|
s.Close()
|
|
return nil, err
|
|
}
|
|
|
|
s.newLink = make(chan *link)
|
|
|
|
go s.startMux()
|
|
|
|
return s, nil
|
|
}
|
|
|
|
type link struct {
|
|
handle uint32
|
|
rx chan interface{}
|
|
}
|
|
|
|
func (s *Session) startMux() {
|
|
links := make(map[uint32]*link)
|
|
nextLink := &link{rx: make(chan interface{})}
|
|
for {
|
|
select {
|
|
case s.newLink <- nextLink:
|
|
fmt.Println("Got new link request")
|
|
links[nextLink.handle] = nextLink
|
|
// TODO: handle max session/wrapping
|
|
nextLink = &link{handle: nextLink.handle + 1, rx: make(chan interface{})}
|
|
case fr := <-s.rx:
|
|
go func() {
|
|
pType, err := preformativeType(fr.payload)
|
|
if err != nil {
|
|
log.Println("error:", err)
|
|
return
|
|
}
|
|
|
|
switch pType {
|
|
case PreformativeAttach:
|
|
var attach Attach
|
|
err = Unmarshal(bytes.NewReader(fr.payload), &attach)
|
|
if err != nil {
|
|
log.Println("error:", err)
|
|
return
|
|
}
|
|
link, ok := links[attach.Handle]
|
|
if ok {
|
|
link.rx <- &attach
|
|
}
|
|
// TODO: error
|
|
default:
|
|
// TODO: error
|
|
fmt.Printf("frame: %#v\n", fr)
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Conn) startMux() {
|
|
go c.connReader()
|
|
|
|
nextSession := &Session{conn: c, rx: make(chan *frame)}
|
|
|
|
// map channel to session
|
|
sessions := make(map[uint16]*Session)
|
|
|
|
fmt.Println("Starting mux")
|
|
|
|
for {
|
|
if c.err != nil {
|
|
panic(c.err) // TODO: graceful close
|
|
}
|
|
|
|
select {
|
|
case err := <-c.readErr:
|
|
fmt.Println("Got read error")
|
|
c.err = err
|
|
|
|
case rawFrame := <-c.rxFrame:
|
|
fmt.Println("Got rxFrame")
|
|
frameHeader, err := parseFrameHeader(rawFrame)
|
|
if err != nil {
|
|
c.err = err
|
|
continue
|
|
}
|
|
|
|
ch, ok := sessions[frameHeader.channel]
|
|
if !ok {
|
|
fmt.Printf("unexpected frame header: %+v", frameHeader)
|
|
continue
|
|
}
|
|
ch.rx <- &frame{header: frameHeader, payload: rawFrame[frameHeader.dataOffsetBytes():]}
|
|
|
|
case c.newSession <- nextSession:
|
|
fmt.Println("Got new session request")
|
|
sessions[nextSession.channel] = nextSession
|
|
// TODO: handle max session/wrapping
|
|
nextSession = &Session{conn: c, channel: nextSession.channel + 1, rx: make(chan *frame)}
|
|
|
|
case s := <-c.delSession:
|
|
fmt.Println("Got delete session request")
|
|
delete(sessions, s.channel)
|
|
|
|
case fr := <-c.txFrame:
|
|
fmt.Printf("Writing: %# 02x\n", fr)
|
|
_, c.err = c.net.Write(fr.Bytes())
|
|
bufPool.Put(fr)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Conn) connReader() {
|
|
for {
|
|
n, err := c.net.Read(c.rxBuf[:c.maxFrameSize]) // TODO: send error on frame too large
|
|
if err != nil {
|
|
c.readErr <- err
|
|
return
|
|
}
|
|
|
|
c.rxFrame <- append([]byte(nil), c.rxBuf[:n]...)
|
|
}
|
|
}
|
|
|
|
/*
|
|
On connection open, we'll need to handle 4 possible scenarios:
|
|
1. Straight into AMQP.
|
|
2. SASL -> AMQP.
|
|
3. TLS -> AMQP.
|
|
4. TLS -> SASL -> AMQP
|
|
*/
|
|
func (c *Conn) negotiateProto() stateFunc {
|
|
switch {
|
|
case c.saslHandlers != nil && !c.saslComplete:
|
|
return c.exchangeProtoHeader(ProtoSASL)
|
|
default:
|
|
return c.exchangeProtoHeader(ProtoAMQP)
|
|
}
|
|
}
|
|
|
|
// ProtoIDs
|
|
const (
|
|
ProtoAMQP = 0x0
|
|
ProtoTLS = 0x2
|
|
ProtoSASL = 0x3
|
|
)
|
|
|
|
func (c *Conn) exchangeProtoHeader(proto uint8) stateFunc {
|
|
_, c.err = c.net.Write([]byte{'A', 'M', 'Q', 'P', proto, 1, 0, 0})
|
|
if c.err != nil {
|
|
return nil
|
|
}
|
|
|
|
c.rxBuf = make([]byte, c.maxFrameSize)
|
|
n, err := c.net.Read(c.rxBuf)
|
|
if err != nil {
|
|
c.err = err
|
|
return nil
|
|
}
|
|
|
|
fmt.Printf("Read %d bytes.\n", n)
|
|
|
|
p, err := parseProto(c.rxBuf[:n])
|
|
if err != nil {
|
|
c.err = err
|
|
return nil
|
|
}
|
|
|
|
fmt.Printf("Proto: %s; ProtoID: %d; Version: %d.%d.%d\n",
|
|
p.proto,
|
|
p.protoID,
|
|
p.major,
|
|
p.minor,
|
|
p.revision,
|
|
)
|
|
|
|
if proto != p.protoID {
|
|
c.err = fmt.Errorf("unexpected protocol header %#00x, expected %#00x", p.protoID, proto)
|
|
return nil
|
|
}
|
|
|
|
switch proto {
|
|
case ProtoAMQP:
|
|
return c.txOpen
|
|
case ProtoTLS:
|
|
// TODO
|
|
return nil
|
|
case ProtoSASL:
|
|
return c.protoSASL
|
|
default:
|
|
c.err = fmt.Errorf("unknown protocol ID %#02x", p.protoID)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (c *Conn) txOpen() stateFunc {
|
|
open, err := Marshal(&Open{
|
|
ContainerID: "gopher",
|
|
Hostname: c.hostname,
|
|
MaxFrameSize: c.maxFrameSize,
|
|
ChannelMax: c.channelMax,
|
|
})
|
|
if err != nil {
|
|
c.err = err
|
|
return nil
|
|
}
|
|
|
|
wr := bufPool.New().(*bytes.Buffer)
|
|
wr.Reset()
|
|
defer bufPool.Put(wr)
|
|
|
|
writeFrame(wr, FrameTypeAMQP, 0, open)
|
|
|
|
fmt.Printf("Writing: %# 02x\n", wr.Bytes())
|
|
|
|
_, err = c.net.Write(wr.Bytes())
|
|
if err != nil {
|
|
c.err = errors.Wrapf(err, "writing")
|
|
return nil
|
|
}
|
|
|
|
return c.rxOpen
|
|
}
|
|
|
|
func (c *Conn) rxOpen() stateFunc {
|
|
n, err := c.net.Read(c.rxBuf)
|
|
if err != nil {
|
|
c.err = errors.Wrapf(err, "reading")
|
|
return nil
|
|
}
|
|
|
|
fh, err := parseFrameHeader(c.rxBuf[:n])
|
|
if err != nil {
|
|
c.err = errors.Wrapf(err, "parsing frame header")
|
|
return nil
|
|
}
|
|
|
|
if fh.frameType != FrameTypeAMQP {
|
|
c.err = fmt.Errorf("unexpected frame type %#02x", fh.frameType)
|
|
}
|
|
|
|
var o Open
|
|
err = Unmarshal(bytes.NewBuffer(c.rxBuf[fh.dataOffsetBytes():n]), &o)
|
|
if err != nil {
|
|
c.err = errors.Wrapf(err, "unmarshaling")
|
|
return nil
|
|
}
|
|
|
|
fmt.Printf("Rx Open: %#v\n", o)
|
|
|
|
if o.MaxFrameSize < c.maxFrameSize {
|
|
c.maxFrameSize = o.MaxFrameSize
|
|
}
|
|
if o.ChannelMax < c.channelMax {
|
|
c.channelMax = o.ChannelMax
|
|
}
|
|
|
|
if uint32(len(c.rxBuf)) < c.maxFrameSize {
|
|
c.rxBuf = make([]byte, c.maxFrameSize)
|
|
}
|
|
|
|
go c.startMux()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) protoSASL() stateFunc {
|
|
if c.saslHandlers == nil {
|
|
// we don't support SASL
|
|
c.err = fmt.Errorf("server request SASL, but not configured")
|
|
return nil
|
|
}
|
|
|
|
n, err := c.net.Read(c.rxBuf)
|
|
if err != nil {
|
|
c.err = err
|
|
return nil
|
|
}
|
|
|
|
fh, err := parseFrameHeader(c.rxBuf[:n])
|
|
if err != nil {
|
|
c.err = err
|
|
return nil
|
|
}
|
|
|
|
if fh.frameType != FrameTypeSASL {
|
|
c.err = fmt.Errorf("unexpected frame type %#02x", fh.frameType)
|
|
}
|
|
|
|
var sm SASLMechanisms
|
|
err = Unmarshal(bytes.NewBuffer(c.rxBuf[fh.dataOffsetBytes():n]), &sm)
|
|
if err != nil {
|
|
c.err = err
|
|
return nil
|
|
}
|
|
|
|
for _, mech := range sm.Mechanisms {
|
|
if state, ok := c.saslHandlers[mech]; ok {
|
|
return state
|
|
}
|
|
}
|
|
|
|
// TODO: send some sort of "auth not supported" frame?
|
|
c.err = fmt.Errorf("no supported auth mechanism (%v)", sm.Mechanisms)
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) saslOutcome() stateFunc {
|
|
n, err := c.net.Read(c.rxBuf)
|
|
if err != nil {
|
|
c.err = err
|
|
return nil
|
|
}
|
|
|
|
fh, err := parseFrameHeader(c.rxBuf[:n])
|
|
if err != nil {
|
|
c.err = err
|
|
return nil
|
|
}
|
|
|
|
if fh.frameType != FrameTypeSASL {
|
|
c.err = fmt.Errorf("unexpected frame type %#02x", fh.frameType)
|
|
}
|
|
|
|
var so SASLOutcome
|
|
c.err = Unmarshal(bytes.NewBuffer(c.rxBuf[fh.dataOffsetBytes():n]), &so)
|
|
if c.err != nil {
|
|
return nil
|
|
}
|
|
|
|
if so.Code != CodeSASLOK {
|
|
c.err = fmt.Errorf("SASL PLAIN auth failed with code %#00x: %s", so.Code, so.AdditionalData)
|
|
return nil
|
|
}
|
|
|
|
c.saslComplete = true
|
|
|
|
return c.negotiateProto
|
|
}
|