**Description**

Cherry-pick bug fixes from v0.8.0 to main.

**Major Revisions**

* Monitor - Fix the cgroup version checking logic (#502)
* Benchmark - Fix matrix size overflow issue in cuBLASLt GEMM (#503)
* Fix wrong torch usage in communication wrapper for Distributed
Inference Benchmark (#505)
* Analyzer: Fix bug in python3.8 due to pandas api change (#504)
* Bug - Fix bug to get metric from cmd when error happens (#506)
* Monitor - Collect realtime GPU power when benchmarking (#507)
* Add num_workers argument in model benchmark (#511)
* Remove unreachable condition when write host list (#512)
* Update cuda11.8 image to cuda12.1 based on nvcr23.03 (#513)
* Doc - Fix wrong unit of cpu-memory-bw-latency in doc (#515)
* Docs - Upgrade version and release note (#508)

Co-authored-by: guoshzhao <guzhao@microsoft.com>
Co-authored-by: Ziyue Yang <ziyyang@microsoft.com>
Co-authored-by: Yuting Jiang <yutingjiang@microsoft.com>
This commit is contained in:
Yifan Xiong 2023-04-14 20:57:55 +08:00 коммит произвёл GitHub
Родитель 97c9a41f14
Коммит 51761b3af1
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
41 изменённых файлов: 265 добавлений и 162 удалений

6
.github/workflows/build-image.yml поставляемый
Просмотреть файл

@ -24,9 +24,9 @@ jobs:
strategy:
matrix:
include:
- name: cuda11.8
dockerfile: cuda11.8
tags: superbench/main:cuda11.8
- name: cuda12.1
dockerfile: cuda12.1
tags: superbench/main:cuda12.1
- name: cuda11.1.1
dockerfile: cuda11.1.1
tags: superbench/main:cuda11.1.1,superbench/superbench:latest

Просмотреть файл

@ -1,18 +1,18 @@
FROM nvcr.io/nvidia/pytorch:22.12-py3
FROM nvcr.io/nvidia/pytorch:23.03-py3
# OS:
# - Ubuntu: 20.04
# - OpenMPI: 4.1.5a1
# - Docker Client: 20.10.8
# NVIDIA:
# - CUDA: 11.8.0
# - cuDNN: 8.7.0.84
# - NCCL: v2.15.5-1
# - CUDA: 12.1.0
# - cuDNN: 8.8.1.3
# - NCCL: v2.17.1-1
# Mellanox:
# - OFED: 5.2-2.2.3.0
# - HPC-X: v2.8.3
# - OFED: 5.2-2.2.3.0 # TODO
# - HPC-X: v2.14
# Intel:
# - mlc: v3.9a
# - mlc: v3.10
LABEL maintainer="SuperBench"
@ -71,37 +71,27 @@ RUN mkdir -p /root/.ssh && \
# Install OFED
ENV OFED_VERSION=5.2-2.2.3.0
RUN cd /tmp && \
wget -q http://content.mellanox.com/ofed/MLNX_OFED-${OFED_VERSION}/MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tgz && \
wget -q https://content.mellanox.com/ofed/MLNX_OFED-${OFED_VERSION}/MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tgz && \
tar xzf MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tgz && \
MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64/mlnxofedinstall --user-space-only --without-fw-update --force --all && \
rm -rf /tmp/MLNX_OFED_LINUX-${OFED_VERSION}*
# Install HPC-X
ENV HPCX_VERSION=v2.14
RUN cd /opt && \
rm -rf hpcx && \
wget -q https://azhpcstor.blob.core.windows.net/azhpc-images-store/hpcx-v2.8.3-gcc-MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tbz && \
tar xf hpcx-v2.8.3-gcc-MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tbz && \
ln -s hpcx-v2.8.3-gcc-MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64 hpcx && \
rm hpcx-v2.8.3-gcc-MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tbz
wget -q https://content.mellanox.com/hpc/hpc-x/${HPCX_VERSION}/hpcx-${HPCX_VERSION}-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda12-gdrcopy2-nccl2.17-x86_64.tbz -O hpcx.tbz && \
tar xf hpcx.tbz && \
mv hpcx-${HPCX_VERSION}-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda12-gdrcopy2-nccl2.17-x86_64 hpcx && \
rm hpcx.tbz
# Install Intel MLC
RUN cd /tmp && \
wget -q https://downloadmirror.intel.com/736634/mlc_v3.9a.tgz -O mlc.tgz && \
wget -q https://downloadmirror.intel.com/763324/mlc_v3.10.tgz -O mlc.tgz && \
tar xzf mlc.tgz Linux/mlc && \
cp ./Linux/mlc /usr/local/bin/ && \
rm -rf ./Linux mlc.tgz
ENV PATH="${PATH}" \
LD_LIBRARY_PATH="/usr/local/lib:${LD_LIBRARY_PATH}" \
SB_HOME=/opt/superbench \
SB_MICRO_PATH=/opt/superbench \
ANSIBLE_DEPRECATION_WARNINGS=FALSE \
ANSIBLE_COLLECTIONS_PATH=/usr/share/ansible/collections
RUN echo PATH="$PATH" > /etc/environment && \
echo LD_LIBRARY_PATH="$LD_LIBRARY_PATH" >> /etc/environment && \
echo SB_MICRO_PATH="$SB_MICRO_PATH" >> /etc/environment
# Install AOCC compiler
RUN cd /tmp && \
wget https://download.amd.com/developer/eula/aocc-compiler/aocc-compiler-4.0.0_1_amd64.deb && \
@ -115,6 +105,18 @@ RUN cd /tmp && \
mv amd-blis /opt/AMD && \
rm -rf aocl-blis-linux-aocc-4.0.tar.gz
ENV PATH="${PATH}" \
LD_LIBRARY_PATH="/usr/local/lib:${LD_LIBRARY_PATH}" \
SB_HOME=/opt/superbench \
SB_MICRO_PATH=/opt/superbench \
ANSIBLE_DEPRECATION_WARNINGS=FALSE \
ANSIBLE_COLLECTIONS_PATH=/usr/share/ansible/collections
RUN echo PATH="$PATH" > /etc/environment && \
echo LD_LIBRARY_PATH="$LD_LIBRARY_PATH" >> /etc/environment && \
echo SB_MICRO_PATH="$SB_MICRO_PATH" >> /etc/environment
# Add config files
ADD dockerfile/etc /opt/microsoft/

Просмотреть файл

@ -29,7 +29,7 @@ You need to [clone the code](./development.md#set-up) first before building the
export DOCKER_BUILDKIT=1
docker buildx build \
--platform linux/amd64 --cache-to type=inline,mode=max \
--tag superbench-dev --file dockerfile/cuda11.1.1.dockerfile .
--tag superbench-dev --file dockerfile/cuda12.1.dockerfile .
```
</TabItem>
@ -39,7 +39,7 @@ docker buildx build \
export DOCKER_BUILDKIT=1
docker buildx build \
--platform linux/amd64 --cache-to type=inline,mode=max \
--tag superbench-dev --file dockerfile/rocm4.2-pytorch1.7.0.dockerfile .
--tag superbench-dev --file dockerfile/rocm5.1.x.dockerfile .
```
</TabItem>

Просмотреть файл

@ -45,7 +45,7 @@ but it is not strictly necessary.
```bash
# create a new virtual environment
python3 -m venv --system-site-packages ./venv
python3 -m venv ./venv
# activate the virtual environment
source ./venv/bin/activate
@ -61,7 +61,7 @@ You can clone the source from GitHub and build it.
:::note Note
You should checkout corresponding tag to use release version, for example,
`git clone -b v0.7.0 https://github.com/microsoft/superbenchmark`
`git clone -b v0.8.0 https://github.com/microsoft/superbenchmark`
:::
```bash

Просмотреть файл

@ -27,7 +27,7 @@ sb deploy -f remote.ini --host-password [password]
:::note Note
You should deploy corresponding Docker image to use release version, for example,
`sb deploy -f local.ini -i superbench/superbench:v0.7.0-cuda11.1.1`
`sb deploy -f local.ini -i superbench/superbench:v0.8.0-cuda12.1`
You should note that version of git repo only determines version of sb CLI, and not the sb container. You should define the container version even if you specified a release version for the git clone.

Просмотреть файл

@ -70,7 +70,7 @@ superbench:
<TabItem value='example'>
```yaml
version: v0.7
version: v0.8
superbench:
enable: benchmark_1
monitor:

Просмотреть файл

@ -180,11 +180,11 @@ Performed by [High-Performance Linpack Benchmark for Distributed-Memory Computer
#### Metrics
| Name | Unit | Description |
|---------------------|--------------------|----------------------------------------------------------------------------|
| cpu-hpl/tests_pass | | HPL completed running and correctness test has passed (1: pass, 0: fail). |
| cpu-hpl/throughput | bandwidth (GFlops) | Compute bandwidth. |
| cpu-hpl/time | time (s) | Time elapsed during HPL run. |
| Name | Unit | Description |
|--------------------|--------------------|---------------------------------------------------------------------------|
| cpu-hpl/tests_pass | | HPL completed running and correctness test has passed (1: pass, 0: fail). |
| cpu-hpl/throughput | bandwidth (GFlops) | Compute bandwidth. |
| cpu-hpl/time | time (s) | Time elapsed during HPL run. |
### `cpu-stream`
@ -216,13 +216,13 @@ performed by [Intel MLC Tool](https://www.intel.com/content/www/us/en/developer/
| Name | Unit | Description |
|-------------------------------------------------------------------------|------------------|---------------------------------------------------------------------|
| cpu-memory-bw-latency/mem\_bandwidth\_matrix\_numa\_[0-9]+\_[0-9]+\_bw | bandwidth (GB/s) | Former NUMA to latter NUMA memory bandwidth. |
| cpu-memory-bw-latency/mem\_bandwidth\_matrix\_numa\_[0-9]+\_[0-9]+\_lat | time (us) | Former NUMA to latter NUMA memory latency. |
| cpu-memory-bw-latency/mem\_max\_bandwidth\_all\_reads\_bw | bandwidth (GB/s) | Whole-CPU maximum memory bandwidth, full read. |
| cpu-memory-bw-latency/mem\_max\_bandwidth\_3_1\_reads-writes\_bw | bandwidth (GB/s) | Whole-CPU maximum memory bandwidth, read : write = 3 : 1. |
| cpu-memory-bw-latency/mem\_max\_bandwidth\_2_1\_reads-writes\_bw | bandwidth (GB/s) | Whole-CPU maximum memory bandwidth, read : write = 2 : 1. |
| cpu-memory-bw-latency/mem\_max\_bandwidth\_1_1\_reads-writes\_bw | bandwidth (GB/s) | Whole-CPU maximum memory bandwidth, read : write = 1 : 1. |
| cpu-memory-bw-latency/mem\_max\_bandwidth\_stream-triad\_like\_bw | bandwidth (GB/s) | Whole-CPU maximum memory bandwidth, with stream-triad like pattern. |
| cpu-memory-bw-latency/mem\_bandwidth\_matrix\_numa\_[0-9]+\_[0-9]+\_bw | bandwidth (MB/s) | Former NUMA to latter NUMA memory bandwidth. |
| cpu-memory-bw-latency/mem\_bandwidth\_matrix\_numa\_[0-9]+\_[0-9]+\_lat | time (ns) | Former NUMA to latter NUMA memory latency. |
| cpu-memory-bw-latency/mem\_max\_bandwidth\_all\_reads\_bw | bandwidth (MB/s) | Whole-CPU maximum memory bandwidth, full read. |
| cpu-memory-bw-latency/mem\_max\_bandwidth\_3_1\_reads-writes\_bw | bandwidth (MB/s) | Whole-CPU maximum memory bandwidth, read : write = 3 : 1. |
| cpu-memory-bw-latency/mem\_max\_bandwidth\_2_1\_reads-writes\_bw | bandwidth (MB/s) | Whole-CPU maximum memory bandwidth, read : write = 2 : 1. |
| cpu-memory-bw-latency/mem\_max\_bandwidth\_1_1\_reads-writes\_bw | bandwidth (MB/s) | Whole-CPU maximum memory bandwidth, read : write = 1 : 1. |
| cpu-memory-bw-latency/mem\_max\_bandwidth\_stream-triad\_like\_bw | bandwidth (MB/s) | Whole-CPU maximum memory bandwidth, with stream-triad like pattern. |
### `mem-bw`

Просмотреть файл

@ -29,6 +29,8 @@ available tags are listed below for all stable versions.
| Tag | Description |
|-------------------|------------------------------------|
| v0.8.0-cuda12.1 | SuperBench v0.8.0 with CUDA 12.1 |
| v0.8.0-cuda11.1.1 | SuperBench v0.8.0 with CUDA 11.1.1 |
| v0.7.0-cuda11.8 | SuperBench v0.7.0 with CUDA 11.8 |
| v0.7.0-cuda11.1.1 | SuperBench v0.7.0 with CUDA 11.1.1 |
| v0.6.0-cuda11.1.1 | SuperBench v0.6.0 with CUDA 11.1.1 |
@ -43,6 +45,10 @@ available tags are listed below for all stable versions.
| Tag | Description |
|-------------------------------|--------------------------------------------------|
| v0.8.0-rocm5.1.3 | SuperBench v0.8.0 with ROCm 5.1.3 |
| v0.8.0-rocm5.1.1 | SuperBench v0.8.0 with ROCm 5.1.1 |
| v0.8.0-rocm5.0.1 | SuperBench v0.8.0 with ROCm 5.0.1 |
| v0.8.0-rocm5.0 | SuperBench v0.8.0 with ROCm 5.0 |
| v0.7.0-rocm5.1.3 | SuperBench v0.7.0 with ROCm 5.1.3 |
| v0.7.0-rocm5.1.1 | SuperBench v0.7.0 with ROCm 5.1.1 |
| v0.7.0-rocm5.0.1 | SuperBench v0.7.0 with ROCm 5.0.1 |

Просмотреть файл

@ -65,7 +65,7 @@ superbench:
example:
```yaml
# SuperBench rules
version: v0.7
version: v0.8
superbench:
rules:
failure-rule:

Просмотреть файл

@ -58,7 +58,7 @@ superbench:
```yaml title="Example"
# SuperBench rules
version: v0.7
version: v0.8
superbench:
rules:
kernel_launch:

Просмотреть файл

@ -6,5 +6,5 @@
Provide hardware and software benchmarks for AI systems.
"""
__version__ = '0.7.0'
__version__ = '0.8.0'
__author__ = 'Microsoft'

Просмотреть файл

@ -31,11 +31,13 @@ def statistic(raw_data_df):
logger.warning('DataAnalyzer: empty data.')
return data_statistics_df
try:
raw_data_df = raw_data_df.apply(pd.to_numeric, errors='coerce')
raw_data_df = raw_data_df.dropna(axis=1, how='all')
data_statistics_df = raw_data_df.describe()
data_statistics_df.loc['1%'] = raw_data_df.quantile(0.01)
data_statistics_df.loc['5%'] = raw_data_df.quantile(0.05)
data_statistics_df.loc['95%'] = raw_data_df.quantile(0.95)
data_statistics_df.loc['99%'] = raw_data_df.quantile(0.99)
data_statistics_df.loc['1%'] = raw_data_df.quantile(0.01, numeric_only=True)
data_statistics_df.loc['5%'] = raw_data_df.quantile(0.05, numeric_only=True)
data_statistics_df.loc['95%'] = raw_data_df.quantile(0.95, numeric_only=True)
data_statistics_df.loc['99%'] = raw_data_df.quantile(0.99, numeric_only=True)
statistics_error = []
for column in list(raw_data_df.columns):
if column not in list(data_statistics_df.columns) and not raw_data_df[column].isnull().all():
@ -122,6 +124,8 @@ def correlation(raw_data_df):
logger.warning('DataAnalyzer: empty data.')
return data_corr_df
try:
raw_data_df = raw_data_df.apply(pd.to_numeric, errors='coerce')
raw_data_df = raw_data_df.dropna(axis=1, how='all')
data_corr_df = raw_data_df.corr()
statistics_error = []
for column in list(raw_data_df.columns):
@ -181,6 +185,8 @@ def generate_baseline(raw_data_df, output_dir):
output_dir (str): the directory of output file
"""
try:
raw_data_df = raw_data_df.apply(pd.to_numeric, errors='coerce')
raw_data_df = raw_data_df.dropna(axis=1, how='all')
if not isinstance(raw_data_df, pd.DataFrame):
logger.error('DataAnalyzer: the type of raw data is not pd.DataFrame')
return

Просмотреть файл

@ -285,7 +285,7 @@ class DataDiagnosis(RuleBase):
logger.log_and_raise(exception=IOError, msg='DataDiagnosis: excel_data_output - invalid file path.')
file_handler.output_excel_raw_data(writer, raw_data_df, 'Raw Data')
file_handler.output_excel_data_not_accept(writer, data_not_accept_df, rules)
writer.save()
writer.close()
except Exception as e:
logger.log_and_raise(exception=Exception, msg='DataDiagnosis: excel_data_output - {}'.format(str(e)))

Просмотреть файл

@ -117,7 +117,7 @@ class ResultSummary(RuleBase):
summary_df = pd.DataFrame()
for category in summary:
for i in range(len(summary[category])):
summary_df = summary_df.append([summary[category][i]], ignore_index=True)
summary_df = pd.concat([summary_df, pd.DataFrame([summary[category][i]])], ignore_index=True)
return summary_df
def _generate_summary(self, round):
@ -217,7 +217,7 @@ class ResultSummary(RuleBase):
file_handler.merge_column_in_excel(worksheet, row, 1)
else:
logger.error('ResultSummary: excel_data_output - summary is empty.')
writer.save()
writer.close()
except Exception as e:
logger.error('ResultSummary: excel_data_output - {}'.format(str(e)))

Просмотреть файл

@ -88,20 +88,21 @@ template <typename T> cudaDataType_t get_datatype() {
}
template <typename Ta, typename Tb, typename Tout>
float timing_matmul_tn(int m, int n, int k, int batch, int warmup, int iter) {
float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, int iter) {
// init matrix
Ta *matrix_a = nullptr;
Tb *matrix_b = nullptr;
Tout *matrix_out = nullptr;
cudaMalloc(&matrix_a, m * k * std::max(batch, 1) * sizeof(Ta));
cudaMalloc(&matrix_b, k * n * std::max(batch, 1) * sizeof(Tb));
cudaMalloc(&matrix_out, m * n * std::max(batch, 1) * sizeof(Tout));
batch = std::max<size_t>(batch, 1);
cudaMalloc(&matrix_a, m * k * batch * sizeof(Ta));
cudaMalloc(&matrix_b, k * n * batch * sizeof(Tb));
cudaMalloc(&matrix_out, m * n * batch * sizeof(Tout));
init_matrix<Ta><<<216, 1024>>>(matrix_a, 1.f, m * k * std::max(batch, 1));
init_matrix<Tb><<<216, 1024>>>(matrix_b, 2.f, k * n * std::max(batch, 1));
init_matrix<Ta><<<216, 1024>>>(matrix_a, 1.f, m * k * batch);
init_matrix<Tb><<<216, 1024>>>(matrix_b, 2.f, k * n * batch);
// init gemm
int lda = k, ldb = k, ldd = m;
size_t lda = k, ldb = k, ldd = m;
std::unique_ptr<cublasLtGemm> gemm = std::make_unique<cublasLtGemm>();
gemm->Init();
gemm->Setup(m, n, k, batch, lda, ldb, ldd, get_datatype<Ta>(), get_datatype<Tb>(), get_datatype<Tout>(),

Просмотреть файл

@ -5,12 +5,12 @@
void cublasLtGemm::Init() {
cublasLtHandle_t handle;
checkCublasStatus(cublasLtCreate(&handle));
CUBLAS_CHECK(cublasLtCreate(&handle));
handle_.reset(handle);
/* preference can be initialized without arguments */
cublasLtMatmulPreference_t preference;
checkCublasStatus(cublasLtMatmulPreferenceCreate(&preference));
CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&preference));
preference_.reset(preference);
}
@ -24,32 +24,32 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
// force c_type
cudaDataType_t c_type = d_type;
// Create matrix descriptors.
checkCublasStatus(
CUBLAS_CHECK(
cublasLtMatrixLayoutCreate(&a_desc, a_type, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
checkCublasStatus(
CUBLAS_CHECK(
cublasLtMatrixLayoutCreate(&b_desc, b_type, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
checkCublasStatus(cublasLtMatrixLayoutCreate(&c_desc, c_type, m, n, ldd));
checkCublasStatus(cublasLtMatrixLayoutCreate(&d_desc, d_type, m, n, ldd));
CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&c_desc, c_type, m, n, ldd));
CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&d_desc, d_type, m, n, ldd));
// strided batch gemm
if (batch > 0) {
int64_t stridea = m * k, strideb = k * n, stridec = m * n, strided = m * n;
checkCublasStatus(
CUBLAS_CHECK(
cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&stridea, sizeof(stridea)));
checkCublasStatus(
CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea,
sizeof(stridea)));
CUBLAS_CHECK(
cublasLtMatrixLayoutSetAttribute(b_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(b_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&strideb, sizeof(strideb)));
checkCublasStatus(
CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(b_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb,
sizeof(strideb)));
CUBLAS_CHECK(
cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&stridec, sizeof(stridec)));
checkCublasStatus(
CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec,
sizeof(stridec)));
CUBLAS_CHECK(
cublasLtMatrixLayoutSetAttribute(d_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(d_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&strided, sizeof(strided)));
CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(d_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strided,
sizeof(strided)));
}
a_desc_.reset(a_desc);
b_desc_.reset(b_desc);
@ -64,7 +64,7 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
gemm_compute_type = CUBLAS_COMPUTE_64F;
cublasLtMatmulDesc_t op_desc = nullptr;
checkCublasStatus(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, CUDA_R_32F));
CUBLAS_CHECK(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, CUDA_R_32F));
op_desc_.reset(op_desc);
if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) {
@ -73,33 +73,31 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode));
}
checkCublasStatus(
cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)));
checkCublasStatus(
cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb)));
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)));
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb)));
if (a_scale_inverse != nullptr) {
checkCublasStatus(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&a_scale_inverse, sizeof(a_scale_inverse)));
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&a_scale_inverse, sizeof(a_scale_inverse)));
}
if (b_scale_inverse != nullptr) {
checkCublasStatus(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&b_scale_inverse, sizeof(b_scale_inverse)));
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&b_scale_inverse, sizeof(b_scale_inverse)));
}
checkCublasStatus(
CUBLAS_CHECK(
cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
}
size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_size) {
checkCublasStatus(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&max_workspace_size, sizeof(max_workspace_size)));
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&max_workspace_size, sizeof(max_workspace_size)));
int found_algorithm_count = 0;
std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
// Though we query all of possible algorithm, we will use the first later
checkCublasStatus(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(),
c_desc_.get(), d_desc_.get(), preference_.get(),
max_algorithm_count, results.data(), &found_algorithm_count));
CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(),
c_desc_.get(), d_desc_.get(), preference_.get(), max_algorithm_count,
results.data(), &found_algorithm_count));
if (found_algorithm_count == 0) {
throw std::runtime_error("Unable to find any suitable algorithms");
}
@ -111,13 +109,13 @@ size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_
void cublasLtGemm::Execute(void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d, float alpha, float beta,
void *workspace, size_t workspace_size, cudaStream_t stream) {
checkCublasStatus(cublasLtMatmul(handle_.get(), op_desc_.get(), static_cast<const void *>(&alpha), /* alpha */
matrix_a, /* A */
a_desc_.get(), matrix_b, /* B */
b_desc_.get(), static_cast<const void *>(&beta), /* beta */
matrix_c, /* C */
c_desc_.get(), matrix_d, /* D */
d_desc_.get(), &heuristic_results_.front().algo, /* algo */
workspace, /* workspace */
workspace_size, stream)); /* stream */
CUBLAS_CHECK(cublasLtMatmul(handle_.get(), op_desc_.get(), static_cast<const void *>(&alpha), /* alpha */
matrix_a, /* A */
a_desc_.get(), matrix_b, /* B */
b_desc_.get(), static_cast<const void *>(&beta), /* beta */
matrix_c, /* C */
c_desc_.get(), matrix_d, /* D */
d_desc_.get(), &heuristic_results_.front().algo, /* algo */
workspace, /* workspace */
workspace_size, stream)); /* stream */
}

