285 строки
9.4 KiB
Go
285 строки
9.4 KiB
Go
|
/*
|
||
|
*
|
||
|
* Copyright 2020 gRPC authors.
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*
|
||
|
*/
|
||
|
|
||
|
package test
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"io"
|
||
|
"testing"
|
||
|
|
||
|
"google.golang.org/grpc"
|
||
|
"google.golang.org/grpc/codes"
|
||
|
"google.golang.org/grpc/status"
|
||
|
testpb "google.golang.org/grpc/test/grpc_testing"
|
||
|
)
|
||
|
|
||
|
type ctxKey string
|
||
|
|
||
|
func (s) TestChainUnaryServerInterceptor(t *testing.T) {
|
||
|
var (
|
||
|
firstIntKey = ctxKey("firstIntKey")
|
||
|
secondIntKey = ctxKey("secondIntKey")
|
||
|
)
|
||
|
|
||
|
firstInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||
|
if ctx.Value(firstIntKey) != nil {
|
||
|
return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", firstIntKey)
|
||
|
}
|
||
|
if ctx.Value(secondIntKey) != nil {
|
||
|
return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", secondIntKey)
|
||
|
}
|
||
|
|
||
|
firstCtx := context.WithValue(ctx, firstIntKey, 0)
|
||
|
resp, err := handler(firstCtx, req)
|
||
|
if err != nil {
|
||
|
return nil, status.Errorf(codes.Internal, "failed to handle request at firstInt")
|
||
|
}
|
||
|
|
||
|
simpleResp, ok := resp.(*testpb.SimpleResponse)
|
||
|
if !ok {
|
||
|
return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at firstInt")
|
||
|
}
|
||
|
return &testpb.SimpleResponse{
|
||
|
Payload: &testpb.Payload{
|
||
|
Type: simpleResp.GetPayload().GetType(),
|
||
|
Body: append(simpleResp.GetPayload().GetBody(), '1'),
|
||
|
},
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
secondInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||
|
if ctx.Value(firstIntKey) == nil {
|
||
|
return nil, status.Errorf(codes.Internal, "second interceptor should have %v in context", firstIntKey)
|
||
|
}
|
||
|
if ctx.Value(secondIntKey) != nil {
|
||
|
return nil, status.Errorf(codes.Internal, "second interceptor should not have %v in context", secondIntKey)
|
||
|
}
|
||
|
|
||
|
secondCtx := context.WithValue(ctx, secondIntKey, 1)
|
||
|
resp, err := handler(secondCtx, req)
|
||
|
if err != nil {
|
||
|
return nil, status.Errorf(codes.Internal, "failed to handle request at secondInt")
|
||
|
}
|
||
|
|
||
|
simpleResp, ok := resp.(*testpb.SimpleResponse)
|
||
|
if !ok {
|
||
|
return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at secondInt")
|
||
|
}
|
||
|
return &testpb.SimpleResponse{
|
||
|
Payload: &testpb.Payload{
|
||
|
Type: simpleResp.GetPayload().GetType(),
|
||
|
Body: append(simpleResp.GetPayload().GetBody(), '2'),
|
||
|
},
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
lastInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||
|
if ctx.Value(firstIntKey) == nil {
|
||
|
return nil, status.Errorf(codes.Internal, "last interceptor should have %v in context", firstIntKey)
|
||
|
}
|
||
|
if ctx.Value(secondIntKey) == nil {
|
||
|
return nil, status.Errorf(codes.Internal, "last interceptor should not have %v in context", secondIntKey)
|
||
|
}
|
||
|
|
||
|
resp, err := handler(ctx, req)
|
||
|
if err != nil {
|
||
|
return nil, status.Errorf(codes.Internal, "failed to handle request at lastInt at lastInt")
|
||
|
}
|
||
|
|
||
|
simpleResp, ok := resp.(*testpb.SimpleResponse)
|
||
|
if !ok {
|
||
|
return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at lastInt")
|
||
|
}
|
||
|
return &testpb.SimpleResponse{
|
||
|
Payload: &testpb.Payload{
|
||
|
Type: simpleResp.GetPayload().GetType(),
|
||
|
Body: append(simpleResp.GetPayload().GetBody(), '3'),
|
||
|
},
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
sopts := []grpc.ServerOption{
|
||
|
grpc.ChainUnaryInterceptor(firstInt, secondInt, lastInt),
|
||
|
}
|
||
|
|
||
|
ss := &stubServer{
|
||
|
unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
|
||
|
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 0)
|
||
|
if err != nil {
|
||
|
return nil, status.Errorf(codes.Aborted, "failed to make payload: %v", err)
|
||
|
}
|
||
|
|
||
|
return &testpb.SimpleResponse{
|
||
|
Payload: payload,
|
||
|
}, nil
|
||
|
},
|
||
|
}
|
||
|
if err := ss.Start(sopts); err != nil {
|
||
|
t.Fatalf("Error starting endpoint server: %v", err)
|
||
|
}
|
||
|
defer ss.Stop()
|
||
|
|
||
|
resp, err := ss.client.UnaryCall(context.Background(), &testpb.SimpleRequest{})
|
||
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.OK {
|
||
|
t.Fatalf("ss.client.UnaryCall(context.Background(), _) = %v, %v; want nil, <status with Code()=OK>", resp, err)
|
||
|
}
|
||
|
|
||
|
respBytes := resp.Payload.GetBody()
|
||
|
if string(respBytes) != "321" {
|
||
|
t.Fatalf("invalid response: want=%s, but got=%s", "321", resp)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s) TestChainOnBaseUnaryServerInterceptor(t *testing.T) {
|
||
|
baseIntKey := ctxKey("baseIntKey")
|
||
|
|
||
|
baseInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||
|
if ctx.Value(baseIntKey) != nil {
|
||
|
return nil, status.Errorf(codes.Internal, "base interceptor should not have %v in context", baseIntKey)
|
||
|
}
|
||
|
|
||
|
baseCtx := context.WithValue(ctx, baseIntKey, 1)
|
||
|
return handler(baseCtx, req)
|
||
|
}
|
||
|
|
||
|
chainInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||
|
if ctx.Value(baseIntKey) == nil {
|
||
|
return nil, status.Errorf(codes.Internal, "chain interceptor should have %v in context", baseIntKey)
|
||
|
}
|
||
|
|
||
|
return handler(ctx, req)
|
||
|
}
|
||
|
|
||
|
sopts := []grpc.ServerOption{
|
||
|
grpc.UnaryInterceptor(baseInt),
|
||
|
grpc.ChainUnaryInterceptor(chainInt),
|
||
|
}
|
||
|
|
||
|
ss := &stubServer{
|
||
|
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||
|
return &testpb.Empty{}, nil
|
||
|
},
|
||
|
}
|
||
|
if err := ss.Start(sopts); err != nil {
|
||
|
t.Fatalf("Error starting endpoint server: %v", err)
|
||
|
}
|
||
|
defer ss.Stop()
|
||
|
|
||
|
resp, err := ss.client.EmptyCall(context.Background(), &testpb.Empty{})
|
||
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.OK {
|
||
|
t.Fatalf("ss.client.EmptyCall(context.Background(), _) = %v, %v; want nil, <status with Code()=OK>", resp, err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s) TestChainStreamServerInterceptor(t *testing.T) {
|
||
|
callCounts := make([]int, 4)
|
||
|
|
||
|
firstInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||
|
if callCounts[0] != 0 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[0] should be 0, but got=%d", callCounts[0])
|
||
|
}
|
||
|
if callCounts[1] != 0 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1])
|
||
|
}
|
||
|
if callCounts[2] != 0 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
|
||
|
}
|
||
|
if callCounts[3] != 0 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
|
||
|
}
|
||
|
callCounts[0]++
|
||
|
return handler(srv, stream)
|
||
|
}
|
||
|
|
||
|
secondInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||
|
if callCounts[0] != 1 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
|
||
|
}
|
||
|
if callCounts[1] != 0 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1])
|
||
|
}
|
||
|
if callCounts[2] != 0 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
|
||
|
}
|
||
|
if callCounts[3] != 0 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
|
||
|
}
|
||
|
callCounts[1]++
|
||
|
return handler(srv, stream)
|
||
|
}
|
||
|
|
||
|
lastInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||
|
if callCounts[0] != 1 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
|
||
|
}
|
||
|
if callCounts[1] != 1 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1])
|
||
|
}
|
||
|
if callCounts[2] != 0 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
|
||
|
}
|
||
|
if callCounts[3] != 0 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
|
||
|
}
|
||
|
callCounts[2]++
|
||
|
return handler(srv, stream)
|
||
|
}
|
||
|
|
||
|
sopts := []grpc.ServerOption{
|
||
|
grpc.ChainStreamInterceptor(firstInt, secondInt, lastInt),
|
||
|
}
|
||
|
|
||
|
ss := &stubServer{
|
||
|
fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
|
||
|
if callCounts[0] != 1 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
|
||
|
}
|
||
|
if callCounts[1] != 1 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1])
|
||
|
}
|
||
|
if callCounts[2] != 1 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
|
||
|
}
|
||
|
if callCounts[3] != 0 {
|
||
|
return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
|
||
|
}
|
||
|
callCounts[3]++
|
||
|
return nil
|
||
|
},
|
||
|
}
|
||
|
if err := ss.Start(sopts); err != nil {
|
||
|
t.Fatalf("Error starting endpoint server: %v", err)
|
||
|
}
|
||
|
defer ss.Stop()
|
||
|
|
||
|
stream, err := ss.client.FullDuplexCall(context.Background())
|
||
|
if err != nil {
|
||
|
t.Fatalf("failed to FullDuplexCall: %v", err)
|
||
|
}
|
||
|
|
||
|
_, err = stream.Recv()
|
||
|
if err != io.EOF {
|
||
|
t.Fatalf("failed to recv from stream: %v", err)
|
||
|
}
|
||
|
|
||
|
if callCounts[3] != 1 {
|
||
|
t.Fatalf("callCounts[3] should be 1, but got=%d", callCounts[3])
|
||
|
}
|
||
|
}
|