Benchmarks: Microbenchmark - Add distributed inference benchmark cpp implementation (#586)

**Description**
Add distributed inference benchmark cpp implementation.
This commit is contained in:
Ziyue Yang 2023-12-11 06:53:51 +08:00 коммит произвёл GitHub
Родитель 1f5031bd74
Коммит 719a427fe7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 791 добавлений и 64 удалений

Просмотреть файл

@ -418,7 +418,7 @@ Test the performance of large scale matmul operation with multiple GPUs:
#### Introduction #### Introduction
Test the performance of distributed model inference. Test the performance of distributed model inference. Support both PyTorch implementation and cpp implementation.
#### Metrics #### Metrics

Просмотреть файл

@ -12,7 +12,7 @@ import torch.distributed as dist
from superbench.common.utils import logger from superbench.common.utils import logger
from superbench.benchmarks import DistributedImpl, DistributedBackend, BenchmarkRegistry, ReturnCode, Precision from superbench.benchmarks import DistributedImpl, DistributedBackend, BenchmarkRegistry, ReturnCode, Precision
from superbench.benchmarks.micro_benchmarks import MicroBenchmark from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke
from superbench.benchmarks.context import Enum from superbench.benchmarks.context import Enum
from superbench.benchmarks.reducer import ReduceType from superbench.benchmarks.reducer import ReduceType
@ -168,7 +168,7 @@ class DistInferenceModel(torch.nn.Module):
return activation_out return activation_out
class DistInference(MicroBenchmark): class DistInference(MicroBenchmarkWithInvoke):
"""The base class of micro-benchmarks.""" """The base class of micro-benchmarks."""
def __init__(self, name, parameters=''): def __init__(self, name, parameters=''):
"""Constructor. """Constructor.
@ -182,7 +182,9 @@ class DistInference(MicroBenchmark):
self.__local_rank = 0 self.__local_rank = 0
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
self.__device = None self.__device = None
self.__cuda_available = False
# For cpp impl path
self._bin_name = 'dist_inference'
def __timer(self): def __timer(self):
"""Returns the current time which ensures all previous CUDA events have been finished. """Returns the current time which ensures all previous CUDA events have been finished.
@ -193,14 +195,19 @@ class DistInference(MicroBenchmark):
Return: Return:
Current time in second. Current time in second.
""" """
if self.__cuda_available: torch.cuda.synchronize()
torch.cuda.synchronize()
return time.time() return time.time()
def add_parser_arguments(self): def add_parser_arguments(self):
"""Add the specified arguments.""" """Add the specified arguments."""
super().add_parser_arguments() super().add_parser_arguments()
self._parser.add_argument(
'--use_pytorch',
action='store_true',
required=False,
help='Whether to use pytorch implementation. If not, cpp implementation will be used.',
)
self._parser.add_argument( self._parser.add_argument(
'--batch_size', '--batch_size',
type=int, type=int,
@ -222,6 +229,20 @@ class DistInference(MicroBenchmark):
required=False, required=False,
help='Hidden size.', help='Hidden size.',
) )
self._parser.add_argument(
'--alpha',
type=float,
default=1.0,
required=False,
help='Coefficient alpha in D = alpha*(A*B) + beta*(C).',
)
self._parser.add_argument(
'--beta',
type=float,
default=1.0,
required=False,
help='Coefficient beta in D = alpha*(A*B) + beta*(C).',
)
self._parser.add_argument( self._parser.add_argument(
'--num_layers', '--num_layers',
type=int, type=int,
@ -285,6 +306,12 @@ class DistInference(MicroBenchmark):
required=False, required=False,
help='Distributed backends. E.g. {}.'.format(' '.join(DistributedBackend.get_values())), help='Distributed backends. E.g. {}.'.format(' '.join(DistributedBackend.get_values())),
) )
self._parser.add_argument(
'--use_cuda_graph',
action='store_true',
required=False,
help='Whether to launch kernels in CUDA graph mode.',
)
def _preprocess(self): def _preprocess(self):
"""Preprocess/preparation operations before the benchmarking. """Preprocess/preparation operations before the benchmarking.
@ -295,32 +322,41 @@ class DistInference(MicroBenchmark):
if not super()._preprocess(): if not super()._preprocess():
return False return False
if self._args.distributed_impl != DistributedImpl.DDP: if self._args.use_pytorch:
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE) # Initialize PyTorch if pytorch impl path
logger.error( if self._args.distributed_impl != DistributedImpl.DDP:
'Unsupported distributed implementation - model: {}, distributed implementation: {}.'.format( return self._set_error_code_and_print_error_msg(
self._name, self._args.distributed_impl ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE,
'Unsupported distributed implementation - model: {}, distributed implementation: {}.'.format(
self._name, self._args.distributed_impl
)
) )
)
return False
try: try:
torch.distributed.init_process_group(backend=self._args.distributed_backend.value) torch.distributed.init_process_group(backend=self._args.distributed_backend.value)
self.__world_size = int(os.environ['WORLD_SIZE']) self.__world_size = int(os.environ['WORLD_SIZE'])
self.__local_rank = int(os.environ['LOCAL_RANK']) self.__local_rank = int(os.environ['LOCAL_RANK'])
except BaseException as e: assert (torch.cuda.is_available())
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE) except BaseException as e:
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
logger.error('Initialize distributed env failed - benchmark: {}, message: {}.'.format(self._name, str(e))) return self._set_error_code_and_print_error_msg(
return False ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE,
'Initialize distributed env failed - benchmark: {}, message: {}.'.format(self._name, str(e))
)
if torch.cuda.is_available():
torch.cuda.set_device(self.__local_rank) torch.cuda.set_device(self.__local_rank)
self.__device = torch.device('cuda:{}'.format(self.__local_rank)) self.__device = torch.device('cuda:{}'.format(self.__local_rank))
self.__cuda_available = True
else: else:
self.__device = torch.device('cpu:{}'.format(self.__local_rank)) # Assemble commands if cpp impl path
self.__cuda_available = False self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name)
args = '-m %d -n %d -k %d' % (self._args.hidden_size, self._args.batch_size, self._args.input_size)
args += ' --alpha %g --beta %g' % (self._args.alpha, self._args.beta)
args += ' --num_layers %d --num_warmups %d --num_iters %d' % \
(self._args.num_layers, self._args.num_warmup, self._args.num_steps)
if self._args.use_cuda_graph:
args += ' --use_cuda_graph'
self._commands = ['%s %s' % (self.__bin_path, args)]
return True return True
@ -347,8 +383,7 @@ class DistInference(MicroBenchmark):
self.__device self.__device
) )
model = model.to(dtype=getattr(torch, precision.value)) model = model.to(dtype=getattr(torch, precision.value))
if self.__cuda_available: model = model.cuda()
model = model.cuda()
return model return model
def _run_model(self, model, batch_size, input_size, precision, device, num_warmup, num_steps): def _run_model(self, model, batch_size, input_size, precision, device, num_warmup, num_steps):
@ -401,38 +436,78 @@ class DistInference(MicroBenchmark):
Return: Return:
True if _benchmark succeeds. True if _benchmark succeeds.
""" """
batch_size = self._args.batch_size if self._args.use_pytorch:
input_size = self._args.input_size # Execute PyTorch model if pytorch impl path
hidden_size = self._args.hidden_size batch_size = self._args.batch_size
num_layers = self._args.num_layers input_size = self._args.input_size
computation = self._args.computation_kernel hidden_size = self._args.hidden_size
communication = self._args.communication_kernel num_layers = self._args.num_layers
activation = self._args.activation_kernel computation = self._args.computation_kernel
precision = self._args.precision communication = self._args.communication_kernel
num_warmup = self._args.num_warmup activation = self._args.activation_kernel
num_steps = self._args.num_steps precision = self._args.precision
num_warmup = self._args.num_warmup
num_steps = self._args.num_steps
if self.__local_rank == 0: if self.__local_rank == 0:
logger.info( logger.info(
'Distributed Inference - using {} GPUs: ' 'Distributed Inference - using {} GPUs: '
'batch_size={}, input_size={}, hidden_size={}, num_layers={}, ' 'batch_size={}, input_size={}, hidden_size={}, num_layers={}, '
'computation_kernel={}, communication_kernel={}, activation_kernel={}, precision={}, ' 'computation_kernel={}, communication_kernel={}, activation_kernel={}, precision={}, '
'num_warmup={} num_steps={}'.format( 'num_warmup={} num_steps={}'.format(
self.__world_size, batch_size, input_size, hidden_size, num_layers, computation, communication, self.__world_size, batch_size, input_size, hidden_size, num_layers, computation, communication,
activation, precision, num_warmup, num_steps activation, precision, num_warmup, num_steps
)
) )
# Prepare model
model = self._prepare_model(
input_size, hidden_size, num_layers, computation, communication, activation, precision,
self.__world_size
) )
# Prepare model # Run model
model = self._prepare_model( step_times = self._run_model(model, batch_size, input_size, precision, self.__device, num_warmup, num_steps)
input_size, hidden_size, num_layers, computation, communication, activation, precision, self.__world_size
)
# Run model # Process data and return
step_times = self._run_model(model, batch_size, input_size, precision, self.__device, num_warmup, num_steps) return self._process_data(step_times)
else:
# Execute commands if cpp impl path
if not super()._benchmark():
return False
return True
# Process data and return def _process_raw_result(self, cmd_idx, raw_output):
return self._process_data(step_times) """Function to parse raw results and save the summarized results.
self._result.add_raw_data() and self._result.add_result() need to be called to save the results.
Args:
cmd_idx (int): the index of command corresponding with the raw_output.
raw_output (str): raw output string of the micro-benchmark.
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data)
try:
output_lines = [x.strip() for x in raw_output.strip().splitlines()]
step_time = None
for output_line in output_lines:
if ' ms per iteration' in output_line:
step_time = float(output_line.split(' ms per iteration')[0].split()[-1])
break
return self._process_numeric_result(
'step_times', [step_time], reduce_type=ReduceType.MAX, cal_percentile=True
)
except BaseException as e:
return self._set_error_code_and_print_error_msg(
ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE,
'The result format is invalid - round: {}, benchmark: {}, raw output: {}, message: {}.'.format(
self._curr_run_index, self._name, raw_output, str(e)
)
)
def _postprocess(self): def _postprocess(self):
"""Postprocess/cleanup operations after the benchmarking. """Postprocess/cleanup operations after the benchmarking.
@ -443,14 +518,26 @@ class DistInference(MicroBenchmark):
if not super()._postprocess(): if not super()._postprocess():
return False return False
try: if self._args.use_pytorch:
torch.distributed.destroy_process_group() try:
except BaseException as e: torch.distributed.destroy_process_group()
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_DESTROY_FAILURE) except BaseException as e:
logger.error('Post process failed - benchmark: {}, message: {}.'.format(self._name, str(e))) return self._set_error_code_and_print_error_msg(
return False ReturnCode.DISTRIBUTED_SETTING_DESTROY_FAILURE,
'Post process failed - benchmark: {}, message: {}.'.format(self._name, str(e))
)
return True return True
def _set_error_code_and_print_error_msg(self, error_code, error_msg):
"""Set error code and print error log upon error.
Return:
False, representing error.
"""
self._result.set_return_code(error_code)
logger.error(error_msg)
return False
BenchmarkRegistry.register_benchmark('pytorch-dist-inference', DistInference, parameters='') BenchmarkRegistry.register_benchmark('pytorch-dist-inference', DistInference, parameters='')

Просмотреть файл

@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
cmake_minimum_required(VERSION 3.18)
project(dist_inference LANGUAGES CXX)
find_package(MPI REQUIRED)
include_directories(SYSTEM ${MPI_INCLUDE_PATH})
find_package(CUDAToolkit QUIET)
# Cuda environment
if(CUDAToolkit_FOUND)
message(STATUS "Found CUDA: " ${CUDAToolkit_VERSION})
include(../cuda_common.cmake)
add_executable(dist_inference dist_inference.cu)
set_property(TARGET dist_inference PROPERTY CUDA_ARCHITECTURES ${NVCC_ARCHS_SUPPORTED})
target_link_libraries(dist_inference MPI::MPI_CXX nccl cublasLt)
else()
# ROCm environment
include(../rocm_common.cmake)
find_package(hip QUIET)
if(hip_FOUND)
message(STATUS "Found ROCm: " ${HIP_VERSION})
# Convert cuda code to hip code in cpp
execute_process(COMMAND hipify-perl -print-stats -o dist_inference.cpp dist_inference.cu WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/)
# link hip device lib
add_executable(dist_inference dist_inference.cpp)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -DROCM_USE_FLOAT16=1")
target_link_libraries(dist_inference MPI::MPI_CXX rccl hipblaslt hip::device)
else()
message(FATAL_ERROR "No CUDA or ROCm environment found.")
endif()
endif()
install(TARGETS dist_inference RUNTIME DESTINATION bin)

Просмотреть файл

@ -0,0 +1,455 @@
/*******************************************************************************
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT License.
*
* MIT License
*
* Copyright (C) 2022-2023 Advanced Micro Devices, Inc.
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cuda.h>
#include <cuda_runtime.h>
#include <functional>
#include <iostream>
#include <limits>
#include <mpi.h>
#include <string>
#include <unistd.h>
#include <vector>
#if defined(__HIP_PLATFORM_AMD__)
#include <hipblaslt/hipblaslt.h>
#include <rccl/rccl.h>
using cublasLtHalf = hipblasLtHalf;
#else
#include <cublasLt.h>
#include <cuda_fp16.h>
#include <nccl.h>
using cublasLtHalf = half;
#endif
#ifndef CHECK_CUDA_ERROR
#define CHECK_CUDA_ERROR(error) \
if (error != cudaSuccess) { \
fprintf(stderr, "Cuda error: '%s'(%d) at %s:%d\n", cudaGetErrorString(error), error, __FILE__, __LINE__); \
exit(-1); \
}
#endif
#ifndef CHECK_CUBLASLT_ERROR
#define CHECK_CUBLASLT_ERROR(error) \
if (error != CUBLAS_STATUS_SUCCESS) { \
fprintf(stderr, "cuBLASLt error(Err=%d) at %s:%d\n", error, __FILE__, __LINE__); \
fprintf(stderr, "\n"); \
exit(-1); \
}
#endif
#ifndef CHECK_NCCL_ERROR
#define CHECK_NCCL_ERROR(error) \
if (error != ncclSuccess) { \
fprintf(stderr, "NCCL error(Err=%d) at %s:%d\n", error, __FILE__, __LINE__); \
fprintf(stderr, "\n"); \
exit(-1); \
}
#endif
static void ShowUsage(char *argv[]) {
std::cerr << "Usage: " << argv[0] << " <options>\n"
<< "options:\n"
<< "\t-h, --help\t\t\t\tShow this help message\n"
<< "\t-m \t\t\tm\t\tGEMM_STRIDED argument m\n"
<< "\t-n \t\t\tn\t\tGEMM_STRIDED argument n\n"
<< "\t-k \t\t\tk \t\tGEMM_STRIDED argument k\n"
<< "\t--alpha \t\talpha \t\tGEMM_STRIDED argument alpha\n"
<< "\t--beta \t\t\tbeta \t\tGEMM_STRIDED argument beta\n"
<< "\t--num_layers \t\t\tnum_layers \t\tNumber of layers in the model\n"
<< "\t--num_warmups \t\t\tnum_warmups \t\tNumber of warmup runs\n"
<< "\t--num_iters \t\t\tnum_iters \t\tNumber of test runs\n"
<< "\t--use_cuda_graph \t\t\tuse_cuda_graph \t\tWhether to launch kernels in CUDA graph mode\n"
<< std::endl;
}
static int ParseArguments(int argc, char *argv[], int64_t *m, int64_t *n, int64_t *k, float *alpha, float *beta,
int32_t *num_layers, int32_t *num_warmups, int32_t *num_iters, bool *use_cuda_graph) {
if (argc >= 2) {
for (int i = 1; i < argc; ++i) {
std::string arg = argv[i];
if ((arg.at(0) == '-') || ((arg.at(0) == '-') && (arg.at(1) == '-'))) {
if ((arg == "-h") || (arg == "--help")) {
return -1;
} else if ((arg == "-m") && (i + 1 < argc)) {
*m = atoi(argv[++i]);
} else if ((arg == "-n") && (i + 1 < argc)) {
*n = atoi(argv[++i]);
} else if ((arg == "-k") && (i + 1 < argc)) {
*k = atoi(argv[++i]);
} else if ((arg == "--alpha") && (i + 1 < argc)) {
*alpha = atof(argv[++i]);
} else if ((arg == "--beta") && (i + 1 < argc)) {
*beta = atof(argv[++i]);
} else if ((arg == "--num_layers") && (i + 1 < argc)) {
*num_layers = atoi(argv[++i]);
} else if ((arg == "--num_warmups") && (i + 1 < argc)) {
*num_warmups = atoi(argv[++i]);
} else if ((arg == "--num_iters") && (i + 1 < argc)) {
*num_iters = atoi(argv[++i]);
} else if (arg == "--use_cuda_graph") {
#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310)
*use_cuda_graph = true;
#else
*use_cuda_graph = false;
std::cerr << "error with " << arg << std::endl;
std::cerr << "not supported by current environment" << std::endl << std::endl;
return -1;
#endif
} else {
std::cerr << "error with " << arg << std::endl;
std::cerr << "do not recognize option" << std::endl << std::endl;
return -1;
}
} else {
std::cerr << "error with " << arg << std::endl;
std::cerr << "option must start with - or --" << std::endl << std::endl;
return -1;
}
}
}
return 0;
}
void InitializeABCDEF(std::vector<cublasLtHalf> &ha, int64_t size_a, std::vector<cublasLtHalf> &hb, int64_t size_b,
std::vector<cublasLtHalf> &hc, int64_t size_c, std::vector<cublasLtHalf> &hd, int64_t size_d,
std::vector<cublasLtHalf> &he, int64_t size_e, std::vector<cublasLtHalf> &hf, int64_t size_f) {
srand(1);
for (int i = 0; i < size_a; ++i) {
ha[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
}
for (int i = 0; i < size_b; ++i) {
hb[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
}
for (int i = 0; i < size_c; ++i) {
hc[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
}
for (int i = 0; i < size_d; ++i) {
hd[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
}
for (int i = 0; i < size_e; ++i) {
he[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
}
for (int i = 0; i < size_f; ++i) {
hf[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
}
}
// B[m, k] * A[k, n] + C[m, n] = D[m, n]
// E[k, m] * D[m, n] + F[k, n] = G[k, n]
void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t num_layers, int32_t num_warmups,
int32_t num_iters, bool use_cuda_graph, ncclComm_t nccl_comm) {
const int kNcclBufAlignment = 512;
int size_a = k * n;
int size_b = m * k;
int size_c = m * n;
int size_d = m * n;
int size_e = k * m;
int size_f = k * n;
int size_g = (k * n + kNcclBufAlignment - 1) / kNcclBufAlignment * kNcclBufAlignment;
// Naming: da is in GPU (device) memory. ha is in CPU (host) memory
std::vector<cublasLtHalf> ha(size_a);
std::vector<cublasLtHalf> hb(size_b);
std::vector<cublasLtHalf> hc(size_c);
std::vector<cublasLtHalf> hd(size_d);
std::vector<cublasLtHalf> he(size_e);
std::vector<cublasLtHalf> hf(size_f);
std::vector<cublasLtHalf> hg(size_g);
// initial data on host
InitializeABCDEF(ha, size_a, hb, size_b, hc, size_c, hd, size_d, he, size_e, hf, size_f);
// allocate memory on device
void *da, *db, *dc, *dd, *de, *df, *dg;
// Create stream
cudaStream_t stream = nullptr;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CHECK_CUDA_ERROR(cudaMalloc(&da, size_a * sizeof(cublasLtHalf)));
CHECK_CUDA_ERROR(cudaMalloc(&db, size_b * sizeof(cublasLtHalf)));
CHECK_CUDA_ERROR(cudaMalloc(&dc, size_c * sizeof(cublasLtHalf)));
CHECK_CUDA_ERROR(cudaMalloc(&dd, size_d * sizeof(cublasLtHalf)));
CHECK_CUDA_ERROR(cudaMalloc(&de, size_e * sizeof(cublasLtHalf)));
CHECK_CUDA_ERROR(cudaMalloc(&df, size_f * sizeof(cublasLtHalf)));
CHECK_CUDA_ERROR(cudaMalloc(&dg, size_g * sizeof(cublasLtHalf)));
// copy matrices from host to device
CHECK_CUDA_ERROR(cudaMemcpy(da, ha.data(), sizeof(cublasLtHalf) * size_a, cudaMemcpyHostToDevice));
CHECK_CUDA_ERROR(cudaMemcpy(db, hb.data(), sizeof(cublasLtHalf) * size_b, cudaMemcpyHostToDevice));
CHECK_CUDA_ERROR(cudaMemcpy(dc, hc.data(), sizeof(cublasLtHalf) * size_c, cudaMemcpyHostToDevice));
CHECK_CUDA_ERROR(cudaMemcpy(dd, hd.data(), sizeof(cublasLtHalf) * size_d, cudaMemcpyHostToDevice));
CHECK_CUDA_ERROR(cudaMemcpy(de, he.data(), sizeof(cublasLtHalf) * size_e, cudaMemcpyHostToDevice));
CHECK_CUDA_ERROR(cudaMemcpy(df, hf.data(), sizeof(cublasLtHalf) * size_f, cudaMemcpyHostToDevice));
uint64_t workspace_size = 1024 * 1024;
void *d_workspace;
CHECK_CUDA_ERROR(cudaMalloc(&d_workspace, workspace_size));
int returnedAlgoCount = 0;
// cublasLt is not well supported by ROCm hipify tools, explicitly define ROCm logic instead.
#if defined(__HIP_PLATFORM_AMD__)
hipblasLtHandle_t handle;
hipblasLtMatrixLayout_t matA, matB, matC, matD, matE, matF, matG;
hipblasLtMatmulDesc_t matmul1, matmul2;
hipblasLtMatmulPreference_t pref;
CHECK_CUBLASLT_ERROR(hipblasLtCreate(&handle));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, HIPBLAS_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, HIPBLAS_R_16F, m, k, m));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, HIPBLAS_R_16F, m, n, m));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, HIPBLAS_R_16F, m, n, m));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matE, HIPBLAS_R_16F, k, m, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matF, HIPBLAS_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matG, HIPBLAS_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul1, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul2, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F));
hipblasOperation_t trans = HIPBLAS_OP_N;
CHECK_CUBLASLT_ERROR(
hipblasLtMatmulDescSetAttribute(matmul1, HIPBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t)));
CHECK_CUBLASLT_ERROR(
hipblasLtMatmulDescSetAttribute(matmul1, HIPBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t)));
CHECK_CUBLASLT_ERROR(
hipblasLtMatmulDescSetAttribute(matmul2, HIPBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t)));
CHECK_CUBLASLT_ERROR(
hipblasLtMatmulDescSetAttribute(matmul2, HIPBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t)));
// Set User Preference attributes
CHECK_CUBLASLT_ERROR(hipblasLtMatmulPreferenceCreate(&pref));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulPreferenceSetAttribute(pref, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size, sizeof(workspace_size)));
// Get Heuristic results
hipblasLtMatmulHeuristicResult_t heuristicResult1[1] = {0};
hipblasLtMatmulHeuristicResult_t heuristicResult2[1] = {0};
// B[m, k] * A[k, n] + C[m, n] = D[m, n]
// E[k, m] * D[m, n] + F[k, n] = G[k, n]
CHECK_CUBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, matmul1, matB, matA, matC, matD, pref, 1,
heuristicResult1, &returnedAlgoCount));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, matmul2, matE, matD, matF, matG, pref, 1,
heuristicResult2, &returnedAlgoCount));
#else
cublasLtHandle_t handle;
cublasLtMatrixLayout_t matA, matB, matC, matD, matE, matF, matG;
cublasLtMatmulDesc_t matmul1, matmul2;
cublasLtMatmulPreference_t pref;
CHECK_CUBLASLT_ERROR(cublasLtCreate(&handle));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matA, CUDA_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matB, CUDA_R_16F, m, k, m));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matC, CUDA_R_16F, m, n, m));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matD, CUDA_R_16F, m, n, m));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matE, CUDA_R_16F, k, m, k));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matF, CUDA_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matG, CUDA_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescCreate(&matmul1, CUBLAS_COMPUTE_16F, CUDA_R_32F));
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescCreate(&matmul2, CUBLAS_COMPUTE_16F, CUDA_R_32F));
cublasOperation_t trans = CUBLAS_OP_N;
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul1, CUBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t)));
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul1, CUBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t)));
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul2, CUBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t)));
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul2, CUBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t)));
// Set User Preference attributes
CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceCreate(&pref));
CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size, sizeof(workspace_size)));
// Get Heuristic results
cublasLtMatmulHeuristicResult_t heuristicResult1[1] = {0};
cublasLtMatmulHeuristicResult_t heuristicResult2[1] = {0};
// B[m, k] * A[k, n] + C[m, n] = D[m, n]
// E[k, m] * D[m, n] + F[k, n] = G[k, n]
CHECK_CUBLASLT_ERROR(cublasLtMatmulAlgoGetHeuristic(handle, matmul1, matB, matA, matC, matD, pref, 1,
heuristicResult1, &returnedAlgoCount));
CHECK_CUBLASLT_ERROR(cublasLtMatmulAlgoGetHeuristic(handle, matmul2, matE, matD, matF, matG, pref, 1,
heuristicResult2, &returnedAlgoCount));
#endif
auto model_forward = [&] {
for (int j = 0; j < num_layers; j++) {
// B[m, k] * A[k, n] + C[m, n] = D[m, n]
// E[k, m] * D[m, n] + F[k, n] = G[k, n]
// cublasLt is not well supported by ROCm hipify tools, explicitly define ROCm logic instead.
#if defined(__HIP_PLATFORM_AMD__)
CHECK_CUBLASLT_ERROR(hipblasLtMatmul(handle, matmul1, &alpha, db, matB, da, matA, &beta, dc, matC, dd, matD,
&heuristicResult1[0].algo, d_workspace, workspace_size, stream));
CHECK_CUBLASLT_ERROR(hipblasLtMatmul(handle, matmul1, &alpha, de, matE, dd, matD, &beta, df, matF, dg, matG,
&heuristicResult2[0].algo, d_workspace, workspace_size, stream));
#else
CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, matmul1, &alpha, db, matB, da, matA, &beta, dc, matC, dd, matD,
&heuristicResult1[0].algo, d_workspace, workspace_size, stream));
CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, matmul1, &alpha, de, matE, dd, matD, &beta, df, matF, dg, matG,
&heuristicResult2[0].algo, d_workspace, workspace_size, stream));
#endif
CHECK_NCCL_ERROR(ncclAllReduce(dg, dg, size_g, ncclFloat16, ncclSum, nccl_comm, stream));
}
};
#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310)
cudaGraph_t graph;
cudaGraphExec_t instance;
if (use_cuda_graph) {
CHECK_CUDA_ERROR(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));
model_forward();
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &graph));
CHECK_CUDA_ERROR(cudaGraphInstantiate(&instance, graph, NULL, NULL, 0));
}
#endif
std::chrono::steady_clock::time_point start_time, stop_time;
for (int i = 0; i < num_warmups + num_iters; ++i) {
if (i == num_warmups) {
start_time = std::chrono::steady_clock::now();
}
#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310)
if (use_cuda_graph) {
CHECK_CUDA_ERROR(cudaGraphLaunch(instance, stream));
} else {
model_forward();
}
#else
model_forward();
#endif
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));
}
stop_time = std::chrono::steady_clock::now();
double duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop_time - start_time).count();
fprintf(stdout, "Time: %g ms in total, %g ms per iteration, %g ms per layer\n", duration, duration / num_iters,
duration / num_iters / num_layers);
#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310)
// Destroy graph
if (use_cuda_graph) {
CHECK_CUDA_ERROR(cudaGraphExecDestroy(instance));
CHECK_CUDA_ERROR(cudaGraphDestroy(graph));
}
#endif
// Destroy stream
CHECK_CUDA_ERROR(cudaStreamDestroy(stream));
CHECK_CUDA_ERROR(cudaFree(da));
CHECK_CUDA_ERROR(cudaFree(db));
CHECK_CUDA_ERROR(cudaFree(dc));
CHECK_CUDA_ERROR(cudaFree(dd));
CHECK_CUDA_ERROR(cudaFree(de));
CHECK_CUDA_ERROR(cudaFree(df));
CHECK_CUDA_ERROR(cudaFree(dg));
CHECK_CUDA_ERROR(cudaFree(d_workspace));
// cublasLt is not well supported by ROCm hipify tools, explicitly define ROCm logic instead.
#if defined(__HIP_PLATFORM_AMD__)
CHECK_CUBLASLT_ERROR(hipblasLtMatmulPreferenceDestroy(pref));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescDestroy(matmul1));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescDestroy(matmul2));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matA));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matB));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matC));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matD));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matE));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matF));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matG));
CHECK_CUBLASLT_ERROR(hipblasLtDestroy(handle));
#else
CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceDestroy(pref));
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescDestroy(matmul1));
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescDestroy(matmul2));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matA));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matB));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matC));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matD));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matE));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matF));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matG));
CHECK_CUBLASLT_ERROR(cublasLtDestroy(handle));
#endif
return;
}
int main(int argc, char *argv[]) {
// Init MPI
int comm_rank, comm_size;
MPI_Init(NULL, NULL);
MPI_Comm_rank(MPI_COMM_WORLD, &comm_rank);
MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
// Init NCCL
int num_local_ranks = 0;
ncclComm_t nccl_comm;
ncclUniqueId nccl_id;
if (comm_rank == 0) {
CHECK_NCCL_ERROR(ncclGetUniqueId(&nccl_id));
}
MPI_Bcast(&nccl_id, sizeof(ncclUniqueId), MPI_BYTE, 0, MPI_COMM_WORLD);
CHECK_CUDA_ERROR(cudaGetDeviceCount(&num_local_ranks))
CHECK_CUDA_ERROR(cudaSetDevice(comm_rank % num_local_ranks));
CHECK_NCCL_ERROR(ncclCommInitRank(&nccl_comm, comm_size, nccl_id, comm_rank));
// Init parameters with default values
int64_t m = 80;
int64_t n = 128;
int64_t k = 128;
float alpha = 1;
float beta = 1;
int32_t num_layers = 50;
int32_t num_warmups = 20;
int32_t num_iters = 100;
bool use_cuda_graph = false;
if (ParseArguments(argc, argv, &m, &n, &k, &alpha, &beta, &num_layers, &num_warmups, &num_iters, &use_cuda_graph)) {
ShowUsage(argv);
return -1;
}
fprintf(stdout,
"Parameters: m=%ld, n=%ld, k=%ld, alpha=%f, beta=%f, num_layers=%d, num_warmups=%d, num_iters=%d, "
"use_cuda_graph=%d\n",
m, n, k, alpha, beta, num_layers, num_warmups, num_iters, (int)use_cuda_graph);
TestModel(m, n, k, alpha, beta, num_layers, num_warmups, num_iters, use_cuda_graph, nccl_comm);
CHECK_NCCL_ERROR(ncclCommDestroy(nccl_comm));
MPI_Finalize();
return 0;
}

Просмотреть файл

@ -33,6 +33,8 @@ superbench:
node_num: 1 node_num: 1
env: env:
NCCL_ASYNC_ERROR_HANDLING: '0' NCCL_ASYNC_ERROR_HANDLING: '0'
frameworks:
- pytorch
common_model_config: &common_model_config common_model_config: &common_model_config
duration: 0 duration: 0
num_warmup: 64 num_warmup: 64

Просмотреть файл

@ -29,6 +29,8 @@ superbench:
node_num: 1 node_num: 1
env: env:
NCCL_ASYNC_ERROR_HANDLING: '0' NCCL_ASYNC_ERROR_HANDLING: '0'
frameworks:
- pytorch
common_model_config: &common_model_config common_model_config: &common_model_config
duration: 0 duration: 0
num_warmup: 64 num_warmup: 64

Просмотреть файл

@ -3,12 +3,15 @@
"""Tests for distributed inference benchmark.""" """Tests for distributed inference benchmark."""
import numbers
import unittest import unittest
from tests.helper import decorator from tests.helper import decorator
from tests.helper.testcase import BenchmarkTestCase
import tests.benchmarks.utils as utils import tests.benchmarks.utils as utils
from superbench.benchmarks \ from superbench.benchmarks \
import BenchmarkRegistry, Framework, BenchmarkType, ReturnCode, Precision, DistributedImpl, DistributedBackend import BenchmarkRegistry, Framework, BenchmarkType, ReturnCode, Precision, DistributedImpl, DistributedBackend, \
Platform
from superbench.benchmarks.micro_benchmarks.dist_inference \ from superbench.benchmarks.micro_benchmarks.dist_inference \
import DistInference, ComputationKernelType, CommunicationKernelType, ActivationKernelType import DistInference, ComputationKernelType, CommunicationKernelType, ActivationKernelType
from superbench.common.utils import network from superbench.common.utils import network
@ -20,7 +23,9 @@ from superbench.common.utils import network
@decorator.pytorch_test @decorator.pytorch_test
def test_pytorch_dist_inference_normal(): def test_pytorch_dist_inference_normal():
"""Test pytorch-dist-inference benchmark on distributed normal case.""" """Test pytorch-dist-inference benchmark on distributed normal case."""
context = BenchmarkRegistry.create_benchmark_context('dist-inference', parameters='', framework=Framework.PYTORCH) context = BenchmarkRegistry.create_benchmark_context(
'dist-inference', parameters='--use_pytorch', framework=Framework.PYTORCH
)
world_size = 2 world_size = 2
assert (BenchmarkRegistry.is_benchmark_context_valid(context)) assert (BenchmarkRegistry.is_benchmark_context_valid(context))
results = utils.simulated_ddp_distributed_benchmark(context, world_size) results = utils.simulated_ddp_distributed_benchmark(context, world_size)
@ -33,9 +38,12 @@ def test_pytorch_dist_inference_normal():
assert (benchmark.type == BenchmarkType.MICRO) assert (benchmark.type == BenchmarkType.MICRO)
# Check predefined parameters of dist-inference benchmark. # Check predefined parameters of dist-inference benchmark.
assert (benchmark._args.use_pytorch is True)
assert (benchmark._args.batch_size == 64) assert (benchmark._args.batch_size == 64)
assert (benchmark._args.input_size == 1024) assert (benchmark._args.input_size == 1024)
assert (benchmark._args.hidden_size == 1024) assert (benchmark._args.hidden_size == 1024)
assert (benchmark._args.alpha == 1.0)
assert (benchmark._args.beta == 1.0)
assert (benchmark._args.num_layers == 1) assert (benchmark._args.num_layers == 1)
assert (benchmark._args.computation_kernel == ComputationKernelType.MATMUL) assert (benchmark._args.computation_kernel == ComputationKernelType.MATMUL)
assert (benchmark._args.communication_kernel == CommunicationKernelType.ALLREDUCE) assert (benchmark._args.communication_kernel == CommunicationKernelType.ALLREDUCE)
@ -45,6 +53,7 @@ def test_pytorch_dist_inference_normal():
assert (benchmark._args.num_steps == 10000) assert (benchmark._args.num_steps == 10000)
assert (benchmark._args.distributed_impl == DistributedImpl.DDP) assert (benchmark._args.distributed_impl == DistributedImpl.DDP)
assert (benchmark._args.distributed_backend == DistributedBackend.NCCL) assert (benchmark._args.distributed_backend == DistributedBackend.NCCL)
assert (benchmark._args.use_cuda_graph is False)
# Check results and metrics. # Check results and metrics.
assert (benchmark.run_count == 1) assert (benchmark.run_count == 1)
@ -52,14 +61,16 @@ def test_pytorch_dist_inference_normal():
# step_times # step_times
assert (len(benchmark.raw_data) == 1) assert (len(benchmark.raw_data) == 1)
# return code + (avg, 50th, 90th, 95th, 99th, 99.9th) # return code + (avg, 50th, 90th, 95th, 99th, 99.9th)
assert (len(benchmark.result) == 7) assert (7 == len(benchmark.result))
@decorator.cuda_test @decorator.cuda_test
@decorator.pytorch_test @decorator.pytorch_test
def test_pytorch_dist_inference_fake_distributed(): def test_pytorch_dist_inference_fake_distributed():
"""Test pytorch-dist-inference benchmark on single gpu.""" """Test pytorch-dist-inference benchmark on single gpu."""
context = BenchmarkRegistry.create_benchmark_context('dist-inference', parameters='', framework=Framework.PYTORCH) context = BenchmarkRegistry.create_benchmark_context(
'dist-inference', parameters='--use_pytorch', framework=Framework.PYTORCH
)
port = network.get_free_port() port = network.get_free_port()
assert (port) assert (port)
utils.setup_simulated_ddp_distributed_env(1, 0, port) utils.setup_simulated_ddp_distributed_env(1, 0, port)
@ -72,9 +83,12 @@ def test_pytorch_dist_inference_fake_distributed():
assert (benchmark.type == BenchmarkType.MICRO) assert (benchmark.type == BenchmarkType.MICRO)
# Check predefined parameters of dist-inference benchmark. # Check predefined parameters of dist-inference benchmark.
assert (benchmark._args.use_pytorch is True)
assert (benchmark._args.batch_size == 64) assert (benchmark._args.batch_size == 64)
assert (benchmark._args.input_size == 1024) assert (benchmark._args.input_size == 1024)
assert (benchmark._args.hidden_size == 1024) assert (benchmark._args.hidden_size == 1024)
assert (benchmark._args.alpha == 1.0)
assert (benchmark._args.beta == 1.0)
assert (benchmark._args.num_layers == 1) assert (benchmark._args.num_layers == 1)
assert (benchmark._args.computation_kernel == ComputationKernelType.MATMUL) assert (benchmark._args.computation_kernel == ComputationKernelType.MATMUL)
assert (benchmark._args.communication_kernel == CommunicationKernelType.ALLREDUCE) assert (benchmark._args.communication_kernel == CommunicationKernelType.ALLREDUCE)
@ -84,6 +98,7 @@ def test_pytorch_dist_inference_fake_distributed():
assert (benchmark._args.num_steps == 10000) assert (benchmark._args.num_steps == 10000)
assert (benchmark._args.distributed_impl == DistributedImpl.DDP) assert (benchmark._args.distributed_impl == DistributedImpl.DDP)
assert (benchmark._args.distributed_backend == DistributedBackend.NCCL) assert (benchmark._args.distributed_backend == DistributedBackend.NCCL)
assert (benchmark._args.use_cuda_graph is False)
# Check results and metrics. # Check results and metrics.
assert (benchmark.run_count == 1) assert (benchmark.run_count == 1)
@ -94,3 +109,127 @@ def test_pytorch_dist_inference_fake_distributed():
assert (len(benchmark.result) == 7) assert (len(benchmark.result) == 7)
utils.clean_simulated_ddp_distributed_env() utils.clean_simulated_ddp_distributed_env()
class DistInferenceCppImplTest(BenchmarkTestCase, unittest.TestCase):
"""Test class for pytorch-dist-inference benchmark."""
@classmethod
def setUpClass(cls):
"""Hook method for setting up class fixture before running tests in the class."""
super().setUpClass()
cls.createMockEnvs(cls)
cls.createMockFiles(cls, ['bin/dist_inference'])
def _test_dist_inference_command_generation(self, platform):
"""Test pytorch-dist-inference cpp impl benchmark command generation."""
benchmark_name = 'pytorch-dist-inference'
(benchmark_class,
predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, platform)
assert (benchmark_class)
batch_size = 1
input_size = 2
hidden_size = 3
alpha = 4.0
beta = 5.0
num_layers = 6
num_warmup = 7
num_steps = 8
wrapper_params_format_str = \
'--batch_size %d --input_size %d --hidden_size %d ' \
'--alpha %g --beta %g --num_layers %d --num_warmup %d --num_steps %d --use_cuda_graph'
parameters = wrapper_params_format_str % (
batch_size, input_size, hidden_size, alpha, beta, num_layers, num_warmup, num_steps
)
benchmark = benchmark_class(benchmark_name, parameters=parameters)
# Check basic information
assert (benchmark)
ret = benchmark._preprocess()
assert (ret is True)
assert (benchmark.return_code == ReturnCode.SUCCESS)
assert (benchmark.name == benchmark_name)
assert (benchmark.type == BenchmarkType.MICRO)
# Check parameters specified in BenchmarkContext.
assert (benchmark._args.use_pytorch is False)
assert (benchmark._args.batch_size == batch_size)
assert (benchmark._args.input_size == input_size)
assert (benchmark._args.hidden_size == hidden_size)
assert (benchmark._args.alpha == alpha)
assert (benchmark._args.beta == beta)
assert (benchmark._args.num_layers == num_layers)
assert (benchmark._args.num_warmup == num_warmup)
assert (benchmark._args.num_steps == num_steps)
assert (benchmark._args.use_cuda_graph is True)
# Check command
assert (1 == len(benchmark._commands))
for cmd in benchmark._commands:
m, n, k = hidden_size, batch_size, input_size
bench_params_format_str = \
'%s -m %d -n %d -k %d --alpha %g --beta %g ' + \
'--num_layers %d --num_warmups %d --num_iters %d --use_cuda_graph'
assert (
cmd == (
bench_params_format_str %
(benchmark._DistInference__bin_path, m, n, k, alpha, beta, num_layers, num_warmup, num_steps)
)
)
@decorator.cuda_test
def test_dist_inference_command_generation_cuda(self):
"""Test pytorch-dist-inference cpp impl benchmark command generation, CUDA case."""
self._test_dist_inference_command_generation(Platform.CUDA)
@decorator.rocm_test
def test_dist_inference_command_generation_rocm(self):
"""Test pytorch-dist-inference cpp impl benchmark command generation, ROCm case."""
self._test_dist_inference_command_generation(Platform.ROCM)
@decorator.load_data('tests/data/dist_inference.log')
def _test_dist_inference_result_parsing(self, platform, test_raw_output):
"""Test pytorch-dist-inference cpp impl benchmark result parsing."""
benchmark_name = 'pytorch-dist-inference'
(benchmark_class,
predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, platform)
assert (benchmark_class)
benchmark = benchmark_class(benchmark_name, parameters='')
assert (benchmark)
ret = benchmark._preprocess()
assert (ret is True)
assert (benchmark.return_code == ReturnCode.SUCCESS)
assert (benchmark.name == 'pytorch-dist-inference')
assert (benchmark.type == BenchmarkType.MICRO)
# Positive case - valid raw output.
assert (benchmark._process_raw_result(0, test_raw_output))
assert (benchmark.return_code == ReturnCode.SUCCESS)
# step_times
assert (len(benchmark.raw_data) == 2)
# return code + (avg, 50th, 90th, 95th, 99th, 99.9th)
test_latency = float(test_raw_output.splitlines()[-1].split(' ms per iteration')[0].split()[-1])
assert (7 == len(benchmark.result))
for output_key in benchmark.result:
if output_key == 'return_code':
assert (benchmark.result[output_key] == [0])
else:
assert (output_key.startswith('step_times'))
assert (len(benchmark.result[output_key]) == 1)
assert (isinstance(benchmark.result[output_key][0], numbers.Number))
assert (test_latency == benchmark.result[output_key][0])
# Negative case - invalid raw output.
assert (benchmark._process_raw_result(1, 'Invalid raw output') is False)
assert (benchmark.return_code == ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)
@decorator.cuda_test
def test_dist_inference_result_parsing_cuda(self):
"""Test pytorch-dist-inference cpp impl benchmark result parsing, CUDA case."""
self._test_dist_inference_result_parsing(Platform.CUDA)
@decorator.rocm_test
def test_dist_inference_result_parsing_rocm(self):
"""Test pytorch-dist-inference cpp impl benchmark result parsing, ROCm case."""
self._test_dist_inference_result_parsing(Platform.ROCM)

Просмотреть файл

@ -0,0 +1,2 @@
Parameters: m=80, n=128, k=128, alpha=1.000000, beta=1.000000, num_layers=50, num_warmups=20, num_iters=100, use_cuda_graph=0
Time: 173 ms in total, 1.73 ms per iteration, 0.0346 ms per layer