Benchmarks: Microbenchmark - Add distributed inference benchmark cpp implementation (#586)

**Description**
Add distributed inference benchmark cpp implementation.
This commit is contained in:
Ziyue Yang 2023-12-11 06:53:51 +08:00 коммит произвёл GitHub
Родитель 1f5031bd74
Коммит 719a427fe7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 791 добавлений и 64 удалений

Просмотреть файл

@ -418,7 +418,7 @@ Test the performance of large scale matmul operation with multiple GPUs:
#### Introduction
Test the performance of distributed model inference.
Test the performance of distributed model inference. Support both PyTorch implementation and cpp implementation.
#### Metrics

Просмотреть файл

@ -12,7 +12,7 @@ import torch.distributed as dist
from superbench.common.utils import logger
from superbench.benchmarks import DistributedImpl, DistributedBackend, BenchmarkRegistry, ReturnCode, Precision
from superbench.benchmarks.micro_benchmarks import MicroBenchmark
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke
from superbench.benchmarks.context import Enum
from superbench.benchmarks.reducer import ReduceType
@ -168,7 +168,7 @@ class DistInferenceModel(torch.nn.Module):
return activation_out
class DistInference(MicroBenchmark):
class DistInference(MicroBenchmarkWithInvoke):
"""The base class of micro-benchmarks."""
def __init__(self, name, parameters=''):
"""Constructor.
@ -182,7 +182,9 @@ class DistInference(MicroBenchmark):
self.__local_rank = 0
torch.backends.cudnn.benchmark = True
self.__device = None
self.__cuda_available = False
# For cpp impl path
self._bin_name = 'dist_inference'
def __timer(self):
"""Returns the current time which ensures all previous CUDA events have been finished.
@ -193,14 +195,19 @@ class DistInference(MicroBenchmark):
Return:
Current time in second.
"""
if self.__cuda_available:
torch.cuda.synchronize()
torch.cuda.synchronize()
return time.time()
def add_parser_arguments(self):
"""Add the specified arguments."""
super().add_parser_arguments()
self._parser.add_argument(
'--use_pytorch',
action='store_true',
required=False,
help='Whether to use pytorch implementation. If not, cpp implementation will be used.',
)
self._parser.add_argument(
'--batch_size',
type=int,
@ -222,6 +229,20 @@ class DistInference(MicroBenchmark):
required=False,
help='Hidden size.',
)
self._parser.add_argument(
'--alpha',
type=float,
default=1.0,
required=False,
help='Coefficient alpha in D = alpha*(A*B) + beta*(C).',
)
self._parser.add_argument(
'--beta',
type=float,
default=1.0,
required=False,
help='Coefficient beta in D = alpha*(A*B) + beta*(C).',
)
self._parser.add_argument(
'--num_layers',
type=int,
@ -285,6 +306,12 @@ class DistInference(MicroBenchmark):
required=False,
help='Distributed backends. E.g. {}.'.format(' '.join(DistributedBackend.get_values())),
)
self._parser.add_argument(
'--use_cuda_graph',
action='store_true',
required=False,
help='Whether to launch kernels in CUDA graph mode.',
)
def _preprocess(self):
"""Preprocess/preparation operations before the benchmarking.
@ -295,32 +322,41 @@ class DistInference(MicroBenchmark):
if not super()._preprocess():
return False
if self._args.distributed_impl != DistributedImpl.DDP:
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE)
logger.error(
'Unsupported distributed implementation - model: {}, distributed implementation: {}.'.format(
self._name, self._args.distributed_impl
if self._args.use_pytorch:
# Initialize PyTorch if pytorch impl path
if self._args.distributed_impl != DistributedImpl.DDP:
return self._set_error_code_and_print_error_msg(
ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE,
'Unsupported distributed implementation - model: {}, distributed implementation: {}.'.format(
self._name, self._args.distributed_impl
)
)
)
return False
try:
torch.distributed.init_process_group(backend=self._args.distributed_backend.value)
self.__world_size = int(os.environ['WORLD_SIZE'])
self.__local_rank = int(os.environ['LOCAL_RANK'])
except BaseException as e:
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE)
torch.distributed.destroy_process_group()
logger.error('Initialize distributed env failed - benchmark: {}, message: {}.'.format(self._name, str(e)))
return False
try:
torch.distributed.init_process_group(backend=self._args.distributed_backend.value)
self.__world_size = int(os.environ['WORLD_SIZE'])
self.__local_rank = int(os.environ['LOCAL_RANK'])
assert (torch.cuda.is_available())
except BaseException as e:
torch.distributed.destroy_process_group()
return self._set_error_code_and_print_error_msg(
ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE,
'Initialize distributed env failed - benchmark: {}, message: {}.'.format(self._name, str(e))
)
if torch.cuda.is_available():
torch.cuda.set_device(self.__local_rank)
self.__device = torch.device('cuda:{}'.format(self.__local_rank))
self.__cuda_available = True
else:
self.__device = torch.device('cpu:{}'.format(self.__local_rank))
self.__cuda_available = False
# Assemble commands if cpp impl path
self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name)
args = '-m %d -n %d -k %d' % (self._args.hidden_size, self._args.batch_size, self._args.input_size)
args += ' --alpha %g --beta %g' % (self._args.alpha, self._args.beta)
args += ' --num_layers %d --num_warmups %d --num_iters %d' % \
(self._args.num_layers, self._args.num_warmup, self._args.num_steps)
if self._args.use_cuda_graph:
args += ' --use_cuda_graph'
self._commands = ['%s %s' % (self.__bin_path, args)]
return True
@ -347,8 +383,7 @@ class DistInference(MicroBenchmark):
self.__device
)
model = model.to(dtype=getattr(torch, precision.value))
if self.__cuda_available:
model = model.cuda()
model = model.cuda()
return model
def _run_model(self, model, batch_size, input_size, precision, device, num_warmup, num_steps):
@ -401,38 +436,78 @@ class DistInference(MicroBenchmark):
Return:
True if _benchmark succeeds.
"""
batch_size = self._args.batch_size
input_size = self._args.input_size
hidden_size = self._args.hidden_size
num_layers = self._args.num_layers
computation = self._args.computation_kernel
communication = self._args.communication_kernel
activation = self._args.activation_kernel
precision = self._args.precision
num_warmup = self._args.num_warmup
num_steps = self._args.num_steps
if self._args.use_pytorch:
# Execute PyTorch model if pytorch impl path
batch_size = self._args.batch_size
input_size = self._args.input_size
hidden_size = self._args.hidden_size
num_layers = self._args.num_layers
computation = self._args.computation_kernel
communication = self._args.communication_kernel
activation = self._args.activation_kernel
precision = self._args.precision
num_warmup = self._args.num_warmup
num_steps = self._args.num_steps
if self.__local_rank == 0:
logger.info(
'Distributed Inference - using {} GPUs: '
'batch_size={}, input_size={}, hidden_size={}, num_layers={}, '
'computation_kernel={}, communication_kernel={}, activation_kernel={}, precision={}, '
'num_warmup={} num_steps={}'.format(
self.__world_size, batch_size, input_size, hidden_size, num_layers, computation, communication,
activation, precision, num_warmup, num_steps
if self.__local_rank == 0:
logger.info(
'Distributed Inference - using {} GPUs: '
'batch_size={}, input_size={}, hidden_size={}, num_layers={}, '
'computation_kernel={}, communication_kernel={}, activation_kernel={}, precision={}, '
'num_warmup={} num_steps={}'.format(
self.__world_size, batch_size, input_size, hidden_size, num_layers, computation, communication,
activation, precision, num_warmup, num_steps
)
)
# Prepare model
model = self._prepare_model(
input_size, hidden_size, num_layers, computation, communication, activation, precision,
self.__world_size
)
# Prepare model
model = self._prepare_model(
input_size, hidden_size, num_layers, computation, communication, activation, precision, self.__world_size
)
# Run model
step_times = self._run_model(model, batch_size, input_size, precision, self.__device, num_warmup, num_steps)
# Run model
step_times = self._run_model(model, batch_size, input_size, precision, self.__device, num_warmup, num_steps)
# Process data and return
return self._process_data(step_times)
else:
# Execute commands if cpp impl path
if not super()._benchmark():
return False
return True
# Process data and return
return self._process_data(step_times)
def _process_raw_result(self, cmd_idx, raw_output):
"""Function to parse raw results and save the summarized results.
self._result.add_raw_data() and self._result.add_result() need to be called to save the results.
Args:
cmd_idx (int): the index of command corresponding with the raw_output.
raw_output (str): raw output string of the micro-benchmark.
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data)
try:
output_lines = [x.strip() for x in raw_output.strip().splitlines()]
step_time = None
for output_line in output_lines:
if ' ms per iteration' in output_line:
step_time = float(output_line.split(' ms per iteration')[0].split()[-1])
break
return self._process_numeric_result(
'step_times', [step_time], reduce_type=ReduceType.MAX, cal_percentile=True
)
except BaseException as e:
return self._set_error_code_and_print_error_msg(
ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE,
'The result format is invalid - round: {}, benchmark: {}, raw output: {}, message: {}.'.format(
self._curr_run_index, self._name, raw_output, str(e)
)
)
def _postprocess(self):
"""Postprocess/cleanup operations after the benchmarking.
@ -443,14 +518,26 @@ class DistInference(MicroBenchmark):
if not super()._postprocess():
return False
try:
torch.distributed.destroy_process_group()
except BaseException as e:
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_DESTROY_FAILURE)
logger.error('Post process failed - benchmark: {}, message: {}.'.format(self._name, str(e)))
return False
if self._args.use_pytorch:
try:
torch.distributed.destroy_process_group()
except BaseException as e:
return self._set_error_code_and_print_error_msg(
ReturnCode.DISTRIBUTED_SETTING_DESTROY_FAILURE,
'Post process failed - benchmark: {}, message: {}.'.format(self._name, str(e))
)
return True
def _set_error_code_and_print_error_msg(self, error_code, error_msg):
"""Set error code and print error log upon error.
Return:
False, representing error.
"""
self._result.set_return_code(error_code)
logger.error(error_msg)
return False
BenchmarkRegistry.register_benchmark('pytorch-dist-inference', DistInference, parameters='')

Просмотреть файл

@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
cmake_minimum_required(VERSION 3.18)
project(dist_inference LANGUAGES CXX)
find_package(MPI REQUIRED)
include_directories(SYSTEM ${MPI_INCLUDE_PATH})
find_package(CUDAToolkit QUIET)
# Cuda environment
if(CUDAToolkit_FOUND)
message(STATUS "Found CUDA: " ${CUDAToolkit_VERSION})
include(../cuda_common.cmake)
add_executable(dist_inference dist_inference.cu)
set_property(TARGET dist_inference PROPERTY CUDA_ARCHITECTURES ${NVCC_ARCHS_SUPPORTED})
target_link_libraries(dist_inference MPI::MPI_CXX nccl cublasLt)
else()
# ROCm environment
include(../rocm_common.cmake)
find_package(hip QUIET)
if(hip_FOUND)
message(STATUS "Found ROCm: " ${HIP_VERSION})
# Convert cuda code to hip code in cpp
execute_process(COMMAND hipify-perl -print-stats -o dist_inference.cpp dist_inference.cu WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/)
# link hip device lib
add_executable(dist_inference dist_inference.cpp)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -DROCM_USE_FLOAT16=1")
target_link_libraries(dist_inference MPI::MPI_CXX rccl hipblaslt hip::device)
else()
message(FATAL_ERROR "No CUDA or ROCm environment found.")
endif()
endif()
install(TARGETS dist_inference RUNTIME DESTINATION bin)

Просмотреть файл

@ -0,0 +1,455 @@
/*******************************************************************************
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT License.
*
* MIT License
*
* Copyright (C) 2022-2023 Advanced Micro Devices, Inc.
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cuda.h>
#include <cuda_runtime.h>
#include <functional>
#include <iostream>
#include <limits>
#include <mpi.h>
#include <string>
#include <unistd.h>
#include <vector>
#if defined(__HIP_PLATFORM_AMD__)
#include <hipblaslt/hipblaslt.h>
#include <rccl/rccl.h>
using cublasLtHalf = hipblasLtHalf;
#else
#include <cublasLt.h>
#include <cuda_fp16.h>
#include <nccl.h>
using cublasLtHalf = half;
#endif
#ifndef CHECK_CUDA_ERROR
#define CHECK_CUDA_ERROR(error) \
if (error != cudaSuccess) { \
fprintf(stderr, "Cuda error: '%s'(%d) at %s:%d\n", cudaGetErrorString(error), error, __FILE__, __LINE__); \
exit(-1); \
}
#endif
#ifndef CHECK_CUBLASLT_ERROR
#define CHECK_CUBLASLT_ERROR(error) \
if (error != CUBLAS_STATUS_SUCCESS) { \
fprintf(stderr, "cuBLASLt error(Err=%d) at %s:%d\n", error, __FILE__, __LINE__); \
fprintf(stderr, "\n"); \
exit(-1); \
}
#endif
#ifndef CHECK_NCCL_ERROR
#define CHECK_NCCL_ERROR(error) \
if (error != ncclSuccess) { \
fprintf(stderr, "NCCL error(Err=%d) at %s:%d\n", error, __FILE__, __LINE__); \
fprintf(stderr, "\n"); \
exit(-1); \
}
#endif
static void ShowUsage(char *argv[]) {
std::cerr << "Usage: " << argv[0] << " <options>\n"
<< "options:\n"
<< "\t-h, --help\t\t\t\tShow this help message\n"
<< "\t-m \t\t\tm\t\tGEMM_STRIDED argument m\n"
<< "\t-n \t\t\tn\t\tGEMM_STRIDED argument n\n"
<< "\t-k \t\t\tk \t\tGEMM_STRIDED argument k\n"
<< "\t--alpha \t\talpha \t\tGEMM_STRIDED argument alpha\n"
<< "\t--beta \t\t\tbeta \t\tGEMM_STRIDED argument beta\n"
<< "\t--num_layers \t\t\tnum_layers \t\tNumber of layers in the model\n"
<< "\t--num_warmups \t\t\tnum_warmups \t\tNumber of warmup runs\n"
<< "\t--num_iters \t\t\tnum_iters \t\tNumber of test runs\n"
<< "\t--use_cuda_graph \t\t\tuse_cuda_graph \t\tWhether to launch kernels in CUDA graph mode\n"
<< std::endl;
}
static int ParseArguments(int argc, char *argv[], int64_t *m, int64_t *n, int64_t *k, float *alpha, float *beta,
int32_t *num_layers, int32_t *num_warmups, int32_t *num_iters, bool *use_cuda_graph) {
if (argc >= 2) {
for (int i = 1; i < argc; ++i) {
std::string arg = argv[i];
if ((arg.at(0) == '-') || ((arg.at(0) == '-') && (arg.at(1) == '-'))) {
if ((arg == "-h") || (arg == "--help")) {
return -1;
} else if ((arg == "-m") && (i + 1 < argc)) {
*m = atoi(argv[++i]);
} else if ((arg == "-n") && (i + 1 < argc)) {
*n = atoi(argv[++i]);
} else if ((arg == "-k") && (i + 1 < argc)) {
*k = atoi(argv[++i]);
} else if ((arg == "--alpha") && (i + 1 < argc)) {
*alpha = atof(argv[++i]);
} else if ((arg == "--beta") && (i + 1 < argc)) {
*beta = atof(argv[++i]);
} else if ((arg == "--num_layers") && (i + 1 < argc)) {
*num_layers = atoi(argv[++i]);
} else if ((arg == "--num_warmups") && (i + 1 < argc)) {
*num_warmups = atoi(argv[++i]);
} else if ((arg == "--num_iters") && (i + 1 < argc)) {
*num_iters = atoi(argv[++i]);
} else if (arg == "--use_cuda_graph") {
#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310)
*use_cuda_graph = true;
#else
*use_cuda_graph = false;
std::cerr << "error with " << arg << std::endl;
std::cerr << "not supported by current environment" << std::endl << std::endl;
return -1;
#endif
} else {
std::cerr << "error with " << arg << std::endl;
std::cerr << "do not recognize option" << std::endl << std::endl;
return -1;
}
} else {
std::cerr << "error with " << arg << std::endl;
std::cerr << "option must start with - or --" << std::endl << std::endl;
return -1;
}
}
}
return 0;
}
void InitializeABCDEF(std::vector<cublasLtHalf> &ha, int64_t size_a, std::vector<cublasLtHalf> &hb, int64_t size_b,
std::vector<cublasLtHalf> &hc, int64_t size_c, std::vector<cublasLtHalf> &hd, int64_t size_d,
std::vector<cublasLtHalf> &he, int64_t size_e, std::vector<cublasLtHalf> &hf, int64_t size_f) {
srand(1);
for (int i = 0; i < size_a; ++i) {
ha[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
}
for (int i = 0; i < size_b; ++i) {
hb[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
}
for (int i = 0; i < size_c; ++i) {
hc[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
}
for (int i = 0; i < size_d; ++i) {
hd[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
}
for (int i = 0; i < size_e; ++i) {
he[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
}
for (int i = 0; i < size_f; ++i) {
hf[i] = static_cast<cublasLtHalf>((rand() % 7) - 3);
}
}
// B[m, k] * A[k, n] + C[m, n] = D[m, n]
// E[k, m] * D[m, n] + F[k, n] = G[k, n]
void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t num_layers, int32_t num_warmups,
int32_t num_iters, bool use_cuda_graph, ncclComm_t nccl_comm) {
const int kNcclBufAlignment = 512;
int size_a = k * n;
int size_b = m * k;
int size_c = m * n;
int size_d = m * n;
int size_e = k * m;
int size_f = k * n;
int size_g = (k * n + kNcclBufAlignment - 1) / kNcclBufAlignment * kNcclBufAlignment;
// Naming: da is in GPU (device) memory. ha is in CPU (host) memory
std::vector<cublasLtHalf> ha(size_a);
std::vector<cublasLtHalf> hb(size_b);
std::vector<cublasLtHalf> hc(size_c);
std::vector<cublasLtHalf> hd(size_d);
std::vector<cublasLtHalf> he(size_e);
std::vector<cublasLtHalf> hf(size_f);
std::vector<cublasLtHalf> hg(size_g);
// initial data on host
InitializeABCDEF(ha, size_a, hb, size_b, hc, size_c, hd, size_d, he, size_e, hf, size_f);
// allocate memory on device
void *da, *db, *dc, *dd, *de, *df, *dg;
// Create stream
cudaStream_t stream = nullptr;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CHECK_CUDA_ERROR(cudaMalloc(&da, size_a * sizeof(cublasLtHalf)));
CHECK_CUDA_ERROR(cudaMalloc(&db, size_b * sizeof(cublasLtHalf)));
CHECK_CUDA_ERROR(cudaMalloc(&dc, size_c * sizeof(cublasLtHalf)));
CHECK_CUDA_ERROR(cudaMalloc(&dd, size_d * sizeof(cublasLtHalf)));
CHECK_CUDA_ERROR(cudaMalloc(&de, size_e * sizeof(cublasLtHalf)));
CHECK_CUDA_ERROR(cudaMalloc(&df, size_f * sizeof(cublasLtHalf)));
CHECK_CUDA_ERROR(cudaMalloc(&dg, size_g * sizeof(cublasLtHalf)));
// copy matrices from host to device
CHECK_CUDA_ERROR(cudaMemcpy(da, ha.data(), sizeof(cublasLtHalf) * size_a, cudaMemcpyHostToDevice));
CHECK_CUDA_ERROR(cudaMemcpy(db, hb.data(), sizeof(cublasLtHalf) * size_b, cudaMemcpyHostToDevice));
CHECK_CUDA_ERROR(cudaMemcpy(dc, hc.data(), sizeof(cublasLtHalf) * size_c, cudaMemcpyHostToDevice));
CHECK_CUDA_ERROR(cudaMemcpy(dd, hd.data(), sizeof(cublasLtHalf) * size_d, cudaMemcpyHostToDevice));
CHECK_CUDA_ERROR(cudaMemcpy(de, he.data(), sizeof(cublasLtHalf) * size_e, cudaMemcpyHostToDevice));
CHECK_CUDA_ERROR(cudaMemcpy(df, hf.data(), sizeof(cublasLtHalf) * size_f, cudaMemcpyHostToDevice));
uint64_t workspace_size = 1024 * 1024;
void *d_workspace;
CHECK_CUDA_ERROR(cudaMalloc(&d_workspace, workspace_size));
int returnedAlgoCount = 0;
// cublasLt is not well supported by ROCm hipify tools, explicitly define ROCm logic instead.
#if defined(__HIP_PLATFORM_AMD__)
hipblasLtHandle_t handle;
hipblasLtMatrixLayout_t matA, matB, matC, matD, matE, matF, matG;
hipblasLtMatmulDesc_t matmul1, matmul2;
hipblasLtMatmulPreference_t pref;
CHECK_CUBLASLT_ERROR(hipblasLtCreate(&handle));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, HIPBLAS_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, HIPBLAS_R_16F, m, k, m));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, HIPBLAS_R_16F, m, n, m));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, HIPBLAS_R_16F, m, n, m));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matE, HIPBLAS_R_16F, k, m, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matF, HIPBLAS_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matG, HIPBLAS_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul1, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul2, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F));
hipblasOperation_t trans = HIPBLAS_OP_N;
CHECK_CUBLASLT_ERROR(
hipblasLtMatmulDescSetAttribute(matmul1, HIPBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t)));
CHECK_CUBLASLT_ERROR(
hipblasLtMatmulDescSetAttribute(matmul1, HIPBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t)));
CHECK_CUBLASLT_ERROR(
hipblasLtMatmulDescSetAttribute(matmul2, HIPBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t)));
CHECK_CUBLASLT_ERROR(
hipblasLtMatmulDescSetAttribute(matmul2, HIPBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t)));
// Set User Preference attributes
CHECK_CUBLASLT_ERROR(hipblasLtMatmulPreferenceCreate(&pref));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulPreferenceSetAttribute(pref, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size, sizeof(workspace_size)));
// Get Heuristic results
hipblasLtMatmulHeuristicResult_t heuristicResult1[1] = {0};
hipblasLtMatmulHeuristicResult_t heuristicResult2[1] = {0};
// B[m, k] * A[k, n] + C[m, n] = D[m, n]
// E[k, m] * D[m, n] + F[k, n] = G[k, n]
CHECK_CUBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, matmul1, matB, matA, matC, matD, pref, 1,
heuristicResult1, &returnedAlgoCount));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, matmul2, matE, matD, matF, matG, pref, 1,
heuristicResult2, &returnedAlgoCount));
#else
cublasLtHandle_t handle;
cublasLtMatrixLayout_t matA, matB, matC, matD, matE, matF, matG;
cublasLtMatmulDesc_t matmul1, matmul2;
cublasLtMatmulPreference_t pref;
CHECK_CUBLASLT_ERROR(cublasLtCreate(&handle));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matA, CUDA_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matB, CUDA_R_16F, m, k, m));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matC, CUDA_R_16F, m, n, m));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matD, CUDA_R_16F, m, n, m));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matE, CUDA_R_16F, k, m, k));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matF, CUDA_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matG, CUDA_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescCreate(&matmul1, CUBLAS_COMPUTE_16F, CUDA_R_32F));
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescCreate(&matmul2, CUBLAS_COMPUTE_16F, CUDA_R_32F));
cublasOperation_t trans = CUBLAS_OP_N;
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul1, CUBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t)));
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul1, CUBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t)));
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul2, CUBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(int32_t)));
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute(matmul2, CUBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(int32_t)));
// Set User Preference attributes
CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceCreate(&pref));
CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size, sizeof(workspace_size)));
// Get Heuristic results
cublasLtMatmulHeuristicResult_t heuristicResult1[1] = {0};
cublasLtMatmulHeuristicResult_t heuristicResult2[1] = {0};
// B[m, k] * A[k, n] + C[m, n] = D[m, n]
// E[k, m] * D[m, n] + F[k, n] = G[k, n]
CHECK_CUBLASLT_ERROR(cublasLtMatmulAlgoGetHeuristic(handle, matmul1, matB, matA, matC, matD, pref, 1,
heuristicResult1, &returnedAlgoCount));
CHECK_CUBLASLT_ERROR(cublasLtMatmulAlgoGetHeuristic(handle, matmul2, matE, matD, matF, matG, pref, 1,
heuristicResult2, &returnedAlgoCount));
#endif
auto model_forward = [&] {
for (int j = 0; j < num_layers; j++) {
// B[m, k] * A[k, n] + C[m, n] = D[m, n]
// E[k, m] * D[m, n] + F[k, n] = G[k, n]
// cublasLt is not well supported by ROCm hipify tools, explicitly define ROCm logic instead.
#if defined(__HIP_PLATFORM_AMD__)
CHECK_CUBLASLT_ERROR(hipblasLtMatmul(handle, matmul1, &alpha, db, matB, da, matA, &beta, dc, matC, dd, matD,
&heuristicResult1[0].algo, d_workspace, workspace_size, stream));
CHECK_CUBLASLT_ERROR(hipblasLtMatmul(handle, matmul1, &alpha, de, matE, dd, matD, &beta, df, matF, dg, matG,
&heuristicResult2[0].algo, d_workspace, workspace_size, stream));
#else
CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, matmul1, &alpha, db, matB, da, matA, &beta, dc, matC, dd, matD,
&heuristicResult1[0].algo, d_workspace, workspace_size, stream));
CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, matmul1, &alpha, de, matE, dd, matD, &beta, df, matF, dg, matG,
&heuristicResult2[0].algo, d_workspace, workspace_size, stream));
#endif
CHECK_NCCL_ERROR(ncclAllReduce(dg, dg, size_g, ncclFloat16, ncclSum, nccl_comm, stream));
}
};
#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310)
cudaGraph_t graph;
cudaGraphExec_t instance;
if (use_cuda_graph) {
CHECK_CUDA_ERROR(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));
model_forward();
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &graph));
CHECK_CUDA_ERROR(cudaGraphInstantiate(&instance, graph, NULL, NULL, 0));
}
#endif
std::chrono::steady_clock::time_point start_time, stop_time;
for (int i = 0; i < num_warmups + num_iters; ++i) {
if (i == num_warmups) {
start_time = std::chrono::steady_clock::now();
}
#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310)
if (use_cuda_graph) {
CHECK_CUDA_ERROR(cudaGraphLaunch(instance, stream));
} else {
model_forward();
}
#else
model_forward();
#endif
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));
}
stop_time = std::chrono::steady_clock::now();
double duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop_time - start_time).count();
fprintf(stdout, "Time: %g ms in total, %g ms per iteration, %g ms per layer\n", duration, duration / num_iters,
duration / num_iters / num_layers);
#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && (CUDART_VERSION >= 11030 || HIP_VERSION >= 50221310)
// Destroy graph
if (use_cuda_graph) {
CHECK_CUDA_ERROR(cudaGraphExecDestroy(instance));
CHECK_CUDA_ERROR(cudaGraphDestroy(graph));
}
#endif
// Destroy stream
CHECK_CUDA_ERROR(cudaStreamDestroy(stream));
CHECK_CUDA_ERROR(cudaFree(da));
CHECK_CUDA_ERROR(cudaFree(db));
CHECK_CUDA_ERROR(cudaFree(dc));
CHECK_CUDA_ERROR(cudaFree(dd));
CHECK_CUDA_ERROR(cudaFree(de));
CHECK_CUDA_ERROR(cudaFree(df));
CHECK_CUDA_ERROR(cudaFree(dg));
CHECK_CUDA_ERROR(cudaFree(d_workspace));
// cublasLt is not well supported by ROCm hipify tools, explicitly define ROCm logic instead.
#if defined(__HIP_PLATFORM_AMD__)
CHECK_CUBLASLT_ERROR(hipblasLtMatmulPreferenceDestroy(pref));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescDestroy(matmul1));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescDestroy(matmul2));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matA));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matB));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matC));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matD));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matE));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matF));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matG));
CHECK_CUBLASLT_ERROR(hipblasLtDestroy(handle));
#else
CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceDestroy(pref));
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescDestroy(matmul1));
CHECK_CUBLASLT_ERROR(cublasLtMatmulDescDestroy(matmul2));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matA));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matB));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matC));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matD));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matE));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matF));
CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matG));
CHECK_CUBLASLT_ERROR(cublasLtDestroy(handle));
#endif
return;
}
int main(int argc, char *argv[]) {
// Init MPI
int comm_rank, comm_size;
MPI_Init(NULL, NULL);
MPI_Comm_rank(MPI_COMM_WORLD, &comm_rank);
MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
// Init NCCL
int num_local_ranks = 0;
ncclComm_t nccl_comm;
ncclUniqueId nccl_id;
if (comm_rank == 0) {
CHECK_NCCL_ERROR(ncclGetUniqueId(&nccl_id));
}
MPI_Bcast(&nccl_id, sizeof(ncclUniqueId), MPI_BYTE, 0, MPI_COMM_WORLD);
CHECK_CUDA_ERROR(cudaGetDeviceCount(&num_local_ranks))
CHECK_CUDA_ERROR(cudaSetDevice(comm_rank % num_local_ranks));
CHECK_NCCL_ERROR(ncclCommInitRank(&nccl_comm, comm_size, nccl_id, comm_rank));
// Init parameters with default values
int64_t m = 80;
int64_t n = 128;
int64_t k = 128;
float alpha = 1;
float beta = 1;
int32_t num_layers = 50;
int32_t num_warmups = 20;
int32_t num_iters = 100;
bool use_cuda_graph = false;
if (ParseArguments(argc, argv, &m, &n, &k, &alpha, &beta, &num_layers, &num_warmups, &num_iters, &use_cuda_graph)) {
ShowUsage(argv);
return -1;
}
fprintf(stdout,
"Parameters: m=%ld, n=%ld, k=%ld, alpha=%f, beta=%f, num_layers=%d, num_warmups=%d, num_iters=%d, "
"use_cuda_graph=%d\n",
m, n, k, alpha, beta, num_layers, num_warmups, num_iters, (int)use_cuda_graph);
TestModel(m, n, k, alpha, beta, num_layers, num_warmups, num_iters, use_cuda_graph, nccl_comm);
CHECK_NCCL_ERROR(ncclCommDestroy(nccl_comm));
MPI_Finalize();
return 0;
}

Просмотреть файл

@ -33,6 +33,8 @@ superbench:
node_num: 1
env:
NCCL_ASYNC_ERROR_HANDLING: '0'
frameworks:
- pytorch
common_model_config: &common_model_config
duration: 0
num_warmup: 64

Просмотреть файл

@ -29,6 +29,8 @@ superbench:
node_num: 1
env:
NCCL_ASYNC_ERROR_HANDLING: '0'
frameworks:
- pytorch
common_model_config: &common_model_config
duration: 0
num_warmup: 64

Просмотреть файл

@ -3,12 +3,15 @@
"""Tests for distributed inference benchmark."""
import numbers
import unittest
from tests.helper import decorator
from tests.helper.testcase import BenchmarkTestCase
import tests.benchmarks.utils as utils
from superbench.benchmarks \
import BenchmarkRegistry, Framework, BenchmarkType, ReturnCode, Precision, DistributedImpl, DistributedBackend
import BenchmarkRegistry, Framework, BenchmarkType, ReturnCode, Precision, DistributedImpl, DistributedBackend, \
Platform
from superbench.benchmarks.micro_benchmarks.dist_inference \
import DistInference, ComputationKernelType, CommunicationKernelType, ActivationKernelType
from superbench.common.utils import network
@ -20,7 +23,9 @@ from superbench.common.utils import network
@decorator.pytorch_test
def test_pytorch_dist_inference_normal():
"""Test pytorch-dist-inference benchmark on distributed normal case."""
context = BenchmarkRegistry.create_benchmark_context('dist-inference', parameters='', framework=Framework.PYTORCH)
context = BenchmarkRegistry.create_benchmark_context(
'dist-inference', parameters='--use_pytorch', framework=Framework.PYTORCH
)
world_size = 2
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
results = utils.simulated_ddp_distributed_benchmark(context, world_size)
@ -33,9 +38,12 @@ def test_pytorch_dist_inference_normal():
assert (benchmark.type == BenchmarkType.MICRO)
# Check predefined parameters of dist-inference benchmark.
assert (benchmark._args.use_pytorch is True)
assert (benchmark._args.batch_size == 64)
assert (benchmark._args.input_size == 1024)
assert (benchmark._args.hidden_size == 1024)
assert (benchmark._args.alpha == 1.0)
assert (benchmark._args.beta == 1.0)
assert (benchmark._args.num_layers == 1)
assert (benchmark._args.computation_kernel == ComputationKernelType.MATMUL)
assert (benchmark._args.communication_kernel == CommunicationKernelType.ALLREDUCE)
@ -45,6 +53,7 @@ def test_pytorch_dist_inference_normal():
assert (benchmark._args.num_steps == 10000)
assert (benchmark._args.distributed_impl == DistributedImpl.DDP)
assert (benchmark._args.distributed_backend == DistributedBackend.NCCL)
assert (benchmark._args.use_cuda_graph is False)
# Check results and metrics.
assert (benchmark.run_count == 1)
@ -52,14 +61,16 @@ def test_pytorch_dist_inference_normal():
# step_times
assert (len(benchmark.raw_data) == 1)
# return code + (avg, 50th, 90th, 95th, 99th, 99.9th)
assert (len(benchmark.result) == 7)
assert (7 == len(benchmark.result))
@decorator.cuda_test
@decorator.pytorch_test
def test_pytorch_dist_inference_fake_distributed():
"""Test pytorch-dist-inference benchmark on single gpu."""
context = BenchmarkRegistry.create_benchmark_context('dist-inference', parameters='', framework=Framework.PYTORCH)
context = BenchmarkRegistry.create_benchmark_context(
'dist-inference', parameters='--use_pytorch', framework=Framework.PYTORCH
)
port = network.get_free_port()
assert (port)
utils.setup_simulated_ddp_distributed_env(1, 0, port)
@ -72,9 +83,12 @@ def test_pytorch_dist_inference_fake_distributed():
assert (benchmark.type == BenchmarkType.MICRO)
# Check predefined parameters of dist-inference benchmark.
assert (benchmark._args.use_pytorch is True)
assert (benchmark._args.batch_size == 64)
assert (benchmark._args.input_size == 1024)
assert (benchmark._args.hidden_size == 1024)
assert (benchmark._args.alpha == 1.0)
assert (benchmark._args.beta == 1.0)
assert (benchmark._args.num_layers == 1)
assert (benchmark._args.computation_kernel == ComputationKernelType.MATMUL)
assert (benchmark._args.communication_kernel == CommunicationKernelType.ALLREDUCE)
@ -84,6 +98,7 @@ def test_pytorch_dist_inference_fake_distributed():
assert (benchmark._args.num_steps == 10000)
assert (benchmark._args.distributed_impl == DistributedImpl.DDP)
assert (benchmark._args.distributed_backend == DistributedBackend.NCCL)
assert (benchmark._args.use_cuda_graph is False)
# Check results and metrics.
assert (benchmark.run_count == 1)
@ -94,3 +109,127 @@ def test_pytorch_dist_inference_fake_distributed():
assert (len(benchmark.result) == 7)
utils.clean_simulated_ddp_distributed_env()
class DistInferenceCppImplTest(BenchmarkTestCase, unittest.TestCase):
"""Test class for pytorch-dist-inference benchmark."""
@classmethod
def setUpClass(cls):
"""Hook method for setting up class fixture before running tests in the class."""
super().setUpClass()
cls.createMockEnvs(cls)
cls.createMockFiles(cls, ['bin/dist_inference'])
def _test_dist_inference_command_generation(self, platform):
"""Test pytorch-dist-inference cpp impl benchmark command generation."""
benchmark_name = 'pytorch-dist-inference'
(benchmark_class,
predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, platform)
assert (benchmark_class)
batch_size = 1
input_size = 2
hidden_size = 3
alpha = 4.0
beta = 5.0
num_layers = 6
num_warmup = 7
num_steps = 8
wrapper_params_format_str = \
'--batch_size %d --input_size %d --hidden_size %d ' \
'--alpha %g --beta %g --num_layers %d --num_warmup %d --num_steps %d --use_cuda_graph'
parameters = wrapper_params_format_str % (
batch_size, input_size, hidden_size, alpha, beta, num_layers, num_warmup, num_steps
)
benchmark = benchmark_class(benchmark_name, parameters=parameters)
# Check basic information
assert (benchmark)
ret = benchmark._preprocess()
assert (ret is True)
assert (benchmark.return_code == ReturnCode.SUCCESS)
assert (benchmark.name == benchmark_name)
assert (benchmark.type == BenchmarkType.MICRO)
# Check parameters specified in BenchmarkContext.
assert (benchmark._args.use_pytorch is False)
assert (benchmark._args.batch_size == batch_size)
assert (benchmark._args.input_size == input_size)
assert (benchmark._args.hidden_size == hidden_size)
assert (benchmark._args.alpha == alpha)
assert (benchmark._args.beta == beta)
assert (benchmark._args.num_layers == num_layers)
assert (benchmark._args.num_warmup == num_warmup)
assert (benchmark._args.num_steps == num_steps)
assert (benchmark._args.use_cuda_graph is True)
# Check command
assert (1 == len(benchmark._commands))
for cmd in benchmark._commands:
m, n, k = hidden_size, batch_size, input_size
bench_params_format_str = \
'%s -m %d -n %d -k %d --alpha %g --beta %g ' + \
'--num_layers %d --num_warmups %d --num_iters %d --use_cuda_graph'
assert (
cmd == (
bench_params_format_str %
(benchmark._DistInference__bin_path, m, n, k, alpha, beta, num_layers, num_warmup, num_steps)
)
)
@decorator.cuda_test
def test_dist_inference_command_generation_cuda(self):
"""Test pytorch-dist-inference cpp impl benchmark command generation, CUDA case."""
self._test_dist_inference_command_generation(Platform.CUDA)
@decorator.rocm_test
def test_dist_inference_command_generation_rocm(self):
"""Test pytorch-dist-inference cpp impl benchmark command generation, ROCm case."""
self._test_dist_inference_command_generation(Platform.ROCM)
@decorator.load_data('tests/data/dist_inference.log')
def _test_dist_inference_result_parsing(self, platform, test_raw_output):
"""Test pytorch-dist-inference cpp impl benchmark result parsing."""
benchmark_name = 'pytorch-dist-inference'
(benchmark_class,
predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, platform)
assert (benchmark_class)
benchmark = benchmark_class(benchmark_name, parameters='')
assert (benchmark)
ret = benchmark._preprocess()
assert (ret is True)
assert (benchmark.return_code == ReturnCode.SUCCESS)
assert (benchmark.name == 'pytorch-dist-inference')
assert (benchmark.type == BenchmarkType.MICRO)
# Positive case - valid raw output.
assert (benchmark._process_raw_result(0, test_raw_output))
assert (benchmark.return_code == ReturnCode.SUCCESS)
# step_times
assert (len(benchmark.raw_data) == 2)
# return code + (avg, 50th, 90th, 95th, 99th, 99.9th)
test_latency = float(test_raw_output.splitlines()[-1].split(' ms per iteration')[0].split()[-1])
assert (7 == len(benchmark.result))
for output_key in benchmark.result:
if output_key == 'return_code':
assert (benchmark.result[output_key] == [0])
else:
assert (output_key.startswith('step_times'))
assert (len(benchmark.result[output_key]) == 1)
assert (isinstance(benchmark.result[output_key][0], numbers.Number))
assert (test_latency == benchmark.result[output_key][0])
# Negative case - invalid raw output.
assert (benchmark._process_raw_result(1, 'Invalid raw output') is False)
assert (benchmark.return_code == ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)
@decorator.cuda_test
def test_dist_inference_result_parsing_cuda(self):
"""Test pytorch-dist-inference cpp impl benchmark result parsing, CUDA case."""
self._test_dist_inference_result_parsing(Platform.CUDA)
@decorator.rocm_test
def test_dist_inference_result_parsing_rocm(self):
"""Test pytorch-dist-inference cpp impl benchmark result parsing, ROCm case."""
self._test_dist_inference_result_parsing(Platform.ROCM)

Просмотреть файл

@ -0,0 +1,2 @@
Parameters: m=80, n=128, k=128, alpha=1.000000, beta=1.000000, num_layers=50, num_warmups=20, num_iters=100, use_cuda_graph=0
Time: 173 ms in total, 1.73 ms per iteration, 0.0346 ms per layer