Benchmarks: Code Revision - Revise BenchmarkRegistry interfaces for integration with executor. (#33)

* revise BenchmarkRegistry interfaces.
* address comments

Co-authored-by: Guoshuai Zhao <guzhao@microsoft.com>
This commit is contained in:
guoshzhao 2021-04-08 23:17:03 +08:00 коммит произвёл GitHub
Родитель 2871a68b62
Коммит 923ce2773f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 144 добавлений и 159 удалений

Просмотреть файл

@ -3,25 +3,21 @@
"""Model benchmark example for bert-large."""
from superbench.benchmarks import Platform, Framework, BenchmarkRegistry, BenchmarkContext
from superbench.benchmarks import Framework, BenchmarkRegistry
from superbench.common.utils import logger
if __name__ == '__main__':
# Create context for bert-large benchmark and run it for 120 * 2 seconds.
context = BenchmarkContext(
context = BenchmarkRegistry.create_benchmark_context(
'bert-large',
Platform.CUDA,
parameters='--batch_size=1 --duration=120 --seq_len=512 --precision=float32 --run_count=2',
framework=Framework.PYTORCH
)
if BenchmarkRegistry.check_parameters(context):
benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
logger.info(
'benchmark: {}, return code: {}, result: {}'.format(
benchmark.name, benchmark.return_code, benchmark.result
)
benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
logger.info(
'benchmark: {}, return code: {}, result: {}'.format(
benchmark.name, benchmark.return_code, benchmark.result
)
else:
logger.error('bert-large benchmark does not exist or context/parameters are invalid.')
)

Просмотреть файл

@ -74,14 +74,14 @@ class Benchmark(ABC):
logger.error('Invalid argument - benchmark: {}, message: {}.'.format(self._name, str(e)))
return False, None, None
ret = True
if len(unknown) > 0:
logger.warning(
'Benchmark has unknown arguments - benchmark: {}, unknown arguments: {}'.format(
self._name, ' '.join(unknown)
)
logger.error(
'Unknown arguments - benchmark: {}, unknown arguments: {}'.format(self._name, ' '.join(unknown))
)
ret = False
return True, args, unknown
return ret, args, unknown
def _preprocess(self):
"""Preprocess/preparation operations before the benchmarking.

Просмотреть файл

@ -105,32 +105,19 @@ class BenchmarkRegistry:
return benchmark_name
@classmethod
def check_parameters(cls, benchmark_context):
"""Check the validation of customized parameters.
def create_benchmark_context(cls, name, platform=Platform.CPU, parameters='', framework=Framework.NONE):
"""Constructor.
Args:
benchmark_context (BenchmarkContext): the benchmark context.
name (str): name of benchmark in config file.
platform (Platform): Platform types like Platform.CPU, Platform.CUDA, Platform.ROCM.
parameters (str): predefined parameters of benchmark.
framework (Framework): Framework types like Framework.PYTORCH, Framework.ONNX.
Return:
Return True if benchmark exists and context/parameters are valid.
benchmark_context (BenchmarkContext): the benchmark context.
"""
if not cls.is_benchmark_context_valid(benchmark_context):
return False
benchmark_name = cls.__get_benchmark_name(benchmark_context)
platform = benchmark_context.platform
customized_parameters = benchmark_context.parameters
if benchmark_name:
(benchmark_class, params) = cls.__select_benchmark(benchmark_name, platform)
if benchmark_class:
benchmark = benchmark_class(benchmark_name, customized_parameters)
benchmark.add_parser_arguments()
ret, args, unknown = benchmark.parse_args()
if ret and len(unknown) < 1:
return True
return False
return BenchmarkContext(name, platform, parameters, framework)
@classmethod
def get_benchmark_configurable_settings(cls, benchmark_context):

Просмотреть файл

@ -3,8 +3,7 @@
"""Tests for BenchmarkRegistry module."""
from superbench.benchmarks import Platform, Framework, Precision, \
BenchmarkContext, BenchmarkRegistry, BenchmarkType, ReturnCode
from superbench.benchmarks import Platform, Framework, Precision, BenchmarkRegistry, BenchmarkType, ReturnCode
from superbench.benchmarks.model_benchmarks import ModelBenchmark
@ -111,7 +110,9 @@ def create_benchmark(params='--num_steps=8'):
parameters='--hidden_size=2',
platform=Platform.CUDA,
)
context = BenchmarkContext('fake-model', Platform.CUDA, parameters=params, framework=Framework.PYTORCH)
context = BenchmarkRegistry.create_benchmark_context(
'fake-model', platform=Platform.CUDA, parameters=params, framework=Framework.PYTORCH
)
name = BenchmarkRegistry._BenchmarkRegistry__get_benchmark_name(context)
assert (name)
(benchmark_class, predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(name, context.platform)

Просмотреть файл

@ -9,7 +9,7 @@ import numbers
import torch
from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Precision, Platform, BenchmarkContext, ReturnCode
from superbench.benchmarks import BenchmarkRegistry, Precision, ReturnCode
from superbench.benchmarks.model_benchmarks.model_base import Optimizer, DistributedImpl, DistributedBackend
from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset
@ -174,62 +174,58 @@ def test_pytorch_base():
# Register BERT Base benchmark.
BenchmarkRegistry.register_benchmark('pytorch-mnist', PytorchMNIST)
# Launch benchmark for testing.
context = BenchmarkContext(
# Launch benchmark with --no_gpu for testing.
context = BenchmarkRegistry.create_benchmark_context(
'pytorch-mnist',
Platform.CPU,
parameters='--batch_size=32 --num_warmup=8 --num_steps=64 --model_action train inference --no_gpu'
)
assert (BenchmarkRegistry.check_parameters(context))
benchmark = BenchmarkRegistry.launch_benchmark(context)
assert (benchmark)
assert (benchmark.name == 'pytorch-mnist')
assert (benchmark.return_code == ReturnCode.SUCCESS)
if BenchmarkRegistry.check_parameters(context):
benchmark = BenchmarkRegistry.launch_benchmark(context)
# Test results.
for metric in [
'steptime_train_float32', 'steptime_inference_float32', 'throughput_train_float32',
'throughput_inference_float32'
]:
assert (len(benchmark.raw_data[metric]) == 1)
assert (len(benchmark.raw_data[metric][0]) == 64)
assert (len(benchmark.result[metric]) == 1)
assert (isinstance(benchmark.result[metric][0], numbers.Number))
assert (benchmark.name == 'pytorch-mnist')
assert (benchmark.return_code == ReturnCode.SUCCESS)
# Test _cal_params_count().
assert (benchmark._cal_params_count() == 1199882)
# Test results.
for metric in [
'steptime_train_float32', 'steptime_inference_float32', 'throughput_train_float32',
'throughput_inference_float32'
]:
assert (len(benchmark.raw_data[metric]) == 1)
assert (len(benchmark.raw_data[metric][0]) == 64)
assert (len(benchmark.result[metric]) == 1)
assert (isinstance(benchmark.result[metric][0], numbers.Number))
# Test _judge_gpu_availability().
assert (benchmark._gpu_available is False)
# Test _cal_params_count().
assert (benchmark._cal_params_count() == 1199882)
# Test _init_distributed_setting().
assert (benchmark._args.distributed_impl is None)
assert (benchmark._args.distributed_backend is None)
assert (benchmark._init_distributed_setting() is True)
benchmark._args.distributed_impl = DistributedImpl.DDP
benchmark._args.distributed_backend = DistributedBackend.NCCL
assert (benchmark._init_distributed_setting() is False)
benchmark._args.distributed_impl = DistributedImpl.MIRRORED
assert (benchmark._init_distributed_setting() is False)
# Test _judge_gpu_availability().
assert (benchmark._gpu_available is False)
# Test _init_dataloader().
benchmark._args.distributed_impl = None
assert (benchmark._init_dataloader() is True)
benchmark._args.distributed_impl = DistributedImpl.DDP
assert (benchmark._init_dataloader() is False)
benchmark._args.distributed_impl = DistributedImpl.MIRRORED
assert (benchmark._init_dataloader() is False)
# Test _init_distributed_setting().
assert (benchmark._args.distributed_impl is None)
assert (benchmark._args.distributed_backend is None)
assert (benchmark._init_distributed_setting() is True)
benchmark._args.distributed_impl = DistributedImpl.DDP
benchmark._args.distributed_backend = DistributedBackend.NCCL
assert (benchmark._init_distributed_setting() is False)
benchmark._args.distributed_impl = DistributedImpl.MIRRORED
assert (benchmark._init_distributed_setting() is False)
# Test _init_dataloader().
benchmark._args.distributed_impl = None
assert (benchmark._init_dataloader() is True)
benchmark._args.distributed_impl = DistributedImpl.DDP
assert (benchmark._init_dataloader() is False)
benchmark._args.distributed_impl = DistributedImpl.MIRRORED
assert (benchmark._init_dataloader() is False)
# Test _create_optimizer().
assert (isinstance(benchmark._optimizer, torch.optim.AdamW))
benchmark._optimizer_type = Optimizer.ADAM
assert (benchmark._create_optimizer() is True)
assert (isinstance(benchmark._optimizer, torch.optim.Adam))
benchmark._optimizer_type = Optimizer.SGD
assert (benchmark._create_optimizer() is True)
assert (isinstance(benchmark._optimizer, torch.optim.SGD))
benchmark._optimizer_type = None
assert (benchmark._create_optimizer() is False)
# Test _create_optimizer().
assert (isinstance(benchmark._optimizer, torch.optim.AdamW))
benchmark._optimizer_type = Optimizer.ADAM
assert (benchmark._create_optimizer() is True)
assert (isinstance(benchmark._optimizer, torch.optim.Adam))
benchmark._optimizer_type = Optimizer.SGD
assert (benchmark._create_optimizer() is True)
assert (isinstance(benchmark._optimizer, torch.optim.SGD))
benchmark._optimizer_type = None
assert (benchmark._create_optimizer() is False)

Просмотреть файл

@ -3,21 +3,20 @@
"""Tests for BERT model benchmarks."""
from superbench.benchmarks import BenchmarkRegistry, Precision, Platform, Framework, BenchmarkContext
from superbench.benchmarks import BenchmarkRegistry, Precision, Platform, Framework
import superbench.benchmarks.model_benchmarks.pytorch_bert as pybert
def test_pytorch_bert_base():
"""Test pytorch-bert-base benchmark."""
context = BenchmarkContext(
context = BenchmarkRegistry.create_benchmark_context(
'bert-base',
Platform.CUDA,
platform=Platform.CUDA,
parameters='--batch_size=32 --num_classes=5 --seq_len=512',
framework=Framework.PYTORCH
)
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
assert (BenchmarkRegistry.check_parameters(context))
benchmark_name = BenchmarkRegistry._BenchmarkRegistry__get_benchmark_name(context)
assert (benchmark_name == 'pytorch-bert-base')
@ -54,15 +53,14 @@ def test_pytorch_bert_base():
def test_pytorch_bert_large():
"""Test pytorch-bert-large benchmark."""
context = BenchmarkContext(
context = BenchmarkRegistry.create_benchmark_context(
'bert-large',
Platform.CUDA,
platform=Platform.CUDA,
parameters='--batch_size=32 --num_classes=5 --seq_len=512',
framework=Framework.PYTORCH
)
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
assert (BenchmarkRegistry.check_parameters(context))
benchmark_name = BenchmarkRegistry._BenchmarkRegistry__get_benchmark_name(context)
assert (benchmark_name == 'pytorch-bert-large')

Просмотреть файл

@ -5,7 +5,7 @@
import re
from superbench.benchmarks import Platform, Framework, BenchmarkType, BenchmarkContext, BenchmarkRegistry, ReturnCode
from superbench.benchmarks import Platform, Framework, BenchmarkType, BenchmarkRegistry, ReturnCode
from superbench.benchmarks.micro_benchmarks import MicroBenchmark
@ -60,21 +60,21 @@ def test_register_benchmark():
# Register the benchmark for all platform if use default platform.
BenchmarkRegistry.register_benchmark('accumulation', AccumulationBenchmark)
for platform in Platform:
context = BenchmarkContext('accumulation', platform)
context = BenchmarkRegistry.create_benchmark_context('accumulation', platform=platform)
assert (BenchmarkRegistry.is_benchmark_registered(context))
# Register the benchmark for CUDA platform if use platform=Platform.CUDA.
BenchmarkRegistry.register_benchmark('accumulation-cuda', AccumulationBenchmark, platform=Platform.CUDA)
context = BenchmarkContext('accumulation-cuda', Platform.CUDA)
context = BenchmarkRegistry.create_benchmark_context('accumulation-cuda', platform=Platform.CUDA)
assert (BenchmarkRegistry.is_benchmark_registered(context))
context = BenchmarkContext('accumulation-cuda', Platform.ROCM)
context = BenchmarkRegistry.create_benchmark_context('accumulation-cuda', platform=Platform.ROCM)
assert (BenchmarkRegistry.is_benchmark_registered(context) is False)
def test_is_benchmark_context_valid():
"""Test interface BenchmarkRegistry.is_benchmark_context_valid()."""
# Positive case.
context = BenchmarkContext('accumulation', Platform.CPU)
context = BenchmarkRegistry.create_benchmark_context('accumulation', platform=Platform.CPU)
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
# Negative case.
@ -94,25 +94,13 @@ def test_get_benchmark_name():
# Test benchmark name for different Frameworks.
benchmark_frameworks = [Framework.NONE, Framework.PYTORCH, Framework.TENSORFLOW1, Framework.ONNX]
for i in range(len(benchmark_names)):
context = BenchmarkContext('accumulation', Platform.CPU, framework=benchmark_frameworks[i])
context = BenchmarkRegistry.create_benchmark_context(
'accumulation', platform=Platform.CPU, framework=benchmark_frameworks[i]
)
name = BenchmarkRegistry._BenchmarkRegistry__get_benchmark_name(context)
assert (name == benchmark_names[i])
def test_check_parameters():
"""Test interface BenchmarkRegistry.check_parameters()."""
# Register benchmarks for testing.
BenchmarkRegistry.register_benchmark('accumulation', AccumulationBenchmark)
# Positive case.
context = BenchmarkContext('accumulation', Platform.CPU, parameters='--lower_bound=1')
assert (BenchmarkRegistry.check_parameters(context))
# Negative case.
context = BenchmarkContext('accumulation', Platform.CPU, parameters='--lower=1')
assert (BenchmarkRegistry.check_parameters(context) is False)
def test_get_benchmark_configurable_settings():
"""Test BenchmarkRegistry interface.
@ -121,7 +109,7 @@ def test_get_benchmark_configurable_settings():
# Register benchmarks for testing.
BenchmarkRegistry.register_benchmark('accumulation', AccumulationBenchmark)
context = BenchmarkContext('accumulation', Platform.CPU)
context = BenchmarkRegistry.create_benchmark_context('accumulation', platform=Platform.CPU)
settings = BenchmarkRegistry.get_benchmark_configurable_settings(context)
expected = """optional arguments:
@ -140,52 +128,71 @@ def test_launch_benchmark():
)
# Launch benchmark.
context = BenchmarkContext('accumulation', Platform.CPU, parameters='--lower_bound=1')
context = BenchmarkRegistry.create_benchmark_context(
'accumulation', platform=Platform.CPU, parameters='--lower_bound=1'
)
if BenchmarkRegistry.check_parameters(context):
benchmark = BenchmarkRegistry.launch_benchmark(context)
assert (benchmark)
assert (benchmark.name == 'accumulation')
assert (benchmark.type == BenchmarkType.MICRO)
assert (benchmark.run_count == 1)
assert (benchmark.return_code == ReturnCode.SUCCESS)
assert (benchmark.raw_data == {'accumulation_result': ['1,3,6,10']})
assert (benchmark.result == {'accumulation_result': [10]})
benchmark = BenchmarkRegistry.launch_benchmark(context)
assert (benchmark)
assert (benchmark.name == 'accumulation')
assert (benchmark.type == BenchmarkType.MICRO)
assert (benchmark.run_count == 1)
assert (benchmark.return_code == ReturnCode.SUCCESS)
assert (benchmark.raw_data == {'accumulation_result': ['1,3,6,10']})
assert (benchmark.result == {'accumulation_result': [10]})
# Replace the timestamp as null.
result = re.sub(r'\"\d+-\d+-\d+ \d+:\d+:\d+\"', 'null', benchmark.serialized_result)
expected = (
'{"name": "accumulation", "type": "micro", "run_count": 1, '
'"return_code": 0, "start_time": null, "end_time": null, '
'"raw_data": {"accumulation_result": ["1,3,6,10"]}, '
'"result": {"accumulation_result": [10]}}'
)
assert (result == expected)
# Replace the timestamp as null.
result = re.sub(r'\"\d+-\d+-\d+ \d+:\d+:\d+\"', 'null', benchmark.serialized_result)
expected = (
'{"name": "accumulation", "type": "micro", "run_count": 1, '
'"return_code": 0, "start_time": null, "end_time": null, '
'"raw_data": {"accumulation_result": ["1,3,6,10"]}, '
'"result": {"accumulation_result": [10]}}'
)
assert (result == expected)
# Launch benchmark with overridden parameters.
context = BenchmarkContext('accumulation', Platform.CPU, parameters='--lower_bound=1 --upper_bound=4')
if BenchmarkRegistry.check_parameters(context):
benchmark = BenchmarkRegistry.launch_benchmark(context)
assert (benchmark)
assert (benchmark.name == 'accumulation')
assert (benchmark.type == BenchmarkType.MICRO)
assert (benchmark.run_count == 1)
assert (benchmark.return_code == ReturnCode.SUCCESS)
assert (benchmark.raw_data == {'accumulation_result': ['1,3,6']})
assert (benchmark.result == {'accumulation_result': [6]})
context = BenchmarkRegistry.create_benchmark_context(
'accumulation', platform=Platform.CPU, parameters='--lower_bound=1 --upper_bound=4'
)
benchmark = BenchmarkRegistry.launch_benchmark(context)
assert (benchmark)
assert (benchmark.name == 'accumulation')
assert (benchmark.type == BenchmarkType.MICRO)
assert (benchmark.run_count == 1)
assert (benchmark.return_code == ReturnCode.SUCCESS)
assert (benchmark.raw_data == {'accumulation_result': ['1,3,6']})
assert (benchmark.result == {'accumulation_result': [6]})
# Replace the timestamp as null.
result = re.sub(r'\"\d+-\d+-\d+ \d+:\d+:\d+\"', 'null', benchmark.serialized_result)
expected = (
'{"name": "accumulation", "type": "micro", "run_count": 1, '
'"return_code": 0, "start_time": null, "end_time": null, '
'"raw_data": {"accumulation_result": ["1,3,6"]}, '
'"result": {"accumulation_result": [6]}}'
)
assert (result == expected)
# Replace the timestamp as null.
result = re.sub(r'\"\d+-\d+-\d+ \d+:\d+:\d+\"', 'null', benchmark.serialized_result)
expected = (
'{"name": "accumulation", "type": "micro", "run_count": 1, '
'"return_code": 0, "start_time": null, "end_time": null, '
'"raw_data": {"accumulation_result": ["1,3,6"]}, '
'"result": {"accumulation_result": [6]}}'
)
assert (result == expected)
# Failed to launch benchmark due to 'benchmark not found'.
context = BenchmarkContext(
context = BenchmarkRegistry.create_benchmark_context(
'accumulation-fail', Platform.CPU, parameters='--lower_bound=1 --upper_bound=4', framework=Framework.PYTORCH
)
assert (BenchmarkRegistry.check_parameters(context) is False)
benchmark = BenchmarkRegistry.launch_benchmark(context)
assert (benchmark is None)
# Failed to launch benchmark due to 'unknown arguments'.
context = BenchmarkRegistry.create_benchmark_context(
'accumulation', platform=Platform.CPU, parameters='--lower_bound=1 --test=4'
)
benchmark = BenchmarkRegistry.launch_benchmark(context)
assert (benchmark)
assert (benchmark.return_code == ReturnCode.INVALID_ARGUMENT)
# Failed to launch benchmark due to 'invalid arguments'.
context = BenchmarkRegistry.create_benchmark_context(
'accumulation', platform=Platform.CPU, parameters='--lower_bound=1 --upper_bound=x'
)
benchmark = BenchmarkRegistry.launch_benchmark(context)
assert (benchmark)
assert (benchmark.return_code == ReturnCode.INVALID_ARGUMENT)