Use exact size, if known, to allocate decompression buffer (#3048)

For large messages this generates far less garbage than ioutil.ReadAll().

Implement for gzip - RFC1952 requires it, and the Go implementation
checks it already (modulo 2^32).
This commit is contained in:
Bryan Boreham 2019-10-04 18:05:56 +01:00 коммит произвёл Doug Fawley
Родитель 47d3cfe042
Коммит dcd1c9748d
3 изменённых файлов: 54 добавлений и 15 удалений

Просмотреть файл

@ -46,6 +46,10 @@ type Compressor interface {
// coding header. The result must be static; the result cannot change
// between calls.
Name() string
// EXPERIMENTAL: if a Compressor implements
// DecompressedSize(compressedBytes []byte) int, gRPC will call it
// to determine the size of the buffer allocated for the result of decompression.
// Return -1 to indicate unknown size.
}
var registeredCompressor = make(map[string]Compressor)

Просмотреть файл

@ -23,6 +23,7 @@ package gzip
import (
"compress/gzip"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
@ -107,6 +108,17 @@ func (z *reader) Read(p []byte) (n int, err error) {
return n, err
}
// RFC1952 specifies that the last four bytes "contains the size of
// the original (uncompressed) input data modulo 2^32."
// gRPC has a max message size of 2GB so we don't need to worry about wraparound.
func (c *compressor) DecompressedSize(buf []byte) int {
last := len(buf)
if last < 4 {
return -1
}
return int(binary.LittleEndian.Uint32(buf[last-4 : last]))
}
func (c *compressor) Name() string {
return Name
}

Просмотреть файл

@ -648,35 +648,58 @@ func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxRecei
return nil, st.Err()
}
var size int
if pf == compressionMade {
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
// use this decompressor as the default.
if dc != nil {
d, err = dc.Do(bytes.NewReader(d))
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
}
size = len(d)
} else {
dcReader, err := compressor.Decompress(bytes.NewReader(d))
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
}
// Read from LimitReader with limit max+1. So if the underlying
// reader is over limit, the result will be bigger than max.
d, err = ioutil.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
}
d, size, err = decompress(compressor, d, maxReceiveMessageSize)
}
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
}
} else {
size = len(d)
}
if len(d) > maxReceiveMessageSize {
if size > maxReceiveMessageSize {
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize)
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", size, maxReceiveMessageSize)
}
return d, nil
}
// Using compressor, decompress d, returning data and size.
// Optionally, if data will be over maxReceiveMessageSize, just return the size.
func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize int) ([]byte, int, error) {
dcReader, err := compressor.Decompress(bytes.NewReader(d))
if err != nil {
return nil, 0, err
}
if sizer, ok := compressor.(interface {
DecompressedSize(compressedBytes []byte) int
}); ok {
if size := sizer.DecompressedSize(d); size >= 0 {
if size > maxReceiveMessageSize {
return nil, size, nil
}
// size is used as an estimate to size the buffer, but we
// will read more data if available.
// +MinRead so ReadFrom will not reallocate if size is correct.
buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
return buf.Bytes(), int(bytesRead), err
}
}
// Read from LimitReader with limit max+1. So if the underlying
// reader is over limit, the result will be bigger than max.
d, err = ioutil.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
return d, len(d), err
}
// For the two compressor parameters, both should not be set, but if they are,
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?