Benchmarks: Microbenchmark - Add distributed inference benchmark cpp implementation (#586)
**Description** Add distributed inference benchmark cpp implementation.
This commit is contained in:
Родитель
1f5031bd74
Коммит
719a427fe7
|
@ -418,7 +418,7 @@ Test the performance of large scale matmul operation with multiple GPUs:
|
|||
|
||||
#### Introduction
|
||||
|
||||
Test the performance of distributed model inference.
|
||||
Test the performance of distributed model inference. Support both PyTorch implementation and cpp implementation.
|
||||
|
||||
#### Metrics
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch.distributed as dist
|
|||
|
||||
from superbench.common.utils import logger
|
||||
from superbench.benchmarks import DistributedImpl, DistributedBackend, BenchmarkRegistry, ReturnCode, Precision
|
||||
from superbench.benchmarks.micro_benchmarks import MicroBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke
|
||||
from superbench.benchmarks.context import Enum
|
||||
from superbench.benchmarks.reducer import ReduceType
|
||||
|
||||
|
@ -168,7 +168,7 @@ class DistInferenceModel(torch.nn.Module):
|
|||
return activation_out
|
||||
|
||||
|
||||
class DistInference(MicroBenchmark):
|
||||
class DistInference(MicroBenchmarkWithInvoke):
|
||||
"""The base class of micro-benchmarks."""
|
||||
def __init__(self, name, parameters=''):
|
||||
"""Constructor.
|
||||
|
@ -182,7 +182,9 @@ class DistInference(MicroBenchmark):
|
|||
self.__local_rank = 0
|
||||
torch.backends.cudnn.benchmark = True
|
||||
self.__device = None
|
||||
self.__cuda_available = False
|
||||
|
||||
# For cpp impl path
|
||||
self._bin_name = 'dist_inference'
|
||||
|
||||
def __timer(self):
|
||||
"""Returns the current time which ensures all previous CUDA events have been finished.
|
||||
|
@ -193,14 +195,19 @@ class DistInference(MicroBenchmark):
|
|||
Return:
|
||||
Current time in second.
|
||||
"""
|
||||
if self.__cuda_available:
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
return time.time()
|
||||
|
||||
def add_parser_arguments(self):
|
||||
"""Add the specified arguments."""
|
||||
super().add_parser_arguments()
|
||||
|
||||
self._parser.add_argument(
|
||||
'--use_pytorch',
|
||||
action='store_true',
|
||||
required=False,
|
||||
help='Whether to use pytorch implementation. If not, cpp implementation will be used.',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--batch_size',
|
||||
type=int,
|
||||
|
@ -222,6 +229,20 @@ class DistInference(MicroBenchmark):
|
|||
required=False,
|
||||
help='Hidden size.',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--alpha',
|
||||
type=float,
|
||||
default=1.0,
|
||||
required=False,
|
||||
help='Coefficient alpha in D = alpha*(A*B) + beta*(C).',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--beta',
|
||||
type=float,
|
||||
default=1.0,
|
||||
required=False,
|
||||
help='Coefficient beta in D = alpha*(A*B) + beta*(C).',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--num_layers',
|
||||
type=int,
|
||||
|
@ -285,6 +306,12 @@ class DistInference(MicroBenchmark):
|
|||
required=False,
|
||||
help='Distributed backends. E.g. {}.'.format(' '.join(DistributedBackend.get_values())),
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--use_cuda_graph',
|
||||
action='store_true',
|
||||
required=False,
|
||||
help='Whether to launch kernels in CUDA graph mode.',
|
||||
)
|
||||
|
||||
def _preprocess(self):
|
||||
"""Preprocess/preparation operations before the benchmarking.
|
||||
|
@ -295,32 +322,41 @@ class DistInference(MicroBenchmark):
|
|||
if not super()._preprocess():
|
||||
return False
|
||||
|
||||
if self._args.distributed_impl != DistributedImpl.DDP:
|
||||
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE)
|
||||
logger.error(
|
||||
'Unsupported distributed implementation - model: {}, distributed implementation: {}.'.format(
|
||||
self._name, self._args.distributed_impl
|
||||
if self._args.use_pytorch:
|
||||
# Initialize PyTorch if pytorch impl path
|
||||
if self._args.distributed_impl != DistributedImpl.DDP:
|
||||
return self._set_error_code_and_print_error_msg(
|
||||
ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE,
|
||||
'Unsupported distributed implementation - model: {}, distributed implementation: {}.'.format(
|
||||
self._name, self._args.distributed_impl
|
||||
)
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
torch.distributed.init_process_group(backend=self._args.distributed_backend.value)
|
||||
self.__world_size = int(os.environ['WORLD_SIZE'])
|
||||
self.__local_rank = int(os.environ['LOCAL_RANK'])
|
||||
except BaseException as e:
|
||||
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE)
|
||||
torch.distributed.destroy_process_group()
|
||||
logger.error('Initialize distributed env failed - benchmark: {}, message: {}.'.format(self._name, str(e)))
|
||||
return False
|
||||
try:
|
||||
torch.distributed.init_process_group(backend=self._args.distributed_backend.value)
|
||||
self.__world_size = int(os.environ['WORLD_SIZE'])
|
||||
self.__local_rank = int(os.environ['LOCAL_RANK'])
|
||||
assert (torch.cuda.is_available())
|
||||
except BaseException as e:
|
||||
torch.distributed.destroy_process_group()
|
||||
return self._set_error_code_and_print_error_msg(
|
||||
ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE,
|
||||
'Initialize distributed env failed - benchmark: {}, message: {}.'.format(self._name, str(e))
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(self.__local_rank)
|
||||
self.__device = torch.device('cuda:{}'.format(self.__local_rank))
|
||||
self.__cuda_available = True
|
||||
else:
|
||||
self.__device = torch.device('cpu:{}'.format(self.__local_rank))
|
||||
self.__cuda_available = False
|
||||
# Assemble commands if cpp impl path
|
||||
self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name)
|
||||
|
||||
args = '-m %d -n %d -k %d' % (self._args.hidden_size, self._args.batch_size, self._args.input_size)
|
||||
args += ' --alpha %g --beta %g' % (self._args.alpha, self._args.beta)
|
||||
args += ' --num_layers %d --num_warmups %d --num_iters %d' % \
|
||||
(self._args.num_layers, self._args.num_warmup, self._args.num_steps)
|
||||
if self._args.use_cuda_graph:
|
||||
args += ' --use_cuda_graph'
|
||||
self._commands = ['%s %s' % (self.__bin_path, args)]
|
||||
|
||||
return True
|
||||
|
||||
|
@ -347,8 +383,7 @@ class DistInference(MicroBenchmark):
|
|||
self.__device
|
||||
)
|
||||
model = model.to(dtype=getattr(torch, precision.value))
|
||||
if self.__cuda_available:
|
||||
model = model.cuda()
|
||||
model = model.cuda()
|
||||
return model
|
||||
|
||||
def _run_model(self, model, batch_size, input_size, precision, device, num_warmup, num_steps):
|
||||
|
@ -401,38 +436,78 @@ class DistInference(MicroBenchmark):
|
|||
Return:
|
||||
True if _benchmark succeeds.
|
||||
"""
|
||||
batch_size = self._args.batch_size
|
||||
input_size = self._args.input_size
|
||||
hidden_size = self._args.hidden_size
|
||||
num_layers = self._args.num_layers
|
||||
computation = self._args.computation_kernel
|
||||
communication = self._args.communication_kernel
|
||||
activation = self._args.activation_kernel
|
||||
precision = self._args.precision
|
||||
num_warmup = self._args.num_warmup
|
||||
num_steps = self._args.num_steps
|
||||
if self._args.use_pytorch:
|
||||
# Execute PyTorch model if pytorch impl path
|
||||
batch_size = self._args.batch_size
|
||||
input_size = self._args.input_size
|
||||
hidden_size = self._args.hidden_size
|
||||
num_layers = self._args.num_layers
|
||||
computation = self._args.computation_kernel
|
||||
communication = self._args.communication_kernel
|
||||
activation = self._args.activation_kernel
|
||||
precision = self._args.precision
|
||||
num_warmup = self._args.num_warmup
|
||||
num_steps = self._args.num_steps
|
||||
|
||||
if self.__local_rank == 0:
|
||||
logger.info(
|
||||
'Distributed Inference - using {} GPUs: '
|
||||
'batch_size={}, input_size={}, hidden_size={}, num_layers={}, '
|
||||
'computation_kernel={}, communication_kernel={}, activation_kernel={}, precision={}, '
|
||||
'num_warmup={} num_steps={}'.format(
|
||||
self.__world_size, batch_size, input_size, hidden_size, num_layers, computation, communication,
|
||||
activation, precision, num_warmup, num_steps
|
||||
if self.__local_rank == 0:
|
||||
logger.info(
|
||||
'Distributed Inference - using {} GPUs: '
|
||||
'batch_size={}, input_size={}, hidden_size={}, num_layers={}, '
|
||||
'computation_kernel={}, communication_kernel={}, activation_kernel={}, precision={}, '
|
||||
'num_warmup={} num_steps={}'.format(
|
||||
self.__world_size, batch_size, input_size, hidden_size, num_layers, computation, communication,
|
||||
activation, precision, num_warmup, num_steps
|
||||
)
|
||||
)
|
||||
|
||||
# Prepare model
|
||||
model = self._prepare_model(
|
||||
input_size, hidden_size, num_layers, computation, communication, activation, precision,
|
||||
self.__world_size
|
||||
)
|
||||
|
||||
# Prepare model
|
||||
model = self._prepare_model(
|
||||
input_size, hidden_size, num_layers, computation, communication, activation, precision, self.__world_size
|
||||
)
|
||||
# Run model
|
||||
step_times = self._run_model(model, batch_size, input_size, precision, self.__device, num_warmup, num_steps)
|
||||
|
||||
# Run model
|
||||
step_times = self._run_model(model, batch_size, input_size, precision, self.__device, num_warmup, num_steps)
|
||||
# Process data and return
|
||||
return self._process_data(step_times)
|
||||
else:
|
||||
# Execute commands if cpp impl path
|
||||
if not super()._benchmark():
|
||||
return False
|
||||
return True
|
||||
|
||||
# Process data and return
|
||||
return self._process_data(step_times)
|
||||
def _process_raw_result(self, cmd_idx, raw_output):
|
||||
"""Function to parse raw results and save the summarized results.
|
||||
|
||||
self._result.add_raw_data() and self._result.add_result() need to be called to save the results.
|
||||
|
||||
Args:
|
||||
cmd_idx (int): the index of command corresponding with the raw_output.
|
||||
raw_output (str): raw output string of the micro-benchmark.
|
||||
|
||||
Return:
|
||||
True if the raw output string is valid and result can be extracted.
|
||||
"""
|
||||
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data)
|
||||
|
||||
try:
|
||||
output_lines = [x.strip() for x in raw_output.strip().splitlines()]
|
||||
step_time = None
|
||||
for output_line in output_lines:
|
||||
if ' ms per iteration' in output_line:
|
||||
step_time = float(output_line.split(' ms per iteration')[0].split()[-1])
|
||||
break
|
||||
return self._process_numeric_result(
|
||||
'step_times', [step_time], reduce_type=ReduceType.MAX, cal_percentile=True
|
||||
)
|
||||
except BaseException as e:
|
||||
return self._set_error_code_and_print_error_msg(
|
||||
ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE,
|
||||
'The result format is invalid - round: {}, benchmark: {}, raw output: {}, message: {}.'.format(
|
||||
self._curr_run_index, self._name, raw_output, str(e)
|
||||
)
|
||||
)
|
||||
|
||||
def _postprocess(self):
|
||||
"""Postprocess/cleanup operations after the benchmarking.
|
||||
|
@ -443,14 +518,26 @@ class DistInference(MicroBenchmark):
|
|||
if not super()._postprocess():
|
||||
return False
|
||||
|
||||
try:
|
||||
torch.distributed.destroy_process_group()
|
||||
except BaseException as e:
|
||||
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_DESTROY_FAILURE)
|
||||
logger.error('Post process failed - benchmark: {}, message: {}.'.format(self._name, str(e)))
|
||||
return False
|
||||
if self._args.use_pytorch:
|
||||
try:
|
||||
torch.distributed.destroy_process_group()
|
||||
except BaseException as e:
|
||||
return self._set_error_code_and_print_error_msg(
|
||||
ReturnCode.DISTRIBUTED_SETTING_DESTROY_FAILURE,
|
||||
'Post process failed - benchmark: {}, message: {}.'.format(self._name, str(e))
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def _set_error_code_and_print_error_msg(self, error_code, error_msg):
|
||||
"""Set error code and print error log upon error.
|
||||
|
||||
Return:
|
||||
False, representing error.
|
||||
"""
|
||||
self._result.set_return_code(error_code)
|
||||
logger.error(error_msg)
|
||||
return False
|
||||
|
||||
|
||||
BenchmarkRegistry.register_benchmark('pytorch-dist-inference', DistInference, parameters='')
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
cmake_minimum_required(VERSION 3.18)
|
||||
|
||||
project(dist_inference LANGUAGES CXX)
|
||||
|
||||
find_package(MPI REQUIRED)
|
||||
include_directories(SYSTEM ${MPI_INCLUDE_PATH})
|
||||
|
||||
find_package(CUDAToolkit QUIET)
|
||||
|
||||
# Cuda environment
|
||||
if(CUDAToolkit_FOUND)
|
||||
message(STATUS "Found CUDA: " ${CUDAToolkit_VERSION})
|
||||
|
||||
include(../cuda_common.cmake)
|
||||
add_executable(dist_inference dist_inference.cu)
|
||||
set_property(TARGET dist_inference PROPERTY CUDA_ARCHITECTURES ${NVCC_ARCHS_SUPPORTED})
|
||||
target_link_libraries(dist_inference MPI::MPI_CXX nccl cublasLt)
|
||||
else()
|
||||
# ROCm environment
|
||||
include(../rocm_common.cmake)
|
||||
find_package(hip QUIET)
|
||||
if(hip_FOUND)
|
||||
message(STATUS "Found ROCm: " ${HIP_VERSION})
|
||||
|
||||
# Convert cuda code to hip code in cpp
|
||||
execute_process(COMMAND hipify-perl -print-stats -o dist_inference.cpp dist_inference.cu WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
|
||||
# link hip device lib
|
||||
add_executable(dist_inference dist_inference.cpp)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -DROCM_USE_FLOAT16=1")
|
||||
target_link_libraries(dist_inference MPI::MPI_CXX rccl hipblaslt hip::device)
|
||||
else()
|
||||
message(FATAL_ERROR "No CUDA or ROCm environment found.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
install(TARGETS dist_inference RUNTIME DESTINATION bin)
|
|
@ -0,0 +1,455 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) Microsoft Corporation.
|
||||
* Licensed under the MIT License.
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (C) 2022-2023 Advanced Micro Devices, Inc.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <mpi.h>
|
||||
#include <string>
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#include <hipblaslt/hipblaslt.h>
|
||||
#include <rccl/rccl.h>
|
||||
using cublasLtHalf = hipblasLtHalf;
|
||||
#else
|
||||
#include <cublasLt.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <nccl.h>
|
||||
using cublasLtHalf = half;
|
||||
#endif
|
||||
|
||||
#ifndef CHECK_CUDA_ERROR
|
||||
#define CHECK_CUDA_ERROR(error) \
|
||||
if (error != cudaSuccess) { \
|
||||
fprintf(stderr, "Cuda error: '%s'(%d) at %s:%d\n", cudaGetErrorString(error), error, __FILE__, __LINE__); \
|
||||
exit(-1); \
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef CHECK_CUBLASLT_ERROR
|
||||
#define CHECK_CUBLASLT_ERROR(error) \
|
||||
if (error != CUBLAS_STATUS_SUCCESS) { \
|
||||
fprintf(stderr, "cuBLASLt error(Err=%d) at %s:%d\n", error, __FILE__, __LINE__); \
|
||||
fprintf(stderr, "\n"); \
|
||||
exit(-1); \
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef CHECK_NCCL_ERROR
|
||||
#define CHECK_NCCL_ERROR(error) \
|
||||
if (error != ncclSuccess) { \
|
||||
fprintf(stderr, "NCCL error(Err=%d) at %s:%d\n", error, __FILE__, __LINE__); \
|
||||
fprintf(stderr, "\n"); \
|
||||
exit(-1); \
|
||||
}
|
||||
#endif
|
||||
|
||||
static void ShowUsage(char *argv[]) {
|
||||
std::cerr << "Usage: " << argv[0] << " <options>\n"
|
||||
<< "options:\n"
|
||||
<< "\t-h, --help\t\t\t\tShow this help message\n"
|
||||
<< "\t-m \t\t\tm\t\tGEMM_STRIDED argument m\n"
|
||||
<< "\t-n \t\t\tn\t\tGEMM_STRIDED argument n\n"
|
||||
<< "\t-k \t\t\tk \t\tGEMM_STRIDED argument k\n"
|
||||
<< "\t--alpha \t\talpha \t\tGEMM_STRIDED argument alpha\n"
|
||||
<< "\t--beta \t\t\tbeta \t\tGEMM_STRIDED argument beta\n"
|
||||
<< "\t--num_layers \t\t\tnum_layers \t\tNumber of layers in the model\n"
|
||||
<< "\t--num_warmups \t\t\tnum_warmups \t\tNumber of warmup runs\n"
|
||||
<< "\t--num_iters \t\t\tnum_iters \t\tNumber of test runs\n"
|
||||
<< "\t--use_cuda_graph \t\t\tuse_cuda_graph \t\tWhether to launch kernels in CUDA graph mode\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
static int ParseArguments(int argc, char *argv[], int64_t *m, int64_t *n, int64_t *k, float *alpha, float *beta,
|
||||
int32_t *num_layers, int32_t *num_warmups, int32_t *num_iters, bool *use_cuda_graph) {
|
||||
if (argc >= 2) {
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if ((arg.at(0) == '-') || ((arg.at(0) == '-') && (arg.at(1) == '-'))) {
|
||||
if ((arg == "-h") || (arg == "--help")) {
|
||||
return -1;
|
||||
} else if ((arg == "-m") && (i + 1 < argc)) {
|
||||
*m = atoi(argv[++i]);
|
||||
} else if ((arg == "-n") && (i + 1 < argc)) {
|
||||
*n = atoi(argv[++i]);
|
||||
} else if ((arg == "-k") && (i + 1 < argc)) {
|
||||
*k = atoi(argv[++i]);
|
||||
} else if ((arg == "--alpha") && (i + 1 < argc)) {
|
||||
*alpha = atof(argv[++i]);
|
||||
} else if ((arg == "--beta") && (i + 1 < argc)) {
|
||||
*beta = atof(argv[++i]);
|
||||
} else if ((arg == "--num_layers") && (i + 1 < argc)) {
|
||||
*num_layers = atoi(argv[++i]);
|
||||
} else if ((arg == "--num_warmups") && (i + 1 < argc)) {
|
||||
*num_warmups = atoi(argv[++i]);
|
||||
} else if ((arg == "--num_iters") && (i + 1 < argc)) {
|
||||
*num_iters = atoi(argv[++i]);
|
||||
} else if (arg == "--use_cuda_graph") {
|
||||
#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310)
|
||||
*use_cuda_graph = true;
|
||||
#else
|
||||
*use_cuda_graph = false;
|
||||
std::cerr << "error with " << arg << std::endl;
|
||||
std::cerr << "not supported by current environment" << std::endl << std::endl;
|
||||
return -1;
|
||||
#endif
|
||||
} else {
|
||||
std::cerr << "error with " << arg << std::endl;
|
||||
std::cerr << "do not recognize option" << std::endl << std::endl;
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
std::cerr << "error with " << arg << std::endl;
|
||||
std::cerr << "option must start with - or --" << std::endl << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void InitializeABCDEF(std::vector<cublasLtHalf> &ha, int64_t size_a, std::vector<cublasLtHalf> &hb, int64_t size_b,
|
||||
std::vector<cublasLtHalf> &hc, int64_t size_c, std::vector<cublasLtHalf> &hd, int64_t size_d,
|
||||
std::vector<cublasLtHalf> &he, int64_t size_e, std::vector<cublasLtHalf> &hf, int64_t size_f) {
|
||||
srand(1);
|
||||
for (int i = 0; i < size_a; ++i) {
|
||||
ha[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
|
||||
}
|
||||
for (int i = 0; i < size_b; ++i) {
|
||||
hb[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
|
||||
}
|
||||
for (int i = 0; i < size_c; ++i) {
|
||||
hc[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
|
||||
}
|
||||
for (int i = 0; i < size_d; ++i) {
|
||||
hd[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
|
||||
}
|
||||
for (int i = 0; i < size_e; ++i) {
|
||||
he[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
|
||||
}
|
||||
for (int i = 0; i < size_f; ++i) {
|
||||
hf[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
|
||||
}
|
||||
}
|
||||
|
||||
// B[m, k] * A[k, n] + C[m, n] = D[m, n]
|
||||
// E[k, m] * D[m, n] + F[k, n] = G[k, n]
|
||||
void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t num_layers, int32_t num_warmups,
|
||||
int32_t num_iters, bool use_cuda_graph, ncclComm_t nccl_comm) {
|
||||
const int kNcclBufAlignment = 512;
|
||||
|
||||
int size_a = k * n;
|
||||
int size_b = m * k;
|
||||
int size_c = m * n;
|
||||
int size_d = m * n;
|
||||
int size_e = k * m;
|
||||
int size_f = k * n;
|
||||
int size_g = (k * n + kNcclBufAlignment - 1) / kNcclBufAlignment * kNcclBufAlignment;
|
||||
|
||||
// Naming: da is in GPU (device) memory. ha is in CPU (host) memory
|
||||
std::vector<cublasLtHalf> ha(size_a);
|
||||
std::vector<cublasLtHalf> hb(size_b);
|
||||
std::vector<cublasLtHalf> hc(size_c);
|
||||
std::vector<cublasLtHalf> hd(size_d);
|
||||
std::vector<cublasLtHalf> he(size_e);
|
||||
std::vector<cublasLtHalf> hf(size_f);
|
||||
std::vector<cublasLtHalf> hg(size_g);
|
||||
|
||||
// initial data on host
|
||||
InitializeABCDEF(ha, size_a, hb, size_b, hc, size_c, hd, size_d, he, size_e, hf, size_f);
|
||||
|
||||
// allocate memory on device
|
||||
void *da, *db, *dc, *dd, *de, *df, *dg;
|
||||
|
||||
// Create stream
|
||||
cudaStream_t stream = nullptr;
|
||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
|
||||
CHECK_CUDA_ERROR(cudaMalloc(&da, size_a * sizeof(cublasLtHalf)));
|
||||
CHECK_CUDA_ERROR(cudaMalloc(&db, size_b * sizeof(cublasLtHalf)));
|
||||
CHECK_CUDA_ERROR(cudaMalloc(&dc, size_c * sizeof(cublasLtHalf)));
|
||||
CHECK_CUDA_ERROR(cudaMalloc(&dd, size_d * sizeof(cublasLtHalf)));
|
||||
CHECK_CUDA_ERROR(cudaMalloc(&de, size_e * sizeof(cublasLtHalf)));
|
||||
CHECK_CUDA_ERROR(cudaMalloc(&df, size_f * sizeof(cublasLtHalf)));
|
||||
CHECK_CUDA_ERROR(cudaMalloc(&dg, size_g * sizeof(cublasLtHalf)));
|
||||
// copy matrices from host to device
|
||||
CHECK_CUDA_ERROR(cudaMemcpy(da, ha.data(), sizeof(cublasLtHalf) * size_a, cudaMemcpyHostToDevice));
|
||||
CHECK_CUDA_ERROR(cudaMemcpy(db, hb.data(), sizeof(cublasLtHalf) * size_b, cudaMemcpyHostToDevice));
|
||||
CHECK_CUDA_ERROR(cudaMemcpy(dc, hc.data(), sizeof(cublasLtHalf) * size_c, cudaMemcpyHostToDevice));
|
||||
CHECK_CUDA_ERROR(cudaMemcpy(dd, hd.data(), sizeof(cublasLtHalf) * size_d, cudaMemcpyHostToDevice));
|
||||
CHECK_CUDA_ERROR(cudaMemcpy(de, he.data(), sizeof(cublasLtHalf) * size_e, cudaMemcpyHostToDevice));
|
||||
CHECK_CUDA_ERROR(cudaMemcpy(df, hf.data(), sizeof(cublasLtHalf) * size_f, cudaMemcpyHostToDevice));
|
||||
|
||||
uint64_t workspace_size = 1024 * 1024;
|
||||
void *d_workspace;
|
||||
CHECK_CUDA_ERROR(cudaMalloc(&d_workspace, workspace_size));
|
||||
int returnedAlgoCount = 0;
|
||||
|
||||
// cublasLt is not well supported by ROCm hipify tools, explicitly define ROCm logic instead.
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
hipblasLtHandle_t handle;
|
||||
hipblasLtMatrixLayout_t matA, matB, matC, matD, matE, matF, matG;
|
||||
hipblasLtMatmulDesc_t matmul1, matmul2;
|
||||
hipblasLtMatmulPreference_t pref;
|
||||
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtCreate(&handle));
|
||||
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, HIPBLAS_R_16F, k, n, k));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, HIPBLAS_R_16F, m, k, m));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, HIPBLAS_R_16F, m, n, m));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, HIPBLAS_R_16F, m, n, m));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matE, HIPBLAS_R_16F, k, m, k));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matF, HIPBLAS_R_16F, k, n, k));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matG, HIPBLAS_R_16F, k, n, k));
|
||||
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul1, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul2, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F));
|
||||
|
||||
hipblasOperation_t trans = HIPBLAS_OP_N;
|
||||
CHECK_CUBLASLT_ERROR(
|
||||
hipblasLtMatmulDescSetAttribute(matmul1, HIPBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t)));
|
||||
CHECK_CUBLASLT_ERROR(
|
||||
hipblasLtMatmulDescSetAttribute(matmul1, HIPBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t)));
|
||||
CHECK_CUBLASLT_ERROR(
|
||||
hipblasLtMatmulDescSetAttribute(matmul2, HIPBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t)));
|
||||
CHECK_CUBLASLT_ERROR(
|
||||
hipblasLtMatmulDescSetAttribute(matmul2, HIPBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t)));
|
||||
|
||||
// Set User Preference attributes
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatmulPreferenceCreate(&pref));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatmulPreferenceSetAttribute(pref, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&workspace_size, sizeof(workspace_size)));
|
||||
|
||||
// Get Heuristic results
|
||||
hipblasLtMatmulHeuristicResult_t heuristicResult1[1] = {0};
|
||||
hipblasLtMatmulHeuristicResult_t heuristicResult2[1] = {0};
|
||||
// B[m, k] * A[k, n] + C[m, n] = D[m, n]
|
||||
// E[k, m] * D[m, n] + F[k, n] = G[k, n]
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, matmul1, matB, matA, matC, matD, pref, 1,
|
||||
heuristicResult1, &returnedAlgoCount));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, matmul2, matE, matD, matF, matG, pref, 1,
|
||||
heuristicResult2, &returnedAlgoCount));
|
||||
#else
|
||||
cublasLtHandle_t handle;
|
||||
cublasLtMatrixLayout_t matA, matB, matC, matD, matE, matF, matG;
|
||||
cublasLtMatmulDesc_t matmul1, matmul2;
|
||||
cublasLtMatmulPreference_t pref;
|
||||
CHECK_CUBLASLT_ERROR(cublasLtCreate(&handle));
|
||||
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matA, CUDA_R_16F, k, n, k));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matB, CUDA_R_16F, m, k, m));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matC, CUDA_R_16F, m, n, m));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matD, CUDA_R_16F, m, n, m));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matE, CUDA_R_16F, k, m, k));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matF, CUDA_R_16F, k, n, k));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matG, CUDA_R_16F, k, n, k));
|
||||
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescCreate(&matmul1, CUBLAS_COMPUTE_16F, CUDA_R_32F));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescCreate(&matmul2, CUBLAS_COMPUTE_16F, CUDA_R_32F));
|
||||
|
||||
cublasOperation_t trans = CUBLAS_OP_N;
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul1, CUBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t)));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul1, CUBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t)));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul2, CUBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t)));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul2, CUBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t)));
|
||||
|
||||
// Set User Preference attributes
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceCreate(&pref));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&workspace_size, sizeof(workspace_size)));
|
||||
|
||||
// Get Heuristic results
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult1[1] = {0};
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult2[1] = {0};
|
||||
// B[m, k] * A[k, n] + C[m, n] = D[m, n]
|
||||
// E[k, m] * D[m, n] + F[k, n] = G[k, n]
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatmulAlgoGetHeuristic(handle, matmul1, matB, matA, matC, matD, pref, 1,
|
||||
heuristicResult1, &returnedAlgoCount));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatmulAlgoGetHeuristic(handle, matmul2, matE, matD, matF, matG, pref, 1,
|
||||
heuristicResult2, &returnedAlgoCount));
|
||||
#endif
|
||||
|
||||
auto model_forward = [&] {
|
||||
for (int j = 0; j < num_layers; j++) {
|
||||
// B[m, k] * A[k, n] + C[m, n] = D[m, n]
|
||||
// E[k, m] * D[m, n] + F[k, n] = G[k, n]
|
||||
// cublasLt is not well supported by ROCm hipify tools, explicitly define ROCm logic instead.
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatmul(handle, matmul1, &alpha, db, matB, da, matA, &beta, dc, matC, dd, matD,
|
||||
&heuristicResult1[0].algo, d_workspace, workspace_size, stream));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatmul(handle, matmul1, &alpha, de, matE, dd, matD, &beta, df, matF, dg, matG,
|
||||
&heuristicResult2[0].algo, d_workspace, workspace_size, stream));
|
||||
#else
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, matmul1, &alpha, db, matB, da, matA, &beta, dc, matC, dd, matD,
|
||||
&heuristicResult1[0].algo, d_workspace, workspace_size, stream));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, matmul1, &alpha, de, matE, dd, matD, &beta, df, matF, dg, matG,
|
||||
&heuristicResult2[0].algo, d_workspace, workspace_size, stream));
|
||||
#endif
|
||||
CHECK_NCCL_ERROR(ncclAllReduce(dg, dg, size_g, ncclFloat16, ncclSum, nccl_comm, stream));
|
||||
}
|
||||
};
|
||||
|
||||
#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310)
|
||||
cudaGraph_t graph;
|
||||
cudaGraphExec_t instance;
|
||||
if (use_cuda_graph) {
|
||||
CHECK_CUDA_ERROR(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));
|
||||
model_forward();
|
||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &graph));
|
||||
CHECK_CUDA_ERROR(cudaGraphInstantiate(&instance, graph, NULL, NULL, 0));
|
||||
}
|
||||
#endif
|
||||
|
||||
std::chrono::steady_clock::time_point start_time, stop_time;
|
||||
for (int i = 0; i < num_warmups + num_iters; ++i) {
|
||||
if (i == num_warmups) {
|
||||
start_time = std::chrono::steady_clock::now();
|
||||
}
|
||||
#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310)
|
||||
if (use_cuda_graph) {
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(instance, stream));
|
||||
} else {
|
||||
model_forward();
|
||||
}
|
||||
#else
|
||||
model_forward();
|
||||
#endif
|
||||
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));
|
||||
}
|
||||
stop_time = std::chrono::steady_clock::now();
|
||||
double duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop_time - start_time).count();
|
||||
fprintf(stdout, "Time: %g ms in total, %g ms per iteration, %g ms per layer\n", duration, duration / num_iters,
|
||||
duration / num_iters / num_layers);
|
||||
|
||||
#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310)
|
||||
// Destroy graph
|
||||
if (use_cuda_graph) {
|
||||
CHECK_CUDA_ERROR(cudaGraphExecDestroy(instance));
|
||||
CHECK_CUDA_ERROR(cudaGraphDestroy(graph));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Destroy stream
|
||||
CHECK_CUDA_ERROR(cudaStreamDestroy(stream));
|
||||
|
||||
CHECK_CUDA_ERROR(cudaFree(da));
|
||||
CHECK_CUDA_ERROR(cudaFree(db));
|
||||
CHECK_CUDA_ERROR(cudaFree(dc));
|
||||
CHECK_CUDA_ERROR(cudaFree(dd));
|
||||
CHECK_CUDA_ERROR(cudaFree(de));
|
||||
CHECK_CUDA_ERROR(cudaFree(df));
|
||||
CHECK_CUDA_ERROR(cudaFree(dg));
|
||||
CHECK_CUDA_ERROR(cudaFree(d_workspace));
|
||||
// cublasLt is not well supported by ROCm hipify tools, explicitly define ROCm logic instead.
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatmulPreferenceDestroy(pref));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescDestroy(matmul1));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescDestroy(matmul2));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matA));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matB));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matC));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matD));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matE));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matF));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matG));
|
||||
CHECK_CUBLASLT_ERROR(hipblasLtDestroy(handle));
|
||||
#else
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceDestroy(pref));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescDestroy(matmul1));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescDestroy(matmul2));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matA));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matB));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matC));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matD));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matE));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matF));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matG));
|
||||
CHECK_CUBLASLT_ERROR(cublasLtDestroy(handle));
|
||||
#endif
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
// Init MPI
|
||||
int comm_rank, comm_size;
|
||||
MPI_Init(NULL, NULL);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &comm_rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
|
||||
|
||||
// Init NCCL
|
||||
int num_local_ranks = 0;
|
||||
ncclComm_t nccl_comm;
|
||||
ncclUniqueId nccl_id;
|
||||
if (comm_rank == 0) {
|
||||
CHECK_NCCL_ERROR(ncclGetUniqueId(&nccl_id));
|
||||
}
|
||||
MPI_Bcast(&nccl_id, sizeof(ncclUniqueId), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
CHECK_CUDA_ERROR(cudaGetDeviceCount(&num_local_ranks))
|
||||
CHECK_CUDA_ERROR(cudaSetDevice(comm_rank % num_local_ranks));
|
||||
CHECK_NCCL_ERROR(ncclCommInitRank(&nccl_comm, comm_size, nccl_id, comm_rank));
|
||||
|
||||
// Init parameters with default values
|
||||
int64_t m = 80;
|
||||
int64_t n = 128;
|
||||
int64_t k = 128;
|
||||
float alpha = 1;
|
||||
float beta = 1;
|
||||
int32_t num_layers = 50;
|
||||
int32_t num_warmups = 20;
|
||||
int32_t num_iters = 100;
|
||||
bool use_cuda_graph = false;
|
||||
|
||||
if (ParseArguments(argc, argv, &m, &n, &k, &alpha, &beta, &num_layers, &num_warmups, &num_iters, &use_cuda_graph)) {
|
||||
ShowUsage(argv);
|
||||
return -1;
|
||||
}
|
||||
|
||||
fprintf(stdout,
|
||||
"Parameters: m=%ld, n=%ld, k=%ld, alpha=%f, beta=%f, num_layers=%d, num_warmups=%d, num_iters=%d, "
|
||||
"use_cuda_graph=%d\n",
|
||||
m, n, k, alpha, beta, num_layers, num_warmups, num_iters, (int)use_cuda_graph);
|
||||
|
||||
TestModel(m, n, k, alpha, beta, num_layers, num_warmups, num_iters, use_cuda_graph, nccl_comm);
|
||||
|
||||
CHECK_NCCL_ERROR(ncclCommDestroy(nccl_comm));
|
||||
|
||||
MPI_Finalize();
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -33,6 +33,8 @@ superbench:
|
|||
node_num: 1
|
||||
env:
|
||||
NCCL_ASYNC_ERROR_HANDLING: '0'
|
||||
frameworks:
|
||||
- pytorch
|
||||
common_model_config: &common_model_config
|
||||
duration: 0
|
||||
num_warmup: 64
|
||||
|
|
|
@ -29,6 +29,8 @@ superbench:
|
|||
node_num: 1
|
||||
env:
|
||||
NCCL_ASYNC_ERROR_HANDLING: '0'
|
||||
frameworks:
|
||||
- pytorch
|
||||
common_model_config: &common_model_config
|
||||
duration: 0
|
||||
num_warmup: 64
|
||||
|
|
|
@ -3,12 +3,15 @@
|
|||
|
||||
"""Tests for distributed inference benchmark."""
|
||||
|
||||
import numbers
|
||||
import unittest
|
||||
|
||||
from tests.helper import decorator
|
||||
from tests.helper.testcase import BenchmarkTestCase
|
||||
import tests.benchmarks.utils as utils
|
||||
from superbench.benchmarks \
|
||||
import BenchmarkRegistry, Framework, BenchmarkType, ReturnCode, Precision, DistributedImpl, DistributedBackend
|
||||
import BenchmarkRegistry, Framework, BenchmarkType, ReturnCode, Precision, DistributedImpl, DistributedBackend, \
|
||||
Platform
|
||||
from superbench.benchmarks.micro_benchmarks.dist_inference \
|
||||
import DistInference, ComputationKernelType, CommunicationKernelType, ActivationKernelType
|
||||
from superbench.common.utils import network
|
||||
|
@ -20,7 +23,9 @@ from superbench.common.utils import network
|
|||
@decorator.pytorch_test
|
||||
def test_pytorch_dist_inference_normal():
|
||||
"""Test pytorch-dist-inference benchmark on distributed normal case."""
|
||||
context = BenchmarkRegistry.create_benchmark_context('dist-inference', parameters='', framework=Framework.PYTORCH)
|
||||
context = BenchmarkRegistry.create_benchmark_context(
|
||||
'dist-inference', parameters='--use_pytorch', framework=Framework.PYTORCH
|
||||
)
|
||||
world_size = 2
|
||||
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
|
||||
results = utils.simulated_ddp_distributed_benchmark(context, world_size)
|
||||
|
@ -33,9 +38,12 @@ def test_pytorch_dist_inference_normal():
|
|||
assert (benchmark.type == BenchmarkType.MICRO)
|
||||
|
||||
# Check predefined parameters of dist-inference benchmark.
|
||||
assert (benchmark._args.use_pytorch is True)
|
||||
assert (benchmark._args.batch_size == 64)
|
||||
assert (benchmark._args.input_size == 1024)
|
||||
assert (benchmark._args.hidden_size == 1024)
|
||||
assert (benchmark._args.alpha == 1.0)
|
||||
assert (benchmark._args.beta == 1.0)
|
||||
assert (benchmark._args.num_layers == 1)
|
||||
assert (benchmark._args.computation_kernel == ComputationKernelType.MATMUL)
|
||||
assert (benchmark._args.communication_kernel == CommunicationKernelType.ALLREDUCE)
|
||||
|
@ -45,6 +53,7 @@ def test_pytorch_dist_inference_normal():
|
|||
assert (benchmark._args.num_steps == 10000)
|
||||
assert (benchmark._args.distributed_impl == DistributedImpl.DDP)
|
||||
assert (benchmark._args.distributed_backend == DistributedBackend.NCCL)
|
||||
assert (benchmark._args.use_cuda_graph is False)
|
||||
|
||||
# Check results and metrics.
|
||||
assert (benchmark.run_count == 1)
|
||||
|
@ -52,14 +61,16 @@ def test_pytorch_dist_inference_normal():
|
|||
# step_times
|
||||
assert (len(benchmark.raw_data) == 1)
|
||||
# return code + (avg, 50th, 90th, 95th, 99th, 99.9th)
|
||||
assert (len(benchmark.result) == 7)
|
||||
assert (7 == len(benchmark.result))
|
||||
|
||||
|
||||
@decorator.cuda_test
|
||||
@decorator.pytorch_test
|
||||
def test_pytorch_dist_inference_fake_distributed():
|
||||
"""Test pytorch-dist-inference benchmark on single gpu."""
|
||||
context = BenchmarkRegistry.create_benchmark_context('dist-inference', parameters='', framework=Framework.PYTORCH)
|
||||
context = BenchmarkRegistry.create_benchmark_context(
|
||||
'dist-inference', parameters='--use_pytorch', framework=Framework.PYTORCH
|
||||
)
|
||||
port = network.get_free_port()
|
||||
assert (port)
|
||||
utils.setup_simulated_ddp_distributed_env(1, 0, port)
|
||||
|
@ -72,9 +83,12 @@ def test_pytorch_dist_inference_fake_distributed():
|
|||
assert (benchmark.type == BenchmarkType.MICRO)
|
||||
|
||||
# Check predefined parameters of dist-inference benchmark.
|
||||
assert (benchmark._args.use_pytorch is True)
|
||||
assert (benchmark._args.batch_size == 64)
|
||||
assert (benchmark._args.input_size == 1024)
|
||||
assert (benchmark._args.hidden_size == 1024)
|
||||
assert (benchmark._args.alpha == 1.0)
|
||||
assert (benchmark._args.beta == 1.0)
|
||||
assert (benchmark._args.num_layers == 1)
|
||||
assert (benchmark._args.computation_kernel == ComputationKernelType.MATMUL)
|
||||
assert (benchmark._args.communication_kernel == CommunicationKernelType.ALLREDUCE)
|
||||
|
@ -84,6 +98,7 @@ def test_pytorch_dist_inference_fake_distributed():
|
|||
assert (benchmark._args.num_steps == 10000)
|
||||
assert (benchmark._args.distributed_impl == DistributedImpl.DDP)
|
||||
assert (benchmark._args.distributed_backend == DistributedBackend.NCCL)
|
||||
assert (benchmark._args.use_cuda_graph is False)
|
||||
|
||||
# Check results and metrics.
|
||||
assert (benchmark.run_count == 1)
|
||||
|
@ -94,3 +109,127 @@ def test_pytorch_dist_inference_fake_distributed():
|
|||
assert (len(benchmark.result) == 7)
|
||||
|
||||
utils.clean_simulated_ddp_distributed_env()
|
||||
|
||||
|
||||
class DistInferenceCppImplTest(BenchmarkTestCase, unittest.TestCase):
|
||||
"""Test class for pytorch-dist-inference benchmark."""
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Hook method for setting up class fixture before running tests in the class."""
|
||||
super().setUpClass()
|
||||
cls.createMockEnvs(cls)
|
||||
cls.createMockFiles(cls, ['bin/dist_inference'])
|
||||
|
||||
def _test_dist_inference_command_generation(self, platform):
|
||||
"""Test pytorch-dist-inference cpp impl benchmark command generation."""
|
||||
benchmark_name = 'pytorch-dist-inference'
|
||||
(benchmark_class,
|
||||
predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, platform)
|
||||
assert (benchmark_class)
|
||||
|
||||
batch_size = 1
|
||||
input_size = 2
|
||||
hidden_size = 3
|
||||
alpha = 4.0
|
||||
beta = 5.0
|
||||
num_layers = 6
|
||||
num_warmup = 7
|
||||
num_steps = 8
|
||||
wrapper_params_format_str = \
|
||||
'--batch_size %d --input_size %d --hidden_size %d ' \
|
||||
'--alpha %g --beta %g --num_layers %d --num_warmup %d --num_steps %d --use_cuda_graph'
|
||||
parameters = wrapper_params_format_str % (
|
||||
batch_size, input_size, hidden_size, alpha, beta, num_layers, num_warmup, num_steps
|
||||
)
|
||||
benchmark = benchmark_class(benchmark_name, parameters=parameters)
|
||||
|
||||
# Check basic information
|
||||
assert (benchmark)
|
||||
ret = benchmark._preprocess()
|
||||
assert (ret is True)
|
||||
assert (benchmark.return_code == ReturnCode.SUCCESS)
|
||||
assert (benchmark.name == benchmark_name)
|
||||
assert (benchmark.type == BenchmarkType.MICRO)
|
||||
|
||||
# Check parameters specified in BenchmarkContext.
|
||||
assert (benchmark._args.use_pytorch is False)
|
||||
assert (benchmark._args.batch_size == batch_size)
|
||||
assert (benchmark._args.input_size == input_size)
|
||||
assert (benchmark._args.hidden_size == hidden_size)
|
||||
assert (benchmark._args.alpha == alpha)
|
||||
assert (benchmark._args.beta == beta)
|
||||
assert (benchmark._args.num_layers == num_layers)
|
||||
assert (benchmark._args.num_warmup == num_warmup)
|
||||
assert (benchmark._args.num_steps == num_steps)
|
||||
assert (benchmark._args.use_cuda_graph is True)
|
||||
|
||||
# Check command
|
||||
assert (1 == len(benchmark._commands))
|
||||
for cmd in benchmark._commands:
|
||||
m, n, k = hidden_size, batch_size, input_size
|
||||
bench_params_format_str = \
|
||||
'%s -m %d -n %d -k %d --alpha %g --beta %g ' + \
|
||||
'--num_layers %d --num_warmups %d --num_iters %d --use_cuda_graph'
|
||||
assert (
|
||||
cmd == (
|
||||
bench_params_format_str %
|
||||
(benchmark._DistInference__bin_path, m, n, k, alpha, beta, num_layers, num_warmup, num_steps)
|
||||
)
|
||||
)
|
||||
|
||||
@decorator.cuda_test
|
||||
def test_dist_inference_command_generation_cuda(self):
|
||||
"""Test pytorch-dist-inference cpp impl benchmark command generation, CUDA case."""
|
||||
self._test_dist_inference_command_generation(Platform.CUDA)
|
||||
|
||||
@decorator.rocm_test
|
||||
def test_dist_inference_command_generation_rocm(self):
|
||||
"""Test pytorch-dist-inference cpp impl benchmark command generation, ROCm case."""
|
||||
self._test_dist_inference_command_generation(Platform.ROCM)
|
||||
|
||||
@decorator.load_data('tests/data/dist_inference.log')
|
||||
def _test_dist_inference_result_parsing(self, platform, test_raw_output):
|
||||
"""Test pytorch-dist-inference cpp impl benchmark result parsing."""
|
||||
benchmark_name = 'pytorch-dist-inference'
|
||||
(benchmark_class,
|
||||
predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, platform)
|
||||
assert (benchmark_class)
|
||||
benchmark = benchmark_class(benchmark_name, parameters='')
|
||||
assert (benchmark)
|
||||
ret = benchmark._preprocess()
|
||||
assert (ret is True)
|
||||
assert (benchmark.return_code == ReturnCode.SUCCESS)
|
||||
assert (benchmark.name == 'pytorch-dist-inference')
|
||||
assert (benchmark.type == BenchmarkType.MICRO)
|
||||
|
||||
# Positive case - valid raw output.
|
||||
assert (benchmark._process_raw_result(0, test_raw_output))
|
||||
assert (benchmark.return_code == ReturnCode.SUCCESS)
|
||||
|
||||
# step_times
|
||||
assert (len(benchmark.raw_data) == 2)
|
||||
# return code + (avg, 50th, 90th, 95th, 99th, 99.9th)
|
||||
test_latency = float(test_raw_output.splitlines()[-1].split(' ms per iteration')[0].split()[-1])
|
||||
assert (7 == len(benchmark.result))
|
||||
for output_key in benchmark.result:
|
||||
if output_key == 'return_code':
|
||||
assert (benchmark.result[output_key] == [0])
|
||||
else:
|
||||
assert (output_key.startswith('step_times'))
|
||||
assert (len(benchmark.result[output_key]) == 1)
|
||||
assert (isinstance(benchmark.result[output_key][0], numbers.Number))
|
||||
assert (test_latency == benchmark.result[output_key][0])
|
||||
|
||||
# Negative case - invalid raw output.
|
||||
assert (benchmark._process_raw_result(1, 'Invalid raw output') is False)
|
||||
assert (benchmark.return_code == ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)
|
||||
|
||||
@decorator.cuda_test
|
||||
def test_dist_inference_result_parsing_cuda(self):
|
||||
"""Test pytorch-dist-inference cpp impl benchmark result parsing, CUDA case."""
|
||||
self._test_dist_inference_result_parsing(Platform.CUDA)
|
||||
|
||||
@decorator.rocm_test
|
||||
def test_dist_inference_result_parsing_rocm(self):
|
||||
"""Test pytorch-dist-inference cpp impl benchmark result parsing, ROCm case."""
|
||||
self._test_dist_inference_result_parsing(Platform.ROCM)
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
Parameters: m=80, n=128, k=128, alpha=1.000000, beta=1.000000, num_layers=50, num_warmups=20, num_iters=100, use_cuda_graph=0
|
||||
Time: 173 ms in total, 1.73 ms per iteration, 0.0346 ms per layer
|
Загрузка…
Ссылка в новой задаче