**Description**

Cherry-pick bug fixes from v0.7.0 to main.

**Major Revisions**

* Benchmarks - Fix missing include in FP8 benchmark (#460)
* Fix bug in TE BERT model (#461)
* Doc - Update benchmark doc (#465)
* Bug: Fix bug for incorrect datatype judgement in cublas-function
source code (#464)
* Support `sb deploy` without pulling image (#466)
* Docs - Upgrade version and release note (#467)

Co-authored-by: Russell J. Hewett <russell.j.hewett@gmail.com>
Co-authored-by: Yuting Jiang <yutingjiang@microsoft.com>
This commit is contained in:
Yifan Xiong 2023-01-28 11:07:06 +08:00 коммит произвёл GitHub
Родитель f380bc5eff
Коммит b07fda155e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
30 изменённых файлов: 121 добавлений и 52 удалений

Просмотреть файл

@ -15,7 +15,7 @@
__SuperBench__ is a validation and profiling tool for AI infrastructure.
📢 [v0.6.0](https://github.com/microsoft/superbenchmark/releases/tag/v0.6.0) has been released!
📢 [v0.7.0](https://github.com/microsoft/superbenchmark/releases/tag/v0.7.0) has been released!
## _Check [aka.ms/superbench](https://aka.ms/superbench) for more details._

Просмотреть файл

@ -97,6 +97,7 @@ sb deploy [--docker-image]
[--host-list]
[--host-password]
[--host-username]
[--no-image-pull]
[--output-dir]
[--private-key]
```
@ -112,6 +113,7 @@ sb deploy [--docker-image]
| `--host-list` `-l` | `None` | Comma separated host list. |
| `--host-password` | `None` | Host password or key passphase if needed. |
| `--host-username` | `None` | Host username if needed. |
| `--no-image-pull` | `False` | Skip pull and use local Docker image. |
| `--output-dir` | `None` | Path to output directory, outputs/{datetime} will be used if not specified. |
| `--private-key` | `None` | Path to private key if needed. |

Просмотреть файл

@ -61,7 +61,7 @@ You can clone the source from GitHub and build it.
:::note Note
You should checkout corresponding tag to use release version, for example,
`git clone -b v0.6.0 https://github.com/microsoft/superbenchmark`
`git clone -b v0.7.0 https://github.com/microsoft/superbenchmark`
:::
```bash

Просмотреть файл

@ -27,7 +27,7 @@ sb deploy -f remote.ini --host-password [password]
:::note Note
You should deploy corresponding Docker image to use release version, for example,
`sb deploy -f local.ini -i superbench/superbench:v0.6.0-cuda11.1.1`
`sb deploy -f local.ini -i superbench/superbench:v0.7.0-cuda11.1.1`
You should note that version of git repo only determines version of sb CLI, and not the sb container. You should define the container version even if you specified a release version for the git clone.

Просмотреть файл

@ -70,7 +70,7 @@ superbench:
<TabItem value='example'>
```yaml
version: v0.6
version: v0.7
superbench:
enable: benchmark_1
monitor:
@ -471,4 +471,3 @@ Available variables in formatted string includes:
+ `ibnetdiscover(str)`: the path of ibnetdiscover output `ibnetdiscover_file.txt`, required in `topo-aware` pattern.
+ `min_dist(int)`: minimum distance of VM pair, required in `topo-aware` pattern.
+ `max_dist(int)`: maximum distance of VM pair, required in `topo-aware` pattern.

Просмотреть файл

@ -66,9 +66,9 @@ Measure the GEMM performance of [`cublasLtMatmul`](https://docs.nvidia.com/cuda/
#### Metrics
| Name | Unit | Description |
|---------------------------------|----------------|---------------------------------|
| cublaslt-gemm/dtype_m_n_k_flops | FLOPS (TFLOPS) | TFLOPS of measured GEMM kernel. |
| Name | Unit | Description |
|------------------------------------------------|----------------|---------------------------------|
| cublaslt-gemm/${dtype}\_${m}\_${n}\_${k}_flops | FLOPS (TFLOPS) | TFLOPS of measured GEMM kernel. |
### `cublas-function`
@ -86,9 +86,11 @@ The supported functions for cuBLAS are as follows:
#### Metrics
| Name | Unit | Description |
|----------------------------------------------------------|-----------|-------------------------------------------------------------------|
| cublas-function/name_${function_name}_${parameters}_time | time (us) | The mean time to execute the cublas function with the parameters. |
| Name | Unit | Description |
|-------------------------------------------------------------------|-----------|----------------------------------------------------------------------------------------------------------------------------------------------|
| cublas-function/name\_${function_name}\_${parameters}_time | time (us) | The mean time to execute the cublas function with the parameters. |
| cublas-function/name\_${function_name}\_${parameters}_correctness | | Whether the calculation results of executing the cublas function with the parameters pass the correctness check if enable correctness check. |
| cublas-function/name\_${function_name}\_${parameters}_error | | The error ratio of the calculation results of executing the cublas function with the parameters if enable correctness check. |
### `cudnn-function`
@ -103,9 +105,9 @@ The supported functions for cuDNN are as follows:
#### Metrics
| Name | Unit | Description |
|---------------------------------------------------------|-----------|------------------------------------------------------------------|
| cudnn-function/name_${function_name}_${parameters}_time | time (us) | The mean time to execute the cudnn function with the parameters. |
| Name | Unit | Description |
|-----------------------------------------------------------|-----------|------------------------------------------------------------------|
| cudnn-function/name\_${function_name}\_${parameters}_time | time (us) | The mean time to execute the cudnn function with the parameters. |
### `tensorrt-inference`
@ -264,9 +266,10 @@ Support the following traffic patterns:
| rccl-bw/${operation}_${msg_size}_algbw | bandwidth (GB/s) | RCCL operation algorithm bandwidth with given message size. |
| rccl-bw/${operation}_${msg_size}_busbw | bandwidth (GB/s) | RCCL operation bus bandwidth with given message size. |
If traffic pattern is specified, the metrics pattern will change to `nccl-bw/${operation}_${serial_index)_${parallel_index):${msg_size}_time`
If mpi mode is enable and traffic pattern is specified, the metrics pattern will change to `nccl-bw/${operation}_${serial_index)_${parallel_index):${msg_size}_time`
- `serial_index` represents the serial index of the host group in serial.
- `parallel_index` represents the parallel index of the host list in parallel.
### `tcp-connectivity`
#### Introduction

Просмотреть файл

@ -30,19 +30,15 @@ including the following categories:
For inference, supported percentiles include
50<sup>th</sup>, 90<sup>th</sup>, 95<sup>th</sup>, 99<sup>th</sup>, and 99.9<sup>th</sup>.
**New: Support fp8_hybrid and fp8_e4m3 precision for BERT models.**
#### Metrics
| Name | Unit | Description |
|---------------------------------------------------------------------------------|------------------------|---------------------------------------------------------------------------|
| model-benchmarks/pytorch-${model_name}/fp32_train_step_time | time (ms) | The average training step time with single precision. |
| model-benchmarks/pytorch-${model_name}/fp32_train_throughput | throughput (samples/s) | The average training throughput with single precision. |
| model-benchmarks/pytorch-${model_name}/fp32_inference_step_time | time (ms) | The average inference step time with single precision. |
| model-benchmarks/pytorch-${model_name}/fp32_inference_throughput | throughput (samples/s) | The average inference throughput with single precision. |
| model-benchmarks/pytorch-${model_name}/fp32_inference_step_time\_${percentile} | time (ms) | The n<sup>th</sup> percentile inference step time with single precision. |
| model-benchmarks/pytorch-${model_name}/fp32_inference_throughput\_${percentile} | throughput (samples/s) | The n<sup>th</sup> percentile inference throughput with single precision. |
| model-benchmarks/pytorch-${model_name}/fp16_train_step_time | time (ms) | The average training step time with half precision. |
| model-benchmarks/pytorch-${model_name}/fp16_train_throughput | throughput (samples/s) | The average training throughput with half precision. |
| model-benchmarks/pytorch-${model_name}/fp16_inference_step_time | time (ms) | The average inference step time with half precision. |
| model-benchmarks/pytorch-${model_name}/fp16_inference_throughput | throughput (samples/s) | The average inference throughput with half precision. |
| model-benchmarks/pytorch-${model_name}/fp16_inference_step_time\_${percentile} | time (ms) | The n<sup>th</sup> percentile inference step time with half precision. |
| model-benchmarks/pytorch-${model_name}/fp16_inference_throughput\_${percentile} | throughput (samples/s) | The n<sup>th</sup> percentile inference throughput with half precision. |
| Name | Unit | Description |
|-----------------------------------------------------------------------------------------|------------------------|------------------------------------------------------------------------------|
| model-benchmarks/pytorch-${model_name}/${precision}_train_step_time | time (ms) | The average training step time with fp32/fp16 precision. |
| model-benchmarks/pytorch-${model_name}/${precision}_train_throughput | throughput (samples/s) | The average training throughput with fp32/fp16 precision. |
| model-benchmarks/pytorch-${model_name}/${precision}_inference_step_time | time (ms) | The average inference step time with fp32/fp16 precision. |
| model-benchmarks/pytorch-${model_name}/${precision}_inference_throughput | throughput (samples/s) | The average inference throughput with fp32/fp16 precision. |
| model-benchmarks/pytorch-${model_name}/${precision}_inference_step_time\_${percentile} | time (ms) | The n<sup>th</sup> percentile inference step time with fp32/fp16 precision. |
| model-benchmarks/pytorch-${model_name}/${precision}_inference_throughput\_${percentile} | throughput (samples/s) | The n<sup>th</sup> percentile inference throughput with fp32/fp16 precision. |

Просмотреть файл

@ -29,6 +29,8 @@ available tags are listed below for all stable versions.
| Tag | Description |
|-------------------|------------------------------------|
| v0.7.0-cuda11.8 | SuperBench v0.7.0 with CUDA 11.8 |
| v0.7.0-cuda11.1.1 | SuperBench v0.7.0 with CUDA 11.1.1 |
| v0.6.0-cuda11.1.1 | SuperBench v0.6.0 with CUDA 11.1.1 |
| v0.5.0-cuda11.1.1 | SuperBench v0.5.0 with CUDA 11.1.1 |
| v0.4.0-cuda11.1.1 | SuperBench v0.4.0 with CUDA 11.1.1 |
@ -41,6 +43,10 @@ available tags are listed below for all stable versions.
| Tag | Description |
|-------------------------------|--------------------------------------------------|
| v0.7.0-rocm5.1.3 | SuperBench v0.7.0 with ROCm 5.1.3 |
| v0.7.0-rocm5.1.1 | SuperBench v0.7.0 with ROCm 5.1.1 |
| v0.7.0-rocm5.0.1 | SuperBench v0.7.0 with ROCm 5.0.1 |
| v0.7.0-rocm5.0 | SuperBench v0.7.0 with ROCm 5.0 |
| v0.6.0-rocm5.1.3 | SuperBench v0.6.0 with ROCm 5.1.3 |
| v0.6.0-rocm5.1.1 | SuperBench v0.6.0 with ROCm 5.1.1 |
| v0.6.0-rocm5.0.1 | SuperBench v0.6.0 with ROCm 5.0.1 |

Просмотреть файл

@ -65,7 +65,7 @@ superbench:
example:
```yaml
# SuperBench rules
version: v0.6
version: v0.7
superbench:
rules:
failure-rule:

Просмотреть файл

@ -58,7 +58,7 @@ superbench:
```yaml title="Example"
# SuperBench rules
version: v0.6
version: v0.7
superbench:
rules:
kernel_launch:

Просмотреть файл

@ -6,5 +6,5 @@
Provide hardware and software benchmarks for AI systems.
"""
__version__ = '0.6.0'
__version__ = '0.7.0'
__author__ = 'Microsoft'

Просмотреть файл

@ -4,6 +4,7 @@
#pragma once
#include <memory>
#include <stdexcept>
#include <stdio.h>
#include <vector>

Просмотреть файл

@ -61,15 +61,18 @@ class TeBertBenchmarkModel(torch.nn.Module):
self._embedding = torch.nn.Embedding(config.vocab_size, config.hidden_size)
# Build BERT using nn.TransformerEncoderLayer or te.TransformerLayer
# input shape: (seq_len, batch_size, hidden_size)
encoder_layer = te.TransformerLayer(
config.hidden_size,
config.intermediate_size,
config.num_attention_heads,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
layer_type='encoder',
self._encoder_layers = torch.nn.ModuleList(
[
te.TransformerLayer(
config.hidden_size,
config.intermediate_size,
config.num_attention_heads,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
layer_type='encoder',
) for _ in range(config.num_hidden_layers)
]
)
self._encoder_layers = torch.nn.ModuleList([encoder_layer for _ in range(config.num_hidden_layers)])
# BertPooler used in huggingface transformers
# https://github.com/huggingface/transformers/blob/accad48e/src/transformers/models/bert/modeling_bert.py#L893
self._pooler = torch.nn.Sequential(
@ -113,7 +116,6 @@ class PytorchBERT(PytorchBase):
Precision.FLOAT16,
Precision.FP8_HYBRID,
Precision.FP8_E4M3,
Precision.FP8_E5M2,
]
self._optimizer_type = Optimizer.ADAMW
self._loss_fn = torch.nn.CrossEntropyLoss()

Просмотреть файл

@ -44,6 +44,7 @@ class SuperBenchCommandsLoader(CLICommandsLoader):
ac.argument('docker_username', type=str, help='Docker registry username if authentication is needed.')
ac.argument('docker_password', type=str, help='Docker registry password if authentication is needed.')
ac.argument('no_docker', action='store_true', help='Run on host directly without Docker.')
ac.argument('no_image_pull', action='store_true', help='Skip pull and use local Docker image.')
ac.argument(
'host_file', options_list=('--host-file', '-f'), type=str, help='Path to Ansible inventory host file.'
)

Просмотреть файл

@ -99,6 +99,7 @@ def process_runner_arguments(
docker_username=None,
docker_password=None,
no_docker=False,
no_image_pull=False,
host_file=None,
host_list=None,
host_username=None,
@ -115,6 +116,7 @@ def process_runner_arguments(
docker_username (str, optional): Docker registry username if authentication is needed. Defaults to None.
docker_password (str, optional): Docker registry password if authentication is needed. Defaults to None.
no_docker (bool, optional): Run on host directly without Docker. Defaults to False.
no_image_pull (bool, optional): Skip pull and use local Docker image. Defaults to False.
host_file (str, optional): Path to Ansible inventory host file. Defaults to None.
host_list (str, optional): Comma separated host list. Defaults to None.
host_username (str, optional): Host username if needed. Defaults to None.
@ -149,6 +151,7 @@ def process_runner_arguments(
'password': docker_password,
'registry': split_docker_domain(docker_image)[0],
'skip': no_docker,
'pull': not no_image_pull,
}
)
# Ansible config
@ -209,6 +212,7 @@ def deploy_command_handler(
docker_image='superbench/superbench',
docker_username=None,
docker_password=None,
no_image_pull=False,
host_file=None,
host_list=None,
host_username=None,
@ -228,6 +232,7 @@ def deploy_command_handler(
docker_image (str, optional): Docker image URI. Defaults to superbench/superbench:latest.
docker_username (str, optional): Docker registry username if authentication is needed. Defaults to None.
docker_password (str, optional): Docker registry password if authentication is needed. Defaults to None.
no_image_pull (bool, optional): Skip pull and use local Docker image. Defaults to False.
host_file (str, optional): Path to Ansible inventory host file. Defaults to None.
host_list (str, optional): Comma separated host list. Defaults to None.
host_username (str, optional): Host username if needed. Defaults to None.
@ -243,6 +248,7 @@ def deploy_command_handler(
docker_username=docker_username,
docker_password=docker_password,
no_docker=False,
no_image_pull=no_image_pull,
host_file=host_file,
host_list=host_list,
host_username=host_username,
@ -298,6 +304,7 @@ def run_command_handler(
docker_username=docker_username,
docker_password=docker_password,
no_docker=no_docker,
no_image_pull=False,
host_file=host_file,
host_list=host_list,
host_username=host_username,

Просмотреть файл

@ -3,7 +3,7 @@
# Server:
# - Product: HPE Apollo 6500
version: v0.6
version: v0.7
superbench:
enable: null
var:

Просмотреть файл

@ -4,7 +4,7 @@
# - Product: G482-Z53
# - Link: https://www.gigabyte.cn/FileUpload/Global/MicroSite/553/G482-Z53.html
version: v0.6
version: v0.7
superbench:
enable: null
var:

Просмотреть файл

@ -1,4 +1,4 @@
version: v0.6
version: v0.7
superbench:
enable: null
monitor:

Просмотреть файл

@ -1,4 +1,4 @@
version: v0.6
version: v0.7
superbench:
enable: null
monitor:

Просмотреть файл

@ -1,4 +1,4 @@
version: v0.6
version: v0.7
superbench:
enable: null
monitor:

Просмотреть файл

@ -3,7 +3,7 @@
# Azure NDm A100 v4
# reference: https://docs.microsoft.com/en-us/azure/virtual-machines/ndm-a100-v4-series
version: v0.6
version: v0.7
superbench:
enable: null
monitor:

Просмотреть файл

@ -1,5 +1,5 @@
# SuperBench Config
version: v0.6
version: v0.7
superbench:
enable: null
monitor:

Просмотреть файл

@ -1,5 +1,5 @@
# SuperBench Config
version: v0.6
version: v0.7
superbench:
enable: null
monitor:

Просмотреть файл

@ -92,6 +92,7 @@
shell: |
docker pull {{ docker_image }}
become: yes
when: docker_pull | default(true)
throttle: 32
- name: Starting Container
shell: |

Просмотреть файл

@ -183,6 +183,7 @@ class SuperBenchRunner():
'ssh_port': random.randint(1 << 14, (1 << 15) - 1),
'output_dir': str(self._output_path),
'docker_image': self._docker_config.image,
'docker_pull': bool(self._docker_config.pull),
}
if bool(self._docker_config.username) and bool(self._docker_config.password):
extravars.update(

Просмотреть файл

@ -60,6 +60,12 @@ class SuperBenchCLIScenarioTest(ScenarioTest):
mocked_failure_count.return_value = 0
self.cmd('sb deploy --host-list localhost', checks=[NoneCheck()])
@mock.patch('superbench.runner.SuperBenchRunner.get_failure_count')
def test_sb_deploy_skippull(self, mocked_failure_count):
"""Test sb deploy without docker pull."""
mocked_failure_count.return_value = 0
self.cmd('sb deploy --host-list localhost --no-image-pull', checks=[NoneCheck()])
def test_sb_deploy_no_host(self):
"""Test sb deploy, no host_file or host_list provided, should fail."""
self.cmd('sb deploy', expect_failure=True)

Просмотреть файл

@ -0,0 +1,44 @@
---
slug: release-sb-v0.7
title: Releasing SuperBench v0.7
author: Peng Cheng
author_title: SuperBench Team
author_url: https://github.com/cp5555
author_image_url: https://github.com/cp5555.png
tags: [superbench, announcement, release]
---
We are very happy to announce that **SuperBench 0.7.0 version** is officially released today!
You can install and try superbench by following [Getting Started Tutorial](https://microsoft.github.io/superbenchmark/docs/getting-started/installation).
## SuperBench 0.7.0 Release Notes
### SuperBench Improvement
- Support non-zero return code when "sb deploy" or "sb run" fails in Ansible.
- Support log flushing to the result file during runtime.
- Update version to include revision hash and date.
- Support "pattern" in mpi mode to run tasks in parallel.
- Support topo-aware, all-pair, and K-batch pattern in mpi mode.
- Fix Transformers version to avoid Tensorrt failure.
- Add CUDA11.8 Docker image for NVIDIA arch90 GPUs.
- Support "sb deploy" without pulling image.
### Micro-benchmark Improvements
- Support list of custom config string in cudnn-functions and cublas-functions.
- Support correctness check in cublas-functions.
- Support GEMM-FLOPS for NVIDIA arch90 GPUs.
- Support cuBLASLt FP16 and FP8 GEMM.
- Add wait time option to resolve mem-bw unstable issue.
- Fix bug for incorrect datatype judgement in cublas-function source code.
### Model Benchmark Improvements
- Support FP8 in BERT model training.
### Distributed Benchmark Improvements
- Support pair-wise pattern in IB validation benchmark.
- Support topo-aware, pair-wise, and K-batch pattern in nccl-bw benchmark.

Просмотреть файл

@ -101,7 +101,7 @@ module.exports = {
announcementBar: {
id: 'supportus',
content:
'📢 <a href="https://microsoft.github.io/superbenchmark/blog/release-sb-v0.6">v0.6.0</a> has been released! ' +
'📢 <a href="https://microsoft.github.io/superbenchmark/blog/release-sb-v0.7">v0.7.0</a> has been released! ' +
'⭐️ If you like SuperBench, give it a star on <a target="_blank" rel="noopener noreferrer" href="https://github.com/microsoft/superbenchmark">GitHub</a>! ⭐️',
},
algolia: {

2
website/package-lock.json сгенерированный
Просмотреть файл

@ -1,6 +1,6 @@
{
"name": "superbench-website",
"version": "0.6.0",
"version": "0.7.0",
"lockfileVersion": 1,
"requires": true,
"dependencies": {

Просмотреть файл

@ -1,6 +1,6 @@
{
"name": "superbench-website",
"version": "0.6.0",
"version": "0.7.0",
"private": true,
"scripts": {
"docusaurus": "docusaurus",