Benchmarks: Add Feature - Add interface to get all predefine parameters of all benchmarks. (#56)
* Benchmarks: Add Feature - Add interface to get all predefine parameters of all benchmarks.
This commit is contained in:
Родитель
435b2d5eeb
Коммит
fb850af760
|
@ -23,7 +23,7 @@ class Benchmark(ABC):
|
|||
parameters (str): benchmark parameters.
|
||||
"""
|
||||
self._name = name
|
||||
self._argv = list(filter(None, parameters.split(' ')))
|
||||
self._argv = list(filter(None, parameters.split(' '))) if parameters is not None else list()
|
||||
self._benchmark_type = None
|
||||
self._parser = argparse.ArgumentParser(
|
||||
add_help=False,
|
||||
|
|
|
@ -23,7 +23,7 @@ class BenchmarkRegistry:
|
|||
benchmarks: Dict[str, dict] = dict()
|
||||
|
||||
@classmethod
|
||||
def register_benchmark(cls, name, class_def, parameters=None, platform=None):
|
||||
def register_benchmark(cls, name, class_def, parameters='', platform=None):
|
||||
"""Register new benchmark, key is the benchmark name.
|
||||
|
||||
Args:
|
||||
|
@ -67,6 +67,18 @@ class BenchmarkRegistry:
|
|||
|
||||
cls.benchmarks[name][p] = (class_def, parameters)
|
||||
|
||||
benchmark = class_def(name, parameters)
|
||||
benchmark.add_parser_arguments()
|
||||
ret, args, unknown = benchmark.parse_args()
|
||||
if not ret or len(unknown) >= 1:
|
||||
logger.log_and_raise(
|
||||
TypeError,
|
||||
'Registered benchmark has invalid arguments - benchmark: {}, parameters: {}'.format(name, parameters)
|
||||
)
|
||||
else:
|
||||
cls.benchmarks[name]['predefine_param'] = vars(args)
|
||||
logger.info('Benchmark registration - benchmark: {}, predefine_parameters: {}'.format(name, vars(args)))
|
||||
|
||||
@classmethod
|
||||
def is_benchmark_context_valid(cls, benchmark_context):
|
||||
"""Check wether the benchmark context is valid or not.
|
||||
|
@ -143,6 +155,19 @@ class BenchmarkRegistry:
|
|||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_all_benchmark_predefine_settings(cls):
|
||||
"""Get all registered benchmarks' predefine settings.
|
||||
|
||||
Return:
|
||||
benchmark_params (dict[str, dict]): key is benchmark name,
|
||||
value is the dict with structure: {'parameter': default_value}.
|
||||
"""
|
||||
benchmark_params = dict()
|
||||
for name in cls.benchmarks:
|
||||
benchmark_params[name] = cls.benchmarks[name]['predefine_param']
|
||||
return benchmark_params
|
||||
|
||||
@classmethod
|
||||
def launch_benchmark(cls, benchmark_context):
|
||||
"""Select and Launch benchmark.
|
||||
|
|
|
@ -7,6 +7,7 @@ import re
|
|||
|
||||
from superbench.benchmarks import Platform, Framework, BenchmarkType, BenchmarkRegistry, ReturnCode
|
||||
from superbench.benchmarks.micro_benchmarks import MicroBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.sharding_matmul import ShardingMode
|
||||
|
||||
|
||||
class AccumulationBenchmark(MicroBenchmark):
|
||||
|
@ -196,3 +197,20 @@ def test_launch_benchmark():
|
|||
benchmark = BenchmarkRegistry.launch_benchmark(context)
|
||||
assert (benchmark)
|
||||
assert (benchmark.return_code == ReturnCode.INVALID_ARGUMENT)
|
||||
|
||||
|
||||
def test_get_all_benchmark_predefine_settings():
|
||||
"""Test interface BenchmarkRegistry.get_all_benchmark_predefine_settings()."""
|
||||
benchmark_params = BenchmarkRegistry.get_all_benchmark_predefine_settings()
|
||||
|
||||
# Choose benchmark 'pytorch-sharding-matmul' for testing.
|
||||
benchmark_name = 'pytorch-sharding-matmul'
|
||||
assert (benchmark_name in benchmark_params)
|
||||
assert (benchmark_params[benchmark_name]['run_count'] == 1)
|
||||
assert (benchmark_params[benchmark_name]['duration'] == 0)
|
||||
assert (benchmark_params[benchmark_name]['n'] == 4096)
|
||||
assert (benchmark_params[benchmark_name]['k'] == 4096)
|
||||
assert (benchmark_params[benchmark_name]['m'] == 4096)
|
||||
assert (benchmark_params[benchmark_name]['mode'] == [ShardingMode.ALLREDUCE, ShardingMode.ALLGATHER])
|
||||
assert (benchmark_params[benchmark_name]['num_warmup'] == 10)
|
||||
assert (benchmark_params[benchmark_name]['num_steps'] == 500)
|
||||
|
|
Загрузка…
Ссылка в новой задаче