Benchmarks: Add Benchmark - Add onnx model benchmarks based on docker image. (#227)

Add RocmOnnxModelBenchmark class to run benchmarks packaged in superbench/benchmark:rocm4.3.1-onnxruntime1.9.0
This commit is contained in:
guoshzhao 2021-10-27 18:41:40 +08:00 коммит произвёл GitHub
Родитель 6003f2c2a2
Коммит e98a68124e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 185 добавлений и 3 удалений

Просмотреть файл

@ -0,0 +1,24 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Docker benchmark example for onnxruntime models.
Commands to run:
python3 examples/benchmarks/rocm_onnxruntime_model_benchmark.py
"""
from superbench.benchmarks import BenchmarkRegistry, Framework, Platform
from superbench.common.utils import logger
if __name__ == '__main__':
context = BenchmarkRegistry.create_benchmark_context(
'ort-models', platform=Platform.ROCM, framework=Framework.ONNXRUNTIME
)
benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
logger.info(
'benchmark: {}, return code: {}, result: {}'.format(
benchmark.name, benchmark.return_code, benchmark.result
)
)

Просмотреть файл

@ -5,5 +5,6 @@
from superbench.benchmarks.docker_benchmarks.docker_base import DockerBenchmark, CudaDockerBenchmark, \
RocmDockerBenchmark
from superbench.benchmarks.docker_benchmarks.rocm_onnxruntime_performance import RocmOnnxRuntimeModelBenchmark
__all__ = ['DockerBenchmark', 'CudaDockerBenchmark', 'RocmDockerBenchmark']
__all__ = ['DockerBenchmark', 'CudaDockerBenchmark', 'RocmDockerBenchmark', 'RocmOnnxRuntimeModelBenchmark']

Просмотреть файл

@ -0,0 +1,89 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Module of the onnxruntime E2E model benchmarks.
Including:
bert-large-uncased ngpu=1
bert-large-uncased ngpu=8
distilbert-base-uncased ngpu=1
distilbert-base-uncased ngpu=8
gpt2 ngpu=1
gpt2 ngpu=8
facebook/bart-large ngpu=1
facebook/bart-large ngpu=8
roberta-large ngpu=1
roberta-large ngpu=8
"""
from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Platform
from superbench.benchmarks.docker_benchmarks.docker_base import RocmDockerBenchmark
class RocmOnnxRuntimeModelBenchmark(RocmDockerBenchmark):
"""The onnxruntime E2E model benchmark class."""
def __init__(self, name, parameters=''):
"""Constructor.
Args:
name (str): benchmark name.
parameters (str): benchmark parameters.
"""
super().__init__(name, parameters)
# Image uri of the current docker-benchmark.
self._image_uri = 'superbench/benchmark:rocm4.3.1-onnxruntime1.9.0'
# Image digest of the current docker-benchmark.
self._digest = 'f5e6c832e3cdcbba9820c619bb30fa47ca7117aa7f2c15944d17e6983d37ab9a'
# Container name of the current docker-benchmark.
self._container_name = 'rocm-onnxruntime-model-benchmarks'
# Entrypoint option of the current docker-benchmark.
self._entrypoint = '/stage/onnxruntime-training-examples/huggingface/azureml/run_benchmark.sh'
# CMD option of the current docker-benchmark.
self._cmd = None
def _process_raw_result(self, cmd_idx, raw_output):
"""Function to parse raw results and save the summarized results.
self._result.add_raw_data() and self._result.add_result() need to be called to save the results.
Args:
cmd_idx (int): the index of command corresponding with the raw_output.
raw_output (str): raw output string of the micro-benchmark.
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output', raw_output)
content = raw_output.splitlines(False)
try:
name_prefix = '__superbench__ begin '
value_prefix = ' "samples_per_second": '
model_name = None
for line in content:
if name_prefix in line:
model_name = line[len(name_prefix):]
for char in ['-', ' ', '=', '/']:
model_name = model_name.replace(char, '_')
elif value_prefix in line and model_name is not None:
throughput = float(line[len(value_prefix):])
self._result.add_result(model_name, throughput)
model_name = None
except BaseException as e:
logger.error(
'The result format is invalid - round: {}, benchmark: {}, message: {}.'.format(
self._curr_run_index, self._name, str(e)
)
)
return False
return True
BenchmarkRegistry.register_benchmark('onnxruntime-ort-models', RocmOnnxRuntimeModelBenchmark, platform=Platform.ROCM)

Просмотреть файл

@ -81,6 +81,12 @@ superbench:
mem_type:
- dtoh
- htod
ort-models:
enable: false
modes:
- name: local
frameworks:
- onnxruntime
gpt_models:
<<: *default_pytorch_mode
models:

Просмотреть файл

@ -53,8 +53,8 @@ superbench:
gemm-flops:
<<: *default_local_mode
parameters:
m: 7680
n: 8192
m: 7680
n: 8192
k: 8192
ib-loopback:
enable: true
@ -82,6 +82,12 @@ superbench:
mem_type:
- dtoh
- htod
ort-models:
enable: false
modes:
- name: local
frameworks:
- onnxruntime
gpt_models:
<<: *default_pytorch_mode
models:

Просмотреть файл

@ -0,0 +1,56 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tests for RocmOnnxRuntimeModelBenchmark modules."""
from superbench.benchmarks import BenchmarkRegistry, BenchmarkType, Platform, ReturnCode
from superbench.benchmarks.result import BenchmarkResult
def test_rocm_onnxruntime_performance():
"""Test onnxruntime model benchmark."""
benchmark_name = 'onnxruntime-ort-models'
(benchmark_class,
predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, Platform.ROCM)
assert (benchmark_class)
benchmark = benchmark_class(benchmark_name)
assert (benchmark._benchmark_type == BenchmarkType.DOCKER)
assert (benchmark._image_uri == 'superbench/benchmark:rocm4.3.1-onnxruntime1.9.0')
assert (benchmark._container_name == 'rocm-onnxruntime-model-benchmarks')
assert (benchmark._entrypoint == '/stage/onnxruntime-training-examples/huggingface/azureml/run_benchmark.sh')
assert (benchmark._cmd is None)
benchmark._result = BenchmarkResult(benchmark._name, benchmark._benchmark_type, ReturnCode.SUCCESS)
raw_output = """
__superbench__ begin bert-large-uncased ngpu=1
"samples_per_second": 21.829
__superbench__ begin bert-large-uncased ngpu=8
"samples_per_second": 147.181
__superbench__ begin distilbert-base-uncased ngpu=1
"samples_per_second": 126.827
__superbench__ begin distilbert-base-uncased ngpu=8
"samples_per_second": 966.796
__superbench__ begin gpt2 ngpu=1
"samples_per_second": 20.46
__superbench__ begin gpt2 ngpu=8
"samples_per_second": 151.089
__superbench__ begin facebook/bart-large ngpu=1
"samples_per_second": 66.171
__superbench__ begin facebook/bart-large ngpu=8
"samples_per_second": 370.343
__superbench__ begin roberta-large ngpu=1
"samples_per_second": 37.103
__superbench__ begin roberta-large ngpu=8
"samples_per_second": 274.455
"""
assert (benchmark._process_raw_result(0, raw_output))
assert (benchmark.result['bert_large_uncased_ngpu_1'][0] == 21.829)
assert (benchmark.result['bert_large_uncased_ngpu_8'][0] == 147.181)
assert (benchmark.result['distilbert_base_uncased_ngpu_1'][0] == 126.827)
assert (benchmark.result['distilbert_base_uncased_ngpu_8'][0] == 966.796)
assert (benchmark.result['gpt2_ngpu_1'][0] == 20.46)
assert (benchmark.result['gpt2_ngpu_8'][0] == 151.089)
assert (benchmark.result['facebook_bart_large_ngpu_1'][0] == 66.171)
assert (benchmark.result['facebook_bart_large_ngpu_8'][0] == 370.343)
assert (benchmark.result['roberta_large_ngpu_1'][0] == 37.103)
assert (benchmark.result['roberta_large_ngpu_8'][0] == 274.455)