Benchmarks: Add Feature - Provide option to save raw data into file. (#333)

**Description**
Use config `log_raw_data` to control whether log the raw data into file or not. The default value is `no`. We can set it as `yes` for some particular benchmarks to save the raw data into file, such as NCCL/RCCL test.
This commit is contained in:
guoshzhao 2022-04-01 16:26:09 +08:00 коммит произвёл GitHub
Родитель d368d90e21
Коммит 6d895da83c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
33 изменённых файлов: 95 добавлений и 42 удалений

Просмотреть файл

@ -209,6 +209,7 @@ ${benchmark_name}:
parameters:
run_count: int
duration: int
log_raw_data: bool
${argument}: bool | str | int | float | list
```
@ -224,6 +225,7 @@ model-benchmarks:${annotation}:
parameters:
run_count: int
duration: int
log_raw_data: bool
num_warmup: int
num_steps: int
sample_count: int
@ -334,6 +336,18 @@ A list of models to run, only supported in model-benchmark.
Parameters for benchmark to use, varying for different benchmarks.
There have three common parameters for all benchmarks:
* run_count: how many times do user want to run this benchmark, default value is 1.
* duration: the elapsed time of benchmark in seconds. It can work for all model-benchmark. But for micro-benchmark, benchmark authors should consume it by themselves.
* log_raw_data: log raw data into file instead of saving it into result object, default value is `False`. Benchmarks who have large raw output may want to set it as `True`, such as `nccl-bw`/`rccl-bw`.
For Model-Benchmark, there have some parameters that can control the elapsed time.
* duration: the elapsed time of benchmark in seconds.
* num_warmup: the number of warmup step.
* num_steps: the number of test step.
If `duration > 0` and `num_warmup + num_steps > 0`, then benchmark will take the least as the elapsed time. Otherwise only one of them will take effect.
## `Mode` Schema
Definition for each benchmark mode, here is an overview of `Mode` configuration structure:

Просмотреть файл

@ -65,6 +65,12 @@ class Benchmark(ABC):
required=False,
help='The elapsed time of benchmark in seconds.',
)
self._parser.add_argument(
'--log_raw_data',
action='store_true',
default=False,
help='Log raw data into file instead of saving it into result object.',
)
def get_configurable_settings(self):
"""Get all the configurable settings.

Просмотреть файл

@ -59,7 +59,7 @@ class RocmOnnxRuntimeModelBenchmark(RocmDockerBenchmark):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output', raw_output)
self._result.add_raw_data('raw_output', raw_output, self._args.log_raw_data)
content = raw_output.splitlines(False)
try:

Просмотреть файл

@ -78,7 +78,7 @@ class CpuMemBwLatencyBenchmark(MicroBenchmarkWithInvoke):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output)
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data)
# parse the command to see which command this output belongs to
# the command is formed as ...; mlc --option; ...

Просмотреть файл

@ -268,7 +268,7 @@ class CublasBenchmark(MicroBenchmarkWithInvoke):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output)
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data)
try:
lines = raw_output.splitlines()
@ -292,7 +292,7 @@ class CublasBenchmark(MicroBenchmarkWithInvoke):
raw_data.pop()
raw_data = [float(item) for item in raw_data]
self._result.add_result(metric.lower() + '_time', statistics.mean(raw_data))
self._result.add_raw_data(metric.lower() + '_time', raw_data)
self._result.add_raw_data(metric.lower() + '_time', raw_data, self._args.log_raw_data)
if 'Error' in line:
error = True
except BaseException as e:

Просмотреть файл

@ -110,7 +110,7 @@ class CudaGemmFlopsBenchmark(GemmFlopsBenchmark):
True if the raw output string is valid and result can be extracted.
"""
precision = self._precision_need_to_run[cmd_idx]
self._result.add_raw_data('raw_output_' + precision, raw_output)
self._result.add_raw_data('raw_output_' + precision, raw_output, self._args.log_raw_data)
valid = True
flops = list()

Просмотреть файл

@ -68,7 +68,7 @@ class CudaMemBwBenchmark(MemBwBenchmark):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + self._args.mem_type[cmd_idx], raw_output)
self._result.add_raw_data('raw_output_' + self._args.mem_type[cmd_idx], raw_output, self._args.log_raw_data)
mem_bw = -1
valid = True

Просмотреть файл

@ -143,7 +143,7 @@ class CudaNcclBwBenchmark(MicroBenchmarkWithInvoke):
if rank > 0:
return True
self._result.add_raw_data('raw_output_' + self._args.operation, raw_output)
self._result.add_raw_data('raw_output_' + self._args.operation, raw_output, self._args.log_raw_data)
content = raw_output.splitlines()
size = -1

Просмотреть файл

@ -402,7 +402,7 @@ class CudnnBenchmark(MicroBenchmarkWithInvoke):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output)
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data)
try:
lines = raw_output.splitlines()
@ -426,7 +426,7 @@ class CudnnBenchmark(MicroBenchmarkWithInvoke):
raw_data.pop()
raw_data = [float(item) for item in raw_data]
self._result.add_result(metric.lower() + '_time', statistics.mean(raw_data) * 1000)
self._result.add_raw_data(metric.lower() + '_time', raw_data)
self._result.add_raw_data(metric.lower() + '_time', raw_data, self._args.log_raw_data)
if 'Error' in line:
error = True
except BaseException as e:

Просмотреть файл

@ -184,7 +184,7 @@ class DiskBenchmark(MicroBenchmarkWithInvoke):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output)
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data)
try:
fio_output = json.loads(raw_output)

Просмотреть файл

@ -74,7 +74,7 @@ class GPCNetBenchmark(MicroBenchmarkWithInvoke):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + str(idx), raw_output)
self._result.add_raw_data('raw_output_' + str(idx), raw_output, self._args.log_raw_data)
try:
# Parse and add result

Просмотреть файл

@ -123,9 +123,9 @@ class GpuBurnBenchmark(MicroBenchmarkWithInvoke):
self._result.add_result(res.split(':')[0].replace(' ', '_').lower() + '_pass', 1)
else:
self._result.add_result(res.split(':')[0].replace(' ', '_').lower() + '_pass', 0)
self._result.add_raw_data('GPU-Burn_result', res)
self._result.add_raw_data('GPU-Burn_result', res, self._args.log_raw_data)
else:
self._result.add_raw_data('GPU Burn Failure: ', failure_msg)
self._result.add_raw_data('GPU Burn Failure: ', failure_msg, self._args.log_raw_data)
self._result.add_result('abort', 1)
return False
except BaseException as e:

Просмотреть файл

@ -122,7 +122,7 @@ class GpuCopyBwBenchmark(MicroBenchmarkWithInvoke):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output)
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data)
try:
output_lines = [x.strip() for x in raw_output.strip().splitlines()]

Просмотреть файл

@ -187,7 +187,8 @@ class IBLoopbackBenchmark(MicroBenchmarkWithInvoke):
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data(
'raw_output_' + self._args.commands[cmd_idx] + '_IB' + str(self._args.ib_index), raw_output
'raw_output_' + self._args.commands[cmd_idx] + '_IB' + str(self._args.ib_index), raw_output,
self._args.log_raw_data
)
valid = False

Просмотреть файл

@ -336,7 +336,7 @@ class IBBenchmark(MicroBenchmarkWithInvoke):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + self._args.commands[cmd_idx], raw_output)
self._result.add_raw_data('raw_output_' + self._args.commands[cmd_idx], raw_output, self._args.log_raw_data)
# If it's invoked by MPI and rank is not 0, no result is expected
if os.getenv('OMPI_COMM_WORLD_RANK'):

Просмотреть файл

@ -79,7 +79,7 @@ class KernelLaunch(MicroBenchmarkWithInvoke):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output)
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data)
pattern = r'\d+\.\d+'
result = re.findall(pattern, raw_output)

Просмотреть файл

@ -69,7 +69,7 @@ class MicroBenchmark(Benchmark):
)
return False
self._result.add_raw_data(metric, result)
self._result.add_raw_data(metric, result, self._args.log_raw_data)
self._result.add_result(metric, statistics.mean(result), reduce_type)
if cal_percentile:
self._process_percentile_result(metric, result, reduce_type)

Просмотреть файл

@ -127,7 +127,7 @@ class RocmGemmFlopsBenchmark(GemmFlopsBenchmark):
True if the raw output string is valid and result can be extracted.
"""
precision = self._precision_need_to_run[cmd_idx]
self._result.add_raw_data('raw_output_' + precision, raw_output)
self._result.add_raw_data('raw_output_' + precision, raw_output, self._args.log_raw_data)
content = raw_output.splitlines()
gflops_index = None

Просмотреть файл

@ -60,7 +60,7 @@ class RocmMemBwBenchmark(MemBwBenchmark):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + self._args.mem_type[cmd_idx], raw_output)
self._result.add_raw_data('raw_output_' + self._args.mem_type[cmd_idx], raw_output, self._args.log_raw_data)
mem_bw = -1
value_index = -1

Просмотреть файл

@ -154,7 +154,7 @@ class TCPConnectivityBenchmark(MicroBenchmark):
True if the raw output string is valid and result can be extracted.
"""
host = self.__hosts[idx]
self._result.add_raw_data('raw_output_' + host, raw_output)
self._result.add_raw_data('raw_output_' + host, raw_output, self._args.log_raw_data)
try:
# If socket error or exception happens on TCPing, add result values as failed

Просмотреть файл

@ -127,7 +127,9 @@ class TensorRTInferenceBenchmark(MicroBenchmarkWithInvoke):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data(f'raw_output_{self._args.pytorch_models[cmd_idx]}', raw_output)
self._result.add_raw_data(
f'raw_output_{self._args.pytorch_models[cmd_idx]}', raw_output, self._args.log_raw_data
)
success = False
try:

Просмотреть файл

@ -400,8 +400,8 @@ class ModelBenchmark(Benchmark):
# The unit of step time is millisecond, use it to calculate the throughput with the unit samples/sec.
millisecond_per_second = 1000
throughput = [millisecond_per_second / step_time * self._args.batch_size for step_time in step_times]
self._result.add_raw_data(metric_s, step_times)
self._result.add_raw_data(metric_t, throughput)
self._result.add_raw_data(metric_s, step_times, self._args.log_raw_data)
self._result.add_raw_data(metric_t, throughput, self._args.log_raw_data)
if model_action == ModelAction.TRAIN:
if not self._sync_result(step_times):

Просмотреть файл

@ -3,6 +3,7 @@
"""A module for unified result of benchmarks."""
import os
import json
from enum import Enum
@ -46,7 +47,7 @@ class BenchmarkResult():
"""
return self.__dict__ == rhs.__dict__
def add_raw_data(self, metric, value):
def add_raw_data(self, metric, value, log_raw_data):
"""Add raw benchmark data into result.
Args:
@ -54,6 +55,7 @@ class BenchmarkResult():
value (str or list): raw benchmark data.
For e2e model benchmarks, its type is list.
For micro-benchmarks or docker-benchmarks, its type is string.
log_raw_data (bool): whether to log raw data into file instead of saving it into result object.
Return:
True if succeed to add the raw data.
@ -64,9 +66,14 @@ class BenchmarkResult():
)
return False
if metric not in self.__raw_data:
self.__raw_data[metric] = list()
self.__raw_data[metric].append(value)
if log_raw_data:
with open(os.path.join(os.getcwd(), 'rawdata.log'), 'a') as f:
f.write('metric:{}\n'.format(metric))
f.write('rawdata:{}\n\n'.format(value))
else:
if metric not in self.__raw_data:
self.__raw_data[metric] = list()
self.__raw_data[metric].append(value)
return True

Просмотреть файл

@ -200,6 +200,8 @@ class SuperBenchExecutor():
benchmark_config = self._sb_benchmarks[benchmark_name]
benchmark_results = list()
self.__create_benchmark_dir(benchmark_name)
cwd = os.getcwd()
os.chdir(self.__get_benchmark_dir(benchmark_name))
monitor = None
if self.__get_rank_id() == 0 and self._sb_monitor_config and self._sb_monitor_config.enable:
@ -243,3 +245,4 @@ class SuperBenchExecutor():
if monitor:
monitor.stop()
self.__write_benchmark_results(benchmark_name, benchmark_results)
os.chdir(cwd)

Просмотреть файл

@ -33,7 +33,7 @@ class FakeDockerBenchmark(DockerBenchmark):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output)
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data)
pattern = r'\d+\.\d+'
result = re.findall(pattern, raw_output)
if len(result) != 2:

Просмотреть файл

@ -3,6 +3,8 @@
"""Tests for RocmOnnxRuntimeModelBenchmark modules."""
from types import SimpleNamespace
from superbench.benchmarks import BenchmarkRegistry, BenchmarkType, Platform, ReturnCode
from superbench.benchmarks.result import BenchmarkResult
@ -20,6 +22,7 @@ def test_rocm_onnxruntime_performance():
assert (benchmark._entrypoint == '/stage/onnxruntime-training-examples/huggingface/azureml/run_benchmark.sh')
assert (benchmark._cmd is None)
benchmark._result = BenchmarkResult(benchmark._name, benchmark._benchmark_type, ReturnCode.SUCCESS)
benchmark._args = SimpleNamespace(log_raw_data=False)
raw_output = """
__superbench__ begin bert-large-uncased ngpu=1

Просмотреть файл

@ -54,7 +54,7 @@ class FakeGemmFlopsBenchmark(GemmFlopsBenchmark):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output)
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data)
try:
params = raw_output.strip('\n').split('--')

Просмотреть файл

@ -53,7 +53,7 @@ class FakeMemBwBenchmark(MemBwBenchmark):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output)
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data)
try:
params = raw_output.strip('\n').split(' memory=')

Просмотреть файл

@ -69,7 +69,7 @@ class FakeMicroBenchmarkWithInvoke(MicroBenchmarkWithInvoke):
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output)
self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data)
pattern = r'\d+\.\d+'
result = re.findall(pattern, raw_output)
if len(result) != 2:

Просмотреть файл

@ -121,7 +121,7 @@ class TensorRTInferenceBenchmarkTestCase(BenchmarkTestCase, unittest.TestCase):
"""Test tensorrt-inference benchmark result parsing."""
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
benchmark = benchmark_cls(self.benchmark_name, parameters='')
benchmark._args = SimpleNamespace(pytorch_models=['model_0', 'model_1'])
benchmark._args = SimpleNamespace(pytorch_models=['model_0', 'model_1'], log_raw_data=False)
benchmark._result = BenchmarkResult(self.benchmark_name, BenchmarkType.MICRO, ReturnCode.SUCCESS, run_count=1)
# Positive case - valid raw output

Просмотреть файл

@ -158,6 +158,8 @@ def test_arguments_related_interfaces():
--duration int The elapsed time of benchmark in seconds.
--force_fp32 Enable option to use full float32 precision.
--hidden_size int Hidden size.
--log_raw_data Log raw data into file instead of saving it into
result object.
--model_action ModelAction [ModelAction ...]
Benchmark model process. E.g. train inference.
--no_gpu Disable GPU training.
@ -192,6 +194,8 @@ def test_preprocess():
--duration int The elapsed time of benchmark in seconds.
--force_fp32 Enable option to use full float32 precision.
--hidden_size int Hidden size.
--log_raw_data Log raw data into file instead of saving it into
result object.
--model_action ModelAction [ModelAction ...]
Benchmark model process. E.g. train inference.
--no_gpu Disable GPU training.

Просмотреть файл

@ -49,7 +49,7 @@ class AccumulationBenchmark(MicroBenchmark):
raw_data.append(str(result))
metric = 'accumulation_result'
self._result.add_raw_data(metric, ','.join(raw_data))
self._result.add_raw_data(metric, ','.join(raw_data), self._args.log_raw_data)
self._result.add_result(metric, result)
return True
@ -114,6 +114,8 @@ def test_get_benchmark_configurable_settings():
expected = """optional arguments:
--duration int The elapsed time of benchmark in seconds.
--log_raw_data Log raw data into file instead of saving it into result
object.
--lower_bound int The lower bound for accumulation.
--run_count int The run count of benchmark.
--upper_bound int The upper bound for accumulation."""

Просмотреть файл

@ -3,6 +3,8 @@
"""Tests for BenchmarkResult module."""
import os
from superbench.benchmarks import BenchmarkType, ReturnCode, ReduceType
from superbench.benchmarks.result import BenchmarkResult
@ -10,22 +12,31 @@ from superbench.benchmarks.result import BenchmarkResult
def test_add_raw_data():
"""Test interface BenchmarkResult.add_raw_data()."""
result = BenchmarkResult('micro', BenchmarkType.MICRO, ReturnCode.SUCCESS)
result.add_raw_data('metric1', 'raw log 1')
result.add_raw_data('metric1', 'raw log 2')
result.add_raw_data('metric1', 'raw log 1', False)
result.add_raw_data('metric1', 'raw log 2', False)
assert (result.raw_data['metric1'][0] == 'raw log 1')
assert (result.raw_data['metric1'][1] == 'raw log 2')
assert (result.type == BenchmarkType.MICRO)
assert (result.return_code == ReturnCode.SUCCESS)
result = BenchmarkResult('model', BenchmarkType.MODEL, ReturnCode.SUCCESS)
result.add_raw_data('metric1', [1, 2, 3])
result.add_raw_data('metric1', [4, 5, 6])
result.add_raw_data('metric1', [1, 2, 3], False)
result.add_raw_data('metric1', [4, 5, 6], False)
assert (result.raw_data['metric1'][0] == [1, 2, 3])
assert (result.raw_data['metric1'][1] == [4, 5, 6])
assert (result.type == BenchmarkType.MODEL)
assert (result.return_code == ReturnCode.SUCCESS)
# Test log_raw_data = True.
result = BenchmarkResult('micro', BenchmarkType.MICRO, ReturnCode.SUCCESS)
result.add_raw_data('metric1', 'raw log 1', True)
result.add_raw_data('metric1', 'raw log 2', True)
assert (result.type == BenchmarkType.MICRO)
assert (result.return_code == ReturnCode.SUCCESS)
raw_data_file = os.path.join(os.getcwd(), 'rawdata.log')
assert (os.path.isfile(raw_data_file))
os.remove(raw_data_file)
def test_add_result():
"""Test interface BenchmarkResult.add_result()."""
@ -73,9 +84,9 @@ def test_serialize_deserialize():
result.add_result('metric1', 300, ReduceType.MAX)
result.add_result('metric1', 200, ReduceType.MAX)
result.add_result('metric2', 100, ReduceType.AVG)
result.add_raw_data('metric1', [1, 2, 3])
result.add_raw_data('metric1', [4, 5, 6])
result.add_raw_data('metric1', [7, 8, 9])
result.add_raw_data('metric1', [1, 2, 3], False)
result.add_raw_data('metric1', [4, 5, 6], False)
result.add_raw_data('metric1', [7, 8, 9], False)
start_time = '2021-02-03 16:59:49'
end_time = '2021-02-03 17:00:08'
result.set_timestamp(start_time, end_time)