Просмотреть файл

@ -10,12 +10,14 @@
#include <cublasLt.h>
inline void checkCublasStatus(cublasStatus_t status) {
if (status != CUBLAS_STATUS_SUCCESS) {
printf("cuBLAS API failed with status %s\n", cublasGetStatusString(status));
throw std::logic_error("cuBLAS API failed");
}
}
#define CUBLAS_CHECK(func) \
do { \
cublasStatus_t status = func; \
if (status != CUBLAS_STATUS_SUCCESS) { \
printf("cuBLAS call %s failed at %s:%d '%s'\n", #func, __FILE__, __LINE__, cublasGetStatusString(status)); \
exit(EXIT_FAILURE); \
} \
} while (0)
class cublasLtGemm {
public:

Просмотреть файл

@ -408,23 +408,21 @@ class CudnnBenchmark(MicroBenchmarkWithInvoke):
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data)
metric = ''
try:
lines = raw_output.splitlines()
metric = ''
cmd_config = json.loads(self._commands[cmd_idx].split('--config_json')[-1].replace(' ', '')[1:-1])
for key in sorted(cmd_config.keys()):
if 'name' in key:
metric = key + '_' + str(cmd_config[key]) + metric
else:
metric = metric + '_' + key + '_' + str(cmd_config[key])
metric = metric.replace(' ', '').replace(',', '_')
error = False
raw_data = []
for line in lines:
if '[function config]' in line:
metric = ''
metric_json_str = line[line.index('[function config]: ') +
len('[function config]: '):].replace(' ', '').replace(':', '_')[1:-1]
metric_list = metric_json_str.split(',')
for key in metric_list:
if 'name' in key:
metric = key + metric
else:
metric = metric + '_' + key
if '[raw_data]' in line:
raw_data = line[line.index('[raw_data]: ') + len('[raw_data]: '):]
raw_data = raw_data.split(',')

Просмотреть файл

@ -121,7 +121,7 @@ class DistInferenceModel(torch.nn.Module):
Return:
Tensor after all-gather.
"""
output = torch.empty_like([x.shape[0] * self.num_ranks] + list(x.shape[1:]))
output = torch.empty([x.shape[0] * self.num_ranks] + list(x.shape[1:]), dtype=x.dtype, device=x.device)
dist.all_gather_into_tensor(output, x)
return output

Просмотреть файл

@ -78,6 +78,13 @@ class ModelBenchmark(Benchmark):
required=False,
help='The number of batch size.',
)
self._parser.add_argument(
'--num_workers',
type=int,
default=8,
required=False,
help='Number of subprocesses to use for data loading.',
)
self._parser.add_argument(
'--precision',
type=Precision,

Просмотреть файл

@ -181,7 +181,7 @@ class PytorchBase(ModelBenchmark):
dataset=self._dataset,
batch_size=self._args.batch_size,
shuffle=False,
num_workers=8,
num_workers=self._args.num_workers,
sampler=train_sampler,
drop_last=True,
pin_memory=self._args.pin_memory

Просмотреть файл

@ -72,6 +72,22 @@ class DeviceManager:
temp = None
return temp
def get_device_power(self, idx):
"""Get the realtime power of device, unit: watt.
Args:
idx (int): device index.
Return:
temp (float): the realtime power of device, None means failed to get the data.
"""
try:
power = nvml.nvmlDeviceGetPowerUsage(self._device_handlers[idx])
except Exception as err:
logger.error('Get device power failed: {}'.format(str(err)))
return None
return int(int(power) / 1000)
def get_device_power_limit(self, idx):
"""Get the power management limit of device, unit: watt.

Просмотреть файл

@ -182,15 +182,14 @@ def gen_traffic_pattern_host_groups(host_list, pattern, mpi_pattern_path, benchm
logger.error('Unsupported traffic pattern: {}'.format(pattern.type))
host_groups = __convert_config_to_host_group(config, host_list)
# write traffic pattern host groups to specified path
if pattern.mpi_pattern:
with open(mpi_pattern_path, 'a') as f:
f.write('benchmark_name: {} pattern_type: {}'.format(benchmark_name, pattern.type) + '\n')
for host_group in host_groups:
row = []
for host_list in host_group:
group = ','.join(host_list)
row.append(group)
group = ';'.join(row)
f.write(group + '\n')
f.write('\n')
with open(mpi_pattern_path, 'a') as f:
f.write('benchmark_name: {} pattern_type: {}'.format(benchmark_name, pattern.type) + '\n')
for host_group in host_groups:
row = []
for host_list in host_group:
group = ','.join(host_list)
row.append(group)
group = ';'.join(row)
f.write(group + '\n')
f.write('\n')
return host_groups

Просмотреть файл

@ -3,7 +3,7 @@
# Server:
# - Product: HPE Apollo 6500
version: v0.7
version: v0.8
superbench:
enable: null
var:

Просмотреть файл

@ -4,7 +4,7 @@
# - Product: G482-Z53
# - Link: https://www.gigabyte.cn/FileUpload/Global/MicroSite/553/G482-Z53.html
version: v0.7
version: v0.8
superbench:
enable: null
var:

Просмотреть файл

@ -1,4 +1,4 @@
version: v0.7
version: v0.8
superbench:
enable: null
monitor:

Просмотреть файл

@ -1,4 +1,4 @@
version: v0.7
version: v0.8
superbench:
enable: null
monitor:

Просмотреть файл

@ -1,4 +1,4 @@
version: v0.7
version: v0.8
superbench:
enable: null
monitor:

Просмотреть файл

@ -3,7 +3,7 @@
# Azure NDm A100 v4
# reference: https://docs.microsoft.com/en-us/azure/virtual-machines/ndm-a100-v4-series
version: v0.7
version: v0.8
superbench:
enable: null
monitor:

Просмотреть файл

@ -1,5 +1,5 @@
# SuperBench Config
version: v0.7
version: v0.8
superbench:
enable: null
monitor:

Просмотреть файл

@ -1,5 +1,5 @@
# SuperBench Config
version: v0.7
version: v0.8
superbench:
enable: null
monitor:

Просмотреть файл

@ -38,16 +38,7 @@ class Monitor(multiprocessing.Process):
self.__unit_MiByte = 1024 * 1024 * 1.0
self.__output_handler = open(self.__output_file, 'a')
self.__cgroup = 1
output = run_command('grep cgroup /proc/filesystems', quiet=True)
if output.returncode != 0:
logger.error('Failed to check the cgroup version, will assume using cgroup V1.')
else:
if 'cgroup2' in output.stdout:
self.__cgroup = 2
logger.info('cgroup version: {}.'.format(self.__cgroup))
def __preprocess(self):
"""Preprocess/preparation operations before the monitoring.
@ -77,13 +68,15 @@ class Monitor(multiprocessing.Process):
container_pid = output.stdout
try:
if self.__cgroup == 1:
self._cpu_file = glob.glob('/sys/fs/cgroup/cpuacct/docker/{}*/cpuacct.stat'.format(container_id))[0]
cpu_file_cgroup_v1 = glob.glob('/sys/fs/cgroup/cpuacct/docker/{}*/cpuacct.stat'.format(container_id))
if len(cpu_file_cgroup_v1) > 0:
self._cpu_file = cpu_file_cgroup_v1[0]
self._mem_file = glob.glob(
'/sys/fs/cgroup/memory/docker/{}*/memory.usage_in_bytes'.format(container_id)
)[0]
self._net_file = '/proc/{}/net/dev'.format(container_pid)
else:
self.__cgroup = 2
self._cpu_file = glob.glob(
'/sys/fs/cgroup/system.slice/docker-{}*.scope/cpu.stat'.format(container_id)
)[0]
@ -99,10 +92,12 @@ class Monitor(multiprocessing.Process):
)
return False
else:
if self.__cgroup == 1:
self._cpu_file = '/sys/fs/cgroup/cpuacct/cpuacct.stat'
cpu_file_cgroup_v1 = '/sys/fs/cgroup/cpuacct/cpuacct.stat'
if os.path.exists(cpu_file_cgroup_v1):
self._cpu_file = cpu_file_cgroup_v1
self._mem_file = '/sys/fs/cgroup/memory/memory.usage_in_bytes'
else:
self.__cgroup = 2
self._cpu_file = '/sys/fs/cgroup/cpu.stat'
self._mem_file = '/sys/fs/cgroup/memory.stat'
self._net_file = '/proc/net/dev'
@ -199,6 +194,7 @@ class Monitor(multiprocessing.Process):
for i in range(device_count):
record.gpu_usage.append(dm.device_manager.get_device_utilization(i))
record.gpu_temperature.append(dm.device_manager.get_device_temperature(i))
record.gpu_power.append(dm.device_manager.get_device_power(i))
record.gpu_power_limit.append(dm.device_manager.get_device_power_limit(i))
mem_used, mem_total = dm.device_manager.get_device_memory(i)
record.gpu_mem_used.append(mem_used)

Просмотреть файл

@ -14,6 +14,7 @@ class MonitorRecord:
"""Record class to save all monitoring data."""
reduce_ops = {
'gpu_temperature': ReduceType.MAX,
'gpu_power': ReduceType.MAX,
'gpu_power_limit': ReduceType.MIN,
'gpu_corrected_ecc': ReduceType.LAST,
'gpu_uncorrected_ecc': ReduceType.LAST,
@ -28,6 +29,7 @@ class MonitorRecord:
self.__mem_total = None
self.__gpu_usage = list()
self.__gpu_temperature = list()
self.__gpu_power = list()
self.__gpu_power_limit = list()
self.__gpu_mem_used = list()
self.__gpu_mem_total = list()
@ -112,6 +114,20 @@ class MonitorRecord:
"""
self.__gpu_temperature = gpu_temperature
@property
def gpu_power(self):
"""Decoration function to access __gpu_power."""
return self.__gpu_power
@gpu_power.setter
def gpu_power(self, gpu_power):
"""Set the gpu realtime power, unit: Watt.
Args:
gpu_power(list): list of gpu realtime power.
"""
self.__gpu_power = gpu_power
@property
def gpu_power_limit(self):
"""Decoration function to access __gpu_power_limit."""

Просмотреть файл

@ -387,8 +387,9 @@ class SuperBenchRunner():
metrics_dict[metric].append(value)
for metric, values in metrics_dict.items():
prefix = metric.split(':')[0]
for pattern, reduce_type in MonitorRecord.reduce_ops.items():
if pattern in metric:
if pattern == prefix:
reduce_func = Reducer.get_reduce_func(reduce_type)
metric_name = 'monitor/{}'.format(metric)
metrics_summary[metric_name] = reduce_func(values)

Просмотреть файл

@ -167,6 +167,7 @@ def test_arguments_related_interfaces():
--no_gpu Disable GPU training.
--num_steps int The number of test step.
--num_warmup int The number of warmup step.
--num_workers int Number of subprocesses to use for data loading.
--pin_memory Enable option to pin memory in data loader.
--precision Precision [Precision ...]
Model precision. E.g. fp8_hybrid fp8_e4m3 fp8_e5m2
@ -206,6 +207,7 @@ def test_preprocess():
--no_gpu Disable GPU training.
--num_steps int The number of test step.
--num_warmup int The number of warmup step.
--num_workers int Number of subprocesses to use for data loading.
--pin_memory Enable option to pin memory in data loader.
--precision Precision [Precision ...]
Model precision. E.g. fp8_hybrid fp8_e4m3 fp8_e5m2

Просмотреть файл

@ -44,8 +44,8 @@ class MonitorTestCase(unittest.TestCase):
monitor._Monitor__sample_gpu_metrics(record)
gpu_list_metrics = [
record.gpu_usage, record.gpu_temperature, record.gpu_power_limit, record.gpu_mem_used, record.gpu_mem_total,
record.gpu_corrected_ecc, record.gpu_uncorrected_ecc
record.gpu_usage, record.gpu_temperature, record.gpu_power, record.gpu_power_limit, record.gpu_mem_used,
record.gpu_mem_total, record.gpu_corrected_ecc, record.gpu_uncorrected_ecc
]
for metric in gpu_list_metrics:
assert (metric)

Просмотреть файл

@ -17,6 +17,7 @@ def test_monitor_record():
mr.mem_total = 1024
mr.gpu_usage = [90, 80, 86, 72, 79, 81, 94, 85]
mr.gpu_temperature = [62, 75, 69, 63, 72, 77, 80, 71]
mr.gpu_power = [257, 290, 280, 262, 291, 284, 281, 273]
mr.gpu_power_limit = [400, 400, 400, 350, 400, 400, 400, 400]
mr.gpu_mem_used = [2550, 2680, 2543, 2588, 2612, 2603, 2515, 2593]
mr.gpu_mem_total = [16777216, 16777216, 16777216, 16777216, 16777216, 16777216, 16777216, 16777216]
@ -59,6 +60,14 @@ def test_monitor_record():
'gpu_temperature:5': 77,
'gpu_temperature:6': 80,
'gpu_temperature:7': 71,
'gpu_power:0': 257,
'gpu_power:1': 290,
'gpu_power:2': 280,
'gpu_power:3': 262,
'gpu_power:4': 291,
'gpu_power:5': 284,
'gpu_power:6': 281,
'gpu_power:7': 273,
'gpu_power_limit:0': 400,
'gpu_power_limit:1': 400,
'gpu_power_limit:2': 400,

Просмотреть файл

@ -0,0 +1,44 @@
---
slug: release-sb-v0.8
title: Releasing SuperBench v0.8
author: Peng Cheng
author_title: SuperBench Team
author_url: https://github.com/cp5555
author_image_url: https://github.com/cp5555.png
tags: [superbench, announcement, release]
---
We are very happy to announce that **SuperBench 0.8.0 version** is officially released today!
You can install and try superbench by following [Getting Started Tutorial](https://microsoft.github.io/superbenchmark/docs/getting-started/installation).
## SuperBench 0.8.0 Release Notes
### SuperBench Improvements
- Support SuperBench Executor running on Windows.
- Remove fixed rccl version in rocm5.1.x docker file.
- Upgrade networkx version to fix installation compatibility issue.
- Pin setuptools version to v65.7.0.
- Limit ansible_runner version for Python 3.6.
- Support cgroup V2 when read system metrics in monitor.
- Fix analyzer bug in Python 3.8 due to pandas api change.
- Collect real-time GPU power in monitor.
- Remove unreachable condition when write host list in mpi mode.
- Upgrade Docker image with cuda12.1, nccl 2.17.1-1, hpcx v2.14, and mlc 3.10.
- Fix wrong unit of cpu-memory-bw-latency in document.
### Micro-benchmark Improvements
- Add STREAM benchmark for sustainable memory bandwidth and the corresponding computation rate.
- Add HPL Benchmark for HPC Linpack Benchmark.
- Support flexible warmup and non-random data initialization in cublas-benchmark.
- Support error tolerance in micro-benchmark for CuDNN function.
- Add distributed inference benchmark.
- Support tensor core precisions (e.g., FP8) and batch/shape range in cublaslt gemm.
### Model Benchmark Improvements
- Fix torch.dist init issue with multiple models.
- Support TE FP8 in BERT/GPT2 model.
- Add num_workers configurations in model benchmark.

Просмотреть файл

@ -101,7 +101,7 @@ module.exports = {
announcementBar: {
id: 'supportus',
content:
'📢 <a href="https://microsoft.github.io/superbenchmark/blog/release-sb-v0.7">v0.7.0</a> has been released! ' +
'📢 <a href="https://microsoft.github.io/superbenchmark/blog/release-sb-v0.8">v0.8.0</a> has been released! ' +
'⭐️ If you like SuperBench, give it a star on <a target="_blank" rel="noopener noreferrer" href="https://github.com/microsoft/superbenchmark">GitHub</a>! ⭐️',
},
algolia: {

2
website/package-lock.json сгенерированный
Просмотреть файл

@ -1,6 +1,6 @@
{
"name": "superbench-website",
"version": "0.7.0",
"version": "0.8.0",
"lockfileVersion": 1,
"requires": true,
"dependencies": {

Просмотреть файл

@ -1,6 +1,6 @@
{
"name": "superbench-website",
"version": "0.7.0",
"version": "0.8.0",
"private": true,
"scripts": {
"docusaurus": "docusaurus",