status: handle invalid utf-8 characters (#2109)
fixes #2078 A status with invalid utf-8 characters could still be created, but invalid characters will be replaced with [Unicode replacement character](https://en.wikipedia.org/wiki/Specials_(Unicode_block)#Replacement_character) before being sent out. Those bytes will still be percent encoded. All details added to this invalid status will be dropped.
This commit is contained in:
Родитель
96cefb43cf
Коммит
9c658603f0
|
@ -6143,6 +6143,73 @@ func TestServeExitsWhenListenerClosed(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Service handler returns status with invalid utf8 message.
|
||||
func TestStatusInvalidUTF8Message(t *testing.T) {
|
||||
defer leakcheck.Check(t)
|
||||
|
||||
var (
|
||||
origMsg = string([]byte{0xff, 0xfe, 0xfd})
|
||||
wantMsg = "<22><><EFBFBD>"
|
||||
)
|
||||
|
||||
ss := &stubServer{
|
||||
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
return nil, status.Errorf(codes.Internal, origMsg)
|
||||
},
|
||||
}
|
||||
if err := ss.Start(nil); err != nil {
|
||||
t.Fatalf("Error starting endpoint server: %v", err)
|
||||
}
|
||||
defer ss.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMsg {
|
||||
t.Fatalf("ss.client.EmptyCall(_, _) = _, %v (msg %q); want _, err with msg %q", err, status.Convert(err).Message(), wantMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// Service handler returns status with details and invalid utf8 message. Proto
|
||||
// will fail to marshal the status because of the invalid utf8 message. Details
|
||||
// will be dropped when sending.
|
||||
func TestStatusInvalidUTF8Details(t *testing.T) {
|
||||
defer leakcheck.Check(t)
|
||||
|
||||
var (
|
||||
origMsg = string([]byte{0xff, 0xfe, 0xfd})
|
||||
wantMsg = "<22><><EFBFBD>"
|
||||
)
|
||||
|
||||
ss := &stubServer{
|
||||
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
st := status.New(codes.Internal, origMsg)
|
||||
st, err := st.WithDetails(&testpb.Empty{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, st.Err()
|
||||
},
|
||||
}
|
||||
if err := ss.Start(nil); err != nil {
|
||||
t.Fatalf("Error starting endpoint server: %v", err)
|
||||
}
|
||||
defer ss.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := ss.client.EmptyCall(ctx, &testpb.Empty{})
|
||||
st := status.Convert(err)
|
||||
if st.Message() != wantMsg {
|
||||
t.Fatalf("ss.client.EmptyCall(_, _) = _, %v (msg %q); want _, err with msg %q", err, st.Message(), wantMsg)
|
||||
}
|
||||
if len(st.Details()) != 0 {
|
||||
// Details should be dropped on the server side.
|
||||
t.Fatalf("RPC status contain details: %v, want no details", st.Details())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientDoesntDeadlockWhileWritingErrornousLargeMessages(t *testing.T) {
|
||||
defer leakcheck.Check(t)
|
||||
for _, e := range listTestEnv() {
|
||||
|
|
|
@ -38,6 +38,7 @@ import (
|
|||
"google.golang.org/grpc/channelz"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
"google.golang.org/grpc/internal/grpcrand"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
@ -769,10 +770,10 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
|
|||
stBytes, err := proto.Marshal(p)
|
||||
if err != nil {
|
||||
// TODO: return error instead, when callers are able to handle it.
|
||||
panic(err)
|
||||
grpclog.Errorf("transport: failed to marshal rpc status: %v, error: %v", p, err)
|
||||
} else {
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
|
||||
}
|
||||
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
|
||||
}
|
||||
|
||||
// Attach the trailer metadata.
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"golang.org/x/net/http2"
|
||||
|
@ -442,11 +443,12 @@ const (
|
|||
)
|
||||
|
||||
// encodeGrpcMessage is used to encode status code in header field
|
||||
// "grpc-message".
|
||||
// It checks to see if each individual byte in msg is an
|
||||
// allowable byte, and then either percent encoding or passing it through.
|
||||
// When percent encoding, the byte is converted into hexadecimal notation
|
||||
// with a '%' prepended.
|
||||
// "grpc-message". It does percent encoding and also replaces invalid utf-8
|
||||
// characters with Unicode replacement character.
|
||||
//
|
||||
// It checks to see if each individual byte in msg is an allowable byte, and
|
||||
// then either percent encoding or passing it through. When percent encoding,
|
||||
// the byte is converted into hexadecimal notation with a '%' prepended.
|
||||
func encodeGrpcMessage(msg string) string {
|
||||
if msg == "" {
|
||||
return ""
|
||||
|
@ -463,14 +465,26 @@ func encodeGrpcMessage(msg string) string {
|
|||
|
||||
func encodeGrpcMessageUnchecked(msg string) string {
|
||||
var buf bytes.Buffer
|
||||
lenMsg := len(msg)
|
||||
for i := 0; i < lenMsg; i++ {
|
||||
c := msg[i]
|
||||
if c >= spaceByte && c < tildaByte && c != percentByte {
|
||||
buf.WriteByte(c)
|
||||
} else {
|
||||
buf.WriteString(fmt.Sprintf("%%%02X", c))
|
||||
for len(msg) > 0 {
|
||||
r, size := utf8.DecodeRuneInString(msg)
|
||||
for _, b := range []byte(string(r)) {
|
||||
if size > 1 {
|
||||
// If size > 1, r is not ascii. Always do percent encoding.
|
||||
buf.WriteString(fmt.Sprintf("%%%02X", b))
|
||||
continue
|
||||
}
|
||||
|
||||
// The for loop is necessary even if size == 1. r could be
|
||||
// utf8.RuneError.
|
||||
//
|
||||
// fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
|
||||
if b >= spaceByte && b < tildaByte && b != percentByte {
|
||||
buf.WriteByte(b)
|
||||
} else {
|
||||
buf.WriteString(fmt.Sprintf("%%%02X", b))
|
||||
}
|
||||
}
|
||||
msg = msg[size:]
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
|
|
@ -102,12 +102,14 @@ func TestEncodeGrpcMessage(t *testing.T) {
|
|||
}{
|
||||
{"", ""},
|
||||
{"Hello", "Hello"},
|
||||
{"my favorite character is \u0000", "my favorite character is %00"},
|
||||
{"my favorite character is %", "my favorite character is %25"},
|
||||
{"\u0000", "%00"},
|
||||
{"%", "%25"},
|
||||
{"系统", "%E7%B3%BB%E7%BB%9F"},
|
||||
{string([]byte{0xff, 0xfe, 0xfd}), "%EF%BF%BD%EF%BF%BD%EF%BF%BD"},
|
||||
} {
|
||||
actual := encodeGrpcMessage(tt.input)
|
||||
if tt.expected != actual {
|
||||
t.Errorf("encodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
|
||||
t.Errorf("encodeGrpcMessage(%q) = %q, want %q", tt.input, actual, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -123,10 +125,36 @@ func TestDecodeGrpcMessage(t *testing.T) {
|
|||
{"H%6", "H%6"},
|
||||
{"%G0", "%G0"},
|
||||
{"%E7%B3%BB%E7%BB%9F", "系统"},
|
||||
{"%EF%BF%BD", "<22>"},
|
||||
} {
|
||||
actual := decodeGrpcMessage(tt.input)
|
||||
if tt.expected != actual {
|
||||
t.Errorf("dncodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
|
||||
t.Errorf("dncodeGrpcMessage(%q) = %q, want %q", tt.input, actual, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decode an encoded string should get the same thing back, except for invalid
|
||||
// utf8 chars.
|
||||
func TestDecodeEncodeGrpcMessage(t *testing.T) {
|
||||
testCases := []struct {
|
||||
orig string
|
||||
want string
|
||||
}{
|
||||
{"", ""},
|
||||
{"hello", "hello"},
|
||||
{"h%6", "h%6"},
|
||||
{"%G0", "%G0"},
|
||||
{"系统", "系统"},
|
||||
{"Hello, 世界", "Hello, 世界"},
|
||||
|
||||
{string([]byte{0xff, 0xfe, 0xfd}), "<22><><EFBFBD>"},
|
||||
{string([]byte{0xff}) + "Hello" + string([]byte{0xfe}) + "世界" + string([]byte{0xfd}), "<22>Hello<6C>世界<E4B896>"},
|
||||
}
|
||||
for _, tC := range testCases {
|
||||
got := decodeGrpcMessage(encodeGrpcMessage(tC.orig))
|
||||
if got != tC.want {
|
||||
t.Errorf("decodeGrpcMessage(encodeGrpcMessage(%q)) = %q, want %q", tC.orig, got, tC.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче