Benchmarks: Add Feature - Add interface to get all predefine parameters of all benchmarks. (#56)

* Benchmarks: Add Feature - Add interface to get all predefine parameters of all benchmarks.
This commit is contained in:
guoshzhao 2021-04-14 22:38:26 +08:00 коммит произвёл GitHub
Родитель 435b2d5eeb
Коммит fb850af760
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 45 добавлений и 2 удалений

Просмотреть файл

@ -23,7 +23,7 @@ class Benchmark(ABC):
parameters (str): benchmark parameters.
"""
self._name = name
self._argv = list(filter(None, parameters.split(' ')))
self._argv = list(filter(None, parameters.split(' '))) if parameters is not None else list()
self._benchmark_type = None
self._parser = argparse.ArgumentParser(
add_help=False,

Просмотреть файл

@ -23,7 +23,7 @@ class BenchmarkRegistry:
benchmarks: Dict[str, dict] = dict()
@classmethod
def register_benchmark(cls, name, class_def, parameters=None, platform=None):
def register_benchmark(cls, name, class_def, parameters='', platform=None):
"""Register new benchmark, key is the benchmark name.
Args:
@ -67,6 +67,18 @@ class BenchmarkRegistry:
cls.benchmarks[name][p] = (class_def, parameters)
benchmark = class_def(name, parameters)
benchmark.add_parser_arguments()
ret, args, unknown = benchmark.parse_args()
if not ret or len(unknown) >= 1:
logger.log_and_raise(
TypeError,
'Registered benchmark has invalid arguments - benchmark: {}, parameters: {}'.format(name, parameters)
)
else:
cls.benchmarks[name]['predefine_param'] = vars(args)
logger.info('Benchmark registration - benchmark: {}, predefine_parameters: {}'.format(name, vars(args)))
@classmethod
def is_benchmark_context_valid(cls, benchmark_context):
"""Check wether the benchmark context is valid or not.
@ -143,6 +155,19 @@ class BenchmarkRegistry:
else:
return None
@classmethod
def get_all_benchmark_predefine_settings(cls):
"""Get all registered benchmarks' predefine settings.
Return:
benchmark_params (dict[str, dict]): key is benchmark name,
value is the dict with structure: {'parameter': default_value}.
"""
benchmark_params = dict()
for name in cls.benchmarks:
benchmark_params[name] = cls.benchmarks[name]['predefine_param']
return benchmark_params
@classmethod
def launch_benchmark(cls, benchmark_context):
"""Select and Launch benchmark.

Просмотреть файл

@ -7,6 +7,7 @@ import re
from superbench.benchmarks import Platform, Framework, BenchmarkType, BenchmarkRegistry, ReturnCode
from superbench.benchmarks.micro_benchmarks import MicroBenchmark
from superbench.benchmarks.micro_benchmarks.sharding_matmul import ShardingMode
class AccumulationBenchmark(MicroBenchmark):
@ -196,3 +197,20 @@ def test_launch_benchmark():
benchmark = BenchmarkRegistry.launch_benchmark(context)
assert (benchmark)
assert (benchmark.return_code == ReturnCode.INVALID_ARGUMENT)
def test_get_all_benchmark_predefine_settings():
"""Test interface BenchmarkRegistry.get_all_benchmark_predefine_settings()."""
benchmark_params = BenchmarkRegistry.get_all_benchmark_predefine_settings()
# Choose benchmark 'pytorch-sharding-matmul' for testing.
benchmark_name = 'pytorch-sharding-matmul'
assert (benchmark_name in benchmark_params)
assert (benchmark_params[benchmark_name]['run_count'] == 1)
assert (benchmark_params[benchmark_name]['duration'] == 0)
assert (benchmark_params[benchmark_name]['n'] == 4096)
assert (benchmark_params[benchmark_name]['k'] == 4096)
assert (benchmark_params[benchmark_name]['m'] == 4096)
assert (benchmark_params[benchmark_name]['mode'] == [ShardingMode.ALLREDUCE, ShardingMode.ALLGATHER])
assert (benchmark_params[benchmark_name]['num_warmup'] == 10)
assert (benchmark_params[benchmark_name]['num_steps'] == 500)