Benchmarks: Code Revision - Fix some issue for BERT benchmark. (#58)
Benchmarks: Code Revision - Fix some issue for BERT benchmark. (#58)
This commit is contained in:
Родитель
af567cf650
Коммит
ce3ed24ab7
|
@ -1,22 +1,35 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Model benchmark example for bert-large.
|
||||
"""Model benchmark example for bert-large (24-layer, 1024-hidden, 16-heads, 340M parameters).
|
||||
|
||||
Commands to run:
|
||||
python3 examples/benchmarks/pytorch_bert_large.py (Single GPU)
|
||||
python3 -m torch.distributed.launch --nproc_per_node=8 examples/benchmarks/pytorch_bert_large.py (Distributed)
|
||||
python3 -m torch.distributed.launch --use_env --nproc_per_node=8 examples/benchmarks/pytorch_bert_large.py \
|
||||
--distributed (Distributed)
|
||||
"""
|
||||
|
||||
from superbench.benchmarks import Framework, BenchmarkRegistry
|
||||
import argparse
|
||||
|
||||
from superbench.benchmarks import Platform, Framework, BenchmarkRegistry
|
||||
from superbench.common.utils import logger
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--distributed', action='store_true', default=False, help='Whether to enable distributed training.'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Specify the model name and benchmark parameters.
|
||||
model_name = 'bert-large'
|
||||
parameters = '--batch_size 1 --duration 120 --seq_len 128 --precision float32 --run_count 2'
|
||||
if args.distributed:
|
||||
parameters += ' --distributed_impl ddp --distributed_backend nccl'
|
||||
|
||||
# Create context for bert-large benchmark and run it for 120 * 2 seconds.
|
||||
context = BenchmarkRegistry.create_benchmark_context(
|
||||
'bert-large',
|
||||
parameters='--batch_size 1 --duration 120 --seq_len 8 --precision float32 --run_count 2',
|
||||
framework=Framework.PYTORCH
|
||||
model_name, platform=Platform.CUDA, parameters=parameters, framework=Framework.PYTORCH
|
||||
)
|
||||
|
||||
benchmark = BenchmarkRegistry.launch_benchmark(context)
|
||||
|
|
|
@ -171,11 +171,12 @@ class PytorchBERT(PytorchBase):
|
|||
self._model.eval()
|
||||
while True:
|
||||
for idx, sample in enumerate(self._dataloader):
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
sample = sample.cuda()
|
||||
if self._gpu_available:
|
||||
sample = sample.cuda()
|
||||
self._model(sample)
|
||||
torch.cuda.synchronize()
|
||||
if self._gpu_available:
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
curr_step += 1
|
||||
if curr_step > self._args.num_warmup:
|
||||
|
|
|
@ -4,8 +4,8 @@
|
|||
"""Tests for BERT model benchmarks."""
|
||||
|
||||
from tests.helper import decorator
|
||||
from superbench.benchmarks import BenchmarkRegistry, Precision, Platform, Framework
|
||||
import superbench.benchmarks.model_benchmarks.pytorch_bert as pybert
|
||||
from superbench.benchmarks import BenchmarkRegistry, Platform, Framework, BenchmarkType, ReturnCode
|
||||
from superbench.benchmarks.model_benchmarks.pytorch_bert import PytorchBERT
|
||||
|
||||
|
||||
@decorator.cuda_test
|
||||
|
@ -15,86 +15,45 @@ def test_pytorch_bert_base():
|
|||
context = BenchmarkRegistry.create_benchmark_context(
|
||||
'bert-base',
|
||||
platform=Platform.CUDA,
|
||||
parameters='--batch_size 32 --num_classes 5 --seq_len 512',
|
||||
parameters='--batch_size 1 --num_classes 5 --seq_len 8 --num_warmup 2 --num_steps 4 \
|
||||
--model_action train inference',
|
||||
framework=Framework.PYTORCH
|
||||
)
|
||||
|
||||
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
|
||||
|
||||
benchmark_name = BenchmarkRegistry._BenchmarkRegistry__get_benchmark_name(context)
|
||||
assert (benchmark_name == 'pytorch-bert-base')
|
||||
benchmark = BenchmarkRegistry.launch_benchmark(context)
|
||||
|
||||
(benchmark_class,
|
||||
predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, context.platform)
|
||||
assert (benchmark_class == pybert.PytorchBERT)
|
||||
# Check basic information.
|
||||
assert (benchmark)
|
||||
assert (isinstance(benchmark, PytorchBERT))
|
||||
assert (benchmark.name == 'pytorch-bert-base')
|
||||
assert (benchmark.type == BenchmarkType.MODEL)
|
||||
|
||||
parameters = context.parameters
|
||||
if predefine_params:
|
||||
parameters = predefine_params + ' ' + parameters
|
||||
|
||||
benchmark = benchmark_class(benchmark_name, parameters)
|
||||
assert (benchmark._preprocess() is True)
|
||||
|
||||
# Predefined parameters of bert-base model.
|
||||
# Check predefined parameters of resnet101 model.
|
||||
assert (benchmark._args.hidden_size == 768)
|
||||
assert (benchmark._args.num_hidden_layers == 12)
|
||||
assert (benchmark._args.num_attention_heads == 12)
|
||||
assert (benchmark._args.intermediate_size == 3072)
|
||||
|
||||
# Parameters from BenchmarkContext.
|
||||
assert (benchmark._args.batch_size == 32)
|
||||
# Check parameters specified in BenchmarkContext.
|
||||
assert (benchmark._args.batch_size == 1)
|
||||
assert (benchmark._args.num_classes == 5)
|
||||
assert (benchmark._args.seq_len == 512)
|
||||
assert (benchmark._args.seq_len == 8)
|
||||
assert (benchmark._args.num_warmup == 2)
|
||||
assert (benchmark._args.num_steps == 4)
|
||||
|
||||
# Test Dataset.
|
||||
# Check dataset scale.
|
||||
assert (len(benchmark._dataset) == benchmark._args.sample_count * benchmark._world_size)
|
||||
|
||||
# Test _create_model().
|
||||
assert (benchmark._create_model(Precision.FLOAT32) is True)
|
||||
assert (isinstance(benchmark._model, pybert.BertBenchmarkModel))
|
||||
|
||||
|
||||
@decorator.cuda_test
|
||||
@decorator.pytorch_test
|
||||
def test_pytorch_bert_large():
|
||||
"""Test pytorch-bert-large benchmark."""
|
||||
context = BenchmarkRegistry.create_benchmark_context(
|
||||
'bert-large',
|
||||
platform=Platform.CUDA,
|
||||
parameters='--batch_size 32 --num_classes 5 --seq_len 512',
|
||||
framework=Framework.PYTORCH
|
||||
)
|
||||
|
||||
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
|
||||
|
||||
benchmark_name = BenchmarkRegistry._BenchmarkRegistry__get_benchmark_name(context)
|
||||
assert (benchmark_name == 'pytorch-bert-large')
|
||||
|
||||
(benchmark_class,
|
||||
predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, context.platform)
|
||||
assert (benchmark_class is pybert.PytorchBERT)
|
||||
|
||||
parameters = context.parameters
|
||||
if predefine_params:
|
||||
parameters = predefine_params + ' ' + parameters
|
||||
|
||||
benchmark = benchmark_class(benchmark_name, parameters)
|
||||
assert (benchmark._preprocess() is True)
|
||||
|
||||
# Predefined parameters of bert-large model.
|
||||
assert (benchmark._args.hidden_size == 1024)
|
||||
assert (benchmark._args.num_hidden_layers == 24)
|
||||
assert (benchmark._args.num_attention_heads == 16)
|
||||
assert (benchmark._args.intermediate_size == 4096)
|
||||
|
||||
# Parameters from BenchmarkContext.
|
||||
assert (benchmark._args.batch_size == 32)
|
||||
assert (benchmark._args.num_classes == 5)
|
||||
assert (benchmark._args.seq_len == 512)
|
||||
|
||||
# Test Dataset.
|
||||
assert (len(benchmark._dataset) == benchmark._args.sample_count * benchmark._world_size)
|
||||
|
||||
# Test _create_model().
|
||||
assert (benchmark._create_model(Precision.FLOAT32) is True)
|
||||
assert (isinstance(benchmark._model, pybert.BertBenchmarkModel))
|
||||
# Check results and metrics.
|
||||
assert (benchmark.run_count == 1)
|
||||
assert (benchmark.return_code == ReturnCode.SUCCESS)
|
||||
for metric in [
|
||||
'steptime_train_float32', 'throughput_train_float32', 'steptime_train_float16', 'throughput_train_float16',
|
||||
'steptime_inference_float32', 'throughput_inference_float32', 'steptime_inference_float16',
|
||||
'throughput_inference_float16'
|
||||
]:
|
||||
assert (len(benchmark.raw_data[metric]) == benchmark.run_count)
|
||||
assert (len(benchmark.raw_data[metric][0]) == benchmark._args.num_steps)
|
||||
assert (len(benchmark.result[metric]) == benchmark.run_count)
|
||||
|
|
Загрузка…
Ссылка в новой задаче