Move buffer to internal package (#58)

* Move buffer to internal package

* rename write* to append* to fix vet

it's more accurate to the operation anyways
This commit is contained in:
Joel Hendrix 2021-09-07 19:43:35 -07:00 коммит произвёл GitHub
Родитель edbb7b246e
Коммит feea334a91
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
12 изменённых файлов: 541 добавлений и 507 удалений

36
conn.go
Просмотреть файл

@ -10,6 +10,8 @@ import (
"net"
"sync"
"time"
"github.com/Azure/go-amqp/internal/buffer"
)
// Default connection options
@ -202,8 +204,8 @@ type conn struct {
connReaderRun chan func() // functions to be run by conn reader (set deadline on conn to run)
// connWriter
txFrame chan frame // AMQP frames to be sent by connWriter
txBuf buffer // buffer for marshaling frames before transmitting
txFrame chan frame // AMQP frames to be sent by connWriter
txBuf buffer.Buffer // buffer for marshaling frames before transmitting
txDone chan struct{}
}
@ -441,7 +443,7 @@ func (c *conn) mux() {
func (c *conn) connReader() {
defer close(c.rxDone)
buf := new(buffer)
buf := &buffer.Buffer{}
var (
negotiating = true // true during conn establishment, check for protoHeaders
@ -452,21 +454,21 @@ func (c *conn) connReader() {
for {
switch {
// Cheaply reuse free buffer space when fully read.
case buf.len() == 0:
buf.reset()
case buf.Len() == 0:
buf.Reset()
// Prevent excessive/unbounded growth by shifting data to beginning of buffer.
case int64(buf.i) > int64(c.maxFrameSize):
buf.reclaim()
case int64(buf.Size()) > int64(c.maxFrameSize):
buf.Reclaim()
}
// need to read more if buf doesn't contain the complete frame
// or there's not enough in buf to parse the header
if frameInProgress || buf.len() < frameHeaderSize {
if frameInProgress || buf.Len() < frameHeaderSize {
if c.idleTimeout > 0 {
_ = c.net.SetReadDeadline(time.Now().Add(c.idleTimeout))
}
err := buf.readFromOnce(c.net)
err := buf.ReadFromOnce(c.net)
if err != nil {
select {
// check if error was due to close in progress
@ -487,12 +489,12 @@ func (c *conn) connReader() {
}
// read more if buf doesn't contain enough to parse the header
if buf.len() < frameHeaderSize {
if buf.Len() < frameHeaderSize {
continue
}
// during negotiation, check for proto frames
if negotiating && bytes.Equal(buf.bytes()[:4], []byte{'A', 'M', 'Q', 'P'}) {
if negotiating && bytes.Equal(buf.Bytes()[:4], []byte{'A', 'M', 'Q', 'P'}) {
p, err := parseProtoHeader(buf)
if err != nil {
c.connErr <- err
@ -534,7 +536,7 @@ func (c *conn) connReader() {
bodySize := int64(currentHeader.Size - frameHeaderSize)
// the full frame has been received
if int64(buf.len()) < bodySize {
if int64(buf.Len()) < bodySize {
continue
}
frameInProgress = false
@ -545,13 +547,13 @@ func (c *conn) connReader() {
}
// parse the frame
b, ok := buf.next(bodySize)
b, ok := buf.Next(bodySize)
if !ok {
c.connErr <- io.EOF
return
}
parsedBody, err := parseFrameBody(&buffer{b: b})
parsedBody, err := parseFrameBody(buffer.New(b))
if err != nil {
c.connErr <- err
return
@ -638,20 +640,20 @@ func (c *conn) writeFrame(fr frame) error {
}
// writeFrame into txBuf
c.txBuf.reset()
c.txBuf.Reset()
err := writeFrame(&c.txBuf, fr)
if err != nil {
return err
}
// validate the frame isn't exceeding peer's max frame size
requiredFrameSize := c.txBuf.len()
requiredFrameSize := c.txBuf.Len()
if uint64(requiredFrameSize) > uint64(c.peerMaxFrameSize) {
return fmt.Errorf("%T frame size %d larger than peer's max frame size %d", fr, requiredFrameSize, c.peerMaxFrameSize)
}
// write to network
_, err = c.net.Write(c.txBuf.bytes())
_, err = c.net.Write(c.txBuf.Bytes())
return err
}

210
decode.go
Просмотреть файл

@ -8,13 +8,15 @@ import (
"math"
"reflect"
"time"
"github.com/Azure/go-amqp/internal/buffer"
)
// parseFrameHeader reads the header from r and returns the result.
//
// No validation is done.
func parseFrameHeader(r *buffer) (frameHeader, error) {
buf, ok := r.next(8)
func parseFrameHeader(r *buffer.Buffer) (frameHeader, error) {
buf, ok := r.Next(8)
if !ok {
return frameHeader{}, errors.New("invalid frameHeader")
}
@ -37,9 +39,9 @@ func parseFrameHeader(r *buffer) (frameHeader, error) {
// parseProtoHeader reads the proto header from r and returns the results
//
// An error is returned if the protocol is not "AMQP" or if the version is not 1.0.0.
func parseProtoHeader(r *buffer) (protoHeader, error) {
func parseProtoHeader(r *buffer.Buffer) (protoHeader, error) {
const protoHeaderSize = 8
buf, ok := r.next(protoHeaderSize)
buf, ok := r.Next(protoHeaderSize)
if !ok {
return protoHeader{}, errors.New("invalid protoHeader")
}
@ -63,10 +65,10 @@ func parseProtoHeader(r *buffer) (protoHeader, error) {
}
// peekFrameBodyType peeks at the frame body's type code without advancing r.
func peekFrameBodyType(r *buffer) (amqpType, error) {
payload := r.bytes()
func peekFrameBodyType(r *buffer.Buffer) (amqpType, error) {
payload := r.Bytes()
if r.len() < 3 || payload[0] != 0 || amqpType(payload[1]) != typeCodeSmallUlong {
if r.Len() < 3 || payload[0] != 0 || amqpType(payload[1]) != typeCodeSmallUlong {
return 0, errors.New("invalid frame body header")
}
@ -74,7 +76,7 @@ func peekFrameBodyType(r *buffer) (amqpType, error) {
}
// parseFrameBody reads and unmarshals an AMQP frame.
func parseFrameBody(r *buffer) (frameBody, error) {
func parseFrameBody(r *buffer.Buffer) (frameBody, error) {
pType, err := peekFrameBodyType(r)
if err != nil {
return nil, err
@ -137,7 +139,7 @@ func parseFrameBody(r *buffer) (frameBody, error) {
// unmarshaler is fulfilled by types that can unmarshal
// themselves from AMQP data.
type unmarshaler interface {
unmarshal(r *buffer) error
unmarshal(r *buffer.Buffer) error
}
// unmarshal decodes AMQP encoded data into i.
@ -154,7 +156,7 @@ type unmarshaler interface {
// Common map types (map[string]string, map[Symbol]interface{}, and
// map[interface{}]interface{}), will be decoded via conversion to the mapStringAny,
// mapSymbolAny, and mapAnyAny types.
func unmarshal(r *buffer, i interface{}) error {
func unmarshal(r *buffer.Buffer, i interface{}) error {
if tryReadNull(r) {
return nil
}
@ -301,7 +303,7 @@ func unmarshal(r *buffer, i interface{}) error {
case *map[symbol]interface{}:
return (*mapSymbolAny)(t).unmarshal(r)
case *deliveryState:
type_, err := peekMessageType(r.bytes())
type_, err := peekMessageType(r.Bytes())
if err != nil {
return err
}
@ -355,7 +357,7 @@ func unmarshal(r *buffer, i interface{}) error {
//
// The composite from r will be unmarshaled into zero or more fields. An error
// will be returned if typ does not match the decoded type.
func unmarshalComposite(r *buffer, type_ amqpType, fields ...unmarshalField) error {
func unmarshalComposite(r *buffer.Buffer, type_ amqpType, fields ...unmarshalField) error {
cType, numFields, err := readCompositeHeader(r)
if err != nil {
return err
@ -417,9 +419,19 @@ type unmarshalField struct {
// is null.
type nullHandler func() error
func readType(r *buffer.Buffer) (amqpType, error) {
n, err := r.ReadByte()
return amqpType(n), err
}
func peekType(r *buffer.Buffer) (amqpType, error) {
n, err := r.PeekByte()
return amqpType(n), err
}
// readCompositeHeader reads and consumes the composite header from r.
func readCompositeHeader(r *buffer) (_ amqpType, fields int64, _ error) {
type_, err := r.readType()
func readCompositeHeader(r *buffer.Buffer) (_ amqpType, fields int64, _ error) {
type_, err := readType(r)
if err != nil {
return 0, 0, err
}
@ -441,19 +453,19 @@ func readCompositeHeader(r *buffer) (_ amqpType, fields int64, _ error) {
return amqpType(v), fields, err
}
func readListHeader(r *buffer) (length int64, _ error) {
type_, err := r.readType()
func readListHeader(r *buffer.Buffer) (length int64, _ error) {
type_, err := readType(r)
if err != nil {
return 0, err
}
listLength := r.len()
listLength := r.Len()
switch type_ {
case typeCodeList0:
return 0, nil
case typeCodeList8:
buf, ok := r.next(2)
buf, ok := r.Next(2)
if !ok {
return 0, errors.New("invalid length")
}
@ -465,7 +477,7 @@ func readListHeader(r *buffer) (length int64, _ error) {
}
length = int64(buf[1])
case typeCodeList32:
buf, ok := r.next(8)
buf, ok := r.Next(8)
if !ok {
return 0, errors.New("invalid length")
}
@ -483,17 +495,17 @@ func readListHeader(r *buffer) (length int64, _ error) {
return length, nil
}
func readArrayHeader(r *buffer) (length int64, _ error) {
type_, err := r.readType()
func readArrayHeader(r *buffer.Buffer) (length int64, _ error) {
type_, err := readType(r)
if err != nil {
return 0, err
}
arrayLength := r.len()
arrayLength := r.Len()
switch type_ {
case typeCodeArray8:
buf, ok := r.next(2)
buf, ok := r.Next(2)
if !ok {
return 0, errors.New("invalid length")
}
@ -505,7 +517,7 @@ func readArrayHeader(r *buffer) (length int64, _ error) {
}
length = int64(buf[1])
case typeCodeArray32:
buf, ok := r.next(8)
buf, ok := r.Next(8)
if !ok {
return 0, errors.New("invalid length")
}
@ -522,8 +534,8 @@ func readArrayHeader(r *buffer) (length int64, _ error) {
return length, nil
}
func readString(r *buffer) (string, error) {
type_, err := r.readType()
func readString(r *buffer.Buffer) (string, error) {
type_, err := readType(r)
if err != nil {
return "", err
}
@ -531,13 +543,13 @@ func readString(r *buffer) (string, error) {
var length int64
switch type_ {
case typeCodeStr8, typeCodeSym8:
n, err := r.readByte()
n, err := r.ReadByte()
if err != nil {
return "", err
}
length = int64(n)
case typeCodeStr32, typeCodeSym32:
buf, ok := r.next(4)
buf, ok := r.Next(4)
if !ok {
return "", fmt.Errorf("invalid length for type %#02x", type_)
}
@ -546,15 +558,15 @@ func readString(r *buffer) (string, error) {
return "", fmt.Errorf("type code %#02x is not a recognized string type", type_)
}
buf, ok := r.next(length)
buf, ok := r.Next(length)
if !ok {
return "", errors.New("invalid length")
}
return string(buf), nil
}
func readBinary(r *buffer) ([]byte, error) {
type_, err := r.readType()
func readBinary(r *buffer.Buffer) ([]byte, error) {
type_, err := readType(r)
if err != nil {
return nil, err
}
@ -562,13 +574,13 @@ func readBinary(r *buffer) ([]byte, error) {
var length int64
switch type_ {
case typeCodeVbin8:
n, err := r.readByte()
n, err := r.ReadByte()
if err != nil {
return nil, err
}
length = int64(n)
case typeCodeVbin32:
buf, ok := r.next(4)
buf, ok := r.Next(4)
if !ok {
return nil, fmt.Errorf("invalid length for type %#02x", type_)
}
@ -583,19 +595,19 @@ func readBinary(r *buffer) ([]byte, error) {
return make([]byte, 0), nil
}
buf, ok := r.next(length)
buf, ok := r.Next(length)
if !ok {
return nil, errors.New("invalid length")
}
return append([]byte(nil), buf...), nil
}
func readAny(r *buffer) (interface{}, error) {
func readAny(r *buffer.Buffer) (interface{}, error) {
if tryReadNull(r) {
return nil, nil
}
type_, err := r.peekType()
type_, err := peekType(r)
if err != nil {
return nil, errors.New("invalid length")
}
@ -691,7 +703,7 @@ func readAny(r *buffer) (interface{}, error) {
}
}
func readAnyMap(r *buffer) (interface{}, error) {
func readAnyMap(r *buffer.Buffer) (interface{}, error) {
var m map[interface{}]interface{}
err := (*mapAnyAny)(&m).unmarshal(r)
if err != nil {
@ -730,15 +742,15 @@ Loop:
return m, nil
}
func readAnyList(r *buffer) (interface{}, error) {
func readAnyList(r *buffer.Buffer) (interface{}, error) {
var a []interface{}
err := (*list)(&a).unmarshal(r)
return a, err
}
func readAnyArray(r *buffer) (interface{}, error) {
func readAnyArray(r *buffer.Buffer) (interface{}, error) {
// get the array type
buf := r.bytes()
buf := r.Bytes()
if len(buf) < 1 {
return nil, errors.New("invalid length")
}
@ -826,8 +838,8 @@ func readAnyArray(r *buffer) (interface{}, error) {
}
}
func readComposite(r *buffer) (interface{}, error) {
buf := r.bytes()
func readComposite(r *buffer.Buffer) (interface{}, error) {
buf := r.Bytes()
if len(buf) < 2 {
return nil, errors.New("invalid length for composite")
@ -941,8 +953,8 @@ func readComposite(r *buffer) (interface{}, error) {
}
}
func readTimestamp(r *buffer) (time.Time, error) {
type_, err := r.readType()
func readTimestamp(r *buffer.Buffer) (time.Time, error) {
type_, err := readType(r)
if err != nil {
return time.Time{}, err
}
@ -951,13 +963,13 @@ func readTimestamp(r *buffer) (time.Time, error) {
return time.Time{}, fmt.Errorf("invalid type for timestamp %02x", type_)
}
n, err := r.readUint64()
n, err := r.ReadUint64()
ms := int64(n)
return time.Unix(ms/1000, (ms%1000)*1000000).UTC(), err
}
func readInt(r *buffer) (int, error) {
type_, err := r.peekType()
func readInt(r *buffer.Buffer) (int, error) {
type_, err := peekType(r)
if err != nil {
return 0, err
}
@ -995,44 +1007,44 @@ func readInt(r *buffer) (int, error) {
}
}
func readLong(r *buffer) (int64, error) {
type_, err := r.readType()
func readLong(r *buffer.Buffer) (int64, error) {
type_, err := readType(r)
if err != nil {
return 0, err
}
switch type_ {
case typeCodeSmalllong:
n, err := r.readByte()
n, err := r.ReadByte()
return int64(n), err
case typeCodeLong:
n, err := r.readUint64()
n, err := r.ReadUint64()
return int64(n), err
default:
return 0, fmt.Errorf("invalid type for uint32 %02x", type_)
}
}
func readInt32(r *buffer) (int32, error) {
type_, err := r.readType()
func readInt32(r *buffer.Buffer) (int32, error) {
type_, err := readType(r)
if err != nil {
return 0, err
}
switch type_ {
case typeCodeSmallint:
n, err := r.readByte()
n, err := r.ReadByte()
return int32(n), err
case typeCodeInt:
n, err := r.readUint32()
n, err := r.ReadUint32()
return int32(n), err
default:
return 0, fmt.Errorf("invalid type for int32 %02x", type_)
}
}
func readShort(r *buffer) (int16, error) {
type_, err := r.readType()
func readShort(r *buffer.Buffer) (int16, error) {
type_, err := readType(r)
if err != nil {
return 0, err
}
@ -1041,12 +1053,12 @@ func readShort(r *buffer) (int16, error) {
return 0, fmt.Errorf("invalid type for short %02x", type_)
}
n, err := r.readUint16()
n, err := r.ReadUint16()
return int16(n), err
}
func readSbyte(r *buffer) (int8, error) {
type_, err := r.readType()
func readSbyte(r *buffer.Buffer) (int8, error) {
type_, err := readType(r)
if err != nil {
return 0, err
}
@ -1055,12 +1067,12 @@ func readSbyte(r *buffer) (int8, error) {
return 0, fmt.Errorf("invalid type for int8 %02x", type_)
}
n, err := r.readByte()
n, err := r.ReadByte()
return int8(n), err
}
func readUbyte(r *buffer) (uint8, error) {
type_, err := r.readType()
func readUbyte(r *buffer.Buffer) (uint8, error) {
type_, err := readType(r)
if err != nil {
return 0, err
}
@ -1069,11 +1081,11 @@ func readUbyte(r *buffer) (uint8, error) {
return 0, fmt.Errorf("invalid type for ubyte %02x", type_)
}
return r.readByte()
return r.ReadByte()
}
func readUshort(r *buffer) (uint16, error) {
type_, err := r.readType()
func readUshort(r *buffer.Buffer) (uint16, error) {
type_, err := readType(r)
if err != nil {
return 0, err
}
@ -1082,11 +1094,11 @@ func readUshort(r *buffer) (uint16, error) {
return 0, fmt.Errorf("invalid type for ushort %02x", type_)
}
return r.readUint16()
return r.ReadUint16()
}
func readUint32(r *buffer) (uint32, error) {
type_, err := r.readType()
func readUint32(r *buffer.Buffer) (uint32, error) {
type_, err := readType(r)
if err != nil {
return 0, err
}
@ -1095,17 +1107,17 @@ func readUint32(r *buffer) (uint32, error) {
case typeCodeUint0:
return 0, nil
case typeCodeSmallUint:
n, err := r.readByte()
n, err := r.ReadByte()
return uint32(n), err
case typeCodeUint:
return r.readUint32()
return r.ReadUint32()
default:
return 0, fmt.Errorf("invalid type for uint32 %02x", type_)
}
}
func readUlong(r *buffer) (uint64, error) {
type_, err := r.readType()
func readUlong(r *buffer.Buffer) (uint64, error) {
type_, err := readType(r)
if err != nil {
return 0, err
}
@ -1114,17 +1126,17 @@ func readUlong(r *buffer) (uint64, error) {
case typeCodeUlong0:
return 0, nil
case typeCodeSmallUlong:
n, err := r.readByte()
n, err := r.ReadByte()
return uint64(n), err
case typeCodeUlong:
return r.readUint64()
return r.ReadUint64()
default:
return 0, fmt.Errorf("invalid type for uint32 %02x", type_)
}
}
func readFloat(r *buffer) (float32, error) {
type_, err := r.readType()
func readFloat(r *buffer.Buffer) (float32, error) {
type_, err := readType(r)
if err != nil {
return 0, err
}
@ -1133,12 +1145,12 @@ func readFloat(r *buffer) (float32, error) {
return 0, fmt.Errorf("invalid type for float32 %02x", type_)
}
bits, err := r.readUint32()
bits, err := r.ReadUint32()
return math.Float32frombits(bits), err
}
func readDouble(r *buffer) (float64, error) {
type_, err := r.readType()
func readDouble(r *buffer.Buffer) (float64, error) {
type_, err := readType(r)
if err != nil {
return 0, err
}
@ -1147,19 +1159,19 @@ func readDouble(r *buffer) (float64, error) {
return 0, fmt.Errorf("invalid type for float64 %02x", type_)
}
bits, err := r.readUint64()
bits, err := r.ReadUint64()
return math.Float64frombits(bits), err
}
func readBool(r *buffer) (bool, error) {
type_, err := r.readType()
func readBool(r *buffer.Buffer) (bool, error) {
type_, err := readType(r)
if err != nil {
return false, err
}
switch type_ {
case typeCodeBool:
b, err := r.readByte()
b, err := r.ReadByte()
return b != 0, err
case typeCodeBoolTrue:
return true, nil
@ -1170,8 +1182,8 @@ func readBool(r *buffer) (bool, error) {
}
}
func readUint(r *buffer) (value uint64, _ error) {
type_, err := r.readType()
func readUint(r *buffer.Buffer) (value uint64, _ error) {
type_, err := readType(r)
if err != nil {
return 0, err
}
@ -1180,25 +1192,25 @@ func readUint(r *buffer) (value uint64, _ error) {
case typeCodeUint0, typeCodeUlong0:
return 0, nil
case typeCodeUbyte, typeCodeSmallUint, typeCodeSmallUlong:
n, err := r.readByte()
n, err := r.ReadByte()
return uint64(n), err
case typeCodeUshort:
n, err := r.readUint16()
n, err := r.ReadUint16()
return uint64(n), err
case typeCodeUint:
n, err := r.readUint32()
n, err := r.ReadUint32()
return uint64(n), err
case typeCodeUlong:
return r.readUint64()
return r.ReadUint64()
default:
return 0, fmt.Errorf("type code %#02x is not a recognized number type", type_)
}
}
func readUUID(r *buffer) (UUID, error) {
func readUUID(r *buffer.Buffer) (UUID, error) {
var uuid UUID
type_, err := r.readType()
type_, err := readType(r)
if err != nil {
return uuid, err
}
@ -1207,7 +1219,7 @@ func readUUID(r *buffer) (UUID, error) {
return uuid, fmt.Errorf("type code %#00x is not a UUID", type_)
}
buf, ok := r.next(16)
buf, ok := r.Next(16)
if !ok {
return uuid, errors.New("invalid length")
}
@ -1216,17 +1228,17 @@ func readUUID(r *buffer) (UUID, error) {
return uuid, nil
}
func readMapHeader(r *buffer) (count uint32, _ error) {
type_, err := r.readType()
func readMapHeader(r *buffer.Buffer) (count uint32, _ error) {
type_, err := readType(r)
if err != nil {
return 0, err
}
length := r.len()
length := r.Len()
switch type_ {
case typeCodeMap8:
buf, ok := r.next(2)
buf, ok := r.Next(2)
if !ok {
return 0, errors.New("invalid length")
}
@ -1238,7 +1250,7 @@ func readMapHeader(r *buffer) (count uint32, _ error) {
}
count = uint32(buf[1])
case typeCodeMap32:
buf, ok := r.next(8)
buf, ok := r.Next(8)
if !ok {
return 0, errors.New("invalid length")
}
@ -1253,7 +1265,7 @@ func readMapHeader(r *buffer) (count uint32, _ error) {
return 0, fmt.Errorf("invalid map type %#02x", type_)
}
if int(count) > r.len() {
if int(count) > r.Len() {
return 0, errors.New("invalid length")
}
return count, nil

186
encode.go
Просмотреть файл

@ -7,17 +7,19 @@ import (
"math"
"time"
"unicode/utf8"
"github.com/Azure/go-amqp/internal/buffer"
)
// writesFrame encodes fr into buf.
func writeFrame(buf *buffer, fr frame) error {
func writeFrame(buf *buffer.Buffer, fr frame) error {
// write header
buf.write([]byte{
buf.Append([]byte{
0, 0, 0, 0, // size, overwrite later
2, // doff, see frameHeader.DataOffset comment
fr.type_, // frame type
})
buf.writeUint16(fr.channel) // channel
buf.AppendUint16(fr.channel) // channel
// write AMQP frame body
err := marshal(buf, fr.body)
@ -26,12 +28,12 @@ func writeFrame(buf *buffer, fr frame) error {
}
// validate size
if uint(buf.len()) > math.MaxUint32 {
if uint(buf.Len()) > math.MaxUint32 {
return errors.New("frame too large")
}
// retrieve raw bytes
bufBytes := buf.bytes()
bufBytes := buf.Bytes()
// write correct size
binary.BigEndian.PutUint32(bufBytes, uint32(len(bufBytes)))
@ -39,24 +41,24 @@ func writeFrame(buf *buffer, fr frame) error {
}
type marshaler interface {
marshal(*buffer) error
marshal(*buffer.Buffer) error
}
func marshal(wr *buffer, i interface{}) error {
func marshal(wr *buffer.Buffer, i interface{}) error {
switch t := i.(type) {
case nil:
wr.writeByte(byte(typeCodeNull))
wr.AppendByte(byte(typeCodeNull))
case bool:
if t {
wr.writeByte(byte(typeCodeBoolTrue))
wr.AppendByte(byte(typeCodeBoolTrue))
} else {
wr.writeByte(byte(typeCodeBoolFalse))
wr.AppendByte(byte(typeCodeBoolFalse))
}
case *bool:
if *t {
wr.writeByte(byte(typeCodeBoolTrue))
wr.AppendByte(byte(typeCodeBoolTrue))
} else {
wr.writeByte(byte(typeCodeBoolFalse))
wr.AppendByte(byte(typeCodeBoolFalse))
}
case uint:
writeUint64(wr, uint64(t))
@ -71,18 +73,18 @@ func marshal(wr *buffer, i interface{}) error {
case *uint32:
writeUint32(wr, *t)
case uint16:
wr.writeByte(byte(typeCodeUshort))
wr.writeUint16(t)
wr.AppendByte(byte(typeCodeUshort))
wr.AppendUint16(t)
case *uint16:
wr.writeByte(byte(typeCodeUshort))
wr.writeUint16(*t)
wr.AppendByte(byte(typeCodeUshort))
wr.AppendUint16(*t)
case uint8:
wr.write([]byte{
wr.Append([]byte{
byte(typeCodeUbyte),
t,
})
case *uint8:
wr.write([]byte{
wr.Append([]byte{
byte(typeCodeUbyte),
*t,
})
@ -91,21 +93,21 @@ func marshal(wr *buffer, i interface{}) error {
case *int:
writeInt64(wr, int64(*t))
case int8:
wr.write([]byte{
wr.Append([]byte{
byte(typeCodeByte),
uint8(t),
})
case *int8:
wr.write([]byte{
wr.Append([]byte{
byte(typeCodeByte),
uint8(*t),
})
case int16:
wr.writeByte(byte(typeCodeShort))
wr.writeUint16(uint16(t))
wr.AppendByte(byte(typeCodeShort))
wr.AppendUint16(uint16(t))
case *int16:
wr.writeByte(byte(typeCodeShort))
wr.writeUint16(uint16(*t))
wr.AppendByte(byte(typeCodeShort))
wr.AppendUint16(uint16(*t))
case int32:
writeInt32(wr, t)
case *int32:
@ -222,82 +224,82 @@ func marshal(wr *buffer, i interface{}) error {
return nil
}
func writeInt32(wr *buffer, n int32) {
func writeInt32(wr *buffer.Buffer, n int32) {
if n < 128 && n >= -128 {
wr.write([]byte{
wr.Append([]byte{
byte(typeCodeSmallint),
byte(n),
})
return
}
wr.writeByte(byte(typeCodeInt))
wr.writeUint32(uint32(n))
wr.AppendByte(byte(typeCodeInt))
wr.AppendUint32(uint32(n))
}
func writeInt64(wr *buffer, n int64) {
func writeInt64(wr *buffer.Buffer, n int64) {
if n < 128 && n >= -128 {
wr.write([]byte{
wr.Append([]byte{
byte(typeCodeSmalllong),
byte(n),
})
return
}
wr.writeByte(byte(typeCodeLong))
wr.writeUint64(uint64(n))
wr.AppendByte(byte(typeCodeLong))
wr.AppendUint64(uint64(n))
}
func writeUint32(wr *buffer, n uint32) {
func writeUint32(wr *buffer.Buffer, n uint32) {
if n == 0 {
wr.writeByte(byte(typeCodeUint0))
wr.AppendByte(byte(typeCodeUint0))
return
}
if n < 256 {
wr.write([]byte{
wr.Append([]byte{
byte(typeCodeSmallUint),
byte(n),
})
return
}
wr.writeByte(byte(typeCodeUint))
wr.writeUint32(n)
wr.AppendByte(byte(typeCodeUint))
wr.AppendUint32(n)
}
func writeUint64(wr *buffer, n uint64) {
func writeUint64(wr *buffer.Buffer, n uint64) {
if n == 0 {
wr.writeByte(byte(typeCodeUlong0))
wr.AppendByte(byte(typeCodeUlong0))
return
}
if n < 256 {
wr.write([]byte{
wr.Append([]byte{
byte(typeCodeSmallUlong),
byte(n),
})
return
}
wr.writeByte(byte(typeCodeUlong))
wr.writeUint64(n)
wr.AppendByte(byte(typeCodeUlong))
wr.AppendUint64(n)
}
func writeFloat(wr *buffer, f float32) {
wr.writeByte(byte(typeCodeFloat))
wr.writeUint32(math.Float32bits(f))
func writeFloat(wr *buffer.Buffer, f float32) {
wr.AppendByte(byte(typeCodeFloat))
wr.AppendUint32(math.Float32bits(f))
}
func writeDouble(wr *buffer, f float64) {
wr.writeByte(byte(typeCodeDouble))
wr.writeUint64(math.Float64bits(f))
func writeDouble(wr *buffer.Buffer, f float64) {
wr.AppendByte(byte(typeCodeDouble))
wr.AppendUint64(math.Float64bits(f))
}
func writeTimestamp(wr *buffer, t time.Time) {
wr.writeByte(byte(typeCodeTimestamp))
func writeTimestamp(wr *buffer.Buffer, t time.Time) {
wr.AppendByte(byte(typeCodeTimestamp))
ms := t.UnixNano() / int64(time.Millisecond)
wr.writeUint64(uint64(ms))
wr.AppendUint64(uint64(ms))
}
// marshalField is a field to be marshaled
@ -311,7 +313,7 @@ type marshalField struct {
// The returned bytes include the composite header and fields. Fields with
// omit set to true will be encoded as null or omitted altogether if there are
// no non-null fields after them.
func marshalComposite(wr *buffer, code amqpType, fields []marshalField) error {
func marshalComposite(wr *buffer.Buffer, code amqpType, fields []marshalField) error {
// lastSetIdx is the last index to have a non-omitted field.
// start at -1 as it's possible to have no fields in a composite
lastSetIdx := -1
@ -327,7 +329,7 @@ func marshalComposite(wr *buffer, code amqpType, fields []marshalField) error {
// write header only
if lastSetIdx == -1 {
wr.write([]byte{
wr.Append([]byte{
0x0,
byte(typeCodeSmallUlong),
byte(code),
@ -340,20 +342,20 @@ func marshalComposite(wr *buffer, code amqpType, fields []marshalField) error {
writeDescriptor(wr, code)
// write fields
wr.writeByte(byte(typeCodeList32))
wr.AppendByte(byte(typeCodeList32))
// write temp size, replace later
sizeIdx := wr.len()
wr.write([]byte{0, 0, 0, 0})
preFieldLen := wr.len()
sizeIdx := wr.Len()
wr.Append([]byte{0, 0, 0, 0})
preFieldLen := wr.Len()
// field count
wr.writeUint32(uint32(lastSetIdx + 1))
wr.AppendUint32(uint32(lastSetIdx + 1))
// write null to each index up to lastSetIdx
for _, f := range fields[:lastSetIdx+1] {
if f.omit {
wr.writeByte(byte(typeCodeNull))
wr.AppendByte(byte(typeCodeNull))
continue
}
err := marshal(wr, f.value)
@ -363,22 +365,22 @@ func marshalComposite(wr *buffer, code amqpType, fields []marshalField) error {
}
// fix size
size := uint32(wr.len() - preFieldLen)
buf := wr.bytes()
size := uint32(wr.Len() - preFieldLen)
buf := wr.Bytes()
binary.BigEndian.PutUint32(buf[sizeIdx:], size)
return nil
}
func writeDescriptor(wr *buffer, code amqpType) {
wr.write([]byte{
func writeDescriptor(wr *buffer.Buffer, code amqpType) {
wr.Append([]byte{
0x0,
byte(typeCodeSmallUlong),
byte(code),
})
}
func writeString(wr *buffer, str string) error {
func writeString(wr *buffer.Buffer, str string) error {
if !utf8.ValidString(str) {
return errors.New("not a valid UTF-8 string")
}
@ -387,18 +389,18 @@ func writeString(wr *buffer, str string) error {
switch {
// Str8
case l < 256:
wr.write([]byte{
wr.Append([]byte{
byte(typeCodeStr8),
byte(l),
})
wr.writeString(str)
wr.AppendString(str)
return nil
// Str32
case uint(l) < math.MaxUint32:
wr.writeByte(byte(typeCodeStr32))
wr.writeUint32(uint32(l))
wr.writeString(str)
wr.AppendByte(byte(typeCodeStr32))
wr.AppendUint32(uint32(l))
wr.AppendString(str)
return nil
default:
@ -406,24 +408,24 @@ func writeString(wr *buffer, str string) error {
}
}
func writeBinary(wr *buffer, bin []byte) error {
func writeBinary(wr *buffer.Buffer, bin []byte) error {
l := len(bin)
switch {
// List8
case l < 256:
wr.write([]byte{
wr.Append([]byte{
byte(typeCodeVbin8),
byte(l),
})
wr.write(bin)
wr.Append(bin)
return nil
// List32
case uint(l) < math.MaxUint32:
wr.writeByte(byte(typeCodeVbin32))
wr.writeUint32(uint32(l))
wr.write(bin)
wr.AppendByte(byte(typeCodeVbin32))
wr.AppendUint32(uint32(l))
wr.Append(bin)
return nil
default:
@ -431,9 +433,9 @@ func writeBinary(wr *buffer, bin []byte) error {
}
}
func writeMap(wr *buffer, m interface{}) error {
startIdx := wr.len()
wr.write([]byte{
func writeMap(wr *buffer.Buffer, m interface{}) error {
startIdx := wr.Len()
wr.Append([]byte{
byte(typeCodeMap32), // type
0, 0, 0, 0, // size placeholder
0, 0, 0, 0, // length placeholder
@ -537,10 +539,10 @@ func writeMap(wr *buffer, m interface{}) error {
}
// overwrite placeholder size and length
bytes := wr.bytes()[startIdx+1 : startIdx+9]
bytes := wr.Bytes()[startIdx+1 : startIdx+9]
_ = bytes[7] // bounds check hint
length := wr.len() - startIdx - 1 - 4 // -1 for type, -4 for length
length := wr.Len() - startIdx - 1 - 4 // -1 for type, -4 for length
binary.BigEndian.PutUint32(bytes[:4], uint32(length))
binary.BigEndian.PutUint32(bytes[4:8], uint32(pairs))
@ -553,26 +555,26 @@ const (
array32TLSize = 5
)
func writeArrayHeader(wr *buffer, length, typeSize int, type_ amqpType) {
func writeArrayHeader(wr *buffer.Buffer, length, typeSize int, type_ amqpType) {
size := length * typeSize
// array type
if size+array8TLSize <= math.MaxUint8 {
wr.write([]byte{
wr.Append([]byte{
byte(typeCodeArray8), // type
byte(size + array8TLSize), // size
byte(length), // length
byte(type_), // element type
})
} else {
wr.writeByte(byte(typeCodeArray32)) //type
wr.writeUint32(uint32(size + array32TLSize)) // size
wr.writeUint32(uint32(length)) // length
wr.writeByte(byte(type_)) // element type
wr.AppendByte(byte(typeCodeArray32)) //type
wr.AppendUint32(uint32(size + array32TLSize)) // size
wr.AppendUint32(uint32(length)) // length
wr.AppendByte(byte(type_)) // element type
}
}
func writeVariableArrayHeader(wr *buffer, length, elementsSizeTotal int, type_ amqpType) {
func writeVariableArrayHeader(wr *buffer.Buffer, length, elementsSizeTotal int, type_ amqpType) {
// 0xA_ == 1, 0xB_ == 4
// http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-types-v1.0-os.html#doc-idp82960
elementTypeSize := 1
@ -582,16 +584,16 @@ func writeVariableArrayHeader(wr *buffer, length, elementsSizeTotal int, type_ a
size := elementsSizeTotal + (length * elementTypeSize) // size excluding array length
if size+array8TLSize <= math.MaxUint8 {
wr.write([]byte{
wr.Append([]byte{
byte(typeCodeArray8), // type
byte(size + array8TLSize), // size
byte(length), // length
byte(type_), // element type
})
} else {
wr.writeByte(byte(typeCodeArray32)) // type
wr.writeUint32(uint32(size + array32TLSize)) // size
wr.writeUint32(uint32(length)) // length
wr.writeByte(byte(type_)) // element type
wr.AppendByte(byte(typeCodeArray32)) // type
wr.AppendUint32(uint32(size + array32TLSize)) // size
wr.AppendUint32(uint32(length)) // length
wr.AppendByte(byte(type_)) // element type
}
}

Просмотреть файл

@ -9,6 +9,7 @@ import (
"testing"
"time"
"github.com/Azure/go-amqp/internal/buffer"
"github.com/Azure/go-amqp/internal/testconn"
"github.com/fortytw2/leaktest"
)
@ -194,8 +195,8 @@ func fuzzUnmarshal(data []byte) int {
}
for _, t := range types {
_ = unmarshal(&buffer{b: data}, t)
_, _ = readAny(&buffer{b: data})
_ = unmarshal(buffer.New(data), t)
_, _ = readAny(buffer.New(data))
}
return 0
}

Просмотреть файл

@ -1,4 +1,4 @@
package amqp
package buffer
import (
"encoding/binary"
@ -6,12 +6,16 @@ import (
)
// buffer is similar to bytes.Buffer but specialized for this package
type buffer struct {
type Buffer struct {
b []byte
i int
}
func (b *buffer) next(n int64) ([]byte, bool) {
func New(b []byte) *Buffer {
return &Buffer{b: b}
}
func (b *Buffer) Next(n int64) ([]byte, bool) {
if b.readCheck(n) {
buf := b.b[b.i:len(b.b)]
b.i = len(b.b)
@ -23,29 +27,29 @@ func (b *buffer) next(n int64) ([]byte, bool) {
return buf, true
}
func (b *buffer) skip(n int) {
func (b *Buffer) Skip(n int) {
b.i += n
}
func (b *buffer) reset() {
func (b *Buffer) Reset() {
b.b = b.b[:0]
b.i = 0
}
// reclaim shifts used buffer space to the beginning of the
// underlying slice.
func (b *buffer) reclaim() {
l := b.len()
func (b *Buffer) Reclaim() {
l := b.Len()
copy(b.b[:l], b.b[b.i:])
b.b = b.b[:l]
b.i = 0
}
func (b *buffer) readCheck(n int64) bool {
func (b *Buffer) readCheck(n int64) bool {
return int64(b.i)+n > int64(len(b.b))
}
func (b *buffer) readByte() (byte, error) {
func (b *Buffer) ReadByte() (byte, error) {
if b.readCheck(1) {
return 0, io.EOF
}
@ -55,20 +59,15 @@ func (b *buffer) readByte() (byte, error) {
return byte_, nil
}
func (b *buffer) readType() (amqpType, error) {
n, err := b.readByte()
return amqpType(n), err
}
func (b *buffer) peekType() (amqpType, error) {
func (b *Buffer) PeekByte() (byte, error) {
if b.readCheck(1) {
return 0, io.EOF
}
return amqpType(b.b[b.i]), nil
return b.b[b.i], nil
}
func (b *buffer) readUint16() (uint16, error) {
func (b *Buffer) ReadUint16() (uint16, error) {
if b.readCheck(2) {
return 0, io.EOF
}
@ -78,7 +77,7 @@ func (b *buffer) readUint16() (uint16, error) {
return n, nil
}
func (b *buffer) readUint32() (uint32, error) {
func (b *Buffer) ReadUint32() (uint32, error) {
if b.readCheck(4) {
return 0, io.EOF
}
@ -88,7 +87,7 @@ func (b *buffer) readUint32() (uint32, error) {
return n, nil
}
func (b *buffer) readUint64() (uint64, error) {
func (b *Buffer) ReadUint64() (uint64, error) {
if b.readCheck(8) {
return 0, io.EOF
}
@ -98,7 +97,7 @@ func (b *buffer) readUint64() (uint64, error) {
return n, nil
}
func (b *buffer) readFromOnce(r io.Reader) error {
func (b *Buffer) ReadFromOnce(r io.Reader) error {
const minRead = 512
l := len(b.b)
@ -117,34 +116,45 @@ func (b *buffer) readFromOnce(r io.Reader) error {
return err
}
func (b *buffer) write(p []byte) {
func (b *Buffer) Append(p []byte) {
b.b = append(b.b, p...)
}
func (b *buffer) writeByte(byte_ byte) {
b.b = append(b.b, byte_)
func (b *Buffer) AppendByte(bb byte) {
b.b = append(b.b, bb)
}
func (b *buffer) writeString(s string) {
func (b *Buffer) AppendString(s string) {
b.b = append(b.b, s...)
}
func (b *buffer) len() int {
func (b *Buffer) Len() int {
return len(b.b) - b.i
}
func (b *buffer) bytes() []byte {
func (b *Buffer) Size() int {
return b.i
}
func (b *Buffer) Bytes() []byte {
return b.b[b.i:]
}
func (b *buffer) writeUint16(n uint16) {
func (b *Buffer) Detach() []byte {
temp := b.b
b.b = nil
b.i = 0
return temp
}
func (b *Buffer) AppendUint16(n uint16) {
b.b = append(b.b,
byte(n>>8),
byte(n),
)
}
func (b *buffer) writeUint32(n uint32) {
func (b *Buffer) AppendUint32(n uint32) {
b.b = append(b.b,
byte(n>>24),
byte(n>>16),
@ -153,7 +163,7 @@ func (b *buffer) writeUint32(n uint32) {
)
}
func (b *buffer) writeUint64(n uint64) {
func (b *Buffer) AppendUint64(n uint64) {
b.b = append(b.b,
byte(n>>56),
byte(n>>48),

12
link.go
Просмотреть файл

@ -7,6 +7,8 @@ import (
"fmt"
"sync"
"sync/atomic"
"github.com/Azure/go-amqp/internal/buffer"
)
// link is a unidirectional route.
@ -63,7 +65,7 @@ type link struct {
messages chan Message // used to send completed messages to receiver
unsettledMessages map[string]struct{} // used to keep track of messages being handled downstream
unsettledMessagesLock sync.RWMutex // lock to protect concurrent access to unsettledMessages
buf buffer // buffered bytes for current message
buf buffer.Buffer // buffered bytes for current message
more bool // if true, buf contains a partial message
msg Message // current message being decoded
}
@ -487,7 +489,7 @@ func (l *link) muxReceive(fr performTransfer) error {
// discard message if it's been aborted
if fr.Aborted {
l.buf.reset()
l.buf.Reset()
l.msg = Message{
doneSignal: make(chan struct{}),
}
@ -496,7 +498,7 @@ func (l *link) muxReceive(fr performTransfer) error {
}
// ensure maxMessageSize will not be exceeded
if l.maxMessageSize != 0 && uint64(l.buf.len())+uint64(len(fr.Payload)) > l.maxMessageSize {
if l.maxMessageSize != 0 && uint64(l.buf.Len())+uint64(len(fr.Payload)) > l.maxMessageSize {
msg := fmt.Sprintf("received message larger than max size of %d", l.maxMessageSize)
l.closeWithError(&Error{
Condition: ErrorMessageSizeExceeded,
@ -506,7 +508,7 @@ func (l *link) muxReceive(fr performTransfer) error {
}
// add the payload the the buffer
l.buf.write(fr.Payload)
l.buf.Append(fr.Payload)
// mark as settled if at least one frame is settled
l.msg.settled = l.msg.settled || fr.Settled
@ -534,7 +536,7 @@ func (l *link) muxReceive(fr performTransfer) error {
debug(1, "deliveryID %d after push to receiver - deliveryCount : %d - linkCredit: %d, len(messages): %d, len(inflight): %d", l.msg.deliveryID, l.deliveryCount, l.linkCredit, len(l.messages), len(l.receiver.inFlight.m))
// reset progress
l.buf.reset()
l.buf.Reset()
l.msg = Message{}
// decrement link-credit after entire message received

Просмотреть файл

@ -11,6 +11,8 @@ import (
"strings"
"testing"
"time"
"github.com/Azure/go-amqp/internal/buffer"
)
var exampleFrames = []struct {
@ -43,7 +45,7 @@ var exampleFrames = []struct {
func TestFrameMarshalUnmarshal(t *testing.T) {
for _, tt := range exampleFrames {
t.Run(tt.label, func(t *testing.T) {
var buf buffer
var buf buffer.Buffer
err := writeFrame(&buf, tt.frame)
if err != nil {
@ -78,15 +80,15 @@ func BenchmarkFrameMarshal(b *testing.B) {
for _, tt := range exampleFrames {
b.Run(tt.label, func(b *testing.B) {
b.ReportAllocs()
var buf buffer
var buf buffer.Buffer
for i := 0; i < b.N; i++ {
err := writeFrame(&buf, tt.frame)
if err != nil {
b.Error(fmt.Sprintf("%+v", err))
}
bytesSink = buf.bytes()
buf.reset()
bytesSink = buf.Bytes()
buf.Reset()
}
})
}
@ -95,19 +97,19 @@ func BenchmarkFrameUnmarshal(b *testing.B) {
for _, tt := range exampleFrames {
b.Run(tt.label, func(b *testing.B) {
b.ReportAllocs()
var buf buffer
var buf buffer.Buffer
err := writeFrame(&buf, tt.frame)
if err != nil {
b.Error(fmt.Sprintf("%+v", err))
}
data := buf.bytes()
buf.reset()
data := buf.Bytes()
buf.Reset()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
buf := &buffer{b: data}
buf := buffer.New(data)
_, err := parseFrameHeader(buf)
if err != nil {
b.Errorf("%+v", err)
@ -128,15 +130,15 @@ func BenchmarkMarshal(b *testing.B) {
for _, typ := range allTypes {
b.Run(fmt.Sprintf("%T", typ), func(b *testing.B) {
b.ReportAllocs()
var buf buffer
var buf buffer.Buffer
for i := 0; i < b.N; i++ {
err := marshal(&buf, typ)
if err != nil {
b.Error(fmt.Sprintf("%+v", err))
}
bytesSink = buf.bytes()
buf.reset()
bytesSink = buf.Bytes()
buf.Reset()
}
})
}
@ -145,19 +147,19 @@ func BenchmarkMarshal(b *testing.B) {
func BenchmarkUnmarshal(b *testing.B) {
for _, type_ := range allTypes {
b.Run(fmt.Sprintf("%T", type_), func(b *testing.B) {
var buf buffer
var buf buffer.Buffer
err := marshal(&buf, type_)
if err != nil {
b.Error(fmt.Sprintf("%+v", err))
}
data := buf.bytes()
data := buf.Bytes()
newType := reflect.New(reflect.TypeOf(type_)).Interface()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
err = unmarshal(&buffer{b: data}, newType)
err = unmarshal(buffer.New(data), newType)
if err != nil {
b.Error(fmt.Sprintf("%v", err))
}
@ -171,7 +173,7 @@ func TestMarshalUnmarshal(t *testing.T) {
for _, type_ := range allTypes {
t.Run(fmt.Sprintf("%T", type_), func(t *testing.T) {
var buf buffer
var buf buffer.Buffer
err := marshal(&buf, type_)
if err != nil {
t.Fatal(fmt.Sprintf("%+v", err))
@ -182,7 +184,7 @@ func TestMarshalUnmarshal(t *testing.T) {
name = strings.TrimPrefix(name, "amqp.")
name = strings.TrimPrefix(name, "*amqp.")
path := filepath.Join("fuzz/marshal/corpus", name)
err = ioutil.WriteFile(path, buf.bytes(), 0644)
err = ioutil.WriteFile(path, buf.Bytes(), 0644)
if err != nil {
t.Error(err)
}
@ -215,7 +217,7 @@ func TestMarshalUnmarshal(t *testing.T) {
// Regression test for time calculation bug.
// https://github.com/vcabbage/amqp/issues/173
func TestIssue173(t *testing.T) {
var buf buffer
var buf buffer.Buffer
// NOTE: Dates after the Unix Epoch don't trigger the bug, only
// dates that negative Unix time show the problem.
want := time.Date(1969, 03, 21, 0, 0, 0, 0, time.UTC)
@ -236,7 +238,7 @@ func TestIssue173(t *testing.T) {
func TestReadAny(t *testing.T) {
for _, type_ := range generalTypes {
t.Run(fmt.Sprintf("%T", type_), func(t *testing.T) {
var buf buffer
var buf buffer.Buffer
err := marshal(&buf, type_)
if err != nil {
t.Errorf("%+v", err)

Просмотреть файл

@ -2,6 +2,8 @@ package amqp
import (
"fmt"
"github.com/Azure/go-amqp/internal/buffer"
)
// SASL Codes
@ -20,11 +22,11 @@ const (
type saslCode uint8
func (s saslCode) marshal(wr *buffer) error {
func (s saslCode) marshal(wr *buffer.Buffer) error {
return marshal(wr, uint8(s))
}
func (s *saslCode) unmarshal(r *buffer) error {
func (s *saslCode) unmarshal(r *buffer.Buffer) error {
n, err := readUbyte(r)
*s = saslCode(n)
return err

Просмотреть файл

@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/Azure/go-amqp/internal/buffer"
"github.com/Azure/go-amqp/internal/testconn"
)
@ -226,15 +227,12 @@ func peerResponse(items ...interface{}) ([]byte, error) {
for _, item := range items {
switch v := item.(type) {
case frame:
b := &buffer{
b: make([]byte, 0),
i: 0,
}
b := &buffer.Buffer{}
e := writeFrame(b, v)
if e != nil {
return buf, e
}
buf = append(buf, b.bytes()...)
buf = append(buf, b.Bytes()...)
case []byte:
buf = append(buf, v...)
default:

Просмотреть файл

@ -6,6 +6,8 @@ import (
"fmt"
"sync"
"sync/atomic"
"github.com/Azure/go-amqp/internal/buffer"
)
// Sender sends messages on a single AMQP link.
@ -13,7 +15,7 @@ type Sender struct {
link *link
mu sync.Mutex // protects buf and nextDeliveryTag
buf buffer
buf buffer.Buffer
nextDeliveryTag uint64
}
@ -65,13 +67,13 @@ func (s *Sender) send(ctx context.Context, msg *Message) (chan deliveryState, er
s.mu.Lock()
defer s.mu.Unlock()
s.buf.reset()
s.buf.Reset()
err := msg.marshal(&s.buf)
if err != nil {
return nil, err
}
if s.link.maxMessageSize != 0 && uint64(s.buf.len()) > s.link.maxMessageSize {
if s.link.maxMessageSize != 0 && uint64(s.buf.Len()) > s.link.maxMessageSize {
return nil, fmt.Errorf("encoded message size exceeds max of %d", s.link.maxMessageSize)
}
@ -95,13 +97,13 @@ func (s *Sender) send(ctx context.Context, msg *Message) (chan deliveryState, er
DeliveryID: &deliveryID,
DeliveryTag: deliveryTag,
MessageFormat: &msg.Format,
More: s.buf.len() > 0,
More: s.buf.Len() > 0,
}
for fr.More {
buf, _ := s.buf.next(maxPayloadSize)
buf, _ := s.buf.Next(maxPayloadSize)
fr.Payload = append([]byte(nil), buf...)
fr.More = s.buf.len() > 0
fr.More = s.buf.Len() > 0
if !fr.More {
// SSM=settled: overrides RSM; no acks.
// SSM=unsettled: sender should wait for receiver to ack

456
types.go

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -4,6 +4,7 @@ import (
"math"
"testing"
"github.com/Azure/go-amqp/internal/buffer"
"github.com/stretchr/testify/require"
)
@ -15,9 +16,9 @@ func TestMarshalArrayInt64AsLongArray(t *testing.T) {
// typeCodeSmalllong (1 byte, signed).
ai := arrayInt64([]int64{math.MaxInt8 + 1})
buff := &buffer{}
buff := &buffer.Buffer{}
require.NoError(t, ai.marshal(buff))
require.EqualValues(t, amqpArrayHeaderLength+8, buff.len(), "Expected an AMQP header (4 bytes) + 8 bytes for a long")
require.EqualValues(t, amqpArrayHeaderLength+8, buff.Len(), "Expected an AMQP header (4 bytes) + 8 bytes for a long")
unmarshalled := arrayInt64{}
require.NoError(t, unmarshalled.unmarshal(buff))
@ -30,9 +31,9 @@ func TestMarshalArrayInt64AsSmallLongArray(t *testing.T) {
// we can save some space.
ai := arrayInt64([]int64{math.MaxInt8, math.MinInt8})
buff := &buffer{}
buff := &buffer.Buffer{}
require.NoError(t, ai.marshal(buff))
require.EqualValues(t, amqpArrayHeaderLength+1+1, buff.len(), "Expected an AMQP header (4 bytes) + 1 byte apiece for the two values")
require.EqualValues(t, amqpArrayHeaderLength+1+1, buff.Len(), "Expected an AMQP header (4 bytes) + 1 byte apiece for the two values")
unmarshalled := arrayInt64{}
require.NoError(t, unmarshalled.unmarshal(buff))