Config - Update benchmark naming to support annotations (#284)

__Description__

Update benchmark naming to support annotations.

__Major Revisions__
- Update name for `create_benchmark_context` in executor.
- Backward compatibility for model benchmarks using "_models" suffix.
- Update documents.
This commit is contained in:
Yifan Xiong 2022-01-25 17:54:58 +08:00 коммит произвёл GitHub
Родитель 35fc06ebd1
Коммит 7d7cd3dc63
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 96 добавлений и 119 удалений

Просмотреть файл

@ -138,11 +138,11 @@ superbench:
num_steps: 128
batch_size: 128
benchmarks:
foo_models:
model-benchmarks:foo:
models:
- resnet50
parameters: *param
bar_models:
model-benchmarks:bar:
models:
- vgg19
parameters: *param
@ -152,14 +152,14 @@ The above configuration equals to the following:
```yaml {6-9,13-16}
superbench:
benchmarks:
foo_models:
model-benchmarks:foo:
models:
- resnet50
parameters:
num_warmup: 16
num_steps: 128
batch_size: 128
bar_models:
model-benchmarks:bar:
models:
- vgg19
parameters:
@ -172,10 +172,16 @@ superbench:
Mappings of `${benchmark_name}: Benchmark`.
There are three types of benchmarks, micro-benchmark, model-benchmark and docker-benchmark.
For micro-benchmark and docker-benchmark, `${benchmark_name}` should be the exact same as provided benchmarks' name.
For model-benchmark, `${benchmark_name}` should be in `${name}_models` format,
each model-benchmark can have a customized name while ending with `_models`.
There are three types of benchmarks,
[micro-benchmark](./user-tutorial/benchmarks/micro-benchmarks),
[model-benchmark](./user-tutorial/benchmarks/model-benchmarks),
and [docker-benchmark](./user-tutorial/benchmarks/docker-benchmarks).
Each benchmark has its own unique name listed in docs.
`${benchmark_name}` can be one of the followings:
* `${benchmark_unique_name}`, it can be the exact same as benchmark's own unique name;
* `${benchmark_unique_name}:${annotation}`, or if there's a need to run one benchmark with different settings,
an annotation separated by `:` can be appended after benchmark's unique name.
See [`Benchmark` Schema](#benchmark-schema) for benchmark definition.
@ -208,7 +214,7 @@ ${benchmark_name}:
#### Model-Benchmark
```yaml
${name}_models:
model-benchmarks:${annotation}:
enable: bool
modes: [ Mode ]
frameworks: [ enum ]
@ -248,7 +254,7 @@ kernel-launch:
#### Model-Benchmark
```yaml
resnet_models:
model-benchmarks:resnet:
enable: true
modes:
- name: torch.distributed

Просмотреть файл

@ -6,94 +6,43 @@ id: model-benchmarks
## PyTorch Model Benchmarks
### `gpt_models`
### `model-benchmarks`
#### Introduction
Run training or inference tasks with single or half precision for GPT models,
including gpt2-small, gpt2-medium, gpt2-large and gpt2-xl.
The supported percentiles are 50, 90, 95, 99, and 99.9.
Run training or inference tasks with single or half precision for deep learning models,
including the following categories:
* GPT: gpt2-small, gpt2-medium, gpt2-large and gpt2-xl
* BERT: bert-base and bert-large
* LSTM
* CNN, listed in [`torchvision.models`](https://pytorch.org/vision/0.8/models.html), including:
* resnet: resnet18, resnet34, resnet50, resnet101, resnet152
* resnext: resnext50_32x4d, resnext101_32x8d
* wide_resnet: wide_resnet50_2, wide_resnet101_2
* densenet: densenet121, densenet169, densenet201, densenet161
* vgg: vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19_bn, vgg19
* mnasnet: mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3
* mobilenet: mobilenet_v2
* shufflenet: shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5, shufflenet_v2_x2_0
* squeezenet: squeezenet1_0, squeezenet1_1
* others: alexnet, googlenet, inception_v3
For inference, supported percentiles include
50<sup>th</sup>, 90<sup>th</sup>, 95<sup>th</sup>, 99<sup>th</sup>, and 99.9<sup>th</sup>.
#### Metrics
| Name | Unit | Description |
|-------------------------------------------------------------------------|------------------------|---------------------------------------------------------------------------|
| gpt_models/pytorch-${model_name}/fp32_train_step_time | time (ms) | The average training step time with single precision. |
| gpt_models/pytorch-${model_name}/fp32_train_throughput | throughput (samples/s) | The average training throughput with single precision. |
| gpt_models/pytorch-${model_name}/fp32_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with single precision. |
| gpt_models/pytorch-${model_name}/fp32_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with single precision. |
| gpt_models/pytorch-${model_name}/fp16_train_step_time | time (ms) | The average training step time with half precision. |
| gpt_models/pytorch-${model_name}/fp16_train_throughput | throughput (samples/s) | The average training throughput with half precision. |
| gpt_models/pytorch-${model_name}/fp16_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with half precision. |
| gpt_models/pytorch-${model_name}/fp16_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with half precision. |
### `bert_models`
#### Introduction
Run training or inference tasks with single or half precision for BERT models, including bert-base and bert-large.
The supported percentiles are 50, 90, 95, 99, and 99.9.
#### Metrics
| Name | Unit | Description |
|--------------------------------------------------------------------------|------------------------|---------------------------------------------------------------------------|
| bert_models/pytorch-${model_name}/fp32_train_step_time | time (ms) | The average training step time with single precision. |
| bert_models/pytorch-${model_name}/fp32_train_throughput | throughput (samples/s) | The average training throughput with single precision. |
| bert_models/pytorch-${model_name}/fp32_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with single precision. |
| bert_models/pytorch-${model_name}/fp32_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with single precision. |
| bert_models/pytorch-${model_name}/fp16_train_step_time | time (ms) | The average training step time with half precision. |
| bert_models/pytorch-${model_name}/fp16_train_throughput | throughput (samples/s) | The average training throughput with half precision. |
| bert_models/pytorch-${model_name}/fp16_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with half precision. |
| bert_models/pytorch-${model_name}/fp16_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with half precision. |
### `lstm_models`
#### Introduction
Run training or inference tasks with single or half precision for one bidirectional LSTM model.
The supported percentiles are 50, 90, 95, 99, and 99.9.
#### Metrics
| Name | Unit | Description |
|-----------------------------------------------------------------|------------------------|---------------------------------------------------------------------------|
| lstm_models/pytorch-lstm/fp32_train_step_time | time (ms) | The average training step time with single precision. |
| lstm_models/pytorch-lstm/fp32_train_throughput | throughput (samples/s) | The average training throughput with single precision. |
| lstm_models/pytorch-lstm/fp32_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with single precision. |
| lstm_models/pytorch-lstm/fp32_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with single precision. |
| lstm_models/pytorch-lstm/fp16_train_step_time | time (ms) | The average training step time with half precision. |
| lstm_models/pytorch-lstm/fp16_train_throughput | throughput (samples/s) | The average training throughput with half precision. |
| lstm_models/pytorch-lstm/fp16_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with half precision. |
| lstm_models/pytorch-lstm/fp16_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with half precision. |
### `cnn_models`
#### Introduction
Run training or inference tasks with single or half precision for CNN models listed in
[`torchvision.models`](https://pytorch.org/vision/0.8/models.html), including:
* resnet: resnet18, resnet34, resnet50, resnet101, resnet152
* resnext: resnext50_32x4d, resnext101_32x8d
* wide_resnet: wide_resnet50_2, wide_resnet101_2
* densenet: densenet121, densenet169, densenet201, densenet161
* vgg: vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19_bn, vgg19
* mnasnet: mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3
* mobilenet: mobilenet_v2
* shufflenet: shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5, shufflenet_v2_x2_0
* squeezenet: squeezenet1_0, squeezenet1_1
* others: alexnet, googlenet, inception_v3
The supported percentiles are 50, 90, 95, 99, and 99.9.
#### Metrics
| Name | Unit | Description |
|-------------------------------------------------------------------------|------------------------|---------------------------------------------------------------------------|
| cnn_models/pytorch-${model_name}/fp32_train_step_time | time (ms) | Train average step time with single precision. |
| cnn_models/pytorch-${model_name}/fp32_train_throughput | throughput (samples/s) | Train average throughput with single precision. |
| cnn_models/pytorch-${model_name}/fp32_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with single precision. |
| cnn_models/pytorch-${model_name}/fp32_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with single precision. |
| cnn_models/pytorch-${model_name}/fp16_train_step_time | time (ms) | Train average step time with half precision. |
| cnn_models/pytorch-${model_name}/fp16_train_throughput | throughput (samples/s) | Train average throughput with half precision. |
| cnn_models/pytorch-${model_name}/fp16_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with half precision. |
| cnn_models/pytorch-${model_name}/fp16_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with half precision. |
| Name | Unit | Description |
|---------------------------------------------------------------------------------|------------------------|---------------------------------------------------------------------------|
| model-benchmarks/pytorch-${model_name}/fp32_train_step_time | time (ms) | The average training step time with single precision. |
| model-benchmarks/pytorch-${model_name}/fp32_train_throughput | throughput (samples/s) | The average training throughput with single precision. |
| model-benchmarks/pytorch-${model_name}/fp32_inference_step_time | time (ms) | The average inference step time with single precision. |
| model-benchmarks/pytorch-${model_name}/fp32_inference_throughput | throughput (samples/s) | The average inference throughput with single precision. |
| model-benchmarks/pytorch-${model_name}/fp32_inference_step_time\_${percentile} | time (ms) | The n<sup>th</sup> percentile inference step time with single precision. |
| model-benchmarks/pytorch-${model_name}/fp32_inference_throughput\_${percentile} | throughput (samples/s) | The n<sup>th</sup> percentile inference throughput with single precision. |
| model-benchmarks/pytorch-${model_name}/fp16_train_step_time | time (ms) | The average training step time with half precision. |
| model-benchmarks/pytorch-${model_name}/fp16_train_throughput | throughput (samples/s) | The average training throughput with half precision. |
| model-benchmarks/pytorch-${model_name}/fp16_inference_step_time | time (ms) | The average inference step time with half precision. |
| model-benchmarks/pytorch-${model_name}/fp16_inference_throughput | throughput (samples/s) | The average inference throughput with half precision. |
| model-benchmarks/pytorch-${model_name}/fp16_inference_step_time\_${percentile} | time (ms) | The n<sup>th</sup> percentile inference step time with half precision. |
| model-benchmarks/pytorch-${model_name}/fp16_inference_throughput\_${percentile} | throughput (samples/s) | The n<sup>th</sup> percentile inference throughput with half precision. |

Просмотреть файл

@ -104,15 +104,15 @@ class SuperBenchExecutor():
argv.append('--{} {}'.format(name, ' '.join(val)))
return ' '.join(argv)
def __exec_benchmark(self, context, log_suffix):
def __exec_benchmark(self, benchmark_full_name, context):
"""Launch benchmark for context.
Args:
benchmark_full_name (str): Benchmark full name.
context (BenchmarkContext): Benchmark context to launch.
log_suffix (str): Log string suffix.
Return:
dict: Benchmark results.
dict: Benchmark result.
"""
try:
benchmark = BenchmarkRegistry.launch_benchmark(context)
@ -122,15 +122,17 @@ class SuperBenchExecutor():
benchmark.result
)
if benchmark.return_code.value == 0:
logger.info('Executor succeeded in %s.', log_suffix)
logger.info('Executor succeeded in %s.', benchmark_full_name)
else:
logger.error('Executor failed in %s.', log_suffix)
return json.loads(benchmark.serialized_result)
logger.error('Executor failed in %s.', benchmark_full_name)
result = json.loads(benchmark.serialized_result)
result['name'] = benchmark_full_name
return result
else:
logger.error('Executor failed in %s, invalid context.', log_suffix)
logger.error('Executor failed in %s, invalid context.', benchmark_full_name)
except Exception as e:
logger.error(e)
logger.error('Executor failed in %s.', log_suffix)
logger.error('Executor failed in %s.', benchmark_full_name)
return None
def __get_rank_id(self):
@ -210,29 +212,32 @@ class SuperBenchExecutor():
else:
logger.warning('Monitor can not support ROCM/CPU platform.')
benchmark_real_name = benchmark_name.split(':')[0]
for framework in benchmark_config.frameworks or [Framework.NONE.value]:
if benchmark_name.endswith('_models'):
if benchmark_real_name == 'model-benchmarks' or (
':' not in benchmark_name and benchmark_name.endswith('_models')
):
for model in benchmark_config.models:
log_suffix = 'model-benchmark {}: {}/{}'.format(benchmark_name, framework, model)
logger.info('Executor is going to execute %s.', log_suffix)
full_name = f'{benchmark_name}/{framework}-{model}'
logger.info('Executor is going to execute %s.', full_name)
context = BenchmarkRegistry.create_benchmark_context(
model,
platform=self.__get_platform(),
framework=Framework(framework.lower()),
parameters=self.__get_arguments(benchmark_config.parameters)
)
result = self.__exec_benchmark(context, log_suffix)
result = self.__exec_benchmark(full_name, context)
benchmark_results.append(result)
else:
log_suffix = 'micro-benchmark {}'.format(benchmark_name)
logger.info('Executor is going to execute %s.', log_suffix)
full_name = benchmark_name
logger.info('Executor is going to execute %s.', full_name)
context = BenchmarkRegistry.create_benchmark_context(
benchmark_name,
benchmark_real_name,
platform=self.__get_platform(),
framework=Framework(framework.lower()),
parameters=self.__get_arguments(benchmark_config.parameters)
)
result = self.__exec_benchmark(context, log_suffix)
result = self.__exec_benchmark(full_name, context)
benchmark_results.append(result)
if monitor:

Просмотреть файл

@ -243,8 +243,6 @@ class SuperBenchRunner():
except Exception:
logger.error('Invalid content in JSON file: {}'.format(results_file))
continue
if results_file.parts[-3].endswith('_models'):
benchmark_name = '{}/{}'.format(results_file.parts[-3], result['name'])
if benchmark_name not in results_summary:
results_summary[benchmark_name] = defaultdict(list)
for metric in result['result']:

Просмотреть файл

@ -13,6 +13,7 @@ from unittest import mock
import yaml
from omegaconf import OmegaConf
from superbench.benchmarks import ReturnCode
from superbench.executor import SuperBenchExecutor
@ -135,17 +136,35 @@ class ExecutorTestCase(unittest.TestCase):
self.executor._sb_enabled = []
self.executor.exec()
@mock.patch('superbench.executor.SuperBenchExecutor._SuperBenchExecutor__exec_benchmark')
def test_exec_default_benchmarks(self, mock_exec_benchmark):
@mock.patch('superbench.benchmarks.BenchmarkRegistry.launch_benchmark')
def test_exec_default_benchmarks(self, mock_launch_benchmark):
"""Test execute default benchmarks, mock exec function.
Args:
mock_exec_benchmark (function): Mocked __exec_benchmark function.
mock_launch_benchmark (function): Mocked BenchmarkRegistry.launch_benchmark function in __exec_benchmark.
"""
mock_exec_benchmark.return_value = {}
mock_launch_benchmark.return_value = OmegaConf.create(
{
'name': 'foobar',
'return_code': ReturnCode.SUCCESS,
'result': {
'return_code': [0],
'metric1': [-1.0],
'metric2': [1.0]
},
'serialized_result': json.dumps({
'name': 'foobar',
'return_code': 0,
}),
}
)
self.executor.exec()
self.assertTrue(Path(self.sb_output_dir, 'benchmarks').is_dir())
for benchmark_name in self.executor._sb_enabled:
self.assertTrue(Path(self.sb_output_dir, 'benchmarks', benchmark_name, 'rank0').is_dir())
self.assertTrue(Path(self.sb_output_dir, 'benchmarks', benchmark_name, 'rank0', 'results.json').is_file())
p = Path(self.sb_output_dir, 'benchmarks', benchmark_name, 'rank0')
self.assertTrue(p.is_dir())
self.assertTrue((p / 'results.json').is_file())
with (p / 'results.json').open() as f:
for result in json.load(f):
self.assertIn(benchmark_name, result['name'])