status: handle invalid utf-8 characters (#2109)

fixes #2078

A status with invalid utf-8 characters could still be created, but invalid characters will be replaced with [Unicode replacement character](https://en.wikipedia.org/wiki/Specials_(Unicode_block)#Replacement_character) before being sent out. Those bytes will still be percent encoded.

All details added to this invalid status will be dropped.
This commit is contained in:
Menghan Li 2018-06-06 10:58:32 -07:00 коммит произвёл GitHub
Родитель 96cefb43cf
Коммит 9c658603f0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 129 добавлений и 19 удалений

Просмотреть файл

@ -6143,6 +6143,73 @@ func TestServeExitsWhenListenerClosed(t *testing.T) {
}
}
// Service handler returns status with invalid utf8 message.
func TestStatusInvalidUTF8Message(t *testing.T) {
defer leakcheck.Check(t)
var (
origMsg = string([]byte{0xff, 0xfe, 0xfd})
wantMsg = "<22><><EFBFBD>"
)
ss := &stubServer{
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return nil, status.Errorf(codes.Internal, origMsg)
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMsg {
t.Fatalf("ss.client.EmptyCall(_, _) = _, %v (msg %q); want _, err with msg %q", err, status.Convert(err).Message(), wantMsg)
}
}
// Service handler returns status with details and invalid utf8 message. Proto
// will fail to marshal the status because of the invalid utf8 message. Details
// will be dropped when sending.
func TestStatusInvalidUTF8Details(t *testing.T) {
defer leakcheck.Check(t)
var (
origMsg = string([]byte{0xff, 0xfe, 0xfd})
wantMsg = "<22><><EFBFBD>"
)
ss := &stubServer{
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
st := status.New(codes.Internal, origMsg)
st, err := st.WithDetails(&testpb.Empty{})
if err != nil {
return nil, err
}
return nil, st.Err()
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, err := ss.client.EmptyCall(ctx, &testpb.Empty{})
st := status.Convert(err)
if st.Message() != wantMsg {
t.Fatalf("ss.client.EmptyCall(_, _) = _, %v (msg %q); want _, err with msg %q", err, st.Message(), wantMsg)
}
if len(st.Details()) != 0 {
// Details should be dropped on the server side.
t.Fatalf("RPC status contain details: %v, want no details", st.Details())
}
}
func TestClientDoesntDeadlockWhileWritingErrornousLargeMessages(t *testing.T) {
defer leakcheck.Check(t)
for _, e := range listTestEnv() {

Просмотреть файл

@ -38,6 +38,7 @@ import (
"google.golang.org/grpc/channelz"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
@ -769,10 +770,10 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
stBytes, err := proto.Marshal(p)
if err != nil {
// TODO: return error instead, when callers are able to handle it.
panic(err)
grpclog.Errorf("transport: failed to marshal rpc status: %v, error: %v", p, err)
} else {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
}
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
}
// Attach the trailer metadata.

Просмотреть файл

@ -28,6 +28,7 @@ import (
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/golang/protobuf/proto"
"golang.org/x/net/http2"
@ -442,11 +443,12 @@ const (
)
// encodeGrpcMessage is used to encode status code in header field
// "grpc-message".
// It checks to see if each individual byte in msg is an
// allowable byte, and then either percent encoding or passing it through.
// When percent encoding, the byte is converted into hexadecimal notation
// with a '%' prepended.
// "grpc-message". It does percent encoding and also replaces invalid utf-8
// characters with Unicode replacement character.
//
// It checks to see if each individual byte in msg is an allowable byte, and
// then either percent encoding or passing it through. When percent encoding,
// the byte is converted into hexadecimal notation with a '%' prepended.
func encodeGrpcMessage(msg string) string {
if msg == "" {
return ""
@ -463,14 +465,26 @@ func encodeGrpcMessage(msg string) string {
func encodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c >= spaceByte && c < tildaByte && c != percentByte {
buf.WriteByte(c)
} else {
buf.WriteString(fmt.Sprintf("%%%02X", c))
for len(msg) > 0 {
r, size := utf8.DecodeRuneInString(msg)
for _, b := range []byte(string(r)) {
if size > 1 {
// If size > 1, r is not ascii. Always do percent encoding.
buf.WriteString(fmt.Sprintf("%%%02X", b))
continue
}
// The for loop is necessary even if size == 1. r could be
// utf8.RuneError.
//
// fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
if b >= spaceByte && b < tildaByte && b != percentByte {
buf.WriteByte(b)
} else {
buf.WriteString(fmt.Sprintf("%%%02X", b))
}
}
msg = msg[size:]
}
return buf.String()
}

Просмотреть файл

@ -102,12 +102,14 @@ func TestEncodeGrpcMessage(t *testing.T) {
}{
{"", ""},
{"Hello", "Hello"},
{"my favorite character is \u0000", "my favorite character is %00"},
{"my favorite character is %", "my favorite character is %25"},
{"\u0000", "%00"},
{"%", "%25"},
{"系统", "%E7%B3%BB%E7%BB%9F"},
{string([]byte{0xff, 0xfe, 0xfd}), "%EF%BF%BD%EF%BF%BD%EF%BF%BD"},
} {
actual := encodeGrpcMessage(tt.input)
if tt.expected != actual {
t.Errorf("encodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
t.Errorf("encodeGrpcMessage(%q) = %q, want %q", tt.input, actual, tt.expected)
}
}
}
@ -123,10 +125,36 @@ func TestDecodeGrpcMessage(t *testing.T) {
{"H%6", "H%6"},
{"%G0", "%G0"},
{"%E7%B3%BB%E7%BB%9F", "系统"},
{"%EF%BF%BD", "<22>"},
} {
actual := decodeGrpcMessage(tt.input)
if tt.expected != actual {
t.Errorf("dncodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
t.Errorf("dncodeGrpcMessage(%q) = %q, want %q", tt.input, actual, tt.expected)
}
}
}
// Decode an encoded string should get the same thing back, except for invalid
// utf8 chars.
func TestDecodeEncodeGrpcMessage(t *testing.T) {
testCases := []struct {
orig string
want string
}{
{"", ""},
{"hello", "hello"},
{"h%6", "h%6"},
{"%G0", "%G0"},
{"系统", "系统"},
{"Hello, 世界", "Hello, 世界"},
{string([]byte{0xff, 0xfe, 0xfd}), "<22><><EFBFBD>"},
{string([]byte{0xff}) + "Hello" + string([]byte{0xfe}) + "世界" + string([]byte{0xfd}), "<22>Hello<6C>世界<E4B896>"},
}
for _, tC := range testCases {
got := decodeGrpcMessage(encodeGrpcMessage(tC.orig))
if got != tC.want {
t.Errorf("decodeGrpcMessage(encodeGrpcMessage(%q)) = %q, want %q", tC.orig, got, tC.want)
}
}
}