From 85034d5828ca1b4e6e3d66c33036ab01b2fb66b3 Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Thu, 5 Mar 2015 18:52:06 -0800 Subject: [PATCH] fix double wrapping of rpc status --- server.go | 24 ++++++++++++++++++------ test/end2end_test.go | 43 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/server.go b/server.go index 9152d7ae..6c7a78b2 100644 --- a/server.go +++ b/server.go @@ -241,9 +241,18 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } switch pf { case compressionNone: + statusCode := codes.OK + statusDesc := "" reply, appErr := md.Handler(srv.server, stream.Context(), req) if appErr != nil { - if err := t.WriteStatus(stream, convertCode(appErr), appErr.Error()); err != nil { + if err, ok := appErr.(rpcError); ok { + statusCode = err.code + statusDesc = err.desc + } else { + statusCode = convertCode(appErr) + statusDesc = appErr.Error() + } + if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil { log.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) } return @@ -252,8 +261,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Last: true, Delay: false, } - statusCode := codes.OK - statusDesc := "" if err := s.sendProto(t, stream, reply, compressionNone, opts); err != nil { if _, ok := err.(transport.ConnectionError); ok { return @@ -281,9 +288,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp s: stream, p: &parser{s: stream}, } - if err := sd.Handler(srv.server, ss); err != nil { - ss.statusCode = convertCode(err) - ss.statusDesc = err.Error() + if appErr := sd.Handler(srv.server, ss); appErr != nil { + if err, ok := appErr.(rpcError); ok { + ss.statusCode = err.code + ss.statusDesc = err.desc + } else { + ss.statusCode = convertCode(appErr) + ss.statusDesc = appErr.Error() + } } if err := t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc); err != nil { log.Printf("grpc: Server.processStreamingRPC failed to write status: %v", err) diff --git a/test/end2end_test.go b/test/end2end_test.go index 006f952d..bfc459a9 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -63,6 +63,10 @@ type testServer struct { } func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + if _, ok := metadata.FromContext(ctx); ok { + // For testing purpose, returns an error if there is attached metadata. + return nil, grpc.Errorf(codes.DataLoss, "got extra metadata") + } return new(testpb.Empty), nil } @@ -100,6 +104,11 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* } func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error { + if _, ok := metadata.FromContext(stream.Context()); ok { + log.Println("REACH HERE !!!") + // For testing purpose, returns an error if there is attached metadata. + return grpc.Errorf(codes.DataLoss, "got extra metadata") + } cs := args.GetResponseParameters() for _, c := range cs { if us := c.GetIntervalUs(); us > 0 { @@ -298,7 +307,16 @@ func TestEmptyUnary(t *testing.T) { defer s.Stop() reply, err := tc.EmptyCall(context.Background(), &testpb.Empty{}) if err != nil || !proto.Equal(&testpb.Empty{}, reply) { - t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want %v, ", reply, err, &testpb.Empty{}) + t.Fatalf("TestService/EmptyCall(_, _) = %v, %v, want %v, ", reply, err, &testpb.Empty{}) + } +} + +func TestFailedEmptyUnary(t *testing.T) { + s, tc := setUp(true, math.MaxUint32) + defer s.Stop() + ctx := metadata.NewContext(context.Background(), testMetadata) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != grpc.Errorf(codes.DataLoss, "got extra metadata") { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, grpc.Errorf(codes.DataLoss, "got extra metadata")) } } @@ -580,6 +598,29 @@ func TestServerStreaming(t *testing.T) { } } +func TestFailedServerStreaming(t *testing.T) { + s, tc := setUp(true, math.MaxUint32) + defer s.Stop() + respParam := make([]*testpb.ResponseParameters, len(respSizes)) + for i, s := range respSizes { + respParam[i] = &testpb.ResponseParameters{ + Size: proto.Int32(int32(s)), + } + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + } + ctx := metadata.NewContext(context.Background(), testMetadata) + stream, err := tc.StreamingOutputCall(ctx, req) + if err != nil { + t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want ", tc, err) + } + if _, err := stream.Recv(); err != grpc.Errorf(codes.DataLoss, "got extra metadata") { + t.Fatalf("%v.Recv() = _, %v, want _, %v", stream, err, grpc.Errorf(codes.DataLoss, "got extra metadata")) + } +} + func TestClientStreaming(t *testing.T) { s, tc := setUp(true, math.MaxUint32) defer s.Stop()