Benchmarks: Add Benchmark - Add onnx model benchmarks based on docker image. (#227)
Add RocmOnnxModelBenchmark class to run benchmarks packaged in superbench/benchmark:rocm4.3.1-onnxruntime1.9.0
This commit is contained in:
Родитель
6003f2c2a2
Коммит
e98a68124e
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Docker benchmark example for onnxruntime models.
|
||||
|
||||
Commands to run:
|
||||
python3 examples/benchmarks/rocm_onnxruntime_model_benchmark.py
|
||||
"""
|
||||
|
||||
from superbench.benchmarks import BenchmarkRegistry, Framework, Platform
|
||||
from superbench.common.utils import logger
|
||||
|
||||
if __name__ == '__main__':
|
||||
context = BenchmarkRegistry.create_benchmark_context(
|
||||
'ort-models', platform=Platform.ROCM, framework=Framework.ONNXRUNTIME
|
||||
)
|
||||
|
||||
benchmark = BenchmarkRegistry.launch_benchmark(context)
|
||||
if benchmark:
|
||||
logger.info(
|
||||
'benchmark: {}, return code: {}, result: {}'.format(
|
||||
benchmark.name, benchmark.return_code, benchmark.result
|
||||
)
|
||||
)
|
|
@ -5,5 +5,6 @@
|
|||
|
||||
from superbench.benchmarks.docker_benchmarks.docker_base import DockerBenchmark, CudaDockerBenchmark, \
|
||||
RocmDockerBenchmark
|
||||
from superbench.benchmarks.docker_benchmarks.rocm_onnxruntime_performance import RocmOnnxRuntimeModelBenchmark
|
||||
|
||||
__all__ = ['DockerBenchmark', 'CudaDockerBenchmark', 'RocmDockerBenchmark']
|
||||
__all__ = ['DockerBenchmark', 'CudaDockerBenchmark', 'RocmDockerBenchmark', 'RocmOnnxRuntimeModelBenchmark']
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Module of the onnxruntime E2E model benchmarks.
|
||||
|
||||
Including:
|
||||
bert-large-uncased ngpu=1
|
||||
bert-large-uncased ngpu=8
|
||||
distilbert-base-uncased ngpu=1
|
||||
distilbert-base-uncased ngpu=8
|
||||
gpt2 ngpu=1
|
||||
gpt2 ngpu=8
|
||||
facebook/bart-large ngpu=1
|
||||
facebook/bart-large ngpu=8
|
||||
roberta-large ngpu=1
|
||||
roberta-large ngpu=8
|
||||
"""
|
||||
|
||||
from superbench.common.utils import logger
|
||||
from superbench.benchmarks import BenchmarkRegistry, Platform
|
||||
from superbench.benchmarks.docker_benchmarks.docker_base import RocmDockerBenchmark
|
||||
|
||||
|
||||
class RocmOnnxRuntimeModelBenchmark(RocmDockerBenchmark):
|
||||
"""The onnxruntime E2E model benchmark class."""
|
||||
def __init__(self, name, parameters=''):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
name (str): benchmark name.
|
||||
parameters (str): benchmark parameters.
|
||||
"""
|
||||
super().__init__(name, parameters)
|
||||
|
||||
# Image uri of the current docker-benchmark.
|
||||
self._image_uri = 'superbench/benchmark:rocm4.3.1-onnxruntime1.9.0'
|
||||
|
||||
# Image digest of the current docker-benchmark.
|
||||
self._digest = 'f5e6c832e3cdcbba9820c619bb30fa47ca7117aa7f2c15944d17e6983d37ab9a'
|
||||
|
||||
# Container name of the current docker-benchmark.
|
||||
self._container_name = 'rocm-onnxruntime-model-benchmarks'
|
||||
|
||||
# Entrypoint option of the current docker-benchmark.
|
||||
self._entrypoint = '/stage/onnxruntime-training-examples/huggingface/azureml/run_benchmark.sh'
|
||||
|
||||
# CMD option of the current docker-benchmark.
|
||||
self._cmd = None
|
||||
|
||||
def _process_raw_result(self, cmd_idx, raw_output):
|
||||
"""Function to parse raw results and save the summarized results.
|
||||
|
||||
self._result.add_raw_data() and self._result.add_result() need to be called to save the results.
|
||||
|
||||
Args:
|
||||
cmd_idx (int): the index of command corresponding with the raw_output.
|
||||
raw_output (str): raw output string of the micro-benchmark.
|
||||
|
||||
Return:
|
||||
True if the raw output string is valid and result can be extracted.
|
||||
"""
|
||||
self._result.add_raw_data('raw_output', raw_output)
|
||||
|
||||
content = raw_output.splitlines(False)
|
||||
try:
|
||||
name_prefix = '__superbench__ begin '
|
||||
value_prefix = ' "samples_per_second": '
|
||||
model_name = None
|
||||
for line in content:
|
||||
if name_prefix in line:
|
||||
model_name = line[len(name_prefix):]
|
||||
for char in ['-', ' ', '=', '/']:
|
||||
model_name = model_name.replace(char, '_')
|
||||
elif value_prefix in line and model_name is not None:
|
||||
throughput = float(line[len(value_prefix):])
|
||||
self._result.add_result(model_name, throughput)
|
||||
model_name = None
|
||||
except BaseException as e:
|
||||
logger.error(
|
||||
'The result format is invalid - round: {}, benchmark: {}, message: {}.'.format(
|
||||
self._curr_run_index, self._name, str(e)
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
BenchmarkRegistry.register_benchmark('onnxruntime-ort-models', RocmOnnxRuntimeModelBenchmark, platform=Platform.ROCM)
|
|
@ -81,6 +81,12 @@ superbench:
|
|||
mem_type:
|
||||
- dtoh
|
||||
- htod
|
||||
ort-models:
|
||||
enable: false
|
||||
modes:
|
||||
- name: local
|
||||
frameworks:
|
||||
- onnxruntime
|
||||
gpt_models:
|
||||
<<: *default_pytorch_mode
|
||||
models:
|
||||
|
|
|
@ -53,8 +53,8 @@ superbench:
|
|||
gemm-flops:
|
||||
<<: *default_local_mode
|
||||
parameters:
|
||||
m: 7680
|
||||
n: 8192
|
||||
m: 7680
|
||||
n: 8192
|
||||
k: 8192
|
||||
ib-loopback:
|
||||
enable: true
|
||||
|
@ -82,6 +82,12 @@ superbench:
|
|||
mem_type:
|
||||
- dtoh
|
||||
- htod
|
||||
ort-models:
|
||||
enable: false
|
||||
modes:
|
||||
- name: local
|
||||
frameworks:
|
||||
- onnxruntime
|
||||
gpt_models:
|
||||
<<: *default_pytorch_mode
|
||||
models:
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Tests for RocmOnnxRuntimeModelBenchmark modules."""
|
||||
|
||||
from superbench.benchmarks import BenchmarkRegistry, BenchmarkType, Platform, ReturnCode
|
||||
from superbench.benchmarks.result import BenchmarkResult
|
||||
|
||||
|
||||
def test_rocm_onnxruntime_performance():
|
||||
"""Test onnxruntime model benchmark."""
|
||||
benchmark_name = 'onnxruntime-ort-models'
|
||||
(benchmark_class,
|
||||
predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, Platform.ROCM)
|
||||
assert (benchmark_class)
|
||||
benchmark = benchmark_class(benchmark_name)
|
||||
assert (benchmark._benchmark_type == BenchmarkType.DOCKER)
|
||||
assert (benchmark._image_uri == 'superbench/benchmark:rocm4.3.1-onnxruntime1.9.0')
|
||||
assert (benchmark._container_name == 'rocm-onnxruntime-model-benchmarks')
|
||||
assert (benchmark._entrypoint == '/stage/onnxruntime-training-examples/huggingface/azureml/run_benchmark.sh')
|
||||
assert (benchmark._cmd is None)
|
||||
benchmark._result = BenchmarkResult(benchmark._name, benchmark._benchmark_type, ReturnCode.SUCCESS)
|
||||
|
||||
raw_output = """
|
||||
__superbench__ begin bert-large-uncased ngpu=1
|
||||
"samples_per_second": 21.829
|
||||
__superbench__ begin bert-large-uncased ngpu=8
|
||||
"samples_per_second": 147.181
|
||||
__superbench__ begin distilbert-base-uncased ngpu=1
|
||||
"samples_per_second": 126.827
|
||||
__superbench__ begin distilbert-base-uncased ngpu=8
|
||||
"samples_per_second": 966.796
|
||||
__superbench__ begin gpt2 ngpu=1
|
||||
"samples_per_second": 20.46
|
||||
__superbench__ begin gpt2 ngpu=8
|
||||
"samples_per_second": 151.089
|
||||
__superbench__ begin facebook/bart-large ngpu=1
|
||||
"samples_per_second": 66.171
|
||||
__superbench__ begin facebook/bart-large ngpu=8
|
||||
"samples_per_second": 370.343
|
||||
__superbench__ begin roberta-large ngpu=1
|
||||
"samples_per_second": 37.103
|
||||
__superbench__ begin roberta-large ngpu=8
|
||||
"samples_per_second": 274.455
|
||||
"""
|
||||
assert (benchmark._process_raw_result(0, raw_output))
|
||||
assert (benchmark.result['bert_large_uncased_ngpu_1'][0] == 21.829)
|
||||
assert (benchmark.result['bert_large_uncased_ngpu_8'][0] == 147.181)
|
||||
assert (benchmark.result['distilbert_base_uncased_ngpu_1'][0] == 126.827)
|
||||
assert (benchmark.result['distilbert_base_uncased_ngpu_8'][0] == 966.796)
|
||||
assert (benchmark.result['gpt2_ngpu_1'][0] == 20.46)
|
||||
assert (benchmark.result['gpt2_ngpu_8'][0] == 151.089)
|
||||
assert (benchmark.result['facebook_bart_large_ngpu_1'][0] == 66.171)
|
||||
assert (benchmark.result['facebook_bart_large_ngpu_8'][0] == 370.343)
|
||||
assert (benchmark.result['roberta_large_ngpu_1'][0] == 37.103)
|
||||
assert (benchmark.result['roberta_large_ngpu_8'][0] == 274.455)
|
Загрузка…
Ссылка в новой задаче