diff --git a/stress/client/main.go b/stress/client/main.go index 916d7f8b..d0f78d62 100644 --- a/stress/client/main.go +++ b/stress/client/main.go @@ -177,45 +177,56 @@ func startServer(server *server, port int) { // stressClient defines client for stress test. type stressClient struct { - testID int - address string - testDurationSecs int - selector *weightedRandomTestSelector - interopClient testpb.TestServiceClient + testID int + address string + selector *weightedRandomTestSelector + interopClient testpb.TestServiceClient + stop <-chan bool } // newStressClient construct a new stressClient. -func newStressClient(id int, addr string, conn *grpc.ClientConn, selector *weightedRandomTestSelector, testDurSecs int) *stressClient { +func newStressClient(id int, addr string, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) *stressClient { client := testpb.NewTestServiceClient(conn) - return &stressClient{testID: id, address: addr, selector: selector, testDurationSecs: testDurSecs, interopClient: client} + return &stressClient{testID: id, address: addr, selector: selector, interopClient: client, stop: stop} } // mainLoop uses weightedRandomTestSelector to select test case and runs the tests. func (c *stressClient) mainLoop(gauge *gauge) { var numCalls int64 timeStarted := time.Now() - for testEndTime := time.Now().Add(time.Duration(c.testDurationSecs) * time.Second); c.testDurationSecs < 0 || time.Now().Before(testEndTime); { - test, err := c.selector.getNextTest() - if err != nil { - grpclog.Printf("%v", err) - continue + for { + done := make(chan bool) + go func() { + test, err := c.selector.getNextTest() + if err != nil { + grpclog.Printf("%v", err) + done <- false + } + switch test { + case "empty_unary": + interop.DoEmptyUnaryCall(c.interopClient) + case "large_unary": + interop.DoLargeUnaryCall(c.interopClient) + case "client_streaming": + interop.DoClientStreaming(c.interopClient) + case "server_streaming": + interop.DoServerStreaming(c.interopClient) + case "empty_stream": + interop.DoEmptyStream(c.interopClient) + default: + grpclog.Fatalf("Unsupported test case: %d", test) + } + done <- true + }() + select { + case <-c.stop: + return + case r := <-done: + if r { + numCalls++ + gauge.set(int64(float64(numCalls) / time.Since(timeStarted).Seconds())) + } } - switch test { - case "empty_unary": - interop.DoEmptyUnaryCall(c.interopClient) - case "large_unary": - interop.DoLargeUnaryCall(c.interopClient) - case "client_streaming": - interop.DoClientStreaming(c.interopClient) - case "server_streaming": - interop.DoServerStreaming(c.interopClient) - case "empty_stream": - interop.DoEmptyStream(c.interopClient) - default: - grpclog.Fatalf("Unsupported test case: %d", test) - } - numCalls++ - gauge.set(int64(float64(numCalls) / time.Since(timeStarted).Seconds())) } } @@ -250,6 +261,8 @@ func main() { var wg sync.WaitGroup wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel) + stop := make(chan bool) + var clientIndex int for serverIndex, address := range addresses { for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ { @@ -260,7 +273,7 @@ func main() { defer conn.Close() for stubIndex := 0; stubIndex < *numStubsPerChannel; stubIndex++ { clientIndex++ - client := newStressClient(clientIndex, address, conn, testSelector, *testDurationSecs) + client := newStressClient(clientIndex, address, conn, testSelector, stop) buf := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, stubIndex+1) go func() { defer wg.Done() @@ -275,6 +288,10 @@ func main() { } } go startServer(metricsServer, *metricsPort) + if *testDurationSecs > 0 { + time.Sleep(time.Duration(*testDurationSecs) * time.Second) + close(stop) + } wg.Wait() grpclog.Printf(" ===== ALL DONE ===== ")