From 719a427fe7884c3e44fe76e7e5a7a5f465aeea4e Mon Sep 17 00:00:00 2001 From: Ziyue Yang Date: Mon, 11 Dec 2023 06:53:51 +0800 Subject: [PATCH] Benchmarks: Microbenchmark - Add distributed inference benchmark cpp implementation (#586) **Description** Add distributed inference benchmark cpp implementation. --- .../benchmarks/micro-benchmarks.md | 2 +- .../micro_benchmarks/dist_inference.py | 205 +++++--- .../dist_inference_cpp/CMakeLists.txt | 40 ++ .../dist_inference_cpp/dist_inference.cu | 455 ++++++++++++++++++ superbench/config/azure_ndmv4.yaml | 2 + superbench/config/azure_ndv4.yaml | 2 + .../micro_benchmarks/test_dist_inference.py | 147 +++++- tests/data/dist_inference.log | 2 + 8 files changed, 791 insertions(+), 64 deletions(-) create mode 100644 superbench/benchmarks/micro_benchmarks/dist_inference_cpp/CMakeLists.txt create mode 100644 superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu create mode 100644 tests/data/dist_inference.log diff --git a/docs/user-tutorial/benchmarks/micro-benchmarks.md b/docs/user-tutorial/benchmarks/micro-benchmarks.md index ee082e8b..fd0365f4 100644 --- a/docs/user-tutorial/benchmarks/micro-benchmarks.md +++ b/docs/user-tutorial/benchmarks/micro-benchmarks.md @@ -418,7 +418,7 @@ Test the performance of large scale matmul operation with multiple GPUs: #### Introduction -Test the performance of distributed model inference. +Test the performance of distributed model inference. Support both PyTorch implementation and cpp implementation. #### Metrics diff --git a/superbench/benchmarks/micro_benchmarks/dist_inference.py b/superbench/benchmarks/micro_benchmarks/dist_inference.py index 8e51b6bd..ffc955ab 100644 --- a/superbench/benchmarks/micro_benchmarks/dist_inference.py +++ b/superbench/benchmarks/micro_benchmarks/dist_inference.py @@ -12,7 +12,7 @@ import torch.distributed as dist from superbench.common.utils import logger from superbench.benchmarks import DistributedImpl, DistributedBackend, BenchmarkRegistry, ReturnCode, Precision -from superbench.benchmarks.micro_benchmarks import MicroBenchmark +from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke from superbench.benchmarks.context import Enum from superbench.benchmarks.reducer import ReduceType @@ -168,7 +168,7 @@ class DistInferenceModel(torch.nn.Module): return activation_out -class DistInference(MicroBenchmark): +class DistInference(MicroBenchmarkWithInvoke): """The base class of micro-benchmarks.""" def __init__(self, name, parameters=''): """Constructor. @@ -182,7 +182,9 @@ class DistInference(MicroBenchmark): self.__local_rank = 0 torch.backends.cudnn.benchmark = True self.__device = None - self.__cuda_available = False + + # For cpp impl path + self._bin_name = 'dist_inference' def __timer(self): """Returns the current time which ensures all previous CUDA events have been finished. @@ -193,14 +195,19 @@ class DistInference(MicroBenchmark): Return: Current time in second. """ - if self.__cuda_available: - torch.cuda.synchronize() + torch.cuda.synchronize() return time.time() def add_parser_arguments(self): """Add the specified arguments.""" super().add_parser_arguments() + self._parser.add_argument( + '--use_pytorch', + action='store_true', + required=False, + help='Whether to use pytorch implementation. If not, cpp implementation will be used.', + ) self._parser.add_argument( '--batch_size', type=int, @@ -222,6 +229,20 @@ class DistInference(MicroBenchmark): required=False, help='Hidden size.', ) + self._parser.add_argument( + '--alpha', + type=float, + default=1.0, + required=False, + help='Coefficient alpha in D = alpha*(A*B) + beta*(C).', + ) + self._parser.add_argument( + '--beta', + type=float, + default=1.0, + required=False, + help='Coefficient beta in D = alpha*(A*B) + beta*(C).', + ) self._parser.add_argument( '--num_layers', type=int, @@ -285,6 +306,12 @@ class DistInference(MicroBenchmark): required=False, help='Distributed backends. E.g. {}.'.format(' '.join(DistributedBackend.get_values())), ) + self._parser.add_argument( + '--use_cuda_graph', + action='store_true', + required=False, + help='Whether to launch kernels in CUDA graph mode.', + ) def _preprocess(self): """Preprocess/preparation operations before the benchmarking. @@ -295,32 +322,41 @@ class DistInference(MicroBenchmark): if not super()._preprocess(): return False - if self._args.distributed_impl != DistributedImpl.DDP: - self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE) - logger.error( - 'Unsupported distributed implementation - model: {}, distributed implementation: {}.'.format( - self._name, self._args.distributed_impl + if self._args.use_pytorch: + # Initialize PyTorch if pytorch impl path + if self._args.distributed_impl != DistributedImpl.DDP: + return self._set_error_code_and_print_error_msg( + ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE, + 'Unsupported distributed implementation - model: {}, distributed implementation: {}.'.format( + self._name, self._args.distributed_impl + ) ) - ) - return False - try: - torch.distributed.init_process_group(backend=self._args.distributed_backend.value) - self.__world_size = int(os.environ['WORLD_SIZE']) - self.__local_rank = int(os.environ['LOCAL_RANK']) - except BaseException as e: - self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE) - torch.distributed.destroy_process_group() - logger.error('Initialize distributed env failed - benchmark: {}, message: {}.'.format(self._name, str(e))) - return False + try: + torch.distributed.init_process_group(backend=self._args.distributed_backend.value) + self.__world_size = int(os.environ['WORLD_SIZE']) + self.__local_rank = int(os.environ['LOCAL_RANK']) + assert (torch.cuda.is_available()) + except BaseException as e: + torch.distributed.destroy_process_group() + return self._set_error_code_and_print_error_msg( + ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE, + 'Initialize distributed env failed - benchmark: {}, message: {}.'.format(self._name, str(e)) + ) - if torch.cuda.is_available(): torch.cuda.set_device(self.__local_rank) self.__device = torch.device('cuda:{}'.format(self.__local_rank)) - self.__cuda_available = True else: - self.__device = torch.device('cpu:{}'.format(self.__local_rank)) - self.__cuda_available = False + # Assemble commands if cpp impl path + self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name) + + args = '-m %d -n %d -k %d' % (self._args.hidden_size, self._args.batch_size, self._args.input_size) + args += ' --alpha %g --beta %g' % (self._args.alpha, self._args.beta) + args += ' --num_layers %d --num_warmups %d --num_iters %d' % \ + (self._args.num_layers, self._args.num_warmup, self._args.num_steps) + if self._args.use_cuda_graph: + args += ' --use_cuda_graph' + self._commands = ['%s %s' % (self.__bin_path, args)] return True @@ -347,8 +383,7 @@ class DistInference(MicroBenchmark): self.__device ) model = model.to(dtype=getattr(torch, precision.value)) - if self.__cuda_available: - model = model.cuda() + model = model.cuda() return model def _run_model(self, model, batch_size, input_size, precision, device, num_warmup, num_steps): @@ -401,38 +436,78 @@ class DistInference(MicroBenchmark): Return: True if _benchmark succeeds. """ - batch_size = self._args.batch_size - input_size = self._args.input_size - hidden_size = self._args.hidden_size - num_layers = self._args.num_layers - computation = self._args.computation_kernel - communication = self._args.communication_kernel - activation = self._args.activation_kernel - precision = self._args.precision - num_warmup = self._args.num_warmup - num_steps = self._args.num_steps + if self._args.use_pytorch: + # Execute PyTorch model if pytorch impl path + batch_size = self._args.batch_size + input_size = self._args.input_size + hidden_size = self._args.hidden_size + num_layers = self._args.num_layers + computation = self._args.computation_kernel + communication = self._args.communication_kernel + activation = self._args.activation_kernel + precision = self._args.precision + num_warmup = self._args.num_warmup + num_steps = self._args.num_steps - if self.__local_rank == 0: - logger.info( - 'Distributed Inference - using {} GPUs: ' - 'batch_size={}, input_size={}, hidden_size={}, num_layers={}, ' - 'computation_kernel={}, communication_kernel={}, activation_kernel={}, precision={}, ' - 'num_warmup={} num_steps={}'.format( - self.__world_size, batch_size, input_size, hidden_size, num_layers, computation, communication, - activation, precision, num_warmup, num_steps + if self.__local_rank == 0: + logger.info( + 'Distributed Inference - using {} GPUs: ' + 'batch_size={}, input_size={}, hidden_size={}, num_layers={}, ' + 'computation_kernel={}, communication_kernel={}, activation_kernel={}, precision={}, ' + 'num_warmup={} num_steps={}'.format( + self.__world_size, batch_size, input_size, hidden_size, num_layers, computation, communication, + activation, precision, num_warmup, num_steps + ) ) + + # Prepare model + model = self._prepare_model( + input_size, hidden_size, num_layers, computation, communication, activation, precision, + self.__world_size ) - # Prepare model - model = self._prepare_model( - input_size, hidden_size, num_layers, computation, communication, activation, precision, self.__world_size - ) + # Run model + step_times = self._run_model(model, batch_size, input_size, precision, self.__device, num_warmup, num_steps) - # Run model - step_times = self._run_model(model, batch_size, input_size, precision, self.__device, num_warmup, num_steps) + # Process data and return + return self._process_data(step_times) + else: + # Execute commands if cpp impl path + if not super()._benchmark(): + return False + return True - # Process data and return - return self._process_data(step_times) + def _process_raw_result(self, cmd_idx, raw_output): + """Function to parse raw results and save the summarized results. + + self._result.add_raw_data() and self._result.add_result() need to be called to save the results. + + Args: + cmd_idx (int): the index of command corresponding with the raw_output. + raw_output (str): raw output string of the micro-benchmark. + + Return: + True if the raw output string is valid and result can be extracted. + """ + self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data) + + try: + output_lines = [x.strip() for x in raw_output.strip().splitlines()] + step_time = None + for output_line in output_lines: + if ' ms per iteration' in output_line: + step_time = float(output_line.split(' ms per iteration')[0].split()[-1]) + break + return self._process_numeric_result( + 'step_times', [step_time], reduce_type=ReduceType.MAX, cal_percentile=True + ) + except BaseException as e: + return self._set_error_code_and_print_error_msg( + ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE, + 'The result format is invalid - round: {}, benchmark: {}, raw output: {}, message: {}.'.format( + self._curr_run_index, self._name, raw_output, str(e) + ) + ) def _postprocess(self): """Postprocess/cleanup operations after the benchmarking. @@ -443,14 +518,26 @@ class DistInference(MicroBenchmark): if not super()._postprocess(): return False - try: - torch.distributed.destroy_process_group() - except BaseException as e: - self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_DESTROY_FAILURE) - logger.error('Post process failed - benchmark: {}, message: {}.'.format(self._name, str(e))) - return False + if self._args.use_pytorch: + try: + torch.distributed.destroy_process_group() + except BaseException as e: + return self._set_error_code_and_print_error_msg( + ReturnCode.DISTRIBUTED_SETTING_DESTROY_FAILURE, + 'Post process failed - benchmark: {}, message: {}.'.format(self._name, str(e)) + ) return True + def _set_error_code_and_print_error_msg(self, error_code, error_msg): + """Set error code and print error log upon error. + + Return: + False, representing error. + """ + self._result.set_return_code(error_code) + logger.error(error_msg) + return False + BenchmarkRegistry.register_benchmark('pytorch-dist-inference', DistInference, parameters='') diff --git a/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/CMakeLists.txt b/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/CMakeLists.txt new file mode 100644 index 00000000..728fffc9 --- /dev/null +++ b/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/CMakeLists.txt @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +cmake_minimum_required(VERSION 3.18) + +project(dist_inference LANGUAGES CXX) + +find_package(MPI REQUIRED) +include_directories(SYSTEM ${MPI_INCLUDE_PATH}) + +find_package(CUDAToolkit QUIET) + +# Cuda environment +if(CUDAToolkit_FOUND) + message(STATUS "Found CUDA: " ${CUDAToolkit_VERSION}) + + include(../cuda_common.cmake) + add_executable(dist_inference dist_inference.cu) + set_property(TARGET dist_inference PROPERTY CUDA_ARCHITECTURES ${NVCC_ARCHS_SUPPORTED}) + target_link_libraries(dist_inference MPI::MPI_CXX nccl cublasLt) +else() + # ROCm environment + include(../rocm_common.cmake) + find_package(hip QUIET) + if(hip_FOUND) + message(STATUS "Found ROCm: " ${HIP_VERSION}) + + # Convert cuda code to hip code in cpp + execute_process(COMMAND hipify-perl -print-stats -o dist_inference.cpp dist_inference.cu WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/) + + # link hip device lib + add_executable(dist_inference dist_inference.cpp) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -DROCM_USE_FLOAT16=1") + target_link_libraries(dist_inference MPI::MPI_CXX rccl hipblaslt hip::device) + else() + message(FATAL_ERROR "No CUDA or ROCm environment found.") + endif() +endif() + +install(TARGETS dist_inference RUNTIME DESTINATION bin) diff --git a/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu b/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu new file mode 100644 index 00000000..2f4e798c --- /dev/null +++ b/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu @@ -0,0 +1,455 @@ +/******************************************************************************* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT License. + * + * MIT License + * + * Copyright (C) 2022-2023 Advanced Micro Devices, Inc. + * Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__HIP_PLATFORM_AMD__) +#include +#include +using cublasLtHalf = hipblasLtHalf; +#else +#include +#include +#include +using cublasLtHalf = half; +#endif + +#ifndef CHECK_CUDA_ERROR +#define CHECK_CUDA_ERROR(error) \ + if (error != cudaSuccess) { \ + fprintf(stderr, "Cuda error: '%s'(%d) at %s:%d\n", cudaGetErrorString(error), error, __FILE__, __LINE__); \ + exit(-1); \ + } +#endif + +#ifndef CHECK_CUBLASLT_ERROR +#define CHECK_CUBLASLT_ERROR(error) \ + if (error != CUBLAS_STATUS_SUCCESS) { \ + fprintf(stderr, "cuBLASLt error(Err=%d) at %s:%d\n", error, __FILE__, __LINE__); \ + fprintf(stderr, "\n"); \ + exit(-1); \ + } +#endif + +#ifndef CHECK_NCCL_ERROR +#define CHECK_NCCL_ERROR(error) \ + if (error != ncclSuccess) { \ + fprintf(stderr, "NCCL error(Err=%d) at %s:%d\n", error, __FILE__, __LINE__); \ + fprintf(stderr, "\n"); \ + exit(-1); \ + } +#endif + +static void ShowUsage(char *argv[]) { + std::cerr << "Usage: " << argv[0] << " \n" + << "options:\n" + << "\t-h, --help\t\t\t\tShow this help message\n" + << "\t-m \t\t\tm\t\tGEMM_STRIDED argument m\n" + << "\t-n \t\t\tn\t\tGEMM_STRIDED argument n\n" + << "\t-k \t\t\tk \t\tGEMM_STRIDED argument k\n" + << "\t--alpha \t\talpha \t\tGEMM_STRIDED argument alpha\n" + << "\t--beta \t\t\tbeta \t\tGEMM_STRIDED argument beta\n" + << "\t--num_layers \t\t\tnum_layers \t\tNumber of layers in the model\n" + << "\t--num_warmups \t\t\tnum_warmups \t\tNumber of warmup runs\n" + << "\t--num_iters \t\t\tnum_iters \t\tNumber of test runs\n" + << "\t--use_cuda_graph \t\t\tuse_cuda_graph \t\tWhether to launch kernels in CUDA graph mode\n" + << std::endl; +} + +static int ParseArguments(int argc, char *argv[], int64_t *m, int64_t *n, int64_t *k, float *alpha, float *beta, + int32_t *num_layers, int32_t *num_warmups, int32_t *num_iters, bool *use_cuda_graph) { + if (argc >= 2) { + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + + if ((arg.at(0) == '-') || ((arg.at(0) == '-') && (arg.at(1) == '-'))) { + if ((arg == "-h") || (arg == "--help")) { + return -1; + } else if ((arg == "-m") && (i + 1 < argc)) { + *m = atoi(argv[++i]); + } else if ((arg == "-n") && (i + 1 < argc)) { + *n = atoi(argv[++i]); + } else if ((arg == "-k") && (i + 1 < argc)) { + *k = atoi(argv[++i]); + } else if ((arg == "--alpha") && (i + 1 < argc)) { + *alpha = atof(argv[++i]); + } else if ((arg == "--beta") && (i + 1 < argc)) { + *beta = atof(argv[++i]); + } else if ((arg == "--num_layers") && (i + 1 < argc)) { + *num_layers = atoi(argv[++i]); + } else if ((arg == "--num_warmups") && (i + 1 < argc)) { + *num_warmups = atoi(argv[++i]); + } else if ((arg == "--num_iters") && (i + 1 < argc)) { + *num_iters = atoi(argv[++i]); + } else if (arg == "--use_cuda_graph") { +#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310) + *use_cuda_graph = true; +#else + *use_cuda_graph = false; + std::cerr << "error with " << arg << std::endl; + std::cerr << "not supported by current environment" << std::endl << std::endl; + return -1; +#endif + } else { + std::cerr << "error with " << arg << std::endl; + std::cerr << "do not recognize option" << std::endl << std::endl; + return -1; + } + } else { + std::cerr << "error with " << arg << std::endl; + std::cerr << "option must start with - or --" << std::endl << std::endl; + return -1; + } + } + } + return 0; +} + +void InitializeABCDEF(std::vector &ha, int64_t size_a, std::vector &hb, int64_t size_b, + std::vector &hc, int64_t size_c, std::vector &hd, int64_t size_d, + std::vector &he, int64_t size_e, std::vector &hf, int64_t size_f) { + srand(1); + for (int i = 0; i < size_a; ++i) { + ha[i] = static_cast((rand() % 7) - 3); + } + for (int i = 0; i < size_b; ++i) { + hb[i] = static_cast((rand() % 7) - 3); + } + for (int i = 0; i < size_c; ++i) { + hc[i] = static_cast((rand() % 7) - 3); + } + for (int i = 0; i < size_d; ++i) { + hd[i] = static_cast((rand() % 7) - 3); + } + for (int i = 0; i < size_e; ++i) { + he[i] = static_cast((rand() % 7) - 3); + } + for (int i = 0; i < size_f; ++i) { + hf[i] = static_cast((rand() % 7) - 3); + } +} + +// B[m, k] * A[k, n] + C[m, n] = D[m, n] +// E[k, m] * D[m, n] + F[k, n] = G[k, n] +void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t num_layers, int32_t num_warmups, + int32_t num_iters, bool use_cuda_graph, ncclComm_t nccl_comm) { + const int kNcclBufAlignment = 512; + + int size_a = k * n; + int size_b = m * k; + int size_c = m * n; + int size_d = m * n; + int size_e = k * m; + int size_f = k * n; + int size_g = (k * n + kNcclBufAlignment - 1) / kNcclBufAlignment * kNcclBufAlignment; + + // Naming: da is in GPU (device) memory. ha is in CPU (host) memory + std::vector ha(size_a); + std::vector hb(size_b); + std::vector hc(size_c); + std::vector hd(size_d); + std::vector he(size_e); + std::vector hf(size_f); + std::vector hg(size_g); + + // initial data on host + InitializeABCDEF(ha, size_a, hb, size_b, hc, size_c, hd, size_d, he, size_e, hf, size_f); + + // allocate memory on device + void *da, *db, *dc, *dd, *de, *df, *dg; + + // Create stream + cudaStream_t stream = nullptr; + CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + + CHECK_CUDA_ERROR(cudaMalloc(&da, size_a * sizeof(cublasLtHalf))); + CHECK_CUDA_ERROR(cudaMalloc(&db, size_b * sizeof(cublasLtHalf))); + CHECK_CUDA_ERROR(cudaMalloc(&dc, size_c * sizeof(cublasLtHalf))); + CHECK_CUDA_ERROR(cudaMalloc(&dd, size_d * sizeof(cublasLtHalf))); + CHECK_CUDA_ERROR(cudaMalloc(&de, size_e * sizeof(cublasLtHalf))); + CHECK_CUDA_ERROR(cudaMalloc(&df, size_f * sizeof(cublasLtHalf))); + CHECK_CUDA_ERROR(cudaMalloc(&dg, size_g * sizeof(cublasLtHalf))); + // copy matrices from host to device + CHECK_CUDA_ERROR(cudaMemcpy(da, ha.data(), sizeof(cublasLtHalf) * size_a, cudaMemcpyHostToDevice)); + CHECK_CUDA_ERROR(cudaMemcpy(db, hb.data(), sizeof(cublasLtHalf) * size_b, cudaMemcpyHostToDevice)); + CHECK_CUDA_ERROR(cudaMemcpy(dc, hc.data(), sizeof(cublasLtHalf) * size_c, cudaMemcpyHostToDevice)); + CHECK_CUDA_ERROR(cudaMemcpy(dd, hd.data(), sizeof(cublasLtHalf) * size_d, cudaMemcpyHostToDevice)); + CHECK_CUDA_ERROR(cudaMemcpy(de, he.data(), sizeof(cublasLtHalf) * size_e, cudaMemcpyHostToDevice)); + CHECK_CUDA_ERROR(cudaMemcpy(df, hf.data(), sizeof(cublasLtHalf) * size_f, cudaMemcpyHostToDevice)); + + uint64_t workspace_size = 1024 * 1024; + void *d_workspace; + CHECK_CUDA_ERROR(cudaMalloc(&d_workspace, workspace_size)); + int returnedAlgoCount = 0; + + // cublasLt is not well supported by ROCm hipify tools, explicitly define ROCm logic instead. +#if defined(__HIP_PLATFORM_AMD__) + hipblasLtHandle_t handle; + hipblasLtMatrixLayout_t matA, matB, matC, matD, matE, matF, matG; + hipblasLtMatmulDesc_t matmul1, matmul2; + hipblasLtMatmulPreference_t pref; + + CHECK_CUBLASLT_ERROR(hipblasLtCreate(&handle)); + + CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, HIPBLAS_R_16F, k, n, k)); + CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, HIPBLAS_R_16F, m, k, m)); + CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, HIPBLAS_R_16F, m, n, m)); + CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, HIPBLAS_R_16F, m, n, m)); + CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matE, HIPBLAS_R_16F, k, m, k)); + CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matF, HIPBLAS_R_16F, k, n, k)); + CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matG, HIPBLAS_R_16F, k, n, k)); + + CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul1, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F)); + CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul2, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F)); + + hipblasOperation_t trans = HIPBLAS_OP_N; + CHECK_CUBLASLT_ERROR( + hipblasLtMatmulDescSetAttribute(matmul1, HIPBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t))); + CHECK_CUBLASLT_ERROR( + hipblasLtMatmulDescSetAttribute(matmul1, HIPBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t))); + CHECK_CUBLASLT_ERROR( + hipblasLtMatmulDescSetAttribute(matmul2, HIPBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t))); + CHECK_CUBLASLT_ERROR( + hipblasLtMatmulDescSetAttribute(matmul2, HIPBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t))); + + // Set User Preference attributes + CHECK_CUBLASLT_ERROR(hipblasLtMatmulPreferenceCreate(&pref)); + CHECK_CUBLASLT_ERROR(hipblasLtMatmulPreferenceSetAttribute(pref, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, sizeof(workspace_size))); + + // Get Heuristic results + hipblasLtMatmulHeuristicResult_t heuristicResult1[1] = {0}; + hipblasLtMatmulHeuristicResult_t heuristicResult2[1] = {0}; + // B[m, k] * A[k, n] + C[m, n] = D[m, n] + // E[k, m] * D[m, n] + F[k, n] = G[k, n] + CHECK_CUBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, matmul1, matB, matA, matC, matD, pref, 1, + heuristicResult1, &returnedAlgoCount)); + CHECK_CUBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, matmul2, matE, matD, matF, matG, pref, 1, + heuristicResult2, &returnedAlgoCount)); +#else + cublasLtHandle_t handle; + cublasLtMatrixLayout_t matA, matB, matC, matD, matE, matF, matG; + cublasLtMatmulDesc_t matmul1, matmul2; + cublasLtMatmulPreference_t pref; + CHECK_CUBLASLT_ERROR(cublasLtCreate(&handle)); + + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matA, CUDA_R_16F, k, n, k)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matB, CUDA_R_16F, m, k, m)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matC, CUDA_R_16F, m, n, m)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matD, CUDA_R_16F, m, n, m)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matE, CUDA_R_16F, k, m, k)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matF, CUDA_R_16F, k, n, k)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matG, CUDA_R_16F, k, n, k)); + + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescCreate(&matmul1, CUBLAS_COMPUTE_16F, CUDA_R_32F)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescCreate(&matmul2, CUBLAS_COMPUTE_16F, CUDA_R_32F)); + + cublasOperation_t trans = CUBLAS_OP_N; + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul1, CUBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t))); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul1, CUBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t))); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul2, CUBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t))); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul2, CUBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t))); + + // Set User Preference attributes + CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceCreate(&pref)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, sizeof(workspace_size))); + + // Get Heuristic results + cublasLtMatmulHeuristicResult_t heuristicResult1[1] = {0}; + cublasLtMatmulHeuristicResult_t heuristicResult2[1] = {0}; + // B[m, k] * A[k, n] + C[m, n] = D[m, n] + // E[k, m] * D[m, n] + F[k, n] = G[k, n] + CHECK_CUBLASLT_ERROR(cublasLtMatmulAlgoGetHeuristic(handle, matmul1, matB, matA, matC, matD, pref, 1, + heuristicResult1, &returnedAlgoCount)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulAlgoGetHeuristic(handle, matmul2, matE, matD, matF, matG, pref, 1, + heuristicResult2, &returnedAlgoCount)); +#endif + + auto model_forward = [&] { + for (int j = 0; j < num_layers; j++) { + // B[m, k] * A[k, n] + C[m, n] = D[m, n] + // E[k, m] * D[m, n] + F[k, n] = G[k, n] + // cublasLt is not well supported by ROCm hipify tools, explicitly define ROCm logic instead. +#if defined(__HIP_PLATFORM_AMD__) + CHECK_CUBLASLT_ERROR(hipblasLtMatmul(handle, matmul1, &alpha, db, matB, da, matA, &beta, dc, matC, dd, matD, + &heuristicResult1[0].algo, d_workspace, workspace_size, stream)); + CHECK_CUBLASLT_ERROR(hipblasLtMatmul(handle, matmul1, &alpha, de, matE, dd, matD, &beta, df, matF, dg, matG, + &heuristicResult2[0].algo, d_workspace, workspace_size, stream)); +#else + CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, matmul1, &alpha, db, matB, da, matA, &beta, dc, matC, dd, matD, + &heuristicResult1[0].algo, d_workspace, workspace_size, stream)); + CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, matmul1, &alpha, de, matE, dd, matD, &beta, df, matF, dg, matG, + &heuristicResult2[0].algo, d_workspace, workspace_size, stream)); +#endif + CHECK_NCCL_ERROR(ncclAllReduce(dg, dg, size_g, ncclFloat16, ncclSum, nccl_comm, stream)); + } + }; + +#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310) + cudaGraph_t graph; + cudaGraphExec_t instance; + if (use_cuda_graph) { + CHECK_CUDA_ERROR(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal)); + model_forward(); + CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &graph)); + CHECK_CUDA_ERROR(cudaGraphInstantiate(&instance, graph, NULL, NULL, 0)); + } +#endif + + std::chrono::steady_clock::time_point start_time, stop_time; + for (int i = 0; i < num_warmups + num_iters; ++i) { + if (i == num_warmups) { + start_time = std::chrono::steady_clock::now(); + } +#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310) + if (use_cuda_graph) { + CHECK_CUDA_ERROR(cudaGraphLaunch(instance, stream)); + } else { + model_forward(); + } +#else + model_forward(); +#endif + CHECK_CUDA_ERROR(cudaStreamSynchronize(stream)); + } + stop_time = std::chrono::steady_clock::now(); + double duration = std::chrono::duration_cast(stop_time - start_time).count(); + fprintf(stdout, "Time: %g ms in total, %g ms per iteration, %g ms per layer\n", duration, duration / num_iters, + duration / num_iters / num_layers); + +#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310) + // Destroy graph + if (use_cuda_graph) { + CHECK_CUDA_ERROR(cudaGraphExecDestroy(instance)); + CHECK_CUDA_ERROR(cudaGraphDestroy(graph)); + } +#endif + + // Destroy stream + CHECK_CUDA_ERROR(cudaStreamDestroy(stream)); + + CHECK_CUDA_ERROR(cudaFree(da)); + CHECK_CUDA_ERROR(cudaFree(db)); + CHECK_CUDA_ERROR(cudaFree(dc)); + CHECK_CUDA_ERROR(cudaFree(dd)); + CHECK_CUDA_ERROR(cudaFree(de)); + CHECK_CUDA_ERROR(cudaFree(df)); + CHECK_CUDA_ERROR(cudaFree(dg)); + CHECK_CUDA_ERROR(cudaFree(d_workspace)); + // cublasLt is not well supported by ROCm hipify tools, explicitly define ROCm logic instead. +#if defined(__HIP_PLATFORM_AMD__) + CHECK_CUBLASLT_ERROR(hipblasLtMatmulPreferenceDestroy(pref)); + CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescDestroy(matmul1)); + CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescDestroy(matmul2)); + CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matA)); + CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matB)); + CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matC)); + CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matD)); + CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matE)); + CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matF)); + CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matG)); + CHECK_CUBLASLT_ERROR(hipblasLtDestroy(handle)); +#else + CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceDestroy(pref)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescDestroy(matmul1)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescDestroy(matmul2)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matA)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matB)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matC)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matD)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matE)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matF)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matG)); + CHECK_CUBLASLT_ERROR(cublasLtDestroy(handle)); +#endif + + return; +} + +int main(int argc, char *argv[]) { + // Init MPI + int comm_rank, comm_size; + MPI_Init(NULL, NULL); + MPI_Comm_rank(MPI_COMM_WORLD, &comm_rank); + MPI_Comm_size(MPI_COMM_WORLD, &comm_size); + + // Init NCCL + int num_local_ranks = 0; + ncclComm_t nccl_comm; + ncclUniqueId nccl_id; + if (comm_rank == 0) { + CHECK_NCCL_ERROR(ncclGetUniqueId(&nccl_id)); + } + MPI_Bcast(&nccl_id, sizeof(ncclUniqueId), MPI_BYTE, 0, MPI_COMM_WORLD); + CHECK_CUDA_ERROR(cudaGetDeviceCount(&num_local_ranks)) + CHECK_CUDA_ERROR(cudaSetDevice(comm_rank % num_local_ranks)); + CHECK_NCCL_ERROR(ncclCommInitRank(&nccl_comm, comm_size, nccl_id, comm_rank)); + + // Init parameters with default values + int64_t m = 80; + int64_t n = 128; + int64_t k = 128; + float alpha = 1; + float beta = 1; + int32_t num_layers = 50; + int32_t num_warmups = 20; + int32_t num_iters = 100; + bool use_cuda_graph = false; + + if (ParseArguments(argc, argv, &m, &n, &k, &alpha, &beta, &num_layers, &num_warmups, &num_iters, &use_cuda_graph)) { + ShowUsage(argv); + return -1; + } + + fprintf(stdout, + "Parameters: m=%ld, n=%ld, k=%ld, alpha=%f, beta=%f, num_layers=%d, num_warmups=%d, num_iters=%d, " + "use_cuda_graph=%d\n", + m, n, k, alpha, beta, num_layers, num_warmups, num_iters, (int)use_cuda_graph); + + TestModel(m, n, k, alpha, beta, num_layers, num_warmups, num_iters, use_cuda_graph, nccl_comm); + + CHECK_NCCL_ERROR(ncclCommDestroy(nccl_comm)); + + MPI_Finalize(); + + return 0; +} diff --git a/superbench/config/azure_ndmv4.yaml b/superbench/config/azure_ndmv4.yaml index ccef356c..4914780a 100644 --- a/superbench/config/azure_ndmv4.yaml +++ b/superbench/config/azure_ndmv4.yaml @@ -33,6 +33,8 @@ superbench: node_num: 1 env: NCCL_ASYNC_ERROR_HANDLING: '0' + frameworks: + - pytorch common_model_config: &common_model_config duration: 0 num_warmup: 64 diff --git a/superbench/config/azure_ndv4.yaml b/superbench/config/azure_ndv4.yaml index 9af82655..827626af 100644 --- a/superbench/config/azure_ndv4.yaml +++ b/superbench/config/azure_ndv4.yaml @@ -29,6 +29,8 @@ superbench: node_num: 1 env: NCCL_ASYNC_ERROR_HANDLING: '0' + frameworks: + - pytorch common_model_config: &common_model_config duration: 0 num_warmup: 64 diff --git a/tests/benchmarks/micro_benchmarks/test_dist_inference.py b/tests/benchmarks/micro_benchmarks/test_dist_inference.py index 1c27bc15..e24ec341 100644 --- a/tests/benchmarks/micro_benchmarks/test_dist_inference.py +++ b/tests/benchmarks/micro_benchmarks/test_dist_inference.py @@ -3,12 +3,15 @@ """Tests for distributed inference benchmark.""" +import numbers import unittest from tests.helper import decorator +from tests.helper.testcase import BenchmarkTestCase import tests.benchmarks.utils as utils from superbench.benchmarks \ - import BenchmarkRegistry, Framework, BenchmarkType, ReturnCode, Precision, DistributedImpl, DistributedBackend + import BenchmarkRegistry, Framework, BenchmarkType, ReturnCode, Precision, DistributedImpl, DistributedBackend, \ + Platform from superbench.benchmarks.micro_benchmarks.dist_inference \ import DistInference, ComputationKernelType, CommunicationKernelType, ActivationKernelType from superbench.common.utils import network @@ -20,7 +23,9 @@ from superbench.common.utils import network @decorator.pytorch_test def test_pytorch_dist_inference_normal(): """Test pytorch-dist-inference benchmark on distributed normal case.""" - context = BenchmarkRegistry.create_benchmark_context('dist-inference', parameters='', framework=Framework.PYTORCH) + context = BenchmarkRegistry.create_benchmark_context( + 'dist-inference', parameters='--use_pytorch', framework=Framework.PYTORCH + ) world_size = 2 assert (BenchmarkRegistry.is_benchmark_context_valid(context)) results = utils.simulated_ddp_distributed_benchmark(context, world_size) @@ -33,9 +38,12 @@ def test_pytorch_dist_inference_normal(): assert (benchmark.type == BenchmarkType.MICRO) # Check predefined parameters of dist-inference benchmark. + assert (benchmark._args.use_pytorch is True) assert (benchmark._args.batch_size == 64) assert (benchmark._args.input_size == 1024) assert (benchmark._args.hidden_size == 1024) + assert (benchmark._args.alpha == 1.0) + assert (benchmark._args.beta == 1.0) assert (benchmark._args.num_layers == 1) assert (benchmark._args.computation_kernel == ComputationKernelType.MATMUL) assert (benchmark._args.communication_kernel == CommunicationKernelType.ALLREDUCE) @@ -45,6 +53,7 @@ def test_pytorch_dist_inference_normal(): assert (benchmark._args.num_steps == 10000) assert (benchmark._args.distributed_impl == DistributedImpl.DDP) assert (benchmark._args.distributed_backend == DistributedBackend.NCCL) + assert (benchmark._args.use_cuda_graph is False) # Check results and metrics. assert (benchmark.run_count == 1) @@ -52,14 +61,16 @@ def test_pytorch_dist_inference_normal(): # step_times assert (len(benchmark.raw_data) == 1) # return code + (avg, 50th, 90th, 95th, 99th, 99.9th) - assert (len(benchmark.result) == 7) + assert (7 == len(benchmark.result)) @decorator.cuda_test @decorator.pytorch_test def test_pytorch_dist_inference_fake_distributed(): """Test pytorch-dist-inference benchmark on single gpu.""" - context = BenchmarkRegistry.create_benchmark_context('dist-inference', parameters='', framework=Framework.PYTORCH) + context = BenchmarkRegistry.create_benchmark_context( + 'dist-inference', parameters='--use_pytorch', framework=Framework.PYTORCH + ) port = network.get_free_port() assert (port) utils.setup_simulated_ddp_distributed_env(1, 0, port) @@ -72,9 +83,12 @@ def test_pytorch_dist_inference_fake_distributed(): assert (benchmark.type == BenchmarkType.MICRO) # Check predefined parameters of dist-inference benchmark. + assert (benchmark._args.use_pytorch is True) assert (benchmark._args.batch_size == 64) assert (benchmark._args.input_size == 1024) assert (benchmark._args.hidden_size == 1024) + assert (benchmark._args.alpha == 1.0) + assert (benchmark._args.beta == 1.0) assert (benchmark._args.num_layers == 1) assert (benchmark._args.computation_kernel == ComputationKernelType.MATMUL) assert (benchmark._args.communication_kernel == CommunicationKernelType.ALLREDUCE) @@ -84,6 +98,7 @@ def test_pytorch_dist_inference_fake_distributed(): assert (benchmark._args.num_steps == 10000) assert (benchmark._args.distributed_impl == DistributedImpl.DDP) assert (benchmark._args.distributed_backend == DistributedBackend.NCCL) + assert (benchmark._args.use_cuda_graph is False) # Check results and metrics. assert (benchmark.run_count == 1) @@ -94,3 +109,127 @@ def test_pytorch_dist_inference_fake_distributed(): assert (len(benchmark.result) == 7) utils.clean_simulated_ddp_distributed_env() + + +class DistInferenceCppImplTest(BenchmarkTestCase, unittest.TestCase): + """Test class for pytorch-dist-inference benchmark.""" + @classmethod + def setUpClass(cls): + """Hook method for setting up class fixture before running tests in the class.""" + super().setUpClass() + cls.createMockEnvs(cls) + cls.createMockFiles(cls, ['bin/dist_inference']) + + def _test_dist_inference_command_generation(self, platform): + """Test pytorch-dist-inference cpp impl benchmark command generation.""" + benchmark_name = 'pytorch-dist-inference' + (benchmark_class, + predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, platform) + assert (benchmark_class) + + batch_size = 1 + input_size = 2 + hidden_size = 3 + alpha = 4.0 + beta = 5.0 + num_layers = 6 + num_warmup = 7 + num_steps = 8 + wrapper_params_format_str = \ + '--batch_size %d --input_size %d --hidden_size %d ' \ + '--alpha %g --beta %g --num_layers %d --num_warmup %d --num_steps %d --use_cuda_graph' + parameters = wrapper_params_format_str % ( + batch_size, input_size, hidden_size, alpha, beta, num_layers, num_warmup, num_steps + ) + benchmark = benchmark_class(benchmark_name, parameters=parameters) + + # Check basic information + assert (benchmark) + ret = benchmark._preprocess() + assert (ret is True) + assert (benchmark.return_code == ReturnCode.SUCCESS) + assert (benchmark.name == benchmark_name) + assert (benchmark.type == BenchmarkType.MICRO) + + # Check parameters specified in BenchmarkContext. + assert (benchmark._args.use_pytorch is False) + assert (benchmark._args.batch_size == batch_size) + assert (benchmark._args.input_size == input_size) + assert (benchmark._args.hidden_size == hidden_size) + assert (benchmark._args.alpha == alpha) + assert (benchmark._args.beta == beta) + assert (benchmark._args.num_layers == num_layers) + assert (benchmark._args.num_warmup == num_warmup) + assert (benchmark._args.num_steps == num_steps) + assert (benchmark._args.use_cuda_graph is True) + + # Check command + assert (1 == len(benchmark._commands)) + for cmd in benchmark._commands: + m, n, k = hidden_size, batch_size, input_size + bench_params_format_str = \ + '%s -m %d -n %d -k %d --alpha %g --beta %g ' + \ + '--num_layers %d --num_warmups %d --num_iters %d --use_cuda_graph' + assert ( + cmd == ( + bench_params_format_str % + (benchmark._DistInference__bin_path, m, n, k, alpha, beta, num_layers, num_warmup, num_steps) + ) + ) + + @decorator.cuda_test + def test_dist_inference_command_generation_cuda(self): + """Test pytorch-dist-inference cpp impl benchmark command generation, CUDA case.""" + self._test_dist_inference_command_generation(Platform.CUDA) + + @decorator.rocm_test + def test_dist_inference_command_generation_rocm(self): + """Test pytorch-dist-inference cpp impl benchmark command generation, ROCm case.""" + self._test_dist_inference_command_generation(Platform.ROCM) + + @decorator.load_data('tests/data/dist_inference.log') + def _test_dist_inference_result_parsing(self, platform, test_raw_output): + """Test pytorch-dist-inference cpp impl benchmark result parsing.""" + benchmark_name = 'pytorch-dist-inference' + (benchmark_class, + predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, platform) + assert (benchmark_class) + benchmark = benchmark_class(benchmark_name, parameters='') + assert (benchmark) + ret = benchmark._preprocess() + assert (ret is True) + assert (benchmark.return_code == ReturnCode.SUCCESS) + assert (benchmark.name == 'pytorch-dist-inference') + assert (benchmark.type == BenchmarkType.MICRO) + + # Positive case - valid raw output. + assert (benchmark._process_raw_result(0, test_raw_output)) + assert (benchmark.return_code == ReturnCode.SUCCESS) + + # step_times + assert (len(benchmark.raw_data) == 2) + # return code + (avg, 50th, 90th, 95th, 99th, 99.9th) + test_latency = float(test_raw_output.splitlines()[-1].split(' ms per iteration')[0].split()[-1]) + assert (7 == len(benchmark.result)) + for output_key in benchmark.result: + if output_key == 'return_code': + assert (benchmark.result[output_key] == [0]) + else: + assert (output_key.startswith('step_times')) + assert (len(benchmark.result[output_key]) == 1) + assert (isinstance(benchmark.result[output_key][0], numbers.Number)) + assert (test_latency == benchmark.result[output_key][0]) + + # Negative case - invalid raw output. + assert (benchmark._process_raw_result(1, 'Invalid raw output') is False) + assert (benchmark.return_code == ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE) + + @decorator.cuda_test + def test_dist_inference_result_parsing_cuda(self): + """Test pytorch-dist-inference cpp impl benchmark result parsing, CUDA case.""" + self._test_dist_inference_result_parsing(Platform.CUDA) + + @decorator.rocm_test + def test_dist_inference_result_parsing_rocm(self): + """Test pytorch-dist-inference cpp impl benchmark result parsing, ROCm case.""" + self._test_dist_inference_result_parsing(Platform.ROCM) diff --git a/tests/data/dist_inference.log b/tests/data/dist_inference.log new file mode 100644 index 00000000..14b104cf --- /dev/null +++ b/tests/data/dist_inference.log @@ -0,0 +1,2 @@ +Parameters: m=80, n=128, k=128, alpha=1.000000, beta=1.000000, num_layers=50, num_warmups=20, num_iters=100, use_cuda_graph=0 +Time: 173 ms in total, 1.73 ms per iteration, 0.0346 ms per layer