Config - Update benchmark naming to support annotations (#284)
__Description__ Update benchmark naming to support annotations. __Major Revisions__ - Update name for `create_benchmark_context` in executor. - Backward compatibility for model benchmarks using "_models" suffix. - Update documents.
This commit is contained in:
Родитель
35fc06ebd1
Коммит
7d7cd3dc63
|
@ -138,11 +138,11 @@ superbench:
|
|||
num_steps: 128
|
||||
batch_size: 128
|
||||
benchmarks:
|
||||
foo_models:
|
||||
model-benchmarks:foo:
|
||||
models:
|
||||
- resnet50
|
||||
parameters: *param
|
||||
bar_models:
|
||||
model-benchmarks:bar:
|
||||
models:
|
||||
- vgg19
|
||||
parameters: *param
|
||||
|
@ -152,14 +152,14 @@ The above configuration equals to the following:
|
|||
```yaml {6-9,13-16}
|
||||
superbench:
|
||||
benchmarks:
|
||||
foo_models:
|
||||
model-benchmarks:foo:
|
||||
models:
|
||||
- resnet50
|
||||
parameters:
|
||||
num_warmup: 16
|
||||
num_steps: 128
|
||||
batch_size: 128
|
||||
bar_models:
|
||||
model-benchmarks:bar:
|
||||
models:
|
||||
- vgg19
|
||||
parameters:
|
||||
|
@ -172,10 +172,16 @@ superbench:
|
|||
|
||||
Mappings of `${benchmark_name}: Benchmark`.
|
||||
|
||||
There are three types of benchmarks, micro-benchmark, model-benchmark and docker-benchmark.
|
||||
For micro-benchmark and docker-benchmark, `${benchmark_name}` should be the exact same as provided benchmarks' name.
|
||||
For model-benchmark, `${benchmark_name}` should be in `${name}_models` format,
|
||||
each model-benchmark can have a customized name while ending with `_models`.
|
||||
There are three types of benchmarks,
|
||||
[micro-benchmark](./user-tutorial/benchmarks/micro-benchmarks),
|
||||
[model-benchmark](./user-tutorial/benchmarks/model-benchmarks),
|
||||
and [docker-benchmark](./user-tutorial/benchmarks/docker-benchmarks).
|
||||
Each benchmark has its own unique name listed in docs.
|
||||
|
||||
`${benchmark_name}` can be one of the followings:
|
||||
* `${benchmark_unique_name}`, it can be the exact same as benchmark's own unique name;
|
||||
* `${benchmark_unique_name}:${annotation}`, or if there's a need to run one benchmark with different settings,
|
||||
an annotation separated by `:` can be appended after benchmark's unique name.
|
||||
|
||||
See [`Benchmark` Schema](#benchmark-schema) for benchmark definition.
|
||||
|
||||
|
@ -208,7 +214,7 @@ ${benchmark_name}:
|
|||
#### Model-Benchmark
|
||||
|
||||
```yaml
|
||||
${name}_models:
|
||||
model-benchmarks:${annotation}:
|
||||
enable: bool
|
||||
modes: [ Mode ]
|
||||
frameworks: [ enum ]
|
||||
|
@ -248,7 +254,7 @@ kernel-launch:
|
|||
#### Model-Benchmark
|
||||
|
||||
```yaml
|
||||
resnet_models:
|
||||
model-benchmarks:resnet:
|
||||
enable: true
|
||||
modes:
|
||||
- name: torch.distributed
|
||||
|
|
|
@ -6,94 +6,43 @@ id: model-benchmarks
|
|||
|
||||
## PyTorch Model Benchmarks
|
||||
|
||||
### `gpt_models`
|
||||
### `model-benchmarks`
|
||||
|
||||
#### Introduction
|
||||
|
||||
Run training or inference tasks with single or half precision for GPT models,
|
||||
including gpt2-small, gpt2-medium, gpt2-large and gpt2-xl.
|
||||
The supported percentiles are 50, 90, 95, 99, and 99.9.
|
||||
Run training or inference tasks with single or half precision for deep learning models,
|
||||
including the following categories:
|
||||
* GPT: gpt2-small, gpt2-medium, gpt2-large and gpt2-xl
|
||||
* BERT: bert-base and bert-large
|
||||
* LSTM
|
||||
* CNN, listed in [`torchvision.models`](https://pytorch.org/vision/0.8/models.html), including:
|
||||
* resnet: resnet18, resnet34, resnet50, resnet101, resnet152
|
||||
* resnext: resnext50_32x4d, resnext101_32x8d
|
||||
* wide_resnet: wide_resnet50_2, wide_resnet101_2
|
||||
* densenet: densenet121, densenet169, densenet201, densenet161
|
||||
* vgg: vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19_bn, vgg19
|
||||
* mnasnet: mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3
|
||||
* mobilenet: mobilenet_v2
|
||||
* shufflenet: shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5, shufflenet_v2_x2_0
|
||||
* squeezenet: squeezenet1_0, squeezenet1_1
|
||||
* others: alexnet, googlenet, inception_v3
|
||||
|
||||
For inference, supported percentiles include
|
||||
50<sup>th</sup>, 90<sup>th</sup>, 95<sup>th</sup>, 99<sup>th</sup>, and 99.9<sup>th</sup>.
|
||||
|
||||
#### Metrics
|
||||
|
||||
| Name | Unit | Description |
|
||||
|-------------------------------------------------------------------------|------------------------|---------------------------------------------------------------------------|
|
||||
| gpt_models/pytorch-${model_name}/fp32_train_step_time | time (ms) | The average training step time with single precision. |
|
||||
| gpt_models/pytorch-${model_name}/fp32_train_throughput | throughput (samples/s) | The average training throughput with single precision. |
|
||||
| gpt_models/pytorch-${model_name}/fp32_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with single precision. |
|
||||
| gpt_models/pytorch-${model_name}/fp32_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with single precision. |
|
||||
| gpt_models/pytorch-${model_name}/fp16_train_step_time | time (ms) | The average training step time with half precision. |
|
||||
| gpt_models/pytorch-${model_name}/fp16_train_throughput | throughput (samples/s) | The average training throughput with half precision. |
|
||||
| gpt_models/pytorch-${model_name}/fp16_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with half precision. |
|
||||
| gpt_models/pytorch-${model_name}/fp16_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with half precision. |
|
||||
|
||||
### `bert_models`
|
||||
|
||||
#### Introduction
|
||||
|
||||
Run training or inference tasks with single or half precision for BERT models, including bert-base and bert-large.
|
||||
The supported percentiles are 50, 90, 95, 99, and 99.9.
|
||||
|
||||
#### Metrics
|
||||
|
||||
| Name | Unit | Description |
|
||||
|--------------------------------------------------------------------------|------------------------|---------------------------------------------------------------------------|
|
||||
| bert_models/pytorch-${model_name}/fp32_train_step_time | time (ms) | The average training step time with single precision. |
|
||||
| bert_models/pytorch-${model_name}/fp32_train_throughput | throughput (samples/s) | The average training throughput with single precision. |
|
||||
| bert_models/pytorch-${model_name}/fp32_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with single precision. |
|
||||
| bert_models/pytorch-${model_name}/fp32_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with single precision. |
|
||||
| bert_models/pytorch-${model_name}/fp16_train_step_time | time (ms) | The average training step time with half precision. |
|
||||
| bert_models/pytorch-${model_name}/fp16_train_throughput | throughput (samples/s) | The average training throughput with half precision. |
|
||||
| bert_models/pytorch-${model_name}/fp16_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with half precision. |
|
||||
| bert_models/pytorch-${model_name}/fp16_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with half precision. |
|
||||
|
||||
### `lstm_models`
|
||||
|
||||
#### Introduction
|
||||
|
||||
Run training or inference tasks with single or half precision for one bidirectional LSTM model.
|
||||
The supported percentiles are 50, 90, 95, 99, and 99.9.
|
||||
|
||||
#### Metrics
|
||||
|
||||
| Name | Unit | Description |
|
||||
|-----------------------------------------------------------------|------------------------|---------------------------------------------------------------------------|
|
||||
| lstm_models/pytorch-lstm/fp32_train_step_time | time (ms) | The average training step time with single precision. |
|
||||
| lstm_models/pytorch-lstm/fp32_train_throughput | throughput (samples/s) | The average training throughput with single precision. |
|
||||
| lstm_models/pytorch-lstm/fp32_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with single precision. |
|
||||
| lstm_models/pytorch-lstm/fp32_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with single precision. |
|
||||
| lstm_models/pytorch-lstm/fp16_train_step_time | time (ms) | The average training step time with half precision. |
|
||||
| lstm_models/pytorch-lstm/fp16_train_throughput | throughput (samples/s) | The average training throughput with half precision. |
|
||||
| lstm_models/pytorch-lstm/fp16_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with half precision. |
|
||||
| lstm_models/pytorch-lstm/fp16_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with half precision. |
|
||||
|
||||
### `cnn_models`
|
||||
|
||||
#### Introduction
|
||||
|
||||
Run training or inference tasks with single or half precision for CNN models listed in
|
||||
[`torchvision.models`](https://pytorch.org/vision/0.8/models.html), including:
|
||||
* resnet: resnet18, resnet34, resnet50, resnet101, resnet152
|
||||
* resnext: resnext50_32x4d, resnext101_32x8d
|
||||
* wide_resnet: wide_resnet50_2, wide_resnet101_2
|
||||
* densenet: densenet121, densenet169, densenet201, densenet161
|
||||
* vgg: vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19_bn, vgg19
|
||||
* mnasnet: mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3
|
||||
* mobilenet: mobilenet_v2
|
||||
* shufflenet: shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5, shufflenet_v2_x2_0
|
||||
* squeezenet: squeezenet1_0, squeezenet1_1
|
||||
* others: alexnet, googlenet, inception_v3
|
||||
The supported percentiles are 50, 90, 95, 99, and 99.9.
|
||||
|
||||
#### Metrics
|
||||
|
||||
| Name | Unit | Description |
|
||||
|-------------------------------------------------------------------------|------------------------|---------------------------------------------------------------------------|
|
||||
| cnn_models/pytorch-${model_name}/fp32_train_step_time | time (ms) | Train average step time with single precision. |
|
||||
| cnn_models/pytorch-${model_name}/fp32_train_throughput | throughput (samples/s) | Train average throughput with single precision. |
|
||||
| cnn_models/pytorch-${model_name}/fp32_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with single precision. |
|
||||
| cnn_models/pytorch-${model_name}/fp32_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with single precision. |
|
||||
| cnn_models/pytorch-${model_name}/fp16_train_step_time | time (ms) | Train average step time with half precision. |
|
||||
| cnn_models/pytorch-${model_name}/fp16_train_throughput | throughput (samples/s) | Train average throughput with half precision. |
|
||||
| cnn_models/pytorch-${model_name}/fp16_inference_step_time_{percentile} | time (ms) | The {percentile}th percentile inference step time with half precision. |
|
||||
| cnn_models/pytorch-${model_name}/fp16_inference_throughput_{percentile} | throughput (samples/s) | The {percentile}th percentile inference throughput with half precision. |
|
||||
| Name | Unit | Description |
|
||||
|---------------------------------------------------------------------------------|------------------------|---------------------------------------------------------------------------|
|
||||
| model-benchmarks/pytorch-${model_name}/fp32_train_step_time | time (ms) | The average training step time with single precision. |
|
||||
| model-benchmarks/pytorch-${model_name}/fp32_train_throughput | throughput (samples/s) | The average training throughput with single precision. |
|
||||
| model-benchmarks/pytorch-${model_name}/fp32_inference_step_time | time (ms) | The average inference step time with single precision. |
|
||||
| model-benchmarks/pytorch-${model_name}/fp32_inference_throughput | throughput (samples/s) | The average inference throughput with single precision. |
|
||||
| model-benchmarks/pytorch-${model_name}/fp32_inference_step_time\_${percentile} | time (ms) | The n<sup>th</sup> percentile inference step time with single precision. |
|
||||
| model-benchmarks/pytorch-${model_name}/fp32_inference_throughput\_${percentile} | throughput (samples/s) | The n<sup>th</sup> percentile inference throughput with single precision. |
|
||||
| model-benchmarks/pytorch-${model_name}/fp16_train_step_time | time (ms) | The average training step time with half precision. |
|
||||
| model-benchmarks/pytorch-${model_name}/fp16_train_throughput | throughput (samples/s) | The average training throughput with half precision. |
|
||||
| model-benchmarks/pytorch-${model_name}/fp16_inference_step_time | time (ms) | The average inference step time with half precision. |
|
||||
| model-benchmarks/pytorch-${model_name}/fp16_inference_throughput | throughput (samples/s) | The average inference throughput with half precision. |
|
||||
| model-benchmarks/pytorch-${model_name}/fp16_inference_step_time\_${percentile} | time (ms) | The n<sup>th</sup> percentile inference step time with half precision. |
|
||||
| model-benchmarks/pytorch-${model_name}/fp16_inference_throughput\_${percentile} | throughput (samples/s) | The n<sup>th</sup> percentile inference throughput with half precision. |
|
||||
|
|
|
@ -104,15 +104,15 @@ class SuperBenchExecutor():
|
|||
argv.append('--{} {}'.format(name, ' '.join(val)))
|
||||
return ' '.join(argv)
|
||||
|
||||
def __exec_benchmark(self, context, log_suffix):
|
||||
def __exec_benchmark(self, benchmark_full_name, context):
|
||||
"""Launch benchmark for context.
|
||||
|
||||
Args:
|
||||
benchmark_full_name (str): Benchmark full name.
|
||||
context (BenchmarkContext): Benchmark context to launch.
|
||||
log_suffix (str): Log string suffix.
|
||||
|
||||
Return:
|
||||
dict: Benchmark results.
|
||||
dict: Benchmark result.
|
||||
"""
|
||||
try:
|
||||
benchmark = BenchmarkRegistry.launch_benchmark(context)
|
||||
|
@ -122,15 +122,17 @@ class SuperBenchExecutor():
|
|||
benchmark.result
|
||||
)
|
||||
if benchmark.return_code.value == 0:
|
||||
logger.info('Executor succeeded in %s.', log_suffix)
|
||||
logger.info('Executor succeeded in %s.', benchmark_full_name)
|
||||
else:
|
||||
logger.error('Executor failed in %s.', log_suffix)
|
||||
return json.loads(benchmark.serialized_result)
|
||||
logger.error('Executor failed in %s.', benchmark_full_name)
|
||||
result = json.loads(benchmark.serialized_result)
|
||||
result['name'] = benchmark_full_name
|
||||
return result
|
||||
else:
|
||||
logger.error('Executor failed in %s, invalid context.', log_suffix)
|
||||
logger.error('Executor failed in %s, invalid context.', benchmark_full_name)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error('Executor failed in %s.', log_suffix)
|
||||
logger.error('Executor failed in %s.', benchmark_full_name)
|
||||
return None
|
||||
|
||||
def __get_rank_id(self):
|
||||
|
@ -210,29 +212,32 @@ class SuperBenchExecutor():
|
|||
else:
|
||||
logger.warning('Monitor can not support ROCM/CPU platform.')
|
||||
|
||||
benchmark_real_name = benchmark_name.split(':')[0]
|
||||
for framework in benchmark_config.frameworks or [Framework.NONE.value]:
|
||||
if benchmark_name.endswith('_models'):
|
||||
if benchmark_real_name == 'model-benchmarks' or (
|
||||
':' not in benchmark_name and benchmark_name.endswith('_models')
|
||||
):
|
||||
for model in benchmark_config.models:
|
||||
log_suffix = 'model-benchmark {}: {}/{}'.format(benchmark_name, framework, model)
|
||||
logger.info('Executor is going to execute %s.', log_suffix)
|
||||
full_name = f'{benchmark_name}/{framework}-{model}'
|
||||
logger.info('Executor is going to execute %s.', full_name)
|
||||
context = BenchmarkRegistry.create_benchmark_context(
|
||||
model,
|
||||
platform=self.__get_platform(),
|
||||
framework=Framework(framework.lower()),
|
||||
parameters=self.__get_arguments(benchmark_config.parameters)
|
||||
)
|
||||
result = self.__exec_benchmark(context, log_suffix)
|
||||
result = self.__exec_benchmark(full_name, context)
|
||||
benchmark_results.append(result)
|
||||
else:
|
||||
log_suffix = 'micro-benchmark {}'.format(benchmark_name)
|
||||
logger.info('Executor is going to execute %s.', log_suffix)
|
||||
full_name = benchmark_name
|
||||
logger.info('Executor is going to execute %s.', full_name)
|
||||
context = BenchmarkRegistry.create_benchmark_context(
|
||||
benchmark_name,
|
||||
benchmark_real_name,
|
||||
platform=self.__get_platform(),
|
||||
framework=Framework(framework.lower()),
|
||||
parameters=self.__get_arguments(benchmark_config.parameters)
|
||||
)
|
||||
result = self.__exec_benchmark(context, log_suffix)
|
||||
result = self.__exec_benchmark(full_name, context)
|
||||
benchmark_results.append(result)
|
||||
|
||||
if monitor:
|
||||
|
|
|
@ -243,8 +243,6 @@ class SuperBenchRunner():
|
|||
except Exception:
|
||||
logger.error('Invalid content in JSON file: {}'.format(results_file))
|
||||
continue
|
||||
if results_file.parts[-3].endswith('_models'):
|
||||
benchmark_name = '{}/{}'.format(results_file.parts[-3], result['name'])
|
||||
if benchmark_name not in results_summary:
|
||||
results_summary[benchmark_name] = defaultdict(list)
|
||||
for metric in result['result']:
|
||||
|
|
|
@ -13,6 +13,7 @@ from unittest import mock
|
|||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from superbench.benchmarks import ReturnCode
|
||||
from superbench.executor import SuperBenchExecutor
|
||||
|
||||
|
||||
|
@ -135,17 +136,35 @@ class ExecutorTestCase(unittest.TestCase):
|
|||
self.executor._sb_enabled = []
|
||||
self.executor.exec()
|
||||
|
||||
@mock.patch('superbench.executor.SuperBenchExecutor._SuperBenchExecutor__exec_benchmark')
|
||||
def test_exec_default_benchmarks(self, mock_exec_benchmark):
|
||||
@mock.patch('superbench.benchmarks.BenchmarkRegistry.launch_benchmark')
|
||||
def test_exec_default_benchmarks(self, mock_launch_benchmark):
|
||||
"""Test execute default benchmarks, mock exec function.
|
||||
|
||||
Args:
|
||||
mock_exec_benchmark (function): Mocked __exec_benchmark function.
|
||||
mock_launch_benchmark (function): Mocked BenchmarkRegistry.launch_benchmark function in __exec_benchmark.
|
||||
"""
|
||||
mock_exec_benchmark.return_value = {}
|
||||
mock_launch_benchmark.return_value = OmegaConf.create(
|
||||
{
|
||||
'name': 'foobar',
|
||||
'return_code': ReturnCode.SUCCESS,
|
||||
'result': {
|
||||
'return_code': [0],
|
||||
'metric1': [-1.0],
|
||||
'metric2': [1.0]
|
||||
},
|
||||
'serialized_result': json.dumps({
|
||||
'name': 'foobar',
|
||||
'return_code': 0,
|
||||
}),
|
||||
}
|
||||
)
|
||||
self.executor.exec()
|
||||
|
||||
self.assertTrue(Path(self.sb_output_dir, 'benchmarks').is_dir())
|
||||
for benchmark_name in self.executor._sb_enabled:
|
||||
self.assertTrue(Path(self.sb_output_dir, 'benchmarks', benchmark_name, 'rank0').is_dir())
|
||||
self.assertTrue(Path(self.sb_output_dir, 'benchmarks', benchmark_name, 'rank0', 'results.json').is_file())
|
||||
p = Path(self.sb_output_dir, 'benchmarks', benchmark_name, 'rank0')
|
||||
self.assertTrue(p.is_dir())
|
||||
self.assertTrue((p / 'results.json').is_file())
|
||||
with (p / 'results.json').open() as f:
|
||||
for result in json.load(f):
|
||||
self.assertIn(benchmark_name, result['name'])
|
||||
|
|
Загрузка…
Ссылка в новой задаче